├── .github └── workflows │ ├── publish-docs.yml │ └── validate.yml ├── .gitignore ├── .gitmodules ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── examples ├── mistral │ ├── .gitignore │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── main.rs │ │ └── model.rs └── mnist │ ├── Cargo.toml │ ├── data │ ├── t10k-images.idx3-ubyte │ ├── t10k-labels.idx1-ubyte │ ├── train-images.idx3-ubyte │ └── train-labels.idx1-ubyte │ └── src │ ├── data.rs │ ├── main.rs │ └── mlp.rs ├── mlx-internal-macros ├── Cargo.toml └── src │ ├── derive_buildable.rs │ ├── derive_builder.rs │ ├── generate_builder.rs │ ├── generate_macro.rs │ ├── lib.rs │ └── shared.rs ├── mlx-macros ├── Cargo.toml └── src │ ├── lib.rs │ ├── module_parameters.rs │ ├── quantizable.rs │ └── util.rs ├── mlx-rs ├── CHANGELOG.md ├── Cargo.toml ├── README.md ├── examples │ ├── linear_regression.rs │ └── tutorial.rs └── src │ ├── array │ ├── element.rs │ ├── mod.rs │ ├── operators.rs │ └── safetensors.rs │ ├── builder.rs │ ├── device.rs │ ├── dtype.rs │ ├── error.rs │ ├── fast.rs │ ├── fft │ ├── fftn.rs │ ├── mod.rs │ ├── rfftn.rs │ └── utils.rs │ ├── lib.rs │ ├── linalg.rs │ ├── losses.rs │ ├── macros │ ├── array.rs │ ├── assert.rs │ ├── internal.rs │ └── mod.rs │ ├── module │ ├── mod.rs │ ├── module.rs │ └── param.rs │ ├── nested.rs │ ├── nn │ ├── activation.rs │ ├── container.rs │ ├── convolution.rs │ ├── convolution_transpose.rs │ ├── dropout.rs │ ├── embedding.rs │ ├── linear.rs │ ├── mod.rs │ ├── normalization.rs │ ├── pooling.rs │ ├── positional_encoding.rs │ ├── quantized.rs │ ├── recurrent.rs │ ├── transformer.rs │ ├── upsample.rs │ └── value_and_grad.rs │ ├── ops │ ├── arithmetic.rs │ ├── conversion.rs │ ├── convolution.rs │ ├── cumulative.rs │ ├── factory.rs │ ├── indexing │ │ ├── index_impl.rs │ │ ├── indexmut_impl.rs │ │ └── mod.rs │ ├── io.rs │ ├── logical.rs │ ├── mod.rs │ ├── other.rs │ ├── quantization.rs │ ├── reduction.rs │ ├── shapes.rs │ └── sort.rs │ ├── optimizers │ ├── adadelta.rs │ ├── adafactor.rs │ ├── adagrad.rs │ ├── adam.rs │ ├── adamax.rs │ ├── adamw.rs │ ├── lion.rs │ ├── mod.rs │ ├── rmsprop.rs │ └── sgd.rs │ ├── quantization.rs │ ├── random.rs │ ├── stream.rs │ ├── transforms │ ├── compile │ │ ├── compile.rs │ │ ├── compile_with_state.rs │ │ └── mod.rs │ ├── grad.rs │ ├── keyed_value_and_grad.rs │ ├── mod.rs │ └── value_and_grad.rs │ └── utils │ ├── guard.rs │ ├── io.rs │ └── mod.rs ├── mlx-sys ├── CHANGELOG.md ├── Cargo.toml ├── README.md ├── build.rs ├── examples │ └── is_metal_available.rs └── src │ └── lib.rs └── mlx-tests ├── Cargo.toml ├── src └── lib.rs └── tests ├── common.rs ├── test_compile_with_state.rs ├── test_disable_compile.rs ├── test_exported_macros.rs ├── test_generate_builder.rs ├── test_generate_macro.rs ├── test_internal_macros.rs ├── test_module.rs ├── test_module_parameters.rs ├── test_optimizers.rs └── test_quantizable.rs /.github/workflows/publish-docs.yml: -------------------------------------------------------------------------------- 1 | name: publish-docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | pages: write 11 | id-token: write 12 | 13 | concurrency: 14 | group: "pages" 15 | cancel-in-progress: false 16 | 17 | jobs: 18 | build-docs: 19 | runs-on: blaze/macos-14 20 | concurrency: 21 | group: ${{ github.workflow }}-${{ github.ref }} 22 | steps: 23 | - name: Checkout 24 | uses: actions/checkout@v4 25 | with: 26 | submodules: true 27 | - name: Setup Dependencies 28 | run: brew install gnu-tar 29 | - name: Install Rust 30 | uses: actions-rust-lang/setup-rust-toolchain@v1 31 | - name: Build docs 32 | run: | 33 | cargo doc --no-deps 34 | echo "" > target/doc/index.html 35 | - name: Setup Pages 36 | uses: actions/configure-pages@v5 37 | - name: Upload artifact 38 | uses: actions/upload-pages-artifact@v3 39 | with: 40 | path: './target/doc' 41 | deploy-docs: 42 | environment: 43 | name: github-pages 44 | url: ${{ steps.deployment.outputs.page_url }} 45 | runs-on: ubuntu-latest 46 | needs: build-docs 47 | steps: 48 | - name: Deploy to GitHub Pages 49 | id: deployment 50 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.github/workflows/validate.yml: -------------------------------------------------------------------------------- 1 | name: validate 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | types: [opened, synchronize] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | jobs: 13 | rustfmt-check: 14 | runs-on: blaze/macos-15 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | with: 19 | submodules: true 20 | - name: Setup Xcode 21 | run: sudo xcodes select 16.0 22 | - name: Install Rust 23 | uses: actions-rust-lang/setup-rust-toolchain@v1 24 | with: 25 | components: rustfmt, clippy 26 | - name: Run cargo fmt 27 | run: cargo fmt -- --check 28 | - name: Run cargo clippy 29 | run: cargo clippy -- -D warnings 30 | 31 | tests: 32 | runs-on: blaze/macos-15 33 | strategy: 34 | matrix: 35 | rust: [ stable, 1.81.0 ] 36 | include: 37 | - cache: stable 38 | rust: stable 39 | - cache: 1-81-0 40 | rust: 1.81.0 41 | steps: 42 | - name: Checkout 43 | uses: actions/checkout@v4 44 | with: 45 | submodules: true 46 | - name: Setup Xcode 47 | run: sudo xcodes select 16.0 48 | - name: Install Rust 49 | uses: actions-rust-lang/setup-rust-toolchain@v1 50 | with: 51 | cache: false 52 | toolchain: ${{ matrix.rust }} 53 | rustflags: "" # Disable when we're ready 54 | - name: Setup cache 55 | uses: Swatinem/rust-cache@v2 56 | with: 57 | key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.backend }}-${{ hashFiles('**/Cargo.toml') }} 58 | - name: Run tests 59 | run: cargo test --all -- --test-threads=1 # MLX is not thread safe -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | settings.json 17 | **.DS_Store 18 | .idea 19 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mlx-sys/src/mlx-c"] 2 | path = mlx-sys/src/mlx-c 3 | url = https://github.com/ml-explore/mlx-c.git 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace.package] 2 | # All but mlx-sys should follow the same version. mlx-sys should follow 3 | # the version of mlx-c. 4 | version = "0.25.0-alpha.1" 5 | edition = "2021" 6 | authors = [ 7 | "Minghua Wu ", 8 | "David Chavez ", 9 | ] 10 | 11 | repository = "https://github.com/oxideai/mlx-rs" 12 | keywords = ["mlx", "deep-learning", "machine-learning"] 13 | categories = ["science"] 14 | license = "MIT OR Apache-2.0" 15 | documentation = "https://oxideai.github.io/mlx-rs/mlx_rs/" 16 | 17 | [workspace] 18 | members = [ 19 | "mlx-macros", 20 | "mlx-sys", 21 | "mlx-rs", 22 | "mlx-internal-macros", 23 | "mlx-tests", 24 | "examples/*", 25 | ] 26 | 27 | resolver = "2" 28 | 29 | [workspace.dependencies] 30 | # workspace local dependencies 31 | mlx-sys = { version = "=0.2.0-alpha.2", path = "mlx-sys" } 32 | mlx-macros = { version = "0.25.0-alpha.1", path = "mlx-macros" } 33 | mlx-internal-macros = { version = "0.25.0-alpha.1", path = "mlx-internal-macros" } 34 | mlx-rs = { version = "0.25.0-alpha.1", path = "mlx-rs" } 35 | 36 | # external dependencies 37 | thiserror = "2" 38 | float_eq = "1" 39 | pretty_assertions = "1.4.0" 40 | dyn-clone = "1" 41 | half = "2" 42 | mach-sys = "0.5" 43 | num-complex = "0.4" 44 | num_enum = "0.7" 45 | num-traits = "0.2" 46 | paste = "1" 47 | smallvec = "1" 48 | strum = { version = "0.26", features = ["derive"] } 49 | libc = "0.2" 50 | parking_lot = "0.12" 51 | tempfile = "3" 52 | itertools = "0.14" 53 | syn = { version = "2", features = ["full"] } 54 | quote = "1" 55 | darling = "0.20" 56 | proc-macro2 = "1" 57 | bindgen = "0.70" 58 | cmake = "0.1" 59 | cc = "1" 60 | safetensors = "0.5" 61 | bytemuck = "1" 62 | memmap2 = "0.9" -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | mlx-rs/README.md -------------------------------------------------------------------------------- /examples/mistral/.gitignore: -------------------------------------------------------------------------------- 1 | /cache/ 2 | /data/ 3 | .env -------------------------------------------------------------------------------- /examples/mistral/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mistral" 3 | edition = "2021" 4 | version.workspace = true 5 | authors.workspace = true 6 | 7 | [dependencies] 8 | # Local dependencies 9 | mlx-rs.workspace = true 10 | 11 | # External dependencies 12 | tokenizers = "=0.21.0" # 0.21.1 uses features that went stable in 1.82 while our MSRV is 1.81 13 | thiserror = "1.0" 14 | anyhow = "1.0" 15 | hf-hub = "=0.4.1" # 0.4.2 uses features that went stable in 1.82 while our MSRV is 1.81 16 | dotenv = "0.15" 17 | serde = { version = "1", features = ["derive"] } 18 | serde_json = "1" 19 | clap = { version = "4", features = ["derive"] } 20 | safetensors.workspace = true 21 | 22 | # Fix idna-adapter version so that it works with rustc 1.81 23 | idna_adapter = "=1.2.0" -------------------------------------------------------------------------------- /examples/mistral/README.md: -------------------------------------------------------------------------------- 1 | # Mistral 2 | 3 | An example of generating text with Mistral-7B-v0.1 model using mlx-rs. 4 | 5 | This is the rust version of the [mlx-examples/llms/mistral](https://github.com/ml-explore/mlx-examples/tree/main/llms/mistral) example. 6 | 7 | ## Usage 8 | 9 | This example loads the safetensors version of the model from a huggingface repo and thus requires internet connection the first time it is run. 10 | 11 | To run the example in release mode, execute the following command: 12 | 13 | ```bash 14 | cargo run --release 15 | ``` 16 | 17 | ### Arguments 18 | 19 | The example accepts the following optional arguments: 20 | 21 | - `--prompt: str` - The message to be processed by the model. Default: "In the beginning the Universe was created." 22 | - `--max-tokens: int` - The maximum number of tokens to generate. Default: 100 23 | - `--temp: float` - The sampling temperature. Default: 0.0 24 | - `--tokens-per-eval: int` - The batch size of tokens to generate. Default: 10 25 | - `--seed: int` - The PRNG seed. Default: 0 26 | 27 | For example, to generate text with a prompt "Hello, world!" and a seed of 1 (in release mode), run the following command: 28 | 29 | ```bash 30 | cargo run --release -- --prompt "Hello, world!" --seed 1 31 | ``` 32 | -------------------------------------------------------------------------------- /examples/mistral/src/main.rs: -------------------------------------------------------------------------------- 1 | use hf_hub::{ 2 | api::sync::{Api, ApiBuilder, ApiRepo}, 3 | Repo, 4 | }; 5 | use mlx_rs::{ 6 | array, 7 | module::{Module, ModuleParametersExt}, 8 | ops::indexing::{argmax_axis, IndexOp, NewAxis}, 9 | random::categorical, 10 | transforms::eval, 11 | Array, 12 | }; 13 | use tokenizers::Tokenizer; 14 | 15 | mod model; 16 | 17 | use model::{Mistral, MistralInput, MistralOutput, ModelArgs}; 18 | 19 | type Error = Box; 20 | type Result = std::result::Result; 21 | 22 | use clap::Parser; 23 | 24 | #[derive(Parser)] 25 | #[command(about = "Mistral inference example")] 26 | pub struct Cli { 27 | /// The message to be processed by the model 28 | #[clap(long, default_value = "In the begging the Unverse was created.")] 29 | prompt: String, 30 | 31 | /// Maximum number of tokens to generate 32 | #[clap(long, default_value = "100")] 33 | max_tokens: usize, 34 | 35 | /// The sampling temperature 36 | #[clap(long, default_value = "0.0")] 37 | temp: f32, 38 | 39 | /// The batch size of tokens to generate 40 | #[clap(long, default_value = "10")] 41 | tokens_per_eval: usize, 42 | 43 | /// The PRNG seed 44 | #[clap(long, default_value = "0")] 45 | seed: u64, 46 | } 47 | 48 | fn build_hf_api() -> Result { 49 | let cache_dir = std::env::var("HF_CACHE_DIR").ok(); 50 | 51 | let mut builder = ApiBuilder::new(); 52 | if let Some(cache_dir) = cache_dir { 53 | builder = builder.with_cache_dir(cache_dir.into()); 54 | } 55 | builder.build().map_err(Into::into) 56 | } 57 | 58 | fn get_tokenizer(repo: &ApiRepo) -> Result { 59 | let tokenizer_filename = repo.get("tokenizer.json")?; 60 | let t = Tokenizer::from_file(tokenizer_filename)?; 61 | 62 | Ok(t) 63 | } 64 | 65 | fn get_model_args(repo: &ApiRepo) -> Result { 66 | let model_args_filename = repo.get("params.json")?; 67 | let file = std::fs::File::open(model_args_filename)?; 68 | let model_args: ModelArgs = serde_json::from_reader(file)?; 69 | 70 | Ok(model_args) 71 | } 72 | 73 | fn load_model(repo: &ApiRepo) -> Result { 74 | let model_args = get_model_args(repo)?; 75 | let mut model = Mistral::new(&model_args)?; 76 | let weights_filename = repo.get("weights.safetensors")?; 77 | model.load_safetensors(weights_filename)?; 78 | 79 | Ok(model) 80 | } 81 | 82 | fn sample(logits: &Array, temp: f32) -> Result { 83 | match temp { 84 | 0.0 => argmax_axis(logits, -1, None).map_err(Into::into), 85 | _ => { 86 | let logits = logits.multiply(array!(1.0 / temp))?; 87 | categorical(logits, None, None, None).map_err(Into::into) 88 | } 89 | } 90 | } 91 | 92 | macro_rules! tri { 93 | ($expr:expr) => { 94 | match $expr { 95 | Ok(val) => val, 96 | Err(e) => return Some(Err(e.into())), 97 | } 98 | }; 99 | } 100 | 101 | struct Generate<'a> { 102 | model: &'a mut Mistral, 103 | temp: f32, 104 | state: GenerateState<'a>, 105 | } 106 | 107 | enum GenerateState<'a> { 108 | Start { 109 | prompt_token: &'a Array, 110 | }, 111 | Continue { 112 | y: Array, 113 | cache: Vec>, 114 | }, 115 | } 116 | 117 | impl<'a> Generate<'a> { 118 | pub fn new(model: &'a mut Mistral, prompt_token: &'a Array, temp: f32) -> Self { 119 | Self { 120 | model, 121 | temp, 122 | state: GenerateState::Start { prompt_token }, 123 | } 124 | } 125 | } 126 | 127 | impl Iterator for Generate<'_> { 128 | type Item = Result; 129 | 130 | fn next(&mut self) -> Option { 131 | match &self.state { 132 | GenerateState::Start { prompt_token } => { 133 | let initial_cache = Vec::with_capacity(0); // This won't allocate 134 | let input = MistralInput { 135 | inputs: prompt_token, 136 | cache: &initial_cache, 137 | }; 138 | let MistralOutput { logits, cache } = tri!(self.model.forward(input)); 139 | let y = tri!(sample(&logits.index((.., -1, ..)), self.temp)); 140 | 141 | self.state = GenerateState::Continue { 142 | y: y.clone(), 143 | cache, 144 | }; 145 | 146 | Some(Ok(y)) 147 | } 148 | GenerateState::Continue { y, cache } => { 149 | let next_token = y.index((.., NewAxis)); 150 | let input = MistralInput { 151 | inputs: &next_token, 152 | cache: cache.as_slice(), 153 | }; 154 | let MistralOutput { 155 | logits, 156 | cache: new_cache, 157 | } = tri!(self.model.forward(input)); 158 | 159 | let logits = tri!(logits.squeeze_axes(&[1])); 160 | let y = tri!(sample(&logits, self.temp)); 161 | 162 | self.state = GenerateState::Continue { 163 | y: y.clone(), 164 | cache: new_cache, 165 | }; 166 | 167 | Some(Ok(y)) 168 | } 169 | } 170 | } 171 | } 172 | 173 | fn main() -> Result<()> { 174 | // If you want to manually set the cache directory, you can set the HF_CACHE_DIR 175 | // environment variable or put it in a .env file located at the root of this example 176 | // (ie. examples/mistral/.env) 177 | let _ = dotenv::dotenv(); 178 | let api = build_hf_api()?; 179 | 180 | // Parse args 181 | let cli = Cli::parse(); 182 | 183 | mlx_rs::random::seed(cli.seed)?; 184 | 185 | // The model used in the original example is converted to safetensors and 186 | // uploaded to the huggingface hub 187 | let model_id = "minghuaw/Mistral-7B-v0.1".to_string(); 188 | let repo = api.repo(Repo::new(model_id, hf_hub::RepoType::Model)); 189 | println!("[INFO] Loading model... "); 190 | let tokenizer = get_tokenizer(&repo)?; 191 | let mut model = load_model(&repo)?; 192 | 193 | model = mlx_rs::nn::quantize(model, None, None)?; 194 | 195 | let encoding = tokenizer.encode(&cli.prompt[..], true)?; 196 | let prompt_tokens = Array::from(encoding.get_ids()).index(NewAxis); 197 | print!("{}", cli.prompt); 198 | 199 | let generate = Generate::new(&mut model, &prompt_tokens, cli.temp); 200 | let mut tokens = Vec::with_capacity(cli.max_tokens); 201 | for (token, ntoks) in generate.zip(0..cli.max_tokens) { 202 | let token = token?; 203 | tokens.push(token); 204 | 205 | if ntoks == 0 { 206 | eval(&tokens)?; 207 | } 208 | 209 | if tokens.len() % cli.tokens_per_eval == 0 { 210 | eval(&tokens)?; 211 | let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); 212 | let s = tokenizer.decode(&slice, true)?; 213 | print!("{}", s); 214 | } 215 | } 216 | 217 | eval(&tokens)?; 218 | let slice: Vec = tokens.drain(..).map(|t| t.item::()).collect(); 219 | let s = tokenizer.decode(&slice, true)?; 220 | println!("{}", s); 221 | 222 | println!("------"); 223 | 224 | Ok(()) 225 | } 226 | -------------------------------------------------------------------------------- /examples/mnist/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mnist" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | mlx-rs.workspace = true 8 | mnist = "0.6" -------------------------------------------------------------------------------- /examples/mnist/data/t10k-images.idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxideai/mlx-rs/093d5afe76c2747c7f2e7ac85be66a41ea602778/examples/mnist/data/t10k-images.idx3-ubyte -------------------------------------------------------------------------------- /examples/mnist/data/train-images.idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxideai/mlx-rs/093d5afe76c2747c7f2e7ac85be66a41ea602778/examples/mnist/data/train-images.idx3-ubyte -------------------------------------------------------------------------------- /examples/mnist/data/train-labels.idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxideai/mlx-rs/093d5afe76c2747c7f2e7ac85be66a41ea602778/examples/mnist/data/train-labels.idx1-ubyte -------------------------------------------------------------------------------- /examples/mnist/src/data.rs: -------------------------------------------------------------------------------- 1 | // TODO 2 | 3 | use mlx_rs::{error::Exception, ops::stack_axis, Array}; 4 | use mnist::{Mnist, MnistBuilder}; 5 | 6 | const IMAGE_SIZE: usize = 28 * 28; 7 | 8 | pub fn read_data() -> (Vec, Vec, Array, Array) { 9 | let Mnist { 10 | trn_img, 11 | trn_lbl, 12 | val_img: _, 13 | val_lbl: _, 14 | tst_img, 15 | tst_lbl, 16 | } = MnistBuilder::new() 17 | .label_format_digit() 18 | .base_path("data") 19 | .training_images_filename("train-images.idx3-ubyte") 20 | .training_labels_filename("train-labels.idx1-ubyte") 21 | .test_images_filename("t10k-images.idx3-ubyte") 22 | .test_labels_filename("t10k-labels.idx1-ubyte") 23 | .finalize(); 24 | 25 | // Check size 26 | assert_eq!(trn_img.len(), trn_lbl.len() * IMAGE_SIZE); 27 | assert_eq!(tst_img.len(), tst_lbl.len() * IMAGE_SIZE); 28 | 29 | // Convert to Array 30 | let train_images = trn_img 31 | .chunks_exact(IMAGE_SIZE) 32 | .map(|chunk| Array::from_slice(chunk, &[IMAGE_SIZE as i32])) 33 | .collect(); 34 | 35 | let test_images = tst_img 36 | .chunks_exact(IMAGE_SIZE) 37 | .map(|chunk| Array::from_slice(chunk, &[IMAGE_SIZE as i32])) 38 | .collect::>(); 39 | let test_images = stack_axis(&test_images, 0).unwrap(); 40 | 41 | let test_labels = Array::from_slice(&tst_lbl, &[tst_lbl.len() as i32]); 42 | 43 | (train_images, trn_lbl, test_images, test_labels) 44 | } 45 | 46 | /// The iterator is collected to avoid repeated calls to `stack` in the training loop. 47 | pub fn iterate_data<'a>( 48 | images: &'a [Array], 49 | labels: &'a [u8], 50 | batch_size: usize, 51 | ) -> Result, Exception> { 52 | images 53 | .chunks_exact(batch_size) 54 | .zip(labels.chunks_exact(batch_size)) 55 | .map(move |(images, labels)| { 56 | let images = stack_axis(images, 0)?; 57 | let labels = Array::from_slice(labels, &[batch_size as i32]); 58 | Ok((images, labels)) 59 | }) 60 | .collect() 61 | } 62 | -------------------------------------------------------------------------------- /examples/mnist/src/main.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::{ 2 | builder::Builder, 3 | error::Exception, 4 | losses::{CrossEntropyBuilder, LossReduction}, 5 | module::{Module, ModuleParameters}, 6 | nn, 7 | ops::{eq, indexing::argmax_axis, mean}, 8 | optimizers::{Optimizer, Sgd}, 9 | transforms::eval_params, 10 | Array, 11 | }; 12 | 13 | /// MLP model 14 | mod mlp; 15 | 16 | /// Retrieves MNIST dataset 17 | mod data; 18 | 19 | fn eval_fn(model: &mut mlp::Mlp, (x, y): (&Array, &Array)) -> Result { 20 | let y_pred = model.forward(x)?; 21 | let accuracy = mean(&eq(&argmax_axis(&y_pred, 1, None)?, y)?, None)?; 22 | Ok(accuracy) 23 | } 24 | 25 | fn main() -> Result<(), Box> { 26 | let num_layers = 2; 27 | let hidden_dim = 32; 28 | let num_classes = 10; 29 | let batch_size = 256; 30 | let num_epochs = 10; 31 | let learning_rate = 1e-2; 32 | 33 | let (train_images, train_labels, test_images, test_labels) = data::read_data(); 34 | let loader = data::iterate_data(&train_images, &train_labels, batch_size)?; 35 | 36 | let input_dim = train_images[0].shape()[0]; 37 | let mut model = mlp::Mlp::new(num_layers, input_dim, hidden_dim, num_classes)?; 38 | 39 | let cross_entropy = CrossEntropyBuilder::new() 40 | .reduction(LossReduction::Mean) 41 | .build()?; 42 | let loss_fn = |model: &mut mlp::Mlp, (x, y): (&Array, &Array)| -> Result { 43 | let y_pred = model.forward(x)?; 44 | cross_entropy.apply(y_pred, y) 45 | }; 46 | let mut loss_and_grad_fn = nn::value_and_grad(loss_fn); 47 | 48 | let mut optimizer = Sgd::new(learning_rate); 49 | 50 | for e in 0..num_epochs { 51 | let now = std::time::Instant::now(); 52 | for (x, y) in &loader { 53 | let (_loss, grad) = loss_and_grad_fn(&mut model, (x, y))?; 54 | optimizer.update(&mut model, grad).unwrap(); 55 | eval_params(model.parameters())?; 56 | } 57 | 58 | // Evaluate on test set 59 | let accuracy = eval_fn(&mut model, (&test_images, &test_labels))?; 60 | let elapsed = now.elapsed(); 61 | println!( 62 | "Epoch: {}, Test accuracy: {:.2}, Time: {:.2} s", 63 | e, 64 | accuracy.item::(), 65 | elapsed.as_secs_f32() 66 | ); 67 | } 68 | 69 | Ok(()) 70 | } 71 | -------------------------------------------------------------------------------- /examples/mnist/src/mlp.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::{ 2 | error::Exception, 3 | macros::ModuleParameters, 4 | module::Module, 5 | nn::{Linear, Relu, Sequential}, 6 | Array, 7 | }; 8 | 9 | #[derive(Debug, ModuleParameters)] 10 | pub struct Mlp { 11 | #[param] 12 | pub layers: Sequential, 13 | } 14 | 15 | impl Module<&Array> for Mlp { 16 | type Error = Exception; 17 | type Output = Array; 18 | 19 | fn forward(&mut self, x: &Array) -> Result { 20 | self.layers.forward(x) 21 | } 22 | 23 | fn training_mode(&mut self, mode: bool) { 24 | self.layers.training_mode(mode); 25 | } 26 | } 27 | 28 | impl Mlp { 29 | pub fn new( 30 | num_layers: usize, 31 | input_dim: i32, 32 | hidden_dim: i32, 33 | output_dim: i32, 34 | ) -> Result { 35 | let mut layers = Sequential::new(); 36 | 37 | // Add the input layer 38 | layers = layers 39 | .append(Linear::new(input_dim, hidden_dim)?) 40 | .append(Relu); 41 | 42 | // Add the hidden layers 43 | for _ in 1..num_layers { 44 | layers = layers 45 | .append(Linear::new(hidden_dim, hidden_dim)?) 46 | .append(Relu); 47 | } 48 | 49 | // Add the output layer 50 | layers = layers.append(Linear::new(hidden_dim, output_dim)?); 51 | 52 | Ok(Self { layers }) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /mlx-internal-macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mlx-internal-macros" 3 | version.workspace = true 4 | authors.workspace = true 5 | edition.workspace = true 6 | repository.workspace = true 7 | keywords.workspace = true 8 | categories.workspace = true 9 | license.workspace = true 10 | documentation.workspace = true 11 | description = "Internal procedural macros for mlx-rs" 12 | 13 | [lib] 14 | proc-macro = true 15 | 16 | [dependencies] 17 | syn.workspace = true 18 | quote.workspace = true 19 | darling.workspace = true 20 | proc-macro2.workspace = true 21 | itertools.workspace = true 22 | -------------------------------------------------------------------------------- /mlx-internal-macros/src/derive_buildable.rs: -------------------------------------------------------------------------------- 1 | use darling::FromDeriveInput; 2 | use quote::quote; 3 | use syn::DeriveInput; 4 | 5 | use crate::shared::{PathOrIdent, Result}; 6 | 7 | #[derive(Debug, Clone, FromDeriveInput)] 8 | #[darling(attributes(buildable))] 9 | #[allow(dead_code)] 10 | pub(crate) struct StructProperty { 11 | pub ident: syn::Ident, 12 | 13 | /// Generate builder if None 14 | pub builder: Option, 15 | 16 | /// Rename `mlx_rs` if Some(_) 17 | pub root: Option, 18 | } 19 | 20 | pub(crate) fn expand_derive_buildable(input: DeriveInput) -> Result { 21 | let struct_prop = StructProperty::from_derive_input(&input)?; 22 | let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); 23 | 24 | let struct_ident = &struct_prop.ident; 25 | let builder_ident = syn::Ident::new(&format!("{}Builder", struct_ident), struct_ident.span()); 26 | let root = match struct_prop.root { 27 | Some(path) => path, 28 | None => syn::parse_quote!(::mlx_rs), 29 | }; 30 | 31 | let struct_builder_ident = match &struct_prop.builder { 32 | Some(path) => PathOrIdent::Path(path.clone()), 33 | None => PathOrIdent::Ident(builder_ident), 34 | }; 35 | 36 | let impl_buildable = quote! { 37 | impl #impl_generics #root::builder::Buildable for #struct_ident #type_generics #where_clause { 38 | type Builder = #struct_builder_ident #type_generics; 39 | } 40 | }; 41 | 42 | Ok(quote! { 43 | #impl_buildable 44 | }) 45 | } 46 | -------------------------------------------------------------------------------- /mlx-internal-macros/src/derive_builder.rs: -------------------------------------------------------------------------------- 1 | use darling::FromDeriveInput; 2 | use quote::quote; 3 | use syn::{DeriveInput, Ident}; 4 | 5 | use crate::shared::{ 6 | parse_fields_from_derive_input, BuilderStructAnalyzer, BuilderStructProperty, PathOrIdent, 7 | Result, 8 | }; 9 | 10 | pub(crate) fn expand_derive_builder(input: DeriveInput) -> Result { 11 | let builder_struct_prop = BuilderStructProperty::from_derive_input(&input)?; 12 | let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); 13 | 14 | let builder_ident = &builder_struct_prop.ident; 15 | if !is_builder_struct_end_with_builder(builder_ident) { 16 | return Err("Builder struct must end with 'Builder'".into()); 17 | } 18 | let builder_ident_str = builder_ident.to_string(); 19 | let struct_ident = Ident::new( 20 | // We have already checked that the builder struct ends with 'Builder' 21 | &builder_ident_str[..builder_ident_str.len() - "Builder".len()], 22 | builder_ident.span(), 23 | ); 24 | let root = match builder_struct_prop.root { 25 | Some(path) => path, 26 | None => syn::parse_quote!(::mlx_rs), 27 | }; 28 | 29 | let builder_struct_ident = PathOrIdent::Ident(builder_ident.clone()); 30 | let (mandatory_fields, optional_fields) = parse_fields_from_derive_input(&input)?; 31 | let is_default_infallible = builder_struct_prop 32 | .default_infallible 33 | .unwrap_or_else(|| builder_struct_prop.err.is_none()); 34 | 35 | let builder_struct_analyzer = BuilderStructAnalyzer { 36 | struct_ident: &struct_ident, 37 | builder_struct_ident: &builder_struct_ident, 38 | root: &root, 39 | impl_generics: &impl_generics, 40 | type_generics: &type_generics, 41 | where_clause, 42 | mandatory_fields: &mandatory_fields, 43 | optional_fields: &optional_fields, 44 | build_with: builder_struct_prop.build_with.as_ref(), 45 | err: builder_struct_prop.err.as_ref(), 46 | }; 47 | 48 | let impl_builder = builder_struct_analyzer.impl_builder(); 49 | let impl_struct_new = builder_struct_analyzer.impl_struct_new(is_default_infallible); 50 | 51 | Ok(quote! { 52 | #impl_builder 53 | #impl_struct_new 54 | }) 55 | } 56 | 57 | fn is_builder_struct_end_with_builder(ident: &Ident) -> bool { 58 | ident.to_string().ends_with("Builder") 59 | } 60 | -------------------------------------------------------------------------------- /mlx-internal-macros/src/generate_builder.rs: -------------------------------------------------------------------------------- 1 | use darling::FromDeriveInput; 2 | use proc_macro2::TokenTree; 3 | use quote::quote; 4 | use syn::DeriveInput; 5 | 6 | use crate::{ 7 | derive_buildable::StructProperty, 8 | shared::{ 9 | parse_fields_from_derive_input, BuilderStructAnalyzer, BuilderStructProperty, PathOrIdent, 10 | Result, 11 | }, 12 | }; 13 | 14 | pub(crate) fn expand_generate_builder(input: &DeriveInput) -> Result { 15 | // Make sure the struct does NOT have #[derive(Default)] 16 | if struct_attr_derive_default(&input.attrs) { 17 | return Err("Struct with #[derive(Default)] cannot derive Buildable".into()); 18 | } 19 | 20 | let struct_prop = StructProperty::from_derive_input(input)?; 21 | let builder_struct_prop = BuilderStructProperty::from_derive_input(input)?; 22 | let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); 23 | 24 | let struct_ident = &struct_prop.ident; 25 | let builder_struct_ident = 26 | syn::Ident::new(&format!("{}Builder", struct_ident), struct_ident.span()); 27 | let root = match struct_prop.root { 28 | Some(path) => path, 29 | None => syn::parse_quote!(::mlx_rs), 30 | }; 31 | 32 | let (mandatory_fields, optional_fields) = parse_fields_from_derive_input(input)?; 33 | let is_default_infallible = builder_struct_prop 34 | .default_infallible 35 | .unwrap_or_else(|| builder_struct_prop.err.is_none()); 36 | 37 | let builder_struct_ident = match &struct_prop.builder { 38 | Some(path) => PathOrIdent::Path(path.clone()), 39 | None => PathOrIdent::Ident(builder_struct_ident.clone()), 40 | }; 41 | let builder_struct_analyzer = BuilderStructAnalyzer { 42 | struct_ident, 43 | builder_struct_ident: &builder_struct_ident, 44 | root: &root, 45 | impl_generics: &impl_generics, 46 | type_generics: &type_generics, 47 | where_clause, 48 | mandatory_fields: &mandatory_fields, 49 | optional_fields: &optional_fields, 50 | build_with: builder_struct_prop.build_with.as_ref(), 51 | err: builder_struct_prop.err.as_ref(), 52 | }; 53 | let builder_struct = if struct_prop.builder.is_none() { 54 | builder_struct_analyzer.generate_builder_struct() 55 | } else { 56 | quote! {} 57 | }; 58 | let impl_builder = builder_struct_analyzer.impl_builder(); 59 | let impl_struct_new = builder_struct_analyzer.impl_struct_new(is_default_infallible); 60 | 61 | Ok(quote! { 62 | #builder_struct 63 | #impl_builder 64 | #impl_struct_new 65 | }) 66 | } 67 | 68 | fn struct_attr_derive_default(attrs: &[syn::Attribute]) -> bool { 69 | attrs 70 | .iter() 71 | .filter_map(|attr| { 72 | if attr.path().is_ident("derive") { 73 | attr.meta 74 | .require_list() 75 | .map(|list| list.tokens.clone()) 76 | .ok() 77 | } else { 78 | None 79 | } 80 | }) 81 | .any(|tokens| { 82 | tokens.into_iter().any(|tree| { 83 | if let TokenTree::Ident(ident) = tree { 84 | ident == "Default" 85 | } else { 86 | false 87 | } 88 | }) 89 | }) 90 | } 91 | -------------------------------------------------------------------------------- /mlx-macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mlx-macros" 3 | version.workspace = true 4 | authors.workspace = true 5 | edition.workspace = true 6 | repository.workspace = true 7 | keywords.workspace = true 8 | categories.workspace = true 9 | license.workspace = true 10 | documentation.workspace = true 11 | 12 | description = "Procedural macros for mlx-rs" 13 | 14 | [lib] 15 | proc-macro = true 16 | 17 | [dependencies] 18 | syn.workspace = true 19 | quote.workspace = true 20 | darling.workspace = true 21 | proc-macro2.workspace = true 22 | -------------------------------------------------------------------------------- /mlx-macros/src/lib.rs: -------------------------------------------------------------------------------- 1 | extern crate proc_macro; 2 | use proc_macro::TokenStream; 3 | use syn::{parse_macro_input, DeriveInput}; 4 | 5 | mod module_parameters; 6 | mod quantizable; 7 | mod util; 8 | 9 | /// Derive the `ModuleParameters` trait for a struct. Mark a field with 10 | /// `#[param]` attribute to include it in the parameters. The field type must 11 | /// implement the `mlx_rs::module::Parameter` trait. 12 | /// 13 | /// # Example 14 | /// 15 | /// ```rust, ignore 16 | /// use mlx_macros::ModuleParameters; 17 | /// use mlx_rs::module::{ModuleParameters, Param}; 18 | /// 19 | /// #[derive(ModuleParameters)] 20 | /// struct Example { 21 | /// #[param] 22 | /// regular: Param, 23 | /// 24 | /// #[param] 25 | /// optional: Param>, 26 | /// 27 | /// #[param] 28 | /// nested: Inner, 29 | /// 30 | /// #[param] 31 | /// vec_nested: Vec, 32 | /// 33 | /// #[param] 34 | /// trait_object: Box, 35 | /// 36 | /// #[param] 37 | /// trait_object_vec: Vec>, 38 | /// } 39 | /// 40 | /// #[derive(ModuleParameters)] 41 | /// struct Inner { 42 | /// #[param] 43 | /// a: Param, 44 | /// } 45 | /// ``` 46 | #[proc_macro_derive(ModuleParameters, attributes(module, param))] 47 | pub fn derive_module_parameters(input: TokenStream) -> TokenStream { 48 | let input = parse_macro_input!(input as DeriveInput); 49 | let module_param_impl = module_parameters::expand_module_parameters(&input).unwrap(); 50 | TokenStream::from(module_param_impl) 51 | } 52 | 53 | /// Derive the `Quantizable` trait for a struct. Mark a field with 54 | /// `#[quantizable]` attribute to include it in the quantization process. 55 | /// Only support types `M` that `M::Quantized = Self` 56 | /// 57 | /// See `mlx-rs/mlx-tests/tests/test_quantizable.rs` for example usage. 58 | /// 59 | /// # Panics 60 | /// 61 | /// This macro will panic if the struct does not have any field marked with 62 | /// `#[quantizable]`. 63 | #[proc_macro_derive(Quantizable, attributes(quantizable))] 64 | pub fn derive_quantizable_module(input: TokenStream) -> TokenStream { 65 | let input = parse_macro_input!(input as DeriveInput); 66 | let quantizable_module_impl = quantizable::expand_quantizable(&input).unwrap(); 67 | TokenStream::from(quantizable_module_impl) 68 | } 69 | -------------------------------------------------------------------------------- /mlx-macros/src/module_parameters.rs: -------------------------------------------------------------------------------- 1 | use darling::FromDeriveInput; 2 | use syn::{DataStruct, DeriveInput, Generics, Ident}; 3 | 4 | use crate::util::filter_fields_with_attr; 5 | 6 | #[derive(Debug, Clone, FromDeriveInput)] 7 | #[darling(attributes(module))] 8 | struct ModuleProperties { 9 | root: Option, 10 | } 11 | 12 | pub(crate) fn expand_module_parameters( 13 | input: &DeriveInput, 14 | ) -> Result { 15 | let prop = ModuleProperties::from_derive_input(input)?; 16 | let struct_ident = &input.ident; 17 | let generics = &input.generics; 18 | match &input.data { 19 | syn::Data::Struct(data) => { 20 | expand_module_parameters_for_struct(struct_ident, generics, data, prop.root) 21 | } 22 | _ => Err(syn::Error::new_spanned( 23 | input, 24 | "ModuleParameters can only be derived for structs", 25 | )), 26 | } 27 | } 28 | 29 | fn expand_module_parameters_for_struct( 30 | ident: &Ident, 31 | generics: &Generics, 32 | data: &DataStruct, 33 | root: Option, 34 | ) -> Result { 35 | let fields = filter_fields_with_attr(&data.fields, "param")?; 36 | 37 | Ok(impl_module_parameters_for_struct( 38 | ident, 39 | generics, 40 | fields.filtered, 41 | root, 42 | )) 43 | } 44 | 45 | fn impl_module_parameters_for_struct( 46 | ident: &Ident, 47 | generics: &Generics, 48 | fields: Vec<&syn::Field>, 49 | root: Option, 50 | ) -> proc_macro2::TokenStream { 51 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); 52 | let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect(); 53 | 54 | // Returns None if there are no fields 55 | let default_all_frozen = match field_names.len() { 56 | 0 => quote::quote! { None }, 57 | _ => quote::quote! { Some(true) }, 58 | }; 59 | 60 | // Returns None if there are no fields 61 | let default_any_frozen = match field_names.len() { 62 | 0 => quote::quote! { None }, 63 | _ => quote::quote! { Some(false) }, 64 | }; 65 | 66 | let (extern_import, root) = match root { 67 | Some(root) => (quote::quote! {}, quote::quote! { #root }), 68 | None => ( 69 | quote::quote! { extern crate mlx_rs as _mlx_rs; }, 70 | quote::quote! { _mlx_rs }, 71 | ), 72 | }; 73 | 74 | quote::quote! { 75 | const _: () = { 76 | #extern_import 77 | impl #impl_generics #root::module::ModuleParameters for #ident #ty_generics #where_clause { 78 | fn num_parameters(&self) -> usize { 79 | use #root::module::Parameter; 80 | let mut count = 0; 81 | #( 82 | count += self.#field_names.count(); 83 | )* 84 | count 85 | } 86 | 87 | fn freeze_parameters(&mut self, recursive: bool) { 88 | use #root::module::Parameter; 89 | #(self.#field_names.freeze(recursive);)* 90 | } 91 | 92 | fn unfreeze_parameters(&mut self, recursive: bool) { 93 | use #root::module::Parameter; 94 | #(self.#field_names.unfreeze(recursive);)* 95 | } 96 | 97 | fn parameters(&self) -> #root::module::ModuleParamRef<'_> { 98 | let mut parameters = #root::nested::NestedHashMap::new(); 99 | #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value(&self.#field_names));)* 100 | parameters 101 | } 102 | 103 | fn parameters_mut(&mut self) -> #root::module::ModuleParamMut<'_> { 104 | let mut parameters = #root::nested::NestedHashMap::new(); 105 | #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value_mut(&mut self.#field_names));)* 106 | parameters 107 | } 108 | 109 | fn trainable_parameters(&self) -> #root::module::ModuleParamRef<'_> { 110 | let mut parameters = #root::nested::NestedHashMap::new(); 111 | #( 112 | if let Some(field) = #root::module::Parameter::as_trainable_nested_value(&self.#field_names) { 113 | parameters.insert(std::rc::Rc::from(stringify!(#field_names)), field); 114 | } 115 | )* 116 | parameters 117 | } 118 | 119 | fn all_frozen(&self) -> Option { 120 | use #root::module::Parameter; 121 | #( 122 | if matches!(self.#field_names.is_frozen(), Some(false)) { 123 | return Some(false); 124 | } 125 | )* 126 | #default_all_frozen 127 | } 128 | 129 | fn any_frozen(&self) -> Option { 130 | use #root::module::Parameter; 131 | #( 132 | if matches!(self.#field_names.is_frozen(), Some(true)) { 133 | return Some(true); 134 | } 135 | )* 136 | #default_any_frozen 137 | } 138 | } 139 | }; 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /mlx-macros/src/quantizable.rs: -------------------------------------------------------------------------------- 1 | use darling::FromDeriveInput; 2 | use syn::{DeriveInput, Generics, Ident}; 3 | 4 | use crate::util::{filter_fields_with_attr, FilteredFields}; 5 | 6 | #[derive(Debug, Clone, FromDeriveInput)] 7 | #[darling(attributes(quantizable))] 8 | struct StructProperties { 9 | root: Option, 10 | } 11 | 12 | pub(crate) fn expand_quantizable( 13 | input: &DeriveInput, 14 | ) -> Result { 15 | let prop = StructProperties::from_derive_input(input)?; 16 | let struct_ident = &input.ident; 17 | let generics = &input.generics; 18 | 19 | match &input.data { 20 | syn::Data::Struct(data) => { 21 | expand_quantizable_module_for_struct(struct_ident, generics, data, prop.root) 22 | } 23 | _ => Err(syn::Error::new_spanned( 24 | input, 25 | "Quantizable can only be derived for structs", 26 | )), 27 | } 28 | } 29 | 30 | fn expand_quantizable_module_for_struct( 31 | ident: &syn::Ident, 32 | generics: &syn::Generics, 33 | data: &syn::DataStruct, 34 | root: Option, 35 | ) -> Result { 36 | // Filter fields with #[quantizable] 37 | let fields = filter_fields_with_attr(&data.fields, "quantizable")?; 38 | 39 | impl_quantizable_module_for_struct(ident, generics, fields, root) 40 | } 41 | 42 | fn impl_quantizable_module_for_struct( 43 | ident: &Ident, 44 | generics: &Generics, 45 | fields: FilteredFields, 46 | root: Option, 47 | ) -> Result { 48 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); 49 | // let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect(); 50 | 51 | let filtered_field_names = fields.filtered.iter().map(|field| &field.ident); 52 | let other_field_names = fields.other_fields.iter().map(|field| &field.ident); 53 | 54 | if fields.filtered.is_empty() { 55 | return Err(syn::Error::new_spanned( 56 | ident, 57 | "At least one field must be quantizable", 58 | )); 59 | } 60 | 61 | let (extern_import, root) = match root { 62 | Some(root) => (quote::quote! {}, quote::quote! { #root }), 63 | None => ( 64 | quote::quote! { extern crate mlx_rs as _mlx_rs; }, 65 | quote::quote! { _mlx_rs }, 66 | ), 67 | }; 68 | 69 | let token = quote::quote! { 70 | const _: () = { 71 | #extern_import 72 | impl #impl_generics #root::quantization::Quantizable for #ident #ty_generics #where_clause { 73 | type Quantized = Self; // Generating new struct is not supported yet 74 | 75 | type QuantizationError = #root::error::Exception; 76 | 77 | fn try_into_quantized( 78 | self, 79 | group_size: i32, 80 | bits: i32, 81 | ) -> Result { 82 | Ok(Self { 83 | #( 84 | #filtered_field_names: #root::quantization::Quantizable 85 | ::try_into_quantized(self.#filtered_field_names, group_size, bits)?, 86 | )* 87 | #( 88 | #other_field_names: self.#other_field_names, 89 | )* 90 | }) 91 | } 92 | } 93 | }; 94 | }; 95 | Ok(token) 96 | } 97 | -------------------------------------------------------------------------------- /mlx-macros/src/util.rs: -------------------------------------------------------------------------------- 1 | pub(crate) struct FilteredFields<'a> { 2 | pub filtered: Vec<&'a syn::Field>, 3 | pub other_fields: Vec<&'a syn::Field>, 4 | } 5 | 6 | pub(crate) fn filter_fields_with_attr<'a>( 7 | fields: &'a syn::Fields, 8 | attr_name: &str, 9 | ) -> Result, syn::Error> { 10 | let mut filtered = Vec::new(); 11 | let mut other_fields = Vec::new(); 12 | 13 | match fields { 14 | syn::Fields::Named(fields) => { 15 | for field in &fields.named { 16 | if field 17 | .attrs 18 | .iter() 19 | .any(|attr| attr.path().is_ident(attr_name)) 20 | { 21 | filtered.push(field); 22 | } else { 23 | other_fields.push(field); 24 | } 25 | } 26 | } 27 | syn::Fields::Unit => {} 28 | syn::Fields::Unnamed(_) => { 29 | return Err(syn::Error::new_spanned( 30 | fields, 31 | "Struct with unnamed fields is not supported".to_string(), 32 | )) 33 | } 34 | } 35 | 36 | Ok(FilteredFields { 37 | filtered, 38 | other_fields, 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /mlx-rs/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 0.25.0-alpha.1 4 | 5 | - Update `mlx-c` to version "0.2.0-alpha" and changes function signatures to 6 | match the new API 7 | - Update `thiserror` to version "2" 8 | - Fix wrong states number in `compile_with_state` 9 | - Remove unnecessary evaluation in fft ops 10 | 11 | ## 0.23.0 12 | 13 | - Update `mlx-c` to "0.1.2" 14 | - Added `dilation` and `groups` parameters to the convolution layer 15 | 16 | ## 0.21.1 17 | 18 | - Fix `mlx-sys` dependency to patch version in workspace 19 | 20 | ## 0.21.0 21 | 22 | - Initial feature-complete release 23 | -------------------------------------------------------------------------------- /mlx-rs/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mlx-rs" 3 | version.workspace = true 4 | authors.workspace = true 5 | edition.workspace = true 6 | repository.workspace = true 7 | keywords.workspace = true 8 | categories.workspace = true 9 | license.workspace = true 10 | documentation.workspace = true 11 | description = "Unofficial rust wrapper for Apple's mlx machine learning library." 12 | readme = "README.md" 13 | 14 | [package.metadata.docs.rs] 15 | targets = [ 16 | "aarch64-apple-darwin", 17 | "aarch64-apple-ios", 18 | "aarch64-apple-ios-sim", 19 | ] 20 | 21 | [dependencies] 22 | mlx-sys.workspace = true 23 | mlx-internal-macros.workspace = true 24 | mlx-macros.workspace = true 25 | dyn-clone.workspace = true 26 | half.workspace = true 27 | mach-sys.workspace = true 28 | num-complex.workspace = true 29 | num_enum.workspace = true 30 | num-traits.workspace = true 31 | paste.workspace = true 32 | smallvec.workspace = true 33 | strum.workspace = true 34 | thiserror.workspace = true 35 | libc.workspace = true 36 | parking_lot.workspace = true 37 | itertools.workspace = true 38 | 39 | # optional dependencies 40 | safetensors = { workspace = true, optional = true } 41 | bytemuck = { workspace = true, optional = true, features = ["extern_crate_std"] } 42 | 43 | [dev-dependencies] 44 | pretty_assertions.workspace = true 45 | float_eq.workspace = true 46 | tempfile.workspace = true 47 | 48 | [features] 49 | default = ["accelerate", "metal"] 50 | 51 | accelerate = ["mlx-sys/accelerate"] 52 | metal = ["mlx-sys/metal"] 53 | 54 | # Enables conversion between `Array` and `safetensors::TensorView` 55 | safetensors = ["dep:safetensors", "dep:bytemuck"] -------------------------------------------------------------------------------- /mlx-rs/README.md: -------------------------------------------------------------------------------- 1 |
2 |

mlx-rs

3 | 4 | Rust bindings for Apple's mlx machine learning library. 5 | 6 | [![Discord](https://img.shields.io/discord/1176807732473495552.svg?color=7289da&&logo=discord)](https://discord.gg/jZvTsxDX49) 7 | [![Current Crates.io Version](https://img.shields.io/crates/v/mlx-rs.svg)](https://crates.io/crates/mlx-rs) 8 | [![Documentation](https://img.shields.io/badge/docs-latest-blue)]() 9 | [![Test Status](https://github.com/oxideai/mlx-rs/actions/workflows/validate.yml/badge.svg)](https://github.com/oxideai/mlx-rs/actions/workflows/validate.yml) 10 | [![Blaze](https://runblaze.dev/gh/307493885959233117281096297203102330146/badge.svg)](https://runblaze.dev) 11 | [![Rust Version](https://img.shields.io/badge/Rust-1.81.0+-blue)](https://releases.rs/docs/1.81.0) 12 | ![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue) 13 | 14 | > **⚠️ Project is in active development - contributors welcome!** 15 | 16 | --- 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | _[Blaze](https://runblaze.dev) supports this project by providing ultra-fast Apple Silicon macOS Github Action Runners. Apply the discount code `AI25` at checkout to enjoy 25% off your first year._ 29 | 30 |
31 | 32 |
33 | 34 | ## Documentation 35 | 36 | Due to known limitation of docsrs, we are hosting the documentation on github pages [here](https://oxideai.github.io/mlx-rs/mlx_rs/). 37 | 38 | ## Features 39 | 40 | MLX is an array framework for machine learning on Apple Silicon. mlx-rs provides Rust bindings for MLX, allowing you to use MLX in your Rust projects. 41 | 42 | Some key features of MLX and `mlx-rs` include: 43 | - **Performance**: MLX is optimized for Apple Silicon, providing fast performance for machine learning tasks. 44 | - **Lazy Evaluation**: MLX uses lazy evaluation to optimize performance and memory usage. Arrays are only materialized when needed. 45 | - **Dynamic Graphs**: Computation graphs in MLX are constructed dynamically, allowing for flexible and efficient computation. Changing the shapes of function arguments does not require recompilation. 46 | - **Mutli-Device Support**: MLX supports running computations on any of the supported devices (for now the CPU and GPU). 47 | - **Unified memory**: MLX provides a unified memory model, meaning arrays live in the same memory space, regardless of the device they are computed on. Operations can be performed on arrays on different devices without copying data between them. 48 | 49 | `mlx-rs` is designed to be a safe and idiomatic Rust interface to MLX, providing a seamless experience for Rust developers. 50 | 51 | ## Examples 52 | The [examples](examples/) directory contains sample projects demonstrating different uses cases of our library. 53 | - [mnist](examples/mnist/): Train a basic neural network on the MNIST digit dataset 54 | - [mistral](examples/mistral/): Text generation using the pre-trained Mistral model 55 | 56 | ## Installation 57 | 58 | Add this to your `Cargo.toml`: 59 | ```toml 60 | [dependencies] 61 | mlx-rs = "0.21.0" 62 | ``` 63 | 64 | ## Feature Flags 65 | 66 | * `metal` - enables metal (GPU) usage in MLX 67 | * `accelerate` - enables using the accelerate framework in MLX 68 | 69 | ## Important Notes on Automatic Differentiation 70 | 71 | When using automatic differentiation in mlx-rs, there's an important difference in how closures work compared to Python's MLX. In Python, variables are implicitly captured and properly traced in the compute graph. However, in Rust, we need to be more explicit about which arrays should be traced. 72 | 73 | ❌ This approach may cause segfaults: 74 | ```rust 75 | // Don't do this 76 | let x = random::normal::(&[num_examples, num_features], None, None, None)?; 77 | let y = x.matmul(&w_star)? + eps; 78 | 79 | let loss_fn = |w: &Array| -> Result { 80 | let y_pred = x.matmul(w)?; // x and y are captured from outer scope 81 | let loss = Array::from_f32(0.5) * ops::mean(&ops::square(&(y_pred - &y))?, None, None)?; 82 | Ok(loss) 83 | }; 84 | 85 | let grad_fn = transforms::grad(loss_fn, &[0]); 86 | ``` 87 | 88 | ✅ Instead, pass all required arrays as inputs to ensure proper tracing: 89 | ```rust 90 | let loss_fn = |inputs: &[Array]| -> Result { 91 | let w = &inputs[0]; 92 | let x = &inputs[1]; 93 | let y = &inputs[2]; 94 | 95 | let y_pred = x.matmul(w)?; 96 | let loss = Array::from_f32(0.5) * ops::mean(&ops::square(y_pred - y)?, None, None)?; 97 | Ok(loss) 98 | }; 99 | let argnums = &[0]; // Specify which argument to differentiate with respect to 100 | 101 | // Pass all required arrays in the inputs slice 102 | let mut inputs = vec![w, x, y]; 103 | let grad = transforms::grad(loss_fn, argnums)(&inputs)?; 104 | ``` 105 | 106 | When using gradients in training loops, remember to update the appropriate array in your inputs: 107 | 108 | ```rust 109 | let mut inputs = vec![w, x, y]; 110 | 111 | for _ in 0..num_iterations { 112 | let grad = transforms::grad(loss_fn, argnums)(&inputs)?; 113 | inputs[0] = &inputs[0] - Array::from_f32(learning_rate) * grad; // Update the weight array 114 | inputs[0].eval()?; 115 | } 116 | ``` 117 | 118 | We are actively working on improving this API to make it more ergonomic and closer to Python's behavior. For now, explicitly passing all required arrays as shown above is the recommended approach. 119 | 120 | ## Versioning 121 | 122 | For simplicity, the main crate `mls-rs` follows MLX’s versioning, allowing you to easily see which MLX version you’re using under the hood. The `mlx-sys` crate follows the versioning of `mlx-c`, as that is the version from which the API is generated. 123 | 124 | ## Community 125 | 126 | If you are excited about the project or want to contribute, don't hesitate to join our [Discord](https://discord.gg/jZvTsxDX49)! 127 | We try to be as welcoming as possible to everybody from any background. We're still building this out, but you can ask your questions there! 128 | 129 | ## Status 130 | 131 | mlx-rs is currently in active development and can be used to run MLX models in Rust. 132 | 133 | ## MSRV 134 | 135 | The minimum supported Rust version is 1.81.0. 136 | 137 | The MSRV is the minimum Rust version that can be used to compile each crate. 138 | 139 | ## License 140 | 141 | mlx-rs is distributed under the terms of both the MIT license and the Apache License (Version 2.0). 142 | See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details. Opening a pull 143 | request is assumed to signal agreement with these licensing terms. 144 | -------------------------------------------------------------------------------- /mlx-rs/examples/linear_regression.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::error::Exception; 2 | use mlx_rs::{ops, transforms, Array}; 3 | use std::error::Error; 4 | 5 | fn main() -> Result<(), Box> { 6 | let num_features: i32 = 100; 7 | let num_examples: i32 = 1000; 8 | let num_iterations: i32 = 10000; 9 | let learning_rate: f32 = 0.01; 10 | 11 | // True weight vector 12 | // let w_star = mlx_rs::random::normal::(&[num_features], None, None, None)?; 13 | let w_star = mlx_rs::normal!(shape = &[num_features])?; 14 | 15 | // Input examples (design matrix) 16 | // let x = mlx_rs::random::normal::(&[num_examples, num_features], None, None, None)?; 17 | let x = mlx_rs::normal!(shape = &[num_examples, num_features])?; 18 | 19 | // Noisy labels 20 | // let eps = mlx_rs::random::normal::(&[num_examples], None, None, None)? * 1e-2; 21 | let eps = mlx_rs::normal!(shape = &[num_examples])? * 1e-2; 22 | let y = x.matmul(&w_star)? + eps; 23 | 24 | // Initialize random weights 25 | // let w = mlx_rs::random::normal::(&[num_features], None, None, None)? * 1e-2; 26 | let w = mlx_rs::normal!(shape = &[num_features])? * 1e-2; 27 | 28 | let loss_fn = |inputs: &[Array]| -> Result { 29 | let w = &inputs[0]; 30 | let x = &inputs[1]; 31 | let y = &inputs[2]; 32 | 33 | let y_pred = x.matmul(w)?; 34 | let loss = Array::from_f32(0.5) * ops::mean(&ops::square(y_pred - y)?, None)?; 35 | Ok(loss) 36 | }; 37 | 38 | let mut grad_fn = transforms::grad(loss_fn); 39 | 40 | let now = std::time::Instant::now(); 41 | let mut inputs = [w, x, y]; 42 | 43 | for _ in 0..num_iterations { 44 | let grad = grad_fn(&inputs)?; 45 | inputs[0] = &inputs[0] - Array::from_f32(learning_rate) * grad; 46 | inputs[0].eval()?; 47 | } 48 | 49 | let elapsed = now.elapsed(); 50 | 51 | let loss = loss_fn(&inputs)?; 52 | let error_norm = ops::sum(&ops::square(&(&inputs[0] - &w_star))?, None)?.sqrt()?; 53 | let throughput = num_iterations as f32 / elapsed.as_secs_f32(); 54 | 55 | println!( 56 | "Loss {:.5}, L2 distance: |w-w*| = {:.5}, Throughput {:.5} (it/s)", 57 | loss.item::(), 58 | error_norm.item::(), 59 | throughput 60 | ); 61 | 62 | Ok(()) 63 | } 64 | -------------------------------------------------------------------------------- /mlx-rs/examples/tutorial.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::transforms::grad; 2 | use mlx_rs::{Array, Dtype}; 3 | 4 | fn scalar_basics() { 5 | // create a scalar array 6 | let x: Array = 1.0.into(); 7 | 8 | // the datatype is .float32 9 | let dtype = x.dtype(); 10 | assert_eq!(dtype, Dtype::Float32); 11 | 12 | // get the value 13 | let s = x.item::(); 14 | assert_eq!(s, 1.0); 15 | 16 | // reading the value with a different type is a fatal error 17 | // let i = x.item::(); 18 | 19 | // scalars have a size of 1 20 | let size = x.size(); 21 | assert_eq!(size, 1); 22 | 23 | // scalars have 0 dimensions 24 | let ndim = x.ndim(); 25 | assert_eq!(ndim, 0); 26 | 27 | // scalar shapes are empty arrays 28 | let shape = x.shape(); 29 | assert!(shape.is_empty()); 30 | } 31 | 32 | #[allow(unused_variables)] 33 | fn array_basics() { 34 | // make a multidimensional array. 35 | let x: Array = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]); 36 | 37 | // mlx is row-major by default so the first row of this array 38 | // is [1.0, 2.0] and the second row is [3.0, 4.0] 39 | 40 | // Make an array of shape {2, 2} filled with ones: 41 | let y = Array::ones::(&[2, 2]).unwrap(); 42 | 43 | // Pointwise add x and y: 44 | let z = x.add(&y); 45 | 46 | // Same thing: 47 | let mut z = &x + &y; 48 | 49 | // mlx is lazy by default. At this point `z` only 50 | // has a shape and a type but no actual data: 51 | assert_eq!(z.dtype(), Dtype::Float32); 52 | assert_eq!(z.shape(), vec![2, 2]); 53 | 54 | // To actually run the computation you must evaluate `z`. 55 | // Under the hood, mlx records operations in a graph. 56 | // The variable `z` is a node in the graph which points to its operation 57 | // and inputs. When `eval` is called on an array (or arrays), the array and 58 | // all of its dependencies are recursively evaluated to produce the result. 59 | // Once an array is evaluated, it has data and is detached from its inputs. 60 | z.eval().unwrap(); 61 | 62 | // Of course the array can still be an input to other operations. You can even 63 | // call eval on the array again, this will just be a no-op: 64 | z.eval().unwrap(); // no-op 65 | 66 | // Some functions or methods on arrays implicitly evaluate them. For example 67 | // accessing a value in an array or printing the array implicitly evaluate it: 68 | z = Array::ones::(&[1]).unwrap(); 69 | z.item::(); // implicit evaluation 70 | 71 | z = Array::ones::(&[2, 2]).unwrap(); 72 | println!("{}", z); // implicit evaluation 73 | } 74 | 75 | fn automatic_differentiation() { 76 | use mlx_rs::error::Result; 77 | 78 | fn f(x: &Array) -> Result { 79 | x.square() 80 | } 81 | 82 | fn calculate_grad(func: impl Fn(&Array) -> Result, arg: &Array) -> Result { 83 | grad(&func)(arg) 84 | } 85 | 86 | let x = Array::from(1.5); 87 | 88 | let dfdx = calculate_grad(f, &x).unwrap(); 89 | assert_eq!(dfdx.item::(), 2.0 * 1.5); 90 | 91 | let dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x).unwrap(); 92 | assert_eq!(dfdx2.item::(), 2.0); 93 | } 94 | 95 | fn main() { 96 | scalar_basics(); 97 | array_basics(); 98 | automatic_differentiation(); 99 | } 100 | -------------------------------------------------------------------------------- /mlx-rs/src/array/element.rs: -------------------------------------------------------------------------------- 1 | use crate::error::Result; 2 | use crate::sealed::Sealed; 3 | use crate::{complex64, Array, Dtype}; 4 | use half::{bf16, f16}; 5 | 6 | /// A marker trait for array elements. 7 | pub trait ArrayElement: Sized + Sealed { 8 | /// The data type of the element. 9 | const DTYPE: Dtype; 10 | 11 | /// Access the value of a scalar array. Returns `Err` if the array is not scalar. 12 | fn array_item(array: &Array) -> Result; 13 | 14 | /// Access the raw data of an array. 15 | fn array_data(array: &Array) -> *const Self; 16 | } 17 | 18 | /// A marker trait for array element types that can be constructed via the 19 | /// [`Array::from_slice`] method. This trait is implemented for all array 20 | /// element types except for [`f64`]. 21 | /// 22 | /// [`f64`] is treated specially because it is not supported on GPU devices, but 23 | /// rust defaults floating point literals to `f64`. With this trait, we can 24 | /// limit the default floating point literals to `f32` for constructors 25 | /// functions like [`Array::from_slice`] and [`Array::from_iter`], and macro 26 | /// [`crate::array!`]. 27 | pub trait FromSliceElement: ArrayElement {} 28 | 29 | macro_rules! impl_array_element { 30 | ($type:ty, $dtype:expr, $ctype:ident) => { 31 | paste::paste! { 32 | impl Sealed for $type {} 33 | impl ArrayElement for $type { 34 | const DTYPE: Dtype = $dtype; 35 | 36 | fn array_item(array: &Array) -> Result { 37 | use crate::utils::guard::*; 38 | 39 | <$type as Guarded>::try_from_op(|ptr| unsafe { 40 | mlx_sys::[](ptr, array.as_ptr()) 41 | }) 42 | } 43 | 44 | fn array_data(array: &Array) -> *const Self { 45 | unsafe { mlx_sys::[](array.as_ptr()) as *const Self } 46 | } 47 | 48 | } 49 | } 50 | }; 51 | } 52 | 53 | impl_array_element!(bool, Dtype::Bool, bool); 54 | impl_array_element!(u8, Dtype::Uint8, uint8); 55 | impl_array_element!(u16, Dtype::Uint16, uint16); 56 | impl_array_element!(u32, Dtype::Uint32, uint32); 57 | impl_array_element!(u64, Dtype::Uint64, uint64); 58 | impl_array_element!(i8, Dtype::Int8, int8); 59 | impl_array_element!(i16, Dtype::Int16, int16); 60 | impl_array_element!(i32, Dtype::Int32, int32); 61 | impl_array_element!(i64, Dtype::Int64, int64); 62 | impl_array_element!(f64, Dtype::Float64, float64); 63 | impl_array_element!(f32, Dtype::Float32, float32); 64 | impl_array_element!(f16, Dtype::Float16, float16); 65 | impl_array_element!(bf16, Dtype::Bfloat16, bfloat16); 66 | impl_array_element!(complex64, Dtype::Complex64, complex64); 67 | 68 | macro_rules! impl_from_slice_element { 69 | ($type:ty) => { 70 | impl FromSliceElement for $type {} 71 | }; 72 | } 73 | 74 | impl_from_slice_element!(bool); 75 | impl_from_slice_element!(u8); 76 | impl_from_slice_element!(u16); 77 | impl_from_slice_element!(u32); 78 | impl_from_slice_element!(u64); 79 | impl_from_slice_element!(i8); 80 | impl_from_slice_element!(i16); 81 | impl_from_slice_element!(i32); 82 | impl_from_slice_element!(i64); 83 | impl_from_slice_element!(f32); 84 | impl_from_slice_element!(f16); 85 | impl_from_slice_element!(bf16); 86 | impl_from_slice_element!(complex64); 87 | -------------------------------------------------------------------------------- /mlx-rs/src/array/operators.rs: -------------------------------------------------------------------------------- 1 | use crate::{utils::ScalarOrArray, Array, StreamOrDevice}; 2 | use num_traits::Pow; 3 | use std::{ 4 | iter::Product, 5 | ops::{ 6 | Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Sub, SubAssign, 7 | }, 8 | }; 9 | 10 | macro_rules! impl_binary_op { 11 | ($trait:ident, $method:ident, $c_method:ident) => { 12 | impl<'a, T> $trait for Array 13 | where 14 | T: ScalarOrArray<'a>, 15 | { 16 | type Output = Array; 17 | 18 | fn $method(self, rhs: T) -> Self::Output { 19 | paste::paste! { 20 | self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap() 21 | } 22 | } 23 | } 24 | 25 | impl<'a, 't: 'a, T> $trait for &'a Array 26 | where 27 | T: ScalarOrArray<'t>, 28 | { 29 | type Output = Array; 30 | 31 | fn $method(self, rhs: T) -> Self::Output { 32 | paste::paste! { 33 | self.[<$c_method _device>](rhs.into_owned_or_ref_array(), StreamOrDevice::default()).unwrap() 34 | } 35 | } 36 | } 37 | }; 38 | } 39 | 40 | macro_rules! impl_binary_op_assign { 41 | ($trait:ident, $method:ident, $c_method:ident) => { 42 | impl> $trait for Array { 43 | fn $method(&mut self, rhs: T) { 44 | let new_array = paste::paste! { 45 | self.[<$c_method _device>](&rhs.into(), StreamOrDevice::default()).unwrap() 46 | }; 47 | *self = new_array; 48 | } 49 | } 50 | 51 | impl $trait<&Array> for Array { 52 | fn $method(&mut self, rhs: &Self) { 53 | let new_array = paste::paste! { 54 | self.[<$c_method _device>](rhs, StreamOrDevice::default()).unwrap() 55 | }; 56 | *self = new_array; 57 | } 58 | } 59 | }; 60 | } 61 | 62 | impl_binary_op!(Add, add, add); 63 | impl_binary_op_assign!(AddAssign, add_assign, add); 64 | impl_binary_op!(Sub, sub, subtract); 65 | impl_binary_op_assign!(SubAssign, sub_assign, subtract); 66 | impl_binary_op!(Mul, mul, multiply); 67 | impl_binary_op_assign!(MulAssign, mul_assign, multiply); 68 | impl_binary_op!(Div, div, divide); 69 | impl_binary_op_assign!(DivAssign, div_assign, divide); 70 | impl_binary_op!(Rem, rem, remainder); 71 | impl_binary_op_assign!(RemAssign, rem_assign, remainder); 72 | impl_binary_op!(Pow, pow, power); 73 | 74 | impl Neg for &Array { 75 | type Output = Array; 76 | fn neg(self) -> Self::Output { 77 | self.negative_device(StreamOrDevice::default()).unwrap() 78 | } 79 | } 80 | impl Neg for Array { 81 | type Output = Array; 82 | fn neg(self) -> Self::Output { 83 | self.negative_device(StreamOrDevice::default()).unwrap() 84 | } 85 | } 86 | 87 | impl Not for &Array { 88 | type Output = Array; 89 | fn not(self) -> Self::Output { 90 | self.logical_not_device(StreamOrDevice::default()).unwrap() 91 | } 92 | } 93 | impl Not for Array { 94 | type Output = Array; 95 | fn not(self) -> Self::Output { 96 | self.logical_not_device(StreamOrDevice::default()).unwrap() 97 | } 98 | } 99 | 100 | impl Product for Array { 101 | fn product>(iter: I) -> Self { 102 | iter.fold(1.0.into(), |acc, x| acc * x) 103 | } 104 | } 105 | 106 | impl<'a> Product<&'a Array> for Array { 107 | fn product>(iter: I) -> Self { 108 | iter.fold(1.0.into(), |acc, x| acc * x) 109 | } 110 | } 111 | 112 | #[cfg(test)] 113 | mod tests { 114 | use super::*; 115 | use pretty_assertions::assert_eq; 116 | 117 | #[test] 118 | fn test_add_assign() { 119 | let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); 120 | let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); 121 | a += &b; 122 | 123 | assert_eq!(a.as_slice::(), &[5.0, 7.0, 9.0]); 124 | } 125 | 126 | #[test] 127 | fn test_sub_assign() { 128 | let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); 129 | let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); 130 | a -= &b; 131 | 132 | assert_eq!(a.as_slice::(), &[-3.0, -3.0, -3.0]); 133 | } 134 | 135 | #[test] 136 | fn test_mul_assign() { 137 | let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); 138 | let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); 139 | a *= &b; 140 | 141 | assert_eq!(a.as_slice::(), &[4.0, 10.0, 18.0]); 142 | } 143 | 144 | #[test] 145 | fn test_div_assign() { 146 | let mut a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); 147 | let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); 148 | a /= &b; 149 | 150 | assert_eq!(a.as_slice::(), &[0.25, 0.4, 0.5]); 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /mlx-rs/src/array/safetensors.rs: -------------------------------------------------------------------------------- 1 | //! Implement conversion from safetensors TensorView to Array 2 | //! 3 | //! `F8_*` dtypes are not supported and will return an error. 4 | 5 | use std::{ffi::c_void, mem::transmute}; 6 | 7 | use bytemuck::cast_slice; 8 | use safetensors::tensor::TensorView; 9 | 10 | use crate::{error::ConversionError, Dtype}; 11 | 12 | use super::Array; 13 | 14 | impl<'data> TryFrom> for Array { 15 | type Error = ConversionError; 16 | 17 | fn try_from(value: TensorView<'data>) -> Result { 18 | let dtype: Dtype = value.dtype().try_into()?; 19 | let shape = value.shape() 20 | .iter() 21 | .map(|x| *x as i32) 22 | .collect::>(); 23 | 24 | let data = value.data(); 25 | 26 | unsafe { 27 | Ok(Array::from_raw_data(data.as_ptr() as *const c_void, &shape, dtype)) 28 | } 29 | } 30 | } 31 | 32 | impl<'a> TryFrom<&'a Array> for TensorView<'a> { 33 | type Error = ConversionError; 34 | 35 | fn try_from(value: &'a Array) -> Result { 36 | let dtype: safetensors::tensor::Dtype = value.dtype().try_into()?; 37 | let shape = value.shape() 38 | .iter() 39 | .map(|x| *x as usize) 40 | .collect::>(); 41 | let data: &[u8] = unsafe { 42 | match value.dtype() { 43 | Dtype::Bool => { 44 | let data = value.as_slice::(); 45 | cast_slice(data) 46 | }, 47 | Dtype::Uint8 => { 48 | let data = value.as_slice::(); 49 | cast_slice(data) 50 | }, 51 | Dtype::Uint16 => { 52 | let data = value.as_slice::(); 53 | cast_slice(data) 54 | }, 55 | Dtype::Uint32 => { 56 | let data = value.as_slice::(); 57 | cast_slice(data) 58 | }, 59 | Dtype::Uint64 => { 60 | let data = value.as_slice::(); 61 | cast_slice(data) 62 | }, 63 | Dtype::Int8 => { 64 | let data = value.as_slice::(); 65 | cast_slice(data) 66 | }, 67 | Dtype::Int16 => { 68 | let data = value.as_slice::(); 69 | cast_slice(data) 70 | }, 71 | Dtype::Int32 => { 72 | let data = value.as_slice::(); 73 | cast_slice(data) 74 | }, 75 | Dtype::Int64 => { 76 | let data = value.as_slice::(); 77 | cast_slice(data) 78 | }, 79 | Dtype::Float16 => { 80 | let data = value.as_slice::(); 81 | let bits: &[u16] = transmute(data); 82 | cast_slice(bits) 83 | }, 84 | Dtype::Float32 => { 85 | let data = value.as_slice::(); 86 | cast_slice(data) 87 | }, 88 | Dtype::Bfloat16 => { 89 | let data = value.as_slice::(); 90 | let bits: &[u16] = transmute(data); 91 | cast_slice(bits) 92 | }, 93 | Dtype::Complex64 => return Err(ConversionError::MlxDtype(Dtype::Complex64)), 94 | } 95 | }; 96 | 97 | TensorView::new(dtype, shape, data) 98 | .map_err(Into::into) 99 | } 100 | } 101 | 102 | #[cfg(test)] 103 | mod tests { 104 | use safetensors::tensor::TensorView; 105 | 106 | use crate::{array, complex64, Array}; 107 | 108 | // Helper macro to test conversion between Array and TensorView 109 | macro_rules! assert_conversion { 110 | ($arr:expr, $dtype:expr) => { 111 | let arr = $arr.as_dtype($dtype).unwrap(); 112 | let tensor = TensorView::try_from(&arr).unwrap(); 113 | let arr2 = Array::try_from(tensor).unwrap(); 114 | 115 | assert_eq!(arr, arr2); 116 | }; 117 | } 118 | 119 | #[test] 120 | fn test_conversion_bool() { 121 | let arr = array!([[true, false, true], [false, true, false]]); 122 | assert_conversion!(&arr, crate::Dtype::Bool); 123 | } 124 | 125 | #[test] 126 | fn test_conversion_uint8() { 127 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 128 | assert_conversion!(&arr, crate::Dtype::Uint8); 129 | } 130 | 131 | #[test] 132 | fn test_conversion_uint16() { 133 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 134 | assert_conversion!(&arr, crate::Dtype::Uint16); 135 | } 136 | 137 | #[test] 138 | fn test_conversion_uint32() { 139 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 140 | assert_conversion!(&arr, crate::Dtype::Uint32); 141 | } 142 | 143 | #[test] 144 | fn test_conversion_uint64() { 145 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 146 | assert_conversion!(&arr, crate::Dtype::Uint64); 147 | } 148 | 149 | #[test] 150 | fn test_conversion_int8() { 151 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 152 | assert_conversion!(&arr, crate::Dtype::Int8); 153 | } 154 | 155 | #[test] 156 | fn test_conversion_int16() { 157 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 158 | assert_conversion!(&arr, crate::Dtype::Int16); 159 | } 160 | 161 | #[test] 162 | fn test_conversion_int32() { 163 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 164 | assert_conversion!(&arr, crate::Dtype::Int32); 165 | } 166 | 167 | #[test] 168 | fn test_conversion_int64() { 169 | let arr = array!([[1, 2, 3], [4, 5, 6]]); 170 | assert_conversion!(&arr, crate::Dtype::Int64); 171 | } 172 | 173 | #[test] 174 | fn test_conversion_float16() { 175 | let arr = array!([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); 176 | assert_conversion!(&arr, crate::Dtype::Float16); 177 | } 178 | 179 | #[test] 180 | fn test_conversion_float32() { 181 | let arr = array!([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); 182 | assert_conversion!(&arr, crate::Dtype::Float32); 183 | } 184 | 185 | #[test] 186 | fn test_conversion_bfloat16() { 187 | let arr = array!([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); 188 | assert_conversion!(&arr, crate::Dtype::Bfloat16); 189 | } 190 | 191 | #[test] 192 | fn test_conversion_complex64() { 193 | let arr = array!([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).as_type::().unwrap(); 194 | let tensor = TensorView::try_from(&arr); 195 | assert!(tensor.is_err()); 196 | } 197 | } -------------------------------------------------------------------------------- /mlx-rs/src/builder.rs: -------------------------------------------------------------------------------- 1 | //! Defines helper traits for builder pattern 2 | 3 | /// Helper trait for buildable types 4 | pub trait Buildable: Sized { 5 | /// The builder type for this buildable type 6 | type Builder: Builder; 7 | } 8 | 9 | /// Helper trait for builder 10 | pub trait Builder { 11 | /// Error with building 12 | type Error: std::error::Error; 13 | 14 | /// Build the type 15 | fn build(self) -> Result; 16 | } 17 | -------------------------------------------------------------------------------- /mlx-rs/src/device.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::CStr; 2 | 3 | use crate::{ 4 | error::Result, 5 | utils::{guard::Guarded, SUCCESS}, 6 | }; 7 | 8 | ///Type of device. 9 | #[derive(num_enum::IntoPrimitive, Debug, Clone, Copy)] 10 | #[repr(u32)] 11 | pub enum DeviceType { 12 | /// CPU device 13 | Cpu = mlx_sys::mlx_device_type__MLX_CPU, 14 | 15 | /// GPU device 16 | Gpu = mlx_sys::mlx_device_type__MLX_GPU, 17 | } 18 | 19 | /// Representation of a Device in MLX. 20 | pub struct Device { 21 | pub(crate) c_device: mlx_sys::mlx_device, 22 | } 23 | 24 | impl PartialEq for Device { 25 | fn eq(&self, other: &Self) -> bool { 26 | unsafe { mlx_sys::mlx_device_equal(self.c_device, other.c_device) } 27 | } 28 | } 29 | 30 | impl Device { 31 | /// Create a new [`Device`] 32 | pub fn new(device_type: DeviceType, index: i32) -> Device { 33 | let c_device = unsafe { mlx_sys::mlx_device_new_type(device_type.into(), index) }; 34 | Device { c_device } 35 | } 36 | 37 | /// Try to get the default device. 38 | pub fn try_default() -> Result { 39 | Device::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_device(res) }) 40 | } 41 | 42 | /// Create a default CPU device. 43 | pub fn cpu() -> Device { 44 | Device::new(DeviceType::Cpu, 0) 45 | } 46 | 47 | /// Create a default GPU device. 48 | pub fn gpu() -> Device { 49 | Device::new(DeviceType::Gpu, 0) 50 | } 51 | 52 | /// Get the device index 53 | pub fn get_index(&self) -> Result { 54 | i32::try_from_op(|res| unsafe { mlx_sys::mlx_device_get_index(res, self.c_device) }) 55 | } 56 | 57 | /// Get the device type 58 | pub fn get_type(&self) -> Result { 59 | DeviceType::try_from_op(|res| unsafe { mlx_sys::mlx_device_get_type(res, self.c_device) }) 60 | } 61 | 62 | /// Set the default device. 63 | /// 64 | /// # Example: 65 | /// 66 | /// ```rust 67 | /// use mlx_rs::{Device, DeviceType}; 68 | /// Device::set_default(&Device::new(DeviceType::Cpu, 1)); 69 | /// ``` 70 | /// 71 | /// By default, this is `gpu()`. 72 | pub fn set_default(device: &Device) { 73 | unsafe { mlx_sys::mlx_set_default_device(device.c_device) }; 74 | } 75 | 76 | fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 77 | unsafe { 78 | let mut mlx_str = mlx_sys::mlx_string_new(); 79 | let result = match mlx_sys::mlx_device_tostring(&mut mlx_str as *mut _, self.c_device) { 80 | SUCCESS => { 81 | let ptr = mlx_sys::mlx_string_data(mlx_str); 82 | let c_str = CStr::from_ptr(ptr); 83 | write!(f, "{}", c_str.to_string_lossy()) 84 | } 85 | _ => Err(std::fmt::Error), 86 | }; 87 | mlx_sys::mlx_string_free(mlx_str); 88 | result 89 | } 90 | } 91 | } 92 | 93 | impl Drop for Device { 94 | fn drop(&mut self) { 95 | let status = unsafe { mlx_sys::mlx_device_free(self.c_device) }; 96 | debug_assert_eq!(status, SUCCESS); 97 | } 98 | } 99 | 100 | impl Default for Device { 101 | fn default() -> Self { 102 | Self::try_default().unwrap() 103 | } 104 | } 105 | 106 | impl std::fmt::Debug for Device { 107 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 108 | self.describe(f) 109 | } 110 | } 111 | 112 | impl std::fmt::Display for Device { 113 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 114 | self.describe(f) 115 | } 116 | } 117 | 118 | #[cfg(test)] 119 | mod tests { 120 | use super::*; 121 | 122 | #[test] 123 | fn test_fmt() { 124 | let device = Device::default(); 125 | let description = format!("{}", device); 126 | assert_eq!(description, "Device(gpu, 0)"); 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /mlx-rs/src/fft/mod.rs: -------------------------------------------------------------------------------- 1 | //! Fast Fourier Transform (FFT) and its inverse (IFFT) for one, two, and `N` dimensions. 2 | //! 3 | //! Like all other functions in `mlx-rs`, three variants are provided for each FFT function, plus 4 | //! each variant has a version that uses the default `StreamOrDevice` or takes a user-specified 5 | //! `StreamOrDevice`. 6 | //! 7 | //! The difference are explained below using `fftn` as an example: 8 | //! 9 | //! 1. `fftn_unchecked`/`fftn_device_unchecked`: This function is simply a wrapper around the C API 10 | //! and does not perform any checks on the input. It may panic or get an fatal error that cannot 11 | //! be caught by the rust runtime if the input is invalid. 12 | //! 2. `try_fftn`/`try_fftn_device`: This function performs checks on the input and returns a 13 | //! `Result` instead of panicking. 14 | //! 3. `fftn`/`fftn_device`: This function is a wrapper around `try_fftn` and unwraps the result. It 15 | //! panics if the input is invalid. 16 | //! 17 | //! The functions that contains `device` in their name are meant to be used with a user-specified 18 | //! `StreamOrDevice`. If you don't care about the stream, you can use the functions without `device` 19 | //! in their names. Please note that GPU device support is not yet implemented. 20 | //! 21 | //! # Examples 22 | //! 23 | //! ## One dimension 24 | //! 25 | //! ```rust 26 | //! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*}; 27 | //! 28 | //! let src = [1.0f32, 2.0, 3.0, 4.0]; 29 | //! let mut array = Array::from_slice(&src[..], &[4]); 30 | //! 31 | //! let mut fft_result = fft(&array, 4, 0).unwrap(); 32 | //! assert_eq!(fft_result.dtype(), Dtype::Complex64); 33 | //! 34 | //! let expected = &[ 35 | //! complex64::new(10.0, 0.0), 36 | //! complex64::new(-2.0, 2.0), 37 | //! complex64::new(-2.0, 0.0), 38 | //! complex64::new(-2.0, -2.0), 39 | //! ]; 40 | //! assert_eq!(fft_result.as_slice::(), &expected[..]); 41 | //! 42 | //! let mut ifft_result = ifft(&fft_result, 4, 0).unwrap(); 43 | //! assert_eq!(ifft_result.dtype(), Dtype::Complex64); 44 | //! 45 | //! let expected = &[ 46 | //! complex64::new(1.0, 0.0), 47 | //! complex64::new(2.0, 0.0), 48 | //! complex64::new(3.0, 0.0), 49 | //! complex64::new(4.0, 0.0), 50 | //! ]; 51 | //! assert_eq!(ifft_result.as_slice::(), &expected[..]); 52 | //! 53 | //! let mut rfft_result = rfft(&array, 4, 0).unwrap(); 54 | //! assert_eq!(rfft_result.dtype(), Dtype::Complex64); 55 | //! 56 | //! let expected = &[ 57 | //! complex64::new(10.0, 0.0), 58 | //! complex64::new(-2.0, 2.0), 59 | //! complex64::new(-2.0, 0.0), 60 | //! ]; 61 | //! assert_eq!(rfft_result.as_slice::(), &expected[..]); 62 | //! 63 | //! let mut irfft_result = irfft(&rfft_result, 4, 0).unwrap(); 64 | //! assert_eq!(irfft_result.dtype(), Dtype::Float32); 65 | //! assert_eq!(irfft_result.as_slice::(), &src[..]); 66 | //! 67 | //! // The original array is not modified 68 | //! let data: &[f32] = array.as_slice(); 69 | //! assert_eq!(data, &src[..]); 70 | //! ``` 71 | //! 72 | //! ## Two dimensions 73 | //! 74 | //! ```rust 75 | //! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*}; 76 | //! 77 | //! let src = [1.0f32, 1.0, 1.0, 1.0]; 78 | //! let mut array = Array::from_slice(&src[..], &[2, 2]); 79 | //! 80 | //! let mut fft2_result = fft2(&array, None, None).unwrap(); 81 | //! assert_eq!(fft2_result.dtype(), Dtype::Complex64); 82 | //! let expected = &[ 83 | //! complex64::new(4.0, 0.0), 84 | //! complex64::new(0.0, 0.0), 85 | //! complex64::new(0.0, 0.0), 86 | //! complex64::new(0.0, 0.0), 87 | //! ]; 88 | //! assert_eq!(fft2_result.as_slice::(), &expected[..]); 89 | //! 90 | //! let mut ifft2_result = ifft2(&fft2_result, None, None).unwrap(); 91 | //! assert_eq!(ifft2_result.dtype(), Dtype::Complex64); 92 | //! 93 | //! let expected = &[ 94 | //! complex64::new(1.0, 0.0), 95 | //! complex64::new(1.0, 0.0), 96 | //! complex64::new(1.0, 0.0), 97 | //! complex64::new(1.0, 0.0), 98 | //! ]; 99 | //! assert_eq!(ifft2_result.as_slice::(), &expected[..]); 100 | //! 101 | //! let mut rfft2_result = rfft2(&array, None, None).unwrap(); 102 | //! assert_eq!(rfft2_result.dtype(), Dtype::Complex64); 103 | //! 104 | //! let expected = &[ 105 | //! complex64::new(4.0, 0.0), 106 | //! complex64::new(0.0, 0.0), 107 | //! complex64::new(0.0, 0.0), 108 | //! complex64::new(0.0, 0.0), 109 | //! ]; 110 | //! assert_eq!(rfft2_result.as_slice::(), &expected[..]); 111 | //! 112 | //! let mut irfft2_result = irfft2(&rfft2_result, None, None).unwrap(); 113 | //! assert_eq!(irfft2_result.dtype(), Dtype::Float32); 114 | //! assert_eq!(irfft2_result.as_slice::(), &src[..]); 115 | //! 116 | //! // The original array is not modified 117 | //! let data: &[f32] = array.as_slice(); 118 | //! assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); 119 | //! ``` 120 | //! 121 | //! ## `N` dimensions 122 | //! 123 | //! ```rust 124 | //! use mlx_rs::{Dtype, Array, StreamOrDevice, complex64, fft::*}; 125 | //! 126 | //! let mut array = Array::ones::(&[2, 2, 2]).unwrap(); 127 | //! let mut fftn_result = fftn(&array, None, None).unwrap(); 128 | //! assert_eq!(fftn_result.dtype(), Dtype::Complex64); 129 | //! 130 | //! let mut expected = [complex64::new(0.0, 0.0); 8]; 131 | //! expected[0] = complex64::new(8.0, 0.0); 132 | //! assert_eq!(fftn_result.as_slice::(), &expected[..]); 133 | //! 134 | //! let mut ifftn_result = ifftn(&fftn_result, None, None).unwrap(); 135 | //! assert_eq!(ifftn_result.dtype(), Dtype::Complex64); 136 | //! 137 | //! let expected = [complex64::new(1.0, 0.0); 8]; 138 | //! assert_eq!(ifftn_result.as_slice::(), &expected[..]); 139 | //! 140 | //! let mut rfftn_result = rfftn(&array, None, None).unwrap(); 141 | //! assert_eq!(rfftn_result.dtype(), Dtype::Complex64); 142 | //! 143 | //! let mut expected = [complex64::new(0.0, 0.0); 8]; 144 | //! expected[0] = complex64::new(8.0, 0.0); 145 | //! assert_eq!(rfftn_result.as_slice::(), &expected[..]); 146 | //! 147 | //! let mut irfftn_result = irfftn(&rfftn_result, None, None).unwrap(); 148 | //! assert_eq!(irfftn_result.dtype(), Dtype::Float32); 149 | //! 150 | //! let expected = [1.0; 8]; 151 | //! assert_eq!(irfftn_result.as_slice::(), &expected[..]); 152 | //! 153 | //! // The original array is not modified 154 | //! let data: &[f32] = array.as_slice(); 155 | //! assert_eq!(data, &[1.0; 8]); 156 | //! ``` 157 | 158 | mod fftn; 159 | mod rfftn; 160 | mod utils; 161 | 162 | pub use self::{fftn::*, rfftn::*}; 163 | 164 | /* -------------------------------------------------------------------------- */ 165 | /* Helper functions */ 166 | /* -------------------------------------------------------------------------- */ 167 | -------------------------------------------------------------------------------- /mlx-rs/src/fft/utils.rs: -------------------------------------------------------------------------------- 1 | use smallvec::SmallVec; 2 | 3 | use crate::{constants::DEFAULT_STACK_VEC_LEN, utils::resolve_index_unchecked, Array}; 4 | 5 | #[inline] 6 | pub(super) fn resolve_size_and_axis_unchecked( 7 | a: &Array, 8 | n: Option, 9 | axis: Option, 10 | ) -> (i32, i32) { 11 | let axis = axis.unwrap_or(-1); 12 | let n = n.unwrap_or_else(|| { 13 | let axis_index = resolve_index_unchecked(axis, a.ndim()); 14 | a.shape()[axis_index] 15 | }); 16 | (n, axis) 17 | } 18 | 19 | // Use Cow or SmallVec? 20 | #[inline] 21 | pub(super) fn resolve_sizes_and_axes_unchecked<'a>( 22 | a: &Array, 23 | s: Option<&'a [i32]>, 24 | axes: Option<&'a [i32]>, 25 | ) -> ( 26 | SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>, 27 | SmallVec<[i32; DEFAULT_STACK_VEC_LEN]>, 28 | ) { 29 | match (s, axes) { 30 | (Some(s), Some(axes)) => { 31 | let valid_s = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(s); 32 | let valid_axes = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(axes); 33 | (valid_s, valid_axes) 34 | } 35 | (Some(s), None) => { 36 | let valid_s = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(s); 37 | let valid_axes = (-(valid_s.len() as i32)..0).collect(); 38 | (valid_s, valid_axes) 39 | } 40 | (None, Some(axes)) => { 41 | let valid_s = axes 42 | .iter() 43 | .map(|&axis| { 44 | let axis_index = resolve_index_unchecked(axis, a.ndim()); 45 | a.shape()[axis_index] 46 | }) 47 | .collect(); 48 | let valid_axes = SmallVec::<[i32; DEFAULT_STACK_VEC_LEN]>::from_slice(axes); 49 | (valid_s, valid_axes) 50 | } 51 | (None, None) => { 52 | let valid_s: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = 53 | (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); 54 | let valid_axes = (-(valid_s.len() as i32)..0).collect(); 55 | (valid_s, valid_axes) 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /mlx-rs/src/macros/array.rs: -------------------------------------------------------------------------------- 1 | //! Macros for creating arrays. 2 | 3 | /// A helper macro to create an array with up to 3 dimensions. 4 | /// 5 | /// # Examples 6 | /// 7 | /// ```rust 8 | /// use mlx_rs::array; 9 | /// 10 | /// // Create an empty array 11 | /// // Note that an empty array defaults to f32 and one dimension 12 | /// let empty = array!(); 13 | /// 14 | /// // Create a scalar array 15 | /// let s = array!(1); 16 | /// // Scalar array has 0 dimension 17 | /// assert_eq!(s.ndim(), 0); 18 | /// 19 | /// // Create a one-element array (singleton matrix) 20 | /// let s = array!([1]); 21 | /// // Singleton array has 1 dimension 22 | /// assert!(s.ndim() == 1); 23 | /// 24 | /// // Create a 1D array 25 | /// let a1 = array!([1, 2, 3]); 26 | /// 27 | /// // Create a 2D array 28 | /// let a2 = array!([ 29 | /// [1, 2, 3], 30 | /// [4, 5, 6] 31 | /// ]); 32 | /// 33 | /// // Create a 3D array 34 | /// let a3 = array!([ 35 | /// [ 36 | /// [1, 2, 3], 37 | /// [4, 5, 6] 38 | /// ], 39 | /// [ 40 | /// [7, 8, 9], 41 | /// [10, 11, 12] 42 | /// ] 43 | /// ]); 44 | /// 45 | /// // Create a 2x2 array by specifying the shape 46 | /// let a = array!([1, 2, 3, 4], shape=[2, 2]); 47 | /// ``` 48 | #[macro_export] 49 | macro_rules! array { 50 | ([$($x:expr),*], shape=[$($s:expr),*]) => { 51 | { 52 | let data = [$($x,)*]; 53 | let shape = [$($s,)*]; 54 | $crate::Array::from_slice(&data, &shape) 55 | } 56 | }; 57 | ([$([$([$($x:expr),*]),*]),*]) => { 58 | { 59 | let arr = [$([$([$($x,)*],)*],)*]; 60 | <$crate::Array as $crate::FromNested<_>>::from_nested(arr) 61 | } 62 | }; 63 | ([$([$($x:expr),*]),*]) => { 64 | { 65 | let arr = [$([$($x,)*],)*]; 66 | <$crate::Array as $crate::FromNested<_>>::from_nested(arr) 67 | } 68 | }; 69 | ([$($x:expr),*]) => { 70 | { 71 | let arr = [$($x,)*]; 72 | <$crate::Array as $crate::FromNested<_>>::from_nested(arr) 73 | } 74 | }; 75 | ($x:expr) => { 76 | { 77 | <$crate::Array as $crate::FromScalar<_>>::from_scalar($x) 78 | } 79 | }; 80 | // Empty array default to f32 81 | () => { 82 | $crate::Array::from_slice::(&[], &[0]) 83 | }; 84 | } 85 | 86 | #[cfg(test)] 87 | mod tests { 88 | use crate::ops::indexing::IndexOp; 89 | 90 | #[test] 91 | fn test_scalar_array() { 92 | let arr = array!(1); 93 | 94 | // Scalar array has 0 dimension 95 | assert_eq!(arr.ndim(), 0); 96 | // Scalar array has empty shape 97 | assert!(arr.shape().is_empty()); 98 | assert_eq!(arr.item::(), 1); 99 | } 100 | 101 | #[test] 102 | fn test_array_1d() { 103 | let arr = array!([1, 2, 3]); 104 | 105 | // One element array has 1 dimension 106 | assert_eq!(arr.ndim(), 1); 107 | assert_eq!(arr.shape(), &[3]); 108 | assert_eq!(arr.index(0).item::(), 1); 109 | assert_eq!(arr.index(1).item::(), 2); 110 | assert_eq!(arr.index(2).item::(), 3); 111 | } 112 | 113 | #[test] 114 | fn test_array_2d() { 115 | let a = array!([[1, 2, 3], [4, 5, 6]]); 116 | 117 | assert_eq!(a.ndim(), 2); 118 | assert_eq!(a.shape(), &[2, 3]); 119 | assert_eq!(a.index((0, 0)).item::(), 1); 120 | assert_eq!(a.index((0, 1)).item::(), 2); 121 | assert_eq!(a.index((0, 2)).item::(), 3); 122 | assert_eq!(a.index((1, 0)).item::(), 4); 123 | assert_eq!(a.index((1, 1)).item::(), 5); 124 | assert_eq!(a.index((1, 2)).item::(), 6); 125 | } 126 | 127 | #[test] 128 | fn test_array_3d() { 129 | let a = array!([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]); 130 | 131 | assert!(a.ndim() == 3); 132 | assert_eq!(a.shape(), &[2, 2, 3]); 133 | assert_eq!(a.index((0, 0, 0)).item::(), 1); 134 | assert_eq!(a.index((0, 0, 1)).item::(), 2); 135 | assert_eq!(a.index((0, 0, 2)).item::(), 3); 136 | assert_eq!(a.index((0, 1, 0)).item::(), 4); 137 | assert_eq!(a.index((0, 1, 1)).item::(), 5); 138 | assert_eq!(a.index((0, 1, 2)).item::(), 6); 139 | assert_eq!(a.index((1, 0, 0)).item::(), 7); 140 | assert_eq!(a.index((1, 0, 1)).item::(), 8); 141 | assert_eq!(a.index((1, 0, 2)).item::(), 9); 142 | assert_eq!(a.index((1, 1, 0)).item::(), 10); 143 | assert_eq!(a.index((1, 1, 1)).item::(), 11); 144 | assert_eq!(a.index((1, 1, 2)).item::(), 12); 145 | } 146 | 147 | #[test] 148 | fn test_array_with_shape() { 149 | let a = array!([1, 2, 3, 4], shape = [2, 2]); 150 | 151 | assert_eq!(a.ndim(), 2); 152 | assert_eq!(a.shape(), &[2, 2]); 153 | assert_eq!(a.index((0, 0)).item::(), 1); 154 | assert_eq!(a.index((0, 1)).item::(), 2); 155 | assert_eq!(a.index((1, 0)).item::(), 3); 156 | assert_eq!(a.index((1, 1)).item::(), 4); 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /mlx-rs/src/macros/assert.rs: -------------------------------------------------------------------------------- 1 | /// Asserts that two arrays are equal. 2 | /// 3 | /// It checks that the two arrays have the same shape and that all elements are 4 | /// sufficiently close. 5 | #[macro_export] 6 | macro_rules! assert_array_eq { 7 | ($value:expr, $expected:expr) => { 8 | assert_array_eq!($value, $expected, None); 9 | }; 10 | ($value:expr, $expected:expr, $atol:expr) => { 11 | assert_eq!($value.shape(), $expected.shape(), "Shapes are not equal"); 12 | let assert = $value.all_close(&$expected, $atol, $atol, None); 13 | assert!( 14 | assert.unwrap().item::(), 15 | "Values are not sufficiently close" 16 | ); 17 | }; 18 | } 19 | -------------------------------------------------------------------------------- /mlx-rs/src/macros/internal.rs: -------------------------------------------------------------------------------- 1 | /// See `assertEqual` in the swift binding tests 2 | #[allow(unused_macros)] 3 | macro_rules! assert_array_all_close { 4 | ($a:tt, $b:tt) => { 5 | let _b: Array = $b.into(); 6 | let assert = $a.all_close(&_b, None, None, None).unwrap(); 7 | assert!(assert.item::()); 8 | }; 9 | } 10 | 11 | #[allow(unused_macros)] 12 | macro_rules! cfg_safetensors { 13 | ($($item:item)*) => { 14 | $( 15 | #[cfg(feature = "safetensors")] 16 | $item 17 | )* 18 | }; 19 | } 20 | -------------------------------------------------------------------------------- /mlx-rs/src/macros/mod.rs: -------------------------------------------------------------------------------- 1 | //! Macros for mlx-rs. 2 | 3 | #[macro_use] 4 | mod internal; 5 | 6 | mod array; 7 | mod assert; 8 | 9 | pub use mlx_macros::*; 10 | -------------------------------------------------------------------------------- /mlx-rs/src/module/mod.rs: -------------------------------------------------------------------------------- 1 | //! This mod defines the traits for neural network modules and parameters. 2 | //! 3 | //! This is to separate the trait definitions from the implementations, which are in the `mlx-nn` 4 | //! crate. This also allows using the `mlx_macros::ModuleParameters` derive macro in crates other 5 | //! than `mlx-nn`. 6 | 7 | #[allow(clippy::module_inception)] 8 | mod module; 9 | mod param; 10 | 11 | pub use module::*; 12 | pub use param::*; 13 | -------------------------------------------------------------------------------- /mlx-rs/src/module/param.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | ops::{Deref, DerefMut}, 4 | rc::Rc, 5 | }; 6 | 7 | use crate::{nested::NestedValue, Array}; 8 | 9 | use super::ModuleParameters; 10 | 11 | /// Trait for a module parameter. 12 | pub trait Parameter { 13 | /// Total number of parameters in this module/parameter. 14 | fn count(&self) -> usize; 15 | 16 | /// Freeze the parameter. 17 | fn freeze(&mut self, recursive: bool); 18 | 19 | /// Unfreeze the parameter. 20 | fn unfreeze(&mut self, recursive: bool); 21 | 22 | /// Check if the parameter is frozen. Returns `None` if the parameter is a module that has no 23 | /// parameters. 24 | fn is_frozen(&self) -> Option; 25 | 26 | /// Get the parameter as a nested value. 27 | fn as_nested_value(&self) -> NestedValue, &Array>; 28 | 29 | /// Get the parameter as a mutable nested value. 30 | fn as_nested_value_mut(&mut self) -> NestedValue, &mut Array>; 31 | 32 | /// Get the parameter as a nested value if it is trainable. 33 | fn as_trainable_nested_value(&self) -> Option, &Array>>; 34 | } 35 | 36 | /// A simple wrapper for a module parameter. 37 | #[derive(Debug, Clone)] 38 | pub struct Param { 39 | /// The value of the parameter. 40 | pub value: T, 41 | 42 | /// Whether the parameter is frozen. 43 | /// 44 | /// This is no longer public because it should be accessed through the `Parameter` trait. 45 | is_frozen: bool, 46 | } 47 | 48 | impl Param { 49 | /// Create a new `Param` 50 | pub fn new(value: T) -> Self { 51 | Self { 52 | value, 53 | is_frozen: false, 54 | } 55 | } 56 | } 57 | 58 | impl From for Param { 59 | fn from(inner: T) -> Self { 60 | Self::new(inner) 61 | } 62 | } 63 | 64 | impl Deref for Param { 65 | type Target = T; 66 | 67 | fn deref(&self) -> &Self::Target { 68 | &self.value 69 | } 70 | } 71 | 72 | impl DerefMut for Param { 73 | fn deref_mut(&mut self) -> &mut Self::Target { 74 | &mut self.value 75 | } 76 | } 77 | 78 | impl AsRef for Param { 79 | fn as_ref(&self) -> &T { 80 | &self.value 81 | } 82 | } 83 | 84 | impl AsMut for Param { 85 | fn as_mut(&mut self) -> &mut T { 86 | &mut self.value 87 | } 88 | } 89 | 90 | impl Parameter for Param { 91 | fn count(&self) -> usize { 92 | 1 93 | } 94 | 95 | fn freeze(&mut self, _recursive: bool) { 96 | self.is_frozen = true; 97 | } 98 | 99 | fn unfreeze(&mut self, _recursive: bool) { 100 | self.is_frozen = false; 101 | } 102 | 103 | fn is_frozen(&self) -> Option { 104 | Some(self.is_frozen) 105 | } 106 | 107 | fn as_nested_value<'a>(&self) -> NestedValue, &Array> { 108 | NestedValue::Value(&self.value) 109 | } 110 | 111 | fn as_nested_value_mut<'a>(&mut self) -> NestedValue, &mut Array> { 112 | NestedValue::Value(&mut self.value) 113 | } 114 | 115 | fn as_trainable_nested_value<'a>(&self) -> Option, &Array>> { 116 | match self.is_frozen { 117 | true => None, 118 | false => Some(NestedValue::Value(&self.value)), 119 | } 120 | } 121 | } 122 | 123 | impl Parameter for Param> { 124 | fn count(&self) -> usize { 125 | self.value.as_ref().map_or(0, |_| 1) 126 | } 127 | 128 | fn freeze(&mut self, _recursive: bool) { 129 | self.is_frozen = true; 130 | } 131 | 132 | fn unfreeze(&mut self, _recursive: bool) { 133 | self.is_frozen = false; 134 | } 135 | 136 | fn is_frozen(&self) -> Option { 137 | Some(self.is_frozen) 138 | } 139 | 140 | fn as_nested_value(&self) -> NestedValue, &Array> { 141 | match &self.value { 142 | Some(array) => NestedValue::Value(array), 143 | // An empty map entry will be ignored during flattening 144 | None => NestedValue::Map(HashMap::with_capacity(0)), 145 | } 146 | } 147 | 148 | fn as_nested_value_mut(&mut self) -> NestedValue, &mut Array> { 149 | match &mut self.value { 150 | Some(array) => NestedValue::Value(array), 151 | // An empty map entry will be ignored during flattening 152 | None => NestedValue::Map(HashMap::with_capacity(0)), 153 | } 154 | } 155 | 156 | fn as_trainable_nested_value(&self) -> Option, &Array>> { 157 | match self.is_frozen { 158 | true => None, 159 | false => self.value.as_ref().map(NestedValue::Value), 160 | } 161 | } 162 | } 163 | 164 | impl Parameter for T 165 | where 166 | T: ModuleParameters, 167 | { 168 | fn count(&self) -> usize { 169 | self.num_parameters() 170 | } 171 | 172 | fn freeze(&mut self, recursive: bool) { 173 | self.freeze_parameters(recursive); 174 | } 175 | 176 | fn unfreeze(&mut self, recursive: bool) { 177 | self.unfreeze_parameters(recursive); 178 | } 179 | 180 | fn is_frozen(&self) -> Option { 181 | self.all_frozen() 182 | } 183 | 184 | fn as_nested_value(&self) -> NestedValue, &Array> { 185 | self.parameters().into() 186 | } 187 | 188 | fn as_nested_value_mut(&mut self) -> NestedValue, &mut Array> { 189 | self.parameters_mut().into() 190 | } 191 | 192 | fn as_trainable_nested_value(&self) -> Option, &Array>> { 193 | Some(self.trainable_parameters().into()) 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /mlx-rs/src/nested.rs: -------------------------------------------------------------------------------- 1 | //! Implements a nested hashmap 2 | 3 | use std::{collections::HashMap, fmt::Display, rc::Rc}; 4 | 5 | const DELIMITER: char = '.'; 6 | 7 | /// A nested value that can be either a value or a map of nested values 8 | #[derive(Debug, Clone)] 9 | pub enum NestedValue { 10 | /// A value 11 | Value(T), 12 | 13 | /// A map of nested values 14 | Map(HashMap>), 15 | } 16 | 17 | impl NestedValue { 18 | /// Flattens the nested value into a hashmap 19 | pub fn flatten(self, prefix: &str) -> HashMap, V> 20 | where 21 | K: Display, 22 | { 23 | match self { 24 | NestedValue::Value(array) => { 25 | let mut map = HashMap::new(); 26 | map.insert(prefix.into(), array); 27 | map 28 | } 29 | NestedValue::Map(entries) => entries 30 | .into_iter() 31 | .flat_map(|(key, value)| value.flatten(&format!("{}{}{}", prefix, DELIMITER, key))) 32 | .collect(), 33 | } 34 | } 35 | } 36 | 37 | /// A nested hashmap 38 | #[derive(Debug, Clone)] 39 | pub struct NestedHashMap { 40 | /// The internal hashmap 41 | pub entries: HashMap>, 42 | } 43 | 44 | impl From> for NestedValue { 45 | fn from(map: NestedHashMap) -> Self { 46 | NestedValue::Map(map.entries) 47 | } 48 | } 49 | 50 | impl Default for NestedHashMap { 51 | fn default() -> Self { 52 | Self::new() 53 | } 54 | } 55 | 56 | impl NestedHashMap { 57 | /// Creates a new nested hashmap 58 | pub fn new() -> Self { 59 | Self { 60 | entries: HashMap::new(), 61 | } 62 | } 63 | 64 | /// Inserts a new entry into the nested hashmap 65 | pub fn insert(&mut self, key: K, value: NestedValue) 66 | where 67 | K: Eq + std::hash::Hash, 68 | { 69 | self.entries.insert(key, value); 70 | } 71 | 72 | /// Flattens the nested hashmap into a hashmap 73 | pub fn flatten(self) -> HashMap, V> 74 | where 75 | K: AsRef + Display, 76 | { 77 | self.entries 78 | .into_iter() 79 | .flat_map(|(key, value)| value.flatten(key.as_ref())) 80 | .collect() 81 | } 82 | } 83 | 84 | #[cfg(test)] 85 | mod tests { 86 | use crate::array; 87 | 88 | use super::*; 89 | 90 | #[test] 91 | fn test_flatten_nested_hash_map_of_owned_arrays() { 92 | let first_entry = NestedValue::Value(array!([1, 2, 3])); 93 | let second_entry = NestedValue::Map({ 94 | let mut map = HashMap::new(); 95 | map.insert("a", NestedValue::Value(array!([4, 5, 6]))); 96 | map.insert("b", NestedValue::Value(array!([7, 8, 9]))); 97 | map 98 | }); 99 | 100 | let map = NestedHashMap { 101 | entries: { 102 | let mut map = HashMap::new(); 103 | map.insert("first", first_entry); 104 | map.insert("second", second_entry); 105 | map 106 | }, 107 | }; 108 | 109 | let flattened = map.flatten(); 110 | 111 | assert_eq!(flattened.len(), 3); 112 | assert_eq!(flattened["first"], array!([1, 2, 3])); 113 | assert_eq!(flattened["second.a"], array!([4, 5, 6])); 114 | assert_eq!(flattened["second.b"], array!([7, 8, 9])); 115 | } 116 | 117 | #[test] 118 | fn test_flatten_nested_hash_map_of_borrowed_arrays() { 119 | let first_entry_content = array!([1, 2, 3]); 120 | let first_entry = NestedValue::Value(&first_entry_content); 121 | 122 | let second_entry_content_a = array!([4, 5, 6]); 123 | let second_entry_content_b = array!([7, 8, 9]); 124 | let second_entry = NestedValue::Map({ 125 | let mut map = HashMap::new(); 126 | map.insert("a", NestedValue::Value(&second_entry_content_a)); 127 | map.insert("b", NestedValue::Value(&second_entry_content_b)); 128 | map 129 | }); 130 | 131 | let map = NestedHashMap { 132 | entries: { 133 | let mut map = HashMap::new(); 134 | map.insert("first", first_entry); 135 | map.insert("second", second_entry); 136 | map 137 | }, 138 | }; 139 | 140 | let flattened = map.flatten(); 141 | 142 | assert_eq!(flattened.len(), 3); 143 | assert_eq!(flattened["first"], &first_entry_content); 144 | assert_eq!(flattened["second.a"], &second_entry_content_a); 145 | assert_eq!(flattened["second.b"], &second_entry_content_b); 146 | } 147 | 148 | #[test] 149 | fn test_flatten_nested_hash_map_of_mut_borrowed_arrays() { 150 | let mut first_entry_content = array!([1, 2, 3]); 151 | let first_entry = NestedValue::Value(&mut first_entry_content); 152 | 153 | let mut second_entry_content_a = array!([4, 5, 6]); 154 | let mut second_entry_content_b = array!([7, 8, 9]); 155 | let second_entry = NestedValue::Map({ 156 | let mut map = HashMap::new(); 157 | map.insert("a", NestedValue::Value(&mut second_entry_content_a)); 158 | map.insert("b", NestedValue::Value(&mut second_entry_content_b)); 159 | map 160 | }); 161 | 162 | let map = NestedHashMap { 163 | entries: { 164 | let mut map = HashMap::new(); 165 | map.insert("first", first_entry); 166 | map.insert("second", second_entry); 167 | map 168 | }, 169 | }; 170 | 171 | let flattened = map.flatten(); 172 | 173 | assert_eq!(flattened.len(), 3); 174 | assert_eq!(flattened["first"], &mut array!([1, 2, 3])); 175 | assert_eq!(flattened["second.a"], &mut array!([4, 5, 6])); 176 | assert_eq!(flattened["second.b"], &mut array!([7, 8, 9])); 177 | } 178 | 179 | #[test] 180 | fn test_flatten_empty_nested_hash_map() { 181 | let map = NestedHashMap::<&str, i32>::new(); 182 | let flattened = map.flatten(); 183 | 184 | assert!(flattened.is_empty()); 185 | 186 | // Insert another empty map 187 | let mut map = NestedHashMap::<&str, i32>::new(); 188 | let empty_map = NestedValue::Map(HashMap::new()); 189 | map.insert("empty", empty_map); 190 | 191 | let flattened = map.flatten(); 192 | assert!(flattened.is_empty()); 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /mlx-rs/src/nn/container.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | 3 | use crate::module::{Module, UnaryModule}; 4 | use crate::{error::Exception, Array}; 5 | use mlx_macros::ModuleParameters; 6 | 7 | /// Marker trait for items that can be used in a `Sequential` module. 8 | /// 9 | /// It is implemented for all types that implement [`Module`] and [`std::fmt::Debug`]. 10 | pub trait SequentialModuleItem: UnaryModule + std::fmt::Debug {} 11 | 12 | impl SequentialModuleItem for T where T: UnaryModule + std::fmt::Debug {} 13 | 14 | /// A sequential layer. 15 | /// 16 | /// It calls each layer in sequence. 17 | #[derive(Debug, ModuleParameters)] 18 | #[module(root = crate)] 19 | pub struct Sequential { 20 | /// The layers to be called in sequence. 21 | #[param] 22 | pub layers: Vec>>, 23 | } 24 | 25 | impl Module<&Array> for Sequential { 26 | type Error = Exception; 27 | type Output = Array; 28 | 29 | fn forward(&mut self, x: &Array) -> Result { 30 | let mut x = Cow::Borrowed(x); 31 | 32 | for layer in &mut self.layers { 33 | x = Cow::Owned(layer.forward(x.as_ref())?); 34 | } 35 | 36 | match x { 37 | Cow::Owned(array) => Ok(array), 38 | Cow::Borrowed(array) => Ok(array.clone()), 39 | } 40 | } 41 | 42 | fn training_mode(&mut self, mode: bool) { 43 | self.layers 44 | .iter_mut() 45 | .for_each(|layer| layer.training_mode(mode)); 46 | } 47 | } 48 | 49 | impl Default for Sequential { 50 | fn default() -> Self { 51 | Self::new() 52 | } 53 | } 54 | 55 | impl Sequential { 56 | /// Creates a new [`Sequential`] module. 57 | pub fn new() -> Self { 58 | Self { layers: Vec::new() } 59 | } 60 | 61 | /// Appends a layer to the sequential module. 62 | pub fn append(mut self, layer: M) -> Self 63 | where 64 | M: UnaryModule + std::fmt::Debug + 'static, 65 | { 66 | self.layers.push(Box::new(layer)); 67 | self 68 | } 69 | } 70 | 71 | #[cfg(test)] 72 | mod tests { 73 | use crate::{ 74 | array, 75 | builder::Builder, 76 | module::ModuleParameters, 77 | nn::{self, Linear}, 78 | ops::zeros, 79 | optimizers::{Optimizer, Sgd}, 80 | random::uniform, 81 | transforms::{eval, eval_params}, 82 | }; 83 | 84 | use crate::losses::{LossReduction, MseLossBuilder}; 85 | 86 | use super::*; 87 | 88 | #[test] 89 | fn test_sequential_linear_param_len() { 90 | let model = Sequential::new() 91 | .append(Linear::new(2, 3).unwrap()) 92 | .append(Linear::new(3, 1).unwrap()); 93 | 94 | let params = model.parameters().flatten(); 95 | assert_eq!(params.len(), 4); 96 | } 97 | 98 | #[test] 99 | fn test_sequential_linear_param_update() { 100 | let mut model = Sequential::new() 101 | .append(Linear::new(2, 3).unwrap()) 102 | .append(Linear::new(3, 1).unwrap()); 103 | 104 | model 105 | .trainable_parameters() 106 | .flatten() 107 | .iter() 108 | .for_each(|(key, value)| { 109 | println!("{}: {:?}", key, value); 110 | }); 111 | 112 | let mut params = model.parameters_mut().flatten(); 113 | 114 | // Check that the initial weights are not all zeros 115 | assert!( 116 | params["layers.0.weight"] 117 | .abs() 118 | .unwrap() 119 | .sum(None) 120 | .unwrap() 121 | .item::() 122 | - 0.0 123 | > 1e-6 124 | ); 125 | 126 | // Update the weight with zeros 127 | let shape = params["layers.0.weight"].shape(); 128 | let zeros = zeros::(shape).unwrap(); 129 | let value_mut = params.get_mut("layers.0.weight").unwrap(); 130 | **value_mut = zeros; 131 | 132 | // Check that the weight is now all zeros 133 | let first_layer = &model.layers[0]; 134 | let linear_params = first_layer.parameters().flatten(); 135 | let weight = linear_params["weight"]; 136 | assert!(weight.abs().unwrap().sum(None).unwrap().item::() - 0.0 < 1e-6); 137 | } 138 | 139 | #[test] 140 | fn test_sgd_update_sequential_linear_params() { 141 | let lr = 1e-2; 142 | let input_dim = 2; 143 | let hidden_dim = 3; 144 | let output_dim = 2; 145 | 146 | // Test using a simple linear equation 147 | let m = array!(0.25); 148 | let b = array!(0.75); 149 | 150 | let mut model = Sequential::new() 151 | .append(Linear::new(input_dim, hidden_dim).unwrap()) 152 | .append(Linear::new(hidden_dim, output_dim).unwrap()); 153 | 154 | let loss = MseLossBuilder::new() 155 | .reduction(LossReduction::Mean) 156 | .build() 157 | .unwrap(); 158 | let loss_fn = |model: &mut Sequential, (x, y): (&Array, &Array)| { 159 | let y_pred = model.forward(x)?; 160 | loss.apply(&y_pred, y) 161 | }; 162 | 163 | let mut lg = nn::value_and_grad(loss_fn); 164 | 165 | let mut optimizer = Sgd::new(lr); 166 | 167 | let mut losses = vec![]; 168 | for _ in 0..100 { 169 | // Generate random data 170 | let x = uniform::<_, f32>(-5.0, 5.0, &[input_dim], None).unwrap(); 171 | let y = &m * &x + &b; 172 | 173 | eval([&x, &y]).unwrap(); 174 | 175 | // Compute the loss and gradients and update the model 176 | let (loss, grads) = lg(&mut model, (&x, &y)).unwrap(); 177 | optimizer.update(&mut model, grads).unwrap(); 178 | 179 | eval_params(model.parameters()).unwrap(); 180 | 181 | losses.push(loss.item::()); 182 | } 183 | 184 | // Check that it converges 185 | assert!( 186 | losses[0] > losses[losses.len() - 1], 187 | "Not converging loss: {:?}", 188 | losses 189 | ); 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /mlx-rs/src/nn/embedding.rs: -------------------------------------------------------------------------------- 1 | //! Embedding layer. 2 | 3 | use crate::error::Exception; 4 | use crate::module::Module; 5 | use crate::module::Param; 6 | use crate::ops::indexing::IndexOp; 7 | use crate::quantization::Quantizable; 8 | use crate::Array; 9 | use mlx_macros::ModuleParameters; 10 | 11 | use super::QuantizedEmbedding; 12 | 13 | /// Implements a simple lookup table that maps each input integer to a high-dimensional vector. 14 | /// 15 | /// Typically used to embed discrete tokens for processing by neural networks. 16 | #[derive(Debug, Clone, ModuleParameters)] 17 | #[module(root = crate)] 18 | pub struct Embedding { 19 | /// The weight of the 20 | #[param] 21 | pub weight: Param, 22 | } 23 | 24 | impl Embedding { 25 | /// Creates a new [`Embedding`] layer. 26 | /// 27 | /// # Params 28 | /// 29 | /// - `embedding_count`: How many possible discrete tokens can we embed. Usually called the vocabulary size. 30 | /// - `dimensions`: The dimensionality of the embeddings. 31 | pub fn new(embedding_count: i32, dimensions: i32) -> Result { 32 | let scale = f32::sqrt(1.0 / (dimensions as f32)); 33 | let weight = 34 | crate::random::normal::(&[embedding_count, dimensions], None, None, None)? * scale; 35 | 36 | Ok(Self { 37 | weight: Param::new(weight), 38 | }) 39 | } 40 | 41 | /// Call the embedding layer as a linear layer. 42 | /// 43 | /// Use this for example when input embedding and output projection 44 | /// weights are tied. 45 | pub fn as_linear(&self, x: &Array) -> Result { 46 | crate::ops::matmul(x, self.weight.value.t()) 47 | } 48 | } 49 | 50 | impl Quantizable for Embedding { 51 | type Quantized = QuantizedEmbedding; 52 | 53 | type QuantizationError = Exception; 54 | 55 | fn try_into_quantized( 56 | self, 57 | group_size: i32, 58 | bits: i32, 59 | ) -> Result { 60 | QuantizedEmbedding::try_from_embedding(self, group_size, bits) 61 | } 62 | } 63 | 64 | impl Module<&Array> for Embedding { 65 | type Error = Exception; 66 | type Output = Array; 67 | 68 | fn forward(&mut self, x: &Array) -> Result { 69 | Ok(self.weight.index(x)) 70 | } 71 | 72 | fn training_mode(&mut self, _mode: bool) {} 73 | } 74 | 75 | #[cfg(test)] 76 | mod tests { 77 | use super::*; 78 | use float_eq::float_eq; 79 | use pretty_assertions::assert_eq; 80 | 81 | #[test] 82 | fn test_embedding() { 83 | crate::random::seed(557).unwrap(); 84 | let a = crate::random::randint::<_, i32>(0, 10, &[2, 8, 8, 4], None).unwrap(); 85 | assert_eq!(a.shape(), &[2, 8, 8, 4]); 86 | assert_eq!(a.dtype(), crate::Dtype::Int32); 87 | float_eq!( 88 | a.mean(None).unwrap().item::(), 89 | 4.605_468_8, 90 | abs <= 0.092_109_375 91 | ); 92 | float_eq!(a.sum(None).unwrap().item::(), 2358.0, abs <= 47.16); 93 | 94 | let result = Embedding::new(10, 8).unwrap().forward(&a).unwrap(); 95 | assert_eq!(result.shape(), &[2, 8, 8, 4, 8]); 96 | assert_eq!(result.dtype(), crate::Dtype::Float32); 97 | float_eq!( 98 | result.mean(None).unwrap().item::(), 99 | -0.001_197_346_3, 100 | abs <= 2.394_692_5e-5 101 | ); 102 | float_eq!( 103 | result.sum(None).unwrap().item::(), 104 | -4.904_330_3, 105 | abs <= 0.098_086_6 106 | ); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /mlx-rs/src/nn/mod.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs, missing_debug_implementations)] 2 | 3 | //! Neural network support for MLX 4 | //! 5 | //! All modules provide a `new()` function that take mandatory parameters and other methods 6 | //! to set optional parameters. 7 | 8 | mod activation; 9 | mod container; 10 | mod convolution; 11 | mod convolution_transpose; 12 | mod dropout; 13 | mod embedding; 14 | mod linear; 15 | mod normalization; 16 | mod pooling; 17 | mod positional_encoding; 18 | mod quantized; 19 | mod recurrent; 20 | mod transformer; 21 | mod upsample; 22 | mod value_and_grad; 23 | 24 | pub use activation::*; 25 | pub use container::*; 26 | pub use convolution::*; 27 | pub use convolution_transpose::*; 28 | pub use dropout::*; 29 | pub use embedding::*; 30 | pub use linear::*; 31 | pub use normalization::*; 32 | pub use pooling::*; 33 | pub use positional_encoding::*; 34 | pub use quantized::*; 35 | pub use recurrent::*; 36 | pub use transformer::*; 37 | pub use upsample::*; 38 | pub use value_and_grad::*; 39 | -------------------------------------------------------------------------------- /mlx-rs/src/ops/mod.rs: -------------------------------------------------------------------------------- 1 | //! Operations 2 | 3 | mod arithmetic; 4 | mod conversion; 5 | mod convolution; 6 | mod cumulative; 7 | mod factory; 8 | mod io; 9 | mod logical; 10 | mod other; 11 | mod quantization; 12 | mod reduction; 13 | mod shapes; 14 | mod sort; 15 | 16 | pub mod indexing; 17 | 18 | pub use arithmetic::*; 19 | pub use convolution::*; 20 | pub use cumulative::*; 21 | pub use factory::*; 22 | pub use logical::*; 23 | pub use other::*; 24 | pub use quantization::*; 25 | pub use reduction::*; 26 | pub use shapes::*; 27 | pub use sort::*; 28 | -------------------------------------------------------------------------------- /mlx-rs/src/ops/quantization.rs: -------------------------------------------------------------------------------- 1 | use mlx_internal_macros::{default_device, generate_macro}; 2 | 3 | use crate::{error::Result, utils::guard::Guarded, Array, Stream}; 4 | 5 | /// Quantize the matrix `w` using `bits` bits per element. 6 | /// 7 | /// Note, every `group_size` elements in a row of `w` are quantized together. Hence, number of 8 | /// columns of `w` should be divisible by `group_size`. In particular, the rows of `w` are divided 9 | /// into groups of size `group_size` which are quantized together. 10 | /// 11 | /// > `quantized` currently only supports 2D inputs with dimensions which are multiples of 32 12 | /// 13 | /// For details, please see [this 14 | /// documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.quantize.html) 15 | /// 16 | /// # Params 17 | /// 18 | /// - `w`: The input matrix 19 | /// - `group_size`: The size of the group in `w` that shares a scale and bias. (default: `64`) 20 | /// - `bits`: The number of bits occupied by each element of w in the returned quantized matrix. 21 | /// (default: 4) 22 | #[generate_macro] 23 | #[default_device] 24 | pub fn quantize_device( 25 | w: impl AsRef, 26 | #[optional] group_size: impl Into>, 27 | #[optional] bits: impl Into>, 28 | #[optional] stream: impl AsRef, 29 | ) -> Result<(Array, Array, Array)> { 30 | let group_size = group_size.into().unwrap_or(64); 31 | let bits = bits.into().unwrap_or(4); 32 | 33 | <(Array, Array, Array) as Guarded>::try_from_op(|(res0, res1, res2)| unsafe { 34 | mlx_sys::mlx_quantize( 35 | res0, 36 | res1, 37 | res2, 38 | w.as_ref().as_ptr(), 39 | group_size, 40 | bits, 41 | stream.as_ref().as_ptr(), 42 | ) 43 | }) 44 | } 45 | 46 | /// Perform the matrix multiplication with the quantized matrix `w`. The quantization uses one 47 | /// floating point scale and bias per `group_size` of elements. Each element in `w` takes `bits` 48 | /// bits and is packed in an unsigned 32 bit integer. 49 | #[allow(clippy::too_many_arguments)] 50 | #[generate_macro] 51 | #[default_device] 52 | pub fn quantized_matmul_device( 53 | x: impl AsRef, 54 | w: impl AsRef, 55 | scales: impl AsRef, 56 | biases: impl AsRef, 57 | #[optional] transpose: impl Into>, 58 | #[optional] group_size: impl Into>, 59 | #[optional] bits: impl Into>, 60 | #[optional] stream: impl AsRef, 61 | ) -> Result { 62 | let transpose = transpose.into().unwrap_or(false); 63 | let group_size = group_size.into().unwrap_or(64); 64 | let bits = bits.into().unwrap_or(4); 65 | 66 | ::try_from_op(|res| unsafe { 67 | mlx_sys::mlx_quantized_matmul( 68 | res, 69 | x.as_ref().as_ptr(), 70 | w.as_ref().as_ptr(), 71 | scales.as_ref().as_ptr(), 72 | biases.as_ref().as_ptr(), 73 | transpose, 74 | group_size, 75 | bits, 76 | stream.as_ref().as_ptr(), 77 | ) 78 | }) 79 | } 80 | 81 | /// Dequantize the matrix `w` using the provided `scales` and `biases` and the `group_size` and 82 | /// `bits` configuration. 83 | /// 84 | /// For details, please see [this 85 | /// documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.dequantize.html) 86 | #[generate_macro] 87 | #[default_device] 88 | pub fn dequantize_device( 89 | w: impl AsRef, 90 | scales: impl AsRef, 91 | biases: impl AsRef, 92 | #[optional] group_size: impl Into>, 93 | #[optional] bits: impl Into>, 94 | #[optional] stream: impl AsRef, 95 | ) -> Result { 96 | let group_size = group_size.into().unwrap_or(64); 97 | let bits = bits.into().unwrap_or(4); 98 | 99 | ::try_from_op(|res| unsafe { 100 | mlx_sys::mlx_dequantize( 101 | res, 102 | w.as_ref().as_ptr(), 103 | scales.as_ref().as_ptr(), 104 | biases.as_ref().as_ptr(), 105 | group_size, 106 | bits, 107 | stream.as_ref().as_ptr(), 108 | ) 109 | }) 110 | } 111 | 112 | #[cfg(test)] 113 | mod tests { 114 | use crate::{ 115 | ops::{dequantize, expand_dims, quantize}, 116 | Array, 117 | }; 118 | 119 | #[test] 120 | fn test_quantize_dequantize() { 121 | let x1 = Array::ones::(&[128, 1]).unwrap(); 122 | let x2 = expand_dims(Array::arange::<_, f32>(0, 512, None).unwrap(), 0).unwrap(); 123 | let x = x1 * x2; 124 | 125 | for i in [2, 4, 8].iter() { 126 | let el_per_int = 32 / i; 127 | let (x_q, scales, biases) = quantize(&x, 128, *i).unwrap(); 128 | assert_eq!(x_q.shape(), [128, 512 / el_per_int]); 129 | assert_eq!(scales.shape(), [128, 4]); 130 | assert_eq!(biases.shape(), [128, 4]); 131 | 132 | let x_hat = dequantize(&x_q, &scales, &biases, 128, *i).unwrap(); 133 | let max_diff = ((&x - &x_hat).abs().unwrap().max(None).unwrap()).item::(); 134 | assert!(max_diff <= 127.0 / (1 << i) as f32); 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/adadelta.rs: -------------------------------------------------------------------------------- 1 | use std::rc::Rc; 2 | 3 | use crate::{ 4 | array, 5 | ops::sqrt, 6 | utils::{get_mut_or_insert_with, Updatable}, 7 | Array, 8 | }; 9 | use mlx_internal_macros::{generate_builder, Buildable}; 10 | 11 | use crate::error::AdaDeltaBuildError; 12 | 13 | use super::*; 14 | 15 | generate_builder! { 16 | /// The AdaDelta optimizer with a learning rate 17 | /// 18 | /// Please refer to the original paper for more details: 19 | /// 20 | /// [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701. 21 | #[derive(Debug, Clone, Buildable)] 22 | #[buildable(root = crate)] 23 | #[builder( 24 | build_with = build_adadelta, 25 | err = AdaDeltaBuildError, 26 | root = crate 27 | )] 28 | pub struct AdaDelta { 29 | /// The learning rate 30 | #[builder(ty_override = f32)] 31 | pub lr: Array, 32 | 33 | /// The coefficient used for computing a running average of squared gradients. Default to 34 | /// [`AdaDelta::DEFAULT_RHO`]. 35 | #[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_RHO)] 36 | pub rho: Array, 37 | 38 | /// The epsilon added to the denominator to improve numerical stability. Default to 39 | /// [`AdaDelta::DEFAULT_EPS`]. 40 | #[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_EPS)] 41 | pub eps: Array, 42 | 43 | /// Inner state 44 | #[builder(ignore)] 45 | pub state: State<(Array, Array)>, 46 | } 47 | } 48 | 49 | /// Builds a new [`AdaDelta`] optimizer 50 | fn build_adadelta(builder: AdaDeltaBuilder) -> Result { 51 | let rho = builder.rho; 52 | let eps = builder.eps; 53 | 54 | if rho < 0.0 { 55 | return Err(AdaDeltaBuildError::NegativeRho); 56 | } 57 | 58 | if eps < 0.0 { 59 | return Err(AdaDeltaBuildError::NegativeEps); 60 | } 61 | 62 | Ok(AdaDelta { 63 | lr: array!(builder.lr), 64 | rho: array!(rho), 65 | eps: array!(eps), 66 | state: State::new(), 67 | }) 68 | } 69 | 70 | impl AdaDelta { 71 | /// Default value for `rho` 72 | pub const DEFAULT_RHO: f32 = 0.99; 73 | 74 | /// Default value for `eps` 75 | pub const DEFAULT_EPS: f32 = 1e-6; 76 | } 77 | 78 | impl Optimizer for AdaDelta { 79 | type State = State<(Array, Array)>; 80 | 81 | fn state(&self) -> &Self::State { 82 | &self.state 83 | } 84 | 85 | fn state_mut(&mut self) -> &mut Self::State { 86 | &mut self.state 87 | } 88 | 89 | fn update_single( 90 | &mut self, 91 | key: &Rc, 92 | gradient: &Array, 93 | parameter: &mut Array, 94 | ) -> crate::error::Result<()> { 95 | let (v, u) = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0))); 96 | 97 | let one_minus_rho = array!(1.0).subtract(&self.rho)?; 98 | let first_term = self.rho.multiply(&v)?; 99 | let second_term = one_minus_rho.multiply(gradient.square()?)?; 100 | let v_new = first_term.add(&second_term)?; 101 | 102 | let num = sqrt(&u.add(&self.eps)?)?; 103 | let den = sqrt(&v_new.add(&self.eps)?)?; 104 | let d = num.divide(&den)?.multiply(gradient)?; 105 | let first_term = self.rho.multiply(&u)?; 106 | let second_term = one_minus_rho.multiply(d.square()?)?; 107 | let u_new = first_term.add(&second_term)?; 108 | 109 | let param_new = parameter.subtract(self.lr.multiply(d)?)?; 110 | 111 | *parameter = param_new; 112 | 113 | *v = v_new; 114 | *u = u_new; 115 | 116 | Ok(()) 117 | } 118 | } 119 | 120 | impl Updatable for AdaDelta { 121 | fn updatable_states_len(&self) -> usize { 122 | self.state.len() * 2 123 | } 124 | 125 | fn updatable_states(&self) -> impl IntoIterator { 126 | use itertools::Itertools; 127 | 128 | self.state 129 | .iter() 130 | .sorted_by(|a, b| a.0.cmp(b.0)) 131 | .flat_map(|(_, (v, u))| [v, u]) 132 | } 133 | 134 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 135 | use itertools::Itertools; 136 | 137 | self.state 138 | .iter_mut() 139 | .sorted_by(|a, b| a.0.cmp(b.0)) 140 | .flat_map(|(_, (v, u))| [v, u]) 141 | } 142 | } 143 | 144 | impl_updatable_for_mut_optimizer!(AdaDelta); 145 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/adagrad.rs: -------------------------------------------------------------------------------- 1 | use std::{convert::Infallible, rc::Rc}; 2 | 3 | use crate::{array, ops::square, utils::Updatable, Array}; 4 | use mlx_internal_macros::{generate_builder, Buildable}; 5 | 6 | use crate::utils::get_mut_or_insert_with; 7 | 8 | use super::*; 9 | 10 | generate_builder! { 11 | /// The Adagrad optimizer. 12 | /// 13 | /// Please refer to the original paper for more details: 14 | /// 15 | /// [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods for online 16 | /// learning and stochastic optimization. JMLR 2011. 17 | #[derive(Debug, Clone, Buildable)] 18 | #[buildable(root = crate)] 19 | #[builder( 20 | build_with = build_adagrad, 21 | root = crate 22 | )] 23 | pub struct AdaGrad { 24 | /// Learning rate 25 | #[builder(ty_override = f32)] 26 | pub lr: Array, 27 | 28 | /// The epsilon added to the denominator to improve numerical stability. Default to 29 | /// [`AdaGrad::DEFAULT_EPS`]. 30 | #[builder(optional, ty_override = f32, default = AdaGrad::DEFAULT_EPS)] 31 | pub eps: Array, 32 | 33 | /// Inner state 34 | #[builder(ignore)] 35 | pub state: State, 36 | } 37 | } 38 | 39 | /// Builds a new [`AdaGrad`]. 40 | fn build_adagrad(builder: AdaGradBuilder) -> Result { 41 | let eps = array!(builder.eps); 42 | 43 | Ok(AdaGrad { 44 | lr: array!(builder.lr), 45 | eps, 46 | state: State::new(), 47 | }) 48 | } 49 | 50 | impl AdaGrad { 51 | /// Default value for `eps`. 52 | pub const DEFAULT_EPS: f32 = 1e-8; 53 | } 54 | 55 | impl Optimizer for AdaGrad { 56 | type State = State; 57 | 58 | fn state(&self) -> &Self::State { 59 | &self.state 60 | } 61 | 62 | fn state_mut(&mut self) -> &mut Self::State { 63 | &mut self.state 64 | } 65 | 66 | fn update_single( 67 | &mut self, 68 | key: &Rc, 69 | gradient: &Array, 70 | parameter: &mut Array, 71 | ) -> crate::error::Result<()> { 72 | let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0)); 73 | 74 | let v = state.add(square(gradient)?)?; 75 | 76 | let num = self.lr.multiply(gradient)?; 77 | let den = v.sqrt()?.add(&self.eps)?; 78 | let new_param = parameter.subtract(num.divide(&den)?)?; 79 | 80 | *state = v; 81 | *parameter = new_param; 82 | 83 | Ok(()) 84 | } 85 | } 86 | 87 | impl Updatable for AdaGrad { 88 | fn updatable_states_len(&self) -> usize { 89 | self.state.len() 90 | } 91 | 92 | fn updatable_states(&self) -> impl IntoIterator { 93 | use itertools::Itertools; 94 | 95 | self.state 96 | .iter() 97 | .sorted_by(|a, b| a.0.cmp(b.0)) 98 | .map(|(_, v)| v) 99 | } 100 | 101 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 102 | use itertools::Itertools; 103 | 104 | self.state 105 | .iter_mut() 106 | .sorted_by(|a, b| a.0.cmp(b.0)) 107 | .map(|(_, v)| v) 108 | } 109 | } 110 | 111 | impl_updatable_for_mut_optimizer!(AdaGrad); 112 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/adam.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | 3 | use mlx_internal_macros::{generate_builder, Buildable}; 4 | 5 | use crate::{array, utils::get_mut_or_insert_with}; 6 | 7 | use super::*; 8 | 9 | /// `(f32, f32O)`. Type alias for betas in the Adam/AdamW/Adamax optimizer builders due to 10 | /// limitation in the `generate_builder` macro 11 | pub type Betas = (f32, f32); // The macro right now can't handle raw tuple types 12 | 13 | generate_builder! { 14 | /// The Adam optimizer. 15 | /// 16 | /// Please refer to the original paper for more details: 17 | /// 18 | /// [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015. 19 | #[derive(Debug, Clone, Buildable)] 20 | #[buildable(root = crate)] 21 | #[builder( 22 | build_with = build_adam, 23 | root = crate 24 | )] 25 | pub struct Adam { 26 | /// The learning rate 27 | #[builder(ty_override = f32)] 28 | pub lr: Array, 29 | 30 | /// The coefficients used for computing running averages of the gradient and its square 31 | /// 32 | /// Default to [`Adam::DEFAULT_BETAS`] 33 | #[builder(optional, ty_override = Betas, default = Adam::DEFAULT_BETAS)] 34 | pub betas: (Array, Array), 35 | 36 | /// The epsilon added to the denominator to improve numerical stability 37 | /// 38 | /// Default to [`Adam::DEFAULT_EPS`] 39 | #[builder(optional, ty_override = f32, default = Adam::DEFAULT_EPS)] 40 | pub eps: Array, 41 | 42 | /// Inner state 43 | #[builder(ignore)] 44 | pub state: State<(Array, Array)>, 45 | } 46 | } 47 | 48 | /// Builds a new [`Adam`]. 49 | fn build_adam(builder: AdamBuilder) -> Result { 50 | let lr = array!(builder.lr); 51 | let betas = builder.betas; 52 | let eps = array!(builder.eps); 53 | 54 | Ok(Adam { 55 | lr, 56 | betas: (array!(betas.0), array!(betas.1)), 57 | eps, 58 | state: State::new(), 59 | }) 60 | } 61 | 62 | impl Adam { 63 | /// Default values for `betas` 64 | pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999); 65 | 66 | /// Default value for `eps` 67 | pub const DEFAULT_EPS: f32 = 1e-8; 68 | } 69 | 70 | impl Optimizer for Adam { 71 | type State = State<(Array, Array)>; 72 | 73 | fn state(&self) -> &Self::State { 74 | &self.state 75 | } 76 | 77 | fn state_mut(&mut self) -> &mut Self::State { 78 | &mut self.state 79 | } 80 | 81 | fn update_single( 82 | &mut self, 83 | key: &Rc, 84 | gradient: &Array, 85 | parameter: &mut Array, 86 | ) -> crate::error::Result<()> { 87 | let betas = &self.betas; 88 | let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0))); 89 | 90 | let (new_parameter, new_state) = 91 | adam_apply_single(&self.lr, betas, &self.eps, gradient, parameter, state)?; 92 | 93 | *state = new_state; 94 | *parameter = new_parameter; 95 | 96 | Ok(()) 97 | } 98 | } 99 | 100 | // Returns (new_parameter, (new_m, new_v)) 101 | pub(super) fn adam_apply_single( 102 | lr: &Array, 103 | betas: &(Array, Array), 104 | eps: &Array, 105 | gradient: &Array, 106 | parameter: &Array, 107 | state: &(Array, Array), 108 | ) -> crate::error::Result<(Array, (Array, Array))> { 109 | let (b1, b2) = betas; 110 | let (m, v) = state; 111 | 112 | let one_minus_b1 = array!(1.0).subtract(b1)?; 113 | let one_minus_b2 = array!(1.0).subtract(b2)?; 114 | 115 | let new_m = b1.multiply(m)?.add(&one_minus_b1.multiply(gradient)?)?; 116 | let new_v = b2 117 | .multiply(v)? 118 | .add(&one_minus_b2.multiply(gradient.square()?)?)?; 119 | 120 | let new_parameter = 121 | parameter.subtract(&lr.multiply(&new_m.divide(&new_v.sqrt()?.add(eps)?)?)?)?; 122 | 123 | Ok((new_parameter, (new_m, new_v))) 124 | } 125 | 126 | impl Updatable for Adam { 127 | fn updatable_states_len(&self) -> usize { 128 | self.state.len() * 2 129 | } 130 | 131 | fn updatable_states(&self) -> impl IntoIterator { 132 | use itertools::Itertools; 133 | 134 | self.state 135 | .iter() 136 | .sorted_by(|a, b| a.0.cmp(b.0)) 137 | .flat_map(|(_, (v, u))| vec![v, u]) 138 | } 139 | 140 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 141 | use itertools::Itertools; 142 | 143 | self.state 144 | .iter_mut() 145 | .sorted_by(|a, b| a.0.cmp(b.0)) 146 | .flat_map(|(_, (v, u))| vec![v, u]) 147 | } 148 | } 149 | 150 | impl_updatable_for_mut_optimizer!(Adam); 151 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/adamax.rs: -------------------------------------------------------------------------------- 1 | use std::{convert::Infallible, rc::Rc}; 2 | 3 | use mlx_internal_macros::{generate_builder, Buildable}; 4 | 5 | use crate::{ 6 | array, 7 | ops::{abs, maximum}, 8 | utils::{get_mut_or_insert_with, Updatable}, 9 | Array, 10 | }; 11 | 12 | use super::*; 13 | 14 | generate_builder! { 15 | /// The Adamax optimizer, a variant of Adam based on the infinity norm [1]. 16 | /// 17 | /// Our Adam implementation follows the original paper and omits the bias 18 | /// correction in the first and second moment estimates. In detail, 19 | /// 20 | /// [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015. 21 | #[derive(Debug, Clone, Buildable)] 22 | #[buildable(root = crate)] 23 | #[builder( 24 | build_with = build_adamax, 25 | root = crate 26 | )] 27 | pub struct Adamax { 28 | /// The learning rate. 29 | #[builder(ty_override = f32)] 30 | pub lr: Array, 31 | 32 | /// The beta coefficients 33 | #[builder(optional, ty_override = Betas, default = Adamax::DEFAULT_BETAS)] 34 | pub betas: (Array, Array), 35 | 36 | /// The epsilon added to the denominator to improve numerical stability. 37 | #[builder(optional, ty_override = f32, default = Adamax::DEFAULT_EPS)] 38 | pub eps: Array, 39 | 40 | /// Inner state. 41 | #[builder(ignore)] 42 | pub state: State<(Array, Array)>, 43 | } 44 | } 45 | 46 | fn build_adamax(builder: AdamaxBuilder) -> Result { 47 | let lr = builder.lr; 48 | let betas = builder.betas; 49 | let eps = builder.eps; 50 | 51 | Ok(Adamax { 52 | lr: array!(lr), 53 | betas: (array!(betas.0), array!(betas.1)), 54 | eps: array!(eps), 55 | state: State::new(), 56 | }) 57 | } 58 | 59 | impl Adamax { 60 | /// Default value for `betas`. 61 | pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999); 62 | 63 | /// Default value for `eps`. 64 | pub const DEFAULT_EPS: f32 = 1e-8; 65 | } 66 | 67 | impl Optimizer for Adamax { 68 | type State = State<(Array, Array)>; 69 | 70 | fn state(&self) -> &Self::State { 71 | &self.state 72 | } 73 | 74 | fn state_mut(&mut self) -> &mut Self::State { 75 | &mut self.state 76 | } 77 | 78 | fn update_single( 79 | &mut self, 80 | key: &Rc, 81 | gradient: &Array, 82 | parameter: &mut Array, 83 | ) -> crate::error::Result<()> { 84 | let (b1, b2) = &self.betas; 85 | let (m, v) = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0))); 86 | 87 | let one_minus_b1 = array!(1.0).subtract(b1)?; 88 | let new_m = b1.multiply(&*m)?.add(&one_minus_b1.multiply(gradient)?)?; 89 | let new_v = maximum(b2.multiply(&*v)?, abs(gradient)?)?; 90 | 91 | let new_parameter = 92 | parameter.subtract(self.lr.multiply(&new_m)?.divide(&new_v.add(&self.eps)?)?)?; 93 | 94 | *m = new_m; 95 | *v = new_v; 96 | *parameter = new_parameter; 97 | 98 | Ok(()) 99 | } 100 | } 101 | 102 | impl Updatable for Adamax { 103 | fn updatable_states_len(&self) -> usize { 104 | self.state.len() * 2 105 | } 106 | 107 | fn updatable_states(&self) -> impl IntoIterator { 108 | use itertools::Itertools; 109 | 110 | self.state 111 | .iter() 112 | .sorted_by(|a, b| a.0.cmp(b.0)) 113 | .flat_map(|(_, (v, u))| vec![v, u]) 114 | } 115 | 116 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 117 | use itertools::Itertools; 118 | 119 | self.state 120 | .iter_mut() 121 | .sorted_by(|a, b| a.0.cmp(b.0)) 122 | .flat_map(|(_, (v, u))| vec![v, u]) 123 | } 124 | } 125 | 126 | impl_updatable_for_mut_optimizer!(Adamax); 127 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/adamw.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | 3 | use mlx_internal_macros::{generate_builder, Buildable}; 4 | 5 | use crate::{ 6 | array, 7 | utils::{get_mut_or_insert_with, Updatable}, 8 | Array, 9 | }; 10 | 11 | use super::*; 12 | 13 | generate_builder! { 14 | /// The AdamW optimizer [1]. 15 | /// 16 | /// Following the above convention, in contrast with [1], we do not use bias 17 | /// correction in the first and second moments for AdamW. We update the weights 18 | /// with a `weightDecay` lambda value: 19 | /// 20 | /// [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay regularization. ICLR 2019. 21 | #[derive(Debug, Clone, Buildable)] 22 | #[buildable(root = crate)] 23 | #[builder( 24 | build_with = build_adamw, 25 | root = crate 26 | )] 27 | pub struct AdamW { 28 | /// The learning rate. 29 | #[builder(ty_override = f32)] 30 | pub lr: Array, 31 | 32 | /// The coefficients used for computing running averages of the gradient and its square. 33 | /// 34 | /// Default to [`AdamW::DEFAULT_BETAS`]. 35 | #[builder(optional, ty_override = Betas, default = AdamW::DEFAULT_BETAS)] 36 | pub betas: (Array, Array), 37 | 38 | /// The epsilon added to the denominator to improve numerical stability. 39 | /// 40 | /// Default to [`AdamW::DEFAULT_EPS`]. 41 | #[builder(optional, ty_override = f32, default = AdamW::DEFAULT_EPS)] 42 | pub eps: Array, 43 | 44 | /// The weight decay 45 | /// 46 | /// Default to [`AdamW::DEFAULT_WEIGHT_DECAY`]. 47 | #[builder(optional, ty_override = f32, default = AdamW::DEFAULT_WEIGHT_DECAY)] 48 | pub weight_decay: Array, 49 | 50 | /// Inner state. 51 | #[builder(ignore)] 52 | pub state: State<(Array, Array)>, 53 | } 54 | } 55 | 56 | /// Builds a new [`AdamW`] optimizer. 57 | fn build_adamw(builder: AdamWBuilder) -> Result { 58 | let lr = builder.lr; 59 | let betas = builder.betas; 60 | let eps = builder.eps; 61 | let weight_decay = builder.weight_decay; 62 | 63 | Ok(AdamW { 64 | lr: array!(lr), 65 | betas: (array!(betas.0), array!(betas.1)), 66 | eps: array!(eps), 67 | weight_decay: array!(weight_decay), 68 | state: State::new(), 69 | }) 70 | } 71 | 72 | impl AdamW { 73 | /// Default value for `betas`. 74 | pub const DEFAULT_BETAS: (f32, f32) = super::Adam::DEFAULT_BETAS; 75 | 76 | /// Default value for `eps`. 77 | pub const DEFAULT_EPS: f32 = super::Adam::DEFAULT_EPS; 78 | 79 | /// Default value for `weight_decay`. 80 | pub const DEFAULT_WEIGHT_DECAY: f32 = 0.01; 81 | } 82 | 83 | impl Optimizer for AdamW { 84 | type State = State<(Array, Array)>; 85 | 86 | fn state(&self) -> &Self::State { 87 | &self.state 88 | } 89 | 90 | fn state_mut(&mut self) -> &mut Self::State { 91 | &mut self.state 92 | } 93 | 94 | fn update_single( 95 | &mut self, 96 | key: &std::rc::Rc, 97 | gradient: &Array, 98 | parameter: &mut Array, 99 | ) -> Result<(), crate::error::Exception> { 100 | let betas = &self.betas; 101 | let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0))); 102 | 103 | // SAFETY: These are all single-element arrays and won't panic. 104 | let one_minus_lr_wd = array!(1.0) - (&self.lr * &self.weight_decay); 105 | let decayed_parameter = &*parameter * &one_minus_lr_wd; 106 | 107 | let (new_parameter, new_states) = super::adam_apply_single( 108 | &self.lr, 109 | betas, 110 | &self.eps, 111 | gradient, 112 | &decayed_parameter, 113 | state, 114 | )?; 115 | 116 | *state = new_states; 117 | *parameter = new_parameter; 118 | 119 | Ok(()) 120 | } 121 | } 122 | 123 | impl Updatable for AdamW { 124 | fn updatable_states_len(&self) -> usize { 125 | self.state.len() * 2 126 | } 127 | 128 | fn updatable_states(&self) -> impl IntoIterator { 129 | use itertools::Itertools; 130 | 131 | self.state 132 | .iter() 133 | .sorted_by(|a, b| a.0.cmp(b.0)) 134 | .flat_map(|(_, (v, u))| vec![v, u]) 135 | } 136 | 137 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 138 | use itertools::Itertools; 139 | 140 | self.state 141 | .iter_mut() 142 | .sorted_by(|a, b| a.0.cmp(b.0)) 143 | .flat_map(|(_, (v, u))| vec![v, u]) 144 | } 145 | } 146 | 147 | impl_updatable_for_mut_optimizer!(AdamW); 148 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/lion.rs: -------------------------------------------------------------------------------- 1 | use mlx_internal_macros::{generate_builder, Buildable}; 2 | 3 | use crate::{ 4 | array, 5 | utils::{get_mut_or_insert_with, Updatable}, 6 | Array, 7 | }; 8 | 9 | use super::*; 10 | 11 | generate_builder! { 12 | /// The Lion optimizer [1]. 13 | /// 14 | /// Since updates are computed through the sign operation, they tend to have larger norm than 15 | /// for other optimizers such as SGD and Adam. We recommend a learning rate that is 3-10x 16 | /// smaller than AdamW and a weight decay 3-10x larger than AdamW to maintain the strength `(lr 17 | /// * wd)`. Our Lion implementation follows the original paper. In detail, 18 | /// 19 | /// [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv preprint 20 | /// arXiv:2302.06675. 21 | #[derive(Debug, Clone, Buildable)] 22 | #[buildable(root = crate)] 23 | #[builder( 24 | build_with = build_lion, 25 | root = crate 26 | )] 27 | pub struct Lion { 28 | /// The learning rate. 29 | pub lr: f32, 30 | 31 | /// The coefficients used for computing running averages of the gradient and its square. 32 | /// Default to [`Lion::DEFAULT_BETAS`]. 33 | #[builder(optional, ty_override = Betas, default = Lion::DEFAULT_BETAS)] 34 | pub betas: (Array, Array), 35 | 36 | /// The weight decay. Default to [`Lion::DEFAULT_WEIGHT_DECAY`]. 37 | #[builder(optional, default = Lion::DEFAULT_WEIGHT_DECAY)] 38 | pub weight_decay: f32, 39 | 40 | /// Inner state. 41 | #[builder(ignore)] 42 | pub state: State, 43 | } 44 | } 45 | 46 | fn build_lion(builder: LionBuilder) -> Result { 47 | let lr = builder.lr; 48 | let betas = builder.betas; 49 | let weight_decay = builder.weight_decay; 50 | 51 | Ok(Lion { 52 | lr, 53 | betas: (array!(betas.0), array!(betas.1)), 54 | weight_decay, 55 | state: State::new(), 56 | }) 57 | } 58 | 59 | impl Lion { 60 | /// Default values for `betas` 61 | pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999); 62 | 63 | /// Default value for `weight_decay` 64 | pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0; 65 | } 66 | 67 | impl Optimizer for Lion { 68 | type State = State; 69 | 70 | fn state(&self) -> &Self::State { 71 | &self.state 72 | } 73 | 74 | fn state_mut(&mut self) -> &mut Self::State { 75 | &mut self.state 76 | } 77 | 78 | fn update_single( 79 | &mut self, 80 | key: &std::rc::Rc, 81 | gradient: &Array, 82 | parameter: &mut Array, 83 | ) -> Result<(), crate::error::Exception> { 84 | use crate::ops::sign; 85 | 86 | let (b1, b2) = &self.betas; 87 | let m = get_mut_or_insert_with(&mut self.state, key, || array!(0.0)); 88 | 89 | let one_minus_b1 = array!(1.0).subtract(b1)?; 90 | let one_minus_b2 = array!(1.0).subtract(b2)?; 91 | 92 | let c = b1.multiply(&m)?.add(&one_minus_b1.multiply(gradient)?)?; 93 | *m = b2.multiply(&m)?.add(&one_minus_b2.multiply(gradient)?)?; 94 | 95 | if self.weight_decay > 0.0 { 96 | // SAFETY: These coeffs are all single-element arrays and won't panic. 97 | *parameter = array!(1.0 - self.lr * self.weight_decay) * &*parameter; 98 | } 99 | 100 | let lr = array!(self.lr); 101 | *parameter = parameter.subtract(lr.multiply(sign(&c)?)?)?; 102 | 103 | Ok(()) 104 | } 105 | } 106 | 107 | impl Updatable for Lion { 108 | fn updatable_states_len(&self) -> usize { 109 | self.state.len() 110 | } 111 | 112 | fn updatable_states(&self) -> impl IntoIterator { 113 | use itertools::Itertools; 114 | 115 | self.state 116 | .iter() 117 | .sorted_by(|a, b| a.0.cmp(b.0)) 118 | .map(|(_, v)| v) 119 | } 120 | 121 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 122 | use itertools::Itertools; 123 | 124 | self.state 125 | .iter_mut() 126 | .sorted_by(|a, b| a.0.cmp(b.0)) 127 | .map(|(_, v)| v) 128 | } 129 | } 130 | 131 | impl_updatable_for_mut_optimizer!(Lion); 132 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/rmsprop.rs: -------------------------------------------------------------------------------- 1 | use std::rc::Rc; 2 | 3 | use crate::{ 4 | array, 5 | ops::{sqrt, square}, 6 | Array, 7 | }; 8 | use mlx_internal_macros::{generate_builder, Buildable}; 9 | 10 | use crate::{error::RmsPropBuildError, utils::get_mut_or_insert_with}; 11 | 12 | use super::*; 13 | 14 | generate_builder! { 15 | /// The RMSprop optimizer [1]. 16 | /// 17 | /// [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for 18 | /// machine learning 19 | #[derive(Debug, Clone, Buildable)] 20 | #[buildable(root = crate)] 21 | #[builder( 22 | build_with = build_rmdprop, 23 | err = RmsPropBuildError, 24 | root = crate 25 | )] 26 | pub struct RmsProp { 27 | /// Learning rate 28 | #[builder(ty_override = f32)] 29 | pub lr: Array, 30 | 31 | /// The smoothing constant. Default to [`RmsProp::DEFAULT_ALPHA`] if not specified. 32 | #[builder(optional, ty_override = f32, default = RmsProp::DEFAULT_ALPHA)] 33 | pub alpha: Array, 34 | 35 | /// The epsilon added to the denominator to improve numerical stability. Default to 36 | /// [`RmsProp::DEFAULT_EPSILON`] if not specified. 37 | #[builder(optional, ty_override = f32, default = RmsProp::DEFAULT_EPSILON)] 38 | pub epsilon: Array, 39 | 40 | /// Inner state 41 | #[builder(ignore)] 42 | pub state: State, 43 | } 44 | } 45 | 46 | fn build_rmdprop(builder: RmsPropBuilder) -> Result { 47 | let lr = builder.lr; 48 | let alpha = builder.alpha; 49 | let epsilon = builder.epsilon; 50 | 51 | if alpha < 0.0 { 52 | return Err(RmsPropBuildError::NegativeAlpha); 53 | } 54 | 55 | if epsilon < 0.0 { 56 | return Err(RmsPropBuildError::NegativeEpsilon); 57 | } 58 | 59 | Ok(RmsProp { 60 | lr: array!(lr), 61 | alpha: array!(alpha), 62 | epsilon: array!(epsilon), 63 | state: State::new(), 64 | }) 65 | } 66 | 67 | impl RmsProp { 68 | /// Default alpha if not specified. 69 | pub const DEFAULT_ALPHA: f32 = 0.99; 70 | 71 | /// Default epsilon if not specified. 72 | pub const DEFAULT_EPSILON: f32 = 1e-8; 73 | } 74 | 75 | impl Optimizer for RmsProp { 76 | type State = State; 77 | 78 | fn state(&self) -> &Self::State { 79 | &self.state 80 | } 81 | 82 | fn state_mut(&mut self) -> &mut Self::State { 83 | &mut self.state 84 | } 85 | 86 | fn update_single( 87 | &mut self, 88 | key: &Rc, 89 | gradient: &Array, 90 | parameter: &mut Array, 91 | ) -> crate::error::Result<()> { 92 | let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0)); 93 | 94 | let lr = &self.lr; 95 | let alpha = &self.alpha; 96 | let eps = &self.epsilon; 97 | 98 | let one_minus_alpha = array!(1.0).subtract(alpha)?; 99 | let first_term = alpha.multiply(&*state)?; 100 | let second_term = one_minus_alpha.multiply(square(gradient)?)?; 101 | let v = first_term.add(&second_term)?; 102 | 103 | let num = lr.multiply(gradient)?; 104 | let den = sqrt(&v)?.add(eps)?; 105 | let new_param = parameter.subtract(num.divide(&den)?)?; 106 | 107 | *parameter = new_param; 108 | *state = v; 109 | 110 | Ok(()) 111 | } 112 | } 113 | 114 | impl Updatable for RmsProp { 115 | fn updatable_states_len(&self) -> usize { 116 | self.state.len() 117 | } 118 | 119 | fn updatable_states(&self) -> impl IntoIterator { 120 | use itertools::Itertools; 121 | 122 | self.state 123 | .iter() 124 | .sorted_by(|a, b| a.0.cmp(b.0)) 125 | .map(|(_, v)| v) 126 | } 127 | 128 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 129 | use itertools::Itertools; 130 | 131 | self.state 132 | .iter_mut() 133 | .sorted_by(|a, b| a.0.cmp(b.0)) 134 | .map(|(_, v)| v) 135 | } 136 | } 137 | 138 | impl_updatable_for_mut_optimizer!(RmsProp); 139 | -------------------------------------------------------------------------------- /mlx-rs/src/optimizers/sgd.rs: -------------------------------------------------------------------------------- 1 | use std::{borrow::Cow, rc::Rc}; 2 | 3 | use crate::{array, utils::get_mut_or_insert_with, Array}; 4 | use mlx_internal_macros::{generate_builder, Buildable}; 5 | 6 | use super::*; 7 | 8 | generate_builder! { 9 | /// Stochastic gradient descent optimizer. 10 | #[derive(Debug, Clone, Buildable)] 11 | #[buildable(root = crate)] 12 | #[builder( 13 | build_with = build_sgd, 14 | root = crate 15 | )] 16 | pub struct Sgd { 17 | /// Learning rate 18 | pub lr: f32, 19 | 20 | /// Momentum strength. Default to [`Sgd::DEFAULT_MOMENTUM`] if not specified. 21 | #[builder(optional, default = Sgd::DEFAULT_MOMENTUM)] 22 | pub momentum: f32, 23 | 24 | /// Weight decay (L2 penalty). Default to [`Sgd::DEFAULT_WEIGHT_DECAY`] if not specified. 25 | #[builder(optional, default = Sgd::DEFAULT_WEIGHT_DECAY)] 26 | pub weight_decay: f32, 27 | 28 | /// Dampening for momentum. Default to [`Sgd::DEFAULT_DAMPENING`] if not specified. 29 | #[builder(optional, default = Sgd::DEFAULT_DAMPENING)] 30 | pub dampening: f32, 31 | 32 | /// Enables nesterov momentum. Default to [`Sgd::DEFAULT_NESTEROV`] if not specified. 33 | #[builder(optional, ty_override = bool, default = Sgd::DEFAULT_NESTEROV)] 34 | pub nesterov: bool, 35 | 36 | /// Inner state 37 | #[builder(ignore)] 38 | pub state: State, 39 | } 40 | } 41 | 42 | fn build_sgd(builder: SgdBuilder) -> Result { 43 | let lr = builder.lr; 44 | let momentum = builder.momentum; 45 | let weight_decay = builder.weight_decay; 46 | let dampening = builder.dampening; 47 | let nesterov = builder.nesterov; 48 | 49 | Ok(Sgd { 50 | lr, 51 | momentum, 52 | weight_decay, 53 | dampening, 54 | nesterov, 55 | state: State::new(), 56 | }) 57 | } 58 | 59 | impl Sgd { 60 | /// Default momentum if not specified. 61 | pub const DEFAULT_MOMENTUM: f32 = 0.0; 62 | 63 | /// Default weight decay if not specified. 64 | pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0; 65 | 66 | /// Default dampening if not specified. 67 | pub const DEFAULT_DAMPENING: f32 = 0.0; 68 | 69 | /// Default nesterov if not specified. 70 | pub const DEFAULT_NESTEROV: bool = false; 71 | } 72 | 73 | impl Optimizer for Sgd { 74 | type State = State; 75 | 76 | fn state(&self) -> &Self::State { 77 | &self.state 78 | } 79 | 80 | fn state_mut(&mut self) -> &mut Self::State { 81 | &mut self.state 82 | } 83 | 84 | /// Apply SGD to a single parameter. Returns the updated parameter and the updated state. 85 | #[inline] 86 | fn update_single( 87 | &mut self, 88 | key: &Rc, 89 | gradient: &Array, 90 | parameter: &mut Array, 91 | ) -> crate::error::Result<()> { 92 | let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0)); 93 | let mut gradient = Cow::Borrowed(gradient); 94 | 95 | if self.weight_decay != 0.0 { 96 | let weight_decay = array!(self.weight_decay); 97 | gradient = Cow::Owned(weight_decay.multiply(&*parameter)?.add(&*gradient)?); 98 | } 99 | 100 | if self.momentum <= 0.0 { 101 | let lr = array!(self.lr); 102 | *parameter = parameter.subtract(lr.multiply(gradient)?)?; 103 | return Ok(()); 104 | } 105 | 106 | let mut v = &*state * self.momentum; 107 | 108 | if self.dampening > 0.0 { 109 | let dampening = array!(self.dampening); 110 | let one_minus_dampening = array!(1.0).subtract(dampening)?; 111 | v = v.add(&one_minus_dampening.multiply(&gradient)?)?; 112 | } else { 113 | v = v.add(&gradient)?; 114 | } 115 | 116 | match self.nesterov { 117 | true => { 118 | let momentum = array!(self.momentum); 119 | let lr = array!(self.lr); 120 | let update = gradient.add(momentum.multiply(&v)?)?; 121 | *parameter = parameter.subtract(lr.multiply(&update)?)?; 122 | *state = v; 123 | } 124 | false => { 125 | let update = &v; 126 | let lr = array!(self.lr); 127 | *parameter = parameter.subtract(lr.multiply(update)?)?; 128 | *state = v; 129 | } 130 | } 131 | 132 | Ok(()) 133 | } 134 | } 135 | 136 | impl Updatable for Sgd { 137 | fn updatable_states_len(&self) -> usize { 138 | self.state.len() 139 | } 140 | 141 | fn updatable_states(&self) -> impl IntoIterator { 142 | use itertools::Itertools; 143 | 144 | self.state 145 | .iter() 146 | .sorted_by(|a, b| a.0.cmp(b.0)) 147 | .map(|(_, v)| v) 148 | } 149 | 150 | fn updatable_states_mut(&mut self) -> impl IntoIterator { 151 | use itertools::Itertools; 152 | 153 | self.state 154 | .iter_mut() 155 | .sorted_by(|a, b| a.0.cmp(b.0)) 156 | .map(|(_, v)| v) 157 | } 158 | } 159 | 160 | impl_updatable_for_mut_optimizer!(Sgd); 161 | -------------------------------------------------------------------------------- /mlx-rs/src/quantization.rs: -------------------------------------------------------------------------------- 1 | //! Traits for quantization 2 | 3 | use crate::module::{Module, ModuleParameters}; 4 | 5 | /// Trait for quantization of modules. 6 | pub trait Quantizable { 7 | /// The default group size for quantization. 8 | const DEFAULT_GROUP_SIZE: i32 = 64; 9 | 10 | /// The default number of bits for quantization. 11 | const DEFAULT_BITS: i32 = 4; 12 | 13 | /// The quantized type. 14 | type Quantized; 15 | 16 | /// The error type for quantization. 17 | type QuantizationError; 18 | 19 | /// Quantize the module with the specified group size and number of bits. 20 | fn try_into_quantized( 21 | self, 22 | group_size: i32, 23 | bits: i32, 24 | ) -> Result; 25 | } 26 | 27 | impl Quantizable for Vec 28 | where 29 | M: Quantizable, 30 | { 31 | type Quantized = Vec; 32 | 33 | type QuantizationError = M::QuantizationError; 34 | 35 | fn try_into_quantized( 36 | self, 37 | group_size: i32, 38 | bits: i32, 39 | ) -> Result { 40 | self.into_iter() 41 | .map(|m| m.try_into_quantized(group_size, bits)) 42 | .collect() 43 | } 44 | } 45 | 46 | impl Quantizable for Box 47 | where 48 | M: Quantizable, 49 | { 50 | type Quantized = Box; 51 | 52 | type QuantizationError = M::QuantizationError; 53 | 54 | fn try_into_quantized( 55 | self, 56 | group_size: i32, 57 | bits: i32, 58 | ) -> Result { 59 | (*self).try_into_quantized(group_size, bits).map(Box::new) 60 | } 61 | } 62 | 63 | /// A wrapper for a quantizable module. 64 | #[derive(Debug, Clone)] 65 | pub enum MaybeQuantized 66 | where 67 | M: Quantizable, 68 | { 69 | /// The original module. 70 | Original(M), 71 | 72 | /// The quantized version of the module. 73 | Quantized(M::Quantized), 74 | } 75 | 76 | impl Quantizable for MaybeQuantized 77 | where 78 | M: Quantizable, 79 | { 80 | type Quantized = Self; 81 | type QuantizationError = ::QuantizationError; 82 | 83 | fn try_into_quantized( 84 | self, 85 | group_size: i32, 86 | bits: i32, 87 | ) -> Result { 88 | match self { 89 | MaybeQuantized::Original(m) => { 90 | let quantized = m.try_into_quantized(group_size, bits)?; 91 | Ok(MaybeQuantized::Quantized(quantized)) 92 | } 93 | MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)), 94 | } 95 | } 96 | } 97 | 98 | impl MaybeQuantized 99 | where 100 | M: Quantizable, 101 | { 102 | /// Create a new [`MaybeQuantized`] from the original module. 103 | pub fn new(module: M) -> Self { 104 | MaybeQuantized::Original(module) 105 | } 106 | 107 | /// Quantize the module with a custom quantization function. 108 | /// 109 | /// This is useful if one would like to quantize with a custom group size or bit width. 110 | pub fn quantize_with( 111 | self, 112 | op: impl FnOnce(M) -> Result, 113 | ) -> Result { 114 | match self { 115 | MaybeQuantized::Original(m) => op(m).map(MaybeQuantized::Quantized), 116 | MaybeQuantized::Quantized(q) => Ok(MaybeQuantized::Quantized(q)), 117 | } 118 | } 119 | 120 | /// Check if the module is quantized. 121 | pub fn is_quantized(&self) -> bool { 122 | match self { 123 | MaybeQuantized::Original(_) => false, 124 | MaybeQuantized::Quantized(_) => true, 125 | } 126 | } 127 | } 128 | 129 | impl ModuleParameters for MaybeQuantized 130 | where 131 | M: Quantizable + ModuleParameters, 132 | M::Quantized: ModuleParameters, 133 | { 134 | fn num_parameters(&self) -> usize { 135 | match self { 136 | MaybeQuantized::Original(m) => m.num_parameters(), 137 | MaybeQuantized::Quantized(q) => q.num_parameters(), 138 | } 139 | } 140 | 141 | fn parameters(&self) -> crate::module::ModuleParamRef<'_> { 142 | match self { 143 | MaybeQuantized::Original(m) => m.parameters(), 144 | MaybeQuantized::Quantized(q) => q.parameters(), 145 | } 146 | } 147 | 148 | fn parameters_mut(&mut self) -> crate::module::ModuleParamMut<'_> { 149 | match self { 150 | MaybeQuantized::Original(m) => m.parameters_mut(), 151 | MaybeQuantized::Quantized(q) => q.parameters_mut(), 152 | } 153 | } 154 | 155 | fn trainable_parameters(&self) -> crate::module::ModuleParamRef<'_> { 156 | match self { 157 | MaybeQuantized::Original(m) => m.trainable_parameters(), 158 | MaybeQuantized::Quantized(q) => q.trainable_parameters(), 159 | } 160 | } 161 | 162 | fn freeze_parameters(&mut self, recursive: bool) { 163 | match self { 164 | MaybeQuantized::Original(m) => m.freeze_parameters(recursive), 165 | MaybeQuantized::Quantized(q) => q.freeze_parameters(recursive), 166 | } 167 | } 168 | 169 | fn unfreeze_parameters(&mut self, recursive: bool) { 170 | match self { 171 | MaybeQuantized::Original(m) => m.unfreeze_parameters(recursive), 172 | MaybeQuantized::Quantized(q) => q.unfreeze_parameters(recursive), 173 | } 174 | } 175 | 176 | fn all_frozen(&self) -> Option { 177 | match self { 178 | MaybeQuantized::Original(m) => m.all_frozen(), 179 | MaybeQuantized::Quantized(q) => q.all_frozen(), 180 | } 181 | } 182 | 183 | fn any_frozen(&self) -> Option { 184 | match self { 185 | MaybeQuantized::Original(m) => m.any_frozen(), 186 | MaybeQuantized::Quantized(q) => q.any_frozen(), 187 | } 188 | } 189 | } 190 | 191 | impl Module for MaybeQuantized 192 | where 193 | M: Quantizable + Module, 194 | M::Quantized: 195 | Module>::Output, Error = >::Error>, 196 | { 197 | type Output = >::Output; 198 | 199 | type Error = >::Error; 200 | 201 | fn forward(&mut self, x: Input) -> Result { 202 | match self { 203 | MaybeQuantized::Original(m) => m.forward(x), 204 | MaybeQuantized::Quantized(q) => q.forward(x), 205 | } 206 | } 207 | 208 | fn training_mode(&mut self, mode: bool) { 209 | match self { 210 | MaybeQuantized::Original(m) => m.training_mode(mode), 211 | MaybeQuantized::Quantized(q) => q.training_mode(mode), 212 | } 213 | } 214 | } 215 | 216 | #[cfg(test)] 217 | mod tests { 218 | use crate::nn::{self, Embedding, Linear}; 219 | 220 | use super::*; 221 | 222 | #[test] 223 | fn test_quantizable_linear() { 224 | let linear = Linear::new(64, 64).unwrap(); 225 | let mut qlinear = MaybeQuantized::new(linear); 226 | assert!(!qlinear.is_quantized()); 227 | 228 | qlinear = nn::quantize(qlinear, None, None).unwrap(); 229 | assert!(qlinear.is_quantized()); 230 | } 231 | 232 | #[test] 233 | fn test_quantizable_embedding() { 234 | let embedding = Embedding::new(64, 64).unwrap(); 235 | let mut qembedding = MaybeQuantized::new(embedding); 236 | assert!(!qembedding.is_quantized()); 237 | 238 | qembedding = nn::quantize(qembedding, None, None).unwrap(); 239 | assert!(qembedding.is_quantized()); 240 | } 241 | } 242 | -------------------------------------------------------------------------------- /mlx-rs/src/stream.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::CStr; 2 | 3 | use crate::{ 4 | device::Device, 5 | error::Result, 6 | utils::{guard::Guarded, SUCCESS}, 7 | }; 8 | 9 | /// Parameter type for all MLX operations. 10 | /// 11 | /// Use this to control where operations are evaluated: 12 | /// 13 | /// If omitted it will use the [Default::default()], which will be [Device::gpu()] unless 14 | /// set otherwise. 15 | #[derive(PartialEq)] 16 | pub struct StreamOrDevice { 17 | pub(crate) stream: Stream, 18 | } 19 | 20 | impl StreamOrDevice { 21 | /// Create a new [`StreamOrDevice`] with a [`Stream`]. 22 | pub fn new(stream: Stream) -> StreamOrDevice { 23 | StreamOrDevice { stream } 24 | } 25 | 26 | /// Create a new [`StreamOrDevice`] with a [`Device`]. 27 | pub fn new_with_device(device: &Device) -> StreamOrDevice { 28 | StreamOrDevice { 29 | stream: Stream::new_with_device(device), 30 | } 31 | } 32 | 33 | /// Current default CPU stream. 34 | pub fn cpu() -> StreamOrDevice { 35 | StreamOrDevice { 36 | stream: Stream::cpu(), 37 | } 38 | } 39 | 40 | /// Current default GPU stream. 41 | pub fn gpu() -> StreamOrDevice { 42 | StreamOrDevice { 43 | stream: Stream::gpu(), 44 | } 45 | } 46 | } 47 | 48 | impl Default for StreamOrDevice { 49 | /// The default stream on the default device. 50 | /// 51 | /// This will be [Device::gpu()] unless [Device::set_default()] 52 | /// sets it otherwise. 53 | fn default() -> Self { 54 | Self { 55 | stream: Stream::new(), 56 | } 57 | } 58 | } 59 | 60 | impl AsRef for StreamOrDevice { 61 | fn as_ref(&self) -> &Stream { 62 | &self.stream 63 | } 64 | } 65 | 66 | impl std::fmt::Debug for StreamOrDevice { 67 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 68 | write!(f, "{}", self.stream) 69 | } 70 | } 71 | 72 | impl std::fmt::Display for StreamOrDevice { 73 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 74 | write!(f, "{}", self.stream) 75 | } 76 | } 77 | 78 | /// A stream of evaluation attached to a particular device. 79 | /// 80 | /// Typically, this is used via the `stream:` parameter on a method with a [StreamOrDevice]: 81 | pub struct Stream { 82 | pub(crate) c_stream: mlx_sys::mlx_stream, 83 | } 84 | 85 | impl AsRef for Stream { 86 | fn as_ref(&self) -> &Stream { 87 | self 88 | } 89 | } 90 | 91 | impl Stream { 92 | /// Create a new stream on the default device. Panics if fails. 93 | pub fn new() -> Stream { 94 | unsafe { 95 | let mut dev = mlx_sys::mlx_device_new(); 96 | // SAFETY: mlx_get_default_device internally never throws an error 97 | mlx_sys::mlx_get_default_device(&mut dev as *mut _); 98 | 99 | let mut c_stream = mlx_sys::mlx_stream_new(); 100 | // SAFETY: mlx_get_default_stream internally never throws if dev is valid 101 | mlx_sys::mlx_get_default_stream(&mut c_stream as *mut _, dev); 102 | 103 | mlx_sys::mlx_device_free(dev); 104 | Stream { c_stream } 105 | } 106 | } 107 | 108 | /// Try to get the default stream on the given device. 109 | pub fn try_default_on_device(device: &Device) -> Result { 110 | Stream::try_from_op(|res| unsafe { mlx_sys::mlx_get_default_stream(res, device.c_device) }) 111 | } 112 | 113 | /// Create a new stream on the given device 114 | pub fn new_with_device(device: &Device) -> Stream { 115 | unsafe { 116 | let c_stream = mlx_sys::mlx_stream_new_device(device.c_device); 117 | Stream { c_stream } 118 | } 119 | } 120 | 121 | /// Get the underlying C pointer. 122 | pub fn as_ptr(&self) -> mlx_sys::mlx_stream { 123 | self.c_stream 124 | } 125 | 126 | /// Current default CPU stream. 127 | pub fn cpu() -> Self { 128 | unsafe { 129 | let c_stream = mlx_sys::mlx_default_cpu_stream_new(); 130 | Stream { c_stream } 131 | } 132 | } 133 | 134 | /// Current default GPU stream. 135 | pub fn gpu() -> Self { 136 | unsafe { 137 | let c_stream = mlx_sys::mlx_default_gpu_stream_new(); 138 | Stream { c_stream } 139 | } 140 | } 141 | 142 | /// Get the index of the stream. 143 | pub fn get_index(&self) -> Result { 144 | i32::try_from_op(|res| unsafe { mlx_sys::mlx_stream_get_index(res, self.c_stream) }) 145 | } 146 | 147 | fn describe(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 148 | unsafe { 149 | let mut mlx_str = mlx_sys::mlx_string_new(); 150 | let result = match mlx_sys::mlx_stream_tostring(&mut mlx_str as *mut _, self.c_stream) { 151 | SUCCESS => { 152 | let ptr = mlx_sys::mlx_string_data(mlx_str); 153 | let c_str = CStr::from_ptr(ptr); 154 | write!(f, "{}", c_str.to_string_lossy()) 155 | } 156 | _ => Err(std::fmt::Error), 157 | }; 158 | mlx_sys::mlx_string_free(mlx_str); 159 | result 160 | } 161 | } 162 | } 163 | 164 | impl Drop for Stream { 165 | fn drop(&mut self) { 166 | unsafe { mlx_sys::mlx_stream_free(self.c_stream) }; 167 | } 168 | } 169 | 170 | impl Default for Stream { 171 | fn default() -> Self { 172 | Stream::new() 173 | } 174 | } 175 | 176 | impl std::fmt::Debug for Stream { 177 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 178 | self.describe(f) 179 | } 180 | } 181 | 182 | impl std::fmt::Display for Stream { 183 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 184 | self.describe(f) 185 | } 186 | } 187 | 188 | impl PartialEq for Stream { 189 | fn eq(&self, other: &Self) -> bool { 190 | unsafe { mlx_sys::mlx_stream_equal(self.c_stream, other.c_stream) } 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /mlx-rs/src/transforms/compile/mod.rs: -------------------------------------------------------------------------------- 1 | //! Compilation of functions. 2 | //! 3 | //! See also [MLX python 4 | //! documentation](https://ml-explore.github.io/mlx/build/html/usage/compile.html). 5 | //! 6 | //! MLX has a [`compile()`] function transformation which compiles computation 7 | //! graphs. Function compilation results in smaller graphs by merging common 8 | //! work and fusing certain operations. In many cases this can lead to big 9 | //! improvements in run-time and memory use. 10 | //! 11 | //! Getting started with compile() is simple, but there are some edge cases that 12 | //! are good to be aware of for more complex graphs and advanced usage. 13 | //! 14 | //! **WARN**: Because function transforms including compilation works on the 15 | //! computation graph, the user must ensure that all `Array`s are passed as 16 | //! inputs to the function/closure. Closures with captured `Array`s may not work 17 | //! as expected and may lead to undefined behavior. 18 | //! 19 | //! # Basic usage 20 | //! 21 | //! ```rust 22 | //! use mlx_rs::{Array, array, transforms::compile::compile, error::Exception}; 23 | //! 24 | //! let fun = |(x, y): (&Array, &Array)| -> Result { 25 | //! mlx_rs::exp!(x.negative()?)?.add(y) 26 | //! }; 27 | //! 28 | //! let x = array!(1.0); 29 | //! let y = array!(2.0); 30 | //! 31 | //! // Regular call, no compilation 32 | //! let result = fun((&x, &y)).unwrap(); 33 | //! // Prints: array(2.36788, dtype=float32) 34 | //! println!("{:?}", result); 35 | //! 36 | //! // Compile the function 37 | //! let mut compiled_fun = compile(fun, None); 38 | //! let result = compiled_fun((&x, &y)).unwrap(); 39 | //! // Prints: array(2.36788, dtype=float32) 40 | //! println!("{:?}", result); 41 | //! ``` 42 | //! 43 | //! The output of both the regular function and the compiled function is the 44 | //! same up to numerical precision. 45 | //! 46 | //! The first time you call a compiled function, MLX will build the compute 47 | //! graph, optimize it, and generate and compile code. This can be relatively 48 | //! slow. However, MLX will cache compiled functions, so calling a compiled 49 | //! function multiple times will not initiate a new compilation. This means you 50 | //! should typically compile functions that you plan to use more than once. 51 | //! 52 | //! ```rust 53 | //! use mlx_rs::{Array, array, transforms::compile::compile}; 54 | //! 55 | //! let fun = |(x, y): (&Array, &Array)| { 56 | //! mlx_rs::exp!(x.negative()?)?.add(y) 57 | //! }; 58 | //! 59 | //! let x = array!(1.0); 60 | //! let y = array!(2.0); 61 | //! 62 | //! let mut compiled_fun = compile(fun, None); 63 | //! 64 | //! // Compiled here 65 | //! let result = compiled_fun((&x, &y)).unwrap(); 66 | //! 67 | //! // Not compiled again 68 | //! let result = compiled_fun((&x, &y)).unwrap(); 69 | //! 70 | //! // Not compiled again 71 | //! let compiled_fun2 = compile(fun, None); 72 | //! ``` 73 | //! 74 | //! There are some important cases to be aware of that can cause a function to 75 | //! be recompiled: 76 | //! 77 | //! - Changing the shape or number of dimensions 78 | //! - Changing the type of any of the inputs 79 | //! - Changing the number of inputs to the function 80 | //! 81 | //! In certain cases only some of the compilation stack will be rerun (for 82 | //! example when changing the shapes) and in other cases the full compilation 83 | //! stack will be rerun (for example when changing the types). In general you 84 | //! should avoid compiling functions too frequently. 85 | //! 86 | //! Another idiom to watch out for is compiling functions which get created and 87 | //! destroyed frequently. This can happen, for example, when compiling an 88 | //! closure in a loop. 89 | //! 90 | //! # Pure Functions 91 | //! 92 | //! Compiled functions are intended to be pure; that is they should not have 93 | //! side effects. For example: 94 | //! 95 | //! ```rust,ignore 96 | //! use mlx_rs::{Array, array, transforms::compile::compile}; 97 | //! 98 | //! let mut c = array!(0.5); 99 | //! 100 | //! let fun = |(x, y): (&Array, &Array)| { 101 | //! let z = (x + y) * c; 102 | //! mlx_rs::exp!(z) 103 | //! }; 104 | //! 105 | //! let mut compiled = compile(fun, None); 106 | //! 107 | //! let x = array!(1.0); 108 | //! let y = array!(2.0); 109 | //! 110 | //! // This may lead to undefined behavior 111 | //! let result = compiled((&x, &y)).unwrap(); 112 | //! println!("{:?}", result); 113 | //! ``` 114 | //! 115 | //! Use [`compile_with_state()`] to compile functions that have side effects and 116 | //! pass the state as an mutable reference. 117 | //! 118 | //! ```rust 119 | //! use mlx_rs::{Array, array, transforms::compile::compile_with_state}; 120 | //! let mut state = vec![]; 121 | //! 122 | //! let fun = |state: &mut Vec, (x, y): (&Array, &Array)| { 123 | //! let z = x + y; 124 | //! let result = mlx_rs::exp!(&z); 125 | //! state.push(z); 126 | //! result 127 | //! }; 128 | //! 129 | //! let x = array!(1.0); 130 | //! let y = array!(2.0); 131 | //! 132 | //! let mut compiled = compile_with_state(fun, None); 133 | //! let result = compiled(&mut state, (&x, &y)).unwrap(); 134 | //! println!("{:?}", result); 135 | //! // println!("{:?}", state); // TODO: this currently doesn't work somehow 136 | //! ``` 137 | //! 138 | //! This is particularly useful for compiling a function which includes an 139 | //! update to a container of arrays, as is commonly done when training the 140 | //! parameters of a [`crate::module::Module`]. 141 | //! 142 | //! See mlx-rs/mlx-tests/tests/test_compile_with_state.rs for more examples. 143 | //! 144 | 145 | use std::collections::hash_map::DefaultHasher; 146 | use std::hash::{Hash, Hasher}; 147 | 148 | use super::{Closure, Guarded, VectorArray}; 149 | use crate::Array; 150 | 151 | #[allow(clippy::module_inception)] 152 | mod compile; 153 | mod compile_with_state; 154 | 155 | pub use compile::*; 156 | pub use compile_with_state::*; 157 | 158 | /// Globally enable the compilation of functions. 159 | /// 160 | /// Default is enabled. 161 | pub fn enable_compile() { 162 | unsafe { 163 | mlx_sys::mlx_enable_compile(); 164 | } 165 | } 166 | 167 | /// Globally disable the compilation of functions. 168 | /// 169 | /// Default is enabled. 170 | pub fn disable_compile() { 171 | unsafe { 172 | mlx_sys::mlx_disable_compile(); 173 | } 174 | } 175 | 176 | /// Clear the memory cache. 177 | pub fn clear_cache() { 178 | unsafe { 179 | mlx_sys::mlx_detail_compile_clear_cache(); 180 | } 181 | } 182 | 183 | /// A compiled function that can be called. 184 | #[derive(Debug, Clone)] 185 | pub struct Compiled { 186 | f_marker: std::marker::PhantomData, 187 | state: CompiledState, 188 | } 189 | 190 | #[derive(Debug, Clone)] 191 | struct CompiledState { 192 | f: F, 193 | shapeless: bool, 194 | id: usize, 195 | } 196 | 197 | impl Drop for CompiledState { 198 | fn drop(&mut self) { 199 | unsafe { 200 | // remove the compiled structure from the back end 201 | mlx_sys::mlx_detail_compile_erase(self.id); 202 | } 203 | } 204 | } 205 | 206 | fn type_id_to_usize(_val: &T) -> usize 207 | where 208 | T: 'static, 209 | { 210 | // hash type id to usize 211 | let type_id = std::any::TypeId::of::(); 212 | let mut hasher = DefaultHasher::new(); 213 | type_id.hash(&mut hasher); 214 | hasher.finish() as usize 215 | } 216 | 217 | fn update_by_replace_with_ref_to_new_array(src: &mut Array, new_array: &Array) { 218 | debug_assert_eq!(src.shape(), new_array.shape()); 219 | unsafe { 220 | mlx_sys::mlx_array_set(&mut src.as_ptr() as *mut _, new_array.as_ptr()); 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /mlx-rs/src/transforms/keyed_value_and_grad.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, rc::Rc}; 2 | 3 | use crate::{ 4 | error::{Exception, Result}, 5 | utils::{guard::Guarded, Closure}, 6 | Array, 7 | }; 8 | 9 | use super::{value_and_gradient, ClosureValueAndGrad}; 10 | 11 | /// Type alias for a hashmap of parameters. 12 | pub type KeyedParameters = HashMap, Arr>; 13 | 14 | /// Type alias for a hashmap of gradients. 15 | pub type KeyedGrad = KeyedParameters; 16 | 17 | macro_rules! keyed_value_and_grad { 18 | ($inner_ret:ty, $cls_new:ident, $f:ident, $args_ty:ty) => { 19 | move |parameters: KeyedParameters, 20 | arrays: $args_ty| 21 | -> Result<(Vec, KeyedGrad)> { 22 | let (flattened_keys, flattened_values): (Vec<_>, Vec<_>) = 23 | parameters.into_iter().unzip(); 24 | 25 | let inner = |flattened_arrays: &[Array]| -> $inner_ret { 26 | let parameters = flattened_keys 27 | .iter() 28 | .cloned() 29 | .zip(flattened_arrays.iter().cloned()) 30 | .collect(); 31 | ($f)(parameters, arrays.clone()) 32 | }; 33 | 34 | let argument_numbers = (0..flattened_values.len() as i32).collect::>(); 35 | 36 | let closure = Closure::$cls_new(inner); 37 | let cvg = ClosureValueAndGrad::try_from_op(|res| unsafe { 38 | mlx_sys::mlx_value_and_grad( 39 | res, 40 | closure.as_ptr(), 41 | argument_numbers.as_ptr(), 42 | argument_numbers.len(), 43 | ) 44 | })?; 45 | 46 | let (value, grads) = value_and_gradient(cvg.as_ptr(), flattened_values.into_iter())?; 47 | 48 | let grads_map = flattened_keys.iter().cloned().zip(grads).collect(); 49 | 50 | Ok((value, grads_map)) 51 | } 52 | }; 53 | } 54 | 55 | /// Similar to [`IntoValueAndGrad`] but for functions that take a hashmap of parameters. 56 | pub trait IntoKeyedValueAndGrad<'a, Arr, Args, Err> 57 | where 58 | Arr: AsRef, 59 | Args: Clone, 60 | { 61 | /// Convert the function/closure into a closure that computes the value and gradient. 62 | fn into_keyed_value_and_grad( 63 | self, 64 | ) -> impl FnMut(KeyedParameters, Args) -> Result<(Vec, KeyedGrad)> + 'a; 65 | } 66 | 67 | impl<'a, F, Arr, Args> IntoKeyedValueAndGrad<'a, Arr, Args, ()> for F 68 | where 69 | F: FnMut(HashMap, Array>, Args) -> Vec + 'a, 70 | Arr: AsRef, 71 | Args: Clone, 72 | { 73 | fn into_keyed_value_and_grad( 74 | mut self, 75 | ) -> impl FnMut(KeyedParameters, Args) -> Result<(Vec, KeyedGrad)> + 'a { 76 | keyed_value_and_grad!(Vec, new, self, Args) 77 | } 78 | } 79 | 80 | impl<'a, F, Arr, Args> IntoKeyedValueAndGrad<'a, Arr, Args, Exception> for F 81 | where 82 | F: FnMut(HashMap, Array>, Args) -> Result> + 'a, 83 | Arr: AsRef, 84 | Args: Clone, 85 | { 86 | fn into_keyed_value_and_grad( 87 | mut self, 88 | ) -> impl FnMut(KeyedParameters, Args) -> Result<(Vec, KeyedGrad)> + 'a { 89 | keyed_value_and_grad!(Result>, new_fallible, self, Args) 90 | } 91 | } 92 | 93 | /// Returns a function which computes the value and gradient of `f` with keyed parameters. 94 | pub fn keyed_value_and_grad<'a, F, Arr, Args, Err>( 95 | f: F, 96 | ) -> impl FnMut(KeyedParameters, Args) -> Result<(Vec, KeyedGrad)> + 'a 97 | where 98 | F: IntoKeyedValueAndGrad<'a, Arr, Args, Err> + 'a, 99 | Arr: AsRef, 100 | Args: Clone, 101 | { 102 | f.into_keyed_value_and_grad() 103 | } 104 | 105 | #[cfg(test)] 106 | mod tests { 107 | use std::{collections::HashMap, rc::Rc}; 108 | 109 | use crate::{array, Array}; 110 | 111 | use super::*; 112 | 113 | #[test] 114 | fn test_keyed_value_and_grad() { 115 | let f = |parameters: HashMap, Array>, _: i32| -> Vec { 116 | vec![¶meters["x"] * ¶meters["y"]] 117 | }; 118 | 119 | let x = array!(1.5f32); 120 | let y = array!(2.0f32); 121 | let parameters = vec![("x", x), ("y", y)] 122 | .into_iter() 123 | .map(|(k, v)| (k.into(), v)) 124 | .collect(); 125 | 126 | let mut vg = keyed_value_and_grad(f); 127 | 128 | let (value, grad) = vg(parameters, 0).unwrap(); 129 | 130 | assert_eq!(value[0].item::(), 1.5 * 2.0); 131 | assert_eq!(grad["x"].item::(), 2.0); 132 | assert_eq!(grad["y"].item::(), 1.5); 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /mlx-rs/src/transforms/value_and_grad.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | error::{Exception, Result}, 3 | utils::{guard::Guarded, Closure, IntoOption}, 4 | Array, 5 | }; 6 | 7 | use super::{value_and_gradient, ClosureValueAndGrad}; 8 | 9 | fn build_value_and_gradient_inner<'a>( 10 | closure: Closure<'a>, 11 | argnums: &'a [i32], 12 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a { 13 | move |arrays: &[Array]| unsafe { 14 | let cvg = ClosureValueAndGrad::try_from_op(|res| { 15 | mlx_sys::mlx_value_and_grad(res, closure.as_ptr(), argnums.as_ptr(), argnums.len()) 16 | })?; 17 | value_and_gradient(cvg.as_ptr(), arrays.iter()) 18 | } 19 | } 20 | 21 | fn build_value_and_gradient<'a, F>( 22 | f: F, 23 | argnums: &'a [i32], 24 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a 25 | where 26 | F: FnMut(&[Array]) -> Vec + 'a, 27 | { 28 | let closure = Closure::new(f); 29 | build_value_and_gradient_inner(closure, argnums) 30 | } 31 | 32 | fn build_fallible_value_and_gradient<'a, F>( 33 | f: F, 34 | argnums: &'a [i32], 35 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a 36 | where 37 | F: FnMut(&[Array]) -> Result> + 'a, 38 | { 39 | let closure = Closure::new_fallible(f); 40 | build_value_and_gradient_inner(closure, argnums) 41 | } 42 | 43 | /// Trait for functions/closures that can be converted into a closure that computes the value and 44 | /// gradient. 45 | pub trait IntoValueAndGrad<'a, Err> { 46 | /// Convert the function/closure into a closure that computes the value and gradient. 47 | fn into_value_and_grad( 48 | self, 49 | argnums: impl IntoOption<&'a [i32]>, 50 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a; 51 | } 52 | 53 | impl<'a, F> IntoValueAndGrad<'a, ()> for F 54 | where 55 | F: FnMut(&[Array]) -> Vec + 'a, 56 | { 57 | // refining_impl_trait is fine here because we have restricted the Args and Output types 58 | // in the generics. 59 | #[allow(refining_impl_trait)] 60 | fn into_value_and_grad( 61 | self, 62 | argnums: impl IntoOption<&'a [i32]>, 63 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a { 64 | let argnums = argnums.into_option().unwrap_or(&[0]); 65 | build_value_and_gradient(self, argnums) 66 | } 67 | } 68 | 69 | impl<'a, F> IntoValueAndGrad<'a, Exception> for F 70 | where 71 | F: FnMut(&[Array]) -> Result> + 'a, 72 | { 73 | #[allow(refining_impl_trait)] 74 | fn into_value_and_grad( 75 | self, 76 | argnums: impl IntoOption<&'a [i32]>, 77 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a { 78 | let argnums = argnums.into_option().unwrap_or(&[0]); 79 | build_fallible_value_and_gradient(self, argnums) 80 | } 81 | } 82 | 83 | /// Returns a function which computes the value and gradient of `f` with a 84 | /// default argument number `&[0]`. 85 | /// 86 | /// See also [`value_and_grad_with_arg_nums`] for a version that allows 87 | /// specifying the argument numbers 88 | pub fn value_and_grad<'a, F, Err>( 89 | f: F, 90 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a 91 | where 92 | F: IntoValueAndGrad<'a, Err> + 'a, 93 | { 94 | f.into_value_and_grad(None) 95 | } 96 | 97 | /// Returns a function which computes the value and gradient of `f`. 98 | /// 99 | /// See also [`value_and_grad`] for a version that uses the default argument 100 | /// numbers `&[0]`. 101 | pub fn value_and_grad_with_argnums<'a, F, Err>( 102 | f: F, 103 | argnums: impl IntoOption<&'a [i32]>, 104 | ) -> impl FnMut(&[Array]) -> Result<(Vec, Vec)> + 'a 105 | where 106 | F: IntoValueAndGrad<'a, Err> + 'a, 107 | { 108 | f.into_value_and_grad(argnums) 109 | } 110 | 111 | #[cfg(test)] 112 | mod tests { 113 | 114 | use crate::{array, transforms::value_and_grad, Array}; 115 | 116 | use super::*; 117 | 118 | // The unit tests below are adapted from the mlx c++ codebase 119 | #[test] 120 | fn test_value_and_grad() { 121 | let x = &[Array::from_f32(1.0)]; 122 | let fun = |argin: &[Array]| -> Vec { vec![&argin[0] + 1.0] }; 123 | let argnums = &[0]; 124 | let (y, dfdx) = value_and_grad_with_argnums(fun, argnums)(x).unwrap(); 125 | assert_eq!(y[0].item::(), 2.0); 126 | assert_eq!(dfdx[0].item::(), 1.0); 127 | 128 | let (y, dfdx) = value_and_grad(fun)(x).unwrap(); 129 | assert_eq!(y[0].item::(), 2.0); 130 | assert_eq!(dfdx[0].item::(), 1.0); 131 | } 132 | 133 | #[test] 134 | fn test_value_and_grad_with_error() { 135 | let fun = |argin: &[Array]| -> Result> { 136 | argin[0].add(array!(1.0)).map(|res| vec![res]) 137 | }; 138 | 139 | // Success case 140 | let argnums = &[0]; 141 | let x = array!(1.0f32); 142 | let y = array!(1.0f32); 143 | let args = &[x, y]; 144 | let result = value_and_grad_with_argnums(fun, argnums)(args); 145 | assert!(result.is_ok()); 146 | let result = value_and_grad(fun)(args); 147 | assert!(result.is_ok()); 148 | 149 | // Error case 150 | // Use non-broadcastable shapes 151 | let a = array!([1.0, 2.0, 3.0]); 152 | let b = array!([4.0, 5.0]); 153 | let args = &[a, b]; 154 | let result = value_and_grad_with_argnums(fun, argnums)(args); 155 | assert!(result.is_err()); 156 | let result = value_and_grad(fun)(args); 157 | assert!(result.is_err()); 158 | 159 | // Check that the error is not just "mlx_closure returned a non-zero value" 160 | let err = result.unwrap_err(); 161 | assert!(!err.what().contains("non-zero value")) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /mlx-rs/src/utils/io.rs: -------------------------------------------------------------------------------- 1 | use crate::error::{Exception, IoError}; 2 | use crate::utils::SUCCESS; 3 | use crate::{Array, Stream}; 4 | use std::collections::HashMap; 5 | use std::ffi::{CStr, CString}; 6 | use std::path::Path; 7 | use std::ptr::null_mut; 8 | 9 | use super::Guarded; 10 | 11 | pub(crate) struct SafeTensors { 12 | pub(crate) c_data: mlx_sys::mlx_map_string_to_array, 13 | pub(crate) c_metadata: mlx_sys::mlx_map_string_to_string, 14 | } 15 | 16 | impl Drop for SafeTensors { 17 | fn drop(&mut self) { 18 | unsafe { 19 | mlx_sys::mlx_map_string_to_string_free(self.c_metadata); 20 | mlx_sys::mlx_map_string_to_array_free(self.c_data); 21 | } 22 | } 23 | } 24 | 25 | impl SafeTensors { 26 | pub(crate) fn load_device(path: &Path, stream: impl AsRef) -> Result { 27 | if !path.is_file() { 28 | return Err(IoError::NotFile); 29 | } 30 | 31 | let extension = path 32 | .extension() 33 | .and_then(|ext| ext.to_str()) 34 | .ok_or(IoError::UnsupportedFormat)?; 35 | 36 | if extension != "safetensors" { 37 | return Err(IoError::UnsupportedFormat); 38 | } 39 | 40 | let path_str = path.to_str().ok_or(IoError::InvalidUtf8)?; 41 | let filepath = CString::new(path_str)?; 42 | 43 | SafeTensors::try_from_op(|(res_0, res_1)| unsafe { 44 | mlx_sys::mlx_load_safetensors(res_0, res_1, filepath.as_ptr(), stream.as_ref().as_ptr()) 45 | }) 46 | .map_err(Into::into) 47 | } 48 | 49 | pub(crate) fn data(&self) -> Result, Exception> { 50 | crate::error::INIT_ERR_HANDLER 51 | .with(|init| init.call_once(crate::error::setup_mlx_error_handler)); 52 | let mut map = HashMap::new(); 53 | unsafe { 54 | let iterator = mlx_sys::mlx_map_string_to_array_iterator_new(self.c_data); 55 | 56 | loop { 57 | let mut key_ptr: *const ::std::os::raw::c_char = null_mut(); 58 | let mut value = mlx_sys::mlx_array_new(); 59 | let status = mlx_sys::mlx_map_string_to_array_iterator_next( 60 | &mut key_ptr as *mut *const _, 61 | &mut value, 62 | iterator, 63 | ); 64 | 65 | match status { 66 | SUCCESS => { 67 | let key = CStr::from_ptr(key_ptr).to_string_lossy().into_owned(); 68 | let array = Array::from_ptr(value); 69 | map.insert(key, array); 70 | } 71 | 1 => { 72 | mlx_sys::mlx_array_free(value); 73 | return Err(crate::error::get_and_clear_last_mlx_error() 74 | .expect("A non-success status was returned, but no error was set.") 75 | .into()); 76 | } 77 | 2 => { 78 | mlx_sys::mlx_array_free(value); 79 | break; 80 | } 81 | _ => unreachable!(), 82 | } 83 | } 84 | 85 | mlx_sys::mlx_map_string_to_array_iterator_free(iterator); 86 | } 87 | 88 | Ok(map) 89 | } 90 | 91 | pub(crate) fn metadata(&self) -> Result, Exception> { 92 | crate::error::INIT_ERR_HANDLER 93 | .with(|init| init.call_once(crate::error::setup_mlx_error_handler)); 94 | 95 | let mut map = HashMap::new(); 96 | unsafe { 97 | let iterator = mlx_sys::mlx_map_string_to_string_iterator_new(self.c_metadata); 98 | 99 | let mut key: *const ::std::os::raw::c_char = null_mut(); 100 | let mut value: *const ::std::os::raw::c_char = null_mut(); 101 | loop { 102 | let status = mlx_sys::mlx_map_string_to_string_iterator_next( 103 | &mut key as *mut *const _, 104 | &mut value as *mut *const _, 105 | iterator, 106 | ); 107 | 108 | match status { 109 | SUCCESS => { 110 | let key = CStr::from_ptr(key).to_string_lossy().into_owned(); 111 | let value = CStr::from_ptr(value).to_string_lossy().into_owned(); 112 | map.insert(key, value); 113 | } 114 | 1 => { 115 | return Err(crate::error::get_and_clear_last_mlx_error() 116 | .expect("A non-success status was returned, but no error was set.") 117 | .into()) 118 | } 119 | 2 => break, 120 | _ => unreachable!(), 121 | } 122 | } 123 | } 124 | 125 | Ok(map) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /mlx-sys/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | ## 0.1.2-release 4 | 5 | - Update generated bindings to mlx-c 0.1.2 6 | 7 | ## ~~0.1.2~~ 8 | 9 | - ~~Update generated bindings to mlx-c 0.1.2~~ 10 | - Mistakenly published 0.1.0 as 0.1.2 11 | 12 | ## 0.1.0 13 | 14 | - Update generated bindings to mlx-c 0.1.0 15 | -------------------------------------------------------------------------------- /mlx-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mlx-sys" 3 | version = "0.2.0-alpha.2" # mlx-sys version should follow that of mlx-c 4 | authors.workspace = true 5 | edition.workspace = true 6 | 7 | description = "Low-level interface and binding generation for the mlx library" 8 | repository.workspace = true 9 | keywords.workspace = true 10 | categories.workspace = true 11 | license.workspace = true 12 | readme = "README.md" 13 | 14 | [package.metadata.docs.rs] 15 | targets = [ 16 | "aarch64-apple-darwin", 17 | "aarch64-apple-ios", 18 | "aarch64-apple-ios-sim", 19 | ] 20 | 21 | [features] 22 | default = ["accelerate", "metal"] 23 | 24 | accelerate = [] 25 | metal = [] 26 | 27 | [dependencies] 28 | 29 | [build-dependencies] 30 | bindgen.workspace = true 31 | cmake.workspace = true 32 | cc.workspace = true 33 | -------------------------------------------------------------------------------- /mlx-sys/README.md: -------------------------------------------------------------------------------- 1 | # mlx-sys 2 | 3 | Rust bindings to the mlx-c API. Generated using bindgen. 4 | -------------------------------------------------------------------------------- /mlx-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate cmake; 2 | 3 | use bindgen::RustTarget; 4 | use cmake::Config; 5 | use std::{env, path::PathBuf}; 6 | 7 | fn build_and_link_mlx_c() { 8 | let mut config = Config::new("src/mlx-c"); 9 | config.very_verbose(true); 10 | config.define("CMAKE_INSTALL_PREFIX", "."); 11 | 12 | #[cfg(debug_assertions)] 13 | { 14 | config.define("CMAKE_BUILD_TYPE", "Debug"); 15 | } 16 | 17 | #[cfg(not(debug_assertions))] 18 | { 19 | config.define("CMAKE_BUILD_TYPE", "Release"); 20 | } 21 | 22 | config.define("MLX_BUILD_METAL", "OFF"); 23 | config.define("MLX_BUILD_ACCELERATE", "OFF"); 24 | 25 | #[cfg(feature = "metal")] 26 | { 27 | config.define("MLX_BUILD_METAL", "ON"); 28 | } 29 | 30 | #[cfg(feature = "accelerate")] 31 | { 32 | config.define("MLX_BUILD_ACCELERATE", "ON"); 33 | } 34 | 35 | // build the mlx-c project 36 | let dst = config.build(); 37 | 38 | println!("cargo:rustc-link-search=native={}/build/lib", dst.display()); 39 | println!("cargo:rustc-link-lib=static=mlx"); 40 | println!("cargo:rustc-link-lib=static=mlxc"); 41 | 42 | println!("cargo:rustc-link-lib=c++"); 43 | println!("cargo:rustc-link-lib=dylib=objc"); 44 | println!("cargo:rustc-link-lib=framework=Foundation"); 45 | 46 | #[cfg(feature = "metal")] 47 | { 48 | println!("cargo:rustc-link-lib=framework=Metal"); 49 | } 50 | 51 | #[cfg(feature = "accelerate")] 52 | { 53 | println!("cargo:rustc-link-lib=framework=Accelerate"); 54 | } 55 | } 56 | 57 | fn main() { 58 | build_and_link_mlx_c(); 59 | 60 | // generate bindings 61 | let bindings = bindgen::Builder::default() 62 | .rust_target(RustTarget::Stable_1_73) 63 | .header("src/mlx-c/mlx/c/mlx.h") 64 | .header("src/mlx-c/mlx/c/linalg.h") 65 | .header("src/mlx-c/mlx/c/error.h") 66 | .header("src/mlx-c/mlx/c/transforms_impl.h") 67 | .clang_arg("-Isrc/mlx-c") 68 | .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) 69 | .generate() 70 | .expect("Unable to generate bindings"); 71 | 72 | // Write the bindings to the $OUT_DIR/bindings.rs file. 73 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); 74 | bindings 75 | .write_to_file(out_path.join("bindings.rs")) 76 | .expect("Couldn't write bindings!"); 77 | } 78 | -------------------------------------------------------------------------------- /mlx-sys/examples/is_metal_available.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | let mut is_available = false; 3 | let status = unsafe { mlx_sys::mlx_metal_is_available(&mut is_available as *mut bool) }; 4 | assert_eq!(status, 0); 5 | println!("{:?}", is_available); 6 | } 7 | -------------------------------------------------------------------------------- /mlx-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_upper_case_globals)] 2 | #![allow(non_camel_case_types)] 3 | #![allow(non_snake_case)] 4 | #![allow(clippy::all)] 5 | 6 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 7 | -------------------------------------------------------------------------------- /mlx-tests/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mlx-tests" 3 | edition = "2021" 4 | version.workspace = true 5 | authors.workspace = true 6 | 7 | [dependencies] 8 | 9 | [dev-dependencies] 10 | mlx-internal-macros.workspace = true 11 | mlx-rs.workspace = true 12 | tempfile.workspace = true -------------------------------------------------------------------------------- /mlx-tests/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! An empty crate for testing purposes only. 2 | -------------------------------------------------------------------------------- /mlx-tests/tests/common.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::{ 2 | error::Exception, 3 | macros::ModuleParameters, 4 | module::{Module, Param}, 5 | random::uniform, 6 | utils::IntoOption, 7 | Array, 8 | }; 9 | 10 | /// A helper model for testing optimizers. 11 | /// 12 | /// This is adapted from the swift binding tests in `mlx-swift/Tests/MLXTests/OptimizerTests.swift`. 13 | #[derive(Debug, ModuleParameters)] 14 | pub struct LinearFunctionModel { 15 | #[param] 16 | pub m: Param, 17 | 18 | #[param] 19 | pub b: Param, 20 | } 21 | 22 | impl Module<&Array> for LinearFunctionModel { 23 | type Error = Exception; 24 | type Output = Array; 25 | 26 | fn forward(&mut self, x: &Array) -> Result { 27 | self.m.multiply(x)?.add(&self.b) 28 | } 29 | 30 | fn training_mode(&mut self, _mode: bool) {} 31 | } 32 | 33 | impl LinearFunctionModel { 34 | pub fn new<'a>(shape: impl IntoOption<&'a [i32]>) -> mlx_rs::error::Result { 35 | let shape = shape.into_option(); 36 | let m = uniform::<_, f32>(-5.0, 5.0, shape, None)?; 37 | let b = uniform::<_, f32>(-5.0, 5.0, shape, None)?; 38 | Ok(Self { 39 | m: Param::new(m), 40 | b: Param::new(b), 41 | }) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /mlx-tests/tests/test_disable_compile.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::{ 2 | array, 3 | error::Exception, 4 | exp, negative, 5 | transforms::compile::{compile, disable_compile, enable_compile}, 6 | Array, 7 | }; 8 | 9 | #[test] 10 | fn test_disable_compile() { 11 | disable_compile(); 12 | 13 | let f = |x: &Array| -> Result { 14 | let z = negative!(x)?; 15 | 16 | // this will crash is compile is enabled 17 | println!("{:?}", z); 18 | 19 | exp!(z) 20 | }; 21 | 22 | let x = array!(10.0); 23 | let mut compiled = compile(f, None); 24 | 25 | // This will panic if compilation is enabled 26 | let _result = compiled(&x).unwrap(); 27 | 28 | // Re-enable compilation for other tests 29 | enable_compile(); 30 | } 31 | -------------------------------------------------------------------------------- /mlx-tests/tests/test_exported_macros.rs: -------------------------------------------------------------------------------- 1 | //! This contains the tests for some of the exported macros. 2 | //! 3 | //! This is mainly a sanity check to ensure that the exported macros are working as expected. 4 | 5 | use mlx_rs::{ 6 | array, complex64, 7 | ops::{arange, reshape}, 8 | Array, Dtype, StreamOrDevice, 9 | }; 10 | 11 | // Try two functions that don't have any optional arguments. 12 | 13 | #[test] 14 | fn test_ops_arithmetic_abs() { 15 | let data = array!([1i32, 2, -3, -4, -5]); 16 | let result = mlx_rs::abs!(&data).unwrap(); 17 | 18 | assert_eq!(result.as_slice::(), &[1, 2, 3, 4, 5]); 19 | 20 | let stream = StreamOrDevice::cpu(); 21 | let result = mlx_rs::abs!(data, stream = stream).unwrap(); 22 | 23 | assert_eq!(result.as_slice::(), &[1, 2, 3, 4, 5]); 24 | } 25 | 26 | #[test] 27 | fn test_ops_arithmetic_add() { 28 | let data1 = array!([1i32, 2, 3, 4, 5]); 29 | let data2 = array!([1i32, 2, 3, 4, 5]); 30 | let result = mlx_rs::add!(&data1, &data2).unwrap(); 31 | 32 | assert_eq!(result.as_slice::(), &[2, 4, 6, 8, 10]); 33 | 34 | let stream = StreamOrDevice::cpu(); 35 | let result = mlx_rs::add!(data1, data2, stream = stream).unwrap(); 36 | 37 | assert_eq!(result.as_slice::(), &[2, 4, 6, 8, 10]); 38 | } 39 | 40 | // Try a function that has optional arguments. 41 | 42 | #[test] 43 | fn test_ops_arithmetic_tensordot() { 44 | let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap(); 45 | let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap(); 46 | let axes_x = [1, 0]; 47 | let axes_y = [0, 1]; 48 | let z = mlx_rs::tensordot_axes!(&x, &y, &axes_x, &axes_y).unwrap(); 49 | let expected = Array::from_slice( 50 | &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306], 51 | &[5, 2], 52 | ); 53 | assert_eq!(z, expected); 54 | 55 | let stream = StreamOrDevice::cpu(); 56 | let z = mlx_rs::tensordot_axes!(&x, &y, &axes_x, &axes_y, stream = stream).unwrap(); 57 | assert_eq!(z, expected); 58 | } 59 | 60 | // Test functions defined in `mlx_rs::ops` module. 61 | 62 | #[test] 63 | fn test_ops_convolution_conv1d() { 64 | let input = array!( 65 | [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 66 | shape = [1, 5, 2] 67 | ); 68 | let weight = array!( 69 | [0.5, 0.0, -0.5, 1.0, 0.0, 1.5, 2.0, 0.0, -2.0, 1.5, 0.0, 1.0], 70 | shape = [2, 3, 2] 71 | ); 72 | 73 | let result = mlx_rs::conv1d!( 74 | &input, 75 | &weight, 76 | stride = 1, 77 | padding = 0, 78 | dilation = 1, 79 | groups = 1 80 | ) 81 | .unwrap(); 82 | 83 | let expected = array!([12.0, 8.0, 17.0, 13.0, 22.0, 18.0], shape = [1, 3, 2]); 84 | assert_eq!(result, expected); 85 | } 86 | 87 | #[test] 88 | fn test_ops_factory_arange() { 89 | // Without specifying start and step 90 | let array = mlx_rs::arange!(stop = 50).unwrap(); 91 | assert_eq!(array.shape(), &[50]); 92 | assert_eq!(array.dtype(), Dtype::Float32); 93 | 94 | let data: &[f32] = array.as_slice(); 95 | let expected: Vec = (0..50).map(|x| x as f32).collect(); 96 | assert_eq!(data, expected.as_slice()); 97 | 98 | // With specifying start and step 99 | let array = mlx_rs::arange!(start = 1.0, stop = 50.0, step = 2.0).unwrap(); 100 | assert_eq!(array.shape(), &[25]); 101 | assert_eq!(array.dtype(), Dtype::Float32); 102 | 103 | let data: &[f32] = array.as_slice(); 104 | let expected: Vec = (1..50).step_by(2).map(|x| x as f32).collect(); 105 | assert_eq!(data, expected.as_slice()); 106 | 107 | let stream = StreamOrDevice::cpu(); 108 | let array = mlx_rs::arange!(start = 1.0, stop = 50.0, step = 2.0, stream = stream).unwrap(); 109 | assert_eq!(array.shape(), &[25]); 110 | assert_eq!(array.dtype(), Dtype::Float32); 111 | 112 | let data: &[f32] = array.as_slice(); 113 | let expected: Vec = (1..50).step_by(2).map(|x| x as f32).collect(); 114 | assert_eq!(data, expected.as_slice()); 115 | } 116 | 117 | // Test functions defined in `mlx_rs::fft` module. 118 | 119 | #[test] 120 | fn test_fft_fft() { 121 | const FFT_EXPECTED: &[complex64; 4] = &[ 122 | complex64::new(10.0, 0.0), 123 | complex64::new(-2.0, 2.0), 124 | complex64::new(-2.0, 0.0), 125 | complex64::new(-2.0, -2.0), 126 | ]; 127 | 128 | let data = array!([1.0, 2.0, 3.0, 4.0]); 129 | let fft = mlx_rs::fft!(&data).unwrap(); 130 | 131 | assert_eq!(fft.dtype(), Dtype::Complex64); 132 | assert_eq!(fft.as_slice::(), FFT_EXPECTED); 133 | } 134 | 135 | // Test functions defined in `mlx_rs::linalg` module. 136 | 137 | #[test] 138 | fn test_linalg_norm() { 139 | let a = array!([1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]).unwrap(); 140 | let norm = mlx_rs::norm_l2!(&a).unwrap(); 141 | assert_eq!(norm.item::(), 5.477_226); 142 | } 143 | 144 | // Test functions defined in `mlx_rs::random` module. 145 | 146 | #[test] 147 | fn test_random_uniform() { 148 | let value = mlx_rs::uniform!(0.0, 1.0, shape = &[1]).unwrap(); 149 | assert_eq!(value.shape(), &[1]); 150 | assert!(value.item::() >= 0.0 && value.item::() <= 1.0); 151 | } 152 | 153 | #[test] 154 | fn test_random_normal() { 155 | let value = mlx_rs::normal!(shape = &[1]).unwrap(); 156 | assert_eq!(value.shape(), &[1]); 157 | assert!(value.item::() >= -10.0 && value.item::() <= 10.0); 158 | } 159 | -------------------------------------------------------------------------------- /mlx-tests/tests/test_generate_builder.rs: -------------------------------------------------------------------------------- 1 | use mlx_internal_macros::*; 2 | use mlx_rs::builder::{Buildable, Builder}; 3 | 4 | generate_builder! { 5 | /// Test struct for the builder generation. 6 | #[derive(Debug, Buildable)] 7 | #[builder(build_with = build_test_struct)] 8 | struct TestStruct { 9 | #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_1)] 10 | opt_field_1: i32, 11 | #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_2)] 12 | opt_field_2: i32, 13 | mandatory_field_1: i32, 14 | 15 | #[builder(ignore)] 16 | ignored_field: String, 17 | } 18 | } 19 | 20 | fn build_test_struct( 21 | builder: TestStructBuilder, 22 | ) -> std::result::Result { 23 | Ok(TestStruct { 24 | opt_field_1: builder.opt_field_1, 25 | opt_field_2: builder.opt_field_2, 26 | mandatory_field_1: builder.mandatory_field_1, 27 | ignored_field: String::from("ignored"), 28 | }) 29 | } 30 | 31 | impl TestStruct { 32 | pub const DEFAULT_OPT_FIELD_1: i32 = 1; 33 | pub const DEFAULT_OPT_FIELD_2: i32 = 2; 34 | } 35 | 36 | #[test] 37 | fn test_generated_builder() { 38 | let test_struct = ::Builder::new(4) 39 | .opt_field_1(2) 40 | .opt_field_2(3) 41 | .build() 42 | .unwrap(); 43 | 44 | assert_eq!(test_struct.opt_field_1, 2); 45 | assert_eq!(test_struct.opt_field_2, 3); 46 | assert_eq!(test_struct.mandatory_field_1, 4); 47 | assert_eq!(test_struct.ignored_field, String::from("ignored")); 48 | } 49 | -------------------------------------------------------------------------------- /mlx-tests/tests/test_generate_macro.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused_variables)] 2 | 3 | use mlx_internal_macros::{default_device, generate_macro}; 4 | use mlx_rs::{Stream, StreamOrDevice}; 5 | 6 | // Test generate_macro for functions with no generic type arguments. 7 | #[generate_macro(customize(root = "$crate"))] 8 | #[default_device] 9 | fn foo_device( 10 | a: i32, // Mandatory argument 11 | b: i32, // Mandatory argument 12 | #[optional] c: Option, // Optional argument 13 | #[optional] d: impl Into>, // Optional argument but impl Trait 14 | #[optional] stream: impl AsRef, // stream always optional and placed at the end 15 | ) -> i32 { 16 | a + b + c.unwrap_or(0) + d.into().unwrap_or(0) 17 | } 18 | 19 | #[test] 20 | fn test_foo() { 21 | assert_eq!(foo!(1, 2), 3); 22 | assert_eq!(foo!(1, 2, c = Some(3)), 6); 23 | assert_eq!(foo!(1, 2, d = Some(4)), 7); 24 | assert_eq!(foo!(1, 2, c = Some(3), d = Some(4)), 10); 25 | 26 | let stream = Stream::new(); 27 | 28 | assert_eq!(foo!(1, 2, stream = &stream), 3); 29 | assert_eq!(foo!(1, 2, c = Some(3), stream = &stream), 6); 30 | assert_eq!(foo!(1, 2, d = Some(4), stream = &stream), 7); 31 | assert_eq!(foo!(1, 2, c = Some(3), d = Some(4), stream = &stream), 10); 32 | } 33 | 34 | // Test generate_macro for functions with generic type arguments. 35 | #[generate_macro(customize( 36 | root = "$crate", 37 | default_dtype = i32, 38 | ))] 39 | #[default_device] 40 | fn bar_device>( 41 | a: T, // Mandatory argument 42 | b: T, // Mandatory argument 43 | #[optional] c: Option, // Optional argument 44 | #[optional] d: impl Into>, // Optional argument but impl Trait 45 | #[optional] stream: impl AsRef, // stream always optional and placed at the end 46 | ) -> i32 { 47 | let a = a.into(); 48 | let b = b.into(); 49 | let c = c.map(Into::into); 50 | let d = d.into().map(Into::into); 51 | a + b + c.unwrap_or(0) + d.unwrap_or(0) 52 | } 53 | 54 | #[test] 55 | fn test_bar() { 56 | // Without specifying dtype, the default is i32. 57 | 58 | let result = bar!(1, 2); 59 | assert_eq!(result, 3); 60 | 61 | let result = bar!(1, 2, c = Some(3)); 62 | assert_eq!(result, 6); 63 | 64 | let result = bar!(1, 2, d = Some(4)); 65 | assert_eq!(result, 7); 66 | 67 | let result = bar!(1, 2, c = Some(3), d = Some(4)); 68 | assert_eq!(result, 10); 69 | 70 | // With dtype specified as i16. 71 | 72 | let result = bar!(1, 2, dtype = i16); 73 | assert_eq!(result, 3); 74 | 75 | let result = bar!(1, 2, c = Some(3), dtype = i16); 76 | assert_eq!(result, 6); 77 | 78 | let result = bar!(1, 2, d = Some(4), dtype = i16); 79 | assert_eq!(result, 7); 80 | 81 | let result = bar!(1, 2, c = Some(3), d = Some(4), dtype = i16); 82 | assert_eq!(result, 10); 83 | 84 | // With stream specified. 85 | 86 | let stream = Stream::new(); 87 | 88 | let result = bar!(1, 2, stream = &stream); 89 | assert_eq!(result, 3); 90 | 91 | let result = bar!(1, 2, c = Some(3), stream = &stream); 92 | assert_eq!(result, 6); 93 | 94 | let result = bar!(1, 2, d = Some(4), stream = &stream); 95 | assert_eq!(result, 7); 96 | 97 | let result = bar!(1, 2, c = Some(3), d = Some(4), stream = &stream); 98 | assert_eq!(result, 10); 99 | 100 | // With dtype and stream specified. 101 | 102 | let result = bar!(1, 2, dtype = i16, stream = &stream); 103 | assert_eq!(result, 3); 104 | 105 | let result = bar!(1, 2, c = Some(3), dtype = i16, stream = &stream); 106 | assert_eq!(result, 6); 107 | 108 | let result = bar!(1, 2, d = Some(4), dtype = i16, stream = &stream); 109 | assert_eq!(result, 7); 110 | 111 | let result = bar!( 112 | 1, 113 | 2, 114 | c = Some(3), 115 | d = Some(4), 116 | dtype = i16, 117 | stream = &stream 118 | ); 119 | assert_eq!(result, 10); 120 | } 121 | 122 | // Test named mandatory arguments. 123 | #[generate_macro(customize(root = "$crate"))] 124 | #[default_device] 125 | fn baz_device( 126 | #[optional] a: Option, // Optinal argument 127 | #[named] b: i32, // Mandatory argument 128 | #[optional] c: Option, // Optional argument 129 | #[optional] stream: impl AsRef, // stream always optional and placed at the end 130 | ) -> i32 { 131 | a.unwrap_or(0) + b + c.unwrap_or(0) 132 | } 133 | 134 | #[test] 135 | fn test_baz() { 136 | assert_eq!(baz!(b = 1), 1); 137 | assert_eq!(baz!(a = Some(2), b = 1), 3); 138 | assert_eq!(baz!(b = 1, c = Some(3)), 4); 139 | assert_eq!(baz!(a = Some(2), b = 1, c = Some(3)), 6); 140 | 141 | let stream = Stream::new(); 142 | 143 | assert_eq!(baz!(b = 1, stream = &stream), 1); 144 | assert_eq!(baz!(a = Some(2), b = 1, stream = &stream), 3); 145 | assert_eq!(baz!(b = 1, c = Some(3), stream = &stream), 4); 146 | assert_eq!(baz!(a = Some(2), b = 1, c = Some(3), stream = &stream), 6); 147 | } 148 | -------------------------------------------------------------------------------- /mlx-tests/tests/test_internal_macros.rs: -------------------------------------------------------------------------------- 1 | use mlx_internal_macros::*; 2 | use mlx_rs::builder::{Buildable, Builder}; 3 | 4 | generate_builder! { 5 | /// Test struct for the builder generation. 6 | #[derive(Debug, Buildable)] 7 | #[builder(build_with = build_test_struct)] 8 | struct TestStruct { 9 | #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_1)] 10 | opt_field_1: i32, 11 | #[builder(optional, default = TestStruct::DEFAULT_OPT_FIELD_2)] 12 | opt_field_2: i32, 13 | mandatory_field_1: i32, 14 | 15 | #[builder(ignore)] 16 | ignored_field: String, 17 | } 18 | } 19 | 20 | fn build_test_struct( 21 | builder: TestStructBuilder, 22 | ) -> std::result::Result { 23 | Ok(TestStruct { 24 | opt_field_1: builder.opt_field_1, 25 | opt_field_2: builder.opt_field_2, 26 | mandatory_field_1: builder.mandatory_field_1, 27 | ignored_field: String::from("ignored"), 28 | }) 29 | } 30 | 31 | impl TestStruct { 32 | pub const DEFAULT_OPT_FIELD_1: i32 = 1; 33 | pub const DEFAULT_OPT_FIELD_2: i32 = 2; 34 | } 35 | 36 | #[test] 37 | fn test_generated_builder() { 38 | let test_struct = ::Builder::new(4) 39 | .opt_field_1(2) 40 | .opt_field_2(3) 41 | .build() 42 | .unwrap(); 43 | 44 | assert_eq!(test_struct.opt_field_1, 2); 45 | assert_eq!(test_struct.opt_field_2, 3); 46 | assert_eq!(test_struct.mandatory_field_1, 4); 47 | assert_eq!(test_struct.ignored_field, String::from("ignored")); 48 | } 49 | -------------------------------------------------------------------------------- /mlx-tests/tests/test_module.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::{error::Exception, macros::ModuleParameters, module::Module, nn::Linear, Array}; 2 | 3 | #[derive(Debug, ModuleParameters)] 4 | struct M { 5 | #[param] 6 | linear: Linear, 7 | } 8 | 9 | impl M { 10 | pub fn new() -> Self { 11 | Self { 12 | linear: Linear::new(5, 5).unwrap(), 13 | } 14 | } 15 | } 16 | 17 | impl Module<&Array> for M { 18 | type Error = Exception; 19 | type Output = Array; 20 | 21 | fn forward(&mut self, x: &Array) -> Result { 22 | self.linear.forward(x) 23 | } 24 | 25 | fn training_mode(&mut self, _mode: bool) {} 26 | } 27 | 28 | #[test] 29 | fn test_nested_module() { 30 | let mut m = M::new(); 31 | let x = mlx_rs::random::uniform::<_, f32>(1.0, 2.0, &[1, 5], None).unwrap(); 32 | let y = m.forward(&x).unwrap(); 33 | assert_ne!(y.sum(None).unwrap(), mlx_rs::array!(0.0)); 34 | } 35 | -------------------------------------------------------------------------------- /mlx-tests/tests/test_quantizable.rs: -------------------------------------------------------------------------------- 1 | use mlx_rs::{ 2 | error::Exception, 3 | macros::{ModuleParameters, Quantizable}, 4 | module::Module, 5 | nn::Linear, 6 | quantization::MaybeQuantized, 7 | Array, 8 | }; 9 | 10 | #[derive(Debug, ModuleParameters, Quantizable)] 11 | struct QuantizableExample { 12 | #[quantizable] 13 | pub ql: MaybeQuantized, 14 | } 15 | 16 | impl Module<&Array> for QuantizableExample { 17 | type Output = Array; 18 | 19 | type Error = Exception; 20 | 21 | fn forward(&mut self, x: &Array) -> Result { 22 | self.ql.forward(x) 23 | } 24 | 25 | fn training_mode(&mut self, mode: bool) { 26 | self.ql.training_mode(mode) 27 | } 28 | } 29 | --------------------------------------------------------------------------------