├── .clang-format
├── .clangd_template
├── .github
└── workflows
│ ├── building.yml
│ ├── core_tests.yml
│ └── cuda
│ ├── Linux-env.sh
│ ├── Linux.sh
│ ├── Windows-env.sh
│ └── Windows.sh
├── .gitignore
├── .gitmodules
├── LICENSE
├── MANIFEST.in
├── README.md
├── assets
└── test_garden.npz
├── docs
└── _static
│ └── imgs
│ └── front-fig-stacked.jpg
├── examples
├── lidar_rendering.ipynb
└── rolling_shutter.ipynb
├── gsplat
├── __init__.py
├── _helper.py
├── compression
│ ├── __init__.py
│ ├── png_compression.py
│ └── sort.py
├── cuda
│ ├── __init__.py
│ ├── _backend.py
│ ├── _torch_impl.py
│ ├── _wrapper.py
│ └── csrc
│ │ ├── bindings.h
│ │ ├── ext.cpp
│ │ ├── helpers.cuh
│ │ ├── projection.cu
│ │ ├── rasterization.cu
│ │ ├── relocation.cu
│ │ ├── sh.cu
│ │ └── utils.cuh
├── cuda_legacy
│ ├── __init__.py
│ ├── _backend.py
│ ├── _torch_impl.py
│ ├── _wrapper.py
│ └── csrc
│ │ ├── CMakeLists.txt
│ │ ├── backward.cu
│ │ ├── backward.cuh
│ │ ├── bindings.cu
│ │ ├── bindings.h
│ │ ├── config.h
│ │ ├── ext.cpp
│ │ ├── forward.cu
│ │ ├── forward.cuh
│ │ ├── helpers.cuh
│ │ └── sh.cuh
├── distributed.py
├── profile.py
├── relocation.py
├── rendering.py
├── strategy
│ ├── __init__.py
│ ├── base.py
│ ├── default.py
│ ├── mcmc.py
│ └── ops.py
├── utils.py
└── version.py
├── setup.py
└── tests
├── _test_distributed.py
├── test_basic.py
├── test_compression.py
├── test_rasterization.py
└── test_strategy.py
/.clang-format:
--------------------------------------------------------------------------------
1 | BasedOnStyle: LLVM
2 | AlignAfterOpenBracket: BlockIndent
3 | BinPackArguments: false
4 | BinPackParameters: false
5 | IndentWidth: 4
6 |
--------------------------------------------------------------------------------
/.clangd_template:
--------------------------------------------------------------------------------
1 | # To set up the clangd config.
2 | # 1. Activate an environment with cuda installed. (probably via conda, skip if using system CUDA)
3 | # 2. Run:
4 | # echo "# Autogenerated, see .clangd_template\!" > .clangd && sed -e "/^#/d" -e "s|YOUR_CUDA_PATH|$(dirname $(dirname $(which nvcc)))|" .clangd_template >> .clangd
5 | CompileFlags:
6 | Add:
7 | - -Xclang
8 | - -fcuda-allow-variadic-functions
9 | - --cuda-path=YOUR_CUDA_PATH
10 | Remove:
11 | - --diag_suppress=*
12 | - --generate-code=*
13 | - -gencode=*
14 | - -forward-unknown-to-host-compiler
15 | - -Xcompiler
16 | - -Xcudafe
17 | - --use_fast_math
18 | - --options-file
19 | - --compiler-options
20 | - --expt-relaxed-constexpr
21 |
--------------------------------------------------------------------------------
/.github/workflows/building.yml:
--------------------------------------------------------------------------------
1 | name: Building Wheels
2 |
3 | on: [workflow_dispatch]
4 |
5 | jobs:
6 |
7 | wheel:
8 | runs-on: ${{ matrix.os }}
9 | environment: production
10 |
11 | strategy:
12 | fail-fast: false
13 | matrix:
14 | # os: [windows-2019]
15 | # python-version: ['3.7']
16 | # torch-version: [1.11.0]
17 | # cuda-version: ['cu113']
18 | os: [ubuntu-20.04, windows-2019]
19 | # support version based on: https://download.pytorch.org/whl/torch/
20 | python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
21 | torch-version: [1.11.0, 1.12.0, 1.13.0, 2.0.0]
22 | cuda-version: ['cu113', 'cu115', 'cu116', 'cu117', 'cu118']
23 | exclude:
24 | - torch-version: 1.11.0
25 | python-version: '3.11'
26 | - torch-version: 1.11.0
27 | cuda-version: 'cu116'
28 | - torch-version: 1.11.0
29 | cuda-version: 'cu117'
30 | - torch-version: 1.11.0
31 | cuda-version: 'cu118'
32 |
33 | - torch-version: 1.12.0
34 | python-version: '3.11'
35 | - torch-version: 1.12.0
36 | cuda-version: 'cu115'
37 | - torch-version: 1.12.0
38 | cuda-version: 'cu117'
39 | - torch-version: 1.12.0
40 | cuda-version: 'cu118'
41 |
42 | - torch-version: 1.13.0
43 | cuda-version: 'cu102'
44 | - torch-version: 1.13.0
45 | cuda-version: 'cu113'
46 | - torch-version: 1.13.0
47 | cuda-version: 'cu115'
48 | - torch-version: 1.13.0
49 | cuda-version: 'cu118'
50 |
51 | - torch-version: 2.0.0
52 | python-version: '3.7'
53 | - torch-version: 2.0.0
54 | cuda-version: 'cu102'
55 | - torch-version: 2.0.0
56 | cuda-version: 'cu113'
57 | - torch-version: 2.0.0
58 | cuda-version: 'cu115'
59 | - torch-version: 2.0.0
60 | cuda-version: 'cu116'
61 |
62 | - os: windows-2019
63 | cuda-version: 'cu102'
64 | - os: windows-2019
65 | torch-version: 1.13.0
66 | python-version: '3.11'
67 |
68 | # - os: windows-2019
69 | # torch-version: 1.13.0
70 | # cuda-version: 'cu117'
71 | # python-version: '3.9'
72 |
73 |
74 |
75 | steps:
76 | - uses: actions/checkout@v3
77 | with:
78 | submodules: 'recursive'
79 |
80 | - name: Set up Python ${{ matrix.python-version }}
81 | uses: actions/setup-python@v4
82 | with:
83 | python-version: ${{ matrix.python-version }}
84 |
85 | - name: Upgrade pip
86 | run: |
87 | pip install --upgrade setuptools
88 | pip install ninja
89 |
90 | - name: Free up disk space
91 | if: ${{ runner.os == 'Linux' }}
92 | run: |
93 | sudo rm -rf /usr/share/dotnet
94 |
95 | - name: Install CUDA ${{ matrix.cuda-version }}
96 | if: ${{ matrix.cuda-version != 'cpu' }}
97 | run: |
98 | bash .github/workflows/cuda/${{ runner.os }}.sh ${{ matrix.cuda-version }}
99 |
100 | - name: Install PyTorch ${{ matrix.torch-version }}+${{ matrix.cuda-version }}
101 | run: |
102 | pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ matrix.cuda-version }}
103 | python -c "import torch; print('PyTorch:', torch.__version__)"
104 | python -c "import torch; print('CUDA:', torch.version.cuda)"
105 | python -c "import torch; print('CUDA Available:', torch.cuda.is_available())"
106 |
107 | - name: Patch PyTorch static constexpr on Windows
108 | if: ${{ runner.os == 'Windows' }}
109 | run: |
110 | Torch_DIR=`python -c 'import os; import torch; print(os.path.dirname(torch.__file__))'`
111 | sed -i '31,38c\
112 | TORCH_API void lazy_init_num_threads();' ${Torch_DIR}/include/ATen/Parallel.h
113 | shell: bash
114 |
115 | - name: Set version
116 | if: ${{ runner.os != 'macOS' }}
117 | run: |
118 | VERSION=`sed -n 's/^__version__ = "\(.*\)"/\1/p' gsplat/version.py`
119 | TORCH_VERSION=`echo "pt${{ matrix.torch-version }}" | sed "s/..$//" | sed "s/\.//g"`
120 | CUDA_VERSION=`echo ${{ matrix.cuda-version }}`
121 | echo "New version name: $VERSION+$TORCH_VERSION$CUDA_VERSION"
122 | sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" gsplat/version.py
123 | shell:
124 | bash
125 |
126 | - name: Install main package for CPU
127 | if: ${{ matrix.cuda-version == 'cpu' }}
128 | run: |
129 | BUILD_NO_CUDA=1 pip install .
130 |
131 | - name: Build wheel
132 | run: |
133 | pip install wheel
134 | source .github/workflows/cuda/${{ runner.os }}-env.sh ${{ matrix.cuda-version }}
135 | python setup.py bdist_wheel --dist-dir=dist
136 | shell: bash # `source` does not exist in windows powershell
137 |
138 | - name: Test wheel
139 | run: |
140 | cd dist
141 | ls -lah
142 | pip install *.whl
143 | python -c "import gsplat; print('gsplat:', gsplat.__version__)"
144 | cd ..
145 | shell: bash # `ls -lah` does not exist in windows powershell
146 |
--------------------------------------------------------------------------------
/.github/workflows/core_tests.yml:
--------------------------------------------------------------------------------
1 | name: Core Tests.
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | permissions:
10 | contents: read
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 |
16 | steps:
17 | - uses: actions/checkout@v3
18 | with:
19 | submodules: 'recursive'
20 |
21 | - name: Set up Python 3.8.12
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: "3.8.12"
25 | - name: Install dependencies
26 | run: |
27 | pip install pytest
28 | pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cpu
29 | BUILD_NO_CUDA=1 pip install .
30 | - name: Run Tests.
31 | run: pytest tests/
32 |
--------------------------------------------------------------------------------
/.github/workflows/cuda/Linux-env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Took from https://github.com/pyg-team/pyg-lib/
4 |
5 | case ${1} in
6 | cu118)
7 | export CUDA_HOME=/usr/local/cuda-11.8
8 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
9 | export PATH=${CUDA_HOME}/bin:${PATH}
10 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
11 | ;;
12 | cu117)
13 | export CUDA_HOME=/usr/local/cuda-11.7
14 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
15 | export PATH=${CUDA_HOME}/bin:${PATH}
16 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
17 | ;;
18 | cu116)
19 | export CUDA_HOME=/usr/local/cuda-11.6
20 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
21 | export PATH=${CUDA_HOME}/bin:${PATH}
22 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
23 | ;;
24 | cu115)
25 | export CUDA_HOME=/usr/local/cuda-11.5
26 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
27 | export PATH=${CUDA_HOME}/bin:${PATH}
28 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
29 | ;;
30 | cu113)
31 | export CUDA_HOME=/usr/local/cuda-11.3
32 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
33 | export PATH=${CUDA_HOME}/bin:${PATH}
34 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
35 | ;;
36 | cu102)
37 | export CUDA_HOME=/usr/local/cuda-10.2
38 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
39 | export PATH=${CUDA_HOME}/bin:${PATH}
40 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5"
41 | ;;
42 | *)
43 | ;;
44 | esac
--------------------------------------------------------------------------------
/.github/workflows/cuda/Linux.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Took from https://github.com/pyg-team/pyg-lib/
4 |
5 | OS=ubuntu2004
6 |
7 | case ${1} in
8 | cu118)
9 | CUDA=11.8
10 | APT_KEY=${OS}-${CUDA/./-}-local
11 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.0-520.61.05-1_amd64.deb
12 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.0/local_installers
13 | ;;
14 | cu117)
15 | CUDA=11.7
16 | APT_KEY=${OS}-${CUDA/./-}-local
17 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.1-515.65.01-1_amd64.deb
18 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.1/local_installers
19 | ;;
20 | cu116)
21 | CUDA=11.6
22 | APT_KEY=${OS}-${CUDA/./-}-local
23 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.2-510.47.03-1_amd64.deb
24 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.2/local_installers
25 | ;;
26 | cu115)
27 | CUDA=11.5
28 | APT_KEY=${OS}-${CUDA/./-}-local
29 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.2-495.29.05-1_amd64.deb
30 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.2/local_installers
31 | ;;
32 | cu113)
33 | CUDA=11.3
34 | APT_KEY=${OS}-${CUDA/./-}-local
35 | FILENAME=cuda-repo-${APT_KEY}_${CUDA}.0-465.19.01-1_amd64.deb
36 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.0/local_installers
37 | ;;
38 | cu102)
39 | CUDA=10.2
40 | APT_KEY=${CUDA/./-}-local-${CUDA}.89-440.33.01
41 | FILENAME=cuda-repo-${OS}-${APT_KEY}_1.0-1_amd64.deb
42 | URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}/Prod/local_installers
43 | ;;
44 | *)
45 | echo "Unrecognized CUDA_VERSION=${1}"
46 | exit 1
47 | ;;
48 | esac
49 |
50 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
51 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
52 | wget -nv ${URL}/${FILENAME}
53 | sudo dpkg -i ${FILENAME}
54 |
55 | if [ "${1}" = "cu117" ] || [ "${1}" = "cu118" ]; then
56 | sudo cp /var/cuda-repo-${APT_KEY}/cuda-*-keyring.gpg /usr/share/keyrings/
57 | else
58 | sudo apt-key add /var/cuda-repo-${APT_KEY}/7fa2af80.pub
59 | fi
60 |
61 | sudo apt-get update
62 | sudo apt-get -y install cuda
63 |
64 | rm -f ${FILENAME}
--------------------------------------------------------------------------------
/.github/workflows/cuda/Windows-env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Took from https://github.com/pyg-team/pyg-lib/
4 |
5 | case ${1} in
6 | cu118)
7 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.8
8 | PATH=${CUDA_HOME}/bin:$PATH
9 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
10 | export TORCH_CUDA_ARCH_LIST="6.0+PTX"
11 | ;;
12 | cu117)
13 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.7
14 | PATH=${CUDA_HOME}/bin:$PATH
15 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
16 | export TORCH_CUDA_ARCH_LIST="6.0+PTX"
17 | ;;
18 | cu116)
19 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.6
20 | PATH=${CUDA_HOME}/bin:$PATH
21 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
22 | export TORCH_CUDA_ARCH_LIST="6.0+PTX"
23 | ;;
24 | cu115)
25 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.5
26 | PATH=${CUDA_HOME}/bin:$PATH
27 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
28 | export TORCH_CUDA_ARCH_LIST="6.0+PTX"
29 | ;;
30 | cu113)
31 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3
32 | PATH=${CUDA_HOME}/bin:$PATH
33 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
34 | export TORCH_CUDA_ARCH_LIST="6.0+PTX"
35 | ;;
36 | *)
37 | ;;
38 | esac
--------------------------------------------------------------------------------
/.github/workflows/cuda/Windows.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Took from https://github.com/pyg-team/pyg-lib/
4 |
5 | # Install NVIDIA drivers, see:
6 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102
7 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip"
8 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32"
9 |
10 | case ${1} in
11 | cu118)
12 | CUDA_SHORT=11.8
13 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers
14 | CUDA_FILE=cuda_${CUDA_SHORT}.0_522.06_windows.exe
15 | ;;
16 | cu117)
17 | CUDA_SHORT=11.7
18 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers
19 | CUDA_FILE=cuda_${CUDA_SHORT}.1_516.94_windows.exe
20 | ;;
21 | cu116)
22 | CUDA_SHORT=11.3
23 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers
24 | CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe
25 | ;;
26 | cu115)
27 | CUDA_SHORT=11.3
28 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers
29 | CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe
30 | ;;
31 | cu113)
32 | CUDA_SHORT=11.3
33 | CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers
34 | CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe
35 | ;;
36 | *)
37 | echo "Unrecognized CUDA_VERSION=${1}"
38 | exit 1
39 | ;;
40 | esac
41 |
42 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}"
43 | echo ""
44 | echo "Installing from ${CUDA_FILE}..."
45 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow"
46 | echo "Done!"
47 | rm -f "${CUDA_FILE}"
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .clangd
2 | compile_commands.json
3 |
4 | # Visual Studio Code configs.
5 | .vscode/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | # lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # pyenv
82 | .python-version
83 |
84 | # celery beat schedule file
85 | celerybeat-schedule
86 |
87 | # SageMath parsed files
88 | *.sage.py
89 |
90 | # Environments
91 | .env
92 | .venv
93 | env/
94 | venv/
95 | ENV/
96 | env.bak/
97 | venv.bak/
98 |
99 | # Spyder project settings
100 | .spyderproject
101 | .spyproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 |
112 | .DS_Store
113 |
114 | # Direnv config.
115 | .envrc
116 |
117 | # line_profiler
118 | *.lprof
119 |
120 | *build
121 | compile_commands.json
122 | *.dump
123 |
124 | data
125 | results
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "gsplat/cuda/csrc/third_party/glm"]
2 | path = gsplat/cuda/csrc/third_party/glm
3 | url = https://github.com/g-truc/glm.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include gsplat/cuda/csrc *
2 | recursive-include gsplat/cuda_legacy/csrc *
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
SplatAD
12 |
13 | Real-Time Lidar and Camera Rendering with 3D Gaussian Splatting for Autonomous Driving
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | [Project page](https://research.zenseact.com/publications/splatad/)
26 |
27 |
28 |
29 |
30 | # About
31 | This is the official repository for the CVPR 2025 paper [_SplatAD: Real-Time Lidar and Camera Rendering with 3D Gaussian Splatting for Autonomous Driving_](https://arxiv.org/abs/2411.16816). The code in this repository builds upon the open-source library [gsplat](https://github.com/nerfstudio-project/gsplat), with modifications and extensions designed for autonomous driving data.
32 |
33 | While the code contians all components needed to efficiently render camera and lidar data, the SplatAD-model itself, including dataloading, decoders, etc., will be released through [neurad-studio](https://github.com/georghess/neurad-studio).
34 |
35 | **We welcome all contributions!**
36 |
37 | # Key Features
38 | - Efficient lidar rendering
39 | - Projection to spherical coordinates
40 | - Depth and feature rasterization for a non-linear grid of points
41 | - Rolling shutter compensation for camera and lidar
42 |
43 |
44 | # Installation
45 | Our code introduce no additional dependencies. We thus refer to the original documentation from gsplat for both [installation](https://github.com/nerfstudio-project/gsplat#installation) and [development setup](https://github.com/nerfstudio-project/gsplat/blob/main/docs/DEV.md).
46 |
47 | # Usage
48 | See [`rasterization`](gsplat/rendering.py#L22) and [`lidar_rasterization`]((gsplat/rendering.py#L443)) for entry points to camera and lidar rasterization.
49 | Additionally, we provide example notebooks under [examples](examples) that demonstrate lidar rendering and rolling shutter compensation.
50 | For further examples, check out the [test files](tests).
51 |
52 |
53 | # Built On
54 | - [gsplat](https://github.com/nerfstudio-project/gsplat) - Collaboration friendly library for CUDA accelerated rasterization of Gaussians with python bindings
55 | - [3dgs-deblur](https://github.com/SpectacularAI/3dgs-deblur) - Inspiration for the rolling shutter compensation
56 |
57 | # Citation
58 |
59 | You can find our paper on [arXiv](https://arxiv.org/abs/2411.16816).
60 |
61 | If you use this code or find our paper useful, please consider citing:
62 |
63 | ```bibtex
64 | @article{hess2024splatad,
65 | title={SplatAD: Real-Time Lidar and Camera Rendering with 3D Gaussian Splatting for Autonomous Driving},
66 | author={Hess, Georg and Lindstr{\"o}m, Carl and Fatemi, Maryam and Petersson, Christoffer and Svensson, Lennart},
67 | journal={arXiv preprint arXiv:2411.16816},
68 | year={2024}
69 | }
70 | ```
71 |
72 | # Contributors
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 | \+ [gsplat contributors](https://github.com/nerfstudio-project/gsplat/graphs/contributors)
--------------------------------------------------------------------------------
/assets/test_garden.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carlinds/splatad/5d3c9bd4b856f142c707a9c0b78161f555bacb10/assets/test_garden.npz
--------------------------------------------------------------------------------
/docs/_static/imgs/front-fig-stacked.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carlinds/splatad/5d3c9bd4b856f142c707a9c0b78161f555bacb10/docs/_static/imgs/front-fig-stacked.jpg
--------------------------------------------------------------------------------
/gsplat/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from .cuda._torch_impl import accumulate
4 | from .cuda._wrapper import (
5 | fully_fused_projection,
6 | isect_offset_encode,
7 | isect_tiles,
8 | persp_proj,
9 | quat_scale_to_covar_preci,
10 | rasterize_to_indices_in_range,
11 | rasterize_to_pixels,
12 | rasterize_to_points,
13 | spherical_harmonics,
14 | world_to_cam,
15 | map_points_to_lidar_tiles,
16 | points_mapping_offset_encode,
17 | populate_image_from_points,
18 | )
19 | from .rendering import (
20 | rasterization,
21 | rasterization_inria_wrapper,
22 | rasterization_legacy_wrapper,
23 | )
24 | from .version import __version__
25 |
26 |
27 | def rasterize_gaussians(*args, **kwargs):
28 | from .cuda_legacy._wrapper import rasterize_gaussians
29 |
30 | warnings.warn(
31 | "'rasterize_gaussians is deprecated and will be removed in a future release. "
32 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.",
33 | DeprecationWarning,
34 | )
35 | return rasterize_gaussians(*args, **kwargs)
36 |
37 |
38 | def project_gaussians(*args, **kwargs):
39 | from .cuda_legacy._wrapper import project_gaussians
40 |
41 | warnings.warn(
42 | "'project_gaussians is deprecated and will be removed in a future release. "
43 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.",
44 | DeprecationWarning,
45 | )
46 | return project_gaussians(*args, **kwargs)
47 |
48 |
49 | def map_gaussian_to_intersects(*args, **kwargs):
50 | from .cuda_legacy._wrapper import map_gaussian_to_intersects
51 |
52 | warnings.warn(
53 | "'map_gaussian_to_intersects is deprecated and will be removed in a future release. "
54 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.",
55 | DeprecationWarning,
56 | )
57 | return map_gaussian_to_intersects(*args, **kwargs)
58 |
59 |
60 | def bin_and_sort_gaussians(*args, **kwargs):
61 | from .cuda_legacy._wrapper import bin_and_sort_gaussians
62 |
63 | warnings.warn(
64 | "'bin_and_sort_gaussians is deprecated and will be removed in a future release. "
65 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.",
66 | DeprecationWarning,
67 | )
68 | return bin_and_sort_gaussians(*args, **kwargs)
69 |
70 |
71 | def compute_cumulative_intersects(*args, **kwargs):
72 | from .cuda_legacy._wrapper import compute_cumulative_intersects
73 |
74 | warnings.warn(
75 | "'compute_cumulative_intersects is deprecated and will be removed in a future release. "
76 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.",
77 | DeprecationWarning,
78 | )
79 | return compute_cumulative_intersects(*args, **kwargs)
80 |
81 |
82 | def compute_cov2d_bounds(*args, **kwargs):
83 | from .cuda_legacy._wrapper import compute_cov2d_bounds
84 |
85 | warnings.warn(
86 | "'compute_cov2d_bounds is deprecated and will be removed in a future release. "
87 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.",
88 | DeprecationWarning,
89 | )
90 | return compute_cov2d_bounds(*args, **kwargs)
91 |
92 |
93 | def get_tile_bin_edges(*args, **kwargs):
94 | from .cuda_legacy._wrapper import get_tile_bin_edges
95 |
96 | warnings.warn(
97 | "'get_tile_bin_edges is deprecated and will be removed in a future release. "
98 | "Use gsplat.rasterization for end-to-end rasterizing GSs to images instead.",
99 | DeprecationWarning,
100 | )
101 | return get_tile_bin_edges(*args, **kwargs)
102 |
103 |
104 | all = [
105 | "rasterization",
106 | "rasterization_legacy_wrapper",
107 | "rasterization_inria_wrapper",
108 | "spherical_harmonics",
109 | "isect_offset_encode",
110 | "isect_tiles",
111 | "isect_lidar_tiles",
112 | "map_points_to_lidar_tiles",
113 | "points_mapping_offset_encode",
114 | "populate_image_from_points",
115 | "persp_proj",
116 | "fully_fused_projection",
117 | "quat_scale_to_covar_preci",
118 | "rasterize_to_pixels",
119 | "rasterize_to_points",
120 | "world_to_cam",
121 | "accumulate",
122 | "rasterize_to_indices_in_range",
123 | "rasterize_to_indices_in_range_lidar",
124 | "__version__",
125 | # deprecated
126 | "rasterize_gaussians",
127 | "project_gaussians",
128 | "map_gaussian_to_intersects",
129 | "bin_and_sort_gaussians",
130 | "compute_cumulative_intersects",
131 | "compute_cov2d_bounds",
132 | "get_tile_bin_edges",
133 | ]
134 |
--------------------------------------------------------------------------------
/gsplat/_helper.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def load_test_data(
10 | data_path: Optional[str] = None,
11 | device="cuda",
12 | scene_crop: Tuple[float, float, float, float, float, float] = (-2, -2, -2, 2, 2, 2),
13 | scene_grid: int = 1,
14 | ):
15 | """Load the test data."""
16 | assert scene_grid % 2 == 1, "scene_grid must be odd"
17 |
18 | if data_path is None:
19 | data_path = os.path.join(os.path.dirname(__file__), "../assets/test_garden.npz")
20 | data = np.load(data_path)
21 | height, width = data["height"].item(), data["width"].item()
22 | viewmats = torch.from_numpy(data["viewmats"]).float().to(device)
23 | Ks = torch.from_numpy(data["Ks"]).float().to(device)
24 | means = torch.from_numpy(data["means3d"]).float().to(device)
25 | colors = torch.from_numpy(data["colors"] / 255.0).float().to(device)
26 | C = len(viewmats)
27 |
28 | # crop
29 | aabb = torch.tensor(scene_crop, device=device)
30 | edges = aabb[3:] - aabb[:3]
31 | sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
32 | sel = torch.where(sel)[0]
33 | means, colors = means[sel], colors[sel]
34 |
35 | # repeat the scene into a grid (to mimic a large-scale setting)
36 | repeats = scene_grid
37 | gridx, gridy = torch.meshgrid(
38 | [
39 | torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
40 | torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
41 | ],
42 | indexing="ij",
43 | )
44 | grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(-1, 3)
45 | means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
46 | means = means.reshape(-1, 3)
47 | colors = colors.repeat(repeats**2, 1)
48 |
49 | # create gaussian attributes
50 | N = len(means)
51 | scales = torch.rand((N, 3), device=device) * 0.02
52 | quats = F.normalize(torch.randn((N, 4), device=device), dim=-1)
53 | opacities = torch.rand((N,), device=device)
54 |
55 | return means, quats, scales, opacities, colors, viewmats, Ks, width, height
56 |
--------------------------------------------------------------------------------
/gsplat/compression/__init__.py:
--------------------------------------------------------------------------------
1 | from .png_compression import PngCompression
2 |
--------------------------------------------------------------------------------
/gsplat/compression/png_compression.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from dataclasses import dataclass
4 | from typing import Any, Callable, Dict
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import Tensor
10 |
11 | from gsplat.compression.sort import sort_splats
12 | from gsplat.utils import inverse_log_transform, log_transform
13 |
14 |
15 | @dataclass
16 | class PngCompression:
17 | """Uses quantization and sorting to compress splats into PNG files and uses
18 | K-means clustering to compress the spherical harmonic coefficents.
19 |
20 | .. warning::
21 | This class requires the `imageio `_,
22 | `plas `_
23 | and `torchpq `_ packages to be installed.
24 |
25 | .. warning::
26 | This class might throw away a few lowest opacities splats if the number of
27 | splats is not a square number.
28 |
29 | .. note::
30 | The splats parameters are expected to be pre-activation values. It expects
31 | the following fields in the splats dictionary: "means", "scales", "quats",
32 | "opacities", "sh0", "shN". More fields can be added to the dictionary, but
33 | they will only be compressed using NPZ compression.
34 |
35 | References:
36 | - `Compact 3D Scene Representation via Self-Organizing Gaussian Grids `_
37 | - `Making Gaussian Splats more smaller `_
38 |
39 | Args:
40 | use_sort (bool, optional): Whether to sort splats before compression. Defaults to True.
41 | verbose (bool, optional): Whether to print verbose information. Default to True.
42 | """
43 |
44 | use_sort: bool = True
45 | verbose: bool = True
46 |
47 | def _get_compress_fn(self, param_name: str) -> Callable:
48 | compress_fn_map = {
49 | "means": _compress_png_16bit,
50 | "scales": _compress_png,
51 | "quats": _compress_png,
52 | "opacities": _compress_png,
53 | "sh0": _compress_png,
54 | "shN": _compress_kmeans,
55 | }
56 | if param_name in compress_fn_map:
57 | return compress_fn_map[param_name]
58 | else:
59 | return _compress_npz
60 |
61 | def _get_decompress_fn(self, param_name: str) -> Callable:
62 | decompress_fn_map = {
63 | "means": _decompress_png_16bit,
64 | "scales": _decompress_png,
65 | "quats": _decompress_png,
66 | "opacities": _decompress_png,
67 | "sh0": _decompress_png,
68 | "shN": _decompress_kmeans,
69 | }
70 | if param_name in decompress_fn_map:
71 | return decompress_fn_map[param_name]
72 | else:
73 | return _decompress_npz
74 |
75 | def compress(self, compress_dir: str, splats: Dict[str, Tensor]) -> None:
76 | """Run compression
77 |
78 | Args:
79 | compress_dir (str): directory to save compressed files
80 | splats (Dict[str, Tensor]): Gaussian splats to compress
81 | """
82 |
83 | # Param-specific preprocessing
84 | splats["means"] = log_transform(splats["means"])
85 | splats["quats"] = F.normalize(splats["quats"], dim=-1)
86 |
87 | n_gs = len(splats["means"])
88 | n_sidelen = int(n_gs**0.5)
89 | n_crop = n_gs - n_sidelen**2
90 | if n_crop != 0:
91 | splats = _crop_n_splats(splats, n_crop)
92 | print(
93 | f"Warning: Number of Gaussians was not square. Removed {n_crop} Gaussians."
94 | )
95 |
96 | if self.use_sort:
97 | splats = sort_splats(splats)
98 |
99 | meta = {}
100 | for param_name in splats.keys():
101 | compress_fn = self._get_compress_fn(param_name)
102 | kwargs = {
103 | "n_sidelen": n_sidelen,
104 | "verbose": self.verbose,
105 | }
106 | meta[param_name] = compress_fn(
107 | compress_dir, param_name, splats[param_name], **kwargs
108 | )
109 |
110 | with open(os.path.join(compress_dir, "meta.json"), "w") as f:
111 | json.dump(meta, f)
112 |
113 | def decompress(self, compress_dir: str) -> Dict[str, Tensor]:
114 | """Run decompression
115 |
116 | Args:
117 | compress_dir (str): directory that contains compressed files
118 |
119 | Returns:
120 | Dict[str, Tensor]: decompressed Gaussian splats
121 | """
122 | with open(os.path.join(compress_dir, "meta.json"), "r") as f:
123 | meta = json.load(f)
124 |
125 | splats = {}
126 | for param_name, param_meta in meta.items():
127 | decompress_fn = self._get_decompress_fn(param_name)
128 | splats[param_name] = decompress_fn(compress_dir, param_name, param_meta)
129 |
130 | # Param-specific postprocessing
131 | splats["means"] = inverse_log_transform(splats["means"])
132 | return splats
133 |
134 |
135 | def _crop_n_splats(splats: Dict[str, Tensor], n_crop: int) -> Dict[str, Tensor]:
136 | opacities = splats["opacities"]
137 | keep_indices = torch.argsort(opacities, descending=True)[:-n_crop]
138 | for k, v in splats.items():
139 | splats[k] = v[keep_indices]
140 | return splats
141 |
142 |
143 | def _compress_png(
144 | compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs
145 | ) -> Dict[str, Any]:
146 | """Compress parameters with 8-bit quantization and lossless PNG compression.
147 |
148 | Args:
149 | compress_dir (str): compression directory
150 | param_name (str): parameter field name
151 | params (Tensor): parameters
152 | n_sidelen (int): image side length
153 |
154 | Returns:
155 | Dict[str, Any]: metadata
156 | """
157 | import imageio.v2 as imageio
158 |
159 | if torch.numel == 0:
160 | meta = {
161 | "shape": list(params.shape),
162 | "dtype": str(params.dtype).split(".")[1],
163 | }
164 | return meta
165 |
166 | grid = params.reshape((n_sidelen, n_sidelen, -1))
167 | mins = torch.amin(grid, dim=(0, 1))
168 | maxs = torch.amax(grid, dim=(0, 1))
169 | grid_norm = (grid - mins) / (maxs - mins)
170 | img_norm = grid_norm.detach().cpu().numpy()
171 |
172 | img = (img_norm * (2**8 - 1)).round().astype(np.uint8)
173 | img = img.squeeze()
174 | imageio.imwrite(os.path.join(compress_dir, f"{param_name}.png"), img)
175 |
176 | meta = {
177 | "shape": list(params.shape),
178 | "dtype": str(params.dtype).split(".")[1],
179 | "mins": mins.tolist(),
180 | "maxs": maxs.tolist(),
181 | }
182 | return meta
183 |
184 |
185 | def _decompress_png(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor:
186 | """Decompress parameters from PNG file.
187 |
188 | Args:
189 | compress_dir (str): compression directory
190 | param_name (str): parameter field name
191 | meta (Dict[str, Any]): metadata
192 |
193 | Returns:
194 | Tensor: parameters
195 | """
196 | import imageio.v2 as imageio
197 |
198 | if not np.all(meta["shape"]):
199 | params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"]))
200 | return meta
201 |
202 | img = imageio.imread(os.path.join(compress_dir, f"{param_name}.png"))
203 | img_norm = img / (2**8 - 1)
204 |
205 | grid_norm = torch.tensor(img_norm)
206 | mins = torch.tensor(meta["mins"])
207 | maxs = torch.tensor(meta["maxs"])
208 | grid = grid_norm * (maxs - mins) + mins
209 |
210 | params = grid.reshape(meta["shape"])
211 | params = params.to(dtype=getattr(torch, meta["dtype"]))
212 | return params
213 |
214 |
215 | def _compress_png_16bit(
216 | compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs
217 | ) -> Dict[str, Any]:
218 | """Compress parameters with 16-bit quantization and PNG compression.
219 |
220 | Args:
221 | compress_dir (str): compression directory
222 | param_name (str): parameter field name
223 | params (Tensor): parameters
224 | n_sidelen (int): image side length
225 |
226 | Returns:
227 | Dict[str, Any]: metadata
228 | """
229 | import imageio.v2 as imageio
230 |
231 | if torch.numel == 0:
232 | meta = {
233 | "shape": list(params.shape),
234 | "dtype": str(params.dtype).split(".")[1],
235 | }
236 | return meta
237 |
238 | grid = params.reshape((n_sidelen, n_sidelen, -1))
239 | mins = torch.amin(grid, dim=(0, 1))
240 | maxs = torch.amax(grid, dim=(0, 1))
241 | grid_norm = (grid - mins) / (maxs - mins)
242 | img_norm = grid_norm.detach().cpu().numpy()
243 | img = (img_norm * (2**16 - 1)).round().astype(np.uint16)
244 |
245 | img_l = img & 0xFF
246 | img_u = (img >> 8) & 0xFF
247 | imageio.imwrite(
248 | os.path.join(compress_dir, f"{param_name}_l.png"), img_l.astype(np.uint8)
249 | )
250 | imageio.imwrite(
251 | os.path.join(compress_dir, f"{param_name}_u.png"), img_u.astype(np.uint8)
252 | )
253 |
254 | meta = {
255 | "shape": list(params.shape),
256 | "dtype": str(params.dtype).split(".")[1],
257 | "mins": mins.tolist(),
258 | "maxs": maxs.tolist(),
259 | }
260 | return meta
261 |
262 |
263 | def _decompress_png_16bit(
264 | compress_dir: str, param_name: str, meta: Dict[str, Any]
265 | ) -> Tensor:
266 | """Decompress parameters from PNG files.
267 |
268 | Args:
269 | compress_dir (str): compression directory
270 | param_name (str): parameter field name
271 | meta (Dict[str, Any]): metadata
272 |
273 | Returns:
274 | Tensor: parameters
275 | """
276 | import imageio.v2 as imageio
277 |
278 | if not np.all(meta["shape"]):
279 | params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"]))
280 | return meta
281 |
282 | img_l = imageio.imread(os.path.join(compress_dir, f"{param_name}_l.png"))
283 | img_u = imageio.imread(os.path.join(compress_dir, f"{param_name}_u.png"))
284 | img_u = img_u.astype(np.uint16)
285 | img = (img_u << 8) + img_l
286 |
287 | img_norm = img / (2**16 - 1)
288 | grid_norm = torch.tensor(img_norm)
289 | mins = torch.tensor(meta["mins"])
290 | maxs = torch.tensor(meta["maxs"])
291 | grid = grid_norm * (maxs - mins) + mins
292 |
293 | params = grid.reshape(meta["shape"])
294 | params = params.to(dtype=getattr(torch, meta["dtype"]))
295 | return params
296 |
297 |
298 | def _compress_npz(
299 | compress_dir: str, param_name: str, params: Tensor, **kwargs
300 | ) -> Dict[str, Any]:
301 | """Compress parameters with numpy's NPZ compression."""
302 | npz_dict = {"arr": params.detach().cpu().numpy()}
303 | save_fp = os.path.join(compress_dir, f"{param_name}.npz")
304 | os.makedirs(os.path.dirname(save_fp), exist_ok=True)
305 | np.savez_compressed(save_fp, **npz_dict)
306 | meta = {
307 | "shape": params.shape,
308 | "dtype": str(params.dtype).split(".")[1],
309 | }
310 | return meta
311 |
312 |
313 | def _decompress_npz(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor:
314 | """Decompress parameters with numpy's NPZ compression."""
315 | arr = np.load(os.path.join(compress_dir, f"{param_name}.npz"))["arr"]
316 | params = torch.tensor(arr)
317 | params = params.reshape(meta["shape"])
318 | params = params.to(dtype=getattr(torch, meta["dtype"]))
319 | return params
320 |
321 |
322 | def _compress_kmeans(
323 | compress_dir: str,
324 | param_name: str,
325 | params: Tensor,
326 | n_clusters: int = 65536,
327 | quantization: int = 6,
328 | verbose: bool = True,
329 | **kwargs,
330 | ) -> Dict[str, Any]:
331 | """Run K-means clustering on parameters and save centroids and labels to a npz file.
332 |
333 | .. warning::
334 | TorchPQ must installed to use K-means clustering.
335 |
336 | Args:
337 | compress_dir (str): compression directory
338 | param_name (str): parameter field name
339 | params (Tensor): parameters to compress
340 | n_clusters (int): number of K-means clusters
341 | quantization (int): number of bits in quantization
342 | verbose (bool, optional): Whether to print verbose information. Default to True.
343 |
344 | Returns:
345 | Dict[str, Any]: metadata
346 | """
347 | try:
348 | from torchpq.clustering import KMeans
349 | except:
350 | raise ImportError(
351 | "Please install torchpq with 'pip install torchpq' to use K-means clustering"
352 | )
353 |
354 | if torch.numel == 0:
355 | meta = {
356 | "shape": list(params.shape),
357 | "dtype": str(params.dtype).split(".")[1],
358 | }
359 | return meta
360 |
361 | kmeans = KMeans(n_clusters=n_clusters, distance="manhattan", verbose=verbose)
362 | x = params.reshape(params.shape[0], -1).permute(1, 0).contiguous()
363 | labels = kmeans.fit(x)
364 | labels = labels.detach().cpu().numpy()
365 | centroids = kmeans.centroids.permute(1, 0)
366 |
367 | mins = torch.min(centroids)
368 | maxs = torch.max(centroids)
369 | centroids_norm = (centroids - mins) / (maxs - mins)
370 | centroids_norm = centroids_norm.detach().cpu().numpy()
371 | centroids_quant = (
372 | (centroids_norm * (2**quantization - 1)).round().astype(np.uint8)
373 | )
374 | labels = labels.astype(np.uint16)
375 |
376 | npz_dict = {
377 | "centroids": centroids_quant,
378 | "labels": labels,
379 | }
380 | np.savez_compressed(os.path.join(compress_dir, f"{param_name}.npz"), **npz_dict)
381 | meta = {
382 | "shape": list(params.shape),
383 | "dtype": str(params.dtype).split(".")[1],
384 | "mins": mins.tolist(),
385 | "maxs": maxs.tolist(),
386 | "quantization": quantization,
387 | }
388 | return meta
389 |
390 |
391 | def _decompress_kmeans(
392 | compress_dir: str, param_name: str, meta: Dict[str, Any], **kwargs
393 | ) -> Tensor:
394 | """Decompress parameters from K-means compression.
395 |
396 | Args:
397 | compress_dir (str): compression directory
398 | param_name (str): parameter field name
399 | meta (Dict[str, Any]): metadata
400 |
401 | Returns:
402 | Tensor: parameters
403 | """
404 | if not np.all(meta["shape"]):
405 | params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"]))
406 | return meta
407 |
408 | npz_dict = np.load(os.path.join(compress_dir, f"{param_name}.npz"))
409 | centroids_quant = npz_dict["centroids"]
410 | labels = npz_dict["labels"]
411 |
412 | centroids_norm = centroids_quant / (2 ** meta["quantization"] - 1)
413 | centroids_norm = torch.tensor(centroids_norm)
414 | mins = torch.tensor(meta["mins"])
415 | maxs = torch.tensor(meta["maxs"])
416 | centroids = centroids_norm * (maxs - mins) + mins
417 |
418 | params = centroids[labels]
419 | params = params.reshape(meta["shape"])
420 | params = params.to(dtype=getattr(torch, meta["dtype"]))
421 | return params
422 |
--------------------------------------------------------------------------------
/gsplat/compression/sort.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | def sort_splats(splats: Dict[str, Tensor], verbose: bool = True) -> Dict[str, Tensor]:
8 | """Sort splats with Parallel Linear Assignment Sorting from the paper `Compact 3D Scene Representation via
9 | Self-Organizing Gaussian Grids `_.
10 |
11 | .. warning::
12 | PLAS must installed to use sorting.
13 |
14 | Args:
15 | splats (Dict[str, Tensor]): splats
16 | verbose (bool, optional): Whether to print verbose information. Default to True.
17 |
18 | Returns:
19 | Dict[str, Tensor]: sorted splats
20 | """
21 | try:
22 | from plas import sort_with_plas
23 | except:
24 | raise ImportError(
25 | "Please install PLAS with 'pip install git+https://github.com/fraunhoferhhi/PLAS.git' to use sorting"
26 | )
27 |
28 | n_gs = len(splats["means"])
29 | n_sidelen = int(n_gs**0.5)
30 | assert n_sidelen**2 == n_gs, "Must be a perfect square"
31 |
32 | sort_keys = [k for k in splats if k != "shN"]
33 | params_to_sort = torch.cat([splats[k].reshape(n_gs, -1) for k in sort_keys], dim=-1)
34 | shuffled_indices = torch.randperm(
35 | params_to_sort.shape[0], device=params_to_sort.device
36 | )
37 | params_to_sort = params_to_sort[shuffled_indices]
38 | grid = params_to_sort.reshape((n_sidelen, n_sidelen, -1))
39 | _, sorted_indices = sort_with_plas(
40 | grid.permute(2, 0, 1), improvement_break=1e-4, verbose=verbose
41 | )
42 | sorted_indices = sorted_indices.squeeze().flatten()
43 | sorted_indices = shuffled_indices[sorted_indices]
44 | for k, v in splats.items():
45 | splats[k] = v[sorted_indices]
46 | return splats
47 |
--------------------------------------------------------------------------------
/gsplat/cuda/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carlinds/splatad/5d3c9bd4b856f142c707a9c0b78161f555bacb10/gsplat/cuda/__init__.py
--------------------------------------------------------------------------------
/gsplat/cuda/_backend.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import shutil
5 | from subprocess import DEVNULL, call
6 |
7 | from rich.console import Console
8 | from torch.utils.cpp_extension import (
9 | _get_build_directory,
10 | _import_module_from_library,
11 | load,
12 | )
13 |
14 | PATH = os.path.dirname(os.path.abspath(__file__))
15 | NO_FAST_MATH = os.getenv("NO_FAST_MATH", "0") == "1"
16 | MAX_JOBS = os.getenv("MAX_JOBS")
17 | need_to_unset_max_jobs = False
18 | if not MAX_JOBS:
19 | need_to_unset_max_jobs = True
20 | os.environ["MAX_JOBS"] = "10"
21 |
22 |
23 | def load_extension(
24 | name,
25 | sources,
26 | extra_cflags=None,
27 | extra_cuda_cflags=None,
28 | extra_ldflags=None,
29 | extra_include_paths=None,
30 | build_directory=None,
31 | ):
32 | """Load a JIT compiled extension."""
33 | # Make sure the build directory exists.
34 | if build_directory:
35 | os.makedirs(build_directory, exist_ok=True)
36 |
37 | # If the JIT build happens concurrently in multiple processes,
38 | # race conditions can occur when removing the lock file at:
39 | # https://github.com/pytorch/pytorch/blob/e3513fb2af7951ddf725d8c5b6f6d962a053c9da/torch/utils/cpp_extension.py#L1736
40 | # But it's ok so we catch this exception and ignore it.
41 | try:
42 | return load(
43 | name,
44 | sources,
45 | extra_cflags=extra_cflags,
46 | extra_cuda_cflags=extra_cuda_cflags,
47 | extra_ldflags=extra_ldflags,
48 | extra_include_paths=extra_include_paths,
49 | build_directory=build_directory,
50 | )
51 | except OSError:
52 | # The module should be already compiled
53 | return _import_module_from_library(name, build_directory, True)
54 |
55 |
56 | def cuda_toolkit_available():
57 | """Check if the nvcc is avaiable on the machine."""
58 | try:
59 | call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL)
60 | return True
61 | except FileNotFoundError:
62 | return False
63 |
64 |
65 | def cuda_toolkit_version():
66 | """Get the cuda toolkit version."""
67 | cuda_home = os.path.join(os.path.dirname(shutil.which("nvcc")), "..")
68 | if os.path.exists(os.path.join(cuda_home, "version.txt")):
69 | with open(os.path.join(cuda_home, "version.txt")) as f:
70 | cuda_version = f.read().strip().split()[-1]
71 | elif os.path.exists(os.path.join(cuda_home, "version.json")):
72 | with open(os.path.join(cuda_home, "version.json")) as f:
73 | cuda_version = json.load(f)["cuda"]["version"]
74 | else:
75 | raise RuntimeError("Cannot find the cuda version.")
76 | return cuda_version
77 |
78 |
79 | _C = None
80 |
81 | try:
82 | # try to import the compiled module (via setup.py)
83 | from gsplat import csrc as _C
84 | except ImportError:
85 | # if failed, try with JIT compilation
86 | if cuda_toolkit_available():
87 | name = "gsplat_cuda"
88 | build_dir = _get_build_directory(name, verbose=False)
89 | current_dir = os.path.dirname(os.path.abspath(__file__))
90 | glm_path = os.path.join(current_dir, "csrc", "third_party", "glm")
91 |
92 | extra_include_paths = [os.path.join(PATH, "csrc/"), glm_path]
93 | extra_cflags = ["-O3"]
94 | if NO_FAST_MATH:
95 | extra_cuda_cflags = ["-O3"]
96 | else:
97 | extra_cuda_cflags = ["-O3", "--use_fast_math"]
98 | sources = list(glob.glob(os.path.join(PATH, "csrc/*.cu"))) + list(
99 | glob.glob(os.path.join(PATH, "csrc/*.cpp"))
100 | )
101 |
102 | # If JIT is interrupted it might leave a lock in the build directory.
103 | # We dont want it to exist in any case.
104 | try:
105 | os.remove(os.path.join(build_dir, "lock"))
106 | except OSError:
107 | pass
108 |
109 | if os.path.exists(os.path.join(build_dir, "gsplat_cuda.so")) or os.path.exists(
110 | os.path.join(build_dir, "gsplat_cuda.lib")
111 | ):
112 | # If the build exists, we assume the extension has been built
113 | # and we can load it.
114 | _C = load_extension(
115 | name=name,
116 | sources=sources,
117 | extra_cflags=extra_cflags,
118 | extra_cuda_cflags=extra_cuda_cflags,
119 | extra_include_paths=extra_include_paths,
120 | build_directory=build_dir,
121 | )
122 | else:
123 | # Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck
124 | # if the build directory exists with a lock file in it.
125 | shutil.rmtree(build_dir)
126 | with Console().status(
127 | f"[bold yellow]gsplat: Setting up CUDA with MAX_JOBS={os.environ['MAX_JOBS']} (This may take a few minutes the first time)",
128 | spinner="bouncingBall",
129 | ):
130 | _C = load_extension(
131 | name=name,
132 | sources=sources,
133 | extra_cflags=extra_cflags,
134 | extra_cuda_cflags=extra_cuda_cflags,
135 | extra_include_paths=extra_include_paths,
136 | build_directory=build_dir,
137 | )
138 |
139 | else:
140 | Console().print(
141 | "[yellow]gsplat: No CUDA toolkit found. gsplat will be disabled.[/yellow]"
142 | )
143 |
144 | if need_to_unset_max_jobs:
145 | os.environ.pop("MAX_JOBS")
146 |
147 |
148 | __all__ = ["_C"]
149 |
--------------------------------------------------------------------------------
/gsplat/cuda/csrc/ext.cpp:
--------------------------------------------------------------------------------
1 | #include "bindings.h"
2 | #include
3 |
4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5 | m.def("compute_sh_fwd", &compute_sh_fwd_tensor);
6 | m.def("compute_sh_bwd", &compute_sh_bwd_tensor);
7 |
8 | m.def("quat_scale_to_covar_preci_fwd", &quat_scale_to_covar_preci_fwd_tensor);
9 | m.def("quat_scale_to_covar_preci_bwd", &quat_scale_to_covar_preci_bwd_tensor);
10 |
11 | m.def("persp_proj_fwd", &persp_proj_fwd_tensor);
12 | m.def("persp_proj_bwd", &persp_proj_bwd_tensor);
13 |
14 | m.def("lidar_proj_fwd", &lidar_proj_fwd_tensor);
15 | m.def("lidar_proj_bwd", &lidar_proj_bwd_tensor);
16 |
17 | m.def("world_to_cam_fwd", &world_to_cam_fwd_tensor);
18 | m.def("world_to_cam_bwd", &world_to_cam_bwd_tensor);
19 |
20 | m.def("compute_pix_velocity_fwd", &compute_pix_velocity_fwd_tensor);
21 | m.def("compute_pix_velocity_bwd", &compute_pix_velocity_bwd_tensor);
22 |
23 | m.def("compute_lidar_velocity_fwd", &compute_lidar_velocity_fwd_tensor);
24 | m.def("compute_lidar_velocity_bwd", &compute_lidar_velocity_bwd_tensor);
25 |
26 | m.def("fully_fused_projection_fwd", &fully_fused_projection_fwd_tensor);
27 | m.def("fully_fused_projection_bwd", &fully_fused_projection_bwd_tensor);
28 |
29 | m.def("fully_fused_lidar_projection_fwd", &fully_fused_lidar_projection_fwd_tensor);
30 | m.def("fully_fused_lidar_projection_bwd", &fully_fused_lidar_projection_bwd_tensor);
31 |
32 | m.def("isect_tiles", &isect_tiles_tensor);
33 | m.def("isect_lidar_tiles", &isect_lidar_tiles_tensor);
34 | m.def("isect_offset_encode", &isect_offset_encode_tensor);
35 |
36 | m.def("map_points_to_lidar_tiles", &map_points_to_lidar_tiles_tensor);
37 | m.def("points_mapping_offset_encode", &points_mapping_offset_encode_tensor);
38 | m.def("populate_image_from_points", &populate_image_from_points_tensor);
39 |
40 | m.def("rasterize_to_pixels_fwd", &rasterize_to_pixels_fwd_tensor);
41 | m.def("rasterize_to_pixels_bwd", &rasterize_to_pixels_bwd_tensor);
42 |
43 | m.def("rasterize_to_points_fwd", &rasterize_to_points_fwd_tensor);
44 | m.def("rasterize_to_points_bwd", &rasterize_to_points_bwd_tensor);
45 |
46 | m.def("rasterize_to_indices_in_range", &rasterize_to_indices_in_range_tensor);
47 | m.def("rasterize_to_indices_in_range_lidar", &rasterize_to_indices_in_range_lidar_tensor);
48 |
49 | // packed version
50 | m.def("fully_fused_projection_packed_fwd", &fully_fused_projection_packed_fwd_tensor);
51 | m.def("fully_fused_projection_packed_bwd", &fully_fused_projection_packed_bwd_tensor);
52 |
53 | m.def("compute_relocation", &compute_relocation_tensor);
54 | }
--------------------------------------------------------------------------------
/gsplat/cuda/csrc/helpers.cuh:
--------------------------------------------------------------------------------
1 | #ifndef GSPLAT_CUDA_HELPERS_H
2 | #define GSPLAT_CUDA_HELPERS_H
3 |
4 | #include "third_party/glm/glm/glm.hpp"
5 | #include "third_party/glm/glm/gtc/type_ptr.hpp"
6 | #include
7 | #include
8 |
9 | #include
10 |
11 | #define PRAGMA_UNROLL _Pragma("unroll")
12 |
13 | #define RAD_TO_DEG 57.2957795131f
14 |
15 | namespace cg = cooperative_groups;
16 |
17 | template
18 | inline __device__ void warpSum(T *val, WarpT &warp) {
19 | PRAGMA_UNROLL
20 | for (uint32_t i = 0; i < DIM; i++) {
21 | val[i] = cg::reduce(warp, val[i], cg::plus());
22 | }
23 | }
24 |
25 | template inline __device__ void warpSum(float3 &val, WarpT &warp) {
26 | val.x = cg::reduce(warp, val.x, cg::plus());
27 | val.y = cg::reduce(warp, val.y, cg::plus());
28 | val.z = cg::reduce(warp, val.z, cg::plus());
29 | }
30 |
31 | template inline __device__ void warpSum(float2 &val, WarpT &warp) {
32 | val.x = cg::reduce(warp, val.x, cg::plus());
33 | val.y = cg::reduce(warp, val.y, cg::plus());
34 | }
35 |
36 | template inline __device__ void warpSum(float &val, WarpT &warp) {
37 | val = cg::reduce(warp, val, cg::plus());
38 | }
39 |
40 | template inline __device__ void warpSum(glm::vec4 &val, WarpT &warp) {
41 | val.x = cg::reduce(warp, val.x, cg::plus());
42 | val.y = cg::reduce(warp, val.y, cg::plus());
43 | val.z = cg::reduce(warp, val.z, cg::plus());
44 | val.w = cg::reduce(warp, val.w, cg::plus());
45 | }
46 |
47 | template inline __device__ void warpSum(glm::vec3 &val, WarpT &warp) {
48 | val.x = cg::reduce(warp, val.x, cg::plus());
49 | val.y = cg::reduce(warp, val.y, cg::plus());
50 | val.z = cg::reduce(warp, val.z, cg::plus());
51 | }
52 |
53 | template inline __device__ void warpSum(glm::vec2 &val, WarpT &warp) {
54 | val.x = cg::reduce(warp, val.x, cg::plus());
55 | val.y = cg::reduce(warp, val.y, cg::plus());
56 | }
57 |
58 | template inline __device__ void warpSum(glm::mat4 &val, WarpT &warp) {
59 | warpSum(val[0], warp);
60 | warpSum(val[1], warp);
61 | warpSum(val[2], warp);
62 | warpSum(val[3], warp);
63 | }
64 |
65 | template inline __device__ void warpSum(glm::mat3 &val, WarpT &warp) {
66 | warpSum(val[0], warp);
67 | warpSum(val[1], warp);
68 | warpSum(val[2], warp);
69 | }
70 |
71 | template inline __device__ void warpSum(glm::mat2 &val, WarpT &warp) {
72 | warpSum(val[0], warp);
73 | warpSum(val[1], warp);
74 | }
75 |
76 | template inline __device__ void warpMax(float &val, WarpT &warp) {
77 | val = cg::reduce(warp, val, cg::greater());
78 | }
79 |
80 | inline __device__ void compute_pix_velocity(
81 | const glm::vec3 p_view,
82 | const glm::vec3 lin_vel,
83 | const glm::vec3 ang_vel,
84 | const glm::vec3 vel_view,
85 | const float fx,
86 | const float fy,
87 | const float cx,
88 | const float cy,
89 | const uint32_t width,
90 | const uint32_t height,
91 | glm::vec2 &total_vel_pix
92 | ) {
93 |
94 | float x = p_view[0], y = p_view[1], z = p_view[2];
95 |
96 | float tan_fovx = 0.5f * width / fx;
97 | float tan_fovy = 0.5f * height / fy;
98 | float lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx;
99 | float lim_x_neg = cx / fx + 0.3f * tan_fovx;
100 | float lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy;
101 | float lim_y_neg = cy / fy + 0.3f * tan_fovy;
102 |
103 | float rz = 1.f / z;
104 | float rz2 = rz * rz;
105 | float tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz));
106 | float ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz));
107 |
108 | // mat3x2 is 3 columns x 2 rows.
109 | glm::mat3x2 J = glm::mat3x2(fx * rz, 0.f, // 1st column
110 | 0.f, fy * rz, // 2nd column
111 | -fx * tx * rz2, -fy * ty * rz2 // 3rd column
112 | );
113 |
114 | glm::vec3 rot_part = glm::cross(ang_vel, p_view);
115 | glm::vec3 total_vel = lin_vel + rot_part - vel_view;
116 | // negative sign: move points to the opposite direction as the camera
117 | total_vel_pix = -J * total_vel;
118 | }
119 |
120 | inline __device__ void compute_lidar_velocity(
121 | const glm::vec3 p_view,
122 | const glm::vec3 lin_vel,
123 | const glm::vec3 ang_vel,
124 | const glm::vec3 vel_view,
125 | glm::mat3 &J,
126 | glm::vec3 &total_vel_pix
127 | ) {
128 | glm::vec3 rot_part = glm::cross(ang_vel, p_view);
129 | glm::vec3 total_vel = lin_vel + rot_part - vel_view;
130 |
131 | if (glm::length(J[0]) == 0) {
132 | const float x2 = p_view.x * p_view.x;
133 | const float y2 = p_view.y * p_view.y;
134 | const float z2 = p_view.z * p_view.z;
135 | const float r2 = x2 + y2 + z2;
136 | const float rinv = rsqrtf(r2);
137 | const float sqrtx2y2 = hypotf(p_view.x, p_view.y);
138 | const float sqrtx2y2_inv = rhypotf(p_view.x, p_view.y);
139 | const float xz = p_view.x * p_view.z;
140 | const float yz = p_view.y * p_view.z;
141 | const float r2sqrtx2y2_inv = 1.f / (r2) * sqrtx2y2_inv;
142 |
143 | // column major, mat3x2 is 3 columns x 2 rows.
144 | J = glm::mat3(
145 | -p_view.y / (x2 + y2) * RAD_TO_DEG, -xz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.x * rinv, // 1st column
146 | p_view.x / (x2 + y2) * RAD_TO_DEG, -yz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.y * rinv,// 2nd column
147 | 0.f , sqrtx2y2 / r2 * RAD_TO_DEG, p_view.z * rinv // 3rd column
148 | );
149 | }
150 |
151 | // negative sign: move points to the opposite direction as the camera
152 | total_vel_pix = -J * total_vel;
153 | }
154 |
155 | inline __device__ void compute_and_sum_pix_velocity_vjp(
156 | const glm::vec3 p_view,
157 | const glm::vec3 lin_vel,
158 | const glm::vec3 ang_vel,
159 | const glm::vec3 vel_view,
160 | const float fx,
161 | const float fy,
162 | const float cx,
163 | const float cy,
164 | const uint32_t width,
165 | const uint32_t height,
166 | const glm::vec2 v_pix_velocity,
167 | glm::vec3 &v_p_view_accumulator,
168 | glm::vec3 &v_vel_view)
169 | {
170 | float x = p_view[0], y = p_view[1], z = p_view[2];
171 |
172 | float tan_fovx = 0.5f * width / fx;
173 | float tan_fovy = 0.5f * height / fy;
174 | float lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx;
175 | float lim_x_neg = cx / fx + 0.3f * tan_fovx;
176 | float lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy;
177 | float lim_y_neg = cy / fy + 0.3f * tan_fovy;
178 |
179 | float rz = 1.f / z;
180 | float rz2 = rz * rz;
181 | float tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz));
182 | float ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz));
183 |
184 | // mat3x2 is 3 columns x 2 rows.
185 | glm::mat3x2 J = glm::mat3x2(fx * rz, 0.f, // 1st column
186 | 0.f, fy * rz, // 2nd column
187 | -fx * tx * rz2, -fy * ty * rz2 // 3rd column
188 | );
189 |
190 | glm::vec3 rot_part = glm::cross(ang_vel, p_view);
191 | glm::vec3 total_vel = lin_vel + rot_part - vel_view;
192 |
193 | glm::mat3x2 dJ_dz = glm::mat3x2(
194 | -fx * rz2,
195 | 0.f,
196 | 0.f,
197 | -fy * rz2,
198 | 2.f * fx * tx * rz2 * rz,
199 | 2.f * fy * ty * rz2 * rz
200 | );
201 |
202 | if (x * rz <= lim_x_pos && x * rz >= -lim_x_neg) {
203 | v_p_view_accumulator.x += v_pix_velocity.x * fx * rz2 * total_vel.z; //-glm::dot(v_pix_velocity, dJ_dx * total_vel);
204 | } else {
205 | v_p_view_accumulator.z += v_pix_velocity.x * fx * rz2 * rz * tx * total_vel.z; //-glm::dot(v_pix_velocity, dJ_dx * rz * tx * total_vel);
206 | }
207 | if (y * rz <= lim_y_pos && y * rz >= -lim_y_neg) {
208 | v_p_view_accumulator.y += v_pix_velocity.y * fy * rz2 * total_vel.z; //glm::dot(v_pix_velocity, dJ_dy * total_vel);
209 | } else {
210 | v_p_view_accumulator.z += v_pix_velocity.y * fy * rz2 * rz * ty * total_vel.z; // glm::dot(v_pix_velocity, dJ_dy * rz * ty * total_vel);
211 | }
212 | v_p_view_accumulator.z -= glm::dot(v_pix_velocity, dJ_dz * total_vel);
213 |
214 | glm::vec3 v_rot_part = -glm::transpose(J) * v_pix_velocity; // = v_total_vel
215 |
216 | // (v_rot_part^T * cross_prod_matrix(ang_vel))^T
217 | // = cross_prod_matrix(ang_vel)^T * v_rot_part // ... skew-symmetry
218 | // = -cross_prod_matrix(ang_vel) * v_rot_part
219 | // = -cross(ang_vel, v_rot_part)
220 | glm::vec3 v_p_view_rot = -glm::cross(ang_vel, v_rot_part);
221 |
222 | v_p_view_accumulator.x += v_p_view_rot[0];
223 | v_p_view_accumulator.y += v_p_view_rot[1];
224 | v_p_view_accumulator.z += v_p_view_rot[2];
225 |
226 | v_vel_view -= v_rot_part;
227 | }
228 |
229 | inline __device__ void compute_and_sum_lidar_velocity_vjp(
230 | const glm::vec3 p_view,
231 | const glm::vec3 lin_vel,
232 | const glm::vec3 ang_vel,
233 | const glm::vec3 vel_view,
234 | const glm::vec3 v_pix_velocity,
235 | glm::vec3 &v_p_view_accumulator,
236 | glm::vec3 &v_vel_view)
237 | {
238 | glm::vec3 rot_part = glm::cross(ang_vel, p_view);
239 | glm::vec3 total_vel = lin_vel + rot_part - vel_view;
240 |
241 | const float x = p_view.x;
242 | const float y = p_view.y;
243 | const float z = p_view.z;
244 | const float x2 = x * x;
245 | const float y2 = y * y;
246 | const float z2 = z * z;
247 | const float x4 = x2 * x2;
248 | const float y4 = y2 * y2;
249 | const float x2plusy2 = x2 + y2;
250 | const float sqrtx2y2 = hypot(x, y);
251 | const float sqrtx2y2_inv = rhypot(x, y);
252 | const float x2plusy2squared = x2plusy2 * x2plusy2;
253 | const float x2plusy2pow3by2 = sqrtx2y2 * x2plusy2;
254 | const float r2 = x2 + y2 + z2;
255 | const float r4 = r2 * r2;
256 | const float rinv = rsqrtf(r2);
257 | const float r3_inv = 1 / r2 * rinv;
258 | const float xz = x * z;
259 | const float xy = x * y;
260 | const float yz = y * z;
261 | const float r2sqrtx2y2_inv = 1.f / (r2) * sqrtx2y2_inv;
262 | const float xyz = x * y * z;
263 | const float denom1 = 1.f / (x2plusy2pow3by2 * r4);
264 | const float denom2 = 1.f / r4 * sqrtx2y2_inv;
265 |
266 | // column major, mat3x2 is 3 columns x 2 rows.
267 | glm::mat3 J = glm::mat3(
268 | -p_view.y / (x2 + y2) * RAD_TO_DEG, -xz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.x * rinv, // 1st column
269 | p_view.x / (x2 + y2) * RAD_TO_DEG, -yz * r2sqrtx2y2_inv * RAD_TO_DEG, p_view.y * rinv, // 2nd column
270 | 0.f, sqrtx2y2 / r2 * RAD_TO_DEG, p_view.z * rinv // 3rd column
271 | );
272 |
273 |
274 | glm::mat3 dJ_dx = glm::mat3(
275 | 2.f * xy / x2plusy2squared * RAD_TO_DEG , z * (2.f * x4 + x2 * y2 - y2 * (y2 + z2)) * denom1 * RAD_TO_DEG, (y2 + z2) * r3_inv,
276 | (y2 - x2) / x2plusy2squared * RAD_TO_DEG, xyz * (3.f * x2 + 3.f * y2 + z2) * denom1 * RAD_TO_DEG , - xy * r3_inv,
277 | 0.f , -x * (x2 + y2 - z2) * denom2 * RAD_TO_DEG , - xz * r3_inv
278 | );
279 |
280 | glm::mat3 dJ_dy = glm::mat3(
281 | (y2 - x2) / x2plusy2squared * RAD_TO_DEG, xyz * (3.f * x2 + 3.f * y2 + z2) * denom1 * RAD_TO_DEG , - xy * r3_inv,
282 | -2.f * xy / x2plusy2squared * RAD_TO_DEG, -z * (x4 + x2 * (z2 - y2) - 2.f * y4) * denom1 * RAD_TO_DEG, (x2 + z2) * r3_inv,
283 | 0.f , -y * (x2 + y2 - z2) * denom2 * RAD_TO_DEG , - yz * r3_inv
284 | );
285 |
286 | glm::mat3 dJ_dz = glm::mat3(
287 | 0.f , -x * (x2 + y2 - z2) * denom2 * RAD_TO_DEG, - xz * r3_inv,
288 | 0.f , -y * (x2 + y2 - z2) * denom2 * RAD_TO_DEG, - yz * r3_inv,
289 | 0.f , -2.f * z * sqrtx2y2 / r4 * RAD_TO_DEG , (x2 + y2) * r3_inv
290 | );
291 |
292 | v_p_view_accumulator.x -= glm::dot(v_pix_velocity, dJ_dx * total_vel);
293 | v_p_view_accumulator.y -= glm::dot(v_pix_velocity, dJ_dy * total_vel);
294 | v_p_view_accumulator.z -= glm::dot(v_pix_velocity, dJ_dz * total_vel);
295 |
296 | glm::vec3 v_rot_part = -glm::transpose(J) * v_pix_velocity; // = v_total_vel
297 |
298 | // (v_rot_part^T * cross_prod_matrix(ang_vel))^T
299 | // = cross_prod_matrix(ang_vel)^T * v_rot_part // ... skew-symmetry
300 | // = -cross_prod_matrix(ang_vel) * v_rot_part
301 | // = -cross(ang_vel, v_rot_part)
302 | glm::vec3 v_p_view_rot = -glm::cross(ang_vel, v_rot_part);
303 |
304 | v_p_view_accumulator.x += v_p_view_rot[0];
305 | v_p_view_accumulator.y += v_p_view_rot[1];
306 | v_p_view_accumulator.z += v_p_view_rot[2];
307 |
308 | v_vel_view -= v_rot_part;
309 | }
310 |
311 | #endif // GSPLAT_CUDA_HELPERS_H
--------------------------------------------------------------------------------
/gsplat/cuda/csrc/relocation.cu:
--------------------------------------------------------------------------------
1 | #include "bindings.h"
2 |
3 | // Equation (9) in "3D Gaussian Splatting as Markov Chain Monte Carlo"
4 | __global__ void compute_relocation_kernel(int N, float *opacities, float *scales,
5 | int *ratios, float *binoms, int n_max,
6 | float *new_opacities, float *new_scales) {
7 | int idx = threadIdx.x + blockIdx.x * blockDim.x;
8 | if (idx >= N)
9 | return;
10 |
11 | int n_idx = ratios[idx];
12 | float denom_sum = 0.0f;
13 |
14 | // compute new opacity
15 | new_opacities[idx] = 1.0f - powf(1.0f - opacities[idx], 1.0f / n_idx);
16 |
17 | // compute new scale
18 | for (int i = 1; i <= n_idx; ++i) {
19 | for (int k = 0; k <= (i - 1); ++k) {
20 | float bin_coeff = binoms[(i - 1) * n_max + k];
21 | float term = (pow(-1.0f, k) / sqrt(static_cast(k + 1))) *
22 | pow(new_opacities[idx], k + 1);
23 | denom_sum += (bin_coeff * term);
24 | }
25 | }
26 | float coeff = (opacities[idx] / denom_sum);
27 | for (int i = 0; i < 3; ++i)
28 | new_scales[idx * 3 + i] = coeff * scales[idx * 3 + i];
29 | }
30 |
31 | std::tuple
32 | compute_relocation_tensor(torch::Tensor &opacities, torch::Tensor &scales,
33 | torch::Tensor &ratios, torch::Tensor &binoms,
34 | const int n_max) {
35 | DEVICE_GUARD(opacities);
36 | CHECK_INPUT(opacities);
37 | CHECK_INPUT(scales);
38 | CHECK_INPUT(ratios);
39 | CHECK_INPUT(binoms);
40 | torch::Tensor new_opacities = torch::empty_like(opacities);
41 | torch::Tensor new_scales = torch::empty_like(scales);
42 |
43 | uint32_t N = opacities.size(0);
44 | if (N) {
45 | at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
46 | compute_relocation_kernel<<<(N + N_THREADS - 1) / N_THREADS, N_THREADS, 0,
47 | stream>>>(
48 | N, opacities.data_ptr(), scales.data_ptr(),
49 | ratios.data_ptr(), binoms.data_ptr(), n_max,
50 | new_opacities.data_ptr(), new_scales.data_ptr());
51 | }
52 | return std::make_tuple(new_opacities, new_scales);
53 | }
54 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/__init__.py:
--------------------------------------------------------------------------------
1 | # from typing import Callable
2 |
3 |
4 | # def _make_lazy_cuda_func(name: str) -> Callable:
5 | # def call_cuda(*args, **kwargs):
6 | # # pylint: disable=import-outside-toplevel
7 | # from ._backend import _C
8 |
9 | # return getattr(_C, name)(*args, **kwargs)
10 |
11 | # return call_cuda
12 |
13 |
14 | # nd_rasterize_forward = _make_lazy_cuda_func("nd_rasterize_forward")
15 | # nd_rasterize_backward = _make_lazy_cuda_func("nd_rasterize_backward")
16 | # rasterize_forward = _make_lazy_cuda_func("rasterize_forward")
17 | # rasterize_backward = _make_lazy_cuda_func("rasterize_backward")
18 | # compute_cov2d_bounds = _make_lazy_cuda_func("compute_cov2d_bounds")
19 | # project_gaussians_forward = _make_lazy_cuda_func("project_gaussians_forward")
20 | # project_gaussians_backward = _make_lazy_cuda_func("project_gaussians_backward")
21 | # compute_sh_forward = _make_lazy_cuda_func("compute_sh_forward")
22 | # compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward")
23 | # map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
24 | # get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges")
25 | # rasterize_forward = _make_lazy_cuda_func("rasterize_forward")
26 | # nd_rasterize_forward = _make_lazy_cuda_func("nd_rasterize_forward")
27 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/_backend.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import shutil
5 | from subprocess import DEVNULL, call
6 |
7 | from rich.console import Console
8 | from torch.utils.cpp_extension import _get_build_directory, load
9 |
10 | PATH = os.path.dirname(os.path.abspath(__file__))
11 |
12 |
13 | def cuda_toolkit_available():
14 | """Check if the nvcc is avaiable on the machine."""
15 | try:
16 | call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL)
17 | return True
18 | except FileNotFoundError:
19 | return False
20 |
21 |
22 | def cuda_toolkit_version():
23 | """Get the cuda toolkit version."""
24 | cuda_home = os.path.join(os.path.dirname(shutil.which("nvcc")), "..")
25 | if os.path.exists(os.path.join(cuda_home, "version.txt")):
26 | with open(os.path.join(cuda_home, "version.txt")) as f:
27 | cuda_version = f.read().strip().split()[-1]
28 | elif os.path.exists(os.path.join(cuda_home, "version.json")):
29 | with open(os.path.join(cuda_home, "version.json")) as f:
30 | cuda_version = json.load(f)["cuda"]["version"]
31 | else:
32 | raise RuntimeError("Cannot find the cuda version.")
33 | return cuda_version
34 |
35 |
36 | _C = None
37 |
38 | try:
39 | # try to import the compiled module (via setup.py)
40 | from gsplat import csrc_legacy as _C
41 | except ImportError:
42 | # if failed, try with JIT compilation
43 | if cuda_toolkit_available():
44 | name = "gsplat_cuda_legacy"
45 | build_dir = _get_build_directory(name, verbose=False)
46 | extra_include_paths = [os.path.join(PATH, "..", "cuda/", "csrc/")]
47 | extra_cflags = ["-O3"]
48 | extra_cuda_cflags = ["-O3"]
49 | sources = list(glob.glob(os.path.join(PATH, "csrc/*.cu"))) + list(
50 | glob.glob(os.path.join(PATH, "csrc/*.cpp"))
51 | )
52 |
53 | # If JIT is interrupted it might leave a lock in the build directory.
54 | # We dont want it to exist in any case.
55 | try:
56 | os.remove(os.path.join(build_dir, "lock"))
57 | except OSError:
58 | pass
59 |
60 | if os.path.exists(
61 | os.path.join(build_dir, "gsplat_cuda_legacy.so")
62 | ) or os.path.exists(os.path.join(build_dir, "gsplat_cuda_legacy.lib")):
63 | # If the build exists, we assume the extension has been built
64 | # and we can load it.
65 |
66 | _C = load(
67 | name=name,
68 | sources=sources,
69 | extra_cflags=extra_cflags,
70 | extra_cuda_cflags=extra_cuda_cflags,
71 | extra_include_paths=extra_include_paths,
72 | )
73 | else:
74 | # Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck
75 | # if the build directory exists with a lock file in it.
76 | shutil.rmtree(build_dir)
77 | with Console().status(
78 | "[bold yellow]gsplat (legacy): Setting up CUDA (This may take a few minutes the first time)",
79 | spinner="bouncingBall",
80 | ):
81 | _C = load(
82 | name=name,
83 | sources=sources,
84 | extra_cflags=extra_cflags,
85 | extra_cuda_cflags=extra_cuda_cflags,
86 | extra_include_paths=extra_include_paths,
87 | )
88 | else:
89 | Console().print(
90 | "[yellow]gsplat (legacy): No CUDA toolkit found. gsplat will be disabled.[/yellow]"
91 | )
92 |
93 |
94 | __all__ = ["_C"]
95 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.12) # You can adjust the minimum required version
2 | set(CMAKE_CUDA_ARCHITECTURES 70 75 89) # Ti 2080 uses 75. V100 uses 70. RTX 4090 uses 89.
3 |
4 | project(gsplat CXX CUDA)
5 | set(CMAKE_CXX_STANDARD 17)
6 | set(CMAKE_CXX_EXTENSIONS OFF)
7 | set(CMAKE_CUDA_STANDARD 17)
8 |
9 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
10 |
11 | # our library library
12 | add_library(gsplat forward.cu backward.cu helpers.cuh)
13 | target_link_libraries(gsplat PUBLIC cuda)
14 | target_include_directories(gsplat PRIVATE
15 | ${PROJECT_SOURCE_DIR}/third_party/glm
16 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
17 | )
18 | set_target_properties(gsplat PROPERTIES CUDA_ARCHITECTURES "70;75;86")
19 |
20 | # # To add an executable that uses the gsplat library,
21 | # # follow example in the comments for a script `run_forward.cpp`
22 | # # Add the executable
23 | # add_executable(run_forward run_forward.cpp)
24 |
25 | # # Link against CUDA runtime library
26 | # target_link_libraries(run_forward PUBLIC cuda gsplat)
27 |
28 | # # Include directories for the header-only library
29 | # target_include_directories(run_forward PRIVATE
30 | # ${PROJECT_SOURCE_DIR}/third_party/glm
31 | # )
32 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/backward.cuh:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | // for f : R(n) -> R(m), J in R(m, n),
6 | // v is cotangent in R(m), e.g. dL/df in R(m),
7 | // compute vjp i.e. vT J -> R(n)
8 | __global__ void project_gaussians_backward_kernel(
9 | const int num_points,
10 | const float3* __restrict__ means3d,
11 | const float3* __restrict__ scales,
12 | const float glob_scale,
13 | const float4* __restrict__ quats,
14 | const float* __restrict__ viewmat,
15 | const float4 intrins,
16 | const dim3 img_size,
17 | const float* __restrict__ cov3d,
18 | const int* __restrict__ radii,
19 | const float3* __restrict__ conics,
20 | const float* __restrict__ compensation,
21 | const float2* __restrict__ v_xy,
22 | const float* __restrict__ v_depth,
23 | const float3* __restrict__ v_conic,
24 | const float* __restrict__ v_compensation,
25 | float3* __restrict__ v_cov2d,
26 | float* __restrict__ v_cov3d,
27 | float3* __restrict__ v_mean3d,
28 | float3* __restrict__ v_scale,
29 | float4* __restrict__ v_quat
30 | );
31 |
32 | // compute jacobians of output image wrt binned and sorted gaussians
33 | __global__ void nd_rasterize_backward_kernel(
34 | const dim3 tile_bounds,
35 | const dim3 img_size,
36 | const unsigned channels,
37 | const int32_t* __restrict__ gaussians_ids_sorted,
38 | const int2* __restrict__ tile_bins,
39 | const float2* __restrict__ xys,
40 | const float3* __restrict__ conics,
41 | const float* __restrict__ rgbs,
42 | const float* __restrict__ opacities,
43 | const float* __restrict__ background,
44 | const float* __restrict__ final_Ts,
45 | const int* __restrict__ final_index,
46 | const float* __restrict__ v_output,
47 | const float* __restrict__ v_output_alpha,
48 | float2* __restrict__ v_xy,
49 | float2* __restrict__ v_xy_abs,
50 | float3* __restrict__ v_conic,
51 | float* __restrict__ v_rgb,
52 | float* __restrict__ v_opacity
53 | );
54 |
55 | __global__ void rasterize_backward_kernel(
56 | const dim3 tile_bounds,
57 | const dim3 img_size,
58 | const int32_t* __restrict__ gaussian_ids_sorted,
59 | const int2* __restrict__ tile_bins,
60 | const float2* __restrict__ xys,
61 | const float3* __restrict__ conics,
62 | const float3* __restrict__ rgbs,
63 | const float* __restrict__ opacities,
64 | const float3& __restrict__ background,
65 | const float* __restrict__ final_Ts,
66 | const int* __restrict__ final_index,
67 | const float3* __restrict__ v_output,
68 | const float* __restrict__ v_output_alpha,
69 | float2* __restrict__ v_xy,
70 | float2* __restrict__ v_xy_abs,
71 | float3* __restrict__ v_conic,
72 | float3* __restrict__ v_rgb,
73 | float* __restrict__ v_opacity
74 | );
75 |
76 | __device__ void project_cov3d_ewa_vjp(
77 | const float3 &mean3d,
78 | const float *cov3d,
79 | const float *viewmat,
80 | const float fx,
81 | const float fy,
82 | const float3 &v_cov2d,
83 | float3 &v_mean3d,
84 | float *v_cov3d
85 | );
86 |
87 | __device__ void scale_rot_to_cov3d_vjp(
88 | const float3 scale,
89 | const float glob_scale,
90 | const float4 quat,
91 | const float *v_cov3d,
92 | float3 &v_scale,
93 | float4 &v_quat
94 | );
95 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/bindings.h:
--------------------------------------------------------------------------------
1 | #include "cuda_runtime.h"
2 | #include "forward.cuh"
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) \
12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) \
14 | CHECK_CUDA(x); \
15 | CHECK_CONTIGUOUS(x)
16 | #define DEVICE_GUARD(_ten) \
17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));
18 |
19 | std::tuple<
20 | torch::Tensor, // output conics
21 | torch::Tensor> // output radii
22 | compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &A);
23 |
24 | torch::Tensor compute_sh_forward_tensor(
25 | const std::string &method,
26 | unsigned num_points,
27 | unsigned degree,
28 | unsigned degrees_to_use,
29 | torch::Tensor &viewdirs,
30 | torch::Tensor &coeffs
31 | );
32 |
33 | torch::Tensor compute_sh_backward_tensor(
34 | const std::string &method,
35 | unsigned num_points,
36 | unsigned degree,
37 | unsigned degrees_to_use,
38 | torch::Tensor &viewdirs,
39 | torch::Tensor &v_colors
40 | );
41 |
42 | std::tuple<
43 | torch::Tensor,
44 | torch::Tensor,
45 | torch::Tensor,
46 | torch::Tensor,
47 | torch::Tensor,
48 | torch::Tensor,
49 | torch::Tensor>
50 | project_gaussians_forward_tensor(
51 | const int num_points,
52 | torch::Tensor &means3d,
53 | torch::Tensor &scales,
54 | const float glob_scale,
55 | torch::Tensor &quats,
56 | torch::Tensor &viewmat,
57 | const float fx,
58 | const float fy,
59 | const float cx,
60 | const float cy,
61 | const unsigned img_height,
62 | const unsigned img_width,
63 | const unsigned block_width,
64 | const float clip_thresh
65 | );
66 |
67 | std::tuple<
68 | torch::Tensor,
69 | torch::Tensor,
70 | torch::Tensor,
71 | torch::Tensor,
72 | torch::Tensor>
73 | project_gaussians_backward_tensor(
74 | const int num_points,
75 | torch::Tensor &means3d,
76 | torch::Tensor &scales,
77 | const float glob_scale,
78 | torch::Tensor &quats,
79 | torch::Tensor &viewmat,
80 | const float fx,
81 | const float fy,
82 | const float cx,
83 | const float cy,
84 | const unsigned img_height,
85 | const unsigned img_width,
86 | torch::Tensor &cov3d,
87 | torch::Tensor &radii,
88 | torch::Tensor &conics,
89 | torch::Tensor &compensation,
90 | torch::Tensor &v_xy,
91 | torch::Tensor &v_depth,
92 | torch::Tensor &v_conic,
93 | torch::Tensor &v_compensation
94 | );
95 |
96 |
97 | std::tuple map_gaussian_to_intersects_tensor(
98 | const int num_points,
99 | const int num_intersects,
100 | const torch::Tensor &xys,
101 | const torch::Tensor &depths,
102 | const torch::Tensor &radii,
103 | const torch::Tensor &cum_tiles_hit,
104 | const std::tuple tile_bounds,
105 | const unsigned block_width
106 | );
107 |
108 | torch::Tensor get_tile_bin_edges_tensor(
109 | int num_intersects,
110 | const torch::Tensor &isect_ids_sorted,
111 | const std::tuple tile_bounds
112 | );
113 |
114 | std::tuple<
115 | torch::Tensor,
116 | torch::Tensor,
117 | torch::Tensor
118 | > rasterize_forward_tensor(
119 | const std::tuple tile_bounds,
120 | const std::tuple block,
121 | const std::tuple img_size,
122 | const torch::Tensor &gaussian_ids_sorted,
123 | const torch::Tensor &tile_bins,
124 | const torch::Tensor &xys,
125 | const torch::Tensor &conics,
126 | const torch::Tensor &colors,
127 | const torch::Tensor &opacities,
128 | const torch::Tensor &background
129 | );
130 |
131 | std::tuple<
132 | torch::Tensor,
133 | torch::Tensor,
134 | torch::Tensor
135 | > nd_rasterize_forward_tensor(
136 | const std::tuple tile_bounds,
137 | const std::tuple block,
138 | const std::tuple img_size,
139 | const torch::Tensor &gaussian_ids_sorted,
140 | const torch::Tensor &tile_bins,
141 | const torch::Tensor &xys,
142 | const torch::Tensor &conics,
143 | const torch::Tensor &colors,
144 | const torch::Tensor &opacities,
145 | const torch::Tensor &background
146 | );
147 |
148 |
149 | std::
150 | tuple<
151 | torch::Tensor, // dL_dxy
152 | torch::Tensor, // dL_dxy_abs
153 | torch::Tensor, // dL_dconic
154 | torch::Tensor, // dL_dcolors
155 | torch::Tensor // dL_dopacity
156 | >
157 | nd_rasterize_backward_tensor(
158 | const unsigned img_height,
159 | const unsigned img_width,
160 | const unsigned block_width,
161 | const torch::Tensor &gaussians_ids_sorted,
162 | const torch::Tensor &tile_bins,
163 | const torch::Tensor &xys,
164 | const torch::Tensor &conics,
165 | const torch::Tensor &colors,
166 | const torch::Tensor &opacities,
167 | const torch::Tensor &background,
168 | const torch::Tensor &final_Ts,
169 | const torch::Tensor &final_idx,
170 | const torch::Tensor &v_output, // dL_dout_color
171 | const torch::Tensor &v_output_alpha
172 | );
173 |
174 | std::
175 | tuple<
176 | torch::Tensor, // dL_dxy
177 | torch::Tensor, // dL_dxy_abs
178 | torch::Tensor, // dL_dconic
179 | torch::Tensor, // dL_dcolors
180 | torch::Tensor // dL_dopacity
181 | >
182 | rasterize_backward_tensor(
183 | const unsigned img_height,
184 | const unsigned img_width,
185 | const unsigned block_width,
186 | const torch::Tensor &gaussians_ids_sorted,
187 | const torch::Tensor &tile_bins,
188 | const torch::Tensor &xys,
189 | const torch::Tensor &conics,
190 | const torch::Tensor &colors,
191 | const torch::Tensor &opacities,
192 | const torch::Tensor &background,
193 | const torch::Tensor &final_Ts,
194 | const torch::Tensor &final_idx,
195 | const torch::Tensor &v_output, // dL_dout_color
196 | const torch::Tensor &v_output_alpha
197 | );
198 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/config.h:
--------------------------------------------------------------------------------
1 | #define MAX_BLOCK_SIZE ( 16 * 16 )
2 | #define N_THREADS 256
3 |
4 | #define MAX_REGISTER_CHANNELS 3
5 |
6 | #define CUDA_CALL(x) \
7 | do { \
8 | if ((x) != cudaSuccess) { \
9 | printf( \
10 | "Error at %s:%d - %s\n", \
11 | __FILE__, \
12 | __LINE__, \
13 | cudaGetErrorString(cudaGetLastError()) \
14 | ); \
15 | exit(EXIT_FAILURE); \
16 | } \
17 | } while (0)
18 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/ext.cpp:
--------------------------------------------------------------------------------
1 | #include "bindings.h"
2 | #include
3 |
4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5 | // auto diff functions
6 | m.def("nd_rasterize_forward", &nd_rasterize_forward_tensor);
7 | m.def("nd_rasterize_backward", &nd_rasterize_backward_tensor);
8 | m.def("rasterize_forward", &rasterize_forward_tensor);
9 | m.def("rasterize_backward", &rasterize_backward_tensor);
10 | m.def("project_gaussians_forward", &project_gaussians_forward_tensor);
11 | m.def("project_gaussians_backward", &project_gaussians_backward_tensor);
12 | m.def("compute_sh_forward", &compute_sh_forward_tensor);
13 | m.def("compute_sh_backward", &compute_sh_backward_tensor);
14 | // utils
15 | m.def("compute_cov2d_bounds", &compute_cov2d_bounds_tensor);
16 | m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
17 | m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor);
18 | }
19 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/forward.cuh:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | // compute the 2d gaussian parameters from 3d gaussian parameters
6 | __global__ void project_gaussians_forward_kernel(
7 | const int num_points,
8 | const float3* __restrict__ means3d,
9 | const float3* __restrict__ scales,
10 | const float glob_scale,
11 | const float4* __restrict__ quats,
12 | const float* __restrict__ viewmat,
13 | const float4 intrins,
14 | const dim3 img_size,
15 | const dim3 tile_bounds,
16 | const unsigned block_width,
17 | const float clip_thresh,
18 | float* __restrict__ covs3d,
19 | float2* __restrict__ xys,
20 | float* __restrict__ depths,
21 | int* __restrict__ radii,
22 | float3* __restrict__ conics,
23 | float* __restrict__ compensation,
24 | int32_t* __restrict__ num_tiles_hit
25 | );
26 |
27 | // compute output color image from binned and sorted gaussians
28 | __global__ void rasterize_forward(
29 | const dim3 tile_bounds,
30 | const dim3 img_size,
31 | const int32_t* __restrict__ gaussian_ids_sorted,
32 | const int2* __restrict__ tile_bins,
33 | const float2* __restrict__ xys,
34 | const float3* __restrict__ conics,
35 | const float3* __restrict__ colors,
36 | const float* __restrict__ opacities,
37 | float* __restrict__ final_Ts,
38 | int* __restrict__ final_index,
39 | float3* __restrict__ out_img,
40 | const float3& __restrict__ background
41 | );
42 |
43 | // compute output color image from binned and sorted gaussians
44 | __global__ void nd_rasterize_forward(
45 | const dim3 tile_bounds,
46 | const dim3 img_size,
47 | const unsigned channels,
48 | const int32_t* __restrict__ gaussian_ids_sorted,
49 | const int2* __restrict__ tile_bins,
50 | const float2* __restrict__ xys,
51 | const float3* __restrict__ conics,
52 | const float* __restrict__ colors,
53 | const float* __restrict__ opacities,
54 | float* __restrict__ final_Ts,
55 | int* __restrict__ final_index,
56 | float* __restrict__ out_img,
57 | const float* __restrict__ background
58 | );
59 |
60 | // device helper to approximate projected 2d cov from 3d mean and cov
61 | __device__ void project_cov3d_ewa(
62 | const float3 &mean3d,
63 | const float *cov3d,
64 | const float *viewmat,
65 | const float fx,
66 | const float fy,
67 | const float tan_fovx,
68 | const float tan_fovy,
69 | float3 &cov2d,
70 | float &comp
71 | );
72 |
73 | // device helper to get 3D covariance from scale and quat parameters
74 | __device__ void scale_rot_to_cov3d(
75 | const float3 scale, const float glob_scale, const float4 quat, float *cov3d
76 | );
77 |
78 | __global__ void map_gaussian_to_intersects(
79 | const int num_points,
80 | const float2* __restrict__ xys,
81 | const float* __restrict__ depths,
82 | const int* __restrict__ radii,
83 | const int32_t* __restrict__ cum_tiles_hit,
84 | const dim3 tile_bounds,
85 | const unsigned block_width,
86 | int64_t* __restrict__ isect_ids,
87 | int32_t* __restrict__ gaussian_ids
88 | );
89 |
90 | __global__ void get_tile_bin_edges(
91 | const int num_intersects, const int64_t* __restrict__ isect_ids_sorted, int2* __restrict__ tile_bins
92 | );
93 |
94 | __global__ void rasterize_forward(
95 | const dim3 tile_bounds,
96 | const dim3 img_size,
97 | const int32_t* __restrict__ gaussian_ids_sorted,
98 | const int2* __restrict__ tile_bins,
99 | const float2* __restrict__ xys,
100 | const float3* __restrict__ conics,
101 | const float3* __restrict__ colors,
102 | const float* __restrict__ opacities,
103 | float* __restrict__ final_Ts,
104 | int* __restrict__ final_index,
105 | float3* __restrict__ out_img,
106 | const float3& __restrict__ background
107 | );
108 |
109 | __global__ void nd_rasterize_forward(
110 | const dim3 tile_bounds,
111 | const dim3 img_size,
112 | const unsigned channels,
113 | const int32_t* __restrict__ gaussian_ids_sorted,
114 | const int2* __restrict__ tile_bins,
115 | const float2* __restrict__ xys,
116 | const float3* __restrict__ conics,
117 | const float* __restrict__ colors,
118 | const float* __restrict__ opacities,
119 | float* __restrict__ final_Ts,
120 | int* __restrict__ final_index,
121 | float* __restrict__ out_img,
122 | const float* __restrict__ background
123 | );
124 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/helpers.cuh:
--------------------------------------------------------------------------------
1 | #include "config.h"
2 | #include "third_party/glm/glm/glm.hpp"
3 | #include "third_party/glm/glm/gtc/type_ptr.hpp"
4 | #include
5 | #include
6 |
7 | inline __device__ void get_bbox(const float2 center, const float2 dims,
8 | const dim3 img_size, uint2 &bb_min, uint2 &bb_max) {
9 | // get bounding box with center and dims, within bounds
10 | // bounding box coords returned in tile coords, inclusive min, exclusive max
11 | // clamp between 0 and tile bounds
12 | bb_min.x = min(max(0, (int)(center.x - dims.x)), img_size.x);
13 | bb_max.x = min(max(0, (int)(center.x + dims.x + 1)), img_size.x);
14 | bb_min.y = min(max(0, (int)(center.y - dims.y)), img_size.y);
15 | bb_max.y = min(max(0, (int)(center.y + dims.y + 1)), img_size.y);
16 | }
17 |
18 | inline __device__ void get_tile_bbox(const float2 pix_center, const float pix_radius,
19 | const dim3 tile_bounds, uint2 &tile_min,
20 | uint2 &tile_max, const int block_size) {
21 | // gets gaussian dimensions in tile space, i.e. the span of a gaussian in
22 | // tile_grid (image divided into tiles)
23 | float2 tile_center = {pix_center.x / (float)block_size,
24 | pix_center.y / (float)block_size};
25 | float2 tile_radius = {pix_radius / (float)block_size,
26 | pix_radius / (float)block_size};
27 | get_bbox(tile_center, tile_radius, tile_bounds, tile_min, tile_max);
28 | }
29 |
30 | inline __device__ bool compute_cov2d_bounds(const float3 cov2d, float3 &conic,
31 | float &radius) {
32 | // find eigenvalues of 2d covariance matrix
33 | // expects upper triangular values of cov matrix as float3
34 | // then compute the radius and conic dimensions
35 | // the conic is the inverse cov2d matrix, represented here with upper
36 | // triangular values.
37 | float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
38 | if (det == 0.f)
39 | return false;
40 | float inv_det = 1.f / det;
41 |
42 | // inverse of 2x2 cov2d matrix
43 | conic.x = cov2d.z * inv_det;
44 | conic.y = -cov2d.y * inv_det;
45 | conic.z = cov2d.x * inv_det;
46 |
47 | float b = 0.5f * (cov2d.x + cov2d.z);
48 | float v1 = b + sqrt(max(0.1f, b * b - det));
49 | float v2 = b - sqrt(max(0.1f, b * b - det));
50 | // take 3 sigma of covariance
51 | radius = ceil(3.f * sqrt(max(v1, v2)));
52 | return true;
53 | }
54 |
55 | // compute vjp from df/d_conic to df/c_cov2d
56 | inline __device__ void cov2d_to_conic_vjp(const float3 &conic, const float3 &v_conic,
57 | float3 &v_cov2d) {
58 | // conic = inverse cov2d
59 | // df/d_cov2d = -conic * df/d_conic * conic
60 | glm::mat2 X = glm::mat2(conic.x, conic.y, conic.y, conic.z);
61 | glm::mat2 G = glm::mat2(v_conic.x, v_conic.y / 2.f, v_conic.y / 2.f, v_conic.z);
62 | glm::mat2 v_Sigma = -X * G * X;
63 | v_cov2d.x = v_Sigma[0][0];
64 | v_cov2d.y = v_Sigma[1][0] + v_Sigma[0][1];
65 | v_cov2d.z = v_Sigma[1][1];
66 | }
67 |
68 | inline __device__ void cov2d_to_compensation_vjp(const float compensation,
69 | const float3 &conic,
70 | const float v_compensation,
71 | float3 &v_cov2d) {
72 | // comp = sqrt(det(cov2d - 0.3 I) / det(cov2d))
73 | // conic = inverse(cov2d)
74 | // df / d_cov2d = df / d comp * 0.5 / comp * [ d comp^2 / d cov2d ]
75 | // d comp^2 / d cov2d = (1 - comp^2) * conic - 0.3 I * det(conic)
76 | float inv_det = conic.x * conic.z - conic.y * conic.y;
77 | float one_minus_sqr_comp = 1 - compensation * compensation;
78 | float v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6);
79 | v_cov2d.x += v_sqr_comp * (one_minus_sqr_comp * conic.x - 0.3 * inv_det);
80 | v_cov2d.y += 2 * v_sqr_comp * (one_minus_sqr_comp * conic.y);
81 | v_cov2d.z += v_sqr_comp * (one_minus_sqr_comp * conic.z - 0.3 * inv_det);
82 | }
83 |
84 | // helper for applying R^T * p for a ROW MAJOR 4x3 matrix [R, t], ignoring t
85 | inline __device__ float3 transform_4x3_rot_only_transposed(const float *mat,
86 | const float3 p) {
87 | float3 out = {
88 | mat[0] * p.x + mat[4] * p.y + mat[8] * p.z,
89 | mat[1] * p.x + mat[5] * p.y + mat[9] * p.z,
90 | mat[2] * p.x + mat[6] * p.y + mat[10] * p.z,
91 | };
92 | return out;
93 | }
94 |
95 | // helper for applying R * p + T, expect mat to be ROW MAJOR
96 | inline __device__ float3 transform_4x3(const float *mat, const float3 p) {
97 | float3 out = {
98 | mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3],
99 | mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7],
100 | mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11],
101 | };
102 | return out;
103 | }
104 |
105 | // helper to apply 4x4 transform to 3d vector, return homo coords
106 | // expects mat to be ROW MAJOR
107 | inline __device__ float4 transform_4x4(const float *mat, const float3 p) {
108 | float4 out = {
109 | mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3],
110 | mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7],
111 | mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11],
112 | mat[12] * p.x + mat[13] * p.y + mat[14] * p.z + mat[15],
113 | };
114 | return out;
115 | }
116 |
117 | inline __device__ float2 project_pix(const float2 fxfy, const float3 p_view,
118 | const float2 pp) {
119 | float rw = 1.f / (p_view.z + 1e-6f);
120 | float2 p_proj = {p_view.x * rw, p_view.y * rw};
121 | float2 p_pix = {p_proj.x * fxfy.x + pp.x, p_proj.y * fxfy.y + pp.y};
122 | return p_pix;
123 | }
124 |
125 | // given v_xy_pix, get v_xyz
126 | inline __device__ float3 project_pix_vjp(const float2 fxfy, const float3 p_view,
127 | const float2 v_xy) {
128 | float rw = 1.f / (p_view.z + 1e-6f);
129 | float2 v_proj = {fxfy.x * v_xy.x, fxfy.y * v_xy.y};
130 | float3 v_view = {v_proj.x * rw, v_proj.y * rw,
131 | -(v_proj.x * p_view.x + v_proj.y * p_view.y) * rw * rw};
132 | return v_view;
133 | }
134 |
135 | inline __device__ glm::mat3 quat_to_rotmat(const float4 quat) {
136 | // quat to rotation matrix
137 | float w = quat.x;
138 | float x = quat.y;
139 | float y = quat.z;
140 | float z = quat.w;
141 |
142 | // glm matrices are column-major
143 | return glm::mat3(
144 | 1.f - 2.f * (y * y + z * z), 2.f * (x * y + w * z), 2.f * (x * z - w * y),
145 | 2.f * (x * y - w * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z + w * x),
146 | 2.f * (x * z + w * y), 2.f * (y * z - w * x), 1.f - 2.f * (x * x + y * y));
147 | }
148 |
149 | inline __device__ float4 quat_to_rotmat_vjp(const float4 quat, const glm::mat3 v_R) {
150 | float w = quat.x;
151 | float x = quat.y;
152 | float y = quat.z;
153 | float z = quat.w;
154 |
155 | float4 v_quat;
156 | // v_R is COLUMN MAJOR
157 | // w element stored in x field
158 | v_quat.x = 2.f * (
159 | // v_quat.w = 2.f * (
160 | x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) +
161 | z * (v_R[0][1] - v_R[1][0]));
162 | // x element in y field
163 | v_quat.y =
164 | 2.f * (
165 | // v_quat.x = 2.f * (
166 | -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) +
167 | z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1]));
168 | // y element in z field
169 | v_quat.z =
170 | 2.f * (
171 | // v_quat.y = 2.f * (
172 | x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) +
173 | z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2]));
174 | // z element in w field
175 | v_quat.w =
176 | 2.f * (
177 | // v_quat.z = 2.f * (
178 | x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) -
179 | 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0]));
180 | return v_quat;
181 | }
182 |
183 | inline __device__ glm::mat3 scale_to_mat(const float3 scale, const float glob_scale) {
184 | glm::mat3 S = glm::mat3(1.f);
185 | S[0][0] = glob_scale * scale.x;
186 | S[1][1] = glob_scale * scale.y;
187 | S[2][2] = glob_scale * scale.z;
188 | return S;
189 | }
190 |
191 | // device helper for culling near points
192 | inline __device__ bool clip_near_plane(const float3 p, const float *viewmat,
193 | float3 &p_view, float thresh) {
194 | p_view = transform_4x3(viewmat, p);
195 | if (p_view.z <= thresh) {
196 | return true;
197 | }
198 | return false;
199 | }
200 |
--------------------------------------------------------------------------------
/gsplat/cuda_legacy/csrc/sh.cuh:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #define CHANNELS 3
4 | namespace cg = cooperative_groups;
5 |
6 | enum class SHType {
7 | Poly,
8 | Fast,
9 | };
10 |
11 | __device__ __constant__ float SH_C0 = 0.28209479177387814f;
12 | __device__ __constant__ float SH_C1 = 0.4886025119029199f;
13 | __device__ __constant__ float SH_C2[] = {
14 | 1.0925484305920792f,
15 | -1.0925484305920792f,
16 | 0.31539156525252005f,
17 | -1.0925484305920792f,
18 | 0.5462742152960396f};
19 | __device__ __constant__ float SH_C3[] = {
20 | -0.5900435899266435f,
21 | 2.890611442640554f,
22 | -0.4570457994644658f,
23 | 0.3731763325901154f,
24 | -0.4570457994644658f,
25 | 1.445305721320277f,
26 | -0.5900435899266435f};
27 | __device__ __constant__ float SH_C4[] = {
28 | 2.5033429417967046f,
29 | -1.7701307697799304,
30 | 0.9461746957575601f,
31 | -0.6690465435572892f,
32 | 0.10578554691520431f,
33 | -0.6690465435572892f,
34 | 0.47308734787878004f,
35 | -1.7701307697799304f,
36 | 0.6258357354491761f};
37 |
38 | // This function is used in both host and device code
39 | __host__ __device__ unsigned num_sh_bases(const unsigned degree) {
40 | if (degree == 0)
41 | return 1;
42 | if (degree == 1)
43 | return 4;
44 | if (degree == 2)
45 | return 9;
46 | if (degree == 3)
47 | return 16;
48 | return 25;
49 | }
50 |
51 | // Evaluate spherical harmonics bases at unit direction for high orders using approach described by
52 | // Efficient Spherical Harmonic Evaluation, Peter-Pike Sloan, JCGT 2013
53 | // See https://jcgt.org/published/0002/02/06/ for reference implementation
54 | __device__ void sh_coeffs_to_color_fast(
55 | const unsigned degree,
56 | const float3 &viewdir,
57 | const float *coeffs,
58 | float *colors
59 | ) {
60 | for (int c = 0; c < CHANNELS; ++c) {
61 | colors[c] = 0.2820947917738781f * coeffs[c];
62 | }
63 | if (degree < 1) {
64 | return;
65 | }
66 |
67 | float norm = sqrt(
68 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z
69 | );
70 | float x = viewdir.x / norm;
71 | float y = viewdir.y / norm;
72 | float z = viewdir.z / norm;
73 |
74 | for (int c = 0; c < CHANNELS; ++c) {
75 | float fTmp0A = 0.48860251190292f;
76 | colors[c] += fTmp0A *
77 | (-y * coeffs[1 * CHANNELS + c] +
78 | z * coeffs[2 * CHANNELS + c] -
79 | x * coeffs[3 * CHANNELS + c]);
80 | }
81 | if (degree < 2) {
82 | return;
83 | }
84 | float z2 = z * z;
85 |
86 | float fTmp0B = -1.092548430592079f * z;
87 | float fTmp1A = 0.5462742152960395f;
88 | float fC1 = x * x - y * y;
89 | float fS1 = 2.f * x * y;
90 | float pSH6 = (0.9461746957575601f * z2 - 0.3153915652525201f);
91 | float pSH7 = fTmp0B * x;
92 | float pSH5 = fTmp0B * y;
93 | float pSH8 = fTmp1A * fC1;
94 | float pSH4 = fTmp1A * fS1;
95 | for (int c = 0; c < CHANNELS; ++c) {
96 | colors[c] +=
97 | pSH4 * coeffs[4 * CHANNELS + c] + pSH5 * coeffs[5 * CHANNELS + c] +
98 | pSH6 * coeffs[6 * CHANNELS + c] + pSH7 * coeffs[7 * CHANNELS + c] +
99 | pSH8 * coeffs[8 * CHANNELS + c];
100 | }
101 | if (degree < 3) {
102 | return;
103 | }
104 |
105 | float fTmp0C = -2.285228997322329f * z2 + 0.4570457994644658f;
106 | float fTmp1B = 1.445305721320277f * z;
107 | float fTmp2A = -0.5900435899266435f;
108 | float fC2 = x * fC1 - y * fS1;
109 | float fS2 = x * fS1 + y * fC1;
110 | float pSH12 = z * (1.865881662950577f * z2 - 1.119528997770346f);
111 | float pSH13 = fTmp0C * x;
112 | float pSH11 = fTmp0C * y;
113 | float pSH14 = fTmp1B * fC1;
114 | float pSH10 = fTmp1B * fS1;
115 | float pSH15 = fTmp2A * fC2;
116 | float pSH9 = fTmp2A * fS2;
117 | for (int c = 0; c < CHANNELS; ++c) {
118 | colors[c] += pSH9 * coeffs[9 * CHANNELS + c] +
119 | pSH10 * coeffs[10 * CHANNELS + c] +
120 | pSH11 * coeffs[11 * CHANNELS + c] +
121 | pSH12 * coeffs[12 * CHANNELS + c] +
122 | pSH13 * coeffs[13 * CHANNELS + c] +
123 | pSH14 * coeffs[14 * CHANNELS + c] +
124 | pSH15 * coeffs[15 * CHANNELS + c];
125 | }
126 | if (degree < 4) {
127 | return;
128 | }
129 |
130 | float fTmp0D = z * (-4.683325804901025f * z2 + 2.007139630671868f);
131 | float fTmp1C = 3.31161143515146f * z2 - 0.47308734787878f;
132 | float fTmp2B = -1.770130769779931f * z;
133 | float fTmp3A = 0.6258357354491763f;
134 | float fC3 = x * fC2 - y * fS2;
135 | float fS3 = x * fS2 + y * fC2;
136 | float pSH20 = (1.984313483298443f * z * pSH12 - 1.006230589874905f * pSH6);
137 | float pSH21 = fTmp0D * x;
138 | float pSH19 = fTmp0D * y;
139 | float pSH22 = fTmp1C * fC1;
140 | float pSH18 = fTmp1C * fS1;
141 | float pSH23 = fTmp2B * fC2;
142 | float pSH17 = fTmp2B * fS2;
143 | float pSH24 = fTmp3A * fC3;
144 | float pSH16 = fTmp3A * fS3;
145 | for (int c = 0; c < CHANNELS; ++c) {
146 | colors[c] += pSH16 * coeffs[16 * CHANNELS + c] +
147 | pSH17 * coeffs[17 * CHANNELS + c] +
148 | pSH18 * coeffs[18 * CHANNELS + c] +
149 | pSH19 * coeffs[19 * CHANNELS + c] +
150 | pSH20 * coeffs[20 * CHANNELS + c] +
151 | pSH21 * coeffs[21 * CHANNELS + c] +
152 | pSH22 * coeffs[22 * CHANNELS + c] +
153 | pSH23 * coeffs[23 * CHANNELS + c] +
154 | pSH24 * coeffs[24 * CHANNELS + c];
155 | }
156 | }
157 |
158 | __device__ void sh_coeffs_to_color_fast_vjp(
159 | const unsigned degree,
160 | const float3 &viewdir,
161 | const float *v_colors,
162 | float *v_coeffs
163 | ) {
164 | // Expects v_colors to be len CHANNELS
165 | // and v_coeffs to be num_bases * CHANNELS
166 | #pragma unroll
167 | for (int c = 0; c < CHANNELS; ++c) {
168 | v_coeffs[c] = 0.2820947917738781f * v_colors[c];
169 | }
170 | if (degree < 1) {
171 | return;
172 | }
173 | float norm = sqrt(
174 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z
175 | );
176 | float x = viewdir.x / norm;
177 | float y = viewdir.y / norm;
178 | float z = viewdir.z / norm;
179 |
180 |
181 | float fTmp0A = 0.48860251190292f;
182 | #pragma unroll
183 | for (int c = 0; c < CHANNELS; ++c) {
184 | v_coeffs[1 * CHANNELS + c] = -fTmp0A * y * v_colors[c];
185 | v_coeffs[2 * CHANNELS + c] = fTmp0A * z * v_colors[c];
186 | v_coeffs[3 * CHANNELS + c] = -fTmp0A * x * v_colors[c];
187 | }
188 | if (degree < 2) {
189 | return;
190 | }
191 |
192 | float z2 = z * z;
193 | float fTmp0B = -1.092548430592079f * z;
194 | float fTmp1A = 0.5462742152960395f;
195 | float fC1 = x * x - y * y;
196 | float fS1 = 2.f * x * y;
197 | float pSH6 = (0.9461746957575601f * z2 - 0.3153915652525201f);
198 | float pSH7 = fTmp0B * x;
199 | float pSH5 = fTmp0B * y;
200 | float pSH8 = fTmp1A * fC1;
201 | float pSH4 = fTmp1A * fS1;
202 | #pragma unroll
203 | for (int c = 0; c < CHANNELS; ++c) {
204 | v_coeffs[4 * CHANNELS + c] = pSH4 * v_colors[c];
205 | v_coeffs[5 * CHANNELS + c] = pSH5 * v_colors[c];
206 | v_coeffs[6 * CHANNELS + c] = pSH6 * v_colors[c];
207 | v_coeffs[7 * CHANNELS + c] = pSH7 * v_colors[c];
208 | v_coeffs[8 * CHANNELS + c] = pSH8 * v_colors[c];
209 | }
210 | if (degree < 3) {
211 | return;
212 | }
213 |
214 | float fTmp0C = -2.285228997322329f * z2 + 0.4570457994644658f;
215 | float fTmp1B = 1.445305721320277f * z;
216 | float fTmp2A = -0.5900435899266435f;
217 | float fC2 = x * fC1 - y * fS1;
218 | float fS2 = x * fS1 + y * fC1;
219 | float pSH12 = z * (1.865881662950577f * z2 - 1.119528997770346f);
220 | float pSH13 = fTmp0C * x;
221 | float pSH11 = fTmp0C * y;
222 | float pSH14 = fTmp1B * fC1;
223 | float pSH10 = fTmp1B * fS1;
224 | float pSH15 = fTmp2A * fC2;
225 | float pSH9 = fTmp2A * fS2;
226 | #pragma unroll
227 | for (int c = 0; c < CHANNELS; ++c) {
228 | v_coeffs[9 * CHANNELS + c] = pSH9 * v_colors[c];
229 | v_coeffs[10 * CHANNELS + c] = pSH10 * v_colors[c];
230 | v_coeffs[11 * CHANNELS + c] = pSH11 * v_colors[c];
231 | v_coeffs[12 * CHANNELS + c] = pSH12 * v_colors[c];
232 | v_coeffs[13 * CHANNELS + c] = pSH13 * v_colors[c];
233 | v_coeffs[14 * CHANNELS + c] = pSH14 * v_colors[c];
234 | v_coeffs[15 * CHANNELS + c] = pSH15 * v_colors[c];
235 | }
236 | if (degree < 4) {
237 | return;
238 | }
239 |
240 | float fTmp0D = z * (-4.683325804901025f * z2 + 2.007139630671868f);
241 | float fTmp1C = 3.31161143515146f * z2 - 0.47308734787878f;
242 | float fTmp2B = -1.770130769779931f * z;
243 | float fTmp3A = 0.6258357354491763f;
244 | float fC3 = x * fC2 - y * fS2;
245 | float fS3 = x * fS2 + y * fC2;
246 | float pSH20 = (1.984313483298443f * z * pSH12 + -1.006230589874905f * pSH6);
247 | float pSH21 = fTmp0D * x;
248 | float pSH19 = fTmp0D * y;
249 | float pSH22 = fTmp1C * fC1;
250 | float pSH18 = fTmp1C * fS1;
251 | float pSH23 = fTmp2B * fC2;
252 | float pSH17 = fTmp2B * fS2;
253 | float pSH24 = fTmp3A * fC3;
254 | float pSH16 = fTmp3A * fS3;
255 | #pragma unroll
256 | for (int c = 0; c < CHANNELS; ++c) {
257 | v_coeffs[16 * CHANNELS + c] = pSH16 * v_colors[c];
258 | v_coeffs[17 * CHANNELS + c] = pSH17 * v_colors[c];
259 | v_coeffs[18 * CHANNELS + c] = pSH18 * v_colors[c];
260 | v_coeffs[19 * CHANNELS + c] = pSH19 * v_colors[c];
261 | v_coeffs[20 * CHANNELS + c] = pSH20 * v_colors[c];
262 | v_coeffs[21 * CHANNELS + c] = pSH21 * v_colors[c];
263 | v_coeffs[22 * CHANNELS + c] = pSH22 * v_colors[c];
264 | v_coeffs[23 * CHANNELS + c] = pSH23 * v_colors[c];
265 | v_coeffs[24 * CHANNELS + c] = pSH24 * v_colors[c];
266 | }
267 | }
268 | __device__ void sh_coeffs_to_color(
269 | const unsigned degree,
270 | const float3 &viewdir,
271 | const float *coeffs,
272 | float *colors
273 | ) {
274 | // Expects v_colors to be len CHANNELS
275 | // and v_coeffs to be num_bases * CHANNELS
276 | for (int c = 0; c < CHANNELS; ++c) {
277 | colors[c] = SH_C0 * coeffs[c];
278 | }
279 | if (degree < 1) {
280 | return;
281 | }
282 |
283 | float norm = sqrt(
284 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z
285 | );
286 | float x = viewdir.x / norm;
287 | float y = viewdir.y / norm;
288 | float z = viewdir.z / norm;
289 |
290 | float xx = x * x;
291 | float xy = x * y;
292 | float xz = x * z;
293 | float yy = y * y;
294 | float yz = y * z;
295 | float zz = z * z;
296 | // expects CHANNELS * num_bases coefficients
297 | // supports up to num_bases = 25
298 | for (int c = 0; c < CHANNELS; ++c) {
299 | colors[c] += SH_C1 * (-y * coeffs[1 * CHANNELS + c] +
300 | z * coeffs[2 * CHANNELS + c] -
301 | x * coeffs[3 * CHANNELS + c]);
302 | if (degree < 2) {
303 | continue;
304 | }
305 | colors[c] +=
306 | (SH_C2[0] * xy * coeffs[4 * CHANNELS + c] +
307 | SH_C2[1] * yz * coeffs[5 * CHANNELS + c] +
308 | SH_C2[2] * (2.f * zz - xx - yy) * coeffs[6 * CHANNELS + c] +
309 | SH_C2[3] * xz * coeffs[7 * CHANNELS + c] +
310 | SH_C2[4] * (xx - yy) * coeffs[8 * CHANNELS + c]);
311 | if (degree < 3) {
312 | continue;
313 | }
314 | colors[c] +=
315 | (SH_C3[0] * y * (3.f * xx - yy) * coeffs[9 * CHANNELS + c] +
316 | SH_C3[1] * xy * z * coeffs[10 * CHANNELS + c] +
317 | SH_C3[2] * y * (4.f * zz - xx - yy) * coeffs[11 * CHANNELS + c] +
318 | SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy) *
319 | coeffs[12 * CHANNELS + c] +
320 | SH_C3[4] * x * (4.f * zz - xx - yy) * coeffs[13 * CHANNELS + c] +
321 | SH_C3[5] * z * (xx - yy) * coeffs[14 * CHANNELS + c] +
322 | SH_C3[6] * x * (xx - 3.f * yy) * coeffs[15 * CHANNELS + c]);
323 | if (degree < 4) {
324 | continue;
325 | }
326 | colors[c] +=
327 | (SH_C4[0] * xy * (xx - yy) * coeffs[16 * CHANNELS + c] +
328 | SH_C4[1] * yz * (3.f * xx - yy) * coeffs[17 * CHANNELS + c] +
329 | SH_C4[2] * xy * (7.f * zz - 1.f) * coeffs[18 * CHANNELS + c] +
330 | SH_C4[3] * yz * (7.f * zz - 3.f) * coeffs[19 * CHANNELS + c] +
331 | SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f) *
332 | coeffs[20 * CHANNELS + c] +
333 | SH_C4[5] * xz * (7.f * zz - 3.f) * coeffs[21 * CHANNELS + c] +
334 | SH_C4[6] * (xx - yy) * (7.f * zz - 1.f) *
335 | coeffs[22 * CHANNELS + c] +
336 | SH_C4[7] * xz * (xx - 3.f * yy) * coeffs[23 * CHANNELS + c] +
337 | SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy)) *
338 | coeffs[24 * CHANNELS + c]);
339 | }
340 | }
341 |
342 | __device__ void sh_coeffs_to_color_vjp(
343 | const unsigned degree,
344 | const float3 &viewdir,
345 | const float *v_colors,
346 | float *v_coeffs
347 | ) {
348 | // Expects v_colors to be len CHANNELS
349 | // and v_coeffs to be num_bases * CHANNELS
350 | #pragma unroll
351 | for (int c = 0; c < CHANNELS; ++c) {
352 | v_coeffs[c] = SH_C0 * v_colors[c];
353 | }
354 | if (degree < 1) {
355 | return;
356 | }
357 |
358 | float norm = sqrt(
359 | viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z
360 | );
361 | float x = viewdir.x / norm;
362 | float y = viewdir.y / norm;
363 | float z = viewdir.z / norm;
364 |
365 | float xx = x * x;
366 | float xy = x * y;
367 | float xz = x * z;
368 | float yy = y * y;
369 | float yz = y * z;
370 | float zz = z * z;
371 |
372 | #pragma unroll
373 | for (int c = 0; c < CHANNELS; ++c) {
374 | float v1 = -SH_C1 * y;
375 | float v2 = SH_C1 * z;
376 | float v3 = -SH_C1 * x;
377 | v_coeffs[1 * CHANNELS + c] = v1 * v_colors[c];
378 | v_coeffs[2 * CHANNELS + c] = v2 * v_colors[c];
379 | v_coeffs[3 * CHANNELS + c] = v3 * v_colors[c];
380 | if (degree < 2) {
381 | continue;
382 | }
383 | float v4 = SH_C2[0] * xy;
384 | float v5 = SH_C2[1] * yz;
385 | float v6 = SH_C2[2] * (2.f * zz - xx - yy);
386 | float v7 = SH_C2[3] * xz;
387 | float v8 = SH_C2[4] * (xx - yy);
388 | v_coeffs[4 * CHANNELS + c] = v4 * v_colors[c];
389 | v_coeffs[5 * CHANNELS + c] = v5 * v_colors[c];
390 | v_coeffs[6 * CHANNELS + c] = v6 * v_colors[c];
391 | v_coeffs[7 * CHANNELS + c] = v7 * v_colors[c];
392 | v_coeffs[8 * CHANNELS + c] = v8 * v_colors[c];
393 | if (degree < 3) {
394 | continue;
395 | }
396 | float v9 = SH_C3[0] * y * (3.f * xx - yy);
397 | float v10 = SH_C3[1] * xy * z;
398 | float v11 = SH_C3[2] * y * (4.f * zz - xx - yy);
399 | float v12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
400 | float v13 = SH_C3[4] * x * (4.f * zz - xx - yy);
401 | float v14 = SH_C3[5] * z * (xx - yy);
402 | float v15 = SH_C3[6] * x * (xx - 3.f * yy);
403 | v_coeffs[9 * CHANNELS + c] = v9 * v_colors[c];
404 | v_coeffs[10 * CHANNELS + c] = v10 * v_colors[c];
405 | v_coeffs[11 * CHANNELS + c] = v11 * v_colors[c];
406 | v_coeffs[12 * CHANNELS + c] = v12 * v_colors[c];
407 | v_coeffs[13 * CHANNELS + c] = v13 * v_colors[c];
408 | v_coeffs[14 * CHANNELS + c] = v14 * v_colors[c];
409 | v_coeffs[15 * CHANNELS + c] = v15 * v_colors[c];
410 | if (degree < 4) {
411 | continue;
412 | }
413 | float v16 = SH_C4[0] * xy * (xx - yy);
414 | float v17 = SH_C4[1] * yz * (3.f * xx - yy);
415 | float v18 = SH_C4[2] * xy * (7.f * zz - 1.f);
416 | float v19 = SH_C4[3] * yz * (7.f * zz - 3.f);
417 | float v20 = SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f);
418 | float v21 = SH_C4[5] * xz * (7.f * zz - 3.f);
419 | float v22 = SH_C4[6] * (xx - yy) * (7.f * zz - 1.f);
420 | float v23 = SH_C4[7] * xz * (xx - 3.f * yy);
421 | float v24 = SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy));
422 | v_coeffs[16 * CHANNELS + c] = v16 * v_colors[c];
423 | v_coeffs[17 * CHANNELS + c] = v17 * v_colors[c];
424 | v_coeffs[18 * CHANNELS + c] = v18 * v_colors[c];
425 | v_coeffs[19 * CHANNELS + c] = v19 * v_colors[c];
426 | v_coeffs[20 * CHANNELS + c] = v20 * v_colors[c];
427 | v_coeffs[21 * CHANNELS + c] = v21 * v_colors[c];
428 | v_coeffs[22 * CHANNELS + c] = v22 * v_colors[c];
429 | v_coeffs[23 * CHANNELS + c] = v23 * v_colors[c];
430 | v_coeffs[24 * CHANNELS + c] = v24 * v_colors[c];
431 | }
432 | }
433 |
434 | template
435 | __global__ void compute_sh_forward_kernel(
436 | const unsigned num_points,
437 | const unsigned degree,
438 | const unsigned degrees_to_use,
439 | const float3* __restrict__ viewdirs,
440 | const float* __restrict__ coeffs,
441 | float* __restrict__ colors
442 | ) {
443 | unsigned idx = cg::this_grid().thread_rank();
444 | if (idx >= num_points) {
445 | return;
446 | }
447 | const unsigned num_channels = 3;
448 | unsigned num_bases = num_sh_bases(degree);
449 | unsigned idx_sh = num_bases * num_channels * idx;
450 | unsigned idx_col = num_channels * idx;
451 |
452 | switch (SH_TYPE)
453 | {
454 | case SHType::Poly:
455 | sh_coeffs_to_color(
456 | degrees_to_use, viewdirs[idx], &(coeffs[idx_sh]), &(colors[idx_col])
457 | );
458 | break;
459 | case SHType::Fast:
460 | sh_coeffs_to_color_fast(
461 | degrees_to_use, viewdirs[idx], &(coeffs[idx_sh]), &(colors[idx_col])
462 | );
463 | break;
464 | }
465 | }
466 |
467 | template
468 | __global__ void compute_sh_backward_kernel(
469 | const unsigned num_points,
470 | const unsigned degree,
471 | const unsigned degrees_to_use,
472 | const float3* __restrict__ viewdirs,
473 | const float* __restrict__ v_colors,
474 | float* __restrict__ v_coeffs
475 | ) {
476 | unsigned idx = cg::this_grid().thread_rank();
477 | if (idx >= num_points) {
478 | return;
479 | }
480 | const unsigned num_channels = 3;
481 | unsigned num_bases = num_sh_bases(degree);
482 | unsigned idx_sh = num_bases * num_channels * idx;
483 | unsigned idx_col = num_channels * idx;
484 |
485 | switch (SH_TYPE)
486 | {
487 | case SHType::Poly:
488 | sh_coeffs_to_color_vjp(
489 | degrees_to_use, viewdirs[idx], &(v_colors[idx_col]), &(v_coeffs[idx_sh])
490 | );
491 | break;
492 | case SHType::Fast:
493 | sh_coeffs_to_color_fast_vjp(
494 | degrees_to_use, viewdirs[idx], &(v_colors[idx_col]), &(v_coeffs[idx_sh])
495 | );
496 | break;
497 | }
498 | }
499 |
--------------------------------------------------------------------------------
/gsplat/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Callable, List, Optional, Union
3 |
4 | import torch
5 | import torch.distributed as dist
6 | import torch.distributed.nn.functional as distF
7 | from torch import Tensor
8 |
9 |
10 | def all_gather_int32(
11 | world_size: int, value: Union[int, Tensor], device: Optional[torch.device] = None
12 | ) -> List[int]:
13 | """Gather an 32-bit integer from all ranks.
14 |
15 | .. note::
16 | This implementation is faster than using `torch.distributed.all_gather_object`.
17 |
18 | .. note::
19 | This function is not differentiable to the input tensor.
20 |
21 | Args:
22 | world_size: The total number of ranks.
23 | value: The integer to gather. Could be a scalar or a tensor.
24 | device: Only required if `value` is a scalar. The device to put the tensor on.
25 |
26 | Returns:
27 | A list of integers, where the i-th element is the value from the i-th rank.
28 | Could be a list of scalars or tensors based on the input `value`.
29 | """
30 | if world_size == 1:
31 | return [value]
32 |
33 | # move to CUDA
34 | if isinstance(value, int):
35 | assert device is not None, "device is required for scalar input"
36 | value_tensor = torch.tensor(value, dtype=torch.int, device=device)
37 | else:
38 | value_tensor = value
39 | assert value_tensor.is_cuda, "value should be on CUDA"
40 |
41 | # gather
42 | collected = torch.empty(
43 | world_size, dtype=value_tensor.dtype, device=value_tensor.device
44 | )
45 | dist.all_gather_into_tensor(collected, value_tensor)
46 |
47 | if isinstance(value, int):
48 | # return as list of integers on CPU
49 | return collected.tolist()
50 | else:
51 | # return as list of single-element tensors
52 | return collected.unbind()
53 |
54 |
55 | def all_to_all_int32(
56 | world_size: int,
57 | values: List[Union[int, Tensor]],
58 | device: Optional[torch.device] = None,
59 | ) -> List[int]:
60 | """Exchange 32-bit integers between all ranks in a many-to-many fashion.
61 |
62 | .. note::
63 | This function is not differentiable to the input tensors.
64 |
65 | Args:
66 | world_size: The total number of ranks.
67 | values: A list of integers to exchange. Could be a list of scalars or tensors.
68 | Should have the same length as `world_size`.
69 | device: Only required if `values` contains scalars. The device to put the tensors on.
70 |
71 | Returns:
72 | A list of integers. Could be a list of scalars or tensors based on the input `values`.
73 | Have the same length as `world_size`.
74 | """
75 | if world_size == 1:
76 | return values
77 |
78 | assert (
79 | len(values) == world_size
80 | ), "The length of values should be equal to world_size"
81 |
82 | if any(isinstance(v, int) for v in values):
83 | assert device is not None, "device is required for scalar input"
84 |
85 | # move to CUDA
86 | values_tensor = [
87 | (torch.tensor(v, dtype=torch.int, device=device) if isinstance(v, int) else v)
88 | for v in values
89 | ]
90 |
91 | # all_to_all
92 | collected = [torch.empty_like(v) for v in values_tensor]
93 | dist.all_to_all(collected, values_tensor)
94 |
95 | # return as a list of integers or tensors, based on the input
96 | return [
97 | v.item() if isinstance(tensor, int) else v
98 | for v, tensor in zip(collected, values)
99 | ]
100 |
101 |
102 | def all_gather_tensor_list(world_size: int, tensor_list: List[Tensor]) -> List[Tensor]:
103 | """Gather a list of tensors from all ranks.
104 |
105 | .. note::
106 | This function expects the tensors in the `tensor_list` to have the same shape
107 | and data type across all ranks.
108 |
109 | .. note::
110 | This function is differentiable to the tensors in `tensor_list`.
111 |
112 | .. note::
113 | For efficiency, this function internally concatenates the tensors in `tensor_list`
114 | and performs a single gather operation. Thus it requires all tensors in the list
115 | to have the same first-dimension size.
116 |
117 | Args:
118 | world_size: The total number of ranks.
119 | tensor_list: A list of tensors to gather. The size of the first dimension of all
120 | the tensors in this list should be the same. The rest dimensions can be
121 | arbitrary. Shape: [(N, *), (N, *), ...]
122 |
123 | Returns:
124 | A list of tensors gathered from all ranks, where the i-th element is corresponding
125 | to the i-th tensor in `tensor_list`. The returned tensors have the shape
126 | [(N * world_size, *), (N * world_size, *), ...]
127 |
128 | Examples:
129 |
130 | .. code-block:: python
131 |
132 | >>> # on rank 0
133 | >>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]
134 | >>> # on rank 1
135 | >>> # tensor_list = [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]
136 | >>> collected = all_gather_tensor_list(world_rank, world_size, tensor_list)
137 | >>> # on both ranks
138 | >>> # [torch.tensor([1, 2, 3, 7, 8, 9]), torch.tensor([4, 5, 6, 10, 11, 12])]
139 |
140 | """
141 | if world_size == 1:
142 | return tensor_list
143 |
144 | N = len(tensor_list[0])
145 | for tensor in tensor_list:
146 | assert len(tensor) == N, "All tensors should have the same first dimension size"
147 |
148 | # concatenate tensors and record their sizes
149 | data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1)
150 | sizes = [t.numel() // N for t in tensor_list]
151 |
152 | if data.requires_grad:
153 | # differentiable gather
154 | collected = distF.all_gather(data)
155 | else:
156 | # non-differentiable gather
157 | collected = [torch.empty_like(data) for _ in range(world_size)]
158 | torch.distributed.all_gather(collected, data)
159 | collected = torch.cat(collected, dim=0)
160 |
161 | # split the collected tensor and reshape to the original shape
162 | out_tensor_tuple = torch.split(collected, sizes, dim=-1)
163 | out_tensor_list = []
164 | for out_tensor, tensor in zip(out_tensor_tuple, tensor_list):
165 | out_tensor = out_tensor.view(-1, *tensor.shape[1:]) # [N * world_size, *]
166 | out_tensor_list.append(out_tensor)
167 | return out_tensor_list
168 |
169 |
170 | def all_to_all_tensor_list(
171 | world_size: int,
172 | tensor_list: List[Tensor],
173 | splits: List[Union[int, Tensor]],
174 | output_splits: Optional[List[Union[int, Tensor]]] = None,
175 | ) -> List[Tensor]:
176 | """Split and exchange tensors between all ranks in a many-to-many fashion.
177 |
178 | Args:
179 | world_size: The total number of ranks.
180 | tensor_list: A list of tensors to split and exchange. The size of the first
181 | dimension of all the tensors in this list should be the same. The rest
182 | dimensions can be arbitrary. Shape: [(N, *), (N, *), ...]
183 | splits: A list of integers representing the number of elements to send to each
184 | rank. It will be used to split the tensor in the `tensor_list`.
185 | The sum of the elements in this list should be equal to N. The size of this
186 | list should be equal to the `world_size`.
187 | output_splits: Splits of the output tensors. Could be pre-calculated via
188 | `all_to_all_int32(world_size, splits)`. If not provided, it will
189 | be calculated internally.
190 |
191 | Returns:
192 | A list of tensors exchanged between all ranks, where the i-th element is
193 | corresponding to the i-th tensor in `tensor_list`. Note the shape of the
194 | returned tensors might be different from the input tensors, depending on the
195 | splits.
196 |
197 | Examples:
198 |
199 | .. code-block:: python
200 |
201 | >>> # on rank 0
202 | >>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]
203 | >>> # splits = [2, 1]
204 |
205 | >>> # on rank 1
206 | >>> # tensor_list = [torch.tensor([7, 8]), torch.tensor([9, 10])]
207 | >>> # splits = [1, 1]
208 |
209 | >>> collected = all_to_all_tensor_list(world_rank, world_size, tensor_list, splits)
210 |
211 | >>> # on rank 0
212 | >>> # [torch.tensor([1, 2, 7]), torch.tensor([4, 5, 9])]
213 | >>> # on rank 1
214 | >>> # [torch.tensor([3, 8]), torch.tensor([6, 10])]
215 |
216 | """
217 | if world_size == 1:
218 | return tensor_list
219 |
220 | N = len(tensor_list[0])
221 | for tensor in tensor_list:
222 | assert len(tensor) == N, "All tensors should have the same first dimension size"
223 |
224 | assert (
225 | len(splits) == world_size
226 | ), "The length of splits should be equal to world_size"
227 |
228 | # concatenate tensors and record their sizes
229 | data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1)
230 | sizes = [t.numel() // N for t in tensor_list]
231 |
232 | # all_to_all
233 | if output_splits is not None:
234 | collected_splits = output_splits
235 | else:
236 | collected_splits = all_to_all_int32(world_size, splits, device=data.device)
237 | collected = [
238 | torch.empty((l, *data.shape[1:]), dtype=data.dtype, device=data.device)
239 | for l in collected_splits
240 | ]
241 | # torch.split requires tuple of integers
242 | splits = [s.item() if isinstance(s, Tensor) else s for s in splits]
243 | if data.requires_grad:
244 | # differentiable all_to_all
245 | distF.all_to_all(collected, data.split(splits, dim=0))
246 | else:
247 | # non-differentiable all_to_all
248 | torch.distributed.all_to_all(collected, list(data.split(splits, dim=0)))
249 | collected = torch.cat(collected, dim=0)
250 |
251 | # split the collected tensor and reshape to the original shape
252 | out_tensor_tuple = torch.split(collected, sizes, dim=-1)
253 | out_tensor_list = []
254 | for out_tensor, tensor in zip(out_tensor_tuple, tensor_list):
255 | out_tensor = out_tensor.view(-1, *tensor.shape[1:])
256 | out_tensor_list.append(out_tensor)
257 | return out_tensor_list
258 |
259 |
260 | def _find_free_port():
261 | import socket
262 |
263 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
264 | # Binding to port 0 will cause the OS to find an available port for us
265 | sock.bind(("", 0))
266 | port = sock.getsockname()[1]
267 | sock.close()
268 | # NOTE: there is still a chance the port could be taken by other processes.
269 | return port
270 |
271 |
272 | def _distributed_worker(
273 | world_rank: int,
274 | world_size: int,
275 | fn: Callable,
276 | args: Any,
277 | local_rank: Optional[int] = None,
278 | verbose: bool = False,
279 | ) -> bool:
280 | if local_rank is None: # single Node
281 | local_rank = world_rank
282 | if verbose:
283 | print("Distributed worker: %d / %d" % (world_rank + 1, world_size))
284 | distributed = world_size > 1
285 | if distributed:
286 | torch.cuda.set_device(local_rank)
287 | torch.distributed.init_process_group(
288 | backend="nccl", world_size=world_size, rank=world_rank
289 | )
290 | # Dump collection that participates all ranks.
291 | # This initializes the communicator required by `batch_isend_irecv`.
292 | # See: https://github.com/pytorch/pytorch/pull/74701
293 | _ = [None for _ in range(world_size)]
294 | torch.distributed.all_gather_object(_, 0)
295 | fn(local_rank, world_rank, world_size, args)
296 | if distributed:
297 | torch.distributed.barrier()
298 | torch.distributed.destroy_process_group()
299 | if verbose:
300 | print("Job Done for worker: %d / %d" % (world_rank + 1, world_size))
301 | return True
302 |
303 |
304 | def cli(fn: Callable, args: Any, verbose: bool = False) -> bool:
305 | """Wrapper to run a function in a distributed environment.
306 |
307 | The function `fn` should have the following signature:
308 |
309 | ```python
310 | def fn(local_rank: int, world_rank: int, world_size: int, args: Any) -> None:
311 | pass
312 | ```
313 |
314 | Usage:
315 |
316 | ```python
317 | # Launch with "CUDA_VISIBLE_DEVICES=0,1,2,3 python my_script.py"
318 | if __name__ == "__main__":
319 | cli(fn, None, verbose=True)
320 | ```
321 | """
322 | assert torch.cuda.is_available(), "CUDA device is required!"
323 | if "OMPI_COMM_WORLD_SIZE" in os.environ: # multi-node
324 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
325 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) # dist.get_world_size()
326 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) # dist.get_rank()
327 | return _distributed_worker(
328 | world_rank, world_size, fn, args, local_rank, verbose
329 | )
330 |
331 | world_size = torch.cuda.device_count()
332 | distributed = world_size > 1
333 |
334 | if distributed:
335 | os.environ["MASTER_ADDR"] = "localhost"
336 | os.environ["MASTER_PORT"] = str(_find_free_port())
337 | process_context = torch.multiprocessing.spawn(
338 | _distributed_worker,
339 | args=(world_size, fn, args, None, verbose),
340 | nprocs=world_size,
341 | join=False,
342 | )
343 | try:
344 | process_context.join()
345 | except KeyboardInterrupt:
346 | # this is important.
347 | # if we do not explicitly terminate all launched subprocesses,
348 | # they would continue living even after this main process ends,
349 | # eventually making the OD machine unusable!
350 | for i, process in enumerate(process_context.processes):
351 | if process.is_alive():
352 | if verbose:
353 | print("terminating process " + str(i) + "...")
354 | process.terminate()
355 | process.join()
356 | if verbose:
357 | print("process " + str(i) + " finished")
358 | return True
359 | else:
360 | return _distributed_worker(0, 1, fn=fn, args=args)
361 |
--------------------------------------------------------------------------------
/gsplat/profile.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from functools import wraps
4 | from typing import Callable, Optional
5 |
6 | import torch
7 |
8 | profiler = {}
9 |
10 |
11 | class timeit(object):
12 | """Profiler that is controled by the TIMEIT environment variable.
13 |
14 | If TIMEIT is set to 1, the profiler will measure the time taken by the decorated function.
15 |
16 | Usage:
17 |
18 | ```python
19 | @timeit()
20 | def my_function():
21 | pass
22 |
23 | # Or
24 |
25 | with timeit(name="stage1"):
26 | my_function()
27 |
28 | print(profiler)
29 | ```
30 | """
31 |
32 | def __init__(self, name: str = "unnamed"):
33 | self.name = name
34 | self.start_time: Optional[float] = None
35 | self.enabled = os.environ.get("TIMEIT", "0") == "1"
36 |
37 | def __enter__(self):
38 | if self.enabled:
39 | torch.cuda.synchronize()
40 | self.start_time = time.perf_counter()
41 |
42 | def __exit__(self, exc_type, exc_val, exc_tb):
43 | if self.enabled:
44 | torch.cuda.synchronize()
45 | end_time = time.perf_counter()
46 | total_time = end_time - self.start_time
47 | if self.name not in profiler:
48 | profiler[self.name] = total_time
49 | else:
50 | profiler[self.name] += total_time
51 |
52 | def __call__(self, f: Callable) -> Callable:
53 | @wraps(f)
54 | def decorated(*args, **kwargs):
55 | with self:
56 | self.name = f.__name__
57 | return f(*args, **kwargs)
58 |
59 | return decorated
60 |
--------------------------------------------------------------------------------
/gsplat/relocation.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | from .cuda._wrapper import _make_lazy_cuda_func
8 |
9 |
10 | def compute_relocation(
11 | opacities: Tensor, # [N]
12 | scales: Tensor, # [N, 3]
13 | ratios: Tensor, # [N]
14 | binoms: Tensor, # [n_max, n_max]
15 | ) -> Tuple[Tensor, Tensor]:
16 | """Compute new Gaussians from a set of old Gaussians.
17 |
18 | This function interprets the Gaussians as samples from a likelihood distribution.
19 | It uses the old opacities and scales to compute the new opacities and scales.
20 | This is an implementation of the paper
21 | `3D Gaussian Splatting as Markov Chain Monte Carlo `_,
22 |
23 | Args:
24 | opacities: The opacities of the Gaussians. [N]
25 | scales: The scales of the Gaussians. [N, 3]
26 | ratios: The relative frequencies for each of the Gaussians. [N]
27 | binoms: Precomputed lookup table for binomial coefficients used in
28 | Equation 9 in the paper. [n_max, n_max]
29 |
30 | Returns:
31 | A tuple:
32 |
33 | **new_opacities**: The opacities of the new Gaussians. [N]
34 | **new_scales**: The scales of the Gaussians. [N, 3]
35 | """
36 |
37 | N = opacities.shape[0]
38 | n_max, _ = binoms.shape
39 | assert scales.shape == (N, 3), scales.shape
40 | assert ratios.shape == (N,), ratios.shape
41 | opacities = opacities.contiguous()
42 | scales = scales.contiguous()
43 | ratios.clamp_(min=1, max=n_max)
44 | ratios = ratios.int().contiguous()
45 |
46 | new_opacities, new_scales = _make_lazy_cuda_func("compute_relocation")(
47 | opacities, scales, ratios, binoms, n_max
48 | )
49 | return new_opacities, new_scales
50 |
--------------------------------------------------------------------------------
/gsplat/strategy/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Strategy
2 | from .default import DefaultStrategy
3 | from .mcmc import MCMCStrategy
4 |
--------------------------------------------------------------------------------
/gsplat/strategy/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Union
3 |
4 | import torch
5 |
6 |
7 | @dataclass
8 | class Strategy:
9 | """Base class for the GS densification strategy.
10 |
11 | This class is an base class that defines the interface for the GS
12 | densification strategy.
13 | """
14 |
15 | def check_sanity(
16 | self,
17 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
18 | optimizers: Dict[str, torch.optim.Optimizer],
19 | ):
20 | """Sanity check for the parameters and optimizers."""
21 | assert set(params.keys()) == set(optimizers.keys()), (
22 | "params and optimizers must have the same keys, "
23 | f"but got {params.keys()} and {optimizers.keys()}"
24 | )
25 |
26 | for optimizer in optimizers.values():
27 | assert len(optimizer.param_groups) == 1, (
28 | "Each optimizer must have exactly one param_group, "
29 | "that cooresponds to each parameter, "
30 | f"but got {len(optimizer.param_groups)}"
31 | )
32 |
33 | def step_pre_backward(
34 | self,
35 | *args,
36 | **kwargs,
37 | ):
38 | """Callback function to be executed before the `loss.backward()` call."""
39 | pass
40 |
41 | def step_post_backward(
42 | self,
43 | *args,
44 | **kwargs,
45 | ):
46 | """Callback function to be executed after the `loss.backward()` call."""
47 | pass
48 |
--------------------------------------------------------------------------------
/gsplat/strategy/default.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, Tuple, Union
3 |
4 | import torch
5 | from typing_extensions import Literal
6 |
7 | from .base import Strategy
8 | from .ops import duplicate, remove, reset_opa, split
9 |
10 |
11 | @dataclass
12 | class DefaultStrategy(Strategy):
13 | """A default strategy that follows the original 3DGS paper:
14 |
15 | `3D Gaussian Splatting for Real-Time Radiance Field Rendering `_
16 |
17 | The strategy will:
18 |
19 | - Periodically duplicate GSs with high image plane gradients and small scales.
20 | - Periodically split GSs with high image plane gradients and large scales.
21 | - Periodically prune GSs with low opacity.
22 | - Periodically reset GSs to a lower opacity.
23 |
24 | If `absgrad=True`, it will use the absolute gradients instead of average gradients
25 | for GS duplicating & splitting, following the AbsGS paper:
26 |
27 | `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_
28 |
29 | Which typically leads to better results but requires to set the `grow_grad2d` to a
30 | higher value, e.g., 0.0008. Also, the :func:`rasterization` function should be called
31 | with `absgrad=True` as well so that the absolute gradients are computed.
32 |
33 | Args:
34 | prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005.
35 | grow_grad2d (float): GSs with image plane gradient above this value will be
36 | split/duplicated. Default is 0.0002.
37 | grow_scale3d (float): GSs with 3d scale (normalized by scene_scale) below this
38 | value will be duplicated. Above will be split. Default is 0.01.
39 | grow_scale2d (float): GSs with 2d scale (normalized by image resolution) above
40 | this value will be split. Default is 0.05.
41 | prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this
42 | value will be pruned. Default is 0.1.
43 | prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above
44 | this value will be pruned. Default is 0.15.
45 | refine_scale2d_stop_iter (int): Stop refining GSs based on 2d scale after this
46 | iteration. Default is 0. Set to a positive value to enable this feature.
47 | refine_start_iter (int): Start refining GSs after this iteration. Default is 500.
48 | refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000.
49 | reset_every (int): Reset opacities every this steps. Default is 3000.
50 | refine_every (int): Refine GSs every this steps. Default is 100.
51 | pause_refine_after_reset (int): Pause refining GSs until this number of steps after
52 | reset, Default is 0 (no pause at all) and one might want to set this number to the
53 | number of images in training set.
54 | absgrad (bool): Use absolute gradients for GS splitting. Default is False.
55 | revised_opacity (bool): Whether to use revised opacity heuristic from
56 | arXiv:2404.06109 (experimental). Default is False.
57 | verbose (bool): Whether to print verbose information. Default is False.
58 | key_for_gradient (str): Which variable uses for densification strategy.
59 | 3DGS uses "means2d" gradient and 2DGS uses a similar gradient which stores
60 | in variable "gradient_2dgs".
61 |
62 | Examples:
63 |
64 | >>> from gsplat import DefaultStrategy, rasterization
65 | >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
66 | >>> optimizers: Dict[str, torch.optim.Optimizer] = ...
67 | >>> strategy = DefaultStrategy()
68 | >>> strategy.check_sanity(params, optimizers)
69 | >>> strategy_state = strategy.initialize_state()
70 | >>> for step in range(1000):
71 | ... render_image, render_alpha, info = rasterization(...)
72 | ... strategy.step_pre_backward(params, optimizers, strategy_state, step, info)
73 | ... loss = ...
74 | ... loss.backward()
75 | ... strategy.step_post_backward(params, optimizers, strategy_state, step, info)
76 |
77 | """
78 |
79 | prune_opa: float = 0.005
80 | grow_grad2d: float = 0.0002
81 | grow_scale3d: float = 0.01
82 | grow_scale2d: float = 0.05
83 | prune_scale3d: float = 0.1
84 | prune_scale2d: float = 0.15
85 | refine_scale2d_stop_iter: int = 0
86 | refine_start_iter: int = 500
87 | refine_stop_iter: int = 15_000
88 | reset_every: int = 3000
89 | refine_every: int = 100
90 | pause_refine_after_reset: int = 0
91 | absgrad: bool = False
92 | revised_opacity: bool = False
93 | verbose: bool = False
94 | key_for_gradient: Literal["means2d", "gradient_2dgs"] = "means2d"
95 |
96 | def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]:
97 | """Initialize and return the running state for this strategy.
98 |
99 | The returned state should be passed to the `step_pre_backward()` and
100 | `step_post_backward()` functions.
101 | """
102 | # Postpone the initialization of the state to the first step so that we can
103 | # put them on the correct device.
104 | # - grad2d: running accum of the norm of the image plane gradients for each GS.
105 | # - count: running accum of how many time each GS is visible.
106 | # - radii: the radii of the GSs (normalized by the image resolution).
107 | state = {"grad2d": None, "count": None, "scene_scale": scene_scale}
108 | if self.refine_scale2d_stop_iter > 0:
109 | state["radii"] = None
110 | return state
111 |
112 | def check_sanity(
113 | self,
114 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
115 | optimizers: Dict[str, torch.optim.Optimizer],
116 | ):
117 | """Sanity check for the parameters and optimizers.
118 |
119 | Check if:
120 | * `params` and `optimizers` have the same keys.
121 | * Each optimizer has exactly one param_group, corresponding to each parameter.
122 | * The following keys are present: {"means", "scales", "quats", "opacities"}.
123 |
124 | Raises:
125 | AssertionError: If any of the above conditions is not met.
126 |
127 | .. note::
128 | It is not required but highly recommended for the user to call this function
129 | after initializing the strategy to ensure the convention of the parameters
130 | and optimizers is as expected.
131 | """
132 |
133 | super().check_sanity(params, optimizers)
134 | # The following keys are required for this strategy.
135 | for key in ["means", "scales", "quats", "opacities"]:
136 | assert key in params, f"{key} is required in params but missing."
137 |
138 | def step_pre_backward(
139 | self,
140 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
141 | optimizers: Dict[str, torch.optim.Optimizer],
142 | state: Dict[str, Any],
143 | step: int,
144 | info: Dict[str, Any],
145 | ):
146 | """Callback function to be executed before the `loss.backward()` call."""
147 | assert (
148 | self.key_for_gradient in info
149 | ), "The 2D means of the Gaussians is required but missing."
150 | info[self.key_for_gradient].retain_grad()
151 |
152 | def step_post_backward(
153 | self,
154 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
155 | optimizers: Dict[str, torch.optim.Optimizer],
156 | state: Dict[str, Any],
157 | step: int,
158 | info: Dict[str, Any],
159 | packed: bool = False,
160 | ):
161 | """Callback function to be executed after the `loss.backward()` call."""
162 | if step >= self.refine_stop_iter:
163 | return
164 |
165 | self._update_state(params, state, info, packed=packed)
166 |
167 | if (
168 | step > self.refine_start_iter
169 | and step % self.refine_every == 0
170 | and step % self.reset_every >= self.pause_refine_after_reset
171 | ):
172 | # grow GSs
173 | n_dupli, n_split = self._grow_gs(params, optimizers, state, step)
174 | if self.verbose:
175 | print(
176 | f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. "
177 | f"Now having {len(params['means'])} GSs."
178 | )
179 |
180 | # prune GSs
181 | n_prune = self._prune_gs(params, optimizers, state, step)
182 | if self.verbose:
183 | print(
184 | f"Step {step}: {n_prune} GSs pruned. "
185 | f"Now having {len(params['means'])} GSs."
186 | )
187 |
188 | # reset running stats
189 | state["grad2d"].zero_()
190 | state["count"].zero_()
191 | if self.refine_scale2d_stop_iter > 0:
192 | state["radii"].zero_()
193 | torch.cuda.empty_cache()
194 |
195 | if step % self.reset_every == 0:
196 | reset_opa(
197 | params=params,
198 | optimizers=optimizers,
199 | state=state,
200 | value=self.prune_opa * 2.0,
201 | )
202 |
203 | def _update_state(
204 | self,
205 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
206 | state: Dict[str, Any],
207 | info: Dict[str, Any],
208 | packed: bool = False,
209 | ):
210 | for key in [
211 | "width",
212 | "height",
213 | "n_cameras",
214 | "radii",
215 | "gaussian_ids",
216 | self.key_for_gradient,
217 | ]:
218 | assert key in info, f"{key} is required but missing."
219 |
220 | # normalize grads to [-1, 1] screen space
221 | if self.absgrad:
222 | grads = info[self.key_for_gradient].absgrad.clone()
223 | else:
224 | grads = info[self.key_for_gradient].grad.clone()
225 | grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"]
226 | grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"]
227 |
228 | # initialize state on the first run
229 | n_gaussian = len(list(params.values())[0])
230 |
231 | if state["grad2d"] is None:
232 | state["grad2d"] = torch.zeros(n_gaussian, device=grads.device)
233 | if state["count"] is None:
234 | state["count"] = torch.zeros(n_gaussian, device=grads.device)
235 | if self.refine_scale2d_stop_iter > 0 and state["radii"] is None:
236 | assert "radii" in info, "radii is required but missing."
237 | state["radii"] = torch.zeros(n_gaussian, device=grads.device)
238 |
239 | # update the running state
240 | if packed:
241 | # grads is [nnz, 2]
242 | gs_ids = info["gaussian_ids"] # [nnz]
243 | radii = info["radii"].max(dim=-1).values # [nnz]
244 | else:
245 | # grads is [C, N, 2]
246 | sel = info["radii"][..., 0] > 0.0 # [C, N]
247 | gs_ids = torch.where(sel)[1] # [nnz]
248 | grads = grads[sel] # [nnz, 2]
249 | radii = info["radii"][sel].max(dim=-1).values # [nnz]
250 |
251 | state["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1))
252 | state["count"].index_add_(0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32))
253 | if self.refine_scale2d_stop_iter > 0:
254 | # Should be ideally using scatter max
255 | state["radii"][gs_ids] = torch.maximum(
256 | state["radii"][gs_ids],
257 | # normalize radii to [0, 1] screen space
258 | radii / float(max(info["width"], info["height"])),
259 | )
260 |
261 | @torch.no_grad()
262 | def _grow_gs(
263 | self,
264 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
265 | optimizers: Dict[str, torch.optim.Optimizer],
266 | state: Dict[str, Any],
267 | step: int,
268 | ) -> Tuple[int, int]:
269 | count = state["count"]
270 | grads = state["grad2d"] / count.clamp_min(1)
271 | device = grads.device
272 |
273 | is_grad_high = grads > self.grow_grad2d
274 | is_small = (
275 | torch.exp(params["scales"]).max(dim=-1).values
276 | <= self.grow_scale3d * state["scene_scale"]
277 | )
278 | is_dupli = is_grad_high & is_small
279 | n_dupli = is_dupli.sum().item()
280 |
281 | is_large = ~is_small
282 | is_split = is_large
283 | if step < self.refine_scale2d_stop_iter:
284 | is_split |= state["radii"] > self.grow_scale2d
285 | is_split = is_grad_high & is_split
286 | n_split = is_split.sum().item()
287 |
288 | # first duplicate
289 | if n_dupli > 0:
290 | duplicate(params=params, optimizers=optimizers, state=state, mask=is_dupli)
291 |
292 | # new GSs added by duplication will not be split
293 | is_split = torch.cat(
294 | [
295 | is_split,
296 | torch.zeros(n_dupli, dtype=torch.bool, device=device),
297 | ]
298 | )
299 |
300 | # then split
301 | if n_split > 0:
302 | split(
303 | params=params,
304 | optimizers=optimizers,
305 | state=state,
306 | mask=is_split,
307 | revised_opacity=self.revised_opacity,
308 | )
309 | return n_dupli, n_split
310 |
311 | @torch.no_grad()
312 | def _prune_gs(
313 | self,
314 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
315 | optimizers: Dict[str, torch.optim.Optimizer],
316 | state: Dict[str, Any],
317 | step: int,
318 | ) -> int:
319 | is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa
320 | if step > self.reset_every:
321 | is_too_big = (
322 | torch.exp(params["scales"]).max(dim=-1).values
323 | > self.prune_scale3d * state["scene_scale"]
324 | )
325 | # The official code also implements sreen-size pruning but
326 | # it's actually not being used due to a bug:
327 | # https://github.com/graphdeco-inria/gaussian-splatting/issues/123
328 | # We implement it here for completeness but set `refine_scale2d_stop_iter`
329 | # to 0 by default to disable it.
330 | if step < self.refine_scale2d_stop_iter:
331 | is_too_big |= state["radii"] > self.prune_scale2d
332 |
333 | is_prune = is_prune | is_too_big
334 |
335 | n_prune = is_prune.sum().item()
336 | if n_prune > 0:
337 | remove(params=params, optimizers=optimizers, state=state, mask=is_prune)
338 |
339 | return n_prune
340 |
--------------------------------------------------------------------------------
/gsplat/strategy/mcmc.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, Union
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | from .base import Strategy
9 | from .ops import inject_noise_to_position, relocate, sample_add
10 |
11 |
12 | @dataclass
13 | class MCMCStrategy(Strategy):
14 | """Strategy that follows the paper:
15 |
16 | `3D Gaussian Splatting as Markov Chain Monte Carlo `_
17 |
18 | This strategy will:
19 |
20 | - Periodically teleport GSs with low opacity to a place that has high opacity.
21 | - Periodically introduce new GSs sampled based on the opacity distribution.
22 | - Periodically perturb the GSs locations.
23 |
24 | Args:
25 | cap_max (int): Maximum number of GSs. Default to 1_000_000.
26 | noise_lr (float): MCMC samping noise learning rate. Default to 5e5.
27 | refine_start_iter (int): Start refining GSs after this iteration. Default to 500.
28 | refine_stop_iter (int): Stop refining GSs after this iteration. Default to 25_000.
29 | refine_every (int): Refine GSs every this steps. Default to 100.
30 | min_opacity (float): GSs with opacity below this value will be pruned. Default to 0.005.
31 | verbose (bool): Whether to print verbose information. Default to False.
32 |
33 | Examples:
34 |
35 | >>> from gsplat import MCMCStrategy, rasterization
36 | >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
37 | >>> optimizers: Dict[str, torch.optim.Optimizer] = ...
38 | >>> strategy = MCMCStrategy()
39 | >>> strategy.check_sanity(params, optimizers)
40 | >>> strategy_state = strategy.initialize_state()
41 | >>> for step in range(1000):
42 | ... render_image, render_alpha, info = rasterization(...)
43 | ... loss = ...
44 | ... loss.backward()
45 | ... strategy.step_post_backward(params, optimizers, strategy_state, step, info, lr=1e-3)
46 |
47 | """
48 |
49 | cap_max: int = 1_000_000
50 | noise_lr: float = 5e5
51 | refine_start_iter: int = 500
52 | refine_stop_iter: int = 25_000
53 | refine_every: int = 100
54 | min_opacity: float = 0.005
55 | verbose: bool = False
56 |
57 | def initialize_state(self) -> Dict[str, Any]:
58 | """Initialize and return the running state for this strategy."""
59 | n_max = 51
60 | binoms = torch.zeros((n_max, n_max))
61 | for n in range(n_max):
62 | for k in range(n + 1):
63 | binoms[n, k] = math.comb(n, k)
64 | return {"binoms": binoms}
65 |
66 | def check_sanity(
67 | self,
68 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
69 | optimizers: Dict[str, torch.optim.Optimizer],
70 | ):
71 | """Sanity check for the parameters and optimizers.
72 |
73 | Check if:
74 | * `params` and `optimizers` have the same keys.
75 | * Each optimizer has exactly one param_group, corresponding to each parameter.
76 | * The following keys are present: {"means", "scales", "quats", "opacities"}.
77 |
78 | Raises:
79 | AssertionError: If any of the above conditions is not met.
80 |
81 | .. note::
82 | It is not required but highly recommended for the user to call this function
83 | after initializing the strategy to ensure the convention of the parameters
84 | and optimizers is as expected.
85 | """
86 |
87 | super().check_sanity(params, optimizers)
88 | # The following keys are required for this strategy.
89 | for key in ["means", "scales", "quats", "opacities"]:
90 | assert key in params, f"{key} is required in params but missing."
91 |
92 | # def step_pre_backward(
93 | # self,
94 | # params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
95 | # optimizers: Dict[str, torch.optim.Optimizer],
96 | # # state: Dict[str, Any],
97 | # step: int,
98 | # info: Dict[str, Any],
99 | # ):
100 | # """Callback function to be executed before the `loss.backward()` call."""
101 | # pass
102 |
103 | def step_post_backward(
104 | self,
105 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
106 | optimizers: Dict[str, torch.optim.Optimizer],
107 | state: Dict[str, Any],
108 | step: int,
109 | info: Dict[str, Any],
110 | lr: float,
111 | ):
112 | """Callback function to be executed after the `loss.backward()` call.
113 |
114 | Args:
115 | lr (float): Learning rate for "means" attribute of the GS.
116 | """
117 | # move to the correct device
118 | state["binoms"] = state["binoms"].to(params["means"].device)
119 |
120 | binoms = state["binoms"]
121 |
122 | if (
123 | step < self.refine_stop_iter
124 | and step > self.refine_start_iter
125 | and step % self.refine_every == 0
126 | ):
127 | # teleport GSs
128 | n_relocated_gs = self._relocate_gs(params, optimizers, binoms)
129 | if self.verbose:
130 | print(f"Step {step}: Relocated {n_relocated_gs} GSs.")
131 |
132 | # add new GSs
133 | n_new_gs = self._add_new_gs(params, optimizers, binoms)
134 | if self.verbose:
135 | print(
136 | f"Step {step}: Added {n_new_gs} GSs. "
137 | f"Now having {len(params['means'])} GSs."
138 | )
139 |
140 | torch.cuda.empty_cache()
141 |
142 | # add noise to GSs
143 | inject_noise_to_position(
144 | params=params, optimizers=optimizers, state={}, scaler=lr * self.noise_lr
145 | )
146 |
147 | @torch.no_grad()
148 | def _relocate_gs(
149 | self,
150 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
151 | optimizers: Dict[str, torch.optim.Optimizer],
152 | binoms: Tensor,
153 | ) -> int:
154 | opacities = torch.sigmoid(params["opacities"].flatten())
155 | dead_mask = opacities <= self.min_opacity
156 | n_gs = dead_mask.sum().item()
157 | if n_gs > 0:
158 | relocate(
159 | params=params,
160 | optimizers=optimizers,
161 | state={},
162 | mask=dead_mask,
163 | binoms=binoms,
164 | min_opacity=self.min_opacity,
165 | )
166 | return n_gs
167 |
168 | @torch.no_grad()
169 | def _add_new_gs(
170 | self,
171 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
172 | optimizers: Dict[str, torch.optim.Optimizer],
173 | binoms: Tensor,
174 | ) -> int:
175 | current_n_points = len(params["means"])
176 | n_target = min(self.cap_max, int(1.05 * current_n_points))
177 | n_gs = max(0, n_target - current_n_points)
178 | if n_gs > 0:
179 | sample_add(
180 | params=params,
181 | optimizers=optimizers,
182 | state={},
183 | n=n_gs,
184 | binoms=binoms,
185 | min_opacity=self.min_opacity,
186 | )
187 | return n_gs
188 |
--------------------------------------------------------------------------------
/gsplat/strategy/ops.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, List, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 | from gsplat import quat_scale_to_covar_preci
9 | from gsplat.relocation import compute_relocation
10 | from gsplat.utils import normalized_quat_to_rotmat
11 |
12 |
13 | @torch.no_grad()
14 | def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Tensor:
15 | """Sample from a distribution using torch.multinomial or numpy.random.choice.
16 |
17 | This function adaptively chooses between `torch.multinomial` and `numpy.random.choice`
18 | based on the number of elements in `weights`. If the number of elements exceeds
19 | the torch.multinomial limit (2^24), it falls back to using `numpy.random.choice`.
20 |
21 | Args:
22 | weights (Tensor): A 1D tensor of weights for each element.
23 | n (int): The number of samples to draw.
24 | replacement (bool): Whether to sample with replacement. Default is True.
25 |
26 | Returns:
27 | Tensor: A 1D tensor of sampled indices.
28 | """
29 | num_elements = weights.size(0)
30 |
31 | if num_elements <= 2**24:
32 | # Use torch.multinomial for elements within the limit
33 | return torch.multinomial(weights, n, replacement=replacement)
34 | else:
35 | # Fallback to numpy.random.choice for larger element spaces
36 | weights = weights / weights.sum()
37 | weights_np = weights.detach().cpu().numpy()
38 | sampled_idxs_np = np.random.choice(
39 | num_elements, size=n, p=weights_np, replace=replacement
40 | )
41 | sampled_idxs = torch.from_numpy(sampled_idxs_np)
42 |
43 | # Return the sampled indices on the original device
44 | return sampled_idxs.to(weights.device)
45 |
46 |
47 | @torch.no_grad()
48 | def _update_param_with_optimizer(
49 | param_fn: Callable[[str, Tensor], Tensor],
50 | optimizer_fn: Callable[[str, Tensor], Tensor],
51 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
52 | optimizers: Dict[str, torch.optim.Optimizer],
53 | names: Union[List[str], None] = None,
54 | ):
55 | """Update the parameters and the state in the optimizers with defined functions.
56 |
57 | Args:
58 | param_fn: A function that takes the name of the parameter and the parameter itself,
59 | and returns the new parameter.
60 | optimizer_fn: A function that takes the key of the optimizer state and the state value,
61 | and returns the new state value.
62 | params: A dictionary of parameters.
63 | optimizers: A dictionary of optimizers, each corresponding to a parameter.
64 | names: A list of key names to update. If None, update all. Default: None.
65 | """
66 | if names is None:
67 | # If names is not provided, update all parameters
68 | names = list(params.keys())
69 |
70 | # Split the names into optimized and un-optimized
71 | un_optimized_names = set(names) - set(optimizers.keys())
72 | optimized_names = set(names) - un_optimized_names
73 |
74 | for name in optimized_names:
75 | optimizer = optimizers[name]
76 | for i, param_group in enumerate(optimizer.param_groups):
77 | p = param_group["params"][0]
78 | p_state = optimizer.state[p]
79 | del optimizer.state[p]
80 | for key in p_state.keys():
81 | if key != "step":
82 | v = p_state[key]
83 | p_state[key] = optimizer_fn(key, v)
84 | p_new = param_fn(name, p)
85 | optimizer.param_groups[i]["params"] = [p_new]
86 | optimizer.state[p_new] = p_state
87 | params[name] = p_new
88 |
89 | for name in un_optimized_names:
90 | params[name] = param_fn(name, params[name])
91 |
92 |
93 | @torch.no_grad()
94 | def duplicate(
95 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
96 | optimizers: Dict[str, torch.optim.Optimizer],
97 | state: Dict[str, Tensor],
98 | mask: Tensor,
99 | ):
100 | """Inplace duplicate the Gaussian with the given mask.
101 |
102 | Args:
103 | params: A dictionary of parameters.
104 | optimizers: A dictionary of optimizers, each corresponding to a parameter.
105 | mask: A boolean mask to duplicate the Gaussians.
106 | """
107 | device = mask.device
108 | sel = torch.where(mask)[0]
109 |
110 | def param_fn(name: str, p: Tensor) -> Tensor:
111 | return torch.nn.Parameter(torch.cat([p, p[sel]]))
112 |
113 | def optimizer_fn(key: str, v: Tensor) -> Tensor:
114 | return torch.cat([v, torch.zeros((len(sel), *v.shape[1:]), device=device)])
115 |
116 | # update the parameters and the state in the optimizers
117 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers)
118 | # update the extra running state
119 | for k, v in state.items():
120 | if isinstance(v, torch.Tensor):
121 | state[k] = torch.cat((v, v[sel]))
122 |
123 |
124 | @torch.no_grad()
125 | def split(
126 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
127 | optimizers: Dict[str, torch.optim.Optimizer],
128 | state: Dict[str, Tensor],
129 | mask: Tensor,
130 | revised_opacity: bool = False,
131 | ):
132 | """Inplace split the Gaussian with the given mask.
133 |
134 | Args:
135 | params: A dictionary of parameters.
136 | optimizers: A dictionary of optimizers, each corresponding to a parameter.
137 | mask: A boolean mask to split the Gaussians.
138 | revised_opacity: Whether to use revised opacity formulation
139 | from arXiv:2404.06109. Default: False.
140 | """
141 | device = mask.device
142 | sel = torch.where(mask)[0]
143 | rest = torch.where(~mask)[0]
144 |
145 | scales = torch.exp(params["scales"][sel])
146 | quats = F.normalize(params["quats"][sel], dim=-1)
147 | rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3]
148 | samples = torch.einsum(
149 | "nij,nj,bnj->bni",
150 | rotmats,
151 | scales,
152 | torch.randn(2, len(scales), 3, device=device),
153 | ) # [2, N, 3]
154 |
155 | def param_fn(name: str, p: Tensor) -> Tensor:
156 | repeats = [2] + [1] * (p.dim() - 1)
157 | if name == "means":
158 | p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3]
159 | elif name == "scales":
160 | p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3]
161 | elif name == "opacities" and revised_opacity:
162 | new_opacities = 1.0 - torch.sqrt(1.0 - torch.sigmoid(p[sel]))
163 | p_split = torch.logit(new_opacities).repeat(repeats) # [2N]
164 | else:
165 | p_split = p[sel].repeat(repeats)
166 | p_new = torch.cat([p[rest], p_split])
167 | p_new = torch.nn.Parameter(p_new)
168 | return p_new
169 |
170 | def optimizer_fn(key: str, v: Tensor) -> Tensor:
171 | v_split = torch.zeros((2 * len(sel), *v.shape[1:]), device=device)
172 | return torch.cat([v[rest], v_split])
173 |
174 | # update the parameters and the state in the optimizers
175 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers)
176 | # update the extra running state
177 | for k, v in state.items():
178 | if isinstance(v, torch.Tensor):
179 | repeats = [2] + [1] * (v.dim() - 1)
180 | v_new = v[sel].repeat(repeats)
181 | state[k] = torch.cat((v[rest], v_new))
182 |
183 |
184 | @torch.no_grad()
185 | def remove(
186 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
187 | optimizers: Dict[str, torch.optim.Optimizer],
188 | state: Dict[str, Tensor],
189 | mask: Tensor,
190 | ):
191 | """Inplace remove the Gaussian with the given mask.
192 |
193 | Args:
194 | params: A dictionary of parameters.
195 | optimizers: A dictionary of optimizers, each corresponding to a parameter.
196 | mask: A boolean mask to remove the Gaussians.
197 | """
198 | sel = torch.where(~mask)[0]
199 |
200 | def param_fn(name: str, p: Tensor) -> Tensor:
201 | return torch.nn.Parameter(p[sel])
202 |
203 | def optimizer_fn(key: str, v: Tensor) -> Tensor:
204 | return v[sel]
205 |
206 | # update the parameters and the state in the optimizers
207 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers)
208 | # update the extra running state
209 | for k, v in state.items():
210 | if isinstance(v, torch.Tensor):
211 | state[k] = v[sel]
212 |
213 |
214 | @torch.no_grad()
215 | def reset_opa(
216 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
217 | optimizers: Dict[str, torch.optim.Optimizer],
218 | state: Dict[str, Tensor],
219 | value: float,
220 | ):
221 | """Inplace reset the opacities to the given post-sigmoid value.
222 |
223 | Args:
224 | params: A dictionary of parameters.
225 | optimizers: A dictionary of optimizers, each corresponding to a parameter.
226 | value: The value to reset the opacities
227 | """
228 |
229 | def param_fn(name: str, p: Tensor) -> Tensor:
230 | if name == "opacities":
231 | opacities = torch.clamp(p, max=torch.logit(torch.tensor(value)).item())
232 | return torch.nn.Parameter(opacities)
233 | else:
234 | raise ValueError(f"Unexpected parameter name: {name}")
235 |
236 | def optimizer_fn(key: str, v: Tensor) -> Tensor:
237 | return torch.zeros_like(v)
238 |
239 | # update the parameters and the state in the optimizers
240 | _update_param_with_optimizer(
241 | param_fn, optimizer_fn, params, optimizers, names=["opacities"]
242 | )
243 |
244 |
245 | @torch.no_grad()
246 | def relocate(
247 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
248 | optimizers: Dict[str, torch.optim.Optimizer],
249 | state: Dict[str, Tensor],
250 | mask: Tensor,
251 | binoms: Tensor,
252 | min_opacity: float = 0.005,
253 | ):
254 | """Inplace relocate some dead Gaussians to the lives ones.
255 |
256 | Args:
257 | params: A dictionary of parameters.
258 | optimizers: A dictionary of optimizers, each corresponding to a parameter.
259 | mask: A boolean mask to indicates which Gaussians are dead.
260 | """
261 | # support "opacities" with shape [N,] or [N, 1]
262 | opacities = torch.sigmoid(params["opacities"])
263 |
264 | dead_indices = mask.nonzero(as_tuple=True)[0]
265 | alive_indices = (~mask).nonzero(as_tuple=True)[0]
266 | n = len(dead_indices)
267 |
268 | # Sample for new GSs
269 | eps = torch.finfo(torch.float32).eps
270 | probs = opacities[alive_indices].flatten() # ensure its shape is [N,]
271 | sampled_idxs = _multinomial_sample(probs, n, replacement=True)
272 | sampled_idxs = alive_indices[sampled_idxs]
273 | new_opacities, new_scales = compute_relocation(
274 | opacities=opacities[sampled_idxs],
275 | scales=torch.exp(params["scales"])[sampled_idxs],
276 | ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1,
277 | binoms=binoms,
278 | )
279 | new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)
280 |
281 | def param_fn(name: str, p: Tensor) -> Tensor:
282 | if name == "opacities":
283 | p[sampled_idxs] = torch.logit(new_opacities)
284 | elif name == "scales":
285 | p[sampled_idxs] = torch.log(new_scales)
286 | p[dead_indices] = p[sampled_idxs]
287 | return torch.nn.Parameter(p)
288 |
289 | def optimizer_fn(key: str, v: Tensor) -> Tensor:
290 | v[sampled_idxs] = 0
291 | return v
292 |
293 | # update the parameters and the state in the optimizers
294 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers)
295 | # update the extra running state
296 | for k, v in state.items():
297 | if isinstance(v, torch.Tensor):
298 | v[sampled_idxs] = 0
299 |
300 |
301 | @torch.no_grad()
302 | def sample_add(
303 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
304 | optimizers: Dict[str, torch.optim.Optimizer],
305 | state: Dict[str, Tensor],
306 | n: int,
307 | binoms: Tensor,
308 | min_opacity: float = 0.005,
309 | ):
310 | opacities = torch.sigmoid(params["opacities"])
311 |
312 | eps = torch.finfo(torch.float32).eps
313 | probs = opacities.flatten()
314 | sampled_idxs = _multinomial_sample(probs, n, replacement=True)
315 | new_opacities, new_scales = compute_relocation(
316 | opacities=opacities[sampled_idxs],
317 | scales=torch.exp(params["scales"])[sampled_idxs],
318 | ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1,
319 | binoms=binoms,
320 | )
321 | new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)
322 |
323 | def param_fn(name: str, p: Tensor) -> Tensor:
324 | if name == "opacities":
325 | p[sampled_idxs] = torch.logit(new_opacities)
326 | elif name == "scales":
327 | p[sampled_idxs] = torch.log(new_scales)
328 | p = torch.cat([p, p[sampled_idxs]])
329 | return torch.nn.Parameter(p)
330 |
331 | def optimizer_fn(key: str, v: Tensor) -> Tensor:
332 | v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device)
333 | return torch.cat([v, v_new])
334 |
335 | # update the parameters and the state in the optimizers
336 | _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers)
337 | # update the extra running state
338 | for k, v in state.items():
339 | v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device)
340 | if isinstance(v, torch.Tensor):
341 | state[k] = torch.cat((v, v_new))
342 |
343 |
344 | @torch.no_grad()
345 | def inject_noise_to_position(
346 | params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
347 | optimizers: Dict[str, torch.optim.Optimizer],
348 | state: Dict[str, Tensor],
349 | scaler: float,
350 | ):
351 | opacities = torch.sigmoid(params["opacities"].flatten())
352 | scales = torch.exp(params["scales"])
353 | covars, _ = quat_scale_to_covar_preci(
354 | params["quats"],
355 | scales,
356 | compute_covar=True,
357 | compute_preci=False,
358 | triu=False,
359 | )
360 |
361 | def op_sigmoid(x, k=100, x0=0.995):
362 | return 1 / (1 + torch.exp(-k * (x - x0)))
363 |
364 | noise = (
365 | torch.randn_like(params["means"])
366 | * (op_sigmoid(1 - opacities)).unsqueeze(-1)
367 | * scaler
368 | )
369 | noise = torch.einsum("bij,bj->bi", covars, noise)
370 | params["means"].add_(noise)
371 |
--------------------------------------------------------------------------------
/gsplat/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 |
7 |
8 | def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
9 | """Convert normalized quaternion to rotation matrix.
10 |
11 | Args:
12 | quat: Normalized quaternion in wxyz convension. (..., 4)
13 |
14 | Returns:
15 | Rotation matrix (..., 3, 3)
16 | """
17 | assert quat.shape[-1] == 4, quat.shape
18 | w, x, y, z = torch.unbind(quat, dim=-1)
19 | mat = torch.stack(
20 | [
21 | 1 - 2 * (y**2 + z**2),
22 | 2 * (x * y - w * z),
23 | 2 * (x * z + w * y),
24 | 2 * (x * y + w * z),
25 | 1 - 2 * (x**2 + z**2),
26 | 2 * (y * z - w * x),
27 | 2 * (x * z - w * y),
28 | 2 * (y * z + w * x),
29 | 1 - 2 * (x**2 + y**2),
30 | ],
31 | dim=-1,
32 | )
33 | return mat.reshape(quat.shape[:-1] + (3, 3))
34 |
35 |
36 | def log_transform(x):
37 | return torch.sign(x) * torch.log1p(torch.abs(x))
38 |
39 |
40 | def inverse_log_transform(y):
41 | return torch.sign(y) * (torch.expm1(torch.abs(y)))
42 |
43 |
44 | def depth_to_points(
45 | depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
46 | ) -> Tensor:
47 | """Convert depth maps to 3D points
48 |
49 | Args:
50 | depths: Depth maps [..., H, W, 1]
51 | camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
52 | Ks: Camera intrinsics [..., 3, 3]
53 | z_depth: Whether the depth is in z-depth (True) or ray depth (False)
54 |
55 | Returns:
56 | points: 3D points in the world coordinate system [..., H, W, 3]
57 | """
58 | assert depths.shape[-1] == 1, f"Invalid depth shape: {depths.shape}"
59 | assert camtoworlds.shape[-2:] == (
60 | 4,
61 | 4,
62 | ), f"Invalid viewmats shape: {camtoworlds.shape}"
63 | assert Ks.shape[-2:] == (3, 3), f"Invalid Ks shape: {Ks.shape}"
64 | assert (
65 | depths.shape[:-3] == camtoworlds.shape[:-2] == Ks.shape[:-2]
66 | ), f"Shape mismatch! depths: {depths.shape}, viewmats: {camtoworlds.shape}, Ks: {Ks.shape}"
67 |
68 | device = depths.device
69 | height, width = depths.shape[-3:-1]
70 |
71 | x, y = torch.meshgrid(
72 | torch.arange(width, device=device),
73 | torch.arange(height, device=device),
74 | indexing="xy",
75 | ) # [H, W]
76 |
77 | fx = Ks[..., 0, 0] # [...]
78 | fy = Ks[..., 1, 1] # [...]
79 | cx = Ks[..., 0, 2] # [...]
80 | cy = Ks[..., 1, 2] # [...]
81 |
82 | # camera directions in camera coordinates
83 | camera_dirs = F.pad(
84 | torch.stack(
85 | [
86 | (x - cx[..., None, None] + 0.5) / fx[..., None, None],
87 | (y - cy[..., None, None] + 0.5) / fy[..., None, None],
88 | ],
89 | dim=-1,
90 | ),
91 | (0, 1),
92 | value=1.0,
93 | ) # [..., H, W, 3]
94 |
95 | # ray directions in world coordinates
96 | directions = torch.einsum(
97 | "...ij,...hwj->...hwi", camtoworlds[..., :3, :3], camera_dirs
98 | ) # [..., H, W, 3]
99 | origins = camtoworlds[..., :3, -1] # [..., 3]
100 |
101 | if not z_depth:
102 | directions = F.normalize(directions, dim=-1)
103 |
104 | points = origins[..., None, None, :] + depths * directions
105 | return points
106 |
107 |
108 | def depth_to_normal(
109 | depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
110 | ) -> Tensor:
111 | """Convert depth maps to surface normals
112 |
113 | Args:
114 | depths: Depth maps [..., H, W, 1]
115 | camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
116 | Ks: Camera intrinsics [..., 3, 3]
117 | z_depth: Whether the depth is in z-depth (True) or ray depth (False)
118 |
119 | Returns:
120 | normals: Surface normals in the world coordinate system [..., H, W, 3]
121 | """
122 | points = depth_to_points(depths, camtoworlds, Ks, z_depth=z_depth) # [..., H, W, 3]
123 | dx = torch.cat(
124 | [points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :]], dim=-3
125 | ) # [..., H-2, W-2, 3]
126 | dy = torch.cat(
127 | [points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :]], dim=-2
128 | ) # [..., H-2, W-2, 3]
129 | normals = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) # [..., H-2, W-2, 3]
130 | normals = F.pad(normals, (0, 0, 1, 1, 1, 1), value=0.0) # [..., H, W, 3]
131 | return normals
132 |
133 |
134 | def get_projection_matrix(znear, zfar, fovX, fovY, device="cuda"):
135 | """Create OpenGL-style projection matrix"""
136 | tanHalfFovY = math.tan((fovY / 2))
137 | tanHalfFovX = math.tan((fovX / 2))
138 |
139 | top = tanHalfFovY * znear
140 | bottom = -top
141 | right = tanHalfFovX * znear
142 | left = -right
143 |
144 | P = torch.zeros(4, 4, device=device)
145 |
146 | z_sign = 1.0
147 |
148 | P[0, 0] = 2.0 * znear / (right - left)
149 | P[1, 1] = 2.0 * znear / (top - bottom)
150 | P[0, 2] = (right + left) / (right - left)
151 | P[1, 2] = (top + bottom) / (top - bottom)
152 | P[3, 2] = z_sign
153 | P[2, 2] = z_sign * zfar / (zfar - znear)
154 | P[2, 3] = -(zfar * znear) / (zfar - znear)
155 | return P
156 |
157 |
158 | # def depth_to_normal(
159 | # depths: Tensor, camtoworlds: Tensor, Ks: Tensor, near_plane: float, far_plane: float
160 | # ) -> Tensor:
161 | # """
162 | # Convert depth to surface normal
163 |
164 | # Args:
165 | # depths: Z-depth of the Gaussians.
166 | # camtoworlds: camera to world transformation matrix.
167 | # Ks: camera intrinsics.
168 | # near_plane: Near plane distance.
169 | # far_plane: Far plane distance.
170 |
171 | # Returns:
172 | # Surface normals.
173 | # """
174 | # height, width = depths.shape[1:3]
175 | # viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4]
176 |
177 | # normals = []
178 | # for cid, depth in enumerate(depths):
179 | # FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item()))
180 | # FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item()))
181 | # world_view_transform = viewmats[cid].transpose(0, 1)
182 | # projection_matrix = _get_projection_matrix(
183 | # znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device
184 | # ).transpose(0, 1)
185 | # full_proj_transform = (
186 | # world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))
187 | # ).squeeze(0)
188 | # normal = _depth_to_normal(
189 | # depth,
190 | # world_view_transform,
191 | # full_proj_transform,
192 | # Ks[cid, 0, 0],
193 | # Ks[cid, 1, 1],
194 | # )
195 | # normals.append(normal)
196 | # normals = torch.stack(normals, dim=0)
197 | # return normals
198 |
199 |
200 | # # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/61c7b417393d5e0c58b742ad5e2e5f9e9f240cc6/utils/point_utils.py#L26
201 | # def _depths_to_points(
202 | # depthmap, world_view_transform, full_proj_transform, fx, fy
203 | # ) -> Tensor:
204 | # c2w = (world_view_transform.T).inverse()
205 | # H, W = depthmap.shape[:2]
206 |
207 | # intrins = (
208 | # torch.tensor([[fx, 0.0, W / 2.0], [0.0, fy, H / 2.0], [0.0, 0.0, 1.0]])
209 | # .float()
210 | # .cuda()
211 | # )
212 |
213 | # grid_x, grid_y = torch.meshgrid(
214 | # torch.arange(W, device="cuda").float(),
215 | # torch.arange(H, device="cuda").float(),
216 | # indexing="xy",
217 | # )
218 | # points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(
219 | # -1, 3
220 | # )
221 | # rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T
222 | # rays_o = c2w[:3, 3]
223 | # points = depthmap.reshape(-1, 1) * rays_d + rays_o
224 | # return points
225 |
226 |
227 | # def _depth_to_normal(
228 | # depth, world_view_transform, full_proj_transform, fx, fy
229 | # ) -> Tensor:
230 | # points = _depths_to_points(
231 | # depth,
232 | # world_view_transform,
233 | # full_proj_transform,
234 | # fx,
235 | # fy,
236 | # ).reshape(*depth.shape[:2], 3)
237 | # output = torch.zeros_like(points)
238 | # dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0)
239 | # dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1)
240 | # normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
241 | # output[1:-1, 1:-1, :] = normal_map
242 | # return output
243 |
244 |
245 | # def _get_projection_matrix(znear, zfar, fovX, fovY, device="cuda") -> Tensor:
246 | # tanHalfFovY = math.tan((fovY / 2))
247 | # tanHalfFovX = math.tan((fovX / 2))
248 |
249 | # top = tanHalfFovY * znear
250 | # bottom = -top
251 | # right = tanHalfFovX * znear
252 | # left = -right
253 |
254 | # P = torch.zeros(4, 4, device=device)
255 |
256 | # z_sign = 1.0
257 |
258 | # P[0, 0] = 2.0 * znear / (right - left)
259 | # P[1, 1] = 2.0 * znear / (top - bottom)
260 | # P[0, 2] = (right + left) / (right - left)
261 | # P[1, 2] = (top + bottom) / (top - bottom)
262 | # P[3, 2] = z_sign
263 | # P[2, 2] = z_sign * zfar / (zfar - znear)
264 | # P[2, 3] = -(zfar * znear) / (zfar - znear)
265 | # return P
266 |
--------------------------------------------------------------------------------
/gsplat/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import os.path as osp
4 | import platform
5 | import sys
6 |
7 | from setuptools import find_packages, setup
8 |
9 | __version__ = None
10 | exec(open("gsplat/version.py", "r").read())
11 |
12 | URL = "https://github.com/nerfstudio-project/gsplat"
13 |
14 | BUILD_NO_CUDA = os.getenv("BUILD_NO_CUDA", "0") == "1"
15 | WITH_SYMBOLS = os.getenv("WITH_SYMBOLS", "0") == "1"
16 | LINE_INFO = os.getenv("LINE_INFO", "0") == "1"
17 |
18 |
19 | def get_ext():
20 | from torch.utils.cpp_extension import BuildExtension
21 |
22 | return BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
23 |
24 |
25 | def get_extensions():
26 | import torch
27 | from torch.__config__ import parallel_info
28 | from torch.utils.cpp_extension import CUDAExtension
29 |
30 | extensions_dir_v1 = osp.join("gsplat", "cuda_legacy", "csrc")
31 | sources_v1 = glob.glob(osp.join(extensions_dir_v1, "*.cu")) + glob.glob(
32 | osp.join(extensions_dir_v1, "*.cpp")
33 | )
34 | sources_v1 = [path for path in sources_v1 if "hip" not in path]
35 |
36 | extensions_dir_v2 = osp.join("gsplat", "cuda", "csrc")
37 | sources_v2 = glob.glob(osp.join(extensions_dir_v2, "*.cu")) + glob.glob(
38 | osp.join(extensions_dir_v2, "*.cpp")
39 | )
40 | sources_v2 = [path for path in sources_v2 if "hip" not in path]
41 |
42 | undef_macros = []
43 | define_macros = []
44 |
45 | if sys.platform == "win32":
46 | define_macros += [("gsplat_EXPORTS", None)]
47 |
48 | extra_compile_args = {"cxx": ["-O3"]}
49 | if not os.name == "nt": # Not on Windows:
50 | extra_compile_args["cxx"] += ["-Wno-sign-compare"]
51 | extra_link_args = [] if WITH_SYMBOLS else ["-s"]
52 |
53 | info = parallel_info()
54 | if (
55 | "backend: OpenMP" in info
56 | and "OpenMP not found" not in info
57 | and sys.platform != "darwin"
58 | ):
59 | extra_compile_args["cxx"] += ["-DAT_PARALLEL_OPENMP"]
60 | if sys.platform == "win32":
61 | extra_compile_args["cxx"] += ["/openmp"]
62 | else:
63 | extra_compile_args["cxx"] += ["-fopenmp"]
64 | else:
65 | print("Compiling without OpenMP...")
66 |
67 | # Compile for mac arm64
68 | if sys.platform == "darwin" and platform.machine() == "arm64":
69 | extra_compile_args["cxx"] += ["-arch", "arm64"]
70 | extra_link_args += ["-arch", "arm64"]
71 |
72 | nvcc_flags = os.getenv("NVCC_FLAGS", "")
73 | nvcc_flags = [] if nvcc_flags == "" else nvcc_flags.split(" ")
74 | nvcc_flags += ["-O3", "--use_fast_math"]
75 | if LINE_INFO:
76 | nvcc_flags += ["-lineinfo"]
77 | if torch.version.hip:
78 | # USE_ROCM was added to later versions of PyTorch.
79 | # Define here to support older PyTorch versions as well:
80 | define_macros += [("USE_ROCM", None)]
81 | undef_macros += ["__HIP_NO_HALF_CONVERSIONS__"]
82 | else:
83 | nvcc_flags += ["--expt-relaxed-constexpr"]
84 | extra_compile_args["nvcc"] = nvcc_flags
85 | if sys.platform == "win32":
86 | extra_compile_args["nvcc"] += ["-DWIN32_LEAN_AND_MEAN"]
87 |
88 | extension_v1 = CUDAExtension(
89 | f"gsplat.csrc_legacy",
90 | sources_v1,
91 | include_dirs=[extensions_dir_v2], # glm lives in v2.
92 | define_macros=define_macros,
93 | undef_macros=undef_macros,
94 | extra_compile_args=extra_compile_args,
95 | extra_link_args=extra_link_args,
96 | )
97 | extension_v2 = CUDAExtension(
98 | f"gsplat.csrc",
99 | sources_v2,
100 | include_dirs=[extensions_dir_v2], # glm lives in v2.
101 | define_macros=define_macros,
102 | undef_macros=undef_macros,
103 | extra_compile_args=extra_compile_args,
104 | extra_link_args=extra_link_args,
105 | )
106 |
107 | return [extension_v1, extension_v2]
108 |
109 |
110 | setup(
111 | name="gsplat",
112 | version=__version__,
113 | description=" Python package for differentiable rasterization of gaussians",
114 | keywords="gaussian, splatting, cuda",
115 | url=URL,
116 | download_url=f"{URL}/archive/gsplat-{__version__}.tar.gz",
117 | python_requires=">=3.7",
118 | install_requires=[
119 | "ninja",
120 | "numpy",
121 | "jaxtyping",
122 | "rich>=12",
123 | "torch",
124 | "typing_extensions; python_version<'3.8'",
125 | ],
126 | extras_require={
127 | # dev dependencies. Install them by `pip install gsplat[dev]`
128 | "dev": [
129 | "black[jupyter]==22.3.0",
130 | "isort==5.10.1",
131 | "pylint==2.13.4",
132 | "pytest==7.1.2",
133 | "pytest-xdist==2.5.0",
134 | "typeguard>=2.13.3",
135 | "pyyaml==6.0",
136 | "build",
137 | "twine",
138 | ],
139 | },
140 | ext_modules=get_extensions() if not BUILD_NO_CUDA else [],
141 | cmdclass={"build_ext": get_ext()} if not BUILD_NO_CUDA else {},
142 | packages=find_packages(),
143 | # https://github.com/pypa/setuptools/issues/1461#issuecomment-954725244
144 | include_package_data=True,
145 | )
146 |
--------------------------------------------------------------------------------
/tests/_test_distributed.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from gsplat.distributed import (
5 | all_gather_int32,
6 | all_gather_tensor_list,
7 | all_to_all_int32,
8 | all_to_all_tensor_list,
9 | cli,
10 | )
11 |
12 |
13 | def _main_all_gather_int32(local_rank: int, world_rank: int, world_size: int, _):
14 | device = torch.device("cuda", local_rank)
15 |
16 | value = world_rank
17 | collected = all_gather_int32(world_size, value, device=device)
18 | for i in range(world_size):
19 | assert collected[i] == i
20 |
21 | value = torch.tensor(world_rank, device=device, dtype=torch.int)
22 | collected = all_gather_int32(world_size, value, device=device)
23 | for i in range(world_size):
24 | assert collected[i] == torch.tensor(i, device=device, dtype=torch.int)
25 |
26 |
27 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
28 | def test_all_gather_int32():
29 | cli(_main_all_gather_int32, None, verbose=True)
30 |
31 |
32 | def _main_all_to_all_int32(local_rank: int, world_rank: int, world_size: int, _):
33 | device = torch.device("cuda", local_rank)
34 |
35 | values = list(range(world_size))
36 | collected = all_to_all_int32(world_size, values, device=device)
37 | for i in range(world_size):
38 | assert collected[i] == world_rank
39 |
40 | values = torch.arange(world_size, device=device, dtype=torch.int)
41 | collected = all_to_all_int32(world_size, values, device=device)
42 | for i in range(world_size):
43 | assert collected[i] == torch.tensor(world_rank, device=device, dtype=torch.int)
44 |
45 |
46 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
47 | def test_all_to_all_int32():
48 | cli(_main_all_to_all_int32, None, verbose=True)
49 |
50 |
51 | def _main_all_gather_tensor_list(local_rank: int, world_rank: int, world_size: int, _):
52 | device = torch.device("cuda", local_rank)
53 | N = 10
54 |
55 | tensor_list = [
56 | torch.full((N, 2), world_rank, device=device),
57 | torch.full((N, 3, 3), world_rank, device=device),
58 | ]
59 |
60 | target_list = [
61 | torch.cat([torch.full((N, 2), i, device=device) for i in range(world_size)]),
62 | torch.cat([torch.full((N, 3, 3), i, device=device) for i in range(world_size)]),
63 | ]
64 |
65 | collected = all_gather_tensor_list(world_size, tensor_list)
66 | for tensor, target in zip(collected, target_list):
67 | assert torch.equal(tensor, target)
68 |
69 |
70 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
71 | def test_all_gather_tensor_list():
72 | cli(_main_all_gather_tensor_list, None, verbose=True)
73 |
74 |
75 | def _main_all_to_all_tensor_list(local_rank: int, world_rank: int, world_size: int, _):
76 | device = torch.device("cuda", local_rank)
77 | splits = torch.arange(0, world_size, device=device)
78 | N = splits.sum().item()
79 |
80 | tensor_list = [
81 | torch.full((N, 2), world_rank, device=device),
82 | torch.full((N, 3, 3), world_rank, device=device),
83 | ]
84 |
85 | target_list = [
86 | torch.cat(
87 | [
88 | torch.full((splits[world_rank], 2), i, device=device)
89 | for i in range(world_size)
90 | ]
91 | ),
92 | torch.cat(
93 | [
94 | torch.full((splits[world_rank], 3, 3), i, device=device)
95 | for i in range(world_size)
96 | ]
97 | ),
98 | ]
99 |
100 | collected = all_to_all_tensor_list(world_size, tensor_list, splits)
101 | for tensor, target in zip(collected, target_list):
102 | assert torch.equal(tensor, target)
103 |
104 |
105 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
106 | def test_all_to_all_tensor_list():
107 | cli(_main_all_to_all_tensor_list, None, verbose=True)
108 |
109 |
110 | if __name__ == "__main__":
111 | test_all_gather_int32()
112 | test_all_to_all_int32()
113 | test_all_gather_tensor_list()
114 | test_all_to_all_tensor_list()
115 |
--------------------------------------------------------------------------------
/tests/test_compression.py:
--------------------------------------------------------------------------------
1 | """Tests for the functions in the CUDA extension.
2 |
3 | Usage:
4 | ```bash
5 | pytest -s
6 | ```
7 | """
8 |
9 | import pytest
10 | import torch
11 |
12 | device = torch.device("cuda:0")
13 |
14 |
15 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
16 | def test_png_compression():
17 | from gsplat.compression import PngCompression
18 |
19 | torch.manual_seed(42)
20 |
21 | # Prepare Gaussians
22 | N = 100000
23 | splats = torch.nn.ParameterDict(
24 | {
25 | "means": torch.randn(N, 3),
26 | "scales": torch.randn(N, 3),
27 | "quats": torch.randn(N, 4),
28 | "opacities": torch.randn(N),
29 | "sh0": torch.randn(N, 1, 3),
30 | "shN": torch.randn(N, 24, 3),
31 | "features": torch.randn(N, 128),
32 | }
33 | ).to(device)
34 | compress_dir = "/tmp/gsplat/compression"
35 |
36 | compression_method = PngCompression()
37 | # run compression and save the compressed files to compress_dir
38 | compression_method.compress(compress_dir, splats)
39 | # decompress the compressed files
40 | splats_c = compression_method.decompress(compress_dir)
41 |
42 |
43 | if __name__ == "__main__":
44 | test_png_compression()
45 |
--------------------------------------------------------------------------------
/tests/test_rasterization.py:
--------------------------------------------------------------------------------
1 | """Tests for the functions in the CUDA extension.
2 |
3 | Usage:
4 | ```bash
5 | pytest -s
6 | ```
7 | """
8 |
9 | import math
10 | from typing import Optional
11 |
12 | import pytest
13 | import torch
14 |
15 | device = torch.device("cuda:0")
16 |
17 |
18 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
19 | @pytest.mark.parametrize("per_view_color", [True, False])
20 | @pytest.mark.parametrize("sh_degree", [None, 3])
21 | @pytest.mark.parametrize("render_mode", ["RGB", "RGB+D", "D"])
22 | @pytest.mark.parametrize("packed", [True, False])
23 | def test_rasterization(per_view_color: bool, sh_degree: Optional[int], render_mode: str, packed: bool):
24 | from gsplat.rendering import rasterization
25 |
26 | torch.manual_seed(42)
27 |
28 | C, N = 2, 10_000
29 | means = torch.rand(N, 3, device=device)
30 | velocities = torch.randn(N, 3, device=device) * 0.01
31 | quats = torch.randn(N, 4, device=device)
32 | scales = torch.rand(N, 3, device=device)
33 | opacities = torch.rand(N, device=device)
34 | if per_view_color:
35 | if sh_degree is None:
36 | colors = torch.rand(C, N, 3, device=device)
37 | else:
38 | colors = torch.rand(C, N, (sh_degree + 1) ** 2, 3, device=device)
39 | else:
40 | if sh_degree is None:
41 | colors = torch.rand(N, 3, device=device)
42 | else:
43 | colors = torch.rand(N, (sh_degree + 1) ** 2, 3, device=device)
44 |
45 | width, height = 300, 200
46 | focal = 300.0
47 | Ks = torch.tensor(
48 | [[focal, 0.0, width / 2.0], [0.0, focal, height / 2.0], [0.0, 0.0, 1.0]],
49 | device=device,
50 | ).expand(C, -1, -1)
51 | viewmats = torch.eye(4, device=device).expand(C, -1, -1)
52 |
53 | linear_velocity = torch.randn(C, 3, device=device) * 0.01
54 | angular_velocity = torch.randn(C, 3, device=device) * 0.01
55 | rolling_shutter_time = torch.rand(C, device=device) * 0.1
56 |
57 | colors_out, _, _ = rasterization(
58 | means=means,
59 | quats=quats,
60 | scales=scales,
61 | opacities=opacities,
62 | colors=colors,
63 | velocities=velocities,
64 | viewmats=viewmats,
65 | Ks=Ks,
66 | width=width,
67 | height=height,
68 | linear_velocity=linear_velocity,
69 | angular_velocity=angular_velocity,
70 | rolling_shutter_time=rolling_shutter_time,
71 | sh_degree=sh_degree,
72 | render_mode=render_mode,
73 | packed=packed,
74 | )
75 |
76 | if render_mode == "D":
77 | assert colors_out.shape == (C, height, width, 1)
78 | elif render_mode == "RGB":
79 | assert colors_out.shape == (C, height, width, 3)
80 | elif render_mode == "RGB+D":
81 | assert colors_out.shape == (C, height, width, 4)
82 |
83 |
84 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
85 | @pytest.mark.parametrize("channels", [3, 32, 128])
86 | def test_lidar_rasterization(channels: int):
87 | from gsplat.rendering import lidar_rasterization
88 |
89 | torch.manual_seed(42)
90 |
91 | C, N = 2, 10_000
92 | means = torch.rand(N, 3, device=device)
93 | quats = torch.randn(N, 4, device=device)
94 | scales = torch.rand(N, 3, device=device)
95 | opacities = torch.rand(N, device=device)
96 | velocities = torch.randn(N, 3, device=device) * 0.01
97 |
98 | min_azimuth = -180
99 | max_azimuth = 180
100 | min_elevation = -45
101 | max_elevation = 45
102 | n_elevation_channels = 32
103 | azimuth_resolution = 0.2
104 |
105 | tile_width = 64
106 | tile_height = 4
107 | tile_elevation_boundaries = torch.linspace(
108 | min_elevation, max_elevation, n_elevation_channels // tile_height + 1, device=means.device
109 | )
110 |
111 | viewmats = torch.eye(4, device=device).expand(C, -1, -1)
112 | lidar_features = torch.randn(C, len(means), channels, device=device)
113 |
114 | image_width = math.ceil((max_azimuth - min_azimuth) / azimuth_resolution)
115 | raster_pts_azim = torch.linspace(
116 | min_azimuth + azimuth_resolution / 2, max_azimuth - azimuth_resolution / 2, image_width, device=means.device
117 | )
118 | raster_pts_elev = torch.linspace(
119 | min_elevation + (max_elevation - min_elevation) / n_elevation_channels / 2,
120 | max_elevation - (max_elevation - min_elevation) / n_elevation_channels / 2,
121 | n_elevation_channels,
122 | device=means.device,
123 | )
124 | raster_pts = torch.stack(torch.meshgrid(raster_pts_elev, raster_pts_azim), dim=-1)[..., [1, 0]]
125 | ranges = torch.rand(n_elevation_channels, image_width, 1, device=device) * 10
126 | keep_range_mask = torch.rand(n_elevation_channels, image_width, device=device) > 0.1
127 | raster_pts = torch.cat([raster_pts, ranges], dim=-1)
128 | raster_pts = raster_pts.unsqueeze(0).repeat(C, 1, 1, 1)
129 | # add randomness
130 | raster_pts += torch.randn_like(raster_pts) * 0.01 * keep_range_mask[None, ..., None]
131 |
132 | linear_velocity = torch.randn(C, 3, device=device) * 0.01
133 | angular_velocity = torch.randn(C, 3, device=device) * 0.01
134 | rolling_shutter_time = torch.rand(C, device=device) * 0.1
135 |
136 | # add timestamps
137 | raster_pts = torch.cat(
138 | [raster_pts, rolling_shutter_time.max() * torch.randn(raster_pts[..., 0:1].shape, device=raster_pts.device)],
139 | dim=-1,
140 | )
141 |
142 | render_lidar_features, _, _, _ = lidar_rasterization(
143 | means=means,
144 | quats=quats,
145 | scales=scales,
146 | opacities=opacities,
147 | lidar_features=lidar_features,
148 | velocities=velocities,
149 | linear_velocity=linear_velocity,
150 | angular_velocity=angular_velocity,
151 | rolling_shutter_time=rolling_shutter_time,
152 | viewmats=viewmats,
153 | min_azimuth=min_azimuth,
154 | max_azimuth=max_azimuth,
155 | min_elevation=min_elevation,
156 | max_elevation=max_elevation,
157 | n_elevation_channels=n_elevation_channels,
158 | azimuth_resolution=azimuth_resolution,
159 | raster_pts=raster_pts,
160 | tile_width=tile_width,
161 | tile_height=tile_height,
162 | tile_elevation_boundaries=tile_elevation_boundaries,
163 | )
164 |
165 | n_azimuth_pixels = math.ceil((max_azimuth - min_azimuth) / azimuth_resolution)
166 | assert render_lidar_features.shape == (C, n_elevation_channels, n_azimuth_pixels, channels + 1)
167 |
--------------------------------------------------------------------------------
/tests/test_strategy.py:
--------------------------------------------------------------------------------
1 | """Tests for the functions in the CUDA extension.
2 |
3 | Usage:
4 | ```bash
5 | pytest -s
6 | ```
7 | """
8 |
9 | import pytest
10 | import torch
11 |
12 | device = torch.device("cuda:0")
13 |
14 |
15 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
16 | def test_strategy():
17 | from gsplat.rendering import rasterization
18 | from gsplat.strategy import DefaultStrategy, MCMCStrategy
19 |
20 | torch.manual_seed(42)
21 |
22 | # Prepare Gaussians
23 | N = 100
24 | params = torch.nn.ParameterDict(
25 | {
26 | "means": torch.randn(N, 3),
27 | "scales": torch.rand(N, 3),
28 | "quats": torch.randn(N, 4),
29 | "opacities": torch.rand(N),
30 | "colors": torch.rand(N, 3),
31 | }
32 | ).to(device)
33 | optimizers = {k: torch.optim.Adam([v], lr=1e-3) for k, v in params.items()}
34 |
35 | # A dummy rendering call
36 | render_colors, render_alphas, info = rasterization(
37 | means=params["means"],
38 | quats=params["quats"], # F.normalize is fused into the kernel
39 | scales=torch.exp(params["scales"]),
40 | opacities=torch.sigmoid(params["opacities"]),
41 | colors=params["colors"],
42 | velocities=None,
43 | viewmats=torch.eye(4).unsqueeze(0).to(device),
44 | Ks=torch.eye(3).unsqueeze(0).to(device),
45 | width=10,
46 | height=10,
47 | packed=False,
48 | )
49 |
50 | # Test DefaultStrategy
51 | strategy = DefaultStrategy(verbose=True)
52 | strategy.check_sanity(params, optimizers)
53 | state = strategy.initialize_state()
54 | strategy.step_pre_backward(params, optimizers, state, step=600, info=info)
55 | render_colors.mean().backward(retain_graph=True)
56 | strategy.step_post_backward(params, optimizers, state, step=600, info=info)
57 |
58 | # Test MCMCStrategy
59 | strategy = MCMCStrategy(verbose=True)
60 | strategy.check_sanity(params, optimizers)
61 | state = strategy.initialize_state()
62 | render_colors.mean().backward(retain_graph=True)
63 | strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)
64 |
65 |
66 | if __name__ == "__main__":
67 | test_strategy()
68 |
--------------------------------------------------------------------------------