├── Sources ├── MetalShaders │ ├── MetalShaders.c │ ├── module.modulemap │ ├── include │ │ ├── MetalShaders.h │ │ ├── NearestNeighbor.h │ │ ├── SIFTExtrema.h │ │ ├── ConvolutionSeries.h │ │ ├── SIFTOrientation.h │ │ ├── SIFTDescriptor.h │ │ └── SIFTInterpolate.h │ └── Metal │ │ ├── Subtract.metal │ │ ├── NearestNeighborDownScale.metal │ │ ├── NearestNeighborUpScale.metal │ │ ├── ConvertSRGBToGrayscale.metal │ │ ├── SIFTGradient.metal │ │ ├── Convolution.metal │ │ ├── ConvolutionSeries.metal │ │ ├── BilinearUpScale.metal │ │ ├── Common.hpp │ │ ├── SIFTExtrema.metal │ │ ├── SIFTOrientation.metal │ │ ├── SIFTDescriptor.metal │ │ └── SIFTInterpolate.metal └── SIFTMetal │ ├── SIFT │ ├── SIFTKeypointOrientations.swift │ ├── SIFTCorrespondence.swift │ ├── SIFTKeypoint.swift │ ├── SIFTPatch.swift │ ├── SIFT.swift │ └── SIFTDescriptor.swift │ ├── Extensions │ ├── CoreGraphicsExtensions.swift │ ├── FoundationExtensions.swift │ └── CoreImageExtensions.swift │ ├── Utilities │ ├── Performance.swift │ ├── Math.swift │ ├── MetalExtensions.swift │ ├── CoreVideoMetalCache.swift │ ├── ImageConversion.swift │ ├── Buffer.swift │ ├── Quad.swift │ ├── Image.swift │ └── SIFTRenderer.swift │ └── Metal Compute │ ├── NearestNeighborUpScaleKernel.swift │ ├── SIFTDescriptorKernel.swift │ ├── BilinearUpScaleKernel.swift │ ├── SIFTExtremaKernel.swift │ ├── ConvertSRGBToGrayscaleKernel.swift │ ├── SIFTExtremaListKernel.swift │ ├── SubtractKernel.swift │ ├── SIFTGradientKernel.swift │ ├── SIFTInterpolateKernel.swift │ ├── SIFTOrientationKernel.swift │ ├── GaussianKernel.swift │ ├── NearestNeighborDownScaleKernel.swift │ ├── Convolution1DKernel.swift │ ├── ConvolutionSeriesKernel.swift │ └── GaussianSeriesKernel.swift ├── Tests └── SIFTMetalTests │ ├── Resources │ ├── butterfly.png │ ├── DoG_butterfly_o000_s001.png │ ├── DoG_butterfly_o000_s002.png │ ├── DoG_butterfly_o000_s003.png │ ├── DoG_butterfly_o001_s001.png │ ├── DoG_butterfly_o001_s002.png │ ├── DoG_butterfly_o001_s003.png │ ├── DoG_butterfly_o002_s001.png │ ├── DoG_butterfly_o002_s002.png │ ├── DoG_butterfly_o002_s003.png │ ├── DoG_butterfly_o003_s001.png │ ├── DoG_butterfly_o003_s002.png │ ├── DoG_butterfly_o003_s003.png │ ├── DoG_butterfly_o004_s001.png │ ├── DoG_butterfly_o004_s002.png │ ├── DoG_butterfly_o004_s003.png │ ├── scalespace_butterfly_o000_s000.png │ ├── scalespace_butterfly_o000_s001.png │ ├── scalespace_butterfly_o000_s002.png │ ├── scalespace_butterfly_o000_s003.png │ ├── scalespace_butterfly_o000_s004.png │ ├── scalespace_butterfly_o000_s005.png │ ├── scalespace_butterfly_o001_s000.png │ ├── scalespace_butterfly_o001_s001.png │ ├── scalespace_butterfly_o001_s002.png │ ├── scalespace_butterfly_o001_s003.png │ ├── scalespace_butterfly_o001_s004.png │ ├── scalespace_butterfly_o001_s005.png │ ├── scalespace_butterfly_o002_s000.png │ ├── scalespace_butterfly_o002_s001.png │ ├── scalespace_butterfly_o002_s002.png │ ├── scalespace_butterfly_o002_s003.png │ ├── scalespace_butterfly_o002_s004.png │ ├── scalespace_butterfly_o002_s005.png │ ├── scalespace_butterfly_o003_s000.png │ ├── scalespace_butterfly_o003_s001.png │ ├── scalespace_butterfly_o003_s002.png │ ├── scalespace_butterfly_o003_s003.png │ ├── scalespace_butterfly_o003_s004.png │ ├── scalespace_butterfly_o003_s005.png │ ├── scalespace_butterfly_o004_s000.png │ ├── scalespace_butterfly_o004_s001.png │ ├── scalespace_butterfly_o004_s002.png │ ├── scalespace_butterfly_o004_s003.png │ ├── scalespace_butterfly_o004_s004.png │ └── scalespace_butterfly_o004_s005.png │ ├── GaussianDifferenceTests.swift │ ├── KeypointTests.swift │ ├── SharedTestCase.swift │ ├── TrieTests.swift │ ├── DescriptorTests.swift │ └── DifferenceOfGaussiansTests.swift ├── .gitignore ├── .swiftpm └── xcode │ ├── package.xcworkspace │ └── xcshareddata │ │ └── IDEWorkspaceChecks.plist │ └── xcshareddata │ └── xcschemes │ └── SIFTMetal.xcscheme ├── LICENSE ├── README.md └── Package.swift /Sources/MetalShaders/MetalShaders.c: -------------------------------------------------------------------------------- 1 | // File intentionally left blank 2 | -------------------------------------------------------------------------------- /Sources/MetalShaders/module.modulemap: -------------------------------------------------------------------------------- 1 | module Shaders { 2 | header "MetalShaders.h" 3 | export * 4 | } 5 | -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/butterfly.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s000.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s004.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s005.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s000.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s004.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s005.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s000.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s004.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s005.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s000.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s004.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s005.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s000.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s001.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s002.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s003.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s004.png -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s005.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /.build 3 | /Packages 4 | /*.xcodeproj 5 | xcuserdata/ 6 | DerivedData/ 7 | .swiftpm/config/registries.json 8 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata 9 | .netrc 10 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/SIFT/SIFTKeypointOrientations.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTOrientedKeypoint.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/10. 6 | // 7 | 8 | import Foundation 9 | 10 | struct SIFTKeypointOrientations { 11 | let keypoint: SIFTKeypoint 12 | let orientations: [Float] 13 | } 14 | -------------------------------------------------------------------------------- /.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /Sources/MetalShaders/include/MetalShaders.h: -------------------------------------------------------------------------------- 1 | // 2 | // Use this file to import your target's public headers that you would like to expose to Swift. 3 | // 4 | 5 | 6 | #import "SIFTExtrema.h" 7 | #import "SIFTInterpolate.h" 8 | #import "SIFTOrientation.h" 9 | #import "SIFTDescriptor.h" 10 | #import "ConvolutionSeries.h" 11 | #import "NearestNeighbor.h" 12 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Extensions/CoreGraphicsExtensions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Extensions.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/17. 6 | // 7 | 8 | import CoreGraphics 9 | import simd 10 | 11 | 12 | extension CGPoint { 13 | init(_ point: simd_float2) { 14 | self.init(x: CGFloat(point.x), y: CGFloat(point.y)) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/SIFT/SIFTCorrespondence.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTCorrespondence.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/02. 6 | // 7 | 8 | import Foundation 9 | 10 | 11 | public struct SIFTCorrespondence { 12 | 13 | public var source: SIFTDescriptor 14 | public var target: SIFTDescriptor 15 | public var featureDistance: Float 16 | } 17 | -------------------------------------------------------------------------------- /Sources/MetalShaders/include/NearestNeighbor.h: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighbor.h 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/08. 6 | // 7 | 8 | #include 9 | 10 | #ifndef NearestNeighbor_h 11 | #define NearestNeighbor_h 12 | 13 | struct NearestNeighborScaleParameters { 14 | int32_t inputSlice; 15 | int32_t outputSlice; 16 | }; 17 | 18 | #endif /* NearestNeighbor_h */ 19 | -------------------------------------------------------------------------------- /Sources/MetalShaders/include/SIFTExtrema.h: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTExtrema.h 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/10. 6 | // 7 | 8 | #include 9 | 10 | #ifndef SIFTExtrema_h 11 | #define SIFTExtrema_h 12 | 13 | // This should match SIFTInterpolateInputKeypoint 14 | struct SIFTExtremaResult { 15 | int32_t x; 16 | int32_t y; 17 | int32_t scale; 18 | }; 19 | 20 | #endif /* SIFTExtrema_h */ 21 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Extensions/FoundationExtensions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // File.swift 3 | // 4 | // 5 | // Created by Luke Van In on 2023/01/24. 6 | // 7 | 8 | import Foundation 9 | 10 | extension Bundle { 11 | 12 | // See: https://developer.apple.com/forums/thread/649579 13 | static var metalShaders: Bundle { 14 | let bundleURL = Bundle(for: SIFT.self).url(forResource: "SIFTMetal_MetalShaders", withExtension: "bundle")! 15 | let bundle = Bundle(url: bundleURL)! 16 | return bundle 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /Sources/MetalShaders/include/ConvolutionSeries.h: -------------------------------------------------------------------------------- 1 | // 2 | // ConvolutionSeries.h 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/08. 6 | // 7 | 8 | #include 9 | 10 | #ifndef ConvolutionSeries_h 11 | #define ConvolutionSeries_h 12 | 13 | #define CONVOLUTION_WEIGHTS_LENGTH 32 14 | 15 | 16 | struct ConvolutionParameters { 17 | int32_t inputDepth; 18 | int32_t outputDepth; 19 | int32_t count; 20 | float weights[CONVOLUTION_WEIGHTS_LENGTH]; 21 | }; 22 | 23 | 24 | #endif /* ConvolutionSeries_h */ 25 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/Performance.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Performance.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/10. 6 | // 7 | 8 | import Foundation 9 | import OSLog 10 | 11 | 12 | private let signposter = OSSignposter(subsystem: Bundle.main.bundleIdentifier!, category: "performance") 13 | 14 | 15 | func measure(name: StaticString, worker: () -> Void) { 16 | let signpostID = signposter.makeSignpostID() 17 | let state = signposter.beginInterval(name, id: signpostID) 18 | worker() 19 | signposter.endInterval(name, state) 20 | } 21 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/Subtract.metal: -------------------------------------------------------------------------------- 1 | // 2 | // Subtract.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | using namespace metal; 10 | 11 | 12 | kernel void subtract( 13 | texture2d_array outputTexture [[texture(0)]], 14 | texture2d_array inputTexture [[texture(1)]], 15 | ushort3 gid [[thread_position_in_grid]] 16 | ) { 17 | float4 a = inputTexture.read(gid.xy, gid.z + 1); 18 | float4 b = inputTexture.read(gid.xy, gid.z); 19 | float4 c = a - b; 20 | outputTexture.write(c, gid.xy, gid.z); 21 | } 22 | 23 | 24 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/Math.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Math.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/18. 6 | // 7 | 8 | import Foundation 9 | 10 | 11 | public struct IntegralSize { 12 | public var width: Int 13 | public var height: Int 14 | 15 | public init(width: Int, height: Int) { 16 | self.width = width 17 | self.height = height 18 | } 19 | } 20 | 21 | 22 | func modulus(_ x: Float, _ y: Float) -> Float { 23 | var z: Float = x 24 | var n: Int = 0 25 | if (z < 0) { 26 | n = Int(((-z) / y) + 1) 27 | z += Float(n) * y 28 | } 29 | n = Int(z / y) 30 | z -= Float(n) * y 31 | return z 32 | } 33 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/NearestNeighborDownScale.metal: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborDownScale.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #import "../include/NearestNeighbor.h" 11 | 12 | using namespace metal; 13 | 14 | 15 | kernel void nearestNeighborDownScale( 16 | texture2d_array outputTexture [[texture(0)]], 17 | texture2d_array inputTexture [[texture(1)]], 18 | device NearestNeighborScaleParameters & parameters [[buffer(0)]], 19 | ushort2 gid [[thread_position_in_grid]] 20 | ) { 21 | outputTexture.write(inputTexture.read(gid * 2, parameters.inputSlice), gid, parameters.outputSlice); 22 | } 23 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/MetalExtensions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MetalExtensions.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/10. 6 | // 7 | 8 | import Foundation 9 | import Metal 10 | 11 | func capture(commandQueue: MTLCommandQueue, capture: Bool = true, worker: () -> Void) { 12 | guard capture else { 13 | worker() 14 | return 15 | } 16 | let captureManager = MTLCaptureManager.shared() 17 | let captureDescriptor = MTLCaptureDescriptor() 18 | captureDescriptor.captureObject = commandQueue 19 | captureDescriptor.destination = .developerTools 20 | try! captureManager.startCapture(with: captureDescriptor) 21 | worker() 22 | captureManager.stopCapture() 23 | } 24 | 25 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/NearestNeighborUpScale.metal: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborUpScale.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | using namespace metal; 10 | 11 | 12 | kernel void nearestNeighborUpScale( 13 | texture2d outputTexture [[texture(0)]], 14 | texture2d inputTexture [[texture(1)]], 15 | ushort2 gid [[thread_position_in_grid]] 16 | ) { 17 | ushort2 inputSize = ushort2(inputTexture.get_width(), inputTexture.get_height()); 18 | ushort2 outputSize = ushort2(outputTexture.get_width(), outputTexture.get_height()); 19 | 20 | ushort2 scale = outputSize / inputSize; 21 | outputTexture.write(inputTexture.read(gid / scale), gid); 22 | } 23 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/ConvertSRGBToGrayscale.metal: -------------------------------------------------------------------------------- 1 | // 2 | // ConvertSRGBToGrayscale.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | using namespace metal; 10 | 11 | kernel void convertSRGBToGrayscale( 12 | texture2d outputTexture [[texture(0)]], 13 | texture2d inputTexture [[texture(1)]], 14 | ushort2 gid [[thread_position_in_grid]] 15 | ) { 16 | const float4 input = inputTexture.read(gid); 17 | const float i = 0 + 18 | (0.212639005871510 * input.r) + 19 | (0.715168678767756 * input.g) + 20 | (0.072192315360734 * input.b); 21 | const float4 output = float4(i, i, i, input.a); 22 | outputTexture.write(output, gid); 23 | } 24 | 25 | 26 | -------------------------------------------------------------------------------- /Sources/MetalShaders/include/SIFTOrientation.h: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTOrientation.h 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/08. 6 | // 7 | #include 8 | 9 | #ifndef SIFTOrientation_h 10 | #define SIFTOrientation_h 11 | 12 | #define SIFT_ORIENTATION_HISTOGRAM_BINS 36 13 | 14 | struct SIFTOrientationParameters { 15 | float delta; 16 | float lambda; 17 | float orientationThreshold; 18 | }; 19 | 20 | 21 | struct SIFTOrientationKeypoint { 22 | int32_t index; 23 | int32_t absoluteX; 24 | int32_t absoluteY; 25 | int32_t scale; 26 | float sigma; 27 | }; 28 | 29 | 30 | struct SIFTOrientationResult { 31 | int32_t keypoint; 32 | int32_t count; 33 | float orientations[SIFT_ORIENTATION_HISTOGRAM_BINS]; 34 | }; 35 | 36 | #endif /* SIFTOrientation_h */ 37 | -------------------------------------------------------------------------------- /Sources/MetalShaders/include/SIFTDescriptor.h: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTDescriptor.h 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/08. 6 | // 7 | 8 | #include 9 | 10 | #ifndef SIFTDescriptor_h 11 | #define SIFTDescriptor_h 12 | 13 | 14 | #define SIFT_DESCRIPTOR_HISTOGRAM_WIDTH 4 15 | #define SIFT_DESCRIPTOR_ORIENTATION_BINS 8 16 | #define SIFT_DESCRIPTOR_FEATURE_COUNT 128 17 | 18 | 19 | struct SIFTDescriptorParameters { 20 | float delta; 21 | int32_t scalesPerOctave; 22 | int32_t width; 23 | int32_t height; 24 | }; 25 | 26 | 27 | struct SIFTDescriptorInput { 28 | int32_t keypoint; 29 | int32_t absoluteX; 30 | int32_t absoluteY; 31 | int32_t scale; 32 | float subScale; 33 | float theta; 34 | }; 35 | 36 | 37 | struct SIFTDescriptorResult { 38 | int32_t valid; 39 | int32_t keypoint; 40 | float theta; 41 | int32_t features[SIFT_DESCRIPTOR_FEATURE_COUNT]; 42 | }; 43 | 44 | 45 | #endif /* SIFTDescriptor_h */ 46 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Extensions/CoreImageExtensions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // CoreImageExtensions.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/17. 6 | // 7 | 8 | import CoreImage 9 | import CoreImage.CIFilterBuiltins 10 | import simd 11 | 12 | 13 | extension CIImage { 14 | public func perspectiveTransformed(by matrix: simd_float3x3) -> CIImage { 15 | let bounds = Quad(rect: extent).transformed(by: matrix) 16 | let filter = CIFilter.perspectiveTransform() 17 | filter.topLeft = CGPoint(bounds.topLeft) 18 | filter.topRight = CGPoint(bounds.topRight) 19 | filter.bottomRight = CGPoint(bounds.bottomRight) 20 | filter.bottomLeft = CGPoint(bounds.bottomLeft) 21 | filter.inputImage = self 22 | return filter.outputImage! 23 | } 24 | 25 | public func colorInverted() -> CIImage { 26 | let filter = CIFilter.colorInvert() 27 | filter.inputImage = self 28 | return filter.outputImage! 29 | } 30 | } 31 | 32 | -------------------------------------------------------------------------------- /Sources/MetalShaders/include/SIFTInterpolate.h: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTInterpolate.h 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/09. 6 | // 7 | 8 | #include 9 | 10 | #ifndef SIFTInterpolate_h 11 | #define SIFTInterpolate_h 12 | 13 | 14 | struct SIFTInterpolateParameters { 15 | float dogThreshold; 16 | int32_t maxIterations; 17 | float maxOffset; 18 | int32_t width; 19 | int32_t height; 20 | float octaveDelta; 21 | float edgeThreshold; 22 | int32_t numberOfScales; 23 | }; 24 | 25 | 26 | // This should match SIFTExtremaResult 27 | struct SIFTInterpolateInputKeypoint { 28 | int32_t x; 29 | int32_t y; 30 | int32_t scale; 31 | }; 32 | 33 | 34 | struct SIFTInterpolateOutputKeypoint { 35 | int32_t converged; 36 | int32_t scale; 37 | float subScale; 38 | int32_t relativeX; 39 | int32_t relativeY; 40 | float absoluteX; 41 | float absoluteY; 42 | float value; 43 | float alphaX; 44 | float alphaY; 45 | float alphaZ; 46 | }; 47 | 48 | 49 | #endif /* SIFTInterpolate_h */ 50 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/CoreVideoMetalCache.swift: -------------------------------------------------------------------------------- 1 | // 2 | // CoreVideoMetalCache.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/18. 6 | // 7 | 8 | import Foundation 9 | import CoreVideo 10 | 11 | 12 | final class CoreVideoMetalCache { 13 | 14 | private let textureCache: CVMetalTextureCache 15 | 16 | init(device: MTLDevice) { 17 | var textureCache: CVMetalTextureCache? 18 | CVMetalTextureCacheCreate(kCFAllocatorDefault, nil, device, nil, &textureCache) 19 | self.textureCache = textureCache! 20 | 21 | } 22 | 23 | func makeTexture(from input: CVImageBuffer, size: IntegralSize) -> MTLTexture { 24 | var cvMetalTexture: CVMetalTexture? 25 | let result = CVMetalTextureCacheCreateTextureFromImage(kCFAllocatorDefault, textureCache, input, nil, .bgra8Unorm, size.width, size.height, 0, &cvMetalTexture) 26 | guard result == kCVReturnSuccess else { 27 | fatalError("Cannot create texture") 28 | } 29 | let texture = CVMetalTextureGetTexture(cvMetalTexture!)! 30 | return texture 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Luke Van In 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/ImageConversion.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ImageConversion.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/18. 6 | // 7 | 8 | import Foundation 9 | import CoreImage 10 | import CoreVideo 11 | import Metal 12 | 13 | 14 | public final class ImageConversion { 15 | 16 | private let ciContext: CIContext 17 | 18 | public init(device: MTLDevice, colorSpace: CGColorSpace) { 19 | self.ciContext = CIContext( 20 | mtlDevice: device, 21 | options: [ 22 | .useSoftwareRenderer: false, 23 | .outputColorSpace: colorSpace, 24 | .workingColorSpace: colorSpace, 25 | ] 26 | ) 27 | } 28 | 29 | public func makeCGImage(_ input: CVImageBuffer) -> CGImage { 30 | let ciImage = CIImage( 31 | cvPixelBuffer: input, 32 | options: [.applyOrientationProperty: true] 33 | ) 34 | return makeCGImage(ciImage) 35 | } 36 | 37 | public func makeCGImage(_ input: MTLTexture) -> CGImage { 38 | let ciImage = CIImage(mtlTexture: input)! 39 | return makeCGImage(ciImage) 40 | } 41 | 42 | public func makeCGImage(_ input: CIImage) -> CGImage { 43 | let output = input.transformed(by: input.orientationTransform(for: .downMirrored)) 44 | return ciContext.createCGImage(output, from: output.extent)! 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/Buffer.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Buffer.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/06. 6 | // 7 | 8 | import Foundation 9 | import Metal 10 | 11 | 12 | public final class Buffer { 13 | 14 | public private(set) var count: Int 15 | let data: MTLBuffer 16 | let pointer: UnsafeMutablePointer 17 | public let capacity: Int 18 | 19 | public init(device: MTLDevice, label: String, capacity: Int) { 20 | self.capacity = capacity 21 | let numberOfBytes = MemoryLayout.stride * capacity 22 | self.data = device.makeBuffer( 23 | length: numberOfBytes, 24 | options: [.hazardTrackingModeTracked, .storageModeShared] 25 | )! 26 | self.data.label = label 27 | self.pointer = data.contents().bindMemory(to: T.self, capacity: capacity) 28 | self.count = 0 29 | } 30 | 31 | deinit { 32 | data.setPurgeableState(.empty) 33 | } 34 | 35 | public func allocate(_ count: Int) { 36 | precondition(count >= 0) 37 | precondition(count <= capacity) 38 | self.count = count 39 | } 40 | 41 | public subscript(i: Int) -> T { 42 | get { 43 | precondition(i >= 0) 44 | precondition(i < count) 45 | return pointer[i] 46 | } 47 | set { 48 | precondition(i >= 0) 49 | precondition(i < count) 50 | pointer[i] = newValue 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/SIFTGradient.metal: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTGradient.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #include "Common.hpp" 11 | 12 | using namespace metal; 13 | 14 | 15 | kernel void siftGradient( 16 | texture2d_array outputTexture [[texture(0)]], 17 | texture2d_array inputTexture [[texture(1)]], 18 | ushort3 gid [[thread_position_in_grid]] 19 | ) { 20 | const int gx = (int)gid.x; 21 | const int gy = (int)gid.y; 22 | const int gz = (int)gid.z; 23 | const int dx = inputTexture.get_width(); 24 | const int dy = inputTexture.get_height(); 25 | const ushort px = symmetrizedCoordinates(gx + 1, dx); 26 | const ushort mx = symmetrizedCoordinates(gx - 1, dx); 27 | const ushort py = symmetrizedCoordinates(gy + 1, dy); 28 | const ushort my = symmetrizedCoordinates(gy - 1, dy); 29 | const float cpx = inputTexture.read(ushort2(px, gy), gz).r; 30 | const float cmx = inputTexture.read(ushort2(mx, gy), gz).r; 31 | const float cpy = inputTexture.read(ushort2(gx, py), gz).r; 32 | const float cmy = inputTexture.read(ushort2(gx, my), gz).r; 33 | const float tx = (cpx - cmx) * 0.5; 34 | const float ty = (cpy - cmy) * 0.5; 35 | #warning("FIXME: IPOL implementation swaps dx and dy") 36 | float oa = atan2(tx, ty); 37 | float om = sqrt(tx * tx + ty * ty); 38 | outputTexture.write(float4(oa, om, 0, 0), ushort2(gx, gy), gz); 39 | } 40 | 41 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/Quad.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Quad.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/17. 6 | // 7 | 8 | import Foundation 9 | import simd 10 | 11 | 12 | struct Quad { 13 | var topLeft: simd_float2 14 | var topRight: simd_float2 15 | var bottomRight: simd_float2 16 | var bottomLeft: simd_float2 17 | 18 | var points: [simd_float2] { 19 | [topLeft, topRight, bottomRight, bottomLeft] 20 | } 21 | 22 | func transformed(by matrix: simd_float3x3) -> Quad { 23 | Quad( 24 | points: points 25 | .map { point in 26 | simd_float3(point.x, point.y, 1) 27 | } 28 | .map { point in 29 | matrix * point 30 | } 31 | .map { point in 32 | simd_float2(point.x / point.z, point.y / point.z) 33 | } 34 | ) 35 | } 36 | } 37 | 38 | extension Quad { 39 | init(rect: CGRect) { 40 | self.init( 41 | topLeft: simd_float2(Float(rect.minX), Float(rect.maxY)), 42 | topRight: simd_float2(Float(rect.maxX), Float(rect.maxY)), 43 | bottomRight: simd_float2(Float(rect.maxX), Float(rect.minY)), 44 | bottomLeft: simd_float2(Float(rect.minX), Float(rect.minY)) 45 | ) 46 | } 47 | 48 | init(points: [simd_float2]) { 49 | self.init( 50 | topLeft: points[0], 51 | topRight: points[1], 52 | bottomRight: points[2], 53 | bottomLeft: points[3] 54 | ) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/SIFT/SIFTKeypoint.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTKeypoint.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/02. 6 | // 7 | 8 | import Foundation 9 | 10 | 11 | public struct SIFTKeypoint { 12 | 13 | // Index of the level of the difference-of-gaussians pyramid. 14 | public var octave: Int 15 | 16 | // Index of the image in the octave. 17 | public var scale: Int 18 | 19 | // 20 | public var subScale: Float 21 | 22 | // Coordinate relative to the difference-of-gaussians image size. 23 | public var scaledCoordinate: SIMD2 24 | 25 | // Coordinate relative to the original image. 26 | public var absoluteCoordinate: SIMD2 27 | 28 | // Coordinate relative to the normal space (0...1, 0...1) 29 | public var normalizedCoordinate: SIMD2 30 | 31 | // "Blur" 32 | public var sigma: Float 33 | 34 | // Pixel color (intensity) 35 | public var value: Float 36 | 37 | public init( 38 | octave: Int, 39 | scale: Int, 40 | subScale: Float, 41 | scaledCoordinate: SIMD2, 42 | absoluteCoordinate: SIMD2, 43 | normalizedCoordinate: SIMD2, 44 | sigma: Float, 45 | value: Float 46 | ) { 47 | self.octave = octave 48 | self.scale = scale 49 | self.subScale = subScale 50 | self.scaledCoordinate = scaledCoordinate 51 | self.absoluteCoordinate = absoluteCoordinate 52 | self.normalizedCoordinate = normalizedCoordinate 53 | self.sigma = sigma 54 | self.value = value 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/Convolution.metal: -------------------------------------------------------------------------------- 1 | // 2 | // Convolution.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #include "Common.hpp" 11 | 12 | using namespace metal; 13 | 14 | 15 | kernel void convolutionX( 16 | texture2d outputTexture [[texture(0)]], 17 | texture2d inputTexture [[texture(1)]], 18 | device float * weights [[buffer(0)]], 19 | device uint & numberOfWeights [[buffer(1)]], 20 | ushort2 gid [[thread_position_in_grid]] 21 | ) { 22 | const int width = inputTexture.get_width(); 23 | 24 | float sum = 0; 25 | const int n = (int)numberOfWeights; 26 | const int o = (int)gid.x - (n / 2); 27 | for (int i = 0; i < n; i++) { 28 | int x = symmetrizedCoordinates(o + i, width); 29 | sum += weights[i] * inputTexture.read(ushort2(x, gid.y)).r; 30 | } 31 | outputTexture.write(float4(sum, 0, 0, 1), gid); 32 | } 33 | 34 | 35 | kernel void convolutionY( 36 | texture2d outputTexture [[texture(0)]], 37 | texture2d inputTexture [[texture(1)]], 38 | device float * weights [[buffer(0)]], 39 | device uint & numberOfWeights [[buffer(1)]], 40 | ushort2 gid [[thread_position_in_grid]] 41 | ) { 42 | const int height = inputTexture.get_height(); 43 | 44 | float sum = 0; 45 | const int n = (int)numberOfWeights; 46 | const int o = (int)gid.y - (n / 2); 47 | for (int i = 0; i < n; i++) { 48 | int y = symmetrizedCoordinates(o + i, height); 49 | sum += weights[i] * inputTexture.read(ushort2(gid.x, y)).r; 50 | } 51 | outputTexture.write(float4(sum, 0, 0, 1), gid); 52 | } 53 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/ConvolutionSeries.metal: -------------------------------------------------------------------------------- 1 | // 2 | // Convolution.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #include "Common.hpp" 11 | #include "../include/ConvolutionSeries.h" 12 | 13 | using namespace metal; 14 | 15 | 16 | kernel void convolutionSeriesX( 17 | texture2d_array outputTexture [[texture(0)]], 18 | texture2d_array inputTexture [[texture(1)]], 19 | device ConvolutionParameters & parameters [[buffer(0)]], 20 | ushort2 gid [[thread_position_in_grid]] 21 | ) { 22 | const int width = inputTexture.get_width(); 23 | 24 | float sum = 0; 25 | const int n = (int)parameters.count; 26 | const int o = (int)gid.x - (n / 2); 27 | for (int i = 0; i < n; i++) { 28 | int x = symmetrizedCoordinates(o + i, width); 29 | float c = inputTexture.read(ushort2(x, gid.y), parameters.inputDepth).r; 30 | sum += parameters.weights[i] * c; 31 | } 32 | outputTexture.write(float4(sum, 0, 0, 1), gid, parameters.outputDepth); 33 | } 34 | 35 | 36 | kernel void convolutionSeriesY( 37 | texture2d_array outputTexture [[texture(0)]], 38 | texture2d_array inputTexture [[texture(1)]], 39 | device ConvolutionParameters & parameters [[buffer(0)]], 40 | ushort2 gid [[thread_position_in_grid]] 41 | ) { 42 | const int height = inputTexture.get_height(); 43 | 44 | float sum = 0; 45 | const int n = (int)parameters.count; 46 | const int o = (int)gid.y - (n / 2); 47 | for (int i = 0; i < n; i++) { 48 | int y = symmetrizedCoordinates(o + i, height); 49 | float c = inputTexture.read(ushort2(gid.x, y), parameters.inputDepth).r; 50 | sum += parameters.weights[i] * c; 51 | } 52 | outputTexture.write(float4(sum, 0, 0, 1), gid, parameters.outputDepth); 53 | } 54 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/BilinearUpScale.metal: -------------------------------------------------------------------------------- 1 | // 2 | // BilinearUpScale.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | using namespace metal; 10 | 11 | 12 | kernel void bilinearUpScale( 13 | texture2d outputTexture [[texture(0)]], 14 | texture2d inputTexture [[texture(1)]], 15 | uint2 gid [[thread_position_in_grid]] 16 | ) { 17 | 18 | const int wo = outputTexture.get_width(); 19 | const int ho = outputTexture.get_height(); 20 | 21 | const int wi = inputTexture.get_width(); 22 | const int hi = inputTexture.get_height(); 23 | 24 | const float dx = (float)wi / (float)wo; 25 | const float dy = (float)hi / (float)ho; 26 | 27 | int i = gid.x; 28 | int j = gid.y; 29 | const float x = (float)i * dx; 30 | const float y = (float)j * dy; 31 | int im = (int)x; 32 | int jm = (int)y; 33 | int ip = im + 1; 34 | int jp = jm + 1; 35 | 36 | //image extension by symmetrization 37 | if (ip >= wi) { 38 | ip = 2 * wi - 1 - ip; 39 | } 40 | if (im >= wi) { 41 | im = 2 * wi - 1 - im; 42 | } 43 | if (jp >= hi) { 44 | jp = 2 * hi - 1 - jp; 45 | } 46 | if (jm >= hi) { 47 | jm = 2 * hi - 1 - jm; 48 | } 49 | 50 | const float fractional_x = x - floor(x); 51 | const float fractional_y = y - floor(y); 52 | 53 | const float c0 = inputTexture.read(uint2(ip, jp)).r; 54 | const float c1 = inputTexture.read(uint2(ip, jm)).r; 55 | const float c2 = inputTexture.read(uint2(im, jp)).r; 56 | const float c3 = inputTexture.read(uint2(im, jm)).r; 57 | 58 | const float output = fractional_x * (fractional_y * c0 59 | + (1 - fractional_y) * c1 ) 60 | + (1 - fractional_x) * ( fractional_y * c2 61 | + (1 - fractional_y) * c3 ); 62 | 63 | outputTexture.write(float4(output, 0, 0, 1), gid); 64 | } 65 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/NearestNeighborUpScaleKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborScaleKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class NearestNeighborUpScaleKernel { 13 | 14 | private let computePipelineState: MTLComputePipelineState 15 | 16 | init(device: MTLDevice) { 17 | let library = device.makeDefaultLibrary()! 18 | 19 | let function = library.makeFunction(name: "nearestNeighborUpScale")! 20 | 21 | self.computePipelineState = try! device.makeComputePipelineState( 22 | function: function 23 | ) 24 | } 25 | 26 | func encode( 27 | commandBuffer: MTLCommandBuffer, 28 | inputTexture: MTLTexture, 29 | outputTexture: MTLTexture 30 | ) { 31 | precondition(outputTexture.width % inputTexture.width == 0) 32 | precondition(outputTexture.height % inputTexture.height == 0) 33 | 34 | let encoder = commandBuffer.makeComputeCommandEncoder()! 35 | encoder.setComputePipelineState(computePipelineState) 36 | encoder.setTexture(outputTexture, index: 0) 37 | encoder.setTexture(inputTexture, index: 1) 38 | 39 | // Set the compute kernel's threadgroup size of 16x16 40 | // TODO: Ger threadgroup size from command buffer. 41 | let threadgroupSize = MTLSize( 42 | width: 16, 43 | height: 16, 44 | depth: 1 45 | ) 46 | // Calculate the number of rows and columns of threadgroups given the width of the input image 47 | // Ensure that you cover the entire image (or more) so you process every pixel 48 | // Since we're only dealing with a 2D data set, set depth to 1 49 | let threadgroupCount = MTLSize( 50 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 51 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 52 | depth: 1 53 | ) 54 | encoder.dispatchThreadgroups( 55 | threadgroupCount, 56 | threadsPerThreadgroup: threadgroupSize 57 | ) 58 | encoder.endEncoding() 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/SIFTDescriptorKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTOrientationKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | import MetalShaders 12 | 13 | public final class SIFTDescriptorKernel { 14 | 15 | private let maximumKeypoints = 4096 16 | 17 | private let computePipelineState: MTLComputePipelineState 18 | 19 | public init(device: MTLDevice) { 20 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 21 | 22 | let function = library.makeFunction(name: "siftDescriptors")! 23 | 24 | self.computePipelineState = try! device.makeComputePipelineState( 25 | function: function 26 | ) 27 | } 28 | 29 | public func encode( 30 | commandBuffer: MTLCommandBuffer, 31 | parameters: Buffer, 32 | gradientTextures: MTLTexture, 33 | inputKeypoints: Buffer, 34 | outputDescriptors: Buffer 35 | ) { 36 | precondition(inputKeypoints.count == outputDescriptors.count) 37 | precondition(gradientTextures.textureType == .type2DArray) 38 | precondition(gradientTextures.pixelFormat == .rg32Float) 39 | 40 | let encoder = commandBuffer.makeComputeCommandEncoder()! 41 | encoder.setComputePipelineState(computePipelineState) 42 | encoder.setBuffer(outputDescriptors.data, offset: 0, index: 0) 43 | encoder.setBuffer(inputKeypoints.data, offset: 0, index: 1) 44 | encoder.setBuffer(parameters.data, offset: 0, index: 2) 45 | encoder.setTexture(gradientTextures, index: 0) 46 | 47 | let threadsPerThreadgroup = MTLSize( 48 | width: computePipelineState.maxTotalThreadsPerThreadgroup, 49 | height: 1, 50 | depth: 1 51 | ) 52 | let threadsPerGrid = MTLSize( 53 | width: outputDescriptors.count, 54 | height: 1, 55 | depth: 1 56 | ) 57 | 58 | encoder.dispatchThreads( 59 | threadsPerGrid, 60 | threadsPerThreadgroup: threadsPerThreadgroup 61 | ) 62 | encoder.endEncoding() 63 | } 64 | } 65 | 66 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/BilinearUpScaleKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborScaleKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class BilinearUpScaleKernel { 13 | 14 | private let computePipelineState: MTLComputePipelineState 15 | 16 | init(device: MTLDevice) { 17 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 18 | 19 | let function = library.makeFunction(name: "bilinearUpScale")! 20 | 21 | self.computePipelineState = try! device.makeComputePipelineState( 22 | function: function 23 | ) 24 | } 25 | 26 | func encode( 27 | commandBuffer: MTLCommandBuffer, 28 | inputTexture: MTLTexture, 29 | outputTexture: MTLTexture 30 | ) { 31 | precondition(outputTexture.width > inputTexture.width) 32 | precondition(outputTexture.height > inputTexture.height) 33 | precondition(outputTexture.pixelFormat == inputTexture.pixelFormat) 34 | 35 | let encoder = commandBuffer.makeComputeCommandEncoder()! 36 | encoder.setComputePipelineState(computePipelineState) 37 | encoder.setTexture(outputTexture, index: 0) 38 | encoder.setTexture(inputTexture, index: 1) 39 | 40 | // Set the compute kernel's threadgroup size of 16x16 41 | // TODO: Get threadgroup size from command buffer. 42 | let threadgroupSize = MTLSize( 43 | width: 16, 44 | height: 16, 45 | depth: 1 46 | ) 47 | // Calculate the number of rows and columns of threadgroups given the width of the input image 48 | // Ensure that you cover the entire image (or more) so you process every pixel 49 | // Since we're only dealing with a 2D data set, set depth to 1 50 | let threadgroupCount = MTLSize( 51 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 52 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 53 | depth: 1 54 | ) 55 | encoder.dispatchThreadgroups( 56 | threadgroupCount, 57 | threadsPerThreadgroup: threadgroupSize 58 | ) 59 | encoder.endEncoding() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/SIFTExtremaKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTExtremaKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/20. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class SIFTExtremaFunction { 13 | 14 | private let computePipelineState: MTLComputePipelineState 15 | 16 | init(device: MTLDevice) { 17 | let library = device.makeDefaultLibrary()! 18 | 19 | let function = library.makeFunction(name: "siftExtrema")! 20 | function.label = "siftExtremaFunction" 21 | 22 | self.computePipelineState = try! device.makeComputePipelineState( 23 | function: function 24 | ) 25 | } 26 | 27 | func encode( 28 | commandBuffer: MTLCommandBuffer, 29 | inputTexture: MTLTexture, 30 | outputTexture: MTLTexture 31 | ) { 32 | precondition(inputTexture.width == outputTexture.width) 33 | precondition(inputTexture.height == outputTexture.height) 34 | precondition(inputTexture.arrayLength == outputTexture.arrayLength + 2) 35 | precondition(inputTexture.textureType == .type2DArray) 36 | precondition(inputTexture.pixelFormat == .r32Float) 37 | precondition(outputTexture.textureType == .type2DArray) 38 | precondition(outputTexture.pixelFormat == .rg32Float) 39 | 40 | let encoder = commandBuffer.makeComputeCommandEncoder()! 41 | encoder.label = "siftExtremaFunctionComputeEncoder" 42 | encoder.setComputePipelineState(computePipelineState) 43 | encoder.setTexture(outputTexture, index: 0) 44 | encoder.setTexture(inputTexture, index: 1) 45 | 46 | let threadsPerDimension = Int(cbrt(Float(computePipelineState.maxTotalThreadsPerThreadgroup))) 47 | let threadsPerThreadgroup = MTLSize( 48 | width: threadsPerDimension, 49 | height: threadsPerDimension, 50 | depth: threadsPerDimension 51 | ) 52 | let threadsPerGrid = MTLSize( 53 | width: outputTexture.width - 2, 54 | height: outputTexture.height - 2, 55 | depth: outputTexture.arrayLength 56 | ) 57 | 58 | encoder.dispatchThreads( 59 | threadsPerGrid, 60 | threadsPerThreadgroup: threadsPerThreadgroup 61 | ) 62 | encoder.endEncoding() 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/ConvertSRGBToGrayscaleKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborScaleKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class ConvertSRGBToGrayscaleKernel { 13 | 14 | private let computePipelineState: MTLComputePipelineState 15 | 16 | init(device: MTLDevice) { 17 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 18 | 19 | let function = library.makeFunction(name: "convertSRGBToGrayscale")! 20 | 21 | self.computePipelineState = try! device.makeComputePipelineState( 22 | function: function 23 | ) 24 | } 25 | 26 | func encode( 27 | commandBuffer: MTLCommandBuffer, 28 | inputTexture: MTLTexture, 29 | outputTexture: MTLTexture 30 | ) { 31 | precondition(inputTexture.width == outputTexture.width) 32 | precondition(inputTexture.height == outputTexture.height) 33 | // precondition(inputTexture.pixelFormat == .bgra8Unorm_srgb) 34 | precondition(inputTexture.pixelFormat == .bgra8Unorm) 35 | precondition(outputTexture.pixelFormat == .r32Float) 36 | 37 | let encoder = commandBuffer.makeComputeCommandEncoder()! 38 | encoder.setComputePipelineState(computePipelineState) 39 | encoder.setTexture(outputTexture, index: 0) 40 | encoder.setTexture(inputTexture, index: 1) 41 | 42 | // Set the compute kernel's threadgroup size of 16x16 43 | // TODO: Ger threadgroup size from command buffer. 44 | let threadgroupSize = MTLSize( 45 | width: 16, 46 | height: 16, 47 | depth: 1 48 | ) 49 | // Calculate the number of rows and columns of threadgroups given the width of the input image 50 | // Ensure that you cover the entire image (or more) so you process every pixel 51 | // Since we're only dealing with a 2D data set, set depth to 1 52 | let threadgroupCount = MTLSize( 53 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 54 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 55 | depth: 1 56 | ) 57 | encoder.dispatchThreadgroups( 58 | threadgroupCount, 59 | threadsPerThreadgroup: threadgroupSize 60 | ) 61 | encoder.endEncoding() 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/SIFTExtremaListKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTExtremaKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/20. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | import MetalShaders 12 | 13 | final class SIFTExtremaListFunction { 14 | 15 | private let computePipelineState: MTLComputePipelineState 16 | 17 | let indexBuffer: Buffer 18 | 19 | init(device: MTLDevice) { 20 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 21 | 22 | let function = library.makeFunction(name: "siftExtremaList")! 23 | function.label = "siftExtremaListFunction" 24 | 25 | self.computePipelineState = try! device.makeComputePipelineState( 26 | function: function 27 | ) 28 | self.indexBuffer = Buffer( 29 | device: device, 30 | label: "siftExtremaListIndex", 31 | capacity: 1 32 | ) 33 | self.indexBuffer.allocate(1) 34 | indexBuffer[0] = 0 35 | } 36 | 37 | func encode( 38 | commandBuffer: MTLCommandBuffer, 39 | inputTexture: MTLTexture, 40 | outputBuffer: Buffer 41 | ) { 42 | precondition(inputTexture.textureType == .type2DArray) 43 | precondition(inputTexture.pixelFormat == .r32Float) 44 | 45 | let encoder = commandBuffer.makeComputeCommandEncoder()! 46 | encoder.label = "siftExtremaListFunctionComputeEncoder" 47 | encoder.setComputePipelineState(computePipelineState) 48 | encoder.setBuffer(outputBuffer.data, offset: 0, index: 0) 49 | encoder.setBuffer(indexBuffer.data, offset: 0, index: 1) 50 | encoder.setTexture(inputTexture, index: 0) 51 | 52 | let threadsPerDimension = Int(cbrt(Float(computePipelineState.maxTotalThreadsPerThreadgroup))) 53 | let threadsPerThreadgroup = MTLSize( 54 | width: threadsPerDimension, 55 | height: threadsPerDimension, 56 | depth: threadsPerDimension 57 | ) 58 | let threadsPerGrid = MTLSize( 59 | width: inputTexture.width - 2, 60 | height: inputTexture.height - 2, 61 | depth: inputTexture.arrayLength - 2 62 | ) 63 | 64 | encoder.dispatchThreads( 65 | threadsPerGrid, 66 | threadsPerThreadgroup: threadsPerThreadgroup 67 | ) 68 | encoder.endEncoding() 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/SIFT/SIFTPatch.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTHistograms.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/02. 6 | // 7 | 8 | import Foundation 9 | 10 | 11 | final class SIFTPatch { 12 | 13 | private let buffer: UnsafeMutableBufferPointer 14 | private let side: Int 15 | private let bins: Int 16 | 17 | init(side: Int, bins: Int) { 18 | self.side = side 19 | self.bins = bins 20 | self.buffer = UnsafeMutableBufferPointer.allocate( 21 | capacity: side * side * bins 22 | ) 23 | } 24 | 25 | deinit { 26 | buffer.deallocate() 27 | } 28 | 29 | func addValue(x: Float, y: Float, bin: Float, value: Float) { 30 | let ca = SIMD2(x: Int(floor(x)), y: Int(floor(y))) 31 | let cb = SIMD2(x: Int(ceil(x)), y: Int(floor(y))) 32 | let cc = SIMD2(x: Int(ceil(x)), y: Int(ceil(y))) 33 | let cd = SIMD2(x: Int(floor(x)), y: Int(ceil(y))) 34 | 35 | let ba = Int(floor(bin)) 36 | let bb = Int(ceil(bin)) 37 | 38 | let iMax = x - floor(x) 39 | let iMin = 1 - iMax 40 | let jMax = y - floor(y) 41 | let jMin = 1 - jMax 42 | let bMax = bin - floor(bin) 43 | let bMin = 1 - bMax 44 | 45 | addValue(x: ca.x, y: ca.y, bin: ba, value: (iMin * jMin * bMin) * value) 46 | addValue(x: ca.x, y: ca.y, bin: bb, value: (iMin * jMin * bMax) * value) 47 | 48 | addValue(x: cb.x, y: cb.y, bin: ba, value: (iMax * jMin * bMin) * value) 49 | addValue(x: cb.x, y: cb.y, bin: bb, value: (iMax * jMin * bMax) * value) 50 | 51 | addValue(x: cc.x, y: cc.y, bin: ba, value: (iMax * jMax * bMin) * value) 52 | addValue(x: cc.x, y: cc.y, bin: bb, value: (iMax * jMax * bMax) * value) 53 | 54 | addValue(x: cd.x, y: cd.y, bin: ba, value: (iMin * jMax * bMin) * value) 55 | addValue(x: cd.x, y: cd.y, bin: bb, value: (iMin * jMax * bMax) * value) 56 | } 57 | 58 | func addValue(x: Int, y: Int, bin: Int, value: Float) { 59 | guard x >= 0 && x < side && y >= 0 && y < side else { 60 | return 61 | } 62 | var bin = bin 63 | while bin < 0 { 64 | bin += bins 65 | } 66 | while bin >= bins { 67 | bin -= bins 68 | } 69 | buffer[offset(x: x, y: y, bin: bin)] += value 70 | } 71 | 72 | private func offset(x: Int, y: Int, bin: Int) -> Int { 73 | (y * side * bins) + (x * bins) + bin 74 | } 75 | 76 | func features() -> [Float] { 77 | return Array(buffer) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/SubtractKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SubtractKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class SubtractKernel { 13 | 14 | private let computePipelineState: MTLComputePipelineState 15 | 16 | init(device: MTLDevice) { 17 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 18 | 19 | let function = library.makeFunction(name: "subtract")! 20 | function.label = "subtractFunction" 21 | 22 | self.computePipelineState = try! device.makeComputePipelineState( 23 | function: function 24 | ) 25 | } 26 | 27 | func encode( 28 | commandBuffer: MTLCommandBuffer, 29 | inputTexture: MTLTexture, 30 | outputTexture: MTLTexture 31 | ) { 32 | precondition(inputTexture.width == outputTexture.width) 33 | precondition(inputTexture.width == outputTexture.width) 34 | precondition(inputTexture.arrayLength == outputTexture.arrayLength + 1) 35 | precondition(inputTexture.textureType == .type2DArray) 36 | precondition(inputTexture.pixelFormat == .r32Float) 37 | precondition(outputTexture.textureType == .type2DArray) 38 | precondition(outputTexture.pixelFormat == .r32Float) 39 | 40 | let encoder = commandBuffer.makeComputeCommandEncoder()! 41 | encoder.label = "subtractFunctionComputeEncoder" 42 | encoder.setComputePipelineState(computePipelineState) 43 | encoder.setTexture(outputTexture, index: 0) 44 | encoder.setTexture(inputTexture, index: 1) 45 | 46 | // Set the compute kernel's threadgroup size of 16x16 47 | // TODO: Ger threadgroup size from command buffer. 48 | let threadgroupSize = MTLSize( 49 | width: 8, 50 | height: 8, 51 | depth: 8 52 | ) 53 | // Calculate the number of rows and columns of threadgroups given the width of the input image 54 | // Ensure that you cover the entire image (or more) so you process every pixel 55 | // Since we're only dealing with a 2D data set, set depth to 1 56 | let threadgroupCount = MTLSize( 57 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 58 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 59 | depth: (outputTexture.arrayLength + threadgroupSize.depth - 1) 60 | ) 61 | encoder.dispatchThreadgroups( 62 | threadgroupCount, 63 | threadsPerThreadgroup: threadgroupSize 64 | ) 65 | encoder.endEncoding() 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIFTMetal 2 | 3 | Luke Van In, 2023 4 | 5 | An implementation of the Scale Invariant Feature Transform (SIFT) algorithm, for 6 | Apple devices, written in Swift using Metal compute. 7 | 8 | SIFT is described in the paper "Distinctive Image Features from Scale-Invariant 9 | Keypoints" by David Lowe published in 2004[1]. 10 | 11 | This implementation is based on the source code from the "Anatomy of the SIFT 12 | Method" by Ives Ray-Otero and Mauricio Delbracio published in the Image 13 | Processing Online (IPOL) journal in 2014[2], and source code by Rob Whess[3]. 14 | 15 | The scale-invariant feature transform (SIFT) is a computer vision algorithm to 16 | detect, describe, and match local features in images, invented by David Lowe in 17 | 1999. Applications include object recognition, robotic mapping and navigation, 18 | image stitching, 3D modeling, gesture recognition, video tracking, individual 19 | identification of wildlife and match moving.[4] 20 | 21 | A novel Approximate K Nearest Neighbors algorithm is provided for matching SIFT 22 | descriptors, using a trie data structure. The complexity of the algorithm is: 23 | - Initial construction and update is linear O(n) complexity. 24 | - Nearest neighbor search is O(1) complexity. 25 | 26 | SIFT keypoints of objects are first extracted from a set of reference images[1] and stored in a database. An object is recognized in a new image by individually comparing each feature from the new image to this database and finding candidate matching features based on Euclidean distance of their feature vectors. From the full set of matches, subsets of keypoints that agree on the object and its location, scale, and orientation in the new image are identified to filter out good matches. The determination of consistent clusters is performed rapidly by using an efficient hash table implementation of the generalised Hough transform. Each cluster of 3 or more features that agree on an object and its pose is then subject to further detailed model verification and subsequently outliers are discarded. Finally the probability that a particular set of features indicates the presence of an object is computed, given the accuracy of fit and number of probable false matches. Object matches that pass all these tests can be identified as correct with high confidence.[2] 27 | [1]: https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf "Distinctive Image Features from Scale-Invariant Keypoints", Lowe, International Journal of Computer Vision, 2004 28 | [2]: https://github.com/robwhess/opensift OpenSIFT, Whess, GitHub, 2012 29 | [3]: http://www.ipol.im/pub/art/2014/82/article.pdf "Anatomy of the SIFT Method", Rey-Otero & Delbracio, IPOL, 2014 30 | [4]: https://en.wikipedia.org/wiki/Scale-invariant_feature_transform Scale-invariant feature transform, Wikipedia 31 | 32 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/SIFTGradientKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTGradientKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/03. 6 | // 7 | 8 | import Foundation 9 | import Metal 10 | 11 | /// 12 | /// Computes the gradient orientation and magnitude for every pixel in the input image. 13 | /// 14 | final class SIFTGradientKernel { 15 | 16 | private let computePipelineState: MTLComputePipelineState 17 | 18 | init(device: MTLDevice) { 19 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 20 | 21 | let function = library.makeFunction(name: "siftGradient")! 22 | function.label = "siftGradientFunction" 23 | 24 | self.computePipelineState = try! device.makeComputePipelineState( 25 | function: function 26 | ) 27 | } 28 | 29 | func encode( 30 | commandBuffer: MTLCommandBuffer, 31 | inputTexture: MTLTexture, 32 | outputTexture: MTLTexture 33 | ) { 34 | precondition(inputTexture.width == outputTexture.width) 35 | precondition(inputTexture.height == outputTexture.height) 36 | precondition(inputTexture.arrayLength == outputTexture.arrayLength) 37 | precondition(inputTexture.textureType == .type2DArray) 38 | precondition(inputTexture.pixelFormat == .r32Float) 39 | precondition(outputTexture.textureType == .type2DArray) 40 | precondition(outputTexture.pixelFormat == .rg32Float) 41 | 42 | let encoder = commandBuffer.makeComputeCommandEncoder()! 43 | encoder.label = "siftGradientComputeEncoder" 44 | encoder.setComputePipelineState(computePipelineState) 45 | encoder.setTexture(outputTexture, index: 0) 46 | encoder.setTexture(inputTexture, index: 1) 47 | 48 | // Set the compute kernel's threadgroup size of 16x16 49 | // TODO: Get threadgroup size from command buffer. 50 | let threadgroupSize = MTLSize( 51 | width: 8, 52 | height: 8, 53 | depth: 8 54 | ) 55 | // Calculate the number of rows and columns of threadgroups given the width of the input image 56 | // Ensure that you cover the entire image (or more) so you process every pixel 57 | // Since we're only dealing with a 2D data set, set depth to 1 58 | let threadgroupCount = MTLSize( 59 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 60 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 61 | depth: (outputTexture.arrayLength + threadgroupSize.depth - 1) / threadgroupSize.depth 62 | ) 63 | encoder.dispatchThreadgroups( 64 | threadgroupCount, 65 | threadsPerThreadgroup: threadgroupSize 66 | ) 67 | encoder.endEncoding() 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/SIFTInterpolateKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTInterpolateKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | import MetalShaders 12 | 13 | final class SIFTInterpolateKernel { 14 | 15 | private let maximumKeypoints = 4096 16 | 17 | private let computePipelineState: MTLComputePipelineState 18 | // private let differenceTextureArray: MTLTexture 19 | 20 | init(device: MTLDevice) { 21 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 22 | 23 | let function = library.makeFunction(name: "siftInterpolate")! 24 | 25 | // let descriptor = MTLTextureDescriptor() 26 | // descriptor.textureType = .type2DArray 27 | // descriptor.pixelFormat = .r32Float 28 | // descriptor.width = textureSize.width 29 | // descriptor.height = textureSize.height 30 | // descriptor.arrayLength = numberOfTextures 31 | // descriptor.mipmapLevelCount = 0 32 | // descriptor.storageMode = .shared 33 | // descriptor.usage = [.shaderRead, .shaderWrite] 34 | 35 | self.computePipelineState = try! device.makeComputePipelineState( 36 | function: function 37 | ) 38 | // self.differenceTextureArray = device.makeTexture(descriptor: descriptor)! 39 | } 40 | 41 | func encode( 42 | commandBuffer: MTLCommandBuffer, 43 | parameters: Buffer, 44 | differenceTextures: MTLTexture, 45 | inputKeypoints: Buffer, 46 | outputKeypoints: Buffer 47 | ) { 48 | precondition(inputKeypoints.count == outputKeypoints.count) 49 | precondition(differenceTextures.textureType == .type2DArray) 50 | precondition(differenceTextures.pixelFormat == .r32Float) 51 | 52 | let encoder = commandBuffer.makeComputeCommandEncoder()! 53 | encoder.setComputePipelineState(computePipelineState) 54 | encoder.setBuffer(outputKeypoints.data, offset: 0, index: 0) 55 | encoder.setBuffer(inputKeypoints.data, offset: 0, index: 1) 56 | encoder.setBuffer(parameters.data, offset: 0, index: 2) 57 | encoder.setTexture(differenceTextures, index: 0) 58 | 59 | let threadsPerThreadgroup = MTLSize( 60 | width: computePipelineState.maxTotalThreadsPerThreadgroup, 61 | height: 1, 62 | depth: 1 63 | ) 64 | let threadsPerGrid = MTLSize( 65 | width: outputKeypoints.count, 66 | height: 1, 67 | depth: 1 68 | ) 69 | encoder.dispatchThreads( 70 | threadsPerGrid, 71 | threadsPerThreadgroup: threadsPerThreadgroup 72 | ) 73 | encoder.endEncoding() 74 | } 75 | } 76 | 77 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/SIFTOrientationKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTOrientationKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | import MetalShaders 12 | 13 | final class SIFTOrientationKernel { 14 | 15 | // struct Parameters { 16 | // let delta: Float32 17 | // let lambda: Float32 18 | // let orientationThreshold: Float32 19 | // } 20 | // 21 | // struct InputKeypoint { 22 | // let absoluteX: Int32 23 | // let absoluteY: Int32 24 | // let scale: Int32 25 | // let sigma: Float32 26 | // } 27 | // 28 | // struct OutputKeypoint { 29 | // let count: Int32 30 | // let orientations: [Float32] 31 | // } 32 | 33 | // typealias Parameters = SIFTOrientationKeypoint 34 | 35 | private let maximumKeypoints = 4096 36 | 37 | private let computePipelineState: MTLComputePipelineState 38 | 39 | init(device: MTLDevice) { 40 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 41 | 42 | let function = library.makeFunction(name: "siftOrientation")! 43 | function.label = "siftOrientation" 44 | 45 | self.computePipelineState = try! device.makeComputePipelineState( 46 | function: function 47 | ) 48 | } 49 | 50 | func encode( 51 | commandBuffer: MTLCommandBuffer, 52 | parameters: Buffer, 53 | gradientTextures: MTLTexture, 54 | inputKeypoints: Buffer, 55 | outputKeypoints: Buffer 56 | ) { 57 | precondition(inputKeypoints.count == outputKeypoints.count) 58 | precondition(gradientTextures.textureType == .type2DArray) 59 | precondition(gradientTextures.pixelFormat == .rg32Float) 60 | 61 | let encoder = commandBuffer.makeComputeCommandEncoder()! 62 | encoder.label = "siftOrientationComputeEncoder" 63 | encoder.setComputePipelineState(computePipelineState) 64 | encoder.setBuffer(outputKeypoints.data, offset: 0, index: 0) 65 | encoder.setBuffer(inputKeypoints.data, offset: 0, index: 1) 66 | encoder.setBuffer(parameters.data, offset: 0, index: 2) 67 | encoder.setTexture(gradientTextures, index: 0) 68 | 69 | let threadsPerThreadgroup = MTLSize( 70 | width: computePipelineState.maxTotalThreadsPerThreadgroup, 71 | height: 1, 72 | depth: 1 73 | ) 74 | let threadsPerGrid = MTLSize( 75 | width: outputKeypoints.count, 76 | height: 1, 77 | depth: 1 78 | ) 79 | 80 | encoder.dispatchThreads( 81 | threadsPerGrid, 82 | threadsPerThreadgroup: threadsPerThreadgroup 83 | ) 84 | encoder.endEncoding() 85 | } 86 | } 87 | 88 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/GaussianKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SubtractKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class GaussianKernel { 13 | 14 | private var workingTexture: MTLTexture! 15 | 16 | private let device: MTLDevice 17 | private let convolutionX: Convolution1DKernel 18 | private let convolutionY: Convolution1DKernel 19 | 20 | init(device: MTLDevice, sigma s: Float) { 21 | let radius = Int(ceil(4 * s)) 22 | let size = (radius * 2) + 1 23 | print("GaussianKernel sigma=\(s) radius=\(radius) size=\(size)") 24 | 25 | var weights = [Float]() 26 | var t = Float(0) 27 | let ss = s * s 28 | for k in -radius ... radius { 29 | let kk = Float(k * k) 30 | let w = exp(-0.5 * (kk / ss)) 31 | weights.append(w) 32 | t += w 33 | } 34 | 35 | precondition(weights.count == size) 36 | precondition(weights[radius] == 1.0) 37 | 38 | // Normalize weights 39 | for i in 0 ..< size { 40 | weights[i] = weights[i] / t 41 | } 42 | 43 | precondition(abs(1 - weights.reduce(0, +)) < 0.001) 44 | 45 | self.device = device 46 | 47 | self.convolutionX = Convolution1DKernel( 48 | device: device, 49 | axis: .x, 50 | weights: weights 51 | ) 52 | 53 | self.convolutionY = Convolution1DKernel( 54 | device: device, 55 | axis: .y, 56 | weights: weights 57 | ) 58 | } 59 | 60 | func encode( 61 | commandBuffer: MTLCommandBuffer, 62 | inputTexture: MTLTexture, 63 | outputTexture: MTLTexture 64 | ) { 65 | precondition(inputTexture.width == outputTexture.width) 66 | precondition(inputTexture.height == outputTexture.height) 67 | precondition(inputTexture.pixelFormat == outputTexture.pixelFormat) 68 | 69 | if workingTexture?.width != inputTexture.width || workingTexture?.height != inputTexture.height { 70 | let descriptor = MTLTextureDescriptor.texture2DDescriptor( 71 | pixelFormat: inputTexture.pixelFormat, 72 | width: inputTexture.width, 73 | height: inputTexture.height, 74 | mipmapped: false 75 | ) 76 | descriptor.storageMode = .private 77 | descriptor.usage = [.shaderRead, .shaderWrite] 78 | workingTexture = device.makeTexture(descriptor: descriptor) 79 | } 80 | 81 | convolutionX.encode( 82 | commandBuffer: commandBuffer, 83 | inputTexture: inputTexture, 84 | outputTexture: workingTexture 85 | ) 86 | 87 | convolutionY.encode( 88 | commandBuffer: commandBuffer, 89 | inputTexture: workingTexture, 90 | outputTexture: outputTexture 91 | ) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/NearestNeighborDownScaleKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborScaleKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | import MetalShaders 12 | 13 | 14 | final class NearestNeighborDownScaleKernel { 15 | 16 | private let computePipelineState: MTLComputePipelineState 17 | private let parametersBuffer: MTLBuffer 18 | 19 | init(device: MTLDevice) { 20 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 21 | 22 | let function = library.makeFunction(name: "nearestNeighborDownScale")! 23 | 24 | self.computePipelineState = try! device.makeComputePipelineState( 25 | function: function 26 | ) 27 | self.parametersBuffer = device.makeBuffer( 28 | length: MemoryLayout.stride 29 | )! 30 | } 31 | 32 | func encode( 33 | commandBuffer: MTLCommandBuffer, 34 | inputTexture: MTLTexture, 35 | inputSlice: Int, 36 | outputTexture: MTLTexture, 37 | outputSlice: Int 38 | ) { 39 | precondition((inputTexture.width / 2) == outputTexture.width) 40 | precondition((inputTexture.height / 2) == outputTexture.height) 41 | precondition(inputTexture.textureType == .type2DArray) 42 | precondition(inputTexture.pixelFormat == .r32Float) 43 | precondition(outputTexture.textureType == .type2DArray) 44 | precondition(outputTexture.pixelFormat == .r32Float) 45 | 46 | let p = parametersBuffer.contents().assumingMemoryBound(to: NearestNeighborScaleParameters.self) 47 | p[0] = NearestNeighborScaleParameters( 48 | inputSlice: Int32(inputSlice), 49 | outputSlice: Int32(outputSlice) 50 | ) 51 | 52 | let encoder = commandBuffer.makeComputeCommandEncoder()! 53 | encoder.setComputePipelineState(computePipelineState) 54 | encoder.setTexture(outputTexture, index: 0) 55 | encoder.setTexture(inputTexture, index: 1) 56 | encoder.setBuffer(parametersBuffer, offset: 0, index: 0) 57 | 58 | // Set the compute kernel's threadgroup size of 16x16 59 | // TODO: Ger threadgroup size from command buffer. 60 | let threadgroupSize = MTLSize( 61 | width: 16, 62 | height: 16, 63 | depth: 1 64 | ) 65 | // Calculate the number of rows and columns of threadgroups given the width of the input image 66 | // Ensure that you cover the entire image (or more) so you process every pixel 67 | // Since we're only dealing with a 2D data set, set depth to 1 68 | let threadgroupCount = MTLSize( 69 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 70 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 71 | depth: 1 72 | ) 73 | encoder.dispatchThreadgroups( 74 | threadgroupCount, 75 | threadsPerThreadgroup: threadgroupSize 76 | ) 77 | encoder.endEncoding() 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/Convolution1DKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborScaleKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class Convolution1DKernel { 13 | 14 | enum Axis { 15 | case x 16 | case y 17 | } 18 | 19 | private let computePipelineState: MTLComputePipelineState 20 | private let weightsBuffer: MTLBuffer 21 | private let parametersBuffer: MTLBuffer 22 | 23 | init(device: MTLDevice, axis: Axis, weights: [Float]) { 24 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 25 | 26 | let function: MTLFunction 27 | 28 | switch axis { 29 | case .x: 30 | function = library.makeFunction(name: "convolutionX")! 31 | case .y: 32 | function = library.makeFunction(name: "convolutionY")! 33 | } 34 | 35 | self.computePipelineState = try! device.makeComputePipelineState( 36 | function: function 37 | ) 38 | 39 | var weights = weights 40 | var numberOfWeights: UInt32 = UInt32(weights.count) 41 | self.weightsBuffer = device.makeBuffer( 42 | bytes: &weights, 43 | length: MemoryLayout.stride * weights.count 44 | )! 45 | self.parametersBuffer = device.makeBuffer( 46 | bytes: &numberOfWeights, 47 | length: MemoryLayout.stride 48 | )! 49 | } 50 | 51 | func encode( 52 | commandBuffer: MTLCommandBuffer, 53 | inputTexture: MTLTexture, 54 | outputTexture: MTLTexture 55 | ) { 56 | precondition(inputTexture.width == outputTexture.width) 57 | precondition(inputTexture.height == outputTexture.height) 58 | 59 | let encoder = commandBuffer.makeComputeCommandEncoder()! 60 | encoder.setComputePipelineState(computePipelineState) 61 | encoder.setTexture(outputTexture, index: 0) 62 | encoder.setTexture(inputTexture, index: 1) 63 | encoder.setBuffer(weightsBuffer, offset: 0, index: 0) 64 | encoder.setBuffer(parametersBuffer, offset: 0, index: 1) 65 | 66 | // Set the compute kernel's threadgroup size of 16x16 67 | // TODO: Get threadgroup size from command buffer. 68 | let threadgroupSize = MTLSize( 69 | width: 16, 70 | height: 16, 71 | depth: 1 72 | ) 73 | // Calculate the number of rows and columns of threadgroups given the width of the input image 74 | // Ensure that you cover the entire image (or more) so you process every pixel 75 | // Since we're only dealing with a 2D data set, set depth to 1 76 | let threadgroupCount = MTLSize( 77 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 78 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 79 | depth: 1 80 | ) 81 | encoder.dispatchThreadgroups( 82 | threadgroupCount, 83 | threadsPerThreadgroup: threadgroupSize 84 | ) 85 | encoder.endEncoding() 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/Image.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Image.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/19. 6 | // 7 | 8 | import Foundation 9 | import Metal 10 | 11 | 12 | final class Image { 13 | 14 | let size: IntegralSize 15 | 16 | private let texture: MTLTexture 17 | private let slice: Int 18 | private let buffer: UnsafeMutableBufferPointer 19 | 20 | init(texture: MTLTexture, label: String, slice: Int, defaultValue: T) { 21 | self.size = IntegralSize(width: texture.width, height: texture.height) 22 | self.slice = slice 23 | self.texture = texture 24 | self.texture.label = label 25 | self.buffer = { 26 | let capacity = texture.width * texture.height 27 | let buffer = UnsafeMutableBufferPointer.allocate( 28 | capacity: capacity 29 | ) 30 | buffer.initialize(repeating: defaultValue) 31 | return buffer 32 | }() 33 | } 34 | 35 | deinit { 36 | buffer.deallocate() 37 | } 38 | 39 | func updateFromTexture() { 40 | let region = MTLRegion( 41 | origin: MTLOrigin( 42 | x: 0, 43 | y: 0, 44 | z: 0 45 | ), 46 | size: MTLSize( 47 | width: texture.width, 48 | height: texture.height, 49 | depth: 1 50 | ) 51 | ) 52 | let bytesPerComponent = MemoryLayout.stride 53 | let bytesPerRow = bytesPerComponent * texture.width 54 | let bytesPerImage = bytesPerRow * texture.height 55 | let pointer = UnsafeMutableRawPointer(buffer.baseAddress)! 56 | texture.getBytes( 57 | pointer, 58 | bytesPerRow: bytesPerRow, 59 | bytesPerImage: bytesPerImage, 60 | from: region, 61 | mipmapLevel: 0, 62 | slice: slice 63 | ) 64 | } 65 | 66 | subscript(x: Int, y: Int) -> T { 67 | get { 68 | buffer[offset(x: x, y: y)] 69 | } 70 | set { 71 | buffer[offset(x: x, y: y)] = newValue 72 | } 73 | } 74 | 75 | private func offset(x: Int, y: Int) -> Int { 76 | precondition(x >= 0 && y >= 0 && x <= texture.width - 1 && y <= texture.height - 1) 77 | return (y * texture.width) + x 78 | } 79 | } 80 | 81 | //extension Image where T == Float { 82 | // 83 | // func getGradient(x: Int, y: Int) -> Gradient { 84 | // #warning("FIXME: IPOL implementation seems to swap dx and dy") 85 | // let g = getGradientVector(x: x, y: y) 86 | // return Gradient( 87 | // orientation: atan2(g.x, g.y), 88 | // magnitude: sqrt(g.x * g.x + g.y * g.y) 89 | // ) 90 | // } 91 | // 92 | // func getGradientVector(x: Int, y: Int) -> SIMD2 { 93 | // let px: Float = self[x + 1, y] 94 | // let mx: Float = self[x - 1, y] 95 | // let py: Float = self[x, y + 1] 96 | // let my: Float = self[x, y - 1] 97 | // return SIMD2(x: (px - mx) * 0.5, y: (py - my) * 0.5) 98 | // } 99 | //} 100 | -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/GaussianDifferenceTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GaussianDifferenceTests.swift 3 | // SkyLightTests 4 | // 5 | // Created by Luke Van In on 2022/12/26. 6 | // 7 | 8 | import XCTest 9 | import MetalPerformanceShaders 10 | 11 | @testable import SIFTMetal 12 | 13 | final class GaussianDifferenceTests: SharedTestCase { 14 | 15 | /* 16 | func testGaussianDifference() throws { 17 | let sourceTexture = try loadTexture(name: "butterfly", device: device) 18 | 19 | // let gaussian = MPSImageGaussianBlur(device: device, sigma: 2.5) 20 | // gaussian.edgeMode = .clamp 21 | // gaussian.offset = MPSOffset(x: 3, y: 3, z: 0) 22 | let gaussian = GaussianKernel(device: device, sigma: 2.5) 23 | // gaussian.edgeMode = .clamp 24 | // gaussian.offset = MPSOffset(x: 3, y: 3, z: 0) 25 | 26 | let subtract = MPSImageSubtract(device: device) 27 | 28 | let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor( 29 | pixelFormat: .r32Float, 30 | width: sourceTexture.width, 31 | height: sourceTexture.height, 32 | mipmapped: false 33 | ) 34 | textureDescriptor.usage = [.shaderRead, .shaderWrite] 35 | textureDescriptor.storageMode = .shared 36 | 37 | let inputTexture = device.makeTexture(descriptor: textureDescriptor)! 38 | let gaussianTexture = device.makeTexture(descriptor: textureDescriptor)! 39 | // let gaussian1Texture = device.makeTexture(descriptor: textureDescriptor)! 40 | let resultTexture = device.makeTexture(descriptor: textureDescriptor)! 41 | 42 | 43 | let commandBuffer = commandQueue.makeCommandBuffer()! 44 | 45 | convertSRGBToLinearGrayscaleFunction.encode( 46 | commandBuffer: commandBuffer, 47 | sourceTexture: sourceTexture, 48 | destinationTexture: inputTexture 49 | ) 50 | 51 | gaussian.encode( 52 | commandBuffer: commandBuffer, 53 | inputTexture: inputTexture, 54 | outputTexture: gaussianTexture 55 | ) 56 | 57 | subtract.encode( 58 | commandBuffer: commandBuffer, 59 | primaryTexture: gaussianTexture, 60 | secondaryTexture: inputTexture, 61 | destinationTexture: resultTexture 62 | ) 63 | 64 | commandBuffer.commit() 65 | commandBuffer.waitUntilCompleted() 66 | 67 | attachImage( 68 | name: "input", 69 | uiImage: makeUIImage( 70 | ciImage: smearColor( 71 | ciImage: makeCIImage( 72 | texture: inputTexture 73 | ) 74 | ), 75 | context: ciContext 76 | ) 77 | ) 78 | 79 | attachImage( 80 | name: "gaussian", 81 | uiImage: makeUIImage( 82 | ciImage: smearColor( 83 | ciImage: makeCIImage( 84 | texture: gaussianTexture 85 | ) 86 | ), 87 | context: ciContext 88 | ) 89 | ) 90 | 91 | attachImage( 92 | name: "result", 93 | uiImage: makeUIImage( 94 | ciImage: mapColor( 95 | ciImage: smearColor( 96 | ciImage: normalizeColor( 97 | ciImage: makeCIImage( 98 | texture: resultTexture 99 | ) 100 | ) 101 | ) 102 | ), 103 | context: ciContext 104 | ) 105 | ) 106 | } 107 | */ 108 | } 109 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/Common.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Common.h 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #ifndef Common_h 11 | #define Common_h 12 | 13 | using namespace metal; 14 | 15 | static inline int symmetrizedCoordinates(int i, int l) { 16 | int ll = 2 * l; 17 | i = (i + ll) % (ll); 18 | if (i > l - 1){ 19 | i = ll - 1 - i; 20 | } 21 | return i; 22 | } 23 | 24 | //constant float3x3 identity = float3x3( 25 | // float3(1, 0, 0), 26 | // float3(0, 1, 0), 27 | // float3(0, 0, 1) 28 | //); 29 | 30 | // Computes the inverse of a 3x3 matrix using the cross product and 31 | // triple product. 32 | // https://en.wikipedia.org/wiki/Invertible_matrix#Inversion_of_3_×_3_matrices 33 | // See: https://www.onlinemathstutor.org/post/3x3_inverses 34 | static inline float3x3 invert(const float3x3 input) { 35 | const float3 x0 = input[0]; 36 | const float3 x1 = input[1]; 37 | const float3 x2 = input[2]; 38 | 39 | const float d = determinant(input); // dot(x0, cross(x1, x2)); 40 | 41 | const float3x3 cp = float3x3( 42 | cross(x1, x2), 43 | cross(x2, x0), 44 | cross(x0, x1) 45 | ); 46 | return (1.0 / d) * cp; 47 | } 48 | 49 | // https://github.com/markkilgard/glut/blob/master/lib/gle/vvector.h 50 | //#define SCALE_ADJOINT_3X3(a,s,m) \ 51 | //{ \ 52 | // a[0][0] = (s) * (m[1][1] * m[2][2] - m[1][2] * m[2][1]); \ 53 | // a[1][0] = (s) * (m[1][2] * m[2][0] - m[1][0] * m[2][2]); \ 54 | // a[2][0] = (s) * (m[1][0] * m[2][1] - m[1][1] * m[2][0]); \ 55 | // \ 56 | // a[0][1] = (s) * (m[0][2] * m[2][1] - m[0][1] * m[2][2]); \ 57 | // a[1][1] = (s) * (m[0][0] * m[2][2] - m[0][2] * m[2][0]); \ 58 | // a[2][1] = (s) * (m[0][1] * m[2][0] - m[0][0] * m[2][1]); \ 59 | // \ 60 | // a[0][2] = (s) * (m[0][1] * m[1][2] - m[0][2] * m[1][1]); \ 61 | // a[1][2] = (s) * (m[0][2] * m[1][0] - m[0][0] * m[1][2]); \ 62 | // a[2][2] = (s) * (m[0][0] * m[1][1] - m[0][1] * m[1][0]); \ 63 | //} 64 | //float3x3 scaleAdjoint(const float3x3 m, const float s) { 65 | // float3x3 a; 66 | // a[0][0] = (s) * (m[1][1] * m[2][2] - m[1][2] * m[2][1]); 67 | // a[1][0] = (s) * (m[1][2] * m[2][0] - m[1][0] * m[2][2]); 68 | // a[2][0] = (s) * (m[1][0] * m[2][1] - m[1][1] * m[2][0]); 69 | // 70 | // a[0][1] = (s) * (m[0][2] * m[2][1] - m[0][1] * m[2][2]); 71 | // a[1][1] = (s) * (m[0][0] * m[2][2] - m[0][2] * m[2][0]); 72 | // a[2][1] = (s) * (m[0][1] * m[2][0] - m[0][0] * m[2][1]); 73 | // 74 | // a[0][2] = (s) * (m[0][1] * m[1][2] - m[0][2] * m[1][1]); 75 | // a[1][2] = (s) * (m[0][2] * m[1][0] - m[0][0] * m[1][2]); 76 | // a[2][2] = (s) * (m[0][0] * m[1][1] - m[0][1] * m[1][0]); 77 | // 78 | // return a; 79 | //} 80 | 81 | // https://github.com/markkilgard/glut/blob/master/lib/gle/vvector.h 82 | //#define INVERT_3X3(b,det,a) \ 83 | //{ \ 84 | // double tmp; \ 85 | // DETERMINANT_3X3 (det, a); \ 86 | // tmp = 1.0 / (det); \ 87 | // SCALE_ADJOINT_3X3 (b, tmp, a); \ 88 | //} 89 | //float3x3 invert(const float3x3 input) { 90 | // float d = 1.0 / determinant(input); 91 | // return scaleAdjoint(input, d); 92 | //} 93 | 94 | 95 | // https://metalbyexample.com/fundamentals-of-image-processing/ 96 | static inline float gaussian(float x, float y, float sigma) { 97 | float ss = sigma * sigma; 98 | float xx = x * x; 99 | float yy = y * y; 100 | float base = sqrt(2 * M_PI_F * ss); 101 | float exponent = (xx + yy) / (2 * ss); 102 | return (1 / base) * exp(-exponent); 103 | } 104 | 105 | #endif /* Common_h */ 106 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/ConvolutionSeriesKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NearestNeighborScaleKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | import MetalShaders 12 | 13 | final class ConvolutionSeriesKernel { 14 | 15 | enum Axis { 16 | case x 17 | case y 18 | } 19 | 20 | private let computePipelineState: MTLComputePipelineState 21 | private let parametersBuffer: MTLBuffer 22 | 23 | init(device: MTLDevice, axis: Axis, inputDepth: Int, outputDepth: Int, weights: [Float]) { 24 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders) 25 | 26 | let function: MTLFunction 27 | 28 | switch axis { 29 | case .x: 30 | function = library.makeFunction(name: "convolutionSeriesX")! 31 | function.label = "convolutionSeriesXFunction" 32 | case .y: 33 | function = library.makeFunction(name: "convolutionSeriesY")! 34 | function.label = "convolutionSeriesYFunction" 35 | } 36 | 37 | self.computePipelineState = try! device.makeComputePipelineState( 38 | function: function 39 | ) 40 | 41 | var weights = weights 42 | var parameters = ConvolutionParameters() 43 | parameters.inputDepth = Int32(inputDepth) 44 | parameters.outputDepth = Int32(outputDepth) 45 | parameters.count = Int32(weights.count) 46 | withUnsafeMutablePointer(to: ¶meters.weights) { p in 47 | let p = UnsafeMutableRawPointer(p).assumingMemoryBound(to: Float.self) 48 | p.assign(from: &weights, count: weights.count) 49 | } 50 | self.parametersBuffer = device.makeBuffer( 51 | bytes: ¶meters, 52 | length: MemoryLayout.stride 53 | )! 54 | } 55 | 56 | func encode( 57 | commandBuffer: MTLCommandBuffer, 58 | inputTexture: MTLTexture, 59 | outputTexture: MTLTexture 60 | ) { 61 | precondition(inputTexture.width == outputTexture.width) 62 | precondition(inputTexture.height == outputTexture.height) 63 | // precondition(inputTexture.arrayLength == outputTexture.arrayLength) 64 | precondition(inputTexture.textureType == .type2DArray) 65 | precondition(inputTexture.pixelFormat == .r32Float) 66 | precondition(outputTexture.textureType == .type2DArray) 67 | precondition(outputTexture.pixelFormat == .r32Float) 68 | 69 | let encoder = commandBuffer.makeComputeCommandEncoder()! 70 | encoder.label = "convolutionSeriesComputeEncoder" 71 | encoder.setComputePipelineState(computePipelineState) 72 | encoder.setTexture(outputTexture, index: 0) 73 | encoder.setTexture(inputTexture, index: 1) 74 | encoder.setBuffer(parametersBuffer, offset: 0, index: 0) 75 | 76 | // Set the compute kernel's threadgroup size of 16x16 77 | // TODO: Get threadgroup size from command buffer. 78 | let threadgroupSize = MTLSize( 79 | width: 16, 80 | height: 16, 81 | depth: 1 82 | ) 83 | // Calculate the number of rows and columns of threadgroups given the width of the input image 84 | // Ensure that you cover the entire image (or more) so you process every pixel 85 | // Since we're only dealing with a 2D data set, set depth to 1 86 | let threadgroupCount = MTLSize( 87 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width, 88 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height, 89 | depth: 1 90 | ) 91 | encoder.dispatchThreadgroups( 92 | threadgroupCount, 93 | threadsPerThreadgroup: threadgroupSize 94 | ) 95 | // encoder.dispatchThreads(<#T##threadsPerGrid: MTLSize##MTLSize#>, threadsPerThreadgroup: <#T##MTLSize#>) 96 | encoder.endEncoding() 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.7 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "SIFTMetal", 8 | platforms: [.iOS(.v16), .macOS(.v13)], 9 | products: [ 10 | // Products define the executables and libraries a package produces, and make them visible to other packages. 11 | .library( 12 | name: "SIFTMetal", 13 | targets: ["SIFTMetal", "MetalShaders"]), 14 | ], 15 | dependencies: [ 16 | // Dependencies declare other packages that this package depends on. 17 | // .package(url: /* package url */, from: "1.0.0"), 18 | ], 19 | targets: [ 20 | // Targets are the basic building blocks of a package. A target can define a module or a test suite. 21 | // Targets can depend on other targets in this package, and on products in packages this package depends on. 22 | .target( 23 | name: "MetalShaders", 24 | resources: [ 25 | .process("Metal") 26 | ] 27 | ), 28 | .target( 29 | name: "SIFTMetal", 30 | dependencies: ["MetalShaders"], 31 | resources: [ 32 | .process("Resources") 33 | ], 34 | swiftSettings: [ 35 | .unsafeFlags(["-O"]) 36 | ] 37 | ), 38 | .testTarget( 39 | name: "SIFTMetalTests", 40 | dependencies: ["SIFTMetal", "MetalShaders"], 41 | resources: [ 42 | .process("Resources/butterfly.png"), 43 | .process("Resources/butterfly-descriptors.txt"), 44 | 45 | .process("Resources/extra_OnEdgeResp_butterfly.txt"), 46 | 47 | .process("Resources/scalespace_butterfly_o000_s000.png"), 48 | .process("Resources/scalespace_butterfly_o000_s001.png"), 49 | .process("Resources/scalespace_butterfly_o000_s002.png"), 50 | .process("Resources/scalespace_butterfly_o000_s003.png"), 51 | .process("Resources/scalespace_butterfly_o000_s004.png"), 52 | .process("Resources/scalespace_butterfly_o000_s005.png"), 53 | 54 | .process("Resources/scalespace_butterfly_o001_s000.png"), 55 | .process("Resources/scalespace_butterfly_o001_s001.png"), 56 | .process("Resources/scalespace_butterfly_o001_s002.png"), 57 | .process("Resources/scalespace_butterfly_o001_s003.png"), 58 | .process("Resources/scalespace_butterfly_o001_s004.png"), 59 | .process("Resources/scalespace_butterfly_o001_s005.png"), 60 | 61 | .process("Resources/scalespace_butterfly_o002_s000.png"), 62 | .process("Resources/scalespace_butterfly_o002_s001.png"), 63 | .process("Resources/scalespace_butterfly_o002_s002.png"), 64 | .process("Resources/scalespace_butterfly_o002_s003.png"), 65 | .process("Resources/scalespace_butterfly_o002_s004.png"), 66 | .process("Resources/scalespace_butterfly_o002_s005.png"), 67 | 68 | .process("Resources/scalespace_butterfly_o003_s000.png"), 69 | .process("Resources/scalespace_butterfly_o003_s001.png"), 70 | .process("Resources/scalespace_butterfly_o003_s002.png"), 71 | .process("Resources/scalespace_butterfly_o003_s003.png"), 72 | .process("Resources/scalespace_butterfly_o003_s004.png"), 73 | .process("Resources/scalespace_butterfly_o003_s005.png"), 74 | 75 | .process("Resources/scalespace_butterfly_o004_s000.png"), 76 | .process("Resources/scalespace_butterfly_o004_s001.png"), 77 | .process("Resources/scalespace_butterfly_o004_s002.png"), 78 | .process("Resources/scalespace_butterfly_o004_s003.png"), 79 | .process("Resources/scalespace_butterfly_o004_s004.png"), 80 | .process("Resources/scalespace_butterfly_o004_s005.png"), 81 | ] 82 | ), 83 | ] 84 | ) 85 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Metal Compute/GaussianSeriesKernel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SubtractKernel.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/25. 6 | // 7 | 8 | import Foundation 9 | import MetalPerformanceShaders 10 | 11 | 12 | final class GaussianSeriesKernel { 13 | 14 | private let device: MTLDevice 15 | private let count: Int 16 | private let convolutionX: [ConvolutionSeriesKernel] 17 | private let convolutionY: [ConvolutionSeriesKernel] 18 | private let workingTexture: MTLTexture 19 | 20 | init(device: MTLDevice, sigmas: [Float], textureSize: IntegralSize, arrayLength: Int) { 21 | 22 | let count = sigmas.count 23 | 24 | var convolutionsX = [ConvolutionSeriesKernel]() 25 | var convolutionsY = [ConvolutionSeriesKernel]() 26 | 27 | for i in 0 ..< count { 28 | let s = sigmas[i] 29 | let radius = Int(ceil(4 * s)) 30 | let size = (radius * 2) + 1 31 | print("GaussianKernel sigma=\(s) radius=\(radius) size=\(size)") 32 | 33 | var weights = [Float]() 34 | var t = Float(0) 35 | let ss = s * s 36 | for k in -radius ... radius { 37 | let kk = Float(k * k) 38 | let w = exp(-0.5 * (kk / ss)) 39 | weights.append(w) 40 | t += w 41 | } 42 | 43 | precondition(weights.count == size) 44 | precondition(weights[radius] == 1.0) 45 | 46 | // Normalize weights 47 | for i in 0 ..< size { 48 | weights[i] = weights[i] / t 49 | } 50 | 51 | precondition(abs(1 - weights.reduce(0, +)) < 0.001) 52 | 53 | let convolutionX = ConvolutionSeriesKernel( 54 | device: device, 55 | axis: .x, 56 | inputDepth: i, 57 | outputDepth: 0, 58 | weights: weights 59 | ) 60 | convolutionsX.append(convolutionX) 61 | 62 | let convolutionY = ConvolutionSeriesKernel( 63 | device: device, 64 | axis: .y, 65 | inputDepth: 0, 66 | outputDepth: i + 1, 67 | weights: weights 68 | ) 69 | convolutionsY.append(convolutionY) 70 | 71 | } 72 | 73 | self.device = device 74 | self.count = count 75 | self.convolutionX = convolutionsX 76 | self.convolutionY = convolutionsY 77 | 78 | self.workingTexture = { 79 | let descriptor = MTLTextureDescriptor() 80 | descriptor.textureType = .type2DArray 81 | descriptor.pixelFormat = .r32Float 82 | descriptor.width = textureSize.width 83 | descriptor.height = textureSize.height 84 | descriptor.arrayLength = 1 85 | descriptor.mipmapLevelCount = 1 86 | descriptor.storageMode = .private 87 | descriptor.usage = [.shaderRead, .shaderWrite] 88 | return device.makeTexture(descriptor: descriptor)! 89 | }() 90 | } 91 | 92 | func encode( 93 | commandBuffer: MTLCommandBuffer, 94 | texture: MTLTexture 95 | ) { 96 | // precondition(inputTexture.width == workingTexture.width) 97 | // precondition(inputTexture.height == workingTexture.height) 98 | // precondition(outputTexture.width == workingTexture.width) 99 | // precondition(outputTexture.height == workingTexture.height) 100 | // precondition(inputTexture.pixelFormat == .r32Float) 101 | // precondition(inputTexture.textureType == .type2DArray) 102 | // precondition(inputTexture.arrayLength == workingTexture.arrayLength) 103 | // precondition(outputTexture.pixelFormat == .r32Float) 104 | // precondition(outputTexture.textureType == .type2DArray) 105 | // precondition(outputTexture.arrayLength == workingTexture.arrayLength) 106 | 107 | for i in 0 ..< count { 108 | convolutionX[i].encode( 109 | commandBuffer: commandBuffer, 110 | inputTexture: texture, 111 | outputTexture: workingTexture 112 | ) 113 | convolutionY[i].encode( 114 | commandBuffer: commandBuffer, 115 | inputTexture: workingTexture, 116 | outputTexture: texture 117 | ) 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/KeypointTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // KeypointTests.swift 3 | // SkyLightTests 4 | // 5 | // Created by Luke Van In on 2022/12/26. 6 | // 7 | 8 | import XCTest 9 | 10 | @testable import SIFTMetal 11 | 12 | 13 | final class KeypointTests: SharedTestCase { 14 | 15 | func testKeypoints() throws { 16 | 17 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false) 18 | let configuration = SIFT.Configuration( 19 | inputSize: IntegralSize( 20 | width: inputTexture.width, 21 | height: inputTexture.height 22 | ) 23 | ) 24 | let subject = SIFT(device: device, configuration: configuration) 25 | let octaveKeypoints = subject.getKeypoints(inputTexture) 26 | let keypoints = Array(octaveKeypoints.joined()) 27 | print("Found", keypoints.count, "keypoints") 28 | 29 | // let referenceKeypoints: [SIFTKeypoint] = [] 30 | let referenceImage: CGImage = { 31 | let originalImage = CIImage( 32 | mtlTexture: inputTexture, 33 | options: [ 34 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!, 35 | ] 36 | )! 37 | .oriented(.downMirrored) 38 | .smearColor() 39 | let cgImage = ciContext.makeCGImage(ciImage: originalImage) 40 | return cgImage 41 | }() 42 | 43 | // let referenceKeypoints = try loadKeypoints(filename: "extra_DoGSoftThresh_butterfly") 44 | // let referenceKeypoints = try loadKeypoints(filename: "extra_ExtrInterp_butterfly") 45 | // let referenceKeypoints = try loadKeypoints(filename: "extra_DoGThresh_butterfly") 46 | let referenceKeypoints = try loadKeypoints(filename: "extra_OnEdgeResp_butterfly") 47 | // let referenceKeypoints = try loadKeypoints(filename: "extra_FarFromBorder_butterfly") 48 | // let referenceImage = UIImage(named: "butterfly-keypoints-raw")!.cgImage! 49 | // for i in 0 ..< keypoints.count { 50 | // var keypoint = keypoints[i] 51 | // keypoint.x = Float(inputTexture.width) - keypoint.x 52 | // keypoint.y = Float(inputTexture.height) - keypoint.y 53 | // keypoints[i] = keypoint 54 | // } 55 | 56 | let renderer = SIFTRenderer() 57 | attachImage( 58 | name: "keypoints", 59 | uiImage: renderer.drawKeypoints( 60 | sourceImage: referenceImage, 61 | referenceKeypoints: referenceKeypoints, 62 | foundKeypoints: keypoints 63 | ) 64 | ) 65 | 66 | // for (scale, octave) in subject.octaves.enumerated() { 67 | // 68 | // for (index, texture) in octave.keypointTextures.enumerated() { 69 | // 70 | // attachImage( 71 | // name: "keypoints(\(scale), \(index))", 72 | // uiImage: ciContext.makeUIImage( 73 | // ciImage: CIImage( 74 | // mtlTexture: texture, 75 | // options: [ 76 | // .colorSpace: CGColorSpace(name: CGColorSpace.genericGrayGamma2_2)! 77 | // ] 78 | // )! 79 | // .oriented(.downMirrored) 80 | // .smearColor() 81 | // ) 82 | // ) 83 | // } 84 | // 85 | // } 86 | 87 | } 88 | 89 | 90 | private func loadKeypoints(filename: String, extension: String = "txt") throws -> [SIFTKeypoint] { 91 | var keypoints = [SIFTKeypoint]() 92 | 93 | let fileURL = bundle.url(forResource: filename, withExtension: `extension`)! 94 | let data = try Data(contentsOf: fileURL) 95 | let string = String(data: data, encoding: .utf8)! 96 | let lines = string.split(separator: "\n") 97 | 98 | for line in lines { 99 | let components = line.split(separator: " ") 100 | let y = Float(components[0])! 101 | let x = Float(components[1])! 102 | let s = Float(components[2])! 103 | let keypoint = SIFTKeypoint( 104 | octave: 0, 105 | scale: 0, 106 | subScale: 0, 107 | scaledCoordinate: .zero, 108 | absoluteCoordinate: SIMD2(x: x, y: y), 109 | sigma: s, 110 | value: 0 111 | ) 112 | keypoints.append(keypoint) 113 | } 114 | 115 | return keypoints 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/SIFTExtrema.metal: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTExtrema.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #include "../include/SIFTExtrema.h" 11 | 12 | using namespace metal; 13 | 14 | 15 | constant int3 neighborOffsets[] = { 16 | int3(-1, -1, -1), 17 | int3( 0, -1, -1), 18 | int3(+1, -1, -1), 19 | int3(-1, 0, -1), 20 | int3( 0, 0, -1), 21 | int3(+1, 0, -1), 22 | int3(-1, +1, -1), 23 | int3( 0, +1, -1), 24 | int3(+1, +1, -1), 25 | 26 | int3(-1, -1, 0), 27 | int3( 0, -1, 0), 28 | int3(+1, -1, 0), 29 | int3(-1, 0, 0), 30 | 31 | int3(+1, 0, 0), 32 | int3(-1, +1, 0), 33 | int3( 0, +1, 0), 34 | int3(+1, +1, 0), 35 | 36 | int3(-1, -1, +1), 37 | int3( 0, -1, +1), 38 | int3(+1, -1, +1), 39 | int3(-1, 0, +1), 40 | int3( 0, 0, +1), 41 | int3(+1, 0, +1), 42 | int3(-1, +1, +1), 43 | int3( 0, +1, +1), 44 | int3(+1, +1, +1), 45 | }; 46 | 47 | 48 | static inline float fetch( 49 | texture2d_array texture [[texture(0)]], 50 | const int2 g, 51 | const int s, 52 | const int i 53 | ) { 54 | const int3 neighborOffset = neighborOffsets[i]; 55 | const int2 neighborDelta = g + neighborOffset.xy; 56 | const int textureIndex = s + neighborOffset.z; 57 | const float neighborValue = texture.read((ushort2)neighborDelta, (short)textureIndex).r; 58 | return neighborValue; 59 | } 60 | 61 | 62 | kernel void siftExtremaList( 63 | device SIFTExtremaResult * output [[buffer(0)]], 64 | device atomic_uint * outputCount [[buffer(1)]], 65 | texture2d_array inputTexture [[texture(0)]], 66 | ushort3 gid [[thread_position_in_grid]], 67 | ushort3 lid [[thread_position_in_threadgroup]], 68 | ushort tid [[thread_index_in_threadgroup]] 69 | ) { 70 | // Thread group runs [0...output.width - 2][0...output.height - 2] 71 | const ushort threadsInThreadgroup = 1024; 72 | threadgroup SIFTExtremaResult localResults[threadsInThreadgroup]; 73 | threadgroup atomic_int localCount; 74 | atomic_store_explicit(&localCount, 0, memory_order_relaxed); 75 | threadgroup_barrier(mem_flags::mem_none); 76 | 77 | const int2 g = (int2)gid.xy + 1; 78 | const int s = (int)gid.z + 1; 79 | const float value = inputTexture.read((ushort2)g, (ushort)s).r; 80 | 81 | float minimum = +1000; 82 | float maximum = -1000; 83 | 84 | for (int i = 1; i < 26; i++) { 85 | float neighborValue = fetch(inputTexture, g, s, i); 86 | minimum = min(minimum, neighborValue); 87 | maximum = max(maximum, neighborValue); 88 | } 89 | 90 | if ((value < minimum) || (value > maximum)) { 91 | const int i = atomic_fetch_add_explicit(&localCount, 1, memory_order_relaxed); 92 | SIFTExtremaResult result; 93 | result.x = g.x; 94 | result.y = g.y; 95 | result.scale = s; 96 | localResults[i] = result; 97 | } 98 | 99 | // Copy local results to output 100 | threadgroup_barrier(mem_flags::mem_none); 101 | if (tid == 0) { 102 | const int count = atomic_load_explicit(&localCount, memory_order_relaxed); 103 | if (count > 0) { 104 | const int b = atomic_fetch_add_explicit(outputCount, count, memory_order_relaxed); 105 | for (int i = 0; i < count; i++) { 106 | output[b + i] = localResults[i]; 107 | } 108 | } 109 | } 110 | } 111 | 112 | 113 | kernel void siftExtrema( 114 | texture2d_array outputTexture [[texture(0)]], 115 | texture2d_array inputTexture [[texture(1)]], 116 | ushort3 gid [[thread_position_in_grid]], 117 | ushort3 threadPositionInThreadGroup [[thread_position_in_threadgroup]], 118 | ushort3 threadsPerThreadGroup [[threads_per_threadgroup]] 119 | ) { 120 | // Thread group runs [0...output.width - 2][0...output.height - 2] 121 | 122 | const float value = inputTexture.read(gid.xy + 1, gid.z + 1).r; 123 | const int2 center = int2(gid.xy); 124 | 125 | float minValue = +1000; 126 | float maxValue = -1000; 127 | 128 | for (int i = 0; i < 26; i++) { 129 | int3 neighborOffset = neighborOffsets[i]; 130 | ushort textureIndex = gid.z + neighborOffset.x; 131 | int2 neighborDelta = int2(neighborOffset.yz); 132 | ushort2 coordinate = ushort2(center + neighborDelta); 133 | float neighborValue = inputTexture.read(coordinate + 1, textureIndex).r; 134 | 135 | minValue = min(minValue, neighborValue); 136 | maxValue = max(maxValue, neighborValue); 137 | } 138 | 139 | float result = 0; 140 | 141 | if ((value < minValue) || (value > maxValue)) { 142 | result = 1; 143 | } 144 | 145 | outputTexture.write(float4(result, 0, 0, 1), gid.xy + 1, gid.z); 146 | } 147 | 148 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/SIFTOrientation.metal: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTOrientation.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #include "Common.hpp" 11 | #include "../include/SIFTOrientation.h" 12 | 13 | using namespace metal; 14 | 15 | 16 | float orientationFromBin(float bin) { 17 | const int n = SIFT_ORIENTATION_HISTOGRAM_BINS; 18 | float t = bin / (float)n; 19 | float tau = 2 * M_PI_F; 20 | float orientation = t * tau; 21 | if (orientation < 0) { 22 | orientation += tau; 23 | } 24 | if (orientation >= tau) { 25 | orientation -= tau; 26 | } 27 | return orientation; 28 | } 29 | 30 | 31 | float interpolatePeak(float h1, float h2, float h3) { 32 | return (h1 - h3) / (2 * (h1 + h3 - 2 * h2)); 33 | } 34 | 35 | 36 | void getPrincipalOrientations( 37 | thread float * histogram, 38 | float orientationThreshold, 39 | thread int & orientationsCount, 40 | thread float * orientations 41 | ) { 42 | const int bins = SIFT_ORIENTATION_HISTOGRAM_BINS; 43 | 44 | float maximum = INT_MIN; 45 | for (int i = 0; i < bins; i++) { 46 | maximum = max(maximum, histogram[i]); 47 | } 48 | 49 | const float threshold = orientationThreshold * maximum; 50 | 51 | orientationsCount = 0; 52 | 53 | for (int i = 0; i < bins; i++) { 54 | float hm = histogram[((i - 1) + bins) % bins]; 55 | float h0 = histogram[i]; 56 | float hp = histogram[(i + 1) % bins]; 57 | if ((h0 > threshold) && (h0 > hm) && (h0 > hp)) { 58 | float offset = interpolatePeak(hm, h0, hp); 59 | float orientation = orientationFromBin((float)i + offset); 60 | orientations[orientationsCount] = orientation; 61 | orientationsCount += 1; 62 | } 63 | } 64 | } 65 | 66 | 67 | void smoothHistogram( 68 | thread float * histogram, 69 | int iterations 70 | ) { 71 | const int n = SIFT_ORIENTATION_HISTOGRAM_BINS; 72 | float temp[n]; 73 | for (int j = 0; j < iterations; j++) { 74 | for (int i = 0; i < n; i++) { 75 | temp[i] = histogram[i]; 76 | } 77 | for (int i = 0; i < n; i++) { 78 | float h0 = temp[((i - 1) + n) % n]; 79 | float h1 = temp[i]; 80 | float h2 = temp[(i + 1) % n]; 81 | float v = (h0 + h1 + h2) / 3.0; 82 | histogram[i] = v; 83 | } 84 | } 85 | } 86 | 87 | 88 | void getOrientationsHistogram( 89 | texture2d_array g, 90 | int absoluteX, 91 | int absoluteY, 92 | int scale, 93 | float keypointSigma, 94 | float delta, 95 | float lambda, 96 | thread float * histogram 97 | ) { 98 | const int bins = SIFT_ORIENTATION_HISTOGRAM_BINS; 99 | int x = round((float)absoluteX / delta); 100 | int y = round((float)absoluteY / delta); 101 | float sigma = keypointSigma / delta; 102 | 103 | float exponentDenominator = 2.0 * lambda * lambda; 104 | 105 | int r = ceil(3 * lambda * sigma); 106 | 107 | for (int j = -r; j <= r; j++) { 108 | for (int i = -r; i <= r; i++) { 109 | 110 | // Gaussian weighting 111 | float u = (float)i / sigma; 112 | float v = (float)j / sigma; 113 | float r2 = u * u + v * v; 114 | float w = exp(-r2 / exponentDenominator); 115 | 116 | // Gradient orientation 117 | float2 gradient = g.read(ushort2(x + i, y + j), scale).rg; 118 | float orientation = gradient.x; 119 | float magnitude = gradient.y; 120 | 121 | // Add to histogram 122 | float t = orientation / (2 * M_PI_F); 123 | int bin = round(t * (float)bins); 124 | if (bin < 0) { 125 | bin += bins; 126 | } 127 | if (bin >= bins) { 128 | bin -= bins; 129 | } 130 | 131 | float m = w * magnitude; 132 | 133 | histogram[bin] += m; 134 | } 135 | } 136 | } 137 | 138 | 139 | 140 | kernel void siftOrientation( 141 | device SIFTOrientationResult * results [[buffer(0)]], 142 | device SIFTOrientationKeypoint * keypoints [[buffer(1)]], 143 | device SIFTOrientationParameters & parameters [[buffer(2)]], 144 | texture2d_array gradientTextures [[texture(0)]], 145 | ushort gid [[thread_position_in_grid]] 146 | ) { 147 | const int bins = SIFT_ORIENTATION_HISTOGRAM_BINS; 148 | const SIFTOrientationKeypoint keypoint = keypoints[gid]; 149 | SIFTOrientationResult result; 150 | result.keypoint = keypoint.index; 151 | 152 | float histogram[bins]; 153 | for (int i = 0; i < bins; i++) { 154 | histogram[i] = 0; 155 | } 156 | 157 | getOrientationsHistogram( 158 | gradientTextures, 159 | keypoint.absoluteX, 160 | keypoint.absoluteY, 161 | keypoint.scale, 162 | keypoint.sigma, 163 | parameters.delta, 164 | parameters.lambda, 165 | histogram 166 | ); 167 | smoothHistogram(histogram, 6); 168 | getPrincipalOrientations( 169 | histogram, 170 | parameters.orientationThreshold, 171 | result.count, 172 | result.orientations 173 | ); 174 | results[gid] = result; 175 | } 176 | -------------------------------------------------------------------------------- /.swiftpm/xcode/xcshareddata/xcschemes/SIFTMetal.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 9 | 15 | 21 | 22 | 23 | 29 | 35 | 36 | 37 | 43 | 49 | 50 | 51 | 57 | 63 | 64 | 65 | 71 | 77 | 78 | 79 | 80 | 81 | 86 | 87 | 89 | 95 | 96 | 97 | 98 | 99 | 109 | 110 | 116 | 117 | 123 | 124 | 125 | 126 | 128 | 129 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/SharedTestCase.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SharedTestCase.swift 3 | // SkyLightTests 4 | // 5 | // Created by Luke Van In on 2022/12/26. 6 | // 7 | 8 | import XCTest 9 | import CoreImage 10 | import CoreImage.CIFilterBuiltins 11 | import MetalKit 12 | import MetalPerformanceShaders 13 | 14 | @testable import SIFTMetal 15 | 16 | 17 | let bundle = Bundle.module 18 | 19 | 20 | extension MTLDevice { 21 | 22 | func loadTexture(url: URL, srgb: Bool = true) throws -> MTLTexture { 23 | print("Loading texture \(url)") 24 | let loader = MTKTextureLoader(device: self) 25 | return try loader.newTexture( 26 | URL: url, 27 | options: [ 28 | .SRGB: NSNumber(value: srgb), 29 | ] 30 | ) 31 | } 32 | 33 | func loadTexture(name: String, extension: String, srgb: Bool = true) throws -> MTLTexture { 34 | let imageURL = bundle.url(forResource: name, withExtension: `extension`)! 35 | return try loadTexture(url: imageURL, srgb: srgb) 36 | } 37 | } 38 | 39 | 40 | 41 | //func makeUIImage(texture: MTLTexture, context: CIContext) -> UIImage { 42 | // makeUIImage(ciImage: makeCIImage(texture: texture), context: context) 43 | //} 44 | 45 | 46 | extension CIContext { 47 | 48 | func makeUIImage(ciImage: CIImage) -> UIImage { 49 | return UIImage(cgImage: makeCGImage(ciImage: ciImage)) 50 | } 51 | 52 | func makeCGImage(ciImage: CIImage) -> CGImage { 53 | return createCGImage(ciImage, from: ciImage.extent)! 54 | } 55 | } 56 | 57 | 58 | extension CIImage { 59 | 60 | convenience init(name: String, extension: String) throws { 61 | let imageURL = bundle.url(forResource: name, withExtension: `extension`)! 62 | self.init(contentsOf: imageURL)! 63 | } 64 | 65 | // convenience init(texture: MTLTexture, orientation: CGImagePropertyOrientation = .downMirrored) { 66 | // var ciImage = CIImage(mtlTexture: texture)! 67 | // ciImage = ciImage.transformed( 68 | // by: ciImage.orientationTransform( 69 | // for: orientation 70 | // ) 71 | // ) 72 | // self = ciImage 73 | // } 74 | 75 | func orientation(_ orientation: CGImagePropertyOrientation) -> CIImage { 76 | return transformed( 77 | by: orientationTransform( 78 | for: orientation 79 | ) 80 | ) 81 | } 82 | 83 | func sRGBToneCurveToLinear() -> CIImage { 84 | let filter = CIFilter.sRGBToneCurveToLinear() 85 | filter.inputImage = self 86 | return filter.outputImage! 87 | } 88 | 89 | 90 | func linearToSRGBToneCurve() -> CIImage { 91 | let filter = CIFilter.linearToSRGBToneCurve() 92 | filter.inputImage = self 93 | return filter.outputImage! 94 | } 95 | 96 | func smearColor() -> CIImage { 97 | let filter = CIFilter.colorMatrix() 98 | filter.rVector = CIVector(x: 1, y: 0, z: 0) 99 | filter.gVector = CIVector(x: 1, y: 0, z: 0) 100 | filter.bVector = CIVector(x: 1, y: 0, z: 0) 101 | filter.biasVector = CIVector(x: 0, y: 0, z: 0) 102 | filter.inputImage = self 103 | return filter.outputImage!.cropped( 104 | to: self.extent 105 | ) 106 | } 107 | 108 | func normalizeColor() -> CIImage { 109 | let filter = CIFilter.colorMatrix() 110 | filter.rVector = CIVector(x: 0.5, y: 0, z: 0) 111 | filter.gVector = CIVector(x: 0.5, y: 0, z: 0) 112 | filter.bVector = CIVector(x: 0.5, y: 0, z: 0) 113 | filter.biasVector = CIVector(x: 0.5, y: 0.5, z: 0.5) 114 | filter.inputImage = self 115 | return filter.outputImage!.cropped( 116 | to: self.extent 117 | ) 118 | } 119 | 120 | func invertColor() -> CIImage { 121 | self.colorInverted() 122 | } 123 | 124 | 125 | func mapColor() -> CIImage { 126 | let imageFileURL = bundle.url(forResource: "viridis", withExtension: "png")! 127 | let filter = CIFilter.colorMap() 128 | filter.gradientImage = CIImage(contentsOf: imageFileURL) 129 | filter.inputImage = self 130 | return filter.outputImage! 131 | } 132 | } 133 | 134 | 135 | class SharedTestCase: XCTestCase { 136 | 137 | var device: MTLDevice! 138 | var commandQueue: MTLCommandQueue! 139 | var ciContext: CIContext! 140 | // var upscaleFunction: NearestNeighborUpScaleKernel! 141 | var subtractFunction: MPSImageSubtract! 142 | var convertSRGBToLinearGrayscaleFunction: MPSImageConversion! 143 | // var convertLinearRGBToLinearGrayscaleFunction: MPSImageConversion! 144 | 145 | override func setUp() { 146 | device = MTLCreateSystemDefaultDevice()! 147 | commandQueue = device.makeCommandQueue()! 148 | // upscaleFunction = NearestNeighborUpScaleKernel(device: device) 149 | subtractFunction = MPSImageSubtract(device: device) 150 | convertSRGBToLinearGrayscaleFunction = MPSImageConversion( 151 | device: device, 152 | srcAlpha: .alphaIsOne, 153 | destAlpha: .alphaIsOne, 154 | backgroundColor: nil, 155 | conversionInfo: CGColorConversionInfo( 156 | src: CGColorSpace(name: CGColorSpace.sRGB)!, 157 | dst: CGColorSpace(name: CGColorSpace.linearGray)! 158 | ) 159 | ) 160 | ciContext = CIContext( 161 | mtlDevice: device, 162 | options: [ 163 | .outputColorSpace: CGColorSpace(name: CGColorSpace.sRGB)!, 164 | ] 165 | ) 166 | } 167 | 168 | override func tearDown() { 169 | ciContext = nil 170 | // upscaleFunction = nil 171 | subtractFunction = nil 172 | commandQueue = nil 173 | device = nil 174 | } 175 | 176 | func attachImage(name: String, uiImage: UIImage) { 177 | let attachment = XCTAttachment( 178 | image: uiImage, 179 | quality: .original 180 | ) 181 | attachment.name = name 182 | attachment.lifetime = .keepAlways 183 | add(attachment) 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/SIFTDescriptor.metal: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTDescriptor.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/08. 6 | // 7 | 8 | #include 9 | 10 | #include "../include/SIFTDescriptor.h" 11 | 12 | using namespace metal; 13 | 14 | 15 | void normalizeFeatures( 16 | int count, 17 | thread float * features 18 | ) { 19 | float magnitude = 0; 20 | for (int i = 0; i < count; i++) { 21 | float f = features[i]; 22 | magnitude += (f * f); 23 | } 24 | const float d = 1.0 / sqrt(magnitude); 25 | for (int i = 0; i < count; i++) { 26 | features[i] *= d; 27 | } 28 | } 29 | 30 | 31 | void thresholdFeatures( 32 | int count, 33 | thread float * features, 34 | float threshold 35 | ) { 36 | for (int i = 0; i < count; i++) { 37 | features[i] = min(features[i], threshold); 38 | } 39 | } 40 | 41 | 42 | void quantizeFeatures( 43 | int count, 44 | thread float * features, 45 | thread int * output 46 | ) { 47 | for (int i = 0; i < count; i++) { 48 | output[i] = min(255.0, features[i] * 512.0); 49 | } 50 | } 51 | 52 | 53 | int offset(int x, int y, int b) { 54 | const int side = 4; 55 | const int bins = SIFT_DESCRIPTOR_ORIENTATION_BINS; 56 | return (y * side * bins) + (x * bins) + b; 57 | } 58 | 59 | 60 | void addValue( 61 | thread float * patch, 62 | int x, 63 | int y, 64 | int b, 65 | float value 66 | ) { 67 | const int side = 4; 68 | const int bins = SIFT_DESCRIPTOR_ORIENTATION_BINS; 69 | if ((x < 0) || (x >= side) || (y < 0) || (y >= side)) { 70 | return; 71 | } 72 | if (b < 0) { 73 | b += bins; 74 | } 75 | if (b >= bins) { 76 | b -= bins; 77 | } 78 | patch[offset(x, y, b)] += value; 79 | } 80 | 81 | 82 | void addFeature( 83 | thread float * patch, 84 | float x, 85 | float y, 86 | float b, 87 | float value 88 | ) { 89 | // Integer coordinates of the four pixels surrounding the point x, y 90 | const int2 ca = int2(floor(x), floor(y)); 91 | const int2 cb = int2(ceil(x), floor(y)); 92 | const int2 cc = int2(ceil(x), ceil(y)); 93 | const int2 cd = int2(floor(x), ceil(y)); 94 | 95 | // Bins surrounding the bin at index b 96 | const int ba = floor(b); 97 | const int bb = ceil(b); 98 | 99 | const float iMax = x - floor(x); 100 | const float iMin = 1 - iMax; 101 | const float jMax = y - floor(y); 102 | const float jMin = 1 - jMax; 103 | const float bMax = b - floor(b); 104 | const float bMin = 1 - bMax; 105 | 106 | addValue(patch, ca.x, ca.y, ba, (iMin * jMin * bMin) * value); 107 | addValue(patch, ca.x, ca.y, bb, (iMin * jMin * bMax) * value); 108 | 109 | addValue(patch, cb.x, cb.y, ba, (iMax * jMin * bMin) * value); 110 | addValue(patch, cb.x, cb.y, bb, (iMax * jMin * bMax) * value); 111 | 112 | addValue(patch, cc.x, cc.y, ba, (iMax * jMax * bMin) * value); 113 | addValue(patch, cc.x, cc.y, bb, (iMax * jMax * bMax) * value); 114 | 115 | addValue(patch, cd.x, cd.y, ba, (iMin * jMax * bMin) * value); 116 | addValue(patch, cd.x, cd.y, bb, (iMin * jMax * bMax) * value); 117 | } 118 | 119 | 120 | kernel void siftDescriptors( 121 | device SIFTDescriptorResult * results [[buffer(0)]], 122 | device SIFTDescriptorInput * inputs [[buffer(1)]], 123 | device SIFTDescriptorParameters & parameters [[buffer(2)]], 124 | texture2d_array gradientTextures [[texture(0)]], 125 | ushort gid [[thread_position_in_grid]] 126 | ) { 127 | 128 | // let octave = dog.octaves[keypoint.octave] 129 | // let images = octave.gaussianImages 130 | // let histogramsPerAxis = configuration.descriptorHistogramsPerAxis 131 | SIFTDescriptorResult result; 132 | const SIFTDescriptorInput input = inputs[gid]; 133 | 134 | 135 | // let image = octaves[keypoint.octave].gradientImages[keypoint.scale] 136 | 137 | // let delta = octave.delta 138 | // let lambda = configuration.lambdaDescriptor 139 | // let a = keypoint.absoluteCoordinate 140 | float px = float(input.absoluteX) / parameters.delta; 141 | float py = float(input.absoluteY) / parameters.delta; 142 | 143 | // Check that the keypoint is sufficiently far from the edge to include 144 | // entire area of the descriptor. 145 | 146 | #warning("TODO: Do this check after interpolation to avoid wasting work on extracting the orientation") 147 | // let diagonal = Float(2).squareRoot() * lambda * sigma 148 | // let f = Float(histogramsPerAxis + 1) / Float(histogramsPerAxis) 149 | // let side = Int((diagonal * f).rounded()) 150 | 151 | //let radius = lambda * f 152 | const int d = 4; // width of 2d array of histograms 153 | const int bins = SIFT_DESCRIPTOR_ORIENTATION_BINS; 154 | 155 | const float tau = 2 * M_PI_F; 156 | const float cosT = cos(input.theta); 157 | const float sinT = sin(input.theta); 158 | const float binsPerRadian = (float)bins / tau; 159 | const float exponentDenominator = (float)(d * d) * 0.5; 160 | const float interval = (float)input.scale + input.subScale; 161 | const float intervals = (float)parameters.scalesPerOctave; 162 | const float sigma = 1.6; 163 | const float scale = sigma * pow(2.0, interval / intervals); // identical to below 164 | // let _sigma = keypoint.sigma / octave.delta // identical to above 165 | const float histogramWidth = 3.0 * scale; // 3.0 constant from Whess (OpenSIFT) 166 | const int radius = histogramWidth * sqrt(2.0) * ((float)d + 1.0) * 0.5 + 0.5; 167 | 168 | // const int minX = radius; 169 | // const int minY = radius; 170 | // const int maxX = parameters.width - 1 - radius; 171 | // const int maxY = parameters.height - 1 - radius; 172 | // 173 | // if (px < minX) { 174 | // return; 175 | // } 176 | // if (py < minY) { 177 | // return; 178 | // } 179 | // if (px > maxX) { 180 | // return; 181 | // } 182 | // if (py > maxY) { 183 | // return; 184 | // } 185 | 186 | // Create histograms 187 | const int featureCount = d * d * bins; 188 | float features[featureCount]; 189 | 190 | for (int i = 0; i < featureCount; i++) { 191 | features[i] = 0; 192 | } 193 | 194 | for (int j = -radius; j <= +radius; j++) { 195 | for (int i = -radius; i <= +radius; i++) { 196 | 197 | float rx = ((float)j * cosT - (float)i * sinT) / histogramWidth; 198 | float ry = ((float)j * sinT + (float)i * cosT) / histogramWidth; 199 | float bx = rx + (float)(d / 2) - 0.5; 200 | float by = ry + (float)(d / 2) - 0.5; 201 | 202 | float2 g = gradientTextures.read(ushort2(px + j, py + i), input.scale).rg; 203 | float orientation = g.r - input.theta; 204 | float magnitude = g.g; 205 | while (orientation < 0) { 206 | orientation += tau; 207 | } 208 | while (orientation >= tau) { 209 | orientation -= tau; 210 | } 211 | 212 | // Bin 213 | float bin = orientation * binsPerRadian; 214 | 215 | // Total contribution 216 | float exponentNumerator = rx * rx + ry * ry; 217 | float w = exp(-exponentNumerator / exponentDenominator); 218 | float value = magnitude * w; 219 | 220 | addFeature(features, bx, by, bin, value); 221 | } 222 | } 223 | 224 | // print("feature x=\(Int(a.x)) y=\(Int(a.y)) scale=\(scale) sigma=\(_sigma) histogramWidth=\(histogramWidth) radius=\(radius)") 225 | 226 | // Serialize histograms into array 227 | normalizeFeatures(featureCount, features); 228 | thresholdFeatures(featureCount, features, 0.2); 229 | normalizeFeatures(featureCount, features); 230 | quantizeFeatures(featureCount, features, result.features); 231 | 232 | result.valid = true; 233 | result.keypoint = input.keypoint; 234 | result.theta = input.theta; 235 | 236 | results[gid] = result; 237 | } 238 | -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/TrieTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // TrieTests.swift 3 | // SkyLightTests 4 | // 5 | // Created by Luke Van In on 2023/01/20. 6 | // 7 | 8 | import XCTest 9 | 10 | @testable import SIFTMetal 11 | 12 | /* 13 | final class TrieTests: XCTestCase { 14 | 15 | func testInsert_shouldMatchStructure() { 16 | let expected = Trie( 17 | Trie( 18 | nil, 19 | Trie( 20 | nil, 21 | nil, 22 | Trie( 23 | nil, 24 | nil, 25 | nil 26 | ) 27 | ), 28 | nil 29 | ), 30 | nil, 31 | nil 32 | ) 33 | let trie = Trie(numberOfBins: 3) 34 | trie.insert(key: FloatVector([0.0, 0.5, 1.0]), value: FloatVector([0.0, 0.5, 1.0])) 35 | XCTAssertEqual(trie, expected) 36 | } 37 | 38 | func testContains_shouldReturnFalse_whenTrieDoesNotContainExactMatch() { 39 | let trie = Trie(numberOfBins: 3) 40 | trie.insert(key: FloatVector([1.0, 1.0, 1.0]), value: FloatVector([1.0, 1.0, 1.0])) 41 | let result = trie.contains(FloatVector([0.0, 0.0, 0.0])) 42 | XCTAssertFalse(result) 43 | } 44 | 45 | func testContains_shouldReturnTrue_whenTrieContainsExactMatch() { 46 | let trie = Trie(numberOfBins: 3) 47 | trie.insert(key: FloatVector([0.1, 0.2, 0.3]), value: FloatVector([0.1, 0.2, 0.3])) 48 | let result = trie.contains(FloatVector([0.1, 0.2, 0.3])) 49 | XCTAssertTrue(result) 50 | } 51 | 52 | func testContains_shouldReturnTrue_whenTrieContainsPartialMatch() { 53 | let trie = Trie(numberOfBins: 3) 54 | trie.insert(key: FloatVector([0.1, 0.2, 0.3]), value: FloatVector([0.1, 0.2, 0.3])) 55 | let result = trie.contains(FloatVector([0.1, 0.2])) 56 | XCTAssertTrue(result) 57 | } 58 | 59 | func testContains_shouldReturnTrue_whenTrieContainsSimilarValues() { 60 | let trie = Trie(numberOfBins: 3) 61 | trie.insert(key: FloatVector([0, 0.5, 1.0]), value: FloatVector([0, 0.5, 1.0])) // bins: 0, 1, 2 62 | XCTAssertTrue(trie.contains(FloatVector([0.1, 0.5, 1.0]))) 63 | XCTAssertTrue(trie.contains(FloatVector([0.1, 0.6, 1.0]))) 64 | XCTAssertTrue(trie.contains(FloatVector([0.1, 0.6, 0.9]))) 65 | } 66 | 67 | func testNearest_shouldReturnNearestValue_whenTrieContainsSimilarValues() { 68 | let trie = Trie(numberOfBins: 3) 69 | trie.insert(key: FloatVector([0, 0.5, 1.0]), value: FloatVector([0, 0.5, 1.0])) 70 | let results = trie.nearest(key: FloatVector([0.1, 0.6, 0.9]), query: FloatVector([0.1, 0.6, 0.9]), radius: 0, k: 1) 71 | let expected = FloatVector([0, 0.5, 1.0]) 72 | XCTAssertEqual(results[0].value, expected) 73 | } 74 | 75 | func testNearestAccuracy() { 76 | 77 | struct Neighbor: CustomStringConvertible { 78 | let id: Int 79 | let distance: Float 80 | 81 | var description: String { 82 | "" 83 | } 84 | } 85 | 86 | // capacity = 3^d * n 87 | let n = 1000 88 | let m = 100 89 | let d = 128 90 | var values: [FloatVector] = [] 91 | var queries: [FloatVector] = [] 92 | var nearestNeighbors: [Neighbor] = [] 93 | 94 | // Generate sample data 95 | print("generate sample data x", n) 96 | for _ in 0 ..< n { 97 | var vector: [Float] = Array(repeating: 0, count: d) 98 | for j in 0 ..< d { 99 | vector[j] = .random(in: 0...1) 100 | } 101 | let value = FloatVector(vector) 102 | values.append(value) 103 | } 104 | 105 | // Generate queries. 106 | print("generate queries x", m) 107 | for _ in 0 ..< m { 108 | var vector: [Float] = Array(repeating: 0, count: d) 109 | for j in 0 ..< d { 110 | vector[j] = .random(in: 0...1) 111 | } 112 | queries.append(FloatVector(vector)) 113 | } 114 | 115 | // Compute actual nearest neighbor for each node using brute force. 116 | print("compute nearest neighbor ground truth") 117 | for i in 0 ..< m { 118 | let query = queries[i] 119 | var nearestDistance: Float = .greatestFiniteMagnitude 120 | var nearestNeighbor: Neighbor! 121 | for j in 0 ..< n { 122 | let value = values[j] 123 | let distance = value.distance(to: query) 124 | guard distance < nearestDistance else { 125 | continue 126 | } 127 | nearestDistance = distance 128 | nearestNeighbor = Neighbor(id: j, distance: distance) 129 | } 130 | nearestNeighbors.append(nearestNeighbor) 131 | } 132 | 133 | // Create the trie. 134 | print("create trie") 135 | let subject = Trie(numberOfBins: 4) 136 | for i in 0 ..< n { 137 | let value = values[i] 138 | subject.insert(key: value, value: value) 139 | } 140 | print("capacity", subject.capacity()) 141 | 142 | print("linking trie") 143 | subject.link() 144 | 145 | // Sanity check 146 | print("check contains") 147 | for i in 0 ..< n { 148 | let query = values[i] 149 | let result = subject.contains(query) 150 | XCTAssertTrue(result) 151 | } 152 | 153 | // Sanity check: nearest() with existing key/value should always return exact match 154 | print("check exact nearest") 155 | for i in 0 ..< n { 156 | let query = values[i] 157 | let matches = subject.nearest(key: query, query: query, radius: 0, k: 1) 158 | XCTAssertEqual(matches[0].value, query) 159 | } 160 | 161 | // Compute approximate nearest neighbor. 162 | var totalError: Float = 0 163 | var totalDistance: Float = 0 164 | var totalCorrect: Int = 0 165 | var totalQueries: Int = 0 166 | var totalNodes: Int = 0 167 | var totalFound: Int = 0 168 | // measure { 169 | for i in 0 ..< m { 170 | totalQueries += 1 171 | let query = queries[i] 172 | let foundNeighbors = subject.nearest(key: query, query: query, radius: 10, k: 1) 173 | guard let foundNeighbor = foundNeighbors.first else { 174 | continue 175 | } 176 | totalFound += 1 177 | totalNodes += subject.comparisonCountMetric 178 | let foundDistance = foundNeighbor.distance 179 | totalDistance = foundDistance 180 | let exactNeighbor = nearestNeighbors[i] 181 | let exactDistance = exactNeighbor.distance 182 | let delta = exactDistance - foundDistance 183 | let error = delta * delta 184 | totalError += error 185 | let correct = error == 0 186 | totalCorrect += correct ? 1 : 0 187 | } 188 | // } 189 | let percentCorrect = Float(totalCorrect) / Float(totalQueries) 190 | let meanSquaredError = totalError / Float(totalFound) 191 | let averageNodesPerQuery = Float(totalNodes) / Float(totalFound) 192 | let averageDistance = Float(totalDistance) / Float(totalFound) 193 | print("Total queries: \(totalQueries)") 194 | print("Total found: \(totalFound)") 195 | print("Mean squared error: \(String(format: "%0.3f", meanSquaredError))") 196 | print("Average absolute distance: \(String(format: "%0.3f", averageDistance))") 197 | print("Correct: \(totalCorrect) out of \(totalQueries) = \(String(format: "%0.3f", percentCorrect))") 198 | print("Total nodes: \(totalNodes)") 199 | print("Average comparisons per query: \(String(format: "%0.3f", averageNodesPerQuery))") 200 | } 201 | } 202 | */ 203 | -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/DescriptorTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DescriptorTests.swift 3 | // SkyLightTests 4 | // 5 | // Created by Luke Van In on 2022/12/26. 6 | // 7 | 8 | import XCTest 9 | import simd 10 | 11 | @testable import SIFTMetal 12 | 13 | 14 | final class DescriptorTests: SharedTestCase { 15 | 16 | func testDescriptors() throws { 17 | 18 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false) 19 | let configuration = SIFT.Configuration( 20 | inputSize: IntegralSize( 21 | width: inputTexture.width, 22 | height: inputTexture.height 23 | ) 24 | ) 25 | let subject = SIFT(device: device, configuration: configuration) 26 | let keypoints = subject.getKeypoints(inputTexture) 27 | let descriptorOctaves = subject.getDescriptors(keypointOctaves: keypoints) 28 | let foundDescriptors = Array(descriptorOctaves.joined()) 29 | print("Found", foundDescriptors.count, "descriptors") 30 | 31 | print("loading reference desriptors") 32 | let referenceDescriptors = try loadDescriptors(filename: "butterfly-descriptors") 33 | 34 | print("rendering image") 35 | let referenceImage: CGImage = { 36 | let originalImage = CIImage( 37 | mtlTexture: inputTexture, 38 | options: [ 39 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!, 40 | ] 41 | )! 42 | .oriented(.downMirrored) 43 | .smearColor() 44 | let cgImage = ciContext.makeCGImage(ciImage: originalImage) 45 | return cgImage 46 | }() 47 | 48 | let renderer = SIFTRenderer() 49 | attachImage( 50 | name: "descriptors", 51 | uiImage: renderer.drawDescriptors( 52 | sourceImage: referenceImage, 53 | referenceDescriptors: referenceDescriptors, 54 | foundDescriptors: foundDescriptors 55 | ) 56 | ) 57 | } 58 | 59 | func testDescriptorsPerformance() throws { 60 | 61 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false) 62 | let configuration = SIFT.Configuration( 63 | inputSize: IntegralSize( 64 | width: inputTexture.width, 65 | height: inputTexture.height 66 | ) 67 | ) 68 | let subject = SIFT(device: device, configuration: configuration) 69 | measure { 70 | let keypoints = subject.getKeypoints(inputTexture) 71 | _ = subject.getDescriptors(keypointOctaves: keypoints) 72 | } 73 | } 74 | 75 | private func matchDescriptors(detected: [SIFTDescriptor], reference: [SIFTDescriptor]) { 76 | 77 | let matches = SIFTDescriptor.match( 78 | source: detected, 79 | target: reference, 80 | absoluteThreshold: 300, 81 | relativeThreshold: 0.6 82 | ) 83 | 84 | let rate = Float(matches.count) / Float(detected.count) 85 | print("found \(matches.count) out of \(detected.count) = \(rate * 100)%") 86 | XCTAssertGreaterThanOrEqual(rate, 80.0) 87 | } 88 | 89 | func testMatches() throws { 90 | 91 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false) 92 | let configuration = SIFT.Configuration( 93 | inputSize: IntegralSize( 94 | width: inputTexture.width, 95 | height: inputTexture.height 96 | ) 97 | ) 98 | let subject = SIFT(device: device, configuration: configuration) 99 | let keypoints = subject.getKeypoints(inputTexture) 100 | let descriptorOctaves = subject.getDescriptors(keypointOctaves: keypoints) 101 | let foundDescriptors = Array(descriptorOctaves.joined()) 102 | print("Found", foundDescriptors.count, "descriptors") 103 | 104 | let referenceImage: CGImage = { 105 | let originalImage = CIImage( 106 | mtlTexture: inputTexture, 107 | options: [ 108 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!, 109 | ] 110 | )! 111 | .oriented(.downMirrored) 112 | .smearColor() 113 | let cgImage = ciContext.makeCGImage(ciImage: originalImage) 114 | return cgImage 115 | }() 116 | 117 | let referenceDescriptors = try loadDescriptors(filename: "butterfly-descriptors") 118 | print("Loaded \(referenceDescriptors.count) reference descriptors") 119 | 120 | let matches = SIFTDescriptor.match( 121 | source: foundDescriptors, 122 | target: referenceDescriptors, 123 | absoluteThreshold: 300, 124 | relativeThreshold: 0.6 125 | ) 126 | 127 | print("Found \(matches.count) matches") 128 | 129 | print("drawing matches") 130 | let renderer = SIFTRenderer() 131 | attachImage( 132 | name: "matches", 133 | uiImage: renderer.drawMatches( 134 | sourceImage: referenceImage, 135 | targetImage: referenceImage, 136 | matches: matches 137 | ) 138 | ) 139 | } 140 | 141 | func testMatchesPerformance() throws { 142 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false) 143 | let configuration = SIFT.Configuration( 144 | inputSize: IntegralSize( 145 | width: inputTexture.width, 146 | height: inputTexture.height 147 | ) 148 | ) 149 | let subject = SIFT(device: device, configuration: configuration) 150 | let keypoints = subject.getKeypoints(inputTexture) 151 | let descriptorOctaves = subject.getDescriptors(keypointOctaves: keypoints) 152 | let foundDescriptors = Array(descriptorOctaves.joined()) 153 | print("Found", foundDescriptors.count, "descriptors") 154 | 155 | let referenceDescriptors = try loadDescriptors(filename: "butterfly-descriptors") 156 | print("Loaded \(referenceDescriptors.count) reference descriptors") 157 | 158 | // matchDescriptors(detected: foundDescriptors, reference: referenceDescriptors) 159 | 160 | print("Finding matches") 161 | measure { 162 | _ = SIFTDescriptor.match( 163 | source: foundDescriptors, 164 | target: referenceDescriptors, 165 | absoluteThreshold: 300, 166 | relativeThreshold: 0.6 167 | ) 168 | } 169 | } 170 | 171 | private func filter(_ input: Array, every step: Int = 10, limit: Int = 10000) -> [E] { 172 | Array(input.enumerated().filter({ $0.offset % step == 0 }).map({ $0.element }).prefix(limit)) 173 | } 174 | 175 | 176 | private func loadDescriptors(filename: String, extension: String = "txt") throws -> [SIFTDescriptor] { 177 | var descriptors = [SIFTDescriptor]() 178 | 179 | let fileURL = bundle.url(forResource: filename, withExtension: `extension`)! 180 | let data = try Data(contentsOf: fileURL) 181 | let string = String(data: data, encoding: .utf8)! 182 | let lines = string.split(separator: "\n") 183 | 184 | for line in lines { 185 | let components = line.split(separator: " ") 186 | guard components.count > 0 else { 187 | continue 188 | } 189 | let y = Float(components[0])! 190 | let x = Float(components[1])! 191 | let s = Float(components[2])! 192 | let theta = Float(components[3])! 193 | var features = Array(repeating: 0, count: 4 * 4 * 8) 194 | for i in 0 ..< features.count { 195 | features[i] = Int(components[i + 4])! 196 | } 197 | 198 | let descriptor = SIFTDescriptor( 199 | keypoint: SIFTKeypoint( 200 | octave: 0, 201 | scale: 0, 202 | subScale: 0, 203 | scaledCoordinate: .zero, 204 | absoluteCoordinate: SIMD2(x: x, y: y), 205 | sigma: s, 206 | value: 0 207 | ), 208 | theta: theta, 209 | // rawFeatures: [], 210 | features: IntVector(features) 211 | ) 212 | descriptors.append(descriptor) 213 | } 214 | 215 | return descriptors 216 | } 217 | 218 | } 219 | -------------------------------------------------------------------------------- /Sources/MetalShaders/Metal/SIFTInterpolate.metal: -------------------------------------------------------------------------------- 1 | // 2 | // siftInterpolate.metal 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/07. 6 | // 7 | 8 | #include 9 | 10 | #include "Common.hpp" 11 | 12 | #include "../include/SIFTInterpolate.h" 13 | 14 | using namespace metal; 15 | 16 | 17 | bool isOnEdge( 18 | texture2d_array t [[texture(0)]], 19 | int x, 20 | int y, 21 | int s, 22 | float edgeThreshold 23 | ) { 24 | const float v = t.read(ushort2(x, y), s).r; 25 | 26 | // Compute the 2d Hessian at pixel (i,j) - i = y, j = x 27 | // IPOL implementation uses hxx for y axis, and hyy for x axis 28 | const float zn = t.read(ushort2(x, y - 1), s).r; 29 | const float zp = t.read(ushort2(x, y + 1), s).r; 30 | const float pz = t.read(ushort2(x + 1, y), s).r; 31 | const float nz = t.read(ushort2(x - 1, y), s).r; 32 | const float pp = t.read(ushort2(x + 1, y + 1), s).r; 33 | const float np = t.read(ushort2(x - 1, y + 1), s).r; 34 | const float pn = t.read(ushort2(x + 1, y - 1), s).r; 35 | const float nn = t.read(ushort2(x - 1, y - 1), s).r; 36 | 37 | const float hxx = zn + zp - 2 * v; 38 | const float hyy = pz + nz - 2 * v; 39 | const float hxy = ((pp - np) - (pn - nn)) * 0.25; 40 | 41 | // Whess 42 | const float trace = hxx + hyy; 43 | const float determinant = (hxx * hyy) - (hxy * hxy); 44 | 45 | if (determinant <= 0) { 46 | // Negative determinant -> curvatures have different signs 47 | return true; 48 | } 49 | 50 | // let edgeThreshold = configuration.edgeThreshold 51 | const float threshold = ((edgeThreshold + 1) * (edgeThreshold + 1)) / edgeThreshold; 52 | const float curvature = (trace * trace) / determinant; 53 | 54 | if (curvature >= threshold) { 55 | // Feature is on an edge 56 | return true; 57 | } 58 | 59 | // Feature is not on an edge 60 | return false; 61 | } 62 | 63 | 64 | float3 derivatives3D( 65 | texture2d_array t [[texture(0)]], 66 | int x, 67 | int y, 68 | int s 69 | ) { 70 | const float pzz = t.read(ushort2(x + 1, y), s).r; 71 | const float nzz = t.read(ushort2(x - 1, y), s).r; 72 | const float zpz = t.read(ushort2(x, y + 1), s).r; 73 | const float znz = t.read(ushort2(x, y - 1), s).r; 74 | const float zzp = t.read(ushort2(x, y), s + 1).r; 75 | const float zzn = t.read(ushort2(x, y), s - 1).r; 76 | 77 | // x: (i[c.z][c.x + 1, c.y] - i[c.z][c.x - 1, c.y]) * 0.5, 78 | // y: (i[c.z][c.x, c.y + 1] - i[c.z][c.x, c.y - 1]) * 0.5, 79 | // z: (i[c.z + 1][c.x, c.y] - i[c.z - 1][c.x, c.y]) * 0.5 80 | 81 | return float3( 82 | (pzz - nzz) * 0.5, 83 | (zpz - znz) * 0.5, 84 | (zzp - zzn) * 0.5 85 | ); 86 | } 87 | 88 | 89 | float interpolateContrast( 90 | texture2d_array t [[texture(0)]], 91 | int x, 92 | int y, 93 | int s, 94 | float3 alpha 95 | ) { 96 | const float3 dD = derivatives3D(t, x, y, s); 97 | const float3 c = dD * alpha; 98 | const float v = t.read(ushort2(x, y), s).r; 99 | return v + c.x * 0.5; 100 | } 101 | 102 | 103 | // Computes the 3D Hessian matrix. 104 | // ⎡ Ixx Ixy Ixs ⎤ 105 | // 106 | // Ixy Iyy Iys 107 | // 108 | // ⎣ Ixs Iys Iss ⎦ 109 | float3x3 hessian3D( 110 | texture2d_array t [[texture(0)]], 111 | int x, 112 | int y, 113 | int s 114 | ) { 115 | // z = zero, p = positive, n = negative 116 | const float zzz = t.read(ushort2(x, y), s).r; 117 | 118 | const float pzz = t.read(ushort2(x + 1, y), s).r; 119 | const float nzz = t.read(ushort2(x - 1, y), s).r; 120 | 121 | const float zpz = t.read(ushort2(x, y + 1), s).r; 122 | const float znz = t.read(ushort2(x, y - 1), s).r; 123 | 124 | const float zzp = t.read(ushort2(x, y), s + 1).r; 125 | const float zzn = t.read(ushort2(x, y), s - 1).r; 126 | 127 | const float ppz = t.read(ushort2(x + 1, y + 1), s).r; 128 | const float nnz = t.read(ushort2(x - 1, y - 1), s).r; 129 | 130 | const float npz = t.read(ushort2(x - 1, y + 1), s).r; 131 | const float pnz = t.read(ushort2(x + 1, y - 1), s).r; 132 | 133 | const float pzp = t.read(ushort2(x + 1, y), s + 1).r; 134 | const float nzp = t.read(ushort2(x - 1, y), s + 1).r; 135 | const float zpp = t.read(ushort2(x, y + 1), s + 1).r; 136 | const float znp = t.read(ushort2(x, y - 1), s + 1).r; 137 | 138 | const float pzn = t.read(ushort2(x + 1, y), s - 1).r; 139 | const float nzn = t.read(ushort2(x - 1, y), s - 1).r; 140 | const float zpn = t.read(ushort2(x, y + 1), s - 1).r; 141 | const float znn = t.read(ushort2(x, y - 1), s - 1).r; 142 | 143 | 144 | // let dxx = pzz + nzz - 2 * v 145 | // let dyy = zpz + znz - 2 * v 146 | // let dss = zzp + zzn - 2 * v 147 | const float dxx = pzz + nzz - 2 * zzz; 148 | const float dyy = zpz + znz - 2 * zzz; 149 | const float dss = zzp + zzn - 2 * zzz; 150 | 151 | // let dxy = (ppz - npz - pnz + nnz) * 0.25 152 | // let dxs = (pzp - nzp - pzn + nzn) * 0.25 153 | // let dys = (zpp - znp - zpn + znn) * 0.25 154 | 155 | const float dxy = (ppz - npz - pnz + nnz) * 0.25; 156 | const float dxs = (pzp - nzp - pzn + nzn) * 0.25; 157 | const float dys = (zpp - znp - zpn + znn) * 0.25; 158 | 159 | return float3x3( 160 | float3(dxx, dxy, dxs), 161 | float3(dxy, dyy, dys), 162 | float3(dxs, dys, dss) 163 | ); 164 | } 165 | 166 | 167 | float3 interpolationStep( 168 | texture2d_array t [[texture(0)]], 169 | int x, 170 | int y, 171 | int scale 172 | ) { 173 | const float3x3 H = hessian3D(t, x, y, scale); 174 | float3x3 Hi = -1.0 * invert(H); 175 | const float3 dD = derivatives3D(t, x, y, scale); 176 | return Hi * dD; 177 | } 178 | 179 | 180 | bool outOfBounds(int x, int y, int scale, int width, int height, int scales) { 181 | // TODO: Configurable border. 182 | const int border = 5; 183 | const int minX = border; 184 | const int maxX = width - border - 1; 185 | const int minY = border; 186 | const int maxY = height - border - 1; 187 | const int minS = 1; 188 | const int maxS = scales; 189 | return x < minX || x > maxX || y < minY || y > maxY || scale < minS || scale > maxS; 190 | } 191 | 192 | 193 | kernel void siftInterpolate( 194 | device SIFTInterpolateOutputKeypoint * outputKeypoints [[buffer(0)]], 195 | device SIFTInterpolateInputKeypoint * inputKeypoints [[buffer(1)]], 196 | device SIFTInterpolateParameters & parameters [[buffer(2)]], 197 | texture2d_array dogTextures [[texture(0)]], 198 | ushort gid [[thread_position_in_grid]] 199 | ) { 200 | SIFTInterpolateInputKeypoint input = inputKeypoints[gid]; 201 | SIFTInterpolateOutputKeypoint output; 202 | output.converged = 0; 203 | outputKeypoints[gid] = output; 204 | 205 | float value = dogTextures.read(ushort2(input.x, input.y), input.scale).r; 206 | 207 | // Discard keypoint that is way below the brightness threshold 208 | if (abs(value) <= parameters.dogThreshold * 0.8) { 209 | return; 210 | } 211 | 212 | const int maxIterations = parameters.maxIterations; 213 | const float maxOffset = parameters.maxOffset; 214 | const int width = parameters.width; 215 | const int height = parameters.height; 216 | const int scales = parameters.numberOfScales; 217 | const float delta = parameters.octaveDelta; 218 | 219 | int x = input.x; 220 | int y = input.y; 221 | int scale = input.scale; 222 | 223 | if (outOfBounds(x, y, scale, width, height, scales)) { 224 | return; 225 | } 226 | 227 | bool converged = false; 228 | float3 alpha = float3(0); 229 | 230 | int i = 0; 231 | while (i < maxIterations) { 232 | alpha = interpolationStep(dogTextures, x, y, scale); 233 | 234 | if ((abs(alpha.x) < maxOffset) && (abs(alpha.y) < maxOffset) && (abs(alpha.z) < maxOffset)) { 235 | converged = true; 236 | break; 237 | } 238 | 239 | // Whess 240 | // coordinate.x += Int(alpha.x.rounded()) 241 | // coordinate.y += Int(alpha.y.rounded()) 242 | // coordinate.z += Int(alpha.z.rounded()) 243 | 244 | // IPOL 245 | // TODO: >= 246 | if (alpha.x > +maxOffset) { 247 | x += 1; 248 | } 249 | if (alpha.x < -maxOffset) { 250 | x -= 1; 251 | } 252 | if (alpha.y > +maxOffset) { 253 | y += 1; 254 | } 255 | if (alpha.y < -maxOffset) { 256 | y -= 1; 257 | } 258 | if (alpha.z > +maxOffset) { 259 | scale += 1; 260 | } 261 | if (alpha.z < -maxOffset) { 262 | scale -= 1; 263 | } 264 | 265 | if (outOfBounds(x, y, scale, width, height, scales)) { 266 | return; 267 | } 268 | 269 | i += 1; 270 | } 271 | 272 | if (!converged) { 273 | return; 274 | } 275 | 276 | value = interpolateContrast(dogTextures, x, y, scale, alpha); 277 | 278 | if (abs(value) <= parameters.dogThreshold) { 279 | return; 280 | } 281 | 282 | // Discard keypoint with high edge response 283 | if (isOnEdge(dogTextures, x, y, scale, parameters.edgeThreshold)) { 284 | return; 285 | } 286 | 287 | // Return keypoint 288 | output.converged = 1; 289 | output.scale = scale; 290 | output.subScale = alpha.z; 291 | output.relativeX = x; 292 | output.relativeY = y; 293 | output.absoluteX = ((float)x + alpha.x) * delta; 294 | output.absoluteY = ((float)y + alpha.y) * delta; 295 | output.value = value; 296 | output.alphaX = alpha.x; 297 | output.alphaY = alpha.y; 298 | output.alphaZ = alpha.z; 299 | outputKeypoints[gid] = output; 300 | } 301 | -------------------------------------------------------------------------------- /Tests/SIFTMetalTests/DifferenceOfGaussiansTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DifferenceOfGaussiansTests.swift 3 | // SkyLightTests 4 | // 5 | // Created by Luke Van In on 2022/12/24. 6 | // 7 | 8 | import XCTest 9 | import CoreImage 10 | import CoreImage.CIFilterBuiltins 11 | import MetalPerformanceShaders 12 | 13 | @testable import SIFTMetal 14 | 15 | /* 16 | 17 | final class DifferenceOfGaussiansTests: SharedTestCase { 18 | 19 | 20 | func testComputeDifferenceOfGaussians() throws { 21 | 22 | let referenceGaussianImages = try loadScaleSpaceTextures( 23 | name: "scalespace_butterfly", 24 | extension: "png", 25 | octaves: 1, 26 | scalesPerOctave: 5 27 | ) 28 | 29 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false) 30 | attachImage( 31 | name: "input", 32 | uiImage: ciContext.makeUIImage( 33 | ciImage: CIImage( 34 | mtlTexture: inputTexture, 35 | options: [ 36 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.genericGrayGamma2_2)!, 37 | ] 38 | )! 39 | .oriented(.downMirrored) 40 | .smearColor() 41 | ) 42 | ) 43 | 44 | let configuration = DifferenceOfGaussians.Configuration( 45 | inputDimensions: IntegralSize( 46 | width: Int(inputTexture.width), 47 | height: Int(inputTexture.height) 48 | ) 49 | ) 50 | let subject = DifferenceOfGaussians( 51 | device: device, 52 | configuration: configuration 53 | ) 54 | 55 | print("Encoding") 56 | let commandBuffer = commandQueue.makeCommandBuffer()! 57 | subject.encode( 58 | commandBuffer: commandBuffer, 59 | originalTexture: inputTexture 60 | ) 61 | commandBuffer.commit() 62 | commandBuffer.waitUntilCompleted() 63 | 64 | print("Saving attachments") 65 | 66 | 67 | var resultGaussainImages = [MTLTexture]() 68 | // for (o, octave) in subject.octaves.enumerated() { 69 | // for (s, texture) in octave.gaussianTextures.enumerated() { 70 | // resultGaussainImages.append(texture) 71 | // } 72 | // } 73 | 74 | compare( 75 | referenceImages: referenceGaussianImages, 76 | testImages: resultGaussainImages 77 | ) 78 | 79 | 80 | // attachImage( 81 | // name: "v(1, 0): luminosity", 82 | // uiImage: makeUIImage( 83 | // ciImage: smearColor( 84 | // ciImage: makeCIImage( 85 | // texture: subject.luminosityTexture 86 | // ) 87 | // ), 88 | // context: ciContext 89 | // ) 90 | // ) 91 | 92 | // attachImage( 93 | // name: "v(1, 0): scaled", 94 | // uiImage: makeUIImage( 95 | // ciImage: smearColor( 96 | // ciImage: makeCIImage( 97 | // texture: subject.scaledTexture 98 | // ) 99 | // ), 100 | // context: ciContext 101 | // ) 102 | // ) 103 | 104 | // attachImage( 105 | // name: "v(1, 0): seed", 106 | // uiImage: makeUIImage( 107 | // ciImage: smearColor( 108 | // ciImage: makeCIImage( 109 | // texture: subject.seedTexture 110 | // ) 111 | // ), 112 | // context: ciContext 113 | // ) 114 | // ) 115 | 116 | // for (o, octave) in subject.octaves.enumerated() { 117 | // 118 | // for (s, texture) in octave.gaussianTextures.enumerated() { 119 | // 120 | // attachImage( 121 | // name: "v[\(o), \(s)]", 122 | // uiImage: makeUIImage( 123 | // ciImage: smearColor( 124 | // ciImage: makeCIImage( 125 | // texture: texture 126 | // ) 127 | // ), 128 | // context: ciContext 129 | // ) 130 | // ) 131 | // } 132 | // 133 | // for (s, texture) in octave.differenceTextures.enumerated() { 134 | // 135 | // attachImage( 136 | // name: "w[\(o), \(s)]", 137 | // uiImage: makeUIImage( 138 | // ciImage: mapColor( 139 | // ciImage: normalizeColor( 140 | // ciImage: smearColor( 141 | // ciImage: makeCIImage( 142 | // texture: texture 143 | // ) 144 | // ) 145 | // ) 146 | // ), 147 | // context: ciContext 148 | // ) 149 | // ) 150 | // } 151 | // } 152 | } 153 | 154 | private func compare( 155 | referenceImages: [CIImage], 156 | testImages testTextures: [MTLTexture] 157 | ) { 158 | // let referenceImages = referenceTextures.map { 159 | // CIImage(mtlTexture: $0)!.oriented(.downMirrored) 160 | // } 161 | 162 | let testImages = testTextures.map { 163 | let image = CIImage( 164 | mtlTexture: $0, 165 | options: [ 166 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.genericGrayGamma2_2)!, 167 | ] 168 | )! 169 | let outputImage = image 170 | .settingAlphaOne(in: image.extent) 171 | .oriented(.downMirrored) 172 | .smearColor() 173 | return outputImage 174 | } 175 | 176 | XCTAssert(referenceImages.count == referenceImages.count) 177 | 178 | let differenceFilter = CIFilter.colorAbsoluteDifference() 179 | 180 | let colorFilter = CIFilter.colorControls() 181 | colorFilter.brightness = 0.5 182 | colorFilter.contrast = 2.0 183 | colorFilter.saturation = 1 184 | 185 | let thresholdFilter = CIFilter.colorThreshold() 186 | // thresholdFilter.threshold = 0.005 187 | thresholdFilter.threshold = 0.005 188 | 189 | let clampFilter = CIFilter.colorClamp() 190 | // clampFilter.minComponents = CIVector(x: 0.005, y: 0.005, z: 0.005, w: 0) 191 | // clampFilter.maxComponents = CIVector(x: 0.995, y: 0.995, z: 0.995, w: 1) 192 | clampFilter.minComponents = CIVector(x: 0, y: 0, z: 0, w: 0) 193 | clampFilter.maxComponents = CIVector(x: 1, y: 1, z: 1, w: 1) 194 | 195 | for i in 0 ..< referenceImages.count { 196 | let referenceImage = referenceImages[i] 197 | let testImage = testImages[i] 198 | 199 | let scaledTestImage: CIImage 200 | 201 | if testImage.extent.size != referenceImage.extent.size { 202 | scaledTestImage = testImage.samplingNearest().transformed( 203 | by: .identity.scaledBy( 204 | x: referenceImage.extent.width / testImage.extent.width, 205 | y: referenceImage.extent.height / testImage.extent.height 206 | ) 207 | ) 208 | } 209 | else { 210 | scaledTestImage = testImage 211 | } 212 | 213 | clampFilter.inputImage = scaledTestImage 214 | differenceFilter.inputImage = clampFilter.outputImage 215 | 216 | clampFilter.inputImage = referenceImage 217 | differenceFilter.inputImage2 = clampFilter.outputImage 218 | 219 | colorFilter.inputImage = differenceFilter.outputImage 220 | 221 | thresholdFilter.inputImage = differenceFilter.outputImage 222 | 223 | let differenceImage = colorFilter.outputImage! 224 | let thresholdImage = thresholdFilter.outputImage! 225 | 226 | attachImage( 227 | name: "scalespace \(i): reference", 228 | uiImage: ciContext.makeUIImage(ciImage: referenceImage) 229 | ) 230 | 231 | attachImage( 232 | name: "scalespace \(i): test", 233 | uiImage: ciContext.makeUIImage(ciImage: scaledTestImage) 234 | ) 235 | 236 | attachImage( 237 | name: "scalespace \(i): difference", 238 | uiImage: ciContext.makeUIImage(ciImage: differenceImage) 239 | ) 240 | 241 | attachImage( 242 | name: "scalespace \(i): threshold", 243 | uiImage: ciContext.makeUIImage(ciImage: thresholdImage) 244 | ) 245 | } 246 | } 247 | 248 | private func loadScaleSpaceTextures( 249 | name: String, 250 | extension: String, 251 | octaves: Int, 252 | scalesPerOctave: Int, 253 | file: StaticString = #file, 254 | line: UInt = #line 255 | ) throws -> [CIImage] { 256 | var output = [CIImage]() 257 | for o in 0 ..< octaves { 258 | for s in 0 ..< scalesPerOctave { 259 | let octaveName = String(format: "%03d", o) 260 | let scaleName = String(format: "%03d", s) 261 | let filename = "\(name)_o\(octaveName)_s\(scaleName)" 262 | let image = try CIImage(name: filename, extension: `extension`) 263 | // let image = try device.loadTexture(name: filename, extension: `extension`) 264 | output.append(image) 265 | } 266 | } 267 | return output 268 | } 269 | } 270 | */ 271 | 272 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/SIFT/SIFT.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFT.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2022/12/18. 6 | // 7 | 8 | import Foundation 9 | import OSLog 10 | import Metal 11 | 12 | import MetalShaders 13 | 14 | private let logger = Logger( 15 | subsystem: Bundle.main.bundleIdentifier!, 16 | category: "SIFT" 17 | ) 18 | 19 | 20 | /// 21 | /// Performs the Scale Invariant Feature Transform (SIFT) on an image. 22 | /// 23 | /// Extracts a set of robust feature descriptors for identifiable points on an image using the SIFT 24 | /// algorithm[1]. Uses Metal compute shaders to execute tasks using GPU hardware. 25 | /// 26 | /// This implementation is mostly based on the "Anatomy of the SIFT method"[2] paper and source code 27 | /// published by the Image Processing Online (IPOL) Journal. The implementation has come notable 28 | /// charactaristics not explicitly described in the paper. Their relevance toward the accuracy or correctness of 29 | /// the implementation is not indicated. 30 | /// - sRGB images are not converted to linear grayscale space for analysis. The SIFT analysis is performed 31 | /// on the lograithmic (^2.2) color space. 32 | /// - RGB colors are converted to grayscale using the ITU BT.709-5 (NTSC) color conversion formula. 33 | /// - The convolution kernel used for Gaussian blur is centered symmetrically on pixels. This differers to the 34 | /// positioning used by convolution kernels provided by Metal Performance Shaders. A custom convolution 35 | /// kernel is used to provide similarity to the IPOL implementation. 36 | /// 37 | /// Note: A novel method is used for matching SIFT descriptors, which is different to the methods used by 38 | /// Lowe and IPOL. Our method matches descriptors using a trie structure with leaf nodes forming a linked 39 | /// list. Construction of the trie takes O(n) time. Queries run in O(1) constant time. 40 | /// 41 | /// [1]: https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf "Distinctive Image Features from Scale-Invariant Keypoints", Lowe, 2004 42 | /// [2]: http://www.ipol.im/pub/art/2014/82/article.pdf "Anatomy of the SIFT Method", Rey-Otero & Delbracio, 2014 43 | /// 44 | /// Additional references: 45 | /// See: https://www.cs.ubc.ca/~lowe/keypoints/ 46 | /// See: https://en.wikipedia.org/wiki/Scale-invariant_feature_transform 47 | /// See: https://docs.opencv.org/4.x/da/df5/tutorial_py_sift_intro.html 48 | /// See: https://www.youtube.com/watch?v=4AvTMVD9ig0&t=232 49 | /// See: https://www.youtube.com/watch?v=U0wqePj4Mx0 50 | /// See: https://www.youtube.com/watch?v=ram-jbLJjFg&t=2s 51 | /// See: https://www.youtube.com/watch?v=NPcMS49V5hg 52 | /// See: https://github.com/robwhess/opensift 53 | /// See: https://medium.com/jun94-devpblog/cv-13-scale-invariant-local-feature-extraction-3-sift-315b5de72d48 54 | /// 55 | public final class SIFT { 56 | 57 | public struct Configuration { 58 | 59 | // Dimensions of the input image. 60 | var inputSize: IntegralSize 61 | 62 | // Threshold over the Difference of Gaussians response (value 63 | // relative to scales per octave = 3) 64 | var differenceOfGaussiansThreshold: Float = 0.0133 65 | 66 | // Threshold over the ratio of principal curvatures (edgeness). 67 | var edgeThreshold: Float = 10.0 68 | 69 | // Maximum number of consecutive unsuccessful interpolation. 70 | var maximumInterpolationIterations: Int = 5 71 | 72 | // Width of border in which to ignore keypoints 73 | var imageBorder: Int = 5 74 | 75 | // Sets how local is the analysis of the gradient distribution. 76 | var lambdaOrientation: Float = 1.5 77 | 78 | // Number of bins in the orientation histogram. 79 | var orientationBins: Int = 36 80 | 81 | // Threshold for considering local maxima in the orientation histogram. 82 | var orientationThreshold: Float = 0.8 83 | 84 | // Number of iterations used to smooth the orientation histogram 85 | var orientationSmoothingIterations: Int = 6 86 | 87 | // Number of normalized histograms in the normalized patch in the 88 | // descriptor. This must be a square integer number so that both x 89 | // and y axes have the same length. 90 | var descriptorHistogramsPerAxis: Int = 4 91 | 92 | // Number of bins in the descriptor histogram. 93 | var descriptorOrientationBins: Int = 8 94 | 95 | // How local the descriptor is (size of the descriptor). 96 | // Gaussian window of lambdaDescriptor * sigma 97 | // Descriptor patch width of 2 * lambdaDescriptor * sigma 98 | var lambdaDescriptor: Float = 6 99 | 100 | public init(inputSize: IntegralSize) { 101 | self.inputSize = inputSize 102 | } 103 | } 104 | 105 | let configuration: Configuration 106 | let dog: DifferenceOfGaussians 107 | let octaves: [SIFTOctave] 108 | 109 | private let device: MTLDevice 110 | private let commandQueue: MTLCommandQueue 111 | 112 | public init( 113 | device: MTLDevice, 114 | configuration: Configuration 115 | ) { 116 | self.device = device 117 | 118 | let dog = DifferenceOfGaussians( 119 | device: device, 120 | configuration: DifferenceOfGaussians.Configuration( 121 | inputDimensions: configuration.inputSize 122 | ) 123 | ) 124 | let octaves: [SIFTOctave] = { 125 | let gradientFunction = SIFTGradientKernel(device: device) 126 | 127 | var octaves = [SIFTOctave]() 128 | for scale in dog.octaves { 129 | let octave = SIFTOctave( 130 | device: device, 131 | scale: scale, 132 | gradientFunction: gradientFunction 133 | ) 134 | octaves.append(octave) 135 | } 136 | return octaves 137 | }() 138 | 139 | self.commandQueue = device.makeCommandQueue()! 140 | self.configuration = configuration 141 | self.dog = dog 142 | self.octaves = octaves 143 | } 144 | 145 | // MARK: Keypoints 146 | 147 | public func getKeypoints(_ inputTexture: MTLTexture) -> [[SIFTKeypoint]] { 148 | findKeypoints(inputTexture: inputTexture) 149 | let keypointOctaves = getKeypointsFromOctaves() 150 | let interpolatedKeypoints = interpolateKeypoints(keypointOctaves: keypointOctaves) 151 | return interpolatedKeypoints 152 | } 153 | 154 | private func findKeypoints(inputTexture: MTLTexture) { 155 | measure(name: "findKeypoints") { 156 | capture(commandQueue: commandQueue, capture: false) { 157 | let commandBuffer = commandQueue.makeCommandBuffer()! 158 | commandBuffer.label = "siftKeypointsCommandBuffer" 159 | 160 | dog.encode( 161 | commandBuffer: commandBuffer, 162 | originalTexture: inputTexture 163 | ) 164 | 165 | for octave in octaves { 166 | octave.encode( 167 | commandBuffer: commandBuffer 168 | ) 169 | } 170 | 171 | commandBuffer.commit() 172 | commandBuffer.waitUntilCompleted() 173 | } 174 | } 175 | } 176 | 177 | private func getKeypointsFromOctaves() -> [Buffer] { 178 | var output = [Buffer]() 179 | measure(name: "getKeypointsFromOctaves") { 180 | for octave in octaves { 181 | let keypoints = octave.getKeypoints() 182 | output.append(keypoints) 183 | } 184 | } 185 | let totalKeypoints = output.reduce(into: 0) { $0 += $1.count } 186 | logger.info("getKeypointsFromOctaves: Found \(totalKeypoints) keypoints") 187 | return output 188 | } 189 | 190 | private func interpolateKeypoints(keypointOctaves: [Buffer]) -> [[SIFTKeypoint]] { 191 | var output = [[SIFTKeypoint]]() 192 | measure(name: "interpolateKeypoints") { 193 | for o in 0 ..< keypointOctaves.count { 194 | let keypoints = keypointOctaves[o] 195 | output.append(octaves[o].interpolateKeypoints( 196 | commandQueue: commandQueue, 197 | keypoints: keypoints 198 | )) 199 | } 200 | } 201 | return output 202 | } 203 | 204 | 205 | // MARK: Descriptora 206 | 207 | public func getDescriptors(keypointOctaves: [[SIFTKeypoint]]) -> [[SIFTDescriptor]] { 208 | precondition(keypointOctaves.count == octaves.count) 209 | 210 | // Get all orientations for all keypoints. 211 | var orientationOctaves = [[SIFTKeypointOrientations]]() 212 | measure(name: "getDescriptors(orientations)") { 213 | for i in 0 ..< octaves.count { 214 | let octave = octaves[i] 215 | let keypoints = keypointOctaves[i] 216 | let orientationOctave = octave.getKeypointOrientations( 217 | commandQueue: commandQueue, 218 | keypoints: keypoints 219 | ) 220 | orientationOctaves.append(orientationOctave) 221 | } 222 | } 223 | 224 | // Get descriptors for each orientation. 225 | var output: [[SIFTDescriptor]] = [] 226 | measure(name: "getDescriptors(descriptors)") { 227 | for i in 0 ..< octaves.count { 228 | let octave = octaves[i] 229 | let orientationOctave = orientationOctaves[i] 230 | let descriptors = octave.getDescriptors( 231 | commandQueue: commandQueue, 232 | orientationOctave: orientationOctave 233 | ) 234 | output.append(descriptors) 235 | } 236 | } 237 | return output 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/Utilities/SIFTRenderer.swift: -------------------------------------------------------------------------------- 1 | import UIKit 2 | import CoreGraphics 3 | 4 | 5 | public final class SIFTRenderer { 6 | 7 | public init() { 8 | 9 | } 10 | 11 | public func drawKeypoints( 12 | sourceImage: CGImage, 13 | overlayColor: UIColor = UIColor.black.withAlphaComponent(0.8), 14 | referenceColor: UIColor = UIColor.red, 15 | foundColor: UIColor = UIColor.green, 16 | referenceKeypoints: [SIFTKeypoint], 17 | foundKeypoints: [SIFTKeypoint] 18 | ) -> UIImage { 19 | 20 | let bounds = CGRect(x: 0, y: 0, width: sourceImage.width, height: sourceImage.height) 21 | 22 | let renderer = UIGraphicsImageRenderer(size: bounds.size) 23 | let uiImage = renderer.image { context in 24 | let cgContext = context.cgContext 25 | 26 | cgContext.saveGState() 27 | cgContext.scaleBy(x: 1, y: -1) 28 | cgContext.translateBy(x: 0, y: -bounds.height) 29 | cgContext.draw(sourceImage, in: bounds) 30 | cgContext.restoreGState() 31 | 32 | cgContext.saveGState() 33 | cgContext.setBlendMode(.multiply) 34 | cgContext.setFillColor(overlayColor.cgColor) 35 | cgContext.fill([bounds]) 36 | cgContext.restoreGState() 37 | 38 | cgContext.saveGState() 39 | cgContext.setLineWidth(1) 40 | cgContext.setStrokeColor(referenceColor.cgColor) 41 | // cgContext.setBlendMode(.screen) 42 | for keypoint in referenceKeypoints { 43 | let radius = CGFloat(keypoint.sigma) 44 | let bounds = CGRect( 45 | x: CGFloat(keypoint.absoluteCoordinate.x) - radius, 46 | y: CGFloat(keypoint.absoluteCoordinate.y) - radius, 47 | width: radius * 2, 48 | height: radius * 2 49 | ) 50 | cgContext.addEllipse(in: bounds) 51 | } 52 | cgContext.strokePath() 53 | cgContext.restoreGState() 54 | 55 | 56 | cgContext.saveGState() 57 | cgContext.setLineWidth(1) 58 | cgContext.setStrokeColor(foundColor.cgColor) 59 | cgContext.setBlendMode(.screen) 60 | for keypoint in foundKeypoints { 61 | let radius = CGFloat(keypoint.sigma) 62 | let bounds = CGRect( 63 | x: CGFloat(keypoint.absoluteCoordinate.x) - radius, 64 | y: CGFloat(keypoint.absoluteCoordinate.y) - radius, 65 | width: radius * 2, 66 | height: radius * 2 67 | ) 68 | cgContext.addEllipse(in: bounds) 69 | } 70 | cgContext.strokePath() 71 | cgContext.restoreGState() 72 | } 73 | return uiImage 74 | } 75 | 76 | 77 | func drawDescriptors( 78 | sourceImage: CGImage, 79 | overlayColor: UIColor = UIColor.black.withAlphaComponent(0.8), 80 | referenceColor: UIColor = UIColor.green, 81 | foundColor: UIColor = UIColor.red, 82 | referenceDescriptors: [SIFTDescriptor], 83 | foundDescriptors: [SIFTDescriptor] 84 | ) -> UIImage { 85 | 86 | let bounds = CGRect(x: 0, y: 0, width: sourceImage.width, height: sourceImage.height) 87 | 88 | let renderer = UIGraphicsImageRenderer(size: bounds.size) 89 | let uiImage = renderer.image { context in 90 | let cgContext = context.cgContext 91 | 92 | cgContext.saveGState() 93 | cgContext.scaleBy(x: 1, y: -1) 94 | cgContext.translateBy(x: 0, y: -bounds.height) 95 | cgContext.draw(sourceImage, in: bounds) 96 | cgContext.restoreGState() 97 | 98 | cgContext.saveGState() 99 | cgContext.setBlendMode(.multiply) 100 | cgContext.setFillColor(overlayColor.cgColor) 101 | cgContext.fill([bounds]) 102 | cgContext.restoreGState() 103 | 104 | drawDescriptors(cgContext: cgContext, color: referenceColor, descriptors: referenceDescriptors) 105 | 106 | drawDescriptors(cgContext: cgContext, color: foundColor, descriptors: foundDescriptors) 107 | 108 | } 109 | return uiImage 110 | } 111 | 112 | 113 | public func drawMatches( 114 | sourceImage: CGImage, 115 | targetImage: CGImage, 116 | overlayColor: UIColor = UIColor.black.withAlphaComponent(0.8), 117 | sourceColor: UIColor = UIColor.cyan, 118 | targetColor: UIColor = UIColor.yellow, 119 | matches: [SIFTCorrespondence] 120 | ) -> UIImage { 121 | 122 | precondition(targetImage.width == sourceImage.width) 123 | precondition(targetImage.height == sourceImage.height) 124 | let imageSize = CGSize(width: sourceImage.width, height: sourceImage.height) 125 | let bounds = CGRect(x: 0, y: 0, width: imageSize.width * 2, height: imageSize.height) 126 | 127 | let sourceOffset = CGPoint(x: 0, y: 0) 128 | let targetOffset = CGPoint(x: imageSize.width, y: 0) 129 | let sourceBounds = CGRect(origin: sourceOffset, size: imageSize) 130 | let targetBounds = CGRect(origin: targetOffset, size: imageSize) 131 | 132 | let sourceDescriptors = matches.map { $0.source } 133 | let targetDescriptors = matches.map { $0.target } 134 | 135 | let colors: [UIColor] = [ 136 | .systemRed, .systemOrange, .systemYellow, .systemGreen, .systemTeal, .systemBlue, .systemPurple, .systemIndigo 137 | ] 138 | let radius = CGFloat(2) 139 | 140 | let renderer = UIGraphicsImageRenderer(size: bounds.size) 141 | let uiImage = renderer.image { context in 142 | let cgContext = context.cgContext 143 | 144 | // Draw images 145 | cgContext.saveGState() 146 | cgContext.scaleBy(x: 1, y: -1) 147 | cgContext.translateBy(x: 0, y: -bounds.height) 148 | cgContext.draw(sourceImage, in: sourceBounds) 149 | cgContext.draw(targetImage, in: targetBounds) 150 | cgContext.restoreGState() 151 | 152 | // Overlay 153 | cgContext.saveGState() 154 | cgContext.setBlendMode(.multiply) 155 | cgContext.setFillColor(overlayColor.cgColor) 156 | cgContext.fill([bounds]) 157 | cgContext.restoreGState() 158 | 159 | // Descriptors 160 | drawDescriptors(cgContext: cgContext, at: sourceOffset, color: sourceColor, descriptors: sourceDescriptors) 161 | drawDescriptors(cgContext: cgContext, at: targetOffset, color: targetColor, descriptors: targetDescriptors) 162 | 163 | cgContext.saveGState() 164 | cgContext.setLineWidth(1) 165 | cgContext.setBlendMode(.screen) 166 | 167 | for i in 0 ..< matches.count { 168 | let n = i % (colors.count - 1) 169 | let color = colors[n] 170 | 171 | let match = matches[i] 172 | let ks = match.source.keypoint.absoluteCoordinate 173 | let kt = match.target.keypoint.absoluteCoordinate 174 | let ps = CGPoint( 175 | x: sourceOffset.x + CGFloat(ks.x), 176 | y: sourceOffset.y + CGFloat(ks.y) 177 | ) 178 | let pt = CGPoint( 179 | x: targetOffset.x + CGFloat(kt.x), 180 | y: targetOffset.y + CGFloat(kt.y) 181 | ) 182 | // let bs = CGRect( 183 | // x: ps.x - radius, 184 | // y: ps.y - radius, 185 | // width: radius * 2, 186 | // height: radius * 2 187 | // ) 188 | // let bt = CGRect( 189 | // x: pt.x - radius, 190 | // y: pt.y - radius, 191 | // width: radius * 2, 192 | // height: radius * 2 193 | // ) 194 | 195 | cgContext.move(to: ps) 196 | cgContext.addLine(to: pt) 197 | 198 | if i % 10 == 0 { 199 | cgContext.setLineWidth(2) 200 | cgContext.setStrokeColor(color.withAlphaComponent(0.5).cgColor) 201 | } 202 | else { 203 | cgContext.setLineWidth(0.5) 204 | cgContext.setStrokeColor(color.withAlphaComponent(0.3).cgColor) 205 | } 206 | 207 | cgContext.strokePath() 208 | } 209 | 210 | cgContext.restoreGState() 211 | 212 | } 213 | return uiImage 214 | } 215 | 216 | 217 | private func drawDescriptors( 218 | cgContext: CGContext, 219 | at offset: CGPoint = .zero, 220 | color: UIColor, 221 | descriptors: [SIFTDescriptor] 222 | ) { 223 | cgContext.saveGState() 224 | cgContext.setLineWidth(0.5) 225 | cgContext.setStrokeColor(color.cgColor) 226 | cgContext.setBlendMode(.screen) 227 | for descriptor in descriptors { 228 | let keypoint = descriptor.keypoint 229 | let radius = 1.5 * CGFloat(keypoint.sigma) 230 | let center = CGPoint( 231 | x: offset.x + CGFloat(keypoint.absoluteCoordinate.x), 232 | y: offset.y + CGFloat(keypoint.absoluteCoordinate.y) 233 | ) 234 | let bounds = CGRect( 235 | x: center.x - radius, 236 | y: center.y - radius, 237 | width: radius * 2, 238 | height: radius * 2 239 | ) 240 | cgContext.addEllipse(in: bounds) 241 | 242 | // Primary Orientation 243 | cgContext.move(to: center) 244 | cgContext.addLine( 245 | to: CGPoint( 246 | x: center.x + cos(CGFloat(descriptor.theta)) * radius, 247 | y: center.y + sin(CGFloat(descriptor.theta)) * radius 248 | ) 249 | ) 250 | 251 | // 252 | } 253 | cgContext.strokePath() 254 | cgContext.restoreGState() 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /Sources/SIFTMetal/SIFT/SIFTDescriptor.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SIFTDescriptor.swift 3 | // SkyLight 4 | // 5 | // Created by Luke Van In on 2023/01/02. 6 | // 7 | 8 | import Foundation 9 | import simd 10 | 11 | 12 | public struct SIFTDescriptor { 13 | 14 | // Detected keypoint. 15 | public let keypoint: SIFTKeypoint 16 | // Principal orientation. 17 | public let theta: Float 18 | // Quantized features 19 | public let features: IntVector 20 | 21 | public let rawFeatures: FloatVector 22 | 23 | private let indexKey: FloatVector 24 | private let indexValue: FloatVector 25 | 26 | public init( 27 | keypoint: SIFTKeypoint, 28 | theta: Float, 29 | features: IntVector 30 | ) { 31 | precondition(features.count > 0) 32 | self.keypoint = keypoint 33 | self.theta = theta 34 | self.features = features 35 | 36 | let rawFeatures = { 37 | let components = features.components.map { Float($0) / Float(255) } 38 | return FloatVector(components) 39 | }() 40 | self.rawFeatures = rawFeatures 41 | 42 | let indexFeatures = { 43 | let count = 8 44 | let features = stride(from: 0, to: rawFeatures.count, by: count) 45 | .map { start in 46 | let end = min(start + count, rawFeatures.count) 47 | let patch = rawFeatures.components[start ..< end] 48 | return FloatVector(Array(patch)) 49 | } 50 | let newOrder = [ 51 | // Center 52 | features[5], 53 | features[6], 54 | features[9], 55 | features[10], 56 | 57 | // Corners 58 | features[0], 59 | features[3], 60 | features[12], 61 | features[15], 62 | 63 | // Edges 64 | features[1], 65 | features[2], 66 | features[4], 67 | features[7], 68 | features[8], 69 | features[11], 70 | features[13], 71 | features[14], 72 | ] 73 | precondition(newOrder.count == 16) 74 | 75 | return newOrder 76 | }() 77 | 78 | // let prefix = [keypoint.normalizedCoordinate.x, keypoint.normalizedCoordinate.y] 79 | 80 | self.indexValue = { 81 | let components = Array(indexFeatures.map { $0.components }.joined()) 82 | return FloatVector(components) 83 | }() 84 | 85 | self.indexKey = { 86 | let v = indexFeatures.map { $0.mean() } 87 | precondition(v.count == 16) 88 | return FloatVector(v) 89 | }() 90 | } 91 | 92 | private static func distance(_ a: SIFTDescriptor, _ b: SIFTDescriptor) -> Float { 93 | precondition(a.indexValue.count == 128) 94 | precondition(b.indexValue.count == 128) 95 | return a.indexValue.distance(to: b.indexValue) 96 | // var t = 0 97 | // for i in 0 ..< 128 { 98 | // let d = b.features[i] - a.features[i] 99 | // t += (d * d) 100 | // } 101 | // return sqrt(Float(t)) 102 | } 103 | 104 | public static func matchGeometry( 105 | source: [SIFTDescriptor], 106 | target: [SIFTDescriptor], 107 | absoluteThreshold: Float = 1.176, 108 | relativeThreshold: Float = 0.6 109 | ) -> Float { 110 | // let sampleLimit = 10 111 | // let sampleRatio: Float = 0.10 112 | // let maximumMatches = min(source.count, target.count) 113 | // let minimumSampleSize = max(Int(Float(maximumMatches) * sampleRatio), sampleLimit) 114 | let minimumSampleSize = 7 115 | let maximumSampleSize = max(minimumSampleSize, 80) 116 | // guard minimumSampleSize >= sampleLimit else { 117 | // print("matchGeometry: rejected: Not enough samples: \(minimumSampleSize) out of \(sampleLimit)") 118 | // return 0 119 | // } 120 | // guard source.count >= minimumSampleSize else { 121 | // print("matchGeometry: rejected: Not enough source samples: \(source.count) out of \(minimumSampleSize)") 122 | // return 0 123 | // } 124 | // guard target.count >= minimumSampleSize else { 125 | // print("matchGeometry: rejected: Not enough target samples: \(target.count) out of \(minimumSampleSize)") 126 | // return 0 127 | // } 128 | let matches = match(source: source, target: target, absoluteThreshold: absoluteThreshold, relativeThreshold: relativeThreshold) 129 | guard matches.count >= minimumSampleSize else { 130 | print("matchGeometry: rejected: Not enough matches: \(matches.count) out of \(minimumSampleSize)") 131 | return 0 132 | } 133 | // let matchRatio = Float(matches.count) / Float(maximumMatches) 134 | // guard matchRatio >= minimumMatchRatio else { 135 | // print("matchGeometry: rejected: Source match ratio too low: \(matchRatio) out of \(minimumMatchRatio )") 136 | // return 0 137 | // } 138 | let sample = Array(matches.prefix(maximumSampleSize)) 139 | return compareGeometry( 140 | matches: sample, 141 | minimumSampleSize: minimumSampleSize 142 | ) 143 | } 144 | 145 | private static func makeCoordinate(_ keypoint: SIFTKeypoint) -> SIMD2 { 146 | // return SIMD2( 147 | // keypoint.normalizedCoordinate.x * 3, 148 | // keypoint.normalizedCoordinate.y * 3, 149 | // keypoint.sigma 150 | // ) 151 | // return SIMD2( 152 | // keypoint.absoluteCoordinate.x / keypoint.sigma, 153 | // keypoint.absoluteCoordinate.y / keypoint.sigma 154 | // ) 155 | return keypoint.absoluteCoordinate 156 | // return keypoint.normalizedCoordinate 157 | } 158 | 159 | private static func dotProduct(_ a: SIMD2, _ b: SIMD2) -> Float { 160 | return simd_clamp((simd_dot(a, b) * 0.5) + 0.5, 0, 1) 161 | } 162 | 163 | private static func compareGeometry( 164 | matches: [SIFTCorrespondence], 165 | minimumSampleSize: Int 166 | ) -> Float { 167 | 168 | print("compareGeometry:", "samples", matches.count, "minimum samples", minimumSampleSize) 169 | 170 | let minimumLength: Float = 2 171 | 172 | var sum: Float = 0 173 | var count: Int = 0 174 | var scores: [Float] = [] 175 | for i in stride(from: 0, to: matches.count - 3, by: 1) { 176 | 177 | // if count >= minimumSampleSize { 178 | // break 179 | // } 180 | 181 | let m0 = matches[i + 0] 182 | let m1 = matches[i + 1] 183 | 184 | let sourceBase = makeCoordinate(m1.source.keypoint) - makeCoordinate(m0.source.keypoint) 185 | let targetBase = makeCoordinate(m1.target.keypoint) - makeCoordinate(m0.target.keypoint) 186 | 187 | let sourceBaseLength = simd_length(sourceBase) 188 | let targetBaseLength = simd_length(targetBase) 189 | 190 | guard sourceBaseLength >= minimumLength else { 191 | continue 192 | } 193 | 194 | guard targetBaseLength >= minimumLength else { 195 | continue 196 | } 197 | 198 | let sourceBaseNormal = simd_normalize(sourceBase) 199 | let targetBaseNormal = simd_normalize(targetBase) 200 | 201 | let m2 = matches[i + 2] 202 | let m3 = matches[i + 3] 203 | let sourceTest = makeCoordinate(m3.source.keypoint) - makeCoordinate(m2.source.keypoint) 204 | let targetTest = makeCoordinate(m3.target.keypoint) - makeCoordinate(m2.target.keypoint) 205 | 206 | let sourceTestLength = simd_length(sourceTest) 207 | let targetTestLength = simd_length(targetTest) 208 | 209 | guard sourceTestLength >= minimumLength else { 210 | continue 211 | } 212 | 213 | guard targetTestLength >= minimumLength else { 214 | continue 215 | } 216 | 217 | let sourceTestNormal = simd_normalize(sourceTest) 218 | let targetTestNormal = simd_normalize(targetTest) 219 | 220 | let sourceRatio = sourceTestLength / sourceBaseLength 221 | let targetRatio = targetTestLength / targetBaseLength 222 | 223 | let sourceDotProduct = dotProduct(sourceTestNormal, sourceBaseNormal) 224 | let targetDotProduct = dotProduct(targetTestNormal, targetBaseNormal) 225 | 226 | precondition(sourceDotProduct >= 0) 227 | precondition(sourceDotProduct <= 1) 228 | precondition(targetDotProduct >= 0) 229 | precondition(targetDotProduct <= 1) 230 | 231 | let orientationSimilarity = 1.0 - abs(sourceDotProduct - targetDotProduct) 232 | precondition(orientationSimilarity >= 0) 233 | precondition(orientationSimilarity <= 1) 234 | 235 | let scaleSimilarity: Float 236 | if sourceRatio < targetRatio { 237 | scaleSimilarity = simd_clamp(sourceRatio / targetRatio, 0, 1) 238 | } 239 | else { 240 | scaleSimilarity = simd_clamp(targetRatio / sourceRatio, 0, 1) 241 | } 242 | precondition(scaleSimilarity >= 0) 243 | precondition(scaleSimilarity <= 1) 244 | 245 | let similarity = orientationSimilarity * scaleSimilarity 246 | let score = similarity * similarity 247 | scores.append(score) 248 | sum += score 249 | count += 1 250 | } 251 | 252 | guard count >= minimumSampleSize else { 253 | return 0 254 | } 255 | 256 | let mean = sum / Float(count) 257 | 258 | var error: Float = 0 259 | for score in scores { 260 | let delta = score - mean 261 | error += (delta * delta) 262 | } 263 | let variance = error / Float(count - 1) 264 | let standardDeviation = sqrt(variance) 265 | 266 | var zscores: [Float] = [] 267 | for score in scores { 268 | let zscore = abs((score - mean) / standardDeviation) 269 | zscores.append(zscore) 270 | } 271 | 272 | var fairMeanSum: Float = 0 273 | var fairMeanCount: Float = 0 274 | for i in 0 ..< scores.count { 275 | let zscore = zscores[i] 276 | if zscore <= 2 { 277 | let score = scores[i] 278 | fairMeanSum += score 279 | fairMeanCount += 1 280 | } 281 | } 282 | let fairMean = fairMeanSum / fairMeanCount 283 | 284 | print( 285 | "compareGeometry", 286 | "count", count, 287 | "mean", mean, 288 | "fairMean", fairMean, 289 | "variance", variance, 290 | "standard deviation", standardDeviation, 291 | "scores", scores, 292 | "zscores", zscores 293 | ) 294 | 295 | return fairMean 296 | } 297 | 298 | public static func match( 299 | source: [SIFTDescriptor], 300 | target: [SIFTDescriptor], 301 | absoluteThreshold: Float = 1.176, 302 | relativeThreshold: Float = 0.6 303 | ) -> [SIFTCorrespondence] { 304 | var output = [SIFTCorrespondence]() 305 | 306 | for descriptor in source { 307 | let correspondence = match( 308 | descriptor: descriptor, 309 | target: target, 310 | absoluteThreshold: absoluteThreshold, 311 | relativeThreshold: relativeThreshold 312 | ) 313 | if let correspondence { 314 | output.append(correspondence) 315 | } 316 | } 317 | return output 318 | } 319 | 320 | public static func match( 321 | descriptor: SIFTDescriptor, 322 | target: [SIFTDescriptor], 323 | absoluteThreshold: Float = 300, 324 | relativeThreshold: Float = 0.6 325 | ) -> SIFTCorrespondence? { 326 | 327 | var bestMatch: SIFTDescriptor? 328 | var bestMatchDistance: Float = .greatestFiniteMagnitude 329 | var secondBestMatchDistance: Float? 330 | 331 | for t in target { 332 | let distance = descriptor.indexValue.distance(to: t.indexValue) 333 | if distance < bestMatchDistance { 334 | bestMatch = t 335 | secondBestMatchDistance = bestMatchDistance 336 | bestMatchDistance = distance 337 | } 338 | } 339 | 340 | guard let bestMatch else { 341 | return nil 342 | } 343 | 344 | guard let secondBestMatchDistance else { 345 | return nil 346 | } 347 | 348 | guard bestMatchDistance < absoluteThreshold else { 349 | return nil 350 | } 351 | 352 | guard bestMatchDistance < (secondBestMatchDistance * relativeThreshold) else { 353 | return nil 354 | } 355 | 356 | return SIFTCorrespondence( 357 | source: descriptor, 358 | target: bestMatch, 359 | featureDistance: bestMatchDistance 360 | ) 361 | } 362 | public static func approximateMatch( 363 | source: [SIFTDescriptor], 364 | target: [SIFTDescriptor], 365 | absoluteThreshold: Float = 300, 366 | relativeThreshold: Float = 0.6 367 | ) -> [SIFTCorrespondence] { 368 | var output = [SIFTCorrespondence]() 369 | 370 | let trie = Trie(numberOfBins: 8) 371 | for descriptor in target { 372 | trie.insert(key: descriptor.indexKey, value: descriptor) 373 | } 374 | trie.link() 375 | 376 | for descriptor in source { 377 | let correspondence = approximateMatch( 378 | descriptor: descriptor, 379 | target: trie, 380 | absoluteThreshold: absoluteThreshold, 381 | relativeThreshold: relativeThreshold 382 | ) 383 | if let correspondence { 384 | output.append(correspondence) 385 | } 386 | } 387 | return output 388 | } 389 | 390 | public static func approximateMatch( 391 | descriptor: SIFTDescriptor, 392 | target: Trie, 393 | absoluteThreshold: Float = 300, 394 | relativeThreshold: Float = 0.6 395 | ) -> SIFTCorrespondence? { 396 | let matches = target.nearest(key: descriptor.indexKey, query: descriptor, radius: 10, k: 2) 397 | guard matches.count == 2 else { 398 | return nil 399 | } 400 | 401 | let bestMatch = matches[0] 402 | let secondBestMatch = matches[1] 403 | 404 | guard bestMatch.distance < absoluteThreshold else { 405 | return nil 406 | } 407 | 408 | guard bestMatch.distance < (secondBestMatch.distance * relativeThreshold) else { 409 | return nil 410 | } 411 | 412 | return SIFTCorrespondence( 413 | source: descriptor, 414 | target: bestMatch.value, 415 | featureDistance: bestMatch.distance 416 | ) 417 | } 418 | } 419 | 420 | extension SIFTDescriptor { 421 | 422 | func distance(to other: SIFTDescriptor) -> Float { 423 | features.distance(to: other.features) 424 | } 425 | 426 | func distanceSquared(to other: SIFTDescriptor) -> Float { 427 | Float(features.distanceSquared(to: other.features)) 428 | } 429 | } 430 | --------------------------------------------------------------------------------