├── .cargo └── config.toml ├── .gitignore ├── .github ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── lints.yml │ └── rust-ci.yml ├── katex-header.html ├── LICENSE ├── examples └── mnist │ ├── parse_cli.rs │ ├── models.rs │ ├── training.rs │ ├── main.rs │ └── optim.rs ├── paper ├── paper.md └── paper.bib ├── Cargo.toml ├── README.md ├── Changelog.md ├── benches ├── mnist_bench.rs └── training.rs ├── tests ├── adadelta_tests.rs ├── radam-tests.rs ├── adamax_tests.rs ├── nadam_tests.rs ├── adagrad_tests.rs ├── lbfgs_tests.rs ├── adam_tests.rs ├── rmsprop-tests.rs └── esgd_tests.rs └── src ├── lib.rs ├── adagrad.rs ├── adamax.rs ├── adadelta.rs ├── nadam.rs ├── radam.rs └── lbfgs └── strong_wolfe.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | rustdocflags = [ "--html-in-header", "./katex-header.html" ] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | notes.txt 4 | out.txt 5 | lbfgs_testbed.ipynb 6 | test*.ipynb 7 | References.md 8 | pseudo/* 9 | paper/paper.pdf 10 | paper/paper.jats 11 | *.profraw 12 | .vscode/ -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "cargo" 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Code example to reproduce behaviour 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Desktop (please complete the following information):** 20 | 21 | * OS: [e.g. iOS] 22 | * Features in use [e.g. Cuda] 23 | * version 24 | 25 | **Additional context** 26 | Add any other context about the problem here. 27 | -------------------------------------------------------------------------------- /.github/workflows/lints.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | 7 | name: Clippy 8 | 9 | jobs: 10 | fmt: 11 | name: Rustfmt 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: dtolnay/rust-toolchain@stable 16 | - run: rustup component add rustfmt 17 | - run: cargo fmt --all -- --check 18 | 19 | clippy: 20 | name: Clippy 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: dtolnay/rust-toolchain@stable 25 | - run: rustup component add rustfmt 26 | - run: cargo clippy --workspace --tests --examples -- -D warnings 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | Is the feature an additional optimisation algorithm? Is it a change to the documentation? 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen: for additional optimisation algorithms it would be ideal if possible to see if there are other implementations or a reference (though am very happy to accept novel optimisers). 15 | 16 | **Additional context** 17 | Add any other context about the feature request here. 18 | -------------------------------------------------------------------------------- /.github/workflows/rust-ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | 7 | name: CI 8 | 9 | jobs: 10 | check: 11 | name: Check 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest, windows-latest, macOS-latest] 16 | rust: [stable, beta, nightly] 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: dtolnay/rust-toolchain@master 20 | with: 21 | toolchain: ${{ matrix.rust }} 22 | - run: cargo check --workspace 23 | 24 | test: 25 | name: Test Suite 26 | runs-on: ${{ matrix.os }} 27 | strategy: 28 | matrix: 29 | os: [ubuntu-latest, windows-latest, macOS-latest] 30 | rust: [stable, beta, nightly] 31 | steps: 32 | - uses: actions/checkout@v4 33 | - uses: dtolnay/rust-toolchain@master 34 | with: 35 | toolchain: ${{ matrix.rust }} 36 | - run: cargo test --workspace -------------------------------------------------------------------------------- /katex-header.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2023] [Kirpal Grewal] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/mnist/parse_cli.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, ValueEnum}; 2 | 3 | pub struct TrainingArgs { 4 | pub learning_rate: f64, 5 | pub load: Option, 6 | pub save: Option, 7 | pub epochs: usize, 8 | } 9 | 10 | #[derive(ValueEnum, Clone)] 11 | pub enum WhichModel { 12 | Linear, 13 | Mlp, 14 | } 15 | 16 | #[derive(ValueEnum, Clone)] 17 | pub enum WhichOptim { 18 | Adadelta, 19 | Adagrad, 20 | Adamax, 21 | Sgd, 22 | NAdam, 23 | RAdam, 24 | Rms, 25 | Adam, 26 | } 27 | 28 | #[derive(Parser)] 29 | pub struct Args { 30 | #[clap(value_enum, default_value_t = WhichModel::Linear)] 31 | pub model: WhichModel, 32 | 33 | #[arg(long, value_enum, default_value_t = WhichOptim::Adadelta)] 34 | pub optim: WhichOptim, 35 | 36 | #[arg(long)] 37 | pub learning_rate: Option, 38 | 39 | #[arg(long, default_value_t = 200)] 40 | pub epochs: usize, 41 | 42 | /// The file where to save the trained weights, in safetensors format. 43 | #[arg(long)] 44 | pub save: Option, 45 | 46 | /// The file where to load the trained weights from, in safetensors format. 47 | #[arg(long)] 48 | pub load: Option, 49 | 50 | /// The directory where to load the dataset from, in ubyte format. 51 | #[arg(long)] 52 | pub local_mnist: Option, 53 | } 54 | -------------------------------------------------------------------------------- /examples/mnist/models.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{Result, Tensor}; 2 | use candle_nn::{Linear, Module, VarBuilder}; 3 | 4 | const IMAGE_DIM: usize = 784; 5 | const LABELS: usize = 10; 6 | 7 | fn linear_z(in_dim: usize, out_dim: usize, vs: &VarBuilder) -> Result { 8 | let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; 9 | let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?; 10 | Ok(Linear::new(ws, Some(bs))) 11 | } 12 | 13 | pub trait Model: Sized { 14 | fn new(vs: VarBuilder) -> Result; 15 | fn forward(&self, xs: &Tensor) -> Result; 16 | } 17 | 18 | pub struct LinearModel { 19 | linear: Linear, 20 | } 21 | 22 | impl Model for LinearModel { 23 | fn new(vs: VarBuilder) -> Result { 24 | let linear = linear_z(IMAGE_DIM, LABELS, &vs)?; 25 | Ok(Self { linear }) 26 | } 27 | 28 | fn forward(&self, xs: &Tensor) -> Result { 29 | self.linear.forward(xs) 30 | } 31 | } 32 | 33 | pub struct Mlp { 34 | ln1: Linear, 35 | ln2: Linear, 36 | } 37 | 38 | impl Model for Mlp { 39 | fn new(vs: VarBuilder) -> Result { 40 | let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?; 41 | let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?; 42 | Ok(Self { ln1, ln2 }) 43 | } 44 | 45 | fn forward(&self, xs: &Tensor) -> Result { 46 | let xs = self.ln1.forward(xs)?; 47 | let xs = xs.relu()?; 48 | self.ln2.forward(&xs) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'Candle Optimisers: A Rust crate for optimisation algorithms' 3 | tags: 4 | - Rust 5 | - optimisation 6 | - optimization 7 | - machine learning 8 | authors: 9 | - name: Kirpal Grewal 10 | orcid: 0009-0001-7923-9975 11 | affiliation: 1 12 | affiliations: 13 | - name: Yusuf Hamied Department of Chemistry, University of Cambridge 14 | index: 1 15 | date: 20 December 2023 16 | bibliography: paper.bib 17 | --- 18 | 19 | # Summary 20 | 21 | `candle-optimisers` is a crate for optimisers written in Rust for use with candle (@candle) a lightweight machine learning framework. The crate offers a set of 22 | optimisers for training neural networks. 23 | 24 | # Statement of need 25 | 26 | Rust provides the opportunity for the development of high performance machine learning libraries, with a leaner runtime. However, there is a lack of optimisation algorithms implemented in Rust, 27 | with machine learning libraries currently implementing only some combination of Adam, AdamW, SGD and RMSProp. 28 | This crate aims to provide a set of complete set of optimisation algorithms for use with candle. 29 | This will allow Rust to be used for the training of models more easily. 30 | 31 | # Features 32 | 33 | This library implements the following optimisation algorithms: 34 | 35 | * SGD (including momentum and Nesterov momentum (@nmomentum)) 36 | 37 | * AdaDelta (@adadelta) 38 | 39 | * AdaGrad (@adagrad) 40 | 41 | * AdaMax (@adam) 42 | 43 | * Adam (@adam) including AMSGrad (@amsgrad) 44 | 45 | * AdamW (@weightdecay) (as decoupled weight decay of Adam) 46 | 47 | * NAdam (@nadam) 48 | 49 | * RAdam (@radam) 50 | 51 | * RMSProp (@rmsprop) 52 | 53 | * LBFGS (@LBFGS) 54 | 55 | Furthermore, decoupled weight decay (@weightdecay) is implemented for all of the adaptive methods listed and SGD, 56 | allowing for use of the method beyond solely AdamW. 57 | 58 | # References 59 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "candle-optimisers" 3 | version = "0.10.0-alpha.2" 4 | edition = "2021" 5 | readme = "README.md" 6 | license = "MIT" 7 | keywords = ["optimisers", "candle", "tensor", "machine-learning"] 8 | categories = ["science"] 9 | description = "Optimisers for use with candle, the minimalist ML framework" 10 | repository = "https://github.com/KGrewal1/optimisers" 11 | exclude = ["*.ipynb"] 12 | 13 | 14 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 15 | 16 | [dependencies] 17 | 18 | # Using candle 0.9.2-alpha.1 for CUDA 13.0 support via cudarc 0.17.1+ 19 | candle-core = "0.9.2-alpha.1" 20 | candle-nn = "0.9.2-alpha.1" 21 | log = "0.4.20" 22 | 23 | 24 | [dev-dependencies] 25 | anyhow = { version = "1", features = ["backtrace"] } 26 | assert_approx_eq = "1.1.0" 27 | candle-datasets = "0.9.2-alpha.1" 28 | clap = { version = "4.4.6", features = ["derive"] } 29 | criterion = { version = "0.7.0", features = ["html_reports"] } 30 | 31 | [[bench]] 32 | name = "mnist_bench" 33 | harness = false 34 | 35 | [features] 36 | default = [] 37 | cuda = ["candle-core/cuda", "candle-nn/cuda"] 38 | 39 | [profile.bench] 40 | lto = true # maximal LTO optimisaiton 41 | 42 | [lints.clippy] 43 | pedantic = { level = "warn", priority = -1 } 44 | suspicious = { level = "warn", priority = -1 } 45 | perf = { level = "warn", priority = -1 } 46 | complexity = { level = "warn", priority = -1 } 47 | style = { level = "warn", priority = -1 } 48 | cargo = { level = "warn", priority = -1 } 49 | imprecise_flops = "warn" 50 | missing_errors_doc = { level = "allow", priority = 1 } 51 | uninlined_format_args = { level = "allow", priority = 1 } 52 | similar_names = { level = "allow", priority = 1 } 53 | float_cmp = { level = "allow", priority = 1 } # as internaly rounded before the comparison 54 | doc_markdown = { level = "allow", priority = 1 } # otherwise names get flagged 55 | multiple_crate_versions = { level = "allow", priority = 1 } # for candle dep graph 56 | 57 | [package.metadata.docs.rs] 58 | rustdoc-args = ["--html-in-header", "./katex-header.html"] 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Candle Optimisers 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | [![codecov](https://codecov.io/gh/KGrewal1/optimisers/graph/badge.svg?token=6AFTLS6DFO)](https://codecov.io/gh/KGrewal1/optimisers) 5 | ![Tests](https://github.com/KGrewal1/optimisers/actions/workflows/rust-ci.yml/badge.svg) 6 | ![Tests](https://github.com/KGrewal1/optimisers/actions/workflows/lints.yml/badge.svg) 7 | [![Latest version](https://img.shields.io/crates/v/candle-optimisers.svg)](https://crates.io/crates/candle-optimisers) 8 | [![Documentation](https://docs.rs/candle-optimisers/badge.svg)](https://docs.rs/candle-optimisers) 9 | 10 | A crate for optimisers for use with [candle](https://github.com/huggingface/candle), the minimalist ML framework 11 | 12 | Optimisers implemented are: 13 | 14 | * SGD (including momentum and weight decay) 15 | 16 | * RMSprop 17 | 18 | Adaptive methods: 19 | 20 | * AdaDelta 21 | 22 | * AdaGrad 23 | 24 | * AdaMax 25 | 26 | * Adam 27 | 28 | * AdamW (included with Adam as `decoupled_weight_decay`) 29 | 30 | * NAdam 31 | 32 | * RAdam 33 | 34 | These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking). 35 | 36 | Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/pdf/1711.05101.pdf), in addition to the standard weight decay as implemented in pytorch. 37 | 38 | Pseudosecond order methods: 39 | 40 | * LBFGS 41 | 42 | This is not implemented equivalent to pytorch, but is checked on the 2D rosenbrock function 43 | 44 | ## Examples 45 | 46 | There is an mnist toy program along with a simple example of adagrad. Whilst the parameters of each method aren't tuned (all default with user input learning rate), the following converges quite nicely: 47 | 48 | ```cli 49 | cargo r -r --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025 50 | ``` 51 | 52 | For even faster training try: 53 | 54 | ```cli 55 | cargo r -r --features cuda --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025 56 | ``` 57 | 58 | to use the cuda backend. 59 | 60 | ## Usage 61 | 62 | ```cli 63 | cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers 64 | ``` 65 | 66 | ## Documentation 67 | 68 | Documentation is available on the rust docs site 69 | -------------------------------------------------------------------------------- /Changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v0.10.0-alpha.2 (2025-11-15) 4 | 5 | * Bump candle requirement to 0.9.2-alpha.1: this adds support for CUDA 13.0 via cudarc 0.17.1+ 6 | * Use published alpha versions from crates.io instead of git revisions, allowing this version to be published to crates.io 7 | 8 | ## v0.10.0-alpha.1 (2025-08-12) 9 | 10 | * Prerelease for candle 0.10.0 compatibility 11 | * Remove failing codecov 12 | * Clippy lints 13 | 14 | ## v0.9.0 (2024-11-18) 15 | 16 | * Bump candle requirtement to 0.9.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn 17 | 18 | ## v0.8.0 (2024-11-18) 19 | 20 | * Bump candle requirtement to 0.8.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn 21 | 22 | ## v0.5.0 (2024-02-28) 23 | 24 | * Bump candle requirtement to 0.5.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn 25 | * Internal changes for LBFGS line search 26 | 27 | ## v0.4.0 (2024-02-28) 28 | 29 | * Bump candle requirtement to 0.4.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn 30 | * Explicit reliance on the candle crates hosted on crates.io : as cargo does not support git dependecies in published crates, this library now points only to the crates.io releases (previously cargo would default to the crates.io instead of git repo anyway: if the git repo is specifically desired this can be obtained by patching the `Cargo.toml` to point at the candle repo) 31 | * Remove intel-mkl feature: features in this library are mainly used for running examples: any code that uses this library should instead use the features directly from the candle crates 32 | 33 | ## v0.3.2 (2024-01-07) 34 | 35 | * move directional evaluation into stronge wolfe 36 | * fix strong wolfe condition when used with weight decay 37 | 38 | ## v0.3.1 (2023-12-20) 39 | 40 | * Improved Documentation 41 | * Add ability to set more optimiser parameters (see issue regarding LR schedulers in `candle`) 42 | * All params are now `Clone`, `PartialEq` and `PartialOrd` 43 | 44 | ## v0.3.0 (2023-12-07) 45 | 46 | * Renamed to candle-optimisers for release on crates.io 47 | * Added fuller documentation 48 | * Added decoupled weight decay for SGD 49 | 50 | ## v0.2.1 (2023-12-06) 51 | 52 | * Added decoupled weight decay for all adaptive methods 53 | 54 | ## v0.2.0 (2023-12-06) 55 | 56 | * changed weight decay to `Option` type as opposed to checking for 0 57 | * made `enum` for decoupled weight decay and for momentum 58 | * added weight decay to LBFGS 59 | 60 | ## v0.1.x 61 | 62 | * Initial release and adddition of features 63 | -------------------------------------------------------------------------------- /benches/mnist_bench.rs: -------------------------------------------------------------------------------- 1 | use candle_core::Result as CResult; 2 | use candle_datasets::vision::Dataset; 3 | use candle_optimisers::{ 4 | adadelta::Adadelta, adagrad::Adagrad, adam::Adam, adamax::Adamax, esgd::SGD, nadam::NAdam, 5 | radam::RAdam, rmsprop::RMSprop, 6 | }; 7 | use criterion::{criterion_group, criterion_main, Criterion}; 8 | use training::Mlp; 9 | 10 | // mod models; 11 | // mod optim; 12 | mod training; 13 | 14 | fn load_data() -> CResult { 15 | candle_datasets::vision::mnist::load() 16 | } 17 | 18 | #[allow(clippy::missing_panics_doc)] 19 | pub fn criterion_benchmark_std(c: &mut Criterion) { 20 | let mut group = c.benchmark_group("std-optimisers"); 21 | let m = &load_data().expect("Failed to load data"); 22 | // let m = Rc::new(m); 23 | 24 | group.significance_level(0.1).sample_size(100); 25 | group.bench_function("adadelta", |b| { 26 | b.iter(|| { 27 | training::run_training::(m).expect("Failed to setup training"); 28 | }); 29 | }); 30 | group.bench_function("adagrad", |b| { 31 | b.iter(|| { 32 | training::run_training::(m).expect("Failed to setup training"); 33 | }); 34 | }); 35 | group.bench_function("adam", |b| { 36 | b.iter(|| training::run_training::(m).expect("Failed to setup training")); 37 | }); 38 | group.bench_function("adamax", |b| { 39 | b.iter(|| { 40 | training::run_training::(m).expect("Failed to setup training"); 41 | }); 42 | }); 43 | group.bench_function("esgd", |b| { 44 | b.iter(|| { 45 | training::run_training::(m).expect("Failed to setup training"); 46 | }); 47 | }); 48 | group.bench_function("nadam", |b| { 49 | b.iter(|| { 50 | training::run_training::(m).expect("Failed to setup training"); 51 | }); 52 | }); 53 | group.bench_function("radam", |b| { 54 | b.iter(|| { 55 | training::run_training::(m).expect("Failed to setup training"); 56 | }); 57 | }); 58 | group.bench_function("rmsprop", |b| { 59 | b.iter(|| { 60 | training::run_training::(m).expect("Failed to setup training"); 61 | }); 62 | }); 63 | 64 | group.finish(); 65 | } 66 | 67 | #[allow(clippy::missing_panics_doc)] 68 | pub fn criterion_benchmark_lbfgs(c: &mut Criterion) { 69 | let mut group = c.benchmark_group("lbfgs-optimser"); 70 | let m = load_data().expect("Failed to load data"); 71 | // let m = Rc::new(m); 72 | 73 | group.significance_level(0.1).sample_size(10); 74 | 75 | group.bench_function("lbfgs", |b| { 76 | b.iter(|| training::run_lbfgs_training::(&m).expect("Failed to setup training")); 77 | }); 78 | 79 | group.finish(); 80 | } 81 | 82 | criterion_group!(benches, criterion_benchmark_std, criterion_benchmark_lbfgs); 83 | criterion_main!(benches); 84 | -------------------------------------------------------------------------------- /examples/mnist/training.rs: -------------------------------------------------------------------------------- 1 | use crate::{models::Model, optim::Optim, parse_cli::TrainingArgs}; 2 | use candle_core::{DType, D}; 3 | use candle_nn::{loss, ops, VarBuilder, VarMap}; 4 | 5 | #[allow(clippy::module_name_repetitions)] 6 | pub fn training_loop( 7 | m: candle_datasets::vision::Dataset, 8 | args: &TrainingArgs, 9 | ) -> anyhow::Result<()> { 10 | // check to see if cuda device availabke 11 | let dev = candle_core::Device::cuda_if_available(0)?; 12 | println!("Training on device {:?}", dev); 13 | 14 | // get the labels from the dataset 15 | let train_labels = m.train_labels; 16 | // get the input from the dataset and put on device 17 | let train_images = m.train_images.to_device(&dev)?; 18 | // get the training labels on the device 19 | let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?; 20 | 21 | // create a new variable store 22 | let mut varmap = VarMap::new(); 23 | // create a new variable builder 24 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 25 | // create model from variables 26 | let model = M::new(vs.clone())?; 27 | 28 | // see if there are pretrained weights to load 29 | if let Some(load) = &args.load { 30 | println!("loading weights from {load}"); 31 | varmap.load(load)?; 32 | } 33 | 34 | // create an optimiser 35 | let mut optimiser = O::new(varmap.all_vars(), args.learning_rate)?; 36 | // candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?; 37 | // load the test images 38 | let test_images = m.test_images.to_device(&dev)?; 39 | // load the test labels 40 | let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; 41 | // loop for model optimisation 42 | for epoch in 0..args.epochs { 43 | // get log probabilities of results 44 | let logits = model.forward(&train_images)?; 45 | // softmax the log probabilities 46 | let log_sm = ops::log_softmax(&logits, D::Minus1)?; 47 | // get the loss 48 | let loss = loss::nll(&log_sm, &train_labels)?; 49 | // step the tensors by backpropagating the loss 50 | optimiser.back_step(&loss)?; 51 | 52 | // get the log probabilities of the test images 53 | let test_logits = model.forward(&test_images)?; 54 | // get the sum of the correct predictions 55 | let sum_ok = test_logits 56 | .argmax(D::Minus1)? 57 | .eq(&test_labels)? 58 | .to_dtype(DType::F32)? 59 | .sum_all()? 60 | .to_scalar::()?; 61 | // get the accuracy on the test set 62 | #[allow(clippy::cast_precision_loss)] 63 | let test_accuracy = sum_ok / test_labels.dims1()? as f32; 64 | println!( 65 | "{:4} train loss: {:8.5} test acc: {:5.2}%", 66 | epoch + 1, 67 | loss.to_scalar::()?, 68 | 100. * test_accuracy 69 | ); 70 | } 71 | if let Some(save) = &args.save { 72 | println!("saving trained weights in {save}"); 73 | varmap.save(save)?; 74 | } 75 | Ok(()) 76 | } 77 | -------------------------------------------------------------------------------- /examples/mnist/main.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | 3 | mod models; 4 | mod optim; 5 | mod parse_cli; 6 | mod training; 7 | 8 | use models::{LinearModel, Mlp}; 9 | 10 | use candle_optimisers::adagrad::Adagrad; 11 | use candle_optimisers::adamax::Adamax; 12 | use candle_optimisers::esgd::SGD; 13 | use candle_optimisers::nadam::NAdam; 14 | use candle_optimisers::radam::RAdam; 15 | use candle_optimisers::rmsprop::RMSprop; 16 | use candle_optimisers::{adadelta::Adadelta, adam::Adam}; 17 | 18 | use parse_cli::{Args, TrainingArgs, WhichModel, WhichOptim}; 19 | use training::training_loop; 20 | pub fn main() -> anyhow::Result<()> { 21 | let args = Args::parse(); 22 | // Load the dataset 23 | let m = if let Some(directory) = args.local_mnist { 24 | candle_datasets::vision::mnist::load_dir(directory)? 25 | } else { 26 | candle_datasets::vision::mnist::load()? 27 | }; 28 | println!("train-images: {:?}", m.train_images.shape()); 29 | println!("train-labels: {:?}", m.train_labels.shape()); 30 | println!("test-images: {:?}", m.test_images.shape()); 31 | println!("test-labels: {:?}", m.test_labels.shape()); 32 | 33 | let default_learning_rate = match args.model { 34 | WhichModel::Linear => 1., 35 | WhichModel::Mlp => 0.05, 36 | }; 37 | let training_args = TrainingArgs { 38 | epochs: args.epochs, 39 | learning_rate: args.learning_rate.unwrap_or(default_learning_rate), 40 | load: args.load, 41 | save: args.save, 42 | }; 43 | 44 | match args.optim { 45 | WhichOptim::Adadelta => match args.model { 46 | WhichModel::Linear => training_loop::(m, &training_args), 47 | WhichModel::Mlp => training_loop::(m, &training_args), 48 | }, 49 | WhichOptim::Adagrad => match args.model { 50 | WhichModel::Linear => training_loop::(m, &training_args), 51 | WhichModel::Mlp => training_loop::(m, &training_args), 52 | }, 53 | WhichOptim::Adamax => match args.model { 54 | WhichModel::Linear => training_loop::(m, &training_args), 55 | WhichModel::Mlp => training_loop::(m, &training_args), 56 | }, 57 | WhichOptim::Sgd => match args.model { 58 | WhichModel::Linear => training_loop::(m, &training_args), 59 | WhichModel::Mlp => training_loop::(m, &training_args), 60 | }, 61 | WhichOptim::NAdam => match args.model { 62 | WhichModel::Linear => training_loop::(m, &training_args), 63 | WhichModel::Mlp => training_loop::(m, &training_args), 64 | }, 65 | WhichOptim::RAdam => match args.model { 66 | WhichModel::Linear => training_loop::(m, &training_args), 67 | WhichModel::Mlp => training_loop::(m, &training_args), 68 | }, 69 | WhichOptim::Rms => match args.model { 70 | WhichModel::Linear => training_loop::(m, &training_args), 71 | WhichModel::Mlp => training_loop::(m, &training_args), 72 | }, 73 | WhichOptim::Adam => match args.model { 74 | WhichModel::Linear => training_loop::(m, &training_args), 75 | WhichModel::Mlp => training_loop::(m, &training_args), 76 | }, 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /examples/mnist/optim.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{Result, Tensor, Var}; 2 | use candle_nn::Optimizer; 3 | use candle_optimisers::{ 4 | adadelta::{Adadelta, ParamsAdaDelta}, 5 | adagrad::{Adagrad, ParamsAdaGrad}, 6 | adam::{Adam, ParamsAdam}, 7 | adamax::{Adamax, ParamsAdaMax}, 8 | esgd::{ParamsSGD, SGD}, 9 | nadam::{NAdam, ParamsNAdam}, 10 | radam::{ParamsRAdam, RAdam}, 11 | rmsprop::{ParamsRMSprop, RMSprop}, 12 | }; 13 | 14 | pub trait Optim: Sized { 15 | fn new(vars: Vec, lr: f64) -> Result; 16 | fn back_step(&mut self, loss: &Tensor) -> Result<()>; 17 | } 18 | 19 | impl Optim for Adadelta { 20 | fn new(vars: Vec, lr: f64) -> Result { 21 | ::new( 22 | vars, 23 | ParamsAdaDelta { 24 | lr, 25 | ..Default::default() 26 | }, 27 | ) 28 | } 29 | 30 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 31 | self.backward_step(loss) 32 | } 33 | } 34 | 35 | impl Optim for Adagrad { 36 | fn new(vars: Vec, lr: f64) -> Result { 37 | ::new( 38 | vars, 39 | ParamsAdaGrad { 40 | lr, 41 | ..Default::default() 42 | }, 43 | ) 44 | } 45 | 46 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 47 | self.backward_step(loss) 48 | } 49 | } 50 | 51 | impl Optim for Adamax { 52 | fn new(vars: Vec, lr: f64) -> Result { 53 | ::new( 54 | vars, 55 | ParamsAdaMax { 56 | lr, 57 | ..Default::default() 58 | }, 59 | ) 60 | } 61 | 62 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 63 | self.backward_step(loss) 64 | } 65 | } 66 | 67 | impl Optim for SGD { 68 | fn new(vars: Vec, lr: f64) -> Result { 69 | ::new( 70 | vars, 71 | ParamsSGD { 72 | lr, 73 | ..Default::default() 74 | }, 75 | ) 76 | } 77 | 78 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 79 | self.backward_step(loss) 80 | } 81 | } 82 | 83 | impl Optim for NAdam { 84 | fn new(vars: Vec, lr: f64) -> Result { 85 | ::new( 86 | vars, 87 | ParamsNAdam { 88 | lr, 89 | ..Default::default() 90 | }, 91 | ) 92 | } 93 | 94 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 95 | self.backward_step(loss) 96 | } 97 | } 98 | 99 | impl Optim for RAdam { 100 | fn new(vars: Vec, lr: f64) -> Result { 101 | ::new( 102 | vars, 103 | ParamsRAdam { 104 | lr, 105 | ..Default::default() 106 | }, 107 | ) 108 | } 109 | 110 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 111 | self.backward_step(loss) 112 | } 113 | } 114 | 115 | impl Optim for RMSprop { 116 | fn new(vars: Vec, lr: f64) -> Result { 117 | ::new( 118 | vars, 119 | ParamsRMSprop { 120 | lr, 121 | ..Default::default() 122 | }, 123 | ) 124 | } 125 | 126 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 127 | self.backward_step(loss) 128 | } 129 | } 130 | 131 | impl Optim for Adam { 132 | fn new(vars: Vec, lr: f64) -> Result { 133 | ::new( 134 | vars, 135 | ParamsAdam { 136 | lr, 137 | ..Default::default() 138 | }, 139 | ) 140 | } 141 | 142 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 143 | self.backward_step(loss) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /tests/adadelta_tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::{ 7 | adadelta::{Adadelta, ParamsAdaDelta}, 8 | Decay, 9 | }; 10 | 11 | /* The results of this test have been checked against the following PyTorch code. 12 | import torch 13 | from torch import optim 14 | 15 | w_gen = torch.tensor([[3., 1.]]) 16 | b_gen = torch.tensor([-2.]) 17 | 18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 20 | 21 | m = torch.nn.Linear(2, 1) 22 | with torch.no_grad(): 23 | m.weight.zero_() 24 | m.bias.zero_() 25 | optimiser = optim.Adadelta(m.parameters(), lr=0.004) 26 | # optimiser.zero_grad() 27 | for _step in range(100): 28 | optimiser.zero_grad() 29 | ys = m(sample_xs) 30 | loss = ((ys - sample_ys)**2).sum() 31 | loss.backward() 32 | optimiser.step() 33 | # print("Optimizer state begin") 34 | # print(optimiser.state) 35 | # print("Optimizer state end") 36 | print(m.weight) 37 | print(m.bias) 38 | */ 39 | #[test] 40 | fn adadelta_test() -> Result<()> { 41 | // Generate some linear data, y = 3.x1 + x2 - 2. 42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 44 | let gen = Linear::new(w_gen, Some(b_gen)); 45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 46 | let sample_ys = gen.forward(&sample_xs)?; 47 | 48 | let params = ParamsAdaDelta { 49 | lr: 0.004, 50 | rho: 0.9, 51 | weight_decay: None, 52 | eps: 1e-6, 53 | }; 54 | // Now use backprop to run a linear regression between samples and get the coefficients back. 55 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 56 | let b = Var::new(0f32, &Device::Cpu)?; 57 | let mut n_sgd = Adadelta::new(vec![w.clone(), b.clone()], params)?; 58 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 59 | for _step in 0..100 { 60 | let ys = lin.forward(&sample_xs)?; 61 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 62 | n_sgd.backward_step(&loss)?; 63 | } 64 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0016, 0.0016]]); 65 | assert_eq!(to_vec0_round(&b, 4)?, 0.0016); 66 | Ok(()) 67 | } 68 | 69 | #[test] 70 | fn adadelta_weight_decay_test() -> Result<()> { 71 | // Generate some linear data, y = 3.x1 + x2 - 2. 72 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 73 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 74 | let gen = Linear::new(w_gen, Some(b_gen)); 75 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 76 | let sample_ys = gen.forward(&sample_xs)?; 77 | 78 | let params = ParamsAdaDelta { 79 | lr: 0.004, 80 | rho: 0.9, 81 | weight_decay: Some(Decay::WeightDecay(0.8)), 82 | eps: 1e-6, 83 | }; 84 | // Now use backprop to run a linear regression between samples and get the coefficients back. 85 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 86 | let b = Var::new(0f32, &Device::Cpu)?; 87 | let mut n_sgd = Adadelta::new(vec![w.clone(), b.clone()], params)?; 88 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 89 | for _step in 0..100 { 90 | let ys = lin.forward(&sample_xs)?; 91 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 92 | n_sgd.backward_step(&loss)?; 93 | } 94 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0016, 0.0016]]); 95 | assert_eq!(to_vec0_round(&b, 4)?, 0.0016); 96 | Ok(()) 97 | } 98 | 99 | //------------------------------------------------------------------------- 100 | // THIS IS NOT TESTED AGAINST PYTORCH 101 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR ADADELTA 102 | // ------------------------------------------------------------------------ 103 | #[test] 104 | fn adadelta_decoupled_weight_decay_test() -> Result<()> { 105 | // Generate some linear data, y = 3.x1 + x2 - 2. 106 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 107 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 108 | let gen = Linear::new(w_gen, Some(b_gen)); 109 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 110 | let sample_ys = gen.forward(&sample_xs)?; 111 | 112 | let params = ParamsAdaDelta { 113 | lr: 0.004, 114 | rho: 0.9, 115 | weight_decay: Some(Decay::DecoupledWeightDecay(0.8)), 116 | eps: 1e-6, 117 | }; 118 | // Now use backprop to run a linear regression between samples and get the coefficients back. 119 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 120 | let b = Var::new(0f32, &Device::Cpu)?; 121 | let mut n_sgd = Adadelta::new(vec![w.clone(), b.clone()], params)?; 122 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 123 | for _step in 0..100 { 124 | let ys = lin.forward(&sample_xs)?; 125 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 126 | n_sgd.backward_step(&loss)?; 127 | } 128 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0014, 0.0014]]); 129 | assert_eq!(to_vec0_round(&b, 4)?, 0.0014); 130 | Ok(()) 131 | } 132 | -------------------------------------------------------------------------------- /tests/radam-tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::{ 7 | radam::{ParamsRAdam, RAdam}, 8 | Decay, 9 | }; 10 | 11 | /* The results of this test have been checked against the following PyTorch code. 12 | import torch 13 | from torch import optim 14 | 15 | w_gen = torch.tensor([[3., 1.]]) 16 | b_gen = torch.tensor([-2.]) 17 | 18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 20 | 21 | m = torch.nn.Linear(2, 1) 22 | with torch.no_grad(): 23 | m.weight.zero_() 24 | m.bias.zero_() 25 | optimiser = optim.RAdam(m.parameters()) 26 | # optimiser.zero_grad() 27 | for _step in range(100): 28 | optimiser.zero_grad() 29 | ys = m(sample_xs) 30 | loss = ((ys - sample_ys)**2).sum() 31 | loss.backward() 32 | optimiser.step() 33 | # print("Optimizer state begin") 34 | # print(optimiser.state) 35 | # print("Optimizer state end") 36 | print(m.weight) 37 | print(m.bias) 38 | */ 39 | #[test] 40 | fn radam_test() -> Result<()> { 41 | // Generate some linear data, y = 3.x1 + x2 - 2. 42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 44 | let gen = Linear::new(w_gen, Some(b_gen)); 45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 46 | let sample_ys = gen.forward(&sample_xs)?; 47 | 48 | let params = ParamsRAdam::default(); 49 | // Now use backprop to run a linear regression between samples and get the coefficients back. 50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 51 | let b = Var::new(0f32, &Device::Cpu)?; 52 | let mut n_sgd = RAdam::new(vec![w.clone(), b.clone()], params)?; 53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 54 | for _step in 0..100 { 55 | let ys = lin.forward(&sample_xs)?; 56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 57 | n_sgd.backward_step(&loss)?; 58 | } 59 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.2128, 1.2819]]); 60 | assert_eq!(to_vec0_round(&b, 4)?, 0.2923); 61 | Ok(()) 62 | } 63 | 64 | /* The results of this test have been checked against the following PyTorch code. 65 | import torch 66 | from torch import optim 67 | 68 | w_gen = torch.tensor([[3., 1.]]) 69 | b_gen = torch.tensor([-2.]) 70 | 71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 73 | 74 | m = torch.nn.Linear(2, 1) 75 | with torch.no_grad(): 76 | m.weight.zero_() 77 | m.bias.zero_() 78 | optimiser = optim.RAdam(m.parameters(), weight_decay = 0.4) 79 | # optimiser.zero_grad() 80 | for _step in range(100): 81 | optimiser.zero_grad() 82 | ys = m(sample_xs) 83 | loss = ((ys - sample_ys)**2).sum() 84 | loss.backward() 85 | optimiser.step() 86 | # print("Optimizer state begin") 87 | # print(optimiser.state) 88 | # print("Optimizer state end") 89 | print(m.weight) 90 | print(m.bias) 91 | */ 92 | #[test] 93 | fn radam_weight_decay_test() -> Result<()> { 94 | // Generate some linear data, y = 3.x1 + x2 - 2. 95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 97 | let gen = Linear::new(w_gen, Some(b_gen)); 98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 99 | let sample_ys = gen.forward(&sample_xs)?; 100 | 101 | let params = ParamsRAdam { 102 | weight_decay: Some(Decay::WeightDecay(0.4)), 103 | ..Default::default() 104 | }; 105 | // Now use backprop to run a linear regression between samples and get the coefficients back. 106 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 107 | let b = Var::new(0f32, &Device::Cpu)?; 108 | let mut n_sgd = RAdam::new(vec![w.clone(), b.clone()], params)?; 109 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 110 | for _step in 0..100 { 111 | let ys = lin.forward(&sample_xs)?; 112 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 113 | n_sgd.backward_step(&loss)?; 114 | } 115 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.2117, 1.2812]]); 116 | assert_eq!(to_vec0_round(&b, 4)?, 0.2921); 117 | Ok(()) 118 | } 119 | 120 | //------------------------------------------------------------------------- 121 | // THIS IS NOT TESTED AGAINST PYTORCH 122 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR RADAM 123 | // ------------------------------------------------------------------------ 124 | #[test] 125 | fn radam_decoupled_weight_decay_test() -> Result<()> { 126 | // Generate some linear data, y = 3.x1 + x2 - 2. 127 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 128 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 129 | let gen = Linear::new(w_gen, Some(b_gen)); 130 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 131 | let sample_ys = gen.forward(&sample_xs)?; 132 | 133 | let params = ParamsRAdam { 134 | weight_decay: Some(Decay::DecoupledWeightDecay(0.4)), 135 | ..Default::default() 136 | }; 137 | // Now use backprop to run a linear regression between samples and get the coefficients back. 138 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 139 | let b = Var::new(0f32, &Device::Cpu)?; 140 | let mut n_sgd = RAdam::new(vec![w.clone(), b.clone()], params)?; 141 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 142 | for _step in 0..100 { 143 | let ys = lin.forward(&sample_xs)?; 144 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 145 | n_sgd.backward_step(&loss)?; 146 | } 147 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.1294, 1.2331]]); 148 | assert_eq!(to_vec0_round(&b, 4)?, 0.2818); 149 | Ok(()) 150 | } 151 | -------------------------------------------------------------------------------- /tests/adamax_tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::adamax::{Adamax, ParamsAdaMax}; 7 | 8 | /* The results of this test have been checked against the following PyTorch code. 9 | import torch 10 | from torch import optim 11 | 12 | w_gen = torch.tensor([[3., 1.]]) 13 | b_gen = torch.tensor([-2.]) 14 | 15 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 16 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 17 | 18 | m = torch.nn.Linear(2, 1) 19 | with torch.no_grad(): 20 | m.weight.zero_() 21 | m.bias.zero_() 22 | optimiser = optim.Adamax(m.parameters(), lr=0.004) 23 | # optimiser.zero_grad() 24 | for _step in range(100): 25 | optimiser.zero_grad() 26 | ys = m(sample_xs) 27 | loss = ((ys - sample_ys)**2).sum() 28 | loss.backward() 29 | optimiser.step() 30 | # print("Optimizer state begin") 31 | # print(optimiser.state) 32 | # print("Optimizer state end") 33 | print(m.weight) 34 | print(m.bias) 35 | */ 36 | #[test] 37 | fn adamax_test() -> Result<()> { 38 | // Generate some linear data, y = 3.x1 + x2 - 2. 39 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 40 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 41 | let gen = Linear::new(w_gen, Some(b_gen)); 42 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 43 | let sample_ys = gen.forward(&sample_xs)?; 44 | 45 | let params = ParamsAdaMax { 46 | lr: 0.004, 47 | ..Default::default() 48 | }; 49 | // Now use backprop to run a linear regression between samples and get the coefficients back. 50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 51 | let b = Var::new(0f32, &Device::Cpu)?; 52 | let mut n_sgd = Adamax::new(vec![w.clone(), b.clone()], params)?; 53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 54 | for _step in 0..100 { 55 | let ys = lin.forward(&sample_xs)?; 56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 57 | n_sgd.backward_step(&loss)?; 58 | } 59 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.3895, 0.3450]]); 60 | assert_eq!(to_vec0_round(&b, 4)?, 0.3643); 61 | Ok(()) 62 | } 63 | 64 | /* The results of this test have been checked against the following PyTorch code. 65 | import torch 66 | from torch import optim 67 | 68 | w_gen = torch.tensor([[3., 1.]]) 69 | b_gen = torch.tensor([-2.]) 70 | 71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 73 | 74 | m = torch.nn.Linear(2, 1) 75 | with torch.no_grad(): 76 | m.weight.zero_() 77 | m.bias.zero_() 78 | optimiser = optim.Adamax(m.parameters(), lr=0.004, weight_decay = 0.6) 79 | # optimiser.zero_grad() 80 | for _step in range(100): 81 | optimiser.zero_grad() 82 | ys = m(sample_xs) 83 | loss = ((ys - sample_ys)**2).sum() 84 | loss.backward() 85 | optimiser.step() 86 | # print("Optimizer state begin") 87 | # print(optimiser.state) 88 | # print("Optimizer state end") 89 | print(m.weight) 90 | print(m.bias) 91 | */ 92 | #[test] 93 | fn adamax_weight_decay_test() -> Result<()> { 94 | // Generate some linear data, y = 3.x1 + x2 - 2. 95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 97 | let gen = Linear::new(w_gen, Some(b_gen)); 98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 99 | let sample_ys = gen.forward(&sample_xs)?; 100 | 101 | let params = ParamsAdaMax { 102 | lr: 0.004, 103 | weight_decay: Some(candle_optimisers::Decay::WeightDecay(0.6)), 104 | ..Default::default() 105 | }; 106 | // Now use backprop to run a linear regression between samples and get the coefficients back. 107 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 108 | let b = Var::new(0f32, &Device::Cpu)?; 109 | let mut n_sgd = Adamax::new(vec![w.clone(), b.clone()], params)?; 110 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 111 | for _step in 0..100 { 112 | let ys = lin.forward(&sample_xs)?; 113 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 114 | n_sgd.backward_step(&loss)?; 115 | } 116 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.3894, 0.3450]]); 117 | assert_eq!(to_vec0_round(&b, 4)?, 0.3639); 118 | Ok(()) 119 | } 120 | 121 | //------------------------------------------------------------------------- 122 | // THIS IS NOT TESTED AGAINST PYTORCH 123 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR ADAMAX 124 | // ------------------------------------------------------------------------ 125 | #[test] 126 | fn adamax_decoupled_weight_decay_test() -> Result<()> { 127 | // Generate some linear data, y = 3.x1 + x2 - 2. 128 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 129 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 130 | let gen = Linear::new(w_gen, Some(b_gen)); 131 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 132 | let sample_ys = gen.forward(&sample_xs)?; 133 | 134 | let params = ParamsAdaMax { 135 | lr: 0.004, 136 | weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.6)), 137 | ..Default::default() 138 | }; 139 | // Now use backprop to run a linear regression between samples and get the coefficients back. 140 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 141 | let b = Var::new(0f32, &Device::Cpu)?; 142 | let mut n_sgd = Adamax::new(vec![w.clone(), b.clone()], params)?; 143 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 144 | for _step in 0..100 { 145 | let ys = lin.forward(&sample_xs)?; 146 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 147 | n_sgd.backward_step(&loss)?; 148 | } 149 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.3481, 0.3095]]); 150 | assert_eq!(to_vec0_round(&b, 4)?, 0.3263); 151 | Ok(()) 152 | } 153 | -------------------------------------------------------------------------------- /paper/paper.bib: -------------------------------------------------------------------------------- 1 | @article{adadelta, 2 | author = {Matthew D. Zeiler}, 3 | title = {{ADADELTA:} An Adaptive Learning Rate Method}, 4 | journal = {CoRR}, 5 | volume = {abs/1212.5701}, 6 | year = {2012}, 7 | url = {http://arxiv.org/abs/1212.5701}, 8 | eprinttype = {arXiv}, 9 | eprint = {1212.5701}, 10 | timestamp = {Mon, 13 Aug 2018 16:45:57 +0200}, 11 | biburl = {https://dblp.org/rec/journals/corr/abs-1212-5701.bib}, 12 | bibsource = {dblp computer science bibliography, https://dblp.org}, 13 | doi = {10.48550/arXiv.1212.5701} 14 | } 15 | 16 | @article{adagrad, 17 | author = {John Duchi and Elad Hazan and Yoram Singer}, 18 | title = {Adaptive Subgradient Methods for Online Learning and Stochastic Optimization}, 19 | journal = {Journal of Machine Learning Research}, 20 | year = {2011}, 21 | volume = {12}, 22 | number = {61}, 23 | pages = {2121--2159}, 24 | url = {http://jmlr.org/papers/v12/duchi11a.html}, 25 | 26 | } 27 | 28 | @inproceedings{adam, 29 | author = {Diederik P. Kingma and 30 | Jimmy Ba}, 31 | editor = {Yoshua Bengio and 32 | Yann LeCun}, 33 | title = {Adam: {A} Method for Stochastic Optimization}, 34 | booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015, 35 | San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings}, 36 | year = {2015}, 37 | url = {http://arxiv.org/abs/1412.6980}, 38 | timestamp = {Thu, 25 Jul 2019 14:25:37 +0200}, 39 | biburl = {https://dblp.org/rec/journals/corr/KingmaB14.bib}, 40 | bibsource = {dblp computer science bibliography, https://dblp.org}, 41 | doi = {10.48550/arXiv.1412.6980} 42 | } 43 | 44 | @article{weightdecay, 45 | author = {Ilya Loshchilov and 46 | Frank Hutter}, 47 | title = {Fixing Weight Decay Regularization in Adam}, 48 | journal = {CoRR}, 49 | volume = {abs/1711.05101}, 50 | year = {2017}, 51 | url = {http://arxiv.org/abs/1711.05101}, 52 | eprinttype = {arXiv}, 53 | eprint = {1711.05101}, 54 | timestamp = {Mon, 13 Aug 2018 16:48:18 +0200}, 55 | biburl = {https://dblp.org/rec/journals/corr/abs-1711-05101.bib}, 56 | bibsource = {dblp computer science bibliography, https://dblp.org}, 57 | doi = {10.48550/arXiv.1711.05101} 58 | } 59 | 60 | @inproceedings{amsgrad, 61 | title={On the Convergence of Adam and Beyond}, 62 | author={Sashank J. Reddi and Satyen Kale and Sanjiv Kumar}, 63 | booktitle={International Conference on Learning Representations}, 64 | year={2018}, 65 | url={https://openreview.net/forum?id=ryQu7f-RZ} 66 | } 67 | 68 | @inproceedings{nmomentum, 69 | title = {On the importance of initialization and momentum in deep learning}, 70 | author = {Sutskever, Ilya and Martens, James and Dahl, George and Hinton, Geoffrey}, 71 | booktitle = {Proceedings of the 30th International Conference on Machine Learning}, 72 | pages = {1139--1147}, 73 | year = {2013}, 74 | editor = {Dasgupta, Sanjoy and McAllester, David}, 75 | volume = {28}, 76 | number = {3}, 77 | series = {Proceedings of Machine Learning Research}, 78 | address = {Atlanta, Georgia, USA}, 79 | month = {17--19 Jun}, 80 | publisher = {PMLR}, 81 | pdf = {http://proceedings.mlr.press/v28/sutskever13.pdf}, 82 | url = {https://proceedings.mlr.press/v28/sutskever13.html}, 83 | abstract = {Deep and recurrent neural networks (DNNs and RNNs respectively) are powerful models that were considered to be almost impossible to train using stochastic gradient descent with momentum. In this paper, we show that when stochastic gradient descent with momentum uses a well-designed random initialization and a particular type of slowly increasing schedule for the momentum parameter, it can train both DNNs and RNNs (on datasets with long-term dependencies) to levels of performance that were previously achievable only with Hessian-Free optimization. We find that both the initialization and the momentum are crucial since poorly initialized networks cannot be trained with momentum and well-initialized networks perform markedly worse when the momentum is absent or poorly tuned. Our success training these models suggests that previous attempts to train deep and recurrent neural networks from random initializations have likely failed due to poor initialization schemes. Furthermore, carefully tuned momentum methods suffice for dealing with the curvature issues in deep and recurrent network training objectives without the need for sophisticated second-order methods. } 84 | } 85 | 86 | 87 | @article{LBFGS, 88 | title={On the limited memory BFGS method for large scale optimization}, 89 | author={Liu, Dong C and Nocedal, Jorge}, 90 | journal={Mathematical programming}, 91 | volume={45}, 92 | number={1-3}, 93 | pages={503--528}, 94 | year={1989}, 95 | publisher={Springer}, 96 | doi = {10.1007/BF01589116} 97 | } 98 | 99 | @inproceedings{nadam, 100 | title = {Incorporating {Nesterov Momentum into Adam}}, 101 | author = {Dozat, Timothy}, 102 | booktitle = {Proceedings of the 4th International Conference on Learning Representations}, 103 | pages = {1--4}, 104 | date = 2016, 105 | url = {https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ}, 106 | biburl = {https://bibbase.org/network/publication/dozat-incorporatingnesterovmomentumintoadam} 107 | } 108 | 109 | @article{radam, 110 | author = {Liyuan Liu and 111 | Haoming Jiang and 112 | Pengcheng He and 113 | Weizhu Chen and 114 | Xiaodong Liu and 115 | Jianfeng Gao and 116 | Jiawei Han}, 117 | title = {On the Variance of the Adaptive Learning Rate and Beyond}, 118 | journal = {CoRR}, 119 | volume = {abs/1908.03265}, 120 | year = {2019}, 121 | url = {http://arxiv.org/abs/1908.03265}, 122 | eprinttype = {arXiv}, 123 | eprint = {1908.03265}, 124 | timestamp = {Mon, 30 May 2022 13:48:56 +0200}, 125 | biburl = {https://dblp.org/rec/journals/corr/abs-1908-03265.bib}, 126 | bibsource = {dblp computer science bibliography, https://dblp.org}, 127 | doi = {10.48550/arXiv.1908.03265} 128 | } 129 | 130 | @misc{rmsprop, 131 | author = {Geoffrey Hinton and 132 | Nitish Srivastava and 133 | Kevin Swersky}, 134 | title = {Neural Networks for Machine Learning}, 135 | year = {2012}, 136 | url = {https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf} 137 | } 138 | 139 | @misc{candle, 140 | title={Candle: a Minimalist ML Framework for Rust}, 141 | author={Laurent Mazar\'e and others}, 142 | year={2023}, 143 | url={https://github.com/huggingface/candle/}, 144 | } -------------------------------------------------------------------------------- /tests/nadam_tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::{ 7 | nadam::{NAdam, ParamsNAdam}, 8 | Decay, 9 | }; 10 | 11 | /* The results of this test have been checked against the following PyTorch code. 12 | import torch 13 | from torch import optim 14 | 15 | w_gen = torch.tensor([[3., 1.]]) 16 | b_gen = torch.tensor([-2.]) 17 | 18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 20 | 21 | m = torch.nn.Linear(2, 1) 22 | with torch.no_grad(): 23 | m.weight.zero_() 24 | m.bias.zero_() 25 | optimiser = optim.NAdam(m.parameters()) 26 | # optimiser.zero_grad() 27 | for _step in range(100): 28 | optimiser.zero_grad() 29 | ys = m(sample_xs) 30 | loss = ((ys - sample_ys)**2).sum() 31 | loss.backward() 32 | optimiser.step() 33 | # print("Optimizer state begin") 34 | # print(optimiser.state) 35 | # print("Optimizer state end") 36 | print(m.weight) 37 | print(m.bias) 38 | */ 39 | #[test] 40 | fn nadam_test() -> Result<()> { 41 | // Generate some linear data, y = 3.x1 + x2 - 2. 42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 44 | let gen = Linear::new(w_gen, Some(b_gen)); 45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 46 | let sample_ys = gen.forward(&sample_xs)?; 47 | 48 | let params = ParamsNAdam::default(); 49 | // Now use backprop to run a linear regression between samples and get the coefficients back. 50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 51 | let b = Var::new(0f32, &Device::Cpu)?; 52 | let mut n_sgd = NAdam::new(vec![w.clone(), b.clone()], params)?; 53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 54 | for _step in 0..100 { 55 | let ys = lin.forward(&sample_xs)?; 56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 57 | n_sgd.backward_step(&loss)?; 58 | } 59 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1897, 0.1837]]); 60 | assert_eq!(to_vec0_round(&b, 4)?, 0.1864); 61 | Ok(()) 62 | } 63 | 64 | /* The results of this test have been checked against the following PyTorch code. 65 | import torch 66 | from torch import optim 67 | 68 | w_gen = torch.tensor([[3., 1.]]) 69 | b_gen = torch.tensor([-2.]) 70 | 71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 73 | 74 | m = torch.nn.Linear(2, 1) 75 | with torch.no_grad(): 76 | m.weight.zero_() 77 | m.bias.zero_() 78 | optimiser = optim.NAdam(m.parameters(), weight_decay = 0.6) 79 | # optimiser.zero_grad() 80 | for _step in range(100): 81 | optimiser.zero_grad() 82 | ys = m(sample_xs) 83 | loss = ((ys - sample_ys)**2).sum() 84 | loss.backward() 85 | optimiser.step() 86 | # print("Optimizer state begin") 87 | # print(optimiser.state) 88 | # print("Optimizer state end") 89 | print(m.weight) 90 | print(m.bias) 91 | */ 92 | #[test] 93 | fn nadam_weight_decay_test() -> Result<()> { 94 | // Generate some linear data, y = 3.x1 + x2 - 2. 95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 97 | let gen = Linear::new(w_gen, Some(b_gen)); 98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 99 | let sample_ys = gen.forward(&sample_xs)?; 100 | 101 | let params = ParamsNAdam { 102 | weight_decay: Some(Decay::WeightDecay(0.6)), 103 | ..Default::default() 104 | }; 105 | // Now use backprop to run a linear regression between samples and get the coefficients back. 106 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 107 | let b = Var::new(0f32, &Device::Cpu)?; 108 | let mut n_sgd = NAdam::new(vec![w.clone(), b.clone()], params)?; 109 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 110 | for _step in 0..100 { 111 | let ys = lin.forward(&sample_xs)?; 112 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 113 | n_sgd.backward_step(&loss)?; 114 | } 115 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1897, 0.1837]]); 116 | assert_eq!(to_vec0_round(&b, 4)?, 0.1863); 117 | Ok(()) 118 | } 119 | 120 | /* The results of this test have been checked against the following PyTorch code. 121 | import torch 122 | from torch import optim 123 | 124 | w_gen = torch.tensor([[3., 1.]]) 125 | b_gen = torch.tensor([-2.]) 126 | 127 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 128 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 129 | 130 | m = torch.nn.Linear(2, 1) 131 | with torch.no_grad(): 132 | m.weight.zero_() 133 | m.bias.zero_() 134 | optimiser = optim.NAdam(m.parameters(), weight_decay = 0.6, decoupled_weight_decay=True) 135 | # optimiser.zero_grad() 136 | for _step in range(100): 137 | optimiser.zero_grad() 138 | ys = m(sample_xs) 139 | loss = ((ys - sample_ys)**2).sum() 140 | loss.backward() 141 | optimiser.step() 142 | # print("Optimizer state begin") 143 | # print(optimiser.state) 144 | # print("Optimizer state end") 145 | print(m.weight) 146 | print(m.bias) 147 | */ 148 | #[test] 149 | fn nadam_decoupled_weight_decay_test() -> Result<()> { 150 | // Generate some linear data, y = 3.x1 + x2 - 2. 151 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 152 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 153 | let gen = Linear::new(w_gen, Some(b_gen)); 154 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 155 | let sample_ys = gen.forward(&sample_xs)?; 156 | 157 | let params = ParamsNAdam { 158 | weight_decay: Some(Decay::DecoupledWeightDecay(0.6)), 159 | ..Default::default() 160 | }; 161 | // Now use backprop to run a linear regression between samples and get the coefficients back. 162 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 163 | let b = Var::new(0f32, &Device::Cpu)?; 164 | let mut n_sgd = NAdam::new(vec![w.clone(), b.clone()], params)?; 165 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 166 | for _step in 0..100 { 167 | let ys = lin.forward(&sample_xs)?; 168 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 169 | n_sgd.backward_step(&loss)?; 170 | } 171 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1792, 0.1737]]); 172 | assert_eq!(to_vec0_round(&b, 4)?, 0.1762); 173 | Ok(()) 174 | } 175 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | Optimisers for use with the [candle](https://github.com/huggingface/candle) framework for lightweight machine learning. 3 | Apart from LBFGS, these all implement the [`candle_nn::optim::Optimizer`] trait from candle-nn 4 | 5 | # Example 6 | 7 | Training an MNIST model using the Adam optimiser 8 | 9 | ``` 10 | # use candle_core::{Result, Tensor}; 11 | # use candle_core::{DType, D}; 12 | # use candle_nn::{loss, ops, VarBuilder, VarMap, optim::Optimizer}; 13 | # use candle_optimisers::{ 14 | # adam::{Adam, ParamsAdam} 15 | # }; 16 | # 17 | # pub trait Model: Sized { 18 | # fn new(vs: VarBuilder) -> Result; 19 | # fn forward(&self, xs: &Tensor) -> Result; 20 | # } 21 | # 22 | # pub fn training_loop( 23 | # m: candle_datasets::vision::Dataset, 24 | # varmap: &VarMap, 25 | # model: M, 26 | # ) -> anyhow::Result<()> { 27 | # // check to see if cuda device availabke 28 | # let dev = candle_core::Device::cuda_if_available(0)?; 29 | # // get the input from the dataset and put on device 30 | # let train_images = m.train_images.to_device(&dev)?; 31 | # // get the training labels on the device 32 | # let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?; 33 | # 34 | # 35 | # // load the test images 36 | # let test_images = m.test_images.to_device(&dev)?; 37 | # // load the test labels 38 | # let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; 39 | # 40 | // create the Adam optimiser 41 | 42 | // set the learning rate to 0.004 and use the default parameters for everything else 43 | let params = ParamsAdam { 44 | lr: 0.004, 45 | ..Default::default() 46 | }; 47 | // create the optimiser by passing in the variable to be optimised and the parameters 48 | let mut optimiser = Adam::new(varmap.all_vars(), params)?; 49 | 50 | // loop for model optimisation 51 | for epoch in 0..100 { 52 | // run the model forwards 53 | // get log probabilities of results 54 | let logits = model.forward(&train_images)?; 55 | // softmax the log probabilities 56 | let log_sm = ops::log_softmax(&logits, D::Minus1)?; 57 | // get the loss 58 | let loss = loss::nll(&log_sm, &train_labels)?; 59 | // step the tensors by backpropagating the loss 60 | optimiser.backward_step(&loss)?; 61 | 62 | # // get the log probabilities of the test images 63 | # let test_logits = model.forward(&test_images)?; 64 | # // get the sum of the correct predictions 65 | # let sum_ok = test_logits 66 | # .argmax(D::Minus1)? 67 | # .eq(&test_labels)? 68 | # .to_dtype(DType::F32)? 69 | # .sum_all()? 70 | # .to_scalar::()?; 71 | # // get the accuracy on the test set 72 | # #[allow(clippy::cast_precision_loss)] 73 | # let test_accuracy = sum_ok / test_labels.dims1()? as f32; 74 | # println!( 75 | # "{:4} train loss: {:8.5} test acc: {:5.2}%", 76 | # epoch + 1, 77 | # loss.to_scalar::()?, 78 | # 100. * test_accuracy 79 | # ); 80 | } 81 | Ok(()) 82 | # } 83 | ``` 84 | */ 85 | 86 | use std::fmt::Debug; 87 | 88 | use candle_core::Result as CResult; 89 | use candle_core::Tensor; 90 | use candle_core::Var; 91 | pub mod adadelta; 92 | pub mod adagrad; 93 | pub mod adam; 94 | pub mod adamax; 95 | pub mod esgd; 96 | pub mod lbfgs; 97 | pub mod nadam; 98 | pub mod radam; 99 | pub mod rmsprop; 100 | 101 | /// Trait for optimisers to expose their parameters 102 | pub trait OptimParams: candle_nn::optim::Optimizer { 103 | /// get the current parameters of the Optimiser 104 | fn params(&self) -> &Self::Config; 105 | /// set the current parameters of the Optimiser 106 | fn set_params(&mut self, config: Self::Config); 107 | } 108 | 109 | /// Trait for Models: this is needed for optimisers that require the ability to calculate the loss 110 | /// such as LBFGS 111 | pub trait Model: Sized { 112 | /// get the loss of the model 113 | fn loss(&self) -> CResult; //, xs: &Tensor, ys: &Tensor 114 | } 115 | 116 | /// trait for optimisers like LBFGS that need the ability to calculate the loss 117 | /// and its gradient 118 | pub trait LossOptimizer<'a, M: Model>: Sized { 119 | /// type of the optimiser configuration 120 | type Config: Sized; 121 | /// create a new optimiser from a Vec of variables, setup parameters and a model 122 | fn new(vs: Vec, params: Self::Config, model: &'a M) -> CResult; 123 | /// take a step of the optimiser 124 | fn backward_step(&mut self, loss: &Tensor) -> CResult; //, xs: &Tensor, ys: &Tensor 125 | /// get the current learning rate 126 | fn learning_rate(&self) -> f64; 127 | /// set the learning rate 128 | fn set_learning_rate(&mut self, lr: f64); 129 | /// get the a vec of the variables being optimised 130 | fn into_inner(self) -> Vec; 131 | /// create a new optimiser from a slice of variables, setup parameters and a model 132 | fn from_slice(vars: &[&Var], config: Self::Config, model: &'a M) -> CResult { 133 | let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); 134 | Self::new(vars, config, model) 135 | } 136 | } 137 | 138 | /// Outcomes of an optimiser step for methods such as LBFGS 139 | #[derive(Debug)] 140 | pub enum ModelOutcome { 141 | /// The model took a step and the loss decreased 142 | /// contains next loss and the number of func evals 143 | Stepped(Tensor, usize), 144 | /// The model has converged and the loss has not changed 145 | /// contains loss and the number of func evals 146 | Converged(Tensor, usize), 147 | } 148 | 149 | /// Method of weight decay to use 150 | #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] 151 | pub enum Decay { 152 | /// Weight decay regularisation to penalise large weights 153 | /// 154 | /// The gradient is transformed as 155 | /// $$ g_{t} \\gets g_{t} + \\lambda \\theta_{t-1}$$ 156 | /// 157 | /// This is equivalent to an L2 regularisation term in the loss adding $\\frac{\\lambda}{2}||\theta||_{2}^{2}$ but avoids autodifferentiation 158 | /// of the L2 term 159 | WeightDecay(f64), 160 | /// Decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) 161 | /// 162 | /// This directly decays the weights as 163 | /// 164 | /// $$ \\theta_{t} \\gets (1 - \\eta \\lambda) \\theta_{t-1}$$ 165 | /// 166 | /// This is equivalent to regularisation, only for SGD without momentum, but is different for adaptive gradient methods 167 | DecoupledWeightDecay(f64), 168 | } 169 | 170 | /// Type of momentum to use 171 | #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] 172 | pub enum Momentum { 173 | /// classical momentum 174 | Classical(f64), 175 | /// nesterov momentum 176 | Nesterov(f64), 177 | } 178 | -------------------------------------------------------------------------------- /benches/training.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Result, Tensor, Var, D}; 2 | use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap}; 3 | 4 | use candle_optimisers::{ 5 | adadelta::{Adadelta, ParamsAdaDelta}, 6 | adagrad::{Adagrad, ParamsAdaGrad}, 7 | adam::{Adam, ParamsAdam}, 8 | adamax::{Adamax, ParamsAdaMax}, 9 | esgd::{ParamsSGD, SGD}, 10 | lbfgs::{Lbfgs, LineSearch, ParamsLBFGS}, 11 | nadam::{NAdam, ParamsNAdam}, 12 | radam::{ParamsRAdam, RAdam}, 13 | rmsprop::{ParamsRMSprop, RMSprop}, 14 | LossOptimizer, Model, 15 | }; 16 | 17 | pub trait Optim: Sized { 18 | fn new(vars: Vec) -> Result; 19 | fn back_step(&mut self, loss: &Tensor) -> Result<()>; 20 | } 21 | 22 | impl Optim for Adadelta { 23 | fn new(vars: Vec) -> Result { 24 | ::new(vars, ParamsAdaDelta::default()) 25 | } 26 | 27 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 28 | self.backward_step(loss) 29 | } 30 | } 31 | 32 | impl Optim for Adagrad { 33 | fn new(vars: Vec) -> Result { 34 | ::new(vars, ParamsAdaGrad::default()) 35 | } 36 | 37 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 38 | self.backward_step(loss) 39 | } 40 | } 41 | 42 | impl Optim for Adamax { 43 | fn new(vars: Vec) -> Result { 44 | ::new(vars, ParamsAdaMax::default()) 45 | } 46 | 47 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 48 | self.backward_step(loss) 49 | } 50 | } 51 | 52 | impl Optim for SGD { 53 | fn new(vars: Vec) -> Result { 54 | ::new(vars, ParamsSGD::default()) 55 | } 56 | 57 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 58 | self.backward_step(loss) 59 | } 60 | } 61 | 62 | impl Optim for NAdam { 63 | fn new(vars: Vec) -> Result { 64 | ::new(vars, ParamsNAdam::default()) 65 | } 66 | 67 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 68 | self.backward_step(loss) 69 | } 70 | } 71 | 72 | impl Optim for RAdam { 73 | fn new(vars: Vec) -> Result { 74 | ::new(vars, ParamsRAdam::default()) 75 | } 76 | 77 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 78 | self.backward_step(loss) 79 | } 80 | } 81 | 82 | impl Optim for RMSprop { 83 | fn new(vars: Vec) -> Result { 84 | ::new(vars, ParamsRMSprop::default()) 85 | } 86 | 87 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 88 | self.backward_step(loss) 89 | } 90 | } 91 | 92 | impl Optim for Adam { 93 | fn new(vars: Vec) -> Result { 94 | ::new(vars, ParamsAdam::default()) 95 | } 96 | 97 | fn back_step(&mut self, loss: &Tensor) -> Result<()> { 98 | self.backward_step(loss) 99 | } 100 | } 101 | 102 | const IMAGE_DIM: usize = 784; 103 | const LABELS: usize = 10; 104 | 105 | pub trait SimpleModel: Sized { 106 | fn new(vs: VarBuilder, train_data: Tensor, train_labels: Tensor) -> Result; 107 | fn forward(&self) -> Result; 108 | } 109 | 110 | pub struct Mlp { 111 | ln1: Linear, 112 | ln2: Linear, 113 | train_data: Tensor, 114 | train_labels: Tensor, 115 | } 116 | 117 | impl SimpleModel for Mlp { 118 | fn new(vs: VarBuilder, train_data: Tensor, train_labels: Tensor) -> Result { 119 | let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?; 120 | let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?; 121 | Ok(Self { 122 | ln1, 123 | ln2, 124 | train_data, 125 | train_labels, 126 | }) 127 | } 128 | 129 | fn forward(&self) -> Result { 130 | let xs = self.ln1.forward(&self.train_data)?; 131 | let xs = xs.relu()?; 132 | self.ln2.forward(&xs) 133 | } 134 | } 135 | 136 | impl Model for Mlp { 137 | fn loss(&self) -> Result { 138 | let logits = self.forward()?; 139 | // softmax the log probabilities 140 | let log_sm = ops::log_softmax(&logits, D::Minus1)?; 141 | // get the loss 142 | loss::nll(&log_sm, &self.train_labels) 143 | } 144 | } 145 | 146 | #[allow(clippy::module_name_repetitions)] 147 | pub fn run_training( 148 | m: &candle_datasets::vision::Dataset, 149 | ) -> anyhow::Result<()> { 150 | // check to see if cuda device availabke 151 | let dev = candle_core::Device::cuda_if_available(0)?; 152 | 153 | // get the labels from the dataset 154 | let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?; 155 | // get the input from the dataset and put on device 156 | let train_images = m.train_images.to_device(&dev)?; 157 | 158 | // create a new variable store 159 | let varmap = VarMap::new(); 160 | // create a new variable builder 161 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 162 | // create model from variables 163 | let model = M::new(vs.clone(), train_images, train_labels)?; 164 | 165 | // create an optimiser 166 | let mut optimiser = O::new(varmap.all_vars())?; 167 | // load the test images 168 | let _test_images = m.test_images.to_device(&dev)?; 169 | // load the test labels 170 | let _test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; 171 | 172 | for _epoch in 0..100 { 173 | // get the loss 174 | let loss = model.loss()?; 175 | // step the tensors by backpropagating the loss 176 | optimiser.back_step(&loss)?; 177 | } 178 | Ok(()) 179 | } 180 | 181 | #[allow(clippy::module_name_repetitions)] 182 | pub fn run_lbfgs_training( 183 | m: &candle_datasets::vision::Dataset, 184 | ) -> anyhow::Result<()> { 185 | // check to see if cuda device availabke 186 | let dev = candle_core::Device::cuda_if_available(0)?; 187 | 188 | // get the labels from the dataset 189 | let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?; 190 | // get the input from the dataset and put on device 191 | let train_images = m.train_images.to_device(&dev)?; 192 | 193 | // create a new variable store 194 | let varmap = VarMap::new(); 195 | // create a new variable builder 196 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 197 | // create model from variables 198 | let model = M::new(vs.clone(), train_images, train_labels)?; 199 | 200 | let params = ParamsLBFGS { 201 | lr: 1., 202 | history_size: 4, 203 | line_search: Some(LineSearch::StrongWolfe(1e-4, 0.9, 1e-9)), 204 | ..Default::default() 205 | }; 206 | 207 | let mut loss = model.loss()?; 208 | 209 | // create an optimiser 210 | let mut optimiser = Lbfgs::new(varmap.all_vars(), params, &model)?; 211 | // load the test images 212 | let _test_images = m.test_images.to_device(&dev)?; 213 | // load the test labels 214 | let _test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; 215 | 216 | for _epoch in 0..100 { 217 | // get the loss 218 | 219 | // step the tensors by backpropagating the loss 220 | let res = optimiser.backward_step(&loss)?; 221 | match res { 222 | candle_optimisers::ModelOutcome::Converged(_, _) => break, 223 | candle_optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, 224 | // _ => panic!("unexpected outcome"), 225 | } 226 | } 227 | Ok(()) 228 | } 229 | -------------------------------------------------------------------------------- /tests/adagrad_tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::adagrad::{Adagrad, ParamsAdaGrad}; 7 | 8 | /* The results of this test have been checked against the following PyTorch code. 9 | import torch 10 | from torch import optim 11 | 12 | w_gen = torch.tensor([[3., 1.]]) 13 | b_gen = torch.tensor([-2.]) 14 | 15 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 16 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 17 | 18 | m = torch.nn.Linear(2, 1) 19 | with torch.no_grad(): 20 | m.weight.zero_() 21 | m.bias.zero_() 22 | optimiser = optim.Adagrad(m.parameters(), lr=0.004, weight_decay=0.00) 23 | # optimiser.zero_grad() 24 | for _step in range(1000): 25 | optimiser.zero_grad() 26 | ys = m(sample_xs) 27 | loss = ((ys - sample_ys)**2).sum() 28 | loss.backward() 29 | optimiser.step() 30 | # print("Optimizer state begin") 31 | # print(optimiser.state) 32 | # print("Optimizer state end") 33 | print(m.weight) 34 | print(m.bias) 35 | */ 36 | #[test] 37 | fn adagrad_test() -> Result<()> { 38 | // Generate some linear data, y = 3.x1 + x2 - 2. 39 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 40 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 41 | let gen = Linear::new(w_gen, Some(b_gen)); 42 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 43 | let sample_ys = gen.forward(&sample_xs)?; 44 | 45 | let params = ParamsAdaGrad { 46 | lr: 0.004, 47 | lr_decay: 0.0, 48 | weight_decay: None, 49 | initial_acc: 0.0, 50 | eps: 1e-10, 51 | }; 52 | // Now use backprop to run a linear regression between samples and get the coefficients back. 53 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 54 | let b = Var::new(0f32, &Device::Cpu)?; 55 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?; 56 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 57 | for _step in 0..1000 { 58 | let ys = lin.forward(&sample_xs)?; 59 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 60 | n_sgd.backward_step(&loss)?; 61 | } 62 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.2424, 0.2341]]); 63 | assert_eq!(to_vec0_round(&b, 4)?, 0.2379); 64 | Ok(()) 65 | } 66 | 67 | /* The results of this test have been checked against the following PyTorch code. 68 | import torch 69 | from torch import optim 70 | 71 | w_gen = torch.tensor([[3., 1.]]) 72 | b_gen = torch.tensor([-2.]) 73 | 74 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 75 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 76 | 77 | m = torch.nn.Linear(2, 1) 78 | with torch.no_grad(): 79 | m.weight.zero_() 80 | m.bias.zero_() 81 | optimiser = optim.Adagrad(m.parameters(), lr=0.004, lr_decay=0.2) 82 | # optimiser.zero_grad() 83 | for _step in range(1000): 84 | optimiser.zero_grad() 85 | ys = m(sample_xs) 86 | loss = ((ys - sample_ys)**2).sum() 87 | loss.backward() 88 | optimiser.step() 89 | # print("Optimizer state begin") 90 | # print(optimiser.state) 91 | # print("Optimizer state end") 92 | print(m.weight) 93 | print(m.bias) 94 | */ 95 | #[test] 96 | fn adagrad_lr_decay_test() -> Result<()> { 97 | // Generate some linear data, y = 3.x1 + x2 - 2. 98 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 99 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 100 | let gen = Linear::new(w_gen, Some(b_gen)); 101 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 102 | let sample_ys = gen.forward(&sample_xs)?; 103 | 104 | let params = ParamsAdaGrad { 105 | lr: 0.004, 106 | lr_decay: 0.2, 107 | weight_decay: None, 108 | initial_acc: 0.0, 109 | eps: 1e-10, 110 | }; 111 | // Now use backprop to run a linear regression between samples and get the coefficients back. 112 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 113 | let b = Var::new(0f32, &Device::Cpu)?; 114 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?; 115 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 116 | for _step in 0..1000 { 117 | let ys = lin.forward(&sample_xs)?; 118 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 119 | n_sgd.backward_step(&loss)?; 120 | } 121 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0231, 0.0230]]); 122 | assert_eq!(to_vec0_round(&b, 4)?, 0.0230); 123 | Ok(()) 124 | } 125 | 126 | /* The results of this test have been checked against the following PyTorch code. 127 | import torch 128 | from torch import optim 129 | 130 | w_gen = torch.tensor([[3., 1.]]) 131 | b_gen = torch.tensor([-2.]) 132 | 133 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 134 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 135 | 136 | m = torch.nn.Linear(2, 1) 137 | with torch.no_grad(): 138 | m.weight.zero_() 139 | m.bias.zero_() 140 | optimiser = optim.Adagrad(m.parameters(), lr=0.004, weight_decay=0.2) 141 | # optimiser.zero_grad() 142 | for _step in range(1000): 143 | optimiser.zero_grad() 144 | ys = m(sample_xs) 145 | loss = ((ys - sample_ys)**2).sum() 146 | loss.backward() 147 | optimiser.step() 148 | # print("Optimizer state begin") 149 | # print(optimiser.state) 150 | # print("Optimizer state end") 151 | print(m.weight) 152 | print(m.bias) 153 | */ 154 | #[test] 155 | fn adagrad_weight_decay_test() -> Result<()> { 156 | // Generate some linear data, y = 3.x1 + x2 - 2. 157 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 158 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 159 | let gen = Linear::new(w_gen, Some(b_gen)); 160 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 161 | let sample_ys = gen.forward(&sample_xs)?; 162 | 163 | let params = ParamsAdaGrad { 164 | lr: 0.004, 165 | lr_decay: 0.0, 166 | weight_decay: Some(candle_optimisers::Decay::WeightDecay(0.2)), 167 | initial_acc: 0.0, 168 | eps: 1e-10, 169 | }; 170 | // Now use backprop to run a linear regression between samples and get the coefficients back. 171 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 172 | let b = Var::new(0f32, &Device::Cpu)?; 173 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?; 174 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 175 | for _step in 0..1000 { 176 | let ys = lin.forward(&sample_xs)?; 177 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 178 | n_sgd.backward_step(&loss)?; 179 | } 180 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.2424, 0.2341]]); 181 | assert_eq!(to_vec0_round(&b, 4)?, 0.2378); 182 | Ok(()) 183 | } 184 | 185 | //------------------------------------------------------------------------- 186 | // THIS IS NOT TESTED AGAINST PYTORCH 187 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR ADAGRAD 188 | // ------------------------------------------------------------------------ 189 | #[test] 190 | fn adagrad_decoupled_weight_decay_test() -> Result<()> { 191 | // Generate some linear data, y = 3.x1 + x2 - 2. 192 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 193 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 194 | let gen = Linear::new(w_gen, Some(b_gen)); 195 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 196 | let sample_ys = gen.forward(&sample_xs)?; 197 | 198 | let params = ParamsAdaGrad { 199 | lr: 0.004, 200 | lr_decay: 0.0, 201 | weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.2)), 202 | initial_acc: 0.0, 203 | eps: 1e-10, 204 | }; 205 | // Now use backprop to run a linear regression between samples and get the coefficients back. 206 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 207 | let b = Var::new(0f32, &Device::Cpu)?; 208 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?; 209 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 210 | for _step in 0..1000 { 211 | let ys = lin.forward(&sample_xs)?; 212 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 213 | n_sgd.backward_step(&loss)?; 214 | } 215 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1483, 0.1450]]); 216 | assert_eq!(to_vec0_round(&b, 4)?, 0.1465); 217 | Ok(()) 218 | } 219 | -------------------------------------------------------------------------------- /tests/lbfgs_tests.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use candle_core::test_utils::to_vec2_round; 3 | use candle_core::{DType, Device, Result as CResult, Tensor}; 4 | use candle_optimisers::lbfgs::{GradConv, Lbfgs, LineSearch, ParamsLBFGS, StepConv}; 5 | use candle_optimisers::{LossOptimizer, Model, ModelOutcome}; 6 | 7 | /* 8 | These tests all use the 2D Rosenbrock function as a test function for the optimisers. This has minimum 0 at (1, 1) 9 | */ 10 | 11 | #[derive(Debug, Clone)] 12 | pub struct RosenbrockModel { 13 | x_pos: candle_core::Var, 14 | y_pos: candle_core::Var, 15 | } 16 | 17 | impl Model for RosenbrockModel { 18 | fn loss(&self) -> CResult { 19 | //, xs: &Tensor, ys: &Tensor 20 | self.forward()?.squeeze(1)?.squeeze(0) 21 | } 22 | } 23 | 24 | impl RosenbrockModel { 25 | fn new() -> CResult { 26 | let x_pos = candle_core::Var::from_tensor( 27 | &(10. * Tensor::ones((1, 1), DType::F64, &Device::Cpu)?)?, 28 | )?; 29 | let y_pos = candle_core::Var::from_tensor( 30 | &(10. * Tensor::ones((1, 1), DType::F64, &Device::Cpu)?)?, 31 | )?; 32 | Ok(Self { x_pos, y_pos }) 33 | } 34 | fn vars(&self) -> Vec { 35 | vec![self.x_pos.clone(), self.y_pos.clone()] 36 | } 37 | 38 | fn forward(&self) -> CResult { 39 | //, xs: &Tensor 40 | (1. - self.x_pos.as_tensor())?.powf(2.)? 41 | + 100. * (self.y_pos.as_tensor() - self.x_pos.as_tensor().powf(2.)?)?.powf(2.)? 42 | } 43 | } 44 | 45 | #[test] 46 | fn lbfgs_test() -> Result<()> { 47 | let params = ParamsLBFGS { 48 | lr: 1., 49 | ..Default::default() 50 | }; 51 | 52 | let model = RosenbrockModel::new()?; 53 | 54 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?; 55 | let mut loss = model.loss()?; 56 | 57 | for _step in 0..500 { 58 | // println!("\nstart step {}", step); 59 | // for v in model.vars() { 60 | // println!("{}", v); 61 | // } 62 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys 63 | // println!("end step {}", _step); 64 | match res { 65 | ModelOutcome::Converged(_, _) => break, 66 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss, 67 | // _ => panic!("unexpected outcome"), 68 | } 69 | } 70 | 71 | for v in model.vars() { 72 | // println!("{}", v); 73 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]); 74 | } 75 | 76 | // println!("{:?}", lbfgs); 77 | // panic!("deliberate panic"); 78 | 79 | Ok(()) 80 | } 81 | 82 | #[test] 83 | fn lbfgs_test_strong_wolfe() -> Result<()> { 84 | let params = ParamsLBFGS { 85 | lr: 1., 86 | line_search: Some(LineSearch::StrongWolfe(1e-4, 0.9, 1e-9)), 87 | ..Default::default() 88 | }; 89 | 90 | let model = RosenbrockModel::new()?; 91 | 92 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?; 93 | let mut loss = model.loss()?; 94 | 95 | for _step in 0..500 { 96 | // println!("\nstart step {}", step); 97 | // for v in model.vars() { 98 | // println!("{}", v); 99 | // } 100 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys 101 | // println!("end step {}", _step); 102 | match res { 103 | ModelOutcome::Converged(_, _) => break, 104 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss, 105 | // _ => panic!("unexpected outcome"), 106 | } 107 | } 108 | 109 | for v in model.vars() { 110 | // println!("{}", v); 111 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]); 112 | } 113 | 114 | // println!("{:?}", lbfgs); 115 | // panic!("deliberate panic"); 116 | 117 | Ok(()) 118 | } 119 | 120 | #[test] 121 | fn lbfgs_rms_grad_test() -> Result<()> { 122 | let params = ParamsLBFGS { 123 | lr: 1., 124 | grad_conv: GradConv::RMSForce(1e-6), 125 | ..Default::default() 126 | }; 127 | 128 | let model = RosenbrockModel::new()?; 129 | 130 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?; 131 | let mut loss = model.loss()?; 132 | 133 | for _step in 0..500 { 134 | // println!("\nstart step {}", step); 135 | // for v in model.vars() { 136 | // println!("{}", v); 137 | // } 138 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys 139 | // println!("end step {}", _step); 140 | match res { 141 | ModelOutcome::Converged(_, _) => break, 142 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss, 143 | // _ => panic!("unexpected outcome"), 144 | } 145 | } 146 | 147 | for v in model.vars() { 148 | // println!("{}", v); 149 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]); 150 | } 151 | 152 | // println!("{:?}", lbfgs); 153 | // panic!("deliberate panic"); 154 | 155 | Ok(()) 156 | } 157 | 158 | #[test] 159 | fn lbfgs_rms_step_test() -> Result<()> { 160 | let params = ParamsLBFGS { 161 | lr: 1., 162 | grad_conv: GradConv::RMSForce(0.), 163 | step_conv: StepConv::RMSStep(1e-7), 164 | ..Default::default() 165 | }; 166 | 167 | let model = RosenbrockModel::new()?; 168 | 169 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?; 170 | let mut loss = model.loss()?; 171 | 172 | for _step in 0..500 { 173 | // println!("\nstart step {}", step); 174 | // for v in model.vars() { 175 | // println!("{}", v); 176 | // } 177 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys 178 | // println!("end step {}", _step); 179 | match res { 180 | ModelOutcome::Converged(_, _) => break, 181 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss, 182 | // _ => panic!("unexpected outcome"), 183 | } 184 | } 185 | 186 | for v in model.vars() { 187 | // println!("{}", v); 188 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]); 189 | } 190 | 191 | // println!("{:?}", lbfgs); 192 | // panic!("deliberate panic"); 193 | 194 | Ok(()) 195 | } 196 | 197 | #[test] 198 | fn lbfgs_test_strong_wolfe_weight_decay() -> Result<()> { 199 | let params = ParamsLBFGS { 200 | lr: 1., 201 | line_search: Some(LineSearch::StrongWolfe(1e-4, 0.9, 1e-9)), 202 | weight_decay: Some(0.1), 203 | ..Default::default() 204 | }; 205 | 206 | let model = RosenbrockModel::new()?; 207 | 208 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?; 209 | let mut loss = model.loss()?; 210 | 211 | for _step in 0..500 { 212 | // println!("\nstart step {}", step); 213 | // for v in model.vars() { 214 | // println!("{}", v); 215 | // } 216 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys 217 | // println!("end step {}", _step); 218 | match res { 219 | ModelOutcome::Converged(_, _) => break, 220 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss, 221 | // _ => panic!("unexpected outcome"), 222 | } 223 | } 224 | 225 | let expected = [0.8861, 0.7849]; // this should be properly checked 226 | for (v, e) in model.vars().iter().zip(expected) { 227 | // println!("{}", v); 228 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[e]]); 229 | } 230 | 231 | // println!("{:?}", lbfgs); 232 | // panic!("deliberate panic"); 233 | 234 | Ok(()) 235 | } 236 | 237 | #[test] 238 | fn lbfgs_test_weight_decay() -> Result<()> { 239 | let params = ParamsLBFGS { 240 | lr: 1., 241 | weight_decay: Some(0.1), 242 | ..Default::default() 243 | }; 244 | 245 | let model = RosenbrockModel::new()?; 246 | 247 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?; 248 | let mut loss = model.loss()?; 249 | 250 | for _step in 0..500 { 251 | // println!("\nstart step {}", step); 252 | // for v in model.vars() { 253 | // println!("{}", v); 254 | // } 255 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys 256 | // println!("end step {}", _step); 257 | match res { 258 | ModelOutcome::Converged(_, _) => break, 259 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss, 260 | // _ => panic!("unexpected outcome"), 261 | } 262 | } 263 | 264 | let expected = [0.8861, 0.7849]; // this should be properly checked 265 | for (v, e) in model.vars().iter().zip(expected) { 266 | // println!("{}", v); 267 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[e]]); 268 | } 269 | 270 | // println!("{:?}", lbfgs); 271 | // panic!("deliberate panic"); 272 | 273 | Ok(()) 274 | } 275 | -------------------------------------------------------------------------------- /src/adagrad.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | Adagrad optimiser 3 | 4 | Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html) 5 | 6 | Pseudocode (including decoupling of weight decay): 7 | 8 | $$ 9 | \\begin{aligned} 10 | &\\rule{110mm}{0.4pt} \\\\ 11 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) 12 | \\text{ (objective)}, \\: \\lambda \\text{ (weight decay)}, \\\\ 13 | &\\hspace{12mm} \\tau \\text{ (initial accumulator value)}, \\: \\eta\\text{ (lr decay)}\\\\ 14 | &\\textbf{initialize} : statesum_0 \\leftarrow 0 \\\\[-1.ex] 15 | &\\rule{110mm}{0.4pt} \\\\ 16 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ 17 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ 18 | &\\hspace{5mm} \\tilde{\\gamma} \\leftarrow \\gamma / (1 +(t-1) \\eta) \\\\ 19 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ 20 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ 21 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ 22 | &\\hspace{10mm}\\textbf{else} \\\\ 23 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ 24 | &\\hspace{5mm}statesum_t \\leftarrow statesum_{t-1} + g^2_t \\\\ 25 | &\\hspace{5mm}\\theta_t \\leftarrow 26 | \\theta_{t-1}- \\tilde{\\gamma} \\frac{g_t}{\\sqrt{statesum_t}+\\epsilon} \\\\ 27 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 28 | &\\bf{return} \\: \\theta_t \\\\[-1.ex] 29 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 30 | \\end{aligned} 31 | $$ 32 | 33 | 34 | 35 | */ 36 | 37 | use candle_core::{Result, Var}; 38 | use candle_nn::optim::Optimizer; 39 | 40 | use crate::{Decay, OptimParams}; 41 | 42 | /// Adagrad optimiser 43 | /// 44 | /// Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html) 45 | #[derive(Debug)] 46 | pub struct Adagrad { 47 | vars: Vec, 48 | params: ParamsAdaGrad, 49 | t: f64, 50 | } 51 | 52 | #[derive(Debug)] 53 | struct VarAdaGrad { 54 | theta: Var, 55 | sum: Var, 56 | } 57 | 58 | /// Parameters for the Adagrad optimiser 59 | #[derive(Clone, Debug, PartialEq, PartialOrd)] 60 | pub struct ParamsAdaGrad { 61 | /// Learning rate 62 | pub lr: f64, 63 | /// Learning rate decay 64 | pub lr_decay: f64, 65 | /// Initial value of accumulator 66 | pub initial_acc: f64, 67 | /// weight decay 68 | pub weight_decay: Option, 69 | /// term added to the denominator to improve numerical stability 70 | pub eps: f64, 71 | } 72 | 73 | impl Default for ParamsAdaGrad { 74 | fn default() -> Self { 75 | Self { 76 | lr: 0.01, 77 | lr_decay: 0.0, 78 | initial_acc: 0.0, 79 | weight_decay: None, 80 | eps: 1e-10, 81 | } 82 | } 83 | } 84 | 85 | impl Optimizer for Adagrad { 86 | type Config = ParamsAdaGrad; 87 | 88 | fn new(vars: Vec, params: ParamsAdaGrad) -> Result { 89 | let vars = vars 90 | .into_iter() 91 | .filter(|var| var.dtype().is_float()) 92 | .map(|var| { 93 | let dtype = var.dtype(); 94 | let shape = var.shape(); 95 | let device = var.device(); 96 | let sum = Var::zeros(shape, dtype, device)?; 97 | Ok(VarAdaGrad { theta: var, sum }) 98 | }) 99 | .collect::>>()?; 100 | // // Err(SGDError::NoMomentum)?; 101 | // let mut params = params; 102 | // params.t = 0; 103 | Ok(Self { 104 | vars, 105 | t: 0., 106 | params, 107 | }) 108 | } 109 | 110 | fn learning_rate(&self) -> f64 { 111 | self.params.lr 112 | } 113 | 114 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> { 115 | if let Some(decay) = self.params.weight_decay { 116 | match decay { 117 | Decay::WeightDecay(decay) => { 118 | for var in &self.vars { 119 | let theta = &var.theta; 120 | let sum = &var.sum; 121 | if let Some(grad) = grads.get(theta) { 122 | let gamma_tilde = 123 | self.params.lr / self.t.mul_add(self.params.lr_decay, 1.); 124 | let grad = &(grad + (decay * theta.as_tensor())?)?; 125 | let current_sum = (sum.as_tensor() + grad.powf(2.)?)?; 126 | let change = (gamma_tilde 127 | * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?; 128 | sum.set(¤t_sum)?; 129 | theta.set(&theta.sub(&change)?)?; 130 | } 131 | } 132 | } 133 | Decay::DecoupledWeightDecay(decay) => { 134 | for var in &self.vars { 135 | let theta = &var.theta; 136 | let sum = &var.sum; 137 | if let Some(grad) = grads.get(theta) { 138 | // decoupled weight decay step 139 | theta 140 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?; 141 | let gamma_tilde = 142 | self.params.lr / self.t.mul_add(self.params.lr_decay, 1.); 143 | let current_sum = (sum.as_tensor() + grad.powf(2.)?)?; 144 | let change = (gamma_tilde 145 | * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?; 146 | sum.set(¤t_sum)?; 147 | theta.set(&theta.sub(&change)?)?; 148 | } 149 | } 150 | } 151 | } 152 | } else { 153 | for var in &self.vars { 154 | let theta = &var.theta; 155 | let sum = &var.sum; 156 | if let Some(grad) = grads.get(theta) { 157 | let gamma_tilde = self.params.lr / self.t.mul_add(self.params.lr_decay, 1.); 158 | let current_sum = (sum.as_tensor() + grad.powf(2.)?)?; 159 | let change = 160 | (gamma_tilde * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?; 161 | sum.set(¤t_sum)?; 162 | theta.set(&theta.sub(&change)?)?; 163 | } 164 | } 165 | } 166 | self.t += 1.; 167 | Ok(()) 168 | } 169 | 170 | fn set_learning_rate(&mut self, lr: f64) { 171 | self.params.lr = lr; 172 | } 173 | } 174 | 175 | impl OptimParams for Adagrad { 176 | fn params(&self) -> &Self::Config { 177 | &self.params 178 | } 179 | 180 | fn set_params(&mut self, config: Self::Config) { 181 | self.params = config; 182 | } 183 | } 184 | 185 | impl Adagrad { 186 | /// Return the vars being optimised 187 | #[must_use] 188 | pub fn into_inner(self) -> Vec { 189 | self.vars.into_iter().map(|v| v.theta).collect() 190 | } 191 | 192 | // pub fn push(&mut self, var: &Var) { 193 | // self.vars.push(var.clone()); 194 | // } 195 | } 196 | 197 | #[cfg(test)] 198 | mod tests { 199 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 200 | 201 | use anyhow::Result; 202 | use assert_approx_eq::assert_approx_eq; 203 | use candle_core::{Device, Var}; 204 | use candle_nn::Optimizer; 205 | 206 | use super::*; 207 | #[test] 208 | fn lr_test() -> Result<()> { 209 | let params = ParamsAdaGrad { 210 | lr: 0.004, 211 | ..Default::default() 212 | }; 213 | // Now use backprop to run a linear regression between samples and get the coefficients back. 214 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 215 | let b = Var::new(0f32, &Device::Cpu)?; 216 | let mut optim = Adagrad::new(vec![w.clone(), b.clone()], params)?; 217 | assert_approx_eq!(0.004, optim.learning_rate()); 218 | optim.set_learning_rate(0.002); 219 | assert_approx_eq!(0.002, optim.learning_rate()); 220 | Ok(()) 221 | } 222 | 223 | #[test] 224 | fn into_inner_test() -> Result<()> { 225 | let params = ParamsAdaGrad::default(); 226 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?; 227 | let b = Var::new(-2f32, &Device::Cpu)?; 228 | let optim = Adagrad::new(vec![w.clone(), b.clone()], params)?; 229 | let inner = optim.into_inner(); 230 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]); 231 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32); 232 | Ok(()) 233 | } 234 | 235 | #[test] 236 | fn params_test() -> Result<()> { 237 | let params = ParamsAdaGrad { 238 | lr: 0.004, 239 | ..Default::default() 240 | }; 241 | // Now use backprop to run a linear regression between samples and get the coefficients back. 242 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 243 | let b = Var::new(0f32, &Device::Cpu)?; 244 | let mut optim = Adagrad::new(vec![w.clone(), b.clone()], params.clone())?; 245 | assert_eq!(params, optim.params().clone()); 246 | let new_params = ParamsAdaGrad { 247 | lr: 0.002, 248 | ..Default::default() 249 | }; 250 | optim.set_params(new_params.clone()); 251 | assert_eq!(new_params, optim.params().clone()); 252 | Ok(()) 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /src/adamax.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | Adamax optimiser 3 | 4 | An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) 5 | 6 | Pseudocode (including decoupling of weight decay): 7 | 8 | $$ 9 | \\begin{aligned} 10 | &\\rule{110mm}{0.4pt} \\\\ 11 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2 12 | \\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)}, 13 | \\: \\lambda \\text{ (weight decay)}, \\\\ 14 | &\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\ 15 | &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)}, 16 | u_0 \\leftarrow 0 \\text{ ( infinity norm)} \\\\[-1.ex] 17 | &\\rule{110mm}{0.4pt} \\\\ 18 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ 19 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ 20 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ 21 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ 22 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ 23 | &\\hspace{10mm}\\textbf{else} \\\\ 24 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ 25 | &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ 26 | &\\hspace{5mm}u_t \\leftarrow \\mathrm{max}(\\beta_2 u_{t-1}, |g_{t}|+\\epsilon) \\\\ 27 | &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\frac{\\gamma m_t}{(1-\\beta^t_1) u_t} \\\\ 28 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 29 | &\\bf{return} \\: \\theta_t \\\\[-1.ex] 30 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 31 | \\end{aligned} 32 | $$ 33 | */ 34 | 35 | use candle_core::{Result, Var}; 36 | use candle_nn::optim::Optimizer; 37 | 38 | use crate::{Decay, OptimParams}; 39 | 40 | /// Adamax optimiser 41 | /// 42 | /// An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) 43 | 44 | #[derive(Debug)] 45 | pub struct Adamax { 46 | vars: Vec, 47 | params: ParamsAdaMax, 48 | t: f64, 49 | } 50 | 51 | #[derive(Debug)] 52 | struct VarAdaMax { 53 | theta: Var, 54 | m: Var, 55 | u: Var, 56 | } 57 | 58 | /// Parameters for the Adamax optimiser 59 | #[derive(Clone, Debug, PartialEq, PartialOrd)] 60 | pub struct ParamsAdaMax { 61 | /// Learning rate 62 | pub lr: f64, 63 | /// Coefficient for moving average of first moment 64 | pub beta_1: f64, 65 | /// Coefficient for moving average of second moment 66 | pub beta_2: f64, 67 | /// Weight decay 68 | pub weight_decay: Option, 69 | /// Term added to denominator to improve numerical stability 70 | pub eps: f64, 71 | } 72 | 73 | impl Default for ParamsAdaMax { 74 | fn default() -> Self { 75 | Self { 76 | lr: 1.0, 77 | beta_1: 0.9, 78 | beta_2: 0.999, 79 | weight_decay: None, 80 | eps: 1e-8, 81 | } 82 | } 83 | } 84 | 85 | impl Optimizer for Adamax { 86 | type Config = ParamsAdaMax; 87 | 88 | fn new(vars: Vec, params: ParamsAdaMax) -> Result { 89 | let vars = vars 90 | .into_iter() 91 | .filter(|var| var.dtype().is_float()) 92 | .map(|var| { 93 | let dtype = var.dtype(); 94 | let shape = var.shape(); 95 | let device = var.device(); 96 | let m = Var::zeros(shape, dtype, device)?; 97 | let u = Var::zeros(shape, dtype, device)?; 98 | Ok(VarAdaMax { theta: var, m, u }) 99 | }) 100 | .collect::>>()?; 101 | // // Err(SGDError::NoMomentum)?; 102 | // let mut params = params; 103 | // params.t = 0; 104 | Ok(Self { 105 | vars, 106 | params, 107 | t: 1., 108 | }) 109 | } 110 | 111 | fn learning_rate(&self) -> f64 { 112 | self.params.lr 113 | } 114 | 115 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> { 116 | if let Some(decay) = self.params.weight_decay { 117 | match decay { 118 | Decay::WeightDecay(decay) => { 119 | for var in &self.vars { 120 | let theta = &var.theta; 121 | let m = &var.m; 122 | let u = &var.u; 123 | if let Some(grad) = grads.get(theta) { 124 | let grad = &(grad + (decay * theta.as_tensor())?)?; 125 | let m_next = ((self.params.beta_1 * m.as_tensor())? 126 | + (1. - self.params.beta_1) * grad)?; 127 | let u_next = (self.params.beta_2 * u.as_tensor())? 128 | .maximum(&(grad.abs()? + self.params.eps)?)?; 129 | let delta = (&m_next * self.params.lr)? 130 | .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?; 131 | theta.set(&theta.sub(&(delta))?)?; 132 | m.set(&m_next)?; 133 | u.set(&u_next)?; 134 | } 135 | } 136 | } 137 | Decay::DecoupledWeightDecay(decay) => { 138 | for var in &self.vars { 139 | let theta = &var.theta; 140 | let m = &var.m; 141 | let u = &var.u; 142 | if let Some(grad) = grads.get(theta) { 143 | // decoupled weight decay step 144 | theta 145 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?; 146 | let m_next = ((self.params.beta_1 * m.as_tensor())? 147 | + (1. - self.params.beta_1) * grad)?; 148 | let u_next = (self.params.beta_2 * u.as_tensor())? 149 | .maximum(&(grad.abs()? + self.params.eps)?)?; 150 | let delta = (&m_next * self.params.lr)? 151 | .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?; 152 | theta.set(&theta.sub(&(delta))?)?; 153 | m.set(&m_next)?; 154 | u.set(&u_next)?; 155 | } 156 | } 157 | } 158 | } 159 | } else { 160 | for var in &self.vars { 161 | let theta = &var.theta; 162 | let m = &var.m; 163 | let u = &var.u; 164 | if let Some(grad) = grads.get(theta) { 165 | let m_next = 166 | ((self.params.beta_1 * m.as_tensor())? + (1. - self.params.beta_1) * grad)?; 167 | let u_next = (self.params.beta_2 * u.as_tensor())? 168 | .maximum(&(grad.abs()? + self.params.eps)?)?; 169 | let delta = (&m_next * self.params.lr)? 170 | .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?; 171 | theta.set(&theta.sub(&(delta))?)?; 172 | m.set(&m_next)?; 173 | u.set(&u_next)?; 174 | } 175 | } 176 | } 177 | self.t += 1.; 178 | Ok(()) 179 | } 180 | 181 | fn set_learning_rate(&mut self, lr: f64) { 182 | self.params.lr = lr; 183 | } 184 | } 185 | 186 | impl OptimParams for Adamax { 187 | fn params(&self) -> &Self::Config { 188 | &self.params 189 | } 190 | 191 | fn set_params(&mut self, config: Self::Config) { 192 | self.params = config; 193 | } 194 | } 195 | 196 | impl Adamax { 197 | /// Return the vars being optimised 198 | #[must_use] 199 | pub fn into_inner(self) -> Vec { 200 | self.vars.into_iter().map(|v| v.theta).collect() 201 | } 202 | 203 | // pub fn push(&mut self, var: &Var) { 204 | // self.vars.push(var.clone()); 205 | // } 206 | } 207 | 208 | #[cfg(test)] 209 | mod tests { 210 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 211 | 212 | use anyhow::Result; 213 | use assert_approx_eq::assert_approx_eq; 214 | use candle_core::{Device, Var}; 215 | use candle_nn::Optimizer; 216 | 217 | use super::*; 218 | #[test] 219 | fn lr_test() -> Result<()> { 220 | let params = ParamsAdaMax { 221 | lr: 0.004, 222 | ..Default::default() 223 | }; 224 | // Now use backprop to run a linear regression between samples and get the coefficients back. 225 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 226 | let b = Var::new(0f32, &Device::Cpu)?; 227 | let mut optim = Adamax::new(vec![w.clone(), b.clone()], params)?; 228 | assert_approx_eq!(0.004, optim.learning_rate()); 229 | optim.set_learning_rate(0.002); 230 | assert_approx_eq!(0.002, optim.learning_rate()); 231 | Ok(()) 232 | } 233 | 234 | #[test] 235 | fn into_inner_test() -> Result<()> { 236 | let params = ParamsAdaMax::default(); 237 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?; 238 | let b = Var::new(-2f32, &Device::Cpu)?; 239 | let optim = Adamax::new(vec![w.clone(), b.clone()], params)?; 240 | let inner = optim.into_inner(); 241 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]); 242 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32); 243 | Ok(()) 244 | } 245 | 246 | #[test] 247 | fn params_test() -> Result<()> { 248 | let params = ParamsAdaMax { 249 | lr: 0.004, 250 | ..Default::default() 251 | }; 252 | // Now use backprop to run a linear regression between samples and get the coefficients back. 253 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 254 | let b = Var::new(0f32, &Device::Cpu)?; 255 | let mut optim = Adamax::new(vec![w.clone(), b.clone()], params.clone())?; 256 | assert_eq!(params, optim.params().clone()); 257 | let new_params = ParamsAdaMax { 258 | lr: 0.002, 259 | ..Default::default() 260 | }; 261 | optim.set_params(new_params.clone()); 262 | assert_eq!(new_params, optim.params().clone()); 263 | Ok(()) 264 | } 265 | } 266 | -------------------------------------------------------------------------------- /src/adadelta.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | Adadelta optimiser 3 | 4 | Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701) 5 | 6 | Pseudocode (including decoupling of weight decay): 7 | $$ 8 | \\begin{aligned} 9 | &\\rule{110mm}{0.4pt} \\\\ 10 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, 11 | \\: f(\\theta) \\text{ (objective)}, \\: \\rho \\text{ (decay)}, 12 | \\: \\lambda \\text{ (weight decay)} \\\\ 13 | &\\textbf{initialize} : v_0 \\leftarrow 0 \\: \\text{ (square avg)}, 14 | \\: u_0 \\leftarrow 0 \\: \\text{ (accumulate variables)} \\\\[-1.ex] 15 | &\\rule{110mm}{0.4pt} \\\\ 16 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ 17 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ 18 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ 19 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ 20 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ 21 | &\\hspace{10mm}\\textbf{else} \\\\ 22 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ 23 | &\\hspace{5mm} v_t \\leftarrow v_{t-1} \\rho + g^2_t (1 - \\rho) \\\\ 24 | &\\hspace{5mm}\\Delta x_t \\leftarrow \\frac{\\sqrt{u_{t-1} + 25 | \\epsilon }}{ \\sqrt{v_t + \\epsilon} }g_t \\hspace{21mm} \\\\ 26 | &\\hspace{5mm} u_t \\leftarrow u_{t-1} \\rho + 27 | \\Delta x^2_t (1 - \\rho) \\\\ 28 | &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\Delta x_t \\\\ 29 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 30 | &\\bf{return} \\: \\theta_t \\\\[-1.ex] 31 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 32 | \\end{aligned} 33 | $$ 34 | */ 35 | 36 | use candle_core::{Result, Var}; 37 | use candle_nn::optim::Optimizer; 38 | 39 | use crate::{Decay, OptimParams}; 40 | 41 | /// Adadelta optimiser 42 | /// 43 | /// Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701) 44 | #[derive(Debug)] 45 | pub struct Adadelta { 46 | vars: Vec, 47 | params: ParamsAdaDelta, 48 | // avg_acc: HashMap, 49 | } 50 | 51 | #[derive(Debug)] 52 | struct VarAdaDelta { 53 | theta: Var, 54 | v: Var, 55 | u: Var, 56 | } 57 | 58 | /// Parameters for the Adadelta optimiser 59 | #[derive(Clone, Debug, PartialEq, PartialOrd)] 60 | pub struct ParamsAdaDelta { 61 | /// Learning rate 62 | pub lr: f64, 63 | /// Decay 64 | pub rho: f64, 65 | /// Term added to the denominator to improve numerical stability 66 | pub eps: f64, 67 | /// Weight decay 68 | pub weight_decay: Option, 69 | } 70 | 71 | impl Default for ParamsAdaDelta { 72 | fn default() -> Self { 73 | Self { 74 | lr: 1.0, 75 | rho: 0.9, 76 | weight_decay: None, 77 | eps: 1e-6, 78 | } 79 | } 80 | } 81 | 82 | impl Optimizer for Adadelta { 83 | type Config = ParamsAdaDelta; 84 | 85 | fn new(vars: Vec, params: ParamsAdaDelta) -> Result { 86 | let vars = vars 87 | .into_iter() 88 | .filter(|var| var.dtype().is_float()) 89 | .map(|var| { 90 | let dtype = var.dtype(); 91 | let shape = var.shape(); 92 | let device = var.device(); 93 | let v = Var::zeros(shape, dtype, device)?; 94 | let u = Var::zeros(shape, dtype, device)?; 95 | Ok(VarAdaDelta { theta: var, v, u }) 96 | }) 97 | .collect::>>()?; 98 | // // Err(SGDError::NoMomentum)?; 99 | // let mut params = params; 100 | // params.t = 0; 101 | Ok(Self { 102 | vars, 103 | params, 104 | // avg_acc: HashMap::new(), 105 | }) 106 | } 107 | 108 | fn learning_rate(&self) -> f64 { 109 | self.params.lr 110 | } 111 | 112 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> { 113 | if let Some(decay) = self.params.weight_decay { 114 | match decay { 115 | Decay::WeightDecay(decay) => { 116 | for var in &self.vars { 117 | let theta = &var.theta; 118 | let v = &var.v; 119 | let u = &var.u; 120 | if let Some(grad) = grads.get(theta) { 121 | let grad = &(grad + (decay * theta.as_tensor())?)?; 122 | let v_next = ((v.as_tensor() * self.params.rho)? 123 | + (1. - self.params.rho) * grad.powf(2.)?)?; 124 | let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?) 125 | .div(&((&v_next + self.params.eps)?.powf(0.5)?))? 126 | * grad)?; 127 | let u_next = ((u.as_tensor() * self.params.rho)? 128 | + (1. - self.params.rho) * delta_x.powf(2.)?)?; 129 | theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?; 130 | v.set(&v_next)?; 131 | u.set(&u_next)?; 132 | } 133 | } 134 | } 135 | Decay::DecoupledWeightDecay(decay) => { 136 | for var in &self.vars { 137 | let theta = &var.theta; 138 | let v = &var.v; 139 | let u = &var.u; 140 | if let Some(grad) = grads.get(theta) { 141 | // decoupled weight decay step 142 | theta 143 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?; 144 | let v_next = ((v.as_tensor() * self.params.rho)? 145 | + (1. - self.params.rho) * grad.powf(2.)?)?; 146 | let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?) 147 | .div(&((&v_next + self.params.eps)?.powf(0.5)?))? 148 | * grad)?; 149 | let u_next = ((u.as_tensor() * self.params.rho)? 150 | + (1. - self.params.rho) * delta_x.powf(2.)?)?; 151 | theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?; 152 | v.set(&v_next)?; 153 | u.set(&u_next)?; 154 | } 155 | } 156 | } 157 | } 158 | } else { 159 | for var in &self.vars { 160 | let theta = &var.theta; 161 | let v = &var.v; 162 | let u = &var.u; 163 | if let Some(grad) = grads.get(theta) { 164 | let v_next = ((v.as_tensor() * self.params.rho)? 165 | + (1. - self.params.rho) * grad.powf(2.)?)?; 166 | let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?) 167 | .div(&((&v_next + self.params.eps)?.powf(0.5)?))? 168 | * grad)?; 169 | let u_next = ((u.as_tensor() * self.params.rho)? 170 | + (1. - self.params.rho) * delta_x.powf(2.)?)?; 171 | theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?; 172 | v.set(&v_next)?; 173 | u.set(&u_next)?; 174 | } 175 | } 176 | } 177 | 178 | Ok(()) 179 | } 180 | 181 | fn set_learning_rate(&mut self, lr: f64) { 182 | self.params.lr = lr; 183 | } 184 | } 185 | 186 | impl OptimParams for Adadelta { 187 | fn params(&self) -> &Self::Config { 188 | &self.params 189 | } 190 | 191 | fn set_params(&mut self, config: Self::Config) { 192 | self.params = config; 193 | } 194 | } 195 | 196 | impl Adadelta { 197 | /// Return the vars being optimised 198 | #[must_use] 199 | pub fn into_inner(self) -> Vec { 200 | self.vars.into_iter().map(|v| v.theta).collect() 201 | } 202 | 203 | // pub fn push(&mut self, var: &Var) { 204 | // self.vars.push(var.clone()); 205 | // } 206 | } 207 | 208 | #[cfg(test)] 209 | mod tests { 210 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 211 | 212 | use anyhow::Result; 213 | use assert_approx_eq::assert_approx_eq; 214 | use candle_core::{Device, Var}; 215 | use candle_nn::Optimizer; 216 | 217 | use super::*; 218 | #[test] 219 | fn lr_test() -> Result<()> { 220 | let params = ParamsAdaDelta { 221 | lr: 0.004, 222 | ..Default::default() 223 | }; 224 | // Now use backprop to run a linear regression between samples and get the coefficients back. 225 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 226 | let b = Var::new(0f32, &Device::Cpu)?; 227 | let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params)?; 228 | assert_approx_eq!(0.004, optim.learning_rate()); 229 | optim.set_learning_rate(0.002); 230 | assert_approx_eq!(0.002, optim.learning_rate()); 231 | Ok(()) 232 | } 233 | 234 | #[test] 235 | fn into_inner_test() -> Result<()> { 236 | let params = ParamsAdaDelta::default(); 237 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?; 238 | let b = Var::new(-2f32, &Device::Cpu)?; 239 | let optim = Adadelta::new(vec![w.clone(), b.clone()], params)?; 240 | let inner = optim.into_inner(); 241 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]); 242 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32); 243 | Ok(()) 244 | } 245 | 246 | #[test] 247 | fn params_test() -> Result<()> { 248 | let params = ParamsAdaDelta { 249 | lr: 0.004, 250 | ..Default::default() 251 | }; 252 | // Now use backprop to run a linear regression between samples and get the coefficients back. 253 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 254 | let b = Var::new(0f32, &Device::Cpu)?; 255 | let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params.clone())?; 256 | assert_eq!(params, optim.params().clone()); 257 | let new_params = ParamsAdaDelta { 258 | lr: 0.002, 259 | ..Default::default() 260 | }; 261 | optim.set_params(new_params.clone()); 262 | assert_eq!(new_params, optim.params().clone()); 263 | Ok(()) 264 | } 265 | } 266 | -------------------------------------------------------------------------------- /src/nadam.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | NAdam optimiser: Adam with Nesterov momentum 3 | 4 | Described in [Incorporating Nesterov Momentum into Adam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) 5 | 6 | Pseudocode (including decoupling of weight decay): 7 | 8 | $$ 9 | \\begin{aligned} 10 | &\\rule{110mm}{0.4pt} \\\\ 11 | &\\textbf{input} : \\gamma_t \\text{ (lr)}, \\: \\beta_1,\\beta_2 \\text{ (betas)}, 12 | \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) \\text{ (objective)} \\\\ 13 | &\\hspace{12mm} \\: \\lambda \\text{ (weight decay)}, \\:\\psi \\text{ (momentum decay)} \\\\ 14 | &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)}, 15 | v_0 \\leftarrow 0 \\text{ ( second moment)} \\\\[-1.ex] 16 | &\\rule{110mm}{0.4pt} \\\\ 17 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ 18 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ 19 | &\\hspace{5mm} \\theta_t \\leftarrow \\theta_{t-1} \\\\ 20 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ 21 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ 22 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ 23 | &\\hspace{10mm}\\textbf{else} \\\\ 24 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ 25 | &\\hspace{5mm} \\mu_t \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2} 0.96^{t \\psi} \\big) \\\\ 26 | &\\hspace{5mm} \\mu_{t+1} \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2} 0.96^{(t+1)\\psi}\\big)\\\\ 27 | &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ 28 | &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\ 29 | &\\hspace{5mm}\\widehat{m_t} \\leftarrow \\mu_{t+1} m_t/(1-\\prod_{i=1}^{t+1}\\mu_i)\\\\[-1.ex] 30 | & \\hspace{11mm} + (1-\\mu_t) g_t /(1-\\prod_{i=1}^{t} \\mu_{i}) \\\\ 31 | &\\hspace{5mm}\\widehat{v_t} \\leftarrow v_t/\\big(1-\\beta_2^t \\big) \\\\ 32 | &\\hspace{5mm}\\theta_t \\leftarrow \\theta_t - \\gamma \\widehat{m_t}/ 33 | \\big(\\sqrt{\\widehat{v_t}} + \\epsilon \\big) \\\\ 34 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 35 | &\\bf{return} \\: \\theta_t \\\\[-1.ex] 36 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 37 | \\end{aligned} 38 | $$ 39 | */ 40 | 41 | use candle_core::{Result, Var}; 42 | use candle_nn::optim::Optimizer; 43 | 44 | use crate::{Decay, OptimParams}; 45 | 46 | /// Adam optimiser with Nesterov momentum 47 | /// 48 | /// Described in 49 | #[derive(Debug)] 50 | pub struct NAdam { 51 | vars: Vec, 52 | params: ParamsNAdam, 53 | mu_t: f64, 54 | mu_t2: f64, 55 | prod: f64, 56 | prod2: f64, 57 | t: f64, 58 | } 59 | 60 | #[derive(Debug)] 61 | struct VarNAdam { 62 | theta: Var, 63 | m: Var, 64 | v: Var, 65 | } 66 | 67 | /// Parameters for The NAdam optimiser 68 | #[derive(Clone, Debug, PartialEq, PartialOrd)] 69 | pub struct ParamsNAdam { 70 | /// Learning rate 71 | pub lr: f64, 72 | /// Coefficient for moving average of first moment 73 | pub beta_1: f64, 74 | /// Coefficient for moving average of second moment 75 | pub beta_2: f64, 76 | /// Term added to denominator to improve numerical stability 77 | pub eps: f64, 78 | /// Weight decay 79 | pub weight_decay: Option, 80 | /// Momentum decay 81 | pub momentum_decay: f64, 82 | } 83 | 84 | impl Default for ParamsNAdam { 85 | fn default() -> Self { 86 | Self { 87 | lr: 0.002, 88 | beta_1: 0.9, 89 | beta_2: 0.999, 90 | eps: 1e-8, 91 | weight_decay: None, 92 | momentum_decay: 0.004, 93 | } 94 | } 95 | } 96 | 97 | impl Optimizer for NAdam { 98 | type Config = ParamsNAdam; 99 | 100 | fn new(vars: Vec, params: ParamsNAdam) -> Result { 101 | let vars = vars 102 | .into_iter() 103 | .filter(|var| var.dtype().is_float()) 104 | .map(|var| { 105 | let dtype = var.dtype(); 106 | let shape = var.shape(); 107 | let device = var.device(); 108 | let m = Var::zeros(shape, dtype, device)?; 109 | let v = Var::zeros(shape, dtype, device)?; 110 | Ok(VarNAdam { theta: var, m, v }) 111 | }) 112 | .collect::>>()?; 113 | // // Err(SGDError::NoMomentum)?; 114 | // let mut params = params; 115 | // params.t = 0; 116 | let t = 1.; 117 | let mu_t2 = params.beta_1 * 0.5f64.mul_add(-(0.96_f64.powf(t * params.momentum_decay)), 1.); 118 | Ok(Self { 119 | vars, 120 | params, 121 | t: 1., 122 | mu_t: 1., 123 | mu_t2, 124 | prod: 1., 125 | prod2: mu_t2, 126 | }) 127 | } 128 | 129 | fn learning_rate(&self) -> f64 { 130 | self.params.lr 131 | } 132 | 133 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> { 134 | let mu_t = self.mu_t2; 135 | let mu_t2 = self.params.beta_1 136 | * 0.5f64.mul_add( 137 | -(0.96_f64.powf((self.t + 1.) * self.params.momentum_decay)), 138 | 1., 139 | ); 140 | let prod = self.prod2; 141 | let prod2 = prod * mu_t2; 142 | self.mu_t = mu_t; 143 | self.mu_t2 = mu_t2; 144 | self.prod = prod; 145 | self.prod2 = prod2; 146 | // println!("prod {}", prod); 147 | 148 | if let Some(decay) = self.params.weight_decay { 149 | match decay { 150 | Decay::WeightDecay(decay) => { 151 | for var in &self.vars { 152 | let theta = &var.theta; 153 | let m = &var.m; 154 | let v = &var.v; 155 | if let Some(grad) = grads.get(theta) { 156 | let grad = &(grad + (decay * theta.as_tensor())?)?; 157 | let m_next = ((self.params.beta_1 * m.as_tensor())? 158 | + ((1. - self.params.beta_1) * grad)?)?; 159 | let v_next = ((self.params.beta_2 * v.as_tensor())? 160 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?; 161 | let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)? 162 | + (((1. - mu_t) / (1. - prod)) * grad)?)?; 163 | let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?; 164 | let delta = (m_hat * self.params.lr)? 165 | .div(&(v_hat.powf(0.5)? + self.params.eps)?)?; 166 | theta.set(&theta.sub(&(delta))?)?; 167 | m.set(&m_next)?; 168 | v.set(&v_next)?; 169 | } 170 | } 171 | } 172 | Decay::DecoupledWeightDecay(decay) => { 173 | for var in &self.vars { 174 | let theta = &var.theta; 175 | let m = &var.m; 176 | let v = &var.v; 177 | if let Some(grad) = grads.get(theta) { 178 | theta 179 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?; 180 | let m_next = ((self.params.beta_1 * m.as_tensor())? 181 | + ((1. - self.params.beta_1) * grad)?)?; 182 | let v_next = ((self.params.beta_2 * v.as_tensor())? 183 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?; 184 | let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)? 185 | + (((1. - mu_t) / (1. - prod)) * grad)?)?; 186 | let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?; 187 | let delta = (m_hat * self.params.lr)? 188 | .div(&(v_hat.powf(0.5)? + self.params.eps)?)?; 189 | theta.set(&theta.sub(&(delta))?)?; 190 | m.set(&m_next)?; 191 | v.set(&v_next)?; 192 | } 193 | } 194 | } 195 | } 196 | } else { 197 | for var in &self.vars { 198 | let theta = &var.theta; 199 | let m = &var.m; 200 | let v = &var.v; 201 | if let Some(grad) = grads.get(theta) { 202 | let m_next = ((self.params.beta_1 * m.as_tensor())? 203 | + ((1. - self.params.beta_1) * grad)?)?; 204 | let v_next = ((self.params.beta_2 * v.as_tensor())? 205 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?; 206 | let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)? 207 | + (((1. - mu_t) / (1. - prod)) * grad)?)?; 208 | let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?; 209 | let delta = 210 | (m_hat * self.params.lr)?.div(&(v_hat.powf(0.5)? + self.params.eps)?)?; 211 | theta.set(&theta.sub(&(delta))?)?; 212 | m.set(&m_next)?; 213 | v.set(&v_next)?; 214 | } 215 | } 216 | } 217 | 218 | self.t += 1.; 219 | Ok(()) 220 | } 221 | 222 | fn set_learning_rate(&mut self, lr: f64) { 223 | self.params.lr = lr; 224 | } 225 | } 226 | 227 | impl OptimParams for NAdam { 228 | fn params(&self) -> &Self::Config { 229 | &self.params 230 | } 231 | 232 | fn set_params(&mut self, config: Self::Config) { 233 | self.params = config; 234 | } 235 | } 236 | 237 | impl NAdam { 238 | /// Return the vars being optimised 239 | #[must_use] 240 | pub fn into_inner(self) -> Vec { 241 | self.vars.into_iter().map(|v| v.theta).collect() 242 | } 243 | 244 | // pub fn push(&mut self, var: &Var) { 245 | // self.vars.push(var.clone()); 246 | // } 247 | } 248 | 249 | #[cfg(test)] 250 | mod tests { 251 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 252 | 253 | use anyhow::Result; 254 | use assert_approx_eq::assert_approx_eq; 255 | use candle_core::{Device, Var}; 256 | use candle_nn::Optimizer; 257 | 258 | use super::*; 259 | #[test] 260 | fn lr_test() -> Result<()> { 261 | let params = ParamsNAdam { 262 | lr: 0.004, 263 | ..Default::default() 264 | }; 265 | // Now use backprop to run a linear regression between samples and get the coefficients back. 266 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 267 | let b = Var::new(0f32, &Device::Cpu)?; 268 | let mut optim = NAdam::new(vec![w.clone(), b.clone()], params)?; 269 | assert_approx_eq!(0.004, optim.learning_rate()); 270 | optim.set_learning_rate(0.002); 271 | assert_approx_eq!(0.002, optim.learning_rate()); 272 | Ok(()) 273 | } 274 | 275 | #[test] 276 | fn into_inner_test() -> Result<()> { 277 | let params = ParamsNAdam::default(); 278 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?; 279 | let b = Var::new(-2f32, &Device::Cpu)?; 280 | let optim = NAdam::new(vec![w.clone(), b.clone()], params)?; 281 | let inner = optim.into_inner(); 282 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]); 283 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32); 284 | Ok(()) 285 | } 286 | 287 | #[test] 288 | fn params_test() -> Result<()> { 289 | let params = ParamsNAdam { 290 | lr: 0.004, 291 | ..Default::default() 292 | }; 293 | // Now use backprop to run a linear regression between samples and get the coefficients back. 294 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 295 | let b = Var::new(0f32, &Device::Cpu)?; 296 | let mut optim = NAdam::new(vec![w.clone(), b.clone()], params.clone())?; 297 | assert_eq!(params, optim.params().clone()); 298 | let new_params = ParamsNAdam { 299 | lr: 0.002, 300 | ..Default::default() 301 | }; 302 | optim.set_params(new_params.clone()); 303 | assert_eq!(new_params, optim.params().clone()); 304 | Ok(()) 305 | } 306 | } 307 | -------------------------------------------------------------------------------- /tests/adam_tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::{ 7 | adam::{Adam, ParamsAdam}, 8 | Decay, 9 | }; 10 | 11 | /* The results of this test have been checked against the following PyTorch code. 12 | import torch 13 | from torch import optim 14 | 15 | w_gen = torch.tensor([[3., 1.]]) 16 | b_gen = torch.tensor([-2.]) 17 | 18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 20 | 21 | m = torch.nn.Linear(2, 1) 22 | with torch.no_grad(): 23 | m.weight.zero_() 24 | m.bias.zero_() 25 | optimiser = optim.Adam(m.parameters()) 26 | # optimiser.zero_grad() 27 | for _step in range(1000): 28 | optimiser.zero_grad() 29 | ys = m(sample_xs) 30 | loss = ((ys - sample_ys)**2).sum() 31 | loss.backward() 32 | optimiser.step() 33 | # print("Optimizer state begin") 34 | # print(optimiser.state) 35 | # print("Optimizer state end") 36 | print(m.weight) 37 | print(m.bias) 38 | */ 39 | #[test] 40 | fn adam_test() -> Result<()> { 41 | // Generate some linear data, y = 3.x1 + x2 - 2. 42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 44 | let gen = Linear::new(w_gen, Some(b_gen)); 45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 46 | let sample_ys = gen.forward(&sample_xs)?; 47 | 48 | let params = ParamsAdam::default(); 49 | // Now use backprop to run a linear regression between samples and get the coefficients back. 50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 51 | let b = Var::new(0f32, &Device::Cpu)?; 52 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?; 53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 54 | for _step in 0..1000 { 55 | let ys = lin.forward(&sample_xs)?; 56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 57 | n_sgd.backward_step(&loss)?; 58 | } 59 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9000, 0.6967]]); 60 | assert_eq!(to_vec0_round(&b, 4)?, 0.7996); 61 | Ok(()) 62 | } 63 | 64 | /* The results of this test have been checked against the following PyTorch code. 65 | import torch 66 | from torch import optim 67 | 68 | w_gen = torch.tensor([[3., 1.]]) 69 | b_gen = torch.tensor([-2.]) 70 | 71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 73 | 74 | m = torch.nn.Linear(2, 1) 75 | with torch.no_grad(): 76 | m.weight.zero_() 77 | m.bias.zero_() 78 | optimiser = optim.Adam(m.parameters(), weight_decay = 0.6) 79 | # optimiser.zero_grad() 80 | for _step in range(1000): 81 | optimiser.zero_grad() 82 | ys = m(sample_xs) 83 | loss = ((ys - sample_ys)**2).sum() 84 | loss.backward() 85 | optimiser.step() 86 | # print("Optimizer state begin") 87 | # print(optimiser.state) 88 | # print("Optimizer state end") 89 | print(m.weight) 90 | print(m.bias) 91 | */ 92 | #[test] 93 | fn adam_weight_decay_test() -> Result<()> { 94 | // Generate some linear data, y = 3.x1 + x2 - 2. 95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 97 | let gen = Linear::new(w_gen, Some(b_gen)); 98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 99 | let sample_ys = gen.forward(&sample_xs)?; 100 | 101 | let params = ParamsAdam { 102 | weight_decay: Some(Decay::WeightDecay(0.6)), 103 | ..Default::default() 104 | }; 105 | // Now use backprop to run a linear regression between samples and get the coefficients back. 106 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 107 | let b = Var::new(0f32, &Device::Cpu)?; 108 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?; 109 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 110 | for _step in 0..1000 { 111 | let ys = lin.forward(&sample_xs)?; 112 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 113 | n_sgd.backward_step(&loss)?; 114 | } 115 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.8997, 0.6964]]); 116 | assert_eq!(to_vec0_round(&b, 4)?, 0.7975); 117 | Ok(()) 118 | } 119 | 120 | /* The results of this test have been checked against the following PyTorch code. 121 | import torch 122 | from torch import optim 123 | 124 | w_gen = torch.tensor([[3., 1.]]) 125 | b_gen = torch.tensor([-2.]) 126 | 127 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 128 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 129 | 130 | m = torch.nn.Linear(2, 1) 131 | with torch.no_grad(): 132 | m.weight.zero_() 133 | m.bias.zero_() 134 | optimiser = optim.AdamW(m.parameters(), weight_decay = 0.6) 135 | # optimiser.zero_grad() 136 | for _step in range(1000): 137 | optimiser.zero_grad() 138 | ys = m(sample_xs) 139 | loss = ((ys - sample_ys)**2).sum() 140 | loss.backward() 141 | optimiser.step() 142 | # print("Optimizer state begin") 143 | # print(optimiser.state) 144 | # print("Optimizer state end") 145 | print(m.weight) 146 | print(m.bias) 147 | */ 148 | #[test] 149 | fn adamw_weight_decay_test() -> Result<()> { 150 | // Generate some linear data, y = 3.x1 + x2 - 2. 151 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 152 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 153 | let gen = Linear::new(w_gen, Some(b_gen)); 154 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 155 | let sample_ys = gen.forward(&sample_xs)?; 156 | 157 | let params = ParamsAdam { 158 | weight_decay: Some(Decay::DecoupledWeightDecay(0.6)), 159 | // decoupled_weight_decay: true, 160 | ..Default::default() 161 | }; 162 | // Now use backprop to run a linear regression between samples and get the coefficients back. 163 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 164 | let b = Var::new(0f32, &Device::Cpu)?; 165 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?; 166 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 167 | for _step in 0..1000 { 168 | let ys = lin.forward(&sample_xs)?; 169 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 170 | n_sgd.backward_step(&loss)?; 171 | } 172 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.6901, 0.5677]]); 173 | assert_eq!(to_vec0_round(&b, 4)?, 0.6287); 174 | Ok(()) 175 | } 176 | 177 | /* The results of this test have been checked against the following PyTorch code. 178 | import torch 179 | from torch import optim 180 | 181 | w_gen = torch.tensor([[3., 1.]]) 182 | b_gen = torch.tensor([-2.]) 183 | 184 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 185 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 186 | 187 | m = torch.nn.Linear(2, 1) 188 | with torch.no_grad(): 189 | m.weight.zero_() 190 | m.bias.zero_() 191 | optimiser = optim.Adam(m.parameters(), amsgrad=True) 192 | # optimiser.zero_grad() 193 | for _step in range(1000): 194 | optimiser.zero_grad() 195 | ys = m(sample_xs) 196 | loss = ((ys - sample_ys)**2).sum() 197 | loss.backward() 198 | optimiser.step() 199 | # print("Optimizer state begin") 200 | # print(optimiser.state) 201 | # print("Optimizer state end") 202 | print(m.weight) 203 | print(m.bias) 204 | */ 205 | #[test] 206 | fn adam_amsgrad_test() -> Result<()> { 207 | // Generate some linear data, y = 3.x1 + x2 - 2. 208 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 209 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 210 | let gen = Linear::new(w_gen, Some(b_gen)); 211 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 212 | let sample_ys = gen.forward(&sample_xs)?; 213 | 214 | let params = ParamsAdam { 215 | amsgrad: true, 216 | ..Default::default() 217 | }; 218 | // Now use backprop to run a linear regression between samples and get the coefficients back. 219 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 220 | let b = Var::new(0f32, &Device::Cpu)?; 221 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?; 222 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 223 | for _step in 0..1000 { 224 | let ys = lin.forward(&sample_xs)?; 225 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 226 | n_sgd.backward_step(&loss)?; 227 | } 228 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9001, 0.6904]]); 229 | assert_eq!(to_vec0_round(&b, 4)?, 0.7978); 230 | Ok(()) 231 | } 232 | 233 | /* The results of this test have been checked against the following PyTorch code. 234 | import torch 235 | from torch import optim 236 | 237 | w_gen = torch.tensor([[3., 1.]]) 238 | b_gen = torch.tensor([-2.]) 239 | 240 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 241 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 242 | 243 | m = torch.nn.Linear(2, 1) 244 | with torch.no_grad(): 245 | m.weight.zero_() 246 | m.bias.zero_() 247 | optimiser = optim.Adam(m.parameters(), amsgrad=True, weight_decay = 0.6) 248 | # optimiser.zero_grad() 249 | for _step in range(1000): 250 | optimiser.zero_grad() 251 | ys = m(sample_xs) 252 | loss = ((ys - sample_ys)**2).sum() 253 | loss.backward() 254 | optimiser.step() 255 | # print("Optimizer state begin") 256 | # print(optimiser.state) 257 | # print("Optimizer state end") 258 | print(m.weight) 259 | print(m.bias) 260 | */ 261 | #[test] 262 | fn adam_amsgrad_decay_test() -> Result<()> { 263 | // Generate some linear data, y = 3.x1 + x2 - 2. 264 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 265 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 266 | let gen = Linear::new(w_gen, Some(b_gen)); 267 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 268 | let sample_ys = gen.forward(&sample_xs)?; 269 | 270 | let params = ParamsAdam { 271 | amsgrad: true, 272 | weight_decay: Some(Decay::WeightDecay(0.6)), 273 | ..Default::default() 274 | }; 275 | // Now use backprop to run a linear regression between samples and get the coefficients back. 276 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 277 | let b = Var::new(0f32, &Device::Cpu)?; 278 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?; 279 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 280 | for _step in 0..1000 { 281 | let ys = lin.forward(&sample_xs)?; 282 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 283 | n_sgd.backward_step(&loss)?; 284 | } 285 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.8998, 0.6901]]); 286 | assert_eq!(to_vec0_round(&b, 4)?, 0.7955); 287 | Ok(()) 288 | } 289 | 290 | /* The results of this test have been checked against the following PyTorch code. 291 | import torch 292 | from torch import optim 293 | 294 | w_gen = torch.tensor([[3., 1.]]) 295 | b_gen = torch.tensor([-2.]) 296 | 297 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 298 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 299 | 300 | m = torch.nn.Linear(2, 1) 301 | with torch.no_grad(): 302 | m.weight.zero_() 303 | m.bias.zero_() 304 | optimiser = optim.AdamW(m.parameters(), amsgrad=True, weight_decay = 0.6) 305 | # optimiser.zero_grad() 306 | for _step in range(1000): 307 | optimiser.zero_grad() 308 | ys = m(sample_xs) 309 | loss = ((ys - sample_ys)**2).sum() 310 | loss.backward() 311 | optimiser.step() 312 | # print("Optimizer state begin") 313 | # print(optimiser.state) 314 | # print("Optimizer state end") 315 | print(m.weight) 316 | print(m.bias) 317 | */ 318 | #[test] 319 | fn adamw_amsgrad_decay_test() -> Result<()> { 320 | // Generate some linear data, y = 3.x1 + x2 - 2. 321 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 322 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 323 | let gen = Linear::new(w_gen, Some(b_gen)); 324 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 325 | let sample_ys = gen.forward(&sample_xs)?; 326 | 327 | let params = ParamsAdam { 328 | weight_decay: Some(Decay::DecoupledWeightDecay(0.6)), 329 | amsgrad: true, 330 | // decoupled_weight_decay: true, 331 | ..Default::default() 332 | }; 333 | // Now use backprop to run a linear regression between samples and get the coefficients back. 334 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 335 | let b = Var::new(0f32, &Device::Cpu)?; 336 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?; 337 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 338 | for _step in 0..1000 { 339 | let ys = lin.forward(&sample_xs)?; 340 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 341 | n_sgd.backward_step(&loss)?; 342 | } 343 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.6901, 0.5648]]); 344 | assert_eq!(to_vec0_round(&b, 4)?, 0.6287); 345 | Ok(()) 346 | } 347 | -------------------------------------------------------------------------------- /src/radam.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | RAdam optimiser 3 | 4 | Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) 5 | 6 | As decoupled weight decay is implemented, this can be used equivalent to the paper (which uses decoupled weight decay), 7 | or the PyTorch implementation (which does not) 8 | 9 | Pseudocode (including decoupling of weight decay): 10 | 11 | $$ 12 | \\begin{aligned} 13 | &\\rule{110mm}{0.4pt} \\\\ 14 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\beta_1, \\beta_2 15 | \\text{ (betas)}, \\: \\theta_0 \\text{ (params)}, \\:f(\\theta) \\text{ (objective)}, \\: 16 | \\lambda \\text{ (weightdecay)}, \\\\ 17 | &\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\ 18 | &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)}, 19 | v_0 \\leftarrow 0 \\text{ ( second moment)}, \\\\ 20 | &\\hspace{18mm} \\rho_{\\infty} \\leftarrow 2/(1-\\beta_2) -1 \\\\[-1.ex] 21 | &\\rule{110mm}{0.4pt} \\\\ 22 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ 23 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ 24 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ 25 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ 26 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ 27 | &\\hspace{10mm}\\textbf{else} \\\\ 28 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ 29 | &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ 30 | &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\ 31 | &\\hspace{5mm}\\widehat{m_t} \\leftarrow m_t/\\big(1-\\beta_1^t \\big) \\\\ 32 | &\\hspace{5mm}\\rho_t \\leftarrow \\rho_{\\infty} - 33 | 2 t \\beta^t_2 /\\big(1-\\beta_2^t \\big) \\\\[0.1.ex] 34 | &\\hspace{5mm}\\textbf{if} \\: \\rho_t > 5 \\\\ 35 | &\\hspace{10mm} l_t \\leftarrow \\frac{\\sqrt{ (1-\\beta^t_2) }}{ \\sqrt{v_t} +\\epsilon } \\\\ 36 | &\\hspace{10mm} r_t \\leftarrow 37 | \\sqrt{\\frac{(\\rho_t-4)(\\rho_t-2)\\rho_{\\infty}}{(\\rho_{\\infty}-4)(\\rho_{\\infty}-2) \\rho_t}} \\\\ 38 | &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} r_t l_t \\\\ 39 | &\\hspace{5mm}\\textbf{else} \\\\ 40 | &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} \\\\ 41 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 42 | &\\bf{return} \\: \\theta_t \\\\[-1.ex] 43 | &\\rule{110mm}{0.4pt} \\\\[-1.ex] 44 | \\end{aligned} 45 | $$ 46 | */ 47 | 48 | use candle_core::{Result, Var}; 49 | use candle_nn::optim::Optimizer; 50 | 51 | use crate::{Decay, OptimParams}; 52 | 53 | /// R Adam optimiser 54 | /// 55 | /// Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) 56 | 57 | #[derive(Debug)] 58 | pub struct RAdam { 59 | vars: Vec, 60 | params: ParamsRAdam, 61 | rho_inf: f64, 62 | t: f64, 63 | } 64 | 65 | #[derive(Debug)] 66 | struct VarRAdam { 67 | theta: Var, 68 | m: Var, 69 | v: Var, 70 | } 71 | 72 | /// Parameters for the RAdam optimiser 73 | #[derive(Clone, Debug, PartialEq, PartialOrd)] 74 | pub struct ParamsRAdam { 75 | /// Learning rate 76 | pub lr: f64, 77 | /// Coefficient for moving average of first moment 78 | pub beta_1: f64, 79 | /// Coefficient for moving average of second moment 80 | pub beta_2: f64, 81 | /// Weight decay 82 | pub weight_decay: Option, 83 | /// Term added to denominator to improve numerical stability 84 | pub eps: f64, 85 | } 86 | 87 | impl Default for ParamsRAdam { 88 | fn default() -> Self { 89 | Self { 90 | lr: 0.001, 91 | beta_1: 0.9, 92 | beta_2: 0.999, 93 | eps: 1e-8, 94 | weight_decay: None, 95 | } 96 | } 97 | } 98 | 99 | impl Optimizer for RAdam { 100 | type Config = ParamsRAdam; 101 | 102 | fn new(vars: Vec, params: ParamsRAdam) -> Result { 103 | let vars = vars 104 | .into_iter() 105 | .filter(|var| var.dtype().is_float()) 106 | .map(|var| { 107 | let dtype = var.dtype(); 108 | let shape = var.shape(); 109 | let device = var.device(); 110 | let m = Var::zeros(shape, dtype, device)?; 111 | let v = Var::zeros(shape, dtype, device)?; 112 | Ok(VarRAdam { theta: var, m, v }) 113 | }) 114 | .collect::>>()?; 115 | // // Err(SGDError::NoMomentum)?; 116 | // let mut params = params; 117 | // params.t = 0; 118 | let rho_inf = 2. / (1. - params.beta_2) - 1.; 119 | Ok(Self { 120 | vars, 121 | params, 122 | rho_inf, 123 | t: 1., 124 | }) 125 | } 126 | 127 | fn learning_rate(&self) -> f64 { 128 | self.params.lr 129 | } 130 | 131 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> { 132 | // println!("prod {}", prod); 133 | let rho_t = self.rho_inf 134 | - 2. * self.t * self.params.beta_2.powf(self.t) 135 | / (1. - self.params.beta_2.powf(self.t)); 136 | 137 | if let Some(wd) = self.params.weight_decay { 138 | match wd { 139 | Decay::WeightDecay(wd) => { 140 | for var in &self.vars { 141 | let theta = &var.theta; 142 | let m = &var.m; 143 | let v = &var.v; 144 | if let Some(grad) = grads.get(theta) { 145 | let grad = &(grad + (wd * theta.as_tensor())?)?; 146 | let m_next = ((self.params.beta_1 * m.as_tensor())? 147 | + ((1. - self.params.beta_1) * grad)?)?; 148 | let v_next = ((self.params.beta_2 * v.as_tensor())? 149 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?; 150 | let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?; 151 | 152 | let delta = if rho_t > 5. { 153 | let l = ((1. - self.params.beta_2.powf(self.t)).sqrt() 154 | / (&v_next.sqrt()? + self.params.eps)?)?; 155 | let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf 156 | / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t)) 157 | .sqrt(); 158 | (self.params.lr * r * (l * m_hat)?)? 159 | } else { 160 | (self.params.lr * m_hat)? 161 | }; 162 | theta.set(&theta.sub(&(delta))?)?; 163 | m.set(&m_next)?; 164 | v.set(&v_next)?; 165 | } 166 | } 167 | } 168 | Decay::DecoupledWeightDecay(decay) => { 169 | for var in &self.vars { 170 | let theta = &var.theta; 171 | let m = &var.m; 172 | let v = &var.v; 173 | if let Some(grad) = grads.get(theta) { 174 | // decoupled weight decay step 175 | theta 176 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?; 177 | let m_next = ((self.params.beta_1 * m.as_tensor())? 178 | + ((1. - self.params.beta_1) * grad)?)?; 179 | let v_next = ((self.params.beta_2 * v.as_tensor())? 180 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?; 181 | let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?; 182 | 183 | let delta = if rho_t > 5. { 184 | let l = ((1. - self.params.beta_2.powf(self.t)).sqrt() 185 | / (&v_next.sqrt()? + self.params.eps)?)?; 186 | let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf 187 | / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t)) 188 | .sqrt(); 189 | (self.params.lr * r * (l * m_hat)?)? 190 | } else { 191 | (self.params.lr * m_hat)? 192 | }; 193 | theta.set(&theta.sub(&(delta))?)?; 194 | m.set(&m_next)?; 195 | v.set(&v_next)?; 196 | } 197 | } 198 | } 199 | } 200 | } else { 201 | for var in &self.vars { 202 | let theta = &var.theta; 203 | let m = &var.m; 204 | let v = &var.v; 205 | if let Some(grad) = grads.get(theta) { 206 | let m_next = ((self.params.beta_1 * m.as_tensor())? 207 | + ((1. - self.params.beta_1) * grad)?)?; 208 | let v_next = ((self.params.beta_2 * v.as_tensor())? 209 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?; 210 | let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?; 211 | 212 | let delta = if rho_t > 5. { 213 | let l = ((1. - self.params.beta_2.powf(self.t)).sqrt() 214 | / (&v_next.sqrt()? + self.params.eps)?)?; 215 | let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf 216 | / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t)) 217 | .sqrt(); 218 | (self.params.lr * r * (l * m_hat)?)? 219 | } else { 220 | (self.params.lr * m_hat)? 221 | }; 222 | theta.set(&theta.sub(&(delta))?)?; 223 | m.set(&m_next)?; 224 | v.set(&v_next)?; 225 | } 226 | } 227 | } 228 | 229 | self.t += 1.; 230 | Ok(()) 231 | } 232 | 233 | fn set_learning_rate(&mut self, lr: f64) { 234 | self.params.lr = lr; 235 | } 236 | } 237 | 238 | impl OptimParams for RAdam { 239 | fn params(&self) -> &Self::Config { 240 | &self.params 241 | } 242 | 243 | fn set_params(&mut self, config: Self::Config) { 244 | self.params = config; 245 | } 246 | } 247 | 248 | impl RAdam { 249 | /// Return the vars being optimised 250 | #[must_use] 251 | pub fn into_inner(self) -> Vec { 252 | self.vars.into_iter().map(|v| v.theta).collect() 253 | } 254 | 255 | // pub fn push(&mut self, var: &Var) { 256 | // self.vars.push(var.clone()); 257 | // } 258 | } 259 | 260 | #[cfg(test)] 261 | mod tests { 262 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 263 | 264 | use anyhow::Result; 265 | use assert_approx_eq::assert_approx_eq; 266 | use candle_core::{Device, Var}; 267 | use candle_nn::Optimizer; 268 | 269 | use super::*; 270 | #[test] 271 | fn lr_test() -> Result<()> { 272 | let params = ParamsRAdam { 273 | lr: 0.004, 274 | ..Default::default() 275 | }; 276 | // Now use backprop to run a linear regression between samples and get the coefficients back. 277 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 278 | let b = Var::new(0f32, &Device::Cpu)?; 279 | let mut optim = RAdam::new(vec![w.clone(), b.clone()], params)?; 280 | assert_approx_eq!(0.004, optim.learning_rate()); 281 | optim.set_learning_rate(0.002); 282 | assert_approx_eq!(0.002, optim.learning_rate()); 283 | Ok(()) 284 | } 285 | 286 | #[test] 287 | fn into_inner_test() -> Result<()> { 288 | let params = ParamsRAdam::default(); 289 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?; 290 | let b = Var::new(-2f32, &Device::Cpu)?; 291 | let optim = RAdam::new(vec![w.clone(), b.clone()], params)?; 292 | let inner = optim.into_inner(); 293 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]); 294 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32); 295 | Ok(()) 296 | } 297 | 298 | #[test] 299 | fn params_test() -> Result<()> { 300 | let params = ParamsRAdam { 301 | lr: 0.004, 302 | ..Default::default() 303 | }; 304 | // Now use backprop to run a linear regression between samples and get the coefficients back. 305 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 306 | let b = Var::new(0f32, &Device::Cpu)?; 307 | let mut optim = RAdam::new(vec![w.clone(), b.clone()], params.clone())?; 308 | assert_eq!(params, optim.params().clone()); 309 | let new_params = ParamsRAdam { 310 | lr: 0.002, 311 | ..Default::default() 312 | }; 313 | optim.set_params(new_params.clone()); 314 | assert_eq!(new_params, optim.params().clone()); 315 | Ok(()) 316 | } 317 | } 318 | -------------------------------------------------------------------------------- /tests/rmsprop-tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::rmsprop::{ParamsRMSprop, RMSprop}; 7 | 8 | /* The results of this test have been checked against the following PyTorch code. 9 | import torch 10 | from torch import optim 11 | 12 | w_gen = torch.tensor([[3., 1.]]) 13 | b_gen = torch.tensor([-2.]) 14 | 15 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 16 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 17 | 18 | m = torch.nn.Linear(2, 1) 19 | with torch.no_grad(): 20 | m.weight.zero_() 21 | m.bias.zero_() 22 | optimiser = optim.RMSprop(m.parameters()) 23 | # optimiser.zero_grad() 24 | for _step in range(100): 25 | optimiser.zero_grad() 26 | ys = m(sample_xs) 27 | loss = ((ys - sample_ys)**2).sum() 28 | loss.backward() 29 | optimiser.step() 30 | # print("Optimizer state begin") 31 | # print(optimiser.state) 32 | # print("Optimizer state end") 33 | print(m.weight) 34 | print(m.bias) 35 | */ 36 | #[test] 37 | fn rmsprop_test() -> Result<()> { 38 | // Generate some linear data, y = 3.x1 + x2 - 2. 39 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 40 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 41 | let gen = Linear::new(w_gen, Some(b_gen)); 42 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 43 | let sample_ys = gen.forward(&sample_xs)?; 44 | 45 | let params = ParamsRMSprop::default(); 46 | // Now use backprop to run a linear regression between samples and get the coefficients back. 47 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 48 | let b = Var::new(0f32, &Device::Cpu)?; 49 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 50 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 51 | for _step in 0..100 { 52 | let ys = lin.forward(&sample_xs)?; 53 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 54 | n_sgd.backward_step(&loss)?; 55 | } 56 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.6650, 0.7867]]); 57 | assert_eq!(to_vec0_round(&b, 4)?, 1.3012); 58 | Ok(()) 59 | } 60 | 61 | /* The results of this test have been checked against the following PyTorch code. 62 | import torch 63 | from torch import optim 64 | 65 | w_gen = torch.tensor([[3., 1.]]) 66 | b_gen = torch.tensor([-2.]) 67 | 68 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 69 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 70 | 71 | m = torch.nn.Linear(2, 1) 72 | with torch.no_grad(): 73 | m.weight.zero_() 74 | m.bias.zero_() 75 | optimiser = optim.RMSprop(m.parameters(), weight_decay = 0.4) 76 | # optimiser.zero_grad() 77 | for _step in range(100): 78 | optimiser.zero_grad() 79 | ys = m(sample_xs) 80 | loss = ((ys - sample_ys)**2).sum() 81 | loss.backward() 82 | optimiser.step() 83 | # print("Optimizer state begin") 84 | # print(optimiser.state) 85 | # print("Optimizer state end") 86 | print(m.weight) 87 | print(m.bias) 88 | */ 89 | #[test] 90 | fn rmsprop_weight_decay_test() -> Result<()> { 91 | // Generate some linear data, y = 3.x1 + x2 - 2. 92 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 93 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 94 | let gen = Linear::new(w_gen, Some(b_gen)); 95 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 96 | let sample_ys = gen.forward(&sample_xs)?; 97 | 98 | let params = ParamsRMSprop { 99 | weight_decay: Some(0.4), 100 | ..Default::default() 101 | }; 102 | // Now use backprop to run a linear regression between samples and get the coefficients back. 103 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 104 | let b = Var::new(0f32, &Device::Cpu)?; 105 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 106 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 107 | for _step in 0..100 { 108 | let ys = lin.forward(&sample_xs)?; 109 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 110 | n_sgd.backward_step(&loss)?; 111 | } 112 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.6643, 0.7867]]); 113 | assert_eq!(to_vec0_round(&b, 4)?, 1.2926); 114 | Ok(()) 115 | } 116 | 117 | /* The results of this test have been checked against the following PyTorch code. 118 | import torch 119 | from torch import optim 120 | 121 | w_gen = torch.tensor([[3., 1.]]) 122 | b_gen = torch.tensor([-2.]) 123 | 124 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 125 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 126 | 127 | m = torch.nn.Linear(2, 1) 128 | with torch.no_grad(): 129 | m.weight.zero_() 130 | m.bias.zero_() 131 | optimiser = optim.RMSprop(m.parameters(), centered = True) 132 | # optimiser.zero_grad() 133 | for _step in range(100): 134 | optimiser.zero_grad() 135 | ys = m(sample_xs) 136 | loss = ((ys - sample_ys)**2).sum() 137 | loss.backward() 138 | optimiser.step() 139 | # print("Optimizer state begin") 140 | # print(optimiser.state) 141 | # print("Optimizer state end") 142 | print(m.weight) 143 | print(m.bias) 144 | */ 145 | #[test] 146 | fn rmsprop_centered_test() -> Result<()> { 147 | // Generate some linear data, y = 3.x1 + x2 - 2. 148 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 149 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 150 | let gen = Linear::new(w_gen, Some(b_gen)); 151 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 152 | let sample_ys = gen.forward(&sample_xs)?; 153 | 154 | let params = ParamsRMSprop { 155 | centered: true, 156 | ..Default::default() 157 | }; 158 | // Now use backprop to run a linear regression between samples and get the coefficients back. 159 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 160 | let b = Var::new(0f32, &Device::Cpu)?; 161 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 162 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 163 | for _step in 0..100 { 164 | let ys = lin.forward(&sample_xs)?; 165 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 166 | n_sgd.backward_step(&loss)?; 167 | } 168 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.8892, 0.7617]]); 169 | assert_eq!(to_vec0_round(&b, 4)?, 1.3688); 170 | Ok(()) 171 | } 172 | 173 | /* The results of this test have been checked against the following PyTorch code. 174 | import torch 175 | from torch import optim 176 | 177 | w_gen = torch.tensor([[3., 1.]]) 178 | b_gen = torch.tensor([-2.]) 179 | 180 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 181 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 182 | 183 | m = torch.nn.Linear(2, 1) 184 | with torch.no_grad(): 185 | m.weight.zero_() 186 | m.bias.zero_() 187 | optimiser = optim.RMSprop(m.parameters(), centered = True) 188 | # optimiser.zero_grad() 189 | for _step in range(100): 190 | optimiser.zero_grad() 191 | ys = m(sample_xs) 192 | loss = ((ys - sample_ys)**2).sum() 193 | loss.backward() 194 | optimiser.step() 195 | # print("Optimizer state begin") 196 | # print(optimiser.state) 197 | # print("Optimizer state end") 198 | print(m.weight) 199 | print(m.bias) 200 | */ 201 | #[test] 202 | fn rmsprop_centered_decay_test() -> Result<()> { 203 | // Generate some linear data, y = 3.x1 + x2 - 2. 204 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 205 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 206 | let gen = Linear::new(w_gen, Some(b_gen)); 207 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 208 | let sample_ys = gen.forward(&sample_xs)?; 209 | 210 | let params = ParamsRMSprop { 211 | centered: true, 212 | weight_decay: Some(0.4), 213 | ..Default::default() 214 | }; 215 | // Now use backprop to run a linear regression between samples and get the coefficients back. 216 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 217 | let b = Var::new(0f32, &Device::Cpu)?; 218 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 219 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 220 | for _step in 0..100 { 221 | let ys = lin.forward(&sample_xs)?; 222 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 223 | n_sgd.backward_step(&loss)?; 224 | } 225 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.8883, 0.7621]]); 226 | assert_eq!(to_vec0_round(&b, 4)?, 1.3558); 227 | Ok(()) 228 | } 229 | 230 | /* The results of this test have been checked against the following PyTorch code. 231 | import torch 232 | from torch import optim 233 | 234 | w_gen = torch.tensor([[3., 1.]]) 235 | b_gen = torch.tensor([-2.]) 236 | 237 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 238 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 239 | 240 | m = torch.nn.Linear(2, 1) 241 | with torch.no_grad(): 242 | m.weight.zero_() 243 | m.bias.zero_() 244 | optimiser = optim.RMSprop(m.parameters(), momentum = 0.4) 245 | # optimiser.zero_grad() 246 | for _step in range(100): 247 | optimiser.zero_grad() 248 | ys = m(sample_xs) 249 | loss = ((ys - sample_ys)**2).sum() 250 | loss.backward() 251 | optimiser.step() 252 | # print("Optimizer state begin") 253 | # print(optimiser.state) 254 | # print("Optimizer state end") 255 | print(m.weight) 256 | print(m.bias) 257 | */ 258 | #[test] 259 | fn rmsprop_momentum_test() -> Result<()> { 260 | // Generate some linear data, y = 3.x1 + x2 - 2. 261 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 262 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 263 | let gen = Linear::new(w_gen, Some(b_gen)); 264 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 265 | let sample_ys = gen.forward(&sample_xs)?; 266 | 267 | let params = ParamsRMSprop { 268 | momentum: Some(0.4), 269 | ..Default::default() 270 | }; 271 | // Now use backprop to run a linear regression between samples and get the coefficients back. 272 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 273 | let b = Var::new(0f32, &Device::Cpu)?; 274 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 275 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 276 | for _step in 0..100 { 277 | let ys = lin.forward(&sample_xs)?; 278 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 279 | n_sgd.backward_step(&loss)?; 280 | } 281 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.3042, 0.6835]]); 282 | assert_eq!(to_vec0_round(&b, 4)?, 1.5441); 283 | Ok(()) 284 | } 285 | 286 | /* The results of this test have been checked against the following PyTorch code. 287 | import torch 288 | from torch import optim 289 | 290 | w_gen = torch.tensor([[3., 1.]]) 291 | b_gen = torch.tensor([-2.]) 292 | 293 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 294 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 295 | 296 | m = torch.nn.Linear(2, 1) 297 | with torch.no_grad(): 298 | m.weight.zero_() 299 | m.bias.zero_() 300 | optimiser = optim.RMSprop(m.parameters(), momentum = 0.4, weight_decay = 0.4) 301 | # optimiser.zero_grad() 302 | for _step in range(100): 303 | optimiser.zero_grad() 304 | ys = m(sample_xs) 305 | loss = ((ys - sample_ys)**2).sum() 306 | loss.backward() 307 | optimiser.step() 308 | # print("Optimizer state begin") 309 | # print(optimiser.state) 310 | # print("Optimizer state end") 311 | print(m.weight) 312 | print(m.bias) 313 | */ 314 | #[test] 315 | fn rmsprop_momentum_decay_test() -> Result<()> { 316 | // Generate some linear data, y = 3.x1 + x2 - 2. 317 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 318 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 319 | let gen = Linear::new(w_gen, Some(b_gen)); 320 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 321 | let sample_ys = gen.forward(&sample_xs)?; 322 | 323 | let params = ParamsRMSprop { 324 | momentum: Some(0.4), 325 | weight_decay: Some(0.4), 326 | ..Default::default() 327 | }; 328 | // Now use backprop to run a linear regression between samples and get the coefficients back. 329 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 330 | let b = Var::new(0f32, &Device::Cpu)?; 331 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 332 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 333 | for _step in 0..100 { 334 | let ys = lin.forward(&sample_xs)?; 335 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 336 | n_sgd.backward_step(&loss)?; 337 | } 338 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.3028, 0.6858]]); 339 | assert_eq!(to_vec0_round(&b, 4)?, 1.5149); 340 | Ok(()) 341 | } 342 | 343 | /* The results of this test have been checked against the following PyTorch code. 344 | import torch 345 | from torch import optim 346 | 347 | w_gen = torch.tensor([[3., 1.]]) 348 | b_gen = torch.tensor([-2.]) 349 | 350 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 351 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 352 | 353 | m = torch.nn.Linear(2, 1) 354 | with torch.no_grad(): 355 | m.weight.zero_() 356 | m.bias.zero_() 357 | optimiser = optim.RMSprop(m.parameters(), centered = True, momentum = 0.4) 358 | # optimiser.zero_grad() 359 | for _step in range(100): 360 | optimiser.zero_grad() 361 | ys = m(sample_xs) 362 | loss = ((ys - sample_ys)**2).sum() 363 | loss.backward() 364 | optimiser.step() 365 | # print("Optimizer state begin") 366 | # print(optimiser.state) 367 | # print("Optimizer state end") 368 | print(m.weight) 369 | print(m.bias) 370 | */ 371 | #[test] 372 | fn rmsprop_centered_momentum_test() -> Result<()> { 373 | // Generate some linear data, y = 3.x1 + x2 - 2. 374 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 375 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 376 | let gen = Linear::new(w_gen, Some(b_gen)); 377 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 378 | let sample_ys = gen.forward(&sample_xs)?; 379 | 380 | let params = ParamsRMSprop { 381 | centered: true, 382 | momentum: Some(0.4), 383 | ..Default::default() 384 | }; 385 | // Now use backprop to run a linear regression between samples and get the coefficients back. 386 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 387 | let b = Var::new(0f32, &Device::Cpu)?; 388 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 389 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 390 | for _step in 0..100 { 391 | let ys = lin.forward(&sample_xs)?; 392 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 393 | n_sgd.backward_step(&loss)?; 394 | } 395 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.4486, 0.6715]]); 396 | assert_eq!(to_vec0_round(&b, 4)?, 1.5045); 397 | Ok(()) 398 | } 399 | 400 | /* The results of this test have been checked against the following PyTorch code. 401 | import torch 402 | from torch import optim 403 | 404 | w_gen = torch.tensor([[3., 1.]]) 405 | b_gen = torch.tensor([-2.]) 406 | 407 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 408 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 409 | 410 | m = torch.nn.Linear(2, 1) 411 | with torch.no_grad(): 412 | m.weight.zero_() 413 | m.bias.zero_() 414 | optimiser = optim.RMSprop(m.parameters(), centered = True, momentum = 0.4, weight_decay = 0.4) 415 | # optimiser.zero_grad() 416 | for _step in range(100): 417 | optimiser.zero_grad() 418 | ys = m(sample_xs) 419 | loss = ((ys - sample_ys)**2).sum() 420 | loss.backward() 421 | optimiser.step() 422 | # print("Optimizer state begin") 423 | # print(optimiser.state) 424 | # print("Optimizer state end") 425 | print(m.weight) 426 | print(m.bias) 427 | */ 428 | #[test] 429 | fn rmsprop_centered_momentum_decay_test() -> Result<()> { 430 | // Generate some linear data, y = 3.x1 + x2 - 2. 431 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 432 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 433 | let gen = Linear::new(w_gen, Some(b_gen)); 434 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 435 | let sample_ys = gen.forward(&sample_xs)?; 436 | 437 | let params = ParamsRMSprop { 438 | centered: true, 439 | momentum: Some(0.4), 440 | weight_decay: Some(0.4), 441 | ..Default::default() 442 | }; 443 | // Now use backprop to run a linear regression between samples and get the coefficients back. 444 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 445 | let b = Var::new(0f32, &Device::Cpu)?; 446 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; 447 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 448 | for _step in 0..100 { 449 | let ys = lin.forward(&sample_xs)?; 450 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 451 | n_sgd.backward_step(&loss)?; 452 | } 453 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.4468, 0.6744]]); 454 | assert_eq!(to_vec0_round(&b, 4)?, 1.4695); 455 | Ok(()) 456 | } 457 | -------------------------------------------------------------------------------- /src/lbfgs/strong_wolfe.rs: -------------------------------------------------------------------------------- 1 | use crate::Model; 2 | use candle_core::Result as CResult; 3 | use candle_core::{Tensor, Var}; 4 | 5 | use super::{add_grad, flat_grads, set_vs, Lbfgs}; 6 | 7 | /// ported from pytorch torch/optim/lbfgs.py ported from 8 | fn cubic_interpolate( 9 | // position 1 10 | x1: f64, 11 | // f(x1) 12 | f1: f64, 13 | // f'(x1) 14 | g1: f64, 15 | // position 2 16 | x2: f64, 17 | // f(x2) 18 | f2: f64, 19 | // f'(x2) 20 | g2: f64, 21 | bounds: Option<(f64, f64)>, 22 | ) -> f64 { 23 | let (xmin_bound, xmax_bound) = if let Some(bound) = bounds { 24 | bound 25 | } else if x1 < x2 { 26 | (x1, x2) 27 | } else { 28 | (x2, x1) 29 | }; 30 | let d1 = g1 + g2 - 3. * (f1 - f2) / (x1 - x2); 31 | let d2_square = d1.powi(2) - g1 * g2; 32 | if d2_square >= 0. { 33 | let d2 = d2_square.sqrt(); 34 | let min_pos = if x1 <= x2 { 35 | x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2. * d2)) 36 | } else { 37 | x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2. * d2)) 38 | }; 39 | (min_pos.max(xmin_bound)).min(xmax_bound) 40 | } else { 41 | xmin_bound.midpoint(xmax_bound) 42 | } 43 | } 44 | 45 | impl Lbfgs<'_, M> { 46 | /// Strong Wolfe line search 47 | /// 48 | /// # Arguments 49 | /// 50 | /// step size 51 | /// 52 | /// direction 53 | /// 54 | /// initial loss 55 | /// 56 | /// initial grad 57 | /// 58 | /// initial directional grad 59 | /// 60 | /// c1 coefficient for wolfe condition 61 | /// 62 | /// c2 coefficient for wolfe condition 63 | /// 64 | /// minimum allowed progress 65 | /// 66 | /// maximum number of iterations 67 | /// 68 | /// # Returns 69 | /// 70 | /// (`f_new`, `g_new`, t, `ls_func_evals`) 71 | #[allow(clippy::too_many_arguments, clippy::too_many_lines)] 72 | pub(super) fn strong_wolfe( 73 | &mut self, 74 | mut step_size: f64, // step size 75 | direction: &Tensor, // direction 76 | loss: &Tensor, // initial loss 77 | grad: &Tensor, // initial grad 78 | directional_grad: f64, // initial directional grad 79 | c1: f64, // c1 coefficient for wolfe condition 80 | c2: f64, // c2 coefficient for wolfe condition 81 | tolerance_change: f64, // minimum allowed progress 82 | max_ls: usize, // maximum number of iterations 83 | ) -> CResult<(Tensor, Tensor, f64, usize)> { 84 | // ported from https://github.com/torch/optim/blob/master/lswolfe.lua 85 | 86 | let dtype = loss.dtype(); 87 | let shape = loss.shape(); 88 | let dev = loss.device(); 89 | 90 | let d_norm = &direction 91 | .abs()? 92 | .max(0)? 93 | .to_dtype(candle_core::DType::F64)? 94 | .to_scalar::()?; 95 | 96 | // evaluate objective and gradient using initial step 97 | let (f_new, g_new, mut l2_new) = self.directional_evaluate(step_size, direction)?; 98 | let g_new = Var::from_tensor(&g_new)?; 99 | let mut f_new = f_new 100 | .to_dtype(candle_core::DType::F64)? 101 | .to_scalar::()?; 102 | let mut ls_func_evals = 1; 103 | let mut gtd_new = g_new 104 | .unsqueeze(0)? 105 | .matmul(&(direction.unsqueeze(1)?))? 106 | .to_dtype(candle_core::DType::F64)? 107 | .squeeze(1)? 108 | .squeeze(0)? 109 | .to_scalar::()?; 110 | 111 | // bracket an interval containing a point satisfying the Wolfe criteria 112 | let grad_det = grad.copy()?; 113 | let g_prev = Var::from_tensor(&grad_det)?; 114 | let scalar_loss = loss.to_dtype(candle_core::DType::F64)?.to_scalar::()?; 115 | let mut f_prev = scalar_loss; 116 | let l2_init = self.l2_reg()?; 117 | let mut l2_prev = l2_init; 118 | let (mut t_prev, mut gtd_prev) = (0., directional_grad); 119 | let mut done = false; 120 | let mut ls_iter = 0; 121 | 122 | let mut bracket_gtd; 123 | let mut bracket_l2; 124 | let mut bracket_f; 125 | let (mut bracket, bracket_g) = loop { 126 | // check conditions 127 | if f_new + l2_new >= f_prev + l2_prev { 128 | bracket_gtd = [gtd_prev, gtd_new]; 129 | bracket_l2 = [l2_prev, l2_new]; 130 | bracket_f = [f_prev, f_new]; 131 | break ( 132 | [t_prev, step_size], 133 | [g_prev, Var::from_tensor(g_new.as_tensor())?], 134 | ); 135 | } 136 | 137 | if gtd_new.abs() <= -c2 * directional_grad { 138 | done = true; 139 | bracket_gtd = [gtd_prev, gtd_new]; 140 | bracket_l2 = [l2_prev, l2_new]; 141 | bracket_f = [f_new, f_new]; 142 | break ( 143 | [step_size, step_size], 144 | [ 145 | Var::from_tensor(&g_new.as_tensor().copy()?)?, 146 | Var::from_tensor(g_new.as_tensor())?, 147 | ], 148 | ); 149 | } 150 | 151 | if gtd_new >= 0. { 152 | bracket_gtd = [gtd_prev, gtd_new]; 153 | bracket_l2 = [l2_prev, l2_new]; 154 | bracket_f = [f_prev, f_new]; 155 | break ( 156 | [t_prev, step_size], 157 | [g_prev, Var::from_tensor(g_new.as_tensor())?], 158 | ); 159 | } 160 | 161 | // interpolate 162 | let min_step = step_size + 0.01 * (step_size - t_prev); 163 | let max_step = step_size * 10.; 164 | let tmp = step_size; 165 | step_size = cubic_interpolate( 166 | t_prev, 167 | f_prev + l2_prev, 168 | gtd_prev, 169 | step_size, 170 | f_new + l2_new, 171 | gtd_new, 172 | Some((min_step, max_step)), 173 | ); 174 | 175 | // next step 176 | t_prev = tmp; 177 | f_prev = f_new; 178 | g_prev.set(g_new.as_tensor())?; 179 | l2_prev = l2_new; 180 | gtd_prev = gtd_new; 181 | // assign to temp vars: 182 | let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?; 183 | 184 | // overwrite 185 | f_new = next_f 186 | .to_dtype(candle_core::DType::F64)? 187 | .to_scalar::()?; 188 | g_new.set(&next_g)?; 189 | l2_new = next_l2; 190 | 191 | ls_func_evals += 1; 192 | 193 | gtd_new = g_new 194 | .unsqueeze(0)? 195 | .matmul(&(direction.unsqueeze(1)?))? 196 | .to_dtype(candle_core::DType::F64)? 197 | .squeeze(1)? 198 | .squeeze(0)? 199 | .to_scalar::()?; 200 | ls_iter += 1; 201 | 202 | // reached max number of iterations? 203 | if ls_iter == max_ls { 204 | bracket_gtd = [gtd_prev, gtd_new]; 205 | bracket_l2 = [l2_prev, l2_new]; 206 | bracket_f = [scalar_loss, f_new]; 207 | break ( 208 | [0., step_size], 209 | [ 210 | Var::from_tensor(grad)?, 211 | Var::from_tensor(g_new.as_tensor())?, 212 | ], 213 | ); 214 | } 215 | }; 216 | 217 | // zoom phase: we now have a point satisfying the criteria, or 218 | // a bracket around it. We refine the bracket until we find the 219 | // exact point satisfying the criteria 220 | let mut insuf_progress = false; 221 | // find high and low points in bracket 222 | let (mut low_pos, mut high_pos) = 223 | if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] { 224 | (0, 1) 225 | } else { 226 | (1, 0) 227 | }; 228 | while !done && ls_iter < max_ls { 229 | // line-search bracket is so small 230 | if (bracket[1] - bracket[0]).abs() * d_norm < tolerance_change { 231 | break; 232 | } 233 | 234 | // compute new trial value 235 | step_size = cubic_interpolate( 236 | bracket[0], 237 | bracket_f[0] + bracket_l2[0], 238 | bracket_gtd[0], 239 | bracket[1], 240 | bracket_f[1] + bracket_l2[1], 241 | bracket_gtd[1], 242 | None, 243 | ); 244 | 245 | // test that we are making sufficient progress: 246 | // in case `t` is so close to boundary, we mark that we are making 247 | // insufficient progress, and if 248 | // + we have made insufficient progress in the last step, or 249 | // + `t` is at one of the boundary, 250 | // we will move `t` to a position which is `0.1 * len(bracket)` 251 | // away from the nearest boundary point. 252 | let max_bracket = bracket[0].max(bracket[1]); 253 | let min_bracket = bracket[0].min(bracket[1]); 254 | let eps = 0.1 * (max_bracket - min_bracket); 255 | if (max_bracket - step_size).min(step_size - min_bracket) < eps { 256 | // interpolation close to boundary 257 | if insuf_progress || step_size >= max_bracket || step_size <= min_bracket { 258 | // evaluate at 0.1 away from boundary 259 | if (step_size - max_bracket).abs() < (step_size - min_bracket).abs() { 260 | step_size = max_bracket - eps; 261 | } else { 262 | step_size = min_bracket + eps; 263 | } 264 | insuf_progress = false; 265 | } else { 266 | insuf_progress = true; 267 | } 268 | } else { 269 | insuf_progress = false; 270 | } 271 | 272 | // Evaluate new point 273 | // assign to temp vars: 274 | let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?; 275 | // overwrite 276 | f_new = next_f 277 | .to_dtype(candle_core::DType::F64)? 278 | .to_scalar::()?; 279 | 280 | l2_new = next_l2; 281 | ls_func_evals += 1; 282 | 283 | gtd_new = next_g 284 | .unsqueeze(0)? 285 | .matmul(&(direction.unsqueeze(1)?))? 286 | .to_dtype(candle_core::DType::F64)? 287 | .squeeze(1)? 288 | .squeeze(0)? 289 | .to_scalar::()?; 290 | ls_iter += 1; 291 | 292 | if f_new + l2_new > (scalar_loss + l2_init + c1 * step_size * directional_grad) 293 | || f_new + l2_new >= bracket_f[low_pos] + bracket_l2[low_pos] 294 | { 295 | // Armijo condition not satisfied or not lower than lowest point 296 | bracket[high_pos] = step_size; 297 | bracket_f[high_pos] = f_new; 298 | bracket_g[high_pos].set(&next_g)?; 299 | bracket_l2[high_pos] = l2_new; 300 | bracket_gtd[high_pos] = gtd_new; 301 | 302 | (low_pos, high_pos) = 303 | if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] { 304 | (0, 1) 305 | } else { 306 | (1, 0) 307 | }; 308 | } else { 309 | if gtd_new.abs() <= -c2 * directional_grad { 310 | // Wolfe conditions satisfied 311 | done = true; 312 | } else if gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0. { 313 | // old low becomes new high 314 | bracket[high_pos] = bracket[low_pos]; 315 | bracket_f[high_pos] = bracket_f[low_pos]; 316 | bracket_g[high_pos].set(bracket_g[low_pos].as_tensor())?; 317 | bracket_gtd[high_pos] = bracket_gtd[low_pos]; 318 | bracket_l2[high_pos] = bracket_l2[low_pos]; 319 | } 320 | 321 | // new point becomes new low 322 | bracket[low_pos] = step_size; 323 | bracket_f[low_pos] = f_new; 324 | bracket_g[low_pos].set(&next_g)?; 325 | bracket_gtd[low_pos] = gtd_new; 326 | bracket_l2[low_pos] = l2_new; 327 | } 328 | } 329 | 330 | // return new value, new grad, line-search value, nb of function evals 331 | step_size = bracket[low_pos]; 332 | let [g0, g1] = bracket_g; 333 | let [f0, f1] = bracket_f; 334 | if low_pos == 1 { 335 | // if b is the lower value set a to b, else a should be returned 336 | Ok(( 337 | Tensor::from_slice(&[f1], shape, dev)?.to_dtype(dtype)?, 338 | g1.into_inner(), 339 | step_size, 340 | ls_func_evals, 341 | )) 342 | } else { 343 | Ok(( 344 | Tensor::from_slice(&[f0], shape, dev)?.to_dtype(dtype)?, 345 | g0.into_inner(), 346 | step_size, 347 | ls_func_evals, 348 | )) 349 | } 350 | } 351 | 352 | fn directional_evaluate( 353 | &mut self, 354 | mag: f64, 355 | direction: &Tensor, 356 | ) -> CResult<(Tensor, Tensor, f64)> { 357 | // need to cache the original result 358 | // Otherwise leads to drift over line search evals 359 | let original = self 360 | .vars 361 | .iter() 362 | .map(|v| v.as_tensor().copy()) 363 | .collect::>>()?; 364 | 365 | add_grad(&mut self.vars, &(mag * direction)?)?; 366 | let loss = self.model.loss()?; 367 | let grad = flat_grads(&self.vars, &loss, self.params.weight_decay)?; 368 | let l2_reg = if let Some(wd) = self.params.weight_decay { 369 | 0.5 * wd 370 | * self 371 | .vars 372 | .iter() 373 | .map(|v| -> CResult { 374 | v.as_tensor() 375 | .sqr()? 376 | .sum_all()? 377 | .to_dtype(candle_core::DType::F64)? 378 | .to_scalar::() 379 | }) 380 | .sum::>()? 381 | } else { 382 | 0. 383 | }; 384 | 385 | set_vs(&mut self.vars, &original)?; 386 | // add_grad(&mut self.vars, &(-mag * direction)?)?; 387 | Ok(( 388 | loss, //.to_dtype(candle_core::DType::F64)?.to_scalar::()? 389 | grad, l2_reg, 390 | )) 391 | } 392 | 393 | fn l2_reg(&self) -> CResult { 394 | if let Some(wd) = self.params.weight_decay { 395 | Ok(0.5 396 | * wd 397 | * self 398 | .vars 399 | .iter() 400 | .map(|v| -> CResult { 401 | v.as_tensor() 402 | .sqr()? 403 | .sum_all()? 404 | .to_dtype(candle_core::DType::F64)? 405 | .to_scalar::() 406 | }) 407 | .sum::>()?) 408 | } else { 409 | Ok(0.) 410 | } 411 | } 412 | } 413 | 414 | #[cfg(test)] 415 | mod tests { 416 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 417 | 418 | use crate::lbfgs::ParamsLBFGS; 419 | use crate::{LossOptimizer, Model}; 420 | use anyhow::Result; 421 | use assert_approx_eq::assert_approx_eq; 422 | use candle_core::Device; 423 | use candle_core::{Module, Result as CResult}; 424 | pub struct LinearModel { 425 | linear: candle_nn::Linear, 426 | xs: Tensor, 427 | ys: Tensor, 428 | } 429 | 430 | impl Model for LinearModel { 431 | fn loss(&self) -> CResult { 432 | let preds = self.forward(&self.xs)?; 433 | let loss = candle_nn::loss::mse(&preds, &self.ys)?; 434 | Ok(loss) 435 | } 436 | } 437 | 438 | impl LinearModel { 439 | fn new() -> CResult<(Self, Vec)> { 440 | let weight = Var::from_tensor(&Tensor::new(&[3f64, 1.], &Device::Cpu)?)?; 441 | let bias = Var::from_tensor(&Tensor::new(-2f64, &Device::Cpu)?)?; 442 | 443 | let linear = 444 | candle_nn::Linear::new(weight.as_tensor().clone(), Some(bias.as_tensor().clone())); 445 | 446 | Ok(( 447 | Self { 448 | linear, 449 | xs: Tensor::new(&[[2f64, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?, 450 | ys: Tensor::new(&[[7f64], [26.], [0.], [27.]], &Device::Cpu)?, 451 | }, 452 | vec![weight, bias], 453 | )) 454 | } 455 | 456 | fn forward(&self, xs: &Tensor) -> CResult { 457 | self.linear.forward(xs) 458 | } 459 | } 460 | 461 | use super::*; 462 | #[test] 463 | fn l2_test() -> Result<()> { 464 | let params = ParamsLBFGS { 465 | lr: 0.004, 466 | ..Default::default() 467 | }; 468 | let (model, vars) = LinearModel::new()?; 469 | let lbfgs = Lbfgs::new(vars, params, &model)?; 470 | let l2 = lbfgs.l2_reg()?; 471 | assert_approx_eq!(0.0, l2); 472 | 473 | let params = ParamsLBFGS { 474 | lr: 0.004, 475 | weight_decay: Some(1.0), 476 | ..Default::default() 477 | }; 478 | let (model, vars) = LinearModel::new()?; 479 | let lbfgs = Lbfgs::new(vars, params, &model)?; 480 | let l2 = lbfgs.l2_reg()?; 481 | assert_approx_eq!(7.0, l2); // 0.5 *(3^2 +1^2 + (-2)^2) 482 | Ok(()) 483 | } 484 | } 485 | -------------------------------------------------------------------------------- /tests/esgd_tests.rs: -------------------------------------------------------------------------------- 1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round}; 2 | 3 | use anyhow::Result; 4 | use candle_core::{Device, Tensor, Var}; 5 | use candle_nn::{Linear, Module, Optimizer}; 6 | use candle_optimisers::{ 7 | esgd::{ParamsSGD, SGD}, 8 | Decay, Momentum, 9 | }; 10 | 11 | /* The results of this test have been checked against the following PyTorch code. 12 | import torch 13 | from torch import optim 14 | 15 | w_gen = torch.tensor([[3., 1.]]) 16 | b_gen = torch.tensor([-2.]) 17 | 18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 20 | 21 | m = torch.nn.Linear(2, 1) 22 | with torch.no_grad(): 23 | m.weight.zero_() 24 | m.bias.zero_() 25 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=True) 26 | for _step in range(100): 27 | optimiser.zero_grad() 28 | ys = m(sample_xs) 29 | loss = ((ys - sample_ys)**2).sum() 30 | loss.backward() 31 | optimiser.step() 32 | print(m.weight) 33 | print(m.bias) 34 | */ 35 | #[test] 36 | fn nesterov_sgd_test() -> Result<()> { 37 | // Generate some linear data, y = 3.x1 + x2 - 2. 38 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 39 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 40 | let gen = Linear::new(w_gen, Some(b_gen)); 41 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 42 | let sample_ys = gen.forward(&sample_xs)?; 43 | 44 | let params = ParamsSGD { 45 | lr: 0.004, 46 | weight_decay: None, 47 | momentum: Some(Momentum::Nesterov(0.1)), 48 | dampening: 0.0, 49 | // nesterov: true, 50 | }; 51 | // Now use backprop to run a linear regression between samples and get the coefficients back. 52 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 53 | let b = Var::new(0f32, &Device::Cpu)?; 54 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 55 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 56 | for _step in 0..100 { 57 | let ys = lin.forward(&sample_xs)?; 58 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 59 | n_sgd.backward_step(&loss)?; 60 | } 61 | if cfg!(target_os = "macos") { 62 | assert_eq!(to_vec2_round(&w, 3)?, &[[1.075, -9.904]]); 63 | assert_eq!(to_vec0_round(&b, 3)?, -1.896); 64 | } else { 65 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.0750, -9.9042]]); 66 | assert_eq!(to_vec0_round(&b, 4)?, -1.8961); 67 | } 68 | 69 | Ok(()) 70 | } 71 | 72 | /* The results of this test have been checked against the following PyTorch code. 73 | import torch 74 | from torch import optim 75 | 76 | w_gen = torch.tensor([[3., 1.]]) 77 | b_gen = torch.tensor([-2.]) 78 | 79 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 80 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 81 | 82 | m = torch.nn.Linear(2, 1) 83 | with torch.no_grad(): 84 | m.weight.zero_() 85 | m.bias.zero_() 86 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=True, weight_decay = 0.1) 87 | # optimiser.zero_grad() 88 | for _step in range(100): 89 | optimiser.zero_grad() 90 | ys = m(sample_xs) 91 | loss = ((ys - sample_ys)**2).sum() 92 | loss.backward() 93 | optimiser.step() 94 | # print("Optimizer state begin") 95 | # print(optimiser.state) 96 | # print("Optimizer state end") 97 | print(m.weight) 98 | print(m.bias) 99 | */ 100 | #[test] 101 | fn nesterov_decay_sgd_test() -> Result<()> { 102 | // Generate some linear data, y = 3.x1 + x2 - 2. 103 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 104 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 105 | let gen = Linear::new(w_gen, Some(b_gen)); 106 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 107 | let sample_ys = gen.forward(&sample_xs)?; 108 | 109 | let params = ParamsSGD { 110 | lr: 0.004, 111 | weight_decay: Some(Decay::WeightDecay(0.1)), 112 | momentum: Some(Momentum::Nesterov(0.1)), 113 | dampening: 0.0, 114 | // nesterov: true, 115 | }; 116 | // Now use backprop to run a linear regression between samples and get the coefficients back. 117 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 118 | let b = Var::new(0f32, &Device::Cpu)?; 119 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 120 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 121 | for _step in 0..100 { 122 | let ys = lin.forward(&sample_xs)?; 123 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 124 | n_sgd.backward_step(&loss)?; 125 | } 126 | 127 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9921, -10.3803]]); 128 | assert_eq!(to_vec0_round(&b, 4)?, -1.9331); 129 | Ok(()) 130 | } 131 | 132 | /* The results of this test have been checked against the following PyTorch code. 133 | import torch 134 | from torch import optim 135 | 136 | w_gen = torch.tensor([[3., 1.]]) 137 | b_gen = torch.tensor([-2.]) 138 | 139 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 140 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 141 | 142 | m = torch.nn.Linear(2, 1) 143 | with torch.no_grad(): 144 | m.weight.zero_() 145 | m.bias.zero_() 146 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=False, weight_decay = 0.0) 147 | # optimiser.zero_grad() 148 | for _step in range(100): 149 | optimiser.zero_grad() 150 | ys = m(sample_xs) 151 | loss = ((ys - sample_ys)**2).sum() 152 | loss.backward() 153 | optimiser.step() 154 | # print("Optimizer state begin") 155 | # print(optimiser.state) 156 | # print("Optimizer state end") 157 | print(m.weight) 158 | print(m.bias) 159 | */ 160 | #[test] 161 | fn momentum_sgd_test() -> Result<()> { 162 | // Generate some linear data, y = 3.x1 + x2 - 2. 163 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 164 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 165 | let gen = Linear::new(w_gen, Some(b_gen)); 166 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 167 | let sample_ys = gen.forward(&sample_xs)?; 168 | 169 | let params = ParamsSGD { 170 | lr: 0.004, 171 | weight_decay: None, 172 | momentum: Some(Momentum::Classical(0.1)), 173 | dampening: 0.0, 174 | // nesterov: false,s 175 | }; 176 | // Now use backprop to run a linear regression between samples and get the coefficients back. 177 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 178 | let b = Var::new(0f32, &Device::Cpu)?; 179 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 180 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 181 | for _step in 0..100 { 182 | let ys = lin.forward(&sample_xs)?; 183 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 184 | n_sgd.backward_step(&loss)?; 185 | } 186 | 187 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8870, 0.8589]]); 188 | assert_eq!(to_vec0_round(&b, 4)?, -0.6341); 189 | Ok(()) 190 | } 191 | 192 | /* The results of this test have been checked against the following PyTorch code. 193 | import torch 194 | from torch import optim 195 | 196 | w_gen = torch.tensor([[3., 1.]]) 197 | b_gen = torch.tensor([-2.]) 198 | 199 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 200 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 201 | 202 | m = torch.nn.Linear(2, 1) 203 | with torch.no_grad(): 204 | m.weight.zero_() 205 | m.bias.zero_() 206 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=False, weight_decay = 0.4) 207 | # optimiser.zero_grad() 208 | for _step in range(100): 209 | optimiser.zero_grad() 210 | ys = m(sample_xs) 211 | loss = ((ys - sample_ys)**2).sum() 212 | loss.backward() 213 | optimiser.step() 214 | # print("Optimizer state begin") 215 | # print(optimiser.state) 216 | # print("Optimizer state end") 217 | print(m.weight) 218 | print(m.bias) 219 | */ 220 | #[test] 221 | fn momentum_sgd_decay_test() -> Result<()> { 222 | // Generate some linear data, y = 3.x1 + x2 - 2. 223 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 224 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 225 | let gen = Linear::new(w_gen, Some(b_gen)); 226 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 227 | let sample_ys = gen.forward(&sample_xs)?; 228 | 229 | let params = ParamsSGD { 230 | lr: 0.004, 231 | weight_decay: Some(Decay::WeightDecay(0.4)), 232 | momentum: Some(Momentum::Classical(0.1)), 233 | dampening: 0.0, 234 | // nesterov: false, 235 | }; 236 | // Now use backprop to run a linear regression between samples and get the coefficients back. 237 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 238 | let b = Var::new(0f32, &Device::Cpu)?; 239 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 240 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 241 | for _step in 0..100 { 242 | let ys = lin.forward(&sample_xs)?; 243 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 244 | n_sgd.backward_step(&loss)?; 245 | } 246 | 247 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8751, 0.8514]]); 248 | assert_eq!(to_vec0_round(&b, 4)?, -0.5626); 249 | Ok(()) 250 | } 251 | 252 | /* The results of this test have been checked against the following PyTorch code. 253 | import torch 254 | from torch import optim 255 | 256 | w_gen = torch.tensor([[3., 1.]]) 257 | b_gen = torch.tensor([-2.]) 258 | 259 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) 260 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen 261 | 262 | m = torch.nn.Linear(2, 1) 263 | with torch.no_grad(): 264 | m.weight.zero_() 265 | m.bias.zero_() 266 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=False, weight_decay = 0.0, dampening = 0.2) 267 | # optimiser.zero_grad() 268 | for _step in range(100): 269 | optimiser.zero_grad() 270 | ys = m(sample_xs) 271 | loss = ((ys - sample_ys)**2).sum() 272 | loss.backward() 273 | optimiser.step() 274 | # print("Optimizer state begin") 275 | # print(optimiser.state) 276 | # print("Optimizer state end") 277 | print(m.weight) 278 | print(m.bias) 279 | */ 280 | #[test] 281 | fn momentum_sgd_dampened_test() -> Result<()> { 282 | // Generate some linear data, y = 3.x1 + x2 - 2. 283 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 284 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 285 | let gen = Linear::new(w_gen, Some(b_gen)); 286 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 287 | let sample_ys = gen.forward(&sample_xs)?; 288 | 289 | let params = ParamsSGD { 290 | lr: 0.004, 291 | weight_decay: None, 292 | momentum: Some(Momentum::Classical(0.1)), 293 | dampening: 0.2, 294 | // nesterov: false, 295 | }; 296 | // Now use backprop to run a linear regression between samples and get the coefficients back. 297 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 298 | let b = Var::new(0f32, &Device::Cpu)?; 299 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 300 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 301 | for _step in 0..100 { 302 | let ys = lin.forward(&sample_xs)?; 303 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 304 | n_sgd.backward_step(&loss)?; 305 | } 306 | 307 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8746, 0.8434]]); 308 | assert_eq!(to_vec0_round(&b, 4)?, -0.4838); 309 | Ok(()) 310 | } 311 | 312 | #[test] 313 | fn sgd_test() -> Result<()> { 314 | // Generate some linear data, y = 3.x1 + x2 - 2. 315 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 316 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 317 | let gen = Linear::new(w_gen, Some(b_gen)); 318 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 319 | let sample_ys = gen.forward(&sample_xs)?; 320 | 321 | let params = ParamsSGD { 322 | lr: 0.004, 323 | weight_decay: None, 324 | momentum: None, 325 | dampening: 0.0, 326 | // nesterov: false, 327 | }; 328 | // Now use backprop to run a linear regression between samples and get the coefficients back. 329 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 330 | let b = Var::new(0f32, &Device::Cpu)?; 331 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 332 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 333 | for _step in 0..100 { 334 | let ys = lin.forward(&sample_xs)?; 335 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 336 | n_sgd.backward_step(&loss)?; 337 | } 338 | 339 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8809, 0.8513]]); 340 | assert_eq!(to_vec0_round(&b, 4)?, -0.5606); 341 | Ok(()) 342 | } 343 | 344 | #[test] 345 | fn sgd_decay_test() -> Result<()> { 346 | // Generate some linear data, y = 3.x1 + x2 - 2. 347 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 348 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 349 | let gen = Linear::new(w_gen, Some(b_gen)); 350 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 351 | let sample_ys = gen.forward(&sample_xs)?; 352 | 353 | let params = ParamsSGD { 354 | lr: 0.004, 355 | weight_decay: Some(Decay::WeightDecay(0.4)), 356 | momentum: None, 357 | dampening: 0.0, 358 | // nesterov: false, 359 | }; 360 | // Now use backprop to run a linear regression between samples and get the coefficients back. 361 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 362 | let b = Var::new(0f32, &Device::Cpu)?; 363 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 364 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 365 | for _step in 0..100 { 366 | let ys = lin.forward(&sample_xs)?; 367 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 368 | n_sgd.backward_step(&loss)?; 369 | } 370 | 371 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8700, 0.8450]]); 372 | assert_eq!(to_vec0_round(&b, 4)?, -0.5003); 373 | Ok(()) 374 | } 375 | 376 | // The following are not tested against torch 377 | // As torch has no implementation of SGDW 378 | 379 | // This should be the same (as without momentum, decoupling is equivalent) 380 | #[test] 381 | fn sgdw_decay_test() -> Result<()> { 382 | // Generate some linear data, y = 3.x1 + x2 - 2. 383 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 384 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 385 | let gen = Linear::new(w_gen, Some(b_gen)); 386 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 387 | let sample_ys = gen.forward(&sample_xs)?; 388 | 389 | let params = ParamsSGD { 390 | lr: 0.004, 391 | weight_decay: Some(Decay::DecoupledWeightDecay(0.4)), 392 | momentum: None, 393 | dampening: 0.0, 394 | // nesterov: false, 395 | }; 396 | // Now use backprop to run a linear regression between samples and get the coefficients back. 397 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 398 | let b = Var::new(0f32, &Device::Cpu)?; 399 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 400 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 401 | for _step in 0..100 { 402 | let ys = lin.forward(&sample_xs)?; 403 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 404 | n_sgd.backward_step(&loss)?; 405 | } 406 | 407 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8700, 0.8450]]); 408 | assert_eq!(to_vec0_round(&b, 4)?, -0.5003); 409 | Ok(()) 410 | } 411 | 412 | #[test] 413 | fn momentum_sgdw_decay_test() -> Result<()> { 414 | // Generate some linear data, y = 3.x1 + x2 - 2. 415 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 416 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 417 | let gen = Linear::new(w_gen, Some(b_gen)); 418 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 419 | let sample_ys = gen.forward(&sample_xs)?; 420 | 421 | let params = ParamsSGD { 422 | lr: 0.004, 423 | weight_decay: Some(Decay::DecoupledWeightDecay(0.4)), 424 | momentum: Some(Momentum::Classical(0.1)), 425 | dampening: 0.0, 426 | // nesterov: false, 427 | }; 428 | // Now use backprop to run a linear regression between samples and get the coefficients back. 429 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 430 | let b = Var::new(0f32, &Device::Cpu)?; 431 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 432 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 433 | for _step in 0..100 { 434 | let ys = lin.forward(&sample_xs)?; 435 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 436 | n_sgd.backward_step(&loss)?; 437 | } 438 | 439 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8763, 0.8521]]); 440 | assert_eq!(to_vec0_round(&b, 4)?, -0.5693); 441 | Ok(()) 442 | } 443 | 444 | #[test] 445 | fn nesterov_decay_sgdw_test() -> Result<()> { 446 | // Generate some linear data, y = 3.x1 + x2 - 2. 447 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; 448 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?; 449 | let gen = Linear::new(w_gen, Some(b_gen)); 450 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; 451 | let sample_ys = gen.forward(&sample_xs)?; 452 | 453 | let params = ParamsSGD { 454 | lr: 0.004, 455 | weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.1)), 456 | momentum: Some(Momentum::Nesterov(0.1)), 457 | dampening: 0.0, 458 | // nesterov: true, 459 | }; 460 | // Now use backprop to run a linear regression between samples and get the coefficients back. 461 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; 462 | let b = Var::new(0f32, &Device::Cpu)?; 463 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; 464 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); 465 | for _step in 0..100 { 466 | let ys = lin.forward(&sample_xs)?; 467 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; 468 | n_sgd.backward_step(&loss)?; 469 | } 470 | 471 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9992, -10.3397]]); 472 | assert_eq!(to_vec0_round(&b, 4)?, -1.9302); 473 | Ok(()) 474 | } 475 | --------------------------------------------------------------------------------