├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE.md ├── README.md ├── cuda_rasterizer ├── auxiliary.h ├── backward.cu ├── backward.h ├── config.h ├── forward.cu ├── forward.h ├── rasterizer.h ├── rasterizer_impl.cu └── rasterizer_impl.h ├── diff_surfel_rasterization └── __init__.py ├── ext.cpp ├── rasterize_points.cu ├── rasterize_points.h ├── setup.py └── third_party └── stbi_image_write.h /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | diff_surfel_rasterization.egg-info/ 3 | dist/ 4 | *.so 5 | .vscode 6 | **__pycache__** 7 | *.pyc -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/glm"] 2 | path = third_party/glm 3 | url = https://github.com/g-truc/glm.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | cmake_minimum_required(VERSION 3.20) 13 | 14 | project(DiffRast LANGUAGES CUDA CXX) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_EXTENSIONS OFF) 18 | set(CMAKE_CUDA_STANDARD 17) 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 21 | 22 | add_library(CudaRasterizer 23 | cuda_rasterizer/backward.h 24 | cuda_rasterizer/backward.cu 25 | cuda_rasterizer/forward.h 26 | cuda_rasterizer/forward.cu 27 | cuda_rasterizer/auxiliary.h 28 | cuda_rasterizer/rasterizer_impl.cu 29 | cuda_rasterizer/rasterizer_impl.h 30 | cuda_rasterizer/rasterizer.h 31 | ) 32 | 33 | set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86") 34 | 35 | target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) 36 | target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differential Surfel Rasterization 2 | 3 | This is the rasterization engine for the paper "2D Gaussian Splatting for Geometrically Accurate Radiance Fields". If you can make use of it in your own research, please be so kind to cite us. 4 | 5 |
6 |
7 |

BibTeX

8 |
@inproceedings{Huang2DGS2024,
 9 |     title={2D Gaussian Splatting for Geometrically Accurate Radiance Fields},
10 |     author={Huang, Binbin and Yu, Zehao and Chen, Anpei and Geiger, Andreas and Gao, Shenghua},
11 |     publisher = {Association for Computing Machinery},
12 |     booktitle = {SIGGRAPH 2024 Conference Papers},
13 |     year      = {2024},
14 |     doi       = {10.1145/3641519.3657428}
15 | }
16 |
17 |
18 | 19 |
20 |
21 |

BibTeX

22 |
@Article{kerbl3Dgaussians,
23 |       author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
24 |       title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
25 |       journal      = {ACM Transactions on Graphics},
26 |       number       = {4},
27 |       volume       = {42},
28 |       month        = {July},
29 |       year         = {2023},
30 |       url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
31 | }
32 |
33 |
-------------------------------------------------------------------------------- /cuda_rasterizer/auxiliary.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 13 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 14 | 15 | #include "config.h" 16 | #include "stdio.h" 17 | 18 | #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) 19 | #define NUM_WARPS (BLOCK_SIZE/32) 20 | 21 | #define TIGHTBBOX 0 22 | #define RENDER_AXUTILITY 1 23 | #define DEPTH_OFFSET 0 24 | #define ALPHA_OFFSET 1 25 | #define NORMAL_OFFSET 2 26 | #define MIDDEPTH_OFFSET 5 27 | #define DISTORTION_OFFSET 6 28 | // #define MEDIAN_WEIGHT_OFFSET 7 29 | 30 | // distortion helper macros 31 | #define BACKFACE_CULL 1 32 | #define DUAL_VISIABLE 1 33 | // #define NEAR_PLANE 0.2 34 | // #define FAR_PLANE 100.0 35 | #define DETACH_WEIGHT 0 36 | 37 | __device__ const float near_n = 0.2; 38 | __device__ const float far_n = 100.0; 39 | __device__ const float FilterSize = 0.707106; // sqrt(2) / 2 40 | __device__ const float FilterInvSquare = 2.0f; 41 | 42 | // Spherical harmonics coefficients 43 | __device__ const float SH_C0 = 0.28209479177387814f; 44 | __device__ const float SH_C1 = 0.4886025119029199f; 45 | __device__ const float SH_C2[] = { 46 | 1.0925484305920792f, 47 | -1.0925484305920792f, 48 | 0.31539156525252005f, 49 | -1.0925484305920792f, 50 | 0.5462742152960396f 51 | }; 52 | __device__ const float SH_C3[] = { 53 | -0.5900435899266435f, 54 | 2.890611442640554f, 55 | -0.4570457994644658f, 56 | 0.3731763325901154f, 57 | -0.4570457994644658f, 58 | 1.445305721320277f, 59 | -0.5900435899266435f 60 | }; 61 | 62 | __forceinline__ __device__ float ndc2Pix(float v, int S) 63 | { 64 | return ((v + 1.0) * S - 1.0) * 0.5; 65 | } 66 | 67 | __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) 68 | { 69 | rect_min = { 70 | min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), 71 | min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) 72 | }; 73 | rect_max = { 74 | min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))), 75 | min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y))) 76 | }; 77 | } 78 | 79 | __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) 80 | { 81 | float3 transformed = { 82 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 83 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 84 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 85 | }; 86 | return transformed; 87 | } 88 | 89 | __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) 90 | { 91 | float4 transformed = { 92 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 93 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 94 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 95 | matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] 96 | }; 97 | return transformed; 98 | } 99 | 100 | __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) 101 | { 102 | float3 transformed = { 103 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, 104 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, 105 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, 106 | }; 107 | return transformed; 108 | } 109 | 110 | __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) 111 | { 112 | float3 transformed = { 113 | matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, 114 | matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, 115 | matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, 116 | }; 117 | return transformed; 118 | } 119 | 120 | __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) 121 | { 122 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 123 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 124 | float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 125 | return dnormvdz; 126 | } 127 | 128 | __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) 129 | { 130 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 131 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 132 | 133 | float3 dnormvdv; 134 | dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; 135 | dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; 136 | dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 137 | return dnormvdv; 138 | } 139 | 140 | __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) 141 | { 142 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; 143 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 144 | 145 | float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; 146 | float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; 147 | float4 dnormvdv; 148 | dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; 149 | dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; 150 | dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; 151 | dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; 152 | return dnormvdv; 153 | } 154 | 155 | __forceinline__ __device__ float3 cross(float3 a, float3 b){return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);} 156 | 157 | __forceinline__ __device__ float3 operator*(float3 a, float3 b){return make_float3(a.x * b.x, a.y * b.y, a.z*b.z);} 158 | 159 | __forceinline__ __device__ float2 operator*(float2 a, float2 b){return make_float2(a.x * b.x, a.y * b.y);} 160 | 161 | __forceinline__ __device__ float3 operator*(float f, float3 a){return make_float3(f * a.x, f * a.y, f * a.z);} 162 | 163 | __forceinline__ __device__ float2 operator*(float f, float2 a){return make_float2(f * a.x, f * a.y);} 164 | 165 | __forceinline__ __device__ float3 operator-(float3 a, float3 b){return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);} 166 | 167 | __forceinline__ __device__ float2 operator-(float2 a, float2 b){return make_float2(a.x - b.x, a.y - b.y);} 168 | 169 | __forceinline__ __device__ float sumf3(float3 a){return a.x + a.y + a.z;} 170 | 171 | __forceinline__ __device__ float sumf2(float2 a){return a.x + a.y;} 172 | 173 | __forceinline__ __device__ float3 sqrtf3(float3 a){return make_float3(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z));} 174 | 175 | __forceinline__ __device__ float2 sqrtf2(float2 a){return make_float2(sqrtf(a.x), sqrtf(a.y));} 176 | 177 | __forceinline__ __device__ float3 minf3(float f, float3 a){return make_float3(min(f, a.x), min(f, a.y), min(f, a.z));} 178 | 179 | __forceinline__ __device__ float2 minf2(float f, float2 a){return make_float2(min(f, a.x), min(f, a.y));} 180 | 181 | __forceinline__ __device__ float3 maxf3(float f, float3 a){return make_float3(max(f, a.x), max(f, a.y), max(f, a.z));} 182 | 183 | __forceinline__ __device__ float2 maxf2(float f, float2 a){return make_float2(max(f, a.x), max(f, a.y));} 184 | 185 | __forceinline__ __device__ bool in_frustum(int idx, 186 | const float* orig_points, 187 | const float* viewmatrix, 188 | const float* projmatrix, 189 | bool prefiltered, 190 | float3& p_view) 191 | { 192 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 193 | 194 | // Bring points to screen space 195 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 196 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 197 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 198 | p_view = transformPoint4x3(p_orig, viewmatrix); 199 | 200 | if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) 201 | { 202 | if (prefiltered) 203 | { 204 | printf("Point is filtered although prefiltered is set. This shouldn't happen!"); 205 | __trap(); 206 | } 207 | return false; 208 | } 209 | return true; 210 | } 211 | 212 | // adopt from gsplat: https://github.com/nerfstudio-project/gsplat/blob/main/gsplat/cuda/csrc/forward.cu 213 | inline __device__ glm::mat3 quat_to_rotmat(const glm::vec4 quat) { 214 | // quat to rotation matrix 215 | float s = rsqrtf( 216 | quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z 217 | ); 218 | float w = quat.x * s; 219 | float x = quat.y * s; 220 | float y = quat.z * s; 221 | float z = quat.w * s; 222 | 223 | // glm matrices are column-major 224 | return glm::mat3( 225 | 1.f - 2.f * (y * y + z * z), 226 | 2.f * (x * y + w * z), 227 | 2.f * (x * z - w * y), 228 | 2.f * (x * y - w * z), 229 | 1.f - 2.f * (x * x + z * z), 230 | 2.f * (y * z + w * x), 231 | 2.f * (x * z + w * y), 232 | 2.f * (y * z - w * x), 233 | 1.f - 2.f * (x * x + y * y) 234 | ); 235 | } 236 | 237 | 238 | inline __device__ glm::vec4 239 | quat_to_rotmat_vjp(const glm::vec4 quat, const glm::mat3 v_R) { 240 | float s = rsqrtf( 241 | quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z 242 | ); 243 | float w = quat.x * s; 244 | float x = quat.y * s; 245 | float y = quat.z * s; 246 | float z = quat.w * s; 247 | 248 | glm::vec4 v_quat; 249 | // v_R is COLUMN MAJOR 250 | // w element stored in x field 251 | v_quat.x = 252 | 2.f * ( 253 | // v_quat.w = 2.f * ( 254 | x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + 255 | z * (v_R[0][1] - v_R[1][0]) 256 | ); 257 | // x element in y field 258 | v_quat.y = 259 | 2.f * 260 | ( 261 | // v_quat.x = 2.f * ( 262 | -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + 263 | z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1]) 264 | ); 265 | // y element in z field 266 | v_quat.z = 267 | 2.f * 268 | ( 269 | // v_quat.y = 2.f * ( 270 | x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + 271 | z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2]) 272 | ); 273 | // z element in w field 274 | v_quat.w = 275 | 2.f * 276 | ( 277 | // v_quat.z = 2.f * ( 278 | x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - 279 | 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0]) 280 | ); 281 | return v_quat; 282 | } 283 | 284 | 285 | inline __device__ glm::mat3 286 | scale_to_mat(const glm::vec2 scale, const float glob_scale) { 287 | glm::mat3 S = glm::mat3(1.f); 288 | S[0][0] = glob_scale * scale.x; 289 | S[1][1] = glob_scale * scale.y; 290 | // S[2][2] = glob_scale * scale.z; 291 | return S; 292 | } 293 | 294 | 295 | 296 | #define CHECK_CUDA(A, debug) \ 297 | A; if(debug) { \ 298 | auto ret = cudaDeviceSynchronize(); \ 299 | if (ret != cudaSuccess) { \ 300 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ 301 | throw std::runtime_error(cudaGetErrorString(ret)); \ 302 | } \ 303 | } 304 | 305 | #endif -------------------------------------------------------------------------------- /cuda_rasterizer/backward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "backward.h" 13 | #include "auxiliary.h" 14 | #include 15 | #include 16 | namespace cg = cooperative_groups; 17 | 18 | // Backward pass for conversion of spherical harmonics to RGB for 19 | // each Gaussian. 20 | __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs) 21 | { 22 | // Compute intermediate values, as it is done during forward 23 | glm::vec3 pos = means[idx]; 24 | glm::vec3 dir_orig = pos - campos; 25 | glm::vec3 dir = dir_orig / glm::length(dir_orig); 26 | 27 | glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; 28 | 29 | // Use PyTorch rule for clamping: if clamping was applied, 30 | // gradient becomes 0. 31 | glm::vec3 dL_dRGB = dL_dcolor[idx]; 32 | dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1; 33 | dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1; 34 | dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1; 35 | 36 | glm::vec3 dRGBdx(0, 0, 0); 37 | glm::vec3 dRGBdy(0, 0, 0); 38 | glm::vec3 dRGBdz(0, 0, 0); 39 | float x = dir.x; 40 | float y = dir.y; 41 | float z = dir.z; 42 | 43 | // Target location for this Gaussian to write SH gradients to 44 | glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs; 45 | 46 | // No tricks here, just high school-level calculus. 47 | float dRGBdsh0 = SH_C0; 48 | dL_dsh[0] = dRGBdsh0 * dL_dRGB; 49 | if (deg > 0) 50 | { 51 | float dRGBdsh1 = -SH_C1 * y; 52 | float dRGBdsh2 = SH_C1 * z; 53 | float dRGBdsh3 = -SH_C1 * x; 54 | dL_dsh[1] = dRGBdsh1 * dL_dRGB; 55 | dL_dsh[2] = dRGBdsh2 * dL_dRGB; 56 | dL_dsh[3] = dRGBdsh3 * dL_dRGB; 57 | 58 | dRGBdx = -SH_C1 * sh[3]; 59 | dRGBdy = -SH_C1 * sh[1]; 60 | dRGBdz = SH_C1 * sh[2]; 61 | 62 | if (deg > 1) 63 | { 64 | float xx = x * x, yy = y * y, zz = z * z; 65 | float xy = x * y, yz = y * z, xz = x * z; 66 | 67 | float dRGBdsh4 = SH_C2[0] * xy; 68 | float dRGBdsh5 = SH_C2[1] * yz; 69 | float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy); 70 | float dRGBdsh7 = SH_C2[3] * xz; 71 | float dRGBdsh8 = SH_C2[4] * (xx - yy); 72 | dL_dsh[4] = dRGBdsh4 * dL_dRGB; 73 | dL_dsh[5] = dRGBdsh5 * dL_dRGB; 74 | dL_dsh[6] = dRGBdsh6 * dL_dRGB; 75 | dL_dsh[7] = dRGBdsh7 * dL_dRGB; 76 | dL_dsh[8] = dRGBdsh8 * dL_dRGB; 77 | 78 | dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8]; 79 | dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8]; 80 | dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7]; 81 | 82 | if (deg > 2) 83 | { 84 | float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy); 85 | float dRGBdsh10 = SH_C3[1] * xy * z; 86 | float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy); 87 | float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy); 88 | float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy); 89 | float dRGBdsh14 = SH_C3[5] * z * (xx - yy); 90 | float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy); 91 | dL_dsh[9] = dRGBdsh9 * dL_dRGB; 92 | dL_dsh[10] = dRGBdsh10 * dL_dRGB; 93 | dL_dsh[11] = dRGBdsh11 * dL_dRGB; 94 | dL_dsh[12] = dRGBdsh12 * dL_dRGB; 95 | dL_dsh[13] = dRGBdsh13 * dL_dRGB; 96 | dL_dsh[14] = dRGBdsh14 * dL_dRGB; 97 | dL_dsh[15] = dRGBdsh15 * dL_dRGB; 98 | 99 | dRGBdx += ( 100 | SH_C3[0] * sh[9] * 3.f * 2.f * xy + 101 | SH_C3[1] * sh[10] * yz + 102 | SH_C3[2] * sh[11] * -2.f * xy + 103 | SH_C3[3] * sh[12] * -3.f * 2.f * xz + 104 | SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) + 105 | SH_C3[5] * sh[14] * 2.f * xz + 106 | SH_C3[6] * sh[15] * 3.f * (xx - yy)); 107 | 108 | dRGBdy += ( 109 | SH_C3[0] * sh[9] * 3.f * (xx - yy) + 110 | SH_C3[1] * sh[10] * xz + 111 | SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) + 112 | SH_C3[3] * sh[12] * -3.f * 2.f * yz + 113 | SH_C3[4] * sh[13] * -2.f * xy + 114 | SH_C3[5] * sh[14] * -2.f * yz + 115 | SH_C3[6] * sh[15] * -3.f * 2.f * xy); 116 | 117 | dRGBdz += ( 118 | SH_C3[1] * sh[10] * xy + 119 | SH_C3[2] * sh[11] * 4.f * 2.f * yz + 120 | SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) + 121 | SH_C3[4] * sh[13] * 4.f * 2.f * xz + 122 | SH_C3[5] * sh[14] * (xx - yy)); 123 | } 124 | } 125 | } 126 | 127 | // The view direction is an input to the computation. View direction 128 | // is influenced by the Gaussian's mean, so SHs gradients 129 | // must propagate back into 3D position. 130 | glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB)); 131 | 132 | // Account for normalization of direction 133 | float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z }); 134 | 135 | // Gradients of loss w.r.t. Gaussian means, but only the portion 136 | // that is caused because the mean affects the view-dependent color. 137 | // Additional mean gradient is accumulated in below methods. 138 | dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z); 139 | } 140 | 141 | 142 | // Backward version of the rendering procedure. 143 | template 144 | __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) 145 | renderCUDA( 146 | const uint2* __restrict__ ranges, 147 | const uint32_t* __restrict__ point_list, 148 | int W, int H, 149 | float focal_x, float focal_y, 150 | const float* __restrict__ bg_color, 151 | const float2* __restrict__ points_xy_image, 152 | const float4* __restrict__ normal_opacity, 153 | const float* __restrict__ transMats, 154 | const float* __restrict__ colors, 155 | const float* __restrict__ depths, 156 | const float* __restrict__ final_Ts, 157 | const uint32_t* __restrict__ n_contrib, 158 | const float* __restrict__ dL_dpixels, 159 | const float* __restrict__ dL_depths, 160 | float * __restrict__ dL_dtransMat, 161 | float3* __restrict__ dL_dmean2D, 162 | float* __restrict__ dL_dnormal3D, 163 | float* __restrict__ dL_dopacity, 164 | float* __restrict__ dL_dcolors) 165 | { 166 | // We rasterize again. Compute necessary block info. 167 | auto block = cg::this_thread_block(); 168 | const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; 169 | const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; 170 | const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; 171 | const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; 172 | const uint32_t pix_id = W * pix.y + pix.x; 173 | const float2 pixf = {(float)pix.x, (float)pix.y}; 174 | 175 | const bool inside = pix.x < W&& pix.y < H; 176 | const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; 177 | 178 | const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); 179 | 180 | bool done = !inside; 181 | int toDo = range.y - range.x; 182 | 183 | __shared__ int collected_id[BLOCK_SIZE]; 184 | __shared__ float2 collected_xy[BLOCK_SIZE]; 185 | __shared__ float4 collected_normal_opacity[BLOCK_SIZE]; 186 | __shared__ float collected_colors[C * BLOCK_SIZE]; 187 | __shared__ float3 collected_Tu[BLOCK_SIZE]; 188 | __shared__ float3 collected_Tv[BLOCK_SIZE]; 189 | __shared__ float3 collected_Tw[BLOCK_SIZE]; 190 | // __shared__ float collected_depths[BLOCK_SIZE]; 191 | 192 | // In the forward, we stored the final value for T, the 193 | // product of all (1 - alpha) factors. 194 | const float T_final = inside ? final_Ts[pix_id] : 0; 195 | float T = T_final; 196 | 197 | // We start from the back. The ID of the last contributing 198 | // Gaussian is known from each pixel from the forward. 199 | uint32_t contributor = toDo; 200 | const int last_contributor = inside ? n_contrib[pix_id] : 0; 201 | 202 | float accum_rec[C] = { 0 }; 203 | float dL_dpixel[C]; 204 | 205 | #if RENDER_AXUTILITY 206 | float dL_dreg; 207 | float dL_ddepth; 208 | float dL_daccum; 209 | float dL_dnormal2D[3]; 210 | const int median_contributor = inside ? n_contrib[pix_id + H * W] : 0; 211 | float dL_dmedian_depth; 212 | float dL_dmax_dweight; 213 | 214 | if (inside) { 215 | dL_ddepth = dL_depths[DEPTH_OFFSET * H * W + pix_id]; 216 | dL_daccum = dL_depths[ALPHA_OFFSET * H * W + pix_id]; 217 | dL_dreg = dL_depths[DISTORTION_OFFSET * H * W + pix_id]; 218 | for (int i = 0; i < 3; i++) 219 | dL_dnormal2D[i] = dL_depths[(NORMAL_OFFSET + i) * H * W + pix_id]; 220 | 221 | dL_dmedian_depth = dL_depths[MIDDEPTH_OFFSET * H * W + pix_id]; 222 | // dL_dmax_dweight = dL_depths[MEDIAN_WEIGHT_OFFSET * H * W + pix_id]; 223 | } 224 | 225 | // for compute gradient with respect to depth and normal 226 | float last_depth = 0; 227 | float last_normal[3] = { 0 }; 228 | float accum_depth_rec = 0; 229 | float accum_alpha_rec = 0; 230 | float accum_normal_rec[3] = {0}; 231 | // for compute gradient with respect to the distortion map 232 | const float final_D = inside ? final_Ts[pix_id + H * W] : 0; 233 | const float final_D2 = inside ? final_Ts[pix_id + 2 * H * W] : 0; 234 | const float final_A = 1 - T_final; 235 | float last_dL_dT = 0; 236 | #endif 237 | 238 | if (inside){ 239 | for (int i = 0; i < C; i++) 240 | dL_dpixel[i] = dL_dpixels[i * H * W + pix_id]; 241 | } 242 | 243 | float last_alpha = 0; 244 | float last_color[C] = { 0 }; 245 | 246 | // Gradient of pixel coordinate w.r.t. normalized 247 | // screen-space viewport corrdinates (-1 to 1) 248 | const float ddelx_dx = 0.5 * W; 249 | const float ddely_dy = 0.5 * H; 250 | 251 | // Traverse all Gaussians 252 | for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) 253 | { 254 | // Load auxiliary data into shared memory, start in the BACK 255 | // and load them in revers order. 256 | block.sync(); 257 | const int progress = i * BLOCK_SIZE + block.thread_rank(); 258 | if (range.x + progress < range.y) 259 | { 260 | const int coll_id = point_list[range.y - progress - 1]; 261 | collected_id[block.thread_rank()] = coll_id; 262 | collected_xy[block.thread_rank()] = points_xy_image[coll_id]; 263 | collected_normal_opacity[block.thread_rank()] = normal_opacity[coll_id]; 264 | collected_Tu[block.thread_rank()] = {transMats[9 * coll_id+0], transMats[9 * coll_id+1], transMats[9 * coll_id+2]}; 265 | collected_Tv[block.thread_rank()] = {transMats[9 * coll_id+3], transMats[9 * coll_id+4], transMats[9 * coll_id+5]}; 266 | collected_Tw[block.thread_rank()] = {transMats[9 * coll_id+6], transMats[9 * coll_id+7], transMats[9 * coll_id+8]}; 267 | for (int i = 0; i < C; i++) 268 | collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i]; 269 | // collected_depths[block.thread_rank()] = depths[coll_id]; 270 | } 271 | block.sync(); 272 | 273 | // Iterate over Gaussians 274 | for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) 275 | { 276 | // Keep track of current Gaussian ID. Skip, if this one 277 | // is behind the last contributor for this pixel. 278 | contributor--; 279 | if (contributor >= last_contributor) 280 | continue; 281 | 282 | // compute ray-splat intersection as before 283 | // Fisrt compute two homogeneous planes, See Eq. (8) 284 | const float2 xy = collected_xy[j]; 285 | const float3 Tu = collected_Tu[j]; 286 | const float3 Tv = collected_Tv[j]; 287 | const float3 Tw = collected_Tw[j]; 288 | float3 k = pix.x * Tw - Tu; 289 | float3 l = pix.y * Tw - Tv; 290 | float3 p = cross(k, l); 291 | if (p.z == 0.0) continue; 292 | float2 s = {p.x / p.z, p.y / p.z}; 293 | float rho3d = (s.x * s.x + s.y * s.y); 294 | float2 d = {xy.x - pixf.x, xy.y - pixf.y}; 295 | float rho2d = FilterInvSquare * (d.x * d.x + d.y * d.y); 296 | float rho = min(rho3d, rho2d); 297 | 298 | // compute depth 299 | float c_d = (s.x * Tw.x + s.y * Tw.y) + Tw.z; // Tw * [u,v,1] 300 | // if a point is too small, its depth is not reliable? 301 | // c_d = (rho3d <= rho2d) ? c_d : Tw.z; 302 | if (c_d < near_n) continue; 303 | 304 | float4 nor_o = collected_normal_opacity[j]; 305 | float normal[3] = {nor_o.x, nor_o.y, nor_o.z}; 306 | float opa = nor_o.w; 307 | 308 | // accumulations 309 | 310 | float power = -0.5f * rho; 311 | if (power > 0.0f) 312 | continue; 313 | 314 | const float G = exp(power); 315 | const float alpha = min(0.99f, opa * G); 316 | if (alpha < 1.0f / 255.0f) 317 | continue; 318 | 319 | T = T / (1.f - alpha); 320 | const float dchannel_dcolor = alpha * T; 321 | const float w = alpha * T; 322 | // Propagate gradients to per-Gaussian colors and keep 323 | // gradients w.r.t. alpha (blending factor for a Gaussian/pixel 324 | // pair). 325 | float dL_dalpha = 0.0f; 326 | const int global_id = collected_id[j]; 327 | for (int ch = 0; ch < C; ch++) 328 | { 329 | const float c = collected_colors[ch * BLOCK_SIZE + j]; 330 | // Update last color (to be used in the next iteration) 331 | accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch]; 332 | last_color[ch] = c; 333 | 334 | const float dL_dchannel = dL_dpixel[ch]; 335 | dL_dalpha += (c - accum_rec[ch]) * dL_dchannel; 336 | // Update the gradients w.r.t. color of the Gaussian. 337 | // Atomic, since this pixel is just one of potentially 338 | // many that were affected by this Gaussian. 339 | atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel); 340 | } 341 | 342 | float dL_dz = 0.0f; 343 | float dL_dweight = 0; 344 | #if RENDER_AXUTILITY 345 | const float m_d = far_n / (far_n - near_n) * (1 - near_n / c_d); 346 | const float dmd_dd = (far_n * near_n) / ((far_n - near_n) * c_d * c_d); 347 | if (contributor == median_contributor-1) { 348 | dL_dz += dL_dmedian_depth; 349 | // dL_dweight += dL_dmax_dweight; 350 | } 351 | #if DETACH_WEIGHT 352 | // if not detached weight, sometimes 353 | // it will bia toward creating extragated 2D Gaussians near front 354 | dL_dweight += 0; 355 | #else 356 | dL_dweight += (final_D2 + m_d * m_d * final_A - 2 * m_d * final_D) * dL_dreg; 357 | #endif 358 | dL_dalpha += dL_dweight - last_dL_dT; 359 | // propagate the current weight W_{i} to next weight W_{i-1} 360 | last_dL_dT = dL_dweight * alpha + (1 - alpha) * last_dL_dT; 361 | const float dL_dmd = 2.0f * (T * alpha) * (m_d * final_A - final_D) * dL_dreg; 362 | dL_dz += dL_dmd * dmd_dd; 363 | 364 | // Propagate gradients w.r.t ray-splat depths 365 | accum_depth_rec = last_alpha * last_depth + (1.f - last_alpha) * accum_depth_rec; 366 | last_depth = c_d; 367 | dL_dalpha += (c_d - accum_depth_rec) * dL_ddepth; 368 | // Propagate gradients w.r.t. color ray-splat alphas 369 | accum_alpha_rec = last_alpha * 1.0 + (1.f - last_alpha) * accum_alpha_rec; 370 | dL_dalpha += (1 - accum_alpha_rec) * dL_daccum; 371 | 372 | // Propagate gradients to per-Gaussian normals 373 | for (int ch = 0; ch < 3; ch++) { 374 | accum_normal_rec[ch] = last_alpha * last_normal[ch] + (1.f - last_alpha) * accum_normal_rec[ch]; 375 | last_normal[ch] = normal[ch]; 376 | dL_dalpha += (normal[ch] - accum_normal_rec[ch]) * dL_dnormal2D[ch]; 377 | atomicAdd((&dL_dnormal3D[global_id * 3 + ch]), alpha * T * dL_dnormal2D[ch]); 378 | } 379 | #endif 380 | 381 | dL_dalpha *= T; 382 | // Update last alpha (to be used in the next iteration) 383 | last_alpha = alpha; 384 | 385 | // Account for fact that alpha also influences how much of 386 | // the background color is added if nothing left to blend 387 | float bg_dot_dpixel = 0; 388 | for (int i = 0; i < C; i++) 389 | bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; 390 | dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel; 391 | 392 | 393 | // Helpful reusable temporary variables 394 | const float dL_dG = nor_o.w * dL_dalpha; 395 | #if RENDER_AXUTILITY 396 | dL_dz += alpha * T * dL_ddepth; 397 | #endif 398 | 399 | if (rho3d <= rho2d) { 400 | // Update gradients w.r.t. covariance of Gaussian 3x3 (T) 401 | const float2 dL_ds = { 402 | dL_dG * -G * s.x + dL_dz * Tw.x, 403 | dL_dG * -G * s.y + dL_dz * Tw.y 404 | }; 405 | const float3 dz_dTw = {s.x, s.y, 1.0}; 406 | const float dsx_pz = dL_ds.x / p.z; 407 | const float dsy_pz = dL_ds.y / p.z; 408 | const float3 dL_dp = {dsx_pz, dsy_pz, -(dsx_pz * s.x + dsy_pz * s.y)}; 409 | const float3 dL_dk = cross(l, dL_dp); 410 | const float3 dL_dl = cross(dL_dp, k); 411 | 412 | const float3 dL_dTu = {-dL_dk.x, -dL_dk.y, -dL_dk.z}; 413 | const float3 dL_dTv = {-dL_dl.x, -dL_dl.y, -dL_dl.z}; 414 | const float3 dL_dTw = { 415 | pixf.x * dL_dk.x + pixf.y * dL_dl.x + dL_dz * dz_dTw.x, 416 | pixf.x * dL_dk.y + pixf.y * dL_dl.y + dL_dz * dz_dTw.y, 417 | pixf.x * dL_dk.z + pixf.y * dL_dl.z + dL_dz * dz_dTw.z}; 418 | 419 | 420 | // Update gradients w.r.t. 3D covariance (3x3 matrix) 421 | atomicAdd(&dL_dtransMat[global_id * 9 + 0], dL_dTu.x); 422 | atomicAdd(&dL_dtransMat[global_id * 9 + 1], dL_dTu.y); 423 | atomicAdd(&dL_dtransMat[global_id * 9 + 2], dL_dTu.z); 424 | atomicAdd(&dL_dtransMat[global_id * 9 + 3], dL_dTv.x); 425 | atomicAdd(&dL_dtransMat[global_id * 9 + 4], dL_dTv.y); 426 | atomicAdd(&dL_dtransMat[global_id * 9 + 5], dL_dTv.z); 427 | atomicAdd(&dL_dtransMat[global_id * 9 + 6], dL_dTw.x); 428 | atomicAdd(&dL_dtransMat[global_id * 9 + 7], dL_dTw.y); 429 | atomicAdd(&dL_dtransMat[global_id * 9 + 8], dL_dTw.z); 430 | } else { 431 | // // Update gradients w.r.t. center of Gaussian 2D mean position 432 | const float dG_ddelx = -G * FilterInvSquare * d.x; 433 | const float dG_ddely = -G * FilterInvSquare * d.y; 434 | atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx); // not scaled 435 | atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely); // not scaled 436 | // // Propagate the gradients of depth 437 | atomicAdd(&dL_dtransMat[global_id * 9 + 6], s.x * dL_dz); 438 | atomicAdd(&dL_dtransMat[global_id * 9 + 7], s.y * dL_dz); 439 | atomicAdd(&dL_dtransMat[global_id * 9 + 8], dL_dz); 440 | } 441 | 442 | // Update gradients w.r.t. opacity of the Gaussian 443 | atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha); 444 | } 445 | } 446 | } 447 | 448 | 449 | __device__ void compute_transmat_aabb( 450 | int idx, 451 | const float* Ts_precomp, 452 | const float3* p_origs, 453 | const glm::vec2* scales, 454 | const glm::vec4* rots, 455 | const float* projmatrix, 456 | const float* viewmatrix, 457 | const int W, const int H, 458 | const float3* dL_dnormals, 459 | const float3* dL_dmean2Ds, 460 | float* dL_dTs, 461 | glm::vec3* dL_dmeans, 462 | glm::vec2* dL_dscales, 463 | glm::vec4* dL_drots) 464 | { 465 | glm::mat3 T; 466 | float3 normal; 467 | glm::mat3x4 P; 468 | glm::mat3 R; 469 | glm::mat3 S; 470 | float3 p_orig; 471 | glm::vec4 rot; 472 | glm::vec2 scale; 473 | 474 | // Get transformation matrix of the Gaussian 475 | if (Ts_precomp != nullptr) { 476 | T = glm::mat3( 477 | Ts_precomp[idx * 9 + 0], Ts_precomp[idx * 9 + 1], Ts_precomp[idx * 9 + 2], 478 | Ts_precomp[idx * 9 + 3], Ts_precomp[idx * 9 + 4], Ts_precomp[idx * 9 + 5], 479 | Ts_precomp[idx * 9 + 6], Ts_precomp[idx * 9 + 7], Ts_precomp[idx * 9 + 8] 480 | ); 481 | normal = {0.0, 0.0, 0.0}; 482 | } else { 483 | p_orig = p_origs[idx]; 484 | rot = rots[idx]; 485 | scale = scales[idx]; 486 | R = quat_to_rotmat(rot); 487 | S = scale_to_mat(scale, 1.0f); 488 | 489 | glm::mat3 L = R * S; 490 | glm::mat3x4 M = glm::mat3x4( 491 | glm::vec4(L[0], 0.0), 492 | glm::vec4(L[1], 0.0), 493 | glm::vec4(p_orig.x, p_orig.y, p_orig.z, 1) 494 | ); 495 | 496 | glm::mat4 world2ndc = glm::mat4( 497 | projmatrix[0], projmatrix[4], projmatrix[8], projmatrix[12], 498 | projmatrix[1], projmatrix[5], projmatrix[9], projmatrix[13], 499 | projmatrix[2], projmatrix[6], projmatrix[10], projmatrix[14], 500 | projmatrix[3], projmatrix[7], projmatrix[11], projmatrix[15] 501 | ); 502 | 503 | glm::mat3x4 ndc2pix = glm::mat3x4( 504 | glm::vec4(float(W) / 2.0, 0.0, 0.0, float(W-1) / 2.0), 505 | glm::vec4(0.0, float(H) / 2.0, 0.0, float(H-1) / 2.0), 506 | glm::vec4(0.0, 0.0, 0.0, 1.0) 507 | ); 508 | 509 | P = world2ndc * ndc2pix; 510 | T = glm::transpose(M) * P; 511 | normal = transformVec4x3({L[2].x, L[2].y, L[2].z}, viewmatrix); 512 | } 513 | 514 | // Update gradients w.r.t. transformation matrix of the Gaussian 515 | glm::mat3 dL_dT = glm::mat3( 516 | dL_dTs[idx*9+0], dL_dTs[idx*9+1], dL_dTs[idx*9+2], 517 | dL_dTs[idx*9+3], dL_dTs[idx*9+4], dL_dTs[idx*9+5], 518 | dL_dTs[idx*9+6], dL_dTs[idx*9+7], dL_dTs[idx*9+8] 519 | ); 520 | float3 dL_dmean2D = dL_dmean2Ds[idx]; 521 | if(dL_dmean2D.x != 0 || dL_dmean2D.y != 0) 522 | { 523 | glm::vec3 t_vec = glm::vec3(9.0f, 9.0f, -1.0f); 524 | float d = glm::dot(t_vec, T[2] * T[2]); 525 | glm::vec3 f_vec = t_vec * (1.0f / d); 526 | glm::vec3 dL_dT0 = dL_dmean2D.x * f_vec * T[2]; 527 | glm::vec3 dL_dT1 = dL_dmean2D.y * f_vec * T[2]; 528 | glm::vec3 dL_dT3 = dL_dmean2D.x * f_vec * T[0] + dL_dmean2D.y * f_vec * T[1]; 529 | glm::vec3 dL_df = dL_dmean2D.x * T[0] * T[2] + dL_dmean2D.y * T[1] * T[2]; 530 | float dL_dd = glm::dot(dL_df, f_vec) * (-1.0 / d); 531 | glm::vec3 dd_dT3 = t_vec * T[2] * 2.0f; 532 | dL_dT3 += dL_dd * dd_dT3; 533 | dL_dT[0] += dL_dT0; 534 | dL_dT[1] += dL_dT1; 535 | dL_dT[2] += dL_dT3; 536 | 537 | if (Ts_precomp != nullptr) { 538 | dL_dTs[idx * 9 + 0] = dL_dT[0].x; 539 | dL_dTs[idx * 9 + 1] = dL_dT[0].y; 540 | dL_dTs[idx * 9 + 2] = dL_dT[0].z; 541 | dL_dTs[idx * 9 + 3] = dL_dT[1].x; 542 | dL_dTs[idx * 9 + 4] = dL_dT[1].y; 543 | dL_dTs[idx * 9 + 5] = dL_dT[1].z; 544 | dL_dTs[idx * 9 + 6] = dL_dT[2].x; 545 | dL_dTs[idx * 9 + 7] = dL_dT[2].y; 546 | dL_dTs[idx * 9 + 8] = dL_dT[2].z; 547 | return; 548 | } 549 | } 550 | 551 | if (Ts_precomp != nullptr) return; 552 | 553 | // Update gradients w.r.t. scaling, rotation, position of the Gaussian 554 | glm::mat3x4 dL_dM = P * glm::transpose(dL_dT); 555 | float3 dL_dtn = transformVec4x3Transpose(dL_dnormals[idx], viewmatrix); 556 | #if DUAL_VISIABLE 557 | float3 p_view = transformPoint4x3(p_orig, viewmatrix); 558 | float cos = -sumf3(p_view * normal); 559 | float multiplier = cos > 0 ? 1: -1; 560 | dL_dtn = multiplier * dL_dtn; 561 | #endif 562 | glm::mat3 dL_dRS = glm::mat3( 563 | glm::vec3(dL_dM[0]), 564 | glm::vec3(dL_dM[1]), 565 | glm::vec3(dL_dtn.x, dL_dtn.y, dL_dtn.z) 566 | ); 567 | 568 | glm::mat3 dL_dR = glm::mat3( 569 | dL_dRS[0] * glm::vec3(scale.x), 570 | dL_dRS[1] * glm::vec3(scale.y), 571 | dL_dRS[2]); 572 | 573 | dL_drots[idx] = quat_to_rotmat_vjp(rot, dL_dR); 574 | dL_dscales[idx] = glm::vec2( 575 | (float)glm::dot(dL_dRS[0], R[0]), 576 | (float)glm::dot(dL_dRS[1], R[1]) 577 | ); 578 | dL_dmeans[idx] = glm::vec3(dL_dM[2]); 579 | } 580 | 581 | template 582 | __global__ void preprocessCUDA( 583 | int P, int D, int M, 584 | const float3* means3D, 585 | const float* transMats, 586 | const int* radii, 587 | const float* shs, 588 | const bool* clamped, 589 | const glm::vec2* scales, 590 | const glm::vec4* rotations, 591 | const float scale_modifier, 592 | const float* viewmatrix, 593 | const float* projmatrix, 594 | const float focal_x, 595 | const float focal_y, 596 | const float tan_fovx, 597 | const float tan_fovy, 598 | const glm::vec3* campos, 599 | // grad input 600 | float* dL_dtransMats, 601 | const float* dL_dnormal3Ds, 602 | float* dL_dcolors, 603 | float* dL_dshs, 604 | float3* dL_dmean2Ds, 605 | glm::vec3* dL_dmean3Ds, 606 | glm::vec2* dL_dscales, 607 | glm::vec4* dL_drots) 608 | { 609 | auto idx = cg::this_grid().thread_rank(); 610 | if (idx >= P || !(radii[idx] > 0)) 611 | return; 612 | 613 | const int W = int(focal_x * tan_fovx * 2); 614 | const int H = int(focal_y * tan_fovy * 2); 615 | const float * Ts_precomp = (scales) ? nullptr : transMats; 616 | compute_transmat_aabb( 617 | idx, 618 | Ts_precomp, 619 | means3D, scales, rotations, 620 | projmatrix, viewmatrix, W, H, 621 | (float3*)dL_dnormal3Ds, 622 | dL_dmean2Ds, 623 | (dL_dtransMats), 624 | dL_dmean3Ds, 625 | dL_dscales, 626 | dL_drots 627 | ); 628 | 629 | if (shs) 630 | computeColorFromSH(idx, D, M, (glm::vec3*)means3D, *campos, shs, clamped, (glm::vec3*)dL_dcolors, (glm::vec3*)dL_dmean3Ds, (glm::vec3*)dL_dshs); 631 | 632 | // hack the gradient here for densitification 633 | float depth = transMats[idx * 9 + 8]; 634 | dL_dmean2Ds[idx].x = dL_dtransMats[idx * 9 + 2] * depth * 0.5 * float(W); // to ndc 635 | dL_dmean2Ds[idx].y = dL_dtransMats[idx * 9 + 5] * depth * 0.5 * float(H); // to ndc 636 | } 637 | 638 | 639 | void BACKWARD::preprocess( 640 | int P, int D, int M, 641 | const float3* means3D, 642 | const int* radii, 643 | const float* shs, 644 | const bool* clamped, 645 | const glm::vec2* scales, 646 | const glm::vec4* rotations, 647 | const float scale_modifier, 648 | const float* transMats, 649 | const float* viewmatrix, 650 | const float* projmatrix, 651 | const float focal_x, const float focal_y, 652 | const float tan_fovx, const float tan_fovy, 653 | const glm::vec3* campos, 654 | float3* dL_dmean2Ds, 655 | const float* dL_dnormal3Ds, 656 | float* dL_dtransMats, 657 | float* dL_dcolors, 658 | float* dL_dshs, 659 | glm::vec3* dL_dmean3Ds, 660 | glm::vec2* dL_dscales, 661 | glm::vec4* dL_drots) 662 | { 663 | preprocessCUDA<< <(P + 255) / 256, 256 >> > ( 664 | P, D, M, 665 | (float3*)means3D, 666 | transMats, 667 | radii, 668 | shs, 669 | clamped, 670 | (glm::vec2*)scales, 671 | (glm::vec4*)rotations, 672 | scale_modifier, 673 | viewmatrix, 674 | projmatrix, 675 | focal_x, 676 | focal_y, 677 | tan_fovx, 678 | tan_fovy, 679 | campos, 680 | dL_dtransMats, 681 | dL_dnormal3Ds, 682 | dL_dcolors, 683 | dL_dshs, 684 | dL_dmean2Ds, 685 | dL_dmean3Ds, 686 | dL_dscales, 687 | dL_drots 688 | ); 689 | } 690 | 691 | void BACKWARD::render( 692 | const dim3 grid, const dim3 block, 693 | const uint2* ranges, 694 | const uint32_t* point_list, 695 | int W, int H, 696 | float focal_x, float focal_y, 697 | const float* bg_color, 698 | const float2* means2D, 699 | const float4* normal_opacity, 700 | const float* colors, 701 | const float* transMats, 702 | const float* depths, 703 | const float* final_Ts, 704 | const uint32_t* n_contrib, 705 | const float* dL_dpixels, 706 | const float* dL_depths, 707 | float * dL_dtransMat, 708 | float3* dL_dmean2D, 709 | float* dL_dnormal3D, 710 | float* dL_dopacity, 711 | float* dL_dcolors) 712 | { 713 | renderCUDA << > >( 714 | ranges, 715 | point_list, 716 | W, H, 717 | focal_x, focal_y, 718 | bg_color, 719 | means2D, 720 | normal_opacity, 721 | transMats, 722 | colors, 723 | depths, 724 | final_Ts, 725 | n_contrib, 726 | dL_dpixels, 727 | dL_depths, 728 | dL_dtransMat, 729 | dL_dmean2D, 730 | dL_dnormal3D, 731 | dL_dopacity, 732 | dL_dcolors 733 | ); 734 | } 735 | -------------------------------------------------------------------------------- /cuda_rasterizer/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const dim3 grid, dim3 block, 25 | const uint2* ranges, 26 | const uint32_t* point_list, 27 | int W, int H, 28 | float focal_x, float focal_y, 29 | const float* bg_color, 30 | const float2* means2D, 31 | const float4* normal_opacity, 32 | const float* transMats, 33 | const float* colors, 34 | const float* depths, 35 | const float* final_Ts, 36 | const uint32_t* n_contrib, 37 | const float* dL_dpixels, 38 | const float* dL_depths, 39 | float * dL_dtransMat, 40 | float3* dL_dmean2D, 41 | float* dL_dnormal3D, 42 | float* dL_dopacity, 43 | float* dL_dcolors); 44 | 45 | void preprocess( 46 | int P, int D, int M, 47 | const float3* means, 48 | const int* radii, 49 | const float* shs, 50 | const bool* clamped, 51 | const glm::vec2* scales, 52 | const glm::vec4* rotations, 53 | const float scale_modifier, 54 | const float* transMats, 55 | const float* view, 56 | const float* proj, 57 | const float focal_x, const float focal_y, 58 | const float tan_fovx, const float tan_fovy, 59 | const glm::vec3* campos, 60 | float3* dL_dmean2D, 61 | const float* dL_dnormal3D, 62 | float* dL_dtransMat, 63 | float* dL_dcolor, 64 | float* dL_dsh, 65 | glm::vec3* dL_dmeans, 66 | glm::vec2* dL_dscale, 67 | glm::vec4* dL_drot); 68 | } 69 | 70 | #endif 71 | -------------------------------------------------------------------------------- /cuda_rasterizer/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3, RGB 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /cuda_rasterizer/forward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "forward.h" 13 | #include "auxiliary.h" 14 | #include 15 | #include 16 | namespace cg = cooperative_groups; 17 | 18 | // Forward method for converting the input spherical harmonics 19 | // coefficients of each Gaussian to a simple RGB color. 20 | __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped) 21 | { 22 | // The implementation is loosely based on code for 23 | // "Differentiable Point-Based Radiance Fields for 24 | // Efficient View Synthesis" by Zhang et al. (2022) 25 | glm::vec3 pos = means[idx]; 26 | glm::vec3 dir = pos - campos; 27 | dir = dir / glm::length(dir); 28 | 29 | glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; 30 | glm::vec3 result = SH_C0 * sh[0]; 31 | 32 | if (deg > 0) 33 | { 34 | float x = dir.x; 35 | float y = dir.y; 36 | float z = dir.z; 37 | result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3]; 38 | 39 | if (deg > 1) 40 | { 41 | float xx = x * x, yy = y * y, zz = z * z; 42 | float xy = x * y, yz = y * z, xz = x * z; 43 | result = result + 44 | SH_C2[0] * xy * sh[4] + 45 | SH_C2[1] * yz * sh[5] + 46 | SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] + 47 | SH_C2[3] * xz * sh[7] + 48 | SH_C2[4] * (xx - yy) * sh[8]; 49 | 50 | if (deg > 2) 51 | { 52 | result = result + 53 | SH_C3[0] * y * (3.0f * xx - yy) * sh[9] + 54 | SH_C3[1] * xy * z * sh[10] + 55 | SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] + 56 | SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] + 57 | SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] + 58 | SH_C3[5] * z * (xx - yy) * sh[14] + 59 | SH_C3[6] * x * (xx - 3.0f * yy) * sh[15]; 60 | } 61 | } 62 | } 63 | result += 0.5f; 64 | 65 | // RGB colors are clamped to positive values. If values are 66 | // clamped, we need to keep track of this for the backward pass. 67 | clamped[3 * idx + 0] = (result.x < 0); 68 | clamped[3 * idx + 1] = (result.y < 0); 69 | clamped[3 * idx + 2] = (result.z < 0); 70 | return glm::max(result, 0.0f); 71 | } 72 | 73 | // Compute a 2D-to-2D mapping matrix from a tangent plane into a image plane 74 | // given a 2D gaussian parameters. 75 | __device__ void compute_transmat( 76 | const float3& p_orig, 77 | const glm::vec2 scale, 78 | float mod, 79 | const glm::vec4 rot, 80 | const float* projmatrix, 81 | const float* viewmatrix, 82 | const int W, 83 | const int H, 84 | glm::mat3 &T, 85 | float3 &normal 86 | ) { 87 | 88 | glm::mat3 R = quat_to_rotmat(rot); 89 | glm::mat3 S = scale_to_mat(scale, mod); 90 | glm::mat3 L = R * S; 91 | 92 | // center of Gaussians in the camera coordinate 93 | glm::mat3x4 splat2world = glm::mat3x4( 94 | glm::vec4(L[0], 0.0), 95 | glm::vec4(L[1], 0.0), 96 | glm::vec4(p_orig.x, p_orig.y, p_orig.z, 1) 97 | ); 98 | 99 | glm::mat4 world2ndc = glm::mat4( 100 | projmatrix[0], projmatrix[4], projmatrix[8], projmatrix[12], 101 | projmatrix[1], projmatrix[5], projmatrix[9], projmatrix[13], 102 | projmatrix[2], projmatrix[6], projmatrix[10], projmatrix[14], 103 | projmatrix[3], projmatrix[7], projmatrix[11], projmatrix[15] 104 | ); 105 | 106 | glm::mat3x4 ndc2pix = glm::mat3x4( 107 | glm::vec4(float(W) / 2.0, 0.0, 0.0, float(W-1) / 2.0), 108 | glm::vec4(0.0, float(H) / 2.0, 0.0, float(H-1) / 2.0), 109 | glm::vec4(0.0, 0.0, 0.0, 1.0) 110 | ); 111 | 112 | T = glm::transpose(splat2world) * world2ndc * ndc2pix; 113 | normal = transformVec4x3({L[2].x, L[2].y, L[2].z}, viewmatrix); 114 | 115 | } 116 | 117 | // Computing the bounding box of the 2D Gaussian and its center 118 | // The center of the bounding box is used to create a low pass filter 119 | __device__ bool compute_aabb( 120 | glm::mat3 T, 121 | float cutoff, 122 | float2& point_image, 123 | float2& extent 124 | ) { 125 | glm::vec3 t = glm::vec3(cutoff * cutoff, cutoff * cutoff, -1.0f); 126 | float d = glm::dot(t, T[2] * T[2]); 127 | if (d == 0.0) return false; 128 | glm::vec3 f = (1 / d) * t; 129 | 130 | glm::vec2 p = glm::vec2( 131 | glm::dot(f, T[0] * T[2]), 132 | glm::dot(f, T[1] * T[2]) 133 | ); 134 | 135 | glm::vec2 h0 = p * p - 136 | glm::vec2( 137 | glm::dot(f, T[0] * T[0]), 138 | glm::dot(f, T[1] * T[1]) 139 | ); 140 | 141 | glm::vec2 h = sqrt(max(glm::vec2(1e-4, 1e-4), h0)); 142 | point_image = {p.x, p.y}; 143 | extent = {h.x, h.y}; 144 | return true; 145 | } 146 | 147 | // Perform initial steps for each Gaussian prior to rasterization. 148 | template 149 | __global__ void preprocessCUDA(int P, int D, int M, 150 | const float* orig_points, 151 | const glm::vec2* scales, 152 | const float scale_modifier, 153 | const glm::vec4* rotations, 154 | const float* opacities, 155 | const float* shs, 156 | bool* clamped, 157 | const float* transMat_precomp, 158 | const float* colors_precomp, 159 | const float* viewmatrix, 160 | const float* projmatrix, 161 | const glm::vec3* cam_pos, 162 | const int W, int H, 163 | const float tan_fovx, const float tan_fovy, 164 | const float focal_x, const float focal_y, 165 | int* radii, 166 | float2* points_xy_image, 167 | float* depths, 168 | float* transMats, 169 | float* rgb, 170 | float4* normal_opacity, 171 | const dim3 grid, 172 | uint32_t* tiles_touched, 173 | bool prefiltered) 174 | { 175 | auto idx = cg::this_grid().thread_rank(); 176 | if (idx >= P) 177 | return; 178 | 179 | // Initialize radius and touched tiles to 0. If this isn't changed, 180 | // this Gaussian will not be processed further. 181 | radii[idx] = 0; 182 | tiles_touched[idx] = 0; 183 | 184 | // Perform near culling, quit if outside. 185 | float3 p_view; 186 | if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) 187 | return; 188 | 189 | // Compute transformation matrix 190 | glm::mat3 T; 191 | float3 normal; 192 | if (transMat_precomp == nullptr) 193 | { 194 | compute_transmat(((float3*)orig_points)[idx], scales[idx], scale_modifier, rotations[idx], projmatrix, viewmatrix, W, H, T, normal); 195 | float3 *T_ptr = (float3*)transMats; 196 | T_ptr[idx * 3 + 0] = {T[0][0], T[0][1], T[0][2]}; 197 | T_ptr[idx * 3 + 1] = {T[1][0], T[1][1], T[1][2]}; 198 | T_ptr[idx * 3 + 2] = {T[2][0], T[2][1], T[2][2]}; 199 | } else { 200 | glm::vec3 *T_ptr = (glm::vec3*)transMat_precomp; 201 | T = glm::mat3( 202 | T_ptr[idx * 3 + 0], 203 | T_ptr[idx * 3 + 1], 204 | T_ptr[idx * 3 + 2] 205 | ); 206 | normal = make_float3(0.0, 0.0, 1.0); 207 | } 208 | 209 | #if DUAL_VISIABLE 210 | float cos = -sumf3(p_view * normal); 211 | if (cos == 0) return; 212 | float multiplier = cos > 0 ? 1: -1; 213 | normal = multiplier * normal; 214 | #endif 215 | 216 | #if TIGHTBBOX // no use in the paper, but it indeed help speeds. 217 | // the effective extent is now depended on the opacity of gaussian. 218 | float cutoff = sqrtf(max(9.f + 2.f * logf(opacities[idx]), 0.000001)); 219 | #else 220 | float cutoff = 3.0f; 221 | #endif 222 | 223 | // Compute center and radius 224 | float2 point_image; 225 | float radius; 226 | { 227 | float2 extent; 228 | bool ok = compute_aabb(T, cutoff, point_image, extent); 229 | if (!ok) return; 230 | radius = ceil(max(max(extent.x, extent.y), cutoff * FilterSize)); 231 | } 232 | 233 | uint2 rect_min, rect_max; 234 | getRect(point_image, radius, rect_min, rect_max, grid); 235 | if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) 236 | return; 237 | 238 | // Compute colors 239 | if (colors_precomp == nullptr) { 240 | glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); 241 | rgb[idx * C + 0] = result.x; 242 | rgb[idx * C + 1] = result.y; 243 | rgb[idx * C + 2] = result.z; 244 | } 245 | 246 | depths[idx] = p_view.z; 247 | radii[idx] = (int)radius; 248 | points_xy_image[idx] = point_image; 249 | normal_opacity[idx] = {normal.x, normal.y, normal.z, opacities[idx]}; 250 | tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); 251 | } 252 | 253 | // Main rasterization method. Collaboratively works on one tile per 254 | // block, each thread treats one pixel. Alternates between fetching 255 | // and rasterizing data. 256 | template 257 | __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) 258 | renderCUDA( 259 | const uint2* __restrict__ ranges, 260 | const uint32_t* __restrict__ point_list, 261 | int W, int H, 262 | float focal_x, float focal_y, 263 | const float2* __restrict__ points_xy_image, 264 | const float* __restrict__ features, 265 | const float* __restrict__ transMats, 266 | const float* __restrict__ depths, 267 | const float4* __restrict__ normal_opacity, 268 | float* __restrict__ final_T, 269 | uint32_t* __restrict__ n_contrib, 270 | const float* __restrict__ bg_color, 271 | float* __restrict__ out_color, 272 | float* __restrict__ out_others) 273 | { 274 | // Identify current tile and associated min/max pixel range. 275 | auto block = cg::this_thread_block(); 276 | uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; 277 | uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; 278 | uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; 279 | uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; 280 | uint32_t pix_id = W * pix.y + pix.x; 281 | float2 pixf = { (float)pix.x, (float)pix.y}; 282 | 283 | // Check if this thread is associated with a valid pixel or outside. 284 | bool inside = pix.x < W&& pix.y < H; 285 | // Done threads can help with fetching, but don't rasterize 286 | bool done = !inside; 287 | 288 | // Load start/end range of IDs to process in bit sorted list. 289 | uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; 290 | const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); 291 | int toDo = range.y - range.x; 292 | 293 | // Allocate storage for batches of collectively fetched data. 294 | __shared__ int collected_id[BLOCK_SIZE]; 295 | __shared__ float2 collected_xy[BLOCK_SIZE]; 296 | __shared__ float4 collected_normal_opacity[BLOCK_SIZE]; 297 | __shared__ float3 collected_Tu[BLOCK_SIZE]; 298 | __shared__ float3 collected_Tv[BLOCK_SIZE]; 299 | __shared__ float3 collected_Tw[BLOCK_SIZE]; 300 | 301 | // Initialize helper variables 302 | float T = 1.0f; 303 | uint32_t contributor = 0; 304 | uint32_t last_contributor = 0; 305 | float C[CHANNELS] = { 0 }; 306 | 307 | 308 | #if RENDER_AXUTILITY 309 | // render axutility ouput 310 | float N[3] = {0}; 311 | float D = { 0 }; 312 | float M1 = {0}; 313 | float M2 = {0}; 314 | float distortion = {0}; 315 | float median_depth = {0}; 316 | // float median_weight = {0}; 317 | float median_contributor = {-1}; 318 | 319 | #endif 320 | 321 | // Iterate over batches until all done or range is complete 322 | for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) 323 | { 324 | // End if entire block votes that it is done rasterizing 325 | int num_done = __syncthreads_count(done); 326 | if (num_done == BLOCK_SIZE) 327 | break; 328 | 329 | // Collectively fetch per-Gaussian data from global to shared 330 | int progress = i * BLOCK_SIZE + block.thread_rank(); 331 | if (range.x + progress < range.y) 332 | { 333 | int coll_id = point_list[range.x + progress]; 334 | collected_id[block.thread_rank()] = coll_id; 335 | collected_xy[block.thread_rank()] = points_xy_image[coll_id]; 336 | collected_normal_opacity[block.thread_rank()] = normal_opacity[coll_id]; 337 | collected_Tu[block.thread_rank()] = {transMats[9 * coll_id+0], transMats[9 * coll_id+1], transMats[9 * coll_id+2]}; 338 | collected_Tv[block.thread_rank()] = {transMats[9 * coll_id+3], transMats[9 * coll_id+4], transMats[9 * coll_id+5]}; 339 | collected_Tw[block.thread_rank()] = {transMats[9 * coll_id+6], transMats[9 * coll_id+7], transMats[9 * coll_id+8]}; 340 | } 341 | block.sync(); 342 | 343 | // Iterate over current batch 344 | for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) 345 | { 346 | // Keep track of current position in range 347 | contributor++; 348 | 349 | // Fisrt compute two homogeneous planes, See Eq. (8) 350 | const float2 xy = collected_xy[j]; 351 | const float3 Tu = collected_Tu[j]; 352 | const float3 Tv = collected_Tv[j]; 353 | const float3 Tw = collected_Tw[j]; 354 | // Transform the two planes into local u-v system. 355 | float3 k = pix.x * Tw - Tu; 356 | float3 l = pix.y * Tw - Tv; 357 | // Cross product of two planes is a line, Eq. (9) 358 | float3 p = cross(k, l); 359 | if (p.z == 0.0) continue; 360 | // Perspective division to get the intersection (u,v), Eq. (10) 361 | float2 s = {p.x / p.z, p.y / p.z}; 362 | float rho3d = (s.x * s.x + s.y * s.y); 363 | // Add low pass filter 364 | float2 d = {xy.x - pixf.x, xy.y - pixf.y}; 365 | float rho2d = FilterInvSquare * (d.x * d.x + d.y * d.y); 366 | float rho = min(rho3d, rho2d); 367 | 368 | // compute depth 369 | float depth = (s.x * Tw.x + s.y * Tw.y) + Tw.z; 370 | // if a point is too small, its depth is not reliable? 371 | // depth = (rho3d <= rho2d) ? depth : Tw.z 372 | if (depth < near_n) continue; 373 | 374 | float4 nor_o = collected_normal_opacity[j]; 375 | float normal[3] = {nor_o.x, nor_o.y, nor_o.z}; 376 | float opa = nor_o.w; 377 | 378 | float power = -0.5f * rho; 379 | if (power > 0.0f) 380 | continue; 381 | 382 | // Eq. (2) from 3D Gaussian splatting paper. 383 | // Obtain alpha by multiplying with Gaussian opacity 384 | // and its exponential falloff from mean. 385 | // Avoid numerical instabilities (see paper appendix). 386 | float alpha = min(0.99f, opa * exp(power)); 387 | if (alpha < 1.0f / 255.0f) 388 | continue; 389 | float test_T = T * (1 - alpha); 390 | if (test_T < 0.0001f) 391 | { 392 | done = true; 393 | continue; 394 | } 395 | 396 | float w = alpha * T; 397 | #if RENDER_AXUTILITY 398 | // Render depth distortion map 399 | // Efficient implementation of distortion loss, see 2DGS' paper appendix. 400 | float A = 1-T; 401 | float m = far_n / (far_n - near_n) * (1 - near_n / depth); 402 | distortion += (m * m * A + M2 - 2 * m * M1) * w; 403 | D += depth * w; 404 | M1 += m * w; 405 | M2 += m * m * w; 406 | 407 | if (T > 0.5) { 408 | median_depth = depth; 409 | // median_weight = w; 410 | median_contributor = contributor; 411 | } 412 | // Render normal map 413 | for (int ch=0; ch<3; ch++) N[ch] += normal[ch] * w; 414 | #endif 415 | 416 | // Eq. (3) from 3D Gaussian splatting paper. 417 | for (int ch = 0; ch < CHANNELS; ch++) 418 | C[ch] += features[collected_id[j] * CHANNELS + ch] * w; 419 | T = test_T; 420 | 421 | // Keep track of last range entry to update this 422 | // pixel. 423 | last_contributor = contributor; 424 | } 425 | } 426 | 427 | // All threads that treat valid pixel write out their final 428 | // rendering data to the frame and auxiliary buffers. 429 | if (inside) 430 | { 431 | final_T[pix_id] = T; 432 | n_contrib[pix_id] = last_contributor; 433 | for (int ch = 0; ch < CHANNELS; ch++) 434 | out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; 435 | 436 | #if RENDER_AXUTILITY 437 | n_contrib[pix_id + H * W] = median_contributor; 438 | final_T[pix_id + H * W] = M1; 439 | final_T[pix_id + 2 * H * W] = M2; 440 | out_others[pix_id + DEPTH_OFFSET * H * W] = D; 441 | out_others[pix_id + ALPHA_OFFSET * H * W] = 1 - T; 442 | for (int ch=0; ch<3; ch++) out_others[pix_id + (NORMAL_OFFSET+ch) * H * W] = N[ch]; 443 | out_others[pix_id + MIDDEPTH_OFFSET * H * W] = median_depth; 444 | out_others[pix_id + DISTORTION_OFFSET * H * W] = distortion; 445 | // out_others[pix_id + MEDIAN_WEIGHT_OFFSET * H * W] = median_weight; 446 | #endif 447 | } 448 | } 449 | 450 | void FORWARD::render( 451 | const dim3 grid, dim3 block, 452 | const uint2* ranges, 453 | const uint32_t* point_list, 454 | int W, int H, 455 | float focal_x, float focal_y, 456 | const float2* means2D, 457 | const float* colors, 458 | const float* transMats, 459 | const float* depths, 460 | const float4* normal_opacity, 461 | float* final_T, 462 | uint32_t* n_contrib, 463 | const float* bg_color, 464 | float* out_color, 465 | float* out_others) 466 | { 467 | renderCUDA << > > ( 468 | ranges, 469 | point_list, 470 | W, H, 471 | focal_x, focal_y, 472 | means2D, 473 | colors, 474 | transMats, 475 | depths, 476 | normal_opacity, 477 | final_T, 478 | n_contrib, 479 | bg_color, 480 | out_color, 481 | out_others); 482 | } 483 | 484 | void FORWARD::preprocess(int P, int D, int M, 485 | const float* means3D, 486 | const glm::vec2* scales, 487 | const float scale_modifier, 488 | const glm::vec4* rotations, 489 | const float* opacities, 490 | const float* shs, 491 | bool* clamped, 492 | const float* transMat_precomp, 493 | const float* colors_precomp, 494 | const float* viewmatrix, 495 | const float* projmatrix, 496 | const glm::vec3* cam_pos, 497 | const int W, const int H, 498 | const float focal_x, const float focal_y, 499 | const float tan_fovx, const float tan_fovy, 500 | int* radii, 501 | float2* means2D, 502 | float* depths, 503 | float* transMats, 504 | float* rgb, 505 | float4* normal_opacity, 506 | const dim3 grid, 507 | uint32_t* tiles_touched, 508 | bool prefiltered) 509 | { 510 | preprocessCUDA << <(P + 255) / 256, 256 >> > ( 511 | P, D, M, 512 | means3D, 513 | scales, 514 | scale_modifier, 515 | rotations, 516 | opacities, 517 | shs, 518 | clamped, 519 | transMat_precomp, 520 | colors_precomp, 521 | viewmatrix, 522 | projmatrix, 523 | cam_pos, 524 | W, H, 525 | tan_fovx, tan_fovy, 526 | focal_x, focal_y, 527 | radii, 528 | means2D, 529 | depths, 530 | transMats, 531 | rgb, 532 | normal_opacity, 533 | grid, 534 | tiles_touched, 535 | prefiltered 536 | ); 537 | } 538 | -------------------------------------------------------------------------------- /cuda_rasterizer/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float* orig_points, 26 | const glm::vec2* scales, 27 | const float scale_modifier, 28 | const glm::vec4* rotations, 29 | const float* opacities, 30 | const float* shs, 31 | bool* clamped, 32 | const float* transMat_precomp, 33 | const float* colors_precomp, 34 | const float* viewmatrix, 35 | const float* projmatrix, 36 | const glm::vec3* cam_pos, 37 | const int W, int H, 38 | const float focal_x, float focal_y, 39 | const float tan_fovx, float tan_fovy, 40 | int* radii, 41 | float2* points_xy_image, 42 | float* depths, 43 | // float* isovals, 44 | // float3* normals, 45 | float* transMats, 46 | float* colors, 47 | float4* normal_opacity, 48 | const dim3 grid, 49 | uint32_t* tiles_touched, 50 | bool prefiltered); 51 | 52 | // Main rasterization method. 53 | void render( 54 | const dim3 grid, dim3 block, 55 | const uint2* ranges, 56 | const uint32_t* point_list, 57 | int W, int H, 58 | float focal_x, float focal_y, 59 | const float2* points_xy_image, 60 | const float* features, 61 | const float* transMats, 62 | const float* depths, 63 | const float4* normal_opacity, 64 | float* final_T, 65 | uint32_t* n_contrib, 66 | const float* bg_color, 67 | float* out_color, 68 | float* out_others); 69 | } 70 | 71 | 72 | #endif 73 | -------------------------------------------------------------------------------- /cuda_rasterizer/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | 24 | static void markVisible( 25 | int P, 26 | float* means3D, 27 | float* viewmatrix, 28 | float* projmatrix, 29 | bool* present); 30 | 31 | static int forward( 32 | std::function geometryBuffer, 33 | std::function binningBuffer, 34 | std::function imageBuffer, 35 | const int P, int D, int M, 36 | const float* background, 37 | const int width, int height, 38 | const float* means3D, 39 | const float* shs, 40 | const float* colors_precomp, 41 | const float* opacities, 42 | const float* scales, 43 | const float scale_modifier, 44 | const float* rotations, 45 | const float* transMat_precomp, 46 | const float* viewmatrix, 47 | const float* projmatrix, 48 | const float* cam_pos, 49 | const float tan_fovx, float tan_fovy, 50 | const bool prefiltered, 51 | float* out_color, 52 | float* out_others, 53 | int* radii = nullptr, 54 | bool debug = false); 55 | 56 | static void backward( 57 | const int P, int D, int M, int R, 58 | const float* background, 59 | const int width, int height, 60 | const float* means3D, 61 | const float* shs, 62 | const float* colors_precomp, 63 | const float* scales, 64 | const float scale_modifier, 65 | const float* rotations, 66 | const float* transMat_precomp, 67 | const float* viewmatrix, 68 | const float* projmatrix, 69 | const float* campos, 70 | const float tan_fovx, float tan_fovy, 71 | const int* radii, 72 | char* geom_buffer, 73 | char* binning_buffer, 74 | char* image_buffer, 75 | const float* dL_dpix, 76 | const float* dL_depths, 77 | float* dL_dmean2D, 78 | float* dL_dnormal, 79 | float* dL_dopacity, 80 | float* dL_dcolor, 81 | float* dL_dmean3D, 82 | float* dL_dtransMat, 83 | float* dL_dsh, 84 | float* dL_dscale, 85 | float* dL_drot, 86 | bool debug); 87 | }; 88 | }; 89 | 90 | #endif 91 | -------------------------------------------------------------------------------- /cuda_rasterizer/rasterizer_impl.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "rasterizer_impl.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include "cuda_runtime.h" 19 | #include "device_launch_parameters.h" 20 | #include 21 | #include 22 | #define GLM_FORCE_CUDA 23 | #include 24 | 25 | #include 26 | #include 27 | namespace cg = cooperative_groups; 28 | 29 | #include "auxiliary.h" 30 | #include "forward.h" 31 | #include "backward.h" 32 | 33 | // Helper function to find the next-highest bit of the MSB 34 | // on the CPU. 35 | uint32_t getHigherMsb(uint32_t n) 36 | { 37 | uint32_t msb = sizeof(n) * 4; 38 | uint32_t step = msb; 39 | while (step > 1) 40 | { 41 | step /= 2; 42 | if (n >> msb) 43 | msb += step; 44 | else 45 | msb -= step; 46 | } 47 | if (n >> msb) 48 | msb++; 49 | return msb; 50 | } 51 | 52 | // Wrapper method to call auxiliary coarse frustum containment test. 53 | // Mark all Gaussians that pass it. 54 | __global__ void checkFrustum(int P, 55 | const float* orig_points, 56 | const float* viewmatrix, 57 | const float* projmatrix, 58 | bool* present) 59 | { 60 | auto idx = cg::this_grid().thread_rank(); 61 | if (idx >= P) 62 | return; 63 | 64 | float3 p_view; 65 | present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view); 66 | } 67 | 68 | // Generates one key/value pair for all Gaussian / tile overlaps. 69 | // Run once per Gaussian (1:N mapping). 70 | __global__ void duplicateWithKeys( 71 | int P, 72 | const float2* points_xy, 73 | const float* depths, 74 | const uint32_t* offsets, 75 | uint64_t* gaussian_keys_unsorted, 76 | uint32_t* gaussian_values_unsorted, 77 | int* radii, 78 | dim3 grid) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | if (idx >= P) 82 | return; 83 | 84 | // Generate no key/value pair for invisible Gaussians 85 | if (radii[idx] > 0) 86 | { 87 | // Find this Gaussian's offset in buffer for writing keys/values. 88 | uint32_t off = (idx == 0) ? 0 : offsets[idx - 1]; 89 | uint2 rect_min, rect_max; 90 | 91 | getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid); 92 | 93 | // For each tile that the bounding rect overlaps, emit a 94 | // key/value pair. The key is | tile ID | depth |, 95 | // and the value is the ID of the Gaussian. Sorting the values 96 | // with this key yields Gaussian IDs in a list, such that they 97 | // are first sorted by tile and then by depth. 98 | for (int y = rect_min.y; y < rect_max.y; y++) 99 | { 100 | for (int x = rect_min.x; x < rect_max.x; x++) 101 | { 102 | uint64_t key = y * grid.x + x; 103 | key <<= 32; 104 | key |= *((uint32_t*)&depths[idx]); 105 | gaussian_keys_unsorted[off] = key; 106 | gaussian_values_unsorted[off] = idx; 107 | off++; 108 | } 109 | } 110 | } 111 | } 112 | 113 | // Check keys to see if it is at the start/end of one tile's range in 114 | // the full sorted list. If yes, write start/end of this tile. 115 | // Run once per instanced (duplicated) Gaussian ID. 116 | __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges) 117 | { 118 | auto idx = cg::this_grid().thread_rank(); 119 | if (idx >= L) 120 | return; 121 | 122 | // Read tile ID from key. Update start/end of tile range if at limit. 123 | uint64_t key = point_list_keys[idx]; 124 | uint32_t currtile = key >> 32; 125 | if (idx == 0) 126 | ranges[currtile].x = 0; 127 | else 128 | { 129 | uint32_t prevtile = point_list_keys[idx - 1] >> 32; 130 | if (currtile != prevtile) 131 | { 132 | ranges[prevtile].y = idx; 133 | ranges[currtile].x = idx; 134 | } 135 | } 136 | if (idx == L - 1) 137 | ranges[currtile].y = L; 138 | } 139 | 140 | // Mark Gaussians as visible/invisible, based on view frustum testing 141 | void CudaRasterizer::Rasterizer::markVisible( 142 | int P, 143 | float* means3D, 144 | float* viewmatrix, 145 | float* projmatrix, 146 | bool* present) 147 | { 148 | checkFrustum << <(P + 255) / 256, 256 >> > ( 149 | P, 150 | means3D, 151 | viewmatrix, projmatrix, 152 | present); 153 | } 154 | 155 | CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P) 156 | { 157 | GeometryState geom; 158 | obtain(chunk, geom.depths, P, 128); 159 | obtain(chunk, geom.clamped, P * 3, 128); 160 | obtain(chunk, geom.internal_radii, P, 128); 161 | obtain(chunk, geom.means2D, P, 128); 162 | obtain(chunk, geom.transMat, P * 9, 128); 163 | obtain(chunk, geom.normal_opacity, P, 128); 164 | obtain(chunk, geom.rgb, P * 3, 128); 165 | obtain(chunk, geom.tiles_touched, P, 128); 166 | cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P); 167 | obtain(chunk, geom.scanning_space, geom.scan_size, 128); 168 | obtain(chunk, geom.point_offsets, P, 128); 169 | return geom; 170 | } 171 | 172 | CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N) 173 | { 174 | ImageState img; 175 | obtain(chunk, img.accum_alpha, N * 3, 128); 176 | obtain(chunk, img.n_contrib, N * 2, 128); 177 | obtain(chunk, img.ranges, N, 128); 178 | return img; 179 | } 180 | 181 | CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P) 182 | { 183 | BinningState binning; 184 | obtain(chunk, binning.point_list, P, 128); 185 | obtain(chunk, binning.point_list_unsorted, P, 128); 186 | obtain(chunk, binning.point_list_keys, P, 128); 187 | obtain(chunk, binning.point_list_keys_unsorted, P, 128); 188 | cub::DeviceRadixSort::SortPairs( 189 | nullptr, binning.sorting_size, 190 | binning.point_list_keys_unsorted, binning.point_list_keys, 191 | binning.point_list_unsorted, binning.point_list, P); 192 | obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); 193 | return binning; 194 | } 195 | 196 | // Forward rendering procedure for differentiable rasterization 197 | // of Gaussians. 198 | int CudaRasterizer::Rasterizer::forward( 199 | std::function geometryBuffer, 200 | std::function binningBuffer, 201 | std::function imageBuffer, 202 | const int P, int D, int M, 203 | const float* background, 204 | const int width, int height, 205 | const float* means3D, 206 | const float* shs, 207 | const float* colors_precomp, 208 | const float* opacities, 209 | const float* scales, 210 | const float scale_modifier, 211 | const float* rotations, 212 | const float* transMat_precomp, 213 | const float* viewmatrix, 214 | const float* projmatrix, 215 | const float* cam_pos, 216 | const float tan_fovx, float tan_fovy, 217 | const bool prefiltered, 218 | float* out_color, 219 | float* out_others, 220 | int* radii, 221 | bool debug) 222 | { 223 | const float focal_y = height / (2.0f * tan_fovy); 224 | const float focal_x = width / (2.0f * tan_fovx); 225 | 226 | size_t chunk_size = required(P); 227 | char* chunkptr = geometryBuffer(chunk_size); 228 | GeometryState geomState = GeometryState::fromChunk(chunkptr, P); 229 | 230 | if (radii == nullptr) 231 | { 232 | radii = geomState.internal_radii; 233 | } 234 | 235 | dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); 236 | dim3 block(BLOCK_X, BLOCK_Y, 1); 237 | 238 | // Dynamically resize image-based auxiliary buffers during training 239 | size_t img_chunk_size = required(width * height); 240 | char* img_chunkptr = imageBuffer(img_chunk_size); 241 | ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height); 242 | 243 | if (NUM_CHANNELS != 3 && colors_precomp == nullptr) 244 | { 245 | throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!"); 246 | } 247 | 248 | // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) 249 | CHECK_CUDA(FORWARD::preprocess( 250 | P, D, M, 251 | means3D, 252 | (glm::vec2*)scales, 253 | scale_modifier, 254 | (glm::vec4*)rotations, 255 | opacities, 256 | shs, 257 | geomState.clamped, 258 | transMat_precomp, 259 | colors_precomp, 260 | viewmatrix, projmatrix, 261 | (glm::vec3*)cam_pos, 262 | width, height, 263 | focal_x, focal_y, 264 | tan_fovx, tan_fovy, 265 | radii, 266 | geomState.means2D, 267 | geomState.depths, 268 | geomState.transMat, 269 | geomState.rgb, 270 | geomState.normal_opacity, 271 | tile_grid, 272 | geomState.tiles_touched, 273 | prefiltered 274 | ), debug) 275 | 276 | // Compute prefix sum over full list of touched tile counts by Gaussians 277 | // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] 278 | CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug) 279 | 280 | // Retrieve total number of Gaussian instances to launch and resize aux buffers 281 | int num_rendered; 282 | CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug); 283 | 284 | size_t binning_chunk_size = required(num_rendered); 285 | char* binning_chunkptr = binningBuffer(binning_chunk_size); 286 | BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered); 287 | 288 | // For each instance to be rendered, produce adequate [ tile | depth ] key 289 | // and corresponding dublicated Gaussian indices to be sorted 290 | duplicateWithKeys << <(P + 255) / 256, 256 >> > ( 291 | P, 292 | geomState.means2D, 293 | geomState.depths, 294 | geomState.point_offsets, 295 | binningState.point_list_keys_unsorted, 296 | binningState.point_list_unsorted, 297 | radii, 298 | tile_grid) 299 | CHECK_CUDA(, debug) 300 | 301 | int bit = getHigherMsb(tile_grid.x * tile_grid.y); 302 | 303 | // Sort complete list of (duplicated) Gaussian indices by keys 304 | CHECK_CUDA(cub::DeviceRadixSort::SortPairs( 305 | binningState.list_sorting_space, 306 | binningState.sorting_size, 307 | binningState.point_list_keys_unsorted, binningState.point_list_keys, 308 | binningState.point_list_unsorted, binningState.point_list, 309 | num_rendered, 0, 32 + bit), debug) 310 | 311 | CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug); 312 | 313 | // Identify start and end of per-tile workloads in sorted list 314 | if (num_rendered > 0) 315 | identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > ( 316 | num_rendered, 317 | binningState.point_list_keys, 318 | imgState.ranges); 319 | CHECK_CUDA(, debug) 320 | 321 | // Let each tile blend its range of Gaussians independently in parallel 322 | const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb; 323 | const float* transMat_ptr = transMat_precomp != nullptr ? transMat_precomp : geomState.transMat; 324 | CHECK_CUDA(FORWARD::render( 325 | tile_grid, block, 326 | imgState.ranges, 327 | binningState.point_list, 328 | width, height, 329 | focal_x, focal_y, 330 | geomState.means2D, 331 | feature_ptr, 332 | transMat_ptr, 333 | geomState.depths, 334 | geomState.normal_opacity, 335 | imgState.accum_alpha, 336 | imgState.n_contrib, 337 | background, 338 | out_color, 339 | out_others), debug) 340 | 341 | return num_rendered; 342 | } 343 | 344 | // Produce necessary gradients for optimization, corresponding 345 | // to forward render pass 346 | void CudaRasterizer::Rasterizer::backward( 347 | const int P, int D, int M, int R, 348 | const float* background, 349 | const int width, int height, 350 | const float* means3D, 351 | const float* shs, 352 | const float* colors_precomp, 353 | const float* scales, 354 | const float scale_modifier, 355 | const float* rotations, 356 | const float* transMat_precomp, 357 | const float* viewmatrix, 358 | const float* projmatrix, 359 | const float* campos, 360 | const float tan_fovx, float tan_fovy, 361 | const int* radii, 362 | char* geom_buffer, 363 | char* binning_buffer, 364 | char* img_buffer, 365 | const float* dL_dpix, 366 | const float* dL_depths, 367 | float* dL_dmean2D, 368 | float* dL_dnormal, 369 | float* dL_dopacity, 370 | float* dL_dcolor, 371 | float* dL_dmean3D, 372 | float* dL_dtransMat, 373 | float* dL_dsh, 374 | float* dL_dscale, 375 | float* dL_drot, 376 | bool debug) 377 | { 378 | GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); 379 | BinningState binningState = BinningState::fromChunk(binning_buffer, R); 380 | ImageState imgState = ImageState::fromChunk(img_buffer, width * height); 381 | 382 | if (radii == nullptr) 383 | { 384 | radii = geomState.internal_radii; 385 | } 386 | 387 | const float focal_y = height / (2.0f * tan_fovy); 388 | const float focal_x = width / (2.0f * tan_fovx); 389 | 390 | const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); 391 | const dim3 block(BLOCK_X, BLOCK_Y, 1); 392 | 393 | // Compute loss gradients w.r.t. 2D mean position, conic matrix, 394 | // opacity and RGB of Gaussians from per-pixel loss gradients. 395 | // If we were given precomputed colors and not SHs, use them. 396 | const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb; 397 | const float* depth_ptr = geomState.depths; 398 | const float* transMat_ptr = (transMat_precomp != nullptr) ? transMat_precomp : geomState.transMat; 399 | CHECK_CUDA(BACKWARD::render( 400 | tile_grid, 401 | block, 402 | imgState.ranges, 403 | binningState.point_list, 404 | width, height, 405 | focal_x, focal_y, 406 | background, 407 | geomState.means2D, 408 | geomState.normal_opacity, 409 | color_ptr, 410 | transMat_ptr, 411 | depth_ptr, 412 | imgState.accum_alpha, 413 | imgState.n_contrib, 414 | dL_dpix, 415 | dL_depths, 416 | dL_dtransMat, 417 | (float3*)dL_dmean2D, 418 | dL_dnormal, 419 | dL_dopacity, 420 | dL_dcolor), debug) 421 | 422 | // Take care of the rest of preprocessing. Was the precomputed covariance 423 | // given to us or a scales/rot pair? If precomputed, pass that. If not, 424 | // use the one we computed ourselves. 425 | // const float* transMat_ptr = (transMat_precomp != nullptr) ? transMat_precomp : geomState.transMat; 426 | CHECK_CUDA(BACKWARD::preprocess(P, D, M, 427 | (float3*)means3D, 428 | radii, 429 | shs, 430 | geomState.clamped, 431 | (glm::vec2*)scales, 432 | (glm::vec4*)rotations, 433 | scale_modifier, 434 | transMat_ptr, 435 | viewmatrix, 436 | projmatrix, 437 | focal_x, focal_y, 438 | tan_fovx, tan_fovy, 439 | (glm::vec3*)campos, 440 | (float3*)dL_dmean2D, // gradient inputs 441 | dL_dnormal, // gradient inputs 442 | dL_dtransMat, 443 | dL_dcolor, 444 | dL_dsh, 445 | (glm::vec3*)dL_dmean3D, 446 | (glm::vec2*)dL_dscale, 447 | (glm::vec4*)dL_drot), debug) 448 | } 449 | -------------------------------------------------------------------------------- /cuda_rasterizer/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | 19 | namespace CudaRasterizer 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | float* depths; 33 | char* scanning_space; 34 | bool* clamped; 35 | int* internal_radii; 36 | float2* means2D; 37 | float* transMat; 38 | float4* normal_opacity; 39 | float* rgb; 40 | uint32_t* point_offsets; 41 | uint32_t* tiles_touched; 42 | 43 | static GeometryState fromChunk(char*& chunk, size_t P); 44 | }; 45 | 46 | struct ImageState 47 | { 48 | uint2* ranges; 49 | uint32_t* n_contrib; 50 | float* accum_alpha; 51 | 52 | static ImageState fromChunk(char*& chunk, size_t N); 53 | }; 54 | 55 | struct BinningState 56 | { 57 | size_t sorting_size; 58 | uint64_t* point_list_keys_unsorted; 59 | uint64_t* point_list_keys; 60 | uint32_t* point_list_unsorted; 61 | uint32_t* point_list; 62 | char* list_sorting_space; 63 | 64 | static BinningState fromChunk(char*& chunk, size_t P); 65 | }; 66 | 67 | template 68 | size_t required(size_t P) 69 | { 70 | char* size = nullptr; 71 | T::fromChunk(size, P); 72 | return ((size_t)size) + 128; 73 | } 74 | }; -------------------------------------------------------------------------------- /diff_surfel_rasterization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from typing import NamedTuple 13 | import torch.nn as nn 14 | import torch 15 | from . import _C 16 | 17 | def cpu_deep_copy_tuple(input_tuple): 18 | copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] 19 | return tuple(copied_tensors) 20 | 21 | def rasterize_gaussians( 22 | means3D, 23 | means2D, 24 | sh, 25 | colors_precomp, 26 | opacities, 27 | scales, 28 | rotations, 29 | cov3Ds_precomp, 30 | raster_settings, 31 | ): 32 | return _RasterizeGaussians.apply( 33 | means3D, 34 | means2D, 35 | sh, 36 | colors_precomp, 37 | opacities, 38 | scales, 39 | rotations, 40 | cov3Ds_precomp, 41 | raster_settings, 42 | ) 43 | 44 | class _RasterizeGaussians(torch.autograd.Function): 45 | @staticmethod 46 | def forward( 47 | ctx, 48 | means3D, 49 | means2D, 50 | sh, 51 | colors_precomp, 52 | opacities, 53 | scales, 54 | rotations, 55 | cov3Ds_precomp, 56 | raster_settings, 57 | ): 58 | 59 | # Restructure arguments the way that the C++ lib expects them 60 | args = ( 61 | raster_settings.bg, 62 | means3D, 63 | colors_precomp, 64 | opacities, 65 | scales, 66 | rotations, 67 | raster_settings.scale_modifier, 68 | cov3Ds_precomp, 69 | raster_settings.viewmatrix, 70 | raster_settings.projmatrix, 71 | raster_settings.tanfovx, 72 | raster_settings.tanfovy, 73 | raster_settings.image_height, 74 | raster_settings.image_width, 75 | sh, 76 | raster_settings.sh_degree, 77 | raster_settings.campos, 78 | raster_settings.prefiltered, 79 | raster_settings.debug 80 | ) 81 | 82 | # Invoke C++/CUDA rasterizer 83 | if raster_settings.debug: 84 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 85 | try: 86 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 87 | except Exception as ex: 88 | torch.save(cpu_args, "snapshot_fw.dump") 89 | print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") 90 | raise ex 91 | else: 92 | num_rendered, color, depth, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) 93 | 94 | # Keep relevant tensors for backward 95 | ctx.raster_settings = raster_settings 96 | ctx.num_rendered = num_rendered 97 | ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) 98 | return color, radii, depth 99 | 100 | @staticmethod 101 | def backward(ctx, grad_out_color, grad_radii, grad_depth): 102 | 103 | # Restore necessary values from context 104 | num_rendered = ctx.num_rendered 105 | raster_settings = ctx.raster_settings 106 | colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors 107 | 108 | # Restructure args as C++ method expects them 109 | args = (raster_settings.bg, 110 | means3D, 111 | radii, 112 | colors_precomp, 113 | scales, 114 | rotations, 115 | raster_settings.scale_modifier, 116 | cov3Ds_precomp, 117 | raster_settings.viewmatrix, 118 | raster_settings.projmatrix, 119 | raster_settings.tanfovx, 120 | raster_settings.tanfovy, 121 | grad_out_color, 122 | grad_depth, 123 | sh, 124 | raster_settings.sh_degree, 125 | raster_settings.campos, 126 | geomBuffer, 127 | num_rendered, 128 | binningBuffer, 129 | imgBuffer, 130 | raster_settings.debug) 131 | 132 | # Compute gradients for relevant tensors by invoking backward method 133 | if raster_settings.debug: 134 | cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted 135 | try: 136 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) 137 | except Exception as ex: 138 | torch.save(cpu_args, "snapshot_bw.dump") 139 | print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") 140 | raise ex 141 | else: 142 | grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) 143 | 144 | grads = ( 145 | grad_means3D, 146 | grad_means2D, 147 | grad_sh, 148 | grad_colors_precomp, 149 | grad_opacities, 150 | grad_scales, 151 | grad_rotations, 152 | grad_cov3Ds_precomp, 153 | None, 154 | ) 155 | 156 | return grads 157 | 158 | class GaussianRasterizationSettings(NamedTuple): 159 | image_height: int 160 | image_width: int 161 | tanfovx : float 162 | tanfovy : float 163 | bg : torch.Tensor 164 | scale_modifier : float 165 | viewmatrix : torch.Tensor 166 | projmatrix : torch.Tensor 167 | sh_degree : int 168 | campos : torch.Tensor 169 | prefiltered : bool 170 | debug : bool 171 | 172 | class GaussianRasterizer(nn.Module): 173 | def __init__(self, raster_settings): 174 | super().__init__() 175 | self.raster_settings = raster_settings 176 | 177 | def markVisible(self, positions): 178 | # Mark visible points (based on frustum culling for camera) with a boolean 179 | with torch.no_grad(): 180 | raster_settings = self.raster_settings 181 | visible = _C.mark_visible( 182 | positions, 183 | raster_settings.viewmatrix, 184 | raster_settings.projmatrix) 185 | 186 | return visible 187 | 188 | def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): 189 | 190 | raster_settings = self.raster_settings 191 | 192 | if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): 193 | raise Exception('Please provide excatly one of either SHs or precomputed colors!') 194 | 195 | if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): 196 | raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') 197 | 198 | if shs is None: 199 | shs = torch.Tensor([]).cuda() 200 | if colors_precomp is None: 201 | colors_precomp = torch.Tensor([]).cuda() 202 | 203 | if scales is None: 204 | scales = torch.Tensor([]).cuda() 205 | if rotations is None: 206 | rotations = torch.Tensor([]).cuda() 207 | if cov3D_precomp is None: 208 | cov3D_precomp = torch.Tensor([]).cuda() 209 | 210 | 211 | # Invoke C++/CUDA rasterization routine 212 | return rasterize_gaussians( 213 | means3D, 214 | means2D, 215 | shs, 216 | colors_precomp, 217 | opacities, 218 | scales, 219 | rotations, 220 | cov3D_precomp, 221 | raster_settings, 222 | ) 223 | 224 | -------------------------------------------------------------------------------- /ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "rasterize_points.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); 17 | m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); 18 | m.def("mark_visible", &markVisible); 19 | } -------------------------------------------------------------------------------- /rasterize_points.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include "cuda_rasterizer/config.h" 22 | #include "cuda_rasterizer/rasterizer.h" 23 | #include 24 | #include 25 | #include 26 | 27 | #define CHECK_INPUT(x) \ 28 | AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 29 | // AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 30 | 31 | std::function resizeFunctional(torch::Tensor& t) { 32 | auto lambda = [&t](size_t N) { 33 | t.resize_({(long long)N}); 34 | return reinterpret_cast(t.contiguous().data_ptr()); 35 | }; 36 | return lambda; 37 | } 38 | 39 | std::tuple 40 | RasterizeGaussiansCUDA( 41 | const torch::Tensor& background, 42 | const torch::Tensor& means3D, 43 | const torch::Tensor& colors, 44 | const torch::Tensor& opacity, 45 | const torch::Tensor& scales, 46 | const torch::Tensor& rotations, 47 | const float scale_modifier, 48 | const torch::Tensor& transMat_precomp, 49 | const torch::Tensor& viewmatrix, 50 | const torch::Tensor& projmatrix, 51 | const float tan_fovx, 52 | const float tan_fovy, 53 | const int image_height, 54 | const int image_width, 55 | const torch::Tensor& sh, 56 | const int degree, 57 | const torch::Tensor& campos, 58 | const bool prefiltered, 59 | const bool debug) 60 | { 61 | if (means3D.ndimension() != 2 || means3D.size(1) != 3) { 62 | AT_ERROR("means3D must have dimensions (num_points, 3)"); 63 | } 64 | 65 | 66 | const int P = means3D.size(0); 67 | const int H = image_height; 68 | const int W = image_width; 69 | 70 | CHECK_INPUT(background); 71 | CHECK_INPUT(means3D); 72 | CHECK_INPUT(colors); 73 | CHECK_INPUT(opacity); 74 | CHECK_INPUT(scales); 75 | CHECK_INPUT(rotations); 76 | CHECK_INPUT(transMat_precomp); 77 | CHECK_INPUT(viewmatrix); 78 | CHECK_INPUT(projmatrix); 79 | CHECK_INPUT(sh); 80 | CHECK_INPUT(campos); 81 | 82 | auto int_opts = means3D.options().dtype(torch::kInt32); 83 | auto float_opts = means3D.options().dtype(torch::kFloat32); 84 | 85 | torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); 86 | torch::Tensor out_others = torch::full({3+3+1, H, W}, 0.0, float_opts); 87 | torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); 88 | 89 | torch::Device device(torch::kCUDA); 90 | torch::TensorOptions options(torch::kByte); 91 | torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); 92 | torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); 93 | torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); 94 | std::function geomFunc = resizeFunctional(geomBuffer); 95 | std::function binningFunc = resizeFunctional(binningBuffer); 96 | std::function imgFunc = resizeFunctional(imgBuffer); 97 | 98 | int rendered = 0; 99 | if(P != 0) 100 | { 101 | int M = 0; 102 | if(sh.size(0) != 0) 103 | { 104 | M = sh.size(1); 105 | } 106 | 107 | rendered = CudaRasterizer::Rasterizer::forward( 108 | geomFunc, 109 | binningFunc, 110 | imgFunc, 111 | P, degree, M, 112 | background.contiguous().data(), 113 | W, H, 114 | means3D.contiguous().data(), 115 | sh.contiguous().data_ptr(), 116 | colors.contiguous().data(), 117 | opacity.contiguous().data(), 118 | scales.contiguous().data_ptr(), 119 | scale_modifier, 120 | rotations.contiguous().data_ptr(), 121 | transMat_precomp.contiguous().data(), 122 | viewmatrix.contiguous().data(), 123 | projmatrix.contiguous().data(), 124 | campos.contiguous().data(), 125 | tan_fovx, 126 | tan_fovy, 127 | prefiltered, 128 | out_color.contiguous().data(), 129 | out_others.contiguous().data(), 130 | radii.contiguous().data(), 131 | debug); 132 | } 133 | return std::make_tuple(rendered, out_color, out_others, radii, geomBuffer, binningBuffer, imgBuffer); 134 | } 135 | 136 | std::tuple 137 | RasterizeGaussiansBackwardCUDA( 138 | const torch::Tensor& background, 139 | const torch::Tensor& means3D, 140 | const torch::Tensor& radii, 141 | const torch::Tensor& colors, 142 | const torch::Tensor& scales, 143 | const torch::Tensor& rotations, 144 | const float scale_modifier, 145 | const torch::Tensor& transMat_precomp, 146 | const torch::Tensor& viewmatrix, 147 | const torch::Tensor& projmatrix, 148 | const float tan_fovx, 149 | const float tan_fovy, 150 | const torch::Tensor& dL_dout_color, 151 | const torch::Tensor& dL_dout_others, 152 | const torch::Tensor& sh, 153 | const int degree, 154 | const torch::Tensor& campos, 155 | const torch::Tensor& geomBuffer, 156 | const int R, 157 | const torch::Tensor& binningBuffer, 158 | const torch::Tensor& imageBuffer, 159 | const bool debug) 160 | { 161 | 162 | CHECK_INPUT(background); 163 | CHECK_INPUT(means3D); 164 | CHECK_INPUT(radii); 165 | CHECK_INPUT(colors); 166 | CHECK_INPUT(scales); 167 | CHECK_INPUT(rotations); 168 | CHECK_INPUT(transMat_precomp); 169 | CHECK_INPUT(viewmatrix); 170 | CHECK_INPUT(projmatrix); 171 | CHECK_INPUT(sh); 172 | CHECK_INPUT(campos); 173 | CHECK_INPUT(binningBuffer); 174 | CHECK_INPUT(imageBuffer); 175 | CHECK_INPUT(geomBuffer); 176 | 177 | const int P = means3D.size(0); 178 | const int H = dL_dout_color.size(1); 179 | const int W = dL_dout_color.size(2); 180 | 181 | int M = 0; 182 | if(sh.size(0) != 0) 183 | { 184 | M = sh.size(1); 185 | } 186 | 187 | torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); 188 | torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); 189 | torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); 190 | torch::Tensor dL_dnormal = torch::zeros({P, 3}, means3D.options()); 191 | torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); 192 | torch::Tensor dL_dtransMat = torch::zeros({P, 9}, means3D.options()); 193 | torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); 194 | torch::Tensor dL_dscales = torch::zeros({P, 2}, means3D.options()); 195 | torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); 196 | 197 | if(P != 0) 198 | { 199 | CudaRasterizer::Rasterizer::backward(P, degree, M, R, 200 | background.contiguous().data(), 201 | W, H, 202 | means3D.contiguous().data(), 203 | sh.contiguous().data(), 204 | colors.contiguous().data(), 205 | scales.data_ptr(), 206 | scale_modifier, 207 | rotations.data_ptr(), 208 | transMat_precomp.contiguous().data(), 209 | viewmatrix.contiguous().data(), 210 | projmatrix.contiguous().data(), 211 | campos.contiguous().data(), 212 | tan_fovx, 213 | tan_fovy, 214 | radii.contiguous().data(), 215 | reinterpret_cast(geomBuffer.contiguous().data_ptr()), 216 | reinterpret_cast(binningBuffer.contiguous().data_ptr()), 217 | reinterpret_cast(imageBuffer.contiguous().data_ptr()), 218 | dL_dout_color.contiguous().data(), 219 | dL_dout_others.contiguous().data(), 220 | dL_dmeans2D.contiguous().data(), 221 | dL_dnormal.contiguous().data(), 222 | dL_dopacity.contiguous().data(), 223 | dL_dcolors.contiguous().data(), 224 | dL_dmeans3D.contiguous().data(), 225 | dL_dtransMat.contiguous().data(), 226 | dL_dsh.contiguous().data(), 227 | dL_dscales.contiguous().data(), 228 | dL_drotations.contiguous().data(), 229 | debug); 230 | } 231 | 232 | return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dtransMat, dL_dsh, dL_dscales, dL_drotations); 233 | } 234 | 235 | torch::Tensor markVisible( 236 | torch::Tensor& means3D, 237 | torch::Tensor& viewmatrix, 238 | torch::Tensor& projmatrix) 239 | { 240 | const int P = means3D.size(0); 241 | 242 | torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); 243 | 244 | if(P != 0) 245 | { 246 | CudaRasterizer::Rasterizer::markVisible(P, 247 | means3D.contiguous().data(), 248 | viewmatrix.contiguous().data(), 249 | projmatrix.contiguous().data(), 250 | present.contiguous().data()); 251 | } 252 | 253 | return present; 254 | } 255 | -------------------------------------------------------------------------------- /rasterize_points.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | std::tuple 19 | RasterizeGaussiansCUDA( 20 | const torch::Tensor& background, 21 | const torch::Tensor& means3D, 22 | const torch::Tensor& colors, 23 | const torch::Tensor& opacity, 24 | const torch::Tensor& scales, 25 | const torch::Tensor& rotations, 26 | const float scale_modifier, 27 | const torch::Tensor& transMat_precomp, 28 | const torch::Tensor& viewmatrix, 29 | const torch::Tensor& projmatrix, 30 | const float tan_fovx, 31 | const float tan_fovy, 32 | const int image_height, 33 | const int image_width, 34 | const torch::Tensor& sh, 35 | const int degree, 36 | const torch::Tensor& campos, 37 | const bool prefiltered, 38 | const bool debug); 39 | 40 | std::tuple 41 | RasterizeGaussiansBackwardCUDA( 42 | const torch::Tensor& background, 43 | const torch::Tensor& means3D, 44 | const torch::Tensor& radii, 45 | const torch::Tensor& colors, 46 | const torch::Tensor& scales, 47 | const torch::Tensor& rotations, 48 | const float scale_modifier, 49 | const torch::Tensor& transMat_precomp, 50 | const torch::Tensor& viewmatrix, 51 | const torch::Tensor& projmatrix, 52 | const float tan_fovx, 53 | const float tan_fovy, 54 | const torch::Tensor& dL_dout_color, 55 | const torch::Tensor& dL_dout_others, 56 | const torch::Tensor& sh, 57 | const int degree, 58 | const torch::Tensor& campos, 59 | const torch::Tensor& geomBuffer, 60 | const int R, 61 | const torch::Tensor& binningBuffer, 62 | const torch::Tensor& imageBuffer, 63 | const bool debug); 64 | 65 | torch::Tensor markVisible( 66 | torch::Tensor& means3D, 67 | torch::Tensor& viewmatrix, 68 | torch::Tensor& projmatrix); 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | os.path.dirname(os.path.abspath(__file__)) 16 | 17 | setup( 18 | name="diff_surfel_rasterization", 19 | packages=['diff_surfel_rasterization'], 20 | version='0.0.1', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name="diff_surfel_rasterization._C", 24 | sources=[ 25 | "cuda_rasterizer/rasterizer_impl.cu", 26 | "cuda_rasterizer/forward.cu", 27 | "cuda_rasterizer/backward.cu", 28 | "rasterize_points.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /third_party/stbi_image_write.h: -------------------------------------------------------------------------------- 1 | /* stb_image_write - v1.16 - public domain - http://nothings.org/stb 2 | writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015 3 | no warranty implied; use at your own risk 4 | 5 | Before #including, 6 | 7 | #define STB_IMAGE_WRITE_IMPLEMENTATION 8 | 9 | in the file that you want to have the implementation. 10 | 11 | Will probably not work correctly with strict-aliasing optimizations. 12 | 13 | ABOUT: 14 | 15 | This header file is a library for writing images to C stdio or a callback. 16 | 17 | The PNG output is not optimal; it is 20-50% larger than the file 18 | written by a decent optimizing implementation; though providing a custom 19 | zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that. 20 | This library is designed for source code compactness and simplicity, 21 | not optimal image file size or run-time performance. 22 | 23 | BUILDING: 24 | 25 | You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h. 26 | You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace 27 | malloc,realloc,free. 28 | You can #define STBIW_MEMMOVE() to replace memmove() 29 | You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function 30 | for PNG compression (instead of the builtin one), it must have the following signature: 31 | unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality); 32 | The returned data will be freed with STBIW_FREE() (free() by default), 33 | so it must be heap allocated with STBIW_MALLOC() (malloc() by default), 34 | 35 | UNICODE: 36 | 37 | If compiling for Windows and you wish to use Unicode filenames, compile 38 | with 39 | #define STBIW_WINDOWS_UTF8 40 | and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert 41 | Windows wchar_t filenames to utf8. 42 | 43 | USAGE: 44 | 45 | There are five functions, one for each image file format: 46 | 47 | int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); 48 | int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); 49 | int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); 50 | int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality); 51 | int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); 52 | 53 | void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically 54 | 55 | There are also five equivalent functions that use an arbitrary write function. You are 56 | expected to open/close your file-equivalent before and after calling these: 57 | 58 | int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); 59 | int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 60 | int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 61 | int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); 62 | int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); 63 | 64 | where the callback is: 65 | void stbi_write_func(void *context, void *data, int size); 66 | 67 | You can configure it with these global variables: 68 | int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE 69 | int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression 70 | int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode 71 | 72 | 73 | You can define STBI_WRITE_NO_STDIO to disable the file variant of these 74 | functions, so the library will not use stdio.h at all. However, this will 75 | also disable HDR writing, because it requires stdio for formatted output. 76 | 77 | Each function returns 0 on failure and non-0 on success. 78 | 79 | The functions create an image file defined by the parameters. The image 80 | is a rectangle of pixels stored from left-to-right, top-to-bottom. 81 | Each pixel contains 'comp' channels of data stored interleaved with 8-bits 82 | per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is 83 | monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall. 84 | The *data pointer points to the first byte of the top-left-most pixel. 85 | For PNG, "stride_in_bytes" is the distance in bytes from the first byte of 86 | a row of pixels to the first byte of the next row of pixels. 87 | 88 | PNG creates output files with the same number of components as the input. 89 | The BMP format expands Y to RGB in the file format and does not 90 | output alpha. 91 | 92 | PNG supports writing rectangles of data even when the bytes storing rows of 93 | data are not consecutive in memory (e.g. sub-rectangles of a larger image), 94 | by supplying the stride between the beginning of adjacent rows. The other 95 | formats do not. (Thus you cannot write a native-format BMP through the BMP 96 | writer, both because it is in BGR order and because it may have padding 97 | at the end of the line.) 98 | 99 | PNG allows you to set the deflate compression level by setting the global 100 | variable 'stbi_write_png_compression_level' (it defaults to 8). 101 | 102 | HDR expects linear float data. Since the format is always 32-bit rgb(e) 103 | data, alpha (if provided) is discarded, and for monochrome data it is 104 | replicated across all three channels. 105 | 106 | TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed 107 | data, set the global variable 'stbi_write_tga_with_rle' to 0. 108 | 109 | JPEG does ignore alpha channels in input data; quality is between 1 and 100. 110 | Higher quality looks better but results in a bigger image. 111 | JPEG baseline (no JPEG progressive). 112 | 113 | CREDITS: 114 | 115 | 116 | Sean Barrett - PNG/BMP/TGA 117 | Baldur Karlsson - HDR 118 | Jean-Sebastien Guay - TGA monochrome 119 | Tim Kelsey - misc enhancements 120 | Alan Hickman - TGA RLE 121 | Emmanuel Julien - initial file IO callback implementation 122 | Jon Olick - original jo_jpeg.cpp code 123 | Daniel Gibson - integrate JPEG, allow external zlib 124 | Aarni Koskela - allow choosing PNG filter 125 | 126 | bugfixes: 127 | github:Chribba 128 | Guillaume Chereau 129 | github:jry2 130 | github:romigrou 131 | Sergio Gonzalez 132 | Jonas Karlsson 133 | Filip Wasil 134 | Thatcher Ulrich 135 | github:poppolopoppo 136 | Patrick Boettcher 137 | github:xeekworx 138 | Cap Petschulat 139 | Simon Rodriguez 140 | Ivan Tikhonov 141 | github:ignotion 142 | Adam Schackart 143 | Andrew Kensler 144 | 145 | LICENSE 146 | 147 | See end of file for license information. 148 | 149 | */ 150 | 151 | #ifndef INCLUDE_STB_IMAGE_WRITE_H 152 | #define INCLUDE_STB_IMAGE_WRITE_H 153 | 154 | #include 155 | 156 | // if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline' 157 | #ifndef STBIWDEF 158 | #ifdef STB_IMAGE_WRITE_STATIC 159 | #define STBIWDEF static 160 | #else 161 | #ifdef __cplusplus 162 | #define STBIWDEF extern "C" 163 | #else 164 | #define STBIWDEF extern 165 | #endif 166 | #endif 167 | #endif 168 | 169 | #ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations 170 | STBIWDEF int stbi_write_tga_with_rle; 171 | STBIWDEF int stbi_write_png_compression_level; 172 | STBIWDEF int stbi_write_force_png_filter; 173 | #endif 174 | 175 | #ifndef STBI_WRITE_NO_STDIO 176 | STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); 177 | STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); 178 | STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); 179 | STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); 180 | STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality); 181 | 182 | #ifdef STBIW_WINDOWS_UTF8 183 | STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); 184 | #endif 185 | #endif 186 | 187 | typedef void stbi_write_func(void *context, void *data, int size); 188 | 189 | STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); 190 | STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 191 | STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); 192 | STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); 193 | STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); 194 | 195 | STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean); 196 | 197 | #endif//INCLUDE_STB_IMAGE_WRITE_H 198 | 199 | #ifdef STB_IMAGE_WRITE_IMPLEMENTATION 200 | 201 | #ifdef _WIN32 202 | #ifndef _CRT_SECURE_NO_WARNINGS 203 | #define _CRT_SECURE_NO_WARNINGS 204 | #endif 205 | #ifndef _CRT_NONSTDC_NO_DEPRECATE 206 | #define _CRT_NONSTDC_NO_DEPRECATE 207 | #endif 208 | #endif 209 | 210 | #ifndef STBI_WRITE_NO_STDIO 211 | #include 212 | #endif // STBI_WRITE_NO_STDIO 213 | 214 | #include 215 | #include 216 | #include 217 | #include 218 | 219 | #if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED)) 220 | // ok 221 | #elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED) 222 | // ok 223 | #else 224 | #error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)." 225 | #endif 226 | 227 | #ifndef STBIW_MALLOC 228 | #define STBIW_MALLOC(sz) malloc(sz) 229 | #define STBIW_REALLOC(p,newsz) realloc(p,newsz) 230 | #define STBIW_FREE(p) free(p) 231 | #endif 232 | 233 | #ifndef STBIW_REALLOC_SIZED 234 | #define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz) 235 | #endif 236 | 237 | 238 | #ifndef STBIW_MEMMOVE 239 | #define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz) 240 | #endif 241 | 242 | 243 | #ifndef STBIW_ASSERT 244 | #include 245 | #define STBIW_ASSERT(x) assert(x) 246 | #endif 247 | 248 | #define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff) 249 | 250 | #ifdef STB_IMAGE_WRITE_STATIC 251 | static int stbi_write_png_compression_level = 8; 252 | static int stbi_write_tga_with_rle = 1; 253 | static int stbi_write_force_png_filter = -1; 254 | #else 255 | int stbi_write_png_compression_level = 8; 256 | int stbi_write_tga_with_rle = 1; 257 | int stbi_write_force_png_filter = -1; 258 | #endif 259 | 260 | static int stbi__flip_vertically_on_write = 0; 261 | 262 | STBIWDEF void stbi_flip_vertically_on_write(int flag) 263 | { 264 | stbi__flip_vertically_on_write = flag; 265 | } 266 | 267 | typedef struct 268 | { 269 | stbi_write_func *func; 270 | void *context; 271 | unsigned char buffer[64]; 272 | int buf_used; 273 | } stbi__write_context; 274 | 275 | // initialize a callback-based context 276 | static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context) 277 | { 278 | s->func = c; 279 | s->context = context; 280 | } 281 | 282 | #ifndef STBI_WRITE_NO_STDIO 283 | 284 | static void stbi__stdio_write(void *context, void *data, int size) 285 | { 286 | fwrite(data,1,size,(FILE*) context); 287 | } 288 | 289 | #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) 290 | #ifdef __cplusplus 291 | #define STBIW_EXTERN extern "C" 292 | #else 293 | #define STBIW_EXTERN extern 294 | #endif 295 | STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); 296 | STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); 297 | 298 | STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) 299 | { 300 | return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); 301 | } 302 | #endif 303 | 304 | static FILE *stbiw__fopen(char const *filename, char const *mode) 305 | { 306 | FILE *f; 307 | #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) 308 | wchar_t wMode[64]; 309 | wchar_t wFilename[1024]; 310 | if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) 311 | return 0; 312 | 313 | if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) 314 | return 0; 315 | 316 | #if defined(_MSC_VER) && _MSC_VER >= 1400 317 | if (0 != _wfopen_s(&f, wFilename, wMode)) 318 | f = 0; 319 | #else 320 | f = _wfopen(wFilename, wMode); 321 | #endif 322 | 323 | #elif defined(_MSC_VER) && _MSC_VER >= 1400 324 | if (0 != fopen_s(&f, filename, mode)) 325 | f=0; 326 | #else 327 | f = fopen(filename, mode); 328 | #endif 329 | return f; 330 | } 331 | 332 | static int stbi__start_write_file(stbi__write_context *s, const char *filename) 333 | { 334 | FILE *f = stbiw__fopen(filename, "wb"); 335 | stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f); 336 | return f != NULL; 337 | } 338 | 339 | static void stbi__end_write_file(stbi__write_context *s) 340 | { 341 | fclose((FILE *)s->context); 342 | } 343 | 344 | #endif // !STBI_WRITE_NO_STDIO 345 | 346 | typedef unsigned int stbiw_uint32; 347 | typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1]; 348 | 349 | static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v) 350 | { 351 | while (*fmt) { 352 | switch (*fmt++) { 353 | case ' ': break; 354 | case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int)); 355 | s->func(s->context,&x,1); 356 | break; } 357 | case '2': { int x = va_arg(v,int); 358 | unsigned char b[2]; 359 | b[0] = STBIW_UCHAR(x); 360 | b[1] = STBIW_UCHAR(x>>8); 361 | s->func(s->context,b,2); 362 | break; } 363 | case '4': { stbiw_uint32 x = va_arg(v,int); 364 | unsigned char b[4]; 365 | b[0]=STBIW_UCHAR(x); 366 | b[1]=STBIW_UCHAR(x>>8); 367 | b[2]=STBIW_UCHAR(x>>16); 368 | b[3]=STBIW_UCHAR(x>>24); 369 | s->func(s->context,b,4); 370 | break; } 371 | default: 372 | STBIW_ASSERT(0); 373 | return; 374 | } 375 | } 376 | } 377 | 378 | static void stbiw__writef(stbi__write_context *s, const char *fmt, ...) 379 | { 380 | va_list v; 381 | va_start(v, fmt); 382 | stbiw__writefv(s, fmt, v); 383 | va_end(v); 384 | } 385 | 386 | static void stbiw__write_flush(stbi__write_context *s) 387 | { 388 | if (s->buf_used) { 389 | s->func(s->context, &s->buffer, s->buf_used); 390 | s->buf_used = 0; 391 | } 392 | } 393 | 394 | static void stbiw__putc(stbi__write_context *s, unsigned char c) 395 | { 396 | s->func(s->context, &c, 1); 397 | } 398 | 399 | static void stbiw__write1(stbi__write_context *s, unsigned char a) 400 | { 401 | if ((size_t)s->buf_used + 1 > sizeof(s->buffer)) 402 | stbiw__write_flush(s); 403 | s->buffer[s->buf_used++] = a; 404 | } 405 | 406 | static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c) 407 | { 408 | int n; 409 | if ((size_t)s->buf_used + 3 > sizeof(s->buffer)) 410 | stbiw__write_flush(s); 411 | n = s->buf_used; 412 | s->buf_used = n+3; 413 | s->buffer[n+0] = a; 414 | s->buffer[n+1] = b; 415 | s->buffer[n+2] = c; 416 | } 417 | 418 | static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d) 419 | { 420 | unsigned char bg[3] = { 255, 0, 255}, px[3]; 421 | int k; 422 | 423 | if (write_alpha < 0) 424 | stbiw__write1(s, d[comp - 1]); 425 | 426 | switch (comp) { 427 | case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case 428 | case 1: 429 | if (expand_mono) 430 | stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp 431 | else 432 | stbiw__write1(s, d[0]); // monochrome TGA 433 | break; 434 | case 4: 435 | if (!write_alpha) { 436 | // composite against pink background 437 | for (k = 0; k < 3; ++k) 438 | px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255; 439 | stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]); 440 | break; 441 | } 442 | /* FALLTHROUGH */ 443 | case 3: 444 | stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]); 445 | break; 446 | } 447 | if (write_alpha > 0) 448 | stbiw__write1(s, d[comp - 1]); 449 | } 450 | 451 | static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono) 452 | { 453 | stbiw_uint32 zero = 0; 454 | int i,j, j_end; 455 | 456 | if (y <= 0) 457 | return; 458 | 459 | if (stbi__flip_vertically_on_write) 460 | vdir *= -1; 461 | 462 | if (vdir < 0) { 463 | j_end = -1; j = y-1; 464 | } else { 465 | j_end = y; j = 0; 466 | } 467 | 468 | for (; j != j_end; j += vdir) { 469 | for (i=0; i < x; ++i) { 470 | unsigned char *d = (unsigned char *) data + (j*x+i)*comp; 471 | stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d); 472 | } 473 | stbiw__write_flush(s); 474 | s->func(s->context, &zero, scanline_pad); 475 | } 476 | } 477 | 478 | static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...) 479 | { 480 | if (y < 0 || x < 0) { 481 | return 0; 482 | } else { 483 | va_list v; 484 | va_start(v, fmt); 485 | stbiw__writefv(s, fmt, v); 486 | va_end(v); 487 | stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono); 488 | return 1; 489 | } 490 | } 491 | 492 | static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data) 493 | { 494 | if (comp != 4) { 495 | // write RGB bitmap 496 | int pad = (-x*3) & 3; 497 | return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad, 498 | "11 4 22 4" "4 44 22 444444", 499 | 'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header 500 | 40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header 501 | } else { 502 | // RGBA bitmaps need a v4 header 503 | // use BI_BITFIELDS mode with 32bpp and alpha mask 504 | // (straight BI_RGB with alpha mask doesn't work in most readers) 505 | return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0, 506 | "11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444", 507 | 'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header 508 | 108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header 509 | } 510 | } 511 | 512 | STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) 513 | { 514 | stbi__write_context s = { 0 }; 515 | stbi__start_write_callbacks(&s, func, context); 516 | return stbi_write_bmp_core(&s, x, y, comp, data); 517 | } 518 | 519 | #ifndef STBI_WRITE_NO_STDIO 520 | STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data) 521 | { 522 | stbi__write_context s = { 0 }; 523 | if (stbi__start_write_file(&s,filename)) { 524 | int r = stbi_write_bmp_core(&s, x, y, comp, data); 525 | stbi__end_write_file(&s); 526 | return r; 527 | } else 528 | return 0; 529 | } 530 | #endif //!STBI_WRITE_NO_STDIO 531 | 532 | static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data) 533 | { 534 | int has_alpha = (comp == 2 || comp == 4); 535 | int colorbytes = has_alpha ? comp-1 : comp; 536 | int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3 537 | 538 | if (y < 0 || x < 0) 539 | return 0; 540 | 541 | if (!stbi_write_tga_with_rle) { 542 | return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0, 543 | "111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8); 544 | } else { 545 | int i,j,k; 546 | int jend, jdir; 547 | 548 | stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8); 549 | 550 | if (stbi__flip_vertically_on_write) { 551 | j = 0; 552 | jend = y; 553 | jdir = 1; 554 | } else { 555 | j = y-1; 556 | jend = -1; 557 | jdir = -1; 558 | } 559 | for (; j != jend; j += jdir) { 560 | unsigned char *row = (unsigned char *) data + j * x * comp; 561 | int len; 562 | 563 | for (i = 0; i < x; i += len) { 564 | unsigned char *begin = row + i * comp; 565 | int diff = 1; 566 | len = 1; 567 | 568 | if (i < x - 1) { 569 | ++len; 570 | diff = memcmp(begin, row + (i + 1) * comp, comp); 571 | if (diff) { 572 | const unsigned char *prev = begin; 573 | for (k = i + 2; k < x && len < 128; ++k) { 574 | if (memcmp(prev, row + k * comp, comp)) { 575 | prev += comp; 576 | ++len; 577 | } else { 578 | --len; 579 | break; 580 | } 581 | } 582 | } else { 583 | for (k = i + 2; k < x && len < 128; ++k) { 584 | if (!memcmp(begin, row + k * comp, comp)) { 585 | ++len; 586 | } else { 587 | break; 588 | } 589 | } 590 | } 591 | } 592 | 593 | if (diff) { 594 | unsigned char header = STBIW_UCHAR(len - 1); 595 | stbiw__write1(s, header); 596 | for (k = 0; k < len; ++k) { 597 | stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp); 598 | } 599 | } else { 600 | unsigned char header = STBIW_UCHAR(len - 129); 601 | stbiw__write1(s, header); 602 | stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin); 603 | } 604 | } 605 | } 606 | stbiw__write_flush(s); 607 | } 608 | return 1; 609 | } 610 | 611 | STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) 612 | { 613 | stbi__write_context s = { 0 }; 614 | stbi__start_write_callbacks(&s, func, context); 615 | return stbi_write_tga_core(&s, x, y, comp, (void *) data); 616 | } 617 | 618 | #ifndef STBI_WRITE_NO_STDIO 619 | STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data) 620 | { 621 | stbi__write_context s = { 0 }; 622 | if (stbi__start_write_file(&s,filename)) { 623 | int r = stbi_write_tga_core(&s, x, y, comp, (void *) data); 624 | stbi__end_write_file(&s); 625 | return r; 626 | } else 627 | return 0; 628 | } 629 | #endif 630 | 631 | // ************************************************************************************************* 632 | // Radiance RGBE HDR writer 633 | // by Baldur Karlsson 634 | 635 | #define stbiw__max(a, b) ((a) > (b) ? (a) : (b)) 636 | 637 | #ifndef STBI_WRITE_NO_STDIO 638 | 639 | static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear) 640 | { 641 | int exponent; 642 | float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2])); 643 | 644 | if (maxcomp < 1e-32f) { 645 | rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0; 646 | } else { 647 | float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp; 648 | 649 | rgbe[0] = (unsigned char)(linear[0] * normalize); 650 | rgbe[1] = (unsigned char)(linear[1] * normalize); 651 | rgbe[2] = (unsigned char)(linear[2] * normalize); 652 | rgbe[3] = (unsigned char)(exponent + 128); 653 | } 654 | } 655 | 656 | static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte) 657 | { 658 | unsigned char lengthbyte = STBIW_UCHAR(length+128); 659 | STBIW_ASSERT(length+128 <= 255); 660 | s->func(s->context, &lengthbyte, 1); 661 | s->func(s->context, &databyte, 1); 662 | } 663 | 664 | static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data) 665 | { 666 | unsigned char lengthbyte = STBIW_UCHAR(length); 667 | STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code 668 | s->func(s->context, &lengthbyte, 1); 669 | s->func(s->context, data, length); 670 | } 671 | 672 | static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline) 673 | { 674 | unsigned char scanlineheader[4] = { 2, 2, 0, 0 }; 675 | unsigned char rgbe[4]; 676 | float linear[3]; 677 | int x; 678 | 679 | scanlineheader[2] = (width&0xff00)>>8; 680 | scanlineheader[3] = (width&0x00ff); 681 | 682 | /* skip RLE for images too small or large */ 683 | if (width < 8 || width >= 32768) { 684 | for (x=0; x < width; x++) { 685 | switch (ncomp) { 686 | case 4: /* fallthrough */ 687 | case 3: linear[2] = scanline[x*ncomp + 2]; 688 | linear[1] = scanline[x*ncomp + 1]; 689 | linear[0] = scanline[x*ncomp + 0]; 690 | break; 691 | default: 692 | linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; 693 | break; 694 | } 695 | stbiw__linear_to_rgbe(rgbe, linear); 696 | s->func(s->context, rgbe, 4); 697 | } 698 | } else { 699 | int c,r; 700 | /* encode into scratch buffer */ 701 | for (x=0; x < width; x++) { 702 | switch(ncomp) { 703 | case 4: /* fallthrough */ 704 | case 3: linear[2] = scanline[x*ncomp + 2]; 705 | linear[1] = scanline[x*ncomp + 1]; 706 | linear[0] = scanline[x*ncomp + 0]; 707 | break; 708 | default: 709 | linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; 710 | break; 711 | } 712 | stbiw__linear_to_rgbe(rgbe, linear); 713 | scratch[x + width*0] = rgbe[0]; 714 | scratch[x + width*1] = rgbe[1]; 715 | scratch[x + width*2] = rgbe[2]; 716 | scratch[x + width*3] = rgbe[3]; 717 | } 718 | 719 | s->func(s->context, scanlineheader, 4); 720 | 721 | /* RLE each component separately */ 722 | for (c=0; c < 4; c++) { 723 | unsigned char *comp = &scratch[width*c]; 724 | 725 | x = 0; 726 | while (x < width) { 727 | // find first run 728 | r = x; 729 | while (r+2 < width) { 730 | if (comp[r] == comp[r+1] && comp[r] == comp[r+2]) 731 | break; 732 | ++r; 733 | } 734 | if (r+2 >= width) 735 | r = width; 736 | // dump up to first run 737 | while (x < r) { 738 | int len = r-x; 739 | if (len > 128) len = 128; 740 | stbiw__write_dump_data(s, len, &comp[x]); 741 | x += len; 742 | } 743 | // if there's a run, output it 744 | if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd 745 | // find next byte after run 746 | while (r < width && comp[r] == comp[x]) 747 | ++r; 748 | // output run up to r 749 | while (x < r) { 750 | int len = r-x; 751 | if (len > 127) len = 127; 752 | stbiw__write_run_data(s, len, comp[x]); 753 | x += len; 754 | } 755 | } 756 | } 757 | } 758 | } 759 | } 760 | 761 | static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data) 762 | { 763 | if (y <= 0 || x <= 0 || data == NULL) 764 | return 0; 765 | else { 766 | // Each component is stored separately. Allocate scratch space for full output scanline. 767 | unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4); 768 | int i, len; 769 | char buffer[128]; 770 | char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n"; 771 | s->func(s->context, header, sizeof(header)-1); 772 | 773 | #ifdef __STDC_LIB_EXT1__ 774 | len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); 775 | #else 776 | len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); 777 | #endif 778 | s->func(s->context, buffer, len); 779 | 780 | for(i=0; i < y; i++) 781 | stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i)); 782 | STBIW_FREE(scratch); 783 | return 1; 784 | } 785 | } 786 | 787 | STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data) 788 | { 789 | stbi__write_context s = { 0 }; 790 | stbi__start_write_callbacks(&s, func, context); 791 | return stbi_write_hdr_core(&s, x, y, comp, (float *) data); 792 | } 793 | 794 | STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data) 795 | { 796 | stbi__write_context s = { 0 }; 797 | if (stbi__start_write_file(&s,filename)) { 798 | int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data); 799 | stbi__end_write_file(&s); 800 | return r; 801 | } else 802 | return 0; 803 | } 804 | #endif // STBI_WRITE_NO_STDIO 805 | 806 | 807 | ////////////////////////////////////////////////////////////////////////////// 808 | // 809 | // PNG writer 810 | // 811 | 812 | #ifndef STBIW_ZLIB_COMPRESS 813 | // stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size() 814 | #define stbiw__sbraw(a) ((int *) (void *) (a) - 2) 815 | #define stbiw__sbm(a) stbiw__sbraw(a)[0] 816 | #define stbiw__sbn(a) stbiw__sbraw(a)[1] 817 | 818 | #define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a)) 819 | #define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0) 820 | #define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a))) 821 | 822 | #define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v)) 823 | #define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0) 824 | #define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0) 825 | 826 | static void *stbiw__sbgrowf(void **arr, int increment, int itemsize) 827 | { 828 | int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1; 829 | void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2); 830 | STBIW_ASSERT(p); 831 | if (p) { 832 | if (!*arr) ((int *) p)[1] = 0; 833 | *arr = (void *) ((int *) p + 2); 834 | stbiw__sbm(*arr) = m; 835 | } 836 | return *arr; 837 | } 838 | 839 | static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount) 840 | { 841 | while (*bitcount >= 8) { 842 | stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer)); 843 | *bitbuffer >>= 8; 844 | *bitcount -= 8; 845 | } 846 | return data; 847 | } 848 | 849 | static int stbiw__zlib_bitrev(int code, int codebits) 850 | { 851 | int res=0; 852 | while (codebits--) { 853 | res = (res << 1) | (code & 1); 854 | code >>= 1; 855 | } 856 | return res; 857 | } 858 | 859 | static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit) 860 | { 861 | int i; 862 | for (i=0; i < limit && i < 258; ++i) 863 | if (a[i] != b[i]) break; 864 | return i; 865 | } 866 | 867 | static unsigned int stbiw__zhash(unsigned char *data) 868 | { 869 | stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16); 870 | hash ^= hash << 3; 871 | hash += hash >> 5; 872 | hash ^= hash << 4; 873 | hash += hash >> 17; 874 | hash ^= hash << 25; 875 | hash += hash >> 6; 876 | return hash; 877 | } 878 | 879 | #define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount)) 880 | #define stbiw__zlib_add(code,codebits) \ 881 | (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush()) 882 | #define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c) 883 | // default huffman tables 884 | #define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8) 885 | #define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9) 886 | #define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7) 887 | #define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8) 888 | #define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n)) 889 | #define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n)) 890 | 891 | #define stbiw__ZHASH 16384 892 | 893 | #endif // STBIW_ZLIB_COMPRESS 894 | 895 | STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality) 896 | { 897 | #ifdef STBIW_ZLIB_COMPRESS 898 | // user provided a zlib compress implementation, use that 899 | return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality); 900 | #else // use builtin 901 | static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 }; 902 | static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 }; 903 | static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 }; 904 | static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 }; 905 | unsigned int bitbuf=0; 906 | int i,j, bitcount=0; 907 | unsigned char *out = NULL; 908 | unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**)); 909 | if (hash_table == NULL) 910 | return NULL; 911 | if (quality < 5) quality = 5; 912 | 913 | stbiw__sbpush(out, 0x78); // DEFLATE 32K window 914 | stbiw__sbpush(out, 0x5e); // FLEVEL = 1 915 | stbiw__zlib_add(1,1); // BFINAL = 1 916 | stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman 917 | 918 | for (i=0; i < stbiw__ZHASH; ++i) 919 | hash_table[i] = NULL; 920 | 921 | i=0; 922 | while (i < data_len-3) { 923 | // hash next 3 bytes of data to be compressed 924 | int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3; 925 | unsigned char *bestloc = 0; 926 | unsigned char **hlist = hash_table[h]; 927 | int n = stbiw__sbcount(hlist); 928 | for (j=0; j < n; ++j) { 929 | if (hlist[j]-data > i-32768) { // if entry lies within window 930 | int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i); 931 | if (d >= best) { best=d; bestloc=hlist[j]; } 932 | } 933 | } 934 | // when hash table entry is too long, delete half the entries 935 | if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) { 936 | STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality); 937 | stbiw__sbn(hash_table[h]) = quality; 938 | } 939 | stbiw__sbpush(hash_table[h],data+i); 940 | 941 | if (bestloc) { 942 | // "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal 943 | h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1); 944 | hlist = hash_table[h]; 945 | n = stbiw__sbcount(hlist); 946 | for (j=0; j < n; ++j) { 947 | if (hlist[j]-data > i-32767) { 948 | int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1); 949 | if (e > best) { // if next match is better, bail on current match 950 | bestloc = NULL; 951 | break; 952 | } 953 | } 954 | } 955 | } 956 | 957 | if (bestloc) { 958 | int d = (int) (data+i - bestloc); // distance back 959 | STBIW_ASSERT(d <= 32767 && best <= 258); 960 | for (j=0; best > lengthc[j+1]-1; ++j); 961 | stbiw__zlib_huff(j+257); 962 | if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]); 963 | for (j=0; d > distc[j+1]-1; ++j); 964 | stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5); 965 | if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]); 966 | i += best; 967 | } else { 968 | stbiw__zlib_huffb(data[i]); 969 | ++i; 970 | } 971 | } 972 | // write out final bytes 973 | for (;i < data_len; ++i) 974 | stbiw__zlib_huffb(data[i]); 975 | stbiw__zlib_huff(256); // end of block 976 | // pad with 0 bits to byte boundary 977 | while (bitcount) 978 | stbiw__zlib_add(0,1); 979 | 980 | for (i=0; i < stbiw__ZHASH; ++i) 981 | (void) stbiw__sbfree(hash_table[i]); 982 | STBIW_FREE(hash_table); 983 | 984 | // store uncompressed instead if compression was worse 985 | if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) { 986 | stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1 987 | for (j = 0; j < data_len;) { 988 | int blocklen = data_len - j; 989 | if (blocklen > 32767) blocklen = 32767; 990 | stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression 991 | stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN 992 | stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8)); 993 | stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN 994 | stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8)); 995 | memcpy(out+stbiw__sbn(out), data+j, blocklen); 996 | stbiw__sbn(out) += blocklen; 997 | j += blocklen; 998 | } 999 | } 1000 | 1001 | { 1002 | // compute adler32 on input 1003 | unsigned int s1=1, s2=0; 1004 | int blocklen = (int) (data_len % 5552); 1005 | j=0; 1006 | while (j < data_len) { 1007 | for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; } 1008 | s1 %= 65521; s2 %= 65521; 1009 | j += blocklen; 1010 | blocklen = 5552; 1011 | } 1012 | stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8)); 1013 | stbiw__sbpush(out, STBIW_UCHAR(s2)); 1014 | stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8)); 1015 | stbiw__sbpush(out, STBIW_UCHAR(s1)); 1016 | } 1017 | *out_len = stbiw__sbn(out); 1018 | // make returned pointer freeable 1019 | STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len); 1020 | return (unsigned char *) stbiw__sbraw(out); 1021 | #endif // STBIW_ZLIB_COMPRESS 1022 | } 1023 | 1024 | static unsigned int stbiw__crc32(unsigned char *buffer, int len) 1025 | { 1026 | #ifdef STBIW_CRC32 1027 | return STBIW_CRC32(buffer, len); 1028 | #else 1029 | static unsigned int crc_table[256] = 1030 | { 1031 | 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3, 1032 | 0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, 1033 | 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, 1034 | 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5, 1035 | 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, 1036 | 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, 1037 | 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F, 1038 | 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, 1039 | 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433, 1040 | 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01, 1041 | 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, 1042 | 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, 1043 | 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB, 1044 | 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, 1045 | 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F, 1046 | 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD, 1047 | 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, 1048 | 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, 1049 | 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7, 1050 | 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, 1051 | 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, 1052 | 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79, 1053 | 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, 1054 | 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, 1055 | 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713, 1056 | 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, 1057 | 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, 1058 | 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45, 1059 | 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, 1060 | 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9, 1061 | 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF, 1062 | 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D 1063 | }; 1064 | 1065 | unsigned int crc = ~0u; 1066 | int i; 1067 | for (i=0; i < len; ++i) 1068 | crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)]; 1069 | return ~crc; 1070 | #endif 1071 | } 1072 | 1073 | #define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4) 1074 | #define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v)); 1075 | #define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3]) 1076 | 1077 | static void stbiw__wpcrc(unsigned char **data, int len) 1078 | { 1079 | unsigned int crc = stbiw__crc32(*data - len - 4, len+4); 1080 | stbiw__wp32(*data, crc); 1081 | } 1082 | 1083 | static unsigned char stbiw__paeth(int a, int b, int c) 1084 | { 1085 | int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c); 1086 | if (pa <= pb && pa <= pc) return STBIW_UCHAR(a); 1087 | if (pb <= pc) return STBIW_UCHAR(b); 1088 | return STBIW_UCHAR(c); 1089 | } 1090 | 1091 | // @OPTIMIZE: provide an option that always forces left-predict or paeth predict 1092 | static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer) 1093 | { 1094 | static int mapping[] = { 0,1,2,3,4 }; 1095 | static int firstmap[] = { 0,1,0,5,6 }; 1096 | int *mymap = (y != 0) ? mapping : firstmap; 1097 | int i; 1098 | int type = mymap[filter_type]; 1099 | unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y); 1100 | int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes; 1101 | 1102 | if (type==0) { 1103 | memcpy(line_buffer, z, width*n); 1104 | return; 1105 | } 1106 | 1107 | // first loop isn't optimized since it's just one pixel 1108 | for (i = 0; i < n; ++i) { 1109 | switch (type) { 1110 | case 1: line_buffer[i] = z[i]; break; 1111 | case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break; 1112 | case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break; 1113 | case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break; 1114 | case 5: line_buffer[i] = z[i]; break; 1115 | case 6: line_buffer[i] = z[i]; break; 1116 | } 1117 | } 1118 | switch (type) { 1119 | case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break; 1120 | case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break; 1121 | case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break; 1122 | case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break; 1123 | case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break; 1124 | case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break; 1125 | } 1126 | } 1127 | 1128 | STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len) 1129 | { 1130 | int force_filter = stbi_write_force_png_filter; 1131 | int ctype[5] = { -1, 0, 4, 2, 6 }; 1132 | unsigned char sig[8] = { 137,80,78,71,13,10,26,10 }; 1133 | unsigned char *out,*o, *filt, *zlib; 1134 | signed char *line_buffer; 1135 | int j,zlen; 1136 | 1137 | if (stride_bytes == 0) 1138 | stride_bytes = x * n; 1139 | 1140 | if (force_filter >= 5) { 1141 | force_filter = -1; 1142 | } 1143 | 1144 | filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0; 1145 | line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; } 1146 | for (j=0; j < y; ++j) { 1147 | int filter_type; 1148 | if (force_filter > -1) { 1149 | filter_type = force_filter; 1150 | stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer); 1151 | } else { // Estimate the best filter by running through all of them: 1152 | int best_filter = 0, best_filter_val = 0x7fffffff, est, i; 1153 | for (filter_type = 0; filter_type < 5; filter_type++) { 1154 | stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer); 1155 | 1156 | // Estimate the entropy of the line using this filter; the less, the better. 1157 | est = 0; 1158 | for (i = 0; i < x*n; ++i) { 1159 | est += abs((signed char) line_buffer[i]); 1160 | } 1161 | if (est < best_filter_val) { 1162 | best_filter_val = est; 1163 | best_filter = filter_type; 1164 | } 1165 | } 1166 | if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it 1167 | stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer); 1168 | filter_type = best_filter; 1169 | } 1170 | } 1171 | // when we get here, filter_type contains the filter type, and line_buffer contains the data 1172 | filt[j*(x*n+1)] = (unsigned char) filter_type; 1173 | STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n); 1174 | } 1175 | STBIW_FREE(line_buffer); 1176 | zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level); 1177 | STBIW_FREE(filt); 1178 | if (!zlib) return 0; 1179 | 1180 | // each tag requires 12 bytes of overhead 1181 | out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12); 1182 | if (!out) return 0; 1183 | *out_len = 8 + 12+13 + 12+zlen + 12; 1184 | 1185 | o=out; 1186 | STBIW_MEMMOVE(o,sig,8); o+= 8; 1187 | stbiw__wp32(o, 13); // header length 1188 | stbiw__wptag(o, "IHDR"); 1189 | stbiw__wp32(o, x); 1190 | stbiw__wp32(o, y); 1191 | *o++ = 8; 1192 | *o++ = STBIW_UCHAR(ctype[n]); 1193 | *o++ = 0; 1194 | *o++ = 0; 1195 | *o++ = 0; 1196 | stbiw__wpcrc(&o,13); 1197 | 1198 | stbiw__wp32(o, zlen); 1199 | stbiw__wptag(o, "IDAT"); 1200 | STBIW_MEMMOVE(o, zlib, zlen); 1201 | o += zlen; 1202 | STBIW_FREE(zlib); 1203 | stbiw__wpcrc(&o, zlen); 1204 | 1205 | stbiw__wp32(o,0); 1206 | stbiw__wptag(o, "IEND"); 1207 | stbiw__wpcrc(&o,0); 1208 | 1209 | STBIW_ASSERT(o == out + *out_len); 1210 | 1211 | return out; 1212 | } 1213 | 1214 | #ifndef STBI_WRITE_NO_STDIO 1215 | STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes) 1216 | { 1217 | FILE *f; 1218 | int len; 1219 | unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); 1220 | if (png == NULL) return 0; 1221 | 1222 | f = stbiw__fopen(filename, "wb"); 1223 | if (!f) { STBIW_FREE(png); return 0; } 1224 | fwrite(png, 1, len, f); 1225 | fclose(f); 1226 | STBIW_FREE(png); 1227 | return 1; 1228 | } 1229 | #endif 1230 | 1231 | STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes) 1232 | { 1233 | int len; 1234 | unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); 1235 | if (png == NULL) return 0; 1236 | func(context, png, len); 1237 | STBIW_FREE(png); 1238 | return 1; 1239 | } 1240 | 1241 | 1242 | /* *************************************************************************** 1243 | * 1244 | * JPEG writer 1245 | * 1246 | * This is based on Jon Olick's jo_jpeg.cpp: 1247 | * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html 1248 | */ 1249 | 1250 | static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18, 1251 | 24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 }; 1252 | 1253 | static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) { 1254 | int bitBuf = *bitBufP, bitCnt = *bitCntP; 1255 | bitCnt += bs[1]; 1256 | bitBuf |= bs[0] << (24 - bitCnt); 1257 | while(bitCnt >= 8) { 1258 | unsigned char c = (bitBuf >> 16) & 255; 1259 | stbiw__putc(s, c); 1260 | if(c == 255) { 1261 | stbiw__putc(s, 0); 1262 | } 1263 | bitBuf <<= 8; 1264 | bitCnt -= 8; 1265 | } 1266 | *bitBufP = bitBuf; 1267 | *bitCntP = bitCnt; 1268 | } 1269 | 1270 | static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) { 1271 | float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p; 1272 | float z1, z2, z3, z4, z5, z11, z13; 1273 | 1274 | float tmp0 = d0 + d7; 1275 | float tmp7 = d0 - d7; 1276 | float tmp1 = d1 + d6; 1277 | float tmp6 = d1 - d6; 1278 | float tmp2 = d2 + d5; 1279 | float tmp5 = d2 - d5; 1280 | float tmp3 = d3 + d4; 1281 | float tmp4 = d3 - d4; 1282 | 1283 | // Even part 1284 | float tmp10 = tmp0 + tmp3; // phase 2 1285 | float tmp13 = tmp0 - tmp3; 1286 | float tmp11 = tmp1 + tmp2; 1287 | float tmp12 = tmp1 - tmp2; 1288 | 1289 | d0 = tmp10 + tmp11; // phase 3 1290 | d4 = tmp10 - tmp11; 1291 | 1292 | z1 = (tmp12 + tmp13) * 0.707106781f; // c4 1293 | d2 = tmp13 + z1; // phase 5 1294 | d6 = tmp13 - z1; 1295 | 1296 | // Odd part 1297 | tmp10 = tmp4 + tmp5; // phase 2 1298 | tmp11 = tmp5 + tmp6; 1299 | tmp12 = tmp6 + tmp7; 1300 | 1301 | // The rotator is modified from fig 4-8 to avoid extra negations. 1302 | z5 = (tmp10 - tmp12) * 0.382683433f; // c6 1303 | z2 = tmp10 * 0.541196100f + z5; // c2-c6 1304 | z4 = tmp12 * 1.306562965f + z5; // c2+c6 1305 | z3 = tmp11 * 0.707106781f; // c4 1306 | 1307 | z11 = tmp7 + z3; // phase 5 1308 | z13 = tmp7 - z3; 1309 | 1310 | *d5p = z13 + z2; // phase 6 1311 | *d3p = z13 - z2; 1312 | *d1p = z11 + z4; 1313 | *d7p = z11 - z4; 1314 | 1315 | *d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6; 1316 | } 1317 | 1318 | static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) { 1319 | int tmp1 = val < 0 ? -val : val; 1320 | val = val < 0 ? val-1 : val; 1321 | bits[1] = 1; 1322 | while(tmp1 >>= 1) { 1323 | ++bits[1]; 1324 | } 1325 | bits[0] = val & ((1<0)&&(DU[end0pos]==0); --end0pos) { 1368 | } 1369 | // end0pos = first element in reverse order !=0 1370 | if(end0pos == 0) { 1371 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); 1372 | return DU[0]; 1373 | } 1374 | for(i = 1; i <= end0pos; ++i) { 1375 | int startpos = i; 1376 | int nrzeroes; 1377 | unsigned short bits[2]; 1378 | for (; DU[i]==0 && i<=end0pos; ++i) { 1379 | } 1380 | nrzeroes = i-startpos; 1381 | if ( nrzeroes >= 16 ) { 1382 | int lng = nrzeroes>>4; 1383 | int nrmarker; 1384 | for (nrmarker=1; nrmarker <= lng; ++nrmarker) 1385 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes); 1386 | nrzeroes &= 15; 1387 | } 1388 | stbiw__jpg_calcBits(DU[i], bits); 1389 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]); 1390 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits); 1391 | } 1392 | if(end0pos != 63) { 1393 | stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); 1394 | } 1395 | return DU[0]; 1396 | } 1397 | 1398 | static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) { 1399 | // Constants that don't pollute global namespace 1400 | static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0}; 1401 | static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; 1402 | static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d}; 1403 | static const unsigned char std_ac_luminance_values[] = { 1404 | 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08, 1405 | 0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28, 1406 | 0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59, 1407 | 0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89, 1408 | 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6, 1409 | 0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2, 1410 | 0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa 1411 | }; 1412 | static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0}; 1413 | static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; 1414 | static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77}; 1415 | static const unsigned char std_ac_chrominance_values[] = { 1416 | 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91, 1417 | 0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26, 1418 | 0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58, 1419 | 0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87, 1420 | 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4, 1421 | 0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda, 1422 | 0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa 1423 | }; 1424 | // Huffman tables 1425 | static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}}; 1426 | static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}}; 1427 | static const unsigned short YAC_HT[256][2] = { 1428 | {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1429 | {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1430 | {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1431 | {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1432 | {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1433 | {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1434 | {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1435 | {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1436 | {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1437 | {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1438 | {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1439 | {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1440 | {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1441 | {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1442 | {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0}, 1443 | {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} 1444 | }; 1445 | static const unsigned short UVAC_HT[256][2] = { 1446 | {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1447 | {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1448 | {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1449 | {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1450 | {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1451 | {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1452 | {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1453 | {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1454 | {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1455 | {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1456 | {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1457 | {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1458 | {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1459 | {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, 1460 | {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0}, 1461 | {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} 1462 | }; 1463 | static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22, 1464 | 37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99}; 1465 | static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99, 1466 | 99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99}; 1467 | static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f, 1468 | 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f }; 1469 | 1470 | int row, col, i, k, subsample; 1471 | float fdtbl_Y[64], fdtbl_UV[64]; 1472 | unsigned char YTable[64], UVTable[64]; 1473 | 1474 | if(!data || !width || !height || comp > 4 || comp < 1) { 1475 | return 0; 1476 | } 1477 | 1478 | quality = quality ? quality : 90; 1479 | subsample = quality <= 90 ? 1 : 0; 1480 | quality = quality < 1 ? 1 : quality > 100 ? 100 : quality; 1481 | quality = quality < 50 ? 5000 / quality : 200 - quality * 2; 1482 | 1483 | for(i = 0; i < 64; ++i) { 1484 | int uvti, yti = (YQT[i]*quality+50)/100; 1485 | YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti); 1486 | uvti = (UVQT[i]*quality+50)/100; 1487 | UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti); 1488 | } 1489 | 1490 | for(row = 0, k = 0; row < 8; ++row) { 1491 | for(col = 0; col < 8; ++col, ++k) { 1492 | fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); 1493 | fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); 1494 | } 1495 | } 1496 | 1497 | // Write Headers 1498 | { 1499 | static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 }; 1500 | static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 }; 1501 | const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width), 1502 | 3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 }; 1503 | s->func(s->context, (void*)head0, sizeof(head0)); 1504 | s->func(s->context, (void*)YTable, sizeof(YTable)); 1505 | stbiw__putc(s, 1); 1506 | s->func(s->context, UVTable, sizeof(UVTable)); 1507 | s->func(s->context, (void*)head1, sizeof(head1)); 1508 | s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1); 1509 | s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values)); 1510 | stbiw__putc(s, 0x10); // HTYACinfo 1511 | s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1); 1512 | s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values)); 1513 | stbiw__putc(s, 1); // HTUDCinfo 1514 | s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1); 1515 | s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values)); 1516 | stbiw__putc(s, 0x11); // HTUACinfo 1517 | s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1); 1518 | s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values)); 1519 | s->func(s->context, (void*)head2, sizeof(head2)); 1520 | } 1521 | 1522 | // Encode 8x8 macroblocks 1523 | { 1524 | static const unsigned short fillBits[] = {0x7F, 7}; 1525 | int DCY=0, DCU=0, DCV=0; 1526 | int bitBuf=0, bitCnt=0; 1527 | // comp == 2 is grey+alpha (alpha is ignored) 1528 | int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0; 1529 | const unsigned char *dataR = (const unsigned char *)data; 1530 | const unsigned char *dataG = dataR + ofsG; 1531 | const unsigned char *dataB = dataR + ofsB; 1532 | int x, y, pos; 1533 | if(subsample) { 1534 | for(y = 0; y < height; y += 16) { 1535 | for(x = 0; x < width; x += 16) { 1536 | float Y[256], U[256], V[256]; 1537 | for(row = y, pos = 0; row < y+16; ++row) { 1538 | // row >= height => use last input row 1539 | int clamped_row = (row < height) ? row : height - 1; 1540 | int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; 1541 | for(col = x; col < x+16; ++col, ++pos) { 1542 | // if col >= width => use pixel from last input column 1543 | int p = base_p + ((col < width) ? col : (width-1))*comp; 1544 | float r = dataR[p], g = dataG[p], b = dataB[p]; 1545 | Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; 1546 | U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; 1547 | V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; 1548 | } 1549 | } 1550 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1551 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1552 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1553 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1554 | 1555 | // subsample U,V 1556 | { 1557 | float subU[64], subV[64]; 1558 | int yy, xx; 1559 | for(yy = 0, pos = 0; yy < 8; ++yy) { 1560 | for(xx = 0; xx < 8; ++xx, ++pos) { 1561 | int j = yy*32+xx*2; 1562 | subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f; 1563 | subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f; 1564 | } 1565 | } 1566 | DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); 1567 | DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); 1568 | } 1569 | } 1570 | } 1571 | } else { 1572 | for(y = 0; y < height; y += 8) { 1573 | for(x = 0; x < width; x += 8) { 1574 | float Y[64], U[64], V[64]; 1575 | for(row = y, pos = 0; row < y+8; ++row) { 1576 | // row >= height => use last input row 1577 | int clamped_row = (row < height) ? row : height - 1; 1578 | int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; 1579 | for(col = x; col < x+8; ++col, ++pos) { 1580 | // if col >= width => use pixel from last input column 1581 | int p = base_p + ((col < width) ? col : (width-1))*comp; 1582 | float r = dataR[p], g = dataG[p], b = dataB[p]; 1583 | Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; 1584 | U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; 1585 | V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; 1586 | } 1587 | } 1588 | 1589 | DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT); 1590 | DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); 1591 | DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); 1592 | } 1593 | } 1594 | } 1595 | 1596 | // Do the bit alignment of the EOI marker 1597 | stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits); 1598 | } 1599 | 1600 | // EOI 1601 | stbiw__putc(s, 0xFF); 1602 | stbiw__putc(s, 0xD9); 1603 | 1604 | return 1; 1605 | } 1606 | 1607 | STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality) 1608 | { 1609 | stbi__write_context s = { 0 }; 1610 | stbi__start_write_callbacks(&s, func, context); 1611 | return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality); 1612 | } 1613 | 1614 | 1615 | #ifndef STBI_WRITE_NO_STDIO 1616 | STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality) 1617 | { 1618 | stbi__write_context s = { 0 }; 1619 | if (stbi__start_write_file(&s,filename)) { 1620 | int r = stbi_write_jpg_core(&s, x, y, comp, data, quality); 1621 | stbi__end_write_file(&s); 1622 | return r; 1623 | } else 1624 | return 0; 1625 | } 1626 | #endif 1627 | 1628 | #endif // STB_IMAGE_WRITE_IMPLEMENTATION 1629 | 1630 | /* Revision history 1631 | 1.16 (2021-07-11) 1632 | make Deflate code emit uncompressed blocks when it would otherwise expand 1633 | support writing BMPs with alpha channel 1634 | 1.15 (2020-07-13) unknown 1635 | 1.14 (2020-02-02) updated JPEG writer to downsample chroma channels 1636 | 1.13 1637 | 1.12 1638 | 1.11 (2019-08-11) 1639 | 1640 | 1.10 (2019-02-07) 1641 | support utf8 filenames in Windows; fix warnings and platform ifdefs 1642 | 1.09 (2018-02-11) 1643 | fix typo in zlib quality API, improve STB_I_W_STATIC in C++ 1644 | 1.08 (2018-01-29) 1645 | add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter 1646 | 1.07 (2017-07-24) 1647 | doc fix 1648 | 1.06 (2017-07-23) 1649 | writing JPEG (using Jon Olick's code) 1650 | 1.05 ??? 1651 | 1.04 (2017-03-03) 1652 | monochrome BMP expansion 1653 | 1.03 ??? 1654 | 1.02 (2016-04-02) 1655 | avoid allocating large structures on the stack 1656 | 1.01 (2016-01-16) 1657 | STBIW_REALLOC_SIZED: support allocators with no realloc support 1658 | avoid race-condition in crc initialization 1659 | minor compile issues 1660 | 1.00 (2015-09-14) 1661 | installable file IO function 1662 | 0.99 (2015-09-13) 1663 | warning fixes; TGA rle support 1664 | 0.98 (2015-04-08) 1665 | added STBIW_MALLOC, STBIW_ASSERT etc 1666 | 0.97 (2015-01-18) 1667 | fixed HDR asserts, rewrote HDR rle logic 1668 | 0.96 (2015-01-17) 1669 | add HDR output 1670 | fix monochrome BMP 1671 | 0.95 (2014-08-17) 1672 | add monochrome TGA output 1673 | 0.94 (2014-05-31) 1674 | rename private functions to avoid conflicts with stb_image.h 1675 | 0.93 (2014-05-27) 1676 | warning fixes 1677 | 0.92 (2010-08-01) 1678 | casts to unsigned char to fix warnings 1679 | 0.91 (2010-07-17) 1680 | first public release 1681 | 0.90 first internal release 1682 | */ 1683 | 1684 | /* 1685 | ------------------------------------------------------------------------------ 1686 | This software is available under 2 licenses -- choose whichever you prefer. 1687 | ------------------------------------------------------------------------------ 1688 | ALTERNATIVE A - MIT License 1689 | Copyright (c) 2017 Sean Barrett 1690 | Permission is hereby granted, free of charge, to any person obtaining a copy of 1691 | this software and associated documentation files (the "Software"), to deal in 1692 | the Software without restriction, including without limitation the rights to 1693 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 1694 | of the Software, and to permit persons to whom the Software is furnished to do 1695 | so, subject to the following conditions: 1696 | The above copyright notice and this permission notice shall be included in all 1697 | copies or substantial portions of the Software. 1698 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1699 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1700 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 1701 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 1702 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 1703 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 1704 | SOFTWARE. 1705 | ------------------------------------------------------------------------------ 1706 | ALTERNATIVE B - Public Domain (www.unlicense.org) 1707 | This is free and unencumbered software released into the public domain. 1708 | Anyone is free to copy, modify, publish, use, compile, sell, or distribute this 1709 | software, either in source code form or as a compiled binary, for any purpose, 1710 | commercial or non-commercial, and by any means. 1711 | In jurisdictions that recognize copyright laws, the author or authors of this 1712 | software dedicate any and all copyright interest in the software to the public 1713 | domain. We make this dedication for the benefit of the public at large and to 1714 | the detriment of our heirs and successors. We intend this dedication to be an 1715 | overt act of relinquishment in perpetuity of all present and future rights to 1716 | this software under copyright law. 1717 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1718 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1719 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 1720 | AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 1721 | ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 1722 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 1723 | ------------------------------------------------------------------------------ 1724 | */ --------------------------------------------------------------------------------