├── .cargo
└── config.toml
├── .gitignore
├── .github
├── dependabot.yml
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── lints.yml
│ └── rust-ci.yml
├── katex-header.html
├── LICENSE
├── examples
└── mnist
│ ├── parse_cli.rs
│ ├── models.rs
│ ├── training.rs
│ ├── main.rs
│ └── optim.rs
├── paper
├── paper.md
└── paper.bib
├── Cargo.toml
├── README.md
├── Changelog.md
├── benches
├── mnist_bench.rs
└── training.rs
├── tests
├── adadelta_tests.rs
├── radam-tests.rs
├── adamax_tests.rs
├── nadam_tests.rs
├── adagrad_tests.rs
├── lbfgs_tests.rs
├── adam_tests.rs
├── rmsprop-tests.rs
└── esgd_tests.rs
└── src
├── lib.rs
├── adagrad.rs
├── adamax.rs
├── adadelta.rs
├── nadam.rs
├── radam.rs
└── lbfgs
└── strong_wolfe.rs
/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | rustdocflags = [ "--html-in-header", "./katex-header.html" ]
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 | /Cargo.lock
3 | notes.txt
4 | out.txt
5 | lbfgs_testbed.ipynb
6 | test*.ipynb
7 | References.md
8 | pseudo/*
9 | paper/paper.pdf
10 | paper/paper.jats
11 | *.profraw
12 | .vscode/
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "cargo"
9 | directory: "/" # Location of package manifests
10 | schedule:
11 | interval: "weekly"
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Code example to reproduce behaviour
15 |
16 | **Expected behavior**
17 | A clear and concise description of what you expected to happen.
18 |
19 | **Desktop (please complete the following information):**
20 |
21 | * OS: [e.g. iOS]
22 | * Features in use [e.g. Cuda]
23 | * version
24 |
25 | **Additional context**
26 | Add any other context about the problem here.
27 |
--------------------------------------------------------------------------------
/.github/workflows/lints.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - master
5 | pull_request:
6 |
7 | name: Clippy
8 |
9 | jobs:
10 | fmt:
11 | name: Rustfmt
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - uses: dtolnay/rust-toolchain@stable
16 | - run: rustup component add rustfmt
17 | - run: cargo fmt --all -- --check
18 |
19 | clippy:
20 | name: Clippy
21 | runs-on: ubuntu-latest
22 | steps:
23 | - uses: actions/checkout@v4
24 | - uses: dtolnay/rust-toolchain@stable
25 | - run: rustup component add rustfmt
26 | - run: cargo clippy --workspace --tests --examples -- -D warnings
27 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | Is the feature an additional optimisation algorithm? Is it a change to the documentation?
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen: for additional optimisation algorithms it would be ideal if possible to see if there are other implementations or a reference (though am very happy to accept novel optimisers).
15 |
16 | **Additional context**
17 | Add any other context about the feature request here.
18 |
--------------------------------------------------------------------------------
/.github/workflows/rust-ci.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - master
5 | pull_request:
6 |
7 | name: CI
8 |
9 | jobs:
10 | check:
11 | name: Check
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | matrix:
15 | os: [ubuntu-latest, windows-latest, macOS-latest]
16 | rust: [stable, beta, nightly]
17 | steps:
18 | - uses: actions/checkout@v4
19 | - uses: dtolnay/rust-toolchain@master
20 | with:
21 | toolchain: ${{ matrix.rust }}
22 | - run: cargo check --workspace
23 |
24 | test:
25 | name: Test Suite
26 | runs-on: ${{ matrix.os }}
27 | strategy:
28 | matrix:
29 | os: [ubuntu-latest, windows-latest, macOS-latest]
30 | rust: [stable, beta, nightly]
31 | steps:
32 | - uses: actions/checkout@v4
33 | - uses: dtolnay/rust-toolchain@master
34 | with:
35 | toolchain: ${{ matrix.rust }}
36 | - run: cargo test --workspace
--------------------------------------------------------------------------------
/katex-header.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2023] [Kirpal Grewal]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/mnist/parse_cli.rs:
--------------------------------------------------------------------------------
1 | use clap::{Parser, ValueEnum};
2 |
3 | pub struct TrainingArgs {
4 | pub learning_rate: f64,
5 | pub load: Option,
6 | pub save: Option,
7 | pub epochs: usize,
8 | }
9 |
10 | #[derive(ValueEnum, Clone)]
11 | pub enum WhichModel {
12 | Linear,
13 | Mlp,
14 | }
15 |
16 | #[derive(ValueEnum, Clone)]
17 | pub enum WhichOptim {
18 | Adadelta,
19 | Adagrad,
20 | Adamax,
21 | Sgd,
22 | NAdam,
23 | RAdam,
24 | Rms,
25 | Adam,
26 | }
27 |
28 | #[derive(Parser)]
29 | pub struct Args {
30 | #[clap(value_enum, default_value_t = WhichModel::Linear)]
31 | pub model: WhichModel,
32 |
33 | #[arg(long, value_enum, default_value_t = WhichOptim::Adadelta)]
34 | pub optim: WhichOptim,
35 |
36 | #[arg(long)]
37 | pub learning_rate: Option,
38 |
39 | #[arg(long, default_value_t = 200)]
40 | pub epochs: usize,
41 |
42 | /// The file where to save the trained weights, in safetensors format.
43 | #[arg(long)]
44 | pub save: Option,
45 |
46 | /// The file where to load the trained weights from, in safetensors format.
47 | #[arg(long)]
48 | pub load: Option,
49 |
50 | /// The directory where to load the dataset from, in ubyte format.
51 | #[arg(long)]
52 | pub local_mnist: Option,
53 | }
54 |
--------------------------------------------------------------------------------
/examples/mnist/models.rs:
--------------------------------------------------------------------------------
1 | use candle_core::{Result, Tensor};
2 | use candle_nn::{Linear, Module, VarBuilder};
3 |
4 | const IMAGE_DIM: usize = 784;
5 | const LABELS: usize = 10;
6 |
7 | fn linear_z(in_dim: usize, out_dim: usize, vs: &VarBuilder) -> Result {
8 | let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
9 | let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
10 | Ok(Linear::new(ws, Some(bs)))
11 | }
12 |
13 | pub trait Model: Sized {
14 | fn new(vs: VarBuilder) -> Result;
15 | fn forward(&self, xs: &Tensor) -> Result;
16 | }
17 |
18 | pub struct LinearModel {
19 | linear: Linear,
20 | }
21 |
22 | impl Model for LinearModel {
23 | fn new(vs: VarBuilder) -> Result {
24 | let linear = linear_z(IMAGE_DIM, LABELS, &vs)?;
25 | Ok(Self { linear })
26 | }
27 |
28 | fn forward(&self, xs: &Tensor) -> Result {
29 | self.linear.forward(xs)
30 | }
31 | }
32 |
33 | pub struct Mlp {
34 | ln1: Linear,
35 | ln2: Linear,
36 | }
37 |
38 | impl Model for Mlp {
39 | fn new(vs: VarBuilder) -> Result {
40 | let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
41 | let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
42 | Ok(Self { ln1, ln2 })
43 | }
44 |
45 | fn forward(&self, xs: &Tensor) -> Result {
46 | let xs = self.ln1.forward(xs)?;
47 | let xs = xs.relu()?;
48 | self.ln2.forward(&xs)
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/paper/paper.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: 'Candle Optimisers: A Rust crate for optimisation algorithms'
3 | tags:
4 | - Rust
5 | - optimisation
6 | - optimization
7 | - machine learning
8 | authors:
9 | - name: Kirpal Grewal
10 | orcid: 0009-0001-7923-9975
11 | affiliation: 1
12 | affiliations:
13 | - name: Yusuf Hamied Department of Chemistry, University of Cambridge
14 | index: 1
15 | date: 20 December 2023
16 | bibliography: paper.bib
17 | ---
18 |
19 | # Summary
20 |
21 | `candle-optimisers` is a crate for optimisers written in Rust for use with candle (@candle) a lightweight machine learning framework. The crate offers a set of
22 | optimisers for training neural networks.
23 |
24 | # Statement of need
25 |
26 | Rust provides the opportunity for the development of high performance machine learning libraries, with a leaner runtime. However, there is a lack of optimisation algorithms implemented in Rust,
27 | with machine learning libraries currently implementing only some combination of Adam, AdamW, SGD and RMSProp.
28 | This crate aims to provide a set of complete set of optimisation algorithms for use with candle.
29 | This will allow Rust to be used for the training of models more easily.
30 |
31 | # Features
32 |
33 | This library implements the following optimisation algorithms:
34 |
35 | * SGD (including momentum and Nesterov momentum (@nmomentum))
36 |
37 | * AdaDelta (@adadelta)
38 |
39 | * AdaGrad (@adagrad)
40 |
41 | * AdaMax (@adam)
42 |
43 | * Adam (@adam) including AMSGrad (@amsgrad)
44 |
45 | * AdamW (@weightdecay) (as decoupled weight decay of Adam)
46 |
47 | * NAdam (@nadam)
48 |
49 | * RAdam (@radam)
50 |
51 | * RMSProp (@rmsprop)
52 |
53 | * LBFGS (@LBFGS)
54 |
55 | Furthermore, decoupled weight decay (@weightdecay) is implemented for all of the adaptive methods listed and SGD,
56 | allowing for use of the method beyond solely AdamW.
57 |
58 | # References
59 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "candle-optimisers"
3 | version = "0.10.0-alpha.2"
4 | edition = "2021"
5 | readme = "README.md"
6 | license = "MIT"
7 | keywords = ["optimisers", "candle", "tensor", "machine-learning"]
8 | categories = ["science"]
9 | description = "Optimisers for use with candle, the minimalist ML framework"
10 | repository = "https://github.com/KGrewal1/optimisers"
11 | exclude = ["*.ipynb"]
12 |
13 |
14 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
15 |
16 | [dependencies]
17 |
18 | # Using candle 0.9.2-alpha.1 for CUDA 13.0 support via cudarc 0.17.1+
19 | candle-core = "0.9.2-alpha.1"
20 | candle-nn = "0.9.2-alpha.1"
21 | log = "0.4.20"
22 |
23 |
24 | [dev-dependencies]
25 | anyhow = { version = "1", features = ["backtrace"] }
26 | assert_approx_eq = "1.1.0"
27 | candle-datasets = "0.9.2-alpha.1"
28 | clap = { version = "4.4.6", features = ["derive"] }
29 | criterion = { version = "0.7.0", features = ["html_reports"] }
30 |
31 | [[bench]]
32 | name = "mnist_bench"
33 | harness = false
34 |
35 | [features]
36 | default = []
37 | cuda = ["candle-core/cuda", "candle-nn/cuda"]
38 |
39 | [profile.bench]
40 | lto = true # maximal LTO optimisaiton
41 |
42 | [lints.clippy]
43 | pedantic = { level = "warn", priority = -1 }
44 | suspicious = { level = "warn", priority = -1 }
45 | perf = { level = "warn", priority = -1 }
46 | complexity = { level = "warn", priority = -1 }
47 | style = { level = "warn", priority = -1 }
48 | cargo = { level = "warn", priority = -1 }
49 | imprecise_flops = "warn"
50 | missing_errors_doc = { level = "allow", priority = 1 }
51 | uninlined_format_args = { level = "allow", priority = 1 }
52 | similar_names = { level = "allow", priority = 1 }
53 | float_cmp = { level = "allow", priority = 1 } # as internaly rounded before the comparison
54 | doc_markdown = { level = "allow", priority = 1 } # otherwise names get flagged
55 | multiple_crate_versions = { level = "allow", priority = 1 } # for candle dep graph
56 |
57 | [package.metadata.docs.rs]
58 | rustdoc-args = ["--html-in-header", "./katex-header.html"]
59 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Candle Optimisers
2 |
3 | [](https://opensource.org/licenses/MIT)
4 | [](https://codecov.io/gh/KGrewal1/optimisers)
5 | 
6 | 
7 | [](https://crates.io/crates/candle-optimisers)
8 | [](https://docs.rs/candle-optimisers)
9 |
10 | A crate for optimisers for use with [candle](https://github.com/huggingface/candle), the minimalist ML framework
11 |
12 | Optimisers implemented are:
13 |
14 | * SGD (including momentum and weight decay)
15 |
16 | * RMSprop
17 |
18 | Adaptive methods:
19 |
20 | * AdaDelta
21 |
22 | * AdaGrad
23 |
24 | * AdaMax
25 |
26 | * Adam
27 |
28 | * AdamW (included with Adam as `decoupled_weight_decay`)
29 |
30 | * NAdam
31 |
32 | * RAdam
33 |
34 | These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking).
35 |
36 | Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/pdf/1711.05101.pdf), in addition to the standard weight decay as implemented in pytorch.
37 |
38 | Pseudosecond order methods:
39 |
40 | * LBFGS
41 |
42 | This is not implemented equivalent to pytorch, but is checked on the 2D rosenbrock function
43 |
44 | ## Examples
45 |
46 | There is an mnist toy program along with a simple example of adagrad. Whilst the parameters of each method aren't tuned (all default with user input learning rate), the following converges quite nicely:
47 |
48 | ```cli
49 | cargo r -r --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
50 | ```
51 |
52 | For even faster training try:
53 |
54 | ```cli
55 | cargo r -r --features cuda --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
56 | ```
57 |
58 | to use the cuda backend.
59 |
60 | ## Usage
61 |
62 | ```cli
63 | cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers
64 | ```
65 |
66 | ## Documentation
67 |
68 | Documentation is available on the rust docs site
69 |
--------------------------------------------------------------------------------
/Changelog.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## v0.10.0-alpha.2 (2025-11-15)
4 |
5 | * Bump candle requirement to 0.9.2-alpha.1: this adds support for CUDA 13.0 via cudarc 0.17.1+
6 | * Use published alpha versions from crates.io instead of git revisions, allowing this version to be published to crates.io
7 |
8 | ## v0.10.0-alpha.1 (2025-08-12)
9 |
10 | * Prerelease for candle 0.10.0 compatibility
11 | * Remove failing codecov
12 | * Clippy lints
13 |
14 | ## v0.9.0 (2024-11-18)
15 |
16 | * Bump candle requirtement to 0.9.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn
17 |
18 | ## v0.8.0 (2024-11-18)
19 |
20 | * Bump candle requirtement to 0.8.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn
21 |
22 | ## v0.5.0 (2024-02-28)
23 |
24 | * Bump candle requirtement to 0.5.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn
25 | * Internal changes for LBFGS line search
26 |
27 | ## v0.4.0 (2024-02-28)
28 |
29 | * Bump candle requirtement to 0.4.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn
30 | * Explicit reliance on the candle crates hosted on crates.io : as cargo does not support git dependecies in published crates, this library now points only to the crates.io releases (previously cargo would default to the crates.io instead of git repo anyway: if the git repo is specifically desired this can be obtained by patching the `Cargo.toml` to point at the candle repo)
31 | * Remove intel-mkl feature: features in this library are mainly used for running examples: any code that uses this library should instead use the features directly from the candle crates
32 |
33 | ## v0.3.2 (2024-01-07)
34 |
35 | * move directional evaluation into stronge wolfe
36 | * fix strong wolfe condition when used with weight decay
37 |
38 | ## v0.3.1 (2023-12-20)
39 |
40 | * Improved Documentation
41 | * Add ability to set more optimiser parameters (see issue regarding LR schedulers in `candle`)
42 | * All params are now `Clone`, `PartialEq` and `PartialOrd`
43 |
44 | ## v0.3.0 (2023-12-07)
45 |
46 | * Renamed to candle-optimisers for release on crates.io
47 | * Added fuller documentation
48 | * Added decoupled weight decay for SGD
49 |
50 | ## v0.2.1 (2023-12-06)
51 |
52 | * Added decoupled weight decay for all adaptive methods
53 |
54 | ## v0.2.0 (2023-12-06)
55 |
56 | * changed weight decay to `Option` type as opposed to checking for 0
57 | * made `enum` for decoupled weight decay and for momentum
58 | * added weight decay to LBFGS
59 |
60 | ## v0.1.x
61 |
62 | * Initial release and adddition of features
63 |
--------------------------------------------------------------------------------
/benches/mnist_bench.rs:
--------------------------------------------------------------------------------
1 | use candle_core::Result as CResult;
2 | use candle_datasets::vision::Dataset;
3 | use candle_optimisers::{
4 | adadelta::Adadelta, adagrad::Adagrad, adam::Adam, adamax::Adamax, esgd::SGD, nadam::NAdam,
5 | radam::RAdam, rmsprop::RMSprop,
6 | };
7 | use criterion::{criterion_group, criterion_main, Criterion};
8 | use training::Mlp;
9 |
10 | // mod models;
11 | // mod optim;
12 | mod training;
13 |
14 | fn load_data() -> CResult {
15 | candle_datasets::vision::mnist::load()
16 | }
17 |
18 | #[allow(clippy::missing_panics_doc)]
19 | pub fn criterion_benchmark_std(c: &mut Criterion) {
20 | let mut group = c.benchmark_group("std-optimisers");
21 | let m = &load_data().expect("Failed to load data");
22 | // let m = Rc::new(m);
23 |
24 | group.significance_level(0.1).sample_size(100);
25 | group.bench_function("adadelta", |b| {
26 | b.iter(|| {
27 | training::run_training::(m).expect("Failed to setup training");
28 | });
29 | });
30 | group.bench_function("adagrad", |b| {
31 | b.iter(|| {
32 | training::run_training::(m).expect("Failed to setup training");
33 | });
34 | });
35 | group.bench_function("adam", |b| {
36 | b.iter(|| training::run_training::(m).expect("Failed to setup training"));
37 | });
38 | group.bench_function("adamax", |b| {
39 | b.iter(|| {
40 | training::run_training::(m).expect("Failed to setup training");
41 | });
42 | });
43 | group.bench_function("esgd", |b| {
44 | b.iter(|| {
45 | training::run_training::(m).expect("Failed to setup training");
46 | });
47 | });
48 | group.bench_function("nadam", |b| {
49 | b.iter(|| {
50 | training::run_training::(m).expect("Failed to setup training");
51 | });
52 | });
53 | group.bench_function("radam", |b| {
54 | b.iter(|| {
55 | training::run_training::(m).expect("Failed to setup training");
56 | });
57 | });
58 | group.bench_function("rmsprop", |b| {
59 | b.iter(|| {
60 | training::run_training::(m).expect("Failed to setup training");
61 | });
62 | });
63 |
64 | group.finish();
65 | }
66 |
67 | #[allow(clippy::missing_panics_doc)]
68 | pub fn criterion_benchmark_lbfgs(c: &mut Criterion) {
69 | let mut group = c.benchmark_group("lbfgs-optimser");
70 | let m = load_data().expect("Failed to load data");
71 | // let m = Rc::new(m);
72 |
73 | group.significance_level(0.1).sample_size(10);
74 |
75 | group.bench_function("lbfgs", |b| {
76 | b.iter(|| training::run_lbfgs_training::(&m).expect("Failed to setup training"));
77 | });
78 |
79 | group.finish();
80 | }
81 |
82 | criterion_group!(benches, criterion_benchmark_std, criterion_benchmark_lbfgs);
83 | criterion_main!(benches);
84 |
--------------------------------------------------------------------------------
/examples/mnist/training.rs:
--------------------------------------------------------------------------------
1 | use crate::{models::Model, optim::Optim, parse_cli::TrainingArgs};
2 | use candle_core::{DType, D};
3 | use candle_nn::{loss, ops, VarBuilder, VarMap};
4 |
5 | #[allow(clippy::module_name_repetitions)]
6 | pub fn training_loop(
7 | m: candle_datasets::vision::Dataset,
8 | args: &TrainingArgs,
9 | ) -> anyhow::Result<()> {
10 | // check to see if cuda device availabke
11 | let dev = candle_core::Device::cuda_if_available(0)?;
12 | println!("Training on device {:?}", dev);
13 |
14 | // get the labels from the dataset
15 | let train_labels = m.train_labels;
16 | // get the input from the dataset and put on device
17 | let train_images = m.train_images.to_device(&dev)?;
18 | // get the training labels on the device
19 | let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
20 |
21 | // create a new variable store
22 | let mut varmap = VarMap::new();
23 | // create a new variable builder
24 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
25 | // create model from variables
26 | let model = M::new(vs.clone())?;
27 |
28 | // see if there are pretrained weights to load
29 | if let Some(load) = &args.load {
30 | println!("loading weights from {load}");
31 | varmap.load(load)?;
32 | }
33 |
34 | // create an optimiser
35 | let mut optimiser = O::new(varmap.all_vars(), args.learning_rate)?;
36 | // candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
37 | // load the test images
38 | let test_images = m.test_images.to_device(&dev)?;
39 | // load the test labels
40 | let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
41 | // loop for model optimisation
42 | for epoch in 0..args.epochs {
43 | // get log probabilities of results
44 | let logits = model.forward(&train_images)?;
45 | // softmax the log probabilities
46 | let log_sm = ops::log_softmax(&logits, D::Minus1)?;
47 | // get the loss
48 | let loss = loss::nll(&log_sm, &train_labels)?;
49 | // step the tensors by backpropagating the loss
50 | optimiser.back_step(&loss)?;
51 |
52 | // get the log probabilities of the test images
53 | let test_logits = model.forward(&test_images)?;
54 | // get the sum of the correct predictions
55 | let sum_ok = test_logits
56 | .argmax(D::Minus1)?
57 | .eq(&test_labels)?
58 | .to_dtype(DType::F32)?
59 | .sum_all()?
60 | .to_scalar::()?;
61 | // get the accuracy on the test set
62 | #[allow(clippy::cast_precision_loss)]
63 | let test_accuracy = sum_ok / test_labels.dims1()? as f32;
64 | println!(
65 | "{:4} train loss: {:8.5} test acc: {:5.2}%",
66 | epoch + 1,
67 | loss.to_scalar::()?,
68 | 100. * test_accuracy
69 | );
70 | }
71 | if let Some(save) = &args.save {
72 | println!("saving trained weights in {save}");
73 | varmap.save(save)?;
74 | }
75 | Ok(())
76 | }
77 |
--------------------------------------------------------------------------------
/examples/mnist/main.rs:
--------------------------------------------------------------------------------
1 | use clap::Parser;
2 |
3 | mod models;
4 | mod optim;
5 | mod parse_cli;
6 | mod training;
7 |
8 | use models::{LinearModel, Mlp};
9 |
10 | use candle_optimisers::adagrad::Adagrad;
11 | use candle_optimisers::adamax::Adamax;
12 | use candle_optimisers::esgd::SGD;
13 | use candle_optimisers::nadam::NAdam;
14 | use candle_optimisers::radam::RAdam;
15 | use candle_optimisers::rmsprop::RMSprop;
16 | use candle_optimisers::{adadelta::Adadelta, adam::Adam};
17 |
18 | use parse_cli::{Args, TrainingArgs, WhichModel, WhichOptim};
19 | use training::training_loop;
20 | pub fn main() -> anyhow::Result<()> {
21 | let args = Args::parse();
22 | // Load the dataset
23 | let m = if let Some(directory) = args.local_mnist {
24 | candle_datasets::vision::mnist::load_dir(directory)?
25 | } else {
26 | candle_datasets::vision::mnist::load()?
27 | };
28 | println!("train-images: {:?}", m.train_images.shape());
29 | println!("train-labels: {:?}", m.train_labels.shape());
30 | println!("test-images: {:?}", m.test_images.shape());
31 | println!("test-labels: {:?}", m.test_labels.shape());
32 |
33 | let default_learning_rate = match args.model {
34 | WhichModel::Linear => 1.,
35 | WhichModel::Mlp => 0.05,
36 | };
37 | let training_args = TrainingArgs {
38 | epochs: args.epochs,
39 | learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
40 | load: args.load,
41 | save: args.save,
42 | };
43 |
44 | match args.optim {
45 | WhichOptim::Adadelta => match args.model {
46 | WhichModel::Linear => training_loop::(m, &training_args),
47 | WhichModel::Mlp => training_loop::(m, &training_args),
48 | },
49 | WhichOptim::Adagrad => match args.model {
50 | WhichModel::Linear => training_loop::(m, &training_args),
51 | WhichModel::Mlp => training_loop::(m, &training_args),
52 | },
53 | WhichOptim::Adamax => match args.model {
54 | WhichModel::Linear => training_loop::(m, &training_args),
55 | WhichModel::Mlp => training_loop::(m, &training_args),
56 | },
57 | WhichOptim::Sgd => match args.model {
58 | WhichModel::Linear => training_loop::(m, &training_args),
59 | WhichModel::Mlp => training_loop::(m, &training_args),
60 | },
61 | WhichOptim::NAdam => match args.model {
62 | WhichModel::Linear => training_loop::(m, &training_args),
63 | WhichModel::Mlp => training_loop::(m, &training_args),
64 | },
65 | WhichOptim::RAdam => match args.model {
66 | WhichModel::Linear => training_loop::(m, &training_args),
67 | WhichModel::Mlp => training_loop::(m, &training_args),
68 | },
69 | WhichOptim::Rms => match args.model {
70 | WhichModel::Linear => training_loop::(m, &training_args),
71 | WhichModel::Mlp => training_loop::(m, &training_args),
72 | },
73 | WhichOptim::Adam => match args.model {
74 | WhichModel::Linear => training_loop::(m, &training_args),
75 | WhichModel::Mlp => training_loop::(m, &training_args),
76 | },
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/examples/mnist/optim.rs:
--------------------------------------------------------------------------------
1 | use candle_core::{Result, Tensor, Var};
2 | use candle_nn::Optimizer;
3 | use candle_optimisers::{
4 | adadelta::{Adadelta, ParamsAdaDelta},
5 | adagrad::{Adagrad, ParamsAdaGrad},
6 | adam::{Adam, ParamsAdam},
7 | adamax::{Adamax, ParamsAdaMax},
8 | esgd::{ParamsSGD, SGD},
9 | nadam::{NAdam, ParamsNAdam},
10 | radam::{ParamsRAdam, RAdam},
11 | rmsprop::{ParamsRMSprop, RMSprop},
12 | };
13 |
14 | pub trait Optim: Sized {
15 | fn new(vars: Vec, lr: f64) -> Result;
16 | fn back_step(&mut self, loss: &Tensor) -> Result<()>;
17 | }
18 |
19 | impl Optim for Adadelta {
20 | fn new(vars: Vec, lr: f64) -> Result {
21 | ::new(
22 | vars,
23 | ParamsAdaDelta {
24 | lr,
25 | ..Default::default()
26 | },
27 | )
28 | }
29 |
30 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
31 | self.backward_step(loss)
32 | }
33 | }
34 |
35 | impl Optim for Adagrad {
36 | fn new(vars: Vec, lr: f64) -> Result {
37 | ::new(
38 | vars,
39 | ParamsAdaGrad {
40 | lr,
41 | ..Default::default()
42 | },
43 | )
44 | }
45 |
46 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
47 | self.backward_step(loss)
48 | }
49 | }
50 |
51 | impl Optim for Adamax {
52 | fn new(vars: Vec, lr: f64) -> Result {
53 | ::new(
54 | vars,
55 | ParamsAdaMax {
56 | lr,
57 | ..Default::default()
58 | },
59 | )
60 | }
61 |
62 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
63 | self.backward_step(loss)
64 | }
65 | }
66 |
67 | impl Optim for SGD {
68 | fn new(vars: Vec, lr: f64) -> Result {
69 | ::new(
70 | vars,
71 | ParamsSGD {
72 | lr,
73 | ..Default::default()
74 | },
75 | )
76 | }
77 |
78 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
79 | self.backward_step(loss)
80 | }
81 | }
82 |
83 | impl Optim for NAdam {
84 | fn new(vars: Vec, lr: f64) -> Result {
85 | ::new(
86 | vars,
87 | ParamsNAdam {
88 | lr,
89 | ..Default::default()
90 | },
91 | )
92 | }
93 |
94 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
95 | self.backward_step(loss)
96 | }
97 | }
98 |
99 | impl Optim for RAdam {
100 | fn new(vars: Vec, lr: f64) -> Result {
101 | ::new(
102 | vars,
103 | ParamsRAdam {
104 | lr,
105 | ..Default::default()
106 | },
107 | )
108 | }
109 |
110 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
111 | self.backward_step(loss)
112 | }
113 | }
114 |
115 | impl Optim for RMSprop {
116 | fn new(vars: Vec, lr: f64) -> Result {
117 | ::new(
118 | vars,
119 | ParamsRMSprop {
120 | lr,
121 | ..Default::default()
122 | },
123 | )
124 | }
125 |
126 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
127 | self.backward_step(loss)
128 | }
129 | }
130 |
131 | impl Optim for Adam {
132 | fn new(vars: Vec, lr: f64) -> Result {
133 | ::new(
134 | vars,
135 | ParamsAdam {
136 | lr,
137 | ..Default::default()
138 | },
139 | )
140 | }
141 |
142 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
143 | self.backward_step(loss)
144 | }
145 | }
146 |
--------------------------------------------------------------------------------
/tests/adadelta_tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::{
7 | adadelta::{Adadelta, ParamsAdaDelta},
8 | Decay,
9 | };
10 |
11 | /* The results of this test have been checked against the following PyTorch code.
12 | import torch
13 | from torch import optim
14 |
15 | w_gen = torch.tensor([[3., 1.]])
16 | b_gen = torch.tensor([-2.])
17 |
18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
20 |
21 | m = torch.nn.Linear(2, 1)
22 | with torch.no_grad():
23 | m.weight.zero_()
24 | m.bias.zero_()
25 | optimiser = optim.Adadelta(m.parameters(), lr=0.004)
26 | # optimiser.zero_grad()
27 | for _step in range(100):
28 | optimiser.zero_grad()
29 | ys = m(sample_xs)
30 | loss = ((ys - sample_ys)**2).sum()
31 | loss.backward()
32 | optimiser.step()
33 | # print("Optimizer state begin")
34 | # print(optimiser.state)
35 | # print("Optimizer state end")
36 | print(m.weight)
37 | print(m.bias)
38 | */
39 | #[test]
40 | fn adadelta_test() -> Result<()> {
41 | // Generate some linear data, y = 3.x1 + x2 - 2.
42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
44 | let gen = Linear::new(w_gen, Some(b_gen));
45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
46 | let sample_ys = gen.forward(&sample_xs)?;
47 |
48 | let params = ParamsAdaDelta {
49 | lr: 0.004,
50 | rho: 0.9,
51 | weight_decay: None,
52 | eps: 1e-6,
53 | };
54 | // Now use backprop to run a linear regression between samples and get the coefficients back.
55 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
56 | let b = Var::new(0f32, &Device::Cpu)?;
57 | let mut n_sgd = Adadelta::new(vec![w.clone(), b.clone()], params)?;
58 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
59 | for _step in 0..100 {
60 | let ys = lin.forward(&sample_xs)?;
61 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
62 | n_sgd.backward_step(&loss)?;
63 | }
64 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0016, 0.0016]]);
65 | assert_eq!(to_vec0_round(&b, 4)?, 0.0016);
66 | Ok(())
67 | }
68 |
69 | #[test]
70 | fn adadelta_weight_decay_test() -> Result<()> {
71 | // Generate some linear data, y = 3.x1 + x2 - 2.
72 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
73 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
74 | let gen = Linear::new(w_gen, Some(b_gen));
75 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
76 | let sample_ys = gen.forward(&sample_xs)?;
77 |
78 | let params = ParamsAdaDelta {
79 | lr: 0.004,
80 | rho: 0.9,
81 | weight_decay: Some(Decay::WeightDecay(0.8)),
82 | eps: 1e-6,
83 | };
84 | // Now use backprop to run a linear regression between samples and get the coefficients back.
85 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
86 | let b = Var::new(0f32, &Device::Cpu)?;
87 | let mut n_sgd = Adadelta::new(vec![w.clone(), b.clone()], params)?;
88 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
89 | for _step in 0..100 {
90 | let ys = lin.forward(&sample_xs)?;
91 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
92 | n_sgd.backward_step(&loss)?;
93 | }
94 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0016, 0.0016]]);
95 | assert_eq!(to_vec0_round(&b, 4)?, 0.0016);
96 | Ok(())
97 | }
98 |
99 | //-------------------------------------------------------------------------
100 | // THIS IS NOT TESTED AGAINST PYTORCH
101 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR ADADELTA
102 | // ------------------------------------------------------------------------
103 | #[test]
104 | fn adadelta_decoupled_weight_decay_test() -> Result<()> {
105 | // Generate some linear data, y = 3.x1 + x2 - 2.
106 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
107 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
108 | let gen = Linear::new(w_gen, Some(b_gen));
109 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
110 | let sample_ys = gen.forward(&sample_xs)?;
111 |
112 | let params = ParamsAdaDelta {
113 | lr: 0.004,
114 | rho: 0.9,
115 | weight_decay: Some(Decay::DecoupledWeightDecay(0.8)),
116 | eps: 1e-6,
117 | };
118 | // Now use backprop to run a linear regression between samples and get the coefficients back.
119 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
120 | let b = Var::new(0f32, &Device::Cpu)?;
121 | let mut n_sgd = Adadelta::new(vec![w.clone(), b.clone()], params)?;
122 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
123 | for _step in 0..100 {
124 | let ys = lin.forward(&sample_xs)?;
125 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
126 | n_sgd.backward_step(&loss)?;
127 | }
128 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0014, 0.0014]]);
129 | assert_eq!(to_vec0_round(&b, 4)?, 0.0014);
130 | Ok(())
131 | }
132 |
--------------------------------------------------------------------------------
/tests/radam-tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::{
7 | radam::{ParamsRAdam, RAdam},
8 | Decay,
9 | };
10 |
11 | /* The results of this test have been checked against the following PyTorch code.
12 | import torch
13 | from torch import optim
14 |
15 | w_gen = torch.tensor([[3., 1.]])
16 | b_gen = torch.tensor([-2.])
17 |
18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
20 |
21 | m = torch.nn.Linear(2, 1)
22 | with torch.no_grad():
23 | m.weight.zero_()
24 | m.bias.zero_()
25 | optimiser = optim.RAdam(m.parameters())
26 | # optimiser.zero_grad()
27 | for _step in range(100):
28 | optimiser.zero_grad()
29 | ys = m(sample_xs)
30 | loss = ((ys - sample_ys)**2).sum()
31 | loss.backward()
32 | optimiser.step()
33 | # print("Optimizer state begin")
34 | # print(optimiser.state)
35 | # print("Optimizer state end")
36 | print(m.weight)
37 | print(m.bias)
38 | */
39 | #[test]
40 | fn radam_test() -> Result<()> {
41 | // Generate some linear data, y = 3.x1 + x2 - 2.
42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
44 | let gen = Linear::new(w_gen, Some(b_gen));
45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
46 | let sample_ys = gen.forward(&sample_xs)?;
47 |
48 | let params = ParamsRAdam::default();
49 | // Now use backprop to run a linear regression between samples and get the coefficients back.
50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
51 | let b = Var::new(0f32, &Device::Cpu)?;
52 | let mut n_sgd = RAdam::new(vec![w.clone(), b.clone()], params)?;
53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
54 | for _step in 0..100 {
55 | let ys = lin.forward(&sample_xs)?;
56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
57 | n_sgd.backward_step(&loss)?;
58 | }
59 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.2128, 1.2819]]);
60 | assert_eq!(to_vec0_round(&b, 4)?, 0.2923);
61 | Ok(())
62 | }
63 |
64 | /* The results of this test have been checked against the following PyTorch code.
65 | import torch
66 | from torch import optim
67 |
68 | w_gen = torch.tensor([[3., 1.]])
69 | b_gen = torch.tensor([-2.])
70 |
71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
73 |
74 | m = torch.nn.Linear(2, 1)
75 | with torch.no_grad():
76 | m.weight.zero_()
77 | m.bias.zero_()
78 | optimiser = optim.RAdam(m.parameters(), weight_decay = 0.4)
79 | # optimiser.zero_grad()
80 | for _step in range(100):
81 | optimiser.zero_grad()
82 | ys = m(sample_xs)
83 | loss = ((ys - sample_ys)**2).sum()
84 | loss.backward()
85 | optimiser.step()
86 | # print("Optimizer state begin")
87 | # print(optimiser.state)
88 | # print("Optimizer state end")
89 | print(m.weight)
90 | print(m.bias)
91 | */
92 | #[test]
93 | fn radam_weight_decay_test() -> Result<()> {
94 | // Generate some linear data, y = 3.x1 + x2 - 2.
95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
97 | let gen = Linear::new(w_gen, Some(b_gen));
98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
99 | let sample_ys = gen.forward(&sample_xs)?;
100 |
101 | let params = ParamsRAdam {
102 | weight_decay: Some(Decay::WeightDecay(0.4)),
103 | ..Default::default()
104 | };
105 | // Now use backprop to run a linear regression between samples and get the coefficients back.
106 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
107 | let b = Var::new(0f32, &Device::Cpu)?;
108 | let mut n_sgd = RAdam::new(vec![w.clone(), b.clone()], params)?;
109 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
110 | for _step in 0..100 {
111 | let ys = lin.forward(&sample_xs)?;
112 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
113 | n_sgd.backward_step(&loss)?;
114 | }
115 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.2117, 1.2812]]);
116 | assert_eq!(to_vec0_round(&b, 4)?, 0.2921);
117 | Ok(())
118 | }
119 |
120 | //-------------------------------------------------------------------------
121 | // THIS IS NOT TESTED AGAINST PYTORCH
122 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR RADAM
123 | // ------------------------------------------------------------------------
124 | #[test]
125 | fn radam_decoupled_weight_decay_test() -> Result<()> {
126 | // Generate some linear data, y = 3.x1 + x2 - 2.
127 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
128 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
129 | let gen = Linear::new(w_gen, Some(b_gen));
130 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
131 | let sample_ys = gen.forward(&sample_xs)?;
132 |
133 | let params = ParamsRAdam {
134 | weight_decay: Some(Decay::DecoupledWeightDecay(0.4)),
135 | ..Default::default()
136 | };
137 | // Now use backprop to run a linear regression between samples and get the coefficients back.
138 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
139 | let b = Var::new(0f32, &Device::Cpu)?;
140 | let mut n_sgd = RAdam::new(vec![w.clone(), b.clone()], params)?;
141 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
142 | for _step in 0..100 {
143 | let ys = lin.forward(&sample_xs)?;
144 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
145 | n_sgd.backward_step(&loss)?;
146 | }
147 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.1294, 1.2331]]);
148 | assert_eq!(to_vec0_round(&b, 4)?, 0.2818);
149 | Ok(())
150 | }
151 |
--------------------------------------------------------------------------------
/tests/adamax_tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::adamax::{Adamax, ParamsAdaMax};
7 |
8 | /* The results of this test have been checked against the following PyTorch code.
9 | import torch
10 | from torch import optim
11 |
12 | w_gen = torch.tensor([[3., 1.]])
13 | b_gen = torch.tensor([-2.])
14 |
15 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
16 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
17 |
18 | m = torch.nn.Linear(2, 1)
19 | with torch.no_grad():
20 | m.weight.zero_()
21 | m.bias.zero_()
22 | optimiser = optim.Adamax(m.parameters(), lr=0.004)
23 | # optimiser.zero_grad()
24 | for _step in range(100):
25 | optimiser.zero_grad()
26 | ys = m(sample_xs)
27 | loss = ((ys - sample_ys)**2).sum()
28 | loss.backward()
29 | optimiser.step()
30 | # print("Optimizer state begin")
31 | # print(optimiser.state)
32 | # print("Optimizer state end")
33 | print(m.weight)
34 | print(m.bias)
35 | */
36 | #[test]
37 | fn adamax_test() -> Result<()> {
38 | // Generate some linear data, y = 3.x1 + x2 - 2.
39 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
40 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
41 | let gen = Linear::new(w_gen, Some(b_gen));
42 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
43 | let sample_ys = gen.forward(&sample_xs)?;
44 |
45 | let params = ParamsAdaMax {
46 | lr: 0.004,
47 | ..Default::default()
48 | };
49 | // Now use backprop to run a linear regression between samples and get the coefficients back.
50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
51 | let b = Var::new(0f32, &Device::Cpu)?;
52 | let mut n_sgd = Adamax::new(vec![w.clone(), b.clone()], params)?;
53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
54 | for _step in 0..100 {
55 | let ys = lin.forward(&sample_xs)?;
56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
57 | n_sgd.backward_step(&loss)?;
58 | }
59 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.3895, 0.3450]]);
60 | assert_eq!(to_vec0_round(&b, 4)?, 0.3643);
61 | Ok(())
62 | }
63 |
64 | /* The results of this test have been checked against the following PyTorch code.
65 | import torch
66 | from torch import optim
67 |
68 | w_gen = torch.tensor([[3., 1.]])
69 | b_gen = torch.tensor([-2.])
70 |
71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
73 |
74 | m = torch.nn.Linear(2, 1)
75 | with torch.no_grad():
76 | m.weight.zero_()
77 | m.bias.zero_()
78 | optimiser = optim.Adamax(m.parameters(), lr=0.004, weight_decay = 0.6)
79 | # optimiser.zero_grad()
80 | for _step in range(100):
81 | optimiser.zero_grad()
82 | ys = m(sample_xs)
83 | loss = ((ys - sample_ys)**2).sum()
84 | loss.backward()
85 | optimiser.step()
86 | # print("Optimizer state begin")
87 | # print(optimiser.state)
88 | # print("Optimizer state end")
89 | print(m.weight)
90 | print(m.bias)
91 | */
92 | #[test]
93 | fn adamax_weight_decay_test() -> Result<()> {
94 | // Generate some linear data, y = 3.x1 + x2 - 2.
95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
97 | let gen = Linear::new(w_gen, Some(b_gen));
98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
99 | let sample_ys = gen.forward(&sample_xs)?;
100 |
101 | let params = ParamsAdaMax {
102 | lr: 0.004,
103 | weight_decay: Some(candle_optimisers::Decay::WeightDecay(0.6)),
104 | ..Default::default()
105 | };
106 | // Now use backprop to run a linear regression between samples and get the coefficients back.
107 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
108 | let b = Var::new(0f32, &Device::Cpu)?;
109 | let mut n_sgd = Adamax::new(vec![w.clone(), b.clone()], params)?;
110 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
111 | for _step in 0..100 {
112 | let ys = lin.forward(&sample_xs)?;
113 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
114 | n_sgd.backward_step(&loss)?;
115 | }
116 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.3894, 0.3450]]);
117 | assert_eq!(to_vec0_round(&b, 4)?, 0.3639);
118 | Ok(())
119 | }
120 |
121 | //-------------------------------------------------------------------------
122 | // THIS IS NOT TESTED AGAINST PYTORCH
123 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR ADAMAX
124 | // ------------------------------------------------------------------------
125 | #[test]
126 | fn adamax_decoupled_weight_decay_test() -> Result<()> {
127 | // Generate some linear data, y = 3.x1 + x2 - 2.
128 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
129 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
130 | let gen = Linear::new(w_gen, Some(b_gen));
131 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
132 | let sample_ys = gen.forward(&sample_xs)?;
133 |
134 | let params = ParamsAdaMax {
135 | lr: 0.004,
136 | weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.6)),
137 | ..Default::default()
138 | };
139 | // Now use backprop to run a linear regression between samples and get the coefficients back.
140 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
141 | let b = Var::new(0f32, &Device::Cpu)?;
142 | let mut n_sgd = Adamax::new(vec![w.clone(), b.clone()], params)?;
143 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
144 | for _step in 0..100 {
145 | let ys = lin.forward(&sample_xs)?;
146 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
147 | n_sgd.backward_step(&loss)?;
148 | }
149 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.3481, 0.3095]]);
150 | assert_eq!(to_vec0_round(&b, 4)?, 0.3263);
151 | Ok(())
152 | }
153 |
--------------------------------------------------------------------------------
/paper/paper.bib:
--------------------------------------------------------------------------------
1 | @article{adadelta,
2 | author = {Matthew D. Zeiler},
3 | title = {{ADADELTA:} An Adaptive Learning Rate Method},
4 | journal = {CoRR},
5 | volume = {abs/1212.5701},
6 | year = {2012},
7 | url = {http://arxiv.org/abs/1212.5701},
8 | eprinttype = {arXiv},
9 | eprint = {1212.5701},
10 | timestamp = {Mon, 13 Aug 2018 16:45:57 +0200},
11 | biburl = {https://dblp.org/rec/journals/corr/abs-1212-5701.bib},
12 | bibsource = {dblp computer science bibliography, https://dblp.org},
13 | doi = {10.48550/arXiv.1212.5701}
14 | }
15 |
16 | @article{adagrad,
17 | author = {John Duchi and Elad Hazan and Yoram Singer},
18 | title = {Adaptive Subgradient Methods for Online Learning and Stochastic Optimization},
19 | journal = {Journal of Machine Learning Research},
20 | year = {2011},
21 | volume = {12},
22 | number = {61},
23 | pages = {2121--2159},
24 | url = {http://jmlr.org/papers/v12/duchi11a.html},
25 |
26 | }
27 |
28 | @inproceedings{adam,
29 | author = {Diederik P. Kingma and
30 | Jimmy Ba},
31 | editor = {Yoshua Bengio and
32 | Yann LeCun},
33 | title = {Adam: {A} Method for Stochastic Optimization},
34 | booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015,
35 | San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings},
36 | year = {2015},
37 | url = {http://arxiv.org/abs/1412.6980},
38 | timestamp = {Thu, 25 Jul 2019 14:25:37 +0200},
39 | biburl = {https://dblp.org/rec/journals/corr/KingmaB14.bib},
40 | bibsource = {dblp computer science bibliography, https://dblp.org},
41 | doi = {10.48550/arXiv.1412.6980}
42 | }
43 |
44 | @article{weightdecay,
45 | author = {Ilya Loshchilov and
46 | Frank Hutter},
47 | title = {Fixing Weight Decay Regularization in Adam},
48 | journal = {CoRR},
49 | volume = {abs/1711.05101},
50 | year = {2017},
51 | url = {http://arxiv.org/abs/1711.05101},
52 | eprinttype = {arXiv},
53 | eprint = {1711.05101},
54 | timestamp = {Mon, 13 Aug 2018 16:48:18 +0200},
55 | biburl = {https://dblp.org/rec/journals/corr/abs-1711-05101.bib},
56 | bibsource = {dblp computer science bibliography, https://dblp.org},
57 | doi = {10.48550/arXiv.1711.05101}
58 | }
59 |
60 | @inproceedings{amsgrad,
61 | title={On the Convergence of Adam and Beyond},
62 | author={Sashank J. Reddi and Satyen Kale and Sanjiv Kumar},
63 | booktitle={International Conference on Learning Representations},
64 | year={2018},
65 | url={https://openreview.net/forum?id=ryQu7f-RZ}
66 | }
67 |
68 | @inproceedings{nmomentum,
69 | title = {On the importance of initialization and momentum in deep learning},
70 | author = {Sutskever, Ilya and Martens, James and Dahl, George and Hinton, Geoffrey},
71 | booktitle = {Proceedings of the 30th International Conference on Machine Learning},
72 | pages = {1139--1147},
73 | year = {2013},
74 | editor = {Dasgupta, Sanjoy and McAllester, David},
75 | volume = {28},
76 | number = {3},
77 | series = {Proceedings of Machine Learning Research},
78 | address = {Atlanta, Georgia, USA},
79 | month = {17--19 Jun},
80 | publisher = {PMLR},
81 | pdf = {http://proceedings.mlr.press/v28/sutskever13.pdf},
82 | url = {https://proceedings.mlr.press/v28/sutskever13.html},
83 | abstract = {Deep and recurrent neural networks (DNNs and RNNs respectively) are powerful models that were considered to be almost impossible to train using stochastic gradient descent with momentum. In this paper, we show that when stochastic gradient descent with momentum uses a well-designed random initialization and a particular type of slowly increasing schedule for the momentum parameter, it can train both DNNs and RNNs (on datasets with long-term dependencies) to levels of performance that were previously achievable only with Hessian-Free optimization. We find that both the initialization and the momentum are crucial since poorly initialized networks cannot be trained with momentum and well-initialized networks perform markedly worse when the momentum is absent or poorly tuned. Our success training these models suggests that previous attempts to train deep and recurrent neural networks from random initializations have likely failed due to poor initialization schemes. Furthermore, carefully tuned momentum methods suffice for dealing with the curvature issues in deep and recurrent network training objectives without the need for sophisticated second-order methods. }
84 | }
85 |
86 |
87 | @article{LBFGS,
88 | title={On the limited memory BFGS method for large scale optimization},
89 | author={Liu, Dong C and Nocedal, Jorge},
90 | journal={Mathematical programming},
91 | volume={45},
92 | number={1-3},
93 | pages={503--528},
94 | year={1989},
95 | publisher={Springer},
96 | doi = {10.1007/BF01589116}
97 | }
98 |
99 | @inproceedings{nadam,
100 | title = {Incorporating {Nesterov Momentum into Adam}},
101 | author = {Dozat, Timothy},
102 | booktitle = {Proceedings of the 4th International Conference on Learning Representations},
103 | pages = {1--4},
104 | date = 2016,
105 | url = {https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ},
106 | biburl = {https://bibbase.org/network/publication/dozat-incorporatingnesterovmomentumintoadam}
107 | }
108 |
109 | @article{radam,
110 | author = {Liyuan Liu and
111 | Haoming Jiang and
112 | Pengcheng He and
113 | Weizhu Chen and
114 | Xiaodong Liu and
115 | Jianfeng Gao and
116 | Jiawei Han},
117 | title = {On the Variance of the Adaptive Learning Rate and Beyond},
118 | journal = {CoRR},
119 | volume = {abs/1908.03265},
120 | year = {2019},
121 | url = {http://arxiv.org/abs/1908.03265},
122 | eprinttype = {arXiv},
123 | eprint = {1908.03265},
124 | timestamp = {Mon, 30 May 2022 13:48:56 +0200},
125 | biburl = {https://dblp.org/rec/journals/corr/abs-1908-03265.bib},
126 | bibsource = {dblp computer science bibliography, https://dblp.org},
127 | doi = {10.48550/arXiv.1908.03265}
128 | }
129 |
130 | @misc{rmsprop,
131 | author = {Geoffrey Hinton and
132 | Nitish Srivastava and
133 | Kevin Swersky},
134 | title = {Neural Networks for Machine Learning},
135 | year = {2012},
136 | url = {https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf}
137 | }
138 |
139 | @misc{candle,
140 | title={Candle: a Minimalist ML Framework for Rust},
141 | author={Laurent Mazar\'e and others},
142 | year={2023},
143 | url={https://github.com/huggingface/candle/},
144 | }
--------------------------------------------------------------------------------
/tests/nadam_tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::{
7 | nadam::{NAdam, ParamsNAdam},
8 | Decay,
9 | };
10 |
11 | /* The results of this test have been checked against the following PyTorch code.
12 | import torch
13 | from torch import optim
14 |
15 | w_gen = torch.tensor([[3., 1.]])
16 | b_gen = torch.tensor([-2.])
17 |
18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
20 |
21 | m = torch.nn.Linear(2, 1)
22 | with torch.no_grad():
23 | m.weight.zero_()
24 | m.bias.zero_()
25 | optimiser = optim.NAdam(m.parameters())
26 | # optimiser.zero_grad()
27 | for _step in range(100):
28 | optimiser.zero_grad()
29 | ys = m(sample_xs)
30 | loss = ((ys - sample_ys)**2).sum()
31 | loss.backward()
32 | optimiser.step()
33 | # print("Optimizer state begin")
34 | # print(optimiser.state)
35 | # print("Optimizer state end")
36 | print(m.weight)
37 | print(m.bias)
38 | */
39 | #[test]
40 | fn nadam_test() -> Result<()> {
41 | // Generate some linear data, y = 3.x1 + x2 - 2.
42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
44 | let gen = Linear::new(w_gen, Some(b_gen));
45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
46 | let sample_ys = gen.forward(&sample_xs)?;
47 |
48 | let params = ParamsNAdam::default();
49 | // Now use backprop to run a linear regression between samples and get the coefficients back.
50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
51 | let b = Var::new(0f32, &Device::Cpu)?;
52 | let mut n_sgd = NAdam::new(vec![w.clone(), b.clone()], params)?;
53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
54 | for _step in 0..100 {
55 | let ys = lin.forward(&sample_xs)?;
56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
57 | n_sgd.backward_step(&loss)?;
58 | }
59 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1897, 0.1837]]);
60 | assert_eq!(to_vec0_round(&b, 4)?, 0.1864);
61 | Ok(())
62 | }
63 |
64 | /* The results of this test have been checked against the following PyTorch code.
65 | import torch
66 | from torch import optim
67 |
68 | w_gen = torch.tensor([[3., 1.]])
69 | b_gen = torch.tensor([-2.])
70 |
71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
73 |
74 | m = torch.nn.Linear(2, 1)
75 | with torch.no_grad():
76 | m.weight.zero_()
77 | m.bias.zero_()
78 | optimiser = optim.NAdam(m.parameters(), weight_decay = 0.6)
79 | # optimiser.zero_grad()
80 | for _step in range(100):
81 | optimiser.zero_grad()
82 | ys = m(sample_xs)
83 | loss = ((ys - sample_ys)**2).sum()
84 | loss.backward()
85 | optimiser.step()
86 | # print("Optimizer state begin")
87 | # print(optimiser.state)
88 | # print("Optimizer state end")
89 | print(m.weight)
90 | print(m.bias)
91 | */
92 | #[test]
93 | fn nadam_weight_decay_test() -> Result<()> {
94 | // Generate some linear data, y = 3.x1 + x2 - 2.
95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
97 | let gen = Linear::new(w_gen, Some(b_gen));
98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
99 | let sample_ys = gen.forward(&sample_xs)?;
100 |
101 | let params = ParamsNAdam {
102 | weight_decay: Some(Decay::WeightDecay(0.6)),
103 | ..Default::default()
104 | };
105 | // Now use backprop to run a linear regression between samples and get the coefficients back.
106 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
107 | let b = Var::new(0f32, &Device::Cpu)?;
108 | let mut n_sgd = NAdam::new(vec![w.clone(), b.clone()], params)?;
109 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
110 | for _step in 0..100 {
111 | let ys = lin.forward(&sample_xs)?;
112 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
113 | n_sgd.backward_step(&loss)?;
114 | }
115 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1897, 0.1837]]);
116 | assert_eq!(to_vec0_round(&b, 4)?, 0.1863);
117 | Ok(())
118 | }
119 |
120 | /* The results of this test have been checked against the following PyTorch code.
121 | import torch
122 | from torch import optim
123 |
124 | w_gen = torch.tensor([[3., 1.]])
125 | b_gen = torch.tensor([-2.])
126 |
127 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
128 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
129 |
130 | m = torch.nn.Linear(2, 1)
131 | with torch.no_grad():
132 | m.weight.zero_()
133 | m.bias.zero_()
134 | optimiser = optim.NAdam(m.parameters(), weight_decay = 0.6, decoupled_weight_decay=True)
135 | # optimiser.zero_grad()
136 | for _step in range(100):
137 | optimiser.zero_grad()
138 | ys = m(sample_xs)
139 | loss = ((ys - sample_ys)**2).sum()
140 | loss.backward()
141 | optimiser.step()
142 | # print("Optimizer state begin")
143 | # print(optimiser.state)
144 | # print("Optimizer state end")
145 | print(m.weight)
146 | print(m.bias)
147 | */
148 | #[test]
149 | fn nadam_decoupled_weight_decay_test() -> Result<()> {
150 | // Generate some linear data, y = 3.x1 + x2 - 2.
151 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
152 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
153 | let gen = Linear::new(w_gen, Some(b_gen));
154 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
155 | let sample_ys = gen.forward(&sample_xs)?;
156 |
157 | let params = ParamsNAdam {
158 | weight_decay: Some(Decay::DecoupledWeightDecay(0.6)),
159 | ..Default::default()
160 | };
161 | // Now use backprop to run a linear regression between samples and get the coefficients back.
162 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
163 | let b = Var::new(0f32, &Device::Cpu)?;
164 | let mut n_sgd = NAdam::new(vec![w.clone(), b.clone()], params)?;
165 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
166 | for _step in 0..100 {
167 | let ys = lin.forward(&sample_xs)?;
168 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
169 | n_sgd.backward_step(&loss)?;
170 | }
171 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1792, 0.1737]]);
172 | assert_eq!(to_vec0_round(&b, 4)?, 0.1762);
173 | Ok(())
174 | }
175 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | /*!
2 | Optimisers for use with the [candle](https://github.com/huggingface/candle) framework for lightweight machine learning.
3 | Apart from LBFGS, these all implement the [`candle_nn::optim::Optimizer`] trait from candle-nn
4 |
5 | # Example
6 |
7 | Training an MNIST model using the Adam optimiser
8 |
9 | ```
10 | # use candle_core::{Result, Tensor};
11 | # use candle_core::{DType, D};
12 | # use candle_nn::{loss, ops, VarBuilder, VarMap, optim::Optimizer};
13 | # use candle_optimisers::{
14 | # adam::{Adam, ParamsAdam}
15 | # };
16 | #
17 | # pub trait Model: Sized {
18 | # fn new(vs: VarBuilder) -> Result;
19 | # fn forward(&self, xs: &Tensor) -> Result;
20 | # }
21 | #
22 | # pub fn training_loop(
23 | # m: candle_datasets::vision::Dataset,
24 | # varmap: &VarMap,
25 | # model: M,
26 | # ) -> anyhow::Result<()> {
27 | # // check to see if cuda device availabke
28 | # let dev = candle_core::Device::cuda_if_available(0)?;
29 | # // get the input from the dataset and put on device
30 | # let train_images = m.train_images.to_device(&dev)?;
31 | # // get the training labels on the device
32 | # let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
33 | #
34 | #
35 | # // load the test images
36 | # let test_images = m.test_images.to_device(&dev)?;
37 | # // load the test labels
38 | # let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
39 | #
40 | // create the Adam optimiser
41 |
42 | // set the learning rate to 0.004 and use the default parameters for everything else
43 | let params = ParamsAdam {
44 | lr: 0.004,
45 | ..Default::default()
46 | };
47 | // create the optimiser by passing in the variable to be optimised and the parameters
48 | let mut optimiser = Adam::new(varmap.all_vars(), params)?;
49 |
50 | // loop for model optimisation
51 | for epoch in 0..100 {
52 | // run the model forwards
53 | // get log probabilities of results
54 | let logits = model.forward(&train_images)?;
55 | // softmax the log probabilities
56 | let log_sm = ops::log_softmax(&logits, D::Minus1)?;
57 | // get the loss
58 | let loss = loss::nll(&log_sm, &train_labels)?;
59 | // step the tensors by backpropagating the loss
60 | optimiser.backward_step(&loss)?;
61 |
62 | # // get the log probabilities of the test images
63 | # let test_logits = model.forward(&test_images)?;
64 | # // get the sum of the correct predictions
65 | # let sum_ok = test_logits
66 | # .argmax(D::Minus1)?
67 | # .eq(&test_labels)?
68 | # .to_dtype(DType::F32)?
69 | # .sum_all()?
70 | # .to_scalar::()?;
71 | # // get the accuracy on the test set
72 | # #[allow(clippy::cast_precision_loss)]
73 | # let test_accuracy = sum_ok / test_labels.dims1()? as f32;
74 | # println!(
75 | # "{:4} train loss: {:8.5} test acc: {:5.2}%",
76 | # epoch + 1,
77 | # loss.to_scalar::()?,
78 | # 100. * test_accuracy
79 | # );
80 | }
81 | Ok(())
82 | # }
83 | ```
84 | */
85 |
86 | use std::fmt::Debug;
87 |
88 | use candle_core::Result as CResult;
89 | use candle_core::Tensor;
90 | use candle_core::Var;
91 | pub mod adadelta;
92 | pub mod adagrad;
93 | pub mod adam;
94 | pub mod adamax;
95 | pub mod esgd;
96 | pub mod lbfgs;
97 | pub mod nadam;
98 | pub mod radam;
99 | pub mod rmsprop;
100 |
101 | /// Trait for optimisers to expose their parameters
102 | pub trait OptimParams: candle_nn::optim::Optimizer {
103 | /// get the current parameters of the Optimiser
104 | fn params(&self) -> &Self::Config;
105 | /// set the current parameters of the Optimiser
106 | fn set_params(&mut self, config: Self::Config);
107 | }
108 |
109 | /// Trait for Models: this is needed for optimisers that require the ability to calculate the loss
110 | /// such as LBFGS
111 | pub trait Model: Sized {
112 | /// get the loss of the model
113 | fn loss(&self) -> CResult; //, xs: &Tensor, ys: &Tensor
114 | }
115 |
116 | /// trait for optimisers like LBFGS that need the ability to calculate the loss
117 | /// and its gradient
118 | pub trait LossOptimizer<'a, M: Model>: Sized {
119 | /// type of the optimiser configuration
120 | type Config: Sized;
121 | /// create a new optimiser from a Vec of variables, setup parameters and a model
122 | fn new(vs: Vec, params: Self::Config, model: &'a M) -> CResult;
123 | /// take a step of the optimiser
124 | fn backward_step(&mut self, loss: &Tensor) -> CResult; //, xs: &Tensor, ys: &Tensor
125 | /// get the current learning rate
126 | fn learning_rate(&self) -> f64;
127 | /// set the learning rate
128 | fn set_learning_rate(&mut self, lr: f64);
129 | /// get the a vec of the variables being optimised
130 | fn into_inner(self) -> Vec;
131 | /// create a new optimiser from a slice of variables, setup parameters and a model
132 | fn from_slice(vars: &[&Var], config: Self::Config, model: &'a M) -> CResult {
133 | let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
134 | Self::new(vars, config, model)
135 | }
136 | }
137 |
138 | /// Outcomes of an optimiser step for methods such as LBFGS
139 | #[derive(Debug)]
140 | pub enum ModelOutcome {
141 | /// The model took a step and the loss decreased
142 | /// contains next loss and the number of func evals
143 | Stepped(Tensor, usize),
144 | /// The model has converged and the loss has not changed
145 | /// contains loss and the number of func evals
146 | Converged(Tensor, usize),
147 | }
148 |
149 | /// Method of weight decay to use
150 | #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
151 | pub enum Decay {
152 | /// Weight decay regularisation to penalise large weights
153 | ///
154 | /// The gradient is transformed as
155 | /// $$ g_{t} \\gets g_{t} + \\lambda \\theta_{t-1}$$
156 | ///
157 | /// This is equivalent to an L2 regularisation term in the loss adding $\\frac{\\lambda}{2}||\theta||_{2}^{2}$ but avoids autodifferentiation
158 | /// of the L2 term
159 | WeightDecay(f64),
160 | /// Decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
161 | ///
162 | /// This directly decays the weights as
163 | ///
164 | /// $$ \\theta_{t} \\gets (1 - \\eta \\lambda) \\theta_{t-1}$$
165 | ///
166 | /// This is equivalent to regularisation, only for SGD without momentum, but is different for adaptive gradient methods
167 | DecoupledWeightDecay(f64),
168 | }
169 |
170 | /// Type of momentum to use
171 | #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
172 | pub enum Momentum {
173 | /// classical momentum
174 | Classical(f64),
175 | /// nesterov momentum
176 | Nesterov(f64),
177 | }
178 |
--------------------------------------------------------------------------------
/benches/training.rs:
--------------------------------------------------------------------------------
1 | use candle_core::{DType, Result, Tensor, Var, D};
2 | use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
3 |
4 | use candle_optimisers::{
5 | adadelta::{Adadelta, ParamsAdaDelta},
6 | adagrad::{Adagrad, ParamsAdaGrad},
7 | adam::{Adam, ParamsAdam},
8 | adamax::{Adamax, ParamsAdaMax},
9 | esgd::{ParamsSGD, SGD},
10 | lbfgs::{Lbfgs, LineSearch, ParamsLBFGS},
11 | nadam::{NAdam, ParamsNAdam},
12 | radam::{ParamsRAdam, RAdam},
13 | rmsprop::{ParamsRMSprop, RMSprop},
14 | LossOptimizer, Model,
15 | };
16 |
17 | pub trait Optim: Sized {
18 | fn new(vars: Vec) -> Result;
19 | fn back_step(&mut self, loss: &Tensor) -> Result<()>;
20 | }
21 |
22 | impl Optim for Adadelta {
23 | fn new(vars: Vec) -> Result {
24 | ::new(vars, ParamsAdaDelta::default())
25 | }
26 |
27 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
28 | self.backward_step(loss)
29 | }
30 | }
31 |
32 | impl Optim for Adagrad {
33 | fn new(vars: Vec) -> Result {
34 | ::new(vars, ParamsAdaGrad::default())
35 | }
36 |
37 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
38 | self.backward_step(loss)
39 | }
40 | }
41 |
42 | impl Optim for Adamax {
43 | fn new(vars: Vec) -> Result {
44 | ::new(vars, ParamsAdaMax::default())
45 | }
46 |
47 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
48 | self.backward_step(loss)
49 | }
50 | }
51 |
52 | impl Optim for SGD {
53 | fn new(vars: Vec) -> Result {
54 | ::new(vars, ParamsSGD::default())
55 | }
56 |
57 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
58 | self.backward_step(loss)
59 | }
60 | }
61 |
62 | impl Optim for NAdam {
63 | fn new(vars: Vec) -> Result {
64 | ::new(vars, ParamsNAdam::default())
65 | }
66 |
67 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
68 | self.backward_step(loss)
69 | }
70 | }
71 |
72 | impl Optim for RAdam {
73 | fn new(vars: Vec) -> Result {
74 | ::new(vars, ParamsRAdam::default())
75 | }
76 |
77 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
78 | self.backward_step(loss)
79 | }
80 | }
81 |
82 | impl Optim for RMSprop {
83 | fn new(vars: Vec) -> Result {
84 | ::new(vars, ParamsRMSprop::default())
85 | }
86 |
87 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
88 | self.backward_step(loss)
89 | }
90 | }
91 |
92 | impl Optim for Adam {
93 | fn new(vars: Vec) -> Result {
94 | ::new(vars, ParamsAdam::default())
95 | }
96 |
97 | fn back_step(&mut self, loss: &Tensor) -> Result<()> {
98 | self.backward_step(loss)
99 | }
100 | }
101 |
102 | const IMAGE_DIM: usize = 784;
103 | const LABELS: usize = 10;
104 |
105 | pub trait SimpleModel: Sized {
106 | fn new(vs: VarBuilder, train_data: Tensor, train_labels: Tensor) -> Result;
107 | fn forward(&self) -> Result;
108 | }
109 |
110 | pub struct Mlp {
111 | ln1: Linear,
112 | ln2: Linear,
113 | train_data: Tensor,
114 | train_labels: Tensor,
115 | }
116 |
117 | impl SimpleModel for Mlp {
118 | fn new(vs: VarBuilder, train_data: Tensor, train_labels: Tensor) -> Result {
119 | let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
120 | let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
121 | Ok(Self {
122 | ln1,
123 | ln2,
124 | train_data,
125 | train_labels,
126 | })
127 | }
128 |
129 | fn forward(&self) -> Result {
130 | let xs = self.ln1.forward(&self.train_data)?;
131 | let xs = xs.relu()?;
132 | self.ln2.forward(&xs)
133 | }
134 | }
135 |
136 | impl Model for Mlp {
137 | fn loss(&self) -> Result {
138 | let logits = self.forward()?;
139 | // softmax the log probabilities
140 | let log_sm = ops::log_softmax(&logits, D::Minus1)?;
141 | // get the loss
142 | loss::nll(&log_sm, &self.train_labels)
143 | }
144 | }
145 |
146 | #[allow(clippy::module_name_repetitions)]
147 | pub fn run_training(
148 | m: &candle_datasets::vision::Dataset,
149 | ) -> anyhow::Result<()> {
150 | // check to see if cuda device availabke
151 | let dev = candle_core::Device::cuda_if_available(0)?;
152 |
153 | // get the labels from the dataset
154 | let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
155 | // get the input from the dataset and put on device
156 | let train_images = m.train_images.to_device(&dev)?;
157 |
158 | // create a new variable store
159 | let varmap = VarMap::new();
160 | // create a new variable builder
161 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
162 | // create model from variables
163 | let model = M::new(vs.clone(), train_images, train_labels)?;
164 |
165 | // create an optimiser
166 | let mut optimiser = O::new(varmap.all_vars())?;
167 | // load the test images
168 | let _test_images = m.test_images.to_device(&dev)?;
169 | // load the test labels
170 | let _test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
171 |
172 | for _epoch in 0..100 {
173 | // get the loss
174 | let loss = model.loss()?;
175 | // step the tensors by backpropagating the loss
176 | optimiser.back_step(&loss)?;
177 | }
178 | Ok(())
179 | }
180 |
181 | #[allow(clippy::module_name_repetitions)]
182 | pub fn run_lbfgs_training(
183 | m: &candle_datasets::vision::Dataset,
184 | ) -> anyhow::Result<()> {
185 | // check to see if cuda device availabke
186 | let dev = candle_core::Device::cuda_if_available(0)?;
187 |
188 | // get the labels from the dataset
189 | let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
190 | // get the input from the dataset and put on device
191 | let train_images = m.train_images.to_device(&dev)?;
192 |
193 | // create a new variable store
194 | let varmap = VarMap::new();
195 | // create a new variable builder
196 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
197 | // create model from variables
198 | let model = M::new(vs.clone(), train_images, train_labels)?;
199 |
200 | let params = ParamsLBFGS {
201 | lr: 1.,
202 | history_size: 4,
203 | line_search: Some(LineSearch::StrongWolfe(1e-4, 0.9, 1e-9)),
204 | ..Default::default()
205 | };
206 |
207 | let mut loss = model.loss()?;
208 |
209 | // create an optimiser
210 | let mut optimiser = Lbfgs::new(varmap.all_vars(), params, &model)?;
211 | // load the test images
212 | let _test_images = m.test_images.to_device(&dev)?;
213 | // load the test labels
214 | let _test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
215 |
216 | for _epoch in 0..100 {
217 | // get the loss
218 |
219 | // step the tensors by backpropagating the loss
220 | let res = optimiser.backward_step(&loss)?;
221 | match res {
222 | candle_optimisers::ModelOutcome::Converged(_, _) => break,
223 | candle_optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
224 | // _ => panic!("unexpected outcome"),
225 | }
226 | }
227 | Ok(())
228 | }
229 |
--------------------------------------------------------------------------------
/tests/adagrad_tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::adagrad::{Adagrad, ParamsAdaGrad};
7 |
8 | /* The results of this test have been checked against the following PyTorch code.
9 | import torch
10 | from torch import optim
11 |
12 | w_gen = torch.tensor([[3., 1.]])
13 | b_gen = torch.tensor([-2.])
14 |
15 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
16 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
17 |
18 | m = torch.nn.Linear(2, 1)
19 | with torch.no_grad():
20 | m.weight.zero_()
21 | m.bias.zero_()
22 | optimiser = optim.Adagrad(m.parameters(), lr=0.004, weight_decay=0.00)
23 | # optimiser.zero_grad()
24 | for _step in range(1000):
25 | optimiser.zero_grad()
26 | ys = m(sample_xs)
27 | loss = ((ys - sample_ys)**2).sum()
28 | loss.backward()
29 | optimiser.step()
30 | # print("Optimizer state begin")
31 | # print(optimiser.state)
32 | # print("Optimizer state end")
33 | print(m.weight)
34 | print(m.bias)
35 | */
36 | #[test]
37 | fn adagrad_test() -> Result<()> {
38 | // Generate some linear data, y = 3.x1 + x2 - 2.
39 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
40 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
41 | let gen = Linear::new(w_gen, Some(b_gen));
42 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
43 | let sample_ys = gen.forward(&sample_xs)?;
44 |
45 | let params = ParamsAdaGrad {
46 | lr: 0.004,
47 | lr_decay: 0.0,
48 | weight_decay: None,
49 | initial_acc: 0.0,
50 | eps: 1e-10,
51 | };
52 | // Now use backprop to run a linear regression between samples and get the coefficients back.
53 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
54 | let b = Var::new(0f32, &Device::Cpu)?;
55 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?;
56 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
57 | for _step in 0..1000 {
58 | let ys = lin.forward(&sample_xs)?;
59 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
60 | n_sgd.backward_step(&loss)?;
61 | }
62 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.2424, 0.2341]]);
63 | assert_eq!(to_vec0_round(&b, 4)?, 0.2379);
64 | Ok(())
65 | }
66 |
67 | /* The results of this test have been checked against the following PyTorch code.
68 | import torch
69 | from torch import optim
70 |
71 | w_gen = torch.tensor([[3., 1.]])
72 | b_gen = torch.tensor([-2.])
73 |
74 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
75 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
76 |
77 | m = torch.nn.Linear(2, 1)
78 | with torch.no_grad():
79 | m.weight.zero_()
80 | m.bias.zero_()
81 | optimiser = optim.Adagrad(m.parameters(), lr=0.004, lr_decay=0.2)
82 | # optimiser.zero_grad()
83 | for _step in range(1000):
84 | optimiser.zero_grad()
85 | ys = m(sample_xs)
86 | loss = ((ys - sample_ys)**2).sum()
87 | loss.backward()
88 | optimiser.step()
89 | # print("Optimizer state begin")
90 | # print(optimiser.state)
91 | # print("Optimizer state end")
92 | print(m.weight)
93 | print(m.bias)
94 | */
95 | #[test]
96 | fn adagrad_lr_decay_test() -> Result<()> {
97 | // Generate some linear data, y = 3.x1 + x2 - 2.
98 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
99 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
100 | let gen = Linear::new(w_gen, Some(b_gen));
101 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
102 | let sample_ys = gen.forward(&sample_xs)?;
103 |
104 | let params = ParamsAdaGrad {
105 | lr: 0.004,
106 | lr_decay: 0.2,
107 | weight_decay: None,
108 | initial_acc: 0.0,
109 | eps: 1e-10,
110 | };
111 | // Now use backprop to run a linear regression between samples and get the coefficients back.
112 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
113 | let b = Var::new(0f32, &Device::Cpu)?;
114 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?;
115 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
116 | for _step in 0..1000 {
117 | let ys = lin.forward(&sample_xs)?;
118 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
119 | n_sgd.backward_step(&loss)?;
120 | }
121 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.0231, 0.0230]]);
122 | assert_eq!(to_vec0_round(&b, 4)?, 0.0230);
123 | Ok(())
124 | }
125 |
126 | /* The results of this test have been checked against the following PyTorch code.
127 | import torch
128 | from torch import optim
129 |
130 | w_gen = torch.tensor([[3., 1.]])
131 | b_gen = torch.tensor([-2.])
132 |
133 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
134 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
135 |
136 | m = torch.nn.Linear(2, 1)
137 | with torch.no_grad():
138 | m.weight.zero_()
139 | m.bias.zero_()
140 | optimiser = optim.Adagrad(m.parameters(), lr=0.004, weight_decay=0.2)
141 | # optimiser.zero_grad()
142 | for _step in range(1000):
143 | optimiser.zero_grad()
144 | ys = m(sample_xs)
145 | loss = ((ys - sample_ys)**2).sum()
146 | loss.backward()
147 | optimiser.step()
148 | # print("Optimizer state begin")
149 | # print(optimiser.state)
150 | # print("Optimizer state end")
151 | print(m.weight)
152 | print(m.bias)
153 | */
154 | #[test]
155 | fn adagrad_weight_decay_test() -> Result<()> {
156 | // Generate some linear data, y = 3.x1 + x2 - 2.
157 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
158 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
159 | let gen = Linear::new(w_gen, Some(b_gen));
160 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
161 | let sample_ys = gen.forward(&sample_xs)?;
162 |
163 | let params = ParamsAdaGrad {
164 | lr: 0.004,
165 | lr_decay: 0.0,
166 | weight_decay: Some(candle_optimisers::Decay::WeightDecay(0.2)),
167 | initial_acc: 0.0,
168 | eps: 1e-10,
169 | };
170 | // Now use backprop to run a linear regression between samples and get the coefficients back.
171 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
172 | let b = Var::new(0f32, &Device::Cpu)?;
173 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?;
174 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
175 | for _step in 0..1000 {
176 | let ys = lin.forward(&sample_xs)?;
177 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
178 | n_sgd.backward_step(&loss)?;
179 | }
180 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.2424, 0.2341]]);
181 | assert_eq!(to_vec0_round(&b, 4)?, 0.2378);
182 | Ok(())
183 | }
184 |
185 | //-------------------------------------------------------------------------
186 | // THIS IS NOT TESTED AGAINST PYTORCH
187 | // AS PYTORCH DOES NOT HAVE DECOUPLED WEIGHT DECAY FOR ADAGRAD
188 | // ------------------------------------------------------------------------
189 | #[test]
190 | fn adagrad_decoupled_weight_decay_test() -> Result<()> {
191 | // Generate some linear data, y = 3.x1 + x2 - 2.
192 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
193 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
194 | let gen = Linear::new(w_gen, Some(b_gen));
195 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
196 | let sample_ys = gen.forward(&sample_xs)?;
197 |
198 | let params = ParamsAdaGrad {
199 | lr: 0.004,
200 | lr_decay: 0.0,
201 | weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.2)),
202 | initial_acc: 0.0,
203 | eps: 1e-10,
204 | };
205 | // Now use backprop to run a linear regression between samples and get the coefficients back.
206 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
207 | let b = Var::new(0f32, &Device::Cpu)?;
208 | let mut n_sgd = Adagrad::new(vec![w.clone(), b.clone()], params)?;
209 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
210 | for _step in 0..1000 {
211 | let ys = lin.forward(&sample_xs)?;
212 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
213 | n_sgd.backward_step(&loss)?;
214 | }
215 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.1483, 0.1450]]);
216 | assert_eq!(to_vec0_round(&b, 4)?, 0.1465);
217 | Ok(())
218 | }
219 |
--------------------------------------------------------------------------------
/tests/lbfgs_tests.rs:
--------------------------------------------------------------------------------
1 | use anyhow::Result;
2 | use candle_core::test_utils::to_vec2_round;
3 | use candle_core::{DType, Device, Result as CResult, Tensor};
4 | use candle_optimisers::lbfgs::{GradConv, Lbfgs, LineSearch, ParamsLBFGS, StepConv};
5 | use candle_optimisers::{LossOptimizer, Model, ModelOutcome};
6 |
7 | /*
8 | These tests all use the 2D Rosenbrock function as a test function for the optimisers. This has minimum 0 at (1, 1)
9 | */
10 |
11 | #[derive(Debug, Clone)]
12 | pub struct RosenbrockModel {
13 | x_pos: candle_core::Var,
14 | y_pos: candle_core::Var,
15 | }
16 |
17 | impl Model for RosenbrockModel {
18 | fn loss(&self) -> CResult {
19 | //, xs: &Tensor, ys: &Tensor
20 | self.forward()?.squeeze(1)?.squeeze(0)
21 | }
22 | }
23 |
24 | impl RosenbrockModel {
25 | fn new() -> CResult {
26 | let x_pos = candle_core::Var::from_tensor(
27 | &(10. * Tensor::ones((1, 1), DType::F64, &Device::Cpu)?)?,
28 | )?;
29 | let y_pos = candle_core::Var::from_tensor(
30 | &(10. * Tensor::ones((1, 1), DType::F64, &Device::Cpu)?)?,
31 | )?;
32 | Ok(Self { x_pos, y_pos })
33 | }
34 | fn vars(&self) -> Vec {
35 | vec![self.x_pos.clone(), self.y_pos.clone()]
36 | }
37 |
38 | fn forward(&self) -> CResult {
39 | //, xs: &Tensor
40 | (1. - self.x_pos.as_tensor())?.powf(2.)?
41 | + 100. * (self.y_pos.as_tensor() - self.x_pos.as_tensor().powf(2.)?)?.powf(2.)?
42 | }
43 | }
44 |
45 | #[test]
46 | fn lbfgs_test() -> Result<()> {
47 | let params = ParamsLBFGS {
48 | lr: 1.,
49 | ..Default::default()
50 | };
51 |
52 | let model = RosenbrockModel::new()?;
53 |
54 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?;
55 | let mut loss = model.loss()?;
56 |
57 | for _step in 0..500 {
58 | // println!("\nstart step {}", step);
59 | // for v in model.vars() {
60 | // println!("{}", v);
61 | // }
62 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys
63 | // println!("end step {}", _step);
64 | match res {
65 | ModelOutcome::Converged(_, _) => break,
66 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
67 | // _ => panic!("unexpected outcome"),
68 | }
69 | }
70 |
71 | for v in model.vars() {
72 | // println!("{}", v);
73 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]);
74 | }
75 |
76 | // println!("{:?}", lbfgs);
77 | // panic!("deliberate panic");
78 |
79 | Ok(())
80 | }
81 |
82 | #[test]
83 | fn lbfgs_test_strong_wolfe() -> Result<()> {
84 | let params = ParamsLBFGS {
85 | lr: 1.,
86 | line_search: Some(LineSearch::StrongWolfe(1e-4, 0.9, 1e-9)),
87 | ..Default::default()
88 | };
89 |
90 | let model = RosenbrockModel::new()?;
91 |
92 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?;
93 | let mut loss = model.loss()?;
94 |
95 | for _step in 0..500 {
96 | // println!("\nstart step {}", step);
97 | // for v in model.vars() {
98 | // println!("{}", v);
99 | // }
100 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys
101 | // println!("end step {}", _step);
102 | match res {
103 | ModelOutcome::Converged(_, _) => break,
104 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
105 | // _ => panic!("unexpected outcome"),
106 | }
107 | }
108 |
109 | for v in model.vars() {
110 | // println!("{}", v);
111 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]);
112 | }
113 |
114 | // println!("{:?}", lbfgs);
115 | // panic!("deliberate panic");
116 |
117 | Ok(())
118 | }
119 |
120 | #[test]
121 | fn lbfgs_rms_grad_test() -> Result<()> {
122 | let params = ParamsLBFGS {
123 | lr: 1.,
124 | grad_conv: GradConv::RMSForce(1e-6),
125 | ..Default::default()
126 | };
127 |
128 | let model = RosenbrockModel::new()?;
129 |
130 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?;
131 | let mut loss = model.loss()?;
132 |
133 | for _step in 0..500 {
134 | // println!("\nstart step {}", step);
135 | // for v in model.vars() {
136 | // println!("{}", v);
137 | // }
138 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys
139 | // println!("end step {}", _step);
140 | match res {
141 | ModelOutcome::Converged(_, _) => break,
142 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
143 | // _ => panic!("unexpected outcome"),
144 | }
145 | }
146 |
147 | for v in model.vars() {
148 | // println!("{}", v);
149 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]);
150 | }
151 |
152 | // println!("{:?}", lbfgs);
153 | // panic!("deliberate panic");
154 |
155 | Ok(())
156 | }
157 |
158 | #[test]
159 | fn lbfgs_rms_step_test() -> Result<()> {
160 | let params = ParamsLBFGS {
161 | lr: 1.,
162 | grad_conv: GradConv::RMSForce(0.),
163 | step_conv: StepConv::RMSStep(1e-7),
164 | ..Default::default()
165 | };
166 |
167 | let model = RosenbrockModel::new()?;
168 |
169 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?;
170 | let mut loss = model.loss()?;
171 |
172 | for _step in 0..500 {
173 | // println!("\nstart step {}", step);
174 | // for v in model.vars() {
175 | // println!("{}", v);
176 | // }
177 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys
178 | // println!("end step {}", _step);
179 | match res {
180 | ModelOutcome::Converged(_, _) => break,
181 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
182 | // _ => panic!("unexpected outcome"),
183 | }
184 | }
185 |
186 | for v in model.vars() {
187 | // println!("{}", v);
188 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[1.0000]]);
189 | }
190 |
191 | // println!("{:?}", lbfgs);
192 | // panic!("deliberate panic");
193 |
194 | Ok(())
195 | }
196 |
197 | #[test]
198 | fn lbfgs_test_strong_wolfe_weight_decay() -> Result<()> {
199 | let params = ParamsLBFGS {
200 | lr: 1.,
201 | line_search: Some(LineSearch::StrongWolfe(1e-4, 0.9, 1e-9)),
202 | weight_decay: Some(0.1),
203 | ..Default::default()
204 | };
205 |
206 | let model = RosenbrockModel::new()?;
207 |
208 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?;
209 | let mut loss = model.loss()?;
210 |
211 | for _step in 0..500 {
212 | // println!("\nstart step {}", step);
213 | // for v in model.vars() {
214 | // println!("{}", v);
215 | // }
216 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys
217 | // println!("end step {}", _step);
218 | match res {
219 | ModelOutcome::Converged(_, _) => break,
220 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
221 | // _ => panic!("unexpected outcome"),
222 | }
223 | }
224 |
225 | let expected = [0.8861, 0.7849]; // this should be properly checked
226 | for (v, e) in model.vars().iter().zip(expected) {
227 | // println!("{}", v);
228 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[e]]);
229 | }
230 |
231 | // println!("{:?}", lbfgs);
232 | // panic!("deliberate panic");
233 |
234 | Ok(())
235 | }
236 |
237 | #[test]
238 | fn lbfgs_test_weight_decay() -> Result<()> {
239 | let params = ParamsLBFGS {
240 | lr: 1.,
241 | weight_decay: Some(0.1),
242 | ..Default::default()
243 | };
244 |
245 | let model = RosenbrockModel::new()?;
246 |
247 | let mut lbfgs = Lbfgs::new(model.vars(), params, &model)?;
248 | let mut loss = model.loss()?;
249 |
250 | for _step in 0..500 {
251 | // println!("\nstart step {}", step);
252 | // for v in model.vars() {
253 | // println!("{}", v);
254 | // }
255 | let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys
256 | // println!("end step {}", _step);
257 | match res {
258 | ModelOutcome::Converged(_, _) => break,
259 | ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
260 | // _ => panic!("unexpected outcome"),
261 | }
262 | }
263 |
264 | let expected = [0.8861, 0.7849]; // this should be properly checked
265 | for (v, e) in model.vars().iter().zip(expected) {
266 | // println!("{}", v);
267 | assert_eq!(to_vec2_round(&v.to_dtype(DType::F32)?, 4)?, &[[e]]);
268 | }
269 |
270 | // println!("{:?}", lbfgs);
271 | // panic!("deliberate panic");
272 |
273 | Ok(())
274 | }
275 |
--------------------------------------------------------------------------------
/src/adagrad.rs:
--------------------------------------------------------------------------------
1 | /*!
2 | Adagrad optimiser
3 |
4 | Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html)
5 |
6 | Pseudocode (including decoupling of weight decay):
7 |
8 | $$
9 | \\begin{aligned}
10 | &\\rule{110mm}{0.4pt} \\\\
11 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta)
12 | \\text{ (objective)}, \\: \\lambda \\text{ (weight decay)}, \\\\
13 | &\\hspace{12mm} \\tau \\text{ (initial accumulator value)}, \\: \\eta\\text{ (lr decay)}\\\\
14 | &\\textbf{initialize} : statesum_0 \\leftarrow 0 \\\\[-1.ex]
15 | &\\rule{110mm}{0.4pt} \\\\
16 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
17 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
18 | &\\hspace{5mm} \\tilde{\\gamma} \\leftarrow \\gamma / (1 +(t-1) \\eta) \\\\
19 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
20 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
21 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
22 | &\\hspace{10mm}\\textbf{else} \\\\
23 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
24 | &\\hspace{5mm}statesum_t \\leftarrow statesum_{t-1} + g^2_t \\\\
25 | &\\hspace{5mm}\\theta_t \\leftarrow
26 | \\theta_{t-1}- \\tilde{\\gamma} \\frac{g_t}{\\sqrt{statesum_t}+\\epsilon} \\\\
27 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
28 | &\\bf{return} \\: \\theta_t \\\\[-1.ex]
29 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
30 | \\end{aligned}
31 | $$
32 |
33 |
34 |
35 | */
36 |
37 | use candle_core::{Result, Var};
38 | use candle_nn::optim::Optimizer;
39 |
40 | use crate::{Decay, OptimParams};
41 |
42 | /// Adagrad optimiser
43 | ///
44 | /// Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html)
45 | #[derive(Debug)]
46 | pub struct Adagrad {
47 | vars: Vec,
48 | params: ParamsAdaGrad,
49 | t: f64,
50 | }
51 |
52 | #[derive(Debug)]
53 | struct VarAdaGrad {
54 | theta: Var,
55 | sum: Var,
56 | }
57 |
58 | /// Parameters for the Adagrad optimiser
59 | #[derive(Clone, Debug, PartialEq, PartialOrd)]
60 | pub struct ParamsAdaGrad {
61 | /// Learning rate
62 | pub lr: f64,
63 | /// Learning rate decay
64 | pub lr_decay: f64,
65 | /// Initial value of accumulator
66 | pub initial_acc: f64,
67 | /// weight decay
68 | pub weight_decay: Option,
69 | /// term added to the denominator to improve numerical stability
70 | pub eps: f64,
71 | }
72 |
73 | impl Default for ParamsAdaGrad {
74 | fn default() -> Self {
75 | Self {
76 | lr: 0.01,
77 | lr_decay: 0.0,
78 | initial_acc: 0.0,
79 | weight_decay: None,
80 | eps: 1e-10,
81 | }
82 | }
83 | }
84 |
85 | impl Optimizer for Adagrad {
86 | type Config = ParamsAdaGrad;
87 |
88 | fn new(vars: Vec, params: ParamsAdaGrad) -> Result {
89 | let vars = vars
90 | .into_iter()
91 | .filter(|var| var.dtype().is_float())
92 | .map(|var| {
93 | let dtype = var.dtype();
94 | let shape = var.shape();
95 | let device = var.device();
96 | let sum = Var::zeros(shape, dtype, device)?;
97 | Ok(VarAdaGrad { theta: var, sum })
98 | })
99 | .collect::>>()?;
100 | // // Err(SGDError::NoMomentum)?;
101 | // let mut params = params;
102 | // params.t = 0;
103 | Ok(Self {
104 | vars,
105 | t: 0.,
106 | params,
107 | })
108 | }
109 |
110 | fn learning_rate(&self) -> f64 {
111 | self.params.lr
112 | }
113 |
114 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
115 | if let Some(decay) = self.params.weight_decay {
116 | match decay {
117 | Decay::WeightDecay(decay) => {
118 | for var in &self.vars {
119 | let theta = &var.theta;
120 | let sum = &var.sum;
121 | if let Some(grad) = grads.get(theta) {
122 | let gamma_tilde =
123 | self.params.lr / self.t.mul_add(self.params.lr_decay, 1.);
124 | let grad = &(grad + (decay * theta.as_tensor())?)?;
125 | let current_sum = (sum.as_tensor() + grad.powf(2.)?)?;
126 | let change = (gamma_tilde
127 | * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?;
128 | sum.set(¤t_sum)?;
129 | theta.set(&theta.sub(&change)?)?;
130 | }
131 | }
132 | }
133 | Decay::DecoupledWeightDecay(decay) => {
134 | for var in &self.vars {
135 | let theta = &var.theta;
136 | let sum = &var.sum;
137 | if let Some(grad) = grads.get(theta) {
138 | // decoupled weight decay step
139 | theta
140 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
141 | let gamma_tilde =
142 | self.params.lr / self.t.mul_add(self.params.lr_decay, 1.);
143 | let current_sum = (sum.as_tensor() + grad.powf(2.)?)?;
144 | let change = (gamma_tilde
145 | * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?;
146 | sum.set(¤t_sum)?;
147 | theta.set(&theta.sub(&change)?)?;
148 | }
149 | }
150 | }
151 | }
152 | } else {
153 | for var in &self.vars {
154 | let theta = &var.theta;
155 | let sum = &var.sum;
156 | if let Some(grad) = grads.get(theta) {
157 | let gamma_tilde = self.params.lr / self.t.mul_add(self.params.lr_decay, 1.);
158 | let current_sum = (sum.as_tensor() + grad.powf(2.)?)?;
159 | let change =
160 | (gamma_tilde * (grad.div(&(current_sum.powf(0.5)? + self.params.eps)?))?)?;
161 | sum.set(¤t_sum)?;
162 | theta.set(&theta.sub(&change)?)?;
163 | }
164 | }
165 | }
166 | self.t += 1.;
167 | Ok(())
168 | }
169 |
170 | fn set_learning_rate(&mut self, lr: f64) {
171 | self.params.lr = lr;
172 | }
173 | }
174 |
175 | impl OptimParams for Adagrad {
176 | fn params(&self) -> &Self::Config {
177 | &self.params
178 | }
179 |
180 | fn set_params(&mut self, config: Self::Config) {
181 | self.params = config;
182 | }
183 | }
184 |
185 | impl Adagrad {
186 | /// Return the vars being optimised
187 | #[must_use]
188 | pub fn into_inner(self) -> Vec {
189 | self.vars.into_iter().map(|v| v.theta).collect()
190 | }
191 |
192 | // pub fn push(&mut self, var: &Var) {
193 | // self.vars.push(var.clone());
194 | // }
195 | }
196 |
197 | #[cfg(test)]
198 | mod tests {
199 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
200 |
201 | use anyhow::Result;
202 | use assert_approx_eq::assert_approx_eq;
203 | use candle_core::{Device, Var};
204 | use candle_nn::Optimizer;
205 |
206 | use super::*;
207 | #[test]
208 | fn lr_test() -> Result<()> {
209 | let params = ParamsAdaGrad {
210 | lr: 0.004,
211 | ..Default::default()
212 | };
213 | // Now use backprop to run a linear regression between samples and get the coefficients back.
214 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
215 | let b = Var::new(0f32, &Device::Cpu)?;
216 | let mut optim = Adagrad::new(vec![w.clone(), b.clone()], params)?;
217 | assert_approx_eq!(0.004, optim.learning_rate());
218 | optim.set_learning_rate(0.002);
219 | assert_approx_eq!(0.002, optim.learning_rate());
220 | Ok(())
221 | }
222 |
223 | #[test]
224 | fn into_inner_test() -> Result<()> {
225 | let params = ParamsAdaGrad::default();
226 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
227 | let b = Var::new(-2f32, &Device::Cpu)?;
228 | let optim = Adagrad::new(vec![w.clone(), b.clone()], params)?;
229 | let inner = optim.into_inner();
230 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]);
231 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32);
232 | Ok(())
233 | }
234 |
235 | #[test]
236 | fn params_test() -> Result<()> {
237 | let params = ParamsAdaGrad {
238 | lr: 0.004,
239 | ..Default::default()
240 | };
241 | // Now use backprop to run a linear regression between samples and get the coefficients back.
242 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
243 | let b = Var::new(0f32, &Device::Cpu)?;
244 | let mut optim = Adagrad::new(vec![w.clone(), b.clone()], params.clone())?;
245 | assert_eq!(params, optim.params().clone());
246 | let new_params = ParamsAdaGrad {
247 | lr: 0.002,
248 | ..Default::default()
249 | };
250 | optim.set_params(new_params.clone());
251 | assert_eq!(new_params, optim.params().clone());
252 | Ok(())
253 | }
254 | }
255 |
--------------------------------------------------------------------------------
/src/adamax.rs:
--------------------------------------------------------------------------------
1 | /*!
2 | Adamax optimiser
3 |
4 | An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
5 |
6 | Pseudocode (including decoupling of weight decay):
7 |
8 | $$
9 | \\begin{aligned}
10 | &\\rule{110mm}{0.4pt} \\\\
11 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2
12 | \\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)},
13 | \\: \\lambda \\text{ (weight decay)}, \\\\
14 | &\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\
15 | &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)},
16 | u_0 \\leftarrow 0 \\text{ ( infinity norm)} \\\\[-1.ex]
17 | &\\rule{110mm}{0.4pt} \\\\
18 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
19 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
20 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
21 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
22 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
23 | &\\hspace{10mm}\\textbf{else} \\\\
24 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
25 | &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\
26 | &\\hspace{5mm}u_t \\leftarrow \\mathrm{max}(\\beta_2 u_{t-1}, |g_{t}|+\\epsilon) \\\\
27 | &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\frac{\\gamma m_t}{(1-\\beta^t_1) u_t} \\\\
28 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
29 | &\\bf{return} \\: \\theta_t \\\\[-1.ex]
30 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
31 | \\end{aligned}
32 | $$
33 | */
34 |
35 | use candle_core::{Result, Var};
36 | use candle_nn::optim::Optimizer;
37 |
38 | use crate::{Decay, OptimParams};
39 |
40 | /// Adamax optimiser
41 | ///
42 | /// An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
43 |
44 | #[derive(Debug)]
45 | pub struct Adamax {
46 | vars: Vec,
47 | params: ParamsAdaMax,
48 | t: f64,
49 | }
50 |
51 | #[derive(Debug)]
52 | struct VarAdaMax {
53 | theta: Var,
54 | m: Var,
55 | u: Var,
56 | }
57 |
58 | /// Parameters for the Adamax optimiser
59 | #[derive(Clone, Debug, PartialEq, PartialOrd)]
60 | pub struct ParamsAdaMax {
61 | /// Learning rate
62 | pub lr: f64,
63 | /// Coefficient for moving average of first moment
64 | pub beta_1: f64,
65 | /// Coefficient for moving average of second moment
66 | pub beta_2: f64,
67 | /// Weight decay
68 | pub weight_decay: Option,
69 | /// Term added to denominator to improve numerical stability
70 | pub eps: f64,
71 | }
72 |
73 | impl Default for ParamsAdaMax {
74 | fn default() -> Self {
75 | Self {
76 | lr: 1.0,
77 | beta_1: 0.9,
78 | beta_2: 0.999,
79 | weight_decay: None,
80 | eps: 1e-8,
81 | }
82 | }
83 | }
84 |
85 | impl Optimizer for Adamax {
86 | type Config = ParamsAdaMax;
87 |
88 | fn new(vars: Vec, params: ParamsAdaMax) -> Result {
89 | let vars = vars
90 | .into_iter()
91 | .filter(|var| var.dtype().is_float())
92 | .map(|var| {
93 | let dtype = var.dtype();
94 | let shape = var.shape();
95 | let device = var.device();
96 | let m = Var::zeros(shape, dtype, device)?;
97 | let u = Var::zeros(shape, dtype, device)?;
98 | Ok(VarAdaMax { theta: var, m, u })
99 | })
100 | .collect::>>()?;
101 | // // Err(SGDError::NoMomentum)?;
102 | // let mut params = params;
103 | // params.t = 0;
104 | Ok(Self {
105 | vars,
106 | params,
107 | t: 1.,
108 | })
109 | }
110 |
111 | fn learning_rate(&self) -> f64 {
112 | self.params.lr
113 | }
114 |
115 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
116 | if let Some(decay) = self.params.weight_decay {
117 | match decay {
118 | Decay::WeightDecay(decay) => {
119 | for var in &self.vars {
120 | let theta = &var.theta;
121 | let m = &var.m;
122 | let u = &var.u;
123 | if let Some(grad) = grads.get(theta) {
124 | let grad = &(grad + (decay * theta.as_tensor())?)?;
125 | let m_next = ((self.params.beta_1 * m.as_tensor())?
126 | + (1. - self.params.beta_1) * grad)?;
127 | let u_next = (self.params.beta_2 * u.as_tensor())?
128 | .maximum(&(grad.abs()? + self.params.eps)?)?;
129 | let delta = (&m_next * self.params.lr)?
130 | .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
131 | theta.set(&theta.sub(&(delta))?)?;
132 | m.set(&m_next)?;
133 | u.set(&u_next)?;
134 | }
135 | }
136 | }
137 | Decay::DecoupledWeightDecay(decay) => {
138 | for var in &self.vars {
139 | let theta = &var.theta;
140 | let m = &var.m;
141 | let u = &var.u;
142 | if let Some(grad) = grads.get(theta) {
143 | // decoupled weight decay step
144 | theta
145 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
146 | let m_next = ((self.params.beta_1 * m.as_tensor())?
147 | + (1. - self.params.beta_1) * grad)?;
148 | let u_next = (self.params.beta_2 * u.as_tensor())?
149 | .maximum(&(grad.abs()? + self.params.eps)?)?;
150 | let delta = (&m_next * self.params.lr)?
151 | .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
152 | theta.set(&theta.sub(&(delta))?)?;
153 | m.set(&m_next)?;
154 | u.set(&u_next)?;
155 | }
156 | }
157 | }
158 | }
159 | } else {
160 | for var in &self.vars {
161 | let theta = &var.theta;
162 | let m = &var.m;
163 | let u = &var.u;
164 | if let Some(grad) = grads.get(theta) {
165 | let m_next =
166 | ((self.params.beta_1 * m.as_tensor())? + (1. - self.params.beta_1) * grad)?;
167 | let u_next = (self.params.beta_2 * u.as_tensor())?
168 | .maximum(&(grad.abs()? + self.params.eps)?)?;
169 | let delta = (&m_next * self.params.lr)?
170 | .div(&(&u_next * (1. - self.params.beta_1.powf(self.t)))?)?;
171 | theta.set(&theta.sub(&(delta))?)?;
172 | m.set(&m_next)?;
173 | u.set(&u_next)?;
174 | }
175 | }
176 | }
177 | self.t += 1.;
178 | Ok(())
179 | }
180 |
181 | fn set_learning_rate(&mut self, lr: f64) {
182 | self.params.lr = lr;
183 | }
184 | }
185 |
186 | impl OptimParams for Adamax {
187 | fn params(&self) -> &Self::Config {
188 | &self.params
189 | }
190 |
191 | fn set_params(&mut self, config: Self::Config) {
192 | self.params = config;
193 | }
194 | }
195 |
196 | impl Adamax {
197 | /// Return the vars being optimised
198 | #[must_use]
199 | pub fn into_inner(self) -> Vec {
200 | self.vars.into_iter().map(|v| v.theta).collect()
201 | }
202 |
203 | // pub fn push(&mut self, var: &Var) {
204 | // self.vars.push(var.clone());
205 | // }
206 | }
207 |
208 | #[cfg(test)]
209 | mod tests {
210 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
211 |
212 | use anyhow::Result;
213 | use assert_approx_eq::assert_approx_eq;
214 | use candle_core::{Device, Var};
215 | use candle_nn::Optimizer;
216 |
217 | use super::*;
218 | #[test]
219 | fn lr_test() -> Result<()> {
220 | let params = ParamsAdaMax {
221 | lr: 0.004,
222 | ..Default::default()
223 | };
224 | // Now use backprop to run a linear regression between samples and get the coefficients back.
225 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
226 | let b = Var::new(0f32, &Device::Cpu)?;
227 | let mut optim = Adamax::new(vec![w.clone(), b.clone()], params)?;
228 | assert_approx_eq!(0.004, optim.learning_rate());
229 | optim.set_learning_rate(0.002);
230 | assert_approx_eq!(0.002, optim.learning_rate());
231 | Ok(())
232 | }
233 |
234 | #[test]
235 | fn into_inner_test() -> Result<()> {
236 | let params = ParamsAdaMax::default();
237 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
238 | let b = Var::new(-2f32, &Device::Cpu)?;
239 | let optim = Adamax::new(vec![w.clone(), b.clone()], params)?;
240 | let inner = optim.into_inner();
241 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]);
242 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32);
243 | Ok(())
244 | }
245 |
246 | #[test]
247 | fn params_test() -> Result<()> {
248 | let params = ParamsAdaMax {
249 | lr: 0.004,
250 | ..Default::default()
251 | };
252 | // Now use backprop to run a linear regression between samples and get the coefficients back.
253 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
254 | let b = Var::new(0f32, &Device::Cpu)?;
255 | let mut optim = Adamax::new(vec![w.clone(), b.clone()], params.clone())?;
256 | assert_eq!(params, optim.params().clone());
257 | let new_params = ParamsAdaMax {
258 | lr: 0.002,
259 | ..Default::default()
260 | };
261 | optim.set_params(new_params.clone());
262 | assert_eq!(new_params, optim.params().clone());
263 | Ok(())
264 | }
265 | }
266 |
--------------------------------------------------------------------------------
/src/adadelta.rs:
--------------------------------------------------------------------------------
1 | /*!
2 | Adadelta optimiser
3 |
4 | Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
5 |
6 | Pseudocode (including decoupling of weight decay):
7 | $$
8 | \\begin{aligned}
9 | &\\rule{110mm}{0.4pt} \\\\
10 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)},
11 | \\: f(\\theta) \\text{ (objective)}, \\: \\rho \\text{ (decay)},
12 | \\: \\lambda \\text{ (weight decay)} \\\\
13 | &\\textbf{initialize} : v_0 \\leftarrow 0 \\: \\text{ (square avg)},
14 | \\: u_0 \\leftarrow 0 \\: \\text{ (accumulate variables)} \\\\[-1.ex]
15 | &\\rule{110mm}{0.4pt} \\\\
16 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
17 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
18 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
19 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
20 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
21 | &\\hspace{10mm}\\textbf{else} \\\\
22 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
23 | &\\hspace{5mm} v_t \\leftarrow v_{t-1} \\rho + g^2_t (1 - \\rho) \\\\
24 | &\\hspace{5mm}\\Delta x_t \\leftarrow \\frac{\\sqrt{u_{t-1} +
25 | \\epsilon }}{ \\sqrt{v_t + \\epsilon} }g_t \\hspace{21mm} \\\\
26 | &\\hspace{5mm} u_t \\leftarrow u_{t-1} \\rho +
27 | \\Delta x^2_t (1 - \\rho) \\\\
28 | &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\Delta x_t \\\\
29 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
30 | &\\bf{return} \\: \\theta_t \\\\[-1.ex]
31 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
32 | \\end{aligned}
33 | $$
34 | */
35 |
36 | use candle_core::{Result, Var};
37 | use candle_nn::optim::Optimizer;
38 |
39 | use crate::{Decay, OptimParams};
40 |
41 | /// Adadelta optimiser
42 | ///
43 | /// Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
44 | #[derive(Debug)]
45 | pub struct Adadelta {
46 | vars: Vec,
47 | params: ParamsAdaDelta,
48 | // avg_acc: HashMap,
49 | }
50 |
51 | #[derive(Debug)]
52 | struct VarAdaDelta {
53 | theta: Var,
54 | v: Var,
55 | u: Var,
56 | }
57 |
58 | /// Parameters for the Adadelta optimiser
59 | #[derive(Clone, Debug, PartialEq, PartialOrd)]
60 | pub struct ParamsAdaDelta {
61 | /// Learning rate
62 | pub lr: f64,
63 | /// Decay
64 | pub rho: f64,
65 | /// Term added to the denominator to improve numerical stability
66 | pub eps: f64,
67 | /// Weight decay
68 | pub weight_decay: Option,
69 | }
70 |
71 | impl Default for ParamsAdaDelta {
72 | fn default() -> Self {
73 | Self {
74 | lr: 1.0,
75 | rho: 0.9,
76 | weight_decay: None,
77 | eps: 1e-6,
78 | }
79 | }
80 | }
81 |
82 | impl Optimizer for Adadelta {
83 | type Config = ParamsAdaDelta;
84 |
85 | fn new(vars: Vec, params: ParamsAdaDelta) -> Result {
86 | let vars = vars
87 | .into_iter()
88 | .filter(|var| var.dtype().is_float())
89 | .map(|var| {
90 | let dtype = var.dtype();
91 | let shape = var.shape();
92 | let device = var.device();
93 | let v = Var::zeros(shape, dtype, device)?;
94 | let u = Var::zeros(shape, dtype, device)?;
95 | Ok(VarAdaDelta { theta: var, v, u })
96 | })
97 | .collect::>>()?;
98 | // // Err(SGDError::NoMomentum)?;
99 | // let mut params = params;
100 | // params.t = 0;
101 | Ok(Self {
102 | vars,
103 | params,
104 | // avg_acc: HashMap::new(),
105 | })
106 | }
107 |
108 | fn learning_rate(&self) -> f64 {
109 | self.params.lr
110 | }
111 |
112 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
113 | if let Some(decay) = self.params.weight_decay {
114 | match decay {
115 | Decay::WeightDecay(decay) => {
116 | for var in &self.vars {
117 | let theta = &var.theta;
118 | let v = &var.v;
119 | let u = &var.u;
120 | if let Some(grad) = grads.get(theta) {
121 | let grad = &(grad + (decay * theta.as_tensor())?)?;
122 | let v_next = ((v.as_tensor() * self.params.rho)?
123 | + (1. - self.params.rho) * grad.powf(2.)?)?;
124 | let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
125 | .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
126 | * grad)?;
127 | let u_next = ((u.as_tensor() * self.params.rho)?
128 | + (1. - self.params.rho) * delta_x.powf(2.)?)?;
129 | theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
130 | v.set(&v_next)?;
131 | u.set(&u_next)?;
132 | }
133 | }
134 | }
135 | Decay::DecoupledWeightDecay(decay) => {
136 | for var in &self.vars {
137 | let theta = &var.theta;
138 | let v = &var.v;
139 | let u = &var.u;
140 | if let Some(grad) = grads.get(theta) {
141 | // decoupled weight decay step
142 | theta
143 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
144 | let v_next = ((v.as_tensor() * self.params.rho)?
145 | + (1. - self.params.rho) * grad.powf(2.)?)?;
146 | let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
147 | .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
148 | * grad)?;
149 | let u_next = ((u.as_tensor() * self.params.rho)?
150 | + (1. - self.params.rho) * delta_x.powf(2.)?)?;
151 | theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
152 | v.set(&v_next)?;
153 | u.set(&u_next)?;
154 | }
155 | }
156 | }
157 | }
158 | } else {
159 | for var in &self.vars {
160 | let theta = &var.theta;
161 | let v = &var.v;
162 | let u = &var.u;
163 | if let Some(grad) = grads.get(theta) {
164 | let v_next = ((v.as_tensor() * self.params.rho)?
165 | + (1. - self.params.rho) * grad.powf(2.)?)?;
166 | let delta_x = (((u.as_tensor() + self.params.eps)?.powf(0.5)?)
167 | .div(&((&v_next + self.params.eps)?.powf(0.5)?))?
168 | * grad)?;
169 | let u_next = ((u.as_tensor() * self.params.rho)?
170 | + (1. - self.params.rho) * delta_x.powf(2.)?)?;
171 | theta.set(&theta.sub(&(delta_x * self.params.lr)?)?)?;
172 | v.set(&v_next)?;
173 | u.set(&u_next)?;
174 | }
175 | }
176 | }
177 |
178 | Ok(())
179 | }
180 |
181 | fn set_learning_rate(&mut self, lr: f64) {
182 | self.params.lr = lr;
183 | }
184 | }
185 |
186 | impl OptimParams for Adadelta {
187 | fn params(&self) -> &Self::Config {
188 | &self.params
189 | }
190 |
191 | fn set_params(&mut self, config: Self::Config) {
192 | self.params = config;
193 | }
194 | }
195 |
196 | impl Adadelta {
197 | /// Return the vars being optimised
198 | #[must_use]
199 | pub fn into_inner(self) -> Vec {
200 | self.vars.into_iter().map(|v| v.theta).collect()
201 | }
202 |
203 | // pub fn push(&mut self, var: &Var) {
204 | // self.vars.push(var.clone());
205 | // }
206 | }
207 |
208 | #[cfg(test)]
209 | mod tests {
210 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
211 |
212 | use anyhow::Result;
213 | use assert_approx_eq::assert_approx_eq;
214 | use candle_core::{Device, Var};
215 | use candle_nn::Optimizer;
216 |
217 | use super::*;
218 | #[test]
219 | fn lr_test() -> Result<()> {
220 | let params = ParamsAdaDelta {
221 | lr: 0.004,
222 | ..Default::default()
223 | };
224 | // Now use backprop to run a linear regression between samples and get the coefficients back.
225 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
226 | let b = Var::new(0f32, &Device::Cpu)?;
227 | let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params)?;
228 | assert_approx_eq!(0.004, optim.learning_rate());
229 | optim.set_learning_rate(0.002);
230 | assert_approx_eq!(0.002, optim.learning_rate());
231 | Ok(())
232 | }
233 |
234 | #[test]
235 | fn into_inner_test() -> Result<()> {
236 | let params = ParamsAdaDelta::default();
237 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
238 | let b = Var::new(-2f32, &Device::Cpu)?;
239 | let optim = Adadelta::new(vec![w.clone(), b.clone()], params)?;
240 | let inner = optim.into_inner();
241 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]);
242 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32);
243 | Ok(())
244 | }
245 |
246 | #[test]
247 | fn params_test() -> Result<()> {
248 | let params = ParamsAdaDelta {
249 | lr: 0.004,
250 | ..Default::default()
251 | };
252 | // Now use backprop to run a linear regression between samples and get the coefficients back.
253 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
254 | let b = Var::new(0f32, &Device::Cpu)?;
255 | let mut optim = Adadelta::new(vec![w.clone(), b.clone()], params.clone())?;
256 | assert_eq!(params, optim.params().clone());
257 | let new_params = ParamsAdaDelta {
258 | lr: 0.002,
259 | ..Default::default()
260 | };
261 | optim.set_params(new_params.clone());
262 | assert_eq!(new_params, optim.params().clone());
263 | Ok(())
264 | }
265 | }
266 |
--------------------------------------------------------------------------------
/src/nadam.rs:
--------------------------------------------------------------------------------
1 | /*!
2 | NAdam optimiser: Adam with Nesterov momentum
3 |
4 | Described in [Incorporating Nesterov Momentum into Adam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ)
5 |
6 | Pseudocode (including decoupling of weight decay):
7 |
8 | $$
9 | \\begin{aligned}
10 | &\\rule{110mm}{0.4pt} \\\\
11 | &\\textbf{input} : \\gamma_t \\text{ (lr)}, \\: \\beta_1,\\beta_2 \\text{ (betas)},
12 | \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) \\text{ (objective)} \\\\
13 | &\\hspace{12mm} \\: \\lambda \\text{ (weight decay)}, \\:\\psi \\text{ (momentum decay)} \\\\
14 | &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)},
15 | v_0 \\leftarrow 0 \\text{ ( second moment)} \\\\[-1.ex]
16 | &\\rule{110mm}{0.4pt} \\\\
17 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
18 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
19 | &\\hspace{5mm} \\theta_t \\leftarrow \\theta_{t-1} \\\\
20 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
21 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
22 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
23 | &\\hspace{10mm}\\textbf{else} \\\\
24 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
25 | &\\hspace{5mm} \\mu_t \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2} 0.96^{t \\psi} \\big) \\\\
26 | &\\hspace{5mm} \\mu_{t+1} \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2} 0.96^{(t+1)\\psi}\\big)\\\\
27 | &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\
28 | &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\
29 | &\\hspace{5mm}\\widehat{m_t} \\leftarrow \\mu_{t+1} m_t/(1-\\prod_{i=1}^{t+1}\\mu_i)\\\\[-1.ex]
30 | & \\hspace{11mm} + (1-\\mu_t) g_t /(1-\\prod_{i=1}^{t} \\mu_{i}) \\\\
31 | &\\hspace{5mm}\\widehat{v_t} \\leftarrow v_t/\\big(1-\\beta_2^t \\big) \\\\
32 | &\\hspace{5mm}\\theta_t \\leftarrow \\theta_t - \\gamma \\widehat{m_t}/
33 | \\big(\\sqrt{\\widehat{v_t}} + \\epsilon \\big) \\\\
34 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
35 | &\\bf{return} \\: \\theta_t \\\\[-1.ex]
36 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
37 | \\end{aligned}
38 | $$
39 | */
40 |
41 | use candle_core::{Result, Var};
42 | use candle_nn::optim::Optimizer;
43 |
44 | use crate::{Decay, OptimParams};
45 |
46 | /// Adam optimiser with Nesterov momentum
47 | ///
48 | /// Described in
49 | #[derive(Debug)]
50 | pub struct NAdam {
51 | vars: Vec,
52 | params: ParamsNAdam,
53 | mu_t: f64,
54 | mu_t2: f64,
55 | prod: f64,
56 | prod2: f64,
57 | t: f64,
58 | }
59 |
60 | #[derive(Debug)]
61 | struct VarNAdam {
62 | theta: Var,
63 | m: Var,
64 | v: Var,
65 | }
66 |
67 | /// Parameters for The NAdam optimiser
68 | #[derive(Clone, Debug, PartialEq, PartialOrd)]
69 | pub struct ParamsNAdam {
70 | /// Learning rate
71 | pub lr: f64,
72 | /// Coefficient for moving average of first moment
73 | pub beta_1: f64,
74 | /// Coefficient for moving average of second moment
75 | pub beta_2: f64,
76 | /// Term added to denominator to improve numerical stability
77 | pub eps: f64,
78 | /// Weight decay
79 | pub weight_decay: Option,
80 | /// Momentum decay
81 | pub momentum_decay: f64,
82 | }
83 |
84 | impl Default for ParamsNAdam {
85 | fn default() -> Self {
86 | Self {
87 | lr: 0.002,
88 | beta_1: 0.9,
89 | beta_2: 0.999,
90 | eps: 1e-8,
91 | weight_decay: None,
92 | momentum_decay: 0.004,
93 | }
94 | }
95 | }
96 |
97 | impl Optimizer for NAdam {
98 | type Config = ParamsNAdam;
99 |
100 | fn new(vars: Vec, params: ParamsNAdam) -> Result {
101 | let vars = vars
102 | .into_iter()
103 | .filter(|var| var.dtype().is_float())
104 | .map(|var| {
105 | let dtype = var.dtype();
106 | let shape = var.shape();
107 | let device = var.device();
108 | let m = Var::zeros(shape, dtype, device)?;
109 | let v = Var::zeros(shape, dtype, device)?;
110 | Ok(VarNAdam { theta: var, m, v })
111 | })
112 | .collect::>>()?;
113 | // // Err(SGDError::NoMomentum)?;
114 | // let mut params = params;
115 | // params.t = 0;
116 | let t = 1.;
117 | let mu_t2 = params.beta_1 * 0.5f64.mul_add(-(0.96_f64.powf(t * params.momentum_decay)), 1.);
118 | Ok(Self {
119 | vars,
120 | params,
121 | t: 1.,
122 | mu_t: 1.,
123 | mu_t2,
124 | prod: 1.,
125 | prod2: mu_t2,
126 | })
127 | }
128 |
129 | fn learning_rate(&self) -> f64 {
130 | self.params.lr
131 | }
132 |
133 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
134 | let mu_t = self.mu_t2;
135 | let mu_t2 = self.params.beta_1
136 | * 0.5f64.mul_add(
137 | -(0.96_f64.powf((self.t + 1.) * self.params.momentum_decay)),
138 | 1.,
139 | );
140 | let prod = self.prod2;
141 | let prod2 = prod * mu_t2;
142 | self.mu_t = mu_t;
143 | self.mu_t2 = mu_t2;
144 | self.prod = prod;
145 | self.prod2 = prod2;
146 | // println!("prod {}", prod);
147 |
148 | if let Some(decay) = self.params.weight_decay {
149 | match decay {
150 | Decay::WeightDecay(decay) => {
151 | for var in &self.vars {
152 | let theta = &var.theta;
153 | let m = &var.m;
154 | let v = &var.v;
155 | if let Some(grad) = grads.get(theta) {
156 | let grad = &(grad + (decay * theta.as_tensor())?)?;
157 | let m_next = ((self.params.beta_1 * m.as_tensor())?
158 | + ((1. - self.params.beta_1) * grad)?)?;
159 | let v_next = ((self.params.beta_2 * v.as_tensor())?
160 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
161 | let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
162 | + (((1. - mu_t) / (1. - prod)) * grad)?)?;
163 | let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
164 | let delta = (m_hat * self.params.lr)?
165 | .div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
166 | theta.set(&theta.sub(&(delta))?)?;
167 | m.set(&m_next)?;
168 | v.set(&v_next)?;
169 | }
170 | }
171 | }
172 | Decay::DecoupledWeightDecay(decay) => {
173 | for var in &self.vars {
174 | let theta = &var.theta;
175 | let m = &var.m;
176 | let v = &var.v;
177 | if let Some(grad) = grads.get(theta) {
178 | theta
179 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
180 | let m_next = ((self.params.beta_1 * m.as_tensor())?
181 | + ((1. - self.params.beta_1) * grad)?)?;
182 | let v_next = ((self.params.beta_2 * v.as_tensor())?
183 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
184 | let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
185 | + (((1. - mu_t) / (1. - prod)) * grad)?)?;
186 | let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
187 | let delta = (m_hat * self.params.lr)?
188 | .div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
189 | theta.set(&theta.sub(&(delta))?)?;
190 | m.set(&m_next)?;
191 | v.set(&v_next)?;
192 | }
193 | }
194 | }
195 | }
196 | } else {
197 | for var in &self.vars {
198 | let theta = &var.theta;
199 | let m = &var.m;
200 | let v = &var.v;
201 | if let Some(grad) = grads.get(theta) {
202 | let m_next = ((self.params.beta_1 * m.as_tensor())?
203 | + ((1. - self.params.beta_1) * grad)?)?;
204 | let v_next = ((self.params.beta_2 * v.as_tensor())?
205 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
206 | let m_hat = (((mu_t2 / (1. - prod2)) * &m_next)?
207 | + (((1. - mu_t) / (1. - prod)) * grad)?)?;
208 | let v_hat = (&v_next / (1. - self.params.beta_2.powf(self.t)))?;
209 | let delta =
210 | (m_hat * self.params.lr)?.div(&(v_hat.powf(0.5)? + self.params.eps)?)?;
211 | theta.set(&theta.sub(&(delta))?)?;
212 | m.set(&m_next)?;
213 | v.set(&v_next)?;
214 | }
215 | }
216 | }
217 |
218 | self.t += 1.;
219 | Ok(())
220 | }
221 |
222 | fn set_learning_rate(&mut self, lr: f64) {
223 | self.params.lr = lr;
224 | }
225 | }
226 |
227 | impl OptimParams for NAdam {
228 | fn params(&self) -> &Self::Config {
229 | &self.params
230 | }
231 |
232 | fn set_params(&mut self, config: Self::Config) {
233 | self.params = config;
234 | }
235 | }
236 |
237 | impl NAdam {
238 | /// Return the vars being optimised
239 | #[must_use]
240 | pub fn into_inner(self) -> Vec {
241 | self.vars.into_iter().map(|v| v.theta).collect()
242 | }
243 |
244 | // pub fn push(&mut self, var: &Var) {
245 | // self.vars.push(var.clone());
246 | // }
247 | }
248 |
249 | #[cfg(test)]
250 | mod tests {
251 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
252 |
253 | use anyhow::Result;
254 | use assert_approx_eq::assert_approx_eq;
255 | use candle_core::{Device, Var};
256 | use candle_nn::Optimizer;
257 |
258 | use super::*;
259 | #[test]
260 | fn lr_test() -> Result<()> {
261 | let params = ParamsNAdam {
262 | lr: 0.004,
263 | ..Default::default()
264 | };
265 | // Now use backprop to run a linear regression between samples and get the coefficients back.
266 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
267 | let b = Var::new(0f32, &Device::Cpu)?;
268 | let mut optim = NAdam::new(vec![w.clone(), b.clone()], params)?;
269 | assert_approx_eq!(0.004, optim.learning_rate());
270 | optim.set_learning_rate(0.002);
271 | assert_approx_eq!(0.002, optim.learning_rate());
272 | Ok(())
273 | }
274 |
275 | #[test]
276 | fn into_inner_test() -> Result<()> {
277 | let params = ParamsNAdam::default();
278 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
279 | let b = Var::new(-2f32, &Device::Cpu)?;
280 | let optim = NAdam::new(vec![w.clone(), b.clone()], params)?;
281 | let inner = optim.into_inner();
282 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]);
283 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32);
284 | Ok(())
285 | }
286 |
287 | #[test]
288 | fn params_test() -> Result<()> {
289 | let params = ParamsNAdam {
290 | lr: 0.004,
291 | ..Default::default()
292 | };
293 | // Now use backprop to run a linear regression between samples and get the coefficients back.
294 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
295 | let b = Var::new(0f32, &Device::Cpu)?;
296 | let mut optim = NAdam::new(vec![w.clone(), b.clone()], params.clone())?;
297 | assert_eq!(params, optim.params().clone());
298 | let new_params = ParamsNAdam {
299 | lr: 0.002,
300 | ..Default::default()
301 | };
302 | optim.set_params(new_params.clone());
303 | assert_eq!(new_params, optim.params().clone());
304 | Ok(())
305 | }
306 | }
307 |
--------------------------------------------------------------------------------
/tests/adam_tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::{
7 | adam::{Adam, ParamsAdam},
8 | Decay,
9 | };
10 |
11 | /* The results of this test have been checked against the following PyTorch code.
12 | import torch
13 | from torch import optim
14 |
15 | w_gen = torch.tensor([[3., 1.]])
16 | b_gen = torch.tensor([-2.])
17 |
18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
20 |
21 | m = torch.nn.Linear(2, 1)
22 | with torch.no_grad():
23 | m.weight.zero_()
24 | m.bias.zero_()
25 | optimiser = optim.Adam(m.parameters())
26 | # optimiser.zero_grad()
27 | for _step in range(1000):
28 | optimiser.zero_grad()
29 | ys = m(sample_xs)
30 | loss = ((ys - sample_ys)**2).sum()
31 | loss.backward()
32 | optimiser.step()
33 | # print("Optimizer state begin")
34 | # print(optimiser.state)
35 | # print("Optimizer state end")
36 | print(m.weight)
37 | print(m.bias)
38 | */
39 | #[test]
40 | fn adam_test() -> Result<()> {
41 | // Generate some linear data, y = 3.x1 + x2 - 2.
42 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
43 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
44 | let gen = Linear::new(w_gen, Some(b_gen));
45 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
46 | let sample_ys = gen.forward(&sample_xs)?;
47 |
48 | let params = ParamsAdam::default();
49 | // Now use backprop to run a linear regression between samples and get the coefficients back.
50 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
51 | let b = Var::new(0f32, &Device::Cpu)?;
52 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
53 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
54 | for _step in 0..1000 {
55 | let ys = lin.forward(&sample_xs)?;
56 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
57 | n_sgd.backward_step(&loss)?;
58 | }
59 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9000, 0.6967]]);
60 | assert_eq!(to_vec0_round(&b, 4)?, 0.7996);
61 | Ok(())
62 | }
63 |
64 | /* The results of this test have been checked against the following PyTorch code.
65 | import torch
66 | from torch import optim
67 |
68 | w_gen = torch.tensor([[3., 1.]])
69 | b_gen = torch.tensor([-2.])
70 |
71 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
72 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
73 |
74 | m = torch.nn.Linear(2, 1)
75 | with torch.no_grad():
76 | m.weight.zero_()
77 | m.bias.zero_()
78 | optimiser = optim.Adam(m.parameters(), weight_decay = 0.6)
79 | # optimiser.zero_grad()
80 | for _step in range(1000):
81 | optimiser.zero_grad()
82 | ys = m(sample_xs)
83 | loss = ((ys - sample_ys)**2).sum()
84 | loss.backward()
85 | optimiser.step()
86 | # print("Optimizer state begin")
87 | # print(optimiser.state)
88 | # print("Optimizer state end")
89 | print(m.weight)
90 | print(m.bias)
91 | */
92 | #[test]
93 | fn adam_weight_decay_test() -> Result<()> {
94 | // Generate some linear data, y = 3.x1 + x2 - 2.
95 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
96 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
97 | let gen = Linear::new(w_gen, Some(b_gen));
98 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
99 | let sample_ys = gen.forward(&sample_xs)?;
100 |
101 | let params = ParamsAdam {
102 | weight_decay: Some(Decay::WeightDecay(0.6)),
103 | ..Default::default()
104 | };
105 | // Now use backprop to run a linear regression between samples and get the coefficients back.
106 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
107 | let b = Var::new(0f32, &Device::Cpu)?;
108 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
109 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
110 | for _step in 0..1000 {
111 | let ys = lin.forward(&sample_xs)?;
112 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
113 | n_sgd.backward_step(&loss)?;
114 | }
115 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.8997, 0.6964]]);
116 | assert_eq!(to_vec0_round(&b, 4)?, 0.7975);
117 | Ok(())
118 | }
119 |
120 | /* The results of this test have been checked against the following PyTorch code.
121 | import torch
122 | from torch import optim
123 |
124 | w_gen = torch.tensor([[3., 1.]])
125 | b_gen = torch.tensor([-2.])
126 |
127 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
128 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
129 |
130 | m = torch.nn.Linear(2, 1)
131 | with torch.no_grad():
132 | m.weight.zero_()
133 | m.bias.zero_()
134 | optimiser = optim.AdamW(m.parameters(), weight_decay = 0.6)
135 | # optimiser.zero_grad()
136 | for _step in range(1000):
137 | optimiser.zero_grad()
138 | ys = m(sample_xs)
139 | loss = ((ys - sample_ys)**2).sum()
140 | loss.backward()
141 | optimiser.step()
142 | # print("Optimizer state begin")
143 | # print(optimiser.state)
144 | # print("Optimizer state end")
145 | print(m.weight)
146 | print(m.bias)
147 | */
148 | #[test]
149 | fn adamw_weight_decay_test() -> Result<()> {
150 | // Generate some linear data, y = 3.x1 + x2 - 2.
151 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
152 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
153 | let gen = Linear::new(w_gen, Some(b_gen));
154 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
155 | let sample_ys = gen.forward(&sample_xs)?;
156 |
157 | let params = ParamsAdam {
158 | weight_decay: Some(Decay::DecoupledWeightDecay(0.6)),
159 | // decoupled_weight_decay: true,
160 | ..Default::default()
161 | };
162 | // Now use backprop to run a linear regression between samples and get the coefficients back.
163 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
164 | let b = Var::new(0f32, &Device::Cpu)?;
165 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
166 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
167 | for _step in 0..1000 {
168 | let ys = lin.forward(&sample_xs)?;
169 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
170 | n_sgd.backward_step(&loss)?;
171 | }
172 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.6901, 0.5677]]);
173 | assert_eq!(to_vec0_round(&b, 4)?, 0.6287);
174 | Ok(())
175 | }
176 |
177 | /* The results of this test have been checked against the following PyTorch code.
178 | import torch
179 | from torch import optim
180 |
181 | w_gen = torch.tensor([[3., 1.]])
182 | b_gen = torch.tensor([-2.])
183 |
184 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
185 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
186 |
187 | m = torch.nn.Linear(2, 1)
188 | with torch.no_grad():
189 | m.weight.zero_()
190 | m.bias.zero_()
191 | optimiser = optim.Adam(m.parameters(), amsgrad=True)
192 | # optimiser.zero_grad()
193 | for _step in range(1000):
194 | optimiser.zero_grad()
195 | ys = m(sample_xs)
196 | loss = ((ys - sample_ys)**2).sum()
197 | loss.backward()
198 | optimiser.step()
199 | # print("Optimizer state begin")
200 | # print(optimiser.state)
201 | # print("Optimizer state end")
202 | print(m.weight)
203 | print(m.bias)
204 | */
205 | #[test]
206 | fn adam_amsgrad_test() -> Result<()> {
207 | // Generate some linear data, y = 3.x1 + x2 - 2.
208 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
209 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
210 | let gen = Linear::new(w_gen, Some(b_gen));
211 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
212 | let sample_ys = gen.forward(&sample_xs)?;
213 |
214 | let params = ParamsAdam {
215 | amsgrad: true,
216 | ..Default::default()
217 | };
218 | // Now use backprop to run a linear regression between samples and get the coefficients back.
219 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
220 | let b = Var::new(0f32, &Device::Cpu)?;
221 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
222 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
223 | for _step in 0..1000 {
224 | let ys = lin.forward(&sample_xs)?;
225 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
226 | n_sgd.backward_step(&loss)?;
227 | }
228 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9001, 0.6904]]);
229 | assert_eq!(to_vec0_round(&b, 4)?, 0.7978);
230 | Ok(())
231 | }
232 |
233 | /* The results of this test have been checked against the following PyTorch code.
234 | import torch
235 | from torch import optim
236 |
237 | w_gen = torch.tensor([[3., 1.]])
238 | b_gen = torch.tensor([-2.])
239 |
240 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
241 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
242 |
243 | m = torch.nn.Linear(2, 1)
244 | with torch.no_grad():
245 | m.weight.zero_()
246 | m.bias.zero_()
247 | optimiser = optim.Adam(m.parameters(), amsgrad=True, weight_decay = 0.6)
248 | # optimiser.zero_grad()
249 | for _step in range(1000):
250 | optimiser.zero_grad()
251 | ys = m(sample_xs)
252 | loss = ((ys - sample_ys)**2).sum()
253 | loss.backward()
254 | optimiser.step()
255 | # print("Optimizer state begin")
256 | # print(optimiser.state)
257 | # print("Optimizer state end")
258 | print(m.weight)
259 | print(m.bias)
260 | */
261 | #[test]
262 | fn adam_amsgrad_decay_test() -> Result<()> {
263 | // Generate some linear data, y = 3.x1 + x2 - 2.
264 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
265 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
266 | let gen = Linear::new(w_gen, Some(b_gen));
267 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
268 | let sample_ys = gen.forward(&sample_xs)?;
269 |
270 | let params = ParamsAdam {
271 | amsgrad: true,
272 | weight_decay: Some(Decay::WeightDecay(0.6)),
273 | ..Default::default()
274 | };
275 | // Now use backprop to run a linear regression between samples and get the coefficients back.
276 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
277 | let b = Var::new(0f32, &Device::Cpu)?;
278 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
279 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
280 | for _step in 0..1000 {
281 | let ys = lin.forward(&sample_xs)?;
282 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
283 | n_sgd.backward_step(&loss)?;
284 | }
285 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.8998, 0.6901]]);
286 | assert_eq!(to_vec0_round(&b, 4)?, 0.7955);
287 | Ok(())
288 | }
289 |
290 | /* The results of this test have been checked against the following PyTorch code.
291 | import torch
292 | from torch import optim
293 |
294 | w_gen = torch.tensor([[3., 1.]])
295 | b_gen = torch.tensor([-2.])
296 |
297 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
298 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
299 |
300 | m = torch.nn.Linear(2, 1)
301 | with torch.no_grad():
302 | m.weight.zero_()
303 | m.bias.zero_()
304 | optimiser = optim.AdamW(m.parameters(), amsgrad=True, weight_decay = 0.6)
305 | # optimiser.zero_grad()
306 | for _step in range(1000):
307 | optimiser.zero_grad()
308 | ys = m(sample_xs)
309 | loss = ((ys - sample_ys)**2).sum()
310 | loss.backward()
311 | optimiser.step()
312 | # print("Optimizer state begin")
313 | # print(optimiser.state)
314 | # print("Optimizer state end")
315 | print(m.weight)
316 | print(m.bias)
317 | */
318 | #[test]
319 | fn adamw_amsgrad_decay_test() -> Result<()> {
320 | // Generate some linear data, y = 3.x1 + x2 - 2.
321 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
322 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
323 | let gen = Linear::new(w_gen, Some(b_gen));
324 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
325 | let sample_ys = gen.forward(&sample_xs)?;
326 |
327 | let params = ParamsAdam {
328 | weight_decay: Some(Decay::DecoupledWeightDecay(0.6)),
329 | amsgrad: true,
330 | // decoupled_weight_decay: true,
331 | ..Default::default()
332 | };
333 | // Now use backprop to run a linear regression between samples and get the coefficients back.
334 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
335 | let b = Var::new(0f32, &Device::Cpu)?;
336 | let mut n_sgd = Adam::new(vec![w.clone(), b.clone()], params)?;
337 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
338 | for _step in 0..1000 {
339 | let ys = lin.forward(&sample_xs)?;
340 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
341 | n_sgd.backward_step(&loss)?;
342 | }
343 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.6901, 0.5648]]);
344 | assert_eq!(to_vec0_round(&b, 4)?, 0.6287);
345 | Ok(())
346 | }
347 |
--------------------------------------------------------------------------------
/src/radam.rs:
--------------------------------------------------------------------------------
1 | /*!
2 | RAdam optimiser
3 |
4 | Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)
5 |
6 | As decoupled weight decay is implemented, this can be used equivalent to the paper (which uses decoupled weight decay),
7 | or the PyTorch implementation (which does not)
8 |
9 | Pseudocode (including decoupling of weight decay):
10 |
11 | $$
12 | \\begin{aligned}
13 | &\\rule{110mm}{0.4pt} \\\\
14 | &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\beta_1, \\beta_2
15 | \\text{ (betas)}, \\: \\theta_0 \\text{ (params)}, \\:f(\\theta) \\text{ (objective)}, \\:
16 | \\lambda \\text{ (weightdecay)}, \\\\
17 | &\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\
18 | &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)},
19 | v_0 \\leftarrow 0 \\text{ ( second moment)}, \\\\
20 | &\\hspace{18mm} \\rho_{\\infty} \\leftarrow 2/(1-\\beta_2) -1 \\\\[-1.ex]
21 | &\\rule{110mm}{0.4pt} \\\\
22 | &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
23 | &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
24 | &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
25 | &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
26 | &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
27 | &\\hspace{10mm}\\textbf{else} \\\\
28 | &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
29 | &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\
30 | &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\
31 | &\\hspace{5mm}\\widehat{m_t} \\leftarrow m_t/\\big(1-\\beta_1^t \\big) \\\\
32 | &\\hspace{5mm}\\rho_t \\leftarrow \\rho_{\\infty} -
33 | 2 t \\beta^t_2 /\\big(1-\\beta_2^t \\big) \\\\[0.1.ex]
34 | &\\hspace{5mm}\\textbf{if} \\: \\rho_t > 5 \\\\
35 | &\\hspace{10mm} l_t \\leftarrow \\frac{\\sqrt{ (1-\\beta^t_2) }}{ \\sqrt{v_t} +\\epsilon } \\\\
36 | &\\hspace{10mm} r_t \\leftarrow
37 | \\sqrt{\\frac{(\\rho_t-4)(\\rho_t-2)\\rho_{\\infty}}{(\\rho_{\\infty}-4)(\\rho_{\\infty}-2) \\rho_t}} \\\\
38 | &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} r_t l_t \\\\
39 | &\\hspace{5mm}\\textbf{else} \\\\
40 | &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} \\\\
41 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
42 | &\\bf{return} \\: \\theta_t \\\\[-1.ex]
43 | &\\rule{110mm}{0.4pt} \\\\[-1.ex]
44 | \\end{aligned}
45 | $$
46 | */
47 |
48 | use candle_core::{Result, Var};
49 | use candle_nn::optim::Optimizer;
50 |
51 | use crate::{Decay, OptimParams};
52 |
53 | /// R Adam optimiser
54 | ///
55 | /// Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265)
56 |
57 | #[derive(Debug)]
58 | pub struct RAdam {
59 | vars: Vec,
60 | params: ParamsRAdam,
61 | rho_inf: f64,
62 | t: f64,
63 | }
64 |
65 | #[derive(Debug)]
66 | struct VarRAdam {
67 | theta: Var,
68 | m: Var,
69 | v: Var,
70 | }
71 |
72 | /// Parameters for the RAdam optimiser
73 | #[derive(Clone, Debug, PartialEq, PartialOrd)]
74 | pub struct ParamsRAdam {
75 | /// Learning rate
76 | pub lr: f64,
77 | /// Coefficient for moving average of first moment
78 | pub beta_1: f64,
79 | /// Coefficient for moving average of second moment
80 | pub beta_2: f64,
81 | /// Weight decay
82 | pub weight_decay: Option,
83 | /// Term added to denominator to improve numerical stability
84 | pub eps: f64,
85 | }
86 |
87 | impl Default for ParamsRAdam {
88 | fn default() -> Self {
89 | Self {
90 | lr: 0.001,
91 | beta_1: 0.9,
92 | beta_2: 0.999,
93 | eps: 1e-8,
94 | weight_decay: None,
95 | }
96 | }
97 | }
98 |
99 | impl Optimizer for RAdam {
100 | type Config = ParamsRAdam;
101 |
102 | fn new(vars: Vec, params: ParamsRAdam) -> Result {
103 | let vars = vars
104 | .into_iter()
105 | .filter(|var| var.dtype().is_float())
106 | .map(|var| {
107 | let dtype = var.dtype();
108 | let shape = var.shape();
109 | let device = var.device();
110 | let m = Var::zeros(shape, dtype, device)?;
111 | let v = Var::zeros(shape, dtype, device)?;
112 | Ok(VarRAdam { theta: var, m, v })
113 | })
114 | .collect::>>()?;
115 | // // Err(SGDError::NoMomentum)?;
116 | // let mut params = params;
117 | // params.t = 0;
118 | let rho_inf = 2. / (1. - params.beta_2) - 1.;
119 | Ok(Self {
120 | vars,
121 | params,
122 | rho_inf,
123 | t: 1.,
124 | })
125 | }
126 |
127 | fn learning_rate(&self) -> f64 {
128 | self.params.lr
129 | }
130 |
131 | fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
132 | // println!("prod {}", prod);
133 | let rho_t = self.rho_inf
134 | - 2. * self.t * self.params.beta_2.powf(self.t)
135 | / (1. - self.params.beta_2.powf(self.t));
136 |
137 | if let Some(wd) = self.params.weight_decay {
138 | match wd {
139 | Decay::WeightDecay(wd) => {
140 | for var in &self.vars {
141 | let theta = &var.theta;
142 | let m = &var.m;
143 | let v = &var.v;
144 | if let Some(grad) = grads.get(theta) {
145 | let grad = &(grad + (wd * theta.as_tensor())?)?;
146 | let m_next = ((self.params.beta_1 * m.as_tensor())?
147 | + ((1. - self.params.beta_1) * grad)?)?;
148 | let v_next = ((self.params.beta_2 * v.as_tensor())?
149 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
150 | let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
151 |
152 | let delta = if rho_t > 5. {
153 | let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
154 | / (&v_next.sqrt()? + self.params.eps)?)?;
155 | let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
156 | / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
157 | .sqrt();
158 | (self.params.lr * r * (l * m_hat)?)?
159 | } else {
160 | (self.params.lr * m_hat)?
161 | };
162 | theta.set(&theta.sub(&(delta))?)?;
163 | m.set(&m_next)?;
164 | v.set(&v_next)?;
165 | }
166 | }
167 | }
168 | Decay::DecoupledWeightDecay(decay) => {
169 | for var in &self.vars {
170 | let theta = &var.theta;
171 | let m = &var.m;
172 | let v = &var.v;
173 | if let Some(grad) = grads.get(theta) {
174 | // decoupled weight decay step
175 | theta
176 | .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?;
177 | let m_next = ((self.params.beta_1 * m.as_tensor())?
178 | + ((1. - self.params.beta_1) * grad)?)?;
179 | let v_next = ((self.params.beta_2 * v.as_tensor())?
180 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
181 | let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
182 |
183 | let delta = if rho_t > 5. {
184 | let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
185 | / (&v_next.sqrt()? + self.params.eps)?)?;
186 | let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
187 | / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
188 | .sqrt();
189 | (self.params.lr * r * (l * m_hat)?)?
190 | } else {
191 | (self.params.lr * m_hat)?
192 | };
193 | theta.set(&theta.sub(&(delta))?)?;
194 | m.set(&m_next)?;
195 | v.set(&v_next)?;
196 | }
197 | }
198 | }
199 | }
200 | } else {
201 | for var in &self.vars {
202 | let theta = &var.theta;
203 | let m = &var.m;
204 | let v = &var.v;
205 | if let Some(grad) = grads.get(theta) {
206 | let m_next = ((self.params.beta_1 * m.as_tensor())?
207 | + ((1. - self.params.beta_1) * grad)?)?;
208 | let v_next = ((self.params.beta_2 * v.as_tensor())?
209 | + ((1. - self.params.beta_2) * grad.powf(2.)?)?)?;
210 | let m_hat = (&m_next / (1. - self.params.beta_1.powf(self.t)))?;
211 |
212 | let delta = if rho_t > 5. {
213 | let l = ((1. - self.params.beta_2.powf(self.t)).sqrt()
214 | / (&v_next.sqrt()? + self.params.eps)?)?;
215 | let r = ((rho_t - 4.) * (rho_t - 2.) * self.rho_inf
216 | / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t))
217 | .sqrt();
218 | (self.params.lr * r * (l * m_hat)?)?
219 | } else {
220 | (self.params.lr * m_hat)?
221 | };
222 | theta.set(&theta.sub(&(delta))?)?;
223 | m.set(&m_next)?;
224 | v.set(&v_next)?;
225 | }
226 | }
227 | }
228 |
229 | self.t += 1.;
230 | Ok(())
231 | }
232 |
233 | fn set_learning_rate(&mut self, lr: f64) {
234 | self.params.lr = lr;
235 | }
236 | }
237 |
238 | impl OptimParams for RAdam {
239 | fn params(&self) -> &Self::Config {
240 | &self.params
241 | }
242 |
243 | fn set_params(&mut self, config: Self::Config) {
244 | self.params = config;
245 | }
246 | }
247 |
248 | impl RAdam {
249 | /// Return the vars being optimised
250 | #[must_use]
251 | pub fn into_inner(self) -> Vec {
252 | self.vars.into_iter().map(|v| v.theta).collect()
253 | }
254 |
255 | // pub fn push(&mut self, var: &Var) {
256 | // self.vars.push(var.clone());
257 | // }
258 | }
259 |
260 | #[cfg(test)]
261 | mod tests {
262 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
263 |
264 | use anyhow::Result;
265 | use assert_approx_eq::assert_approx_eq;
266 | use candle_core::{Device, Var};
267 | use candle_nn::Optimizer;
268 |
269 | use super::*;
270 | #[test]
271 | fn lr_test() -> Result<()> {
272 | let params = ParamsRAdam {
273 | lr: 0.004,
274 | ..Default::default()
275 | };
276 | // Now use backprop to run a linear regression between samples and get the coefficients back.
277 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
278 | let b = Var::new(0f32, &Device::Cpu)?;
279 | let mut optim = RAdam::new(vec![w.clone(), b.clone()], params)?;
280 | assert_approx_eq!(0.004, optim.learning_rate());
281 | optim.set_learning_rate(0.002);
282 | assert_approx_eq!(0.002, optim.learning_rate());
283 | Ok(())
284 | }
285 |
286 | #[test]
287 | fn into_inner_test() -> Result<()> {
288 | let params = ParamsRAdam::default();
289 | let w = Var::new(&[[3f32, 1.]], &Device::Cpu)?;
290 | let b = Var::new(-2f32, &Device::Cpu)?;
291 | let optim = RAdam::new(vec![w.clone(), b.clone()], params)?;
292 | let inner = optim.into_inner();
293 | assert_eq!(inner[0].as_tensor().to_vec2::()?, &[[3f32, 1.]]);
294 | assert_approx_eq!(inner[1].as_tensor().to_vec0::()?, -2_f32);
295 | Ok(())
296 | }
297 |
298 | #[test]
299 | fn params_test() -> Result<()> {
300 | let params = ParamsRAdam {
301 | lr: 0.004,
302 | ..Default::default()
303 | };
304 | // Now use backprop to run a linear regression between samples and get the coefficients back.
305 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
306 | let b = Var::new(0f32, &Device::Cpu)?;
307 | let mut optim = RAdam::new(vec![w.clone(), b.clone()], params.clone())?;
308 | assert_eq!(params, optim.params().clone());
309 | let new_params = ParamsRAdam {
310 | lr: 0.002,
311 | ..Default::default()
312 | };
313 | optim.set_params(new_params.clone());
314 | assert_eq!(new_params, optim.params().clone());
315 | Ok(())
316 | }
317 | }
318 |
--------------------------------------------------------------------------------
/tests/rmsprop-tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::rmsprop::{ParamsRMSprop, RMSprop};
7 |
8 | /* The results of this test have been checked against the following PyTorch code.
9 | import torch
10 | from torch import optim
11 |
12 | w_gen = torch.tensor([[3., 1.]])
13 | b_gen = torch.tensor([-2.])
14 |
15 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
16 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
17 |
18 | m = torch.nn.Linear(2, 1)
19 | with torch.no_grad():
20 | m.weight.zero_()
21 | m.bias.zero_()
22 | optimiser = optim.RMSprop(m.parameters())
23 | # optimiser.zero_grad()
24 | for _step in range(100):
25 | optimiser.zero_grad()
26 | ys = m(sample_xs)
27 | loss = ((ys - sample_ys)**2).sum()
28 | loss.backward()
29 | optimiser.step()
30 | # print("Optimizer state begin")
31 | # print(optimiser.state)
32 | # print("Optimizer state end")
33 | print(m.weight)
34 | print(m.bias)
35 | */
36 | #[test]
37 | fn rmsprop_test() -> Result<()> {
38 | // Generate some linear data, y = 3.x1 + x2 - 2.
39 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
40 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
41 | let gen = Linear::new(w_gen, Some(b_gen));
42 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
43 | let sample_ys = gen.forward(&sample_xs)?;
44 |
45 | let params = ParamsRMSprop::default();
46 | // Now use backprop to run a linear regression between samples and get the coefficients back.
47 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
48 | let b = Var::new(0f32, &Device::Cpu)?;
49 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
50 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
51 | for _step in 0..100 {
52 | let ys = lin.forward(&sample_xs)?;
53 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
54 | n_sgd.backward_step(&loss)?;
55 | }
56 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.6650, 0.7867]]);
57 | assert_eq!(to_vec0_round(&b, 4)?, 1.3012);
58 | Ok(())
59 | }
60 |
61 | /* The results of this test have been checked against the following PyTorch code.
62 | import torch
63 | from torch import optim
64 |
65 | w_gen = torch.tensor([[3., 1.]])
66 | b_gen = torch.tensor([-2.])
67 |
68 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
69 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
70 |
71 | m = torch.nn.Linear(2, 1)
72 | with torch.no_grad():
73 | m.weight.zero_()
74 | m.bias.zero_()
75 | optimiser = optim.RMSprop(m.parameters(), weight_decay = 0.4)
76 | # optimiser.zero_grad()
77 | for _step in range(100):
78 | optimiser.zero_grad()
79 | ys = m(sample_xs)
80 | loss = ((ys - sample_ys)**2).sum()
81 | loss.backward()
82 | optimiser.step()
83 | # print("Optimizer state begin")
84 | # print(optimiser.state)
85 | # print("Optimizer state end")
86 | print(m.weight)
87 | print(m.bias)
88 | */
89 | #[test]
90 | fn rmsprop_weight_decay_test() -> Result<()> {
91 | // Generate some linear data, y = 3.x1 + x2 - 2.
92 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
93 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
94 | let gen = Linear::new(w_gen, Some(b_gen));
95 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
96 | let sample_ys = gen.forward(&sample_xs)?;
97 |
98 | let params = ParamsRMSprop {
99 | weight_decay: Some(0.4),
100 | ..Default::default()
101 | };
102 | // Now use backprop to run a linear regression between samples and get the coefficients back.
103 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
104 | let b = Var::new(0f32, &Device::Cpu)?;
105 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
106 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
107 | for _step in 0..100 {
108 | let ys = lin.forward(&sample_xs)?;
109 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
110 | n_sgd.backward_step(&loss)?;
111 | }
112 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.6643, 0.7867]]);
113 | assert_eq!(to_vec0_round(&b, 4)?, 1.2926);
114 | Ok(())
115 | }
116 |
117 | /* The results of this test have been checked against the following PyTorch code.
118 | import torch
119 | from torch import optim
120 |
121 | w_gen = torch.tensor([[3., 1.]])
122 | b_gen = torch.tensor([-2.])
123 |
124 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
125 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
126 |
127 | m = torch.nn.Linear(2, 1)
128 | with torch.no_grad():
129 | m.weight.zero_()
130 | m.bias.zero_()
131 | optimiser = optim.RMSprop(m.parameters(), centered = True)
132 | # optimiser.zero_grad()
133 | for _step in range(100):
134 | optimiser.zero_grad()
135 | ys = m(sample_xs)
136 | loss = ((ys - sample_ys)**2).sum()
137 | loss.backward()
138 | optimiser.step()
139 | # print("Optimizer state begin")
140 | # print(optimiser.state)
141 | # print("Optimizer state end")
142 | print(m.weight)
143 | print(m.bias)
144 | */
145 | #[test]
146 | fn rmsprop_centered_test() -> Result<()> {
147 | // Generate some linear data, y = 3.x1 + x2 - 2.
148 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
149 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
150 | let gen = Linear::new(w_gen, Some(b_gen));
151 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
152 | let sample_ys = gen.forward(&sample_xs)?;
153 |
154 | let params = ParamsRMSprop {
155 | centered: true,
156 | ..Default::default()
157 | };
158 | // Now use backprop to run a linear regression between samples and get the coefficients back.
159 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
160 | let b = Var::new(0f32, &Device::Cpu)?;
161 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
162 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
163 | for _step in 0..100 {
164 | let ys = lin.forward(&sample_xs)?;
165 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
166 | n_sgd.backward_step(&loss)?;
167 | }
168 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.8892, 0.7617]]);
169 | assert_eq!(to_vec0_round(&b, 4)?, 1.3688);
170 | Ok(())
171 | }
172 |
173 | /* The results of this test have been checked against the following PyTorch code.
174 | import torch
175 | from torch import optim
176 |
177 | w_gen = torch.tensor([[3., 1.]])
178 | b_gen = torch.tensor([-2.])
179 |
180 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
181 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
182 |
183 | m = torch.nn.Linear(2, 1)
184 | with torch.no_grad():
185 | m.weight.zero_()
186 | m.bias.zero_()
187 | optimiser = optim.RMSprop(m.parameters(), centered = True)
188 | # optimiser.zero_grad()
189 | for _step in range(100):
190 | optimiser.zero_grad()
191 | ys = m(sample_xs)
192 | loss = ((ys - sample_ys)**2).sum()
193 | loss.backward()
194 | optimiser.step()
195 | # print("Optimizer state begin")
196 | # print(optimiser.state)
197 | # print("Optimizer state end")
198 | print(m.weight)
199 | print(m.bias)
200 | */
201 | #[test]
202 | fn rmsprop_centered_decay_test() -> Result<()> {
203 | // Generate some linear data, y = 3.x1 + x2 - 2.
204 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
205 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
206 | let gen = Linear::new(w_gen, Some(b_gen));
207 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
208 | let sample_ys = gen.forward(&sample_xs)?;
209 |
210 | let params = ParamsRMSprop {
211 | centered: true,
212 | weight_decay: Some(0.4),
213 | ..Default::default()
214 | };
215 | // Now use backprop to run a linear regression between samples and get the coefficients back.
216 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
217 | let b = Var::new(0f32, &Device::Cpu)?;
218 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
219 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
220 | for _step in 0..100 {
221 | let ys = lin.forward(&sample_xs)?;
222 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
223 | n_sgd.backward_step(&loss)?;
224 | }
225 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.8883, 0.7621]]);
226 | assert_eq!(to_vec0_round(&b, 4)?, 1.3558);
227 | Ok(())
228 | }
229 |
230 | /* The results of this test have been checked against the following PyTorch code.
231 | import torch
232 | from torch import optim
233 |
234 | w_gen = torch.tensor([[3., 1.]])
235 | b_gen = torch.tensor([-2.])
236 |
237 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
238 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
239 |
240 | m = torch.nn.Linear(2, 1)
241 | with torch.no_grad():
242 | m.weight.zero_()
243 | m.bias.zero_()
244 | optimiser = optim.RMSprop(m.parameters(), momentum = 0.4)
245 | # optimiser.zero_grad()
246 | for _step in range(100):
247 | optimiser.zero_grad()
248 | ys = m(sample_xs)
249 | loss = ((ys - sample_ys)**2).sum()
250 | loss.backward()
251 | optimiser.step()
252 | # print("Optimizer state begin")
253 | # print(optimiser.state)
254 | # print("Optimizer state end")
255 | print(m.weight)
256 | print(m.bias)
257 | */
258 | #[test]
259 | fn rmsprop_momentum_test() -> Result<()> {
260 | // Generate some linear data, y = 3.x1 + x2 - 2.
261 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
262 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
263 | let gen = Linear::new(w_gen, Some(b_gen));
264 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
265 | let sample_ys = gen.forward(&sample_xs)?;
266 |
267 | let params = ParamsRMSprop {
268 | momentum: Some(0.4),
269 | ..Default::default()
270 | };
271 | // Now use backprop to run a linear regression between samples and get the coefficients back.
272 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
273 | let b = Var::new(0f32, &Device::Cpu)?;
274 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
275 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
276 | for _step in 0..100 {
277 | let ys = lin.forward(&sample_xs)?;
278 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
279 | n_sgd.backward_step(&loss)?;
280 | }
281 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.3042, 0.6835]]);
282 | assert_eq!(to_vec0_round(&b, 4)?, 1.5441);
283 | Ok(())
284 | }
285 |
286 | /* The results of this test have been checked against the following PyTorch code.
287 | import torch
288 | from torch import optim
289 |
290 | w_gen = torch.tensor([[3., 1.]])
291 | b_gen = torch.tensor([-2.])
292 |
293 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
294 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
295 |
296 | m = torch.nn.Linear(2, 1)
297 | with torch.no_grad():
298 | m.weight.zero_()
299 | m.bias.zero_()
300 | optimiser = optim.RMSprop(m.parameters(), momentum = 0.4, weight_decay = 0.4)
301 | # optimiser.zero_grad()
302 | for _step in range(100):
303 | optimiser.zero_grad()
304 | ys = m(sample_xs)
305 | loss = ((ys - sample_ys)**2).sum()
306 | loss.backward()
307 | optimiser.step()
308 | # print("Optimizer state begin")
309 | # print(optimiser.state)
310 | # print("Optimizer state end")
311 | print(m.weight)
312 | print(m.bias)
313 | */
314 | #[test]
315 | fn rmsprop_momentum_decay_test() -> Result<()> {
316 | // Generate some linear data, y = 3.x1 + x2 - 2.
317 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
318 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
319 | let gen = Linear::new(w_gen, Some(b_gen));
320 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
321 | let sample_ys = gen.forward(&sample_xs)?;
322 |
323 | let params = ParamsRMSprop {
324 | momentum: Some(0.4),
325 | weight_decay: Some(0.4),
326 | ..Default::default()
327 | };
328 | // Now use backprop to run a linear regression between samples and get the coefficients back.
329 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
330 | let b = Var::new(0f32, &Device::Cpu)?;
331 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
332 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
333 | for _step in 0..100 {
334 | let ys = lin.forward(&sample_xs)?;
335 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
336 | n_sgd.backward_step(&loss)?;
337 | }
338 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.3028, 0.6858]]);
339 | assert_eq!(to_vec0_round(&b, 4)?, 1.5149);
340 | Ok(())
341 | }
342 |
343 | /* The results of this test have been checked against the following PyTorch code.
344 | import torch
345 | from torch import optim
346 |
347 | w_gen = torch.tensor([[3., 1.]])
348 | b_gen = torch.tensor([-2.])
349 |
350 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
351 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
352 |
353 | m = torch.nn.Linear(2, 1)
354 | with torch.no_grad():
355 | m.weight.zero_()
356 | m.bias.zero_()
357 | optimiser = optim.RMSprop(m.parameters(), centered = True, momentum = 0.4)
358 | # optimiser.zero_grad()
359 | for _step in range(100):
360 | optimiser.zero_grad()
361 | ys = m(sample_xs)
362 | loss = ((ys - sample_ys)**2).sum()
363 | loss.backward()
364 | optimiser.step()
365 | # print("Optimizer state begin")
366 | # print(optimiser.state)
367 | # print("Optimizer state end")
368 | print(m.weight)
369 | print(m.bias)
370 | */
371 | #[test]
372 | fn rmsprop_centered_momentum_test() -> Result<()> {
373 | // Generate some linear data, y = 3.x1 + x2 - 2.
374 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
375 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
376 | let gen = Linear::new(w_gen, Some(b_gen));
377 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
378 | let sample_ys = gen.forward(&sample_xs)?;
379 |
380 | let params = ParamsRMSprop {
381 | centered: true,
382 | momentum: Some(0.4),
383 | ..Default::default()
384 | };
385 | // Now use backprop to run a linear regression between samples and get the coefficients back.
386 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
387 | let b = Var::new(0f32, &Device::Cpu)?;
388 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
389 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
390 | for _step in 0..100 {
391 | let ys = lin.forward(&sample_xs)?;
392 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
393 | n_sgd.backward_step(&loss)?;
394 | }
395 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.4486, 0.6715]]);
396 | assert_eq!(to_vec0_round(&b, 4)?, 1.5045);
397 | Ok(())
398 | }
399 |
400 | /* The results of this test have been checked against the following PyTorch code.
401 | import torch
402 | from torch import optim
403 |
404 | w_gen = torch.tensor([[3., 1.]])
405 | b_gen = torch.tensor([-2.])
406 |
407 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
408 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
409 |
410 | m = torch.nn.Linear(2, 1)
411 | with torch.no_grad():
412 | m.weight.zero_()
413 | m.bias.zero_()
414 | optimiser = optim.RMSprop(m.parameters(), centered = True, momentum = 0.4, weight_decay = 0.4)
415 | # optimiser.zero_grad()
416 | for _step in range(100):
417 | optimiser.zero_grad()
418 | ys = m(sample_xs)
419 | loss = ((ys - sample_ys)**2).sum()
420 | loss.backward()
421 | optimiser.step()
422 | # print("Optimizer state begin")
423 | # print(optimiser.state)
424 | # print("Optimizer state end")
425 | print(m.weight)
426 | print(m.bias)
427 | */
428 | #[test]
429 | fn rmsprop_centered_momentum_decay_test() -> Result<()> {
430 | // Generate some linear data, y = 3.x1 + x2 - 2.
431 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
432 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
433 | let gen = Linear::new(w_gen, Some(b_gen));
434 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
435 | let sample_ys = gen.forward(&sample_xs)?;
436 |
437 | let params = ParamsRMSprop {
438 | centered: true,
439 | momentum: Some(0.4),
440 | weight_decay: Some(0.4),
441 | ..Default::default()
442 | };
443 | // Now use backprop to run a linear regression between samples and get the coefficients back.
444 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
445 | let b = Var::new(0f32, &Device::Cpu)?;
446 | let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?;
447 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
448 | for _step in 0..100 {
449 | let ys = lin.forward(&sample_xs)?;
450 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
451 | n_sgd.backward_step(&loss)?;
452 | }
453 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.4468, 0.6744]]);
454 | assert_eq!(to_vec0_round(&b, 4)?, 1.4695);
455 | Ok(())
456 | }
457 |
--------------------------------------------------------------------------------
/src/lbfgs/strong_wolfe.rs:
--------------------------------------------------------------------------------
1 | use crate::Model;
2 | use candle_core::Result as CResult;
3 | use candle_core::{Tensor, Var};
4 |
5 | use super::{add_grad, flat_grads, set_vs, Lbfgs};
6 |
7 | /// ported from pytorch torch/optim/lbfgs.py ported from
8 | fn cubic_interpolate(
9 | // position 1
10 | x1: f64,
11 | // f(x1)
12 | f1: f64,
13 | // f'(x1)
14 | g1: f64,
15 | // position 2
16 | x2: f64,
17 | // f(x2)
18 | f2: f64,
19 | // f'(x2)
20 | g2: f64,
21 | bounds: Option<(f64, f64)>,
22 | ) -> f64 {
23 | let (xmin_bound, xmax_bound) = if let Some(bound) = bounds {
24 | bound
25 | } else if x1 < x2 {
26 | (x1, x2)
27 | } else {
28 | (x2, x1)
29 | };
30 | let d1 = g1 + g2 - 3. * (f1 - f2) / (x1 - x2);
31 | let d2_square = d1.powi(2) - g1 * g2;
32 | if d2_square >= 0. {
33 | let d2 = d2_square.sqrt();
34 | let min_pos = if x1 <= x2 {
35 | x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2. * d2))
36 | } else {
37 | x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2. * d2))
38 | };
39 | (min_pos.max(xmin_bound)).min(xmax_bound)
40 | } else {
41 | xmin_bound.midpoint(xmax_bound)
42 | }
43 | }
44 |
45 | impl Lbfgs<'_, M> {
46 | /// Strong Wolfe line search
47 | ///
48 | /// # Arguments
49 | ///
50 | /// step size
51 | ///
52 | /// direction
53 | ///
54 | /// initial loss
55 | ///
56 | /// initial grad
57 | ///
58 | /// initial directional grad
59 | ///
60 | /// c1 coefficient for wolfe condition
61 | ///
62 | /// c2 coefficient for wolfe condition
63 | ///
64 | /// minimum allowed progress
65 | ///
66 | /// maximum number of iterations
67 | ///
68 | /// # Returns
69 | ///
70 | /// (`f_new`, `g_new`, t, `ls_func_evals`)
71 | #[allow(clippy::too_many_arguments, clippy::too_many_lines)]
72 | pub(super) fn strong_wolfe(
73 | &mut self,
74 | mut step_size: f64, // step size
75 | direction: &Tensor, // direction
76 | loss: &Tensor, // initial loss
77 | grad: &Tensor, // initial grad
78 | directional_grad: f64, // initial directional grad
79 | c1: f64, // c1 coefficient for wolfe condition
80 | c2: f64, // c2 coefficient for wolfe condition
81 | tolerance_change: f64, // minimum allowed progress
82 | max_ls: usize, // maximum number of iterations
83 | ) -> CResult<(Tensor, Tensor, f64, usize)> {
84 | // ported from https://github.com/torch/optim/blob/master/lswolfe.lua
85 |
86 | let dtype = loss.dtype();
87 | let shape = loss.shape();
88 | let dev = loss.device();
89 |
90 | let d_norm = &direction
91 | .abs()?
92 | .max(0)?
93 | .to_dtype(candle_core::DType::F64)?
94 | .to_scalar::()?;
95 |
96 | // evaluate objective and gradient using initial step
97 | let (f_new, g_new, mut l2_new) = self.directional_evaluate(step_size, direction)?;
98 | let g_new = Var::from_tensor(&g_new)?;
99 | let mut f_new = f_new
100 | .to_dtype(candle_core::DType::F64)?
101 | .to_scalar::()?;
102 | let mut ls_func_evals = 1;
103 | let mut gtd_new = g_new
104 | .unsqueeze(0)?
105 | .matmul(&(direction.unsqueeze(1)?))?
106 | .to_dtype(candle_core::DType::F64)?
107 | .squeeze(1)?
108 | .squeeze(0)?
109 | .to_scalar::()?;
110 |
111 | // bracket an interval containing a point satisfying the Wolfe criteria
112 | let grad_det = grad.copy()?;
113 | let g_prev = Var::from_tensor(&grad_det)?;
114 | let scalar_loss = loss.to_dtype(candle_core::DType::F64)?.to_scalar::()?;
115 | let mut f_prev = scalar_loss;
116 | let l2_init = self.l2_reg()?;
117 | let mut l2_prev = l2_init;
118 | let (mut t_prev, mut gtd_prev) = (0., directional_grad);
119 | let mut done = false;
120 | let mut ls_iter = 0;
121 |
122 | let mut bracket_gtd;
123 | let mut bracket_l2;
124 | let mut bracket_f;
125 | let (mut bracket, bracket_g) = loop {
126 | // check conditions
127 | if f_new + l2_new >= f_prev + l2_prev {
128 | bracket_gtd = [gtd_prev, gtd_new];
129 | bracket_l2 = [l2_prev, l2_new];
130 | bracket_f = [f_prev, f_new];
131 | break (
132 | [t_prev, step_size],
133 | [g_prev, Var::from_tensor(g_new.as_tensor())?],
134 | );
135 | }
136 |
137 | if gtd_new.abs() <= -c2 * directional_grad {
138 | done = true;
139 | bracket_gtd = [gtd_prev, gtd_new];
140 | bracket_l2 = [l2_prev, l2_new];
141 | bracket_f = [f_new, f_new];
142 | break (
143 | [step_size, step_size],
144 | [
145 | Var::from_tensor(&g_new.as_tensor().copy()?)?,
146 | Var::from_tensor(g_new.as_tensor())?,
147 | ],
148 | );
149 | }
150 |
151 | if gtd_new >= 0. {
152 | bracket_gtd = [gtd_prev, gtd_new];
153 | bracket_l2 = [l2_prev, l2_new];
154 | bracket_f = [f_prev, f_new];
155 | break (
156 | [t_prev, step_size],
157 | [g_prev, Var::from_tensor(g_new.as_tensor())?],
158 | );
159 | }
160 |
161 | // interpolate
162 | let min_step = step_size + 0.01 * (step_size - t_prev);
163 | let max_step = step_size * 10.;
164 | let tmp = step_size;
165 | step_size = cubic_interpolate(
166 | t_prev,
167 | f_prev + l2_prev,
168 | gtd_prev,
169 | step_size,
170 | f_new + l2_new,
171 | gtd_new,
172 | Some((min_step, max_step)),
173 | );
174 |
175 | // next step
176 | t_prev = tmp;
177 | f_prev = f_new;
178 | g_prev.set(g_new.as_tensor())?;
179 | l2_prev = l2_new;
180 | gtd_prev = gtd_new;
181 | // assign to temp vars:
182 | let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?;
183 |
184 | // overwrite
185 | f_new = next_f
186 | .to_dtype(candle_core::DType::F64)?
187 | .to_scalar::()?;
188 | g_new.set(&next_g)?;
189 | l2_new = next_l2;
190 |
191 | ls_func_evals += 1;
192 |
193 | gtd_new = g_new
194 | .unsqueeze(0)?
195 | .matmul(&(direction.unsqueeze(1)?))?
196 | .to_dtype(candle_core::DType::F64)?
197 | .squeeze(1)?
198 | .squeeze(0)?
199 | .to_scalar::()?;
200 | ls_iter += 1;
201 |
202 | // reached max number of iterations?
203 | if ls_iter == max_ls {
204 | bracket_gtd = [gtd_prev, gtd_new];
205 | bracket_l2 = [l2_prev, l2_new];
206 | bracket_f = [scalar_loss, f_new];
207 | break (
208 | [0., step_size],
209 | [
210 | Var::from_tensor(grad)?,
211 | Var::from_tensor(g_new.as_tensor())?,
212 | ],
213 | );
214 | }
215 | };
216 |
217 | // zoom phase: we now have a point satisfying the criteria, or
218 | // a bracket around it. We refine the bracket until we find the
219 | // exact point satisfying the criteria
220 | let mut insuf_progress = false;
221 | // find high and low points in bracket
222 | let (mut low_pos, mut high_pos) =
223 | if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] {
224 | (0, 1)
225 | } else {
226 | (1, 0)
227 | };
228 | while !done && ls_iter < max_ls {
229 | // line-search bracket is so small
230 | if (bracket[1] - bracket[0]).abs() * d_norm < tolerance_change {
231 | break;
232 | }
233 |
234 | // compute new trial value
235 | step_size = cubic_interpolate(
236 | bracket[0],
237 | bracket_f[0] + bracket_l2[0],
238 | bracket_gtd[0],
239 | bracket[1],
240 | bracket_f[1] + bracket_l2[1],
241 | bracket_gtd[1],
242 | None,
243 | );
244 |
245 | // test that we are making sufficient progress:
246 | // in case `t` is so close to boundary, we mark that we are making
247 | // insufficient progress, and if
248 | // + we have made insufficient progress in the last step, or
249 | // + `t` is at one of the boundary,
250 | // we will move `t` to a position which is `0.1 * len(bracket)`
251 | // away from the nearest boundary point.
252 | let max_bracket = bracket[0].max(bracket[1]);
253 | let min_bracket = bracket[0].min(bracket[1]);
254 | let eps = 0.1 * (max_bracket - min_bracket);
255 | if (max_bracket - step_size).min(step_size - min_bracket) < eps {
256 | // interpolation close to boundary
257 | if insuf_progress || step_size >= max_bracket || step_size <= min_bracket {
258 | // evaluate at 0.1 away from boundary
259 | if (step_size - max_bracket).abs() < (step_size - min_bracket).abs() {
260 | step_size = max_bracket - eps;
261 | } else {
262 | step_size = min_bracket + eps;
263 | }
264 | insuf_progress = false;
265 | } else {
266 | insuf_progress = true;
267 | }
268 | } else {
269 | insuf_progress = false;
270 | }
271 |
272 | // Evaluate new point
273 | // assign to temp vars:
274 | let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?;
275 | // overwrite
276 | f_new = next_f
277 | .to_dtype(candle_core::DType::F64)?
278 | .to_scalar::()?;
279 |
280 | l2_new = next_l2;
281 | ls_func_evals += 1;
282 |
283 | gtd_new = next_g
284 | .unsqueeze(0)?
285 | .matmul(&(direction.unsqueeze(1)?))?
286 | .to_dtype(candle_core::DType::F64)?
287 | .squeeze(1)?
288 | .squeeze(0)?
289 | .to_scalar::()?;
290 | ls_iter += 1;
291 |
292 | if f_new + l2_new > (scalar_loss + l2_init + c1 * step_size * directional_grad)
293 | || f_new + l2_new >= bracket_f[low_pos] + bracket_l2[low_pos]
294 | {
295 | // Armijo condition not satisfied or not lower than lowest point
296 | bracket[high_pos] = step_size;
297 | bracket_f[high_pos] = f_new;
298 | bracket_g[high_pos].set(&next_g)?;
299 | bracket_l2[high_pos] = l2_new;
300 | bracket_gtd[high_pos] = gtd_new;
301 |
302 | (low_pos, high_pos) =
303 | if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] {
304 | (0, 1)
305 | } else {
306 | (1, 0)
307 | };
308 | } else {
309 | if gtd_new.abs() <= -c2 * directional_grad {
310 | // Wolfe conditions satisfied
311 | done = true;
312 | } else if gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0. {
313 | // old low becomes new high
314 | bracket[high_pos] = bracket[low_pos];
315 | bracket_f[high_pos] = bracket_f[low_pos];
316 | bracket_g[high_pos].set(bracket_g[low_pos].as_tensor())?;
317 | bracket_gtd[high_pos] = bracket_gtd[low_pos];
318 | bracket_l2[high_pos] = bracket_l2[low_pos];
319 | }
320 |
321 | // new point becomes new low
322 | bracket[low_pos] = step_size;
323 | bracket_f[low_pos] = f_new;
324 | bracket_g[low_pos].set(&next_g)?;
325 | bracket_gtd[low_pos] = gtd_new;
326 | bracket_l2[low_pos] = l2_new;
327 | }
328 | }
329 |
330 | // return new value, new grad, line-search value, nb of function evals
331 | step_size = bracket[low_pos];
332 | let [g0, g1] = bracket_g;
333 | let [f0, f1] = bracket_f;
334 | if low_pos == 1 {
335 | // if b is the lower value set a to b, else a should be returned
336 | Ok((
337 | Tensor::from_slice(&[f1], shape, dev)?.to_dtype(dtype)?,
338 | g1.into_inner(),
339 | step_size,
340 | ls_func_evals,
341 | ))
342 | } else {
343 | Ok((
344 | Tensor::from_slice(&[f0], shape, dev)?.to_dtype(dtype)?,
345 | g0.into_inner(),
346 | step_size,
347 | ls_func_evals,
348 | ))
349 | }
350 | }
351 |
352 | fn directional_evaluate(
353 | &mut self,
354 | mag: f64,
355 | direction: &Tensor,
356 | ) -> CResult<(Tensor, Tensor, f64)> {
357 | // need to cache the original result
358 | // Otherwise leads to drift over line search evals
359 | let original = self
360 | .vars
361 | .iter()
362 | .map(|v| v.as_tensor().copy())
363 | .collect::>>()?;
364 |
365 | add_grad(&mut self.vars, &(mag * direction)?)?;
366 | let loss = self.model.loss()?;
367 | let grad = flat_grads(&self.vars, &loss, self.params.weight_decay)?;
368 | let l2_reg = if let Some(wd) = self.params.weight_decay {
369 | 0.5 * wd
370 | * self
371 | .vars
372 | .iter()
373 | .map(|v| -> CResult {
374 | v.as_tensor()
375 | .sqr()?
376 | .sum_all()?
377 | .to_dtype(candle_core::DType::F64)?
378 | .to_scalar::()
379 | })
380 | .sum::>()?
381 | } else {
382 | 0.
383 | };
384 |
385 | set_vs(&mut self.vars, &original)?;
386 | // add_grad(&mut self.vars, &(-mag * direction)?)?;
387 | Ok((
388 | loss, //.to_dtype(candle_core::DType::F64)?.to_scalar::()?
389 | grad, l2_reg,
390 | ))
391 | }
392 |
393 | fn l2_reg(&self) -> CResult {
394 | if let Some(wd) = self.params.weight_decay {
395 | Ok(0.5
396 | * wd
397 | * self
398 | .vars
399 | .iter()
400 | .map(|v| -> CResult {
401 | v.as_tensor()
402 | .sqr()?
403 | .sum_all()?
404 | .to_dtype(candle_core::DType::F64)?
405 | .to_scalar::()
406 | })
407 | .sum::>()?)
408 | } else {
409 | Ok(0.)
410 | }
411 | }
412 | }
413 |
414 | #[cfg(test)]
415 | mod tests {
416 | // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
417 |
418 | use crate::lbfgs::ParamsLBFGS;
419 | use crate::{LossOptimizer, Model};
420 | use anyhow::Result;
421 | use assert_approx_eq::assert_approx_eq;
422 | use candle_core::Device;
423 | use candle_core::{Module, Result as CResult};
424 | pub struct LinearModel {
425 | linear: candle_nn::Linear,
426 | xs: Tensor,
427 | ys: Tensor,
428 | }
429 |
430 | impl Model for LinearModel {
431 | fn loss(&self) -> CResult {
432 | let preds = self.forward(&self.xs)?;
433 | let loss = candle_nn::loss::mse(&preds, &self.ys)?;
434 | Ok(loss)
435 | }
436 | }
437 |
438 | impl LinearModel {
439 | fn new() -> CResult<(Self, Vec)> {
440 | let weight = Var::from_tensor(&Tensor::new(&[3f64, 1.], &Device::Cpu)?)?;
441 | let bias = Var::from_tensor(&Tensor::new(-2f64, &Device::Cpu)?)?;
442 |
443 | let linear =
444 | candle_nn::Linear::new(weight.as_tensor().clone(), Some(bias.as_tensor().clone()));
445 |
446 | Ok((
447 | Self {
448 | linear,
449 | xs: Tensor::new(&[[2f64, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?,
450 | ys: Tensor::new(&[[7f64], [26.], [0.], [27.]], &Device::Cpu)?,
451 | },
452 | vec![weight, bias],
453 | ))
454 | }
455 |
456 | fn forward(&self, xs: &Tensor) -> CResult {
457 | self.linear.forward(xs)
458 | }
459 | }
460 |
461 | use super::*;
462 | #[test]
463 | fn l2_test() -> Result<()> {
464 | let params = ParamsLBFGS {
465 | lr: 0.004,
466 | ..Default::default()
467 | };
468 | let (model, vars) = LinearModel::new()?;
469 | let lbfgs = Lbfgs::new(vars, params, &model)?;
470 | let l2 = lbfgs.l2_reg()?;
471 | assert_approx_eq!(0.0, l2);
472 |
473 | let params = ParamsLBFGS {
474 | lr: 0.004,
475 | weight_decay: Some(1.0),
476 | ..Default::default()
477 | };
478 | let (model, vars) = LinearModel::new()?;
479 | let lbfgs = Lbfgs::new(vars, params, &model)?;
480 | let l2 = lbfgs.l2_reg()?;
481 | assert_approx_eq!(7.0, l2); // 0.5 *(3^2 +1^2 + (-2)^2)
482 | Ok(())
483 | }
484 | }
485 |
--------------------------------------------------------------------------------
/tests/esgd_tests.rs:
--------------------------------------------------------------------------------
1 | use candle_core::test_utils::{to_vec0_round, to_vec2_round};
2 |
3 | use anyhow::Result;
4 | use candle_core::{Device, Tensor, Var};
5 | use candle_nn::{Linear, Module, Optimizer};
6 | use candle_optimisers::{
7 | esgd::{ParamsSGD, SGD},
8 | Decay, Momentum,
9 | };
10 |
11 | /* The results of this test have been checked against the following PyTorch code.
12 | import torch
13 | from torch import optim
14 |
15 | w_gen = torch.tensor([[3., 1.]])
16 | b_gen = torch.tensor([-2.])
17 |
18 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
19 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
20 |
21 | m = torch.nn.Linear(2, 1)
22 | with torch.no_grad():
23 | m.weight.zero_()
24 | m.bias.zero_()
25 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=True)
26 | for _step in range(100):
27 | optimiser.zero_grad()
28 | ys = m(sample_xs)
29 | loss = ((ys - sample_ys)**2).sum()
30 | loss.backward()
31 | optimiser.step()
32 | print(m.weight)
33 | print(m.bias)
34 | */
35 | #[test]
36 | fn nesterov_sgd_test() -> Result<()> {
37 | // Generate some linear data, y = 3.x1 + x2 - 2.
38 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
39 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
40 | let gen = Linear::new(w_gen, Some(b_gen));
41 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
42 | let sample_ys = gen.forward(&sample_xs)?;
43 |
44 | let params = ParamsSGD {
45 | lr: 0.004,
46 | weight_decay: None,
47 | momentum: Some(Momentum::Nesterov(0.1)),
48 | dampening: 0.0,
49 | // nesterov: true,
50 | };
51 | // Now use backprop to run a linear regression between samples and get the coefficients back.
52 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
53 | let b = Var::new(0f32, &Device::Cpu)?;
54 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
55 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
56 | for _step in 0..100 {
57 | let ys = lin.forward(&sample_xs)?;
58 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
59 | n_sgd.backward_step(&loss)?;
60 | }
61 | if cfg!(target_os = "macos") {
62 | assert_eq!(to_vec2_round(&w, 3)?, &[[1.075, -9.904]]);
63 | assert_eq!(to_vec0_round(&b, 3)?, -1.896);
64 | } else {
65 | assert_eq!(to_vec2_round(&w, 4)?, &[[1.0750, -9.9042]]);
66 | assert_eq!(to_vec0_round(&b, 4)?, -1.8961);
67 | }
68 |
69 | Ok(())
70 | }
71 |
72 | /* The results of this test have been checked against the following PyTorch code.
73 | import torch
74 | from torch import optim
75 |
76 | w_gen = torch.tensor([[3., 1.]])
77 | b_gen = torch.tensor([-2.])
78 |
79 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
80 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
81 |
82 | m = torch.nn.Linear(2, 1)
83 | with torch.no_grad():
84 | m.weight.zero_()
85 | m.bias.zero_()
86 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=True, weight_decay = 0.1)
87 | # optimiser.zero_grad()
88 | for _step in range(100):
89 | optimiser.zero_grad()
90 | ys = m(sample_xs)
91 | loss = ((ys - sample_ys)**2).sum()
92 | loss.backward()
93 | optimiser.step()
94 | # print("Optimizer state begin")
95 | # print(optimiser.state)
96 | # print("Optimizer state end")
97 | print(m.weight)
98 | print(m.bias)
99 | */
100 | #[test]
101 | fn nesterov_decay_sgd_test() -> Result<()> {
102 | // Generate some linear data, y = 3.x1 + x2 - 2.
103 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
104 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
105 | let gen = Linear::new(w_gen, Some(b_gen));
106 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
107 | let sample_ys = gen.forward(&sample_xs)?;
108 |
109 | let params = ParamsSGD {
110 | lr: 0.004,
111 | weight_decay: Some(Decay::WeightDecay(0.1)),
112 | momentum: Some(Momentum::Nesterov(0.1)),
113 | dampening: 0.0,
114 | // nesterov: true,
115 | };
116 | // Now use backprop to run a linear regression between samples and get the coefficients back.
117 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
118 | let b = Var::new(0f32, &Device::Cpu)?;
119 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
120 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
121 | for _step in 0..100 {
122 | let ys = lin.forward(&sample_xs)?;
123 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
124 | n_sgd.backward_step(&loss)?;
125 | }
126 |
127 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9921, -10.3803]]);
128 | assert_eq!(to_vec0_round(&b, 4)?, -1.9331);
129 | Ok(())
130 | }
131 |
132 | /* The results of this test have been checked against the following PyTorch code.
133 | import torch
134 | from torch import optim
135 |
136 | w_gen = torch.tensor([[3., 1.]])
137 | b_gen = torch.tensor([-2.])
138 |
139 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
140 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
141 |
142 | m = torch.nn.Linear(2, 1)
143 | with torch.no_grad():
144 | m.weight.zero_()
145 | m.bias.zero_()
146 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=False, weight_decay = 0.0)
147 | # optimiser.zero_grad()
148 | for _step in range(100):
149 | optimiser.zero_grad()
150 | ys = m(sample_xs)
151 | loss = ((ys - sample_ys)**2).sum()
152 | loss.backward()
153 | optimiser.step()
154 | # print("Optimizer state begin")
155 | # print(optimiser.state)
156 | # print("Optimizer state end")
157 | print(m.weight)
158 | print(m.bias)
159 | */
160 | #[test]
161 | fn momentum_sgd_test() -> Result<()> {
162 | // Generate some linear data, y = 3.x1 + x2 - 2.
163 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
164 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
165 | let gen = Linear::new(w_gen, Some(b_gen));
166 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
167 | let sample_ys = gen.forward(&sample_xs)?;
168 |
169 | let params = ParamsSGD {
170 | lr: 0.004,
171 | weight_decay: None,
172 | momentum: Some(Momentum::Classical(0.1)),
173 | dampening: 0.0,
174 | // nesterov: false,s
175 | };
176 | // Now use backprop to run a linear regression between samples and get the coefficients back.
177 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
178 | let b = Var::new(0f32, &Device::Cpu)?;
179 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
180 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
181 | for _step in 0..100 {
182 | let ys = lin.forward(&sample_xs)?;
183 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
184 | n_sgd.backward_step(&loss)?;
185 | }
186 |
187 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8870, 0.8589]]);
188 | assert_eq!(to_vec0_round(&b, 4)?, -0.6341);
189 | Ok(())
190 | }
191 |
192 | /* The results of this test have been checked against the following PyTorch code.
193 | import torch
194 | from torch import optim
195 |
196 | w_gen = torch.tensor([[3., 1.]])
197 | b_gen = torch.tensor([-2.])
198 |
199 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
200 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
201 |
202 | m = torch.nn.Linear(2, 1)
203 | with torch.no_grad():
204 | m.weight.zero_()
205 | m.bias.zero_()
206 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=False, weight_decay = 0.4)
207 | # optimiser.zero_grad()
208 | for _step in range(100):
209 | optimiser.zero_grad()
210 | ys = m(sample_xs)
211 | loss = ((ys - sample_ys)**2).sum()
212 | loss.backward()
213 | optimiser.step()
214 | # print("Optimizer state begin")
215 | # print(optimiser.state)
216 | # print("Optimizer state end")
217 | print(m.weight)
218 | print(m.bias)
219 | */
220 | #[test]
221 | fn momentum_sgd_decay_test() -> Result<()> {
222 | // Generate some linear data, y = 3.x1 + x2 - 2.
223 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
224 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
225 | let gen = Linear::new(w_gen, Some(b_gen));
226 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
227 | let sample_ys = gen.forward(&sample_xs)?;
228 |
229 | let params = ParamsSGD {
230 | lr: 0.004,
231 | weight_decay: Some(Decay::WeightDecay(0.4)),
232 | momentum: Some(Momentum::Classical(0.1)),
233 | dampening: 0.0,
234 | // nesterov: false,
235 | };
236 | // Now use backprop to run a linear regression between samples and get the coefficients back.
237 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
238 | let b = Var::new(0f32, &Device::Cpu)?;
239 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
240 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
241 | for _step in 0..100 {
242 | let ys = lin.forward(&sample_xs)?;
243 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
244 | n_sgd.backward_step(&loss)?;
245 | }
246 |
247 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8751, 0.8514]]);
248 | assert_eq!(to_vec0_round(&b, 4)?, -0.5626);
249 | Ok(())
250 | }
251 |
252 | /* The results of this test have been checked against the following PyTorch code.
253 | import torch
254 | from torch import optim
255 |
256 | w_gen = torch.tensor([[3., 1.]])
257 | b_gen = torch.tensor([-2.])
258 |
259 | sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
260 | sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
261 |
262 | m = torch.nn.Linear(2, 1)
263 | with torch.no_grad():
264 | m.weight.zero_()
265 | m.bias.zero_()
266 | optimiser = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=False, weight_decay = 0.0, dampening = 0.2)
267 | # optimiser.zero_grad()
268 | for _step in range(100):
269 | optimiser.zero_grad()
270 | ys = m(sample_xs)
271 | loss = ((ys - sample_ys)**2).sum()
272 | loss.backward()
273 | optimiser.step()
274 | # print("Optimizer state begin")
275 | # print(optimiser.state)
276 | # print("Optimizer state end")
277 | print(m.weight)
278 | print(m.bias)
279 | */
280 | #[test]
281 | fn momentum_sgd_dampened_test() -> Result<()> {
282 | // Generate some linear data, y = 3.x1 + x2 - 2.
283 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
284 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
285 | let gen = Linear::new(w_gen, Some(b_gen));
286 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
287 | let sample_ys = gen.forward(&sample_xs)?;
288 |
289 | let params = ParamsSGD {
290 | lr: 0.004,
291 | weight_decay: None,
292 | momentum: Some(Momentum::Classical(0.1)),
293 | dampening: 0.2,
294 | // nesterov: false,
295 | };
296 | // Now use backprop to run a linear regression between samples and get the coefficients back.
297 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
298 | let b = Var::new(0f32, &Device::Cpu)?;
299 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
300 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
301 | for _step in 0..100 {
302 | let ys = lin.forward(&sample_xs)?;
303 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
304 | n_sgd.backward_step(&loss)?;
305 | }
306 |
307 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8746, 0.8434]]);
308 | assert_eq!(to_vec0_round(&b, 4)?, -0.4838);
309 | Ok(())
310 | }
311 |
312 | #[test]
313 | fn sgd_test() -> Result<()> {
314 | // Generate some linear data, y = 3.x1 + x2 - 2.
315 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
316 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
317 | let gen = Linear::new(w_gen, Some(b_gen));
318 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
319 | let sample_ys = gen.forward(&sample_xs)?;
320 |
321 | let params = ParamsSGD {
322 | lr: 0.004,
323 | weight_decay: None,
324 | momentum: None,
325 | dampening: 0.0,
326 | // nesterov: false,
327 | };
328 | // Now use backprop to run a linear regression between samples and get the coefficients back.
329 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
330 | let b = Var::new(0f32, &Device::Cpu)?;
331 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
332 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
333 | for _step in 0..100 {
334 | let ys = lin.forward(&sample_xs)?;
335 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
336 | n_sgd.backward_step(&loss)?;
337 | }
338 |
339 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8809, 0.8513]]);
340 | assert_eq!(to_vec0_round(&b, 4)?, -0.5606);
341 | Ok(())
342 | }
343 |
344 | #[test]
345 | fn sgd_decay_test() -> Result<()> {
346 | // Generate some linear data, y = 3.x1 + x2 - 2.
347 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
348 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
349 | let gen = Linear::new(w_gen, Some(b_gen));
350 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
351 | let sample_ys = gen.forward(&sample_xs)?;
352 |
353 | let params = ParamsSGD {
354 | lr: 0.004,
355 | weight_decay: Some(Decay::WeightDecay(0.4)),
356 | momentum: None,
357 | dampening: 0.0,
358 | // nesterov: false,
359 | };
360 | // Now use backprop to run a linear regression between samples and get the coefficients back.
361 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
362 | let b = Var::new(0f32, &Device::Cpu)?;
363 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
364 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
365 | for _step in 0..100 {
366 | let ys = lin.forward(&sample_xs)?;
367 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
368 | n_sgd.backward_step(&loss)?;
369 | }
370 |
371 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8700, 0.8450]]);
372 | assert_eq!(to_vec0_round(&b, 4)?, -0.5003);
373 | Ok(())
374 | }
375 |
376 | // The following are not tested against torch
377 | // As torch has no implementation of SGDW
378 |
379 | // This should be the same (as without momentum, decoupling is equivalent)
380 | #[test]
381 | fn sgdw_decay_test() -> Result<()> {
382 | // Generate some linear data, y = 3.x1 + x2 - 2.
383 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
384 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
385 | let gen = Linear::new(w_gen, Some(b_gen));
386 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
387 | let sample_ys = gen.forward(&sample_xs)?;
388 |
389 | let params = ParamsSGD {
390 | lr: 0.004,
391 | weight_decay: Some(Decay::DecoupledWeightDecay(0.4)),
392 | momentum: None,
393 | dampening: 0.0,
394 | // nesterov: false,
395 | };
396 | // Now use backprop to run a linear regression between samples and get the coefficients back.
397 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
398 | let b = Var::new(0f32, &Device::Cpu)?;
399 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
400 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
401 | for _step in 0..100 {
402 | let ys = lin.forward(&sample_xs)?;
403 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
404 | n_sgd.backward_step(&loss)?;
405 | }
406 |
407 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8700, 0.8450]]);
408 | assert_eq!(to_vec0_round(&b, 4)?, -0.5003);
409 | Ok(())
410 | }
411 |
412 | #[test]
413 | fn momentum_sgdw_decay_test() -> Result<()> {
414 | // Generate some linear data, y = 3.x1 + x2 - 2.
415 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
416 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
417 | let gen = Linear::new(w_gen, Some(b_gen));
418 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
419 | let sample_ys = gen.forward(&sample_xs)?;
420 |
421 | let params = ParamsSGD {
422 | lr: 0.004,
423 | weight_decay: Some(Decay::DecoupledWeightDecay(0.4)),
424 | momentum: Some(Momentum::Classical(0.1)),
425 | dampening: 0.0,
426 | // nesterov: false,
427 | };
428 | // Now use backprop to run a linear regression between samples and get the coefficients back.
429 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
430 | let b = Var::new(0f32, &Device::Cpu)?;
431 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
432 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
433 | for _step in 0..100 {
434 | let ys = lin.forward(&sample_xs)?;
435 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
436 | n_sgd.backward_step(&loss)?;
437 | }
438 |
439 | assert_eq!(to_vec2_round(&w, 4)?, &[[2.8763, 0.8521]]);
440 | assert_eq!(to_vec0_round(&b, 4)?, -0.5693);
441 | Ok(())
442 | }
443 |
444 | #[test]
445 | fn nesterov_decay_sgdw_test() -> Result<()> {
446 | // Generate some linear data, y = 3.x1 + x2 - 2.
447 | let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
448 | let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
449 | let gen = Linear::new(w_gen, Some(b_gen));
450 | let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
451 | let sample_ys = gen.forward(&sample_xs)?;
452 |
453 | let params = ParamsSGD {
454 | lr: 0.004,
455 | weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.1)),
456 | momentum: Some(Momentum::Nesterov(0.1)),
457 | dampening: 0.0,
458 | // nesterov: true,
459 | };
460 | // Now use backprop to run a linear regression between samples and get the coefficients back.
461 | let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
462 | let b = Var::new(0f32, &Device::Cpu)?;
463 | let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?;
464 | let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
465 | for _step in 0..100 {
466 | let ys = lin.forward(&sample_xs)?;
467 | let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
468 | n_sgd.backward_step(&loss)?;
469 | }
470 |
471 | assert_eq!(to_vec2_round(&w, 4)?, &[[0.9992, -10.3397]]);
472 | assert_eq!(to_vec0_round(&b, 4)?, -1.9302);
473 | Ok(())
474 | }
475 |
--------------------------------------------------------------------------------