├── .github
└── workflows
│ └── swift.yml
├── .gitignore
├── .justfile
├── .pre-commit-config.yaml
├── .spi.yml
├── .swiftlint.yml
├── .swiftpm
├── configuration
│ └── Package.resolved
└── xcode
│ ├── package.xcworkspace
│ └── contents.xcworkspacedata
│ └── xcshareddata
│ └── xcschemes
│ ├── Compute-Package.xcscheme
│ ├── Compute.xcscheme
│ └── Examples.xcscheme
├── .vscode
└── launch.json
├── Documentation
├── Screenshot 2024-08-04 at 09.57.19.png
└── compute-logo.svg
├── LICENSE
├── Package.resolved
├── Package.swift
├── README.md
├── Sources
├── Compute
│ ├── Compute+Arguments.swift
│ ├── Compute+Pipeline.swift
│ ├── Compute+Task.swift
│ ├── Compute.swift
│ ├── ShaderFunction.swift
│ └── Support.swift
├── Examples
│ ├── .swiftlint.yml
│ ├── Bundle.txt
│ ├── Examples.swift
│ ├── Examples
│ │ ├── BareMetal.swift
│ │ ├── BitonicSort
│ │ │ ├── BitonicSort.metal
│ │ │ └── BitonicSort.swift
│ │ ├── Broken
│ │ │ ├── CountingSort
│ │ │ │ ├── CountingSort.metal
│ │ │ │ └── CountingSort.swift
│ │ │ └── RadixSort
│ │ │ │ ├── RadixSort.metal
│ │ │ │ ├── RadixSort.swift
│ │ │ │ └── RadixSortCPU.swift
│ │ ├── BufferFill.swift
│ │ ├── Checkerboard.swift
│ │ ├── CounterDemo.swift
│ │ ├── GameOfLife
│ │ │ ├── GameOfLife.metal
│ │ │ └── GameOfLife.swift
│ │ ├── HelloWorldDemo.swift
│ │ ├── Histogram.swift
│ │ ├── ImageInvert.swift
│ │ ├── IsSorted.swift
│ │ ├── MaxParallel.swift
│ │ ├── MaxValue.swift
│ │ ├── MemcopyDemo.swift
│ │ ├── RandomFill.swift
│ │ ├── SIMDReduce.swift
│ │ └── ThreadgroupLog.swift
│ ├── Resources
│ │ └── Media.xcassets
│ │ │ ├── Contents.json
│ │ │ ├── baboon-acorn-inverted.imageset
│ │ │ ├── Contents.json
│ │ │ └── baboon-acorn-inverted.png
│ │ │ └── baboon.imageset
│ │ │ ├── Contents.json
│ │ │ └── baboon.png
│ └── Support.swift
└── MetalSupportLite
│ ├── BaseSupport.swift
│ ├── MTLBuffer+Extensions.swift
│ ├── MetalBasicExtensions.swift
│ ├── MetalSupportLite.swift
│ └── TypedMTLBuffer.swift
└── Tests
├── .swiftlint.yml
└── ComputeTests
└── ComputeTests.swift
/.github/workflows/swift.yml:
--------------------------------------------------------------------------------
1 | name: Swift
2 | env:
3 | XCODE_VERSION: "latest-stable"
4 | on:
5 | push:
6 | pull_request:
7 | jobs:
8 | swift-build:
9 | runs-on: macos-15 # macos-latest
10 | steps:
11 | - uses: maxim-lobanov/setup-xcode@v1
12 | with:
13 | xcode-version: ${{ env.XCODE_VERSION }}
14 | - uses: actions/checkout@v3
15 | - run: swift build -v
16 | - run: swift test -v
17 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/swift,swiftpm,macos
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=swift,swiftpm,macos
3 |
4 | ### macOS ###
5 | # General
6 | .DS_Store
7 | .AppleDouble
8 | .LSOverride
9 |
10 | # Icon must end with two \r
11 | Icon
12 |
13 | # Thumbnails
14 | ._*
15 |
16 | # Files that might appear in the root of a volume
17 | .DocumentRevisions-V100
18 | .fseventsd
19 | .Spotlight-V100
20 | .TemporaryItems
21 | .Trashes
22 | .VolumeIcon.icns
23 | .com.apple.timemachine.donotpresent
24 |
25 | # Directories potentially created on remote AFP share
26 | .AppleDB
27 | .AppleDesktop
28 | Network Trash Folder
29 | Temporary Items
30 | .apdisk
31 |
32 | ### macOS Patch ###
33 | # iCloud generated files
34 | *.icloud
35 |
36 | ### Swift ###
37 | # Xcode
38 | #
39 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
40 |
41 | ## User settings
42 | xcuserdata/
43 |
44 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
45 | *.xcscmblueprint
46 | *.xccheckout
47 |
48 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
49 | build/
50 | DerivedData/
51 | *.moved-aside
52 | *.pbxuser
53 | !default.pbxuser
54 | *.mode1v3
55 | !default.mode1v3
56 | *.mode2v3
57 | !default.mode2v3
58 | *.perspectivev3
59 | !default.perspectivev3
60 |
61 | ## Obj-C/Swift specific
62 | *.hmap
63 |
64 | ## App packaging
65 | *.ipa
66 | *.dSYM.zip
67 | *.dSYM
68 |
69 | ## Playgrounds
70 | timeline.xctimeline
71 | playground.xcworkspace
72 |
73 | # Swift Package Manager
74 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies.
75 | # Packages/
76 | # Package.pins
77 | # Package.resolved
78 | # *.xcodeproj
79 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
80 | # hence it is not needed unless you have added a package configuration file to your project
81 | # .swiftpm
82 |
83 | .build/
84 |
85 | # CocoaPods
86 | # We recommend against adding the Pods directory to your .gitignore. However
87 | # you should judge for yourself, the pros and cons are mentioned at:
88 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control
89 | # Pods/
90 | # Add this line if you want to avoid checking in source code from the Xcode workspace
91 | # *.xcworkspace
92 |
93 | # Carthage
94 | # Add this line if you want to avoid checking in source code from Carthage dependencies.
95 | # Carthage/Checkouts
96 |
97 | Carthage/Build/
98 |
99 | # Accio dependency management
100 | Dependencies/
101 | .accio/
102 |
103 | # fastlane
104 | # It is recommended to not store the screenshots in the git repo.
105 | # Instead, use fastlane to re-generate the screenshots whenever they are needed.
106 | # For more information about the recommended setup visit:
107 | # https://docs.fastlane.tools/best-practices/source-control/#source-control
108 |
109 | fastlane/report.xml
110 | fastlane/Preview.html
111 | fastlane/screenshots/**/*.png
112 | fastlane/test_output
113 |
114 | # Code Injection
115 | # After new code Injection tools there's a generated folder /iOSInjectionProject
116 | # https://github.com/johnno1962/injectionforxcode
117 |
118 | iOSInjectionProject/
119 |
120 | ### SwiftPM ###
121 | Packages
122 | xcuserdata
123 | *.xcodeproj
124 |
125 |
126 | # End of https://www.toptal.com/developers/gitignore/api/swift,swiftpm,macos
127 |
--------------------------------------------------------------------------------
/.justfile:
--------------------------------------------------------------------------------
1 | build-docs:
2 | xcrun xcodebuild docbuild -scheme Compute -derivedDataPath /tmp/compute-docbuild -destination platform=macOS,arch=arm64
3 | cp -r /tmp/compute-docbuild/Build/Products/Debug/Compute.doccarchive ~/Desktop
4 |
5 | xcrun docc process-archive transform-for-static-hosting ~/Desktop/Compute.doccarchive --hosting-base-path / --output-path ~/Desktop/Compute-HTML/
6 |
7 | concurrency-check:
8 | swift clean
9 | swift build -Xswiftc -strict-concurrency=complete
10 |
11 | swift-six-check:
12 | swift clean
13 | SWIFT_VERSION=6 swift build --verbose
14 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 |
--------------------------------------------------------------------------------
/.spi.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | builder:
3 | configs:
4 | - documentation_targets: [Compute]
5 |
--------------------------------------------------------------------------------
/.swiftlint.yml:
--------------------------------------------------------------------------------
1 | only_rules:
2 | - accessibility_label_for_image
3 | - accessibility_trait_for_button
4 | - anonymous_argument_in_multiline_closure
5 | # - anyobject_protocol
6 | - array_init
7 | - attributes
8 | - balanced_xctest_lifecycle
9 | - blanket_disable_command
10 | - block_based_kvo
11 | - capture_variable
12 | - class_delegate_protocol
13 | - closing_brace
14 | - closure_body_length
15 | - closure_end_indentation
16 | - closure_parameter_position
17 | - closure_spacing
18 | - collection_alignment
19 | - colon
20 | - comma
21 | - comma_inheritance
22 | - comment_spacing
23 | - compiler_protocol_init
24 | - computed_accessors_order
25 | - conditional_returns_on_newline
26 | - contains_over_filter_count
27 | - contains_over_filter_is_empty
28 | - contains_over_first_not_nil
29 | - contains_over_range_nil_comparison
30 | - control_statement
31 | - convenience_type
32 | - custom_rules
33 | - cyclomatic_complexity
34 | - deployment_target
35 | - direct_return
36 | - discarded_notification_center_observer
37 | - discouraged_assert
38 | - discouraged_direct_init
39 | - discouraged_none_name
40 | - discouraged_object_literal
41 | - discouraged_optional_boolean
42 | # - discouraged_optional_collection # 1 violations
43 | - duplicate_conditions
44 | - duplicate_enum_cases
45 | - duplicate_imports
46 | - duplicated_key_in_dictionary_literal
47 | - dynamic_inline
48 | - empty_collection_literal
49 | - empty_count
50 | - empty_enum_arguments
51 | - empty_parameters
52 | - empty_parentheses_with_trailing_closure
53 | - empty_string
54 | - empty_xctest_method
55 | - enum_case_associated_values_count
56 | - expiring_todo
57 | # - explicit_acl # 24 violations
58 | - explicit_enum_raw_value
59 | # - explicit_init
60 | - explicit_self
61 | - explicit_top_level_acl
62 | # - explicit_type_interface
63 | - extension_access_modifier
64 | - fallthrough
65 | - fatal_error_message
66 | - file_header
67 | - file_length
68 | # - file_name # 1 violations
69 | - file_name_no_space
70 | # - file_types_order
71 | - final_test_case
72 | - first_where
73 | - flatmap_over_map_reduce
74 | - for_where
75 | - force_cast
76 | - force_try
77 | - force_unwrapping
78 | - function_body_length
79 | # - function_default_parameter_at_end
80 | - function_parameter_count
81 | - generic_type_name
82 | - ibinspectable_in_extension
83 | - identical_operands
84 | - identifier_name
85 | - implicit_getter
86 | - implicit_return
87 | - implicitly_unwrapped_optional
88 | - inclusive_language
89 | - indentation_width
90 | # - inert_defer
91 | - invalid_swiftlint_command
92 | - is_disjoint
93 | - joined_default_parameter
94 | - large_tuple
95 | - last_where
96 | - leading_whitespace
97 | - legacy_cggeometry_functions
98 | - legacy_constant
99 | - legacy_constructor
100 | - legacy_hashing
101 | - legacy_multiple
102 | - legacy_nsgeometry_functions
103 | - legacy_objc_type
104 | - legacy_random
105 | - let_var_whitespace
106 | # - line_length # 21 violations
107 | - literal_expression_end_indentation
108 | - local_doc_comment
109 | - lower_acl_than_parent
110 | - mark
111 | - missing_docs
112 | - modifier_order
113 | - multiline_arguments
114 | - multiline_arguments_brackets
115 | - multiline_function_chains
116 | - multiline_literal_brackets
117 | - multiline_parameters
118 | - multiline_parameters_brackets
119 | - multiple_closures_with_trailing_closure
120 | - nesting
121 | - nimble_operator
122 | # - no_extension_access_modifier
123 | - no_fallthrough_only
124 | - no_grouping_extension
125 | # - no_magic_numbers # 3 violations
126 | - no_space_in_method_call
127 | - non_optional_string_data_conversion
128 | - non_overridable_class_declaration
129 | - notification_center_detachment
130 | - ns_number_init_as_function_reference
131 | - nslocalizedstring_key
132 | - nslocalizedstring_require_bundle
133 | - nsobject_prefer_isequal
134 | - number_separator
135 | - object_literal
136 | # - one_declaration_per_file # 2 violations
137 | - opening_brace
138 | - operator_usage_whitespace
139 | - operator_whitespace
140 | - optional_enum_case_matching
141 | - orphaned_doc_comment
142 | - overridden_super_call
143 | - override_in_extension
144 | - pattern_matching_keywords
145 | - period_spacing
146 | - prefer_nimble
147 | - prefer_self_in_static_references
148 | - prefer_self_type_over_type_of_self
149 | - prefer_zero_over_explicit_init
150 | # - prefixed_toplevel_constant # 1 violations
151 | - private_action
152 | - private_outlet
153 | - private_over_fileprivate
154 | - private_subject
155 | - private_swiftui_state
156 | - private_unit_test
157 | - prohibited_interface_builder
158 | - prohibited_super_call
159 | - protocol_property_accessors_order
160 | - quick_discouraged_call
161 | - quick_discouraged_focused_test
162 | - quick_discouraged_pending_test
163 | - raw_value_for_camel_cased_codable_enum
164 | - reduce_boolean
165 | - reduce_into
166 | - redundant_discardable_let
167 | - redundant_nil_coalescing
168 | - redundant_objc_attribute
169 | - redundant_optional_initialization
170 | - redundant_self_in_closure
171 | - redundant_set_access_control
172 | - redundant_string_enum_value
173 | - redundant_type_annotation
174 | - redundant_void_return
175 | # - required_deinit
176 | - required_enum_case
177 | - return_arrow_whitespace
178 | - return_value_from_void_function
179 | - self_binding
180 | - self_in_property_initialization
181 | - shorthand_argument
182 | - shorthand_operator
183 | - shorthand_optional_binding
184 | - single_test_class
185 | - sorted_enum_cases
186 | - sorted_first_last
187 | - sorted_imports
188 | - statement_position
189 | - static_operator
190 | - static_over_final_class
191 | - strict_fileprivate
192 | - strong_iboutlet
193 | - superfluous_disable_command
194 | - superfluous_else
195 | - switch_case_alignment
196 | - switch_case_on_newline
197 | - syntactic_sugar
198 | - test_case_accessibility
199 | - todo
200 | - toggle_bool
201 | - trailing_closure
202 | - trailing_comma
203 | - trailing_newline
204 | - trailing_semicolon
205 | - trailing_whitespace
206 | - type_body_length
207 | # - type_contents_order # 3 violations
208 | - type_name
209 | - typesafe_array_init
210 | - unavailable_condition
211 | - unavailable_function
212 | - unhandled_throwing_task
213 | - unneeded_break_in_switch
214 | - unneeded_override
215 | - unneeded_parentheses_in_closure_argument
216 | - unneeded_synthesized_initializer
217 | - unowned_variable_capture
218 | - untyped_error_in_catch
219 | # - unused_capture_list
220 | - unused_closure_parameter
221 | - unused_control_flow_label
222 | - unused_declaration
223 | - unused_enumerated
224 | - unused_import
225 | - unused_optional_binding
226 | - unused_setter_value
227 | - valid_ibinspectable
228 | - vertical_parameter_alignment
229 | - vertical_parameter_alignment_on_call
230 | - vertical_whitespace
231 | - vertical_whitespace_between_cases
232 | - vertical_whitespace_closing_braces
233 | - vertical_whitespace_opening_braces
234 | - void_function_in_ternary
235 | - void_return
236 | - weak_delegate
237 | - xct_specific_matcher
238 | - xctfail_message
239 | - yoda_condition
240 |
241 | excluded:
242 | - .build
243 |
--------------------------------------------------------------------------------
/.swiftpm/configuration/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "originHash" : "ec7dcd2fd869101534d790fb80b9a276d9dea53d259be4371c2e3fd99126356f",
3 | "pins" : [
4 | {
5 | "identity" : "everything",
6 | "kind" : "remoteSourceControl",
7 | "location" : "https://github.com/schwa/Everything",
8 | "state" : {
9 | "branch" : "jwight/swift-6",
10 | "revision" : "e24a8c69a41cfb9abdae4e91a4b7f056f2e36bbb"
11 | }
12 | },
13 | {
14 | "identity" : "metalcompilerplugin",
15 | "kind" : "remoteSourceControl",
16 | "location" : "https://github.com/schwa/MetalCompilerPlugin",
17 | "state" : {
18 | "branch" : "jwight/logging",
19 | "revision" : "86239f9d8a6610e6ab2c335a3ef6e96fdd242e48"
20 | }
21 | },
22 | {
23 | "identity" : "swift-algorithms",
24 | "kind" : "remoteSourceControl",
25 | "location" : "https://github.com/apple/swift-algorithms",
26 | "state" : {
27 | "revision" : "f6919dfc309e7f1b56224378b11e28bab5bccc42",
28 | "version" : "1.2.0"
29 | }
30 | },
31 | {
32 | "identity" : "swift-argument-parser",
33 | "kind" : "remoteSourceControl",
34 | "location" : "https://github.com/apple/swift-argument-parser",
35 | "state" : {
36 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c",
37 | "version" : "1.5.0"
38 | }
39 | },
40 | {
41 | "identity" : "swift-async-algorithms",
42 | "kind" : "remoteSourceControl",
43 | "location" : "https://github.com/apple/swift-async-algorithms",
44 | "state" : {
45 | "revision" : "9cfed92b026c524674ed869a4ff2dcfdeedf8a2a",
46 | "version" : "0.1.0"
47 | }
48 | },
49 | {
50 | "identity" : "swift-collections",
51 | "kind" : "remoteSourceControl",
52 | "location" : "https://github.com/apple/swift-collections.git",
53 | "state" : {
54 | "revision" : "9bf03ff58ce34478e66aaee630e491823326fd06",
55 | "version" : "1.1.3"
56 | }
57 | },
58 | {
59 | "identity" : "swift-numerics",
60 | "kind" : "remoteSourceControl",
61 | "location" : "https://github.com/apple/swift-numerics.git",
62 | "state" : {
63 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b",
64 | "version" : "1.0.2"
65 | }
66 | },
67 | {
68 | "identity" : "swift-syntax",
69 | "kind" : "remoteSourceControl",
70 | "location" : "https://github.com/apple/swift-syntax.git",
71 | "state" : {
72 | "revision" : "515f79b522918f83483068d99c68daeb5116342d",
73 | "version" : "600.0.0-prerelease-2024-08-20"
74 | }
75 | },
76 | {
77 | "identity" : "swiftfields",
78 | "kind" : "remoteSourceControl",
79 | "location" : "https://github.com/schwa/swiftfields",
80 | "state" : {
81 | "revision" : "8a9715b88509c93557833018f335397153f194bb",
82 | "version" : "0.1.3"
83 | }
84 | },
85 | {
86 | "identity" : "swiftformats",
87 | "kind" : "remoteSourceControl",
88 | "location" : "https://github.com/schwa/swiftformats",
89 | "state" : {
90 | "revision" : "6b51ccec7fccf2f1d4f28f03def6c6b97d17c05b",
91 | "version" : "0.3.6"
92 | }
93 | },
94 | {
95 | "identity" : "swiftgltf",
96 | "kind" : "remoteSourceControl",
97 | "location" : "https://github.com/schwa/SwiftGLTF",
98 | "state" : {
99 | "branch" : "main",
100 | "revision" : "4160ed7e89b7cfcccf5ff7a0cbfd5b055fada771"
101 | }
102 | },
103 | {
104 | "identity" : "swiftgraphics",
105 | "kind" : "remoteSourceControl",
106 | "location" : "https://github.com/schwa/SwiftGraphics",
107 | "state" : {
108 | "branch" : "jwight/develop",
109 | "revision" : "ac3d19558f0ecb8249fd83d97a9cc87f37be24ee"
110 | }
111 | },
112 | {
113 | "identity" : "wrappinghstack",
114 | "kind" : "remoteSourceControl",
115 | "location" : "https://github.com/ksemianov/WrappingHStack",
116 | "state" : {
117 | "revision" : "3300f68b6bf5f8a75ee7ca8a40f136a558053d10",
118 | "version" : "0.2.0"
119 | }
120 | }
121 | ],
122 | "version" : 3
123 | }
124 |
--------------------------------------------------------------------------------
/.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.swiftpm/xcode/xcshareddata/xcschemes/Compute-Package.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
30 |
36 |
37 |
38 |
44 |
50 |
51 |
52 |
53 |
54 |
60 |
61 |
63 |
69 |
70 |
71 |
72 |
73 |
83 |
84 |
90 |
91 |
92 |
93 |
99 |
100 |
106 |
107 |
108 |
109 |
111 |
112 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/.swiftpm/xcode/xcshareddata/xcschemes/Compute.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
43 |
44 |
50 |
51 |
57 |
58 |
59 |
60 |
62 |
63 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/.swiftpm/xcode/xcshareddata/xcschemes/Examples.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
43 |
45 |
51 |
52 |
53 |
54 |
57 |
58 |
61 |
62 |
65 |
66 |
67 |
68 |
74 |
76 |
82 |
83 |
84 |
85 |
87 |
88 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "configurations": [
3 | {
4 | "type": "lldb",
5 | "request": "launch",
6 | "sourceLanguages": [
7 | "swift"
8 | ],
9 | "args": [],
10 | "cwd": "${workspaceFolder:Compute}",
11 | "name": "Debug Examples",
12 | "program": "${workspaceFolder:Compute}/.build/debug/Examples",
13 | "preLaunchTask": "swift: Build Debug Examples"
14 | },
15 | {
16 | "type": "lldb",
17 | "request": "launch",
18 | "sourceLanguages": [
19 | "swift"
20 | ],
21 | "args": [],
22 | "cwd": "${workspaceFolder:Compute}",
23 | "name": "Release Examples",
24 | "program": "${workspaceFolder:Compute}/.build/release/Examples",
25 | "preLaunchTask": "swift: Build Release Examples"
26 | }
27 | ]
28 | }
29 |
--------------------------------------------------------------------------------
/Documentation/Screenshot 2024-08-04 at 09.57.19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/schwa/Compute/b1ba1e1071c912007bdbb59313e175887335fab8/Documentation/Screenshot 2024-08-04 at 09.57.19.png
--------------------------------------------------------------------------------
/Documentation/compute-logo.svg:
--------------------------------------------------------------------------------
1 |
78 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jonathan Wight
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "originHash" : "aec40d709f0bbfd74d34611e9ee223a92bb57eb697224186f914d2d3cd91a04c",
3 | "pins" : [
4 | {
5 | "identity" : "metalcompilerplugin",
6 | "kind" : "remoteSourceControl",
7 | "location" : "https://github.com/schwa/MetalCompilerPlugin",
8 | "state" : {
9 | "revision" : "1db3b5dcabd648e0316f3a472646e72af24a5bde",
10 | "version" : "0.1.0"
11 | }
12 | }
13 | ],
14 | "version" : 3
15 | }
16 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version: 6.0
2 |
3 | import PackageDescription
4 |
5 | // swiftlint:disable:next explicit_top_level_acl
6 | let package = Package(
7 | name: "Compute",
8 | platforms: [.macOS(.v15), .iOS(.v18)],
9 | products: [
10 | .library(name: "Compute", targets: ["Compute"]),
11 | ],
12 | dependencies: [
13 | .package(url: "https://github.com/schwa/MetalCompilerPlugin", from: "0.1.0")
14 | ],
15 | targets: [
16 | .target(name: "Compute"),
17 | .target(name: "MetalSupportLite"),
18 | .executableTarget(
19 | name: "Examples",
20 | dependencies: [
21 | "Compute",
22 | "MetalSupportLite",
23 | ],
24 | resources: [
25 | .copy("Bundle.txt"),
26 | .process("Resources/Media.xcassets")
27 | ],
28 | plugins: [
29 | .plugin(name: "MetalCompilerPlugin", package: "MetalCompilerPlugin")
30 | ]
31 | ),
32 | .testTarget(name: "ComputeTests", dependencies: ["Compute"])
33 | ]
34 | )
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # Compute
4 |
5 | This project provides a high-level Swift framework for working with Metal compute shaders. It simplifies the process of setting up and executing compute tasks on GPU using Apple's Metal API.
6 |
7 | ## Inspiration
8 |
9 | This project draws inspiration from Apple's SwiftUI shaders introduced in iOS 17 and macOS Sonoma. SwiftUI shaders demonstrate how GPU operations can be made more accessible to developers, enabling complex visual effects with simple Swift code. Compute aims to provide a similar level of abstraction for Metal compute shaders, making it easier to perform data-parallel computations on the GPU.
10 |
11 | ## Usage Example
12 |
13 | Here's a quick example of how to use the Metal Compute Framework to perform a basic computation:
14 |
15 | ```swift
16 | import Compute
17 | import Metal
18 |
19 | // Example usage that adds two arrays of integers using Compute.
20 |
21 | // Metal shader source code
22 | let source = """
23 | #include
24 | using namespace metal;
25 |
26 | kernel void add(device int* inA [[buffer(0)]],
27 | device int* inB [[buffer(1)]],
28 | device int* result [[buffer(2)]],
29 | uint id [[thread_position_in_grid]]) {
30 | result[id] = inA[id] + inB[id];
31 | }
32 | """
33 |
34 | // Set up the compute environment
35 | let device = MTLCreateSystemDefaultDevice()!
36 | let compute = try Computer(device: device)
37 |
38 | // Create input data
39 | let count = 1000
40 | let inA = [Int32](repeating: 1, count: count)
41 | let inB = [Int32](repeating: 2, count: count)
42 | var result = [Int32](repeating: 0, count: count)
43 |
44 | // Create Metal buffers
45 | let bufferA = device.makeBuffer(bytes: inA, length: MemoryLayout.stride * count, options: [])!
46 | let bufferB = device.makeBuffer(bytes: inB, length: MemoryLayout.stride * count, options: [])!
47 | let bufferResult = device.makeBuffer(length: MemoryLayout.stride * count, options: [])!
48 |
49 | // Create a shader library and function
50 | let library = ShaderLibrary.source(source)
51 | let function = library.add
52 |
53 | // Create a compute pipeline and bind arguments.
54 | var pipeline = try compute.makePipeline(function: function)
55 | pipeline.arguments.inA = .buffer(bufferA)
56 | pipeline.arguments.inB = .buffer(bufferB)
57 | pipeline.arguments.result = .buffer(bufferResult)
58 |
59 | // Run the compute pipeline
60 | try compute.run(pipeline: pipeline, width: count)
61 |
62 | // Read back the results
63 | var bufferPointer = bufferResult.contents()
64 | let bufferRawBuffer32 = bufferPointer.bindMemory(to: Int32.self, capacity: bufferResult.length)
65 | let bufferBufferPointer = UnsafeBufferPointer(start: bufferRawBuffer32, count: count)
66 |
67 | // Verify the results
68 | for i in 0..)
113 |
114 | ## License
115 |
116 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
117 |
118 | ## Contributing
119 |
120 | Contributions are welcome to the Compute Framework.
121 |
122 | Note: Some of our initial documentation and tests were AI-generated.
123 |
124 | ## Links
125 |
126 | > [Metal Overview - Apple Developer](https://developer.apple.com/metal/)
127 |
128 | Apple's main Metal documentation.
129 |
130 | > - [developer.apple.com/metal/Metal-Shading-Language-Specification.pdf](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf)
131 |
132 | The Metal Shading Language Specification book. This is the definitive guide to writing shaders in Metal.
133 |
134 | > - [developer.apple.com/metal/Metal-Feature-Set-Tables.pdf](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf)
135 |
136 | The Metal Feature Set Tables book. This is a reference for which features are available on which devices/Metal versions.
137 |
138 | > - [Metal by Example – High-performance graphics and data-parallel programming for iOS and macOS](https://metalbyexample.com)
139 |
140 | Warren Moore's blog is the single best resource for learning Metal programming.
141 |
142 | > - [Introduction to Compute Programming in Metal – Metal by Example](https://metalbyexample.com/introduction-to-compute/)
143 |
144 | Warren has some posts on Compute programming in Metal but they're showing their age a bit. Nevertheless, they're a good starting point.
145 |
146 | > - [Shader | Apple Developer Documentation](https://developer.apple.com/documentation/swiftui/shader)
147 |
148 | SwiftUI's Shader was the primary inspiration for this project.
149 |
150 | > - [Calculating Threadgroup and Grid Sizes | Apple Developer Documentation](https://developer.apple.com/documentation/metal/compute_passes/calculating_threadgroup_and_grid_sizes)
151 | > - [Creating Threads and Threadgroups | Apple Developer Documentation](https://developer.apple.com/documentation/metal/compute_passes/creating_threads_and_threadgroups)
152 |
153 | How to calculate threadgroup and grid sizes. This is a critical concept in Metal compute programming.
154 |
--------------------------------------------------------------------------------
/Sources/Compute/Compute+Arguments.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 | import simd
3 | import SwiftUI
4 |
5 | public extension Compute {
6 | /// A structure that holds and manages arguments for a compute pipeline.
7 | ///
8 | /// This structure uses dynamic member lookup to provide a convenient way to access and set arguments.
9 | @dynamicMemberLookup
10 | struct Arguments {
11 | /// The underlying dictionary storing the arguments.
12 | internal var arguments: [String: Argument]
13 |
14 | /// Provides access to arguments using dynamic member lookup.
15 | ///
16 | /// - Parameter name: The name of the argument.
17 | /// - Returns: The argument value if it exists, or nil if it doesn't.
18 | public subscript(dynamicMember name: String) -> Argument? {
19 | get {
20 | arguments[name]
21 | }
22 | set {
23 | arguments[name] = newValue
24 | // NOTE: It would be nice to assign name as a label to buffers/textures that have no name.
25 | }
26 | }
27 | }
28 |
29 | /// Represents an argument that can be passed to a compute shader.
30 | ///
31 | /// This struct encapsulates the logic for encoding the argument to a compute command encoder
32 | /// and setting it as a constant value in a function constant values object.
33 | struct Argument {
34 | /// A closure that encodes the argument to a compute command encoder.
35 | internal var encode: (MTLComputeCommandEncoder, Int) -> Void
36 |
37 | /// A closure that sets the argument as a constant value in a function constant values object.
38 | internal var constantValue: (MTLFunctionConstantValues, String) -> Void
39 |
40 | public init(encode: @escaping (MTLComputeCommandEncoder, Int) -> Void, constantValue: @escaping (MTLFunctionConstantValues, String) -> Void) {
41 | self.encode = encode
42 | self.constantValue = constantValue
43 | }
44 | }
45 |
46 | }
47 |
48 |
49 | public extension Compute.Argument {
50 |
51 | /// Creates an integer argument.
52 | ///
53 | /// - Parameter value: The integer value.
54 | /// - Returns: An `Argument` instance representing the integer.
55 | static func int(_ value: some BinaryInteger) -> Self {
56 | .init { encoder, index in
57 | withUnsafeBytes(of: value) { buffer in
58 | guard let baseAddress = buffer.baseAddress else {
59 | fatalError("Could not get baseAddress.")
60 | }
61 | encoder.setBytes(baseAddress, length: buffer.count, index: index)
62 | }
63 | }
64 | constantValue: { constants, name in
65 | withUnsafeBytes(of: value) { buffer in
66 | guard let baseAddress = buffer.baseAddress else {
67 | fatalError("Could not get baseAddress.")
68 | }
69 | switch value {
70 | case is Int8:
71 | constants.setConstantValue(baseAddress, type: .char, withName: name)
72 |
73 | case is UInt8:
74 | constants.setConstantValue(baseAddress, type: .uchar, withName: name)
75 |
76 | case is Int16:
77 | constants.setConstantValue(baseAddress, type: .short, withName: name)
78 |
79 | case is UInt16:
80 | constants.setConstantValue(baseAddress, type: .ushort, withName: name)
81 |
82 | case is Int32:
83 | constants.setConstantValue(baseAddress, type: .int, withName: name)
84 |
85 | case is UInt32:
86 | constants.setConstantValue(baseAddress, type: .uint, withName: name)
87 |
88 | default:
89 | fatalError("Unsupported integer type.")
90 | }
91 | }
92 | }
93 | }
94 |
95 | /// Creates a float argument.
96 | ///
97 | /// - Parameter value: The float value.
98 | /// - Returns: An `Argument` instance representing the float.
99 | static func float(_ value: Float) -> Self {
100 | .init { encoder, index in
101 | withUnsafeBytes(of: value) { buffer in
102 | guard let baseAddress = buffer.baseAddress else {
103 | fatalError("Could not get baseAddress.")
104 | }
105 | encoder.setBytes(baseAddress, length: buffer.count, index: index)
106 | }
107 | }
108 | constantValue: { constants, name in
109 | withUnsafeBytes(of: value) { buffer in
110 | guard let baseAddress = buffer.baseAddress else {
111 | fatalError("Could not get baseAddress.")
112 | }
113 | constants.setConstantValue(baseAddress, type: .float, withName: name)
114 | }
115 | }
116 | }
117 |
118 | /// Creates a boolean argument.
119 | ///
120 | /// - Parameter value: The boolean value.
121 | /// - Returns: An `Argument` instance representing the boolean.
122 | static func bool(_ value: Bool) -> Self {
123 | .init { encoder, index in
124 | withUnsafeBytes(of: value) { buffer in
125 | guard let baseAddress = buffer.baseAddress else {
126 | fatalError("Could not get baseAddress.")
127 | }
128 | encoder.setBytes(baseAddress, length: buffer.count, index: index)
129 | }
130 | }
131 | constantValue: { constants, name in
132 | withUnsafeBytes(of: value) { buffer in
133 | guard let baseAddress = buffer.baseAddress else {
134 | fatalError("Could not get baseAddress.")
135 | }
136 | constants.setConstantValue(baseAddress, type: .bool, withName: name)
137 | }
138 | }
139 | }
140 |
141 | /// Creates a buffer argument.
142 | ///
143 | /// - Parameters:
144 | /// - buffer: The Metal buffer to be used as an argument.
145 | /// - offset: The offset within the buffer. Defaults to 0.
146 | /// - Returns: An `Argument` instance representing the buffer.
147 | static func buffer(_ buffer: MTLBuffer, offset: Int = 0) -> Self {
148 | .init { encoder, index in
149 | encoder.setBuffer(buffer, offset: offset, index: index)
150 | }
151 | constantValue: { _, _ in
152 | fatalError("Unimplemented")
153 | }
154 | }
155 |
156 | /// Creates a texture argument.
157 | ///
158 | /// - Parameter texture: The Metal texture to be used as an argument.
159 | /// - Returns: An `Argument` instance representing the texture.
160 | static func texture(_ texture: MTLTexture) -> Self {
161 | Self { encoder, index in
162 | encoder.setTexture(texture, index: index)
163 | }
164 | constantValue: { _, _ in
165 | fatalError("Unimplemented")
166 | }
167 | }
168 |
169 | /// Creates an argument from a simd vector
170 | ///
171 | /// - Parameter value: The vector value.
172 | /// - Returns: An `Argument` instance representing the boolean.
173 | static func vector(_ value: V) -> Self where V: SIMD {
174 | .init { encoder, index in
175 | withUnsafeBytes(of: value) { buffer in
176 | guard let baseAddress = buffer.baseAddress else {
177 | fatalError("Could not get baseAddress.")
178 | }
179 | encoder.setBytes(baseAddress, length: buffer.count, index: index)
180 | }
181 | }
182 | constantValue: { _, _ in
183 | fatalError("Unimplemented")
184 | }
185 | }
186 |
187 | /// Creates an argument from a simd vector
188 | ///
189 | /// - Parameter value: The vector value.
190 | /// - Returns: An `Argument` instance representing the boolean.
191 | static func float4(_ value: SIMD4) -> Self {
192 | .vector(value)
193 | }
194 |
195 | /// Creates an argument from a simd vector
196 | ///
197 | /// - Parameter value: The vector value.
198 | /// - Returns: An `Argument` instance representing the boolean.
199 | static func color(_ value: Color) throws -> Self {
200 | let cgColor = value.resolve(in: .init()).cgColor
201 | guard let colorSpace = CGColorSpace(name: CGColorSpace.genericRGBLinear) else {
202 | throw ComputeError.resourceCreationFailure
203 | }
204 | guard let components = cgColor.converted(to: colorSpace, intent: .defaultIntent, options: nil)?.components else {
205 | throw ComputeError.resourceCreationFailure
206 | }
207 |
208 | // TODO: This assumes the pass wants a SIMD4 - we can use reflection to work out what is needed and convert appropriately. This will mean we need to refactor Compute.Argument
209 |
210 | let vector = SIMD4([Float(components[0]), Float(components[1]), Float(components[2]), Float(components[3])])
211 | return .vector(vector)
212 | }
213 |
214 | /// Creates an argument from a threadgroup memory length
215 | static func threadgroupMemoryLength(_ value: Int) -> Self {
216 | Self { encoder, index in
217 | encoder.setThreadgroupMemoryLength(value, index: index)
218 | }
219 | constantValue: { _, _ in
220 | fatalError("Unimplemented")
221 | }
222 | }
223 |
224 | static func buffer(_ array: [T], offset: Int = 0, label: String? = nil) -> Self {
225 | .init { encoder, index in
226 | let buffer = array.withUnsafeBufferPointer { buffer in
227 | return encoder.device.makeBuffer(bytes: buffer.baseAddress!, length: buffer.count * MemoryLayout.stride, options: [])
228 | }
229 | if let label {
230 | buffer?.label = label
231 | }
232 | encoder.setBuffer(buffer, offset: offset, index: index)
233 | }
234 | constantValue: { _, _ in
235 | fatalError("Unimplemented")
236 | }
237 | }
238 |
239 | }
240 |
--------------------------------------------------------------------------------
/Sources/Compute/Compute+Pipeline.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 |
3 | public extension Compute {
4 | /// Represents a compute pipeline, encapsulating all the information needed to execute a compute operation.
5 | ///
6 | /// A `Pipeline` includes the shader function, its arguments, and the associated compute pipeline state.
7 | /// It provides the necessary context for dispatching compute operations on the GPU.
8 | struct Pipeline {
9 | /// The shader function associated with this pipeline.
10 | public let function: ShaderFunction
11 |
12 | /// A dictionary mapping argument names to their binding indices.
13 | internal let bindings: [String: Int]
14 |
15 | /// The arguments to be passed to the shader function.
16 | public var arguments: Arguments
17 |
18 | /// The compute pipeline state created from the shader function.
19 | public let computePipelineState: MTLComputePipelineState
20 |
21 | /// Initializes a new compute pipeline.
22 | ///
23 | /// This initializer creates a compute pipeline state from the provided shader function and sets up the necessary bindings and arguments.
24 | ///
25 | /// - Parameters:
26 | /// - device: The Metal device on which the compute pipeline will be executed.
27 | /// - function: The shader function to be used in this pipeline.
28 | /// - constants: A dictionary of constant values to be used when compiling the shader function. Defaults to an empty dictionary.
29 | /// - arguments: A dictionary of arguments to be passed to the shader function. Defaults to an empty dictionary.
30 | /// - Throws: An error if the compute pipeline state cannot be created or if there's an issue with the shader function.
31 | internal init(device: MTLDevice, function: ShaderFunction, constants: [String: Argument] = [:], arguments: [String: Argument] = [:]) throws {
32 | self.function = function
33 |
34 | let constantValues = MTLFunctionConstantValues()
35 | for (name, constant) in constants {
36 | constant.constantValue(constantValues, name)
37 | }
38 |
39 | let library = try function.library.makelibrary(device: device)
40 |
41 | let function = try library.makeFunction(name: function.name, constantValues: constantValues)
42 | function.label = "\(function.name)-MTLFunction"
43 | let computePipelineDescriptor = MTLComputePipelineDescriptor()
44 | computePipelineDescriptor.label = "\(function.name)-MTLComputePipelineState"
45 | computePipelineDescriptor.computeFunction = function
46 | let (computePipelineState, reflection) = try device.makeComputePipelineState(descriptor: computePipelineDescriptor, options: [.bindingInfo])
47 | guard let reflection else {
48 | throw ComputeError.resourceCreationFailure
49 | }
50 | bindings = Dictionary(uniqueKeysWithValues: reflection.bindings.map { binding in
51 | (binding.name, binding.index)
52 | })
53 |
54 | self.computePipelineState = computePipelineState
55 | self.arguments = Arguments(arguments: arguments)
56 | }
57 |
58 | /// The maximum total number of threads per threadgroup for this compute pipeline state.
59 | public var maxTotalThreadsPerThreadgroup: Int {
60 | computePipelineState.maxTotalThreadsPerThreadgroup
61 | }
62 |
63 | /// The thread execution width for this compute pipeline state.
64 | public var threadExecutionWidth: Int {
65 | computePipelineState.threadExecutionWidth
66 | }
67 |
68 | /// Binds the arguments to the provided compute command encoder.
69 | ///
70 | /// This method sets up the arguments for the compute operation, associating each argument with its corresponding binding point.
71 | ///
72 | /// - Parameter commandEncoder: The compute command encoder to which the arguments should be bound.
73 | /// - Throws: `ComputeError.missingBinding` if a required binding is not found for an argument.
74 | func bind(_ commandEncoder: MTLComputeCommandEncoder) throws {
75 | for (name, value) in arguments.arguments {
76 | guard let index = bindings[name] else {
77 | throw ComputeError.missingBinding(name)
78 | }
79 | value.encode(commandEncoder, index)
80 | }
81 | }
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/Sources/Compute/Compute+Task.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 | import os
3 |
4 | public extension Compute {
5 | /// Represents a compute task that can be executed on the GPU.
6 | ///
7 | /// A `Task` encapsulates a Metal command buffer and provides methods to execute compute operations.
8 | struct Task {
9 | /// The label for this task, used for debugging and profiling.
10 | let label: String?
11 |
12 | /// The logger used for logging information about the task execution.
13 | let logger: Logger?
14 |
15 | /// The Metal command buffer associated with this task.
16 | let commandBuffer: MTLCommandBuffer
17 |
18 | /// Executes a block of code with a `Dispatcher`.
19 | ///
20 | /// This method is a convenience wrapper around the `run` method.
21 | ///
22 | /// - Parameter block: A closure that takes a `Dispatcher` and returns a result.
23 | /// - Returns: The result of the block execution.
24 | /// - Throws: Any error that occurs during the execution of the block.
25 | public func callAsFunction(_ block: (Dispatcher) throws -> R) throws -> R {
26 | try run(block)
27 | }
28 |
29 | /// Runs a block of code with a `Dispatcher`.
30 | ///
31 | /// This method creates a compute command encoder and executes the provided block with a `Dispatcher`.
32 | ///
33 | /// - Parameter block: A closure that takes a `Dispatcher` and returns a result.
34 | /// - Returns: The result of the block execution.
35 | /// - Throws: `ComputeError.resourceCreationFailure` if unable to create a compute command encoder,
36 | /// or any error that occurs during the execution of the block.
37 | public func run(_ block: (Dispatcher) throws -> R) throws -> R {
38 | guard let commandEncoder = commandBuffer.makeComputeCommandEncoder() else {
39 | throw ComputeError.resourceCreationFailure
40 | }
41 | commandEncoder.label = "\(label ?? "Unlabeled")-MTLComputeCommandEncoder"
42 |
43 | defer {
44 | commandEncoder.endEncoding()
45 | }
46 | let dispatcher = Dispatcher(label: label, logger: logger, commandEncoder: commandEncoder)
47 | return try block(dispatcher)
48 | }
49 | }
50 |
51 | /// Handles the dispatching of compute operations to the GPU.
52 | ///
53 | /// A `Dispatcher` is responsible for setting up and executing compute operations using the provided `Pipeline`.
54 | struct Dispatcher {
55 | /// The label for this dispatcher, used for debugging and profiling.
56 | public let label: String?
57 |
58 | /// The logger used for logging information about the dispatch operation.
59 | public let logger: Logger?
60 |
61 | /// The Metal compute command encoder used to encode compute commands.
62 | public let commandEncoder: MTLComputeCommandEncoder
63 |
64 | /// Dispatches a compute operation using threadgroups.
65 | ///
66 | /// - Parameters:
67 | /// - pipeline: The `Pipeline` containing the compute pipeline state and arguments.
68 | /// - threadgroupsPerGrid: The number of threadgroups to dispatch in each dimension.
69 | /// - threadsPerThreadgroup: The number of threads in each threadgroup.
70 | /// - Throws: Any error that occurs during the binding of arguments or dispatch.
71 | public func callAsFunction(pipeline: Pipeline, threadgroupsPerGrid: MTLSize, threadsPerThreadgroup: MTLSize) throws {
72 | logger?.info("Dispatching \(threadgroupsPerGrid.shortDescription) threadgroups per grid with \(threadsPerThreadgroup.shortDescription) threads per threadgroup. maxTotalThreadsPerThreadgroup: \(pipeline.computePipelineState.maxTotalThreadsPerThreadgroup), threadExecutionWidth: \(pipeline.computePipelineState.threadExecutionWidth).")
73 |
74 | commandEncoder.setComputePipelineState(pipeline.computePipelineState)
75 | try pipeline.bind(commandEncoder)
76 | commandEncoder.dispatchThreadgroups(threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
77 | }
78 |
79 | /// Dispatches a compute operation using a specific number of threads.
80 | ///
81 | /// - Parameters:
82 | /// - pipeline: The `Pipeline` containing the compute pipeline state and arguments.
83 | /// - threads: The total number of threads to dispatch in each dimension.
84 | /// - threadsPerThreadgroup: The number of threads in each threadgroup.
85 | /// - Throws: Any error that occurs during the binding of arguments or dispatch.
86 | public func callAsFunction(pipeline: Pipeline, threads: MTLSize, threadsPerThreadgroup: MTLSize) throws {
87 | let device = commandEncoder.device
88 | guard device.supportsFamily(.apple4) || device.supportsFamily(.common3) || device.supportsFamily(.metal3) else {
89 | throw ComputeError.nonuniformThreadgroupsSizeNotSupported
90 | }
91 |
92 | logger?.info("Dispatch - threads: \(threads.shortDescription), threadsPerThreadgroup \(threadsPerThreadgroup.shortDescription), maxTotalThreadsPerThreadgroup: \(pipeline.computePipelineState.maxTotalThreadsPerThreadgroup), threadExecutionWidth: \(pipeline.computePipelineState.threadExecutionWidth).")
93 |
94 | commandEncoder.setComputePipelineState(pipeline.computePipelineState)
95 | commandEncoder.setComputePipelineState(pipeline.computePipelineState)
96 | try pipeline.bind(commandEncoder)
97 | commandEncoder.dispatchThreads(threads, threadsPerThreadgroup: threadsPerThreadgroup)
98 | }
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/Sources/Compute/Compute.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 | import os
3 |
4 | /// The main struct that encapsulates the Metal compute environment.
5 | ///
6 | /// This struct provides the core functionality for creating and executing compute tasks on the GPU.
7 | public struct Compute {
8 | /// The Metal device used for compute operations.
9 | public let device: MTLDevice
10 |
11 | /// The logger used for debugging and performance monitoring.
12 | let logger: Logger?
13 |
14 | /// The Metal command queue used for submitting command buffers.
15 | let commandQueue: MTLCommandQueue
16 |
17 | /// Initializes a new Compute instance.
18 | ///
19 | /// - Parameters:
20 | /// - device: The Metal device to use for compute operations.
21 | /// - logger: An optional logger for debugging and performance monitoring.
22 | /// - useLogState: Enable capture of Metal log messages if available.
23 | /// - Throws: `ComputeError.resourceCreationFailure` if unable to create the command queue.
24 | public init(device: MTLDevice, logger: Logger? = nil, useLogState: Bool = false) throws {
25 | self.device = device
26 | self.logger = logger
27 | #if !targetEnvironment(simulator)
28 | if #available(macOS 15, iOS 18, *) {
29 | let commandQueueDescriptor = MTLCommandQueueDescriptor()
30 | if useLogState {
31 | let logStateDescriptor = MTLLogStateDescriptor()
32 | logStateDescriptor.bufferSize = 16 * 1_024 * 1_024
33 | let logState = try device.makeLogState(descriptor: logStateDescriptor)
34 | logState.addLogHandler { _, _, _, message in
35 | logger?.log("\(message)")
36 | }
37 | commandQueueDescriptor.logState = logState
38 | }
39 | guard let commandQueue = device.makeCommandQueue(descriptor: commandQueueDescriptor) else {
40 | throw ComputeError.resourceCreationFailure
41 | }
42 | commandQueue.label = "Compute-MTLCommandQueue"
43 | self.commandQueue = commandQueue
44 | } else {
45 | guard let commandQueue = device.makeCommandQueue() else {
46 | throw ComputeError.resourceCreationFailure
47 | }
48 | commandQueue.label = "Compute-MTLCommandQueue"
49 | self.commandQueue = commandQueue
50 | }
51 | #else
52 | guard let commandQueue = device.makeCommandQueue() else {
53 | throw ComputeError.resourceCreationFailure
54 | }
55 | commandQueue.label = "Compute-MTLCommandQueue"
56 | self.commandQueue = commandQueue
57 | #endif
58 | }
59 |
60 | /// Executes a compute task.
61 | ///
62 | /// This method creates a command buffer and executes the provided block with a `Task` instance.
63 | ///
64 | /// - Parameters:
65 | /// - label: An optional label for the task, useful for debugging.
66 | /// - block: A closure that takes a `Task` instance and returns a result.
67 | /// - Returns: The result of the block execution.
68 | /// - Throws: `ComputeError.resourceCreationFailure` if unable to create the command buffer,
69 | /// or any error thrown by the provided block.
70 | public func task(label: String? = nil, _ block: (Task) throws -> R) throws -> R {
71 | let commandBufferDescriptor = MTLCommandBufferDescriptor()
72 | guard let commandBuffer = commandQueue.makeCommandBuffer(descriptor: commandBufferDescriptor) else {
73 | throw ComputeError.resourceCreationFailure
74 | }
75 | commandBuffer.label = "\(label ?? "Unlabeled")-MTLCommandBuffer"
76 | defer {
77 | commandBuffer.commit()
78 | commandBuffer.waitUntilCompleted()
79 | }
80 | let task = Task(label: label, logger: logger, commandBuffer: commandBuffer)
81 | return try block(task)
82 | }
83 |
84 | /// Creates a compute pipeline.
85 | ///
86 | /// This method creates a `Pipeline` instance, which encapsulates a compute pipeline state and its associated arguments.
87 | ///
88 | /// - Parameters:
89 | /// - function: The shader function to use for the pipeline.
90 | /// - constants: A dictionary of constant values to be used when compiling the shader function.
91 | /// - arguments: A dictionary of arguments to be passed to the shader function.
92 | /// - Returns: A new `Pipeline` instance.
93 | /// - Throws: Any error that occurs during the creation of the compute pipeline state.
94 | public func makePipeline(function: ShaderFunction, constants: [String: Argument] = [:], arguments: [String: Argument] = [:]) throws -> Pipeline {
95 | try Pipeline(device: device, function: function, constants: constants, arguments: arguments)
96 | }
97 | }
98 |
99 | /// Type alias for `Compute` to avoid name conflicts with module name.
100 | public typealias Compute_ = Compute
101 |
--------------------------------------------------------------------------------
/Sources/Compute/ShaderFunction.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 |
3 | /// Represents a Metal shader library.
4 | @dynamicMemberLookup
5 | public struct ShaderLibrary: Identifiable, Sendable {
6 | public enum ID: Hashable, Sendable {
7 | case bundle(URL, String?)
8 | case source(String)
9 | }
10 |
11 | /// The default shader library loaded from the main bundle.
12 | public static let `default` = Self.bundle(.main)
13 |
14 | /// Creates a shader library from a bundle.
15 | ///
16 | /// - Parameters:
17 | /// - bundle: The bundle containing the shader library.
18 | /// - name: The name of the metallib file (without extension). If nil, uses the default library.
19 | /// - Returns: A new ShaderLibrary instance.
20 | public static func bundle(_ bundle: Bundle, name: String? = nil) -> Self {
21 | return ShaderLibrary(id: .bundle(bundle.bundleURL, name)) { device in
22 | if let name {
23 | guard let url = bundle.url(forResource: name, withExtension: "metallib") else {
24 | fatalError("Could not load metallib.")
25 | }
26 | let library = try device.makeLibrary(URL: url)
27 | library.label = "\(name).MTLLibrary"
28 | return library
29 | }
30 | let library = try device.makeDefaultLibrary(bundle: bundle)
31 | library.label = "Default.MTLLibrary"
32 | return library
33 | }
34 | }
35 |
36 | /// Creates a shader library from source code.
37 | ///
38 | /// - Parameters:
39 | /// - source: The Metal shader source code as a string.
40 | /// - enableLogging: Enable Metal shader logging.
41 | /// - Returns: A new ShaderLibrary instance.
42 | public static func source(_ source: String, enableLogging: Bool = false) -> Self {
43 | return ShaderLibrary(id: .source(source)) { device in
44 | let options = MTLCompileOptions()
45 | #if !targetEnvironment(simulator)
46 | if enableLogging {
47 | if #available(macOS 15, iOS 18, *) {
48 | options.enableLogging = true
49 | }
50 | else {
51 | fatalError("Metal logging is not available on this platform.")
52 | }
53 | }
54 | #endif
55 | return try device.makeLibrary(source: source, options: options)
56 | }
57 | }
58 |
59 | public var id: ID
60 |
61 | /// A closure that creates an MTLLibrary given an MTLDevice.
62 | public var make: @Sendable (MTLDevice) throws -> MTLLibrary
63 |
64 | /// Creates a ShaderFunction with the given name.
65 | ///
66 | /// - Parameter name: The name of the shader function.
67 | /// - Returns: A new ShaderFunction instance.
68 | public func function(name: String) -> ShaderFunction {
69 | ShaderFunction(library: self, name: name)
70 | }
71 |
72 | /// Allows accessing shader functions using dynamic member lookup.
73 | ///
74 | /// - Parameter name: The name of the shader function.
75 | /// - Returns: A new ShaderFunction instance.
76 | public subscript(dynamicMember name: String) -> ShaderFunction {
77 | ShaderFunction(library: self, name: name)
78 | }
79 |
80 | /// Creates an MTLLibrary for the given device.
81 | ///
82 | /// - Parameter device: The Metal device to create the library for.
83 | /// - Returns: The created MTLLibrary.
84 | /// - Throws: An error if the library creation fails.
85 | internal func makelibrary(device: MTLDevice) throws -> MTLLibrary {
86 | try make(device)
87 | }
88 | }
89 |
90 | /// Represents a compute shader function within a shader library.
91 | public struct ShaderFunction: Identifiable {
92 | /// A unique identifier for the shader function.
93 | public let id: Composite
94 |
95 | /// The shader library containing this function.
96 | public let library: ShaderLibrary
97 |
98 | /// The name of the shader function.
99 | public let name: String
100 |
101 | /// An array of shader constants associated with this function.
102 | public let constants: [ShaderConstant]
103 |
104 | /// Initializes a new shader function.
105 | ///
106 | /// - Parameters:
107 | /// - library: The shader library containing this function.
108 | /// - name: The name of the shader function.
109 | /// - constants: An array of shader constants associated with this function. Defaults to an empty array.
110 | public init(library: ShaderLibrary, name: String, constants: [ShaderConstant] = []) {
111 | // BUG: https://github.com/schwa/Compute/issues/17 Make shader constants part of name id
112 | self.id = Composite(library.id, name)
113 | self.library = library
114 | self.name = name
115 | self.constants = constants
116 | }
117 | }
118 |
119 | /// Represents a constant value that can be passed to a shader function.
120 | public struct ShaderConstant {
121 | /// The data type of the constant.
122 | var dataType: MTLDataType
123 |
124 | /// A closure that provides access to the constant's value.
125 | var accessor: ((UnsafeRawPointer) -> Void) -> Void
126 |
127 | /// Initializes a new shader constant with an array value.
128 | ///
129 | /// - Parameters:
130 | /// - dataType: The Metal data type of the constant.
131 | /// - value: The array value of the constant.
132 | // BUG: https://github.com/schwa/Compute/issues/16 Get away from Any and use an enum.
133 | public init(dataType: MTLDataType, value: [some Any]) {
134 | self.dataType = dataType
135 | accessor = { (callback: (UnsafeRawPointer) -> Void) in
136 | value.withUnsafeBytes { pointer in
137 | guard let baseAddress = pointer.baseAddress else {
138 | fatalError("Could not get baseAddress.")
139 | }
140 | callback(baseAddress)
141 | }
142 | }
143 | }
144 |
145 | /// Initializes a new shader constant with a single value.
146 | ///
147 | /// - Parameters:
148 | /// - dataType: The Metal data type of the constant.
149 | /// - value: The value of the constant.
150 | // BUG: https://github.com/schwa/Compute/issues/16 Get away from Any and use an enum.
151 | public init(dataType: MTLDataType, value: some Any) {
152 | self.dataType = dataType
153 | accessor = { (callback: (UnsafeRawPointer) -> Void) in
154 | withUnsafeBytes(of: value) { pointer in
155 | guard let baseAddress = pointer.baseAddress else {
156 | fatalError("Could not get baseAddress.")
157 | }
158 | callback(baseAddress)
159 | }
160 | }
161 | }
162 |
163 | /// Adds the constant value to a MTLFunctionConstantValues object.
164 | ///
165 | /// - Parameters:
166 | /// - values: The MTLFunctionConstantValues object to add the constant to.
167 | /// - name: The name of the constant in the shader code.
168 | public func add(to values: MTLFunctionConstantValues, name: String) {
169 | accessor { pointer in
170 | values.setConstantValue(pointer, type: dataType, withName: name)
171 | }
172 | }
173 | }
174 |
--------------------------------------------------------------------------------
/Sources/Compute/Support.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 |
3 | /// Enumerates the possible errors that can occur in the Compute framework.
4 | public enum ComputeError: Error {
5 | /// Indicates that a required binding for an argument is missing.
6 | /// - Parameter String: The name of the missing binding.
7 | case missingBinding(String)
8 |
9 | case nonuniformThreadgroupsSizeNotSupported
10 |
11 | /// Indicates a failure in creating a required resource, such as a command queue or buffer.
12 | case resourceCreationFailure
13 |
14 | }
15 |
16 | public extension Compute {
17 | /// Dispatches a compute operation using a more convenient syntax.
18 | ///
19 | /// This method wraps the `task` method, providing a simpler interface for dispatching compute operations.
20 | ///
21 | /// - Parameters:
22 | /// - label: An optional label for the dispatch operation, useful for debugging.
23 | /// - block: A closure that takes a `Dispatcher` and returns a result.
24 | /// - Returns: The result of the block execution.
25 | /// - Throws: Any error that occurs during the execution of the block or the underlying task.
26 | func dispatch(label: String? = nil, _ block: (Dispatcher) throws -> R) throws -> R {
27 | try task(label: label) { task in
28 | try task { dispatch in
29 | try block(dispatch)
30 | }
31 | }
32 | }
33 | }
34 |
35 | public extension Compute {
36 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, threads: MTLSize, threadsPerThreadgroup: MTLSize) throws {
37 | var pipeline = pipeline
38 | if let arguments {
39 | var existing = pipeline.arguments.arguments
40 | existing.merge(arguments) { $1 }
41 | pipeline.arguments = .init(arguments: existing)
42 | }
43 | try task { task in
44 | try task { dispatch in
45 | try dispatch(pipeline: pipeline, threads: threads, threadsPerThreadgroup: threadsPerThreadgroup)
46 | }
47 | }
48 | }
49 |
50 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, threadgroupsPerGrid: MTLSize, threadsPerThreadgroup: MTLSize) throws {
51 | var pipeline = pipeline
52 | if let arguments {
53 | var existing = pipeline.arguments.arguments
54 | existing.merge(arguments) { $1 }
55 | pipeline.arguments = .init(arguments: existing)
56 | }
57 | try task { task in
58 | try task { dispatch in
59 | try dispatch(pipeline: pipeline, threadgroupsPerGrid: threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
60 | }
61 | }
62 | }
63 |
64 |
65 | /// Runs a compute pipeline with the specified arguments and thread count.
66 | ///
67 | /// This method provides a convenient way to execute a single compute pipeline with optional additional arguments.
68 | ///
69 | /// - Parameters:
70 | /// - pipeline: The compute pipeline to run.
71 | /// - arguments: Optional additional arguments to merge with the pipeline's existing arguments.
72 | /// - width: The number of threads to dispatch.
73 | /// - Throws: Any error that occurs during the execution of the compute pipeline.
74 | @available(*, deprecated, message: "Deprecated")
75 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, width: Int) throws {
76 | var pipeline = pipeline
77 | if let arguments {
78 | var existing = pipeline.arguments.arguments
79 | existing.merge(arguments) { $1 }
80 | pipeline.arguments = .init(arguments: existing)
81 | }
82 | try task { task in
83 | try task { dispatch in
84 | let maxTotalThreadsPerThreadgroup = pipeline.computePipelineState.maxTotalThreadsPerThreadgroup
85 | try dispatch(pipeline: pipeline, threads: MTLSize(width: width, height: 1, depth: 1), threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1, depth: 1))
86 | }
87 | }
88 | }
89 |
90 | @available(*, deprecated, message: "Deprecated")
91 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, width: Int, height: Int) throws {
92 | var pipeline = pipeline
93 | if let arguments {
94 | var existing = pipeline.arguments.arguments
95 | existing.merge(arguments) { $1 }
96 | pipeline.arguments = .init(arguments: existing)
97 | }
98 | try task { task in
99 | try task { dispatch in
100 | let maxTotalThreadsPerThreadgroup = pipeline.computePipelineState.maxTotalThreadsPerThreadgroup
101 |
102 | let threadsPerThreadgroupWidth = Int(sqrt(Double(maxTotalThreadsPerThreadgroup)))
103 | let threadsPerThreadgroupHeight = maxTotalThreadsPerThreadgroup / threadsPerThreadgroupWidth
104 |
105 | try dispatch(pipeline: pipeline, threads: MTLSize(width: width, height: height, depth: 1), threadsPerThreadgroup: MTLSize(width: threadsPerThreadgroupWidth, height: threadsPerThreadgroupHeight, depth: 1))
106 | }
107 | }
108 | }
109 | }
110 |
111 | internal extension MTLSize {
112 | var shortDescription: String {
113 | return "[\(width), \(height), \(depth)]"
114 | }
115 | }
116 |
117 | // MARK: -
118 |
119 | public struct Composite {
120 | private let children: (repeat each T)
121 |
122 | public init(_ children: repeat each T) {
123 | self.children = (repeat each children)
124 | }
125 | }
126 |
127 | extension Composite: Equatable where repeat each T: Equatable {
128 | public static func == (lhs: Self, rhs: Self) -> Bool {
129 | for (left, right) in repeat (each lhs.children, each rhs.children) {
130 | guard left == right else {
131 | return false
132 | }
133 | }
134 | return true
135 | }
136 | }
137 |
138 | extension Composite: Hashable where repeat each T: Hashable {
139 | public func hash(into hasher: inout Hasher) {
140 | for child in repeat (each children) {
141 | child.hash(into: &hasher)
142 | }
143 | }
144 | }
145 |
146 | extension Composite: Sendable where repeat each T: Sendable {
147 | }
148 |
149 |
--------------------------------------------------------------------------------
/Sources/Examples/.swiftlint.yml:
--------------------------------------------------------------------------------
1 | disabled_rules:
2 | - explicit_top_level_acl
3 | - fatal_error_message
4 | - force_cast
5 | - force_try
6 | - force_unwrapping
7 | - function_body_length
8 | - identifier_name
9 | - line_length
10 | - missing_docs
11 |
--------------------------------------------------------------------------------
/Sources/Examples/Bundle.txt:
--------------------------------------------------------------------------------
1 | This file kept left here to act as a resource for SPM so we get a `Bundle.module`.
2 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | protocol Demo {
4 | static func main() async throws
5 | }
6 |
7 | @main
8 | enum Examples {
9 | static func main() async throws {
10 | let demos: [Demo.Type] = [
11 | BareMetalVsCompute.self,
12 | BufferFill.self,
13 | Checkerboard.self,
14 | GameOfLife.self,
15 | HelloWorldDemo.self,
16 | ImageInvert.self,
17 | MaxValue.self,
18 | MemcopyDemo.self,
19 | RandomFill.self,
20 | SIMDReduce.self,
21 | ThreadgroupLogging.self,
22 | BitonicSortDemo.self,
23 | Histogram.self,
24 | CounterDemo.self,
25 | MaxParallel.self,
26 | IsSorted.self,
27 | ]
28 |
29 | let argument: String? = CommandLine.arguments.count > 1 ? CommandLine.arguments[1].lowercased() : nil
30 | if let argument {
31 | if argument == "all" {
32 | for demo in demos {
33 | print("\(demo)")
34 | try await demo.main()
35 | }
36 | }
37 | else {
38 | guard let demo = demos.first(where: { String(describing: $0).lowercased() == argument }) else {
39 | fatalError("No demo with name: \(argument)")
40 | }
41 | try await demo.main()
42 | }
43 | }
44 | else {
45 | for (index, demo) in demos.enumerated() {
46 | print("\(index): \(String(describing: demo))")
47 | }
48 | print("Choice", terminator: ": ")
49 | let choice = Int(readLine()!)!
50 | let demo = demos[choice]
51 | print("Running \(String(describing: demo))...")
52 | try await demo.main()
53 | }
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/BareMetal.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 |
4 | enum BareMetalVsCompute: Demo {
5 | static let source = #"""
6 | #include
7 |
8 | using namespace metal;
9 |
10 | uint thread_position_in_grid [[thread_position_in_grid]];
11 |
12 | kernel void memset(device uchar* output [[buffer(0)]], constant uchar &value [[buffer(1)]]) {
13 | output[thread_position_in_grid] = value;
14 | }
15 | """#
16 |
17 | static func main() async throws {
18 | // Get the default Metal device
19 | let device = MTLCreateSystemDefaultDevice()!
20 | // Create a buffer and confirm it is zeroed
21 | let buffer = device.makeBuffer(length: 16_384)!
22 | assert(UnsafeRawBufferPointer(start: buffer.contents(), count: buffer.length).allSatisfy { $0 == 0x00 })
23 | // Run compute and confirm the output is correct
24 | try compute(device: device, buffer: buffer, value: 0x88)
25 | assert(UnsafeRawBufferPointer(start: buffer.contents(), count: buffer.length).allSatisfy { $0 == 0x88 })
26 | // Run bareMetal and confirm the output is correct
27 | try bareMetal(device: device, buffer: buffer, value: 0xFF)
28 | assert(UnsafeRawBufferPointer(start: buffer.contents(), count: buffer.length).allSatisfy { $0 == 0xFF })
29 | }
30 |
31 | static func bareMetal(device: MTLDevice, buffer: MTLBuffer, value: UInt8) throws {
32 | // Create shader library from source
33 | let library = try device.makeLibrary(source: source, options: .init())
34 | let function = library.makeFunction(name: "memset")
35 |
36 | // Create compute pipeline for memset operation
37 | let computePipelineDescriptor = MTLComputePipelineDescriptor()
38 | computePipelineDescriptor.computeFunction = function
39 | let (computePipelineState, reflection) = try device.makeComputePipelineState(descriptor: computePipelineDescriptor, options: [.bindingInfo])
40 | guard let reflection else {
41 | throw ComputeError.resourceCreationFailure
42 | }
43 | guard let outputIndex = reflection.bindings.first(where: { $0.name == "output" })?.index else {
44 | throw ComputeError.missingBinding("output")
45 | }
46 | guard let valueIndex = reflection.bindings.first(where: { $0.name == "value" })?.index else {
47 | throw ComputeError.missingBinding("value")
48 | }
49 |
50 | // Execute compute pipeline
51 | guard let commandQueue = device.makeCommandQueue() else {
52 | throw ComputeError.resourceCreationFailure
53 | }
54 | commandQueue.label = "memcpy-MTLCommandQueue"
55 |
56 | let commandBufferDescriptor = MTLCommandBufferDescriptor()
57 | guard let commandBuffer = commandQueue.makeCommandBuffer(descriptor: commandBufferDescriptor) else {
58 | throw ComputeError.resourceCreationFailure
59 | }
60 | commandBuffer.label = "memcpy-MTLCommandBuffer"
61 | guard let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder() else {
62 | throw ComputeError.resourceCreationFailure
63 | }
64 | computeCommandEncoder.label = "memcpy-MTLComputeCommandEncoder"
65 | computeCommandEncoder.setComputePipelineState(computePipelineState)
66 | computeCommandEncoder.setBuffer(buffer, offset: outputIndex, index: outputIndex)
67 | var value = value
68 | computeCommandEncoder.setBytes(&value, length: MemoryLayout.size(ofValue: value), index: valueIndex)
69 | let threadsPerGrid = MTLSize(width: buffer.length, height: 1, depth: 1)
70 | let maxTotalThreadsPerThreadgroup = computePipelineState.maxTotalThreadsPerThreadgroup
71 | let threadsPerThreadgroup = MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1, depth: 1)
72 | computeCommandEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
73 | computeCommandEncoder.endEncoding()
74 | commandBuffer.commit()
75 | commandBuffer.waitUntilCompleted()
76 | }
77 |
78 | static func compute(device: MTLDevice, buffer: MTLBuffer, value: UInt8) throws {
79 | // Set up.
80 | let compute = try Compute(device: device)
81 |
82 | // Create shader library from source
83 | let library = ShaderLibrary.source(source)
84 |
85 | // Create compute pipeline for memset operation
86 | var fill = try compute.makePipeline(function: library.memset)
87 |
88 | // Set buffer and fill value arguments
89 | fill.arguments.output = .buffer(buffer)
90 | fill.arguments.value = .int(value)
91 |
92 | // Execute compute pipeline
93 | try compute.run(pipeline: fill, width: buffer.length)
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/BitonicSort/BitonicSort.metal:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | using namespace metal;
5 |
6 | uint thread_position_in_grid [[thread_position_in_grid]];
7 |
8 | [[kernel]]
9 | void bitonicSort(
10 | constant uint &numEntries [[buffer(0)]],
11 | constant uint &groupWidth [[buffer(1)]],
12 | constant uint &groupHeight [[buffer(2)]],
13 | constant uint &stepIndex [[buffer(3)]],
14 | device uint *entries [[buffer(4)]]
15 | ) {
16 | const uint hIndex = thread_position_in_grid & (groupWidth - 1);
17 | const uint indexLeft = hIndex + (groupHeight + 1) * (thread_position_in_grid / groupWidth);
18 | const uint stepSize = stepIndex == 0 ? groupHeight - 2 * hIndex : (groupHeight + 1) / 2;
19 | const uint indexRight = indexLeft + stepSize;
20 | // Exit if out of bounds (for non-power of 2 input sizes)
21 | if (indexRight >= numEntries) {
22 | return;
23 | }
24 | const uint valueLeft = entries[indexLeft];
25 | const uint valueRight = entries[indexRight];
26 | // Swap entries if value is descending
27 | if (valueLeft > valueRight) {
28 | entries[indexLeft] = valueRight;
29 | entries[indexRight] = valueLeft;
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/BitonicSort/BitonicSort.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 | import os
4 |
5 | // swiftlint:disable force_unwrapping
6 |
7 | public struct BitonicSortDemo: Demo {
8 |
9 | public static func main() async throws {
10 |
11 | var entries: [UInt32] = timeit("Creating entries") {
12 | (0 ..< 1_500_000).shuffled()
13 | }
14 |
15 | let device = MTLCreateSystemDefaultDevice()!
16 | let numEntries = entries.count
17 | let buffer: MTLBuffer = try entries.withUnsafeMutableBufferPointer { buffer in
18 | let buffer = UnsafeMutableRawBufferPointer(buffer)
19 | return try device.makeBufferEx(bytes: buffer.baseAddress!, length: buffer.count)
20 | }
21 |
22 | let function = ShaderLibrary.bundle(.module).bitonicSort
23 | let numStages = log2(nextPowerOfTwo(numEntries))
24 |
25 | let compute = try Compute(device: device)
26 |
27 | var pipeline = try compute.makePipeline(function: function, arguments: [
28 | "numEntries": .int(numEntries),
29 | "entries": .buffer(buffer),
30 | ])
31 |
32 | print("Running \(numStages) compute stages")
33 |
34 | var threadgroupsPerGrid = (entries.count + pipeline.maxTotalThreadsPerThreadgroup - 1) / pipeline.maxTotalThreadsPerThreadgroup
35 | threadgroupsPerGrid = (threadgroupsPerGrid + pipeline.threadExecutionWidth - 1) / pipeline.threadExecutionWidth * pipeline.threadExecutionWidth
36 |
37 | try timeit("GPU") {
38 | try compute.task { task in
39 | try task { dispatch in
40 | var n = 0
41 | for stageIndex in 0 ..< numStages {
42 | for stepIndex in 0 ..< (stageIndex + 1) {
43 | let groupWidth = 1 << (stageIndex - stepIndex)
44 | let groupHeight = 2 * groupWidth - 1
45 | pipeline.arguments.groupWidth = .int(groupWidth)
46 | pipeline.arguments.groupHeight = .int(groupHeight)
47 | pipeline.arguments.stepIndex = .int(stepIndex)
48 | // print("\(n), \(stageIndex)/\(numStages), \(stepIndex)/\(stageIndex + 1), \(groupWidth), \(groupHeight)")
49 | try dispatch(
50 | pipeline: pipeline,
51 | threadgroupsPerGrid: MTLSize(width: threadgroupsPerGrid),
52 | threadsPerThreadgroup: MTLSize(width: pipeline.maxTotalThreadsPerThreadgroup)
53 | )
54 | n += 1
55 | }
56 | }
57 | }
58 | }
59 | }
60 |
61 | timeit("CPU") {
62 | entries.sort()
63 | }
64 |
65 | let result = Array(buffer)
66 | assert(entries == result)
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/Broken/CountingSort/CountingSort.metal:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | using namespace metal;
6 |
7 | namespace CountingSort16 {
8 | // MARK: -
9 |
10 | [[kernel]]
11 | void histogram(
12 | uint2 thread_position_in_grid [[thread_position_in_grid]],
13 | device short *input [[buffer(0)]],
14 | device atomic_uint *histogram [[buffer(1)]],
15 | constant uint &histogramCount [[buffer(2)]]
16 |
17 | )
18 | {
19 | const uchar bucket = thread_position_in_grid.x;
20 | if (bucket >= histogramCount) {
21 | return;
22 | }
23 | const uint index = thread_position_in_grid.y;
24 | if (bucket == input[index]) {
25 | atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed);
26 | }
27 | }
28 |
29 | // MARK: -
30 |
31 | [[kernel]]
32 | void prefix_sum_inclusive(
33 | device uint *histogram [[buffer(0)]],
34 | constant uint &histogramCount [[buffer(1)]]
35 | )
36 | {
37 | for (uint index = 1; index != histogramCount; index++) {
38 | histogram[index] += histogram[index - 1];
39 | }
40 | }
41 |
42 | [[kernel]]
43 | void prefix_sum_exclusive(
44 | device uint *histogram [[buffer(0)]],
45 | constant uint &histogramCount [[buffer(1)]]
46 | )
47 | {
48 | uint sum = 0;
49 | for (uint i = 0; i != histogramCount; i++) {
50 | uint t = histogram[i];
51 | histogram[i] = sum;
52 | sum += t;
53 | }
54 | }
55 |
56 | // MARK: -
57 |
58 | [[kernel]]
59 | void shuffle(
60 | uint thread_position_in_grid [[thread_position_in_grid]],
61 | device atomic_uint *histogram [[buffer(0)]],
62 | device short *input [[buffer(1)]],
63 | device short *output [[buffer(2)]],
64 | constant uint &count [[buffer(3)]]
65 | )
66 | {
67 | const uint index = thread_position_in_grid;
68 |
69 | // for (uint index = 0; index != count; ++index) {
70 | const auto bucket = input[index];
71 | const auto old = atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed);
72 | output[old] = input[index];
73 | // }
74 | }
75 |
76 | }
77 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/Broken/CountingSort/CountingSort.swift:
--------------------------------------------------------------------------------
1 | //import AppKit
2 | //import Compute
3 | //import CoreGraphics
4 | //import Foundation
5 | //import Metal
6 | //
7 | //struct CountingSort {
8 | // let device = MTLCreateSystemDefaultDevice()!
9 | //
10 | // func main() throws {
11 | // let capture = false
12 | //
13 | // let maxValue: UInt16 = 65535
14 | //
15 | // let values = (0..<1).map { _ in UInt16.random(in: 0 ... maxValue - 1) }
16 | // // let values: [UInt16] = [3, 1, 4, 2, 5, 6, 7, 8, 9, 0]
17 | // // print("Values", values)
18 | // let expectedResult = timeit("Foundation sort") { values.sorted() }
19 | // // print("Expected", expectedResult)
20 | // let cpuHistograms = Array(values.histogram().map(UInt32.init)[..
2 | #include
3 | #include
4 |
5 | using namespace metal;
6 |
7 | inline uchar key(uint value, uint shift) {
8 | return (value >> shift) & 0xFF;
9 | }
10 |
11 | // MARK: -
12 |
13 | [[kernel]]
14 | void histogram(
15 | uint2 thread_position_in_grid [[thread_position_in_grid]],
16 | device uint *input [[buffer(0)]],
17 | constant uint &shift [[buffer(2)]],
18 | device atomic_uint *histogram [[buffer(3)]]
19 | )
20 | {
21 | const uchar bucket = thread_position_in_grid.x;
22 | const uint index = thread_position_in_grid.y;
23 | if (bucket == key(input[index], shift)) {
24 | atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed);
25 | }
26 | }
27 |
28 | // MARK: -
29 |
30 | [[kernel]]
31 | void prefix_sum_inclusive(
32 | device uint *histogram [[buffer(0)]]
33 | )
34 | {
35 | for (int index = 1; index != 256; index++) {
36 | histogram[index] += histogram[index - 1];
37 | }
38 | }
39 |
40 | [[kernel]]
41 | void prefix_sum_exclusive(
42 | device uint *histogram [[buffer(0)]]
43 | )
44 | {
45 | uint sum = 0;
46 | for (int i = 0; i != 256; i++) {
47 | uint t = histogram[i];
48 | histogram[i] = sum;
49 | sum += t;
50 | }
51 | }
52 |
53 | // MARK: -
54 |
55 | [[kernel]]
56 | void shuffle(
57 | uint2 thread_position_in_grid [[thread_position_in_grid]],
58 | device atomic_uint *histogram [[buffer(0)]],
59 | device uint *input [[buffer(1)]],
60 | device uint *output [[buffer(2)]],
61 | constant uint &count [[buffer(3)]],
62 | constant uint &shift [[buffer(4)]]
63 | )
64 | {
65 | const uchar bucket = thread_position_in_grid.x;
66 | for (int index = count - 1; index >= 0; --index) {
67 | if (bucket == key(input[index], shift)) {
68 | auto old = atomic_fetch_add_explicit(&histogram[bucket], -1, memory_order_relaxed);
69 | output[old - 1] = input[index];
70 | }
71 | }
72 | }
73 |
74 | [[kernel]]
75 | void shuffle2(
76 | uint2 thread_position_in_grid [[thread_position_in_grid]],
77 | device atomic_uint *histogram [[buffer(0)]],
78 | device uint *input [[buffer(1)]],
79 | device uint *output [[buffer(2)]],
80 | constant uint &count [[buffer(3)]],
81 | constant uint &shift [[buffer(4)]]
82 | )
83 | {
84 | const uchar bucket = thread_position_in_grid.x;
85 | for (uint index = 0; index != count; index++) {
86 | if (bucket == key(input[index], shift)) {
87 | auto old = atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed);
88 | output[old] = input[index];
89 | }
90 | }
91 | }
92 |
93 | [[kernel]]
94 | void shuffle3(
95 | uint thread_position_in_grid [[thread_position_in_grid]],
96 | uint threads_per_grid [[threads_per_grid]],
97 | device atomic_uint *histogram [[buffer(0)]],
98 | device uint *input [[buffer(1)]],
99 | device uint *output [[buffer(2)]],
100 | constant uint &count [[buffer(3)]],
101 | constant uint &shift [[buffer(4)]]
102 | )
103 | {
104 | const uint global_id = thread_position_in_grid;
105 |
106 | if (global_id < count) {
107 | const uchar bucket = key(input[global_id], shift);
108 | const auto old = atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed);
109 | output[old] = input[global_id];
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/Broken/RadixSort/RadixSort.swift:
--------------------------------------------------------------------------------
1 | //import AppKit
2 | //import Compute
3 | //import CoreGraphics
4 | //import Foundation
5 | //import Metal
6 | //
7 | //struct RadixSort {
8 | // let device = MTLCreateSystemDefaultDevice()!
9 | //
10 | // func main() throws {
11 | // let capture = false
12 | //
13 | // let values = (0..<1_500_000).map { _ in UInt32.random(in: 0 ... 1000) }
14 | // let expectedResult = timeit("Foundation sort") { values.sorted() }
15 | // let cpuSorted = timeit("CPU Radix Sort") { radixSort(values: values) }
16 | // print("CPU Sorted?", expectedResult == cpuSorted)
17 | //
18 | // let compute = try Compute(device: device, logState: capture ? nil : try device.makeDefaultLogState())
19 | // let library = ShaderLibrary.bundle(.module, name: "debug")
20 | //
21 | // var input = try device.makeBuffer(bytesOf: values, options: [])
22 | // var output = try device.makeBuffer(bytesOf: Array(repeating: UInt32.zero, count: values.count), options: [])
23 | // let histogram = try device.makeBuffer(bytesOf: Array(repeating: UInt32.zero, count: 256), options: [])
24 | //
25 | // var histogramPass = try compute.makePass(function: library.histogram)
26 | // var prefixSumPass = try compute.makePass(function: library.prefix_sum_exclusive)
27 | // var shufflePass = try compute.makePass(function: library.shuffle2)
28 | //
29 | // try timeit("GPU Radix Sort") {
30 | // try device.capture(enabled: capture) {
31 | // for phase in 0..<4 {
32 | // let shift = UInt32(phase * 8)
33 | //
34 | // histogramPass.arguments.histogram = .buffer(histogram)
35 | // histogramPass.arguments.shift = .int(shift)
36 | // histogramPass.arguments.input = .buffer(input)
37 | //
38 | // try compute.dispatch(label: "Histogram") { dispatch in
39 | // let maxTotalThreadsPerThreadgroup = histogramPass.maxTotalThreadsPerThreadgroup
40 | // try dispatch(pass: histogramPass, threads: MTLSize(width: 256, height: values.count), threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1))
41 | // }
42 | //
43 | // // Prefix sum.
44 | // prefixSumPass.arguments.histogram = .buffer(histogram)
45 | // try compute.dispatch(label: "PrefixSum") { dispatch in
46 | // try dispatch(pass: prefixSumPass, threads: MTLSize(width: 1), threadsPerThreadgroup: MTLSize(width: 1))
47 | // }
48 | //
49 | // // Shuffle.
50 | // shufflePass.arguments.histogram = .buffer(histogram)
51 | // shufflePass.arguments.input = .buffer(input)
52 | // shufflePass.arguments.output = .buffer(output)
53 | // shufflePass.arguments.count = .int(values.count)
54 | // shufflePass.arguments.shift = .int(shift)
55 | // try compute.dispatch(label: "Shuffle") { dispatch in
56 | // let maxTotalThreadsPerThreadgroup = shufflePass.maxTotalThreadsPerThreadgroup
57 | //
58 | // try dispatch(pass: shufflePass, threads: MTLSize(width: 256), threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup))
59 | // }
60 | //
61 | // swap(&input, &output)
62 | //
63 | // histogram.clear()
64 | // }
65 | // }
66 | // }
67 | // print("GPU Sorted?", expectedResult == input.as(UInt32.self))
68 | // }
69 | //}
70 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/Broken/RadixSort/RadixSortCPU.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | postfix operator ++
4 |
5 | extension Int {
6 | static postfix func ++(rhs: inout Int) -> Int {
7 | let oldValue = rhs
8 | rhs += 1
9 | return oldValue
10 | }
11 | }
12 |
13 | func shuffle(input: [UInt32], shift: Int, histogram: [UInt32]) -> [UInt32] {
14 | func key(_ value: UInt32) -> Int {
15 | (Int(value) >> shift) & 0xFF
16 | }
17 | var histogram = histogram.map { Int($0) }
18 | var output = Array(repeating: UInt32.zero, count: input.count)
19 | for index in stride(from: input.count - 1, through: 0, by: -1) {
20 | let value = key(input[index])
21 | histogram[value] -= 1
22 | output[histogram[value]] = input[index]
23 | }
24 | return output
25 | }
26 |
27 | func countingSort(input: [UInt32], shift: Int, output: inout [UInt32]) {
28 | func key(_ value: UInt32) -> Int {
29 | (Int(value) >> shift) & 0xFF
30 | }
31 | // Histogram
32 | var histogram = input.reduce(into: Array(repeating: 0, count: 256)) { result, value in
33 | result[key(value)] += 1
34 | }
35 | // Prefix Sum
36 | for index in histogram.indices.dropFirst() {
37 | histogram[index] += histogram[index - 1]
38 | }
39 | // Shuffle
40 | for index in stride(from: input.count - 1, through: 0, by: -1) {
41 | let value = key(input[index])
42 | histogram[value] -= 1
43 | output[histogram[value]] = input[index]
44 | }
45 | }
46 |
47 | // From: "A High-Performance Implementation of Counting Sort on CUDA GPU"
48 | func countingSort2(input: [UInt32], shift: Int, output: inout [UInt32]) {
49 | func key(_ value: UInt32) -> Int {
50 | (Int(value) >> shift) & 0xFF
51 | }
52 | // Histogram
53 | var histogram = input.reduce(into: Array(repeating: 0, count: 256)) { result, value in
54 | result[key(value)] += 1
55 | }
56 |
57 | // Prefix Sum (exclusive)
58 | var sum = 0
59 | for i in histogram.indices {
60 | let t = histogram[i]
61 | histogram[i] = sum
62 | sum += t
63 | }
64 |
65 | // Shuffle
66 | for i in input.indices {
67 | let value = key(input[i])
68 | output[histogram[value]++] = input[i]
69 | }
70 | }
71 |
72 | func radixSort(values: [UInt32]) -> [UInt32] {
73 | var input = values
74 | var output = Array(repeating: UInt32.zero, count: input.count)
75 | for phase in 0..<4 {
76 | countingSort2(input: input, shift: phase * 8, output: &output)
77 | swap(&input, &output)
78 | // print("Phase: \(phase) \(input.map({ String($0, radix: 16) }))")
79 | }
80 | return input
81 | }
82 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/BufferFill.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 | import os
4 |
5 | enum BufferFill: Demo {
6 | static let source = #"""
7 | #include
8 | #include
9 |
10 | using namespace metal;
11 |
12 | uint thread_position_in_grid [[thread_position_in_grid]];
13 | uint threads_per_grid [[threads_per_grid]];
14 |
15 | kernel void buffer_fill(
16 | device uint *data [[buffer(0)]]
17 | ) {
18 | //if (thread_position_in_grid == 0) {
19 | // os_log_default.log("threads_per_grid: %d", threads_per_grid);
20 | //}
21 | data[thread_position_in_grid] = threads_per_grid;
22 | }
23 | """#
24 |
25 | static func main() async throws {
26 | let device = MTLCreateSystemDefaultDevice()!
27 | let count = 2 ** 24
28 | let data = device.makeBuffer(length: count * MemoryLayout.stride, options: .storageModeShared)!
29 | let compute = try Compute(device: device, logger: Logger())
30 | let library = ShaderLibrary.source(source, enableLogging: true)
31 | var bufferFill = try compute.makePipeline(function: library.buffer_fill)
32 | bufferFill.arguments.data = .buffer(data)
33 | var n = count
34 | var values = Array(repeating: 0, count: count)
35 | while n > 0 {
36 | try compute.run(pipeline: bufferFill, width: n)
37 | values[0..>= 1
39 | }
40 | assert(Array(data) == values)
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/Checkerboard.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import CoreImage
3 | import Metal
4 | import MetalKit
5 | import os
6 | import SwiftUI
7 | import UniformTypeIdentifiers
8 |
9 | enum Checkerboard: Demo {
10 | static let source = #"""
11 | #include
12 |
13 | using namespace metal;
14 |
15 | uint2 gid [[thread_position_in_grid]];
16 |
17 | kernel void checkerboard(
18 | texture2d outputTexture [[texture(0)]],
19 | constant float4 &color1 [[buffer(0)]],
20 | constant float4 &color2 [[buffer(1)]],
21 | constant uint2 &cellSize [[buffer(2)]]
22 | ) {
23 | // Get the size of the texture
24 | uint width = outputTexture.get_width();
25 | uint height = outputTexture.get_height();
26 |
27 | // Check if the current thread is within the texture bounds
28 | if (gid.x >= width || gid.y >= height) {
29 | return;
30 | }
31 |
32 | // Determine which square this pixel belongs to
33 | uint squareX = gid.x / cellSize.x;
34 | uint squareY = gid.y / cellSize.y;
35 |
36 | // Choose color based on whether the sum of squareX and squareY is even or odd
37 | float4 color = ((squareX + squareY) % 2 == 0) ? color1 : color2;
38 |
39 | // Write the color to the output texture
40 | outputTexture.write(color, gid);
41 | }
42 | """#
43 |
44 | static func main() async throws {
45 | let device = MTLCreateSystemDefaultDevice()!
46 | let logger = Logger()
47 | let compute = try Compute(device: device, logger: logger)
48 | let library = ShaderLibrary.source(source)
49 |
50 | var checkerboard = try compute.makePipeline(function: library.checkerboard)
51 |
52 | let outputTextureDescriptor = MTLTextureDescriptor()
53 | outputTextureDescriptor.width = 1_024
54 | outputTextureDescriptor.height = 1_024
55 | outputTextureDescriptor.pixelFormat = .bgra8Unorm
56 | outputTextureDescriptor.usage = .shaderWrite
57 | guard let outputTexture = device.makeTexture(descriptor: outputTextureDescriptor) else {
58 | throw ComputeError.resourceCreationFailure
59 | }
60 |
61 | checkerboard.arguments.outputTexture = .texture(outputTexture)
62 | checkerboard.arguments.color1 = try .color(.gray)
63 | checkerboard.arguments.color2 = try .color(.mint)
64 | checkerboard.arguments.cellSize = .vector(SIMD2(64, 64))
65 | try compute.run(pipeline: checkerboard, width: outputTexture.width, height: outputTexture.height)
66 |
67 | let url = URL(filePath: "/tmp/checkerboard.png")
68 | try outputTexture.export(to: url)
69 | #if os(macOS)
70 | NSWorkspace.shared.selectFile(url.path, inFileViewerRootedAtPath: url.deletingLastPathComponent().path)
71 | #endif
72 |
73 | // ksdiff /tmp/inverted.png ~/Projects/Compute/Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/CounterDemo.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 | import os
4 |
5 | enum CounterDemo: Demo {
6 | static let source = #"""
7 | #include
8 | #include
9 |
10 | using namespace metal;
11 |
12 | uint thread_position_in_grid [[thread_position_in_grid]];
13 |
14 | kernel void main(
15 | device float *counters [[buffer(0)]],
16 | constant uint &count [[buffer(1)]],
17 | constant uint &step [[buffer(2)]]
18 |
19 |
20 | ) {
21 | for(uint n = 0; n != count; ++n) {
22 | counters[thread_position_in_grid] += float(n);
23 | }
24 | }
25 | """#
26 |
27 | static func main() async throws {
28 | let device = MTLCreateSystemDefaultDevice()!
29 | let logger = Logger()
30 | let compute = try Compute(device: device, logger: logger)
31 | let library = ShaderLibrary.source(source, enableLogging: true)
32 | let numberOfCounters = 100240
33 | let counters = device.makeBuffer(length: MemoryLayout.size * numberOfCounters, options: [])!
34 | var pipeline = try compute.makePipeline(function: library.main)
35 | pipeline.arguments.counters = .buffer(counters)
36 | pipeline.arguments.count = .int(100_000_000)
37 | pipeline.arguments.step = .int(1)
38 | try timeit {
39 | try compute.run(pipeline: pipeline, width: numberOfCounters)
40 | }
41 | // print(Array(counters))
42 |
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/GameOfLife/GameOfLife.metal:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | using namespace metal;
5 |
6 | constant bool wrap [[function_constant(0)]];
7 |
8 | static constant int2 positions[] = {
9 | int2(-1, -1),
10 | int2( 0, -1),
11 | int2(+1, -1),
12 | int2(-1, 0),
13 | int2(+1, 0),
14 | int2(-1, +1),
15 | int2( 0, +1),
16 | int2(+1, +1),
17 | };
18 |
19 | bool rules(uint count, bool alive) {
20 | if (alive == true && (count == 2 || count == 3)) {
21 | return true;
22 | }
23 | else if (alive == false && count == 3) {
24 | return true;
25 | }
26 | else if (alive == true) {
27 | return false;
28 | }
29 | else {
30 | return false;
31 | }
32 | }
33 |
34 | template void gameOfLifeGENERIC(
35 | uint2 gid,
36 | texture2d inputTexture,
37 | texture2d outputTexture,
38 | V clear,
39 | V set
40 | )
41 | {
42 | const int2 sgid = int2(gid);
43 | const int2 inputTextureSize = int2(inputTexture.get_width(), inputTexture.get_height());
44 | if (sgid.x >= inputTextureSize.x || sgid.y >= inputTextureSize.y) {
45 | return;
46 | }
47 | uint count = 0;
48 | for (int N = 0; N != 8; ++N) {
49 | int2 position = sgid + positions[N];
50 | if (!wrap) {
51 | if (position.x < 0 || position.x >= inputTextureSize.x) {
52 | continue;
53 | }
54 | else if (position.y < 0 || position.y >= inputTextureSize.y) {
55 | continue;
56 | }
57 | }
58 | else {
59 | position.x = (position.x + inputTextureSize.x) % inputTextureSize.x;
60 | position.y = (position.y + inputTextureSize.y) % inputTextureSize.y;
61 | }
62 | count += inputTexture.read(uint2(position)).r ? 1 : 0;
63 | }
64 | const bool alive = inputTexture.read(gid).r != 0;
65 | outputTexture.write(rules(count, alive), gid);
66 | }
67 |
68 |
69 | [[kernel]]
70 | void gameOfLife_uint(
71 | uint2 gid [[thread_position_in_grid]],
72 | texture2d inputTexture [[texture(0)]],
73 | texture2d outputTexture [[texture(1)]])
74 | {
75 | gameOfLifeGENERIC(gid, inputTexture, outputTexture, 0, 1);
76 | }
77 |
78 | [[kernel]]
79 | void gameOfLife_float4(
80 | uint2 gid [[thread_position_in_grid]],
81 | texture2d inputTexture [[texture(0)]],
82 | texture2d outputTexture [[texture(1)]])
83 | {
84 | gameOfLifeGENERIC(gid, inputTexture, outputTexture, float4(0, 0, 0, 0), float4(1, 1, 1, 1));
85 | }
86 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/GameOfLife/GameOfLife.swift:
--------------------------------------------------------------------------------
1 | #if os(macOS)
2 | import AVFoundation
3 | import Compute
4 | import Foundation
5 | import Metal
6 | import os
7 |
8 | enum GameOfLife: Demo {
9 |
10 | static func main() async throws {
11 | try await run()
12 | }
13 |
14 | static func run(density: Double = 0.5, width: Int = 256, height: Int = 256, frames: Int = 1_200, framesPerSecond: Int = 60) async throws {
15 | let logger: Logger? = Logger()
16 |
17 | // Calculate total number of pixels
18 | let pixelCount = width * height
19 |
20 | // Get the default Metal device
21 | let device = MTLCreateSystemDefaultDevice()!
22 |
23 | // Create and initialize buffer A with random live cells
24 | logger?.log("Creating buffers")
25 | let bufferA = device.makeBuffer(length: pixelCount * MemoryLayout.size, options: [])!
26 | bufferA.contents().withMemoryRebound(to: UInt32.self, capacity: pixelCount) { buffer in
27 | for n in 0...size)!
39 | textureA.label = "texture-a"
40 |
41 | // Create buffer B and texture B
42 | let bufferB = device.makeBuffer(length: pixelCount * MemoryLayout.size, options: [])!
43 | let textureB = bufferB.makeTexture(descriptor: textureDescriptor, offset: 0, bytesPerRow: width * MemoryLayout.size)!
44 | textureB.label = "texture-b"
45 |
46 | logger?.log("Loading shaders")
47 | // Initialize Compute and ShaderLibrary
48 | let compute = try Compute(device: device)
49 | let library = ShaderLibrary.bundle(.module, name: "debug")
50 |
51 | // Initialize timing variables
52 | var totalComputeTime: UInt64 = 0
53 | var totalEncodeTime: UInt64 = 0
54 |
55 | // Create compute pipeline
56 | var pipeline = try compute.makePipeline(function: library.gameOfLife_float4, constants: ["wrap": .bool(true)])
57 |
58 | // Set up video writer
59 | let url = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent("Desktop/GameOfLife.mov")
60 | let movieWriter = try TextureToVideoWriter(outputURL: url, size: CGSize(width: width, height: height))
61 | movieWriter.start()
62 |
63 | logger?.log("Encoding")
64 |
65 | let time = CMTimeMakeWithSeconds(0, preferredTimescale: 600)
66 | movieWriter.writeFrame(texture: textureA, at: time)
67 |
68 | // Main simulation loop
69 | for frame in 0..
8 | #include
9 |
10 | using namespace metal;
11 |
12 | kernel void hello_world() {
13 | os_log_default.log("Hello world (from Metal!)");
14 | }
15 | """#
16 |
17 | static func main() async throws {
18 | let device = MTLCreateSystemDefaultDevice()!
19 | let logger = Logger()
20 | logger.log("Hello world (from Swift!)")
21 | let compute = try Compute(device: device, logger: logger)
22 | let library = ShaderLibrary.source(source, enableLogging: true)
23 | let helloWorld = try compute.makePipeline(function: library.hello_world)
24 | try compute.run(pipeline: helloWorld, width: 1)
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/Histogram.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 | import os
4 |
5 | enum Histogram: Demo {
6 | static let source = #"""
7 | #include
8 | #include
9 |
10 | using namespace metal;
11 |
12 | uint thread_position_in_grid [[thread_position_in_grid]];
13 | uint thread_position_in_threadgroup [[thread_position_in_threadgroup]];
14 |
15 | kernel void histogram(
16 | device uint *input [[buffer(0)]],
17 | device atomic_uint *buckets [[buffer(1)]],
18 | constant uint &bucketCount [[buffer(2)]],
19 | threadgroup atomic_uint *scratch [[threadgroup(0)]]
20 | ) {
21 | const uint value = input[thread_position_in_grid];
22 | atomic_fetch_add_explicit(&scratch[value], value < bucketCount ? 1 : 0, memory_order_relaxed);
23 |
24 | threadgroup_barrier(mem_flags::mem_threadgroup);
25 | if (thread_position_in_threadgroup == 0) {
26 | for (uint n = 0; n != bucketCount; ++n) {
27 | const uint value = atomic_load_explicit(&scratch[n], memory_order_relaxed);
28 | atomic_fetch_add_explicit(&buckets[n], value, memory_order_relaxed);
29 | }
30 | }
31 | }
32 | """#
33 |
34 | static func main() async throws {
35 | let device = MTLCreateSystemDefaultDevice()!
36 | let values: [UInt32] = timeit("Generate random input") {
37 | (0..<1_000_000).map { n in UInt32.random(in: 0..<20) }
38 | }
39 | let input = try device.makeBuffer(bytesOf: values, options: [])
40 | let bucketCount = 32
41 | let buckets = device.makeBuffer(length: bucketCount * MemoryLayout.stride, options: [])!
42 | let compute = try Compute(device: device, logger: Logger())
43 | let library = ShaderLibrary.source(source, enableLogging: true)
44 | var histogram = try compute.makePipeline(function: library.histogram)
45 | histogram.arguments.input = .buffer(input)
46 | histogram.arguments.buckets = .buffer(buckets)
47 | histogram.arguments.bucketCount = .int(UInt32(bucketCount))
48 | histogram.arguments.scratch = .threadgroupMemoryLength(bucketCount * MemoryLayout.stride) // TODO: Align 16.
49 | try timeit("GPU") {
50 | try compute.run(pipeline: histogram, threads: [values.count], threadsPerThreadgroup: [histogram.maxTotalThreadsPerThreadgroup])
51 | }
52 | let result = Array(buckets)
53 | let expectedResult = timeit("CPU") {
54 | values.reduce(into: Array(repeating: 0, count: bucketCount)) { partialResult, value in
55 | let index = Int(value)
56 | partialResult[index] += 1
57 | }
58 | }
59 |
60 | print(result)
61 | print(expectedResult)
62 | print(result == expectedResult)
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/ImageInvert.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import CoreImage
3 | import Metal
4 | import MetalKit
5 | import os
6 | import UniformTypeIdentifiers
7 |
8 | enum ImageInvert: Demo {
9 | static let source = #"""
10 | #include
11 |
12 | using namespace metal;
13 |
14 | uint2 gid [[thread_position_in_grid]];
15 |
16 | kernel void invertImage(
17 | texture2d inputTexture [[texture(0)]],
18 | texture2d outputTexture [[texture(1)]]
19 | ) {
20 | float4 pixel = inputTexture.read(gid);
21 | pixel.rgb = 1.0 - pixel.rgb;
22 | outputTexture.write(pixel, gid);
23 | }
24 | """#
25 |
26 | static func main() async throws {
27 | let device = MTLCreateSystemDefaultDevice()!
28 | let logger = Logger()
29 | let compute = try Compute(device: device, logger: logger)
30 | let library = ShaderLibrary.source(source)
31 | var invertImage = try compute.makePipeline(function: library.invertImage, constants: ["isLinear": .bool(true)])
32 |
33 | let textureLoader = MTKTextureLoader(device: device)
34 | let inputTexture = try await textureLoader.newTexture(name: "baboon", scaleFactor: 1, bundle: .module, options: [.SRGB: false])
35 |
36 | let outputTextureDescriptor = MTLTextureDescriptor()
37 | outputTextureDescriptor.width = inputTexture.width
38 | outputTextureDescriptor.height = inputTexture.height
39 | outputTextureDescriptor.pixelFormat = .bgra8Unorm
40 | outputTextureDescriptor.usage = .shaderWrite
41 | guard let outputTexture = device.makeTexture(descriptor: outputTextureDescriptor) else {
42 | throw ComputeError.resourceCreationFailure
43 | }
44 |
45 | invertImage.arguments.inputTexture = .texture(inputTexture)
46 | invertImage.arguments.outputTexture = .texture(outputTexture)
47 | try compute.run(pipeline: invertImage, width: inputTexture.width, height: inputTexture.height)
48 |
49 | try outputTexture.export(to: URL(filePath: "/tmp/inverted.png"))
50 |
51 | // ksdiff /tmp/inverted.png ~/Projects/Compute/Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/IsSorted.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 | import os
4 |
5 | enum IsSorted: Demo {
6 | static let source = #"""
7 | #include
8 | #include
9 |
10 | using namespace metal;
11 |
12 | uint thread_position_in_threadgroup [[thread_position_in_threadgroup]];
13 | uint threadgroup_position_in_grid [[threadgroup_position_in_grid]];
14 | uint thread_position_in_grid [[thread_position_in_grid]];
15 | uint threads_per_threadgroup [[threads_per_threadgroup]];
16 |
17 | kernel void is_sorted(
18 | device const float* input [[buffer(0)]],
19 | constant uint& count [[buffer(1)]],
20 | device atomic_uint* isSorted [[buffer(2)]]
21 | ) {
22 | //os_log_default.log("thread_position_in_grid: %d, threads_per_grid: %f", thread_position_in_grid, input[thread_position_in_grid]);
23 | if (thread_position_in_grid < (count - 1) && input[thread_position_in_grid] > input[thread_position_in_grid + 1]) {
24 | //os_log_default.log("IS NOT SORTED %d %d", thread_position_in_grid, thread_position_in_grid + 1);
25 | atomic_store_explicit(isSorted, 0, memory_order_relaxed);
26 | }
27 | }
28 |
29 |
30 | kernel void is_sorted_complex(
31 | device const float* input [[buffer(0)]],
32 | constant uint& count [[buffer(1)]],
33 | device atomic_bool* isSorted [[buffer(2)]]
34 | ) {
35 | uint groupStart = threadgroup_position_in_grid * threads_per_threadgroup;
36 |
37 | // Start one before if not the first group
38 | uint start = (groupStart == 0) ? groupStart + thread_position_in_threadgroup : groupStart - 1 + threadgroup_position_in_grid;
39 | uint end = metal::min(count - 1, groupStart + threads_per_threadgroup);
40 |
41 | for (uint i = start; i < end; i += threads_per_threadgroup) {
42 | if (input[i] > input[i + 1]) {
43 | atomic_store_explicit(isSorted, false, metal::memory_order_relaxed);
44 | }
45 | }
46 | }
47 | """#
48 |
49 | static func main() async throws {
50 | try await simple()
51 | try await complex()
52 | }
53 |
54 | static func simple() async throws {
55 | let capture = false
56 |
57 | let device = MTLCreateSystemDefaultDevice()!
58 | try device.capture(enabled: capture) {
59 | let logger = Logger()
60 | let compute = try Compute(device: device, logger: logger, useLogState: !capture)
61 | let library = ShaderLibrary.source(source, enableLogging: !capture)
62 | var values = (0..<1_000_000).map { Float($0) }.sorted()
63 | //
64 | values.swapAt(1, 2)
65 | //
66 | let isSorted = try device.makeBuffer(bytesOf: UInt32(1), options: [.storageModeShared])
67 | isSorted.label = "isSorted"
68 |
69 | var pipeline = try compute.makePipeline(function: library.is_sorted)
70 | pipeline.arguments.input = .buffer(values, label: "input")
71 | pipeline.arguments.count = .int(values.count)
72 | pipeline.arguments.isSorted = .buffer(isSorted)
73 |
74 | try compute.run(pipeline: pipeline, threads: [values.count, 1, 1], threadsPerThreadgroup: [pipeline.maxTotalThreadsPerThreadgroup, 1, 1])
75 |
76 | print(isSorted.contentsBuffer(of: UInt32.self)[0])
77 | }
78 |
79 | }
80 |
81 | static func complex() async throws {
82 | let capture = false
83 |
84 | let device = MTLCreateSystemDefaultDevice()!
85 | try device.capture(enabled: capture) {
86 | let logger = Logger()
87 | let compute = try Compute(device: device, logger: logger, useLogState: !capture)
88 | let library = ShaderLibrary.source(source, enableLogging: !capture)
89 | var values = (0..<2_000).map { Float($0) }.sorted()
90 | //
91 | values.swapAt(1, 2)
92 | //
93 | let isSorted = try device.makeBuffer(bytesOf: UInt32(1), options: [.storageModeShared])
94 | isSorted.label = "isSortedComplex"
95 |
96 | var pipeline = try compute.makePipeline(function: library.is_sorted_complex)
97 | pipeline.arguments.input = .buffer(values, label: "input")
98 | pipeline.arguments.count = .int(values.count)
99 | pipeline.arguments.isSorted = .buffer(isSorted)
100 |
101 | let threadsPerThreadgroup = pipeline.maxTotalThreadsPerThreadgroup
102 |
103 | try compute.run(pipeline: pipeline, threads: [(values.count + threadsPerThreadgroup - 1) / threadsPerThreadgroup, 1, 1], threadsPerThreadgroup: [threadsPerThreadgroup, 1, 1])
104 |
105 | print(isSorted.contentsBuffer(of: UInt32.self)[0])
106 | }
107 |
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/MaxParallel.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 | import os
4 |
5 | enum MaxParallel: Demo {
6 | static let source = #"""
7 | #include
8 | #include
9 |
10 | using namespace metal;
11 |
12 | uint thread_position_in_grid [[thread_position_in_grid]];
13 |
14 | kernel void kernel_main(
15 | device atomic_uint &count [[buffer(0)]],
16 | device atomic_uint &maximum [[buffer(1)]]
17 | ) {
18 | uint current = atomic_fetch_add_explicit(&count, 1, memory_order_relaxed);
19 | atomic_fetch_max_explicit(&maximum, current, memory_order_relaxed);
20 | atomic_fetch_add_explicit(&count, -1, memory_order_relaxed);
21 | }
22 | """#
23 |
24 | static func main() async throws {
25 | let device = MTLCreateSystemDefaultDevice()!
26 | let compute = try Compute(device: device)
27 | let library = ShaderLibrary.source(source, enableLogging: true)
28 | let count = device.makeBuffer(length: MemoryLayout.size * 1, options: [])!
29 | let maximum = device.makeBuffer(length: MemoryLayout.size * 1, options: [])!
30 | var pipeline = try compute.makePipeline(function: library.kernel_main)
31 | pipeline.arguments.count = .buffer(count)
32 | pipeline.arguments.maximum = .buffer(maximum)
33 |
34 |
35 | print("thread width,\tmaximum,\ttime (ms)")
36 | for n in 0...31 {
37 | let count = Int(pow(2, Double(n+1))) - 1
38 | let time = try timed() {
39 | try compute.run(pipeline: pipeline, width: count)
40 | }
41 | print("\(count),\t\(Array(maximum)[0]),\t\(Double(time) / Double(1_000_000))")
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/MaxValue.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import CoreImage
3 | import Metal
4 | import MetalKit
5 | import os
6 | import UniformTypeIdentifiers
7 |
8 | enum MaxValue: Demo {
9 |
10 | // Finds the maximum value in an array using a single-threaded Metal compute shader.
11 | // This method is very inefficient and is provided as an example of a suboptimal approach.
12 | static func badIdea(values: [Int32], expectedValue: Int32) throws {
13 | let source = #"""
14 | #include
15 |
16 | using namespace metal;
17 |
18 | kernel void maxValue(
19 | const device uint *input [[buffer(0)]],
20 | constant uint &count [[buffer(1)]],
21 | device uint &output [[buffer(2)]]
22 | ) {
23 | uint temp = 0;
24 | for (uint n = 0; n != count; ++n) {
25 | temp = max(temp, input[n]);
26 | }
27 | output = temp;
28 | }
29 | """#
30 |
31 | let device = MTLCreateSystemDefaultDevice()!
32 | let input = device.makeBuffer(bytes: values, length: MemoryLayout.stride * values.count, options: [])!
33 | let output = device.makeBuffer(length: MemoryLayout.size)!
34 | let compute = try Compute(device: device)
35 | let library = ShaderLibrary.source(source)
36 | var maxValue = try compute.makePipeline(function: library.maxValue)
37 | maxValue.arguments.input = .buffer(input)
38 | maxValue.arguments.count = .int(UInt32(values.count))
39 | maxValue.arguments.output = .buffer(output)
40 | try timeit(#function) {
41 | try compute.run(pipeline: maxValue, width: 1)
42 | }
43 | let result = output.contents().assumingMemoryBound(to: Int32.self)[0]
44 | assert(result == expectedValue)
45 | }
46 |
47 | // Finds the maximum value in an array using an atomic operation in a Metal compute shader.
48 | // Despite all threads fighting over a single lock this version is still extremely fast.
49 | static func simpleAtomic(values: [Int32], expectedValue: Int32) throws {
50 | let source = #"""
51 | #include
52 |
53 | using namespace metal;
54 |
55 | uint thread_position_in_grid [[thread_position_in_grid]];
56 |
57 | kernel void maxValue(
58 | const device uint *input [[buffer(0)]],
59 | device atomic_uint *output [[buffer(1)]]
60 | ) {
61 | const uint value = input[thread_position_in_grid];
62 | atomic_fetch_max_explicit(output, value, memory_order_relaxed);
63 | }
64 | """#
65 | let device = MTLCreateSystemDefaultDevice()!
66 | let input = device.makeBuffer(bytes: values, length: MemoryLayout.stride * values.count, options: [])!
67 | let output = device.makeBuffer(length: MemoryLayout.size)!
68 | let compute = try Compute(device: device)
69 | let library = ShaderLibrary.source(source)
70 | var maxValue = try compute.makePipeline(function: library.maxValue)
71 | maxValue.arguments.input = .buffer(input)
72 | maxValue.arguments.output = .buffer(output)
73 | try timeit(#function) {
74 | try compute.run(pipeline: maxValue, width: values.count)
75 | }
76 | let result = output.contents().assumingMemoryBound(to: Int32.self)[0]
77 | assert(result == expectedValue)
78 | }
79 |
80 |
81 | // Finds the maximum value in an array using a multi-pass approach in a Metal compute shader.
82 | // This method uses SIMD group operations for efficient parallel processing.
83 | // Note: this method is destructive and intermediate values are written to the input buffer.
84 | // TODO: This method may occasionally fail for reasons that are currently unclear.
85 | // NOTE: Does NOT seem to be failing on M1 Ultra. 8/30/2024.
86 | static func multipass(values: [Int32], expectedValue: Int32) throws {
87 | let source = #"""
88 | #include
89 |
90 | using namespace metal;
91 |
92 | // Get the global thread position in the execution grid
93 | uint thread_position_in_grid [[thread_position_in_grid]];
94 |
95 | // Get the number of threads per SIMD group
96 | uint threads_per_simdgroup [[threads_per_simdgroup]];
97 |
98 | kernel void maxValue(
99 | device int *input [[buffer(0)]], // Input/output buffer
100 | constant uint &count [[buffer(1)]], // Total number of elements
101 | constant uint &stride [[buffer(2)]] // Stride between elements processed by each thread
102 | ) {
103 | // Calculate the index for this thread
104 | const uint index = thread_position_in_grid * stride;
105 |
106 | // Get the value for this thread, or INT_MIN if out of bounds
107 | uint localValue = index >= count ? INT_MIN : input[index];
108 |
109 | // Perform a parallel reduction to find the maximum value
110 | for (uint offset = threads_per_simdgroup >> 1; offset > 0; offset >>= 1) {
111 | // Get the value from another thread in the SIMD group
112 | const uint remoteValue = simd_shuffle_down(localValue, offset);
113 |
114 | // Update the current value with the maximum of current and remote
115 | localValue = max(localValue, remoteValue);
116 | }
117 |
118 | // Only the first thread in each SIMD group writes the result
119 | if (simd_is_first()) {
120 | input[index] = localValue;
121 | }
122 | }
123 | """#
124 | let device = MTLCreateSystemDefaultDevice()!
125 | let input = device.makeBuffer(bytes: values, length: MemoryLayout.stride * values.count, options: [])!
126 | let compute = try Compute(device: device)
127 | let library = ShaderLibrary.source(source)
128 | var pipeline = try compute.makePipeline(function: library.maxValue)
129 | pipeline.arguments.input = .buffer(input)
130 |
131 | // This is equivalent `threads_per_simdgroup` in MSL.
132 | let threadExecutionWidth = pipeline.computePipelineState.threadExecutionWidth
133 | assert(threadExecutionWidth == 32)
134 | let maxTotalThreadsPerThreadgroup = pipeline.computePipelineState.maxTotalThreadsPerThreadgroup
135 |
136 | try timeit(#function) {
137 |
138 | // Initialize the stride (between elements processed by each thread) to 1
139 | var stride = 1
140 |
141 | try compute.task { task in
142 | try task { dispatch in
143 | // Continue looping while the stride is less than or equal to the total count of values
144 | while stride <= values.count {
145 | // Set the 'count' argument for the compute pipeline to the total number of values
146 | pipeline.arguments.count = .int(Int32(values.count))
147 |
148 | // Set the 'stride' argument for the compute pipeline
149 | pipeline.arguments.stride = .int(UInt32(stride))
150 |
151 | // Dispatch the compute pipeline
152 | try dispatch(
153 | pipeline: pipeline,
154 | // Set the total number of threads to process all values
155 | threads: MTLSize(width: values.count / stride, height: 1, depth: 1),
156 | // Set the number of threads per threadgroup to the maximum allowed
157 | threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1, depth: 1)
158 | )
159 |
160 | // Increase the stride by multiplying it with the thread execution width
161 | // This effectively reduces the number of active threads in each iteration
162 | stride *= threadExecutionWidth
163 | }
164 | }
165 | }
166 | }
167 |
168 | let result = input.contents().assumingMemoryBound(to: Int32.self)[0]
169 | assert(result == expectedValue)
170 | }
171 |
172 | static func main() async throws {
173 | // var values = Array(Array(repeating: Int32.zero, count: 1000))
174 | let expectedValue: Int32 = 123456789
175 |
176 |
177 | #if os(macOS)
178 | let count: Int32 = 1_000_000
179 | #else
180 | let count: Int32 = 1_000_000
181 | #endif
182 |
183 | let values = timeit("Generating \(count) values") {
184 | var values = Array(Int32.zero ..< count)
185 | values[Int.random(in: 0..
8 |
9 | using namespace metal;
10 |
11 | // Thread position in the execution grid
12 | uint thread_position_in_grid [[thread_position_in_grid]];
13 |
14 | // Empty kernel for baseline performance measurement
15 | kernel void empty()
16 | {
17 | }
18 |
19 | // Kernel to fill buffer with thread positions
20 | kernel void fill(
21 | device uint* output [[buffer(0)]] // Output buffer
22 | )
23 | {
24 | output[thread_position_in_grid] = thread_position_in_grid;
25 | }
26 |
27 | // Kernel to copy data from input buffer to output buffer
28 | kernel void memcpy(
29 | const device uint* input [[buffer(0)]], // Input buffer
30 | device uint* output [[buffer(1)]] // Output buffer
31 | )
32 | {
33 | output[thread_position_in_grid] = input[thread_position_in_grid];
34 | }
35 | """#
36 |
37 | static func main() async throws {
38 | // Get the default Metal device
39 | let device = MTLCreateSystemDefaultDevice()!
40 | // Set count to maximum value of UInt32
41 | let count = Int(UInt32.max)
42 | // Calculate the length of the buffers in bytes
43 | let length = MemoryLayout.stride * count
44 |
45 | // Print the size of the buffers in gigabytes
46 | print(Measurement(value: Double(length), unit: .bytes).converted(to: .gibibytes))
47 |
48 | print("# Allocating")
49 | // Create two Metal buffers of the calculated length
50 | let bufferA = device.makeBuffer(length: length)!
51 | let bufferB = device.makeBuffer(length: length)!
52 |
53 | print("# Preparing")
54 | // Create a Compute object with the Metal device
55 | let compute = try Compute(device: device)
56 | // Create a shader library from the source code
57 | let library = ShaderLibrary.source(source)
58 | // Create compute pipelines for each kernel function
59 | let empty = try compute.makePipeline(function: library.empty)
60 | let fill = try compute.makePipeline(function: library.fill, arguments: ["output": .buffer(bufferA)])
61 | let memcopy = try compute.makePipeline(function: library.memcpy, arguments: ["input": .buffer(bufferA), "output": .buffer(bufferB)])
62 |
63 | print("# Empty")
64 | // Run and time the empty kernel (baseline)
65 | try timeit(length: length) {
66 | try compute.run(pipeline: empty, width: count)
67 | }
68 |
69 | print("# Filling")
70 | // Run and time the fill kernel
71 | try timeit(length: length) {
72 | try compute.run(pipeline: fill, width: count)
73 | }
74 |
75 | print("# GPU memcpy")
76 | // Run and time the GPU memcpy kernel
77 | try timeit(length: length) {
78 | try compute.run(pipeline: memcopy, width: count)
79 | }
80 |
81 | print("# CPU memcpy")
82 | // Run and time CPU memcpy for comparison
83 | timeit(length: length) {
84 | memcpy(bufferB.contents(), bufferA.contents(), length)
85 | }
86 |
87 | print("# DONE")
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/RandomFill.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import CoreGraphics
3 | import Foundation
4 | import Metal
5 | import MetalKit
6 |
7 | // swiftlint:disable force_unwrapping
8 |
9 | struct RandomFill: Demo {
10 | static let source = #"""
11 | #include
12 | #include
13 |
14 | using namespace metal;
15 |
16 | float random(float2 p)
17 | {
18 | // We need irrationals for pseudo randomness.
19 | // Most (all?) known transcendental numbers will (generally) work.
20 | const float2 r = float2(
21 | 23.1406926327792690, // e^pi (Gelfond's constant)
22 | 2.6651441426902251); // 2^sqrt(2) (Gelfond–Schneider constant)
23 | return fract(cos(fmod(123456789.0, 1e-7 + 256.0 * dot(p,r))));
24 | }
25 |
26 | uint2 thread_position_in_grid [[thread_position_in_grid]];
27 |
28 | [[kernel]]
29 | void randomFill_float(texture2d outputTexture [[texture(0)]])
30 | {
31 | const float2 id = float2(thread_position_in_grid);
32 |
33 | float value = random(id) > 0.5 ? 1 : 0;
34 |
35 | float4 color = { value, value, value, 1 };
36 |
37 | outputTexture.write(color, thread_position_in_grid);
38 | }
39 | """#
40 |
41 | static func main() async throws {
42 | let device = MTLCreateSystemDefaultDevice()!
43 |
44 | let outputTextureDescriptor = MTLTextureDescriptor()
45 | outputTextureDescriptor.width = 1_024
46 | outputTextureDescriptor.height = 1_024
47 | outputTextureDescriptor.pixelFormat = .bgra8Unorm
48 | outputTextureDescriptor.usage = .shaderWrite
49 | guard let outputTexture = device.makeTexture(descriptor: outputTextureDescriptor) else {
50 | throw ComputeError.resourceCreationFailure
51 | }
52 |
53 | let compute = try Compute(device: device)
54 |
55 | let library = ShaderLibrary.source(source)
56 |
57 | var pipeline = try compute.makePipeline(function: library.randomFill_float)
58 | pipeline.arguments.outputTexture = .texture(outputTexture)
59 |
60 | try compute.task { task in
61 | try task { dispatch in
62 | try dispatch(pipeline: pipeline, threadgroupsPerGrid: MTLSize(width: outputTexture.width, height: outputTexture.height, depth: 1), threadsPerThreadgroup: MTLSize(width: 1, height: 1, depth: 1))
63 | }
64 | }
65 |
66 | let url = URL(filePath: "/tmp/randomfill.png")
67 | try outputTexture.export(to: url)
68 | #if os(macOS)
69 | NSWorkspace.shared.selectFile(url.path, inFileViewerRootedAtPath: url.deletingLastPathComponent().path)
70 | #endif
71 |
72 |
73 |
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/SIMDReduce.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import os
3 | import Metal
4 |
5 | let source = #"""
6 | #include
7 | #include
8 |
9 | using namespace metal;
10 |
11 | // Thread-specific attributes
12 | uint thread_position_in_grid [[thread_position_in_grid]]; // Position of the current thread in the grid
13 | uint threads_per_simdgroup [[threads_per_simdgroup]]; // Number of threads in a SIMD group
14 | uint threadgroup_position_in_grid [[threadgroup_position_in_grid]]; // Position of the current threadgroup in the grid
15 |
16 | // Kernel function for parallel reduction sum
17 | kernel void parallel_reduction_sum(
18 | constant uint* input [[buffer(0)]], // Input buffer
19 | device uint* output [[buffer(1)]] // Output buffer
20 | )
21 | {
22 | // Get the value for the current thread
23 | uint value = input[thread_position_in_grid];
24 |
25 | // Perform parallel reduction within SIMD group
26 | for (uint offset = threads_per_simdgroup / 2; offset > 0; offset >>= 1) {
27 | // Add the value from the thread 'offset' positions ahead
28 | value += simd_shuffle_and_fill_down(value, 0u, offset);
29 | }
30 |
31 | // Only the first thread in each SIMD group writes the result
32 | if (simd_is_first()) {
33 | output[thread_position_in_grid / threads_per_simdgroup] = value;
34 | }
35 | }
36 | """#
37 |
38 | struct SIMDReduce: Demo {
39 |
40 | static func main() async throws {
41 | try dualBuffer()
42 | }
43 |
44 | static func dualBuffer() throws {
45 | let device = MTLCreateSystemDefaultDevice()!
46 | // Create N values and sum them up in the CPU...
47 | var count = 50_000_000
48 | let values = timeit("Generating \(count) values") {
49 | (0.. 1
84 | }
85 | }
86 | }
87 | let result = Array(output.contentsBuffer(of: UInt32.self))[0]
88 | print("Compute result:", expectedResult, result, result == expectedResult)
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/Sources/Examples/Examples/ThreadgroupLog.swift:
--------------------------------------------------------------------------------
1 | import Compute
2 | import Metal
3 | import os
4 |
5 | enum ThreadgroupLogging: Demo {
6 | static let source = #"""
7 | #include
8 | #include
9 |
10 | const uint thread_position_in_grid [[thread_position_in_grid]];
11 | const uint threadgroup_position_in_grid [[threadgroup_position_in_grid]];
12 | const uint thread_position_in_threadgroup [[thread_position_in_threadgroup]];
13 |
14 | const uint threads_per_grid [[threads_per_grid]];
15 | const uint threads_per_threadgroup [[threads_per_threadgroup]];
16 | const uint threadgroups_per_grid [[threadgroups_per_grid]];
17 |
18 | using namespace metal;
19 |
20 | kernel void threadgroup_test() {
21 | if (thread_position_in_grid == 0) {
22 | os_log_default.log("threads_per_grid: %d, threads_per_threadgroup: %d, threadgroups_per_grid: %d", threads_per_grid, threads_per_threadgroup, threadgroups_per_grid);
23 | }
24 | os_log_default.log("thread_position_in_grid: %d, thread_position_in_threadgroup: %d, threadgroup_position_in_grid: %d", thread_position_in_grid, thread_position_in_threadgroup, threadgroup_position_in_grid);
25 | }
26 | """#
27 |
28 | static func main() async throws {
29 | let device = MTLCreateSystemDefaultDevice()!
30 | let compute = try Compute(device: device, logger: Logger())
31 | let library = ShaderLibrary.source(source, enableLogging: true)
32 | let pipeline = try compute.makePipeline(function: library.threadgroup_test)
33 | try compute.run(pipeline: pipeline, threadgroupsPerGrid: MTLSize(width: 3), threadsPerThreadgroup: MTLSize(width: 2))
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/Sources/Examples/Resources/Media.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Sources/Examples/Resources/Media.xcassets/baboon-acorn-inverted.imageset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "filename" : "baboon-acorn-inverted.png",
5 | "idiom" : "universal"
6 | }
7 | ],
8 | "info" : {
9 | "author" : "xcode",
10 | "version" : 1
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Sources/Examples/Resources/Media.xcassets/baboon-acorn-inverted.imageset/baboon-acorn-inverted.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/schwa/Compute/b1ba1e1071c912007bdbb59313e175887335fab8/Sources/Examples/Resources/Media.xcassets/baboon-acorn-inverted.imageset/baboon-acorn-inverted.png
--------------------------------------------------------------------------------
/Sources/Examples/Resources/Media.xcassets/baboon.imageset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "filename" : "baboon.png",
5 | "idiom" : "universal"
6 | }
7 | ],
8 | "info" : {
9 | "author" : "xcode",
10 | "version" : 1
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/schwa/Compute/b1ba1e1071c912007bdbb59313e175887335fab8/Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png
--------------------------------------------------------------------------------
/Sources/Examples/Support.swift:
--------------------------------------------------------------------------------
1 | import AVFoundation
2 | import Compute
3 | import CoreGraphics
4 | import Foundation
5 | import Metal
6 | import MetalSupportLite
7 |
8 | public func getMachTimeInNanoseconds() -> UInt64 {
9 | var timebase = mach_timebase_info_data_t()
10 | mach_timebase_info(&timebase)
11 | let currentTime = mach_absolute_time()
12 | return currentTime * UInt64(timebase.numer) / UInt64(timebase.denom)
13 | }
14 |
15 | @discardableResult
16 | public func timed(_ work: () throws -> Void) rethrows -> UInt64 {
17 | let start = getMachTimeInNanoseconds()
18 | try work()
19 | let end = getMachTimeInNanoseconds()
20 | return end - start
21 | }
22 |
23 | @discardableResult
24 | public func timeit(_ work: () throws -> R, display: (UInt64) -> Void) rethrows -> R {
25 | let start = getMachTimeInNanoseconds()
26 | let result = try work()
27 | let end = getMachTimeInNanoseconds()
28 | display(end - start)
29 | return result
30 | }
31 |
32 | @discardableResult
33 | public func timeit(_ label: String? = nil, _ work: () throws -> R) rethrows -> R {
34 | try timeit(work) { delta in
35 | let measurement = Measurement(value: Double(delta), unit: UnitDuration.nanoseconds)
36 | let measurementMS = measurement.converted(to: .milliseconds)
37 | print("\(label ?? ""): \(measurementMS.formatted())")
38 | }
39 | }
40 |
41 | @discardableResult
42 | public func timeit(_ label: String? = nil, length: Int, _ work: () throws -> R) rethrows -> R {
43 | try timeit(work) { delta in
44 | let seconds = Double(delta) / 1_000_000_000
45 | let bytesPerSecond = Double(length) / seconds
46 | let gigabytesPerSecond = Measurement(value: bytesPerSecond, unit: UnitInformationStorage.bytes)
47 | .converted(to: .gigabytes)
48 | print("Time: \(Measurement(value: Double(seconds), unit: UnitDuration.seconds).converted(to: .milliseconds).formatted())")
49 | print("Speed: \(gigabytesPerSecond.formatted(.measurement(width: .abbreviated, usage: .asProvided)))/s")
50 | }
51 | }
52 |
53 | class TextureToVideoWriter {
54 | private var assetWriter: AVAssetWriter
55 | private var writerInput: AVAssetWriterInput?
56 | private var adaptor: AVAssetWriterInputPixelBufferAdaptor?
57 |
58 | let outputURL: URL
59 | let temporaryURL: URL
60 | private let size: CGSize
61 | private let pixelFormat = kCVPixelFormatType_32BGRA
62 |
63 | var endTime: CMTime?
64 |
65 | init(outputURL: URL, size: CGSize) throws {
66 | self.outputURL = outputURL
67 | let temporaryURL = FileManager.default.temporaryDirectory.appendingPathComponent("\(UUID().uuidString).mp4")
68 | self.temporaryURL = temporaryURL
69 | self.size = size
70 |
71 | assetWriter = try AVAssetWriter(outputURL: temporaryURL, fileType: .mov)
72 | }
73 |
74 | func start() {
75 | let settings: [String: Any] = [
76 | AVVideoCodecKey: AVVideoCodecType.hevc,
77 | AVVideoWidthKey: size.width,
78 | AVVideoHeightKey: size.height
79 | ]
80 |
81 | writerInput = AVAssetWriterInput(mediaType: .video, outputSettings: settings)
82 | writerInput?.expectsMediaDataInRealTime = true
83 |
84 | adaptor = AVAssetWriterInputPixelBufferAdaptor(
85 | assetWriterInput: writerInput!,
86 | sourcePixelBufferAttributes: [
87 | kCVPixelBufferPixelFormatTypeKey as String: pixelFormat,
88 | kCVPixelBufferWidthKey as String: Int(size.width),
89 | kCVPixelBufferHeightKey as String: Int(size.height)
90 | ]
91 | )
92 |
93 | if assetWriter.canAdd(writerInput!) {
94 | assetWriter.add(writerInput!)
95 | }
96 |
97 | assetWriter.startWriting()
98 | assetWriter.startSession(atSourceTime: .zero)
99 | }
100 |
101 | func writeFrame(texture: MTLTexture, at time: CMTime) {
102 | autoreleasepool {
103 | guard let adaptor, let pixelBufferPool = adaptor.pixelBufferPool else {
104 | fatalError()
105 | }
106 | var pixelBuffer: CVPixelBuffer?
107 | CVPixelBufferPoolCreatePixelBuffer(nil, pixelBufferPool, &pixelBuffer)
108 | guard let pixelBuffer else {
109 | fatalError()
110 | }
111 | CVPixelBufferLockBaseAddress(pixelBuffer, .readOnly)
112 | guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else {
113 | fatalError()
114 | }
115 | let region = MTLRegionMake2D(0, 0, texture.width, texture.height)
116 | texture.getBytes(baseAddress, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBuffer), from: region, mipmapLevel: 0)
117 | CVPixelBufferUnlockBaseAddress(pixelBuffer, .readOnly)
118 | writeFrame(pixelBuffer: pixelBuffer, at: time)
119 | endTime = time
120 | }
121 | }
122 |
123 | func writeFrame(pixelBuffer: CVPixelBuffer, at time: CMTime) {
124 | autoreleasepool {
125 | guard let writerInput, let adaptor else {
126 | fatalError()
127 | }
128 | if writerInput.isReadyForMoreMediaData == false {
129 | // This isn't pretty but it works?
130 | while writerInput.isReadyForMoreMediaData == false {
131 | usleep(10 * 1_000)
132 | }
133 | }
134 | adaptor.append(pixelBuffer, withPresentationTime: time)
135 | }
136 | }
137 |
138 | func finish() async throws {
139 | writerInput?.markAsFinished()
140 | assetWriter.endSession(atSourceTime: endTime!)
141 | await assetWriter.finishWriting()
142 | let fileManager = FileManager()
143 | if fileManager.fileExists(atPath: outputURL.path) {
144 | try FileManager().removeItem(at: outputURL)
145 | }
146 | try FileManager().moveItem(at: temporaryURL, to: outputURL)
147 | }
148 | }
149 |
150 | extension MTLTexture {
151 | func export(to url: URL) throws {
152 | assert(pixelFormat == .bgra8Unorm)
153 | assert(depth == 1)
154 |
155 | let bytesPerRow = width * MemoryLayout.size * 4
156 | guard let colorSpace = CGColorSpace(name: CGColorSpace.genericRGBLinear) else {
157 | throw ComputeError.resourceCreationFailure
158 | }
159 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue)
160 | .union(.byteOrder32Little)
161 | guard let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: bytesPerRow, space: colorSpace, bitmapInfo: bitmapInfo.rawValue) else {
162 | throw ComputeError.resourceCreationFailure
163 | }
164 | guard let contextData = context.data else {
165 | throw ComputeError.resourceCreationFailure
166 | }
167 | getBytes(contextData, bytesPerRow: bytesPerRow, from: MTLRegion(origin: MTLOrigin(x: 0, y: 0, z: 0), size: MTLSize(width: width, height: height, depth: 1)), mipmapLevel: 0)
168 |
169 | guard let image = context.makeImage() else {
170 | throw ComputeError.resourceCreationFailure
171 | }
172 |
173 | guard let imageDestination = CGImageDestinationCreateWithURL(url as CFURL, UTType.png.identifier as CFString, 1, nil) else {
174 | throw ComputeError.resourceCreationFailure
175 | }
176 | CGImageDestinationAddImage(imageDestination, image, nil)
177 | CGImageDestinationFinalize(imageDestination)
178 | }
179 | }
180 |
181 | extension MTLBuffer {
182 | func withUnsafeBytes(_ body: (UnsafeBufferPointer) throws -> ResultType) rethrows -> ResultType {
183 | try withUnsafeBytes { (buffer: UnsafeRawBufferPointer) in
184 | try buffer.withMemoryRebound(to: ContentType.self, body)
185 | }
186 | }
187 |
188 | func withUnsafeBytes(_ body: (UnsafeRawBufferPointer) throws -> ResultType) rethrows -> ResultType {
189 | try body(UnsafeRawBufferPointer(start: contents(), count: length))
190 | }
191 | }
192 |
193 | extension Array where Element: Equatable {
194 |
195 | struct Run {
196 | var element: Element
197 | var count: Int
198 | }
199 |
200 | func rle() -> [Run] {
201 |
202 | var lastElement: Element?
203 | var runLength = 0
204 |
205 | var runs: [Run] = []
206 |
207 | for element in self {
208 | if element == lastElement {
209 | runLength += 1
210 | }
211 | else {
212 | if let lastElement {
213 | runs.append(.init(element: lastElement, count: runLength))
214 | }
215 | lastElement = element
216 | runLength = 1
217 | }
218 | }
219 |
220 | if let lastElement {
221 | runs.append(.init(element: lastElement, count: runLength))
222 | }
223 |
224 | return runs
225 | }
226 |
227 | }
228 |
229 | extension Array where Element == UInt32 {
230 | func prefixSum() -> [UInt32] {
231 | var output = Array(repeating: UInt32.zero, count: count)
232 | for j in 1.. UInt32 {
242 | let result = UInt32(31 - value.leadingZeroBitCount)
243 | return ceiling && (1 << result) < value ? result + 1 : result
244 | }
245 |
246 | func log2(_ value: Int, ceiling: Bool = false) -> Int {
247 | precondition(value > 0, "log2 is only defined for positive numbers")
248 | let result = 63 - value.leadingZeroBitCount
249 | return ceiling && (1 << result) < value ? result + 1 : result
250 | }
251 |
252 | // MARK: -
253 |
254 | infix operator **: MultiplicationPrecedence
255 |
256 | func ** (base: Int, exponent: Int) -> Int {
257 | return Int(pow(Double(base), Double(exponent)))
258 | }
259 |
260 | func ** (base: UInt32, exponent: UInt32) -> UInt32 {
261 | return UInt32(pow(Double(base), Double(exponent)))
262 | }
263 |
264 | extension Array {
265 | init(_ buffer: MTLBuffer) {
266 | let pointer = buffer.contents().bindMemory(to: Element.self, capacity: buffer.length / MemoryLayout.size)
267 | let buffer = UnsafeBufferPointer(start: pointer, count: buffer.length / MemoryLayout.size)
268 | self = Array(buffer)
269 | }
270 |
271 | }
272 |
273 | func ceildiv (_ x: T, _ y: T) -> T where T: BinaryInteger {
274 | (x + y - 1) / y
275 | }
276 |
277 | extension Array {
278 | init(_ buffer: TypedMTLBuffer) {
279 | self = buffer.withUnsafeMTLBuffer { buffer in
280 | Array(buffer!)
281 | }
282 | }
283 | }
284 |
285 | extension Compute.Argument {
286 | static func buffer(_ data: TypedMTLBuffer) -> Self {
287 | data.withUnsafeMTLBuffer { buffer in
288 | return .buffer(buffer!)
289 | }
290 | }
291 | }
292 |
293 | func nextPowerOfTwo(_ n: Int) -> Int {
294 | return Int(pow(2.0, Double(Int(log2(Double(n))).advanced(by: 1))) + 0.5)
295 | }
296 |
297 | extension MTLSize: @retroactive ExpressibleByArrayLiteral {
298 | public init(arrayLiteral elements: Int...) {
299 | switch elements.count {
300 | case 1:
301 | self = .init(elements[0], 1, 1)
302 | case 2:
303 | self = .init(elements[0], elements[1], 1)
304 | case 3:
305 | self = .init(elements[0], elements[1], elements[2])
306 | default:
307 | fatalError()
308 | }
309 | }
310 |
311 | }
312 |
--------------------------------------------------------------------------------
/Sources/MetalSupportLite/BaseSupport.swift:
--------------------------------------------------------------------------------
1 | public enum BaseError: Error {
2 | // IDEA: Have a good going over here and clean up duplicate/vague types.
3 | case generic(String)
4 | case resourceCreationFailure
5 | case illegalValue
6 | case optionalUnwrapFailure
7 | case initializationFailure
8 | case unknown
9 | case missingValue
10 | case typeMismatch
11 | case inputOutputFailure
12 | case invalidParameter
13 | case parsingFailure
14 | case encodingFailure
15 | case missingResource
16 | case extended(Error, String)
17 | case decodingFailure
18 | case missingBinding(String)
19 | case overflow
20 | }
21 |
22 | public extension BaseError {
23 | static func error(_ error: Self) -> Self {
24 | // NOTE: Hook here to add logging or special breakpoint handling.
25 | // logger?.error("Error: \(error)")
26 | error
27 | }
28 | }
29 |
30 | public func fatalError(_ error: Error) -> Never {
31 | fatalError("\(error)")
32 | }
33 |
34 | public func unimplemented(_ message: @autoclosure () -> String = String(), file: StaticString = #file, line: UInt = #line) -> Never {
35 | fatalError(message(), file: file, line: line)
36 | }
37 |
38 | public func temporarilyDisabled(_ message: @autoclosure () -> String = String(), file: StaticString = #file, line: UInt = #line) -> Never {
39 | fatalError(message(), file: file, line: line)
40 | }
41 |
42 | public func unreachable(_ message: @autoclosure () -> String = String(), file: StaticString = #file, line: UInt = #line) -> Never {
43 | fatalError(message(), file: file, line: line)
44 | }
45 |
46 | // MARK: -
47 |
48 | public extension Optional {
49 | func safelyUnwrap(_ error: @autoclosure () -> Error) throws -> Wrapped {
50 | // swiftlint:disable:next shorthand_optional_binding
51 | guard let self = self else {
52 | throw error()
53 | }
54 | return self
55 | }
56 |
57 | func forceUnwrap() -> Wrapped {
58 | // swiftlint:disable:next shorthand_optional_binding
59 | guard let self = self else {
60 | fatalError("Cannot unwrap nil optional.")
61 | }
62 | return self
63 | }
64 |
65 | func forceUnwrap(_ message: @autoclosure () -> String) -> Wrapped {
66 | // swiftlint:disable:next shorthand_optional_binding
67 | guard let self = self else {
68 | fatalError(message())
69 | }
70 | return self
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/Sources/MetalSupportLite/MTLBuffer+Extensions.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 |
3 | public extension MTLDevice {
4 | // TODO: Rename
5 | func makeBufferEx(bytes pointer: UnsafeRawPointer, length: Int, options: MTLResourceOptions = []) throws -> MTLBuffer {
6 | guard let buffer = makeBuffer(bytes: pointer, length: length, options: options) else {
7 | throw BaseError.error(.resourceCreationFailure)
8 | }
9 | return buffer
10 | }
11 |
12 | // TODO: Rename
13 | func makeBufferEx(length: Int, options: MTLResourceOptions = []) throws -> MTLBuffer {
14 | guard let buffer = makeBuffer(length: length, options: options) else {
15 | throw BaseError.error(.resourceCreationFailure)
16 | }
17 | return buffer
18 | }
19 |
20 | func makeBuffer(data: Data, options: MTLResourceOptions) throws -> MTLBuffer {
21 | try data.withUnsafeBytes { buffer in
22 | let baseAddress = buffer.baseAddress.forceUnwrap("No baseAddress.")
23 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else {
24 | throw BaseError.error(.resourceCreationFailure)
25 | }
26 | return buffer
27 | }
28 | }
29 |
30 | func makeBuffer(bytesOf content: some Any, options: MTLResourceOptions) throws -> MTLBuffer {
31 | try withUnsafeBytes(of: content) { buffer in
32 | let baseAddress = buffer.baseAddress.forceUnwrap("No baseAddress.")
33 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else {
34 | throw BaseError.error(.resourceCreationFailure)
35 | }
36 | return buffer
37 | }
38 | }
39 |
40 | func makeBuffer(bytesOf content: [some Any], options: MTLResourceOptions) throws -> MTLBuffer {
41 | try content.withUnsafeBytes { buffer in
42 | let baseAddress = buffer.baseAddress.forceUnwrap("No baseAddress.")
43 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else {
44 | throw BaseError.error(.resourceCreationFailure)
45 | }
46 | return buffer
47 | }
48 | }
49 | }
50 |
51 | public extension MTLBuffer {
52 | func data() -> Data {
53 | Data(bytes: contents(), count: length)
54 | }
55 |
56 | /// Update a MTLBuffer's contents using an inout type block
57 | func with(type: T.Type, _ block: (inout T) -> R) -> R {
58 | let value = contents().bindMemory(to: T.self, capacity: 1)
59 | return block(&value.pointee)
60 | }
61 |
62 | func withEx(type: T.Type, count: Int, _ block: (UnsafeMutableBufferPointer) -> R) -> R {
63 | let pointer = contents().bindMemory(to: T.self, capacity: count)
64 | let buffer = UnsafeMutableBufferPointer(start: pointer, count: count)
65 | return block(buffer)
66 | }
67 |
68 | func contentsBuffer() -> UnsafeMutableRawBufferPointer {
69 | UnsafeMutableRawBufferPointer(start: contents(), count: length)
70 | }
71 |
72 | func contentsBuffer(of type: T.Type) -> UnsafeMutableBufferPointer {
73 | contentsBuffer().bindMemory(to: type)
74 | }
75 | func labelled(_ label: String) -> MTLBuffer {
76 | self.label = label
77 | return self
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/Sources/MetalSupportLite/MetalBasicExtensions.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 |
3 | public extension MTLOrigin {
4 | init(_ origin: CGPoint) {
5 | self.init(x: Int(origin.x), y: Int(origin.y), z: 0)
6 | }
7 |
8 | static var zero: MTLOrigin {
9 | MTLOrigin(x: 0, y: 0, z: 0)
10 | }
11 | }
12 |
13 | public extension MTLRegion {
14 | init(_ rect: CGRect) {
15 | self = MTLRegion(origin: MTLOrigin(rect.origin), size: MTLSize(rect.size))
16 | }
17 | }
18 |
19 | public extension MTLSize {
20 | init(_ size: CGSize) {
21 | self.init(width: Int(size.width), height: Int(size.height), depth: 1)
22 | }
23 |
24 | init(_ width: Int, _ height: Int, _ depth: Int) {
25 | self = MTLSize(width: width, height: height, depth: depth)
26 | }
27 |
28 | init(width: Int) {
29 | self = MTLSize(width: width, height: 1, depth: 1)
30 | }
31 |
32 | init(width: Int, height: Int) {
33 | self = MTLSize(width: width, height: height, depth: 1)
34 | }
35 | }
36 |
37 | public extension SIMD4 {
38 | init(_ clearColor: MTLClearColor) {
39 | self = [clearColor.red, clearColor.green, clearColor.blue, clearColor.alpha]
40 | }
41 | }
42 |
43 | public extension MTLIndexType {
44 | var indexSize: Int {
45 | switch self {
46 | case .uint16:
47 | MemoryLayout.size
48 | case .uint32:
49 | MemoryLayout.size
50 | default:
51 | fatalError(BaseError.illegalValue)
52 | }
53 | }
54 | }
55 |
56 | public extension MTLPrimitiveType {
57 | var vertexCount: Int? {
58 | switch self {
59 | case .triangle:
60 | 3
61 | default:
62 | fatalError(BaseError.illegalValue)
63 | }
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/Sources/MetalSupportLite/MetalSupportLite.swift:
--------------------------------------------------------------------------------
1 | import Metal
2 |
3 | public extension MTLDevice {
4 | func capture (enabled: Bool = true, _ block: () throws -> R) throws -> R {
5 | guard enabled else {
6 | return try block()
7 | }
8 | let captureManager = MTLCaptureManager.shared()
9 | let captureScope = captureManager.makeCaptureScope(device: self)
10 | let captureDescriptor = MTLCaptureDescriptor()
11 | captureDescriptor.captureObject = captureScope
12 | try captureManager.startCapture(with: captureDescriptor)
13 | captureScope.begin()
14 | defer {
15 | captureScope.end()
16 | }
17 | return try block()
18 | }
19 |
20 | var supportsNonuniformThreadGroupSizes: Bool {
21 | let families: [MTLGPUFamily] = [.apple4, .apple5, .apple6, .apple7]
22 | return families.contains { supportsFamily($0) }
23 | }
24 |
25 | func makeComputePipelineState(function: MTLFunction, options: MTLPipelineOption) throws -> (MTLComputePipelineState, MTLComputePipelineReflection?) {
26 | var reflection: MTLComputePipelineReflection?
27 | let pipelineState = try makeComputePipelineState(function: function, options: options, reflection: &reflection)
28 | return (pipelineState, reflection)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/Sources/MetalSupportLite/TypedMTLBuffer.swift:
--------------------------------------------------------------------------------
1 | @preconcurrency import Metal
2 |
3 | /// A type-safe wrapper around `MTLBuffer` for managing Metal buffers with a specific element type.
4 | ///
5 | /// `TypedMTLBuffer` provides a convenient way to work with Metal buffers while maintaining type safety.
6 | /// It encapsulates an `MTLBuffer` and provides methods to safely access and manipulate its contents.
7 | ///
8 | /// - Important: This type conforms to `Sendable`. However, this conformance is only valid if the
9 | /// underlying `MTLBuffer` is not reused elsewhere. Ensure that the `MTLBuffer` is uniquely owned
10 | /// by this `TypedMTLBuffer` instance to maintain thread safety when sending across concurrency domains.
11 | /// - Note: The generic type `Element` should be a POD (Plain Old Data) type.
12 | public struct TypedMTLBuffer: Sendable {
13 | public var count: Int {
14 | willSet {
15 | assert(count <= capacity)
16 | }
17 | }
18 |
19 | /// The underlying Metal buffer.
20 | private var base: MTLBuffer?
21 |
22 | /// Initializes a new `TypedMTLBuffer` with the given Metal buffer.
23 | ///
24 | /// - Parameter mtlBuffer: The Metal buffer to wrap.
25 | /// - Precondition: The generic type `Element` must be a POD type.
26 | public init(mtlBuffer: MTLBuffer?, count: Int) {
27 | assert(_isPOD(Element.self))
28 | self.base = mtlBuffer
29 | self.count = count
30 | }
31 |
32 | public init() {
33 | self.base = nil
34 | self.count = 0
35 | }
36 |
37 | public var capacity: Int {
38 | (base?.length ?? 0) / MemoryLayout.stride
39 | }
40 |
41 | public var label: String? {
42 | get {
43 | base?.label
44 | }
45 | set {
46 | base?.label = newValue
47 | }
48 | }
49 | }
50 |
51 | extension TypedMTLBuffer: Equatable {
52 | public static func == (lhs: Self, rhs: Self) -> Bool {
53 | lhs.count == rhs.count && lhs.base === rhs.base
54 | }
55 | }
56 |
57 | extension TypedMTLBuffer: CustomDebugStringConvertible {
58 | public var debugDescription: String {
59 | "TypedMTLBuffer<\(type(of: Element.self))>(count: \(count), capacity: \(capacity), base.label: \(String(describing: base?.label)), base.length: \(String(describing: base?.length)))"
60 | }
61 | }
62 |
63 | // MARK: -
64 |
65 | public extension TypedMTLBuffer {
66 | var unsafeBase: MTLBuffer? {
67 | base
68 | }
69 |
70 | /// Provides temporary access to the underlying `MTLBuffer`.
71 | ///
72 | /// - Parameter block: A closure that takes an `MTLBuffer` and returns a value.
73 | /// - Returns: The value returned by the `block`.
74 | /// - Throws: Rethrows any error thrown by the `block`.
75 | func withUnsafeMTLBuffer(_ block: (MTLBuffer?) throws -> R) rethrows -> R {
76 | try block(base)
77 | }
78 |
79 | /// Provides unsafe read-only access to the buffer's contents.
80 | ///
81 | /// - Parameter block: A closure that takes an `UnsafeBufferPointer` and returns a value.
82 | /// - Returns: The value returned by the `block`.
83 | /// - Throws: Rethrows any error thrown by the `block`.
84 | func withUnsafeBufferPointer(_ block: (UnsafeBufferPointer) throws -> R) rethrows -> R {
85 | if let base {
86 | let contents = base.contents()
87 | let pointer = contents.bindMemory(to: Element.self, capacity: count)
88 | let buffer = UnsafeBufferPointer(start: pointer, count: count)
89 | return try block(buffer)
90 | }
91 | else {
92 | return try block(UnsafeBufferPointer(start: nil, count: 0))
93 | }
94 | }
95 |
96 | /// Provides unsafe mutable access to the buffer's contents.
97 | ///
98 | /// - Parameter block: A closure that takes an `UnsafeMutableBufferPointer` and returns a value.
99 | /// - Returns: The value returned by the `block`.
100 | /// - Throws: Rethrows any error thrown by the `block`.
101 | func withUnsafeMutableBufferPointer(_ block: (UnsafeMutableBufferPointer) throws -> R) rethrows -> R {
102 | if let base {
103 | let contents = base.contents()
104 | let pointer = contents.bindMemory(to: Element.self, capacity: count)
105 | let buffer = UnsafeMutableBufferPointer(start: pointer, count: count)
106 | return try block(buffer)
107 | }
108 | else {
109 | return try block(UnsafeMutableBufferPointer(start: nil, count: 0))
110 | }
111 | }
112 |
113 | /// Sets a label for the underlying Metal buffer.
114 | ///
115 | /// - Parameter label: The label to set.
116 | /// - Returns: The `TypedMTLBuffer` instance with the updated label.
117 | func labelled(_ label: String) -> Self {
118 | base?.label = label
119 | return self
120 | }
121 | }
122 |
123 | public extension TypedMTLBuffer {
124 | mutating func append(contentsOf elements: [Element]) throws {
125 | if count + elements.count > capacity {
126 | throw BaseError.error(.overflow)
127 | }
128 | guard let base else {
129 | throw BaseError.error(.overflow)
130 | }
131 |
132 | elements.withUnsafeBytes { buffer in
133 | let destination = base.contents().advanced(by: count * MemoryLayout.stride)
134 | buffer.copyBytes(to: .init(start: destination, count: buffer.count))
135 | }
136 | count += elements.count
137 | }
138 | }
139 |
140 | // MARK: -
141 |
142 | public extension MTLDevice {
143 | /// Creates a `TypedMTLBuffer` from the given data.
144 | ///
145 | /// - Parameters:
146 | /// - data: The data to copy into the new buffer.
147 | /// - options: Options for the new buffer. Default is an empty option set.
148 | /// - Returns: A new `TypedMTLBuffer` containing the specified data.
149 | /// - Throws: `BaseError.illegalValue` if the data size is not a multiple of the size of `Element`.
150 | /// `BaseError.resourceCreationFailure` if the buffer creation fails.
151 | func makeTypedBuffer(data: Data, options: MTLResourceOptions = []) throws -> TypedMTLBuffer {
152 | if !data.count.isMultiple(of: MemoryLayout.stride) {
153 | throw BaseError.error(.illegalValue)
154 | }
155 | let count = data.count / MemoryLayout.stride
156 | return try data.withUnsafeBytes { buffer in
157 | guard let baseAddress = buffer.baseAddress else {
158 | throw BaseError.error(.resourceCreationFailure)
159 | }
160 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else {
161 | throw BaseError.error(.resourceCreationFailure)
162 | }
163 | return TypedMTLBuffer(mtlBuffer: buffer, count: count)
164 | }
165 | }
166 |
167 | /// Creates a `TypedMTLBuffer` from the given array.
168 | ///
169 | /// - Parameters:
170 | /// - data: The array to copy into the new buffer.
171 | /// - options: Options for the new buffer. Default is an empty option set.
172 | /// - Returns: A new `TypedMTLBuffer` containing the specified data.
173 | /// - Throws: `BaseError.resourceCreationFailure` if the buffer creation fails.
174 | func makeTypedBuffer(data: [Element], options: MTLResourceOptions = []) throws -> TypedMTLBuffer {
175 | if data.isEmpty {
176 | return TypedMTLBuffer()
177 | }
178 | else {
179 | return try data.withUnsafeBytes { buffer in
180 | guard let baseAddress = buffer.baseAddress else {
181 | throw BaseError.error(.resourceCreationFailure)
182 | }
183 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else {
184 | throw BaseError.error(.resourceCreationFailure)
185 | }
186 | return TypedMTLBuffer(mtlBuffer: buffer, count: data.count)
187 | }
188 | }
189 | }
190 |
191 | func makeTypedBuffer(element: Element.Type, capacity: Int, options: MTLResourceOptions = []) throws -> TypedMTLBuffer {
192 | if capacity == 0 {
193 | return TypedMTLBuffer()
194 | }
195 | else {
196 | guard let buffer = makeBuffer(length: MemoryLayout.stride * capacity, options: options) else {
197 | throw BaseError.error(.resourceCreationFailure)
198 | }
199 | // TODO: FIXME - remove this
200 | memset(buffer.contents(), 0xFF, buffer.length)
201 |
202 | return TypedMTLBuffer(mtlBuffer: buffer, count: 0)
203 | }
204 | }
205 |
206 | func makeTypedBuffer(capacity: Int, options: MTLResourceOptions = []) throws -> TypedMTLBuffer {
207 | try makeTypedBuffer(element: Element.self, capacity: capacity, options: options)
208 | }
209 | }
210 |
211 | // MARK: -
212 |
213 | public extension MTLRenderCommandEncoder {
214 | /// Sets a vertex buffer for the render command encoder.
215 | ///
216 | /// - Parameters:
217 | /// - buffer: The `TypedMTLBuffer` to set as the vertex buffer.
218 | /// - offset: The offset in elements from the start of the buffer. This value is multiplied by `MemoryLayout.stride` to calculate the byte offset.
219 | /// - index: The index into the buffer argument table.
220 | func setVertexBuffer (_ buffer: TypedMTLBuffer, offset: Int, index: Int) {
221 | buffer.withUnsafeMTLBuffer {
222 | setVertexBuffer($0, offset: offset * MemoryLayout.stride, index: index)
223 | }
224 | }
225 |
226 | /// Sets a fragment buffer for the render command encoder.
227 | ///
228 | /// - Parameters:
229 | /// - buffer: The `TypedMTLBuffer` to set as the fragment buffer.
230 | /// - offset: The offset in elements from the start of the buffer. This value is multiplied by `MemoryLayout.stride` to calculate the byte offset.
231 | /// - index: The index into the buffer argument table.
232 | func setFragmentBuffer (_ buffer: TypedMTLBuffer, offset: Int, index: Int) {
233 | buffer.withUnsafeMTLBuffer {
234 | setFragmentBuffer($0, offset: offset * MemoryLayout.stride, index: index)
235 | }
236 | }
237 | }
238 |
239 | public extension MTLComputeCommandEncoder {
240 | /// Sets a buffer for the compute command encoder.
241 | ///
242 | /// - Parameters:
243 | /// - buffer: The `TypedMTLBuffer` to set.
244 | /// - offset: The offset in elements from the start of the buffer. This value is multiplied by `MemoryLayout.stride` to calculate the byte offset.
245 | /// - index: The index into the buffer argument table.
246 | func setBuffer (_ buffer: TypedMTLBuffer, offset: Int, index: Int) {
247 | buffer.withUnsafeMTLBuffer {
248 | setBuffer($0, offset: offset * MemoryLayout.stride, index: index)
249 | }
250 | }
251 | }
252 |
--------------------------------------------------------------------------------
/Tests/.swiftlint.yml:
--------------------------------------------------------------------------------
1 | disabled_rules:
2 | - force_try
3 | - force_cast
4 | - force_unwrapping
5 | - function_body_length
6 | - fatal_error_message
7 | - line_length
8 | - identifier_name
9 | - explicit_top_level_acl
10 |
--------------------------------------------------------------------------------
/Tests/ComputeTests/ComputeTests.swift:
--------------------------------------------------------------------------------
1 | @testable import Compute
2 | import Metal
3 | import Testing
4 |
5 | struct ComputeTests {
6 | let device: MTLDevice
7 | let compute: Compute
8 |
9 | init() throws {
10 | let device = MTLCreateSystemDefaultDevice()!
11 | self.device = device
12 | self.compute = try Compute(device: device)
13 | }
14 |
15 | @Test
16 | func computeInitialization() throws {
17 | #expect(compute != nil)
18 | #expect(compute.device === device)
19 | }
20 |
21 | @Test
22 | func shaderLibraryCreation() throws {
23 | let source = """
24 | #include
25 | using namespace metal;
26 | kernel void add(device int* a [[buffer(0)]], device int* b [[buffer(1)]], device int* result [[buffer(2)]], uint id [[thread_position_in_grid]]) {
27 | result[id] = a[id] + b[id];
28 | }
29 | """
30 | let library = ShaderLibrary.source(source)
31 | let function = library.add
32 | #expect(function.name == "add")
33 | }
34 |
35 | @Test
36 | func pipelineCreation() throws {
37 | let source = """
38 | #include
39 | using namespace metal;
40 | kernel void add(device int* a [[buffer(0)]], device int* b [[buffer(1)]], device int* result [[buffer(2)]], uint id [[thread_position_in_grid]]) {
41 | result[id] = a[id] + b[id];
42 | }
43 | """
44 | let library = ShaderLibrary.source(source)
45 | let function = library.add
46 |
47 | let pipeline = try compute.makePipeline(function: function)
48 |
49 | #expect(pipeline != nil)
50 | }
51 |
52 | @Test
53 | func simpleAddition() throws {
54 | let source = """
55 | #include
56 | using namespace metal;
57 | kernel void add(device int* a [[buffer(0)]], device int* b [[buffer(1)]], device int* result [[buffer(2)]], uint id [[thread_position_in_grid]]) {
58 | result[id] = a[id] + b[id];
59 | }
60 | """
61 | let library = ShaderLibrary.source(source)
62 | let function = library.add
63 |
64 | let count = 1_000
65 | let a = [Int32](repeating: 1, count: count)
66 | let b = [Int32](repeating: 2, count: count)
67 |
68 | let bufferA = device.makeBuffer(bytes: a, length: MemoryLayout.stride * count, options: [])!
69 | let bufferB = device.makeBuffer(bytes: b, length: MemoryLayout.stride * count, options: [])!
70 | let bufferResult = device.makeBuffer(length: MemoryLayout.stride * count, options: [])!
71 |
72 | let pipeline = try compute.makePipeline(
73 | function: function,
74 | arguments: [
75 | "a": .buffer(bufferA),
76 | "b": .buffer(bufferB),
77 | "result": .buffer(bufferResult)
78 | ]
79 | )
80 |
81 | try compute.run(pipeline: pipeline, width: count)
82 |
83 | bufferResult.contents().withMemoryRebound(to: Int32.self, capacity: count) { result in
84 | for i in 0..