├── .github └── workflows │ └── swift.yml ├── .gitignore ├── .justfile ├── .pre-commit-config.yaml ├── .spi.yml ├── .swiftlint.yml ├── .swiftpm ├── configuration │ └── Package.resolved └── xcode │ ├── package.xcworkspace │ └── contents.xcworkspacedata │ └── xcshareddata │ └── xcschemes │ ├── Compute-Package.xcscheme │ ├── Compute.xcscheme │ └── Examples.xcscheme ├── .vscode └── launch.json ├── Documentation ├── Screenshot 2024-08-04 at 09.57.19.png └── compute-logo.svg ├── LICENSE ├── Package.resolved ├── Package.swift ├── README.md ├── Sources ├── Compute │ ├── Compute+Arguments.swift │ ├── Compute+Pipeline.swift │ ├── Compute+Task.swift │ ├── Compute.swift │ ├── ShaderFunction.swift │ └── Support.swift ├── Examples │ ├── .swiftlint.yml │ ├── Bundle.txt │ ├── Examples.swift │ ├── Examples │ │ ├── BareMetal.swift │ │ ├── BitonicSort │ │ │ ├── BitonicSort.metal │ │ │ └── BitonicSort.swift │ │ ├── Broken │ │ │ ├── CountingSort │ │ │ │ ├── CountingSort.metal │ │ │ │ └── CountingSort.swift │ │ │ └── RadixSort │ │ │ │ ├── RadixSort.metal │ │ │ │ ├── RadixSort.swift │ │ │ │ └── RadixSortCPU.swift │ │ ├── BufferFill.swift │ │ ├── Checkerboard.swift │ │ ├── CounterDemo.swift │ │ ├── GameOfLife │ │ │ ├── GameOfLife.metal │ │ │ └── GameOfLife.swift │ │ ├── HelloWorldDemo.swift │ │ ├── Histogram.swift │ │ ├── ImageInvert.swift │ │ ├── IsSorted.swift │ │ ├── MaxParallel.swift │ │ ├── MaxValue.swift │ │ ├── MemcopyDemo.swift │ │ ├── RandomFill.swift │ │ ├── SIMDReduce.swift │ │ └── ThreadgroupLog.swift │ ├── Resources │ │ └── Media.xcassets │ │ │ ├── Contents.json │ │ │ ├── baboon-acorn-inverted.imageset │ │ │ ├── Contents.json │ │ │ └── baboon-acorn-inverted.png │ │ │ └── baboon.imageset │ │ │ ├── Contents.json │ │ │ └── baboon.png │ └── Support.swift └── MetalSupportLite │ ├── BaseSupport.swift │ ├── MTLBuffer+Extensions.swift │ ├── MetalBasicExtensions.swift │ ├── MetalSupportLite.swift │ └── TypedMTLBuffer.swift └── Tests ├── .swiftlint.yml └── ComputeTests └── ComputeTests.swift /.github/workflows/swift.yml: -------------------------------------------------------------------------------- 1 | name: Swift 2 | env: 3 | XCODE_VERSION: "latest-stable" 4 | on: 5 | push: 6 | pull_request: 7 | jobs: 8 | swift-build: 9 | runs-on: macos-15 # macos-latest 10 | steps: 11 | - uses: maxim-lobanov/setup-xcode@v1 12 | with: 13 | xcode-version: ${{ env.XCODE_VERSION }} 14 | - uses: actions/checkout@v3 15 | - run: swift build -v 16 | - run: swift test -v 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/swift,swiftpm,macos 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=swift,swiftpm,macos 3 | 4 | ### macOS ### 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | # Thumbnails 14 | ._* 15 | 16 | # Files that might appear in the root of a volume 17 | .DocumentRevisions-V100 18 | .fseventsd 19 | .Spotlight-V100 20 | .TemporaryItems 21 | .Trashes 22 | .VolumeIcon.icns 23 | .com.apple.timemachine.donotpresent 24 | 25 | # Directories potentially created on remote AFP share 26 | .AppleDB 27 | .AppleDesktop 28 | Network Trash Folder 29 | Temporary Items 30 | .apdisk 31 | 32 | ### macOS Patch ### 33 | # iCloud generated files 34 | *.icloud 35 | 36 | ### Swift ### 37 | # Xcode 38 | # 39 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 40 | 41 | ## User settings 42 | xcuserdata/ 43 | 44 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) 45 | *.xcscmblueprint 46 | *.xccheckout 47 | 48 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) 49 | build/ 50 | DerivedData/ 51 | *.moved-aside 52 | *.pbxuser 53 | !default.pbxuser 54 | *.mode1v3 55 | !default.mode1v3 56 | *.mode2v3 57 | !default.mode2v3 58 | *.perspectivev3 59 | !default.perspectivev3 60 | 61 | ## Obj-C/Swift specific 62 | *.hmap 63 | 64 | ## App packaging 65 | *.ipa 66 | *.dSYM.zip 67 | *.dSYM 68 | 69 | ## Playgrounds 70 | timeline.xctimeline 71 | playground.xcworkspace 72 | 73 | # Swift Package Manager 74 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 75 | # Packages/ 76 | # Package.pins 77 | # Package.resolved 78 | # *.xcodeproj 79 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 80 | # hence it is not needed unless you have added a package configuration file to your project 81 | # .swiftpm 82 | 83 | .build/ 84 | 85 | # CocoaPods 86 | # We recommend against adding the Pods directory to your .gitignore. However 87 | # you should judge for yourself, the pros and cons are mentioned at: 88 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 89 | # Pods/ 90 | # Add this line if you want to avoid checking in source code from the Xcode workspace 91 | # *.xcworkspace 92 | 93 | # Carthage 94 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 95 | # Carthage/Checkouts 96 | 97 | Carthage/Build/ 98 | 99 | # Accio dependency management 100 | Dependencies/ 101 | .accio/ 102 | 103 | # fastlane 104 | # It is recommended to not store the screenshots in the git repo. 105 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 106 | # For more information about the recommended setup visit: 107 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 108 | 109 | fastlane/report.xml 110 | fastlane/Preview.html 111 | fastlane/screenshots/**/*.png 112 | fastlane/test_output 113 | 114 | # Code Injection 115 | # After new code Injection tools there's a generated folder /iOSInjectionProject 116 | # https://github.com/johnno1962/injectionforxcode 117 | 118 | iOSInjectionProject/ 119 | 120 | ### SwiftPM ### 121 | Packages 122 | xcuserdata 123 | *.xcodeproj 124 | 125 | 126 | # End of https://www.toptal.com/developers/gitignore/api/swift,swiftpm,macos 127 | -------------------------------------------------------------------------------- /.justfile: -------------------------------------------------------------------------------- 1 | build-docs: 2 | xcrun xcodebuild docbuild -scheme Compute -derivedDataPath /tmp/compute-docbuild -destination platform=macOS,arch=arm64 3 | cp -r /tmp/compute-docbuild/Build/Products/Debug/Compute.doccarchive ~/Desktop 4 | 5 | xcrun docc process-archive transform-for-static-hosting ~/Desktop/Compute.doccarchive --hosting-base-path / --output-path ~/Desktop/Compute-HTML/ 6 | 7 | concurrency-check: 8 | swift clean 9 | swift build -Xswiftc -strict-concurrency=complete 10 | 11 | swift-six-check: 12 | swift clean 13 | SWIFT_VERSION=6 swift build --verbose 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | -------------------------------------------------------------------------------- /.spi.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | builder: 3 | configs: 4 | - documentation_targets: [Compute] 5 | -------------------------------------------------------------------------------- /.swiftlint.yml: -------------------------------------------------------------------------------- 1 | only_rules: 2 | - accessibility_label_for_image 3 | - accessibility_trait_for_button 4 | - anonymous_argument_in_multiline_closure 5 | # - anyobject_protocol 6 | - array_init 7 | - attributes 8 | - balanced_xctest_lifecycle 9 | - blanket_disable_command 10 | - block_based_kvo 11 | - capture_variable 12 | - class_delegate_protocol 13 | - closing_brace 14 | - closure_body_length 15 | - closure_end_indentation 16 | - closure_parameter_position 17 | - closure_spacing 18 | - collection_alignment 19 | - colon 20 | - comma 21 | - comma_inheritance 22 | - comment_spacing 23 | - compiler_protocol_init 24 | - computed_accessors_order 25 | - conditional_returns_on_newline 26 | - contains_over_filter_count 27 | - contains_over_filter_is_empty 28 | - contains_over_first_not_nil 29 | - contains_over_range_nil_comparison 30 | - control_statement 31 | - convenience_type 32 | - custom_rules 33 | - cyclomatic_complexity 34 | - deployment_target 35 | - direct_return 36 | - discarded_notification_center_observer 37 | - discouraged_assert 38 | - discouraged_direct_init 39 | - discouraged_none_name 40 | - discouraged_object_literal 41 | - discouraged_optional_boolean 42 | # - discouraged_optional_collection # 1 violations 43 | - duplicate_conditions 44 | - duplicate_enum_cases 45 | - duplicate_imports 46 | - duplicated_key_in_dictionary_literal 47 | - dynamic_inline 48 | - empty_collection_literal 49 | - empty_count 50 | - empty_enum_arguments 51 | - empty_parameters 52 | - empty_parentheses_with_trailing_closure 53 | - empty_string 54 | - empty_xctest_method 55 | - enum_case_associated_values_count 56 | - expiring_todo 57 | # - explicit_acl # 24 violations 58 | - explicit_enum_raw_value 59 | # - explicit_init 60 | - explicit_self 61 | - explicit_top_level_acl 62 | # - explicit_type_interface 63 | - extension_access_modifier 64 | - fallthrough 65 | - fatal_error_message 66 | - file_header 67 | - file_length 68 | # - file_name # 1 violations 69 | - file_name_no_space 70 | # - file_types_order 71 | - final_test_case 72 | - first_where 73 | - flatmap_over_map_reduce 74 | - for_where 75 | - force_cast 76 | - force_try 77 | - force_unwrapping 78 | - function_body_length 79 | # - function_default_parameter_at_end 80 | - function_parameter_count 81 | - generic_type_name 82 | - ibinspectable_in_extension 83 | - identical_operands 84 | - identifier_name 85 | - implicit_getter 86 | - implicit_return 87 | - implicitly_unwrapped_optional 88 | - inclusive_language 89 | - indentation_width 90 | # - inert_defer 91 | - invalid_swiftlint_command 92 | - is_disjoint 93 | - joined_default_parameter 94 | - large_tuple 95 | - last_where 96 | - leading_whitespace 97 | - legacy_cggeometry_functions 98 | - legacy_constant 99 | - legacy_constructor 100 | - legacy_hashing 101 | - legacy_multiple 102 | - legacy_nsgeometry_functions 103 | - legacy_objc_type 104 | - legacy_random 105 | - let_var_whitespace 106 | # - line_length # 21 violations 107 | - literal_expression_end_indentation 108 | - local_doc_comment 109 | - lower_acl_than_parent 110 | - mark 111 | - missing_docs 112 | - modifier_order 113 | - multiline_arguments 114 | - multiline_arguments_brackets 115 | - multiline_function_chains 116 | - multiline_literal_brackets 117 | - multiline_parameters 118 | - multiline_parameters_brackets 119 | - multiple_closures_with_trailing_closure 120 | - nesting 121 | - nimble_operator 122 | # - no_extension_access_modifier 123 | - no_fallthrough_only 124 | - no_grouping_extension 125 | # - no_magic_numbers # 3 violations 126 | - no_space_in_method_call 127 | - non_optional_string_data_conversion 128 | - non_overridable_class_declaration 129 | - notification_center_detachment 130 | - ns_number_init_as_function_reference 131 | - nslocalizedstring_key 132 | - nslocalizedstring_require_bundle 133 | - nsobject_prefer_isequal 134 | - number_separator 135 | - object_literal 136 | # - one_declaration_per_file # 2 violations 137 | - opening_brace 138 | - operator_usage_whitespace 139 | - operator_whitespace 140 | - optional_enum_case_matching 141 | - orphaned_doc_comment 142 | - overridden_super_call 143 | - override_in_extension 144 | - pattern_matching_keywords 145 | - period_spacing 146 | - prefer_nimble 147 | - prefer_self_in_static_references 148 | - prefer_self_type_over_type_of_self 149 | - prefer_zero_over_explicit_init 150 | # - prefixed_toplevel_constant # 1 violations 151 | - private_action 152 | - private_outlet 153 | - private_over_fileprivate 154 | - private_subject 155 | - private_swiftui_state 156 | - private_unit_test 157 | - prohibited_interface_builder 158 | - prohibited_super_call 159 | - protocol_property_accessors_order 160 | - quick_discouraged_call 161 | - quick_discouraged_focused_test 162 | - quick_discouraged_pending_test 163 | - raw_value_for_camel_cased_codable_enum 164 | - reduce_boolean 165 | - reduce_into 166 | - redundant_discardable_let 167 | - redundant_nil_coalescing 168 | - redundant_objc_attribute 169 | - redundant_optional_initialization 170 | - redundant_self_in_closure 171 | - redundant_set_access_control 172 | - redundant_string_enum_value 173 | - redundant_type_annotation 174 | - redundant_void_return 175 | # - required_deinit 176 | - required_enum_case 177 | - return_arrow_whitespace 178 | - return_value_from_void_function 179 | - self_binding 180 | - self_in_property_initialization 181 | - shorthand_argument 182 | - shorthand_operator 183 | - shorthand_optional_binding 184 | - single_test_class 185 | - sorted_enum_cases 186 | - sorted_first_last 187 | - sorted_imports 188 | - statement_position 189 | - static_operator 190 | - static_over_final_class 191 | - strict_fileprivate 192 | - strong_iboutlet 193 | - superfluous_disable_command 194 | - superfluous_else 195 | - switch_case_alignment 196 | - switch_case_on_newline 197 | - syntactic_sugar 198 | - test_case_accessibility 199 | - todo 200 | - toggle_bool 201 | - trailing_closure 202 | - trailing_comma 203 | - trailing_newline 204 | - trailing_semicolon 205 | - trailing_whitespace 206 | - type_body_length 207 | # - type_contents_order # 3 violations 208 | - type_name 209 | - typesafe_array_init 210 | - unavailable_condition 211 | - unavailable_function 212 | - unhandled_throwing_task 213 | - unneeded_break_in_switch 214 | - unneeded_override 215 | - unneeded_parentheses_in_closure_argument 216 | - unneeded_synthesized_initializer 217 | - unowned_variable_capture 218 | - untyped_error_in_catch 219 | # - unused_capture_list 220 | - unused_closure_parameter 221 | - unused_control_flow_label 222 | - unused_declaration 223 | - unused_enumerated 224 | - unused_import 225 | - unused_optional_binding 226 | - unused_setter_value 227 | - valid_ibinspectable 228 | - vertical_parameter_alignment 229 | - vertical_parameter_alignment_on_call 230 | - vertical_whitespace 231 | - vertical_whitespace_between_cases 232 | - vertical_whitespace_closing_braces 233 | - vertical_whitespace_opening_braces 234 | - void_function_in_ternary 235 | - void_return 236 | - weak_delegate 237 | - xct_specific_matcher 238 | - xctfail_message 239 | - yoda_condition 240 | 241 | excluded: 242 | - .build 243 | -------------------------------------------------------------------------------- /.swiftpm/configuration/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "ec7dcd2fd869101534d790fb80b9a276d9dea53d259be4371c2e3fd99126356f", 3 | "pins" : [ 4 | { 5 | "identity" : "everything", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/schwa/Everything", 8 | "state" : { 9 | "branch" : "jwight/swift-6", 10 | "revision" : "e24a8c69a41cfb9abdae4e91a4b7f056f2e36bbb" 11 | } 12 | }, 13 | { 14 | "identity" : "metalcompilerplugin", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/schwa/MetalCompilerPlugin", 17 | "state" : { 18 | "branch" : "jwight/logging", 19 | "revision" : "86239f9d8a6610e6ab2c335a3ef6e96fdd242e48" 20 | } 21 | }, 22 | { 23 | "identity" : "swift-algorithms", 24 | "kind" : "remoteSourceControl", 25 | "location" : "https://github.com/apple/swift-algorithms", 26 | "state" : { 27 | "revision" : "f6919dfc309e7f1b56224378b11e28bab5bccc42", 28 | "version" : "1.2.0" 29 | } 30 | }, 31 | { 32 | "identity" : "swift-argument-parser", 33 | "kind" : "remoteSourceControl", 34 | "location" : "https://github.com/apple/swift-argument-parser", 35 | "state" : { 36 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c", 37 | "version" : "1.5.0" 38 | } 39 | }, 40 | { 41 | "identity" : "swift-async-algorithms", 42 | "kind" : "remoteSourceControl", 43 | "location" : "https://github.com/apple/swift-async-algorithms", 44 | "state" : { 45 | "revision" : "9cfed92b026c524674ed869a4ff2dcfdeedf8a2a", 46 | "version" : "0.1.0" 47 | } 48 | }, 49 | { 50 | "identity" : "swift-collections", 51 | "kind" : "remoteSourceControl", 52 | "location" : "https://github.com/apple/swift-collections.git", 53 | "state" : { 54 | "revision" : "9bf03ff58ce34478e66aaee630e491823326fd06", 55 | "version" : "1.1.3" 56 | } 57 | }, 58 | { 59 | "identity" : "swift-numerics", 60 | "kind" : "remoteSourceControl", 61 | "location" : "https://github.com/apple/swift-numerics.git", 62 | "state" : { 63 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 64 | "version" : "1.0.2" 65 | } 66 | }, 67 | { 68 | "identity" : "swift-syntax", 69 | "kind" : "remoteSourceControl", 70 | "location" : "https://github.com/apple/swift-syntax.git", 71 | "state" : { 72 | "revision" : "515f79b522918f83483068d99c68daeb5116342d", 73 | "version" : "600.0.0-prerelease-2024-08-20" 74 | } 75 | }, 76 | { 77 | "identity" : "swiftfields", 78 | "kind" : "remoteSourceControl", 79 | "location" : "https://github.com/schwa/swiftfields", 80 | "state" : { 81 | "revision" : "8a9715b88509c93557833018f335397153f194bb", 82 | "version" : "0.1.3" 83 | } 84 | }, 85 | { 86 | "identity" : "swiftformats", 87 | "kind" : "remoteSourceControl", 88 | "location" : "https://github.com/schwa/swiftformats", 89 | "state" : { 90 | "revision" : "6b51ccec7fccf2f1d4f28f03def6c6b97d17c05b", 91 | "version" : "0.3.6" 92 | } 93 | }, 94 | { 95 | "identity" : "swiftgltf", 96 | "kind" : "remoteSourceControl", 97 | "location" : "https://github.com/schwa/SwiftGLTF", 98 | "state" : { 99 | "branch" : "main", 100 | "revision" : "4160ed7e89b7cfcccf5ff7a0cbfd5b055fada771" 101 | } 102 | }, 103 | { 104 | "identity" : "swiftgraphics", 105 | "kind" : "remoteSourceControl", 106 | "location" : "https://github.com/schwa/SwiftGraphics", 107 | "state" : { 108 | "branch" : "jwight/develop", 109 | "revision" : "ac3d19558f0ecb8249fd83d97a9cc87f37be24ee" 110 | } 111 | }, 112 | { 113 | "identity" : "wrappinghstack", 114 | "kind" : "remoteSourceControl", 115 | "location" : "https://github.com/ksemianov/WrappingHStack", 116 | "state" : { 117 | "revision" : "3300f68b6bf5f8a75ee7ca8a40f136a558053d10", 118 | "version" : "0.2.0" 119 | } 120 | } 121 | ], 122 | "version" : 3 123 | } 124 | -------------------------------------------------------------------------------- /.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.swiftpm/xcode/xcshareddata/xcschemes/Compute-Package.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 30 | 36 | 37 | 38 | 44 | 50 | 51 | 52 | 53 | 54 | 60 | 61 | 63 | 69 | 70 | 71 | 72 | 73 | 83 | 84 | 90 | 91 | 92 | 93 | 99 | 100 | 106 | 107 | 108 | 109 | 111 | 112 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /.swiftpm/xcode/xcshareddata/xcschemes/Compute.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 43 | 44 | 50 | 51 | 57 | 58 | 59 | 60 | 62 | 63 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /.swiftpm/xcode/xcshareddata/xcschemes/Examples.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 43 | 45 | 51 | 52 | 53 | 54 | 57 | 58 | 61 | 62 | 65 | 66 | 67 | 68 | 74 | 76 | 82 | 83 | 84 | 85 | 87 | 88 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "type": "lldb", 5 | "request": "launch", 6 | "sourceLanguages": [ 7 | "swift" 8 | ], 9 | "args": [], 10 | "cwd": "${workspaceFolder:Compute}", 11 | "name": "Debug Examples", 12 | "program": "${workspaceFolder:Compute}/.build/debug/Examples", 13 | "preLaunchTask": "swift: Build Debug Examples" 14 | }, 15 | { 16 | "type": "lldb", 17 | "request": "launch", 18 | "sourceLanguages": [ 19 | "swift" 20 | ], 21 | "args": [], 22 | "cwd": "${workspaceFolder:Compute}", 23 | "name": "Release Examples", 24 | "program": "${workspaceFolder:Compute}/.build/release/Examples", 25 | "preLaunchTask": "swift: Build Release Examples" 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /Documentation/Screenshot 2024-08-04 at 09.57.19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schwa/Compute/b1ba1e1071c912007bdbb59313e175887335fab8/Documentation/Screenshot 2024-08-04 at 09.57.19.png -------------------------------------------------------------------------------- /Documentation/compute-logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jonathan Wight 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "aec40d709f0bbfd74d34611e9ee223a92bb57eb697224186f914d2d3cd91a04c", 3 | "pins" : [ 4 | { 5 | "identity" : "metalcompilerplugin", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/schwa/MetalCompilerPlugin", 8 | "state" : { 9 | "revision" : "1db3b5dcabd648e0316f3a472646e72af24a5bde", 10 | "version" : "0.1.0" 11 | } 12 | } 13 | ], 14 | "version" : 3 15 | } 16 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 6.0 2 | 3 | import PackageDescription 4 | 5 | // swiftlint:disable:next explicit_top_level_acl 6 | let package = Package( 7 | name: "Compute", 8 | platforms: [.macOS(.v15), .iOS(.v18)], 9 | products: [ 10 | .library(name: "Compute", targets: ["Compute"]), 11 | ], 12 | dependencies: [ 13 | .package(url: "https://github.com/schwa/MetalCompilerPlugin", from: "0.1.0") 14 | ], 15 | targets: [ 16 | .target(name: "Compute"), 17 | .target(name: "MetalSupportLite"), 18 | .executableTarget( 19 | name: "Examples", 20 | dependencies: [ 21 | "Compute", 22 | "MetalSupportLite", 23 | ], 24 | resources: [ 25 | .copy("Bundle.txt"), 26 | .process("Resources/Media.xcassets") 27 | ], 28 | plugins: [ 29 | .plugin(name: "MetalCompilerPlugin", package: "MetalCompilerPlugin") 30 | ] 31 | ), 32 | .testTarget(name: "ComputeTests", dependencies: ["Compute"]) 33 | ] 34 | ) 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![alt text](Documentation/compute-logo.svg) 2 | 3 | # Compute 4 | 5 | This project provides a high-level Swift framework for working with Metal compute shaders. It simplifies the process of setting up and executing compute tasks on GPU using Apple's Metal API. 6 | 7 | ## Inspiration 8 | 9 | This project draws inspiration from Apple's SwiftUI shaders introduced in iOS 17 and macOS Sonoma. SwiftUI shaders demonstrate how GPU operations can be made more accessible to developers, enabling complex visual effects with simple Swift code. Compute aims to provide a similar level of abstraction for Metal compute shaders, making it easier to perform data-parallel computations on the GPU. 10 | 11 | ## Usage Example 12 | 13 | Here's a quick example of how to use the Metal Compute Framework to perform a basic computation: 14 | 15 | ```swift 16 | import Compute 17 | import Metal 18 | 19 | // Example usage that adds two arrays of integers using Compute. 20 | 21 | // Metal shader source code 22 | let source = """ 23 | #include 24 | using namespace metal; 25 | 26 | kernel void add(device int* inA [[buffer(0)]], 27 | device int* inB [[buffer(1)]], 28 | device int* result [[buffer(2)]], 29 | uint id [[thread_position_in_grid]]) { 30 | result[id] = inA[id] + inB[id]; 31 | } 32 | """ 33 | 34 | // Set up the compute environment 35 | let device = MTLCreateSystemDefaultDevice()! 36 | let compute = try Computer(device: device) 37 | 38 | // Create input data 39 | let count = 1000 40 | let inA = [Int32](repeating: 1, count: count) 41 | let inB = [Int32](repeating: 2, count: count) 42 | var result = [Int32](repeating: 0, count: count) 43 | 44 | // Create Metal buffers 45 | let bufferA = device.makeBuffer(bytes: inA, length: MemoryLayout.stride * count, options: [])! 46 | let bufferB = device.makeBuffer(bytes: inB, length: MemoryLayout.stride * count, options: [])! 47 | let bufferResult = device.makeBuffer(length: MemoryLayout.stride * count, options: [])! 48 | 49 | // Create a shader library and function 50 | let library = ShaderLibrary.source(source) 51 | let function = library.add 52 | 53 | // Create a compute pipeline and bind arguments. 54 | var pipeline = try compute.makePipeline(function: function) 55 | pipeline.arguments.inA = .buffer(bufferA) 56 | pipeline.arguments.inB = .buffer(bufferB) 57 | pipeline.arguments.result = .buffer(bufferResult) 58 | 59 | // Run the compute pipeline 60 | try compute.run(pipeline: pipeline, width: count) 61 | 62 | // Read back the results 63 | var bufferPointer = bufferResult.contents() 64 | let bufferRawBuffer32 = bufferPointer.bindMemory(to: Int32.self, capacity: bufferResult.length) 65 | let bufferBufferPointer = UnsafeBufferPointer(start: bufferRawBuffer32, count: count) 66 | 67 | // Verify the results 68 | for i in 0..) 113 | 114 | ## License 115 | 116 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 117 | 118 | ## Contributing 119 | 120 | Contributions are welcome to the Compute Framework. 121 | 122 | Note: Some of our initial documentation and tests were AI-generated. 123 | 124 | ## Links 125 | 126 | > [Metal Overview - Apple Developer](https://developer.apple.com/metal/) 127 | 128 | Apple's main Metal documentation. 129 | 130 | > - [developer.apple.com/metal/Metal-Shading-Language-Specification.pdf](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf) 131 | 132 | The Metal Shading Language Specification book. This is the definitive guide to writing shaders in Metal. 133 | 134 | > - [developer.apple.com/metal/Metal-Feature-Set-Tables.pdf](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf) 135 | 136 | The Metal Feature Set Tables book. This is a reference for which features are available on which devices/Metal versions. 137 | 138 | > - [Metal by Example – High-performance graphics and data-parallel programming for iOS and macOS](https://metalbyexample.com) 139 | 140 | Warren Moore's blog is the single best resource for learning Metal programming. 141 | 142 | > - [Introduction to Compute Programming in Metal – Metal by Example](https://metalbyexample.com/introduction-to-compute/) 143 | 144 | Warren has some posts on Compute programming in Metal but they're showing their age a bit. Nevertheless, they're a good starting point. 145 | 146 | > - [Shader | Apple Developer Documentation](https://developer.apple.com/documentation/swiftui/shader) 147 | 148 | SwiftUI's Shader was the primary inspiration for this project. 149 | 150 | > - [Calculating Threadgroup and Grid Sizes | Apple Developer Documentation](https://developer.apple.com/documentation/metal/compute_passes/calculating_threadgroup_and_grid_sizes) 151 | > - [Creating Threads and Threadgroups | Apple Developer Documentation](https://developer.apple.com/documentation/metal/compute_passes/creating_threads_and_threadgroups) 152 | 153 | How to calculate threadgroup and grid sizes. This is a critical concept in Metal compute programming. 154 | -------------------------------------------------------------------------------- /Sources/Compute/Compute+Arguments.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | import simd 3 | import SwiftUI 4 | 5 | public extension Compute { 6 | /// A structure that holds and manages arguments for a compute pipeline. 7 | /// 8 | /// This structure uses dynamic member lookup to provide a convenient way to access and set arguments. 9 | @dynamicMemberLookup 10 | struct Arguments { 11 | /// The underlying dictionary storing the arguments. 12 | internal var arguments: [String: Argument] 13 | 14 | /// Provides access to arguments using dynamic member lookup. 15 | /// 16 | /// - Parameter name: The name of the argument. 17 | /// - Returns: The argument value if it exists, or nil if it doesn't. 18 | public subscript(dynamicMember name: String) -> Argument? { 19 | get { 20 | arguments[name] 21 | } 22 | set { 23 | arguments[name] = newValue 24 | // NOTE: It would be nice to assign name as a label to buffers/textures that have no name. 25 | } 26 | } 27 | } 28 | 29 | /// Represents an argument that can be passed to a compute shader. 30 | /// 31 | /// This struct encapsulates the logic for encoding the argument to a compute command encoder 32 | /// and setting it as a constant value in a function constant values object. 33 | struct Argument { 34 | /// A closure that encodes the argument to a compute command encoder. 35 | internal var encode: (MTLComputeCommandEncoder, Int) -> Void 36 | 37 | /// A closure that sets the argument as a constant value in a function constant values object. 38 | internal var constantValue: (MTLFunctionConstantValues, String) -> Void 39 | 40 | public init(encode: @escaping (MTLComputeCommandEncoder, Int) -> Void, constantValue: @escaping (MTLFunctionConstantValues, String) -> Void) { 41 | self.encode = encode 42 | self.constantValue = constantValue 43 | } 44 | } 45 | 46 | } 47 | 48 | 49 | public extension Compute.Argument { 50 | 51 | /// Creates an integer argument. 52 | /// 53 | /// - Parameter value: The integer value. 54 | /// - Returns: An `Argument` instance representing the integer. 55 | static func int(_ value: some BinaryInteger) -> Self { 56 | .init { encoder, index in 57 | withUnsafeBytes(of: value) { buffer in 58 | guard let baseAddress = buffer.baseAddress else { 59 | fatalError("Could not get baseAddress.") 60 | } 61 | encoder.setBytes(baseAddress, length: buffer.count, index: index) 62 | } 63 | } 64 | constantValue: { constants, name in 65 | withUnsafeBytes(of: value) { buffer in 66 | guard let baseAddress = buffer.baseAddress else { 67 | fatalError("Could not get baseAddress.") 68 | } 69 | switch value { 70 | case is Int8: 71 | constants.setConstantValue(baseAddress, type: .char, withName: name) 72 | 73 | case is UInt8: 74 | constants.setConstantValue(baseAddress, type: .uchar, withName: name) 75 | 76 | case is Int16: 77 | constants.setConstantValue(baseAddress, type: .short, withName: name) 78 | 79 | case is UInt16: 80 | constants.setConstantValue(baseAddress, type: .ushort, withName: name) 81 | 82 | case is Int32: 83 | constants.setConstantValue(baseAddress, type: .int, withName: name) 84 | 85 | case is UInt32: 86 | constants.setConstantValue(baseAddress, type: .uint, withName: name) 87 | 88 | default: 89 | fatalError("Unsupported integer type.") 90 | } 91 | } 92 | } 93 | } 94 | 95 | /// Creates a float argument. 96 | /// 97 | /// - Parameter value: The float value. 98 | /// - Returns: An `Argument` instance representing the float. 99 | static func float(_ value: Float) -> Self { 100 | .init { encoder, index in 101 | withUnsafeBytes(of: value) { buffer in 102 | guard let baseAddress = buffer.baseAddress else { 103 | fatalError("Could not get baseAddress.") 104 | } 105 | encoder.setBytes(baseAddress, length: buffer.count, index: index) 106 | } 107 | } 108 | constantValue: { constants, name in 109 | withUnsafeBytes(of: value) { buffer in 110 | guard let baseAddress = buffer.baseAddress else { 111 | fatalError("Could not get baseAddress.") 112 | } 113 | constants.setConstantValue(baseAddress, type: .float, withName: name) 114 | } 115 | } 116 | } 117 | 118 | /// Creates a boolean argument. 119 | /// 120 | /// - Parameter value: The boolean value. 121 | /// - Returns: An `Argument` instance representing the boolean. 122 | static func bool(_ value: Bool) -> Self { 123 | .init { encoder, index in 124 | withUnsafeBytes(of: value) { buffer in 125 | guard let baseAddress = buffer.baseAddress else { 126 | fatalError("Could not get baseAddress.") 127 | } 128 | encoder.setBytes(baseAddress, length: buffer.count, index: index) 129 | } 130 | } 131 | constantValue: { constants, name in 132 | withUnsafeBytes(of: value) { buffer in 133 | guard let baseAddress = buffer.baseAddress else { 134 | fatalError("Could not get baseAddress.") 135 | } 136 | constants.setConstantValue(baseAddress, type: .bool, withName: name) 137 | } 138 | } 139 | } 140 | 141 | /// Creates a buffer argument. 142 | /// 143 | /// - Parameters: 144 | /// - buffer: The Metal buffer to be used as an argument. 145 | /// - offset: The offset within the buffer. Defaults to 0. 146 | /// - Returns: An `Argument` instance representing the buffer. 147 | static func buffer(_ buffer: MTLBuffer, offset: Int = 0) -> Self { 148 | .init { encoder, index in 149 | encoder.setBuffer(buffer, offset: offset, index: index) 150 | } 151 | constantValue: { _, _ in 152 | fatalError("Unimplemented") 153 | } 154 | } 155 | 156 | /// Creates a texture argument. 157 | /// 158 | /// - Parameter texture: The Metal texture to be used as an argument. 159 | /// - Returns: An `Argument` instance representing the texture. 160 | static func texture(_ texture: MTLTexture) -> Self { 161 | Self { encoder, index in 162 | encoder.setTexture(texture, index: index) 163 | } 164 | constantValue: { _, _ in 165 | fatalError("Unimplemented") 166 | } 167 | } 168 | 169 | /// Creates an argument from a simd vector 170 | /// 171 | /// - Parameter value: The vector value. 172 | /// - Returns: An `Argument` instance representing the boolean. 173 | static func vector(_ value: V) -> Self where V: SIMD { 174 | .init { encoder, index in 175 | withUnsafeBytes(of: value) { buffer in 176 | guard let baseAddress = buffer.baseAddress else { 177 | fatalError("Could not get baseAddress.") 178 | } 179 | encoder.setBytes(baseAddress, length: buffer.count, index: index) 180 | } 181 | } 182 | constantValue: { _, _ in 183 | fatalError("Unimplemented") 184 | } 185 | } 186 | 187 | /// Creates an argument from a simd vector 188 | /// 189 | /// - Parameter value: The vector value. 190 | /// - Returns: An `Argument` instance representing the boolean. 191 | static func float4(_ value: SIMD4) -> Self { 192 | .vector(value) 193 | } 194 | 195 | /// Creates an argument from a simd vector 196 | /// 197 | /// - Parameter value: The vector value. 198 | /// - Returns: An `Argument` instance representing the boolean. 199 | static func color(_ value: Color) throws -> Self { 200 | let cgColor = value.resolve(in: .init()).cgColor 201 | guard let colorSpace = CGColorSpace(name: CGColorSpace.genericRGBLinear) else { 202 | throw ComputeError.resourceCreationFailure 203 | } 204 | guard let components = cgColor.converted(to: colorSpace, intent: .defaultIntent, options: nil)?.components else { 205 | throw ComputeError.resourceCreationFailure 206 | } 207 | 208 | // TODO: This assumes the pass wants a SIMD4 - we can use reflection to work out what is needed and convert appropriately. This will mean we need to refactor Compute.Argument 209 | 210 | let vector = SIMD4([Float(components[0]), Float(components[1]), Float(components[2]), Float(components[3])]) 211 | return .vector(vector) 212 | } 213 | 214 | /// Creates an argument from a threadgroup memory length 215 | static func threadgroupMemoryLength(_ value: Int) -> Self { 216 | Self { encoder, index in 217 | encoder.setThreadgroupMemoryLength(value, index: index) 218 | } 219 | constantValue: { _, _ in 220 | fatalError("Unimplemented") 221 | } 222 | } 223 | 224 | static func buffer(_ array: [T], offset: Int = 0, label: String? = nil) -> Self { 225 | .init { encoder, index in 226 | let buffer = array.withUnsafeBufferPointer { buffer in 227 | return encoder.device.makeBuffer(bytes: buffer.baseAddress!, length: buffer.count * MemoryLayout.stride, options: []) 228 | } 229 | if let label { 230 | buffer?.label = label 231 | } 232 | encoder.setBuffer(buffer, offset: offset, index: index) 233 | } 234 | constantValue: { _, _ in 235 | fatalError("Unimplemented") 236 | } 237 | } 238 | 239 | } 240 | -------------------------------------------------------------------------------- /Sources/Compute/Compute+Pipeline.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | 3 | public extension Compute { 4 | /// Represents a compute pipeline, encapsulating all the information needed to execute a compute operation. 5 | /// 6 | /// A `Pipeline` includes the shader function, its arguments, and the associated compute pipeline state. 7 | /// It provides the necessary context for dispatching compute operations on the GPU. 8 | struct Pipeline { 9 | /// The shader function associated with this pipeline. 10 | public let function: ShaderFunction 11 | 12 | /// A dictionary mapping argument names to their binding indices. 13 | internal let bindings: [String: Int] 14 | 15 | /// The arguments to be passed to the shader function. 16 | public var arguments: Arguments 17 | 18 | /// The compute pipeline state created from the shader function. 19 | public let computePipelineState: MTLComputePipelineState 20 | 21 | /// Initializes a new compute pipeline. 22 | /// 23 | /// This initializer creates a compute pipeline state from the provided shader function and sets up the necessary bindings and arguments. 24 | /// 25 | /// - Parameters: 26 | /// - device: The Metal device on which the compute pipeline will be executed. 27 | /// - function: The shader function to be used in this pipeline. 28 | /// - constants: A dictionary of constant values to be used when compiling the shader function. Defaults to an empty dictionary. 29 | /// - arguments: A dictionary of arguments to be passed to the shader function. Defaults to an empty dictionary. 30 | /// - Throws: An error if the compute pipeline state cannot be created or if there's an issue with the shader function. 31 | internal init(device: MTLDevice, function: ShaderFunction, constants: [String: Argument] = [:], arguments: [String: Argument] = [:]) throws { 32 | self.function = function 33 | 34 | let constantValues = MTLFunctionConstantValues() 35 | for (name, constant) in constants { 36 | constant.constantValue(constantValues, name) 37 | } 38 | 39 | let library = try function.library.makelibrary(device: device) 40 | 41 | let function = try library.makeFunction(name: function.name, constantValues: constantValues) 42 | function.label = "\(function.name)-MTLFunction" 43 | let computePipelineDescriptor = MTLComputePipelineDescriptor() 44 | computePipelineDescriptor.label = "\(function.name)-MTLComputePipelineState" 45 | computePipelineDescriptor.computeFunction = function 46 | let (computePipelineState, reflection) = try device.makeComputePipelineState(descriptor: computePipelineDescriptor, options: [.bindingInfo]) 47 | guard let reflection else { 48 | throw ComputeError.resourceCreationFailure 49 | } 50 | bindings = Dictionary(uniqueKeysWithValues: reflection.bindings.map { binding in 51 | (binding.name, binding.index) 52 | }) 53 | 54 | self.computePipelineState = computePipelineState 55 | self.arguments = Arguments(arguments: arguments) 56 | } 57 | 58 | /// The maximum total number of threads per threadgroup for this compute pipeline state. 59 | public var maxTotalThreadsPerThreadgroup: Int { 60 | computePipelineState.maxTotalThreadsPerThreadgroup 61 | } 62 | 63 | /// The thread execution width for this compute pipeline state. 64 | public var threadExecutionWidth: Int { 65 | computePipelineState.threadExecutionWidth 66 | } 67 | 68 | /// Binds the arguments to the provided compute command encoder. 69 | /// 70 | /// This method sets up the arguments for the compute operation, associating each argument with its corresponding binding point. 71 | /// 72 | /// - Parameter commandEncoder: The compute command encoder to which the arguments should be bound. 73 | /// - Throws: `ComputeError.missingBinding` if a required binding is not found for an argument. 74 | func bind(_ commandEncoder: MTLComputeCommandEncoder) throws { 75 | for (name, value) in arguments.arguments { 76 | guard let index = bindings[name] else { 77 | throw ComputeError.missingBinding(name) 78 | } 79 | value.encode(commandEncoder, index) 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /Sources/Compute/Compute+Task.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | import os 3 | 4 | public extension Compute { 5 | /// Represents a compute task that can be executed on the GPU. 6 | /// 7 | /// A `Task` encapsulates a Metal command buffer and provides methods to execute compute operations. 8 | struct Task { 9 | /// The label for this task, used for debugging and profiling. 10 | let label: String? 11 | 12 | /// The logger used for logging information about the task execution. 13 | let logger: Logger? 14 | 15 | /// The Metal command buffer associated with this task. 16 | let commandBuffer: MTLCommandBuffer 17 | 18 | /// Executes a block of code with a `Dispatcher`. 19 | /// 20 | /// This method is a convenience wrapper around the `run` method. 21 | /// 22 | /// - Parameter block: A closure that takes a `Dispatcher` and returns a result. 23 | /// - Returns: The result of the block execution. 24 | /// - Throws: Any error that occurs during the execution of the block. 25 | public func callAsFunction(_ block: (Dispatcher) throws -> R) throws -> R { 26 | try run(block) 27 | } 28 | 29 | /// Runs a block of code with a `Dispatcher`. 30 | /// 31 | /// This method creates a compute command encoder and executes the provided block with a `Dispatcher`. 32 | /// 33 | /// - Parameter block: A closure that takes a `Dispatcher` and returns a result. 34 | /// - Returns: The result of the block execution. 35 | /// - Throws: `ComputeError.resourceCreationFailure` if unable to create a compute command encoder, 36 | /// or any error that occurs during the execution of the block. 37 | public func run(_ block: (Dispatcher) throws -> R) throws -> R { 38 | guard let commandEncoder = commandBuffer.makeComputeCommandEncoder() else { 39 | throw ComputeError.resourceCreationFailure 40 | } 41 | commandEncoder.label = "\(label ?? "Unlabeled")-MTLComputeCommandEncoder" 42 | 43 | defer { 44 | commandEncoder.endEncoding() 45 | } 46 | let dispatcher = Dispatcher(label: label, logger: logger, commandEncoder: commandEncoder) 47 | return try block(dispatcher) 48 | } 49 | } 50 | 51 | /// Handles the dispatching of compute operations to the GPU. 52 | /// 53 | /// A `Dispatcher` is responsible for setting up and executing compute operations using the provided `Pipeline`. 54 | struct Dispatcher { 55 | /// The label for this dispatcher, used for debugging and profiling. 56 | public let label: String? 57 | 58 | /// The logger used for logging information about the dispatch operation. 59 | public let logger: Logger? 60 | 61 | /// The Metal compute command encoder used to encode compute commands. 62 | public let commandEncoder: MTLComputeCommandEncoder 63 | 64 | /// Dispatches a compute operation using threadgroups. 65 | /// 66 | /// - Parameters: 67 | /// - pipeline: The `Pipeline` containing the compute pipeline state and arguments. 68 | /// - threadgroupsPerGrid: The number of threadgroups to dispatch in each dimension. 69 | /// - threadsPerThreadgroup: The number of threads in each threadgroup. 70 | /// - Throws: Any error that occurs during the binding of arguments or dispatch. 71 | public func callAsFunction(pipeline: Pipeline, threadgroupsPerGrid: MTLSize, threadsPerThreadgroup: MTLSize) throws { 72 | logger?.info("Dispatching \(threadgroupsPerGrid.shortDescription) threadgroups per grid with \(threadsPerThreadgroup.shortDescription) threads per threadgroup. maxTotalThreadsPerThreadgroup: \(pipeline.computePipelineState.maxTotalThreadsPerThreadgroup), threadExecutionWidth: \(pipeline.computePipelineState.threadExecutionWidth).") 73 | 74 | commandEncoder.setComputePipelineState(pipeline.computePipelineState) 75 | try pipeline.bind(commandEncoder) 76 | commandEncoder.dispatchThreadgroups(threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) 77 | } 78 | 79 | /// Dispatches a compute operation using a specific number of threads. 80 | /// 81 | /// - Parameters: 82 | /// - pipeline: The `Pipeline` containing the compute pipeline state and arguments. 83 | /// - threads: The total number of threads to dispatch in each dimension. 84 | /// - threadsPerThreadgroup: The number of threads in each threadgroup. 85 | /// - Throws: Any error that occurs during the binding of arguments or dispatch. 86 | public func callAsFunction(pipeline: Pipeline, threads: MTLSize, threadsPerThreadgroup: MTLSize) throws { 87 | let device = commandEncoder.device 88 | guard device.supportsFamily(.apple4) || device.supportsFamily(.common3) || device.supportsFamily(.metal3) else { 89 | throw ComputeError.nonuniformThreadgroupsSizeNotSupported 90 | } 91 | 92 | logger?.info("Dispatch - threads: \(threads.shortDescription), threadsPerThreadgroup \(threadsPerThreadgroup.shortDescription), maxTotalThreadsPerThreadgroup: \(pipeline.computePipelineState.maxTotalThreadsPerThreadgroup), threadExecutionWidth: \(pipeline.computePipelineState.threadExecutionWidth).") 93 | 94 | commandEncoder.setComputePipelineState(pipeline.computePipelineState) 95 | commandEncoder.setComputePipelineState(pipeline.computePipelineState) 96 | try pipeline.bind(commandEncoder) 97 | commandEncoder.dispatchThreads(threads, threadsPerThreadgroup: threadsPerThreadgroup) 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /Sources/Compute/Compute.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | import os 3 | 4 | /// The main struct that encapsulates the Metal compute environment. 5 | /// 6 | /// This struct provides the core functionality for creating and executing compute tasks on the GPU. 7 | public struct Compute { 8 | /// The Metal device used for compute operations. 9 | public let device: MTLDevice 10 | 11 | /// The logger used for debugging and performance monitoring. 12 | let logger: Logger? 13 | 14 | /// The Metal command queue used for submitting command buffers. 15 | let commandQueue: MTLCommandQueue 16 | 17 | /// Initializes a new Compute instance. 18 | /// 19 | /// - Parameters: 20 | /// - device: The Metal device to use for compute operations. 21 | /// - logger: An optional logger for debugging and performance monitoring. 22 | /// - useLogState: Enable capture of Metal log messages if available. 23 | /// - Throws: `ComputeError.resourceCreationFailure` if unable to create the command queue. 24 | public init(device: MTLDevice, logger: Logger? = nil, useLogState: Bool = false) throws { 25 | self.device = device 26 | self.logger = logger 27 | #if !targetEnvironment(simulator) 28 | if #available(macOS 15, iOS 18, *) { 29 | let commandQueueDescriptor = MTLCommandQueueDescriptor() 30 | if useLogState { 31 | let logStateDescriptor = MTLLogStateDescriptor() 32 | logStateDescriptor.bufferSize = 16 * 1_024 * 1_024 33 | let logState = try device.makeLogState(descriptor: logStateDescriptor) 34 | logState.addLogHandler { _, _, _, message in 35 | logger?.log("\(message)") 36 | } 37 | commandQueueDescriptor.logState = logState 38 | } 39 | guard let commandQueue = device.makeCommandQueue(descriptor: commandQueueDescriptor) else { 40 | throw ComputeError.resourceCreationFailure 41 | } 42 | commandQueue.label = "Compute-MTLCommandQueue" 43 | self.commandQueue = commandQueue 44 | } else { 45 | guard let commandQueue = device.makeCommandQueue() else { 46 | throw ComputeError.resourceCreationFailure 47 | } 48 | commandQueue.label = "Compute-MTLCommandQueue" 49 | self.commandQueue = commandQueue 50 | } 51 | #else 52 | guard let commandQueue = device.makeCommandQueue() else { 53 | throw ComputeError.resourceCreationFailure 54 | } 55 | commandQueue.label = "Compute-MTLCommandQueue" 56 | self.commandQueue = commandQueue 57 | #endif 58 | } 59 | 60 | /// Executes a compute task. 61 | /// 62 | /// This method creates a command buffer and executes the provided block with a `Task` instance. 63 | /// 64 | /// - Parameters: 65 | /// - label: An optional label for the task, useful for debugging. 66 | /// - block: A closure that takes a `Task` instance and returns a result. 67 | /// - Returns: The result of the block execution. 68 | /// - Throws: `ComputeError.resourceCreationFailure` if unable to create the command buffer, 69 | /// or any error thrown by the provided block. 70 | public func task(label: String? = nil, _ block: (Task) throws -> R) throws -> R { 71 | let commandBufferDescriptor = MTLCommandBufferDescriptor() 72 | guard let commandBuffer = commandQueue.makeCommandBuffer(descriptor: commandBufferDescriptor) else { 73 | throw ComputeError.resourceCreationFailure 74 | } 75 | commandBuffer.label = "\(label ?? "Unlabeled")-MTLCommandBuffer" 76 | defer { 77 | commandBuffer.commit() 78 | commandBuffer.waitUntilCompleted() 79 | } 80 | let task = Task(label: label, logger: logger, commandBuffer: commandBuffer) 81 | return try block(task) 82 | } 83 | 84 | /// Creates a compute pipeline. 85 | /// 86 | /// This method creates a `Pipeline` instance, which encapsulates a compute pipeline state and its associated arguments. 87 | /// 88 | /// - Parameters: 89 | /// - function: The shader function to use for the pipeline. 90 | /// - constants: A dictionary of constant values to be used when compiling the shader function. 91 | /// - arguments: A dictionary of arguments to be passed to the shader function. 92 | /// - Returns: A new `Pipeline` instance. 93 | /// - Throws: Any error that occurs during the creation of the compute pipeline state. 94 | public func makePipeline(function: ShaderFunction, constants: [String: Argument] = [:], arguments: [String: Argument] = [:]) throws -> Pipeline { 95 | try Pipeline(device: device, function: function, constants: constants, arguments: arguments) 96 | } 97 | } 98 | 99 | /// Type alias for `Compute` to avoid name conflicts with module name. 100 | public typealias Compute_ = Compute 101 | -------------------------------------------------------------------------------- /Sources/Compute/ShaderFunction.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | 3 | /// Represents a Metal shader library. 4 | @dynamicMemberLookup 5 | public struct ShaderLibrary: Identifiable, Sendable { 6 | public enum ID: Hashable, Sendable { 7 | case bundle(URL, String?) 8 | case source(String) 9 | } 10 | 11 | /// The default shader library loaded from the main bundle. 12 | public static let `default` = Self.bundle(.main) 13 | 14 | /// Creates a shader library from a bundle. 15 | /// 16 | /// - Parameters: 17 | /// - bundle: The bundle containing the shader library. 18 | /// - name: The name of the metallib file (without extension). If nil, uses the default library. 19 | /// - Returns: A new ShaderLibrary instance. 20 | public static func bundle(_ bundle: Bundle, name: String? = nil) -> Self { 21 | return ShaderLibrary(id: .bundle(bundle.bundleURL, name)) { device in 22 | if let name { 23 | guard let url = bundle.url(forResource: name, withExtension: "metallib") else { 24 | fatalError("Could not load metallib.") 25 | } 26 | let library = try device.makeLibrary(URL: url) 27 | library.label = "\(name).MTLLibrary" 28 | return library 29 | } 30 | let library = try device.makeDefaultLibrary(bundle: bundle) 31 | library.label = "Default.MTLLibrary" 32 | return library 33 | } 34 | } 35 | 36 | /// Creates a shader library from source code. 37 | /// 38 | /// - Parameters: 39 | /// - source: The Metal shader source code as a string. 40 | /// - enableLogging: Enable Metal shader logging. 41 | /// - Returns: A new ShaderLibrary instance. 42 | public static func source(_ source: String, enableLogging: Bool = false) -> Self { 43 | return ShaderLibrary(id: .source(source)) { device in 44 | let options = MTLCompileOptions() 45 | #if !targetEnvironment(simulator) 46 | if enableLogging { 47 | if #available(macOS 15, iOS 18, *) { 48 | options.enableLogging = true 49 | } 50 | else { 51 | fatalError("Metal logging is not available on this platform.") 52 | } 53 | } 54 | #endif 55 | return try device.makeLibrary(source: source, options: options) 56 | } 57 | } 58 | 59 | public var id: ID 60 | 61 | /// A closure that creates an MTLLibrary given an MTLDevice. 62 | public var make: @Sendable (MTLDevice) throws -> MTLLibrary 63 | 64 | /// Creates a ShaderFunction with the given name. 65 | /// 66 | /// - Parameter name: The name of the shader function. 67 | /// - Returns: A new ShaderFunction instance. 68 | public func function(name: String) -> ShaderFunction { 69 | ShaderFunction(library: self, name: name) 70 | } 71 | 72 | /// Allows accessing shader functions using dynamic member lookup. 73 | /// 74 | /// - Parameter name: The name of the shader function. 75 | /// - Returns: A new ShaderFunction instance. 76 | public subscript(dynamicMember name: String) -> ShaderFunction { 77 | ShaderFunction(library: self, name: name) 78 | } 79 | 80 | /// Creates an MTLLibrary for the given device. 81 | /// 82 | /// - Parameter device: The Metal device to create the library for. 83 | /// - Returns: The created MTLLibrary. 84 | /// - Throws: An error if the library creation fails. 85 | internal func makelibrary(device: MTLDevice) throws -> MTLLibrary { 86 | try make(device) 87 | } 88 | } 89 | 90 | /// Represents a compute shader function within a shader library. 91 | public struct ShaderFunction: Identifiable { 92 | /// A unique identifier for the shader function. 93 | public let id: Composite 94 | 95 | /// The shader library containing this function. 96 | public let library: ShaderLibrary 97 | 98 | /// The name of the shader function. 99 | public let name: String 100 | 101 | /// An array of shader constants associated with this function. 102 | public let constants: [ShaderConstant] 103 | 104 | /// Initializes a new shader function. 105 | /// 106 | /// - Parameters: 107 | /// - library: The shader library containing this function. 108 | /// - name: The name of the shader function. 109 | /// - constants: An array of shader constants associated with this function. Defaults to an empty array. 110 | public init(library: ShaderLibrary, name: String, constants: [ShaderConstant] = []) { 111 | // BUG: https://github.com/schwa/Compute/issues/17 Make shader constants part of name id 112 | self.id = Composite(library.id, name) 113 | self.library = library 114 | self.name = name 115 | self.constants = constants 116 | } 117 | } 118 | 119 | /// Represents a constant value that can be passed to a shader function. 120 | public struct ShaderConstant { 121 | /// The data type of the constant. 122 | var dataType: MTLDataType 123 | 124 | /// A closure that provides access to the constant's value. 125 | var accessor: ((UnsafeRawPointer) -> Void) -> Void 126 | 127 | /// Initializes a new shader constant with an array value. 128 | /// 129 | /// - Parameters: 130 | /// - dataType: The Metal data type of the constant. 131 | /// - value: The array value of the constant. 132 | // BUG: https://github.com/schwa/Compute/issues/16 Get away from Any and use an enum. 133 | public init(dataType: MTLDataType, value: [some Any]) { 134 | self.dataType = dataType 135 | accessor = { (callback: (UnsafeRawPointer) -> Void) in 136 | value.withUnsafeBytes { pointer in 137 | guard let baseAddress = pointer.baseAddress else { 138 | fatalError("Could not get baseAddress.") 139 | } 140 | callback(baseAddress) 141 | } 142 | } 143 | } 144 | 145 | /// Initializes a new shader constant with a single value. 146 | /// 147 | /// - Parameters: 148 | /// - dataType: The Metal data type of the constant. 149 | /// - value: The value of the constant. 150 | // BUG: https://github.com/schwa/Compute/issues/16 Get away from Any and use an enum. 151 | public init(dataType: MTLDataType, value: some Any) { 152 | self.dataType = dataType 153 | accessor = { (callback: (UnsafeRawPointer) -> Void) in 154 | withUnsafeBytes(of: value) { pointer in 155 | guard let baseAddress = pointer.baseAddress else { 156 | fatalError("Could not get baseAddress.") 157 | } 158 | callback(baseAddress) 159 | } 160 | } 161 | } 162 | 163 | /// Adds the constant value to a MTLFunctionConstantValues object. 164 | /// 165 | /// - Parameters: 166 | /// - values: The MTLFunctionConstantValues object to add the constant to. 167 | /// - name: The name of the constant in the shader code. 168 | public func add(to values: MTLFunctionConstantValues, name: String) { 169 | accessor { pointer in 170 | values.setConstantValue(pointer, type: dataType, withName: name) 171 | } 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /Sources/Compute/Support.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | 3 | /// Enumerates the possible errors that can occur in the Compute framework. 4 | public enum ComputeError: Error { 5 | /// Indicates that a required binding for an argument is missing. 6 | /// - Parameter String: The name of the missing binding. 7 | case missingBinding(String) 8 | 9 | case nonuniformThreadgroupsSizeNotSupported 10 | 11 | /// Indicates a failure in creating a required resource, such as a command queue or buffer. 12 | case resourceCreationFailure 13 | 14 | } 15 | 16 | public extension Compute { 17 | /// Dispatches a compute operation using a more convenient syntax. 18 | /// 19 | /// This method wraps the `task` method, providing a simpler interface for dispatching compute operations. 20 | /// 21 | /// - Parameters: 22 | /// - label: An optional label for the dispatch operation, useful for debugging. 23 | /// - block: A closure that takes a `Dispatcher` and returns a result. 24 | /// - Returns: The result of the block execution. 25 | /// - Throws: Any error that occurs during the execution of the block or the underlying task. 26 | func dispatch(label: String? = nil, _ block: (Dispatcher) throws -> R) throws -> R { 27 | try task(label: label) { task in 28 | try task { dispatch in 29 | try block(dispatch) 30 | } 31 | } 32 | } 33 | } 34 | 35 | public extension Compute { 36 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, threads: MTLSize, threadsPerThreadgroup: MTLSize) throws { 37 | var pipeline = pipeline 38 | if let arguments { 39 | var existing = pipeline.arguments.arguments 40 | existing.merge(arguments) { $1 } 41 | pipeline.arguments = .init(arguments: existing) 42 | } 43 | try task { task in 44 | try task { dispatch in 45 | try dispatch(pipeline: pipeline, threads: threads, threadsPerThreadgroup: threadsPerThreadgroup) 46 | } 47 | } 48 | } 49 | 50 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, threadgroupsPerGrid: MTLSize, threadsPerThreadgroup: MTLSize) throws { 51 | var pipeline = pipeline 52 | if let arguments { 53 | var existing = pipeline.arguments.arguments 54 | existing.merge(arguments) { $1 } 55 | pipeline.arguments = .init(arguments: existing) 56 | } 57 | try task { task in 58 | try task { dispatch in 59 | try dispatch(pipeline: pipeline, threadgroupsPerGrid: threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) 60 | } 61 | } 62 | } 63 | 64 | 65 | /// Runs a compute pipeline with the specified arguments and thread count. 66 | /// 67 | /// This method provides a convenient way to execute a single compute pipeline with optional additional arguments. 68 | /// 69 | /// - Parameters: 70 | /// - pipeline: The compute pipeline to run. 71 | /// - arguments: Optional additional arguments to merge with the pipeline's existing arguments. 72 | /// - width: The number of threads to dispatch. 73 | /// - Throws: Any error that occurs during the execution of the compute pipeline. 74 | @available(*, deprecated, message: "Deprecated") 75 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, width: Int) throws { 76 | var pipeline = pipeline 77 | if let arguments { 78 | var existing = pipeline.arguments.arguments 79 | existing.merge(arguments) { $1 } 80 | pipeline.arguments = .init(arguments: existing) 81 | } 82 | try task { task in 83 | try task { dispatch in 84 | let maxTotalThreadsPerThreadgroup = pipeline.computePipelineState.maxTotalThreadsPerThreadgroup 85 | try dispatch(pipeline: pipeline, threads: MTLSize(width: width, height: 1, depth: 1), threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1, depth: 1)) 86 | } 87 | } 88 | } 89 | 90 | @available(*, deprecated, message: "Deprecated") 91 | func run(pipeline: Pipeline, arguments: [String: Argument]? = nil, width: Int, height: Int) throws { 92 | var pipeline = pipeline 93 | if let arguments { 94 | var existing = pipeline.arguments.arguments 95 | existing.merge(arguments) { $1 } 96 | pipeline.arguments = .init(arguments: existing) 97 | } 98 | try task { task in 99 | try task { dispatch in 100 | let maxTotalThreadsPerThreadgroup = pipeline.computePipelineState.maxTotalThreadsPerThreadgroup 101 | 102 | let threadsPerThreadgroupWidth = Int(sqrt(Double(maxTotalThreadsPerThreadgroup))) 103 | let threadsPerThreadgroupHeight = maxTotalThreadsPerThreadgroup / threadsPerThreadgroupWidth 104 | 105 | try dispatch(pipeline: pipeline, threads: MTLSize(width: width, height: height, depth: 1), threadsPerThreadgroup: MTLSize(width: threadsPerThreadgroupWidth, height: threadsPerThreadgroupHeight, depth: 1)) 106 | } 107 | } 108 | } 109 | } 110 | 111 | internal extension MTLSize { 112 | var shortDescription: String { 113 | return "[\(width), \(height), \(depth)]" 114 | } 115 | } 116 | 117 | // MARK: - 118 | 119 | public struct Composite { 120 | private let children: (repeat each T) 121 | 122 | public init(_ children: repeat each T) { 123 | self.children = (repeat each children) 124 | } 125 | } 126 | 127 | extension Composite: Equatable where repeat each T: Equatable { 128 | public static func == (lhs: Self, rhs: Self) -> Bool { 129 | for (left, right) in repeat (each lhs.children, each rhs.children) { 130 | guard left == right else { 131 | return false 132 | } 133 | } 134 | return true 135 | } 136 | } 137 | 138 | extension Composite: Hashable where repeat each T: Hashable { 139 | public func hash(into hasher: inout Hasher) { 140 | for child in repeat (each children) { 141 | child.hash(into: &hasher) 142 | } 143 | } 144 | } 145 | 146 | extension Composite: Sendable where repeat each T: Sendable { 147 | } 148 | 149 | -------------------------------------------------------------------------------- /Sources/Examples/.swiftlint.yml: -------------------------------------------------------------------------------- 1 | disabled_rules: 2 | - explicit_top_level_acl 3 | - fatal_error_message 4 | - force_cast 5 | - force_try 6 | - force_unwrapping 7 | - function_body_length 8 | - identifier_name 9 | - line_length 10 | - missing_docs 11 | -------------------------------------------------------------------------------- /Sources/Examples/Bundle.txt: -------------------------------------------------------------------------------- 1 | This file kept left here to act as a resource for SPM so we get a `Bundle.module`. 2 | -------------------------------------------------------------------------------- /Sources/Examples/Examples.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | protocol Demo { 4 | static func main() async throws 5 | } 6 | 7 | @main 8 | enum Examples { 9 | static func main() async throws { 10 | let demos: [Demo.Type] = [ 11 | BareMetalVsCompute.self, 12 | BufferFill.self, 13 | Checkerboard.self, 14 | GameOfLife.self, 15 | HelloWorldDemo.self, 16 | ImageInvert.self, 17 | MaxValue.self, 18 | MemcopyDemo.self, 19 | RandomFill.self, 20 | SIMDReduce.self, 21 | ThreadgroupLogging.self, 22 | BitonicSortDemo.self, 23 | Histogram.self, 24 | CounterDemo.self, 25 | MaxParallel.self, 26 | IsSorted.self, 27 | ] 28 | 29 | let argument: String? = CommandLine.arguments.count > 1 ? CommandLine.arguments[1].lowercased() : nil 30 | if let argument { 31 | if argument == "all" { 32 | for demo in demos { 33 | print("\(demo)") 34 | try await demo.main() 35 | } 36 | } 37 | else { 38 | guard let demo = demos.first(where: { String(describing: $0).lowercased() == argument }) else { 39 | fatalError("No demo with name: \(argument)") 40 | } 41 | try await demo.main() 42 | } 43 | } 44 | else { 45 | for (index, demo) in demos.enumerated() { 46 | print("\(index): \(String(describing: demo))") 47 | } 48 | print("Choice", terminator: ": ") 49 | let choice = Int(readLine()!)! 50 | let demo = demos[choice] 51 | print("Running \(String(describing: demo))...") 52 | try await demo.main() 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/BareMetal.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | 4 | enum BareMetalVsCompute: Demo { 5 | static let source = #""" 6 | #include 7 | 8 | using namespace metal; 9 | 10 | uint thread_position_in_grid [[thread_position_in_grid]]; 11 | 12 | kernel void memset(device uchar* output [[buffer(0)]], constant uchar &value [[buffer(1)]]) { 13 | output[thread_position_in_grid] = value; 14 | } 15 | """# 16 | 17 | static func main() async throws { 18 | // Get the default Metal device 19 | let device = MTLCreateSystemDefaultDevice()! 20 | // Create a buffer and confirm it is zeroed 21 | let buffer = device.makeBuffer(length: 16_384)! 22 | assert(UnsafeRawBufferPointer(start: buffer.contents(), count: buffer.length).allSatisfy { $0 == 0x00 }) 23 | // Run compute and confirm the output is correct 24 | try compute(device: device, buffer: buffer, value: 0x88) 25 | assert(UnsafeRawBufferPointer(start: buffer.contents(), count: buffer.length).allSatisfy { $0 == 0x88 }) 26 | // Run bareMetal and confirm the output is correct 27 | try bareMetal(device: device, buffer: buffer, value: 0xFF) 28 | assert(UnsafeRawBufferPointer(start: buffer.contents(), count: buffer.length).allSatisfy { $0 == 0xFF }) 29 | } 30 | 31 | static func bareMetal(device: MTLDevice, buffer: MTLBuffer, value: UInt8) throws { 32 | // Create shader library from source 33 | let library = try device.makeLibrary(source: source, options: .init()) 34 | let function = library.makeFunction(name: "memset") 35 | 36 | // Create compute pipeline for memset operation 37 | let computePipelineDescriptor = MTLComputePipelineDescriptor() 38 | computePipelineDescriptor.computeFunction = function 39 | let (computePipelineState, reflection) = try device.makeComputePipelineState(descriptor: computePipelineDescriptor, options: [.bindingInfo]) 40 | guard let reflection else { 41 | throw ComputeError.resourceCreationFailure 42 | } 43 | guard let outputIndex = reflection.bindings.first(where: { $0.name == "output" })?.index else { 44 | throw ComputeError.missingBinding("output") 45 | } 46 | guard let valueIndex = reflection.bindings.first(where: { $0.name == "value" })?.index else { 47 | throw ComputeError.missingBinding("value") 48 | } 49 | 50 | // Execute compute pipeline 51 | guard let commandQueue = device.makeCommandQueue() else { 52 | throw ComputeError.resourceCreationFailure 53 | } 54 | commandQueue.label = "memcpy-MTLCommandQueue" 55 | 56 | let commandBufferDescriptor = MTLCommandBufferDescriptor() 57 | guard let commandBuffer = commandQueue.makeCommandBuffer(descriptor: commandBufferDescriptor) else { 58 | throw ComputeError.resourceCreationFailure 59 | } 60 | commandBuffer.label = "memcpy-MTLCommandBuffer" 61 | guard let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder() else { 62 | throw ComputeError.resourceCreationFailure 63 | } 64 | computeCommandEncoder.label = "memcpy-MTLComputeCommandEncoder" 65 | computeCommandEncoder.setComputePipelineState(computePipelineState) 66 | computeCommandEncoder.setBuffer(buffer, offset: outputIndex, index: outputIndex) 67 | var value = value 68 | computeCommandEncoder.setBytes(&value, length: MemoryLayout.size(ofValue: value), index: valueIndex) 69 | let threadsPerGrid = MTLSize(width: buffer.length, height: 1, depth: 1) 70 | let maxTotalThreadsPerThreadgroup = computePipelineState.maxTotalThreadsPerThreadgroup 71 | let threadsPerThreadgroup = MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1, depth: 1) 72 | computeCommandEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) 73 | computeCommandEncoder.endEncoding() 74 | commandBuffer.commit() 75 | commandBuffer.waitUntilCompleted() 76 | } 77 | 78 | static func compute(device: MTLDevice, buffer: MTLBuffer, value: UInt8) throws { 79 | // Set up. 80 | let compute = try Compute(device: device) 81 | 82 | // Create shader library from source 83 | let library = ShaderLibrary.source(source) 84 | 85 | // Create compute pipeline for memset operation 86 | var fill = try compute.makePipeline(function: library.memset) 87 | 88 | // Set buffer and fill value arguments 89 | fill.arguments.output = .buffer(buffer) 90 | fill.arguments.value = .int(value) 91 | 92 | // Execute compute pipeline 93 | try compute.run(pipeline: fill, width: buffer.length) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/BitonicSort/BitonicSort.metal: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using namespace metal; 5 | 6 | uint thread_position_in_grid [[thread_position_in_grid]]; 7 | 8 | [[kernel]] 9 | void bitonicSort( 10 | constant uint &numEntries [[buffer(0)]], 11 | constant uint &groupWidth [[buffer(1)]], 12 | constant uint &groupHeight [[buffer(2)]], 13 | constant uint &stepIndex [[buffer(3)]], 14 | device uint *entries [[buffer(4)]] 15 | ) { 16 | const uint hIndex = thread_position_in_grid & (groupWidth - 1); 17 | const uint indexLeft = hIndex + (groupHeight + 1) * (thread_position_in_grid / groupWidth); 18 | const uint stepSize = stepIndex == 0 ? groupHeight - 2 * hIndex : (groupHeight + 1) / 2; 19 | const uint indexRight = indexLeft + stepSize; 20 | // Exit if out of bounds (for non-power of 2 input sizes) 21 | if (indexRight >= numEntries) { 22 | return; 23 | } 24 | const uint valueLeft = entries[indexLeft]; 25 | const uint valueRight = entries[indexRight]; 26 | // Swap entries if value is descending 27 | if (valueLeft > valueRight) { 28 | entries[indexLeft] = valueRight; 29 | entries[indexRight] = valueLeft; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/BitonicSort/BitonicSort.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | import os 4 | 5 | // swiftlint:disable force_unwrapping 6 | 7 | public struct BitonicSortDemo: Demo { 8 | 9 | public static func main() async throws { 10 | 11 | var entries: [UInt32] = timeit("Creating entries") { 12 | (0 ..< 1_500_000).shuffled() 13 | } 14 | 15 | let device = MTLCreateSystemDefaultDevice()! 16 | let numEntries = entries.count 17 | let buffer: MTLBuffer = try entries.withUnsafeMutableBufferPointer { buffer in 18 | let buffer = UnsafeMutableRawBufferPointer(buffer) 19 | return try device.makeBufferEx(bytes: buffer.baseAddress!, length: buffer.count) 20 | } 21 | 22 | let function = ShaderLibrary.bundle(.module).bitonicSort 23 | let numStages = log2(nextPowerOfTwo(numEntries)) 24 | 25 | let compute = try Compute(device: device) 26 | 27 | var pipeline = try compute.makePipeline(function: function, arguments: [ 28 | "numEntries": .int(numEntries), 29 | "entries": .buffer(buffer), 30 | ]) 31 | 32 | print("Running \(numStages) compute stages") 33 | 34 | var threadgroupsPerGrid = (entries.count + pipeline.maxTotalThreadsPerThreadgroup - 1) / pipeline.maxTotalThreadsPerThreadgroup 35 | threadgroupsPerGrid = (threadgroupsPerGrid + pipeline.threadExecutionWidth - 1) / pipeline.threadExecutionWidth * pipeline.threadExecutionWidth 36 | 37 | try timeit("GPU") { 38 | try compute.task { task in 39 | try task { dispatch in 40 | var n = 0 41 | for stageIndex in 0 ..< numStages { 42 | for stepIndex in 0 ..< (stageIndex + 1) { 43 | let groupWidth = 1 << (stageIndex - stepIndex) 44 | let groupHeight = 2 * groupWidth - 1 45 | pipeline.arguments.groupWidth = .int(groupWidth) 46 | pipeline.arguments.groupHeight = .int(groupHeight) 47 | pipeline.arguments.stepIndex = .int(stepIndex) 48 | // print("\(n), \(stageIndex)/\(numStages), \(stepIndex)/\(stageIndex + 1), \(groupWidth), \(groupHeight)") 49 | try dispatch( 50 | pipeline: pipeline, 51 | threadgroupsPerGrid: MTLSize(width: threadgroupsPerGrid), 52 | threadsPerThreadgroup: MTLSize(width: pipeline.maxTotalThreadsPerThreadgroup) 53 | ) 54 | n += 1 55 | } 56 | } 57 | } 58 | } 59 | } 60 | 61 | timeit("CPU") { 62 | entries.sort() 63 | } 64 | 65 | let result = Array(buffer) 66 | assert(entries == result) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/Broken/CountingSort/CountingSort.metal: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | using namespace metal; 6 | 7 | namespace CountingSort16 { 8 | // MARK: - 9 | 10 | [[kernel]] 11 | void histogram( 12 | uint2 thread_position_in_grid [[thread_position_in_grid]], 13 | device short *input [[buffer(0)]], 14 | device atomic_uint *histogram [[buffer(1)]], 15 | constant uint &histogramCount [[buffer(2)]] 16 | 17 | ) 18 | { 19 | const uchar bucket = thread_position_in_grid.x; 20 | if (bucket >= histogramCount) { 21 | return; 22 | } 23 | const uint index = thread_position_in_grid.y; 24 | if (bucket == input[index]) { 25 | atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed); 26 | } 27 | } 28 | 29 | // MARK: - 30 | 31 | [[kernel]] 32 | void prefix_sum_inclusive( 33 | device uint *histogram [[buffer(0)]], 34 | constant uint &histogramCount [[buffer(1)]] 35 | ) 36 | { 37 | for (uint index = 1; index != histogramCount; index++) { 38 | histogram[index] += histogram[index - 1]; 39 | } 40 | } 41 | 42 | [[kernel]] 43 | void prefix_sum_exclusive( 44 | device uint *histogram [[buffer(0)]], 45 | constant uint &histogramCount [[buffer(1)]] 46 | ) 47 | { 48 | uint sum = 0; 49 | for (uint i = 0; i != histogramCount; i++) { 50 | uint t = histogram[i]; 51 | histogram[i] = sum; 52 | sum += t; 53 | } 54 | } 55 | 56 | // MARK: - 57 | 58 | [[kernel]] 59 | void shuffle( 60 | uint thread_position_in_grid [[thread_position_in_grid]], 61 | device atomic_uint *histogram [[buffer(0)]], 62 | device short *input [[buffer(1)]], 63 | device short *output [[buffer(2)]], 64 | constant uint &count [[buffer(3)]] 65 | ) 66 | { 67 | const uint index = thread_position_in_grid; 68 | 69 | // for (uint index = 0; index != count; ++index) { 70 | const auto bucket = input[index]; 71 | const auto old = atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed); 72 | output[old] = input[index]; 73 | // } 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/Broken/CountingSort/CountingSort.swift: -------------------------------------------------------------------------------- 1 | //import AppKit 2 | //import Compute 3 | //import CoreGraphics 4 | //import Foundation 5 | //import Metal 6 | // 7 | //struct CountingSort { 8 | // let device = MTLCreateSystemDefaultDevice()! 9 | // 10 | // func main() throws { 11 | // let capture = false 12 | // 13 | // let maxValue: UInt16 = 65535 14 | // 15 | // let values = (0..<1).map { _ in UInt16.random(in: 0 ... maxValue - 1) } 16 | // // let values: [UInt16] = [3, 1, 4, 2, 5, 6, 7, 8, 9, 0] 17 | // // print("Values", values) 18 | // let expectedResult = timeit("Foundation sort") { values.sorted() } 19 | // // print("Expected", expectedResult) 20 | // let cpuHistograms = Array(values.histogram().map(UInt32.init)[.. 2 | #include 3 | #include 4 | 5 | using namespace metal; 6 | 7 | inline uchar key(uint value, uint shift) { 8 | return (value >> shift) & 0xFF; 9 | } 10 | 11 | // MARK: - 12 | 13 | [[kernel]] 14 | void histogram( 15 | uint2 thread_position_in_grid [[thread_position_in_grid]], 16 | device uint *input [[buffer(0)]], 17 | constant uint &shift [[buffer(2)]], 18 | device atomic_uint *histogram [[buffer(3)]] 19 | ) 20 | { 21 | const uchar bucket = thread_position_in_grid.x; 22 | const uint index = thread_position_in_grid.y; 23 | if (bucket == key(input[index], shift)) { 24 | atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed); 25 | } 26 | } 27 | 28 | // MARK: - 29 | 30 | [[kernel]] 31 | void prefix_sum_inclusive( 32 | device uint *histogram [[buffer(0)]] 33 | ) 34 | { 35 | for (int index = 1; index != 256; index++) { 36 | histogram[index] += histogram[index - 1]; 37 | } 38 | } 39 | 40 | [[kernel]] 41 | void prefix_sum_exclusive( 42 | device uint *histogram [[buffer(0)]] 43 | ) 44 | { 45 | uint sum = 0; 46 | for (int i = 0; i != 256; i++) { 47 | uint t = histogram[i]; 48 | histogram[i] = sum; 49 | sum += t; 50 | } 51 | } 52 | 53 | // MARK: - 54 | 55 | [[kernel]] 56 | void shuffle( 57 | uint2 thread_position_in_grid [[thread_position_in_grid]], 58 | device atomic_uint *histogram [[buffer(0)]], 59 | device uint *input [[buffer(1)]], 60 | device uint *output [[buffer(2)]], 61 | constant uint &count [[buffer(3)]], 62 | constant uint &shift [[buffer(4)]] 63 | ) 64 | { 65 | const uchar bucket = thread_position_in_grid.x; 66 | for (int index = count - 1; index >= 0; --index) { 67 | if (bucket == key(input[index], shift)) { 68 | auto old = atomic_fetch_add_explicit(&histogram[bucket], -1, memory_order_relaxed); 69 | output[old - 1] = input[index]; 70 | } 71 | } 72 | } 73 | 74 | [[kernel]] 75 | void shuffle2( 76 | uint2 thread_position_in_grid [[thread_position_in_grid]], 77 | device atomic_uint *histogram [[buffer(0)]], 78 | device uint *input [[buffer(1)]], 79 | device uint *output [[buffer(2)]], 80 | constant uint &count [[buffer(3)]], 81 | constant uint &shift [[buffer(4)]] 82 | ) 83 | { 84 | const uchar bucket = thread_position_in_grid.x; 85 | for (uint index = 0; index != count; index++) { 86 | if (bucket == key(input[index], shift)) { 87 | auto old = atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed); 88 | output[old] = input[index]; 89 | } 90 | } 91 | } 92 | 93 | [[kernel]] 94 | void shuffle3( 95 | uint thread_position_in_grid [[thread_position_in_grid]], 96 | uint threads_per_grid [[threads_per_grid]], 97 | device atomic_uint *histogram [[buffer(0)]], 98 | device uint *input [[buffer(1)]], 99 | device uint *output [[buffer(2)]], 100 | constant uint &count [[buffer(3)]], 101 | constant uint &shift [[buffer(4)]] 102 | ) 103 | { 104 | const uint global_id = thread_position_in_grid; 105 | 106 | if (global_id < count) { 107 | const uchar bucket = key(input[global_id], shift); 108 | const auto old = atomic_fetch_add_explicit(&histogram[bucket], 1, memory_order_relaxed); 109 | output[old] = input[global_id]; 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/Broken/RadixSort/RadixSort.swift: -------------------------------------------------------------------------------- 1 | //import AppKit 2 | //import Compute 3 | //import CoreGraphics 4 | //import Foundation 5 | //import Metal 6 | // 7 | //struct RadixSort { 8 | // let device = MTLCreateSystemDefaultDevice()! 9 | // 10 | // func main() throws { 11 | // let capture = false 12 | // 13 | // let values = (0..<1_500_000).map { _ in UInt32.random(in: 0 ... 1000) } 14 | // let expectedResult = timeit("Foundation sort") { values.sorted() } 15 | // let cpuSorted = timeit("CPU Radix Sort") { radixSort(values: values) } 16 | // print("CPU Sorted?", expectedResult == cpuSorted) 17 | // 18 | // let compute = try Compute(device: device, logState: capture ? nil : try device.makeDefaultLogState()) 19 | // let library = ShaderLibrary.bundle(.module, name: "debug") 20 | // 21 | // var input = try device.makeBuffer(bytesOf: values, options: []) 22 | // var output = try device.makeBuffer(bytesOf: Array(repeating: UInt32.zero, count: values.count), options: []) 23 | // let histogram = try device.makeBuffer(bytesOf: Array(repeating: UInt32.zero, count: 256), options: []) 24 | // 25 | // var histogramPass = try compute.makePass(function: library.histogram) 26 | // var prefixSumPass = try compute.makePass(function: library.prefix_sum_exclusive) 27 | // var shufflePass = try compute.makePass(function: library.shuffle2) 28 | // 29 | // try timeit("GPU Radix Sort") { 30 | // try device.capture(enabled: capture) { 31 | // for phase in 0..<4 { 32 | // let shift = UInt32(phase * 8) 33 | // 34 | // histogramPass.arguments.histogram = .buffer(histogram) 35 | // histogramPass.arguments.shift = .int(shift) 36 | // histogramPass.arguments.input = .buffer(input) 37 | // 38 | // try compute.dispatch(label: "Histogram") { dispatch in 39 | // let maxTotalThreadsPerThreadgroup = histogramPass.maxTotalThreadsPerThreadgroup 40 | // try dispatch(pass: histogramPass, threads: MTLSize(width: 256, height: values.count), threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1)) 41 | // } 42 | // 43 | // // Prefix sum. 44 | // prefixSumPass.arguments.histogram = .buffer(histogram) 45 | // try compute.dispatch(label: "PrefixSum") { dispatch in 46 | // try dispatch(pass: prefixSumPass, threads: MTLSize(width: 1), threadsPerThreadgroup: MTLSize(width: 1)) 47 | // } 48 | // 49 | // // Shuffle. 50 | // shufflePass.arguments.histogram = .buffer(histogram) 51 | // shufflePass.arguments.input = .buffer(input) 52 | // shufflePass.arguments.output = .buffer(output) 53 | // shufflePass.arguments.count = .int(values.count) 54 | // shufflePass.arguments.shift = .int(shift) 55 | // try compute.dispatch(label: "Shuffle") { dispatch in 56 | // let maxTotalThreadsPerThreadgroup = shufflePass.maxTotalThreadsPerThreadgroup 57 | // 58 | // try dispatch(pass: shufflePass, threads: MTLSize(width: 256), threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup)) 59 | // } 60 | // 61 | // swap(&input, &output) 62 | // 63 | // histogram.clear() 64 | // } 65 | // } 66 | // } 67 | // print("GPU Sorted?", expectedResult == input.as(UInt32.self)) 68 | // } 69 | //} 70 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/Broken/RadixSort/RadixSortCPU.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | postfix operator ++ 4 | 5 | extension Int { 6 | static postfix func ++(rhs: inout Int) -> Int { 7 | let oldValue = rhs 8 | rhs += 1 9 | return oldValue 10 | } 11 | } 12 | 13 | func shuffle(input: [UInt32], shift: Int, histogram: [UInt32]) -> [UInt32] { 14 | func key(_ value: UInt32) -> Int { 15 | (Int(value) >> shift) & 0xFF 16 | } 17 | var histogram = histogram.map { Int($0) } 18 | var output = Array(repeating: UInt32.zero, count: input.count) 19 | for index in stride(from: input.count - 1, through: 0, by: -1) { 20 | let value = key(input[index]) 21 | histogram[value] -= 1 22 | output[histogram[value]] = input[index] 23 | } 24 | return output 25 | } 26 | 27 | func countingSort(input: [UInt32], shift: Int, output: inout [UInt32]) { 28 | func key(_ value: UInt32) -> Int { 29 | (Int(value) >> shift) & 0xFF 30 | } 31 | // Histogram 32 | var histogram = input.reduce(into: Array(repeating: 0, count: 256)) { result, value in 33 | result[key(value)] += 1 34 | } 35 | // Prefix Sum 36 | for index in histogram.indices.dropFirst() { 37 | histogram[index] += histogram[index - 1] 38 | } 39 | // Shuffle 40 | for index in stride(from: input.count - 1, through: 0, by: -1) { 41 | let value = key(input[index]) 42 | histogram[value] -= 1 43 | output[histogram[value]] = input[index] 44 | } 45 | } 46 | 47 | // From: "A High-Performance Implementation of Counting Sort on CUDA GPU" 48 | func countingSort2(input: [UInt32], shift: Int, output: inout [UInt32]) { 49 | func key(_ value: UInt32) -> Int { 50 | (Int(value) >> shift) & 0xFF 51 | } 52 | // Histogram 53 | var histogram = input.reduce(into: Array(repeating: 0, count: 256)) { result, value in 54 | result[key(value)] += 1 55 | } 56 | 57 | // Prefix Sum (exclusive) 58 | var sum = 0 59 | for i in histogram.indices { 60 | let t = histogram[i] 61 | histogram[i] = sum 62 | sum += t 63 | } 64 | 65 | // Shuffle 66 | for i in input.indices { 67 | let value = key(input[i]) 68 | output[histogram[value]++] = input[i] 69 | } 70 | } 71 | 72 | func radixSort(values: [UInt32]) -> [UInt32] { 73 | var input = values 74 | var output = Array(repeating: UInt32.zero, count: input.count) 75 | for phase in 0..<4 { 76 | countingSort2(input: input, shift: phase * 8, output: &output) 77 | swap(&input, &output) 78 | // print("Phase: \(phase) \(input.map({ String($0, radix: 16) }))") 79 | } 80 | return input 81 | } 82 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/BufferFill.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | import os 4 | 5 | enum BufferFill: Demo { 6 | static let source = #""" 7 | #include 8 | #include 9 | 10 | using namespace metal; 11 | 12 | uint thread_position_in_grid [[thread_position_in_grid]]; 13 | uint threads_per_grid [[threads_per_grid]]; 14 | 15 | kernel void buffer_fill( 16 | device uint *data [[buffer(0)]] 17 | ) { 18 | //if (thread_position_in_grid == 0) { 19 | // os_log_default.log("threads_per_grid: %d", threads_per_grid); 20 | //} 21 | data[thread_position_in_grid] = threads_per_grid; 22 | } 23 | """# 24 | 25 | static func main() async throws { 26 | let device = MTLCreateSystemDefaultDevice()! 27 | let count = 2 ** 24 28 | let data = device.makeBuffer(length: count * MemoryLayout.stride, options: .storageModeShared)! 29 | let compute = try Compute(device: device, logger: Logger()) 30 | let library = ShaderLibrary.source(source, enableLogging: true) 31 | var bufferFill = try compute.makePipeline(function: library.buffer_fill) 32 | bufferFill.arguments.data = .buffer(data) 33 | var n = count 34 | var values = Array(repeating: 0, count: count) 35 | while n > 0 { 36 | try compute.run(pipeline: bufferFill, width: n) 37 | values[0..>= 1 39 | } 40 | assert(Array(data) == values) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/Checkerboard.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import CoreImage 3 | import Metal 4 | import MetalKit 5 | import os 6 | import SwiftUI 7 | import UniformTypeIdentifiers 8 | 9 | enum Checkerboard: Demo { 10 | static let source = #""" 11 | #include 12 | 13 | using namespace metal; 14 | 15 | uint2 gid [[thread_position_in_grid]]; 16 | 17 | kernel void checkerboard( 18 | texture2d outputTexture [[texture(0)]], 19 | constant float4 &color1 [[buffer(0)]], 20 | constant float4 &color2 [[buffer(1)]], 21 | constant uint2 &cellSize [[buffer(2)]] 22 | ) { 23 | // Get the size of the texture 24 | uint width = outputTexture.get_width(); 25 | uint height = outputTexture.get_height(); 26 | 27 | // Check if the current thread is within the texture bounds 28 | if (gid.x >= width || gid.y >= height) { 29 | return; 30 | } 31 | 32 | // Determine which square this pixel belongs to 33 | uint squareX = gid.x / cellSize.x; 34 | uint squareY = gid.y / cellSize.y; 35 | 36 | // Choose color based on whether the sum of squareX and squareY is even or odd 37 | float4 color = ((squareX + squareY) % 2 == 0) ? color1 : color2; 38 | 39 | // Write the color to the output texture 40 | outputTexture.write(color, gid); 41 | } 42 | """# 43 | 44 | static func main() async throws { 45 | let device = MTLCreateSystemDefaultDevice()! 46 | let logger = Logger() 47 | let compute = try Compute(device: device, logger: logger) 48 | let library = ShaderLibrary.source(source) 49 | 50 | var checkerboard = try compute.makePipeline(function: library.checkerboard) 51 | 52 | let outputTextureDescriptor = MTLTextureDescriptor() 53 | outputTextureDescriptor.width = 1_024 54 | outputTextureDescriptor.height = 1_024 55 | outputTextureDescriptor.pixelFormat = .bgra8Unorm 56 | outputTextureDescriptor.usage = .shaderWrite 57 | guard let outputTexture = device.makeTexture(descriptor: outputTextureDescriptor) else { 58 | throw ComputeError.resourceCreationFailure 59 | } 60 | 61 | checkerboard.arguments.outputTexture = .texture(outputTexture) 62 | checkerboard.arguments.color1 = try .color(.gray) 63 | checkerboard.arguments.color2 = try .color(.mint) 64 | checkerboard.arguments.cellSize = .vector(SIMD2(64, 64)) 65 | try compute.run(pipeline: checkerboard, width: outputTexture.width, height: outputTexture.height) 66 | 67 | let url = URL(filePath: "/tmp/checkerboard.png") 68 | try outputTexture.export(to: url) 69 | #if os(macOS) 70 | NSWorkspace.shared.selectFile(url.path, inFileViewerRootedAtPath: url.deletingLastPathComponent().path) 71 | #endif 72 | 73 | // ksdiff /tmp/inverted.png ~/Projects/Compute/Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/CounterDemo.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | import os 4 | 5 | enum CounterDemo: Demo { 6 | static let source = #""" 7 | #include 8 | #include 9 | 10 | using namespace metal; 11 | 12 | uint thread_position_in_grid [[thread_position_in_grid]]; 13 | 14 | kernel void main( 15 | device float *counters [[buffer(0)]], 16 | constant uint &count [[buffer(1)]], 17 | constant uint &step [[buffer(2)]] 18 | 19 | 20 | ) { 21 | for(uint n = 0; n != count; ++n) { 22 | counters[thread_position_in_grid] += float(n); 23 | } 24 | } 25 | """# 26 | 27 | static func main() async throws { 28 | let device = MTLCreateSystemDefaultDevice()! 29 | let logger = Logger() 30 | let compute = try Compute(device: device, logger: logger) 31 | let library = ShaderLibrary.source(source, enableLogging: true) 32 | let numberOfCounters = 100240 33 | let counters = device.makeBuffer(length: MemoryLayout.size * numberOfCounters, options: [])! 34 | var pipeline = try compute.makePipeline(function: library.main) 35 | pipeline.arguments.counters = .buffer(counters) 36 | pipeline.arguments.count = .int(100_000_000) 37 | pipeline.arguments.step = .int(1) 38 | try timeit { 39 | try compute.run(pipeline: pipeline, width: numberOfCounters) 40 | } 41 | // print(Array(counters)) 42 | 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/GameOfLife/GameOfLife.metal: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using namespace metal; 5 | 6 | constant bool wrap [[function_constant(0)]]; 7 | 8 | static constant int2 positions[] = { 9 | int2(-1, -1), 10 | int2( 0, -1), 11 | int2(+1, -1), 12 | int2(-1, 0), 13 | int2(+1, 0), 14 | int2(-1, +1), 15 | int2( 0, +1), 16 | int2(+1, +1), 17 | }; 18 | 19 | bool rules(uint count, bool alive) { 20 | if (alive == true && (count == 2 || count == 3)) { 21 | return true; 22 | } 23 | else if (alive == false && count == 3) { 24 | return true; 25 | } 26 | else if (alive == true) { 27 | return false; 28 | } 29 | else { 30 | return false; 31 | } 32 | } 33 | 34 | template void gameOfLifeGENERIC( 35 | uint2 gid, 36 | texture2d inputTexture, 37 | texture2d outputTexture, 38 | V clear, 39 | V set 40 | ) 41 | { 42 | const int2 sgid = int2(gid); 43 | const int2 inputTextureSize = int2(inputTexture.get_width(), inputTexture.get_height()); 44 | if (sgid.x >= inputTextureSize.x || sgid.y >= inputTextureSize.y) { 45 | return; 46 | } 47 | uint count = 0; 48 | for (int N = 0; N != 8; ++N) { 49 | int2 position = sgid + positions[N]; 50 | if (!wrap) { 51 | if (position.x < 0 || position.x >= inputTextureSize.x) { 52 | continue; 53 | } 54 | else if (position.y < 0 || position.y >= inputTextureSize.y) { 55 | continue; 56 | } 57 | } 58 | else { 59 | position.x = (position.x + inputTextureSize.x) % inputTextureSize.x; 60 | position.y = (position.y + inputTextureSize.y) % inputTextureSize.y; 61 | } 62 | count += inputTexture.read(uint2(position)).r ? 1 : 0; 63 | } 64 | const bool alive = inputTexture.read(gid).r != 0; 65 | outputTexture.write(rules(count, alive), gid); 66 | } 67 | 68 | 69 | [[kernel]] 70 | void gameOfLife_uint( 71 | uint2 gid [[thread_position_in_grid]], 72 | texture2d inputTexture [[texture(0)]], 73 | texture2d outputTexture [[texture(1)]]) 74 | { 75 | gameOfLifeGENERIC(gid, inputTexture, outputTexture, 0, 1); 76 | } 77 | 78 | [[kernel]] 79 | void gameOfLife_float4( 80 | uint2 gid [[thread_position_in_grid]], 81 | texture2d inputTexture [[texture(0)]], 82 | texture2d outputTexture [[texture(1)]]) 83 | { 84 | gameOfLifeGENERIC(gid, inputTexture, outputTexture, float4(0, 0, 0, 0), float4(1, 1, 1, 1)); 85 | } 86 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/GameOfLife/GameOfLife.swift: -------------------------------------------------------------------------------- 1 | #if os(macOS) 2 | import AVFoundation 3 | import Compute 4 | import Foundation 5 | import Metal 6 | import os 7 | 8 | enum GameOfLife: Demo { 9 | 10 | static func main() async throws { 11 | try await run() 12 | } 13 | 14 | static func run(density: Double = 0.5, width: Int = 256, height: Int = 256, frames: Int = 1_200, framesPerSecond: Int = 60) async throws { 15 | let logger: Logger? = Logger() 16 | 17 | // Calculate total number of pixels 18 | let pixelCount = width * height 19 | 20 | // Get the default Metal device 21 | let device = MTLCreateSystemDefaultDevice()! 22 | 23 | // Create and initialize buffer A with random live cells 24 | logger?.log("Creating buffers") 25 | let bufferA = device.makeBuffer(length: pixelCount * MemoryLayout.size, options: [])! 26 | bufferA.contents().withMemoryRebound(to: UInt32.self, capacity: pixelCount) { buffer in 27 | for n in 0...size)! 39 | textureA.label = "texture-a" 40 | 41 | // Create buffer B and texture B 42 | let bufferB = device.makeBuffer(length: pixelCount * MemoryLayout.size, options: [])! 43 | let textureB = bufferB.makeTexture(descriptor: textureDescriptor, offset: 0, bytesPerRow: width * MemoryLayout.size)! 44 | textureB.label = "texture-b" 45 | 46 | logger?.log("Loading shaders") 47 | // Initialize Compute and ShaderLibrary 48 | let compute = try Compute(device: device) 49 | let library = ShaderLibrary.bundle(.module, name: "debug") 50 | 51 | // Initialize timing variables 52 | var totalComputeTime: UInt64 = 0 53 | var totalEncodeTime: UInt64 = 0 54 | 55 | // Create compute pipeline 56 | var pipeline = try compute.makePipeline(function: library.gameOfLife_float4, constants: ["wrap": .bool(true)]) 57 | 58 | // Set up video writer 59 | let url = FileManager.default.homeDirectoryForCurrentUser.appendingPathComponent("Desktop/GameOfLife.mov") 60 | let movieWriter = try TextureToVideoWriter(outputURL: url, size: CGSize(width: width, height: height)) 61 | movieWriter.start() 62 | 63 | logger?.log("Encoding") 64 | 65 | let time = CMTimeMakeWithSeconds(0, preferredTimescale: 600) 66 | movieWriter.writeFrame(texture: textureA, at: time) 67 | 68 | // Main simulation loop 69 | for frame in 0.. 8 | #include 9 | 10 | using namespace metal; 11 | 12 | kernel void hello_world() { 13 | os_log_default.log("Hello world (from Metal!)"); 14 | } 15 | """# 16 | 17 | static func main() async throws { 18 | let device = MTLCreateSystemDefaultDevice()! 19 | let logger = Logger() 20 | logger.log("Hello world (from Swift!)") 21 | let compute = try Compute(device: device, logger: logger) 22 | let library = ShaderLibrary.source(source, enableLogging: true) 23 | let helloWorld = try compute.makePipeline(function: library.hello_world) 24 | try compute.run(pipeline: helloWorld, width: 1) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/Histogram.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | import os 4 | 5 | enum Histogram: Demo { 6 | static let source = #""" 7 | #include 8 | #include 9 | 10 | using namespace metal; 11 | 12 | uint thread_position_in_grid [[thread_position_in_grid]]; 13 | uint thread_position_in_threadgroup [[thread_position_in_threadgroup]]; 14 | 15 | kernel void histogram( 16 | device uint *input [[buffer(0)]], 17 | device atomic_uint *buckets [[buffer(1)]], 18 | constant uint &bucketCount [[buffer(2)]], 19 | threadgroup atomic_uint *scratch [[threadgroup(0)]] 20 | ) { 21 | const uint value = input[thread_position_in_grid]; 22 | atomic_fetch_add_explicit(&scratch[value], value < bucketCount ? 1 : 0, memory_order_relaxed); 23 | 24 | threadgroup_barrier(mem_flags::mem_threadgroup); 25 | if (thread_position_in_threadgroup == 0) { 26 | for (uint n = 0; n != bucketCount; ++n) { 27 | const uint value = atomic_load_explicit(&scratch[n], memory_order_relaxed); 28 | atomic_fetch_add_explicit(&buckets[n], value, memory_order_relaxed); 29 | } 30 | } 31 | } 32 | """# 33 | 34 | static func main() async throws { 35 | let device = MTLCreateSystemDefaultDevice()! 36 | let values: [UInt32] = timeit("Generate random input") { 37 | (0..<1_000_000).map { n in UInt32.random(in: 0..<20) } 38 | } 39 | let input = try device.makeBuffer(bytesOf: values, options: []) 40 | let bucketCount = 32 41 | let buckets = device.makeBuffer(length: bucketCount * MemoryLayout.stride, options: [])! 42 | let compute = try Compute(device: device, logger: Logger()) 43 | let library = ShaderLibrary.source(source, enableLogging: true) 44 | var histogram = try compute.makePipeline(function: library.histogram) 45 | histogram.arguments.input = .buffer(input) 46 | histogram.arguments.buckets = .buffer(buckets) 47 | histogram.arguments.bucketCount = .int(UInt32(bucketCount)) 48 | histogram.arguments.scratch = .threadgroupMemoryLength(bucketCount * MemoryLayout.stride) // TODO: Align 16. 49 | try timeit("GPU") { 50 | try compute.run(pipeline: histogram, threads: [values.count], threadsPerThreadgroup: [histogram.maxTotalThreadsPerThreadgroup]) 51 | } 52 | let result = Array(buckets) 53 | let expectedResult = timeit("CPU") { 54 | values.reduce(into: Array(repeating: 0, count: bucketCount)) { partialResult, value in 55 | let index = Int(value) 56 | partialResult[index] += 1 57 | } 58 | } 59 | 60 | print(result) 61 | print(expectedResult) 62 | print(result == expectedResult) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/ImageInvert.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import CoreImage 3 | import Metal 4 | import MetalKit 5 | import os 6 | import UniformTypeIdentifiers 7 | 8 | enum ImageInvert: Demo { 9 | static let source = #""" 10 | #include 11 | 12 | using namespace metal; 13 | 14 | uint2 gid [[thread_position_in_grid]]; 15 | 16 | kernel void invertImage( 17 | texture2d inputTexture [[texture(0)]], 18 | texture2d outputTexture [[texture(1)]] 19 | ) { 20 | float4 pixel = inputTexture.read(gid); 21 | pixel.rgb = 1.0 - pixel.rgb; 22 | outputTexture.write(pixel, gid); 23 | } 24 | """# 25 | 26 | static func main() async throws { 27 | let device = MTLCreateSystemDefaultDevice()! 28 | let logger = Logger() 29 | let compute = try Compute(device: device, logger: logger) 30 | let library = ShaderLibrary.source(source) 31 | var invertImage = try compute.makePipeline(function: library.invertImage, constants: ["isLinear": .bool(true)]) 32 | 33 | let textureLoader = MTKTextureLoader(device: device) 34 | let inputTexture = try await textureLoader.newTexture(name: "baboon", scaleFactor: 1, bundle: .module, options: [.SRGB: false]) 35 | 36 | let outputTextureDescriptor = MTLTextureDescriptor() 37 | outputTextureDescriptor.width = inputTexture.width 38 | outputTextureDescriptor.height = inputTexture.height 39 | outputTextureDescriptor.pixelFormat = .bgra8Unorm 40 | outputTextureDescriptor.usage = .shaderWrite 41 | guard let outputTexture = device.makeTexture(descriptor: outputTextureDescriptor) else { 42 | throw ComputeError.resourceCreationFailure 43 | } 44 | 45 | invertImage.arguments.inputTexture = .texture(inputTexture) 46 | invertImage.arguments.outputTexture = .texture(outputTexture) 47 | try compute.run(pipeline: invertImage, width: inputTexture.width, height: inputTexture.height) 48 | 49 | try outputTexture.export(to: URL(filePath: "/tmp/inverted.png")) 50 | 51 | // ksdiff /tmp/inverted.png ~/Projects/Compute/Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/IsSorted.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | import os 4 | 5 | enum IsSorted: Demo { 6 | static let source = #""" 7 | #include 8 | #include 9 | 10 | using namespace metal; 11 | 12 | uint thread_position_in_threadgroup [[thread_position_in_threadgroup]]; 13 | uint threadgroup_position_in_grid [[threadgroup_position_in_grid]]; 14 | uint thread_position_in_grid [[thread_position_in_grid]]; 15 | uint threads_per_threadgroup [[threads_per_threadgroup]]; 16 | 17 | kernel void is_sorted( 18 | device const float* input [[buffer(0)]], 19 | constant uint& count [[buffer(1)]], 20 | device atomic_uint* isSorted [[buffer(2)]] 21 | ) { 22 | //os_log_default.log("thread_position_in_grid: %d, threads_per_grid: %f", thread_position_in_grid, input[thread_position_in_grid]); 23 | if (thread_position_in_grid < (count - 1) && input[thread_position_in_grid] > input[thread_position_in_grid + 1]) { 24 | //os_log_default.log("IS NOT SORTED %d %d", thread_position_in_grid, thread_position_in_grid + 1); 25 | atomic_store_explicit(isSorted, 0, memory_order_relaxed); 26 | } 27 | } 28 | 29 | 30 | kernel void is_sorted_complex( 31 | device const float* input [[buffer(0)]], 32 | constant uint& count [[buffer(1)]], 33 | device atomic_bool* isSorted [[buffer(2)]] 34 | ) { 35 | uint groupStart = threadgroup_position_in_grid * threads_per_threadgroup; 36 | 37 | // Start one before if not the first group 38 | uint start = (groupStart == 0) ? groupStart + thread_position_in_threadgroup : groupStart - 1 + threadgroup_position_in_grid; 39 | uint end = metal::min(count - 1, groupStart + threads_per_threadgroup); 40 | 41 | for (uint i = start; i < end; i += threads_per_threadgroup) { 42 | if (input[i] > input[i + 1]) { 43 | atomic_store_explicit(isSorted, false, metal::memory_order_relaxed); 44 | } 45 | } 46 | } 47 | """# 48 | 49 | static func main() async throws { 50 | try await simple() 51 | try await complex() 52 | } 53 | 54 | static func simple() async throws { 55 | let capture = false 56 | 57 | let device = MTLCreateSystemDefaultDevice()! 58 | try device.capture(enabled: capture) { 59 | let logger = Logger() 60 | let compute = try Compute(device: device, logger: logger, useLogState: !capture) 61 | let library = ShaderLibrary.source(source, enableLogging: !capture) 62 | var values = (0..<1_000_000).map { Float($0) }.sorted() 63 | // 64 | values.swapAt(1, 2) 65 | // 66 | let isSorted = try device.makeBuffer(bytesOf: UInt32(1), options: [.storageModeShared]) 67 | isSorted.label = "isSorted" 68 | 69 | var pipeline = try compute.makePipeline(function: library.is_sorted) 70 | pipeline.arguments.input = .buffer(values, label: "input") 71 | pipeline.arguments.count = .int(values.count) 72 | pipeline.arguments.isSorted = .buffer(isSorted) 73 | 74 | try compute.run(pipeline: pipeline, threads: [values.count, 1, 1], threadsPerThreadgroup: [pipeline.maxTotalThreadsPerThreadgroup, 1, 1]) 75 | 76 | print(isSorted.contentsBuffer(of: UInt32.self)[0]) 77 | } 78 | 79 | } 80 | 81 | static func complex() async throws { 82 | let capture = false 83 | 84 | let device = MTLCreateSystemDefaultDevice()! 85 | try device.capture(enabled: capture) { 86 | let logger = Logger() 87 | let compute = try Compute(device: device, logger: logger, useLogState: !capture) 88 | let library = ShaderLibrary.source(source, enableLogging: !capture) 89 | var values = (0..<2_000).map { Float($0) }.sorted() 90 | // 91 | values.swapAt(1, 2) 92 | // 93 | let isSorted = try device.makeBuffer(bytesOf: UInt32(1), options: [.storageModeShared]) 94 | isSorted.label = "isSortedComplex" 95 | 96 | var pipeline = try compute.makePipeline(function: library.is_sorted_complex) 97 | pipeline.arguments.input = .buffer(values, label: "input") 98 | pipeline.arguments.count = .int(values.count) 99 | pipeline.arguments.isSorted = .buffer(isSorted) 100 | 101 | let threadsPerThreadgroup = pipeline.maxTotalThreadsPerThreadgroup 102 | 103 | try compute.run(pipeline: pipeline, threads: [(values.count + threadsPerThreadgroup - 1) / threadsPerThreadgroup, 1, 1], threadsPerThreadgroup: [threadsPerThreadgroup, 1, 1]) 104 | 105 | print(isSorted.contentsBuffer(of: UInt32.self)[0]) 106 | } 107 | 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/MaxParallel.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | import os 4 | 5 | enum MaxParallel: Demo { 6 | static let source = #""" 7 | #include 8 | #include 9 | 10 | using namespace metal; 11 | 12 | uint thread_position_in_grid [[thread_position_in_grid]]; 13 | 14 | kernel void kernel_main( 15 | device atomic_uint &count [[buffer(0)]], 16 | device atomic_uint &maximum [[buffer(1)]] 17 | ) { 18 | uint current = atomic_fetch_add_explicit(&count, 1, memory_order_relaxed); 19 | atomic_fetch_max_explicit(&maximum, current, memory_order_relaxed); 20 | atomic_fetch_add_explicit(&count, -1, memory_order_relaxed); 21 | } 22 | """# 23 | 24 | static func main() async throws { 25 | let device = MTLCreateSystemDefaultDevice()! 26 | let compute = try Compute(device: device) 27 | let library = ShaderLibrary.source(source, enableLogging: true) 28 | let count = device.makeBuffer(length: MemoryLayout.size * 1, options: [])! 29 | let maximum = device.makeBuffer(length: MemoryLayout.size * 1, options: [])! 30 | var pipeline = try compute.makePipeline(function: library.kernel_main) 31 | pipeline.arguments.count = .buffer(count) 32 | pipeline.arguments.maximum = .buffer(maximum) 33 | 34 | 35 | print("thread width,\tmaximum,\ttime (ms)") 36 | for n in 0...31 { 37 | let count = Int(pow(2, Double(n+1))) - 1 38 | let time = try timed() { 39 | try compute.run(pipeline: pipeline, width: count) 40 | } 41 | print("\(count),\t\(Array(maximum)[0]),\t\(Double(time) / Double(1_000_000))") 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/MaxValue.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import CoreImage 3 | import Metal 4 | import MetalKit 5 | import os 6 | import UniformTypeIdentifiers 7 | 8 | enum MaxValue: Demo { 9 | 10 | // Finds the maximum value in an array using a single-threaded Metal compute shader. 11 | // This method is very inefficient and is provided as an example of a suboptimal approach. 12 | static func badIdea(values: [Int32], expectedValue: Int32) throws { 13 | let source = #""" 14 | #include 15 | 16 | using namespace metal; 17 | 18 | kernel void maxValue( 19 | const device uint *input [[buffer(0)]], 20 | constant uint &count [[buffer(1)]], 21 | device uint &output [[buffer(2)]] 22 | ) { 23 | uint temp = 0; 24 | for (uint n = 0; n != count; ++n) { 25 | temp = max(temp, input[n]); 26 | } 27 | output = temp; 28 | } 29 | """# 30 | 31 | let device = MTLCreateSystemDefaultDevice()! 32 | let input = device.makeBuffer(bytes: values, length: MemoryLayout.stride * values.count, options: [])! 33 | let output = device.makeBuffer(length: MemoryLayout.size)! 34 | let compute = try Compute(device: device) 35 | let library = ShaderLibrary.source(source) 36 | var maxValue = try compute.makePipeline(function: library.maxValue) 37 | maxValue.arguments.input = .buffer(input) 38 | maxValue.arguments.count = .int(UInt32(values.count)) 39 | maxValue.arguments.output = .buffer(output) 40 | try timeit(#function) { 41 | try compute.run(pipeline: maxValue, width: 1) 42 | } 43 | let result = output.contents().assumingMemoryBound(to: Int32.self)[0] 44 | assert(result == expectedValue) 45 | } 46 | 47 | // Finds the maximum value in an array using an atomic operation in a Metal compute shader. 48 | // Despite all threads fighting over a single lock this version is still extremely fast. 49 | static func simpleAtomic(values: [Int32], expectedValue: Int32) throws { 50 | let source = #""" 51 | #include 52 | 53 | using namespace metal; 54 | 55 | uint thread_position_in_grid [[thread_position_in_grid]]; 56 | 57 | kernel void maxValue( 58 | const device uint *input [[buffer(0)]], 59 | device atomic_uint *output [[buffer(1)]] 60 | ) { 61 | const uint value = input[thread_position_in_grid]; 62 | atomic_fetch_max_explicit(output, value, memory_order_relaxed); 63 | } 64 | """# 65 | let device = MTLCreateSystemDefaultDevice()! 66 | let input = device.makeBuffer(bytes: values, length: MemoryLayout.stride * values.count, options: [])! 67 | let output = device.makeBuffer(length: MemoryLayout.size)! 68 | let compute = try Compute(device: device) 69 | let library = ShaderLibrary.source(source) 70 | var maxValue = try compute.makePipeline(function: library.maxValue) 71 | maxValue.arguments.input = .buffer(input) 72 | maxValue.arguments.output = .buffer(output) 73 | try timeit(#function) { 74 | try compute.run(pipeline: maxValue, width: values.count) 75 | } 76 | let result = output.contents().assumingMemoryBound(to: Int32.self)[0] 77 | assert(result == expectedValue) 78 | } 79 | 80 | 81 | // Finds the maximum value in an array using a multi-pass approach in a Metal compute shader. 82 | // This method uses SIMD group operations for efficient parallel processing. 83 | // Note: this method is destructive and intermediate values are written to the input buffer. 84 | // TODO: This method may occasionally fail for reasons that are currently unclear. 85 | // NOTE: Does NOT seem to be failing on M1 Ultra. 8/30/2024. 86 | static func multipass(values: [Int32], expectedValue: Int32) throws { 87 | let source = #""" 88 | #include 89 | 90 | using namespace metal; 91 | 92 | // Get the global thread position in the execution grid 93 | uint thread_position_in_grid [[thread_position_in_grid]]; 94 | 95 | // Get the number of threads per SIMD group 96 | uint threads_per_simdgroup [[threads_per_simdgroup]]; 97 | 98 | kernel void maxValue( 99 | device int *input [[buffer(0)]], // Input/output buffer 100 | constant uint &count [[buffer(1)]], // Total number of elements 101 | constant uint &stride [[buffer(2)]] // Stride between elements processed by each thread 102 | ) { 103 | // Calculate the index for this thread 104 | const uint index = thread_position_in_grid * stride; 105 | 106 | // Get the value for this thread, or INT_MIN if out of bounds 107 | uint localValue = index >= count ? INT_MIN : input[index]; 108 | 109 | // Perform a parallel reduction to find the maximum value 110 | for (uint offset = threads_per_simdgroup >> 1; offset > 0; offset >>= 1) { 111 | // Get the value from another thread in the SIMD group 112 | const uint remoteValue = simd_shuffle_down(localValue, offset); 113 | 114 | // Update the current value with the maximum of current and remote 115 | localValue = max(localValue, remoteValue); 116 | } 117 | 118 | // Only the first thread in each SIMD group writes the result 119 | if (simd_is_first()) { 120 | input[index] = localValue; 121 | } 122 | } 123 | """# 124 | let device = MTLCreateSystemDefaultDevice()! 125 | let input = device.makeBuffer(bytes: values, length: MemoryLayout.stride * values.count, options: [])! 126 | let compute = try Compute(device: device) 127 | let library = ShaderLibrary.source(source) 128 | var pipeline = try compute.makePipeline(function: library.maxValue) 129 | pipeline.arguments.input = .buffer(input) 130 | 131 | // This is equivalent `threads_per_simdgroup` in MSL. 132 | let threadExecutionWidth = pipeline.computePipelineState.threadExecutionWidth 133 | assert(threadExecutionWidth == 32) 134 | let maxTotalThreadsPerThreadgroup = pipeline.computePipelineState.maxTotalThreadsPerThreadgroup 135 | 136 | try timeit(#function) { 137 | 138 | // Initialize the stride (between elements processed by each thread) to 1 139 | var stride = 1 140 | 141 | try compute.task { task in 142 | try task { dispatch in 143 | // Continue looping while the stride is less than or equal to the total count of values 144 | while stride <= values.count { 145 | // Set the 'count' argument for the compute pipeline to the total number of values 146 | pipeline.arguments.count = .int(Int32(values.count)) 147 | 148 | // Set the 'stride' argument for the compute pipeline 149 | pipeline.arguments.stride = .int(UInt32(stride)) 150 | 151 | // Dispatch the compute pipeline 152 | try dispatch( 153 | pipeline: pipeline, 154 | // Set the total number of threads to process all values 155 | threads: MTLSize(width: values.count / stride, height: 1, depth: 1), 156 | // Set the number of threads per threadgroup to the maximum allowed 157 | threadsPerThreadgroup: MTLSize(width: maxTotalThreadsPerThreadgroup, height: 1, depth: 1) 158 | ) 159 | 160 | // Increase the stride by multiplying it with the thread execution width 161 | // This effectively reduces the number of active threads in each iteration 162 | stride *= threadExecutionWidth 163 | } 164 | } 165 | } 166 | } 167 | 168 | let result = input.contents().assumingMemoryBound(to: Int32.self)[0] 169 | assert(result == expectedValue) 170 | } 171 | 172 | static func main() async throws { 173 | // var values = Array(Array(repeating: Int32.zero, count: 1000)) 174 | let expectedValue: Int32 = 123456789 175 | 176 | 177 | #if os(macOS) 178 | let count: Int32 = 1_000_000 179 | #else 180 | let count: Int32 = 1_000_000 181 | #endif 182 | 183 | let values = timeit("Generating \(count) values") { 184 | var values = Array(Int32.zero ..< count) 185 | values[Int.random(in: 0.. 8 | 9 | using namespace metal; 10 | 11 | // Thread position in the execution grid 12 | uint thread_position_in_grid [[thread_position_in_grid]]; 13 | 14 | // Empty kernel for baseline performance measurement 15 | kernel void empty() 16 | { 17 | } 18 | 19 | // Kernel to fill buffer with thread positions 20 | kernel void fill( 21 | device uint* output [[buffer(0)]] // Output buffer 22 | ) 23 | { 24 | output[thread_position_in_grid] = thread_position_in_grid; 25 | } 26 | 27 | // Kernel to copy data from input buffer to output buffer 28 | kernel void memcpy( 29 | const device uint* input [[buffer(0)]], // Input buffer 30 | device uint* output [[buffer(1)]] // Output buffer 31 | ) 32 | { 33 | output[thread_position_in_grid] = input[thread_position_in_grid]; 34 | } 35 | """# 36 | 37 | static func main() async throws { 38 | // Get the default Metal device 39 | let device = MTLCreateSystemDefaultDevice()! 40 | // Set count to maximum value of UInt32 41 | let count = Int(UInt32.max) 42 | // Calculate the length of the buffers in bytes 43 | let length = MemoryLayout.stride * count 44 | 45 | // Print the size of the buffers in gigabytes 46 | print(Measurement(value: Double(length), unit: .bytes).converted(to: .gibibytes)) 47 | 48 | print("# Allocating") 49 | // Create two Metal buffers of the calculated length 50 | let bufferA = device.makeBuffer(length: length)! 51 | let bufferB = device.makeBuffer(length: length)! 52 | 53 | print("# Preparing") 54 | // Create a Compute object with the Metal device 55 | let compute = try Compute(device: device) 56 | // Create a shader library from the source code 57 | let library = ShaderLibrary.source(source) 58 | // Create compute pipelines for each kernel function 59 | let empty = try compute.makePipeline(function: library.empty) 60 | let fill = try compute.makePipeline(function: library.fill, arguments: ["output": .buffer(bufferA)]) 61 | let memcopy = try compute.makePipeline(function: library.memcpy, arguments: ["input": .buffer(bufferA), "output": .buffer(bufferB)]) 62 | 63 | print("# Empty") 64 | // Run and time the empty kernel (baseline) 65 | try timeit(length: length) { 66 | try compute.run(pipeline: empty, width: count) 67 | } 68 | 69 | print("# Filling") 70 | // Run and time the fill kernel 71 | try timeit(length: length) { 72 | try compute.run(pipeline: fill, width: count) 73 | } 74 | 75 | print("# GPU memcpy") 76 | // Run and time the GPU memcpy kernel 77 | try timeit(length: length) { 78 | try compute.run(pipeline: memcopy, width: count) 79 | } 80 | 81 | print("# CPU memcpy") 82 | // Run and time CPU memcpy for comparison 83 | timeit(length: length) { 84 | memcpy(bufferB.contents(), bufferA.contents(), length) 85 | } 86 | 87 | print("# DONE") 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/RandomFill.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import CoreGraphics 3 | import Foundation 4 | import Metal 5 | import MetalKit 6 | 7 | // swiftlint:disable force_unwrapping 8 | 9 | struct RandomFill: Demo { 10 | static let source = #""" 11 | #include 12 | #include 13 | 14 | using namespace metal; 15 | 16 | float random(float2 p) 17 | { 18 | // We need irrationals for pseudo randomness. 19 | // Most (all?) known transcendental numbers will (generally) work. 20 | const float2 r = float2( 21 | 23.1406926327792690, // e^pi (Gelfond's constant) 22 | 2.6651441426902251); // 2^sqrt(2) (Gelfond–Schneider constant) 23 | return fract(cos(fmod(123456789.0, 1e-7 + 256.0 * dot(p,r)))); 24 | } 25 | 26 | uint2 thread_position_in_grid [[thread_position_in_grid]]; 27 | 28 | [[kernel]] 29 | void randomFill_float(texture2d outputTexture [[texture(0)]]) 30 | { 31 | const float2 id = float2(thread_position_in_grid); 32 | 33 | float value = random(id) > 0.5 ? 1 : 0; 34 | 35 | float4 color = { value, value, value, 1 }; 36 | 37 | outputTexture.write(color, thread_position_in_grid); 38 | } 39 | """# 40 | 41 | static func main() async throws { 42 | let device = MTLCreateSystemDefaultDevice()! 43 | 44 | let outputTextureDescriptor = MTLTextureDescriptor() 45 | outputTextureDescriptor.width = 1_024 46 | outputTextureDescriptor.height = 1_024 47 | outputTextureDescriptor.pixelFormat = .bgra8Unorm 48 | outputTextureDescriptor.usage = .shaderWrite 49 | guard let outputTexture = device.makeTexture(descriptor: outputTextureDescriptor) else { 50 | throw ComputeError.resourceCreationFailure 51 | } 52 | 53 | let compute = try Compute(device: device) 54 | 55 | let library = ShaderLibrary.source(source) 56 | 57 | var pipeline = try compute.makePipeline(function: library.randomFill_float) 58 | pipeline.arguments.outputTexture = .texture(outputTexture) 59 | 60 | try compute.task { task in 61 | try task { dispatch in 62 | try dispatch(pipeline: pipeline, threadgroupsPerGrid: MTLSize(width: outputTexture.width, height: outputTexture.height, depth: 1), threadsPerThreadgroup: MTLSize(width: 1, height: 1, depth: 1)) 63 | } 64 | } 65 | 66 | let url = URL(filePath: "/tmp/randomfill.png") 67 | try outputTexture.export(to: url) 68 | #if os(macOS) 69 | NSWorkspace.shared.selectFile(url.path, inFileViewerRootedAtPath: url.deletingLastPathComponent().path) 70 | #endif 71 | 72 | 73 | 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/SIMDReduce.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import os 3 | import Metal 4 | 5 | let source = #""" 6 | #include 7 | #include 8 | 9 | using namespace metal; 10 | 11 | // Thread-specific attributes 12 | uint thread_position_in_grid [[thread_position_in_grid]]; // Position of the current thread in the grid 13 | uint threads_per_simdgroup [[threads_per_simdgroup]]; // Number of threads in a SIMD group 14 | uint threadgroup_position_in_grid [[threadgroup_position_in_grid]]; // Position of the current threadgroup in the grid 15 | 16 | // Kernel function for parallel reduction sum 17 | kernel void parallel_reduction_sum( 18 | constant uint* input [[buffer(0)]], // Input buffer 19 | device uint* output [[buffer(1)]] // Output buffer 20 | ) 21 | { 22 | // Get the value for the current thread 23 | uint value = input[thread_position_in_grid]; 24 | 25 | // Perform parallel reduction within SIMD group 26 | for (uint offset = threads_per_simdgroup / 2; offset > 0; offset >>= 1) { 27 | // Add the value from the thread 'offset' positions ahead 28 | value += simd_shuffle_and_fill_down(value, 0u, offset); 29 | } 30 | 31 | // Only the first thread in each SIMD group writes the result 32 | if (simd_is_first()) { 33 | output[thread_position_in_grid / threads_per_simdgroup] = value; 34 | } 35 | } 36 | """# 37 | 38 | struct SIMDReduce: Demo { 39 | 40 | static func main() async throws { 41 | try dualBuffer() 42 | } 43 | 44 | static func dualBuffer() throws { 45 | let device = MTLCreateSystemDefaultDevice()! 46 | // Create N values and sum them up in the CPU... 47 | var count = 50_000_000 48 | let values = timeit("Generating \(count) values") { 49 | (0.. 1 84 | } 85 | } 86 | } 87 | let result = Array(output.contentsBuffer(of: UInt32.self))[0] 88 | print("Compute result:", expectedResult, result, result == expectedResult) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /Sources/Examples/Examples/ThreadgroupLog.swift: -------------------------------------------------------------------------------- 1 | import Compute 2 | import Metal 3 | import os 4 | 5 | enum ThreadgroupLogging: Demo { 6 | static let source = #""" 7 | #include 8 | #include 9 | 10 | const uint thread_position_in_grid [[thread_position_in_grid]]; 11 | const uint threadgroup_position_in_grid [[threadgroup_position_in_grid]]; 12 | const uint thread_position_in_threadgroup [[thread_position_in_threadgroup]]; 13 | 14 | const uint threads_per_grid [[threads_per_grid]]; 15 | const uint threads_per_threadgroup [[threads_per_threadgroup]]; 16 | const uint threadgroups_per_grid [[threadgroups_per_grid]]; 17 | 18 | using namespace metal; 19 | 20 | kernel void threadgroup_test() { 21 | if (thread_position_in_grid == 0) { 22 | os_log_default.log("threads_per_grid: %d, threads_per_threadgroup: %d, threadgroups_per_grid: %d", threads_per_grid, threads_per_threadgroup, threadgroups_per_grid); 23 | } 24 | os_log_default.log("thread_position_in_grid: %d, thread_position_in_threadgroup: %d, threadgroup_position_in_grid: %d", thread_position_in_grid, thread_position_in_threadgroup, threadgroup_position_in_grid); 25 | } 26 | """# 27 | 28 | static func main() async throws { 29 | let device = MTLCreateSystemDefaultDevice()! 30 | let compute = try Compute(device: device, logger: Logger()) 31 | let library = ShaderLibrary.source(source, enableLogging: true) 32 | let pipeline = try compute.makePipeline(function: library.threadgroup_test) 33 | try compute.run(pipeline: pipeline, threadgroupsPerGrid: MTLSize(width: 3), threadsPerThreadgroup: MTLSize(width: 2)) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /Sources/Examples/Resources/Media.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Sources/Examples/Resources/Media.xcassets/baboon-acorn-inverted.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "baboon-acorn-inverted.png", 5 | "idiom" : "universal" 6 | } 7 | ], 8 | "info" : { 9 | "author" : "xcode", 10 | "version" : 1 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Sources/Examples/Resources/Media.xcassets/baboon-acorn-inverted.imageset/baboon-acorn-inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schwa/Compute/b1ba1e1071c912007bdbb59313e175887335fab8/Sources/Examples/Resources/Media.xcassets/baboon-acorn-inverted.imageset/baboon-acorn-inverted.png -------------------------------------------------------------------------------- /Sources/Examples/Resources/Media.xcassets/baboon.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "baboon.png", 5 | "idiom" : "universal" 6 | } 7 | ], 8 | "info" : { 9 | "author" : "xcode", 10 | "version" : 1 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schwa/Compute/b1ba1e1071c912007bdbb59313e175887335fab8/Sources/Examples/Resources/Media.xcassets/baboon.imageset/baboon.png -------------------------------------------------------------------------------- /Sources/Examples/Support.swift: -------------------------------------------------------------------------------- 1 | import AVFoundation 2 | import Compute 3 | import CoreGraphics 4 | import Foundation 5 | import Metal 6 | import MetalSupportLite 7 | 8 | public func getMachTimeInNanoseconds() -> UInt64 { 9 | var timebase = mach_timebase_info_data_t() 10 | mach_timebase_info(&timebase) 11 | let currentTime = mach_absolute_time() 12 | return currentTime * UInt64(timebase.numer) / UInt64(timebase.denom) 13 | } 14 | 15 | @discardableResult 16 | public func timed(_ work: () throws -> Void) rethrows -> UInt64 { 17 | let start = getMachTimeInNanoseconds() 18 | try work() 19 | let end = getMachTimeInNanoseconds() 20 | return end - start 21 | } 22 | 23 | @discardableResult 24 | public func timeit(_ work: () throws -> R, display: (UInt64) -> Void) rethrows -> R { 25 | let start = getMachTimeInNanoseconds() 26 | let result = try work() 27 | let end = getMachTimeInNanoseconds() 28 | display(end - start) 29 | return result 30 | } 31 | 32 | @discardableResult 33 | public func timeit(_ label: String? = nil, _ work: () throws -> R) rethrows -> R { 34 | try timeit(work) { delta in 35 | let measurement = Measurement(value: Double(delta), unit: UnitDuration.nanoseconds) 36 | let measurementMS = measurement.converted(to: .milliseconds) 37 | print("\(label ?? ""): \(measurementMS.formatted())") 38 | } 39 | } 40 | 41 | @discardableResult 42 | public func timeit(_ label: String? = nil, length: Int, _ work: () throws -> R) rethrows -> R { 43 | try timeit(work) { delta in 44 | let seconds = Double(delta) / 1_000_000_000 45 | let bytesPerSecond = Double(length) / seconds 46 | let gigabytesPerSecond = Measurement(value: bytesPerSecond, unit: UnitInformationStorage.bytes) 47 | .converted(to: .gigabytes) 48 | print("Time: \(Measurement(value: Double(seconds), unit: UnitDuration.seconds).converted(to: .milliseconds).formatted())") 49 | print("Speed: \(gigabytesPerSecond.formatted(.measurement(width: .abbreviated, usage: .asProvided)))/s") 50 | } 51 | } 52 | 53 | class TextureToVideoWriter { 54 | private var assetWriter: AVAssetWriter 55 | private var writerInput: AVAssetWriterInput? 56 | private var adaptor: AVAssetWriterInputPixelBufferAdaptor? 57 | 58 | let outputURL: URL 59 | let temporaryURL: URL 60 | private let size: CGSize 61 | private let pixelFormat = kCVPixelFormatType_32BGRA 62 | 63 | var endTime: CMTime? 64 | 65 | init(outputURL: URL, size: CGSize) throws { 66 | self.outputURL = outputURL 67 | let temporaryURL = FileManager.default.temporaryDirectory.appendingPathComponent("\(UUID().uuidString).mp4") 68 | self.temporaryURL = temporaryURL 69 | self.size = size 70 | 71 | assetWriter = try AVAssetWriter(outputURL: temporaryURL, fileType: .mov) 72 | } 73 | 74 | func start() { 75 | let settings: [String: Any] = [ 76 | AVVideoCodecKey: AVVideoCodecType.hevc, 77 | AVVideoWidthKey: size.width, 78 | AVVideoHeightKey: size.height 79 | ] 80 | 81 | writerInput = AVAssetWriterInput(mediaType: .video, outputSettings: settings) 82 | writerInput?.expectsMediaDataInRealTime = true 83 | 84 | adaptor = AVAssetWriterInputPixelBufferAdaptor( 85 | assetWriterInput: writerInput!, 86 | sourcePixelBufferAttributes: [ 87 | kCVPixelBufferPixelFormatTypeKey as String: pixelFormat, 88 | kCVPixelBufferWidthKey as String: Int(size.width), 89 | kCVPixelBufferHeightKey as String: Int(size.height) 90 | ] 91 | ) 92 | 93 | if assetWriter.canAdd(writerInput!) { 94 | assetWriter.add(writerInput!) 95 | } 96 | 97 | assetWriter.startWriting() 98 | assetWriter.startSession(atSourceTime: .zero) 99 | } 100 | 101 | func writeFrame(texture: MTLTexture, at time: CMTime) { 102 | autoreleasepool { 103 | guard let adaptor, let pixelBufferPool = adaptor.pixelBufferPool else { 104 | fatalError() 105 | } 106 | var pixelBuffer: CVPixelBuffer? 107 | CVPixelBufferPoolCreatePixelBuffer(nil, pixelBufferPool, &pixelBuffer) 108 | guard let pixelBuffer else { 109 | fatalError() 110 | } 111 | CVPixelBufferLockBaseAddress(pixelBuffer, .readOnly) 112 | guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else { 113 | fatalError() 114 | } 115 | let region = MTLRegionMake2D(0, 0, texture.width, texture.height) 116 | texture.getBytes(baseAddress, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBuffer), from: region, mipmapLevel: 0) 117 | CVPixelBufferUnlockBaseAddress(pixelBuffer, .readOnly) 118 | writeFrame(pixelBuffer: pixelBuffer, at: time) 119 | endTime = time 120 | } 121 | } 122 | 123 | func writeFrame(pixelBuffer: CVPixelBuffer, at time: CMTime) { 124 | autoreleasepool { 125 | guard let writerInput, let adaptor else { 126 | fatalError() 127 | } 128 | if writerInput.isReadyForMoreMediaData == false { 129 | // This isn't pretty but it works? 130 | while writerInput.isReadyForMoreMediaData == false { 131 | usleep(10 * 1_000) 132 | } 133 | } 134 | adaptor.append(pixelBuffer, withPresentationTime: time) 135 | } 136 | } 137 | 138 | func finish() async throws { 139 | writerInput?.markAsFinished() 140 | assetWriter.endSession(atSourceTime: endTime!) 141 | await assetWriter.finishWriting() 142 | let fileManager = FileManager() 143 | if fileManager.fileExists(atPath: outputURL.path) { 144 | try FileManager().removeItem(at: outputURL) 145 | } 146 | try FileManager().moveItem(at: temporaryURL, to: outputURL) 147 | } 148 | } 149 | 150 | extension MTLTexture { 151 | func export(to url: URL) throws { 152 | assert(pixelFormat == .bgra8Unorm) 153 | assert(depth == 1) 154 | 155 | let bytesPerRow = width * MemoryLayout.size * 4 156 | guard let colorSpace = CGColorSpace(name: CGColorSpace.genericRGBLinear) else { 157 | throw ComputeError.resourceCreationFailure 158 | } 159 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue) 160 | .union(.byteOrder32Little) 161 | guard let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: bytesPerRow, space: colorSpace, bitmapInfo: bitmapInfo.rawValue) else { 162 | throw ComputeError.resourceCreationFailure 163 | } 164 | guard let contextData = context.data else { 165 | throw ComputeError.resourceCreationFailure 166 | } 167 | getBytes(contextData, bytesPerRow: bytesPerRow, from: MTLRegion(origin: MTLOrigin(x: 0, y: 0, z: 0), size: MTLSize(width: width, height: height, depth: 1)), mipmapLevel: 0) 168 | 169 | guard let image = context.makeImage() else { 170 | throw ComputeError.resourceCreationFailure 171 | } 172 | 173 | guard let imageDestination = CGImageDestinationCreateWithURL(url as CFURL, UTType.png.identifier as CFString, 1, nil) else { 174 | throw ComputeError.resourceCreationFailure 175 | } 176 | CGImageDestinationAddImage(imageDestination, image, nil) 177 | CGImageDestinationFinalize(imageDestination) 178 | } 179 | } 180 | 181 | extension MTLBuffer { 182 | func withUnsafeBytes(_ body: (UnsafeBufferPointer) throws -> ResultType) rethrows -> ResultType { 183 | try withUnsafeBytes { (buffer: UnsafeRawBufferPointer) in 184 | try buffer.withMemoryRebound(to: ContentType.self, body) 185 | } 186 | } 187 | 188 | func withUnsafeBytes(_ body: (UnsafeRawBufferPointer) throws -> ResultType) rethrows -> ResultType { 189 | try body(UnsafeRawBufferPointer(start: contents(), count: length)) 190 | } 191 | } 192 | 193 | extension Array where Element: Equatable { 194 | 195 | struct Run { 196 | var element: Element 197 | var count: Int 198 | } 199 | 200 | func rle() -> [Run] { 201 | 202 | var lastElement: Element? 203 | var runLength = 0 204 | 205 | var runs: [Run] = [] 206 | 207 | for element in self { 208 | if element == lastElement { 209 | runLength += 1 210 | } 211 | else { 212 | if let lastElement { 213 | runs.append(.init(element: lastElement, count: runLength)) 214 | } 215 | lastElement = element 216 | runLength = 1 217 | } 218 | } 219 | 220 | if let lastElement { 221 | runs.append(.init(element: lastElement, count: runLength)) 222 | } 223 | 224 | return runs 225 | } 226 | 227 | } 228 | 229 | extension Array where Element == UInt32 { 230 | func prefixSum() -> [UInt32] { 231 | var output = Array(repeating: UInt32.zero, count: count) 232 | for j in 1.. UInt32 { 242 | let result = UInt32(31 - value.leadingZeroBitCount) 243 | return ceiling && (1 << result) < value ? result + 1 : result 244 | } 245 | 246 | func log2(_ value: Int, ceiling: Bool = false) -> Int { 247 | precondition(value > 0, "log2 is only defined for positive numbers") 248 | let result = 63 - value.leadingZeroBitCount 249 | return ceiling && (1 << result) < value ? result + 1 : result 250 | } 251 | 252 | // MARK: - 253 | 254 | infix operator **: MultiplicationPrecedence 255 | 256 | func ** (base: Int, exponent: Int) -> Int { 257 | return Int(pow(Double(base), Double(exponent))) 258 | } 259 | 260 | func ** (base: UInt32, exponent: UInt32) -> UInt32 { 261 | return UInt32(pow(Double(base), Double(exponent))) 262 | } 263 | 264 | extension Array { 265 | init(_ buffer: MTLBuffer) { 266 | let pointer = buffer.contents().bindMemory(to: Element.self, capacity: buffer.length / MemoryLayout.size) 267 | let buffer = UnsafeBufferPointer(start: pointer, count: buffer.length / MemoryLayout.size) 268 | self = Array(buffer) 269 | } 270 | 271 | } 272 | 273 | func ceildiv (_ x: T, _ y: T) -> T where T: BinaryInteger { 274 | (x + y - 1) / y 275 | } 276 | 277 | extension Array { 278 | init(_ buffer: TypedMTLBuffer) { 279 | self = buffer.withUnsafeMTLBuffer { buffer in 280 | Array(buffer!) 281 | } 282 | } 283 | } 284 | 285 | extension Compute.Argument { 286 | static func buffer(_ data: TypedMTLBuffer) -> Self { 287 | data.withUnsafeMTLBuffer { buffer in 288 | return .buffer(buffer!) 289 | } 290 | } 291 | } 292 | 293 | func nextPowerOfTwo(_ n: Int) -> Int { 294 | return Int(pow(2.0, Double(Int(log2(Double(n))).advanced(by: 1))) + 0.5) 295 | } 296 | 297 | extension MTLSize: @retroactive ExpressibleByArrayLiteral { 298 | public init(arrayLiteral elements: Int...) { 299 | switch elements.count { 300 | case 1: 301 | self = .init(elements[0], 1, 1) 302 | case 2: 303 | self = .init(elements[0], elements[1], 1) 304 | case 3: 305 | self = .init(elements[0], elements[1], elements[2]) 306 | default: 307 | fatalError() 308 | } 309 | } 310 | 311 | } 312 | -------------------------------------------------------------------------------- /Sources/MetalSupportLite/BaseSupport.swift: -------------------------------------------------------------------------------- 1 | public enum BaseError: Error { 2 | // IDEA: Have a good going over here and clean up duplicate/vague types. 3 | case generic(String) 4 | case resourceCreationFailure 5 | case illegalValue 6 | case optionalUnwrapFailure 7 | case initializationFailure 8 | case unknown 9 | case missingValue 10 | case typeMismatch 11 | case inputOutputFailure 12 | case invalidParameter 13 | case parsingFailure 14 | case encodingFailure 15 | case missingResource 16 | case extended(Error, String) 17 | case decodingFailure 18 | case missingBinding(String) 19 | case overflow 20 | } 21 | 22 | public extension BaseError { 23 | static func error(_ error: Self) -> Self { 24 | // NOTE: Hook here to add logging or special breakpoint handling. 25 | // logger?.error("Error: \(error)") 26 | error 27 | } 28 | } 29 | 30 | public func fatalError(_ error: Error) -> Never { 31 | fatalError("\(error)") 32 | } 33 | 34 | public func unimplemented(_ message: @autoclosure () -> String = String(), file: StaticString = #file, line: UInt = #line) -> Never { 35 | fatalError(message(), file: file, line: line) 36 | } 37 | 38 | public func temporarilyDisabled(_ message: @autoclosure () -> String = String(), file: StaticString = #file, line: UInt = #line) -> Never { 39 | fatalError(message(), file: file, line: line) 40 | } 41 | 42 | public func unreachable(_ message: @autoclosure () -> String = String(), file: StaticString = #file, line: UInt = #line) -> Never { 43 | fatalError(message(), file: file, line: line) 44 | } 45 | 46 | // MARK: - 47 | 48 | public extension Optional { 49 | func safelyUnwrap(_ error: @autoclosure () -> Error) throws -> Wrapped { 50 | // swiftlint:disable:next shorthand_optional_binding 51 | guard let self = self else { 52 | throw error() 53 | } 54 | return self 55 | } 56 | 57 | func forceUnwrap() -> Wrapped { 58 | // swiftlint:disable:next shorthand_optional_binding 59 | guard let self = self else { 60 | fatalError("Cannot unwrap nil optional.") 61 | } 62 | return self 63 | } 64 | 65 | func forceUnwrap(_ message: @autoclosure () -> String) -> Wrapped { 66 | // swiftlint:disable:next shorthand_optional_binding 67 | guard let self = self else { 68 | fatalError(message()) 69 | } 70 | return self 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /Sources/MetalSupportLite/MTLBuffer+Extensions.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | 3 | public extension MTLDevice { 4 | // TODO: Rename 5 | func makeBufferEx(bytes pointer: UnsafeRawPointer, length: Int, options: MTLResourceOptions = []) throws -> MTLBuffer { 6 | guard let buffer = makeBuffer(bytes: pointer, length: length, options: options) else { 7 | throw BaseError.error(.resourceCreationFailure) 8 | } 9 | return buffer 10 | } 11 | 12 | // TODO: Rename 13 | func makeBufferEx(length: Int, options: MTLResourceOptions = []) throws -> MTLBuffer { 14 | guard let buffer = makeBuffer(length: length, options: options) else { 15 | throw BaseError.error(.resourceCreationFailure) 16 | } 17 | return buffer 18 | } 19 | 20 | func makeBuffer(data: Data, options: MTLResourceOptions) throws -> MTLBuffer { 21 | try data.withUnsafeBytes { buffer in 22 | let baseAddress = buffer.baseAddress.forceUnwrap("No baseAddress.") 23 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else { 24 | throw BaseError.error(.resourceCreationFailure) 25 | } 26 | return buffer 27 | } 28 | } 29 | 30 | func makeBuffer(bytesOf content: some Any, options: MTLResourceOptions) throws -> MTLBuffer { 31 | try withUnsafeBytes(of: content) { buffer in 32 | let baseAddress = buffer.baseAddress.forceUnwrap("No baseAddress.") 33 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else { 34 | throw BaseError.error(.resourceCreationFailure) 35 | } 36 | return buffer 37 | } 38 | } 39 | 40 | func makeBuffer(bytesOf content: [some Any], options: MTLResourceOptions) throws -> MTLBuffer { 41 | try content.withUnsafeBytes { buffer in 42 | let baseAddress = buffer.baseAddress.forceUnwrap("No baseAddress.") 43 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else { 44 | throw BaseError.error(.resourceCreationFailure) 45 | } 46 | return buffer 47 | } 48 | } 49 | } 50 | 51 | public extension MTLBuffer { 52 | func data() -> Data { 53 | Data(bytes: contents(), count: length) 54 | } 55 | 56 | /// Update a MTLBuffer's contents using an inout type block 57 | func with(type: T.Type, _ block: (inout T) -> R) -> R { 58 | let value = contents().bindMemory(to: T.self, capacity: 1) 59 | return block(&value.pointee) 60 | } 61 | 62 | func withEx(type: T.Type, count: Int, _ block: (UnsafeMutableBufferPointer) -> R) -> R { 63 | let pointer = contents().bindMemory(to: T.self, capacity: count) 64 | let buffer = UnsafeMutableBufferPointer(start: pointer, count: count) 65 | return block(buffer) 66 | } 67 | 68 | func contentsBuffer() -> UnsafeMutableRawBufferPointer { 69 | UnsafeMutableRawBufferPointer(start: contents(), count: length) 70 | } 71 | 72 | func contentsBuffer(of type: T.Type) -> UnsafeMutableBufferPointer { 73 | contentsBuffer().bindMemory(to: type) 74 | } 75 | func labelled(_ label: String) -> MTLBuffer { 76 | self.label = label 77 | return self 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /Sources/MetalSupportLite/MetalBasicExtensions.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | 3 | public extension MTLOrigin { 4 | init(_ origin: CGPoint) { 5 | self.init(x: Int(origin.x), y: Int(origin.y), z: 0) 6 | } 7 | 8 | static var zero: MTLOrigin { 9 | MTLOrigin(x: 0, y: 0, z: 0) 10 | } 11 | } 12 | 13 | public extension MTLRegion { 14 | init(_ rect: CGRect) { 15 | self = MTLRegion(origin: MTLOrigin(rect.origin), size: MTLSize(rect.size)) 16 | } 17 | } 18 | 19 | public extension MTLSize { 20 | init(_ size: CGSize) { 21 | self.init(width: Int(size.width), height: Int(size.height), depth: 1) 22 | } 23 | 24 | init(_ width: Int, _ height: Int, _ depth: Int) { 25 | self = MTLSize(width: width, height: height, depth: depth) 26 | } 27 | 28 | init(width: Int) { 29 | self = MTLSize(width: width, height: 1, depth: 1) 30 | } 31 | 32 | init(width: Int, height: Int) { 33 | self = MTLSize(width: width, height: height, depth: 1) 34 | } 35 | } 36 | 37 | public extension SIMD4 { 38 | init(_ clearColor: MTLClearColor) { 39 | self = [clearColor.red, clearColor.green, clearColor.blue, clearColor.alpha] 40 | } 41 | } 42 | 43 | public extension MTLIndexType { 44 | var indexSize: Int { 45 | switch self { 46 | case .uint16: 47 | MemoryLayout.size 48 | case .uint32: 49 | MemoryLayout.size 50 | default: 51 | fatalError(BaseError.illegalValue) 52 | } 53 | } 54 | } 55 | 56 | public extension MTLPrimitiveType { 57 | var vertexCount: Int? { 58 | switch self { 59 | case .triangle: 60 | 3 61 | default: 62 | fatalError(BaseError.illegalValue) 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /Sources/MetalSupportLite/MetalSupportLite.swift: -------------------------------------------------------------------------------- 1 | import Metal 2 | 3 | public extension MTLDevice { 4 | func capture (enabled: Bool = true, _ block: () throws -> R) throws -> R { 5 | guard enabled else { 6 | return try block() 7 | } 8 | let captureManager = MTLCaptureManager.shared() 9 | let captureScope = captureManager.makeCaptureScope(device: self) 10 | let captureDescriptor = MTLCaptureDescriptor() 11 | captureDescriptor.captureObject = captureScope 12 | try captureManager.startCapture(with: captureDescriptor) 13 | captureScope.begin() 14 | defer { 15 | captureScope.end() 16 | } 17 | return try block() 18 | } 19 | 20 | var supportsNonuniformThreadGroupSizes: Bool { 21 | let families: [MTLGPUFamily] = [.apple4, .apple5, .apple6, .apple7] 22 | return families.contains { supportsFamily($0) } 23 | } 24 | 25 | func makeComputePipelineState(function: MTLFunction, options: MTLPipelineOption) throws -> (MTLComputePipelineState, MTLComputePipelineReflection?) { 26 | var reflection: MTLComputePipelineReflection? 27 | let pipelineState = try makeComputePipelineState(function: function, options: options, reflection: &reflection) 28 | return (pipelineState, reflection) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /Sources/MetalSupportLite/TypedMTLBuffer.swift: -------------------------------------------------------------------------------- 1 | @preconcurrency import Metal 2 | 3 | /// A type-safe wrapper around `MTLBuffer` for managing Metal buffers with a specific element type. 4 | /// 5 | /// `TypedMTLBuffer` provides a convenient way to work with Metal buffers while maintaining type safety. 6 | /// It encapsulates an `MTLBuffer` and provides methods to safely access and manipulate its contents. 7 | /// 8 | /// - Important: This type conforms to `Sendable`. However, this conformance is only valid if the 9 | /// underlying `MTLBuffer` is not reused elsewhere. Ensure that the `MTLBuffer` is uniquely owned 10 | /// by this `TypedMTLBuffer` instance to maintain thread safety when sending across concurrency domains. 11 | /// - Note: The generic type `Element` should be a POD (Plain Old Data) type. 12 | public struct TypedMTLBuffer: Sendable { 13 | public var count: Int { 14 | willSet { 15 | assert(count <= capacity) 16 | } 17 | } 18 | 19 | /// The underlying Metal buffer. 20 | private var base: MTLBuffer? 21 | 22 | /// Initializes a new `TypedMTLBuffer` with the given Metal buffer. 23 | /// 24 | /// - Parameter mtlBuffer: The Metal buffer to wrap. 25 | /// - Precondition: The generic type `Element` must be a POD type. 26 | public init(mtlBuffer: MTLBuffer?, count: Int) { 27 | assert(_isPOD(Element.self)) 28 | self.base = mtlBuffer 29 | self.count = count 30 | } 31 | 32 | public init() { 33 | self.base = nil 34 | self.count = 0 35 | } 36 | 37 | public var capacity: Int { 38 | (base?.length ?? 0) / MemoryLayout.stride 39 | } 40 | 41 | public var label: String? { 42 | get { 43 | base?.label 44 | } 45 | set { 46 | base?.label = newValue 47 | } 48 | } 49 | } 50 | 51 | extension TypedMTLBuffer: Equatable { 52 | public static func == (lhs: Self, rhs: Self) -> Bool { 53 | lhs.count == rhs.count && lhs.base === rhs.base 54 | } 55 | } 56 | 57 | extension TypedMTLBuffer: CustomDebugStringConvertible { 58 | public var debugDescription: String { 59 | "TypedMTLBuffer<\(type(of: Element.self))>(count: \(count), capacity: \(capacity), base.label: \(String(describing: base?.label)), base.length: \(String(describing: base?.length)))" 60 | } 61 | } 62 | 63 | // MARK: - 64 | 65 | public extension TypedMTLBuffer { 66 | var unsafeBase: MTLBuffer? { 67 | base 68 | } 69 | 70 | /// Provides temporary access to the underlying `MTLBuffer`. 71 | /// 72 | /// - Parameter block: A closure that takes an `MTLBuffer` and returns a value. 73 | /// - Returns: The value returned by the `block`. 74 | /// - Throws: Rethrows any error thrown by the `block`. 75 | func withUnsafeMTLBuffer(_ block: (MTLBuffer?) throws -> R) rethrows -> R { 76 | try block(base) 77 | } 78 | 79 | /// Provides unsafe read-only access to the buffer's contents. 80 | /// 81 | /// - Parameter block: A closure that takes an `UnsafeBufferPointer` and returns a value. 82 | /// - Returns: The value returned by the `block`. 83 | /// - Throws: Rethrows any error thrown by the `block`. 84 | func withUnsafeBufferPointer(_ block: (UnsafeBufferPointer) throws -> R) rethrows -> R { 85 | if let base { 86 | let contents = base.contents() 87 | let pointer = contents.bindMemory(to: Element.self, capacity: count) 88 | let buffer = UnsafeBufferPointer(start: pointer, count: count) 89 | return try block(buffer) 90 | } 91 | else { 92 | return try block(UnsafeBufferPointer(start: nil, count: 0)) 93 | } 94 | } 95 | 96 | /// Provides unsafe mutable access to the buffer's contents. 97 | /// 98 | /// - Parameter block: A closure that takes an `UnsafeMutableBufferPointer` and returns a value. 99 | /// - Returns: The value returned by the `block`. 100 | /// - Throws: Rethrows any error thrown by the `block`. 101 | func withUnsafeMutableBufferPointer(_ block: (UnsafeMutableBufferPointer) throws -> R) rethrows -> R { 102 | if let base { 103 | let contents = base.contents() 104 | let pointer = contents.bindMemory(to: Element.self, capacity: count) 105 | let buffer = UnsafeMutableBufferPointer(start: pointer, count: count) 106 | return try block(buffer) 107 | } 108 | else { 109 | return try block(UnsafeMutableBufferPointer(start: nil, count: 0)) 110 | } 111 | } 112 | 113 | /// Sets a label for the underlying Metal buffer. 114 | /// 115 | /// - Parameter label: The label to set. 116 | /// - Returns: The `TypedMTLBuffer` instance with the updated label. 117 | func labelled(_ label: String) -> Self { 118 | base?.label = label 119 | return self 120 | } 121 | } 122 | 123 | public extension TypedMTLBuffer { 124 | mutating func append(contentsOf elements: [Element]) throws { 125 | if count + elements.count > capacity { 126 | throw BaseError.error(.overflow) 127 | } 128 | guard let base else { 129 | throw BaseError.error(.overflow) 130 | } 131 | 132 | elements.withUnsafeBytes { buffer in 133 | let destination = base.contents().advanced(by: count * MemoryLayout.stride) 134 | buffer.copyBytes(to: .init(start: destination, count: buffer.count)) 135 | } 136 | count += elements.count 137 | } 138 | } 139 | 140 | // MARK: - 141 | 142 | public extension MTLDevice { 143 | /// Creates a `TypedMTLBuffer` from the given data. 144 | /// 145 | /// - Parameters: 146 | /// - data: The data to copy into the new buffer. 147 | /// - options: Options for the new buffer. Default is an empty option set. 148 | /// - Returns: A new `TypedMTLBuffer` containing the specified data. 149 | /// - Throws: `BaseError.illegalValue` if the data size is not a multiple of the size of `Element`. 150 | /// `BaseError.resourceCreationFailure` if the buffer creation fails. 151 | func makeTypedBuffer(data: Data, options: MTLResourceOptions = []) throws -> TypedMTLBuffer { 152 | if !data.count.isMultiple(of: MemoryLayout.stride) { 153 | throw BaseError.error(.illegalValue) 154 | } 155 | let count = data.count / MemoryLayout.stride 156 | return try data.withUnsafeBytes { buffer in 157 | guard let baseAddress = buffer.baseAddress else { 158 | throw BaseError.error(.resourceCreationFailure) 159 | } 160 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else { 161 | throw BaseError.error(.resourceCreationFailure) 162 | } 163 | return TypedMTLBuffer(mtlBuffer: buffer, count: count) 164 | } 165 | } 166 | 167 | /// Creates a `TypedMTLBuffer` from the given array. 168 | /// 169 | /// - Parameters: 170 | /// - data: The array to copy into the new buffer. 171 | /// - options: Options for the new buffer. Default is an empty option set. 172 | /// - Returns: A new `TypedMTLBuffer` containing the specified data. 173 | /// - Throws: `BaseError.resourceCreationFailure` if the buffer creation fails. 174 | func makeTypedBuffer(data: [Element], options: MTLResourceOptions = []) throws -> TypedMTLBuffer { 175 | if data.isEmpty { 176 | return TypedMTLBuffer() 177 | } 178 | else { 179 | return try data.withUnsafeBytes { buffer in 180 | guard let baseAddress = buffer.baseAddress else { 181 | throw BaseError.error(.resourceCreationFailure) 182 | } 183 | guard let buffer = makeBuffer(bytes: baseAddress, length: buffer.count, options: options) else { 184 | throw BaseError.error(.resourceCreationFailure) 185 | } 186 | return TypedMTLBuffer(mtlBuffer: buffer, count: data.count) 187 | } 188 | } 189 | } 190 | 191 | func makeTypedBuffer(element: Element.Type, capacity: Int, options: MTLResourceOptions = []) throws -> TypedMTLBuffer { 192 | if capacity == 0 { 193 | return TypedMTLBuffer() 194 | } 195 | else { 196 | guard let buffer = makeBuffer(length: MemoryLayout.stride * capacity, options: options) else { 197 | throw BaseError.error(.resourceCreationFailure) 198 | } 199 | // TODO: FIXME - remove this 200 | memset(buffer.contents(), 0xFF, buffer.length) 201 | 202 | return TypedMTLBuffer(mtlBuffer: buffer, count: 0) 203 | } 204 | } 205 | 206 | func makeTypedBuffer(capacity: Int, options: MTLResourceOptions = []) throws -> TypedMTLBuffer { 207 | try makeTypedBuffer(element: Element.self, capacity: capacity, options: options) 208 | } 209 | } 210 | 211 | // MARK: - 212 | 213 | public extension MTLRenderCommandEncoder { 214 | /// Sets a vertex buffer for the render command encoder. 215 | /// 216 | /// - Parameters: 217 | /// - buffer: The `TypedMTLBuffer` to set as the vertex buffer. 218 | /// - offset: The offset in elements from the start of the buffer. This value is multiplied by `MemoryLayout.stride` to calculate the byte offset. 219 | /// - index: The index into the buffer argument table. 220 | func setVertexBuffer (_ buffer: TypedMTLBuffer, offset: Int, index: Int) { 221 | buffer.withUnsafeMTLBuffer { 222 | setVertexBuffer($0, offset: offset * MemoryLayout.stride, index: index) 223 | } 224 | } 225 | 226 | /// Sets a fragment buffer for the render command encoder. 227 | /// 228 | /// - Parameters: 229 | /// - buffer: The `TypedMTLBuffer` to set as the fragment buffer. 230 | /// - offset: The offset in elements from the start of the buffer. This value is multiplied by `MemoryLayout.stride` to calculate the byte offset. 231 | /// - index: The index into the buffer argument table. 232 | func setFragmentBuffer (_ buffer: TypedMTLBuffer, offset: Int, index: Int) { 233 | buffer.withUnsafeMTLBuffer { 234 | setFragmentBuffer($0, offset: offset * MemoryLayout.stride, index: index) 235 | } 236 | } 237 | } 238 | 239 | public extension MTLComputeCommandEncoder { 240 | /// Sets a buffer for the compute command encoder. 241 | /// 242 | /// - Parameters: 243 | /// - buffer: The `TypedMTLBuffer` to set. 244 | /// - offset: The offset in elements from the start of the buffer. This value is multiplied by `MemoryLayout.stride` to calculate the byte offset. 245 | /// - index: The index into the buffer argument table. 246 | func setBuffer (_ buffer: TypedMTLBuffer, offset: Int, index: Int) { 247 | buffer.withUnsafeMTLBuffer { 248 | setBuffer($0, offset: offset * MemoryLayout.stride, index: index) 249 | } 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /Tests/.swiftlint.yml: -------------------------------------------------------------------------------- 1 | disabled_rules: 2 | - force_try 3 | - force_cast 4 | - force_unwrapping 5 | - function_body_length 6 | - fatal_error_message 7 | - line_length 8 | - identifier_name 9 | - explicit_top_level_acl 10 | -------------------------------------------------------------------------------- /Tests/ComputeTests/ComputeTests.swift: -------------------------------------------------------------------------------- 1 | @testable import Compute 2 | import Metal 3 | import Testing 4 | 5 | struct ComputeTests { 6 | let device: MTLDevice 7 | let compute: Compute 8 | 9 | init() throws { 10 | let device = MTLCreateSystemDefaultDevice()! 11 | self.device = device 12 | self.compute = try Compute(device: device) 13 | } 14 | 15 | @Test 16 | func computeInitialization() throws { 17 | #expect(compute != nil) 18 | #expect(compute.device === device) 19 | } 20 | 21 | @Test 22 | func shaderLibraryCreation() throws { 23 | let source = """ 24 | #include 25 | using namespace metal; 26 | kernel void add(device int* a [[buffer(0)]], device int* b [[buffer(1)]], device int* result [[buffer(2)]], uint id [[thread_position_in_grid]]) { 27 | result[id] = a[id] + b[id]; 28 | } 29 | """ 30 | let library = ShaderLibrary.source(source) 31 | let function = library.add 32 | #expect(function.name == "add") 33 | } 34 | 35 | @Test 36 | func pipelineCreation() throws { 37 | let source = """ 38 | #include 39 | using namespace metal; 40 | kernel void add(device int* a [[buffer(0)]], device int* b [[buffer(1)]], device int* result [[buffer(2)]], uint id [[thread_position_in_grid]]) { 41 | result[id] = a[id] + b[id]; 42 | } 43 | """ 44 | let library = ShaderLibrary.source(source) 45 | let function = library.add 46 | 47 | let pipeline = try compute.makePipeline(function: function) 48 | 49 | #expect(pipeline != nil) 50 | } 51 | 52 | @Test 53 | func simpleAddition() throws { 54 | let source = """ 55 | #include 56 | using namespace metal; 57 | kernel void add(device int* a [[buffer(0)]], device int* b [[buffer(1)]], device int* result [[buffer(2)]], uint id [[thread_position_in_grid]]) { 58 | result[id] = a[id] + b[id]; 59 | } 60 | """ 61 | let library = ShaderLibrary.source(source) 62 | let function = library.add 63 | 64 | let count = 1_000 65 | let a = [Int32](repeating: 1, count: count) 66 | let b = [Int32](repeating: 2, count: count) 67 | 68 | let bufferA = device.makeBuffer(bytes: a, length: MemoryLayout.stride * count, options: [])! 69 | let bufferB = device.makeBuffer(bytes: b, length: MemoryLayout.stride * count, options: [])! 70 | let bufferResult = device.makeBuffer(length: MemoryLayout.stride * count, options: [])! 71 | 72 | let pipeline = try compute.makePipeline( 73 | function: function, 74 | arguments: [ 75 | "a": .buffer(bufferA), 76 | "b": .buffer(bufferB), 77 | "result": .buffer(bufferResult) 78 | ] 79 | ) 80 | 81 | try compute.run(pipeline: pipeline, width: count) 82 | 83 | bufferResult.contents().withMemoryRebound(to: Int32.self, capacity: count) { result in 84 | for i in 0..