├── .clang-format ├── .gitattributes ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake └── aliked-config.cmake.in ├── examples └── main.cpp ├── include ├── aliked.hpp ├── blocks.hpp ├── cuda_helpers.h ├── deform_conv2d.h ├── dkd.hpp ├── get_patches.hpp ├── get_patches_cuda.h ├── input_padder.hpp └── sddh.hpp ├── models ├── aliked-n16.pt ├── aliked-n16rot.pt ├── aliked-n32.pt └── aliked-t16.pt └── src ├── aliked.cpp ├── blocks.cpp ├── deform_conv2d.cpp ├── deform_conv2d_kernel.cu ├── dkd.cpp ├── get_patches.cpp ├── get_patches_cuda.cu ├── input_padder.cpp └── sddh.cpp /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | Standard: Latest 4 | 5 | # Access modifiers and indentation 6 | AccessModifierOffset: -4 7 | IndentWidth: 4 8 | ContinuationIndentWidth: 4 9 | TabWidth: 4 10 | UseTab: Never 11 | NamespaceIndentation: All 12 | 13 | # Alignment 14 | AlignAfterOpenBracket: Align 15 | AlignConsecutiveAssignments: false 16 | AlignConsecutiveDeclarations: false 17 | AlignConsecutiveMacros: Consecutive 18 | AlignEscapedNewlines: Left 19 | AlignOperands: true 20 | AlignTrailingComments: true 21 | 22 | # Allow behaviors 23 | AllowAllArgumentsOnNextLine: false 24 | AllowAllParametersOfDeclarationOnNextLine: false 25 | AllowShortBlocksOnASingleLine: Always 26 | AllowShortCaseLabelsOnASingleLine: true 27 | AllowShortEnumsOnASingleLine: true 28 | AllowShortFunctionsOnASingleLine: All 29 | AllowShortLambdasOnASingleLine: All 30 | AllowShortIfStatementsOnASingleLine: Never 31 | AllowShortLoopsOnASingleLine: false 32 | AllowAllConstructorInitializersOnNextLine: false 33 | 34 | # Breaking and wrapping 35 | AlwaysBreakAfterDefinitionReturnType: None 36 | AlwaysBreakAfterReturnType: None 37 | AlwaysBreakBeforeMultilineStrings: false 38 | AlwaysBreakTemplateDeclarations: true 39 | BreakBeforeBinaryOperators: None 40 | BreakBeforeTernaryOperators: true 41 | BreakConstructorInitializersBeforeComma: false 42 | BreakConstructorInitializers: BeforeColon 43 | BreakInheritanceList: BeforeColon 44 | BreakStringLiterals: true 45 | ColumnLimit: 0 46 | 47 | # Brace wrapping 48 | BreakBeforeBraces: Custom 49 | BraceWrapping: 50 | AfterCaseLabel: false 51 | AfterClass: false 52 | AfterControlStatement: Always 53 | AfterEnum: false 54 | AfterFunction: false 55 | AfterNamespace: false 56 | AfterObjCDeclaration: false 57 | AfterStruct: false 58 | AfterUnion: false 59 | AfterExternBlock: false 60 | BeforeCatch: false 61 | BeforeElse: false 62 | IndentBraces: false 63 | SplitEmptyFunction: true 64 | SplitEmptyRecord: true 65 | SplitEmptyNamespace: true 66 | 67 | # Constructor initialization 68 | ConstructorInitializerIndentWidth: 4 69 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 70 | 71 | # Empty lines and spacing 72 | EmptyLineBeforeAccessModifier: Always 73 | KeepEmptyLinesAtTheStartOfBlocks: true 74 | MaxEmptyLinesToKeep: 1 75 | SpaceAfterCStyleCast: false 76 | SpaceAfterTemplateKeyword: true 77 | SpaceBeforeAssignmentOperators: true 78 | SpaceBeforeParens: ControlStatements 79 | SpaceInEmptyParentheses: false 80 | SpacesBeforeTrailingComments: 1 81 | SpacesInAngles: false 82 | SpacesInContainerLiterals: true 83 | SpacesInCStyleCastParentheses: false 84 | SpacesInParentheses: false 85 | SpacesInSquareBrackets: false 86 | 87 | # Include ordering 88 | SortIncludes: CaseInsensitive 89 | IncludeBlocks: Regroup 90 | IncludeCategories: 91 | # C++ Standard Library headers 92 | - Regex: '^<(cctype|span|cstring|string|string_view|vector|map|fstream|typeindex|source_location|stacktrace|array|iostream|memory|future|stdexcept|algorithm|random|atomic|sstream|chrono|cstdint|expected|filesystem|functional|mutex|queue|optional|shared_mutex|thread|utility|variant|unordered_map|unordered_set|condition_variable)>$' 93 | Priority: 1 94 | SortPriority: 3 95 | # C Standard Library headers 96 | - Regex: '^<(ft2build\.h|GL/|GLFW/|glm/|spdlog/|fmt/).*>' 97 | Priority: 2 98 | SortPriority: 2 99 | - Regex: '^"(ft2build\.h|GL/|GLFW/|glm/|spdlog/|fmt/).*"' 100 | Priority: 2 101 | SortPriority: 2 102 | # All project headers 103 | - Regex: '.*' 104 | Priority: 3 105 | SortPriority: 1 106 | 107 | # Other settings 108 | Cpp11BracedListStyle: true 109 | DerivePointerAlignment: false 110 | FixNamespaceComments: true 111 | IndentCaseBlocks: false 112 | IndentCaseLabels: false 113 | IndentGotoLabels: false 114 | IndentPPDirectives: None 115 | IndentWrappedFunctionNames: false 116 | PointerAlignment: Left 117 | ReflowComments: true 118 | --- 119 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | models filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/code 2 | external/libtorch 3 | src/copies.sh 4 | logs 5 | build 6 | .idea 7 | cmake-* 8 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.26) 2 | project(aliked 3 | VERSION 1.0.0 4 | DESCRIPTION "C++ implementation of ALIKED" 5 | LANGUAGES CUDA CXX) 6 | 7 | # Enable LTO/IPO 8 | include(CheckIPOSupported) 9 | check_ipo_supported(RESULT IPO_SUPPORTED OUTPUT IPO_ERROR) 10 | if(IPO_SUPPORTED) 11 | set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON) 12 | endif() 13 | 14 | # Core configuration 15 | set(CMAKE_CXX_STANDARD 20) 16 | set(CMAKE_CUDA_STANDARD 17) 17 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 18 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 19 | set(CMAKE_CUDA_ARCHITECTURES native) 20 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 21 | 22 | # Configure paths 23 | set(ALIKED_MODELS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/models" CACHE PATH "Path to model weights directory") 24 | 25 | # Find dependencies 26 | set(LIBTORCH_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/libtorch") 27 | set(CMAKE_PREFIX_PATH ${LIBTORCH_DIR}) 28 | 29 | find_package(Torch REQUIRED) 30 | find_package(OpenCV REQUIRED) 31 | find_package(CUDAToolkit REQUIRED) 32 | 33 | # Check CUDA version 34 | if(CUDAToolkit_VERSION VERSION_LESS "12.1") 35 | message(FATAL_ERROR "This project requires CUDA 12.1 or higher (found: ${CUDAToolkit_VERSION})") 36 | endif() 37 | 38 | # Performance flags 39 | if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") 40 | set(PERFORMANCE_FLAGS 41 | -O3 42 | -march=native 43 | -mtune=native 44 | -fomit-frame-pointer 45 | -ffast-math 46 | ) 47 | endif() 48 | 49 | # CUDA flags 50 | set(CUDA_FLAGS 51 | -O3 52 | --use_fast_math 53 | -DNDEBUG 54 | ) 55 | 56 | # Add models directory definition 57 | add_definitions(-DALIKED_MODELS_DIR="${ALIKED_MODELS_DIR}") 58 | 59 | # Source files 60 | set(ALIKED_HEADERS 61 | include/aliked.hpp 62 | include/dkd.hpp 63 | include/sddh.hpp 64 | include/blocks.hpp 65 | include/get_patches.hpp 66 | include/input_padder.hpp 67 | include/deform_conv2d.h 68 | include/get_patches_cuda.h 69 | include/cuda_helpers.h 70 | ) 71 | 72 | set(ALIKED_SOURCES 73 | src/blocks.cpp 74 | src/aliked.cpp 75 | src/dkd.cpp 76 | src/input_padder.cpp 77 | src/get_patches.cpp 78 | src/sddh.cpp 79 | src/deform_conv2d.cpp 80 | src/deform_conv2d_kernel.cu 81 | src/get_patches_cuda.cu 82 | ) 83 | 84 | # Library target 85 | add_library(${PROJECT_NAME}_lib STATIC 86 | ${ALIKED_SOURCES} 87 | ${ALIKED_HEADERS} 88 | ) 89 | 90 | add_library(${PROJECT_NAME}::lib ALIAS ${PROJECT_NAME}_lib) 91 | 92 | target_include_directories(${PROJECT_NAME}_lib 93 | PUBLIC 94 | $ 95 | $ 96 | ) 97 | 98 | target_compile_options(${PROJECT_NAME}_lib 99 | PRIVATE 100 | $<$:${PERFORMANCE_FLAGS}> 101 | $<$:${CUDA_FLAGS}> 102 | ) 103 | 104 | target_link_libraries(${PROJECT_NAME}_lib 105 | PUBLIC 106 | ${TORCH_LIBRARIES} 107 | ${OpenCV_LIBS} 108 | PRIVATE 109 | CUDA::cudart 110 | CUDA::curand 111 | CUDA::cublas 112 | ) 113 | 114 | # Properties for maximum performance 115 | set_target_properties(${PROJECT_NAME}_lib PROPERTIES 116 | CUDA_SEPARABLE_COMPILATION ON 117 | CUDA_RESOLVE_DEVICE_SYMBOLS ON 118 | POSITION_INDEPENDENT_CODE ON 119 | INTERPROCEDURAL_OPTIMIZATION ${IPO_SUPPORTED} 120 | ) 121 | 122 | # Example application 123 | add_executable(${PROJECT_NAME} examples/main.cpp) 124 | target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_NAME}::lib) 125 | 126 | # Set output directories 127 | set_target_properties(${PROJECT_NAME} ${PROJECT_NAME}_lib PROPERTIES 128 | RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" 129 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" 130 | ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" 131 | ) 132 | 133 | # Create models directory if it doesn't exist 134 | add_custom_target(create_models_dir ALL 135 | COMMAND ${CMAKE_COMMAND} -E make_directory ${ALIKED_MODELS_DIR} 136 | ) 137 | 138 | # Print models directory location 139 | message(STATUS "Models directory: ${ALIKED_MODELS_DIR}") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Zhao Xiaoming 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ALIKED C++ Implementation 2 | 3 | This is a C++ implementation of ALIKED (Attentive Local and Implicit Keypoint Detector) using LibTorch and OpenCV. The implementation provides a high-performance, production-ready version of the ALIKED model for keypoint detection and matching. 4 | 5 | If you are interested in collaborating on this project, would like to reach out to me, or are considering to contribute to the discussion of the overall NeRF/GS pipeline, please join the Discord: https://discord.gg/NqwTqVYVmj 6 | 7 | ## Features 8 | 9 | - Complete C++ implementation of ALIKED model 10 | - CUDA-accelerated computations 11 | - OpenCV integration for image processing 12 | - Real-time keypoint detection and matching 13 | - Multiple model configurations (aliked-t16, aliked-n16, aliked-n16rot, aliked-n32) 14 | - Move semantics optimization for better performance 15 | - Simple tracking demo application 16 | 17 | ## Prerequisites 18 | 19 | - CMake (>= 3.26) 20 | - CUDA Toolkit (>= 12.1) 21 | - LibTorch (with CUDA support) 22 | - OpenCV 23 | - C++20 compatible compiler 24 | 25 | ## Directory Structure 26 | 27 | ``` 28 | . 29 | ├── include/ # Header files 30 | ├── src/ # Source files 31 | ├── examples/ # Example applications 32 | ├── models/ # Pre-trained model weights 33 | ├── external/ 34 | │ └── libtorch/ # LibTorch directory 35 | └── CMakeLists.txt # CMake configuration 36 | ``` 37 | 38 | ## Setup Instructions 39 | 40 | 1. Download and extract LibTorch: 41 | ```bash 42 | mkdir -p external 43 | cd external 44 | wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcu121.zip 45 | unzip libtorch-cxx11-abi-shared-with-deps-2.1.0+cu121.zip 46 | cd .. 47 | ``` 48 | 49 | 2. Build the project: 50 | ```bash 51 | mkdir build && cd build 52 | cmake -DCMAKE_BUILD_TYPE=Release .. 53 | make -j$(nproc) 54 | ``` 55 | 56 | ## Usage 57 | 58 | The demo application performs keypoint detection and matching between consecutive images: 59 | 60 | ```bash 61 | ./aliked /path/to/image/directory [options] 62 | ``` 63 | 64 | ### Options 65 | 66 | - `model_name`: Model configuration (default: "aliked-n32") 67 | - `device`: Computation device (default: "cuda") 68 | - `top_k`: Number of top keypoints (-1 for threshold-based selection, default: -1) 69 | - `scores_th`: Score threshold for keypoint selection (default: 0.2) 70 | - `n_limit`: Maximum number of keypoints (default: 5000) 71 | 72 | ### Example Code 73 | 74 | ```cpp 75 | #include "ALIKED.hpp" 76 | 77 | // Initialize model 78 | auto model = std::make_shared("aliked-n32", "cuda"); 79 | 80 | // Load and process image 81 | cv::Mat img = cv::imread("image.jpg"); 82 | cv::Mat img_rgb; 83 | cv::cvtColor(img, img_rgb, cv::COLOR_BGR2RGB); 84 | 85 | // Run inference 86 | auto pred = model->run(img_rgb); 87 | auto keypoints = pred.at("keypoints"); 88 | auto descriptors = pred.at("descriptors"); 89 | ``` 90 | 91 | ## Model Configurations 92 | 93 | | Model Name | Description | 94 | |--------------|-------------------------------------------| 95 | | aliked-t16 | Tiny model with 16 descriptor dimensions | 96 | | aliked-n16 | Normal model with 16 descriptor dimensions| 97 | | aliked-n16rot| Rotation-invariant model | 98 | | aliked-n32 | Normal model with 32 descriptor dimensions| 99 | 100 | ## Performance Optimizations 101 | 102 | The implementation includes several optimizations: 103 | 104 | - Link Time Optimization (LTO/IPO) 105 | - CPU architecture-specific optimizations 106 | - CUDA optimizations 107 | - Fast math operations 108 | - Position-independent code 109 | 110 | ## Custom Model Directory 111 | 112 | You can specify a custom location for model weights during build: 113 | 114 | ```bash 115 | cmake -DCMAKE_BUILD_TYPE=Release -DALIKED_MODELS_DIR=/path/to/models .. 116 | ``` 117 | 118 | ## Contributing 119 | 120 | Contributions are welcome! Please feel free to submit pull requests, create issues, or suggest improvements. 121 | 122 | ## Acknowledgements 123 | 124 | - Original ALIKED paper and implementation 125 | - LibTorch and PyTorch teams 126 | - OpenCV team -------------------------------------------------------------------------------- /cmake/aliked-config.cmake.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/ALIKED_CPP/3daaff1ae13537c4fed0aba2f5776b5c4f418330/cmake/aliked-config.cmake.in -------------------------------------------------------------------------------- /examples/main.cpp: -------------------------------------------------------------------------------- 1 | #include "aliked.hpp" 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace fs = std::filesystem; 10 | 11 | class ImageLoader { 12 | public: 13 | explicit ImageLoader(const std::string& filepath) { 14 | for (const auto& entry : fs::directory_iterator(filepath)) 15 | { 16 | const auto& path = entry.path(); 17 | std::string ext = path.extension().string(); 18 | std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); 19 | 20 | if (ext == ".png" || ext == ".jpg" || ext == ".ppm") 21 | { 22 | images_.push_back(path.string()); 23 | } 24 | } 25 | std::sort(images_.begin(), images_.end()); 26 | std::cout << "Loading " << images_.size() << " images" << std::endl; 27 | } 28 | 29 | cv::Mat operator[](size_t idx) const { 30 | return cv::imread(images_[idx]); 31 | } 32 | 33 | size_t size() const { return images_.size(); } 34 | 35 | private: 36 | std::vector images_; 37 | }; 38 | 39 | class SimpleTracker { 40 | public: 41 | SimpleTracker() : pts_prev_(), 42 | desc_prev_() {} 43 | 44 | // Update function 45 | std::tuple update(const cv::Mat& img, const torch::Tensor& pts, const torch::Tensor& desc) { 46 | cv::Mat out = img.clone(); 47 | int N_matches = 0; 48 | 49 | if (!pts_prev_.defined()) 50 | { 51 | // First frame: Initialize points and descriptors 52 | pts_prev_ = pts.clone(); 53 | desc_prev_ = desc.clone(); 54 | 55 | // Draw keypoints 56 | for (int i = 0; i < pts.size(0); ++i) 57 | { 58 | cv::Point2f p1(pts[i][0].item(), pts[i][1].item()); 59 | cv::circle(out, p1, 1, cv::Scalar(0, 0, 255), -1, cv::LINE_AA); 60 | } 61 | } else 62 | { 63 | // Compute matches 64 | auto matches = mnn_matcher(desc_prev_, desc); 65 | N_matches = matches.size(0); 66 | 67 | // Draw matches 68 | for (int i = 0; i < N_matches; ++i) 69 | { 70 | int idx0 = matches[i][0].item(); 71 | int idx1 = matches[i][1].item(); 72 | 73 | cv::Point2f pt1(pts_prev_[idx0][0].item(), pts_prev_[idx0][1].item()); 74 | cv::Point2f pt2(pts[idx1][0].item(), pts[idx1][1].item()); 75 | cv::line(out, pt1, pt2, cv::Scalar(0, 255, 0), 1, cv::LINE_AA); 76 | cv::circle(out, pt2, 1, cv::Scalar(0, 0, 255), -1, cv::LINE_AA); 77 | } 78 | 79 | // Update previous points and descriptors 80 | pts_prev_ = pts.clone(); 81 | desc_prev_ = desc.clone(); 82 | } 83 | 84 | return {out, N_matches}; 85 | } 86 | 87 | private: 88 | // Nearest neighbor matcher 89 | static torch::Tensor mnn_matcher(const torch::Tensor& desc1, const torch::Tensor& desc2) { 90 | // Compute similarity matrix 91 | auto sim = torch::matmul(desc1, desc2.t()); 92 | sim = torch::where(sim < 0.9, torch::zeros_like(sim), sim); 93 | 94 | // Nearest neighbors 95 | auto nn12 = std::get<1>(torch::max(sim, 1)); // Nearest in desc2 for each desc1 96 | auto nn21 = std::get<1>(torch::max(sim, 0)); // Nearest in desc1 for each desc2 97 | 98 | // Mask to enforce mutual nearest neighbors 99 | auto ids1 = torch::arange(sim.size(0), torch::TensorOptions().device(sim.device())); 100 | auto mask = (ids1 == nn21.index({nn12})); 101 | auto matches = torch::stack({ids1.masked_select(mask), nn12.masked_select(mask)}, 1); 102 | 103 | return matches; 104 | } 105 | 106 | torch::Tensor pts_prev_; 107 | torch::Tensor desc_prev_; 108 | }; 109 | 110 | int main(int argc, char* argv[]) { 111 | if (argc < 2) 112 | { 113 | std::cerr << "Usage: " << argv[0] << " [options]" << std::endl; 114 | return 1; 115 | } 116 | 117 | // Parse command line arguments 118 | const std::string input_dir = argv[1]; 119 | const std::string model_name = "aliked-n32"; 120 | const std::string device = "cuda"; 121 | const int top_k = -1; 122 | const float scores_th = 0.2f; 123 | const int n_limit = 5000; 124 | 125 | // Initialize model 126 | std::cout << "Initializing ALIKED model..." << std::endl; 127 | const auto model = std::make_shared(model_name, device, top_k, scores_th, n_limit); 128 | 129 | // Load images 130 | ImageLoader image_loader(input_dir); 131 | if (image_loader.size() < 2) 132 | { 133 | std::cerr << "Need at least 2 images in the input directory" << std::endl; 134 | return 1; 135 | } 136 | 137 | // Initialize tracker 138 | SimpleTracker tracker; 139 | 140 | // Display prompt 141 | std::cout << "Press 'space' to start. \nPress 'q' or 'ESC' to stop!" << std::endl; 142 | 143 | for (size_t i = 0; i < image_loader.size(); i++) 144 | { 145 | cv::Mat img = image_loader[i]; 146 | if (img.empty()) 147 | break; 148 | 149 | // Convert image to RGB 150 | cv::Mat img_rgb; 151 | cv::cvtColor(img, img_rgb, cv::COLOR_BGR2RGB); 152 | 153 | // Run model 154 | const auto pred = model->run(img_rgb); 155 | auto kpts = pred.at("keypoints").cpu(); 156 | const auto desc = pred.at("descriptors").cpu(); 157 | 158 | // Normalize and scale keypoints to pixel coordinates 159 | const int img_width = img.cols; 160 | const int img_height = img.rows; 161 | 162 | kpts = (kpts + 1.0) * 0.5; // Normalize to [0, 1] 163 | kpts.select(1, 0).mul_(img_width); // Scale x-coordinates 164 | kpts.select(1, 1).mul_(img_height); // Scale y-coordinates 165 | 166 | // Plot keypoints on the current image 167 | for (int j = 0; j < kpts.size(0); ++j) 168 | { 169 | const auto x = kpts[j][0].item(); 170 | const auto y = kpts[j][1].item(); 171 | 172 | // Validate coordinates 173 | if (x >= 0 && x < img_width && y >= 0 && y < img_height) 174 | { 175 | cv::circle(img, cv::Point2f(x, y), 1, cv::Scalar(0, 0, 255), -1, cv::LINE_AA); 176 | } else 177 | { 178 | std::cerr << "Keypoint out of bounds: (" << x << ", " << y << ")" << std::endl; 179 | } 180 | } 181 | 182 | // Update tracker 183 | cv::Mat vis_img; 184 | int N_matches; 185 | std::tie(vis_img, N_matches) = tracker.update(img, kpts, desc); 186 | 187 | // Status message 188 | const std::string status = "matches/keypoints: " + 189 | std::to_string(N_matches) + "/" + 190 | std::to_string(kpts.size(0)); 191 | 192 | // Overlay status and instructions 193 | cv::putText(vis_img, "Press 'q' or 'ESC' to stop.", 194 | cv::Point(10, 30), cv::FONT_HERSHEY_SIMPLEX, 1, 195 | cv::Scalar(0, 0, 255), 2, cv::LINE_AA); 196 | 197 | cv::namedWindow(model_name); 198 | cv::setWindowTitle(model_name, model_name + ": " + status); 199 | cv::imshow(model_name, vis_img); 200 | 201 | // Handle user input 202 | const char c = static_cast(cv::waitKey(0)); 203 | if (c == 'q' || c == 27) 204 | break; // Quit on 'q' or 'ESC' 205 | } 206 | 207 | std::cout << "Finished!" << std::endl; 208 | std::cout << "Press any key to exit!" << std::endl; 209 | 210 | cv::destroyAllWindows(); 211 | return 0; 212 | } 213 | -------------------------------------------------------------------------------- /include/aliked.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "blocks.hpp" 4 | #include "input_padder.hpp" 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | struct AlikedConfig { 14 | int c1, c2, c3, c4, dim, K, M; 15 | }; 16 | 17 | class DKD; 18 | class SDDH; 19 | 20 | // Static configuration map 21 | inline const std::unordered_map ALIKED_CFGS = { 22 | {"aliked-t16", {8, 16, 32, 64, 64, 3, 16}}, 23 | {"aliked-n16", {16, 32, 64, 128, 128, 3, 16}}, 24 | {"aliked-n16rot", {16, 32, 64, 128, 128, 3, 16}}, 25 | {"aliked-n32", {16, 32, 64, 128, 128, 3, 32}}}; 26 | 27 | class ALIKED : public torch::nn::Module { 28 | public: 29 | explicit ALIKED(std::string_view model_name = "aliked-n32", 30 | std::string_view device = "cuda", 31 | int top_k = -1, 32 | float scores_th = 0.2, 33 | int n_limit = 5000); 34 | 35 | // Move semantics for tensor operations 36 | std::tuple 37 | extract_dense_map(torch::Tensor image) &&; 38 | 39 | std::tuple 40 | extract_dense_map(const torch::Tensor& image) &; 41 | 42 | torch::Dict 43 | forward(torch::Tensor image) &&; 44 | 45 | torch::Dict 46 | forward(const torch::Tensor& image) &; 47 | 48 | torch::Dict run(cv::Mat& img_rgb); 49 | 50 | private: 51 | void init_layers(std::string_view model_name); 52 | void load_weights(std::string_view model_name); 53 | void load_parameters(std::string_view pt_pth); 54 | 55 | static std::vector get_the_bytes(std::string_view filename); 56 | 57 | torch::nn::AvgPool2d pool2_{nullptr}, pool4_{nullptr}; 58 | std::shared_ptr block1_; 59 | std::shared_ptr block2_; 60 | std::shared_ptr block3_; 61 | std::shared_ptr block4_; 62 | torch::nn::Conv2d conv1_{nullptr}, conv2_{nullptr}, 63 | conv3_{nullptr}, conv4_{nullptr}; 64 | torch::nn::Sequential score_head_{nullptr}; 65 | 66 | std::shared_ptr dkd_; 67 | std::shared_ptr desc_head_; 68 | 69 | torch::Device device_; 70 | int dim_{}; 71 | }; -------------------------------------------------------------------------------- /include/blocks.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | class DeformableConv2d : public torch::nn::Module { 8 | public: 9 | DeformableConv2d(int in_channels, int out_channels, 10 | int kernel_size = 3, int stride = 1, 11 | int padding = 1, bool bias = false); 12 | 13 | torch::Tensor forward(const torch::Tensor& x) &; 14 | torch::Tensor forward(torch::Tensor x) &&; 15 | 16 | private: 17 | torch::nn::Conv2d offset_conv_{nullptr}; 18 | torch::nn::Conv2d regular_conv_{nullptr}; 19 | int padding_; 20 | int groups_ = 1; 21 | int mask_offset_ = 1; 22 | }; 23 | 24 | class ConvBlock : public torch::nn::Module { 25 | public: 26 | ConvBlock(int in_channels, int out_channels, 27 | std::string_view conv_type = "conv", 28 | bool mask = false); 29 | 30 | torch::Tensor forward(torch::Tensor x) &&; 31 | torch::Tensor forward(const torch::Tensor& x) &; 32 | 33 | private: 34 | torch::nn::Conv2d conv1_{nullptr}, conv2_{nullptr}; 35 | std::shared_ptr deform1_{nullptr}, deform2_{nullptr}; 36 | torch::nn::BatchNorm2d bn1_{nullptr}, bn2_{nullptr}; 37 | }; 38 | 39 | class ResBlock : public torch::nn::Module { 40 | public: 41 | ResBlock(int inplanes, int planes, int stride = 1, 42 | const torch::nn::Conv2d& downsample = nullptr, 43 | std::string_view conv_type = "conv"); 44 | 45 | torch::Tensor forward(torch::Tensor x) &&; 46 | torch::Tensor forward(const torch::Tensor& x) &; 47 | 48 | private: 49 | torch::nn::Conv2d conv1_{nullptr}, conv2_{nullptr}; 50 | std::shared_ptr deform1_{nullptr}, deform2_{nullptr}; 51 | torch::nn::BatchNorm2d bn1_{nullptr}, bn2_{nullptr}; 52 | torch::nn::Conv2d downsample_; 53 | }; -------------------------------------------------------------------------------- /include/cuda_helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace vision { 4 | namespace ops { 5 | 6 | #define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \ 7 | for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ 8 | i += (blockDim.x * gridDim.x)) 9 | 10 | #define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int) 11 | 12 | template 13 | constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { 14 | return (n + m - 1) / m; 15 | } 16 | 17 | } // namespace ops 18 | } // namespace vision 19 | -------------------------------------------------------------------------------- /include/deform_conv2d.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace vision { 6 | namespace ops { 7 | 8 | at::Tensor deform_conv2d( 9 | const at::Tensor& input, 10 | const at::Tensor& weight, 11 | const at::Tensor& offset, 12 | const at::Tensor& mask, 13 | const at::Tensor& bias, 14 | int64_t stride_h, 15 | int64_t stride_w, 16 | int64_t pad_h, 17 | int64_t pad_w, 18 | int64_t dilation_h, 19 | int64_t dilation_w, 20 | int64_t groups, 21 | int64_t offset_groups, 22 | bool use_mask); 23 | 24 | at::Tensor deform_conv2d_symint( 25 | const at::Tensor& input, 26 | const at::Tensor& weight, 27 | const at::Tensor& offset, 28 | const at::Tensor& mask, 29 | const at::Tensor& bias, 30 | c10::SymInt stride_h, 31 | c10::SymInt stride_w, 32 | c10::SymInt pad_h, 33 | c10::SymInt pad_w, 34 | c10::SymInt dilation_h, 35 | c10::SymInt dilation_w, 36 | c10::SymInt groups, 37 | c10::SymInt offset_groups, 38 | bool use_mask); 39 | 40 | namespace detail { 41 | 42 | std::tuple 43 | _deform_conv2d_backward( 44 | const at::Tensor& grad, 45 | const at::Tensor& input, 46 | const at::Tensor& weight, 47 | const at::Tensor& offset, 48 | const at::Tensor& mask, 49 | const at::Tensor& bias, 50 | int64_t stride_h, 51 | int64_t stride_w, 52 | int64_t pad_h, 53 | int64_t pad_w, 54 | int64_t dilation_h, 55 | int64_t dilation_w, 56 | int64_t groups, 57 | int64_t offset_groups, 58 | bool use_mask); 59 | 60 | std::tuple 61 | _deform_conv2d_backward_symint( 62 | const at::Tensor& grad, 63 | const at::Tensor& input, 64 | const at::Tensor& weight, 65 | const at::Tensor& offset, 66 | const at::Tensor& mask, 67 | const at::Tensor& bias, 68 | c10::SymInt stride_h, 69 | c10::SymInt stride_w, 70 | c10::SymInt pad_h, 71 | c10::SymInt pad_w, 72 | c10::SymInt dilation_h, 73 | c10::SymInt dilation_w, 74 | c10::SymInt groups, 75 | c10::SymInt offset_groups, 76 | bool use_mask); 77 | 78 | } // namespace detail 79 | 80 | } // namespace ops 81 | } // namespace vision 82 | -------------------------------------------------------------------------------- /include/dkd.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | class DKD : public torch::nn::Module { 7 | public: 8 | DKD(int radius = 2, int top_k = -1, float scores_th = 0.2, int n_limit = 20000); 9 | 10 | std::tuple, std::vector, std::vector> 11 | detect_keypoints(torch::Tensor scores_map, bool sub_pixel = true) &&; 12 | 13 | std::tuple, std::vector, std::vector> 14 | detect_keypoints(const torch::Tensor& scores_map, bool sub_pixel = true) &; 15 | 16 | torch::Tensor simple_nms(torch::Tensor scores, int nms_radius) &&; 17 | torch::Tensor simple_nms(const torch::Tensor& scores, int nms_radius) &; 18 | 19 | std::tuple, std::vector, std::vector> 20 | forward(torch::Tensor scores_map, bool sub_pixel = true) &&; 21 | 22 | std::tuple, std::vector, std::vector> 23 | forward(const torch::Tensor& scores_map, bool sub_pixel = true) &; 24 | 25 | private: 26 | static constexpr int calculateKernelSize(int radius) { return 2 * radius + 1; } 27 | 28 | const int radius_; 29 | const int top_k_; 30 | const float scores_th_; 31 | const int n_limit_; 32 | const int kernel_size_; 33 | const float temperature_; 34 | torch::nn::Unfold unfold_{nullptr}; 35 | torch::Tensor hw_grid_; 36 | }; -------------------------------------------------------------------------------- /include/get_patches.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace custom_ops { 6 | 7 | torch::Tensor get_patches_forward(const torch::Tensor& map, torch::Tensor& points, int64_t radius); 8 | torch::Tensor get_patches_backward(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W); 9 | } // namespace custom_ops 10 | -------------------------------------------------------------------------------- /include/get_patches_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | // CUDA declarations 4 | torch::Tensor get_patches_forward_cuda(const torch::Tensor& map, torch::Tensor& points, int64_t radius); 5 | torch::Tensor get_patches_backward_cuda(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W); 6 | -------------------------------------------------------------------------------- /include/input_padder.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | class InputPadder { 8 | public: 9 | InputPadder(int h, int w, int div_by = 8) 10 | : ht_(h), 11 | wd_(w) { 12 | int pad_ht = (((ht_ / div_by) + 1) * div_by - ht_) % div_by; 13 | int pad_wd = (((wd_ / div_by) + 1) * div_by - wd_) % div_by; 14 | 15 | pad_ = {pad_wd / 2, pad_wd - pad_wd / 2, 16 | pad_ht / 2, pad_ht - pad_ht / 2}; 17 | } 18 | 19 | // Move semantics for pad operation 20 | torch::Tensor pad(torch::Tensor x) &&; 21 | torch::Tensor pad(const torch::Tensor& x) &; 22 | 23 | // Move semantics for unpad operation 24 | [[maybe_unused]] torch::Tensor unpad(torch::Tensor x) &&; 25 | torch::Tensor unpad(const torch::Tensor& x) &; 26 | 27 | void setPadding(std::span padding); 28 | std::span getPadding() const { return std::span(pad_); } 29 | 30 | private: 31 | int ht_; 32 | int wd_; 33 | std::array pad_; 34 | }; -------------------------------------------------------------------------------- /include/sddh.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | class SDDH : public torch::nn::Module { 7 | public: 8 | SDDH(int dims, int kernel_size = 3, int n_pos = 8, 9 | bool conv2D = false, bool mask = false); 10 | 11 | // Overloaded forward with move semantics 12 | std::tuple, std::vector> 13 | forward(torch::Tensor x, std::vector& keypoints) &&; 14 | 15 | std::tuple, std::vector> 16 | forward(const torch::Tensor& x, std::vector& keypoints) &; 17 | 18 | private: 19 | const int kernel_size_; 20 | const int n_pos_; 21 | const bool conv2D_; 22 | const bool mask_; 23 | torch::nn::Sequential offset_conv_{nullptr}; 24 | torch::nn::Conv2d sf_conv_{nullptr}; 25 | torch::nn::Conv2d convM_{nullptr}; 26 | torch::Tensor agg_weights_; 27 | 28 | // Helper functions for processing features 29 | torch::Tensor process_features(torch::Tensor features, int64_t num_keypoints) &&; 30 | torch::Tensor process_features(const torch::Tensor& features, int64_t num_keypoints) &; 31 | }; -------------------------------------------------------------------------------- /models/aliked-n16.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/ALIKED_CPP/3daaff1ae13537c4fed0aba2f5776b5c4f418330/models/aliked-n16.pt -------------------------------------------------------------------------------- /models/aliked-n16rot.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/ALIKED_CPP/3daaff1ae13537c4fed0aba2f5776b5c4f418330/models/aliked-n16rot.pt -------------------------------------------------------------------------------- /models/aliked-n32.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/ALIKED_CPP/3daaff1ae13537c4fed0aba2f5776b5c4f418330/models/aliked-n32.pt -------------------------------------------------------------------------------- /models/aliked-t16.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/ALIKED_CPP/3daaff1ae13537c4fed0aba2f5776b5c4f418330/models/aliked-t16.pt -------------------------------------------------------------------------------- /src/aliked.cpp: -------------------------------------------------------------------------------- 1 | #include "aliked.hpp" 2 | 3 | #include "dkd.hpp" 4 | #include "sddh.hpp" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace fs = std::filesystem; 11 | 12 | ALIKED::ALIKED(std::string_view model_name, 13 | std::string_view device, 14 | int top_k, 15 | float scores_th, 16 | int n_limit) 17 | : device_(torch::Device(std::string(device))), 18 | dim_(-1) { 19 | 20 | // Initialize DKD and descriptor head 21 | dkd_ = std::make_shared(2, top_k, scores_th, n_limit); 22 | const auto& config = ALIKED_CFGS.at(std::string(model_name)); 23 | desc_head_ = std::make_shared(config.dim, config.K, config.M); 24 | 25 | // Initialize layers 26 | init_layers(model_name); 27 | 28 | // Load weights first 29 | load_weights(model_name); 30 | 31 | // Move everything to the specified device 32 | this->to(device_); 33 | dkd_->to(device_); // Explicitly move DKD 34 | desc_head_->to(device_); // Explicitly move SDDH 35 | 36 | // Double check all submodules are on the correct device 37 | for (const auto& param : this->parameters()) 38 | { 39 | if (param.device() != device_) 40 | { 41 | param.to(device_); 42 | } 43 | } 44 | 45 | for (const auto& buffer : this->buffers()) 46 | { 47 | if (buffer.device() != device_) 48 | { 49 | buffer.to(device_); 50 | } 51 | } 52 | 53 | this->eval(); 54 | } 55 | 56 | std::tuple 57 | ALIKED::extract_dense_map(torch::Tensor image) && { 58 | // Create padder for input 59 | auto padder = InputPadder(image.size(2), image.size(3), 32); 60 | image = std::move(padder).pad(std::move(image)); 61 | 62 | // Feature extraction with move semantics 63 | auto x1 = std::dynamic_pointer_cast(block1_)->forward(image); 64 | auto x2 = std::dynamic_pointer_cast(block2_)->forward(pool2_->forward(x1)); 65 | auto x3 = std::dynamic_pointer_cast(block3_)->forward(pool4_->forward(x2)); 66 | auto x4 = std::dynamic_pointer_cast(block4_)->forward(pool4_->forward(x3)); 67 | 68 | // Feature aggregation 69 | auto x1_processed = torch::selu(conv1_->forward(x1)); 70 | auto x2_processed = torch::selu(conv2_->forward(x2)); 71 | auto x3_processed = torch::selu(conv3_->forward(x3)); 72 | auto x4_processed = torch::selu(conv4_->forward(x4)); 73 | 74 | // Upsample with move semantics 75 | auto options = torch::nn::functional::InterpolateFuncOptions() 76 | .mode(torch::kBilinear) 77 | .align_corners(true); 78 | 79 | auto x2_up = torch::nn::functional::interpolate(x2_processed, 80 | options.size(std::vector{x1.size(2), x1.size(3)})); 81 | auto x3_up = torch::nn::functional::interpolate(x3_processed, 82 | options.size(std::vector{x1.size(2), x1.size(3)})); 83 | auto x4_up = torch::nn::functional::interpolate(x4_processed, 84 | options.size(std::vector{x1.size(2), x1.size(3)})); 85 | 86 | auto x1234 = torch::cat({std::move(x1_processed), 87 | std::move(x2_up), 88 | std::move(x3_up), 89 | std::move(x4_up)}, 90 | 1); 91 | 92 | // Generate score map and feature map 93 | auto score_map = torch::sigmoid(score_head_->forward(x1234.clone())); 94 | auto feature_map = torch::nn::functional::normalize(x1234, 95 | torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); 96 | 97 | // Unpad tensors with move semantics 98 | feature_map = std::move(padder).unpad(std::move(feature_map)); 99 | score_map = std::move(padder).unpad(std::move(score_map)); 100 | 101 | return std::make_tuple(std::move(feature_map), std::move(score_map)); 102 | } 103 | 104 | std::tuple 105 | ALIKED::extract_dense_map(const torch::Tensor& image) & { 106 | auto image_copy = image.clone(); 107 | return std::move(*this).extract_dense_map(std::move(image_copy)); 108 | } 109 | 110 | torch::Dict 111 | ALIKED::forward(torch::Tensor image) && { 112 | 113 | auto start_time = std::chrono::high_resolution_clock::now(); 114 | 115 | auto [feature_map, score_map] = std::move(*this).extract_dense_map(std::move(image)); 116 | auto [keypoints, kptscores, scoredispersitys] = std::move(*dkd_).forward(score_map); 117 | auto [descriptors, offsets] = std::move(*desc_head_).forward(feature_map, keypoints); 118 | 119 | auto end_time = std::chrono::high_resolution_clock::now(); 120 | auto duration = duration_cast(end_time - start_time).count() / 1000.0f; 121 | 122 | torch::Dict output; 123 | output.insert("keypoints", std::move(keypoints[0])); 124 | output.insert("descriptors", std::move(descriptors[0])); 125 | output.insert("scores", std::move(kptscores[0])); 126 | output.insert("score_dispersity", std::move(scoredispersitys[0])); 127 | output.insert("score_map", std::move(score_map)); 128 | output.insert("time", torch::tensor(duration)); 129 | 130 | return output; 131 | } 132 | 133 | torch::Dict 134 | ALIKED::forward(const torch::Tensor& image) & { 135 | auto image_copy = image.clone(); 136 | return std::move(*this).forward(std::move(image_copy)); 137 | } 138 | 139 | torch::Dict 140 | ALIKED::run(cv::Mat& img_rgb) { 141 | cv::Mat float_img; 142 | img_rgb.convertTo(float_img, CV_32F, 1.0 / 255.0); 143 | 144 | std::vector channels(3); 145 | cv::split(float_img, channels); 146 | 147 | auto options = torch::TensorOptions() 148 | .dtype(torch::kFloat32) 149 | .device(device_); 150 | 151 | std::vector tensor_channels; 152 | tensor_channels.reserve(3); 153 | 154 | for (const auto& channel : channels) 155 | { 156 | auto host_tensor = torch::from_blob( 157 | channel.data, 158 | {channel.rows, channel.cols}, 159 | torch::TensorOptions().dtype(torch::kFloat32)); 160 | tensor_channels.push_back(std::move(host_tensor).to(device_)); 161 | } 162 | 163 | auto img_tensor = torch::stack(std::move(tensor_channels), 0) 164 | .unsqueeze(0) 165 | .to(device_); 166 | 167 | // Forward pass with move semantics 168 | auto pred = std::move(*this).forward(std::move(img_tensor)); 169 | 170 | // Convert keypoints from normalized coordinates to image coordinates 171 | auto kpts = pred.at("keypoints"); 172 | const auto h = static_cast(float_img.rows); 173 | const auto w = static_cast(float_img.cols); 174 | const auto wh = torch::tensor({w - 1.0f, h - 1.0f}, kpts.options()); 175 | kpts = wh * (kpts + 1) / 2; 176 | 177 | pred.insert("keypoints", std::move(kpts)); 178 | return pred; 179 | } 180 | 181 | void ALIKED::init_layers(std::string_view model_name) { 182 | const auto& config = ALIKED_CFGS.at(std::string(model_name)); 183 | dim_ = config.dim; 184 | 185 | // Basic layers 186 | pool2_ = register_module("pool2", 187 | torch::nn::AvgPool2d(torch::nn::AvgPool2dOptions(2).stride(2))); 188 | pool4_ = register_module("pool4", 189 | torch::nn::AvgPool2d(torch::nn::AvgPool2dOptions(4).stride(4))); 190 | 191 | // Blocks with move semantics 192 | block1_ = register_module( 193 | "block1", 194 | std::make_shared(3, config.c1, "conv", false)); 195 | 196 | auto downsample2 = torch::nn::Conv2d( 197 | torch::nn::Conv2dOptions(config.c1, config.c2, 1)); 198 | block2_ = register_module( 199 | "block2", 200 | std::make_shared(config.c1, config.c2, 1, downsample2, "conv")); 201 | 202 | auto downsample3 = torch::nn::Conv2d( 203 | torch::nn::Conv2dOptions(config.c2, config.c3, 1)); 204 | block3_ = register_module( 205 | "block3", 206 | std::make_shared(config.c2, config.c3, 1, downsample3, "dcn")); 207 | 208 | auto downsample4 = torch::nn::Conv2d( 209 | torch::nn::Conv2dOptions(config.c3, config.c4, 1)); 210 | block4_ = register_module( 211 | "block4", 212 | std::make_shared(config.c3, config.c4, 1, downsample4, "dcn")); 213 | 214 | // Convolution layers 215 | const int out_channels = dim_ / 4; 216 | conv1_ = register_module("conv1", 217 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c1, out_channels, 1).stride(1).bias(false))); 218 | conv2_ = register_module("conv2", 219 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c2, out_channels, 1).stride(1).bias(false))); 220 | conv3_ = register_module("conv3", 221 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c3, out_channels, 1).stride(1).bias(false))); 222 | conv4_ = register_module("conv4", 223 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c4, out_channels, 1).stride(1).bias(false))); 224 | 225 | // Score head 226 | torch::nn::Sequential score_head; 227 | score_head->push_back(torch::nn::Conv2d( 228 | torch::nn::Conv2dOptions(dim_, 8, 1).stride(1).bias(false))); 229 | score_head->push_back(torch::nn::SELU()); 230 | score_head->push_back(torch::nn::Conv2d( 231 | torch::nn::Conv2dOptions(8, 4, 3).padding(1).stride(1).bias(false))); 232 | score_head->push_back(torch::nn::SELU()); 233 | score_head->push_back(torch::nn::Conv2d( 234 | torch::nn::Conv2dOptions(4, 4, 3).padding(1).stride(1).bias(false))); 235 | score_head->push_back(torch::nn::SELU()); 236 | score_head->push_back(torch::nn::Conv2d( 237 | torch::nn::Conv2dOptions(4, 1, 3).padding(1).stride(1).bias(false))); 238 | 239 | score_head_ = register_module("score_head", score_head); 240 | register_module("desc_head", desc_head_); 241 | register_module("dkd", dkd_); 242 | } 243 | 244 | void ALIKED::load_weights(std::string_view model_name) { 245 | std::vector search_paths = { 246 | std::filesystem::path(ALIKED_MODELS_DIR) / (std::string(model_name) + ".pt"), 247 | std::filesystem::current_path() / "models" / (std::string(model_name) + ".pt"), 248 | std::filesystem::current_path() / (std::string(model_name) + ".pt") 249 | }; 250 | 251 | std::filesystem::path model_path; 252 | bool found = false; 253 | 254 | for (const auto& path : search_paths) { 255 | if (std::filesystem::exists(path)) { 256 | model_path = path; 257 | found = true; 258 | break; 259 | } 260 | } 261 | 262 | if (!found) { 263 | std::string error_msg = "Cannot find pretrained model. Searched in:\n"; 264 | for (const auto& path : search_paths) { 265 | error_msg += " " + path.string() + "\n"; 266 | } 267 | error_msg += "Please place the model file in one of these locations."; 268 | throw std::runtime_error(error_msg); 269 | } 270 | 271 | std::cout << "Loading model from: " << model_path << std::endl; 272 | load_parameters(model_path.string()); 273 | } 274 | 275 | void ALIKED::load_parameters(std::string_view pt_pth) { 276 | auto f = get_the_bytes(pt_pth); 277 | auto weights = torch::pickle_load(f).toGenericDict(); 278 | 279 | // Use unordered_maps for O(1) lookup 280 | std::unordered_map param_map; 281 | std::unordered_map buffer_map; 282 | 283 | auto model_params = named_parameters(); 284 | auto model_buffers = named_buffers(); 285 | // Pre-allocate with expected size 286 | param_map.reserve(model_params.size()); 287 | buffer_map.reserve(model_buffers.size()); 288 | 289 | // Collect parameter names 290 | for (const auto& p : model_params) 291 | { 292 | param_map.emplace(p.key(), p.value()); 293 | } 294 | 295 | // Collect buffer names 296 | for (const auto& b : model_buffers) 297 | { 298 | buffer_map.emplace(b.key(), b.value()); 299 | } 300 | 301 | // Update parameters and buffers 302 | torch::NoGradGuard no_grad; 303 | 304 | for (const auto& w : weights) 305 | { 306 | const auto name = w.key().toStringRef(); 307 | const auto& param = w.value().toTensor(); 308 | 309 | // Try parameters first 310 | if (auto it = param_map.find(name); it != param_map.end()) 311 | { 312 | if (it->second.sizes() == param.sizes()) 313 | { 314 | it->second.copy_(param); 315 | } else 316 | { 317 | throw std::runtime_error( 318 | "Shape mismatch for parameter: " + name + 319 | " Expected: " + std::to_string(it->second.numel()) + 320 | " Got: " + std::to_string(param.numel())); 321 | } 322 | continue; 323 | } 324 | 325 | // Then try buffers 326 | if (auto it = buffer_map.find(name); it != buffer_map.end()) 327 | { 328 | if (it->second.sizes() == param.sizes()) 329 | { 330 | it->second.copy_(param); 331 | } else 332 | { 333 | throw std::runtime_error( 334 | "Shape mismatch for buffer: " + name + 335 | " Expected: " + std::to_string(it->second.numel()) + 336 | " Got: " + std::to_string(param.numel())); 337 | } 338 | continue; 339 | } 340 | 341 | // Parameter not found in model 342 | std::cerr << "Warning: " << name 343 | << " not found in model parameters or buffers\n"; 344 | } 345 | } 346 | 347 | std::vector ALIKED::get_the_bytes(std::string_view filename) { 348 | // Use RAII file handling 349 | std::ifstream file(std::string(filename), std::ios::binary); 350 | if (!file) 351 | { 352 | throw std::runtime_error( 353 | "Failed to open file: " + std::string(filename)); 354 | } 355 | 356 | // Get file size 357 | file.seekg(0, std::ios::end); 358 | const auto size = file.tellg(); 359 | file.seekg(0, std::ios::beg); 360 | 361 | // Pre-allocate vector 362 | std::vector buffer; 363 | buffer.reserve(size); 364 | 365 | // Read file in chunks for better performance 366 | constexpr size_t CHUNK_SIZE = 8192; 367 | char chunk[CHUNK_SIZE]; 368 | 369 | while (file.read(chunk, CHUNK_SIZE)) 370 | { 371 | buffer.insert(buffer.end(), chunk, chunk + file.gcount()); 372 | } 373 | if (file.gcount() > 0) 374 | { 375 | buffer.insert(buffer.end(), chunk, chunk + file.gcount()); 376 | } 377 | 378 | return buffer; 379 | } -------------------------------------------------------------------------------- /src/blocks.cpp: -------------------------------------------------------------------------------- 1 | #include "blocks.hpp" 2 | 3 | #include "deform_conv2d.h" 4 | 5 | DeformableConv2d::DeformableConv2d(int in_channels, int out_channels, 6 | int kernel_size, int stride, int padding, 7 | bool bias) { 8 | padding_ = padding; 9 | const int channel_num = 2 * kernel_size * kernel_size; 10 | 11 | // Register offset conv 12 | offset_conv_ = register_module("offset_conv", 13 | torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, channel_num, kernel_size) 14 | .stride(stride) 15 | .padding(padding) 16 | .bias(true))); 17 | 18 | // Register regular conv 19 | regular_conv_ = register_module("regular_conv", 20 | torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size) 21 | .stride(stride) 22 | .padding(padding) 23 | .bias(bias))); 24 | } 25 | 26 | torch::Tensor DeformableConv2d::forward(const torch::Tensor& x) & { 27 | auto h = x.size(2); 28 | auto w = x.size(3); 29 | float max_offset = std::max(h, w) / 4.0f; 30 | 31 | // Offset and mask 32 | auto offset = offset_conv_->forward(x); 33 | auto mask = torch::zeros( 34 | {offset.size(0), 1}, 35 | torch::TensorOptions().device(offset.device()).dtype(offset.dtype())); 36 | 37 | offset = offset.clamp(-max_offset, max_offset); 38 | 39 | if (!regular_conv_->bias.defined()) 40 | { 41 | regular_conv_->bias = torch::zeros( 42 | {regular_conv_->weight.size(0)}, 43 | torch::TensorOptions().device(x.device()).dtype(x.dtype())); 44 | } 45 | 46 | return vision::ops::deform_conv2d( 47 | x, 48 | regular_conv_->weight, 49 | offset, 50 | mask, 51 | regular_conv_->bias, 52 | 1, 1, 53 | padding_, padding_, 54 | 1, 1, 55 | groups_, 56 | mask_offset_, 57 | false); 58 | } 59 | 60 | torch::Tensor DeformableConv2d::forward(torch::Tensor x) && { 61 | auto h = x.size(2); 62 | auto w = x.size(3); 63 | float max_offset = std::max(h, w) / 4.0f; 64 | 65 | // Offset and mask 66 | auto offset = offset_conv_->forward(std::move(x)); 67 | auto mask = torch::zeros( 68 | {offset.size(0), 1}, 69 | torch::TensorOptions().device(offset.device()).dtype(offset.dtype())); 70 | 71 | offset = std::move(offset).clamp(-max_offset, max_offset); 72 | 73 | if (!regular_conv_->bias.defined()) 74 | { 75 | regular_conv_->bias = torch::zeros( 76 | {regular_conv_->weight.size(0)}, 77 | torch::TensorOptions().device(x.device()).dtype(x.dtype())); 78 | } 79 | 80 | return vision::ops::deform_conv2d( 81 | std::move(x), 82 | regular_conv_->weight, 83 | std::move(offset), 84 | std::move(mask), 85 | regular_conv_->bias, 86 | 1, 1, 87 | padding_, padding_, 88 | 1, 1, 89 | groups_, 90 | mask_offset_, 91 | false); 92 | } 93 | 94 | ConvBlock::ConvBlock(int in_channels, int out_channels, 95 | std::string_view conv_type, bool mask) { 96 | 97 | if (conv_type == "conv") 98 | { 99 | auto conv1 = torch::nn::Conv2d((torch::nn::Conv2dOptions(in_channels, out_channels, 3) 100 | .stride(1) 101 | .padding(1) 102 | .bias(false))); 103 | conv1_ = register_module("conv1", conv1); 104 | 105 | auto conv2 = torch::nn::Conv2d((torch::nn::Conv2dOptions(out_channels, out_channels, 3) 106 | .stride(1) 107 | .padding(1) 108 | .bias(false))); 109 | conv2_ = register_module("conv2", conv2); 110 | 111 | } else 112 | { 113 | auto conv1 = std::make_shared( 114 | in_channels, 115 | out_channels, 116 | 3, 117 | 1, 118 | 1, 119 | false); 120 | deform1_ = register_module("conv1", conv1); 121 | 122 | auto conv2 = std::make_shared( 123 | out_channels, 124 | out_channels, 125 | 3, 126 | 1, 127 | 1, 128 | false); 129 | deform2_ = register_module("conv2", conv2); 130 | } 131 | 132 | bn1_ = register_module("bn1", torch::nn::BatchNorm2d(out_channels)); 133 | bn2_ = register_module("bn2", torch::nn::BatchNorm2d(out_channels)); 134 | } 135 | 136 | ResBlock::ResBlock(int inplanes, int planes, int stride, 137 | const torch::nn::Conv2d& downsample, 138 | std::string_view conv_type) 139 | : downsample_(downsample) { 140 | 141 | if (conv_type == "conv") 142 | { 143 | auto conv1 = torch::nn::Conv2d((torch::nn::Conv2dOptions(inplanes, planes, 3) 144 | .stride(stride) 145 | .padding(1) 146 | .bias(false))); 147 | conv1_ = register_module("conv1", conv1); 148 | 149 | auto conv2 = torch::nn::Conv2d((torch::nn::Conv2dOptions(planes, planes, 3) 150 | .stride(stride) 151 | .padding(1) 152 | .bias(false))); 153 | conv2_ = register_module("conv2", conv2); 154 | 155 | } else 156 | { 157 | auto conv1 = std::make_shared( 158 | inplanes, 159 | planes, 160 | 3, 161 | 1, 162 | 1, 163 | false); 164 | deform1_ = register_module("conv1", conv1); 165 | 166 | auto conv2 = std::make_shared( 167 | planes, 168 | planes, 169 | 3, 170 | 1, 171 | 1, 172 | false); 173 | deform2_ = register_module("conv2", conv2); 174 | } 175 | 176 | bn1_ = register_module("bn1", 177 | torch::nn::BatchNorm2d(planes)); 178 | bn2_ = register_module("bn2", 179 | torch::nn::BatchNorm2d(planes)); 180 | 181 | if (downsample) 182 | { 183 | register_module("downsample", downsample); 184 | } 185 | } 186 | 187 | torch::Tensor ConvBlock::forward(torch::Tensor x) && { 188 | return std::move(*this).forward(std::move(x)); 189 | } 190 | 191 | torch::Tensor ConvBlock::forward(const torch::Tensor& x) & { 192 | if (conv1_ && conv2_) 193 | { 194 | auto tmp = torch::selu(bn1_->forward(conv1_->forward(x))); 195 | return torch::selu(bn2_->forward(conv2_->forward(std::move(tmp)))); 196 | } else 197 | { 198 | auto tmp = torch::selu(bn1_->forward(deform1_->forward(x))); 199 | return torch::selu(bn2_->forward(deform2_->forward(std::move(tmp)))); 200 | } 201 | } 202 | 203 | torch::Tensor ResBlock::forward(torch::Tensor x) && { 204 | return std::move(*this).forward(std::move(x)); 205 | } 206 | 207 | torch::Tensor ResBlock::forward(const torch::Tensor& x) & { 208 | auto identity = x; 209 | 210 | torch::Tensor processed; 211 | if (conv1_ && conv2_) 212 | { 213 | auto tmp = conv1_->forward(x); 214 | tmp = bn1_->forward(std::move(tmp)); 215 | tmp = torch::selu(std::move(tmp)); 216 | 217 | processed = conv2_->forward(std::move(tmp)); 218 | processed = bn2_->forward(std::move(processed)); 219 | } else 220 | { 221 | auto tmp = deform1_->forward(x); 222 | tmp = bn1_->forward(std::move(tmp)); 223 | tmp = torch::selu(std::move(tmp)); 224 | 225 | processed = deform2_->forward(std::move(tmp)); 226 | processed = bn2_->forward(std::move(processed)); 227 | } 228 | 229 | if (downsample_) 230 | { 231 | identity = downsample_->as()->forward(std::move(identity)); 232 | } 233 | 234 | processed += identity; 235 | return torch::selu(std::move(processed)); 236 | } -------------------------------------------------------------------------------- /src/deform_conv2d.cpp: -------------------------------------------------------------------------------- 1 | #include "deform_conv2d.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace vision { 8 | namespace ops { 9 | 10 | at::Tensor deform_conv2d( 11 | const at::Tensor& input, 12 | const at::Tensor& weight, 13 | const at::Tensor& offset, 14 | const at::Tensor& mask, 15 | const at::Tensor& bias, 16 | int64_t stride_h, 17 | int64_t stride_w, 18 | int64_t pad_h, 19 | int64_t pad_w, 20 | int64_t dilation_h, 21 | int64_t dilation_w, 22 | int64_t groups, 23 | int64_t offset_groups, 24 | bool use_mask) { 25 | C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); 26 | static auto op = c10::Dispatcher::singleton() 27 | .findSchemaOrThrow("torchvision::deform_conv2d", "") 28 | .typed(); 29 | return op.call( 30 | input, 31 | weight, 32 | offset, 33 | mask, 34 | bias, 35 | stride_h, 36 | stride_w, 37 | pad_h, 38 | pad_w, 39 | dilation_h, 40 | dilation_w, 41 | groups, 42 | offset_groups, 43 | use_mask); 44 | } 45 | 46 | at::Tensor deform_conv2d_symint( 47 | const at::Tensor& input, 48 | const at::Tensor& weight, 49 | const at::Tensor& offset, 50 | const at::Tensor& mask, 51 | const at::Tensor& bias, 52 | c10::SymInt stride_h, 53 | c10::SymInt stride_w, 54 | c10::SymInt pad_h, 55 | c10::SymInt pad_w, 56 | c10::SymInt dilation_h, 57 | c10::SymInt dilation_w, 58 | c10::SymInt groups, 59 | c10::SymInt offset_groups, 60 | bool use_mask) { 61 | C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); 62 | static auto op = c10::Dispatcher::singleton() 63 | .findSchemaOrThrow("torchvision::deform_conv2d", "") 64 | .typed(); 65 | return op.call( 66 | input, 67 | weight, 68 | offset, 69 | mask, 70 | bias, 71 | stride_h, 72 | stride_w, 73 | pad_h, 74 | pad_w, 75 | dilation_h, 76 | dilation_w, 77 | groups, 78 | offset_groups, 79 | use_mask); 80 | } 81 | 82 | namespace detail { 83 | 84 | std::tuple 85 | _deform_conv2d_backward( 86 | const at::Tensor& grad, 87 | const at::Tensor& input, 88 | const at::Tensor& weight, 89 | const at::Tensor& offset, 90 | const at::Tensor& mask, 91 | const at::Tensor& bias, 92 | int64_t stride_h, 93 | int64_t stride_w, 94 | int64_t pad_h, 95 | int64_t pad_w, 96 | int64_t dilation_h, 97 | int64_t dilation_w, 98 | int64_t groups, 99 | int64_t offset_groups, 100 | bool use_mask) { 101 | static auto op = 102 | c10::Dispatcher::singleton() 103 | .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") 104 | .typed(); 105 | return op.call( 106 | grad, 107 | input, 108 | weight, 109 | offset, 110 | mask, 111 | bias, 112 | stride_h, 113 | stride_w, 114 | pad_h, 115 | pad_w, 116 | dilation_h, 117 | dilation_w, 118 | groups, 119 | offset_groups, 120 | use_mask); 121 | } 122 | 123 | std::tuple 124 | _deform_conv2d_backward_symint( 125 | const at::Tensor& grad, 126 | const at::Tensor& input, 127 | const at::Tensor& weight, 128 | const at::Tensor& offset, 129 | const at::Tensor& mask, 130 | const at::Tensor& bias, 131 | c10::SymInt stride_h, 132 | c10::SymInt stride_w, 133 | c10::SymInt pad_h, 134 | c10::SymInt pad_w, 135 | c10::SymInt dilation_h, 136 | c10::SymInt dilation_w, 137 | c10::SymInt groups, 138 | c10::SymInt offset_groups, 139 | bool use_mask) { 140 | static auto op = 141 | c10::Dispatcher::singleton() 142 | .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") 143 | .typed(); 144 | return op.call( 145 | grad, 146 | input, 147 | weight, 148 | offset, 149 | mask, 150 | bias, 151 | stride_h, 152 | stride_w, 153 | pad_h, 154 | pad_w, 155 | dilation_h, 156 | dilation_w, 157 | groups, 158 | offset_groups, 159 | use_mask); 160 | } 161 | 162 | } // namespace detail 163 | 164 | TORCH_LIBRARY_FRAGMENT(torchvision, m) { 165 | m.def(TORCH_SELECTIVE_SCHEMA( 166 | "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor")); 167 | m.def(TORCH_SELECTIVE_SCHEMA( 168 | "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); 169 | } 170 | 171 | } // namespace ops 172 | } // namespace vision 173 | -------------------------------------------------------------------------------- /src/deform_conv2d_kernel.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ******************* BEGIN Caffe Copyright Notice and Disclaimer 3 | ***************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, 28 | *this list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 34 | *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 35 | *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 37 | *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 38 | *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 39 | *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 41 | *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 42 | *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer 51 | ********************* 52 | * 53 | * Copyright (c) 2018 Microsoft 54 | * Licensed under The MIT License [see LICENSE for details] 55 | * \file modulated_deformable_im2col.cuh 56 | * \brief Function definitions of converting an image to 57 | * column matrix based on kernel, padding, dilation, and offset. 58 | * These functions are mainly used in deformable convolution operators. 59 | * \ref: https://arxiv.org/abs/1703.06211 60 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng 61 | */ 62 | 63 | // modified from 64 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu 65 | 66 | // modified from 67 | // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp 68 | 69 | #include "cuda_helpers.h" 70 | #include 71 | #include 72 | #include 73 | #include 74 | #include 75 | 76 | namespace vision { 77 | namespace ops { 78 | 79 | namespace { 80 | 81 | const int kMaxParallelImgs = 32; 82 | 83 | inline unsigned int GET_THREADS() { 84 | #ifdef WITH_HIP 85 | return 256; 86 | #endif 87 | return 512; 88 | } 89 | 90 | inline unsigned int GET_BLOCKS(const unsigned int THREADS, const int64_t N) { 91 | int64_t kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; 92 | return (unsigned int)std::min(kMaxGridNum, (N + THREADS - 1) / THREADS); 93 | } 94 | 95 | template 96 | __device__ scalar_t bilinear_interpolate( 97 | const scalar_t* in, 98 | index_t height, 99 | index_t width, 100 | scalar_t h, 101 | scalar_t w) { 102 | if (h <= -1 || height <= h || w <= -1 || width <= w) 103 | { 104 | return 0; 105 | } 106 | 107 | index_t h_low = floor(h); 108 | index_t w_low = floor(w); 109 | index_t h_high = h_low + 1; 110 | index_t w_high = w_low + 1; 111 | 112 | scalar_t lh = h - h_low; 113 | scalar_t lw = w - w_low; 114 | scalar_t hh = 1 - lh, hw = 1 - lw; 115 | 116 | scalar_t v1 = 0; 117 | if (h_low >= 0 && w_low >= 0) 118 | v1 = in[h_low * width + w_low]; 119 | scalar_t v2 = 0; 120 | if (h_low >= 0 && w_high <= width - 1) 121 | v2 = in[h_low * width + w_high]; 122 | scalar_t v3 = 0; 123 | if (h_high <= height - 1 && w_low >= 0) 124 | v3 = in[h_high * width + w_low]; 125 | scalar_t v4 = 0; 126 | if (h_high <= height - 1 && w_high <= width - 1) 127 | v4 = in[h_high * width + w_high]; 128 | 129 | scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 130 | 131 | scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 132 | return val; 133 | } 134 | 135 | template 136 | __global__ void deformable_im2col_kernel( 137 | index_t n, 138 | const scalar_t* input_ptr, 139 | const scalar_t* offset_ptr, 140 | const scalar_t* mask_ptr, 141 | index_t height, 142 | index_t width, 143 | index_t weight_h, 144 | index_t weight_w, 145 | index_t pad_h, 146 | index_t pad_w, 147 | index_t stride_h, 148 | index_t stride_w, 149 | index_t dilation_h, 150 | index_t dilation_w, 151 | index_t batch_sz, 152 | index_t n_in_channels, 153 | index_t n_offset_grps, 154 | index_t out_h, 155 | index_t out_w, 156 | bool use_mask, 157 | scalar_t* columns_ptr) { 158 | CUDA_1D_KERNEL_LOOP_T(index, n, index_t) { 159 | const index_t out_x = index % out_w; 160 | const index_t out_y = (index / out_w) % out_h; 161 | const index_t out_b = (index / (out_w * out_h)) % batch_sz; 162 | const index_t in_c = index / (out_w * out_h * batch_sz); 163 | const index_t out_c = in_c * weight_h * weight_w; 164 | 165 | index_t c_per_offset_grp = n_in_channels / n_offset_grps; 166 | const index_t grp_idx = in_c / c_per_offset_grp; 167 | 168 | columns_ptr += 169 | (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + 170 | out_y * out_w + out_x); 171 | 172 | input_ptr += 173 | (out_b * (n_in_channels * height * width) + in_c * (height * width)); 174 | 175 | offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * 176 | out_h * out_w; 177 | 178 | if (use_mask) 179 | { 180 | mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * 181 | out_h * out_w; 182 | } 183 | 184 | for (int i = 0; i < weight_h; ++i) 185 | { 186 | for (int j = 0; j < weight_w; ++j) 187 | { 188 | const index_t mask_idx = i * weight_w + j; 189 | const index_t offset_idx = 2 * mask_idx; 190 | 191 | scalar_t mask_value = 1; 192 | if (use_mask) 193 | { 194 | mask_value = 195 | mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; 196 | } 197 | 198 | const scalar_t offset_h = 199 | offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; 200 | const scalar_t offset_w = offset_ptr 201 | [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; 202 | const scalar_t y = 203 | (out_y * stride_h - pad_h) + i * dilation_h + offset_h; 204 | const scalar_t x = 205 | (out_x * stride_w - pad_w) + j * dilation_w + offset_w; 206 | *columns_ptr = 207 | mask_value * bilinear_interpolate(input_ptr, height, width, y, x); 208 | columns_ptr += batch_sz * out_h * out_w; 209 | } 210 | } 211 | } 212 | } 213 | 214 | void deformable_im2col( 215 | const at::Tensor& input, 216 | const at::Tensor& data_offset, 217 | const at::Tensor& data_mask, 218 | int n_in_channels, 219 | int height, 220 | int width, 221 | int weight_h, 222 | int weight_w, 223 | int pad_h, 224 | int pad_w, 225 | int stride_h, 226 | int stride_w, 227 | int dilation_h, 228 | int dilation_w, 229 | int out_h, 230 | int out_w, 231 | int parallel_imgs, 232 | int deformable_group, 233 | bool use_mask, 234 | at::Tensor data_col) { 235 | at::cuda::CUDAGuard device_guard(input.get_device()); 236 | 237 | const int64_t num_kernels = 238 | (int64_t)n_in_channels * out_h * out_w * parallel_imgs; 239 | 240 | const unsigned int threads = GET_THREADS(); 241 | const unsigned int blocks = GET_BLOCKS(threads, num_kernels); 242 | 243 | // Checks if we should use 64bits indexing 244 | // https://github.com/pytorch/vision/issues/4269 245 | bool use_64bits_indexing = false; 246 | // Checks if num_kernels or columns numel larger than 2 ** 31 247 | use_64bits_indexing |= num_kernels > (1 << 31); 248 | use_64bits_indexing |= 249 | ((int64_t)n_in_channels * weight_h * weight_w * parallel_imgs * out_h * 250 | out_w > 251 | (1 << 31)); 252 | 253 | if (use_64bits_indexing) 254 | { 255 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 256 | input.scalar_type(), "deformable_im2col", ([&] { 257 | deformable_im2col_kernel<<>>( 258 | num_kernels, 259 | input.data_ptr(), 260 | data_offset.data_ptr(), 261 | data_mask.data_ptr(), 262 | height, 263 | width, 264 | weight_h, 265 | weight_w, 266 | pad_h, 267 | pad_w, 268 | stride_h, 269 | stride_w, 270 | dilation_h, 271 | dilation_w, 272 | parallel_imgs, 273 | n_in_channels, 274 | deformable_group, 275 | out_h, 276 | out_w, 277 | use_mask, 278 | data_col.data_ptr()); 279 | })); 280 | 281 | } else 282 | { 283 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 284 | input.scalar_type(), "deformable_im2col", ([&] { 285 | deformable_im2col_kernel<<>>( 286 | num_kernels, 287 | input.data_ptr(), 288 | data_offset.data_ptr(), 289 | data_mask.data_ptr(), 290 | height, 291 | width, 292 | weight_h, 293 | weight_w, 294 | pad_h, 295 | pad_w, 296 | stride_h, 297 | stride_w, 298 | dilation_h, 299 | dilation_w, 300 | parallel_imgs, 301 | n_in_channels, 302 | deformable_group, 303 | out_h, 304 | out_w, 305 | use_mask, 306 | data_col.data_ptr()); 307 | })); 308 | } 309 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 310 | } 311 | 312 | int get_greatest_divisor_below_bound(int n, int bound) { 313 | for (int k = bound; k > 1; --k) 314 | { 315 | if (n % k == 0) 316 | { 317 | return k; 318 | } 319 | } 320 | return 1; 321 | } 322 | 323 | template 324 | __global__ void deformable_col2im_kernel( 325 | index_t n, 326 | const scalar_t* col, 327 | const scalar_t* offset_ptr, 328 | const scalar_t* mask_ptr, 329 | index_t channels, 330 | index_t height, 331 | index_t width, 332 | index_t kernel_h, 333 | index_t kernel_w, 334 | index_t pad_h, 335 | index_t pad_w, 336 | index_t stride_h, 337 | index_t stride_w, 338 | index_t dilation_h, 339 | index_t dilation_w, 340 | index_t batch_sz, 341 | index_t n_offset_grps, 342 | index_t out_h, 343 | index_t out_w, 344 | bool use_mask, 345 | scalar_t* grad_im) { 346 | const index_t grad_im_numel = width * height * channels * batch_sz; 347 | 348 | CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) { 349 | const index_t out_x = index % out_w; 350 | const index_t out_y = (index / out_w) % out_h; 351 | const index_t b = (index / (out_w * out_h)) % batch_sz; 352 | const index_t j = (index / (out_w * out_h * batch_sz)) % kernel_w; 353 | const index_t i = 354 | (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; 355 | const index_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); 356 | 357 | index_t c_per_offset_grp = channels / n_offset_grps; 358 | const index_t offset_grp = c / c_per_offset_grp; 359 | 360 | offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * 361 | out_h * out_w; 362 | 363 | if (use_mask) 364 | { 365 | mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * 366 | out_h * out_w; 367 | } 368 | 369 | const index_t mask_idx = i * kernel_w + j; 370 | const index_t offset_idx = 2 * mask_idx; 371 | 372 | const index_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; 373 | const index_t offset_w_ptr = 374 | ((offset_idx + 1) * out_h + out_y) * out_w + out_x; 375 | 376 | const scalar_t offset_h = offset_ptr[offset_h_ptr]; 377 | const scalar_t offset_w = offset_ptr[offset_w_ptr]; 378 | 379 | scalar_t mask_value = 1; 380 | if (use_mask) 381 | { 382 | mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; 383 | } 384 | 385 | const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; 386 | const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; 387 | 388 | for (index_t dy = -1; dy <= 1; dy++) 389 | { 390 | for (index_t dx = -1; dx <= 1; dx++) 391 | { 392 | index_t yp = (index_t)y + dy; 393 | index_t xp = (index_t)x + dx; 394 | if (0 <= yp && yp < height && 0 <= xp && xp < width && 395 | std::abs(y - yp) < 1 && std::abs(x - xp) < 1) 396 | { 397 | index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; 398 | scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); 399 | at::native::fastAtomicAdd( 400 | grad_im, 401 | grad_pos, 402 | grad_im_numel, 403 | mask_value * weight * col[index], 404 | true); 405 | } 406 | } 407 | } 408 | } 409 | } 410 | 411 | void compute_grad_input( 412 | const at::Tensor& columns, 413 | const at::Tensor& offset, 414 | const at::Tensor& mask, 415 | int channels, 416 | int height, 417 | int width, 418 | int weight_h, 419 | int weight_w, 420 | int pad_h, 421 | int pad_w, 422 | int stride_h, 423 | int stride_w, 424 | int dilation_h, 425 | int dilation_w, 426 | int parallel_imgs, 427 | int n_offset_grps, 428 | bool use_mask, 429 | at::Tensor grad_im) { 430 | at::cuda::CUDAGuard device_guard(columns.get_device()); 431 | 432 | const int out_h = 433 | (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; 434 | const int out_w = 435 | (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; 436 | 437 | const int64_t num_kernels = 438 | (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; 439 | 440 | const unsigned int threads = GET_THREADS(); 441 | const unsigned int blocks = GET_BLOCKS(threads, num_kernels); 442 | 443 | // Checks if we should use 64bits indexing 444 | // https://github.com/pytorch/vision/issues/4269 445 | bool use_64bits_indexing = false; 446 | // Checks if num_kernels or columns numel larger than 2 ** 31 447 | use_64bits_indexing |= num_kernels > (1 << 31); 448 | 449 | at::globalContext().alertNotDeterministic("compute_grad_input"); 450 | 451 | if (use_64bits_indexing) 452 | { 453 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 454 | columns.scalar_type(), "compute_grad_input", ([&] { 455 | deformable_col2im_kernel<<>>( 456 | num_kernels, 457 | columns.data_ptr(), 458 | offset.data_ptr(), 459 | mask.data_ptr(), 460 | channels, 461 | height, 462 | width, 463 | weight_h, 464 | weight_w, 465 | pad_h, 466 | pad_w, 467 | stride_h, 468 | stride_w, 469 | dilation_h, 470 | dilation_w, 471 | parallel_imgs, 472 | n_offset_grps, 473 | out_h, 474 | out_w, 475 | use_mask, 476 | grad_im.data_ptr()); 477 | })); 478 | } else 479 | { 480 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 481 | columns.scalar_type(), "compute_grad_input", ([&] { 482 | deformable_col2im_kernel<<>>( 483 | num_kernels, 484 | columns.data_ptr(), 485 | offset.data_ptr(), 486 | mask.data_ptr(), 487 | channels, 488 | height, 489 | width, 490 | weight_h, 491 | weight_w, 492 | pad_h, 493 | pad_w, 494 | stride_h, 495 | stride_w, 496 | dilation_h, 497 | dilation_w, 498 | parallel_imgs, 499 | n_offset_grps, 500 | out_h, 501 | out_w, 502 | use_mask, 503 | grad_im.data_ptr()); 504 | })); 505 | } 506 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 507 | } 508 | 509 | template 510 | __device__ scalar_t get_coordinate_weight( 511 | const scalar_t* im_data, 512 | index_t height, 513 | index_t width, 514 | scalar_t y, 515 | scalar_t x, 516 | bool is_y_direction) { 517 | index_t y_l = floor(y); 518 | index_t x_l = floor(x); 519 | index_t y_h = y_l + 1; 520 | index_t x_h = x_l + 1; 521 | 522 | bool valid_y_l = 0 <= y_l && y_l < height; 523 | bool valid_y_h = 0 <= y_h && y_h < height; 524 | bool valid_x_l = 0 <= x_l && x_l < width; 525 | bool valid_x_h = 0 <= x_h && x_h < width; 526 | 527 | scalar_t zero = 0; 528 | scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; 529 | scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; 530 | scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; 531 | scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; 532 | 533 | if (is_y_direction) 534 | { 535 | scalar_t dx = x - x_l; 536 | return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); 537 | } else 538 | { 539 | scalar_t dy = y - y_l; 540 | return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); 541 | } 542 | } 543 | 544 | template 545 | __global__ void deformable_col2im_coord_kernel( 546 | index_t n, 547 | const scalar_t* col_ptr, 548 | const scalar_t* im_ptr, 549 | const scalar_t* offset_ptr, 550 | const scalar_t* mask_ptr, 551 | index_t channels, 552 | index_t height, 553 | index_t width, 554 | index_t weight_h, 555 | index_t weight_w, 556 | index_t pad_h, 557 | index_t pad_w, 558 | index_t stride_h, 559 | index_t stride_w, 560 | index_t dilation_h, 561 | index_t dilation_w, 562 | index_t batch_sz, 563 | index_t offset_channels, 564 | index_t n_offset_grps, 565 | index_t out_h, 566 | index_t out_w, 567 | const bool use_mask, 568 | scalar_t* grad_offset, 569 | scalar_t* grad_mask) { 570 | CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) { 571 | scalar_t grad_offset_val = 0; 572 | scalar_t grad_mask_val = 0; 573 | 574 | index_t w = index % out_w; 575 | index_t h = (index / out_w) % out_h; 576 | index_t w_w = (index / (out_w * out_h * 2)) % weight_w; 577 | index_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; 578 | index_t c = (index / (out_w * out_h)) % offset_channels; 579 | index_t b = index / (out_w * out_h * offset_channels); 580 | 581 | const index_t offset_grp = c / (2 * weight_h * weight_w); 582 | const index_t col_step = weight_h * weight_w; 583 | 584 | index_t c_per_offset_grp = channels / n_offset_grps; 585 | 586 | col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * 587 | out_w * out_h; 588 | im_ptr += 589 | (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; 590 | offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * 591 | out_h * out_w; 592 | 593 | if (use_mask) 594 | { 595 | mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * 596 | out_h * out_w; 597 | } 598 | 599 | const index_t offset_c = c - offset_grp * 2 * weight_h * weight_w; 600 | const bool is_y_direction = offset_c % 2 == 0; 601 | 602 | const index_t c_bound = c_per_offset_grp * weight_h * weight_w; 603 | for (index_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) 604 | { 605 | const index_t col_pos = 606 | (((col_c * batch_sz + b) * out_h) + h) * out_w + w; 607 | 608 | index_t out_x = col_pos % out_w; 609 | index_t out_y = (col_pos / out_w) % out_h; 610 | index_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; 611 | index_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; 612 | 613 | const index_t mask_idx = i * weight_w + j; 614 | 615 | const index_t offset_h_ptr = 616 | (((2 * mask_idx) * out_h + out_y) * out_w + out_x); 617 | const index_t offset_w_ptr = 618 | (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); 619 | const scalar_t offset_h = offset_ptr[offset_h_ptr]; 620 | const scalar_t offset_w = offset_ptr[offset_w_ptr]; 621 | 622 | scalar_t mask_value = 1; 623 | if (use_mask) 624 | { 625 | mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; 626 | } 627 | 628 | scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; 629 | scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; 630 | 631 | const scalar_t weight = 632 | get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); 633 | grad_offset_val += mask_value * weight * col_ptr[col_pos]; 634 | 635 | if (use_mask && is_y_direction) 636 | { 637 | grad_mask_val += col_ptr[col_pos] * 638 | bilinear_interpolate(im_ptr, height, width, y, x); 639 | } 640 | 641 | im_ptr += height * width; 642 | } 643 | 644 | grad_offset[index] = grad_offset_val; 645 | 646 | if (use_mask && is_y_direction) 647 | { 648 | const index_t idx = 649 | ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + 650 | w_w) * 651 | out_h + 652 | h) * 653 | out_w + 654 | w; 655 | grad_mask[idx] = grad_mask_val; 656 | } 657 | } 658 | } 659 | 660 | void compute_grad_offset_and_mask( 661 | const at::Tensor& columns, 662 | const at::Tensor& input, 663 | const at::Tensor& offset, 664 | const at::Tensor& mask, 665 | int channels, 666 | int height, 667 | int width, 668 | int weight_h, 669 | int weight_w, 670 | int pad_h, 671 | int pad_w, 672 | int stride_h, 673 | int stride_w, 674 | int dilation_h, 675 | int dilation_w, 676 | int parallel_imgs, 677 | int n_offset_grps, 678 | bool use_mask, 679 | at::Tensor grad_offset, 680 | at::Tensor grad_mask) { 681 | at::cuda::CUDAGuard device_guard(columns.get_device()); 682 | 683 | const int out_h = 684 | (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; 685 | const int out_w = 686 | (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; 687 | const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * 688 | n_offset_grps * parallel_imgs; 689 | 690 | const unsigned int threads = GET_THREADS(); 691 | const unsigned int blocks = GET_BLOCKS(threads, num_kernels); 692 | 693 | // Checks if we should use 64bits indexing 694 | // https://github.com/pytorch/vision/issues/4269 695 | bool use_64bits_indexing = false; 696 | // Checks if columns numel is larger than 2 ** 31 697 | use_64bits_indexing |= num_kernels > (1 << 31); 698 | use_64bits_indexing |= 699 | ((int64_t)channels * weight_h * weight_w * parallel_imgs * out_h * out_w > 700 | (1 << 31)); 701 | 702 | if (use_64bits_indexing) 703 | { 704 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 705 | columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { 706 | deformable_col2im_coord_kernel 707 | <<>>( 708 | num_kernels, 709 | columns.data_ptr(), 710 | input.data_ptr(), 711 | offset.data_ptr(), 712 | mask.data_ptr(), 713 | channels, 714 | height, 715 | width, 716 | weight_h, 717 | weight_w, 718 | pad_h, 719 | pad_w, 720 | stride_h, 721 | stride_w, 722 | dilation_h, 723 | dilation_w, 724 | parallel_imgs, 725 | 2 * weight_h * weight_w * n_offset_grps, 726 | n_offset_grps, 727 | out_h, 728 | out_w, 729 | use_mask, 730 | grad_offset.data_ptr(), 731 | grad_mask.data_ptr()); 732 | })); 733 | } else 734 | { 735 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 736 | columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { 737 | deformable_col2im_coord_kernel<<>>( 738 | num_kernels, 739 | columns.data_ptr(), 740 | input.data_ptr(), 741 | offset.data_ptr(), 742 | mask.data_ptr(), 743 | channels, 744 | height, 745 | width, 746 | weight_h, 747 | weight_w, 748 | pad_h, 749 | pad_w, 750 | stride_h, 751 | stride_w, 752 | dilation_h, 753 | dilation_w, 754 | parallel_imgs, 755 | 2 * weight_h * weight_w * n_offset_grps, 756 | n_offset_grps, 757 | out_h, 758 | out_w, 759 | use_mask, 760 | grad_offset.data_ptr(), 761 | grad_mask.data_ptr()); 762 | })); 763 | } 764 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 765 | } 766 | 767 | std::tuple backward_gradient_inputs( 768 | at::Tensor input, 769 | at::Tensor weight, 770 | at::Tensor offset, 771 | at::Tensor mask, 772 | at::Tensor grad_out, 773 | int stride_h, 774 | int stride_w, 775 | int pad_h, 776 | int pad_w, 777 | int dilation_h, 778 | int dilation_w, 779 | int n_weight_grps, 780 | int n_offset_grps, 781 | int n_parallel_imgs, 782 | bool use_mask) { 783 | at::DeviceGuard guard(input.device()); 784 | 785 | int batch_sz = input.size(0); 786 | long n_in_channels = input.size(1); 787 | long in_h = input.size(2); 788 | long in_w = input.size(3); 789 | 790 | n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); 791 | 792 | long n_out_channels = weight.size(0); 793 | int weight_h = weight.size(2); 794 | int weight_w = weight.size(3); 795 | 796 | long out_w = 797 | (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; 798 | long out_h = 799 | (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; 800 | 801 | auto grad_input = at::zeros_like(input); 802 | auto grad_offset = at::zeros_like(offset); 803 | auto grad_mask = at::zeros_like(mask); 804 | 805 | if (batch_sz == 0) 806 | { 807 | return std::make_tuple(grad_input, grad_offset, grad_mask); 808 | } 809 | 810 | auto columns = at::empty( 811 | {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, 812 | input.options()); 813 | 814 | // Separate into blocks 815 | grad_input = grad_input.reshape( 816 | {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); 817 | input = input.reshape( 818 | {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); 819 | 820 | grad_offset = grad_offset.reshape( 821 | {batch_sz / n_parallel_imgs, 822 | n_parallel_imgs, 823 | n_offset_grps * 2 * weight_h * weight_w, 824 | out_h, 825 | out_w}); 826 | offset = offset.reshape( 827 | {batch_sz / n_parallel_imgs, 828 | n_parallel_imgs, 829 | n_offset_grps * 2 * weight_h * weight_w, 830 | out_h, 831 | out_w}); 832 | 833 | if (use_mask) 834 | { 835 | grad_mask = grad_mask.reshape( 836 | {batch_sz / n_parallel_imgs, 837 | n_parallel_imgs, 838 | n_offset_grps * weight_h * weight_w, 839 | out_h, 840 | out_w}); 841 | mask = mask.reshape( 842 | {batch_sz / n_parallel_imgs, 843 | n_parallel_imgs, 844 | n_offset_grps * weight_h * weight_w, 845 | out_h, 846 | out_w}); 847 | } 848 | 849 | grad_out = grad_out 850 | .reshape( 851 | {batch_sz / n_parallel_imgs, 852 | n_parallel_imgs, 853 | n_weight_grps, 854 | n_out_channels / n_weight_grps, 855 | out_h, 856 | out_w}) 857 | .permute({0, 2, 3, 1, 4, 5}); 858 | 859 | weight = weight.reshape( 860 | {n_weight_grps, 861 | weight.size(0) / n_weight_grps, 862 | weight.size(1), 863 | weight.size(2), 864 | weight.size(3)}); 865 | 866 | columns = columns.view( 867 | {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); 868 | for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) 869 | { 870 | columns.zero_(); 871 | // Separate into weight groups 872 | for (int g = 0; g < n_weight_grps; g++) 873 | { 874 | columns[g] = columns[g].addmm_( 875 | weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); 876 | } 877 | 878 | compute_grad_offset_and_mask( 879 | columns, 880 | input[elt], 881 | offset[elt], 882 | mask[elt], 883 | n_in_channels, 884 | in_h, 885 | in_w, 886 | weight_h, 887 | weight_w, 888 | pad_h, 889 | pad_w, 890 | stride_h, 891 | stride_w, 892 | dilation_h, 893 | dilation_w, 894 | n_parallel_imgs, 895 | n_offset_grps, 896 | use_mask, 897 | grad_offset[elt], 898 | grad_mask[elt]); 899 | 900 | compute_grad_input( 901 | columns, 902 | offset[elt], 903 | mask[elt], 904 | n_in_channels, 905 | in_h, 906 | in_w, 907 | weight_h, 908 | weight_w, 909 | pad_h, 910 | pad_w, 911 | stride_h, 912 | stride_w, 913 | dilation_h, 914 | dilation_w, 915 | n_parallel_imgs, 916 | n_offset_grps, 917 | use_mask, 918 | grad_input[elt]); 919 | } 920 | 921 | grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); 922 | grad_offset = grad_offset.view( 923 | {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); 924 | 925 | if (use_mask) 926 | { 927 | grad_mask = grad_mask.view( 928 | {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); 929 | } 930 | 931 | return std::make_tuple(grad_input, grad_offset, grad_mask); 932 | } 933 | 934 | at::Tensor backward_gradient_parameters( 935 | at::Tensor input, 936 | const at::Tensor& weight, 937 | at::Tensor offset, 938 | at::Tensor mask, 939 | const at::Tensor& grad_out, 940 | int stride_h, 941 | int stride_w, 942 | int pad_h, 943 | int pad_w, 944 | int dilation_h, 945 | int dilation_w, 946 | int n_weight_grps, 947 | int n_offset_grps, 948 | int n_parallel_imgs, 949 | bool use_mask) { 950 | at::DeviceGuard guard(input.device()); 951 | 952 | int batch_sz = input.size(0); 953 | long n_in_channels = input.size(1); 954 | long in_h = input.size(2); 955 | long in_w = input.size(3); 956 | 957 | n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); 958 | 959 | long n_out_channels = weight.size(0); 960 | int weight_h = weight.size(2); 961 | int weight_w = weight.size(3); 962 | 963 | long out_h = grad_out.size(2); 964 | long out_w = grad_out.size(3); 965 | 966 | auto grad_weight = at::zeros_like(weight); 967 | if (batch_sz == 0) 968 | { 969 | return grad_weight; 970 | } 971 | 972 | at::Tensor grad_out_buf = grad_out 973 | .reshape( 974 | {batch_sz / n_parallel_imgs, 975 | n_parallel_imgs, 976 | n_weight_grps, 977 | n_out_channels / n_weight_grps, 978 | out_h, 979 | out_w}) 980 | .permute({0, 2, 3, 1, 4, 5}) 981 | .contiguous(); 982 | 983 | input = input.reshape( 984 | {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); 985 | 986 | offset = offset.reshape( 987 | {batch_sz / n_parallel_imgs, 988 | n_parallel_imgs, 989 | n_offset_grps * 2 * weight_h * weight_w, 990 | out_h, 991 | out_w}); 992 | 993 | if (use_mask) 994 | { 995 | mask = mask.reshape( 996 | {batch_sz / n_parallel_imgs, 997 | n_parallel_imgs, 998 | n_offset_grps * weight_h * weight_w, 999 | out_h, 1000 | out_w}); 1001 | } 1002 | 1003 | grad_weight = grad_weight.reshape( 1004 | {n_weight_grps, 1005 | grad_weight.size(0) / n_weight_grps, 1006 | grad_weight.size(1), 1007 | grad_weight.size(2), 1008 | grad_weight.size(3)}); 1009 | 1010 | auto columns = at::empty( 1011 | {n_weight_grps, 1012 | n_in_channels * weight_w * weight_h / n_weight_grps, 1013 | n_parallel_imgs * out_h * out_w}, 1014 | input.options()); 1015 | 1016 | for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) 1017 | { 1018 | deformable_im2col( 1019 | input[elt], 1020 | offset[elt], 1021 | mask[elt], 1022 | n_in_channels, 1023 | in_h, 1024 | in_w, 1025 | weight_h, 1026 | weight_w, 1027 | pad_h, 1028 | pad_w, 1029 | stride_h, 1030 | stride_w, 1031 | dilation_h, 1032 | dilation_w, 1033 | out_h, 1034 | out_w, 1035 | n_parallel_imgs, 1036 | n_offset_grps, 1037 | use_mask, 1038 | columns); 1039 | 1040 | for (int g = 0; g < n_weight_grps; g++) 1041 | { 1042 | grad_weight[g] = 1043 | grad_weight[g] 1044 | .flatten(1) 1045 | .addmm_( 1046 | grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) 1047 | .view_as(grad_weight[g]); 1048 | } 1049 | } 1050 | 1051 | grad_weight = grad_weight.view( 1052 | {grad_weight.size(0) * grad_weight.size(1), 1053 | grad_weight.size(2), 1054 | grad_weight.size(3), 1055 | grad_weight.size(4)}); 1056 | return grad_weight; 1057 | } 1058 | 1059 | at::Tensor deform_conv2d_forward_kernel( 1060 | const at::Tensor& input, 1061 | const at::Tensor& weight, 1062 | const at::Tensor& offset, 1063 | const at::Tensor& mask, 1064 | const at::Tensor& bias, 1065 | int64_t stride_h, 1066 | int64_t stride_w, 1067 | int64_t pad_h, 1068 | int64_t pad_w, 1069 | int64_t dilation_h, 1070 | int64_t dilation_w, 1071 | int64_t n_weight_grps, 1072 | int64_t n_offset_grps, 1073 | bool use_mask) { 1074 | at::Tensor input_c = input.contiguous(); 1075 | at::Tensor offset_c = offset.contiguous(); 1076 | at::Tensor weight_c = weight.contiguous(); 1077 | at::Tensor mask_c = mask.contiguous(); 1078 | at::Tensor bias_c = bias.contiguous(); 1079 | 1080 | TORCH_CHECK(input_c.ndimension() == 4); 1081 | TORCH_CHECK(offset_c.ndimension() == 4); 1082 | TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); 1083 | TORCH_CHECK(weight_c.ndimension() == 4); 1084 | TORCH_CHECK(input_c.is_cuda(), "input must be a CUDA tensor"); 1085 | 1086 | at::DeviceGuard guard(input_c.device()); 1087 | 1088 | int batch_sz = input_c.size(0); 1089 | int in_channels = input_c.size(1); 1090 | int in_h = input_c.size(2); 1091 | int in_w = input_c.size(3); 1092 | 1093 | int n_parallel_imgs = 1094 | get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); 1095 | 1096 | int out_channels = weight_c.size(0); 1097 | int weight_h = weight_c.size(2); 1098 | int weight_w = weight_c.size(3); 1099 | 1100 | int ker_h = dilation_h * (weight_h - 1) + 1; 1101 | int ker_w = dilation_w * (weight_w - 1) + 1; 1102 | int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; 1103 | int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; 1104 | 1105 | TORCH_CHECK( 1106 | weight_h > 0 && weight_w > 0, 1107 | "weight_h: ", 1108 | weight_h, 1109 | " weight_w: ", 1110 | weight_w); 1111 | TORCH_CHECK( 1112 | stride_h > 0 && stride_w > 0, 1113 | "stride_h: ", 1114 | stride_h, 1115 | " stride_w: ", 1116 | stride_w); 1117 | TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); 1118 | TORCH_CHECK( 1119 | dilation_h > 0 && dilation_w > 0, 1120 | "dilation_h: ", 1121 | dilation_h, 1122 | " dilation_w: ", 1123 | dilation_w); 1124 | 1125 | TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); 1126 | TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); 1127 | TORCH_CHECK( 1128 | (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), 1129 | "offset.shape[1] is not valid: got: ", 1130 | offset_c.size(1), 1131 | " expected: ", 1132 | n_offset_grps * 2 * weight_h * weight_w); 1133 | TORCH_CHECK( 1134 | (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), 1135 | "mask.shape[1] is not valid: got: ", 1136 | mask_c.size(1), 1137 | " expected: ", 1138 | n_offset_grps * weight_h * weight_w); 1139 | TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); 1140 | 1141 | TORCH_CHECK( 1142 | (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); 1143 | TORCH_CHECK( 1144 | (offset_c.size(2) == out_h && offset_c.size(3) == out_w), 1145 | "offset output dims: (", 1146 | offset_c.size(2), 1147 | ", ", 1148 | offset_c.size(3), 1149 | ") - ", 1150 | "computed output dims: (", 1151 | out_h, 1152 | ", ", 1153 | out_w, 1154 | ")"); 1155 | TORCH_CHECK( 1156 | (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); 1157 | TORCH_CHECK( 1158 | (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), 1159 | "mask output dims: (", 1160 | mask_c.size(2), 1161 | ", ", 1162 | mask_c.size(3), 1163 | ") - ", 1164 | "computed output dims: (", 1165 | out_h, 1166 | ", ", 1167 | out_w, 1168 | ")"); 1169 | TORCH_CHECK( 1170 | out_h > 0 && out_w > 0, 1171 | "Calculated output size too small - out_h: ", 1172 | out_h, 1173 | " out_w: ", 1174 | out_w); 1175 | 1176 | auto out = 1177 | at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); 1178 | if (batch_sz == 0) 1179 | { 1180 | return out; 1181 | } 1182 | 1183 | // Separate batches into blocks 1184 | out = out.view( 1185 | {batch_sz / n_parallel_imgs, 1186 | n_parallel_imgs, 1187 | out_channels, 1188 | out_h, 1189 | out_w}); 1190 | input_c = input_c.view( 1191 | {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); 1192 | 1193 | offset_c = offset_c.view( 1194 | {batch_sz / n_parallel_imgs, 1195 | n_parallel_imgs, 1196 | n_offset_grps * 2 * weight_h * weight_w, 1197 | out_h, 1198 | out_w}); 1199 | 1200 | if (use_mask) 1201 | { 1202 | mask_c = mask_c.view( 1203 | {batch_sz / n_parallel_imgs, 1204 | n_parallel_imgs, 1205 | n_offset_grps * weight_h * weight_w, 1206 | out_h, 1207 | out_w}); 1208 | } 1209 | 1210 | at::Tensor out_buf = at::zeros( 1211 | {batch_sz / n_parallel_imgs, 1212 | out_channels, 1213 | n_parallel_imgs * out_h, 1214 | out_w}, 1215 | out.options()); 1216 | 1217 | // Separate channels into convolution groups 1218 | out_buf = out_buf.view( 1219 | {out_buf.size(0), 1220 | n_weight_grps, 1221 | out_buf.size(1) / n_weight_grps, 1222 | out_buf.size(2), 1223 | out_buf.size(3)}); 1224 | weight_c = weight_c.view( 1225 | {n_weight_grps, 1226 | weight_c.size(0) / n_weight_grps, 1227 | weight_c.size(1), 1228 | weight_c.size(2), 1229 | weight_c.size(3)}); 1230 | 1231 | // Sample points and perform convolution 1232 | auto columns = at::zeros( 1233 | {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, 1234 | input_c.options()); 1235 | for (int b = 0; b < batch_sz / n_parallel_imgs; b++) 1236 | { 1237 | deformable_im2col( 1238 | input_c[b], 1239 | offset_c[b], 1240 | mask_c[b], 1241 | in_channels, 1242 | in_h, 1243 | in_w, 1244 | weight_h, 1245 | weight_w, 1246 | pad_h, 1247 | pad_w, 1248 | stride_h, 1249 | stride_w, 1250 | dilation_h, 1251 | dilation_w, 1252 | out_h, 1253 | out_w, 1254 | n_parallel_imgs, 1255 | n_offset_grps, 1256 | use_mask, 1257 | columns); 1258 | 1259 | columns = columns.view( 1260 | {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); 1261 | for (int g = 0; g < n_weight_grps; g++) 1262 | { 1263 | out_buf[b][g] = out_buf[b][g] 1264 | .flatten(1) 1265 | .addmm_(weight_c[g].flatten(1), columns[g]) 1266 | .view_as(out_buf[b][g]); 1267 | } 1268 | columns = 1269 | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); 1270 | } 1271 | 1272 | out_buf = out_buf.view( 1273 | {batch_sz / n_parallel_imgs, 1274 | out_channels, 1275 | n_parallel_imgs, 1276 | out_h, 1277 | out_w}); 1278 | out_buf.transpose_(1, 2); 1279 | out.copy_(out_buf); 1280 | out = out.view({batch_sz, out_channels, out_h, out_w}); 1281 | 1282 | return out + bias_c.view({1, out_channels, 1, 1}); 1283 | } 1284 | 1285 | std::tuple 1286 | deform_conv2d_backward_kernel( 1287 | const at::Tensor& grad_out, 1288 | const at::Tensor& input, 1289 | const at::Tensor& weight, 1290 | const at::Tensor& offset, 1291 | const at::Tensor& mask, 1292 | const at::Tensor& bias, 1293 | int64_t stride_h, 1294 | int64_t stride_w, 1295 | int64_t pad_h, 1296 | int64_t pad_w, 1297 | int64_t dilation_h, 1298 | int64_t dilation_w, 1299 | int64_t n_weight_grps, 1300 | int64_t n_offset_grps, 1301 | bool use_mask) { 1302 | at::Tensor grad_out_c = grad_out.contiguous(); 1303 | at::Tensor input_c = input.contiguous(); 1304 | at::Tensor weight_c = weight.contiguous(); 1305 | at::Tensor offset_c = offset.contiguous(); 1306 | at::Tensor mask_c = mask.contiguous(); 1307 | at::Tensor bias_c = bias.contiguous(); 1308 | 1309 | const int batch_sz = input_c.size(0); 1310 | const int n_parallel_imgs = 1311 | get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); 1312 | 1313 | auto grad_input_and_offset_and_mask = backward_gradient_inputs( 1314 | input_c, 1315 | weight_c, 1316 | offset_c, 1317 | mask_c, 1318 | grad_out_c, 1319 | stride_h, 1320 | stride_w, 1321 | pad_h, 1322 | pad_w, 1323 | dilation_h, 1324 | dilation_w, 1325 | n_weight_grps, 1326 | n_offset_grps, 1327 | n_parallel_imgs, 1328 | use_mask); 1329 | 1330 | auto grad_input = std::get<0>(grad_input_and_offset_and_mask); 1331 | auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); 1332 | auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); 1333 | 1334 | auto grad_weight = backward_gradient_parameters( 1335 | input_c, 1336 | weight_c, 1337 | offset_c, 1338 | mask_c, 1339 | grad_out_c, 1340 | stride_h, 1341 | stride_w, 1342 | pad_h, 1343 | pad_w, 1344 | dilation_h, 1345 | dilation_w, 1346 | n_weight_grps, 1347 | n_offset_grps, 1348 | n_parallel_imgs, 1349 | use_mask); 1350 | 1351 | auto value = grad_out_c.sum({0, 2, 3}); 1352 | auto grad_bias = at::ones_like(bias_c) * value; 1353 | 1354 | return std::make_tuple( 1355 | grad_input, grad_weight, grad_offset, grad_mask, grad_bias); 1356 | } 1357 | 1358 | } // namespace 1359 | 1360 | TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { 1361 | m.impl( 1362 | TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), 1363 | TORCH_FN(deform_conv2d_forward_kernel)); 1364 | m.impl( 1365 | TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), 1366 | TORCH_FN(deform_conv2d_backward_kernel)); 1367 | } 1368 | 1369 | } // namespace ops 1370 | } // namespace vision 1371 | -------------------------------------------------------------------------------- /src/dkd.cpp: -------------------------------------------------------------------------------- 1 | #include "dkd.hpp" 2 | 3 | #include 4 | 5 | namespace F = torch::nn::functional; 6 | using namespace torch::indexing; 7 | 8 | DKD::DKD(int radius, int top_k, float scores_th, int n_limit) 9 | : radius_(radius), 10 | top_k_(top_k), 11 | scores_th_(scores_th), 12 | n_limit_(n_limit), 13 | kernel_size_(calculateKernelSize(radius)), 14 | temperature_(0.1f), 15 | unfold_(torch::nn::UnfoldOptions(kernel_size_).padding(radius)) { 16 | 17 | auto x = torch::linspace(-radius_, radius_, kernel_size_); 18 | auto meshgrid = torch::meshgrid({x, x}); 19 | hw_grid_ = torch::stack({meshgrid[1], meshgrid[0]}, -1) 20 | .reshape({-1, 2}) 21 | .contiguous(); // Ensure contiguous memory layout 22 | } 23 | 24 | torch::Tensor DKD::simple_nms(torch::Tensor scores, int nms_radius) && { 25 | auto zeros = torch::zeros_like(scores); 26 | auto max_pool_options = F::MaxPool2dFuncOptions(nms_radius * 2 + 1) 27 | .stride(1) 28 | .padding(nms_radius); 29 | 30 | auto max_mask = std::move(scores) == F::max_pool2d(scores, max_pool_options); 31 | 32 | for (int i = 0; i < 2; ++i) 33 | { 34 | auto supp_mask = F::max_pool2d(max_mask.to(torch::kFloat), max_pool_options) > 0; 35 | auto supp_scores = torch::where(supp_mask, zeros, scores); 36 | auto new_max_mask = supp_scores == F::max_pool2d(supp_scores, max_pool_options); 37 | max_mask = max_mask | (new_max_mask & (~supp_mask)); 38 | } 39 | 40 | return torch::where(max_mask, scores, std::move(zeros)); 41 | } 42 | 43 | torch::Tensor DKD::simple_nms(const torch::Tensor& scores, int nms_radius) & { 44 | auto scores_copy = scores.clone(); 45 | return std::move(*this).simple_nms(std::move(scores_copy), nms_radius); 46 | } 47 | 48 | std::tuple, std::vector, std::vector> 49 | DKD::detect_keypoints(torch::Tensor scores_map, bool sub_pixel) && { 50 | const auto batch_size = scores_map.size(0); 51 | const auto height = scores_map.size(2); 52 | const auto width = scores_map.size(3); 53 | const auto device = scores_map.device(); 54 | 55 | auto scores_nograd = scores_map.detach(); 56 | auto nms_scores = std::move(*this).simple_nms(std::move(scores_nograd), 2); 57 | 58 | auto border_mask = torch::ones_like(nms_scores, 59 | torch::TensorOptions() 60 | .dtype(torch::kBool) 61 | .device(device)); 62 | 63 | border_mask.index_put_({Slice(), Slice(), Slice(None, radius_), Slice()}, false); 64 | border_mask.index_put_({Slice(), Slice(), Slice(), Slice(None, radius_)}, false); 65 | border_mask.index_put_({Slice(), Slice(), Slice(-radius_, None), Slice()}, false); 66 | border_mask.index_put_({Slice(), Slice(), Slice(), Slice(-radius_, None)}, false); 67 | 68 | nms_scores = torch::where(border_mask, nms_scores, torch::zeros_like(nms_scores)); 69 | 70 | std::vector keypoints; 71 | std::vector scoredispersitys; 72 | std::vector kptscores; 73 | keypoints.reserve(batch_size); 74 | scoredispersitys.reserve(batch_size); 75 | kptscores.reserve(batch_size); 76 | 77 | // Create wh tensor on the correct device 78 | auto wh = torch::tensor( 79 | {static_cast(width - 1), static_cast(height - 1)}, 80 | torch::TensorOptions().dtype(scores_map.dtype()).device(device)); 81 | 82 | // Ensure hw_grid_ is on the correct device 83 | if (hw_grid_.device() != device) 84 | { 85 | hw_grid_ = hw_grid_.to(device); 86 | } 87 | 88 | if (sub_pixel) 89 | { 90 | auto patches = unfold_(scores_map); 91 | 92 | for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) 93 | { 94 | auto patch = patches[batch_idx].transpose(0, 1); 95 | 96 | torch::Tensor indices_kpt; 97 | if (top_k_ > 0) 98 | { 99 | auto scores_view = nms_scores[batch_idx].reshape(-1); 100 | auto topk = scores_view.topk(top_k_); 101 | indices_kpt = std::get<1>(topk); 102 | } else 103 | { 104 | auto scores_view = nms_scores[batch_idx].reshape(-1); 105 | auto mask = scores_view > scores_th_; 106 | indices_kpt = mask.nonzero().squeeze(1); 107 | if (indices_kpt.size(0) > n_limit_) 108 | { 109 | auto kpts_sc = scores_view.index_select(0, indices_kpt); 110 | auto sort_idx = std::get<1>(kpts_sc.sort(true)); 111 | indices_kpt = indices_kpt.index_select(0, sort_idx.slice(0, n_limit_)); 112 | } 113 | } 114 | 115 | auto patch_scores = patch.index_select(0, indices_kpt); 116 | auto keypoints_xy_nms = torch::stack({indices_kpt % width, 117 | torch::div(indices_kpt, width, /*rounding_mode=*/"floor")}, 118 | 1) 119 | .to(device); 120 | 121 | auto [max_v, _] = patch_scores.max(1, true); 122 | auto x_exp = ((patch_scores - max_v.detach()) / temperature_).exp(); 123 | auto xy_residual = (x_exp.unsqueeze(2) * hw_grid_.unsqueeze(0)).sum(1) / 124 | x_exp.sum(1, true); 125 | 126 | auto dist2 = (hw_grid_.unsqueeze(0) - xy_residual.unsqueeze(1)) 127 | .div(radius_) 128 | .norm(2, -1) 129 | .pow(2); 130 | 131 | auto scoredispersity = (x_exp * dist2).sum(1) / x_exp.sum(1); 132 | auto keypoints_xy = keypoints_xy_nms + xy_residual; 133 | keypoints_xy = keypoints_xy.div(wh).mul(2).sub(1); 134 | 135 | auto kptscore = torch::nn::functional::grid_sample( 136 | scores_map[batch_idx].unsqueeze(0), 137 | keypoints_xy.view({1, 1, -1, 2}), 138 | torch::nn::functional::GridSampleFuncOptions() 139 | .mode(torch::kBilinear) 140 | .align_corners(true))[0][0][0]; 141 | 142 | keypoints.push_back(std::move(keypoints_xy)); 143 | scoredispersitys.push_back(std::move(scoredispersity)); 144 | kptscores.push_back(std::move(kptscore)); 145 | } 146 | } else 147 | { 148 | for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) 149 | { 150 | torch::Tensor indices_kpt; 151 | if (top_k_ > 0) 152 | { 153 | auto scores_view = nms_scores[batch_idx].reshape(-1); 154 | auto topk = scores_view.topk(top_k_); 155 | indices_kpt = std::get<1>(topk); 156 | } else 157 | { 158 | auto scores_view = nms_scores[batch_idx].reshape(-1); 159 | auto mask = scores_view > scores_th_; 160 | indices_kpt = mask.nonzero().squeeze(1); 161 | if (indices_kpt.size(0) > n_limit_) 162 | { 163 | auto kpts_sc = scores_view.index_select(0, indices_kpt); 164 | auto sort_idx = std::get<1>(kpts_sc.sort(true)); 165 | indices_kpt = indices_kpt.index_select(0, sort_idx.slice(0, n_limit_)); 166 | } 167 | } 168 | 169 | auto keypoints_xy = torch::stack({indices_kpt % width, 170 | torch::div(indices_kpt, width, /*rounding_mode=*/"floor")}, 171 | 1) 172 | .to(device); 173 | 174 | keypoints_xy = keypoints_xy.div(wh).mul(2).sub(1); 175 | 176 | auto kptscore = torch::nn::functional::grid_sample( 177 | scores_map[batch_idx].unsqueeze(0), 178 | keypoints_xy.view({1, 1, -1, 2}), 179 | torch::nn::functional::GridSampleFuncOptions() 180 | .mode(torch::kBilinear) 181 | .align_corners(true))[0][0][0]; 182 | 183 | keypoints.push_back(std::move(keypoints_xy)); 184 | scoredispersitys.push_back(kptscore.clone()); 185 | kptscores.push_back(std::move(kptscore)); 186 | } 187 | } 188 | 189 | return std::make_tuple(std::move(keypoints), 190 | std::move(scoredispersitys), 191 | std::move(kptscores)); 192 | } 193 | 194 | std::tuple, std::vector, std::vector> 195 | DKD::detect_keypoints(const torch::Tensor& scores_map, bool sub_pixel) & { 196 | auto scores_map_copy = scores_map.clone(); 197 | return std::move(*this).detect_keypoints(std::move(scores_map_copy), sub_pixel); 198 | } 199 | 200 | std::tuple, std::vector, std::vector> 201 | DKD::forward(torch::Tensor scores_map, bool sub_pixel) && { 202 | return std::move(*this).detect_keypoints(std::move(scores_map), sub_pixel); 203 | } 204 | 205 | std::tuple, std::vector, std::vector> 206 | DKD::forward(const torch::Tensor& scores_map, bool sub_pixel) & { 207 | return this->detect_keypoints(scores_map, sub_pixel); 208 | } -------------------------------------------------------------------------------- /src/get_patches.cpp: -------------------------------------------------------------------------------- 1 | #include "get_patches_cuda.h" 2 | #include 3 | #include 4 | 5 | // map: CxHxW 6 | // points: Nx2 7 | // kernel_size: int 8 | // return: N x C x kernel_size x kernel_size 9 | namespace custom_ops { 10 | torch::Tensor get_patches_forward_cpu(const torch::Tensor& map, torch::Tensor& points, int64_t kernel_size) { 11 | namespace F = torch::nn::functional; 12 | using namespace torch::indexing; 13 | 14 | auto N = points.size(0); 15 | auto C = map.size(0); 16 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 17 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 18 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 19 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 20 | auto radius = (kernel_size - 1.0) / 2.0; 21 | int pad_left_top = floor(radius); 22 | int pad_right_bottom = ceil(radius); 23 | 24 | // pad map 25 | auto options = F::PadFuncOptions({pad_left_top, pad_right_bottom, pad_left_top, pad_right_bottom}).mode(torch::kConstant); 26 | auto map_pad = F::pad(map.unsqueeze(0), options).squeeze(0); // Cx(H+2*radius)x(W+2*radius) 27 | 28 | // get patches 29 | torch::Tensor patches = torch::zeros({N, C, kernel_size, kernel_size}, map.options()); 30 | auto a_points = points.accessor(); // Nx2 31 | auto a_map_pad = map_pad.accessor(); // Cx(H+2*radius)x(W+2*radius) 32 | auto a_patches = patches.accessor(); // N x C x kernel_size x kernel_size 33 | 34 | for (auto in = 0; in < N; in++) 35 | { 36 | auto w_start = a_points[in][0]; 37 | auto h_start = a_points[in][1]; 38 | 39 | // copy data 40 | for (auto ic = 0; ic < C; ic++) 41 | { 42 | for (auto ih = 0; ih < kernel_size; ih++) 43 | { 44 | for (auto iw = 0; iw < kernel_size; iw++) 45 | { 46 | a_patches[in][ic][ih][iw] = a_map_pad[ic][ih + h_start][iw + w_start]; 47 | } 48 | } 49 | } 50 | } 51 | return patches; 52 | } 53 | 54 | // patches: NxCx(2*radius+1)x(2*radius+1) 55 | // points: Nx2 56 | torch::Tensor 57 | get_patches_backward_cpu(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W) { 58 | namespace F = torch::nn::functional; 59 | using namespace torch::indexing; 60 | 61 | auto N = d_patches.size(0); 62 | auto C = d_patches.size(1); 63 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 64 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 65 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 66 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 67 | auto kernel_size = d_patches.size(2); 68 | auto radius = (kernel_size - 1.0) / 2.0; 69 | int pad_left_top = floor(radius); 70 | int pad_right_bottom = ceil(radius); 71 | // printf("kernel_size=%d, radius=%f, pad_left_top=%d, pad_right_bottom=%d\n", 72 | // kernel_size, 73 | // radius, 74 | // pad_left_top, 75 | // pad_right_bottom); 76 | 77 | torch::Tensor d_map_pad = torch::zeros({C, H + int(2 * radius), W + int(2 * radius)}, d_patches.options()); 78 | 79 | auto a_points = points.accessor(); // Nx2 80 | auto a_d_map_pad = d_map_pad.accessor(); // Cx(H+2*radius)x(W+2*radius) 81 | auto a_p_patches = d_patches.accessor(); // NxCxkernel_sizexkernel_size 82 | for (auto in = 0; in < N; in++) 83 | { 84 | // long w_start = static_cast(*(p_points + in * 2 + 0)); 85 | // long h_start = static_cast(*(p_points + in * 2 + 1)); 86 | auto w_start = a_points[in][0]; 87 | auto h_start = a_points[in][1]; 88 | 89 | // copy data 90 | for (auto ic = 0; ic < C; ic++) 91 | { 92 | for (auto ih = 0; ih < kernel_size; ih++) 93 | { 94 | for (auto iw = 0; iw < kernel_size; iw++) 95 | { 96 | a_d_map_pad[ic][ih + h_start][iw + w_start] = a_p_patches[in][ic][ih][iw]; 97 | } 98 | } 99 | } 100 | } 101 | 102 | auto d_map = d_map_pad.index( 103 | {Slice(), Slice(pad_left_top, -pad_right_bottom), Slice(pad_left_top, -pad_right_bottom)}); 104 | 105 | return d_map; 106 | } 107 | 108 | torch::Tensor get_patches_forward(const torch::Tensor& map, torch::Tensor& points, int64_t kernel_size) { 109 | if (map.device() == torch::kCPU) 110 | return get_patches_forward_cpu(map, points, kernel_size); 111 | else 112 | { 113 | return get_patches_forward_cuda(map, points, kernel_size); 114 | } 115 | } 116 | 117 | torch::Tensor get_patches_backward(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W) { 118 | if (d_patches.device() == torch::kCPU) 119 | return get_patches_backward_cpu(d_patches, points, H, W); 120 | else 121 | return get_patches_backward_cuda(d_patches, points, H, W); 122 | } 123 | } // namespace custom_ops -------------------------------------------------------------------------------- /src/get_patches_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "get_patches_cuda.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace F = torch::nn::functional; 11 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | // CUDA: grid stride looping 19 | // 20 | // int64_t _i_n_d_e_x specifically prevents overflow in the loop increment. 21 | // If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final 22 | // iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be 23 | // greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no 24 | // further iterations and the overflowed value in i=_i_n_d_e_x is not used. 25 | #define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \ 26 | int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \ 27 | for (index_type i = _i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x += blockDim.x * gridDim.x, i = _i_n_d_e_x) 28 | 29 | #define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int) 30 | 31 | // Use 1024 threads per block, which requires cuda sm_2x or above 32 | // constexpr int CUDA_NUM_THREADS = 1024; 33 | constexpr int CUDA_NUM_THREADS = 16; 34 | 35 | // CUDA: number of blocks for threads. 36 | inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block = CUDA_NUM_THREADS) { 37 | TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N); 38 | constexpr int64_t max_int = std::numeric_limits::max(); 39 | 40 | // Round up division for positive number that cannot cause integer overflow 41 | auto block_num = (N - 1) / max_threads_per_block + 1; 42 | TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device"); 43 | 44 | return static_cast(block_num); 45 | } 46 | 47 | template 48 | C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) 49 | __global__ void get_patches_forward_cuda_kernel(const int64_t n, 50 | const scalar_t* p_map, // Cx(H+2*radius)x(W+2*radius) 51 | const int64_t* p_points, // Nx2 52 | int64_t n_input_plane, int64_t input_height, int64_t input_width, int64_t n_points, 53 | int64_t pad_left_top, int64_t pad_right_bottom, int64_t kernel_size, 54 | scalar_t* p_patches // NxCxkernel_sizexkernel_size 55 | ) { 56 | CUDA_KERNEL_LOOP(index, n) { 57 | int64_t n_out = index % n_points; // point idx 58 | int64_t channel_idx = index / n_points; // channel idx 59 | 60 | int64_t w_in = *(p_points + 2 * n_out); 61 | int64_t h_in = *(p_points + 2 * n_out + 1); 62 | 63 | const scalar_t* im = p_map + (channel_idx * input_height + h_in) * input_width + w_in; 64 | scalar_t* dst_patches = p_patches + (n_out * n_input_plane + channel_idx) * kernel_size * kernel_size; 65 | 66 | // copy data 67 | for (int64_t i = 0; i < kernel_size; ++i) 68 | { 69 | for (int64_t j = 0; j < kernel_size; ++j) 70 | { 71 | int64_t h = h_in + i - pad_left_top; 72 | int64_t w = w_in + j - pad_left_top; 73 | 74 | *(dst_patches + i * kernel_size + j) = (h >= 0 && w >= 0 && h < input_height && w < input_width) 75 | ? im[(i - pad_left_top) * input_width + j - pad_left_top] 76 | : static_cast(0); 77 | } 78 | } 79 | } 80 | } 81 | 82 | template 83 | __global__ void 84 | get_patches_forward_cuda_kernel1(const torch::PackedTensorAccessor32 map_pad, // Cx(H+2*radius)x(W+2*radius) 85 | const torch::PackedTensorAccessor32 points, // Nx2 86 | torch::PackedTensorAccessor32 patches, // NxCxkernel_sizexkernel_size 87 | int64_t kernel_size) { 88 | const int in = blockIdx.x * blockDim.x + threadIdx.x; 89 | const int N = points.size(0); 90 | const int C = map_pad.size(0); 91 | 92 | if (in < N) 93 | { 94 | long w_start = points[in][0]; 95 | long h_start = points[in][1]; 96 | 97 | // copy data 98 | for (long ic = 0; ic < C; ic++) 99 | { 100 | for (long ih = 0; ih < kernel_size; ih++) 101 | { 102 | for (long iw = 0; iw < kernel_size; iw++) 103 | { 104 | patches[in][ic][ih][iw] = map_pad[ic][h_start + ih][w_start + iw]; 105 | } 106 | } 107 | } 108 | } 109 | } 110 | 111 | template 112 | __global__ void 113 | get_patches_backward_cuda_kernel(torch::PackedTensorAccessor32 d_map_pad, // Cx(H+2*radius)x(W+2*radius) 114 | const torch::PackedTensorAccessor32 points, // Nx2 115 | const torch::PackedTensorAccessor32 d_patches, // NxCxkernel_sizexkernel_size 116 | int64_t kernel_size) { 117 | const int in = blockIdx.x * blockDim.x + threadIdx.x; 118 | const int N = points.size(0); 119 | const int C = d_map_pad.size(0); 120 | 121 | if (in < N) 122 | { 123 | long w_start = points[in][0]; 124 | long h_start = points[in][1]; 125 | 126 | // copy data 127 | for (long ic = 0; ic < C; ic++) 128 | { 129 | for (long ih = 0; ih < kernel_size; ih++) 130 | { 131 | for (long iw = 0; iw < kernel_size; iw++) 132 | { 133 | d_map_pad[ic][h_start + ih][w_start + iw] = d_patches[in][ic][ih][iw]; 134 | } 135 | } 136 | } 137 | } 138 | } 139 | 140 | torch::Tensor get_patches_forward_cuda(const torch::Tensor& input, torch::Tensor& points, int64_t kernel_size) { 141 | CHECK_INPUT(input); 142 | CHECK_INPUT(points); 143 | 144 | int64_t n_input_plane = input.size(0); 145 | int64_t input_height = input.size(1); 146 | int64_t input_width = input.size(2); 147 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 148 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 149 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 150 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 151 | auto radius = (kernel_size - 1.0) / 2.0; 152 | int64_t pad_left_top = floor(radius); 153 | int64_t pad_right_bottom = ceil(radius); 154 | int64_t n_points = points.size(0); 155 | 156 | // create output patches 157 | torch::Tensor patches = torch::zeros({n_points, n_input_plane, kernel_size, kernel_size}, input.options()); 158 | 159 | // cuda kernel 160 | int64_t num_kernels = n_input_plane * n_points; 161 | auto stream = at::cuda::getCurrentCUDAStream(); 162 | AT_DISPATCH_FLOATING_TYPES(input.type(), "get_patches_forward_cuda", 163 | ( 164 | [&] { 165 | get_patches_forward_cuda_kernel<<>>( 166 | num_kernels, input.data_ptr(), points.data_ptr(), n_input_plane, input_height, 167 | input_width, n_points, pad_left_top, pad_right_bottom, kernel_size, patches.data_ptr()); 168 | })); 169 | 170 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 171 | 172 | return patches; 173 | } 174 | 175 | torch::Tensor get_patches_forward_cuda1(const torch::Tensor& map, torch::Tensor& points, int64_t kernel_size) { 176 | CHECK_INPUT(map); 177 | CHECK_INPUT(points); 178 | 179 | auto N = points.size(0); 180 | auto C = map.size(0); 181 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 182 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 183 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 184 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 185 | auto radius = (kernel_size - 1.0) / 2.0; 186 | int pad_left_top = floor(radius); 187 | int pad_right_bottom = ceil(radius); 188 | 189 | // pad map 190 | auto options = F::PadFuncOptions({pad_left_top, pad_right_bottom, pad_left_top, pad_right_bottom}).mode(torch::kConstant); 191 | auto map_pad = F::pad(map.unsqueeze(0), options).squeeze(0); // Cx(H+2*radius)x(W+2*radius) 192 | 193 | // create patches 194 | torch::Tensor patches = torch::empty({N, C, kernel_size, kernel_size}, map.options()); 195 | 196 | // cuda kernel 197 | const int threads = CUDA_NUM_THREADS; 198 | const int blocks = (N + threads - 1) / threads; 199 | AT_DISPATCH_FLOATING_TYPES(map_pad.type(), "get_patches_forward_cuda", 200 | ( 201 | [&] { 202 | get_patches_forward_cuda_kernel1 203 | <<>>(map_pad.packed_accessor32(), 204 | points.packed_accessor32(), 205 | patches.packed_accessor32(), kernel_size); 206 | })); 207 | 208 | // get error 209 | cudaDeviceSynchronize(); 210 | cudaError_t cudaerr = cudaGetLastError(); 211 | if (cudaerr != cudaSuccess) 212 | printf("kernel launch failed with error \"%s\".\n", cudaGetErrorString(cudaerr)); 213 | 214 | return patches; 215 | } 216 | 217 | torch::Tensor get_patches_backward_cuda(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W) { 218 | CHECK_INPUT(d_patches); 219 | CHECK_INPUT(points); 220 | 221 | auto N = d_patches.size(0); 222 | auto C = d_patches.size(1); 223 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 224 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 225 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 226 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 227 | auto kernel_size = d_patches.size(2); 228 | auto radius = (kernel_size - 1.0) / 2.0; 229 | int pad_left_top = floor(radius); 230 | int pad_right_bottom = ceil(radius); 231 | 232 | torch::Tensor d_map_pad = torch::zeros({C, H + int(2 * radius), W + int(2 * radius)}, d_patches.options()); 233 | 234 | // cuda kernel 235 | const int threads = CUDA_NUM_THREADS; 236 | const int blocks = (N + threads - 1) / threads; 237 | AT_DISPATCH_FLOATING_TYPES(d_map_pad.type(), "get_patches_backward_cuda", 238 | ( 239 | [&] { 240 | get_patches_backward_cuda_kernel 241 | <<>>(d_map_pad.packed_accessor32(), 242 | points.packed_accessor32(), 243 | d_patches.packed_accessor32(), kernel_size); 244 | })); 245 | 246 | // get error 247 | cudaDeviceSynchronize(); 248 | cudaError_t cudaerr = cudaGetLastError(); 249 | if (cudaerr != cudaSuccess) 250 | printf("kernel launch failed with error \"%s\".\n", cudaGetErrorString(cudaerr)); 251 | 252 | using namespace torch::indexing; 253 | auto d_map = d_map_pad.index({Slice(), Slice(pad_left_top, -pad_right_bottom), Slice(pad_left_top, -pad_right_bottom)}); 254 | 255 | return d_map; 256 | } 257 | -------------------------------------------------------------------------------- /src/input_padder.cpp: -------------------------------------------------------------------------------- 1 | #include "input_padder.hpp" 2 | 3 | torch::Tensor InputPadder::pad(torch::Tensor x) && { 4 | return torch::nn::functional::pad( 5 | std::move(x), 6 | torch::nn::functional::PadFuncOptions({pad_[0], pad_[1], pad_[2], pad_[3]}) 7 | .mode(torch::kReplicate)); 8 | } 9 | 10 | torch::Tensor InputPadder::pad(const torch::Tensor& x) & { 11 | return torch::nn::functional::pad( 12 | x, 13 | torch::nn::functional::PadFuncOptions({pad_[0], pad_[1], pad_[2], pad_[3]}) 14 | .mode(torch::kReplicate)); 15 | } 16 | 17 | [[maybe_unused]] torch::Tensor InputPadder::unpad(torch::Tensor x) && { 18 | int h = x.size(-2); 19 | int w = x.size(-1); 20 | return std::move(x).index({torch::indexing::Slice(), 21 | torch::indexing::Slice(), 22 | torch::indexing::Slice(pad_[2], h - pad_[3]), 23 | torch::indexing::Slice(pad_[0], w - pad_[1])}); 24 | } 25 | 26 | torch::Tensor InputPadder::unpad(const torch::Tensor& x) & { 27 | int h = x.size(-2); 28 | int w = x.size(-1); 29 | return x.index({torch::indexing::Slice(), 30 | torch::indexing::Slice(), 31 | torch::indexing::Slice(pad_[2], h - pad_[3]), 32 | torch::indexing::Slice(pad_[0], w - pad_[1])}); 33 | } 34 | 35 | void InputPadder::setPadding(std::span padding) { 36 | if (padding.size() != 4) 37 | { 38 | throw std::invalid_argument("Padding must have exactly 4 values"); 39 | } 40 | std::copy(padding.begin(), padding.end(), pad_.begin()); 41 | } -------------------------------------------------------------------------------- /src/sddh.cpp: -------------------------------------------------------------------------------- 1 | #include "sddh.hpp" 2 | 3 | #include "get_patches.hpp" 4 | #include 5 | 6 | using namespace torch::indexing; 7 | 8 | SDDH::SDDH(int dims, int kernel_size, int n_pos, bool conv2D, bool mask) 9 | : kernel_size_(kernel_size), 10 | n_pos_(n_pos), 11 | conv2D_(conv2D), 12 | mask_(mask) { 13 | 14 | // Channel num for offsets 15 | const int channel_num = mask ? 3 * n_pos : 2 * n_pos; 16 | 17 | // Build offset convolution layers 18 | torch::nn::Sequential offset_conv; 19 | offset_conv->push_back(torch::nn::Conv2d( 20 | torch::nn::Conv2dOptions(dims, channel_num, kernel_size) 21 | .stride(1) 22 | .padding(0) 23 | .bias(true))); 24 | offset_conv->push_back(torch::nn::SELU()); 25 | offset_conv->push_back(torch::nn::Conv2d( 26 | torch::nn::Conv2dOptions(channel_num, channel_num, 1) 27 | .stride(1) 28 | .padding(0) 29 | .bias(true))); 30 | 31 | register_module("offset_conv", offset_conv); 32 | offset_conv_ = offset_conv; 33 | 34 | // Sampled feature convolution 35 | sf_conv_ = register_module("sf_conv", 36 | torch::nn::Conv2d(torch::nn::Conv2dOptions(dims, dims, 1) 37 | .stride(1) 38 | .padding(0) 39 | .bias(false))); 40 | 41 | if (!conv2D) 42 | { 43 | // Register deformable desc weights 44 | agg_weights_ = register_parameter("agg_weights", 45 | torch::randn({n_pos, dims, dims})); 46 | } else 47 | { 48 | // Register convM 49 | convM_ = register_module("convM", 50 | torch::nn::Conv2d(torch::nn::Conv2dOptions(dims * n_pos, dims, 1) 51 | .stride(1) 52 | .padding(0) 53 | .bias(false))); 54 | } 55 | } 56 | 57 | torch::Tensor SDDH::process_features(torch::Tensor features, int64_t num_keypoints) && { 58 | if (!conv2D_) 59 | { 60 | return torch::einsum("ncp,pcd->nd", 61 | {std::move(features), agg_weights_}); 62 | } else 63 | { 64 | features = std::move(features) 65 | .reshape({num_keypoints, -1}) 66 | .unsqueeze(-1) 67 | .unsqueeze(-1); 68 | return convM_->forward(std::move(features)).squeeze(); 69 | } 70 | } 71 | 72 | torch::Tensor SDDH::process_features(const torch::Tensor& features, int64_t num_keypoints) & { 73 | auto features_copy = features.clone(); 74 | return std::move(*this).process_features(std::move(features_copy), num_keypoints); 75 | } 76 | 77 | std::tuple, std::vector> 78 | SDDH::forward(torch::Tensor x, std::vector& keypoints) && { 79 | // Make input tensor contiguous if it isn't already 80 | if (!x.is_contiguous()) 81 | { 82 | x = x.contiguous(); 83 | } 84 | 85 | const auto batch_size = x.size(0); 86 | const auto channels = x.size(1); 87 | const auto height = x.size(2); 88 | const auto width = x.size(3); 89 | const auto device = x.device(); 90 | 91 | const auto wh = torch::tensor({width - 1.0f, height - 1.0f}, 92 | torch::TensorOptions() 93 | .dtype(x.dtype()) 94 | .device(device)); 95 | 96 | const float max_offset = std::max(height, width) / 4.0f; 97 | 98 | std::vector offsets; 99 | std::vector descriptors; 100 | offsets.reserve(batch_size); 101 | descriptors.reserve(batch_size); 102 | 103 | for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) 104 | { 105 | auto xi = x[batch_idx]; 106 | // Ensure xi is contiguous 107 | if (!xi.is_contiguous()) 108 | { 109 | xi = xi.contiguous(); 110 | } 111 | 112 | const auto& kptsi = keypoints[batch_idx]; 113 | auto kptsi_wh = (kptsi / 2 + 0.5) * wh; 114 | const auto num_keypoints = kptsi_wh.size(0); 115 | 116 | torch::Tensor patch; 117 | if (kernel_size_ > 1) 118 | { 119 | // Ensure inputs to get_patches_forward are contiguous 120 | auto kptsi_wh_long = kptsi_wh.to(torch::kLong).contiguous(); 121 | patch = custom_ops::get_patches_forward(xi.contiguous(), kptsi_wh_long, kernel_size_); 122 | } else 123 | { 124 | auto kptsi_wh_long = kptsi_wh.to(torch::kLong).contiguous(); 125 | patch = xi.index({Slice(), 126 | kptsi_wh_long.index({Slice(), 1}), 127 | kptsi_wh_long.index({Slice(), 0})}) 128 | .transpose(0, 1) 129 | .reshape({num_keypoints, channels, 1, 1}) 130 | .contiguous(); 131 | } 132 | 133 | // Rest of the code remains the same... 134 | auto offset = offset_conv_->forward(std::move(patch)); 135 | offset = offset.clamp(-max_offset, max_offset); 136 | 137 | torch::Tensor mask_weight; 138 | if (mask_) 139 | { 140 | offset = offset.index({Slice(), Slice(), 0, 0}) 141 | .view({num_keypoints, 3, n_pos_}) 142 | .permute({0, 2, 1}) 143 | .contiguous(); 144 | auto offset_xy = offset.index({Slice(), Slice(), Slice(None, 2)}); 145 | mask_weight = torch::sigmoid(offset.index({Slice(), Slice(), 2})); 146 | offset = offset_xy; 147 | } else 148 | { 149 | offset = offset.index({Slice(), Slice(), 0, 0}) 150 | .view({num_keypoints, 2, n_pos_}) 151 | .permute({0, 2, 1}) 152 | .contiguous(); 153 | } 154 | 155 | offsets.push_back(offset); 156 | 157 | auto pos = kptsi_wh.unsqueeze(1) + offset; 158 | pos = 2.0 * pos / wh - 1; 159 | pos = pos.reshape({1, num_keypoints * n_pos_, 1, 2}).contiguous(); 160 | 161 | auto features = torch::nn::functional::grid_sample( 162 | xi.unsqueeze(0), pos, 163 | torch::nn::functional::GridSampleFuncOptions() 164 | .mode(torch::kBilinear) 165 | .align_corners(true)); 166 | 167 | features = features.reshape({channels, num_keypoints, n_pos_, 1}) 168 | .permute({1, 0, 2, 3}) 169 | .contiguous(); 170 | 171 | if (mask_) 172 | { 173 | features = features * mask_weight.unsqueeze(1).unsqueeze(-1); 174 | } 175 | 176 | features = torch::selu(sf_conv_->forward(std::move(features))).squeeze(-1); 177 | 178 | torch::Tensor descs; 179 | if (!conv2D_) 180 | { 181 | descs = torch::einsum("ncp,pcd->nd", {features, agg_weights_}); 182 | } else 183 | { 184 | features = features.reshape({num_keypoints, -1}).unsqueeze(-1).unsqueeze(-1).contiguous(); 185 | descs = convM_->forward(std::move(features)).squeeze(); 186 | } 187 | 188 | descs = torch::nn::functional::normalize(std::move(descs), 189 | torch::nn::functional::NormalizeFuncOptions() 190 | .p(2) 191 | .dim(1)); 192 | 193 | descriptors.push_back(std::move(descs)); 194 | } 195 | 196 | return std::make_tuple(std::move(descriptors), std::move(offsets)); 197 | } 198 | 199 | std::tuple, std::vector> 200 | SDDH::forward(const torch::Tensor& x, std::vector& keypoints) & { 201 | auto x_copy = x.clone(); 202 | return std::move(*this).forward(std::move(x_copy), keypoints); 203 | } --------------------------------------------------------------------------------