├── data ├── bigram │ ├── names_shuffle.txt │ ├── tensor_b2.json │ ├── tensor_b1.json │ ├── tensor_C.json │ └── weight_file.json ├── .DS_Store └── batchnorm │ ├── b2.json │ ├── bngain.json │ ├── bnbias.json │ └── C.json ├── examples ├── linear_layer_test.rs ├── 2_autograd_backprop.rs ├── 1_tensor_basics.rs ├── linear_and_adam_optimizer.rs ├── 3_modules_and_models.rs ├── mlp.rs └── wavenet.rs ├── .DS_Store ├── .gitignore ├── icon.webp ├── key_bias.json ├── query_bias.json ├── value_bias.json ├── src ├── nn │ ├── mod.rs │ ├── model │ │ ├── mod.rs │ │ ├── model.rs │ │ └── transfomer.rs │ └── layers │ │ ├── module.rs │ │ ├── tanh.rs │ │ ├── mod.rs │ │ ├── dropout.rs │ │ ├── embedding.rs │ │ ├── batch_norm.rs │ │ ├── attention.rs │ │ ├── linear.rs │ │ ├── mlp.rs │ │ ├── layer_norm.rs │ │ └── casual_self_attention.rs ├── optimizers │ ├── optimizer.rs │ ├── mod.rs │ ├── adamw.rs │ └── adam.rs ├── central │ ├── grad_control.rs │ ├── mod.rs │ ├── internal_tensor.rs │ ├── operation.rs │ ├── indexable.rs │ ├── add_op.rs │ ├── mul_op.rs │ ├── shape.rs │ ├── view.rs │ └── matmul_op.rs └── lib.rs ├── key.json ├── query.json ├── value.json ├── linear_check.py ├── Cargo.toml ├── double_check.py ├── README.md ├── gpt_check.py └── embedding.json /data/bigram/names_shuffle.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/linear_layer_test.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | 3 | } -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuckerBMorgan/poro/HEAD/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | *.bin 3 | *.json 4 | *.txt 5 | *.log 6 | */data/* -------------------------------------------------------------------------------- /icon.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuckerBMorgan/poro/HEAD/icon.webp -------------------------------------------------------------------------------- /key_bias.json: -------------------------------------------------------------------------------- 1 | { 2 | "data":[-0.2869], 3 | "shape":[1, 1] 4 | } -------------------------------------------------------------------------------- /query_bias.json: -------------------------------------------------------------------------------- 1 | { 2 | "data":[0.0731], 3 | "shape":[1, 1] 4 | } -------------------------------------------------------------------------------- /value_bias.json: -------------------------------------------------------------------------------- 1 | { 2 | "data":[-0.0508], 3 | "shape":[1, 1] 4 | } -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuckerBMorgan/poro/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /src/nn/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod layers; 2 | pub mod model; 3 | 4 | pub use layers::*; 5 | pub use model::*; 6 | -------------------------------------------------------------------------------- /key.json: -------------------------------------------------------------------------------- 1 | { 2 | "data":[-0.0881, -0.1676, 0.2705, -0.2251, 0.1758, -0.2274, -0.0687, -0.1183, 0.1939, 0.0342], 3 | "shape":[10, 1] 4 | } -------------------------------------------------------------------------------- /query.json: -------------------------------------------------------------------------------- 1 | { 2 | "data":[-0.2163, -0.2105, 0.0501, -0.1863, 0.2465, 0.3135, -0.2091, -0.0489, 0.2379, -0.0757], 3 | "shape":[10, 1] 4 | } -------------------------------------------------------------------------------- /value.json: -------------------------------------------------------------------------------- 1 | { 2 | "data":[-0.1852, -0.1070, -0.1148, -0.0180, -0.1549, -0.0210, 0.0651, 0.0044, 0.0598, -0.2795], 3 | "shape":[10, 1] 4 | } -------------------------------------------------------------------------------- /src/nn/model/mod.rs: -------------------------------------------------------------------------------- 1 | mod model; 2 | mod transfomer; 3 | 4 | pub use model::{Model, Sequential}; 5 | pub use transfomer::{DecoderLayer, DecoderOnlyTransformer, PositionalEncoding}; 6 | -------------------------------------------------------------------------------- /src/optimizers/optimizer.rs: -------------------------------------------------------------------------------- 1 | use crate::nn::Module; 2 | 3 | pub trait Optimizer { 4 | fn record_parameters(&mut self, model: &dyn Module); 5 | fn step(&mut self, model: &mut dyn Module); 6 | } -------------------------------------------------------------------------------- /src/optimizers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod optimizer; 2 | pub mod adam; 3 | pub mod adamw; 4 | 5 | pub use optimizer::Optimizer; 6 | pub use adam::AdamBuilder; 7 | pub use adam::Adam; 8 | pub use adamw::AdamW; 9 | pub use adamw::AdamWBuilder; -------------------------------------------------------------------------------- /src/central/grad_control.rs: -------------------------------------------------------------------------------- 1 | pub struct NoGrad {} 2 | 3 | impl NoGrad { 4 | pub fn new() -> NoGrad { 5 | let mut equation = super::get_equation(); 6 | equation.disable_grad(); 7 | NoGrad {} 8 | } 9 | } 10 | 11 | impl Drop for NoGrad { 12 | fn drop(&mut self) { 13 | let mut equation = super::get_equation(); 14 | equation.enable_grad(); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/nn/layers/module.rs: -------------------------------------------------------------------------------- 1 | 2 | // This could maybe be its own lib, along with model.rs 3 | use crate::central::Tensor; 4 | 5 | 6 | pub trait Module { 7 | fn forward(&mut self, x: &Tensor) -> Tensor; 8 | fn get_parameters(&self) -> Vec; 9 | fn set_requires_grad(&mut self, requires_grad: bool) { 10 | for mut parameter in self.get_parameters() { 11 | parameter.set_requires_grad(requires_grad); 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/nn/layers/tanh.rs: -------------------------------------------------------------------------------- 1 | use super::Module; 2 | use crate::central::Tensor; 3 | 4 | pub struct Tanh {} 5 | 6 | impl Tanh { 7 | #[allow(unused)] 8 | pub fn new() -> Tanh { 9 | Tanh {} 10 | } 11 | } 12 | 13 | impl Module for Tanh { 14 | fn forward(&mut self, x: &Tensor) -> Tensor { 15 | x.tanh() 16 | } 17 | 18 | fn get_parameters(&self) -> Vec { 19 | Vec::new() 20 | } 21 | } 22 | 23 | impl From for Box { 24 | fn from(layer: Tanh) -> Box { 25 | Box::new(layer) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /data/batchnorm/b2.json: -------------------------------------------------------------------------------- 1 | {"data": [0.6553500890731812, -0.9524102807044983, -0.28236764669418335, 1.4239068031311035, -1.7644256353378296, -1.4027190208435059, 0.17747284471988678, 0.7741606831550598, -1.0762460231781006, 0.39529699087142944, 0.10165512561798096, 0.5887088775634766, 0.9897605776786804, -0.43960240483283997, 0.42898663878440857, -0.4913828372955322, 2.642697811126709, 1.1716099977493286, -0.1792488992214203, -0.8400498628616333, -0.9447303414344788, 0.9816269874572754, -0.15779687464237213, -0.34383338689804077, -1.4714620113372803, -1.0330257415771484, 1.2502636909484863], "shape": [27]} -------------------------------------------------------------------------------- /data/bigram/tensor_b2.json: -------------------------------------------------------------------------------- 1 | {"shape": [1, 27], "data": [0.13731364905834198, 1.2041877508163452, 1.9441583156585693, 0.30936285853385925, 2.019181251525879, 0.008597055450081825, -0.8007318377494812, 0.6119608283042908, 1.789738416671753, -0.6736906170845032, -1.2737611532211304, -1.128475308418274, 0.2866993844509125, 0.1687377542257309, 1.123108983039856, 0.6549684405326843, -2.7028048038482666, 1.2879042625427246, 0.009431934915482998, -0.6419469714164734, 1.3549457788467407, -1.9539179801940918, 1.1203961372375488, 0.023704156279563904, 0.8101799488067627, -0.79859858751297, -0.2513209283351898]} -------------------------------------------------------------------------------- /src/nn/layers/mod.rs: -------------------------------------------------------------------------------- 1 | mod attention; 2 | mod batch_norm; 3 | mod module; 4 | mod tanh; 5 | mod embedding; 6 | mod dropout; 7 | mod linear; 8 | mod mlp; 9 | mod layer_norm; 10 | mod casual_self_attention; 11 | 12 | pub use batch_norm::BatchNorm1d; 13 | pub use module::Module; 14 | pub use tanh::Tanh; 15 | pub use embedding::Embedding; 16 | pub use attention::*; 17 | pub use dropout::Dropout; 18 | pub use linear::{LinearLayer, LinearLayerConfig}; 19 | pub use mlp::{MLP, NewGLU}; 20 | pub use layer_norm::LayerNorm; 21 | pub use casual_self_attention::{CasualSelfAttention, CasualSelfAttentionConfig}; -------------------------------------------------------------------------------- /linear_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | def main(): 4 | # I want a linear layer of shape 2, 2, with weights 1, 2, 3 and 4 5 | linear = nn.Linear(2, 2) 6 | linear.weight.data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float) 7 | linear.bias.data = torch.tensor([0, 0], dtype=torch.float) 8 | print(linear.weight) 9 | 10 | test_input = torch.tensor([[1, 2]], dtype=torch.float) 11 | 12 | ouput = linear(test_input) 13 | 14 | print(test_input @ linear.weight.data.t()) 15 | print(ouput) 16 | 17 | 18 | 19 | 20 | 21 | main() -------------------------------------------------------------------------------- /src/nn/layers/dropout.rs: -------------------------------------------------------------------------------- 1 | use crate::nn::Module; 2 | use crate::Tensor; 3 | 4 | pub struct Dropout { 5 | p: f32 6 | } 7 | 8 | impl Dropout { 9 | pub fn new(p: f32) -> Self { 10 | Dropout { p } 11 | } 12 | } 13 | 14 | impl Module for Dropout { 15 | fn forward(&mut self, input: &Tensor) -> Tensor { 16 | // TODO: This is really slow, as it allocates new tensors for each operation 17 | let mask = Tensor::randn(input.shape); 18 | let mask = mask * self.p; 19 | let mask = mask / (1.0 - self.p); 20 | let mask = mask * *input; 21 | return mask; 22 | } 23 | 24 | fn get_parameters(&self) -> Vec { 25 | Vec::new() 26 | } 27 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "poro" 3 | version = "0.1.2" 4 | edition = "2021" 5 | authors = ["tucker morgan tucker.bull.morgan@gmail.com"] 6 | description = "A simple toy neural network library" 7 | license = "MIT" 8 | repository = "https://github.com/tuckerbmorgan/hermes" 9 | 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dependencies] 13 | lazy_static = "1.4.0" 14 | ndarray = { version = "0.15.6", features = ["rayon"]} 15 | rand = "0.8.5" 16 | rand_distr = "0.4.3" 17 | serde = "1.0.201" 18 | serde_json = "1.0.117" 19 | log = "0.4.22" 20 | simplelog = "0.11" 21 | tokenizers = {version = "0.20.3", features = ["http"]} 22 | 23 | [target.'cfg(windows)'.dependencies] 24 | cudarc = { version="0.11.1", features = [ 25 | "cuda-version-from-build-system", 26 | ]} 27 | 28 | [target.'cfg(macos)'.dependencies] 29 | metal = "0.30.0" -------------------------------------------------------------------------------- /examples/2_autograd_backprop.rs: -------------------------------------------------------------------------------- 1 | use poro::{update_parameters, Tensor}; 2 | 3 | fn main() { 4 | // You allocate tensors just like you would in PyTorch 5 | // Then automatically calculate the gradients 6 | // you do need to call set_requires_grad(true) on the tensors you want to calculate the gradients for 7 | let mut a = Tensor::randn(vec![1].into()); 8 | a.set_requires_grad(true); 9 | let mut b = Tensor::randn(vec![1].into()); 10 | b.set_requires_grad(true); 11 | 12 | let c = a + b; 13 | 14 | println!( 15 | "a {:?} + b {:?} = {:?}", 16 | a.item().to_string(), 17 | b.item().to_string(), 18 | c.item().to_string() 19 | ); 20 | 21 | // And then simply call backward() on the result tensor 22 | c.backward(); 23 | println!("a.grad: {:?}", a.grad().to_string()); 24 | // And then you need to call update_parameters() to update the parameters 25 | update_parameters(-0.01); 26 | } 27 | -------------------------------------------------------------------------------- /data/batchnorm/bngain.json: -------------------------------------------------------------------------------- 1 | {"data": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "shape": [200]} -------------------------------------------------------------------------------- /data/batchnorm/bnbias.json: -------------------------------------------------------------------------------- 1 | {"data": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "shape": [1, 200]} -------------------------------------------------------------------------------- /src/central/mod.rs: -------------------------------------------------------------------------------- 1 | use lazy_static::lazy_static; 2 | 3 | use std::sync::{Mutex, MutexGuard}; 4 | mod add_op; 5 | mod equation; 6 | mod grad_control; 7 | mod indexable; 8 | mod internal_tensor; 9 | mod matmul_op; 10 | mod mul_op; 11 | mod operation; 12 | mod shape; 13 | mod tensor; 14 | mod view; 15 | 16 | pub use add_op::backward; 17 | pub use equation::{BackpropagationPacket, Equation}; 18 | pub use grad_control::NoGrad; 19 | pub use indexable::Indexable; 20 | pub use shape::Shape; 21 | pub use tensor::{Tensor, TensorID}; 22 | 23 | lazy_static! { 24 | static ref SINGLETON_INSTANCE: Mutex = Mutex::new(Equation::new()); 25 | } 26 | 27 | pub fn get_equation() -> MutexGuard<'static, Equation> { 28 | SINGLETON_INSTANCE.lock().unwrap() 29 | } 30 | 31 | pub fn zero_all_grads() { 32 | let mut equation = get_equation(); 33 | equation.zero_all_grads(); 34 | } 35 | 36 | pub fn update_parameters(learning_rate: f32) { 37 | let mut equation = get_equation(); 38 | equation.update_parameters(learning_rate); 39 | } 40 | -------------------------------------------------------------------------------- /examples/1_tensor_basics.rs: -------------------------------------------------------------------------------- 1 | use poro::Tensor; 2 | 3 | fn main() { 4 | // You allocate tensors just like you would in PyTorch 5 | let a = Tensor::randn(vec![1].into()); 6 | let b = Tensor::randn(vec![1].into()); 7 | print!( 8 | "You can add tensors together a {:?} + b {:?}", 9 | a.item().to_string(), 10 | b.item().to_string() 11 | ); 12 | // And then you can treat them like you would in PyTorch 13 | let c = a + b; 14 | 15 | println!(" = {:?}", c.item().to_string()); 16 | 17 | // They support most basic math operations 18 | let d = Tensor::randn(vec![1].into()); 19 | let e = Tensor::randn(vec![1].into()); 20 | print!( 21 | "You can multiply them as well d {:?} + e {:?}", 22 | d.item().to_string(), 23 | e.item().to_string() 24 | ); 25 | let f = d * e; 26 | println!(" = {:?}", f.item().to_string()); 27 | 28 | // And you can chain them together 29 | let g = Tensor::randn(vec![1].into()); 30 | let h = Tensor::randn(vec![1].into()); 31 | let i = Tensor::randn(vec![1].into()); 32 | print!( 33 | "You can chain them together g {:?} + h {:?} * i {:?}", 34 | g.item().to_string(), 35 | h.item().to_string(), 36 | i.item().to_string() 37 | ); 38 | let j = g + h * i; 39 | 40 | println!(" = {:?}", j.item().to_string()); 41 | } 42 | -------------------------------------------------------------------------------- /double_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import struct 5 | import math 6 | 7 | 8 | def write_floats_to_file(path: str, data) -> None: 9 | with open(path, 'a') as file: 10 | for value in data: 11 | file.write(f"{value}\n") 12 | 13 | def append_string_to_file(path: str, content: str) -> None: 14 | with open(path, 'a') as file: 15 | file.write(f"{content}\n") 16 | 17 | def write_fp32(tensor, file): 18 | # first write the length of the tensor's shape 19 | shape = torch.tensor(tensor.size(), dtype=torch.int32) 20 | # write the number of dimensions 21 | file.write(struct.pack(" InternalTensor { 20 | InternalTensor { 21 | tensor_id, 22 | shape, 23 | operation, 24 | data_start_index, 25 | grad_start_index, 26 | requires_grad: false, 27 | } 28 | } 29 | 30 | pub fn dependencies(&self) -> Vec { 31 | match self.operation { 32 | Operation::Nop => vec![], 33 | Operation::Add(a, b) => vec![a, b], 34 | Operation::Mul(a, b) => vec![a, b], 35 | Operation::Exp(a) => vec![a], 36 | Operation::Pow(base, power) => vec![base, power], 37 | Operation::MatMul(a, b) => vec![a, b], 38 | Operation::Sum(a, _) => vec![a], 39 | Operation::Broadcast(a, _) => vec![a], 40 | Operation::Log(a) => vec![a], 41 | Operation::View(a, _index) => vec![a], 42 | Operation::Mean(a, _axes) => vec![a], 43 | Operation::Concat(a, b) => vec![a, b], 44 | Operation::Reshape(a, _) => vec![a], 45 | Operation::Tanh(a) => vec![a], 46 | Operation::Transpose(a, _, _) => vec![a], 47 | Operation::Sin(a) => vec![a], 48 | Operation::Cos(a) => vec![a], 49 | Operation::MaskedFill(a, _, _) => vec![a], 50 | Operation::Embedding(a, _) => vec![a], 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/nn/layers/embedding.rs: -------------------------------------------------------------------------------- 1 | use crate::{Shape, Tensor}; 2 | 3 | use super::Module; 4 | 5 | pub struct Embedding { 6 | pub tensor: Tensor, 7 | vocab_size: usize, 8 | model_dimension: usize, 9 | } 10 | 11 | impl Embedding { 12 | pub fn new(vocab_size: usize, model_dimension: usize) -> Embedding { 13 | let tensor = Tensor::randn(Shape::new(vec![vocab_size, model_dimension])); 14 | Embedding { tensor, model_dimension, vocab_size } 15 | } 16 | 17 | pub fn from_pretrained(pretrained: Tensor) -> Embedding { 18 | Embedding { tensor: pretrained, model_dimension: pretrained.shape.indices[1], vocab_size: pretrained.shape.indices[0] } 19 | } 20 | 21 | pub fn from_tensor(tensor: Tensor) -> Embedding { 22 | Embedding { tensor, model_dimension: tensor.shape.indices[1], vocab_size: tensor.shape.indices[0] } 23 | } 24 | } 25 | 26 | impl Module for Embedding { 27 | fn forward(&mut self, input: &Tensor) -> Tensor { 28 | 29 | return self.tensor.embbeding(input, self.model_dimension); 30 | 31 | let mut test_index_tensor = Tensor::zeroes(Shape::new(vec![input.shape.indices[0], input.shape.indices[1], self.model_dimension])); 32 | let data = self.tensor.item(); 33 | 34 | for b in 0..input.shape.indices[0] { 35 | for t in 0..input.shape.indices[1] { 36 | let view = input.view([b, t].into()); 37 | let index = view.item()[[0]] as usize; 38 | for i in 0..self.model_dimension { 39 | let datum = data[[index, i]]; 40 | test_index_tensor.set_index( 41 | [b,t, i].into(), 42 | vec![datum] 43 | ); 44 | } 45 | } 46 | } 47 | return test_index_tensor; 48 | } 49 | 50 | fn get_parameters(&self) -> Vec { 51 | vec![self.tensor.clone()] 52 | } 53 | } -------------------------------------------------------------------------------- /src/nn/layers/batch_norm.rs: -------------------------------------------------------------------------------- 1 | use super::Module; 2 | use crate::{central::Tensor, Shape}; 3 | 4 | /// This struct represents a Batch Normalization layer for 1D tensors. 5 | /// It is used to normalize the activations of the previous layer at each batch. 6 | pub struct BatchNorm1d { 7 | /// The gain tensor of the BatchNorm1d layer. 8 | gain: Tensor, 9 | /// The bias tensor of the BatchNorm1d layer. 10 | bias: Tensor, 11 | } 12 | 13 | impl BatchNorm1d { 14 | #[allow(unused)] 15 | pub fn new(number_of_weights: usize) -> BatchNorm1d { 16 | // Initialize the gain and bias tensors 17 | let gain = Tensor::ones(Shape::new(vec![number_of_weights])); 18 | let bias = Tensor::zeroes(Shape::new(vec![number_of_weights])); 19 | BatchNorm1d { gain, bias } 20 | } 21 | } 22 | 23 | impl Module for BatchNorm1d { 24 | fn forward(&mut self, x: &Tensor) -> Tensor { 25 | // Perform the forward pass: x * gain + bias 26 | if x.shape.number_of_indices == 2 { 27 | let bnmeani = x.mean(vec![0]); 28 | let bnvari = x.std(vec![0]); 29 | let offset = *x - bnmeani; 30 | let numer = offset * self.gain; 31 | let hpreact = numer / bnvari + self.bias; 32 | return hpreact; 33 | } else if x.shape.number_of_indices == 3 { 34 | let bnmeani = x.mean(vec![0, 1]); 35 | let bnvari = x.std(vec![0, 1]); 36 | let offset = *x - bnmeani; 37 | let numer = offset * self.gain; 38 | let hpreact = numer / bnvari + self.bias; 39 | return hpreact; 40 | } else { 41 | panic!("BatchNorm1d only supports 2D and 3D tensors"); 42 | } 43 | } 44 | 45 | fn get_parameters(&self) -> Vec { 46 | vec![self.gain.clone(), self.bias.clone()] 47 | } 48 | } 49 | 50 | impl From for Box { 51 | fn from(layer: BatchNorm1d) -> Box { 52 | Box::new(layer) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/nn/layers/attention.rs: -------------------------------------------------------------------------------- 1 | use crate::nn::layers::LinearLayer; 2 | use crate::nn::Module; 3 | use crate::Tensor; 4 | use crate::nn::layers::linear::LinearLayerConfig; 5 | 6 | pub struct AttentionHead { 7 | pub q: LinearLayer, 8 | pub k: LinearLayer, 9 | pub v: LinearLayer, 10 | pub mask: Option 11 | } 12 | 13 | impl AttentionHead { 14 | pub fn new(d_model: usize, heads: usize) -> Self { 15 | let common_config = LinearLayerConfig { 16 | number_of_inputs: d_model, 17 | number_of_weights: heads 18 | }; 19 | AttentionHead { 20 | q: LinearLayer::new(common_config), 21 | k: LinearLayer::new(common_config), 22 | v: LinearLayer::new(common_config), 23 | mask: None 24 | } 25 | } 26 | 27 | pub fn from_pretrained(q: LinearLayer, k: LinearLayer, v: LinearLayer) -> Self { 28 | AttentionHead { 29 | q, 30 | k, 31 | v, 32 | mask: None 33 | } 34 | } 35 | 36 | pub fn set_mask(&mut self, mask: Tensor) { 37 | self.mask = Some(mask); 38 | } 39 | } 40 | 41 | impl Module for AttentionHead { 42 | fn forward(&mut self, input: &Tensor) -> Tensor { 43 | 44 | let q = self.q.forward(input); 45 | let k = self.k.forward(input); 46 | let v = self.v.forward(input); 47 | let attention = q << k.tranpose_with_provided_axis(1, 0); 48 | let mut attention = attention / (k.shape.indices[1] as f32).sqrt(); 49 | 50 | if self.mask.is_some() { 51 | attention = attention * self.mask.unwrap(); 52 | } 53 | println!("{:?}", attention.shape); 54 | let attention = attention.softmax(attention.shape.number_of_indices - 1); 55 | 56 | let attention = attention << v; 57 | 58 | return attention; 59 | } 60 | 61 | fn get_parameters(&self) -> Vec { 62 | let mut parameters = Vec::new(); 63 | parameters.extend(self.q.get_parameters()); 64 | parameters.extend(self.k.get_parameters()); 65 | parameters.extend(self.v.get_parameters()); 66 | return parameters; 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NO LONGER MANTAINED GO TO [Cant](https://www.github.com/TuckerBMorgan/can-t) 2 | # Poro 3 | [![Current Crates.io Version](https://img.shields.io/crates/v/poro.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/poro) 4 | ![Professor Poro](./icon.webp) 5 | 6 | Poro is a library I am writing to help me better understand modern ML frameworks like [Pytorch](https://pytorch.org/) and [Tensorflow](https://www.tensorflow.org/). It is mostly based off of Karpathys Micrograd series of (lectures)[https://www.youtube.com/watch?v=VMj-3S1tku0]. It is in rust because I enjoy its ease of setup. 7 | 8 | ## Features 9 | - Basic neural network operations 10 | - Frictionless Autograd 11 | - Cuda support (limited by growing!) 12 | - Support for custom layers and operations 13 | - Lightweight and trying to focus on ease of understanding for others 14 | 15 | ## notes on usage 16 | I work on this lib while I also have a job, and it is done for the enjoyment of learning, so it is likely a good idea to not use this for professional use as I am unlikely to get to any issues you might have in a timely manner :) 17 | 18 | if you run cargo test, some tests might fail, this is not actually them failing 19 | do to the nature of the way I have the equation working as a singelton, it it possible for 20 | test to grab the lock at the wrong time and fail. If you know how to fix this, that would make a great first PR :). You can always run the tests one at a time and see that they work that way 21 | 22 | # Getting Started 23 | ```bash 24 | cargo test --release 25 | ``` 26 | 27 | ### Installation 28 | 29 | To use Poro, add the following to your `Cargo.toml`: 30 | 31 | ```toml 32 | [dependencies] 33 | poro = "0.1.2" 34 | ``` 35 | 36 | ### Usage 37 | 38 | Here is a simple example to get you started with Poro: 39 | 40 | ```rust 41 | use Poro::tensor::Tensor; 42 | use ndarray::prelude::*; 43 | 44 | fn main() { 45 | let a = Tensor::ones(Shape::new(vec![2, 2])); 46 | let b = Tensor::zeroes(Shape::new(vec![2, 2])); 47 | let c = a + b; 48 | let result = c.item(); 49 | assert!(result == arr2(&[[1.0, 1.0], [1.0, 1.0]]).into_dyn()); 50 | } 51 | ``` 52 | 53 | ## Planned features 54 | 55 | - Optimizer module 56 | - Data Loader Module 57 | - Working with Metal 58 | - Transfomers Layer 59 | - Conv Layer 60 | - Model/Module Configure 61 | -------------------------------------------------------------------------------- /src/nn/model/model.rs: -------------------------------------------------------------------------------- 1 | use crate::central::*; 2 | use crate::nn::layers::*; 3 | 4 | pub trait Model { 5 | fn forward(&mut self, x: &Tensor) -> Tensor; 6 | fn get_parameters(&self) -> Vec; 7 | } 8 | 9 | pub struct Sequential { 10 | pub layers: Vec>, 11 | } 12 | 13 | impl Sequential { 14 | #[allow(unused)] 15 | pub fn new(layers: Vec>) -> Self { 16 | Sequential { layers } 17 | } 18 | 19 | #[allow(unused)] 20 | pub fn set_requires_grad(&mut self, requires_grad: bool) { 21 | for layer in &mut self.layers { 22 | layer.set_requires_grad(requires_grad); 23 | } 24 | } 25 | } 26 | 27 | impl Model for Sequential { 28 | fn forward(&mut self, x: &Tensor) -> Tensor { 29 | let mut output = x.clone(); 30 | for layer in &mut self.layers { 31 | output = layer.forward(&output); 32 | } 33 | output 34 | } 35 | 36 | fn get_parameters(&self) -> Vec { 37 | let mut parameters = Vec::new(); 38 | for layer in &self.layers { 39 | parameters.extend(layer.get_parameters()); 40 | } 41 | parameters 42 | } 43 | } 44 | 45 | impl From>> for Sequential { 46 | fn from(modules: Vec>) -> Sequential { 47 | Sequential::new(modules) 48 | } 49 | } 50 | 51 | #[test] 52 | fn linear_model() { 53 | let mut linear_layer_config = LinearLayerConfig::default(); 54 | linear_layer_config.number_of_inputs = 3; 55 | linear_layer_config.number_of_weights = 1; 56 | 57 | let mut linear_model = Sequential::new(vec![Box::new(LinearLayer::new(linear_layer_config))]); 58 | 59 | let inputs = vec![ 60 | vec![2.0f32, 3.0, -1.0], 61 | vec![3.0, -1.0, 0.5], 62 | vec![0.5, 1.0, 1.0], 63 | vec![1.0, 1.0, -1.0], 64 | ]; 65 | 66 | let inputs_as_tensor = Tensor::from_vec( 67 | inputs.iter().flatten().map(|x| *x).collect(), 68 | vec![4, 3].into(), 69 | ); 70 | 71 | let outputs = vec![1.0f32, -1.0, -1.0, 1.0]; 72 | 73 | let outputs_as_tensor = 74 | Tensor::from_vec(outputs.iter().map(|x| *x).collect(), vec![4, 1].into()); 75 | 76 | for _ in 0..50 { 77 | zero_all_grads(); 78 | let prediction = linear_model.forward(&inputs_as_tensor); 79 | let loss = (prediction - outputs_as_tensor).pow(2.0); 80 | loss.backward(); 81 | update_parameters(-0.01); 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /gpt_check.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def read_file(filepath): 4 | """Reads a file and returns its lines as a list""" 5 | with open(filepath, 'r') as file: 6 | return file.readlines() 7 | 8 | def calculate_percent_error(value1, value2): 9 | """Calculates the percent error between two numbers""" 10 | if value1 == value2 == 0: 11 | return 0.0 12 | # handle masking of -inf values 13 | if value1 == float('-inf') or value2 == float('-inf'): 14 | return 0.0 15 | try: 16 | return abs((value1 - value2) / value1) * 100 17 | except ZeroDivisionError: 18 | return np.inf 19 | 20 | def process_files(rust_file, python_file): 21 | """Processes two files, calculates percentage error per zone""" 22 | rust_lines = read_file(rust_file) 23 | python_lines = read_file(python_file) 24 | 25 | if len(rust_lines) != len(python_lines): 26 | print("Error: Files have different numbers of lines.") 27 | return 28 | 29 | current_zone = None 30 | zone_errors = {} 31 | 32 | for rust_line, python_line in zip(rust_lines, python_lines): 33 | rust_line = rust_line.strip() 34 | python_line = python_line.strip() 35 | 36 | # Track zone if line starts with $ 37 | if rust_line.startswith('$'): 38 | current_zone = rust_line 39 | zone_errors[current_zone] = [] 40 | continue 41 | 42 | if python_line.startswith('$'): 43 | continue # Ignore python zone markers, just rely on rust_file 44 | 45 | try: 46 | rust_value = float(rust_line) 47 | python_value = float(python_line) 48 | except ValueError: 49 | # Skip lines that are not numerical 50 | continue 51 | 52 | # Calculate percent error 53 | percent_error = calculate_percent_error(rust_value, python_value) 54 | zone_errors[current_zone].append(percent_error) 55 | 56 | # Calculate total percent error for each zone 57 | for zone, errors in zone_errors.items(): 58 | if errors: 59 | avg_error = sum(errors) / len(errors) 60 | else: 61 | avg_error = 0 62 | print(f"Zone: {zone}, Total % Error: {avg_error:.2f}%") 63 | 64 | if __name__ == "__main__": 65 | rust_file = "rust_checkfile.txt" # Replace with your rust file path 66 | python_file = "python_checkfile.txt" # Replace with your python file path 67 | 68 | process_files(rust_file, python_file) 69 | -------------------------------------------------------------------------------- /examples/3_modules_and_models.rs: -------------------------------------------------------------------------------- 1 | use log::info; 2 | use poro::central::*; 3 | use poro::nn::layers::LinearLayer; 4 | use poro::nn::model::Sequential; 5 | use poro::nn::Model; 6 | use poro::nn::LinearLayerConfig; 7 | use poro::nn::Module; 8 | use poro::Array2; 9 | 10 | use std::fs::File; 11 | use std::path::Path; 12 | 13 | use simplelog::*; 14 | 15 | fn main() { 16 | 17 | WriteLogger::init( 18 | LevelFilter::Info, // Set the log level 19 | Config::default(), // Use the default configuration 20 | File::create(Path::new("modules_and_models.log")).unwrap(), // Create or open the log file 21 | ).unwrap(); 22 | 23 | let mut tensor_from = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2].into()); 24 | let mut bias = Tensor::from_vec(vec![0.0, 0.0], vec![2].into()); 25 | let mut linear_from = LinearLayer::from_weights_and_bias(tensor_from, bias); 26 | let input = Tensor::from_vec(vec![1.0, 2.0], vec![1, 2].into()); 27 | 28 | let array_2d = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); 29 | let array_input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap(); 30 | let array_output = array_input.dot(&array_2d.t()); 31 | 32 | info!("array_output: {:?}", array_output); 33 | 34 | let output = linear_from.forward(&input); 35 | info!("lnear_weights {:?}", linear_from.weights.item()); 36 | info!("Output: {:?}", output.item()); 37 | info!("non linear layer test: {:?}", (input << tensor_from).item()); 38 | return; 39 | 40 | let linear_layer_config = LinearLayerConfig::new(3, 1); 41 | // A layer is single module that can be used in a model 42 | let layer = LinearLayer::new(linear_layer_config); 43 | 44 | // A model is a sequence of layers 45 | // You can create a model by calling "into" on a vector of layers 46 | let mut linear_model: Sequential = vec![layer.into()].into(); 47 | 48 | let inputs = vec![ 49 | vec![2.0f32, 3.0, -1.0], 50 | vec![3.0, -1.0, 0.5], 51 | vec![0.5, 1.0, 1.0], 52 | vec![1.0, 1.0, -1.0], 53 | ]; 54 | 55 | let inputs_as_tensor = Tensor::from_vec( 56 | inputs.iter().flatten().map(|x| *x).collect(), 57 | vec![4, 3].into(), 58 | ); 59 | 60 | let outputs = vec![1.0f32, -1.0, -1.0, 1.0]; 61 | 62 | let outputs_as_tensor = 63 | Tensor::from_vec(outputs.iter().map(|x| *x).collect(), vec![4, 1].into()); 64 | 65 | // This is your training loop 66 | for _ in 0..50 { 67 | zero_all_grads(); 68 | let prediction = linear_model.forward(&inputs_as_tensor); 69 | let loss = (prediction - outputs_as_tensor).pow(2.0); 70 | info!("Loss: {:?}", loss.item()); 71 | loss.backward(); 72 | update_parameters(-0.01); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/optimizers/adamw.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::hash::Hash; 3 | 4 | use ndarray::ArrayD; 5 | 6 | use crate::nn::Module; 7 | use crate::optimizers::optimizer::Optimizer; 8 | use crate::{get_equation, TensorID}; 9 | 10 | pub struct AdamWBuilder { 11 | pub learning_rate: f32, 12 | pub beta1: f32, 13 | pub beta2: f32, 14 | pub epsilon: f32, 15 | } 16 | 17 | impl Default for AdamWBuilder { 18 | fn default() -> Self { 19 | Self { 20 | learning_rate: 0.01, 21 | beta1: 0.9, 22 | beta2: 0.999, 23 | epsilon: 1e-08, 24 | } 25 | } 26 | } 27 | 28 | pub struct AdamW { 29 | pub learning_rate: f32, 30 | pub beta1: f32, 31 | pub beta2: f32, 32 | pub epsilon: f32, 33 | pub m: HashMap>, 34 | pub v: HashMap>, 35 | pub t: usize, 36 | } 37 | 38 | impl AdamW { 39 | pub fn new(adam_builder: AdamWBuilder) -> Self { 40 | Self { 41 | learning_rate: adam_builder.learning_rate, 42 | beta1: adam_builder.beta1, 43 | beta2: adam_builder.beta2, 44 | epsilon: adam_builder.epsilon, 45 | m: HashMap::new(), 46 | v: HashMap::new(), 47 | t: 0, 48 | } 49 | } 50 | } 51 | 52 | impl Optimizer for AdamW { 53 | fn record_parameters(&mut self, model: &dyn Module) { 54 | self.m.clear(); 55 | self.v.clear(); 56 | for param in model.get_parameters() { 57 | let test = param.item().len(); 58 | let m1s: Vec = vec![0.0; test]; 59 | let v1s: Vec = vec![0.0; test]; 60 | self.m.insert(param.tensor_id, m1s); 61 | self.v.insert(param.tensor_id, v1s); 62 | } 63 | } 64 | 65 | fn step(&mut self, model: &mut dyn Module) { 66 | self.t += 1; 67 | for tensor in model.get_parameters() { 68 | let grad = tensor.grad(); 69 | let m = self.m.get_mut(&tensor.tensor_id).unwrap(); 70 | let v = self.v.get_mut(&tensor.tensor_id).unwrap(); 71 | 72 | for (i, g) in grad.iter().enumerate() { 73 | m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * g; 74 | v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * g.powi(2); 75 | } 76 | 77 | let m_hat = m.iter().map(|x| x / (1.0 - self.beta1.powi(self.t as i32))).collect::>(); 78 | let v_hat = v.iter().map(|x| x / (1.0 - self.beta2.powi(self.t as i32))).collect::>(); 79 | let mut new_weights = tensor.item().into_raw_vec(); 80 | for (i, w) in new_weights.iter_mut().enumerate() { 81 | *w -= self.learning_rate * m_hat[i] / (v_hat[i].sqrt() + self.epsilon); 82 | } 83 | 84 | let new_weights = ArrayD::from_shape_vec(tensor.item().shape(), new_weights).unwrap(); 85 | get_equation().set_tensor_data(tensor.tensor_id, new_weights); 86 | } 87 | } 88 | } -------------------------------------------------------------------------------- /src/central/operation.rs: -------------------------------------------------------------------------------- 1 | use crate::{Indexable, Shape}; 2 | 3 | use super::tensor::TensorID; 4 | use std::fmt; 5 | 6 | #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] 7 | pub enum Operation { 8 | /// No operation, this will not pass any gradient 9 | Nop, 10 | Add(TensorID, TensorID), 11 | Mul(TensorID, TensorID), 12 | Exp(TensorID), 13 | Pow(TensorID, TensorID), 14 | MatMul(TensorID, TensorID), 15 | Sum(TensorID, usize), 16 | Broadcast(TensorID, Shape), 17 | Log(TensorID), 18 | View(TensorID, Indexable), 19 | Mean(TensorID, usize), 20 | Concat(TensorID, TensorID), 21 | Reshape(TensorID, Shape), 22 | Tanh(TensorID), 23 | Transpose(TensorID, usize, usize), 24 | Sin(TensorID), 25 | Cos(TensorID), 26 | MaskedFill(TensorID, TensorID, isize), 27 | Embedding(TensorID, TensorID), 28 | } 29 | 30 | impl Operation { 31 | pub fn get_tensor_id(&self) -> Option { 32 | match self { 33 | Operation::Nop => None, 34 | Operation::Add(a, _) => Some(*a), 35 | Operation::Mul(a, _) => Some(*a), 36 | Operation::Exp(a) => Some(*a), 37 | Operation::Pow(a, _) => Some(*a), 38 | Operation::MatMul(a, _) => Some(*a), 39 | Operation::Sum(a, _) => Some(*a), 40 | Operation::Broadcast(a, _) => Some(*a), 41 | Operation::Log(a) => Some(*a), 42 | Operation::View(a, _index) => Some(*a), 43 | Operation::Mean(a, _) => Some(*a), 44 | Operation::Concat(a, _) => Some(*a), 45 | Operation::Reshape(a, _) => Some(*a), 46 | Operation::Tanh(a) => Some(*a), 47 | Operation::Transpose(a, _, _) => Some(*a), 48 | Operation::Sin(a) => Some(*a), 49 | Operation::Cos(a) => Some(*a), 50 | Operation::MaskedFill(a, _, _) => Some(*a), 51 | Operation::Embedding(a, _) => Some(*a), 52 | } 53 | } 54 | } 55 | 56 | impl fmt::Display for Operation { 57 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 58 | match self { 59 | Operation::Nop => write!(f, "Nop"), 60 | Operation::Add(_a, _b) => write!(f, "Add()"), 61 | Operation::Mul(_a, _b) => write!(f, "Mul()"), 62 | Operation::Exp(_a) => write!(f, "Exp()"), 63 | Operation::Pow(_a, _b) => write!(f, "Pow()"), 64 | Operation::MatMul(_a, _b) => write!(f, "MatMul()"), 65 | Operation::Sum(_a, _) => write!(f, "Sum()"), 66 | Operation::Broadcast(_a, _shape) => write!(f, "Broadcast()"), 67 | Operation::Log(_a) => write!(f, "Log()"), 68 | Operation::View(_a, _index) => write!(f, "View()"), 69 | Operation::Mean(_a, _) => write!(f, "Mean()"), 70 | Operation::Concat(_a, _b) => write!(f, "Concat()"), 71 | Operation::Reshape(_a, _shape) => write!(f, "Reshape()"), 72 | Operation::Tanh(_a) => write!(f, "Tanh()"), 73 | Operation::Transpose(_a, _, _) => write!(f, "Transpose()"), 74 | Operation::Sin(_a) => write!(f, "Sin()"), 75 | Operation::Cos(_a) => write!(f, "Cos()"), 76 | Operation::MaskedFill(_a, _b, _c) => write!(f, "MaskedFill()"), 77 | Operation::Embedding(_a, _b) => write!(f, "Embedding()"), 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/optimizers/adam.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::hash::Hash; 3 | 4 | use ndarray::ArrayD; 5 | 6 | use crate::nn::Module; 7 | use crate::optimizers::optimizer::Optimizer; 8 | use crate::{get_equation, TensorID}; 9 | 10 | pub struct AdamBuilder { 11 | pub learning_rate: f32, 12 | pub beta1: f32, 13 | pub beta2: f32, 14 | pub epsilon: f32, 15 | } 16 | 17 | impl Default for AdamBuilder { 18 | fn default() -> Self { 19 | Self { 20 | learning_rate: 0.01, 21 | beta1: 0.9, 22 | beta2: 0.999, 23 | epsilon: 1e-08, 24 | } 25 | } 26 | } 27 | 28 | pub struct Adam { 29 | pub learning_rate: f32, 30 | pub beta1: f32, 31 | pub beta2: f32, 32 | pub epsilon: f32, 33 | pub m: HashMap>, 34 | pub v: HashMap>, 35 | pub t: usize, 36 | } 37 | 38 | 39 | 40 | impl Adam { 41 | pub fn new(adam_builder: AdamBuilder) -> Self { 42 | Self { 43 | learning_rate: adam_builder.learning_rate, 44 | beta1: adam_builder.beta1, 45 | beta2: adam_builder.beta2, 46 | epsilon: adam_builder.epsilon, 47 | m: HashMap::new(), 48 | v: HashMap::new(), 49 | t: 0, 50 | } 51 | } 52 | } 53 | 54 | impl Optimizer for Adam { 55 | fn record_parameters(&mut self, model: &dyn Module) { 56 | self.m.clear(); 57 | self.v.clear(); 58 | for param in model.get_parameters() { 59 | let test = param.item().len(); 60 | let m1s: Vec = vec![0.0; test]; 61 | let v1s: Vec = vec![0.0; test]; 62 | self.m.insert(param.tensor_id, m1s); // Assuming Tensor has an id() method 63 | self.v.insert(param.tensor_id, v1s); // Assuming Tensor has an id() method 64 | } 65 | } 66 | 67 | fn step(&mut self, model: &mut dyn Module) { 68 | self.t += 1; 69 | for tensor in model.get_parameters() { 70 | 71 | let grad = tensor.grad(); // Assuming Tensor has a grad() method 72 | let m = self.m.get_mut(&tensor.tensor_id).unwrap(); 73 | let v = self.v.get_mut(&tensor.tensor_id).unwrap(); 74 | 75 | // Update biased first moment estimate 76 | for (m_i, g_i) in m.iter_mut().zip(grad.iter()) { 77 | *m_i = self.beta1 * *m_i + (1.0 - self.beta1) * g_i; 78 | } 79 | 80 | // Update biased second raw moment estimate 81 | for (v_i, g_i) in v.iter_mut().zip(grad.iter()) { 82 | *v_i = self.beta2 * *v_i + (1.0 - self.beta2) * g_i.powi(2); 83 | } 84 | 85 | let m_hat: Vec = m.iter().map(|x| x / (1.0 - self.beta1.powi(self.t as i32))).collect(); 86 | let v_hat: Vec = v.iter().map(|x| x / (1.0 - self.beta2.powi(self.t as i32))).collect(); 87 | 88 | let mut data = tensor.item().into_raw_vec(); 89 | 90 | for i in 0..data.len() { 91 | data[i] -= self.learning_rate * m_hat[i] / (v_hat[i].sqrt() + self.epsilon); 92 | } 93 | let data_as_array = ArrayD::from_shape_vec(tensor.item().shape(), data).unwrap(); 94 | get_equation().set_tensor_data(tensor.tensor_id, data_as_array); // Assuming Tensor has a tensor_id() method 95 | } 96 | 97 | } 98 | } -------------------------------------------------------------------------------- /embedding.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | -0.817333996295929, 4 | -0.555568516254425, 5 | -0.8266895413398743, 6 | -1.2969543933868408, 7 | -0.19737878441810608, 8 | -0.9643335342407227, 9 | -0.5132989287376404, 10 | 2.6277847290039062, 11 | -0.74649578332901, 12 | 1.0050938129425049, 13 | -0.25683921575546265, 14 | 0.47649285197257996, 15 | -0.6652101278305054, 16 | -0.3626655638217926, 17 | -1.4503511190414429, 18 | -0.24958686530590057, 19 | 0.8297731876373291, 20 | 1.120947241783142, 21 | 0.9999132752418518, 22 | 1.116694450378418, 23 | 1.0762836933135986, 24 | -0.06622858345508575, 25 | 0.13152696192264557, 26 | 0.16809339821338654, 27 | 0.056165486574172974, 28 | 0.24557670950889587, 29 | 0.9534983038902283, 30 | 0.35533979535102844, 31 | 0.2120935618877411, 32 | -0.3377707004547119, 33 | -0.3536095917224884, 34 | -0.2729077935218811, 35 | 1.3290292024612427, 36 | 1.1973766088485718, 37 | -1.2645164728164673, 38 | -0.43443864583969116, 39 | -2.1805989742279053, 40 | -1.1093748807907104, 41 | -2.0410380363464355, 42 | 0.03336522728204727, 43 | -1.080498218536377, 44 | -0.8832487463951111, 45 | 0.9741044640541077, 46 | 0.5631665587425232, 47 | -1.1151163578033447, 48 | 0.1490076184272766, 49 | -1.0923006534576416, 50 | -1.4551283121109009, 51 | -0.4810960590839386, 52 | -1.0822992324829102, 53 | 0.3177033066749573, 54 | 2.747004747390747, 55 | 0.36901962757110596, 56 | 1.3373314142227173, 57 | -0.9179764986038208, 58 | -0.9615129232406616, 59 | -0.0050771813839674, 60 | 0.35382238030433655, 61 | -0.35596030950546265, 62 | -0.766556978225708, 63 | 0.43009135127067566, 64 | 0.14965298771858215, 65 | -0.24599970877170563, 66 | -1.463648796081543, 67 | 0.5876042246818542, 68 | -1.1602551937103271, 69 | 1.0045193433761597, 70 | -0.5749108791351318, 71 | 0.585305392742157, 72 | -1.3658185005187988, 73 | -1.3333778381347656, 74 | 0.14503853023052216, 75 | 1.2486594915390015, 76 | -2.277423143386841, 77 | 1.2525312900543213, 78 | 0.5466712117195129, 79 | 0.020079230889678, 80 | -1.5847541093826294, 81 | 2.443178176879883, 82 | 0.767185628414154, 83 | 0.6273192763328552, 84 | -0.29026904702186584, 85 | 0.6038818955421448, 86 | 1.2172517776489258, 87 | -0.5778941512107849, 88 | -1.6192350387573242, 89 | -0.2753530740737915, 90 | -0.3702923655509949, 91 | 0.3324183523654938, 92 | 0.5222781896591187, 93 | 0.6027423143386841, 94 | -0.54026198387146, 95 | 1.0014485120773315, 96 | 0.7306378483772278, 97 | -2.264442205429077, 98 | -1.3223894834518433, 99 | -0.5860189199447632, 100 | 1.610308051109314, 101 | -1.4181116819381714, 102 | 1.4855009317398071 103 | ], 104 | "shape": [ 105 | 10, 106 | 10 107 | ] 108 | } -------------------------------------------------------------------------------- /data/bigram/tensor_b1.json: -------------------------------------------------------------------------------- 1 | {"shape": [1, 200], "data": [-2.0704636573791504, -0.9859631061553955, 0.03212256729602814, 1.9678773880004883, -0.14864158630371094, 0.9788134694099426, -0.3514842689037323, -1.0125799179077148, -0.8775776028633118, 0.8043980598449707, 1.2247507572174072, -0.7079716324806213, 0.14367274940013885, 2.2885613441467285, 1.1220331192016602, -0.18222437798976898, -0.15387427806854248, -1.1546859741210938, -1.4097756147384644, 1.189675211906433, -0.4227447211742401, 1.2084898948669434, 0.6531090140342712, -0.3381667733192444, 1.2731729745864868, 1.4809616804122925, 1.1152701377868652, -0.11543188989162445, 0.8132458329200745, -0.7410956025123596, 0.2077830731868744, 0.23032301664352417, 0.42633846402168274, -1.5452055931091309, -0.9454129338264465, -0.9093910455703735, -0.36922216415405273, 0.14301390945911407, 1.0620536804199219, 1.5949292182922363, -0.0468718521296978, -0.4184325337409973, 1.3012486696243286, -1.8453333377838135, -0.35152557492256165, -0.43061456084251404, 0.03578567132353783, -1.0245014429092407, 0.8731839656829834, -1.2807856798171997, -0.3654491901397705, -0.9221362471580505, -0.9590191841125488, 0.41448816657066345, -0.9840542078018188, -0.24361710250377655, 1.2110004425048828, 1.2867810726165771, 0.7162692546844482, 1.6830796003341675, 0.042002856731414795, -0.9180684685707092, -0.5576449632644653, -0.6978939175605774, 0.2516961693763733, -0.5281932353973389, 0.6851141452789307, -0.27623251080513, 0.6756979823112488, -0.10577193647623062, 1.0515114068984985, -1.3378746509552002, 0.11863823235034943, -0.5300424695014954, 0.07706673443317413, -1.0999878644943237, 0.17150814831256866, -0.4509272575378418, 0.02892194129526615, -0.19723659753799438, 1.2722715139389038, 0.08613698929548264, -1.7036510705947876, -0.03748498857021332, 0.2756001353263855, -0.2625517249107361, 0.39931365847587585, -0.19695349037647247, -1.2899240255355835, 0.6199979782104492, 2.1286604404449463, 0.9751921892166138, 0.28824594616889954, -0.25219184160232544, 0.96076500415802, -0.6834419369697571, -0.7115420699119568, -0.8633403182029724, 0.23946379125118256, -2.839179039001465, -0.21991638839244843, 0.9078301191329956, -0.0686650425195694, 1.9234635829925537, 1.0784916877746582, 2.6284561157226562, -1.8466873168945312, -0.9865627884864807, 1.9641478061676025, -1.9937317371368408, 0.9653881192207336, 0.4336015284061432, 0.9796942472457886, -1.1946985721588135, 0.9899058938026428, -1.659302830696106, 0.09641622006893158, 0.5479996800422668, -1.677019476890564, -1.092279314994812, 0.618103563785553, -0.05023413896560669, -0.6701115965843201, 1.6990615129470825, -0.17047294974327087, -0.841547966003418, -0.24554899334907532, -0.8508386611938477, -0.4476163983345032, 0.21367408335208893, -0.31191039085388184, -0.04090249538421631, -0.6381316184997559, -0.17707806825637817, 1.8937699794769287, 0.4485393464565277, 0.3465205729007721, -0.47058725357055664, 0.5921720266342163, -0.6766966581344604, -1.7564834356307983, -1.1173162460327148, -1.1631474494934082, -2.100975751876831, 1.0832459926605225, -0.8473730087280273, -0.5742049217224121, 2.2467257976531982, -0.24249079823493958, 0.9446818828582764, 0.17126603424549103, -0.1568722426891327, -1.912213683128357, 0.5803057551383972, 0.7614110112190247, -0.8251919150352478, 1.2223150730133057, -0.0866578221321106, 0.18618211150169373, -0.7550662755966187, 2.0444159507751465, 0.23962320387363434, -1.4920696020126343, 0.531319260597229, 0.8681632876396179, 0.7268648147583008, 1.2205647230148315, 0.7936658263206482, 0.35465940833091736, -0.7616128921508789, 0.49878817796707153, 0.3572254478931427, -0.11707241833209991, -0.9430463910102844, 0.6465049982070923, -0.6936981678009033, 3.3714704513549805, -0.8634629845619202, 0.6228163242340088, -1.0213602781295776, 1.3811545372009277, -0.728112518787384, -1.0282363891601562, 0.7540144920349121, 0.5894933342933655, 0.9577163457870483, 0.7062502503395081, -0.9797568917274475, -0.5547504425048828, -1.074957251548767, -1.3799283504486084, 0.6060267090797424, 1.6270167827606201, 0.7891387939453125, 0.28836753964424133, -0.8174278736114502, 0.02726033329963684, 0.30390098690986633, -1.2483158111572266, 0.5486917495727539]} -------------------------------------------------------------------------------- /src/central/indexable.rs: -------------------------------------------------------------------------------- 1 | use crate::Shape; 2 | 3 | use super::tensor::TensorID; 4 | /// Represents an indexable value. 5 | /// This is used to index into a tensor. 6 | /// It can be a single index, a double index, or an index from a tensor. 7 | #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Hash)] 8 | pub enum Indexable { 9 | Single(usize), 10 | Double(usize, usize), 11 | Triple(usize, usize, usize), 12 | FromTensor(TensorID), 13 | } 14 | 15 | impl Indexable { 16 | /// Returns the number of indices in the indexable. 17 | pub fn len(&self) -> usize { 18 | match self { 19 | Indexable::Single(_) => 1, 20 | Indexable::Double(_, _) => 2, 21 | Indexable::Triple(_, _, _) => 3, 22 | Indexable::FromTensor(_) => 1, 23 | } 24 | } 25 | 26 | /// Returns the index at the given position. 27 | /// # Arguments 28 | /// * `shape` - The shape to get the index from. 29 | /// # Returns 30 | /// The index at the given position. 31 | /// # Panics 32 | /// This function will panic if the index is out of range. 33 | pub fn get_index(&self, shape: Shape) -> usize { 34 | match self { 35 | Indexable::Single(i) => { 36 | assert!(*i < shape.number_of_indices); 37 | shape.indices[*i] 38 | } 39 | Indexable::Double(_a, _b) => { 40 | panic!("Double index not implemented"); 41 | }, 42 | Indexable::Triple(_a, _b, _c) => { 43 | panic!("Triple index not implemented"); 44 | }, 45 | Indexable::FromTensor(_a) => { 46 | panic!("Mixed index not implemented"); 47 | } 48 | } 49 | } 50 | } 51 | 52 | /// Convert an indexable into a shape. 53 | /// This will return a shape with the indexable as the indices. 54 | /// # Arguments 55 | /// * `indexable` - The indexable to convert. 56 | /// # Returns 57 | /// A shape with the indexable as the indices. 58 | /// # Panics 59 | /// This function will panic if the indexable is a double index or a tensor index. 60 | impl From for Shape { 61 | fn from(indexable: Indexable) -> Shape { 62 | match indexable { 63 | Indexable::Single(a) => vec![a].into(), 64 | Indexable::Double(a, b) => vec![a, b].into(), 65 | Indexable::Triple(a, b, c) => vec![a, b, c].into(), 66 | Indexable::FromTensor(_) => panic!("Mixed index not implemented"), 67 | } 68 | } 69 | } 70 | 71 | /// Convert an indexable into a vector of indices. 72 | /// This will return a vector of indices with the indexable as the indices. 73 | /// # Arguments 74 | /// * `indexable` - The indexable to convert. 75 | /// # Returns 76 | /// A vector of indices with the indexable as the indices. 77 | impl From for Vec { 78 | fn from(indexable: Indexable) -> Vec { 79 | match indexable { 80 | Indexable::Single(index) => vec![index], 81 | Indexable::Double(a, b) => vec![a, b], 82 | Indexable::Triple(a, b, c) => vec![a, b, c], 83 | Indexable::FromTensor(_) => vec![1], 84 | } 85 | } 86 | } 87 | 88 | /// Convert a usize into an indexable. 89 | /// This will return a single index indexable. 90 | /// # Arguments 91 | /// * `index` - The index to convert. 92 | /// # Returns 93 | /// A single index indexable. 94 | impl From for Indexable { 95 | fn from(index: usize) -> Indexable { 96 | Indexable::Single(index) 97 | } 98 | } 99 | 100 | /// Convert a tuple of two usizes into an indexable. 101 | /// This will return a double index indexable. 102 | /// # Arguments 103 | /// * `index` - The index to convert. 104 | /// # Returns 105 | /// A double index indexable. 106 | impl From<[usize; 2]> for Indexable { 107 | fn from([a, b]: [usize; 2]) -> Indexable { 108 | Indexable::Double(a, b) 109 | } 110 | } 111 | 112 | impl From<[usize;3]> for Indexable { 113 | fn from([a, b, c]: [usize; 3]) -> Indexable { 114 | Indexable::Triple(a, b, c) 115 | } 116 | } -------------------------------------------------------------------------------- /data/batchnorm/C.json: -------------------------------------------------------------------------------- 1 | {"data": [1.5673619508743286, -0.2372923195362091, -0.027384605258703232, -1.100779414176941, 0.2858814597129822, -0.029643338173627853, -1.547059178352356, 0.6048919558525085, 0.0791362002491951, 0.9046238660812378, -0.4712531864643097, 0.7868220210075378, -0.32843494415283203, -0.43297016620635986, 1.372930884361267, 2.9333672523498535, 1.561784267425537, -1.6260507106781006, 0.677162230014801, -0.8403948545455933, 0.9848796725273132, -0.14837250113487244, -1.4795262813568115, 0.44830015301704407, -0.07072983682155609, 2.4967896938323975, 2.444791793823242, -0.6700625419616699, -1.2198841571807861, 0.30314409732818604, -1.0725021362304688, 0.7276178002357483, 0.05111394077539444, 1.309485912322998, -0.8021991848945618, -0.8504244685173035, -1.8067610263824463, 1.2523078918457031, -1.2255810499191284, 1.2164969444274902, -0.9647807478904724, -0.23210863769054413, -0.3476170599460602, 0.3324427902698517, -1.3262643814086914, 1.1224371194839478, 0.5964093804359436, 0.45845651626586914, 0.054010938853025436, -1.7400306463241577, 0.11559642851352692, 0.8031929135322571, 0.5410798192024231, -1.1646243333816528, 0.14756156504154205, -1.0006153583526611, 0.38011646270751953, 0.4732840359210968, -0.9102707505226135, -0.7830496430397034, 0.13505524396896362, -0.21160592138767242, -1.0405985116958618, -1.5367215871810913, 0.9374262690544128, -0.8830345273017883, 1.745694637298584, 2.1346139907836914, -0.8561445474624634, 0.5408160090446472, 0.6168996691703796, 1.5159580707550049, -1.0447320938110352, -0.6641396284103394, -0.7239009141921997, 1.7507048845291138, 0.17530231177806854, 0.9927966594696045, -0.6278737187385559, 0.07702291756868362, -1.1640658378601074, 1.2472796440124512, -0.2706054151058197, -1.3635324239730835, 1.3065752983093262, 0.3230685293674469, 1.0357621908187866, -0.862492561340332, -1.2575125694274902, 0.9417989253997803, -1.3256750106811523, 0.14669643342494965, 0.1691337674856186, -1.5396591424942017, -0.7275874614715576, 1.1490541696548462, -0.8746175169944763, -0.29770681262016296, -1.3706518411636353, 0.11499851942062378, -1.0188171863555908, -0.8377675414085388, -2.1056742668151855, -0.2604387104511261, -1.714870810508728, -0.33787330985069275, -1.8262721300125122, -0.8389687538146973, -1.5722687244415283, 0.45795297622680664, -0.5653309226036072, 0.542813241481781, 0.17549338936805725, -2.2901439666748047, -0.7092764973640442, -0.29283446073532104, -2.180312156677246, 0.07931055873632431, 0.9018704891204834, 1.2027945518493652, -0.5614364743232727, -0.1375267058610916, -0.13798850774765015, -2.0976502895355225, -0.7923803925514221, 0.6068918108940125, -1.477674126625061, -0.5102899670600891, 0.564214825630188, 0.9683837890625, -0.3111408054828644, -0.30602651834487915, -1.7495094537734985, -1.633530616760254, 0.38760754466056824, 0.4723603427410126, 1.4829586744308472, 0.3174806535243988, 1.058820128440857, 2.3981854915618896, 0.46826571226119995, -0.656498908996582, 0.6166205406188965, -0.6219748854637146, 0.5100740194320679, 1.3562920093536377, 0.23445184528827667, -0.45584994554519653, -0.0013132077874615788, -0.5116108655929565, 0.5556988716125488, 0.47458043694496155, -1.3866722583770752, 1.6229392290115356, 0.1719702035188675, 0.988463282585144, 0.5065655708312988, 1.0198405981063843, -1.9062488079071045, -0.42753297090530396, -2.1258766651153564, 0.9604060649871826, 1.2481971979141235, 0.2534126937389374, 2.81878399848938, -0.33982959389686584, 0.7031070590019226, 0.40716081857681274, -0.19018201529979706, -0.6965201497077942, 1.703926920890808, 0.742035448551178, 0.9737022519111633, 0.30027997493743896, -0.28971123695373535, -0.31565725803375244, -0.878982424736023, 0.10660803318023682, 1.8598307371139526, 0.05575166642665863, 1.2814979553222656, -0.6318161487579346, -1.246392011642456, 0.6830465197563171, -0.3945547640323639, 0.014387708157300949, 0.5721626281738281, 0.8672562837600708, 0.631492018699646, -1.2230440378189087, -0.21286392211914062, 0.5095030069351196, 0.3271303176879883, 1.9660793542861938, -0.24091431498527527, -0.7951505184173584, 0.2719773054122925, -1.1100248098373413, -0.45284605026245117, -0.4957810044288635, 1.2647724151611328, 1.4624544382095337, 1.119904637336731, 0.9953904151916504, -1.2353113889694214, 0.7381800413131714, 0.8141525983810425, -0.738063633441925, 0.567144513130188, -1.46007239818573, -0.2478037029504776, 0.8828239440917969, -0.08100391179323196, -0.9529945850372314, -0.4883781969547272, -0.7371233701705933, 0.7060908675193787, -0.1929508000612259, 1.2347642183303833, 0.3330754041671753, 1.328259825706482, -1.0921270847320557, -0.8395169973373413, 0.19097915291786194, -0.7174965143203735, -0.38667669892311096, -1.2541816234588623, 1.206780195236206, -1.7102421522140503, -0.47700679302215576, -1.052660346031189, -0.14367036521434784, -0.2773723900318146, 1.1634039878845215, -0.6690982580184937, 0.6491793990135193, 0.5824344158172607, 1.9263932704925537, -0.3784555196762085, 0.007957705296576023, 0.5106770992279053, 0.7592651844024658, -1.6086313724517822, -0.16065196692943573, 1.3783751726150513, -0.27803653478622437, 0.20709708333015442, 1.0032601356506348, -0.597723662853241, -0.39770567417144775, -1.2801213264465332, 0.09244456887245178, 0.10526478290557861, -0.3907223045825958, 0.031722817569971085, -0.5475257039070129, 0.8182698488235474, -0.8162779211997986, -0.39242666959762573, -0.7452054619789124, -0.9464942812919617, -0.1594141125679016, -0.19336439669132233, -0.3765956163406372, -0.04915774613618851, 0.09386590868234634, -0.6453335881233215, 1.2107977867126465, -0.7819821834564209, 0.38448527455329895], "shape": [27, 10]} -------------------------------------------------------------------------------- /src/nn/layers/linear.rs: -------------------------------------------------------------------------------- 1 | use crate::{central::Tensor, Shape, nn::layers::module::Module}; 2 | 3 | #[derive(Debug, Default, Copy, Clone)] 4 | pub struct LinearLayerConfig { 5 | pub number_of_inputs: usize, 6 | pub number_of_weights: usize, 7 | } 8 | 9 | impl LinearLayerConfig { 10 | pub fn new(number_of_inputs: usize, number_of_weights: usize) -> LinearLayerConfig { 11 | LinearLayerConfig { number_of_inputs, number_of_weights } 12 | } 13 | } 14 | pub struct LinearLayer { 15 | pub weights: Tensor, 16 | pub bias: Tensor, 17 | } 18 | 19 | impl LinearLayer { 20 | 21 | pub fn new(config: LinearLayerConfig) -> LinearLayer { 22 | let weights = Tensor::randn(Shape::new(vec![config.number_of_inputs, config.number_of_weights])); 23 | let bias = Tensor::ones(Shape::new(vec![config.number_of_weights])); 24 | LinearLayer { weights, bias } 25 | } 26 | 27 | pub fn from_weights_and_bias(weights: Tensor, bias: Tensor) -> LinearLayer { 28 | LinearLayer { weights, bias } 29 | } 30 | } 31 | 32 | impl Module for LinearLayer { 33 | 34 | fn forward(&mut self, x: &Tensor) -> Tensor { 35 | let weight_transpose = self.weights.transpose(); 36 | if self.bias.shape.number_of_indices == 1 && self.bias.shape.indices[0] == 0 { 37 | return *x << weight_transpose; 38 | } 39 | (*x << weight_transpose) + self.bias 40 | } 41 | 42 | fn get_parameters(&self) -> Vec { 43 | vec![self.weights.clone(), self.bias.clone()] 44 | } 45 | } 46 | 47 | impl From for Box { 48 | fn from(layer: LinearLayer) -> Box { 49 | Box::new(layer) 50 | } 51 | } 52 | 53 | 54 | mod tests { 55 | use crate::nn::layers::linear::{LinearLayer, LinearLayerConfig}; 56 | use crate::nn::layers::module::Module; 57 | use crate::central::Tensor; 58 | use crate::Shape; 59 | 60 | #[test] 61 | pub fn test_linear_layer() { 62 | let config = LinearLayerConfig::new(3, 2); 63 | let mut linear_layer = LinearLayer::new(config); 64 | } 65 | 66 | #[test] 67 | pub fn from_python_weights_and_bias() { 68 | 69 | use std::fs::File; 70 | let weights_path = "data/tests/linear/linear_weights.txt"; 71 | let bias_path = "data/tests/linear/linear_bias.txt"; 72 | let input_file = "data/tests/linear/linear_input.txt"; 73 | let output_file = "data/tests/linear/linear_output.txt"; 74 | let fake_target = "data/tests/linear/linear_fake_target.txt"; 75 | let weight_grad_path = "data/tests/linear/linear_weight_grad.txt"; 76 | let bias_grad_path = "data/tests/linear/linear_bias_grad.txt"; 77 | let loss_path = "data/tests/linear/linear_loss.txt"; 78 | 79 | let mut weights_file = File::open(weights_path).unwrap(); 80 | let mut bias_file = File::open(bias_path).unwrap(); 81 | let mut input_file = File::open(input_file).unwrap(); 82 | let mut output_file = File::open(output_file).unwrap(); 83 | let mut fake_target_file = File::open(fake_target).unwrap(); 84 | let mut loss_file = File::open(loss_path).unwrap(); 85 | let mut weight_grad_file = File::open(weight_grad_path).unwrap(); 86 | let mut bias_grad_file = File::open(bias_grad_path).unwrap(); 87 | 88 | let weights = Tensor::from_bytestream(&mut weights_file, false).unwrap(); 89 | let bias = Tensor::from_bytestream(&mut bias_file, false).unwrap(); 90 | let input = Tensor::from_bytestream(&mut input_file, false).unwrap(); 91 | let expected_output = Tensor::from_bytestream(&mut output_file, false).unwrap(); 92 | let fake_target = Tensor::from_bytestream(&mut fake_target_file, false).unwrap(); 93 | let expected_loss = Tensor::from_bytestream(&mut loss_file, false).unwrap(); 94 | let expected_weight_grad = Tensor::from_bytestream(&mut weight_grad_file, false).unwrap(); 95 | let expected_bias_grad = Tensor::from_bytestream(&mut bias_grad_file, false).unwrap(); 96 | 97 | 98 | let mut linear_layer = LinearLayer::from_weights_and_bias(weights, bias); 99 | let output = linear_layer.forward(&input); 100 | 101 | for i in 0..output.shape.size() { 102 | assert!((output.item()[i] - expected_output.item()[i]).abs() < 1e-6); 103 | } 104 | 105 | let mse_loss = (output - fake_target).pow(2.0).mean(vec![0]); 106 | 107 | for i in 0..mse_loss.shape.size() { 108 | assert!((mse_loss.item()[i] - expected_loss.item()[i]).abs() < 1e-6); 109 | } 110 | 111 | // weight grad check 112 | mse_loss.backward(); 113 | 114 | let weight_bias = linear_layer.bias.grad(); 115 | 116 | for i in 0..expected_bias_grad.shape.size() { 117 | let left = weight_bias[i]; 118 | let right = expected_bias_grad.item()[[i]]; 119 | assert!((left - right).abs() < 1e-6); 120 | } 121 | 122 | 123 | let weight_grad = linear_layer.weights.grad(); 124 | for x in 0..5 { 125 | for y in 0..5 { 126 | let left = weight_grad[[x, y]]; 127 | let right = expected_weight_grad.item()[[x, y]]; 128 | assert!((left - right).abs() < 1e-6); 129 | } 130 | } 131 | 132 | 133 | } 134 | 135 | } -------------------------------------------------------------------------------- /data/bigram/tensor_C.json: -------------------------------------------------------------------------------- 1 | {"shape": [27, 10], "data": [1.5673619508743286, -0.2372923195362091, -0.027384605258703232, -1.100779414176941, 0.2858814597129822, -0.029643338173627853, -1.547059178352356, 0.6048919558525085, 0.0791362002491951, 0.9046238660812378, -0.4712531864643097, 0.7868220210075378, -0.32843494415283203, -0.43297016620635986, 1.372930884361267, 2.9333672523498535, 1.561784267425537, -1.6260507106781006, 0.677162230014801, -0.8403948545455933, 0.9848796725273132, -0.14837250113487244, -1.4795262813568115, 0.44830015301704407, -0.07072983682155609, 2.4967896938323975, 2.444791793823242, -0.6700625419616699, -1.2198841571807861, 0.30314409732818604, -1.0725021362304688, 0.7276178002357483, 0.05111394077539444, 1.309485912322998, -0.8021991848945618, -0.8504244685173035, -1.8067610263824463, 1.2523078918457031, -1.2255810499191284, 1.2164969444274902, -0.9647807478904724, -0.23210863769054413, -0.3476170599460602, 0.3324427902698517, -1.3262643814086914, 1.1224371194839478, 0.5964093804359436, 0.45845651626586914, 0.054010938853025436, -1.7400306463241577, 0.11559642851352692, 0.8031929135322571, 0.5410798192024231, -1.1646243333816528, 0.14756156504154205, -1.0006153583526611, 0.38011646270751953, 0.4732840359210968, -0.9102707505226135, -0.7830496430397034, 0.13505524396896362, -0.21160592138767242, -1.0405985116958618, -1.5367215871810913, 0.9374262690544128, -0.8830345273017883, 1.745694637298584, 2.1346139907836914, -0.8561445474624634, 0.5408160090446472, 0.6168996691703796, 1.5159580707550049, -1.0447320938110352, -0.6641396284103394, -0.7239009141921997, 1.7507048845291138, 0.17530231177806854, 0.9927966594696045, -0.6278737187385559, 0.07702291756868362, -1.1640658378601074, 1.2472796440124512, -0.2706054151058197, -1.3635324239730835, 1.3065752983093262, 0.3230685293674469, 1.0357621908187866, -0.862492561340332, -1.2575125694274902, 0.9417989253997803, -1.3256750106811523, 0.14669643342494965, 0.1691337674856186, -1.5396591424942017, -0.7275874614715576, 1.1490541696548462, -0.8746175169944763, -0.29770681262016296, -1.3706518411636353, 0.11499851942062378, -1.0188171863555908, -0.8377675414085388, -2.1056742668151855, -0.2604387104511261, -1.714870810508728, -0.33787330985069275, -1.8262721300125122, -0.8389687538146973, -1.5722687244415283, 0.45795297622680664, -0.5653309226036072, 0.542813241481781, 0.17549338936805725, -2.2901439666748047, -0.7092764973640442, -0.29283446073532104, -2.180312156677246, 0.07931055873632431, 0.9018704891204834, 1.2027945518493652, -0.5614364743232727, -0.1375267058610916, -0.13798850774765015, -2.0976502895355225, -0.7923803925514221, 0.6068918108940125, -1.477674126625061, -0.5102899670600891, 0.564214825630188, 0.9683837890625, -0.3111408054828644, -0.30602651834487915, -1.7495094537734985, -1.633530616760254, 0.38760754466056824, 0.4723603427410126, 1.4829586744308472, 0.3174806535243988, 1.058820128440857, 2.3981854915618896, 0.46826571226119995, -0.656498908996582, 0.6166205406188965, -0.6219748854637146, 0.5100740194320679, 1.3562920093536377, 0.23445184528827667, -0.45584994554519653, -0.0013132077874615788, -0.5116108655929565, 0.5556988716125488, 0.47458043694496155, -1.3866722583770752, 1.6229392290115356, 0.1719702035188675, 0.988463282585144, 0.5065655708312988, 1.0198405981063843, -1.9062488079071045, -0.42753297090530396, -2.1258766651153564, 0.9604060649871826, 1.2481971979141235, 0.2534126937389374, 2.81878399848938, -0.33982959389686584, 0.7031070590019226, 0.40716081857681274, -0.19018201529979706, -0.6965201497077942, 1.703926920890808, 0.742035448551178, 0.9737022519111633, 0.30027997493743896, -0.28971123695373535, -0.31565725803375244, -0.878982424736023, 0.10660803318023682, 1.8598307371139526, 0.05575166642665863, 1.2814979553222656, -0.6318161487579346, -1.246392011642456, 0.6830465197563171, -0.3945547640323639, 0.014387708157300949, 0.5721626281738281, 0.8672562837600708, 0.631492018699646, -1.2230440378189087, -0.21286392211914062, 0.5095030069351196, 0.3271303176879883, 1.9660793542861938, -0.24091431498527527, -0.7951505184173584, 0.2719773054122925, -1.1100248098373413, -0.45284605026245117, -0.4957810044288635, 1.2647724151611328, 1.4624544382095337, 1.119904637336731, 0.9953904151916504, -1.2353113889694214, 0.7381800413131714, 0.8141525983810425, -0.738063633441925, 0.567144513130188, -1.46007239818573, -0.2478037029504776, 0.8828239440917969, -0.08100391179323196, -0.9529945850372314, -0.4883781969547272, -0.7371233701705933, 0.7060908675193787, -0.1929508000612259, 1.2347642183303833, 0.3330754041671753, 1.328259825706482, -1.0921270847320557, -0.8395169973373413, 0.19097915291786194, -0.7174965143203735, -0.38667669892311096, -1.2541816234588623, 1.206780195236206, -1.7102421522140503, -0.47700679302215576, -1.052660346031189, -0.14367036521434784, -0.2773723900318146, 1.1634039878845215, -0.6690982580184937, 0.6491793990135193, 0.5824344158172607, 1.9263932704925537, -0.3784555196762085, 0.007957705296576023, 0.5106770992279053, 0.7592651844024658, -1.6086313724517822, -0.16065196692943573, 1.3783751726150513, -0.27803653478622437, 0.20709708333015442, 1.0032601356506348, -0.597723662853241, -0.39770567417144775, -1.2801213264465332, 0.09244456887245178, 0.10526478290557861, -0.3907223045825958, 0.031722817569971085, -0.5475257039070129, 0.8182698488235474, -0.8162779211997986, -0.39242666959762573, -0.7452054619789124, -0.9464942812919617, -0.1594141125679016, -0.19336439669132233, -0.3765956163406372, -0.04915774613618851, 0.09386590868234634, -0.6453335881233215, 1.2107977867126465, -0.7819821834564209, 0.38448527455329895]} -------------------------------------------------------------------------------- /examples/mlp.rs: -------------------------------------------------------------------------------- 1 | use poro::central::{get_equation, Indexable, Shape, Tensor}; 2 | use std::collections::HashMap; 3 | use std::fs::read_to_string; 4 | use std::time::Instant; 5 | 6 | fn read_lines(filename: &str) -> Vec { 7 | let mut result = Vec::new(); 8 | 9 | for line in read_to_string(filename).unwrap().lines() { 10 | result.push(line.to_string()) 11 | } 12 | 13 | result 14 | } 15 | 16 | fn build_dataset_from_subset( 17 | words: &[String], 18 | stoi: &HashMap, 19 | ) -> (Vec<[usize; 3]>, Vec) { 20 | let mut xs = vec![]; 21 | let mut ys = vec![]; 22 | for word in words { 23 | let fixed = String::from("...") + word + "."; 24 | let chars: Vec = fixed.chars().collect(); 25 | for i in 0..chars.len() - 3 { 26 | let pair = (chars[i], chars[i + 1], chars[i + 2], chars[i + 3]); 27 | xs.push([stoi[&pair.0], stoi[&pair.1], stoi[&pair.2]]); 28 | ys.push(stoi[&pair.3]); 29 | } 30 | } 31 | (xs, ys) 32 | } 33 | 34 | fn add_time(timings: &mut HashMap, operation: &str, now: Instant) { 35 | if !timings.contains_key(&operation.to_string()) { 36 | timings.insert(operation.to_string(), 0); 37 | } 38 | let elapsed = now.elapsed().as_micros(); 39 | let current_time = timings.get(&operation.to_string()).unwrap(); 40 | timings.insert(operation.to_string(), current_time + elapsed); 41 | } 42 | 43 | fn main() { 44 | // let mut times = HashMap::new(); 45 | 46 | let names = read_lines("./data/bigram/names.txt"); 47 | 48 | let mut stoi = HashMap::new(); 49 | let mut itos = HashMap::new(); 50 | let mut i = 0; 51 | for c in ".abcdefghijklmnopqrstuvwxyz".chars() { 52 | stoi.insert(c, i); 53 | itos.insert(i, c); 54 | i += 1; 55 | } 56 | //1. Copy the weights and try it in an isolated test 57 | //2. Try offsetting before I push them into the gpu 58 | //3. upgrade my cuda to 12.5 59 | let n1 = (names.len() as f32 * 0.8f32) as usize; 60 | let n2 = (names.len() as f32 * 0.9f32) as usize; 61 | 62 | let (xtr, ytr) = build_dataset_from_subset(&names[..n1], &stoi); 63 | let (_xdev, _ydev) = build_dataset_from_subset(&names[n1..n2], &stoi); 64 | let (_cte, _yte) = build_dataset_from_subset(&names[n2..], &stoi); 65 | 66 | let mut c = Tensor::load_from_weight_file("./data/bigram/tensor_C.json"); 67 | 68 | c.set_requires_grad(true); 69 | let mut w1 = Tensor::load_from_weight_file("./data/bigram/tensor_W1.json"); 70 | w1.set_requires_grad(true); 71 | let mut b1 = Tensor::load_from_weight_file("./data/bigram/tensor_b1.json"); 72 | b1.set_requires_grad(true); 73 | let mut w2 = Tensor::load_from_weight_file("./data/bigram/tensor_W2.json"); 74 | w2.set_requires_grad(true); 75 | let mut b2 = Tensor::load_from_weight_file("./data/bigram/tensor_b2.json"); 76 | b2.set_requires_grad(true); 77 | 78 | const EPOCH_COUNT: usize = 25; 79 | let batch_size: usize = xtr.len(); 80 | let mut test_index_tensor = Tensor::zeroes(Shape::new(vec![batch_size, 3])); 81 | 82 | let mut timings = HashMap::new(); 83 | 84 | for epoch in 0..EPOCH_COUNT { 85 | let now = Instant::now(); 86 | let time_frame = Instant::now(); 87 | println!("Epoch: {:?}", epoch); 88 | { 89 | let mut singleton = get_equation(); 90 | singleton.zero_all_grads(); 91 | } 92 | add_time(&mut timings, "Zero", now); 93 | let now = Instant::now(); 94 | for b in 0..batch_size { 95 | test_index_tensor.set_index([b, 0].into(), vec![xtr[b][0] as f32].into()); 96 | test_index_tensor.set_index([b, 1].into(), vec![xtr[b][1] as f32].into()); 97 | test_index_tensor.set_index([b, 2].into(), vec![xtr[b][2] as f32].into()); 98 | } 99 | 100 | let test = c.view(Indexable::FromTensor(test_index_tensor.tensor_id)); 101 | let reshape = test.reshape(Shape::new(vec![batch_size, 30])); 102 | add_time(&mut timings, "Data Fill", now); 103 | let now = Instant::now(); 104 | let test_mult = reshape << w1; 105 | add_time(&mut timings, "Matmul", now); 106 | 107 | let now = Instant::now(); 108 | let test_add = test_mult + b1; 109 | add_time(&mut timings, "Add", now); 110 | let now = Instant::now(); 111 | let test_tanh = test_add.tanh_mapped(); 112 | add_time(&mut timings, "Tanh", now); 113 | 114 | let now = Instant::now(); 115 | let test_output_ = test_tanh << w2; 116 | add_time(&mut timings, "Matmul", now); 117 | let now = Instant::now(); 118 | let test_output = test_output_ + b2; 119 | add_time(&mut timings, "Add", now); 120 | 121 | let now = Instant::now(); 122 | let test_max = test_output.max(1); 123 | let test_counts = (test_output - test_max).exp(); 124 | let test_counts_sum = test_counts.sum(1); 125 | let test_counts_sum_inverted = test_counts_sum.pow(-1.0); 126 | let test_probabilities = test_counts * test_counts_sum_inverted; 127 | add_time(&mut timings, "Softmax", now); 128 | 129 | let now = Instant::now(); 130 | let mut test_ytrue_onehot = Tensor::element(Shape::new(vec![batch_size, 27]), 0.0); 131 | for b in 0..batch_size { 132 | test_ytrue_onehot.set_index([b, ytr[b]].into(), vec![1.0].into()); 133 | } 134 | 135 | let test_prob_log = test_probabilities.log(); 136 | let test_presum = test_ytrue_onehot * test_prob_log; 137 | let test_sum = (-test_presum).sum(1); 138 | let test_mean = test_sum.mean(vec![0]); 139 | add_time(&mut timings, "Loss", now); 140 | 141 | println!("Loss: {:?}", test_mean.item()); 142 | let now = Instant::now(); 143 | test_mean.backward(); 144 | add_time(&mut timings, "Backward", now); 145 | let now = Instant::now(); 146 | { 147 | let mut singleton = get_equation(); 148 | singleton.update_parameters(-0.1); 149 | } 150 | add_time(&mut timings, "Update", now); 151 | println!("Time Frame: {:?}", time_frame.elapsed().as_micros()); 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /examples/wavenet.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use poro::central::{update_parameters, zero_all_grads, Indexable, Shape, Tensor}; 4 | use poro::nn::layers::{BatchNorm1d, LinearLayer, Module, Tanh}; 5 | use poro::nn::model::{Model, Sequential}; 6 | use std::fs::read_to_string; 7 | use poro::nn::LinearLayerConfig; 8 | 9 | struct EmbeddingLayer { 10 | weight: Tensor, 11 | } 12 | 13 | impl EmbeddingLayer { 14 | pub fn new(number_of_embeddings: usize, embedding_dims: usize) -> Self { 15 | EmbeddingLayer { 16 | weight: Tensor::randn(Shape::new(vec![number_of_embeddings, embedding_dims])), 17 | } 18 | } 19 | 20 | pub fn from_tensor(weights: Tensor) -> Self { 21 | EmbeddingLayer { weight: weights } 22 | } 23 | } 24 | 25 | impl Module for EmbeddingLayer { 26 | fn forward(&mut self, input: &Tensor) -> Tensor { 27 | let output_tensor = self.weight.view(Indexable::FromTensor(input.tensor_id)); 28 | return output_tensor; 29 | } 30 | 31 | fn get_parameters(&self) -> Vec { 32 | vec![self.weight.clone()] 33 | } 34 | } 35 | 36 | impl From for Box { 37 | fn from(layer: EmbeddingLayer) -> Box { 38 | Box::new(layer) 39 | } 40 | } 41 | 42 | struct FlattenConsecutive { 43 | block_size: usize, 44 | } 45 | 46 | impl FlattenConsecutive { 47 | pub fn new(block_size: usize) -> Self { 48 | FlattenConsecutive { block_size } 49 | } 50 | } 51 | 52 | impl Module for FlattenConsecutive { 53 | fn forward(&mut self, input: &Tensor) -> Tensor { 54 | let b = input.shape.indices[0]; 55 | let t = input.shape.indices[1]; 56 | let c = input.shape.indices[2]; 57 | let new_middle = t / self.block_size; 58 | let new_end = self.block_size * c; 59 | 60 | let new_input = input.reshape(Shape::new(vec![b, new_middle, new_end])); 61 | 62 | if new_input.shape.indices[1] == 1 { 63 | return new_input.squeeze(1); 64 | } 65 | 66 | return new_input; 67 | } 68 | 69 | fn get_parameters(&self) -> Vec { 70 | vec![] 71 | } 72 | } 73 | 74 | impl From for Box { 75 | fn from(layer: FlattenConsecutive) -> Box { 76 | Box::new(layer) 77 | } 78 | } 79 | 80 | fn build_wavenet_dataset_from_subset( 81 | words: &[String], 82 | stoi: &HashMap, 83 | ) -> (Vec<[usize; 8]>, Vec) { 84 | let mut xs = vec![]; 85 | let mut ys = vec![]; 86 | for word in words { 87 | let fixed = String::from("........") + word + "."; 88 | let chars: Vec = fixed.chars().collect(); 89 | for i in 0..chars.len() - 8 { 90 | let pair = ( 91 | chars[i], 92 | chars[i + 1], 93 | chars[i + 2], 94 | chars[i + 3], 95 | chars[i + 4], 96 | chars[i + 5], 97 | chars[i + 6], 98 | chars[i + 7], 99 | chars[i + 8], 100 | ); 101 | xs.push([ 102 | stoi[&pair.0], 103 | stoi[&pair.1], 104 | stoi[&pair.2], 105 | stoi[&pair.3], 106 | stoi[&pair.4], 107 | stoi[&pair.5], 108 | stoi[&pair.6], 109 | stoi[&pair.7], 110 | ]); 111 | ys.push(stoi[&pair.8]); 112 | } 113 | } 114 | (xs, ys) 115 | } 116 | 117 | fn read_lines(filename: &str) -> Vec { 118 | let mut result = Vec::new(); 119 | 120 | for line in read_to_string(filename).unwrap().lines() { 121 | result.push(line.to_string()) 122 | } 123 | 124 | result 125 | } 126 | 127 | fn main() { 128 | let n_embd = 24; 129 | let n_hidden = 128; 130 | let block_size = 8; 131 | 132 | let names = read_lines("./data/bigram/names.txt"); 133 | 134 | let mut stoi = HashMap::new(); 135 | let mut itos = HashMap::new(); 136 | let mut i = 0; 137 | for c in ".abcdefghijklmnopqrstuvwxyz".chars() { 138 | stoi.insert(c, i); 139 | itos.insert(i, c); 140 | i += 1; 141 | } 142 | let n1 = (names.len() as f32 * 0.8f32) as usize; 143 | let (xtr, ytr) = build_wavenet_dataset_from_subset(&names[..n1], &stoi); 144 | let linear_layer_config_1 = LinearLayerConfig::new(n_embd * 2, n_hidden); 145 | let linear_layer_config_2 = LinearLayerConfig::new(n_hidden * 2, n_hidden); 146 | let linear_layer_config_3 = LinearLayerConfig::new(n_hidden * 2, n_hidden); 147 | let linear_layer_config_4 = LinearLayerConfig::new(n_hidden, 27); 148 | let mut model: Sequential = vec![ 149 | EmbeddingLayer::new(27, n_embd).into(), 150 | FlattenConsecutive::new(2).into(), 151 | LinearLayer::new(linear_layer_config_1).into(), 152 | BatchNorm1d::new(n_hidden).into(), 153 | Tanh::new().into(), 154 | FlattenConsecutive::new(2).into(), 155 | LinearLayer::new(linear_layer_config_2).into(), 156 | BatchNorm1d::new(n_hidden).into(), 157 | Tanh::new().into(), 158 | FlattenConsecutive::new(2).into(), 159 | LinearLayer::new(linear_layer_config_3).into(), 160 | BatchNorm1d::new(n_hidden).into(), 161 | Tanh::new().into(), 162 | LinearLayer::new(linear_layer_config_4).into(), 163 | ] 164 | .into(); 165 | 166 | model.set_requires_grad(true); 167 | 168 | let max_steps = 10; 169 | let batch_size = 32; 170 | 171 | for _i in 0..max_steps { 172 | zero_all_grads(); 173 | let mut test_index_tensor = Tensor::zeroes(Shape::new(vec![batch_size, 8])); 174 | for b in 0..batch_size { 175 | for b_index in 0..block_size { 176 | test_index_tensor 177 | .set_index([b, b_index].into(), vec![xtr[b][b_index] as f32].into()); 178 | } 179 | } 180 | let test = model.forward(&test_index_tensor); 181 | let mut test_ytrue_onehot = Tensor::element(Shape::new(vec![batch_size, 27]), 0.0); 182 | for b in 0..batch_size { 183 | test_ytrue_onehot.set_index([b, ytr[b]].into(), vec![1.0].into()); 184 | } 185 | let loss = test.cross_entropy_loss(test_ytrue_onehot); 186 | println!("Loss: {}", loss.item()); 187 | loss.backward(); 188 | update_parameters(0.01); 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /src/nn/model/transfomer.rs: -------------------------------------------------------------------------------- 1 | use crate::central::*; 2 | use crate::nn::*; 3 | use crate::nn::layers::*; 4 | 5 | pub struct PositionalEncoding { 6 | pub encoding: Tensor, 7 | } 8 | 9 | impl PositionalEncoding { 10 | pub fn new(d_model: usize, max_len: usize) -> Self { 11 | let mut encoding = Tensor::zeroes(vec![max_len, d_model].into()); 12 | let position = Tensor::arange(0, max_len, 1).reshape(vec![max_len, 1].into()); 13 | let mut position_real = Tensor::element(vec![10, 5].into(), 1.0); 14 | let position_data_as_vec = position.item().into_raw_vec(); 15 | for x in 0..10 { 16 | for y in 0..5 { 17 | position_real.set_index(Indexable::Single(x * 5 + y), vec![position_data_as_vec[x]]); 18 | } 19 | 20 | } 21 | 22 | let po = Tensor::arange(0, d_model, 2) * (-(10000.0f32).ln() / d_model as f32); 23 | let div_term = Tensor::exp(&po).reshape(vec![d_model / 2].into()); 24 | 25 | let mut div_term_vec = vec![0.0;50]; 26 | 27 | for y in 0..10 { 28 | for x in 0..5 { 29 | div_term_vec[y * 5 + x] = div_term.item().into_raw_vec()[x]; 30 | } 31 | } 32 | 33 | let div_term_real = Tensor::from_vec(div_term_vec, vec![10, 5].into()); 34 | 35 | let test = div_term_real * position_real; 36 | 37 | let sin = test.sin(); 38 | let cos = test.cos(); 39 | 40 | for i in 0..max_len { 41 | for j in 0..d_model { 42 | if j % 2 == 0 { 43 | encoding.set_index(Indexable::Double(i, j), vec![sin.view(Indexable::Double(i, j / 2)).item()[0]]); 44 | } else { 45 | encoding.set_index(Indexable::Double(i, j), vec![cos.view(Indexable::Double(i, j / 2)).item()[0]]); 46 | } 47 | } 48 | } 49 | 50 | Self { encoding } 51 | } 52 | } 53 | 54 | impl Module for PositionalEncoding { 55 | fn forward(&mut self, x: &Tensor) -> Tensor { 56 | let indexes_as_tensor = Tensor::arange(0, x.shape.indices[1], 1); 57 | *x + self.encoding.view(Indexable::FromTensor(indexes_as_tensor.tensor_id)) 58 | } 59 | 60 | fn get_parameters(&self) -> Vec { 61 | vec![self.encoding.clone()] 62 | } 63 | } 64 | 65 | // One day it would be great to get this working in Rust, but for now, here is a Python implementation of a Transformer model using PyTorch: 66 | /* 67 | */ 68 | 69 | pub struct DecoderLayer { 70 | self_attention: AttentionHead, 71 | feed_forward: LinearLayer, 72 | 73 | norm1: BatchNorm1d, 74 | norm2: BatchNorm1d, 75 | 76 | dropout1: Dropout, 77 | dropout2: Dropout, 78 | } 79 | 80 | impl DecoderLayer { 81 | pub fn new(d_model: usize, num_heads: usize, d_ff: usize) -> Self { 82 | let self_attention = AttentionHead::new(d_model, num_heads); 83 | let linear_layer_config = LinearLayerConfig { number_of_inputs: d_model, number_of_weights: d_ff }; 84 | let feed_forward = LinearLayer::new(linear_layer_config); 85 | let norm1 = BatchNorm1d::new(d_model); 86 | let norm2 = BatchNorm1d::new(d_model); 87 | let dropout1 = Dropout::new(0.1); 88 | let dropout2 = Dropout::new(0.1); 89 | Self { self_attention, feed_forward, norm1, norm2, dropout1, dropout2 } 90 | } 91 | } 92 | 93 | impl Module for DecoderLayer { 94 | fn forward(&mut self, x: &Tensor) -> Tensor { 95 | let attn_output = self.self_attention.forward(x); 96 | let x = *x + self.dropout1.forward(&attn_output); 97 | let x = self.norm1.forward(&x); 98 | let ff_output = self.feed_forward.forward(&x); 99 | let x = x + self.dropout2.forward(&ff_output); 100 | let x = self.norm2.forward(&x); 101 | x 102 | } 103 | 104 | fn get_parameters(&self) -> Vec { 105 | let mut parameters = Vec::new(); 106 | parameters.append(&mut self.self_attention.get_parameters()); 107 | parameters.append(&mut self.feed_forward.get_parameters()); 108 | parameters.append(&mut self.norm1.get_parameters()); 109 | parameters.append(&mut self.norm2.get_parameters()); 110 | parameters.append(&mut self.dropout1.get_parameters()); 111 | parameters.append(&mut self.dropout2.get_parameters()); 112 | parameters 113 | } 114 | 115 | } 116 | /* 117 | # Example usage 118 | vocab_size = 10000 # Example vocabulary size 119 | d_model = 512 # Dimensionality of the model 120 | num_layers = 6 # Number of decoder layers 121 | num_heads = 8 # Number of attention heads 122 | d_ff = 2048 # Dimensionality of the feed-forward network 123 | max_len = 5000 # Maximum sequence length 124 | 125 | model = DecoderOnlyTransformer(vocab_size, d_model, num_layers, num_heads, d_ff, max_len) 126 | input_sequence = torch.randint(0, vocab_size, (32, 100)) # Example input: batch of 32 sequences of length 100 127 | output = model(input_sequence) 128 | 129 | print(output.shape) # Should output: (32, 100, vocab_size) 130 | */ 131 | 132 | 133 | pub struct DecoderOnlyTransformer { 134 | embedding: Embedding, 135 | positional_encoding: PositionalEncoding, 136 | layers: Vec, 137 | fc_out: LinearLayer, 138 | } 139 | 140 | impl DecoderOnlyTransformer { 141 | pub fn new(vocab_size: usize, d_model: usize, num_layers: usize, num_heads: usize, d_ff: usize, max_len: usize) -> Self { 142 | let embedding = Embedding::new(vocab_size, d_model); 143 | let positional_encoding = PositionalEncoding::new(d_model, max_len); 144 | let layers = (0..num_layers).map(|_| DecoderLayer::new(d_model, num_heads, d_ff)).collect(); 145 | let linear_layer_config = LinearLayerConfig { number_of_inputs: d_model, number_of_weights: vocab_size }; 146 | let fc_out = LinearLayer::new(linear_layer_config); 147 | Self { embedding, positional_encoding, layers, fc_out } 148 | } 149 | } 150 | 151 | impl Model for DecoderOnlyTransformer { 152 | fn forward(&mut self, x: &Tensor) -> Tensor { 153 | let x = self.embedding.forward(x); 154 | let mut x = x + self.positional_encoding.forward(&x); 155 | for layer in &mut self.layers { 156 | x = layer.forward(&x); 157 | } 158 | self.fc_out.forward(&x) 159 | } 160 | 161 | fn get_parameters(&self) -> Vec { 162 | let mut parameters = Vec::new(); 163 | parameters.append(&mut self.embedding.get_parameters()); 164 | for layer in &self.layers { 165 | parameters.append(&mut layer.get_parameters()); 166 | } 167 | parameters.append(&mut self.fc_out.get_parameters()); 168 | parameters 169 | } 170 | } -------------------------------------------------------------------------------- /src/central/add_op.rs: -------------------------------------------------------------------------------- 1 | use crate::central::get_equation; 2 | use crate::central::operation::Operation; 3 | #[allow(unused_imports)] 4 | use crate::central::shape::Shape; 5 | use crate::central::tensor::Tensor; 6 | use crate::central::BackpropagationPacket; 7 | pub use ndarray::prelude::*; 8 | use std::ops::{Add, Sub}; 9 | 10 | use super::tensor::NAME_LENGTH; 11 | 12 | pub fn backward(backprop_packet: BackpropagationPacket) { 13 | if let Operation::Add(a, b) = backprop_packet.operation { 14 | // Get each of current tensor's gradient 15 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 16 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 17 | 18 | // derivative of a + b is a' + b' * global_grad 19 | backprop_packet 20 | .equation 21 | .set_tensor_grad(a, left_hand_grad + backprop_packet.grad.clone()); 22 | backprop_packet 23 | .equation 24 | .set_tensor_grad(b, right_hand_grad + backprop_packet.grad); 25 | } else { 26 | panic!("Invalid operation type for backward pass"); 27 | } 28 | } 29 | 30 | /// Overload the add operator for the Tensor struct 31 | /// This will allow us to add two tensors together 32 | /// If the tensors are not the same shape, we will broadcast the right hand side tensor 33 | /// to the shape of the left hand side tensor 34 | /// If the tensors are the same shape, we will just add them together 35 | impl Add for Tensor { 36 | type Output = Self; 37 | fn add(self, rhs: Self) -> Self::Output { 38 | if self.shape != rhs.shape { 39 | // check if we need to broadcast the tensors, and then do so 40 | // will only broadcast the right hand side tensor 41 | let right_hand_broadcasted = rhs.broadcast(self.shape); 42 | let mut singleton = get_equation(); 43 | 44 | let result_data = 45 | singleton.element_wise_add(self.tensor_id, right_hand_broadcasted.tensor_id); 46 | let tensor_id = singleton.allocate_tensor_from_operation( 47 | self.shape.clone(), 48 | result_data, 49 | Operation::Add(self.tensor_id, right_hand_broadcasted.tensor_id), 50 | ); 51 | Tensor { 52 | tensor_id, 53 | shape: self.shape, 54 | operation: Operation::Add(self.tensor_id, right_hand_broadcasted.tensor_id), 55 | name: ['a'; NAME_LENGTH], 56 | } 57 | } else { 58 | let mut singleton = get_equation(); 59 | // If they are the same size, preform the add and then return the result tensor 60 | let result_data = singleton.element_wise_add(self.tensor_id, rhs.tensor_id); 61 | 62 | let tensor_id = singleton.allocate_tensor_from_operation( 63 | self.shape.clone(), 64 | result_data, 65 | Operation::Add(self.tensor_id, rhs.tensor_id), 66 | ); 67 | 68 | Tensor { 69 | tensor_id, 70 | shape: self.shape, 71 | operation: Operation::Add(self.tensor_id, rhs.tensor_id), 72 | name: ['a'; NAME_LENGTH], 73 | } 74 | } 75 | } 76 | } 77 | 78 | /// Overload the add operator for the Tensor struct 79 | /// This will allow us to add a tensor and a f32 together 80 | /// it will turn the f32 into a tensor and then add them together 81 | impl Add for Tensor { 82 | type Output = Self; 83 | fn add(self, rhs: f32) -> Self::Output { 84 | let right_hand_as_tesnor = Tensor::element(self.shape.clone(), rhs); 85 | self + right_hand_as_tesnor 86 | } 87 | } 88 | 89 | /// Overload the sub operator for the Tensor struct 90 | /// This will allow us to subtract two tensors together 91 | /// it does it by negating the right hand side tensor and then adding them together 92 | impl Sub for Tensor { 93 | type Output = Self; 94 | fn sub(self, rhs: Self) -> Self::Output { 95 | self + -rhs 96 | } 97 | } 98 | 99 | // Overload the sub operator for the Tensor struct 100 | // This will allow us to subtract a tensor and a f32 together 101 | // it will turn the f32 into a tensor and then subtract them together 102 | impl Sub for f32 { 103 | type Output = Tensor; 104 | fn sub(self, rhs: Tensor) -> Self::Output { 105 | let right_hand_as_tesnor = Tensor::element(rhs.shape.clone(), self); 106 | right_hand_as_tesnor - rhs 107 | } 108 | } 109 | 110 | // Overload the sub operator for the Tensor struct 111 | // This will allow us to subtract a tensor and a f32 together 112 | // it will turn the f32 into a tensor and then subtract them together 113 | impl Sub for Tensor { 114 | type Output = Tensor; 115 | fn sub(self, rhs: f32) -> Self::Output { 116 | let right_hand_as_tesnor = Tensor::element(self.shape.clone(), rhs); 117 | self - right_hand_as_tesnor 118 | } 119 | } 120 | 121 | // Overload the add operator for the f32 struct 122 | // This will allow us to add a f32 and a tensor together 123 | // it will turn the f32 into a tensor and then add them together 124 | impl Add for f32 { 125 | type Output = Tensor; 126 | fn add(self, rhs: Tensor) -> Self::Output { 127 | rhs + self 128 | } 129 | } 130 | 131 | #[test] 132 | fn add_test() { 133 | // Test the add operation 134 | let a = Tensor::ones(Shape::new(vec![2, 2])); 135 | let b = Tensor::zeroes(Shape::new(vec![2, 2])); 136 | let c = a + b; 137 | let result = c.item(); 138 | assert!(result == arr2(&[[1.0, 1.0], [1.0, 1.0]]).into_dyn()); 139 | } 140 | 141 | #[test] 142 | fn add_test_2() { 143 | let a = Tensor::zeroes(Shape::new(vec![2, 2])); 144 | let b = Tensor::zeroes(Shape::new(vec![2, 2])); 145 | let c = a + b; 146 | 147 | // Test using the result of the add operation 148 | let d = c.clone() + c; 149 | let result = d.item(); 150 | assert!(result == arr2(&[[0.0, 0.0], [0.0, 0.0]]).into_dyn()); 151 | } 152 | 153 | #[test] 154 | fn backward_add_test() { 155 | // Basic verification of the backward pass 156 | let a = Tensor::ones(Shape::new(vec![1, 1])); 157 | let b = Tensor::element(Shape::new(vec![1, 1]), 2.0); 158 | let c = a + b; 159 | c.backward(); 160 | let result = c.grad(); 161 | assert!(result == arr2(&[[1.0]]).into_dyn()); 162 | let result = a.grad(); 163 | assert!(result == arr2(&[[1.0]]).into_dyn()); 164 | let result = b.grad(); 165 | assert!(result == arr2(&[[1.0]]).into_dyn()); 166 | } 167 | 168 | #[test] 169 | fn sub_test() { 170 | // sub is the same as add, but with a negative sign 171 | let a = Tensor::ones(Shape::new(vec![2, 2])); 172 | let b = Tensor::zeroes(Shape::new(vec![2, 2])); 173 | let c = a - b; 174 | let result = c.item(); 175 | assert!(result == arr2(&[[1.0, 1.0], [1.0, 1.0]]).into_dyn()); 176 | } 177 | 178 | #[test] 179 | fn sub_test_2() { 180 | let a = Tensor::zeroes(Shape::new(vec![2, 2])); 181 | let b = Tensor::zeroes(Shape::new(vec![2, 2])); 182 | let c = a - b; 183 | let d = c.clone() - c; 184 | let result = d.item(); 185 | assert!(result == arr2(&[[0.0, 0.0], [0.0, 0.0]]).into_dyn()); 186 | } 187 | 188 | #[test] 189 | fn backward_sub_test() { 190 | let a = Tensor::ones(Shape::new(vec![1, 1])); 191 | let b = Tensor::element(Shape::new(vec![1, 1]), 2.0); 192 | let c = a - b; 193 | c.backward(); 194 | let result = c.grad(); 195 | assert!(result == arr2(&[[1.0]]).into_dyn()); 196 | let result = a.grad(); 197 | assert!(result == arr2(&[[1.0]]).into_dyn()); 198 | let result = b.grad(); 199 | assert!(result == arr2(&[[-1.0]]).into_dyn()); 200 | } 201 | -------------------------------------------------------------------------------- /src/nn/layers/mlp.rs: -------------------------------------------------------------------------------- 1 | use crate::central::Tensor; 2 | use crate::nn::layers::linear::{LinearLayer, LinearLayerConfig}; 3 | use crate::nn::layers::module::Module; 4 | use std::f32::consts::PI; 5 | 6 | pub struct NewGLU { 7 | 8 | } 9 | 10 | impl Module for NewGLU { 11 | fn forward(&mut self, x: &Tensor) -> Tensor { 12 | let x_pow = x.pow(3.0); 13 | let why = 1.0 + (((2.0 / PI).sqrt() * (*x + 0.044715 * x_pow)).tanh_mapped()); 14 | return 0.5 * *x * why; 15 | } 16 | 17 | fn get_parameters(&self) -> Vec { 18 | Vec::new() 19 | } 20 | } 21 | 22 | 23 | struct MLPConfig { 24 | embedding_dim: usize, 25 | } 26 | 27 | pub struct MLP { 28 | pub c_fc: LinearLayer, 29 | pub c_proj: LinearLayer, 30 | gelu: NewGLU, 31 | } 32 | 33 | impl MLP { 34 | fn new(config: MLPConfig) -> Self { 35 | let c_fc = LinearLayer::new(LinearLayerConfig { 36 | number_of_inputs: config.embedding_dim, 37 | number_of_weights: 4 * config.embedding_dim, 38 | }); 39 | 40 | let c_proj = LinearLayer::new(LinearLayerConfig { 41 | number_of_inputs: 4 * config.embedding_dim, 42 | number_of_weights: config.embedding_dim, 43 | }); 44 | 45 | MLP { 46 | c_fc, 47 | c_proj, 48 | gelu: NewGLU {} 49 | } 50 | } 51 | 52 | pub fn from_weights_and_bias(c_fc_weights: Tensor, c_fc_bias: Tensor, c_proj_weights: Tensor, c_proj_bias: Tensor) -> Self { 53 | let c_fc = LinearLayer::from_weights_and_bias(c_fc_weights, c_fc_bias); 54 | let c_proj = LinearLayer::from_weights_and_bias(c_proj_weights, c_proj_bias); 55 | MLP { 56 | c_fc, 57 | c_proj, 58 | gelu: NewGLU {} 59 | } 60 | } 61 | } 62 | 63 | impl Module for MLP { 64 | fn forward(&mut self, x: &Tensor) -> Tensor { 65 | let x = self.c_fc.forward(x); 66 | let x = self.gelu.forward(&x); 67 | let x = self.c_proj.forward(&x); 68 | x 69 | } 70 | 71 | fn get_parameters(&self) -> Vec { 72 | let mut parameters = Vec::new(); 73 | parameters.extend(self.c_fc.get_parameters()); 74 | parameters.extend(self.c_proj.get_parameters()); 75 | parameters 76 | } 77 | } 78 | 79 | mod tests { 80 | use crate::nn::layers::linear; 81 | use crate::nn::layers::mlp::MLP; 82 | use crate::nn::layers::module::Module; 83 | use crate::central::Tensor; 84 | use crate::Shape; 85 | 86 | #[test] 87 | pub fn from_python_weights_and_bias() { 88 | 89 | use std::fs::File; 90 | let linear_1_weights_path = "data/tests/mlp/linear_1_weights.txt"; 91 | let linear_1_bias_path = "data/tests/mlp/linear_1_bias.txt"; 92 | let linear_2_weights_path = "data/tests/mlp/linear_2_weights.txt"; 93 | let linear_2_bias_path = "data/tests/mlp/linear_2_bias.txt"; 94 | let test_input_path = "data/tests/mlp/test_input.txt"; 95 | let expected_output_path = "data/tests/mlp/output.txt"; 96 | let fake_target = "data/tests/mlp/fake_target.txt"; 97 | let expected_loss = "data/tests/mlp/expected_loss.txt"; 98 | let linear_1_weight_grad_path = "data/tests/mlp/linear_1_weight_grad.txt"; 99 | let linear_1_bias_grad_path = "data/tests/mlp/linear_1_bias_grad.txt"; 100 | let linear_2_weight_grad_path = "data/tests/mlp/linear_2_weight_grad.txt"; 101 | let linear_2_bias_grad_path = "data/tests/mlp/linear_2_bias_grad.txt"; 102 | 103 | let mut linear_1_weight_file = File::open(linear_1_weights_path).unwrap(); 104 | let mut linear_1_bias_file = File::open(linear_1_bias_path).unwrap(); 105 | let mut linear_2_weight_file = File::open(linear_2_weights_path).unwrap(); 106 | let mut linear_2_bias_file = File::open(linear_2_bias_path).unwrap(); 107 | let mut test_input_file = File::open(test_input_path).unwrap(); 108 | let mut expected_output_file = File::open(expected_output_path).unwrap(); 109 | let mut fake_target_file = File::open(fake_target).unwrap(); 110 | let mut expected_loss_file = File::open(expected_loss).unwrap(); 111 | let mut linear_1_weight_grad_file = File::open(linear_1_weight_grad_path).unwrap(); 112 | let mut linear_1_bias_grad_file = File::open(linear_1_bias_grad_path).unwrap(); 113 | let mut linear_2_weight_grad_file = File::open(linear_2_weight_grad_path).unwrap(); 114 | let mut linear_2_bias_grad_file = File::open(linear_2_bias_grad_path).unwrap(); 115 | 116 | 117 | let linear_1_weights = Tensor::from_bytestream(&mut linear_1_weight_file, false).unwrap(); 118 | let linear_1_bias = Tensor::from_bytestream(&mut linear_1_bias_file, false).unwrap(); 119 | let linear_2_weights = Tensor::from_bytestream(&mut linear_2_weight_file, false).unwrap(); 120 | let linear_2_bias = Tensor::from_bytestream(&mut linear_2_bias_file, false).unwrap(); 121 | let test_input = Tensor::from_bytestream(&mut test_input_file, false).unwrap(); 122 | let expected_output = Tensor::from_bytestream(&mut expected_output_file, false).unwrap(); 123 | let fake_target = Tensor::from_bytestream(&mut fake_target_file, false).unwrap(); 124 | let expected_loss = Tensor::from_bytestream(&mut expected_loss_file, false).unwrap(); 125 | 126 | let expected_linear_1_weight_grad = Tensor::from_bytestream(&mut linear_1_weight_grad_file, false).unwrap(); 127 | let expected_linear_1_bias_grad = Tensor::from_bytestream(&mut linear_1_bias_grad_file, false).unwrap(); 128 | let expected_linear_2_weight_grad = Tensor::from_bytestream(&mut linear_2_weight_grad_file, false).unwrap(); 129 | let expected_linear_2_bias_grad = Tensor::from_bytestream(&mut linear_2_bias_grad_file, false).unwrap(); 130 | 131 | let mut mlp = MLP::from_weights_and_bias(linear_1_weights, linear_1_bias, linear_2_weights, linear_2_bias); 132 | let output = mlp.forward(&test_input); 133 | 134 | let output_as_flat_array = output.item().iter().map(|x| x.clone()).collect::>(); 135 | let expected_output_as_flat_array = expected_output.item().iter().map(|x| x.clone()).collect::>(); 136 | for (i, (a, b)) in output_as_flat_array.iter().zip(expected_output_as_flat_array.iter()).enumerate() { 137 | assert!((a - b).abs() < 1e-4, "Mismatch at index {}: {} != {}", i, a, b); 138 | } 139 | 140 | let mse_loss = (output - fake_target).pow(2.0).mean(vec![1]); 141 | 142 | for i in 0..mse_loss.shape.size() { 143 | assert!((mse_loss.item()[[0, i]] - expected_loss.item()[[0, i]]).abs() < 1e-4); 144 | } 145 | 146 | // weight grad check 147 | mse_loss.backward(); 148 | 149 | let linear_1_weight_bias = mlp.c_fc.bias.grad(); 150 | let data = expected_linear_1_bias_grad.item(); 151 | for i in 0..expected_linear_1_bias_grad.shape.size() { 152 | let left = linear_1_weight_bias[i]; 153 | let right = data[[i]]; 154 | assert!((left - right).abs() < 1e-4); 155 | } 156 | let linear_1_weight_grad = mlp.c_fc.weights.grad(); 157 | let shape = linear_1_weight_grad.shape(); 158 | let data = expected_linear_1_weight_grad.item(); 159 | for x in 0..shape[0] { 160 | for y in 0..shape[1] { 161 | let left = linear_1_weight_grad[[x, y]]; 162 | let right = data[[x, y]]; 163 | assert!((left - right).abs() < 1e-4); 164 | } 165 | } 166 | let linear_2_weight_bias = mlp.c_proj.bias.grad(); 167 | let data = expected_linear_2_bias_grad.item(); 168 | for i in 0..expected_linear_2_bias_grad.shape.size() { 169 | let left = linear_2_weight_bias[i]; 170 | let right = data[[i]]; 171 | assert!((left - right).abs() < 1e-4); 172 | } 173 | let linear_2_weight_grad = mlp.c_proj.weights.grad(); 174 | let shape = linear_2_weight_grad.shape(); 175 | let data = expected_linear_2_weight_grad.item(); 176 | for x in 0..shape[0] { 177 | for y in 0..shape[1] { 178 | let left = linear_2_weight_grad[[x, y]]; 179 | let right = data[[x, y]]; 180 | assert!((left - right).abs() < 1e-4); 181 | } 182 | } 183 | 184 | 185 | 186 | 187 | } 188 | 189 | } -------------------------------------------------------------------------------- /src/central/mul_op.rs: -------------------------------------------------------------------------------- 1 | use crate::central::get_equation; 2 | use crate::central::operation::Operation; 3 | use crate::central::tensor::Tensor; 4 | use crate::central::BackpropagationPacket; 5 | use std::ops::{Div, Mul, Neg}; 6 | 7 | use super::tensor::NAME_LENGTH; 8 | 9 | /// This function is used to perform the backward pass for the multiplication operation. 10 | /// It takes in a `BackpropagationPacket` and then sets the gradients of the left and right hand side tensors. 11 | /// # Arguments 12 | /// 13 | /// * `backprop_packet` - A `BackpropagationPacket` that contains the information needed to perform the backward pass. 14 | /// 15 | /// # Panics 16 | /// This function will panic if the operation in the `BackpropagationPacket` is not a multiplication operation. 17 | pub fn backward(backprop_packet: BackpropagationPacket) { 18 | if let Operation::Mul(a, b) = backprop_packet.operation { 19 | if backprop_packet.advanced_logging == true { 20 | println!("A: {:?}", a); 21 | println!("B: {:?}", b); 22 | } 23 | 24 | // Get the gradients of the left and right hand side tensors 25 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 26 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 27 | if backprop_packet.advanced_logging == true { 28 | println!("Left Hand Grad: {:?}", left_hand_grad); 29 | println!("Right Hand Grad: {:?}", right_hand_grad); 30 | } 31 | 32 | // Get the data of the left and right hand side tensors 33 | let left_hand_data = backprop_packet.equation.get_tensor_data(a); 34 | let right_hand_data = backprop_packet.equation.get_tensor_data(b); 35 | if backprop_packet.advanced_logging == true { 36 | println!("Left Hand Data: {:?}", left_hand_data); 37 | println!("Right Hand Data: {:?}", right_hand_data); 38 | } 39 | 40 | // The derivative of a * b is a' * b + a * b' 41 | let new_left_hand_grad = right_hand_data * backprop_packet.grad.clone(); 42 | let new_right_hand_grad = left_hand_data * backprop_packet.grad; 43 | if backprop_packet.advanced_logging == true { 44 | println!("New Left Hand Grad: {:?}", new_left_hand_grad); 45 | println!("New Right Hand Grad: {:?}", new_right_hand_grad); 46 | } 47 | 48 | // Add the new gradients to the old gradients 49 | // and then set the new gradients 50 | let right_hand_grad = right_hand_grad + new_right_hand_grad; 51 | let left_hand_grad = left_hand_grad + new_left_hand_grad; 52 | if backprop_packet.advanced_logging == true { 53 | println!("Right Hand Grad: {:?}", right_hand_grad); 54 | println!("Left Hand Grad: {:?}", left_hand_grad); 55 | } 56 | 57 | // Set the new gradients 58 | backprop_packet.equation.set_tensor_grad(a, left_hand_grad); 59 | backprop_packet.equation.set_tensor_grad(b, right_hand_grad); 60 | } else { 61 | panic!("Invalid operation type for backward pass"); 62 | } 63 | } 64 | 65 | /// Overload the multiplication operator for the Tensor struct 66 | /// This will allow us to multiply two tensors together 67 | /// 68 | /// If the tensors are not the same shape, we will broadcast the right hand side tensor 69 | /// to the shape of the left hand side tensor 70 | impl Mul for Tensor { 71 | type Output = Self; 72 | fn mul(self, rhs: Self) -> Self::Output { 73 | if self.shape != rhs.shape { 74 | // we need to broadcast the tensors 75 | let broaded_casted_rhs = rhs.broadcast(self.shape); 76 | let result_data = self.item() * broaded_casted_rhs.item(); 77 | 78 | let mut singleton = get_equation(); 79 | 80 | let data_as_vec = result_data.into_iter().collect(); 81 | let tensor_id = singleton.allocate_tensor_from_operation( 82 | self.shape.clone(), 83 | data_as_vec, 84 | Operation::Mul(self.tensor_id, broaded_casted_rhs.tensor_id), 85 | ); 86 | 87 | Tensor { 88 | tensor_id, 89 | shape: self.shape, 90 | operation: Operation::Mul(self.tensor_id, broaded_casted_rhs.tensor_id), 91 | name: ['a'; NAME_LENGTH], 92 | } 93 | } else { 94 | let result_data = self.item() * rhs.item(); 95 | 96 | let mut singleton = get_equation(); 97 | 98 | let data_as_vec = result_data.into_iter().collect(); 99 | let tensor_id = singleton.allocate_tensor_from_operation( 100 | self.shape.clone(), 101 | data_as_vec, 102 | Operation::Mul(self.tensor_id, rhs.tensor_id), 103 | ); 104 | 105 | Tensor { 106 | tensor_id, 107 | shape: self.shape, 108 | operation: Operation::Mul(self.tensor_id, rhs.tensor_id), 109 | name: ['a'; NAME_LENGTH], 110 | } 111 | } 112 | } 113 | } 114 | 115 | /// Overload the multiplication operator for the Tensor struct 116 | /// This will allow us to multiply a tensor by a scalar 117 | /// we will convert the scalar into a tensor and then multiply the two tensors together 118 | impl Mul for Tensor { 119 | type Output = Self; 120 | fn mul(self, rhs: f32) -> Self::Output { 121 | let right_hand_as_tesnor = Tensor::element(self.shape.clone(), rhs); 122 | self * right_hand_as_tesnor 123 | } 124 | } 125 | 126 | /// Convinence function to allow us to handle the element wise negation of a tensor 127 | impl Neg for Tensor { 128 | type Output = Self; 129 | fn neg(self) -> Self::Output { 130 | self * -1.0 131 | } 132 | } 133 | 134 | /// Overload the Division operator for the Tensor struct 135 | /// we are going to use the fact that a/b = a * b^-1 136 | impl Div for Tensor { 137 | type Output = Self; 138 | fn div(self, rhs: Self) -> Self::Output { 139 | // we take advantage of the fact that a/b = a * b^-1 140 | // to let us keep the code simplier 141 | let intermidiate = rhs.pow(-1.0); 142 | self * intermidiate 143 | } 144 | } 145 | 146 | /// Overload the Division operator for the Tensor struct 147 | /// This will allow us to divide a tensor by a scalar 148 | /// we will convert the scalar into a tensor and then divide the two tensors together 149 | impl Div for Tensor { 150 | type Output = Self; 151 | fn div(self, rhs: f32) -> Self::Output { 152 | let right_hand_as_tesnor = Tensor::element(self.shape.clone(), rhs); 153 | self / right_hand_as_tesnor 154 | } 155 | } 156 | 157 | /// Overload the Multiplication operator for the f32 struct 158 | /// This will allow us to multiply a scalar by a tensor 159 | /// we will convert the scalar into a tensor and then multiply the two tensors together 160 | impl Mul for f32 { 161 | type Output = Tensor; 162 | fn mul(self, rhs: Tensor) -> Self::Output { 163 | rhs * self 164 | } 165 | } 166 | 167 | mod test { 168 | #[allow(unused_imports)] 169 | use crate::central::shape::Shape; 170 | #[allow(unused_imports)] 171 | use crate::central::tensor::Tensor; 172 | #[allow(unused_imports)] 173 | use ndarray::prelude::*; 174 | 175 | #[test] 176 | fn mul_test() { 177 | let a = Tensor::ones(Shape::new(vec![2, 2])); 178 | let b = Tensor::ones(Shape::new(vec![2, 2])); 179 | let c = a * b; 180 | let result = c.item(); 181 | assert!(result == arr2(&[[1.0, 1.0], [1.0, 1.0]]).into_dyn()); 182 | } 183 | 184 | #[test] 185 | fn mul_test_2() { 186 | let a = Tensor::ones(Shape::new(vec![2, 2])); 187 | let b = Tensor::element(Shape::new(vec![2, 2]), 2.0); 188 | let c = a * b; 189 | let result = c.item(); 190 | assert!(result == arr2(&[[2.0, 2.0], [2.0, 2.0]]).into_dyn()); 191 | } 192 | 193 | #[test] 194 | fn backward_mul_test() { 195 | let a = Tensor::ones(Shape::new(vec![1, 1])); 196 | let b = Tensor::element(Shape::new(vec![1, 1]), 2.0); 197 | let c = a * b; 198 | c.backward(); 199 | let result = c.grad(); 200 | assert!(result == arr2(&[[1.0]]).into_dyn()); 201 | let result = a.grad(); 202 | assert!(result == arr2(&[[2.0]]).into_dyn()); 203 | let result = b.grad(); 204 | assert!(result == arr2(&[[1.0]]).into_dyn()); 205 | } 206 | 207 | #[test] 208 | fn basic_div_test() { 209 | let a = Tensor::element(Shape::new(vec![1, 1]), 2.0); 210 | let b = a / 2.0; 211 | let result = b.item(); 212 | assert!(result == arr2(&[[1.0]]).into_dyn()); 213 | } 214 | 215 | #[test] 216 | fn backward_div_test() { 217 | let a = Tensor::element(Shape::new(vec![1, 1]), 2.0); 218 | let b = a / 2.0; 219 | b.backward(); 220 | let result = b.grad(); 221 | assert!(result == arr2(&[[1.0]]).into_dyn()); 222 | let result = a.grad(); 223 | assert!(result == arr2(&[[0.5]]).into_dyn()); 224 | } 225 | } 226 | -------------------------------------------------------------------------------- /src/central/shape.rs: -------------------------------------------------------------------------------- 1 | use crate::Indexable; 2 | 3 | pub const MAX_NUMBER_OF_INDICES: usize = 10; 4 | 5 | /// Represents the shape of a tensor. 6 | #[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)] 7 | pub struct Shape { 8 | pub number_of_indices: usize, 9 | pub indices: [usize; MAX_NUMBER_OF_INDICES], 10 | } 11 | 12 | impl Shape { 13 | /// Creates a new `Shape` instance with the given indices. 14 | /// 15 | /// # Arguments 16 | /// 17 | /// * `indices` - A vector of indices representing the shape. 18 | /// 19 | /// # Panics 20 | /// 21 | /// This function will panic if the number of indices exceeds the maximum number of indices. 22 | pub fn new(indices: Vec) -> Shape { 23 | assert!(indices.len() <= MAX_NUMBER_OF_INDICES); 24 | let mut local_indices = [0; MAX_NUMBER_OF_INDICES]; 25 | for (i, index) in indices.iter().enumerate() { 26 | local_indices[i] = *index; 27 | } 28 | 29 | Shape { 30 | number_of_indices: indices.len(), 31 | indices: local_indices, 32 | } 33 | } 34 | 35 | /// Returns the total size of the shape. 36 | pub fn total_size(&self) -> usize { 37 | let mut total = 1; 38 | for i in 0..self.number_of_indices { 39 | total *= self.indices[i]; 40 | } 41 | total 42 | } 43 | 44 | pub fn remove(&self, index: usize) -> Shape { 45 | let mut new_indices = Vec::new(); 46 | for i in 0..self.number_of_indices { 47 | if i != index { 48 | new_indices.push(self.indices[i]); 49 | } 50 | } 51 | Shape::new(new_indices) 52 | } 53 | 54 | /// Returns the shape as a vector of usize values. 55 | pub fn as_ndarray_shape(&self) -> Vec { 56 | let mut shape = Vec::new(); 57 | for i in 0..self.number_of_indices { 58 | shape.push(self.indices[i]); 59 | } 60 | shape 61 | } 62 | 63 | /// Returns the shape resulting from matrix multiplication with another shape. 64 | /// 65 | /// # Arguments 66 | /// 67 | /// * `other` - The shape to multiply with. 68 | /// 69 | /// # Panics 70 | /// 71 | /// This function will panic if the shapes are not compatible for matrix multiplication. 72 | pub fn matmul_shape(&self, other: &Shape) -> Shape { 73 | 74 | if self.number_of_indices == 1 && other.number_of_indices == 2 { 75 | assert!(self.indices[0] == other.indices[0]); 76 | return Shape::new(vec![other.indices[1]]); 77 | } 78 | 79 | if self.number_of_indices == 2 && other.number_of_indices == 1 { 80 | assert!(self.indices[1] == other.indices[0], "{} != {}", self.indices[1], other.indices[0]); 81 | return Shape::new(vec![self.indices[0]]); 82 | } 83 | if self.number_of_indices == 2 && other.number_of_indices == 2 { 84 | assert!(self.indices[1] == other.indices[0], "{} != {}", self.indices[1], other.indices[0]); 85 | let mut new_indices = Vec::new(); 86 | new_indices.push(self.indices[0]); 87 | new_indices.push(other.indices[1]); 88 | return Shape::new(new_indices); 89 | } 90 | if self.number_of_indices == 3 && other.number_of_indices == 2 { 91 | assert!(self.indices[2] == other.indices[0]); 92 | let mut new_indices = Vec::new(); 93 | new_indices.push(self.indices[0]); 94 | new_indices.push(self.indices[1]); 95 | new_indices.push(other.indices[1]); 96 | return Shape::new(new_indices); 97 | } 98 | if self.number_of_indices == 3 && other.number_of_indices == 3 { 99 | assert!(self.indices[2] == other.indices[1]); 100 | let mut new_indices = Vec::new(); 101 | new_indices.push(self.indices[0]); 102 | new_indices.push(self.indices[1]); 103 | new_indices.push(other.indices[2]); 104 | return Shape::new(new_indices); 105 | } 106 | if self.number_of_indices == 3 && other.number_of_indices == 1 { 107 | assert!(self.indices[2] == other.indices[0]); 108 | let mut new_indices = Vec::new(); 109 | new_indices.push(self.indices[0]); 110 | new_indices.push(self.indices[1]); 111 | return Shape::new(new_indices); 112 | } 113 | 114 | if self.number_of_indices == 4 && other.number_of_indices == 4 { 115 | assert!(self.indices[3] == other.indices[2]); 116 | let mut new_indices = Vec::new(); 117 | new_indices.push(self.indices[0]); 118 | new_indices.push(self.indices[1]); 119 | new_indices.push(self.indices[2]); 120 | new_indices.push(other.indices[3]); 121 | return Shape::new(new_indices); 122 | } 123 | 124 | 125 | panic!("Not implemented"); 126 | } 127 | 128 | 129 | pub fn matmul_shape_generic(&self, other: &Shape) -> Shape { 130 | // Check if both shapes have at least 1 axis 131 | assert!(self.number_of_indices >= 1 && other.number_of_indices >= 1); 132 | 133 | // Matrices must satisfy matrix multiplication rules: 134 | // The last dimension of self (left) should match the first dimension of other (right). 135 | assert!(self.indices[self.number_of_indices - 1] == other.indices[0]); 136 | 137 | // Build the resulting shape 138 | let mut new_indices = Vec::new(); 139 | 140 | // Broadcasting the leading dimensions (all dimensions except the last for `self` 141 | // and all dimensions except the first for `other`) 142 | for i in 0..(self.number_of_indices - 1).max(other.number_of_indices - 1) { 143 | if i < self.number_of_indices - 1 && i < other.number_of_indices - 1 { 144 | // Both matrices have this dimension, so we need to broadcast them 145 | assert!(self.indices[i] == 1 || other.indices[i + 1] == 1 || self.indices[i] == other.indices[i + 1]); 146 | new_indices.push(self.indices[i].max(other.indices[i + 1])); 147 | } else if i < self.number_of_indices - 1 { 148 | // Only `self` has this dimension 149 | new_indices.push(self.indices[i]); 150 | } else { 151 | // Only `other` has this dimension 152 | new_indices.push(other.indices[i + 1]); 153 | } 154 | } 155 | 156 | // Append the final dimension from the right matrix 157 | new_indices.push(other.indices[other.number_of_indices - 1]); 158 | 159 | Shape::new(new_indices) 160 | } 161 | 162 | 163 | /// Returns the size of the shape. 164 | pub fn size(&self) -> usize { 165 | let mut size = 1; 166 | for i in 0..self.number_of_indices { 167 | size *= self.indices[i]; 168 | } 169 | size 170 | } 171 | 172 | /// Returns the index at the specified position. 173 | /// 174 | /// # Arguments 175 | /// 176 | /// * `index` - The index position. 177 | /// 178 | /// # Panics 179 | /// 180 | /// This function will panic if the index is out of range. 181 | pub fn get_index(&self, index: Indexable) -> usize { 182 | match index { 183 | Indexable::Single(i) => { 184 | assert!(i < self.number_of_indices); 185 | self.indices[i] 186 | } 187 | Indexable::Double(_a, _b) => { 188 | panic!("Not implemented"); 189 | } 190 | Indexable::Triple(_a, _b, _c) => { 191 | panic!("Not implemented"); 192 | } 193 | Indexable::FromTensor(_) => { 194 | panic!("Not implemented"); 195 | } 196 | } 197 | } 198 | 199 | /// Returns a subshape of the current shape based on the given index. 200 | /// 201 | /// # Arguments 202 | /// 203 | /// * `index` - The index to create the subshape from. 204 | pub fn subshape_from_indexable(&self, index: Indexable) -> Shape { 205 | match index { 206 | Indexable::Single(_i) => { 207 | assert!(self.indices.len() > 1); 208 | let mut new_indices = vec![]; 209 | for j in 1..self.number_of_indices { 210 | new_indices.push(self.indices[j]); 211 | } 212 | Shape::new(new_indices) 213 | } 214 | Indexable::Double(_a, _b) => { 215 | if self.number_of_indices == 2 { 216 | return Shape::new(vec![1]); 217 | } 218 | 219 | let mut new_indices = vec![]; 220 | for j in 1..self.number_of_indices { 221 | new_indices.push(self.indices[j]); 222 | } 223 | Shape::new(new_indices) 224 | } 225 | Indexable::Triple(_a, _b, _c) => { 226 | if self.number_of_indices == 3 { 227 | return Shape::new(vec![1]); 228 | } 229 | 230 | let mut new_indices = vec![]; 231 | for j in 1..self.number_of_indices { 232 | new_indices.push(self.indices[j]); 233 | } 234 | Shape::new(new_indices) 235 | } 236 | Indexable::FromTensor(_tensor) => { 237 | return Shape::new(self.indices.to_vec()); 238 | } 239 | } 240 | } 241 | } 242 | 243 | impl From> for Shape { 244 | fn from(indices: Vec) -> Shape { 245 | Shape::new(indices) 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /src/nn/layers/layer_norm.rs: -------------------------------------------------------------------------------- 1 | use core::panic; 2 | 3 | use crate::central::Tensor; 4 | use crate::nn::layers::module::Module; 5 | 6 | pub struct LayerNorm { 7 | pub weight: Tensor, 8 | bias: Tensor, 9 | length_of_normalized_shape: usize, 10 | normalized_input: Option, 11 | var: Option, 12 | mean: Option, 13 | input_minus_mean: Option, 14 | } 15 | 16 | impl LayerNorm { 17 | pub fn new(embedding_dim: usize) -> Self { 18 | let weight = Tensor::randn(vec![embedding_dim].into()); 19 | let bias = Tensor::randn(vec![embedding_dim].into()); 20 | LayerNorm { 21 | weight, 22 | bias, 23 | length_of_normalized_shape: weight.shape.number_of_indices, 24 | normalized_input: None, 25 | var: None, 26 | mean: None, 27 | input_minus_mean: None 28 | } 29 | } 30 | 31 | pub fn from_weights_and_bias(weight: Tensor, bias: Tensor) -> Self { 32 | LayerNorm { 33 | weight, 34 | bias, 35 | length_of_normalized_shape: weight.shape.number_of_indices, 36 | normalized_input: None, 37 | var: None, 38 | mean: None, 39 | input_minus_mean: None 40 | } 41 | } 42 | } 43 | 44 | impl Module for LayerNorm { 45 | fn forward(&mut self, x: &Tensor) -> Tensor { 46 | let mean_indices = vec![2]; 47 | 48 | let mean = x.mean(mean_indices.clone()); 49 | 50 | self.mean = Some(mean.clone()); 51 | 52 | let input_minus_mean = *x - mean; 53 | 54 | self.input_minus_mean = Some(input_minus_mean.clone()); 55 | 56 | let var: Tensor = x.var(mean_indices);;//(input_minus_mean - input_minus_mean).pow(2.0).mean(mean_indices.clone()); 57 | // 58 | self.var = Some(var.clone()); 59 | 60 | let std_inv = (var + 1e-5).pow(0.5); 61 | 62 | let normalized_input = input_minus_mean / std_inv; 63 | 64 | self.normalized_input = Some(normalized_input.clone()); 65 | 66 | let output = normalized_input * self.weight + self.bias; 67 | 68 | 69 | output 70 | } 71 | 72 | fn get_parameters(&self) -> Vec { 73 | vec![self.weight.clone(), self.bias.clone()] 74 | } 75 | } 76 | mod tests { 77 | use core::panic; 78 | 79 | use ndarray::prelude::*; 80 | #[test] 81 | fn mean_tests() { 82 | let array = ArrayD::from_shape_simple_fn(vec![5, 5, 5], ||{return 5.0}); 83 | let test = Tensor::from_vec(array.into_iter().collect(), vec![5, 5, 5].into()); 84 | let test = test.mean(vec![1, 2]); 85 | println!("{:?}", test.shape); 86 | } 87 | use ndarray::iter::Axes; 88 | use ndarray::Axis; 89 | 90 | use crate::central::Tensor; 91 | use crate::nn::layers::layer_norm::LayerNorm; 92 | use crate::nn::layers::module::Module; 93 | use simplelog::*; 94 | use log::info; 95 | use std::path::Path; 96 | use std::fs::File; 97 | 98 | #[test] 99 | fn simple_test() { 100 | 101 | WriteLogger::init( 102 | LevelFilter::Info, // Set the log level 103 | Config::default(), // Use the default configuration 104 | File::create(Path::new("./LayerNormTest.log")).unwrap(), // Create or open the log file 105 | ).unwrap(); 106 | 107 | let test_input_a = Tensor::from_vec(vec![1.0, 2., 3., 4., 5.0, 6.0], vec![2, 2, 2].into()); 108 | let test_input_b = Tensor::from_vec(vec![7., 8., 9., 10.0, 11.0, 12.], vec![2, 2, 2].into()); 109 | let test_input = test_input_a + test_input_b; 110 | let test_mean = test_input.mean(vec![2]); 111 | let fake_output = Tensor::from_vec(vec![5.0, 10., 15.], vec![3, 1].into()); 112 | let error = test_mean - fake_output; 113 | let error_mean = error.mean(vec![0]); 114 | println!("{:?}", error_mean.item()); 115 | error_mean.backward(); 116 | panic!("test_input {:?}", test_input.get_id()); 117 | 118 | } 119 | 120 | #[test] 121 | fn test_layer_norm() { 122 | 123 | WriteLogger::init( 124 | LevelFilter::Info, // Set the log level 125 | Config::default(), // Use the default configuration 126 | File::create(Path::new("./LayerNormTest.log")).unwrap(), // Create or open the log file 127 | ).unwrap(); 128 | 129 | 130 | 131 | let layer_norm_weights_path = "data/tests/layer_norm/layer_norm_weights.txt"; 132 | let layer_norm_bias_path = "data/tests/layer_norm/layer_norm_bias.txt"; 133 | let test_input_path = "data/tests/layer_norm/test_input.txt"; 134 | let test_input_a_path = "data/tests/layer_norm/test_input_a.txt"; 135 | let test_input_b_path = "data/tests/layer_norm/test_input_b.txt"; 136 | let expected_output_path = "data/tests/layer_norm/expected_output.txt"; 137 | let fake_target = "data/tests/layer_norm/fake_target.txt"; 138 | let expected_loss = "data/tests/layer_norm/expected_loss.txt"; 139 | 140 | let mut layer_norm_weights_file = File::open(layer_norm_weights_path).unwrap(); 141 | let mut layer_norm_bias_file = File::open(layer_norm_bias_path).unwrap(); 142 | let mut test_input_file = File::open(test_input_path).unwrap(); 143 | let mut test_input_a_file = File::open(test_input_a_path).unwrap(); 144 | let mut test_input_b_file = File::open(test_input_b_path).unwrap(); 145 | 146 | 147 | let mut expected_output_file = File::open(expected_output_path).unwrap(); 148 | let mut fake_target_file = File::open(fake_target).unwrap(); 149 | let mut expected_loss_file = File::open(expected_loss).unwrap(); 150 | 151 | let mut layer_norm_weight_grad_file = File::open("data/tests/layer_norm/layer_norm_weights_grad.txt").unwrap(); 152 | let mut layer_norm_bias_grad_file = File::open("data/tests/layer_norm/layer_norm_bias_grad.txt").unwrap(); 153 | 154 | // Load weights and bias from files 155 | let layer_norm_weight= Tensor::from_bytestream(&mut layer_norm_weights_file, false).unwrap(); 156 | let layer_norm_bias = Tensor::from_bytestream(&mut layer_norm_bias_file, false).unwrap(); 157 | let test_input = Tensor::from_bytestream(&mut test_input_file, false).unwrap(); 158 | let test_input_a = Tensor::from_bytestream(&mut test_input_a_file, false).unwrap(); 159 | let test_input_b = Tensor::from_bytestream(&mut test_input_b_file, false).unwrap(); 160 | let expected_output = Tensor::from_bytestream(&mut expected_output_file, false).unwrap(); 161 | let fake_target = Tensor::from_bytestream(&mut fake_target_file, false).unwrap(); 162 | let expected_loss = Tensor::from_bytestream(&mut expected_loss_file, false).unwrap(); 163 | let expected_layer_norm_weight_grad = Tensor::from_bytestream(&mut layer_norm_weight_grad_file, false).unwrap(); 164 | let layer_norm_bias_grad = Tensor::from_bytestream(&mut layer_norm_bias_grad_file, false).unwrap(); 165 | 166 | 167 | 168 | // Create LayerNorm instance 169 | let mut layer_norm = LayerNorm::from_weights_and_bias(layer_norm_weight, layer_norm_bias); 170 | let real_test_input = test_input_a + test_input_b; 171 | // Perform forward pass 172 | let output = layer_norm.forward(&real_test_input); 173 | 174 | 175 | let ouput_as_flat_array = output.item().iter().map(|x| x.clone()).collect::>(); 176 | let expected_as_flat_array = expected_output.item().iter().map(|x| x.clone()).collect::>(); 177 | 178 | // Check if the output is approximately equal to the expected output 179 | for (o, e) in ouput_as_flat_array.iter().zip(expected_as_flat_array.iter()) { 180 | // assert!((o - e).abs() < 1e-4, "Output {} is not approximately equal to expected {}", o, e); 181 | } 182 | 183 | let diff = output - fake_target; 184 | 185 | 186 | let mse_loss = diff.pow(2.0).reshape(vec![4 * 64 * 768].into()).mean(vec![0]); 187 | 188 | let expected_mse_loss = expected_loss.item(); 189 | for (loss, expected) in mse_loss.item().iter().zip(expected_mse_loss.iter()) { 190 | // assert!((loss - expected).abs() < 1e-2, "Loss {} is not approximately equal to expected {}", loss, expected); 191 | } 192 | 193 | mse_loss.backward(); 194 | //println!("{:?}", output.item()); 195 | println!("{:?}",real_test_input.grad()); 196 | //println!("{:?}",layer_norm.mean.unwrap().grad()); 197 | //println!("{:?}", mse_loss.item()); 198 | panic!("---"); 199 | 200 | let expected_layer_norm_weight_grad_output_flatten = expected_layer_norm_weight_grad.item().iter().map(|x| x.clone()).collect::>(); 201 | let layer_norm_weight_grad_output_flatten = layer_norm.weight.grad().iter().map(|x| x.clone()).collect::>(); 202 | 203 | for (g, eg) in layer_norm_weight_grad_output_flatten.iter().zip(expected_layer_norm_weight_grad_output_flatten.iter()) { 204 | assert!((g - eg).abs() < 1e-2, "Gradient {} is not approximately equal to expected gradient {}", g, eg); 205 | } 206 | 207 | let expected_layer_norm_bias_grad_output_flatten = layer_norm_bias_grad.item().iter().map(|x| x.clone()).collect::>(); 208 | let layer_norm_bias_grad_output_flatten = layer_norm.bias.grad().iter().map(|x| x.clone()).collect::>(); 209 | 210 | for (g, eg) in layer_norm_bias_grad_output_flatten.iter().zip(expected_layer_norm_bias_grad_output_flatten.iter()) { 211 | assert!((g - eg).abs() < 1e-2, "Gradient {} is not approximately equal to expected gradient {}", g, eg); 212 | } 213 | 214 | 215 | 216 | } 217 | } -------------------------------------------------------------------------------- /data/bigram/weight_file.json: -------------------------------------------------------------------------------- 1 | { 2 | "shape":[27, 27], 3 | "data":[1.5674e+00, -2.3729e-01, -2.7385e-02, -1.1008e+00, 2.8588e-01, 4 | -2.9643e-02, -1.5471e+00, 6.0489e-01, 7.9136e-02, 9.0462e-01, 5 | -4.7125e-01, 7.8682e-01, -3.2843e-01, -4.3297e-01, 1.3729e+00, 6 | 2.9334e+00, 1.5618e+00, -1.6261e+00, 6.7716e-01, -8.4039e-01, 7 | 9.8488e-01, -1.4837e-01, -1.4795e+00, 4.4830e-01, -7.0730e-02, 8 | 2.4968e+00, 2.4448e+00, 9 | -6.7006e-01, -1.2199e+00, 3.0314e-01, -1.0725e+00, 7.2762e-01, 10 | 5.1114e-02, 1.3095e+00, -8.0220e-01, -8.5042e-01, -1.8068e+00, 11 | 1.2523e+00, -1.2256e+00, 1.2165e+00, -9.6478e-01, -2.3211e-01, 12 | -3.4762e-01, 3.3244e-01, -1.3263e+00, 1.1224e+00, 5.9641e-01, 13 | 4.5846e-01, 5.4011e-02, -1.7400e+00, 1.1560e-01, 8.0319e-01, 14 | 5.4108e-01, -1.1646e+00, 15 | 1.4756e-01, -1.0006e+00, 3.8012e-01, 4.7328e-01, -9.1027e-01, 16 | -7.8305e-01, 1.3506e-01, -2.1161e-01, -1.0406e+00, -1.5367e+00, 17 | 9.3743e-01, -8.8303e-01, 1.7457e+00, 2.1346e+00, -8.5614e-01, 18 | 5.4082e-01, 6.1690e-01, 1.5160e+00, -1.0447e+00, -6.6414e-01, 19 | -7.2390e-01, 1.7507e+00, 1.7530e-01, 9.9280e-01, -6.2787e-01, 20 | 7.7023e-02, -1.1641e+00, 21 | 1.2473e+00, -2.7061e-01, -1.3635e+00, 1.3066e+00, 3.2307e-01, 22 | 1.0358e+00, -8.6249e-01, -1.2575e+00, 9.4180e-01, -1.3257e+00, 23 | 1.4670e-01, 1.6913e-01, -1.5397e+00, -7.2759e-01, 1.1491e+00, 24 | -8.7462e-01, -2.9771e-01, -1.3707e+00, 1.1500e-01, -1.0188e+00, 25 | -8.3777e-01, -2.1057e+00, -2.6044e-01, -1.7149e+00, -3.3787e-01, 26 | -1.8263e+00, -8.3897e-01, 27 | -1.5723e+00, 4.5795e-01, -5.6533e-01, 5.4281e-01, 1.7549e-01, 28 | -2.2901e+00, -7.0928e-01, -2.9283e-01, -2.1803e+00, 7.9311e-02, 29 | 9.0187e-01, 1.2028e+00, -5.6144e-01, -1.3753e-01, -1.3799e-01, 30 | -2.0977e+00, -7.9238e-01, 6.0689e-01, -1.4777e+00, -5.1029e-01, 31 | 5.6421e-01, 9.6838e-01, -3.1114e-01, -3.0603e-01, -1.7495e+00, 32 | -1.6335e+00, 3.8761e-01, 33 | 4.7236e-01, 1.4830e+00, 3.1748e-01, 1.0588e+00, 2.3982e+00, 34 | 4.6827e-01, -6.5650e-01, 6.1662e-01, -6.2197e-01, 5.1007e-01, 35 | 1.3563e+00, 2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01, 36 | 5.5570e-01, 4.7458e-01, -1.3867e+00, 1.6229e+00, 1.7197e-01, 37 | 9.8846e-01, 5.0657e-01, 1.0198e+00, -1.9062e+00, -4.2753e-01, 38 | -2.1259e+00, 9.6041e-01, 39 | 1.2482e+00, 2.5341e-01, 2.8188e+00, -3.3983e-01, 7.0311e-01, 40 | 4.0716e-01, -1.9018e-01, -6.9652e-01, 1.7039e+00, 7.4204e-01, 41 | 9.7370e-01, 3.0028e-01, -2.8971e-01, -3.1566e-01, -8.7898e-01, 42 | 1.0661e-01, 1.8598e+00, 5.5752e-02, 1.2815e+00, -6.3182e-01, 43 | -1.2464e+00, 6.8305e-01, -3.9455e-01, 1.4388e-02, 5.7216e-01, 44 | 8.6726e-01, 6.3149e-01, 45 | -1.2230e+00, -2.1286e-01, 5.0950e-01, 3.2713e-01, 1.9661e+00, 46 | -2.4091e-01, -7.9515e-01, 2.7198e-01, -1.1100e+00, -4.5285e-01, 47 | -4.9578e-01, 1.2648e+00, 1.4625e+00, 1.1199e+00, 9.9539e-01, 48 | -1.2353e+00, 7.3818e-01, 8.1415e-01, -7.3806e-01, 5.6714e-01, 49 | -1.4601e+00, -2.4780e-01, 8.8282e-01, -8.1004e-02, -9.5299e-01, 50 | -4.8838e-01, -7.3712e-01, 51 | 7.0609e-01, -1.9295e-01, 1.2348e+00, 3.3308e-01, 1.3283e+00, 52 | -1.0921e+00, -8.3952e-01, 1.9098e-01, -7.1750e-01, -3.8668e-01, 53 | -1.2542e+00, 1.2068e+00, -1.7102e+00, -4.7701e-01, -1.0527e+00, 54 | -1.4367e-01, -2.7737e-01, 1.1634e+00, -6.6910e-01, 6.4918e-01, 55 | 5.8243e-01, 1.9264e+00, -3.7846e-01, 7.9577e-03, 5.1068e-01, 56 | 7.5927e-01, -1.6086e+00, 57 | -1.6065e-01, 1.3784e+00, -2.7804e-01, 2.0710e-01, 1.0033e+00, 58 | -5.9772e-01, -3.9771e-01, -1.2801e+00, 9.2445e-02, 1.0526e-01, 59 | -3.9072e-01, -4.0091e-01, 5.6533e-01, -1.5065e+00, 1.2898e+00, 60 | -1.5100e+00, 1.0930e+00, 1.0797e+00, -8.6681e-02, 1.3423e+00, 61 | 1.5184e-01, 2.4687e-01, 3.1895e-01, -9.8614e-01, -2.1382e-01, 62 | -6.4308e-02, -8.5528e-01, 63 | 1.6113e-01, 4.4925e-01, 8.1827e-01, -8.1628e-01, -3.9243e-01, 64 | -7.4521e-01, -9.4649e-01, -1.5941e-01, -1.5047e+00, 8.4682e-01, 65 | -4.9158e-02, 9.3866e-02, -6.4533e-01, 1.2108e+00, -7.8198e-01, 66 | 3.8449e-01, -8.5259e-01, 1.0464e+00, -1.8493e+00, 9.1092e-01, 67 | -9.9360e-01, 6.0195e-01, -1.0890e-01, 5.2587e-01, -9.4046e-01, 68 | -1.2773e-01, -2.5679e-01, 69 | -1.5437e+00, 3.7950e-01, -1.7705e+00, -1.2085e+00, 9.4773e-01, 70 | -9.1355e-01, 7.1023e-01, 7.9512e-01, 5.7662e-01, -7.3778e-01, 71 | -1.5264e+00, 7.1173e-01, 1.4056e+00, -4.0636e-01, -7.4648e-01, 72 | 4.9790e-01, 1.1298e-01, -4.1854e-01, 1.7905e-01, 2.3483e-01, 73 | 7.3510e-01, -6.1577e-01, 7.0467e-01, 1.1630e-01, 2.8365e-01, 74 | -2.5043e+00, -5.1931e-01, 75 | -5.9134e-01, -1.1059e-01, 8.3416e-01, -1.0505e+00, 3.6345e-01, 76 | 1.8195e-01, -4.8045e-01, 5.3309e-01, 6.7869e-01, -3.5974e-01, 77 | -1.3270e+00, -8.2526e-01, 6.3614e-01, 1.9110e-01, 7.5476e-01, 78 | 4.0538e-01, 2.2565e+00, 1.3655e+00, -5.6192e-01, -3.0423e-01, 79 | 2.9894e-01, 1.8784e+00, 5.5958e-01, 1.3388e+00, 4.1606e-01, 80 | 6.8491e-01, -1.4790e-01, 81 | 1.9359e-01, 1.0532e+00, 6.3393e-01, 2.5786e-01, 9.6408e-01, 82 | -2.4855e-01, 2.4756e-02, -3.0404e-02, 1.5622e+00, -4.4852e-01, 83 | -1.2345e+00, 1.1220e+00, -6.7381e-01, 3.7882e-02, -5.5881e-01, 84 | -8.2709e-01, 8.2253e-01, -7.5100e-01, 9.2778e-01, -1.4849e+00, 85 | -2.1293e-01, -1.1860e+00, -6.6092e-01, -2.3348e-01, 1.5447e+00, 86 | 6.0061e-01, -7.0909e-01, 87 | 1.9217e+00, -1.8182e-01, 1.5220e+00, 5.4644e-01, 4.0858e-01, 88 | -1.9692e+00, -8.9185e-01, 3.2961e-01, -2.5128e-01, 5.5030e-01, 89 | -7.5171e-01, -6.5783e-03, -6.3108e-01, 1.3431e+00, 3.8010e-02, 90 | -7.1654e-01, 1.7206e+00, -5.2149e-01, -2.3248e-01, 1.0774e+00, 91 | -7.6019e-01, 9.0109e-03, -7.9219e-01, 1.2307e+00, -5.2760e-01, 92 | -1.3207e+00, -7.0654e-01, 93 | -7.7861e-01, 1.2910e+00, -1.5094e+00, 7.4593e-01, 4.8990e-01, 94 | -1.0034e+00, 9.6407e-01, 2.0990e+00, -3.9870e-01, -7.6635e-01, 95 | -2.1007e+00, 1.2331e+00, 7.7481e-01, 2.4311e-01, -2.1322e-01, 96 | -6.9877e-01, 2.0889e-01, -6.2477e-01, -1.0825e-01, -2.1964e+00, 97 | 2.7083e-01, 6.1047e-01, -5.8162e-01, -1.7025e+00, -8.0672e-01, 98 | -2.4174e-01, 1.5490e+00, 99 | -3.4593e-01, 5.4714e-01, 3.1755e-02, 8.1375e-01, 2.6200e-01, 100 | -6.7101e-01, 2.0656e-02, 7.1300e-01, -4.3997e-02, -5.1944e-01, 101 | 1.1241e-01, -3.9770e-01, -2.7829e-01, -1.5364e-01, -2.5424e+00, 102 | 2.5033e-01, 1.1056e-01, -2.0366e+00, -9.2735e-01, -6.9350e-01, 103 | -5.2788e-01, -8.7438e-01, -1.0102e+00, -1.0522e+00, 1.2348e+00, 104 | 2.5907e-02, -9.6676e-01, 105 | 1.0904e+00, 5.3966e-01, 6.6741e-01, -2.2316e+00, -1.1603e+00, 106 | -4.2560e-01, 5.9547e-01, -1.0887e+00, 2.4324e-01, -2.1021e+00, 107 | -2.9289e-01, -7.0682e-01, 9.5190e-01, -1.1583e+00, -1.2844e+00, 108 | 1.0193e+00, 1.6851e+00, 8.3422e-01, 1.7113e+00, 4.4456e-01, 109 | -7.1861e-01, -7.0343e-01, -7.1332e-01, 9.9760e-01, -6.1980e-01, 110 | 1.9522e+00, 1.4311e-01, 111 | 1.8765e-01, 7.5974e-01, -2.6387e-01, -7.3048e-01, 6.1955e-01, 112 | 3.5577e-02, -7.6459e-02, -1.2306e+00, 1.3419e+00, 1.1878e+00, 113 | -1.0672e+00, -2.1507e+00, 6.7082e-01, 1.1614e+00, -2.4155e-01, 114 | 9.5907e-01, 3.8262e-02, 3.9877e-02, -7.7180e-01, 2.9251e-01, 115 | -6.0606e-01, -1.5136e+00, -2.7143e+00, -4.1164e-01, -1.2273e+00, 116 | -4.1746e-01, 1.5021e+00, 117 | -6.2849e-01, -4.4247e-01, 5.6885e-01, 1.2803e+00, -5.5397e-01, 118 | 1.1179e+00, -6.0053e-01, -5.8619e-01, -2.8277e-01, 5.3390e-01, 119 | -9.9388e-01, -1.6996e+00, 1.8362e+00, 4.2016e-01, -6.8729e-01, 120 | -3.5060e-01, 7.5598e-01, -9.3632e-01, -8.4109e-02, -1.6361e+00, 121 | 1.0224e+00, 1.0733e+00, -5.7453e-01, 4.9668e-02, 7.2379e-01, 122 | 5.9746e-01, 2.6966e+00, 123 | 2.7930e+00, -2.2745e+00, -2.3912e-01, 8.7498e-02, 1.4967e+00, 124 | -5.7016e-01, -5.7248e-01, 1.9909e+00, -7.4416e-01, 7.2960e-01, 125 | 6.4083e-01, 1.6075e+00, -8.8810e-01, 2.7359e-01, -1.3257e-01, 126 | 1.2710e+00, 1.7234e+00, 1.1180e-01, 2.6952e-01, 1.1835e+00, 127 | 1.2575e+00, 1.3969e-01, 4.7259e-01, 7.9025e-01, 1.0811e+00, 128 | -9.1965e-01, -4.0503e-01, 129 | 4.5696e-01, -5.4184e-01, -2.3025e+00, 2.0127e+00, -4.6452e-01, 130 | -5.8270e-01, 2.0863e+00, -4.7729e-02, -4.4920e-01, 9.5566e-01, 131 | -1.4708e-01, -1.2532e+00, -1.1850e+00, 3.6583e-01, -1.4049e-01, 132 | 3.5252e-01, -5.2400e-01, -6.2844e-01, -9.3792e-01, 1.6772e+00, 133 | 3.8554e-03, -7.3685e-01, -9.3514e-01, 1.0465e-01, -4.6464e-01, 134 | 1.6676e+00, 1.3931e+00, 135 | 6.5398e-01, -2.2449e-01, 1.2831e+00, -9.1787e-01, -3.3916e-01, 136 | -1.8058e+00, 6.0518e-01, -5.6252e-01, -7.8933e-01, 1.2767e+00, 137 | -1.0143e+00, 4.1611e-01, -7.5348e-01, 1.7128e+00, -8.7554e-01, 138 | 3.9714e-01, 8.4326e-01, 3.7988e-01, -1.1670e+00, 5.5228e-01, 139 | -1.0279e+00, -3.9554e-01, -7.1410e-01, -8.7456e-02, -3.3361e-01, 140 | -1.8798e-01, -1.2647e+00, 141 | 2.0021e+00, -2.3470e-01, -1.3765e+00, 9.3426e-01, 1.0880e+00, 142 | 1.9179e-01, 3.0114e-01, 8.9896e-01, -8.4454e-01, 2.3267e-01, 143 | -3.9205e-01, -2.5081e-01, 8.7124e-02, 1.3769e+00, -8.3358e-01, 144 | -8.9400e-01, 1.1744e+00, -6.0779e-01, -1.1493e-01, -7.8077e-01, 145 | 1.9660e+00, 6.1175e-01, 3.6039e-01, -1.0274e+00, 1.1495e+00, 146 | 4.5111e-01, 6.4420e-01, 147 | 2.1635e-01, -7.8731e-01, -3.3005e-01, 3.2877e-01, -1.6332e+00, 148 | 1.0807e+00, 3.3638e-01, 1.1536e-01, 3.2834e-01, 5.3447e-02, 149 | 1.4224e+00, -8.3957e-01, -2.4956e-01, -8.9778e-01, -8.6583e-01, 150 | -1.0786e+00, -1.8384e-01, 7.1622e-01, 1.8175e-01, 1.1053e+00, 151 | 1.7003e+00, -1.6965e-01, 1.6293e-01, 1.3413e+00, -2.6301e-01, 152 | -7.5521e-01, 8.1911e-01, 153 | 7.4140e-01, -5.8787e-01, -4.6505e-01, 5.3112e-02, 2.2190e+00, 154 | -3.5158e-01, 3.6381e-01, 2.5769e+00, 1.4544e+00, -6.1003e-01, 155 | -5.9961e-01, -5.8392e-01, -1.8104e-02, -9.5177e-01, -9.6400e-01, 156 | -2.8183e-01, 1.0597e+00, -7.2370e-01, 1.4755e-01, -3.2667e-01, 157 | 2.4958e+00, 1.1088e+00, -8.5476e-01, 1.8443e+00, -1.3881e-01, 158 | 1.3096e+00, -2.5802e-01, 159 | 1.0669e+00, 2.1363e-01, -7.6603e-01, -1.6977e+00, -1.5023e-01, 160 | -5.2150e-01, -6.3730e-01, 2.6214e-01, 7.6539e-03, 1.3067e+00, 161 | -6.3482e-01, -1.1042e-04, -6.6158e-01, 1.4723e-01, -6.6036e-02, 162 | 5.2851e-01, 5.7950e-01, 2.1438e-01, 9.2200e-01, 5.2919e-01, 163 | 7.7070e-01, 4.2899e-01, 3.4330e-01, 2.0698e+00, 1.3405e+00, 164 | -2.1746e-01, 8.6273e-01] 165 | } -------------------------------------------------------------------------------- /src/central/view.rs: -------------------------------------------------------------------------------- 1 | use crate::central::get_equation; 2 | use crate::central::indexable::Indexable; 3 | use crate::central::operation::Operation; 4 | use crate::central::shape::Shape; 5 | use crate::central::tensor::Tensor; 6 | use crate::central::BackpropagationPacket; 7 | use ndarray::prelude::*; 8 | use ndarray::ArrayD; 9 | 10 | use super::tensor::NAME_LENGTH; 11 | 12 | /// Backward pass for the view operation 13 | /// This function will take in a `BackpropagationPacket` and then set the gradients of the source tensor. 14 | /// # Arguments 15 | /// * `backprop_packet` - A `BackpropagationPacket` that contains the information needed to perform the backward pass. 16 | /// 17 | /// # Panics 18 | /// This function will panic if the operation in the `BackpropagationPacket` is not a view operation. 19 | /// This function will panic if the number of dimensions of the source tensor is not 1 or 2. 20 | /// This function will panic if the number of dimensions of the source tensor is greater than 2. 21 | /// This function will panic if the number of dimensions of the view tensor is not 1 or 2. 22 | pub fn backward(backprop_packet: BackpropagationPacket) { 23 | if let Operation::View(source_tensor, origin_index) = backprop_packet.operation { 24 | // All this should do is take the grad and put it in the right place 25 | let source_grad = backprop_packet.equation.get_tensor_grad(source_tensor); 26 | // View grad is not the same shape as source grad 27 | // so we want to allocate a zero tensor of the same shape as source grad 28 | // and then copy the view grad into the right place 29 | let source_shape = backprop_packet 30 | .equation 31 | .internal_tensor_store 32 | .get(&source_tensor) 33 | .unwrap() 34 | .shape; 35 | let mut new_view_grad = ArrayD::zeros(source_shape.as_ndarray_shape()); 36 | // Indexable::Single and Indexable::Double are simple versions that let us use just usize to index 37 | // Indexable::FromTensor is a more complex version that lets us use a tensor to index 38 | match origin_index { 39 | Indexable::Single(i) => { 40 | let number_of_dimensions = new_view_grad.shape().len(); 41 | if number_of_dimensions == 1 { 42 | new_view_grad[[i]] = backprop_packet.grad[0]; 43 | } else if number_of_dimensions == 2 { 44 | new_view_grad 45 | .slice_mut(s![i, ..]) 46 | .assign(&backprop_packet.grad.clone()); 47 | } else { 48 | panic!("Not implemented"); 49 | } 50 | } 51 | Indexable::Double(i, j) => { 52 | new_view_grad[[i, j]] = backprop_packet.grad[0]; 53 | } 54 | Indexable::Triple(i, j, k) => { 55 | new_view_grad[[i, j, k]] = backprop_packet.grad[0]; 56 | } 57 | Indexable::FromTensor(tensor) => { 58 | // Get the indices from the tensor 59 | let indices = backprop_packet.equation.get_tensor_data(tensor); 60 | 61 | // This is a hack to get the shape of the source tensor 62 | // This is because the source tensor is not stored in the backprop packet 63 | // and so we need to get the shape of the source tensor from the equation 64 | let this_shape = source_grad.shape(); 65 | let other_shape = indices.shape(); 66 | let mut new_shape_dims = Vec::new(); 67 | for i in 0..other_shape.len() { 68 | new_shape_dims.push(other_shape[i]); 69 | } 70 | // HACK: this is to get this to work for tha 2 indexing 2 case 71 | new_shape_dims.push(this_shape[source_grad.ndim() - 1]); 72 | let new_shape = Shape::new(new_shape_dims); 73 | // REcreate the shape of the output 74 | 75 | // This is just running the operation in reverse 76 | for i in 0..new_shape.indices[0] { 77 | for j in 0..new_shape.indices[1] { 78 | for k in 0..new_shape.indices[2] { 79 | new_view_grad[[indices[[i, j]] as usize, k]] = 80 | backprop_packet.grad[[i, j, k]]; 81 | } 82 | } 83 | } 84 | } 85 | } 86 | 87 | if backprop_packet.equation.advanced_logging { 88 | println!("New View Grad: {:?}", new_view_grad); 89 | println!("Source Grad: {:?}", source_grad); 90 | } 91 | backprop_packet 92 | .equation 93 | .set_tensor_grad(source_tensor, source_grad + new_view_grad); 94 | } else { 95 | panic!("Invalid operation type for backward pass"); 96 | } 97 | } 98 | 99 | /// View operation for the tensor struct 100 | /// This function will take in an indexable and then return a new tensor that is a view of the original tensor 101 | /// # Arguments 102 | /// * `index` - An indexable that will be used to create the view tensor. 103 | /// # Returns 104 | /// A new tensor that is a view of the original tensor. 105 | impl Tensor { 106 | pub fn view(&self, index: Indexable) -> Tensor { 107 | // Allocate a new tensor the size of the view 108 | // and then set the data of the new tensor to the data of the old tensor 109 | let mut singleton = get_equation(); 110 | let data: Vec = singleton.get_item(self.tensor_id); 111 | // I now need to get the index subset of data from the old tensor 112 | let new_shape = self.shape.subshape_from_indexable(index); 113 | 114 | match index { 115 | Indexable::Single(i) => { 116 | // If the number of indices is 1, then we can just take the data from the old tensor 117 | // and then allocate a new tensor with the data from the old tensor 118 | if self.shape.number_of_indices == 1 { 119 | let data = data[i]; 120 | let tensor_id = singleton.allocate_element_tensor( 121 | new_shape, 122 | data, 123 | Operation::View(self.tensor_id, index), 124 | ); 125 | return Tensor { 126 | tensor_id, 127 | shape: new_shape, 128 | operation: Operation::View(self.tensor_id, index), 129 | name: ['a'; NAME_LENGTH], 130 | }; 131 | // If the number of indices is 2, then we need to take a slice of the data 132 | // and then allocate a new tensor with the data from the old tensor 133 | } else if self.shape.number_of_indices == 2 { 134 | let offset = i * self.shape.indices[1]; 135 | let data = data[offset..offset + self.shape.indices[1]].to_vec(); 136 | let tensor_id = singleton.allocate_tensor_from_operation( 137 | new_shape, 138 | data, 139 | Operation::View(self.tensor_id, index), 140 | ); 141 | return Tensor { 142 | tensor_id, 143 | shape: new_shape, 144 | operation: Operation::View(self.tensor_id, index), 145 | name: ['a'; NAME_LENGTH], 146 | }; 147 | } else { 148 | panic!("Indexing not supported for tensors with more than 2 dimensions"); 149 | } 150 | } 151 | Indexable::Double(a, b) => { 152 | // If the number of indices is 1, then we can just take the data from the old tensor 153 | // and then allocate a new tensor with the data from the old tensor 154 | let offset = a * self.shape.indices[1] + b; 155 | let data = data[offset]; 156 | let tensor_id = singleton.allocate_element_tensor( 157 | new_shape, 158 | data, 159 | Operation::View(self.tensor_id, index), 160 | ); 161 | return Tensor { 162 | tensor_id, 163 | shape: new_shape, 164 | operation: Operation::View(self.tensor_id, index), 165 | name: ['a'; NAME_LENGTH], 166 | }; 167 | } 168 | Indexable::Triple(a, b, c) => { 169 | // If the number of indices is 1, then we can just take the data from the old tensor 170 | // and then allocate a new tensor with the data from the old tensor 171 | let offset = a * self.shape.indices[1] * self.shape.indices[2] 172 | + b * self.shape.indices[2] 173 | + c; 174 | let data = data[offset]; 175 | let tensor_id = singleton.allocate_element_tensor( 176 | new_shape, 177 | data, 178 | Operation::View(self.tensor_id, index), 179 | ); 180 | return Tensor { 181 | tensor_id, 182 | shape: new_shape, 183 | operation: Operation::View(self.tensor_id, index), 184 | name: ['a'; NAME_LENGTH], 185 | }; 186 | } 187 | Indexable::FromTensor(a) => { 188 | 189 | let indices = singleton.get_tensor_data(a); 190 | 191 | let this_shape = self.shape.clone().indices; 192 | let other_shape = indices.shape(); 193 | let mut new_shape_dims = Vec::new(); 194 | for i in 0..other_shape.len() { 195 | new_shape_dims.push(other_shape[i]); 196 | } 197 | // HACK: this is to get this to work for tha 2 indexing 2 case 198 | new_shape_dims.push(this_shape[self.shape.number_of_indices - 1]); 199 | let new_shape = Shape::new(new_shape_dims); 200 | 201 | assert!(indices.ndim() <= self.shape.number_of_indices); 202 | 203 | let data = singleton.get_item(self.tensor_id).clone(); 204 | let data = data.as_slice(); 205 | let data_as_array = 206 | ArrayD::from_shape_vec(self.shape.as_ndarray_shape(), data.to_vec()).unwrap(); 207 | let mut return_tensor = ArrayD::::zeros(new_shape.as_ndarray_shape()); 208 | 209 | let return_shape = return_tensor.shape().to_vec(); 210 | 211 | if return_shape.len() == 2 { 212 | for i in 0..return_shape[0] { 213 | for j in 0..return_shape[1] { 214 | return_tensor[[i, j]] = data_as_array[[indices[[i]] as usize, j]]; 215 | } 216 | } 217 | } 218 | else if return_shape.len() == 3 { 219 | for i in 0..return_shape[0] { 220 | for j in 0..return_shape[1] { 221 | for k in 0..return_shape[2] { 222 | return_tensor[[i, j, k]] = data_as_array[[indices[[i, j]] as usize, k]]; 223 | } 224 | } 225 | } 226 | } 227 | 228 | let tensor_id = singleton.allocate_tensor_from_operation( 229 | new_shape.clone().into(), 230 | return_tensor.into_raw_vec(), 231 | Operation::View(self.tensor_id, index), 232 | ); 233 | return Tensor { 234 | tensor_id, 235 | shape: new_shape, 236 | operation: Operation::View(self.tensor_id, index), 237 | name: ['a'; NAME_LENGTH], 238 | }; 239 | } 240 | } 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /src/central/matmul_op.rs: -------------------------------------------------------------------------------- 1 | use crate::central::get_equation; 2 | use crate::central::operation::Operation; 3 | use crate::central::tensor::Tensor; 4 | use crate::central::BackpropagationPacket; 5 | use ndarray::prelude::*; 6 | use std::ops::Shl; 7 | 8 | use super::tensor::name_from_string; 9 | 10 | pub fn backward(backprop_packet: BackpropagationPacket) { 11 | if let Operation::MatMul(a, b) = backprop_packet.operation { 12 | // Handle the case when the gradient is a 2D matrix 13 | if backprop_packet.grad.ndim() == 1 { 14 | // Convert the gradient to a 2D matrix 15 | let out_grad = backprop_packet.grad.clone(); 16 | // Get the data of the right-hand operand of the MatMul operation 17 | let right_hand_data = backprop_packet.equation.get_tensor_data(b); 18 | let right_hand_data_tranpose = right_hand_data.t(); 19 | 20 | // Transpose the right-hand data 21 | 22 | // Get the gradient of the left-hand operand of the MatMul operation 23 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 24 | // Update the gradient of the left-hand operand 25 | // Gradient of A in A*B with respect to the loss is given by (dL/dZ) * B^T 26 | let hold = left_hand_grad 27 | + backprop_packet 28 | .equation 29 | .matmul(&out_grad, &right_hand_data_tranpose.to_owned()); 30 | //let hold = left_hand_grad + out_grad.clone().dot(&other).into_dyn(); 31 | backprop_packet.equation.set_tensor_grad(a, hold); 32 | 33 | // Get the data of the left-hand operand of the MatMul operation 34 | let left_hand_data = backprop_packet.equation.get_tensor_data(a); 35 | 36 | // Get the gradient of the right-hand operand of the MatMul operation 37 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 38 | // Transpose the left-hand data 39 | let other = left_hand_data.t(); 40 | // Update the gradient of the right-hand operand 41 | // Gradient of B in A*B with respect to the loss is given by A^T * (dL/dZ) 42 | 43 | let other_len = other.len(); 44 | let out_grad_len = out_grad.len(); 45 | let other_reshape = other.into_shape((other_len, 1)).unwrap().to_owned().into_dyn(); 46 | let out_grad_reshape = out_grad.into_shape((1, out_grad_len)).unwrap().to_owned().into_dyn(); 47 | 48 | let temp = right_hand_grad 49 | + backprop_packet 50 | .equation 51 | .matmul(&other_reshape, &out_grad_reshape); 52 | 53 | backprop_packet.equation.set_tensor_grad(b, temp); 54 | } 55 | else if backprop_packet.grad.ndim() == 2 { 56 | // Convert the gradient to a 2D matrix 57 | let out_grad = backprop_packet.grad.clone(); 58 | // Get the data of the right-hand operand of the MatMul operation 59 | let right_hand_data = backprop_packet.equation.get_tensor_data(b); 60 | let right_hand_data_tranpose = right_hand_data.t(); 61 | 62 | // Transpose the right-hand data 63 | 64 | // Get the gradient of the left-hand operand of the MatMul operation 65 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 66 | // Update the gradient of the left-hand operand 67 | // Gradient of A in A*B with respect to the loss is given by (dL/dZ) * B^T 68 | let hold = left_hand_grad 69 | + backprop_packet 70 | .equation 71 | .matmul(&out_grad, &right_hand_data_tranpose.to_owned()); 72 | //let hold = left_hand_grad + out_grad.clone().dot(&other).into_dyn(); 73 | backprop_packet.equation.set_tensor_grad(a, hold); 74 | 75 | // Get the data of the left-hand operand of the MatMul operation 76 | let left_hand_data = backprop_packet.equation.get_tensor_data(a); 77 | // Get the gradient of the right-hand operand of the MatMul operation 78 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 79 | // Transpose the left-hand data 80 | let other = left_hand_data.t(); 81 | // Update the gradient of the right-hand operand 82 | // Gradient of B in A*B with respect to the loss is given by A^T * (dL/dZ) 83 | let temp = right_hand_grad 84 | + backprop_packet 85 | .equation 86 | .matmul(&other.to_owned(), &out_grad); 87 | // let temp = right_hand_grad + other.dot(&out_grad).into_dyn(); 88 | backprop_packet.equation.set_tensor_grad(b, temp); 89 | } else if backprop_packet.grad.ndim() == 3 { 90 | // Handle the case when the gradient is a 3D tensor 91 | // Convert the gradient to a 3D tensor 92 | let out_grad = backprop_packet 93 | .grad 94 | .clone() 95 | .into_dimensionality::() 96 | .unwrap(); 97 | // Get the data of the right-hand operand of the MatMul operation 98 | let right_hand_data = backprop_packet.equation.get_tensor_data(b); 99 | let right_hand_data_tranpose = right_hand_data.t(); 100 | 101 | // Transpose the right-hand data 102 | let other = right_hand_data_tranpose 103 | .into_dimensionality::() 104 | .unwrap(); 105 | 106 | // Get the gradient of the left-hand operand of the MatMul operation 107 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 108 | let mut result = Array3::zeros(( 109 | left_hand_grad.shape()[0], 110 | left_hand_grad.shape()[1], 111 | left_hand_grad.shape()[2], 112 | )); 113 | // Update the gradient of the left-hand operand for each slice 114 | for i in 0..out_grad.shape()[0] { 115 | let hold = out_grad 116 | .slice(s![i, .., ..]) 117 | .dot(&other.slice(s![.., ..])) 118 | .into_dyn(); 119 | result.slice_mut(s![i, .., ..]).assign(&hold); 120 | } 121 | 122 | backprop_packet 123 | .equation 124 | .set_tensor_grad(a, result.into_dyn() + left_hand_grad); 125 | 126 | // Get the data of the left-hand operand of the MatMul operation 127 | let left_hand_data = backprop_packet.equation.get_tensor_data(a); 128 | // Get the gradient of the right-hand operand of the MatMul operation 129 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 130 | // Convert the left-hand data to a 3D tensor 131 | let other = left_hand_data.into_dimensionality::().unwrap(); 132 | let mut result = 133 | Array2::zeros((right_hand_grad.shape()[0], right_hand_grad.shape()[1])).into_dyn(); 134 | 135 | // Update the gradient of the right-hand operand for each slice 136 | for i in 0..out_grad.shape()[0] { 137 | let hold = other 138 | .slice(s![i, .., ..]) 139 | .t() 140 | .dot(&out_grad.slice(s![i, .., ..])) 141 | .into_dyn(); 142 | result = result + hold; 143 | } 144 | backprop_packet 145 | .equation 146 | .set_tensor_grad(b, result.into_dyn() + right_hand_grad); 147 | } 148 | else if backprop_packet.grad.ndim() == 4 { 149 | // Convert the gradient to a 2D matrix 150 | let out_grad = backprop_packet.grad.clone(); 151 | // Get the data of the right-hand operand of the MatMul operation 152 | let mut right_hand_data = backprop_packet.equation.get_tensor_data(b); 153 | right_hand_data.swap_axes(2, 3); 154 | 155 | // Transpose the right-hand data 156 | 157 | // Get the gradient of the left-hand operand of the MatMul operation 158 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 159 | // Update the gradient of the left-hand operand 160 | // Gradient of A in A*B with respect to the loss is given by (dL/dZ) * B^T 161 | let hold = left_hand_grad 162 | + backprop_packet 163 | .equation 164 | .matmul(&out_grad, &right_hand_data.to_owned()); 165 | //let hold = left_hand_grad + out_grad.clone().dot(&other).into_dyn(); 166 | backprop_packet.equation.set_tensor_grad(a, hold); 167 | 168 | // Get the data of the left-hand operand of the MatMul operation 169 | let mut left_hand_data = backprop_packet.equation.get_tensor_data(a); 170 | 171 | // Get the gradient of the right-hand operand of the MatMul operation 172 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 173 | // Transpose the left-hand data 174 | left_hand_data.swap_axes(2, 3); 175 | // Update the gradient of the right-hand operand 176 | // Gradient of B in A*B with respect to the loss is given by A^T * (dL/dZ) 177 | 178 | 179 | let temp = right_hand_grad 180 | + backprop_packet 181 | .equation 182 | .matmul(&left_hand_data.to_owned(), &out_grad.to_owned()); 183 | 184 | backprop_packet.equation.set_tensor_grad(b, temp); 185 | } 186 | else { 187 | 188 | panic!("Not implementerd for dim {}", backprop_packet.grad.ndim()); 189 | } 190 | } else { 191 | panic!("Invalid operation type for backward pass"); 192 | } 193 | } 194 | 195 | /* 196 | pub fn backward(backprop_packet: BackpropagationPacket) { 197 | if let Operation::MatMul(a, b) = backprop_packet.operation { 198 | 199 | 200 | if backprop_packet.grad.ndim() == 2 { 201 | let out_grad = backprop_packet.grad.clone().into_dimensionality::().unwrap(); 202 | let right_hand_data = backprop_packet.equation.get_tensor_data(b); 203 | let right_hand_data_tranpose = right_hand_data.t(); 204 | 205 | let other = right_hand_data_tranpose 206 | .into_dimensionality::() 207 | .unwrap(); 208 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 209 | let hold = left_hand_grad + out_grad.clone().dot(&other).into_dyn(); 210 | backprop_packet.equation.set_tensor_grad(a, hold); 211 | 212 | let left_hand_data = backprop_packet.equation.get_tensor_data(a); 213 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 214 | let other = left_hand_data.t().into_dimensionality::().unwrap(); 215 | let temp = right_hand_grad + other.dot(&out_grad).into_dyn(); 216 | backprop_packet.equation.set_tensor_grad(b, temp); 217 | } 218 | else if backprop_packet.grad.ndim() == 3 { 219 | let out_grad = backprop_packet.grad.clone().into_dimensionality::().unwrap(); 220 | let right_hand_data = backprop_packet.equation.get_tensor_data(b); 221 | let right_hand_data_tranpose = right_hand_data.t(); 222 | 223 | 224 | let other = right_hand_data_tranpose 225 | .into_dimensionality::() 226 | .unwrap(); 227 | 228 | let left_hand_grad = backprop_packet.equation.get_tensor_grad(a); 229 | let mut result = Array3::zeros((left_hand_grad.shape()[0], left_hand_grad.shape()[1], left_hand_grad.shape()[2])); 230 | for i in 0..out_grad.shape()[0] { 231 | let hold = out_grad.slice(s![i, .., ..]).dot(&other.slice(s![.., ..])).into_dyn(); 232 | result.slice_mut(s![i, .., ..]).assign(&hold); 233 | } 234 | 235 | 236 | 237 | backprop_packet.equation.set_tensor_grad(a, result.into_dyn() + left_hand_grad); 238 | 239 | let left_hand_data = backprop_packet.equation.get_tensor_data(a); 240 | let right_hand_grad = backprop_packet.equation.get_tensor_grad(b); 241 | let other = left_hand_data.into_dimensionality::().unwrap(); 242 | let mut result = Array2::zeros((right_hand_grad.shape()[0], right_hand_grad.shape()[1])).into_dyn(); 243 | 244 | for i in 0..out_grad.shape()[0] { 245 | let hold = other.slice(s![i,..,..]).t().dot(&out_grad.slice(s![i, .., ..])).into_dyn(); 246 | result = result + hold; 247 | } 248 | backprop_packet.equation.set_tensor_grad(b, result.into_dyn() + right_hand_grad); 249 | } 250 | else { 251 | panic!("Not implemented"); 252 | } 253 | 254 | } 255 | else { 256 | panic!("Invalid operation type for backward pass"); 257 | } 258 | } 259 | */ 260 | // SIN: reusing the Shl opeartor to do the matmul operations 261 | impl Shl for Tensor { 262 | type Output = Tensor; 263 | fn shl(self, rhs: Self) -> Self::Output { 264 | let mut singleton = get_equation(); 265 | let a_data = singleton.get_tensor_data(self.tensor_id); 266 | let b_data = singleton.get_tensor_data(rhs.tensor_id); 267 | let result_data = singleton.matmul(&a_data, &b_data); 268 | 269 | let resultant_shape = self.shape.matmul_shape(&rhs.shape); 270 | let tensor_id = singleton.allocate_tensor_from_operation( 271 | resultant_shape, 272 | result_data.into_raw_vec(), 273 | Operation::MatMul(self.tensor_id, rhs.tensor_id), 274 | ); 275 | let matmul_shape = self.shape.matmul_shape(&rhs.shape); 276 | 277 | Tensor { 278 | tensor_id, 279 | shape: matmul_shape, 280 | operation: Operation::MatMul(self.tensor_id, rhs.tensor_id), 281 | name: name_from_string("MatMul"), 282 | } 283 | } 284 | } 285 | 286 | mod tests { 287 | #[allow(unused_imports)] 288 | use crate::central::shape::Shape; 289 | #[allow(unused_imports)] 290 | use crate::central::tensor::Tensor; 291 | #[test] 292 | fn two_dimension_matmul_test() { 293 | let a = Tensor::ones(Shape::new(vec![2, 2])); 294 | let b = Tensor::ones(Shape::new(vec![2, 2])); 295 | let c = a << b; 296 | let result = c.item(); 297 | println!("{:?}", result); 298 | } 299 | 300 | #[test] 301 | fn three_dimension_matmul_test() { 302 | let a = Tensor::randn(Shape::new(vec![3, 2, 2])); 303 | let b = Tensor::randn(Shape::new(vec![2, 2])); 304 | let c = a << b; 305 | let result = c.item(); 306 | println!("{:?}", result); 307 | } 308 | } 309 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(unboxed_closures)] 2 | #![feature(fn_traits)] 3 | 4 | pub mod central; 5 | pub mod nn; 6 | pub mod optimizers; 7 | pub use optimizers::*; 8 | pub use central::*; 9 | pub use ndarray::prelude::*; 10 | 11 | #[cfg(test)] 12 | mod tests { 13 | 14 | use crate::nn::*; 15 | use crate::*; 16 | use rand::Rng; 17 | fn approx_equal(a: f32, b: f32, epsilon: f32) -> bool { 18 | (a - b).abs() < epsilon 19 | } 20 | 21 | #[test] 22 | fn basic_pow_test() { 23 | let a = Tensor::element(Shape::new(vec![1, 1]), 2.0); 24 | let b = a.pow(2.0); 25 | let result = b.item(); 26 | assert!(result == arr2(&[[4.0]]).into_dyn()); 27 | } 28 | 29 | #[test] 30 | fn backward_pow_test() { 31 | let a = Tensor::element(Shape::new(vec![1, 1]), 2.0); 32 | let b = a.pow(2.0); 33 | b.backward(); 34 | let result = b.grad(); 35 | assert!(result == arr2(&[[1.0]]).into_dyn()); 36 | let result = a.grad(); 37 | assert!(result == arr2(&[[4.0]]).into_dyn()); 38 | } 39 | 40 | 41 | #[test] 42 | fn basic_exp_test() { 43 | let a = Tensor::element(Shape::new(vec![1, 1]), 2.0); 44 | let b = a.exp(); 45 | let result = b.item(); 46 | assert!(result == arr2(&[[2.0f32.exp()]]).into_dyn()); 47 | } 48 | 49 | #[test] 50 | fn backward_exp_test() { 51 | let a = Tensor::element(Shape::new(vec![1, 1]), 2.0); 52 | let b = a.exp(); 53 | b.backward(); 54 | let result = b.grad(); 55 | assert!(result == arr2(&[[1.0]]).into_dyn()); 56 | let result = a.grad(); 57 | assert!(result == arr2(&[[2.0f32.exp()]]).into_dyn()); 58 | } 59 | 60 | #[test] 61 | fn transpose_test() { 62 | let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3].into()); 63 | let a_array = ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 64 | .unwrap(); 65 | let a_array_t = a_array.t(); 66 | println!("{:?}", a_array_t); 67 | println!("{:?}", a.item()); 68 | let b = a.transpose(); 69 | let result = b.item(); 70 | println!("{:?}", result); 71 | assert!(result == arr2(&[[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]).into_dyn()); 72 | } 73 | 74 | #[test] 75 | fn micrograd_copy_test() { 76 | let x1 = Tensor::element(Shape::new(vec![1]), 2.0); 77 | let x2 = Tensor::element(Shape::new(vec![1]), 0.0); 78 | 79 | let w1 = Tensor::element(Shape::new(vec![1]), -3.0); 80 | let w2 = Tensor::element(Shape::new(vec![1]), 1.0); 81 | 82 | let b = Tensor::element(Shape::new(vec![1]), 6.8813735870195432); 83 | 84 | let x1w1 = x1 * w1; 85 | let x2w2 = x2 * w2; 86 | let x1w1x2w2 = x1w1 + x2w2; 87 | let n = x1w1x2w2 + b; 88 | let l = 2.0f32 * n; 89 | let e = l.exp(); 90 | let o_1 = e - 1.0; 91 | let o_2 = e + 1.0; 92 | let o = o_1 / o_2; 93 | o.backward(); 94 | 95 | let n_result = n.item(); 96 | assert!(approx_equal(n_result[[0]], 0.8813734, 1e-6)); 97 | let n_grad = n.grad(); 98 | assert!(approx_equal(n_grad[[0]], 0.5, 1e-6)); 99 | 100 | let b_result = b.item(); 101 | assert!(approx_equal(b_result[[0]], 6.8813735870195432, 1e-6)); 102 | let b_grad = b.grad(); 103 | assert!(approx_equal(b_grad[[0]], 0.5, 1e-6)); 104 | 105 | let x1_result = x1.item(); 106 | assert!(approx_equal(x1_result[[0]], 2.0, 1e-6)); 107 | let x1_grad = x1.grad(); 108 | assert!(approx_equal(x1_grad[[0]], -1.5, 1e-6)); 109 | 110 | let x1w1x2w2_result = x1w1x2w2.item(); 111 | assert!(approx_equal(x1w1x2w2_result[[0]], -6.0, 1e-6)); 112 | let x1w1x2w2_grad = x1w1x2w2.grad(); 113 | assert!(approx_equal(x1w1x2w2_grad[[0]], 0.5, 1e-6)); 114 | 115 | let x2w2_result = x2w2.item(); 116 | assert!(approx_equal(x2w2_result[[0]], 0.0, 1e-6)); 117 | let x2w2_grad = x2w2.grad(); 118 | assert!(approx_equal(x2w2_grad[[0]], 0.5, 1e-6)); 119 | 120 | let x1w1_result = x1w1.item(); 121 | assert!(approx_equal(x1w1_result[[0]], -6.0, 1e-6)); 122 | let x1w1_grad = x1w1.grad(); 123 | assert!(approx_equal(x1w1_grad[[0]], 0.5, 1e-6)); 124 | 125 | let w2_result = w2.item(); 126 | assert!(approx_equal(w2_result[[0]], 1.0, 1e-6)); 127 | let w2_grad = w2.grad(); 128 | assert!(w2_grad[[0]] == 0.0); 129 | 130 | let x2_result = x2.item(); 131 | assert!(approx_equal(x2_result[[0]], 0.0, 1e-6)); 132 | let x2_grad = x2.grad(); 133 | assert!(approx_equal(x2_grad[[0]], 0.5, 1e-6)); 134 | 135 | let w1_result = w1.item(); 136 | assert!(approx_equal(w1_result[[0]], -3.0, 1e-6)); 137 | let w1_grad = w1.grad(); 138 | assert!(approx_equal(w1_grad[[0]], 1.0, 1e-6)); 139 | } 140 | 141 | use core::f32; 142 | 143 | use std::collections::HashMap; 144 | use std::fs::read_to_string; 145 | 146 | fn read_lines(filename: &str) -> Vec { 147 | let mut result = Vec::new(); 148 | 149 | for line in read_to_string(filename).unwrap().lines() { 150 | result.push(line.to_string()) 151 | } 152 | 153 | result 154 | } 155 | 156 | #[test] 157 | fn something_up_with_cuda() { 158 | let a = Tensor::randn(Shape::new(vec![32, 27])); 159 | let b = Tensor::randn(Shape::new(vec![27, 200])); 160 | let c = a << b; 161 | 162 | let a_as_nd = a.item().into_dimensionality::().unwrap(); 163 | let b_as_nd = b.item().into_dimensionality::().unwrap(); 164 | let c_dot = a_as_nd.dot(&b_as_nd); 165 | 166 | println!("{:?}", c.item()); 167 | println!("{:?}", c_dot); 168 | } 169 | 170 | #[test] 171 | fn array_indexing_testbed() { 172 | let mut a = ArrayD::from_elem(vec![27, 10], 0.0); 173 | let mut b = ArrayD::from_elem(vec![32, 3], 0.0); 174 | let mut c = ArrayD::from_elem(vec![32, 3, 10], 0.0); 175 | // Ok now I want to index A with b 176 | // roughly in the same way I would do in numpy 177 | // Fill A with some random values 178 | 179 | for i in 0..27 { 180 | for j in 0..10 { 181 | a[[i, j]] = i as f32 + j as f32; 182 | } 183 | } 184 | 185 | // Fill B with some random values, between 0 and 27 186 | let mut rng = rand::thread_rng(); 187 | for i in 0..32 { 188 | for j in 0..3 { 189 | b[[i, j]] = rng.gen_range(0..27) as f32; 190 | } 191 | } 192 | 193 | // 194 | for i in 0..32 { 195 | for j in 0..3 { 196 | for k in 0..10 { 197 | let test = [i, j, k]; 198 | c[test] = a[[b[[i, j]] as usize, k]]; 199 | } 200 | } 201 | } 202 | 203 | println!("{:?}", c); 204 | } 205 | 206 | #[test] 207 | fn linear_module() { 208 | let linear_layer_config = LinearLayerConfig::new(3, 3); 209 | let mut linear = LinearLayer::new(linear_layer_config); 210 | 211 | let inputs = vec![ 212 | vec![2.0f32, 3.0, -1.0], 213 | vec![3.0, -1.0, 0.5], 214 | vec![0.5, 1.0, 1.0], 215 | vec![1.0, 1.0, -1.0], 216 | ]; 217 | 218 | let inputs_as_tensor = Tensor::from_vec( 219 | inputs.iter().flatten().map(|x| *x).collect(), 220 | vec![4, 3].into(), 221 | ); 222 | 223 | let outputs = vec![1.0f32, -1.0, -1.0, 1.0]; 224 | let outputs_as_tensor = 225 | Tensor::from_vec(outputs.iter().map(|x| *x).collect(), vec![4, 1].into()); 226 | 227 | for _ in 0..50 { 228 | zero_all_grads(); 229 | let prediction = linear.forward(&inputs_as_tensor); 230 | let loss = (prediction - outputs_as_tensor).pow(2.0); 231 | loss.backward(); 232 | update_parameters(-0.01); 233 | } 234 | } 235 | 236 | fn build_batch_norm_dataset_from_subset( 237 | words: &[String], 238 | stoi: &HashMap, 239 | ) -> (Vec<[usize; 3]>, Vec) { 240 | let mut xs = vec![]; 241 | let mut ys = vec![]; 242 | for word in words { 243 | let fixed = String::from("...") + word + "."; 244 | let chars: Vec = fixed.chars().collect(); 245 | for i in 0..chars.len() - 3 { 246 | let pair = (chars[i], chars[i + 1], chars[i + 2], chars[i + 3]); 247 | xs.push([stoi[&pair.0], stoi[&pair.1], stoi[&pair.2]]); 248 | ys.push(stoi[&pair.3]); 249 | } 250 | } 251 | (xs, ys) 252 | } 253 | 254 | fn build_dataset_from_subset( 255 | words: &[String], 256 | stoi: &HashMap, 257 | ) -> (Vec<[usize; 3]>, Vec) { 258 | let mut xs = vec![]; 259 | let mut ys = vec![]; 260 | for word in words { 261 | let fixed = String::from("...") + word + "."; 262 | let chars: Vec = fixed.chars().collect(); 263 | for i in 0..chars.len() - 3 { 264 | let pair = (chars[i], chars[i + 1], chars[i + 2], chars[i + 3]); 265 | xs.push([stoi[&pair.0], stoi[&pair.1], stoi[&pair.2]]); 266 | ys.push(stoi[&pair.3]); 267 | } 268 | } 269 | (xs, ys) 270 | } 271 | 272 | #[test] 273 | fn batch_norm_simple_test() { 274 | let n_hidden = 200; 275 | 276 | const BATCH_SIZE: usize = 32; 277 | let names = read_lines("./data/bigram/names.txt"); 278 | 279 | let mut stoi = HashMap::new(); 280 | let mut itos = HashMap::new(); 281 | let mut i = 0; 282 | for c in ".abcdefghijklmnopqrstuvwxyz".chars() { 283 | stoi.insert(c, i); 284 | itos.insert(i, c); 285 | i += 1; 286 | } 287 | let n1 = (names.len() as f32 * 0.8f32) as usize; 288 | let n2 = (names.len() as f32 * 0.9f32) as usize; 289 | let (xtr, ytr) = build_batch_norm_dataset_from_subset(&names[..n1], &stoi); 290 | let (_xdev, _ydev) = build_batch_norm_dataset_from_subset(&names[n1..n2], &stoi); 291 | let (_cte, _yte) = build_batch_norm_dataset_from_subset(&names[n2..], &stoi); 292 | 293 | let mut c = Tensor::load_from_weight_file("./data/batchnorm/C.json"); 294 | c.set_requires_grad(true); 295 | let mut w1 = Tensor::load_from_weight_file("./data/batchnorm/W1.json"); 296 | w1.set_requires_grad(true); 297 | let mut w2 = Tensor::load_from_weight_file("./data/batchnorm/W2.json"); 298 | w2.set_requires_grad(true); 299 | let mut b2 = Tensor::load_from_weight_file("./data/batchnorm/b2.json"); 300 | b2.set_requires_grad(true); 301 | 302 | let mut bngain = Tensor::load_from_weight_file("./data/batchnorm/bngain.json"); 303 | bngain.set_requires_grad(true); 304 | let mut bnbiases = Tensor::load_from_weight_file("./data/batchnorm/bnbias.json"); 305 | bnbiases.set_requires_grad(true); 306 | 307 | let mut bnmean_running = Tensor::zeroes(Shape::new(vec![1, n_hidden])); 308 | bnmean_running.set_requires_grad(true); 309 | let mut bnvar_running = Tensor::ones(Shape::new(vec![1, n_hidden])); 310 | bnvar_running.set_requires_grad(true); 311 | 312 | let max_steps = 2; 313 | 314 | for _i in 0..max_steps { 315 | zero_all_grads(); 316 | let mut test_index_tensor = Tensor::zeroes(Shape::new(vec![BATCH_SIZE, 3])); 317 | for b in 0..BATCH_SIZE { 318 | test_index_tensor.set_index([b, 0].into(), vec![xtr[b][0] as f32].into()); 319 | test_index_tensor.set_index([b, 1].into(), vec![xtr[b][1] as f32].into()); 320 | test_index_tensor.set_index([b, 2].into(), vec![xtr[b][2] as f32].into()); 321 | } 322 | let test = c.view(Indexable::FromTensor(test_index_tensor.tensor_id)); 323 | let reshape = test.reshape(Shape::new(vec![BATCH_SIZE, 30])); 324 | let hpreact = reshape << w1; 325 | 326 | let bnmeani = hpreact.mean(vec![0]); 327 | let bnvari = hpreact.std(vec![0]); 328 | let offset = hpreact - bnmeani; 329 | let numer = offset * bngain; 330 | let hpreact = numer / bnvari + bnbiases; 331 | 332 | let h = hpreact.tanh(); 333 | let logits = (h << w2) + b2; 334 | 335 | let mut test_ytrue_onehot = Tensor::element(Shape::new(vec![BATCH_SIZE, 27]), 0.0); 336 | for b in 0..BATCH_SIZE { 337 | test_ytrue_onehot.set_index([b, ytr[b]].into(), vec![1.0].into()); 338 | } 339 | 340 | let loss = logits.cross_entropy_loss(test_ytrue_onehot); 341 | println!("Loss: {}", loss.item()); 342 | 343 | loss.backward(); 344 | update_parameters(-0.01); 345 | } 346 | println!("w1 grad {:?}", w1.grad()); 347 | } 348 | 349 | use crate::nn::model::{Model, Sequential}; 350 | #[test] 351 | fn batch_norm_test() { 352 | let batch_size = 32; 353 | let block_size = 3; 354 | let vocab_size = 100; 355 | let n_embd = 10; 356 | let n_hidden = 100; 357 | let names = read_lines("./data/bigram/names.txt"); 358 | 359 | let mut stoi = HashMap::new(); 360 | let mut itos = HashMap::new(); 361 | let mut i = 0; 362 | for c in ".abcdefghijklmnopqrstuvwxyz".chars() { 363 | stoi.insert(c, i); 364 | itos.insert(i, c); 365 | i += 1; 366 | } 367 | let n1 = (names.len() as f32 * 0.8f32) as usize; 368 | 369 | let (xtr, _ytr) = build_dataset_from_subset(&names[..n1], &stoi); 370 | 371 | let mut test_index_tensor = Tensor::zeroes(Shape::new(vec![batch_size, 3])); 372 | for b in 0..batch_size { 373 | test_index_tensor.set_index([b, 0].into(), vec![xtr[b][0] as f32].into()); 374 | test_index_tensor.set_index([b, 1].into(), vec![xtr[b][1] as f32].into()); 375 | test_index_tensor.set_index([b, 2].into(), vec![xtr[b][2] as f32].into()); 376 | } 377 | 378 | let c = Tensor::randn(Shape::new(vec![vocab_size, n_embd])); 379 | let linear_layer_config = LinearLayerConfig::new(n_embd * block_size, n_hidden); 380 | let linear_layer_common_config = LinearLayerConfig::new(n_hidden, n_hidden); 381 | let mut linear_model: Sequential = vec![ 382 | LinearLayer::new(linear_layer_config).into(), 383 | BatchNorm1d::new(n_hidden).into(), 384 | Tanh::new().into(), 385 | LinearLayer::new(linear_layer_common_config).into(), 386 | BatchNorm1d::new(n_hidden).into(), 387 | Tanh::new().into(), 388 | LinearLayer::new(linear_layer_common_config).into(), 389 | BatchNorm1d::new(n_hidden).into(), 390 | Tanh::new().into(), 391 | LinearLayer::new(linear_layer_common_config).into(), 392 | BatchNorm1d::new(n_hidden).into(), 393 | Tanh::new().into(), 394 | LinearLayer::new(linear_layer_common_config).into(), 395 | BatchNorm1d::new(n_hidden).into(), 396 | Tanh::new().into(), 397 | LinearLayer::new(linear_layer_common_config).into(), 398 | BatchNorm1d::new(vocab_size).into(), 399 | ] 400 | .into(); 401 | 402 | let test = c.view(Indexable::FromTensor(test_index_tensor.tensor_id)); 403 | let reshape = test.reshape(Shape::new(vec![32, 30])); 404 | 405 | let output = linear_model.forward(&reshape); 406 | output.backward(); 407 | update_parameters(-0.01); 408 | } 409 | } 410 | -------------------------------------------------------------------------------- /src/nn/layers/casual_self_attention.rs: -------------------------------------------------------------------------------- 1 | use crate::nn::{Module, LinearLayer, LinearLayerConfig}; 2 | use std::io::{BufWriter, Write}; 3 | use std::fs::OpenOptions; 4 | use crate::central::Tensor; 5 | use log::info; 6 | 7 | fn write_f32_vector_to_file(path: &str, data: &[f32]) -> std::io::Result<()> { 8 | // Create or open the file 9 | let file = OpenOptions::new() 10 | .create(true) 11 | .append(true) 12 | .open(path)?; 13 | 14 | // Create a buffered writer to improve performance 15 | let mut writer = BufWriter::new(file); 16 | 17 | // Write each f32 value to the file 18 | for &value in data { 19 | // Write the float value as a string 20 | writeln!(writer, "{}", value)?; 21 | } 22 | 23 | // Ensure all data is flushed to the file 24 | writer.flush()?; 25 | 26 | Ok(()) 27 | } 28 | 29 | fn write_string_vector_to_file(path: &str, data: &str) -> std::io::Result<()> { 30 | // Create or open the file 31 | let file = OpenOptions::new() 32 | .create(true) 33 | .append(true) 34 | .open(path)?; 35 | 36 | // Create a buffered writer to improve performance 37 | let mut writer = BufWriter::new(file); 38 | 39 | // Write the float value as a string 40 | writeln!(writer, "{}", data)?; 41 | 42 | // Ensure all data is flushed to the file 43 | writer.flush()?; 44 | 45 | Ok(()) 46 | } 47 | pub struct CasualSelfAttentionConfig { 48 | embedding_dim: usize, 49 | } 50 | 51 | pub struct CasualSelfAttention { 52 | pub query_attention: LinearLayer, 53 | pub key_attention: LinearLayer, 54 | pub value_attention: LinearLayer, 55 | pub c_proj: LinearLayer, 56 | pub y: Option, 57 | pub attn_weights: Option, 58 | pub filled: Option, 59 | pub mask: Option, 60 | } 61 | 62 | impl CasualSelfAttention { 63 | pub fn new(config: CasualSelfAttentionConfig) -> Self { 64 | let query_attention = LinearLayer::new(LinearLayerConfig { 65 | number_of_inputs: config.embedding_dim, 66 | number_of_weights: config.embedding_dim, 67 | }); 68 | 69 | let key_attention = LinearLayer::new(LinearLayerConfig { 70 | number_of_inputs: config.embedding_dim, 71 | number_of_weights: config.embedding_dim, 72 | }); 73 | 74 | let value_attention = LinearLayer::new(LinearLayerConfig { 75 | number_of_inputs: config.embedding_dim, 76 | number_of_weights: config.embedding_dim, 77 | }); 78 | 79 | let c_proj = LinearLayer::new(LinearLayerConfig { 80 | number_of_inputs: config.embedding_dim, 81 | number_of_weights: config.embedding_dim, 82 | }); 83 | 84 | CasualSelfAttention { 85 | query_attention, 86 | key_attention, 87 | value_attention, 88 | c_proj, 89 | y: None, 90 | attn_weights: None, 91 | filled: None, 92 | mask: None, 93 | } 94 | } 95 | 96 | pub fn from_weights_and_bias(query_weights: Tensor, query_bias: Tensor, key_weights: Tensor, key_bias: Tensor, value_weights: Tensor, value_bias: Tensor, c_proj_weights: Tensor, c_proj_bias: Tensor) -> Self { 97 | let query_attention = LinearLayer::from_weights_and_bias(query_weights, query_bias); 98 | let key_attention = LinearLayer::from_weights_and_bias(key_weights, key_bias); 99 | let value_attention = LinearLayer::from_weights_and_bias(value_weights, value_bias); 100 | let c_proj = LinearLayer::from_weights_and_bias(c_proj_weights, c_proj_bias); 101 | 102 | CasualSelfAttention { 103 | query_attention, 104 | key_attention, 105 | value_attention, 106 | c_proj, 107 | y: None, 108 | attn_weights: None, 109 | filled: None, 110 | mask: None, 111 | } 112 | } 113 | } 114 | 115 | impl Module for CasualSelfAttention { 116 | fn forward(&mut self, x: &Tensor) -> Tensor { 117 | info!("CasualSelfAttention forward"); 118 | info!("Query start"); 119 | let b = x.shape.indices[0]; 120 | let t = x.shape.indices[1]; 121 | let c = x.shape.indices[2]; 122 | println!("{:?}", x.shape); 123 | 124 | 125 | let num_heads = 12; 126 | 127 | let query = self.query_attention.forward(&x); 128 | let key = self.key_attention.forward(&x); 129 | let value = self.value_attention.forward(&x); 130 | 131 | 132 | 133 | 134 | //let _ = write_string_vector_to_file("./rust_checkfile.txt", "$Query"); 135 | //let _ = write_f32_vector_to_file("./rust_checkfile.txt", &query.item().into_raw_vec()); 136 | 137 | //let _ = write_string_vector_to_file("./rust_checkfile.txt", "$Key"); 138 | //let _ = write_f32_vector_to_file("./rust_checkfile.txt", &key.item().into_raw_vec()); 139 | 140 | //let _ = write_string_vector_to_file("./rust_checkfile.txt", "$Value"); 141 | //let _ = write_f32_vector_to_file("./rust_checkfile.txt", &value.item().into_raw_vec()); 142 | 143 | 144 | let query = query.reshape(vec![b, t, num_heads, c / num_heads].into()).tranpose_with_provided_axis(1, 2); 145 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$QueryT"); 146 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &query.item().into_raw_vec()); 147 | let key = key.reshape(vec![b, t, num_heads, c / num_heads].into()).tranpose_with_provided_axis(1, 2); 148 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$KeyT"); 149 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &key.item().into_raw_vec()); 150 | let value = value.reshape(vec![b, t, num_heads, c / num_heads].into()).tranpose_with_provided_axis(1, 2); 151 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$ValueT"); 152 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &value.item().into_raw_vec()); 153 | 154 | let key_super_tranposed = key.tranpose_with_provided_axis(2, 3); 155 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$KeyST"); 156 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &key_super_tranposed.item().into_raw_vec()); 157 | let query_key = query << key_super_tranposed; 158 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$QueryKey"); 159 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &query_key.item().into_raw_vec()); 160 | let denom = 1.0 / (key.shape.indices[key.shape.number_of_indices - 1] as f32).sqrt(); 161 | let attn_weights = query_key * denom; 162 | self.attn_weights = Some(attn_weights.clone()); 163 | 164 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$AttnWeights"); 165 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &attn_weights.item().into_raw_vec()); 166 | let mask = Tensor::tril(vec![t, t].into()).reshape(vec![1, 1, t, t].into()); 167 | self.mask = Some(mask.clone()); 168 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$Premask"); 169 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &mask.item().into_raw_vec()); 170 | 171 | let mask_broadcasted = mask.broadcast(vec![b, num_heads, t, t].into()); 172 | let filled = attn_weights.masked_fill(&mask_broadcasted, std::f32::NEG_INFINITY); 173 | self.filled = Some(filled.clone()); 174 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$Filled"); 175 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &filled.item().into_raw_vec()); 176 | let attn_weights = filled.softmax(attn_weights.shape.number_of_indices - 1); 177 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$Softmax"); 178 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &attn_weights.item().into_raw_vec()); 179 | let attn_output = attn_weights << value; 180 | self.y = Some(attn_output.clone()); 181 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$AttnOutput"); 182 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &attn_output.item().into_raw_vec()); 183 | let attn_output = attn_output.tranpose_with_provided_axis(1, 2).reshape(vec![b, t, c].into()); 184 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$AttnOutputReshape"); 185 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &attn_output.item().into_raw_vec()); 186 | let x = self.c_proj.forward(&attn_output); 187 | // let _ = write_string_vector_to_file("./rust_checkfile.txt", "$CProj"); 188 | // let _ = write_f32_vector_to_file("./rust_checkfile.txt", &x.item().into_raw_vec()); 189 | x 190 | } 191 | 192 | fn get_parameters(&self) -> Vec { 193 | let mut parameters = Vec::new(); 194 | parameters.extend(self.query_attention.get_parameters()); 195 | parameters.extend(self.key_attention.get_parameters()); 196 | parameters.extend(self.value_attention.get_parameters()); 197 | parameters.extend(self.c_proj.get_parameters()); 198 | parameters 199 | } 200 | } 201 | 202 | mod tests { 203 | use std::vec; 204 | use crate::nn::layers::linear::LinearLayerConfig; 205 | use crate::nn::layers::module::Module; 206 | use crate::central::Tensor; 207 | use crate::Shape; 208 | 209 | #[test] 210 | fn test_casual_self_attention() { 211 | let file_names = vec![ 212 | "./data/tests/causal_self_attention/causal_self_attention_query_weights.txt", 213 | "./data/tests/causal_self_attention/causal_self_attention_query_bias.txt", 214 | "./data/tests/causal_self_attention/causal_self_attention_key_weights.txt", 215 | "./data/tests/causal_self_attention/causal_self_attention_key_bias.txt", 216 | "./data/tests/causal_self_attention/causal_self_attention_value_weights.txt", 217 | "./data/tests/causal_self_attention/causal_self_attention_value_bias.txt", 218 | "./data/tests/causal_self_attention/causal_self_attention_c_proj_weights.txt", 219 | "./data/tests/causal_self_attention/causal_self_attention_c_proj_bias.txt", 220 | "./data/tests/causal_self_attention/fake_output.txt", 221 | "./data/tests/causal_self_attention/loss.txt", 222 | "./data/tests/causal_self_attention/fake_input.txt", 223 | "./data/tests/causal_self_attention/expected_output.txt", 224 | "./data/tests/causal_self_attention/causal_self_attention_c_proj_weights_grad.txt", 225 | "./data/tests/causal_self_attention/casual_self_attention_c_proj_bias_grad.txt", 226 | "./data/tests/causal_self_attention/causal_self_attention_query_weights_grad.txt", 227 | "./data/tests/causal_self_attention/causal_self_attention_query_bias_grad.txt", 228 | "./data/tests/causal_self_attention/causal_self_attention_key_weights_grad.txt", 229 | "./data/tests/causal_self_attention/causal_self_attention_key_bias_grad.txt", 230 | "./data/tests/causal_self_attention/causal_self_attention_value_weights_grad.txt", 231 | "./data/tests/causal_self_attention/causal_self_attention_value_bias_grad.txt" 232 | ]; 233 | 234 | let mut tenors = vec![]; 235 | 236 | for file_name in file_names.iter() { 237 | println!("Reading file {}", file_name); 238 | let mut file = std::fs::File::open(file_name).unwrap(); 239 | let t = crate::central::Tensor::from_bytestream(&mut file, false).unwrap(); 240 | tenors.push(t); 241 | } 242 | 243 | let mut casual_self_attention = super::CasualSelfAttention::from_weights_and_bias( 244 | tenors[0].clone(), 245 | tenors[1].clone(), 246 | tenors[2].clone(), 247 | tenors[3].clone(), 248 | tenors[4].clone(), 249 | tenors[5].clone(), 250 | tenors[6].clone(), 251 | tenors[7].clone(), 252 | ); 253 | 254 | let input = tenors[10].clone(); 255 | let output = casual_self_attention.forward(&input); 256 | 257 | let output_flatten = output.item().into_raw_vec(); 258 | let expected_output_flatten = tenors[11].item().into_raw_vec(); 259 | 260 | for (i, (o, e)) in output_flatten.iter().zip(expected_output_flatten.iter()).enumerate() { 261 | assert!((o - e).abs() < 1e-6, "Mismatch at index {}", i); 262 | } 263 | 264 | let mse_loss = (output - tenors[8]).pow(2.0).reshape(vec![2 * 8 * 768].into()).mean(vec![0]); 265 | 266 | let expected_loss_flatten = tenors[9].item().into_raw_vec(); 267 | let loss_flatten = mse_loss.item().into_raw_vec(); 268 | 269 | 270 | for (i, (l, e)) in loss_flatten.iter().zip(expected_loss_flatten.iter()).enumerate() { 271 | assert!((l - e).abs() < 1e-4, "Mismatch at index {}", i); 272 | } 273 | mse_loss.backward(); 274 | 275 | 276 | let grad = casual_self_attention.c_proj.weights.grad(); 277 | let expected_grad = tenors[12].clone(); 278 | let grad_flatten = grad.into_raw_vec(); 279 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 280 | const ERR : f32 = 1e-3; 281 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 282 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 283 | } 284 | 285 | let grad = casual_self_attention.c_proj.bias.grad(); 286 | let expected_grad = tenors[13].clone(); 287 | let grad_flatten = grad.into_raw_vec(); 288 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 289 | 290 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 291 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 292 | } 293 | 294 | let grad = casual_self_attention.query_attention.weights.grad(); 295 | let expected_grad = tenors[14].clone(); 296 | let grad_flatten = grad.into_raw_vec(); 297 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 298 | 299 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 300 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 301 | } 302 | 303 | let grad = casual_self_attention.query_attention.bias.grad(); 304 | let expected_grad = tenors[15].clone(); 305 | let grad_flatten = grad.into_raw_vec(); 306 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 307 | 308 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 309 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 310 | } 311 | 312 | let grad = casual_self_attention.key_attention.weights.grad(); 313 | let expected_grad = tenors[16].clone(); 314 | let grad_flatten = grad.into_raw_vec(); 315 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 316 | 317 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 318 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 319 | } 320 | 321 | let grad = casual_self_attention.key_attention.bias.grad(); 322 | let expected_grad = tenors[17].clone(); 323 | let grad_flatten = grad.into_raw_vec(); 324 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 325 | 326 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 327 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 328 | } 329 | 330 | let grad = casual_self_attention.value_attention.weights.grad(); 331 | let expected_grad = tenors[18].clone(); 332 | let grad_flatten = grad.into_raw_vec(); 333 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 334 | 335 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 336 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 337 | } 338 | 339 | let grad = casual_self_attention.value_attention.bias.grad(); 340 | let expected_grad = tenors[19].clone(); 341 | let grad_flatten = grad.into_raw_vec(); 342 | let expected_grad_flatten = expected_grad.item().into_raw_vec(); 343 | 344 | for (i, (g, e)) in grad_flatten.iter().zip(expected_grad_flatten.iter()).enumerate() { 345 | assert!((g - e).abs() < ERR, "Mismatch at index {} a {} b {}", i, g, e); 346 | } 347 | 348 | 349 | } 350 | } --------------------------------------------------------------------------------