├── .python-version ├── rust-toolchain.toml ├── requirements.txt ├── metal_dump.py ├── .gitignore ├── src ├── metadata.rs ├── data.rs ├── workload.rs ├── strides.rs ├── shape.rs ├── handle.rs ├── dtype.rs ├── storage.rs ├── bench.rs ├── quant.rs ├── lib.rs └── tensor.rs ├── README.md ├── kernels ├── sgemv │ ├── sgemv_1v_2.wgsl │ ├── sgemv_1v.wgsl │ ├── sgemv_1.wgsl │ ├── sgemv_3.wgsl │ ├── sgemv_2v.wgsl │ ├── sgemv_2.wgsl │ └── mlx_sgemv.wgsl ├── sgemm │ ├── slow.wgsl │ ├── gemm_vectorized.wgsl │ └── gemm_scalar.wgsl ├── rope │ ├── rope_cp.wgsl │ └── rope.wgsl ├── layernorm │ ├── onepass_scalar.wgsl │ ├── onepass_vec4.wgsl │ ├── naive_scalar.wgsl │ ├── naive_vec4.wgsl │ ├── welford_scalar.wgsl │ └── welford_vec4.wgsl ├── qgemv │ ├── sgemv_2v.wgsl │ └── mlx-qgemv.wgsl └── qgemm │ ├── slow.wgsl │ ├── tfjs2.wgsl │ └── tfjs.wgsl ├── Cargo.toml ├── scratch ├── scratchy └── benches ├── layernorm ├── naive_onepass.rs ├── naive_vectorized.rs ├── naive_vectorized_onepass.rs ├── welford_scalar.rs ├── welford_vectorized.rs └── naive.rs ├── mlx-gemv └── gemv.rs ├── rope └── rope.rs ├── mlx-qgemv └── gemv.rs ├── qgemv └── gemv.rs ├── qgemm └── tfjs.rs ├── sgemv └── gemv.rs └── sgemm └── tfjs.rs /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.7 2 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly" 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | numpy==1.24.3 3 | torch==2.0.1 4 | rotary-embedding-torch==0.5.3 5 | -------------------------------------------------------------------------------- /metal_dump.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | 3 | a = mx.random.uniform(shape=(16384, 3072)) 4 | b = mx.random.uniform(shape=(1, 3072)) 5 | mx.eval(a, b) 6 | 7 | trace_file = "mlx_trace.gputrace" 8 | 9 | # Make sure to run with MTL_CAPTURE_ENABLED=1 and 10 | # that the path trace_file does not already exist. 11 | mx.metal.start_capture(trace_file) 12 | for i in range(10): 13 | c = a @ b 14 | mx.eval(c) 15 | mx.metal.stop_capture() 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | 17 | # Added by cargo 18 | 19 | /target 20 | -------------------------------------------------------------------------------- /src/metadata.rs: -------------------------------------------------------------------------------- 1 | use encase::{private::WriteInto, ShaderType, UniformBuffer}; 2 | 3 | use crate::{GPUBuffer, GPUHandle}; 4 | 5 | pub const UNIFORM_ALIGN: usize = 256; 6 | pub const STORAGE_BUFFER_ALIGN: usize = 256; 7 | pub const MIN_STORAGE_BUFFER_SIZE: usize = 16; 8 | 9 | pub trait OpMetadata: Sized + ShaderType + WriteInto + std::fmt::Debug { 10 | fn into_buffer(&self, handle: &GPUHandle) -> GPUBuffer { 11 | let size: usize = self.size().get() as _; 12 | let aligned_size = size + (UNIFORM_ALIGN - size % UNIFORM_ALIGN); 13 | 14 | let mut uniform = UniformBuffer::new(Vec::with_capacity(aligned_size)); 15 | uniform.write(self).unwrap(); 16 | 17 | let buffer = handle.device().create_buffer(&wgpu::BufferDescriptor { 18 | label: None, 19 | size: aligned_size as u64, 20 | usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, 21 | mapped_at_creation: false, 22 | }); 23 | 24 | handle 25 | .queue() 26 | .write_buffer(&buffer, 0, bytemuck::cast_slice(&uniform.into_inner())); 27 | buffer.into() 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wgpu-bench 2 | 3 | Benchmark any WebGPU Kernel. 4 | 5 | Check out `/benches` for an example, simply implement the Kernel trait and boom! 6 | 7 | Provide a Python snippet to ensure that your kernel is correct! 8 | 9 | ## Optimizing a LayerNorm Kernel 10 | 11 | Reproduce: 12 | ```bash 13 | cat Cargo.toml 14 | cargo bench --bench 15 | ``` 16 | Results on M3 Max 14 core: 17 | ```bash 18 | Naive Onepass (precision FAIL) 19 | time: [88680.6021 ns 92554.9586 ns 95756.6363 ns] 20 | thrpt: [40.7935 GiB/s 42.2047 GiB/s 44.0485 GiB/s] 21 | 22 | Naive 23 | time: [115698.3828 ns 116433.8667 ns 117343.1857 ns] 24 | thrpt: [33.2891 GiB/s 33.5491 GiB/s 33.7624 GiB/s] 25 | 26 | Naive Vectorized 27 | time: [113990.1512 ns 114341.8859 ns 114775.5896 ns] 28 | thrpt: [34.0338 GiB/s 34.1629 GiB/s 34.2683 GiB/s] 29 | 30 | Welford Scalar 31 | time: [74209.2818 ns 74668.5137 ns 75306.7653 ns] 32 | thrpt: [51.8712 GiB/s 52.3146 GiB/s 52.6383 GiB/s] 33 | 34 | Welford Vectorized 35 | time: [48744.7028 ns 48831.9797 ns 48933.0603 ns] 36 | thrpt: [79.8284 GiB/s 79.9937 GiB/s 80.1369 GiB/s] 37 | ``` 38 | 39 | ## TODO 40 | 41 | - [x] Add throughput measurements 42 | - [ ] Encode more commands into a single command buffer (https://github.com/philipturner/metal-flash-attention/issues/12#issuecomment-1850300198) 43 | - [ ] Benchmark comparisons? Shared code between similar kernels? 44 | - [ ] Simplify Kernel trait 45 | - [ ] Cleaning & Polishing 🧽 46 | -------------------------------------------------------------------------------- /kernels/sgemv/sgemv_1v_2.wgsl: -------------------------------------------------------------------------------- 1 | //https://www.bealto.com/gpu-gemv_v1.html 2 | var localId: vec3; 3 | var globalId: vec3; 4 | var workgroupId: vec3; 5 | 6 | @group(0) @binding(0) var A: array>; 7 | @group(0) @binding(1) var X: array>; 8 | @group(0) @binding(2) var result: array; 9 | @group(1) @binding(0) var metadata: Meta; 10 | 11 | 12 | struct Meta { 13 | aShape: vec3, 14 | aStrides: vec3, 15 | bShape: vec3, 16 | bStrides: vec3, 17 | outShape: vec3, 18 | outShapeStrides: vec3, 19 | dimAOuter: i32, 20 | dimBOuter: i32, 21 | dimInner: i32, 22 | } 23 | 24 | @compute @workgroup_size({{workgroup_size_x}},{{workgroup_size_y}},{{workgroup_size_z}}) 25 | fn main(@builtin(global_invocation_id) globalId : vec3) { 26 | var sum0 = vec4(0.0); 27 | var sum1 = vec4(0.0); 28 | 29 | let row = i32(globalId.x) * 2; 30 | 31 | let aIndex0 = metadata.aStrides.y * row / 4; 32 | let aIndex1 = metadata.aStrides.y * (row + 1) / 4; 33 | 34 | for (var k = 0; k < metadata.dimInner / 4; k+=1) { 35 | sum0 = fma(A[aIndex0 + k], X[k], sum0); 36 | sum1 = fma(A[aIndex1 + k], X[k], sum1); 37 | } 38 | let outIndex0 = metadata.outShapeStrides.y * row; 39 | let outIndex1 = metadata.outShapeStrides.y * (row + 1); 40 | result[outIndex0] = dot(sum0, vec4(1.0)); 41 | result[outIndex1] = dot(sum1, vec4(1.0)); 42 | } 43 | -------------------------------------------------------------------------------- /kernels/sgemv/sgemv_1v.wgsl: -------------------------------------------------------------------------------- 1 | //https://www.bealto.com/gpu-gemv_v1.html 2 | var localId: vec3; 3 | var globalId: vec3; 4 | var workgroupId: vec3; 5 | 6 | @group(0) @binding(0) var A: array>; 7 | @group(0) @binding(1) var X: array>; 8 | @group(0) @binding(2) var result: array; 9 | @group(1) @binding(0) var metadata: Meta; 10 | 11 | 12 | struct Meta { 13 | aShape: vec3, 14 | aStrides: vec3, 15 | bShape: vec3, 16 | bStrides: vec3, 17 | outShape: vec3, 18 | outShapeStrides: vec3, 19 | dimAOuter: i32, 20 | dimBOuter: i32, 21 | dimInner: i32, 22 | } 23 | 24 | @compute @workgroup_size({{workgroup_size_x}},{{workgroup_size_y}},{{workgroup_size_z}}) 25 | fn main(@builtin(global_invocation_id) globalId : vec3) { 26 | let batch = i32(globalId.z); 27 | let batchA = batch % metadata.aShape[0]; 28 | let batchB = batch % metadata.bShape[0]; 29 | 30 | let aOffset = metadata.aStrides.x * batchA / 4; 31 | let bOffset = metadata.bStrides.x * batchB / 4; 32 | let outOffset = metadata.outShapeStrides.x * batch; 33 | 34 | var sum = vec4(0.0); 35 | let row = i32(globalId.x); 36 | 37 | let aIndex = aOffset + metadata.aStrides.y * row / 4; 38 | for (var k = 0; k < metadata.dimInner / 4; k+=1) { 39 | sum = fma(A[aIndex + k], X[bOffset + k], sum); 40 | } 41 | let outIndex = outOffset + metadata.outShapeStrides.y * row; 42 | result[outIndex] = dot(sum, vec4(1.0)); 43 | } 44 | -------------------------------------------------------------------------------- /kernels/sgemv/sgemv_1.wgsl: -------------------------------------------------------------------------------- 1 | //https://www.bealto.com/gpu-gemv_v1.html 2 | var localId: vec3; 3 | var globalId: vec3; 4 | var workgroupId: vec3; 5 | 6 | @group(0) @binding(0) var A: array; 7 | @group(0) @binding(1) var X: array; 8 | @group(0) @binding(2) var result: array; 9 | @group(1) @binding(0) var metadata: Meta; 10 | 11 | 12 | struct Meta { 13 | aShape: vec3, 14 | aStrides: vec3, 15 | bShape: vec3, 16 | bStrides: vec3, 17 | outShape: vec3, 18 | outShapeStrides: vec3, 19 | dimAOuter: i32, 20 | dimBOuter: i32, 21 | dimInner: i32, 22 | } 23 | 24 | @compute @workgroup_size({{workgroup_size_x}},{{workgroup_size_y}},{{workgroup_size_z}}) 25 | fn main(@builtin(local_invocation_id) localId : vec3, 26 | @builtin(global_invocation_id) globalId : vec3, 27 | @builtin(workgroup_id) workgroupId : vec3) { 28 | let batch = i32(globalId.z); 29 | let batchA = batch % metadata.aShape[0]; 30 | let batchB = batch % metadata.bShape[0]; 31 | 32 | let aOffset = metadata.aStrides.x * batchA; 33 | let bOffset = metadata.bStrides.x * batchB; 34 | let outOffset = metadata.outShapeStrides.x * batch; 35 | 36 | var sum = 0.0; 37 | let row = i32(globalId.x); 38 | let aIndex = aOffset + metadata.aStrides.y * row; 39 | for (var k = 0; k < metadata.dimInner; k+=1) { 40 | sum = fma(A[aIndex + k], X[bOffset + k], sum); 41 | } 42 | let outIndex = outOffset + metadata.outShapeStrides.y * row; 43 | result[outIndex] = sum; 44 | } 45 | -------------------------------------------------------------------------------- /kernels/sgemm/slow.wgsl: -------------------------------------------------------------------------------- 1 | //Unoptimized, only gets 500GFLOP 2 | @group(0) @binding(0) 3 | var A: array>; 4 | 5 | @group(0) @binding(1) 6 | var B: array>; 7 | 8 | @group(0) @binding(2) 9 | var C: array>; 10 | 11 | struct Meta { 12 | aShape: vec3, 13 | aStrides: vec3, 14 | bShape: vec3, 15 | bStrides: vec3, 16 | outShape: vec3, 17 | outStrides: vec3, 18 | dimInner: i32, 19 | } 20 | 21 | @group(1) @binding(0) 22 | var metadata: Meta; 23 | 24 | @compute @workgroup_size(8,8,1) 25 | fn main( 26 | @builtin(global_invocation_id) global_id: vec3 27 | ) { 28 | let M = u32(metadata.aShape.y); 29 | let N = u32(metadata.bShape.z); 30 | let K = u32(metadata.aShape.z); 31 | 32 | let a_offset = global_id.z * u32(metadata.aStrides.x); 33 | let b_offset = global_id.z * u32(metadata.bStrides.x); 34 | let c_offset = global_id.z * u32(metadata.outStrides.x); 35 | 36 | let cRow = global_id.x; 37 | let cCol = global_id.y; 38 | if (cRow < M && cCol < N / 4u) { 39 | var tmp = vec4(); 40 | for (var k = 0u; k < K / 4u; k++) { 41 | let a = A[a_offset + (cRow * (K / 4u) + k)]; 42 | let b_step = k * N + cCol; //4 rows per iter 43 | let b_stride = N / 4u; 44 | 45 | tmp = fma(vec4(a.x), B[b_offset + b_step], tmp); 46 | tmp = fma(vec4(a.y), B[b_offset + (b_step + b_stride)], tmp); 47 | tmp = fma(vec4(a.z), B[b_offset + (b_step + (2u * b_stride))], tmp); 48 | tmp = fma(vec4(a.w), B[b_offset + (b_step + (3u * b_stride))], tmp); 49 | } 50 | C[c_offset + (cRow * (N / 4u) + cCol)] = tmp; 51 | } 52 | } 53 | 54 | -------------------------------------------------------------------------------- /src/data.rs: -------------------------------------------------------------------------------- 1 | use num_traits::Float; 2 | use rand::distributions::{uniform::SampleUniform, Distribution, Standard, Uniform}; 3 | use wgpu::util::DeviceExt; 4 | 5 | use crate::GPUHandle; 6 | 7 | pub fn generate_weight_data(elements: usize) -> Vec 8 | where 9 | Standard: Distribution, 10 | F: SampleUniform, 11 | { 12 | let mut rng = rand::thread_rng(); 13 | let dist = Uniform::from(F::from(-10.0).unwrap()..F::from(10.0).unwrap()); 14 | let x: Vec = (0..elements).map(|_| dist.sample(&mut rng)).collect(); 15 | x 16 | } 17 | 18 | pub fn empty_buffer( 19 | device: &wgpu::Device, 20 | elements: usize, 21 | ) -> wgpu::Buffer 22 | where 23 | Standard: Distribution, 24 | F: SampleUniform, 25 | { 26 | let data = vec![F::zero(); elements]; 27 | device.create_buffer_init(&wgpu::util::BufferInitDescriptor { 28 | label: None, 29 | contents: bytemuck::cast_slice(&data), 30 | usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, 31 | }) 32 | } 33 | 34 | pub fn rand_gpu_buffer( 35 | handle: &GPUHandle, 36 | elements: usize, 37 | ) -> wgpu::Buffer 38 | where 39 | Standard: Distribution, 40 | F: SampleUniform, 41 | { 42 | let data = generate_weight_data::(elements); 43 | let buffer = handle 44 | .device() 45 | .create_buffer_init(&wgpu::util::BufferInitDescriptor { 46 | label: None, 47 | contents: bytemuck::cast_slice(&data), 48 | usage: wgpu::BufferUsages::STORAGE 49 | | wgpu::BufferUsages::COPY_SRC 50 | | wgpu::BufferUsages::COPY_DST, 51 | }); 52 | handle.queue().submit(None); 53 | handle.device().poll(wgpu::Maintain::Wait); 54 | buffer 55 | } 56 | -------------------------------------------------------------------------------- /kernels/rope/rope_cp.wgsl: -------------------------------------------------------------------------------- 1 | // Kernel by Carson Poole 2 | 3 | @group(0) @binding(0) 4 | var X: array; 5 | 6 | @group(0) @binding(1) 7 | var Y: array; 8 | 9 | struct Meta { 10 | in_strides: vec4, 11 | out_strides: vec4, 12 | offset: u32, 13 | base: f32, 14 | rotary_dim: u32, 15 | } 16 | 17 | @group(1) @binding(0) 18 | var metadata: Meta; 19 | 20 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) 21 | fn main( 22 | @builtin(local_invocation_id) local_id: vec3, 23 | @builtin(global_invocation_id) global_id: vec3, 24 | @builtin(subgroup_id) subgroup_id: u32, 25 | @builtin(subgroup_size) subgroup_size: u32, 26 | @builtin(num_workgroups) groups: vec3, 27 | @builtin(workgroup_id) group_id: vec3, 28 | ) { 29 | let tid = local_id.x; 30 | let batch_idx = group_id.x; 31 | let head_idx = group_id.y; 32 | let tok_idx = group_id.z; 33 | let half_rotary_dim = metadata.rotary_dim / 2u; 34 | 35 | let hx = tid / half_rotary_dim; 36 | let hy = tid % half_rotary_dim; 37 | let rot_sign = select(-1.0, 1.0, hx == 0u); 38 | 39 | let global_offset = dot(vec4(batch_idx, head_idx, tok_idx, 1), metadata.in_strides) + tid; 40 | let global_offset_rot = dot(vec4(batch_idx, head_idx, tok_idx, 1), metadata.in_strides) + (1u-hx) * half_rotary_dim + hy; 41 | 42 | let x = X[global_offset]; 43 | 44 | let x_rot = rot_sign * X[global_offset_rot]; 45 | 46 | let ar = f32(hy) * 2.0; 47 | let inv_freq = f32(tok_idx + metadata.offset) * (1.0 / pow(metadata.base, ar / f32(metadata.rotary_dim))); 48 | 49 | let sin = sin(inv_freq); 50 | let cos = cos(inv_freq); 51 | 52 | workgroupBarrier(); 53 | Y[global_offset] = x * cos + x_rot * sin; 54 | } 55 | 56 | 57 | -------------------------------------------------------------------------------- /kernels/sgemv/sgemv_3.wgsl: -------------------------------------------------------------------------------- 1 | //https://www.bealto.com/gpu-gemv_v1.html 2 | var localId: vec3; 3 | var globalId: vec3; 4 | var workgroupId: vec3; 5 | 6 | @group(0) @binding(0) var A: array; 7 | @group(0) @binding(1) var X: array; 8 | @group(0) @binding(2) var result: array; 9 | @group(1) @binding(0) var metadata: Meta; 10 | 11 | 12 | struct Meta { 13 | aShape: vec3, 14 | aStrides: vec3, 15 | bShape: vec3, 16 | bStrides: vec3, 17 | outShape: vec3, 18 | outShapeStrides: vec3, 19 | dimAOuter: i32, 20 | dimBOuter: i32, 21 | dimInner: i32, 22 | } 23 | 24 | var work: array; 25 | 26 | @compute @workgroup_size({{workgroup_size_x}},{{workgroup_size_y}},{{workgroup_size_z}}) 27 | fn main(@builtin(local_invocation_id) localId : vec3, 28 | @builtin(global_invocation_id) globalId : vec3, 29 | @builtin(workgroup_id) workgroupId : vec3) { 30 | 31 | var sum = 0.0; 32 | let row = i32(globalId.x); 33 | let aIndex = row * metadata.aStrides.y; 34 | 35 | for (var k = i32(globalId.y); k < metadata.dimInner; k+={{workgroup_size_y}}) { 36 | sum = fma(A[aIndex + k], X[k], sum); 37 | } 38 | 39 | let rows = {{workgroup_size_x}}u; 40 | let cols = {{workgroup_size_y}}u; 41 | let ii = u32(localId.x); 42 | let jj = u32(localId.y); 43 | work[ii + rows * jj] = sum; 44 | workgroupBarrier(); 45 | 46 | // Reduce sums in log2(cols) steps 47 | for (var s = u32(cols) / 2u; s > 0u; s >>= 1u) { 48 | if (jj < s) { 49 | work[ii + rows * jj] += work[ii + rows * (jj + s)]; 50 | } 51 | workgroupBarrier(); 52 | } 53 | 54 | if (jj == 0u) { 55 | result[row] = work[ii]; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/workload.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, derive_new::new)] 2 | pub struct WorkgroupCount(pub u32, pub u32, pub u32); //Analagous to gridDim in CUDA 3 | 4 | impl WorkgroupCount { 5 | pub fn as_tuple(&self) -> (u32, u32, u32) { 6 | (self.0, self.1, self.2) 7 | } 8 | } 9 | 10 | #[macro_export] 11 | macro_rules! wgc { 12 | ($x:expr, $y:expr, $z:expr) => { 13 | $crate::WorkgroupCount::new($x, $y, $z) 14 | }; 15 | } 16 | 17 | #[derive(Debug, derive_new::new)] 18 | pub struct WorkgroupSize(pub u32, pub u32, pub u32); //Analagous to blockDim in CUDA 19 | 20 | impl WorkgroupSize { 21 | pub fn total(&self) -> u32 { 22 | self.0 * self.1 * self.2 23 | } 24 | } 25 | 26 | #[macro_export] 27 | macro_rules! wgs { 28 | ($x:expr, $y:expr, $z:expr) => { 29 | $crate::WorkgroupSize::new($x, $y, $z) 30 | }; 31 | } 32 | 33 | ///The Workload represents the entire piece of work. 34 | ///For more read: https://surma.dev/things/webgpu/ 35 | #[derive(Debug)] 36 | pub struct Workload { 37 | size: WorkgroupSize, 38 | count: WorkgroupCount, 39 | } 40 | 41 | impl Workload { 42 | pub fn new(size: WorkgroupSize, count: WorkgroupCount) -> Self { 43 | Self { size, count } 44 | } 45 | 46 | pub fn count(&self) -> &WorkgroupCount { 47 | &self.count 48 | } 49 | 50 | pub fn size(&self) -> &WorkgroupSize { 51 | &self.size 52 | } 53 | } 54 | 55 | ///Used to determine which limit applies 56 | #[derive(Debug, Clone)] 57 | pub enum WorkloadDim { 58 | X, 59 | Y, 60 | Z, 61 | } 62 | 63 | impl Workload { 64 | pub const MAX_WORKGROUP_SIZE_X: usize = 256; 65 | pub const MAX_WORKGROUP_SIZE_Y: usize = 256; 66 | pub const MAX_WORKGROUP_SIZE_Z: usize = 64; 67 | pub const MAX_COMPUTE_WORKGROUPS_PER_DIMENSION: usize = 65535; 68 | 69 | pub fn ceil(num: usize, div: usize) -> usize { 70 | (num + div - 1) / div 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /kernels/sgemv/sgemv_2v.wgsl: -------------------------------------------------------------------------------- 1 | //https://www.bealto.com/gpu-gemv_v1.html 2 | var localId: vec3; 3 | var globalId: vec3; 4 | var workgroupId: vec3; 5 | 6 | @group(0) @binding(0) var A: array>; 7 | @group(0) @binding(1) var X: array>; 8 | @group(0) @binding(2) var result: array; 9 | @group(1) @binding(0) var metadata: Meta; 10 | 11 | 12 | struct Meta { 13 | aShape: vec3, 14 | aStrides: vec3, 15 | bShape: vec3, 16 | bStrides: vec3, 17 | outShape: vec3, 18 | outShapeStrides: vec3, 19 | dimAOuter: i32, 20 | dimBOuter: i32, 21 | dimInner: i32, 22 | } 23 | 24 | var work: array, {{workgroup_size_x * workgroup_size_y / 4}}>; 25 | 26 | @compute @workgroup_size({{workgroup_size_x}},{{workgroup_size_y}},{{workgroup_size_z}}) 27 | fn main(@builtin(local_invocation_id) localId : vec3, 28 | @builtin(global_invocation_id) globalId : vec3, 29 | @builtin(workgroup_id) workgroupId : vec3) { 30 | 31 | var sum = vec4(0.0); 32 | let row = i32(globalId.x); 33 | let aIndex = row * metadata.aStrides.y / 4; 34 | 35 | for (var k = i32(globalId.y); k < metadata.dimInner / 4; k+={{workgroup_size_y / 4}}) { 36 | sum = fma(A[aIndex + k], X[k], sum); 37 | } 38 | 39 | let rows = {{workgroup_size_x}}u; 40 | let cols = {{workgroup_size_y / 4}}u; 41 | let ii = u32(localId.x); 42 | let jj = u32(localId.y); 43 | work[ii + rows * jj] = sum; 44 | workgroupBarrier(); 45 | 46 | // Reduce sums in log2(cols) steps 47 | for (var s = u32(cols) / 2u; s > 0u; s >>= 1u) { 48 | if (jj < s) { 49 | work[ii + rows * jj] += work[ii + rows * (jj + s)]; 50 | } 51 | workgroupBarrier(); 52 | } 53 | 54 | if (jj == 0u) { 55 | result[row] = dot(work[ii], vec4(1.0)); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /kernels/layernorm/onepass_scalar.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var X: array; 3 | 4 | @group(0) @binding(1) 5 | var S: array; 6 | 7 | @group(0) @binding(2) 8 | var B: array; 9 | 10 | @group(0) @binding(3) 11 | var Y: array; 12 | 13 | struct Meta { 14 | M: u32, 15 | N: u32, 16 | ND4: u32, 17 | eps: f32, 18 | } 19 | 20 | @group(1) @binding(0) 21 | var metadata: Meta; 22 | 23 | const BLOCK_SIZE: u32 = 128u; 24 | 25 | var sum: array; 26 | var sq_sum: array; 27 | var mean: f32; 28 | var sigma: f32; 29 | 30 | fn block_reduce(index: u32, stride: u32) { 31 | if index < stride { 32 | sum[index] += sum[index + stride]; 33 | sq_sum[index] += sq_sum[index + stride]; 34 | } 35 | workgroupBarrier(); 36 | } 37 | 38 | @compute @workgroup_size(128, 1, 1) 39 | fn main( 40 | @builtin(global_invocation_id) global_id: vec3, 41 | @builtin(local_invocation_id) local_id: vec3, 42 | @builtin(workgroup_id) group_id: vec3 43 | ) { 44 | let anchor = (group_id.y * metadata.M * metadata.N) + group_id.x * metadata.N; 45 | 46 | for (var i = local_id.x; i < metadata.N; i += BLOCK_SIZE) { 47 | let val = X[anchor + i]; 48 | sum[local_id.x] += val; 49 | sq_sum[local_id.x] += val * val; 50 | } 51 | workgroupBarrier(); 52 | 53 | block_reduce(local_id.x, 64u); 54 | block_reduce(local_id.x, 32u); 55 | block_reduce(local_id.x, 16u); 56 | block_reduce(local_id.x, 8u); 57 | block_reduce(local_id.x, 4u); 58 | block_reduce(local_id.x, 2u); 59 | block_reduce(local_id.x, 1u); 60 | 61 | if local_id.x == 0u { 62 | mean = sum[0] / f32(metadata.N); 63 | sigma = inverseSqrt(sq_sum[0] / f32(metadata.N) - mean * mean + metadata.eps); 64 | } 65 | workgroupBarrier(); 66 | 67 | for (var i = local_id.x; i < metadata.N; i += BLOCK_SIZE) { 68 | let val = (X[anchor + i] - mean) * sigma; 69 | Y[anchor + i] = fma(val, S[i], B[i]); 70 | } 71 | } 72 | 73 | -------------------------------------------------------------------------------- /kernels/qgemv/sgemv_2v.wgsl: -------------------------------------------------------------------------------- 1 | //https://www.bealto.com/gpu-gemv_v1.html 2 | var localId: vec3; 3 | var globalId: vec3; 4 | var workgroupId: vec3; 5 | 6 | @group(0) @binding(0) var A: array; 7 | @group(0) @binding(1) var scale: array; 8 | @group(0) @binding(2) var X: array>; 9 | @group(0) @binding(3) var result: array; 10 | @group(1) @binding(0) var metadata: Meta; 11 | 12 | 13 | struct Meta { 14 | aShape: vec3, 15 | aStrides: vec3, 16 | bShape: vec3, 17 | bStrides: vec3, 18 | outShape: vec3, 19 | outShapeStrides: vec3, 20 | dimAOuter: i32, 21 | dimBOuter: i32, 22 | dimInner: i32, 23 | } 24 | 25 | var work: array, {{workgroup_size_x * workgroup_size_y / 4}}>; 26 | 27 | @compute @workgroup_size({{workgroup_size_x}},{{workgroup_size_y}},{{workgroup_size_z}}) 28 | fn main(@builtin(local_invocation_id) localId : vec3, 29 | @builtin(global_invocation_id) globalId : vec3, 30 | @builtin(workgroup_id) workgroupId : vec3) { 31 | 32 | var sum = vec4(0.0); 33 | let row = i32(globalId.x); 34 | let aIndex = row * metadata.aStrides.y / 4; 35 | let sIndex = row * metadata.aStrides.y / 16; 36 | 37 | for (var k = i32(globalId.y); k < metadata.dimInner / 4; k+={{workgroup_size_y / 4}}) { 38 | sum = fma(unpack4x8snorm(A[aIndex + k]) * scale[sIndex + (k/4)], X[k], sum); 39 | } 40 | 41 | let rows = {{workgroup_size_x}}u; 42 | let cols = {{workgroup_size_y / 4}}u; 43 | let ii = u32(localId.x); 44 | let jj = u32(localId.y); 45 | work[ii + rows * jj] = sum; 46 | workgroupBarrier(); 47 | 48 | // Reduce sums in log2(cols) steps 49 | for (var s = u32(cols) / 2u; s > 0u; s >>= 1u) { 50 | if (jj < s) { 51 | work[ii + rows * jj] += work[ii + rows * (jj + s)]; 52 | } 53 | workgroupBarrier(); 54 | } 55 | 56 | if (jj == 0u) { 57 | result[row] = dot(work[ii], vec4(1.0)); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /kernels/qgemm/slow.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var A: array>; 3 | 4 | @group(0) @binding(1) 5 | var B: array; 6 | 7 | @group(0) @binding(2) 8 | var absmax: array; 9 | 10 | @group(0) @binding(3) 11 | var C: array>; 12 | 13 | struct Meta { 14 | M: u32, 15 | N: u32, 16 | K: u32, 17 | MD2: u32, 18 | ND2: u32, 19 | KD2: u32, 20 | MD4: u32, 21 | ND4: u32, 22 | KD4: u32, 23 | A_OFFSET: u32, 24 | B_OFFSET: u32, 25 | C_OFFSET: u32, 26 | } 27 | 28 | @group(1) @binding(0) 29 | var metadata: Meta; 30 | 31 | @compute @workgroup_size(8,8,1) 32 | fn main( 33 | @builtin(global_invocation_id) global_id: vec3 34 | ) { 35 | let a_offset = global_id.z * metadata.A_OFFSET; 36 | let b_offset = global_id.z * metadata.B_OFFSET; 37 | let c_offset = global_id.z * metadata.C_OFFSET; 38 | 39 | let cRow = global_id.x; 40 | let cCol = global_id.y; 41 | 42 | let absmax_stride = metadata.N / 16u; 43 | let b_stride = metadata.N; //Solve 4 per iter a.k.a metadata.ND4 * 4u 44 | 45 | if (cRow < metadata.M && cCol < metadata.ND4) { 46 | var tmp = vec4(0.0); 47 | for (var k = 0u; k < metadata.KD4; k++) { 48 | let a = A[a_offset + cRow * metadata.KD4 + k]; 49 | 50 | let bidx = b_offset + (k * b_stride) + cCol; 51 | let absidx = (k * 4u) * absmax_stride + (cCol / 4u); 52 | 53 | let b0 = unpack4x8snorm(B[bidx]) * absmax[absidx]; 54 | let b1 = unpack4x8snorm(B[bidx + (1u * metadata.ND4)]) * absmax[absidx + absmax_stride]; 55 | let b2 = unpack4x8snorm(B[bidx + (2u * metadata.ND4)]) * absmax[absidx + (2u * absmax_stride)]; 56 | let b3 = unpack4x8snorm(B[bidx + (3u * metadata.ND4)]) * absmax[absidx + (3u * absmax_stride)]; 57 | 58 | tmp = fma(vec4(a.x), b0, tmp); 59 | tmp = fma(vec4(a.y), b1, tmp); 60 | tmp = fma(vec4(a.z), b2, tmp); 61 | tmp = fma(vec4(a.w), b3, tmp); 62 | } 63 | C[c_offset + (cRow * metadata.ND4 + cCol)] = tmp; 64 | } 65 | } 66 | 67 | -------------------------------------------------------------------------------- /kernels/layernorm/onepass_vec4.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var X: array>; 3 | 4 | @group(0) @binding(1) 5 | var S: array>; 6 | 7 | @group(0) @binding(2) 8 | var B: array>; 9 | 10 | @group(0) @binding(3) 11 | var Y: array>; 12 | 13 | struct Meta { 14 | M: u32, 15 | N: u32, 16 | ND4: u32, 17 | eps: f32, 18 | } 19 | 20 | @group(1) @binding(0) 21 | var metadata: Meta; 22 | 23 | const BLOCK_SIZE: u32 = 128u; 24 | 25 | var sum: array, BLOCK_SIZE>; 26 | var sq_sum: array, BLOCK_SIZE>; 27 | var mean: f32; 28 | var sigma: f32; 29 | 30 | fn block_reduce(index: u32, stride: u32) { 31 | if index < stride { 32 | sum[index] += sum[index + stride]; 33 | sq_sum[index] += sq_sum[index + stride]; 34 | } 35 | workgroupBarrier(); 36 | } 37 | 38 | @compute @workgroup_size(128, 1, 1) 39 | fn main( 40 | @builtin(global_invocation_id) global_id: vec3, 41 | @builtin(local_invocation_id) local_id: vec3, 42 | @builtin(workgroup_id) group_id: vec3 43 | ) { 44 | let anchor = (group_id.y * metadata.M * metadata.ND4) + group_id.x * metadata.ND4; 45 | 46 | for (var i = local_id.x; i < metadata.ND4; i += BLOCK_SIZE) { 47 | let val = X[anchor + i]; 48 | sum[local_id.x] += val; 49 | sq_sum[local_id.x] += val * val; 50 | } 51 | workgroupBarrier(); 52 | 53 | block_reduce(local_id.x, 64u); 54 | block_reduce(local_id.x, 32u); 55 | block_reduce(local_id.x, 16u); 56 | block_reduce(local_id.x, 8u); 57 | block_reduce(local_id.x, 4u); 58 | block_reduce(local_id.x, 2u); 59 | block_reduce(local_id.x, 1u); 60 | 61 | if local_id.x == 0u { 62 | mean = dot(sum[0], vec4(1.0)) / f32(metadata.N); 63 | sigma = inverseSqrt(dot(sq_sum[0], vec4(1.0)) / f32(metadata.N) - mean * mean + metadata.eps); 64 | } 65 | workgroupBarrier(); 66 | 67 | for (var i = local_id.x; i < metadata.ND4; i += BLOCK_SIZE) { 68 | let val = (X[anchor + i] - mean) * sigma; 69 | Y[anchor + i] = fma(val, S[i], B[i]); 70 | } 71 | } 72 | 73 | -------------------------------------------------------------------------------- /kernels/sgemv/sgemv_2.wgsl: -------------------------------------------------------------------------------- 1 | //https://www.bealto.com/gpu-gemv_v1.html 2 | var localId: vec3; 3 | var globalId: vec3; 4 | var workgroupId: vec3; 5 | 6 | @group(0) @binding(0) var A: array; 7 | @group(0) @binding(1) var X: array; 8 | @group(0) @binding(2) var result: array; 9 | @group(1) @binding(0) var metadata: Meta; 10 | 11 | 12 | struct Meta { 13 | aShape: vec3, 14 | aStrides: vec3, 15 | bShape: vec3, 16 | bStrides: vec3, 17 | outShape: vec3, 18 | outShapeStrides: vec3, 19 | dimAOuter: i32, 20 | dimBOuter: i32, 21 | dimInner: i32, 22 | } 23 | 24 | var work: array; 25 | 26 | @compute @workgroup_size({{workgroup_size_x}},{{workgroup_size_y}},{{workgroup_size_z}}) 27 | fn main(@builtin(local_invocation_id) localId : vec3, 28 | @builtin(global_invocation_id) globalId : vec3, 29 | @builtin(workgroup_id) workgroupId : vec3) { 30 | let batch = i32(globalId.z); 31 | let batchA = batch % metadata.aShape[0]; 32 | let batchB = batch % metadata.bShape[0]; 33 | 34 | let aOffset = metadata.aStrides.x * batchA; 35 | let bOffset = metadata.bStrides.x * batchB; 36 | let outOffset = metadata.outShapeStrides.x * batch; 37 | 38 | var sum = 0.0; 39 | let row = i32(globalId.x); 40 | let aIndex = aOffset + row * metadata.aStrides.y; 41 | 42 | for (var k = i32(globalId.y); k < metadata.dimInner; k+={{workgroup_size_y}}) { 43 | sum = fma(A[aIndex + k], X[bOffset + k], sum); 44 | } 45 | 46 | let rows = {{workgroup_size_x}}u; 47 | let cols = {{workgroup_size_y}}u; 48 | let ii = u32(localId.x); 49 | let jj = u32(localId.y); 50 | work[ii + rows * jj] = sum; 51 | workgroupBarrier(); 52 | 53 | // Reduce sums in log2(cols) steps 54 | for (var s = u32(cols) / 2u; s > 0u; s >>= 1u) { 55 | if (jj < s) { 56 | work[ii + rows * jj] += work[ii + rows * (jj + s)]; 57 | } 58 | workgroupBarrier(); 59 | } 60 | 61 | if (jj == 0u) { 62 | result[outOffset + row] = work[ii]; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgpu-bencher" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | [[bench]] 8 | name = "naive" 9 | path = "benches/layernorm/naive.rs" 10 | harness = false 11 | 12 | [[bench]] 13 | name = "naive_vectorized" 14 | path = "benches/layernorm/naive_vectorized.rs" 15 | harness = false 16 | 17 | [[bench]] 18 | name = "naive_onepass" 19 | path = "benches/layernorm/naive_onepass.rs" 20 | harness = false 21 | 22 | [[bench]] 23 | name = "naive_vectorized_onepass" 24 | path = "benches/layernorm/naive_vectorized_onepass.rs" 25 | harness = false 26 | 27 | [[bench]] 28 | name = "welford_scalar" 29 | path = "benches/layernorm/welford_scalar.rs" 30 | harness = false 31 | 32 | [[bench]] 33 | name = "welford_vectorized" 34 | path = "benches/layernorm/welford_vectorized.rs" 35 | harness = false 36 | 37 | [[bench]] 38 | name = "sgemm" 39 | path = "benches/sgemm/tfjs.rs" 40 | harness = false 41 | 42 | [[bench]] 43 | name = "sgemv" 44 | path = "benches/sgemv/gemv.rs" 45 | harness = false 46 | 47 | [[bench]] 48 | name = "qgemm" 49 | path = "benches/qgemm/tfjs.rs" 50 | harness = false 51 | 52 | [[bench]] 53 | name = "qgemv" 54 | path = "benches/qgemv/gemv.rs" 55 | harness = false 56 | 57 | [[bench]] 58 | name = "rope" 59 | path = "benches/rope/rope.rs" 60 | harness = false 61 | 62 | [[bench]] 63 | name = "mlx-gemv" 64 | path = "benches/mlx-gemv/gemv.rs" 65 | harness = false 66 | 67 | [[bench]] 68 | name = "mlx-qgemv" 69 | path = "benches/mlx-qgemv/gemv.rs" 70 | harness = false 71 | 72 | 73 | [dependencies] 74 | anyhow = "1.0.75" 75 | bytemuck = "1.14.0" 76 | log = "0.4.20" 77 | num-traits = "0.2.17" 78 | rand = {version="0.8.5", features=["small_rng"]} 79 | smallvec = "1.11.2" 80 | tabled = "0.14.0" 81 | criterion = "0.5.1" 82 | wgpu = { git = "https://github.com/FL33TW00D/wgpu", branch = "feature/multi-dim-compute-subgroups" } 83 | pollster = "0.3.0" 84 | lazy_static = "1.4.0" 85 | glam = "0.25.0" 86 | encase = { version = "0.7", features=["glam"] } 87 | derive-new = "0.6.0" 88 | tera = "1.19.1" 89 | inline-python = { version = "0.12.0"} 90 | numpy = { version = "0.19.0"} 91 | pyo3 = { version = "0.19.1"} 92 | npyz = "0.8.1" 93 | ndarray = "0.15.6" 94 | rand_distr = "0.4.3" 95 | env_logger = "0.11.3" 96 | half = { version = "2.4.0", features=["num-traits", "bytemuck"]} 97 | num = "0.4.1" 98 | -------------------------------------------------------------------------------- /scratch: -------------------------------------------------------------------------------- 1 | // language: metal2.4 2 | #include 3 | #include 4 | 5 | using metal::uint; 6 | 7 | struct _mslBufferSizes { 8 | uint size0; 9 | uint size1; 10 | }; 11 | 12 | typedef float type_1[1]; 13 | struct Meta { 14 | metal::uint3 in_strides; 15 | metal::packed_uint3 out_strides; 16 | uint offset; 17 | float base; 18 | float scale; 19 | }; 20 | 21 | struct main_Input { 22 | }; 23 | kernel void main_( 24 | metal::uint3 local_id [[thread_position_in_threadgroup]] 25 | , metal::uint3 pos [[thread_position_in_grid]] 26 | , uint subgroup_id [[simdgroup_index_in_threadgroup]] 27 | , uint subgroup_size [[threads_per_simdgroup]] 28 | , metal::uint3 groups [[threadgroups_per_grid]] 29 | , device type_1 const& in [[buffer(0)]] 30 | , device type_1& out [[buffer(1)]] 31 | , constant Meta& metadata [[buffer(2)]] 32 | , constant _mslBufferSizes& _buffer_sizes [[buffer(3)]] 33 | ) { 34 | uint in_index_1_ = 0u; 35 | uint in_index_2_ = 0u; 36 | uint out_index_1_ = 0u; 37 | uint out_index_2_ = 0u; 38 | metal::uint3 grid = metal::uint3(groups.x * 16u, groups.y * 8u, groups.z * 8u); 39 | uint _e27 = metadata.out_strides[2]; 40 | uint _e33 = metadata.out_strides[1]; 41 | uint _e40 = metadata.out_strides[0]; 42 | out_index_1_ = ((pos.x * _e27) + (pos.y * _e33)) + (pos.z * _e40); 43 | uint _e43 = out_index_1_; 44 | uint _e48 = metadata.out_strides[2]; 45 | out_index_2_ = _e43 + (grid.x * _e48); 46 | uint _e55 = metadata.in_strides.z; 47 | uint _e61 = metadata.in_strides.y; 48 | uint _e68 = metadata.in_strides.x; 49 | in_index_1_ = ((pos.x * _e55) + (pos.y * _e61)) + (pos.z * _e68); 50 | uint _e71 = in_index_1_; 51 | uint _e76 = metadata.in_strides.z; 52 | in_index_2_ = _e71 + (grid.x * _e76); 53 | float _e81 = metadata.scale; 54 | uint _e85 = metadata.offset; 55 | float L = _e81 * static_cast(pos.y + _e85); 56 | float d = static_cast(pos.x) / static_cast(grid.x); 57 | float _e97 = metadata.base; 58 | float theta = L * metal::exp2(-(d) * _e97); 59 | float costheta = metal::cos(theta); 60 | float sintheta = metal::sin(theta); 61 | uint _e104 = in_index_1_; 62 | float _e106 = in[_e104]; 63 | float x1_ = static_cast(_e106); 64 | uint _e109 = in_index_2_; 65 | float _e111 = in[_e109]; 66 | float x2_ = static_cast(_e111); 67 | float rx1_ = (x1_ * costheta) - (x2_ * sintheta); 68 | float rx2_ = (x1_ * sintheta) + (x2_ * costheta); 69 | uint _e120 = out_index_1_; 70 | out[_e120] = static_cast(rx1_); 71 | uint _e124 = out_index_2_; 72 | out[_e124] = static_cast(rx2_); 73 | return; 74 | } 75 | -------------------------------------------------------------------------------- /src/strides.rs: -------------------------------------------------------------------------------- 1 | use crate::Shape; 2 | use encase::impl_wrapper; 3 | 4 | #[derive(Clone, PartialEq, Eq, Default, Hash)] 5 | pub struct Strides(Vec); 6 | 7 | impl_wrapper!(Strides; using); 8 | 9 | impl Strides { 10 | pub fn inner(self) -> Vec { 11 | self.0 12 | } 13 | } 14 | 15 | impl std::fmt::Debug for Strides { 16 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 17 | let mut shape = format!("[{}", self.0.first().unwrap_or(&0)); 18 | for dim in self.0.iter().skip(1) { 19 | shape.push_str(&format!("x{}", dim)); 20 | } 21 | write!(f, "{}]", shape) 22 | } 23 | } 24 | 25 | impl From<&Shape> for Strides { 26 | fn from(shape: &Shape) -> Self { 27 | let mut strides = vec![]; 28 | let mut stride = 1; 29 | for size in shape.to_vec().iter().rev() { 30 | strides.push(stride); 31 | stride *= *size as isize; 32 | } 33 | strides.reverse(); 34 | Self(strides) 35 | } 36 | } 37 | 38 | impl From<&Strides> for [u32; 4] { 39 | fn from(strides: &Strides) -> Self { 40 | assert!(strides.0.len() <= 4); 41 | let mut array = [0; 4]; 42 | for (i, &stride) in strides.0.iter().enumerate() { 43 | array[i] = stride as u32; 44 | } 45 | array 46 | } 47 | } 48 | 49 | impl From<&Strides> for [u32; 3] { 50 | fn from(strides: &Strides) -> Self { 51 | assert!(strides.0.len() <= 3); 52 | let mut array = [0; 3]; 53 | for (i, &stride) in strides.0.iter().enumerate() { 54 | array[i] = stride as u32; 55 | } 56 | array 57 | } 58 | } 59 | 60 | impl From<&Strides> for glam::UVec4 { 61 | fn from(strides: &Strides) -> Self { 62 | let array: [u32; 4] = strides.into(); 63 | glam::UVec4::from(array) 64 | } 65 | } 66 | 67 | impl From<&Strides> for glam::UVec3 { 68 | fn from(strides: &Strides) -> Self { 69 | let array: [u32; 3] = strides.into(); 70 | glam::UVec3::from(array) 71 | } 72 | } 73 | 74 | impl From for glam::IVec3 { 75 | fn from(strides: Strides) -> Self { 76 | (&strides).into() 77 | } 78 | } 79 | 80 | impl From<&Strides> for glam::IVec3 { 81 | fn from(strides: &Strides) -> Self { 82 | glam::IVec3::new(strides.0[0] as _, strides.0[1] as _, strides.0[2] as _) 83 | } 84 | } 85 | 86 | #[cfg(test)] 87 | mod tests { 88 | use crate::shape; 89 | 90 | #[test] 91 | fn test_strides() { 92 | use super::*; 93 | let shape = shape![2, 3, 4]; 94 | let strides = Strides::from(&shape); 95 | assert_eq!(strides.inner(), vec![12, 4, 1]); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/shape.rs: -------------------------------------------------------------------------------- 1 | use std::ops::RangeTo; 2 | 3 | use smallvec::SmallVec; 4 | 5 | #[derive(Clone, PartialEq, Eq)] 6 | pub struct Shape(SmallVec<[usize; 4]>); 7 | 8 | impl Shape { 9 | pub fn new(shape: SmallVec<[usize; 4]>) -> Self { 10 | Shape(shape) 11 | } 12 | 13 | pub fn rank(&self) -> usize { 14 | self.0.len() 15 | } 16 | 17 | pub fn numel(&self) -> usize { 18 | self.0.iter().product() 19 | } 20 | 21 | pub fn to_vec(&self) -> Vec { 22 | self.0.clone().into_vec() 23 | } 24 | 25 | pub fn remove(&mut self, index: usize) -> usize { 26 | self.0.remove(index) 27 | } 28 | } 29 | 30 | impl std::fmt::Debug for Shape { 31 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 32 | let mut shape = String::from("["); 33 | for (i, dim) in self.0.iter().enumerate() { 34 | if i == 0 { 35 | shape.push_str(&format!("{}", dim)); 36 | } else { 37 | shape.push_str(&format!("x{}", dim)); 38 | } 39 | } 40 | write!(f, "{}]", shape) 41 | } 42 | } 43 | 44 | impl std::fmt::Display for Shape { 45 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 46 | write!(f, "{:?}", self.0) 47 | } 48 | } 49 | 50 | impl std::ops::Index for Shape { 51 | type Output = usize; 52 | 53 | fn index(&self, index: usize) -> &Self::Output { 54 | &self.0[index] 55 | } 56 | } 57 | 58 | impl std::ops::IndexMut for Shape { 59 | fn index_mut(&mut self, index: usize) -> &mut Self::Output { 60 | &mut self.0[index] 61 | } 62 | } 63 | 64 | impl std::ops::Index> for Shape { 65 | type Output = [usize]; 66 | 67 | fn index(&self, index: RangeTo) -> &Self::Output { 68 | &self.0[index] 69 | } 70 | } 71 | 72 | impl From<&[usize]> for Shape { 73 | fn from(slice: &[usize]) -> Self { 74 | Shape(slice.into()) 75 | } 76 | } 77 | 78 | macro_rules! impl_try_into { 79 | ($($n:literal),*) => { 80 | $( 81 | impl TryInto<[usize; $n]> for &Shape { 82 | type Error = &'static str; 83 | 84 | fn try_into(self) -> Result<[usize; $n], Self::Error> { 85 | if self.0.len() != $n { 86 | Err(concat!("Shape must have rank ", stringify!($n))) 87 | } else { 88 | let mut shape = [1; $n]; 89 | shape[..self.0.len()].copy_from_slice(&self.0); 90 | Ok(shape) 91 | } 92 | } 93 | } 94 | )* 95 | }; 96 | } 97 | 98 | impl_try_into!(1, 2, 3, 4); 99 | -------------------------------------------------------------------------------- /kernels/layernorm/naive_scalar.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var X: array; 3 | 4 | @group(0) @binding(1) 5 | var S: array; 6 | 7 | @group(0) @binding(2) 8 | var B: array; 9 | 10 | @group(0) @binding(3) 11 | var Y: array; 12 | 13 | struct Meta { 14 | M: u32, 15 | N: u32, 16 | ND4: u32, 17 | eps: f32, 18 | } 19 | 20 | @group(1) @binding(0) 21 | var metadata: Meta; 22 | 23 | const BLOCK_SIZE: u32 = 128u; 24 | 25 | var smem: array; //max 16kb 26 | 27 | fn block_sum(index: u32, stride: u32) { 28 | if index < stride { 29 | smem[index] += smem[index + stride]; 30 | } 31 | workgroupBarrier(); 32 | } 33 | 34 | fn mu(local_id: vec3, anchor: u32) -> f32 { 35 | var threadSum = 0f; 36 | for (var i: u32 = local_id.x; i < metadata.N; i += BLOCK_SIZE) { 37 | threadSum += X[anchor + i]; 38 | } 39 | smem[local_id.x] = threadSum; 40 | workgroupBarrier(); 41 | 42 | block_sum(local_id.x, 64u); 43 | block_sum(local_id.x, 32u); 44 | block_sum(local_id.x, 16u); 45 | block_sum(local_id.x, 8u); 46 | block_sum(local_id.x, 4u); 47 | block_sum(local_id.x, 2u); 48 | block_sum(local_id.x, 1u); 49 | 50 | return smem[0] / f32(metadata.N); 51 | } 52 | 53 | fn sigma(local_id: vec3, anchor: u32, mu: f32) -> f32 { 54 | var threadSum = 0f; 55 | //Compute σ 56 | for (var i: u32 = local_id.x; i < metadata.N; i += BLOCK_SIZE) { 57 | let val = X[anchor + i] - mu; 58 | threadSum += (val * val); 59 | } 60 | smem[local_id.x] = threadSum; 61 | workgroupBarrier(); 62 | 63 | block_sum(local_id.x, 64u); 64 | block_sum(local_id.x, 32u); 65 | block_sum(local_id.x, 16u); 66 | block_sum(local_id.x, 8u); 67 | block_sum(local_id.x, 4u); 68 | block_sum(local_id.x, 2u); 69 | block_sum(local_id.x, 1u); 70 | 71 | return smem[0] / (f32(metadata.N)); 72 | } 73 | 74 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) 75 | fn main( 76 | @builtin(local_invocation_id) local_id: vec3, 77 | @builtin(workgroup_id) group_id: vec3, 78 | @builtin(global_invocation_id) global_id: vec3 79 | ) { 80 | let anchor = (group_id.y * metadata.M * metadata.N) + group_id.x * metadata.N; 81 | let mu = mu(local_id, anchor); 82 | let sigma = sigma(local_id, anchor, mu); 83 | 84 | let denom = inverseSqrt(sigma + metadata.eps); 85 | 86 | for(var i: u32 = local_id.x; i < metadata.N; i += BLOCK_SIZE) { 87 | let core = (X[anchor + i] - mu) * denom; 88 | Y[anchor + i] = fma(core, S[i], B[i]); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /kernels/layernorm/naive_vec4.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var X: array>; 3 | 4 | @group(0) @binding(1) 5 | var S: array>; 6 | 7 | @group(0) @binding(2) 8 | var B: array>; 9 | 10 | @group(0) @binding(3) 11 | var Y: array>; 12 | 13 | struct Meta { 14 | M: u32, 15 | N: u32, 16 | ND4: u32, 17 | eps: f32, 18 | } 19 | 20 | @group(1) @binding(0) 21 | var metadata: Meta; 22 | 23 | const BLOCK_SIZE: u32 = 128u; 24 | 25 | var smem: array, BLOCK_SIZE>; //max 16kb 26 | 27 | fn block_sum(index: u32, stride: u32) { 28 | if index < stride { 29 | smem[index] += smem[index + stride]; 30 | } 31 | workgroupBarrier(); 32 | } 33 | 34 | fn mu(local_id: vec3, anchor: u32) -> f32 { 35 | var threadSum = vec4(0.0); 36 | for (var i: u32 = local_id.x; i < metadata.ND4; i += BLOCK_SIZE) { 37 | threadSum += X[anchor + i]; 38 | } 39 | smem[local_id.x] = threadSum; 40 | workgroupBarrier(); 41 | 42 | block_sum(local_id.x, 64u); 43 | block_sum(local_id.x, 32u); 44 | block_sum(local_id.x, 16u); 45 | block_sum(local_id.x, 8u); 46 | block_sum(local_id.x, 4u); 47 | block_sum(local_id.x, 2u); 48 | block_sum(local_id.x, 1u); 49 | 50 | return dot(smem[0], vec4(1.0)) / f32(metadata.N); 51 | } 52 | 53 | fn sigma(local_id: vec3, anchor: u32, mu: f32) -> f32 { 54 | var threadSum = vec4(0.0); 55 | //Compute σ 56 | for (var i: u32 = local_id.x; i < metadata.ND4; i += BLOCK_SIZE) { 57 | let val = X[anchor + i] - mu; 58 | threadSum = fma(val, val, threadSum); 59 | } 60 | smem[local_id.x] = threadSum; 61 | workgroupBarrier(); 62 | 63 | block_sum(local_id.x, 64u); 64 | block_sum(local_id.x, 32u); 65 | block_sum(local_id.x, 16u); 66 | block_sum(local_id.x, 8u); 67 | block_sum(local_id.x, 4u); 68 | block_sum(local_id.x, 2u); 69 | block_sum(local_id.x, 1u); 70 | 71 | return dot(smem[0], vec4(1.0)) / (f32(metadata.N)); 72 | } 73 | 74 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) 75 | fn main( 76 | @builtin(local_invocation_id) local_id: vec3, 77 | @builtin(workgroup_id) group_id: vec3, 78 | @builtin(global_invocation_id) global_id: vec3 79 | ) { 80 | let anchor = (group_id.y * metadata.M * metadata.ND4) + group_id.x * metadata.ND4; 81 | let mu = mu(local_id, anchor); 82 | let sigma = sigma(local_id, anchor, mu); 83 | 84 | let denom = inverseSqrt(sigma + vec4(metadata.eps)); 85 | 86 | for(var i: u32 = local_id.x; i < metadata.ND4; i += BLOCK_SIZE) { 87 | let val = (X[anchor + i] - mu) * denom; 88 | Y[anchor + i] = fma(val, S[i], B[i]); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /kernels/rope/rope.wgsl: -------------------------------------------------------------------------------- 1 | //Translated from: https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rope.metal 2 | //Reading materials 3 | //https://blog.eleuther.ai/rotary-embeddings/ 4 | // 5 | // RoPE summary: 6 | // 1. In language, we don't care about the absolute position of words(tokens). I don't care that cat is in position 500. 7 | // However, I do care about the relative position of words. "You shall know a word by the company it keeps (Firth, 1957)" 8 | // 2. RoPE gives us a way to encode relative positions into our hidden states for q & k. 9 | // 3. Key insight of rope: Use rotations to encode relative positions, unaffected by translation (absolute position). 10 | // 4. Rotations are easiest to work with in complex space. 11 | // 5. Complex numbers suck for computers, so we do it in the reals. 12 | // 6. We pair up components of our q & k vectors, to form 2D coords in the complex plain. 13 | // This can be done in 2 ways: 1. q = (q1, q2, q3, q4) -> q = (q1 + iq2, q3 + iq4) 14 | // 2. q = (q1, q2 ... qd/2, qd/2+1) -> q = (q1 + iqd/2, q2 + iqd/2+1) 15 | // 7. We then rotate these 2D coords by a fixed angle, theta, to encode relative positions. 16 | 17 | @group(0) @binding(0) 18 | var in: array; 19 | 20 | @group(0) @binding(1) 21 | var out: array; 22 | 23 | struct Meta { 24 | in_strides: vec3, 25 | out_strides: vec3, 26 | offset: u32, 27 | base: f32, 28 | scale: f32, 29 | } 30 | 31 | @group(1) @binding(0) 32 | var metadata: Meta; 33 | 34 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) 35 | fn main( 36 | @builtin(local_invocation_id) local_id: vec3, 37 | @builtin(global_invocation_id) pos: vec3, 38 | @builtin(subgroup_id) subgroup_id: u32, 39 | @builtin(subgroup_size) subgroup_size: u32, 40 | @builtin(num_workgroups) groups: vec3, 41 | ) { 42 | let grid = vec3(groups.x * {{ workgroup_size_x }}u, groups.y * {{ workgroup_size_y }}u, groups.z * {{ workgroup_size_z }}u); 43 | let out_index_1 = dot(pos, vec3(metadata.out_strides[2], metadata.out_strides[1], metadata.out_strides[0])); 44 | let out_index_2 = out_index_1 + grid.x * metadata.out_strides[2]; 45 | 46 | let in_index_1 = dot(pos, vec3(metadata.in_strides[2], metadata.in_strides[1], metadata.in_strides[0])); 47 | let in_index_2 = in_index_1 + grid.x * metadata.in_strides[2]; 48 | 49 | let L = metadata.scale * f32(pos.y + metadata.offset); 50 | let d = f32(pos.x) / f32(grid.x); 51 | 52 | let theta = L * exp2(-d * metadata.base); 53 | let costheta = cos(theta); 54 | let sintheta = sin(theta); 55 | 56 | let x1 = in[in_index_1]; 57 | let x2 = in[in_index_2]; 58 | 59 | let rx1 = x1 * costheta - x2 * sintheta; 60 | let rx2 = x1 * sintheta + x2 * costheta; 61 | 62 | out[out_index_1] = rx1; 63 | out[out_index_2] = rx2; 64 | } 65 | -------------------------------------------------------------------------------- /scratchy: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/kernels/bf16.h" 6 | #include "mlx/backend/metal/kernels/utils.h" 7 | 8 | template 9 | [[kernel]] void rope( 10 | const device T *in [[buffer(0)]], 11 | device T * out [[buffer(1)]], 12 | constant const size_t strides[3], 13 | constant const size_t out_strides[3], 14 | constant const int& offset, 15 | constant const float& base, 16 | constant const float& scale, 17 | uint3 pos [[thread_position_in_grid]], 18 | uint3 grid [[threads_per_grid]]) { 19 | // Compute the input and output indices 20 | uint in_index_1, in_index_2; 21 | uint out_index_1, out_index_2; 22 | if (traditional) { 23 | out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; 24 | out_index_2 = out_index_1 + 1; 25 | in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; 26 | in_index_2 = in_index_1 + strides[2]; 27 | } else { 28 | out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; 29 | out_index_2 = out_index_1 + grid.x * out_strides[2]; 30 | in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; 31 | in_index_2 = in_index_1 + grid.x * strides[2]; 32 | } 33 | 34 | // Figure out L and d. 35 | float L = scale * static_cast(pos.y + offset); 36 | float d = static_cast(pos.x) / static_cast(grid.x); 37 | 38 | // Compute costheta, sintheta 39 | float theta = L * metal::exp2(-d * base); 40 | float costheta = metal::fast::cos(theta); 41 | float sintheta = metal::fast::sin(theta); 42 | 43 | // Read and write the output 44 | float x1 = static_cast(in[in_index_1]); 45 | float x2 = static_cast(in[in_index_2]); 46 | float rx1; 47 | float rx2; 48 | if (forward) { 49 | rx1 = x1 * costheta - x2 * sintheta; 50 | rx2 = x1 * sintheta + x2 * costheta; 51 | } else { 52 | rx1 = x2 * sintheta + x1 * costheta; 53 | rx2 = x2 * costheta - x1 * sintheta; 54 | } 55 | out[out_index_1] = static_cast(rx1); 56 | out[out_index_2] = static_cast(rx2); 57 | } 58 | 59 | #define instantiate_rope(name, type, traditional, forward) \ 60 | template [[host_name("rope_" #name)]] \ 61 | [[kernel]] void rope( \ 62 | const device type* in [[buffer(0)]], \ 63 | device type* out [[buffer(1)]], \ 64 | constant const size_t strides[3], \ 65 | constant const size_t out_strides[3], \ 66 | constant const int& offset, \ 67 | constant const float& base, \ 68 | constant const float& scale, \ 69 | uint3 pos [[thread_position_in_grid]], \ 70 | uint3 grid [[threads_per_grid]]); 71 | 72 | instantiate_rope(traditional_float16, half, true, true) 73 | instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) 74 | instantiate_rope(traditional_float32, float, true, true) 75 | instantiate_rope(float16, half, false, true) 76 | instantiate_rope(bfloat16, bfloat16_t, false, true) 77 | instantiate_rope(float32, float, false, true) 78 | instantiate_rope(vjp_traditional_float16, half, true, false) 79 | instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false) 80 | instantiate_rope(vjp_traditional_float32, float, true, false) 81 | instantiate_rope(vjp_float16, half, false, false) 82 | instantiate_rope(vjp_bfloat16, bfloat16_t, false, false) 83 | instantiate_rope(vjp_float32, float, false, false) 84 | -------------------------------------------------------------------------------- /kernels/layernorm/welford_scalar.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var X: array; 3 | 4 | @group(0) @binding(1) 5 | var S: array; 6 | 7 | @group(0) @binding(2) 8 | var B: array; 9 | 10 | @group(0) @binding(3) 11 | var Y: array; 12 | 13 | struct Meta { 14 | M: u32, 15 | N: u32, 16 | ND4: u32, 17 | eps: f32, 18 | } 19 | 20 | @group(1) @binding(0) 21 | var metadata: Meta; 22 | 23 | var mu: f32; 24 | var sigma: f32; 25 | var subgrp_size: u32; 26 | 27 | fn welford_combine(val: f32, mean: ptr, m2: ptr, count: ptr) { 28 | *count += 1.0; 29 | let delta1 = val - *mean; 30 | *mean += delta1 / *count; 31 | let delta2 = val - *mean; 32 | *m2 += delta1 * delta2; 33 | } 34 | 35 | fn block_welford_combine(b_mean: f32, b_m2: f32, b_count: f32, mean: ptr, m2: ptr, count: ptr) { 36 | if (b_count == 0.0) { 37 | return; 38 | } 39 | let new_count = *count + b_count; 40 | let nb_over_n = b_count / new_count; 41 | let delta = b_mean - *mean; 42 | *mean += delta * nb_over_n; 43 | *m2 += b_m2 + delta * delta * (*count) * nb_over_n; 44 | *count = new_count; 45 | } 46 | 47 | fn welford_warp_reduce(thread_mean: f32, thread_m2: f32, thread_count: f32, mean: ptr, m2: ptr, count: ptr) { 48 | *mean = thread_mean; 49 | *m2 = thread_m2; 50 | *count = thread_count; 51 | for (var offset = subgrp_size >> 1u; offset > 0u; offset >>= 1u) { 52 | let b_mean = subgroupShuffleDown(*mean, offset); 53 | let b_m2 = subgroupShuffleDown(*m2, offset); 54 | let b_count = subgroupShuffleDown(*count, offset); 55 | block_welford_combine(b_mean, b_m2, b_count, mean, m2, count); 56 | } 57 | } 58 | 59 | fn welford_warp_all_reduce(thread_mean: f32, thread_m2: f32, thread_count: f32, mean: ptr, m2: ptr, count: ptr) { 60 | welford_warp_reduce(thread_mean, thread_m2, thread_count, mean, m2, count); 61 | 62 | *mean = subgroupBroadcast(*mean, 0u); 63 | *m2 = subgroupBroadcast(*m2, 0u); 64 | *count = subgroupBroadcast(*count, 0u); 65 | } 66 | 67 | 68 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) 69 | fn main( 70 | @builtin(local_invocation_id) local_id: vec3, 71 | @builtin(workgroup_id) group_id: vec3, 72 | @builtin(global_invocation_id) global_id: vec3, 73 | @builtin(subgroup_id) subgroup_id: u32, 74 | @builtin(subgroup_size) subgroup_size: u32, 75 | ) { 76 | subgrp_size = subgroup_size; 77 | let anchor = (group_id.y * metadata.M * metadata.N) + group_id.x * metadata.N; 78 | var threadVar = 0f; 79 | var threadMean = 0f; 80 | var threadCount = 0f; 81 | for (var i = local_id.x; i < metadata.N; i+= {{ workgroup_size_x }}u) { 82 | welford_combine(X[anchor + i], &threadMean, &threadVar, &threadCount); 83 | } 84 | 85 | var mean = 0f; 86 | var m2 = 0f; 87 | var count = 0f; 88 | welford_warp_all_reduce(threadMean, threadVar, threadCount, &mean, &m2, &count); 89 | 90 | if (subgroup_id == 0u) { 91 | mu = mean; 92 | sigma = inverseSqrt(m2 / count + metadata.eps); 93 | } 94 | subgroupBarrier(); 95 | for (var i = local_id.x; i < metadata.N; i+= {{ workgroup_size_x }}u) { 96 | let val = X[anchor + i]; 97 | let normalized = (val - mu) * sigma; 98 | Y[anchor + i] = fma(normalized, S[i], B[i]); 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/handle.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use wgpu::Adapter; 4 | use wgpu::DeviceType; 5 | 6 | use wgpu::Limits; 7 | 8 | /// # GPUHandle 9 | /// 10 | /// A reference counted handle to a GPU device and queue. 11 | #[derive(Debug, Clone)] 12 | pub struct GPUHandle { 13 | inner: Arc, 14 | } 15 | 16 | #[derive(Debug)] 17 | pub struct Inner { 18 | device: wgpu::Device, 19 | queue: wgpu::Queue, 20 | } 21 | 22 | impl std::ops::Deref for GPUHandle { 23 | type Target = Inner; 24 | 25 | fn deref(&self) -> &Self::Target { 26 | &self.inner 27 | } 28 | } 29 | 30 | impl GPUHandle { 31 | fn get_features() -> wgpu::Features { 32 | wgpu::Features::default() | wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::SUBGROUP 33 | } 34 | 35 | pub async fn new() -> Result { 36 | let adapter = Self::select_adapter(); 37 | 38 | let mut device_descriptor = wgpu::DeviceDescriptor { 39 | label: Some("rumble"), 40 | required_features: Self::get_features(), 41 | required_limits: Limits { 42 | max_buffer_size: (2 << 29) - 1, 43 | max_storage_buffer_binding_size: (2 << 29) - 1, 44 | max_compute_invocations_per_workgroup: 1024, 45 | ..Default::default() 46 | }, 47 | }; 48 | let device_request = adapter.request_device(&device_descriptor, None).await; 49 | let (device, queue) = if let Err(e) = device_request { 50 | log::warn!("Failed to create device with error: {:?}", e); 51 | log::warn!("Trying again with reduced limits"); 52 | device_descriptor.required_limits = adapter.limits(); 53 | let device_request = adapter.request_device(&device_descriptor, None).await; 54 | device_request.unwrap() 55 | } else { 56 | device_request.unwrap() 57 | }; 58 | 59 | Ok(Self { 60 | inner: Arc::new(Inner { device, queue }), 61 | }) 62 | } 63 | 64 | pub fn device(&self) -> &wgpu::Device { 65 | &self.device 66 | } 67 | 68 | pub fn queue(&self) -> &wgpu::Queue { 69 | &self.queue 70 | } 71 | 72 | fn select_adapter() -> Adapter { 73 | let instance = wgpu::Instance::new(wgpu::InstanceDescriptor { 74 | dx12_shader_compiler: wgpu::util::dx12_shader_compiler_from_env().unwrap_or_default(), 75 | ..Default::default() 76 | }); 77 | let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY); 78 | 79 | let adapter = { 80 | let mut most_performant_adapter = None; 81 | let mut current_score = -1; 82 | 83 | instance 84 | .enumerate_adapters(backends) 85 | .into_iter() 86 | .for_each(|adapter| { 87 | let info = adapter.get_info(); 88 | let score = match info.device_type { 89 | DeviceType::DiscreteGpu => 5, 90 | DeviceType::Other => 4, //Other is usually discrete 91 | DeviceType::IntegratedGpu => 3, 92 | DeviceType::VirtualGpu => 2, 93 | DeviceType::Cpu => 1, 94 | }; 95 | 96 | if score > current_score { 97 | most_performant_adapter = Some(adapter); 98 | current_score = score; 99 | } 100 | }); 101 | 102 | if let Some(adapter) = most_performant_adapter { 103 | adapter 104 | } else { 105 | panic!("No adapter found, please check if your GPU is supported"); 106 | } 107 | }; 108 | log::info!("Using adapter {:?}", adapter.get_info()); 109 | adapter 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /benches/layernorm/naive_onepass.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct LayerNormMeta { 22 | M: u32, 23 | N: u32, 24 | ND4: u32, 25 | eps: f32, 26 | } 27 | 28 | impl OpMetadata for LayerNormMeta {} 29 | 30 | #[derive(derive_new::new, Debug)] 31 | pub struct LayerNorm { 32 | eps: f32, 33 | } 34 | 35 | const PROB_M: usize = 2048; 36 | const PROB_N: usize = 512; 37 | 38 | impl KernelBench for LayerNorm { 39 | type Metadata = LayerNormMeta; 40 | 41 | fn name() -> &'static str { 42 | "LayerNormOnePass" 43 | } 44 | 45 | fn source(&self, workload: &Workload) -> String { 46 | let mut tera = tera::Tera::default(); 47 | let mut context = tera::Context::new(); 48 | tera.add_raw_template( 49 | Self::name(), 50 | include_str!("../../kernels/layernorm/onepass_scalar.wgsl"), 51 | ) 52 | .unwrap(); 53 | context.insert_workload(workload); 54 | tera.render(Self::name(), &context).unwrap() 55 | } 56 | 57 | fn tensors(&self) -> Vec { 58 | let input = CPUTensor::randn::(shape![1, PROB_M, PROB_N]); 59 | let scale = CPUTensor::randn::(shape![PROB_N]); 60 | let bias = CPUTensor::randn::(shape![PROB_N]); 61 | let output = CPUTensor::zeros::(shape![1, PROB_M, PROB_N]); 62 | vec![input, scale, bias, output] 63 | } 64 | 65 | fn workload(&self, tensors: &[CPUTensor]) -> Workload { 66 | let input = &tensors[0]; 67 | let [_B, M, _N] = input.shape().try_into().unwrap(); 68 | Workload::new(wgs![128, 1, 1], wgc![M as _, 1, 1]) 69 | } 70 | 71 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata { 72 | let input = &tensors[0]; 73 | let [_B, M, N] = input.shape().try_into().unwrap(); 74 | LayerNormMeta::new(M as _, N as _, (N / 4) as _, self.eps) 75 | } 76 | 77 | fn validate(&self, tensors: &[CPUTensor]) { 78 | let (input, scale, bias) = (&tensors[0], &tensors[1], &tensors[2]); 79 | let ground = Python::with_gil(|py| { 80 | let (py_input, py_scale, py_bias) = ( 81 | input.to_py::(&py), 82 | scale.to_py::(&py), 83 | bias.to_py::(&py), 84 | ); 85 | let result: Context = python! { 86 | import torch 87 | import torch.nn.functional as F 88 | 89 | (input, scale, bias) = (torch.from_numpy('py_input), torch.from_numpy('py_scale), torch.from_numpy('py_bias)) 90 | result = F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy() 91 | }; 92 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 93 | }); 94 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 95 | let cpu_result = gpu_tensors.remove(3).into_cpu(TIMER.handle()).unwrap(); 96 | ground.all_close(&cpu_result, 1e-4, 1e-4).unwrap(); 97 | } 98 | } 99 | 100 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 101 | let throughput = Throughput::Elements((PROB_M * PROB_N) as u64); 102 | wgpu_bencher::benchmark(c, &TIMER, LayerNorm::new(1e-5), throughput) 103 | } 104 | 105 | criterion_group!( 106 | name = bench; 107 | config = Criterion::default().with_measurement(&*TIMER); 108 | targets = benchmark 109 | ); 110 | criterion_main!(bench); 111 | -------------------------------------------------------------------------------- /benches/layernorm/naive_vectorized.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct LayerNormMeta { 22 | M: u32, 23 | N: u32, 24 | ND4: u32, 25 | eps: f32, 26 | } 27 | 28 | impl OpMetadata for LayerNormMeta {} 29 | 30 | #[derive(derive_new::new, Debug)] 31 | pub struct LayerNorm { 32 | eps: f32, 33 | } 34 | 35 | const PROB_M: usize = 2048; 36 | const PROB_N: usize = 512; 37 | 38 | impl KernelBench for LayerNorm { 39 | type Metadata = LayerNormMeta; 40 | 41 | fn name() -> &'static str { 42 | "LayerNormVectorized" 43 | } 44 | 45 | fn source(&self, workload: &Workload) -> String { 46 | let mut tera = tera::Tera::default(); 47 | let mut context = tera::Context::new(); 48 | tera.add_raw_template( 49 | Self::name(), 50 | include_str!("../../kernels/layernorm/naive_vec4.wgsl"), 51 | ) 52 | .unwrap(); 53 | context.insert_workload(workload); 54 | tera.render(Self::name(), &context).unwrap() 55 | } 56 | 57 | fn tensors(&self) -> Vec { 58 | let input = CPUTensor::randn::(shape![1, PROB_M, PROB_N]); 59 | let scale = CPUTensor::randn::(shape![PROB_N]); 60 | let bias = CPUTensor::randn::(shape![PROB_N]); 61 | let output = CPUTensor::zeros::(shape![1, PROB_M, PROB_N]); 62 | vec![input, scale, bias, output] 63 | } 64 | 65 | fn workload(&self, tensors: &[CPUTensor]) -> Workload { 66 | let input = &tensors[0]; 67 | let [_B, M, _N] = input.shape().try_into().unwrap(); 68 | Workload::new(wgs![128, 1, 1], wgc![M as _, 1, 1]) 69 | } 70 | 71 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata { 72 | let input = &tensors[0]; 73 | let [_B, M, N] = input.shape().try_into().unwrap(); 74 | LayerNormMeta::new(M as _, N as _, (N / 4) as _, self.eps) 75 | } 76 | 77 | fn validate(&self, tensors: &[CPUTensor]) { 78 | let (input, scale, bias) = (&tensors[0], &tensors[1], &tensors[2]); 79 | let ground = Python::with_gil(|py| { 80 | let (py_input, py_scale, py_bias) = ( 81 | input.to_py::(&py), 82 | scale.to_py::(&py), 83 | bias.to_py::(&py), 84 | ); 85 | let result: Context = python! { 86 | import torch 87 | import torch.nn.functional as F 88 | 89 | (input, scale, bias) = (torch.from_numpy('py_input), torch.from_numpy('py_scale), torch.from_numpy('py_bias)) 90 | result = F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy() 91 | }; 92 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 93 | }); 94 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 95 | let cpu_result = gpu_tensors.remove(3).into_cpu(TIMER.handle()).unwrap(); 96 | ground.all_close(&cpu_result, 1e-5, 1e-5).unwrap(); 97 | } 98 | } 99 | 100 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 101 | let throughput = Throughput::Elements((PROB_M * PROB_N) as u64); 102 | wgpu_bencher::benchmark(c, &TIMER, LayerNorm::new(1e-5), throughput) 103 | } 104 | 105 | criterion_group!( 106 | name = bench; 107 | config = Criterion::default().with_measurement(&*TIMER); 108 | targets = benchmark 109 | ); 110 | criterion_main!(bench); 111 | -------------------------------------------------------------------------------- /benches/layernorm/naive_vectorized_onepass.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct LayerNormMeta { 22 | M: u32, 23 | N: u32, 24 | ND4: u32, 25 | eps: f32, 26 | } 27 | 28 | impl OpMetadata for LayerNormMeta {} 29 | 30 | #[derive(derive_new::new, Debug)] 31 | pub struct LayerNorm { 32 | eps: f32, 33 | } 34 | 35 | const PROB_M: usize = 2048; 36 | const PROB_N: usize = 512; 37 | 38 | impl KernelBench for LayerNorm { 39 | type Metadata = LayerNormMeta; 40 | 41 | fn name() -> &'static str { 42 | "LayerNormVectorizedOnePass" 43 | } 44 | 45 | fn source(&self, workload: &Workload) -> String { 46 | let mut tera = tera::Tera::default(); 47 | let mut context = tera::Context::new(); 48 | tera.add_raw_template( 49 | Self::name(), 50 | include_str!("../../kernels/layernorm/onepass_vec4.wgsl"), 51 | ) 52 | .unwrap(); 53 | context.insert_workload(workload); 54 | tera.render(Self::name(), &context).unwrap() 55 | } 56 | 57 | fn tensors(&self) -> Vec { 58 | let input = CPUTensor::randn::(shape![1, PROB_M, PROB_N]); 59 | let scale = CPUTensor::randn::(shape![PROB_N]); 60 | let bias = CPUTensor::randn::(shape![PROB_N]); 61 | let output = CPUTensor::zeros::(shape![1, PROB_M, PROB_N]); 62 | vec![input, scale, bias, output] 63 | } 64 | 65 | fn workload(&self, tensors: &[CPUTensor]) -> Workload { 66 | let input = &tensors[0]; 67 | let [_B, M, _N] = input.shape().try_into().unwrap(); 68 | Workload::new(wgs![128, 1, 1], wgc![M as _, 1, 1]) 69 | } 70 | 71 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata { 72 | let input = &tensors[0]; 73 | let [_B, M, N] = input.shape().try_into().unwrap(); 74 | LayerNormMeta::new(M as _, N as _, (N / 4) as _, self.eps) 75 | } 76 | 77 | fn validate(&self, tensors: &[CPUTensor]) { 78 | let (input, scale, bias) = (&tensors[0], &tensors[1], &tensors[2]); 79 | let ground = Python::with_gil(|py| { 80 | let (py_input, py_scale, py_bias) = ( 81 | input.to_py::(&py), 82 | scale.to_py::(&py), 83 | bias.to_py::(&py), 84 | ); 85 | let result: Context = python! { 86 | import torch 87 | import torch.nn.functional as F 88 | 89 | (input, scale, bias) = (torch.from_numpy('py_input), torch.from_numpy('py_scale), torch.from_numpy('py_bias)) 90 | result = F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy() 91 | }; 92 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 93 | }); 94 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 95 | let cpu_result = gpu_tensors.remove(3).into_cpu(TIMER.handle()).unwrap(); 96 | ground.all_close(&cpu_result, 1e-5, 1e-5).unwrap(); 97 | } 98 | } 99 | 100 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 101 | let throughput = Throughput::Elements((PROB_M * PROB_N) as u64); 102 | wgpu_bencher::benchmark(c, &TIMER, LayerNorm::new(1e-5), throughput) 103 | } 104 | 105 | criterion_group!( 106 | name = bench; 107 | config = Criterion::default().with_measurement(&*TIMER); 108 | targets = benchmark 109 | ); 110 | criterion_main!(bench); 111 | -------------------------------------------------------------------------------- /benches/layernorm/welford_scalar.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct LayerNormMeta { 22 | M: u32, 23 | N: u32, 24 | ND4: u32, 25 | eps: f32, 26 | } 27 | 28 | impl OpMetadata for LayerNormMeta {} 29 | 30 | #[derive(derive_new::new, Debug)] 31 | pub struct LayerNorm { 32 | eps: f32, 33 | } 34 | 35 | const PROB_M: usize = 2048; 36 | const PROB_N: usize = 512; 37 | const WARP_SIZE: usize = 32; 38 | 39 | impl KernelBench for LayerNorm { 40 | type Metadata = LayerNormMeta; 41 | 42 | fn name() -> &'static str { 43 | "WelfordScalar" 44 | } 45 | 46 | fn source(&self, workload: &Workload) -> String { 47 | let mut tera = tera::Tera::default(); 48 | let mut context = tera::Context::new(); 49 | tera.add_raw_template( 50 | Self::name(), 51 | include_str!("../../kernels/layernorm/welford_scalar.wgsl"), 52 | ) 53 | .unwrap(); 54 | context.insert_workload(workload); 55 | tera.render(Self::name(), &context).unwrap() 56 | } 57 | 58 | fn tensors(&self) -> Vec { 59 | let input = CPUTensor::randn::(shape![1, PROB_M, PROB_N]); 60 | let scale = CPUTensor::randn::(shape![PROB_N]); 61 | let bias = CPUTensor::randn::(shape![PROB_N]); 62 | let output = CPUTensor::zeros::(shape![1, PROB_M, PROB_N]); 63 | vec![input, scale, bias, output] 64 | } 65 | 66 | fn workload(&self, tensors: &[CPUTensor]) -> Workload { 67 | let input = &tensors[0]; 68 | let [_B, M, _N] = input.shape().try_into().unwrap(); 69 | Workload::new(wgs![WARP_SIZE as _, 1, 1], wgc![M as _, 1, 1]) 70 | } 71 | 72 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata { 73 | let input = &tensors[0]; 74 | let [_B, M, N] = input.shape().try_into().unwrap(); 75 | LayerNormMeta::new(M as _, N as _, (N / 4) as _, self.eps) 76 | } 77 | 78 | fn validate(&self, tensors: &[CPUTensor]) { 79 | let (input, scale, bias) = (&tensors[0], &tensors[1], &tensors[2]); 80 | let ground = Python::with_gil(|py| { 81 | let (py_input, py_scale, py_bias) = ( 82 | input.to_py::(&py), 83 | scale.to_py::(&py), 84 | bias.to_py::(&py), 85 | ); 86 | let result: Context = python! { 87 | import torch 88 | import torch.nn.functional as F 89 | 90 | (input, scale, bias) = (torch.from_numpy('py_input), torch.from_numpy('py_scale), torch.from_numpy('py_bias)) 91 | result = F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy() 92 | }; 93 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 94 | }); 95 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 96 | let cpu_result = gpu_tensors.remove(3).into_cpu(TIMER.handle()).unwrap(); 97 | ground.all_close(&cpu_result, 1e-5, 1e-5).unwrap(); 98 | } 99 | } 100 | 101 | fn benchmark(c: &mut Criterion<&WgpuTimer>) { 102 | let throughput = Throughput::Elements((PROB_M * PROB_N) as u64); 103 | wgpu_bencher::benchmark(c, &TIMER, LayerNorm::new(1e-5), throughput) 104 | } 105 | 106 | criterion_group!( 107 | name = bench; 108 | config = Criterion::default().with_measurement(&*TIMER); 109 | targets = benchmark 110 | ); 111 | criterion_main!(bench); 112 | -------------------------------------------------------------------------------- /benches/layernorm/welford_vectorized.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct LayerNormMeta { 22 | M: u32, 23 | N: u32, 24 | ND4: u32, 25 | eps: f32, 26 | } 27 | 28 | impl OpMetadata for LayerNormMeta {} 29 | 30 | #[derive(derive_new::new, Debug)] 31 | pub struct LayerNorm { 32 | eps: f32, 33 | } 34 | 35 | const PROB_M: usize = 2048; 36 | const PROB_N: usize = 512; 37 | const WARP_SIZE: usize = 32; //M1 warp size 38 | 39 | impl KernelBench for LayerNorm { 40 | type Metadata = LayerNormMeta; 41 | 42 | fn name() -> &'static str { 43 | "WelfordVectorized" 44 | } 45 | 46 | fn source(&self, workload: &Workload) -> String { 47 | let mut tera = tera::Tera::default(); 48 | let mut context = tera::Context::new(); 49 | tera.add_raw_template( 50 | Self::name(), 51 | include_str!("../../kernels/layernorm/welford_vec4.wgsl"), 52 | ) 53 | .unwrap(); 54 | context.insert_workload(workload); 55 | tera.render(Self::name(), &context).unwrap() 56 | } 57 | 58 | fn tensors(&self) -> Vec { 59 | let input = CPUTensor::randn::(shape![1, PROB_M, PROB_N]); 60 | let scale = CPUTensor::randn::(shape![PROB_N]); 61 | let bias = CPUTensor::randn::(shape![PROB_N]); 62 | let output = CPUTensor::zeros::(shape![1, PROB_M, PROB_N]); 63 | vec![input, scale, bias, output] 64 | } 65 | 66 | fn workload(&self, tensors: &[CPUTensor]) -> Workload { 67 | let input = &tensors[0]; 68 | let [_B, M, _N] = input.shape().try_into().unwrap(); 69 | Workload::new(wgs![WARP_SIZE as _, 1, 1], wgc![M as _, 1, 1]) 70 | } 71 | 72 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata { 73 | let input = &tensors[0]; 74 | let [_B, M, N] = input.shape().try_into().unwrap(); 75 | LayerNormMeta::new(M as _, N as _, (N / 4) as _, self.eps) 76 | } 77 | 78 | fn validate(&self, tensors: &[CPUTensor]) { 79 | let (input, scale, bias) = (&tensors[0], &tensors[1], &tensors[2]); 80 | let ground = Python::with_gil(|py| { 81 | let (py_input, py_scale, py_bias) = ( 82 | input.to_py::(&py), 83 | scale.to_py::(&py), 84 | bias.to_py::(&py), 85 | ); 86 | let result: Context = python! { 87 | import torch 88 | import torch.nn.functional as F 89 | 90 | (input, scale, bias) = (torch.from_numpy('py_input), torch.from_numpy('py_scale), torch.from_numpy('py_bias)) 91 | result = F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy() 92 | }; 93 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 94 | }); 95 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 96 | let cpu_result = gpu_tensors.remove(3).into_cpu(TIMER.handle()).unwrap(); 97 | ground.all_close(&cpu_result, 1e-5, 1e-5).unwrap(); 98 | } 99 | } 100 | 101 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 102 | let throughput = Throughput::Elements((PROB_M * PROB_N) as u64); 103 | wgpu_bencher::benchmark(c, &TIMER, LayerNorm::new(1e-5), throughput) 104 | } 105 | 106 | criterion_group!( 107 | name = bench; 108 | config = Criterion::default().with_measurement(&*TIMER); 109 | targets = benchmark 110 | ); 111 | criterion_main!(bench); 112 | -------------------------------------------------------------------------------- /kernels/layernorm/welford_vec4.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var X: array>; 3 | 4 | @group(0) @binding(1) 5 | var S: array>; 6 | 7 | @group(0) @binding(2) 8 | var B: array>; 9 | 10 | @group(0) @binding(3) 11 | var Y: array>; 12 | 13 | struct Meta { 14 | M: u32, 15 | N: u32, 16 | ND4: u32, 17 | eps: f32, 18 | } 19 | 20 | @group(1) @binding(0) 21 | var metadata: Meta; 22 | 23 | var mu: f32; 24 | var sigma: f32; 25 | var subgrp_size: u32; 26 | 27 | fn welford_vcombine(val: vec4, mean: ptr>, m2: ptr>, count: ptr>) { 28 | *count += 1.0; 29 | let delta1 = val - *mean; 30 | *mean += delta1 / *count; 31 | let delta2 = val - *mean; 32 | *m2 += delta1 * delta2; 33 | } 34 | 35 | fn block_welford_combine(b_mean: f32, b_m2: f32, b_count: f32, mean: ptr, m2: ptr, count: ptr) { 36 | if (b_count == 0.0) { 37 | return; 38 | } 39 | let new_count = *count + b_count; 40 | let nb_over_n = b_count / new_count; 41 | let delta = b_mean - *mean; 42 | *mean += delta * nb_over_n; 43 | *m2 += b_m2 + delta * delta * (*count) * nb_over_n; 44 | *count = new_count; 45 | } 46 | 47 | fn welford_warp_reduce(thread_mean: f32, thread_m2: f32, thread_count: f32, mean: ptr, m2: ptr, count: ptr) { 48 | *mean = thread_mean; 49 | *m2 = thread_m2; 50 | *count = thread_count; 51 | for (var offset = subgrp_size >> 1u; offset > 0u; offset >>= 1u) { 52 | let b_mean = subgroupShuffleDown(*mean, offset); 53 | let b_m2 = subgroupShuffleDown(*m2, offset); 54 | let b_count = subgroupShuffleDown(*count, offset); 55 | block_welford_combine(b_mean, b_m2, b_count, mean, m2, count); 56 | } 57 | } 58 | 59 | fn welford_warp_all_reduce(thread_mean: f32, thread_m2: f32, thread_count: f32, mean: ptr, m2: ptr, count: ptr) { 60 | welford_warp_reduce(thread_mean, thread_m2, thread_count, mean, m2, count); 61 | 62 | *mean = subgroupBroadcast(*mean, 0u); 63 | *m2 = subgroupBroadcast(*m2, 0u); 64 | *count = subgroupBroadcast(*count, 0u); 65 | } 66 | 67 | 68 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) 69 | fn main( 70 | @builtin(local_invocation_id) local_id: vec3, 71 | @builtin(workgroup_id) group_id: vec3, 72 | @builtin(global_invocation_id) global_id: vec3, 73 | @builtin(subgroup_id) subgroup_id: u32, 74 | @builtin(subgroup_size) subgroup_size: u32, 75 | ) { 76 | subgrp_size = subgroup_size; 77 | let anchor = (group_id.y * metadata.M * metadata.ND4) + group_id.x * metadata.ND4; 78 | var threadMean = vec4(0.0); 79 | var threadM2 = vec4(0.0); 80 | var threadCount = vec4(0.0); 81 | for (var i = local_id.x; i < metadata.ND4; i+= {{ workgroup_size_x }}u) { 82 | welford_vcombine(X[anchor + i], &threadMean, &threadM2, &threadCount); 83 | } 84 | var finalMean = threadMean.x; 85 | var finalM2 = threadM2.x; 86 | var finalCount = threadCount.x; 87 | block_welford_combine(threadMean.y, threadM2.y, threadCount.y, &finalMean, &finalM2, &finalCount); 88 | block_welford_combine(threadMean.z, threadM2.z, threadCount.z, &finalMean, &finalM2, &finalCount); 89 | block_welford_combine(threadMean.w, threadM2.w, threadCount.w, &finalMean, &finalM2, &finalCount); 90 | 91 | var mean = 0f; 92 | var m2 = 0f; 93 | var count = 0f; 94 | welford_warp_all_reduce(finalMean, finalM2, finalCount, &mean, &m2, &count); 95 | 96 | if (subgroup_id == 0u) { 97 | mu = mean; 98 | sigma = inverseSqrt(m2 / count + metadata.eps); 99 | } 100 | subgroupBarrier(); 101 | for (var i = local_id.x; i < metadata.ND4; i+= {{ workgroup_size_x }}u) { 102 | let val = X[anchor + i]; 103 | let normalized = (val - vec4(mu)) * vec4(sigma); 104 | Y[anchor + i] = fma(normalized, S[i], B[i]); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /benches/layernorm/naive.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct LayerNormMeta { 22 | M: u32, 23 | N: u32, 24 | ND4: u32, 25 | eps: f32, 26 | } 27 | 28 | impl OpMetadata for LayerNormMeta {} 29 | 30 | #[derive(derive_new::new, Debug)] 31 | pub struct LayerNormBench { 32 | M: usize, 33 | N: usize, 34 | eps: f32, 35 | } 36 | 37 | impl KernelBench for LayerNormBench { 38 | type Metadata = LayerNormMeta; 39 | 40 | fn name() -> &'static str { 41 | "LayerNorm" 42 | } 43 | 44 | fn source(&self, workload: &Workload) -> String { 45 | let mut tera = tera::Tera::default(); 46 | let mut context = tera::Context::new(); 47 | tera.add_raw_template( 48 | Self::name(), 49 | include_str!("../../kernels/layernorm/naive_scalar.wgsl"), 50 | ) 51 | .unwrap(); 52 | context.insert_workload(workload); 53 | tera.render(Self::name(), &context).unwrap() 54 | } 55 | 56 | fn tensors(&self) -> Vec { 57 | let (M, N) = (self.M, self.N); 58 | let input = CPUTensor::randn::(shape![1, M, N]); 59 | let scale = CPUTensor::randn::(shape![N]); 60 | let bias = CPUTensor::randn::(shape![N]); 61 | let output = CPUTensor::zeros::(shape![1, M, N]); 62 | vec![input, scale, bias, output] 63 | } 64 | 65 | fn workload(&self, tensors: &[CPUTensor]) -> Workload { 66 | let input = &tensors[0]; 67 | let [_B, M, _N] = input.shape().try_into().unwrap(); 68 | Workload::new(wgs![128, 1, 1], wgc![M as _, 1, 1]) 69 | } 70 | 71 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata { 72 | let input = &tensors[0]; 73 | let [_B, M, N] = input.shape().try_into().unwrap(); 74 | LayerNormMeta::new(M as _, N as _, (N / 4) as _, self.eps) 75 | } 76 | 77 | fn validate(&self, tensors: &[CPUTensor]) { 78 | let (input, scale, bias) = (&tensors[0], &tensors[1], &tensors[2]); 79 | let ground = Python::with_gil(|py| { 80 | let (py_input, py_scale, py_bias) = ( 81 | input.to_py::(&py), 82 | scale.to_py::(&py), 83 | bias.to_py::(&py), 84 | ); 85 | let result: Context = python! { 86 | import torch 87 | import torch.nn.functional as F 88 | 89 | (input, scale, bias) = (torch.from_numpy('py_input), torch.from_numpy('py_scale), torch.from_numpy('py_bias)) 90 | print("Input: ", input) 91 | print("Scale: ", scale) 92 | print("Bias: ", bias) 93 | result = F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy() 94 | }; 95 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 96 | }); 97 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 98 | let cpu_result = gpu_tensors.remove(3).into_cpu(TIMER.handle()).unwrap(); 99 | ground.all_close(&cpu_result, 1e-5, 1e-5).unwrap(); 100 | } 101 | } 102 | 103 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 104 | let M = 2048; 105 | let N = 2048; 106 | let bytes_per_iter = M * N * std::mem::size_of::(); 107 | let tp = Throughput::Bytes(bytes_per_iter as u64); 108 | wgpu_bencher::benchmark(c, &TIMER, LayerNormBench::new(M, N, 1e-5), tp) 109 | } 110 | 111 | criterion_group!( 112 | name = bench; 113 | config = Criterion::default().with_measurement(&*TIMER); 114 | targets = benchmark 115 | ); 116 | criterion_main!(bench); 117 | -------------------------------------------------------------------------------- /benches/mlx-gemv/gemv.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, Debug)] 21 | pub struct MLXGEMVMeta { 22 | M: i32, 23 | N: i32, 24 | } 25 | 26 | impl OpMetadata for MLXGEMVMeta {} 27 | 28 | #[derive(derive_new::new, Debug)] 29 | pub struct MLXGEMVBenchmark { 30 | B: usize, 31 | M: usize, 32 | N: usize, 33 | K: usize, 34 | } 35 | 36 | //#1 gemv_float32_bm8_bn32_tm4_tn4_nc0_axpby0 = [0x600002d2a480 newFunctionWithName:"gemv_float32_bm8_bn32_tm4_tn4_nc0_axpby0"] 37 | 38 | impl KernelBench for MLXGEMVBenchmark { 39 | type Metadata = MLXGEMVMeta; 40 | 41 | fn name() -> &'static str { 42 | "MLXGEMVBenchmark" 43 | } 44 | 45 | fn source(&self, workload: &Workload) -> String { 46 | let mut tera = tera::Tera::default(); 47 | let mut context = tera::Context::new(); 48 | 49 | let template = include_str!("../../kernels/sgemv/mlx_sgemv.wgsl"); 50 | tera.add_raw_template(Self::name(), template).unwrap(); 51 | 52 | context.insert("TM", &4); 53 | context.insert("TN", &4); 54 | context.insert("BM", &8); 55 | context.insert("BN", &32); 56 | 57 | context.insert_workload(workload); 58 | let kernel = tera.render(Self::name(), &context).unwrap(); 59 | println!("{}", kernel); 60 | kernel 61 | } 62 | 63 | fn tensors(&self) -> Vec { 64 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 65 | let a = CPUTensor::randn::(shape![B, M, K]); 66 | let b = CPUTensor::randn::(shape![B, K, N]); 67 | let output = CPUTensor::zeros::(shape![B, M, N]); 68 | vec![a, b, output] 69 | } 70 | 71 | //[dispatchThreadgroups:{512, 1, 1} threadsPerThreadgroup:{32, 8, 1}] 72 | fn workload(&self, _: &[CPUTensor]) -> Workload { 73 | let workgroup_size = wgs![32, 8, 1]; 74 | let workgroup_count = wgc![(self.M / 32) as _, 1, self.B as _]; 75 | let dispatch = Workload::new(workgroup_size, workgroup_count); 76 | println!("DISPATCH: {:?}", dispatch); 77 | dispatch 78 | } 79 | 80 | fn metadata(&self, _: &[CPUTensor]) -> Self::Metadata { 81 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 82 | 83 | let meta = MLXGEMVMeta { M: 16384, N: 3072 }; 84 | println!("META: {:?}", meta); 85 | meta 86 | } 87 | 88 | fn validate(&self, tensors: &[CPUTensor]) { 89 | let (a, b) = (&tensors[0], &tensors[1]); 90 | let ground = Python::with_gil(|py| { 91 | let (py_a, py_b) = (a.to_py::(&py), b.to_py::(&py)); 92 | let result: Context = python! { 93 | import torch 94 | (a, b) = (torch.from_numpy('py_a), torch.from_numpy('py_b)) 95 | result = (a @ b).numpy() 96 | }; 97 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 98 | }); 99 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 100 | let cpu_result = gpu_tensors.remove(2).into_cpu(TIMER.handle()).unwrap(); 101 | println!("GROUND: {}", ground); 102 | println!("OURS: {}", cpu_result); 103 | ground.all_close(&cpu_result, 5e-4, 5e-4).unwrap(); 104 | } 105 | } 106 | 107 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 108 | let B = 1; 109 | let M = 16384; 110 | let N = 1; 111 | let K = 3072; 112 | 113 | let bench = MLXGEMVBenchmark::new(B, M, N, K); 114 | let throughput = Throughput::Elements(2 * (B * M * N * K) as u64); 115 | wgpu_bencher::benchmark(c, &TIMER, bench, throughput) 116 | } 117 | 118 | criterion_group!( 119 | name = bench; 120 | config = Criterion::default().with_measurement(&*TIMER); 121 | targets = benchmark 122 | ); 123 | criterion_main!(bench); 124 | -------------------------------------------------------------------------------- /kernels/qgemv/mlx-qgemv.wgsl: -------------------------------------------------------------------------------- 1 | const SIMD_SIZE = 32i; 2 | 3 | // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up 4 | // into blocks of (BM * TM, BN * TN) divided among threadgroups 5 | // - Every thread works on a block of (TM, TN) 6 | // - We assume each thead group is launched with (BN, BM, 1) threads 7 | // 8 | // 1. A thread loads TN elements each from mat along TM contiguous rows 9 | // and the corresponding scalar from the vector 10 | // 2. The thread then multiplies and adds to accumulate its local result for the block 11 | // 3. At the end, each thread has accumulated results over all blocks across the rows 12 | // These are then summed up across the threadgroup 13 | // 4. Each threadgroup writes its accumulated BN * TN outputs 14 | // 15 | // Edge case handling: 16 | // - The threadgroup with the largest group_id will have blocks that exceed the matrix 17 | // * The blocks that start outside the matrix are never read (thread results remain zero) 18 | // * The last thread that partially overlaps with the matrix is shifted inwards 19 | // such that the thread block fits exactly in the matrix 20 | 21 | struct Meta { 22 | M: i32, //out_vec_size 23 | N: i32, //in_vec_size 24 | } 25 | 26 | @group(0) @binding(0) var mat: array; 27 | @group(0) @binding(1) var scale: array; 28 | @group(0) @binding(2) var inVec: array; 29 | @group(0) @binding(3) var outVec: array; 30 | @group(1) @binding(0) var metadata: Meta; 31 | 32 | var tgpMemory: array; 33 | 34 | @compute @workgroup_size(SIMD_SIZE, {{BM}}, 1) 35 | fn main( 36 | @builtin(global_invocation_id) global_id: vec3, 37 | @builtin(local_invocation_id) local_id: vec3, 38 | @builtin(local_invocation_index) local_index: u32, 39 | @builtin(workgroup_id) group_id: vec3, 40 | @builtin(num_workgroups) num_groups: vec3, 41 | @builtin(subgroup_size) subgroup_size: u32, 42 | @builtin(subgroup_invocation_id) simd_lid: u32 43 | ) { 44 | let simd_gid = local_index / subgroup_size; 45 | 46 | // Threadgroup in_vec cache 47 | let inVecBlockOffset = i32(simd_lid * {{TN}} * 2); 48 | 49 | // Thread local accumulation results 50 | var result: array; 51 | var inter = vec4(0.0); 52 | var vCoeff = vec4(0.0); 53 | 54 | // Block position 55 | var outRow = i32((group_id.x * {{BM}} + simd_gid) * {{TM}}); 56 | 57 | // Exit simdgroup if rows out of bound 58 | if (outRow >= metadata.M) { 59 | return; 60 | } 61 | 62 | // Adjust tail simdgroup to ensure in bound reads 63 | outRow = select(metadata.M - {{TM}}, outRow, outRow + {{TM}} <= metadata.M); 64 | 65 | // Advance matrix 66 | let matOffset = outRow * metadata.N; 67 | 68 | // Loop over in_vec in blocks of SIMD_SIZE * {{TN}} 69 | for (var bn = i32(simd_lid * {{TN}}); bn < i32(metadata.N); bn += {{BN * TN}}) { 70 | workgroupBarrier(); 71 | 72 | // Prefetch in_vector for threadgroup use 73 | if (simd_gid == 0u) { 74 | // Main load loop 75 | if (bn + {{TN}} <= i32(metadata.N)) { 76 | {% for tn in range(end=TN) %} 77 | tgpMemory[inVecBlockOffset + {{tn}}] = inVec[bn + {{tn}}]; 78 | {% endfor %} 79 | } else { // Edgecase 80 | {% for tn in range(end=TN) %} 81 | tgpMemory[inVecBlockOffset + {{tn}}] = select(inVec[bn + {{tn}}], 0.0, bn + {{tn}} < metadata.N); 82 | {% endfor %} 83 | } 84 | } 85 | 86 | workgroupBarrier(); 87 | 88 | // Load for all rows 89 | vCoeff = vec4(tgpMemory[inVecBlockOffset], tgpMemory[inVecBlockOffset + 1], tgpMemory[inVecBlockOffset + 2], tgpMemory[inVecBlockOffset + 3]); 90 | 91 | // Per thread work loop 92 | for (var tm = 0; tm < {{TM}}; tm++) { 93 | // Load for the row 94 | let matIdx = matOffset + tm * metadata.N + bn; 95 | inter = unpack4x8snorm(mat[matIdx / 4]) * scale[matIdx / 16]; 96 | 97 | // Accumulate results 98 | {% for tn in range(end=TN) %} 99 | result[tm] = fma(inter[{{tn}}], vCoeff[{{tn}}], result[tm]); 100 | {% endfor %} 101 | } 102 | } 103 | 104 | for (var tm = 0; tm < {{TM}}; tm++) { 105 | result[tm] = subgroupAdd(result[tm]); 106 | } 107 | 108 | // Write outputs 109 | if (simd_lid == 0u) { 110 | {% for tm in range(end=TM) %} 111 | outVec[outRow + {{tm}}] = result[{{tm}}]; 112 | {% endfor %} 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /benches/rope/rope.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, Strides, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct RopeMeta { 22 | in_strides: glam::UVec3, 23 | out_strides: glam::UVec3, 24 | offset: u32, 25 | base: f32, 26 | scale: f32, 27 | } 28 | 29 | impl OpMetadata for RopeMeta {} 30 | 31 | #[derive(Debug)] 32 | pub struct Rope {} 33 | 34 | impl KernelBench for Rope { 35 | type Metadata = RopeMeta; 36 | 37 | fn name() -> &'static str { 38 | "RoPE" 39 | } 40 | 41 | fn source(&self, workload: &Workload) -> String { 42 | let mut tera = tera::Tera::default(); 43 | let mut context = tera::Context::new(); 44 | tera.add_raw_template(Self::name(), include_str!("../../kernels/rope/rope.wgsl")) 45 | .unwrap(); 46 | context.insert_workload(workload); 47 | tera.render(Self::name(), &context).unwrap() 48 | } 49 | 50 | // [batch_size, num_heads, seq_len, head_dim] 51 | fn tensors(&self) -> Vec { 52 | let input = CPUTensor::randn::(shape![2, 16, 64, 128]); 53 | let output = CPUTensor::zeros::(shape![2, 16, 64, 128]); 54 | vec![input, output] 55 | } 56 | 57 | fn workload(&self, tensors: &[CPUTensor]) -> Workload { 58 | let input = &tensors[0]; 59 | let [BS, NH, SL, HD] = input.shape().try_into().unwrap(); 60 | 61 | let total_x = 128 / 2; 62 | let total_y = SL; 63 | println!("INPUT: {:?}", input.shape()); 64 | println!("SL * HD: {}", SL * HD); 65 | let total_z = input.shape().numel() / (SL * HD); 66 | 67 | let wgsx = 16; 68 | let wgsy = 8; 69 | let wgsz = 8; 70 | 71 | let wgcx = total_x / wgsx; 72 | let wgcy = total_y / wgsy; 73 | let wgcz = total_z / wgsz; 74 | 75 | let wl = Workload::new( 76 | wgs![wgsx as _, wgsy as _, wgsz as _], 77 | wgc![wgcx as _, wgcy as _, wgcz as _], 78 | ); 79 | println!("{:?}", wl); 80 | wl 81 | } 82 | 83 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata { 84 | let input = &tensors[0]; 85 | let out = &tensors[1]; 86 | let mut input_shape = input.shape().clone(); 87 | let mut out_shape = out.shape().clone(); 88 | input_shape.remove(0); 89 | out_shape.remove(0); 90 | let in_strides = Strides::from(&input_shape); 91 | let out_strides = Strides::from(&out_shape); 92 | let meta = RopeMeta::new( 93 | (&in_strides).into(), 94 | (&out_strides).into(), 95 | 0, 96 | f32::log2(10000.0), 97 | 1.0, 98 | ); 99 | println!("{:?}", meta); 100 | meta 101 | } 102 | 103 | fn validate(&self, tensors: &[CPUTensor]) { 104 | let input = &tensors[0]; 105 | let ground = Python::with_gil(|py| { 106 | let py_input = input.to_py::(&py); 107 | let result: Context = python! { 108 | import mlx.core as mx 109 | import mlx.nn as nn 110 | import numpy as np 111 | 112 | rope = nn.RoPE(128) 113 | mx_input = mx.array('py_input) 114 | y = rope(mx_input) 115 | mx.eval(y) 116 | result = np.array(y) 117 | }; 118 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 119 | }); 120 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 121 | let cpu_result = gpu_tensors.remove(1).into_cpu(TIMER.handle()).unwrap(); 122 | println!("MLX: {}\n", ground); 123 | println!("US: {}", cpu_result); 124 | ground.all_close(&cpu_result, 1e-5, 1e-5).unwrap(); 125 | } 126 | } 127 | 128 | fn benchmark(c: &mut Criterion<&WgpuTimer>) { 129 | let throughput = Throughput::Elements(16 * 64 * 128); 130 | wgpu_bencher::benchmark(c, &TIMER, Rope {}, throughput) 131 | } 132 | 133 | criterion_group!( 134 | name = bench; 135 | config = Criterion::default().with_measurement(&*TIMER); 136 | targets = benchmark 137 | ); 138 | criterion_main!(bench); 139 | -------------------------------------------------------------------------------- /benches/mlx-qgemv/gemv.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, Quantization, Quantizer, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, Debug)] 21 | pub struct MLXQGEMVMeta { 22 | M: i32, 23 | N: i32, 24 | } 25 | 26 | impl OpMetadata for MLXQGEMVMeta {} 27 | 28 | #[derive(derive_new::new, Debug)] 29 | pub struct MLXQGEMVBenchmark { 30 | B: usize, 31 | M: usize, 32 | N: usize, 33 | K: usize, 34 | } 35 | 36 | //#1 gemv_float32_bm8_bn32_tm4_tn4_nc0_axpby0 = [0x600002d2a480 newFunctionWithName:"gemv_float32_bm8_bn32_tm4_tn4_nc0_axpby0"] 37 | 38 | impl KernelBench for MLXQGEMVBenchmark { 39 | type Metadata = MLXQGEMVMeta; 40 | 41 | fn name() -> &'static str { 42 | "MLXQGEMVBenchmark" 43 | } 44 | 45 | fn source(&self, workload: &Workload) -> String { 46 | let mut tera = tera::Tera::default(); 47 | let mut context = tera::Context::new(); 48 | 49 | let template = include_str!("../../kernels/qgemv/mlx-qgemv.wgsl"); 50 | tera.add_raw_template(Self::name(), template).unwrap(); 51 | 52 | context.insert("TM", &4); 53 | context.insert("TN", &4); 54 | context.insert("BM", &8); 55 | context.insert("BN", &32); 56 | 57 | context.insert_workload(workload); 58 | let kernel = tera.render(Self::name(), &context).unwrap(); 59 | println!("{}", kernel); 60 | kernel 61 | } 62 | 63 | fn tensors(&self) -> Vec { 64 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 65 | let a_unquant = CPUTensor::randn::(shape![B, M, K]); 66 | 67 | let quantizer = Quantizer::new(Quantization::SInt8); 68 | let quantized_a = quantizer.quantize(a_unquant.clone()); 69 | 70 | let b = CPUTensor::randn::(shape![B, K, N]); 71 | let output = CPUTensor::zeros::(shape![B, M, N]); 72 | vec![quantized_a, b, output] 73 | } 74 | 75 | //[dispatchThreadgroups:{512, 1, 1} threadsPerThreadgroup:{32, 8, 1}] 76 | fn workload(&self, _: &[CPUTensor]) -> Workload { 77 | let workgroup_size = wgs![32, 8, 1]; 78 | let workgroup_count = wgc![(self.M / 32) as _, 1, self.B as _]; 79 | let dispatch = Workload::new(workgroup_size, workgroup_count); 80 | println!("DISPATCH: {:?}", dispatch); 81 | dispatch 82 | } 83 | 84 | fn metadata(&self, _: &[CPUTensor]) -> Self::Metadata { 85 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 86 | 87 | let meta = MLXQGEMVMeta { M: 16384, N: 3072 }; 88 | println!("META: {:?}", meta); 89 | meta 90 | } 91 | 92 | fn validate(&self, tensors: &[CPUTensor]) { 93 | let (aquant, b) = (&tensors[0], &tensors[1]); 94 | let dequantized = Quantizer::new(Quantization::SInt8).dequantize(aquant.clone()); 95 | let ground = Python::with_gil(|py| { 96 | let (py_a, py_b) = (dequantized.to_py::(&py), b.to_py::(&py)); 97 | let result: Context = python! { 98 | import torch 99 | (a, b) = (torch.from_numpy('py_a), torch.from_numpy('py_b)) 100 | print("A: ", a) 101 | print("B: ", b) 102 | result = (a @ b).numpy() 103 | }; 104 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 105 | }); 106 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 107 | let cpu_result = gpu_tensors.remove(2).into_cpu(TIMER.handle()).unwrap(); 108 | println!("OURS: {}", cpu_result); 109 | println!("GROUND: {}", ground); 110 | ground.all_close(&cpu_result, 1e-2, 1e-2).unwrap(); 111 | } 112 | } 113 | 114 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 115 | let B = 1; 116 | let M = 16384; 117 | let N = 1; 118 | let K = 3072; 119 | 120 | let bench = MLXQGEMVBenchmark::new(B, M, N, K); 121 | let throughput = Throughput::Elements(2 * (B * M * N * K) as u64); 122 | wgpu_bencher::benchmark(c, &TIMER, bench, throughput) 123 | } 124 | 125 | criterion_group!( 126 | name = bench; 127 | config = Criterion::default().with_measurement(&*TIMER); 128 | targets = benchmark 129 | ); 130 | criterion_main!(bench); 131 | -------------------------------------------------------------------------------- /src/dtype.rs: -------------------------------------------------------------------------------- 1 | use std::{cmp::max, num::NonZeroU64}; 2 | 3 | use half::{bf16, f16}; 4 | use wgpu::{BufferAddress, BufferSize}; 5 | 6 | use crate::{MIN_STORAGE_BUFFER_SIZE, STORAGE_BUFFER_ALIGN}; 7 | 8 | #[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] 9 | pub enum DType { 10 | Q8, 11 | F16, 12 | BF16, 13 | #[default] 14 | F32, 15 | I32, 16 | U32, 17 | WQ8, //Packed Q8 (|--4xQ8(u32)--| |--f32--|) 18 | } 19 | 20 | impl DType { 21 | pub fn to_u32(self) -> u32 { 22 | match self { 23 | DType::F32 => 0, 24 | DType::F16 => 1, 25 | DType::WQ8 => 64, 26 | _ => unimplemented!(), 27 | } 28 | } 29 | 30 | /// Returns the size of the type in bytes. 31 | pub fn size_of(self) -> usize { 32 | match self { 33 | DType::Q8 => 1, 34 | DType::F16 => 2, 35 | DType::BF16 => 2, 36 | DType::F32 => 4, 37 | DType::I32 => 4, 38 | DType::U32 => 4, 39 | DType::WQ8 => 4, 40 | } 41 | } 42 | 43 | pub fn segments(&self, numel: usize, buffer_bytes: usize) -> Vec { 44 | match self { 45 | DType::WQ8 => { 46 | let aligner = |numel: usize, size_t: usize| -> usize { 47 | let nbytes = numel * size_t; 48 | 49 | if nbytes % STORAGE_BUFFER_ALIGN != 0 { 50 | nbytes + STORAGE_BUFFER_ALIGN - nbytes % STORAGE_BUFFER_ALIGN 51 | } else { 52 | nbytes 53 | } 54 | }; 55 | let weight_size = aligner(numel / 4, std::mem::size_of::()); 56 | let absmax_size = aligner(numel / 16, std::mem::size_of::()); 57 | assert_eq!(weight_size + absmax_size, buffer_bytes); 58 | 59 | let weights = BufferSegment::new(0, Some(weight_size as u64), true); 60 | let absmax = BufferSegment::new(weight_size as u64, Some(absmax_size as u64), true); 61 | vec![weights, absmax] 62 | } 63 | _ => { 64 | let mut total_bytes = numel * self.size_of(); 65 | total_bytes = max(total_bytes, MIN_STORAGE_BUFFER_SIZE); 66 | 67 | vec![BufferSegment::new(0, Some(total_bytes as u64), false)] 68 | } 69 | } 70 | } 71 | } 72 | 73 | impl DType { 74 | fn handle_type_str(ts: npyz::TypeStr) -> DType { 75 | match ts.endianness() { 76 | npyz::Endianness::Little => match (ts.type_char(), ts.size_field()) { 77 | (npyz::TypeChar::Float, 4) => DType::F32, 78 | (npyz::TypeChar::Int, 4) => DType::I32, 79 | (npyz::TypeChar::Uint, 4) => DType::U32, 80 | (t, s) => unimplemented!("{} {}", t, s), 81 | }, 82 | _ => unimplemented!(), 83 | } 84 | } 85 | } 86 | 87 | impl From for DType { 88 | fn from(dtype: npyz::DType) -> Self { 89 | match dtype { 90 | npyz::DType::Plain(ts) => Self::handle_type_str(ts), 91 | _ => unimplemented!(), 92 | } 93 | } 94 | } 95 | 96 | #[derive(Debug)] 97 | pub struct BufferSegment { 98 | pub offset: BufferAddress, 99 | pub size: Option, 100 | } 101 | 102 | impl BufferSegment { 103 | pub fn new(offset: BufferAddress, size: Option, aligned: bool) -> Self { 104 | if let Some(size) = size { 105 | if aligned { 106 | assert!(size % 256 == 0); //storage buffer alignment 107 | } 108 | } 109 | let size = size.map(NonZeroU64::new).unwrap(); 110 | Self { offset, size } 111 | } 112 | } 113 | 114 | pub trait DataType: 115 | Clone + std::fmt::Debug + PartialEq + 'static + num_traits::Zero + Send + Sync + bytemuck::Pod 116 | { 117 | fn dt() -> DType; 118 | 119 | fn one() -> Self; 120 | } 121 | 122 | macro_rules! map_type { 123 | ($t:ty, $v:ident) => { 124 | impl DataType for $t { 125 | fn dt() -> DType { 126 | DType::$v 127 | } 128 | 129 | fn one() -> Self { 130 | 1 as Self 131 | } 132 | } 133 | }; 134 | } 135 | 136 | macro_rules! map_half_type { 137 | ($t:ty, $v:ident) => { 138 | impl DataType for $t { 139 | fn dt() -> DType { 140 | DType::$v 141 | } 142 | 143 | fn one() -> Self { 144 | Self::ONE 145 | } 146 | } 147 | }; 148 | } 149 | 150 | map_type!(f32, F32); 151 | map_type!(i32, I32); 152 | map_type!(u32, U32); 153 | map_half_type!(f16, F16); 154 | map_half_type!(bf16, BF16); 155 | -------------------------------------------------------------------------------- /kernels/sgemv/mlx_sgemv.wgsl: -------------------------------------------------------------------------------- 1 | const SIMD_SIZE = 32i; 2 | 3 | // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up 4 | // into blocks of (BM * TM, BN * TN) divided among threadgroups 5 | // - Every thread works on a block of (TM, TN) 6 | // - We assume each thead group is launched with (BN, BM, 1) threads 7 | // 8 | // 1. A thread loads TN elements each from mat along TM contiguous rows 9 | // and the corresponding scalar from the vector 10 | // 2. The thread then multiplies and adds to accumulate its local result for the block 11 | // 3. At the end, each thread has accumulated results over all blocks across the rows 12 | // These are then summed up across the threadgroup 13 | // 4. Each threadgroup writes its accumulated BN * TN outputs 14 | // 15 | // Edge case handling: 16 | // - The threadgroup with the largest group_id will have blocks that exceed the matrix 17 | // * The blocks that start outside the matrix are never read (thread results remain zero) 18 | // * The last thread that partially overlaps with the matrix is shifted inwards 19 | // such that the thread block fits exactly in the matrix 20 | 21 | struct Meta { 22 | M: i32, //out_vec_size 23 | N: i32, //in_vec_size 24 | } 25 | 26 | @group(0) @binding(0) var mat: array; 27 | @group(0) @binding(1) var inVec: array; 28 | @group(0) @binding(2) var outVec: array; 29 | @group(1) @binding(0) var metadata: Meta; 30 | 31 | var tgpMemory: array; 32 | 33 | @compute @workgroup_size(SIMD_SIZE, {{BM}}, 1) 34 | fn main( 35 | @builtin(global_invocation_id) global_id: vec3, 36 | @builtin(local_invocation_id) local_id: vec3, 37 | @builtin(local_invocation_index) local_index: u32, 38 | @builtin(workgroup_id) group_id: vec3, 39 | @builtin(num_workgroups) num_groups: vec3, 40 | @builtin(subgroup_size) subgroup_size: u32, 41 | @builtin(subgroup_invocation_id) simd_lid: u32 42 | ) { 43 | let simd_gid = local_index / subgroup_size; 44 | 45 | // Threadgroup in_vec cache 46 | let inVecBlockOffset = i32(simd_lid * {{TN}} * 2); 47 | 48 | // Thread local accumulation results 49 | var result: array; 50 | var inter: array; 51 | var vCoeff: array; 52 | 53 | // Block position 54 | var outRow = i32((group_id.x * {{BM}} + simd_gid) * {{TM}}); 55 | 56 | // Exit simdgroup if rows out of bound 57 | if (outRow >= metadata.M) { 58 | return; 59 | } 60 | 61 | // Adjust tail simdgroup to ensure in bound reads 62 | outRow = select(metadata.M - {{TM}}, outRow, outRow + {{TM}} <= metadata.M); 63 | 64 | // Advance matrix 65 | let matOffset = outRow * metadata.N; 66 | 67 | // Loop over in_vec in blocks of SIMD_SIZE * {{TN}} 68 | for (var bn = i32(simd_lid * {{TN}}); bn < i32(metadata.N); bn += {{BN * TN}}) { 69 | workgroupBarrier(); 70 | 71 | // Prefetch in_vector for threadgroup use 72 | if (simd_gid == 0u) { 73 | // Main load loop 74 | if (bn + {{TN}} <= i32(metadata.N)) { 75 | {% for tn in range(end=TN) %} 76 | tgpMemory[inVecBlockOffset + {{tn}}] = inVec[bn + {{tn}}]; 77 | {% endfor %} 78 | } else { // Edgecase 79 | {% for tn in range(end=TN) %} 80 | tgpMemory[inVecBlockOffset + {{tn}}] = select(inVec[bn + {{tn}}], 0.0, bn + {{tn}} < metadata.N); 81 | {% endfor %} 82 | } 83 | } 84 | 85 | workgroupBarrier(); 86 | 87 | // Load for all rows 88 | {% for tn in range(end=TN) %} 89 | vCoeff[{{tn}}] = tgpMemory[inVecBlockOffset + {{tn}}]; 90 | {% endfor %} 91 | 92 | // Per thread work loop 93 | for (var tm = 0; tm < {{TM}}; tm++) { 94 | // Load for the row 95 | if (bn + {{TN}} <= metadata.N) { 96 | {% for tn in range(end=TN) %} 97 | inter[{{tn}}] = mat[matOffset + tm * metadata.N + bn + {{tn}}]; 98 | {% endfor %} 99 | } else { // Edgecase 100 | {% for tn in range(end=TN) %} 101 | inter[{{tn}}] = mat[matOffset + tm * metadata.N + select(metadata.N - 1, bn + {{tn}}, bn + {{tn}} < metadata.N)]; 102 | {% endfor %} 103 | } 104 | 105 | // Accumulate results 106 | {% for tn in range(end=TN) %} 107 | result[tm] = fma(inter[{{tn}}], vCoeff[{{tn}}], result[tm]); 108 | {% endfor %} 109 | } 110 | } 111 | 112 | for (var tm = 0; tm < {{TM}}; tm++) { 113 | result[tm] = subgroupAdd(result[tm]); 114 | } 115 | 116 | // Write outputs 117 | if (simd_lid == 0u) { 118 | {% for tm in range(end=TM) %} 119 | outVec[outRow + {{tm}}] = result[{{tm}}]; 120 | {% endfor %} 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/storage.rs: -------------------------------------------------------------------------------- 1 | use std::{alloc::Layout, ops::RangeBounds, sync::Arc}; 2 | use wgpu::{util::DeviceExt, Buffer, BufferAddress, BufferSlice, BufferUsages}; 3 | 4 | use crate::GPUHandle; 5 | 6 | // Caution: no pooling of buffers is done for benchmarking 7 | // long running benchmarks could OOM 8 | pub trait Storage: std::fmt::Debug + Clone + 'static { 9 | fn to_gpu(self, handle: &GPUHandle) -> GPUStorage; 10 | fn to_cpu(self) -> CPUStorage; 11 | fn n_bytes(&self) -> usize; 12 | } 13 | 14 | #[derive(derive_new::new, Debug, PartialEq, Eq)] 15 | pub struct CPUStorage(*mut u8, Layout); 16 | 17 | impl CPUStorage { 18 | pub fn inner(&self) -> (*mut u8, Layout) { 19 | (self.0, self.1) 20 | } 21 | 22 | pub fn as_bytes_mut(&mut self) -> &mut [u8] { 23 | unsafe { std::slice::from_raw_parts_mut(self.0, self.1.size()) } 24 | } 25 | 26 | pub fn as_bytes(&self) -> &[u8] { 27 | unsafe { std::slice::from_raw_parts(self.0, self.1.size()) } 28 | } 29 | } 30 | 31 | impl Clone for CPUStorage { 32 | fn clone(&self) -> Self { 33 | let (ptr, layout) = self.inner(); 34 | let alloc = unsafe { std::alloc::alloc(layout) }; 35 | unsafe { ptr.copy_to_nonoverlapping(alloc, layout.size()) }; 36 | 37 | Self(alloc, layout) 38 | } 39 | } 40 | 41 | impl Drop for CPUStorage { 42 | fn drop(&mut self) { 43 | if !self.0.is_null() && self.1.size() > 0 { 44 | unsafe { std::alloc::dealloc(self.0, self.1) } 45 | } 46 | } 47 | } 48 | 49 | impl Storage for CPUStorage { 50 | //No allocations are pooled here because we don't care 51 | fn to_gpu(self, handle: &GPUHandle) -> GPUStorage { 52 | let mut min_bytes = [0; 16]; 53 | let bytes = if self.as_bytes().len() < 16 { 54 | min_bytes[..self.as_bytes().len()].copy_from_slice(self.as_bytes()); 55 | &min_bytes //&[u8] 56 | } else { 57 | self.as_bytes() //&[u8] 58 | }; 59 | 60 | let buffer = handle 61 | .device() 62 | .create_buffer_init(&wgpu::util::BufferInitDescriptor { 63 | label: None, 64 | contents: bytes, 65 | usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, 66 | }); 67 | //These should be batched up 68 | handle.queue().submit(None); 69 | handle.device().poll(wgpu::Maintain::Wait); 70 | GPUStorage(buffer.into()) 71 | } 72 | 73 | fn to_cpu(self) -> CPUStorage { 74 | self 75 | } 76 | 77 | fn n_bytes(&self) -> usize { 78 | self.1.size() 79 | } 80 | } 81 | 82 | #[derive(Debug, Clone)] 83 | pub struct GPUBuffer(Arc); 84 | 85 | impl std::ops::Deref for GPUBuffer { 86 | type Target = wgpu::Buffer; 87 | 88 | fn deref(&self) -> &Self::Target { 89 | &self.0 90 | } 91 | } 92 | 93 | impl From for GPUBuffer { 94 | fn from(b: wgpu::Buffer) -> Self { 95 | Self(Arc::new(b)) 96 | } 97 | } 98 | 99 | #[derive(Clone)] 100 | pub struct GPUStorage(GPUBuffer); 101 | 102 | impl From for GPUStorage { 103 | fn from(b: GPUBuffer) -> Self { 104 | Self(b) 105 | } 106 | } 107 | 108 | impl std::fmt::Debug for GPUStorage { 109 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 110 | f.debug_struct("GPUStorage") 111 | .field("buffer", &self.0.global_id()) 112 | .field("size", &self.0.size()) 113 | .field("usage", &self.0.usage()) 114 | .finish() 115 | } 116 | } 117 | 118 | impl PartialEq for GPUStorage { 119 | fn eq(&self, other: &Self) -> bool { 120 | self.0.global_id() == other.0.global_id() 121 | } 122 | } 123 | 124 | impl GPUStorage { 125 | pub fn new(buffer: GPUBuffer) -> Self { 126 | Self(buffer) 127 | } 128 | 129 | pub fn inner(&self) -> &GPUBuffer { 130 | &self.0 131 | } 132 | 133 | pub fn set_inner(&mut self, b: GPUBuffer) { 134 | self.0 = b; 135 | } 136 | 137 | pub fn as_entire_binding(&self) -> wgpu::BindingResource { 138 | self.0.as_entire_binding() 139 | } 140 | 141 | pub fn usage(&self) -> wgpu::BufferUsages { 142 | self.0.usage() 143 | } 144 | 145 | pub fn slice>(&self, bounds: S) -> BufferSlice { 146 | self.0.slice(bounds) 147 | } 148 | 149 | pub fn unmap(&self) { 150 | self.0.unmap(); 151 | } 152 | 153 | pub fn buffer_id(&self) -> wgpu::Id { 154 | self.0.global_id() 155 | } 156 | 157 | pub fn size(&self) -> BufferAddress { 158 | self.0.size() 159 | } 160 | } 161 | 162 | impl Storage for GPUStorage { 163 | fn to_gpu(self, _h: &GPUHandle) -> GPUStorage { 164 | self 165 | } 166 | 167 | fn to_cpu(self) -> CPUStorage { 168 | todo!() 169 | } 170 | 171 | fn n_bytes(&self) -> usize { 172 | self.0.size() as usize 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /benches/qgemv/gemv.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::{IntoPy, Python}; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, Quantization, Quantizer, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, Debug)] 21 | pub struct QGEMVMeta { 22 | aShape: glam::IVec3, 23 | aStrides: glam::IVec3, 24 | bShape: glam::IVec3, 25 | bStrides: glam::IVec3, 26 | outShape: glam::IVec3, 27 | outStrides: glam::IVec3, 28 | dimAOuter: i32, 29 | dimBOuter: i32, 30 | dimInner: i32, 31 | } 32 | 33 | impl OpMetadata for QGEMVMeta {} 34 | 35 | #[derive(derive_new::new, Debug)] 36 | pub struct QGEMVBenchmark { 37 | B: usize, 38 | M: usize, 39 | N: usize, 40 | K: usize, 41 | TILE_DIM: usize, 42 | ROW_PER_THREAD: usize, 43 | trans_a: bool, 44 | trans_b: bool, 45 | } 46 | 47 | impl KernelBench for QGEMVBenchmark { 48 | type Metadata = QGEMVMeta; 49 | 50 | fn name() -> &'static str { 51 | "QGEMVBenchmark" 52 | } 53 | 54 | fn source(&self, workload: &Workload) -> String { 55 | let mut tera = tera::Tera::default(); 56 | let mut context = tera::Context::new(); 57 | 58 | let template = include_str!("../../kernels/qgemv/sgemv_2v.wgsl"); 59 | tera.add_raw_template(Self::name(), template).unwrap(); 60 | 61 | context.insert("TILE_DIM", &self.TILE_DIM); 62 | context.insert("ROW_PER_THREAD", &self.ROW_PER_THREAD); 63 | context.insert_workload(workload); 64 | let kernel = tera.render(Self::name(), &context).unwrap(); 65 | println!("{}", kernel); 66 | kernel 67 | } 68 | 69 | fn tensors(&self) -> Vec { 70 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 71 | println!("B: {}, M: {}, N: {}, K: {}", B, M, N, K); 72 | let a_unquant = CPUTensor::randn::(shape![B, M, K]); 73 | let b = CPUTensor::randn::(shape![B, K, N]); 74 | let quantizer = Quantizer::new(Quantization::SInt8); 75 | let quantized_a = quantizer.quantize(a_unquant.clone()); 76 | let output = CPUTensor::zeros::(shape![B, M, N]); 77 | vec![quantized_a, b, output] 78 | } 79 | 80 | fn workload(&self, _: &[CPUTensor]) -> Workload { 81 | let wgsx: usize = 8; 82 | let workgroup_size = wgs![wgsx as _, 8, 1]; 83 | let workgroup_count = wgc![(self.M / wgsx) as _, 1, self.B as _]; 84 | let dispatch = Workload::new(workgroup_size, workgroup_count); 85 | println!("DISPATCH: {:?}", dispatch); 86 | dispatch 87 | } 88 | 89 | fn metadata(&self, _: &[CPUTensor]) -> Self::Metadata { 90 | let (B, M, N, K) = (self.B as i32, self.M as i32, self.N as i32, self.K as i32); 91 | 92 | let aShape = glam::IVec3::new(B, M, K); 93 | let aStrides = glam::IVec3::new(M * K, K, 1); 94 | let bShape = glam::IVec3::new(B, K, N); 95 | let bStrides = glam::IVec3::new(K * N, N, 1); 96 | let outShape = glam::IVec3::new(B, M, N); 97 | let outStrides = glam::IVec3::new(M * N, N, 1); 98 | 99 | let dimAOuter = if self.trans_a { K } else { M }; 100 | let dimBOuter = if self.trans_b { K } else { N }; 101 | let dimInner = if self.trans_a { M } else { K }; 102 | 103 | let meta = QGEMVMeta { 104 | aShape, 105 | aStrides, 106 | bShape, 107 | bStrides, 108 | outShape, 109 | outStrides, 110 | dimAOuter, 111 | dimBOuter, 112 | dimInner, 113 | }; 114 | println!("META: {:?}", meta); 115 | meta 116 | } 117 | 118 | fn validate(&self, tensors: &[CPUTensor]) { 119 | let (aquant, b) = (&tensors[0], &tensors[1]); 120 | let dequantized = Quantizer::new(Quantization::SInt8).dequantize(aquant.clone()); 121 | let ground = Python::with_gil(|py| { 122 | let (py_a, py_b) = (dequantized.to_py::(&py), b.to_py::(&py)); 123 | let result: Context = python! { 124 | import torch 125 | (a, b) = (torch.from_numpy('py_a), torch.from_numpy('py_b)) 126 | print("A: ", a) 127 | print("B: ", b) 128 | result = (a @ b).numpy() 129 | }; 130 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 131 | }); 132 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 133 | let cpu_result = gpu_tensors.remove(2).into_cpu(TIMER.handle()).unwrap(); 134 | println!("OURS: {}", cpu_result); 135 | println!("GROUND: {}", ground); 136 | ground.all_close(&cpu_result, 1e-2, 1e-2).unwrap(); 137 | } 138 | } 139 | 140 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 141 | let B = 1; 142 | let M = 16384; 143 | let N = 1; 144 | let K = 3072; 145 | let TILE_DIM = 32; 146 | let ROW_PER_THREAD = 4; 147 | 148 | let trans_a = false; 149 | let trans_b = false; 150 | 151 | let bench = QGEMVBenchmark::new(B, M, N, K, TILE_DIM, ROW_PER_THREAD, trans_a, trans_b); 152 | let throughput = Throughput::Elements(2 * (B * M * N * K) as u64); 153 | wgpu_bencher::benchmark(c, &TIMER, bench, throughput) 154 | } 155 | 156 | criterion_group!( 157 | name = bench; 158 | config = Criterion::default().with_measurement(&*TIMER); 159 | targets = benchmark 160 | ); 161 | criterion_main!(bench); 162 | -------------------------------------------------------------------------------- /benches/qgemm/tfjs.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::Python; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, Quantization, Quantizer, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, derive_new::new, Debug)] 21 | pub struct QGEMMMeta { 22 | aShape: glam::IVec3, 23 | aStrides: glam::IVec3, 24 | bShape: glam::IVec3, 25 | bStrides: glam::IVec3, 26 | outShape: glam::IVec3, 27 | outStrides: glam::IVec3, 28 | dimInner: i32, 29 | } 30 | 31 | impl OpMetadata for QGEMMMeta {} 32 | 33 | #[derive(derive_new::new, Debug)] 34 | pub struct QGEMMBenchmark { 35 | B: usize, 36 | M: usize, 37 | N: usize, 38 | K: usize, 39 | TILE_DIM: usize, 40 | ROW_PER_THREAD: usize, 41 | } 42 | 43 | impl QGEMMBenchmark { 44 | fn shape_fit(&self) -> [bool; 3] { 45 | let aOuter = self.M; 46 | let bOuter = self.N; 47 | let dimInner = self.K; 48 | 49 | let mut shape_fit = [false; 3]; 50 | shape_fit[0] = aOuter % self.TILE_DIM == 0; 51 | shape_fit[1] = bOuter % self.TILE_DIM == 0; 52 | shape_fit[2] = dimInner % self.TILE_DIM == 0; 53 | shape_fit 54 | } 55 | } 56 | 57 | impl KernelBench for QGEMMBenchmark { 58 | type Metadata = QGEMMMeta; 59 | 60 | fn name() -> &'static str { 61 | "QGEMMBenchmark" 62 | } 63 | 64 | fn source(&self, workload: &Workload) -> String { 65 | let mut tera = tera::Tera::default(); 66 | let mut context = tera::Context::new(); 67 | tera.add_raw_template(Self::name(), include_str!("../../kernels/qgemm/tfjs.wgsl")) 68 | .unwrap(); 69 | let shape_fit = self.shape_fit(); 70 | context.insert("A_FIT", &shape_fit[0]); 71 | context.insert("B_FIT", &shape_fit[1]); 72 | context.insert("INNER_FIT", &shape_fit[2]); 73 | 74 | context.insert("TILE_DIM", &self.TILE_DIM); 75 | context.insert("ROW_PER_THREAD", &self.ROW_PER_THREAD); 76 | context.insert_workload(workload); 77 | tera.render(Self::name(), &context).unwrap() 78 | } 79 | 80 | fn tensors(&self) -> Vec { 81 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 82 | println!("B: {}, M: {}, N: {}, K: {}", B, M, N, K); 83 | let a_unquant = CPUTensor::randn::(shape![B, M, K]); 84 | let b = CPUTensor::randn::(shape![B, K, N]); 85 | let quantizer = Quantizer::new(Quantization::SInt8); 86 | let quantized_a = quantizer.quantize(a_unquant.clone()); 87 | let output = CPUTensor::zeros::(shape![B, M, N]); 88 | vec![quantized_a, b, output] 89 | } 90 | 91 | fn workload(&self, _: &[CPUTensor]) -> Workload { 92 | let (TILE_DIM, ROW_PER_THREAD) = (self.TILE_DIM, self.ROW_PER_THREAD); 93 | let workgroup_size = wgs![(TILE_DIM / 4) as _, (TILE_DIM / ROW_PER_THREAD) as _, 1]; 94 | let group_x = Workload::ceil(self.N, TILE_DIM); 95 | let group_y = Workload::ceil(self.M, TILE_DIM); 96 | let workgroup_count = wgc![group_x as _, group_y as _, self.B as u32]; 97 | Workload::new(workgroup_size, workgroup_count) 98 | } 99 | 100 | fn metadata(&self, _: &[CPUTensor]) -> Self::Metadata { 101 | let (B, M, N, K) = (self.B as i32, self.M as i32, self.N as i32, self.K as i32); 102 | 103 | let aShape = glam::IVec3::new(B, M, K); 104 | let aStrides = glam::IVec3::new(M * K, K, 1); 105 | let bShape = glam::IVec3::new(B, K, N); 106 | let bStrides = glam::IVec3::new(K * N, N, 1); 107 | let outShape = glam::IVec3::new(B, M, N); 108 | let outStrides = glam::IVec3::new(M * N, N, 1); 109 | 110 | let meta = QGEMMMeta::new(aShape, aStrides, bShape, bStrides, outShape, outStrides, K); 111 | println!("META: {:?}", meta); 112 | meta 113 | } 114 | 115 | fn validate(&self, tensors: &[CPUTensor]) { 116 | let (aquant, b) = (&tensors[0], &tensors[1]); 117 | let dequantized = Quantizer::new(Quantization::SInt8).dequantize(aquant.clone()); 118 | let ground = Python::with_gil(|py| { 119 | let (py_a, py_b) = (dequantized.to_py::(&py), b.to_py::(&py)); 120 | let result: Context = python! { 121 | import torch 122 | (a, b) = (torch.from_numpy('py_a), torch.from_numpy('py_b)) 123 | print("A: ", a) 124 | print("B: ", b) 125 | result = (a @ b).numpy() 126 | }; 127 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 128 | }); 129 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 130 | let cpu_result = gpu_tensors.remove(2).into_cpu(TIMER.handle()).unwrap(); 131 | println!("OURS: {}", cpu_result); 132 | println!("GROUND: {}", ground); 133 | ground.all_close(&cpu_result, 1e-2, 1e-2).unwrap(); 134 | } 135 | } 136 | 137 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 138 | let B = 1; 139 | let M = 2048; 140 | let N = 2048; 141 | let K = 2048; 142 | let TILE_DIM = 32; 143 | let ROW_PER_THREAD = 8; 144 | let bench = QGEMMBenchmark::new(B, M, N, K, TILE_DIM, ROW_PER_THREAD); 145 | let throughput = Throughput::Elements(2 * (B * M * N * K) as u64); 146 | wgpu_bencher::benchmark(c, &TIMER, bench, throughput) 147 | } 148 | 149 | criterion_group!( 150 | name = bench; 151 | config = Criterion::default().with_measurement(&*TIMER); 152 | targets = benchmark 153 | ); 154 | criterion_main!(bench); 155 | -------------------------------------------------------------------------------- /kernels/sgemm/gemm_vectorized.wgsl: -------------------------------------------------------------------------------- 1 | fn getAIndexFromCoords3D(coords : vec3) -> i32 { 2 | return dot(coords, metadata.aStrides); 3 | } 4 | 5 | fn getBIndexFromCoords3D(coords : vec3) -> i32 { 6 | return dot(coords, metadata.bStrides); 7 | } 8 | 9 | fn getOutputIndexFromCoords(coords : vec3) -> i32 { 10 | return dot(coords, metadata.outStrides); 11 | } 12 | 13 | fn setOutputAtIndex(flatIndex : i32, value : vec4) { 14 | result[flatIndex] = vec4(value); 15 | } 16 | 17 | fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, value : vec4) { 18 | let flatIndex = getOutputIndexFromCoords(vec3(d0, d1, d2)); 19 | setOutputAtIndex(flatIndex / 4, value); 20 | } 21 | 22 | fn getA(d0 : i32, d1 : i32, d2 : i32) -> vec4 { 23 | return vec4(A[getAIndexFromCoords3D(vec3(d0,d1,d2)) / 4]); 24 | } 25 | 26 | fn getB(d0 : i32, d1 : i32, d2 : i32) -> vec4 { 27 | return vec4(B[getBIndexFromCoords3D(vec3(d0,d1,d2)) / 4]); 28 | } 29 | 30 | {% if FIT_A_OUTER and FIT_INNER %} 31 | fn mm_readA(batch: i32, row: i32, col: i32) -> vec4 { 32 | var value = vec4(0.0); 33 | value = getA(batch, row, col); 34 | return value; 35 | } 36 | {% else %} 37 | fn mm_readA(batch: i32, row: i32, col: i32) -> vec4 { 38 | var value = vec4(0.0); 39 | if (row < metadata.aShape.y && col < metadata.aShape.z) { 40 | value = getA(batch, row, col); 41 | } 42 | return value; 43 | } 44 | {% endif %} 45 | 46 | fn mm_readB(batch: i32, row: i32, col: i32) -> vec4 { 47 | var value = vec4(0.0); 48 | value = getB(batch, row, col); 49 | return value; 50 | } 51 | 52 | fn mm_write(batch: i32, row: i32, col: i32, valueIn: vec4) { 53 | {% if FIT_A_OUTER and FIT_B_OUTER %} 54 | var value = valueIn; 55 | let coords = vec3(batch, row, col); 56 | setOutputAtCoords(coords[0], coords[1], coords[2], value); 57 | {% else %} 58 | if (row < metadata.dimAOuter && col < metadata.dimBOuter) { 59 | var value = valueIn; 60 | let coords = vec3(batch, row, col); 61 | setOutputAtCoords(coords[0], coords[1], coords[2], valueIn); 62 | } 63 | {% endif %} 64 | } 65 | 66 | 67 | var localId: vec3; 68 | var globalId: vec3; 69 | var workgroupId: vec3; 70 | 71 | @group(0) @binding(0) var A: array>; 72 | 73 | @group(0) @binding(1) var B: array>; 74 | 75 | @group(0) @binding(2) var bias: array>; 76 | 77 | @group(0) @binding(3) var result: array>; 78 | 79 | struct Meta { 80 | aShape: vec3, 81 | aStrides: vec3, 82 | bShape: vec3, 83 | bStrides: vec3, 84 | outShape: vec3, 85 | outStrides: vec3, 86 | dimAOuter: i32, 87 | dimBOuter: i32, 88 | dimInner: i32, 89 | } 90 | 91 | @group(1) @binding(0) 92 | var metadata: Meta; 93 | 94 | var mm_Asub : array, {{ TILE_DIM / 4 }}>, {{ TILE_DIM }}>; 95 | var mm_Bsub : array, {{ TILE_DIM / 4 }}>, {{ TILE_DIM }}>; 96 | 97 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }},1) 98 | fn main(@builtin(local_invocation_id) localId : vec3, 99 | @builtin(global_invocation_id) globalId : vec3, 100 | @builtin(workgroup_id) workgroupId : vec3) { 101 | let batch = i32(globalId.z); 102 | let batchA = batch % metadata.aShape[0]; 103 | let batchB = batch % metadata.bShape[0]; 104 | 105 | let localRow = i32(localId.y); 106 | let tileRow = localRow * {{ ROW_PER_THREAD }}; 107 | let tileCol = i32(localId.x); 108 | 109 | let globalRow = i32(globalId.y) * {{ ROW_PER_THREAD }}; 110 | let globalCol = i32(globalId.x) * 4; 111 | 112 | let numTiles = (metadata.dimInner - 1) / {{ TILE_DIM }} + 1; 113 | var kStart = 0; 114 | 115 | var acc: array, {{ ROW_PER_THREAD }}>; 116 | 117 | // Loop over shared dimension. 118 | let tileRowB = localRow * {{ ROW_PER_THREAD }}; 119 | for (var t = 0; t < numTiles; t++) { 120 | // Load one tile of A into local memory. 121 | for (var innerRow = 0; innerRow < {{ ROW_PER_THREAD }}; innerRow++) { 122 | let inputRow = tileRow + innerRow; 123 | let inputCol = tileCol; 124 | 125 | mm_Asub[inputRow][inputCol] = mm_readA(batchA, globalRow + innerRow, kStart + inputCol * 4); 126 | } 127 | 128 | // Load one tile of B into local memory. 129 | for (var innerRow = 0; innerRow < {{ ROW_PER_THREAD }}; innerRow++) { 130 | let inputRow = tileRowB + innerRow; 131 | let inputCol = tileCol; 132 | mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol); 133 | } 134 | kStart = kStart + {{ TILE_DIM }}; 135 | workgroupBarrier(); 136 | 137 | // Compute acc values for a single thread. 138 | for (var k = 0; k < {{ TILE_DIM / 4 }}; k++) { 139 | let bidx = k * 4; 140 | let BCached0 = mm_Bsub[bidx][tileCol]; 141 | let BCached1 = mm_Bsub[bidx + 1][tileCol]; 142 | let BCached2 = mm_Bsub[bidx + 2][tileCol]; 143 | let BCached3 = mm_Bsub[bidx + 3][tileCol]; 144 | for (var i = 0; i < {{ ROW_PER_THREAD }}; i++) { 145 | let ACached = mm_Asub[tileRow + i][k]; 146 | acc[i] = fma(BCached0, vec4(ACached[0]), acc[i]); 147 | acc[i] = fma(BCached1, vec4(ACached[1]), acc[i]); 148 | acc[i] = fma(BCached2, vec4(ACached[2]), acc[i]); 149 | acc[i] = fma(BCached3, vec4(ACached[3]), acc[i]); 150 | } 151 | } 152 | workgroupBarrier(); 153 | } 154 | 155 | {% for innerRow in range(end=ROW_PER_THREAD) %} 156 | mm_write(batch, globalRow + {{ innerRow }}, globalCol, acc[{{ innerRow }}] + bias[globalCol / 4]); 157 | {% endfor %} 158 | } 159 | -------------------------------------------------------------------------------- /kernels/qgemm/tfjs2.wgsl: -------------------------------------------------------------------------------- 1 | fn getAIndexFromCoords3D(coords : vec3) -> i32 { 2 | return dot(coords, metadata.aStrides); 3 | } 4 | 5 | fn getBIndexFromCoords3D(coords : vec3) -> i32 { 6 | return dot(coords, metadata.bStrides); 7 | } 8 | 9 | fn getOutputIndexFromCoords(coords : vec3) -> i32 { 10 | return dot(coords, metadata.outShapeStrides); 11 | } 12 | 13 | fn setOutputAtIndex(flatIndex : i32, value : vec4) { 14 | result[flatIndex] = vec4(value); 15 | } 16 | 17 | fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, value : vec4) { 18 | let flatIndex = getOutputIndexFromCoords(vec3(d0, d1, d2)); 19 | setOutputAtIndex(flatIndex / 4, value); 20 | } 21 | 22 | fn getA(d0 : i32, d1 : i32, d2 : i32) -> vec4 { 23 | return vec4(A[getAIndexFromCoords3D(vec3(d0,d1,d2)) / 4]); 24 | } 25 | 26 | fn getB(d0 : i32, d1 : i32, d2 : i32) -> vec4 { 27 | return unpack4x8snorm(B[getBIndexFromCoords3D(vec3(d0,d1,d2)) / 4]); 28 | } 29 | 30 | fn getAbsMax(d0 : i32, d1 : i32, d2 : i32) -> f32 { 31 | let abs_index = getBIndexFromCoords3D(vec3(d0,d1,d2)) / 16; 32 | return absmax[abs_index]; 33 | } 34 | 35 | {% if FIT_A_OUTER and FIT_INNER %} 36 | fn mm_readA(batch: i32, row: i32, col: i32) -> vec4 { 37 | var value = vec4(0.0); 38 | value = getA(batch, row, col); 39 | return value; 40 | } 41 | {% else %} 42 | fn mm_readA(batch: i32, row: i32, col: i32) -> vec4 { 43 | var value = vec4(0.0); 44 | if (row < metadata.aShape.y && col < metadata.aShape.z) { 45 | value = getA(batch, row, col); 46 | } 47 | return value; 48 | } 49 | {% endif %} 50 | 51 | fn mm_readB(batch: i32, row: i32, col: i32) -> vec4 { 52 | var value = vec4(0.0); 53 | value = getB(batch, row, col); 54 | return value; 55 | } 56 | 57 | fn mm_write(batch: i32, row: i32, col: i32, valueIn: vec4) { 58 | {% if FIT_A_OUTER and FIT_B_OUTER %} 59 | var value = valueIn; 60 | let coords = vec3(batch, row, col); 61 | setOutputAtCoords(coords[0], coords[1], coords[2], value); 62 | {% else %} 63 | if (row < metadata.dimAOuter && col < metadata.dimBOuter) { 64 | var value = valueIn; 65 | let coords = vec3(batch, row, col); 66 | setOutputAtCoords(coords[0], coords[1], coords[2], valueIn); 67 | } 68 | {% endif %} 69 | } 70 | 71 | 72 | var localId: vec3; 73 | var globalId: vec3; 74 | var workgroupId: vec3; 75 | 76 | @group(0) @binding(0) var A: array>; 77 | 78 | @group(0) @binding(1) var B: array; 79 | 80 | @group(0) @binding(2) var absmax: array; 81 | 82 | @group(0) @binding(3) var result: array>; 83 | 84 | struct Meta { 85 | aShape: vec3, 86 | aStrides: vec3, 87 | bShape: vec3, 88 | bStrides: vec3, 89 | outShape: vec3, 90 | outShapeStrides: vec3, 91 | dimInner: i32, 92 | } 93 | 94 | @group(1) @binding(0) 95 | var metadata: Meta; 96 | 97 | var mm_Asub : array, {{ TILE_DIM / 4 }}>, {{ TILE_DIM }}>; 98 | var mm_Bsub : array, {{ TILE_DIM / 4 }}>, {{ TILE_DIM }}>; 99 | 100 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) 101 | fn main(@builtin(local_invocation_id) localId : vec3, 102 | @builtin(global_invocation_id) globalId : vec3, 103 | @builtin(workgroup_id) workgroupId : vec3) { 104 | let localRow = i32(localId.y); 105 | let tileRow = localRow * {{ ROW_PER_THREAD }}; 106 | let tileCol = i32(localId.x); 107 | 108 | let globalRow = i32(globalId.y) * {{ ROW_PER_THREAD }}; 109 | let globalCol = i32(globalId.x) * 4; 110 | let batch = i32(globalId.z); 111 | let batchA = batch % metadata.aShape.x; 112 | let batchB = batch % metadata.bShape.x; 113 | 114 | let numTiles = (metadata.dimInner - 1) / {{ TILE_DIM }} + 1; 115 | var kStart = 0; 116 | 117 | var acc: array, {{ ROW_PER_THREAD }}>; 118 | 119 | // Loop over shared dimension. 120 | let tileRowB = localRow * {{ ROW_PER_THREAD }}; 121 | for (var t = 0; t < numTiles; t++) { 122 | // Load one tile of A into local memory. 123 | for (var innerRow = 0; innerRow < {{ ROW_PER_THREAD }}; innerRow++) { 124 | let inputRow = tileRow + innerRow; 125 | let inputCol = tileCol; 126 | 127 | mm_Asub[inputRow][inputCol] = mm_readA(batchA, globalRow + innerRow, kStart + inputCol * 4); 128 | } 129 | 130 | // Load one tile of B into local memory. 131 | for (var innerRow = 0; innerRow < {{ ROW_PER_THREAD }}; innerRow++) { 132 | let inputRow = tileRowB + innerRow; 133 | let inputCol = tileCol; 134 | let absmax = getAbsMax(batchB, kStart + inputRow, globalCol); 135 | mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol) * absmax; 136 | } 137 | kStart = kStart + {{ TILE_DIM }}; 138 | workgroupBarrier(); 139 | 140 | // Compute acc values for a single thread. 141 | for (var k = 0; k < {{ TILE_DIM / 4 }}; k++) { 142 | let bidx = k * 4; 143 | let BCached0 = mm_Bsub[bidx][tileCol]; 144 | let BCached1 = mm_Bsub[bidx + 1][tileCol]; 145 | let BCached2 = mm_Bsub[bidx + 2][tileCol]; 146 | let BCached3 = mm_Bsub[bidx + 3][tileCol]; 147 | for (var i = 0; i < {{ ROW_PER_THREAD }}; i++) { 148 | let ACached = mm_Asub[tileRow + i][k]; 149 | acc[i] = fma(BCached0, vec4(ACached[0]), acc[i]); 150 | acc[i] = fma(BCached1, vec4(ACached[1]), acc[i]); 151 | acc[i] = fma(BCached2, vec4(ACached[2]), acc[i]); 152 | acc[i] = fma(BCached3, vec4(ACached[3]), acc[i]); 153 | } 154 | } 155 | workgroupBarrier(); 156 | } 157 | 158 | {% for innerRow in range(end=ROW_PER_THREAD) -%} 159 | mm_write(batch, globalRow + {{ innerRow }}, globalCol, acc[{{ innerRow }}]); 160 | {% endfor %} 161 | } 162 | -------------------------------------------------------------------------------- /src/bench.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | 3 | use criterion::{BenchmarkId, Criterion, Throughput}; 4 | 5 | use crate::{CPUTensor, GPUBuffer, GPUHandle, GPUTensor, OpMetadata, WgpuTimer, Workload}; 6 | 7 | pub trait KernelContextExt { 8 | fn insert_workload(&mut self, workload: &Workload); 9 | } 10 | 11 | impl KernelContextExt for tera::Context { 12 | fn insert_workload(&mut self, workload: &Workload) { 13 | self.insert("workgroup_size_x", &workload.size().0); 14 | self.insert("workgroup_size_y", &workload.size().1); 15 | self.insert("workgroup_size_z", &workload.size().2); 16 | } 17 | } 18 | 19 | pub trait KernelBench: std::fmt::Debug { 20 | type Metadata: OpMetadata; 21 | fn name() -> &'static str; 22 | fn source(&self, workload: &Workload) -> String; 23 | fn tensors(&self) -> Vec; 24 | fn workload(&self, tensors: &[CPUTensor]) -> Workload; 25 | fn metadata(&self, tensors: &[CPUTensor]) -> Self::Metadata; 26 | fn validate(&self, tensors: &[CPUTensor]); 27 | } 28 | 29 | pub fn dispatch_validate( 30 | handle: &GPUHandle, 31 | kernel: &K, 32 | tensors: &[CPUTensor], 33 | ) -> Vec { 34 | let _ = env_logger::builder().is_test(true).try_init(); 35 | let workload = kernel.workload(&tensors); 36 | log::debug!("Workload: {:?}", workload); 37 | let source = kernel.source(&workload); 38 | log::debug!("Source: {}", source); 39 | let pipeline = source_to_pipeline(handle, &source); 40 | let uniform_buffer = kernel.metadata(&tensors).into_buffer(handle); 41 | let gpu_tensors = tensors 42 | .into_iter() 43 | .cloned() 44 | .map(|t| t.into_gpu(handle)) 45 | .collect::>(); 46 | let bind_groups = tensors_to_bind_groups(handle, &gpu_tensors, uniform_buffer, &pipeline); 47 | dispatch(handle, &workload, &bind_groups, &pipeline, None); 48 | gpu_tensors 49 | } 50 | 51 | #[inline(always)] 52 | pub fn dispatch( 53 | handle: &GPUHandle, 54 | workload: &Workload, 55 | bind_groups: &[wgpu::BindGroup], 56 | pipeline: &wgpu::ComputePipeline, 57 | timestamp_writes: Option, 58 | ) { 59 | let mut encoder = handle 60 | .device() 61 | .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); 62 | { 63 | let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { 64 | label: None, 65 | timestamp_writes, 66 | }); 67 | for (i, bind_group) in bind_groups.iter().enumerate() { 68 | cpass.set_bind_group(i as _, bind_group, &[]); 69 | } 70 | cpass.set_pipeline(pipeline); 71 | let (x, y, z) = workload.count().as_tuple(); 72 | for _ in 0..WgpuTimer::COMPUTE_PER_QUERY { 73 | cpass.dispatch_workgroups(x, y, z); 74 | } 75 | } 76 | handle.queue().submit(Some(encoder.finish())); 77 | handle.device().poll(wgpu::Maintain::Wait); 78 | } 79 | 80 | pub fn source_to_pipeline(handle: &GPUHandle, source: &str) -> wgpu::ComputePipeline { 81 | let shader_module = unsafe { 82 | handle 83 | .device() 84 | .create_shader_module_unchecked(wgpu::ShaderModuleDescriptor { 85 | label: None, 86 | source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), 87 | }) 88 | }; 89 | 90 | handle 91 | .device() 92 | .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { 93 | label: None, 94 | layout: None, 95 | module: &shader_module, 96 | entry_point: "main", 97 | compilation_options: wgpu::PipelineCompilationOptions { 98 | zero_initialize_workgroup_memory: false, 99 | ..Default::default() 100 | }, 101 | }) 102 | } 103 | 104 | pub fn tensors_to_bind_groups( 105 | handle: &GPUHandle, 106 | tensors: &[GPUTensor], 107 | uniform_buffer: GPUBuffer, 108 | pipeline: &wgpu::ComputePipeline, 109 | ) -> Vec { 110 | let mut bind_group_entries = vec![]; 111 | 112 | for tensor in tensors { 113 | bind_group_entries.append(&mut tensor.bindings(bind_group_entries.len())); 114 | } 115 | 116 | let mut standard_bind_groups = bind_group_entries 117 | .chunks(4) 118 | .enumerate() 119 | .map(|(i, entries)| { 120 | handle 121 | .device() 122 | .create_bind_group(&wgpu::BindGroupDescriptor { 123 | label: None, 124 | layout: &pipeline.get_bind_group_layout(i as _), 125 | entries, 126 | }) 127 | }) 128 | .collect::>(); 129 | 130 | let uniform_bind_group = handle 131 | .device() 132 | .create_bind_group(&wgpu::BindGroupDescriptor { 133 | label: None, 134 | layout: &pipeline.get_bind_group_layout(standard_bind_groups.len() as _), 135 | entries: &[wgpu::BindGroupEntry { 136 | binding: 0, 137 | resource: uniform_buffer.as_entire_binding(), 138 | }], 139 | }); 140 | standard_bind_groups.push(uniform_bind_group); 141 | standard_bind_groups 142 | } 143 | 144 | pub fn benchmark( 145 | c: &mut Criterion<&WgpuTimer>, 146 | timer: &WgpuTimer, 147 | kernel: K, 148 | throughput: Throughput, 149 | ) { 150 | let handle = timer.handle(); 151 | let tensors = kernel.tensors(); 152 | kernel.validate(&tensors); 153 | let workload = kernel.workload(&tensors); 154 | let source = kernel.source(&workload); 155 | let pipeline = source_to_pipeline(handle, &source); 156 | let uniform_buffer = kernel.metadata(&tensors).into_buffer(handle); 157 | 158 | let gpu_tensors = tensors 159 | .into_iter() 160 | .map(|t| t.into_gpu(handle)) 161 | .collect::>(); 162 | let bind_groups = tensors_to_bind_groups(handle, &gpu_tensors, uniform_buffer, &pipeline); 163 | 164 | let mut group = c.benchmark_group(K::name()); 165 | group.throughput(throughput); 166 | group.bench_function(BenchmarkId::new(K::name(), 0), |b| { 167 | b.iter(|| { 168 | let tsw = timer.timestamp_writes(); 169 | dispatch(handle, &workload, &bind_groups, &pipeline, Some(tsw)); 170 | timer.increment_query(); 171 | }); 172 | }); 173 | group.finish() 174 | } 175 | -------------------------------------------------------------------------------- /kernels/qgemm/tfjs.wgsl: -------------------------------------------------------------------------------- 1 | fn getAIndexFromCoords3D(coords : vec3) -> i32 { 2 | return dot(coords, metadata.aStrides); 3 | } 4 | 5 | fn getBIndexFromCoords3D(coords : vec3) -> i32 { 6 | return dot(coords, metadata.bStrides); 7 | } 8 | 9 | fn getOutputIndexFromCoords(coords : vec3) -> i32 { 10 | return dot(coords, metadata.outShapeStrides); 11 | } 12 | 13 | fn setOutputAtIndex(flatIndex : i32, value : vec4) { 14 | result[flatIndex] = vec4(value); 15 | } 16 | 17 | fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, value : vec4) { 18 | let flatIndex = getOutputIndexFromCoords(vec3(d0, d1, d2)); 19 | setOutputAtIndex(flatIndex / 4, value); 20 | } 21 | 22 | fn getA(d0 : i32, d1 : i32, d2 : i32) -> vec4 { 23 | return unpack4x8snorm(A[getAIndexFromCoords3D(vec3(d0,d1,d2)) / 4]); 24 | } 25 | 26 | fn getB(d0 : i32, d1 : i32, d2 : i32) -> vec4 { 27 | return B[getBIndexFromCoords3D(vec3(d0,d1,d2)) / 4]; 28 | } 29 | 30 | fn getAbsMax(d0 : i32, d1 : i32, d2 : i32) -> f32 { 31 | let abs_index = getAIndexFromCoords3D(vec3(d0,d1,d2)) / 16; 32 | return absmax[abs_index]; 33 | } 34 | 35 | {% if A_FIT %} 36 | fn mm_readA(batch: i32, row: i32, col: i32) -> vec4 { 37 | var value = vec4(0.0); 38 | value = getA(batch, row, col); 39 | return value; 40 | } 41 | {% else %} 42 | fn mm_readA(batch: i32, row: i32, col: i32) -> vec4 { 43 | var value = vec4(0.0); 44 | if (row < metadata.aShape.y && col < metadata.aShape.z) { 45 | value = getA(batch, row, col); 46 | } 47 | return value; 48 | } 49 | {% endif %} 50 | 51 | {% if B_FIT %} 52 | fn mm_readB(batch: i32, row: i32, col: i32) -> vec4 { 53 | var value = vec4(0.0); 54 | value = getB(batch, row, col); 55 | return value; 56 | } 57 | {% else %} 58 | fn mm_readB(batch: i32, row: i32, col: i32) -> vec4 { 59 | var value = vec4(0.0); 60 | if (row < metadata.bShape.y && col < metadata.bShape.z) { 61 | value = getB(batch, row, col); 62 | } 63 | return value; 64 | } 65 | {% endif %} 66 | 67 | fn mm_write(batch: i32, row: i32, col: i32, valueIn: vec4) { 68 | {% if OUT_FIT %} 69 | var value = valueIn; 70 | let coords = vec3(batch, row, col); 71 | setOutputAtCoords(coords[0], coords[1], coords[2], value); 72 | {% else %} 73 | if (row < metadata.outShape.y && col < metadata.outShape.z) { 74 | var value = valueIn; 75 | let coords = vec3(batch, row, col); 76 | setOutputAtCoords(coords[0], coords[1], coords[2], valueIn); 77 | } 78 | {% endif %} 79 | } 80 | 81 | 82 | var localId: vec3; 83 | var globalId: vec3; 84 | var workgroupId: vec3; 85 | 86 | @group(0) @binding(0) var A: array; 87 | 88 | @group(0) @binding(1) var absmax: array; 89 | 90 | @group(0) @binding(2) var B: array>; 91 | 92 | @group(0) @binding(3) var result: array>; 93 | 94 | struct Meta { 95 | aShape: vec3, 96 | aStrides: vec3, 97 | bShape: vec3, 98 | bStrides: vec3, 99 | outShape: vec3, 100 | outShapeStrides: vec3, 101 | dimInner: i32, 102 | } 103 | 104 | @group(1) @binding(0) 105 | var metadata: Meta; 106 | 107 | var mm_Asub : array, {{ TILE_DIM / 4 }}>, {{ TILE_DIM }}>; 108 | var mm_Bsub : array, {{ TILE_DIM / 4 }}>, {{ TILE_DIM }}>; 109 | 110 | @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) 111 | fn main(@builtin(local_invocation_id) localId : vec3, 112 | @builtin(global_invocation_id) globalId : vec3, 113 | @builtin(workgroup_id) workgroupId : vec3) { 114 | let localRow = i32(localId.y); 115 | let tileRow = localRow * {{ ROW_PER_THREAD }}; 116 | let tileCol = i32(localId.x); 117 | 118 | let globalRow = i32(globalId.y) * {{ ROW_PER_THREAD }}; 119 | let globalCol = i32(globalId.x) * 4; 120 | let batch = i32(globalId.z); 121 | let batchA = batch % metadata.aShape.x; 122 | let batchB = batch % metadata.bShape.x; 123 | 124 | let numTiles = (metadata.dimInner - 1) / {{ TILE_DIM }} + 1; 125 | var kStart = 0; 126 | 127 | var acc: array, {{ ROW_PER_THREAD }}>; 128 | 129 | // Loop over shared dimension. 130 | let tileRowB = localRow * {{ ROW_PER_THREAD }}; 131 | for (var t = 0; t < numTiles; t++) { 132 | // Load one tile of A into local memory. 133 | for (var innerRow = 0; innerRow < {{ ROW_PER_THREAD }}; innerRow++) { 134 | let inputRow = tileRow + innerRow; 135 | let inputCol = tileCol; 136 | 137 | let curRow = globalRow + innerRow; 138 | let curCol = kStart + inputCol * 4; 139 | 140 | let absmax = getAbsMax(batchA, curRow, curCol); 141 | mm_Asub[inputRow][inputCol] = mm_readA(batchA, curRow, curCol) * absmax; 142 | } 143 | 144 | // Load one tile of B into local memory. 145 | for (var innerRow = 0; innerRow < {{ ROW_PER_THREAD }}; innerRow++) { 146 | let inputRow = tileRowB + innerRow; 147 | let inputCol = tileCol; 148 | mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol); 149 | } 150 | kStart = kStart + {{ TILE_DIM }}; 151 | workgroupBarrier(); 152 | 153 | // Compute acc values for a single thread. 154 | for (var k = 0; k < {{ TILE_DIM / 4 }}; k++) { 155 | let bidx = k * 4; 156 | let BCached0 = mm_Bsub[bidx][tileCol]; 157 | let BCached1 = mm_Bsub[bidx + 1][tileCol]; 158 | let BCached2 = mm_Bsub[bidx + 2][tileCol]; 159 | let BCached3 = mm_Bsub[bidx + 3][tileCol]; 160 | for (var i = 0; i < {{ ROW_PER_THREAD }}; i++) { 161 | let ACached = mm_Asub[tileRow + i][k]; 162 | acc[i] = fma(BCached0, vec4(ACached[0]), acc[i]); 163 | acc[i] = fma(BCached1, vec4(ACached[1]), acc[i]); 164 | acc[i] = fma(BCached2, vec4(ACached[2]), acc[i]); 165 | acc[i] = fma(BCached3, vec4(ACached[3]), acc[i]); 166 | } 167 | } 168 | workgroupBarrier(); 169 | } 170 | 171 | {% for innerRow in range(end=ROW_PER_THREAD) -%} 172 | mm_write(batch, globalRow + {{ innerRow }}, globalCol, acc[{{ innerRow }}]); 173 | {% endfor %} 174 | } 175 | -------------------------------------------------------------------------------- /benches/sgemv/gemv.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::{IntoPy, Python}; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, Debug)] 21 | pub struct SGEMVMeta { 22 | aShape: glam::IVec3, 23 | aStrides: glam::IVec3, 24 | bShape: glam::IVec3, 25 | bStrides: glam::IVec3, 26 | outShape: glam::IVec3, 27 | outStrides: glam::IVec3, 28 | dimAOuter: i32, 29 | dimBOuter: i32, 30 | dimInner: i32, 31 | } 32 | 33 | impl OpMetadata for SGEMVMeta {} 34 | 35 | #[derive(derive_new::new, Debug)] 36 | pub struct SGEMVBenchmark { 37 | B: usize, 38 | M: usize, 39 | N: usize, 40 | K: usize, 41 | TILE_DIM: usize, 42 | ROW_PER_THREAD: usize, 43 | trans_a: bool, 44 | trans_b: bool, 45 | } 46 | 47 | impl SGEMVBenchmark { 48 | fn shape_fit(&self) -> [bool; 3] { 49 | let aOuter = if self.trans_a { self.K } else { self.M }; 50 | let bOuter = if self.trans_b { self.K } else { self.N }; 51 | let dimInner = if self.trans_a { self.M } else { self.K }; 52 | 53 | let mut shape_fit = [false; 3]; 54 | shape_fit[0] = aOuter % self.TILE_DIM == 0; 55 | shape_fit[1] = bOuter % self.TILE_DIM == 0; 56 | shape_fit[2] = dimInner % self.TILE_DIM == 0; 57 | println!("SHAPE FIT: {:?}", shape_fit); 58 | shape_fit 59 | } 60 | } 61 | 62 | impl KernelBench for SGEMVBenchmark { 63 | type Metadata = SGEMVMeta; 64 | 65 | fn name() -> &'static str { 66 | "SGEMVBenchmark" 67 | } 68 | 69 | fn source(&self, workload: &Workload) -> String { 70 | let mut tera = tera::Tera::default(); 71 | let mut context = tera::Context::new(); 72 | 73 | let is_vec4 = !self.trans_a 74 | && !self.trans_b 75 | && (self.M % 4 == 0) 76 | && (self.N % 4 == 0) 77 | && (self.K % 4 == 0); 78 | let template = include_str!("../../kernels/sgemv/sgemv_2.wgsl"); 79 | tera.add_raw_template(Self::name(), template).unwrap(); 80 | let shape_fit = self.shape_fit(); 81 | context.insert("FIT_A_OUTER", &shape_fit[0]); 82 | context.insert("FIT_B_OUTER", &shape_fit[1]); 83 | context.insert("FIT_INNER", &shape_fit[2]); 84 | context.insert("TRANS_A", &self.trans_a); 85 | context.insert("TRANS_B", &self.trans_b); 86 | 87 | context.insert("TILE_DIM", &self.TILE_DIM); 88 | context.insert("ROW_PER_THREAD", &self.ROW_PER_THREAD); 89 | context.insert_workload(workload); 90 | let kernel = tera.render(Self::name(), &context).unwrap(); 91 | println!("{}", kernel); 92 | kernel 93 | } 94 | 95 | fn tensors(&self) -> Vec { 96 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 97 | let a = CPUTensor::randn::(shape![B, M, K]); 98 | let b = CPUTensor::randn::(shape![B, K, N]); 99 | let output = CPUTensor::zeros::(shape![B, M, N]); 100 | vec![a, b, output] 101 | } 102 | 103 | fn workload(&self, _: &[CPUTensor]) -> Workload { 104 | let workgroup_size = wgs![32, 4, 1]; 105 | let workgroup_count = wgc![(self.M / 32) as _, 1, self.B as _]; 106 | let dispatch = Workload::new(workgroup_size, workgroup_count); 107 | println!("DISPATCH: {:?}", dispatch); 108 | dispatch 109 | } 110 | 111 | fn metadata(&self, _: &[CPUTensor]) -> Self::Metadata { 112 | let (B, M, N, K) = (self.B as i32, self.M as i32, self.N as i32, self.K as i32); 113 | 114 | let aShape = glam::IVec3::new(B, M, K); 115 | let aStrides = glam::IVec3::new(M * K, K, 1); 116 | let bShape = glam::IVec3::new(B, K, N); 117 | let bStrides = glam::IVec3::new(K * N, N, 1); 118 | let outShape = glam::IVec3::new(B, M, N); 119 | let outStrides = glam::IVec3::new(M * N, N, 1); 120 | 121 | let dimAOuter = if self.trans_a { K } else { M }; 122 | let dimBOuter = if self.trans_b { K } else { N }; 123 | let dimInner = if self.trans_a { M } else { K }; 124 | 125 | let meta = SGEMVMeta { 126 | aShape, 127 | aStrides, 128 | bShape, 129 | bStrides, 130 | outShape, 131 | outStrides, 132 | dimAOuter, 133 | dimBOuter, 134 | dimInner, 135 | }; 136 | println!("META: {:?}", meta); 137 | meta 138 | } 139 | 140 | fn validate(&self, tensors: &[CPUTensor]) { 141 | let (a, b) = (&tensors[0], &tensors[1]); 142 | let (trans_a, trans_b) = (self.trans_a, self.trans_b); 143 | let ground = Python::with_gil(|py| { 144 | let (py_a, py_b) = (a.to_py::(&py), b.to_py::(&py)); 145 | let (py_trans_a, py_trans_b) = (trans_a.into_py(py), trans_b.into_py(py)); 146 | let result: Context = python! { 147 | import torch 148 | (a, b) = (torch.from_numpy('py_a), torch.from_numpy('py_b)) 149 | if 'py_trans_a: 150 | print("Transposing A in Python") 151 | a = torch.permute(a, (0, 2, 1)) 152 | if 'py_trans_b: 153 | print("Transposing B in Python") 154 | b = torch.permute(b, (0, 2, 1)) 155 | result = (a @ b).numpy() 156 | }; 157 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 158 | }); 159 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 160 | let cpu_result = gpu_tensors.remove(2).into_cpu(TIMER.handle()).unwrap(); 161 | println!("GROUND: {}", ground); 162 | println!("OURS: {}", cpu_result); 163 | ground.all_close(&cpu_result, 5e-4, 5e-4).unwrap(); 164 | } 165 | } 166 | 167 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 168 | let B = 1; 169 | let M = 16384; 170 | let N = 1; 171 | let K = 3072; 172 | let TILE_DIM = 32; 173 | let ROW_PER_THREAD = 4; 174 | 175 | let trans_a = false; 176 | let trans_b = false; 177 | 178 | let bench = SGEMVBenchmark::new(B, M, N, K, TILE_DIM, ROW_PER_THREAD, trans_a, trans_b); 179 | let throughput = Throughput::Elements(2 * (B * M * N * K) as u64); 180 | wgpu_bencher::benchmark(c, &TIMER, bench, throughput) 181 | } 182 | 183 | criterion_group!( 184 | name = bench; 185 | config = Criterion::default().with_measurement(&*TIMER); 186 | targets = benchmark 187 | ); 188 | criterion_main!(bench); 189 | -------------------------------------------------------------------------------- /src/quant.rs: -------------------------------------------------------------------------------- 1 | use crate::{CPUTensor, DType, STORAGE_BUFFER_ALIGN}; 2 | use num::integer::div_floor; 3 | use std::fmt::Debug; 4 | 5 | /// Quantizer 6 | /// 7 | /// Packs weights into our custom quantization formats. 8 | #[derive(Debug, derive_new::new)] 9 | pub struct Quantizer { 10 | format: Quantization, 11 | } 12 | 13 | impl Quantizer { 14 | pub fn quantize(&self, tensor: CPUTensor) -> CPUTensor { 15 | match self.format { 16 | Quantization::None => tensor, 17 | Quantization::SInt8 => self.sint8_quantize(tensor), 18 | Quantization::SInt4 => todo!(), 19 | } 20 | } 21 | 22 | pub fn dequantize(&self, tensor: CPUTensor) -> CPUTensor { 23 | match self.format { 24 | Quantization::None => tensor, 25 | Quantization::SInt8 => self.sint8_dequantize(tensor), 26 | Quantization::SInt4 => todo!(), 27 | } 28 | } 29 | 30 | /// Quantizes a float 32 tensor into a packed uint32 tensor. 31 | /// This is the rust equivalent of: https://www.w3.org/TR/WGSL/#pack4x8snorm-builtin 32 | /// This allows us to call `unpack4x8snorm` in the shader. 33 | /// It's a pretty naive quantization scheme, more to come. 34 | pub fn sint8_quantize(&self, tensor: CPUTensor) -> CPUTensor { 35 | let numel = tensor.shape().numel(); 36 | assert!(numel % 4 == 0 && numel % 16 == 0); 37 | assert!(tensor.dt() == DType::F32); //TODO: f16, bf16 38 | //TODO: check if tensor is contiguous 39 | let pack_size = self.format.pack_size(); 40 | let group_size = self.format.group_size(); 41 | 42 | let qmatrix_len = numel / pack_size; 43 | let amatrix_len = numel / group_size; 44 | 45 | //returns the aligned number of ELEMENTS 46 | let aligner = |numel: usize, size_t: usize| -> usize { 47 | let nbytes = numel * size_t; 48 | let aligned = if nbytes % STORAGE_BUFFER_ALIGN != 0 { 49 | nbytes + STORAGE_BUFFER_ALIGN - nbytes % STORAGE_BUFFER_ALIGN 50 | } else { 51 | nbytes 52 | }; 53 | aligned / size_t 54 | }; 55 | 56 | let mut quantized_matrix = vec![0u32; aligner(qmatrix_len, std::mem::size_of::())]; 57 | let mut absmax_matrix = vec![0f32; aligner(amatrix_len, std::mem::size_of::())]; 58 | 59 | let sf = 127.0f32; 60 | let mut block_absmax = f32::NEG_INFINITY; 61 | 62 | let matrix = tensor.to_vec::().unwrap(); 63 | 64 | for i in (0..numel).step_by(pack_size) { 65 | if i % group_size == 0 { 66 | block_absmax = matrix[i..i + group_size] 67 | .iter() 68 | .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x.abs())); 69 | } 70 | let packed_value: i32 = ((matrix[i] / block_absmax * sf).round() as i32 & 0xFF) 71 | | (((matrix[i + 1] / block_absmax * sf).round() as i32 & 0xFF) << 8) 72 | | (((matrix[i + 2] / block_absmax * sf).round() as i32 & 0xFF) << 16) 73 | | (((matrix[i + 3] / block_absmax * sf).round() as i32 & 0xFF) << 24); 74 | quantized_matrix[i / pack_size] = packed_value as u32; 75 | absmax_matrix[i / group_size] = block_absmax; 76 | } 77 | quantized_matrix.append(&mut unsafe { std::mem::transmute(absmax_matrix) }); 78 | unsafe { CPUTensor::from_quantized(quantized_matrix, tensor.shape().clone(), DType::WQ8) } 79 | } 80 | 81 | pub fn sint8_dequantize(&self, quantized: CPUTensor) -> CPUTensor { 82 | assert!(quantized.dt() == DType::WQ8); 83 | let numel = quantized.shape().numel(); 84 | 85 | let aligner = |numel: usize, size_t: usize| -> usize { 86 | let nbytes = numel * size_t; 87 | let aligned = if nbytes % STORAGE_BUFFER_ALIGN != 0 { 88 | nbytes + STORAGE_BUFFER_ALIGN - nbytes % STORAGE_BUFFER_ALIGN 89 | } else { 90 | nbytes 91 | }; 92 | aligned 93 | }; 94 | 95 | let pack_size = self.format.pack_size(); 96 | let group_size = self.format.group_size(); 97 | 98 | let num_q = numel / pack_size; 99 | let num_q_bytes = num_q * std::mem::size_of::(); 100 | let aligned_q_bytes = aligner(num_q, std::mem::size_of::()); 101 | 102 | let num_absmax = numel / group_size; 103 | let num_absmax_bytes = num_absmax * std::mem::size_of::(); 104 | 105 | let raw_bytes = quantized.storage().as_bytes(); 106 | 107 | let quantized_matrix = bytemuck::cast_slice::(&raw_bytes[..num_q_bytes]); 108 | let absmax_matrix = bytemuck::cast_slice::( 109 | &raw_bytes[aligned_q_bytes..aligned_q_bytes + num_absmax_bytes], 110 | ); 111 | 112 | let mut dequantized = vec![0.0f32; numel]; 113 | 114 | for i in (0..numel).step_by(pack_size) { 115 | let block_absmax = absmax_matrix[div_floor(i, group_size)]; 116 | let packed_value = quantized_matrix[div_floor(i, pack_size)] as i32; 117 | dequantized[i] = ((packed_value << 24) >> 24) as f32 / 127.0 * block_absmax; 118 | dequantized[i + 1] = ((packed_value << 16) >> 24) as f32 / 127.0 * block_absmax; 119 | dequantized[i + 2] = ((packed_value << 8) >> 24) as f32 / 127.0 * block_absmax; 120 | dequantized[i + 3] = (packed_value >> 24) as f32 / 127.0 * block_absmax; 121 | } 122 | 123 | CPUTensor::from_slice(&dequantized, quantized.shape().clone()) 124 | } 125 | } 126 | 127 | #[derive(Debug, Clone, Copy)] 128 | pub enum Quantization { 129 | None, 130 | SInt8, 131 | SInt4, 132 | } 133 | 134 | impl Quantization { 135 | pub fn pack_size(&self) -> usize { 136 | match self { 137 | Quantization::None => 1, 138 | Quantization::SInt8 => 4, 139 | Quantization::SInt4 => 8, 140 | } 141 | } 142 | 143 | pub fn group_size(&self) -> usize { 144 | match self { 145 | Quantization::None => 1, 146 | Quantization::SInt8 => 16, 147 | Quantization::SInt4 => 8, 148 | } 149 | } 150 | } 151 | 152 | #[cfg(test)] 153 | mod tests { 154 | use crate::shape; 155 | 156 | #[test] 157 | pub fn sint8_qdq() { 158 | use crate::CPUTensor; 159 | use crate::Quantization; 160 | use crate::Quantizer; 161 | let tensor = CPUTensor::from_slice( 162 | &[ 163 | 0.1, -0.1, 0.5, -0.5, 1.0, -1.0, 1.2, -1.2, 0.1, -0.1, 0.5, -0.5, 1.0, -1.0, 1.2, 164 | -1.2, 0.1, -0.1, 0.5, -0.5, 1.0, -1.0, 1.2, -1.2, 0.1, -0.1, 0.5, -0.5, 1.0, -1.0, 165 | 1.2, -1.2, 166 | ], 167 | shape![4, 8], 168 | ); 169 | println!("{}", tensor); 170 | let quantizer = Quantizer::new(Quantization::SInt8); 171 | let quantized = quantizer.quantize(tensor.clone()); 172 | let dequantized = quantizer.dequantize(quantized); 173 | println!("{}", dequantized); 174 | 175 | dequantized.all_close(&tensor, 1e-2, 1e-2).unwrap(); 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /kernels/sgemm/gemm_scalar.wgsl: -------------------------------------------------------------------------------- 1 | fn getAIndexFromCoords3D(coords : vec3) -> i32 { 2 | return dot(coords, metadata.aStrides); 3 | } 4 | 5 | fn getBIndexFromCoords3D(coords : vec3) -> i32 { 6 | return dot(coords, metadata.bStrides); 7 | } 8 | 9 | fn getOutputIndexFromCoords(coords : vec3) -> i32 { 10 | return dot(coords, metadata.outShapeStrides); 11 | } 12 | 13 | fn setOutputAtIndex(flatIndex: i32, value: f32) { 14 | result[flatIndex] = f32(value); 15 | } 16 | 17 | fn setOutputAtCoords(d0: i32, d1: i32, d2: i32, value: f32) { 18 | let flatIndex = getOutputIndexFromCoords(vec3(d0, d1, d2)); 19 | setOutputAtIndex(flatIndex, value); 20 | } 21 | 22 | fn getA(d0: i32, d1: i32, d2: i32) -> f32 { 23 | return f32(A[getAIndexFromCoords3D(vec3(d0, d1, d2))]); 24 | } 25 | 26 | fn getB(d0: i32, d1: i32, d2: i32) -> f32 { 27 | return f32(B[getBIndexFromCoords3D(vec3(d0, d1, d2))]); 28 | } 29 | 30 | {% if FIT_A_OUTER and FIT_INNER %} 31 | fn mm_readA(batch: i32, row: i32, col: i32) -> f32 { 32 | var value = f32(0.0); 33 | {% if TRANS_A %} 34 | value = getA(batch, col, row); 35 | {% else %} 36 | value = getA(batch, row, col); 37 | {% endif %} 38 | return value; 39 | } 40 | {% else %} 41 | fn mm_readA(batch: i32, row: i32, col: i32) -> f32 { 42 | var value = f32(0.0); 43 | {% if TRANS_A %} 44 | if (row < metadata.aShape.z && col < metadata.aShape.y) { 45 | value = getA(batch, col, row); 46 | } 47 | {% else %} 48 | if (row < metadata.aShape.y && col < metadata.aShape.z) { 49 | value = getA(batch, row, col); 50 | } 51 | {% endif %} 52 | return value; 53 | } 54 | {% endif %} 55 | 56 | {% if FIT_B_OUTER and FIT_INNER %} 57 | fn mm_readB(batch: i32, row: i32, col: i32) -> f32 { 58 | var value = f32(0.0); 59 | {% if TRANS_B %} 60 | value = getB(batch, col, row); 61 | {% else %} 62 | value = getB(batch, row, col); 63 | {% endif %} 64 | return value; 65 | } 66 | {% else %} 67 | fn mm_readB(batch: i32, row: i32, col: i32) -> f32 { 68 | var value = f32(0.0); 69 | {% if TRANS_B %} 70 | if (row < metadata.bShape.z && col < metadata.bShape.y) { 71 | value = getB(batch, col, row); 72 | } 73 | {% else %} 74 | if (row < metadata.bShape.y && col < metadata.bShape.z) { 75 | value = getB(batch, row, col); 76 | } 77 | {% endif %} 78 | return value; 79 | } 80 | {% endif %} 81 | 82 | fn mm_write(batch: i32, row: i32, col: i32, valueIn: f32) { 83 | {% if FIT_A_OUTER and FIT_B_OUTER %} 84 | var value = valueIn; 85 | let coords = vec3(batch, row, col); 86 | setOutputAtCoords(coords[0], coords[1], coords[2], value); 87 | {% else %} 88 | if (row < metadata.dimAOuter && col < metadata.dimBOuter) { 89 | var value = valueIn; 90 | let coords = vec3(batch, row, col); 91 | setOutputAtCoords(coords[0], coords[1], coords[2], valueIn); 92 | } 93 | {% endif %} 94 | } 95 | 96 | var localId: vec3; 97 | var globalId: vec3; 98 | var workgroupId: vec3; 99 | 100 | @group(0) @binding(0) var A: array; 101 | @group(0) @binding(1) var B: array; 102 | @group(0) @binding(2) var bias: array; 103 | @group(0) @binding(3) var result: array; 104 | @group(1) @binding(0) var metadata: Meta; 105 | 106 | 107 | struct Meta { 108 | aShape: vec3, 109 | aStrides: vec3, 110 | bShape: vec3, 111 | bStrides: vec3, 112 | outShape: vec3, 113 | outShapeStrides: vec3, 114 | dimAOuter: i32, 115 | dimBOuter: i32, 116 | dimInner: i32, 117 | } 118 | 119 | var mm_Asub : array, {{ TILE_DIM }}>; 120 | var mm_Bsub : array, {{ TILE_DIM }}>; 121 | 122 | @compute @workgroup_size({{ TILE_DIM / 4 }}, {{ TILE_DIM / ROW_PER_THREAD }},1) 123 | fn main(@builtin(local_invocation_id) localId : vec3, 124 | @builtin(global_invocation_id) globalId : vec3, 125 | @builtin(workgroup_id) workgroupId : vec3) { 126 | let batch = i32(globalId.z); 127 | let batchA = batch % metadata.aShape[0]; 128 | let batchB = batch % metadata.bShape[0]; 129 | 130 | let tileRow = i32(localId.y) * 4; 131 | let tileCol = i32(localId.x) * 4; 132 | 133 | let globalRowStart = i32(workgroupId.y) * {{ TILE_DIM }}; 134 | let globalRow = i32(globalId.y) * 4; 135 | let globalCol = i32(globalId.x) * 4; 136 | 137 | let numTiles = (metadata.dimInner - 1) / {{ TILE_DIM }} + 1; 138 | var kStart = 0; 139 | 140 | var acc: array, 4>; 141 | 142 | let tileRowA = i32(localId.y) * 4; 143 | let tileColA = i32(localId.x) * 4; 144 | let tileRowB = i32(localId.y) * 4; 145 | // Loop over shared dimension. 146 | for (var t = 0; t < numTiles; t++) { 147 | // Load one tile of A into local memory. 148 | for (var innerRow = 0; innerRow < 4; innerRow++) { 149 | for (var innerCol = 0; innerCol < 4; innerCol++) { 150 | let inputRow = tileRowA + innerRow; 151 | let inputCol = tileColA + innerCol; 152 | 153 | mm_Asub[inputRow][inputCol] = mm_readA(batchA, 154 | globalRowStart + inputRow, 155 | kStart + inputCol); 156 | } 157 | } 158 | 159 | // Load one tile of B into local memory. 160 | for (var innerRow = 0; innerRow < 4; innerRow++) { 161 | for (var innerCol = 0; innerCol < 4; innerCol++) { 162 | let inputRow = tileRowB + innerRow; 163 | let inputCol = tileCol + innerCol; 164 | mm_Bsub[inputRow][inputCol] = mm_readB(batchB, 165 | kStart + inputRow, 166 | globalCol + innerCol); 167 | } 168 | } 169 | kStart = kStart + {{ TILE_DIM }}; 170 | workgroupBarrier(); 171 | 172 | for (var k = 0; k < {{ TILE_DIM }}; k++) { 173 | let BCached0 = mm_Bsub[k][tileCol + 0]; 174 | let BCached1 = mm_Bsub[k][tileCol + 1]; 175 | let BCached2 = mm_Bsub[k][tileCol + 2]; 176 | let BCached3 = mm_Bsub[k][tileCol + 3]; 177 | 178 | for (var innerRow = 0; innerRow < 4; innerRow++) { 179 | let ACached = mm_Asub[tileRow + innerRow][k]; 180 | acc[innerRow][0] = fma(ACached, BCached0, acc[innerRow][0]); 181 | acc[innerRow][1] = fma(ACached, BCached1, acc[innerRow][1]); 182 | acc[innerRow][2] = fma(ACached, BCached2, acc[innerRow][2]); 183 | acc[innerRow][3] = fma(ACached, BCached3, acc[innerRow][3]); 184 | } 185 | } 186 | 187 | workgroupBarrier(); 188 | } 189 | 190 | for (var innerRow = 0; innerRow < 4; innerRow++) { 191 | for (var innerCol = 0; innerCol < 4; innerCol++) { 192 | let val = acc[innerRow][innerCol] + bias[globalCol + innerCol]; 193 | mm_write(batch, globalRow + innerRow, globalCol + innerCol, val); 194 | } 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /benches/sgemm/tfjs.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | use encase::ShaderType; 3 | use inline_python::{python, Context}; 4 | use numpy::PyArrayDyn; 5 | use pyo3::{IntoPy, Python}; 6 | use smallvec::smallvec; 7 | 8 | use criterion::{criterion_group, criterion_main, Criterion, Throughput}; 9 | use wgpu_bencher::{ 10 | dispatch_validate, shape, wgc, wgs, CPUTensor, GPUHandle, KernelBench, KernelContextExt, 11 | OpMetadata, WgpuTimer, Workload, 12 | }; 13 | 14 | lazy_static::lazy_static! { 15 | pub static ref TIMER: WgpuTimer = WgpuTimer::new(pollster::block_on(async { 16 | GPUHandle::new().await.unwrap() 17 | })); 18 | } 19 | 20 | #[derive(ShaderType, Debug)] 21 | pub struct SGEMMMeta { 22 | aShape: glam::IVec3, 23 | aStrides: glam::IVec3, 24 | bShape: glam::IVec3, 25 | bStrides: glam::IVec3, 26 | outShape: glam::IVec3, 27 | outStrides: glam::IVec3, 28 | dimAOuter: i32, 29 | dimBOuter: i32, 30 | dimInner: i32, 31 | } 32 | 33 | impl OpMetadata for SGEMMMeta {} 34 | 35 | #[derive(derive_new::new, Debug)] 36 | pub struct SGEMMBenchmark { 37 | B: usize, 38 | M: usize, 39 | N: usize, 40 | K: usize, 41 | TILE_DIM: usize, 42 | ROW_PER_THREAD: usize, 43 | trans_a: bool, 44 | trans_b: bool, 45 | } 46 | 47 | impl SGEMMBenchmark { 48 | fn shape_fit(&self) -> [bool; 3] { 49 | let aOuter = if self.trans_a { self.K } else { self.M }; 50 | let bOuter = if self.trans_b { self.K } else { self.N }; 51 | let dimInner = if self.trans_a { self.M } else { self.K }; 52 | 53 | let mut shape_fit = [false; 3]; 54 | shape_fit[0] = aOuter % self.TILE_DIM == 0; 55 | shape_fit[1] = bOuter % self.TILE_DIM == 0; 56 | shape_fit[2] = dimInner % self.TILE_DIM == 0; 57 | println!("SHAPE FIT: {:?}", shape_fit); 58 | shape_fit 59 | } 60 | } 61 | 62 | impl KernelBench for SGEMMBenchmark { 63 | type Metadata = SGEMMMeta; 64 | 65 | fn name() -> &'static str { 66 | "SGEMMBenchmark" 67 | } 68 | 69 | fn source(&self, workload: &Workload) -> String { 70 | let mut tera = tera::Tera::default(); 71 | let mut context = tera::Context::new(); 72 | 73 | let is_vec4 = !self.trans_a && !self.trans_b && (self.N % 4 == 0) && (self.K % 4 == 0); 74 | let template = if is_vec4 { 75 | include_str!("../../kernels/sgemm/gemm_vectorized.wgsl") 76 | } else { 77 | include_str!("../../kernels/sgemm/gemm_scalar.wgsl") 78 | }; 79 | tera.add_raw_template(Self::name(), template).unwrap(); 80 | let shape_fit = self.shape_fit(); 81 | context.insert("FIT_A_OUTER", &shape_fit[0]); 82 | context.insert("FIT_B_OUTER", &shape_fit[1]); 83 | context.insert("FIT_INNER", &shape_fit[2]); 84 | context.insert("TRANS_A", &self.trans_a); 85 | context.insert("TRANS_B", &self.trans_b); 86 | 87 | context.insert("TILE_DIM", &self.TILE_DIM); 88 | context.insert("ROW_PER_THREAD", &self.ROW_PER_THREAD); 89 | context.insert_workload(workload); 90 | let kernel = tera.render(Self::name(), &context).unwrap(); 91 | println!("{}", kernel); 92 | kernel 93 | } 94 | 95 | fn tensors(&self) -> Vec { 96 | let (B, M, N, K) = (self.B, self.M, self.N, self.K); 97 | let a = CPUTensor::randn::(shape![B, M, K]); 98 | let b = CPUTensor::randn::(shape![B, K, N]); 99 | let bias = CPUTensor::randn::(shape![N]); 100 | let output = CPUTensor::zeros::(shape![B, M, N]); 101 | vec![a, b, bias, output] 102 | } 103 | 104 | fn workload(&self, _: &[CPUTensor]) -> Workload { 105 | let (TILE_DIM, ROW_PER_THREAD) = (self.TILE_DIM, self.ROW_PER_THREAD); 106 | let workgroup_size = wgs![(TILE_DIM / 4) as _, (TILE_DIM / ROW_PER_THREAD) as _, 1]; 107 | let dimA = if self.trans_a { self.K } else { self.M }; 108 | let dimB = if self.trans_b { self.K } else { self.N }; 109 | let group_x = Workload::ceil(dimB, TILE_DIM); 110 | let group_y = Workload::ceil(dimA, TILE_DIM); 111 | let workgroup_count = wgc![group_x as _, group_y as _, self.B as u32]; 112 | let dispatch = Workload::new(workgroup_size, workgroup_count); 113 | println!("DISPATCH: {:?}", dispatch); 114 | dispatch 115 | } 116 | 117 | fn metadata(&self, _: &[CPUTensor]) -> Self::Metadata { 118 | let (B, M, N, K) = (self.B as i32, self.M as i32, self.N as i32, self.K as i32); 119 | 120 | let aShape = glam::IVec3::new(B, M, K); 121 | let aStrides = glam::IVec3::new(M * K, K, 1); 122 | let bShape = glam::IVec3::new(B, K, N); 123 | let bStrides = glam::IVec3::new(K * N, N, 1); 124 | let outShape = glam::IVec3::new(B, M, N); 125 | let outStrides = glam::IVec3::new(M * N, N, 1); 126 | 127 | let dimAOuter = if self.trans_a { K } else { M }; 128 | let dimBOuter = if self.trans_b { K } else { N }; 129 | let dimInner = if self.trans_a { M } else { K }; 130 | 131 | let meta = SGEMMMeta { 132 | aShape, 133 | aStrides, 134 | bShape, 135 | bStrides, 136 | outShape, 137 | outStrides, 138 | dimAOuter, 139 | dimBOuter, 140 | dimInner, 141 | }; 142 | println!("META: {:?}", meta); 143 | meta 144 | } 145 | 146 | fn validate(&self, tensors: &[CPUTensor]) { 147 | let (a, b, bias) = (&tensors[0], &tensors[1], &tensors[2]); 148 | let (trans_a, trans_b) = (self.trans_a, self.trans_b); 149 | let ground = Python::with_gil(|py| { 150 | let (py_a, py_b, py_bias) = ( 151 | a.to_py::(&py), 152 | b.to_py::(&py), 153 | bias.to_py::(&py), 154 | ); 155 | let (py_trans_a, py_trans_b) = (trans_a.into_py(py), trans_b.into_py(py)); 156 | let result: Context = python! { 157 | import torch 158 | import numpy as np 159 | (a, b) = (torch.from_numpy('py_a), torch.from_numpy('py_b)) 160 | bias = torch.from_numpy('py_bias) 161 | if 'py_trans_a: 162 | print("Transposing A in Python") 163 | a = torch.permute(a, (0, 2, 1)) 164 | if 'py_trans_b: 165 | print("Transposing B in Python") 166 | b = torch.permute(b, (0, 2, 1)) 167 | 168 | result = ((a @ b) + bias).numpy() 169 | }; 170 | CPUTensor::from(result.get_with_gil::<&PyArrayDyn>(py, "result")) 171 | }); 172 | let mut gpu_tensors = dispatch_validate(TIMER.handle(), self, tensors); 173 | let cpu_result = gpu_tensors.remove(3).into_cpu(TIMER.handle()).unwrap(); 174 | println!("GROUND: {}", ground); 175 | println!("OURS: {}", cpu_result); 176 | ground.all_close(&cpu_result, 1e-3, 1e-3).unwrap(); 177 | } 178 | } 179 | 180 | pub fn benchmark(c: &mut Criterion<&WgpuTimer>) { 181 | let B = 1; 182 | let M = 256; 183 | let N = 256; 184 | let K = 256; 185 | let TILE_DIM = 32; 186 | let ROW_PER_THREAD = 4; 187 | 188 | let trans_a = false; 189 | let trans_b = false; 190 | 191 | let bench = SGEMMBenchmark::new(B, M, N, K, TILE_DIM, ROW_PER_THREAD, trans_a, trans_b); 192 | let throughput = Throughput::Elements(2 * (B * M * N * K) as u64); 193 | wgpu_bencher::benchmark(c, &TIMER, bench, throughput) 194 | } 195 | 196 | criterion_group!( 197 | name = bench; 198 | config = Criterion::default().with_measurement(&*TIMER); 199 | targets = benchmark 200 | ); 201 | criterion_main!(bench); 202 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(int_roundings)] 2 | mod bench; 3 | mod data; 4 | mod dtype; 5 | mod handle; 6 | mod metadata; 7 | mod quant; 8 | mod shape; 9 | mod storage; 10 | mod strides; 11 | mod tensor; 12 | mod workload; 13 | 14 | use std::{cell::Cell, ops::Range}; 15 | 16 | pub use bench::*; 17 | pub use data::*; 18 | pub use dtype::*; 19 | pub use handle::*; 20 | pub use metadata::*; 21 | pub use quant::*; 22 | pub use shape::*; 23 | pub use storage::*; 24 | pub use strides::*; 25 | pub use tensor::*; 26 | pub use workload::*; 27 | 28 | use criterion::{ 29 | measurement::{Measurement, ValueFormatter}, 30 | Throughput, 31 | }; 32 | use wgpu::QuerySet; 33 | 34 | pub const MAX_QUERIES: u32 = 4096; 35 | 36 | /// Start and end index in the counter sample buffer 37 | #[derive(Debug, Clone, Copy)] 38 | pub struct QueryPair { 39 | pub start: u32, 40 | pub end: u32, 41 | } 42 | 43 | impl QueryPair { 44 | pub fn first() -> Self { 45 | Self { start: 0, end: 1 } 46 | } 47 | 48 | pub fn size(&self) -> wgpu::BufferAddress { 49 | ((self.end - self.start + 1) as usize * std::mem::size_of::()) as wgpu::BufferAddress 50 | } 51 | 52 | pub fn start_address(&self) -> wgpu::BufferAddress { 53 | (self.start as usize * std::mem::size_of::()) as wgpu::BufferAddress 54 | } 55 | 56 | pub fn end_address(&self) -> wgpu::BufferAddress { 57 | ((self.end + 1) as usize * std::mem::size_of::()) as wgpu::BufferAddress 58 | } 59 | } 60 | 61 | impl From for Range { 62 | fn from(val: QueryPair) -> Self { 63 | val.start..val.end + 1 64 | } 65 | } 66 | 67 | pub struct WgpuTimer { 68 | handle: GPUHandle, 69 | query_set: QuerySet, 70 | resolve_buffer: wgpu::Buffer, 71 | destination_buffer: wgpu::Buffer, 72 | current_query: Cell, 73 | } 74 | 75 | //TODO: dumb 76 | unsafe impl Send for WgpuTimer {} 77 | unsafe impl Sync for WgpuTimer {} 78 | 79 | impl WgpuTimer { 80 | pub const COMPUTE_PER_QUERY: u64 = 100; 81 | 82 | pub fn new(handle: GPUHandle) -> Self { 83 | let query_set = handle.device().create_query_set(&wgpu::QuerySetDescriptor { 84 | count: MAX_QUERIES, 85 | ty: wgpu::QueryType::Timestamp, 86 | label: None, 87 | }); 88 | 89 | let size = MAX_QUERIES as u64 * std::mem::size_of::() as u64; 90 | 91 | let resolve_buffer = handle.device().create_buffer(&wgpu::BufferDescriptor { 92 | label: None, 93 | size, 94 | usage: wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::QUERY_RESOLVE, 95 | mapped_at_creation: false, 96 | }); 97 | 98 | let destination_buffer = handle.device().create_buffer(&wgpu::BufferDescriptor { 99 | label: None, 100 | size, 101 | usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, 102 | mapped_at_creation: false, 103 | }); 104 | 105 | Self { 106 | handle, 107 | query_set, 108 | resolve_buffer, 109 | destination_buffer, 110 | current_query: QueryPair::first().into(), 111 | } 112 | } 113 | 114 | pub fn resolve_pass(&self, encoder: &mut wgpu::CommandEncoder, pass_query: QueryPair) { 115 | let resolution_range = pass_query.into(); 116 | log::trace!("Resolution range: {:?}", resolution_range); 117 | encoder.resolve_query_set(&self.query_set, resolution_range, &self.resolve_buffer, 0); 118 | let size = pass_query.size(); 119 | log::trace!("Resolution size in bytes: {:?}", size); 120 | encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.destination_buffer, 0, size); 121 | } 122 | 123 | pub fn handle(&self) -> &GPUHandle { 124 | &self.handle 125 | } 126 | 127 | pub fn query_set(&self) -> &QuerySet { 128 | &self.query_set 129 | } 130 | 131 | pub fn increment_query(&self) { 132 | let pair = self.current_query.get(); 133 | if pair.end + 2 >= MAX_QUERIES { 134 | panic!("Number of queries exceeds MAX_QUERIES, reduce duration of benchmark"); 135 | } 136 | self.current_query.set(QueryPair { 137 | start: pair.start + 2, 138 | end: pair.end + 2, 139 | }); 140 | } 141 | 142 | pub fn current_query(&self) -> QueryPair { 143 | self.current_query.get() 144 | } 145 | 146 | //Fetches the current query as ComputePassTimestampWrites 147 | pub fn timestamp_writes(&self) -> wgpu::ComputePassTimestampWrites { 148 | wgpu::ComputePassTimestampWrites { 149 | query_set: &self.query_set, 150 | beginning_of_pass_write_index: Some(self.current_query().start), 151 | end_of_pass_write_index: Some(self.current_query().end), 152 | } 153 | } 154 | 155 | pub fn hardware_elapsed(&self, timestamps: &[u64]) -> u64 { 156 | assert!(timestamps.len() % 2 == 0); 157 | let mut elapsed = 0; 158 | for i in (0..timestamps.len()).step_by(2) { 159 | elapsed += timestamps[i + 1] - timestamps[i]; 160 | } 161 | elapsed 162 | } 163 | } 164 | 165 | impl Measurement for &WgpuTimer { 166 | type Intermediate = u32; // Index of the start query 167 | 168 | type Value = u64; // Raw unscaled GPU counter 169 | // Must be multiplied by the timestamp period to get nanoseconds 170 | 171 | fn start(&self) -> Self::Intermediate { 172 | log::trace!("\nQuery at start of pass: {:?}", self.current_query()); 173 | 0 174 | } 175 | 176 | fn end(&self, start_index: Self::Intermediate) -> Self::Value { 177 | log::trace!("\nQuery at end of pass: {:?}", self.current_query()); 178 | let mut encoder = self 179 | .handle 180 | .device() 181 | .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); 182 | 183 | //Large window, eg 0..512 184 | let pass_query = QueryPair { 185 | start: start_index, 186 | end: self.current_query().end - 2, //decrement here to counteract last iter 187 | }; 188 | log::trace!("Pass range: {:?}", pass_query); 189 | 190 | self.resolve_pass(&mut encoder, pass_query); 191 | self.handle().queue().submit(Some(encoder.finish())); 192 | self.handle.device().poll(wgpu::Maintain::Wait); 193 | 194 | self.destination_buffer 195 | .slice(..) 196 | .map_async(wgpu::MapMode::Read, |_| ()); 197 | self.handle.device().poll(wgpu::Maintain::Wait); 198 | let timestamps: Vec = { 199 | let byte_range = pass_query.start_address()..pass_query.end_address(); 200 | let timestamp_view = self.destination_buffer.slice(byte_range).get_mapped_range(); 201 | (*bytemuck::cast_slice(×tamp_view)).to_vec() 202 | }; 203 | log::trace!("Timestamps: {:?}", timestamps); 204 | self.destination_buffer.unmap(); 205 | self.current_query.set(QueryPair::first()); 206 | self.hardware_elapsed(×tamps) / WgpuTimer::COMPUTE_PER_QUERY 207 | } 208 | 209 | fn add(&self, v1: &Self::Value, v2: &Self::Value) -> Self::Value { 210 | v1 + v2 211 | } 212 | 213 | fn zero(&self) -> Self::Value { 214 | 0 215 | } 216 | 217 | fn to_f64(&self, value: &Self::Value) -> f64 { 218 | (self.handle.queue().get_timestamp_period() as f64) * (*value as f64) 219 | } 220 | 221 | fn formatter(&self) -> &dyn ValueFormatter { 222 | &WgpuTimerFormatter 223 | } 224 | } 225 | 226 | struct WgpuTimerFormatter; 227 | 228 | impl ValueFormatter for WgpuTimerFormatter { 229 | fn format_value(&self, value: f64) -> String { 230 | format!("{:.4} ns", value) 231 | } 232 | 233 | fn format_throughput(&self, throughput: &Throughput, value: f64) -> String { 234 | match throughput { 235 | Throughput::Bytes(b) => format!( 236 | "{:.4} GiB/s", 237 | (*b as f64) / (1024.0 * 1024.0 * 1024.0) / (value * 1e-9) 238 | ), 239 | Throughput::Elements(e) => { 240 | let gflop = (*e as f64) / 1e9; 241 | let seconds = value * 1e-9; 242 | let gigaflop_per_second = gflop / seconds; 243 | format!("{:.4} GFLOP/s", gigaflop_per_second) 244 | } 245 | _ => unreachable!(), 246 | } 247 | } 248 | 249 | fn scale_values(&self, _typical_value: f64, _values: &mut [f64]) -> &'static str { 250 | "ns" 251 | } 252 | 253 | /// TODO! 254 | fn scale_throughputs( 255 | &self, 256 | _typical_value: f64, 257 | throughput: &Throughput, 258 | _values: &mut [f64], 259 | ) -> &'static str { 260 | match throughput { 261 | Throughput::Bytes(_) => "GiB/s", 262 | Throughput::Elements(_) => "elements/s", 263 | _ => unreachable!(), 264 | } 265 | } 266 | 267 | fn scale_for_machines(&self, _values: &mut [f64]) -> &'static str { 268 | "ns" 269 | } 270 | } 271 | 272 | #[macro_export] 273 | macro_rules! shape { 274 | ($($x:expr),*$(,)*) => ({ 275 | use smallvec::smallvec; 276 | $crate::Shape::new(smallvec![$($x,)*]) 277 | }); 278 | } 279 | 280 | #[cfg(test)] 281 | mod tests { 282 | use crate::*; 283 | 284 | #[test] 285 | pub fn pair_size() { 286 | let query = QueryPair::first(); 287 | assert_eq!(query.size(), 16); 288 | } 289 | } 290 | -------------------------------------------------------------------------------- /src/tensor.rs: -------------------------------------------------------------------------------- 1 | use bytemuck::NoUninit; 2 | use ndarray::Dimension; 3 | use numpy::ndarray::{ArrayD, ArrayViewD}; 4 | use rand::{distributions::uniform::SampleUniform, prelude::SeedableRng, rngs::SmallRng}; 5 | use rand_distr::{Distribution, StandardNormal}; 6 | 7 | use numpy::PyArrayDyn; 8 | use wgpu::{BindGroupEntry, BindingResource, BufferUsages}; 9 | 10 | use crate::storage::{CPUStorage, GPUStorage}; 11 | use crate::DType; 12 | use crate::DataType; 13 | use crate::GPUHandle; 14 | use crate::{Shape, Storage}; 15 | 16 | #[derive(Clone)] 17 | pub struct Tensor { 18 | dt: DType, 19 | shape: Shape, 20 | storage: S, 21 | } 22 | 23 | unsafe impl Send for Tensor {} 24 | unsafe impl Sync for Tensor {} 25 | 26 | impl Tensor { 27 | pub fn new(dt: DType, shape: Shape, storage: S) -> Self { 28 | Self { dt, shape, storage } 29 | } 30 | 31 | pub fn dt(&self) -> DType { 32 | self.dt 33 | } 34 | 35 | pub fn shape(&self) -> &Shape { 36 | &self.shape 37 | } 38 | 39 | pub fn storage(&self) -> &S { 40 | &self.storage 41 | } 42 | 43 | pub fn storage_mut(&mut self) -> &mut S { 44 | &mut self.storage 45 | } 46 | 47 | pub fn n_bytes(&self) -> usize { 48 | self.shape().numel() * self.dt().size_of() 49 | } 50 | 51 | pub fn into_inner(self) -> (DType, Shape, S) { 52 | let Self { dt, shape, storage } = self; 53 | (dt, shape, storage) 54 | } 55 | } 56 | 57 | pub type CPUTensor = Tensor; 58 | 59 | impl CPUTensor { 60 | pub unsafe fn uninitialized(dt: DType, shape: Shape, alignment: usize) -> anyhow::Result { 61 | let bytes = shape.numel() * dt.size_of(); 62 | let layout = std::alloc::Layout::from_size_align(bytes, alignment)?; 63 | let data = if bytes == 0 { 64 | std::ptr::null() 65 | } else { 66 | let ptr = std::alloc::alloc(layout); 67 | assert!(!ptr.is_null()); 68 | ptr 69 | } as *mut u8; 70 | let storage = CPUStorage::new(data, layout); 71 | Ok(Tensor::new(dt, shape, storage)) 72 | } 73 | 74 | pub fn to_vec(&self) -> anyhow::Result> { 75 | let bytes = self.storage().as_bytes(); 76 | let data = bytemuck::cast_slice(bytes); 77 | Ok(data.to_vec()) 78 | } 79 | 80 | pub fn from_slice(data: &[T], shape: Shape) -> Self { 81 | assert_eq!(data.len(), shape.numel(), "from_slice data length mismatch"); 82 | let bytes: &[u8] = bytemuck::cast_slice(data); 83 | let mut tensor = 84 | unsafe { Tensor::uninitialized(T::dt(), shape, T::dt().size_of()).unwrap() }; 85 | tensor.storage_mut().as_bytes_mut().copy_from_slice(bytes); 86 | tensor 87 | } 88 | 89 | pub unsafe fn from_quantized>( 90 | data: U, 91 | shape: Shape, 92 | dt: DType, 93 | ) -> CPUTensor { 94 | let raw_data = data.as_ref(); 95 | let data_bytes: &[u8] = bytemuck::cast_slice(raw_data); 96 | let n_bytes = data_bytes.len(); 97 | 98 | let layout = std::alloc::Layout::from_size_align(n_bytes, dt.size_of()).unwrap(); 99 | let data = if n_bytes == 0 { 100 | std::ptr::null() 101 | } else { 102 | let ptr = std::alloc::alloc(layout); 103 | assert!(!ptr.is_null()); 104 | ptr 105 | } as *mut u8; 106 | let storage = CPUStorage::new(data, layout); 107 | let mut tensor = Tensor::new(dt, shape, storage); 108 | tensor 109 | .storage_mut() 110 | .as_bytes_mut() 111 | .copy_from_slice(data_bytes); 112 | tensor 113 | } 114 | 115 | pub fn randn(shape: Shape) -> Self { 116 | let mut rng = SmallRng::from_entropy(); 117 | let data = (0..shape.numel()) 118 | .map(|_| { 119 | let sample: f32 = StandardNormal.sample(&mut rng); 120 | T::from(sample).expect("Failed to convert sample") 121 | }) 122 | .collect::>(); 123 | Self::from_slice(&data, shape) 124 | } 125 | 126 | pub fn zeros(shape: Shape) -> Self { 127 | let data = vec![D::zero(); shape.numel()]; 128 | Self::from_slice(&data, shape) 129 | } 130 | 131 | pub fn into_gpu(self, handle: &GPUHandle) -> GPUTensor { 132 | let storage = self.storage.to_gpu(handle); 133 | GPUTensor::new(self.dt, self.shape.clone(), storage) 134 | } 135 | 136 | pub unsafe fn into_array_unchecked(self) -> ArrayD { 137 | self.to_array_view_unchecked::().to_owned() 138 | } 139 | 140 | pub unsafe fn to_array_view_unchecked(&self) -> ArrayViewD { 141 | let inner = self.storage().inner(); 142 | if self.n_bytes() != 0 { 143 | ArrayViewD::from_shape_ptr(self.shape().to_vec(), inner.0 as *const T) 144 | } else { 145 | ArrayViewD::from_shape(self.shape().to_vec(), &[]).unwrap() 146 | } 147 | } 148 | 149 | pub fn to_py<'s, 'p: 's, T: DataType + numpy::Element>( 150 | &'s self, 151 | py: &'p pyo3::Python<'p>, 152 | ) -> &PyArrayDyn { 153 | use numpy::PyArray; 154 | PyArray::from_owned_array(*py, unsafe { self.clone().into_array_unchecked::() }) 155 | } 156 | 157 | pub fn fmt(&self) -> String { 158 | format!("{}", unsafe { self.to_array_view_unchecked::() }) 159 | } 160 | 161 | pub fn debug_fmt(&self) -> String { 162 | format!("{:?}", unsafe { self.to_array_view_unchecked::() }) 163 | } 164 | 165 | pub fn all_close(&self, other: &Self, atol: f32, rtol: f32) -> anyhow::Result<()> { 166 | if self.shape() != other.shape() { 167 | anyhow::bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape()) 168 | } 169 | let ma = unsafe { self.to_array_view_unchecked::() }; 170 | let mb = unsafe { other.to_array_view_unchecked::() }; 171 | let mut elem_cnt = 0; 172 | let mut fail_cnt = 0; 173 | let mut total_error = 0f32; 174 | let mut mae = -1f32; 175 | let mut mae_idxs = Default::default(); 176 | ndarray::indices_of(&ma).into_iter().try_for_each(|idxs| { 177 | let (a, b) = (ma[&idxs], mb[&idxs]); 178 | let abs_diff = (a - b).abs(); 179 | let cur_mae = mae.max(abs_diff); 180 | if cur_mae > mae { 181 | mae = cur_mae; 182 | mae_idxs = idxs.clone(); 183 | } 184 | total_error += abs_diff; 185 | elem_cnt += 1; 186 | 187 | if !((a.is_nan() && b.is_nan()) 188 | || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum()) 189 | || abs_diff <= atol + rtol * b.abs()) 190 | { 191 | let slice = idxs.slice(); 192 | log::trace!( 193 | "Mismatch at {:?}: {:?} != {:?} (atol={}, rtol={})", 194 | slice, 195 | a, 196 | b, 197 | atol, 198 | rtol 199 | ); 200 | fail_cnt += 1; 201 | } 202 | Ok::<(), anyhow::Error>(()) 203 | })?; 204 | let avg_error = total_error / elem_cnt as f32; 205 | let slice = mae_idxs.slice(); 206 | if fail_cnt > 0 { 207 | anyhow::bail!( 208 | "{} samples not close - AVGE={} MAE={} at {:?}", 209 | fail_cnt, 210 | avg_error, 211 | mae, 212 | slice, 213 | ); 214 | } else { 215 | println!("All close - AVGE={} MAE={} at {:?}", avg_error, mae, slice,); 216 | Ok(()) 217 | } 218 | } 219 | } 220 | 221 | impl std::fmt::Debug for CPUTensor { 222 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 223 | f.debug_struct("CPUTensor") 224 | .field("dt", &self.dt) 225 | .field("shape", &self.shape) 226 | .field("storage", &self.storage) 227 | .finish() 228 | } 229 | } 230 | 231 | impl std::fmt::Display for CPUTensor { 232 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 233 | f.write_str(&self.fmt()) 234 | } 235 | } 236 | 237 | impl From<&PyArrayDyn> for CPUTensor { 238 | fn from(array: &PyArrayDyn) -> Self { 239 | Self::from(array.to_owned_array()) 240 | } 241 | } 242 | 243 | impl From> for CPUTensor { 244 | fn from(array: PyArrayDyn) -> Self { 245 | Self::from(array.to_owned_array()) 246 | } 247 | } 248 | 249 | impl From> for CPUTensor { 250 | fn from(it: ArrayD) -> Self { 251 | if it.as_slice().is_some() { 252 | let layout = std::alloc::Layout::from_size_align( 253 | it.len() * std::mem::size_of::(), 254 | std::mem::align_of::(), 255 | ) 256 | .unwrap(); 257 | let shape = it.shape().into(); 258 | let vec = it.into_raw_vec().into_boxed_slice(); 259 | let data = Box::into_raw(vec) as *mut u8; 260 | 261 | Tensor::new(T::dt(), shape, CPUStorage::new(data, layout)) 262 | } else { 263 | panic!("Cannot convert numpy array with non-contiguous memory layout to tensor"); 264 | } 265 | } 266 | } 267 | 268 | pub type GPUTensor = Tensor; 269 | 270 | impl GPUTensor { 271 | /// # Bindings 272 | /// 273 | /// Only applicable to GPU tensors. 274 | /// Generates the bind group entries required to bind the tensor to a kernel. 275 | /// Quantized tensors may use multiple bind groups. 276 | /// Unquantized tensors should only use a single bind group. 277 | pub(crate) fn bindings(&self, current_binding: usize) -> Vec { 278 | let buf = self.storage().inner(); 279 | let numel = self.shape().numel(); 280 | let segments = self.dt().segments(numel, buf.size() as usize); 281 | 282 | let mut entries = vec![]; 283 | for (idx, seg) in segments.iter().enumerate() { 284 | let (offset, size) = (seg.offset, seg.size); 285 | entries.push(BindGroupEntry { 286 | binding: ((current_binding + idx) % 4) as _, 287 | resource: BindingResource::Buffer(wgpu::BufferBinding { 288 | buffer: buf, 289 | offset, 290 | size, 291 | }), 292 | }); 293 | } 294 | entries 295 | } 296 | 297 | fn read_to_host(shape: Shape, dt: DType, bytes: &[A]) -> CPUTensor { 298 | match dt { 299 | DType::F32 => CPUTensor::from_slice::(bytemuck::cast_slice(bytes), shape), 300 | DType::I32 => CPUTensor::from_slice::(bytemuck::cast_slice(bytes), shape), 301 | DType::U32 => CPUTensor::from_slice::(bytemuck::cast_slice(bytes), shape), 302 | _ => panic!("Unsupported dtype"), 303 | } 304 | } 305 | 306 | fn into_cpu_inner(self, handle: &GPUHandle) -> anyhow::Result { 307 | let (dt, shape, storage) = self.into_inner(); 308 | if !storage.usage().contains(BufferUsages::COPY_SRC) { 309 | panic!("Attempted to read GPU tensor to host without COPY_SRC usage") 310 | } 311 | let buffer_slice = storage.slice(..); 312 | let (tx, rx) = std::sync::mpsc::channel(); 313 | 314 | wgpu::util::DownloadBuffer::read_buffer( 315 | handle.device(), 316 | handle.queue(), 317 | &buffer_slice, 318 | move |buffer| { 319 | // Called on download completed 320 | tx.send(match buffer { 321 | Ok(db) => Ok(Self::read_to_host(shape, dt, &db)), 322 | Err(error) => panic!("Failed to read GPU tensor to host: {:?}", error), 323 | }) 324 | .unwrap(); 325 | }, 326 | ); 327 | handle.queue().submit(None); 328 | handle.device().poll(wgpu::Maintain::Wait); 329 | rx.recv().unwrap() 330 | } 331 | 332 | ///Consumes the GPU tensor and returns a CPU tensor 333 | pub fn into_cpu(self, handle: &GPUHandle) -> anyhow::Result { 334 | self.into_cpu_inner(handle) 335 | } 336 | } 337 | --------------------------------------------------------------------------------