├── .gitignore ├── src ├── kernels │ ├── saxpy.zig │ └── reduce.zig ├── saxpy.zig ├── reduce.zig └── common.zig └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .zig-cache/ 2 | zig-out 3 | -------------------------------------------------------------------------------- /src/kernels/saxpy.zig: -------------------------------------------------------------------------------- 1 | export fn saxpy(y: [*]addrspace(.global) f32, x: [*]addrspace(.global) const f32, a: f32) callconv(.kernel) void { 2 | const gid = @workGroupId(0) * @workGroupSize(0) + @workItemId(0); 3 | y[gid] += x[gid] * a; 4 | } 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zig OpenCL SPIR-V Demos 2 | 3 | This repository hosts a bunch of examples of writing OpenCL kernels in Zig using the Zig SPIR-V backend. 4 | 5 | To compile these examples, you often need a master build of Zig, or a preview version of the Zig branch that I'm happening to be working on. 6 | -------------------------------------------------------------------------------- /src/kernels/reduce.zig: -------------------------------------------------------------------------------- 1 | const builtin = @import("builtin"); 2 | 3 | pub const items_per_thread = 24; 4 | pub const block_dim = 256; 5 | pub const items_per_block = items_per_thread * block_dim; 6 | 7 | var shared: [block_dim]f32 addrspace(.shared) = undefined; 8 | 9 | fn syncThreads() void { 10 | asm volatile( 11 | \\OpControlBarrier %execution_scope %memory_scope %semantics 12 | :: [execution_scope] "" (@as(u32, 2)), // Workgroup scope 13 | [memory_scope] "" (@as(u32, 2)), // Workgroup scope 14 | [semantics] "" (@as(u32, 0x100 | 0x10)) // SequentiallyConsistent | WorkgroupMemory 15 | ); 16 | } 17 | 18 | fn reduce( 19 | input: [*]const addrspace(.global) f32, 20 | output: [*]addrspace(.global) f32, 21 | last_block: u32, 22 | valid_in_last_block: u32, 23 | ) callconv(.kernel) void { 24 | const bid = @workGroupId(0); 25 | const tid = @workItemId(0); 26 | const block_offset = bid * items_per_block; 27 | 28 | var total: f32 = 0; 29 | if (bid == last_block) { 30 | inline for (0..items_per_thread) |i| { 31 | const index = block_dim * i + tid; 32 | if (index < valid_in_last_block) { 33 | total += input[block_offset + block_dim * i + tid]; 34 | } 35 | } 36 | } else { 37 | inline for (0..items_per_thread) |i| { 38 | total += input[block_offset + block_dim * i + tid]; 39 | } 40 | } 41 | 42 | shared[tid] = total; 43 | 44 | syncThreads(); 45 | 46 | comptime var i: usize = 1; 47 | inline while (i < block_dim) : (i <<= 1) { 48 | if (tid % (i * 2) == 0) { 49 | shared[tid] += (&shared)[tid + i]; 50 | } 51 | syncThreads(); 52 | } 53 | 54 | if (tid == 0) { 55 | output[bid] = (&shared)[0]; 56 | } 57 | } 58 | 59 | comptime { 60 | // Only export the kernel when compiling for the device 61 | // so that if we @import this file from host it doesn't 62 | // try to reference the kernel. 63 | if (builtin.os.tag == .opencl) { 64 | @export(&reduce, .{ .name = "reduce" }); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/saxpy.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const cl = @import("opencl"); 3 | const common = @import("common.zig"); 4 | 5 | pub const std_options = common.std_options; 6 | 7 | pub fn main() !void { 8 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 9 | defer arena.deinit(); 10 | const allocator = arena.allocator(); 11 | 12 | const platform, const device = try common.parseOptions(allocator); 13 | 14 | const context = try cl.createContext(&.{device}, .{ .platform = platform }); 15 | defer context.release(); 16 | 17 | const queue = try cl.createCommandQueue(context, device, .{ .profiling = true }); 18 | defer queue.release(); 19 | 20 | const program = try common.buildSpvProgram(allocator, context, device, @embedFile("saxpy-kernel")); 21 | defer program.release(); 22 | 23 | const kernel = try cl.createKernel(program, "saxpy"); 24 | defer kernel.release(); 25 | 26 | std.log.debug("generating inputs...", .{}); 27 | 28 | const size = 256 * 1024 * 1024; 29 | 30 | const y, const x = blk: { 31 | const y = try allocator.alloc(f32, size); 32 | const x = try allocator.alloc(f32, size); 33 | for (x, 0..) |*value, i| value.* = @floatFromInt(i); 34 | for (y, 0..) |*value, i| value.* = @floatFromInt(i + 1000); 35 | break :blk .{ y, x }; 36 | }; 37 | 38 | const results = try allocator.alloc(f32, size); 39 | 40 | const a: f32 = 123; 41 | 42 | const d_y = try cl.createBufferWithData(f32, context, .{ .read_write = true }, y); 43 | const d_x = try cl.createBufferWithData(f32, context, .{ .read_only = true }, x); 44 | 45 | std.log.debug("launching kernel with {} inputs per array", .{size}); 46 | 47 | try kernel.setArg(@TypeOf(d_y), 0, d_y); 48 | try kernel.setArg(@TypeOf(d_x), 1, d_x); 49 | try kernel.setArg(f32, 2, a); 50 | 51 | const saxpy_complete = try queue.enqueueNDRangeKernel( 52 | kernel, 53 | null, 54 | &.{size}, 55 | &.{256}, 56 | &.{}, 57 | ); 58 | defer saxpy_complete.release(); 59 | 60 | const read_complete = try queue.enqueueReadBuffer( 61 | f32, 62 | d_y, 63 | false, 64 | 0, 65 | results, 66 | &.{saxpy_complete}, 67 | ); 68 | defer read_complete.release(); 69 | 70 | try cl.waitForEvents(&.{read_complete}); 71 | 72 | const start = try saxpy_complete.commandStartTime(); 73 | const stop = try saxpy_complete.commandEndTime(); 74 | const runtime = stop - start; 75 | const tput = size * @sizeOf(f32) * std.time.ns_per_s / runtime; 76 | std.log.info("kernel took {d:.2} us, {Bi:.2}/s", .{runtime / std.time.ns_per_us, tput}); 77 | 78 | std.log.debug("checking results...", .{}); 79 | 80 | // Compute reference results on host 81 | for (y, x) |*yi, xi| { 82 | yi.* += xi * a; 83 | } 84 | 85 | // Check if the results are close. 86 | // y = y + a * x is 2 operations of 0.5 ulp each, 87 | // multiply by 2 for host and device side error. 88 | const max_error = std.math.floatEps(f32) * 2 * 2; 89 | for (results, y, 0..) |ri, yi, i| { 90 | if (!std.math.approxEqRel(f32, ri, yi, max_error)) { 91 | common.fail("invalid result at index {}: expected = {d}, actual = {d}", .{ i, yi, ri }); 92 | } 93 | } 94 | 95 | std.log.info("results ok", .{}); 96 | } 97 | -------------------------------------------------------------------------------- /src/reduce.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const cl = @import("opencl"); 3 | const common = @import("common.zig"); 4 | const reduce = @import("kernels/reduce.zig"); 5 | 6 | pub const std_options = common.std_options; 7 | 8 | pub fn main() !void { 9 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 10 | defer arena.deinit(); 11 | const allocator = arena.allocator(); 12 | 13 | const platform, const device = try common.parseOptions(allocator); 14 | 15 | const context = try cl.createContext(&.{device}, .{ .platform = platform }); 16 | defer context.release(); 17 | 18 | const queue = try cl.createCommandQueue(context, device, .{ .profiling = true }); 19 | defer queue.release(); 20 | 21 | const program = try common.buildSpvProgram(allocator, context, device, @embedFile("reduce-kernel")); 22 | defer program.release(); 23 | 24 | const kernel = try cl.createKernel(program, "reduce"); 25 | defer kernel.release(); 26 | 27 | std.log.debug("generating inputs...", .{}); 28 | 29 | const size = 256 * 1024 * 1024; 30 | 31 | const input = blk: { 32 | const values = try allocator.alloc(f32, size); 33 | var rng = std.Random.DefaultPrng.init(0); 34 | const random = rng.random(); 35 | for (values) |*value| value.* = random.float(f32); 36 | break :blk values; 37 | }; 38 | 39 | var d_input = try cl.createBufferWithData(f32, context, .{ .read_write = true }, input); 40 | var d_output = try cl.createBuffer(f32, context, .{ .read_write = true }, input.len); 41 | defer d_input.release(); 42 | defer d_output.release(); 43 | 44 | var maybe_event: ?cl.Event = null; 45 | var first_event: ?cl.Event = null; 46 | var remaining_size: usize = input.len; 47 | while (remaining_size != 1) { 48 | const blocks = std.math.divCeil(usize, remaining_size, reduce.items_per_block) catch unreachable; 49 | const valid_in_last_block = remaining_size % reduce.items_per_block; 50 | std.log.debug("reducing {d} items over {d} block(s)", .{ remaining_size, blocks }); 51 | 52 | try kernel.setArg(@TypeOf(d_input), 0, d_input); 53 | try kernel.setArg(@TypeOf(d_output), 1, d_output); 54 | try kernel.setArg(u32, 2, @intCast(blocks - 1)); 55 | try kernel.setArg(u32, 3, @intCast(valid_in_last_block)); 56 | 57 | maybe_event = try queue.enqueueNDRangeKernel( 58 | kernel, 59 | null, 60 | &.{blocks * reduce.block_dim}, 61 | &.{reduce.block_dim}, 62 | if (maybe_event) |event| &.{event} else &.{}, 63 | ); 64 | 65 | if (first_event == null) { 66 | first_event = maybe_event.?; 67 | } 68 | 69 | const d_tmp = d_input; 70 | d_input = d_output; 71 | d_output = d_tmp; 72 | 73 | remaining_size = blocks; 74 | } 75 | 76 | var result: f32 = undefined; 77 | const read_complete = try queue.enqueueReadBuffer( 78 | f32, 79 | d_input, 80 | false, 81 | 0, 82 | @as(*[1]f32, &result), 83 | if (maybe_event) |event| &.{event} else &.{}, 84 | ); 85 | defer read_complete.release(); 86 | try cl.waitForEvents(&.{read_complete}); 87 | 88 | if (maybe_event) |last_event| { 89 | const start = try first_event.?.commandStartTime(); 90 | const stop = try last_event.commandEndTime(); 91 | const runtime = stop - start; 92 | const tput = input.len * @sizeOf(f32) * std.time.ns_per_s / runtime; 93 | std.log.info("reduction took {d:.2} us, {Bi:.2}/s", .{runtime / std.time.ns_per_us, tput}); 94 | } 95 | 96 | // input.len * random in [0, 1) yields an average of input.len * 0.5 97 | const expected = input.len / 2; 98 | std.log.debug("result: {d}, expected: {d}", .{result, expected}); 99 | } 100 | 101 | -------------------------------------------------------------------------------- /src/common.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Allocator = std.mem.Allocator; 3 | 4 | const cl = @import("opencl"); 5 | 6 | // Downstream demos should import this 7 | pub const std_options: std.Options = .{ 8 | .log_level = .debug, 9 | .logFn = log, 10 | }; 11 | 12 | var log_verbose: bool = false; 13 | 14 | pub fn log( 15 | comptime level: std.log.Level, 16 | comptime scope: @TypeOf(.EnumLiteral), 17 | comptime format: []const u8, 18 | args: anytype, 19 | ) void { 20 | _ = scope; 21 | if (@intFromEnum(level) <= @intFromEnum(std.log.Level.info) or log_verbose) { 22 | switch (level) { 23 | .info => std.debug.print(format ++ "\n", args), 24 | else => { 25 | const prefix = comptime level.asText(); 26 | std.debug.print(prefix ++ ": " ++ format ++ "\n", args); 27 | }, 28 | } 29 | } 30 | } 31 | 32 | pub fn fail(comptime fmt: []const u8, args: anytype) noreturn { 33 | std.log.err(fmt, args); 34 | std.process.exit(1); 35 | } 36 | 37 | const DeviceAndPlatform = struct { cl.Platform, cl.Device }; 38 | 39 | pub fn parseOptions(a: Allocator) !DeviceAndPlatform { 40 | var args = try std.process.argsWithAllocator(a); 41 | defer args.deinit(); 42 | 43 | const exe_name = args.next().?; 44 | 45 | var platform: ?[]const u8 = null; 46 | var device: ?[]const u8 = null; 47 | var help: bool = false; 48 | 49 | while (args.next()) |arg| { 50 | if (std.mem.eql(u8, arg, "--platform") or std.mem.eql(u8, arg, "-p")) { 51 | platform = args.next() orelse fail("missing argument to option {s}", .{arg}); 52 | } else if (std.mem.eql(u8, arg, "--device") or std.mem.eql(u8, arg, "-d")) { 53 | device = args.next() orelse fail("missing argument to option {s}", .{arg}); 54 | } else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { 55 | help = true; 56 | } else if (std.mem.eql(u8, arg, "--verbose") or std.mem.eql(u8, arg, "-v")) { 57 | log_verbose = true; 58 | } else { 59 | fail("unknown option '{s}'", .{arg}); 60 | } 61 | } 62 | 63 | if (help) { 64 | const out: std.fs.File = .stdout(); 65 | var writer = out.writer(&.{}); 66 | try writer.interface.print( 67 | \\usage: {s} [options...] 68 | \\ 69 | \\Options: 70 | \\--platform|-p OpenCL platform name to use. By default, uses the 71 | \\ first platform that has any devices available. 72 | \\--device|-d OpenCL device name to use. If --platform is left 73 | \\ unspecified, all devices of all platforms are 74 | \\ matched. By default, uses the first device of the 75 | \\ platform. 76 | \\--verbose|-v Turn on verbose logging. 77 | \\--help -h Show this message and exit. 78 | \\ 79 | , 80 | .{exe_name}, 81 | ); 82 | std.process.exit(0); 83 | } 84 | 85 | return try pickPlatformAndDevice(a, platform, device); 86 | } 87 | 88 | fn deviceSupportsSpirv(a: Allocator, device: cl.Device) !bool { 89 | // TODO: Check for OpenCL 3.0 before accessing this function? 90 | const ils = try device.getILsWithVersion(a); 91 | defer a.free(ils); 92 | 93 | for (ils) |il| { 94 | // TODO: Minimum version? 95 | if (std.mem.eql(u8, il.getName(), "SPIR-V")) { 96 | std.log.debug("Support for SPIR-V version {}.{}.{} detected", .{ 97 | il.version.major, 98 | il.version.minor, 99 | il.version.patch, 100 | }); 101 | return true; 102 | } 103 | } 104 | 105 | return false; 106 | } 107 | 108 | fn pickPlatformAndDevice( 109 | a: Allocator, 110 | maybe_platform_query: ?[]const u8, 111 | maybe_device_query: ?[]const u8, 112 | ) !DeviceAndPlatform { 113 | const platforms = try cl.getPlatforms(a); 114 | std.log.debug("{} platform(s) available", .{platforms.len}); 115 | 116 | if (platforms.len == 0) { 117 | fail("no opencl platform available", .{}); 118 | } 119 | 120 | for (platforms) |platform| { 121 | const platform_name = try platform.getName(a); 122 | defer a.free(platform_name); 123 | 124 | if (maybe_platform_query) |platform_query| { 125 | if (!std.mem.eql(u8, platform_name, platform_query)) { 126 | continue; 127 | } 128 | } 129 | 130 | std.log.debug("trying platform '{s}'", .{platform_name}); 131 | 132 | const devices = try platform.getDevices(a, cl.DeviceType.all); 133 | defer a.free(devices); 134 | 135 | if (devices.len == 0) { 136 | if (maybe_platform_query != null) { 137 | fail("platform '{s}' has no devices available", .{platform_name}); 138 | } 139 | continue; 140 | } 141 | 142 | for (devices) |device| { 143 | const device_name = try device.getName(a); 144 | defer a.free(device_name); 145 | 146 | if (maybe_device_query) |device_query| { 147 | if (!std.mem.eql(u8, device_name, device_query)) { 148 | continue; 149 | } 150 | } 151 | std.log.debug("trying device '{s}'", .{device_name}); 152 | 153 | if (!try deviceSupportsSpirv(a, device)) { 154 | if (maybe_device_query != null) { 155 | fail("device '{s}' of platform '{s}' does not support SPIR-V ingestion", .{ device_name, platform_name }); 156 | } 157 | } 158 | 159 | std.log.info("selected platform '{s}' and device '{s}'", .{ platform_name, device_name }); 160 | 161 | return .{ platform, device }; 162 | } 163 | 164 | if (maybe_device_query) |device_query| { 165 | if (maybe_platform_query != null) { 166 | fail("platform '{s}' does not have any device that matches '{s}'", .{ platform_name, device_query }); 167 | } 168 | } 169 | } 170 | 171 | // Case where both platform and device queries are not null is already handled above. 172 | 173 | if (maybe_platform_query) |platform_query| { 174 | fail("no such opencl platform '{s}'", .{platform_query}); 175 | } else if (maybe_device_query) |device_query| { 176 | fail("no such opencl device '{s}'", .{device_query}); 177 | } 178 | 179 | unreachable; 180 | } 181 | 182 | pub fn buildSpvProgram(a: Allocator, context: cl.Context, device: cl.Device, spv: []const u8) !cl.Program { 183 | std.log.debug("compiling program", .{}); 184 | 185 | const program = try cl.createProgramWithIL(context, spv); 186 | errdefer program.release(); 187 | 188 | program.build(&.{device}, "") catch |err| { 189 | if (err == error.BuildProgramFailure) { 190 | const build_log = try program.getBuildLog(a, device); 191 | defer a.free(build_log); 192 | std.log.err("failed to compile kernel:\n{s}", .{build_log}); 193 | } 194 | return err; 195 | }; 196 | 197 | return program; 198 | } 199 | --------------------------------------------------------------------------------