├── .clog.toml ├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── line_fitting.rs └── minimal.rs └── src ├── gd.rs ├── lib.rs ├── line_search.rs ├── numeric.rs ├── problems.rs ├── sgd.rs ├── types.rs └── utils.rs /.clog.toml: -------------------------------------------------------------------------------- 1 | [clog] 2 | repository = "https://github.com/b52/optimization-rust" 3 | changelog = "CHANGELOG.md" 4 | from-latest-tag = true 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | target/ 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | addons: 3 | apt: 4 | packages: 5 | - libcurl4-openssl-dev 6 | - libelf-dev 7 | - libdw-dev 8 | - binutils-dev 9 | 10 | language: rust 11 | rust: 12 | - beta 13 | - stable 14 | 15 | before_script: 16 | - | 17 | pip install 'travis-cargo<0.2' --user && 18 | export PATH=$HOME/.local/bin:$PATH && 19 | rustup component add clippy 20 | 21 | script: 22 | - | 23 | travis-cargo build && 24 | cargo clippy --all-targets --all-features -- -D warnings && 25 | travis-cargo test && 26 | travis-cargo bench && 27 | travis-cargo --only stable doc 28 | 29 | after_success: 30 | - travis-cargo --only stable doc-upload 31 | - travis-cargo coveralls --no-sudo --verify 32 | 33 | env: 34 | global: 35 | - secure: "ySW1tffk3YRjoP5oi7yHgZyu/yiEPVppLK943c2csym05Scb8dLfgJKWY45kKpddf3hC8ZVVSwVIIuW7Jl6cnlAXv87awGvq/GG4HTaRbXeE7tNKBbyC8Pb2M7I1bICCqEV4EM/BxYFx/AB+BEfmcKlvi1by2fjFlKSeCGCnozZFjRJ6ai6a+IpP5T79IlfcAJytok6jFPkpSBez70TXQ02gCBhaLornS5Tw/X4RweQy9rVesw2kXMiIwg7DvEKgjviFEEKogKg/j5h9ik0ZOUk90jdDxF9glVZigOcAuXy3kAKQvJM9V5egLlCXAiz1nVrA5wF5hiYcqPCKB0VGv866qgnx7T+s2dqQqOFLbqYNB80m/REMfjV0pgpjh6O3+WQho1UDf5nLdPLTmuDt/F8BK9VUH/WQKX8yACxFKCWAu3zh/h/e8zWn0FYEIRBW36aQC3kPmR/nXxOwHCGzqqcBANiKKK38PInV4hXvyhWf+HRAO/PgQ3MxKxLjc4XIgNsXd670cbi7HR6eBFtTHFw0nQUrvP0V41XrhNqPcVWJMLzc0XyWODthzAjU44ubm32kDVMFImQBuYNfwBgXmIyCXhc8O+DlYbZURrD8Eih22t90u1bLR/Box1y44THJDRqUrSrDvgup1bP3/hIkzBYOmEh4MGh5BCtnO4dThE8=" 36 | 37 | notifications: 38 | email: false 39 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | ## 0.2.0 (2020-04-11) 3 | 4 | * Update dependencies 5 | * **BREAKING:** `StochasticGradientDescent::seed()` takes an `u64` now 6 | 7 | 8 | ## 0.1.0 (2016-06-06) 9 | 10 | 11 | #### Breaking Changes 12 | 13 | * **minimizer:** Add stochastic gradient descent ([b761a5db](https://github.com/b52/optimization-rust/commit/b761a5db504130df4219f369ae73dcacfd50d448), breaks [#](https://github.com/b52/optimization-rust/issues/)) 14 | 15 | #### Features 16 | 17 | * **minimizer:** Add stochastic gradient descent ([b761a5db](https://github.com/b52/optimization-rust/commit/b761a5db504130df4219f369ae73dcacfd50d448), breaks [#](https://github.com/b52/optimization-rust/issues/)) 18 | 19 | 20 | 21 | 22 | ## 0.0.1 (2016-04-26) 23 | 24 | * **Birth!** 25 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | [[package]] 4 | name = "aho-corasick" 5 | version = "0.7.10" 6 | source = "registry+https://github.com/rust-lang/crates.io-index" 7 | dependencies = [ 8 | "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 9 | ] 10 | 11 | [[package]] 12 | name = "atty" 13 | version = "0.2.14" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | dependencies = [ 16 | "hermit-abi 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", 17 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)", 18 | "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", 19 | ] 20 | 21 | [[package]] 22 | name = "cfg-if" 23 | version = "0.1.10" 24 | source = "registry+https://github.com/rust-lang/crates.io-index" 25 | 26 | [[package]] 27 | name = "env_logger" 28 | version = "0.7.1" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | dependencies = [ 31 | "atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", 32 | "humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 33 | "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", 34 | "regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 35 | "termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 36 | ] 37 | 38 | [[package]] 39 | name = "getrandom" 40 | version = "0.1.14" 41 | source = "registry+https://github.com/rust-lang/crates.io-index" 42 | dependencies = [ 43 | "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", 44 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)", 45 | "wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)", 46 | ] 47 | 48 | [[package]] 49 | name = "hermit-abi" 50 | version = "0.1.10" 51 | source = "registry+https://github.com/rust-lang/crates.io-index" 52 | dependencies = [ 53 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)", 54 | ] 55 | 56 | [[package]] 57 | name = "humantime" 58 | version = "1.3.0" 59 | source = "registry+https://github.com/rust-lang/crates.io-index" 60 | dependencies = [ 61 | "quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)", 62 | ] 63 | 64 | [[package]] 65 | name = "lazy_static" 66 | version = "1.4.0" 67 | source = "registry+https://github.com/rust-lang/crates.io-index" 68 | 69 | [[package]] 70 | name = "libc" 71 | version = "0.2.68" 72 | source = "registry+https://github.com/rust-lang/crates.io-index" 73 | 74 | [[package]] 75 | name = "log" 76 | version = "0.4.8" 77 | source = "registry+https://github.com/rust-lang/crates.io-index" 78 | dependencies = [ 79 | "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", 80 | ] 81 | 82 | [[package]] 83 | name = "memchr" 84 | version = "2.3.3" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | 87 | [[package]] 88 | name = "optimization" 89 | version = "0.2.0" 90 | dependencies = [ 91 | "env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", 92 | "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", 93 | "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", 94 | "rand_distr 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", 95 | "rand_pcg 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", 96 | ] 97 | 98 | [[package]] 99 | name = "ppv-lite86" 100 | version = "0.2.6" 101 | source = "registry+https://github.com/rust-lang/crates.io-index" 102 | 103 | [[package]] 104 | name = "quick-error" 105 | version = "1.2.3" 106 | source = "registry+https://github.com/rust-lang/crates.io-index" 107 | 108 | [[package]] 109 | name = "rand" 110 | version = "0.7.3" 111 | source = "registry+https://github.com/rust-lang/crates.io-index" 112 | dependencies = [ 113 | "getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)", 114 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)", 115 | "rand_chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", 116 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 117 | "rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", 118 | ] 119 | 120 | [[package]] 121 | name = "rand_chacha" 122 | version = "0.2.2" 123 | source = "registry+https://github.com/rust-lang/crates.io-index" 124 | dependencies = [ 125 | "ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", 126 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 127 | ] 128 | 129 | [[package]] 130 | name = "rand_core" 131 | version = "0.5.1" 132 | source = "registry+https://github.com/rust-lang/crates.io-index" 133 | dependencies = [ 134 | "getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)", 135 | ] 136 | 137 | [[package]] 138 | name = "rand_distr" 139 | version = "0.2.2" 140 | source = "registry+https://github.com/rust-lang/crates.io-index" 141 | dependencies = [ 142 | "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", 143 | ] 144 | 145 | [[package]] 146 | name = "rand_hc" 147 | version = "0.2.0" 148 | source = "registry+https://github.com/rust-lang/crates.io-index" 149 | dependencies = [ 150 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 151 | ] 152 | 153 | [[package]] 154 | name = "rand_pcg" 155 | version = "0.2.1" 156 | source = "registry+https://github.com/rust-lang/crates.io-index" 157 | dependencies = [ 158 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 159 | ] 160 | 161 | [[package]] 162 | name = "regex" 163 | version = "1.3.6" 164 | source = "registry+https://github.com/rust-lang/crates.io-index" 165 | dependencies = [ 166 | "aho-corasick 0.7.10 (registry+https://github.com/rust-lang/crates.io-index)", 167 | "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 168 | "regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)", 169 | "thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 170 | ] 171 | 172 | [[package]] 173 | name = "regex-syntax" 174 | version = "0.6.17" 175 | source = "registry+https://github.com/rust-lang/crates.io-index" 176 | 177 | [[package]] 178 | name = "termcolor" 179 | version = "1.1.0" 180 | source = "registry+https://github.com/rust-lang/crates.io-index" 181 | dependencies = [ 182 | "winapi-util 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", 183 | ] 184 | 185 | [[package]] 186 | name = "thread_local" 187 | version = "1.0.1" 188 | source = "registry+https://github.com/rust-lang/crates.io-index" 189 | dependencies = [ 190 | "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 191 | ] 192 | 193 | [[package]] 194 | name = "wasi" 195 | version = "0.9.0+wasi-snapshot-preview1" 196 | source = "registry+https://github.com/rust-lang/crates.io-index" 197 | 198 | [[package]] 199 | name = "winapi" 200 | version = "0.3.8" 201 | source = "registry+https://github.com/rust-lang/crates.io-index" 202 | dependencies = [ 203 | "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 204 | "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 205 | ] 206 | 207 | [[package]] 208 | name = "winapi-i686-pc-windows-gnu" 209 | version = "0.4.0" 210 | source = "registry+https://github.com/rust-lang/crates.io-index" 211 | 212 | [[package]] 213 | name = "winapi-util" 214 | version = "0.1.4" 215 | source = "registry+https://github.com/rust-lang/crates.io-index" 216 | dependencies = [ 217 | "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", 218 | ] 219 | 220 | [[package]] 221 | name = "winapi-x86_64-pc-windows-gnu" 222 | version = "0.4.0" 223 | source = "registry+https://github.com/rust-lang/crates.io-index" 224 | 225 | [metadata] 226 | "checksum aho-corasick 0.7.10 (registry+https://github.com/rust-lang/crates.io-index)" = "8716408b8bc624ed7f65d223ddb9ac2d044c0547b6fa4b0d554f3a9540496ada" 227 | "checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" 228 | "checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" 229 | "checksum env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36" 230 | "checksum getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "7abc8dd8451921606d809ba32e95b6111925cd2906060d2dcc29c070220503eb" 231 | "checksum hermit-abi 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "725cf19794cf90aa94e65050cb4191ff5d8fa87a498383774c47b332e3af952e" 232 | "checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" 233 | "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 234 | "checksum libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)" = "dea0c0405123bba743ee3f91f49b1c7cfb684eef0da0a50110f758ccf24cdff0" 235 | "checksum log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)" = "14b6052be84e6b71ab17edffc2eeabf5c2c3ae1fdb464aae35ac50c67a44e1f7" 236 | "checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400" 237 | "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" 238 | "checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" 239 | "checksum rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" 240 | "checksum rand_chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" 241 | "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" 242 | "checksum rand_distr 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "96977acbdd3a6576fb1d27391900035bf3863d4a16422973a409b488cf29ffb2" 243 | "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" 244 | "checksum rand_pcg 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429" 245 | "checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3" 246 | "checksum regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)" = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" 247 | "checksum termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bb6bfa289a4d7c5766392812c0a1f4c1ba45afa1ad47803c11e1f407d846d75f" 248 | "checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" 249 | "checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" 250 | "checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6" 251 | "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 252 | "checksum winapi-util 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "fa515c5163a99cc82bab70fd3bfdd36d827be85de63737b40fcef2ce084a436e" 253 | "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 254 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "optimization" 3 | version = "0.2.0" 4 | authors = ["Oliver Mader "] 5 | license = "MIT" 6 | description = "Collection of optimization algorithms" 7 | repository = "https://github.com/b52/optimization-rust" 8 | readme = "README.md" 9 | keywords = ["optimization", "minimization", "numeric"] 10 | 11 | [dependencies] 12 | log = "0.4" 13 | rand = "0.7" 14 | rand_distr = "0.2" 15 | rand_pcg = "0.2" 16 | 17 | [dev-dependencies] 18 | env_logger = "0.7" 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Oliver Mader 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # optimization [![Build Status](https://travis-ci.org/b52/optimization-rust.svg?branch=master)](https://travis-ci.org/b52/optimization-rust) [![Coverage Status](https://coveralls.io/repos/b52/optimization-rust/badge.svg?branch=master&service=github)](https://coveralls.io/github/b52/optimization-rust?branch=master) [![crates.io version](http://meritbadge.herokuapp.com/optimization)](https://crates.io/crates/optimization) 2 | 3 | Collection of optimization algorithms and strategies. 4 | 5 | ## Usage 6 | 7 | ```rust 8 | extern crate optimization; 9 | 10 | use optimmization::{Minimizer, GradientDescent, NumericalDifferentiation, Func}; 11 | 12 | // numeric version of the Rosenbrock function 13 | let function = NumericalDifferentiation::new(Func(|x: &[f64]| { 14 | (1.0 - x[0]).powi(2) + 100.0*(x[1] - x[0].powi(2)).powi(2) 15 | })); 16 | 17 | // we use a simple gradient descent scheme 18 | let minimizer = GradientDescent::new(); 19 | 20 | // perform the actual minimization, depending on the task this may 21 | // take some time, it may be useful to install a log sink to see 22 | // what's going on 23 | let solution = minimizer.minimize(&function, vec![-3.0, -4.0]); 24 | 25 | println!("Found solution for Rosenbrock function at f({:?}) = {:?}", 26 | solution.position, solution.value); 27 | ``` 28 | 29 | ## Installation 30 | 31 | Simply add it as a `Cargo` dependency: 32 | 33 | ```toml 34 | [dependencies] 35 | optimization = "*" 36 | ``` 37 | 38 | ## Documentation 39 | 40 | For an exhaustive documentation head over to the [API docs]. 41 | 42 | ## Development 43 | 44 | Simply download this crate, add your stuff, write some tests and create a pull request. Pretty simple! :) 45 | 46 | ```shell 47 | $ cargo test 48 | $ cargo clippy 49 | ``` 50 | 51 | ## License 52 | 53 | This software is licensed under the terms of the MIT license. Please see the 54 | [LICENSE](LICENSE) for full details. 55 | 56 | [API docs]: https://docs.rs/optimization 57 | -------------------------------------------------------------------------------- /examples/line_fitting.rs: -------------------------------------------------------------------------------- 1 | //! Illustration of fitting a linear regression model using stochastic gradient descent 2 | //! given a few noisy sample observations. 3 | //! 4 | //! Run with `cargo run --example line_fitting`. 5 | 6 | 7 | #![allow(clippy::many_single_char_names)] 8 | 9 | 10 | extern crate env_logger; 11 | extern crate rand; 12 | extern crate rand_distr; 13 | 14 | extern crate optimization; 15 | 16 | 17 | use std::f64::consts::PI; 18 | use rand::prelude::*; 19 | use rand_distr::StandardNormal; 20 | 21 | use optimization::*; 22 | 23 | 24 | fn main() { 25 | env_logger::init(); 26 | 27 | // the true coefficients of our linear model 28 | let true_coefficients = &[13.37, -4.2, PI]; 29 | 30 | println!("Trying to approximate the true linear regression coefficients {:?} using SGD \ 31 | given 100 noisy samples", true_coefficients); 32 | 33 | let noisy_observations = (0..100).map(|_| { 34 | let x = random::<[f64; 2]>(); 35 | let noise: f64 = thread_rng().sample(StandardNormal); 36 | let y = linear_regression(true_coefficients, &x) + noise; 37 | 38 | (x.to_vec(), y) 39 | }).collect(); 40 | 41 | 42 | // the actual function we want to minimize, which in our case corresponds to the 43 | // sum squared error 44 | let sse = SSE { 45 | observations: noisy_observations 46 | }; 47 | 48 | let solution = StochasticGradientDescent::new() 49 | .max_iterations(Some(1000)) 50 | .minimize(&sse, vec![1.0; true_coefficients.len()]); 51 | 52 | println!("Found coefficients {:?} with a SSE = {:?}", solution.position, solution.value); 53 | } 54 | 55 | 56 | // the sum squared error measure we want to minimize over a set of observations 57 | struct SSE { 58 | observations: Vec<(Vec, f64)> 59 | } 60 | 61 | impl Summation for SSE { 62 | fn terms(&self) -> usize { 63 | self.observations.len() 64 | } 65 | 66 | fn term_value(&self, w: &[f64], i: usize) -> f64 { 67 | let (ref x, y) = self.observations[i]; 68 | 69 | 0.5 * (y - linear_regression(w, x)).powi(2) 70 | } 71 | } 72 | 73 | impl Summation1 for SSE { 74 | fn term_gradient(&self, w: &[f64], i: usize) -> Vec { 75 | let (ref x, y) = self.observations[i]; 76 | 77 | let e = y - linear_regression(w, x); 78 | 79 | let mut gradient = vec![e * -1.0]; 80 | 81 | for x in x { 82 | gradient.push(e * -x); 83 | } 84 | 85 | gradient 86 | } 87 | } 88 | 89 | 90 | // a simple linear regression model, i.e., f(x) = w_0 + w_1*x_1 + w_2*x_2 + ... 91 | fn linear_regression(w: &[f64], x: &[f64]) -> f64 { 92 | let mut y = w[0]; 93 | 94 | for (w, x) in w[1..].iter().zip(x) { 95 | y += w * x; 96 | } 97 | 98 | y 99 | } 100 | -------------------------------------------------------------------------------- /examples/minimal.rs: -------------------------------------------------------------------------------- 1 | //! Minimizing the Rosenbrock function using Gradient Descent by applying 2 | //! numerical differentiation. 3 | //! 4 | //! Run with `cargo run --example minimal`. 5 | 6 | 7 | extern crate env_logger; 8 | 9 | extern crate optimization; 10 | 11 | 12 | use optimization::{Minimizer, GradientDescent, NumericalDifferentiation, Func}; 13 | 14 | 15 | pub fn main() { 16 | env_logger::init(); 17 | 18 | // the target function we want to minimize, for educational reasons we use 19 | // the Rosenbrock function 20 | let function = NumericalDifferentiation::new(Func(|x: &[f64]| { 21 | (1.0 - x[0]).powi(2) + 100.0*(x[1] - x[0].powi(2)).powi(2) 22 | })); 23 | 24 | // we use a simple gradient descent scheme 25 | let minimizer = GradientDescent::new(); 26 | 27 | // perform the actual minimization, depending on the task this may take some time 28 | // it may be useful to install a log sink to seew hat's going on 29 | let solution = minimizer.minimize(&function, vec![-3.0, -4.0]); 30 | 31 | println!("Found solution for Rosenbrock function at f({:?}) = {:?}", 32 | solution.position, solution.value); 33 | } 34 | -------------------------------------------------------------------------------- /src/gd.rs: -------------------------------------------------------------------------------- 1 | use log::Level::Trace; 2 | 3 | use types::{Function1, Minimizer, Solution}; 4 | use line_search::{LineSearch, ArmijoLineSearch}; 5 | use utils::is_saddle_point; 6 | 7 | 8 | /// A simple Gradient Descent optimizer. 9 | #[derive(Default)] 10 | pub struct GradientDescent { 11 | line_search: T, 12 | gradient_tolerance: f64, 13 | max_iterations: Option 14 | } 15 | 16 | impl GradientDescent { 17 | /// Creates a new `GradientDescent` optimizer using the following defaults: 18 | /// 19 | /// - **`line_search`** = `ArmijoLineSearch(0.5, 1.0, 0.5)` 20 | /// - **`gradient_tolerance`** = `1e-4` 21 | /// - **`max_iterations`** = `None` 22 | pub fn new() -> GradientDescent { 23 | GradientDescent { 24 | line_search: ArmijoLineSearch::new(0.5, 1.0, 0.5), 25 | gradient_tolerance: 1.0e-4, 26 | max_iterations: None 27 | } 28 | } 29 | } 30 | 31 | impl GradientDescent { 32 | /// Specifies the line search method to use. 33 | pub fn line_search(self, line_search: S) -> GradientDescent { 34 | GradientDescent { 35 | line_search, 36 | gradient_tolerance: self.gradient_tolerance, 37 | max_iterations: self.max_iterations 38 | } 39 | } 40 | 41 | /// Adjusts the gradient tolerance which is used as abort criterion to decide 42 | /// whether we reached a plateau. 43 | pub fn gradient_tolerance(mut self, gradient_tolerance: f64) -> Self { 44 | assert!(gradient_tolerance > 0.0); 45 | 46 | self.gradient_tolerance = gradient_tolerance; 47 | self 48 | } 49 | 50 | /// Adjusts the number of maximally run iterations. A value of `None` instructs the 51 | /// optimizer to ignore the nubmer of iterations. 52 | pub fn max_iterations(mut self, max_iterations: Option) -> Self { 53 | assert!(max_iterations.map_or(true, |max_iterations| max_iterations > 0)); 54 | 55 | self.max_iterations = max_iterations; 56 | self 57 | } 58 | } 59 | 60 | impl Minimizer for GradientDescent 61 | { 62 | type Solution = Solution; 63 | 64 | fn minimize(&self, function: &F, initial_position: Vec) -> Solution { 65 | info!("Starting gradient descent minimization: gradient_tolerance = {:?}, 66 | max_iterations = {:?}, line_search = {:?}", 67 | self.gradient_tolerance, self.max_iterations, self.line_search); 68 | 69 | let mut position = initial_position; 70 | let mut value = function.value(&position); 71 | 72 | if log_enabled!(Trace) { 73 | info!("Starting with y = {:?} for x = {:?}", value, position); 74 | } else { 75 | info!("Starting with y = {:?}", value); 76 | } 77 | 78 | let mut iteration = 0; 79 | 80 | loop { 81 | let gradient = function.gradient(&position); 82 | 83 | if is_saddle_point(&gradient, self.gradient_tolerance) { 84 | info!("Gradient to small, stopping optimization"); 85 | 86 | return Solution::new(position, value); 87 | } 88 | 89 | let direction: Vec<_> = gradient.into_iter().map(|g| -g).collect(); 90 | 91 | let iter_xs = self.line_search.search(function, &position, &direction); 92 | 93 | position = iter_xs; 94 | value = function.value(&position); 95 | 96 | iteration += 1; 97 | 98 | if log_enabled!(Trace) { 99 | debug!("Iteration {:6}: y = {:?}, x = {:?}", iteration, value, position); 100 | } else { 101 | debug!("Iteration {:6}: y = {:?}", iteration, value); 102 | } 103 | 104 | let reached_max_iterations = self.max_iterations.map_or(false, 105 | |max_iterations| iteration == max_iterations); 106 | 107 | if reached_max_iterations { 108 | info!("Reached maximal number of iterations, stopping optimization"); 109 | 110 | return Solution::new(position, value); 111 | } 112 | } 113 | } 114 | } 115 | 116 | 117 | #[cfg(test)] 118 | mod tests { 119 | use problems::{Sphere, Rosenbrock}; 120 | 121 | use super::GradientDescent; 122 | 123 | test_minimizer!{GradientDescent::new(), 124 | sphere => Sphere::default(), 125 | rosenbrock => Rosenbrock::default()} 126 | } 127 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Collection of various optimization algorithms and strategies. 2 | //! 3 | //! # Building Blocks 4 | //! 5 | //! Each central primitive is specified by a trait: 6 | //! 7 | //! - **`Function`** - Specifies a function that can be minimized 8 | //! - **`Function1`** - Extends a `Function` by its first derivative 9 | //! - **`Summation`** - Represents a summation of functions, exploited, e.g., by SGD 10 | //! - **`Summation1`** - Analogous to `Function` and `Function1` but for `Summation` 11 | //! - **`Minimizer`** - A minimization algorithm 12 | //! - **`Evaluation`** - A function evaluation `f(x) = y` that is returned by a `Minimizer` 13 | //! - **`Func`** - A new-type wrapper for the `Function` trait 14 | //! - **`NumericalDifferentiation`** - Provides numerical differentiation for arbitrary `Function`s 15 | //! 16 | //! # Algorithms 17 | //! 18 | //! Currently, the following algorithms are implemented. This list is not final and being 19 | //! expanded over time. 20 | //! 21 | //! - **`GradientDescent`** - Iterative gradient descent minimization, supporting various line 22 | //! search methods: 23 | //! - *`FixedStepWidth`* - No line search is performed, but a fixed step width is used 24 | //! - *`ExactLineSearch`* - Exhaustive line search over a set of step widths 25 | //! - *`ArmijoLineSearch`* - Backtracking line search using the Armijo rule as stopping 26 | //! criterion 27 | //! - **`StochasticGradientDescent`** - Iterative stochastic gradient descenent minimazation, 28 | //! currently using a fixed step width 29 | 30 | 31 | #[macro_use] 32 | extern crate log; 33 | extern crate rand; 34 | extern crate rand_pcg; 35 | 36 | 37 | #[macro_use] 38 | pub mod problems; 39 | 40 | mod types; 41 | mod utils; 42 | mod numeric; 43 | mod line_search; 44 | mod gd; 45 | mod sgd; 46 | 47 | 48 | pub use types::{Function, Function1, Func, Minimizer, Evaluation, Summation, Summation1}; 49 | pub use numeric::NumericalDifferentiation; 50 | pub use line_search::{LineSearch, FixedStepWidth, ExactLineSearch, ArmijoLineSearch}; 51 | pub use gd::GradientDescent; 52 | pub use sgd::StochasticGradientDescent; 53 | -------------------------------------------------------------------------------- /src/line_search.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::ops::Add; 3 | 4 | use types::{Function, Function1}; 5 | 6 | 7 | /// Define a line search method, i.e., choosing an appropriate step width. 8 | pub trait LineSearch: Debug { 9 | /// Performs the actual line search given the current `position` `x` and a `direction` to go to. 10 | /// Returns the new position. 11 | fn search(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec 12 | where F: Function1; 13 | } 14 | 15 | 16 | /// Uses a fixed step width `γ` in each iteration instead of performing an actual line search. 17 | #[derive(Debug, Copy, Clone)] 18 | pub struct FixedStepWidth { 19 | fixed_step_width: f64 20 | } 21 | 22 | impl FixedStepWidth { 23 | /// Creates a new `FixedStepWidth` given the static step width. 24 | pub fn new(fixed_step_width: f64) -> FixedStepWidth { 25 | assert!(fixed_step_width > 0.0 && fixed_step_width.is_finite(), 26 | "fixed_step_width must be greater than 0 and finite"); 27 | 28 | FixedStepWidth { 29 | fixed_step_width 30 | } 31 | } 32 | } 33 | 34 | impl LineSearch for FixedStepWidth { 35 | fn search(&self, _function: &F, initial_position: &[f64], direction: &[f64]) -> Vec 36 | where F: Function 37 | { 38 | initial_position.iter().cloned().zip(direction).map(|(x, d)| { 39 | x + self.fixed_step_width * d 40 | }).collect() 41 | } 42 | } 43 | 44 | 45 | /// Brute-force line search minimizing the objective function over a set of 46 | /// step width candidates, also known as exact line search. 47 | #[derive(Debug, Copy, Clone)] 48 | pub struct ExactLineSearch { 49 | start_step_width: f64, 50 | stop_step_width: f64, 51 | increase_factor: f64 52 | } 53 | 54 | impl ExactLineSearch { 55 | /// Creates a new `ExactLineSearch` given the `start_step_width`, the `stop_step_width` 56 | /// and the `increase_factor`. The set of evaluated step widths `γ` is specified as 57 | /// `{ γ | γ = start_step_width · increase_factorⁱ, i ∈ N, γ <= stop_step_width }`, 58 | /// assuming that `start_step_width` < `stop_step_width` and `increase_factor` > 1. 59 | pub fn new(start_step_width: f64, stop_step_width: f64, increase_factor: f64) -> 60 | ExactLineSearch 61 | { 62 | assert!(start_step_width > 0.0 && start_step_width.is_finite(), 63 | "start_step_width must be greater than 0 and finite"); 64 | assert!(stop_step_width > start_step_width && stop_step_width.is_finite(), 65 | "stop_step_width must be greater than start_step_width"); 66 | assert!(increase_factor > 1.0 && increase_factor.is_finite(), 67 | "increase_factor must be greater than 1 and finite"); 68 | 69 | ExactLineSearch { 70 | start_step_width, 71 | stop_step_width, 72 | increase_factor 73 | } 74 | } 75 | } 76 | 77 | impl LineSearch for ExactLineSearch { 78 | fn search(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec 79 | where F: Function1 80 | { 81 | let mut min_position = initial_position.to_vec(); 82 | let mut min_value = function.value(initial_position); 83 | 84 | let mut step_width = self.start_step_width; 85 | 86 | loop { 87 | let position: Vec<_> = initial_position.iter().cloned().zip(direction).map(|(x, d)| { 88 | x + step_width * d 89 | }).collect(); 90 | let value = function.value(&position); 91 | 92 | if value < min_value { 93 | min_position = position; 94 | min_value = value; 95 | } 96 | 97 | step_width *= self.increase_factor; 98 | 99 | if step_width >= self.stop_step_width { 100 | break; 101 | } 102 | } 103 | 104 | min_position 105 | } 106 | } 107 | 108 | 109 | /// Backtracking line search evaluating the Armijo rule at each step width. 110 | #[derive(Debug, Copy, Clone)] 111 | pub struct ArmijoLineSearch { 112 | control_parameter: f64, 113 | initial_step_width: f64, 114 | decay_factor: f64 115 | } 116 | 117 | impl ArmijoLineSearch { 118 | /// Creates a new `ArmijoLineSearch` given the `control_parameter` ∈ (0, 1), the 119 | /// `initial_step_width` > 0 and the `decay_factor` ∈ (0, 1). 120 | /// 121 | /// Armijo used in his paper the values 0.5, 1.0 and 0.5, respectively. 122 | pub fn new(control_parameter: f64, initial_step_width: f64, decay_factor: f64) -> 123 | ArmijoLineSearch 124 | { 125 | assert!(control_parameter > 0.0 && control_parameter < 1.0, 126 | "control_parameter must be in range (0, 1)"); 127 | assert!(initial_step_width > 0.0 && initial_step_width.is_finite(), 128 | "initial_step_width must be > 0 and finite"); 129 | assert!(decay_factor > 0.0 && decay_factor < 1.0, "decay_factor must be in range (0, 1)"); 130 | 131 | ArmijoLineSearch { 132 | control_parameter, 133 | initial_step_width, 134 | decay_factor 135 | } 136 | } 137 | } 138 | 139 | impl LineSearch for ArmijoLineSearch { 140 | fn search(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec 141 | where F: Function1 142 | { 143 | let initial_value = function.value(initial_position); 144 | let gradient = function.gradient(initial_position); 145 | 146 | let m = gradient.iter().zip(direction).map(|(g, d)| g * d).fold(0.0, Add::add); 147 | let t = -self.control_parameter * m; 148 | 149 | assert!(t > 0.0); 150 | 151 | let mut step_width = self.initial_step_width; 152 | 153 | loop { 154 | let position: Vec<_> = initial_position.iter().cloned().zip(direction).map(|(x, d)| { 155 | x + step_width * d 156 | }).collect(); 157 | let value = function.value(&position); 158 | 159 | if value <= initial_value - step_width * t { 160 | return position; 161 | } 162 | 163 | step_width *= self.decay_factor; 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/numeric.rs: -------------------------------------------------------------------------------- 1 | use std::f64::EPSILON; 2 | 3 | use problems::Problem; 4 | use types::{Function, Function1}; 5 | 6 | 7 | /// Wraps a function for which to provide numeric differentiation. 8 | /// 9 | /// Uses simple one step forward finite difference with step width `h = √εx`. 10 | /// 11 | /// # Examples 12 | /// 13 | /// ``` 14 | /// # use self::optimization::*; 15 | /// let square = NumericalDifferentiation::new(Func(|x: &[f64]| { 16 | /// x[0] * x[0] 17 | /// })); 18 | /// 19 | /// assert!(square.gradient(&[0.0])[0] < 1.0e-3); 20 | /// assert!(square.gradient(&[1.0])[0] > 1.0); 21 | /// assert!(square.gradient(&[-1.0])[0] < 1.0); 22 | /// ``` 23 | pub struct NumericalDifferentiation { 24 | function: F 25 | } 26 | 27 | impl NumericalDifferentiation { 28 | /// Creates a new differentiable function by using the supplied `function` in 29 | /// combination with numeric differentiation to find the derivatives. 30 | pub fn new(function: F) -> Self { 31 | NumericalDifferentiation { 32 | function 33 | } 34 | } 35 | } 36 | 37 | impl Function for NumericalDifferentiation { 38 | fn value(&self, position: &[f64]) -> f64 { 39 | self.function.value(position) 40 | } 41 | } 42 | 43 | impl Function1 for NumericalDifferentiation { 44 | fn gradient(&self, position: &[f64]) -> Vec { 45 | let mut x: Vec<_> = position.to_vec(); 46 | 47 | let current = self.value(&x); 48 | 49 | position.iter().cloned().enumerate().map(|(i, x_i)| { 50 | let h = if x_i == 0.0 { 51 | EPSILON * 1.0e10 52 | } else { 53 | (EPSILON * x_i.abs()).sqrt() 54 | }; 55 | 56 | assert!(h.is_finite()); 57 | 58 | x[i] = x_i + h; 59 | 60 | let forward = self.function.value(&x); 61 | 62 | x[i] = x_i; 63 | 64 | let d_i = (forward - current) / h; 65 | 66 | assert!(d_i.is_finite()); 67 | 68 | d_i 69 | }).collect() 70 | } 71 | } 72 | 73 | impl Default for NumericalDifferentiation { 74 | fn default() -> Self { 75 | NumericalDifferentiation::new(F::default()) 76 | } 77 | } 78 | 79 | impl Problem for NumericalDifferentiation { 80 | fn dimensions(&self) -> usize { 81 | self.function.dimensions() 82 | } 83 | 84 | fn domain(&self) -> Vec<(f64, f64)> { 85 | self.function.domain() 86 | } 87 | 88 | fn minimum(&self) -> (Vec, f64) { 89 | self.function.minimum() 90 | } 91 | 92 | fn random_start(&self) -> Vec { 93 | self.function.random_start() 94 | } 95 | } 96 | 97 | 98 | #[cfg(test)] 99 | mod tests { 100 | use types::Function1; 101 | use problems::{Problem, Sphere, Rosenbrock}; 102 | use utils::are_close; 103 | use gd::GradientDescent; 104 | 105 | use super::NumericalDifferentiation; 106 | 107 | #[test] 108 | fn test_accuracy() { 109 | //let a = Sphere::default(); 110 | let b = Rosenbrock::default(); 111 | 112 | // FIXME: How to iterate over different problems? 113 | let problems = vec![b]; 114 | 115 | for analytical_problem in problems { 116 | let numerical_problem = NumericalDifferentiation::new(analytical_problem); 117 | 118 | for _ in 0..1000 { 119 | let position = analytical_problem.random_start(); 120 | 121 | let analytical_gradient = analytical_problem.gradient(&position); 122 | let numerical_gradient = numerical_problem.gradient(&position); 123 | 124 | assert_eq!(analytical_gradient.len(), numerical_gradient.len()); 125 | 126 | assert!(analytical_gradient.into_iter().zip(numerical_gradient).all(|(a, n)| 127 | a.is_finite() && n.is_finite() && are_close(a, n, 1.0e-3) 128 | )); 129 | } 130 | } 131 | } 132 | 133 | test_minimizer!{GradientDescent::new(), 134 | test_gd_sphere => NumericalDifferentiation::new(Sphere::default()), 135 | test_gd_rosenbrock => NumericalDifferentiation::new(Rosenbrock::default())} 136 | } 137 | -------------------------------------------------------------------------------- /src/problems.rs: -------------------------------------------------------------------------------- 1 | //! Common optimization problems for testing purposes. 2 | //! 3 | //! Currently, the following [optimization test functions] are implemented. 4 | //! 5 | //! ## Bowl-Shaped 6 | //! 7 | //! * [`Sphere`](http://www.sfu.ca/~ssurjano/spheref.html) 8 | //! 9 | //! ## Valley-Shaped 10 | //! 11 | //! * [`Rosenbrock`](http://www.sfu.ca/~ssurjano/rosen.html) 12 | //! 13 | //! [optimization test functions]: http://www.sfu.ca/~ssurjano/optimization.html 14 | 15 | use rand::random; 16 | use std::f64::INFINITY; 17 | use std::ops::Add; 18 | 19 | use types::{Function, Function1}; 20 | 21 | 22 | /// Specifies a well known optimization problem. 23 | pub trait Problem: Function + Default { 24 | /// Returns the dimensionality of the input domain. 25 | fn dimensions(&self) -> usize; 26 | 27 | /// Returns the input domain of the function in terms of upper and lower, 28 | /// respectively, for each input dimension. 29 | fn domain(&self) -> Vec<(f64, f64)>; 30 | 31 | /// Returns the position as well as the value of the global minimum. 32 | fn minimum(&self) -> (Vec, f64); 33 | 34 | /// Generates a random and **feasible** position to start a minimization. 35 | fn random_start(&self) -> Vec; 36 | 37 | /// Tests whether the supplied position is legal for this function. 38 | fn is_legal_position(&self, position: &[f64]) -> bool { 39 | position.len() == self.dimensions() && 40 | position.iter().zip(self.domain()).all(|(&x, (lower, upper))| { 41 | lower < x && x < upper 42 | }) 43 | } 44 | } 45 | 46 | 47 | macro_rules! define_problem { 48 | ( $name:ident: $this:ident, 49 | default: $def:expr, 50 | dimensions: $dims:expr, 51 | domain: $domain:expr, 52 | minimum: $miny:expr, 53 | at: $minx:expr, 54 | start: $start:expr, 55 | value: $x1:ident => $value:expr, 56 | gradient: $x2:ident => $gradient:expr ) => 57 | { 58 | impl Default for $name { 59 | fn default() -> Self { 60 | $def 61 | } 62 | } 63 | 64 | impl Function for $name { 65 | fn value(&$this, $x1: &[f64]) -> f64 { 66 | assert!($this.is_legal_position($x1)); 67 | 68 | $value 69 | } 70 | } 71 | 72 | impl Function1 for $name { 73 | fn gradient(&$this, $x2: &[f64]) -> Vec { 74 | assert!($this.is_legal_position($x2)); 75 | 76 | $gradient 77 | } 78 | } 79 | 80 | impl Problem for $name { 81 | fn dimensions(&$this) -> usize { 82 | $dims 83 | } 84 | 85 | fn domain(&$this) -> Vec<(f64, f64)> { 86 | $domain 87 | } 88 | 89 | fn minimum(&$this) -> (Vec, f64) { 90 | ($minx, $miny) 91 | } 92 | 93 | fn random_start(&$this) -> Vec { 94 | $start 95 | } 96 | } 97 | }; 98 | } 99 | 100 | 101 | /// n-dimensional Sphere function. 102 | /// 103 | /// It is continuous, convex and unimodal: 104 | /// 105 | /// > f(x) = ∑ᵢ xᵢ² 106 | /// 107 | /// *Global minimum*: `f(0,...,0) = 0` 108 | #[derive(Debug, Copy, Clone)] 109 | pub struct Sphere { 110 | dimensions: usize 111 | } 112 | 113 | impl Sphere { 114 | pub fn new(dimensions: usize) -> Sphere { 115 | assert!(dimensions > 0, "dimensions must be larger than 1"); 116 | 117 | Sphere { 118 | dimensions 119 | } 120 | } 121 | } 122 | 123 | define_problem!{Sphere: self, 124 | default: Sphere::new(2), 125 | dimensions: self.dimensions, 126 | domain: (0..self.dimensions).map(|_| (-INFINITY, INFINITY)).collect(), 127 | minimum: 0.0, 128 | at: (0..self.dimensions).map(|_| 0.0).collect(), 129 | start: (0..self.dimensions).map(|_| random::() * 10.24 - 5.12).collect(), 130 | value: x => x.iter().map(|x| x.powi(2)).fold(0.0, Add::add), 131 | gradient: x => x.iter().map(|x| 2.0 * x).collect() 132 | } 133 | 134 | 135 | /// Two-dimensional Rosenbrock function. 136 | /// 137 | /// A non-convex function with its global minimum inside a long, narrow, parabolic 138 | /// shaped flat valley: 139 | /// 140 | /// > f(x, y) = (a - x)² + b (y - x²)² 141 | /// 142 | /// *Global minimum*: `f(a, a²) = 0` 143 | #[derive(Debug, Copy, Clone)] 144 | pub struct Rosenbrock { 145 | a: f64, 146 | b: f64 147 | } 148 | 149 | impl Rosenbrock { 150 | /// Creates a new `Rosenbrock` function given `a` and `b`, commonly definied 151 | /// with 1 and 100, respectively, which also corresponds to the `default`. 152 | pub fn new(a: f64, b: f64) -> Rosenbrock { 153 | Rosenbrock { 154 | a, 155 | b 156 | } 157 | } 158 | } 159 | 160 | define_problem!{Rosenbrock: self, 161 | default: Rosenbrock::new(1.0, 100.0), 162 | dimensions: 2, 163 | domain: vec![(-INFINITY, INFINITY), (-INFINITY, INFINITY)], 164 | minimum: 0.0, 165 | at: vec![self.a, self.a * self.a], 166 | start: (0..2).map(|_| random::() * 4.096 - 2.048).collect(), 167 | value: x => (self.a - x[0]).powi(2) + self.b * (x[1] - x[0].powi(2)).powi(2), 168 | gradient: x => vec![-2.0 * self.a + 4.0 * self.b * x[0].powi(3) - 4.0 * self.b * x[0] * x[1] + 2.0 * x[0], 169 | 2.0 * self.b * (x[1] - x[0].powi(2))] 170 | } 171 | 172 | 173 | /* 174 | pub struct McCormick; 175 | 176 | impl McCormick { 177 | pub fn new() -> McCormick { 178 | McCormick 179 | } 180 | } 181 | 182 | define_problem!{McCormick: self, 183 | default: McCormick::new(), 184 | dimensions: 2, 185 | domain: vec![(-INFINITY, INFINITY), (-INFINITY, INFINITY)], 186 | minimum: -1.9133, 187 | at: vec![-0.54719, -1.54719], 188 | start: vec![random::() * 5.5 - 1.5, random::() * 7.0 - 3.0], 189 | value: x => (x[0] + x[1]).sin() + (x[0] - x[1]).powi(2) - 1.5 * x[0] + 2.5 * x[1] + 1.0, 190 | gradient: x => vec![(x[0] + x[1]).cos() + 2.0 * (x[0] - x[1]) - 1.5, 191 | (x[0] + x[1]).cos() - 2.0 * (x[0] - x[1]) + 2.5] 192 | } 193 | */ 194 | 195 | 196 | #[cfg(test)] 197 | macro_rules! test_minimizer { 198 | ( $minimizer:expr, $( $name:ident => $problem:expr ),* ) => { 199 | $( 200 | #[test] 201 | fn $name() { 202 | let minimizer = $minimizer; 203 | let problem = $problem; 204 | 205 | for _ in 0..100 { 206 | let position = $crate::problems::Problem::random_start(&problem); 207 | 208 | let solution = $crate::Minimizer::minimize(&minimizer, 209 | &problem, position); 210 | 211 | let distance = $crate::Evaluation::position(&solution).iter() 212 | .zip($crate::problems::Problem::minimum(&problem).0) 213 | .map(|(a, b)| (a - b).powi(2)) 214 | .fold(0.0, ::std::ops::Add::add) 215 | .sqrt(); 216 | 217 | assert!(distance < 1.0e-2); 218 | } 219 | } 220 | )* 221 | }; 222 | } 223 | -------------------------------------------------------------------------------- /src/sgd.rs: -------------------------------------------------------------------------------- 1 | use log::Level::Trace; 2 | use rand::{SeedableRng, random}; 3 | use rand::seq::SliceRandom; 4 | use rand_pcg::Pcg64Mcg; 5 | 6 | use types::{Minimizer, Solution, Summation1}; 7 | 8 | 9 | /// Provides _stochastic_ Gradient Descent optimization. 10 | pub struct StochasticGradientDescent { 11 | rng: Pcg64Mcg, 12 | max_iterations: Option, 13 | mini_batch: usize, 14 | step_width: f64 15 | } 16 | 17 | impl StochasticGradientDescent { 18 | /// Creates a new `StochasticGradientDescent` optimizer using the following defaults: 19 | /// 20 | /// - **`step_width`** = `0.01` 21 | /// - **`mini_batch`** = `1` 22 | /// - **`max_iterations`** = `1000` 23 | /// 24 | /// The used random number generator is randomly seeded. 25 | pub fn new() -> StochasticGradientDescent { 26 | StochasticGradientDescent { 27 | rng: Pcg64Mcg::new(random()), 28 | max_iterations: None, 29 | mini_batch: 1, 30 | step_width: 0.01 31 | } 32 | } 33 | 34 | /// Seeds the random number generator using the supplied `seed`. 35 | /// 36 | /// This is useful to create re-producable results. 37 | pub fn seed(&mut self, seed: u64) -> &mut Self { 38 | self.rng = Pcg64Mcg::seed_from_u64(seed); 39 | self 40 | } 41 | 42 | /// Adjusts the number of maximally run iterations. A value of `None` instructs the 43 | /// optimizer to ignore the nubmer of iterations. 44 | pub fn max_iterations(&mut self, max_iterations: Option) -> &mut Self { 45 | assert!(max_iterations.map_or(true, |max_iterations| max_iterations > 0)); 46 | 47 | self.max_iterations = max_iterations; 48 | self 49 | } 50 | 51 | /// Adjusts the mini batch size, i.e., how many terms are considered in one step at most. 52 | pub fn mini_batch(&mut self, mini_batch: usize) -> &mut Self { 53 | assert!(mini_batch > 0); 54 | 55 | self.mini_batch = mini_batch; 56 | self 57 | } 58 | 59 | /// Adjusts the step size applied for each mini batch. 60 | pub fn step_width(&mut self, step_width: f64) -> &mut Self { 61 | assert!(step_width > 0.0); 62 | 63 | self.step_width = step_width; 64 | self 65 | } 66 | } 67 | 68 | impl Default for StochasticGradientDescent { 69 | fn default() -> Self { 70 | Self::new() 71 | } 72 | } 73 | 74 | impl Minimizer for StochasticGradientDescent { 75 | type Solution = Solution; 76 | 77 | fn minimize(&self, function: &F, initial_position: Vec) -> Solution { 78 | let mut position = initial_position; 79 | let mut value = function.value(&position); 80 | 81 | if log_enabled!(Trace) { 82 | info!("Starting with y = {:?} for x = {:?}", value, position); 83 | } else { 84 | info!("Starting with y = {:?}", value); 85 | } 86 | 87 | let mut iteration = 0; 88 | let mut terms: Vec<_> = (0..function.terms()).collect(); 89 | let mut rng = self.rng.clone(); 90 | 91 | loop { 92 | // ensure that we don't run into cycles 93 | terms.shuffle(&mut rng); 94 | 95 | for batch in terms.chunks(self.mini_batch) { 96 | let gradient = function.partial_gradient(&position, batch); 97 | 98 | // step into the direction of the negative gradient 99 | for (x, g) in position.iter_mut().zip(gradient) { 100 | *x -= self.step_width * g; 101 | } 102 | } 103 | 104 | value = function.value(&position); 105 | 106 | iteration += 1; 107 | 108 | if log_enabled!(Trace) { 109 | debug!("Iteration {:6}: y = {:?}, x = {:?}", iteration, value, position); 110 | } else { 111 | debug!("Iteration {:6}: y = {:?}", iteration, value); 112 | } 113 | 114 | let reached_max_iterations = self.max_iterations.map_or(false, 115 | |max_iterations| iteration == max_iterations); 116 | 117 | if reached_max_iterations { 118 | info!("Reached maximal number of iterations, stopping optimization"); 119 | return Solution::new(position, value); 120 | } 121 | } 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/types.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Borrow; 2 | 3 | 4 | /// Defines an objective function `f` that is subject to minimization. 5 | /// 6 | /// For convenience every function with the same signature as `value()` qualifies as 7 | /// an objective function, e.g., minimizing a closure is perfectly fine. 8 | pub trait Function { 9 | /// Computes the objective function at a given `position` `x`, i.e., `f(x) = y`. 10 | fn value(&self, position: &[f64]) -> f64; 11 | } 12 | 13 | 14 | /// New-type to support optimization of arbitrary functions without requiring 15 | /// to implement a trait. 16 | pub struct Func f64>(pub F); 17 | 18 | impl f64> Function for Func { 19 | fn value(&self, position: &[f64]) -> f64 { 20 | self.0(position) 21 | } 22 | } 23 | 24 | 25 | /// Defines an objective function `f` that is able to compute the first derivative 26 | /// `f'(x)`. 27 | pub trait Function1: Function { 28 | /// Computes the gradient of the objective function at a given `position` `x`, 29 | /// i.e., `∀ᵢ ∂/∂xᵢ f(x) = ∇f(x)`. 30 | fn gradient(&self, position: &[f64]) -> Vec; 31 | } 32 | 33 | 34 | /// Defines a summation of individual functions, i.e., f(x) = ∑ᵢ fᵢ(x). 35 | pub trait Summation: Function { 36 | /// Returns the number of individual functions that are terms of the summation. 37 | fn terms(&self) -> usize; 38 | 39 | /// Comptues the value of one individual function indentified by its index `term`, 40 | /// given the `position` `x`. 41 | fn term_value(&self, position: &[f64], term: usize) -> f64; 42 | 43 | /// Computes the partial sum over a set of individual functions identified by `terms`. 44 | fn partial_value, I: Borrow>(&self, position: &[f64], terms: T) -> f64 { 45 | let mut value = 0.0; 46 | 47 | for term in terms { 48 | value += self.term_value(position, *term.borrow()); 49 | } 50 | 51 | value 52 | } 53 | } 54 | 55 | impl Function for S { 56 | fn value(&self, position: &[f64]) -> f64 { 57 | self.partial_value(position, 0..self.terms()) 58 | } 59 | } 60 | 61 | 62 | /// Defines a summation of individual functions `fᵢ(x)`, assuming that each function has a first 63 | /// derivative. 64 | pub trait Summation1: Summation + Function1 { 65 | /// Computes the gradient of one individual function identified by `term` at the given 66 | /// `position`. 67 | fn term_gradient(&self, position: &[f64], term: usize) -> Vec; 68 | 69 | /// Computes the partial gradient over a set of `terms` at the given `position`. 70 | fn partial_gradient, I: Borrow>(&self, position: &[f64], terms: T) -> Vec { 71 | let mut gradient = vec![0.0; position.len()]; 72 | 73 | for term in terms { 74 | for (g, gi) in gradient.iter_mut().zip(self.term_gradient(position, *term.borrow())) { 75 | *g += gi; 76 | } 77 | } 78 | 79 | gradient 80 | } 81 | } 82 | 83 | impl Function1 for S { 84 | fn gradient(&self, position: &[f64]) -> Vec { 85 | self.partial_gradient(position, 0..self.terms()) 86 | } 87 | } 88 | 89 | 90 | /// Defines an optimizer that is able to minimize a given objective function `F`. 91 | pub trait Minimizer { 92 | /// Type of the solution the `Minimizer` returns. 93 | type Solution: Evaluation; 94 | 95 | /// Performs the actual minimization and returns a solution that 96 | /// might be better than the initially provided one. 97 | fn minimize(&self, function: &F, initial_position: Vec) -> Self::Solution; 98 | } 99 | 100 | 101 | /// Captures the essence of a function evaluation. 102 | pub trait Evaluation { 103 | /// Position `x` with the lowest corresponding value `f(x)`. 104 | fn position(&self) -> &[f64]; 105 | 106 | /// The actual value `f(x)`. 107 | fn value(&self) -> f64; 108 | } 109 | 110 | 111 | /// A solution of a minimization run providing only the minimal information. 112 | /// 113 | /// Each `Minimizer` might yield different types of solution structs which provide more 114 | /// information. 115 | #[derive(Debug, Clone)] 116 | pub struct Solution { 117 | /// Position `x` of the lowest corresponding value `f(x)` that has been found. 118 | pub position: Vec, 119 | /// The actual value `f(x)`. 120 | pub value: f64 121 | } 122 | 123 | impl Solution { 124 | /// Creates a new `Solution` given the `position` as well as the corresponding `value`. 125 | pub fn new(position: Vec, value: f64) -> Solution { 126 | Solution { 127 | position, 128 | value 129 | } 130 | } 131 | } 132 | 133 | impl Evaluation for Solution { 134 | fn position(&self) -> &[f64] { 135 | &self.position 136 | } 137 | 138 | fn value(&self) -> f64 { 139 | self.value 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | use std::f64::{MIN_POSITIVE, MAX}; 3 | 4 | 5 | /// Tests whether we reached a flat area, i.e., tests if all absolute gradient component 6 | /// lie within the `tolerance`. 7 | pub fn is_saddle_point(gradient: &[f64], tolerance: f64) -> bool { 8 | gradient.iter().all(|dx| dx.abs() <= tolerance) 9 | } 10 | 11 | 12 | /// Tests whether two floating point numbers are close using the relative error 13 | /// and handling special cases like infinity etc. 14 | #[cfg(test)] 15 | #[allow(clippy::float_cmp)] 16 | pub fn are_close(a: f64, b: f64, eps: f64) -> bool { 17 | assert!(eps.is_finite()); 18 | 19 | let d = (a - b).abs(); 20 | 21 | // identical, e.g., infinity 22 | a == b 23 | 24 | // a or b is zero or both are extremely close to it 25 | // relative error is less meaningful here 26 | || ((a == 0.0 || b == 0.0 || d < MIN_POSITIVE) && 27 | d < eps * MIN_POSITIVE) 28 | 29 | // finally, use the relative error 30 | || d / (a + b).min(MAX) < eps 31 | } 32 | 33 | 34 | #[cfg(test)] 35 | mod tests { 36 | use std::f64::{INFINITY, NAN}; 37 | 38 | use super::{is_saddle_point, are_close}; 39 | 40 | #[test] 41 | fn test_is_saddle_point() { 42 | assert!(is_saddle_point(&[1.0, 2.0], 2.0)); 43 | assert!(is_saddle_point(&[1.0, -2.0], 2.0)); 44 | assert!(!is_saddle_point(&[1.0, 2.1], 2.0)); 45 | assert!(!is_saddle_point(&[1.0, -2.1], 2.0)); 46 | } 47 | 48 | #[test] 49 | fn test_are_close() { 50 | assert!(are_close(1.0, 1.0, 0.00001)); 51 | assert!(are_close(INFINITY, INFINITY, 0.00001)); 52 | assert!(are_close(1.0e-1000, 0.0, 0.1)); 53 | assert!(!are_close(1.0e-40, 0.0, 0.000_001)); 54 | assert!(!are_close(2.0, 1.0, 0.00001)); 55 | assert!(!are_close(NAN, NAN, 0.00001)); 56 | } 57 | } 58 | --------------------------------------------------------------------------------