├── Sources
├── MetalShaders
│ ├── MetalShaders.c
│ ├── module.modulemap
│ ├── include
│ │ ├── MetalShaders.h
│ │ ├── NearestNeighbor.h
│ │ ├── SIFTExtrema.h
│ │ ├── ConvolutionSeries.h
│ │ ├── SIFTOrientation.h
│ │ ├── SIFTDescriptor.h
│ │ └── SIFTInterpolate.h
│ └── Metal
│ │ ├── Subtract.metal
│ │ ├── NearestNeighborDownScale.metal
│ │ ├── NearestNeighborUpScale.metal
│ │ ├── ConvertSRGBToGrayscale.metal
│ │ ├── SIFTGradient.metal
│ │ ├── Convolution.metal
│ │ ├── ConvolutionSeries.metal
│ │ ├── BilinearUpScale.metal
│ │ ├── Common.hpp
│ │ ├── SIFTExtrema.metal
│ │ ├── SIFTOrientation.metal
│ │ ├── SIFTDescriptor.metal
│ │ └── SIFTInterpolate.metal
└── SIFTMetal
│ ├── SIFT
│ ├── SIFTKeypointOrientations.swift
│ ├── SIFTCorrespondence.swift
│ ├── SIFTKeypoint.swift
│ ├── SIFTPatch.swift
│ ├── SIFT.swift
│ └── SIFTDescriptor.swift
│ ├── Extensions
│ ├── CoreGraphicsExtensions.swift
│ ├── FoundationExtensions.swift
│ └── CoreImageExtensions.swift
│ ├── Utilities
│ ├── Performance.swift
│ ├── Math.swift
│ ├── MetalExtensions.swift
│ ├── CoreVideoMetalCache.swift
│ ├── ImageConversion.swift
│ ├── Buffer.swift
│ ├── Quad.swift
│ ├── Image.swift
│ └── SIFTRenderer.swift
│ └── Metal Compute
│ ├── NearestNeighborUpScaleKernel.swift
│ ├── SIFTDescriptorKernel.swift
│ ├── BilinearUpScaleKernel.swift
│ ├── SIFTExtremaKernel.swift
│ ├── ConvertSRGBToGrayscaleKernel.swift
│ ├── SIFTExtremaListKernel.swift
│ ├── SubtractKernel.swift
│ ├── SIFTGradientKernel.swift
│ ├── SIFTInterpolateKernel.swift
│ ├── SIFTOrientationKernel.swift
│ ├── GaussianKernel.swift
│ ├── NearestNeighborDownScaleKernel.swift
│ ├── Convolution1DKernel.swift
│ ├── ConvolutionSeriesKernel.swift
│ └── GaussianSeriesKernel.swift
├── Tests
└── SIFTMetalTests
│ ├── Resources
│ ├── butterfly.png
│ ├── DoG_butterfly_o000_s001.png
│ ├── DoG_butterfly_o000_s002.png
│ ├── DoG_butterfly_o000_s003.png
│ ├── DoG_butterfly_o001_s001.png
│ ├── DoG_butterfly_o001_s002.png
│ ├── DoG_butterfly_o001_s003.png
│ ├── DoG_butterfly_o002_s001.png
│ ├── DoG_butterfly_o002_s002.png
│ ├── DoG_butterfly_o002_s003.png
│ ├── DoG_butterfly_o003_s001.png
│ ├── DoG_butterfly_o003_s002.png
│ ├── DoG_butterfly_o003_s003.png
│ ├── DoG_butterfly_o004_s001.png
│ ├── DoG_butterfly_o004_s002.png
│ ├── DoG_butterfly_o004_s003.png
│ ├── scalespace_butterfly_o000_s000.png
│ ├── scalespace_butterfly_o000_s001.png
│ ├── scalespace_butterfly_o000_s002.png
│ ├── scalespace_butterfly_o000_s003.png
│ ├── scalespace_butterfly_o000_s004.png
│ ├── scalespace_butterfly_o000_s005.png
│ ├── scalespace_butterfly_o001_s000.png
│ ├── scalespace_butterfly_o001_s001.png
│ ├── scalespace_butterfly_o001_s002.png
│ ├── scalespace_butterfly_o001_s003.png
│ ├── scalespace_butterfly_o001_s004.png
│ ├── scalespace_butterfly_o001_s005.png
│ ├── scalespace_butterfly_o002_s000.png
│ ├── scalespace_butterfly_o002_s001.png
│ ├── scalespace_butterfly_o002_s002.png
│ ├── scalespace_butterfly_o002_s003.png
│ ├── scalespace_butterfly_o002_s004.png
│ ├── scalespace_butterfly_o002_s005.png
│ ├── scalespace_butterfly_o003_s000.png
│ ├── scalespace_butterfly_o003_s001.png
│ ├── scalespace_butterfly_o003_s002.png
│ ├── scalespace_butterfly_o003_s003.png
│ ├── scalespace_butterfly_o003_s004.png
│ ├── scalespace_butterfly_o003_s005.png
│ ├── scalespace_butterfly_o004_s000.png
│ ├── scalespace_butterfly_o004_s001.png
│ ├── scalespace_butterfly_o004_s002.png
│ ├── scalespace_butterfly_o004_s003.png
│ ├── scalespace_butterfly_o004_s004.png
│ └── scalespace_butterfly_o004_s005.png
│ ├── GaussianDifferenceTests.swift
│ ├── KeypointTests.swift
│ ├── SharedTestCase.swift
│ ├── TrieTests.swift
│ ├── DescriptorTests.swift
│ └── DifferenceOfGaussiansTests.swift
├── .gitignore
├── .swiftpm
└── xcode
│ ├── package.xcworkspace
│ └── xcshareddata
│ │ └── IDEWorkspaceChecks.plist
│ └── xcshareddata
│ └── xcschemes
│ └── SIFTMetal.xcscheme
├── LICENSE
├── README.md
└── Package.swift
/Sources/MetalShaders/MetalShaders.c:
--------------------------------------------------------------------------------
1 | // File intentionally left blank
2 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/module.modulemap:
--------------------------------------------------------------------------------
1 | module Shaders {
2 | header "MetalShaders.h"
3 | export *
4 | }
5 |
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/butterfly.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/butterfly.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o000_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o001_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o002_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o003_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/DoG_butterfly_o004_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s000.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s004.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o000_s005.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s000.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s004.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o001_s005.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s000.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s004.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o002_s005.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s000.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s004.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o003_s005.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s000.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s001.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s002.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s003.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s004.png
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lukevanin/SIFTMetal/HEAD/Tests/SIFTMetalTests/Resources/scalespace_butterfly_o004_s005.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | /.build
3 | /Packages
4 | /*.xcodeproj
5 | xcuserdata/
6 | DerivedData/
7 | .swiftpm/config/registries.json
8 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
9 | .netrc
10 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/SIFT/SIFTKeypointOrientations.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTOrientedKeypoint.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/10.
6 | //
7 |
8 | import Foundation
9 |
10 | struct SIFTKeypointOrientations {
11 | let keypoint: SIFTKeypoint
12 | let orientations: [Float]
13 | }
14 |
--------------------------------------------------------------------------------
/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | IDEDidComputeMac32BitWarning
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/include/MetalShaders.h:
--------------------------------------------------------------------------------
1 | //
2 | // Use this file to import your target's public headers that you would like to expose to Swift.
3 | //
4 |
5 |
6 | #import "SIFTExtrema.h"
7 | #import "SIFTInterpolate.h"
8 | #import "SIFTOrientation.h"
9 | #import "SIFTDescriptor.h"
10 | #import "ConvolutionSeries.h"
11 | #import "NearestNeighbor.h"
12 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Extensions/CoreGraphicsExtensions.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Extensions.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/17.
6 | //
7 |
8 | import CoreGraphics
9 | import simd
10 |
11 |
12 | extension CGPoint {
13 | init(_ point: simd_float2) {
14 | self.init(x: CGFloat(point.x), y: CGFloat(point.y))
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/SIFT/SIFTCorrespondence.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTCorrespondence.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/02.
6 | //
7 |
8 | import Foundation
9 |
10 |
11 | public struct SIFTCorrespondence {
12 |
13 | public var source: SIFTDescriptor
14 | public var target: SIFTDescriptor
15 | public var featureDistance: Float
16 | }
17 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/include/NearestNeighbor.h:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighbor.h
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/08.
6 | //
7 |
8 | #include
9 |
10 | #ifndef NearestNeighbor_h
11 | #define NearestNeighbor_h
12 |
13 | struct NearestNeighborScaleParameters {
14 | int32_t inputSlice;
15 | int32_t outputSlice;
16 | };
17 |
18 | #endif /* NearestNeighbor_h */
19 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/include/SIFTExtrema.h:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTExtrema.h
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/10.
6 | //
7 |
8 | #include
9 |
10 | #ifndef SIFTExtrema_h
11 | #define SIFTExtrema_h
12 |
13 | // This should match SIFTInterpolateInputKeypoint
14 | struct SIFTExtremaResult {
15 | int32_t x;
16 | int32_t y;
17 | int32_t scale;
18 | };
19 |
20 | #endif /* SIFTExtrema_h */
21 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Extensions/FoundationExtensions.swift:
--------------------------------------------------------------------------------
1 | //
2 | // File.swift
3 | //
4 | //
5 | // Created by Luke Van In on 2023/01/24.
6 | //
7 |
8 | import Foundation
9 |
10 | extension Bundle {
11 |
12 | // See: https://developer.apple.com/forums/thread/649579
13 | static var metalShaders: Bundle {
14 | let bundleURL = Bundle(for: SIFT.self).url(forResource: "SIFTMetal_MetalShaders", withExtension: "bundle")!
15 | let bundle = Bundle(url: bundleURL)!
16 | return bundle
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/include/ConvolutionSeries.h:
--------------------------------------------------------------------------------
1 | //
2 | // ConvolutionSeries.h
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/08.
6 | //
7 |
8 | #include
9 |
10 | #ifndef ConvolutionSeries_h
11 | #define ConvolutionSeries_h
12 |
13 | #define CONVOLUTION_WEIGHTS_LENGTH 32
14 |
15 |
16 | struct ConvolutionParameters {
17 | int32_t inputDepth;
18 | int32_t outputDepth;
19 | int32_t count;
20 | float weights[CONVOLUTION_WEIGHTS_LENGTH];
21 | };
22 |
23 |
24 | #endif /* ConvolutionSeries_h */
25 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/Performance.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Performance.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/10.
6 | //
7 |
8 | import Foundation
9 | import OSLog
10 |
11 |
12 | private let signposter = OSSignposter(subsystem: Bundle.main.bundleIdentifier!, category: "performance")
13 |
14 |
15 | func measure(name: StaticString, worker: () -> Void) {
16 | let signpostID = signposter.makeSignpostID()
17 | let state = signposter.beginInterval(name, id: signpostID)
18 | worker()
19 | signposter.endInterval(name, state)
20 | }
21 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/Subtract.metal:
--------------------------------------------------------------------------------
1 | //
2 | // Subtract.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 | using namespace metal;
10 |
11 |
12 | kernel void subtract(
13 | texture2d_array outputTexture [[texture(0)]],
14 | texture2d_array inputTexture [[texture(1)]],
15 | ushort3 gid [[thread_position_in_grid]]
16 | ) {
17 | float4 a = inputTexture.read(gid.xy, gid.z + 1);
18 | float4 b = inputTexture.read(gid.xy, gid.z);
19 | float4 c = a - b;
20 | outputTexture.write(c, gid.xy, gid.z);
21 | }
22 |
23 |
24 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/Math.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Math.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/18.
6 | //
7 |
8 | import Foundation
9 |
10 |
11 | public struct IntegralSize {
12 | public var width: Int
13 | public var height: Int
14 |
15 | public init(width: Int, height: Int) {
16 | self.width = width
17 | self.height = height
18 | }
19 | }
20 |
21 |
22 | func modulus(_ x: Float, _ y: Float) -> Float {
23 | var z: Float = x
24 | var n: Int = 0
25 | if (z < 0) {
26 | n = Int(((-z) / y) + 1)
27 | z += Float(n) * y
28 | }
29 | n = Int(z / y)
30 | z -= Float(n) * y
31 | return z
32 | }
33 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/NearestNeighborDownScale.metal:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborDownScale.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #import "../include/NearestNeighbor.h"
11 |
12 | using namespace metal;
13 |
14 |
15 | kernel void nearestNeighborDownScale(
16 | texture2d_array outputTexture [[texture(0)]],
17 | texture2d_array inputTexture [[texture(1)]],
18 | device NearestNeighborScaleParameters & parameters [[buffer(0)]],
19 | ushort2 gid [[thread_position_in_grid]]
20 | ) {
21 | outputTexture.write(inputTexture.read(gid * 2, parameters.inputSlice), gid, parameters.outputSlice);
22 | }
23 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/MetalExtensions.swift:
--------------------------------------------------------------------------------
1 | //
2 | // MetalExtensions.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/10.
6 | //
7 |
8 | import Foundation
9 | import Metal
10 |
11 | func capture(commandQueue: MTLCommandQueue, capture: Bool = true, worker: () -> Void) {
12 | guard capture else {
13 | worker()
14 | return
15 | }
16 | let captureManager = MTLCaptureManager.shared()
17 | let captureDescriptor = MTLCaptureDescriptor()
18 | captureDescriptor.captureObject = commandQueue
19 | captureDescriptor.destination = .developerTools
20 | try! captureManager.startCapture(with: captureDescriptor)
21 | worker()
22 | captureManager.stopCapture()
23 | }
24 |
25 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/NearestNeighborUpScale.metal:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborUpScale.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 | using namespace metal;
10 |
11 |
12 | kernel void nearestNeighborUpScale(
13 | texture2d outputTexture [[texture(0)]],
14 | texture2d inputTexture [[texture(1)]],
15 | ushort2 gid [[thread_position_in_grid]]
16 | ) {
17 | ushort2 inputSize = ushort2(inputTexture.get_width(), inputTexture.get_height());
18 | ushort2 outputSize = ushort2(outputTexture.get_width(), outputTexture.get_height());
19 |
20 | ushort2 scale = outputSize / inputSize;
21 | outputTexture.write(inputTexture.read(gid / scale), gid);
22 | }
23 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/ConvertSRGBToGrayscale.metal:
--------------------------------------------------------------------------------
1 | //
2 | // ConvertSRGBToGrayscale.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 | using namespace metal;
10 |
11 | kernel void convertSRGBToGrayscale(
12 | texture2d outputTexture [[texture(0)]],
13 | texture2d inputTexture [[texture(1)]],
14 | ushort2 gid [[thread_position_in_grid]]
15 | ) {
16 | const float4 input = inputTexture.read(gid);
17 | const float i = 0 +
18 | (0.212639005871510 * input.r) +
19 | (0.715168678767756 * input.g) +
20 | (0.072192315360734 * input.b);
21 | const float4 output = float4(i, i, i, input.a);
22 | outputTexture.write(output, gid);
23 | }
24 |
25 |
26 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/include/SIFTOrientation.h:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTOrientation.h
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/08.
6 | //
7 | #include
8 |
9 | #ifndef SIFTOrientation_h
10 | #define SIFTOrientation_h
11 |
12 | #define SIFT_ORIENTATION_HISTOGRAM_BINS 36
13 |
14 | struct SIFTOrientationParameters {
15 | float delta;
16 | float lambda;
17 | float orientationThreshold;
18 | };
19 |
20 |
21 | struct SIFTOrientationKeypoint {
22 | int32_t index;
23 | int32_t absoluteX;
24 | int32_t absoluteY;
25 | int32_t scale;
26 | float sigma;
27 | };
28 |
29 |
30 | struct SIFTOrientationResult {
31 | int32_t keypoint;
32 | int32_t count;
33 | float orientations[SIFT_ORIENTATION_HISTOGRAM_BINS];
34 | };
35 |
36 | #endif /* SIFTOrientation_h */
37 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/include/SIFTDescriptor.h:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTDescriptor.h
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/08.
6 | //
7 |
8 | #include
9 |
10 | #ifndef SIFTDescriptor_h
11 | #define SIFTDescriptor_h
12 |
13 |
14 | #define SIFT_DESCRIPTOR_HISTOGRAM_WIDTH 4
15 | #define SIFT_DESCRIPTOR_ORIENTATION_BINS 8
16 | #define SIFT_DESCRIPTOR_FEATURE_COUNT 128
17 |
18 |
19 | struct SIFTDescriptorParameters {
20 | float delta;
21 | int32_t scalesPerOctave;
22 | int32_t width;
23 | int32_t height;
24 | };
25 |
26 |
27 | struct SIFTDescriptorInput {
28 | int32_t keypoint;
29 | int32_t absoluteX;
30 | int32_t absoluteY;
31 | int32_t scale;
32 | float subScale;
33 | float theta;
34 | };
35 |
36 |
37 | struct SIFTDescriptorResult {
38 | int32_t valid;
39 | int32_t keypoint;
40 | float theta;
41 | int32_t features[SIFT_DESCRIPTOR_FEATURE_COUNT];
42 | };
43 |
44 |
45 | #endif /* SIFTDescriptor_h */
46 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Extensions/CoreImageExtensions.swift:
--------------------------------------------------------------------------------
1 | //
2 | // CoreImageExtensions.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/17.
6 | //
7 |
8 | import CoreImage
9 | import CoreImage.CIFilterBuiltins
10 | import simd
11 |
12 |
13 | extension CIImage {
14 | public func perspectiveTransformed(by matrix: simd_float3x3) -> CIImage {
15 | let bounds = Quad(rect: extent).transformed(by: matrix)
16 | let filter = CIFilter.perspectiveTransform()
17 | filter.topLeft = CGPoint(bounds.topLeft)
18 | filter.topRight = CGPoint(bounds.topRight)
19 | filter.bottomRight = CGPoint(bounds.bottomRight)
20 | filter.bottomLeft = CGPoint(bounds.bottomLeft)
21 | filter.inputImage = self
22 | return filter.outputImage!
23 | }
24 |
25 | public func colorInverted() -> CIImage {
26 | let filter = CIFilter.colorInvert()
27 | filter.inputImage = self
28 | return filter.outputImage!
29 | }
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/include/SIFTInterpolate.h:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTInterpolate.h
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/09.
6 | //
7 |
8 | #include
9 |
10 | #ifndef SIFTInterpolate_h
11 | #define SIFTInterpolate_h
12 |
13 |
14 | struct SIFTInterpolateParameters {
15 | float dogThreshold;
16 | int32_t maxIterations;
17 | float maxOffset;
18 | int32_t width;
19 | int32_t height;
20 | float octaveDelta;
21 | float edgeThreshold;
22 | int32_t numberOfScales;
23 | };
24 |
25 |
26 | // This should match SIFTExtremaResult
27 | struct SIFTInterpolateInputKeypoint {
28 | int32_t x;
29 | int32_t y;
30 | int32_t scale;
31 | };
32 |
33 |
34 | struct SIFTInterpolateOutputKeypoint {
35 | int32_t converged;
36 | int32_t scale;
37 | float subScale;
38 | int32_t relativeX;
39 | int32_t relativeY;
40 | float absoluteX;
41 | float absoluteY;
42 | float value;
43 | float alphaX;
44 | float alphaY;
45 | float alphaZ;
46 | };
47 |
48 |
49 | #endif /* SIFTInterpolate_h */
50 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/CoreVideoMetalCache.swift:
--------------------------------------------------------------------------------
1 | //
2 | // CoreVideoMetalCache.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/18.
6 | //
7 |
8 | import Foundation
9 | import CoreVideo
10 |
11 |
12 | final class CoreVideoMetalCache {
13 |
14 | private let textureCache: CVMetalTextureCache
15 |
16 | init(device: MTLDevice) {
17 | var textureCache: CVMetalTextureCache?
18 | CVMetalTextureCacheCreate(kCFAllocatorDefault, nil, device, nil, &textureCache)
19 | self.textureCache = textureCache!
20 |
21 | }
22 |
23 | func makeTexture(from input: CVImageBuffer, size: IntegralSize) -> MTLTexture {
24 | var cvMetalTexture: CVMetalTexture?
25 | let result = CVMetalTextureCacheCreateTextureFromImage(kCFAllocatorDefault, textureCache, input, nil, .bgra8Unorm, size.width, size.height, 0, &cvMetalTexture)
26 | guard result == kCVReturnSuccess else {
27 | fatalError("Cannot create texture")
28 | }
29 | let texture = CVMetalTextureGetTexture(cvMetalTexture!)!
30 | return texture
31 | }
32 |
33 | }
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Luke Van In
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/ImageConversion.swift:
--------------------------------------------------------------------------------
1 | //
2 | // ImageConversion.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/18.
6 | //
7 |
8 | import Foundation
9 | import CoreImage
10 | import CoreVideo
11 | import Metal
12 |
13 |
14 | public final class ImageConversion {
15 |
16 | private let ciContext: CIContext
17 |
18 | public init(device: MTLDevice, colorSpace: CGColorSpace) {
19 | self.ciContext = CIContext(
20 | mtlDevice: device,
21 | options: [
22 | .useSoftwareRenderer: false,
23 | .outputColorSpace: colorSpace,
24 | .workingColorSpace: colorSpace,
25 | ]
26 | )
27 | }
28 |
29 | public func makeCGImage(_ input: CVImageBuffer) -> CGImage {
30 | let ciImage = CIImage(
31 | cvPixelBuffer: input,
32 | options: [.applyOrientationProperty: true]
33 | )
34 | return makeCGImage(ciImage)
35 | }
36 |
37 | public func makeCGImage(_ input: MTLTexture) -> CGImage {
38 | let ciImage = CIImage(mtlTexture: input)!
39 | return makeCGImage(ciImage)
40 | }
41 |
42 | public func makeCGImage(_ input: CIImage) -> CGImage {
43 | let output = input.transformed(by: input.orientationTransform(for: .downMirrored))
44 | return ciContext.createCGImage(output, from: output.extent)!
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/Buffer.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Buffer.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/06.
6 | //
7 |
8 | import Foundation
9 | import Metal
10 |
11 |
12 | public final class Buffer {
13 |
14 | public private(set) var count: Int
15 | let data: MTLBuffer
16 | let pointer: UnsafeMutablePointer
17 | public let capacity: Int
18 |
19 | public init(device: MTLDevice, label: String, capacity: Int) {
20 | self.capacity = capacity
21 | let numberOfBytes = MemoryLayout.stride * capacity
22 | self.data = device.makeBuffer(
23 | length: numberOfBytes,
24 | options: [.hazardTrackingModeTracked, .storageModeShared]
25 | )!
26 | self.data.label = label
27 | self.pointer = data.contents().bindMemory(to: T.self, capacity: capacity)
28 | self.count = 0
29 | }
30 |
31 | deinit {
32 | data.setPurgeableState(.empty)
33 | }
34 |
35 | public func allocate(_ count: Int) {
36 | precondition(count >= 0)
37 | precondition(count <= capacity)
38 | self.count = count
39 | }
40 |
41 | public subscript(i: Int) -> T {
42 | get {
43 | precondition(i >= 0)
44 | precondition(i < count)
45 | return pointer[i]
46 | }
47 | set {
48 | precondition(i >= 0)
49 | precondition(i < count)
50 | pointer[i] = newValue
51 | }
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/SIFTGradient.metal:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTGradient.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #include "Common.hpp"
11 |
12 | using namespace metal;
13 |
14 |
15 | kernel void siftGradient(
16 | texture2d_array outputTexture [[texture(0)]],
17 | texture2d_array inputTexture [[texture(1)]],
18 | ushort3 gid [[thread_position_in_grid]]
19 | ) {
20 | const int gx = (int)gid.x;
21 | const int gy = (int)gid.y;
22 | const int gz = (int)gid.z;
23 | const int dx = inputTexture.get_width();
24 | const int dy = inputTexture.get_height();
25 | const ushort px = symmetrizedCoordinates(gx + 1, dx);
26 | const ushort mx = symmetrizedCoordinates(gx - 1, dx);
27 | const ushort py = symmetrizedCoordinates(gy + 1, dy);
28 | const ushort my = symmetrizedCoordinates(gy - 1, dy);
29 | const float cpx = inputTexture.read(ushort2(px, gy), gz).r;
30 | const float cmx = inputTexture.read(ushort2(mx, gy), gz).r;
31 | const float cpy = inputTexture.read(ushort2(gx, py), gz).r;
32 | const float cmy = inputTexture.read(ushort2(gx, my), gz).r;
33 | const float tx = (cpx - cmx) * 0.5;
34 | const float ty = (cpy - cmy) * 0.5;
35 | #warning("FIXME: IPOL implementation swaps dx and dy")
36 | float oa = atan2(tx, ty);
37 | float om = sqrt(tx * tx + ty * ty);
38 | outputTexture.write(float4(oa, om, 0, 0), ushort2(gx, gy), gz);
39 | }
40 |
41 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/Quad.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Quad.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/17.
6 | //
7 |
8 | import Foundation
9 | import simd
10 |
11 |
12 | struct Quad {
13 | var topLeft: simd_float2
14 | var topRight: simd_float2
15 | var bottomRight: simd_float2
16 | var bottomLeft: simd_float2
17 |
18 | var points: [simd_float2] {
19 | [topLeft, topRight, bottomRight, bottomLeft]
20 | }
21 |
22 | func transformed(by matrix: simd_float3x3) -> Quad {
23 | Quad(
24 | points: points
25 | .map { point in
26 | simd_float3(point.x, point.y, 1)
27 | }
28 | .map { point in
29 | matrix * point
30 | }
31 | .map { point in
32 | simd_float2(point.x / point.z, point.y / point.z)
33 | }
34 | )
35 | }
36 | }
37 |
38 | extension Quad {
39 | init(rect: CGRect) {
40 | self.init(
41 | topLeft: simd_float2(Float(rect.minX), Float(rect.maxY)),
42 | topRight: simd_float2(Float(rect.maxX), Float(rect.maxY)),
43 | bottomRight: simd_float2(Float(rect.maxX), Float(rect.minY)),
44 | bottomLeft: simd_float2(Float(rect.minX), Float(rect.minY))
45 | )
46 | }
47 |
48 | init(points: [simd_float2]) {
49 | self.init(
50 | topLeft: points[0],
51 | topRight: points[1],
52 | bottomRight: points[2],
53 | bottomLeft: points[3]
54 | )
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/SIFT/SIFTKeypoint.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTKeypoint.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/02.
6 | //
7 |
8 | import Foundation
9 |
10 |
11 | public struct SIFTKeypoint {
12 |
13 | // Index of the level of the difference-of-gaussians pyramid.
14 | public var octave: Int
15 |
16 | // Index of the image in the octave.
17 | public var scale: Int
18 |
19 | //
20 | public var subScale: Float
21 |
22 | // Coordinate relative to the difference-of-gaussians image size.
23 | public var scaledCoordinate: SIMD2
24 |
25 | // Coordinate relative to the original image.
26 | public var absoluteCoordinate: SIMD2
27 |
28 | // Coordinate relative to the normal space (0...1, 0...1)
29 | public var normalizedCoordinate: SIMD2
30 |
31 | // "Blur"
32 | public var sigma: Float
33 |
34 | // Pixel color (intensity)
35 | public var value: Float
36 |
37 | public init(
38 | octave: Int,
39 | scale: Int,
40 | subScale: Float,
41 | scaledCoordinate: SIMD2,
42 | absoluteCoordinate: SIMD2,
43 | normalizedCoordinate: SIMD2,
44 | sigma: Float,
45 | value: Float
46 | ) {
47 | self.octave = octave
48 | self.scale = scale
49 | self.subScale = subScale
50 | self.scaledCoordinate = scaledCoordinate
51 | self.absoluteCoordinate = absoluteCoordinate
52 | self.normalizedCoordinate = normalizedCoordinate
53 | self.sigma = sigma
54 | self.value = value
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/Convolution.metal:
--------------------------------------------------------------------------------
1 | //
2 | // Convolution.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #include "Common.hpp"
11 |
12 | using namespace metal;
13 |
14 |
15 | kernel void convolutionX(
16 | texture2d outputTexture [[texture(0)]],
17 | texture2d inputTexture [[texture(1)]],
18 | device float * weights [[buffer(0)]],
19 | device uint & numberOfWeights [[buffer(1)]],
20 | ushort2 gid [[thread_position_in_grid]]
21 | ) {
22 | const int width = inputTexture.get_width();
23 |
24 | float sum = 0;
25 | const int n = (int)numberOfWeights;
26 | const int o = (int)gid.x - (n / 2);
27 | for (int i = 0; i < n; i++) {
28 | int x = symmetrizedCoordinates(o + i, width);
29 | sum += weights[i] * inputTexture.read(ushort2(x, gid.y)).r;
30 | }
31 | outputTexture.write(float4(sum, 0, 0, 1), gid);
32 | }
33 |
34 |
35 | kernel void convolutionY(
36 | texture2d outputTexture [[texture(0)]],
37 | texture2d inputTexture [[texture(1)]],
38 | device float * weights [[buffer(0)]],
39 | device uint & numberOfWeights [[buffer(1)]],
40 | ushort2 gid [[thread_position_in_grid]]
41 | ) {
42 | const int height = inputTexture.get_height();
43 |
44 | float sum = 0;
45 | const int n = (int)numberOfWeights;
46 | const int o = (int)gid.y - (n / 2);
47 | for (int i = 0; i < n; i++) {
48 | int y = symmetrizedCoordinates(o + i, height);
49 | sum += weights[i] * inputTexture.read(ushort2(gid.x, y)).r;
50 | }
51 | outputTexture.write(float4(sum, 0, 0, 1), gid);
52 | }
53 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/ConvolutionSeries.metal:
--------------------------------------------------------------------------------
1 | //
2 | // Convolution.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #include "Common.hpp"
11 | #include "../include/ConvolutionSeries.h"
12 |
13 | using namespace metal;
14 |
15 |
16 | kernel void convolutionSeriesX(
17 | texture2d_array outputTexture [[texture(0)]],
18 | texture2d_array inputTexture [[texture(1)]],
19 | device ConvolutionParameters & parameters [[buffer(0)]],
20 | ushort2 gid [[thread_position_in_grid]]
21 | ) {
22 | const int width = inputTexture.get_width();
23 |
24 | float sum = 0;
25 | const int n = (int)parameters.count;
26 | const int o = (int)gid.x - (n / 2);
27 | for (int i = 0; i < n; i++) {
28 | int x = symmetrizedCoordinates(o + i, width);
29 | float c = inputTexture.read(ushort2(x, gid.y), parameters.inputDepth).r;
30 | sum += parameters.weights[i] * c;
31 | }
32 | outputTexture.write(float4(sum, 0, 0, 1), gid, parameters.outputDepth);
33 | }
34 |
35 |
36 | kernel void convolutionSeriesY(
37 | texture2d_array outputTexture [[texture(0)]],
38 | texture2d_array inputTexture [[texture(1)]],
39 | device ConvolutionParameters & parameters [[buffer(0)]],
40 | ushort2 gid [[thread_position_in_grid]]
41 | ) {
42 | const int height = inputTexture.get_height();
43 |
44 | float sum = 0;
45 | const int n = (int)parameters.count;
46 | const int o = (int)gid.y - (n / 2);
47 | for (int i = 0; i < n; i++) {
48 | int y = symmetrizedCoordinates(o + i, height);
49 | float c = inputTexture.read(ushort2(gid.x, y), parameters.inputDepth).r;
50 | sum += parameters.weights[i] * c;
51 | }
52 | outputTexture.write(float4(sum, 0, 0, 1), gid, parameters.outputDepth);
53 | }
54 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/BilinearUpScale.metal:
--------------------------------------------------------------------------------
1 | //
2 | // BilinearUpScale.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 | using namespace metal;
10 |
11 |
12 | kernel void bilinearUpScale(
13 | texture2d outputTexture [[texture(0)]],
14 | texture2d inputTexture [[texture(1)]],
15 | uint2 gid [[thread_position_in_grid]]
16 | ) {
17 |
18 | const int wo = outputTexture.get_width();
19 | const int ho = outputTexture.get_height();
20 |
21 | const int wi = inputTexture.get_width();
22 | const int hi = inputTexture.get_height();
23 |
24 | const float dx = (float)wi / (float)wo;
25 | const float dy = (float)hi / (float)ho;
26 |
27 | int i = gid.x;
28 | int j = gid.y;
29 | const float x = (float)i * dx;
30 | const float y = (float)j * dy;
31 | int im = (int)x;
32 | int jm = (int)y;
33 | int ip = im + 1;
34 | int jp = jm + 1;
35 |
36 | //image extension by symmetrization
37 | if (ip >= wi) {
38 | ip = 2 * wi - 1 - ip;
39 | }
40 | if (im >= wi) {
41 | im = 2 * wi - 1 - im;
42 | }
43 | if (jp >= hi) {
44 | jp = 2 * hi - 1 - jp;
45 | }
46 | if (jm >= hi) {
47 | jm = 2 * hi - 1 - jm;
48 | }
49 |
50 | const float fractional_x = x - floor(x);
51 | const float fractional_y = y - floor(y);
52 |
53 | const float c0 = inputTexture.read(uint2(ip, jp)).r;
54 | const float c1 = inputTexture.read(uint2(ip, jm)).r;
55 | const float c2 = inputTexture.read(uint2(im, jp)).r;
56 | const float c3 = inputTexture.read(uint2(im, jm)).r;
57 |
58 | const float output = fractional_x * (fractional_y * c0
59 | + (1 - fractional_y) * c1 )
60 | + (1 - fractional_x) * ( fractional_y * c2
61 | + (1 - fractional_y) * c3 );
62 |
63 | outputTexture.write(float4(output, 0, 0, 1), gid);
64 | }
65 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/NearestNeighborUpScaleKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborScaleKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class NearestNeighborUpScaleKernel {
13 |
14 | private let computePipelineState: MTLComputePipelineState
15 |
16 | init(device: MTLDevice) {
17 | let library = device.makeDefaultLibrary()!
18 |
19 | let function = library.makeFunction(name: "nearestNeighborUpScale")!
20 |
21 | self.computePipelineState = try! device.makeComputePipelineState(
22 | function: function
23 | )
24 | }
25 |
26 | func encode(
27 | commandBuffer: MTLCommandBuffer,
28 | inputTexture: MTLTexture,
29 | outputTexture: MTLTexture
30 | ) {
31 | precondition(outputTexture.width % inputTexture.width == 0)
32 | precondition(outputTexture.height % inputTexture.height == 0)
33 |
34 | let encoder = commandBuffer.makeComputeCommandEncoder()!
35 | encoder.setComputePipelineState(computePipelineState)
36 | encoder.setTexture(outputTexture, index: 0)
37 | encoder.setTexture(inputTexture, index: 1)
38 |
39 | // Set the compute kernel's threadgroup size of 16x16
40 | // TODO: Ger threadgroup size from command buffer.
41 | let threadgroupSize = MTLSize(
42 | width: 16,
43 | height: 16,
44 | depth: 1
45 | )
46 | // Calculate the number of rows and columns of threadgroups given the width of the input image
47 | // Ensure that you cover the entire image (or more) so you process every pixel
48 | // Since we're only dealing with a 2D data set, set depth to 1
49 | let threadgroupCount = MTLSize(
50 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
51 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
52 | depth: 1
53 | )
54 | encoder.dispatchThreadgroups(
55 | threadgroupCount,
56 | threadsPerThreadgroup: threadgroupSize
57 | )
58 | encoder.endEncoding()
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/SIFTDescriptorKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTOrientationKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 | import MetalShaders
12 |
13 | public final class SIFTDescriptorKernel {
14 |
15 | private let maximumKeypoints = 4096
16 |
17 | private let computePipelineState: MTLComputePipelineState
18 |
19 | public init(device: MTLDevice) {
20 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
21 |
22 | let function = library.makeFunction(name: "siftDescriptors")!
23 |
24 | self.computePipelineState = try! device.makeComputePipelineState(
25 | function: function
26 | )
27 | }
28 |
29 | public func encode(
30 | commandBuffer: MTLCommandBuffer,
31 | parameters: Buffer,
32 | gradientTextures: MTLTexture,
33 | inputKeypoints: Buffer,
34 | outputDescriptors: Buffer
35 | ) {
36 | precondition(inputKeypoints.count == outputDescriptors.count)
37 | precondition(gradientTextures.textureType == .type2DArray)
38 | precondition(gradientTextures.pixelFormat == .rg32Float)
39 |
40 | let encoder = commandBuffer.makeComputeCommandEncoder()!
41 | encoder.setComputePipelineState(computePipelineState)
42 | encoder.setBuffer(outputDescriptors.data, offset: 0, index: 0)
43 | encoder.setBuffer(inputKeypoints.data, offset: 0, index: 1)
44 | encoder.setBuffer(parameters.data, offset: 0, index: 2)
45 | encoder.setTexture(gradientTextures, index: 0)
46 |
47 | let threadsPerThreadgroup = MTLSize(
48 | width: computePipelineState.maxTotalThreadsPerThreadgroup,
49 | height: 1,
50 | depth: 1
51 | )
52 | let threadsPerGrid = MTLSize(
53 | width: outputDescriptors.count,
54 | height: 1,
55 | depth: 1
56 | )
57 |
58 | encoder.dispatchThreads(
59 | threadsPerGrid,
60 | threadsPerThreadgroup: threadsPerThreadgroup
61 | )
62 | encoder.endEncoding()
63 | }
64 | }
65 |
66 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/BilinearUpScaleKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborScaleKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class BilinearUpScaleKernel {
13 |
14 | private let computePipelineState: MTLComputePipelineState
15 |
16 | init(device: MTLDevice) {
17 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
18 |
19 | let function = library.makeFunction(name: "bilinearUpScale")!
20 |
21 | self.computePipelineState = try! device.makeComputePipelineState(
22 | function: function
23 | )
24 | }
25 |
26 | func encode(
27 | commandBuffer: MTLCommandBuffer,
28 | inputTexture: MTLTexture,
29 | outputTexture: MTLTexture
30 | ) {
31 | precondition(outputTexture.width > inputTexture.width)
32 | precondition(outputTexture.height > inputTexture.height)
33 | precondition(outputTexture.pixelFormat == inputTexture.pixelFormat)
34 |
35 | let encoder = commandBuffer.makeComputeCommandEncoder()!
36 | encoder.setComputePipelineState(computePipelineState)
37 | encoder.setTexture(outputTexture, index: 0)
38 | encoder.setTexture(inputTexture, index: 1)
39 |
40 | // Set the compute kernel's threadgroup size of 16x16
41 | // TODO: Get threadgroup size from command buffer.
42 | let threadgroupSize = MTLSize(
43 | width: 16,
44 | height: 16,
45 | depth: 1
46 | )
47 | // Calculate the number of rows and columns of threadgroups given the width of the input image
48 | // Ensure that you cover the entire image (or more) so you process every pixel
49 | // Since we're only dealing with a 2D data set, set depth to 1
50 | let threadgroupCount = MTLSize(
51 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
52 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
53 | depth: 1
54 | )
55 | encoder.dispatchThreadgroups(
56 | threadgroupCount,
57 | threadsPerThreadgroup: threadgroupSize
58 | )
59 | encoder.endEncoding()
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/SIFTExtremaKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTExtremaKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/20.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class SIFTExtremaFunction {
13 |
14 | private let computePipelineState: MTLComputePipelineState
15 |
16 | init(device: MTLDevice) {
17 | let library = device.makeDefaultLibrary()!
18 |
19 | let function = library.makeFunction(name: "siftExtrema")!
20 | function.label = "siftExtremaFunction"
21 |
22 | self.computePipelineState = try! device.makeComputePipelineState(
23 | function: function
24 | )
25 | }
26 |
27 | func encode(
28 | commandBuffer: MTLCommandBuffer,
29 | inputTexture: MTLTexture,
30 | outputTexture: MTLTexture
31 | ) {
32 | precondition(inputTexture.width == outputTexture.width)
33 | precondition(inputTexture.height == outputTexture.height)
34 | precondition(inputTexture.arrayLength == outputTexture.arrayLength + 2)
35 | precondition(inputTexture.textureType == .type2DArray)
36 | precondition(inputTexture.pixelFormat == .r32Float)
37 | precondition(outputTexture.textureType == .type2DArray)
38 | precondition(outputTexture.pixelFormat == .rg32Float)
39 |
40 | let encoder = commandBuffer.makeComputeCommandEncoder()!
41 | encoder.label = "siftExtremaFunctionComputeEncoder"
42 | encoder.setComputePipelineState(computePipelineState)
43 | encoder.setTexture(outputTexture, index: 0)
44 | encoder.setTexture(inputTexture, index: 1)
45 |
46 | let threadsPerDimension = Int(cbrt(Float(computePipelineState.maxTotalThreadsPerThreadgroup)))
47 | let threadsPerThreadgroup = MTLSize(
48 | width: threadsPerDimension,
49 | height: threadsPerDimension,
50 | depth: threadsPerDimension
51 | )
52 | let threadsPerGrid = MTLSize(
53 | width: outputTexture.width - 2,
54 | height: outputTexture.height - 2,
55 | depth: outputTexture.arrayLength
56 | )
57 |
58 | encoder.dispatchThreads(
59 | threadsPerGrid,
60 | threadsPerThreadgroup: threadsPerThreadgroup
61 | )
62 | encoder.endEncoding()
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/ConvertSRGBToGrayscaleKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborScaleKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class ConvertSRGBToGrayscaleKernel {
13 |
14 | private let computePipelineState: MTLComputePipelineState
15 |
16 | init(device: MTLDevice) {
17 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
18 |
19 | let function = library.makeFunction(name: "convertSRGBToGrayscale")!
20 |
21 | self.computePipelineState = try! device.makeComputePipelineState(
22 | function: function
23 | )
24 | }
25 |
26 | func encode(
27 | commandBuffer: MTLCommandBuffer,
28 | inputTexture: MTLTexture,
29 | outputTexture: MTLTexture
30 | ) {
31 | precondition(inputTexture.width == outputTexture.width)
32 | precondition(inputTexture.height == outputTexture.height)
33 | // precondition(inputTexture.pixelFormat == .bgra8Unorm_srgb)
34 | precondition(inputTexture.pixelFormat == .bgra8Unorm)
35 | precondition(outputTexture.pixelFormat == .r32Float)
36 |
37 | let encoder = commandBuffer.makeComputeCommandEncoder()!
38 | encoder.setComputePipelineState(computePipelineState)
39 | encoder.setTexture(outputTexture, index: 0)
40 | encoder.setTexture(inputTexture, index: 1)
41 |
42 | // Set the compute kernel's threadgroup size of 16x16
43 | // TODO: Ger threadgroup size from command buffer.
44 | let threadgroupSize = MTLSize(
45 | width: 16,
46 | height: 16,
47 | depth: 1
48 | )
49 | // Calculate the number of rows and columns of threadgroups given the width of the input image
50 | // Ensure that you cover the entire image (or more) so you process every pixel
51 | // Since we're only dealing with a 2D data set, set depth to 1
52 | let threadgroupCount = MTLSize(
53 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
54 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
55 | depth: 1
56 | )
57 | encoder.dispatchThreadgroups(
58 | threadgroupCount,
59 | threadsPerThreadgroup: threadgroupSize
60 | )
61 | encoder.endEncoding()
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/SIFTExtremaListKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTExtremaKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/20.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 | import MetalShaders
12 |
13 | final class SIFTExtremaListFunction {
14 |
15 | private let computePipelineState: MTLComputePipelineState
16 |
17 | let indexBuffer: Buffer
18 |
19 | init(device: MTLDevice) {
20 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
21 |
22 | let function = library.makeFunction(name: "siftExtremaList")!
23 | function.label = "siftExtremaListFunction"
24 |
25 | self.computePipelineState = try! device.makeComputePipelineState(
26 | function: function
27 | )
28 | self.indexBuffer = Buffer(
29 | device: device,
30 | label: "siftExtremaListIndex",
31 | capacity: 1
32 | )
33 | self.indexBuffer.allocate(1)
34 | indexBuffer[0] = 0
35 | }
36 |
37 | func encode(
38 | commandBuffer: MTLCommandBuffer,
39 | inputTexture: MTLTexture,
40 | outputBuffer: Buffer
41 | ) {
42 | precondition(inputTexture.textureType == .type2DArray)
43 | precondition(inputTexture.pixelFormat == .r32Float)
44 |
45 | let encoder = commandBuffer.makeComputeCommandEncoder()!
46 | encoder.label = "siftExtremaListFunctionComputeEncoder"
47 | encoder.setComputePipelineState(computePipelineState)
48 | encoder.setBuffer(outputBuffer.data, offset: 0, index: 0)
49 | encoder.setBuffer(indexBuffer.data, offset: 0, index: 1)
50 | encoder.setTexture(inputTexture, index: 0)
51 |
52 | let threadsPerDimension = Int(cbrt(Float(computePipelineState.maxTotalThreadsPerThreadgroup)))
53 | let threadsPerThreadgroup = MTLSize(
54 | width: threadsPerDimension,
55 | height: threadsPerDimension,
56 | depth: threadsPerDimension
57 | )
58 | let threadsPerGrid = MTLSize(
59 | width: inputTexture.width - 2,
60 | height: inputTexture.height - 2,
61 | depth: inputTexture.arrayLength - 2
62 | )
63 |
64 | encoder.dispatchThreads(
65 | threadsPerGrid,
66 | threadsPerThreadgroup: threadsPerThreadgroup
67 | )
68 | encoder.endEncoding()
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/SIFT/SIFTPatch.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTHistograms.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/02.
6 | //
7 |
8 | import Foundation
9 |
10 |
11 | final class SIFTPatch {
12 |
13 | private let buffer: UnsafeMutableBufferPointer
14 | private let side: Int
15 | private let bins: Int
16 |
17 | init(side: Int, bins: Int) {
18 | self.side = side
19 | self.bins = bins
20 | self.buffer = UnsafeMutableBufferPointer.allocate(
21 | capacity: side * side * bins
22 | )
23 | }
24 |
25 | deinit {
26 | buffer.deallocate()
27 | }
28 |
29 | func addValue(x: Float, y: Float, bin: Float, value: Float) {
30 | let ca = SIMD2(x: Int(floor(x)), y: Int(floor(y)))
31 | let cb = SIMD2(x: Int(ceil(x)), y: Int(floor(y)))
32 | let cc = SIMD2(x: Int(ceil(x)), y: Int(ceil(y)))
33 | let cd = SIMD2(x: Int(floor(x)), y: Int(ceil(y)))
34 |
35 | let ba = Int(floor(bin))
36 | let bb = Int(ceil(bin))
37 |
38 | let iMax = x - floor(x)
39 | let iMin = 1 - iMax
40 | let jMax = y - floor(y)
41 | let jMin = 1 - jMax
42 | let bMax = bin - floor(bin)
43 | let bMin = 1 - bMax
44 |
45 | addValue(x: ca.x, y: ca.y, bin: ba, value: (iMin * jMin * bMin) * value)
46 | addValue(x: ca.x, y: ca.y, bin: bb, value: (iMin * jMin * bMax) * value)
47 |
48 | addValue(x: cb.x, y: cb.y, bin: ba, value: (iMax * jMin * bMin) * value)
49 | addValue(x: cb.x, y: cb.y, bin: bb, value: (iMax * jMin * bMax) * value)
50 |
51 | addValue(x: cc.x, y: cc.y, bin: ba, value: (iMax * jMax * bMin) * value)
52 | addValue(x: cc.x, y: cc.y, bin: bb, value: (iMax * jMax * bMax) * value)
53 |
54 | addValue(x: cd.x, y: cd.y, bin: ba, value: (iMin * jMax * bMin) * value)
55 | addValue(x: cd.x, y: cd.y, bin: bb, value: (iMin * jMax * bMax) * value)
56 | }
57 |
58 | func addValue(x: Int, y: Int, bin: Int, value: Float) {
59 | guard x >= 0 && x < side && y >= 0 && y < side else {
60 | return
61 | }
62 | var bin = bin
63 | while bin < 0 {
64 | bin += bins
65 | }
66 | while bin >= bins {
67 | bin -= bins
68 | }
69 | buffer[offset(x: x, y: y, bin: bin)] += value
70 | }
71 |
72 | private func offset(x: Int, y: Int, bin: Int) -> Int {
73 | (y * side * bins) + (x * bins) + bin
74 | }
75 |
76 | func features() -> [Float] {
77 | return Array(buffer)
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/SubtractKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SubtractKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class SubtractKernel {
13 |
14 | private let computePipelineState: MTLComputePipelineState
15 |
16 | init(device: MTLDevice) {
17 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
18 |
19 | let function = library.makeFunction(name: "subtract")!
20 | function.label = "subtractFunction"
21 |
22 | self.computePipelineState = try! device.makeComputePipelineState(
23 | function: function
24 | )
25 | }
26 |
27 | func encode(
28 | commandBuffer: MTLCommandBuffer,
29 | inputTexture: MTLTexture,
30 | outputTexture: MTLTexture
31 | ) {
32 | precondition(inputTexture.width == outputTexture.width)
33 | precondition(inputTexture.width == outputTexture.width)
34 | precondition(inputTexture.arrayLength == outputTexture.arrayLength + 1)
35 | precondition(inputTexture.textureType == .type2DArray)
36 | precondition(inputTexture.pixelFormat == .r32Float)
37 | precondition(outputTexture.textureType == .type2DArray)
38 | precondition(outputTexture.pixelFormat == .r32Float)
39 |
40 | let encoder = commandBuffer.makeComputeCommandEncoder()!
41 | encoder.label = "subtractFunctionComputeEncoder"
42 | encoder.setComputePipelineState(computePipelineState)
43 | encoder.setTexture(outputTexture, index: 0)
44 | encoder.setTexture(inputTexture, index: 1)
45 |
46 | // Set the compute kernel's threadgroup size of 16x16
47 | // TODO: Ger threadgroup size from command buffer.
48 | let threadgroupSize = MTLSize(
49 | width: 8,
50 | height: 8,
51 | depth: 8
52 | )
53 | // Calculate the number of rows and columns of threadgroups given the width of the input image
54 | // Ensure that you cover the entire image (or more) so you process every pixel
55 | // Since we're only dealing with a 2D data set, set depth to 1
56 | let threadgroupCount = MTLSize(
57 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
58 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
59 | depth: (outputTexture.arrayLength + threadgroupSize.depth - 1)
60 | )
61 | encoder.dispatchThreadgroups(
62 | threadgroupCount,
63 | threadsPerThreadgroup: threadgroupSize
64 | )
65 | encoder.endEncoding()
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SIFTMetal
2 |
3 | Luke Van In, 2023
4 |
5 | An implementation of the Scale Invariant Feature Transform (SIFT) algorithm, for
6 | Apple devices, written in Swift using Metal compute.
7 |
8 | SIFT is described in the paper "Distinctive Image Features from Scale-Invariant
9 | Keypoints" by David Lowe published in 2004[1].
10 |
11 | This implementation is based on the source code from the "Anatomy of the SIFT
12 | Method" by Ives Ray-Otero and Mauricio Delbracio published in the Image
13 | Processing Online (IPOL) journal in 2014[2], and source code by Rob Whess[3].
14 |
15 | The scale-invariant feature transform (SIFT) is a computer vision algorithm to
16 | detect, describe, and match local features in images, invented by David Lowe in
17 | 1999. Applications include object recognition, robotic mapping and navigation,
18 | image stitching, 3D modeling, gesture recognition, video tracking, individual
19 | identification of wildlife and match moving.[4]
20 |
21 | A novel Approximate K Nearest Neighbors algorithm is provided for matching SIFT
22 | descriptors, using a trie data structure. The complexity of the algorithm is:
23 | - Initial construction and update is linear O(n) complexity.
24 | - Nearest neighbor search is O(1) complexity.
25 |
26 | SIFT keypoints of objects are first extracted from a set of reference images[1] and stored in a database. An object is recognized in a new image by individually comparing each feature from the new image to this database and finding candidate matching features based on Euclidean distance of their feature vectors. From the full set of matches, subsets of keypoints that agree on the object and its location, scale, and orientation in the new image are identified to filter out good matches. The determination of consistent clusters is performed rapidly by using an efficient hash table implementation of the generalised Hough transform. Each cluster of 3 or more features that agree on an object and its pose is then subject to further detailed model verification and subsequently outliers are discarded. Finally the probability that a particular set of features indicates the presence of an object is computed, given the accuracy of fit and number of probable false matches. Object matches that pass all these tests can be identified as correct with high confidence.[2]
27 | [1]: https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf "Distinctive Image Features from Scale-Invariant Keypoints", Lowe, International Journal of Computer Vision, 2004
28 | [2]: https://github.com/robwhess/opensift OpenSIFT, Whess, GitHub, 2012
29 | [3]: http://www.ipol.im/pub/art/2014/82/article.pdf "Anatomy of the SIFT Method", Rey-Otero & Delbracio, IPOL, 2014
30 | [4]: https://en.wikipedia.org/wiki/Scale-invariant_feature_transform Scale-invariant feature transform, Wikipedia
31 |
32 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/SIFTGradientKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTGradientKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/03.
6 | //
7 |
8 | import Foundation
9 | import Metal
10 |
11 | ///
12 | /// Computes the gradient orientation and magnitude for every pixel in the input image.
13 | ///
14 | final class SIFTGradientKernel {
15 |
16 | private let computePipelineState: MTLComputePipelineState
17 |
18 | init(device: MTLDevice) {
19 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
20 |
21 | let function = library.makeFunction(name: "siftGradient")!
22 | function.label = "siftGradientFunction"
23 |
24 | self.computePipelineState = try! device.makeComputePipelineState(
25 | function: function
26 | )
27 | }
28 |
29 | func encode(
30 | commandBuffer: MTLCommandBuffer,
31 | inputTexture: MTLTexture,
32 | outputTexture: MTLTexture
33 | ) {
34 | precondition(inputTexture.width == outputTexture.width)
35 | precondition(inputTexture.height == outputTexture.height)
36 | precondition(inputTexture.arrayLength == outputTexture.arrayLength)
37 | precondition(inputTexture.textureType == .type2DArray)
38 | precondition(inputTexture.pixelFormat == .r32Float)
39 | precondition(outputTexture.textureType == .type2DArray)
40 | precondition(outputTexture.pixelFormat == .rg32Float)
41 |
42 | let encoder = commandBuffer.makeComputeCommandEncoder()!
43 | encoder.label = "siftGradientComputeEncoder"
44 | encoder.setComputePipelineState(computePipelineState)
45 | encoder.setTexture(outputTexture, index: 0)
46 | encoder.setTexture(inputTexture, index: 1)
47 |
48 | // Set the compute kernel's threadgroup size of 16x16
49 | // TODO: Get threadgroup size from command buffer.
50 | let threadgroupSize = MTLSize(
51 | width: 8,
52 | height: 8,
53 | depth: 8
54 | )
55 | // Calculate the number of rows and columns of threadgroups given the width of the input image
56 | // Ensure that you cover the entire image (or more) so you process every pixel
57 | // Since we're only dealing with a 2D data set, set depth to 1
58 | let threadgroupCount = MTLSize(
59 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
60 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
61 | depth: (outputTexture.arrayLength + threadgroupSize.depth - 1) / threadgroupSize.depth
62 | )
63 | encoder.dispatchThreadgroups(
64 | threadgroupCount,
65 | threadsPerThreadgroup: threadgroupSize
66 | )
67 | encoder.endEncoding()
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/SIFTInterpolateKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTInterpolateKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 | import MetalShaders
12 |
13 | final class SIFTInterpolateKernel {
14 |
15 | private let maximumKeypoints = 4096
16 |
17 | private let computePipelineState: MTLComputePipelineState
18 | // private let differenceTextureArray: MTLTexture
19 |
20 | init(device: MTLDevice) {
21 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
22 |
23 | let function = library.makeFunction(name: "siftInterpolate")!
24 |
25 | // let descriptor = MTLTextureDescriptor()
26 | // descriptor.textureType = .type2DArray
27 | // descriptor.pixelFormat = .r32Float
28 | // descriptor.width = textureSize.width
29 | // descriptor.height = textureSize.height
30 | // descriptor.arrayLength = numberOfTextures
31 | // descriptor.mipmapLevelCount = 0
32 | // descriptor.storageMode = .shared
33 | // descriptor.usage = [.shaderRead, .shaderWrite]
34 |
35 | self.computePipelineState = try! device.makeComputePipelineState(
36 | function: function
37 | )
38 | // self.differenceTextureArray = device.makeTexture(descriptor: descriptor)!
39 | }
40 |
41 | func encode(
42 | commandBuffer: MTLCommandBuffer,
43 | parameters: Buffer,
44 | differenceTextures: MTLTexture,
45 | inputKeypoints: Buffer,
46 | outputKeypoints: Buffer
47 | ) {
48 | precondition(inputKeypoints.count == outputKeypoints.count)
49 | precondition(differenceTextures.textureType == .type2DArray)
50 | precondition(differenceTextures.pixelFormat == .r32Float)
51 |
52 | let encoder = commandBuffer.makeComputeCommandEncoder()!
53 | encoder.setComputePipelineState(computePipelineState)
54 | encoder.setBuffer(outputKeypoints.data, offset: 0, index: 0)
55 | encoder.setBuffer(inputKeypoints.data, offset: 0, index: 1)
56 | encoder.setBuffer(parameters.data, offset: 0, index: 2)
57 | encoder.setTexture(differenceTextures, index: 0)
58 |
59 | let threadsPerThreadgroup = MTLSize(
60 | width: computePipelineState.maxTotalThreadsPerThreadgroup,
61 | height: 1,
62 | depth: 1
63 | )
64 | let threadsPerGrid = MTLSize(
65 | width: outputKeypoints.count,
66 | height: 1,
67 | depth: 1
68 | )
69 | encoder.dispatchThreads(
70 | threadsPerGrid,
71 | threadsPerThreadgroup: threadsPerThreadgroup
72 | )
73 | encoder.endEncoding()
74 | }
75 | }
76 |
77 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/SIFTOrientationKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTOrientationKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 | import MetalShaders
12 |
13 | final class SIFTOrientationKernel {
14 |
15 | // struct Parameters {
16 | // let delta: Float32
17 | // let lambda: Float32
18 | // let orientationThreshold: Float32
19 | // }
20 | //
21 | // struct InputKeypoint {
22 | // let absoluteX: Int32
23 | // let absoluteY: Int32
24 | // let scale: Int32
25 | // let sigma: Float32
26 | // }
27 | //
28 | // struct OutputKeypoint {
29 | // let count: Int32
30 | // let orientations: [Float32]
31 | // }
32 |
33 | // typealias Parameters = SIFTOrientationKeypoint
34 |
35 | private let maximumKeypoints = 4096
36 |
37 | private let computePipelineState: MTLComputePipelineState
38 |
39 | init(device: MTLDevice) {
40 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
41 |
42 | let function = library.makeFunction(name: "siftOrientation")!
43 | function.label = "siftOrientation"
44 |
45 | self.computePipelineState = try! device.makeComputePipelineState(
46 | function: function
47 | )
48 | }
49 |
50 | func encode(
51 | commandBuffer: MTLCommandBuffer,
52 | parameters: Buffer,
53 | gradientTextures: MTLTexture,
54 | inputKeypoints: Buffer,
55 | outputKeypoints: Buffer
56 | ) {
57 | precondition(inputKeypoints.count == outputKeypoints.count)
58 | precondition(gradientTextures.textureType == .type2DArray)
59 | precondition(gradientTextures.pixelFormat == .rg32Float)
60 |
61 | let encoder = commandBuffer.makeComputeCommandEncoder()!
62 | encoder.label = "siftOrientationComputeEncoder"
63 | encoder.setComputePipelineState(computePipelineState)
64 | encoder.setBuffer(outputKeypoints.data, offset: 0, index: 0)
65 | encoder.setBuffer(inputKeypoints.data, offset: 0, index: 1)
66 | encoder.setBuffer(parameters.data, offset: 0, index: 2)
67 | encoder.setTexture(gradientTextures, index: 0)
68 |
69 | let threadsPerThreadgroup = MTLSize(
70 | width: computePipelineState.maxTotalThreadsPerThreadgroup,
71 | height: 1,
72 | depth: 1
73 | )
74 | let threadsPerGrid = MTLSize(
75 | width: outputKeypoints.count,
76 | height: 1,
77 | depth: 1
78 | )
79 |
80 | encoder.dispatchThreads(
81 | threadsPerGrid,
82 | threadsPerThreadgroup: threadsPerThreadgroup
83 | )
84 | encoder.endEncoding()
85 | }
86 | }
87 |
88 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/GaussianKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SubtractKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class GaussianKernel {
13 |
14 | private var workingTexture: MTLTexture!
15 |
16 | private let device: MTLDevice
17 | private let convolutionX: Convolution1DKernel
18 | private let convolutionY: Convolution1DKernel
19 |
20 | init(device: MTLDevice, sigma s: Float) {
21 | let radius = Int(ceil(4 * s))
22 | let size = (radius * 2) + 1
23 | print("GaussianKernel sigma=\(s) radius=\(radius) size=\(size)")
24 |
25 | var weights = [Float]()
26 | var t = Float(0)
27 | let ss = s * s
28 | for k in -radius ... radius {
29 | let kk = Float(k * k)
30 | let w = exp(-0.5 * (kk / ss))
31 | weights.append(w)
32 | t += w
33 | }
34 |
35 | precondition(weights.count == size)
36 | precondition(weights[radius] == 1.0)
37 |
38 | // Normalize weights
39 | for i in 0 ..< size {
40 | weights[i] = weights[i] / t
41 | }
42 |
43 | precondition(abs(1 - weights.reduce(0, +)) < 0.001)
44 |
45 | self.device = device
46 |
47 | self.convolutionX = Convolution1DKernel(
48 | device: device,
49 | axis: .x,
50 | weights: weights
51 | )
52 |
53 | self.convolutionY = Convolution1DKernel(
54 | device: device,
55 | axis: .y,
56 | weights: weights
57 | )
58 | }
59 |
60 | func encode(
61 | commandBuffer: MTLCommandBuffer,
62 | inputTexture: MTLTexture,
63 | outputTexture: MTLTexture
64 | ) {
65 | precondition(inputTexture.width == outputTexture.width)
66 | precondition(inputTexture.height == outputTexture.height)
67 | precondition(inputTexture.pixelFormat == outputTexture.pixelFormat)
68 |
69 | if workingTexture?.width != inputTexture.width || workingTexture?.height != inputTexture.height {
70 | let descriptor = MTLTextureDescriptor.texture2DDescriptor(
71 | pixelFormat: inputTexture.pixelFormat,
72 | width: inputTexture.width,
73 | height: inputTexture.height,
74 | mipmapped: false
75 | )
76 | descriptor.storageMode = .private
77 | descriptor.usage = [.shaderRead, .shaderWrite]
78 | workingTexture = device.makeTexture(descriptor: descriptor)
79 | }
80 |
81 | convolutionX.encode(
82 | commandBuffer: commandBuffer,
83 | inputTexture: inputTexture,
84 | outputTexture: workingTexture
85 | )
86 |
87 | convolutionY.encode(
88 | commandBuffer: commandBuffer,
89 | inputTexture: workingTexture,
90 | outputTexture: outputTexture
91 | )
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/NearestNeighborDownScaleKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborScaleKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 | import MetalShaders
12 |
13 |
14 | final class NearestNeighborDownScaleKernel {
15 |
16 | private let computePipelineState: MTLComputePipelineState
17 | private let parametersBuffer: MTLBuffer
18 |
19 | init(device: MTLDevice) {
20 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
21 |
22 | let function = library.makeFunction(name: "nearestNeighborDownScale")!
23 |
24 | self.computePipelineState = try! device.makeComputePipelineState(
25 | function: function
26 | )
27 | self.parametersBuffer = device.makeBuffer(
28 | length: MemoryLayout.stride
29 | )!
30 | }
31 |
32 | func encode(
33 | commandBuffer: MTLCommandBuffer,
34 | inputTexture: MTLTexture,
35 | inputSlice: Int,
36 | outputTexture: MTLTexture,
37 | outputSlice: Int
38 | ) {
39 | precondition((inputTexture.width / 2) == outputTexture.width)
40 | precondition((inputTexture.height / 2) == outputTexture.height)
41 | precondition(inputTexture.textureType == .type2DArray)
42 | precondition(inputTexture.pixelFormat == .r32Float)
43 | precondition(outputTexture.textureType == .type2DArray)
44 | precondition(outputTexture.pixelFormat == .r32Float)
45 |
46 | let p = parametersBuffer.contents().assumingMemoryBound(to: NearestNeighborScaleParameters.self)
47 | p[0] = NearestNeighborScaleParameters(
48 | inputSlice: Int32(inputSlice),
49 | outputSlice: Int32(outputSlice)
50 | )
51 |
52 | let encoder = commandBuffer.makeComputeCommandEncoder()!
53 | encoder.setComputePipelineState(computePipelineState)
54 | encoder.setTexture(outputTexture, index: 0)
55 | encoder.setTexture(inputTexture, index: 1)
56 | encoder.setBuffer(parametersBuffer, offset: 0, index: 0)
57 |
58 | // Set the compute kernel's threadgroup size of 16x16
59 | // TODO: Ger threadgroup size from command buffer.
60 | let threadgroupSize = MTLSize(
61 | width: 16,
62 | height: 16,
63 | depth: 1
64 | )
65 | // Calculate the number of rows and columns of threadgroups given the width of the input image
66 | // Ensure that you cover the entire image (or more) so you process every pixel
67 | // Since we're only dealing with a 2D data set, set depth to 1
68 | let threadgroupCount = MTLSize(
69 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
70 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
71 | depth: 1
72 | )
73 | encoder.dispatchThreadgroups(
74 | threadgroupCount,
75 | threadsPerThreadgroup: threadgroupSize
76 | )
77 | encoder.endEncoding()
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/Convolution1DKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborScaleKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class Convolution1DKernel {
13 |
14 | enum Axis {
15 | case x
16 | case y
17 | }
18 |
19 | private let computePipelineState: MTLComputePipelineState
20 | private let weightsBuffer: MTLBuffer
21 | private let parametersBuffer: MTLBuffer
22 |
23 | init(device: MTLDevice, axis: Axis, weights: [Float]) {
24 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
25 |
26 | let function: MTLFunction
27 |
28 | switch axis {
29 | case .x:
30 | function = library.makeFunction(name: "convolutionX")!
31 | case .y:
32 | function = library.makeFunction(name: "convolutionY")!
33 | }
34 |
35 | self.computePipelineState = try! device.makeComputePipelineState(
36 | function: function
37 | )
38 |
39 | var weights = weights
40 | var numberOfWeights: UInt32 = UInt32(weights.count)
41 | self.weightsBuffer = device.makeBuffer(
42 | bytes: &weights,
43 | length: MemoryLayout.stride * weights.count
44 | )!
45 | self.parametersBuffer = device.makeBuffer(
46 | bytes: &numberOfWeights,
47 | length: MemoryLayout.stride
48 | )!
49 | }
50 |
51 | func encode(
52 | commandBuffer: MTLCommandBuffer,
53 | inputTexture: MTLTexture,
54 | outputTexture: MTLTexture
55 | ) {
56 | precondition(inputTexture.width == outputTexture.width)
57 | precondition(inputTexture.height == outputTexture.height)
58 |
59 | let encoder = commandBuffer.makeComputeCommandEncoder()!
60 | encoder.setComputePipelineState(computePipelineState)
61 | encoder.setTexture(outputTexture, index: 0)
62 | encoder.setTexture(inputTexture, index: 1)
63 | encoder.setBuffer(weightsBuffer, offset: 0, index: 0)
64 | encoder.setBuffer(parametersBuffer, offset: 0, index: 1)
65 |
66 | // Set the compute kernel's threadgroup size of 16x16
67 | // TODO: Get threadgroup size from command buffer.
68 | let threadgroupSize = MTLSize(
69 | width: 16,
70 | height: 16,
71 | depth: 1
72 | )
73 | // Calculate the number of rows and columns of threadgroups given the width of the input image
74 | // Ensure that you cover the entire image (or more) so you process every pixel
75 | // Since we're only dealing with a 2D data set, set depth to 1
76 | let threadgroupCount = MTLSize(
77 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
78 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
79 | depth: 1
80 | )
81 | encoder.dispatchThreadgroups(
82 | threadgroupCount,
83 | threadsPerThreadgroup: threadgroupSize
84 | )
85 | encoder.endEncoding()
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/Image.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Image.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/19.
6 | //
7 |
8 | import Foundation
9 | import Metal
10 |
11 |
12 | final class Image {
13 |
14 | let size: IntegralSize
15 |
16 | private let texture: MTLTexture
17 | private let slice: Int
18 | private let buffer: UnsafeMutableBufferPointer
19 |
20 | init(texture: MTLTexture, label: String, slice: Int, defaultValue: T) {
21 | self.size = IntegralSize(width: texture.width, height: texture.height)
22 | self.slice = slice
23 | self.texture = texture
24 | self.texture.label = label
25 | self.buffer = {
26 | let capacity = texture.width * texture.height
27 | let buffer = UnsafeMutableBufferPointer.allocate(
28 | capacity: capacity
29 | )
30 | buffer.initialize(repeating: defaultValue)
31 | return buffer
32 | }()
33 | }
34 |
35 | deinit {
36 | buffer.deallocate()
37 | }
38 |
39 | func updateFromTexture() {
40 | let region = MTLRegion(
41 | origin: MTLOrigin(
42 | x: 0,
43 | y: 0,
44 | z: 0
45 | ),
46 | size: MTLSize(
47 | width: texture.width,
48 | height: texture.height,
49 | depth: 1
50 | )
51 | )
52 | let bytesPerComponent = MemoryLayout.stride
53 | let bytesPerRow = bytesPerComponent * texture.width
54 | let bytesPerImage = bytesPerRow * texture.height
55 | let pointer = UnsafeMutableRawPointer(buffer.baseAddress)!
56 | texture.getBytes(
57 | pointer,
58 | bytesPerRow: bytesPerRow,
59 | bytesPerImage: bytesPerImage,
60 | from: region,
61 | mipmapLevel: 0,
62 | slice: slice
63 | )
64 | }
65 |
66 | subscript(x: Int, y: Int) -> T {
67 | get {
68 | buffer[offset(x: x, y: y)]
69 | }
70 | set {
71 | buffer[offset(x: x, y: y)] = newValue
72 | }
73 | }
74 |
75 | private func offset(x: Int, y: Int) -> Int {
76 | precondition(x >= 0 && y >= 0 && x <= texture.width - 1 && y <= texture.height - 1)
77 | return (y * texture.width) + x
78 | }
79 | }
80 |
81 | //extension Image where T == Float {
82 | //
83 | // func getGradient(x: Int, y: Int) -> Gradient {
84 | // #warning("FIXME: IPOL implementation seems to swap dx and dy")
85 | // let g = getGradientVector(x: x, y: y)
86 | // return Gradient(
87 | // orientation: atan2(g.x, g.y),
88 | // magnitude: sqrt(g.x * g.x + g.y * g.y)
89 | // )
90 | // }
91 | //
92 | // func getGradientVector(x: Int, y: Int) -> SIMD2 {
93 | // let px: Float = self[x + 1, y]
94 | // let mx: Float = self[x - 1, y]
95 | // let py: Float = self[x, y + 1]
96 | // let my: Float = self[x, y - 1]
97 | // return SIMD2(x: (px - mx) * 0.5, y: (py - my) * 0.5)
98 | // }
99 | //}
100 |
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/GaussianDifferenceTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // GaussianDifferenceTests.swift
3 | // SkyLightTests
4 | //
5 | // Created by Luke Van In on 2022/12/26.
6 | //
7 |
8 | import XCTest
9 | import MetalPerformanceShaders
10 |
11 | @testable import SIFTMetal
12 |
13 | final class GaussianDifferenceTests: SharedTestCase {
14 |
15 | /*
16 | func testGaussianDifference() throws {
17 | let sourceTexture = try loadTexture(name: "butterfly", device: device)
18 |
19 | // let gaussian = MPSImageGaussianBlur(device: device, sigma: 2.5)
20 | // gaussian.edgeMode = .clamp
21 | // gaussian.offset = MPSOffset(x: 3, y: 3, z: 0)
22 | let gaussian = GaussianKernel(device: device, sigma: 2.5)
23 | // gaussian.edgeMode = .clamp
24 | // gaussian.offset = MPSOffset(x: 3, y: 3, z: 0)
25 |
26 | let subtract = MPSImageSubtract(device: device)
27 |
28 | let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
29 | pixelFormat: .r32Float,
30 | width: sourceTexture.width,
31 | height: sourceTexture.height,
32 | mipmapped: false
33 | )
34 | textureDescriptor.usage = [.shaderRead, .shaderWrite]
35 | textureDescriptor.storageMode = .shared
36 |
37 | let inputTexture = device.makeTexture(descriptor: textureDescriptor)!
38 | let gaussianTexture = device.makeTexture(descriptor: textureDescriptor)!
39 | // let gaussian1Texture = device.makeTexture(descriptor: textureDescriptor)!
40 | let resultTexture = device.makeTexture(descriptor: textureDescriptor)!
41 |
42 |
43 | let commandBuffer = commandQueue.makeCommandBuffer()!
44 |
45 | convertSRGBToLinearGrayscaleFunction.encode(
46 | commandBuffer: commandBuffer,
47 | sourceTexture: sourceTexture,
48 | destinationTexture: inputTexture
49 | )
50 |
51 | gaussian.encode(
52 | commandBuffer: commandBuffer,
53 | inputTexture: inputTexture,
54 | outputTexture: gaussianTexture
55 | )
56 |
57 | subtract.encode(
58 | commandBuffer: commandBuffer,
59 | primaryTexture: gaussianTexture,
60 | secondaryTexture: inputTexture,
61 | destinationTexture: resultTexture
62 | )
63 |
64 | commandBuffer.commit()
65 | commandBuffer.waitUntilCompleted()
66 |
67 | attachImage(
68 | name: "input",
69 | uiImage: makeUIImage(
70 | ciImage: smearColor(
71 | ciImage: makeCIImage(
72 | texture: inputTexture
73 | )
74 | ),
75 | context: ciContext
76 | )
77 | )
78 |
79 | attachImage(
80 | name: "gaussian",
81 | uiImage: makeUIImage(
82 | ciImage: smearColor(
83 | ciImage: makeCIImage(
84 | texture: gaussianTexture
85 | )
86 | ),
87 | context: ciContext
88 | )
89 | )
90 |
91 | attachImage(
92 | name: "result",
93 | uiImage: makeUIImage(
94 | ciImage: mapColor(
95 | ciImage: smearColor(
96 | ciImage: normalizeColor(
97 | ciImage: makeCIImage(
98 | texture: resultTexture
99 | )
100 | )
101 | )
102 | ),
103 | context: ciContext
104 | )
105 | )
106 | }
107 | */
108 | }
109 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/Common.hpp:
--------------------------------------------------------------------------------
1 | //
2 | // Common.h
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #ifndef Common_h
11 | #define Common_h
12 |
13 | using namespace metal;
14 |
15 | static inline int symmetrizedCoordinates(int i, int l) {
16 | int ll = 2 * l;
17 | i = (i + ll) % (ll);
18 | if (i > l - 1){
19 | i = ll - 1 - i;
20 | }
21 | return i;
22 | }
23 |
24 | //constant float3x3 identity = float3x3(
25 | // float3(1, 0, 0),
26 | // float3(0, 1, 0),
27 | // float3(0, 0, 1)
28 | //);
29 |
30 | // Computes the inverse of a 3x3 matrix using the cross product and
31 | // triple product.
32 | // https://en.wikipedia.org/wiki/Invertible_matrix#Inversion_of_3_×_3_matrices
33 | // See: https://www.onlinemathstutor.org/post/3x3_inverses
34 | static inline float3x3 invert(const float3x3 input) {
35 | const float3 x0 = input[0];
36 | const float3 x1 = input[1];
37 | const float3 x2 = input[2];
38 |
39 | const float d = determinant(input); // dot(x0, cross(x1, x2));
40 |
41 | const float3x3 cp = float3x3(
42 | cross(x1, x2),
43 | cross(x2, x0),
44 | cross(x0, x1)
45 | );
46 | return (1.0 / d) * cp;
47 | }
48 |
49 | // https://github.com/markkilgard/glut/blob/master/lib/gle/vvector.h
50 | //#define SCALE_ADJOINT_3X3(a,s,m) \
51 | //{ \
52 | // a[0][0] = (s) * (m[1][1] * m[2][2] - m[1][2] * m[2][1]); \
53 | // a[1][0] = (s) * (m[1][2] * m[2][0] - m[1][0] * m[2][2]); \
54 | // a[2][0] = (s) * (m[1][0] * m[2][1] - m[1][1] * m[2][0]); \
55 | // \
56 | // a[0][1] = (s) * (m[0][2] * m[2][1] - m[0][1] * m[2][2]); \
57 | // a[1][1] = (s) * (m[0][0] * m[2][2] - m[0][2] * m[2][0]); \
58 | // a[2][1] = (s) * (m[0][1] * m[2][0] - m[0][0] * m[2][1]); \
59 | // \
60 | // a[0][2] = (s) * (m[0][1] * m[1][2] - m[0][2] * m[1][1]); \
61 | // a[1][2] = (s) * (m[0][2] * m[1][0] - m[0][0] * m[1][2]); \
62 | // a[2][2] = (s) * (m[0][0] * m[1][1] - m[0][1] * m[1][0]); \
63 | //}
64 | //float3x3 scaleAdjoint(const float3x3 m, const float s) {
65 | // float3x3 a;
66 | // a[0][0] = (s) * (m[1][1] * m[2][2] - m[1][2] * m[2][1]);
67 | // a[1][0] = (s) * (m[1][2] * m[2][0] - m[1][0] * m[2][2]);
68 | // a[2][0] = (s) * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
69 | //
70 | // a[0][1] = (s) * (m[0][2] * m[2][1] - m[0][1] * m[2][2]);
71 | // a[1][1] = (s) * (m[0][0] * m[2][2] - m[0][2] * m[2][0]);
72 | // a[2][1] = (s) * (m[0][1] * m[2][0] - m[0][0] * m[2][1]);
73 | //
74 | // a[0][2] = (s) * (m[0][1] * m[1][2] - m[0][2] * m[1][1]);
75 | // a[1][2] = (s) * (m[0][2] * m[1][0] - m[0][0] * m[1][2]);
76 | // a[2][2] = (s) * (m[0][0] * m[1][1] - m[0][1] * m[1][0]);
77 | //
78 | // return a;
79 | //}
80 |
81 | // https://github.com/markkilgard/glut/blob/master/lib/gle/vvector.h
82 | //#define INVERT_3X3(b,det,a) \
83 | //{ \
84 | // double tmp; \
85 | // DETERMINANT_3X3 (det, a); \
86 | // tmp = 1.0 / (det); \
87 | // SCALE_ADJOINT_3X3 (b, tmp, a); \
88 | //}
89 | //float3x3 invert(const float3x3 input) {
90 | // float d = 1.0 / determinant(input);
91 | // return scaleAdjoint(input, d);
92 | //}
93 |
94 |
95 | // https://metalbyexample.com/fundamentals-of-image-processing/
96 | static inline float gaussian(float x, float y, float sigma) {
97 | float ss = sigma * sigma;
98 | float xx = x * x;
99 | float yy = y * y;
100 | float base = sqrt(2 * M_PI_F * ss);
101 | float exponent = (xx + yy) / (2 * ss);
102 | return (1 / base) * exp(-exponent);
103 | }
104 |
105 | #endif /* Common_h */
106 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/ConvolutionSeriesKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // NearestNeighborScaleKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 | import MetalShaders
12 |
13 | final class ConvolutionSeriesKernel {
14 |
15 | enum Axis {
16 | case x
17 | case y
18 | }
19 |
20 | private let computePipelineState: MTLComputePipelineState
21 | private let parametersBuffer: MTLBuffer
22 |
23 | init(device: MTLDevice, axis: Axis, inputDepth: Int, outputDepth: Int, weights: [Float]) {
24 | let library = try! device.makeDefaultLibrary(bundle: .metalShaders)
25 |
26 | let function: MTLFunction
27 |
28 | switch axis {
29 | case .x:
30 | function = library.makeFunction(name: "convolutionSeriesX")!
31 | function.label = "convolutionSeriesXFunction"
32 | case .y:
33 | function = library.makeFunction(name: "convolutionSeriesY")!
34 | function.label = "convolutionSeriesYFunction"
35 | }
36 |
37 | self.computePipelineState = try! device.makeComputePipelineState(
38 | function: function
39 | )
40 |
41 | var weights = weights
42 | var parameters = ConvolutionParameters()
43 | parameters.inputDepth = Int32(inputDepth)
44 | parameters.outputDepth = Int32(outputDepth)
45 | parameters.count = Int32(weights.count)
46 | withUnsafeMutablePointer(to: ¶meters.weights) { p in
47 | let p = UnsafeMutableRawPointer(p).assumingMemoryBound(to: Float.self)
48 | p.assign(from: &weights, count: weights.count)
49 | }
50 | self.parametersBuffer = device.makeBuffer(
51 | bytes: ¶meters,
52 | length: MemoryLayout.stride
53 | )!
54 | }
55 |
56 | func encode(
57 | commandBuffer: MTLCommandBuffer,
58 | inputTexture: MTLTexture,
59 | outputTexture: MTLTexture
60 | ) {
61 | precondition(inputTexture.width == outputTexture.width)
62 | precondition(inputTexture.height == outputTexture.height)
63 | // precondition(inputTexture.arrayLength == outputTexture.arrayLength)
64 | precondition(inputTexture.textureType == .type2DArray)
65 | precondition(inputTexture.pixelFormat == .r32Float)
66 | precondition(outputTexture.textureType == .type2DArray)
67 | precondition(outputTexture.pixelFormat == .r32Float)
68 |
69 | let encoder = commandBuffer.makeComputeCommandEncoder()!
70 | encoder.label = "convolutionSeriesComputeEncoder"
71 | encoder.setComputePipelineState(computePipelineState)
72 | encoder.setTexture(outputTexture, index: 0)
73 | encoder.setTexture(inputTexture, index: 1)
74 | encoder.setBuffer(parametersBuffer, offset: 0, index: 0)
75 |
76 | // Set the compute kernel's threadgroup size of 16x16
77 | // TODO: Get threadgroup size from command buffer.
78 | let threadgroupSize = MTLSize(
79 | width: 16,
80 | height: 16,
81 | depth: 1
82 | )
83 | // Calculate the number of rows and columns of threadgroups given the width of the input image
84 | // Ensure that you cover the entire image (or more) so you process every pixel
85 | // Since we're only dealing with a 2D data set, set depth to 1
86 | let threadgroupCount = MTLSize(
87 | width: (outputTexture.width + threadgroupSize.width - 1) / threadgroupSize.width,
88 | height: (outputTexture.height + threadgroupSize.height - 1) / threadgroupSize.height,
89 | depth: 1
90 | )
91 | encoder.dispatchThreadgroups(
92 | threadgroupCount,
93 | threadsPerThreadgroup: threadgroupSize
94 | )
95 | // encoder.dispatchThreads(<#T##threadsPerGrid: MTLSize##MTLSize#>, threadsPerThreadgroup: <#T##MTLSize#>)
96 | encoder.endEncoding()
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version: 5.7
2 | // The swift-tools-version declares the minimum version of Swift required to build this package.
3 |
4 | import PackageDescription
5 |
6 | let package = Package(
7 | name: "SIFTMetal",
8 | platforms: [.iOS(.v16), .macOS(.v13)],
9 | products: [
10 | // Products define the executables and libraries a package produces, and make them visible to other packages.
11 | .library(
12 | name: "SIFTMetal",
13 | targets: ["SIFTMetal", "MetalShaders"]),
14 | ],
15 | dependencies: [
16 | // Dependencies declare other packages that this package depends on.
17 | // .package(url: /* package url */, from: "1.0.0"),
18 | ],
19 | targets: [
20 | // Targets are the basic building blocks of a package. A target can define a module or a test suite.
21 | // Targets can depend on other targets in this package, and on products in packages this package depends on.
22 | .target(
23 | name: "MetalShaders",
24 | resources: [
25 | .process("Metal")
26 | ]
27 | ),
28 | .target(
29 | name: "SIFTMetal",
30 | dependencies: ["MetalShaders"],
31 | resources: [
32 | .process("Resources")
33 | ],
34 | swiftSettings: [
35 | .unsafeFlags(["-O"])
36 | ]
37 | ),
38 | .testTarget(
39 | name: "SIFTMetalTests",
40 | dependencies: ["SIFTMetal", "MetalShaders"],
41 | resources: [
42 | .process("Resources/butterfly.png"),
43 | .process("Resources/butterfly-descriptors.txt"),
44 |
45 | .process("Resources/extra_OnEdgeResp_butterfly.txt"),
46 |
47 | .process("Resources/scalespace_butterfly_o000_s000.png"),
48 | .process("Resources/scalespace_butterfly_o000_s001.png"),
49 | .process("Resources/scalespace_butterfly_o000_s002.png"),
50 | .process("Resources/scalespace_butterfly_o000_s003.png"),
51 | .process("Resources/scalespace_butterfly_o000_s004.png"),
52 | .process("Resources/scalespace_butterfly_o000_s005.png"),
53 |
54 | .process("Resources/scalespace_butterfly_o001_s000.png"),
55 | .process("Resources/scalespace_butterfly_o001_s001.png"),
56 | .process("Resources/scalespace_butterfly_o001_s002.png"),
57 | .process("Resources/scalespace_butterfly_o001_s003.png"),
58 | .process("Resources/scalespace_butterfly_o001_s004.png"),
59 | .process("Resources/scalespace_butterfly_o001_s005.png"),
60 |
61 | .process("Resources/scalespace_butterfly_o002_s000.png"),
62 | .process("Resources/scalespace_butterfly_o002_s001.png"),
63 | .process("Resources/scalespace_butterfly_o002_s002.png"),
64 | .process("Resources/scalespace_butterfly_o002_s003.png"),
65 | .process("Resources/scalespace_butterfly_o002_s004.png"),
66 | .process("Resources/scalespace_butterfly_o002_s005.png"),
67 |
68 | .process("Resources/scalespace_butterfly_o003_s000.png"),
69 | .process("Resources/scalespace_butterfly_o003_s001.png"),
70 | .process("Resources/scalespace_butterfly_o003_s002.png"),
71 | .process("Resources/scalespace_butterfly_o003_s003.png"),
72 | .process("Resources/scalespace_butterfly_o003_s004.png"),
73 | .process("Resources/scalespace_butterfly_o003_s005.png"),
74 |
75 | .process("Resources/scalespace_butterfly_o004_s000.png"),
76 | .process("Resources/scalespace_butterfly_o004_s001.png"),
77 | .process("Resources/scalespace_butterfly_o004_s002.png"),
78 | .process("Resources/scalespace_butterfly_o004_s003.png"),
79 | .process("Resources/scalespace_butterfly_o004_s004.png"),
80 | .process("Resources/scalespace_butterfly_o004_s005.png"),
81 | ]
82 | ),
83 | ]
84 | )
85 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Metal Compute/GaussianSeriesKernel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SubtractKernel.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/25.
6 | //
7 |
8 | import Foundation
9 | import MetalPerformanceShaders
10 |
11 |
12 | final class GaussianSeriesKernel {
13 |
14 | private let device: MTLDevice
15 | private let count: Int
16 | private let convolutionX: [ConvolutionSeriesKernel]
17 | private let convolutionY: [ConvolutionSeriesKernel]
18 | private let workingTexture: MTLTexture
19 |
20 | init(device: MTLDevice, sigmas: [Float], textureSize: IntegralSize, arrayLength: Int) {
21 |
22 | let count = sigmas.count
23 |
24 | var convolutionsX = [ConvolutionSeriesKernel]()
25 | var convolutionsY = [ConvolutionSeriesKernel]()
26 |
27 | for i in 0 ..< count {
28 | let s = sigmas[i]
29 | let radius = Int(ceil(4 * s))
30 | let size = (radius * 2) + 1
31 | print("GaussianKernel sigma=\(s) radius=\(radius) size=\(size)")
32 |
33 | var weights = [Float]()
34 | var t = Float(0)
35 | let ss = s * s
36 | for k in -radius ... radius {
37 | let kk = Float(k * k)
38 | let w = exp(-0.5 * (kk / ss))
39 | weights.append(w)
40 | t += w
41 | }
42 |
43 | precondition(weights.count == size)
44 | precondition(weights[radius] == 1.0)
45 |
46 | // Normalize weights
47 | for i in 0 ..< size {
48 | weights[i] = weights[i] / t
49 | }
50 |
51 | precondition(abs(1 - weights.reduce(0, +)) < 0.001)
52 |
53 | let convolutionX = ConvolutionSeriesKernel(
54 | device: device,
55 | axis: .x,
56 | inputDepth: i,
57 | outputDepth: 0,
58 | weights: weights
59 | )
60 | convolutionsX.append(convolutionX)
61 |
62 | let convolutionY = ConvolutionSeriesKernel(
63 | device: device,
64 | axis: .y,
65 | inputDepth: 0,
66 | outputDepth: i + 1,
67 | weights: weights
68 | )
69 | convolutionsY.append(convolutionY)
70 |
71 | }
72 |
73 | self.device = device
74 | self.count = count
75 | self.convolutionX = convolutionsX
76 | self.convolutionY = convolutionsY
77 |
78 | self.workingTexture = {
79 | let descriptor = MTLTextureDescriptor()
80 | descriptor.textureType = .type2DArray
81 | descriptor.pixelFormat = .r32Float
82 | descriptor.width = textureSize.width
83 | descriptor.height = textureSize.height
84 | descriptor.arrayLength = 1
85 | descriptor.mipmapLevelCount = 1
86 | descriptor.storageMode = .private
87 | descriptor.usage = [.shaderRead, .shaderWrite]
88 | return device.makeTexture(descriptor: descriptor)!
89 | }()
90 | }
91 |
92 | func encode(
93 | commandBuffer: MTLCommandBuffer,
94 | texture: MTLTexture
95 | ) {
96 | // precondition(inputTexture.width == workingTexture.width)
97 | // precondition(inputTexture.height == workingTexture.height)
98 | // precondition(outputTexture.width == workingTexture.width)
99 | // precondition(outputTexture.height == workingTexture.height)
100 | // precondition(inputTexture.pixelFormat == .r32Float)
101 | // precondition(inputTexture.textureType == .type2DArray)
102 | // precondition(inputTexture.arrayLength == workingTexture.arrayLength)
103 | // precondition(outputTexture.pixelFormat == .r32Float)
104 | // precondition(outputTexture.textureType == .type2DArray)
105 | // precondition(outputTexture.arrayLength == workingTexture.arrayLength)
106 |
107 | for i in 0 ..< count {
108 | convolutionX[i].encode(
109 | commandBuffer: commandBuffer,
110 | inputTexture: texture,
111 | outputTexture: workingTexture
112 | )
113 | convolutionY[i].encode(
114 | commandBuffer: commandBuffer,
115 | inputTexture: workingTexture,
116 | outputTexture: texture
117 | )
118 | }
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/KeypointTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // KeypointTests.swift
3 | // SkyLightTests
4 | //
5 | // Created by Luke Van In on 2022/12/26.
6 | //
7 |
8 | import XCTest
9 |
10 | @testable import SIFTMetal
11 |
12 |
13 | final class KeypointTests: SharedTestCase {
14 |
15 | func testKeypoints() throws {
16 |
17 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false)
18 | let configuration = SIFT.Configuration(
19 | inputSize: IntegralSize(
20 | width: inputTexture.width,
21 | height: inputTexture.height
22 | )
23 | )
24 | let subject = SIFT(device: device, configuration: configuration)
25 | let octaveKeypoints = subject.getKeypoints(inputTexture)
26 | let keypoints = Array(octaveKeypoints.joined())
27 | print("Found", keypoints.count, "keypoints")
28 |
29 | // let referenceKeypoints: [SIFTKeypoint] = []
30 | let referenceImage: CGImage = {
31 | let originalImage = CIImage(
32 | mtlTexture: inputTexture,
33 | options: [
34 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!,
35 | ]
36 | )!
37 | .oriented(.downMirrored)
38 | .smearColor()
39 | let cgImage = ciContext.makeCGImage(ciImage: originalImage)
40 | return cgImage
41 | }()
42 |
43 | // let referenceKeypoints = try loadKeypoints(filename: "extra_DoGSoftThresh_butterfly")
44 | // let referenceKeypoints = try loadKeypoints(filename: "extra_ExtrInterp_butterfly")
45 | // let referenceKeypoints = try loadKeypoints(filename: "extra_DoGThresh_butterfly")
46 | let referenceKeypoints = try loadKeypoints(filename: "extra_OnEdgeResp_butterfly")
47 | // let referenceKeypoints = try loadKeypoints(filename: "extra_FarFromBorder_butterfly")
48 | // let referenceImage = UIImage(named: "butterfly-keypoints-raw")!.cgImage!
49 | // for i in 0 ..< keypoints.count {
50 | // var keypoint = keypoints[i]
51 | // keypoint.x = Float(inputTexture.width) - keypoint.x
52 | // keypoint.y = Float(inputTexture.height) - keypoint.y
53 | // keypoints[i] = keypoint
54 | // }
55 |
56 | let renderer = SIFTRenderer()
57 | attachImage(
58 | name: "keypoints",
59 | uiImage: renderer.drawKeypoints(
60 | sourceImage: referenceImage,
61 | referenceKeypoints: referenceKeypoints,
62 | foundKeypoints: keypoints
63 | )
64 | )
65 |
66 | // for (scale, octave) in subject.octaves.enumerated() {
67 | //
68 | // for (index, texture) in octave.keypointTextures.enumerated() {
69 | //
70 | // attachImage(
71 | // name: "keypoints(\(scale), \(index))",
72 | // uiImage: ciContext.makeUIImage(
73 | // ciImage: CIImage(
74 | // mtlTexture: texture,
75 | // options: [
76 | // .colorSpace: CGColorSpace(name: CGColorSpace.genericGrayGamma2_2)!
77 | // ]
78 | // )!
79 | // .oriented(.downMirrored)
80 | // .smearColor()
81 | // )
82 | // )
83 | // }
84 | //
85 | // }
86 |
87 | }
88 |
89 |
90 | private func loadKeypoints(filename: String, extension: String = "txt") throws -> [SIFTKeypoint] {
91 | var keypoints = [SIFTKeypoint]()
92 |
93 | let fileURL = bundle.url(forResource: filename, withExtension: `extension`)!
94 | let data = try Data(contentsOf: fileURL)
95 | let string = String(data: data, encoding: .utf8)!
96 | let lines = string.split(separator: "\n")
97 |
98 | for line in lines {
99 | let components = line.split(separator: " ")
100 | let y = Float(components[0])!
101 | let x = Float(components[1])!
102 | let s = Float(components[2])!
103 | let keypoint = SIFTKeypoint(
104 | octave: 0,
105 | scale: 0,
106 | subScale: 0,
107 | scaledCoordinate: .zero,
108 | absoluteCoordinate: SIMD2(x: x, y: y),
109 | sigma: s,
110 | value: 0
111 | )
112 | keypoints.append(keypoint)
113 | }
114 |
115 | return keypoints
116 | }
117 |
118 | }
119 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/SIFTExtrema.metal:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTExtrema.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #include "../include/SIFTExtrema.h"
11 |
12 | using namespace metal;
13 |
14 |
15 | constant int3 neighborOffsets[] = {
16 | int3(-1, -1, -1),
17 | int3( 0, -1, -1),
18 | int3(+1, -1, -1),
19 | int3(-1, 0, -1),
20 | int3( 0, 0, -1),
21 | int3(+1, 0, -1),
22 | int3(-1, +1, -1),
23 | int3( 0, +1, -1),
24 | int3(+1, +1, -1),
25 |
26 | int3(-1, -1, 0),
27 | int3( 0, -1, 0),
28 | int3(+1, -1, 0),
29 | int3(-1, 0, 0),
30 |
31 | int3(+1, 0, 0),
32 | int3(-1, +1, 0),
33 | int3( 0, +1, 0),
34 | int3(+1, +1, 0),
35 |
36 | int3(-1, -1, +1),
37 | int3( 0, -1, +1),
38 | int3(+1, -1, +1),
39 | int3(-1, 0, +1),
40 | int3( 0, 0, +1),
41 | int3(+1, 0, +1),
42 | int3(-1, +1, +1),
43 | int3( 0, +1, +1),
44 | int3(+1, +1, +1),
45 | };
46 |
47 |
48 | static inline float fetch(
49 | texture2d_array texture [[texture(0)]],
50 | const int2 g,
51 | const int s,
52 | const int i
53 | ) {
54 | const int3 neighborOffset = neighborOffsets[i];
55 | const int2 neighborDelta = g + neighborOffset.xy;
56 | const int textureIndex = s + neighborOffset.z;
57 | const float neighborValue = texture.read((ushort2)neighborDelta, (short)textureIndex).r;
58 | return neighborValue;
59 | }
60 |
61 |
62 | kernel void siftExtremaList(
63 | device SIFTExtremaResult * output [[buffer(0)]],
64 | device atomic_uint * outputCount [[buffer(1)]],
65 | texture2d_array inputTexture [[texture(0)]],
66 | ushort3 gid [[thread_position_in_grid]],
67 | ushort3 lid [[thread_position_in_threadgroup]],
68 | ushort tid [[thread_index_in_threadgroup]]
69 | ) {
70 | // Thread group runs [0...output.width - 2][0...output.height - 2]
71 | const ushort threadsInThreadgroup = 1024;
72 | threadgroup SIFTExtremaResult localResults[threadsInThreadgroup];
73 | threadgroup atomic_int localCount;
74 | atomic_store_explicit(&localCount, 0, memory_order_relaxed);
75 | threadgroup_barrier(mem_flags::mem_none);
76 |
77 | const int2 g = (int2)gid.xy + 1;
78 | const int s = (int)gid.z + 1;
79 | const float value = inputTexture.read((ushort2)g, (ushort)s).r;
80 |
81 | float minimum = +1000;
82 | float maximum = -1000;
83 |
84 | for (int i = 1; i < 26; i++) {
85 | float neighborValue = fetch(inputTexture, g, s, i);
86 | minimum = min(minimum, neighborValue);
87 | maximum = max(maximum, neighborValue);
88 | }
89 |
90 | if ((value < minimum) || (value > maximum)) {
91 | const int i = atomic_fetch_add_explicit(&localCount, 1, memory_order_relaxed);
92 | SIFTExtremaResult result;
93 | result.x = g.x;
94 | result.y = g.y;
95 | result.scale = s;
96 | localResults[i] = result;
97 | }
98 |
99 | // Copy local results to output
100 | threadgroup_barrier(mem_flags::mem_none);
101 | if (tid == 0) {
102 | const int count = atomic_load_explicit(&localCount, memory_order_relaxed);
103 | if (count > 0) {
104 | const int b = atomic_fetch_add_explicit(outputCount, count, memory_order_relaxed);
105 | for (int i = 0; i < count; i++) {
106 | output[b + i] = localResults[i];
107 | }
108 | }
109 | }
110 | }
111 |
112 |
113 | kernel void siftExtrema(
114 | texture2d_array outputTexture [[texture(0)]],
115 | texture2d_array inputTexture [[texture(1)]],
116 | ushort3 gid [[thread_position_in_grid]],
117 | ushort3 threadPositionInThreadGroup [[thread_position_in_threadgroup]],
118 | ushort3 threadsPerThreadGroup [[threads_per_threadgroup]]
119 | ) {
120 | // Thread group runs [0...output.width - 2][0...output.height - 2]
121 |
122 | const float value = inputTexture.read(gid.xy + 1, gid.z + 1).r;
123 | const int2 center = int2(gid.xy);
124 |
125 | float minValue = +1000;
126 | float maxValue = -1000;
127 |
128 | for (int i = 0; i < 26; i++) {
129 | int3 neighborOffset = neighborOffsets[i];
130 | ushort textureIndex = gid.z + neighborOffset.x;
131 | int2 neighborDelta = int2(neighborOffset.yz);
132 | ushort2 coordinate = ushort2(center + neighborDelta);
133 | float neighborValue = inputTexture.read(coordinate + 1, textureIndex).r;
134 |
135 | minValue = min(minValue, neighborValue);
136 | maxValue = max(maxValue, neighborValue);
137 | }
138 |
139 | float result = 0;
140 |
141 | if ((value < minValue) || (value > maxValue)) {
142 | result = 1;
143 | }
144 |
145 | outputTexture.write(float4(result, 0, 0, 1), gid.xy + 1, gid.z);
146 | }
147 |
148 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/SIFTOrientation.metal:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTOrientation.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #include "Common.hpp"
11 | #include "../include/SIFTOrientation.h"
12 |
13 | using namespace metal;
14 |
15 |
16 | float orientationFromBin(float bin) {
17 | const int n = SIFT_ORIENTATION_HISTOGRAM_BINS;
18 | float t = bin / (float)n;
19 | float tau = 2 * M_PI_F;
20 | float orientation = t * tau;
21 | if (orientation < 0) {
22 | orientation += tau;
23 | }
24 | if (orientation >= tau) {
25 | orientation -= tau;
26 | }
27 | return orientation;
28 | }
29 |
30 |
31 | float interpolatePeak(float h1, float h2, float h3) {
32 | return (h1 - h3) / (2 * (h1 + h3 - 2 * h2));
33 | }
34 |
35 |
36 | void getPrincipalOrientations(
37 | thread float * histogram,
38 | float orientationThreshold,
39 | thread int & orientationsCount,
40 | thread float * orientations
41 | ) {
42 | const int bins = SIFT_ORIENTATION_HISTOGRAM_BINS;
43 |
44 | float maximum = INT_MIN;
45 | for (int i = 0; i < bins; i++) {
46 | maximum = max(maximum, histogram[i]);
47 | }
48 |
49 | const float threshold = orientationThreshold * maximum;
50 |
51 | orientationsCount = 0;
52 |
53 | for (int i = 0; i < bins; i++) {
54 | float hm = histogram[((i - 1) + bins) % bins];
55 | float h0 = histogram[i];
56 | float hp = histogram[(i + 1) % bins];
57 | if ((h0 > threshold) && (h0 > hm) && (h0 > hp)) {
58 | float offset = interpolatePeak(hm, h0, hp);
59 | float orientation = orientationFromBin((float)i + offset);
60 | orientations[orientationsCount] = orientation;
61 | orientationsCount += 1;
62 | }
63 | }
64 | }
65 |
66 |
67 | void smoothHistogram(
68 | thread float * histogram,
69 | int iterations
70 | ) {
71 | const int n = SIFT_ORIENTATION_HISTOGRAM_BINS;
72 | float temp[n];
73 | for (int j = 0; j < iterations; j++) {
74 | for (int i = 0; i < n; i++) {
75 | temp[i] = histogram[i];
76 | }
77 | for (int i = 0; i < n; i++) {
78 | float h0 = temp[((i - 1) + n) % n];
79 | float h1 = temp[i];
80 | float h2 = temp[(i + 1) % n];
81 | float v = (h0 + h1 + h2) / 3.0;
82 | histogram[i] = v;
83 | }
84 | }
85 | }
86 |
87 |
88 | void getOrientationsHistogram(
89 | texture2d_array g,
90 | int absoluteX,
91 | int absoluteY,
92 | int scale,
93 | float keypointSigma,
94 | float delta,
95 | float lambda,
96 | thread float * histogram
97 | ) {
98 | const int bins = SIFT_ORIENTATION_HISTOGRAM_BINS;
99 | int x = round((float)absoluteX / delta);
100 | int y = round((float)absoluteY / delta);
101 | float sigma = keypointSigma / delta;
102 |
103 | float exponentDenominator = 2.0 * lambda * lambda;
104 |
105 | int r = ceil(3 * lambda * sigma);
106 |
107 | for (int j = -r; j <= r; j++) {
108 | for (int i = -r; i <= r; i++) {
109 |
110 | // Gaussian weighting
111 | float u = (float)i / sigma;
112 | float v = (float)j / sigma;
113 | float r2 = u * u + v * v;
114 | float w = exp(-r2 / exponentDenominator);
115 |
116 | // Gradient orientation
117 | float2 gradient = g.read(ushort2(x + i, y + j), scale).rg;
118 | float orientation = gradient.x;
119 | float magnitude = gradient.y;
120 |
121 | // Add to histogram
122 | float t = orientation / (2 * M_PI_F);
123 | int bin = round(t * (float)bins);
124 | if (bin < 0) {
125 | bin += bins;
126 | }
127 | if (bin >= bins) {
128 | bin -= bins;
129 | }
130 |
131 | float m = w * magnitude;
132 |
133 | histogram[bin] += m;
134 | }
135 | }
136 | }
137 |
138 |
139 |
140 | kernel void siftOrientation(
141 | device SIFTOrientationResult * results [[buffer(0)]],
142 | device SIFTOrientationKeypoint * keypoints [[buffer(1)]],
143 | device SIFTOrientationParameters & parameters [[buffer(2)]],
144 | texture2d_array gradientTextures [[texture(0)]],
145 | ushort gid [[thread_position_in_grid]]
146 | ) {
147 | const int bins = SIFT_ORIENTATION_HISTOGRAM_BINS;
148 | const SIFTOrientationKeypoint keypoint = keypoints[gid];
149 | SIFTOrientationResult result;
150 | result.keypoint = keypoint.index;
151 |
152 | float histogram[bins];
153 | for (int i = 0; i < bins; i++) {
154 | histogram[i] = 0;
155 | }
156 |
157 | getOrientationsHistogram(
158 | gradientTextures,
159 | keypoint.absoluteX,
160 | keypoint.absoluteY,
161 | keypoint.scale,
162 | keypoint.sigma,
163 | parameters.delta,
164 | parameters.lambda,
165 | histogram
166 | );
167 | smoothHistogram(histogram, 6);
168 | getPrincipalOrientations(
169 | histogram,
170 | parameters.orientationThreshold,
171 | result.count,
172 | result.orientations
173 | );
174 | results[gid] = result;
175 | }
176 |
--------------------------------------------------------------------------------
/.swiftpm/xcode/xcshareddata/xcschemes/SIFTMetal.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
8 |
9 |
15 |
21 |
22 |
23 |
29 |
35 |
36 |
37 |
43 |
49 |
50 |
51 |
57 |
63 |
64 |
65 |
71 |
77 |
78 |
79 |
80 |
81 |
86 |
87 |
89 |
95 |
96 |
97 |
98 |
99 |
109 |
110 |
116 |
117 |
123 |
124 |
125 |
126 |
128 |
129 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/SharedTestCase.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SharedTestCase.swift
3 | // SkyLightTests
4 | //
5 | // Created by Luke Van In on 2022/12/26.
6 | //
7 |
8 | import XCTest
9 | import CoreImage
10 | import CoreImage.CIFilterBuiltins
11 | import MetalKit
12 | import MetalPerformanceShaders
13 |
14 | @testable import SIFTMetal
15 |
16 |
17 | let bundle = Bundle.module
18 |
19 |
20 | extension MTLDevice {
21 |
22 | func loadTexture(url: URL, srgb: Bool = true) throws -> MTLTexture {
23 | print("Loading texture \(url)")
24 | let loader = MTKTextureLoader(device: self)
25 | return try loader.newTexture(
26 | URL: url,
27 | options: [
28 | .SRGB: NSNumber(value: srgb),
29 | ]
30 | )
31 | }
32 |
33 | func loadTexture(name: String, extension: String, srgb: Bool = true) throws -> MTLTexture {
34 | let imageURL = bundle.url(forResource: name, withExtension: `extension`)!
35 | return try loadTexture(url: imageURL, srgb: srgb)
36 | }
37 | }
38 |
39 |
40 |
41 | //func makeUIImage(texture: MTLTexture, context: CIContext) -> UIImage {
42 | // makeUIImage(ciImage: makeCIImage(texture: texture), context: context)
43 | //}
44 |
45 |
46 | extension CIContext {
47 |
48 | func makeUIImage(ciImage: CIImage) -> UIImage {
49 | return UIImage(cgImage: makeCGImage(ciImage: ciImage))
50 | }
51 |
52 | func makeCGImage(ciImage: CIImage) -> CGImage {
53 | return createCGImage(ciImage, from: ciImage.extent)!
54 | }
55 | }
56 |
57 |
58 | extension CIImage {
59 |
60 | convenience init(name: String, extension: String) throws {
61 | let imageURL = bundle.url(forResource: name, withExtension: `extension`)!
62 | self.init(contentsOf: imageURL)!
63 | }
64 |
65 | // convenience init(texture: MTLTexture, orientation: CGImagePropertyOrientation = .downMirrored) {
66 | // var ciImage = CIImage(mtlTexture: texture)!
67 | // ciImage = ciImage.transformed(
68 | // by: ciImage.orientationTransform(
69 | // for: orientation
70 | // )
71 | // )
72 | // self = ciImage
73 | // }
74 |
75 | func orientation(_ orientation: CGImagePropertyOrientation) -> CIImage {
76 | return transformed(
77 | by: orientationTransform(
78 | for: orientation
79 | )
80 | )
81 | }
82 |
83 | func sRGBToneCurveToLinear() -> CIImage {
84 | let filter = CIFilter.sRGBToneCurveToLinear()
85 | filter.inputImage = self
86 | return filter.outputImage!
87 | }
88 |
89 |
90 | func linearToSRGBToneCurve() -> CIImage {
91 | let filter = CIFilter.linearToSRGBToneCurve()
92 | filter.inputImage = self
93 | return filter.outputImage!
94 | }
95 |
96 | func smearColor() -> CIImage {
97 | let filter = CIFilter.colorMatrix()
98 | filter.rVector = CIVector(x: 1, y: 0, z: 0)
99 | filter.gVector = CIVector(x: 1, y: 0, z: 0)
100 | filter.bVector = CIVector(x: 1, y: 0, z: 0)
101 | filter.biasVector = CIVector(x: 0, y: 0, z: 0)
102 | filter.inputImage = self
103 | return filter.outputImage!.cropped(
104 | to: self.extent
105 | )
106 | }
107 |
108 | func normalizeColor() -> CIImage {
109 | let filter = CIFilter.colorMatrix()
110 | filter.rVector = CIVector(x: 0.5, y: 0, z: 0)
111 | filter.gVector = CIVector(x: 0.5, y: 0, z: 0)
112 | filter.bVector = CIVector(x: 0.5, y: 0, z: 0)
113 | filter.biasVector = CIVector(x: 0.5, y: 0.5, z: 0.5)
114 | filter.inputImage = self
115 | return filter.outputImage!.cropped(
116 | to: self.extent
117 | )
118 | }
119 |
120 | func invertColor() -> CIImage {
121 | self.colorInverted()
122 | }
123 |
124 |
125 | func mapColor() -> CIImage {
126 | let imageFileURL = bundle.url(forResource: "viridis", withExtension: "png")!
127 | let filter = CIFilter.colorMap()
128 | filter.gradientImage = CIImage(contentsOf: imageFileURL)
129 | filter.inputImage = self
130 | return filter.outputImage!
131 | }
132 | }
133 |
134 |
135 | class SharedTestCase: XCTestCase {
136 |
137 | var device: MTLDevice!
138 | var commandQueue: MTLCommandQueue!
139 | var ciContext: CIContext!
140 | // var upscaleFunction: NearestNeighborUpScaleKernel!
141 | var subtractFunction: MPSImageSubtract!
142 | var convertSRGBToLinearGrayscaleFunction: MPSImageConversion!
143 | // var convertLinearRGBToLinearGrayscaleFunction: MPSImageConversion!
144 |
145 | override func setUp() {
146 | device = MTLCreateSystemDefaultDevice()!
147 | commandQueue = device.makeCommandQueue()!
148 | // upscaleFunction = NearestNeighborUpScaleKernel(device: device)
149 | subtractFunction = MPSImageSubtract(device: device)
150 | convertSRGBToLinearGrayscaleFunction = MPSImageConversion(
151 | device: device,
152 | srcAlpha: .alphaIsOne,
153 | destAlpha: .alphaIsOne,
154 | backgroundColor: nil,
155 | conversionInfo: CGColorConversionInfo(
156 | src: CGColorSpace(name: CGColorSpace.sRGB)!,
157 | dst: CGColorSpace(name: CGColorSpace.linearGray)!
158 | )
159 | )
160 | ciContext = CIContext(
161 | mtlDevice: device,
162 | options: [
163 | .outputColorSpace: CGColorSpace(name: CGColorSpace.sRGB)!,
164 | ]
165 | )
166 | }
167 |
168 | override func tearDown() {
169 | ciContext = nil
170 | // upscaleFunction = nil
171 | subtractFunction = nil
172 | commandQueue = nil
173 | device = nil
174 | }
175 |
176 | func attachImage(name: String, uiImage: UIImage) {
177 | let attachment = XCTAttachment(
178 | image: uiImage,
179 | quality: .original
180 | )
181 | attachment.name = name
182 | attachment.lifetime = .keepAlways
183 | add(attachment)
184 | }
185 | }
186 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/SIFTDescriptor.metal:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTDescriptor.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/08.
6 | //
7 |
8 | #include
9 |
10 | #include "../include/SIFTDescriptor.h"
11 |
12 | using namespace metal;
13 |
14 |
15 | void normalizeFeatures(
16 | int count,
17 | thread float * features
18 | ) {
19 | float magnitude = 0;
20 | for (int i = 0; i < count; i++) {
21 | float f = features[i];
22 | magnitude += (f * f);
23 | }
24 | const float d = 1.0 / sqrt(magnitude);
25 | for (int i = 0; i < count; i++) {
26 | features[i] *= d;
27 | }
28 | }
29 |
30 |
31 | void thresholdFeatures(
32 | int count,
33 | thread float * features,
34 | float threshold
35 | ) {
36 | for (int i = 0; i < count; i++) {
37 | features[i] = min(features[i], threshold);
38 | }
39 | }
40 |
41 |
42 | void quantizeFeatures(
43 | int count,
44 | thread float * features,
45 | thread int * output
46 | ) {
47 | for (int i = 0; i < count; i++) {
48 | output[i] = min(255.0, features[i] * 512.0);
49 | }
50 | }
51 |
52 |
53 | int offset(int x, int y, int b) {
54 | const int side = 4;
55 | const int bins = SIFT_DESCRIPTOR_ORIENTATION_BINS;
56 | return (y * side * bins) + (x * bins) + b;
57 | }
58 |
59 |
60 | void addValue(
61 | thread float * patch,
62 | int x,
63 | int y,
64 | int b,
65 | float value
66 | ) {
67 | const int side = 4;
68 | const int bins = SIFT_DESCRIPTOR_ORIENTATION_BINS;
69 | if ((x < 0) || (x >= side) || (y < 0) || (y >= side)) {
70 | return;
71 | }
72 | if (b < 0) {
73 | b += bins;
74 | }
75 | if (b >= bins) {
76 | b -= bins;
77 | }
78 | patch[offset(x, y, b)] += value;
79 | }
80 |
81 |
82 | void addFeature(
83 | thread float * patch,
84 | float x,
85 | float y,
86 | float b,
87 | float value
88 | ) {
89 | // Integer coordinates of the four pixels surrounding the point x, y
90 | const int2 ca = int2(floor(x), floor(y));
91 | const int2 cb = int2(ceil(x), floor(y));
92 | const int2 cc = int2(ceil(x), ceil(y));
93 | const int2 cd = int2(floor(x), ceil(y));
94 |
95 | // Bins surrounding the bin at index b
96 | const int ba = floor(b);
97 | const int bb = ceil(b);
98 |
99 | const float iMax = x - floor(x);
100 | const float iMin = 1 - iMax;
101 | const float jMax = y - floor(y);
102 | const float jMin = 1 - jMax;
103 | const float bMax = b - floor(b);
104 | const float bMin = 1 - bMax;
105 |
106 | addValue(patch, ca.x, ca.y, ba, (iMin * jMin * bMin) * value);
107 | addValue(patch, ca.x, ca.y, bb, (iMin * jMin * bMax) * value);
108 |
109 | addValue(patch, cb.x, cb.y, ba, (iMax * jMin * bMin) * value);
110 | addValue(patch, cb.x, cb.y, bb, (iMax * jMin * bMax) * value);
111 |
112 | addValue(patch, cc.x, cc.y, ba, (iMax * jMax * bMin) * value);
113 | addValue(patch, cc.x, cc.y, bb, (iMax * jMax * bMax) * value);
114 |
115 | addValue(patch, cd.x, cd.y, ba, (iMin * jMax * bMin) * value);
116 | addValue(patch, cd.x, cd.y, bb, (iMin * jMax * bMax) * value);
117 | }
118 |
119 |
120 | kernel void siftDescriptors(
121 | device SIFTDescriptorResult * results [[buffer(0)]],
122 | device SIFTDescriptorInput * inputs [[buffer(1)]],
123 | device SIFTDescriptorParameters & parameters [[buffer(2)]],
124 | texture2d_array gradientTextures [[texture(0)]],
125 | ushort gid [[thread_position_in_grid]]
126 | ) {
127 |
128 | // let octave = dog.octaves[keypoint.octave]
129 | // let images = octave.gaussianImages
130 | // let histogramsPerAxis = configuration.descriptorHistogramsPerAxis
131 | SIFTDescriptorResult result;
132 | const SIFTDescriptorInput input = inputs[gid];
133 |
134 |
135 | // let image = octaves[keypoint.octave].gradientImages[keypoint.scale]
136 |
137 | // let delta = octave.delta
138 | // let lambda = configuration.lambdaDescriptor
139 | // let a = keypoint.absoluteCoordinate
140 | float px = float(input.absoluteX) / parameters.delta;
141 | float py = float(input.absoluteY) / parameters.delta;
142 |
143 | // Check that the keypoint is sufficiently far from the edge to include
144 | // entire area of the descriptor.
145 |
146 | #warning("TODO: Do this check after interpolation to avoid wasting work on extracting the orientation")
147 | // let diagonal = Float(2).squareRoot() * lambda * sigma
148 | // let f = Float(histogramsPerAxis + 1) / Float(histogramsPerAxis)
149 | // let side = Int((diagonal * f).rounded())
150 |
151 | //let radius = lambda * f
152 | const int d = 4; // width of 2d array of histograms
153 | const int bins = SIFT_DESCRIPTOR_ORIENTATION_BINS;
154 |
155 | const float tau = 2 * M_PI_F;
156 | const float cosT = cos(input.theta);
157 | const float sinT = sin(input.theta);
158 | const float binsPerRadian = (float)bins / tau;
159 | const float exponentDenominator = (float)(d * d) * 0.5;
160 | const float interval = (float)input.scale + input.subScale;
161 | const float intervals = (float)parameters.scalesPerOctave;
162 | const float sigma = 1.6;
163 | const float scale = sigma * pow(2.0, interval / intervals); // identical to below
164 | // let _sigma = keypoint.sigma / octave.delta // identical to above
165 | const float histogramWidth = 3.0 * scale; // 3.0 constant from Whess (OpenSIFT)
166 | const int radius = histogramWidth * sqrt(2.0) * ((float)d + 1.0) * 0.5 + 0.5;
167 |
168 | // const int minX = radius;
169 | // const int minY = radius;
170 | // const int maxX = parameters.width - 1 - radius;
171 | // const int maxY = parameters.height - 1 - radius;
172 | //
173 | // if (px < minX) {
174 | // return;
175 | // }
176 | // if (py < minY) {
177 | // return;
178 | // }
179 | // if (px > maxX) {
180 | // return;
181 | // }
182 | // if (py > maxY) {
183 | // return;
184 | // }
185 |
186 | // Create histograms
187 | const int featureCount = d * d * bins;
188 | float features[featureCount];
189 |
190 | for (int i = 0; i < featureCount; i++) {
191 | features[i] = 0;
192 | }
193 |
194 | for (int j = -radius; j <= +radius; j++) {
195 | for (int i = -radius; i <= +radius; i++) {
196 |
197 | float rx = ((float)j * cosT - (float)i * sinT) / histogramWidth;
198 | float ry = ((float)j * sinT + (float)i * cosT) / histogramWidth;
199 | float bx = rx + (float)(d / 2) - 0.5;
200 | float by = ry + (float)(d / 2) - 0.5;
201 |
202 | float2 g = gradientTextures.read(ushort2(px + j, py + i), input.scale).rg;
203 | float orientation = g.r - input.theta;
204 | float magnitude = g.g;
205 | while (orientation < 0) {
206 | orientation += tau;
207 | }
208 | while (orientation >= tau) {
209 | orientation -= tau;
210 | }
211 |
212 | // Bin
213 | float bin = orientation * binsPerRadian;
214 |
215 | // Total contribution
216 | float exponentNumerator = rx * rx + ry * ry;
217 | float w = exp(-exponentNumerator / exponentDenominator);
218 | float value = magnitude * w;
219 |
220 | addFeature(features, bx, by, bin, value);
221 | }
222 | }
223 |
224 | // print("feature x=\(Int(a.x)) y=\(Int(a.y)) scale=\(scale) sigma=\(_sigma) histogramWidth=\(histogramWidth) radius=\(radius)")
225 |
226 | // Serialize histograms into array
227 | normalizeFeatures(featureCount, features);
228 | thresholdFeatures(featureCount, features, 0.2);
229 | normalizeFeatures(featureCount, features);
230 | quantizeFeatures(featureCount, features, result.features);
231 |
232 | result.valid = true;
233 | result.keypoint = input.keypoint;
234 | result.theta = input.theta;
235 |
236 | results[gid] = result;
237 | }
238 |
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/TrieTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // TrieTests.swift
3 | // SkyLightTests
4 | //
5 | // Created by Luke Van In on 2023/01/20.
6 | //
7 |
8 | import XCTest
9 |
10 | @testable import SIFTMetal
11 |
12 | /*
13 | final class TrieTests: XCTestCase {
14 |
15 | func testInsert_shouldMatchStructure() {
16 | let expected = Trie(
17 | Trie(
18 | nil,
19 | Trie(
20 | nil,
21 | nil,
22 | Trie(
23 | nil,
24 | nil,
25 | nil
26 | )
27 | ),
28 | nil
29 | ),
30 | nil,
31 | nil
32 | )
33 | let trie = Trie(numberOfBins: 3)
34 | trie.insert(key: FloatVector([0.0, 0.5, 1.0]), value: FloatVector([0.0, 0.5, 1.0]))
35 | XCTAssertEqual(trie, expected)
36 | }
37 |
38 | func testContains_shouldReturnFalse_whenTrieDoesNotContainExactMatch() {
39 | let trie = Trie(numberOfBins: 3)
40 | trie.insert(key: FloatVector([1.0, 1.0, 1.0]), value: FloatVector([1.0, 1.0, 1.0]))
41 | let result = trie.contains(FloatVector([0.0, 0.0, 0.0]))
42 | XCTAssertFalse(result)
43 | }
44 |
45 | func testContains_shouldReturnTrue_whenTrieContainsExactMatch() {
46 | let trie = Trie(numberOfBins: 3)
47 | trie.insert(key: FloatVector([0.1, 0.2, 0.3]), value: FloatVector([0.1, 0.2, 0.3]))
48 | let result = trie.contains(FloatVector([0.1, 0.2, 0.3]))
49 | XCTAssertTrue(result)
50 | }
51 |
52 | func testContains_shouldReturnTrue_whenTrieContainsPartialMatch() {
53 | let trie = Trie(numberOfBins: 3)
54 | trie.insert(key: FloatVector([0.1, 0.2, 0.3]), value: FloatVector([0.1, 0.2, 0.3]))
55 | let result = trie.contains(FloatVector([0.1, 0.2]))
56 | XCTAssertTrue(result)
57 | }
58 |
59 | func testContains_shouldReturnTrue_whenTrieContainsSimilarValues() {
60 | let trie = Trie(numberOfBins: 3)
61 | trie.insert(key: FloatVector([0, 0.5, 1.0]), value: FloatVector([0, 0.5, 1.0])) // bins: 0, 1, 2
62 | XCTAssertTrue(trie.contains(FloatVector([0.1, 0.5, 1.0])))
63 | XCTAssertTrue(trie.contains(FloatVector([0.1, 0.6, 1.0])))
64 | XCTAssertTrue(trie.contains(FloatVector([0.1, 0.6, 0.9])))
65 | }
66 |
67 | func testNearest_shouldReturnNearestValue_whenTrieContainsSimilarValues() {
68 | let trie = Trie(numberOfBins: 3)
69 | trie.insert(key: FloatVector([0, 0.5, 1.0]), value: FloatVector([0, 0.5, 1.0]))
70 | let results = trie.nearest(key: FloatVector([0.1, 0.6, 0.9]), query: FloatVector([0.1, 0.6, 0.9]), radius: 0, k: 1)
71 | let expected = FloatVector([0, 0.5, 1.0])
72 | XCTAssertEqual(results[0].value, expected)
73 | }
74 |
75 | func testNearestAccuracy() {
76 |
77 | struct Neighbor: CustomStringConvertible {
78 | let id: Int
79 | let distance: Float
80 |
81 | var description: String {
82 | ""
83 | }
84 | }
85 |
86 | // capacity = 3^d * n
87 | let n = 1000
88 | let m = 100
89 | let d = 128
90 | var values: [FloatVector] = []
91 | var queries: [FloatVector] = []
92 | var nearestNeighbors: [Neighbor] = []
93 |
94 | // Generate sample data
95 | print("generate sample data x", n)
96 | for _ in 0 ..< n {
97 | var vector: [Float] = Array(repeating: 0, count: d)
98 | for j in 0 ..< d {
99 | vector[j] = .random(in: 0...1)
100 | }
101 | let value = FloatVector(vector)
102 | values.append(value)
103 | }
104 |
105 | // Generate queries.
106 | print("generate queries x", m)
107 | for _ in 0 ..< m {
108 | var vector: [Float] = Array(repeating: 0, count: d)
109 | for j in 0 ..< d {
110 | vector[j] = .random(in: 0...1)
111 | }
112 | queries.append(FloatVector(vector))
113 | }
114 |
115 | // Compute actual nearest neighbor for each node using brute force.
116 | print("compute nearest neighbor ground truth")
117 | for i in 0 ..< m {
118 | let query = queries[i]
119 | var nearestDistance: Float = .greatestFiniteMagnitude
120 | var nearestNeighbor: Neighbor!
121 | for j in 0 ..< n {
122 | let value = values[j]
123 | let distance = value.distance(to: query)
124 | guard distance < nearestDistance else {
125 | continue
126 | }
127 | nearestDistance = distance
128 | nearestNeighbor = Neighbor(id: j, distance: distance)
129 | }
130 | nearestNeighbors.append(nearestNeighbor)
131 | }
132 |
133 | // Create the trie.
134 | print("create trie")
135 | let subject = Trie(numberOfBins: 4)
136 | for i in 0 ..< n {
137 | let value = values[i]
138 | subject.insert(key: value, value: value)
139 | }
140 | print("capacity", subject.capacity())
141 |
142 | print("linking trie")
143 | subject.link()
144 |
145 | // Sanity check
146 | print("check contains")
147 | for i in 0 ..< n {
148 | let query = values[i]
149 | let result = subject.contains(query)
150 | XCTAssertTrue(result)
151 | }
152 |
153 | // Sanity check: nearest() with existing key/value should always return exact match
154 | print("check exact nearest")
155 | for i in 0 ..< n {
156 | let query = values[i]
157 | let matches = subject.nearest(key: query, query: query, radius: 0, k: 1)
158 | XCTAssertEqual(matches[0].value, query)
159 | }
160 |
161 | // Compute approximate nearest neighbor.
162 | var totalError: Float = 0
163 | var totalDistance: Float = 0
164 | var totalCorrect: Int = 0
165 | var totalQueries: Int = 0
166 | var totalNodes: Int = 0
167 | var totalFound: Int = 0
168 | // measure {
169 | for i in 0 ..< m {
170 | totalQueries += 1
171 | let query = queries[i]
172 | let foundNeighbors = subject.nearest(key: query, query: query, radius: 10, k: 1)
173 | guard let foundNeighbor = foundNeighbors.first else {
174 | continue
175 | }
176 | totalFound += 1
177 | totalNodes += subject.comparisonCountMetric
178 | let foundDistance = foundNeighbor.distance
179 | totalDistance = foundDistance
180 | let exactNeighbor = nearestNeighbors[i]
181 | let exactDistance = exactNeighbor.distance
182 | let delta = exactDistance - foundDistance
183 | let error = delta * delta
184 | totalError += error
185 | let correct = error == 0
186 | totalCorrect += correct ? 1 : 0
187 | }
188 | // }
189 | let percentCorrect = Float(totalCorrect) / Float(totalQueries)
190 | let meanSquaredError = totalError / Float(totalFound)
191 | let averageNodesPerQuery = Float(totalNodes) / Float(totalFound)
192 | let averageDistance = Float(totalDistance) / Float(totalFound)
193 | print("Total queries: \(totalQueries)")
194 | print("Total found: \(totalFound)")
195 | print("Mean squared error: \(String(format: "%0.3f", meanSquaredError))")
196 | print("Average absolute distance: \(String(format: "%0.3f", averageDistance))")
197 | print("Correct: \(totalCorrect) out of \(totalQueries) = \(String(format: "%0.3f", percentCorrect))")
198 | print("Total nodes: \(totalNodes)")
199 | print("Average comparisons per query: \(String(format: "%0.3f", averageNodesPerQuery))")
200 | }
201 | }
202 | */
203 |
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/DescriptorTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // DescriptorTests.swift
3 | // SkyLightTests
4 | //
5 | // Created by Luke Van In on 2022/12/26.
6 | //
7 |
8 | import XCTest
9 | import simd
10 |
11 | @testable import SIFTMetal
12 |
13 |
14 | final class DescriptorTests: SharedTestCase {
15 |
16 | func testDescriptors() throws {
17 |
18 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false)
19 | let configuration = SIFT.Configuration(
20 | inputSize: IntegralSize(
21 | width: inputTexture.width,
22 | height: inputTexture.height
23 | )
24 | )
25 | let subject = SIFT(device: device, configuration: configuration)
26 | let keypoints = subject.getKeypoints(inputTexture)
27 | let descriptorOctaves = subject.getDescriptors(keypointOctaves: keypoints)
28 | let foundDescriptors = Array(descriptorOctaves.joined())
29 | print("Found", foundDescriptors.count, "descriptors")
30 |
31 | print("loading reference desriptors")
32 | let referenceDescriptors = try loadDescriptors(filename: "butterfly-descriptors")
33 |
34 | print("rendering image")
35 | let referenceImage: CGImage = {
36 | let originalImage = CIImage(
37 | mtlTexture: inputTexture,
38 | options: [
39 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!,
40 | ]
41 | )!
42 | .oriented(.downMirrored)
43 | .smearColor()
44 | let cgImage = ciContext.makeCGImage(ciImage: originalImage)
45 | return cgImage
46 | }()
47 |
48 | let renderer = SIFTRenderer()
49 | attachImage(
50 | name: "descriptors",
51 | uiImage: renderer.drawDescriptors(
52 | sourceImage: referenceImage,
53 | referenceDescriptors: referenceDescriptors,
54 | foundDescriptors: foundDescriptors
55 | )
56 | )
57 | }
58 |
59 | func testDescriptorsPerformance() throws {
60 |
61 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false)
62 | let configuration = SIFT.Configuration(
63 | inputSize: IntegralSize(
64 | width: inputTexture.width,
65 | height: inputTexture.height
66 | )
67 | )
68 | let subject = SIFT(device: device, configuration: configuration)
69 | measure {
70 | let keypoints = subject.getKeypoints(inputTexture)
71 | _ = subject.getDescriptors(keypointOctaves: keypoints)
72 | }
73 | }
74 |
75 | private func matchDescriptors(detected: [SIFTDescriptor], reference: [SIFTDescriptor]) {
76 |
77 | let matches = SIFTDescriptor.match(
78 | source: detected,
79 | target: reference,
80 | absoluteThreshold: 300,
81 | relativeThreshold: 0.6
82 | )
83 |
84 | let rate = Float(matches.count) / Float(detected.count)
85 | print("found \(matches.count) out of \(detected.count) = \(rate * 100)%")
86 | XCTAssertGreaterThanOrEqual(rate, 80.0)
87 | }
88 |
89 | func testMatches() throws {
90 |
91 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false)
92 | let configuration = SIFT.Configuration(
93 | inputSize: IntegralSize(
94 | width: inputTexture.width,
95 | height: inputTexture.height
96 | )
97 | )
98 | let subject = SIFT(device: device, configuration: configuration)
99 | let keypoints = subject.getKeypoints(inputTexture)
100 | let descriptorOctaves = subject.getDescriptors(keypointOctaves: keypoints)
101 | let foundDescriptors = Array(descriptorOctaves.joined())
102 | print("Found", foundDescriptors.count, "descriptors")
103 |
104 | let referenceImage: CGImage = {
105 | let originalImage = CIImage(
106 | mtlTexture: inputTexture,
107 | options: [
108 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!,
109 | ]
110 | )!
111 | .oriented(.downMirrored)
112 | .smearColor()
113 | let cgImage = ciContext.makeCGImage(ciImage: originalImage)
114 | return cgImage
115 | }()
116 |
117 | let referenceDescriptors = try loadDescriptors(filename: "butterfly-descriptors")
118 | print("Loaded \(referenceDescriptors.count) reference descriptors")
119 |
120 | let matches = SIFTDescriptor.match(
121 | source: foundDescriptors,
122 | target: referenceDescriptors,
123 | absoluteThreshold: 300,
124 | relativeThreshold: 0.6
125 | )
126 |
127 | print("Found \(matches.count) matches")
128 |
129 | print("drawing matches")
130 | let renderer = SIFTRenderer()
131 | attachImage(
132 | name: "matches",
133 | uiImage: renderer.drawMatches(
134 | sourceImage: referenceImage,
135 | targetImage: referenceImage,
136 | matches: matches
137 | )
138 | )
139 | }
140 |
141 | func testMatchesPerformance() throws {
142 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false)
143 | let configuration = SIFT.Configuration(
144 | inputSize: IntegralSize(
145 | width: inputTexture.width,
146 | height: inputTexture.height
147 | )
148 | )
149 | let subject = SIFT(device: device, configuration: configuration)
150 | let keypoints = subject.getKeypoints(inputTexture)
151 | let descriptorOctaves = subject.getDescriptors(keypointOctaves: keypoints)
152 | let foundDescriptors = Array(descriptorOctaves.joined())
153 | print("Found", foundDescriptors.count, "descriptors")
154 |
155 | let referenceDescriptors = try loadDescriptors(filename: "butterfly-descriptors")
156 | print("Loaded \(referenceDescriptors.count) reference descriptors")
157 |
158 | // matchDescriptors(detected: foundDescriptors, reference: referenceDescriptors)
159 |
160 | print("Finding matches")
161 | measure {
162 | _ = SIFTDescriptor.match(
163 | source: foundDescriptors,
164 | target: referenceDescriptors,
165 | absoluteThreshold: 300,
166 | relativeThreshold: 0.6
167 | )
168 | }
169 | }
170 |
171 | private func filter(_ input: Array, every step: Int = 10, limit: Int = 10000) -> [E] {
172 | Array(input.enumerated().filter({ $0.offset % step == 0 }).map({ $0.element }).prefix(limit))
173 | }
174 |
175 |
176 | private func loadDescriptors(filename: String, extension: String = "txt") throws -> [SIFTDescriptor] {
177 | var descriptors = [SIFTDescriptor]()
178 |
179 | let fileURL = bundle.url(forResource: filename, withExtension: `extension`)!
180 | let data = try Data(contentsOf: fileURL)
181 | let string = String(data: data, encoding: .utf8)!
182 | let lines = string.split(separator: "\n")
183 |
184 | for line in lines {
185 | let components = line.split(separator: " ")
186 | guard components.count > 0 else {
187 | continue
188 | }
189 | let y = Float(components[0])!
190 | let x = Float(components[1])!
191 | let s = Float(components[2])!
192 | let theta = Float(components[3])!
193 | var features = Array(repeating: 0, count: 4 * 4 * 8)
194 | for i in 0 ..< features.count {
195 | features[i] = Int(components[i + 4])!
196 | }
197 |
198 | let descriptor = SIFTDescriptor(
199 | keypoint: SIFTKeypoint(
200 | octave: 0,
201 | scale: 0,
202 | subScale: 0,
203 | scaledCoordinate: .zero,
204 | absoluteCoordinate: SIMD2(x: x, y: y),
205 | sigma: s,
206 | value: 0
207 | ),
208 | theta: theta,
209 | // rawFeatures: [],
210 | features: IntVector(features)
211 | )
212 | descriptors.append(descriptor)
213 | }
214 |
215 | return descriptors
216 | }
217 |
218 | }
219 |
--------------------------------------------------------------------------------
/Sources/MetalShaders/Metal/SIFTInterpolate.metal:
--------------------------------------------------------------------------------
1 | //
2 | // siftInterpolate.metal
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/07.
6 | //
7 |
8 | #include
9 |
10 | #include "Common.hpp"
11 |
12 | #include "../include/SIFTInterpolate.h"
13 |
14 | using namespace metal;
15 |
16 |
17 | bool isOnEdge(
18 | texture2d_array t [[texture(0)]],
19 | int x,
20 | int y,
21 | int s,
22 | float edgeThreshold
23 | ) {
24 | const float v = t.read(ushort2(x, y), s).r;
25 |
26 | // Compute the 2d Hessian at pixel (i,j) - i = y, j = x
27 | // IPOL implementation uses hxx for y axis, and hyy for x axis
28 | const float zn = t.read(ushort2(x, y - 1), s).r;
29 | const float zp = t.read(ushort2(x, y + 1), s).r;
30 | const float pz = t.read(ushort2(x + 1, y), s).r;
31 | const float nz = t.read(ushort2(x - 1, y), s).r;
32 | const float pp = t.read(ushort2(x + 1, y + 1), s).r;
33 | const float np = t.read(ushort2(x - 1, y + 1), s).r;
34 | const float pn = t.read(ushort2(x + 1, y - 1), s).r;
35 | const float nn = t.read(ushort2(x - 1, y - 1), s).r;
36 |
37 | const float hxx = zn + zp - 2 * v;
38 | const float hyy = pz + nz - 2 * v;
39 | const float hxy = ((pp - np) - (pn - nn)) * 0.25;
40 |
41 | // Whess
42 | const float trace = hxx + hyy;
43 | const float determinant = (hxx * hyy) - (hxy * hxy);
44 |
45 | if (determinant <= 0) {
46 | // Negative determinant -> curvatures have different signs
47 | return true;
48 | }
49 |
50 | // let edgeThreshold = configuration.edgeThreshold
51 | const float threshold = ((edgeThreshold + 1) * (edgeThreshold + 1)) / edgeThreshold;
52 | const float curvature = (trace * trace) / determinant;
53 |
54 | if (curvature >= threshold) {
55 | // Feature is on an edge
56 | return true;
57 | }
58 |
59 | // Feature is not on an edge
60 | return false;
61 | }
62 |
63 |
64 | float3 derivatives3D(
65 | texture2d_array t [[texture(0)]],
66 | int x,
67 | int y,
68 | int s
69 | ) {
70 | const float pzz = t.read(ushort2(x + 1, y), s).r;
71 | const float nzz = t.read(ushort2(x - 1, y), s).r;
72 | const float zpz = t.read(ushort2(x, y + 1), s).r;
73 | const float znz = t.read(ushort2(x, y - 1), s).r;
74 | const float zzp = t.read(ushort2(x, y), s + 1).r;
75 | const float zzn = t.read(ushort2(x, y), s - 1).r;
76 |
77 | // x: (i[c.z][c.x + 1, c.y] - i[c.z][c.x - 1, c.y]) * 0.5,
78 | // y: (i[c.z][c.x, c.y + 1] - i[c.z][c.x, c.y - 1]) * 0.5,
79 | // z: (i[c.z + 1][c.x, c.y] - i[c.z - 1][c.x, c.y]) * 0.5
80 |
81 | return float3(
82 | (pzz - nzz) * 0.5,
83 | (zpz - znz) * 0.5,
84 | (zzp - zzn) * 0.5
85 | );
86 | }
87 |
88 |
89 | float interpolateContrast(
90 | texture2d_array t [[texture(0)]],
91 | int x,
92 | int y,
93 | int s,
94 | float3 alpha
95 | ) {
96 | const float3 dD = derivatives3D(t, x, y, s);
97 | const float3 c = dD * alpha;
98 | const float v = t.read(ushort2(x, y), s).r;
99 | return v + c.x * 0.5;
100 | }
101 |
102 |
103 | // Computes the 3D Hessian matrix.
104 | // ⎡ Ixx Ixy Ixs ⎤
105 | //
106 | // Ixy Iyy Iys
107 | //
108 | // ⎣ Ixs Iys Iss ⎦
109 | float3x3 hessian3D(
110 | texture2d_array t [[texture(0)]],
111 | int x,
112 | int y,
113 | int s
114 | ) {
115 | // z = zero, p = positive, n = negative
116 | const float zzz = t.read(ushort2(x, y), s).r;
117 |
118 | const float pzz = t.read(ushort2(x + 1, y), s).r;
119 | const float nzz = t.read(ushort2(x - 1, y), s).r;
120 |
121 | const float zpz = t.read(ushort2(x, y + 1), s).r;
122 | const float znz = t.read(ushort2(x, y - 1), s).r;
123 |
124 | const float zzp = t.read(ushort2(x, y), s + 1).r;
125 | const float zzn = t.read(ushort2(x, y), s - 1).r;
126 |
127 | const float ppz = t.read(ushort2(x + 1, y + 1), s).r;
128 | const float nnz = t.read(ushort2(x - 1, y - 1), s).r;
129 |
130 | const float npz = t.read(ushort2(x - 1, y + 1), s).r;
131 | const float pnz = t.read(ushort2(x + 1, y - 1), s).r;
132 |
133 | const float pzp = t.read(ushort2(x + 1, y), s + 1).r;
134 | const float nzp = t.read(ushort2(x - 1, y), s + 1).r;
135 | const float zpp = t.read(ushort2(x, y + 1), s + 1).r;
136 | const float znp = t.read(ushort2(x, y - 1), s + 1).r;
137 |
138 | const float pzn = t.read(ushort2(x + 1, y), s - 1).r;
139 | const float nzn = t.read(ushort2(x - 1, y), s - 1).r;
140 | const float zpn = t.read(ushort2(x, y + 1), s - 1).r;
141 | const float znn = t.read(ushort2(x, y - 1), s - 1).r;
142 |
143 |
144 | // let dxx = pzz + nzz - 2 * v
145 | // let dyy = zpz + znz - 2 * v
146 | // let dss = zzp + zzn - 2 * v
147 | const float dxx = pzz + nzz - 2 * zzz;
148 | const float dyy = zpz + znz - 2 * zzz;
149 | const float dss = zzp + zzn - 2 * zzz;
150 |
151 | // let dxy = (ppz - npz - pnz + nnz) * 0.25
152 | // let dxs = (pzp - nzp - pzn + nzn) * 0.25
153 | // let dys = (zpp - znp - zpn + znn) * 0.25
154 |
155 | const float dxy = (ppz - npz - pnz + nnz) * 0.25;
156 | const float dxs = (pzp - nzp - pzn + nzn) * 0.25;
157 | const float dys = (zpp - znp - zpn + znn) * 0.25;
158 |
159 | return float3x3(
160 | float3(dxx, dxy, dxs),
161 | float3(dxy, dyy, dys),
162 | float3(dxs, dys, dss)
163 | );
164 | }
165 |
166 |
167 | float3 interpolationStep(
168 | texture2d_array t [[texture(0)]],
169 | int x,
170 | int y,
171 | int scale
172 | ) {
173 | const float3x3 H = hessian3D(t, x, y, scale);
174 | float3x3 Hi = -1.0 * invert(H);
175 | const float3 dD = derivatives3D(t, x, y, scale);
176 | return Hi * dD;
177 | }
178 |
179 |
180 | bool outOfBounds(int x, int y, int scale, int width, int height, int scales) {
181 | // TODO: Configurable border.
182 | const int border = 5;
183 | const int minX = border;
184 | const int maxX = width - border - 1;
185 | const int minY = border;
186 | const int maxY = height - border - 1;
187 | const int minS = 1;
188 | const int maxS = scales;
189 | return x < minX || x > maxX || y < minY || y > maxY || scale < minS || scale > maxS;
190 | }
191 |
192 |
193 | kernel void siftInterpolate(
194 | device SIFTInterpolateOutputKeypoint * outputKeypoints [[buffer(0)]],
195 | device SIFTInterpolateInputKeypoint * inputKeypoints [[buffer(1)]],
196 | device SIFTInterpolateParameters & parameters [[buffer(2)]],
197 | texture2d_array dogTextures [[texture(0)]],
198 | ushort gid [[thread_position_in_grid]]
199 | ) {
200 | SIFTInterpolateInputKeypoint input = inputKeypoints[gid];
201 | SIFTInterpolateOutputKeypoint output;
202 | output.converged = 0;
203 | outputKeypoints[gid] = output;
204 |
205 | float value = dogTextures.read(ushort2(input.x, input.y), input.scale).r;
206 |
207 | // Discard keypoint that is way below the brightness threshold
208 | if (abs(value) <= parameters.dogThreshold * 0.8) {
209 | return;
210 | }
211 |
212 | const int maxIterations = parameters.maxIterations;
213 | const float maxOffset = parameters.maxOffset;
214 | const int width = parameters.width;
215 | const int height = parameters.height;
216 | const int scales = parameters.numberOfScales;
217 | const float delta = parameters.octaveDelta;
218 |
219 | int x = input.x;
220 | int y = input.y;
221 | int scale = input.scale;
222 |
223 | if (outOfBounds(x, y, scale, width, height, scales)) {
224 | return;
225 | }
226 |
227 | bool converged = false;
228 | float3 alpha = float3(0);
229 |
230 | int i = 0;
231 | while (i < maxIterations) {
232 | alpha = interpolationStep(dogTextures, x, y, scale);
233 |
234 | if ((abs(alpha.x) < maxOffset) && (abs(alpha.y) < maxOffset) && (abs(alpha.z) < maxOffset)) {
235 | converged = true;
236 | break;
237 | }
238 |
239 | // Whess
240 | // coordinate.x += Int(alpha.x.rounded())
241 | // coordinate.y += Int(alpha.y.rounded())
242 | // coordinate.z += Int(alpha.z.rounded())
243 |
244 | // IPOL
245 | // TODO: >=
246 | if (alpha.x > +maxOffset) {
247 | x += 1;
248 | }
249 | if (alpha.x < -maxOffset) {
250 | x -= 1;
251 | }
252 | if (alpha.y > +maxOffset) {
253 | y += 1;
254 | }
255 | if (alpha.y < -maxOffset) {
256 | y -= 1;
257 | }
258 | if (alpha.z > +maxOffset) {
259 | scale += 1;
260 | }
261 | if (alpha.z < -maxOffset) {
262 | scale -= 1;
263 | }
264 |
265 | if (outOfBounds(x, y, scale, width, height, scales)) {
266 | return;
267 | }
268 |
269 | i += 1;
270 | }
271 |
272 | if (!converged) {
273 | return;
274 | }
275 |
276 | value = interpolateContrast(dogTextures, x, y, scale, alpha);
277 |
278 | if (abs(value) <= parameters.dogThreshold) {
279 | return;
280 | }
281 |
282 | // Discard keypoint with high edge response
283 | if (isOnEdge(dogTextures, x, y, scale, parameters.edgeThreshold)) {
284 | return;
285 | }
286 |
287 | // Return keypoint
288 | output.converged = 1;
289 | output.scale = scale;
290 | output.subScale = alpha.z;
291 | output.relativeX = x;
292 | output.relativeY = y;
293 | output.absoluteX = ((float)x + alpha.x) * delta;
294 | output.absoluteY = ((float)y + alpha.y) * delta;
295 | output.value = value;
296 | output.alphaX = alpha.x;
297 | output.alphaY = alpha.y;
298 | output.alphaZ = alpha.z;
299 | outputKeypoints[gid] = output;
300 | }
301 |
--------------------------------------------------------------------------------
/Tests/SIFTMetalTests/DifferenceOfGaussiansTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // DifferenceOfGaussiansTests.swift
3 | // SkyLightTests
4 | //
5 | // Created by Luke Van In on 2022/12/24.
6 | //
7 |
8 | import XCTest
9 | import CoreImage
10 | import CoreImage.CIFilterBuiltins
11 | import MetalPerformanceShaders
12 |
13 | @testable import SIFTMetal
14 |
15 | /*
16 |
17 | final class DifferenceOfGaussiansTests: SharedTestCase {
18 |
19 |
20 | func testComputeDifferenceOfGaussians() throws {
21 |
22 | let referenceGaussianImages = try loadScaleSpaceTextures(
23 | name: "scalespace_butterfly",
24 | extension: "png",
25 | octaves: 1,
26 | scalesPerOctave: 5
27 | )
28 |
29 | let inputTexture = try device.loadTexture(name: "butterfly", extension: "png", srgb: false)
30 | attachImage(
31 | name: "input",
32 | uiImage: ciContext.makeUIImage(
33 | ciImage: CIImage(
34 | mtlTexture: inputTexture,
35 | options: [
36 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.genericGrayGamma2_2)!,
37 | ]
38 | )!
39 | .oriented(.downMirrored)
40 | .smearColor()
41 | )
42 | )
43 |
44 | let configuration = DifferenceOfGaussians.Configuration(
45 | inputDimensions: IntegralSize(
46 | width: Int(inputTexture.width),
47 | height: Int(inputTexture.height)
48 | )
49 | )
50 | let subject = DifferenceOfGaussians(
51 | device: device,
52 | configuration: configuration
53 | )
54 |
55 | print("Encoding")
56 | let commandBuffer = commandQueue.makeCommandBuffer()!
57 | subject.encode(
58 | commandBuffer: commandBuffer,
59 | originalTexture: inputTexture
60 | )
61 | commandBuffer.commit()
62 | commandBuffer.waitUntilCompleted()
63 |
64 | print("Saving attachments")
65 |
66 |
67 | var resultGaussainImages = [MTLTexture]()
68 | // for (o, octave) in subject.octaves.enumerated() {
69 | // for (s, texture) in octave.gaussianTextures.enumerated() {
70 | // resultGaussainImages.append(texture)
71 | // }
72 | // }
73 |
74 | compare(
75 | referenceImages: referenceGaussianImages,
76 | testImages: resultGaussainImages
77 | )
78 |
79 |
80 | // attachImage(
81 | // name: "v(1, 0): luminosity",
82 | // uiImage: makeUIImage(
83 | // ciImage: smearColor(
84 | // ciImage: makeCIImage(
85 | // texture: subject.luminosityTexture
86 | // )
87 | // ),
88 | // context: ciContext
89 | // )
90 | // )
91 |
92 | // attachImage(
93 | // name: "v(1, 0): scaled",
94 | // uiImage: makeUIImage(
95 | // ciImage: smearColor(
96 | // ciImage: makeCIImage(
97 | // texture: subject.scaledTexture
98 | // )
99 | // ),
100 | // context: ciContext
101 | // )
102 | // )
103 |
104 | // attachImage(
105 | // name: "v(1, 0): seed",
106 | // uiImage: makeUIImage(
107 | // ciImage: smearColor(
108 | // ciImage: makeCIImage(
109 | // texture: subject.seedTexture
110 | // )
111 | // ),
112 | // context: ciContext
113 | // )
114 | // )
115 |
116 | // for (o, octave) in subject.octaves.enumerated() {
117 | //
118 | // for (s, texture) in octave.gaussianTextures.enumerated() {
119 | //
120 | // attachImage(
121 | // name: "v[\(o), \(s)]",
122 | // uiImage: makeUIImage(
123 | // ciImage: smearColor(
124 | // ciImage: makeCIImage(
125 | // texture: texture
126 | // )
127 | // ),
128 | // context: ciContext
129 | // )
130 | // )
131 | // }
132 | //
133 | // for (s, texture) in octave.differenceTextures.enumerated() {
134 | //
135 | // attachImage(
136 | // name: "w[\(o), \(s)]",
137 | // uiImage: makeUIImage(
138 | // ciImage: mapColor(
139 | // ciImage: normalizeColor(
140 | // ciImage: smearColor(
141 | // ciImage: makeCIImage(
142 | // texture: texture
143 | // )
144 | // )
145 | // )
146 | // ),
147 | // context: ciContext
148 | // )
149 | // )
150 | // }
151 | // }
152 | }
153 |
154 | private func compare(
155 | referenceImages: [CIImage],
156 | testImages testTextures: [MTLTexture]
157 | ) {
158 | // let referenceImages = referenceTextures.map {
159 | // CIImage(mtlTexture: $0)!.oriented(.downMirrored)
160 | // }
161 |
162 | let testImages = testTextures.map {
163 | let image = CIImage(
164 | mtlTexture: $0,
165 | options: [
166 | CIImageOption.colorSpace: CGColorSpace(name: CGColorSpace.genericGrayGamma2_2)!,
167 | ]
168 | )!
169 | let outputImage = image
170 | .settingAlphaOne(in: image.extent)
171 | .oriented(.downMirrored)
172 | .smearColor()
173 | return outputImage
174 | }
175 |
176 | XCTAssert(referenceImages.count == referenceImages.count)
177 |
178 | let differenceFilter = CIFilter.colorAbsoluteDifference()
179 |
180 | let colorFilter = CIFilter.colorControls()
181 | colorFilter.brightness = 0.5
182 | colorFilter.contrast = 2.0
183 | colorFilter.saturation = 1
184 |
185 | let thresholdFilter = CIFilter.colorThreshold()
186 | // thresholdFilter.threshold = 0.005
187 | thresholdFilter.threshold = 0.005
188 |
189 | let clampFilter = CIFilter.colorClamp()
190 | // clampFilter.minComponents = CIVector(x: 0.005, y: 0.005, z: 0.005, w: 0)
191 | // clampFilter.maxComponents = CIVector(x: 0.995, y: 0.995, z: 0.995, w: 1)
192 | clampFilter.minComponents = CIVector(x: 0, y: 0, z: 0, w: 0)
193 | clampFilter.maxComponents = CIVector(x: 1, y: 1, z: 1, w: 1)
194 |
195 | for i in 0 ..< referenceImages.count {
196 | let referenceImage = referenceImages[i]
197 | let testImage = testImages[i]
198 |
199 | let scaledTestImage: CIImage
200 |
201 | if testImage.extent.size != referenceImage.extent.size {
202 | scaledTestImage = testImage.samplingNearest().transformed(
203 | by: .identity.scaledBy(
204 | x: referenceImage.extent.width / testImage.extent.width,
205 | y: referenceImage.extent.height / testImage.extent.height
206 | )
207 | )
208 | }
209 | else {
210 | scaledTestImage = testImage
211 | }
212 |
213 | clampFilter.inputImage = scaledTestImage
214 | differenceFilter.inputImage = clampFilter.outputImage
215 |
216 | clampFilter.inputImage = referenceImage
217 | differenceFilter.inputImage2 = clampFilter.outputImage
218 |
219 | colorFilter.inputImage = differenceFilter.outputImage
220 |
221 | thresholdFilter.inputImage = differenceFilter.outputImage
222 |
223 | let differenceImage = colorFilter.outputImage!
224 | let thresholdImage = thresholdFilter.outputImage!
225 |
226 | attachImage(
227 | name: "scalespace \(i): reference",
228 | uiImage: ciContext.makeUIImage(ciImage: referenceImage)
229 | )
230 |
231 | attachImage(
232 | name: "scalespace \(i): test",
233 | uiImage: ciContext.makeUIImage(ciImage: scaledTestImage)
234 | )
235 |
236 | attachImage(
237 | name: "scalespace \(i): difference",
238 | uiImage: ciContext.makeUIImage(ciImage: differenceImage)
239 | )
240 |
241 | attachImage(
242 | name: "scalespace \(i): threshold",
243 | uiImage: ciContext.makeUIImage(ciImage: thresholdImage)
244 | )
245 | }
246 | }
247 |
248 | private func loadScaleSpaceTextures(
249 | name: String,
250 | extension: String,
251 | octaves: Int,
252 | scalesPerOctave: Int,
253 | file: StaticString = #file,
254 | line: UInt = #line
255 | ) throws -> [CIImage] {
256 | var output = [CIImage]()
257 | for o in 0 ..< octaves {
258 | for s in 0 ..< scalesPerOctave {
259 | let octaveName = String(format: "%03d", o)
260 | let scaleName = String(format: "%03d", s)
261 | let filename = "\(name)_o\(octaveName)_s\(scaleName)"
262 | let image = try CIImage(name: filename, extension: `extension`)
263 | // let image = try device.loadTexture(name: filename, extension: `extension`)
264 | output.append(image)
265 | }
266 | }
267 | return output
268 | }
269 | }
270 | */
271 |
272 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/SIFT/SIFT.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFT.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2022/12/18.
6 | //
7 |
8 | import Foundation
9 | import OSLog
10 | import Metal
11 |
12 | import MetalShaders
13 |
14 | private let logger = Logger(
15 | subsystem: Bundle.main.bundleIdentifier!,
16 | category: "SIFT"
17 | )
18 |
19 |
20 | ///
21 | /// Performs the Scale Invariant Feature Transform (SIFT) on an image.
22 | ///
23 | /// Extracts a set of robust feature descriptors for identifiable points on an image using the SIFT
24 | /// algorithm[1]. Uses Metal compute shaders to execute tasks using GPU hardware.
25 | ///
26 | /// This implementation is mostly based on the "Anatomy of the SIFT method"[2] paper and source code
27 | /// published by the Image Processing Online (IPOL) Journal. The implementation has come notable
28 | /// charactaristics not explicitly described in the paper. Their relevance toward the accuracy or correctness of
29 | /// the implementation is not indicated.
30 | /// - sRGB images are not converted to linear grayscale space for analysis. The SIFT analysis is performed
31 | /// on the lograithmic (^2.2) color space.
32 | /// - RGB colors are converted to grayscale using the ITU BT.709-5 (NTSC) color conversion formula.
33 | /// - The convolution kernel used for Gaussian blur is centered symmetrically on pixels. This differers to the
34 | /// positioning used by convolution kernels provided by Metal Performance Shaders. A custom convolution
35 | /// kernel is used to provide similarity to the IPOL implementation.
36 | ///
37 | /// Note: A novel method is used for matching SIFT descriptors, which is different to the methods used by
38 | /// Lowe and IPOL. Our method matches descriptors using a trie structure with leaf nodes forming a linked
39 | /// list. Construction of the trie takes O(n) time. Queries run in O(1) constant time.
40 | ///
41 | /// [1]: https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf "Distinctive Image Features from Scale-Invariant Keypoints", Lowe, 2004
42 | /// [2]: http://www.ipol.im/pub/art/2014/82/article.pdf "Anatomy of the SIFT Method", Rey-Otero & Delbracio, 2014
43 | ///
44 | /// Additional references:
45 | /// See: https://www.cs.ubc.ca/~lowe/keypoints/
46 | /// See: https://en.wikipedia.org/wiki/Scale-invariant_feature_transform
47 | /// See: https://docs.opencv.org/4.x/da/df5/tutorial_py_sift_intro.html
48 | /// See: https://www.youtube.com/watch?v=4AvTMVD9ig0&t=232
49 | /// See: https://www.youtube.com/watch?v=U0wqePj4Mx0
50 | /// See: https://www.youtube.com/watch?v=ram-jbLJjFg&t=2s
51 | /// See: https://www.youtube.com/watch?v=NPcMS49V5hg
52 | /// See: https://github.com/robwhess/opensift
53 | /// See: https://medium.com/jun94-devpblog/cv-13-scale-invariant-local-feature-extraction-3-sift-315b5de72d48
54 | ///
55 | public final class SIFT {
56 |
57 | public struct Configuration {
58 |
59 | // Dimensions of the input image.
60 | var inputSize: IntegralSize
61 |
62 | // Threshold over the Difference of Gaussians response (value
63 | // relative to scales per octave = 3)
64 | var differenceOfGaussiansThreshold: Float = 0.0133
65 |
66 | // Threshold over the ratio of principal curvatures (edgeness).
67 | var edgeThreshold: Float = 10.0
68 |
69 | // Maximum number of consecutive unsuccessful interpolation.
70 | var maximumInterpolationIterations: Int = 5
71 |
72 | // Width of border in which to ignore keypoints
73 | var imageBorder: Int = 5
74 |
75 | // Sets how local is the analysis of the gradient distribution.
76 | var lambdaOrientation: Float = 1.5
77 |
78 | // Number of bins in the orientation histogram.
79 | var orientationBins: Int = 36
80 |
81 | // Threshold for considering local maxima in the orientation histogram.
82 | var orientationThreshold: Float = 0.8
83 |
84 | // Number of iterations used to smooth the orientation histogram
85 | var orientationSmoothingIterations: Int = 6
86 |
87 | // Number of normalized histograms in the normalized patch in the
88 | // descriptor. This must be a square integer number so that both x
89 | // and y axes have the same length.
90 | var descriptorHistogramsPerAxis: Int = 4
91 |
92 | // Number of bins in the descriptor histogram.
93 | var descriptorOrientationBins: Int = 8
94 |
95 | // How local the descriptor is (size of the descriptor).
96 | // Gaussian window of lambdaDescriptor * sigma
97 | // Descriptor patch width of 2 * lambdaDescriptor * sigma
98 | var lambdaDescriptor: Float = 6
99 |
100 | public init(inputSize: IntegralSize) {
101 | self.inputSize = inputSize
102 | }
103 | }
104 |
105 | let configuration: Configuration
106 | let dog: DifferenceOfGaussians
107 | let octaves: [SIFTOctave]
108 |
109 | private let device: MTLDevice
110 | private let commandQueue: MTLCommandQueue
111 |
112 | public init(
113 | device: MTLDevice,
114 | configuration: Configuration
115 | ) {
116 | self.device = device
117 |
118 | let dog = DifferenceOfGaussians(
119 | device: device,
120 | configuration: DifferenceOfGaussians.Configuration(
121 | inputDimensions: configuration.inputSize
122 | )
123 | )
124 | let octaves: [SIFTOctave] = {
125 | let gradientFunction = SIFTGradientKernel(device: device)
126 |
127 | var octaves = [SIFTOctave]()
128 | for scale in dog.octaves {
129 | let octave = SIFTOctave(
130 | device: device,
131 | scale: scale,
132 | gradientFunction: gradientFunction
133 | )
134 | octaves.append(octave)
135 | }
136 | return octaves
137 | }()
138 |
139 | self.commandQueue = device.makeCommandQueue()!
140 | self.configuration = configuration
141 | self.dog = dog
142 | self.octaves = octaves
143 | }
144 |
145 | // MARK: Keypoints
146 |
147 | public func getKeypoints(_ inputTexture: MTLTexture) -> [[SIFTKeypoint]] {
148 | findKeypoints(inputTexture: inputTexture)
149 | let keypointOctaves = getKeypointsFromOctaves()
150 | let interpolatedKeypoints = interpolateKeypoints(keypointOctaves: keypointOctaves)
151 | return interpolatedKeypoints
152 | }
153 |
154 | private func findKeypoints(inputTexture: MTLTexture) {
155 | measure(name: "findKeypoints") {
156 | capture(commandQueue: commandQueue, capture: false) {
157 | let commandBuffer = commandQueue.makeCommandBuffer()!
158 | commandBuffer.label = "siftKeypointsCommandBuffer"
159 |
160 | dog.encode(
161 | commandBuffer: commandBuffer,
162 | originalTexture: inputTexture
163 | )
164 |
165 | for octave in octaves {
166 | octave.encode(
167 | commandBuffer: commandBuffer
168 | )
169 | }
170 |
171 | commandBuffer.commit()
172 | commandBuffer.waitUntilCompleted()
173 | }
174 | }
175 | }
176 |
177 | private func getKeypointsFromOctaves() -> [Buffer] {
178 | var output = [Buffer]()
179 | measure(name: "getKeypointsFromOctaves") {
180 | for octave in octaves {
181 | let keypoints = octave.getKeypoints()
182 | output.append(keypoints)
183 | }
184 | }
185 | let totalKeypoints = output.reduce(into: 0) { $0 += $1.count }
186 | logger.info("getKeypointsFromOctaves: Found \(totalKeypoints) keypoints")
187 | return output
188 | }
189 |
190 | private func interpolateKeypoints(keypointOctaves: [Buffer]) -> [[SIFTKeypoint]] {
191 | var output = [[SIFTKeypoint]]()
192 | measure(name: "interpolateKeypoints") {
193 | for o in 0 ..< keypointOctaves.count {
194 | let keypoints = keypointOctaves[o]
195 | output.append(octaves[o].interpolateKeypoints(
196 | commandQueue: commandQueue,
197 | keypoints: keypoints
198 | ))
199 | }
200 | }
201 | return output
202 | }
203 |
204 |
205 | // MARK: Descriptora
206 |
207 | public func getDescriptors(keypointOctaves: [[SIFTKeypoint]]) -> [[SIFTDescriptor]] {
208 | precondition(keypointOctaves.count == octaves.count)
209 |
210 | // Get all orientations for all keypoints.
211 | var orientationOctaves = [[SIFTKeypointOrientations]]()
212 | measure(name: "getDescriptors(orientations)") {
213 | for i in 0 ..< octaves.count {
214 | let octave = octaves[i]
215 | let keypoints = keypointOctaves[i]
216 | let orientationOctave = octave.getKeypointOrientations(
217 | commandQueue: commandQueue,
218 | keypoints: keypoints
219 | )
220 | orientationOctaves.append(orientationOctave)
221 | }
222 | }
223 |
224 | // Get descriptors for each orientation.
225 | var output: [[SIFTDescriptor]] = []
226 | measure(name: "getDescriptors(descriptors)") {
227 | for i in 0 ..< octaves.count {
228 | let octave = octaves[i]
229 | let orientationOctave = orientationOctaves[i]
230 | let descriptors = octave.getDescriptors(
231 | commandQueue: commandQueue,
232 | orientationOctave: orientationOctave
233 | )
234 | output.append(descriptors)
235 | }
236 | }
237 | return output
238 | }
239 | }
240 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/Utilities/SIFTRenderer.swift:
--------------------------------------------------------------------------------
1 | import UIKit
2 | import CoreGraphics
3 |
4 |
5 | public final class SIFTRenderer {
6 |
7 | public init() {
8 |
9 | }
10 |
11 | public func drawKeypoints(
12 | sourceImage: CGImage,
13 | overlayColor: UIColor = UIColor.black.withAlphaComponent(0.8),
14 | referenceColor: UIColor = UIColor.red,
15 | foundColor: UIColor = UIColor.green,
16 | referenceKeypoints: [SIFTKeypoint],
17 | foundKeypoints: [SIFTKeypoint]
18 | ) -> UIImage {
19 |
20 | let bounds = CGRect(x: 0, y: 0, width: sourceImage.width, height: sourceImage.height)
21 |
22 | let renderer = UIGraphicsImageRenderer(size: bounds.size)
23 | let uiImage = renderer.image { context in
24 | let cgContext = context.cgContext
25 |
26 | cgContext.saveGState()
27 | cgContext.scaleBy(x: 1, y: -1)
28 | cgContext.translateBy(x: 0, y: -bounds.height)
29 | cgContext.draw(sourceImage, in: bounds)
30 | cgContext.restoreGState()
31 |
32 | cgContext.saveGState()
33 | cgContext.setBlendMode(.multiply)
34 | cgContext.setFillColor(overlayColor.cgColor)
35 | cgContext.fill([bounds])
36 | cgContext.restoreGState()
37 |
38 | cgContext.saveGState()
39 | cgContext.setLineWidth(1)
40 | cgContext.setStrokeColor(referenceColor.cgColor)
41 | // cgContext.setBlendMode(.screen)
42 | for keypoint in referenceKeypoints {
43 | let radius = CGFloat(keypoint.sigma)
44 | let bounds = CGRect(
45 | x: CGFloat(keypoint.absoluteCoordinate.x) - radius,
46 | y: CGFloat(keypoint.absoluteCoordinate.y) - radius,
47 | width: radius * 2,
48 | height: radius * 2
49 | )
50 | cgContext.addEllipse(in: bounds)
51 | }
52 | cgContext.strokePath()
53 | cgContext.restoreGState()
54 |
55 |
56 | cgContext.saveGState()
57 | cgContext.setLineWidth(1)
58 | cgContext.setStrokeColor(foundColor.cgColor)
59 | cgContext.setBlendMode(.screen)
60 | for keypoint in foundKeypoints {
61 | let radius = CGFloat(keypoint.sigma)
62 | let bounds = CGRect(
63 | x: CGFloat(keypoint.absoluteCoordinate.x) - radius,
64 | y: CGFloat(keypoint.absoluteCoordinate.y) - radius,
65 | width: radius * 2,
66 | height: radius * 2
67 | )
68 | cgContext.addEllipse(in: bounds)
69 | }
70 | cgContext.strokePath()
71 | cgContext.restoreGState()
72 | }
73 | return uiImage
74 | }
75 |
76 |
77 | func drawDescriptors(
78 | sourceImage: CGImage,
79 | overlayColor: UIColor = UIColor.black.withAlphaComponent(0.8),
80 | referenceColor: UIColor = UIColor.green,
81 | foundColor: UIColor = UIColor.red,
82 | referenceDescriptors: [SIFTDescriptor],
83 | foundDescriptors: [SIFTDescriptor]
84 | ) -> UIImage {
85 |
86 | let bounds = CGRect(x: 0, y: 0, width: sourceImage.width, height: sourceImage.height)
87 |
88 | let renderer = UIGraphicsImageRenderer(size: bounds.size)
89 | let uiImage = renderer.image { context in
90 | let cgContext = context.cgContext
91 |
92 | cgContext.saveGState()
93 | cgContext.scaleBy(x: 1, y: -1)
94 | cgContext.translateBy(x: 0, y: -bounds.height)
95 | cgContext.draw(sourceImage, in: bounds)
96 | cgContext.restoreGState()
97 |
98 | cgContext.saveGState()
99 | cgContext.setBlendMode(.multiply)
100 | cgContext.setFillColor(overlayColor.cgColor)
101 | cgContext.fill([bounds])
102 | cgContext.restoreGState()
103 |
104 | drawDescriptors(cgContext: cgContext, color: referenceColor, descriptors: referenceDescriptors)
105 |
106 | drawDescriptors(cgContext: cgContext, color: foundColor, descriptors: foundDescriptors)
107 |
108 | }
109 | return uiImage
110 | }
111 |
112 |
113 | public func drawMatches(
114 | sourceImage: CGImage,
115 | targetImage: CGImage,
116 | overlayColor: UIColor = UIColor.black.withAlphaComponent(0.8),
117 | sourceColor: UIColor = UIColor.cyan,
118 | targetColor: UIColor = UIColor.yellow,
119 | matches: [SIFTCorrespondence]
120 | ) -> UIImage {
121 |
122 | precondition(targetImage.width == sourceImage.width)
123 | precondition(targetImage.height == sourceImage.height)
124 | let imageSize = CGSize(width: sourceImage.width, height: sourceImage.height)
125 | let bounds = CGRect(x: 0, y: 0, width: imageSize.width * 2, height: imageSize.height)
126 |
127 | let sourceOffset = CGPoint(x: 0, y: 0)
128 | let targetOffset = CGPoint(x: imageSize.width, y: 0)
129 | let sourceBounds = CGRect(origin: sourceOffset, size: imageSize)
130 | let targetBounds = CGRect(origin: targetOffset, size: imageSize)
131 |
132 | let sourceDescriptors = matches.map { $0.source }
133 | let targetDescriptors = matches.map { $0.target }
134 |
135 | let colors: [UIColor] = [
136 | .systemRed, .systemOrange, .systemYellow, .systemGreen, .systemTeal, .systemBlue, .systemPurple, .systemIndigo
137 | ]
138 | let radius = CGFloat(2)
139 |
140 | let renderer = UIGraphicsImageRenderer(size: bounds.size)
141 | let uiImage = renderer.image { context in
142 | let cgContext = context.cgContext
143 |
144 | // Draw images
145 | cgContext.saveGState()
146 | cgContext.scaleBy(x: 1, y: -1)
147 | cgContext.translateBy(x: 0, y: -bounds.height)
148 | cgContext.draw(sourceImage, in: sourceBounds)
149 | cgContext.draw(targetImage, in: targetBounds)
150 | cgContext.restoreGState()
151 |
152 | // Overlay
153 | cgContext.saveGState()
154 | cgContext.setBlendMode(.multiply)
155 | cgContext.setFillColor(overlayColor.cgColor)
156 | cgContext.fill([bounds])
157 | cgContext.restoreGState()
158 |
159 | // Descriptors
160 | drawDescriptors(cgContext: cgContext, at: sourceOffset, color: sourceColor, descriptors: sourceDescriptors)
161 | drawDescriptors(cgContext: cgContext, at: targetOffset, color: targetColor, descriptors: targetDescriptors)
162 |
163 | cgContext.saveGState()
164 | cgContext.setLineWidth(1)
165 | cgContext.setBlendMode(.screen)
166 |
167 | for i in 0 ..< matches.count {
168 | let n = i % (colors.count - 1)
169 | let color = colors[n]
170 |
171 | let match = matches[i]
172 | let ks = match.source.keypoint.absoluteCoordinate
173 | let kt = match.target.keypoint.absoluteCoordinate
174 | let ps = CGPoint(
175 | x: sourceOffset.x + CGFloat(ks.x),
176 | y: sourceOffset.y + CGFloat(ks.y)
177 | )
178 | let pt = CGPoint(
179 | x: targetOffset.x + CGFloat(kt.x),
180 | y: targetOffset.y + CGFloat(kt.y)
181 | )
182 | // let bs = CGRect(
183 | // x: ps.x - radius,
184 | // y: ps.y - radius,
185 | // width: radius * 2,
186 | // height: radius * 2
187 | // )
188 | // let bt = CGRect(
189 | // x: pt.x - radius,
190 | // y: pt.y - radius,
191 | // width: radius * 2,
192 | // height: radius * 2
193 | // )
194 |
195 | cgContext.move(to: ps)
196 | cgContext.addLine(to: pt)
197 |
198 | if i % 10 == 0 {
199 | cgContext.setLineWidth(2)
200 | cgContext.setStrokeColor(color.withAlphaComponent(0.5).cgColor)
201 | }
202 | else {
203 | cgContext.setLineWidth(0.5)
204 | cgContext.setStrokeColor(color.withAlphaComponent(0.3).cgColor)
205 | }
206 |
207 | cgContext.strokePath()
208 | }
209 |
210 | cgContext.restoreGState()
211 |
212 | }
213 | return uiImage
214 | }
215 |
216 |
217 | private func drawDescriptors(
218 | cgContext: CGContext,
219 | at offset: CGPoint = .zero,
220 | color: UIColor,
221 | descriptors: [SIFTDescriptor]
222 | ) {
223 | cgContext.saveGState()
224 | cgContext.setLineWidth(0.5)
225 | cgContext.setStrokeColor(color.cgColor)
226 | cgContext.setBlendMode(.screen)
227 | for descriptor in descriptors {
228 | let keypoint = descriptor.keypoint
229 | let radius = 1.5 * CGFloat(keypoint.sigma)
230 | let center = CGPoint(
231 | x: offset.x + CGFloat(keypoint.absoluteCoordinate.x),
232 | y: offset.y + CGFloat(keypoint.absoluteCoordinate.y)
233 | )
234 | let bounds = CGRect(
235 | x: center.x - radius,
236 | y: center.y - radius,
237 | width: radius * 2,
238 | height: radius * 2
239 | )
240 | cgContext.addEllipse(in: bounds)
241 |
242 | // Primary Orientation
243 | cgContext.move(to: center)
244 | cgContext.addLine(
245 | to: CGPoint(
246 | x: center.x + cos(CGFloat(descriptor.theta)) * radius,
247 | y: center.y + sin(CGFloat(descriptor.theta)) * radius
248 | )
249 | )
250 |
251 | //
252 | }
253 | cgContext.strokePath()
254 | cgContext.restoreGState()
255 | }
256 | }
257 |
--------------------------------------------------------------------------------
/Sources/SIFTMetal/SIFT/SIFTDescriptor.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SIFTDescriptor.swift
3 | // SkyLight
4 | //
5 | // Created by Luke Van In on 2023/01/02.
6 | //
7 |
8 | import Foundation
9 | import simd
10 |
11 |
12 | public struct SIFTDescriptor {
13 |
14 | // Detected keypoint.
15 | public let keypoint: SIFTKeypoint
16 | // Principal orientation.
17 | public let theta: Float
18 | // Quantized features
19 | public let features: IntVector
20 |
21 | public let rawFeatures: FloatVector
22 |
23 | private let indexKey: FloatVector
24 | private let indexValue: FloatVector
25 |
26 | public init(
27 | keypoint: SIFTKeypoint,
28 | theta: Float,
29 | features: IntVector
30 | ) {
31 | precondition(features.count > 0)
32 | self.keypoint = keypoint
33 | self.theta = theta
34 | self.features = features
35 |
36 | let rawFeatures = {
37 | let components = features.components.map { Float($0) / Float(255) }
38 | return FloatVector(components)
39 | }()
40 | self.rawFeatures = rawFeatures
41 |
42 | let indexFeatures = {
43 | let count = 8
44 | let features = stride(from: 0, to: rawFeatures.count, by: count)
45 | .map { start in
46 | let end = min(start + count, rawFeatures.count)
47 | let patch = rawFeatures.components[start ..< end]
48 | return FloatVector(Array(patch))
49 | }
50 | let newOrder = [
51 | // Center
52 | features[5],
53 | features[6],
54 | features[9],
55 | features[10],
56 |
57 | // Corners
58 | features[0],
59 | features[3],
60 | features[12],
61 | features[15],
62 |
63 | // Edges
64 | features[1],
65 | features[2],
66 | features[4],
67 | features[7],
68 | features[8],
69 | features[11],
70 | features[13],
71 | features[14],
72 | ]
73 | precondition(newOrder.count == 16)
74 |
75 | return newOrder
76 | }()
77 |
78 | // let prefix = [keypoint.normalizedCoordinate.x, keypoint.normalizedCoordinate.y]
79 |
80 | self.indexValue = {
81 | let components = Array(indexFeatures.map { $0.components }.joined())
82 | return FloatVector(components)
83 | }()
84 |
85 | self.indexKey = {
86 | let v = indexFeatures.map { $0.mean() }
87 | precondition(v.count == 16)
88 | return FloatVector(v)
89 | }()
90 | }
91 |
92 | private static func distance(_ a: SIFTDescriptor, _ b: SIFTDescriptor) -> Float {
93 | precondition(a.indexValue.count == 128)
94 | precondition(b.indexValue.count == 128)
95 | return a.indexValue.distance(to: b.indexValue)
96 | // var t = 0
97 | // for i in 0 ..< 128 {
98 | // let d = b.features[i] - a.features[i]
99 | // t += (d * d)
100 | // }
101 | // return sqrt(Float(t))
102 | }
103 |
104 | public static func matchGeometry(
105 | source: [SIFTDescriptor],
106 | target: [SIFTDescriptor],
107 | absoluteThreshold: Float = 1.176,
108 | relativeThreshold: Float = 0.6
109 | ) -> Float {
110 | // let sampleLimit = 10
111 | // let sampleRatio: Float = 0.10
112 | // let maximumMatches = min(source.count, target.count)
113 | // let minimumSampleSize = max(Int(Float(maximumMatches) * sampleRatio), sampleLimit)
114 | let minimumSampleSize = 7
115 | let maximumSampleSize = max(minimumSampleSize, 80)
116 | // guard minimumSampleSize >= sampleLimit else {
117 | // print("matchGeometry: rejected: Not enough samples: \(minimumSampleSize) out of \(sampleLimit)")
118 | // return 0
119 | // }
120 | // guard source.count >= minimumSampleSize else {
121 | // print("matchGeometry: rejected: Not enough source samples: \(source.count) out of \(minimumSampleSize)")
122 | // return 0
123 | // }
124 | // guard target.count >= minimumSampleSize else {
125 | // print("matchGeometry: rejected: Not enough target samples: \(target.count) out of \(minimumSampleSize)")
126 | // return 0
127 | // }
128 | let matches = match(source: source, target: target, absoluteThreshold: absoluteThreshold, relativeThreshold: relativeThreshold)
129 | guard matches.count >= minimumSampleSize else {
130 | print("matchGeometry: rejected: Not enough matches: \(matches.count) out of \(minimumSampleSize)")
131 | return 0
132 | }
133 | // let matchRatio = Float(matches.count) / Float(maximumMatches)
134 | // guard matchRatio >= minimumMatchRatio else {
135 | // print("matchGeometry: rejected: Source match ratio too low: \(matchRatio) out of \(minimumMatchRatio )")
136 | // return 0
137 | // }
138 | let sample = Array(matches.prefix(maximumSampleSize))
139 | return compareGeometry(
140 | matches: sample,
141 | minimumSampleSize: minimumSampleSize
142 | )
143 | }
144 |
145 | private static func makeCoordinate(_ keypoint: SIFTKeypoint) -> SIMD2 {
146 | // return SIMD2(
147 | // keypoint.normalizedCoordinate.x * 3,
148 | // keypoint.normalizedCoordinate.y * 3,
149 | // keypoint.sigma
150 | // )
151 | // return SIMD2(
152 | // keypoint.absoluteCoordinate.x / keypoint.sigma,
153 | // keypoint.absoluteCoordinate.y / keypoint.sigma
154 | // )
155 | return keypoint.absoluteCoordinate
156 | // return keypoint.normalizedCoordinate
157 | }
158 |
159 | private static func dotProduct(_ a: SIMD2, _ b: SIMD2) -> Float {
160 | return simd_clamp((simd_dot(a, b) * 0.5) + 0.5, 0, 1)
161 | }
162 |
163 | private static func compareGeometry(
164 | matches: [SIFTCorrespondence],
165 | minimumSampleSize: Int
166 | ) -> Float {
167 |
168 | print("compareGeometry:", "samples", matches.count, "minimum samples", minimumSampleSize)
169 |
170 | let minimumLength: Float = 2
171 |
172 | var sum: Float = 0
173 | var count: Int = 0
174 | var scores: [Float] = []
175 | for i in stride(from: 0, to: matches.count - 3, by: 1) {
176 |
177 | // if count >= minimumSampleSize {
178 | // break
179 | // }
180 |
181 | let m0 = matches[i + 0]
182 | let m1 = matches[i + 1]
183 |
184 | let sourceBase = makeCoordinate(m1.source.keypoint) - makeCoordinate(m0.source.keypoint)
185 | let targetBase = makeCoordinate(m1.target.keypoint) - makeCoordinate(m0.target.keypoint)
186 |
187 | let sourceBaseLength = simd_length(sourceBase)
188 | let targetBaseLength = simd_length(targetBase)
189 |
190 | guard sourceBaseLength >= minimumLength else {
191 | continue
192 | }
193 |
194 | guard targetBaseLength >= minimumLength else {
195 | continue
196 | }
197 |
198 | let sourceBaseNormal = simd_normalize(sourceBase)
199 | let targetBaseNormal = simd_normalize(targetBase)
200 |
201 | let m2 = matches[i + 2]
202 | let m3 = matches[i + 3]
203 | let sourceTest = makeCoordinate(m3.source.keypoint) - makeCoordinate(m2.source.keypoint)
204 | let targetTest = makeCoordinate(m3.target.keypoint) - makeCoordinate(m2.target.keypoint)
205 |
206 | let sourceTestLength = simd_length(sourceTest)
207 | let targetTestLength = simd_length(targetTest)
208 |
209 | guard sourceTestLength >= minimumLength else {
210 | continue
211 | }
212 |
213 | guard targetTestLength >= minimumLength else {
214 | continue
215 | }
216 |
217 | let sourceTestNormal = simd_normalize(sourceTest)
218 | let targetTestNormal = simd_normalize(targetTest)
219 |
220 | let sourceRatio = sourceTestLength / sourceBaseLength
221 | let targetRatio = targetTestLength / targetBaseLength
222 |
223 | let sourceDotProduct = dotProduct(sourceTestNormal, sourceBaseNormal)
224 | let targetDotProduct = dotProduct(targetTestNormal, targetBaseNormal)
225 |
226 | precondition(sourceDotProduct >= 0)
227 | precondition(sourceDotProduct <= 1)
228 | precondition(targetDotProduct >= 0)
229 | precondition(targetDotProduct <= 1)
230 |
231 | let orientationSimilarity = 1.0 - abs(sourceDotProduct - targetDotProduct)
232 | precondition(orientationSimilarity >= 0)
233 | precondition(orientationSimilarity <= 1)
234 |
235 | let scaleSimilarity: Float
236 | if sourceRatio < targetRatio {
237 | scaleSimilarity = simd_clamp(sourceRatio / targetRatio, 0, 1)
238 | }
239 | else {
240 | scaleSimilarity = simd_clamp(targetRatio / sourceRatio, 0, 1)
241 | }
242 | precondition(scaleSimilarity >= 0)
243 | precondition(scaleSimilarity <= 1)
244 |
245 | let similarity = orientationSimilarity * scaleSimilarity
246 | let score = similarity * similarity
247 | scores.append(score)
248 | sum += score
249 | count += 1
250 | }
251 |
252 | guard count >= minimumSampleSize else {
253 | return 0
254 | }
255 |
256 | let mean = sum / Float(count)
257 |
258 | var error: Float = 0
259 | for score in scores {
260 | let delta = score - mean
261 | error += (delta * delta)
262 | }
263 | let variance = error / Float(count - 1)
264 | let standardDeviation = sqrt(variance)
265 |
266 | var zscores: [Float] = []
267 | for score in scores {
268 | let zscore = abs((score - mean) / standardDeviation)
269 | zscores.append(zscore)
270 | }
271 |
272 | var fairMeanSum: Float = 0
273 | var fairMeanCount: Float = 0
274 | for i in 0 ..< scores.count {
275 | let zscore = zscores[i]
276 | if zscore <= 2 {
277 | let score = scores[i]
278 | fairMeanSum += score
279 | fairMeanCount += 1
280 | }
281 | }
282 | let fairMean = fairMeanSum / fairMeanCount
283 |
284 | print(
285 | "compareGeometry",
286 | "count", count,
287 | "mean", mean,
288 | "fairMean", fairMean,
289 | "variance", variance,
290 | "standard deviation", standardDeviation,
291 | "scores", scores,
292 | "zscores", zscores
293 | )
294 |
295 | return fairMean
296 | }
297 |
298 | public static func match(
299 | source: [SIFTDescriptor],
300 | target: [SIFTDescriptor],
301 | absoluteThreshold: Float = 1.176,
302 | relativeThreshold: Float = 0.6
303 | ) -> [SIFTCorrespondence] {
304 | var output = [SIFTCorrespondence]()
305 |
306 | for descriptor in source {
307 | let correspondence = match(
308 | descriptor: descriptor,
309 | target: target,
310 | absoluteThreshold: absoluteThreshold,
311 | relativeThreshold: relativeThreshold
312 | )
313 | if let correspondence {
314 | output.append(correspondence)
315 | }
316 | }
317 | return output
318 | }
319 |
320 | public static func match(
321 | descriptor: SIFTDescriptor,
322 | target: [SIFTDescriptor],
323 | absoluteThreshold: Float = 300,
324 | relativeThreshold: Float = 0.6
325 | ) -> SIFTCorrespondence? {
326 |
327 | var bestMatch: SIFTDescriptor?
328 | var bestMatchDistance: Float = .greatestFiniteMagnitude
329 | var secondBestMatchDistance: Float?
330 |
331 | for t in target {
332 | let distance = descriptor.indexValue.distance(to: t.indexValue)
333 | if distance < bestMatchDistance {
334 | bestMatch = t
335 | secondBestMatchDistance = bestMatchDistance
336 | bestMatchDistance = distance
337 | }
338 | }
339 |
340 | guard let bestMatch else {
341 | return nil
342 | }
343 |
344 | guard let secondBestMatchDistance else {
345 | return nil
346 | }
347 |
348 | guard bestMatchDistance < absoluteThreshold else {
349 | return nil
350 | }
351 |
352 | guard bestMatchDistance < (secondBestMatchDistance * relativeThreshold) else {
353 | return nil
354 | }
355 |
356 | return SIFTCorrespondence(
357 | source: descriptor,
358 | target: bestMatch,
359 | featureDistance: bestMatchDistance
360 | )
361 | }
362 | public static func approximateMatch(
363 | source: [SIFTDescriptor],
364 | target: [SIFTDescriptor],
365 | absoluteThreshold: Float = 300,
366 | relativeThreshold: Float = 0.6
367 | ) -> [SIFTCorrespondence] {
368 | var output = [SIFTCorrespondence]()
369 |
370 | let trie = Trie(numberOfBins: 8)
371 | for descriptor in target {
372 | trie.insert(key: descriptor.indexKey, value: descriptor)
373 | }
374 | trie.link()
375 |
376 | for descriptor in source {
377 | let correspondence = approximateMatch(
378 | descriptor: descriptor,
379 | target: trie,
380 | absoluteThreshold: absoluteThreshold,
381 | relativeThreshold: relativeThreshold
382 | )
383 | if let correspondence {
384 | output.append(correspondence)
385 | }
386 | }
387 | return output
388 | }
389 |
390 | public static func approximateMatch(
391 | descriptor: SIFTDescriptor,
392 | target: Trie,
393 | absoluteThreshold: Float = 300,
394 | relativeThreshold: Float = 0.6
395 | ) -> SIFTCorrespondence? {
396 | let matches = target.nearest(key: descriptor.indexKey, query: descriptor, radius: 10, k: 2)
397 | guard matches.count == 2 else {
398 | return nil
399 | }
400 |
401 | let bestMatch = matches[0]
402 | let secondBestMatch = matches[1]
403 |
404 | guard bestMatch.distance < absoluteThreshold else {
405 | return nil
406 | }
407 |
408 | guard bestMatch.distance < (secondBestMatch.distance * relativeThreshold) else {
409 | return nil
410 | }
411 |
412 | return SIFTCorrespondence(
413 | source: descriptor,
414 | target: bestMatch.value,
415 | featureDistance: bestMatch.distance
416 | )
417 | }
418 | }
419 |
420 | extension SIFTDescriptor {
421 |
422 | func distance(to other: SIFTDescriptor) -> Float {
423 | features.distance(to: other.features)
424 | }
425 |
426 | func distanceSquared(to other: SIFTDescriptor) -> Float {
427 | Float(features.distanceSquared(to: other.features))
428 | }
429 | }
430 |
--------------------------------------------------------------------------------