├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE.md ├── README.md ├── cuda_rasterizer ├── auxiliary.h ├── backward.cu ├── backward.h ├── config.h ├── forward.cu ├── forward.h ├── rasterizer.h ├── rasterizer_impl.cu └── rasterizer_impl.h ├── diff_gauss └── __init__.py ├── ext.cpp ├── rasterize_points.cu ├── rasterize_points.h ├── setup.py └── third_party └── stbi_image_write.h /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | *.egg-info/ 3 | dist/ 4 | __pycache__/ 5 | /*.sh 6 | .vscode/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/glm"] 2 | path = third_party/glm 3 | url = https://github.com/g-truc/glm.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | cmake_minimum_required(VERSION 3.20) 13 | 14 | project(DiffRast LANGUAGES CUDA CXX) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_EXTENSIONS OFF) 18 | set(CMAKE_CUDA_STANDARD 17) 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 21 | 22 | add_library(CudaRasterizer 23 | cuda_rasterizer/backward.h 24 | cuda_rasterizer/backward.cu 25 | cuda_rasterizer/forward.h 26 | cuda_rasterizer/forward.cu 27 | cuda_rasterizer/auxiliary.h 28 | cuda_rasterizer/rasterizer_impl.cu 29 | cuda_rasterizer/rasterizer_impl.h 30 | cuda_rasterizer/rasterizer.h 31 | ) 32 | 33 | set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86") 34 | 35 | target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) 36 | target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differential Gaussian Rasterization 2 | 3 | **What's new** : Except for the RGB image, we also support render depth map, alpha map, normal map and extra per-Gaussian attributes (both forward and backward process) compared with the [original repository](https://github.com/graphdeco-inria/diff-gaussian-rasterization). 4 | 5 | :mega: We recently support computing the gradient w.r.t. camera pose. However, this feature hasn't been fully validated, thus we don't merge it into the main branch. You can switch to pose branch for more information. If you find any bugs, leave a message in the issues. Thank you! 6 | 7 | We modify the dependency name as **diff_gauss** to avoid dependecy conflict with the original version. You can install our repo by executing the following command lines 8 | ```shell 9 | git clone --recurse-submodules https://github.com/slothfulxtx/diff-gaussian-rasterization.git 10 | cd diff-gaussian-rasterization 11 | python setup.py install 12 | ``` 13 | 14 | Here's an example of our modified differential gaussian rasterization repo 15 | ```python 16 | from diff_gauss import GaussianRasterizationSettings, GaussianRasterizer 17 | 18 | rendered_image, rendered_depth, rendered_norm, rendered_alpha, radii, extra = rasterizer( 19 | means3D = means3D, 20 | means2D = means2D, 21 | shs = shs, 22 | colors_precomp = colors_precomp, 23 | opacities = opacity, 24 | scales = scales, 25 | rotations = rotations, 26 | cov3Ds_precomp = cov3D_precomp, 27 | extra_attrs = extra_attrs 28 | ) 29 | ``` 30 | 31 | Details: 32 | 33 | - Depth: By default, the depth is calculated as 'median depth', where the depth values of each pixels covered by 3D Gaussian Splatting are set to be the depth of the 3D Gaussian center. Thus, there exist numerical errors when the scales of 3D Gaussian are large. However, thanks to the densificaiton scheme, most 3D Gaussians are small. Currently, we ignore the numerical error of depth maps. 34 | - Normal: By default, the normal is considered as the shortest axis direction of the covariance matrix. Notably both the directions inwards and outwards the surface satisfy the above condition. To obtain a meaningful normal, we reverse the inwards directions using the view direction. 35 | - Per-Gaussian attributes: The maximum value of per-Gaussian attributes is 34, which is a magic number for my NVIDIA 3090 Ti GPU (larger value will raise error). If your compiling process raises error, you can change it to a lower value in `cuda_rasterizer/auxiliary.h`. 36 | 37 | Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us. 38 | 39 |
40 |
41 |

BibTeX

42 |
@Article{kerbl3Dgaussians,
43 |       author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
44 |       title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
45 |       journal      = {ACM Transactions on Graphics},
46 |       number       = {4},
47 |       volume       = {42},
48 |       month        = {July},
49 |       year         = {2023},
50 |       url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
51 | }
52 |
53 |
54 | -------------------------------------------------------------------------------- /cuda_rasterizer/auxiliary.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 13 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 14 | 15 | #include "config.h" 16 | #include "stdio.h" 17 | 18 | #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) 19 | #define NUM_WARPS (BLOCK_SIZE/32) 20 | #define MAX_EXTRA_DIMS 34 21 | 22 | // Spherical harmonics coefficients 23 | __device__ const float SH_C0 = 0.28209479177387814f; 24 | __device__ const float SH_C1 = 0.4886025119029199f; 25 | __device__ const float SH_C2[] = { 26 | 1.0925484305920792f, 27 | -1.0925484305920792f, 28 | 0.31539156525252005f, 29 | -1.0925484305920792f, 30 | 0.5462742152960396f 31 | }; 32 | __device__ const float SH_C3[] = { 33 | -0.5900435899266435f, 34 | 2.890611442640554f, 35 | -0.4570457994644658f, 36 | 0.3731763325901154f, 37 | -0.4570457994644658f, 38 | 1.445305721320277f, 39 | -0.5900435899266435f 40 | }; 41 | 42 | __forceinline__ __device__ float ndc2Pix(float v, int S) 43 | { 44 | return ((v + 1.0) * S - 1.0) * 0.5; 45 | } 46 | 47 | __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) 48 | { 49 | rect_min = { 50 | min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), 51 | min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) 52 | }; 53 | rect_max = { 54 | min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))), 55 | min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y))) 56 | }; 57 | } 58 | 59 | __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) 60 | { 61 | float3 transformed = { 62 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 63 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 64 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 65 | }; 66 | return transformed; 67 | } 68 | 69 | __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) 70 | { 71 | float4 transformed = { 72 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 73 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 74 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 75 | matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] 76 | }; 77 | return transformed; 78 | } 79 | 80 | __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) 81 | { 82 | float3 transformed = { 83 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, 84 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, 85 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, 86 | }; 87 | return transformed; 88 | } 89 | 90 | __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) 91 | { 92 | float3 transformed = { 93 | matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, 94 | matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, 95 | matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, 96 | }; 97 | return transformed; 98 | } 99 | 100 | __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) 101 | { 102 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 103 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 104 | float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 105 | return dnormvdz; 106 | } 107 | 108 | __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) 109 | { 110 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 111 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 112 | 113 | float3 dnormvdv; 114 | dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; 115 | dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; 116 | dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 117 | return dnormvdv; 118 | } 119 | 120 | __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) 121 | { 122 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; 123 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 124 | 125 | float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; 126 | float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; 127 | float4 dnormvdv; 128 | dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; 129 | dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; 130 | dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; 131 | dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; 132 | return dnormvdv; 133 | } 134 | 135 | __forceinline__ __device__ float sigmoid(float x) 136 | { 137 | return 1.0f / (1.0f + expf(-x)); 138 | } 139 | 140 | __forceinline__ __device__ bool in_frustum(int idx, 141 | const float* orig_points, 142 | const float* viewmatrix, 143 | const float* projmatrix, 144 | bool prefiltered, 145 | float3& p_view) 146 | { 147 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 148 | 149 | // Bring points to screen space 150 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 151 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 152 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 153 | p_view = transformPoint4x3(p_orig, viewmatrix); 154 | 155 | if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) 156 | { 157 | if (prefiltered) 158 | { 159 | printf("Point is filtered although prefiltered is set. This shouldn't happen!"); 160 | __trap(); 161 | } 162 | return false; 163 | } 164 | return true; 165 | } 166 | 167 | #define CHECK_CUDA(A, debug) \ 168 | A; if(debug) { \ 169 | auto ret = cudaDeviceSynchronize(); \ 170 | if (ret != cudaSuccess) { \ 171 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ 172 | throw std::runtime_error(cudaGetErrorString(ret)); \ 173 | } \ 174 | } 175 | 176 | #endif -------------------------------------------------------------------------------- /cuda_rasterizer/backward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "backward.h" 13 | #include "auxiliary.h" 14 | #include 15 | #include 16 | namespace cg = cooperative_groups; 17 | 18 | // Backward pass for conversion of spherical harmonics to RGB for 19 | // each Gaussian. 20 | __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs) 21 | { 22 | // Compute intermediate values, as it is done during forward 23 | glm::vec3 pos = means[idx]; 24 | glm::vec3 dir_orig = pos - campos; 25 | glm::vec3 dir = dir_orig / glm::length(dir_orig); 26 | 27 | glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; 28 | 29 | // Use PyTorch rule for clamping: if clamping was applied, 30 | // gradient becomes 0. 31 | glm::vec3 dL_dRGB = dL_dcolor[idx]; 32 | dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1; 33 | dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1; 34 | dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1; 35 | 36 | glm::vec3 dRGBdx(0, 0, 0); 37 | glm::vec3 dRGBdy(0, 0, 0); 38 | glm::vec3 dRGBdz(0, 0, 0); 39 | float x = dir.x; 40 | float y = dir.y; 41 | float z = dir.z; 42 | 43 | // Target location for this Gaussian to write SH gradients to 44 | glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs; 45 | 46 | // No tricks here, just high school-level calculus. 47 | float dRGBdsh0 = SH_C0; 48 | dL_dsh[0] = dRGBdsh0 * dL_dRGB; 49 | if (deg > 0) 50 | { 51 | float dRGBdsh1 = -SH_C1 * y; 52 | float dRGBdsh2 = SH_C1 * z; 53 | float dRGBdsh3 = -SH_C1 * x; 54 | dL_dsh[1] = dRGBdsh1 * dL_dRGB; 55 | dL_dsh[2] = dRGBdsh2 * dL_dRGB; 56 | dL_dsh[3] = dRGBdsh3 * dL_dRGB; 57 | 58 | dRGBdx = -SH_C1 * sh[3]; 59 | dRGBdy = -SH_C1 * sh[1]; 60 | dRGBdz = SH_C1 * sh[2]; 61 | 62 | if (deg > 1) 63 | { 64 | float xx = x * x, yy = y * y, zz = z * z; 65 | float xy = x * y, yz = y * z, xz = x * z; 66 | 67 | float dRGBdsh4 = SH_C2[0] * xy; 68 | float dRGBdsh5 = SH_C2[1] * yz; 69 | float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy); 70 | float dRGBdsh7 = SH_C2[3] * xz; 71 | float dRGBdsh8 = SH_C2[4] * (xx - yy); 72 | dL_dsh[4] = dRGBdsh4 * dL_dRGB; 73 | dL_dsh[5] = dRGBdsh5 * dL_dRGB; 74 | dL_dsh[6] = dRGBdsh6 * dL_dRGB; 75 | dL_dsh[7] = dRGBdsh7 * dL_dRGB; 76 | dL_dsh[8] = dRGBdsh8 * dL_dRGB; 77 | 78 | dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8]; 79 | dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8]; 80 | dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7]; 81 | 82 | if (deg > 2) 83 | { 84 | float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy); 85 | float dRGBdsh10 = SH_C3[1] * xy * z; 86 | float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy); 87 | float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy); 88 | float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy); 89 | float dRGBdsh14 = SH_C3[5] * z * (xx - yy); 90 | float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy); 91 | dL_dsh[9] = dRGBdsh9 * dL_dRGB; 92 | dL_dsh[10] = dRGBdsh10 * dL_dRGB; 93 | dL_dsh[11] = dRGBdsh11 * dL_dRGB; 94 | dL_dsh[12] = dRGBdsh12 * dL_dRGB; 95 | dL_dsh[13] = dRGBdsh13 * dL_dRGB; 96 | dL_dsh[14] = dRGBdsh14 * dL_dRGB; 97 | dL_dsh[15] = dRGBdsh15 * dL_dRGB; 98 | 99 | dRGBdx += ( 100 | SH_C3[0] * sh[9] * 3.f * 2.f * xy + 101 | SH_C3[1] * sh[10] * yz + 102 | SH_C3[2] * sh[11] * -2.f * xy + 103 | SH_C3[3] * sh[12] * -3.f * 2.f * xz + 104 | SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) + 105 | SH_C3[5] * sh[14] * 2.f * xz + 106 | SH_C3[6] * sh[15] * 3.f * (xx - yy)); 107 | 108 | dRGBdy += ( 109 | SH_C3[0] * sh[9] * 3.f * (xx - yy) + 110 | SH_C3[1] * sh[10] * xz + 111 | SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) + 112 | SH_C3[3] * sh[12] * -3.f * 2.f * yz + 113 | SH_C3[4] * sh[13] * -2.f * xy + 114 | SH_C3[5] * sh[14] * -2.f * yz + 115 | SH_C3[6] * sh[15] * -3.f * 2.f * xy); 116 | 117 | dRGBdz += ( 118 | SH_C3[1] * sh[10] * xy + 119 | SH_C3[2] * sh[11] * 4.f * 2.f * yz + 120 | SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) + 121 | SH_C3[4] * sh[13] * 4.f * 2.f * xz + 122 | SH_C3[5] * sh[14] * (xx - yy)); 123 | } 124 | } 125 | } 126 | 127 | // The view direction is an input to the computation. View direction 128 | // is influenced by the Gaussian's mean, so SHs gradients 129 | // must propagate back into 3D position. 130 | glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB)); 131 | 132 | // Account for normalization of direction 133 | float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z }); 134 | 135 | // Gradients of loss w.r.t. Gaussian means, but only the portion 136 | // that is caused because the mean affects the view-dependent color. 137 | // Additional mean gradient is accumulated in below methods. 138 | dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z); 139 | } 140 | 141 | // Backward version of INVERSE 2D covariance matrix computation 142 | // (due to length launched as separate kernel before other 143 | // backward steps contained in preprocess) 144 | __global__ void computeCov2DCUDA(int P, 145 | const float3* means, 146 | const int* radii, 147 | const float* cov3Ds, 148 | const float h_x, float h_y, 149 | const float tan_fovx, float tan_fovy, 150 | const float* view_matrix, 151 | const float* dL_dconics, 152 | float3* dL_dmeans, 153 | float* dL_dcov) 154 | { 155 | auto idx = cg::this_grid().thread_rank(); 156 | if (idx >= P || !(radii[idx] > 0)) 157 | return; 158 | 159 | // Reading location of 3D covariance for this Gaussian 160 | const float* cov3D = cov3Ds + 6 * idx; 161 | 162 | // Fetch gradients, recompute 2D covariance and relevant 163 | // intermediate forward results needed in the backward. 164 | float3 mean = means[idx]; 165 | float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] }; 166 | float3 t = transformPoint4x3(mean, view_matrix); 167 | 168 | const float limx = 1.3f * tan_fovx; 169 | const float limy = 1.3f * tan_fovy; 170 | const float txtz = t.x / t.z; 171 | const float tytz = t.y / t.z; 172 | t.x = min(limx, max(-limx, txtz)) * t.z; 173 | t.y = min(limy, max(-limy, tytz)) * t.z; 174 | 175 | const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1; 176 | const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1; 177 | 178 | glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z), 179 | 0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z), 180 | 0, 0, 0); 181 | 182 | glm::mat3 W = glm::mat3( 183 | view_matrix[0], view_matrix[4], view_matrix[8], 184 | view_matrix[1], view_matrix[5], view_matrix[9], 185 | view_matrix[2], view_matrix[6], view_matrix[10]); 186 | 187 | glm::mat3 Vrk = glm::mat3( 188 | cov3D[0], cov3D[1], cov3D[2], 189 | cov3D[1], cov3D[3], cov3D[4], 190 | cov3D[2], cov3D[4], cov3D[5]); 191 | 192 | glm::mat3 T = W * J; 193 | 194 | glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T; 195 | 196 | // Use helper variables for 2D covariance entries. More compact. 197 | float a = cov2D[0][0] += 0.3f; 198 | float b = cov2D[0][1]; 199 | float c = cov2D[1][1] += 0.3f; 200 | 201 | float denom = a * c - b * b; 202 | float dL_da = 0, dL_db = 0, dL_dc = 0; 203 | float denom2inv = 1.0f / ((denom * denom) + 0.0000001f); 204 | 205 | if (denom2inv != 0) 206 | { 207 | // Gradients of loss w.r.t. entries of 2D covariance matrix, 208 | // given gradients of loss w.r.t. conic matrix (inverse covariance matrix). 209 | // e.g., dL / da = dL / d_conic_a * d_conic_a / d_a 210 | dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z); 211 | dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x); 212 | dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z); 213 | 214 | // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, 215 | // given gradients w.r.t. 2D covariance matrix (diagonal). 216 | // cov2D = transpose(T) * transpose(Vrk) * T; 217 | dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc); 218 | dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc); 219 | dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc); 220 | 221 | // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, 222 | // given gradients w.r.t. 2D covariance matrix (off-diagonal). 223 | // Off-diagonal elements appear twice --> double the gradient. 224 | // cov2D = transpose(T) * transpose(Vrk) * T; 225 | dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc; 226 | dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc; 227 | dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc; 228 | } 229 | else 230 | { 231 | for (int i = 0; i < 6; i++) 232 | dL_dcov[6 * idx + i] = 0; 233 | } 234 | 235 | // Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T 236 | // cov2D = transpose(T) * transpose(Vrk) * T; 237 | float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da + 238 | (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db; 239 | float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da + 240 | (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db; 241 | float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da + 242 | (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db; 243 | float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc + 244 | (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db; 245 | float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc + 246 | (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db; 247 | float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc + 248 | (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db; 249 | 250 | // Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix 251 | // T = W * J 252 | float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02; 253 | float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02; 254 | float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12; 255 | float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12; 256 | 257 | float tz = 1.f / t.z; 258 | float tz2 = tz * tz; 259 | float tz3 = tz2 * tz; 260 | 261 | // Gradients of loss w.r.t. transformed Gaussian mean t 262 | float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02; 263 | float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12; 264 | float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12; 265 | 266 | // Account for transformation of mean to t 267 | // t = transformPoint4x3(mean, view_matrix); 268 | float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix); 269 | 270 | // Gradients of loss w.r.t. Gaussian means, but only the portion 271 | // that is caused because the mean affects the covariance matrix. 272 | // Additional mean gradient is accumulated in BACKWARD::preprocess. 273 | dL_dmeans[idx] = dL_dmean; 274 | } 275 | 276 | // Backward pass for the conversion of scale and rotation to a 277 | // 3D covariance matrix for each Gaussian. 278 | __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots) 279 | { 280 | // Recompute (intermediate) results for the 3D covariance computation. 281 | glm::vec4 q = rot;// / glm::length(rot); 282 | float r = q.x; 283 | float x = q.y; 284 | float y = q.z; 285 | float z = q.w; 286 | 287 | glm::mat3 R = glm::mat3( 288 | 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y), 289 | 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x), 290 | 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y) 291 | ); 292 | 293 | glm::mat3 S = glm::mat3(1.0f); 294 | 295 | glm::vec3 s = mod * scale; 296 | S[0][0] = s.x; 297 | S[1][1] = s.y; 298 | S[2][2] = s.z; 299 | 300 | glm::mat3 M = S * R; 301 | 302 | const float* dL_dcov3D = dL_dcov3Ds + 6 * idx; 303 | 304 | glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]); 305 | glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]); 306 | 307 | // Convert per-element covariance loss gradients to matrix form 308 | glm::mat3 dL_dSigma = glm::mat3( 309 | dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2], 310 | 0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4], 311 | 0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5] 312 | ); 313 | 314 | // Compute loss gradient w.r.t. matrix M 315 | // dSigma_dM = 2 * M 316 | glm::mat3 dL_dM = 2.0f * M * dL_dSigma; 317 | 318 | glm::mat3 Rt = glm::transpose(R); 319 | glm::mat3 dL_dMt = glm::transpose(dL_dM); 320 | 321 | // Gradients of loss w.r.t. scale 322 | glm::vec3* dL_dscale = dL_dscales + idx; 323 | dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]); 324 | dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]); 325 | dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]); 326 | 327 | dL_dMt[0] *= s.x; 328 | dL_dMt[1] *= s.y; 329 | dL_dMt[2] *= s.z; 330 | 331 | // Gradients of loss w.r.t. normalized quaternion 332 | glm::vec4 dL_dq; 333 | dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]); 334 | dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) - 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]); 335 | dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]); 336 | dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]); 337 | 338 | // Gradients of loss w.r.t. unnormalized quaternion 339 | float4* dL_drot = (float4*)(dL_drots + idx); 340 | *dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w }); 341 | } 342 | 343 | 344 | __device__ void computeNorm3D(int idx, const glm::vec3 scale, const glm::vec4 rot, const glm::vec3 norm3D, const glm::vec3 dL_dnorm3D, glm::vec3* dL_dscales, glm::vec4* dL_drots) 345 | { 346 | // Recompute (intermediate) results for the 3D covariance computation. 347 | glm::vec4 q = rot;// / glm::length(rot); 348 | float r = q.x; 349 | float x = q.y; 350 | float y = q.z; 351 | float z = q.w; 352 | 353 | glm::mat3 R = glm::mat3( 354 | 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y), 355 | 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x), 356 | 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y) 357 | ); 358 | 359 | glm::vec3 norm; 360 | if(scale.x > scale.z && scale.y > scale.z) 361 | { 362 | norm = glm::vec3(0.0, 0.0, 1.0); 363 | } 364 | else if(scale.x > scale.y && scale.z > scale.y) 365 | { 366 | norm = glm::vec3(0.0, 1.0, 0.0); 367 | } 368 | else 369 | { 370 | norm = glm::vec3(1.0, 0.0, 0.0); 371 | } 372 | 373 | if(glm::dot(R * norm, norm3D) < 0) 374 | norm = -norm; 375 | 376 | glm::mat3 dL_dR = glm::mat3( 377 | dL_dnorm3D.x * norm.x, dL_dnorm3D.x * norm.y, dL_dnorm3D.x * norm.z, 378 | dL_dnorm3D.y * norm.x, dL_dnorm3D.y * norm.y, dL_dnorm3D.y * norm.z, 379 | dL_dnorm3D.z * norm.x, dL_dnorm3D.z * norm.y, dL_dnorm3D.z * norm.z 380 | ); 381 | glm::mat3 dL_dRt = glm::transpose(dL_dR); 382 | 383 | // Gradients of loss w.r.t. normalized quaternion 384 | glm::vec4 dL_dq; 385 | dL_dq.x = 2 * z * (dL_dRt[0][1] - dL_dRt[1][0]) + 2 * y * (dL_dRt[2][0] - dL_dRt[0][2]) + 2 * x * (dL_dRt[1][2] - dL_dRt[2][1]); 386 | dL_dq.y = 2 * y * (dL_dRt[1][0] + dL_dRt[0][1]) + 2 * z * (dL_dRt[2][0] + dL_dRt[0][2]) + 2 * r * (dL_dRt[1][2] - dL_dRt[2][1]) - 4 * x * (dL_dRt[2][2] + dL_dRt[1][1]); 387 | dL_dq.z = 2 * x * (dL_dRt[1][0] + dL_dRt[0][1]) + 2 * r * (dL_dRt[2][0] - dL_dRt[0][2]) + 2 * z * (dL_dRt[1][2] + dL_dRt[2][1]) - 4 * y * (dL_dRt[2][2] + dL_dRt[0][0]); 388 | dL_dq.w = 2 * r * (dL_dRt[0][1] - dL_dRt[1][0]) + 2 * x * (dL_dRt[2][0] + dL_dRt[0][2]) + 2 * y * (dL_dRt[1][2] + dL_dRt[2][1]) - 4 * z * (dL_dRt[1][1] + dL_dRt[0][0]); 389 | 390 | // Gradients of loss w.r.t. unnormalized quaternion 391 | dL_drots[idx] += dL_dq; 392 | } 393 | 394 | // Backward pass of the preprocessing steps, except 395 | // for the covariance computation and inversion 396 | // (those are handled by a previous kernel call) 397 | template 398 | __global__ void preprocessCUDA( 399 | int P, int D, int M, 400 | const float3* means, 401 | const int* radii, 402 | const float* shs, 403 | const glm::vec3* norm3Ds, 404 | bool is_norm3Ds_precomp, 405 | const bool* clamped, 406 | const glm::vec3* scales, 407 | const glm::vec4* rotations, 408 | const float scale_modifier, 409 | const float* view, 410 | const float* proj, 411 | const glm::vec3* campos, 412 | const float3* dL_dmean2D, 413 | glm::vec3* dL_dmeans, 414 | float* dL_dcolor, 415 | float* dL_ddepth, 416 | float* dL_dcov3D, 417 | glm::vec3* dL_dnorm3D, 418 | float* dL_dsh, 419 | glm::vec3* dL_dscale, 420 | glm::vec4* dL_drot) 421 | { 422 | auto idx = cg::this_grid().thread_rank(); 423 | if (idx >= P || !(radii[idx] > 0)) 424 | return; 425 | 426 | float3 m = means[idx]; 427 | 428 | // Taking care of gradients from the screenspace points 429 | float4 m_hom = transformPoint4x4(m, proj); 430 | float m_w = 1.0f / (m_hom.w + 0.0000001f); 431 | 432 | // Compute loss gradient w.r.t. 3D means due to gradients of 2D means 433 | // from rendering procedure 434 | glm::vec3 dL_dmean; 435 | float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w; 436 | float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w; 437 | dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y; 438 | dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y; 439 | dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y; 440 | 441 | // That's the second part of the mean gradient. Previous computation 442 | // of cov2D and following SH conversion also affects it. 443 | dL_dmeans[idx] += dL_dmean; 444 | 445 | // the w must be equal to 1 for view^T * [x,y,z,1] 446 | float3 m_view = transformPoint4x3(m, view); 447 | 448 | // Compute loss gradient w.r.t. 3D means due to gradients of depth 449 | // from rendering procedure 450 | glm::vec3 dL_dmean2; 451 | float mul3 = view[2] * m.x + view[6] * m.y + view[10] * m.z + view[14]; 452 | dL_dmean2.x = (view[2] - view[3] * mul3) * dL_ddepth[idx]; 453 | dL_dmean2.y = (view[6] - view[7] * mul3) * dL_ddepth[idx]; 454 | dL_dmean2.z = (view[10] - view[11] * mul3) * dL_ddepth[idx]; 455 | 456 | // That's the third part of the mean gradient. 457 | dL_dmeans[idx] += dL_dmean2; 458 | 459 | // Compute gradient updates due to computing colors from SHs 460 | if (shs) 461 | computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh); 462 | 463 | // Compute gradient updates due to computing covariance from scale/rotation 464 | if (scales) 465 | computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot); 466 | 467 | if (!is_norm3Ds_precomp) 468 | computeNorm3D(idx, scales[idx], rotations[idx], norm3Ds[idx], dL_dnorm3D[idx], dL_dscale, dL_drot); 469 | 470 | } 471 | 472 | // Backward version of the rendering procedure. 473 | template 474 | __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) 475 | renderCUDA( 476 | const uint2* __restrict__ ranges, 477 | const uint32_t* __restrict__ point_list, 478 | int W, int H, int ED, 479 | const float* __restrict__ bg_color, 480 | const float2* __restrict__ points_xy_image, 481 | const float4* __restrict__ conic_opacity, 482 | const float* __restrict__ colors, 483 | const float* __restrict__ depths, 484 | const float* __restrict__ norms, 485 | const float* __restrict__ extras, 486 | const float* __restrict__ accum_alphas, 487 | const uint32_t* __restrict__ n_contrib, 488 | const float* __restrict__ dL_dpixels, 489 | const float* __restrict__ dL_dpixel_depths, 490 | const float* __restrict__ dL_dpixel_norms, 491 | const float* __restrict__ dL_dpixel_alphas, 492 | const float* __restrict__ dL_dpixel_extras, 493 | float3* __restrict__ dL_dmean2D, 494 | float4* __restrict__ dL_dconic2D, 495 | float* __restrict__ dL_dopacity, 496 | float* __restrict__ dL_dcolors, 497 | float* __restrict__ dL_ddepths, 498 | float* __restrict__ dL_dnorm3Ds, 499 | float* __restrict__ dL_dextras) 500 | { 501 | // We rasterize again. Compute necessary block info. 502 | auto block = cg::this_thread_block(); 503 | const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; 504 | const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; 505 | const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; 506 | const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; 507 | const uint32_t pix_id = W * pix.y + pix.x; 508 | const float2 pixf = { (float)pix.x, (float)pix.y }; 509 | 510 | const bool inside = pix.x < W&& pix.y < H; 511 | const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; 512 | 513 | const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); 514 | 515 | bool done = !inside; 516 | int toDo = range.y - range.x; 517 | 518 | __shared__ int collected_id[BLOCK_SIZE]; 519 | __shared__ float2 collected_xy[BLOCK_SIZE]; 520 | __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; 521 | __shared__ float collected_colors[C * BLOCK_SIZE]; 522 | __shared__ float collected_depths[BLOCK_SIZE]; 523 | __shared__ float collected_norms[3 * BLOCK_SIZE]; 524 | __shared__ float collected_extras[MAX_EXTRA_DIMS * BLOCK_SIZE]; 525 | 526 | // In the forward, we stored the final value for T, the 527 | // product of all (1 - alpha) factors. 528 | const float T_final = inside ? (1 - accum_alphas[pix_id]) : 0; 529 | float T = T_final; 530 | 531 | // We start from the back. The ID of the last contributing 532 | // Gaussian is known from each pixel from the forward. 533 | uint32_t contributor = toDo; 534 | const int last_contributor = inside ? n_contrib[pix_id] : 0; 535 | 536 | float accum_rec[C] = { 0 }; 537 | float accum_red = 0; 538 | float accum_ren[3] = { 0 }; 539 | float accum_rea = 0; 540 | float accum_ree[MAX_EXTRA_DIMS] = { 0 }; 541 | float dL_dpixel[C]; 542 | float dL_dpixel_depth; 543 | float dL_dpixel_norm[3]; 544 | float dL_dpixel_alpha; 545 | float dL_dpixel_extra[MAX_EXTRA_DIMS]; 546 | if (inside) 547 | { 548 | for (int i = 0; i < C; i++) 549 | dL_dpixel[i] = dL_dpixels[i * H * W + pix_id]; 550 | dL_dpixel_depth = dL_dpixel_depths[pix_id]; 551 | for (int i = 0; i < 3; i++) 552 | dL_dpixel_norm[i] = dL_dpixel_norms[i * H * W + pix_id]; 553 | dL_dpixel_alpha = dL_dpixel_alphas[pix_id]; 554 | for (int i = 0; i < ED; i++) 555 | dL_dpixel_extra[i] = dL_dpixel_extras[i * H * W + pix_id]; 556 | } 557 | float last_alpha = 0; 558 | float last_color[C] = { 0 }; 559 | float last_depth = 0; 560 | float last_norm[3] = { 0 }; 561 | float last_extra[MAX_EXTRA_DIMS] = { 0 }; 562 | // Gradient of pixel coordinate w.r.t. normalized 563 | // screen-space viewport corrdinates (-1 to 1) 564 | const float ddelx_dx = 0.5 * W; 565 | const float ddely_dy = 0.5 * H; 566 | 567 | // Traverse all Gaussians 568 | for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) 569 | { 570 | // Load auxiliary data into shared memory, start in the BACK 571 | // and load them in revers order. 572 | block.sync(); 573 | const int progress = i * BLOCK_SIZE + block.thread_rank(); 574 | if (range.x + progress < range.y) 575 | { 576 | const int coll_id = point_list[range.y - progress - 1]; 577 | collected_id[block.thread_rank()] = coll_id; 578 | collected_xy[block.thread_rank()] = points_xy_image[coll_id]; 579 | collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; 580 | for (int i = 0; i < C; i++) 581 | collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i]; 582 | for (int i = 0; i < 3; i++) 583 | collected_norms[i * BLOCK_SIZE + block.thread_rank()] = norms[coll_id * 3 + i]; 584 | collected_depths[block.thread_rank()] = depths[coll_id]; 585 | for (int i = 0; i < ED; i++) 586 | collected_extras[i * BLOCK_SIZE + block.thread_rank()] = extras[coll_id * ED + i]; 587 | } 588 | block.sync(); 589 | 590 | // Iterate over Gaussians 591 | for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) 592 | { 593 | // Keep track of current Gaussian ID. Skip, if this one 594 | // is behind the last contributor for this pixel. 595 | contributor--; 596 | if (contributor >= last_contributor) 597 | continue; 598 | 599 | // Compute blending values, as before. 600 | const float2 xy = collected_xy[j]; 601 | const float2 d = { xy.x - pixf.x, xy.y - pixf.y }; 602 | const float4 con_o = collected_conic_opacity[j]; 603 | const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; 604 | if (power > 0.0f) 605 | continue; 606 | 607 | const float G = exp(power); 608 | const float alpha = min(0.99f, con_o.w * G); 609 | if (alpha < 1.0f / 255.0f) 610 | continue; 611 | 612 | T = T / (1.f - alpha); 613 | const float weight = alpha * T; 614 | // const float dchannel_dcolor = alpha * T; 615 | // const float dpixel_depth_ddepth = alpha * T; 616 | // const float dpixel_norm_dnorm = alpha * T; 617 | // const float dpixel_extra_dextra = alpha * T; 618 | 619 | // Propagate gradients to per-Gaussian colors and keep 620 | // gradients w.r.t. alpha (blending factor for a Gaussian/pixel 621 | // pair). 622 | float dL_dalpha = 0.0f; 623 | const int global_id = collected_id[j]; 624 | for (int ch = 0; ch < C; ch++) 625 | { 626 | const float c = collected_colors[ch * BLOCK_SIZE + j]; 627 | // Update last color (to be used in the next iteration) 628 | accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch]; 629 | last_color[ch] = c; 630 | 631 | const float dL_dchannel = dL_dpixel[ch]; 632 | dL_dalpha += (c - accum_rec[ch]) * dL_dchannel; 633 | // Update the gradients w.r.t. color of the Gaussian. 634 | // Atomic, since this pixel is just one of potentially 635 | // many that were affected by this Gaussian. 636 | atomicAdd(&(dL_dcolors[global_id * C + ch]), weight * dL_dchannel); 637 | } 638 | const float dep = collected_depths[j]; 639 | accum_red = last_alpha * last_depth + (1.f - last_alpha) * accum_red; 640 | last_depth = dep; 641 | dL_dalpha += (dep-accum_red) * dL_dpixel_depth; 642 | atomicAdd(&(dL_ddepths[global_id]), weight * dL_dpixel_depth); 643 | 644 | for (int ch = 0; ch < 3; ch++) 645 | { 646 | const float n = collected_norms[ch * BLOCK_SIZE + j]; 647 | // Update last norm (to be used in the next iteration) 648 | accum_ren[ch] = last_alpha * last_norm[ch] + (1.f - last_alpha) * accum_ren[ch]; 649 | last_norm[ch] = n; 650 | 651 | const float dL_dnormch = dL_dpixel_norm[ch]; 652 | dL_dalpha += (n - accum_ren[ch]) * dL_dnormch; 653 | // Update the gradients w.r.t. norm of the Gaussian. 654 | // Atomic, since this pixel is just one of potentially 655 | // many that were affected by this Gaussian. 656 | atomicAdd(&(dL_dnorm3Ds[global_id * 3 + ch]), weight * dL_dnormch); 657 | } 658 | 659 | for (int ch = 0; ch < ED; ch++) 660 | { 661 | const float e = collected_extras[ch * BLOCK_SIZE + j]; 662 | // Update last norm (to be used in the next iteration) 663 | accum_ree[ch] = last_alpha * last_extra[ch] + (1.f - last_alpha) * accum_ree[ch]; 664 | last_extra[ch] = e; 665 | 666 | const float dL_dextrach = dL_dpixel_extra[ch]; 667 | dL_dalpha += (e - accum_ree[ch]) * dL_dextrach; 668 | // Update the gradients w.r.t. norm of the Gaussian. 669 | // Atomic, since this pixel is just one of potentially 670 | // many that were affected by this Gaussian. 671 | atomicAdd(&(dL_dextras[global_id * ED + ch]), weight * dL_dextrach); 672 | } 673 | 674 | accum_rea = last_alpha + (1.f - last_alpha) * accum_rea; 675 | dL_dalpha += (1 - accum_rea) * dL_dpixel_alpha; 676 | 677 | 678 | dL_dalpha *= T; 679 | // Update last alpha (to be used in the next iteration) 680 | last_alpha = alpha; 681 | 682 | // Account for fact that alpha also influences how much of 683 | // the background color is added if nothing left to blend 684 | float bg_dot_dpixel = 0; 685 | for (int i = 0; i < C; i++) 686 | bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; 687 | dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel; 688 | 689 | // Set background depth value == 0, thus no contribution for 690 | // dL_dalpha 691 | 692 | // Helpful reusable temporary variables 693 | const float dL_dG = con_o.w * dL_dalpha; 694 | const float gdx = G * d.x; 695 | const float gdy = G * d.y; 696 | const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y; 697 | const float dG_ddely = -gdy * con_o.z - gdx * con_o.y; 698 | 699 | // Update gradients w.r.t. 2D mean position of the Gaussian 700 | atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx); 701 | atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy); 702 | 703 | // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) 704 | atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG); 705 | atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG); 706 | atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG); 707 | 708 | // Update gradients w.r.t. opacity of the Gaussian 709 | atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha); 710 | } 711 | } 712 | } 713 | 714 | void BACKWARD::preprocess( 715 | int P, int D, int M, 716 | const float3* means3D, 717 | const int* radii, 718 | const float* shs, 719 | const bool* clamped, 720 | const glm::vec3* scales, 721 | const glm::vec4* rotations, 722 | const float scale_modifier, 723 | const float* cov3Ds, 724 | const glm::vec3* norm3Ds, 725 | bool is_norm3Ds_precomp, 726 | const float* viewmatrix, 727 | const float* projmatrix, 728 | const float focal_x, float focal_y, 729 | const float tan_fovx, float tan_fovy, 730 | const glm::vec3* campos, 731 | const float3* dL_dmean2D, 732 | const float* dL_dconic, 733 | glm::vec3* dL_dmean3D, 734 | float* dL_dcolor, 735 | float* dL_ddepth, 736 | float* dL_dcov3D, 737 | glm::vec3* dL_dnorm3D, 738 | float* dL_dsh, 739 | glm::vec3* dL_dscale, 740 | glm::vec4* dL_drot) 741 | { 742 | // Propagate gradients for the path of 2D conic matrix computation. 743 | // Somewhat long, thus it is its own kernel rather than being part of 744 | // "preprocess". When done, loss gradient w.r.t. 3D means has been 745 | // modified and gradient w.r.t. 3D covariance matrix has been computed. 746 | computeCov2DCUDA << <(P + 255) / 256, 256 >> > ( 747 | P, 748 | means3D, 749 | radii, 750 | cov3Ds, 751 | focal_x, 752 | focal_y, 753 | tan_fovx, 754 | tan_fovy, 755 | viewmatrix, 756 | dL_dconic, 757 | (float3*)dL_dmean3D, 758 | dL_dcov3D); 759 | 760 | // Propagate gradients for remaining steps: finish 3D mean gradients, 761 | // propagate color gradients to SH (if desireD), propagate 3D covariance 762 | // matrix gradients to scale and rotation. 763 | preprocessCUDA << < (P + 255) / 256, 256 >> > ( 764 | P, D, M, 765 | (float3*)means3D, 766 | radii, 767 | shs, 768 | (glm::vec3*)norm3Ds, 769 | is_norm3Ds_precomp, 770 | clamped, 771 | (glm::vec3*)scales, 772 | (glm::vec4*)rotations, 773 | scale_modifier, 774 | viewmatrix, 775 | projmatrix, 776 | campos, 777 | (float3*)dL_dmean2D, 778 | (glm::vec3*)dL_dmean3D, 779 | dL_dcolor, 780 | dL_ddepth, 781 | dL_dcov3D, 782 | (glm::vec3*)dL_dnorm3D, 783 | dL_dsh, 784 | dL_dscale, 785 | dL_drot); 786 | } 787 | 788 | void BACKWARD::render( 789 | const dim3 grid, const dim3 block, 790 | const uint2* ranges, 791 | const uint32_t* point_list, 792 | int W, int H, int ED, 793 | const float* bg_color, 794 | const float2* means2D, 795 | const float4* conic_opacity, 796 | const float* colors, 797 | const float* depths, 798 | const float* norms, 799 | const float* extras, 800 | const float* accum_alphas, 801 | const uint32_t* n_contrib, 802 | const float* dL_dpixels, 803 | const float* dL_dpixel_depths, 804 | const float* dL_dpixel_norms, 805 | const float* dL_dpixel_alphas, 806 | const float* dL_dpixel_extras, 807 | float3* dL_dmean2D, 808 | float4* dL_dconic2D, 809 | float* dL_dopacity, 810 | float* dL_dcolors, 811 | float* dL_ddepths, 812 | float* dL_dnorm3Ds, 813 | float* dL_dextras) 814 | { 815 | renderCUDA << > >( 816 | ranges, 817 | point_list, 818 | W, H, ED, 819 | bg_color, 820 | means2D, 821 | conic_opacity, 822 | colors, 823 | depths, 824 | norms, 825 | extras, 826 | accum_alphas, 827 | n_contrib, 828 | dL_dpixels, 829 | dL_dpixel_depths, 830 | dL_dpixel_norms, 831 | dL_dpixel_alphas, 832 | dL_dpixel_extras, 833 | dL_dmean2D, 834 | dL_dconic2D, 835 | dL_dopacity, 836 | dL_dcolors, 837 | dL_ddepths, 838 | dL_dnorm3Ds, 839 | dL_dextras); 840 | } -------------------------------------------------------------------------------- /cuda_rasterizer/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const dim3 grid, dim3 block, 25 | const uint2* ranges, 26 | const uint32_t* point_list, 27 | int W, int H, int ED, 28 | const float* bg_color, 29 | const float2* means2D, 30 | const float4* conic_opacity, 31 | const float* colors, 32 | const float* depths, 33 | const float* norms, 34 | const float* extras, 35 | const float* accum_alphas, 36 | const uint32_t* n_contrib, 37 | const float* dL_dpixels, 38 | const float* dL_dpixel_depths, 39 | const float* dL_dpixel_norms, 40 | const float* dL_dpixel_alphas, 41 | const float* dL_dpixel_extras, 42 | float3* dL_dmean2D, 43 | float4* dL_dconic2D, 44 | float* dL_dopacity, 45 | float* dL_dcolors, 46 | float* dL_ddepths, 47 | float* dL_dnorm3Ds, 48 | float* dL_dextras); 49 | 50 | void preprocess( 51 | int P, int D, int M, 52 | const float3* means, 53 | const int* radii, 54 | const float* shs, 55 | const bool* clamped, 56 | const glm::vec3* scales, 57 | const glm::vec4* rotations, 58 | const float scale_modifier, 59 | const float* cov3Ds, 60 | const glm::vec3* norm3Ds, 61 | bool is_norm3Ds_precomp, 62 | const float* view, 63 | const float* proj, 64 | const float focal_x, float focal_y, 65 | const float tan_fovx, float tan_fovy, 66 | const glm::vec3* campos, 67 | const float3* dL_dmean2D, 68 | const float* dL_dconics, 69 | glm::vec3* dL_dmeans, 70 | float* dL_dcolor, 71 | float* dL_ddepth, 72 | float* dL_dcov3D, 73 | glm::vec3* dL_dnorm3D, 74 | float* dL_dsh, 75 | glm::vec3* dL_dscale, 76 | glm::vec4* dL_drot); 77 | } 78 | 79 | #endif -------------------------------------------------------------------------------- /cuda_rasterizer/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3, RGB 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /cuda_rasterizer/forward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "forward.h" 13 | #include "auxiliary.h" 14 | #include 15 | #include 16 | namespace cg = cooperative_groups; 17 | 18 | // Forward method for converting the input spherical harmonics 19 | // coefficients of each Gaussian to a simple RGB color. 20 | __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped) 21 | { 22 | // The implementation is loosely based on code for 23 | // "Differentiable Point-Based Radiance Fields for 24 | // Efficient View Synthesis" by Zhang et al. (2022) 25 | glm::vec3 pos = means[idx]; 26 | glm::vec3 dir = pos - campos; 27 | dir = dir / glm::length(dir); 28 | 29 | glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; 30 | glm::vec3 result = SH_C0 * sh[0]; 31 | 32 | if (deg > 0) 33 | { 34 | float x = dir.x; 35 | float y = dir.y; 36 | float z = dir.z; 37 | result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3]; 38 | 39 | if (deg > 1) 40 | { 41 | float xx = x * x, yy = y * y, zz = z * z; 42 | float xy = x * y, yz = y * z, xz = x * z; 43 | result = result + 44 | SH_C2[0] * xy * sh[4] + 45 | SH_C2[1] * yz * sh[5] + 46 | SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] + 47 | SH_C2[3] * xz * sh[7] + 48 | SH_C2[4] * (xx - yy) * sh[8]; 49 | 50 | if (deg > 2) 51 | { 52 | result = result + 53 | SH_C3[0] * y * (3.0f * xx - yy) * sh[9] + 54 | SH_C3[1] * xy * z * sh[10] + 55 | SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] + 56 | SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] + 57 | SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] + 58 | SH_C3[5] * z * (xx - yy) * sh[14] + 59 | SH_C3[6] * x * (xx - 3.0f * yy) * sh[15]; 60 | } 61 | } 62 | } 63 | result += 0.5f; 64 | 65 | // RGB colors are clamped to positive values. If values are 66 | // clamped, we need to keep track of this for the backward pass. 67 | clamped[3 * idx + 0] = (result.x < 0); 68 | clamped[3 * idx + 1] = (result.y < 0); 69 | clamped[3 * idx + 2] = (result.z < 0); 70 | return glm::max(result, 0.0f); 71 | } 72 | 73 | // Forward version of 2D covariance matrix computation 74 | __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix) 75 | { 76 | // The following models the steps outlined by equations 29 77 | // and 31 in "EWA Splatting" (Zwicker et al., 2002). 78 | // Additionally considers aspect / scaling of viewport. 79 | // Transposes used to account for row-/column-major conventions. 80 | float3 t = transformPoint4x3(mean, viewmatrix); 81 | 82 | const float limx = 1.3f * tan_fovx; 83 | const float limy = 1.3f * tan_fovy; 84 | const float txtz = t.x / t.z; 85 | const float tytz = t.y / t.z; 86 | t.x = min(limx, max(-limx, txtz)) * t.z; 87 | t.y = min(limy, max(-limy, tytz)) * t.z; 88 | 89 | glm::mat3 J = glm::mat3( 90 | focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z), 91 | 0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z), 92 | 0, 0, 0); 93 | 94 | glm::mat3 W = glm::mat3( 95 | viewmatrix[0], viewmatrix[4], viewmatrix[8], 96 | viewmatrix[1], viewmatrix[5], viewmatrix[9], 97 | viewmatrix[2], viewmatrix[6], viewmatrix[10]); 98 | 99 | glm::mat3 T = W * J; 100 | 101 | glm::mat3 Vrk = glm::mat3( 102 | cov3D[0], cov3D[1], cov3D[2], 103 | cov3D[1], cov3D[3], cov3D[4], 104 | cov3D[2], cov3D[4], cov3D[5]); 105 | 106 | glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T; 107 | 108 | // Apply low-pass filter: every Gaussian should be at least 109 | // one pixel wide/high. Discard 3rd row and column. 110 | cov[0][0] += 0.3f; 111 | cov[1][1] += 0.3f; 112 | return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) }; 113 | } 114 | 115 | // Forward method for converting scale and rotation properties of each 116 | // Gaussian to a 3D covariance matrix in world space. Also takes care 117 | // of quaternion normalization. 118 | __device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D) 119 | { 120 | // Create scaling matrix 121 | glm::mat3 S = glm::mat3(1.0f); 122 | S[0][0] = mod * scale.x; 123 | S[1][1] = mod * scale.y; 124 | S[2][2] = mod * scale.z; 125 | 126 | // Normalize quaternion to get valid rotation 127 | glm::vec4 q = rot;// / glm::length(rot); 128 | float r = q.x; 129 | float x = q.y; 130 | float y = q.z; 131 | float z = q.w; 132 | 133 | // Compute rotation matrix from quaternion 134 | glm::mat3 R = glm::mat3( 135 | 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y), 136 | 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x), 137 | 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y) 138 | ); 139 | 140 | glm::mat3 M = S * R; 141 | 142 | // Compute 3D world covariance matrix Sigma 143 | glm::mat3 Sigma = glm::transpose(M) * M; 144 | 145 | // Covariance is symmetric, only store upper right 146 | cov3D[0] = Sigma[0][0]; 147 | cov3D[1] = Sigma[0][1]; 148 | cov3D[2] = Sigma[0][2]; 149 | cov3D[3] = Sigma[1][1]; 150 | cov3D[4] = Sigma[1][2]; 151 | cov3D[5] = Sigma[2][2]; 152 | } 153 | 154 | // Forward method for converting scale and rotation properties of each 155 | // Gaussian to a 3D norm vector in world space. 156 | __device__ void computeNorm3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* norm3D, int idx, const glm::vec3* means, glm::vec3 campos) 157 | { 158 | // Create scaling matrix 159 | glm::mat3 S = glm::mat3(1.0f); 160 | S[0][0] = mod * scale.x; 161 | S[1][1] = mod * scale.y; 162 | S[2][2] = mod * scale.z; 163 | 164 | // Normalize quaternion to get valid rotation 165 | glm::vec4 q = rot;// / glm::length(rot); 166 | float r = q.x; 167 | float x = q.y; 168 | float y = q.z; 169 | float z = q.w; 170 | 171 | // Compute rotation matrix from quaternion 172 | glm::mat3 R = glm::mat3( 173 | 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y), 174 | 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x), 175 | 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y) 176 | ); 177 | 178 | glm::vec3 norm; 179 | if(scale.x > scale.z && scale.y > scale.z) 180 | { 181 | norm = glm::vec3(0.0, 0.0, 1.0); 182 | } 183 | else if(scale.x > scale.y && scale.z > scale.y) 184 | { 185 | norm = glm::vec3(0.0, 1.0, 0.0); 186 | } 187 | else 188 | { 189 | norm = glm::vec3(1.0, 0.0, 0.0); 190 | } 191 | norm = glm::transpose(R) * norm; 192 | 193 | glm::vec3 raydir = means[idx] - campos; 194 | if(glm::dot(raydir, norm) > 0) 195 | norm = -norm; 196 | 197 | norm3D[0] = norm.x; 198 | norm3D[1] = norm.y; 199 | norm3D[2] = norm.z; 200 | } 201 | 202 | // Perform initial steps for each Gaussian prior to rasterization. 203 | template 204 | __global__ void preprocessCUDA(int P, int D, int M, 205 | const float* orig_points, 206 | const glm::vec3* scales, 207 | const float scale_modifier, 208 | const glm::vec4* rotations, 209 | const float* opacities, 210 | const float* shs, 211 | bool* clamped, 212 | const float* cov3Ds_precomp, 213 | const float* norm3Ds_precomp, 214 | const float* colors_precomp, 215 | const float* viewmatrix, 216 | const float* projmatrix, 217 | const glm::vec3* cam_pos, 218 | const int W, int H, 219 | const float tan_fovx, float tan_fovy, 220 | const float focal_x, float focal_y, 221 | int* radii, 222 | float2* points_xy_image, 223 | float* depths, 224 | float* cov3Ds, 225 | float* norm3Ds, 226 | float* rgb, 227 | float4* conic_opacity, 228 | const dim3 grid, 229 | uint32_t* tiles_touched, 230 | bool prefiltered) 231 | { 232 | auto idx = cg::this_grid().thread_rank(); 233 | if (idx >= P) 234 | return; 235 | 236 | // Initialize radius and touched tiles to 0. If this isn't changed, 237 | // this Gaussian will not be processed further. 238 | radii[idx] = 0; 239 | tiles_touched[idx] = 0; 240 | 241 | // Perform near culling, quit if outside. 242 | float3 p_view; 243 | if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) 244 | return; 245 | 246 | // Transform point by projecting 247 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 248 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 249 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 250 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 251 | 252 | // If 3D covariance matrix is precomputed, use it, otherwise compute 253 | // from scaling and rotation parameters. 254 | const float* cov3D; 255 | if (cov3Ds_precomp != nullptr) 256 | { 257 | cov3D = cov3Ds_precomp + idx * 6; 258 | } 259 | else 260 | { 261 | computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6); 262 | cov3D = cov3Ds + idx * 6; 263 | } 264 | 265 | const float* norm3D; 266 | if (norm3Ds_precomp != nullptr) 267 | { 268 | norm3D = norm3Ds_precomp + idx * 3; 269 | } 270 | else 271 | { 272 | computeNorm3D(scales[idx], scale_modifier, rotations[idx], norm3Ds + idx * 3, idx, (glm::vec3*)orig_points, *cam_pos); 273 | norm3D = norm3Ds + idx * 3; 274 | } 275 | // Compute 2D screen-space covariance matrix 276 | float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); 277 | 278 | // Invert covariance (EWA algorithm) 279 | float det = (cov.x * cov.z - cov.y * cov.y); 280 | if (det == 0.0f) 281 | return; 282 | float det_inv = 1.f / det; 283 | float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv }; 284 | 285 | // Compute extent in screen space (by finding eigenvalues of 286 | // 2D covariance matrix). Use extent to compute a bounding rectangle 287 | // of screen-space tiles that this Gaussian overlaps with. Quit if 288 | // rectangle covers 0 tiles. 289 | float mid = 0.5f * (cov.x + cov.z); 290 | float lambda1 = mid + sqrt(max(0.1f, mid * mid - det)); 291 | float lambda2 = mid - sqrt(max(0.1f, mid * mid - det)); 292 | float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2))); 293 | float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) }; 294 | uint2 rect_min, rect_max; 295 | getRect(point_image, my_radius, rect_min, rect_max, grid); 296 | if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) 297 | return; 298 | 299 | // If colors have been precomputed, use them, otherwise convert 300 | // spherical harmonics coefficients to RGB color. 301 | if (colors_precomp == nullptr) 302 | { 303 | glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); 304 | rgb[idx * C + 0] = result.x; 305 | rgb[idx * C + 1] = result.y; 306 | rgb[idx * C + 2] = result.z; 307 | } 308 | 309 | // Store some useful helper data for the next steps. 310 | depths[idx] = p_view.z; 311 | radii[idx] = my_radius; 312 | points_xy_image[idx] = point_image; 313 | // Inverse 2D covariance and opacity neatly pack into one float4 314 | conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] }; 315 | tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); 316 | } 317 | 318 | // Main rasterization method. Collaboratively works on one tile per 319 | // block, each thread treats one pixel. Alternates between fetching 320 | // and rasterizing data. 321 | template 322 | __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) 323 | renderCUDA( 324 | const uint2* __restrict__ ranges, 325 | const uint32_t* __restrict__ point_list, 326 | int W, int H, 327 | const int ED, 328 | const float2* __restrict__ points_xy_image, 329 | const float* __restrict__ features, 330 | const float* __restrict__ norms, 331 | const float* __restrict__ depths, 332 | const float* __restrict__ extras, 333 | const float4* __restrict__ conic_opacity, 334 | float* __restrict__ out_alpha, 335 | uint32_t* __restrict__ n_contrib, 336 | const float* __restrict__ bg_color, 337 | float* __restrict__ out_color, 338 | float* __restrict__ out_depth, 339 | float* __restrict__ out_norm, 340 | float* __restrict__ out_extra) 341 | { 342 | // Identify current tile and associated min/max pixel range. 343 | auto block = cg::this_thread_block(); 344 | uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; 345 | uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; 346 | uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; 347 | uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; 348 | uint32_t pix_id = W * pix.y + pix.x; 349 | float2 pixf = { (float)pix.x, (float)pix.y }; 350 | 351 | // Check if this thread is associated with a valid pixel or outside. 352 | bool inside = pix.x < W&& pix.y < H; 353 | // Done threads can help with fetching, but don't rasterize 354 | bool done = !inside; 355 | 356 | // Load start/end range of IDs to process in bit sorted list. 357 | uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; 358 | const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); 359 | int toDo = range.y - range.x; 360 | 361 | // Allocate storage for batches of collectively fetched data. 362 | __shared__ int collected_id[BLOCK_SIZE]; 363 | __shared__ float2 collected_xy[BLOCK_SIZE]; 364 | __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; 365 | 366 | // Initialize helper variables 367 | float T = 1.0f; 368 | uint32_t contributor = 0; 369 | uint32_t last_contributor = 0; 370 | float C[CHANNELS] = { 0 }; 371 | float D = 0; 372 | float N[3] = {0}; 373 | float E[MAX_EXTRA_DIMS] = {0}; 374 | // We assure the extra feature dim ED <= 8 375 | 376 | // Iterate over batches until all done or range is complete 377 | for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) 378 | { 379 | // End if entire block votes that it is done rasterizing 380 | int num_done = __syncthreads_count(done); 381 | if (num_done == BLOCK_SIZE) 382 | break; 383 | 384 | // Collectively fetch per-Gaussian data from global to shared 385 | int progress = i * BLOCK_SIZE + block.thread_rank(); 386 | if (range.x + progress < range.y) 387 | { 388 | int coll_id = point_list[range.x + progress]; 389 | collected_id[block.thread_rank()] = coll_id; 390 | collected_xy[block.thread_rank()] = points_xy_image[coll_id]; 391 | collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; 392 | } 393 | block.sync(); 394 | 395 | // Iterate over current batch 396 | for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) 397 | { 398 | // Keep track of current position in range 399 | contributor++; 400 | 401 | // Resample using conic matrix (cf. "Surface 402 | // Splatting" by Zwicker et al., 2001) 403 | float2 xy = collected_xy[j]; 404 | float2 d = { xy.x - pixf.x, xy.y - pixf.y }; 405 | float4 con_o = collected_conic_opacity[j]; 406 | float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; 407 | if (power > 0.0f) 408 | continue; 409 | 410 | // Eq. (2) from 3D Gaussian splatting paper. 411 | // Obtain alpha by multiplying with Gaussian opacity 412 | // and its exponential falloff from mean. 413 | // Avoid numerical instabilities (see paper appendix). 414 | float alpha = min(0.99f, con_o.w * exp(power)); 415 | if (alpha < 1.0f / 255.0f) 416 | continue; 417 | float test_T = T * (1 - alpha); 418 | if (test_T < 0.0001f) 419 | { 420 | done = true; 421 | continue; 422 | } 423 | 424 | // Eq. (3) from 3D Gaussian splatting paper. 425 | for (int ch = 0; ch < CHANNELS; ch++) 426 | C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; 427 | D += depths[collected_id[j]] * alpha * T; 428 | for (int ch = 0; ch < 3; ch++) 429 | N[ch] += norms[collected_id[j] * 3 + ch] * alpha * T; 430 | for(int ch = 0; ch < ED; ch++) 431 | E[ch] += extras[collected_id[j] * ED + ch] * alpha * T; 432 | T = test_T; 433 | 434 | // Keep track of last range entry to update this 435 | // pixel. 436 | last_contributor = contributor; 437 | } 438 | } 439 | 440 | // All threads that treat valid pixel write out their final 441 | // rendering data to the frame and auxiliary buffers. 442 | if (inside) 443 | { 444 | out_alpha[pix_id] = 1 - T; 445 | n_contrib[pix_id] = last_contributor; 446 | for (int ch = 0; ch < CHANNELS; ch++) 447 | out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; 448 | out_depth[pix_id] = D; 449 | // float len = sqrt(N[0]*N[0] + N[1]*N[1] + N[2]*N[2]) + 1e-6; 450 | for (int ch = 0; ch < 3; ch++) 451 | out_norm[ch * H * W + pix_id] = N[ch]; 452 | for (int ch = 0; ch < ED; ch++) 453 | out_extra[ch * H * W + pix_id] = E[ch]; 454 | 455 | } 456 | } 457 | 458 | void FORWARD::render( 459 | const dim3 grid, dim3 block, 460 | const uint2* ranges, 461 | const uint32_t* point_list, 462 | int W, int H, 463 | const int ED, 464 | const float2* means2D, 465 | const float* colors, 466 | const float* norms, 467 | const float* depths, 468 | const float* extras, 469 | const float4* conic_opacity, 470 | float* out_alpha, 471 | uint32_t* n_contrib, 472 | const float* bg_color, 473 | float* out_color, 474 | float* out_depth, 475 | float* out_norm, 476 | float* out_extra) 477 | { 478 | renderCUDA << > > ( 479 | ranges, 480 | point_list, 481 | W, H, ED, 482 | means2D, 483 | colors, 484 | norms, 485 | depths, 486 | extras, 487 | conic_opacity, 488 | out_alpha, 489 | n_contrib, 490 | bg_color, 491 | out_color, 492 | out_depth, 493 | out_norm, 494 | out_extra); 495 | } 496 | 497 | void FORWARD::preprocess(int P, int D, int M, 498 | const float* means3D, 499 | const glm::vec3* scales, 500 | const float scale_modifier, 501 | const glm::vec4* rotations, 502 | const float* opacities, 503 | const float* shs, 504 | bool* clamped, 505 | const float* cov3Ds_precomp, 506 | const float* norm3Ds_precomp, 507 | const float* colors_precomp, 508 | const float* viewmatrix, 509 | const float* projmatrix, 510 | const glm::vec3* cam_pos, 511 | const int W, int H, 512 | const float focal_x, float focal_y, 513 | const float tan_fovx, float tan_fovy, 514 | int* radii, 515 | float2* means2D, 516 | float* depths, 517 | float* cov3Ds, 518 | float* norm3Ds, 519 | float* rgb, 520 | float4* conic_opacity, 521 | const dim3 grid, 522 | uint32_t* tiles_touched, 523 | bool prefiltered) 524 | { 525 | preprocessCUDA << <(P + 255) / 256, 256 >> > ( 526 | P, D, M, 527 | means3D, 528 | scales, 529 | scale_modifier, 530 | rotations, 531 | opacities, 532 | shs, 533 | clamped, 534 | cov3Ds_precomp, 535 | norm3Ds_precomp, 536 | colors_precomp, 537 | viewmatrix, 538 | projmatrix, 539 | cam_pos, 540 | W, H, 541 | tan_fovx, tan_fovy, 542 | focal_x, focal_y, 543 | radii, 544 | means2D, 545 | depths, 546 | cov3Ds, 547 | norm3Ds, 548 | rgb, 549 | conic_opacity, 550 | grid, 551 | tiles_touched, 552 | prefiltered 553 | ); 554 | } -------------------------------------------------------------------------------- /cuda_rasterizer/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float* orig_points, 26 | const glm::vec3* scales, 27 | const float scale_modifier, 28 | const glm::vec4* rotations, 29 | const float* opacities, 30 | const float* shs, 31 | bool* clamped, 32 | const float* cov3Ds_precomp, 33 | const float* norm3Ds_precomp, 34 | const float* colors_precomp, 35 | const float* viewmatrix, 36 | const float* projmatrix, 37 | const glm::vec3* cam_pos, 38 | const int W, int H, 39 | const float focal_x, float focal_y, 40 | const float tan_fovx, float tan_fovy, 41 | int* radii, 42 | float2* points_xy_image, 43 | float* depths, 44 | float* cov3Ds, 45 | float* norm3Ds, 46 | float* colors, 47 | float4* conic_opacity, 48 | const dim3 grid, 49 | uint32_t* tiles_touched, 50 | bool prefiltered); 51 | 52 | // Main rasterization method. 53 | void render( 54 | const dim3 grid, dim3 block, 55 | const uint2* ranges, 56 | const uint32_t* point_list, 57 | int W, int H, int ED, 58 | const float2* points_xy_image, 59 | const float* features, 60 | const float* norms, 61 | const float* depths, 62 | const float* extras, 63 | const float4* conic_opacity, 64 | float* out_alpha, 65 | uint32_t* n_contrib, 66 | const float* bg_color, 67 | float* out_color, 68 | float* out_depth, 69 | float* out_norm, 70 | float* out_extra); 71 | } 72 | 73 | 74 | #endif -------------------------------------------------------------------------------- /cuda_rasterizer/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | 24 | static void markVisible( 25 | int P, 26 | float* means3D, 27 | float* viewmatrix, 28 | float* projmatrix, 29 | bool* present); 30 | 31 | static int forward( 32 | std::function geometryBuffer, 33 | std::function binningBuffer, 34 | std::function imageBuffer, 35 | const int P, const int D, const int M, const int ED, 36 | const float* background, 37 | const int width, const int height, 38 | const float* means3D, 39 | const float* shs, 40 | const float* colors_precomp, 41 | const float* opacities, 42 | const float* scales, 43 | const float scale_modifier, 44 | const float* rotations, 45 | const float* cov3Ds_precomp, 46 | const float* norm3Ds_precomp, 47 | const float* extra_attrs, 48 | const float* viewmatrix, 49 | const float* projmatrix, 50 | const float* cam_pos, 51 | const float tan_fovx, const float tan_fovy, 52 | const bool prefiltered, 53 | float* out_color, 54 | float* out_depth, 55 | float* out_norm, 56 | float* out_alpha, 57 | float* out_extra, 58 | int* radii, 59 | bool debug = false); 60 | 61 | static void backward( 62 | const int P, const int D, const int M, const int R, const int ED, 63 | const float* background, 64 | const int width, const int height, 65 | const float* means3D, 66 | const float* shs, 67 | const float* colors_precomp, 68 | const float* scales, 69 | const float scale_modifier, 70 | const float* rotations, 71 | const float* cov3Ds_precomp, 72 | const float* norm3Ds_precomp, 73 | const float* extra_attrs, 74 | const float* viewmatrix, 75 | const float* projmatrix, 76 | const float* campos, 77 | const float tan_fovx, const float tan_fovy, 78 | const int* radii, 79 | char* geom_buffer, 80 | char* binning_buffer, 81 | char* image_buffer, 82 | const float* accum_alphas, 83 | const float* dL_dpix, 84 | const float* dL_dpix_depth, 85 | const float* dL_dpix_norm, 86 | const float* dL_dpix_dalpha, 87 | const float* dL_dpix_dextra, 88 | float* dL_dmean2D, 89 | float* dL_dconic, 90 | float* dL_dopacity, 91 | float* dL_dcolor, 92 | float* dL_ddepth, 93 | float* dL_dmean3D, 94 | float* dL_dcov3D, 95 | float* dL_dnorm3D, 96 | float* dL_dsh, 97 | float* dL_dscale, 98 | float* dL_drot, 99 | float* dL_dextra, 100 | bool debug); 101 | }; 102 | }; 103 | 104 | #endif -------------------------------------------------------------------------------- /cuda_rasterizer/rasterizer_impl.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "rasterizer_impl.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include "cuda_runtime.h" 19 | #include "device_launch_parameters.h" 20 | #include 21 | #include 22 | #define GLM_FORCE_CUDA 23 | #include 24 | 25 | #include 26 | #include 27 | namespace cg = cooperative_groups; 28 | 29 | #include "auxiliary.h" 30 | #include "forward.h" 31 | #include "backward.h" 32 | 33 | // Helper function to find the next-highest bit of the MSB 34 | // on the CPU. 35 | uint32_t getHigherMsb(uint32_t n) 36 | { 37 | uint32_t msb = sizeof(n) * 4; 38 | uint32_t step = msb; 39 | while (step > 1) 40 | { 41 | step /= 2; 42 | if (n >> msb) 43 | msb += step; 44 | else 45 | msb -= step; 46 | } 47 | if (n >> msb) 48 | msb++; 49 | return msb; 50 | } 51 | 52 | // Wrapper method to call auxiliary coarse frustum containment test. 53 | // Mark all Gaussians that pass it. 54 | __global__ void checkFrustum(int P, 55 | const float* orig_points, 56 | const float* viewmatrix, 57 | const float* projmatrix, 58 | bool* present) 59 | { 60 | auto idx = cg::this_grid().thread_rank(); 61 | if (idx >= P) 62 | return; 63 | 64 | float3 p_view; 65 | present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view); 66 | } 67 | 68 | // Generates one key/value pair for all Gaussian / tile overlaps. 69 | // Run once per Gaussian (1:N mapping). 70 | __global__ void duplicateWithKeys( 71 | int P, 72 | const float2* points_xy, 73 | const float* depths, 74 | const uint32_t* offsets, 75 | uint64_t* gaussian_keys_unsorted, 76 | uint32_t* gaussian_values_unsorted, 77 | int* radii, 78 | dim3 grid) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | if (idx >= P) 82 | return; 83 | 84 | // Generate no key/value pair for invisible Gaussians 85 | if (radii[idx] > 0) 86 | { 87 | // Find this Gaussian's offset in buffer for writing keys/values. 88 | uint32_t off = (idx == 0) ? 0 : offsets[idx - 1]; 89 | uint2 rect_min, rect_max; 90 | 91 | getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid); 92 | 93 | // For each tile that the bounding rect overlaps, emit a 94 | // key/value pair. The key is | tile ID | depth |, 95 | // and the value is the ID of the Gaussian. Sorting the values 96 | // with this key yields Gaussian IDs in a list, such that they 97 | // are first sorted by tile and then by depth. 98 | for (int y = rect_min.y; y < rect_max.y; y++) 99 | { 100 | for (int x = rect_min.x; x < rect_max.x; x++) 101 | { 102 | uint64_t key = y * grid.x + x; 103 | key <<= 32; 104 | key |= *((uint32_t*)&depths[idx]); 105 | gaussian_keys_unsorted[off] = key; 106 | gaussian_values_unsorted[off] = idx; 107 | off++; 108 | } 109 | } 110 | } 111 | } 112 | 113 | // Check keys to see if it is at the start/end of one tile's range in 114 | // the full sorted list. If yes, write start/end of this tile. 115 | // Run once per instanced (duplicated) Gaussian ID. 116 | __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges) 117 | { 118 | auto idx = cg::this_grid().thread_rank(); 119 | if (idx >= L) 120 | return; 121 | 122 | // Read tile ID from key. Update start/end of tile range if at limit. 123 | uint64_t key = point_list_keys[idx]; 124 | uint32_t currtile = key >> 32; 125 | if (idx == 0) 126 | ranges[currtile].x = 0; 127 | else 128 | { 129 | uint32_t prevtile = point_list_keys[idx - 1] >> 32; 130 | if (currtile != prevtile) 131 | { 132 | ranges[prevtile].y = idx; 133 | ranges[currtile].x = idx; 134 | } 135 | } 136 | if (idx == L - 1) 137 | ranges[currtile].y = L; 138 | } 139 | 140 | // Mark Gaussians as visible/invisible, based on view frustum testing 141 | void CudaRasterizer::Rasterizer::markVisible( 142 | int P, 143 | float* means3D, 144 | float* viewmatrix, 145 | float* projmatrix, 146 | bool* present) 147 | { 148 | checkFrustum << <(P + 255) / 256, 256 >> > ( 149 | P, 150 | means3D, 151 | viewmatrix, projmatrix, 152 | present); 153 | } 154 | 155 | CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P) 156 | { 157 | GeometryState geom; 158 | obtain(chunk, geom.depths, P, 128); 159 | obtain(chunk, geom.clamped, P * 3, 128); 160 | obtain(chunk, geom.means2D, P, 128); 161 | obtain(chunk, geom.cov3D, P * 6, 128); 162 | obtain(chunk, geom.conic_opacity, P, 128); 163 | obtain(chunk, geom.rgb, P * 3, 128); 164 | obtain(chunk, geom.norm3D, P * 3, 128); 165 | obtain(chunk, geom.tiles_touched, P, 128); 166 | cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P); 167 | obtain(chunk, geom.scanning_space, geom.scan_size, 128); 168 | obtain(chunk, geom.point_offsets, P, 128); 169 | return geom; 170 | } 171 | 172 | CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N) 173 | { 174 | ImageState img; 175 | obtain(chunk, img.n_contrib, N, 128); 176 | obtain(chunk, img.ranges, N, 128); 177 | return img; 178 | } 179 | 180 | CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P) 181 | { 182 | BinningState binning; 183 | obtain(chunk, binning.point_list, P, 128); 184 | obtain(chunk, binning.point_list_unsorted, P, 128); 185 | obtain(chunk, binning.point_list_keys, P, 128); 186 | obtain(chunk, binning.point_list_keys_unsorted, P, 128); 187 | cub::DeviceRadixSort::SortPairs( 188 | nullptr, binning.sorting_size, 189 | binning.point_list_keys_unsorted, binning.point_list_keys, 190 | binning.point_list_unsorted, binning.point_list, P); 191 | obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); 192 | return binning; 193 | } 194 | 195 | // Forward rendering procedure for differentiable rasterization 196 | // of Gaussians. 197 | int CudaRasterizer::Rasterizer::forward( 198 | std::function geometryBuffer, 199 | std::function binningBuffer, 200 | std::function imageBuffer, 201 | const int P, const int D, const int M, 202 | const int ED, 203 | const float* background, 204 | const int width, const int height, 205 | const float* means3D, 206 | const float* shs, 207 | const float* colors_precomp, 208 | const float* opacities, 209 | const float* scales, 210 | const float scale_modifier, 211 | const float* rotations, 212 | const float* cov3Ds_precomp, 213 | const float* norm3Ds_precomp, 214 | const float* extra_attrs, 215 | const float* viewmatrix, 216 | const float* projmatrix, 217 | const float* cam_pos, 218 | const float tan_fovx, const float tan_fovy, 219 | const bool prefiltered, 220 | float* out_color, 221 | float* out_depth, 222 | float* out_norm, 223 | float* out_alpha, 224 | float* out_extra, 225 | int* radii, 226 | bool debug) 227 | { 228 | const float focal_y = height / (2.0f * tan_fovy); 229 | const float focal_x = width / (2.0f * tan_fovx); 230 | 231 | size_t chunk_size = required(P); 232 | char* chunkptr = geometryBuffer(chunk_size); 233 | GeometryState geomState = GeometryState::fromChunk(chunkptr, P); 234 | 235 | dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); 236 | dim3 block(BLOCK_X, BLOCK_Y, 1); 237 | 238 | // Dynamically resize image-based auxiliary buffers during training 239 | size_t img_chunk_size = required(width * height); 240 | char* img_chunkptr = imageBuffer(img_chunk_size); 241 | ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height); 242 | 243 | if (NUM_CHANNELS != 3 && colors_precomp == nullptr) 244 | { 245 | throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!"); 246 | } 247 | 248 | // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) 249 | CHECK_CUDA(FORWARD::preprocess( 250 | P, D, M, 251 | means3D, 252 | (glm::vec3*)scales, 253 | scale_modifier, 254 | (glm::vec4*)rotations, 255 | opacities, 256 | shs, 257 | geomState.clamped, 258 | cov3Ds_precomp, 259 | norm3Ds_precomp, 260 | colors_precomp, 261 | viewmatrix, projmatrix, 262 | (glm::vec3*)cam_pos, 263 | width, height, 264 | focal_x, focal_y, 265 | tan_fovx, tan_fovy, 266 | radii, 267 | geomState.means2D, 268 | geomState.depths, 269 | geomState.cov3D, 270 | geomState.norm3D, 271 | geomState.rgb, 272 | geomState.conic_opacity, 273 | tile_grid, 274 | geomState.tiles_touched, 275 | prefiltered 276 | ), debug) 277 | 278 | // Compute prefix sum over full list of touched tile counts by Gaussians 279 | // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] 280 | CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug) 281 | 282 | // Retrieve total number of Gaussian instances to launch and resize aux buffers 283 | int num_rendered; 284 | CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug); 285 | 286 | size_t binning_chunk_size = required(num_rendered); 287 | char* binning_chunkptr = binningBuffer(binning_chunk_size); 288 | BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered); 289 | 290 | // For each instance to be rendered, produce adequate [ tile | depth ] key 291 | // and corresponding dublicated Gaussian indices to be sorted 292 | duplicateWithKeys << <(P + 255) / 256, 256 >> > ( 293 | P, 294 | geomState.means2D, 295 | geomState.depths, 296 | geomState.point_offsets, 297 | binningState.point_list_keys_unsorted, 298 | binningState.point_list_unsorted, 299 | radii, 300 | tile_grid) 301 | CHECK_CUDA(, debug) 302 | 303 | int bit = getHigherMsb(tile_grid.x * tile_grid.y); 304 | 305 | // Sort complete list of (duplicated) Gaussian indices by keys 306 | CHECK_CUDA(cub::DeviceRadixSort::SortPairs( 307 | binningState.list_sorting_space, 308 | binningState.sorting_size, 309 | binningState.point_list_keys_unsorted, binningState.point_list_keys, 310 | binningState.point_list_unsorted, binningState.point_list, 311 | num_rendered, 0, 32 + bit), debug) 312 | 313 | CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug); 314 | 315 | // Identify start and end of per-tile workloads in sorted list 316 | if (num_rendered > 0) 317 | identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > ( 318 | num_rendered, 319 | binningState.point_list_keys, 320 | imgState.ranges); 321 | CHECK_CUDA(, debug) 322 | 323 | // Let each tile blend its range of Gaussians independently in parallel 324 | const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb; 325 | const float* norm_ptr = norm3Ds_precomp != nullptr ? norm3Ds_precomp : geomState.norm3D; 326 | CHECK_CUDA(FORWARD::render( 327 | tile_grid, block, 328 | imgState.ranges, 329 | binningState.point_list, 330 | width, height, ED, 331 | geomState.means2D, 332 | feature_ptr, 333 | norm_ptr, 334 | geomState.depths, 335 | extra_attrs, 336 | geomState.conic_opacity, 337 | out_alpha, 338 | imgState.n_contrib, 339 | background, 340 | out_color, 341 | out_depth, 342 | out_norm, 343 | out_extra), debug) 344 | 345 | return num_rendered; 346 | } 347 | 348 | // Produce necessary gradients for optimization, corresponding 349 | // to forward render pass 350 | void CudaRasterizer::Rasterizer::backward( 351 | const int P, const int D, const int M, const int R, const int ED, 352 | const float* background, 353 | const int width, const int height, 354 | const float* means3D, 355 | const float* shs, 356 | const float* colors_precomp, 357 | const float* scales, 358 | const float scale_modifier, 359 | const float* rotations, 360 | const float* cov3Ds_precomp, 361 | const float* norm3Ds_precomp, 362 | const float* extra_attrs, 363 | const float* viewmatrix, 364 | const float* projmatrix, 365 | const float* campos, 366 | const float tan_fovx, const float tan_fovy, 367 | const int* radii, 368 | char* geom_buffer, 369 | char* binning_buffer, 370 | char* img_buffer, 371 | const float* accum_alphas, 372 | const float* dL_dpix, 373 | const float* dL_dpix_depth, 374 | const float* dL_dpix_norm, 375 | const float* dL_dpix_alpha, 376 | const float* dL_dpix_extra, 377 | float* dL_dmean2D, 378 | float* dL_dconic, 379 | float* dL_dopacity, 380 | float* dL_dcolor, 381 | float* dL_ddepth, 382 | float* dL_dmean3D, 383 | float* dL_dcov3D, 384 | float* dL_dnorm3D, 385 | float* dL_dsh, 386 | float* dL_dscale, 387 | float* dL_drot, 388 | float* dL_dextra, 389 | bool debug) 390 | { 391 | GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); 392 | BinningState binningState = BinningState::fromChunk(binning_buffer, R); 393 | ImageState imgState = ImageState::fromChunk(img_buffer, width * height); 394 | 395 | const float focal_y = height / (2.0f * tan_fovy); 396 | const float focal_x = width / (2.0f * tan_fovx); 397 | 398 | const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); 399 | const dim3 block(BLOCK_X, BLOCK_Y, 1); 400 | 401 | // Compute loss gradients w.r.t. 2D mean position, conic matrix, 402 | // opacity and RGB of Gaussians from per-pixel loss gradients. 403 | // If we were given precomputed colors and not SHs, use them. 404 | const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb; 405 | const float* norm_ptr = (norm3Ds_precomp != nullptr) ? norm3Ds_precomp : geomState.norm3D; 406 | CHECK_CUDA(BACKWARD::render( 407 | tile_grid, 408 | block, 409 | imgState.ranges, 410 | binningState.point_list, 411 | width, height, ED, 412 | background, 413 | geomState.means2D, 414 | geomState.conic_opacity, 415 | color_ptr, 416 | geomState.depths, 417 | norm_ptr, 418 | extra_attrs, 419 | accum_alphas, 420 | imgState.n_contrib, 421 | dL_dpix, 422 | dL_dpix_depth, 423 | dL_dpix_norm, 424 | dL_dpix_alpha, 425 | dL_dpix_extra, 426 | (float3*)dL_dmean2D, 427 | (float4*)dL_dconic, 428 | dL_dopacity, 429 | dL_dcolor, 430 | dL_ddepth, 431 | dL_dnorm3D, 432 | dL_dextra), debug) 433 | 434 | // Take care of the rest of preprocessing. Was the precomputed covariance 435 | // given to us or a scales/rot pair? If precomputed, pass that. If not, 436 | // use the one we computed ourselves. 437 | const float* cov3D_ptr = (cov3Ds_precomp != nullptr) ? cov3Ds_precomp : geomState.cov3D; 438 | CHECK_CUDA(BACKWARD::preprocess(P, D, M, 439 | (float3*)means3D, 440 | radii, 441 | shs, 442 | geomState.clamped, 443 | (glm::vec3*)scales, 444 | (glm::vec4*)rotations, 445 | scale_modifier, 446 | cov3D_ptr, 447 | (glm::vec3*)norm_ptr, 448 | (norm3Ds_precomp != nullptr), 449 | viewmatrix, 450 | projmatrix, 451 | focal_x, focal_y, 452 | tan_fovx, tan_fovy, 453 | (glm::vec3*)campos, 454 | (float3*)dL_dmean2D, 455 | dL_dconic, 456 | (glm::vec3*)dL_dmean3D, 457 | dL_dcolor, 458 | dL_ddepth, 459 | dL_dcov3D, 460 | (glm::vec3*)dL_dnorm3D, 461 | dL_dsh, 462 | (glm::vec3*)dL_dscale, 463 | (glm::vec4*)dL_drot), debug) 464 | } -------------------------------------------------------------------------------- /cuda_rasterizer/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | 19 | namespace CudaRasterizer 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | float* depths; 33 | char* scanning_space; 34 | bool* clamped; 35 | float2* means2D; 36 | float* cov3D; 37 | float* norm3D; 38 | float4* conic_opacity; 39 | float* rgb; 40 | uint32_t* point_offsets; 41 | uint32_t* tiles_touched; 42 | 43 | static GeometryState fromChunk(char*& chunk, size_t P); 44 | }; 45 | 46 | struct ImageState 47 | { 48 | uint2* ranges; 49 | uint32_t* n_contrib; 50 | 51 | static ImageState fromChunk(char*& chunk, size_t N); 52 | }; 53 | 54 | struct BinningState 55 | { 56 | size_t sorting_size; 57 | uint64_t* point_list_keys_unsorted; 58 | uint64_t* point_list_keys; 59 | uint32_t* point_list_unsorted; 60 | uint32_t* point_list; 61 | char* list_sorting_space; 62 | 63 | static BinningState fromChunk(char*& chunk, size_t P); 64 | }; 65 | 66 | template 67 | size_t required(size_t P) 68 | { 69 | char* size = nullptr; 70 | T::fromChunk(size, P); 71 | return ((size_t)size) + 128; 72 | } 73 | }; -------------------------------------------------------------------------------- /diff_gauss/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from typing import NamedTuple 13 | import torch.nn as nn 14 | import torch 15 | from . import _C 16 | 17 | def cpu_deep_copy_tuple(input_tuple): 18 | copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] 19 | return tuple(copied_tensors) 20 | 21 | def rasterize_gaussians( 22 | means3D, 23 | means2D, 24 | sh, 25 | colors_precomp, 26 | opacities, 27 | scales, 28 | rotations, 29 | cov3Ds_precomp, 30 | norm3Ds_precomp, 31 | extra_attrs, 32 | raster_settings, 33 | ): 34 | color, depth, norm, alpha, radii, extra = _RasterizeGaussians.apply( 35 | means3D, 36 | means2D, 37 | sh, 38 | colors_precomp, 39 | opacities, 40 | scales, 41 | rotations, 42 | cov3Ds_precomp, 43 | norm3Ds_precomp, 44 | extra_attrs, 45 | raster_settings, 46 | ) 47 | 48 | norm = torch.nn.functional.normalize(norm, p=2, dim=0) 49 | # 3, H, W 50 | return color, depth, norm, alpha, radii, extra 51 | 52 | class _RasterizeGaussians(torch.autograd.Function): 53 | @staticmethod 54 | def forward( 55 | ctx, 56 | means3D, 57 | means2D, 58 | sh, 59 | colors_precomp, 60 | opacities, 61 | scales, 62 | rotations, 63 | cov3Ds_precomp, 64 | norm3Ds_precomp, 65 | extra_attrs, 66 | raster_settings, 67 | ): 68 | # restrict the length of extra attr values to avoid dynamically sized shared memory allocation 69 | assert extra_attrs.shape[0] == 0 or extra_attrs.shape[1] <= 34 70 | # Restructure arguments the way that the C++ lib expects them 71 | args = ( 72 | raster_settings.bg, 73 | means3D, 74 | colors_precomp, 75 | opacities, 76 | scales, 77 | rotations, 78 | raster_settings.scale_modifier, 79 | cov3Ds_precomp, 80 | norm3Ds_precomp, 81 | extra_attrs, 82 | extra_attrs.shape[1] if extra_attrs.shape[0] != 0 else 0, 83 | raster_settings.viewmatrix, 84 | raster_settings.projmatrix, 85 | raster_settings.tanfovx, 86 | raster_settings.tanfovy, 87 | raster_settings.image_height, 88 | raster_settings.image_width, 89 | sh, 90 | raster_settings.sh_degree, 91 | raster_settings.campos, 92 | raster_settings.prefiltered, 93 | raster_settings.debug 94 | ) 95 | 96 | # Invoke C++/CUDA rasterizer 97 | if raster_settings.debug: 98 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 99 | try: 100 | num_rendered, color, depth, norm, alpha, radii, extra, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 101 | except Exception as ex: 102 | torch.save(cpu_args, "snapshot_fw.dump") 103 | print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") 104 | raise ex 105 | else: 106 | num_rendered, color, depth, norm, alpha, radii, extra, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 107 | 108 | # Keep relevant tensors for backward 109 | ctx.raster_settings = raster_settings 110 | ctx.num_rendered = num_rendered 111 | ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, norm3Ds_precomp, radii, extra_attrs, sh, geomBuffer, binningBuffer, imgBuffer, alpha) 112 | return color, depth, norm, alpha, radii, extra 113 | 114 | @staticmethod 115 | def backward(ctx, grad_out_color, grad_out_depth, grad_out_norm, grad_out_alpha, _, grad_out_extra): 116 | 117 | # Restore necessary values from context 118 | num_rendered = ctx.num_rendered 119 | raster_settings = ctx.raster_settings 120 | colors_precomp, means3D, scales, rotations, cov3Ds_precomp, norm3Ds_precomp, radii, extra_attrs, sh, geomBuffer, binningBuffer, imgBuffer, alpha = ctx.saved_tensors 121 | 122 | # Restructure args as C++ method expects them 123 | args = (raster_settings.bg, 124 | means3D, 125 | radii, 126 | colors_precomp, 127 | scales, 128 | rotations, 129 | extra_attrs, 130 | raster_settings.scale_modifier, 131 | cov3Ds_precomp, 132 | norm3Ds_precomp, 133 | raster_settings.viewmatrix, 134 | raster_settings.projmatrix, 135 | raster_settings.tanfovx, 136 | raster_settings.tanfovy, 137 | grad_out_color, 138 | grad_out_depth, 139 | grad_out_norm, 140 | grad_out_alpha, 141 | grad_out_extra, 142 | sh, 143 | raster_settings.sh_degree, 144 | raster_settings.campos, 145 | geomBuffer, 146 | num_rendered, 147 | binningBuffer, 148 | imgBuffer, 149 | alpha, 150 | raster_settings.debug) 151 | 152 | # Compute gradients for relevant tensors by invoking backward method 153 | if raster_settings.debug: 154 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 155 | try: 156 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_norm3Ds_precomp, grad_sh, grad_scales, grad_rotations, grad_extra_attrs = _C.rasterize_gaussians_backward(*args) 157 | except Exception as ex: 158 | torch.save(cpu_args, "snapshot_bw.dump") 159 | print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") 160 | raise ex 161 | else: 162 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_norm3Ds_precomp, grad_sh, grad_scales, grad_rotations, grad_extra_attrs = _C.rasterize_gaussians_backward(*args) 163 | 164 | grads = ( 165 | grad_means3D, 166 | grad_means2D, 167 | grad_sh, 168 | grad_colors_precomp, 169 | grad_opacities, 170 | grad_scales, 171 | grad_rotations, 172 | grad_cov3Ds_precomp, 173 | grad_norm3Ds_precomp, 174 | grad_extra_attrs, 175 | None 176 | ) 177 | 178 | return grads 179 | 180 | class GaussianRasterizationSettings(NamedTuple): 181 | image_height: int 182 | image_width: int 183 | tanfovx : float 184 | tanfovy : float 185 | bg : torch.Tensor 186 | scale_modifier : float 187 | viewmatrix : torch.Tensor 188 | projmatrix : torch.Tensor 189 | sh_degree : int 190 | campos : torch.Tensor 191 | prefiltered : bool 192 | debug : bool 193 | 194 | class GaussianRasterizer(nn.Module): 195 | def __init__(self, raster_settings): 196 | super().__init__() 197 | self.raster_settings = raster_settings 198 | 199 | def markVisible(self, positions): 200 | # Mark visible points (based on frustum culling for camera) with a boolean 201 | with torch.no_grad(): 202 | raster_settings = self.raster_settings 203 | visible = _C.mark_visible( 204 | positions, 205 | raster_settings.viewmatrix, 206 | raster_settings.projmatrix) 207 | 208 | return visible 209 | 210 | def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3Ds_precomp = None, norm3Ds_precomp=None, extra_attrs=None): 211 | 212 | raster_settings = self.raster_settings 213 | 214 | if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): 215 | raise Exception('Please provide excatly one of either SHs or precomputed colors!') 216 | 217 | if ((scales is None or rotations is None) and cov3Ds_precomp is None) or ((scales is not None or rotations is not None) and cov3Ds_precomp is not None): 218 | raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') 219 | 220 | if shs is None: 221 | shs = torch.Tensor([]) 222 | if colors_precomp is None: 223 | colors_precomp = torch.Tensor([]) 224 | 225 | if scales is None: 226 | raise ValueError('To support norm and depth prediction, scales == None is not allowed') 227 | scales = torch.Tensor([]) 228 | if rotations is None: 229 | raise ValueError('To support norm and depth prediction, rotations == None is not allowed') 230 | rotations = torch.Tensor([]) 231 | if cov3Ds_precomp is None: 232 | cov3Ds_precomp = torch.Tensor([]) 233 | if norm3Ds_precomp is None: 234 | norm3Ds_precomp = torch.Tensor([]) 235 | if extra_attrs is None: 236 | extra_attrs = torch.Tensor([]) 237 | # Invoke C++/CUDA rasterization routine 238 | return rasterize_gaussians( 239 | means3D, 240 | means2D, 241 | shs, 242 | colors_precomp, 243 | opacities, 244 | scales, 245 | rotations, 246 | cov3Ds_precomp, 247 | norm3Ds_precomp, 248 | extra_attrs, 249 | raster_settings, 250 | ) 251 | 252 | -------------------------------------------------------------------------------- /ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "rasterize_points.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); 17 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); 18 | m.def("mark_visible", &markVisible); 19 | } -------------------------------------------------------------------------------- /rasterize_points.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include "cuda_rasterizer/config.h" 22 | #include "cuda_rasterizer/rasterizer.h" 23 | #include 24 | #include 25 | #include 26 | 27 | std::function resizeFunctional(torch::Tensor& t) { 28 | auto lambda = [&t](size_t N) { 29 | t.resize_({(long long)N}); 30 | return reinterpret_cast(t.contiguous().data_ptr()); 31 | }; 32 | return lambda; 33 | } 34 | 35 | std::tuple 36 | RasterizeGaussiansCUDA( 37 | const torch::Tensor& background, 38 | const torch::Tensor& means3D, 39 | const torch::Tensor& colors, 40 | const torch::Tensor& opacity, 41 | const torch::Tensor& scales, 42 | const torch::Tensor& rotations, 43 | const float scale_modifier, 44 | const torch::Tensor& cov3Ds_precomp, 45 | const torch::Tensor& norm3Ds_precomp, 46 | const torch::Tensor& extra_attrs, 47 | const int attr_degree, 48 | const torch::Tensor& viewmatrix, 49 | const torch::Tensor& projmatrix, 50 | const float tan_fovx, 51 | const float tan_fovy, 52 | const int image_height, 53 | const int image_width, 54 | const torch::Tensor& sh, 55 | const int degree, 56 | const torch::Tensor& campos, 57 | const bool prefiltered, 58 | const bool debug) 59 | { 60 | if (means3D.ndimension() != 2 || means3D.size(1) != 3) { 61 | AT_ERROR("means3D must have dimensions (num_points, 3)"); 62 | } 63 | 64 | const int P = means3D.size(0); 65 | const int H = image_height; 66 | const int W = image_width; 67 | const int F = attr_degree; 68 | 69 | auto int_opts = means3D.options().dtype(torch::kInt32); 70 | auto float_opts = means3D.options().dtype(torch::kFloat32); 71 | 72 | torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); 73 | torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts); 74 | torch::Tensor out_alpha = torch::full({1, H, W}, 0.0, float_opts); 75 | torch::Tensor out_norm = torch::full({3, H, W}, 0.0, float_opts); 76 | torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 77 | torch::Tensor out_extra; 78 | if(F > 0) 79 | out_extra = torch::full({F, H, W}, 0.0, float_opts); 80 | else 81 | out_extra = torch::empty({0}, float_opts); 82 | 83 | torch::Device device(torch::kCUDA); 84 | torch::TensorOptions options(torch::kByte); 85 | torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); 86 | torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); 87 | torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); 88 | std::function geomFunc = resizeFunctional(geomBuffer); 89 | std::function binningFunc = resizeFunctional(binningBuffer); 90 | std::function imgFunc = resizeFunctional(imgBuffer); 91 | 92 | int rendered = 0; 93 | if(P != 0) 94 | { 95 | int M = 0; 96 | if(sh.size(0) != 0) 97 | { 98 | M = sh.size(1); 99 | } 100 | 101 | rendered = CudaRasterizer::Rasterizer::forward( 102 | geomFunc, 103 | binningFunc, 104 | imgFunc, 105 | P, degree, M, F, 106 | background.contiguous().data(), 107 | W, H, 108 | means3D.contiguous().data(), 109 | sh.contiguous().data_ptr(), 110 | colors.contiguous().data(), 111 | opacity.contiguous().data(), 112 | scales.contiguous().data_ptr(), 113 | scale_modifier, 114 | rotations.contiguous().data_ptr(), 115 | cov3Ds_precomp.contiguous().data(), 116 | norm3Ds_precomp.contiguous().data(), 117 | extra_attrs.contiguous().data(), 118 | viewmatrix.contiguous().data(), 119 | projmatrix.contiguous().data(), 120 | campos.contiguous().data(), 121 | tan_fovx, 122 | tan_fovy, 123 | prefiltered, 124 | out_color.contiguous().data(), 125 | out_depth.contiguous().data(), 126 | out_norm.contiguous().data(), 127 | out_alpha.contiguous().data(), 128 | out_extra.contiguous().data(), 129 | radii.contiguous().data(), 130 | debug); 131 | } 132 | return std::make_tuple(rendered, out_color, out_depth, out_norm, out_alpha, radii, out_extra, geomBuffer, binningBuffer, imgBuffer); 133 | } 134 | 135 | std::tuple 136 | RasterizeGaussiansBackwardCUDA( 137 | const torch::Tensor& background, 138 | const torch::Tensor& means3D, 139 | const torch::Tensor& radii, 140 | const torch::Tensor& colors, 141 | const torch::Tensor& scales, 142 | const torch::Tensor& rotations, 143 | const torch::Tensor& extra_attrs, 144 | const float scale_modifier, 145 | const torch::Tensor& cov3Ds_precomp, 146 | const torch::Tensor& norm3Ds_precomp, 147 | const torch::Tensor& viewmatrix, 148 | const torch::Tensor& projmatrix, 149 | const float tan_fovx, 150 | const float tan_fovy, 151 | const torch::Tensor& dL_dout_color, 152 | const torch::Tensor& dL_dout_depth, 153 | const torch::Tensor& dL_dout_norm, 154 | const torch::Tensor& dL_dout_alpha, 155 | const torch::Tensor& dL_dout_extra, 156 | const torch::Tensor& sh, 157 | const int degree, 158 | const torch::Tensor& campos, 159 | const torch::Tensor& geomBuffer, 160 | const int R, 161 | const torch::Tensor& binningBuffer, 162 | const torch::Tensor& imageBuffer, 163 | const torch::Tensor& out_alpha, 164 | const bool debug) 165 | { 166 | const int P = means3D.size(0); 167 | const int H = dL_dout_color.size(1); 168 | const int W = dL_dout_color.size(2); 169 | const int F = (extra_attrs.size(0) != 0 ? extra_attrs.size(1) : 0); 170 | 171 | int M = 0; 172 | if(sh.size(0) != 0) 173 | { 174 | M = sh.size(1); 175 | } 176 | 177 | torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); 178 | torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); 179 | torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); 180 | // just for storing intermediate results 181 | torch::Tensor dL_ddepths = torch::zeros({P, 1}, means3D.options()); 182 | torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); 183 | torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); 184 | torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); 185 | torch::Tensor dL_dnorm3D = torch::zeros({P, 3}, means3D.options()); 186 | torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); 187 | torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); 188 | torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); 189 | torch::Tensor dL_dextra_attrs; 190 | if(F > 0) 191 | dL_dextra_attrs = torch::zeros({P, F}, means3D.options()); 192 | else 193 | dL_dextra_attrs = torch::empty({0}, means3D.options()); 194 | if(P != 0) 195 | { 196 | CudaRasterizer::Rasterizer::backward(P, degree, M, R, F, 197 | background.contiguous().data(), 198 | W, H, 199 | means3D.contiguous().data(), 200 | sh.contiguous().data(), 201 | colors.contiguous().data(), 202 | scales.data_ptr(), 203 | scale_modifier, 204 | rotations.data_ptr(), 205 | cov3Ds_precomp.contiguous().data(), 206 | norm3Ds_precomp.contiguous().data(), 207 | extra_attrs.contiguous().data(), 208 | viewmatrix.contiguous().data(), 209 | projmatrix.contiguous().data(), 210 | campos.contiguous().data(), 211 | tan_fovx, 212 | tan_fovy, 213 | radii.contiguous().data(), 214 | reinterpret_cast(geomBuffer.contiguous().data_ptr()), 215 | reinterpret_cast(binningBuffer.contiguous().data_ptr()), 216 | reinterpret_cast(imageBuffer.contiguous().data_ptr()), 217 | out_alpha.contiguous().data(), 218 | dL_dout_color.contiguous().data(), 219 | dL_dout_depth.contiguous().data(), 220 | dL_dout_norm.contiguous().data(), 221 | dL_dout_alpha.contiguous().data(), 222 | dL_dout_extra.contiguous().data(), 223 | dL_dmeans2D.contiguous().data(), 224 | dL_dconic.contiguous().data(), 225 | dL_dopacity.contiguous().data(), 226 | dL_dcolors.contiguous().data(), 227 | dL_ddepths.contiguous().data(), 228 | dL_dmeans3D.contiguous().data(), 229 | dL_dcov3D.contiguous().data(), 230 | dL_dnorm3D.contiguous().data(), 231 | dL_dsh.contiguous().data(), 232 | dL_dscales.contiguous().data(), 233 | dL_drotations.contiguous().data(), 234 | dL_dextra_attrs.contiguous().data(), 235 | debug); 236 | } 237 | 238 | return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dnorm3D, dL_dsh, dL_dscales, dL_drotations, dL_dextra_attrs); 239 | } 240 | 241 | torch::Tensor markVisible( 242 | torch::Tensor& means3D, 243 | torch::Tensor& viewmatrix, 244 | torch::Tensor& projmatrix) 245 | { 246 | const int P = means3D.size(0); 247 | 248 | torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); 249 | 250 | if(P != 0) 251 | { 252 | CudaRasterizer::Rasterizer::markVisible(P, 253 | means3D.contiguous().data(), 254 | viewmatrix.contiguous().data(), 255 | projmatrix.contiguous().data(), 256 | present.contiguous().data()); 257 | } 258 | 259 | return present; 260 | } -------------------------------------------------------------------------------- /rasterize_points.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | std::tuple 18 | RasterizeGaussiansCUDA( 19 | const torch::Tensor& background, 20 | const torch::Tensor& means3D, 21 | const torch::Tensor& colors, 22 | const torch::Tensor& opacity, 23 | const torch::Tensor& scales, 24 | const torch::Tensor& rotations, 25 | const float scale_modifier, 26 | const torch::Tensor& cov3Ds_precomp, 27 | const torch::Tensor& norm3Ds_precomp, 28 | const torch::Tensor& extra_attrs, 29 | const int attr_degree, 30 | const torch::Tensor& viewmatrix, 31 | const torch::Tensor& projmatrix, 32 | const float tan_fovx, 33 | const float tan_fovy, 34 | const int image_height, 35 | const int image_width, 36 | const torch::Tensor& sh, 37 | const int degree, 38 | const torch::Tensor& campos, 39 | const bool prefiltered, 40 | const bool debug); 41 | 42 | std::tuple 43 | RasterizeGaussiansBackwardCUDA( 44 | const torch::Tensor& background, 45 | const torch::Tensor& means3D, 46 | const torch::Tensor& radii, 47 | const torch::Tensor& colors, 48 | const torch::Tensor& scales, 49 | const torch::Tensor& rotations, 50 | const torch::Tensor& extra_attrs, 51 | const float scale_modifier, 52 | const torch::Tensor& cov3Ds_precomp, 53 | const torch::Tensor& norm3Ds_precomp, 54 | const torch::Tensor& viewmatrix, 55 | const torch::Tensor& projmatrix, 56 | const float tan_fovx, 57 | const float tan_fovy, 58 | const torch::Tensor& dL_dout_color, 59 | const torch::Tensor& dL_dout_depth, 60 | const torch::Tensor& dL_dout_norm, 61 | const torch::Tensor& dL_dout_alpha, 62 | const torch::Tensor& dL_dout_extra, 63 | const torch::Tensor& sh, 64 | const int degree, 65 | const torch::Tensor& campos, 66 | const torch::Tensor& geomBuffer, 67 | const int R, 68 | const torch::Tensor& binningBuffer, 69 | const torch::Tensor& imageBuffer, 70 | const torch::Tensor& out_alpha, 71 | const bool debug); 72 | 73 | torch::Tensor markVisible( 74 | torch::Tensor& means3D, 75 | torch::Tensor& viewmatrix, 76 | torch::Tensor& projmatrix); 77 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | os.path.dirname(os.path.abspath(__file__)) 16 | 17 | setup( 18 | name="diff_gauss", 19 | packages=['diff_gauss'], 20 | version="1.0.10.0", 21 | ext_modules=[ 22 | CUDAExtension( 23 | name="diff_gauss._C", 24 | sources=[ 25 | "cuda_rasterizer/rasterizer_impl.cu", 26 | "cuda_rasterizer/forward.cu", 27 | "cuda_rasterizer/backward.cu", 28 | "rasterize_points.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /third_party/stbi_image_write.h: -------------------------------------------------------------------------------- 1 | /* stb_image_write - v1.16 - public domain - http://nothings.org/stb 2 | writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015 3 | no warranty implied; use at your own risk 4 | 5 | Before #including, 6 | 7 | #define STB_IMAGE_WRITE_IMPLEMENTATION 8 | 9 | in the file that you want to have the implementation. 10 | 11 | Will probably not work correctly with strict-aliasing optimizations. 12 | 13 | ABOUT: 14 | 15 | This header file is a library for writing images to C stdio or a callback. 16 | 17 | The PNG output is not optimal; it is 20-50% larger than the file 18 | written by a decent optimizing implementation; though providing a custom 19 | zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that. 20 | This library is designed for source code compactness and simplicity, 21 | not optimal image file size or run-time performance. 22 | 23 | BUILDING: 24 | 25 | You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h. 26 | You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace 27 | malloc,realloc,free. 28 | You can #define STBIW_MEMMOVE() to replace memmove() 29 | You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function 30 | for PNG compression (instead of the builtin one), it must have the following signature: 31 | unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality); 32 | The returned data will be freed with STBIW_FREE() (free() by default), 33 | so it must be heap allocated with STBIW_MALLOC() (malloc() by default), 34 | 35 | UNICODE: 36 | 37 | If compiling for Windows and you wish to use Unicode filenames, compile 38 | with 39 | #define STBIW_WINDOWS_UTF8 40 | and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert 41 | Windows wchar_t filenames to utf8. 42 | 43 | USAGE: 44 | 45 | There are five functions, one for each image file format: 46 | 47 | int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); 48 | int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); 49 | int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); 50 | int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality); 51 | int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); 52 | 53 | void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically 54 | 55 | There are also five equivalent functions that use an arbitrary write function. You are 56 | expected to open/close your file-equivalent before and after calling these: 57 | 58 | int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); 59 | int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 60 | int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 61 | int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); 62 | int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); 63 | 64 | where the callback is: 65 | void stbi_write_func(void *context, void *data, int size); 66 | 67 | You can configure it with these global variables: 68 | int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE 69 | int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression 70 | int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode 71 | 72 | 73 | You can define STBI_WRITE_NO_STDIO to disable the file variant of these 74 | functions, so the library will not use stdio.h at all. However, this will 75 | also disable HDR writing, because it requires stdio for formatted output. 76 | 77 | Each function returns 0 on failure and non-0 on success. 78 | 79 | The functions create an image file defined by the parameters. The image 80 | is a rectangle of pixels stored from left-to-right, top-to-bottom. 81 | Each pixel contains 'comp' channels of data stored interleaved with 8-bits 82 | per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is 83 | monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall. 84 | The *data pointer points to the first byte of the top-left-most pixel. 85 | For PNG, "stride_in_bytes" is the distance in bytes from the first byte of 86 | a row of pixels to the first byte of the next row of pixels. 87 | 88 | PNG creates output files with the same number of components as the input. 89 | The BMP format expands Y to RGB in the file format and does not 90 | output alpha. 91 | 92 | PNG supports writing rectangles of data even when the bytes storing rows of 93 | data are not consecutive in memory (e.g. sub-rectangles of a larger image), 94 | by supplying the stride between the beginning of adjacent rows. The other 95 | formats do not. (Thus you cannot write a native-format BMP through the BMP 96 | writer, both because it is in BGR order and because it may have padding 97 | at the end of the line.) 98 | 99 | PNG allows you to set the deflate compression level by setting the global 100 | variable 'stbi_write_png_compression_level' (it defaults to 8). 101 | 102 | HDR expects linear float data. Since the format is always 32-bit rgb(e) 103 | data, alpha (if provided) is discarded, and for monochrome data it is 104 | replicated across all three channels. 105 | 106 | TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed 107 | data, set the global variable 'stbi_write_tga_with_rle' to 0. 108 | 109 | JPEG does ignore alpha channels in input data; quality is between 1 and 100. 110 | Higher quality looks better but results in a bigger image. 111 | JPEG baseline (no JPEG progressive). 112 | 113 | CREDITS: 114 | 115 | 116 | Sean Barrett - PNG/BMP/TGA 117 | Baldur Karlsson - HDR 118 | Jean-Sebastien Guay - TGA monochrome 119 | Tim Kelsey - misc enhancements 120 | Alan Hickman - TGA RLE 121 | Emmanuel Julien - initial file IO callback implementation 122 | Jon Olick - original jo_jpeg.cpp code 123 | Daniel Gibson - integrate JPEG, allow external zlib 124 | Aarni Koskela - allow choosing PNG filter 125 | 126 | bugfixes: 127 | github:Chribba 128 | Guillaume Chereau 129 | github:jry2 130 | github:romigrou 131 | Sergio Gonzalez 132 | Jonas Karlsson 133 | Filip Wasil 134 | Thatcher Ulrich 135 | github:poppolopoppo 136 | Patrick Boettcher 137 | github:xeekworx 138 | Cap Petschulat 139 | Simon Rodriguez 140 | Ivan Tikhonov 141 | github:ignotion 142 | Adam Schackart 143 | Andrew Kensler 144 | 145 | LICENSE 146 | 147 | See end of file for license information. 148 | 149 | */ 150 | 151 | #ifndef INCLUDE_STB_IMAGE_WRITE_H 152 | #define INCLUDE_STB_IMAGE_WRITE_H 153 | 154 | #include 155 | 156 | // if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline' 157 | #ifndef STBIWDEF 158 | #ifdef STB_IMAGE_WRITE_STATIC 159 | #define STBIWDEF static 160 | #else 161 | #ifdef __cplusplus 162 | #define STBIWDEF extern "C" 163 | #else 164 | #define STBIWDEF extern 165 | #endif 166 | #endif 167 | #endif 168 | 169 | #ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations 170 | STBIWDEF int stbi_write_tga_with_rle; 171 | STBIWDEF int stbi_write_png_compression_level; 172 | STBIWDEF int stbi_write_force_png_filter; 173 | #endif 174 | 175 | #ifndef STBI_WRITE_NO_STDIO 176 | STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); 177 | STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); 178 | STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); 179 | STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); 180 | STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality); 181 | 182 | #ifdef STBIW_WINDOWS_UTF8 183 | STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); 184 | #endif 185 | #endif 186 | 187 | typedef void stbi_write_func(void *context, void *data, int size); 188 | 189 | STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); 190 | STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 191 | STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 192 | STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); 193 | STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); 194 | 195 | STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean); 196 | 197 | #endif//INCLUDE_STB_IMAGE_WRITE_H 198 | 199 | #ifdef STB_IMAGE_WRITE_IMPLEMENTATION 200 | 201 | #ifdef _WIN32 202 | #ifndef _CRT_SECURE_NO_WARNINGS 203 | #define _CRT_SECURE_NO_WARNINGS 204 | #endif 205 | #ifndef _CRT_NONSTDC_NO_DEPRECATE 206 | #define _CRT_NONSTDC_NO_DEPRECATE 207 | #endif 208 | #endif 209 | 210 | #ifndef STBI_WRITE_NO_STDIO 211 | #include 212 | #endif // STBI_WRITE_NO_STDIO 213 | 214 | #include 215 | #include 216 | #include 217 | #include 218 | 219 | #if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED)) 220 | // ok 221 | #elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED) 222 | // ok 223 | #else 224 | #error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)." 225 | #endif 226 | 227 | #ifndef STBIW_MALLOC 228 | #define STBIW_MALLOC(sz) malloc(sz) 229 | #define STBIW_REALLOC(p,newsz) realloc(p,newsz) 230 | #define STBIW_FREE(p) free(p) 231 | #endif 232 | 233 | #ifndef STBIW_REALLOC_SIZED 234 | #define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz) 235 | #endif 236 | 237 | 238 | #ifndef STBIW_MEMMOVE 239 | #define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz) 240 | #endif 241 | 242 | 243 | #ifndef STBIW_ASSERT 244 | #include 245 | #define STBIW_ASSERT(x) assert(x) 246 | #endif 247 | 248 | #define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff) 249 | 250 | #ifdef STB_IMAGE_WRITE_STATIC 251 | static int stbi_write_png_compression_level = 8; 252 | static int stbi_write_tga_with_rle = 1; 253 | static int stbi_write_force_png_filter = -1; 254 | #else 255 | int stbi_write_png_compression_level = 8; 256 | int stbi_write_tga_with_rle = 1; 257 | int stbi_write_force_png_filter = -1; 258 | #endif 259 | 260 | static int stbi__flip_vertically_on_write = 0; 261 | 262 | STBIWDEF void stbi_flip_vertically_on_write(int flag) 263 | { 264 | stbi__flip_vertically_on_write = flag; 265 | } 266 | 267 | typedef struct 268 | { 269 | stbi_write_func *func; 270 | void *context; 271 | unsigned char buffer[64]; 272 | int buf_used; 273 | } stbi__write_context; 274 | 275 | // initialize a callback-based context 276 | static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context) 277 | { 278 | s->func = c; 279 | s->context = context; 280 | } 281 | 282 | #ifndef STBI_WRITE_NO_STDIO 283 | 284 | static void stbi__stdio_write(void *context, void *data, int size) 285 | { 286 | fwrite(data,1,size,(FILE*) context); 287 | } 288 | 289 | #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) 290 | #ifdef __cplusplus 291 | #define STBIW_EXTERN extern "C" 292 | #else 293 | #define STBIW_EXTERN extern 294 | #endif 295 | STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); 296 | STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); 297 | 298 | STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) 299 | { 300 | return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); 301 | } 302 | #endif 303 | 304 | static FILE *stbiw__fopen(char const *filename, char const *mode) 305 | { 306 | FILE *f; 307 | #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) 308 | wchar_t wMode[64]; 309 | wchar_t wFilename[1024]; 310 | if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) 311 | return 0; 312 | 313 | if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) 314 | return 0; 315 | 316 | #if defined(_MSC_VER) && _MSC_VER >= 1400 317 | if (0 != _wfopen_s(&f, wFilename, wMode)) 318 | f = 0; 319 | #else 320 | f = _wfopen(wFilename, wMode); 321 | #endif 322 | 323 | #elif defined(_MSC_VER) && _MSC_VER >= 1400 324 | if (0 != fopen_s(&f, filename, mode)) 325 | f=0; 326 | #else 327 | f = fopen(filename, mode); 328 | #endif 329 | return f; 330 | } 331 | 332 | static int stbi__start_write_file(stbi__write_context *s, const char *filename) 333 | { 334 | FILE *f = stbiw__fopen(filename, "wb"); 335 | stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f); 336 | return f != NULL; 337 | } 338 | 339 | static void stbi__end_write_file(stbi__write_context *s) 340 | { 341 | fclose((FILE *)s->context); 342 | } 343 | 344 | #endif // !STBI_WRITE_NO_STDIO 345 | 346 | typedef unsigned int stbiw_uint32; 347 | typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1]; 348 | 349 | static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v) 350 | { 351 | while (*fmt) { 352 | switch (*fmt++) { 353 | case ' ': break; 354 | case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int)); 355 | s->func(s->context,&x,1); 356 | break; } 357 | case '2': { int x = va_arg(v,int); 358 | unsigned char b[2]; 359 | b[0] = STBIW_UCHAR(x); 360 | b[1] = STBIW_UCHAR(x>>8); 361 | s->func(s->context,b,2); 362 | break; } 363 | case '4': { stbiw_uint32 x = va_arg(v,int); 364 | unsigned char b[4]; 365 | b[0]=STBIW_UCHAR(x); 366 | b[1]=STBIW_UCHAR(x>>8); 367 | b[2]=STBIW_UCHAR(x>>16); 368 | b[3]=STBIW_UCHAR(x>>24); 369 | s->func(s->context,b,4); 370 | break; } 371 | default: 372 | STBIW_ASSERT(0); 373 | return; 374 | } 375 | } 376 | } 377 | 378 | static void stbiw__writef(stbi__write_context *s, const char *fmt, ...) 379 | { 380 | va_list v; 381 | va_start(v, fmt); 382 | stbiw__writefv(s, fmt, v); 383 | va_end(v); 384 | } 385 | 386 | static void stbiw__write_flush(stbi__write_context *s) 387 | { 388 | if (s->buf_used) { 389 | s->func(s->context, &s->buffer, s->buf_used); 390 | s->buf_used = 0; 391 | } 392 | } 393 | 394 | static void stbiw__putc(stbi__write_context *s, unsigned char c) 395 | { 396 | s->func(s->context, &c, 1); 397 | } 398 | 399 | static void stbiw__write1(stbi__write_context *s, unsigned char a) 400 | { 401 | if ((size_t)s->buf_used + 1 > sizeof(s->buffer)) 402 | stbiw__write_flush(s); 403 | s->buffer[s->buf_used++] = a; 404 | } 405 | 406 | static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c) 407 | { 408 | int n; 409 | if ((size_t)s->buf_used + 3 > sizeof(s->buffer)) 410 | stbiw__write_flush(s); 411 | n = s->buf_used; 412 | s->buf_used = n+3; 413 | s->buffer[n+0] = a; 414 | s->buffer[n+1] = b; 415 | s->buffer[n+2] = c; 416 | } 417 | 418 | static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d) 419 | { 420 | unsigned char bg[3] = { 255, 0, 255}, px[3]; 421 | int k; 422 | 423 | if (write_alpha < 0) 424 | stbiw__write1(s, d[comp - 1]); 425 | 426 | switch (comp) { 427 | case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case 428 | case 1: 429 | if (expand_mono) 430 | stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp 431 | else 432 | stbiw__write1(s, d[0]); // monochrome TGA 433 | break; 434 | case 4: 435 | if (!write_alpha) { 436 | // composite against pink background 437 | for (k = 0; k < 3; ++k) 438 | px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255; 439 | stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]); 440 | break; 441 | } 442 | /* FALLTHROUGH */ 443 | case 3: 444 | stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]); 445 | break; 446 | } 447 | if (write_alpha > 0) 448 | stbiw__write1(s, d[comp - 1]); 449 | } 450 | 451 | static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono) 452 | { 453 | stbiw_uint32 zero = 0; 454 | int i,j, j_end; 455 | 456 | if (y <= 0) 457 | return; 458 | 459 | if (stbi__flip_vertically_on_write) 460 | vdir *= -1; 461 | 462 | if (vdir < 0) { 463 | j_end = -1; j = y-1; 464 | } else { 465 | j_end = y; j = 0; 466 | } 467 | 468 | for (; j != j_end; j += vdir) { 469 | for (i=0; i < x; ++i) { 470 | unsigned char *d = (unsigned char *) data + (j*x+i)*comp; 471 | stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d); 472 | } 473 | stbiw__write_flush(s); 474 | s->func(s->context, &zero, scanline_pad); 475 | } 476 | } 477 | 478 | static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...) 479 | { 480 | if (y < 0 || x < 0) { 481 | return 0; 482 | } else { 483 | va_list v; 484 | va_start(v, fmt); 485 | stbiw__writefv(s, fmt, v); 486 | va_end(v); 487 | stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono); 488 | return 1; 489 | } 490 | } 491 | 492 | static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data) 493 | { 494 | if (comp != 4) { 495 | // write RGB bitmap 496 | int pad = (-x*3) & 3; 497 | return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad, 498 | "11 4 22 4" "4 44 22 444444", 499 | 'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header 500 | 40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header 501 | } else { 502 | // RGBA bitmaps need a v4 header 503 | // use BI_BITFIELDS mode with 32bpp and alpha mask 504 | // (straight BI_RGB with alpha mask doesn't work in most readers) 505 | return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0, 506 | "11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444", 507 | 'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header 508 | 108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header 509 | } 510 | } 511 | 512 | STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) 513 | { 514 | stbi__write_context s = { 0 }; 515 | stbi__start_write_callbacks(&s, func, context); 516 | return stbi_write_bmp_core(&s, x, y, comp, data); 517 | } 518 | 519 | #ifndef STBI_WRITE_NO_STDIO 520 | STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data) 521 | { 522 | stbi__write_context s = { 0 }; 523 | if (stbi__start_write_file(&s,filename)) { 524 | int r = stbi_write_bmp_core(&s, x, y, comp, data); 525 | stbi__end_write_file(&s); 526 | return r; 527 | } else 528 | return 0; 529 | } 530 | #endif //!STBI_WRITE_NO_STDIO 531 | 532 | static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data) 533 | { 534 | int has_alpha = (comp == 2 || comp == 4); 535 | int colorbytes = has_alpha ? comp-1 : comp; 536 | int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3 537 | 538 | if (y < 0 || x < 0) 539 | return 0; 540 | 541 | if (!stbi_write_tga_with_rle) { 542 | return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0, 543 | "111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8); 544 | } else { 545 | int i,j,k; 546 | int jend, jdir; 547 | 548 | stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8); 549 | 550 | if (stbi__flip_vertically_on_write) { 551 | j = 0; 552 | jend = y; 553 | jdir = 1; 554 | } else { 555 | j = y-1; 556 | jend = -1; 557 | jdir = -1; 558 | } 559 | for (; j != jend; j += jdir) { 560 | unsigned char *row = (unsigned char *) data + j * x * comp; 561 | int len; 562 | 563 | for (i = 0; i < x; i += len) { 564 | unsigned char *begin = row + i * comp; 565 | int diff = 1; 566 | len = 1; 567 | 568 | if (i < x - 1) { 569 | ++len; 570 | diff = memcmp(begin, row + (i + 1) * comp, comp); 571 | if (diff) { 572 | const unsigned char *prev = begin; 573 | for (k = i + 2; k < x && len < 128; ++k) { 574 | if (memcmp(prev, row + k * comp, comp)) { 575 | prev += comp; 576 | ++len; 577 | } else { 578 | --len; 579 | break; 580 | } 581 | } 582 | } else { 583 | for (k = i + 2; k < x && len < 128; ++k) { 584 | if (!memcmp(begin, row + k * comp, comp)) { 585 | ++len; 586 | } else { 587 | break; 588 | } 589 | } 590 | } 591 | } 592 | 593 | if (diff) { 594 | unsigned char header = STBIW_UCHAR(len - 1); 595 | stbiw__write1(s, header); 596 | for (k = 0; k < len; ++k) { 597 | stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp); 598 | } 599 | } else { 600 | unsigned char header = STBIW_UCHAR(len - 129); 601 | stbiw__write1(s, header); 602 | stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin); 603 | } 604 | } 605 | } 606 | stbiw__write_flush(s); 607 | } 608 | return 1; 609 | } 610 | 611 | STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) 612 | { 613 | stbi__write_context s = { 0 }; 614 | stbi__start_write_callbacks(&s, func, context); 615 | return stbi_write_tga_core(&s, x, y, comp, (void *) data); 616 | } 617 | 618 | #ifndef STBI_WRITE_NO_STDIO 619 | STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data) 620 | { 621 | stbi__write_context s = { 0 }; 622 | if (stbi__start_write_file(&s,filename)) { 623 | int r = stbi_write_tga_core(&s, x, y, comp, (void *) data); 624 | stbi__end_write_file(&s); 625 | return r; 626 | } else 627 | return 0; 628 | } 629 | #endif 630 | 631 | // ************************************************************************************************* 632 | // Radiance RGBE HDR writer 633 | // by Baldur Karlsson 634 | 635 | #define stbiw__max(a, b) ((a) > (b) ? (a) : (b)) 636 | 637 | #ifndef STBI_WRITE_NO_STDIO 638 | 639 | static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear) 640 | { 641 | int exponent; 642 | float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2])); 643 | 644 | if (maxcomp < 1e-32f) { 645 | rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0; 646 | } else { 647 | float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp; 648 | 649 | rgbe[0] = (unsigned char)(linear[0] * normalize); 650 | rgbe[1] = (unsigned char)(linear[1] * normalize); 651 | rgbe[2] = (unsigned char)(linear[2] * normalize); 652 | rgbe[3] = (unsigned char)(exponent + 128); 653 | } 654 | } 655 | 656 | static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte) 657 | { 658 | unsigned char lengthbyte = STBIW_UCHAR(length+128); 659 | STBIW_ASSERT(length+128 <= 255); 660 | s->func(s->context, &lengthbyte, 1); 661 | s->func(s->context, &databyte, 1); 662 | } 663 | 664 | static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data) 665 | { 666 | unsigned char lengthbyte = STBIW_UCHAR(length); 667 | STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code 668 | s->func(s->context, &lengthbyte, 1); 669 | s->func(s->context, data, length); 670 | } 671 | 672 | static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline) 673 | { 674 | unsigned char scanlineheader[4] = { 2, 2, 0, 0 }; 675 | unsigned char rgbe[4]; 676 | float linear[3]; 677 | int x; 678 | 679 | scanlineheader[2] = (width&0xff00)>>8; 680 | scanlineheader[3] = (width&0x00ff); 681 | 682 | /* skip RLE for images too small or large */ 683 | if (width < 8 || width >= 32768) { 684 | for (x=0; x < width; x++) { 685 | switch (ncomp) { 686 | case 4: /* fallthrough */ 687 | case 3: linear[2] = scanline[x*ncomp + 2]; 688 | linear[1] = scanline[x*ncomp + 1]; 689 | linear[0] = scanline[x*ncomp + 0]; 690 | break; 691 | default: 692 | linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; 693 | break; 694 | } 695 | stbiw__linear_to_rgbe(rgbe, linear); 696 | s->func(s->context, rgbe, 4); 697 | } 698 | } else { 699 | int c,r; 700 | /* encode into scratch buffer */ 701 | for (x=0; x < width; x++) { 702 | switch(ncomp) { 703 | case 4: /* fallthrough */ 704 | case 3: linear[2] = scanline[x*ncomp + 2]; 705 | linear[1] = scanline[x*ncomp + 1]; 706 | linear[0] = scanline[x*ncomp + 0]; 707 | break; 708 | default: 709 | linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; 710 | break; 711 | } 712 | stbiw__linear_to_rgbe(rgbe, linear); 713 | scratch[x + width*0] = rgbe[0]; 714 | scratch[x + width*1] = rgbe[1]; 715 | scratch[x + width*2] = rgbe[2]; 716 | scratch[x + width*3] = rgbe[3]; 717 | } 718 | 719 | s->func(s->context, scanlineheader, 4); 720 | 721 | /* RLE each component separately */ 722 | for (c=0; c < 4; c++) { 723 | unsigned char *comp = &scratch[width*c]; 724 | 725 | x = 0; 726 | while (x < width) { 727 | // find first run 728 | r = x; 729 | while (r+2 < width) { 730 | if (comp[r] == comp[r+1] && comp[r] == comp[r+2]) 731 | break; 732 | ++r; 733 | } 734 | if (r+2 >= width) 735 | r = width; 736 | // dump up to first run 737 | while (x < r) { 738 | int len = r-x; 739 | if (len > 128) len = 128; 740 | stbiw__write_dump_data(s, len, &comp[x]); 741 | x += len; 742 | } 743 | // if there's a run, output it 744 | if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd 745 | // find next byte after run 746 | while (r < width && comp[r] == comp[x]) 747 | ++r; 748 | // output run up to r 749 | while (x < r) { 750 | int len = r-x; 751 | if (len > 127) len = 127; 752 | stbiw__write_run_data(s, len, comp[x]); 753 | x += len; 754 | } 755 | } 756 | } 757 | } 758 | } 759 | } 760 | 761 | static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data) 762 | { 763 | if (y <= 0 || x <= 0 || data == NULL) 764 | return 0; 765 | else { 766 | // Each component is stored separately. Allocate scratch space for full output scanline. 767 | unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4); 768 | int i, len; 769 | char buffer[128]; 770 | char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n"; 771 | s->func(s->context, header, sizeof(header)-1); 772 | 773 | #ifdef __STDC_LIB_EXT1__ 774 | len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); 775 | #else 776 | len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); 777 | #endif 778 | s->func(s->context, buffer, len); 779 | 780 | for(i=0; i < y; i++) 781 | stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i)); 782 | STBIW_FREE(scratch); 783 | return 1; 784 | } 785 | } 786 | 787 | STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data) 788 | { 789 | stbi__write_context s = { 0 }; 790 | stbi__start_write_callbacks(&s, func, context); 791 | return stbi_write_hdr_core(&s, x, y, comp, (float *) data); 792 | } 793 | 794 | STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data) 795 | { 796 | stbi__write_context s = { 0 }; 797 | if (stbi__start_write_file(&s,filename)) { 798 | int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data); 799 | stbi__end_write_file(&s); 800 | return r; 801 | } else 802 | return 0; 803 | } 804 | #endif // STBI_WRITE_NO_STDIO 805 | 806 | 807 | ////////////////////////////////////////////////////////////////////////////// 808 | // 809 | // PNG writer 810 | // 811 | 812 | #ifndef STBIW_ZLIB_COMPRESS 813 | // stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size() 814 | #define stbiw__sbraw(a) ((int *) (void *) (a) - 2) 815 | #define stbiw__sbm(a) stbiw__sbraw(a)[0] 816 | #define stbiw__sbn(a) stbiw__sbraw(a)[1] 817 | 818 | #define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a)) 819 | #define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0) 820 | #define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a))) 821 | 822 | #define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v)) 823 | #define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0) 824 | #define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0) 825 | 826 | static void *stbiw__sbgrowf(void **arr, int increment, int itemsize) 827 | { 828 | int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1; 829 | void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2); 830 | STBIW_ASSERT(p); 831 | if (p) { 832 | if (!*arr) ((int *) p)[1] = 0; 833 | *arr = (void *) ((int *) p + 2); 834 | stbiw__sbm(*arr) = m; 835 | } 836 | return *arr; 837 | } 838 | 839 | static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount) 840 | { 841 | while (*bitcount >= 8) { 842 | stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer)); 843 | *bitbuffer >>= 8; 844 | *bitcount -= 8; 845 | } 846 | return data; 847 | } 848 | 849 | static int stbiw__zlib_bitrev(int code, int codebits) 850 | { 851 | int res=0; 852 | while (codebits--) { 853 | res = (res << 1) | (code & 1); 854 | code >>= 1; 855 | } 856 | return res; 857 | } 858 | 859 | static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit) 860 | { 861 | int i; 862 | for (i=0; i < limit && i < 258; ++i) 863 | if (a[i] != b[i]) break; 864 | return i; 865 | } 866 | 867 | static unsigned int stbiw__zhash(unsigned char *data) 868 | { 869 | stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16); 870 | hash ^= hash << 3; 871 | hash += hash >> 5; 872 | hash ^= hash << 4; 873 | hash += hash >> 17; 874 | hash ^= hash << 25; 875 | hash += hash >> 6; 876 | return hash; 877 | } 878 | 879 | #define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount)) 880 | #define stbiw__zlib_add(code,codebits) \ 881 | (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush()) 882 | #define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c) 883 | // default huffman tables 884 | #define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8) 885 | #define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9) 886 | #define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7) 887 | #define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8) 888 | #define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n)) 889 | #define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n)) 890 | 891 | #define stbiw__ZHASH 16384 892 | 893 | #endif // STBIW_ZLIB_COMPRESS 894 | 895 | STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality) 896 | { 897 | #ifdef STBIW_ZLIB_COMPRESS 898 | // user provided a zlib compress implementation, use that 899 | return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality); 900 | #else // use builtin 901 | static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 }; 902 | static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 }; 903 | static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 }; 904 | static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 }; 905 | unsigned int bitbuf=0; 906 | int i,j, bitcount=0; 907 | unsigned char *out = NULL; 908 | unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**)); 909 | if (hash_table == NULL) 910 | return NULL; 911 | if (quality < 5) quality = 5; 912 | 913 | stbiw__sbpush(out, 0x78); // DEFLATE 32K window 914 | stbiw__sbpush(out, 0x5e); // FLEVEL = 1 915 | stbiw__zlib_add(1,1); // BFINAL = 1 916 | stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman 917 | 918 | for (i=0; i < stbiw__ZHASH; ++i) 919 | hash_table[i] = NULL; 920 | 921 | i=0; 922 | while (i < data_len-3) { 923 | // hash next 3 bytes of data to be compressed 924 | int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3; 925 | unsigned char *bestloc = 0; 926 | unsigned char **hlist = hash_table[h]; 927 | int n = stbiw__sbcount(hlist); 928 | for (j=0; j < n; ++j) { 929 | if (hlist[j]-data > i-32768) { // if entry lies within window 930 | int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i); 931 | if (d >= best) { best=d; bestloc=hlist[j]; } 932 | } 933 | } 934 | // when hash table entry is too long, delete half the entries 935 | if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) { 936 | STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality); 937 | stbiw__sbn(hash_table[h]) = quality; 938 | } 939 | stbiw__sbpush(hash_table[h],data+i); 940 | 941 | if (bestloc) { 942 | // "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal 943 | h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1); 944 | hlist = hash_table[h]; 945 | n = stbiw__sbcount(hlist); 946 | for (j=0; j < n; ++j) { 947 | if (hlist[j]-data > i-32767) { 948 | int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1); 949 | if (e > best) { // if next match is better, bail on current match 950 | bestloc = NULL; 951 | break; 952 | } 953 | } 954 | } 955 | } 956 | 957 | if (bestloc) { 958 | int d = (int) (data+i - bestloc); // distance back 959 | STBIW_ASSERT(d <= 32767 && best <= 258); 960 | for (j=0; best > lengthc[j+1]-1; ++j); 961 | stbiw__zlib_huff(j+257); 962 | if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]); 963 | for (j=0; d > distc[j+1]-1; ++j); 964 | stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5); 965 | if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]); 966 | i += best; 967 | } else { 968 | stbiw__zlib_huffb(data[i]); 969 | ++i; 970 | } 971 | } 972 | // write out final bytes 973 | for (;i < data_len; ++i) 974 | stbiw__zlib_huffb(data[i]); 975 | stbiw__zlib_huff(256); // end of block 976 | // pad with 0 bits to byte boundary 977 | while (bitcount) 978 | stbiw__zlib_add(0,1); 979 | 980 | for (i=0; i < stbiw__ZHASH; ++i) 981 | (void) stbiw__sbfree(hash_table[i]); 982 | STBIW_FREE(hash_table); 983 | 984 | // store uncompressed instead if compression was worse 985 | if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) { 986 | stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1 987 | for (j = 0; j < data_len;) { 988 | int blocklen = data_len - j; 989 | if (blocklen > 32767) blocklen = 32767; 990 | stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression 991 | stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN 992 | stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8)); 993 | stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN 994 | stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8)); 995 | memcpy(out+stbiw__sbn(out), data+j, blocklen); 996 | stbiw__sbn(out) += blocklen; 997 | j += blocklen; 998 | } 999 | } 1000 | 1001 | { 1002 | // compute adler32 on input 1003 | unsigned int s1=1, s2=0; 1004 | int blocklen = (int) (data_len % 5552); 1005 | j=0; 1006 | while (j < data_len) { 1007 | for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; } 1008 | s1 %= 65521; s2 %= 65521; 1009 | j += blocklen; 1010 | blocklen = 5552; 1011 | } 1012 | stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8)); 1013 | stbiw__sbpush(out, STBIW_UCHAR(s2)); 1014 | stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8)); 1015 | stbiw__sbpush(out, STBIW_UCHAR(s1)); 1016 | } 1017 | *out_len = stbiw__sbn(out); 1018 | // make returned pointer freeable 1019 | STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len); 1020 | return (unsigned char *) stbiw__sbraw(out); 1021 | #endif // STBIW_ZLIB_COMPRESS 1022 | } 1023 | 1024 | static unsigned int stbiw__crc32(unsigned char *buffer, int len) 1025 | { 1026 | #ifdef STBIW_CRC32 1027 | return STBIW_CRC32(buffer, len); 1028 | #else 1029 | static unsigned int crc_table[256] = 1030 | { 1031 | 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3, 1032 | 0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, 1033 | 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, 1034 | 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5, 1035 | 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, 1036 | 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, 1037 | 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F, 1038 | 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, 1039 | 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433, 1040 | 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01, 1041 | 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, 1042 | 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, 1043 | 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB, 1044 | 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, 1045 | 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F, 1046 | 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD, 1047 | 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, 1048 | 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, 1049 | 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7, 1050 | 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, 1051 | 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, 1052 | 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79, 1053 | 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, 1054 | 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, 1055 | 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713, 1056 | 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, 1057 | 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, 1058 | 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45, 1059 | 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, 1060 | 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9, 1061 | 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF, 1062 | 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D 1063 | }; 1064 | 1065 | unsigned int crc = ~0u; 1066 | int i; 1067 | for (i=0; i < len; ++i) 1068 | crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)]; 1069 | return ~crc; 1070 | #endif 1071 | } 1072 | 1073 | #define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4) 1074 | #define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v)); 1075 | #define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3]) 1076 | 1077 | static void stbiw__wpcrc(unsigned char **data, int len) 1078 | { 1079 | unsigned int crc = stbiw__crc32(*data - len - 4, len+4); 1080 | stbiw__wp32(*data, crc); 1081 | } 1082 | 1083 | static unsigned char stbiw__paeth(int a, int b, int c) 1084 | { 1085 | int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c); 1086 | if (pa <= pb && pa <= pc) return STBIW_UCHAR(a); 1087 | if (pb <= pc) return STBIW_UCHAR(b); 1088 | return STBIW_UCHAR(c); 1089 | } 1090 | 1091 | // @OPTIMIZE: provide an option that always forces left-predict or paeth predict 1092 | static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer) 1093 | { 1094 | static int mapping[] = { 0,1,2,3,4 }; 1095 | static int firstmap[] = { 0,1,0,5,6 }; 1096 | int *mymap = (y != 0) ? mapping : firstmap; 1097 | int i; 1098 | int type = mymap[filter_type]; 1099 | unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y); 1100 | int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes; 1101 | 1102 | if (type==0) { 1103 | memcpy(line_buffer, z, width*n); 1104 | return; 1105 | } 1106 | 1107 | // first loop isn't optimized since it's just one pixel 1108 | for (i = 0; i < n; ++i) { 1109 | switch (type) { 1110 | case 1: line_buffer[i] = z[i]; break; 1111 | case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break; 1112 | case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break; 1113 | case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break; 1114 | case 5: line_buffer[i] = z[i]; break; 1115 | case 6: line_buffer[i] = z[i]; break; 1116 | } 1117 | } 1118 | switch (type) { 1119 | case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break; 1120 | case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break; 1121 | case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break; 1122 | case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break; 1123 | case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break; 1124 | case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break; 1125 | } 1126 | } 1127 | 1128 | STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len) 1129 | { 1130 | int force_filter = stbi_write_force_png_filter; 1131 | int ctype[5] = { -1, 0, 4, 2, 6 }; 1132 | unsigned char sig[8] = { 137,80,78,71,13,10,26,10 }; 1133 | unsigned char *out,*o, *filt, *zlib; 1134 | signed char *line_buffer; 1135 | int j,zlen; 1136 | 1137 | if (stride_bytes == 0) 1138 | stride_bytes = x * n; 1139 | 1140 | if (force_filter >= 5) { 1141 | force_filter = -1; 1142 | } 1143 | 1144 | filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0; 1145 | line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; } 1146 | for (j=0; j < y; ++j) { 1147 | int filter_type; 1148 | if (force_filter > -1) { 1149 | filter_type = force_filter; 1150 | stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer); 1151 | } else { // Estimate the best filter by running through all of them: 1152 | int best_filter = 0, best_filter_val = 0x7fffffff, est, i; 1153 | for (filter_type = 0; filter_type < 5; filter_type++) { 1154 | stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer); 1155 | 1156 | // Estimate the entropy of the line using this filter; the less, the better. 1157 | est = 0; 1158 | for (i = 0; i < x*n; ++i) { 1159 | est += abs((signed char) line_buffer[i]); 1160 | } 1161 | if (est < best_filter_val) { 1162 | best_filter_val = est; 1163 | best_filter = filter_type; 1164 | } 1165 | } 1166 | if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it 1167 | stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer); 1168 | filter_type = best_filter; 1169 | } 1170 | } 1171 | // when we get here, filter_type contains the filter type, and line_buffer contains the data 1172 | filt[j*(x*n+1)] = (unsigned char) filter_type; 1173 | STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n); 1174 | } 1175 | STBIW_FREE(line_buffer); 1176 | zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level); 1177 | STBIW_FREE(filt); 1178 | if (!zlib) return 0; 1179 | 1180 | // each tag requires 12 bytes of overhead 1181 | out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12); 1182 | if (!out) return 0; 1183 | *out_len = 8 + 12+13 + 12+zlen + 12; 1184 | 1185 | o=out; 1186 | STBIW_MEMMOVE(o,sig,8); o+= 8; 1187 | stbiw__wp32(o, 13); // header length 1188 | stbiw__wptag(o, "IHDR"); 1189 | stbiw__wp32(o, x); 1190 | stbiw__wp32(o, y); 1191 | *o++ = 8; 1192 | *o++ = STBIW_UCHAR(ctype[n]); 1193 | *o++ = 0; 1194 | *o++ = 0; 1195 | *o++ = 0; 1196 | stbiw__wpcrc(&o,13); 1197 | 1198 | stbiw__wp32(o, zlen); 1199 | stbiw__wptag(o, "IDAT"); 1200 | STBIW_MEMMOVE(o, zlib, zlen); 1201 | o += zlen; 1202 | STBIW_FREE(zlib); 1203 | stbiw__wpcrc(&o, zlen); 1204 | 1205 | stbiw__wp32(o,0); 1206 | stbiw__wptag(o, "IEND"); 1207 | stbiw__wpcrc(&o,0); 1208 | 1209 | STBIW_ASSERT(o == out + *out_len); 1210 | 1211 | return out; 1212 | } 1213 | 1214 | #ifndef STBI_WRITE_NO_STDIO 1215 | STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes) 1216 | { 1217 | FILE *f; 1218 | int len; 1219 | unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); 1220 | if (png == NULL) return 0; 1221 | 1222 | f = stbiw__fopen(filename, "wb"); 1223 | if (!f) { STBIW_FREE(png); return 0; } 1224 | fwrite(png, 1, len, f); 1225 | fclose(f); 1226 | STBIW_FREE(png); 1227 | return 1; 1228 | } 1229 | #endif 1230 | 1231 | STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes) 1232 | { 1233 | int len; 1234 | unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); 1235 | if (png == NULL) return 0; 1236 | func(context, png, len); 1237 | STBIW_FREE(png); 1238 | return 1; 1239 | } 1240 | 1241 | 1242 | /* *************************************************************************** 1243 | * 1244 | * JPEG writer 1245 | * 1246 | * This is based on Jon Olick's jo_jpeg.cpp: 1247 | * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html 1248 | */ 1249 | 1250 | static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18, 1251 | 24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 }; 1252 | 1253 | static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) { 1254 | int bitBuf = *bitBufP, bitCnt = *bitCntP; 1255 | bitCnt += bs[1]; 1256 | bitBuf |= bs[0] << (24 - bitCnt); 1257 | while(bitCnt >= 8) { 1258 | unsigned char c = (bitBuf >> 16) & 255; 1259 | stbiw__putc(s, c); 1260 | if(c == 255) { 1261 | stbiw__putc(s, 0); 1262 | } 1263 | bitBuf <<= 8; 1264 | bitCnt -= 8; 1265 | } 1266 | *bitBufP = bitBuf; 1267 | *bitCntP = bitCnt; 1268 | } 1269 | 1270 | static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) { 1271 | float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p; 1272 | float z1, z2, z3, z4, z5, z11, z13; 1273 | 1274 | float tmp0 = d0 + d7; 1275 | float tmp7 = d0 - d7; 1276 | float tmp1 = d1 + d6; 1277 | float tmp6 = d1 - d6; 1278 | float tmp2 = d2 + d5; 1279 | float tmp5 = d2 - d5; 1280 | float tmp3 = d3 + d4; 1281 | float tmp4 = d3 - d4; 1282 | 1283 | // Even part 1284 | float tmp10 = tmp0 + tmp3; // phase 2 1285 | float tmp13 = tmp0 - tmp3; 1286 | float tmp11 = tmp1 + tmp2; 1287 | float tmp12 = tmp1 - tmp2; 1288 | 1289 | d0 = tmp10 + tmp11; // phase 3 1290 | d4 = tmp10 - tmp11; 1291 | 1292 | z1 = (tmp12 + tmp13) * 0.707106781f; // c4 1293 | d2 = tmp13 + z1; // phase 5 1294 | d6 = tmp13 - z1; 1295 | 1296 | // Odd part 1297 | tmp10 = tmp4 + tmp5; // phase 2 1298 | tmp11 = tmp5 + tmp6; 1299 | tmp12 = tmp6 + tmp7; 1300 | 1301 | // The rotator is modified from fig 4-8 to avoid extra negations. 1302 | z5 = (tmp10 - tmp12) * 0.382683433f; // c6 1303 | z2 = tmp10 * 0.541196100f + z5; // c2-c6 1304 | z4 = tmp12 * 1.306562965f + z5; // c2+c6 1305 | z3 = tmp11 * 0.707106781f; // c4 1306 | 1307 | z11 = tmp7 + z3; // phase 5 1308 | z13 = tmp7 - z3; 1309 | 1310 | *d5p = z13 + z2; // phase 6 1311 | *d3p = z13 - z2; 1312 | *d1p = z11 + z4; 1313 | *d7p = z11 - z4; 1314 | 1315 | *d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6; 1316 | } 1317 | 1318 | static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) { 1319 | int tmp1 = val < 0 ? -val : val; 1320 | val = val < 0 ? val-1 : val; 1321 | bits[1] = 1; 1322 | while(tmp1 >>= 1) { 1323 | ++bits[1]; 1324 | } 1325 | bits[0] = val & ((1<0)&&(DU[end0pos]==0); --end0pos) { 1368 | } 1369 | // end0pos = first element in reverse order !=0 1370 | if(end0pos == 0) { 1371 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); 1372 | return DU[0]; 1373 | } 1374 | for(i = 1; i <= end0pos; ++i) { 1375 | int startpos = i; 1376 | int nrzeroes; 1377 | unsigned short bits[2]; 1378 | for (; DU[i]==0 && i<=end0pos; ++i) { 1379 | } 1380 | nrzeroes = i-startpos; 1381 | if ( nrzeroes >= 16 ) { 1382 | int lng = nrzeroes>>4; 1383 | int nrmarker; 1384 | for (nrmarker=1; nrmarker <= lng; ++nrmarker) 1385 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes); 1386 | nrzeroes &= 15; 1387 | } 1388 | stbiw__jpg_calcBits(DU[i], bits); 1389 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]); 1390 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits); 1391 | } 1392 | if(end0pos != 63) { 1393 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); 1394 | } 1395 | return DU[0]; 1396 | } 1397 | 1398 | static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) { 1399 | // Constants that don't pollute global namespace 1400 | static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0}; 1401 | static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; 1402 | static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d}; 1403 | static const unsigned char std_ac_luminance_values[] = { 1404 | 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08, 1405 | 0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28, 1406 | 0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59, 1407 | 0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89, 1408 | 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6, 1409 | 0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2, 1410 | 0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa 1411 | }; 1412 | static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0}; 1413 | static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; 1414 | static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77}; 1415 | static const unsigned char std_ac_chrominance_values[] = { 1416 | 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91, 1417 | 0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26, 1418 | 0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58, 1419 | 0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87, 1420 | 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4, 1421 | 0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda, 1422 | 0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa 1423 | }; 1424 | // Huffman tables 1425 | static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}}; 1426 | static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}}; 1427 | static const unsigned short YAC_HT[256][2] = { 1428 | {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1429 | {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1430 | {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1431 | {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1432 | {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1433 | {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1434 | {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1435 | {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1436 | {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1437 | {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1438 | {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1439 | {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1440 | {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1441 | {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1442 | {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0}, 1443 | {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} 1444 | }; 1445 | static const unsigned short UVAC_HT[256][2] = { 1446 | {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1447 | {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1448 | {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1449 | {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1450 | {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1451 | {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1452 | {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1453 | {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1454 | {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1455 | {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1456 | {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1457 | {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1458 | {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1459 | {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1460 | {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0}, 1461 | {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} 1462 | }; 1463 | static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22, 1464 | 37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99}; 1465 | static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99, 1466 | 99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99}; 1467 | static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f, 1468 | 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f }; 1469 | 1470 | int row, col, i, k, subsample; 1471 | float fdtbl_Y[64], fdtbl_UV[64]; 1472 | unsigned char YTable[64], UVTable[64]; 1473 | 1474 | if(!data || !width || !height || comp > 4 || comp < 1) { 1475 | return 0; 1476 | } 1477 | 1478 | quality = quality ? quality : 90; 1479 | subsample = quality <= 90 ? 1 : 0; 1480 | quality = quality < 1 ? 1 : quality > 100 ? 100 : quality; 1481 | quality = quality < 50 ? 5000 / quality : 200 - quality * 2; 1482 | 1483 | for(i = 0; i < 64; ++i) { 1484 | int uvti, yti = (YQT[i]*quality+50)/100; 1485 | YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti); 1486 | uvti = (UVQT[i]*quality+50)/100; 1487 | UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti); 1488 | } 1489 | 1490 | for(row = 0, k = 0; row < 8; ++row) { 1491 | for(col = 0; col < 8; ++col, ++k) { 1492 | fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); 1493 | fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); 1494 | } 1495 | } 1496 | 1497 | // Write Headers 1498 | { 1499 | static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 }; 1500 | static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 }; 1501 | const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width), 1502 | 3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 }; 1503 | s->func(s->context, (void*)head0, sizeof(head0)); 1504 | s->func(s->context, (void*)YTable, sizeof(YTable)); 1505 | stbiw__putc(s, 1); 1506 | s->func(s->context, UVTable, sizeof(UVTable)); 1507 | s->func(s->context, (void*)head1, sizeof(head1)); 1508 | s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1); 1509 | s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values)); 1510 | stbiw__putc(s, 0x10); // HTYACinfo 1511 | s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1); 1512 | s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values)); 1513 | stbiw__putc(s, 1); // HTUDCinfo 1514 | s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1); 1515 | s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values)); 1516 | stbiw__putc(s, 0x11); // HTUACinfo 1517 | s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1); 1518 | s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values)); 1519 | s->func(s->context, (void*)head2, sizeof(head2)); 1520 | } 1521 | 1522 | // Encode 8x8 macroblocks 1523 | { 1524 | static const unsigned short fillBits[] = {0x7F, 7}; 1525 | int DCY=0, DCU=0, DCV=0; 1526 | int bitBuf=0, bitCnt=0; 1527 | // comp == 2 is grey+alpha (alpha is ignored) 1528 | int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0; 1529 | const unsigned char *dataR = (const unsigned char *)data; 1530 | const unsigned char *dataG = dataR + ofsG; 1531 | const unsigned char *dataB = dataR + ofsB; 1532 | int x, y, pos; 1533 | if(subsample) { 1534 | for(y = 0; y < height; y += 16) { 1535 | for(x = 0; x < width; x += 16) { 1536 | float Y[256], U[256], V[256]; 1537 | for(row = y, pos = 0; row < y+16; ++row) { 1538 | // row >= height => use last input row 1539 | int clamped_row = (row < height) ? row : height - 1; 1540 | int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; 1541 | for(col = x; col < x+16; ++col, ++pos) { 1542 | // if col >= width => use pixel from last input column 1543 | int p = base_p + ((col < width) ? col : (width-1))*comp; 1544 | float r = dataR[p], g = dataG[p], b = dataB[p]; 1545 | Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; 1546 | U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; 1547 | V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; 1548 | } 1549 | } 1550 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1551 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1552 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1553 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1554 | 1555 | // subsample U,V 1556 | { 1557 | float subU[64], subV[64]; 1558 | int yy, xx; 1559 | for(yy = 0, pos = 0; yy < 8; ++yy) { 1560 | for(xx = 0; xx < 8; ++xx, ++pos) { 1561 | int j = yy*32+xx*2; 1562 | subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f; 1563 | subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f; 1564 | } 1565 | } 1566 | DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); 1567 | DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); 1568 | } 1569 | } 1570 | } 1571 | } else { 1572 | for(y = 0; y < height; y += 8) { 1573 | for(x = 0; x < width; x += 8) { 1574 | float Y[64], U[64], V[64]; 1575 | for(row = y, pos = 0; row < y+8; ++row) { 1576 | // row >= height => use last input row 1577 | int clamped_row = (row < height) ? row : height - 1; 1578 | int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; 1579 | for(col = x; col < x+8; ++col, ++pos) { 1580 | // if col >= width => use pixel from last input column 1581 | int p = base_p + ((col < width) ? col : (width-1))*comp; 1582 | float r = dataR[p], g = dataG[p], b = dataB[p]; 1583 | Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; 1584 | U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; 1585 | V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; 1586 | } 1587 | } 1588 | 1589 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1590 | DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); 1591 | DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); 1592 | } 1593 | } 1594 | } 1595 | 1596 | // Do the bit alignment of the EOI marker 1597 | stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits); 1598 | } 1599 | 1600 | // EOI 1601 | stbiw__putc(s, 0xFF); 1602 | stbiw__putc(s, 0xD9); 1603 | 1604 | return 1; 1605 | } 1606 | 1607 | STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality) 1608 | { 1609 | stbi__write_context s = { 0 }; 1610 | stbi__start_write_callbacks(&s, func, context); 1611 | return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality); 1612 | } 1613 | 1614 | 1615 | #ifndef STBI_WRITE_NO_STDIO 1616 | STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality) 1617 | { 1618 | stbi__write_context s = { 0 }; 1619 | if (stbi__start_write_file(&s,filename)) { 1620 | int r = stbi_write_jpg_core(&s, x, y, comp, data, quality); 1621 | stbi__end_write_file(&s); 1622 | return r; 1623 | } else 1624 | return 0; 1625 | } 1626 | #endif 1627 | 1628 | #endif // STB_IMAGE_WRITE_IMPLEMENTATION 1629 | 1630 | /* Revision history 1631 | 1.16 (2021-07-11) 1632 | make Deflate code emit uncompressed blocks when it would otherwise expand 1633 | support writing BMPs with alpha channel 1634 | 1.15 (2020-07-13) unknown 1635 | 1.14 (2020-02-02) updated JPEG writer to downsample chroma channels 1636 | 1.13 1637 | 1.12 1638 | 1.11 (2019-08-11) 1639 | 1640 | 1.10 (2019-02-07) 1641 | support utf8 filenames in Windows; fix warnings and platform ifdefs 1642 | 1.09 (2018-02-11) 1643 | fix typo in zlib quality API, improve STB_I_W_STATIC in C++ 1644 | 1.08 (2018-01-29) 1645 | add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter 1646 | 1.07 (2017-07-24) 1647 | doc fix 1648 | 1.06 (2017-07-23) 1649 | writing JPEG (using Jon Olick's code) 1650 | 1.05 ??? 1651 | 1.04 (2017-03-03) 1652 | monochrome BMP expansion 1653 | 1.03 ??? 1654 | 1.02 (2016-04-02) 1655 | avoid allocating large structures on the stack 1656 | 1.01 (2016-01-16) 1657 | STBIW_REALLOC_SIZED: support allocators with no realloc support 1658 | avoid race-condition in crc initialization 1659 | minor compile issues 1660 | 1.00 (2015-09-14) 1661 | installable file IO function 1662 | 0.99 (2015-09-13) 1663 | warning fixes; TGA rle support 1664 | 0.98 (2015-04-08) 1665 | added STBIW_MALLOC, STBIW_ASSERT etc 1666 | 0.97 (2015-01-18) 1667 | fixed HDR asserts, rewrote HDR rle logic 1668 | 0.96 (2015-01-17) 1669 | add HDR output 1670 | fix monochrome BMP 1671 | 0.95 (2014-08-17) 1672 | add monochrome TGA output 1673 | 0.94 (2014-05-31) 1674 | rename private functions to avoid conflicts with stb_image.h 1675 | 0.93 (2014-05-27) 1676 | warning fixes 1677 | 0.92 (2010-08-01) 1678 | casts to unsigned char to fix warnings 1679 | 0.91 (2010-07-17) 1680 | first public release 1681 | 0.90 first internal release 1682 | */ 1683 | 1684 | /* 1685 | ------------------------------------------------------------------------------ 1686 | This software is available under 2 licenses -- choose whichever you prefer. 1687 | ------------------------------------------------------------------------------ 1688 | ALTERNATIVE A - MIT License 1689 | Copyright (c) 2017 Sean Barrett 1690 | Permission is hereby granted, free of charge, to any person obtaining a copy of 1691 | this software and associated documentation files (the "Software"), to deal in 1692 | the Software without restriction, including without limitation the rights to 1693 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 1694 | of the Software, and to permit persons to whom the Software is furnished to do 1695 | so, subject to the following conditions: 1696 | The above copyright notice and this permission notice shall be included in all 1697 | copies or substantial portions of the Software. 1698 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1699 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1700 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 1701 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 1702 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 1703 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 1704 | SOFTWARE. 1705 | ------------------------------------------------------------------------------ 1706 | ALTERNATIVE B - Public Domain (www.unlicense.org) 1707 | This is free and unencumbered software released into the public domain. 1708 | Anyone is free to copy, modify, publish, use, compile, sell, or distribute this 1709 | software, either in source code form or as a compiled binary, for any purpose, 1710 | commercial or non-commercial, and by any means. 1711 | In jurisdictions that recognize copyright laws, the author or authors of this 1712 | software dedicate any and all copyright interest in the software to the public 1713 | domain. We make this dedication for the benefit of the public at large and to 1714 | the detriment of our heirs and successors. We intend this dedication to be an 1715 | overt act of relinquishment in perpetuity of all present and future rights to 1716 | this software under copyright law. 1717 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1718 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1719 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 1720 | AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 1721 | ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 1722 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 1723 | ------------------------------------------------------------------------------ 1724 | */ --------------------------------------------------------------------------------