├── .github └── workflows │ └── rust.yml ├── .gitignore ├── .vscode └── settings.json ├── Cargo.toml ├── LICENSE ├── Makefile ├── README.md └── src ├── bit_attention.rs ├── bit_dropout.rs ├── bit_ffn.rs ├── bit_linear.rs ├── bit_transformer.rs ├── config.rs ├── embedding.rs ├── inference.rs ├── main.rs ├── rms_norm.rs ├── test_data ├── bit_attention_test.safetensors └── scaled_dot_product_gqa.safetensors ├── training.rs └── utils_tensor.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # ignore checkpoint files 17 | *.safetensors 18 | 19 | # ignore generated traces 20 | trace-*.json -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "rust-analyzer.showUnlinkedFileNotification": false 3 | } 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bitnet-rs" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | accelerate-src = { version = "0.3.2", optional = true } 8 | anyhow = "1.0.80" 9 | candle-core = { git = "https://github.com/huggingface/candle.git" } 10 | candle-datasets = { git = "https://github.com/huggingface/candle.git" } 11 | candle-einops = { git = "https://github.com/tomsanbear/candle-einops", branch = "latest-candle" } 12 | candle-flash-attn = { git = "https://github.com/huggingface/candle.git", optional = true } 13 | candle-nn = { git = "https://github.com/huggingface/candle.git" } 14 | candle-transformers = { git = "https://github.com/huggingface/candle.git" } 15 | clap = "4.5.1" 16 | crossterm = "0.27.0" 17 | cudarc = { version = "0.10.0", optional = true } 18 | half = "2.4.0" 19 | hf-hub = "0.3.2" 20 | intel-mkl-src = { version = "0.8.1", optional = true } 21 | kdam = "0.5.1" 22 | num_cpus = "1.16.0" 23 | pretty_trace = "0.5.23" 24 | rand = "0.8.5" 25 | ratatui = "0.26.1" 26 | serde = "1.0.197" 27 | tokenizers = "0.15.2" 28 | tracing = "0.1.40" 29 | tracing-chrome = "0.7.1" 30 | tracing-subscriber = "0.3.18" 31 | 32 | [build-dependencies] 33 | bindgen_cuda = { version = "0.1.1", optional = true } 34 | 35 | 36 | [features] 37 | default = [] 38 | accelerate = [ 39 | "dep:accelerate-src", 40 | "candle-core/accelerate", 41 | "candle-nn/accelerate", 42 | "candle-transformers/accelerate", 43 | ] 44 | cuda = [ 45 | "candle-core/cuda", 46 | "candle-nn/cuda", 47 | "candle-transformers/cuda", 48 | "dep:bindgen_cuda", 49 | ] 50 | cudnn = ["candle-core/cudnn"] 51 | flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] 52 | mkl = [ 53 | "dep:intel-mkl-src", 54 | "candle-core/mkl", 55 | "candle-nn/mkl", 56 | "candle-transformers/mkl", 57 | ] 58 | nccl = ["cuda", "cudarc/nccl"] 59 | metal = ["candle-core/metal", "candle-nn/metal"] 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Thomas Santerre 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build: 2 | cargo build 3 | 4 | tests = test-cpu test-metal test-cuda test-flash-attn 5 | 6 | test-cpu: 7 | cargo test 8 | 9 | test-metal: 10 | cargo test --features "metal" 11 | 12 | test-cuda: 13 | cargo test --features "cuda" 14 | 15 | test-flash-attn: 16 | cargo test --features "flash_attn" 17 | 18 | .PHONY: $(tests) 19 | test-all: $(tests) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bitnet-rs: Bitnet Transformer in Rust! 2 | 3 | Implementation of the Bitnet transformer using [Candle](https://github.com/huggingface/candle). Implementation is based on the pytorch implementation here: [kyegomez/BitNet](https://github.com/kyegomez/BitNet) 4 | 5 | ## About 6 | 7 | I started this project in order to better understand what goes into making a transformer model in a ML library from scratch, rather than re-implement an existing model I wanted to try doing this from a less known and unimplemented model. In addition, I'm curious about non pytorch based models in order to push performance for models, as such learning to use Candle was a big part of this! 8 | 9 | ## Building 10 | 11 | ### CPU 12 | 13 | `cargo build --release` 14 | 15 | ### Metal 16 | 17 | `cargo build --release --features "metal,accelerate"` 18 | 19 | ### CUDA 20 | 21 | `cargo build --release --features "cuda"` 22 | 23 | ## Training 24 | 25 | First, build the binary according to the instructions above, then run the command below. 26 | 27 | `./target/release/bitnet-rs train --dataset ""` 28 | 29 | Replace `` with the directory location of the dataset you are training from. These must be precompiled datasets. I would recommend using the same dataset that has been used for validation: [karpathy/llama2.c](https://github.com/karpathy/llama2.c?tab=readme-ov-file#training). Please follow the instructions in that repository for generating the pretokenized dataset. 30 | 31 | For example, on my machine the training command is this: `./target/release/bitnet-rs train --dataset "../../karpathy/llama2.c/data/TinyStories_all_data"`. 32 | 33 | ## Inference 34 | 35 | First, build the binary according to the instructions above, then run the command below. 36 | 37 | `./target/release/bitnet-rs inference` 38 | 39 | If you want to provide a prompt, provide the `--prompt` flag. 40 | 41 | `./target/release/bitnet-rs inference --prompt "Once upon a time "` 42 | 43 | If you want to specify a specific model to use for the inference, use the `--pretrained-model-path` flag. 44 | 45 | `./target/release/bitnet-rs inference --pretrained-model-path "./checkpoint.safetensors"`. 46 | 47 | ## Known Issues 48 | 49 | I'm still testing this out but I am getting semi coherent output with models I've trained. Definitely not useful for any task right now until I can get loss down. 50 | 51 | ## Contributing 52 | 53 | If you have an interest in contributing please feel free! I'm still learning and would appreciate any suggestions from others. -------------------------------------------------------------------------------- /src/bit_attention.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | bit_linear::{Bitlinear, BitlinearCfg}, 3 | utils_tensor::scaled_dot_product_attention, 4 | }; 5 | use anyhow::{anyhow, Ok, Result}; 6 | use candle_core::{DType, Device, Tensor}; 7 | use candle_einops::einops; 8 | use candle_nn::{rotary_emb::rope_i, VarBuilder}; 9 | use tracing::instrument; 10 | 11 | #[derive(Debug, Clone, Copy)] 12 | pub struct BitAttentionCfg { 13 | pub dim: usize, 14 | pub n_heads: usize, 15 | pub n_kv_heads: usize, 16 | pub dropout: f32, 17 | pub eps: f64, 18 | pub max_seq_len: usize, 19 | } 20 | 21 | #[derive(Debug)] 22 | pub struct BitAttention { 23 | q_proj: Bitlinear, 24 | k_proj: Bitlinear, 25 | v_proj: Bitlinear, 26 | o_proj: Bitlinear, 27 | dropout: f32, 28 | n_heads: usize, 29 | n_kv_heads: usize, 30 | cos: Tensor, 31 | sin: Tensor, 32 | kv_cache: Option<(Tensor, Tensor)>, 33 | } 34 | 35 | fn precompute_freqs_cis( 36 | head_dim: usize, 37 | freq_base: f32, 38 | max_seq_len: usize, 39 | device: &Device, 40 | ) -> Result<(Tensor, Tensor)> { 41 | let theta: Vec<_> = (0..head_dim) 42 | .step_by(2) 43 | .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) 44 | .collect(); 45 | let theta = Tensor::new(theta.as_slice(), device)?; 46 | let idx_theta = Tensor::arange(0, max_seq_len as u32, device)? 47 | .to_dtype(DType::F32)? 48 | .reshape((max_seq_len, 1))? 49 | .matmul(&theta.reshape((1, theta.elem_count()))?)?; 50 | let cos = idx_theta.cos()?; 51 | let sin = idx_theta.sin()?; 52 | Ok((cos, sin)) 53 | } 54 | 55 | impl BitAttention { 56 | pub fn load(cfg: BitAttentionCfg, vb: VarBuilder) -> Result { 57 | let head_dim = cfg.dim / cfg.n_heads; 58 | if cfg.n_heads % cfg.n_kv_heads != 0 { 59 | return Err(anyhow!( 60 | "query_heads must be divisible by kv_heads, got: {} and {}", 61 | cfg.n_heads, 62 | cfg.n_kv_heads 63 | )); 64 | } 65 | if (cfg.dim % cfg.n_heads) != 0 { 66 | return Err(anyhow!( 67 | "dim must be divisible by query_heads, got: {} and {}", 68 | cfg.n_heads, 69 | cfg.dim 70 | )); 71 | } 72 | if (cfg.dim % cfg.n_heads) != 0 { 73 | return Err(anyhow!( 74 | "dim must be divisible by n_kv_heads, got: {} and {}", 75 | cfg.n_heads, 76 | cfg.dim 77 | )); 78 | } 79 | if head_dim % 8 != 0 { 80 | return Err(anyhow!( 81 | "head_dim must be divisible by 8, got: {}", 82 | head_dim 83 | )); 84 | } 85 | if head_dim > 128 { 86 | return Err(anyhow!("head_dim must be less than or equal to 128")); 87 | } 88 | 89 | let q_proj = Bitlinear::load( 90 | BitlinearCfg { 91 | in_features: cfg.dim, 92 | out_features: cfg.n_heads * head_dim, 93 | eps: cfg.eps, 94 | }, 95 | vb.pp("q_proj"), 96 | )?; 97 | let k_proj = Bitlinear::load( 98 | BitlinearCfg { 99 | in_features: cfg.dim, 100 | out_features: cfg.n_kv_heads * head_dim, 101 | eps: cfg.eps, 102 | }, 103 | vb.pp("k_proj"), 104 | )?; 105 | let v_proj = Bitlinear::load( 106 | BitlinearCfg { 107 | in_features: cfg.dim, 108 | out_features: cfg.n_kv_heads * head_dim, 109 | eps: cfg.eps, 110 | }, 111 | vb.pp("v_proj"), 112 | )?; 113 | 114 | let o_proj = Bitlinear::load( 115 | BitlinearCfg { 116 | in_features: cfg.dim, 117 | out_features: cfg.dim, 118 | eps: cfg.eps, 119 | }, 120 | vb.pp("o_proj"), 121 | )?; 122 | 123 | let (cos, sin) = precompute_freqs_cis(head_dim, 10000., cfg.max_seq_len, vb.device())?; 124 | 125 | Ok(BitAttention { 126 | q_proj, 127 | k_proj, 128 | v_proj, 129 | o_proj, 130 | n_heads: cfg.n_heads, 131 | n_kv_heads: cfg.n_kv_heads, 132 | dropout: cfg.dropout, 133 | cos, 134 | sin, 135 | kv_cache: None, 136 | }) 137 | } 138 | 139 | #[instrument] 140 | fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { 141 | let dtype = x.dtype(); 142 | let x = x.to_dtype(DType::F32)?; 143 | let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?; 144 | let cos = self.cos.narrow(0, index_pos, seq_len)?; 145 | let sin = self.sin.narrow(0, index_pos, seq_len)?; 146 | let x = rope_i(&x.contiguous()?, &cos, &sin)?; 147 | let x = x.to_dtype(dtype)?; 148 | Ok(x) 149 | } 150 | 151 | #[instrument] 152 | pub fn forward(&mut self, x: &Tensor, is_causal: bool, index_pos: usize) -> Result { 153 | let q = self.q_proj.forward(x)?; 154 | let k = self.k_proj.forward(x)?; 155 | let v = self.v_proj.forward(x)?; 156 | 157 | let q = einops!("b n ({self.n_heads} d) -> b n {self.n_heads} d", q); 158 | let k = einops!("b n ({self.n_kv_heads} d) -> b n {self.n_kv_heads} d", k); 159 | let v = einops!("b n ({self.n_kv_heads} d) -> b n {self.n_kv_heads} d", v); 160 | 161 | let q = self.apply_rotary_emb(&q, index_pos)?; 162 | let k = self.apply_rotary_emb(&k, index_pos)?; 163 | let (k, v) = match &self.kv_cache { 164 | None => (k, v), 165 | Some((k_cache, v_cache)) => { 166 | if index_pos == 0 { 167 | (k, v) 168 | } else { 169 | let k = Tensor::cat(&[k_cache, &k], 2)?; 170 | let v = Tensor::cat(&[v_cache, &v], 2)?; 171 | (k, v) 172 | } 173 | } 174 | }; 175 | self.kv_cache = Some((k.clone(), v.clone())); 176 | 177 | let scale = (q.dims4()?.3 as f64).sqrt(); 178 | let x = scaled_dot_product_attention( 179 | &q, 180 | &k, 181 | &v, 182 | None, 183 | Some(self.dropout), 184 | Some(is_causal), 185 | Some(scale), 186 | )?; 187 | 188 | let x = einops!("b n h d -> b n (h d)", x); 189 | let x = self.o_proj.forward(&x)?; 190 | Ok(x) 191 | } 192 | } 193 | 194 | #[cfg(test)] 195 | mod bit_attention_tests { 196 | use crate::{ 197 | bit_attention::{BitAttention, BitAttentionCfg}, 198 | utils_tensor::device, 199 | }; 200 | use candle_core::Tensor; 201 | use candle_nn::VarBuilder; 202 | 203 | #[test] 204 | fn forward_produces_expected_shape_f32() -> anyhow::Result<()> { 205 | let device = device(true).unwrap(); 206 | let vb = VarBuilder::zeros(candle_core::DType::F32, &device); 207 | 208 | let input_tensor = Tensor::randn(0.0f32, 1.0f32, (2, 8, 64), &device)?; 209 | let mut bit_attention = BitAttention::load( 210 | BitAttentionCfg { 211 | dim: 64, 212 | n_heads: 8, 213 | n_kv_heads: 8, 214 | dropout: 0.1, 215 | eps: 1e-6, 216 | max_seq_len: 64, 217 | }, 218 | vb, 219 | )?; 220 | 221 | let output_tensor = bit_attention.forward(&input_tensor, true, 0).unwrap(); 222 | 223 | assert_eq!(output_tensor.shape().dims(), &[2, 8, 64]); 224 | 225 | Ok(()) 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /src/bit_dropout.rs: -------------------------------------------------------------------------------- 1 | use candle_core::Tensor; 2 | use tracing::instrument; 3 | 4 | #[derive(Debug)] 5 | pub struct DropoutCfg { 6 | pub p: f32, 7 | pub is_training: bool, 8 | } 9 | 10 | #[derive(Debug)] 11 | pub struct Dropout { 12 | drop_p: f32, 13 | is_training: bool, 14 | } 15 | 16 | impl Dropout { 17 | #[instrument] 18 | pub fn load(cfg: DropoutCfg) -> anyhow::Result { 19 | Ok(Self { 20 | drop_p: cfg.p, 21 | is_training: cfg.is_training, 22 | }) 23 | } 24 | 25 | #[instrument] 26 | pub fn forward(&self, x: &Tensor) -> anyhow::Result { 27 | if !self.is_training { 28 | return Ok(x.clone()); 29 | } 30 | 31 | if !(0. ..1.).contains(&self.drop_p) { 32 | anyhow::bail!( 33 | "dropout probability has to be in [0, 1), got {:?}", 34 | self.drop_p 35 | ) 36 | } 37 | let rand = Tensor::rand(0f32, 1f32, x.shape(), x.device())?; 38 | let scale = 1.0 / (1.0 - self.drop_p as f64); 39 | let drop_p = Tensor::new(self.drop_p, x.device())?.broadcast_as(x.shape())?; 40 | // Metal doesn't support contiguous affine operation so we need to cast to f32 41 | let mask = match x.device() { 42 | candle_core::Device::Metal(_) => { 43 | (rand.ge(&drop_p)?.to_dtype(candle_core::DType::F32)? * scale)?.to_dtype(x.dtype()) 44 | } 45 | _ => (rand.ge(&drop_p)? * scale)?.to_dtype(x.dtype()), 46 | }?; 47 | 48 | Ok((x * mask)?) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/bit_ffn.rs: -------------------------------------------------------------------------------- 1 | use crate::bit_dropout::{Dropout, DropoutCfg}; 2 | use crate::bit_linear::{Bitlinear, BitlinearCfg}; 3 | use crate::rms_norm::RmsNorm; 4 | use candle_core::{Module, Tensor}; 5 | use candle_nn::{Activation, VarBuilder}; 6 | use tracing::instrument; 7 | 8 | pub struct BitFeedForwardCfg { 9 | pub dim: usize, 10 | pub ff_mult: usize, 11 | pub dropout: f32, 12 | pub train: bool, 13 | pub eps: f64, 14 | } 15 | 16 | #[derive(Debug)] 17 | pub struct BitFeedForward { 18 | proj_in: Bitlinear, 19 | activation: Activation, 20 | post_act_norm: RmsNorm, 21 | dropout: Dropout, 22 | proj_out: Bitlinear, 23 | } 24 | 25 | impl BitFeedForward { 26 | pub fn load(cfg: BitFeedForwardCfg, vb: VarBuilder) -> anyhow::Result { 27 | // Setup internal parameters 28 | let inner_dim = cfg.dim * cfg.ff_mult; 29 | 30 | // Use swiglu from 1.58 paper 31 | let activation = Activation::Swiglu; 32 | 33 | // Dropout layer, if train is passed then this is skipped 34 | let dropout = Dropout::load(DropoutCfg { 35 | p: cfg.dropout, 36 | is_training: cfg.train, 37 | })?; 38 | 39 | // Post activation normalization 40 | let post_act_norm = RmsNorm::load(cfg.eps, inner_dim, vb.pp("norm"))?; 41 | 42 | // Input linear layer 43 | let proj_in = Bitlinear::load( 44 | BitlinearCfg { 45 | in_features: cfg.dim, 46 | out_features: inner_dim * 2, 47 | eps: cfg.eps, 48 | }, 49 | vb.pp("proj"), 50 | )?; 51 | 52 | // Linear layer 53 | let proj_out = Bitlinear::load( 54 | BitlinearCfg { 55 | in_features: inner_dim, 56 | out_features: cfg.dim, 57 | eps: cfg.eps, 58 | }, 59 | vb.pp("linear"), 60 | )?; 61 | 62 | // Return the layer as a sequential module 63 | Ok(Self { 64 | proj_in, 65 | activation, 66 | post_act_norm, 67 | dropout, 68 | proj_out, 69 | }) 70 | } 71 | 72 | #[instrument] 73 | pub fn forward(&self, x: &Tensor) -> anyhow::Result { 74 | let x = self.proj_in.forward(x)?; 75 | let x = self.activation.forward(&x)?; 76 | let x = self.post_act_norm.forward(&x)?; 77 | let x = self.dropout.forward(&x)?; 78 | let x = self.proj_out.forward(&x)?; 79 | Ok(x) 80 | } 81 | } 82 | 83 | #[cfg(test)] 84 | mod bitffn_tests { 85 | use super::BitFeedForward; 86 | use crate::{bit_ffn::BitFeedForwardCfg, utils_tensor::device}; 87 | use candle_core::{DType, Device, Tensor}; 88 | use candle_nn::VarBuilder; 89 | 90 | #[test] 91 | fn it_applies_forward_pass_dim_2() -> anyhow::Result<()> { 92 | let device: Device = device(true).unwrap(); 93 | let vb = VarBuilder::zeros(DType::F32, &device); 94 | let dim = 128; 95 | let input: Tensor = Tensor::randn(0f32, 1.0, (10, dim), &device)?; 96 | let bff = BitFeedForward::load( 97 | BitFeedForwardCfg { 98 | dim, 99 | ff_mult: 4, 100 | dropout: 0.1, 101 | train: true, 102 | eps: 1e-6, 103 | }, 104 | vb, 105 | )?; 106 | let output = bff.forward(&input).unwrap(); 107 | let output_shape = output.shape().dims2()?; 108 | assert_eq!(output_shape, (10, dim)); 109 | Ok(()) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/bit_linear.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Ok; 2 | use candle_core::{Tensor, D}; 3 | use candle_nn::{Init, Module, VarBuilder}; 4 | use candle_transformers::models::with_tracing::{Linear, RmsNorm}; 5 | use tracing::instrument; 6 | 7 | #[derive(Debug, Clone, Copy)] 8 | pub struct BitlinearCfg { 9 | pub in_features: usize, 10 | pub out_features: usize, 11 | pub eps: f64, 12 | } 13 | 14 | #[derive(Debug)] 15 | pub struct Bitlinear { 16 | weight: Tensor, 17 | layer_norm: RmsNorm, 18 | } 19 | 20 | impl Bitlinear { 21 | pub fn load(cfg: BitlinearCfg, vb: VarBuilder) -> anyhow::Result { 22 | let weight = vb.get_with_hints( 23 | (cfg.out_features, cfg.in_features), 24 | "weight", 25 | Init::Randn { 26 | mean: 0.0, 27 | stdev: 1.0, 28 | }, 29 | )?; 30 | let layer_norm = RmsNorm::new(cfg.in_features, cfg.eps, vb.pp("rms_norm"))?; 31 | Ok(Self { weight, layer_norm }) 32 | } 33 | 34 | #[instrument] 35 | pub fn forward(&self, x: &Tensor) -> anyhow::Result { 36 | fn activation_quant(x: &Tensor) -> anyhow::Result { 37 | let scale = (127.0 38 | / x.abs()? 39 | .max(D::Minus1)? 40 | .max(D::Minus1)? 41 | .clamp(1e-5, f32::INFINITY)?)?; 42 | let y = x 43 | .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? 44 | .clamp(-128.0, 127.0)?; 45 | Ok(y) 46 | } 47 | 48 | fn weight_quant(x: &Tensor) -> anyhow::Result { 49 | let scale = x.abs()?.mean_all()?; 50 | let e = x.mean_all()?; 51 | let u = x.broadcast_sub(&e)?.sign()?.broadcast_mul(&scale)?; 52 | Ok(u) 53 | } 54 | 55 | let weight = self.weight.clone(); 56 | 57 | let x_norm = self.layer_norm.forward(x)?; 58 | 59 | let x_quant = (x_norm.clone() + (activation_quant(&x_norm)? - x_norm)?.detach())?; 60 | 61 | let w_quant = (weight.clone() + (weight_quant(&weight)? - weight)?.detach())?; 62 | 63 | let y = Linear::from_weights(w_quant, None).forward(&x_quant)?; 64 | 65 | Ok(y) 66 | } 67 | } 68 | 69 | #[cfg(test)] 70 | mod bitlinear_tests { 71 | use super::Bitlinear; 72 | use crate::{bit_linear::BitlinearCfg, utils_tensor::device}; 73 | use candle_core::{DType, Tensor}; 74 | use candle_nn::var_builder::VarBuilderArgs; 75 | 76 | #[test] 77 | fn it_applies_forward_pass() -> anyhow::Result<()> { 78 | let device = device(true).unwrap(); 79 | let vb = VarBuilderArgs::zeros(DType::F32, &device.clone()); 80 | let in_features = 64; 81 | let out_features = 64; 82 | let bl = Bitlinear::load( 83 | BitlinearCfg { 84 | in_features, 85 | out_features, 86 | eps: 1e-6, 87 | }, 88 | vb, 89 | )?; 90 | let input: Tensor = Tensor::randn(0.0f32, 1.0f32, (1, 64), &device.clone())?; 91 | let output = bl.forward(&input)?; 92 | assert_eq!(output.shape().dims2()?, (1, 64)); 93 | Ok(()) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/bit_transformer.rs: -------------------------------------------------------------------------------- 1 | use crate::bit_attention::{BitAttention, BitAttentionCfg}; 2 | use crate::bit_ffn::{BitFeedForward, BitFeedForwardCfg}; 3 | use crate::config::Config; 4 | use crate::embedding::Embedding; 5 | use crate::rms_norm::RmsNorm; 6 | use anyhow::Result; 7 | use candle_core::{Module, Tensor}; 8 | use candle_nn::VarBuilder; 9 | use candle_transformers::models::with_tracing::{linear, Linear}; 10 | use tracing::instrument; 11 | 12 | #[derive(Debug)] 13 | pub struct BitTransformer { 14 | embedding: Embedding, 15 | blocks: Vec<(BitAttention, BitFeedForward)>, 16 | rms_norm: RmsNorm, 17 | logits_linear: Linear, 18 | } 19 | 20 | impl BitTransformer { 21 | pub fn load(cfg: Config, vb: VarBuilder, train: bool) -> Result { 22 | let embedding = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("embedding"))?; 23 | let blocks: Vec<_> = (0..(cfg.depth)) 24 | .map(|i| { 25 | ( 26 | BitAttention::load( 27 | BitAttentionCfg { 28 | dim: cfg.dim, 29 | n_heads: cfg.heads, 30 | n_kv_heads: 8, 31 | dropout: 0.1, 32 | eps: cfg.eps, 33 | max_seq_len: cfg.seq_len, 34 | }, 35 | vb.pp(&format!("attn.{i}")), 36 | ) 37 | .unwrap(), 38 | BitFeedForward::load( 39 | BitFeedForwardCfg { 40 | dim: cfg.dim, 41 | ff_mult: cfg.ff_mult, 42 | dropout: cfg.ff_dropout, 43 | train, 44 | eps: cfg.eps, 45 | }, 46 | vb.pp(&format!("ffn.{i}")), 47 | ) 48 | .unwrap(), 49 | ) 50 | }) 51 | .collect(); 52 | 53 | let rms_norm = RmsNorm::load(cfg.eps, cfg.dim, vb.pp("rms_norm"))?; 54 | let logits_linear = linear(cfg.dim, cfg.vocab_size, vb.pp("logits_linear"))?; 55 | 56 | Ok(Self { 57 | blocks, 58 | rms_norm, 59 | logits_linear, 60 | embedding, 61 | }) 62 | } 63 | 64 | #[instrument] 65 | pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { 66 | // Run the embedding layer 67 | let x_embed = self.embedding.forward(x)?; 68 | 69 | // Fold each block forward 70 | let mut x = x_embed.clone(); 71 | for (attn, ffn) in self.blocks.iter_mut() { 72 | x = attn.forward(&x, true, index_pos)?; 73 | x = x.add(&x_embed)?; 74 | x = ffn.forward(&x)?; 75 | x = x.add(&x)?; 76 | } 77 | 78 | // Convert to logits 79 | let x = self.rms_norm.forward(&x)?; 80 | let x = self.logits_linear.forward(&x)?; 81 | Ok(x) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/config.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Copy, Clone, PartialEq)] 2 | pub struct Config { 3 | pub(crate) dim: usize, 4 | pub(crate) depth: usize, 5 | pub(crate) vocab_size: usize, 6 | pub(crate) heads: usize, 7 | pub(crate) ff_mult: usize, 8 | pub(crate) eps: f64, 9 | pub(crate) ff_dropout: f32, 10 | pub(crate) seq_len: usize, 11 | } 12 | 13 | impl Config { 14 | // Default configuration for initial evaluation, will add larger configs later after confirming valid output 15 | pub fn default() -> Self { 16 | Self { 17 | dim: 256, 18 | depth: 12, 19 | vocab_size: 32000, 20 | heads: 8, 21 | ff_mult: 12, 22 | eps: 1e-6, 23 | ff_dropout: 0.1, 24 | seq_len: 100, 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/embedding.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use candle_core::{DType, Tensor}; 3 | use candle_nn::{Init, VarBuilder}; 4 | 5 | #[derive(Clone, Debug)] 6 | pub struct Embedding { 7 | embeddings: Tensor, 8 | hidden_size: usize, 9 | forward_dtype: DType, 10 | } 11 | 12 | impl Embedding { 13 | pub fn new(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { 14 | let embeddings = vb 15 | .get_with_hints( 16 | (vocab_size, hidden_size), 17 | "embedding.weight", 18 | Init::Randn { 19 | mean: 0., 20 | stdev: 1., 21 | }, 22 | )? 23 | .to_dtype(DType::F32)?; 24 | Ok(Self { 25 | embeddings, 26 | hidden_size, 27 | forward_dtype: vb.dtype(), 28 | }) 29 | } 30 | 31 | pub fn forward(&self, indexes: &Tensor) -> Result { 32 | let mut final_dims = indexes.dims().to_vec(); 33 | final_dims.push(self.hidden_size); 34 | let indexes = indexes.flatten_all()?; 35 | let values = self.embeddings.index_select(&indexes, 0)?; 36 | let values = values.reshape(final_dims)?.to_dtype(self.forward_dtype)?; 37 | Ok(values) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/inference.rs: -------------------------------------------------------------------------------- 1 | use std::io::Write; 2 | 3 | use anyhow::Result; 4 | use candle_transformers::generation::LogitsProcessor; 5 | use rand::Rng; 6 | use tokenizers::Tokenizer; 7 | 8 | use crate::bit_transformer::BitTransformer; 9 | use crate::config::Config; 10 | use crate::{utils_tensor::device, Args, InferenceCmd}; 11 | use candle_core::{safetensors, DType, IndexOp, Tensor}; 12 | use candle_nn::VarBuilder; 13 | 14 | pub fn run(args: &InferenceCmd, common_args: &Args) -> Result<()> { 15 | let mut rng = rand::thread_rng(); 16 | 17 | let tokenizer = { 18 | let api = hf_hub::api::sync::Api::new()?; 19 | let api = api.model("hf-internal-testing/llama-tokenizer".to_string()); 20 | let tokenizer_path = api.get("tokenizer.json")?; 21 | Tokenizer::from_file(tokenizer_path).unwrap() 22 | }; 23 | 24 | let device = device(common_args.cpu)?; 25 | 26 | let safetensors = safetensors::load("./checkpoint.safetensors", &device)?; 27 | let vb = VarBuilder::from_tensors(safetensors, DType::F32, &device); 28 | 29 | let config = Config::default(); 30 | let mut model = BitTransformer::load(config, vb, false)?; 31 | 32 | println!("starting the inference loop"); 33 | let mut logits_processor = LogitsProcessor::new(rng.gen(), args.temperature, Some(args.top_p)); 34 | 35 | print!("{}", args.prompt); 36 | let mut tokens = tokenizer 37 | .encode(args.prompt.clone(), true) 38 | .unwrap() 39 | .get_ids() 40 | .to_vec(); 41 | let mut generated_tokens = 0usize; 42 | let eos_token = match tokenizer.get_vocab(true).get("") { 43 | Some(token) => *token, 44 | None => anyhow::bail!("cannot find the endoftext token"), 45 | }; 46 | 47 | let start_gen = std::time::Instant::now(); 48 | for index in 0.. { 49 | if tokens.len() >= config.seq_len { 50 | break; 51 | } 52 | let context_size = if index > 0 { 1 } else { tokens.len() }; 53 | let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; 54 | let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; 55 | let logits = model.forward(&input, index)?; 56 | let logits = logits.i((0, logits.dim(1)? - 1))?; 57 | let logits = if args.repeat_penalty == 1. || tokens.is_empty() { 58 | logits 59 | } else { 60 | let start_at = tokens.len().saturating_sub(args.repeat_last_n); 61 | candle_transformers::utils::apply_repeat_penalty( 62 | &logits, 63 | args.repeat_penalty, 64 | &tokens[start_at..], 65 | )? 66 | }; 67 | 68 | let next_token = logits_processor.sample(&logits)?; 69 | tokens.push(next_token); 70 | generated_tokens += 1; 71 | if next_token == eos_token { 72 | break; 73 | } 74 | let token = tokenizer.decode(&[next_token], true).unwrap(); 75 | print!("{token}"); 76 | std::io::stdout().flush()?; 77 | } 78 | let dt = start_gen.elapsed(); 79 | println!( 80 | "\n{generated_tokens} tokens generated ({:.2} token/s)", 81 | generated_tokens as f64 / dt.as_secs_f64(), 82 | ); 83 | Ok(()) 84 | } 85 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | mod bit_attention; 2 | mod bit_dropout; 3 | mod bit_ffn; 4 | mod bit_linear; 5 | mod bit_transformer; 6 | mod config; 7 | mod embedding; 8 | mod inference; 9 | mod rms_norm; 10 | mod training; 11 | mod utils_tensor; 12 | 13 | use anyhow::{anyhow, Result}; 14 | use clap::{Parser, Subcommand}; 15 | 16 | #[cfg(feature = "accelerate")] 17 | extern crate accelerate_src; 18 | 19 | #[cfg(feature = "mkl")] 20 | extern crate intel_mkl_src; 21 | 22 | #[derive(Parser, Debug, Clone)] 23 | struct InferenceCmd { 24 | /// Pretrained model path, only safetensors are supported 25 | #[arg(long, default_value = "./checkpoint.safetensors")] 26 | pretrained_model_path: String, 27 | 28 | /// The temperature used to generate samples. 29 | #[arg(long)] 30 | temperature: Option, 31 | 32 | /// Nucleus sampling probability cutoff. 33 | #[arg(long, default_value = "0.9")] 34 | top_p: f64, 35 | 36 | /// The repeat penalty to use 37 | #[arg(long, default_value = "1.1")] 38 | repeat_penalty: f32, 39 | 40 | /// The prompt for generation 41 | #[arg(long, default_value = "")] 42 | prompt: String, 43 | 44 | /// The number of tokens to repeat 45 | /// This is used to penalize repeating tokens 46 | #[arg(long, default_value = "10")] 47 | repeat_last_n: usize, 48 | } 49 | 50 | #[derive(Parser, Debug, Clone)] 51 | pub struct TrainingCmd { 52 | /// The data type for the weights, due to the implementation, we should theoretically be able to use a single bit, but we need candle to support this or contribute this 53 | /// For now, this can only be: f32 54 | #[arg(long, default_value = "f32")] 55 | dtype: String, 56 | 57 | /// The path to the dataset. 58 | #[arg( 59 | long, 60 | default_value = "../../karpathy/llama2.c/data/TinyStories_all_data" 61 | )] 62 | dataset: String, 63 | 64 | /// The batch size to use 65 | #[arg(long, default_value = "1")] 66 | batch_size: usize, 67 | 68 | /// The learning rate to use 69 | #[arg(long, default_value = "8e-4")] 70 | learning_rate: f64, 71 | 72 | /// The sequence length to use 73 | #[arg(long, default_value = "100")] 74 | seq_len: usize, 75 | 76 | /// The number of tokens in the vocabulary 77 | #[arg(long, default_value = "32000")] 78 | num_tokens: usize, 79 | 80 | /// The checkpoint file to continue from 81 | #[arg(long)] 82 | checkpoint: Option, 83 | 84 | /// The number of epochs to train for 85 | #[arg(long, default_value = "1")] 86 | epochs: usize, 87 | 88 | /// Max number of steps 89 | #[arg(long, default_value = "100000")] 90 | max_steps: usize, 91 | } 92 | 93 | #[derive(Subcommand, Debug, Clone)] 94 | enum Task { 95 | Inference(InferenceCmd), 96 | Train(TrainingCmd), 97 | } 98 | 99 | #[derive(Parser, Debug)] 100 | #[command(author, version, about, long_about = None)] 101 | struct Args { 102 | /// The task to be performed, inference, training or evaluation. 103 | #[command(subcommand)] 104 | task: Option, 105 | 106 | /// Enable tracing (generates a trace-timestamp.json file). 107 | #[arg(long)] 108 | tracing: bool, 109 | 110 | /// Run on CPU rather than on GPU. 111 | #[arg(long)] 112 | cpu: bool, 113 | } 114 | 115 | fn main() -> Result<()> { 116 | use tracing_chrome::ChromeLayerBuilder; 117 | use tracing_subscriber::prelude::*; 118 | 119 | let args = Args::parse(); 120 | 121 | // Setup tracing if enabled 122 | let _guard = if args.tracing { 123 | let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); 124 | tracing_subscriber::registry().with(chrome_layer).init(); 125 | Some(guard) 126 | } else { 127 | None 128 | }; 129 | 130 | match &args.task { 131 | Some(Task::Inference(cmd)) => inference::run(cmd, &args)?, 132 | Some(Task::Train(cmd)) => training::run(cmd, &args)?, 133 | _ => return Err(anyhow!("No task specified")), 134 | } 135 | Ok(()) 136 | } 137 | -------------------------------------------------------------------------------- /src/rms_norm.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{Module, Result, Tensor}; 2 | use candle_nn::VarBuilder; 3 | use tracing::instrument; 4 | 5 | #[derive(Debug, Clone)] 6 | pub struct RmsNorm { 7 | inner: candle_nn::RmsNorm, 8 | } 9 | 10 | impl RmsNorm { 11 | pub fn load(rms_norm_eps: f64, size: usize, vb: VarBuilder) -> Result { 12 | let inner = candle_nn::rms_norm(size, rms_norm_eps, vb)?; 13 | Ok(Self { inner }) 14 | } 15 | } 16 | 17 | impl Module for RmsNorm { 18 | #[instrument] 19 | fn forward(&self, x: &Tensor) -> Result { 20 | self.inner.forward(x) 21 | } 22 | } 23 | 24 | #[cfg(test)] 25 | mod rmsnorm_tests { 26 | 27 | use super::RmsNorm; 28 | use candle_core::{DType, Device, Module, Result, Tensor}; 29 | use candle_nn::VarBuilder; 30 | 31 | #[test] 32 | fn it_loads() -> Result<()> { 33 | let vb = VarBuilder::zeros(DType::F64, &Device::Cpu); 34 | RmsNorm::load(1e-6, 512, vb)?; 35 | Ok(()) 36 | } 37 | 38 | #[test] 39 | fn it_applies_forward_pass() -> Result<()> { 40 | let vb = VarBuilder::zeros(DType::F32, &Device::Cpu); 41 | let rmsnorm = RmsNorm::load(1e-6, 512, vb)?; 42 | let input = Tensor::ones((1, 512), DType::F32, &Device::Cpu)?; 43 | let output = rmsnorm.forward(&input).unwrap(); 44 | assert_eq!(output.shape().dims(), &[1, 512]); 45 | Ok(()) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/test_data/bit_attention_test.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomsanbear/bitnet-rs/2adf025bda13a7e579c63f940d36fd21423b4b13/src/test_data/bit_attention_test.safetensors -------------------------------------------------------------------------------- /src/test_data/scaled_dot_product_gqa.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomsanbear/bitnet-rs/2adf025bda13a7e579c63f940d36fd21423b4b13/src/test_data/scaled_dot_product_gqa.safetensors -------------------------------------------------------------------------------- /src/training.rs: -------------------------------------------------------------------------------- 1 | use crate::config::Config; 2 | use crate::utils_tensor::cross_entropy; 3 | use crate::{bit_transformer::BitTransformer, utils_tensor::device, Args, TrainingCmd}; 4 | use anyhow::Result; 5 | use candle_core::{DType, Device}; 6 | use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter}; 7 | use candle_datasets::Batcher; 8 | use candle_nn::{AdamW, Optimizer, VarBuilder, VarMap}; 9 | use kdam::{tqdm, BarExt}; 10 | use std::time::{SystemTime, UNIX_EPOCH}; 11 | use tracing::span; 12 | 13 | fn valid_loss( 14 | seq_len: usize, 15 | batch_size: usize, 16 | dataset: &Dataset, 17 | model: &mut BitTransformer, 18 | device: &Device, 19 | ) -> Result { 20 | let span = span!(tracing::Level::TRACE, "validate-loss"); 21 | let _enter = span.enter(); 22 | 23 | let iter = DatasetRandomIter::new(dataset, true, seq_len, device.clone()); 24 | let batch_iter = Batcher::new_r2(iter).batch_size(batch_size); 25 | let mut sum_ce = 0f64; 26 | let mut cnt = 0usize; 27 | let batch_count = 10; 28 | 29 | for inp_tgt in batch_iter.take(batch_count) { 30 | let span = span!(tracing::Level::TRACE, "validate-loss-iter"); 31 | let _enter = span.enter(); 32 | let (inp, tgt) = inp_tgt?; 33 | let logits = model.forward(&inp, 0)?; 34 | let loss = cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; 35 | sum_ce += match loss.dtype() { 36 | DType::F32 => f64::from(loss.to_vec0::()?), 37 | DType::F16 => f64::from(loss.to_vec0::()?), 38 | DType::BF16 => f64::from(loss.to_vec0::()?), 39 | _ => panic!("Invalid dtype"), 40 | }; 41 | cnt += 1; 42 | } 43 | Ok(sum_ce / cnt as f64) 44 | } 45 | 46 | pub fn run(args: &TrainingCmd, common_args: &Args) -> Result<()> { 47 | let span = span!(tracing::Level::TRACE, "training"); 48 | let _enter = span.enter(); 49 | 50 | // Setup device 51 | let device = device(common_args.cpu)?; 52 | 53 | // Get the underlying data type to use for the model 54 | let dtype = match args.dtype.as_str() { 55 | "f32" => candle_core::DType::F32, 56 | "f16" => candle_core::DType::F16, 57 | "bf16" => candle_core::DType::BF16, 58 | _ => panic!("Invalid dtype"), 59 | }; 60 | 61 | // Setup varbuilder 62 | let mut varmap = VarMap::new(); 63 | 64 | // Load vars if checkpoint was provided 65 | if let Some(checkpoint) = &args.checkpoint { 66 | println!("Loading checkpoint: {:?}", checkpoint); 67 | varmap.load(checkpoint)?; 68 | } 69 | 70 | // Setup varbuilder 71 | let vb = VarBuilder::from_varmap(&varmap, dtype, &device); 72 | 73 | // Get the datasets 74 | let dataset = { Dataset::new(&args.dataset)? }; 75 | 76 | // Setup the model 77 | let config = Config { 78 | seq_len: args.seq_len, 79 | ..Config::default() 80 | }; 81 | let mut model = BitTransformer::load(config, vb, true)?; 82 | 83 | // Setup the optimizer 84 | let mut opt = AdamW::new_lr(varmap.all_vars(), args.learning_rate)?; 85 | 86 | // Setup the dataset, currently using tinystories to replicate the llama example from candle 87 | let iter = DatasetRandomIter::new(&dataset, true, args.seq_len, device.clone()); 88 | let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); 89 | 90 | // Training loop 91 | let mut training_loss = 0f64; 92 | let mut validation_loss = 0f64; 93 | let mut pb = tqdm!(total = args.max_steps, desc = "Training", position = 0); 94 | for (batch_index, batch) in batch_iter.enumerate() { 95 | pb.update(1)?; 96 | pb.set_postfix(&format!( 97 | "Training Loss: {:.2}, Validation Loss: {:.2}", 98 | training_loss, validation_loss 99 | )); 100 | 101 | if batch_index > args.max_steps { 102 | break; 103 | } 104 | 105 | let span = span!(tracing::Level::TRACE, "training-iteration"); 106 | let _enter = span.enter(); 107 | 108 | let (inp, tgt) = batch?; 109 | let logits = model.forward(&inp, 0)?; 110 | let loss = cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; 111 | training_loss = match dtype { 112 | candle_core::DType::F32 => f64::from(loss.to_vec0::()?), 113 | candle_core::DType::F16 => f64::from(loss.to_vec0::()?), 114 | candle_core::DType::BF16 => f64::from(loss.to_vec0::()?), 115 | _ => panic!("Invalid dtype"), 116 | }; 117 | opt.backward_step(&loss)?; 118 | 119 | if batch_index > 0 && batch_index % 100 == 0 { 120 | validation_loss = 121 | valid_loss(args.seq_len, args.batch_size, &dataset, &mut model, &device)?; 122 | if batch_index % 10000 == 0 { 123 | let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); 124 | let checkpoint_file_name = format!("checkpoint-{:?}.safetensors", timestamp); 125 | varmap.save(checkpoint_file_name)?; 126 | } 127 | varmap.save("checkpoint.safetensors")?; 128 | } 129 | } 130 | 131 | Ok(()) 132 | } 133 | -------------------------------------------------------------------------------- /src/utils_tensor.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use candle_core::utils::{cuda_is_available, metal_is_available}; 3 | use candle_core::{DType, Device, Shape, Tensor, WithDType, D}; 4 | use candle_nn::ops::{self}; 5 | use tracing::instrument; 6 | 7 | // Get the device to use for the tensor operations, only really used for tests 8 | // Originally from: https://github.com/huggingface/candle/blob/314630638d8f6886c07d73211d6c35f8cf05d56a/candle-examples/src/lib.rs#L9 9 | pub fn device(cpu: bool) -> Result { 10 | if cpu { 11 | Ok(Device::Cpu) 12 | } else if cuda_is_available() { 13 | Ok(Device::new_cuda(0)?) 14 | } else if metal_is_available() { 15 | Ok(Device::new_metal(0)?) 16 | } else { 17 | #[cfg(all(target_os = "macos", target_arch = "aarch64"))] 18 | { 19 | println!( 20 | "Running on CPU, to run on GPU(metal), build this example with `--features metal`" 21 | ); 22 | } 23 | #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] 24 | { 25 | println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); 26 | } 27 | Ok(Device::Cpu) 28 | } 29 | } 30 | 31 | #[instrument] 32 | pub fn full + std::fmt::Debug, D: WithDType + std::fmt::Debug>( 33 | shape: S, 34 | fill_value: D, 35 | dtype: DType, 36 | device: &Device, 37 | ) -> candle_core::Result { 38 | Tensor::new(&[fill_value], device)? 39 | .to_dtype(dtype)? 40 | .broadcast_as(shape) 41 | } 42 | 43 | #[instrument] 44 | pub fn full_like( 45 | input: &Tensor, 46 | fill_value: D, 47 | ) -> candle_core::Result { 48 | full(input.shape(), fill_value, input.dtype(), input.device()) 49 | } 50 | 51 | #[instrument] 52 | pub fn masked_fill( 53 | xs: &Tensor, 54 | mask: &Tensor, 55 | value: D, 56 | ) -> candle_core::Result { 57 | let on_true = full_like(xs, value)?; 58 | let on_false = xs; 59 | mask.broadcast_as(xs.shape())? 60 | .where_cond(&on_true, on_false) 61 | } 62 | 63 | #[instrument] 64 | fn apply_triangular(xs: &Tensor, diagonal: isize, upper: bool) -> candle_core::Result { 65 | let device = xs.device(); 66 | let (l, s) = xs.dims2()?; 67 | let mut xs_tri = vec![]; 68 | for i in 0..l as isize { 69 | for j in 0..s as isize { 70 | let cond = if upper { 71 | i + diagonal > j 72 | } else { 73 | i + diagonal < j 74 | }; 75 | xs_tri.push(if cond { 0u8 } else { 1u8 }); 76 | } 77 | } 78 | xs * Tensor::from_vec(xs_tri, (l, s), device)?.to_dtype(xs.dtype())? 79 | } 80 | 81 | #[instrument] 82 | pub fn logical_not(xs: &Tensor) -> Result { 83 | let out = xs.where_cond(&xs.zeros_like()?, &xs.ones_like()?)?; 84 | Ok(out) 85 | } 86 | 87 | #[instrument] 88 | pub fn dropout(xs: &Tensor, drop_p: f32) -> candle_core::Result { 89 | // This implementation is inefficient as it stores the full mask for the backward pass. 90 | // Instead we could just store the seed and have a specialized kernel that would both 91 | // generate the random mask and apply it. 92 | // Another easier optimization would be to be able to generate boolean mask using just a bit of 93 | // entropy per element rather than generating a full float per element. 94 | if !(0. ..1.).contains(&drop_p) { 95 | candle_core::bail!("dropout probability has to be in [0, 1), got {drop_p}") 96 | } 97 | let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?; 98 | let scale = 1.0 / (1.0 - drop_p as f64); 99 | let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?; 100 | let mask = (rand.ge(&drop_p)?.to_dtype(DType::F32)? * scale)?.to_dtype(xs.dtype())?; 101 | xs * mask 102 | } 103 | 104 | #[instrument] 105 | pub fn scaled_dot_product_attention( 106 | query: &Tensor, 107 | key: &Tensor, 108 | value: &Tensor, 109 | attn_mask: Option<&Tensor>, 110 | dropout_p: Option, 111 | is_causal: Option, 112 | scale: Option, 113 | ) -> Result { 114 | let device = query.device(); 115 | let l = query.dim(D::Minus2)?; 116 | let s = key.dim(D::Minus2)?; 117 | let dim = query.dim(D::Minus1)?; 118 | 119 | let scale_factor = if let Some(scale) = scale { 120 | scale 121 | } else { 122 | 1.0 / (dim as f64).sqrt() 123 | }; 124 | 125 | let mut attn_bias = Tensor::zeros((l, s), query.dtype(), device)?; 126 | 127 | if matches!(is_causal, Some(true)) { 128 | assert!(attn_mask.is_none(), "scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True"); 129 | let mask = apply_triangular(&Tensor::ones((l, s), DType::U8, device)?, 0, false)?; 130 | attn_bias = masked_fill(&attn_bias, &logical_not(&mask)?, f32::NEG_INFINITY)?; 131 | } 132 | 133 | if let Some(attn_mask) = attn_mask { 134 | if attn_mask.rank() > attn_bias.rank() { 135 | attn_bias = attn_bias.broadcast_as(attn_mask.shape())?; 136 | } 137 | if attn_mask.dtype() == DType::U8 { 138 | // bool 139 | attn_bias = masked_fill(&attn_bias, &logical_not(attn_mask)?, f32::NEG_INFINITY)?; 140 | } else { 141 | attn_bias = (&attn_bias 142 | + attn_mask 143 | .to_dtype(attn_bias.dtype())? 144 | .broadcast_as(attn_bias.shape())?)?; 145 | } 146 | } 147 | 148 | let mut attn_weights = 149 | (query.matmul(&key.transpose(D::Minus2, D::Minus1)?.contiguous()?)? * scale_factor)?; 150 | 151 | attn_weights = (&attn_weights + attn_bias.broadcast_as(attn_weights.shape())?)?; 152 | attn_weights = ops::softmax_last_dim(&attn_weights)?; 153 | if let Some(drop_p) = dropout_p { 154 | attn_weights = if attn_weights.device().is_metal() { 155 | dropout(&attn_weights, drop_p) 156 | } else { 157 | ops::dropout(&attn_weights, drop_p) 158 | }?; 159 | } 160 | let out = attn_weights.matmul(value)?; 161 | Ok(out) 162 | } 163 | 164 | #[instrument] 165 | pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> candle_core::Result { 166 | candle_nn::loss::cross_entropy(inp, target) 167 | } 168 | --------------------------------------------------------------------------------