├── .clog.toml
├── .gitignore
├── .travis.yml
├── CHANGELOG.md
├── Cargo.lock
├── Cargo.toml
├── LICENSE
├── README.md
├── examples
├── line_fitting.rs
└── minimal.rs
└── src
├── gd.rs
├── lib.rs
├── line_search.rs
├── numeric.rs
├── problems.rs
├── sgd.rs
├── types.rs
└── utils.rs
/.clog.toml:
--------------------------------------------------------------------------------
1 | [clog]
2 | repository = "https://github.com/b52/optimization-rust"
3 | changelog = "CHANGELOG.md"
4 | from-latest-tag = true
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .*
2 | target/
3 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: false
2 | addons:
3 | apt:
4 | packages:
5 | - libcurl4-openssl-dev
6 | - libelf-dev
7 | - libdw-dev
8 | - binutils-dev
9 |
10 | language: rust
11 | rust:
12 | - beta
13 | - stable
14 |
15 | before_script:
16 | - |
17 | pip install 'travis-cargo<0.2' --user &&
18 | export PATH=$HOME/.local/bin:$PATH &&
19 | rustup component add clippy
20 |
21 | script:
22 | - |
23 | travis-cargo build &&
24 | cargo clippy --all-targets --all-features -- -D warnings &&
25 | travis-cargo test &&
26 | travis-cargo bench &&
27 | travis-cargo --only stable doc
28 |
29 | after_success:
30 | - travis-cargo --only stable doc-upload
31 | - travis-cargo coveralls --no-sudo --verify
32 |
33 | env:
34 | global:
35 | - secure: "ySW1tffk3YRjoP5oi7yHgZyu/yiEPVppLK943c2csym05Scb8dLfgJKWY45kKpddf3hC8ZVVSwVIIuW7Jl6cnlAXv87awGvq/GG4HTaRbXeE7tNKBbyC8Pb2M7I1bICCqEV4EM/BxYFx/AB+BEfmcKlvi1by2fjFlKSeCGCnozZFjRJ6ai6a+IpP5T79IlfcAJytok6jFPkpSBez70TXQ02gCBhaLornS5Tw/X4RweQy9rVesw2kXMiIwg7DvEKgjviFEEKogKg/j5h9ik0ZOUk90jdDxF9glVZigOcAuXy3kAKQvJM9V5egLlCXAiz1nVrA5wF5hiYcqPCKB0VGv866qgnx7T+s2dqQqOFLbqYNB80m/REMfjV0pgpjh6O3+WQho1UDf5nLdPLTmuDt/F8BK9VUH/WQKX8yACxFKCWAu3zh/h/e8zWn0FYEIRBW36aQC3kPmR/nXxOwHCGzqqcBANiKKK38PInV4hXvyhWf+HRAO/PgQ3MxKxLjc4XIgNsXd670cbi7HR6eBFtTHFw0nQUrvP0V41XrhNqPcVWJMLzc0XyWODthzAjU44ubm32kDVMFImQBuYNfwBgXmIyCXhc8O+DlYbZURrD8Eih22t90u1bLR/Box1y44THJDRqUrSrDvgup1bP3/hIkzBYOmEh4MGh5BCtnO4dThE8="
36 |
37 | notifications:
38 | email: false
39 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 |
2 | ## 0.2.0 (2020-04-11)
3 |
4 | * Update dependencies
5 | * **BREAKING:** `StochasticGradientDescent::seed()` takes an `u64` now
6 |
7 |
8 | ## 0.1.0 (2016-06-06)
9 |
10 |
11 | #### Breaking Changes
12 |
13 | * **minimizer:** Add stochastic gradient descent ([b761a5db](https://github.com/b52/optimization-rust/commit/b761a5db504130df4219f369ae73dcacfd50d448), breaks [#](https://github.com/b52/optimization-rust/issues/))
14 |
15 | #### Features
16 |
17 | * **minimizer:** Add stochastic gradient descent ([b761a5db](https://github.com/b52/optimization-rust/commit/b761a5db504130df4219f369ae73dcacfd50d448), breaks [#](https://github.com/b52/optimization-rust/issues/))
18 |
19 |
20 |
21 |
22 | ## 0.0.1 (2016-04-26)
23 |
24 | * **Birth!**
25 |
--------------------------------------------------------------------------------
/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | [[package]]
4 | name = "aho-corasick"
5 | version = "0.7.10"
6 | source = "registry+https://github.com/rust-lang/crates.io-index"
7 | dependencies = [
8 | "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
9 | ]
10 |
11 | [[package]]
12 | name = "atty"
13 | version = "0.2.14"
14 | source = "registry+https://github.com/rust-lang/crates.io-index"
15 | dependencies = [
16 | "hermit-abi 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
17 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
18 | "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
19 | ]
20 |
21 | [[package]]
22 | name = "cfg-if"
23 | version = "0.1.10"
24 | source = "registry+https://github.com/rust-lang/crates.io-index"
25 |
26 | [[package]]
27 | name = "env_logger"
28 | version = "0.7.1"
29 | source = "registry+https://github.com/rust-lang/crates.io-index"
30 | dependencies = [
31 | "atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
32 | "humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
33 | "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
34 | "regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
35 | "termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
36 | ]
37 |
38 | [[package]]
39 | name = "getrandom"
40 | version = "0.1.14"
41 | source = "registry+https://github.com/rust-lang/crates.io-index"
42 | dependencies = [
43 | "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
44 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
45 | "wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)",
46 | ]
47 |
48 | [[package]]
49 | name = "hermit-abi"
50 | version = "0.1.10"
51 | source = "registry+https://github.com/rust-lang/crates.io-index"
52 | dependencies = [
53 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
54 | ]
55 |
56 | [[package]]
57 | name = "humantime"
58 | version = "1.3.0"
59 | source = "registry+https://github.com/rust-lang/crates.io-index"
60 | dependencies = [
61 | "quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
62 | ]
63 |
64 | [[package]]
65 | name = "lazy_static"
66 | version = "1.4.0"
67 | source = "registry+https://github.com/rust-lang/crates.io-index"
68 |
69 | [[package]]
70 | name = "libc"
71 | version = "0.2.68"
72 | source = "registry+https://github.com/rust-lang/crates.io-index"
73 |
74 | [[package]]
75 | name = "log"
76 | version = "0.4.8"
77 | source = "registry+https://github.com/rust-lang/crates.io-index"
78 | dependencies = [
79 | "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
80 | ]
81 |
82 | [[package]]
83 | name = "memchr"
84 | version = "2.3.3"
85 | source = "registry+https://github.com/rust-lang/crates.io-index"
86 |
87 | [[package]]
88 | name = "optimization"
89 | version = "0.2.0"
90 | dependencies = [
91 | "env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
92 | "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
93 | "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
94 | "rand_distr 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
95 | "rand_pcg 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
96 | ]
97 |
98 | [[package]]
99 | name = "ppv-lite86"
100 | version = "0.2.6"
101 | source = "registry+https://github.com/rust-lang/crates.io-index"
102 |
103 | [[package]]
104 | name = "quick-error"
105 | version = "1.2.3"
106 | source = "registry+https://github.com/rust-lang/crates.io-index"
107 |
108 | [[package]]
109 | name = "rand"
110 | version = "0.7.3"
111 | source = "registry+https://github.com/rust-lang/crates.io-index"
112 | dependencies = [
113 | "getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)",
114 | "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)",
115 | "rand_chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
116 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
117 | "rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
118 | ]
119 |
120 | [[package]]
121 | name = "rand_chacha"
122 | version = "0.2.2"
123 | source = "registry+https://github.com/rust-lang/crates.io-index"
124 | dependencies = [
125 | "ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
126 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
127 | ]
128 |
129 | [[package]]
130 | name = "rand_core"
131 | version = "0.5.1"
132 | source = "registry+https://github.com/rust-lang/crates.io-index"
133 | dependencies = [
134 | "getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)",
135 | ]
136 |
137 | [[package]]
138 | name = "rand_distr"
139 | version = "0.2.2"
140 | source = "registry+https://github.com/rust-lang/crates.io-index"
141 | dependencies = [
142 | "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
143 | ]
144 |
145 | [[package]]
146 | name = "rand_hc"
147 | version = "0.2.0"
148 | source = "registry+https://github.com/rust-lang/crates.io-index"
149 | dependencies = [
150 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
151 | ]
152 |
153 | [[package]]
154 | name = "rand_pcg"
155 | version = "0.2.1"
156 | source = "registry+https://github.com/rust-lang/crates.io-index"
157 | dependencies = [
158 | "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
159 | ]
160 |
161 | [[package]]
162 | name = "regex"
163 | version = "1.3.6"
164 | source = "registry+https://github.com/rust-lang/crates.io-index"
165 | dependencies = [
166 | "aho-corasick 0.7.10 (registry+https://github.com/rust-lang/crates.io-index)",
167 | "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
168 | "regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
169 | "thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
170 | ]
171 |
172 | [[package]]
173 | name = "regex-syntax"
174 | version = "0.6.17"
175 | source = "registry+https://github.com/rust-lang/crates.io-index"
176 |
177 | [[package]]
178 | name = "termcolor"
179 | version = "1.1.0"
180 | source = "registry+https://github.com/rust-lang/crates.io-index"
181 | dependencies = [
182 | "winapi-util 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
183 | ]
184 |
185 | [[package]]
186 | name = "thread_local"
187 | version = "1.0.1"
188 | source = "registry+https://github.com/rust-lang/crates.io-index"
189 | dependencies = [
190 | "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
191 | ]
192 |
193 | [[package]]
194 | name = "wasi"
195 | version = "0.9.0+wasi-snapshot-preview1"
196 | source = "registry+https://github.com/rust-lang/crates.io-index"
197 |
198 | [[package]]
199 | name = "winapi"
200 | version = "0.3.8"
201 | source = "registry+https://github.com/rust-lang/crates.io-index"
202 | dependencies = [
203 | "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
204 | "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
205 | ]
206 |
207 | [[package]]
208 | name = "winapi-i686-pc-windows-gnu"
209 | version = "0.4.0"
210 | source = "registry+https://github.com/rust-lang/crates.io-index"
211 |
212 | [[package]]
213 | name = "winapi-util"
214 | version = "0.1.4"
215 | source = "registry+https://github.com/rust-lang/crates.io-index"
216 | dependencies = [
217 | "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
218 | ]
219 |
220 | [[package]]
221 | name = "winapi-x86_64-pc-windows-gnu"
222 | version = "0.4.0"
223 | source = "registry+https://github.com/rust-lang/crates.io-index"
224 |
225 | [metadata]
226 | "checksum aho-corasick 0.7.10 (registry+https://github.com/rust-lang/crates.io-index)" = "8716408b8bc624ed7f65d223ddb9ac2d044c0547b6fa4b0d554f3a9540496ada"
227 | "checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
228 | "checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
229 | "checksum env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36"
230 | "checksum getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "7abc8dd8451921606d809ba32e95b6111925cd2906060d2dcc29c070220503eb"
231 | "checksum hermit-abi 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "725cf19794cf90aa94e65050cb4191ff5d8fa87a498383774c47b332e3af952e"
232 | "checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
233 | "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
234 | "checksum libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)" = "dea0c0405123bba743ee3f91f49b1c7cfb684eef0da0a50110f758ccf24cdff0"
235 | "checksum log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)" = "14b6052be84e6b71ab17edffc2eeabf5c2c3ae1fdb464aae35ac50c67a44e1f7"
236 | "checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
237 | "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
238 | "checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
239 | "checksum rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
240 | "checksum rand_chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
241 | "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
242 | "checksum rand_distr 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "96977acbdd3a6576fb1d27391900035bf3863d4a16422973a409b488cf29ffb2"
243 | "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
244 | "checksum rand_pcg 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429"
245 | "checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"
246 | "checksum regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)" = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae"
247 | "checksum termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bb6bfa289a4d7c5766392812c0a1f4c1ba45afa1ad47803c11e1f407d846d75f"
248 | "checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14"
249 | "checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
250 | "checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6"
251 | "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
252 | "checksum winapi-util 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "fa515c5163a99cc82bab70fd3bfdd36d827be85de63737b40fcef2ce084a436e"
253 | "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
254 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "optimization"
3 | version = "0.2.0"
4 | authors = ["Oliver Mader "]
5 | license = "MIT"
6 | description = "Collection of optimization algorithms"
7 | repository = "https://github.com/b52/optimization-rust"
8 | readme = "README.md"
9 | keywords = ["optimization", "minimization", "numeric"]
10 |
11 | [dependencies]
12 | log = "0.4"
13 | rand = "0.7"
14 | rand_distr = "0.2"
15 | rand_pcg = "0.2"
16 |
17 | [dev-dependencies]
18 | env_logger = "0.7"
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2016 Oliver Mader
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # optimization [](https://travis-ci.org/b52/optimization-rust) [](https://coveralls.io/github/b52/optimization-rust?branch=master) [](https://crates.io/crates/optimization)
2 |
3 | Collection of optimization algorithms and strategies.
4 |
5 | ## Usage
6 |
7 | ```rust
8 | extern crate optimization;
9 |
10 | use optimmization::{Minimizer, GradientDescent, NumericalDifferentiation, Func};
11 |
12 | // numeric version of the Rosenbrock function
13 | let function = NumericalDifferentiation::new(Func(|x: &[f64]| {
14 | (1.0 - x[0]).powi(2) + 100.0*(x[1] - x[0].powi(2)).powi(2)
15 | }));
16 |
17 | // we use a simple gradient descent scheme
18 | let minimizer = GradientDescent::new();
19 |
20 | // perform the actual minimization, depending on the task this may
21 | // take some time, it may be useful to install a log sink to see
22 | // what's going on
23 | let solution = minimizer.minimize(&function, vec![-3.0, -4.0]);
24 |
25 | println!("Found solution for Rosenbrock function at f({:?}) = {:?}",
26 | solution.position, solution.value);
27 | ```
28 |
29 | ## Installation
30 |
31 | Simply add it as a `Cargo` dependency:
32 |
33 | ```toml
34 | [dependencies]
35 | optimization = "*"
36 | ```
37 |
38 | ## Documentation
39 |
40 | For an exhaustive documentation head over to the [API docs].
41 |
42 | ## Development
43 |
44 | Simply download this crate, add your stuff, write some tests and create a pull request. Pretty simple! :)
45 |
46 | ```shell
47 | $ cargo test
48 | $ cargo clippy
49 | ```
50 |
51 | ## License
52 |
53 | This software is licensed under the terms of the MIT license. Please see the
54 | [LICENSE](LICENSE) for full details.
55 |
56 | [API docs]: https://docs.rs/optimization
57 |
--------------------------------------------------------------------------------
/examples/line_fitting.rs:
--------------------------------------------------------------------------------
1 | //! Illustration of fitting a linear regression model using stochastic gradient descent
2 | //! given a few noisy sample observations.
3 | //!
4 | //! Run with `cargo run --example line_fitting`.
5 |
6 |
7 | #![allow(clippy::many_single_char_names)]
8 |
9 |
10 | extern crate env_logger;
11 | extern crate rand;
12 | extern crate rand_distr;
13 |
14 | extern crate optimization;
15 |
16 |
17 | use std::f64::consts::PI;
18 | use rand::prelude::*;
19 | use rand_distr::StandardNormal;
20 |
21 | use optimization::*;
22 |
23 |
24 | fn main() {
25 | env_logger::init();
26 |
27 | // the true coefficients of our linear model
28 | let true_coefficients = &[13.37, -4.2, PI];
29 |
30 | println!("Trying to approximate the true linear regression coefficients {:?} using SGD \
31 | given 100 noisy samples", true_coefficients);
32 |
33 | let noisy_observations = (0..100).map(|_| {
34 | let x = random::<[f64; 2]>();
35 | let noise: f64 = thread_rng().sample(StandardNormal);
36 | let y = linear_regression(true_coefficients, &x) + noise;
37 |
38 | (x.to_vec(), y)
39 | }).collect();
40 |
41 |
42 | // the actual function we want to minimize, which in our case corresponds to the
43 | // sum squared error
44 | let sse = SSE {
45 | observations: noisy_observations
46 | };
47 |
48 | let solution = StochasticGradientDescent::new()
49 | .max_iterations(Some(1000))
50 | .minimize(&sse, vec![1.0; true_coefficients.len()]);
51 |
52 | println!("Found coefficients {:?} with a SSE = {:?}", solution.position, solution.value);
53 | }
54 |
55 |
56 | // the sum squared error measure we want to minimize over a set of observations
57 | struct SSE {
58 | observations: Vec<(Vec, f64)>
59 | }
60 |
61 | impl Summation for SSE {
62 | fn terms(&self) -> usize {
63 | self.observations.len()
64 | }
65 |
66 | fn term_value(&self, w: &[f64], i: usize) -> f64 {
67 | let (ref x, y) = self.observations[i];
68 |
69 | 0.5 * (y - linear_regression(w, x)).powi(2)
70 | }
71 | }
72 |
73 | impl Summation1 for SSE {
74 | fn term_gradient(&self, w: &[f64], i: usize) -> Vec {
75 | let (ref x, y) = self.observations[i];
76 |
77 | let e = y - linear_regression(w, x);
78 |
79 | let mut gradient = vec![e * -1.0];
80 |
81 | for x in x {
82 | gradient.push(e * -x);
83 | }
84 |
85 | gradient
86 | }
87 | }
88 |
89 |
90 | // a simple linear regression model, i.e., f(x) = w_0 + w_1*x_1 + w_2*x_2 + ...
91 | fn linear_regression(w: &[f64], x: &[f64]) -> f64 {
92 | let mut y = w[0];
93 |
94 | for (w, x) in w[1..].iter().zip(x) {
95 | y += w * x;
96 | }
97 |
98 | y
99 | }
100 |
--------------------------------------------------------------------------------
/examples/minimal.rs:
--------------------------------------------------------------------------------
1 | //! Minimizing the Rosenbrock function using Gradient Descent by applying
2 | //! numerical differentiation.
3 | //!
4 | //! Run with `cargo run --example minimal`.
5 |
6 |
7 | extern crate env_logger;
8 |
9 | extern crate optimization;
10 |
11 |
12 | use optimization::{Minimizer, GradientDescent, NumericalDifferentiation, Func};
13 |
14 |
15 | pub fn main() {
16 | env_logger::init();
17 |
18 | // the target function we want to minimize, for educational reasons we use
19 | // the Rosenbrock function
20 | let function = NumericalDifferentiation::new(Func(|x: &[f64]| {
21 | (1.0 - x[0]).powi(2) + 100.0*(x[1] - x[0].powi(2)).powi(2)
22 | }));
23 |
24 | // we use a simple gradient descent scheme
25 | let minimizer = GradientDescent::new();
26 |
27 | // perform the actual minimization, depending on the task this may take some time
28 | // it may be useful to install a log sink to seew hat's going on
29 | let solution = minimizer.minimize(&function, vec![-3.0, -4.0]);
30 |
31 | println!("Found solution for Rosenbrock function at f({:?}) = {:?}",
32 | solution.position, solution.value);
33 | }
34 |
--------------------------------------------------------------------------------
/src/gd.rs:
--------------------------------------------------------------------------------
1 | use log::Level::Trace;
2 |
3 | use types::{Function1, Minimizer, Solution};
4 | use line_search::{LineSearch, ArmijoLineSearch};
5 | use utils::is_saddle_point;
6 |
7 |
8 | /// A simple Gradient Descent optimizer.
9 | #[derive(Default)]
10 | pub struct GradientDescent {
11 | line_search: T,
12 | gradient_tolerance: f64,
13 | max_iterations: Option
14 | }
15 |
16 | impl GradientDescent {
17 | /// Creates a new `GradientDescent` optimizer using the following defaults:
18 | ///
19 | /// - **`line_search`** = `ArmijoLineSearch(0.5, 1.0, 0.5)`
20 | /// - **`gradient_tolerance`** = `1e-4`
21 | /// - **`max_iterations`** = `None`
22 | pub fn new() -> GradientDescent {
23 | GradientDescent {
24 | line_search: ArmijoLineSearch::new(0.5, 1.0, 0.5),
25 | gradient_tolerance: 1.0e-4,
26 | max_iterations: None
27 | }
28 | }
29 | }
30 |
31 | impl GradientDescent {
32 | /// Specifies the line search method to use.
33 | pub fn line_search(self, line_search: S) -> GradientDescent {
34 | GradientDescent {
35 | line_search,
36 | gradient_tolerance: self.gradient_tolerance,
37 | max_iterations: self.max_iterations
38 | }
39 | }
40 |
41 | /// Adjusts the gradient tolerance which is used as abort criterion to decide
42 | /// whether we reached a plateau.
43 | pub fn gradient_tolerance(mut self, gradient_tolerance: f64) -> Self {
44 | assert!(gradient_tolerance > 0.0);
45 |
46 | self.gradient_tolerance = gradient_tolerance;
47 | self
48 | }
49 |
50 | /// Adjusts the number of maximally run iterations. A value of `None` instructs the
51 | /// optimizer to ignore the nubmer of iterations.
52 | pub fn max_iterations(mut self, max_iterations: Option) -> Self {
53 | assert!(max_iterations.map_or(true, |max_iterations| max_iterations > 0));
54 |
55 | self.max_iterations = max_iterations;
56 | self
57 | }
58 | }
59 |
60 | impl Minimizer for GradientDescent
61 | {
62 | type Solution = Solution;
63 |
64 | fn minimize(&self, function: &F, initial_position: Vec) -> Solution {
65 | info!("Starting gradient descent minimization: gradient_tolerance = {:?},
66 | max_iterations = {:?}, line_search = {:?}",
67 | self.gradient_tolerance, self.max_iterations, self.line_search);
68 |
69 | let mut position = initial_position;
70 | let mut value = function.value(&position);
71 |
72 | if log_enabled!(Trace) {
73 | info!("Starting with y = {:?} for x = {:?}", value, position);
74 | } else {
75 | info!("Starting with y = {:?}", value);
76 | }
77 |
78 | let mut iteration = 0;
79 |
80 | loop {
81 | let gradient = function.gradient(&position);
82 |
83 | if is_saddle_point(&gradient, self.gradient_tolerance) {
84 | info!("Gradient to small, stopping optimization");
85 |
86 | return Solution::new(position, value);
87 | }
88 |
89 | let direction: Vec<_> = gradient.into_iter().map(|g| -g).collect();
90 |
91 | let iter_xs = self.line_search.search(function, &position, &direction);
92 |
93 | position = iter_xs;
94 | value = function.value(&position);
95 |
96 | iteration += 1;
97 |
98 | if log_enabled!(Trace) {
99 | debug!("Iteration {:6}: y = {:?}, x = {:?}", iteration, value, position);
100 | } else {
101 | debug!("Iteration {:6}: y = {:?}", iteration, value);
102 | }
103 |
104 | let reached_max_iterations = self.max_iterations.map_or(false,
105 | |max_iterations| iteration == max_iterations);
106 |
107 | if reached_max_iterations {
108 | info!("Reached maximal number of iterations, stopping optimization");
109 |
110 | return Solution::new(position, value);
111 | }
112 | }
113 | }
114 | }
115 |
116 |
117 | #[cfg(test)]
118 | mod tests {
119 | use problems::{Sphere, Rosenbrock};
120 |
121 | use super::GradientDescent;
122 |
123 | test_minimizer!{GradientDescent::new(),
124 | sphere => Sphere::default(),
125 | rosenbrock => Rosenbrock::default()}
126 | }
127 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | //! Collection of various optimization algorithms and strategies.
2 | //!
3 | //! # Building Blocks
4 | //!
5 | //! Each central primitive is specified by a trait:
6 | //!
7 | //! - **`Function`** - Specifies a function that can be minimized
8 | //! - **`Function1`** - Extends a `Function` by its first derivative
9 | //! - **`Summation`** - Represents a summation of functions, exploited, e.g., by SGD
10 | //! - **`Summation1`** - Analogous to `Function` and `Function1` but for `Summation`
11 | //! - **`Minimizer`** - A minimization algorithm
12 | //! - **`Evaluation`** - A function evaluation `f(x) = y` that is returned by a `Minimizer`
13 | //! - **`Func`** - A new-type wrapper for the `Function` trait
14 | //! - **`NumericalDifferentiation`** - Provides numerical differentiation for arbitrary `Function`s
15 | //!
16 | //! # Algorithms
17 | //!
18 | //! Currently, the following algorithms are implemented. This list is not final and being
19 | //! expanded over time.
20 | //!
21 | //! - **`GradientDescent`** - Iterative gradient descent minimization, supporting various line
22 | //! search methods:
23 | //! - *`FixedStepWidth`* - No line search is performed, but a fixed step width is used
24 | //! - *`ExactLineSearch`* - Exhaustive line search over a set of step widths
25 | //! - *`ArmijoLineSearch`* - Backtracking line search using the Armijo rule as stopping
26 | //! criterion
27 | //! - **`StochasticGradientDescent`** - Iterative stochastic gradient descenent minimazation,
28 | //! currently using a fixed step width
29 |
30 |
31 | #[macro_use]
32 | extern crate log;
33 | extern crate rand;
34 | extern crate rand_pcg;
35 |
36 |
37 | #[macro_use]
38 | pub mod problems;
39 |
40 | mod types;
41 | mod utils;
42 | mod numeric;
43 | mod line_search;
44 | mod gd;
45 | mod sgd;
46 |
47 |
48 | pub use types::{Function, Function1, Func, Minimizer, Evaluation, Summation, Summation1};
49 | pub use numeric::NumericalDifferentiation;
50 | pub use line_search::{LineSearch, FixedStepWidth, ExactLineSearch, ArmijoLineSearch};
51 | pub use gd::GradientDescent;
52 | pub use sgd::StochasticGradientDescent;
53 |
--------------------------------------------------------------------------------
/src/line_search.rs:
--------------------------------------------------------------------------------
1 | use std::fmt::Debug;
2 | use std::ops::Add;
3 |
4 | use types::{Function, Function1};
5 |
6 |
7 | /// Define a line search method, i.e., choosing an appropriate step width.
8 | pub trait LineSearch: Debug {
9 | /// Performs the actual line search given the current `position` `x` and a `direction` to go to.
10 | /// Returns the new position.
11 | fn search(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec
12 | where F: Function1;
13 | }
14 |
15 |
16 | /// Uses a fixed step width `γ` in each iteration instead of performing an actual line search.
17 | #[derive(Debug, Copy, Clone)]
18 | pub struct FixedStepWidth {
19 | fixed_step_width: f64
20 | }
21 |
22 | impl FixedStepWidth {
23 | /// Creates a new `FixedStepWidth` given the static step width.
24 | pub fn new(fixed_step_width: f64) -> FixedStepWidth {
25 | assert!(fixed_step_width > 0.0 && fixed_step_width.is_finite(),
26 | "fixed_step_width must be greater than 0 and finite");
27 |
28 | FixedStepWidth {
29 | fixed_step_width
30 | }
31 | }
32 | }
33 |
34 | impl LineSearch for FixedStepWidth {
35 | fn search(&self, _function: &F, initial_position: &[f64], direction: &[f64]) -> Vec
36 | where F: Function
37 | {
38 | initial_position.iter().cloned().zip(direction).map(|(x, d)| {
39 | x + self.fixed_step_width * d
40 | }).collect()
41 | }
42 | }
43 |
44 |
45 | /// Brute-force line search minimizing the objective function over a set of
46 | /// step width candidates, also known as exact line search.
47 | #[derive(Debug, Copy, Clone)]
48 | pub struct ExactLineSearch {
49 | start_step_width: f64,
50 | stop_step_width: f64,
51 | increase_factor: f64
52 | }
53 |
54 | impl ExactLineSearch {
55 | /// Creates a new `ExactLineSearch` given the `start_step_width`, the `stop_step_width`
56 | /// and the `increase_factor`. The set of evaluated step widths `γ` is specified as
57 | /// `{ γ | γ = start_step_width · increase_factorⁱ, i ∈ N, γ <= stop_step_width }`,
58 | /// assuming that `start_step_width` < `stop_step_width` and `increase_factor` > 1.
59 | pub fn new(start_step_width: f64, stop_step_width: f64, increase_factor: f64) ->
60 | ExactLineSearch
61 | {
62 | assert!(start_step_width > 0.0 && start_step_width.is_finite(),
63 | "start_step_width must be greater than 0 and finite");
64 | assert!(stop_step_width > start_step_width && stop_step_width.is_finite(),
65 | "stop_step_width must be greater than start_step_width");
66 | assert!(increase_factor > 1.0 && increase_factor.is_finite(),
67 | "increase_factor must be greater than 1 and finite");
68 |
69 | ExactLineSearch {
70 | start_step_width,
71 | stop_step_width,
72 | increase_factor
73 | }
74 | }
75 | }
76 |
77 | impl LineSearch for ExactLineSearch {
78 | fn search(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec
79 | where F: Function1
80 | {
81 | let mut min_position = initial_position.to_vec();
82 | let mut min_value = function.value(initial_position);
83 |
84 | let mut step_width = self.start_step_width;
85 |
86 | loop {
87 | let position: Vec<_> = initial_position.iter().cloned().zip(direction).map(|(x, d)| {
88 | x + step_width * d
89 | }).collect();
90 | let value = function.value(&position);
91 |
92 | if value < min_value {
93 | min_position = position;
94 | min_value = value;
95 | }
96 |
97 | step_width *= self.increase_factor;
98 |
99 | if step_width >= self.stop_step_width {
100 | break;
101 | }
102 | }
103 |
104 | min_position
105 | }
106 | }
107 |
108 |
109 | /// Backtracking line search evaluating the Armijo rule at each step width.
110 | #[derive(Debug, Copy, Clone)]
111 | pub struct ArmijoLineSearch {
112 | control_parameter: f64,
113 | initial_step_width: f64,
114 | decay_factor: f64
115 | }
116 |
117 | impl ArmijoLineSearch {
118 | /// Creates a new `ArmijoLineSearch` given the `control_parameter` ∈ (0, 1), the
119 | /// `initial_step_width` > 0 and the `decay_factor` ∈ (0, 1).
120 | ///
121 | /// Armijo used in his paper the values 0.5, 1.0 and 0.5, respectively.
122 | pub fn new(control_parameter: f64, initial_step_width: f64, decay_factor: f64) ->
123 | ArmijoLineSearch
124 | {
125 | assert!(control_parameter > 0.0 && control_parameter < 1.0,
126 | "control_parameter must be in range (0, 1)");
127 | assert!(initial_step_width > 0.0 && initial_step_width.is_finite(),
128 | "initial_step_width must be > 0 and finite");
129 | assert!(decay_factor > 0.0 && decay_factor < 1.0, "decay_factor must be in range (0, 1)");
130 |
131 | ArmijoLineSearch {
132 | control_parameter,
133 | initial_step_width,
134 | decay_factor
135 | }
136 | }
137 | }
138 |
139 | impl LineSearch for ArmijoLineSearch {
140 | fn search(&self, function: &F, initial_position: &[f64], direction: &[f64]) -> Vec
141 | where F: Function1
142 | {
143 | let initial_value = function.value(initial_position);
144 | let gradient = function.gradient(initial_position);
145 |
146 | let m = gradient.iter().zip(direction).map(|(g, d)| g * d).fold(0.0, Add::add);
147 | let t = -self.control_parameter * m;
148 |
149 | assert!(t > 0.0);
150 |
151 | let mut step_width = self.initial_step_width;
152 |
153 | loop {
154 | let position: Vec<_> = initial_position.iter().cloned().zip(direction).map(|(x, d)| {
155 | x + step_width * d
156 | }).collect();
157 | let value = function.value(&position);
158 |
159 | if value <= initial_value - step_width * t {
160 | return position;
161 | }
162 |
163 | step_width *= self.decay_factor;
164 | }
165 | }
166 | }
167 |
--------------------------------------------------------------------------------
/src/numeric.rs:
--------------------------------------------------------------------------------
1 | use std::f64::EPSILON;
2 |
3 | use problems::Problem;
4 | use types::{Function, Function1};
5 |
6 |
7 | /// Wraps a function for which to provide numeric differentiation.
8 | ///
9 | /// Uses simple one step forward finite difference with step width `h = √εx`.
10 | ///
11 | /// # Examples
12 | ///
13 | /// ```
14 | /// # use self::optimization::*;
15 | /// let square = NumericalDifferentiation::new(Func(|x: &[f64]| {
16 | /// x[0] * x[0]
17 | /// }));
18 | ///
19 | /// assert!(square.gradient(&[0.0])[0] < 1.0e-3);
20 | /// assert!(square.gradient(&[1.0])[0] > 1.0);
21 | /// assert!(square.gradient(&[-1.0])[0] < 1.0);
22 | /// ```
23 | pub struct NumericalDifferentiation {
24 | function: F
25 | }
26 |
27 | impl NumericalDifferentiation {
28 | /// Creates a new differentiable function by using the supplied `function` in
29 | /// combination with numeric differentiation to find the derivatives.
30 | pub fn new(function: F) -> Self {
31 | NumericalDifferentiation {
32 | function
33 | }
34 | }
35 | }
36 |
37 | impl Function for NumericalDifferentiation {
38 | fn value(&self, position: &[f64]) -> f64 {
39 | self.function.value(position)
40 | }
41 | }
42 |
43 | impl Function1 for NumericalDifferentiation {
44 | fn gradient(&self, position: &[f64]) -> Vec {
45 | let mut x: Vec<_> = position.to_vec();
46 |
47 | let current = self.value(&x);
48 |
49 | position.iter().cloned().enumerate().map(|(i, x_i)| {
50 | let h = if x_i == 0.0 {
51 | EPSILON * 1.0e10
52 | } else {
53 | (EPSILON * x_i.abs()).sqrt()
54 | };
55 |
56 | assert!(h.is_finite());
57 |
58 | x[i] = x_i + h;
59 |
60 | let forward = self.function.value(&x);
61 |
62 | x[i] = x_i;
63 |
64 | let d_i = (forward - current) / h;
65 |
66 | assert!(d_i.is_finite());
67 |
68 | d_i
69 | }).collect()
70 | }
71 | }
72 |
73 | impl Default for NumericalDifferentiation {
74 | fn default() -> Self {
75 | NumericalDifferentiation::new(F::default())
76 | }
77 | }
78 |
79 | impl Problem for NumericalDifferentiation {
80 | fn dimensions(&self) -> usize {
81 | self.function.dimensions()
82 | }
83 |
84 | fn domain(&self) -> Vec<(f64, f64)> {
85 | self.function.domain()
86 | }
87 |
88 | fn minimum(&self) -> (Vec, f64) {
89 | self.function.minimum()
90 | }
91 |
92 | fn random_start(&self) -> Vec {
93 | self.function.random_start()
94 | }
95 | }
96 |
97 |
98 | #[cfg(test)]
99 | mod tests {
100 | use types::Function1;
101 | use problems::{Problem, Sphere, Rosenbrock};
102 | use utils::are_close;
103 | use gd::GradientDescent;
104 |
105 | use super::NumericalDifferentiation;
106 |
107 | #[test]
108 | fn test_accuracy() {
109 | //let a = Sphere::default();
110 | let b = Rosenbrock::default();
111 |
112 | // FIXME: How to iterate over different problems?
113 | let problems = vec![b];
114 |
115 | for analytical_problem in problems {
116 | let numerical_problem = NumericalDifferentiation::new(analytical_problem);
117 |
118 | for _ in 0..1000 {
119 | let position = analytical_problem.random_start();
120 |
121 | let analytical_gradient = analytical_problem.gradient(&position);
122 | let numerical_gradient = numerical_problem.gradient(&position);
123 |
124 | assert_eq!(analytical_gradient.len(), numerical_gradient.len());
125 |
126 | assert!(analytical_gradient.into_iter().zip(numerical_gradient).all(|(a, n)|
127 | a.is_finite() && n.is_finite() && are_close(a, n, 1.0e-3)
128 | ));
129 | }
130 | }
131 | }
132 |
133 | test_minimizer!{GradientDescent::new(),
134 | test_gd_sphere => NumericalDifferentiation::new(Sphere::default()),
135 | test_gd_rosenbrock => NumericalDifferentiation::new(Rosenbrock::default())}
136 | }
137 |
--------------------------------------------------------------------------------
/src/problems.rs:
--------------------------------------------------------------------------------
1 | //! Common optimization problems for testing purposes.
2 | //!
3 | //! Currently, the following [optimization test functions] are implemented.
4 | //!
5 | //! ## Bowl-Shaped
6 | //!
7 | //! * [`Sphere`](http://www.sfu.ca/~ssurjano/spheref.html)
8 | //!
9 | //! ## Valley-Shaped
10 | //!
11 | //! * [`Rosenbrock`](http://www.sfu.ca/~ssurjano/rosen.html)
12 | //!
13 | //! [optimization test functions]: http://www.sfu.ca/~ssurjano/optimization.html
14 |
15 | use rand::random;
16 | use std::f64::INFINITY;
17 | use std::ops::Add;
18 |
19 | use types::{Function, Function1};
20 |
21 |
22 | /// Specifies a well known optimization problem.
23 | pub trait Problem: Function + Default {
24 | /// Returns the dimensionality of the input domain.
25 | fn dimensions(&self) -> usize;
26 |
27 | /// Returns the input domain of the function in terms of upper and lower,
28 | /// respectively, for each input dimension.
29 | fn domain(&self) -> Vec<(f64, f64)>;
30 |
31 | /// Returns the position as well as the value of the global minimum.
32 | fn minimum(&self) -> (Vec, f64);
33 |
34 | /// Generates a random and **feasible** position to start a minimization.
35 | fn random_start(&self) -> Vec;
36 |
37 | /// Tests whether the supplied position is legal for this function.
38 | fn is_legal_position(&self, position: &[f64]) -> bool {
39 | position.len() == self.dimensions() &&
40 | position.iter().zip(self.domain()).all(|(&x, (lower, upper))| {
41 | lower < x && x < upper
42 | })
43 | }
44 | }
45 |
46 |
47 | macro_rules! define_problem {
48 | ( $name:ident: $this:ident,
49 | default: $def:expr,
50 | dimensions: $dims:expr,
51 | domain: $domain:expr,
52 | minimum: $miny:expr,
53 | at: $minx:expr,
54 | start: $start:expr,
55 | value: $x1:ident => $value:expr,
56 | gradient: $x2:ident => $gradient:expr ) =>
57 | {
58 | impl Default for $name {
59 | fn default() -> Self {
60 | $def
61 | }
62 | }
63 |
64 | impl Function for $name {
65 | fn value(&$this, $x1: &[f64]) -> f64 {
66 | assert!($this.is_legal_position($x1));
67 |
68 | $value
69 | }
70 | }
71 |
72 | impl Function1 for $name {
73 | fn gradient(&$this, $x2: &[f64]) -> Vec {
74 | assert!($this.is_legal_position($x2));
75 |
76 | $gradient
77 | }
78 | }
79 |
80 | impl Problem for $name {
81 | fn dimensions(&$this) -> usize {
82 | $dims
83 | }
84 |
85 | fn domain(&$this) -> Vec<(f64, f64)> {
86 | $domain
87 | }
88 |
89 | fn minimum(&$this) -> (Vec, f64) {
90 | ($minx, $miny)
91 | }
92 |
93 | fn random_start(&$this) -> Vec {
94 | $start
95 | }
96 | }
97 | };
98 | }
99 |
100 |
101 | /// n-dimensional Sphere function.
102 | ///
103 | /// It is continuous, convex and unimodal:
104 | ///
105 | /// > f(x) = ∑ᵢ xᵢ²
106 | ///
107 | /// *Global minimum*: `f(0,...,0) = 0`
108 | #[derive(Debug, Copy, Clone)]
109 | pub struct Sphere {
110 | dimensions: usize
111 | }
112 |
113 | impl Sphere {
114 | pub fn new(dimensions: usize) -> Sphere {
115 | assert!(dimensions > 0, "dimensions must be larger than 1");
116 |
117 | Sphere {
118 | dimensions
119 | }
120 | }
121 | }
122 |
123 | define_problem!{Sphere: self,
124 | default: Sphere::new(2),
125 | dimensions: self.dimensions,
126 | domain: (0..self.dimensions).map(|_| (-INFINITY, INFINITY)).collect(),
127 | minimum: 0.0,
128 | at: (0..self.dimensions).map(|_| 0.0).collect(),
129 | start: (0..self.dimensions).map(|_| random::() * 10.24 - 5.12).collect(),
130 | value: x => x.iter().map(|x| x.powi(2)).fold(0.0, Add::add),
131 | gradient: x => x.iter().map(|x| 2.0 * x).collect()
132 | }
133 |
134 |
135 | /// Two-dimensional Rosenbrock function.
136 | ///
137 | /// A non-convex function with its global minimum inside a long, narrow, parabolic
138 | /// shaped flat valley:
139 | ///
140 | /// > f(x, y) = (a - x)² + b (y - x²)²
141 | ///
142 | /// *Global minimum*: `f(a, a²) = 0`
143 | #[derive(Debug, Copy, Clone)]
144 | pub struct Rosenbrock {
145 | a: f64,
146 | b: f64
147 | }
148 |
149 | impl Rosenbrock {
150 | /// Creates a new `Rosenbrock` function given `a` and `b`, commonly definied
151 | /// with 1 and 100, respectively, which also corresponds to the `default`.
152 | pub fn new(a: f64, b: f64) -> Rosenbrock {
153 | Rosenbrock {
154 | a,
155 | b
156 | }
157 | }
158 | }
159 |
160 | define_problem!{Rosenbrock: self,
161 | default: Rosenbrock::new(1.0, 100.0),
162 | dimensions: 2,
163 | domain: vec![(-INFINITY, INFINITY), (-INFINITY, INFINITY)],
164 | minimum: 0.0,
165 | at: vec![self.a, self.a * self.a],
166 | start: (0..2).map(|_| random::() * 4.096 - 2.048).collect(),
167 | value: x => (self.a - x[0]).powi(2) + self.b * (x[1] - x[0].powi(2)).powi(2),
168 | gradient: x => vec![-2.0 * self.a + 4.0 * self.b * x[0].powi(3) - 4.0 * self.b * x[0] * x[1] + 2.0 * x[0],
169 | 2.0 * self.b * (x[1] - x[0].powi(2))]
170 | }
171 |
172 |
173 | /*
174 | pub struct McCormick;
175 |
176 | impl McCormick {
177 | pub fn new() -> McCormick {
178 | McCormick
179 | }
180 | }
181 |
182 | define_problem!{McCormick: self,
183 | default: McCormick::new(),
184 | dimensions: 2,
185 | domain: vec![(-INFINITY, INFINITY), (-INFINITY, INFINITY)],
186 | minimum: -1.9133,
187 | at: vec![-0.54719, -1.54719],
188 | start: vec![random::() * 5.5 - 1.5, random::() * 7.0 - 3.0],
189 | value: x => (x[0] + x[1]).sin() + (x[0] - x[1]).powi(2) - 1.5 * x[0] + 2.5 * x[1] + 1.0,
190 | gradient: x => vec![(x[0] + x[1]).cos() + 2.0 * (x[0] - x[1]) - 1.5,
191 | (x[0] + x[1]).cos() - 2.0 * (x[0] - x[1]) + 2.5]
192 | }
193 | */
194 |
195 |
196 | #[cfg(test)]
197 | macro_rules! test_minimizer {
198 | ( $minimizer:expr, $( $name:ident => $problem:expr ),* ) => {
199 | $(
200 | #[test]
201 | fn $name() {
202 | let minimizer = $minimizer;
203 | let problem = $problem;
204 |
205 | for _ in 0..100 {
206 | let position = $crate::problems::Problem::random_start(&problem);
207 |
208 | let solution = $crate::Minimizer::minimize(&minimizer,
209 | &problem, position);
210 |
211 | let distance = $crate::Evaluation::position(&solution).iter()
212 | .zip($crate::problems::Problem::minimum(&problem).0)
213 | .map(|(a, b)| (a - b).powi(2))
214 | .fold(0.0, ::std::ops::Add::add)
215 | .sqrt();
216 |
217 | assert!(distance < 1.0e-2);
218 | }
219 | }
220 | )*
221 | };
222 | }
223 |
--------------------------------------------------------------------------------
/src/sgd.rs:
--------------------------------------------------------------------------------
1 | use log::Level::Trace;
2 | use rand::{SeedableRng, random};
3 | use rand::seq::SliceRandom;
4 | use rand_pcg::Pcg64Mcg;
5 |
6 | use types::{Minimizer, Solution, Summation1};
7 |
8 |
9 | /// Provides _stochastic_ Gradient Descent optimization.
10 | pub struct StochasticGradientDescent {
11 | rng: Pcg64Mcg,
12 | max_iterations: Option,
13 | mini_batch: usize,
14 | step_width: f64
15 | }
16 |
17 | impl StochasticGradientDescent {
18 | /// Creates a new `StochasticGradientDescent` optimizer using the following defaults:
19 | ///
20 | /// - **`step_width`** = `0.01`
21 | /// - **`mini_batch`** = `1`
22 | /// - **`max_iterations`** = `1000`
23 | ///
24 | /// The used random number generator is randomly seeded.
25 | pub fn new() -> StochasticGradientDescent {
26 | StochasticGradientDescent {
27 | rng: Pcg64Mcg::new(random()),
28 | max_iterations: None,
29 | mini_batch: 1,
30 | step_width: 0.01
31 | }
32 | }
33 |
34 | /// Seeds the random number generator using the supplied `seed`.
35 | ///
36 | /// This is useful to create re-producable results.
37 | pub fn seed(&mut self, seed: u64) -> &mut Self {
38 | self.rng = Pcg64Mcg::seed_from_u64(seed);
39 | self
40 | }
41 |
42 | /// Adjusts the number of maximally run iterations. A value of `None` instructs the
43 | /// optimizer to ignore the nubmer of iterations.
44 | pub fn max_iterations(&mut self, max_iterations: Option) -> &mut Self {
45 | assert!(max_iterations.map_or(true, |max_iterations| max_iterations > 0));
46 |
47 | self.max_iterations = max_iterations;
48 | self
49 | }
50 |
51 | /// Adjusts the mini batch size, i.e., how many terms are considered in one step at most.
52 | pub fn mini_batch(&mut self, mini_batch: usize) -> &mut Self {
53 | assert!(mini_batch > 0);
54 |
55 | self.mini_batch = mini_batch;
56 | self
57 | }
58 |
59 | /// Adjusts the step size applied for each mini batch.
60 | pub fn step_width(&mut self, step_width: f64) -> &mut Self {
61 | assert!(step_width > 0.0);
62 |
63 | self.step_width = step_width;
64 | self
65 | }
66 | }
67 |
68 | impl Default for StochasticGradientDescent {
69 | fn default() -> Self {
70 | Self::new()
71 | }
72 | }
73 |
74 | impl Minimizer for StochasticGradientDescent {
75 | type Solution = Solution;
76 |
77 | fn minimize(&self, function: &F, initial_position: Vec) -> Solution {
78 | let mut position = initial_position;
79 | let mut value = function.value(&position);
80 |
81 | if log_enabled!(Trace) {
82 | info!("Starting with y = {:?} for x = {:?}", value, position);
83 | } else {
84 | info!("Starting with y = {:?}", value);
85 | }
86 |
87 | let mut iteration = 0;
88 | let mut terms: Vec<_> = (0..function.terms()).collect();
89 | let mut rng = self.rng.clone();
90 |
91 | loop {
92 | // ensure that we don't run into cycles
93 | terms.shuffle(&mut rng);
94 |
95 | for batch in terms.chunks(self.mini_batch) {
96 | let gradient = function.partial_gradient(&position, batch);
97 |
98 | // step into the direction of the negative gradient
99 | for (x, g) in position.iter_mut().zip(gradient) {
100 | *x -= self.step_width * g;
101 | }
102 | }
103 |
104 | value = function.value(&position);
105 |
106 | iteration += 1;
107 |
108 | if log_enabled!(Trace) {
109 | debug!("Iteration {:6}: y = {:?}, x = {:?}", iteration, value, position);
110 | } else {
111 | debug!("Iteration {:6}: y = {:?}", iteration, value);
112 | }
113 |
114 | let reached_max_iterations = self.max_iterations.map_or(false,
115 | |max_iterations| iteration == max_iterations);
116 |
117 | if reached_max_iterations {
118 | info!("Reached maximal number of iterations, stopping optimization");
119 | return Solution::new(position, value);
120 | }
121 | }
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/src/types.rs:
--------------------------------------------------------------------------------
1 | use std::borrow::Borrow;
2 |
3 |
4 | /// Defines an objective function `f` that is subject to minimization.
5 | ///
6 | /// For convenience every function with the same signature as `value()` qualifies as
7 | /// an objective function, e.g., minimizing a closure is perfectly fine.
8 | pub trait Function {
9 | /// Computes the objective function at a given `position` `x`, i.e., `f(x) = y`.
10 | fn value(&self, position: &[f64]) -> f64;
11 | }
12 |
13 |
14 | /// New-type to support optimization of arbitrary functions without requiring
15 | /// to implement a trait.
16 | pub struct Func f64>(pub F);
17 |
18 | impl f64> Function for Func {
19 | fn value(&self, position: &[f64]) -> f64 {
20 | self.0(position)
21 | }
22 | }
23 |
24 |
25 | /// Defines an objective function `f` that is able to compute the first derivative
26 | /// `f'(x)`.
27 | pub trait Function1: Function {
28 | /// Computes the gradient of the objective function at a given `position` `x`,
29 | /// i.e., `∀ᵢ ∂/∂xᵢ f(x) = ∇f(x)`.
30 | fn gradient(&self, position: &[f64]) -> Vec;
31 | }
32 |
33 |
34 | /// Defines a summation of individual functions, i.e., f(x) = ∑ᵢ fᵢ(x).
35 | pub trait Summation: Function {
36 | /// Returns the number of individual functions that are terms of the summation.
37 | fn terms(&self) -> usize;
38 |
39 | /// Comptues the value of one individual function indentified by its index `term`,
40 | /// given the `position` `x`.
41 | fn term_value(&self, position: &[f64], term: usize) -> f64;
42 |
43 | /// Computes the partial sum over a set of individual functions identified by `terms`.
44 | fn partial_value, I: Borrow>(&self, position: &[f64], terms: T) -> f64 {
45 | let mut value = 0.0;
46 |
47 | for term in terms {
48 | value += self.term_value(position, *term.borrow());
49 | }
50 |
51 | value
52 | }
53 | }
54 |
55 | impl Function for S {
56 | fn value(&self, position: &[f64]) -> f64 {
57 | self.partial_value(position, 0..self.terms())
58 | }
59 | }
60 |
61 |
62 | /// Defines a summation of individual functions `fᵢ(x)`, assuming that each function has a first
63 | /// derivative.
64 | pub trait Summation1: Summation + Function1 {
65 | /// Computes the gradient of one individual function identified by `term` at the given
66 | /// `position`.
67 | fn term_gradient(&self, position: &[f64], term: usize) -> Vec;
68 |
69 | /// Computes the partial gradient over a set of `terms` at the given `position`.
70 | fn partial_gradient, I: Borrow>(&self, position: &[f64], terms: T) -> Vec {
71 | let mut gradient = vec![0.0; position.len()];
72 |
73 | for term in terms {
74 | for (g, gi) in gradient.iter_mut().zip(self.term_gradient(position, *term.borrow())) {
75 | *g += gi;
76 | }
77 | }
78 |
79 | gradient
80 | }
81 | }
82 |
83 | impl Function1 for S {
84 | fn gradient(&self, position: &[f64]) -> Vec {
85 | self.partial_gradient(position, 0..self.terms())
86 | }
87 | }
88 |
89 |
90 | /// Defines an optimizer that is able to minimize a given objective function `F`.
91 | pub trait Minimizer {
92 | /// Type of the solution the `Minimizer` returns.
93 | type Solution: Evaluation;
94 |
95 | /// Performs the actual minimization and returns a solution that
96 | /// might be better than the initially provided one.
97 | fn minimize(&self, function: &F, initial_position: Vec) -> Self::Solution;
98 | }
99 |
100 |
101 | /// Captures the essence of a function evaluation.
102 | pub trait Evaluation {
103 | /// Position `x` with the lowest corresponding value `f(x)`.
104 | fn position(&self) -> &[f64];
105 |
106 | /// The actual value `f(x)`.
107 | fn value(&self) -> f64;
108 | }
109 |
110 |
111 | /// A solution of a minimization run providing only the minimal information.
112 | ///
113 | /// Each `Minimizer` might yield different types of solution structs which provide more
114 | /// information.
115 | #[derive(Debug, Clone)]
116 | pub struct Solution {
117 | /// Position `x` of the lowest corresponding value `f(x)` that has been found.
118 | pub position: Vec,
119 | /// The actual value `f(x)`.
120 | pub value: f64
121 | }
122 |
123 | impl Solution {
124 | /// Creates a new `Solution` given the `position` as well as the corresponding `value`.
125 | pub fn new(position: Vec, value: f64) -> Solution {
126 | Solution {
127 | position,
128 | value
129 | }
130 | }
131 | }
132 |
133 | impl Evaluation for Solution {
134 | fn position(&self) -> &[f64] {
135 | &self.position
136 | }
137 |
138 | fn value(&self) -> f64 {
139 | self.value
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/src/utils.rs:
--------------------------------------------------------------------------------
1 | #[cfg(test)]
2 | use std::f64::{MIN_POSITIVE, MAX};
3 |
4 |
5 | /// Tests whether we reached a flat area, i.e., tests if all absolute gradient component
6 | /// lie within the `tolerance`.
7 | pub fn is_saddle_point(gradient: &[f64], tolerance: f64) -> bool {
8 | gradient.iter().all(|dx| dx.abs() <= tolerance)
9 | }
10 |
11 |
12 | /// Tests whether two floating point numbers are close using the relative error
13 | /// and handling special cases like infinity etc.
14 | #[cfg(test)]
15 | #[allow(clippy::float_cmp)]
16 | pub fn are_close(a: f64, b: f64, eps: f64) -> bool {
17 | assert!(eps.is_finite());
18 |
19 | let d = (a - b).abs();
20 |
21 | // identical, e.g., infinity
22 | a == b
23 |
24 | // a or b is zero or both are extremely close to it
25 | // relative error is less meaningful here
26 | || ((a == 0.0 || b == 0.0 || d < MIN_POSITIVE) &&
27 | d < eps * MIN_POSITIVE)
28 |
29 | // finally, use the relative error
30 | || d / (a + b).min(MAX) < eps
31 | }
32 |
33 |
34 | #[cfg(test)]
35 | mod tests {
36 | use std::f64::{INFINITY, NAN};
37 |
38 | use super::{is_saddle_point, are_close};
39 |
40 | #[test]
41 | fn test_is_saddle_point() {
42 | assert!(is_saddle_point(&[1.0, 2.0], 2.0));
43 | assert!(is_saddle_point(&[1.0, -2.0], 2.0));
44 | assert!(!is_saddle_point(&[1.0, 2.1], 2.0));
45 | assert!(!is_saddle_point(&[1.0, -2.1], 2.0));
46 | }
47 |
48 | #[test]
49 | fn test_are_close() {
50 | assert!(are_close(1.0, 1.0, 0.00001));
51 | assert!(are_close(INFINITY, INFINITY, 0.00001));
52 | assert!(are_close(1.0e-1000, 0.0, 0.1));
53 | assert!(!are_close(1.0e-40, 0.0, 0.000_001));
54 | assert!(!are_close(2.0, 1.0, 0.00001));
55 | assert!(!are_close(NAN, NAN, 0.00001));
56 | }
57 | }
58 |
--------------------------------------------------------------------------------