├── .gitignore ├── Cargo.toml ├── README.md ├── benches └── expm.rs └── src ├── bin.rs ├── functions.rs ├── graph.rs ├── lib.rs └── tensor.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-grad" 3 | version = "0.1.0" 4 | authors = ["RustyBamboo "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | ndarray = {git = "https://github.com/RustyBamboo/ndarray", branch="wgpu"} 11 | # ndarray = { path = "../ndarray" } 12 | futures = "*" 13 | enum_dispatch = "0.3" 14 | 15 | [[bin]] 16 | name = "optimize" 17 | path = "src/bin.rs" 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autograd in Rust 2 | 3 | A small Automatic differentiation library for scalars and tensors written in Rust. 4 | 5 | Has a GPU backend via a `ndarray` [fork](https://github.com/RustyBamboo/ndarray/tree/wgpu) that provides [WebGPU](https://github.com/gfx-rs/wgpu) support. 6 | 7 | ## Supported Functions 8 | 9 | - Element wise addition, subtraction, multiplication 10 | - Matrix dot product 11 | 12 | More to come... 13 | 14 | ## Try it out 15 | 16 | Clone this repo 17 | ``` 18 | git clone https://github.com/RustyBamboo/rust-grad 19 | ``` 20 | 21 | Enter the directory 22 | ``` 23 | cd rust-grad 24 | ``` 25 | 26 | Compile and run the example `src/bin.rs` 27 | ``` 28 | cargo run 29 | ``` 30 | 31 | ## Examples 32 | 33 | - [Element-wise Operations](#element-wise-operation) 34 | - [Matrix Multiply](#matrix-multiply) 35 | - [Matrix Exponential](#matrix-exponential) 36 | - [GPU-Backend via WGPU](#gpu-backend-via-wgpu) 37 | 38 | ### Element-wise Operation 39 | 40 | ```rust 41 | use rust_grad::Graph; 42 | 43 | pub fn main() { 44 | let graph = Graph::new(); 45 | 46 | let x = graph.tensor(ndarray::arr1(&[1.0, 2.0]).into_dyn()); 47 | let y = graph.tensor(ndarray::arr1(&[3.0, 4.0]).into_dyn()); 48 | 49 | let z = (x + y) * x; 50 | 51 | z.forward(); // forward pass 52 | 53 | println!("{}", z.value()); 54 | 55 | z.backward(ndarray::Array::ones(2) 56 | .into_dyn()); // backward pass 57 | 58 | 59 | println!("dz/dz {}", z.grad()); // dz/dz [1,1] 60 | println!("dz/dx {}", x.grad()); // dz/dx [5,8] 61 | println!("dz/dy {}", y.grad()); // dz/dy [1,2] 62 | 63 | println!("Graph: {:?}", graph); 64 | } 65 | ``` 66 | 67 | #### Same Example in Torch 68 | 69 | ```python 70 | import torch 71 | 72 | x = torch.tensor([1.0, 2.0], requires_grad=True) 73 | y = torch.tensor([3.0, 4.0], requires_grad=True) 74 | z = (x + y) * x 75 | z.backward(torch.ones_like(x)) 76 | 77 | print(x.grad) # dz/dx tensor([5., 8.]) 78 | print(y.grad) # dz/dy tensor([1., 2.]) 79 | ``` 80 | 81 | ### Matrix Multiply 82 | 83 | ```rust 84 | use rust_grad::Graph; 85 | 86 | pub fn main() { 87 | let graph = Graph::new(); 88 | 89 | let x = graph.tensor(ndarray::array![[1.0, 2.0, 3.0], 90 | [4.0, 5.0, 6.0], 91 | [7.0, 8.0, 9.0]].into_dyn()); 92 | let y = graph.tensor(ndarray::array![[1.0, 2.0, 1.0], 93 | [2.0, 3.0, 2.0], 94 | [3.0, 4.0, 3.0]].into_dyn()); 95 | 96 | let z = x.matmul(y); 97 | 98 | z.forward(); // forward pass 99 | 100 | println!("{}", z.value()); 101 | 102 | z.backward(ndarray::Array::ones((3, 3)) 103 | .into_dyn()); // backward pass 104 | 105 | 106 | println!("dz/dx {}", x.grad()); 107 | println!("dz/dy {}", y.grad()); 108 | } 109 | ``` 110 | 111 | #### Same Example in Torch 112 | 113 | ```python 114 | import torch 115 | 116 | x = torch.tensor([[1.0, 2.0, 3.0], 117 | [4.0, 5.0, 6.0], 118 | [7.0, 8.0, 9.0]], requires_grad=True) 119 | y = torch.tensor([[1.0, 2.0, 1.0], 120 | [2.0, 3.0, 2.0], 121 | [3.0, 4.0, 3.0]], requires_grad=True) 122 | z = x.matmul(y) 123 | print(z) 124 | z.backward(torch.ones_like(x)) 125 | 126 | print(f"dz/dx {x.grad}") # dz/dx 127 | print(f"dz/dy {y.grad}") # dz/dy 128 | ``` 129 | 130 | ### Matrix Exponential 131 | 132 | Limited to diagonal matrices. 133 | 134 | ```rust 135 | use rust_grad::Graph; 136 | 137 | pub fn main() { 138 | let graph = Graph::new(); 139 | 140 | let x = graph.tensor(ndarray::array![[1.0, 0.0, 0.0], 141 | [0.0, 1.0, 0.0], 142 | [0.0, 0.0, 2.0]].into_dyn()); 143 | 144 | let z = x.expm(); 145 | 146 | z.forward(); // forward pass 147 | 148 | println!("{}", z.value()); // [[2.7182822, 0, 0], 149 | // [0, 2.7182822, 0], 150 | // [0, 0, 7.3890576]] 151 | 152 | z.backward(ndarray::Array::ones((3, 3)) 153 | .into_dyn()); // backward pass 154 | 155 | 156 | println!("dz/dx {}", x.grad()); // [[2.7182822, 2.7182822, 4.67016], 157 | // [2.7182822, 2.7182822, 4.67016], 158 | // [4.6694736, 4.6694736, 7.3890576]] 159 | } 160 | ``` 161 | 162 | ### Same Example in Torch 163 | 164 | ```python 165 | import torch 166 | 167 | x = torch.tensor([[1.0, 0.0, 0.0], 168 | [0.0, 1.0, 0.0], 169 | [0.0, 0.0, 2.0]], requires_grad=True) 170 | 171 | z = torch.matrix_exp(x) 172 | 173 | print(z) # tensor([[2.7183, 0.0000, 0.0000], 174 | # [0.0000, 2.7183, 0.0000], 175 | # [0.0000, 0.0000, 7.3891]]) 176 | 177 | z.backwar d(torch.ones_like(x)) 178 | 179 | print(f"dz/dx {x.grad}") # tensor([[2.7183, 2.7183, 4.6708], 180 | # [2.7183, 2.7183, 4.6708], 181 | # [4.6708, 4.6708, 7.3891]]) 182 | 183 | ``` 184 | 185 | 186 | 187 | ### GPU-Backend via WGPU 188 | 189 | ```rust 190 | use rust_grad::Graph; 191 | 192 | use futures::executor::block_on; 193 | 194 | pub fn main() { 195 | let d = block_on(ndarray::WgpuDevice::new()).expect("No GPU"); 196 | 197 | let graph = Graph::new(); 198 | 199 | let x = graph.tensor(ndarray::array![[1.0, 2.0, 3.0], 200 | [4.0, 5.0, 6.0], 201 | [7.0, 8.0, 9.0]] 202 | .into_dyn() 203 | .into_wgpu(&d)); 204 | let y = graph.tensor(ndarray::array![[1.0, 2.0, 1.0], 205 | [2.0, 3.0, 2.0], 206 | [3.0, 4.0, 3.0]] 207 | .into_dyn() 208 | .into_wgpu(&d); 209 | 210 | let z = x * y; 211 | 212 | z.forward(); // forward pass 213 | 214 | println!("{}", z.value()); 215 | 216 | z.backward(ndarray::Array::ones((3, 3)) 217 | .into_dyn() 218 | .into_wgpu(&d)); // backward pass 219 | 220 | println!("dz/dx {}", x.grad()); 221 | println!("dz/dy {}", y.grad()); 222 | } 223 | ``` 224 | 225 | ## Benchmarks 226 | 227 | Requires `nightly` edition of Rust. 228 | 229 | ``` 230 | cargo +nightly bench 231 | ``` 232 | 233 | ## Goals and TODOs 234 | 235 | - [x] Scalars 236 | - [x] Tensors 237 | - [ ] Many supported Functions 238 | - [x] Lazy execution (via `tensor.backward()` and `tensor.forward()` 239 | - [x] CPU support through `ndarray` 240 | - [x] GPU support through WebGPU 241 | - [x] Faster method calls using `enum_dispatch` 242 | 243 | ## License 244 | 245 | The license is a dual-license as detailed below. If you do use this project, I kindly ask to be credited or acknowledged (just trying to get a resume...) 246 | 247 | Licensed under the Apache License, Version 2.0 248 | http://www.apache.org/licenses/LICENSE-2.0 or the MIT license 249 | http://opensource.org/licenses/MIT, at your 250 | option. This file may not be copied, modified, or distributed 251 | except according to those terms. 252 | 253 | -------------------------------------------------------------------------------- /benches/expm.rs: -------------------------------------------------------------------------------- 1 | #![feature(test)] 2 | extern crate test; 3 | use rust_grad::Graph; 4 | 5 | use test::Bencher; 6 | 7 | #[bench] 8 | pub fn expm_cpu(b: &mut Bencher) { 9 | b.iter(|| { 10 | let x = ndarray::array![[0.01, 0.0, 0.0], [0.0, 0.01, 0.0], [0.0, 0.0, 0.01]].into_dyn(); 11 | 12 | let graph = Graph::new(); 13 | let x = graph.tensor(x); 14 | 15 | for _ in 0..10 { 16 | let z = x.expm(); 17 | z.forward(); // forward pass 18 | } 19 | }); 20 | } 21 | -------------------------------------------------------------------------------- /src/bin.rs: -------------------------------------------------------------------------------- 1 | use rust_grad::Graph; 2 | 3 | use futures::executor::block_on; 4 | 5 | pub fn main() { 6 | let d = block_on(ndarray::WgpuDevice::new()).expect("No GPU"); 7 | 8 | let x = ndarray::array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 2.0]].into_dyn(); 9 | //let x = ndarray::array![2.0].into_dyn(); 10 | 11 | //let x = ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]].into_dyn(); 12 | //let y = ndarray::array![[1.0, 2.0, 1.0], [2.0, 3.0, 2.0], [3.0, 4.0, 3.0]].into_dyn(); 13 | 14 | let ones = ndarray::Array::ones(x.shape()).into_dyn(); 15 | 16 | let x = x.into_wgpu(&d); 17 | 18 | //let y = y.into_wgpu(&d); 19 | let ones = ones.into_wgpu(&d); 20 | 21 | let graph = Graph::new(); 22 | let x = graph.tensor(x); 23 | //let y = graph.tensor(y); 24 | 25 | let z = x.expm(); 26 | 27 | //let z = x.expm(); 28 | 29 | z.forward(); // forward pass 30 | 31 | println!("{}", z.value()); 32 | 33 | z.backward(ones); // backward pass 34 | 35 | println!("dz/dz {}", z.grad()); 36 | println!("dz/dx {}", x.grad()); 37 | //println!("dz/dy {}", y.grad()); 38 | } 39 | -------------------------------------------------------------------------------- /src/functions.rs: -------------------------------------------------------------------------------- 1 | use crate::tensor::Raw; 2 | use crate::tensor::TensorType; 3 | use enum_dispatch::enum_dispatch; 4 | 5 | #[enum_dispatch(OneValuedFn)] 6 | pub enum OneValuedFnEnum<'d, T: TensorType<'d> + Clone> { 7 | ExpM(ExpM<'d, T>), 8 | } 9 | 10 | #[enum_dispatch(TwoValuedFn)] 11 | pub enum TwoValuedFnEnum<'d, T: TensorType<'d> + Clone> { 12 | Add, 13 | Mul(Mul<'d, T>), 14 | MatMul(MatMul<'d, T>), 15 | } 16 | 17 | /// 18 | /// Enum of function types: 19 | /// - None e.g.: let x = g.tensor(...); 20 | /// - Single Valued e.g.: x.sin() 21 | /// - Double Valued e.g.: x + y 22 | /// 23 | pub enum Function<'d, T: TensorType<'d> + Clone> { 24 | None, 25 | One(OneValuedFnEnum<'d, T>), 26 | Two(TwoValuedFnEnum<'d, T>), 27 | } 28 | 29 | #[enum_dispatch] 30 | pub trait OneValuedFn<'d, T: TensorType<'d>> { 31 | fn forward(&mut self, t_a: Raw<'d, T>) -> Raw<'d, T>; 32 | fn backward(&self, grad: Raw<'d, T>) -> [Option>; 2]; 33 | } 34 | 35 | #[enum_dispatch] 36 | pub trait TwoValuedFn<'d, T: TensorType<'d>> { 37 | fn forward(&mut self, t_a: Raw<'d, T>, t_b: Raw<'d, T>) -> Raw<'d, T>; 38 | fn backward(&self, grad: Raw<'d, T>) -> [Option>; 2]; 39 | } 40 | 41 | /// 42 | /// Add two tensors together element-wise 43 | /// 44 | pub struct Add; 45 | impl<'d, T: 'd + TensorType<'d>> TwoValuedFn<'d, T> for Add { 46 | fn forward(&mut self, t_a: Raw<'d, T>, t_b: Raw<'d, T>) -> Raw<'d, T> { 47 | let t_c = t_a.value().add(t_b.value()); 48 | Raw::new(t_c) 49 | } 50 | fn backward(&self, grad: Raw<'d, T>) -> [Option>; 2] { 51 | [Some(grad), Some(grad)] 52 | } 53 | } 54 | 55 | /// 56 | /// Multiply two tensors element-wise 57 | /// 58 | pub struct Mul<'d, T: 'd + TensorType<'d>> { 59 | pub x_ctx: Option>, 60 | pub y_ctx: Option>, 61 | } 62 | impl<'d, T: TensorType<'d>> TwoValuedFn<'d, T> for Mul<'d, T> { 63 | fn forward(&mut self, t_a: Raw<'d, T>, t_b: Raw<'d, T>) -> Raw<'d, T> { 64 | self.x_ctx = Some(t_a); 65 | self.y_ctx = Some(t_b); 66 | let t_c = t_a.value().mul(t_b.value()); 67 | Raw::new(t_c) 68 | } 69 | fn backward(&self, grad: Raw<'d, T>) -> [Option>; 2] { 70 | let x_ctx = self.x_ctx.unwrap(); 71 | let y_ctx = self.y_ctx.unwrap(); 72 | 73 | let a = y_ctx.value().mul(grad.value()); 74 | let b = x_ctx.value().mul(grad.value()); 75 | 76 | [Some(Raw::new(a)), Some(Raw::new(b))] 77 | } 78 | } 79 | 80 | /// 81 | /// Perform a matrix product (only on 2-D) 82 | /// TODO: support various dimensions 83 | /// 84 | pub struct MatMul<'d, T: 'd + TensorType<'d>> { 85 | pub x_ctx: Option>, 86 | pub y_ctx: Option>, 87 | } 88 | impl<'d, T: TensorType<'d>> TwoValuedFn<'d, T> for MatMul<'d, T> { 89 | fn forward(&mut self, t_a: Raw<'d, T>, t_b: Raw<'d, T>) -> Raw<'d, T> { 90 | self.x_ctx = Some(t_a); 91 | self.y_ctx = Some(t_b); 92 | 93 | let t_c = t_a.value().matmul(t_b.value()); 94 | Raw::new(t_c) 95 | } 96 | fn backward(&self, grad: Raw<'d, T>) -> [Option>; 2] { 97 | let x_ctx = self.x_ctx.unwrap().value().t(); 98 | let y_ctx = self.y_ctx.unwrap().value().t(); 99 | 100 | let a = grad.value().matmul(&y_ctx); 101 | let b = x_ctx.matmul(grad.value()); 102 | 103 | [Some(Raw::new(a)), Some(Raw::new(b))] 104 | } 105 | } 106 | 107 | // TODO: Implement more generic expm 108 | // https://dl.acm.org/doi/10.1137/S0895479895283409 109 | pub struct ExpM<'d, T: 'd + TensorType<'d>> { 110 | pub a: Option>, 111 | pub res: Option>, 112 | } 113 | impl<'d, T: TensorType<'d> + Clone> OneValuedFn<'d, T> for ExpM<'d, T> { 114 | fn forward(&mut self, t_a: Raw<'d, T>) -> Raw<'d, T> { 115 | self.a = Some(t_a); 116 | let val = t_a.value(); 117 | 118 | let eye = val.eye_like(); 119 | let t_out = eye.mul(&val.expm()); 120 | let t_out = Raw::new(t_out); 121 | self.res = Some(t_out); 122 | t_out 123 | } 124 | /// 125 | /// The backward pass implements a truncated power series for the derivative of exponential map 126 | /// of a lie group 127 | /// 128 | /// https://en.wikipedia.org/wiki/Derivative_of_the_exponential_map 129 | /// 130 | fn backward(&self, grad: Raw<'d, T>) -> [Option>; 2] { 131 | let a = self.a.unwrap().value(); 132 | let res = self.res.unwrap().value(); 133 | let grad = grad.value(); 134 | 135 | let commu = |a: &T, b: &T| a.matmul(b).sub(&b.matmul(a)); 136 | 137 | let mut p_commu = grad.clone(); 138 | let mut total = grad.clone(); 139 | 140 | let mut factorial: i32 = 1; 141 | 142 | for o in 2..7 { 143 | factorial = factorial * o; 144 | let factor = if o % 2 == 0 { -1 } else { 1 }; 145 | 146 | let new_commu = commu(a, &p_commu); 147 | p_commu = new_commu.clone(); 148 | 149 | //TODO: create an element-wise operation which does not require creation of another 150 | //matrix 151 | let fac_mat = a.val_like((factor * factorial) as f32); 152 | total = total.add(&new_commu.div(&fac_mat)); 153 | } 154 | 155 | let a = res.matmul(&total); 156 | [Some(Raw::new(a)), None] 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /src/graph.rs: -------------------------------------------------------------------------------- 1 | use crate::functions::Function; 2 | use crate::tensor::{Raw, Tensor, TensorType}; 3 | use std::cell::RefCell; 4 | use std::fmt; 5 | 6 | /// 7 | /// Represents a node in a Wengert list 8 | /// 9 | /// The node can have at most two dependencies on other nodes 10 | /// The Function enum indicates the func to apply to the value in a forward pass 11 | /// 12 | 13 | pub struct Node<'d, T: TensorType<'d> + Clone> { 14 | pub deps: [usize; 2], 15 | pub func: Function<'d, T>, 16 | pub value: Option>, 17 | pub grad: Option>, 18 | pub ctx: [Option>; 2], 19 | } 20 | 21 | impl<'d, T: TensorType<'d> + Clone> fmt::Debug for Node<'d, T> { 22 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 23 | let node_string = format!("{:?}", self.deps); 24 | 25 | write!(f, "{}", node_string) 26 | } 27 | } 28 | 29 | /// 30 | /// The Computational graph or Wengert list 31 | /// 32 | /// We want several instances to be able to push to the node list, hence RefCell> 33 | /// It may be possible to allow construction in several threads via a RwLock, 34 | /// but for now we assume single-threaded construction of the graph 35 | /// 36 | /// In addition, we have cases where we need to borrow the contents of a Node struct both mutably 37 | /// and immutably, so we wrap it with a RefCell. 38 | /// 39 | 40 | pub struct Graph<'d, T: TensorType<'d> + Clone> { 41 | pub nodes: RefCell>>>, 42 | } 43 | 44 | impl<'d, T: TensorType<'d> + Clone> Default for Graph<'d, T> { 45 | fn default() -> Self { 46 | Self::new() 47 | } 48 | } 49 | 50 | impl<'d, T: TensorType<'d> + Clone> Graph<'d, T> { 51 | /// 52 | /// Create a new graph to store the computations 53 | /// 54 | pub fn new() -> Self { 55 | Graph { 56 | nodes: RefCell::new(Vec::new()), 57 | } 58 | } 59 | 60 | pub fn len(&self) -> usize { 61 | self.nodes.borrow().len() 62 | } 63 | 64 | pub fn is_empty(&self) -> bool { 65 | self.len() == 0 66 | } 67 | 68 | /// 69 | /// Create a Tensor object which takes ownership of a TensorType 70 | /// 71 | pub fn tensor<'g>(&'g self, value: T) -> Tensor<'d, 'g, T> { 72 | let mut nodes = self.nodes.borrow_mut(); 73 | let len = nodes.len(); 74 | 75 | let value = Raw::new(value); 76 | 77 | nodes.push(RefCell::new(Node { 78 | deps: [len, len], 79 | func: Function::None, 80 | value: Some(value), 81 | grad: None, 82 | ctx: [None, None], 83 | })); 84 | Tensor { 85 | graph: self, 86 | index: len, 87 | } 88 | } 89 | } 90 | 91 | impl<'d, T: TensorType<'d> + Clone> fmt::Debug for Graph<'d, T> { 92 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 93 | let mut node_string = String::new(); 94 | for node in &*self.nodes.borrow() { 95 | node_string.push_str(format!("{:?}, ", node.borrow().deps).as_str()); 96 | } 97 | write!(f, "{}", node_string) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod functions; 2 | pub mod graph; 3 | pub mod tensor; 4 | 5 | pub use graph::Graph; 6 | 7 | #[cfg(test)] 8 | mod tests {} 9 | -------------------------------------------------------------------------------- /src/tensor.rs: -------------------------------------------------------------------------------- 1 | use crate::functions::{Function, OneValuedFn, TwoValuedFn}; 2 | use crate::graph::{Graph, Node}; 3 | use std::cell::RefCell; 4 | use std::marker::PhantomData; 5 | 6 | use ndarray::{Array, Ix2, IxDyn, WgpuArray}; 7 | 8 | /// 9 | /// The base trait for Tensor objects 10 | /// 11 | pub trait TensorType<'d> { 12 | fn get_value_cpu(&self) -> Array; 13 | fn tensor(&self) -> &Self; 14 | fn add(&self, other: &Self) -> Self; 15 | fn sub(&self, other: &Self) -> Self; 16 | fn mul(&self, other: &Self) -> Self; 17 | fn div(&self, other: &Self) -> Self; 18 | fn matmul(&self, other: &Self) -> Self; 19 | fn t(&self) -> Self; 20 | fn expm(&self) -> Self; 21 | fn val_like(&'d self, val: f32) -> Self; 22 | fn ones_like(&'d self) -> Self; 23 | fn eye_like(&'d self) -> Self; 24 | } 25 | impl<'d> TensorType<'d> for Array { 26 | fn get_value_cpu(&self) -> Array { 27 | self.clone() 28 | } 29 | fn tensor(&self) -> &Self { 30 | self 31 | } 32 | fn val_like(&'d self, val: f32) -> Self { 33 | let shape = self.shape(); 34 | Array::ones(shape) * val 35 | } 36 | fn ones_like(&'d self) -> Self { 37 | let shape = self.shape(); 38 | Array::ones(shape) 39 | } 40 | fn eye_like(&'d self) -> Self { 41 | let shape = self.shape(); 42 | Array::eye(shape[0]).into_dyn() 43 | } 44 | fn add(&self, other: &Self) -> Self { 45 | self + other 46 | } 47 | fn sub(&self, other: &Self) -> Self { 48 | self - other 49 | } 50 | fn mul(&self, other: &Self) -> Self { 51 | self * other 52 | } 53 | fn div(&self, other: &Self) -> Self { 54 | self / other 55 | } 56 | fn matmul(&self, other: &Self) -> Self { 57 | //TODO: Remove cloning (maybe by passing Raw 58 | let x = self 59 | .clone() 60 | .into_dimensionality::() 61 | .expect("Not a 2x2 matrix"); 62 | let y = other 63 | .clone() 64 | .into_dimensionality::() 65 | .expect("Not a 2x2 matrix"); 66 | 67 | (x.dot(&y)).into_dyn() 68 | } 69 | fn t(&self) -> Self { 70 | self.clone().reversed_axes() 71 | } 72 | fn expm(&self) -> Self { 73 | self.mapv(|x| x.exp()) 74 | //Array::from_iter(self.iter().map(|x| x.exp())) 75 | } 76 | } 77 | impl<'d> TensorType<'d> for WgpuArray<'d, f32, IxDyn> { 78 | fn get_value_cpu(&self) -> Array { 79 | self.clone().into_cpu() 80 | } 81 | fn tensor(&self) -> &Self { 82 | self 83 | } 84 | fn val_like(&'d self, val: f32) -> Self { 85 | let d = self.get_wgpu_device(); 86 | let shape = self.shape(); 87 | (Array::ones(shape) * val).into_wgpu(d) 88 | } 89 | fn ones_like(&'d self) -> Self { 90 | let d = self.get_wgpu_device(); 91 | let shape = self.shape(); 92 | Array::ones(shape).into_wgpu(d) 93 | } 94 | fn eye_like(&'d self) -> Self { 95 | let d = self.get_wgpu_device(); 96 | let shape = self.shape(); 97 | Array::eye(shape[0]).into_dyn().into_wgpu(d) 98 | } 99 | fn add(&self, other: &Self) -> Self { 100 | self + other 101 | } 102 | fn sub(&self, other: &Self) -> Self { 103 | self - other 104 | } 105 | fn mul(&self, other: &Self) -> Self { 106 | self * other 107 | } 108 | fn div(&self, other: &Self) -> Self { 109 | self / other 110 | } 111 | fn matmul(&self, other: &Self) -> Self { 112 | //TODO: Remove cloning (maybe by passing Raw 113 | let x = self 114 | .clone() 115 | .into_dimensionality::() 116 | .expect("Not a 2x2 matrix"); 117 | let y = other 118 | .clone() 119 | .into_dimensionality::() 120 | .expect("Not a 2x2 matrix"); 121 | 122 | (x.dot(&y)).into_dyn() 123 | } 124 | fn t(&self) -> Self { 125 | self.clone().reversed_axes() 126 | } 127 | fn expm(&self) -> Self { 128 | self.clone().exp() 129 | } 130 | } 131 | 132 | /// 133 | /// Rust doesn't have an easy way to deal with cyclic pointers. 134 | /// So this Raw struct exposes unsafe code 135 | /// TODO: Look into using Rc/Weak 136 | /// 137 | pub struct Raw<'d, T: TensorType<'d>> { 138 | pub data: *mut T, 139 | pub _phantom: PhantomData<&'d ()>, 140 | } 141 | 142 | impl<'d, T: TensorType<'d>> Raw<'d, T> { 143 | pub fn new(data: T) -> Self { 144 | let data = Box::new(data); 145 | let data = Box::into_raw(data); 146 | Raw { 147 | data, 148 | _phantom: PhantomData, 149 | } 150 | } 151 | 152 | pub fn value(&self) -> &'d T { 153 | //TODO: Manually drop memory? 154 | unsafe { &*self.data } 155 | } 156 | 157 | pub fn get_box(&self) -> Box { 158 | unsafe { Box::from_raw(self.data) } 159 | } 160 | } 161 | 162 | impl<'d, T: TensorType<'d>> Copy for Raw<'d, T> {} 163 | 164 | impl<'d, T: TensorType<'d>> Clone for Raw<'d, T> { 165 | fn clone(&self) -> Self { 166 | *self 167 | } 168 | } 169 | 170 | /// 171 | /// A Tensor struct which hold a reference to the rest of the computational graph 172 | /// as well as an index of where it is on the graph 173 | /// 174 | /// Tensors are created through the Graph: 175 | /// 176 | /// ``` 177 | /// let g = Graph::new(); 178 | /// let t = g.tensor(...); 179 | /// ``` 180 | /// 181 | pub struct Tensor<'d, 'g, T: 'd + TensorType<'d> + Clone> { 182 | pub graph: &'g Graph<'d, T>, 183 | pub index: usize, 184 | } 185 | 186 | impl<'d, T: TensorType<'d> + Clone> Copy for Tensor<'d, '_, T> {} 187 | 188 | impl<'d, T: TensorType<'d> + Clone> Clone for Tensor<'d, '_, T> { 189 | fn clone(&self) -> Self { 190 | *self 191 | } 192 | } 193 | 194 | impl<'d, 'g, T: 'd + TensorType<'d> + Clone> Tensor<'d, 'g, T> { 195 | /// 196 | /// Returns a CPU copy of the data represented by the Tensor 197 | /// 198 | pub fn value(&self) -> ndarray::ArrayD { 199 | let nodes = self.graph.nodes.borrow(); 200 | let node = nodes[self.index].borrow(); 201 | let val = node.value.as_ref().expect("Was forward called?"); 202 | 203 | val.value().get_value_cpu() 204 | } 205 | 206 | /// 207 | /// Returns a CPU copy of the gradient 208 | /// 209 | pub fn grad(&self) -> ndarray::ArrayD { 210 | let nodes = self.graph.nodes.borrow(); 211 | let node = nodes[self.index].borrow(); 212 | let val = node.grad.as_ref().expect("Was backward called?"); 213 | val.value().get_value_cpu() 214 | } 215 | 216 | /// 217 | /// Do a forward pass stopping at the current node 218 | /// 219 | /// TODO: this should ideally only flow through nodes that matter 220 | /// 221 | pub fn forward(&self) { 222 | let nodes = self.graph.nodes.borrow_mut(); 223 | 224 | for i in 0..self.index + 1 { 225 | let mut node = nodes[i].borrow_mut(); 226 | let d_0 = node.deps[0]; 227 | let d_1 = node.deps[1]; 228 | match &mut node.func { 229 | Function::None => (), 230 | Function::One(f) => { 231 | let n_l: Raw = nodes[d_0].borrow().value.unwrap(); 232 | node.value = Some(f.forward(n_l)); 233 | } 234 | Function::Two(f) => { 235 | let n_l: Raw = nodes[d_0].borrow().value.unwrap(); 236 | let n_r: Raw = nodes[d_1].borrow().value.unwrap(); 237 | node.value = Some(f.forward(n_l, n_r)); 238 | } 239 | } 240 | } 241 | } 242 | 243 | /// 244 | /// A backward pass to compute the gradients. 245 | /// 246 | /// An initial gradient is required, and in typical applications is usually 247 | /// all ones 248 | /// 249 | pub fn backward(&self, init: T) { 250 | let len = self.graph.len(); 251 | let nodes = self.graph.nodes.borrow(); 252 | 253 | { 254 | let mut node = nodes[self.index].borrow_mut(); 255 | node.grad = Some(Raw::new(init)); 256 | } 257 | 258 | for i in (0..len).rev() { 259 | { 260 | let mut node = nodes[i].borrow_mut(); 261 | 262 | match &node.func { 263 | Function::None => (), 264 | Function::One(f) => node.ctx = f.backward(node.grad.unwrap()), 265 | Function::Two(f) => node.ctx = f.backward(node.grad.unwrap()), 266 | } 267 | } 268 | 269 | let node = nodes[i].borrow(); 270 | 271 | for j in 0..2 { 272 | if std::ptr::eq(&*node, nodes[node.deps[j]].as_ptr()) { 273 | continue; 274 | } 275 | let mut node_d = nodes[node.deps[j]].borrow_mut(); 276 | 277 | if let Some(grad) = &node_d.grad { 278 | if let Some(w) = &node.ctx[j] { 279 | unsafe { 280 | *grad.data = grad.value().add(w.value()); 281 | } 282 | } 283 | } else if let Some(w) = &node.ctx[j] { 284 | node_d.grad = Some(Raw::new(w.value().clone())); 285 | } 286 | } 287 | } 288 | } 289 | 290 | pub fn matmul(self, other: Tensor<'d, 'g, T>) -> Tensor<'d, 'g, T> { 291 | assert_eq!( 292 | self.graph as *const Graph, 293 | other.graph as *const Graph 294 | ); 295 | let mut nodes = self.graph.nodes.borrow_mut(); 296 | 297 | let len = nodes.len(); 298 | 299 | use crate::functions::MatMul; 300 | nodes.push(RefCell::new(Node { 301 | deps: [self.index, other.index], 302 | func: Function::Two( 303 | MatMul { 304 | x_ctx: None, 305 | y_ctx: None, 306 | } 307 | .into(), 308 | ), 309 | value: None, 310 | grad: None, 311 | ctx: [None, None], 312 | })); 313 | Tensor { 314 | graph: self.graph, 315 | index: len, 316 | } 317 | } 318 | 319 | /// 320 | /// Take a matrix exponential 321 | /// 322 | /// Note: The matrix must be diagonal. 323 | /// 324 | /// TODO: Repeated squaring + Pade approximation for general case 325 | /// 326 | pub fn expm(self) -> Tensor<'d, 'g, T> { 327 | let mut nodes = self.graph.nodes.borrow_mut(); 328 | 329 | let len = nodes.len(); 330 | use crate::functions::ExpM; 331 | nodes.push(RefCell::new(Node { 332 | deps: [self.index, self.index], 333 | func: Function::One(ExpM { a: None, res: None }.into()), 334 | value: None, 335 | grad: None, 336 | ctx: [None, None], 337 | })); 338 | Tensor { 339 | graph: self.graph, 340 | index: len, 341 | } 342 | } 343 | } 344 | 345 | impl<'d, 'g, T: TensorType<'d> + Clone> ::std::ops::Add for Tensor<'d, 'g, T> { 346 | type Output = Tensor<'d, 'g, T>; 347 | fn add(self, other: Tensor<'d, 'g, T>) -> Self::Output { 348 | assert_eq!( 349 | self.graph as *const Graph, 350 | other.graph as *const Graph 351 | ); 352 | let mut nodes = self.graph.nodes.borrow_mut(); 353 | 354 | let len = nodes.len(); 355 | 356 | use crate::functions::Add; 357 | nodes.push(RefCell::new(Node { 358 | deps: [self.index, other.index], 359 | func: Function::Two(Add.into()), 360 | value: None, 361 | grad: None, 362 | ctx: [None, None], 363 | })); 364 | Tensor { 365 | graph: self.graph, 366 | index: len, 367 | } 368 | } 369 | } 370 | 371 | impl<'d, 'g, T: TensorType<'d> + Clone> ::std::ops::Mul for Tensor<'d, 'g, T> { 372 | type Output = Tensor<'d, 'g, T>; 373 | fn mul(self, other: Tensor<'d, 'g, T>) -> Self::Output { 374 | assert_eq!( 375 | self.graph as *const Graph, 376 | other.graph as *const Graph 377 | ); 378 | let mut nodes = self.graph.nodes.borrow_mut(); 379 | 380 | let len = nodes.len(); 381 | 382 | let m: Mul<'d, T> = Mul { 383 | x_ctx: None, 384 | y_ctx: None, 385 | }; 386 | 387 | let func: Function<'d, T> = Function::Two(m.into()); 388 | 389 | use crate::functions::Mul; 390 | nodes.push(RefCell::new(Node { 391 | deps: [self.index, other.index], 392 | func, 393 | value: None, 394 | grad: None, 395 | ctx: [None, None], 396 | })); 397 | Tensor { 398 | graph: self.graph, 399 | index: len, 400 | } 401 | } 402 | } 403 | --------------------------------------------------------------------------------