├── .gitignore ├── assets └── img.png ├── src ├── gpu │ ├── mod.rs │ ├── wgsl │ │ ├── matmul.wgsl │ │ ├── slice.wgsl │ │ ├── binop.wgsl │ │ └── reduce_to_scalar.wgsl │ ├── context.rs │ ├── op_type.rs │ └── gpu_array.rs ├── optim.rs ├── lib.rs ├── error.rs ├── nn │ ├── functions │ │ └── mod.rs │ └── mod.rs ├── traits.rs ├── array.rs └── tensor.rs ├── .idea ├── vcs.xml ├── .gitignore ├── modules.xml └── ss_tensoria.iml ├── Cargo.toml └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /assets/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ariaghora/tensoria/HEAD/assets/img.png -------------------------------------------------------------------------------- /src/gpu/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod context; 2 | pub mod gpu_array; 3 | pub mod op_type; 4 | -------------------------------------------------------------------------------- /src/optim.rs: -------------------------------------------------------------------------------- 1 | use std::ops::{Add, Sub}; 2 | 3 | pub fn step(_params: Vec, _grads: Vec) {} 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub use std::ops::{Add, Div, Mul, Sub}; 2 | 3 | pub mod array; 4 | pub mod error; 5 | pub mod gpu; 6 | pub mod nn; 7 | pub mod optim; 8 | pub mod tensor; 9 | pub mod traits; 10 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/ss_tensoria.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Display, Formatter}; 2 | 3 | #[derive(Debug)] 4 | pub enum TensoriaError { 5 | CannotReshapeError, 6 | AccessingMismatchedType, 7 | DeviceNotCreated, 8 | BackwardOnTensorWithNoGrad, 9 | AlreadyGPUTensor, 10 | } 11 | 12 | impl Display for TensoriaError { 13 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 14 | write!(f, "{}", self.to_string()) 15 | } 16 | } 17 | 18 | impl std::error::Error for TensoriaError {} 19 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tensoria" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [profile.release] 7 | strip = true 8 | 9 | [dependencies] 10 | uuid = { version = "1.6.1", features = ["v4"] } 11 | ndarray = { version = "0.15.6" } 12 | wgpu = { version = "0.18.0", features = ["wgsl", "vulkan-portability"] } 13 | bytemuck = "1.14.0" 14 | pollster = "0.3.0" 15 | flume = "0.11.0" 16 | include_dir = "0.7.3" 17 | tera = "1.19.1" 18 | rand = "0.8.5" 19 | lazy_static = "1.4.0" 20 | num-traits = "0.2.17" 21 | num-integer = "0.1.45" -------------------------------------------------------------------------------- /src/nn/functions/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::gpu::gpu_array::GetType; 2 | use crate::tensor::Tensor; 3 | use crate::traits::TensoriaOps; 4 | 5 | pub fn softmax_unstable(x: &Tensor, axis: usize) -> Tensor 6 | where 7 | Vec: GetType, 8 | { 9 | let nom = x.exp(); 10 | &nom / &nom.sum(Some(axis), true) 11 | } 12 | 13 | #[cfg(test)] 14 | mod test { 15 | use super::softmax_unstable; 16 | use crate::{error::TensoriaError, tensor::Tensor}; 17 | 18 | #[test] 19 | fn softmax() -> Result<(), TensoriaError> { 20 | let mut x = Tensor::new([3, 2], vec![2., 2., 4., 4., 6., 6.])?; 21 | x.set_requires_grad(true); 22 | 23 | let res = softmax_unstable(&x, 1); 24 | assert_eq!(res.to_vec(), vec![0.5; 6]); 25 | res.backward()?; 26 | 27 | assert_eq!(x.grad_vec().unwrap(), vec![0.; 6]); 28 | Ok(()) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/traits.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use bytemuck::Pod; 4 | use num_traits::{FromPrimitive, Num, NumCast, NumOps, Zero}; 5 | 6 | use crate::gpu::gpu_array::GetType; 7 | 8 | pub trait TensoriaOps: 9 | Sized 10 | + Clone 11 | + Num 12 | + NumCast 13 | + NumOps 14 | + PartialOrd 15 | + Default 16 | + Zero 17 | + FromPrimitive 18 | + Clone 19 | + Pod 20 | + Default 21 | + Debug 22 | { 23 | } 24 | 25 | impl TensoriaOps for T 26 | where 27 | T: Sized 28 | + Clone 29 | + Num 30 | + NumCast 31 | + NumOps 32 | + PartialOrd 33 | + Default 34 | + Zero 35 | + FromPrimitive 36 | + Clone 37 | + Pod 38 | + Default 39 | + Debug, 40 | Vec: GetType, 41 | { 42 | } 43 | 44 | pub trait GPUType: Clone + Pod + Default + Debug {} 45 | 46 | impl GPUType for T 47 | where 48 | T: Clone + Pod + Default + Debug, 49 | Vec: GetType, 50 | { 51 | } 52 | -------------------------------------------------------------------------------- /src/gpu/wgsl/matmul.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var input_0: array<{{input_0_type}}>; 3 | 4 | @group(0) @binding(1) 5 | var input_1: array<{{input_1_type}}>; 6 | 7 | @group(0) @binding(2) 8 | var output_0: array<{{output_0_type}}>; 9 | 10 | @compute @workgroup_size(16, 16) 11 | fn main(@builtin(global_invocation_id) gid: vec3) { 12 | var gx: u32 = gid.x; 13 | var gy: u32 = gid.y; 14 | var M: u32 = {{M}}u; 15 | var N: u32 = {{N}}u; 16 | var K: u32 = {{K}}u; 17 | 18 | if (gx >= N || gy >= M) { return; } 19 | 20 | var sum: {{output_0_type}} = {{output_0_type}}(0); 21 | for (var k: u32 = 0u; k < K; k += 4u) { 22 | sum += input_0[gy * K + k] * input_1[k * N + gx]+ 23 | input_0[gy * K + (k + 1u)] * input_1[(k + 1u) * N + gx] + 24 | input_0[gy * K + (k + 2u)] * input_1[(k + 2u) * N + gx] + 25 | input_0[gy * K + (k + 3u)] * input_1[(k + 3u) * N + gx]; 26 | } 27 | output_0[gy * N + gx] = sum; 28 | } -------------------------------------------------------------------------------- /src/gpu/wgsl/slice.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var input: array<{{input_type}}>; 3 | 4 | @group(0) @binding(1) 5 | var indices: array<{{indices_type}}>; 6 | 7 | @group(0) @binding(2) 8 | var output: array<{{output_type}}>; 9 | 10 | 11 | @compute @workgroup_size(256) 12 | fn main(@builtin(global_invocation_id) gid: vec3) { 13 | var input_shape = array( {{input_shape_csv}} ); 14 | var input_strides = array( {{input_strides_csv}} ); 15 | var output_shape = array( {{output_shape_csv}} ); 16 | var output_strides = array( {{output_strides_csv}} ); 17 | var slicing_axis = {{slicing_axis}}; 18 | 19 | var offset = i32(gid.x); 20 | if (offset >= {{output_len}}) { return; } 21 | 22 | var nd_index = array( {{nd_index_init}} ); 23 | for (var i = {{input_ndim}} - 1; i >= 0; i--) { 24 | if (i == slicing_axis) { 25 | nd_index[i] = indices[offset % input_shape[i]]; 26 | } else { 27 | nd_index[i] = offset % input_shape[i]; 28 | } 29 | offset /= input_shape[i]; 30 | } 31 | 32 | var logical_offset = 0; 33 | for (var i = 0; i < {{input_ndim}}; i++) { 34 | logical_offset += nd_index[i] * input_strides[i]; 35 | } 36 | 37 | output[gid.x] = input[logical_offset]; 38 | } -------------------------------------------------------------------------------- /src/gpu/wgsl/binop.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var input_0: array<{{input_0_type}}>; 3 | 4 | @group(0) @binding(1) 5 | var input_1: array<{{input_1_type}}>; 6 | 7 | @group(0) @binding(2) 8 | var output_0: array<{{output_0_type}}>; 9 | 10 | const input_0_shape = array( {{input_0_shape_csv}} ); 11 | const input_0_strides = array( {{input_0_strides_csv}} ); 12 | const input_1_shape = array( {{input_1_shape_csv}} ); 13 | const input_1_strides = array( {{input_1_strides_csv}} ); 14 | const output_shape = array( {{output_shape_csv}} ); 15 | 16 | @compute @workgroup_size(16, 16) 17 | fn main(@builtin(global_invocation_id) gid: vec3) { 18 | let numel_x = u32(output_shape[0]); 19 | var idx: u32 = gid.y * numel_x + gid.x; 20 | 21 | if idx >= {{output_len}}u { return; } 22 | 23 | {% if left_broadcast -%} 24 | var idx0 = 0u; 25 | {{idx0_code}} 26 | {% else -%} 27 | var idx0 = idx; 28 | {% endif %} 29 | 30 | {%- if right_broadcast -%} 31 | var idx1 = 0u; 32 | {{idx1_code}} 33 | {%- else -%} 34 | var idx1 = idx; 35 | {% endif %} 36 | 37 | let lhs = input_0[idx0]; 38 | let rhs = input_1[idx1]; 39 | var out = {{output_0_type}}(0); 40 | 41 | // The `binop_stmt` is a placeholder in which we put the actual calculation 42 | // of the output `out`. Each implementation is done in the host code `src/gpu/op_type.rs`. 43 | {{ binop_stmt }} 44 | 45 | output_0[idx] = out; 46 | } -------------------------------------------------------------------------------- /src/gpu/wgsl/reduce_to_scalar.wgsl: -------------------------------------------------------------------------------- 1 | alias input_type = {{input_type}}; 2 | alias output_type = {{output_type}}; 3 | 4 | @group(0) @binding(0) 5 | var input: array; 6 | 7 | @group(0) @binding(1) 8 | var output: array; 9 | 10 | // Local memory for partial sums 11 | var sums: array; 12 | 13 | @compute @workgroup_size(256) 14 | fn main(@builtin(global_invocation_id) global_id : vec3) { 15 | var local_sum: input_type = input_type(0); 16 | let group_size = 256u; 17 | let input_len = {{input_len}}u; 18 | let num_groups = (input_len + group_size - 1u) / group_size; 19 | let group_id = global_id.x / group_size; 20 | let local_id = global_id.x % group_size; 21 | 22 | // Each thread in the workgroup adds up its subset of the input 23 | for (var i = local_id; i < input_len; i += group_size) { 24 | let lhs = local_sum; 25 | let rhs = input[i]; 26 | let out = lhs + rhs; 27 | local_sum = out; 28 | } 29 | 30 | // Store the sum in shared memory 31 | sums[local_id] = local_sum; 32 | workgroupBarrier(); 33 | 34 | // Reduction in shared memory 35 | for (var stride = group_size / 2u; stride > 0u; stride /= 2u) { 36 | if (local_id < stride) { 37 | let lhs = sums[local_id]; 38 | let rhs = sums[local_id + stride]; 39 | var out: output_type; 40 | {{reduction_stmt}}; 41 | sums[local_id] = out; 42 | } 43 | workgroupBarrier(); 44 | } 45 | 46 | // Write the result to the output array 47 | if (local_id == 0u) { 48 | var out: output_type; 49 | out = sums[0]; 50 | {{postproc_stmt}}; 51 | output[group_id] = out; 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | T E N S O R I A 3 |

4 | 5 |

6 | ᕕ(⌐■_■)ᕗ ♪♬ 7 |

8 | 9 | --- 10 | 11 |

12 | An ergonomic tensor manipulation library running on GPU, self-contained, in pure rust 13 |

14 | 15 | > At this moment, this library is meant to be the fundamental for one of my research works. There is only **_very limited_** set of supported operations. You may consider using [burn-rs](https://burn.dev/) for a more complete one or even [rust binding for PyTorch](https://github.com/LaurentMazare/tch-rs). 16 | 17 | ## Features 18 | 19 | - Supports GPU with CPU fallback. 20 | - Provides automatic gradient computation (autograd). 21 | - Allows creation of tensors with arbitrary dimensions at runtime. 22 | - Offers an ergonomic API. 23 | 24 | ## Note 25 | 26 | - As a trade-off for easy API, tensor operations' shape checking occurs at runtime, potentially 27 | causing panics due to shape incompatibilities. 28 | - The internal implementation is not thread-safe yet, so please refrain from using this in multithreaded programs. 29 | Consequently, when running `cargo test`, you need to specify `-- --test-threads=1`. 30 | 31 | ## Example 32 | 33 | ```rust 34 | use std::error::Error; 35 | 36 | fn main() -> Result<(), Box> { 37 | let x = Tensor::new([1, 2], vec![1., 2.])?; 38 | let y = Tensor::new([1, 2], vec![3., 4.])?; 39 | let res = &x + &y; 40 | assert_eq!(res.data(), vec![4., 6.]); 41 | 42 | // Or use GPU (via WGPU) if you wish by calling `.to_gpu()`. 43 | // The tensor will now operate on GPU array, while maintaining 44 | // the same user-facing API. 45 | let x = Tensor::new([1, 2], vec![1., 2.])?.to_gpu()?; 46 | let y = Tensor::new([1, 2], vec![3., 4.])?.to_gpu()?; 47 | let res = &x + &y; 48 | assert_eq!(res.data(), vec![4., 6.]); 49 | 50 | // Autograd... 51 | let mut x = Tensor::new([2, 2], vec![1, 2, 3, 4])?.to_gpu()?; 52 | x.set_requires_grad(true); 53 | 54 | let res = x.mul(&x).mul(&x); 55 | res.backward()?; 56 | assert_eq!(x.grad().unwrap(), vec![3, 12, 27, 48]); 57 | 58 | Ok(()) 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /src/gpu/context.rs: -------------------------------------------------------------------------------- 1 | use std::sync::{Arc, RwLock}; 2 | 3 | use uuid::Uuid; 4 | 5 | #[derive(Clone)] 6 | pub struct GPUContext { 7 | pub(crate) id: Uuid, 8 | pub(crate) executor: Arc>, 9 | } 10 | 11 | impl GPUContext { 12 | pub fn new() -> Self { 13 | Self { 14 | id: Uuid::new_v4(), 15 | executor: Arc::new(RwLock::new(Executor::new())), 16 | } 17 | } 18 | } 19 | 20 | pub struct Executor { 21 | pub(crate) synced: bool, 22 | pub(crate) device: wgpu::Device, 23 | pub(crate) encoder: wgpu::CommandEncoder, 24 | pub(crate) queue: wgpu::Queue, 25 | } 26 | 27 | impl Executor { 28 | pub fn new() -> Self { 29 | let (device, queue) = pollster::block_on(Self::create_device()); 30 | let encoder = 31 | device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); 32 | Self { 33 | synced: false, 34 | device, 35 | queue, 36 | encoder, 37 | } 38 | } 39 | 40 | async fn create_device() -> (wgpu::Device, wgpu::Queue) { 41 | let instance = wgpu::Instance::default(); 42 | 43 | let adapter = instance 44 | .request_adapter(&wgpu::RequestAdapterOptions::default()) 45 | .await 46 | .unwrap(); 47 | 48 | let mut limits = wgpu::Limits::default(); 49 | // limits.max_buffer_size = 256 << 25; 50 | // limits.max_storage_buffer_binding_size = 256 << 21; 51 | 52 | let features = adapter.features(); 53 | let (device, queue) = adapter 54 | .request_device( 55 | &wgpu::DeviceDescriptor { 56 | label: None, 57 | features, 58 | limits, 59 | }, 60 | None, 61 | ) 62 | .await 63 | .unwrap(); 64 | (device, queue) 65 | } 66 | 67 | pub(crate) fn sync(&mut self) { 68 | if self.synced { 69 | return; 70 | } 71 | 72 | // Actually poll GPU here 73 | let current_encoder = std::mem::replace( 74 | &mut self.encoder, 75 | self.device 76 | .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }), 77 | ); 78 | self.queue.submit(Some(current_encoder.finish())); 79 | self.device.poll(wgpu::Maintain::Wait); 80 | 81 | // update state 82 | self.synced = true; 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/nn/mod.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::marker::PhantomData; 3 | use std::ops::Add; 4 | use std::sync::{Arc, RwLock}; 5 | 6 | use bytemuck::Pod; 7 | use rand::distributions::{Distribution, Uniform}; 8 | 9 | use crate::error::TensoriaError; 10 | use crate::gpu::gpu_array::GetType; 11 | use crate::tensor::Tensor; 12 | use crate::traits::TensoriaOps; 13 | 14 | pub mod functions; 15 | 16 | pub trait Module { 17 | fn forward(&self, x: &Tensor) -> Tensor; 18 | fn to_gpu(&self) -> Result 19 | where 20 | Self: Sized; 21 | fn parameters(&self) -> Vec>>>; 22 | fn zero_grad(&mut self); 23 | } 24 | 25 | pub struct Linear { 26 | w: Arc>>, 27 | b: Arc>>, 28 | } 29 | 30 | impl Linear 31 | where 32 | T: TensoriaOps + Clone + Pod + Default + Debug, 33 | Vec: GetType, 34 | { 35 | pub fn new(in_size: usize, out_size: usize) -> Result { 36 | let mut rng = rand::thread_rng(); 37 | 38 | let xavier_limit = 6.0f32 / ((in_size + out_size) as f32).sqrt(); 39 | let uniform = Uniform::new(-xavier_limit, xavier_limit); 40 | 41 | let w_val = (0..in_size * out_size) 42 | .map(|_| T::from(uniform.sample(&mut rng)).unwrap()) 43 | .collect(); 44 | 45 | let mut w = Tensor::new([in_size, out_size], w_val)?; 46 | let mut b = Tensor::new([out_size], vec![T::from(0.0).unwrap(); out_size])?; 47 | w.set_requires_grad(true); 48 | b.set_requires_grad(true); 49 | Ok(Self { 50 | w: Arc::new(RwLock::new(w)), 51 | b: Arc::new(RwLock::new(b)), 52 | }) 53 | } 54 | } 55 | 56 | impl Module for Linear 57 | where 58 | T: TensoriaOps + Clone + Pod + Default + Debug, 59 | Vec: GetType, 60 | { 61 | fn forward(&self, x: &Tensor) -> Tensor { 62 | let w = self.w.read().unwrap(); 63 | let b = self.b.read().unwrap(); 64 | x.matmul(&w).add(&b) 65 | } 66 | 67 | fn to_gpu(&self) -> Result { 68 | Ok(Self { 69 | w: Arc::new(RwLock::new(self.w.read().unwrap().to_gpu()?)), 70 | b: Arc::new(RwLock::new(self.b.read().unwrap().to_gpu()?)), 71 | }) 72 | } 73 | 74 | fn parameters(&self) -> Vec>>> { 75 | vec![self.w.clone(), self.b.clone()] 76 | } 77 | 78 | fn zero_grad(&mut self) { 79 | self.w.write().unwrap().zero_grad(); 80 | self.b.write().unwrap().zero_grad(); 81 | } 82 | } 83 | 84 | pub struct Sequential { 85 | modules: Vec, 86 | _p: PhantomData, 87 | } 88 | 89 | impl Sequential 90 | where 91 | T: TensoriaOps + Clone + Pod + Default + Debug, 92 | M: Module, 93 | Vec: GetType, 94 | { 95 | pub fn new(modules: Vec) -> Self { 96 | let mut self_modules = vec![]; 97 | for m in modules { 98 | self_modules.push(m); 99 | } 100 | Self { 101 | modules: self_modules, 102 | _p: PhantomData, 103 | } 104 | } 105 | } 106 | 107 | impl Module for Sequential 108 | where 109 | M: Module, 110 | T: TensoriaOps + Clone + Pod + Default + Debug, 111 | Vec: GetType, 112 | { 113 | fn forward(&self, x: &Tensor) -> Tensor { 114 | let mut res = self.modules[0].forward(x); 115 | for m in &self.modules[1..] { 116 | res = m.forward(&res); 117 | } 118 | res 119 | } 120 | 121 | fn to_gpu(&self) -> Result 122 | where 123 | Self: Sized, 124 | { 125 | todo!() 126 | } 127 | 128 | fn parameters(&self) -> Vec>>> { 129 | let mut params = vec![]; 130 | for module in &self.modules { 131 | for param in &module.parameters() { 132 | params.push(param.clone()) 133 | } 134 | } 135 | params 136 | } 137 | 138 | fn zero_grad(&mut self) { 139 | for module in &mut self.modules { 140 | module.zero_grad() 141 | } 142 | } 143 | } 144 | 145 | #[cfg(test)] 146 | mod test { 147 | use crate::error::TensoriaError; 148 | use crate::nn::{Linear, Module, Sequential}; 149 | use crate::tensor::Tensor; 150 | 151 | #[test] 152 | fn linear() -> Result<(), TensoriaError> { 153 | let x = Tensor::new([10, 10], vec![1.0; 100])?; 154 | let linear = Linear::new(10, 2)?; 155 | let res = linear.forward(&x); 156 | 157 | assert_eq!(res.shape(), vec![10, 2]); 158 | 159 | let linear = linear.to_gpu()?; 160 | let res = linear.forward(&x.to_gpu()?); 161 | assert_eq!(res.shape(), vec![10, 2]); 162 | 163 | Ok(()) 164 | } 165 | 166 | #[test] 167 | fn seq() -> Result<(), TensoriaError> { 168 | let x = Tensor::new([1, 2], vec![1.0, 1.0])?; 169 | let seq = Sequential::new(vec![Linear::new(2, 2)?, Linear::new(2, 1)?]); 170 | _ = seq.forward(&x); 171 | Ok(()) 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /src/gpu/op_type.rs: -------------------------------------------------------------------------------- 1 | use tera::Context; 2 | 3 | use crate::gpu::gpu_array::{compute_broadcasted_shape_and_strides, GPUArray}; 4 | 5 | pub trait Shader { 6 | fn shader_path(&self) -> String; 7 | fn prepare( 8 | &self, 9 | operands: Vec<&GPUArray>, 10 | output: &GPUArray, 11 | params: &mut Context, 12 | ) -> (u32, u32, u32); 13 | } 14 | 15 | macro_rules! define_elementwise_binop { 16 | ($struct_name: ident, $binop_stmt: expr) => { 17 | pub struct $struct_name {} 18 | impl Shader for $struct_name { 19 | fn shader_path(&self) -> String { 20 | "binop.wgsl".into() 21 | } 22 | fn prepare( 23 | &self, 24 | operands: Vec<&GPUArray>, 25 | output: &GPUArray, 26 | params: &mut Context, 27 | ) -> (u32, u32, u32) { 28 | prepare_binop_broadcast_shader(operands, output, params, $binop_stmt) 29 | } 30 | } 31 | }; 32 | } 33 | 34 | macro_rules! define_reduce_to_scalar { 35 | ($struct_name: ident, $reduction_stmt:expr, $postproc_stmt:expr) => { 36 | pub struct $struct_name {} 37 | impl Shader for $struct_name { 38 | fn shader_path(&self) -> String { 39 | "reduce_to_scalar.wgsl".into() 40 | } 41 | fn prepare( 42 | &self, 43 | operands: Vec<&GPUArray>, 44 | output: &GPUArray, 45 | params: &mut Context, 46 | ) -> (u32, u32, u32) { 47 | prepare_reduction_shader(operands, output, params, $reduction_stmt, $postproc_stmt) 48 | } 49 | } 50 | }; 51 | } 52 | 53 | define_elementwise_binop!(Add, "out = lhs + rhs;"); 54 | define_elementwise_binop!(Mul, "out = lhs * rhs;"); 55 | define_elementwise_binop!(Sub, "out = lhs - rhs;"); 56 | define_elementwise_binop!(Div, "out = lhs / rhs;"); 57 | 58 | define_reduce_to_scalar!( 59 | Mean, 60 | "out = lhs + rhs;", 61 | "out = out / output_type(input_len);" 62 | ); 63 | define_reduce_to_scalar!(Sum, "out = lhs + rhs;", ""); 64 | 65 | pub struct MatMul {} 66 | 67 | impl Shader for MatMul { 68 | fn shader_path(&self) -> String { 69 | "matmul.wgsl".into() 70 | } 71 | 72 | fn prepare( 73 | &self, 74 | operands: Vec<&GPUArray>, 75 | output: &GPUArray, 76 | params: &mut Context, 77 | ) -> (u32, u32, u32) { 78 | let m = operands[0].shape[0]; 79 | let n = operands[1].shape[1]; 80 | let k = operands[1].shape[0]; 81 | 82 | params.insert("input_0_type", &operands[0].data_type.wgsl_type()); 83 | params.insert("input_1_type", &operands[1].data_type.wgsl_type()); 84 | params.insert("output_0_type", &output.data_type.wgsl_type()); 85 | params.insert("M", &m); 86 | params.insert("N", &n); 87 | params.insert("K", &k); 88 | 89 | let local_size_x_y = 16; 90 | let num_workgroups_x = (n + local_size_x_y - 1) / local_size_x_y; 91 | let num_workgroups_y = (m + local_size_x_y - 1) / local_size_x_y; 92 | let wg = (num_workgroups_x as u32, num_workgroups_y as u32, 1); 93 | return wg; 94 | } 95 | } 96 | 97 | pub struct Slice { 98 | slice_axis: i32, 99 | } 100 | 101 | impl Slice { 102 | pub fn new(axis: i32) -> Self { 103 | Slice { slice_axis: axis } 104 | } 105 | } 106 | 107 | impl Shader for Slice { 108 | fn shader_path(&self) -> String { 109 | "slice.wgsl".into() 110 | } 111 | 112 | fn prepare( 113 | &self, 114 | operands: Vec<&GPUArray>, 115 | output: &GPUArray, 116 | params: &mut Context, 117 | ) -> (u32, u32, u32) { 118 | params.insert("input_type", &operands[0].data_type.wgsl_type()); 119 | params.insert("indices_type", &operands[1].data_type.wgsl_type()); 120 | params.insert("output_type", &output.data_type.wgsl_type()); 121 | // Input shape now adjusted to be after slice, i.e., similar to that of the output 122 | params.insert("input_shape_csv", &vec_to_csv(&output.shape)); 123 | params.insert("input_strides_csv", &vec_to_csv(&operands[0].strides)); 124 | params.insert("input_ndim", &operands[0].shape.len()); 125 | params.insert( 126 | "indices_len", 127 | &operands[1].shape.iter().fold(1, |x, y| x * y), 128 | ); 129 | params.insert("output_shape_csv", &vec_to_csv(&output.shape)); 130 | params.insert("output_strides_csv", &vec_to_csv(&output.strides)); 131 | params.insert("output_ndim", &output.shape.len()); 132 | params.insert("output_len", &output.shape.iter().fold(1, |x, y| x * y)); 133 | params.insert("slicing_axis", &self.slice_axis); 134 | params.insert( 135 | "nd_index_init", 136 | &vec_to_csv(&vec![0; operands[0].shape.len()]), 137 | ); 138 | 139 | let local_size_x = 256; 140 | let out_shape = &output.shape; 141 | let num_elements = out_shape.iter().fold(1, |x, y| x * y); 142 | let num_workgroups_x = (num_elements + local_size_x - 1) / local_size_x; 143 | (num_workgroups_x as u32, 1, 1) 144 | } 145 | } 146 | 147 | fn vec_to_csv(shape: &Vec) -> String { 148 | shape 149 | .iter() 150 | .map(|v| v.to_string()) 151 | .collect::>() 152 | .join(",") 153 | } 154 | 155 | fn generate_idx_code( 156 | idx_var_name: &str, 157 | shape: &Vec, 158 | adjusted_strides: &Vec, 159 | ) -> String { 160 | let mut code = String::new(); 161 | let mut division_products = vec![1; shape.len()]; 162 | 163 | // Precompute division products in reverse order 164 | for i in (0..shape.len() - 1).rev() { 165 | division_products[i] = division_products[i + 1] * shape[i + 1]; 166 | } 167 | 168 | let mut terms = Vec::new(); 169 | for i in 0..shape.len() { 170 | let term = format!( 171 | "((idx / {}u) % {}u) * {}u", 172 | division_products[i], shape[i], adjusted_strides[i] 173 | ); 174 | terms.push(term); 175 | } 176 | 177 | let compiled_terms = terms.join(" + "); 178 | code.push_str(&format!("{} = {};\n", idx_var_name, compiled_terms)); 179 | 180 | code 181 | } 182 | 183 | /// TODO: Handle specific case: 184 | /// - tensor-scalar 185 | /// - scalar-tensor 186 | fn prepare_binop_broadcast_shader( 187 | operands: Vec<&GPUArray>, 188 | output: &GPUArray, 189 | params: &mut Context, 190 | binop_stmt: &str, 191 | ) -> (u32, u32, u32) { 192 | params.insert("input_0_type", &operands[0].data_type.wgsl_type()); 193 | params.insert("input_1_type", &operands[1].data_type.wgsl_type()); 194 | params.insert("output_0_type", &output.data_type.wgsl_type()); 195 | 196 | let (shape0, shape1) = (&operands[0].shape, &operands[1].shape); 197 | let (strides1, strides2) = (&operands[0].strides, &operands[1].strides); 198 | 199 | let (adj_shape0, adj_shape1, adj_strides0, adj_strides1) = if shape0 == shape1 { 200 | ( 201 | shape0.clone(), 202 | shape1.clone(), 203 | strides1.clone(), 204 | strides2.clone(), 205 | ) 206 | } else { 207 | compute_broadcasted_shape_and_strides(&shape0, &shape1, &strides1, &strides2) 208 | }; 209 | 210 | let left_broadcast = shape0 != &adj_shape0; 211 | let right_broadcast = shape1 != &adj_shape1; 212 | if left_broadcast { 213 | params.insert("left_broadcast", &true); 214 | let left_numel = shape1.iter().fold(1, |x, y| x * y); 215 | // if lhs element length is 1 (scalar), then don't generate anything 216 | let idx0_code = if left_numel == 0 { 217 | "".into() 218 | } else { 219 | generate_idx_code("idx0", &adj_shape0, &adj_strides0) 220 | }; 221 | params.insert("idx0_code", &idx0_code); 222 | } 223 | if right_broadcast { 224 | params.insert("right_broadcast", &true); 225 | let right_numel = shape1.iter().fold(1, |x, y| x * y); 226 | // if rhs element length is 1 (scalar), then don't generate anything 227 | let idx1_code = if right_numel == 0 { 228 | "".into() 229 | } else { 230 | generate_idx_code("idx1", &adj_shape1, &adj_strides1) 231 | }; 232 | params.insert("idx1_code", &idx1_code); 233 | } 234 | 235 | params.insert("input_0_shape_csv", &vec_to_csv(&adj_shape0)); 236 | params.insert("input_0_strides_csv", &vec_to_csv(&adj_strides0)); 237 | params.insert("input_0_ndim", &adj_shape0.len()); 238 | params.insert("input_1_shape_csv", &vec_to_csv(&adj_shape1)); 239 | params.insert("input_1_strides_csv", &vec_to_csv(&adj_strides1)); 240 | params.insert("input_1_ndim", &adj_shape1.len()); 241 | params.insert("output_shape_csv", &vec_to_csv(&output.shape)); 242 | params.insert("output_ndim", &output.shape.len()); 243 | params.insert("output_len", &output.shape.iter().fold(1, |x, y| x * y)); 244 | 245 | params.insert("binop_stmt", binop_stmt); 246 | 247 | let local_size_x = 16; // Use a smaller workgroup size for better distribution 248 | let local_size_y = 16; 249 | 250 | let out_shape = &output.shape; 251 | let num_elements = out_shape.iter().fold(1, |x, y| x * y); 252 | 253 | // Calculate the number of workgroups needed in each dimension 254 | let num_workgroups_x = ((out_shape[0] + local_size_x - 1) / local_size_x) as u32; 255 | let num_workgroups_y = ((num_elements / out_shape[0] + local_size_y - 1) / local_size_y) as u32; 256 | 257 | (num_workgroups_x, num_workgroups_y, 1) 258 | } 259 | 260 | fn prepare_reduction_shader( 261 | operands: Vec<&GPUArray>, 262 | output: &GPUArray, 263 | params: &mut Context, 264 | reduction_stmt: &str, 265 | postproc_stmt: &str, 266 | ) -> (u32, u32, u32) { 267 | let input_shape = &operands[0].shape; 268 | let input_len = input_shape.iter().fold(1, |x, y| x * y); 269 | params.insert("input_type", &operands[0].data_type.wgsl_type()); 270 | params.insert("output_type", &output.data_type.wgsl_type()); 271 | params.insert("input_len", &input_len); 272 | params.insert("reduction_stmt", &reduction_stmt); 273 | params.insert("postproc_stmt", &postproc_stmt); 274 | 275 | (1, 1, 1) 276 | } 277 | -------------------------------------------------------------------------------- /src/array.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::ops::{Div, Mul, Sub}; 3 | 4 | use bytemuck::Pod; 5 | use ndarray::{ArrayD, Axis, Ix2}; 6 | 7 | use crate::error::TensoriaError; 8 | use crate::gpu::gpu_array::{GPUArray, GetType}; 9 | use crate::traits::TensoriaOps; 10 | 11 | #[derive(PartialEq, Debug)] 12 | pub enum Device { 13 | CPU, 14 | GPU, 15 | } 16 | 17 | #[derive(Debug)] 18 | pub enum ArrayData { 19 | CPUArray(ArrayD), 20 | GPUArray(GPUArray), 21 | } 22 | 23 | impl ArrayData { 24 | pub(crate) fn device(&self) -> Device { 25 | match self { 26 | ArrayData::CPUArray(_) => Device::CPU, 27 | ArrayData::GPUArray(_) => Device::GPU, 28 | } 29 | } 30 | } 31 | 32 | impl ArrayData 33 | where 34 | EType: TensoriaOps + Clone + Pod + Default + Debug, 35 | Vec: GetType, 36 | { 37 | pub fn new_cpu>( 38 | shape: S, 39 | data: Vec, 40 | ) -> Result, TensoriaError> { 41 | Ok(ArrayData::CPUArray( 42 | ArrayD::from_shape_vec(shape.as_ref(), data) 43 | .map_err(|_| TensoriaError::CannotReshapeError {})?, 44 | )) 45 | } 46 | 47 | pub fn new_gpu>( 48 | shape: S, 49 | data: Vec, 50 | ) -> Result, TensoriaError> { 51 | let len = shape.as_ref().iter().product::(); 52 | if len != data.len() { 53 | return Err(TensoriaError::CannotReshapeError {}); 54 | } 55 | Ok(ArrayData::GPUArray(GPUArray::new( 56 | data, 57 | shape.as_ref().to_vec(), 58 | ))) 59 | } 60 | 61 | pub fn clone(&self) -> Self { 62 | match self { 63 | ArrayData::CPUArray(data) => Self::CPUArray(data.clone()), 64 | ArrayData::GPUArray(data) => Self::GPUArray(data.clone()), 65 | } 66 | } 67 | 68 | pub fn ndim(&self) -> usize { 69 | match self { 70 | ArrayData::CPUArray(data) => data.ndim(), 71 | ArrayData::GPUArray(data) => data.shape.len(), 72 | } 73 | } 74 | 75 | pub fn shape(&self) -> Vec { 76 | match self { 77 | ArrayData::CPUArray(data) => data.shape().to_vec(), 78 | ArrayData::GPUArray(data) => data.shape.clone(), 79 | } 80 | } 81 | } 82 | 83 | /// Following set of implementations are related to public arithmetic functions 84 | impl ArrayData 85 | where 86 | EType: TensoriaOps + Clone + Pod + Default + Debug, 87 | Vec: GetType, 88 | { 89 | fn arr_add(&self, other: &ArrayData) -> ArrayData { 90 | match (self, other) { 91 | (ArrayData::CPUArray(ldata), ArrayData::CPUArray(rdata)) => { 92 | ArrayData::CPUArray(ldata + rdata) 93 | } 94 | (ArrayData::GPUArray(ldata), ArrayData::GPUArray(rdata)) => { 95 | ArrayData::GPUArray(ldata + rdata) 96 | } 97 | _ => panic!("cannot add tensors from different device"), 98 | } 99 | } 100 | 101 | fn arr_div(&self, other: &ArrayData) -> ArrayData { 102 | match (self, other) { 103 | (ArrayData::CPUArray(ldata), ArrayData::CPUArray(rdata)) => { 104 | ArrayData::CPUArray(ldata / rdata) 105 | } 106 | (ArrayData::GPUArray(_ldata), ArrayData::GPUArray(_rdata)) => { 107 | todo!(); 108 | // ArrayData::GPUArray(ldata / rdata) 109 | } 110 | _ => panic!("cannot add tensors from different device"), 111 | } 112 | } 113 | fn arr_mul(&self, other: &ArrayData) -> ArrayData { 114 | match (self, other) { 115 | (ArrayData::CPUArray(ldata), ArrayData::CPUArray(rdata)) => { 116 | ArrayData::CPUArray(ldata.mul(rdata)) 117 | } 118 | (ArrayData::GPUArray(ldata), ArrayData::GPUArray(rdata)) => { 119 | ArrayData::GPUArray(ldata.mul(rdata)) 120 | } 121 | _ => panic!("cannot add tensors from different device"), 122 | } 123 | } 124 | fn arr_sub(&self, other: &ArrayData) -> ArrayData { 125 | match (self, other) { 126 | (ArrayData::CPUArray(ldata), ArrayData::CPUArray(rdata)) => { 127 | ArrayData::CPUArray(ldata.sub(rdata)) 128 | } 129 | (ArrayData::GPUArray(ldata), ArrayData::GPUArray(rdata)) => { 130 | ArrayData::GPUArray(ldata.sub(rdata)) 131 | } 132 | _ => panic!("cannot add tensors from different device"), 133 | } 134 | } 135 | 136 | pub fn div_scalar_f32(&self, other: f32) -> ArrayData { 137 | match self { 138 | ArrayData::CPUArray(data) => { 139 | let data_scaled = data 140 | .map(|v| EType::from(v.to_f32().unwrap() / other).unwrap()) 141 | .to_owned() 142 | .into_raw_vec(); 143 | ArrayData::new_cpu(data.shape(), data_scaled).unwrap() 144 | } 145 | ArrayData::GPUArray(data) => ArrayData::GPUArray(data.div(&GPUArray::new_with_ctx( 146 | &data.context, 147 | vec![EType::from(other).unwrap()], 148 | vec![1], 149 | ))), 150 | } 151 | } 152 | 153 | pub fn exp(&self) -> ArrayData { 154 | match self { 155 | ArrayData::CPUArray(data) => ArrayData::CPUArray( 156 | data.mapv(|v| { 157 | let vf = EType::from(EType::to_f32(&v).unwrap().exp()).unwrap(); 158 | vf 159 | }) 160 | .into_dyn(), 161 | ), 162 | ArrayData::GPUArray(_) => todo!(), 163 | } 164 | } 165 | pub fn powi(&self, exp: i32) -> ArrayData { 166 | match self { 167 | ArrayData::CPUArray(data) => ArrayData::CPUArray(data.mapv(|v| { 168 | let vf = v.to_f32().unwrap().powi(exp); 169 | EType::from(vf).unwrap() 170 | })), 171 | ArrayData::GPUArray(_) => todo!(), 172 | } 173 | } 174 | 175 | pub fn matmul(&self, other: &ArrayData) -> ArrayData { 176 | let l_ndim = self.ndim(); 177 | let r_ndim = other.ndim(); 178 | if (l_ndim != 2) && (r_ndim != 2) { 179 | panic!( 180 | "Both tensors must be of rank-2, but got rank-{} and rank-{} tensors", 181 | l_ndim, r_ndim 182 | ); 183 | } 184 | 185 | let (l_shape, r_shape) = (self.shape(), other.shape()); 186 | if l_shape[1] != r_shape[0] { 187 | panic!("Incompatible shape: {:?} and {:?}", l_shape, r_shape); 188 | } 189 | 190 | match (self, other) { 191 | (ArrayData::CPUArray(ldata), ArrayData::CPUArray(rdata)) => { 192 | let ldata_2d = ldata.to_owned().into_dimensionality::().unwrap(); 193 | let rdata_2d = rdata.to_owned().into_dimensionality::().unwrap(); 194 | 195 | ArrayData::CPUArray(ldata_2d.dot(&rdata_2d).into_dyn()) 196 | } 197 | (ArrayData::GPUArray(ldata), ArrayData::GPUArray(rdata)) => { 198 | ArrayData::GPUArray(ldata.matmul(rdata)) 199 | } 200 | _ => panic!("cannot add tensors from different device"), 201 | } 202 | } 203 | 204 | pub fn mean(&self, axis: Option, keep_dim: bool) -> ArrayData { 205 | match self { 206 | ArrayData::CPUArray(data) => { 207 | match axis { 208 | None => { 209 | // TODO: handle scalar rank-0 "array" 210 | let mu = data.mean().unwrap(); 211 | ArrayData::new_cpu([1], vec![mu]).unwrap() 212 | } 213 | Some(axis) => { 214 | let mut mu = data.mean_axis(Axis(axis)).unwrap(); 215 | if keep_dim { 216 | mu = mu.insert_axis(Axis(axis)); 217 | } 218 | ArrayData::CPUArray(mu) 219 | } 220 | } 221 | } 222 | ArrayData::GPUArray(data) => match axis { 223 | None => ArrayData::GPUArray(data.mean()), 224 | Some(axis) => ArrayData::GPUArray(data.mean_axis(axis as i32, keep_dim)), 225 | }, 226 | } 227 | } 228 | 229 | pub fn scale(&self, v: EType) -> ArrayData { 230 | match self { 231 | ArrayData::CPUArray(data) => Self::CPUArray(data.map(|item| *item * v)), 232 | ArrayData::GPUArray(_) => { 233 | todo!() 234 | } 235 | } 236 | } 237 | 238 | pub fn sum(&self, axis: Option, keep_dim: bool) -> ArrayData { 239 | match self { 240 | ArrayData::CPUArray(data) => { 241 | match axis { 242 | None => { 243 | // TODO: handle scalar rank-0 "array" 244 | let sum = data.sum(); 245 | ArrayData::new_cpu([1], vec![sum]).unwrap() 246 | } 247 | Some(axis) => { 248 | let mut sum = data.sum_axis(Axis(axis)); 249 | if keep_dim { 250 | sum = sum.insert_axis(Axis(axis)); 251 | } 252 | ArrayData::CPUArray(sum) 253 | } 254 | } 255 | } 256 | ArrayData::GPUArray(data) => match axis { 257 | None => ArrayData::GPUArray(data.sum()), 258 | Some(axis) => ArrayData::GPUArray(data.sum_axis(axis as i32, keep_dim)), 259 | }, 260 | } 261 | } 262 | pub fn t(&self) -> ArrayData { 263 | if self.ndim() != 2 { 264 | panic!( 265 | "Can only transpose a rank-2 tensor, got rank-{}", 266 | self.ndim() 267 | ); 268 | } 269 | match self { 270 | ArrayData::CPUArray(data) => ArrayData::CPUArray( 271 | data.to_owned() 272 | .into_dimensionality::() 273 | .unwrap() 274 | .t() 275 | .to_owned() 276 | .into_dyn(), 277 | ), 278 | ArrayData::GPUArray(_) => { 279 | todo!("Transpose GPUArray is not implemented yet") 280 | } 281 | } 282 | } 283 | } 284 | 285 | macro_rules! impl_bin_op { 286 | ($trait:ident, $trait_method:ident, $array_method:ident) => { 287 | impl std::ops::$trait<&Self> for ArrayData 288 | where 289 | EType: TensoriaOps + Clone + Pod + Default + Debug, 290 | Vec: GetType, 291 | { 292 | type Output = ArrayData; 293 | 294 | fn $trait_method(self, rhs: &Self) -> Self::Output { 295 | self.$array_method(rhs) 296 | } 297 | } 298 | 299 | impl std::ops::$trait for &ArrayData 300 | where 301 | EType: TensoriaOps + Clone + Pod + Default + Debug, 302 | Vec: GetType, 303 | { 304 | type Output = ArrayData; 305 | 306 | fn $trait_method(self, rhs: Self) -> Self::Output { 307 | self.$array_method(rhs) 308 | } 309 | } 310 | 311 | impl std::ops::$trait> for &ArrayData 312 | where 313 | EType: TensoriaOps + Clone + Pod + Default + Debug, 314 | Vec: GetType, 315 | { 316 | type Output = ArrayData; 317 | 318 | fn $trait_method(self, rhs: ArrayData) -> Self::Output { 319 | self.$array_method(&rhs) 320 | } 321 | } 322 | 323 | impl std::ops::$trait for ArrayData 324 | where 325 | EType: TensoriaOps + Clone + Pod + Default + Debug, 326 | Vec: GetType, 327 | { 328 | type Output = ArrayData; 329 | 330 | fn $trait_method(self, rhs: Self) -> Self::Output { 331 | self.$array_method(&rhs) 332 | } 333 | } 334 | 335 | impl std::ops::$trait for &ArrayData 336 | where 337 | EType: TensoriaOps + Clone + Pod + Default + Debug, 338 | Vec: GetType, 339 | { 340 | type Output = ArrayData; 341 | 342 | fn $trait_method(self, rhs: EType) -> Self::Output { 343 | let scalar = match self { 344 | ArrayData::CPUArray(_) => ArrayData::new_cpu([1], vec![rhs]).unwrap(), 345 | ArrayData::GPUArray(_) => ArrayData::new_gpu([1], vec![rhs]).unwrap(), 346 | }; 347 | self.$array_method(&scalar) 348 | } 349 | } 350 | 351 | impl std::ops::$trait for ArrayData 352 | where 353 | EType: TensoriaOps + Clone + Pod + Default + Debug, 354 | Vec: GetType, 355 | { 356 | type Output = ArrayData; 357 | 358 | fn $trait_method(self, rhs: EType) -> Self::Output { 359 | let scalar = match self { 360 | ArrayData::CPUArray(_) => ArrayData::new_cpu([1], vec![rhs]).unwrap(), 361 | ArrayData::GPUArray(_) => ArrayData::new_gpu([1], vec![rhs]).unwrap(), 362 | }; 363 | self.$array_method(&scalar) 364 | } 365 | } 366 | }; 367 | } 368 | 369 | impl_bin_op!(Add, add, arr_add); 370 | impl_bin_op!(Div, div, arr_div); 371 | impl_bin_op!(Mul, mul, arr_mul); 372 | impl_bin_op!(Sub, sub, arr_sub); 373 | -------------------------------------------------------------------------------- /src/tensor.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::ops::{Add, Div, Mul, Sub}; 3 | use std::sync::RwLockReadGuard; 4 | use std::sync::{Arc, RwLock}; 5 | 6 | use bytemuck::Pod; 7 | use uuid::Uuid; 8 | 9 | use crate::array::{ArrayData, Device}; 10 | use crate::error::TensoriaError; 11 | use crate::gpu::gpu_array::GetType; 12 | use crate::traits::TensoriaOps; 13 | 14 | pub struct TensorPointer { 15 | pub(crate) data: ArrayData, 16 | grad: Option>, 17 | deps: Vec>>>, 18 | grad_fn: Option>, 19 | } 20 | 21 | pub struct Tensor { 22 | // for debugging purpose 23 | id: Uuid, 24 | 25 | tp: Arc>>, 26 | requires_grad: bool, 27 | } 28 | 29 | pub(crate) type UnOpFn = Box) -> ArrayData>; 30 | pub(crate) type BinOpFn = 31 | Box, &ArrayData) -> ArrayData>; 32 | pub(crate) type GradFn = fn( 33 | old_grad: &ArrayData, 34 | parent_grad: &ArrayData, 35 | parent: &Arc>>, 36 | ) -> ArrayData; 37 | 38 | impl Tensor 39 | where 40 | EType: TensoriaOps + Clone + Pod + Default + Debug, 41 | Vec: GetType, 42 | { 43 | pub fn new>( 44 | shape: S, 45 | data: Vec, 46 | ) -> Result, TensoriaError> { 47 | let tp = Arc::new(RwLock::new(TensorPointer { 48 | data: ArrayData::new_cpu(shape, data)?, 49 | grad: None, 50 | grad_fn: None, 51 | deps: Default::default(), 52 | })); 53 | return Ok(Self { 54 | id: Uuid::new_v4(), 55 | tp, 56 | requires_grad: false, 57 | }); 58 | } 59 | 60 | pub fn new_gpu>( 61 | shape: S, 62 | data: Vec, 63 | ) -> Result, TensoriaError> { 64 | let tp = Arc::new(RwLock::new(TensorPointer { 65 | data: ArrayData::new_gpu(shape, data)?, 66 | grad: None, 67 | grad_fn: None, 68 | deps: Default::default(), 69 | })); 70 | return Ok(Self { 71 | id: Uuid::new_v4(), 72 | tp, 73 | requires_grad: false, 74 | }); 75 | } 76 | 77 | pub fn backward(&self) -> Result<(), TensoriaError> { 78 | if !self.requires_grad { 79 | return Err(TensoriaError::BackwardOnTensorWithNoGrad); 80 | } 81 | 82 | let shape = self.shape(); 83 | let num_el = shape.iter().fold(1, |x, y| x * y); 84 | let initial_grad_vec = vec![EType::from(1.).unwrap(); num_el]; 85 | let initial_grad = match self.device() { 86 | Device::CPU => ArrayData::new_cpu(shape, initial_grad_vec)?, 87 | Device::GPU => ArrayData::new_gpu(shape, initial_grad_vec)?, 88 | }; 89 | self.tp.write().unwrap().grad = Some(initial_grad); 90 | 91 | self.backward_from_tp(&self.tp); 92 | 93 | Ok(()) 94 | } 95 | 96 | fn backward_from_tp(&self, tp: &Arc>>) { 97 | // recursively invoke backward on each 98 | let parent_grad_opt = &tp.read().unwrap().grad; 99 | 100 | for dep in &tp.read().unwrap().deps { 101 | let new_grad = { 102 | let grad_fn_opt = &dep.read().unwrap().grad_fn; 103 | if let (Some(old_grad), Some(parent_grad), Some(grad_fn)) = 104 | (&dep.read().unwrap().grad, parent_grad_opt, grad_fn_opt) 105 | { 106 | // calculate new grad for dep 107 | let new_grad = grad_fn(old_grad, parent_grad, tp); 108 | Some(new_grad) 109 | } else { 110 | None 111 | } 112 | }; 113 | dep.write().unwrap().grad = new_grad; 114 | self.backward_from_tp(dep); 115 | } 116 | } 117 | 118 | pub fn shape(&self) -> Vec { 119 | match &self.tp.read().unwrap().data { 120 | ArrayData::CPUArray(arr) => arr.shape().to_vec(), 121 | ArrayData::GPUArray(arr) => arr.shape.to_vec(), 122 | } 123 | } 124 | 125 | pub fn zero_grad(&mut self) { 126 | if self.requires_grad { 127 | self.set_requires_grad(true) 128 | } 129 | } 130 | 131 | pub fn set_requires_grad(&mut self, val: bool) { 132 | self.requires_grad = val; 133 | if val { 134 | let shape = self.shape(); 135 | let numel = shape.iter().fold(1, |x, y| x * y); 136 | let zeros = vec![EType::default(); numel]; 137 | let zero_grad = Some(match &self.tp.read().unwrap().data { 138 | ArrayData::CPUArray(_) => ArrayData::new_cpu(&shape, zeros).unwrap(), 139 | ArrayData::GPUArray(_) => ArrayData::new_gpu(shape, zeros).unwrap(), 140 | }); 141 | 142 | self.tp.write().unwrap().grad = zero_grad; 143 | } else { 144 | self.tp.write().unwrap().grad = None; 145 | } 146 | } 147 | 148 | /// Fetch ArrayData with device type corresponding to Self. This method 149 | /// performs data copy. 150 | pub fn data(&self) -> ArrayData { 151 | self.tp.read().unwrap().data.clone() 152 | } 153 | 154 | pub fn to_vec(&self) -> Vec { 155 | let data_ref = &self.tp.read().unwrap().data; 156 | match data_ref { 157 | ArrayData::CPUArray(data) => data.to_owned().into_raw_vec(), 158 | ArrayData::GPUArray(data) => data.to_vec(), 159 | } 160 | } 161 | 162 | pub fn device(&self) -> Device { 163 | self.tp.read().unwrap().data.device() 164 | } 165 | 166 | pub fn grad(&self) -> Option> { 167 | match &self.tp.read().unwrap().grad { 168 | None => None, 169 | Some(data) => Some(data.clone()), 170 | } 171 | } 172 | 173 | pub fn grad_vec(&self) -> Option> { 174 | let data_ref = &self.tp.read().unwrap().grad; 175 | match data_ref { 176 | None => None, 177 | Some(arr) => match arr { 178 | ArrayData::CPUArray(val) => Some(val.to_owned().into_raw_vec()), 179 | ArrayData::GPUArray(val) => Some(val.to_vec()), 180 | }, 181 | } 182 | } 183 | 184 | /// This will return a Tensor with data located in the GPU. 185 | /// This operation will detach the tensor from the graph. 186 | pub fn to_gpu(&self) -> Result { 187 | let data = &self.tp.read().unwrap().data; 188 | let res = match data { 189 | ArrayData::CPUArray(arr) => { 190 | let mut new_arr = Self { 191 | id: self.id, 192 | tp: Arc::new(RwLock::new(TensorPointer { 193 | data: ArrayData::new_gpu( 194 | arr.shape().to_vec(), 195 | arr.as_standard_layout().as_slice().unwrap().to_vec(), 196 | ) 197 | .unwrap(), 198 | grad: None, 199 | grad_fn: self.tp.read().unwrap().grad_fn, 200 | deps: vec![], 201 | })), 202 | requires_grad: false, 203 | }; 204 | 205 | new_arr.set_requires_grad(self.requires_grad); 206 | Ok(new_arr) 207 | } 208 | ArrayData::GPUArray(_) => Err(TensoriaError::AlreadyGPUTensor {}), 209 | }; 210 | res 211 | } 212 | 213 | pub fn update_data) -> ArrayData>(&self, update_fn: F) { 214 | let new_data = update_fn(&self.tp.read().unwrap().data); 215 | self.tp.write().unwrap().data = new_data; 216 | } 217 | } 218 | 219 | impl Tensor 220 | where 221 | EType: TensoriaOps + Clone + Pod + Default + Debug, 222 | Vec: GetType, 223 | { 224 | pub fn zeros>(_shape: Shape) -> Self { 225 | todo!() 226 | } 227 | } 228 | 229 | impl Tensor 230 | where 231 | EType: TensoriaOps + Clone + Pod + Default + Debug, 232 | Vec: GetType, 233 | { 234 | pub fn matmul(&self, other: &Tensor) -> Self { 235 | let lgf: Option> = Some(|lg, og, parent| { 236 | let parent = &parent.read().unwrap(); 237 | let rhs = &parent.deps[1].read().unwrap().data; 238 | lg.add(&og.matmul(&rhs.t())) 239 | }); 240 | let rgf: Option> = Some(|rg, og, parent| { 241 | let parent = &parent.read().unwrap(); 242 | let lhs = &parent.deps[0].read().unwrap().data; 243 | rg.add(&lhs.t().matmul(&og)) 244 | }); 245 | let add_fn: BinOpFn = Box::new(|a, b| a.matmul(&b)); 246 | self.tensor_binop(other, add_fn, lgf, rgf) 247 | } 248 | pub fn tensor_add(&self, other: &Tensor) -> Self { 249 | let lgf: Option> = Some(|lg, og, _| { 250 | let og = gradient_broadcasting(lg, og); 251 | lg.add(og) 252 | }); 253 | let rgf: Option> = Some(|rg, og, _| { 254 | let og = gradient_broadcasting(rg, og); 255 | rg.add(&og) 256 | }); 257 | let add_fn: BinOpFn = Box::new(|a, b| a.add(b)); 258 | self.tensor_binop(other, add_fn, lgf, rgf) 259 | } 260 | 261 | pub fn tensor_mul(&self, other: &Tensor) -> Self { 262 | let lgf: Option> = Some(|lg, og, parent| { 263 | let og = gradient_broadcasting(lg, og); 264 | let parent = &parent.read().unwrap(); 265 | let rhs = &parent.deps[1].read().unwrap().data; 266 | lg.add(&rhs.mul(&og)) 267 | }); 268 | let rgf: Option> = Some(|rg, og, parent| { 269 | let og = gradient_broadcasting(rg, og); 270 | let parent = &parent.read().unwrap(); 271 | let lhs = &parent.deps[0].read().unwrap().data; 272 | rg.add(&lhs.mul(&og)) 273 | }); 274 | let mul_fn: BinOpFn = Box::new(|a, b| a.mul(b)); 275 | self.tensor_binop(other, mul_fn, lgf, rgf) 276 | } 277 | 278 | pub fn tensor_div(&self, other: &Tensor) -> Self { 279 | let lgf: Option> = Some(|lg, og, parent| { 280 | let parent = &parent.read().unwrap(); 281 | let rhs = &parent.deps[1].read().unwrap().data; 282 | let local_grad = &og.div(rhs); 283 | lg.add(gradient_broadcasting(&lg, local_grad)) 284 | }); 285 | let rgf: Option> = Some(|rg, og, parent| { 286 | let parent = &parent.read().unwrap(); 287 | let lhs = &parent.deps[0].read().unwrap().data; 288 | let rhs = &parent.deps[1].read().unwrap().data; 289 | let local_grad = &lhs 290 | .mul(og.scale(EType::from(-1).unwrap())) 291 | .div(&rhs.powi(2)); 292 | rg.add(gradient_broadcasting(&rg, local_grad)) 293 | }); 294 | let div_fn: BinOpFn = Box::new(|a, b| a.div(b)); 295 | self.tensor_binop(other, div_fn, lgf, rgf) 296 | } 297 | 298 | pub fn tensor_sub(&self, other: &Tensor) -> Self { 299 | let lgf: Option> = Some(|lg, og, _| { 300 | let og = gradient_broadcasting(lg, og); 301 | lg.add(&og) 302 | }); 303 | let rgf: Option> = Some(|rg, og, _| { 304 | let og = gradient_broadcasting(rg, og); 305 | rg.sub(&og) 306 | }); 307 | let sub_fn: BinOpFn = Box::new(|a, b| a.sub(b)); 308 | self.tensor_binop(other, sub_fn, lgf, rgf) 309 | } 310 | 311 | pub fn exp(&self) -> Self { 312 | let exp_fn: UnOpFn = Box::new(move |data| data.exp()); 313 | let gf: Option> = Some(|g, og, parent| { 314 | let parent = &parent.read().unwrap(); 315 | let x = &parent.deps[0].read().unwrap().data; 316 | g.add(&og.mul(x.exp())) 317 | }); 318 | self.tensor_unop(exp_fn, gf) 319 | } 320 | 321 | pub fn mean(&self, axis: Option, keep_dim: bool) -> Self { 322 | let mean_fn: UnOpFn = Box::new(move |data| data.mean(axis, keep_dim)); 323 | 324 | let gf: Option> = Some(move |g, og, parent| { 325 | let shape = g.shape(); 326 | 327 | // Get axis data, that is the second dependency of the parent, i.e., deps[1]. 328 | // it will be -1 if no axis is specified and greater than or equals to zero otherwise. 329 | let axis = match &parent.read().unwrap().deps[1].read().unwrap().data { 330 | ArrayData::CPUArray(ax) => ax.first().unwrap().to_owned(), 331 | ArrayData::GPUArray(ax) => ax.to_vec()[0], 332 | }; 333 | 334 | let axis_is_specified = axis > EType::from(-1).unwrap(); 335 | let numel = if axis_is_specified { 336 | shape[axis.to_usize().unwrap()] 337 | } else { 338 | g.shape().iter().fold(1, |x, y| x * y) 339 | }; 340 | 341 | let og_mean = og.div_scalar_f32(numel as f32); 342 | if axis_is_specified { 343 | let og_mean_broadcast = gradient_broadcasting(g, &og_mean); 344 | return g.add(&og_mean_broadcast); 345 | } else { 346 | return g.add(&og_mean); 347 | } 348 | }); 349 | let res = self.tensor_unop(mean_fn, gf); 350 | 351 | // Hacky way to pass axis info to tensor's dependency, since GradFn is a function (not a closure) 352 | // and we want to avoid capturing outer scope of `gf`. 353 | let t_axis = match axis { 354 | None => Self::new([1], vec![EType::from(-1).unwrap()]), 355 | Some(axis) => Self::new([1], vec![EType::from(axis).unwrap()]), 356 | } 357 | .unwrap(); 358 | res.tp.write().unwrap().deps.push(t_axis.tp); 359 | res 360 | } 361 | 362 | pub fn sum(&self, axis: Option, keep_dim: bool) -> Self { 363 | let sum_fn: UnOpFn = Box::new(move |data| data.sum(axis, keep_dim)); 364 | 365 | let gf: Option> = Some(|g, og, parent| { 366 | let axis = match &parent.read().unwrap().deps[1].read().unwrap().data { 367 | ArrayData::CPUArray(ax) => ax.first().unwrap().to_owned(), 368 | ArrayData::GPUArray(ax) => ax.to_vec()[0], 369 | }; 370 | let axis_is_specified = axis > EType::from(-1).unwrap(); 371 | if axis_is_specified { 372 | // The gradient of sum is broadcasting the gradient back to the shape of 'g'. 373 | let og_broadcast = gradient_broadcasting(g, &og); 374 | return g.add(&og_broadcast); 375 | } else { 376 | return g.add(og); 377 | } 378 | }); 379 | let res = self.tensor_unop(sum_fn, gf); 380 | 381 | // Hacky way to pass axis info to tensor's dependency, since GradFn is a function (not a closure) 382 | // and we want to avoid capturing outer scope of `gf`. 383 | let t_axis = match axis { 384 | None => Self::new([1], vec![EType::from(-1).unwrap()]), 385 | Some(axis) => Self::new([1], vec![EType::from(axis).unwrap()]), 386 | } 387 | .unwrap(); 388 | res.tp.write().unwrap().deps.push(t_axis.tp); 389 | res 390 | } 391 | 392 | pub(crate) fn tensor_binop( 393 | &self, 394 | other: &Tensor, 395 | binop_fn: BinOpFn, 396 | l_grad_fn: Option>, 397 | r_grad_fn: Option>, 398 | ) -> Self { 399 | self.tp.write().unwrap().grad_fn = l_grad_fn; 400 | other.tp.write().unwrap().grad_fn = r_grad_fn; 401 | 402 | let ldata = &self.tp.read().unwrap().data; 403 | let rdata = &other.tp.read().unwrap().data; 404 | let res_data = binop_fn(ldata, rdata); 405 | 406 | let requires_grad = self.requires_grad || other.requires_grad; 407 | 408 | let tp = Arc::new(RwLock::new(TensorPointer { 409 | data: res_data, 410 | deps: vec![self.tp.clone(), other.tp.clone()], 411 | grad: None, 412 | grad_fn: None, 413 | })); 414 | 415 | let mut res = Self { 416 | id: Uuid::new_v4(), 417 | tp, 418 | requires_grad, 419 | }; 420 | res.set_requires_grad(requires_grad); 421 | res 422 | } 423 | 424 | pub(crate) fn tensor_unop( 425 | &self, 426 | unop_fn: UnOpFn, 427 | grad_fn: Option>, 428 | ) -> Self { 429 | self.tp.write().unwrap().grad_fn = grad_fn; 430 | 431 | let data = &self.tp.read().unwrap().data; 432 | let res_data = unop_fn(data); 433 | 434 | let requires_grad = self.requires_grad; 435 | 436 | let tp = Arc::new(RwLock::new(TensorPointer { 437 | data: res_data, 438 | deps: vec![self.tp.clone()], 439 | grad: None, 440 | grad_fn: None, 441 | })); 442 | 443 | let mut res = Self { 444 | id: Uuid::new_v4(), 445 | tp, 446 | requires_grad, 447 | }; 448 | res.set_requires_grad(requires_grad); 449 | res 450 | } 451 | } 452 | 453 | fn gradient_broadcasting( 454 | self_grad: &ArrayData, 455 | out_grad: &ArrayData, 456 | ) -> ArrayData 457 | where 458 | EType: TensoriaOps + Clone + Pod + Default + Debug, 459 | Vec: GetType, 460 | { 461 | // Sum out added dims 462 | let mut out_grad = out_grad.clone(); 463 | let ndims_added = out_grad.ndim() as i32 - self_grad.ndim() as i32; 464 | for _ in 0..ndims_added { 465 | out_grad = out_grad.sum(Some(0), false); 466 | } 467 | 468 | // Sum across broadcasted but non-added dims 469 | for (i, dim) in self_grad.shape().iter().enumerate() { 470 | if dim == &1 { 471 | out_grad = out_grad.sum(Some(i), true); 472 | } 473 | } 474 | out_grad 475 | } 476 | 477 | /// Macro for several binary operators, so it is easier to implement it for both 478 | /// op(Self, &Self) and op(&Self, &Self) 479 | macro_rules! impl_bin_op { 480 | ($trait:ident, $trait_method:ident, $tensor_method:ident) => { 481 | impl std::ops::$trait<&Self> for Tensor 482 | where 483 | EType: TensoriaOps + Clone + Pod + Default + Debug, 484 | Vec: GetType, 485 | { 486 | type Output = Tensor; 487 | 488 | fn $trait_method(self, rhs: &Self) -> Self::Output { 489 | self.$tensor_method(rhs) 490 | } 491 | } 492 | 493 | impl std::ops::$trait for &Tensor 494 | where 495 | EType: TensoriaOps + Clone + Pod + Default + Debug, 496 | Vec: GetType, 497 | { 498 | type Output = Tensor; 499 | 500 | fn $trait_method(self, rhs: Self) -> Self::Output { 501 | self.$tensor_method(rhs) 502 | } 503 | } 504 | 505 | impl std::ops::$trait> for &Tensor 506 | where 507 | EType: TensoriaOps + Clone + Pod + Default + Debug, 508 | Vec: GetType, 509 | { 510 | type Output = Tensor; 511 | 512 | fn $trait_method(self, rhs: Tensor) -> Self::Output { 513 | self.$tensor_method(&rhs) 514 | } 515 | } 516 | 517 | impl std::ops::$trait for Tensor 518 | where 519 | EType: TensoriaOps + Clone + Pod + Default + Debug, 520 | Vec: GetType, 521 | { 522 | type Output = Tensor; 523 | 524 | fn $trait_method(self, rhs: Self) -> Self::Output { 525 | self.$tensor_method(&rhs) 526 | } 527 | } 528 | 529 | impl<'a, EType> std::ops::$trait> for RwLockReadGuard<'a, Tensor> 530 | where 531 | EType: TensoriaOps + Clone + Pod + Default + Debug, 532 | Vec: GetType, 533 | { 534 | type Output = Tensor; 535 | 536 | fn $trait_method(self, rhs: Tensor) -> Self::Output { 537 | self.$tensor_method(&rhs) 538 | } 539 | } 540 | 541 | impl<'a, EType> std::ops::$trait<&RwLockReadGuard<'a, Tensor>> for Tensor 542 | where 543 | EType: TensoriaOps + Clone + Pod + Default + Debug, 544 | Vec: GetType, 545 | { 546 | type Output = Tensor; 547 | 548 | fn $trait_method(self, rhs: &RwLockReadGuard<'a, Tensor>) -> Self::Output { 549 | self.$tensor_method(&rhs) 550 | } 551 | } 552 | }; 553 | } 554 | 555 | impl_bin_op!(Add, add, tensor_add); 556 | impl_bin_op!(Div, div, tensor_div); 557 | impl_bin_op!(Mul, mul, tensor_mul); 558 | impl_bin_op!(Sub, sub, tensor_sub); 559 | 560 | /// The suites are solely to test tensor's autograd mechanism 561 | #[cfg(test)] 562 | mod test { 563 | use std::ops::Sub; 564 | 565 | use crate::error::TensoriaError; 566 | use crate::tensor::{Device, Tensor}; 567 | 568 | #[test] 569 | fn add() -> Result<(), TensoriaError> { 570 | let x = Tensor::new([2, 2], vec![1., 2., 3., 4.])?; 571 | let mut y = Tensor::new([2], vec![1., 1.])?; 572 | y.set_requires_grad(true); 573 | (&x + &y).backward()?; 574 | assert_eq!(y.grad_vec(), Some(vec![2., 2.])); 575 | Ok(()) 576 | } 577 | 578 | #[test] 579 | fn add_gpu() -> Result<(), TensoriaError> { 580 | let x = Tensor::new([1, 2], vec![1., 2.])?.to_gpu()?; 581 | let mut y = Tensor::new([1, 2], vec![3., 4.])?.to_gpu()?; 582 | y.set_requires_grad(true); 583 | 584 | let res = &x + &y; 585 | assert_eq!(res.to_vec(), vec![4., 6.]); 586 | res.backward()?; 587 | assert_eq!(y.grad_vec(), Some(vec![1., 1.])); 588 | Ok(()) 589 | } 590 | 591 | #[test] 592 | fn div() -> Result<(), TensoriaError> { 593 | let mut x = Tensor::new([2, 2], vec![1., 2., 3., 4.])?; 594 | let mut y = Tensor::new([2], vec![1., 2.])?; 595 | x.set_requires_grad(true); 596 | y.set_requires_grad(true); 597 | let res = &x / &y; 598 | assert_eq!(res.to_vec(), vec![1., 1., 3., 2.]); 599 | res.backward()?; 600 | assert_eq!(x.grad_vec(), Some(vec![1., 0.5, 1., 0.5])); 601 | assert_eq!(y.grad_vec(), Some(vec![-4., -1.5])); 602 | Ok(()) 603 | } 604 | 605 | #[test] 606 | fn sub() -> Result<(), TensoriaError> { 607 | let x = Tensor::new([1, 2], vec![1., 2.])?.to_gpu()?; 608 | let y = Tensor::new([1, 2], vec![3., 4.])?.to_gpu()?; 609 | let res = &x - &y; 610 | assert_eq!(res.to_vec(), vec![-2., -2.]); 611 | 612 | let x = Tensor::new([2], vec![1., 2.])?; 613 | let y = Tensor::new([2, 1], vec![1., 2.])?; 614 | let res_cpu = (&x - &y).to_vec(); 615 | assert_eq!(x.device(), Device::CPU); 616 | 617 | let x = Tensor::new([2], vec![1., 2.])?.to_gpu()?; 618 | let y = Tensor::new([2, 1], vec![1., 2.])?.to_gpu()?; 619 | let res_gpu = (&x - &y).to_vec(); 620 | assert_eq!(x.device(), Device::GPU); 621 | assert_eq!(res_cpu, res_gpu); 622 | 623 | Ok(()) 624 | } 625 | 626 | #[test] 627 | fn sub_gpu() -> Result<(), TensoriaError> { 628 | let x = Tensor::new([2], vec![1., 2.])?.to_gpu()?; 629 | let mut y = Tensor::new([2, 1], vec![1., 2.])?.to_gpu()?; 630 | y.set_requires_grad(true); 631 | let res_gpu = &x - &y; 632 | res_gpu.backward()?; 633 | assert_eq!(x.device(), Device::GPU); 634 | assert_eq!(y.grad_vec(), Some(vec![-2., -2.])); 635 | 636 | let x = Tensor::new([2, 2], vec![1., 2., 3., 4.])?.to_gpu()?; 637 | let mut y = Tensor::new([2, 2], vec![1., 2., 3., 4.])?.to_gpu()?; 638 | y.set_requires_grad(true); 639 | (x.sub(&y)).backward()?; 640 | assert_eq!(y.grad_vec(), Some(vec![-1.; 4])); 641 | 642 | Ok(()) 643 | } 644 | 645 | #[test] 646 | fn mul() -> Result<(), TensoriaError> { 647 | let mut x = Tensor::new([2, 2], vec![1, 2, 3, 4])?; 648 | x.set_requires_grad(true); 649 | 650 | let y = Tensor::new([2, 2], vec![2, 3, 4, 5])?; 651 | let res = &x * y; 652 | 653 | res.backward()?; 654 | assert_eq!(x.grad_vec().unwrap(), vec![2, 3, 4, 5]); 655 | Ok(()) 656 | } 657 | 658 | #[test] 659 | fn mul_gpu() -> Result<(), TensoriaError> { 660 | let mut x = Tensor::new([2, 2], vec![1, 2, 3, 4])?; 661 | x.set_requires_grad(true); 662 | (&(&x * &x) * &x).backward()?; 663 | assert_eq!(x.grad_vec().unwrap(), vec![3, 12, 27, 48]); 664 | 665 | let mut x = Tensor::new([2, 2], vec![1, 2, 3, 4])?.to_gpu()?; 666 | x.set_requires_grad(true); 667 | (&(&x * &x) * &x).backward()?; 668 | assert_eq!(x.grad_vec().unwrap(), vec![3, 12, 27, 48]); 669 | Ok(()) 670 | } 671 | 672 | #[test] 673 | fn matmul() -> Result<(), TensoriaError> { 674 | let mut x = Tensor::new([2, 2], vec![1, 2, 3, 4])?; 675 | x.set_requires_grad(true); 676 | 677 | let mut y = Tensor::new([2, 1], vec![1, 2])?; 678 | y.set_requires_grad(true); 679 | 680 | let res = x.matmul(&y); 681 | assert_eq!(res.to_vec(), vec![5, 11]); 682 | 683 | res.backward()?; 684 | assert_eq!(x.grad_vec(), Some(vec![1, 2, 1, 2])); 685 | assert_eq!(y.grad_vec(), Some(vec![4, 6])); 686 | Ok(()) 687 | } 688 | 689 | #[test] 690 | fn mean() -> Result<(), TensoriaError> { 691 | let mut x = Tensor::new([3], vec![4., 2., 3.])?; 692 | x.set_requires_grad(true); 693 | let res = x.mean(None, false); 694 | res.backward()?; 695 | assert_eq!(res.to_vec(), vec![3.]); 696 | assert_eq!(x.grad_vec(), Some(vec![1. / 3., 1. / 3., 1. / 3.])); 697 | 698 | let mut x = Tensor::new([2, 3], vec![1.; 6])?; 699 | x.set_requires_grad(true); 700 | let res = x.mean(None, false); 701 | res.backward()?; 702 | assert_eq!(res.to_vec(), vec![1.]); 703 | assert_eq!(x.grad_vec(), Some(vec![1. / 6.; 6])); 704 | 705 | let mut x = Tensor::new([2, 3], vec![2., 2., 2., 2., 2., 2.])?; 706 | x.set_requires_grad(true); 707 | let res = x.mean(Some(0), true); 708 | res.backward()?; 709 | assert_eq!(res.shape(), vec![1, 3]); 710 | assert_eq!(x.grad_vec(), Some(vec![0.5; 6])); 711 | 712 | x.zero_grad(); 713 | let res = x.mean(Some(1), true); 714 | res.backward()?; 715 | assert_eq!(x.grad_vec(), Some(vec![1. / 3.; 6])); 716 | 717 | Ok(()) 718 | } 719 | 720 | #[test] 721 | fn mean_gpu() -> Result<(), TensoriaError> { 722 | let mut x = Tensor::new([3], vec![4., 2., 3.])?.to_gpu()?; 723 | x.set_requires_grad(true); 724 | let res = x.mean(None, false); 725 | res.backward()?; 726 | assert_eq!(res.to_vec(), vec![3.]); 727 | assert_eq!(x.grad_vec(), Some(vec![1. / 3., 1. / 3., 1. / 3.])); 728 | 729 | let mut x = Tensor::new([2, 3], vec![2., 2., 2., 2., 2., 2.])?.to_gpu()?; 730 | x.set_requires_grad(true); 731 | let res = x.mean(Some(0), true); 732 | res.backward()?; 733 | assert_eq!(res.shape(), vec![1, 3]); 734 | assert_eq!(x.grad_vec(), Some(vec![0.5; 6])); 735 | 736 | Ok(()) 737 | } 738 | 739 | #[test] 740 | fn sum() -> Result<(), TensoriaError> { 741 | let mut x = Tensor::new([1, 3], vec![1., 2., 3.])?; 742 | x.set_requires_grad(true); 743 | let res = x.sum(None, false); 744 | res.backward()?; 745 | assert_eq!(res.shape(), vec![1]); 746 | assert_eq!(x.grad_vec(), Some(vec![1.; 3])); 747 | 748 | let mut x = Tensor::new([1, 3], vec![1., 2., 3.])?; 749 | x.set_requires_grad(true); 750 | let res = x.sum(Some(1), false); 751 | res.backward()?; 752 | assert_eq!(res.shape(), vec![1]); 753 | assert_eq!(x.grad_vec(), Some(vec![1.; 3])); 754 | Ok(()) 755 | } 756 | 757 | #[test] 758 | fn sum_gpu() -> Result<(), TensoriaError> { 759 | let mut x = Tensor::new([1, 3], vec![1., 2., 3.])?.to_gpu()?; 760 | x.set_requires_grad(true); 761 | let res = x.sum(None, false); 762 | res.backward()?; 763 | assert_eq!(res.shape(), vec![1]); 764 | assert_eq!(x.grad_vec(), Some(vec![1.; 3])); 765 | 766 | let mut x = Tensor::new([1, 3], vec![1., 2., 3.])?.to_gpu()?; 767 | x.set_requires_grad(true); 768 | let res = x.sum(Some(1), false); 769 | res.backward()?; 770 | assert_eq!(res.shape(), vec![1]); 771 | assert_eq!(x.grad_vec(), Some(vec![1.; 3])); 772 | Ok(()) 773 | } 774 | 775 | #[test] 776 | fn exp() -> Result<(), TensoriaError> { 777 | let mut x = Tensor::new([3, 2], vec![2., 2., 4., 4., 6., 6.])?; 778 | x.set_requires_grad(true); 779 | 780 | let res = x.exp(); 781 | assert_eq!( 782 | res.to_vec(), 783 | x.to_vec().iter().map(|v| v.exp()).collect::>() 784 | ); 785 | res.backward()?; 786 | 787 | assert_eq!( 788 | x.grad_vec().unwrap(), 789 | x.to_vec().iter().map(|v| v.exp()).collect::>() 790 | ); 791 | Ok(()) 792 | } 793 | } 794 | -------------------------------------------------------------------------------- /src/gpu/gpu_array.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | use std::fmt::{Debug, Formatter}; 3 | use std::ops::Add; 4 | use std::sync::{Arc, RwLock}; 5 | 6 | use bytemuck::Pod; 7 | use include_dir::{include_dir, Dir}; 8 | use lazy_static::lazy_static; 9 | use num_traits::{Num, NumCast}; 10 | use uuid::Uuid; 11 | use wgpu::util::DeviceExt; 12 | use wgpu::{BindGroupEntry, ComputePipeline}; 13 | 14 | use crate::gpu::context::{Executor, GPUContext}; 15 | use crate::gpu::op_type::{MatMul, Mean, Shader, Slice, Sum}; 16 | 17 | static PROJECT_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/src/gpu/wgsl/"); 18 | 19 | lazy_static! { 20 | static ref GLOBAL_CTX: GPUContext = GPUContext::new(); 21 | } 22 | 23 | #[derive(Clone)] 24 | pub enum GPUDataType { 25 | F32, 26 | I32, 27 | } 28 | 29 | pub trait GetType { 30 | fn get_type(&self) -> GPUDataType; 31 | } 32 | 33 | impl GetType for Vec { 34 | fn get_type(&self) -> GPUDataType { 35 | GPUDataType::F32 36 | } 37 | } 38 | 39 | impl GetType for Vec { 40 | fn get_type(&self) -> GPUDataType { 41 | GPUDataType::I32 42 | } 43 | } 44 | 45 | impl GPUDataType { 46 | pub fn wgsl_type(&self) -> String { 47 | let dtype = match self { 48 | GPUDataType::F32 => "f32", 49 | GPUDataType::I32 => "i32", 50 | }; 51 | dtype.into() 52 | } 53 | } 54 | 55 | pub struct GPUArray { 56 | pub id: String, 57 | pub data_type: GPUDataType, 58 | pub shape: Vec, 59 | pub strides: Vec, 60 | init_data: Option>, 61 | initializer: bool, 62 | context_id: Uuid, 63 | pub(crate) context: GPUContext, 64 | executor: Arc>, 65 | main_buffer: wgpu::Buffer, 66 | staging_buffer: wgpu::Buffer, 67 | } 68 | 69 | impl Clone for GPUArray 70 | where 71 | Vec: GetType, 72 | { 73 | fn clone(&self) -> Self { 74 | if let Some(init_data) = &self.init_data { 75 | let mut arr = GPUArray::new(init_data.clone(), self.shape.clone()); 76 | arr.id = self.id.clone(); 77 | arr.data_type = self.data_type.clone(); 78 | arr.initializer = self.initializer; 79 | arr.context_id = self.context_id; 80 | arr.executor = self.executor.clone(); 81 | return arr; 82 | } else { 83 | let init_data = self.to_vec(); 84 | let mut arr = GPUArray::new(init_data.clone(), self.shape.clone()); 85 | arr.id = self.id.clone(); 86 | arr.data_type = self.data_type.clone(); 87 | arr.initializer = self.initializer; 88 | arr.context_id = self.context_id; 89 | arr.executor = self.executor.clone(); 90 | return arr; 91 | } 92 | } 93 | } 94 | 95 | impl Debug for GPUArray { 96 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 97 | write!(f, "GPUArray(id={})", self.id) 98 | } 99 | } 100 | 101 | impl Drop for GPUArray { 102 | fn drop(&mut self) { 103 | // we shall sync upon drop 104 | let synced = self.executor.read().unwrap().synced; 105 | if !synced { 106 | self.executor.write().unwrap().sync(); 107 | } 108 | } 109 | } 110 | 111 | fn shape_to_strides(shape: &Vec) -> Vec { 112 | let mut strides = Vec::with_capacity(shape.len()); 113 | let mut stride = 1; 114 | for &dim in shape.iter().rev() { 115 | strides.push(stride); 116 | stride *= dim; 117 | } 118 | strides.reverse(); 119 | strides 120 | } 121 | 122 | pub fn compute_broadcasted_shape_and_strides( 123 | shape1: &Vec, 124 | shape2: &Vec, 125 | strides1: &Vec, 126 | strides2: &Vec, 127 | ) -> (Vec, Vec, Vec, Vec) { 128 | let length = shape1.len().max(shape2.len()); 129 | let mut adjusted_shape1 = vec![1; length]; 130 | let mut adjusted_shape2 = vec![1; length]; 131 | let mut adjusted_strides1 = vec![0; length]; 132 | let mut adjusted_strides2 = vec![0; length]; 133 | 134 | for (i, &dim) in shape1.iter().rev().enumerate() { 135 | adjusted_shape1[length - 1 - i] = dim; 136 | adjusted_strides1[length - 1 - i] = if dim == 1 { 137 | 0 138 | } else { 139 | strides1[shape1.len() - 1 - i] 140 | }; 141 | } 142 | 143 | for (i, &dim) in shape2.iter().rev().enumerate() { 144 | adjusted_shape2[length - 1 - i] = dim; 145 | adjusted_strides2[length - 1 - i] = if dim == 1 { 146 | 0 147 | } else { 148 | strides2[shape2.len() - 1 - i] 149 | }; 150 | } 151 | 152 | for i in 0..length { 153 | if adjusted_shape1[i] != adjusted_shape2[i] { 154 | if adjusted_shape1[i] == 1 { 155 | adjusted_shape1[i] = adjusted_shape2[i]; 156 | adjusted_strides1[i] = 0; 157 | } else if adjusted_shape2[i] == 1 { 158 | adjusted_shape2[i] = adjusted_shape1[i]; 159 | adjusted_strides2[i] = 0; 160 | } else { 161 | panic!("Shapes are not broadcastable"); 162 | } 163 | } 164 | } 165 | 166 | ( 167 | adjusted_shape1, 168 | adjusted_shape2, 169 | adjusted_strides1, 170 | adjusted_strides2, 171 | ) 172 | } 173 | 174 | impl GPUArray 175 | where 176 | Vec: GetType, 177 | { 178 | pub fn new(data: Vec, shape: Vec) -> Self { 179 | Self::new_with_ctx(&GLOBAL_CTX, data, shape) 180 | } 181 | 182 | pub fn new_with_ctx(context: &GPUContext, data: Vec, shape: Vec) -> Self { 183 | Self::new_with_name(context, Uuid::new_v4().to_string().as_str(), data, shape) 184 | } 185 | 186 | fn new_with_name(context: &GPUContext, id: &str, data: Vec, shape: Vec) -> Self { 187 | let storage_buf = create_storage_buf( 188 | &context.executor.read().unwrap().device, 189 | &id, 190 | Some(&data), 191 | &shape, 192 | ); 193 | let staging_buf = 194 | create_staging_buf::(&context.executor.read().unwrap().device, &id, &None, &shape); 195 | 196 | let dtype = data.get_type(); 197 | Self { 198 | id: id.to_string(), 199 | initializer: true, 200 | init_data: Some(data), 201 | context_id: context.id, 202 | context: context.clone(), 203 | data_type: dtype, 204 | shape: shape.clone(), 205 | strides: shape_to_strides(&shape), 206 | executor: context.executor.clone(), 207 | main_buffer: storage_buf, 208 | staging_buffer: staging_buf, 209 | } 210 | } 211 | 212 | pub fn mean(&self) -> GPUArray { 213 | self.un_op(vec![1], Mean {}) 214 | } 215 | 216 | pub fn mean_axis(&self, axis: i32, keep_dim: bool) -> GPUArray { 217 | let res = self.sum_axis(axis, keep_dim); 218 | let numel = self.shape[axis as usize]; 219 | let numel_arr = 220 | &Self::new_with_ctx(&self.context, vec![NumCast::from(numel).unwrap()], vec![1]); 221 | res / numel_arr 222 | } 223 | 224 | pub fn matmul(&self, other: &GPUArray) -> GPUArray { 225 | self.bin_op(other, vec![self.shape[0], other.shape[1]], MatMul {}) 226 | } 227 | 228 | pub fn slice_axis>(&self, axis: i32, indices: I) -> GPUArray { 229 | let mut out_shape = self.shape.clone(); 230 | out_shape[axis as usize] = indices.as_ref().len(); 231 | 232 | // Hackity hack to force indices buffer to be of type i32 233 | let mut indices_arr = Self::new(vec![T::default()], vec![1]); 234 | let indices_shape = vec![indices.as_ref().len()]; 235 | let indices_values = indices.as_ref().to_vec(); 236 | let idx_id = &indices_arr.id; 237 | indices_arr.data_type = GPUDataType::I32; 238 | indices_arr.executor = self.executor.clone(); 239 | indices_arr.context_id = self.context_id; 240 | indices_arr.shape = indices_shape; 241 | indices_arr.strides = shape_to_strides(&indices_arr.shape); 242 | indices_arr.main_buffer = create_storage_buf::( 243 | &indices_arr.executor.read().unwrap().device, 244 | idx_id, 245 | Some(&indices_values), 246 | &indices_arr.shape, 247 | ); 248 | indices_arr.staging_buffer = create_staging_buf::( 249 | &indices_arr.executor.read().unwrap().device, 250 | idx_id, 251 | &None, 252 | &indices_arr.shape, 253 | ); 254 | 255 | self.bin_op(&indices_arr, out_shape, Slice::new(axis)) 256 | } 257 | 258 | pub fn sum(&self) -> GPUArray { 259 | self.un_op(vec![1], Sum {}) 260 | } 261 | 262 | /// Sum along axis and squeeze the singleton axis when keep_dim is false. 263 | /// TODO: efficiently manage resource. At this moment we perform buffer-to-buffer copy 264 | /// after each slice addition 265 | pub fn sum_axis(&self, axis: i32, keep_dim: bool) -> GPUArray { 266 | let axis_len = self.shape[axis as usize]; 267 | if axis_len <= 0 { 268 | panic!("Sum axis must be positive non-zero"); 269 | } 270 | 271 | let mut res: Option> = None; 272 | for idx in 0..axis_len { 273 | let slice = self.slice_axis(axis, [idx as i32]); 274 | res = Some(match res { 275 | None => slice, 276 | Some(r) => r.add(&slice), 277 | }) 278 | } 279 | let mut res = res.unwrap(); 280 | let new_shape = if keep_dim { 281 | res.shape.clone() 282 | } else { 283 | res.shape 284 | .clone() 285 | .into_iter() 286 | .filter(|v| *v != 1) 287 | .collect::>() 288 | }; 289 | res.shape = new_shape; 290 | res.strides = shape_to_strides(&res.shape); 291 | res 292 | } 293 | 294 | /// General binary operation 295 | pub fn bin_op( 296 | &self, 297 | other: &GPUArray, 298 | out_shape: Vec, 299 | op_type: S, 300 | ) -> GPUArray { 301 | if self.context_id != other.context_id { 302 | panic!("cannot do operations on GPUArray from different execution context") 303 | } 304 | 305 | let res_id = Uuid::new_v4().to_string(); 306 | self.executor.write().unwrap().synced = false; 307 | let out_strides = shape_to_strides(&out_shape); 308 | 309 | let (res_storage_buf, staging_buf) = match &self.data_type { 310 | GPUDataType::F32 => { 311 | let storage_buf = create_storage_buf::( 312 | &self.executor.read().unwrap().device, 313 | &res_id, 314 | None, 315 | &out_shape, 316 | ); 317 | let staging_buf = create_staging_buf::( 318 | &self.executor.read().unwrap().device, 319 | &res_id, 320 | &None, 321 | &out_shape, 322 | ); 323 | (storage_buf, staging_buf) 324 | } 325 | GPUDataType::I32 => { 326 | let storage_buf = create_storage_buf::( 327 | &self.executor.read().unwrap().device, 328 | &res_id, 329 | None, 330 | &out_shape, 331 | ); 332 | let staging_buf = create_staging_buf::( 333 | &self.executor.read().unwrap().device, 334 | &res_id, 335 | &None, 336 | &out_shape, 337 | ); 338 | (storage_buf, staging_buf) 339 | } 340 | }; 341 | 342 | let res_gpu = Self { 343 | id: res_id.clone(), 344 | initializer: false, 345 | init_data: None, 346 | context_id: self.context_id, 347 | context: self.context.clone(), 348 | data_type: self.data_type.clone(), 349 | shape: out_shape, 350 | strides: out_strides, 351 | executor: Arc::clone(&self.executor), 352 | main_buffer: res_storage_buf, 353 | staging_buffer: staging_buf, 354 | }; 355 | 356 | let buf_binding_0 = &self.main_buffer; 357 | let buf_binding_1 = &other.main_buffer; 358 | let buf_binding_2 = &res_gpu.main_buffer; 359 | let buffers = vec![buf_binding_0, buf_binding_1, buf_binding_2]; 360 | 361 | let shader_template = PROJECT_DIR 362 | .get_file(op_type.shader_path()) 363 | .unwrap() 364 | .contents_utf8() 365 | .unwrap(); 366 | let mut templ = tera::Tera::default(); 367 | let mut params = tera::Context::new(); 368 | templ 369 | .add_raw_template(&op_type.shader_path(), shader_template) 370 | .unwrap(); 371 | 372 | let operands = vec![self, other]; 373 | let workgroup_sizes = op_type.prepare(operands, &res_gpu, &mut params); 374 | 375 | let shader_source = templ.render(&op_type.shader_path(), ¶ms).unwrap(); 376 | 377 | self.dispatch_compute(buffers, &shader_source, workgroup_sizes); 378 | 379 | let encoder = &mut self.executor.write().unwrap().encoder; 380 | encoder.copy_buffer_to_buffer( 381 | &res_gpu.main_buffer, 382 | 0, 383 | &res_gpu.staging_buffer, 384 | 0, 385 | res_gpu.staging_buffer.size(), 386 | ); 387 | res_gpu 388 | } 389 | 390 | /// General unary operation 391 | pub fn un_op(&self, out_shape: Vec, op_type: S) -> GPUArray { 392 | let res_id = Uuid::new_v4().to_string(); 393 | self.executor.write().unwrap().synced = false; 394 | let out_strides = shape_to_strides(&out_shape); 395 | 396 | let (res_storage_buf, staging_buf) = match &self.data_type { 397 | GPUDataType::F32 => { 398 | let storage_buf = create_storage_buf::( 399 | &self.executor.read().unwrap().device, 400 | &res_id, 401 | None, 402 | &out_shape, 403 | ); 404 | let staging_buf = create_staging_buf::( 405 | &self.executor.read().unwrap().device, 406 | &res_id, 407 | &None, 408 | &out_shape, 409 | ); 410 | (storage_buf, staging_buf) 411 | } 412 | GPUDataType::I32 => { 413 | let storage_buf = create_storage_buf::( 414 | &self.executor.read().unwrap().device, 415 | &res_id, 416 | None, 417 | &out_shape, 418 | ); 419 | let staging_buf = create_staging_buf::( 420 | &self.executor.read().unwrap().device, 421 | &res_id, 422 | &None, 423 | &out_shape, 424 | ); 425 | (storage_buf, staging_buf) 426 | } 427 | }; 428 | 429 | let res_gpu = Self { 430 | id: res_id.clone(), 431 | initializer: false, 432 | init_data: None, 433 | context_id: self.context_id, 434 | context: self.context.clone(), 435 | data_type: self.data_type.clone(), 436 | shape: out_shape, 437 | strides: out_strides, 438 | executor: Arc::clone(&self.executor), 439 | main_buffer: res_storage_buf, 440 | staging_buffer: staging_buf, 441 | }; 442 | 443 | let buf_binding_0 = &self.main_buffer; 444 | let buf_binding_1 = &res_gpu.main_buffer; 445 | let buffers = vec![buf_binding_0, buf_binding_1]; 446 | 447 | let shader_template = PROJECT_DIR 448 | .get_file(op_type.shader_path()) 449 | .unwrap() 450 | .contents_utf8() 451 | .unwrap(); 452 | let mut templ = tera::Tera::default(); 453 | let mut params = tera::Context::new(); 454 | templ 455 | .add_raw_template(&op_type.shader_path(), shader_template) 456 | .unwrap(); 457 | 458 | let operands = vec![self]; 459 | let workgroup_sizes = op_type.prepare(operands, &res_gpu, &mut params); 460 | 461 | let shader_source = templ.render(&op_type.shader_path(), ¶ms).unwrap(); 462 | 463 | self.dispatch_compute(buffers, &shader_source, workgroup_sizes); 464 | let encoder = &mut self.executor.write().unwrap().encoder; 465 | encoder.copy_buffer_to_buffer( 466 | &res_gpu.main_buffer, 467 | 0, 468 | &res_gpu.staging_buffer, 469 | 0, 470 | res_gpu.staging_buffer.size(), 471 | ); 472 | res_gpu 473 | } 474 | 475 | /// Binary operation involving shape broadcasting 476 | pub fn bin_op_broadcast(&self, other: &GPUArray, op_type: S) -> GPUArray { 477 | let (res_shape, _, _, _) = compute_broadcasted_shape_and_strides( 478 | &self.shape, 479 | &other.shape, 480 | &self.strides, 481 | &other.strides, 482 | ); 483 | 484 | return self.bin_op(other, res_shape, op_type); 485 | } 486 | 487 | /// Run compute pipeline from prepared pipeline. This is useful to run pipelines multiple 488 | /// times while using the same pipeline without recompiling shader modules. 489 | fn dispatch_compute_shader_pipeline( 490 | &self, 491 | buffers: Vec<&wgpu::Buffer>, 492 | pipeline: &ComputePipeline, 493 | wg_sizes: (u32, u32, u32), 494 | ) { 495 | let bind_group_layout = pipeline.get_bind_group_layout(0); 496 | 497 | let bind_group_entries: Vec = buffers 498 | .iter() 499 | .enumerate() 500 | .map(|(i, buf)| wgpu::BindGroupEntry { 501 | binding: i as u32, 502 | resource: buf.as_entire_binding(), 503 | }) 504 | .collect(); 505 | 506 | let bind_group = { 507 | let dev = &self.executor.read().unwrap().device; 508 | dev.create_bind_group(&wgpu::BindGroupDescriptor { 509 | label: None, 510 | layout: &bind_group_layout, 511 | entries: &bind_group_entries, 512 | }) 513 | }; 514 | 515 | { 516 | let encoder = &mut self.executor.write().unwrap().encoder; 517 | let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { 518 | label: None, 519 | timestamp_writes: None, 520 | }); 521 | cpass.set_pipeline(&pipeline); 522 | cpass.set_bind_group(0, &bind_group, &[]); 523 | 524 | let (x, y, z) = wg_sizes; 525 | cpass.dispatch_workgroups(x, y, z); 526 | } 527 | } 528 | 529 | /// Directly run compute pipeline from &str shader source 530 | fn dispatch_compute( 531 | &self, 532 | buffers: Vec<&wgpu::Buffer>, 533 | shader_source: &str, 534 | wg_sizes: (u32, u32, u32), 535 | ) { 536 | let shader_module = { 537 | let dev = &self.executor.read().unwrap().device; 538 | let module = dev.create_shader_module(wgpu::ShaderModuleDescriptor { 539 | label: None, 540 | source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(shader_source)), 541 | }); 542 | module 543 | }; 544 | 545 | let compute_pipeline = { 546 | let dev = &self.executor.read().unwrap().device; 547 | dev.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { 548 | label: None, 549 | layout: None, 550 | module: &shader_module, 551 | entry_point: "main", 552 | }) 553 | }; 554 | self.dispatch_compute_shader_pipeline(buffers, &compute_pipeline, wg_sizes); 555 | } 556 | 557 | /// This method copies the actual data from GPU, wrap it as ArrayData then 558 | /// return it. This method should be used sparingly since frequent GPU <-> CPU data 559 | /// transfer is costly. 560 | pub fn to_vec(&self) -> Vec { 561 | pollster::block_on(self.fetch()) 562 | } 563 | 564 | async fn fetch(&self) -> Vec { 565 | // if this array is an initializer, we first copy the data from the main buffer to 566 | // the staging buffer since we didn't do that by default to conserve GPU memory. 567 | if self.initializer { 568 | self.executor 569 | .write() 570 | .unwrap() 571 | .encoder 572 | .copy_buffer_to_buffer( 573 | &self.main_buffer, 574 | 0, 575 | &self.staging_buffer, 576 | 0, 577 | self.staging_buffer.size(), 578 | ); 579 | } 580 | 581 | // ensure sync 582 | self.executor.write().unwrap().sync(); 583 | 584 | let staging_buf = &self.staging_buffer; 585 | let staging_slice = staging_buf.slice(..); 586 | let (sender, receiver) = flume::bounded(1); 587 | staging_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); 588 | 589 | // TODO: is it proper to call device.poll() again here? 590 | self.executor 591 | .read() 592 | .unwrap() 593 | .device 594 | .poll(wgpu::Maintain::Wait); 595 | 596 | if let Ok(Ok(())) = receiver.recv_async().await { 597 | let data = staging_slice.get_mapped_range(); 598 | let vec_data = bytemuck::cast_slice(&data).to_vec(); 599 | 600 | drop(data); 601 | staging_buf.unmap(); 602 | return vec_data; 603 | } else { 604 | panic!("Cannot run on GPU") 605 | } 606 | } 607 | } 608 | 609 | macro_rules! impl_bin_op { 610 | ($trait:ident, $method:ident, $op:expr) => { 611 | impl std::ops::$trait<&Self> 612 | for GPUArray 613 | where 614 | Vec: GetType, 615 | { 616 | type Output = GPUArray; 617 | 618 | fn $method(self, rhs: &Self) -> Self::Output { 619 | self.bin_op_broadcast(rhs, $op) 620 | } 621 | } 622 | 623 | impl std::ops::$trait for &GPUArray 624 | where 625 | Vec: GetType, 626 | { 627 | type Output = GPUArray; 628 | 629 | fn $method(self, rhs: Self) -> Self::Output { 630 | self.bin_op_broadcast(rhs, $op) 631 | } 632 | } 633 | }; 634 | } 635 | 636 | impl_bin_op!(Mul, mul, crate::gpu::op_type::Mul {}); 637 | impl_bin_op!(Add, add, crate::gpu::op_type::Add {}); 638 | impl_bin_op!(Div, div, crate::gpu::op_type::Div {}); 639 | impl_bin_op!(Sub, sub, crate::gpu::op_type::Sub {}); 640 | 641 | pub fn create_storage_buf<'a, T: bytemuck::Pod + Default + Debug>( 642 | device: &wgpu::Device, 643 | buf_label: &str, 644 | values: Option<&'a Vec>, 645 | shape: &Vec, 646 | ) -> wgpu::Buffer { 647 | let mut n_items = shape.iter().fold(1, |x, y| x * y) as usize; 648 | // TODO: proper handling on 0-sized dims or non-zero-length shape but containing 0-length dim 649 | if n_items == 0 { 650 | n_items = 1; 651 | } 652 | let vals: Cow<'a, Vec> = match values { 653 | Some(v) => Cow::Borrowed(v), 654 | None => Cow::Owned(vec![T::default(); n_items]), 655 | }; 656 | 657 | // Some models provides tensors with empty data, i.e., with shape [0]. WGPU does not 658 | // allow zero buffer binding, so we trick it by using a "dummy" buffer binding with 659 | // size of 4 (minimum allowed) 660 | let tensor_has_data = vals.len() > 0; 661 | let data = if tensor_has_data { 662 | // We create buffer initialized with tensor's original data 663 | device.create_buffer_init(&wgpu::util::BufferInitDescriptor { 664 | label: Some(format!("{}.storage", buf_label).as_str()), 665 | contents: bytemuck::cast_slice(&vals), 666 | usage: wgpu::BufferUsages::STORAGE 667 | | wgpu::BufferUsages::COPY_DST 668 | | wgpu::BufferUsages::COPY_SRC, 669 | }) 670 | } else { 671 | // The dummy buffer 672 | device.create_buffer(&wgpu::BufferDescriptor { 673 | label: Some(format!("{}.storage", buf_label).as_str()), 674 | size: 4, 675 | usage: wgpu::BufferUsages::STORAGE 676 | | wgpu::BufferUsages::COPY_DST 677 | | wgpu::BufferUsages::COPY_SRC, 678 | mapped_at_creation: true, 679 | }) 680 | }; 681 | data 682 | } 683 | 684 | pub fn create_staging_buf<'a, T: bytemuck::Pod + Default + Debug>( 685 | device: &wgpu::Device, 686 | buf_label: &str, 687 | values: &'a Option>, 688 | shape: &Vec, 689 | ) -> wgpu::Buffer { 690 | let n_items = shape.iter().fold(1, |x, y| x * y) as usize; 691 | let vals: Cow<'a, Vec> = match values { 692 | Some(v) => Cow::Borrowed(v), 693 | None => Cow::Owned(vec![T::default(); n_items]), 694 | }; 695 | let data = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { 696 | label: Some(format!("{}.staging", buf_label).as_str()), 697 | contents: bytemuck::cast_slice(&vals), 698 | usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, 699 | }); 700 | data 701 | } 702 | 703 | #[allow(unused_imports)] 704 | mod test { 705 | use crate::gpu::context::GPUContext; 706 | use crate::gpu::gpu_array::GPUArray; 707 | 708 | #[test] 709 | fn simple_add() { 710 | let ctx = GPUContext::new(); 711 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3.], vec![1, 3]); 712 | let y = GPUArray::new_with_ctx(&ctx, vec![2., 3., 4.], vec![1, 3]); 713 | let res = &x + &y; 714 | 715 | assert_eq!(x.to_vec(), vec![1., 2., 3.]); 716 | assert_eq!(res.to_vec(), vec![3., 5., 7.]); 717 | } 718 | 719 | #[test] 720 | fn add_bcast() { 721 | let ctx = GPUContext::new(); 722 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4.], vec![2, 2]); 723 | let y = GPUArray::new_with_ctx(&ctx, vec![10., 10.], vec![2]); 724 | let res = x + &y; 725 | assert_eq!(res.to_vec(), vec![11., 12., 13., 14.]); 726 | 727 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4.], vec![2, 2]); 728 | let y = GPUArray::new_with_ctx(&ctx, vec![10., 10.], vec![2, 1]); 729 | let res = x + &y; 730 | assert_eq!(res.to_vec(), vec![11., 12., 13., 14.]); 731 | } 732 | 733 | #[test] 734 | fn add_bcast_bidirection() { 735 | let ctx = GPUContext::new(); 736 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3.], vec![3]); 737 | let y = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3.], vec![3, 1]); 738 | let res = x + &y; 739 | assert_eq!( 740 | res.to_vec(), 741 | vec![2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0] 742 | ); 743 | } 744 | 745 | #[test] 746 | fn matmul() { 747 | let ctx = GPUContext::new(); 748 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4.], vec![2, 2]); 749 | let y = GPUArray::new_with_ctx(&ctx, vec![2., 2., 2., 2.], vec![2, 2]); 750 | let res = x.matmul(&y); 751 | assert_eq!(res.to_vec(), vec![6., 6., 14., 14.]); 752 | } 753 | 754 | #[test] 755 | fn slice() { 756 | let ctx = GPUContext::new(); 757 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4., 5., 6., 7., 8., 9.], vec![3, 3]); 758 | let res = x.slice_axis(0, [1]); 759 | assert_eq!(res.to_vec(), vec![4., 5., 6.]); 760 | 761 | let ctx = GPUContext::new(); 762 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4., 5., 6., 7., 8., 9.], vec![3, 3]); 763 | let res = x.slice_axis(0, [0, 2]); 764 | assert_eq!(res.to_vec(), vec![1., 2., 3., 7., 8., 9.]); 765 | 766 | let ctx = GPUContext::new(); 767 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4., 5., 6., 7., 8., 9.], vec![3, 3]); 768 | let res = x.slice_axis(1, [0, 2]); 769 | assert_eq!(res.to_vec(), vec![1., 3., 4., 6., 7., 9.]); 770 | 771 | let ctx = GPUContext::new(); 772 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4., 5., 6.], vec![3, 2]); 773 | let res = x.slice_axis(1, [1]); 774 | assert_eq!(res.to_vec(), vec![2., 4., 6.]); 775 | 776 | let ctx = GPUContext::new(); 777 | let x = GPUArray::new_with_ctx(&ctx, vec![2, 2, 2, 2, 1, 1, 1, 1], vec![2, 2, 2]); 778 | let res = x.slice_axis(0, [1]); 779 | assert_eq!(res.to_vec(), vec![1, 1, 1, 1]); 780 | let res = x.slice_axis(1, [1]); 781 | assert_eq!(res.to_vec(), vec![2, 2, 1, 1]); 782 | } 783 | 784 | #[test] 785 | fn sum() { 786 | let ctx = GPUContext::new(); 787 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4., 5., 6.], vec![3, 2]); 788 | let res = x.sum_axis(0, true); 789 | assert_eq!(res.to_vec(), vec![9., 12.]); 790 | assert_eq!(res.shape, vec![1, 2]); 791 | let res = x.sum_axis(1, false); 792 | assert_eq!(res.to_vec(), vec![3., 7., 11.]); 793 | assert_eq!(res.shape, vec![3]); 794 | let x = GPUArray::new_with_ctx(&ctx, vec![2, 2, 2, 2, 1, 1, 1, 1], vec![2, 2, 2]); 795 | let res = x.sum_axis(0, false); 796 | assert_eq!(res.to_vec(), vec![3, 3, 3, 3]); 797 | assert_eq!(res.shape, vec![2, 2]); 798 | let res = x.sum_axis(1, true); 799 | assert_eq!(res.to_vec(), vec![4, 4, 2, 2]); 800 | assert_eq!(res.shape, vec![2, 1, 2]); 801 | 802 | // all-axis sum 803 | let x = GPUArray::new_with_ctx(&ctx, vec![1., 2., 3., 4.], vec![4]); 804 | let res = x.sum(); 805 | assert_eq!(res.to_vec(), vec![10.]); 806 | assert_eq!(res.shape, vec![1]); 807 | 808 | let x = GPUArray::new_with_ctx(&ctx, vec![1; 1000000], vec![1000000]); 809 | let res = x.sum(); 810 | assert_eq!(res.to_vec(), vec![1000000]); 811 | } 812 | } 813 | --------------------------------------------------------------------------------