├── .clang-format ├── .github └── workflows │ └── cmake-linux.yml ├── .gitignore ├── .gitmodules ├── .vscode └── launch.json ├── CMakeLists.txt ├── app ├── App.cpp ├── App.hpp ├── CMakeLists.txt ├── offscreen.cpp ├── offscreen_context.cpp ├── offscreen_context.hpp ├── readme.md └── shader │ ├── Binding.h │ ├── Constants.h │ ├── offscreen.frag │ ├── offscreen.frag.spv │ ├── offscreen.vert │ └── offscreen.vert.spv ├── include ├── GSContext.hpp ├── shaders │ ├── DataStruct.h │ ├── gpusort │ │ ├── downsweep.comp │ │ ├── downsweep.comp.spv │ │ ├── spine.comp │ │ ├── spine.comp.spv │ │ ├── upsweep.comp │ │ └── upsweep.comp.spv │ ├── inverseIndex.comp │ ├── inverseIndex.comp.spv │ ├── process.comp │ ├── process.comp.spv │ ├── projection.comp │ ├── projection.comp.spv │ ├── rank.comp │ ├── rank.comp.spv │ ├── splat.frag │ ├── splat.frag.spv │ ├── splat.vert │ └── splat.vert.spv └── sort.hpp ├── main.cpp ├── main1.jpg ├── nsightcompute.txt ├── origin_cuda ├── auxiliary.h ├── config.h ├── forward.cu ├── forward.h ├── rasterizer.h ├── rasterizer_impl.cu └── rasterizer_impl.h ├── point_cloud.ply ├── readme.md ├── showcase ├── origincuda.png └── output.png ├── src ├── GSContext.cpp └── sort.cpp └── temp.txt /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: WebKit 2 | AccessModifierOffset: -4 3 | AlignAfterOpenBracket: Align 4 | # AlignConsecutiveAssignments: true 5 | BinPackArguments: false 6 | BinPackParameters: false 7 | AllowAllArgumentsOnNextLine: false 8 | AllowShortFunctionsOnASingleLine: None 9 | BreakBeforeBinaryOperators: None 10 | PenaltyBreakAssignment: 10000 11 | # IndentWrappedFunctionNames: true 12 | 13 | # AlignConsecutiveAssignments: 'true' 14 | # AlignConsecutiveDeclarations: 'true' 15 | # AlignOperands: 'true' 16 | # AlignTrailingComments: 'true' 17 | # AllowAllParametersOfDeclarationOnNextLine: 'false' 18 | # AllowShortBlocksOnASingleLine: 'false' 19 | # AllowShortCaseLabelsOnASingleLine: 'false' 20 | # AllowShortFunctionsOnASingleLine: Inline 21 | # AllowShortIfStatementsOnASingleLine: 'false' 22 | # AllowShortLoopsOnASingleLine: 'false' 23 | # AlwaysBreakAfterReturnType: None 24 | # AlwaysBreakBeforeMultilineStrings: 'true' 25 | # AlwaysBreakTemplateDeclarations: 'true' 26 | # BinPackArguments: 'true' 27 | # BinPackParameters: 'false' 28 | # ExperimentalAutoDetectBinPacking: 'false' 29 | # BreakBeforeBinaryOperators: NonAssignment 30 | # BreakBeforeBraces: Custom 31 | # BreakBeforeTernaryOperators: 'false' 32 | # BreakConstructorInitializersBeforeComma: 'true' 33 | # ColumnLimit: '120' 34 | # ConstructorInitializerAllOnOneLineOrOnePerLine: 'false' 35 | # Cpp11BracedListStyle: 'true' 36 | # IndentCaseLabels: 'true' 37 | # IndentWidth: '2' 38 | # KeepEmptyLinesAtTheStartOfBlocks: 'true' 39 | # Language: Cpp 40 | # MaxEmptyLinesToKeep: '2' 41 | # NamespaceIndentation: None 42 | # ObjCSpaceBeforeProtocolList: 'true' 43 | # PointerAlignment: Left 44 | # SpaceAfterCStyleCast: 'false' 45 | # SpaceBeforeAssignmentOperators: 'true' 46 | # SpaceBeforeParens: Never 47 | # SpaceInEmptyParentheses: 'false' 48 | # SpacesBeforeTrailingComments: '2' 49 | # SpacesInAngles: 'false' 50 | # SpacesInCStyleCastParentheses: 'false' 51 | # SpacesInParentheses: 'false' 52 | # SpacesInSquareBrackets: 'false' 53 | # Standard: Cpp11 54 | # TabWidth: '2' 55 | # UseTab: Never 56 | # SortIncludes: 'false' 57 | # ReflowComments: 'false' 58 | # BraceWrapping: { 59 | # AfterClass: 'true' 60 | # AfterControlStatement: 'true' 61 | # AfterEnum: 'true' 62 | # AfterFunction: 'true' 63 | # AfterNamespace: 'false' 64 | # AfterStruct: 'true' 65 | # AfterUnion: 'true' 66 | # BeforeCatch: 'true' 67 | # BeforeElse: 'true' 68 | # IndentBraces: 'false' 69 | # } 70 | # PenaltyExcessCharacter: 1 71 | # PenaltyBreakBeforeFirstCallParameter: 40 72 | # PenaltyBreakFirstLessLess: 1 73 | # PenaltyBreakComment: 30 74 | # PenaltyBreakString: 30 75 | # PenaltyReturnTypeOnItsOwnLine: 9999 76 | -------------------------------------------------------------------------------- /.github/workflows/cmake-linux.yml: -------------------------------------------------------------------------------- 1 | # This starter workflow is for a CMake project running on a single platform. There is a different starter workflow if you need cross-platform coverage. 2 | # See: https://github.com/actions/starter-workflows/blob/main/ci/cmake-multi-platform.yml 3 | name: CMake on a single platform 4 | 5 | on: 6 | push: 7 | branches: [ "main" ] 8 | pull_request: 9 | branches: [ "main" ] 10 | 11 | env: 12 | # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) 13 | BUILD_TYPE: Release 14 | 15 | jobs: 16 | build: 17 | # The CMake configure and build commands are platform agnostic and should work equally well on Windows or Mac. 18 | # You can convert this to a matrix build if you need cross-platform coverage. 19 | # See: https://docs.github.com/en/free-pro-team@latest/actions/learn-github-actions/managing-complex-workflows#using-a-build-matrix 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | with: 25 | submodules: true 26 | - name: Initialize submodules 27 | run: git submodule update --init --recursive 28 | 29 | - name: Install TBB 30 | run: sudo apt-get update && 31 | sudo apt-get install -y libtbb-dev 32 | 33 | - name: Install GLM 34 | run: sudo apt-get install -y libglm-dev 35 | 36 | 37 | - name: Install assimp 38 | run: sudo apt-get install -y libassimp-dev 39 | - name: Install Vulkan 40 | run: sudo apt-get install -y libvulkan1 libvulkan-dev 41 | 42 | - name: Install glfw 43 | run: sudo apt-get install -y libglfw3 libglfw3-dev 44 | 45 | - name: Configure CMake 46 | # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. 47 | # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type 48 | run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} 49 | 50 | - name: Build 51 | # Build your program with the given configuration 52 | run: cmake --build ${{github.workspace}}/build --config ${{env.BUILD_TYPE}} -j8 53 | 54 | # - name: Test 55 | # working-directory: ${{github.workspace}}/build 56 | # # Execute tests defined by the CMake configuration. 57 | # # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail 58 | # run: ctest -C ${{env.BUILD_TYPE}} 59 | 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .cache 3 | .idea 4 | .fleet 5 | cmake-build-* 6 | build 7 | # *.spv 8 | act -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "MCRT"] 2 | path = MCRT 3 | url = https://github.com/MouseChannel/MCRT.git 4 | branch = gaussian-splatting 5 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "CUDA C++: Launch", 6 | "type": "cuda-gdb", 7 | "request": "launch", 8 | "program": "${workspaceFolder}/build/MCGS", 9 | }, 10 | { 11 | "type": "lldb", 12 | "request": "launch", 13 | "name": "Launch", 14 | "program": "${workspaceFolder}/build/App/MCGS", 15 | "args": [], 16 | "cwd": "${workspaceFolder}" 17 | }, 18 | { 19 | "type": "lldb", 20 | "request": "launch", 21 | "name": "origin", 22 | "program": "${workspaceFolder}/build/ORIGIN_TEST", 23 | "args": [], 24 | "cwd": "${workspaceFolder}", 25 | // "xdebugSettings":{} 26 | }, 27 | { 28 | "name": "CMake: CMake Script", 29 | "type": "cmake", 30 | "request": "launch", 31 | "cmakeDebugType": "script", 32 | "scriptPath": "${workspaceFolder}/CMakeLists.txt" 33 | } 34 | ] 35 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | 3 | 4 | set(CMAKE_CXX_STANDARD 20) 5 | 6 | 7 | project(MCGS LANGUAGES CXX ) 8 | 9 | 10 | 11 | file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/src/GSContext.cpp ${PROJECT_SOURCE_DIR}/src/sort.cpp) 12 | 13 | 14 | add_library(src ${SRC}) 15 | message(STATUS ${SRC}) 16 | target_include_directories(src PUBLIC ${PROJECT_SOURCE_DIR}/include) 17 | 18 | 19 | set(MCRT_DIR MCRT/cmake) 20 | find_package(MCRT REQUIRED) 21 | target_link_libraries(src PUBLIC MCRT) 22 | 23 | 24 | 25 | add_subdirectory(${PROJECT_SOURCE_DIR}/app) 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /app/App.cpp: -------------------------------------------------------------------------------- 1 | #include "App.hpp" 2 | // #include "Helper/ImGui_Context.hpp" 3 | #include "Context/Context.hpp" 4 | 5 | #include "Imgui/imgui.h" 6 | #include "Rendering/AppWindow.hpp" 7 | #include "offscreen_context.hpp" 8 | inline void SetupImGuiStyle(bool bStyleDark_, float alpha_) 9 | { 10 | ImGuiStyle& style = ImGui::GetStyle(); 11 | // ImGui::SetWindowPos(ImVec2(100, 100), ImGuiCond_Always); 12 | // ImGui::SetWindowSize(ImVec2(500, 200), ImGuiCond_Always); 13 | 14 | // light style from Pacôme Danhiez (user itamago) https://github.com/ocornut/imgui/pull/511#issuecomment-175719267 15 | style.Alpha = 1.0f; 16 | style.FrameRounding = 3.0f; 17 | style.WindowBorderSize = 0.5f; 18 | style.WindowRounding = 10.f; 19 | style.Colors[ImGuiCol_Text] = ImVec4(0.00f, 0.00f, 0.00f, 1.00f); 20 | style.Colors[ImGuiCol_TextDisabled] = ImVec4(0.60f, 0.60f, 0.60f, 1.00f); 21 | style.Colors[ImGuiCol_WindowBg] = ImVec4(0.94f, 0.94f, 0.94f, 0.94f); 22 | style.Colors[ImGuiCol_ChildBg] = ImVec4(0.00f, 0.00f, 0.00f, 0.00f); 23 | style.Colors[ImGuiCol_PopupBg] = ImVec4(1.00f, 1.00f, 1.00f, 0.94f); 24 | style.Colors[ImGuiCol_Border] = ImVec4(0.00f, 0.00f, 0.00f, 0.39f); 25 | style.Colors[ImGuiCol_BorderShadow] = ImVec4(1.00f, 1.00f, 1.00f, 0.10f); 26 | style.Colors[ImGuiCol_FrameBg] = ImVec4(1.00f, 1.00f, 1.00f, 0.94f); 27 | style.Colors[ImGuiCol_FrameBgHovered] = ImVec4(0.26f, 0.59f, 0.98f, 0.40f); 28 | style.Colors[ImGuiCol_FrameBgActive] = ImVec4(0.26f, 0.59f, 0.98f, 0.67f); 29 | style.Colors[ImGuiCol_TitleBg] = ImVec4(0.96f, 0.96f, 0.96f, 1.00f); 30 | style.Colors[ImGuiCol_TitleBgCollapsed] = ImVec4(1.00f, 1.00f, 1.00f, 0.51f); 31 | style.Colors[ImGuiCol_TitleBgActive] = ImVec4(0.82f, 0.82f, 0.82f, 1.00f); 32 | style.Colors[ImGuiCol_MenuBarBg] = ImVec4(0.86f, 0.86f, 0.86f, 1.00f); 33 | style.Colors[ImGuiCol_ScrollbarBg] = ImVec4(0.98f, 0.98f, 0.98f, 0.53f); 34 | style.Colors[ImGuiCol_ScrollbarGrab] = ImVec4(0.69f, 0.69f, 0.69f, 1.00f); 35 | style.Colors[ImGuiCol_ScrollbarGrabHovered] = ImVec4(0.59f, 0.59f, 0.59f, 1.00f); 36 | style.Colors[ImGuiCol_ScrollbarGrabActive] = ImVec4(0.49f, 0.49f, 0.49f, 1.00f); 37 | // style.Colors[ImGuiCol_ComboBg] = ImVec4(0.86f, 0.86f, 0.86f, 0.99f); 38 | style.Colors[ImGuiCol_CheckMark] = ImVec4(0.26f, 0.59f, 0.98f, 1.00f); 39 | style.Colors[ImGuiCol_SliderGrab] = ImVec4(0.24f, 0.52f, 0.88f, 1.00f); 40 | style.Colors[ImGuiCol_SliderGrabActive] = ImVec4(0.26f, 0.59f, 0.98f, 1.00f); 41 | style.Colors[ImGuiCol_Button] = ImVec4(0.26f, 0.59f, 0.98f, 0.40f); 42 | style.Colors[ImGuiCol_ButtonHovered] = ImVec4(0.26f, 0.59f, 0.98f, 1.00f); 43 | style.Colors[ImGuiCol_ButtonActive] = ImVec4(0.06f, 0.53f, 0.98f, 1.00f); 44 | style.Colors[ImGuiCol_Header] = ImVec4(0.26f, 0.59f, 0.98f, 0.31f); 45 | style.Colors[ImGuiCol_HeaderHovered] = ImVec4(0.26f, 0.59f, 0.98f, 0.80f); 46 | style.Colors[ImGuiCol_HeaderActive] = ImVec4(0.26f, 0.59f, 0.98f, 1.00f); 47 | style.Colors[ImGuiCol_Separator] = ImVec4(0.39f, 0.39f, 0.39f, 1.00f); 48 | style.Colors[ImGuiCol_SeparatorHovered] = ImVec4(0.26f, 0.59f, 0.98f, 0.78f); 49 | style.Colors[ImGuiCol_SeparatorActive] = ImVec4(0.26f, 0.59f, 0.98f, 1.00f); 50 | style.Colors[ImGuiCol_ResizeGrip] = ImVec4(1.00f, 1.00f, 1.00f, 0.50f); 51 | style.Colors[ImGuiCol_ResizeGripHovered] = ImVec4(0.26f, 0.59f, 0.98f, 0.67f); 52 | style.Colors[ImGuiCol_ResizeGripActive] = ImVec4(0.26f, 0.59f, 0.98f, 0.95f); 53 | // style.Colors[ImGuiCol_CloseButton] = ImVec4(0.59f, 0.59f, 0.59f, 0.50f); 54 | // style.Colors[ImGuiCol_CloseButtonHovered] = ImVec4(0.98f, 0.39f, 0.36f, 1.00f); 55 | // style.Colors[ImGuiCol_CloseButtonActive] = ImVec4(0.98f, 0.39f, 0.36f, 1.00f); 56 | style.Colors[ImGuiCol_PlotLines] = ImVec4(0.39f, 0.39f, 0.39f, 1.00f); 57 | style.Colors[ImGuiCol_PlotLinesHovered] = ImVec4(1.00f, 0.43f, 0.35f, 1.00f); 58 | style.Colors[ImGuiCol_PlotHistogram] = ImVec4(0.90f, 0.70f, 0.00f, 1.00f); 59 | style.Colors[ImGuiCol_PlotHistogramHovered] = ImVec4(1.00f, 0.60f, 0.00f, 1.00f); 60 | style.Colors[ImGuiCol_TextSelectedBg] = ImVec4(0.26f, 0.59f, 0.98f, 0.35f); 61 | style.Colors[ImGuiCol_ModalWindowDimBg] = ImVec4(0.20f, 0.20f, 0.20f, 0.35f); 62 | 63 | if (bStyleDark_) { 64 | for (int i = 0; i <= ImGuiCol_COUNT; i++) { 65 | ImVec4& col = style.Colors[i]; 66 | float H, S, V; 67 | ImGui::ColorConvertRGBtoHSV(col.x, col.y, col.z, H, S, V); 68 | 69 | if (S < 0.1f) { 70 | V = 1.0f - V; 71 | } 72 | ImGui::ColorConvertHSVtoRGB(H, S, V, col.x, col.y, col.z); 73 | if (col.w < 1.00f) { 74 | col.w *= alpha_; 75 | } 76 | } 77 | } else { 78 | for (int i = 0; i <= ImGuiCol_COUNT; i++) { 79 | ImVec4& col = style.Colors[i]; 80 | if (col.w < 1.00f) { 81 | col.x *= alpha_; 82 | col.y *= alpha_; 83 | col.z *= alpha_; 84 | col.w *= alpha_; 85 | } 86 | } 87 | } 88 | } 89 | 90 | namespace MCRT { 91 | void App::init() 92 | { 93 | window.reset(new Window(1600, 900)); 94 | offscreen_context::Get_Singleton()->prepare(window); 95 | } 96 | 97 | void App::run() 98 | { 99 | // std::cout << "oonon" << std::endl; 100 | // SetupImGuiStyle(true, 0.95f); 101 | auto& context = offscreen_context::Get_Singleton(); 102 | while (!window->Should_Close()) { 103 | window->PollEvents(); 104 | auto cmd = context->Begin_Frame(); 105 | 106 | context->EndFrame(); 107 | } 108 | Context::Get_Singleton()->get_device()->get_handle().waitIdle(); 109 | } 110 | App::~App() 111 | { 112 | window.reset(); 113 | offscreen_context::Get_Singleton()->Quit(); 114 | } 115 | } // namespace MCRT 116 | -------------------------------------------------------------------------------- /app/App.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // #include "Helper/ImGui_Context.hpp" 3 | #include 4 | 5 | namespace MCRT { 6 | // class ImGuiContext; 7 | class Window; 8 | // class Context; 9 | class App { 10 | public: 11 | App() = default; 12 | ~App(); 13 | void init(); 14 | void run(); 15 | 16 | private: 17 | std::shared_ptr window; 18 | 19 | // std::unique_ptr imgui; 20 | }; 21 | } // namespace MCRT -------------------------------------------------------------------------------- /app/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | #file(GLOB_RECURSE OFFSCREEN_PBR ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 2 | # 3 | #add_library(offscreen_lib ${OFFSCREEN_PBR}) 4 | # 5 | #target_link_libraries(offscreen_lib rasterlib) 6 | # 7 | #add_executable(offscreen offscreen.cpp) 8 | # 9 | #target_link_libraries(offscreen PRIVATE offscreen_lib) 10 | 11 | 12 | 13 | 14 | # file(GLOB_RECURSE MCRT_RASTER_SRC ${PROJECT_SOURCE_DIR}/MCRT/example/base/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ) 15 | # add_library(mcrt_raster_src ${MCRT_RASTER_SRC}) 16 | # target_link_libraries(mcrt_raster_src PUBLIC mcrt_src) 17 | 18 | # target_include_directories(mcrt_raster_src PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 19 | 20 | 21 | # include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 22 | 23 | 24 | # target_link_libraries(src PUBLIC mcrt_raster_src imgui) 25 | 26 | add_executable(MCGS offscreen.cpp offscreen_context.cpp App.cpp ) 27 | target_link_libraries(MCGS src) 28 | 29 | # target_include_directories(MCGS PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 30 | -------------------------------------------------------------------------------- /app/offscreen.cpp: -------------------------------------------------------------------------------- 1 | #define GLM_FORCE_DEPTH_ZERO_TO_ONE 2 | #include "App.hpp" 3 | #include 4 | #include 5 | 6 | int main(int, char**) 7 | { 8 | 9 | MCRT::App app; 10 | app.init(); 11 | try { 12 | 13 | app.run(); 14 | } catch (const std::system_error& e) { 15 | if (e.code().value() == VK_ERROR_DEVICE_LOST) { 16 | // #if _WIN32 17 | // MessageBoxA(nullptr, e.what(), "Fatal Error", MB_ICONERROR | MB_OK | MB_DEFBUTTON1); 18 | // #endif 19 | std::cout << "e.what()" << std::endl; 20 | } 21 | std::cout << e.what() << std::endl; 22 | return e.code().value(); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /app/offscreen_context.cpp: -------------------------------------------------------------------------------- 1 | // #include "Wrapper/SubPass/ToneMapSubpass.hpp" 2 | #include "offscreen_context.hpp" 3 | #include "Imgui/imgui.h" 4 | #include "Wrapper/GraphicPass/UiPass.hpp" 5 | #include "Wrapper/Pipeline/Graphic_Pipeline.hpp" 6 | 7 | #define GLM_FORCE_DEPTH_ZERO_TO_ONE 8 | #include "Helper/Camera.hpp" 9 | #include "Helper/CommandManager.hpp" 10 | 11 | #include "Helper/Model_Loader/gltf_loader.hpp" 12 | #include "Rendering/ComputeContext.hpp" 13 | #include "Rendering/GraphicContext.hpp" 14 | #include "Rendering/Model.hpp" 15 | #include "Wrapper/CommandBuffer.hpp" 16 | #include "Wrapper/GraphicPass/GraphicPass.hpp" 17 | // #include "Wrapper/Skybox.hpp" 18 | #include "Wrapper/Texture.hpp" 19 | 20 | #include "Helper/DescriptorSetTarget/BufferDescriptorTarget.hpp" 21 | #include "Helper/DescriptorSetTarget/ImageDescriptorTarget.hpp" 22 | 23 | #include "GSContext.hpp" 24 | #include "Rendering/PBR/IBL_Manager.hpp" 25 | #include 26 | namespace MCRT { 27 | std::unique_ptr Context::_instance { new MCRT::offscreen_context }; 28 | float offscreen_context::light_pos_x = 0, offscreen_context::light_pos_y = 0, offscreen_context::light_pos_z = 5, offscreen_context::gamma = 2.2f; 29 | bool offscreen_context::use_normal_map = false, offscreen_context::use_r_rm_map = false, offscreen_context::use_ao = false; 30 | int irradiance_size = 512; 31 | 32 | enum DescriptorSetsIndex { 33 | MAIN, 34 | DESCRIPTORSET_COUNT 35 | }; 36 | enum PASS { 37 | eMainPass, 38 | eUIPass, 39 | PassCount 40 | }; 41 | enum PIPELINE { 42 | eMainPipeline, 43 | // eUIPipeline, 44 | PipelineCount 45 | }; 46 | 47 | offscreen_context::offscreen_context() 48 | { 49 | } 50 | 51 | offscreen_context::~offscreen_context() 52 | { 53 | } 54 | 55 | void offscreen_context::prepare(std::shared_ptr window) 56 | { 57 | raster_context::prepare(window); 58 | 59 | std::vector vertexs { 60 | { .pos { 1, 1, 0 }, .texCoord { 1, 1 } }, 61 | { .pos { 1, -3, 0 }, .texCoord { 1, -1 } }, 62 | { .pos { -3, 1, 0 }, .texCoord { -1, 1 } } 63 | }; 64 | std::vector faces { 0, 1, 2 }; 65 | offscreen_mesh.reset(new Mesh("offscreen", vertexs, faces, {})); 66 | 67 | PASS.resize(2); 68 | { 69 | // PASS[1] = std::make_shared("C:/Users/moche/project/gs/gaussian-splatting/output/80127b2a-4/point_cloud/iteration_30000/point_cloud.ply"); 70 | PASS[1] = std::make_shared("point_cloud.ply"); 71 | 72 | 73 | PASS[1]->prepare(); 74 | } 75 | { 76 | PASS[Pass_index::Graphic] = std::shared_ptr { new GraphicContext(m_device) }; 77 | 78 | auto graphic_context = std::reinterpret_pointer_cast(PASS[Graphic]); 79 | if (graphic_context == nullptr) { 80 | throw std::runtime_error("not graphic context"); 81 | } 82 | 83 | graphic_context->prepare(); 84 | 85 | { 86 | auto swapchain_renderTarget = graphic_context->AddSwapchainRenderTarget(); 87 | auto depth_renderTarget = graphic_context->AddDepthRenderTarget(); 88 | 89 | { 90 | graphic_context->descriptorSets.resize(DESCRIPTORSET_COUNT); 91 | graphic_context->descriptorSetPools.resize(DESCRIPTORSET_COUNT); 92 | graphic_context->descriptorSets[MAIN] = std::make_shared(); 93 | } 94 | { 95 | graphic_context->graphicPass.resize(PassCount); 96 | 97 | for (int i = 0; i < PassCount; i++) { 98 | 99 | graphic_context->graphicPass[i] = std::make_shared(graphic_context.get()); 100 | graphic_context->graphicPass[i]->set_subpass_index(i); 101 | } 102 | } 103 | 104 | 105 | { 106 | graphic_context->descriptorSets[MAIN]->AddBufferDescriptorTarget(std::reinterpret_pointer_cast(PASS[1])->instance_buffer, 107 | 1, 108 | vk::ShaderStageFlagBits::eVertex, 109 | vk::DescriptorType::eStorageBuffer); 110 | // if (graphic_context->descriptorSets[MAIN]->check_dirty()) { 111 | graphic_context->descriptorSetPools[MAIN].reset(new DescriptorPool({ graphic_context->descriptorSets[MAIN] }, graphic_context->get_frame_count())); 112 | graphic_context->descriptorSets[MAIN]->build(graphic_context->descriptorSetPools[MAIN], graphic_context->get_frame_count()); 113 | // } 114 | } 115 | graphic_context->m_pipelines.resize(PipelineCount); 116 | { 117 | auto mainPass = graphic_context->graphicPass[eMainPass]; 118 | mainPass->link_renderTarget({ swapchain_renderTarget }, 119 | {depth_renderTarget}, 120 | {}, 121 | {}); 122 | auto uiPass = graphic_context->graphicPass[eUIPass]; 123 | uiPass->link_renderTarget({ swapchain_renderTarget }, 124 | {}, 125 | {}, 126 | {}); 127 | } 128 | { 129 | graphic_context->AddSubPassDependency(vk::SubpassDependency() 130 | .setSrcSubpass(VK_SUBPASS_EXTERNAL) 131 | .setDstSubpass(eMainPass) 132 | .setSrcStageMask(vk::PipelineStageFlagBits::eColorAttachmentOutput) 133 | .setSrcAccessMask(vk::AccessFlagBits::eColorAttachmentWrite) 134 | .setDstStageMask(vk::PipelineStageFlagBits::eColorAttachmentOutput) 135 | .setDstAccessMask(vk::AccessFlagBits::eColorAttachmentWrite)); 136 | graphic_context->AddSubPassDependency(vk::SubpassDependency() 137 | .setSrcSubpass(eUIPass - 1) 138 | .setDstSubpass(eUIPass) 139 | .setSrcStageMask(vk::PipelineStageFlagBits::eColorAttachmentOutput) 140 | .setSrcAccessMask(vk::AccessFlagBits::eColorAttachmentWrite) 141 | .setDstStageMask(vk::PipelineStageFlagBits::eColorAttachmentOutput) 142 | .setDstAccessMask(vk::AccessFlagBits::eColorAttachmentWrite)); 143 | } 144 | { 145 | auto gbufferPass = graphic_context->graphicPass[eMainPass]; 146 | 147 | graphic_context->post_prepare(); 148 | 149 | auto uiPass = std::reinterpret_pointer_cast(graphic_context->graphicPass[eUIPass]); 150 | uiPass->Init(); 151 | } 152 | { 153 | auto rr = std::reinterpret_pointer_cast(PASS[1])->set; 154 | graphic_context->m_pipelines[eMainPipeline] 155 | .reset( 156 | new Graphic_Pipeline(graphic_context->Get_render_pass(), 157 | "include/shaders/splat.vert.spv", 158 | "include/shaders/splat.frag.spv", 159 | vk::CullModeFlagBits::eBack, 160 | true, 161 | false, 162 | vk::SampleCountFlagBits::e1, 163 | graphic_context->graphicPass[eMainPass]->get_subpass_index(), 164 | // { rr }, 165 | { graphic_context->descriptorSets[MAIN] }, 166 | 167 | 4, 168 | vk::ShaderStageFlagBits::eFragment, 169 | graphic_context->graphicPass[eMainPass]->color_references.size(), 170 | vk::PipelineColorBlendAttachmentState() 171 | .setBlendEnable(true) 172 | .setSrcColorBlendFactor(vk::BlendFactor ::eOne) 173 | .setSrcAlphaBlendFactor(vk::BlendFactor ::eOne) 174 | .setColorBlendOp(vk::BlendOp ::eAdd) 175 | .setDstColorBlendFactor(vk::BlendFactor ::eOneMinusSrcAlpha) 176 | .setDstAlphaBlendFactor(vk::BlendFactor ::eOneMinusSrcAlpha) 177 | .setAlphaBlendOp(vk::BlendOp ::eAdd) 178 | .setColorWriteMask(vk::ColorComponentFlagBits::eR | vk::ColorComponentFlagBits::eG | vk::ColorComponentFlagBits::eB | vk::ColorComponentFlagBits::eA))); 179 | } 180 | } 181 | } 182 | } 183 | 184 | std::shared_ptr offscreen_context::Begin_Frame() 185 | { 186 | 187 | return raster_context::Begin_Frame(); 188 | } 189 | 190 | void offscreen_context::EndFrame() 191 | { 192 | raster_context::EndFrame(); 193 | } 194 | 195 | std::shared_ptr offscreen_context::BeginGraphicFrame() 196 | { 197 | auto gs_context = std::reinterpret_pointer_cast(PASS[1]); 198 | auto render_context = std::reinterpret_pointer_cast(PASS[Graphic]); 199 | 200 | std::shared_ptr command = render_context->BeginFrame(); 201 | // std::shared_ptr command = gs_context->BeginFrame(); 202 | // gs_context->Submit(); 203 | // Context::Get_Singleton()->get_device()->Get_Graphic_queue().waitIdle(); 204 | 205 | gs_context->tick(command); 206 | 207 | render_context->Begin_RenderPass(command); 208 | { 209 | 210 | { 211 | auto cmd = command->get_handle(); 212 | 213 | cmd.setViewport(0, 214 | vk::Viewport() 215 | .setHeight(extent2d.height) 216 | .setWidth(extent2d.width) 217 | .setMinDepth(0) 218 | .setMaxDepth(1) 219 | .setX(0) 220 | .setY(0)); 221 | cmd.setScissor(0, 222 | vk::Rect2D() 223 | .setExtent(extent2d) 224 | .setOffset(vk::Offset2D() 225 | .setX(0) 226 | .setY(0))); 227 | 228 | { 229 | cmd.bindPipeline(vk::PipelineBindPoint ::eGraphics, render_context->m_pipelines[eMainPipeline]->get_handle()); 230 | 231 | cmd.bindDescriptorSets( 232 | vk::PipelineBindPoint ::eGraphics, 233 | render_context->m_pipelines[eMainPipeline]->get_layout(), 234 | 0, 235 | { render_context->descriptorSets[MAIN]->get_handle()[render_context->get_cur_index()] }, 236 | {}); 237 | 238 | // cmd.bindVertexBuffers(0, offscreen_mesh->get_vertex_buffer()->get_handle(), { 0 }); 239 | cmd.bindIndexBuffer(gs_context->index_buffer->get_handle(), 0, vk::IndexType ::eUint32); 240 | // cmd.draw(3, 1, 0, 0); 241 | cmd.drawIndexedIndirect(gs_context->indirct_cmd_buffer->get_handle(), 0, 1, 0); 242 | } 243 | 244 | { 245 | cmd.nextSubpass(vk::SubpassContents ::eInline); 246 | auto uiPass = std::reinterpret_pointer_cast(render_context->graphicPass[eUIPass]); 247 | 248 | uiPass->DrawUI(cmd, []() { 249 | ImGui::Text("move:[W A S D Q E]"); 250 | ImGui::Text("Hold left Mouse Button To Rotate!!"); 251 | ImGui::SliderFloat("move-sensitivity", &Context::Get_Singleton()->get_camera()->m_sensitivity, 1e-2f, 1e-1f); 252 | 253 | ImGui::Checkbox("use_normal_map", &offscreen_context::use_normal_map); 254 | ImGui::Checkbox("rm", &use_r_rm_map); 255 | ImGui::Checkbox("AO", &offscreen_context::use_ao); 256 | ImGui::Text("fps : %7.3f", ImGui::GetIO().Framerate); 257 | }); 258 | } 259 | } 260 | } 261 | return command; 262 | } 263 | void offscreen_context::EndGraphicFrame() 264 | { 265 | auto& m_render_context = PASS[Graphic]; 266 | m_render_context->Submit(); 267 | m_render_context->EndFrame(); 268 | } 269 | 270 | } -------------------------------------------------------------------------------- /app/offscreen_context.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Wrapper/Texture.hpp" 3 | 4 | #include "Context/raster_context.hpp" 5 | // #include "shaders/Data_struct.h" 6 | #include "Rendering/GraphicContext.hpp" 7 | 8 | namespace MCRT { 9 | class Buffer; 10 | // class Skybox; 11 | class offscreen_context : public raster_context { 12 | public: 13 | enum Pass_index { Graphic }; 14 | 15 | offscreen_context(); 16 | ~offscreen_context(); 17 | std::shared_ptr Begin_Frame() override; 18 | void EndFrame() override; 19 | std::shared_ptr get_rt_context() override 20 | { 21 | throw std::runtime_error("it is not Ray_Tracing context"); 22 | } 23 | std::shared_ptr get_compute_context() override 24 | { 25 | 26 | throw std::runtime_error("it is not compute context"); 27 | } 28 | 29 | std::shared_ptr get_graphic_context() override 30 | { 31 | auto base = PASS[Pass_index::Graphic]; 32 | if (auto context = std::reinterpret_pointer_cast(base); context != nullptr) { 33 | return context; 34 | } 35 | throw std::runtime_error("it is not Ray_Tracing context"); 36 | } 37 | 38 | static float light_pos_x, light_pos_y, light_pos_z; 39 | static bool use_normal_map, use_r_rm_map, use_ao; 40 | static float gamma; 41 | 42 | void prepare(std::shared_ptr window) override; 43 | 44 | private: 45 | std::shared_ptr BeginGraphicFrame() override; 46 | std::shared_ptr target_texture; 47 | void EndGraphicFrame() override; 48 | std::shared_ptr offscreen_mesh; 49 | }; 50 | 51 | } -------------------------------------------------------------------------------- /app/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/app/readme.md -------------------------------------------------------------------------------- /app/shader/Binding.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | #include 3 | using mat4 = glm::mat4; 4 | 5 | #define BEGIN_ENUM(a) enum class a { 6 | #define END_ENUM() } 7 | 8 | #else 9 | #define BEGIN_ENUM(a) const uint 10 | #define END_ENUM() 11 | #endif 12 | // #include "example/base/shaders/raster/Data_struct.h" 13 | 14 | BEGIN_ENUM(Graphic_Set) 15 | e_graphic = 0, 16 | // e_graphic_global = 1, 17 | graphic_count = 1 END_ENUM(); 18 | 19 | BEGIN_ENUM(Graphic_Binding) 20 | e_camera_matrix = 0, 21 | e_offscreen = 1 22 | // e_textures = 1, 23 | // e_skybox = 2, 24 | // e_irradiance_image = 3, 25 | // e_LUT_image = 4, 26 | // e_albedo_image = 5, 27 | // e_normal_image = 6, 28 | // e_m_r_image = 7, 29 | // e_tonemap_input = 8 30 | // e_albedo_texture = 9, 31 | // e_nrm_texture = 10, 32 | // e_arm_texture = 11 33 | // e_roughness = 11 34 | 35 | 36 | END_ENUM(); 37 | 38 | // const int e_out_uv = 5; 39 | -------------------------------------------------------------------------------- /app/shader/Constants.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | 3 | #include "glm/glm.hpp" 4 | using mat4 = glm::mat4; 5 | using vec3 = glm::vec3; 6 | using vec4 = glm::vec4; 7 | #endif // DEBUG 8 | #ifndef raster_push_constants 9 | #define raster_push_constants 10 | struct PC_Raster { 11 | mat4 model_matrix; 12 | vec4 light_pos; 13 | int color_texture_index; 14 | int metallicness_roughness_texture_index; 15 | int normal_texture_index; 16 | 17 | 18 | 19 | // bool use_normal_map; 20 | // bool use_r_m_map; 21 | // bool use_AO; 22 | 23 | int use_normal_map; 24 | int use_r_m_map; 25 | int use_AO; 26 | float gamma; 27 | }; 28 | struct PushContant_Compute { 29 | int frame; 30 | int open_filter; 31 | }; 32 | #endif -------------------------------------------------------------------------------- /app/shader/offscreen.frag: -------------------------------------------------------------------------------- 1 | #version 450 core 2 | #extension GL_GOOGLE_include_directive : enable 3 | #extension GL_EXT_debug_printf : enable 4 | 5 | #include "Binding.h" 6 | 7 | #include "Shader/Data_struct.h" 8 | 9 | layout(location = 0) out vec4 outColor; 10 | layout(location = e_texCoord) in vec2 in_texCoord1; 11 | layout(set = e_graphic, binding = e_offscreen) uniform sampler2D img; 12 | void main() 13 | { 14 | // debugPrintfEXT("Hewe\n"); 15 | 16 | outColor = texture(img, in_texCoord1).rgba; 17 | // outColor = vec4(in_texCoord1,0,1.f); 18 | } 19 | -------------------------------------------------------------------------------- /app/shader/offscreen.frag.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/app/shader/offscreen.frag.spv -------------------------------------------------------------------------------- /app/shader/offscreen.vert: -------------------------------------------------------------------------------- 1 | #version 460 2 | #extension GL_EXT_debug_printf : enable 3 | #extension GL_GOOGLE_include_directive : enable 4 | vec3[3] positions = vec3[3](vec3(1.0, 1.0, 0.0), vec3(1.0, -3.0, 0.0), vec3(-3.0, 1.0, 0.0)); 5 | vec2[3] uvs = vec2[3](vec2(1.0, 1.0), vec2(1.0, -1.0), vec2(-1.0, 1.0)); 6 | #include "Shader/Data_struct.h" 7 | 8 | 9 | layout(location = e_pos) in vec3 inpos; 10 | layout(location = e_texCoord) in vec2 in_texCoord; 11 | 12 | layout(location = e_texCoord) out vec2 out_texCoord; 13 | void rr() 14 | { 15 | } 16 | void main() 17 | { 18 | 19 | // out_texCoord = uvs[gl_VertexIndex]; 20 | // gl_Position = vec4(positions[gl_VertexIndex], 1); 21 | out_texCoord = in_texCoord; 22 | gl_Position = vec4(inpos, 1); 23 | } 24 | -------------------------------------------------------------------------------- /app/shader/offscreen.vert.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/app/shader/offscreen.vert.spv -------------------------------------------------------------------------------- /include/GSContext.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "shaders/DataStruct.h" 3 | 4 | #include "sort.hpp" 5 | #include 6 | #include 7 | 8 | namespace MCRT { 9 | 10 | class Buffer; 11 | class GSContext : public ComputeContext { 12 | public: 13 | GSContext(std::string path); 14 | 15 | void prepare() override; 16 | std::shared_ptr BeginFrame() override; 17 | void tick(std::shared_ptr command); 18 | 19 | std::shared_ptr 20 | raw_data; 21 | 22 | 23 | std::shared_ptr instance_buffer; 24 | std::shared_ptr indirct_cmd_buffer; 25 | std::shared_ptr point_count_buffer; 26 | std::shared_ptr visiable_count_buffer; 27 | 28 | std::shared_ptr key_buffer; 29 | std::shared_ptr value_buffer; 30 | std::shared_ptr inverse_index_buffer; 31 | std::shared_ptr index_buffer; 32 | std::shared_ptr camera_buffer; 33 | 34 | std::unique_ptr> pre_process_pass; 35 | std::unique_ptr> rank_pass; 36 | std::unique_ptr> inverse_pass; 37 | std::unique_ptr> projection_pass; 38 | 39 | int all_point_count; 40 | // std::unique_ptr> duplicatePass; 41 | 42 | // std::vector sets; 43 | std::shared_ptr m_gpu_sort; 44 | 45 | std::shared_ptr rank_pipeline; 46 | 47 | std::shared_ptr set; 48 | std::shared_ptr setpool; 49 | }; 50 | } -------------------------------------------------------------------------------- /include/shaders/DataStruct.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | #pragma once 3 | #include "glm/glm.hpp" 4 | 5 | using vec3 = glm::vec3; 6 | 7 | using vec4 = glm::vec4; 8 | using vec2 = glm::vec2; 9 | using uint = uint32_t; 10 | using mat4x2 = glm::mat4x2; 11 | using mat4 = glm::mat4; 12 | using uvec2 = glm::uvec2; 13 | using lowp_fvec4 = glm::lowp_fvec4; 14 | #else 15 | #extension GL_EXT_shader_16bit_storage : require 16 | #define lowp_fvec4 = lowp vec4 17 | #endif 18 | 19 | #define e_gaussian_raw_point 0 20 | #define e_instance_point 1 21 | #define e_indir_cmd 2 22 | #define e_point_count 3 23 | #define e_visiable_count 4 24 | #define e_camera 5 25 | #define e_instance_key 6 26 | #define e_instance_value 7 27 | #define e_inverse_index 8 28 | struct PointCount { 29 | uint all_count; 30 | }; 31 | 32 | struct 33 | 34 | GaussianPoint { 35 | vec3 pos; 36 | float deldete1; 37 | vec3 scale; 38 | float deldete2; 39 | 40 | vec4 rot; 41 | // glm::vec4 dd[12]; 42 | vec4 sh[12]; 43 | // lowp_fvec4 sh[12]; 44 | 45 | // f16vec4 sh[12]; 46 | 47 | // mat2x4 conv3d; 48 | // mat4x2 conv3d; 49 | vec4 conv3d[2]; 50 | 51 | // 52 | // float sh[48]; 53 | 54 | float opacity; 55 | float pad1; 56 | float pad2; 57 | float pad3; 58 | }; 59 | 60 | struct InstancePoint { 61 | vec3 ndc_position; 62 | float pad0; 63 | vec2 scale; 64 | float theta; 65 | float pad1; 66 | // lowp_fvec4 sh; 67 | vec4 color; 68 | }; 69 | 70 | struct IndexedIndirectCommand { 71 | uint indexCount; 72 | uint instanceCount; 73 | uint firstIndex; 74 | int vertexOffset; 75 | uint firstInstance; 76 | }; 77 | struct CameraInfo { 78 | mat4 projection; 79 | mat4 view; 80 | vec3 camera_position; 81 | float pad; 82 | uvec2 screen_size; 83 | }; 84 | 85 | // struct PushConstants { 86 | // uint32_t pass; 87 | // uint64_t elementCountReference; 88 | // uint64_t globalHistogramReference; 89 | // uint64_t partitionHistogramReference; 90 | // uint64_t keysInReference; 91 | // uint64_t keysOutReference; 92 | // uint64_t valuesInReference; 93 | // uint64_t valuesOutReference; 94 | // }; -------------------------------------------------------------------------------- /include/shaders/gpusort/downsweep.comp: -------------------------------------------------------------------------------- 1 | #version 460 core 2 | 3 | #extension GL_EXT_buffer_reference : require 4 | #extension GL_KHR_shader_subgroup_basic: enable 5 | #extension GL_KHR_shader_subgroup_arithmetic: enable 6 | #extension GL_KHR_shader_subgroup_ballot: enable 7 | 8 | const int RADIX = 256; 9 | #define WORKGROUP_SIZE 512 10 | #define PARTITION_DIVISION 8 11 | const int PARTITION_SIZE = PARTITION_DIVISION * WORKGROUP_SIZE; 12 | 13 | layout (local_size_x = WORKGROUP_SIZE) in; 14 | 15 | layout (buffer_reference, std430) readonly buffer ElementCount { 16 | uint elementCount; 17 | }; 18 | 19 | layout (buffer_reference, std430) readonly buffer GlobalHistogram { 20 | uint globalHistogram[];// (4, R) 21 | }; 22 | 23 | layout (buffer_reference, std430) readonly buffer PartitionHistogram { 24 | uint partitionHistogram[];// (P, R) 25 | }; 26 | 27 | layout (buffer_reference, std430) buffer Keys { 28 | uint keys[];// (N) 29 | }; 30 | 31 | // #ifdef KEY_VALUE 32 | layout (buffer_reference, std430) buffer Values { 33 | uint values[];// (N) 34 | }; 35 | // #endif 36 | 37 | layout (push_constant) uniform PushConstant { 38 | int pass; 39 | restrict ElementCount elementCountReference; 40 | restrict GlobalHistogram globalHistogramReference; 41 | restrict PartitionHistogram partitionHistogramReference; 42 | restrict Keys keysInReference; 43 | restrict Keys keysOutReference; 44 | // #ifdef KEY_VALUE 45 | restrict Values valuesInReference; 46 | restrict Values valuesOutReference; 47 | // #endif 48 | }; 49 | 50 | const uint SHMEM_SIZE = PARTITION_SIZE; 51 | 52 | shared uint localHistogram[SHMEM_SIZE];// (R, S=16)=4096, (P) for alias. take maximum. 53 | shared uint localHistogramSum[RADIX]; 54 | 55 | // returns 0b00000....11111, where msb is id-1. 56 | uvec4 GetExclusiveSubgroupMask(uint id) { 57 | return uvec4( 58 | (1 << id) - 1, 59 | (1 << (id - 32)) - 1, 60 | (1 << (id - 64)) - 1, 61 | (1 << (id - 96)) - 1 62 | ); 63 | } 64 | 65 | uint GetBitCount(uvec4 value) { 66 | uvec4 result = bitCount(value); 67 | return result[0] + result[1] + result[2] + result[3]; 68 | } 69 | 70 | void main() { 71 | uint threadIndex = gl_SubgroupInvocationID;// 0..31 72 | uint subgroupIndex = gl_SubgroupID;// 0..15 73 | uint index = subgroupIndex * gl_SubgroupSize + threadIndex; 74 | uvec4 subgroupMask = GetExclusiveSubgroupMask(threadIndex); 75 | 76 | uint partitionIndex = gl_WorkGroupID.x; 77 | uint partitionStart = partitionIndex * PARTITION_SIZE; 78 | 79 | uint elementCount = elementCountReference.elementCount; 80 | 81 | if (partitionStart >= elementCount) return; 82 | 83 | if (index < RADIX) { 84 | for (int i = 0; i < gl_NumSubgroups; ++i) { 85 | localHistogram[gl_NumSubgroups * index + i] = 0; 86 | } 87 | } 88 | barrier(); 89 | 90 | // load from global memory, local histogram and offset 91 | uint localKeys[PARTITION_DIVISION]; 92 | uint localRadix[PARTITION_DIVISION]; 93 | uint localOffsets[PARTITION_DIVISION]; 94 | uint subgroupHistogram[PARTITION_DIVISION]; 95 | 96 | // #ifdef KEY_VALUE 97 | uint localValues[PARTITION_DIVISION]; 98 | // #endif 99 | for (int i = 0; i < PARTITION_DIVISION; ++i) { 100 | uint keyIndex = partitionStart + (PARTITION_DIVISION * gl_SubgroupSize) * subgroupIndex + i * gl_SubgroupSize + threadIndex; 101 | uint key = keyIndex < elementCount ? keysInReference.keys[keyIndex] : 0xffffffff; 102 | localKeys[i] = key; 103 | 104 | // #ifdef KEY_VALUE 105 | localValues[i] = keyIndex < elementCount ? valuesInReference.values[keyIndex] : 0; 106 | // #endif 107 | 108 | uint radix = bitfieldExtract(key, pass * 8, 8); 109 | localRadix[i] = radix; 110 | 111 | // mask per digit 112 | uvec4 mask = subgroupBallot(true); 113 | #pragma unroll 114 | for (int j = 0; j < 8; ++j) { 115 | uint digit = (radix >> j) & 1; 116 | uvec4 ballot = subgroupBallot(digit == 1); 117 | // digit - 1 is 0 or 0xffffffff. xor to flip. 118 | mask &= uvec4(digit - 1) ^ ballot; 119 | } 120 | 121 | // subgroup level offset for radix 122 | uint subgroupOffset = GetBitCount(subgroupMask & mask); 123 | uint radixCount = GetBitCount(mask); 124 | 125 | // elect a representative per radix, add to histogram 126 | if (subgroupOffset == 0) { 127 | // accumulate to local histogram 128 | atomicAdd(localHistogram[gl_NumSubgroups * radix + subgroupIndex], radixCount); 129 | subgroupHistogram[i] = radixCount; 130 | } else { 131 | subgroupHistogram[i] = 0; 132 | } 133 | 134 | localOffsets[i] = subgroupOffset; 135 | } 136 | barrier(); 137 | 138 | // local histogram reduce 4096 139 | for (uint i = index; i < RADIX * gl_NumSubgroups; i += WORKGROUP_SIZE) { 140 | uint v = localHistogram[i]; 141 | uint sum = subgroupAdd(v); 142 | uint excl = subgroupExclusiveAdd(v); 143 | localHistogram[i] = excl; 144 | if (threadIndex == 0) { 145 | localHistogramSum[i / gl_SubgroupSize] = sum; 146 | } 147 | } 148 | barrier(); 149 | 150 | // local histogram reduce 128 151 | uint intermediateOffset0 = RADIX * gl_NumSubgroups / gl_SubgroupSize; 152 | if (index < intermediateOffset0) { 153 | uint v = localHistogramSum[index]; 154 | uint sum = subgroupAdd(v); 155 | uint excl = subgroupExclusiveAdd(v); 156 | localHistogramSum[index] = excl; 157 | if (threadIndex == 0) { 158 | localHistogramSum[intermediateOffset0 + index / gl_SubgroupSize] = sum; 159 | } 160 | } 161 | barrier(); 162 | 163 | // local histogram reduce 4 164 | uint intermediateSize1 = RADIX * gl_NumSubgroups / gl_SubgroupSize / gl_SubgroupSize; 165 | if (index < intermediateSize1) { 166 | uint v = localHistogramSum[intermediateOffset0 + index]; 167 | uint excl = subgroupExclusiveAdd(v); 168 | localHistogramSum[intermediateOffset0 + index] = excl; 169 | } 170 | barrier(); 171 | 172 | // local histogram add 128 173 | if (index < intermediateOffset0) { 174 | localHistogramSum[index] += localHistogramSum[intermediateOffset0 + index / gl_SubgroupSize]; 175 | } 176 | barrier(); 177 | 178 | // local histogram add 4096 179 | for (uint i = index; i < RADIX * gl_NumSubgroups; i += WORKGROUP_SIZE) { 180 | localHistogram[i] += localHistogramSum[i / gl_SubgroupSize]; 181 | } 182 | barrier(); 183 | 184 | // post-scan stage 185 | for (int i = 0; i < PARTITION_DIVISION; ++i) { 186 | uint radix = localRadix[i]; 187 | localOffsets[i] += localHistogram[gl_NumSubgroups * radix + subgroupIndex]; 188 | 189 | barrier(); 190 | if (subgroupHistogram[i] > 0) { 191 | atomicAdd(localHistogram[gl_NumSubgroups * radix + subgroupIndex], subgroupHistogram[i]); 192 | } 193 | barrier(); 194 | } 195 | 196 | // after atomicAdd, localHistogram contains inclusive sum 197 | if (index < RADIX) { 198 | uint v = index == 0 ? 0 : localHistogram[gl_NumSubgroups * index - 1]; 199 | localHistogramSum[index] = globalHistogramReference.globalHistogram[RADIX * pass + index] + partitionHistogramReference.partitionHistogram[RADIX * partitionIndex + index] - v; 200 | } 201 | barrier(); 202 | 203 | // rearrange keys. grouping keys together makes dstOffset to be almost sequential, grants huge speed boost. 204 | // now localHistogram is unused, so alias memory. 205 | for (int i = 0; i < PARTITION_DIVISION; ++i) { 206 | localHistogram[localOffsets[i]] = localKeys[i]; 207 | } 208 | barrier(); 209 | 210 | // binning 211 | for (uint i = index; i < PARTITION_SIZE; i += WORKGROUP_SIZE) { 212 | uint key = localHistogram[i]; 213 | uint radix = bitfieldExtract(key, pass * 8, 8); 214 | uint dstOffset = localHistogramSum[radix] + i; 215 | if (dstOffset < elementCount) { 216 | keysOutReference.keys[dstOffset] = key; 217 | } 218 | 219 | // #ifdef KEY_VALUE 220 | localKeys[i / WORKGROUP_SIZE] = dstOffset; 221 | // #endif 222 | } 223 | 224 | // #ifdef KEY_VALUE 225 | barrier(); 226 | 227 | for (int i = 0; i < PARTITION_DIVISION; ++i) { 228 | localHistogram[localOffsets[i]] = localValues[i]; 229 | } 230 | barrier(); 231 | 232 | for (uint i = index; i < PARTITION_SIZE; i += WORKGROUP_SIZE) { 233 | uint value = localHistogram[i]; 234 | valuesOutReference.values[localKeys[i / WORKGROUP_SIZE]] = value; 235 | } 236 | // #endif 237 | } 238 | -------------------------------------------------------------------------------- /include/shaders/gpusort/downsweep.comp.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/gpusort/downsweep.comp.spv -------------------------------------------------------------------------------- /include/shaders/gpusort/spine.comp: -------------------------------------------------------------------------------- 1 | #version 460 core 2 | 3 | #extension GL_EXT_buffer_reference : require 4 | #extension GL_KHR_shader_subgroup_basic : enable 5 | #extension GL_KHR_shader_subgroup_arithmetic : enable 6 | #extension GL_KHR_shader_subgroup_ballot : enable 7 | // #extension GL_EXT_buffer_reference2 : require 8 | 9 | const int RADIX = 256; 10 | #define SUBGROUP_SIZE 32 11 | #define WORKGROUP_SIZE 512 12 | #define PARTITION_DIVISION 8 13 | const int PARTITION_SIZE = PARTITION_DIVISION * WORKGROUP_SIZE; 14 | 15 | // dispatch this shader (RADIX, 1, 1), so that gl_WorkGroupID.x is radix 16 | layout(local_size_x = WORKGROUP_SIZE) in; 17 | 18 | layout(buffer_reference, std430) readonly buffer ElementCount 19 | { 20 | uint elementCount; 21 | }; 22 | 23 | layout(buffer_reference, std430) buffer GlobalHistogram 24 | { 25 | uint globalHistogram[];// (4, R) 26 | }; 27 | 28 | layout(buffer_reference, std430) buffer PartitionHistogram 29 | { 30 | uint partitionHistogram[];// (P, R) 31 | }; 32 | 33 | layout(push_constant) uniform PushConstant 34 | { 35 | int pass; 36 | restrict ElementCount elementCountReference; 37 | restrict GlobalHistogram globalHistogramReference; 38 | restrict PartitionHistogram partitionHistogramReference; 39 | }; 40 | 41 | shared uint reduction; 42 | shared uint intermediate[SUBGROUP_SIZE]; 43 | 44 | void main() 45 | { 46 | uint threadIndex = gl_SubgroupInvocationID;// 0..31 47 | uint subgroupIndex = gl_SubgroupID;// 0..15 48 | uint index = subgroupIndex * gl_SubgroupSize + threadIndex; 49 | uint radix = gl_WorkGroupID.x; 50 | 51 | uint elementCount = elementCountReference.elementCount; 52 | 53 | uint partitionCount = (elementCount + PARTITION_SIZE - 1) / PARTITION_SIZE; 54 | 55 | if (index == 0) { 56 | reduction = 0; 57 | } 58 | barrier(); 59 | 60 | for (uint i = 0; WORKGROUP_SIZE * i < partitionCount; ++i) { 61 | uint partitionIndex = WORKGROUP_SIZE * i + index; 62 | uint value = partitionIndex < partitionCount 63 | ? partitionHistogramReference 64 | .partitionHistogram[RADIX * partitionIndex + radix] 65 | : 0; 66 | uint excl = subgroupExclusiveAdd(value) + reduction; 67 | uint sum = subgroupAdd(value); 68 | 69 | if (subgroupElect()) { 70 | intermediate[subgroupIndex] = sum; 71 | } 72 | barrier(); 73 | 74 | if (index < gl_NumSubgroups) { 75 | uint excl = subgroupExclusiveAdd(intermediate[index]); 76 | uint sum = subgroupAdd(intermediate[index]); 77 | intermediate[index] = excl; 78 | 79 | if (index == 0) { 80 | reduction += sum; 81 | } 82 | } 83 | barrier(); 84 | 85 | if (partitionIndex < partitionCount) { 86 | excl += intermediate[subgroupIndex]; 87 | partitionHistogramReference 88 | .partitionHistogram[RADIX * partitionIndex + radix] = excl; 89 | } 90 | barrier(); 91 | } 92 | // uint carry; 93 | // uaddCarry(1, 1, carry); 94 | if (gl_WorkGroupID.x == 0) { 95 | // one workgroup is responsible for global histogram prefix sum 96 | if (index < RADIX) { 97 | uint value = 98 | globalHistogramReference.globalHistogram[RADIX * pass + index]; 99 | uint excl = subgroupExclusiveAdd(value); 100 | uint sum = subgroupAdd(value); 101 | 102 | if (subgroupElect()) { 103 | intermediate[subgroupIndex] = sum; 104 | } 105 | barrier(); 106 | 107 | if (index < RADIX / gl_SubgroupSize) { 108 | uint excl = subgroupExclusiveAdd(intermediate[index]); 109 | intermediate[index] = excl; 110 | } 111 | barrier(); 112 | 113 | excl += intermediate[subgroupIndex]; 114 | globalHistogramReference.globalHistogram[RADIX * pass + index] = excl; 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /include/shaders/gpusort/spine.comp.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/gpusort/spine.comp.spv -------------------------------------------------------------------------------- /include/shaders/gpusort/upsweep.comp: -------------------------------------------------------------------------------- 1 | #version 460 core 2 | 3 | #extension GL_EXT_buffer_reference : require 4 | #extension GL_KHR_shader_subgroup_basic: enable 5 | 6 | const int RADIX = 256; 7 | #define WORKGROUP_SIZE 512 8 | #define PARTITION_DIVISION 8 9 | const int PARTITION_SIZE = PARTITION_DIVISION * WORKGROUP_SIZE; 10 | 11 | layout (local_size_x = WORKGROUP_SIZE) in; 12 | 13 | layout (buffer_reference, std430) readonly buffer ElementCount { 14 | uint elementCount; 15 | }; 16 | 17 | layout (buffer_reference, std430) buffer GlobalHistogram { 18 | uint globalHistogram[];// (4, R) 19 | }; 20 | 21 | layout (buffer_reference, std430) writeonly buffer PartitionHistogram { 22 | uint partitionHistogram[];// (P, R) 23 | }; 24 | 25 | layout (buffer_reference, std430) readonly buffer Keys { 26 | uint keys[];// (N) 27 | }; 28 | 29 | layout (push_constant) uniform PushConstant { 30 | int pass; 31 | restrict ElementCount elementCountReference; 32 | restrict GlobalHistogram globalHistogramReference; 33 | restrict PartitionHistogram partitionHistogramReference; 34 | restrict Keys keysInReference; 35 | }; 36 | 37 | shared uint localHistogram[RADIX]; 38 | 39 | void main() { 40 | uint threadIndex = gl_SubgroupInvocationID;// 0..31 41 | uint subgroupIndex = gl_SubgroupID;// 0..31 42 | uint index = subgroupIndex * gl_SubgroupSize + threadIndex; 43 | 44 | uint elementCount = elementCountReference.elementCount; 45 | 46 | uint partitionIndex = gl_WorkGroupID.x; 47 | uint partitionStart = partitionIndex * PARTITION_SIZE; 48 | 49 | // discard all workgroup invocations 50 | if (partitionStart >= elementCount) { 51 | return; 52 | } 53 | 54 | if (index < RADIX) { 55 | localHistogram[index] = 0; 56 | } 57 | barrier(); 58 | 59 | // local histogram 60 | for (int i = 0; i < PARTITION_DIVISION; ++i) { 61 | uint keyIndex = partitionStart + WORKGROUP_SIZE * i + index; 62 | uint key = keyIndex < elementCount ? keysInReference.keys[keyIndex] : 0xffffffff; 63 | uint radix = bitfieldExtract(key, 8 * pass, 8); 64 | atomicAdd(localHistogram[radix], 1); 65 | } 66 | barrier(); 67 | 68 | if (index < RADIX) { 69 | // set to partition histogram 70 | partitionHistogramReference.partitionHistogram[RADIX * partitionIndex + index] = localHistogram[index]; 71 | 72 | // add to global histogram 73 | atomicAdd(globalHistogramReference.globalHistogram[RADIX * pass + index], localHistogram[index]); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /include/shaders/gpusort/upsweep.comp.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/gpusort/upsweep.comp.spv -------------------------------------------------------------------------------- /include/shaders/inverseIndex.comp: -------------------------------------------------------------------------------- 1 | #version 460 2 | 3 | #include "DataStruct.h" 4 | layout(binding = e_point_count) buffer Info 5 | { 6 | uint all_count; 7 | }; 8 | layout(binding = e_visiable_count) buffer _visiable_count 9 | { 10 | uint visiable_count; 11 | }; 12 | layout(std430, binding = e_instance_value) readonly buffer InstanceIndex 13 | { 14 | uint index[]; 15 | }; 16 | layout(std430, binding = e_instance_key) readonly buffer InstanceKey 17 | { 18 | uint key[]; 19 | }; 20 | layout(std430, binding = e_inverse_index) writeonly buffer InverseMap 21 | { 22 | int inverse_i[]; // (N), inverse map from id to sorted index 23 | }; 24 | layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; 25 | void main() 26 | { 27 | uint id = gl_GlobalInvocationID.x; 28 | if (id < visiable_count) { 29 | inverse_i[index[id]] = int(id); 30 | } 31 | } -------------------------------------------------------------------------------- /include/shaders/inverseIndex.comp.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/inverseIndex.comp.spv -------------------------------------------------------------------------------- /include/shaders/process.comp: -------------------------------------------------------------------------------- 1 | #version 460 2 | #extension GL_EXT_debug_printf : enable 3 | #include "DataStruct.h" 4 | layout(local_size_x = 256) in; 5 | layout(binding = 0) buffer Info 6 | { 7 | uint all_count; 8 | }; 9 | layout(std430, binding = 1) buffer _GaussianPoint 10 | { 11 | GaussianPoint points[]; 12 | }; 13 | 14 | // layout(std430, set = 1, binding = 2) writeonly buffer GaussianCov3d 15 | // { 16 | // // float gaussian_cov3d[];// (N, 6) 17 | // mat3 gaussian_cov3d[]; 18 | // }; 19 | void main() 20 | { 21 | uint id = gl_GlobalInvocationID.x; 22 | if (gl_GlobalInvocationID.x < all_count) { 23 | vec4 q = points[gl_GlobalInvocationID.x].rot; 24 | vec3 s = points[gl_GlobalInvocationID.x].scale; 25 | mat3 rot; 26 | float xx = q.x * q.x; 27 | float yy = q.y * q.y; 28 | float zz = q.z * q.z; 29 | float xy = q.x * q.y; 30 | float xz = q.x * q.z; 31 | float yz = q.y * q.z; 32 | float wx = q.w * q.x; 33 | float wy = q.w * q.y; 34 | float wz = q.w * q.z; 35 | rot[0][0] = 1.f - 2.f * (yy + zz); 36 | rot[0][1] = 2.f * (xy + wz); 37 | rot[0][2] = 2.f * (xz - wy); 38 | rot[1][0] = 2.f * (xy - wz); 39 | rot[1][1] = 1.f - 2.f * (xx + zz); 40 | rot[1][2] = 2.f * (yz + wx); 41 | rot[2][0] = 2.f * (xz + wy); 42 | rot[2][1] = 2.f * (yz - wx); 43 | rot[2][2] = 1.f - 2.f * (xx + yy); 44 | 45 | mat3 ss = mat3(0.f); 46 | ss[0][0] = s[0] * s[0]; 47 | ss[1][1] = s[1] * s[1]; 48 | ss[2][2] = s[2] * s[2]; 49 | mat3 res = rot * ss * transpose(rot); 50 | points[id].conv3d[0] = vec4(res[0][0], res[1][0], res[2][0], 1); 51 | points[id].conv3d[1] = vec4(res[1][1], res[2][1], res[2][2], 1); 52 | 53 | if (gl_GlobalInvocationID.x == 0) { 54 | debugPrintfEXT( 55 | "message %f %f %f | %f %f %f | %f %f %f |%f %f %f \n", 56 | float(points[id].sh[0][0]), 57 | float(points[id].sh[0][1]), 58 | float(points[id].sh[0][2]), 59 | float(points[id].sh[0][3]), 60 | float(points[id].sh[1][0]), 61 | float(points[id].sh[1][1]), 62 | float(points[id].sh[1][2]), 63 | float(points[id].sh[1][3]), 64 | float(points[id].sh[2][0]), 65 | float(points[id].sh[2][1]), 66 | float(points[id].sh[2][2]), 67 | float(points[id].sh[2][3])); 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /include/shaders/process.comp.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/process.comp.spv -------------------------------------------------------------------------------- /include/shaders/projection.comp: -------------------------------------------------------------------------------- 1 | #version 460 2 | 3 | #extension GL_GOOGLE_include_directive : enable 4 | #include "DataStruct.h" 5 | layout( binding = e_camera) uniform Camera 6 | { 7 | mat4 projection; 8 | mat4 view; 9 | vec3 camera_position; 10 | float pad0; 11 | uvec2 screen_size;// (width, height) 12 | }; 13 | 14 | layout(push_constant, std430) uniform PushConstants 15 | { 16 | mat4 model; 17 | }; 18 | layout(std430, binding = e_visiable_count) buffer _visiable_count 19 | { 20 | 21 | uint visiable_count; 22 | }; 23 | layout(std430, binding = e_indir_cmd) buffer _IndexedIndirectCommand 24 | { 25 | 26 | IndexedIndirectCommand indexedIndirectCommand; 27 | }; 28 | 29 | layout(std430, binding = e_gaussian_raw_point) readonly buffer _GaussianPoint 30 | { 31 | GaussianPoint points[]; 32 | }; 33 | layout(std430, binding = e_instance_point)writeonly buffer _Instance 34 | { 35 | InstancePoint instances[]; 36 | }; 37 | layout(std430, binding = e_inverse_index)readonly buffer InverseMap 38 | { 39 | int inverse_map[];// (N), inverse map from id to sorted index 40 | }; 41 | layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; 42 | 43 | void main() 44 | { 45 | uint id = gl_GlobalInvocationID.x; 46 | // if (id >= visiable_count) return; 47 | if (id == 0) { 48 | indexedIndirectCommand.indexCount = visiable_count * 6; 49 | indexedIndirectCommand.instanceCount = 1; 50 | } 51 | int inverse_id = inverse_map[id]; 52 | if (inverse_id == -1) 53 | return; 54 | 55 | vec3 v0 = vec3(points[id].conv3d[0]); 56 | vec3 v1 = vec3(points[id].conv3d[1]); 57 | vec4 pos = vec4(points[id].pos, 1.f); 58 | lowp vec4[12] gaussian_sh = points[id].sh; 59 | 60 | // direction in model space for SH calculation 61 | vec4 camera_podel_position = inverse(model) * vec4(camera_position, 1.f); 62 | camera_podel_position = camera_podel_position / camera_podel_position.w; 63 | vec3 dir = normalize(pos.xyz - camera_podel_position.xyz); 64 | 65 | // [v0.x v0.y v0.z] 66 | // [v0.y v1.x v1.y] 67 | // [v0.z v1.y v1.z] 68 | mat3 cov3d = mat3(v0, v0.y, v1.xy, v0.z, v1.yz); 69 | 70 | // model matrix 71 | mat3 model3d = mat3(model); 72 | cov3d = model3d * cov3d * transpose(model3d); 73 | pos = model * pos; 74 | 75 | // view matrix 76 | mat3 view3d = mat3(view); 77 | cov3d = view3d * cov3d * transpose(view3d); 78 | pos = view * pos; 79 | 80 | // projection 81 | float r = length(vec3(pos)); 82 | mat3 J = mat3(-1.f / pos.z, 0.f, -2.f * pos.x / r, 0.f, -1.f / pos.z, 83 | -2.f * pos.y / r, pos.x / pos.z / pos.z, pos.y / pos.z / pos.z, 84 | -2.f * pos.z / r); 85 | cov3d = J * cov3d * transpose(J); 86 | 87 | // projection xy 88 | mat2 projection_scale = mat2(projection); 89 | mat2 cov2d = projection_scale * mat2(cov3d) * projection_scale; 90 | 91 | // low-pass filter 92 | cov2d[0][0] += 1.f / screen_size.x / screen_size.x; 93 | cov2d[1][1] += 1.f / screen_size.y / screen_size.y; 94 | // eigendecomposition 95 | // [a c] = [x y] 96 | // [c b] [y z] 97 | float a = cov2d[0][0]; 98 | float b = cov2d[1][1]; 99 | float c = cov2d[1][0]; 100 | float D = sqrt((a - b) * (a - b) + 4.f * c * c); 101 | float s0 = sqrt(0.5f * (a + b + D)); 102 | float s1 = sqrt(0.5f * (a + b - D)); 103 | // decompose to R^T S^2 R 104 | float sin2t = 2.f * c / D; 105 | float cos2t = (a - b) / D; 106 | float theta = atan(sin2t, cos2t) / 2.f; 107 | 108 | pos = projection * pos; 109 | pos = pos / pos.w; 110 | vec3 color; 111 | 112 | // calculate spherical harmonics 113 | const float C0 = 0.28209479177387814f; 114 | const float C1 = 0.4886025119029199f; 115 | const float C20 = 1.0925484305920792f; 116 | const float C21 = 0.31539156525252005f; 117 | const float C22 = 0.5462742152960396f; 118 | const float C30 = 0.5900435899266435f; 119 | const float C31 = 2.890611442640554f; 120 | const float C32 = 0.4570457994644658f; 121 | const float C33 = 0.3731763325901154f; 122 | const float C34 = 1.445305721320277f; 123 | float x = dir.x; 124 | float y = dir.y; 125 | float z = dir.z; 126 | float xx = x * x; 127 | float yy = y * y; 128 | float zz = z * z; 129 | float xy = x * y; 130 | float yz = y * z; 131 | float xz = x * z; 132 | vec4 basis0 = vec4(C0, -C1 * y, C1 * z, -C1 * x); 133 | vec4 basis1 = 134 | vec4(C20 * xy, -C20 * yz, C21 * (2.f * zz - xx - yy), -C20 * xz); 135 | vec4 basis2 = vec4(C22 * (xx - yy), -C30 * y * (3.f * xx - yy), C31 * xy * z, -C32 * y * (4.f * zz - xx - yy)); 136 | vec4 basis3 = vec4(C33 * z * (2.f * zz - 3.f * xx - 3.f * yy), 137 | -C32 * x * (4.f * zz - xx - yy), 138 | C34 * z * (xx - yy), 139 | -C30 * x * (xx - 3.f * yy)); 140 | 141 | // mat3x4 sh0 = mat3x4(gaussian_sh[0], 142 | // gaussian_sh[1], 143 | // gaussian_sh[2],gaussian_sh[3] 144 | // ,gaussian_sh[0], 145 | // 146 | // 147 | // 148 | // gaussian_sh[id * 12 + 4], 149 | // gaussian_sh[id * 12 + 8]); 150 | // mat3x4 sh1 = mat3x4(gaussian_sh[id * 12 + 1], gaussian_sh[id * 12 + 5], 151 | // gaussian_sh[id * 12 + 9]); 152 | // mat3x4 sh2 = mat3x4(gaussian_sh[id * 12 + 2], gaussian_sh[id * 12 + 6], 153 | // gaussian_sh[id * 12 + 10]); 154 | // mat3x4 sh3 = mat3x4(gaussian_sh[id * 12 + 3], gaussian_sh[id * 12 + 7], 155 | // gaussian_sh[id * 12 + 11]); 156 | 157 | mat3x4 sh0 = mat3x4(gaussian_sh[0], gaussian_sh[4], gaussian_sh[8]); 158 | mat3x4 sh1 = mat3x4(gaussian_sh[1], gaussian_sh[5], gaussian_sh[9]); 159 | mat3x4 sh2 = mat3x4(gaussian_sh[2], gaussian_sh[6], gaussian_sh[10]); 160 | mat3x4 sh3 = mat3x4(gaussian_sh[3], gaussian_sh[7], gaussian_sh[11]); 161 | // row vector-matrix multiplication 162 | color = basis0 * sh0 + basis1 * sh1 + basis2 * sh2 + basis3 * sh3; 163 | 164 | // translation and clip 165 | color = max(color + 0.5f, 0.f); 166 | 167 | float opacity = points[id].opacity; 168 | // vec3 ndc_position; 169 | // vec2 scale 170 | // float theta; 171 | // vec4 color; 172 | 173 | instances[inverse_id].ndc_position = vec3(pos); 174 | instances[inverse_id].scale = vec2(s0, s1); 175 | 176 | // instances[inverse_id].scale = vec2(screen_size.x, screen_size.y); 177 | 178 | instances[inverse_id] 179 | .theta = theta; 180 | instances[inverse_id].color = vec4(color, opacity); 181 | } 182 | -------------------------------------------------------------------------------- /include/shaders/projection.comp.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/projection.comp.spv -------------------------------------------------------------------------------- /include/shaders/rank.comp: -------------------------------------------------------------------------------- 1 | #version 460 core 2 | #include "DataStruct.h" 3 | #extension GL_EXT_debug_printf : enable 4 | 5 | layout(std430, binding = e_gaussian_raw_point) buffer _GaussianPoint 6 | { 7 | GaussianPoint points[]; 8 | }; 9 | layout (std430, binding = 13) readonly buffer GaussianPosition { 10 | float gaussian_position[]; // (N, 3) 11 | }; 12 | layout(binding = e_point_count) buffer Info 13 | { 14 | uint all_count; 15 | }; 16 | layout(std430, binding = e_visiable_count) buffer _visiable_count 17 | { 18 | uint visiable_count; 19 | }; 20 | // layout(std430, binding = e_indir_cmd) buffer DrawIndirect 21 | //{ 22 | // 23 | // // IndexedIndirectCommand indirectDrawcmd; 24 | // uint indexCount; 25 | // uint instanceCount; 26 | // uint firstIndex; 27 | // int vertexOffset; 28 | // uint firstInstance; 29 | // }; 30 | layout(push_constant, std430) uniform PushConstants 31 | { 32 | mat4 model; 33 | }; 34 | layout(binding = e_camera) uniform Camera 35 | { 36 | mat4 projection; 37 | mat4 view; 38 | vec3 camera_position; 39 | // float pad0; 40 | // uvec2 screen_size; // (width, height) 41 | }; 42 | 43 | layout(std430, binding = e_instance_key) writeonly buffer InstanceKey 44 | { 45 | uint key[]; 46 | }; 47 | 48 | layout(std430, binding = e_instance_value) writeonly buffer InstanceIndex 49 | { 50 | uint index[]; 51 | }; 52 | layout(local_size_x = 256) in; 53 | void main() 54 | { 55 | uint id = gl_GlobalInvocationID.x; 56 | if (id >= all_count) 57 | return; 58 | 59 | vec4 pos = vec4(points[id].pos, 1.f); 60 | pos = projection * view * model * pos; 61 | pos = pos / pos.w; 62 | float depth = pos.z; 63 | // valid only when center is inside NDC clip space. 64 | if (abs(pos.x) <= 1.f && abs(pos.y) <= 1.f && pos.z >= 0.f && pos.z <= 1.f) { 65 | // indirectDrawcmd.indexCount = atomicAdd(indirectDrawcmd.indexCount, 1); 66 | uint instance_index = atomicAdd(visiable_count, 1); 67 | key[instance_index] = floatBitsToUint(1.f - depth); 68 | index[instance_index] = id; 69 | debugPrintfEXT("message \n"); 70 | } 71 | } -------------------------------------------------------------------------------- /include/shaders/rank.comp.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/rank.comp.spv -------------------------------------------------------------------------------- /include/shaders/splat.frag: -------------------------------------------------------------------------------- 1 | #version 460 2 | 3 | layout(location = 0) in vec4 color; 4 | layout(location = 1) in vec2 position; 5 | 6 | layout(location = 0) out vec4 out_color; 7 | 8 | void main() 9 | { 10 | float gaussian_alpha = exp(-0.5f * dot(position, position)); 11 | float alpha = color.a * gaussian_alpha; 12 | // premultiplied alpha 13 | out_color = vec4(color.rgb * alpha, alpha); 14 | } 15 | -------------------------------------------------------------------------------- /include/shaders/splat.frag.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/splat.frag.spv -------------------------------------------------------------------------------- /include/shaders/splat.vert: -------------------------------------------------------------------------------- 1 | #version 450 2 | #extension GL_EXT_debug_printf : enable 3 | 4 | #include "DataStruct.h" 5 | layout(std430, binding = e_instance_point) readonly buffer Instances 6 | { 7 | Instance instances[]; // (N, 10). 3 for ndc position, 3 for scale rot, 4 for color 8 | }; 9 | layout(location = 0) out vec4 out_color; 10 | layout(location = 1) out vec2 out_position; 11 | 12 | void main() 13 | { 14 | // index [0,1,2,2,1,3], 4 vertices for a splat. 15 | int index = gl_VertexIndex / 4; 16 | vec3 ndc_position = instances[index].ndc_position; 17 | // if(index == 0){ 18 | // debugPrintfEXT("message %f %f %f\n", ndc_position.x, ndc_position.y, ndc_position.z); 19 | // } 20 | vec2 scale = instances[index].scale; 21 | float theta = instances[index].theta; 22 | vec4 color = instances[index].color; 23 | 24 | // quad positions (-1, -1), (-1, 1), (1, -1), (1, 1), ccw in screen space. 25 | int vert_index = gl_VertexIndex % 4; 26 | vec2 position = vec2(vert_index / 2, vert_index % 2) * 2.f - 1.f; 27 | 28 | mat2 rot = mat2(cos(theta), sin(theta), -sin(theta), cos(theta)); 29 | 30 | float confidence_radius = 3.f; 31 | 32 | gl_Position = vec4(ndc_position + vec3(rot * (scale * position) * confidence_radius, 0.f), 1.f); 33 | out_color = color; 34 | out_position = position * confidence_radius; 35 | } -------------------------------------------------------------------------------- /include/shaders/splat.vert.spv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/include/shaders/splat.vert.spv -------------------------------------------------------------------------------- /include/sort.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Wrapper/ComputePass/ComputePass.hpp" 3 | #include "Wrapper/Pipeline/Compute_Pipeline.hpp" 4 | #include "vulkan/vulkan.hpp" 5 | #include 6 | namespace MCRT { 7 | class Buffer; 8 | class gpusort { 9 | public: 10 | 11 | gpusort() = default; 12 | void Init(uint all_point_count, std::shared_ptr indirect_buffer); 13 | void sort(vk::CommandBuffer cmd, std::shared_ptr key_buffer, std::shared_ptr value_buffer); 14 | 15 | // private: 16 | std::shared_ptr storage_buffer; 17 | std::shared_ptr visiable_count_buffer; 18 | 19 | VkDeviceSize elementCountSize = sizeof(uint32_t); 20 | VkDeviceSize histogramSize; 21 | VkDeviceSize inoutSize; 22 | 23 | // std::shared_ptr key_buffer; 24 | // std::shared_ptr value_buffer; 25 | 26 | VkDeviceSize histogramOffset = sizeof(uint32_t); 27 | VkDeviceSize inoutOffset; 28 | uint32_t partitionCount; 29 | std::shared_ptr upsweepPipeline; 30 | std::shared_ptr spinePipeline; 31 | std::shared_ptr downsweepKeyValuePipeline; 32 | 33 | std::shared_ptr> upsweepPass; 34 | std::shared_ptr> spinePass; 35 | std::shared_ptr> downsweepKeyValuePass; 36 | 37 | }; 38 | } -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | #include "ply_loader.hpp" 3 | #include "rasterizer.h" 4 | #include "torch/torch.h" 5 | #include "torch/utils.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | // #include 11 | #include 12 | #define STB_IMAGE_WRITE_IMPLEMENTATION 13 | #include "stb_image_write.h" 14 | 15 | std::function resizeFunctional(torch::Tensor& t) 16 | { 17 | auto lambda = [&t](size_t N) { 18 | t.resize_({ (long long)N }); 19 | return reinterpret_cast(t.contiguous().data_ptr()); 20 | }; 21 | return lambda; 22 | } 23 | 24 | std::tuple 25 | RasterizeGaussiansCUDA( 26 | const torch::Tensor& background, 27 | const torch::Tensor& means3D, 28 | const torch::Tensor& colors, 29 | const torch::Tensor& opacity, 30 | const torch::Tensor& scales, 31 | const torch::Tensor& rotations, 32 | const float scale_modifier, 33 | const torch::Tensor& cov3D_precomp, 34 | const torch::Tensor& viewmatrix, 35 | const torch::Tensor& projmatrix, 36 | const float tan_fovx, 37 | const float tan_fovy, 38 | const int image_height, 39 | const int image_width, 40 | const torch::Tensor& sh, 41 | const int degree, 42 | const torch::Tensor& campos, 43 | const bool prefiltered, 44 | const bool debug) 45 | { 46 | if (means3D.ndimension() != 2 || means3D.size(1) != 3) { 47 | AT_ERROR("means3D must have dimensions (num_points, 3)"); 48 | } 49 | 50 | const int P = means3D.size(0); 51 | const int H = image_height; 52 | const int W = image_width; 53 | 54 | auto int_opts = means3D.options().dtype(torch::kInt32); 55 | auto float_opts = means3D.options().dtype(torch::kFloat32); 56 | 57 | torch::Tensor out_color = torch::full({ NUM_CHANNELS, H, W }, 0.0, float_opts); 58 | torch::Tensor radii = 59 | torch::full({ P }, 0, means3D.options().dtype(torch::kInt32)); 60 | 61 | torch::Device device(torch::kCUDA); 62 | torch::TensorOptions options(torch::kByte); 63 | torch::Tensor geomBuffer = torch::empty({ 0 }, options.device(device)); 64 | torch::Tensor binningBuffer = torch::empty({ 0 }, options.device(device)); 65 | torch::Tensor imgBuffer = torch::empty({ 0 }, options.device(device)); 66 | std::function geomFunc = resizeFunctional(geomBuffer); 67 | std::function binningFunc = resizeFunctional(binningBuffer); 68 | std::function imgFunc = resizeFunctional(imgBuffer); 69 | 70 | int rendered = 0; 71 | if (P != 0) { 72 | int M = 0; 73 | if (sh.size(0) != 0) { 74 | M = sh.size(1); 75 | } 76 | 77 | rendered = CudaRasterizer::Rasterizer::forward( 78 | geomFunc, 79 | binningFunc, 80 | imgFunc, 81 | P, 82 | degree, 83 | M, 84 | background.contiguous().data(), 85 | W, 86 | H, 87 | means3D.contiguous().data(), 88 | sh.contiguous().data(), 89 | colors.contiguous().data(), 90 | opacity.contiguous().data(), 91 | scales.contiguous().data(), 92 | scale_modifier, 93 | rotations.contiguous().data(), 94 | cov3D_precomp.contiguous().data(), 95 | viewmatrix.contiguous().data(), 96 | projmatrix.contiguous().data(), 97 | campos.contiguous().data(), 98 | tan_fovx, 99 | tan_fovy, 100 | prefiltered, 101 | out_color.contiguous().data(), 102 | radii.contiguous().data(), 103 | debug); 104 | } 105 | return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer); 106 | } 107 | 108 | int main() 109 | { 110 | 111 | auto gs_data = MCGS::load_ply("point_cloud.ply"); 112 | 113 | auto background = torch::tensor({ 0., 0., 0. }); 114 | 115 | auto xyz = MCGS::get_xyz(gs_data); 116 | 117 | std::vector xyz_3 { 118 | xyz.begin(), 119 | xyz.end() - 3 120 | }; 121 | auto scale_d = MCGS::get_scale(gs_data); 122 | auto dc_012 = MCGS::get_dc_012(gs_data); 123 | auto dc_rest = MCGS::get_dc_rest(gs_data); 124 | 125 | auto opacity_d = MCGS::get_opacity(gs_data); 126 | auto rotations_d = MCGS::get_rotation(gs_data); 127 | 128 | torch::Tensor tensor_xyz = torch::from_blob(xyz.data(), xyz.size(), torch::kFloat); 129 | 130 | torch::Tensor tensor_dc_012 = torch::from_blob(dc_012.data(), dc_012.size(), torch::kFloat); 131 | 132 | torch::Tensor tensor_dc_rest = torch::from_blob(dc_rest.data(), dc_rest.size(), torch::kFloat); 133 | // std::vector feature = dc_012; 134 | std::vector feature(dc_012.size() + dc_rest.size()); 135 | for (int i = 0; i < dc_012.size() / 3; i++) { 136 | for (int j = 0; j < 3; j++) { 137 | 138 | feature[i * 48 + j] = dc_012[i * 3 + j]; 139 | } 140 | for (int j = 0; j < 45; j++) { 141 | feature[i * 48 + 3 + j] = dc_rest[i * 45 + j]; 142 | } 143 | } 144 | torch::Tensor tensor_feature = torch::from_blob(feature.data(), feature.size(), torch::kFloat); 145 | torch::Tensor tensor_opacity_d = torch::from_blob(opacity_d.data(), opacity_d.size(), torch::kFloat); 146 | 147 | torch::Tensor tensor_scale_d = torch::from_blob(scale_d.data(), scale_d.size(), torch::kFloat); 148 | 149 | torch::Tensor tensor_rotations_d = torch::from_blob(rotations_d.data(), rotations_d.size(), torch::kFloat); 150 | 151 | auto proj = torch::tensor({ -2.760816, 152 | .300833, 153 | -0.021124, 154 | -0.021122, 155 | .306501, 156 | 2.70976, 157 | -0.190277, 158 | -0.190258, 159 | -0., 160 | -0.531742, 161 | -0.981605, 162 | -0.981507, 163 | -0., 164 | -0., 165 | 4.021532, 166 | 4.031129 }, 167 | torch::kFloat); 168 | 169 | auto c2w = torch::tensor({ -.993989, 170 | .1083, 171 | -.021122, 172 | 0., 173 | .11034, 174 | .97551, 175 | -.19026, 176 | 0., 177 | 0., 178 | -.19143, 179 | -.98151, 180 | 0., 181 | 0., 182 | 0., 183 | 4.0311, 184 | 1. }, 185 | torch::kFloat); 186 | 187 | auto camera_pos = torch::tensor({ 0.0851, 0.7670, 0.39566 }, torch::kFloat); 188 | auto color = torch::tensor({}, torch::kFloat); 189 | auto pre_comp = torch::tensor({}, torch::kFloat); 190 | 191 | auto res = RasterizeGaussiansCUDA(background, 192 | tensor_xyz.reshape({ -1, 3 }), 193 | color, 194 | tensor_opacity_d.reshape({ -1, 1 }), 195 | tensor_scale_d.reshape({ -1, 3 }), 196 | tensor_rotations_d.reshape({ -1, 4 }), 197 | 1.f, 198 | pre_comp, 199 | c2w.reshape({ 4, 4 }), 200 | proj.reshape({ 4, 4 }), 201 | 0.36, 202 | 0.36, 203 | 800, 204 | 800, 205 | tensor_feature.reshape({ -1, 16, 3 }), 206 | 3, 207 | camera_pos, 208 | false, 209 | true); 210 | 211 | auto rr = std::get<1>(res); 212 | auto num_rendered = std::get<0>(res); 213 | 214 | rr = rr.permute({ 1, 2, 0 }); 215 | 216 | rr = rr.contiguous(); 217 | auto tt = rr.type(); 218 | auto data_ptr = rr.data_ptr(); 219 | size_t vector_size = rr.numel(); 220 | std::vector vec(data_ptr, data_ptr + vector_size); 221 | std::for_each(vec.begin(), 222 | vec.end(), 223 | [](float& i) { 224 | i *= 255.f; 225 | }); 226 | 227 | std::vector data(800 * 800 * 3); 228 | for (int i = 0; i < data.size(); i++) { 229 | 230 | data[i] = (unsigned char)vec[i]; 231 | } 232 | 233 | for (int i = 0; i < data.size(); i++) { 234 | } 235 | 236 | auto suc = stbi_write_jpg("main1.jpg", 800, 800, 3, data.data(), 800 * 3); 237 | std::cout << suc << std::endl; 238 | 239 | return 0; 240 | } 241 | -------------------------------------------------------------------------------- /main1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/main1.jpg -------------------------------------------------------------------------------- /nsightcompute.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/nsightcompute.txt -------------------------------------------------------------------------------- /origin_cuda/auxiliary.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 13 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 14 | 15 | #include "config.h" 16 | #include "stdio.h" 17 | 18 | #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) 19 | #define NUM_WARPS (BLOCK_SIZE/32) 20 | 21 | // Spherical harmonics coefficients 22 | __device__ const float SH_C0 = 0.28209479177387814f; 23 | __device__ const float SH_C1 = 0.4886025119029199f; 24 | __device__ const float SH_C2[] = { 25 | 1.0925484305920792f, 26 | -1.0925484305920792f, 27 | 0.31539156525252005f, 28 | -1.0925484305920792f, 29 | 0.5462742152960396f 30 | }; 31 | __device__ const float SH_C3[] = { 32 | -0.5900435899266435f, 33 | 2.890611442640554f, 34 | -0.4570457994644658f, 35 | 0.3731763325901154f, 36 | -0.4570457994644658f, 37 | 1.445305721320277f, 38 | -0.5900435899266435f 39 | }; 40 | 41 | __forceinline__ __device__ float ndc2Pix(float v, int S) 42 | { 43 | return ((v + 1.0) * S - 1.0) * 0.5; 44 | } 45 | 46 | __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) 47 | { 48 | rect_min = { 49 | min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), 50 | min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) 51 | }; 52 | rect_max = { 53 | min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))), 54 | min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y))) 55 | }; 56 | } 57 | 58 | __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) 59 | { 60 | float3 transformed = { 61 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 62 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 63 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 64 | }; 65 | return transformed; 66 | } 67 | 68 | __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) 69 | { 70 | float4 transformed = { 71 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], 72 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], 73 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], 74 | matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] 75 | }; 76 | return transformed; 77 | } 78 | 79 | __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) 80 | { 81 | float3 transformed = { 82 | matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, 83 | matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, 84 | matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, 85 | }; 86 | return transformed; 87 | } 88 | 89 | __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) 90 | { 91 | float3 transformed = { 92 | matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, 93 | matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, 94 | matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, 95 | }; 96 | return transformed; 97 | } 98 | 99 | __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) 100 | { 101 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 102 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 103 | float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 104 | return dnormvdz; 105 | } 106 | 107 | __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) 108 | { 109 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; 110 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 111 | 112 | float3 dnormvdv; 113 | dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; 114 | dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; 115 | dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; 116 | return dnormvdv; 117 | } 118 | 119 | __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) 120 | { 121 | float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; 122 | float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); 123 | 124 | float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; 125 | float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; 126 | float4 dnormvdv; 127 | dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; 128 | dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; 129 | dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; 130 | dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; 131 | return dnormvdv; 132 | } 133 | 134 | __forceinline__ __device__ float sigmoid(float x) 135 | { 136 | return 1.0f / (1.0f + expf(-x)); 137 | } 138 | 139 | __forceinline__ __device__ bool in_frustum(int idx, 140 | const float* orig_points, 141 | const float* viewmatrix, 142 | const float* projmatrix, 143 | bool prefiltered, 144 | float3& p_view) 145 | { 146 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 147 | 148 | // Bring points to screen space 149 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 150 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 151 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 152 | p_view = transformPoint4x3(p_orig, viewmatrix); 153 | 154 | if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) 155 | { 156 | if (prefiltered) 157 | { 158 | printf("Point is filtered although prefiltered is set. This shouldn't happen!"); 159 | __trap(); 160 | } 161 | return false; 162 | } 163 | return true; 164 | } 165 | 166 | #define CHECK_CUDA(A, debug) \ 167 | A; if(debug) { \ 168 | auto ret = cudaDeviceSynchronize(); \ 169 | if (ret != cudaSuccess) { \ 170 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ 171 | throw std::runtime_error(cudaGetErrorString(ret)); \ 172 | } \ 173 | } 174 | 175 | #endif -------------------------------------------------------------------------------- /origin_cuda/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 3 // Default 3, RGB 16 | #define BLOCK_X 16 17 | #define BLOCK_Y 16 18 | 19 | #endif -------------------------------------------------------------------------------- /origin_cuda/forward.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "auxiliary.h" 13 | #include "forward.h" 14 | #include 15 | #include 16 | #include 17 | namespace cg = cooperative_groups; 18 | 19 | // Forward method for converting the input spherical harmonics 20 | // coefficients of each Gaussian to a simple RGB color. 21 | __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped) 22 | { 23 | // The implementation is loosely based on code for 24 | // "Differentiable Point-Based Radiance Fields for 25 | // Efficient View Synthesis" by Zhang et al. (2022) 26 | glm::vec3 pos = means[idx]; 27 | glm::vec3 dir = pos - campos; 28 | dir = dir / glm::length(dir); 29 | 30 | glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; 31 | glm::vec3 result = SH_C0 * sh[0]; 32 | 33 | if (deg > 0) { 34 | float x = dir.x; 35 | float y = dir.y; 36 | float z = dir.z; 37 | result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3]; 38 | 39 | if (deg > 1) { 40 | float xx = x * x, yy = y * y, zz = z * z; 41 | float xy = x * y, yz = y * z, xz = x * z; 42 | result = result + 43 | SH_C2[0] * xy * sh[4] + 44 | SH_C2[1] * yz * sh[5] + 45 | SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] + 46 | SH_C2[3] * xz * sh[7] + 47 | SH_C2[4] * (xx - yy) * sh[8]; 48 | 49 | if (deg > 2) { 50 | result = result + 51 | SH_C3[0] * y * (3.0f * xx - yy) * sh[9] + 52 | SH_C3[1] * xy * z * sh[10] + 53 | SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] + 54 | SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] + 55 | SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] + 56 | SH_C3[5] * z * (xx - yy) * sh[14] + 57 | SH_C3[6] * x * (xx - 3.0f * yy) * sh[15]; 58 | } 59 | } 60 | } 61 | result += 0.5f; 62 | 63 | // RGB colors are clamped to positive values. If values are 64 | // clamped, we need to keep track of this for the backward pass. 65 | clamped[3 * idx + 0] = (result.x < 0); 66 | clamped[3 * idx + 1] = (result.y < 0); 67 | clamped[3 * idx + 2] = (result.z < 0); 68 | return glm::max(result, 0.0f); 69 | } 70 | 71 | // Forward version of 2D covariance matrix computation 72 | __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix) 73 | { 74 | // The following models the steps outlined by equations 29 75 | // and 31 in "EWA Splatting" (Zwicker et al., 2002). 76 | // Additionally considers aspect / scaling of viewport. 77 | // Transposes used to account for row-/column-major conventions. 78 | float3 t = transformPoint4x3(mean, viewmatrix); 79 | 80 | const float limx = 1.3f * tan_fovx; 81 | const float limy = 1.3f * tan_fovy; 82 | const float txtz = t.x / t.z; 83 | const float tytz = t.y / t.z; 84 | t.x = min(limx, max(-limx, txtz)) * t.z; 85 | t.y = min(limy, max(-limy, tytz)) * t.z; 86 | 87 | glm::mat3 J = glm::mat3( 88 | focal_x / t.z, 89 | 0.0f, 90 | -(focal_x * t.x) / (t.z * t.z), 91 | 0.0f, 92 | focal_y / t.z, 93 | -(focal_y * t.y) / (t.z * t.z), 94 | 0, 95 | 0, 96 | 0); 97 | 98 | glm::mat3 W = glm::mat3( 99 | viewmatrix[0], 100 | viewmatrix[4], 101 | viewmatrix[8], 102 | viewmatrix[1], 103 | viewmatrix[5], 104 | viewmatrix[9], 105 | viewmatrix[2], 106 | viewmatrix[6], 107 | viewmatrix[10]); 108 | 109 | glm::mat3 T = W * J; 110 | 111 | glm::mat3 Vrk = glm::mat3( 112 | cov3D[0], 113 | cov3D[1], 114 | cov3D[2], 115 | cov3D[1], 116 | cov3D[3], 117 | cov3D[4], 118 | cov3D[2], 119 | cov3D[4], 120 | cov3D[5]); 121 | 122 | glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T; 123 | 124 | // Apply low-pass filter: every Gaussian should be at least 125 | // one pixel wide/high. Discard 3rd row and column. 126 | cov[0][0] += 0.3f; 127 | cov[1][1] += 0.3f; 128 | return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) }; 129 | } 130 | 131 | // Forward method for converting scale and rotation properties of each 132 | // Gaussian to a 3D covariance matrix in world space. Also takes care 133 | // of quaternion normalization. 134 | __device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D) 135 | { 136 | // Create scaling matrix 137 | glm::mat3 S = glm::mat3(1.0f); 138 | S[0][0] = mod * scale.x; 139 | S[1][1] = mod * scale.y; 140 | S[2][2] = mod * scale.z; 141 | 142 | // Normalize quaternion to get valid rotation 143 | glm::vec4 q = rot; // / glm::length(rot); 144 | float r = q.x; 145 | float x = q.y; 146 | float y = q.z; 147 | float z = q.w; 148 | 149 | // Compute rotation matrix from quaternion 150 | glm::mat3 R = glm::mat3( 151 | 1.f - 2.f * (y * y + z * z), 152 | 2.f * (x * y - r * z), 153 | 2.f * (x * z + r * y), 154 | 2.f * (x * y + r * z), 155 | 1.f - 2.f * (x * x + z * z), 156 | 2.f * (y * z - r * x), 157 | 2.f * (x * z - r * y), 158 | 2.f * (y * z + r * x), 159 | 1.f - 2.f * (x * x + y * y)); 160 | 161 | glm::mat3 M = S * R; 162 | 163 | // Compute 3D world covariance matrix Sigma 164 | glm::mat3 Sigma = glm::transpose(M) * M; 165 | 166 | // Covariance is symmetric, only store upper right 167 | cov3D[0] = Sigma[0][0]; 168 | cov3D[1] = Sigma[0][1]; 169 | cov3D[2] = Sigma[0][2]; 170 | cov3D[3] = Sigma[1][1]; 171 | cov3D[4] = Sigma[1][2]; 172 | cov3D[5] = Sigma[2][2]; 173 | } 174 | 175 | // Perform initial steps for each Gaussian prior to rasterization. 176 | template 177 | __global__ void preprocessCUDA(int P, 178 | int D, 179 | int M, 180 | const float* orig_points, 181 | glm::vec3* scales, 182 | const float scale_modifier, 183 | const glm::vec4* rotations, 184 | const float* opacities, 185 | const float* shs, 186 | bool* clamped, 187 | const float* cov3D_precomp, 188 | const float* colors_precomp, 189 | const float* viewmatrix, 190 | const float* projmatrix, 191 | const glm::vec3* cam_pos, 192 | const int W, 193 | int H, 194 | const float tan_fovx, 195 | float tan_fovy, 196 | const float focal_x, 197 | float focal_y, 198 | int* radii, 199 | float2* points_xy_image, 200 | float* depths, 201 | float* cov3Ds, 202 | float* rgb, 203 | float4* conic_opacity, 204 | const dim3 grid, 205 | uint32_t* tiles_touched, 206 | bool prefiltered) 207 | { 208 | auto idx = cg::this_grid().thread_rank(); 209 | if (idx >= P) 210 | return; 211 | 212 | // cuPrintf(); 213 | // Initialize radius and touched tiles to 0. If this isn't changed, 214 | // this Gaussian will not be processed further. 215 | radii[idx] = 0; 216 | tiles_touched[idx] = 0; 217 | 218 | // Perform near culling, quit if outside. 219 | float3 p_view; 220 | if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) 221 | return; 222 | 223 | // Transform point by projecting 224 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 225 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 226 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 227 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 228 | // if (idx == 1) { 229 | 230 | // for (int i = 0; i < 16; i++) { 231 | // printf("in2 %f \n", projmatrix[i]); 232 | // } 233 | // float tt = projmatrix[1] * p_orig.x; 234 | // float tt2 = projmatrix[5] * p_orig.y; 235 | // float tt3 = projmatrix[9] * p_orig.z; 236 | // float tt4 = projmatrix[13]; 237 | // // + projmatrix[5] * p_orig.y + projmatrix[9] * p_orig.z + projmatrix[13]; 238 | // printf("herehere %f %f %f %f\n", tt,tt2,tt3,tt4); 239 | // printf("incudaasf %f %f %f \n ", p_orig.x, p_orig.y, p_orig.z); 240 | // printf("incudaasf %f %f %f %f \n ", p_hom.x, p_hom.y, p_hom.z, p_hom.w); 241 | // } 242 | 243 | // If 3D covariance matrix is precomputed, use it, otherwise compute 244 | // from scaling and rotation parameters. 245 | const float* cov3D; 246 | if (cov3D_precomp != nullptr) { 247 | cov3D = cov3D_precomp + idx * 6; 248 | } else { 249 | computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6); 250 | cov3D = cov3Ds + idx * 6; 251 | } 252 | // Compute 2D screen-space covariance matrix 253 | float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); 254 | // Invert covariance (EWA algorithm) 255 | 256 | float det = (cov.x * cov.z - cov.y * cov.y); 257 | if (det == 0.0f) 258 | return; 259 | float det_inv = 1.f / det; 260 | float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv }; 261 | 262 | // Compute extent in screen space (by finding eigenvalues of 263 | // 2D covariance matrix). Use extent to compute a bounding rectangle 264 | // of screen-space tiles that this Gaussian overlaps with. Quit if 265 | // rectangle covers 0 tiles. 266 | float mid = 0.5f * (cov.x + cov.z); 267 | float lambda1 = mid + sqrt(max(0.1f, mid * mid - det)); 268 | float lambda2 = mid - sqrt(max(0.1f, mid * mid - det)); 269 | float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2))); 270 | float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) }; 271 | uint2 rect_min, rect_max; 272 | getRect(point_image, my_radius, rect_min, rect_max, grid); 273 | 274 | if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) 275 | return; 276 | // if (idx == 0) { 277 | // printf("incuda%f %f %d %d %d %d\n", p_proj.x, p_proj.y, rect_min.x, rect_min.y, rect_max.x, rect_max.y); 278 | // } 279 | 280 | // If colors have been precomputed, use them, otherwise convert 281 | // spherical harmonics coefficients to RGB color. 282 | if (colors_precomp == nullptr) { 283 | glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); 284 | rgb[idx * C + 0] = result.x; 285 | rgb[idx * C + 1] = result.y; 286 | rgb[idx * C + 2] = result.z; 287 | } 288 | 289 | // Store some useful helper data for the next steps. 290 | depths[idx] = p_view.z; 291 | radii[idx] = my_radius; 292 | points_xy_image[idx] = point_image; 293 | // Inverse 2D covariance and opacity neatly pack into one float4 294 | conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] }; 295 | tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); 296 | } 297 | 298 | __global__ void preprocessCUDA1(int P, int D, int M, const float* orig_points, const glm::vec3* scales, const float scale_modifier, const glm::vec4* rotations, const float* opacities, const float* shs, bool* clamped, const float* cov3D_precomp, const float* colors_precomp, const float* viewmatrix, const float* projmatrix, const glm::vec3* cam_pos, const int W, int H, const float tan_fovx, float tan_fovy, const float focal_x, float focal_y, int* radii, float2* points_xy_image, float* depths, float* cov3Ds, float* rgb, float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched, bool prefiltered) 299 | { 300 | auto idx = cg::this_grid().thread_rank(); 301 | 302 | if (idx >= P) 303 | return; 304 | // printf("%f scale outside= %f \n",focal_x ,(*scales).y); 305 | // cuPrintf(); 306 | // Initialize radius and touched tiles to 0. If this isn't changed, 307 | // this Gaussian will not be processed further. 308 | radii[idx] = 0; 309 | tiles_touched[idx] = 0; 310 | 311 | // Perform near culling, quit if outside. 312 | float3 p_view; 313 | if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) 314 | return; 315 | 316 | // Transform point by projecting 317 | float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; 318 | float4 p_hom = transformPoint4x4(p_orig, projmatrix); 319 | float p_w = 1.0f / (p_hom.w + 0.0000001f); 320 | float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; 321 | 322 | // If 3D covariance matrix is precomputed, use it, otherwise compute 323 | // from scaling and rotation parameters. 324 | const float* cov3D; 325 | if (cov3D_precomp != nullptr) { 326 | cov3D = cov3D_precomp + idx * 6; 327 | } else { 328 | computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6); 329 | cov3D = cov3Ds + idx * 6; 330 | } 331 | 332 | // Compute 2D screen-space covariance matrix 333 | float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); 334 | 335 | // Invert covariance (EWA algorithm) 336 | 337 | float det = (cov.x * cov.z - cov.y * cov.y); 338 | if (det == 0.0f) 339 | return; 340 | float det_inv = 1.f / det; 341 | float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv }; 342 | 343 | // Compute extent in screen space (by finding eigenvalues of 344 | // 2D covariance matrix). Use extent to compute a bounding rectangle 345 | // of screen-space tiles that this Gaussian overlaps with. Quit if 346 | // rectangle covers 0 tiles. 347 | float mid = 0.5f * (cov.x + cov.z); 348 | float lambda1 = mid + sqrt(max(0.1f, mid * mid - det)); 349 | float lambda2 = mid - sqrt(max(0.1f, mid * mid - det)); 350 | float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2))); 351 | float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) }; 352 | uint2 rect_min, rect_max; 353 | getRect(point_image, my_radius, rect_min, rect_max, grid); 354 | if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) 355 | return; 356 | 357 | // If colors have been precomputed, use them, otherwise convert 358 | // spherical harmonics coefficients to RGB color. 359 | if (colors_precomp == nullptr) { 360 | glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); 361 | rgb[idx * 3 + 0] = result.x; 362 | rgb[idx * 3 + 1] = result.y; 363 | rgb[idx * 3 + 2] = result.z; 364 | } 365 | 366 | // Store some useful helper data for the next steps. 367 | depths[idx] = p_view.z; 368 | radii[idx] = my_radius; 369 | points_xy_image[idx] = point_image; 370 | // Inverse 2D covariance and opacity neatly pack into one float4 371 | conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] }; 372 | tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); 373 | 374 | // printf("1212 %s \n",orig_points[0]); 375 | } 376 | 377 | // Main rasterization method. Collaboratively works on one tile per 378 | // block, each thread treats one pixel. Alternates between fetching 379 | // and rasterizing data. 380 | template 381 | __global__ void __launch_bounds__(BLOCK_X* BLOCK_Y) 382 | renderCUDA( 383 | const uint2* __restrict__ ranges, 384 | const uint32_t* __restrict__ point_list, 385 | int W, 386 | int H, 387 | const float2* __restrict__ points_xy_image, 388 | const float* __restrict__ features, 389 | const float4* __restrict__ conic_opacity, 390 | float* __restrict__ final_T, 391 | uint32_t* __restrict__ n_contrib, 392 | const float* __restrict__ bg_color, 393 | float* __restrict__ out_color) 394 | { 395 | // Identify current tile and associated min/max pixel range. 396 | auto block = cg::this_thread_block(); 397 | uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; 398 | uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; 399 | uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y, H) }; 400 | uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; 401 | uint32_t pix_id = W * pix.y + pix.x; 402 | float2 pixf = { (float)pix.x, (float)pix.y }; 403 | 404 | // Check if this thread is associated with a valid pixel or outside. 405 | bool inside = pix.x < W && pix.y < H; 406 | // Done threads can help with fetching, but don't rasterize 407 | bool done = !inside; 408 | 409 | // Load start/end range of IDs to process in bit sorted list. 410 | uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; 411 | const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); 412 | int toDo = range.y - range.x; 413 | 414 | // Allocate storage for batches of collectively fetched data. 415 | __shared__ int collected_id[BLOCK_SIZE]; 416 | __shared__ float2 collected_xy[BLOCK_SIZE]; 417 | __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; 418 | 419 | // Initialize helper variables 420 | float T = 1.0f; 421 | uint32_t contributor = 0; 422 | uint32_t last_contributor = 0; 423 | float C[CHANNELS] = { 0 }; 424 | 425 | // Iterate over batches until all done or range is complete 426 | for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) { 427 | // End if entire block votes that it is done rasterizing 428 | int num_done = __syncthreads_count(done); 429 | if (num_done == BLOCK_SIZE) 430 | break; 431 | 432 | // Collectively fetch per-Gaussian data from global to shared 433 | int progress = i * BLOCK_SIZE + block.thread_rank(); 434 | if (range.x + progress < range.y) { 435 | int coll_id = point_list[range.x + progress]; 436 | collected_id[block.thread_rank()] = coll_id; 437 | collected_xy[block.thread_rank()] = points_xy_image[coll_id]; 438 | collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; 439 | } 440 | block.sync(); 441 | 442 | // Iterate over current batch 443 | for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) { 444 | // Keep track of current position in range 445 | contributor++; 446 | 447 | // Resample using conic matrix (cf. "Surface 448 | // Splatting" by Zwicker et al., 2001) 449 | float2 xy = collected_xy[j]; 450 | float2 d = { xy.x - pixf.x, xy.y - pixf.y }; 451 | float4 con_o = collected_conic_opacity[j]; 452 | float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; 453 | if (power > 0.0f) 454 | continue; 455 | 456 | // Eq. (2) from 3D Gaussian splatting paper. 457 | // Obtain alpha by multiplying with Gaussian opacity 458 | // and its exponential falloff from mean. 459 | // Avoid numerical instabilities (see paper appendix). 460 | float alpha = min(0.99f, con_o.w * exp(power)); 461 | if (alpha < 1.0f / 255.0f) 462 | continue; 463 | float test_T = T * (1 - alpha); 464 | if (test_T < 0.0001f) { 465 | done = true; 466 | continue; 467 | } 468 | 469 | // Eq. (3) from 3D Gaussian splatting paper. 470 | for (int ch = 0; ch < CHANNELS; ch++) 471 | C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; 472 | 473 | T = test_T; 474 | 475 | // Keep track of last range entry to update this 476 | // pixel. 477 | last_contributor = contributor; 478 | } 479 | } 480 | 481 | // All threads that treat valid pixel write out their final 482 | // rendering data to the frame and auxiliary buffers. 483 | if (inside) { 484 | final_T[pix_id] = T; 485 | n_contrib[pix_id] = last_contributor; 486 | for (int ch = 0; ch < CHANNELS; ch++) 487 | out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; 488 | } 489 | } 490 | 491 | void FORWARD::render( 492 | const dim3 grid, 493 | dim3 block, 494 | const uint2* ranges, 495 | const uint32_t* point_list, 496 | int W, 497 | int H, 498 | const float2* means2D, 499 | const float* colors, 500 | const float4* conic_opacity, 501 | float* final_T, 502 | uint32_t* n_contrib, 503 | const float* bg_color, 504 | float* out_color) 505 | { 506 | renderCUDA<<>>( 507 | ranges, 508 | point_list, 509 | W, 510 | H, 511 | means2D, 512 | colors, 513 | conic_opacity, 514 | final_T, 515 | n_contrib, 516 | bg_color, 517 | out_color); 518 | } 519 | 520 | void FORWARD::preprocess(int P, int D, int M, const float* means3D, glm::vec3* scales, const float scale_modifier, const glm::vec4* rotations, const float* opacities, const float* shs, bool* clamped, const float* cov3D_precomp, const float* colors_precomp, const float* viewmatrix, const float* projmatrix, const glm::vec3* cam_pos, const int W, int H, const float focal_x, float focal_y, const float tan_fovx, float tan_fovy, int* radii, float2* means2D, float* depths, float* cov3Ds, float* rgb, float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched, bool prefiltered) 521 | { 522 | // printf("%f scale outside= %f \n",focal_x ,rotations[0]); 523 | // preprocessCUDA1<<<(P + 255) / 256, 256>>>( 524 | // P, 525 | // D, 526 | // M, 527 | // means3D, 528 | // scales, 529 | // scale_modifier, 530 | // rotations, 531 | // opacities, 532 | // shs, 533 | // clamped, 534 | // cov3D_precomp, 535 | // colors_precomp, 536 | // viewmatrix, 537 | // projmatrix, 538 | // cam_pos, 539 | // W, 540 | // H, 541 | // tan_fovx, 542 | // tan_fovy, 543 | // focal_x, 544 | // focal_y, 545 | // radii, 546 | // means2D, 547 | // depths, 548 | // cov3Ds, 549 | // rgb, 550 | // conic_opacity, 551 | // grid, 552 | // tiles_touched, 553 | // prefiltered); 554 | // printf("%f scale outside= %f \n",focal_x ,scales[0]); 555 | // printf( "%d out \n", (u_int64_t)scales); 556 | preprocessCUDA<<<(P + 255) / 256, 256>>>( 557 | // preprocessCUDA<<<1,1>>>( 558 | P, 559 | D, 560 | M, 561 | means3D, 562 | scales, 563 | scale_modifier, 564 | rotations, 565 | opacities, 566 | shs, 567 | clamped, 568 | cov3D_precomp, 569 | colors_precomp, 570 | viewmatrix, 571 | projmatrix, 572 | cam_pos, 573 | W, 574 | H, 575 | tan_fovx, 576 | tan_fovy, 577 | focal_x, 578 | focal_y, 579 | radii, 580 | means2D, 581 | depths, 582 | cov3Ds, 583 | rgb, 584 | conic_opacity, 585 | grid, 586 | tiles_touched, 587 | prefiltered); 588 | // printf("%f scale outside2= %f \n",focal_x ,scales[0]); 589 | } -------------------------------------------------------------------------------- /origin_cuda/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | #define GLM_FORCE_CUDA 19 | #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess(int P, int D, int M, 25 | const float* orig_points, 26 | glm::vec3* scales, 27 | const float scale_modifier, 28 | const glm::vec4* rotations, 29 | const float* opacities, 30 | const float* shs, 31 | bool* clamped, 32 | const float* cov3D_precomp, 33 | const float* colors_precomp, 34 | const float* viewmatrix, 35 | const float* projmatrix, 36 | const glm::vec3* cam_pos, 37 | const int W, int H, 38 | const float focal_x, float focal_y, 39 | const float tan_fovx, float tan_fovy, 40 | int* radii, 41 | float2* points_xy_image, 42 | float* depths, 43 | float* cov3Ds, 44 | float* colors, 45 | float4* conic_opacity, 46 | const dim3 grid, 47 | uint32_t* tiles_touched, 48 | bool prefiltered); 49 | 50 | // Main rasterization method. 51 | void render( 52 | const dim3 grid, dim3 block, 53 | const uint2* ranges, 54 | const uint32_t* point_list, 55 | int W, int H, 56 | const float2* points_xy_image, 57 | const float* features, 58 | const float4* conic_opacity, 59 | float* final_T, 60 | uint32_t* n_contrib, 61 | const float* bg_color, 62 | float* out_color); 63 | } 64 | 65 | 66 | #endif -------------------------------------------------------------------------------- /origin_cuda/rasterizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace CudaRasterizer 19 | { 20 | class Rasterizer 21 | { 22 | public: 23 | 24 | static void markVisible( 25 | int P, 26 | float* means3D, 27 | float* viewmatrix, 28 | float* projmatrix, 29 | bool* present); 30 | 31 | static int forward( 32 | std::function geometryBuffer, 33 | std::function binningBuffer, 34 | std::function imageBuffer, 35 | const int P, int D, int M, 36 | const float* background, 37 | const int width, int height, 38 | const float* means3D, 39 | const float* shs, 40 | const float* colors_precomp, 41 | const float* opacities, 42 | const float* scales, 43 | const float scale_modifier, 44 | const float* rotations, 45 | const float* cov3D_precomp, 46 | const float* viewmatrix, 47 | const float* projmatrix, 48 | const float* cam_pos, 49 | const float tan_fovx, float tan_fovy, 50 | const bool prefiltered, 51 | float* out_color, 52 | int* radii = nullptr, 53 | bool debug = false); 54 | 55 | // static void backward( 56 | // const int P, int D, int M, int R, 57 | // const float* background, 58 | // const int width, int height, 59 | // const float* means3D, 60 | // const float* shs, 61 | // const float* colors_precomp, 62 | // const float* scales, 63 | // const float scale_modifier, 64 | // const float* rotations, 65 | // const float* cov3D_precomp, 66 | // const float* viewmatrix, 67 | // const float* projmatrix, 68 | // const float* campos, 69 | // const float tan_fovx, float tan_fovy, 70 | // const int* radii, 71 | // char* geom_buffer, 72 | // char* binning_buffer, 73 | // char* image_buffer, 74 | // const float* dL_dpix, 75 | // float* dL_dmean2D, 76 | // float* dL_dconic, 77 | // float* dL_dopacity, 78 | // float* dL_dcolor, 79 | // float* dL_dmean3D, 80 | // float* dL_dcov3D, 81 | // float* dL_dsh, 82 | // float* dL_dscale, 83 | // float* dL_drot, 84 | // bool debug); 85 | }; 86 | }; 87 | 88 | #endif -------------------------------------------------------------------------------- /origin_cuda/rasterizer_impl.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "cuda_runtime.h" 13 | #include "device_launch_parameters.h" 14 | #include "rasterizer_impl.h" 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #define GLM_FORCE_CUDA 23 | #include 24 | 25 | #include 26 | #include 27 | namespace cg = cooperative_groups; 28 | 29 | #include "auxiliary.h" 30 | #include "forward.h" 31 | // #include "backward.h" 32 | 33 | // Helper function to find the next-highest bit of the MSB 34 | // on the CPU. 35 | uint32_t getHigherMsb(uint32_t n) 36 | { 37 | uint32_t msb = sizeof(n) * 4; 38 | uint32_t step = msb; 39 | while (step > 1) { 40 | step /= 2; 41 | if (n >> msb) 42 | msb += step; 43 | else 44 | msb -= step; 45 | } 46 | if (n >> msb) 47 | msb++; 48 | return msb; 49 | } 50 | 51 | // Wrapper method to call auxiliary coarse frustum containment test. 52 | // Mark all Gaussians that pass it. 53 | __global__ void checkFrustum(int P, const float* orig_points, const float* viewmatrix, const float* projmatrix, bool* present) 54 | { 55 | auto idx = cg::this_grid().thread_rank(); 56 | if (idx >= P) 57 | return; 58 | 59 | float3 p_view; 60 | present[idx] = 61 | in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view); 62 | } 63 | 64 | // Generates one key/value pair for all Gaussian / tile overlaps. 65 | // Run once per Gaussian (1:N mapping). 66 | __global__ void duplicateWithKeys(int P, const float2* points_xy, const float* depths, const uint32_t* offsets, uint64_t* gaussian_keys_unsorted, uint32_t* gaussian_values_unsorted, int* radii, dim3 grid) 67 | { 68 | auto idx = cg::this_grid().thread_rank(); 69 | if (idx >= P) 70 | return; 71 | 72 | // Generate no key/value pair for invisible Gaussians 73 | if (radii[idx] > 0) { 74 | // Find this Gaussian's offset in buffer for writing keys/values. 75 | uint32_t off = (idx == 0) ? 0 : offsets[idx - 1]; 76 | uint2 rect_min, rect_max; 77 | 78 | getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid); 79 | 80 | // For each tile that the bounding rect overlaps, emit a 81 | // key/value pair. The key is | tile ID | depth |, 82 | // and the value is the ID of the Gaussian. Sorting the values 83 | // with this key yields Gaussian IDs in a list, such that they 84 | // are first sorted by tile and then by depth. 85 | for (int y = rect_min.y; y < rect_max.y; y++) { 86 | for (int x = rect_min.x; x < rect_max.x; x++) { 87 | uint64_t key = y * grid.x + x; 88 | key <<= 32; 89 | key |= *((uint32_t*)&depths[idx]); 90 | gaussian_keys_unsorted[off] = key; 91 | gaussian_values_unsorted[off] = idx; 92 | off++; 93 | } 94 | } 95 | } 96 | } 97 | 98 | // Check keys to see if it is at the start/end of one tile's range in 99 | // the full sorted list. If yes, write start/end of this tile. 100 | // Run once per instanced (duplicated) Gaussian ID. 101 | __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges) 102 | { 103 | auto idx = cg::this_grid().thread_rank(); 104 | if (idx >= L) 105 | return; 106 | 107 | // Read tile ID from key. Update start/end of tile range if at limit. 108 | uint64_t key = point_list_keys[idx]; 109 | uint32_t currtile = key >> 32; 110 | if (idx == 0) 111 | ranges[currtile].x = 0; 112 | else { 113 | uint32_t prevtile = point_list_keys[idx - 1] >> 32; 114 | if (currtile != prevtile) { 115 | ranges[prevtile].y = idx; 116 | ranges[currtile].x = idx; 117 | } 118 | } 119 | if (idx == L - 1) 120 | ranges[currtile].y = L; 121 | } 122 | 123 | // Mark Gaussians as visible/invisible, based on view frustum testing 124 | void CudaRasterizer::Rasterizer::markVisible(int P, float* means3D, float* viewmatrix, float* projmatrix, bool* present) 125 | { 126 | checkFrustum<<<(P + 255) / 256, 256>>>(P, means3D, viewmatrix, projmatrix, present); 127 | } 128 | 129 | CudaRasterizer::GeometryState 130 | CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P) 131 | { 132 | GeometryState geom; 133 | obtain(chunk, geom.depths, P, 128); 134 | obtain(chunk, geom.clamped, P * 3, 128); 135 | obtain(chunk, geom.internal_radii, P, 128); 136 | obtain(chunk, geom.means2D, P, 128); 137 | obtain(chunk, geom.cov3D, P * 6, 128); 138 | obtain(chunk, geom.conic_opacity, P, 128); 139 | obtain(chunk, geom.rgb, P * 3, 128); 140 | obtain(chunk, geom.tiles_touched, P, 128); 141 | cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P); 142 | obtain(chunk, geom.scanning_space, geom.scan_size, 128); 143 | obtain(chunk, geom.point_offsets, P, 128); 144 | return geom; 145 | } 146 | 147 | CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, 148 | size_t N) 149 | { 150 | ImageState img; 151 | obtain(chunk, img.accum_alpha, N, 128); 152 | obtain(chunk, img.n_contrib, N, 128); 153 | obtain(chunk, img.ranges, N, 128); 154 | return img; 155 | } 156 | 157 | CudaRasterizer::BinningState 158 | CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P) 159 | { 160 | BinningState binning; 161 | obtain(chunk, binning.point_list, P, 128); 162 | obtain(chunk, binning.point_list_unsorted, P, 128); 163 | obtain(chunk, binning.point_list_keys, P, 128); 164 | obtain(chunk, binning.point_list_keys_unsorted, P, 128); 165 | cub::DeviceRadixSort::SortPairs( 166 | nullptr, 167 | binning.sorting_size, 168 | binning.point_list_keys_unsorted, 169 | binning.point_list_keys, 170 | binning.point_list_unsorted, 171 | binning.point_list, 172 | P); 173 | obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); 174 | return binning; 175 | } 176 | 177 | // Forward rendering procedure for differentiable rasterization 178 | // of Gaussians. 179 | int CudaRasterizer::Rasterizer::forward( 180 | std::function geometryBuffer, 181 | std::function binningBuffer, 182 | std::function imageBuffer, 183 | const int P, 184 | int D, 185 | int M, 186 | const float* background, 187 | const int width, 188 | int height, 189 | const float* means3D, 190 | const float* shs, 191 | const float* colors_precomp, 192 | const float* opacities, 193 | const float* scales, 194 | const float scale_modifier, 195 | const float* rotations, 196 | const float* cov3D_precomp, 197 | const float* viewmatrix, 198 | const float* projmatrix, 199 | const float* cam_pos, 200 | const float tan_fovx, 201 | float tan_fovy, 202 | const bool prefiltered, 203 | float* out_color, 204 | int* radii, 205 | bool debug) 206 | { 207 | const float focal_y = height / (2.0f * tan_fovy); 208 | const float focal_x = width / (2.0f * tan_fovx); 209 | 210 | size_t chunk_size = required(P); 211 | char* chunkptr = geometryBuffer(chunk_size); 212 | GeometryState geomState = GeometryState::fromChunk(chunkptr, P); 213 | 214 | if (radii == nullptr) { 215 | radii = geomState.internal_radii; 216 | } 217 | 218 | dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, 219 | (height + BLOCK_Y - 1) / BLOCK_Y, 220 | 1); 221 | dim3 block(BLOCK_X, BLOCK_Y, 1); 222 | 223 | // Dynamically resize image-based auxiliary buffers during training 224 | size_t img_chunk_size = required(width * height); 225 | char* img_chunkptr = imageBuffer(img_chunk_size); 226 | ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height); 227 | 228 | if (NUM_CHANNELS != 3 && colors_precomp == nullptr) { 229 | throw std::runtime_error( 230 | "For non-RGB, provide precomputed Gaussian colors!"); 231 | } 232 | 233 | // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs 234 | // to RGB) 235 | CHECK_CUDA(FORWARD::preprocess( 236 | P, 237 | D, 238 | M, 239 | means3D, 240 | (glm::vec3*)scales, 241 | scale_modifier, 242 | (glm::vec4*)rotations, 243 | opacities, 244 | shs, 245 | geomState.clamped, 246 | cov3D_precomp, 247 | colors_precomp, 248 | viewmatrix, 249 | projmatrix, 250 | (glm::vec3*)cam_pos, 251 | width, 252 | height, 253 | focal_x, 254 | focal_y, 255 | tan_fovx, 256 | tan_fovy, 257 | radii, 258 | geomState.means2D, 259 | geomState.depths, 260 | geomState.cov3D, 261 | geomState.rgb, 262 | geomState.conic_opacity, 263 | tile_grid, 264 | geomState.tiles_touched, 265 | prefiltered), 266 | debug) 267 | 268 | 269 | // Compute prefix sum over full list of touched tile counts by Gaussians 270 | // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] 271 | CHECK_CUDA(cub::DeviceScan::InclusiveSum( 272 | geomState.scanning_space, 273 | geomState.scan_size, 274 | geomState.tiles_touched, 275 | geomState.point_offsets, 276 | P), 277 | debug) 278 | u_int32_t num_rendered1; 279 | CHECK_CUDA(cudaMemcpy(&num_rendered1, geomState.tiles_touched, sizeof(u_int32_t), cudaMemcpyDeviceToHost), 280 | true); 281 | int rr = 0; 282 | 283 | u_int32_t num_rendered2; 284 | CHECK_CUDA(cudaMemcpy(&num_rendered2, geomState.tiles_touched + 1, sizeof(u_int32_t), cudaMemcpyDeviceToHost), 285 | true); 286 | 287 | // Retrieve total number of Gaussian instances to launch and resize aux 288 | // buffers 289 | 290 | int num_rendered; 291 | CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(u_int32_t), cudaMemcpyDeviceToHost), 292 | debug); 293 | std::cout << "here" << geomState.scan_size << std::endl; 294 | size_t binning_chunk_size = required(num_rendered); 295 | char* binning_chunkptr = binningBuffer(binning_chunk_size); 296 | BinningState binningState = 297 | BinningState::fromChunk(binning_chunkptr, num_rendered); 298 | 299 | // For each instance to be rendered, produce adequate [ tile | depth ] key 300 | // and corresponding dublicated Gaussian indices to be sorted 301 | duplicateWithKeys<<<(P + 255) / 256, 256>>>( 302 | P, 303 | geomState.means2D, 304 | geomState.depths, 305 | geomState.point_offsets, 306 | binningState.point_list_keys_unsorted, 307 | binningState.point_list_unsorted, 308 | radii, 309 | tile_grid) CHECK_CUDA(, debug) 310 | 311 | int bit = getHigherMsb(tile_grid.x * tile_grid.y); 312 | 313 | // Sort complete list of (duplicated) Gaussian indices by keys 314 | CHECK_CUDA(cub::DeviceRadixSort::SortPairs( 315 | binningState.list_sorting_space, 316 | binningState.sorting_size, 317 | binningState.point_list_keys_unsorted, 318 | binningState.point_list_keys, 319 | binningState.point_list_unsorted, 320 | binningState.point_list, 321 | num_rendered, 322 | 0, 323 | 32 + bit), 324 | debug) 325 | 326 | CHECK_CUDA( 327 | cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), 328 | debug); 329 | 330 | // Identify start and end of per-tile workloads in sorted list 331 | if (num_rendered > 0) 332 | identifyTileRanges<<<(num_rendered + 255) / 256, 256>>>( 333 | num_rendered, 334 | binningState.point_list_keys, 335 | imgState.ranges); 336 | CHECK_CUDA(, debug) 337 | 338 | // Let each tile blend its range of Gaussians independently in parallel 339 | const float* feature_ptr = 340 | colors_precomp != nullptr ? colors_precomp : geomState.rgb; 341 | CHECK_CUDA(FORWARD::render(tile_grid, 342 | block, 343 | imgState.ranges, 344 | binningState.point_list, 345 | width, 346 | height, 347 | geomState.means2D, 348 | feature_ptr, 349 | geomState.conic_opacity, 350 | imgState.accum_alpha, 351 | imgState.n_contrib, 352 | background, 353 | out_color), 354 | debug) 355 | 356 | return num_rendered; 357 | } 358 | 359 | // Produce necessary gradients for optimization, corresponding 360 | // to forward render pass 361 | // void CudaRasterizer::Rasterizer::backward( 362 | // const int P, int D, int M, int R, const float *background, const int width, 363 | // int height, const float *means3D, const float *shs, 364 | // const float *colors_precomp, const float *scales, 365 | // const float scale_modifier, const float *rotations, 366 | // const float *cov3D_precomp, const float *viewmatrix, 367 | // const float *projmatrix, const float *campos, const float tan_fovx, 368 | // float tan_fovy, const int *radii, char *geom_buffer, char *binning_buffer, 369 | // char *img_buffer, const float *dL_dpix, float *dL_dmean2D, float *dL_dconic, 370 | // float *dL_dopacity, float *dL_dcolor, float *dL_dmean3D, float *dL_dcov3D, 371 | // float *dL_dsh, float *dL_dscale, float *dL_drot, bool debug) { 372 | // GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); 373 | // BinningState binningState = BinningState::fromChunk(binning_buffer, R); 374 | // ImageState imgState = ImageState::fromChunk(img_buffer, width * height); 375 | 376 | // if (radii == nullptr) { 377 | // radii = geomState.internal_radii; 378 | // } 379 | 380 | // const float focal_y = height / (2.0f * tan_fovy); 381 | // const float focal_x = width / (2.0f * tan_fovx); 382 | 383 | // const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, 384 | // (height + BLOCK_Y - 1) / BLOCK_Y, 1); 385 | // const dim3 block(BLOCK_X, BLOCK_Y, 1); 386 | 387 | // // Compute loss gradients w.r.t. 2D mean position, conic matrix, 388 | // // opacity and RGB of Gaussians from per-pixel loss gradients. 389 | // // If we were given precomputed colors and not SHs, use them. 390 | // const float *color_ptr = 391 | // (colors_precomp != nullptr) ? colors_precomp : geomState.rgb; 392 | // CHECK_CUDA(BACKWARD::render( 393 | // tile_grid, block, imgState.ranges, binningState.point_list, 394 | // width, height, background, geomState.means2D, 395 | // geomState.conic_opacity, color_ptr, imgState.accum_alpha, 396 | // imgState.n_contrib, dL_dpix, (float3 *)dL_dmean2D, 397 | // (float4 *)dL_dconic, dL_dopacity, dL_dcolor), 398 | // debug) 399 | 400 | // // Take care of the rest of preprocessing. Was the precomputed covariance 401 | // // given to us or a scales/rot pair? If precomputed, pass that. If not, 402 | // // use the one we computed ourselves. 403 | // const float *cov3D_ptr = 404 | // (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D; 405 | // CHECK_CUDA(BACKWARD::preprocess( 406 | // P, D, M, (float3 *)means3D, radii, shs, geomState.clamped, 407 | // (glm::vec3 *)scales, (glm::vec4 *)rotations, scale_modifier, 408 | // cov3D_ptr, viewmatrix, projmatrix, focal_x, focal_y, tan_fovx, 409 | // tan_fovy, (glm::vec3 *)campos, (float3 *)dL_dmean2D, dL_dconic, 410 | // (glm::vec3 *)dL_dmean3D, dL_dcolor, dL_dcov3D, dL_dsh, 411 | // (glm::vec3 *)dL_dscale, (glm::vec4 *)dL_drot), 412 | // debug) 413 | 414 | // } -------------------------------------------------------------------------------- /origin_cuda/rasterizer_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "rasterizer.h" 17 | #include 18 | 19 | namespace CudaRasterizer 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | float* depths; 33 | char* scanning_space; 34 | bool* clamped; 35 | int* internal_radii; 36 | float2* means2D; 37 | float* cov3D; 38 | float4* conic_opacity; 39 | float* rgb; 40 | uint32_t* point_offsets; 41 | uint32_t* tiles_touched; 42 | 43 | static GeometryState fromChunk(char*& chunk, size_t P); 44 | }; 45 | 46 | struct ImageState 47 | { 48 | uint2* ranges; 49 | uint32_t* n_contrib; 50 | float* accum_alpha; 51 | 52 | static ImageState fromChunk(char*& chunk, size_t N); 53 | }; 54 | 55 | struct BinningState 56 | { 57 | size_t sorting_size; 58 | uint64_t* point_list_keys_unsorted; 59 | uint64_t* point_list_keys; 60 | uint32_t* point_list_unsorted; 61 | uint32_t* point_list; 62 | char* list_sorting_space; 63 | 64 | static BinningState fromChunk(char*& chunk, size_t P); 65 | }; 66 | 67 | template 68 | size_t required(size_t P) 69 | { 70 | char* size = nullptr; 71 | T::fromChunk(size, P); 72 | return ((size_t)size) + 128; 73 | } 74 | }; -------------------------------------------------------------------------------- /point_cloud.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/point_cloud.ply -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## MC Gaussian 2 | 3 | use Vulkan Compute-Shader to rewrite [3D Gaussian splatting](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) 4 | 5 | ### roadmap 6 | 7 | - ✅ Minimal example to verify feasibility(2024-4-10🎊works!) 8 | - ☑️ **SOTA** GPU sort,by [Nvidia 2022 research](https://research.nvidia.com/publication/2022-06_onesweep-faster-least-significant-digit-radix-sort-gpus) 9 | - (2024-4-21)multi-pass sort make it **4x** faster than single pass,but still slower than sota 10 | 11 | - ✅ Camera control 12 | - ⬜ OpenXR support(in develop current now) 13 | 14 | ### showcase 15 | 16 | | | origin CUDA | Compute Shader | 17 | |---------------------|-----------------|-----------------| 18 | | quality | ![origincuda](https://github.com/MouseChannel/MCGS/blob/main/showcase/origincuda.png) | ![output](https://github.com/MouseChannel/MCGS/blob/main/showcase/output.png) | 19 | |PSNR |------------------------ | $+\infty$(same in every single pixel😏) | 20 | | FPS | 227 fps / 4.4ms | 65 fps / 15.2ms ( :arrow_heading_down:245%) | 21 | 22 | 23 | > GPU sort cost 11.5ms/15.2ms(75.6%),so It must be optimized!! 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /showcase/origincuda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/showcase/origincuda.png -------------------------------------------------------------------------------- /showcase/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MouseChannel/MCGS/c9a89849d473c44133f052451ebaff322fde5e89/showcase/output.png -------------------------------------------------------------------------------- /src/GSContext.cpp: -------------------------------------------------------------------------------- 1 | #include "GSContext.hpp" 2 | #include "Helper/Camera.hpp" 3 | #include "Helper/CommandManager.hpp" 4 | #include "Rendering/GraphicContext.hpp" 5 | #include "Wrapper/CommandBuffer.hpp" 6 | #include "Wrapper/ComputePass/ComputePass.hpp" 7 | #include "shaders/DataStruct.h" 8 | 9 | #include 10 | // pos + normal+ de_012 + de_rest+opacity + scale + rot 11 | const int vert_attr = 3 + 3 + 3 + 45 + 1 + 3 + 4; 12 | constexpr int local_size = 256; 13 | std::vector> load_ply(std::string ply_path) 14 | { 15 | 16 | FILE* fp; 17 | 18 | std::vector> Pts; 19 | int numPts; 20 | fp = fopen(ply_path.c_str(), "rb"); 21 | 22 | char strLine[102400]; 23 | 24 | char end_flag[] = "end_header "; 25 | char num_flag[] = "element vertex "; 26 | char* p; 27 | char num[10000]; 28 | 29 | if (fp == NULL) { 30 | printf("Error:Open input.c file fail!\n"); 31 | } 32 | 33 | while (!feof(fp)) // 循环读取每一行,直到文件尾 34 | { 35 | fgets(strLine, 102400, fp); 36 | 37 | if (strlen(strLine) == (strlen(end_flag))) { 38 | break; 39 | } 40 | 41 | if ((p = strstr(strLine, num_flag)) != NULL) { 42 | int start = strlen(num_flag); 43 | int sub_len = strlen(strLine) - strlen(num_flag); 44 | 45 | for (int i = 0; i < sub_len; i++) { 46 | num[i] = strLine[start + i]; 47 | } 48 | numPts = atoi(num); 49 | } 50 | } 51 | 52 | float* pts = (float*)malloc(numPts * vert_attr * sizeof(float)); 53 | 54 | float cnt = numPts * vert_attr; 55 | 56 | fread(pts, sizeof(float), cnt, fp); 57 | 58 | fclose(fp); 59 | 60 | for (int i = 0; i < numPts; i++) { 61 | std::vector temp; 62 | for (int j = 0; j < vert_attr; j++) { 63 | temp.push_back(pts[vert_attr * i + j]); 64 | } 65 | Pts.push_back(temp); 66 | } 67 | 68 | // std::string out_file = "bin_pts_check.txt"; 69 | // std::ofstream out(out_file.c_str()); 70 | // for (int pt_idx = 0; pt_idx < numPts; ++pt_idx) { 71 | // for (int j = 0; j < vert_attr; ++j) { 72 | // out << pts[pt_idx * vert_attr + j] << " "; 73 | // } 74 | // out << std::endl; 75 | // } 76 | // out.close(); 77 | return Pts; 78 | } 79 | 80 | std::vector get_sub(std::vector>& origin, 81 | int from, 82 | int end) 83 | { 84 | std::vector> res(origin.size(), std::vector {}); 85 | std::transform( 86 | origin.begin(), 87 | origin.end(), 88 | res.begin(), 89 | [&](std::vector& data) { 90 | return std::vector { data.data() + from, data.data() + end }; 91 | }); 92 | auto len = (end - from); 93 | std::vector d(origin.size() * len); 94 | for (int i = 0; i < res.size(); i++) { 95 | for (int j = 0; j < len; j++) { 96 | d[i * len + j] = res[i][j]; 97 | } 98 | } 99 | return d; 100 | } 101 | 102 | inline std::vector get_xyz(std::vector>& origin) 103 | { 104 | return get_sub(origin, 0, 3); 105 | } 106 | 107 | inline std::vector get_normal(std::vector>& origin) 108 | { 109 | return get_sub(origin, 3, 6); 110 | } 111 | 112 | inline std::vector get_dc_012(std::vector>& origin) 113 | { 114 | return get_sub(origin, 6, 9); 115 | } 116 | 117 | inline std::vector get_dc_rest(std::vector>& origin) 118 | { 119 | auto data = get_sub(origin, 9, 54); 120 | return data; 121 | std::vector fixed(data.size()); 122 | for (int i = 0; i < data.size(); i += 45) { 123 | for (int j = 0; j < 45; j++) { 124 | auto a = j / 15; 125 | auto b = j % 15; 126 | fixed[i + b * 3 + a] = data[i + j]; 127 | } 128 | } 129 | 130 | return fixed; 131 | } 132 | 133 | inline std::vector get_opacity(std::vector>& origin) 134 | { 135 | auto data = get_sub(origin, 54, 55); 136 | std::for_each(data.begin(), data.end(), [](auto& item) { 137 | item = 1 / (1 + std::exp(-item)); 138 | }); 139 | return data; 140 | } 141 | 142 | inline std::vector get_scale(std::vector>& origin) 143 | { 144 | auto data = get_sub(origin, 55, 58); 145 | std::for_each(data.begin(), data.end(), [](auto& item) { 146 | item = std::exp(item); 147 | }); 148 | return data; 149 | } 150 | 151 | inline std::vector get_rotation(std::vector>& origin) 152 | { 153 | auto data = get_sub(origin, 58, 62); 154 | // std::for_each(data.begin(), data.end(), [](auto& item) { 155 | // item = std::norma(item); 156 | // }); 157 | 158 | for (int i = 0; i < data.size() / 4; i++) { 159 | float cur_sum = 0.; 160 | for (int j = 0; j < 4; j++) { 161 | cur_sum += std::pow(data[i * 4 + j], 2); 162 | } 163 | auto div = std::sqrt(cur_sum); 164 | for (int j = 0; j < 4; j++) { 165 | data[i * 4 + j] /= div; 166 | } 167 | } 168 | return data; 169 | } 170 | 171 | namespace MCRT { 172 | 173 | void LookDeviceBuffer1(vk::Buffer device_buffer, int size) 174 | { 175 | Context::Get_Singleton()->get_device()->Get_Graphic_queue().waitIdle(); 176 | auto host_buffer = Buffer::create_buffer(nullptr, size, vk::BufferUsageFlagBits::eTransferDst); 177 | CommandManager::ExecuteCmd(Context::Get_Singleton()->get_device()->Get_Graphic_queue(), 178 | [&](vk::CommandBuffer& cmd) { 179 | cmd.copyBuffer( 180 | device_buffer, 181 | host_buffer->get_handle(), 182 | vk::BufferCopy() 183 | .setDstOffset(0) 184 | .setSrcOffset(0) 185 | .setSize(size)); 186 | }); 187 | auto cpu_raw_data = host_buffer->Get_mapped_data(); 188 | std::vector data(size / 4); 189 | std::memcpy(data.data(), cpu_raw_data.data(), sizeof(data)); 190 | int r = 0; 191 | } 192 | GSContext::GSContext(std::string path) 193 | { 194 | auto gs_data = load_ply(path); 195 | 196 | auto xyz_d = get_xyz(gs_data); 197 | 198 | std::vector xyz_3 { 199 | xyz_d.begin(), 200 | xyz_d.end() - 3 201 | }; 202 | auto scale_d = get_scale(gs_data); 203 | auto dc_012 = get_dc_012(gs_data); 204 | auto dc_rest = get_dc_rest(gs_data); 205 | 206 | auto opacity_d = get_opacity(gs_data); 207 | auto rotations_d = get_rotation(gs_data); 208 | 209 | std::vector feature_d(dc_012.size() + dc_rest.size()); 210 | for (int j = 0; j < (int)dc_012.size() / 3; j++) { 211 | 212 | feature_d[j * 48] = dc_012[j * 3 + 0]; 213 | feature_d[j * 48 + 16] = dc_012[j * 3 + 1]; 214 | feature_d[j * 48 + 32] = dc_012[j * 3 + 2]; 215 | for (int i = 0; i < 15; i++) { 216 | feature_d[j * 48 + 1 + i] = dc_rest[j * 45 + i]; 217 | feature_d[j * 48 + 17 + i] = dc_rest[j * 45 + 15 + i]; 218 | feature_d[j * 48 + 33 + i] = dc_rest[j * 45 + 30 + i]; 219 | } 220 | 221 | } 222 | std::vector raw_data_cpu(opacity_d.size()); 223 | for (int i = 0; i < opacity_d.size(); i++) { 224 | auto& data = raw_data_cpu[i]; 225 | data.opacity = opacity_d[i]; 226 | data.pos = glm::vec3(xyz_d[i * 3], 227 | xyz_d[i * 3 + 1], 228 | xyz_d[i * 3 + 2]); 229 | data.rot = glm::vec4( 230 | rotations_d[i * 4 + 1], 231 | rotations_d[i * 4 + 2], 232 | rotations_d[i * 4 + 3], 233 | rotations_d[i * 4]); 234 | data.scale = glm::vec3(scale_d[i * 3], 235 | scale_d[i * 3 + 1], 236 | scale_d[i * 3 + 2]); 237 | 238 | for (int j = 0; j < 12; j++) { 239 | glm::vec4 temp; 240 | for (int k = 0; k < 4; k++) { 241 | temp[k] = feature_d[i * 48 + 4 * j + k]; 242 | } 243 | data.sh[j] = temp; 244 | } 245 | } 246 | raw_data = Buffer::CreateDeviceBuffer(raw_data_cpu.data(), sizeof(raw_data_cpu[0]) * raw_data_cpu.size(), vk::BufferUsageFlagBits::eStorageBuffer); 247 | all_point_count = opacity_d.size(); 248 | 249 | point_count_buffer = Buffer::CreateDeviceBuffer( 250 | &all_point_count, 251 | sizeof(all_point_count), 252 | vk::BufferUsageFlagBits::eStorageBuffer); 253 | visiable_count_buffer = Buffer::CreateDeviceBuffer( 254 | nullptr, 255 | sizeof(uint32_t), 256 | vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferDst | vk::BufferUsageFlagBits::eTransferSrc); 257 | 258 | m_gpu_sort = std::make_shared(); 259 | } 260 | void GSContext::prepare() 261 | { 262 | ComputeContext::prepare(); 263 | // auto rr = new ComputePass({ sets }); 264 | auto pre_set = std::make_shared(); 265 | std::shared_ptr pre_pool; 266 | pre_set->AddBufferDescriptorTarget( 267 | point_count_buffer, 268 | 0, 269 | vk::ShaderStageFlagBits::eCompute, 270 | vk::DescriptorType ::eStorageBuffer); 271 | pre_set->AddBufferDescriptorTarget( 272 | raw_data, 273 | 1, 274 | vk::ShaderStageFlagBits::eCompute, 275 | vk::DescriptorType::eStorageBuffer); 276 | pre_pool.reset(new DescriptorPool({ pre_set })); 277 | pre_set->build(pre_pool, 1); 278 | pre_process_pass.reset(new ComputePass({ pre_set }, 4, "include/shaders/process.comp.spv")); 279 | CommandManager::ExecuteCmd(Context::Get_Singleton()->get_device()->Get_Graphic_queue(), 280 | [&](vk::CommandBuffer& cmd) { 281 | pre_process_pass->Dispach( 282 | cmd, 283 | std::ceil(all_point_count / 256), 284 | 1, 285 | 1); 286 | }); 287 | 288 | key_buffer = Buffer::CreateDeviceBuffer(nullptr, all_point_count * sizeof(uint32_t), vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eShaderDeviceAddress | vk::BufferUsageFlagBits::eTransferSrc); 289 | 290 | value_buffer = Buffer::CreateDeviceBuffer(nullptr, all_point_count * sizeof(uint32_t), vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eShaderDeviceAddress | vk::BufferUsageFlagBits::eTransferSrc); 291 | 292 | inverse_index_buffer = Buffer::CreateDeviceBuffer(nullptr, all_point_count * sizeof(uint32_t), vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferDst); 293 | vk::DrawIndexedIndirectCommand indirct_cmd; 294 | 295 | indirct_cmd_buffer = Buffer::CreateDeviceBuffer(&indirct_cmd, sizeof(vk::DrawIndexedIndirectCommand), vk::BufferUsageFlagBits::eIndirectBuffer | vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc); 296 | std::vector splat_index; 297 | splat_index.reserve(all_point_count * 6); 298 | for (int i = 0; i < all_point_count; ++i) { 299 | splat_index.push_back(4 * i + 0); 300 | splat_index.push_back(4 * i + 1); 301 | splat_index.push_back(4 * i + 2); 302 | splat_index.push_back(4 * i + 2); 303 | splat_index.push_back(4 * i + 1); 304 | splat_index.push_back(4 * i + 3); 305 | } 306 | 307 | index_buffer = Buffer::CreateDeviceBuffer( 308 | splat_index.data(), 309 | splat_index.size() * sizeof(uint32_t), 310 | vk::BufferUsageFlagBits::eIndexBuffer); 311 | 312 | instance_buffer = Buffer::CreateDeviceBuffer( 313 | nullptr, 314 | all_point_count * sizeof(InstancePoint), 315 | vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eVertexBuffer); 316 | 317 | CameraInfo camera { 318 | .projection = Context::Get_Singleton()->get_camera()->Get_p_matrix(), 319 | .view = Context::Get_Singleton()->get_camera()->Get_v_matrix(), 320 | .camera_position = Context::Get_Singleton()->get_camera()->get_pos(), 321 | .pad = 0, 322 | .screen_size = glm::uvec2(1600, 900) 323 | }; 324 | camera_buffer = Buffer::CreateDeviceBuffer(&camera, sizeof(CameraInfo), vk::BufferUsageFlagBits::eUniformBuffer); 325 | { 326 | set = std::make_shared(); 327 | 328 | set->AddBufferDescriptorTarget(indirct_cmd_buffer, 329 | e_indir_cmd, 330 | vk::ShaderStageFlagBits::eCompute, 331 | vk::DescriptorType ::eStorageBuffer); 332 | set->AddBufferDescriptorTarget(point_count_buffer, 333 | e_point_count, 334 | vk::ShaderStageFlagBits::eCompute, 335 | vk::DescriptorType ::eStorageBuffer); 336 | set->AddBufferDescriptorTarget(visiable_count_buffer, 337 | e_visiable_count, 338 | vk::ShaderStageFlagBits::eCompute, 339 | vk::DescriptorType ::eStorageBuffer); 340 | set->AddBufferDescriptorTarget(raw_data, 341 | e_gaussian_raw_point, 342 | vk::ShaderStageFlagBits::eCompute, 343 | vk::DescriptorType::eStorageBuffer); 344 | set->AddBufferDescriptorTarget(key_buffer, 345 | e_instance_key, 346 | vk::ShaderStageFlagBits::eCompute, 347 | vk::DescriptorType::eStorageBuffer); 348 | 349 | set->AddBufferDescriptorTarget(value_buffer, 350 | e_instance_value, 351 | vk::ShaderStageFlagBits::eCompute, 352 | vk::DescriptorType::eStorageBuffer); 353 | set->AddBufferDescriptorTarget(instance_buffer, 354 | e_instance_point, 355 | vk::ShaderStageFlagBits::eCompute | vk::ShaderStageFlagBits::eVertex, 356 | vk::DescriptorType::eStorageBuffer); 357 | set->AddBufferDescriptorTarget(inverse_index_buffer, 358 | e_inverse_index, 359 | vk::ShaderStageFlagBits::eCompute, 360 | vk::DescriptorType::eStorageBuffer); 361 | set->AddBufferDescriptorTarget(camera_buffer, 362 | e_camera, 363 | vk::ShaderStageFlagBits::eCompute, 364 | vk::DescriptorType::eUniformBuffer); 365 | setpool.reset(new DescriptorPool({ set })); 366 | set->build(setpool, 1); 367 | } 368 | { 369 | rank_pass.reset(new ComputePass( 370 | { set }, 371 | sizeof(glm::mat4), 372 | "include/shaders/rank.comp.spv")); 373 | } 374 | { 375 | inverse_pass.reset(new ComputePass( 376 | { set }, 377 | sizeof(glm::mat4), 378 | "include/shaders/inverseIndex.comp.spv")); 379 | ; 380 | } 381 | { 382 | projection_pass.reset(new ComputePass( 383 | { set }, 384 | sizeof(glm::mat4), 385 | "include/shaders/projection.comp.spv")); 386 | ; 387 | } 388 | 389 | m_gpu_sort->Init(all_point_count, visiable_count_buffer); 390 | } 391 | void GSContext::tick(std::shared_ptr command) 392 | { 393 | auto cmd = command->get_handle(); 394 | { 395 | cmd.updateBuffer(camera_buffer->get_handle(), 396 | 0, 397 | CameraInfo { 398 | .projection = Context::Get_Singleton()->get_camera()->Get_p_matrix(), 399 | .view = Context::Get_Singleton()->get_camera()->Get_v_matrix(), 400 | .camera_position = Context::Get_Singleton()->get_camera()->get_pos(), 401 | 402 | .pad = 0, 403 | .screen_size = glm::uvec2(1600, 900) }); 404 | } 405 | { // rank 406 | 407 | cmd.fillBuffer(visiable_count_buffer->get_handle(), 0, sizeof(uint32_t), 0); 408 | 409 | cmd.pipelineBarrier( 410 | vk::PipelineStageFlagBits::eTransfer, 411 | vk::PipelineStageFlagBits::eComputeShader, 412 | vk::DependencyFlagBits::eByRegion, 413 | vk::MemoryBarrier() 414 | .setSrcAccessMask(vk::AccessFlagBits::eTransferWrite) 415 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 416 | {}, 417 | {}); 418 | 419 | cmd.pushConstants( 420 | // rank_pipeline->get_layout(), 421 | rank_pass->get_pipeline()->get_layout(), 422 | vk::ShaderStageFlagBits::eCompute, 423 | 0, 424 | glm::mat4(1)); 425 | rank_pass->Dispach(cmd, ceil(float(all_point_count) / local_size), 1, 1); 426 | } 427 | cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, 428 | vk::PipelineStageFlagBits::eComputeShader, 429 | vk::DependencyFlagBits::eByRegion, 430 | vk::MemoryBarrier() 431 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite) 432 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 433 | {}, 434 | {}); 435 | { 436 | m_gpu_sort->sort(cmd, key_buffer, value_buffer); 437 | } 438 | cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, 439 | vk::PipelineStageFlagBits::eComputeShader, 440 | vk::DependencyFlagBits::eByRegion, 441 | vk::MemoryBarrier() 442 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite) 443 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 444 | {}, 445 | {}); 446 | { 447 | cmd.fillBuffer(inverse_index_buffer->get_handle(), 0, all_point_count * sizeof(uint32_t), -1); 448 | 449 | cmd.pipelineBarrier( 450 | vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer, 451 | vk::PipelineStageFlagBits::eComputeShader, 452 | vk::DependencyFlagBits::eByRegion, 453 | vk::MemoryBarrier() 454 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferWrite) 455 | .setDstAccessMask(vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eShaderRead), 456 | {}, 457 | {}); 458 | 459 | inverse_pass->Dispach(cmd, ceil(float(all_point_count) / local_size), 1, 1); 460 | 461 | } 462 | cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, 463 | vk::PipelineStageFlagBits::eComputeShader, 464 | vk::DependencyFlagBits::eByRegion, 465 | vk::MemoryBarrier() 466 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite) 467 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 468 | {}, 469 | {}); 470 | { 471 | cmd.pushConstants( 472 | projection_pass->get_pipeline()->get_layout(), 473 | vk::ShaderStageFlagBits::eCompute, 474 | 0, 475 | glm::mat4(1)); 476 | projection_pass->Dispach(cmd, ceil(float(all_point_count) / local_size), 1, 1); 477 | 478 | } 479 | 480 | cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, 481 | vk::PipelineStageFlagBits::eVertexInput | vk::PipelineStageFlagBits::eDrawIndirect, 482 | vk::DependencyFlagBits::eByRegion, 483 | vk::MemoryBarrier() 484 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite) 485 | .setDstAccessMask(vk::AccessFlagBits::eIndirectCommandRead | vk::AccessFlagBits::eVertexAttributeRead), 486 | {}, 487 | {}); 488 | } 489 | std::shared_ptr GSContext::BeginFrame() 490 | { 491 | 492 | return ComputeContext::BeginFrame(); 493 | } 494 | } -------------------------------------------------------------------------------- /src/sort.cpp: -------------------------------------------------------------------------------- 1 | #include "sort.hpp" 2 | #include "Helper/CommandManager.hpp" 3 | #include "Wrapper/Buffer.hpp" 4 | #include "Wrapper/ComputePass/ComputePass.hpp" 5 | constexpr uint32_t RADIX = 256; 6 | constexpr int WORKGROUP_SIZE = 512; 7 | constexpr int PARTITION_DIVISION = 8; 8 | constexpr int PARTITION_SIZE = PARTITION_DIVISION * WORKGROUP_SIZE; 9 | static constexpr uint32_t MAX_SPLAT_COUNT = 1 << 23; 10 | namespace MCRT { 11 | struct PushConstants { 12 | uint32_t pass; 13 | VkDeviceAddress elementCountReference; 14 | VkDeviceAddress globalHistogramReference; 15 | VkDeviceAddress partitionHistogramReference; 16 | VkDeviceAddress keysInReference; 17 | VkDeviceAddress keysOutReference; 18 | VkDeviceAddress valuesInReference; 19 | VkDeviceAddress valuesOutReference; 20 | }; 21 | uint32_t RoundUp(uint32_t a, uint32_t b) 22 | { 23 | return (a + b - 1) / b; 24 | } 25 | VkDeviceSize HistogramSize(uint32_t elementCount) 26 | { 27 | return (1 + 4 * RADIX + RoundUp(elementCount, PARTITION_SIZE) * RADIX) * 28 | sizeof(uint32_t); 29 | } 30 | 31 | VkDeviceSize InoutSize(uint32_t elementCount) 32 | { 33 | return elementCount * sizeof(uint32_t); 34 | } 35 | // void LookDeviceBuffer(vk::Buffer device_buffer, int size) 36 | // { 37 | // Context::Get_Singleton()->get_device()->Get_Graphic_queue().waitIdle(); 38 | // auto host_buffer = Buffer::create_buffer(nullptr, size, vk::BufferUsageFlagBits::eTransferDst); 39 | // CommandManager::ExecuteCmd(Context::Get_Singleton()->get_device()->Get_Graphic_queue(), 40 | // [&](vk::CommandBuffer& cmd) { 41 | // cmd.copyBuffer( 42 | // device_buffer, 43 | // host_buffer->get_handle(), 44 | // vk::BufferCopy() 45 | // .setDstOffset(0) 46 | // .setSrcOffset(0) 47 | // .setSize(size)); 48 | // }); 49 | // auto cpu_raw_data = host_buffer->Get_mapped_data(); 50 | // std::vector data(size / 4); 51 | // std::memcpy(data.data(), cpu_raw_data.data(), sizeof(data)); 52 | // int r = 0; 53 | // } 54 | void gpusort::Init(uint all_point_count, std::shared_ptr _visiable_buffer) 55 | { 56 | visiable_count_buffer = _visiable_buffer; 57 | 58 | elementCountSize = sizeof(uint32_t); 59 | histogramSize = HistogramSize(all_point_count); 60 | inoutSize = InoutSize(all_point_count); 61 | 62 | histogramOffset = elementCountSize; 63 | inoutOffset = histogramOffset + histogramSize; 64 | // 2x for key value 65 | VkDeviceSize storageSize = inoutOffset + 2 * inoutSize; 66 | 67 | storage_buffer = Buffer::CreateDeviceBuffer( 68 | nullptr, 69 | storageSize, 70 | vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferDst | vk::BufferUsageFlagBits::eTransferSrc | 71 | vk::BufferUsageFlagBits::eShaderDeviceAddress); 72 | partitionCount = 73 | RoundUp(all_point_count, PARTITION_SIZE); 74 | 75 | auto upsweep_shader = std::make_shared("include/shaders/gpusort/upsweep.comp.spv"); 76 | 77 | auto spineshader = std::make_shared("include/shaders/gpusort/spine.comp.spv"); 78 | auto downsweep_shader = std::make_shared("include/shaders/gpusort/downsweep.comp.spv"); 79 | 80 | 81 | upsweepPass.reset(new ComputePass({}, sizeof(PushConstants), upsweep_shader)); 82 | spinePass.reset(new ComputePass({}, sizeof(PushConstants), spineshader)); 83 | downsweepKeyValuePass.reset(new ComputePass({}, sizeof(PushConstants), downsweep_shader)); 84 | } 85 | 86 | void gpusort::sort(vk::CommandBuffer cmd, std::shared_ptr key_buffer, std::shared_ptr value_buffer) 87 | { 88 | 89 | { 90 | cmd.copyBuffer( 91 | visiable_count_buffer->get_handle(), 92 | storage_buffer->get_handle(), 93 | vk::BufferCopy() 94 | .setSize(sizeof(uint32_t)) 95 | .setDstOffset(0) 96 | .setSrcOffset(0)); 97 | } 98 | // reset global histogram. partition histogram is set by shader. 99 | cmd.fillBuffer( 100 | storage_buffer->get_handle(), 101 | sizeof(uint32_t), 102 | 4 * RADIX * sizeof(uint32_t), 103 | 0); 104 | 105 | cmd.pipelineBarrier(vk::PipelineStageFlagBits::eTransfer, 106 | vk::PipelineStageFlagBits::eComputeShader, 107 | vk::DependencyFlagBits::eDeviceGroup, 108 | // 0, 109 | vk::MemoryBarrier() 110 | .setSrcAccessMask(vk::AccessFlagBits::eTransferWrite) 111 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 112 | {}, 113 | {}); 114 | 115 | auto storageAddress = storage_buffer->get_address(); 116 | auto keysAddress = key_buffer->get_address(); 117 | auto valuesAddress = value_buffer->get_address(); 118 | for (int i = 0; i < 4; ++i) { 119 | // pushConstants.pass = i; 120 | auto native_keysInReference = keysAddress; 121 | auto native_keysOutReference = storageAddress + inoutOffset; 122 | auto native_valuesInReference = valuesAddress; 123 | auto native_valuesOutReference = storageAddress + inoutOffset + inoutSize; 124 | 125 | cmd.pushConstants( 126 | upsweepPass->get_pipeline()->get_layout(), 127 | vk::ShaderStageFlagBits::eCompute, 128 | 0, 129 | PushConstants { 130 | .pass = static_cast(i), 131 | .elementCountReference = storageAddress, 132 | .globalHistogramReference = storageAddress + sizeof(uint32_t), 133 | .partitionHistogramReference = storageAddress + sizeof(uint32_t) + sizeof(uint32_t) * 4 * RADIX, 134 | 135 | .keysInReference = i == 0 || i == 2 ? native_keysInReference : native_keysOutReference, 136 | .keysOutReference = i == 0 || i == 2 ? native_keysOutReference : native_keysInReference, 137 | .valuesInReference = i == 0 || i == 2 ? native_valuesInReference : native_valuesOutReference, 138 | .valuesOutReference = i == 0 || i == 2 ? native_valuesOutReference : native_valuesInReference 139 | 140 | }); 141 | 142 | cmd.bindPipeline(vk::PipelineBindPoint ::eCompute, upsweepPass->get_pipeline()->get_handle()); 143 | 144 | cmd.dispatch(partitionCount, 1, 1); 145 | 146 | // 147 | cmd.pipelineBarrier( 148 | vk::PipelineStageFlagBits::eComputeShader, 149 | vk::PipelineStageFlagBits::eComputeShader, 150 | vk::DependencyFlagBits::eByRegion, 151 | vk::MemoryBarrier() 152 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite) 153 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 154 | {}, 155 | {}); 156 | { 157 | cmd.bindPipeline( 158 | vk::PipelineBindPoint ::eCompute, 159 | spinePass->get_pipeline()->get_handle()); 160 | 161 | cmd.dispatch(RADIX, 1, 1); 162 | } 163 | cmd.pipelineBarrier( 164 | vk::PipelineStageFlagBits::eComputeShader, 165 | vk::PipelineStageFlagBits::eComputeShader, 166 | vk::DependencyFlagBits::eByRegion, 167 | vk::MemoryBarrier() 168 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite) 169 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 170 | {}, 171 | {}); 172 | { 173 | cmd.bindPipeline( 174 | vk::PipelineBindPoint ::eCompute, 175 | downsweepKeyValuePass->get_pipeline()->get_handle()); 176 | cmd.dispatch(partitionCount, 1, 1); 177 | } 178 | if (i < 3) { 179 | 180 | cmd.pipelineBarrier( 181 | vk::PipelineStageFlagBits::eComputeShader, 182 | vk::PipelineStageFlagBits::eComputeShader, 183 | vk::DependencyFlagBits::eByRegion, 184 | vk::MemoryBarrier() 185 | .setSrcAccessMask(vk::AccessFlagBits::eShaderWrite) 186 | .setDstAccessMask(vk::AccessFlagBits::eShaderRead), 187 | {}, 188 | {}); 189 | } 190 | } 191 | } 192 | 193 | } 194 | --------------------------------------------------------------------------------