├── .clang-format ├── .gitattributes ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake └── aliked-config.cmake.in ├── examples └── main.cpp ├── include ├── feature │ ├── ALIKED.hpp │ ├── DKD.hpp │ ├── SDDH.hpp │ ├── blocks.hpp │ ├── cuda_helpers.h │ ├── deform_conv2d.h │ ├── get_patches.hpp │ ├── get_patches_cuda.h │ └── input_padder.hpp └── matcher │ └── lightglue │ ├── attention.hpp │ ├── core.hpp │ ├── encoding.hpp │ ├── matcher.hpp │ └── transformer.hpp ├── models ├── aliked-n16.pt ├── aliked-n16rot.pt ├── aliked-n32.pt ├── aliked-t16.pt └── aliked_lightglue.pt └── src ├── feature ├── ALIKED.cpp ├── DKD.cpp ├── SDDH.cpp ├── blocks.cpp ├── deform_conv2d.cpp ├── deform_conv2d_kernel.cu ├── get_patches.cpp ├── get_patches_cuda.cu └── input_padder.cpp └── matcher └── lightglue ├── attention.cpp ├── core.cpp ├── encoding.cpp ├── matcher.cpp └── transformer.cpp /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | Standard: Latest 4 | 5 | # Access modifiers and indentation 6 | AccessModifierOffset: -4 7 | IndentWidth: 4 8 | ContinuationIndentWidth: 4 9 | TabWidth: 4 10 | UseTab: Never 11 | NamespaceIndentation: All 12 | 13 | # Alignment 14 | AlignAfterOpenBracket: Align 15 | AlignConsecutiveAssignments: false 16 | AlignConsecutiveDeclarations: false 17 | AlignConsecutiveMacros: Consecutive 18 | AlignEscapedNewlines: Left 19 | AlignOperands: true 20 | AlignTrailingComments: true 21 | 22 | # Allow behaviors 23 | AllowAllArgumentsOnNextLine: false 24 | AllowAllParametersOfDeclarationOnNextLine: false 25 | AllowShortBlocksOnASingleLine: Always 26 | AllowShortCaseLabelsOnASingleLine: true 27 | AllowShortEnumsOnASingleLine: true 28 | AllowShortFunctionsOnASingleLine: All 29 | AllowShortLambdasOnASingleLine: All 30 | AllowShortIfStatementsOnASingleLine: Never 31 | AllowShortLoopsOnASingleLine: false 32 | AllowAllConstructorInitializersOnNextLine: false 33 | 34 | # Breaking and wrapping 35 | AlwaysBreakAfterDefinitionReturnType: None 36 | AlwaysBreakAfterReturnType: None 37 | AlwaysBreakBeforeMultilineStrings: false 38 | AlwaysBreakTemplateDeclarations: true 39 | BreakBeforeBinaryOperators: None 40 | BreakBeforeTernaryOperators: true 41 | BreakConstructorInitializersBeforeComma: false 42 | BreakConstructorInitializers: BeforeColon 43 | BreakInheritanceList: BeforeColon 44 | BreakStringLiterals: true 45 | ColumnLimit: 0 46 | 47 | # Brace wrapping 48 | BreakBeforeBraces: Custom 49 | BraceWrapping: 50 | AfterCaseLabel: false 51 | AfterClass: false 52 | AfterControlStatement: Always 53 | AfterEnum: false 54 | AfterFunction: false 55 | AfterNamespace: false 56 | AfterObjCDeclaration: false 57 | AfterStruct: false 58 | AfterUnion: false 59 | AfterExternBlock: false 60 | BeforeCatch: false 61 | BeforeElse: false 62 | IndentBraces: false 63 | SplitEmptyFunction: true 64 | SplitEmptyRecord: true 65 | SplitEmptyNamespace: true 66 | 67 | # Constructor initialization 68 | ConstructorInitializerIndentWidth: 4 69 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 70 | 71 | # Empty lines and spacing 72 | EmptyLineBeforeAccessModifier: Always 73 | KeepEmptyLinesAtTheStartOfBlocks: true 74 | MaxEmptyLinesToKeep: 1 75 | SpaceAfterCStyleCast: false 76 | SpaceAfterTemplateKeyword: true 77 | SpaceBeforeAssignmentOperators: true 78 | SpaceBeforeParens: ControlStatements 79 | SpaceInEmptyParentheses: false 80 | SpacesBeforeTrailingComments: 1 81 | SpacesInAngles: false 82 | SpacesInContainerLiterals: true 83 | SpacesInCStyleCastParentheses: false 84 | SpacesInParentheses: false 85 | SpacesInSquareBrackets: false 86 | 87 | # Include ordering 88 | SortIncludes: CaseInsensitive 89 | IncludeBlocks: Regroup 90 | IncludeCategories: 91 | # C++ Standard Library headers 92 | - Regex: '^<(cctype|span|cstring|string|string_view|vector|map|fstream|typeindex|source_location|stacktrace|array|iostream|memory|future|stdexcept|algorithm|random|atomic|sstream|chrono|cstdint|expected|filesystem|functional|mutex|queue|optional|shared_mutex|thread|utility|variant|unordered_map|unordered_set|condition_variable)>$' 93 | Priority: 1 94 | SortPriority: 3 95 | # C Standard Library headers 96 | - Regex: '^<(ft2build\.h|GL/|GLFW/|glm/|spdlog/|fmt/).*>' 97 | Priority: 2 98 | SortPriority: 2 99 | - Regex: '^"(ft2build\.h|GL/|GLFW/|glm/|spdlog/|fmt/).*"' 100 | Priority: 2 101 | SortPriority: 2 102 | # All project headers 103 | - Regex: '.*' 104 | Priority: 3 105 | SortPriority: 1 106 | 107 | # Other settings 108 | Cpp11BracedListStyle: true 109 | DerivePointerAlignment: false 110 | FixNamespaceComments: true 111 | IndentCaseBlocks: false 112 | IndentCaseLabels: false 113 | IndentGotoLabels: false 114 | IndentPPDirectives: None 115 | IndentWrappedFunctionNames: false 116 | PointerAlignment: Left 117 | ReflowComments: true 118 | --- 119 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | models filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/code 2 | external/libtorch 3 | src/copies.sh 4 | logs 5 | build 6 | .idea 7 | cmake-* 8 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.26) 2 | project(LightGlue 3 | VERSION 1.0.0 4 | DESCRIPTION "C++ implementation of LightGlue" 5 | LANGUAGES CUDA CXX) 6 | 7 | # Set default build type if not specified 8 | if(NOT CMAKE_BUILD_TYPE) 9 | set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build." FORCE) 10 | endif() 11 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "RelWithDebInfo" "MinSizeRel") 12 | 13 | # Enable LTO/IPO only for Release builds 14 | include(CheckIPOSupported) 15 | check_ipo_supported(RESULT IPO_SUPPORTED OUTPUT IPO_ERROR) 16 | if(IPO_SUPPORTED AND CMAKE_BUILD_TYPE STREQUAL "Release") 17 | set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON) 18 | endif() 19 | 20 | # Core configuration 21 | set(CMAKE_CXX_STANDARD 20) 22 | set(CMAKE_CUDA_STANDARD 17) 23 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 24 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 25 | set(CMAKE_CUDA_ARCHITECTURES native) 26 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 27 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 28 | 29 | # Configure paths 30 | set(LIGHTGLUE_MODELS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/models" CACHE PATH "Path to model weights directory") 31 | 32 | # Find dependencies 33 | set(LIBTORCH_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/libtorch") 34 | set(CMAKE_PREFIX_PATH ${LIBTORCH_DIR}) 35 | 36 | find_package(Torch REQUIRED) 37 | find_package(OpenCV REQUIRED) 38 | find_package(CUDAToolkit REQUIRED) 39 | 40 | # Check CUDA version 41 | if(CUDAToolkit_VERSION VERSION_LESS "12.1") 42 | message(FATAL_ERROR "This project requires CUDA 12.1 or higher (found: ${CUDAToolkit_VERSION})") 43 | endif() 44 | 45 | # Debug flags for different compilers - removed problematic flags 46 | if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") 47 | set(DEBUG_FLAGS 48 | -g 49 | -Wall 50 | -Wextra 51 | -fno-omit-frame-pointer 52 | ) 53 | endif() 54 | 55 | # Release flags 56 | if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") 57 | set(RELEASE_FLAGS 58 | -O3 59 | -march=native 60 | -mtune=native 61 | -fomit-frame-pointer 62 | -ffast-math 63 | -DNDEBUG 64 | ) 65 | endif() 66 | 67 | # CUDA debug flags - simplified 68 | set(CUDA_DEBUG_FLAGS 69 | -G 70 | -g 71 | -lineinfo 72 | ) 73 | 74 | # CUDA release flags 75 | set(CUDA_RELEASE_FLAGS 76 | -O3 77 | --use_fast_math 78 | -DNDEBUG 79 | ) 80 | 81 | # Add models directory definition 82 | add_definitions(-DLIGHTGLUE_MODELS_DIR="${LIGHTGLUE_MODELS_DIR}") 83 | 84 | # Source files 85 | set(LIGHTGLUE_HEADERS 86 | include/feature/ALIKED.hpp 87 | include/feature/DKD.hpp 88 | include/feature/SDDH.hpp 89 | include/feature/blocks.hpp 90 | include/feature/get_patches.hpp 91 | include/feature/input_padder.hpp 92 | include/feature/deform_conv2d.h 93 | include/feature/get_patches_cuda.h 94 | include/feature/cuda_helpers.h 95 | 96 | include/matcher/lightglue/attention.hpp 97 | include/matcher/lightglue/core.hpp 98 | include/matcher/lightglue/encoding.hpp 99 | include/matcher/lightglue/matcher.hpp 100 | include/matcher/lightglue/transformer.hpp 101 | ) 102 | 103 | set(LIGHTGLUE_SOURCES 104 | src/feature/ALIKED.cpp 105 | src/feature/DKD.cpp 106 | src/feature/input_padder.cpp 107 | src/feature/get_patches.cpp 108 | src/feature/SDDH.cpp 109 | src/feature/deform_conv2d.cpp 110 | src/feature/deform_conv2d_kernel.cu 111 | src/feature/get_patches_cuda.cu 112 | 113 | src/matcher/lightglue/attention.cpp 114 | src/matcher/lightglue/core.cpp 115 | src/matcher/lightglue/encoding.cpp 116 | src/matcher/lightglue/matcher.cpp 117 | src/matcher/lightglue/transformer.cpp 118 | src/feature/blocks.cpp 119 | ) 120 | 121 | # Library target 122 | add_library(${PROJECT_NAME}_lib STATIC 123 | ${LIGHTGLUE_SOURCES} 124 | ${LIGHTGLUE_HEADERS} 125 | ) 126 | 127 | add_library(${PROJECT_NAME}::lib ALIAS ${PROJECT_NAME}_lib) 128 | 129 | target_include_directories(${PROJECT_NAME}_lib 130 | PUBLIC 131 | $ 132 | $ 133 | ) 134 | 135 | # Configure compile options based on build type - simplified 136 | target_compile_options(${PROJECT_NAME}_lib 137 | PRIVATE 138 | $<$,$>:${DEBUG_FLAGS}> 139 | $<$,$>:${RELEASE_FLAGS}> 140 | $<$,$>:${CUDA_DEBUG_FLAGS}> 141 | $<$,$>:${CUDA_RELEASE_FLAGS}> 142 | ) 143 | 144 | target_link_libraries(${PROJECT_NAME}_lib 145 | PUBLIC 146 | ${TORCH_LIBRARIES} 147 | ${OpenCV_LIBS} 148 | PRIVATE 149 | CUDA::cudart 150 | CUDA::curand 151 | CUDA::cublas 152 | ) 153 | 154 | # Properties for debug/release configurations 155 | set_target_properties(${PROJECT_NAME}_lib PROPERTIES 156 | CUDA_SEPARABLE_COMPILATION ON 157 | CUDA_RESOLVE_DEVICE_SYMBOLS ON 158 | POSITION_INDEPENDENT_CODE ON 159 | DEBUG_POSTFIX "d" 160 | ) 161 | 162 | # Debug configuration specific settings - removed problematic definitions 163 | if(CMAKE_BUILD_TYPE STREQUAL "Debug") 164 | target_compile_definitions(${PROJECT_NAME}_lib 165 | PRIVATE 166 | DEBUG 167 | ) 168 | endif() 169 | 170 | # Example application 171 | add_executable(${PROJECT_NAME} examples/main.cpp) 172 | target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_NAME}::lib) 173 | 174 | # Output directories with debug/release suffixes 175 | set_target_properties(${PROJECT_NAME} ${PROJECT_NAME}_lib PROPERTIES 176 | RUNTIME_OUTPUT_DIRECTORY_DEBUG "${CMAKE_BINARY_DIR}/bin/debug" 177 | RUNTIME_OUTPUT_DIRECTORY_RELEASE "${CMAKE_BINARY_DIR}/bin/release" 178 | LIBRARY_OUTPUT_DIRECTORY_DEBUG "${CMAKE_BINARY_DIR}/lib/debug" 179 | LIBRARY_OUTPUT_DIRECTORY_RELEASE "${CMAKE_BINARY_DIR}/lib/release" 180 | ARCHIVE_OUTPUT_DIRECTORY_DEBUG "${CMAKE_BINARY_DIR}/lib/debug" 181 | ARCHIVE_OUTPUT_DIRECTORY_RELEASE "${CMAKE_BINARY_DIR}/lib/release" 182 | ) 183 | 184 | # Create models directory if it doesn't exist 185 | add_custom_target(create_models_dir ALL 186 | COMMAND ${CMAKE_COMMAND} -E make_directory ${LIGHTGLUE_MODELS_DIR} 187 | ) 188 | 189 | # Print configuration information 190 | message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") 191 | message(STATUS "Models directory: ${LIGHTGLUE_MODELS_DIR}") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 ETH Zurich 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightGlue C++ Implementation 2 | 3 | This repository contains a C++ implementation of [LightGlue](https://github.com/cvg/LightGlue) using LibTorch. LightGlue is a lightweight feature matcher that achieves state-of-the-art performance while being significantly faster than traditional approaches. 4 | 5 | If you are interested in collaborating on this project, would like to reach out to me, or are considering to contribute to the discussion of the overall NeRF/GS pipeline, please join the Discord: https://discord.gg/NqwTqVYVmj 6 | 7 | ## Features 8 | 9 | - Complete C++ implementation of LightGlue using LibTorch 10 | - CUDA acceleration support 11 | - Integration with OpenCV for image handling 12 | - [ALIKED](https://github.com/MrNeRF/ALIKED_CPP) feature extractor implementation 13 | - Efficient memory management with move semantics 14 | - Visualization support for matches and pruning 15 | 16 | ## Prerequisites 17 | 18 | - CMake >= 3.26 19 | - CUDA >= 12.1 20 | - LibTorch (C++ distribution of PyTorch) 21 | - OpenCV 22 | - C++20 compliant compiler 23 | - git-lfs (for model weights) 24 | 25 | ## Building 26 | 27 | ```bash 28 | mkdir build && cd build 29 | cmake .. 30 | make -j 31 | ``` 32 | 33 | ## Usage 34 | 35 | The repository includes a sample application demonstrating feature matching between two images: 36 | 37 | ```bash 38 | ./LightGlue path/to/image1.jpg path/to/image2.jpg 39 | ``` 40 | 41 | ## Project Structure 42 | 43 | ``` 44 | . 45 | ├── include/ 46 | │ ├── feature/ # Feature extraction components 47 | │ └── matcher/ # LightGlue matcher implementation 48 | ├── src/ 49 | │ ├── feature/ # Feature extraction implementations 50 | │ └── matcher/ # Matcher implementations 51 | ├── examples/ # Example applications 52 | ├── models/ # Directory for model weights 53 | └── CMakeLists.txt 54 | ``` 55 | 56 | ## Implementation Details 57 | 58 | The implementation follows the original Python architecture while leveraging C++ and LibTorch features: 59 | - CUDA optimizations for performance 60 | - Move semantics for efficient memory handling 61 | - LibTorch's automatic differentiation (though primarily used for inference) 62 | - OpenCV integration for image processing and visualization 63 | 64 | ## Model Weights 65 | 66 | Place the model weights in the `models/` directory. The following models are supported: 67 | - ALIKED feature extractor weights 68 | - LightGlue matcher weights 69 | 70 | ## Future Development 71 | 72 | ### TODO 73 | - [ ] Batch Processing Support 74 | - Implement efficient batch processing for multiple image pairs 75 | - Optimize memory usage for batch operations 76 | - Add batch-specific configuration options 77 | 78 | - [ ] Flash Attention Implementation 79 | - Add efficient Flash Attention mechanism 80 | - Optimize for different GPU architectures 81 | - Implement memory-efficient attention patterns 82 | 83 | ## Contributing 84 | 85 | Contributions are welcome! Please feel free to submit pull requests or create issues for bugs and feature requests. 86 | 87 | ## Citations 88 | 89 | If you use this implementation, please cite both the original LightGlue paper and the C++ implementation: 90 | 91 | ```bibtex 92 | @inproceedings{lindenberger2023lightglue, 93 | author = {Philipp Lindenberger and 94 | Paul-Edouard Sarlin and 95 | Marc Pollefeys}, 96 | title = {{LightGlue: Local Feature Matching at Light Speed}}, 97 | booktitle = {ICCV}, 98 | year = {2023} 99 | } 100 | 101 | @misc{patas2024lightgluecpp, 102 | author = {Janusch Patas}, 103 | title = {LightGlue C++ Implementation}, 104 | year = {2024}, 105 | publisher = {GitHub}, 106 | journal = {GitHub Repository}, 107 | howpublished = {\url{https://github.com/MrNeRF/Light_Glue_CPP}} 108 | } 109 | ``` -------------------------------------------------------------------------------- /cmake/aliked-config.cmake.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/Light_Glue_CPP/bd825e8d76f88c024ec37b41c6178e2b873059de/cmake/aliked-config.cmake.in -------------------------------------------------------------------------------- /examples/main.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/ALIKED.hpp" 2 | #include "matcher/lightglue/matcher.hpp" 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | // Helper function to load and preprocess image 10 | cv::Mat load_image(const std::string& path) { 11 | cv::Mat img = cv::imread(path); 12 | if (img.empty()) 13 | { 14 | throw std::runtime_error("Failed to load image: " + path); 15 | } 16 | 17 | // Convert BGR to RGB 18 | cv::Mat img_rgb; 19 | cv::cvtColor(img, img_rgb, cv::COLOR_BGR2RGB); 20 | return img_rgb; 21 | } 22 | 23 | // Helper function to generate a colormap for pruning visualization 24 | cv::Scalar cm_prune(float value, float max_value) { 25 | float norm_value = (value == max_value) ? -1.0f : (value - 1.0f) / 9.0f; 26 | if (norm_value < 0) 27 | { // Blue for pruned points 28 | return cv::Scalar(255, 0, 0); 29 | } 30 | float green = std::min(1.0f, std::max(0.0f, norm_value * 2.0f)); 31 | float red = 1.0f - green; 32 | return {0, green * 255, red * 255}; 33 | } 34 | 35 | void draw_matches_and_prune(cv::Mat& img1, cv::Mat& img2, 36 | const torch::Tensor& kpts0, const torch::Tensor& kpts1, 37 | const torch::Tensor& matches, const torch::Tensor& scores, 38 | const torch::Tensor& prune0, const torch::Tensor& prune1, 39 | const long stop_layer) { 40 | int height = std::max(img1.rows, img2.rows); 41 | int width = img1.cols + img2.cols; 42 | cv::Mat output(height, width, CV_8UC3); 43 | img1.copyTo(output(cv::Rect(0, 0, img1.cols, img1.rows))); 44 | img2.copyTo(output(cv::Rect(img1.cols, 0, img2.cols, img2.rows))); 45 | 46 | // Move tensors to CPU 47 | const auto kpts0_cpu = kpts0.cpu(); 48 | const auto kpts1_cpu = kpts1.cpu(); 49 | const auto matches_cpu = matches.cpu(); 50 | const auto scores_cpu = scores.cpu(); 51 | const auto prune0_cpu = prune0.cpu(); 52 | const auto prune1_cpu = prune1.cpu(); 53 | 54 | // Debug first few values 55 | const auto max_prune0 = prune0_cpu.flatten().max().item(); 56 | const auto max_prune1 = prune1_cpu.flatten().max().item(); 57 | 58 | // Get number of matches from the second dimension 59 | const auto num_matches = matches_cpu.size(1); 60 | 61 | const float min_score_threshold = 0.1f; 62 | // Draw matches 63 | for (int i = 0; i < num_matches; i++) 64 | { 65 | // Get the match index directly 66 | const auto idx = matches_cpu[0][i].item(); 67 | const auto idx0 = i; 68 | const auto idx1 = idx; 69 | 70 | // Get the score for this match 71 | const auto score = scores_cpu[0][i].item(); 72 | 73 | // Skip low confidence matches 74 | if (score < min_score_threshold) 75 | { 76 | continue; 77 | } 78 | 79 | // Debug keypoint indices 80 | if (idx0 >= kpts0_cpu.size(0) || idx1 >= kpts1_cpu.size(0)) 81 | { 82 | std::cerr << "Error: Keypoint index out of range. idx0: " << idx0 << ", idx1: " << idx1 << std::endl; 83 | continue; 84 | } 85 | 86 | const auto x0 = kpts0_cpu[idx0][0].item(); 87 | const auto y0 = kpts0_cpu[idx0][1].item(); 88 | const auto x1 = kpts1_cpu[idx1][0].item(); 89 | const auto y1 = kpts1_cpu[idx1][1].item(); 90 | 91 | cv::Point2f pt1(x0, y0); 92 | cv::Point2f pt2(x1 + img1.cols, y1); 93 | 94 | // Draw the match only if score is above threshold 95 | cv::Scalar color(0, 255 * score, 0); 96 | cv::line(output, pt1, pt2, color, 1, cv::LINE_AA); 97 | cv::circle(output, pt1, 3, color, -1, cv::LINE_AA); 98 | cv::circle(output, pt2, 3, color, -1, cv::LINE_AA); 99 | } 100 | 101 | // Visualize pruning (uncomment) 102 | //for (int i = 0; i < kpts0_cpu.size(0); i++) 103 | //{ 104 | // const auto x0 = kpts0_cpu[i][0].item(); 105 | // const auto y0 = kpts0_cpu[i][1].item(); 106 | // cv::Scalar color = cm_prune(prune0_cpu[0][i].item(), max_prune0); 107 | // cv::circle(output, cv::Point2f(x0, y0), 5, color, -1, cv::LINE_AA); 108 | //} 109 | // for (int i = 0; i < kpts1_cpu.size(0); i++) 110 | //{ 111 | // const auto x1 = kpts1_cpu[i][0].item() + img1.cols; 112 | // const auto y1 = kpts1_cpu[i][1].item(); 113 | // cv::Scalar color = cm_prune(prune1_cpu[0][i].item(), max_prune1); 114 | // cv::circle(output, cv::Point2f(x1, y1), 5, color, -1, cv::LINE_AA); 115 | //} 116 | 117 | // Add text annotation 118 | std::string text = "Stopped after " + std::to_string(stop_layer) + " layers"; 119 | cv::putText(output, text, cv::Point(10, 30), cv::FONT_HERSHEY_SIMPLEX, 1.0, cv::Scalar(255, 255, 255), 2); 120 | 121 | // Show and save the result 122 | cv::imshow("Matches and Pruning", output); 123 | cv::waitKey(0); 124 | cv::imwrite("matches_and_pruning.png", output); 125 | } 126 | 127 | int main(int argc, char* argv[]) { 128 | if (argc != 3) 129 | { 130 | std::cerr << "Usage: " << argv[0] << " " << std::endl; 131 | return 1; 132 | } 133 | 134 | try 135 | { 136 | // Device selection 137 | torch::Device device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; 138 | std::cout << "Using device: " << (device.is_cuda() ? "CUDA" : "CPU") << std::endl; 139 | 140 | // Initialize models 141 | auto extractor = std::make_shared("aliked-n16", device.str()); 142 | auto matcher = std::make_shared(); 143 | 144 | // Move matcher to device 145 | matcher->to(device); 146 | 147 | // Load and process images 148 | cv::Mat img0 = load_image(argv[1]); 149 | cv::Mat img1 = load_image(argv[2]); 150 | 151 | // Extract features 152 | std::cout << "Extracting features..." << std::endl; 153 | auto feats0 = extractor->run(img0); 154 | auto feats1 = extractor->run(img1); 155 | 156 | // Match features 157 | std::cout << "Matching features..." << std::endl; 158 | feats0.insert("image_size", torch::tensor({static_cast(img0.cols), static_cast(img0.rows)}, torch::kFloat32).unsqueeze(0)); 159 | feats1.insert("image_size", torch::tensor({static_cast(img1.cols), static_cast(img1.rows)}, torch::kFloat32).unsqueeze(0)); 160 | auto matches01 = matcher->forward(feats0, feats1); 161 | 162 | // Get keypoints, matches, scores, and pruning information 163 | const auto& kpts0 = feats0.at("keypoints"); 164 | const auto& kpts1 = feats1.at("keypoints"); 165 | const auto& matches = matches01.at("matches0"); 166 | const auto& matching_scores = matches01.at("matching_scores0"); 167 | const auto& prune0 = matches01.at("prune0"); 168 | const auto& prune1 = matches01.at("prune1"); 169 | const auto stop_layer = matches01.at("stop").item(); 170 | 171 | // Print statistics 172 | std::cout << "Number of keypoints in image 0: " << kpts0.size(0) << std::endl; 173 | std::cout << "Number of keypoints in image 1: " << kpts1.size(0) << std::endl; 174 | 175 | // Count valid matches (where matches != -1) 176 | auto valid_matches = (matches != -1).sum().item(); 177 | std::cout << "Number of valid matches: " << valid_matches << std::endl; 178 | std::cout << "Stopped after " << stop_layer << " layers" << std::endl; 179 | 180 | // Visualize matches and pruning 181 | draw_matches_and_prune(img0, img1, kpts0, kpts1, matches, matching_scores, prune0, prune1, stop_layer); 182 | 183 | } catch (const std::exception& e) 184 | { 185 | std::cerr << "Error: " << e.what() << std::endl; 186 | return 1; 187 | } 188 | 189 | return 0; 190 | } 191 | -------------------------------------------------------------------------------- /include/feature/ALIKED.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "blocks.hpp" 4 | #include "input_padder.hpp" 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | struct AlikedConfig { 14 | int c1, c2, c3, c4, dim, K, M; 15 | }; 16 | 17 | class DKD; 18 | class SDDH; 19 | 20 | // Static configuration map 21 | inline const std::unordered_map ALIKED_CFGS = { 22 | {"aliked-t16", {8, 16, 32, 64, 64, 3, 16}}, 23 | {"aliked-n16", {16, 32, 64, 128, 128, 3, 16}}, 24 | {"aliked-n16rot", {16, 32, 64, 128, 128, 3, 16}}, 25 | {"aliked-n32", {16, 32, 64, 128, 128, 3, 32}}}; 26 | 27 | class ALIKED : public torch::nn::Module { 28 | public: 29 | explicit ALIKED(std::string_view model_name = "aliked-n32", 30 | std::string_view device = "cuda", 31 | int top_k = -1, 32 | float scores_th = 0.2, 33 | int n_limit = 20000); 34 | 35 | // Move semantics for tensor operations 36 | std::tuple 37 | extract_dense_map(torch::Tensor image) &&; 38 | 39 | torch::Dict 40 | forward(torch::Tensor image) &&; 41 | 42 | torch::Dict 43 | forward(const torch::Tensor& image) &; 44 | 45 | torch::Dict run(cv::Mat& img_rgb); 46 | 47 | private: 48 | void init_layers(std::string_view model_name); 49 | void load_weights(std::string_view model_name); 50 | void load_parameters(std::string_view pt_pth); 51 | 52 | static std::vector get_the_bytes(std::string_view filename); 53 | 54 | torch::nn::AvgPool2d pool2_{nullptr}, pool4_{nullptr}; 55 | std::shared_ptr block1_; 56 | std::shared_ptr block2_; 57 | std::shared_ptr block3_; 58 | std::shared_ptr block4_; 59 | torch::nn::Conv2d conv1_{nullptr}, conv2_{nullptr}, 60 | conv3_{nullptr}, conv4_{nullptr}; 61 | torch::nn::Sequential score_head_{nullptr}; 62 | 63 | std::shared_ptr dkd_; 64 | std::shared_ptr desc_head_; 65 | 66 | torch::Device device_; 67 | int dim_{}; 68 | }; -------------------------------------------------------------------------------- /include/feature/DKD.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | class DKD : public torch::nn::Module { 7 | public: 8 | DKD(int radius = 2, int top_k = -1, float scores_th = 0.2, int n_limit = 20000); 9 | 10 | std::tuple, std::vector, std::vector> 11 | detect_keypoints(torch::Tensor scores_map, bool sub_pixel = true) &&; 12 | 13 | std::tuple, std::vector, std::vector> 14 | detect_keypoints(const torch::Tensor& scores_map, bool sub_pixel = true) &; 15 | 16 | torch::Tensor simple_nms(torch::Tensor scores, int nms_radius) &&; 17 | torch::Tensor simple_nms(const torch::Tensor& scores, int nms_radius) &; 18 | 19 | std::tuple, std::vector, std::vector> 20 | forward(torch::Tensor scores_map, bool sub_pixel = true) &&; 21 | 22 | std::tuple, std::vector, std::vector> 23 | forward(const torch::Tensor& scores_map, bool sub_pixel = true) &; 24 | 25 | private: 26 | static constexpr int calculateKernelSize(int radius) { return 2 * radius + 1; } 27 | 28 | const int radius_; 29 | const int top_k_; 30 | const float scores_th_; 31 | const int n_limit_; 32 | const int kernel_size_; 33 | const float temperature_; 34 | torch::nn::Unfold unfold_{nullptr}; 35 | torch::Tensor hw_grid_; 36 | }; -------------------------------------------------------------------------------- /include/feature/SDDH.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | class SDDH : public torch::nn::Module { 7 | public: 8 | SDDH(int dims, int kernel_size = 3, int n_pos = 8, 9 | bool conv2D = false, bool mask = false); 10 | 11 | // Overloaded forward with move semantics 12 | std::tuple, std::vector> 13 | forward(torch::Tensor x, std::vector& keypoints) &&; 14 | 15 | std::tuple, std::vector> 16 | forward(const torch::Tensor& x, std::vector& keypoints) &; 17 | 18 | private: 19 | const int kernel_size_; 20 | const int n_pos_; 21 | const bool conv2D_; 22 | const bool mask_; 23 | torch::nn::Sequential offset_conv_{nullptr}; 24 | torch::nn::Conv2d sf_conv_{nullptr}; 25 | torch::nn::Conv2d convM_{nullptr}; 26 | torch::Tensor agg_weights_; 27 | 28 | // Helper functions for processing features 29 | torch::Tensor process_features(torch::Tensor features, int64_t num_keypoints) &&; 30 | torch::Tensor process_features(const torch::Tensor& features, int64_t num_keypoints) &; 31 | }; -------------------------------------------------------------------------------- /include/feature/blocks.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | class DeformableConv2d : public torch::nn::Module { 8 | public: 9 | DeformableConv2d(int in_channels, int out_channels, 10 | int kernel_size = 3, int stride = 1, 11 | int padding = 1, bool bias = false); 12 | 13 | torch::Tensor forward(const torch::Tensor& x) &; 14 | torch::Tensor forward(torch::Tensor x) &&; 15 | 16 | private: 17 | torch::nn::Conv2d offset_conv_{nullptr}; 18 | torch::nn::Conv2d regular_conv_{nullptr}; 19 | int padding_; 20 | int groups_ = 1; 21 | int mask_offset_ = 1; 22 | }; 23 | 24 | class ConvBlock : public torch::nn::Module { 25 | public: 26 | ConvBlock(int in_channels, int out_channels, 27 | std::string_view conv_type = "conv", 28 | bool mask = false); 29 | 30 | torch::Tensor forward(torch::Tensor x) &&; 31 | torch::Tensor forward(const torch::Tensor& x) &; 32 | 33 | private: 34 | torch::nn::Conv2d conv1_{nullptr}, conv2_{nullptr}; 35 | std::shared_ptr deform1_{nullptr}, deform2_{nullptr}; 36 | torch::nn::BatchNorm2d bn1_{nullptr}, bn2_{nullptr}; 37 | }; 38 | 39 | class ResBlock : public torch::nn::Module { 40 | public: 41 | ResBlock(int inplanes, int planes, int stride = 1, 42 | const torch::nn::Conv2d& downsample = nullptr, 43 | std::string_view conv_type = "conv"); 44 | 45 | torch::Tensor forward(torch::Tensor x) &&; 46 | torch::Tensor forward(const torch::Tensor& x) &; 47 | 48 | private: 49 | torch::nn::Conv2d conv1_{nullptr}, conv2_{nullptr}; 50 | std::shared_ptr deform1_{nullptr}, deform2_{nullptr}; 51 | torch::nn::BatchNorm2d bn1_{nullptr}, bn2_{nullptr}; 52 | torch::nn::Conv2d downsample_; 53 | }; -------------------------------------------------------------------------------- /include/feature/cuda_helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace vision { 4 | namespace ops { 5 | 6 | #define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \ 7 | for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ 8 | i += (blockDim.x * gridDim.x)) 9 | 10 | #define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int) 11 | 12 | template 13 | constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { 14 | return (n + m - 1) / m; 15 | } 16 | 17 | } // namespace ops 18 | } // namespace vision 19 | -------------------------------------------------------------------------------- /include/feature/deform_conv2d.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace vision { 6 | namespace ops { 7 | 8 | at::Tensor deform_conv2d( 9 | const at::Tensor& input, 10 | const at::Tensor& weight, 11 | const at::Tensor& offset, 12 | const at::Tensor& mask, 13 | const at::Tensor& bias, 14 | int64_t stride_h, 15 | int64_t stride_w, 16 | int64_t pad_h, 17 | int64_t pad_w, 18 | int64_t dilation_h, 19 | int64_t dilation_w, 20 | int64_t groups, 21 | int64_t offset_groups, 22 | bool use_mask); 23 | 24 | at::Tensor deform_conv2d_symint( 25 | const at::Tensor& input, 26 | const at::Tensor& weight, 27 | const at::Tensor& offset, 28 | const at::Tensor& mask, 29 | const at::Tensor& bias, 30 | c10::SymInt stride_h, 31 | c10::SymInt stride_w, 32 | c10::SymInt pad_h, 33 | c10::SymInt pad_w, 34 | c10::SymInt dilation_h, 35 | c10::SymInt dilation_w, 36 | c10::SymInt groups, 37 | c10::SymInt offset_groups, 38 | bool use_mask); 39 | 40 | namespace detail { 41 | 42 | std::tuple 43 | _deform_conv2d_backward( 44 | const at::Tensor& grad, 45 | const at::Tensor& input, 46 | const at::Tensor& weight, 47 | const at::Tensor& offset, 48 | const at::Tensor& mask, 49 | const at::Tensor& bias, 50 | int64_t stride_h, 51 | int64_t stride_w, 52 | int64_t pad_h, 53 | int64_t pad_w, 54 | int64_t dilation_h, 55 | int64_t dilation_w, 56 | int64_t groups, 57 | int64_t offset_groups, 58 | bool use_mask); 59 | 60 | std::tuple 61 | _deform_conv2d_backward_symint( 62 | const at::Tensor& grad, 63 | const at::Tensor& input, 64 | const at::Tensor& weight, 65 | const at::Tensor& offset, 66 | const at::Tensor& mask, 67 | const at::Tensor& bias, 68 | c10::SymInt stride_h, 69 | c10::SymInt stride_w, 70 | c10::SymInt pad_h, 71 | c10::SymInt pad_w, 72 | c10::SymInt dilation_h, 73 | c10::SymInt dilation_w, 74 | c10::SymInt groups, 75 | c10::SymInt offset_groups, 76 | bool use_mask); 77 | 78 | } // namespace detail 79 | 80 | } // namespace ops 81 | } // namespace vision 82 | -------------------------------------------------------------------------------- /include/feature/get_patches.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace custom_ops { 6 | 7 | torch::Tensor get_patches_forward(const torch::Tensor& map, torch::Tensor& points, int64_t radius); 8 | torch::Tensor get_patches_backward(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W); 9 | } // namespace custom_ops 10 | -------------------------------------------------------------------------------- /include/feature/get_patches_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | // CUDA declarations 4 | torch::Tensor get_patches_forward_cuda(const torch::Tensor& map, torch::Tensor& points, int64_t radius); 5 | torch::Tensor get_patches_backward_cuda(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W); 6 | -------------------------------------------------------------------------------- /include/feature/input_padder.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | class InputPadder { 8 | public: 9 | InputPadder(int h, int w, int div_by = 8) 10 | : ht_(h), 11 | wd_(w) { 12 | int pad_ht = (((ht_ / div_by) + 1) * div_by - ht_) % div_by; 13 | int pad_wd = (((wd_ / div_by) + 1) * div_by - wd_) % div_by; 14 | 15 | pad_ = {pad_wd / 2, pad_wd - pad_wd / 2, 16 | pad_ht / 2, pad_ht - pad_ht / 2}; 17 | } 18 | 19 | // Move semantics for pad operation 20 | torch::Tensor pad(torch::Tensor x) &&; 21 | torch::Tensor pad(const torch::Tensor& x) &; 22 | 23 | // Move semantics for unpad operation 24 | torch::Tensor unpad(torch::Tensor x) &&; 25 | 26 | private: 27 | int ht_; 28 | int wd_; 29 | std::array pad_; 30 | }; -------------------------------------------------------------------------------- /include/matcher/lightglue/attention.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace matcher { 4 | class TokenConfidence : public torch::nn::Module { 5 | public: 6 | explicit TokenConfidence(int dim); 7 | 8 | // Returns confidence scores for both descriptors 9 | std::tuple forward( 10 | const torch::Tensor& desc0, 11 | const torch::Tensor& desc1); 12 | 13 | private: 14 | torch::nn::Sequential token_{nullptr}; 15 | }; 16 | 17 | // Attention Module 18 | class Attention : public torch::nn::Module { 19 | public: 20 | explicit Attention(bool allow_flash); 21 | 22 | torch::Tensor forward( 23 | const torch::Tensor& q, 24 | const torch::Tensor& k, 25 | const torch::Tensor& v); 26 | 27 | private: 28 | bool enable_flash_; 29 | bool has_sdp_; 30 | }; 31 | 32 | // Self-Attention Block 33 | class SelfBlock : public torch::nn::Module { 34 | public: 35 | SelfBlock(int embed_dim, int num_heads, bool flash = false, bool bias = true); 36 | 37 | torch::Tensor apply_cached_rotary_emb( 38 | const torch::Tensor& freqs, const torch::Tensor& t); 39 | 40 | torch::Tensor forward( 41 | const torch::Tensor& x, 42 | const torch::Tensor& encoding); 43 | 44 | private: 45 | int embed_dim_; 46 | int num_heads_; 47 | int head_dim_; 48 | torch::nn::Linear Wqkv_{nullptr}; 49 | std::shared_ptr inner_attn_; 50 | torch::nn::Linear out_proj_{nullptr}; 51 | torch::nn::Sequential ffn_{nullptr}; 52 | torch::Tensor rotate_half(const torch::Tensor& x); 53 | }; 54 | 55 | // Cross-Attention Block 56 | class CrossBlock : public torch::nn::Module { 57 | public: 58 | CrossBlock(int embed_dim, int num_heads, bool flash = false, bool bias = true); 59 | 60 | std::tuple forward( 61 | const torch::Tensor& x0, 62 | const torch::Tensor& x1, 63 | const torch::optional& mask = torch::nullopt); 64 | 65 | private: 66 | int heads_; 67 | float scale_; 68 | torch::nn::Linear to_qk_{nullptr}; 69 | torch::nn::Linear to_v_{nullptr}; 70 | torch::nn::Linear to_out_{nullptr}; 71 | torch::nn::Sequential ffn_{nullptr}; 72 | std::shared_ptr flash_; 73 | }; 74 | 75 | } 76 | 77 | -------------------------------------------------------------------------------- /include/matcher/lightglue/core.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace matcher { 6 | struct LightGlueConfig { 7 | static constexpr float DEFAULT_DEPTH_CONFIDENCE = 0.95f; 8 | static constexpr float DEFAULT_WIDTH_CONFIDENCE = 0.99f; 9 | static constexpr float DEFAULT_FILTER_THRESHOLD = 0.1f; 10 | 11 | std::string name{"lightglue"}; 12 | int input_dim{128}; 13 | int descriptor_dim{256}; 14 | bool add_scale_ori{false}; 15 | int n_layers{9}; 16 | int num_heads{4}; 17 | bool flash{false}; 18 | bool mp{false}; 19 | float depth_confidence{DEFAULT_DEPTH_CONFIDENCE}; 20 | float width_confidence{DEFAULT_WIDTH_CONFIDENCE}; 21 | float filter_threshold{DEFAULT_FILTER_THRESHOLD}; 22 | std::string weights; 23 | }; 24 | } 25 | 26 | namespace matcher::utils { 27 | torch::Tensor normalize_keypoints(const torch::Tensor& kpts, 28 | const torch::optional& size = torch::nullopt); 29 | 30 | std::tuple 31 | filter_matches(const torch::Tensor& scores, float threshold); 32 | } -------------------------------------------------------------------------------- /include/matcher/lightglue/encoding.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace matcher { 5 | // Learnable Fourier Positional Encoding 6 | class LearnableFourierPosEnc : public torch::nn::Module { 7 | public: 8 | LearnableFourierPosEnc(int M, int dim, torch::optional F_dim = torch::nullopt, float gamma = 1.0); 9 | 10 | // Forward function returns the position encoding 11 | torch::Tensor forward(const torch::Tensor& x); 12 | 13 | private: 14 | float gamma_; 15 | torch::nn::Linear Wr_{nullptr}; 16 | }; 17 | } 18 | 19 | -------------------------------------------------------------------------------- /include/matcher/lightglue/matcher.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "matcher/lightglue/core.hpp" 4 | 5 | 6 | namespace matcher { 7 | class MatchAssignment : public torch::nn::Module { 8 | public: 9 | explicit MatchAssignment(int dim); 10 | 11 | torch::Tensor sigmoid_log_double_softmax( 12 | const torch::Tensor& sim, 13 | const torch::Tensor& z0, 14 | const torch::Tensor& z1); 15 | 16 | torch::Tensor forward( 17 | const torch::Tensor& desc0, 18 | const torch::Tensor& desc1); 19 | 20 | torch::Tensor get_matchability(const torch::Tensor& desc); 21 | 22 | private: 23 | int dim_; 24 | torch::nn::Linear matchability_{nullptr}; 25 | torch::nn::Linear final_proj_{nullptr}; 26 | }; 27 | 28 | 29 | class LearnableFourierPosEnc; 30 | class TokenConfidence; 31 | class TransformerLayer; 32 | class MatchAssignment; 33 | 34 | class LightGlue : public torch::nn::Module { 35 | public: 36 | explicit LightGlue(const std::string& feature_type = "aliked", 37 | const LightGlueConfig& config = LightGlueConfig()); 38 | 39 | // Main forward function to process features and find matches 40 | torch::Dict forward( 41 | const torch::Dict& data0, 42 | const torch::Dict& data1); 43 | 44 | // Method to move all components to specified device 45 | void to(const torch::Device& device); 46 | 47 | private: 48 | torch::Tensor get_pruning_mask( 49 | const torch::optional& confidences, 50 | const torch::Tensor& scores, 51 | int layer_index); 52 | 53 | bool check_if_stop( 54 | const torch::Tensor& confidences0, 55 | const torch::Tensor& confidences1, 56 | int layer_index, 57 | int num_points); 58 | 59 | void load_weights(const std::string& feature_type); 60 | 61 | private: 62 | LightGlueConfig config_; 63 | torch::Device device_; 64 | 65 | // Neural network components 66 | torch::nn::Linear input_proj_{nullptr}; 67 | std::shared_ptr posenc_; 68 | std::vector> transformers_; 69 | std::vector> log_assignment_; 70 | std::vector> token_confidence_; 71 | std::vector confidence_thresholds_; 72 | 73 | static const std::unordered_map pruning_keypoint_thresholds_; 74 | void load_parameters(const std::string& pt_path); 75 | std::vector get_the_bytes(const std::string& filename); 76 | }; 77 | } 78 | -------------------------------------------------------------------------------- /include/matcher/lightglue/transformer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace matcher{ 7 | class SelfBlock; 8 | class CrossBlock; 9 | 10 | class TransformerLayer : public torch::nn::Module { 11 | public: 12 | TransformerLayer(int embed_dim, int num_heads, bool flash = false, bool bias = true); 13 | 14 | std::tuple forward( 15 | const torch::Tensor& desc0, 16 | const torch::Tensor& desc1, 17 | const torch::Tensor& encoding0, 18 | const torch::Tensor& encoding1); 19 | 20 | private: 21 | std::shared_ptr self_attn_; 22 | std::shared_ptr cross_attn_; 23 | }; 24 | } 25 | -------------------------------------------------------------------------------- /models/aliked-n16.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/Light_Glue_CPP/bd825e8d76f88c024ec37b41c6178e2b873059de/models/aliked-n16.pt -------------------------------------------------------------------------------- /models/aliked-n16rot.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/Light_Glue_CPP/bd825e8d76f88c024ec37b41c6178e2b873059de/models/aliked-n16rot.pt -------------------------------------------------------------------------------- /models/aliked-n32.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/Light_Glue_CPP/bd825e8d76f88c024ec37b41c6178e2b873059de/models/aliked-n32.pt -------------------------------------------------------------------------------- /models/aliked-t16.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/Light_Glue_CPP/bd825e8d76f88c024ec37b41c6178e2b873059de/models/aliked-t16.pt -------------------------------------------------------------------------------- /models/aliked_lightglue.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrNeRF/Light_Glue_CPP/bd825e8d76f88c024ec37b41c6178e2b873059de/models/aliked_lightglue.pt -------------------------------------------------------------------------------- /src/feature/ALIKED.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/ALIKED.hpp" 2 | 3 | #include "feature/DKD.hpp" 4 | #include "feature/SDDH.hpp" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace fs = std::filesystem; 11 | 12 | ALIKED::ALIKED(std::string_view model_name, 13 | std::string_view device, 14 | int top_k, 15 | float scores_th, 16 | int n_limit) 17 | : device_(torch::Device(std::string(device))), 18 | dim_(-1) { 19 | 20 | // Initialize DKD and descriptor head 21 | dkd_ = std::make_shared(2, top_k, scores_th, n_limit); 22 | const auto& config = ALIKED_CFGS.at(std::string(model_name)); 23 | desc_head_ = std::make_shared(config.dim, config.K, config.M); 24 | 25 | // Initialize layers 26 | init_layers(model_name); 27 | 28 | // Load weights first 29 | load_weights(model_name); 30 | 31 | // Move everything to the specified device 32 | this->to(device_); 33 | dkd_->to(device_); // Explicitly move DKD 34 | desc_head_->to(device_); // Explicitly move SDDH 35 | 36 | // Double check all submodules are on the correct device 37 | for (const auto& param : this->parameters()) 38 | { 39 | if (param.device() != device_) 40 | { 41 | param.to(device_); 42 | } 43 | } 44 | 45 | for (const auto& buffer : this->buffers()) 46 | { 47 | if (buffer.device() != device_) 48 | { 49 | buffer.to(device_); 50 | } 51 | } 52 | 53 | this->eval(); 54 | } 55 | 56 | std::tuple 57 | ALIKED::extract_dense_map(torch::Tensor image) && { 58 | // Create padder for input 59 | auto padder = InputPadder(image.size(2), image.size(3), 32); 60 | image = std::move(padder).pad(std::move(image)); 61 | 62 | // Feature extraction with move semantics 63 | auto x1 = std::dynamic_pointer_cast(block1_)->forward(image); 64 | auto x2 = std::dynamic_pointer_cast(block2_)->forward(pool2_->forward(x1)); 65 | auto x3 = std::dynamic_pointer_cast(block3_)->forward(pool4_->forward(x2)); 66 | auto x4 = std::dynamic_pointer_cast(block4_)->forward(pool4_->forward(x3)); 67 | 68 | // Feature aggregation 69 | auto x1_processed = torch::selu(conv1_->forward(x1)); 70 | auto x2_processed = torch::selu(conv2_->forward(x2)); 71 | auto x3_processed = torch::selu(conv3_->forward(x3)); 72 | auto x4_processed = torch::selu(conv4_->forward(x4)); 73 | 74 | // Upsample with move semantics 75 | auto options = torch::nn::functional::InterpolateFuncOptions() 76 | .mode(torch::kBilinear) 77 | .align_corners(true); 78 | 79 | auto x2_up = torch::nn::functional::interpolate(x2_processed, 80 | options.size(std::vector{x1.size(2), x1.size(3)})); 81 | auto x3_up = torch::nn::functional::interpolate(x3_processed, 82 | options.size(std::vector{x1.size(2), x1.size(3)})); 83 | auto x4_up = torch::nn::functional::interpolate(x4_processed, 84 | options.size(std::vector{x1.size(2), x1.size(3)})); 85 | 86 | auto x1234 = torch::cat({std::move(x1_processed), 87 | std::move(x2_up), 88 | std::move(x3_up), 89 | std::move(x4_up)}, 90 | 1); 91 | 92 | // Generate score map and feature map 93 | auto score_map = torch::sigmoid(score_head_->forward(x1234.clone())); 94 | auto feature_map = torch::nn::functional::normalize(x1234, 95 | torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); 96 | 97 | // Unpad tensors with move semantics 98 | feature_map = std::move(padder).unpad(std::move(feature_map)); 99 | score_map = std::move(padder).unpad(std::move(score_map)); 100 | 101 | return std::make_tuple(std::move(feature_map), std::move(score_map)); 102 | } 103 | 104 | torch::Dict 105 | ALIKED::forward(torch::Tensor image) && { 106 | 107 | auto start_time = std::chrono::high_resolution_clock::now(); 108 | 109 | auto [feature_map, score_map] = std::move(*this).extract_dense_map(std::move(image)); 110 | auto [keypoints, kptscores, scoredispersitys] = std::move(*dkd_).forward(score_map); 111 | auto [descriptors, offsets] = std::move(*desc_head_).forward(feature_map, keypoints); 112 | 113 | auto end_time = std::chrono::high_resolution_clock::now(); 114 | auto duration = duration_cast(end_time - start_time).count() / 1000.0f; 115 | 116 | torch::Dict output; 117 | output.insert("keypoints", std::move(keypoints[0])); 118 | output.insert("descriptors", std::move(descriptors[0])); 119 | output.insert("scores", std::move(kptscores[0])); 120 | output.insert("score_dispersity", std::move(scoredispersitys[0])); 121 | output.insert("score_map", std::move(score_map)); 122 | output.insert("time", torch::tensor(duration)); 123 | 124 | return output; 125 | } 126 | 127 | torch::Dict 128 | ALIKED::forward(const torch::Tensor& image) & { 129 | auto image_copy = image.clone(); 130 | return std::move(*this).forward(std::move(image_copy)); 131 | } 132 | 133 | torch::Dict 134 | ALIKED::run(cv::Mat& img_rgb) { 135 | cv::Mat float_img; 136 | img_rgb.convertTo(float_img, CV_32F, 1.0 / 255.0); 137 | 138 | std::vector channels(3); 139 | cv::split(float_img, channels); 140 | 141 | auto options = torch::TensorOptions() 142 | .dtype(torch::kFloat32) 143 | .device(device_); 144 | 145 | std::vector tensor_channels; 146 | tensor_channels.reserve(3); 147 | 148 | for (const auto& channel : channels) 149 | { 150 | auto host_tensor = torch::from_blob( 151 | channel.data, 152 | {channel.rows, channel.cols}, 153 | torch::TensorOptions().dtype(torch::kFloat32)); 154 | tensor_channels.push_back(std::move(host_tensor).to(device_)); 155 | } 156 | 157 | auto img_tensor = torch::stack(std::move(tensor_channels), 0) 158 | .unsqueeze(0) 159 | .to(device_); 160 | 161 | // Forward pass with move semantics 162 | auto pred = std::move(*this).forward(std::move(img_tensor)); 163 | 164 | // Convert keypoints from normalized coordinates to image coordinates 165 | auto kpts = pred.at("keypoints"); 166 | TORCH_CHECK(pred.erase("keypoints"), "Failed to remove 'keypoints' from output dict"); 167 | const auto h = static_cast(float_img.rows); 168 | const auto w = static_cast(float_img.cols); 169 | const auto wh = torch::tensor({w - 1.0f, h - 1.0f}, kpts.options()); 170 | kpts = wh * (kpts + 1) / 2; 171 | auto [iter, success] = pred.insert("keypoints", kpts); 172 | TORCH_CHECK(success, "Failed to insert 'keypoints' into output dict"); 173 | return pred; 174 | } 175 | 176 | void ALIKED::init_layers(std::string_view model_name) { 177 | const auto& config = ALIKED_CFGS.at(std::string(model_name)); 178 | dim_ = config.dim; 179 | 180 | // Basic layers 181 | pool2_ = register_module("pool2", 182 | torch::nn::AvgPool2d(torch::nn::AvgPool2dOptions(2).stride(2))); 183 | pool4_ = register_module("pool4", 184 | torch::nn::AvgPool2d(torch::nn::AvgPool2dOptions(4).stride(4))); 185 | 186 | // Blocks with move semantics 187 | block1_ = register_module( 188 | "block1", 189 | std::make_shared(3, config.c1, "conv", false)); 190 | 191 | auto downsample2 = torch::nn::Conv2d( 192 | torch::nn::Conv2dOptions(config.c1, config.c2, 1)); 193 | block2_ = register_module( 194 | "block2", 195 | std::make_shared(config.c1, config.c2, 1, downsample2, "conv")); 196 | 197 | auto downsample3 = torch::nn::Conv2d( 198 | torch::nn::Conv2dOptions(config.c2, config.c3, 1)); 199 | block3_ = register_module( 200 | "block3", 201 | std::make_shared(config.c2, config.c3, 1, downsample3, "dcn")); 202 | 203 | auto downsample4 = torch::nn::Conv2d( 204 | torch::nn::Conv2dOptions(config.c3, config.c4, 1)); 205 | block4_ = register_module( 206 | "block4", 207 | std::make_shared(config.c3, config.c4, 1, downsample4, "dcn")); 208 | 209 | // Convolution layers 210 | const int out_channels = dim_ / 4; 211 | conv1_ = register_module("conv1", 212 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c1, out_channels, 1).stride(1).bias(false))); 213 | conv2_ = register_module("conv2", 214 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c2, out_channels, 1).stride(1).bias(false))); 215 | conv3_ = register_module("conv3", 216 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c3, out_channels, 1).stride(1).bias(false))); 217 | conv4_ = register_module("conv4", 218 | torch::nn::Conv2d(torch::nn::Conv2dOptions(config.c4, out_channels, 1).stride(1).bias(false))); 219 | 220 | // Score head 221 | torch::nn::Sequential score_head; 222 | score_head->push_back(torch::nn::Conv2d( 223 | torch::nn::Conv2dOptions(dim_, 8, 1).stride(1).bias(false))); 224 | score_head->push_back(torch::nn::SELU()); 225 | score_head->push_back(torch::nn::Conv2d( 226 | torch::nn::Conv2dOptions(8, 4, 3).padding(1).stride(1).bias(false))); 227 | score_head->push_back(torch::nn::SELU()); 228 | score_head->push_back(torch::nn::Conv2d( 229 | torch::nn::Conv2dOptions(4, 4, 3).padding(1).stride(1).bias(false))); 230 | score_head->push_back(torch::nn::SELU()); 231 | score_head->push_back(torch::nn::Conv2d( 232 | torch::nn::Conv2dOptions(4, 1, 3).padding(1).stride(1).bias(false))); 233 | 234 | score_head_ = register_module("score_head", score_head); 235 | register_module("desc_head", desc_head_); 236 | register_module("dkd", dkd_); 237 | } 238 | 239 | void ALIKED::load_weights(std::string_view model_name) { 240 | std::vector search_paths = { 241 | std::filesystem::path(LIGHTGLUE_MODELS_DIR) / (std::string(model_name) + ".pt"), 242 | std::filesystem::current_path() / "models" / (std::string(model_name) + ".pt"), 243 | std::filesystem::current_path() / (std::string(model_name) + ".pt")}; 244 | 245 | std::filesystem::path model_path; 246 | bool found = false; 247 | 248 | for (const auto& path : search_paths) 249 | { 250 | if (std::filesystem::exists(path)) 251 | { 252 | model_path = path; 253 | found = true; 254 | break; 255 | } 256 | } 257 | 258 | if (!found) 259 | { 260 | std::string error_msg = "Cannot find pretrained model. Searched in:\n"; 261 | for (const auto& path : search_paths) 262 | { 263 | error_msg += " " + path.string() + "\n"; 264 | } 265 | error_msg += "Please place the model file in one of these locations."; 266 | throw std::runtime_error(error_msg); 267 | } 268 | 269 | std::cout << "Loading model from: " << model_path << std::endl; 270 | load_parameters(model_path.string()); 271 | } 272 | 273 | void ALIKED::load_parameters(std::string_view pt_pth) { 274 | auto f = get_the_bytes(pt_pth); 275 | auto weights = torch::pickle_load(f).toGenericDict(); 276 | 277 | // Use unordered_maps for O(1) lookup 278 | std::unordered_map param_map; 279 | std::unordered_map buffer_map; 280 | 281 | auto model_params = named_parameters(); 282 | auto model_buffers = named_buffers(); 283 | // Pre-allocate with expected size 284 | param_map.reserve(model_params.size()); 285 | buffer_map.reserve(model_buffers.size()); 286 | 287 | // Collect parameter names 288 | for (const auto& p : model_params) 289 | { 290 | param_map.emplace(p.key(), p.value()); 291 | } 292 | 293 | // Collect buffer names 294 | for (const auto& b : model_buffers) 295 | { 296 | buffer_map.emplace(b.key(), b.value()); 297 | } 298 | 299 | // Update parameters and buffers 300 | torch::NoGradGuard no_grad; 301 | 302 | for (const auto& w : weights) 303 | { 304 | const auto name = w.key().toStringRef(); 305 | const auto& param = w.value().toTensor(); 306 | 307 | // Try parameters first 308 | if (auto it = param_map.find(name); it != param_map.end()) 309 | { 310 | if (it->second.sizes() == param.sizes()) 311 | { 312 | it->second.copy_(param); 313 | } else 314 | { 315 | throw std::runtime_error( 316 | "Shape mismatch for parameter: " + name + 317 | " Expected: " + std::to_string(it->second.numel()) + 318 | " Got: " + std::to_string(param.numel())); 319 | } 320 | continue; 321 | } 322 | 323 | // Then try buffers 324 | if (auto it = buffer_map.find(name); it != buffer_map.end()) 325 | { 326 | if (it->second.sizes() == param.sizes()) 327 | { 328 | it->second.copy_(param); 329 | } else 330 | { 331 | throw std::runtime_error( 332 | "Shape mismatch for buffer: " + name + 333 | " Expected: " + std::to_string(it->second.numel()) + 334 | " Got: " + std::to_string(param.numel())); 335 | } 336 | continue; 337 | } 338 | 339 | // Parameter not found in model 340 | std::cerr << "Warning: " << name 341 | << " not found in model parameters or buffers\n"; 342 | } 343 | } 344 | 345 | std::vector ALIKED::get_the_bytes(std::string_view filename) { 346 | // Use RAII file handling 347 | std::ifstream file(std::string(filename), std::ios::binary); 348 | if (!file) 349 | { 350 | throw std::runtime_error( 351 | "Failed to open file: " + std::string(filename)); 352 | } 353 | 354 | // Get file size 355 | file.seekg(0, std::ios::end); 356 | const auto size = file.tellg(); 357 | file.seekg(0, std::ios::beg); 358 | 359 | // Pre-allocate vector 360 | std::vector buffer; 361 | buffer.reserve(size); 362 | 363 | // Read file in chunks for better performance 364 | constexpr size_t CHUNK_SIZE = 8192; 365 | char chunk[CHUNK_SIZE]; 366 | 367 | while (file.read(chunk, CHUNK_SIZE)) 368 | { 369 | buffer.insert(buffer.end(), chunk, chunk + file.gcount()); 370 | } 371 | if (file.gcount() > 0) 372 | { 373 | buffer.insert(buffer.end(), chunk, chunk + file.gcount()); 374 | } 375 | 376 | return buffer; 377 | } -------------------------------------------------------------------------------- /src/feature/DKD.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/DKD.hpp" 2 | 3 | #include 4 | 5 | namespace F = torch::nn::functional; 6 | using namespace torch::indexing; 7 | 8 | DKD::DKD(int radius, int top_k, float scores_th, int n_limit) 9 | : radius_(radius), 10 | top_k_(top_k), 11 | scores_th_(scores_th), 12 | n_limit_(n_limit), 13 | kernel_size_(calculateKernelSize(radius)), 14 | temperature_(0.1f), 15 | unfold_(torch::nn::UnfoldOptions(kernel_size_).padding(radius)) { 16 | 17 | auto x = torch::linspace(-radius_, radius_, kernel_size_); 18 | auto meshgrid = torch::meshgrid({x, x}); 19 | hw_grid_ = torch::stack({meshgrid[1], meshgrid[0]}, -1) 20 | .reshape({-1, 2}) 21 | .contiguous(); // Ensure contiguous memory layout 22 | } 23 | 24 | torch::Tensor DKD::simple_nms(torch::Tensor scores, int nms_radius) && { 25 | auto zeros = torch::zeros_like(scores); 26 | auto max_pool_options = F::MaxPool2dFuncOptions(nms_radius * 2 + 1) 27 | .stride(1) 28 | .padding(nms_radius); 29 | 30 | auto max_mask = std::move(scores) == F::max_pool2d(scores, max_pool_options); 31 | 32 | for (int i = 0; i < 2; ++i) 33 | { 34 | auto supp_mask = F::max_pool2d(max_mask.to(torch::kFloat), max_pool_options) > 0; 35 | auto supp_scores = torch::where(supp_mask, zeros, scores); 36 | auto new_max_mask = supp_scores == F::max_pool2d(supp_scores, max_pool_options); 37 | max_mask = max_mask | (new_max_mask & (~supp_mask)); 38 | } 39 | 40 | return torch::where(max_mask, scores, std::move(zeros)); 41 | } 42 | 43 | torch::Tensor DKD::simple_nms(const torch::Tensor& scores, int nms_radius) & { 44 | auto scores_copy = scores.clone(); 45 | return std::move(*this).simple_nms(std::move(scores_copy), nms_radius); 46 | } 47 | 48 | std::tuple, std::vector, std::vector> 49 | DKD::detect_keypoints(torch::Tensor scores_map, bool sub_pixel) && { 50 | const auto batch_size = scores_map.size(0); 51 | const auto height = scores_map.size(2); 52 | const auto width = scores_map.size(3); 53 | const auto device = scores_map.device(); 54 | 55 | auto scores_nograd = scores_map.detach(); 56 | auto nms_scores = std::move(*this).simple_nms(std::move(scores_nograd), 2); 57 | 58 | auto border_mask = torch::ones_like(nms_scores, 59 | torch::TensorOptions() 60 | .dtype(torch::kBool) 61 | .device(device)); 62 | 63 | border_mask.index_put_({Slice(), Slice(), Slice(None, radius_), Slice()}, false); 64 | border_mask.index_put_({Slice(), Slice(), Slice(), Slice(None, radius_)}, false); 65 | border_mask.index_put_({Slice(), Slice(), Slice(-radius_, None), Slice()}, false); 66 | border_mask.index_put_({Slice(), Slice(), Slice(), Slice(-radius_, None)}, false); 67 | 68 | nms_scores = torch::where(border_mask, nms_scores, torch::zeros_like(nms_scores)); 69 | 70 | std::vector keypoints; 71 | std::vector scoredispersitys; 72 | std::vector kptscores; 73 | keypoints.reserve(batch_size); 74 | scoredispersitys.reserve(batch_size); 75 | kptscores.reserve(batch_size); 76 | 77 | // Create wh tensor on the correct device 78 | auto wh = torch::tensor( 79 | {static_cast(width - 1), static_cast(height - 1)}, 80 | torch::TensorOptions().dtype(scores_map.dtype()).device(device)); 81 | 82 | // Ensure hw_grid_ is on the correct device 83 | if (hw_grid_.device() != device) 84 | { 85 | hw_grid_ = hw_grid_.to(device); 86 | } 87 | 88 | if (sub_pixel) 89 | { 90 | auto patches = unfold_(scores_map); 91 | 92 | for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) 93 | { 94 | auto patch = patches[batch_idx].transpose(0, 1); 95 | 96 | torch::Tensor indices_kpt; 97 | if (top_k_ > 0) 98 | { 99 | auto scores_view = nms_scores[batch_idx].reshape(-1); 100 | auto topk = scores_view.topk(top_k_); 101 | indices_kpt = std::get<1>(topk); 102 | } else 103 | { 104 | auto scores_view = nms_scores[batch_idx].reshape(-1); 105 | auto mask = scores_view > scores_th_; 106 | indices_kpt = mask.nonzero().squeeze(1); 107 | if (indices_kpt.size(0) > n_limit_) 108 | { 109 | auto kpts_sc = scores_view.index_select(0, indices_kpt); 110 | auto sort_idx = std::get<1>(kpts_sc.sort(true)); 111 | indices_kpt = indices_kpt.index_select(0, sort_idx.slice(0, n_limit_)); 112 | } 113 | } 114 | 115 | auto patch_scores = patch.index_select(0, indices_kpt); 116 | auto keypoints_xy_nms = torch::stack({indices_kpt % width, 117 | torch::div(indices_kpt, width, /*rounding_mode=*/"floor")}, 118 | 1) 119 | .to(device); 120 | 121 | auto [max_v, _] = patch_scores.max(1, true); 122 | auto x_exp = ((patch_scores - max_v.detach()) / temperature_).exp(); 123 | auto xy_residual = (x_exp.unsqueeze(2) * hw_grid_.unsqueeze(0)).sum(1) / 124 | x_exp.sum(1, true); 125 | 126 | auto dist2 = (hw_grid_.unsqueeze(0) - xy_residual.unsqueeze(1)) 127 | .div(radius_) 128 | .norm(2, -1) 129 | .pow(2); 130 | 131 | auto scoredispersity = (x_exp * dist2).sum(1) / x_exp.sum(1); 132 | auto keypoints_xy = keypoints_xy_nms + xy_residual; 133 | keypoints_xy = keypoints_xy.div(wh).mul(2).sub(1); 134 | 135 | auto kptscore = torch::nn::functional::grid_sample( 136 | scores_map[batch_idx].unsqueeze(0), 137 | keypoints_xy.view({1, 1, -1, 2}), 138 | torch::nn::functional::GridSampleFuncOptions() 139 | .mode(torch::kBilinear) 140 | .align_corners(true))[0][0][0]; 141 | 142 | keypoints.push_back(std::move(keypoints_xy)); 143 | scoredispersitys.push_back(std::move(scoredispersity)); 144 | kptscores.push_back(std::move(kptscore)); 145 | } 146 | } else 147 | { 148 | for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) 149 | { 150 | torch::Tensor indices_kpt; 151 | if (top_k_ > 0) 152 | { 153 | auto scores_view = nms_scores[batch_idx].reshape(-1); 154 | auto topk = scores_view.topk(top_k_); 155 | indices_kpt = std::get<1>(topk); 156 | } else 157 | { 158 | auto scores_view = nms_scores[batch_idx].reshape(-1); 159 | auto mask = scores_view > scores_th_; 160 | indices_kpt = mask.nonzero().squeeze(1); 161 | if (indices_kpt.size(0) > n_limit_) 162 | { 163 | auto kpts_sc = scores_view.index_select(0, indices_kpt); 164 | auto sort_idx = std::get<1>(kpts_sc.sort(true)); 165 | indices_kpt = indices_kpt.index_select(0, sort_idx.slice(0, n_limit_)); 166 | } 167 | } 168 | 169 | auto keypoints_xy = torch::stack({indices_kpt % width, 170 | torch::div(indices_kpt, width, /*rounding_mode=*/"floor")}, 171 | 1) 172 | .to(device); 173 | 174 | keypoints_xy = keypoints_xy.div(wh).mul(2).sub(1); 175 | 176 | auto kptscore = torch::nn::functional::grid_sample( 177 | scores_map[batch_idx].unsqueeze(0), 178 | keypoints_xy.view({1, 1, -1, 2}), 179 | torch::nn::functional::GridSampleFuncOptions() 180 | .mode(torch::kBilinear) 181 | .align_corners(true))[0][0][0]; 182 | 183 | keypoints.push_back(std::move(keypoints_xy)); 184 | scoredispersitys.push_back(kptscore.clone()); 185 | kptscores.push_back(std::move(kptscore)); 186 | } 187 | } 188 | 189 | return std::make_tuple(std::move(keypoints), 190 | std::move(scoredispersitys), 191 | std::move(kptscores)); 192 | } 193 | 194 | std::tuple, std::vector, std::vector> 195 | DKD::detect_keypoints(const torch::Tensor& scores_map, bool sub_pixel) & { 196 | auto scores_map_copy = scores_map.clone(); 197 | return std::move(*this).detect_keypoints(std::move(scores_map_copy), sub_pixel); 198 | } 199 | 200 | std::tuple, std::vector, std::vector> 201 | DKD::forward(torch::Tensor scores_map, bool sub_pixel) && { 202 | return std::move(*this).detect_keypoints(std::move(scores_map), sub_pixel); 203 | } 204 | 205 | std::tuple, std::vector, std::vector> 206 | DKD::forward(const torch::Tensor& scores_map, bool sub_pixel) & { 207 | return this->detect_keypoints(scores_map, sub_pixel); 208 | } -------------------------------------------------------------------------------- /src/feature/SDDH.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/SDDH.hpp" 2 | 3 | #include "feature/get_patches.hpp" 4 | #include 5 | 6 | using namespace torch::indexing; 7 | 8 | SDDH::SDDH(int dims, int kernel_size, int n_pos, bool conv2D, bool mask) 9 | : kernel_size_(kernel_size), 10 | n_pos_(n_pos), 11 | conv2D_(conv2D), 12 | mask_(mask) { 13 | 14 | // Channel num for offsets 15 | const int channel_num = mask ? 3 * n_pos : 2 * n_pos; 16 | 17 | // Build offset convolution layers 18 | torch::nn::Sequential offset_conv; 19 | offset_conv->push_back(torch::nn::Conv2d( 20 | torch::nn::Conv2dOptions(dims, channel_num, kernel_size) 21 | .stride(1) 22 | .padding(0) 23 | .bias(true))); 24 | offset_conv->push_back(torch::nn::SELU()); 25 | offset_conv->push_back(torch::nn::Conv2d( 26 | torch::nn::Conv2dOptions(channel_num, channel_num, 1) 27 | .stride(1) 28 | .padding(0) 29 | .bias(true))); 30 | 31 | register_module("offset_conv", offset_conv); 32 | offset_conv_ = offset_conv; 33 | 34 | // Sampled feature convolution 35 | sf_conv_ = register_module("sf_conv", 36 | torch::nn::Conv2d(torch::nn::Conv2dOptions(dims, dims, 1) 37 | .stride(1) 38 | .padding(0) 39 | .bias(false))); 40 | 41 | if (!conv2D) 42 | { 43 | // Register deformable desc weights 44 | agg_weights_ = register_parameter("agg_weights", 45 | torch::randn({n_pos, dims, dims})); 46 | } else 47 | { 48 | // Register convM 49 | convM_ = register_module("convM", 50 | torch::nn::Conv2d(torch::nn::Conv2dOptions(dims * n_pos, dims, 1) 51 | .stride(1) 52 | .padding(0) 53 | .bias(false))); 54 | } 55 | } 56 | 57 | torch::Tensor SDDH::process_features(torch::Tensor features, int64_t num_keypoints) && { 58 | if (!conv2D_) 59 | { 60 | return torch::einsum("ncp,pcd->nd", 61 | {std::move(features), agg_weights_}); 62 | } else 63 | { 64 | features = std::move(features) 65 | .reshape({num_keypoints, -1}) 66 | .unsqueeze(-1) 67 | .unsqueeze(-1); 68 | return convM_->forward(std::move(features)).squeeze(); 69 | } 70 | } 71 | 72 | torch::Tensor SDDH::process_features(const torch::Tensor& features, int64_t num_keypoints) & { 73 | auto features_copy = features.clone(); 74 | return std::move(*this).process_features(std::move(features_copy), num_keypoints); 75 | } 76 | 77 | std::tuple, std::vector> 78 | SDDH::forward(torch::Tensor x, std::vector& keypoints) && { 79 | // Make input tensor contiguous if it isn't already 80 | if (!x.is_contiguous()) 81 | { 82 | x = x.contiguous(); 83 | } 84 | 85 | const auto batch_size = x.size(0); 86 | const auto channels = x.size(1); 87 | const auto height = x.size(2); 88 | const auto width = x.size(3); 89 | const auto device = x.device(); 90 | 91 | const auto wh = torch::tensor({width - 1.0f, height - 1.0f}, 92 | torch::TensorOptions() 93 | .dtype(x.dtype()) 94 | .device(device)); 95 | 96 | const float max_offset = std::max(height, width) / 4.0f; 97 | 98 | std::vector offsets; 99 | std::vector descriptors; 100 | offsets.reserve(batch_size); 101 | descriptors.reserve(batch_size); 102 | 103 | for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) 104 | { 105 | auto xi = x[batch_idx]; 106 | // Ensure xi is contiguous 107 | if (!xi.is_contiguous()) 108 | { 109 | xi = xi.contiguous(); 110 | } 111 | 112 | const auto& kptsi = keypoints[batch_idx]; 113 | auto kptsi_wh = (kptsi / 2 + 0.5) * wh; 114 | const auto num_keypoints = kptsi_wh.size(0); 115 | 116 | torch::Tensor patch; 117 | if (kernel_size_ > 1) 118 | { 119 | // Ensure inputs to get_patches_forward are contiguous 120 | auto kptsi_wh_long = kptsi_wh.to(torch::kLong).contiguous(); 121 | patch = custom_ops::get_patches_forward(xi.contiguous(), kptsi_wh_long, kernel_size_); 122 | } else 123 | { 124 | auto kptsi_wh_long = kptsi_wh.to(torch::kLong).contiguous(); 125 | patch = xi.index({Slice(), 126 | kptsi_wh_long.index({Slice(), 1}), 127 | kptsi_wh_long.index({Slice(), 0})}) 128 | .transpose(0, 1) 129 | .reshape({num_keypoints, channels, 1, 1}) 130 | .contiguous(); 131 | } 132 | 133 | // Rest of the code remains the same... 134 | auto offset = offset_conv_->forward(std::move(patch)); 135 | offset = offset.clamp(-max_offset, max_offset); 136 | 137 | torch::Tensor mask_weight; 138 | if (mask_) 139 | { 140 | offset = offset.index({Slice(), Slice(), 0, 0}) 141 | .view({num_keypoints, 3, n_pos_}) 142 | .permute({0, 2, 1}) 143 | .contiguous(); 144 | auto offset_xy = offset.index({Slice(), Slice(), Slice(None, 2)}); 145 | mask_weight = torch::sigmoid(offset.index({Slice(), Slice(), 2})); 146 | offset = offset_xy; 147 | } else 148 | { 149 | offset = offset.index({Slice(), Slice(), 0, 0}) 150 | .view({num_keypoints, 2, n_pos_}) 151 | .permute({0, 2, 1}) 152 | .contiguous(); 153 | } 154 | 155 | offsets.push_back(offset); 156 | 157 | auto pos = kptsi_wh.unsqueeze(1) + offset; 158 | pos = 2.0 * pos / wh - 1; 159 | pos = pos.reshape({1, num_keypoints * n_pos_, 1, 2}).contiguous(); 160 | 161 | auto features = torch::nn::functional::grid_sample( 162 | xi.unsqueeze(0), pos, 163 | torch::nn::functional::GridSampleFuncOptions() 164 | .mode(torch::kBilinear) 165 | .align_corners(true)); 166 | 167 | features = features.reshape({channels, num_keypoints, n_pos_, 1}) 168 | .permute({1, 0, 2, 3}) 169 | .contiguous(); 170 | 171 | if (mask_) 172 | { 173 | features = features * mask_weight.unsqueeze(1).unsqueeze(-1); 174 | } 175 | 176 | features = torch::selu(sf_conv_->forward(std::move(features))).squeeze(-1); 177 | 178 | torch::Tensor descs; 179 | if (!conv2D_) 180 | { 181 | descs = torch::einsum("ncp,pcd->nd", {features, agg_weights_}); 182 | } else 183 | { 184 | features = features.reshape({num_keypoints, -1}).unsqueeze(-1).unsqueeze(-1).contiguous(); 185 | descs = convM_->forward(std::move(features)).squeeze(); 186 | } 187 | 188 | descs = torch::nn::functional::normalize(std::move(descs), 189 | torch::nn::functional::NormalizeFuncOptions() 190 | .p(2) 191 | .dim(1)); 192 | 193 | descriptors.push_back(std::move(descs)); 194 | } 195 | 196 | return std::make_tuple(std::move(descriptors), std::move(offsets)); 197 | } 198 | 199 | std::tuple, std::vector> 200 | SDDH::forward(const torch::Tensor& x, std::vector& keypoints) & { 201 | auto x_copy = x.clone(); 202 | return std::move(*this).forward(std::move(x_copy), keypoints); 203 | } -------------------------------------------------------------------------------- /src/feature/blocks.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/blocks.hpp" 2 | 3 | #include "feature/deform_conv2d.h" 4 | 5 | DeformableConv2d::DeformableConv2d(int in_channels, int out_channels, 6 | int kernel_size, int stride, int padding, 7 | bool bias) { 8 | padding_ = padding; 9 | const int channel_num = 2 * kernel_size * kernel_size; 10 | 11 | // Register offset conv 12 | offset_conv_ = register_module("offset_conv", 13 | torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, channel_num, kernel_size) 14 | .stride(stride) 15 | .padding(padding) 16 | .bias(true))); 17 | 18 | // Register regular conv 19 | regular_conv_ = register_module("regular_conv", 20 | torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size) 21 | .stride(stride) 22 | .padding(padding) 23 | .bias(bias))); 24 | } 25 | 26 | torch::Tensor DeformableConv2d::forward(const torch::Tensor& x) & { 27 | auto h = x.size(2); 28 | auto w = x.size(3); 29 | float max_offset = std::max(h, w) / 4.0f; 30 | 31 | // Offset and mask 32 | auto offset = offset_conv_->forward(x); 33 | auto mask = torch::zeros( 34 | {offset.size(0), 1}, 35 | torch::TensorOptions().device(offset.device()).dtype(offset.dtype())); 36 | 37 | offset = offset.clamp(-max_offset, max_offset); 38 | 39 | if (!regular_conv_->bias.defined()) 40 | { 41 | regular_conv_->bias = torch::zeros( 42 | {regular_conv_->weight.size(0)}, 43 | torch::TensorOptions().device(x.device()).dtype(x.dtype())); 44 | } 45 | 46 | return vision::ops::deform_conv2d( 47 | x, 48 | regular_conv_->weight, 49 | offset, 50 | mask, 51 | regular_conv_->bias, 52 | 1, 1, 53 | padding_, padding_, 54 | 1, 1, 55 | groups_, 56 | mask_offset_, 57 | false); 58 | } 59 | 60 | torch::Tensor DeformableConv2d::forward(torch::Tensor x) && { 61 | auto h = x.size(2); 62 | auto w = x.size(3); 63 | float max_offset = std::max(h, w) / 4.0f; 64 | 65 | // Offset and mask 66 | auto offset = offset_conv_->forward(std::move(x)); 67 | auto mask = torch::zeros( 68 | {offset.size(0), 1}, 69 | torch::TensorOptions().device(offset.device()).dtype(offset.dtype())); 70 | 71 | offset = std::move(offset).clamp(-max_offset, max_offset); 72 | 73 | if (!regular_conv_->bias.defined()) 74 | { 75 | regular_conv_->bias = torch::zeros( 76 | {regular_conv_->weight.size(0)}, 77 | torch::TensorOptions().device(x.device()).dtype(x.dtype())); 78 | } 79 | 80 | return vision::ops::deform_conv2d( 81 | std::move(x), 82 | regular_conv_->weight, 83 | std::move(offset), 84 | std::move(mask), 85 | regular_conv_->bias, 86 | 1, 1, 87 | padding_, padding_, 88 | 1, 1, 89 | groups_, 90 | mask_offset_, 91 | false); 92 | } 93 | 94 | ConvBlock::ConvBlock(int in_channels, int out_channels, 95 | std::string_view conv_type, bool mask) { 96 | 97 | if (conv_type == "conv") 98 | { 99 | auto conv1 = torch::nn::Conv2d((torch::nn::Conv2dOptions(in_channels, out_channels, 3) 100 | .stride(1) 101 | .padding(1) 102 | .bias(false))); 103 | conv1_ = register_module("conv1", conv1); 104 | 105 | auto conv2 = torch::nn::Conv2d((torch::nn::Conv2dOptions(out_channels, out_channels, 3) 106 | .stride(1) 107 | .padding(1) 108 | .bias(false))); 109 | conv2_ = register_module("conv2", conv2); 110 | 111 | } else 112 | { 113 | auto conv1 = std::make_shared( 114 | in_channels, 115 | out_channels, 116 | 3, 117 | 1, 118 | 1, 119 | false); 120 | deform1_ = register_module("conv1", conv1); 121 | 122 | auto conv2 = std::make_shared( 123 | out_channels, 124 | out_channels, 125 | 3, 126 | 1, 127 | 1, 128 | false); 129 | deform2_ = register_module("conv2", conv2); 130 | } 131 | 132 | bn1_ = register_module("bn1", torch::nn::BatchNorm2d(out_channels)); 133 | bn2_ = register_module("bn2", torch::nn::BatchNorm2d(out_channels)); 134 | } 135 | 136 | ResBlock::ResBlock(int inplanes, int planes, int stride, 137 | const torch::nn::Conv2d& downsample, 138 | std::string_view conv_type) 139 | : downsample_(downsample) { 140 | 141 | if (conv_type == "conv") 142 | { 143 | auto conv1 = torch::nn::Conv2d((torch::nn::Conv2dOptions(inplanes, planes, 3) 144 | .stride(stride) 145 | .padding(1) 146 | .bias(false))); 147 | conv1_ = register_module("conv1", conv1); 148 | 149 | auto conv2 = torch::nn::Conv2d((torch::nn::Conv2dOptions(planes, planes, 3) 150 | .stride(stride) 151 | .padding(1) 152 | .bias(false))); 153 | conv2_ = register_module("conv2", conv2); 154 | 155 | } else 156 | { 157 | auto conv1 = std::make_shared( 158 | inplanes, 159 | planes, 160 | 3, 161 | 1, 162 | 1, 163 | false); 164 | deform1_ = register_module("conv1", conv1); 165 | 166 | auto conv2 = std::make_shared( 167 | planes, 168 | planes, 169 | 3, 170 | 1, 171 | 1, 172 | false); 173 | deform2_ = register_module("conv2", conv2); 174 | } 175 | 176 | bn1_ = register_module("bn1", 177 | torch::nn::BatchNorm2d(planes)); 178 | bn2_ = register_module("bn2", 179 | torch::nn::BatchNorm2d(planes)); 180 | 181 | if (downsample) 182 | { 183 | register_module("downsample", downsample); 184 | } 185 | } 186 | 187 | torch::Tensor ConvBlock::forward(torch::Tensor x) && { 188 | return std::move(*this).forward(std::move(x)); 189 | } 190 | 191 | torch::Tensor ConvBlock::forward(const torch::Tensor& x) & { 192 | if (conv1_ && conv2_) 193 | { 194 | auto tmp = torch::selu(bn1_->forward(conv1_->forward(x))); 195 | return torch::selu(bn2_->forward(conv2_->forward(std::move(tmp)))); 196 | } else 197 | { 198 | auto tmp = torch::selu(bn1_->forward(deform1_->forward(x))); 199 | return torch::selu(bn2_->forward(deform2_->forward(std::move(tmp)))); 200 | } 201 | } 202 | 203 | torch::Tensor ResBlock::forward(torch::Tensor x) && { 204 | return std::move(*this).forward(std::move(x)); 205 | } 206 | 207 | torch::Tensor ResBlock::forward(const torch::Tensor& x) & { 208 | auto identity = x; 209 | 210 | torch::Tensor processed; 211 | if (conv1_ && conv2_) 212 | { 213 | auto tmp = conv1_->forward(x); 214 | tmp = bn1_->forward(std::move(tmp)); 215 | tmp = torch::selu(std::move(tmp)); 216 | 217 | processed = conv2_->forward(std::move(tmp)); 218 | processed = bn2_->forward(std::move(processed)); 219 | } else 220 | { 221 | auto tmp = deform1_->forward(x); 222 | tmp = bn1_->forward(std::move(tmp)); 223 | tmp = torch::selu(std::move(tmp)); 224 | 225 | processed = deform2_->forward(std::move(tmp)); 226 | processed = bn2_->forward(std::move(processed)); 227 | } 228 | 229 | if (downsample_) 230 | { 231 | identity = downsample_->as()->forward(std::move(identity)); 232 | } 233 | 234 | processed += identity; 235 | return torch::selu(std::move(processed)); 236 | } -------------------------------------------------------------------------------- /src/feature/deform_conv2d.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/deform_conv2d.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace vision { 8 | namespace ops { 9 | 10 | at::Tensor deform_conv2d( 11 | const at::Tensor& input, 12 | const at::Tensor& weight, 13 | const at::Tensor& offset, 14 | const at::Tensor& mask, 15 | const at::Tensor& bias, 16 | int64_t stride_h, 17 | int64_t stride_w, 18 | int64_t pad_h, 19 | int64_t pad_w, 20 | int64_t dilation_h, 21 | int64_t dilation_w, 22 | int64_t groups, 23 | int64_t offset_groups, 24 | bool use_mask) { 25 | C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); 26 | static auto op = c10::Dispatcher::singleton() 27 | .findSchemaOrThrow("torchvision::deform_conv2d", "") 28 | .typed(); 29 | return op.call( 30 | input, 31 | weight, 32 | offset, 33 | mask, 34 | bias, 35 | stride_h, 36 | stride_w, 37 | pad_h, 38 | pad_w, 39 | dilation_h, 40 | dilation_w, 41 | groups, 42 | offset_groups, 43 | use_mask); 44 | } 45 | 46 | at::Tensor deform_conv2d_symint( 47 | const at::Tensor& input, 48 | const at::Tensor& weight, 49 | const at::Tensor& offset, 50 | const at::Tensor& mask, 51 | const at::Tensor& bias, 52 | c10::SymInt stride_h, 53 | c10::SymInt stride_w, 54 | c10::SymInt pad_h, 55 | c10::SymInt pad_w, 56 | c10::SymInt dilation_h, 57 | c10::SymInt dilation_w, 58 | c10::SymInt groups, 59 | c10::SymInt offset_groups, 60 | bool use_mask) { 61 | C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); 62 | static auto op = c10::Dispatcher::singleton() 63 | .findSchemaOrThrow("torchvision::deform_conv2d", "") 64 | .typed(); 65 | return op.call( 66 | input, 67 | weight, 68 | offset, 69 | mask, 70 | bias, 71 | stride_h, 72 | stride_w, 73 | pad_h, 74 | pad_w, 75 | dilation_h, 76 | dilation_w, 77 | groups, 78 | offset_groups, 79 | use_mask); 80 | } 81 | 82 | namespace detail { 83 | 84 | std::tuple 85 | _deform_conv2d_backward( 86 | const at::Tensor& grad, 87 | const at::Tensor& input, 88 | const at::Tensor& weight, 89 | const at::Tensor& offset, 90 | const at::Tensor& mask, 91 | const at::Tensor& bias, 92 | int64_t stride_h, 93 | int64_t stride_w, 94 | int64_t pad_h, 95 | int64_t pad_w, 96 | int64_t dilation_h, 97 | int64_t dilation_w, 98 | int64_t groups, 99 | int64_t offset_groups, 100 | bool use_mask) { 101 | static auto op = 102 | c10::Dispatcher::singleton() 103 | .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") 104 | .typed(); 105 | return op.call( 106 | grad, 107 | input, 108 | weight, 109 | offset, 110 | mask, 111 | bias, 112 | stride_h, 113 | stride_w, 114 | pad_h, 115 | pad_w, 116 | dilation_h, 117 | dilation_w, 118 | groups, 119 | offset_groups, 120 | use_mask); 121 | } 122 | 123 | std::tuple 124 | _deform_conv2d_backward_symint( 125 | const at::Tensor& grad, 126 | const at::Tensor& input, 127 | const at::Tensor& weight, 128 | const at::Tensor& offset, 129 | const at::Tensor& mask, 130 | const at::Tensor& bias, 131 | c10::SymInt stride_h, 132 | c10::SymInt stride_w, 133 | c10::SymInt pad_h, 134 | c10::SymInt pad_w, 135 | c10::SymInt dilation_h, 136 | c10::SymInt dilation_w, 137 | c10::SymInt groups, 138 | c10::SymInt offset_groups, 139 | bool use_mask) { 140 | static auto op = 141 | c10::Dispatcher::singleton() 142 | .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") 143 | .typed(); 144 | return op.call( 145 | grad, 146 | input, 147 | weight, 148 | offset, 149 | mask, 150 | bias, 151 | stride_h, 152 | stride_w, 153 | pad_h, 154 | pad_w, 155 | dilation_h, 156 | dilation_w, 157 | groups, 158 | offset_groups, 159 | use_mask); 160 | } 161 | 162 | } // namespace detail 163 | 164 | TORCH_LIBRARY_FRAGMENT(torchvision, m) { 165 | m.def(TORCH_SELECTIVE_SCHEMA( 166 | "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor")); 167 | m.def(TORCH_SELECTIVE_SCHEMA( 168 | "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); 169 | } 170 | 171 | } // namespace ops 172 | } // namespace vision 173 | -------------------------------------------------------------------------------- /src/feature/deform_conv2d_kernel.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ******************* BEGIN Caffe Copyright Notice and Disclaimer 3 | ***************** 4 | * 5 | * COPYRIGHT 6 | * 7 | * All contributions by the University of California: 8 | * Copyright (c) 2014-2017 The Regents of the University of California (Regents) 9 | * All rights reserved. 10 | * 11 | * All other contributions: 12 | * Copyright (c) 2014-2017, the respective contributors 13 | * All rights reserved. 14 | * 15 | * Caffe uses a shared copyright model: each contributor holds copyright over 16 | * their contributions to Caffe. The project versioning records all such 17 | * contribution and copyright details. If a contributor wants to further mark 18 | * their specific copyright on a particular contribution, they should indicate 19 | * their copyright solely in the commit message of the change when it is 20 | * committed. 21 | * 22 | * LICENSE 23 | * 24 | * Redistribution and use in source and binary forms, with or without 25 | * modification, are permitted provided that the following conditions are met: 26 | * 27 | * 1. Redistributions of source code must retain the above copyright notice, 28 | *this list of conditions and the following disclaimer. 29 | * 2. Redistributions in binary form must reproduce the above copyright notice, 30 | * this list of conditions and the following disclaimer in the documentation 31 | * and/or other materials provided with the distribution. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 34 | *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 35 | *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 36 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 37 | *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 38 | *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 39 | *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 40 | *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 41 | *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 42 | *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | * 44 | * CONTRIBUTION AGREEMENT 45 | * 46 | * By contributing to the BVLC/caffe repository through pull-request, comment, 47 | * or otherwise, the contributor releases their content to the 48 | * license and copyright terms herein. 49 | * 50 | ***************** END Caffe Copyright Notice and Disclaimer 51 | ********************* 52 | * 53 | * Copyright (c) 2018 Microsoft 54 | * Licensed under The MIT License [see LICENSE for details] 55 | * \file modulated_deformable_im2col.cuh 56 | * \brief Function definitions of converting an image to 57 | * column matrix based on kernel, padding, dilation, and offset. 58 | * These functions are mainly used in deformable convolution operators. 59 | * \ref: https://arxiv.org/abs/1703.06211 60 | * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng 61 | */ 62 | 63 | // modified from 64 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu 65 | 66 | // modified from 67 | // https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp 68 | 69 | #include "feature/cuda_helpers.h" 70 | #include 71 | #include 72 | #include 73 | #include 74 | #include 75 | 76 | namespace vision { 77 | namespace ops { 78 | 79 | namespace { 80 | 81 | const int kMaxParallelImgs = 32; 82 | 83 | inline unsigned int GET_THREADS() { 84 | #ifdef WITH_HIP 85 | return 256; 86 | #endif 87 | return 512; 88 | } 89 | 90 | inline unsigned int GET_BLOCKS(const unsigned int THREADS, const int64_t N) { 91 | int64_t kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; 92 | return (unsigned int)std::min(kMaxGridNum, (N + THREADS - 1) / THREADS); 93 | } 94 | 95 | template 96 | __device__ scalar_t bilinear_interpolate( 97 | const scalar_t* in, 98 | index_t height, 99 | index_t width, 100 | scalar_t h, 101 | scalar_t w) { 102 | if (h <= -1 || height <= h || w <= -1 || width <= w) 103 | { 104 | return 0; 105 | } 106 | 107 | index_t h_low = floor(h); 108 | index_t w_low = floor(w); 109 | index_t h_high = h_low + 1; 110 | index_t w_high = w_low + 1; 111 | 112 | scalar_t lh = h - h_low; 113 | scalar_t lw = w - w_low; 114 | scalar_t hh = 1 - lh, hw = 1 - lw; 115 | 116 | scalar_t v1 = 0; 117 | if (h_low >= 0 && w_low >= 0) 118 | v1 = in[h_low * width + w_low]; 119 | scalar_t v2 = 0; 120 | if (h_low >= 0 && w_high <= width - 1) 121 | v2 = in[h_low * width + w_high]; 122 | scalar_t v3 = 0; 123 | if (h_high <= height - 1 && w_low >= 0) 124 | v3 = in[h_high * width + w_low]; 125 | scalar_t v4 = 0; 126 | if (h_high <= height - 1 && w_high <= width - 1) 127 | v4 = in[h_high * width + w_high]; 128 | 129 | scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 130 | 131 | scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 132 | return val; 133 | } 134 | 135 | template 136 | __global__ void deformable_im2col_kernel( 137 | index_t n, 138 | const scalar_t* input_ptr, 139 | const scalar_t* offset_ptr, 140 | const scalar_t* mask_ptr, 141 | index_t height, 142 | index_t width, 143 | index_t weight_h, 144 | index_t weight_w, 145 | index_t pad_h, 146 | index_t pad_w, 147 | index_t stride_h, 148 | index_t stride_w, 149 | index_t dilation_h, 150 | index_t dilation_w, 151 | index_t batch_sz, 152 | index_t n_in_channels, 153 | index_t n_offset_grps, 154 | index_t out_h, 155 | index_t out_w, 156 | bool use_mask, 157 | scalar_t* columns_ptr) { 158 | CUDA_1D_KERNEL_LOOP_T(index, n, index_t) { 159 | const index_t out_x = index % out_w; 160 | const index_t out_y = (index / out_w) % out_h; 161 | const index_t out_b = (index / (out_w * out_h)) % batch_sz; 162 | const index_t in_c = index / (out_w * out_h * batch_sz); 163 | const index_t out_c = in_c * weight_h * weight_w; 164 | 165 | index_t c_per_offset_grp = n_in_channels / n_offset_grps; 166 | const index_t grp_idx = in_c / c_per_offset_grp; 167 | 168 | columns_ptr += 169 | (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + 170 | out_y * out_w + out_x); 171 | 172 | input_ptr += 173 | (out_b * (n_in_channels * height * width) + in_c * (height * width)); 174 | 175 | offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * 176 | out_h * out_w; 177 | 178 | if (use_mask) 179 | { 180 | mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * 181 | out_h * out_w; 182 | } 183 | 184 | for (int i = 0; i < weight_h; ++i) 185 | { 186 | for (int j = 0; j < weight_w; ++j) 187 | { 188 | const index_t mask_idx = i * weight_w + j; 189 | const index_t offset_idx = 2 * mask_idx; 190 | 191 | scalar_t mask_value = 1; 192 | if (use_mask) 193 | { 194 | mask_value = 195 | mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; 196 | } 197 | 198 | const scalar_t offset_h = 199 | offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; 200 | const scalar_t offset_w = offset_ptr 201 | [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; 202 | const scalar_t y = 203 | (out_y * stride_h - pad_h) + i * dilation_h + offset_h; 204 | const scalar_t x = 205 | (out_x * stride_w - pad_w) + j * dilation_w + offset_w; 206 | *columns_ptr = 207 | mask_value * bilinear_interpolate(input_ptr, height, width, y, x); 208 | columns_ptr += batch_sz * out_h * out_w; 209 | } 210 | } 211 | } 212 | } 213 | 214 | void deformable_im2col( 215 | const at::Tensor& input, 216 | const at::Tensor& data_offset, 217 | const at::Tensor& data_mask, 218 | int n_in_channels, 219 | int height, 220 | int width, 221 | int weight_h, 222 | int weight_w, 223 | int pad_h, 224 | int pad_w, 225 | int stride_h, 226 | int stride_w, 227 | int dilation_h, 228 | int dilation_w, 229 | int out_h, 230 | int out_w, 231 | int parallel_imgs, 232 | int deformable_group, 233 | bool use_mask, 234 | at::Tensor data_col) { 235 | at::cuda::CUDAGuard device_guard(input.get_device()); 236 | 237 | const int64_t num_kernels = 238 | (int64_t)n_in_channels * out_h * out_w * parallel_imgs; 239 | 240 | const unsigned int threads = GET_THREADS(); 241 | const unsigned int blocks = GET_BLOCKS(threads, num_kernels); 242 | 243 | // Checks if we should use 64bits indexing 244 | // https://github.com/pytorch/vision/issues/4269 245 | bool use_64bits_indexing = false; 246 | // Checks if num_kernels or columns numel larger than 2 ** 31 247 | use_64bits_indexing |= num_kernels > (1 << 31); 248 | use_64bits_indexing |= 249 | ((int64_t)n_in_channels * weight_h * weight_w * parallel_imgs * out_h * 250 | out_w > 251 | (1 << 31)); 252 | 253 | if (use_64bits_indexing) 254 | { 255 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 256 | input.scalar_type(), "deformable_im2col", ([&] { 257 | deformable_im2col_kernel<<>>( 258 | num_kernels, 259 | input.data_ptr(), 260 | data_offset.data_ptr(), 261 | data_mask.data_ptr(), 262 | height, 263 | width, 264 | weight_h, 265 | weight_w, 266 | pad_h, 267 | pad_w, 268 | stride_h, 269 | stride_w, 270 | dilation_h, 271 | dilation_w, 272 | parallel_imgs, 273 | n_in_channels, 274 | deformable_group, 275 | out_h, 276 | out_w, 277 | use_mask, 278 | data_col.data_ptr()); 279 | })); 280 | 281 | } else 282 | { 283 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 284 | input.scalar_type(), "deformable_im2col", ([&] { 285 | deformable_im2col_kernel<<>>( 286 | num_kernels, 287 | input.data_ptr(), 288 | data_offset.data_ptr(), 289 | data_mask.data_ptr(), 290 | height, 291 | width, 292 | weight_h, 293 | weight_w, 294 | pad_h, 295 | pad_w, 296 | stride_h, 297 | stride_w, 298 | dilation_h, 299 | dilation_w, 300 | parallel_imgs, 301 | n_in_channels, 302 | deformable_group, 303 | out_h, 304 | out_w, 305 | use_mask, 306 | data_col.data_ptr()); 307 | })); 308 | } 309 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 310 | } 311 | 312 | int get_greatest_divisor_below_bound(int n, int bound) { 313 | for (int k = bound; k > 1; --k) 314 | { 315 | if (n % k == 0) 316 | { 317 | return k; 318 | } 319 | } 320 | return 1; 321 | } 322 | 323 | template 324 | __global__ void deformable_col2im_kernel( 325 | index_t n, 326 | const scalar_t* col, 327 | const scalar_t* offset_ptr, 328 | const scalar_t* mask_ptr, 329 | index_t channels, 330 | index_t height, 331 | index_t width, 332 | index_t kernel_h, 333 | index_t kernel_w, 334 | index_t pad_h, 335 | index_t pad_w, 336 | index_t stride_h, 337 | index_t stride_w, 338 | index_t dilation_h, 339 | index_t dilation_w, 340 | index_t batch_sz, 341 | index_t n_offset_grps, 342 | index_t out_h, 343 | index_t out_w, 344 | bool use_mask, 345 | scalar_t* grad_im) { 346 | const index_t grad_im_numel = width * height * channels * batch_sz; 347 | 348 | CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) { 349 | const index_t out_x = index % out_w; 350 | const index_t out_y = (index / out_w) % out_h; 351 | const index_t b = (index / (out_w * out_h)) % batch_sz; 352 | const index_t j = (index / (out_w * out_h * batch_sz)) % kernel_w; 353 | const index_t i = 354 | (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; 355 | const index_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); 356 | 357 | index_t c_per_offset_grp = channels / n_offset_grps; 358 | const index_t offset_grp = c / c_per_offset_grp; 359 | 360 | offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * 361 | out_h * out_w; 362 | 363 | if (use_mask) 364 | { 365 | mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * 366 | out_h * out_w; 367 | } 368 | 369 | const index_t mask_idx = i * kernel_w + j; 370 | const index_t offset_idx = 2 * mask_idx; 371 | 372 | const index_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; 373 | const index_t offset_w_ptr = 374 | ((offset_idx + 1) * out_h + out_y) * out_w + out_x; 375 | 376 | const scalar_t offset_h = offset_ptr[offset_h_ptr]; 377 | const scalar_t offset_w = offset_ptr[offset_w_ptr]; 378 | 379 | scalar_t mask_value = 1; 380 | if (use_mask) 381 | { 382 | mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; 383 | } 384 | 385 | const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; 386 | const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; 387 | 388 | for (index_t dy = -1; dy <= 1; dy++) 389 | { 390 | for (index_t dx = -1; dx <= 1; dx++) 391 | { 392 | index_t yp = (index_t)y + dy; 393 | index_t xp = (index_t)x + dx; 394 | if (0 <= yp && yp < height && 0 <= xp && xp < width && 395 | std::abs(y - yp) < 1 && std::abs(x - xp) < 1) 396 | { 397 | index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; 398 | scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); 399 | at::native::fastAtomicAdd( 400 | grad_im, 401 | grad_pos, 402 | grad_im_numel, 403 | mask_value * weight * col[index], 404 | true); 405 | } 406 | } 407 | } 408 | } 409 | } 410 | 411 | void compute_grad_input( 412 | const at::Tensor& columns, 413 | const at::Tensor& offset, 414 | const at::Tensor& mask, 415 | int channels, 416 | int height, 417 | int width, 418 | int weight_h, 419 | int weight_w, 420 | int pad_h, 421 | int pad_w, 422 | int stride_h, 423 | int stride_w, 424 | int dilation_h, 425 | int dilation_w, 426 | int parallel_imgs, 427 | int n_offset_grps, 428 | bool use_mask, 429 | at::Tensor grad_im) { 430 | at::cuda::CUDAGuard device_guard(columns.get_device()); 431 | 432 | const int out_h = 433 | (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; 434 | const int out_w = 435 | (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; 436 | 437 | const int64_t num_kernels = 438 | (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; 439 | 440 | const unsigned int threads = GET_THREADS(); 441 | const unsigned int blocks = GET_BLOCKS(threads, num_kernels); 442 | 443 | // Checks if we should use 64bits indexing 444 | // https://github.com/pytorch/vision/issues/4269 445 | bool use_64bits_indexing = false; 446 | // Checks if num_kernels or columns numel larger than 2 ** 31 447 | use_64bits_indexing |= num_kernels > (1 << 31); 448 | 449 | at::globalContext().alertNotDeterministic("compute_grad_input"); 450 | 451 | if (use_64bits_indexing) 452 | { 453 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 454 | columns.scalar_type(), "compute_grad_input", ([&] { 455 | deformable_col2im_kernel<<>>( 456 | num_kernels, 457 | columns.data_ptr(), 458 | offset.data_ptr(), 459 | mask.data_ptr(), 460 | channels, 461 | height, 462 | width, 463 | weight_h, 464 | weight_w, 465 | pad_h, 466 | pad_w, 467 | stride_h, 468 | stride_w, 469 | dilation_h, 470 | dilation_w, 471 | parallel_imgs, 472 | n_offset_grps, 473 | out_h, 474 | out_w, 475 | use_mask, 476 | grad_im.data_ptr()); 477 | })); 478 | } else 479 | { 480 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 481 | columns.scalar_type(), "compute_grad_input", ([&] { 482 | deformable_col2im_kernel<<>>( 483 | num_kernels, 484 | columns.data_ptr(), 485 | offset.data_ptr(), 486 | mask.data_ptr(), 487 | channels, 488 | height, 489 | width, 490 | weight_h, 491 | weight_w, 492 | pad_h, 493 | pad_w, 494 | stride_h, 495 | stride_w, 496 | dilation_h, 497 | dilation_w, 498 | parallel_imgs, 499 | n_offset_grps, 500 | out_h, 501 | out_w, 502 | use_mask, 503 | grad_im.data_ptr()); 504 | })); 505 | } 506 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 507 | } 508 | 509 | template 510 | __device__ scalar_t get_coordinate_weight( 511 | const scalar_t* im_data, 512 | index_t height, 513 | index_t width, 514 | scalar_t y, 515 | scalar_t x, 516 | bool is_y_direction) { 517 | index_t y_l = floor(y); 518 | index_t x_l = floor(x); 519 | index_t y_h = y_l + 1; 520 | index_t x_h = x_l + 1; 521 | 522 | bool valid_y_l = 0 <= y_l && y_l < height; 523 | bool valid_y_h = 0 <= y_h && y_h < height; 524 | bool valid_x_l = 0 <= x_l && x_l < width; 525 | bool valid_x_h = 0 <= x_h && x_h < width; 526 | 527 | scalar_t zero = 0; 528 | scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; 529 | scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; 530 | scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; 531 | scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; 532 | 533 | if (is_y_direction) 534 | { 535 | scalar_t dx = x - x_l; 536 | return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); 537 | } else 538 | { 539 | scalar_t dy = y - y_l; 540 | return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); 541 | } 542 | } 543 | 544 | template 545 | __global__ void deformable_col2im_coord_kernel( 546 | index_t n, 547 | const scalar_t* col_ptr, 548 | const scalar_t* im_ptr, 549 | const scalar_t* offset_ptr, 550 | const scalar_t* mask_ptr, 551 | index_t channels, 552 | index_t height, 553 | index_t width, 554 | index_t weight_h, 555 | index_t weight_w, 556 | index_t pad_h, 557 | index_t pad_w, 558 | index_t stride_h, 559 | index_t stride_w, 560 | index_t dilation_h, 561 | index_t dilation_w, 562 | index_t batch_sz, 563 | index_t offset_channels, 564 | index_t n_offset_grps, 565 | index_t out_h, 566 | index_t out_w, 567 | const bool use_mask, 568 | scalar_t* grad_offset, 569 | scalar_t* grad_mask) { 570 | CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) { 571 | scalar_t grad_offset_val = 0; 572 | scalar_t grad_mask_val = 0; 573 | 574 | index_t w = index % out_w; 575 | index_t h = (index / out_w) % out_h; 576 | index_t w_w = (index / (out_w * out_h * 2)) % weight_w; 577 | index_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; 578 | index_t c = (index / (out_w * out_h)) % offset_channels; 579 | index_t b = index / (out_w * out_h * offset_channels); 580 | 581 | const index_t offset_grp = c / (2 * weight_h * weight_w); 582 | const index_t col_step = weight_h * weight_w; 583 | 584 | index_t c_per_offset_grp = channels / n_offset_grps; 585 | 586 | col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * 587 | out_w * out_h; 588 | im_ptr += 589 | (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; 590 | offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * 591 | out_h * out_w; 592 | 593 | if (use_mask) 594 | { 595 | mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * 596 | out_h * out_w; 597 | } 598 | 599 | const index_t offset_c = c - offset_grp * 2 * weight_h * weight_w; 600 | const bool is_y_direction = offset_c % 2 == 0; 601 | 602 | const index_t c_bound = c_per_offset_grp * weight_h * weight_w; 603 | for (index_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) 604 | { 605 | const index_t col_pos = 606 | (((col_c * batch_sz + b) * out_h) + h) * out_w + w; 607 | 608 | index_t out_x = col_pos % out_w; 609 | index_t out_y = (col_pos / out_w) % out_h; 610 | index_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; 611 | index_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; 612 | 613 | const index_t mask_idx = i * weight_w + j; 614 | 615 | const index_t offset_h_ptr = 616 | (((2 * mask_idx) * out_h + out_y) * out_w + out_x); 617 | const index_t offset_w_ptr = 618 | (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); 619 | const scalar_t offset_h = offset_ptr[offset_h_ptr]; 620 | const scalar_t offset_w = offset_ptr[offset_w_ptr]; 621 | 622 | scalar_t mask_value = 1; 623 | if (use_mask) 624 | { 625 | mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; 626 | } 627 | 628 | scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; 629 | scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; 630 | 631 | const scalar_t weight = 632 | get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); 633 | grad_offset_val += mask_value * weight * col_ptr[col_pos]; 634 | 635 | if (use_mask && is_y_direction) 636 | { 637 | grad_mask_val += col_ptr[col_pos] * 638 | bilinear_interpolate(im_ptr, height, width, y, x); 639 | } 640 | 641 | im_ptr += height * width; 642 | } 643 | 644 | grad_offset[index] = grad_offset_val; 645 | 646 | if (use_mask && is_y_direction) 647 | { 648 | const index_t idx = 649 | ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + 650 | w_w) * 651 | out_h + 652 | h) * 653 | out_w + 654 | w; 655 | grad_mask[idx] = grad_mask_val; 656 | } 657 | } 658 | } 659 | 660 | void compute_grad_offset_and_mask( 661 | const at::Tensor& columns, 662 | const at::Tensor& input, 663 | const at::Tensor& offset, 664 | const at::Tensor& mask, 665 | int channels, 666 | int height, 667 | int width, 668 | int weight_h, 669 | int weight_w, 670 | int pad_h, 671 | int pad_w, 672 | int stride_h, 673 | int stride_w, 674 | int dilation_h, 675 | int dilation_w, 676 | int parallel_imgs, 677 | int n_offset_grps, 678 | bool use_mask, 679 | at::Tensor grad_offset, 680 | at::Tensor grad_mask) { 681 | at::cuda::CUDAGuard device_guard(columns.get_device()); 682 | 683 | const int out_h = 684 | (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; 685 | const int out_w = 686 | (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; 687 | const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * 688 | n_offset_grps * parallel_imgs; 689 | 690 | const unsigned int threads = GET_THREADS(); 691 | const unsigned int blocks = GET_BLOCKS(threads, num_kernels); 692 | 693 | // Checks if we should use 64bits indexing 694 | // https://github.com/pytorch/vision/issues/4269 695 | bool use_64bits_indexing = false; 696 | // Checks if columns numel is larger than 2 ** 31 697 | use_64bits_indexing |= num_kernels > (1 << 31); 698 | use_64bits_indexing |= 699 | ((int64_t)channels * weight_h * weight_w * parallel_imgs * out_h * out_w > 700 | (1 << 31)); 701 | 702 | if (use_64bits_indexing) 703 | { 704 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 705 | columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { 706 | deformable_col2im_coord_kernel 707 | <<>>( 708 | num_kernels, 709 | columns.data_ptr(), 710 | input.data_ptr(), 711 | offset.data_ptr(), 712 | mask.data_ptr(), 713 | channels, 714 | height, 715 | width, 716 | weight_h, 717 | weight_w, 718 | pad_h, 719 | pad_w, 720 | stride_h, 721 | stride_w, 722 | dilation_h, 723 | dilation_w, 724 | parallel_imgs, 725 | 2 * weight_h * weight_w * n_offset_grps, 726 | n_offset_grps, 727 | out_h, 728 | out_w, 729 | use_mask, 730 | grad_offset.data_ptr(), 731 | grad_mask.data_ptr()); 732 | })); 733 | } else 734 | { 735 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 736 | columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { 737 | deformable_col2im_coord_kernel<<>>( 738 | num_kernels, 739 | columns.data_ptr(), 740 | input.data_ptr(), 741 | offset.data_ptr(), 742 | mask.data_ptr(), 743 | channels, 744 | height, 745 | width, 746 | weight_h, 747 | weight_w, 748 | pad_h, 749 | pad_w, 750 | stride_h, 751 | stride_w, 752 | dilation_h, 753 | dilation_w, 754 | parallel_imgs, 755 | 2 * weight_h * weight_w * n_offset_grps, 756 | n_offset_grps, 757 | out_h, 758 | out_w, 759 | use_mask, 760 | grad_offset.data_ptr(), 761 | grad_mask.data_ptr()); 762 | })); 763 | } 764 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 765 | } 766 | 767 | std::tuple backward_gradient_inputs( 768 | at::Tensor input, 769 | at::Tensor weight, 770 | at::Tensor offset, 771 | at::Tensor mask, 772 | at::Tensor grad_out, 773 | int stride_h, 774 | int stride_w, 775 | int pad_h, 776 | int pad_w, 777 | int dilation_h, 778 | int dilation_w, 779 | int n_weight_grps, 780 | int n_offset_grps, 781 | int n_parallel_imgs, 782 | bool use_mask) { 783 | at::DeviceGuard guard(input.device()); 784 | 785 | int batch_sz = input.size(0); 786 | long n_in_channels = input.size(1); 787 | long in_h = input.size(2); 788 | long in_w = input.size(3); 789 | 790 | n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); 791 | 792 | long n_out_channels = weight.size(0); 793 | int weight_h = weight.size(2); 794 | int weight_w = weight.size(3); 795 | 796 | long out_w = 797 | (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; 798 | long out_h = 799 | (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; 800 | 801 | auto grad_input = at::zeros_like(input); 802 | auto grad_offset = at::zeros_like(offset); 803 | auto grad_mask = at::zeros_like(mask); 804 | 805 | if (batch_sz == 0) 806 | { 807 | return std::make_tuple(grad_input, grad_offset, grad_mask); 808 | } 809 | 810 | auto columns = at::empty( 811 | {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, 812 | input.options()); 813 | 814 | // Separate into blocks 815 | grad_input = grad_input.reshape( 816 | {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); 817 | input = input.reshape( 818 | {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); 819 | 820 | grad_offset = grad_offset.reshape( 821 | {batch_sz / n_parallel_imgs, 822 | n_parallel_imgs, 823 | n_offset_grps * 2 * weight_h * weight_w, 824 | out_h, 825 | out_w}); 826 | offset = offset.reshape( 827 | {batch_sz / n_parallel_imgs, 828 | n_parallel_imgs, 829 | n_offset_grps * 2 * weight_h * weight_w, 830 | out_h, 831 | out_w}); 832 | 833 | if (use_mask) 834 | { 835 | grad_mask = grad_mask.reshape( 836 | {batch_sz / n_parallel_imgs, 837 | n_parallel_imgs, 838 | n_offset_grps * weight_h * weight_w, 839 | out_h, 840 | out_w}); 841 | mask = mask.reshape( 842 | {batch_sz / n_parallel_imgs, 843 | n_parallel_imgs, 844 | n_offset_grps * weight_h * weight_w, 845 | out_h, 846 | out_w}); 847 | } 848 | 849 | grad_out = grad_out 850 | .reshape( 851 | {batch_sz / n_parallel_imgs, 852 | n_parallel_imgs, 853 | n_weight_grps, 854 | n_out_channels / n_weight_grps, 855 | out_h, 856 | out_w}) 857 | .permute({0, 2, 3, 1, 4, 5}); 858 | 859 | weight = weight.reshape( 860 | {n_weight_grps, 861 | weight.size(0) / n_weight_grps, 862 | weight.size(1), 863 | weight.size(2), 864 | weight.size(3)}); 865 | 866 | columns = columns.view( 867 | {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); 868 | for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) 869 | { 870 | columns.zero_(); 871 | // Separate into weight groups 872 | for (int g = 0; g < n_weight_grps; g++) 873 | { 874 | columns[g] = columns[g].addmm_( 875 | weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); 876 | } 877 | 878 | compute_grad_offset_and_mask( 879 | columns, 880 | input[elt], 881 | offset[elt], 882 | mask[elt], 883 | n_in_channels, 884 | in_h, 885 | in_w, 886 | weight_h, 887 | weight_w, 888 | pad_h, 889 | pad_w, 890 | stride_h, 891 | stride_w, 892 | dilation_h, 893 | dilation_w, 894 | n_parallel_imgs, 895 | n_offset_grps, 896 | use_mask, 897 | grad_offset[elt], 898 | grad_mask[elt]); 899 | 900 | compute_grad_input( 901 | columns, 902 | offset[elt], 903 | mask[elt], 904 | n_in_channels, 905 | in_h, 906 | in_w, 907 | weight_h, 908 | weight_w, 909 | pad_h, 910 | pad_w, 911 | stride_h, 912 | stride_w, 913 | dilation_h, 914 | dilation_w, 915 | n_parallel_imgs, 916 | n_offset_grps, 917 | use_mask, 918 | grad_input[elt]); 919 | } 920 | 921 | grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); 922 | grad_offset = grad_offset.view( 923 | {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); 924 | 925 | if (use_mask) 926 | { 927 | grad_mask = grad_mask.view( 928 | {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); 929 | } 930 | 931 | return std::make_tuple(grad_input, grad_offset, grad_mask); 932 | } 933 | 934 | at::Tensor backward_gradient_parameters( 935 | at::Tensor input, 936 | const at::Tensor& weight, 937 | at::Tensor offset, 938 | at::Tensor mask, 939 | const at::Tensor& grad_out, 940 | int stride_h, 941 | int stride_w, 942 | int pad_h, 943 | int pad_w, 944 | int dilation_h, 945 | int dilation_w, 946 | int n_weight_grps, 947 | int n_offset_grps, 948 | int n_parallel_imgs, 949 | bool use_mask) { 950 | at::DeviceGuard guard(input.device()); 951 | 952 | int batch_sz = input.size(0); 953 | long n_in_channels = input.size(1); 954 | long in_h = input.size(2); 955 | long in_w = input.size(3); 956 | 957 | n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); 958 | 959 | long n_out_channels = weight.size(0); 960 | int weight_h = weight.size(2); 961 | int weight_w = weight.size(3); 962 | 963 | long out_h = grad_out.size(2); 964 | long out_w = grad_out.size(3); 965 | 966 | auto grad_weight = at::zeros_like(weight); 967 | if (batch_sz == 0) 968 | { 969 | return grad_weight; 970 | } 971 | 972 | at::Tensor grad_out_buf = grad_out 973 | .reshape( 974 | {batch_sz / n_parallel_imgs, 975 | n_parallel_imgs, 976 | n_weight_grps, 977 | n_out_channels / n_weight_grps, 978 | out_h, 979 | out_w}) 980 | .permute({0, 2, 3, 1, 4, 5}) 981 | .contiguous(); 982 | 983 | input = input.reshape( 984 | {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); 985 | 986 | offset = offset.reshape( 987 | {batch_sz / n_parallel_imgs, 988 | n_parallel_imgs, 989 | n_offset_grps * 2 * weight_h * weight_w, 990 | out_h, 991 | out_w}); 992 | 993 | if (use_mask) 994 | { 995 | mask = mask.reshape( 996 | {batch_sz / n_parallel_imgs, 997 | n_parallel_imgs, 998 | n_offset_grps * weight_h * weight_w, 999 | out_h, 1000 | out_w}); 1001 | } 1002 | 1003 | grad_weight = grad_weight.reshape( 1004 | {n_weight_grps, 1005 | grad_weight.size(0) / n_weight_grps, 1006 | grad_weight.size(1), 1007 | grad_weight.size(2), 1008 | grad_weight.size(3)}); 1009 | 1010 | auto columns = at::empty( 1011 | {n_weight_grps, 1012 | n_in_channels * weight_w * weight_h / n_weight_grps, 1013 | n_parallel_imgs * out_h * out_w}, 1014 | input.options()); 1015 | 1016 | for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) 1017 | { 1018 | deformable_im2col( 1019 | input[elt], 1020 | offset[elt], 1021 | mask[elt], 1022 | n_in_channels, 1023 | in_h, 1024 | in_w, 1025 | weight_h, 1026 | weight_w, 1027 | pad_h, 1028 | pad_w, 1029 | stride_h, 1030 | stride_w, 1031 | dilation_h, 1032 | dilation_w, 1033 | out_h, 1034 | out_w, 1035 | n_parallel_imgs, 1036 | n_offset_grps, 1037 | use_mask, 1038 | columns); 1039 | 1040 | for (int g = 0; g < n_weight_grps; g++) 1041 | { 1042 | grad_weight[g] = 1043 | grad_weight[g] 1044 | .flatten(1) 1045 | .addmm_( 1046 | grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) 1047 | .view_as(grad_weight[g]); 1048 | } 1049 | } 1050 | 1051 | grad_weight = grad_weight.view( 1052 | {grad_weight.size(0) * grad_weight.size(1), 1053 | grad_weight.size(2), 1054 | grad_weight.size(3), 1055 | grad_weight.size(4)}); 1056 | return grad_weight; 1057 | } 1058 | 1059 | at::Tensor deform_conv2d_forward_kernel( 1060 | const at::Tensor& input, 1061 | const at::Tensor& weight, 1062 | const at::Tensor& offset, 1063 | const at::Tensor& mask, 1064 | const at::Tensor& bias, 1065 | int64_t stride_h, 1066 | int64_t stride_w, 1067 | int64_t pad_h, 1068 | int64_t pad_w, 1069 | int64_t dilation_h, 1070 | int64_t dilation_w, 1071 | int64_t n_weight_grps, 1072 | int64_t n_offset_grps, 1073 | bool use_mask) { 1074 | at::Tensor input_c = input.contiguous(); 1075 | at::Tensor offset_c = offset.contiguous(); 1076 | at::Tensor weight_c = weight.contiguous(); 1077 | at::Tensor mask_c = mask.contiguous(); 1078 | at::Tensor bias_c = bias.contiguous(); 1079 | 1080 | TORCH_CHECK(input_c.ndimension() == 4); 1081 | TORCH_CHECK(offset_c.ndimension() == 4); 1082 | TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); 1083 | TORCH_CHECK(weight_c.ndimension() == 4); 1084 | TORCH_CHECK(input_c.is_cuda(), "input must be a CUDA tensor"); 1085 | 1086 | at::DeviceGuard guard(input_c.device()); 1087 | 1088 | int batch_sz = input_c.size(0); 1089 | int in_channels = input_c.size(1); 1090 | int in_h = input_c.size(2); 1091 | int in_w = input_c.size(3); 1092 | 1093 | int n_parallel_imgs = 1094 | get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); 1095 | 1096 | int out_channels = weight_c.size(0); 1097 | int weight_h = weight_c.size(2); 1098 | int weight_w = weight_c.size(3); 1099 | 1100 | int ker_h = dilation_h * (weight_h - 1) + 1; 1101 | int ker_w = dilation_w * (weight_w - 1) + 1; 1102 | int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; 1103 | int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; 1104 | 1105 | TORCH_CHECK( 1106 | weight_h > 0 && weight_w > 0, 1107 | "weight_h: ", 1108 | weight_h, 1109 | " weight_w: ", 1110 | weight_w); 1111 | TORCH_CHECK( 1112 | stride_h > 0 && stride_w > 0, 1113 | "stride_h: ", 1114 | stride_h, 1115 | " stride_w: ", 1116 | stride_w); 1117 | TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); 1118 | TORCH_CHECK( 1119 | dilation_h > 0 && dilation_w > 0, 1120 | "dilation_h: ", 1121 | dilation_h, 1122 | " dilation_w: ", 1123 | dilation_w); 1124 | 1125 | TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); 1126 | TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); 1127 | TORCH_CHECK( 1128 | (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), 1129 | "offset.shape[1] is not valid: got: ", 1130 | offset_c.size(1), 1131 | " expected: ", 1132 | n_offset_grps * 2 * weight_h * weight_w); 1133 | TORCH_CHECK( 1134 | (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), 1135 | "mask.shape[1] is not valid: got: ", 1136 | mask_c.size(1), 1137 | " expected: ", 1138 | n_offset_grps * weight_h * weight_w); 1139 | TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); 1140 | 1141 | TORCH_CHECK( 1142 | (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); 1143 | TORCH_CHECK( 1144 | (offset_c.size(2) == out_h && offset_c.size(3) == out_w), 1145 | "offset output dims: (", 1146 | offset_c.size(2), 1147 | ", ", 1148 | offset_c.size(3), 1149 | ") - ", 1150 | "computed output dims: (", 1151 | out_h, 1152 | ", ", 1153 | out_w, 1154 | ")"); 1155 | TORCH_CHECK( 1156 | (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); 1157 | TORCH_CHECK( 1158 | (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), 1159 | "mask output dims: (", 1160 | mask_c.size(2), 1161 | ", ", 1162 | mask_c.size(3), 1163 | ") - ", 1164 | "computed output dims: (", 1165 | out_h, 1166 | ", ", 1167 | out_w, 1168 | ")"); 1169 | TORCH_CHECK( 1170 | out_h > 0 && out_w > 0, 1171 | "Calculated output size too small - out_h: ", 1172 | out_h, 1173 | " out_w: ", 1174 | out_w); 1175 | 1176 | auto out = 1177 | at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); 1178 | if (batch_sz == 0) 1179 | { 1180 | return out; 1181 | } 1182 | 1183 | // Separate batches into blocks 1184 | out = out.view( 1185 | {batch_sz / n_parallel_imgs, 1186 | n_parallel_imgs, 1187 | out_channels, 1188 | out_h, 1189 | out_w}); 1190 | input_c = input_c.view( 1191 | {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); 1192 | 1193 | offset_c = offset_c.view( 1194 | {batch_sz / n_parallel_imgs, 1195 | n_parallel_imgs, 1196 | n_offset_grps * 2 * weight_h * weight_w, 1197 | out_h, 1198 | out_w}); 1199 | 1200 | if (use_mask) 1201 | { 1202 | mask_c = mask_c.view( 1203 | {batch_sz / n_parallel_imgs, 1204 | n_parallel_imgs, 1205 | n_offset_grps * weight_h * weight_w, 1206 | out_h, 1207 | out_w}); 1208 | } 1209 | 1210 | at::Tensor out_buf = at::zeros( 1211 | {batch_sz / n_parallel_imgs, 1212 | out_channels, 1213 | n_parallel_imgs * out_h, 1214 | out_w}, 1215 | out.options()); 1216 | 1217 | // Separate channels into convolution groups 1218 | out_buf = out_buf.view( 1219 | {out_buf.size(0), 1220 | n_weight_grps, 1221 | out_buf.size(1) / n_weight_grps, 1222 | out_buf.size(2), 1223 | out_buf.size(3)}); 1224 | weight_c = weight_c.view( 1225 | {n_weight_grps, 1226 | weight_c.size(0) / n_weight_grps, 1227 | weight_c.size(1), 1228 | weight_c.size(2), 1229 | weight_c.size(3)}); 1230 | 1231 | // Sample points and perform convolution 1232 | auto columns = at::zeros( 1233 | {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, 1234 | input_c.options()); 1235 | for (int b = 0; b < batch_sz / n_parallel_imgs; b++) 1236 | { 1237 | deformable_im2col( 1238 | input_c[b], 1239 | offset_c[b], 1240 | mask_c[b], 1241 | in_channels, 1242 | in_h, 1243 | in_w, 1244 | weight_h, 1245 | weight_w, 1246 | pad_h, 1247 | pad_w, 1248 | stride_h, 1249 | stride_w, 1250 | dilation_h, 1251 | dilation_w, 1252 | out_h, 1253 | out_w, 1254 | n_parallel_imgs, 1255 | n_offset_grps, 1256 | use_mask, 1257 | columns); 1258 | 1259 | columns = columns.view( 1260 | {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); 1261 | for (int g = 0; g < n_weight_grps; g++) 1262 | { 1263 | out_buf[b][g] = out_buf[b][g] 1264 | .flatten(1) 1265 | .addmm_(weight_c[g].flatten(1), columns[g]) 1266 | .view_as(out_buf[b][g]); 1267 | } 1268 | columns = 1269 | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); 1270 | } 1271 | 1272 | out_buf = out_buf.view( 1273 | {batch_sz / n_parallel_imgs, 1274 | out_channels, 1275 | n_parallel_imgs, 1276 | out_h, 1277 | out_w}); 1278 | out_buf.transpose_(1, 2); 1279 | out.copy_(out_buf); 1280 | out = out.view({batch_sz, out_channels, out_h, out_w}); 1281 | 1282 | return out + bias_c.view({1, out_channels, 1, 1}); 1283 | } 1284 | 1285 | std::tuple 1286 | deform_conv2d_backward_kernel( 1287 | const at::Tensor& grad_out, 1288 | const at::Tensor& input, 1289 | const at::Tensor& weight, 1290 | const at::Tensor& offset, 1291 | const at::Tensor& mask, 1292 | const at::Tensor& bias, 1293 | int64_t stride_h, 1294 | int64_t stride_w, 1295 | int64_t pad_h, 1296 | int64_t pad_w, 1297 | int64_t dilation_h, 1298 | int64_t dilation_w, 1299 | int64_t n_weight_grps, 1300 | int64_t n_offset_grps, 1301 | bool use_mask) { 1302 | at::Tensor grad_out_c = grad_out.contiguous(); 1303 | at::Tensor input_c = input.contiguous(); 1304 | at::Tensor weight_c = weight.contiguous(); 1305 | at::Tensor offset_c = offset.contiguous(); 1306 | at::Tensor mask_c = mask.contiguous(); 1307 | at::Tensor bias_c = bias.contiguous(); 1308 | 1309 | const int batch_sz = input_c.size(0); 1310 | const int n_parallel_imgs = 1311 | get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); 1312 | 1313 | auto grad_input_and_offset_and_mask = backward_gradient_inputs( 1314 | input_c, 1315 | weight_c, 1316 | offset_c, 1317 | mask_c, 1318 | grad_out_c, 1319 | stride_h, 1320 | stride_w, 1321 | pad_h, 1322 | pad_w, 1323 | dilation_h, 1324 | dilation_w, 1325 | n_weight_grps, 1326 | n_offset_grps, 1327 | n_parallel_imgs, 1328 | use_mask); 1329 | 1330 | auto grad_input = std::get<0>(grad_input_and_offset_and_mask); 1331 | auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); 1332 | auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); 1333 | 1334 | auto grad_weight = backward_gradient_parameters( 1335 | input_c, 1336 | weight_c, 1337 | offset_c, 1338 | mask_c, 1339 | grad_out_c, 1340 | stride_h, 1341 | stride_w, 1342 | pad_h, 1343 | pad_w, 1344 | dilation_h, 1345 | dilation_w, 1346 | n_weight_grps, 1347 | n_offset_grps, 1348 | n_parallel_imgs, 1349 | use_mask); 1350 | 1351 | auto value = grad_out_c.sum({0, 2, 3}); 1352 | auto grad_bias = at::ones_like(bias_c) * value; 1353 | 1354 | return std::make_tuple( 1355 | grad_input, grad_weight, grad_offset, grad_mask, grad_bias); 1356 | } 1357 | 1358 | } // namespace 1359 | 1360 | TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { 1361 | m.impl( 1362 | TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), 1363 | TORCH_FN(deform_conv2d_forward_kernel)); 1364 | m.impl( 1365 | TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), 1366 | TORCH_FN(deform_conv2d_backward_kernel)); 1367 | } 1368 | 1369 | } // namespace ops 1370 | } // namespace vision 1371 | -------------------------------------------------------------------------------- /src/feature/get_patches.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/get_patches_cuda.h" 2 | #include 3 | #include 4 | 5 | // map: CxHxW 6 | // points: Nx2 7 | // kernel_size: int 8 | // return: N x C x kernel_size x kernel_size 9 | namespace custom_ops { 10 | torch::Tensor get_patches_forward_cpu(const torch::Tensor& map, torch::Tensor& points, int64_t kernel_size) { 11 | namespace F = torch::nn::functional; 12 | using namespace torch::indexing; 13 | 14 | auto N = points.size(0); 15 | auto C = map.size(0); 16 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 17 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 18 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 19 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 20 | auto radius = (kernel_size - 1.0) / 2.0; 21 | int pad_left_top = floor(radius); 22 | int pad_right_bottom = ceil(radius); 23 | 24 | // pad map 25 | auto options = F::PadFuncOptions({pad_left_top, pad_right_bottom, pad_left_top, pad_right_bottom}).mode(torch::kConstant); 26 | auto map_pad = F::pad(map.unsqueeze(0), options).squeeze(0); // Cx(H+2*radius)x(W+2*radius) 27 | 28 | // get patches 29 | torch::Tensor patches = torch::zeros({N, C, kernel_size, kernel_size}, map.options()); 30 | auto a_points = points.accessor(); // Nx2 31 | auto a_map_pad = map_pad.accessor(); // Cx(H+2*radius)x(W+2*radius) 32 | auto a_patches = patches.accessor(); // N x C x kernel_size x kernel_size 33 | 34 | for (auto in = 0; in < N; in++) 35 | { 36 | auto w_start = a_points[in][0]; 37 | auto h_start = a_points[in][1]; 38 | 39 | // copy data 40 | for (auto ic = 0; ic < C; ic++) 41 | { 42 | for (auto ih = 0; ih < kernel_size; ih++) 43 | { 44 | for (auto iw = 0; iw < kernel_size; iw++) 45 | { 46 | a_patches[in][ic][ih][iw] = a_map_pad[ic][ih + h_start][iw + w_start]; 47 | } 48 | } 49 | } 50 | } 51 | return patches; 52 | } 53 | 54 | // patches: NxCx(2*radius+1)x(2*radius+1) 55 | // points: Nx2 56 | torch::Tensor 57 | get_patches_backward_cpu(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W) { 58 | namespace F = torch::nn::functional; 59 | using namespace torch::indexing; 60 | 61 | auto N = d_patches.size(0); 62 | auto C = d_patches.size(1); 63 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 64 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 65 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 66 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 67 | auto kernel_size = d_patches.size(2); 68 | auto radius = (kernel_size - 1.0) / 2.0; 69 | int pad_left_top = floor(radius); 70 | int pad_right_bottom = ceil(radius); 71 | // printf("kernel_size=%d, radius=%f, pad_left_top=%d, pad_right_bottom=%d\n", 72 | // kernel_size, 73 | // radius, 74 | // pad_left_top, 75 | // pad_right_bottom); 76 | 77 | torch::Tensor d_map_pad = torch::zeros({C, H + int(2 * radius), W + int(2 * radius)}, d_patches.options()); 78 | 79 | auto a_points = points.accessor(); // Nx2 80 | auto a_d_map_pad = d_map_pad.accessor(); // Cx(H+2*radius)x(W+2*radius) 81 | auto a_p_patches = d_patches.accessor(); // NxCxkernel_sizexkernel_size 82 | for (auto in = 0; in < N; in++) 83 | { 84 | // long w_start = static_cast(*(p_points + in * 2 + 0)); 85 | // long h_start = static_cast(*(p_points + in * 2 + 1)); 86 | auto w_start = a_points[in][0]; 87 | auto h_start = a_points[in][1]; 88 | 89 | // copy data 90 | for (auto ic = 0; ic < C; ic++) 91 | { 92 | for (auto ih = 0; ih < kernel_size; ih++) 93 | { 94 | for (auto iw = 0; iw < kernel_size; iw++) 95 | { 96 | a_d_map_pad[ic][ih + h_start][iw + w_start] = a_p_patches[in][ic][ih][iw]; 97 | } 98 | } 99 | } 100 | } 101 | 102 | auto d_map = d_map_pad.index( 103 | {Slice(), Slice(pad_left_top, -pad_right_bottom), Slice(pad_left_top, -pad_right_bottom)}); 104 | 105 | return d_map; 106 | } 107 | 108 | torch::Tensor get_patches_forward(const torch::Tensor& map, torch::Tensor& points, int64_t kernel_size) { 109 | if (map.device() == torch::kCPU) 110 | return get_patches_forward_cpu(map, points, kernel_size); 111 | else 112 | { 113 | return get_patches_forward_cuda(map, points, kernel_size); 114 | } 115 | } 116 | 117 | torch::Tensor get_patches_backward(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W) { 118 | if (d_patches.device() == torch::kCPU) 119 | return get_patches_backward_cpu(d_patches, points, H, W); 120 | else 121 | return get_patches_backward_cuda(d_patches, points, H, W); 122 | } 123 | } // namespace custom_ops -------------------------------------------------------------------------------- /src/feature/get_patches_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "feature/get_patches_cuda.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace F = torch::nn::functional; 11 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | // CUDA: grid stride looping 19 | // 20 | // int64_t _i_n_d_e_x specifically prevents overflow in the loop increment. 21 | // If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final 22 | // iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be 23 | // greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no 24 | // further iterations and the overflowed value in i=_i_n_d_e_x is not used. 25 | #define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \ 26 | int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \ 27 | for (index_type i = _i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x += blockDim.x * gridDim.x, i = _i_n_d_e_x) 28 | 29 | #define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int) 30 | 31 | // Use 1024 threads per block, which requires cuda sm_2x or above 32 | // constexpr int CUDA_NUM_THREADS = 1024; 33 | constexpr int CUDA_NUM_THREADS = 16; 34 | 35 | // CUDA: number of blocks for threads. 36 | inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block = CUDA_NUM_THREADS) { 37 | TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N); 38 | constexpr int64_t max_int = std::numeric_limits::max(); 39 | 40 | // Round up division for positive number that cannot cause integer overflow 41 | auto block_num = (N - 1) / max_threads_per_block + 1; 42 | TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device"); 43 | 44 | return static_cast(block_num); 45 | } 46 | 47 | template 48 | C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) 49 | __global__ void get_patches_forward_cuda_kernel(const int64_t n, 50 | const scalar_t* p_map, // Cx(H+2*radius)x(W+2*radius) 51 | const int64_t* p_points, // Nx2 52 | int64_t n_input_plane, int64_t input_height, int64_t input_width, int64_t n_points, 53 | int64_t pad_left_top, int64_t pad_right_bottom, int64_t kernel_size, 54 | scalar_t* p_patches // NxCxkernel_sizexkernel_size 55 | ) { 56 | CUDA_KERNEL_LOOP(index, n) { 57 | int64_t n_out = index % n_points; // point idx 58 | int64_t channel_idx = index / n_points; // channel idx 59 | 60 | int64_t w_in = *(p_points + 2 * n_out); 61 | int64_t h_in = *(p_points + 2 * n_out + 1); 62 | 63 | const scalar_t* im = p_map + (channel_idx * input_height + h_in) * input_width + w_in; 64 | scalar_t* dst_patches = p_patches + (n_out * n_input_plane + channel_idx) * kernel_size * kernel_size; 65 | 66 | // copy data 67 | for (int64_t i = 0; i < kernel_size; ++i) 68 | { 69 | for (int64_t j = 0; j < kernel_size; ++j) 70 | { 71 | int64_t h = h_in + i - pad_left_top; 72 | int64_t w = w_in + j - pad_left_top; 73 | 74 | *(dst_patches + i * kernel_size + j) = (h >= 0 && w >= 0 && h < input_height && w < input_width) 75 | ? im[(i - pad_left_top) * input_width + j - pad_left_top] 76 | : static_cast(0); 77 | } 78 | } 79 | } 80 | } 81 | 82 | template 83 | __global__ void 84 | get_patches_forward_cuda_kernel1(const torch::PackedTensorAccessor32 map_pad, // Cx(H+2*radius)x(W+2*radius) 85 | const torch::PackedTensorAccessor32 points, // Nx2 86 | torch::PackedTensorAccessor32 patches, // NxCxkernel_sizexkernel_size 87 | int64_t kernel_size) { 88 | const int in = blockIdx.x * blockDim.x + threadIdx.x; 89 | const int N = points.size(0); 90 | const int C = map_pad.size(0); 91 | 92 | if (in < N) 93 | { 94 | long w_start = points[in][0]; 95 | long h_start = points[in][1]; 96 | 97 | // copy data 98 | for (long ic = 0; ic < C; ic++) 99 | { 100 | for (long ih = 0; ih < kernel_size; ih++) 101 | { 102 | for (long iw = 0; iw < kernel_size; iw++) 103 | { 104 | patches[in][ic][ih][iw] = map_pad[ic][h_start + ih][w_start + iw]; 105 | } 106 | } 107 | } 108 | } 109 | } 110 | 111 | template 112 | __global__ void 113 | get_patches_backward_cuda_kernel(torch::PackedTensorAccessor32 d_map_pad, // Cx(H+2*radius)x(W+2*radius) 114 | const torch::PackedTensorAccessor32 points, // Nx2 115 | const torch::PackedTensorAccessor32 d_patches, // NxCxkernel_sizexkernel_size 116 | int64_t kernel_size) { 117 | const int in = blockIdx.x * blockDim.x + threadIdx.x; 118 | const int N = points.size(0); 119 | const int C = d_map_pad.size(0); 120 | 121 | if (in < N) 122 | { 123 | long w_start = points[in][0]; 124 | long h_start = points[in][1]; 125 | 126 | // copy data 127 | for (long ic = 0; ic < C; ic++) 128 | { 129 | for (long ih = 0; ih < kernel_size; ih++) 130 | { 131 | for (long iw = 0; iw < kernel_size; iw++) 132 | { 133 | d_map_pad[ic][h_start + ih][w_start + iw] = d_patches[in][ic][ih][iw]; 134 | } 135 | } 136 | } 137 | } 138 | } 139 | 140 | torch::Tensor get_patches_forward_cuda(const torch::Tensor& input, torch::Tensor& points, int64_t kernel_size) { 141 | CHECK_INPUT(input); 142 | CHECK_INPUT(points); 143 | 144 | int64_t n_input_plane = input.size(0); 145 | int64_t input_height = input.size(1); 146 | int64_t input_width = input.size(2); 147 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 148 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 149 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 150 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 151 | auto radius = (kernel_size - 1.0) / 2.0; 152 | int64_t pad_left_top = floor(radius); 153 | int64_t pad_right_bottom = ceil(radius); 154 | int64_t n_points = points.size(0); 155 | 156 | // create output patches 157 | torch::Tensor patches = torch::zeros({n_points, n_input_plane, kernel_size, kernel_size}, input.options()); 158 | 159 | // cuda kernel 160 | int64_t num_kernels = n_input_plane * n_points; 161 | auto stream = at::cuda::getCurrentCUDAStream(); 162 | AT_DISPATCH_FLOATING_TYPES(input.type(), "get_patches_forward_cuda", 163 | ( 164 | [&] { 165 | get_patches_forward_cuda_kernel<<>>( 166 | num_kernels, input.data_ptr(), points.data_ptr(), n_input_plane, input_height, 167 | input_width, n_points, pad_left_top, pad_right_bottom, kernel_size, patches.data_ptr()); 168 | })); 169 | 170 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 171 | 172 | return patches; 173 | } 174 | 175 | torch::Tensor get_patches_forward_cuda1(const torch::Tensor& map, torch::Tensor& points, int64_t kernel_size) { 176 | CHECK_INPUT(map); 177 | CHECK_INPUT(points); 178 | 179 | auto N = points.size(0); 180 | auto C = map.size(0); 181 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 182 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 183 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 184 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 185 | auto radius = (kernel_size - 1.0) / 2.0; 186 | int pad_left_top = floor(radius); 187 | int pad_right_bottom = ceil(radius); 188 | 189 | // pad map 190 | auto options = F::PadFuncOptions({pad_left_top, pad_right_bottom, pad_left_top, pad_right_bottom}).mode(torch::kConstant); 191 | auto map_pad = F::pad(map.unsqueeze(0), options).squeeze(0); // Cx(H+2*radius)x(W+2*radius) 192 | 193 | // create patches 194 | torch::Tensor patches = torch::empty({N, C, kernel_size, kernel_size}, map.options()); 195 | 196 | // cuda kernel 197 | const int threads = CUDA_NUM_THREADS; 198 | const int blocks = (N + threads - 1) / threads; 199 | AT_DISPATCH_FLOATING_TYPES(map_pad.type(), "get_patches_forward_cuda", 200 | ( 201 | [&] { 202 | get_patches_forward_cuda_kernel1 203 | <<>>(map_pad.packed_accessor32(), 204 | points.packed_accessor32(), 205 | patches.packed_accessor32(), kernel_size); 206 | })); 207 | 208 | // get error 209 | cudaDeviceSynchronize(); 210 | cudaError_t cudaerr = cudaGetLastError(); 211 | if (cudaerr != cudaSuccess) 212 | printf("kernel launch failed with error \"%s\".\n", cudaGetErrorString(cudaerr)); 213 | 214 | return patches; 215 | } 216 | 217 | torch::Tensor get_patches_backward_cuda(const torch::Tensor& d_patches, torch::Tensor& points, int64_t H, int64_t W) { 218 | CHECK_INPUT(d_patches); 219 | CHECK_INPUT(points); 220 | 221 | auto N = d_patches.size(0); 222 | auto C = d_patches.size(1); 223 | // kernel_size=2, radius=0.5, pad_left_top=0, pad_right_bottom=1 224 | // kernel_size=3, radius=1.0, pad_left_top=1, pad_right_bottom=1 225 | // kernel_size=4, radius=1.5, pad_left_top=1, pad_right_bottom=2 226 | // kernel_size=5, radius=2.0, pad_left_top=2, pad_right_bottom=2 227 | auto kernel_size = d_patches.size(2); 228 | auto radius = (kernel_size - 1.0) / 2.0; 229 | int pad_left_top = floor(radius); 230 | int pad_right_bottom = ceil(radius); 231 | 232 | torch::Tensor d_map_pad = torch::zeros({C, H + int(2 * radius), W + int(2 * radius)}, d_patches.options()); 233 | 234 | // cuda kernel 235 | const int threads = CUDA_NUM_THREADS; 236 | const int blocks = (N + threads - 1) / threads; 237 | AT_DISPATCH_FLOATING_TYPES(d_map_pad.type(), "get_patches_backward_cuda", 238 | ( 239 | [&] { 240 | get_patches_backward_cuda_kernel 241 | <<>>(d_map_pad.packed_accessor32(), 242 | points.packed_accessor32(), 243 | d_patches.packed_accessor32(), kernel_size); 244 | })); 245 | 246 | // get error 247 | cudaDeviceSynchronize(); 248 | cudaError_t cudaerr = cudaGetLastError(); 249 | if (cudaerr != cudaSuccess) 250 | printf("kernel launch failed with error \"%s\".\n", cudaGetErrorString(cudaerr)); 251 | 252 | using namespace torch::indexing; 253 | auto d_map = d_map_pad.index({Slice(), Slice(pad_left_top, -pad_right_bottom), Slice(pad_left_top, -pad_right_bottom)}); 254 | 255 | return d_map; 256 | } 257 | -------------------------------------------------------------------------------- /src/feature/input_padder.cpp: -------------------------------------------------------------------------------- 1 | #include "feature/input_padder.hpp" 2 | 3 | torch::Tensor InputPadder::pad(torch::Tensor x) && { 4 | return torch::nn::functional::pad( 5 | x, 6 | torch::nn::functional::PadFuncOptions({pad_[0], pad_[1], pad_[2], pad_[3]}) 7 | .mode(torch::kReplicate)); 8 | } 9 | 10 | torch::Tensor InputPadder::pad(const torch::Tensor& x) & { 11 | return torch::nn::functional::pad( 12 | x, 13 | torch::nn::functional::PadFuncOptions({pad_[0], pad_[1], pad_[2], pad_[3]}) 14 | .mode(torch::kReplicate)); 15 | } 16 | 17 | [[maybe_unused]] torch::Tensor InputPadder::unpad(torch::Tensor x) && { 18 | int h = x.size(-2); 19 | int w = x.size(-1); 20 | return std::move(x).index({torch::indexing::Slice(), 21 | torch::indexing::Slice(), 22 | torch::indexing::Slice(pad_[2], h - pad_[3]), 23 | torch::indexing::Slice(pad_[0], w - pad_[1])}); 24 | } -------------------------------------------------------------------------------- /src/matcher/lightglue/attention.cpp: -------------------------------------------------------------------------------- 1 | #include "matcher/lightglue/attention.hpp" 2 | #include 3 | 4 | namespace matcher { 5 | 6 | SelfBlock::SelfBlock(int embed_dim, int num_heads, bool flash, bool bias) 7 | : embed_dim_(embed_dim), 8 | num_heads_(num_heads), 9 | head_dim_(embed_dim / num_heads), 10 | Wqkv_(torch::nn::Linear(torch::nn::LinearOptions(embed_dim, 3 * embed_dim).bias(bias))), 11 | inner_attn_(std::make_shared(flash)), 12 | out_proj_(torch::nn::Linear(torch::nn::LinearOptions(embed_dim, embed_dim).bias(bias))), 13 | ffn_(torch::nn::Sequential( 14 | torch::nn::Linear(torch::nn::LinearOptions(2 * embed_dim, 2 * embed_dim).bias(bias)), 15 | torch::nn::LayerNorm(torch::nn::LayerNormOptions({2 * embed_dim}).elementwise_affine(true)), 16 | torch::nn::GELU(), 17 | torch::nn::Linear(torch::nn::LinearOptions(2 * embed_dim, embed_dim).bias(bias)))) { 18 | register_module("Wqkv", Wqkv_); 19 | register_module("out_proj", out_proj_); 20 | register_module("ffn", ffn_); 21 | } 22 | 23 | torch::Tensor SelfBlock::rotate_half(const torch::Tensor& x) { 24 | auto x_split = x.unflatten(-1, {-1, 2}); 25 | auto x1 = x_split.select(-1, 0); 26 | auto x2 = x_split.select(-1, 1); 27 | return torch::stack({-x2, x1}, -1).flatten(-2); 28 | } 29 | 30 | torch::Tensor SelfBlock::apply_cached_rotary_emb( 31 | const torch::Tensor& freqs, 32 | const torch::Tensor& t) { 33 | 34 | return (t * freqs.select(0, 0)) + 35 | (rotate_half(t) * freqs.select(0, 1)); 36 | } 37 | 38 | torch::Tensor SelfBlock::forward( 39 | const torch::Tensor& x, 40 | const torch::Tensor& encoding) { 41 | 42 | // Project to QKV 43 | auto qkv = Wqkv_->forward(x); 44 | qkv = qkv.unflatten(-1, {num_heads_, -1, 3}).transpose(1, 2); 45 | 46 | // Split into query, key, value 47 | auto q = qkv.select(-1, 0); 48 | auto k = qkv.select(-1, 1); 49 | auto v = qkv.select(-1, 2); 50 | 51 | // Apply rotary embeddings 52 | q = apply_cached_rotary_emb(encoding, q); 53 | k = apply_cached_rotary_emb(encoding, k); 54 | 55 | // Apply attention 56 | auto context = inner_attn_->forward(q, k, v); 57 | 58 | // Project output and apply residual connection 59 | auto message = out_proj_->forward( 60 | context.transpose(1, 2).flatten(/*start_dim=*/-2)); 61 | 62 | // Combine with input using ffn 63 | return x + ffn_->forward(torch::cat({x, message}, -1)); 64 | } 65 | 66 | CrossBlock::CrossBlock(int embed_dim, int num_heads, bool flash, bool bias) 67 | : heads_(num_heads), 68 | scale_(1.0f / sqrt(embed_dim / num_heads)) { 69 | 70 | auto dim_head = embed_dim / num_heads; 71 | auto inner_dim = dim_head * num_heads; 72 | 73 | // Initialize projections using LinearOptions 74 | to_qk_ = register_module( 75 | "to_qk", torch::nn::Linear(torch::nn::LinearOptions(embed_dim, inner_dim).bias(bias))); 76 | to_v_ = register_module( 77 | "to_v", torch::nn::Linear(torch::nn::LinearOptions(embed_dim, inner_dim).bias(bias))); 78 | to_out_ = register_module( 79 | "to_out", torch::nn::Linear(torch::nn::LinearOptions(inner_dim, embed_dim).bias(bias))); 80 | 81 | // Initialize feed-forward network 82 | auto ffn = torch::nn::Sequential( 83 | torch::nn::Linear(torch::nn::LinearOptions(2 * embed_dim, 2 * embed_dim).bias(true)), 84 | torch::nn::LayerNorm(torch::nn::LayerNormOptions({2 * embed_dim}).elementwise_affine(true)), 85 | torch::nn::GELU(), 86 | torch::nn::Linear(torch::nn::LinearOptions(2 * embed_dim, embed_dim).bias(true))); 87 | 88 | ffn_ = register_module("ffn", ffn); 89 | 90 | // Initialize flash attention if requested 91 | if (flash && torch::cuda::is_available()) 92 | { 93 | flash_ = register_module("flash", std::make_shared(true)); 94 | } 95 | } 96 | 97 | std::tuple CrossBlock::forward( 98 | const torch::Tensor& x0, 99 | const torch::Tensor& x1, 100 | const torch::optional& mask) { 101 | 102 | // Project inputs 103 | auto qk0 = to_qk_->forward(x0); 104 | auto qk1 = to_qk_->forward(x1); 105 | auto v0 = to_v_->forward(x0); 106 | auto v1 = to_v_->forward(x1); 107 | 108 | // Reshape for attention 109 | auto reshape_for_attention = [this](torch::Tensor t) { 110 | return t.unflatten(-1, {heads_, -1}).transpose(1, 2); 111 | }; 112 | 113 | qk0 = reshape_for_attention(qk0); 114 | qk1 = reshape_for_attention(qk1); 115 | v0 = reshape_for_attention(v0); 116 | v1 = reshape_for_attention(v1); 117 | 118 | torch::Tensor m0, m1; 119 | 120 | if (flash_ && x0.device().is_cuda()) 121 | { 122 | // Use flash attention 123 | m0 = flash_->forward(qk0, qk1, v1); 124 | m1 = flash_->forward(qk1, qk0, v0); 125 | } else 126 | { 127 | // Manual attention computation 128 | qk0 = qk0 * sqrt(scale_); 129 | qk1 = qk1 * sqrt(scale_); 130 | 131 | auto sim = torch::einsum("bhid,bhjd->bhij", {qk0, qk1}); 132 | 133 | if (mask.has_value()) 134 | { 135 | sim.masked_fill_(~mask.value(), -INFINITY); 136 | } 137 | 138 | auto attn01 = torch::softmax(sim, -1); 139 | auto attn10 = torch::softmax(sim.transpose(-2, -1).contiguous(), -1); 140 | 141 | m0 = torch::einsum("bhij,bhjd->bhid", {attn01, v1}); 142 | m1 = torch::einsum("bhji,bhjd->bhid", 143 | {attn10.transpose(-2, -1), v0}); 144 | 145 | if (mask.has_value()) 146 | { 147 | m0 = m0.nan_to_num(); 148 | m1 = m1.nan_to_num(); 149 | } 150 | } 151 | 152 | // Project back to original dimensions 153 | auto project_out = [this](torch::Tensor t) { 154 | return to_out_->forward(t.transpose(1, 2).flatten(/*start_dim=*/-2)); 155 | }; 156 | 157 | m0 = project_out(m0); 158 | m1 = project_out(m1); 159 | 160 | // Apply FFN with residual connections 161 | auto out0 = x0 + ffn_->forward(torch::cat({x0, m0}, -1)); 162 | auto out1 = x1 + ffn_->forward(torch::cat({x1, m1}, -1)); 163 | 164 | return std::make_tuple(out0, out1); 165 | } 166 | 167 | TokenConfidence::TokenConfidence(int dim) { 168 | // Build sequential module for token confidence 169 | torch::nn::Sequential token; 170 | token->push_back(torch::nn::Linear(dim, 1)); 171 | token->push_back(torch::nn::Sigmoid()); 172 | 173 | token_ = register_module("token", token); 174 | } 175 | 176 | std::tuple TokenConfidence::forward( 177 | const torch::Tensor& desc0, 178 | const torch::Tensor& desc1) { 179 | 180 | return std::make_tuple( 181 | token_->forward(desc0.detach()).squeeze(-1), 182 | token_->forward(desc1.detach()).squeeze(-1)); 183 | } 184 | 185 | Attention::Attention(bool allow_flash) { 186 | // TODO: fix this 187 | // enable_flash_ = allow_flash && FLASH_AVAILABLE; 188 | // has_sdp_ = torch::cuda::is_available() && 189 | // torch::cuda::is_available(); // && 190 | // torch::().major >= 8; 191 | enable_flash_ = false; 192 | // if (enable_flash_) { 193 | // torch::cuda::set_device(torch::cuda::current_device()); 194 | // } 195 | } 196 | 197 | torch::Tensor Attention::forward( 198 | const torch::Tensor& q, 199 | const torch::Tensor& k, 200 | const torch::Tensor& v) { 201 | 202 | // Handle empty tensors 203 | if (q.size(-2) == 0 || k.size(-2) == 0) 204 | { 205 | return q.new_zeros({*q.sizes().begin(), q.size(-2), v.size(-1)}); 206 | } 207 | 208 | // Use scaled dot-product attention if available 209 | if (enable_flash_ && q.device().is_cuda()) 210 | { 211 | if (has_sdp_) 212 | { 213 | auto args_q = q.to(torch::kHalf).contiguous(); 214 | auto args_k = k.to(torch::kHalf).contiguous(); 215 | auto args_v = v.to(torch::kHalf).contiguous(); 216 | 217 | auto result = torch::scaled_dot_product_attention( 218 | args_q, args_k, args_v); 219 | 220 | result = result.to(q.dtype()); 221 | return result; 222 | } 223 | } 224 | 225 | // Fall back to manual implementation 226 | const auto scale = 1.f / sqrt(q.size(-1)); 227 | auto sim = torch::einsum("...id,...jd->...ij", {q, k}) * scale; 228 | auto attn = torch::softmax(sim, -1); 229 | return torch::einsum("...ij,...jd->...id", {attn, v}); 230 | } 231 | } -------------------------------------------------------------------------------- /src/matcher/lightglue/core.cpp: -------------------------------------------------------------------------------- 1 | #include "matcher/lightglue/core.hpp" 2 | 3 | namespace matcher::utils { 4 | torch::Tensor normalize_keypoints( 5 | const torch::Tensor& kpts, 6 | const torch::optional& size) { 7 | 8 | torch::Tensor size_tensor; 9 | if (!size.has_value()) 10 | { 11 | // Compute the size as the range of keypoints 12 | size_tensor = 1 + std::get<0>(torch::max(kpts, /*dim=*/-2)) - std::get<0>(torch::min(kpts, /*dim=*/-2)); 13 | } else 14 | { 15 | // If size is provided but not a tensor, convert it to a tensor 16 | size_tensor = size.value().to(kpts); 17 | } 18 | 19 | // Compute shift and scale 20 | auto shift = size_tensor / 2; 21 | auto scale = std::get<0>(size_tensor.max(-1)) / 2; 22 | 23 | return (kpts - shift.unsqueeze(-2)) / scale.unsqueeze(-1).unsqueeze(-1); 24 | } 25 | 26 | 27 | std::tuple 28 | filter_matches(const torch::Tensor& scores, float threshold) { 29 | int64_t M = scores.size(1) - 1; // 1708 30 | int64_t N = scores.size(2) - 1; // 1519 31 | 32 | auto scores_slice = scores.slice(1, 0, -1).slice(2, 0, -1); 33 | 34 | // Get max values and indices 35 | auto max0 = scores_slice.max(2); 36 | auto max1 = scores_slice.max(1); 37 | 38 | auto m0 = std::get<1>(max0); // shape: [1, M] 39 | auto max0_values = std::get<0>(max0); // shape: [1, M] 40 | auto m1 = std::get<1>(max1); // shape: [1, N] 41 | 42 | // Create index tensors with correct shape 43 | auto indices0 = torch::arange(M, m0.options()).unsqueeze(0); 44 | auto indices1 = torch::arange(N, m1.options()).unsqueeze(0); 45 | 46 | // Ensure all tensors are properly shaped before operations 47 | m0 = m0.view({1, M}); 48 | m1 = m1.view({1, N}); 49 | indices0 = indices0.view({1, M}); 50 | indices1 = indices1.view({1, N}); 51 | 52 | // Calculate mutual matches 53 | auto mutual0 = indices0 == m1.index_select(1, m0.squeeze()).view({1, M}); 54 | auto mutual1 = indices1 == m0.index_select(1, m1.squeeze()).view({1, N}); 55 | 56 | // Calculate scores 57 | auto max0_exp = max0_values.exp(); 58 | auto zero0 = torch::zeros({1, M}, max0_exp.options()); 59 | auto zero1 = torch::zeros({1, N}, max0_exp.options()); 60 | auto mscores0 = torch::where(mutual0, max0_exp, zero0); 61 | 62 | // Ensure proper shapes for score calculation 63 | auto mscores0_expanded = mscores0.index_select(1, m1.squeeze()); 64 | auto mscores1 = torch::where(mutual1, mscores0_expanded.view({1, N}), zero1); 65 | 66 | // Calculate valid matches 67 | auto valid0 = mutual0 & (mscores0 > threshold); 68 | auto valid1 = mutual1 & (mscores1 > threshold); 69 | 70 | // Create output tensors with correct shape 71 | auto m0_valid = torch::where(valid0, m0, torch::full({1, M}, -1, m0.options())); 72 | auto m1_valid = torch::where(valid1, m1, torch::full({1, N}, -1, m1.options())); 73 | 74 | return std::make_tuple(m0_valid, m1_valid, mscores0, mscores1); 75 | } 76 | } -------------------------------------------------------------------------------- /src/matcher/lightglue/encoding.cpp: -------------------------------------------------------------------------------- 1 | #include "matcher/lightglue/encoding.hpp" 2 | 3 | namespace matcher { 4 | LearnableFourierPosEnc::LearnableFourierPosEnc( 5 | int M, int dim, torch::optional F_dim, float gamma) 6 | : gamma_(gamma) { 7 | 8 | int f_dim = F_dim.value_or(dim); 9 | // Initialize Wr with normal distribution 10 | Wr_ = register_module("Wr", 11 | torch::nn::Linear(torch::nn::LinearOptions(M, f_dim / 2).bias(false))); 12 | 13 | // Initialize weights according to the paper 14 | auto std = gamma_ * gamma_; 15 | torch::nn::init::normal_(Wr_->weight, 0.0, std); 16 | } 17 | 18 | torch::Tensor LearnableFourierPosEnc::forward(const torch::Tensor& x) { 19 | // Project and compute trig functions 20 | auto projected = Wr_->forward(x); 21 | auto cosines = torch::cos(projected); 22 | auto sines = torch::sin(projected); 23 | 24 | // Stack and reshape 25 | auto emb = torch::stack({cosines, sines}, 0).unsqueeze(-3); 26 | return emb.repeat_interleave(2, -1); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/matcher/lightglue/matcher.cpp: -------------------------------------------------------------------------------- 1 | #include "matcher/lightglue/matcher.hpp" 2 | #include "matcher/lightglue/encoding.hpp" 3 | #include "matcher/lightglue/attention.hpp" 4 | #include "matcher/lightglue/transformer.hpp" 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace { 14 | std::string map_python_to_cpp(const std::string& python_name) { 15 | std::string cpp_name = python_name; 16 | 17 | size_t pos_transformer = cpp_name.find("transformers"); 18 | size_t pos_assignment = cpp_name.find("log_assignment"); 19 | size_t pos_confidence = cpp_name.find("token_confidence"); 20 | 21 | size_t pos = std::min({pos_transformer, pos_assignment, pos_confidence}); 22 | // Replace "." with "" 23 | size_t dot_pos = cpp_name.find_first_of("0123456789", pos); 24 | if (dot_pos != std::string::npos && cpp_name[dot_pos - 1] == '.') 25 | { 26 | cpp_name.erase(dot_pos - 1, 1); // Remove the dot before the number 27 | } 28 | 29 | return cpp_name; 30 | } 31 | } 32 | 33 | namespace matcher { 34 | MatchAssignment::MatchAssignment(int dim) 35 | : dim_(dim), 36 | matchability_(torch::nn::Linear(torch::nn::LinearOptions(dim_, 1).bias(true))), // Adjust the dimensions as needed 37 | final_proj_(torch::nn::LinearOptions(dim_, dim_).bias(true)) // Adjust the dimensions as needed 38 | { 39 | register_module("matchability", matchability_); 40 | register_module("final_proj", final_proj_); 41 | } 42 | 43 | torch::Tensor MatchAssignment::forward( 44 | const torch::Tensor& desc0, 45 | const torch::Tensor& desc1) { 46 | 47 | // Project descriptors 48 | auto mdesc0 = final_proj_->forward(desc0); 49 | auto mdesc1 = final_proj_->forward(desc1); 50 | 51 | // Scale by dimension 52 | auto d = mdesc0.size(-1); 53 | auto scale = 1.0f / std::pow(d, 0.25f); 54 | mdesc0 = mdesc0 * scale; 55 | mdesc1 = mdesc1 * scale; 56 | auto sim = torch::einsum("bmd,bnd->bmn", {mdesc0, mdesc1}); 57 | auto z0 = matchability_->forward(desc0); 58 | auto z1 = matchability_->forward(desc1); 59 | auto scores = sigmoid_log_double_softmax(sim, z0, z1); 60 | return scores; 61 | } 62 | 63 | torch::Tensor MatchAssignment::get_matchability(const torch::Tensor& desc) { 64 | // Debug input tensor 65 | auto weight = matchability_->weight.data(); 66 | auto bias = matchability_->bias.data(); 67 | auto result = torch::sigmoid(matchability_->forward(desc)).squeeze(-1); 68 | return result; 69 | } 70 | 71 | torch::Tensor MatchAssignment::sigmoid_log_double_softmax( 72 | const torch::Tensor& sim, 73 | const torch::Tensor& z0, 74 | const torch::Tensor& z1) { 75 | 76 | auto batch_size = sim.size(0); 77 | auto m = sim.size(1); 78 | auto n = sim.size(2); 79 | 80 | auto certainties = torch::log_sigmoid(z0) + 81 | torch::log_sigmoid(z1).transpose(1, 2); 82 | auto scores0 = torch::log_softmax(sim, 2); 83 | auto scores1 = torch::log_softmax( 84 | sim.transpose(-1, -2).contiguous(), 2) 85 | .transpose(-1, -2); 86 | 87 | auto scores = torch::full( 88 | {batch_size, m + 1, n + 1}, 0.0f, 89 | torch::TensorOptions().device(sim.device()).dtype(sim.dtype())); 90 | 91 | scores.index_put_( 92 | {torch::indexing::Slice(), 93 | torch::indexing::Slice(torch::indexing::None, m), 94 | torch::indexing::Slice(torch::indexing::None, n)}, 95 | scores0 + scores1 + certainties); 96 | 97 | scores.index_put_( 98 | {torch::indexing::Slice(), 99 | torch::indexing::Slice(torch::indexing::None, -1), 100 | n}, 101 | torch::log_sigmoid(-z0.squeeze(-1))); 102 | 103 | scores.index_put_( 104 | {torch::indexing::Slice(), 105 | m, 106 | torch::indexing::Slice(torch::indexing::None, -1)}, 107 | torch::log_sigmoid(-z1.squeeze(-1))); 108 | 109 | return scores; 110 | } 111 | // Static member initialization 112 | const std::unordered_map LightGlue::pruning_keypoint_thresholds_ = { 113 | {"cpu", -1}, 114 | {"mps", -1}, 115 | {"cuda", 1024}, 116 | {"flash", 1536}}; 117 | 118 | // Feature configurations 119 | static const std::unordered_map> FEATURES = { 120 | {"aliked", {"aliked_lightglue", 128}}}; 121 | 122 | 123 | LightGlue::LightGlue(const std::string& feature_type, const LightGlueConfig& config) 124 | : config_(config), 125 | device_(torch::kCPU) { 126 | 127 | // Configure based on feature type 128 | auto it = FEATURES.find(feature_type); 129 | if (it == FEATURES.end()) 130 | { 131 | throw std::runtime_error("Unsupported feature type: " + feature_type); 132 | } 133 | 134 | config_.weights = it->second.first; 135 | config_.input_dim = it->second.second; 136 | 137 | // Initialize input projection if needed 138 | if (config_.input_dim != config_.descriptor_dim) 139 | { 140 | input_proj_ = register_module("input_proj", 141 | torch::nn::Linear(config_.input_dim, config_.descriptor_dim)); 142 | } 143 | 144 | // Initialize positional encoding 145 | posenc_ = register_module("posenc", 146 | std::make_shared( 147 | 2 + 2 * config_.add_scale_ori, 148 | config_.descriptor_dim / config_.num_heads, 149 | config_.descriptor_dim / config_.num_heads)); 150 | 151 | // Initialize transformer layers 152 | for (int i = 0; i < config_.n_layers; ++i) 153 | { 154 | auto layer = std::make_shared( 155 | config_.descriptor_dim, 156 | config_.num_heads, 157 | config_.flash); 158 | 159 | transformers_.push_back(layer); 160 | register_module("transformers" + std::to_string(i), layer); 161 | } 162 | 163 | // Initialize assignment and token confidence layers 164 | for (int i = 0; i < config_.n_layers; ++i) 165 | { 166 | auto assign = std::make_shared(config_.descriptor_dim); 167 | log_assignment_.push_back(assign); 168 | register_module("log_assignment" + std::to_string(i), assign); 169 | 170 | if (i < config_.n_layers - 1) 171 | { 172 | auto conf = std::make_shared(config_.descriptor_dim); 173 | token_confidence_.push_back(conf); 174 | register_module("token_confidence" + std::to_string(i), conf); 175 | } 176 | } 177 | 178 | // Register confidence thresholds buffer 179 | confidence_thresholds_.reserve(config_.n_layers); 180 | 181 | auto confidence_threshold = [](int layer_index, int n_layers) -> float { 182 | float progress = static_cast(layer_index) / n_layers; 183 | float threshold = 0.8f + 0.1f * std::exp(-4.0f * progress); 184 | return std::clamp(threshold, 0.0f, 1.0f); 185 | }; 186 | 187 | for (int i = 0; i < config_.n_layers; ++i) 188 | { 189 | confidence_thresholds_.push_back(confidence_threshold(i, config.n_layers)); 190 | } 191 | 192 | // Load weights if specified 193 | if (!config_.weights.empty()) 194 | { 195 | load_weights(config_.weights); 196 | } 197 | 198 | // Move to device if CUDA is available 199 | if (torch::cuda::is_available()) 200 | { 201 | device_ = torch::kCUDA; 202 | this->to(device_); 203 | } 204 | } 205 | 206 | void LightGlue::to(const torch::Device& device) { 207 | device_ = device; 208 | torch::nn::Module::to(device); 209 | } 210 | 211 | 212 | torch::Tensor LightGlue::get_pruning_mask( 213 | const torch::optional& confidences, 214 | const torch::Tensor& scores, 215 | int layer_index) { 216 | 217 | // Initialize keep mask based on scores 218 | auto keep = scores > (1.0f - config_.width_confidence); 219 | 220 | // Include low-confidence points if confidences are provided 221 | if (confidences.has_value()) 222 | { 223 | keep = keep | (confidences.value() <= confidence_thresholds_[layer_index]); 224 | } 225 | 226 | return keep; 227 | } 228 | 229 | bool LightGlue::check_if_stop( 230 | const torch::Tensor& confidences0, 231 | const torch::Tensor& confidences1, 232 | int layer_index, 233 | int num_points) { 234 | 235 | auto confidences = torch::cat({confidences0, confidences1}, -1); 236 | auto threshold = confidence_thresholds_[layer_index]; 237 | auto ratio_confident = 1.0f - 238 | (confidences < threshold).to(torch::kFloat32).sum().item() / num_points; 239 | return ratio_confident > config_.depth_confidence; 240 | } 241 | 242 | torch::Dict LightGlue::forward( 243 | const torch::Dict& data0, 244 | const torch::Dict& data1) { 245 | 246 | // Extract keypoints and descriptors 247 | // TODO: Batching 248 | const auto& kpts0_ref = data0.at("keypoints"); 249 | const auto& kpts1_ref = data1.at("keypoints"); 250 | const auto& desc0_ref = data0.at("descriptors"); 251 | const auto& desc1_ref = data1.at("descriptors"); 252 | 253 | // Single operation instead of multiple calls 254 | auto kpts0 = kpts0_ref.detach().contiguous().unsqueeze(0); 255 | auto kpts1 = kpts1_ref.detach().contiguous().unsqueeze(0); 256 | auto desc0 = desc0_ref.detach().contiguous().unsqueeze(0); 257 | auto desc1 = desc1_ref.detach().contiguous().unsqueeze(0); 258 | 259 | // Pre-calculate sizes once 260 | const int64_t b = kpts0.size(0); 261 | const int64_t m = kpts0.size(1); 262 | const int64_t n = kpts1.size(1); 263 | 264 | // Get image sizes if available 265 | torch::optional size0, size1; 266 | if (data0.contains("image_size")) 267 | size0 = data0.at("image_size"); 268 | if (data1.contains("image_size")) 269 | size1 = data1.at("image_size"); 270 | 271 | // Normalize keypoints 272 | kpts0 = matcher::utils::normalize_keypoints(kpts0, size0).clone(); 273 | kpts1 = matcher::utils::normalize_keypoints(kpts1, size1).clone(); 274 | 275 | // Add scale and orientation if configured 276 | if (config_.add_scale_ori) 277 | { 278 | kpts0 = torch::cat({kpts0, 279 | data0.at("scales").unsqueeze(-1), 280 | data0.at("oris").unsqueeze(-1)}, 281 | -1); 282 | kpts1 = torch::cat({kpts1, 283 | data1.at("scales").unsqueeze(-1), 284 | data1.at("oris").unsqueeze(-1)}, 285 | -1); 286 | } 287 | 288 | // Convert to fp16 if mixed precision is enabled 289 | if (config_.mp && device_.is_cuda()) 290 | { 291 | desc0 = desc0.to(torch::kHalf); 292 | desc1 = desc1.to(torch::kHalf); 293 | } 294 | 295 | // Project descriptors if needed 296 | if (config_.input_dim != config_.descriptor_dim) 297 | { 298 | desc0 = input_proj_->forward(desc0); 299 | desc1 = input_proj_->forward(desc1); 300 | } 301 | 302 | // Generate positional encodings 303 | auto encoding0 = posenc_->forward(kpts0); 304 | auto encoding1 = posenc_->forward(kpts1); 305 | 306 | // Initialize pruning if enabled 307 | const bool do_early_stop = config_.depth_confidence > 0.f; 308 | const bool do_point_pruning = config_.width_confidence > 0.f; 309 | const auto pruning_th = pruning_keypoint_thresholds_.at( 310 | config_.flash ? "flash" : device_.is_cuda() ? "cuda" 311 | : "cpu"); 312 | 313 | torch::Tensor ind0, ind1, prune0, prune1; 314 | if (do_point_pruning) 315 | { 316 | ind0 = torch::arange(m, torch::TensorOptions().device(device_)).unsqueeze(0); 317 | ind1 = torch::arange(n, torch::TensorOptions().device(device_)).unsqueeze(0); 318 | prune0 = torch::ones_like(ind0); 319 | prune1 = torch::ones_like(ind1); 320 | } 321 | 322 | // Process through transformer layers 323 | torch::optional token0, token1; 324 | int last_layer; 325 | for (int i = 0; i < config_.n_layers; ++i) 326 | { 327 | last_layer = i; 328 | if (desc0.size(1) == 0 || desc1.size(1) == 0) 329 | break; 330 | 331 | std::tie(desc0, desc1) = transformers_[i]->forward( 332 | desc0, desc1, encoding0, encoding1); 333 | 334 | if (i == config_.n_layers - 1) 335 | continue; 336 | 337 | // Early stopping check 338 | if (do_early_stop) 339 | { 340 | std::tie(token0, token1) = token_confidence_[i]->forward(desc0, desc1); 341 | 342 | if (check_if_stop( 343 | token0.value().index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, m)}), 344 | token1.value().index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, n)}), 345 | i, m + n)) 346 | { 347 | break; 348 | } 349 | } 350 | 351 | if (do_point_pruning && desc0.size(-2) > pruning_th) 352 | { 353 | auto scores0 = log_assignment_[i]->get_matchability(desc0); 354 | auto prunemask0 = get_pruning_mask(token0, scores0, i); 355 | 356 | if (prunemask0.dtype() != torch::kBool) 357 | { 358 | prunemask0 = prunemask0.to(torch::kBool); 359 | } 360 | 361 | auto where_result = torch::where(prunemask0); 362 | auto keep0 = where_result[1]; 363 | 364 | if (keep0.numel() > 0) 365 | { 366 | ind0 = ind0.index_select(1, keep0); 367 | desc0 = desc0.index_select(1, keep0); 368 | encoding0 = encoding0.index_select(-2, keep0); 369 | prune0.index_put_({torch::indexing::Slice(), ind0}, prune0.index({torch::indexing::Slice(), ind0}) + 1); 370 | } else 371 | { 372 | std::cout << "No points kept after pruning for desc0." << std::endl; 373 | } 374 | } 375 | 376 | if (do_point_pruning && desc1.size(-2) > pruning_th) 377 | { 378 | auto scores1 = log_assignment_[i]->get_matchability(desc1); 379 | auto prunemask1 = get_pruning_mask(token1, scores1, i); 380 | if (prunemask1.dtype() != torch::kBool) 381 | { 382 | prunemask1 = prunemask1.to(torch::kBool); 383 | } 384 | 385 | auto where_result = torch::where(prunemask1); 386 | auto keep1 = where_result[1]; 387 | if (keep1.numel() > 0) 388 | { 389 | ind1 = ind1.index_select(1, keep1); 390 | desc1 = desc1.index_select(1, keep1); 391 | encoding1 = encoding1.index_select(-2, keep1); 392 | prune1.index_put_({torch::indexing::Slice(), ind1}, prune1.index({torch::indexing::Slice(), ind1}) + 1); 393 | } else 394 | { 395 | std::cout << "No points kept after pruning for desc1." << std::endl; 396 | } 397 | } 398 | } 399 | 400 | // Handle empty descriptor case 401 | if (desc0.size(1) == 0 || desc1.size(1) == 0) 402 | { 403 | auto m0 = torch::full({b, m}, -1, torch::TensorOptions().dtype(torch::kLong).device(device_)); 404 | auto m1 = torch::full({b, n}, -1, torch::TensorOptions().dtype(torch::kLong).device(device_)); 405 | auto mscores0 = torch::zeros({b, m}, device_); 406 | auto mscores1 = torch::zeros({b, n}, device_); 407 | 408 | if (!do_point_pruning) 409 | { 410 | prune0 = torch::ones_like(mscores0) * config_.n_layers; 411 | prune1 = torch::ones_like(mscores1) * config_.n_layers; 412 | } 413 | 414 | torch::Dict output; 415 | output.insert("matches0", m0); 416 | output.insert("matches1", m1); 417 | output.insert("matching_scores0", mscores0); 418 | output.insert("matching_scores1", mscores1); 419 | output.insert("stop", torch::tensor(last_layer + 1)); 420 | output.insert("prune0", prune0); 421 | output.insert("prune1", prune1); 422 | 423 | return output; 424 | } 425 | 426 | // Remove padding and compute assignment 427 | desc0 = desc0.index({torch::indexing::Slice(), torch::indexing::Slice(0, m), torch::indexing::Slice()}); 428 | desc1 = desc1.index({torch::indexing::Slice(), torch::indexing::Slice(0, n), torch::indexing::Slice()}); 429 | 430 | auto scores = log_assignment_[last_layer]->forward(desc0, desc1); 431 | auto [m0, m1, mscores0, mscores1] = matcher::utils::filter_matches(scores, config_.filter_threshold); 432 | torch::Tensor m_indices_0, m_indices_1; 433 | 434 | if (do_point_pruning) 435 | { 436 | // Get the actual number of matches from m0 437 | int64_t num_matches = m0.size(1); // Should be 1708 438 | 439 | // Create batch indices tensor and repeat for each match 440 | auto batch_indices = torch::arange(b, torch::TensorOptions().device(device_)); 441 | m_indices_0 = batch_indices.unsqueeze(1).expand({b, num_matches}).reshape(-1); 442 | 443 | // Flatten match indices and create mask for valid matches 444 | m_indices_1 = m0.reshape(-1); 445 | auto valid_mask = m_indices_1 >= 0; 446 | 447 | // Apply mask to both tensors using masked_select 448 | m_indices_0 = m_indices_0.masked_select(valid_mask); 449 | m_indices_1 = m_indices_1.masked_select(valid_mask); 450 | 451 | // Use advanced indexing to select final indices 452 | if (m_indices_0.numel() > 0 && m_indices_1.numel() > 0) 453 | { 454 | m_indices_0 = ind0.index({torch::indexing::Slice(), m_indices_0}); 455 | m_indices_1 = ind1.index({torch::indexing::Slice(), m_indices_1}); 456 | } 457 | } 458 | 459 | auto matches = torch::stack({m_indices_0, m_indices_1}, 0); 460 | 461 | // Update m0, m1, mscores tensors 462 | if (do_point_pruning) 463 | { 464 | auto m0_ = torch::full({b, m}, -1, torch::TensorOptions().dtype(torch::kLong).device(device_)); 465 | auto m1_ = torch::full({b, n}, -1, torch::TensorOptions().dtype(torch::kLong).device(device_)); 466 | auto mscores0_ = torch::zeros({b, m}, device_); 467 | auto mscores1_ = torch::zeros({b, n}, device_); 468 | 469 | m0_.index_put_({torch::indexing::Slice(), ind0}, 470 | torch::where(m0 == -1, -1, ind1.gather(1, m0.clamp(0)))); 471 | m1_.index_put_({torch::indexing::Slice(), ind1}, 472 | torch::where(m1 == -1, -1, ind0.gather(1, m1.clamp(0)))); 473 | 474 | mscores0_.index_put_({torch::indexing::Slice(), ind0}, mscores0); 475 | mscores1_.index_put_({torch::indexing::Slice(), ind1}, mscores1); 476 | 477 | m0 = m0_; 478 | m1 = m1_; 479 | mscores0 = mscores0_; 480 | mscores1 = mscores1_; 481 | } else 482 | { 483 | prune0 = torch::ones_like(mscores0) * config_.n_layers; 484 | prune1 = torch::ones_like(mscores1) * config_.n_layers; 485 | } 486 | 487 | // Prepare output 488 | torch::Dict output; 489 | output.insert("matches0", m0); 490 | output.insert("matches1", m1); 491 | output.insert("matching_scores0", mscores0); 492 | output.insert("matching_scores1", mscores1); 493 | output.insert("matches", matches); 494 | output.insert("stop", torch::tensor(last_layer + 1)); 495 | output.insert("prune0", prune0); 496 | output.insert("prune1", prune1); 497 | 498 | return output; 499 | } 500 | 501 | void LightGlue::load_weights(const std::string& feature_type) { 502 | std::vector search_paths = { 503 | std::filesystem::path(LIGHTGLUE_MODELS_DIR) / (std::string(feature_type) + ".pt"), 504 | std::filesystem::current_path() / "models" / (std::string(feature_type) + ".pt"), 505 | std::filesystem::current_path() / (std::string(feature_type) + ".pt")}; 506 | 507 | std::filesystem::path model_path; 508 | bool found = false; 509 | 510 | for (const auto& path : search_paths) 511 | { 512 | if (std::filesystem::exists(path)) 513 | { 514 | model_path = path; 515 | found = true; 516 | break; 517 | } 518 | } 519 | 520 | if (!found) 521 | { 522 | std::string error_msg = "Cannot find pretrained model. Searched in:\n"; 523 | for (const auto& path : search_paths) 524 | { 525 | error_msg += " " + path.string() + "\n"; 526 | } 527 | error_msg += "Please place the model file in one of these locations."; 528 | throw std::runtime_error(error_msg); 529 | } 530 | 531 | std::cout << "Loading model from: " << model_path << std::endl; 532 | load_parameters(model_path.string()); 533 | } 534 | 535 | void LightGlue::load_parameters(const std::string& pt_path) { 536 | auto f = get_the_bytes(pt_path); 537 | auto weights = torch::pickle_load(f).toGenericDict(); 538 | 539 | // Use unordered_maps for O(1) lookup 540 | std::unordered_map param_map; 541 | std::unordered_map buffer_map; 542 | 543 | auto model_params = named_parameters(); 544 | auto model_buffers = named_buffers(); 545 | // Pre-allocate with expected size 546 | param_map.reserve(model_params.size()); 547 | buffer_map.reserve(model_buffers.size()); 548 | 549 | // Collect parameter names 550 | for (const auto& p : model_params) 551 | { 552 | param_map.emplace(p.key(), p.value()); 553 | } 554 | 555 | // Collect buffer names 556 | for (const auto& b : model_buffers) 557 | { 558 | buffer_map.emplace(b.key(), b.value()); 559 | } 560 | 561 | // Update parameters and buffers 562 | torch::NoGradGuard no_grad; 563 | 564 | for (const auto& w : weights) 565 | { 566 | const auto name = map_python_to_cpp(w.key().toStringRef()); 567 | const auto& param = w.value().toTensor(); 568 | 569 | // Try parameters first 570 | if (auto it = param_map.find(name); it != param_map.end()) 571 | { 572 | if (it->second.sizes() == param.sizes()) 573 | { 574 | it->second.copy_(param); 575 | } else 576 | { 577 | throw std::runtime_error( 578 | "Shape mismatch for parameter: " + name + 579 | " Expected: " + std::to_string(it->second.numel()) + 580 | " Got: " + std::to_string(param.numel())); 581 | } 582 | continue; 583 | } 584 | 585 | // Then try buffers 586 | if (auto it = buffer_map.find(name); it != buffer_map.end()) 587 | { 588 | if (it->second.sizes() == param.sizes()) 589 | { 590 | it->second.copy_(param); 591 | } else 592 | { 593 | std::cout << "buffer name: " << name << "Expected: " << it->second.sizes() << ", Got: " << param.sizes() << std::endl; 594 | throw std::runtime_error( 595 | "Shape mismatch for buffer: " + name + 596 | " Expected: " + std::to_string(it->second.numel()) + 597 | " Got: " + std::to_string(param.numel())); 598 | } 599 | continue; 600 | } 601 | 602 | // Parameter not found in model 603 | std::cerr << "Warning: " << name 604 | << " not found in model parameters or buffers\n"; 605 | } 606 | } 607 | 608 | std::vector LightGlue::get_the_bytes(const std::string& filename) { 609 | // Use RAII file handling 610 | std::ifstream file(std::string(filename), std::ios::binary); 611 | if (!file) 612 | { 613 | throw std::runtime_error( 614 | "Failed to open file: " + std::string(filename)); 615 | } 616 | 617 | // Get file size 618 | file.seekg(0, std::ios::end); 619 | const auto size = file.tellg(); 620 | file.seekg(0, std::ios::beg); 621 | 622 | // Pre-allocate vector 623 | std::vector buffer; 624 | buffer.reserve(size); 625 | 626 | // Read file in chunks for better performance 627 | constexpr size_t CHUNK_SIZE = 8192; 628 | char chunk[CHUNK_SIZE]; 629 | 630 | while (file.read(chunk, CHUNK_SIZE)) 631 | { 632 | buffer.insert(buffer.end(), chunk, chunk + file.gcount()); 633 | } 634 | if (file.gcount() > 0) 635 | { 636 | buffer.insert(buffer.end(), chunk, chunk + file.gcount()); 637 | } 638 | 639 | return buffer; 640 | } 641 | } -------------------------------------------------------------------------------- /src/matcher/lightglue/transformer.cpp: -------------------------------------------------------------------------------- 1 | #include "matcher/lightglue/transformer.hpp" 2 | #include "matcher/lightglue/attention.hpp" 3 | 4 | namespace matcher { 5 | TransformerLayer::TransformerLayer(int embed_dim, int num_heads, bool flash, bool bias) { 6 | // Initialize self-attention block 7 | self_attn_ = register_module("self_attn", 8 | std::make_shared(embed_dim, num_heads, flash, bias)); 9 | 10 | // Initialize cross-attention block 11 | cross_attn_ = register_module("cross_attn", 12 | std::make_shared(embed_dim, num_heads, flash, bias)); 13 | } 14 | 15 | std::tuple TransformerLayer::forward( 16 | const torch::Tensor& desc0, 17 | const torch::Tensor& desc1, 18 | const torch::Tensor& encoding0, 19 | const torch::Tensor& encoding1) { 20 | 21 | auto desc0_sa = self_attn_->forward(desc0, encoding0); 22 | auto desc1_sa = self_attn_->forward(desc1, encoding1); 23 | 24 | // Apply cross-attention between the two sets 25 | return cross_attn_->forward(desc0_sa, desc1_sa); 26 | } 27 | } --------------------------------------------------------------------------------