├── .gitignore ├── README.md ├── examples ├── add │ ├── add.metal │ └── main.go ├── grayscale │ ├── gray-puppy.jpg │ ├── grayscale.metal │ ├── main.go │ └── puppy-g7b38fec9b_1920.jpg └── mandelbrot │ ├── main.go │ ├── mandelbrot.metal │ └── mandelbrot.png ├── go.mod ├── go.sum ├── matrix_test.go ├── mtl.go ├── mtl.h └── mtl.m /.gitignore: -------------------------------------------------------------------------------- 1 | .ccls-cache 2 | examples/add/add 3 | examples/grayscale/grayscale 4 | examples/mandelbrot/mandelbrot 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gpu 2 | 3 | -------------------------------------------------------------------------------- /examples/add/add.metal: -------------------------------------------------------------------------------- 1 | #include 2 | using namespace metal; 3 | 4 | typedef struct Params { 5 | int w_in, h_in, d_in; 6 | int w_out, h_out, d_out; 7 | } Params; 8 | 9 | int idx(int x, int y, int z, int w, int h, int d) { 10 | int i = z * w * h; 11 | i += y * w; 12 | i += x; 13 | return i; 14 | } 15 | 16 | kernel void process(device const Params* p, 17 | device const float* input, 18 | device float* output, 19 | uint3 gridSize[[threads_per_grid]], 20 | uint3 gid[[thread_position_in_grid]]) { 21 | // Only process once per row of data. 22 | if(gid.x != 0) { 23 | return; 24 | } 25 | 26 | // Since we know we're in the first column... 27 | // we can process the whole row. 28 | int input_index = idx(gid.x, gid.y, gid.z, 29 | p->w_in, p->h_in, p->d_in); 30 | 31 | float a = input[input_index]; 32 | float b = input[input_index+1]; 33 | 34 | int output_index = idx(0, gid.y, 0, 35 | p->w_out, p->h_out, p->d_out); 36 | 37 | output[output_index] = a + b; 38 | } 39 | -------------------------------------------------------------------------------- /examples/add/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | 7 | "github.com/a-h/gpu" 8 | ) 9 | 10 | //go:embed add.metal 11 | var source string 12 | 13 | func main() { 14 | // Compilation has to be done once. 15 | gpu.Compile(source) 16 | 17 | input := gpu.NewMatrix[float32](2, 10, 1) 18 | // Initialize like: 19 | // 0.0 0.0 20 | // 1.0 1.0 21 | // 2.0 2.0 22 | z := input.D - 1 23 | for y := 0; y < input.H; y++ { 24 | for x := 0; x < input.W; x++ { 25 | input.Set(x, y, z, float32(y)) 26 | } 27 | } 28 | // 1 across, 10 down, 1 deep. 29 | output := gpu.NewMatrix[float32](1, 10, 1) 30 | 31 | // Run code on GPU, includes copying the matrix to the GPU. 32 | gpu.Run(input, output) 33 | 34 | // The GPU code adds the numbers in column A and B together, so the results are: 35 | // 0.0 36 | // 2.0 (1+1) 37 | // 4.0 (2+2) 38 | // 6.0 (3+3) 39 | // ... 40 | for y := 0; y < output.H; y++ { 41 | fmt.Printf("Summed: %v\n", output.Get(0, y, 0)) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /examples/grayscale/gray-puppy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-h/gpu/d4b763031d4b4741658cbe32ca4aad563846fd6d/examples/grayscale/gray-puppy.jpg -------------------------------------------------------------------------------- /examples/grayscale/grayscale.metal: -------------------------------------------------------------------------------- 1 | #include 2 | using namespace metal; 3 | 4 | typedef struct Params { 5 | int w_in, h_in, d_in; 6 | int w_out, h_out, d_out; 7 | } Params; 8 | 9 | int idx(int x, int y, int z, int w, int h, int d) { 10 | int i = z * w * h; 11 | i += y * w; 12 | i += x; 13 | return i; 14 | } 15 | 16 | kernel void process(device const Params* p, 17 | device uint8_t* input, 18 | device uint8_t* output, 19 | uint3 gridSize[[threads_per_grid]], 20 | uint3 gid[[thread_position_in_grid]]) { 21 | // Only process once per pixel of data (4 uint8_t) 22 | if(gid.x % 4 != 0) { 23 | return; 24 | } 25 | 26 | int input_index = idx(gid.x, gid.y, gid.z, 27 | p->w_in, p->h_in, p->d_in); 28 | 29 | uint8_t r = input[input_index+0]; 30 | uint8_t g = input[input_index+1]; 31 | uint8_t b = input[input_index+2]; 32 | uint8_t a = input[input_index+3]; 33 | 34 | uint8_t avg = uint8_t((int(r) + int(g) + int(b)) / 3); 35 | 36 | output[input_index+0] = avg; 37 | output[input_index+1] = avg; 38 | output[input_index+2] = avg; 39 | output[input_index+3] = 255; 40 | } 41 | -------------------------------------------------------------------------------- /examples/grayscale/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | _ "embed" 5 | "image" 6 | "image/jpeg" 7 | "log" 8 | "os" 9 | 10 | "github.com/a-h/gpu" 11 | ) 12 | 13 | //go:embed grayscale.metal 14 | var source string 15 | 16 | func main() { 17 | gpu.Compile(source) 18 | 19 | f, err := os.Open("puppy-g7b38fec9b_1920.jpg") 20 | if err != nil { 21 | log.Fatalf("failed to read puppy JPEG: %v", err) 22 | } 23 | defer f.Close() 24 | jpg, err := jpeg.Decode(f) 25 | if err != nil { 26 | log.Fatalf("failed to decode JPEG: %v", err) 27 | } 28 | 29 | // Create a matrix to copy the data into. 30 | // Unfortunately, there's no backing byte array that's easy to access. 31 | // So load into the matrix. 32 | 33 | bounds := jpg.Bounds() 34 | stride := 4 35 | input := gpu.NewMatrix[uint8](bounds.Dx()*stride, bounds.Dy(), 1) 36 | for y := 0; y < bounds.Dy(); y++ { 37 | for x := 0; x < bounds.Dx(); x++ { 38 | r, g, b, a := jpg.At(x, y).RGBA() 39 | input.Set((x*stride)+0, y, 0, uint8(r/257)) 40 | input.Set((x*stride)+1, y, 0, uint8(g/257)) 41 | input.Set((x*stride)+2, y, 0, uint8(b/257)) 42 | input.Set((x*stride)+3, y, 0, uint8(a/257)) 43 | } 44 | } 45 | 46 | // Configure the output. 47 | output := gpu.NewMatrix[uint8](bounds.Dx()*stride, bounds.Dy(), 1) 48 | 49 | // Run the processing. 50 | gpu.Run(input, output) 51 | 52 | // Write the output. 53 | fo, err := os.Create("gray-puppy.jpg") 54 | if err != nil { 55 | log.Fatalf("failed to create grayscale puppy: %v", err) 56 | } 57 | img := image.NewRGBA(jpg.Bounds()) 58 | img.Pix = output.Data 59 | err = jpeg.Encode(fo, img, &jpeg.Options{ 60 | Quality: 100, 61 | }) 62 | if err != nil { 63 | log.Fatalf("failed to write JPEG to disk: %v", err) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /examples/grayscale/puppy-g7b38fec9b_1920.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-h/gpu/d4b763031d4b4741658cbe32ca4aad563846fd6d/examples/grayscale/puppy-g7b38fec9b_1920.jpg -------------------------------------------------------------------------------- /examples/mandelbrot/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | _ "embed" 5 | "image" 6 | "image/png" 7 | "log" 8 | "os" 9 | 10 | "github.com/a-h/gpu" 11 | ) 12 | 13 | //go:embed mandelbrot.metal 14 | var source string 15 | 16 | func main() { 17 | gpu.Compile(source) 18 | 19 | size := image.Rect(0, 0, 1920, 1080) 20 | // Input and output are both unpopulated. 21 | stride := 4 // r, g, b, and A. 22 | input := gpu.NewMatrix[uint8](size.Dx()*stride, size.Dy(), 1) 23 | output := gpu.NewMatrix[uint8](size.Dx()*stride, size.Dy(), 1) 24 | 25 | // Run the processing. 26 | gpu.Run(input, output) 27 | 28 | // Write the output. 29 | fo, err := os.Create("mandelbrot.png") 30 | if err != nil { 31 | log.Fatalf("failed to create mandelbrot PNG: %v", err) 32 | } 33 | img := image.NewRGBA(size) 34 | img.Pix = output.Data 35 | err = png.Encode(fo, img) 36 | if err != nil { 37 | log.Fatalf("failed to write PNG to disk: %v", err) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/mandelbrot/mandelbrot.metal: -------------------------------------------------------------------------------- 1 | #include 2 | using namespace metal; 3 | 4 | // Code to support mandelbrot set calculations. 5 | 6 | constant int maxRainbow = 256 * 3; 7 | constant int maxIterations = 250; 8 | 9 | int section(int n) { 10 | return 256 * n; 11 | } 12 | 13 | struct RGBA { 14 | uint8_t r; 15 | uint8_t g; 16 | uint8_t b; 17 | uint8_t a; 18 | }; 19 | 20 | struct RGBA colorFromIndex(int i) { 21 | i = i % section(5); 22 | // Red to yellow. 23 | if(i < section(1)) { 24 | return RGBA{255, (uint8_t)i, 0, 255}; 25 | } 26 | // Yellow to green. 27 | if(i < section(2)) { 28 | return RGBA{(uint8_t)(section(2) - i - 1), 255, 0, 255}; 29 | } 30 | // Green to light blue. 31 | if(i < section(3)) { 32 | return RGBA{0, 255, (uint8_t)(section(2) + i), 255}; 33 | } 34 | // Light blue to dark blue. 35 | if(i < section(4)) { 36 | return RGBA{0, (uint8_t)(section(4) - i - 1), 255, 255}; 37 | } 38 | // Dark blue to purple. 39 | return RGBA{(uint8_t)(section(4) + i), 0, 255, 255}; 40 | } 41 | 42 | float scale(float fromMax, float toMin, float toMax, float v) { 43 | return ((v / fromMax) * (toMax - toMin)) + toMin; 44 | } 45 | 46 | // isInSet returns 0 for numbers that are in the set, or the number of iterations taken to escape. 47 | int isInSet(float creal, float cimag) { 48 | float zreal = creal; 49 | float zimag = cimag; 50 | for(int n = 0; n < maxIterations; n++) { 51 | float zzreal = zreal; 52 | float zzimag = zimag; 53 | zreal = (zzreal*zzreal - zzimag*zzimag) + creal; 54 | zimag = (zzreal*zzimag + zzimag*zzreal) + cimag; 55 | if(zreal > 2.0 || zimag > 2.0 ) { 56 | return n; 57 | } 58 | } 59 | return 0; 60 | 61 | } 62 | 63 | // Normal metal code from here onwards. 64 | 65 | typedef struct Params { 66 | int w_in, h_in, d_in; 67 | int w_out, h_out, d_out; 68 | } Params; 69 | 70 | int idx(int x, int y, int z, int w, int h, int d) { 71 | int i = z * w * h; 72 | i += y * w; 73 | i += x; 74 | return i; 75 | } 76 | 77 | kernel void process(device const Params* p, 78 | device uint8_t* input, 79 | device uint8_t* output, 80 | uint3 gridSize[[threads_per_grid]], 81 | uint3 gid[[thread_position_in_grid]]) { 82 | 83 | // Only process once per pixel of data (4 uint8_t) 84 | if(gid.x % 4 != 0) { 85 | return; 86 | } 87 | 88 | int x = gid.x / 4; 89 | int y = gid.y; 90 | int w = gridSize[0]; 91 | int h = gridSize[1]; 92 | 93 | int index = idx(gid.x, gid.y, gid.z, 94 | p->w_in, p->h_in, p->d_in); 95 | 96 | // Parameters to define the visible area. 97 | float min_r = -1.4; 98 | float max_r = 3.0; 99 | float min_i = -0.8; 100 | float max_i = 0.8; 101 | 102 | // Show some numbers. 103 | float r = scale(w, min_r, max_r, float(x)); 104 | float i = scale(h, min_i, max_i, float(y)); 105 | int n = isInSet(r, i); 106 | if(n == 0) { 107 | output[index+0] = 0, 108 | output[index+1] = 0, 109 | output[index+2] = 0, 110 | output[index+3] = 255; 111 | } else { 112 | float rainbowIndex = scale(float(maxIterations), 0.0, float(maxRainbow), float(n)); 113 | RGBA c = colorFromIndex(int(rainbowIndex)); 114 | output[index+0] = c.r; 115 | output[index+1] = c.g; 116 | output[index+2] = c.b; 117 | output[index+3] = 255; 118 | } 119 | } 120 | 121 | -------------------------------------------------------------------------------- /examples/mandelbrot/mandelbrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a-h/gpu/d4b763031d4b4741658cbe32ca4aad563846fd6d/examples/mandelbrot/mandelbrot.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/a-h/gpu 2 | 3 | go 1.18 4 | 5 | require github.com/google/go-cmp v0.5.7 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= 2 | github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= 3 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 4 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 5 | -------------------------------------------------------------------------------- /matrix_test.go: -------------------------------------------------------------------------------- 1 | package gpu 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | ) 8 | 9 | func TestMatrix(t *testing.T) { 10 | t.Run("1D matrix", func(t *testing.T) { 11 | m := NewMatrix[float32](10, 1, 1) 12 | m.Set(0, 0, 0, 1.0) 13 | m.Set(9, 0, 0, 1.0) 14 | expected := []float32{1.0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0} 15 | expectMatrix(t, expected, m.Data) 16 | expectValue(t, m, 0, 0, 0, 1.0) 17 | expectValue(t, m, 9, 0, 0, 1.0) 18 | }) 19 | t.Run("2D matrix", func(t *testing.T) { 20 | m := NewMatrix[float32](3, 3, 1) 21 | m.Set(0, 0, 0, 1.0) 22 | m.Set(1, 1, 0, 1.0) 23 | m.Set(2, 2, 0, 1.0) 24 | expected := []float32{ 25 | 1.0, 0.0, 0.0, 26 | 0.0, 1.0, 0.0, 27 | 0.0, 0.0, 1.0, 28 | } 29 | expectMatrix(t, expected, m.Data) 30 | expectValue(t, m, 2, 2, 0, 1.0) 31 | }) 32 | t.Run("#D matrix", func(t *testing.T) { 33 | m := NewMatrix[float32](3, 3, 2) 34 | m.Set(0, 0, 1, 1.0) 35 | m.Set(1, 1, 1, 1.0) 36 | m.Set(2, 2, 1, 1.0) 37 | expected := []float32{ 38 | 0.0, 0.0, 0.0, 39 | 0.0, 0.0, 0.0, 40 | 0.0, 0.0, 0.0, 41 | 1.0, 0.0, 0.0, 42 | 0.0, 1.0, 0.0, 43 | 0.0, 0.0, 1.0, 44 | } 45 | expectMatrix(t, expected, m.Data) 46 | expectValue(t, m, 2, 2, 0, 0.0) 47 | expectValue(t, m, 2, 2, 1, 1.0) 48 | }) 49 | } 50 | 51 | func expectMatrix[T GPUType](t *testing.T, want, got []T) { 52 | if diff := cmp.Diff(want, got); diff != "" { 53 | t.Error(diff) 54 | } 55 | } 56 | 57 | func expectValue[T GPUType](t *testing.T, m *Matrix[T], x, y, z int, expected T) { 58 | if got := m.Get(x, y, z); got != expected { 59 | t.Errorf("expected %v, got %v", expected, got) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /mtl.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | // +build darwin 3 | 4 | package gpu 5 | 6 | /* 7 | #cgo LDFLAGS: -framework Metal -framework CoreGraphics -framework Foundation 8 | #include 9 | #include 10 | #include "mtl.h" 11 | */ 12 | import "C" 13 | import ( 14 | "sync" 15 | "unsafe" 16 | ) 17 | 18 | type GPUType interface { 19 | uint8 | uint32 | int32 | float32 20 | } 21 | 22 | func NewMatrix[T GPUType](w, h, d int) *Matrix[T] { 23 | // Store matrix like a display buffer, i.e. 24 | // get a whole y row, and index it with x. 25 | // so, d, y, x 26 | m := &Matrix[T]{ 27 | W: w, 28 | H: h, 29 | D: d, 30 | init: &sync.Once{}, 31 | } 32 | return m 33 | } 34 | 35 | type Matrix[T GPUType] struct { 36 | W, H, D int 37 | Data []T 38 | init *sync.Once 39 | } 40 | 41 | func (m *Matrix[T]) Populate() { 42 | m.init.Do(func() { 43 | if m.Data == nil { 44 | m.Data = make([]T, m.W*m.H*m.D) 45 | } 46 | }) 47 | } 48 | 49 | func (m Matrix[T]) Index(x, y, z int) (i int) { 50 | i += z * m.W * m.H 51 | i += y * m.W 52 | i += x 53 | return i 54 | } 55 | 56 | func (m *Matrix[T]) Set(x, y, z int, v T) { 57 | m.Populate() 58 | m.Data[m.Index(x, y, z)] = v 59 | } 60 | 61 | func (m Matrix[T]) Get(x, y, z int) T { 62 | m.Populate() 63 | return m.Data[m.Index(x, y, z)] 64 | } 65 | 66 | func (m Matrix[T]) Size() int { 67 | return m.W * m.H * m.D 68 | } 69 | 70 | // Compile the shader. Only needs to be done once. 71 | func Compile(shaderCode string) { 72 | src := C.CString(shaderCode) 73 | defer C.free(unsafe.Pointer(src)) 74 | C.compile(src) 75 | } 76 | 77 | func dataSizeBytes[T GPUType]() int32 { 78 | var v T 79 | switch any(v).(type) { 80 | case uint8: 81 | return 1 82 | case int32: 83 | return 4 84 | case uint32: 85 | return 4 86 | case float32: 87 | return 4 88 | } 89 | panic("unknown data size for GPU type") 90 | } 91 | 92 | // Params matches the definitions in mtl.m etc. 93 | type params struct { 94 | // Size of input matrix. 95 | WIn, HIn, DIn int32 96 | // Size of output matrix. 97 | WOut, HOut, DOut int32 98 | } 99 | 100 | func Run[TIn GPUType, TOut GPUType](input *Matrix[TIn], output *Matrix[TOut]) { 101 | // Setup. 102 | var in unsafe.Pointer 103 | var inputSize int 104 | if len(input.Data) > 0 { 105 | in = unsafe.Pointer(&input.Data[0]) 106 | inputSize = input.Size() 107 | } 108 | C.createBuffers(in, C.int(dataSizeBytes[TIn]()), C.int(inputSize), 109 | C.int(dataSizeBytes[TOut]()), C.int(output.Size())) 110 | // Convert the Go param struct to its C version. 111 | p := params{ 112 | WIn: int32(input.W), 113 | HIn: int32(input.H), 114 | DIn: int32(input.D), 115 | WOut: int32(output.W), 116 | HOut: int32(output.H), 117 | DOut: int32(output.D), 118 | } 119 | cp := (*C.Params)(unsafe.Pointer(&p)) 120 | // Run. 121 | ptr := C.run(cp) 122 | output.Data = unsafe.Slice((*TOut)(ptr), output.Size()) 123 | return 124 | } 125 | -------------------------------------------------------------------------------- /mtl.h: -------------------------------------------------------------------------------- 1 | // +build darwin 2 | 3 | typedef unsigned long uint_t; 4 | typedef unsigned char uint8_t; 5 | typedef unsigned short uint16_t; 6 | typedef unsigned long long uint64_t; 7 | 8 | // Matches with Params type in mtl.go 9 | typedef struct Params { 10 | int w_in, h_in, d_in; 11 | int w_out, h_out, d_out; 12 | } Params; 13 | 14 | void compile(char* source); 15 | void createBuffers(void* in, int in_data_size_bytes, int in_array_size, 16 | int out_data_size_bytes, int out_array_size); 17 | void* run(Params *params); 18 | -------------------------------------------------------------------------------- /mtl.m: -------------------------------------------------------------------------------- 1 | // +build darwin 2 | 3 | #include "mtl.h" 4 | #import 5 | 6 | id device; 7 | id pipelineState; 8 | id commandQueue; 9 | 10 | void compile(char *source) { 11 | device = MTLCreateSystemDefaultDevice(); 12 | // NSLog(@"Using default device %s", [device.name UTF8String]); 13 | 14 | // Create library of code. 15 | NSError *error = nil; 16 | MTLCompileOptions *compileOptions = [MTLCompileOptions new]; 17 | compileOptions.languageVersion = MTLLanguageVersion1_1; 18 | NSString *ss = [NSString stringWithUTF8String:source]; 19 | id newLibrary = [device newLibraryWithSource:ss 20 | options:compileOptions 21 | error:&error]; 22 | if (newLibrary == nil) { 23 | NSLog(@"Failed to create new library, error %@.", error); 24 | return; 25 | } 26 | 27 | // Add the process function. 28 | id processFunction = 29 | [newLibrary newFunctionWithName:@"process"]; 30 | if (processFunction == nil) { 31 | NSLog(@"Failed to find the process function."); 32 | return; 33 | } 34 | 35 | //NSLog(@"%@", [newLibrary functionNames]); 36 | 37 | // Create a compute pipeline state object. 38 | pipelineState = [device newComputePipelineStateWithFunction:processFunction 39 | error:&error]; 40 | if (pipelineState == nil) { 41 | NSLog(@"Failed to created pipeline state object, error %@.", error); 42 | return; 43 | } 44 | 45 | commandQueue = [device newCommandQueue]; 46 | if (commandQueue == nil) { 47 | NSLog(@"Failed to find the command queue."); 48 | return; 49 | } 50 | } 51 | 52 | id bufferInput; 53 | id bufferOutput; 54 | 55 | void createBuffers(void* in, int in_data_size_bytes, int in_array_size, 56 | int out_data_size_bytes, int out_array_size) { 57 | bufferInput = [device newBufferWithBytes:in 58 | length:in_array_size*in_data_size_bytes 59 | options:MTLResourceStorageModeShared]; 60 | bufferOutput = [device newBufferWithLength:out_array_size*out_data_size_bytes 61 | options:MTLResourceStorageModeShared]; 62 | } 63 | 64 | void *run(Params *params) { 65 | @autoreleasepool { 66 | NSError *error = nil; 67 | 68 | // Send compute command. 69 | id commandBuffer = [commandQueue commandBuffer]; 70 | if (commandBuffer == nil) { 71 | NSLog(@"Failed to get the command buffer."); 72 | return nil; 73 | } 74 | // Get the compute encoder. 75 | id computeEncoder = 76 | [commandBuffer computeCommandEncoder]; 77 | if (computeEncoder == nil) { 78 | NSLog(@"Failed to get the compute encoder."); 79 | return nil; 80 | } 81 | 82 | // Create the data to pass in to the add function. 83 | // Buffers to hold data. 84 | // Encode the pipeline state object and its parameters. 85 | [computeEncoder setComputePipelineState:pipelineState]; 86 | // The inputs. 87 | [computeEncoder setBytes:params length:24 atIndex:0]; // 24 bytes (32 bits * 6) 88 | [computeEncoder setBuffer:bufferInput offset:0 atIndex:1]; 89 | [computeEncoder setBuffer:bufferOutput offset:0 atIndex:2]; 90 | 91 | MTLSize threadsPerGrid = MTLSizeMake(params->w_in, params->h_in, params->d_in); 92 | 93 | // Calculate a threadgroup size. 94 | // https://developer.apple.com/documentation/metal/calculating_threadgroup_and_grid_sizes?language=objc 95 | NSUInteger w = pipelineState.threadExecutionWidth; 96 | NSUInteger h = pipelineState.maxTotalThreadsPerThreadgroup / w; 97 | MTLSize threadsPerThreadgroup = MTLSizeMake(w, h, 1); 98 | 99 | // Encode the compute command. 100 | [computeEncoder dispatchThreads:threadsPerGrid 101 | threadsPerThreadgroup:threadsPerThreadgroup]; 102 | 103 | // End the compute pass. 104 | [computeEncoder endEncoding]; 105 | 106 | // Execute the command. 107 | [commandBuffer commit]; 108 | 109 | // Normally, you want to do other work in your app while the GPU is running, 110 | // but in this example, the code simply blocks until the calculation is 111 | // complete. 112 | [commandBuffer waitUntilCompleted]; 113 | 114 | return bufferOutput.contents; 115 | } 116 | } 117 | --------------------------------------------------------------------------------