├── .babelrc
├── .editorconfig
├── .eslintignore
├── .eslintrc
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── examples
├── react-sample
│ ├── .babelrc
│ ├── README.md
│ ├── dist
│ │ └── index.html
│ ├── package.json
│ ├── src
│ │ ├── components
│ │ │ ├── ExperimentsTable.js
│ │ │ ├── Value.js
│ │ │ └── ValuesRow.js
│ │ ├── index.js
│ │ ├── pages
│ │ │ └── Home.js
│ │ └── withRoot.js
│ └── webpack.config.js
└── tiny
│ ├── README.md
│ ├── index.html
│ ├── index.js
│ └── package.json
├── package.json
├── rollup.config.js
├── src
├── base
│ ├── base.js
│ └── fmin.js
├── index.js
├── search
│ ├── grid.js
│ └── random.js
└── utils
│ └── RandomState.js
└── test
├── __snapshots__
└── index.js.json
└── index.js
/.babelrc:
--------------------------------------------------------------------------------
1 | {
2 | "presets": [
3 | [
4 | "env",
5 | {
6 | "targets": {
7 | "browsers": [
8 | "ie >= 11"
9 | ]
10 | },
11 | "exclude": ["transform-async-to-generator", "transform-regenerator"],
12 | "modules": false,
13 | "loose": true
14 | }
15 | ]
16 | , "stage-0"
17 | ],
18 | "plugins": [
19 | "transform-object-rest-spread"
20 | ],
21 | "env": {
22 | "development": {
23 | "presets": ["env", "stage-0"]
24 | },
25 | "commonjs": {
26 | "presets": [
27 | [
28 | "env",
29 | {
30 | "loose": true
31 | }
32 | ]
33 | ]
34 | }
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | end_of_line = lf
5 | charset = utf-8
6 | trim_trailing_whitespace = true
7 | insert_final_newline = true
8 | indent_style = space
9 | indent_size = 2
10 |
11 | [*.md]
12 | trim_trailing_whitespace = false
13 |
14 | [{package,bower}.json]
15 | indent_style = space
16 | indent_size = 2
17 |
--------------------------------------------------------------------------------
/.eslintignore:
--------------------------------------------------------------------------------
1 | node_modules/**
2 | coverage/**
3 | test/**
4 | lib/**
5 | .vscode/**
6 | src/random.js
7 |
--------------------------------------------------------------------------------
/.eslintrc:
--------------------------------------------------------------------------------
1 | {
2 | "parser": "babel-eslint",
3 | "extends": "airbnb",
4 | "env": {
5 | "mocha": true
6 | },
7 | "rules": {
8 | "comma-dangle": ["error", "only-multiline"]
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Logs
2 | logs
3 | *.log
4 | npm-debug.log*
5 |
6 | # Runtime data
7 | pids
8 | *.pid
9 | *.seed
10 |
11 | # Directory for instrumented libs generated by jscoverage/JSCover
12 | lib-cov
13 |
14 | # Coverage directory used by tools like istanbul
15 | coverage
16 |
17 | # nyc test coverage
18 | .nyc_output
19 |
20 | # Compiled binary addons (http://nodejs.org/api/addons.html)
21 | build/Release
22 |
23 | # Dependency directories
24 | node_modules
25 | jspm_packages
26 |
27 | # Optional npm cache directory
28 | .npm
29 |
30 | # Optional REPL history
31 | .node_repl_history
32 |
33 | # Editors
34 | .idea
35 |
36 | # Lib
37 | lib
38 |
39 | # npm package lock
40 | package-lock.json
41 | yarn.lock
42 |
43 | others
44 | .DS_Store
45 | .history
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: node_js
2 | node_js:
3 | - "6"
4 | script:
5 | - npm run test
6 | branches:
7 | only:
8 | - master
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ES6 hyperparameters optimization
2 |
3 | [](https://travis-ci.org/atanasster/hyperparameters) [](https://david-dm.org/atanasster/hyperjs) [](https://david-dm.org/atanasster/hyperjs?type=dev) [](https://opensource.org/licenses/MIT)
4 |
5 | :warning: Early version subject to changes.
6 |
7 |
8 |
9 | ## Features
10 | * **written in javascript** - Use with tensorflow.js as a replacement to your python hyperparameters library
11 | * **use from cdn or npm** - Link hpjs in your html file from a cdn, or install in your project with npm
12 | * **versatile** - Utilize multiple parameters and multiple search algorithms (grid search, random, bayesian)
13 |
14 |
15 |
16 | ## Installation
17 |
18 | ```
19 | $ npm install hyperparameters
20 | ```
21 |
22 |
23 | ## Parameter Expressions
24 |
25 | ```
26 | import * as hpjs from 'hyperparameters';
27 | ```
28 |
29 | ### hpjs.choice(options)
30 |
31 | - Randomly returns one of the options
32 |
33 | ### hpjs.randint(upper)
34 |
35 | - Return a random integer in the range [0, upper)
36 |
37 | ### hpjs.uniform(low, high)
38 |
39 | - Returns a single value uniformly between `low` and `high` i.e. any value between `low` and `high` has an equal probability of being selected
40 |
41 | ### hpjs.quniform(low, high, q)
42 |
43 | - returns a quantized value of `hp.uniform` calculated as `round(uniform(low, high) / q) * q`
44 |
45 | ### hpjs.loguniform(low, high)
46 |
47 | - Returns a value `exp(uniform(low, high))` so the logarithm of the return value is uniformly distributed.
48 |
49 | ### hpjs.qloguniform(low, high, q)
50 |
51 | - Returns a value `round(exp(uniform(low, high)) / q) * q`
52 |
53 | ### hpjs.normal(mu, sigma)
54 |
55 | - Returns a real number that's normally-distributed with mean mu and standard deviation sigma
56 |
57 | ### hpjs.qnormal(mu, sigma, q)
58 |
59 | - Returns a value `round(normal(mu, sigma) / q) * q`
60 |
61 | ### hpjs.lognormal(mu, sigma)
62 |
63 | - Returns a value `exp(normal(mu, sigma))`
64 |
65 | ### hpjs.qlognormal(mu, sigma, q)
66 |
67 | - Returns a value `round(exp(normal(mu, sigma)) / q) * q`
68 |
69 |
70 |
71 | ## Random numbers generator
72 |
73 | ```
74 | import { RandomState } from 'hyperparameters';
75 | ```
76 |
77 | **example:**
78 | ```
79 | const rng = new RandomState(12345);
80 | console.log(rng.randrange(0, 5, 0.5));
81 |
82 | ```
83 |
84 |
85 | ## Spaces
86 |
87 | ```
88 | import { sample } from 'hyperparameters';
89 | ```
90 |
91 | **example:**
92 | ```
93 | import * as hpjs from 'hyperparameters';
94 |
95 | const space = {
96 | x: hpjs.normal(0, 2),
97 | y: hpjs.uniform(0, 1),
98 | choice: hpjs.choice([
99 | undefined, hp.uniform('float', 0, 1),
100 | ]),
101 | array: [
102 | hpjs.normal(0, 2), hpjs.uniform(0, 3), hpjs.choice([false, true]),
103 | ],
104 | obj: {
105 | u: hpjs.uniform(0, 3),
106 | v: hpjs.uniform(0, 3),
107 | w: hpjs.uniform(-3, 0)
108 | }
109 | };
110 |
111 | console.log(hpjs.sample.randomSample(space));
112 |
113 | ```
114 | ## fmin - find best value of a function over the arguments
115 |
116 | ```
117 | import * as hpjs from 'hyperparameters';
118 | const trials = hpjs.fmin(optimizationFunction, space, estimator, max_estimates, options);
119 | ```
120 |
121 | **example:**
122 | ```
123 | import * as hpjs from 'hyperparameters';
124 |
125 | const fn = x => ((x ** 2) - (x + 1));
126 | const space = hpjs.uniform(-5, 5);
127 | fmin(fn, space, hpjs.search.randomSearch, 1000, { rng: new hpjs.RandomState(123456) })
128 | .then(trials => console.log(result.argmin));
129 | ```
130 | ## Getting started with tensorflow.js
131 |
132 | ### 1. [include javascript file](https://github.com/atanasster/hyperparameters/tree/master/examples/tiny)
133 |
134 | * include (latest) version from cdn
135 |
136 | ``
137 |
138 | * create search space
139 | ```
140 | const space = {
141 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']),
142 | epochs: hpjs.quniform(50, 250, 50),
143 | };
144 |
145 | ```
146 | * create tensorflow.js train function. Parameters are optimizer and epochs. input and output data passed as second argument
147 | ```
148 | const trainModel = async ({ optimizer, epochs }, { xs, ys }) => {
149 | // Create a simple model.
150 | const model = tf.sequential();
151 | model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
152 | // Prepare the model for training: Specify the loss and the optimizer.
153 | model.compile({
154 | loss: 'meanSquaredError',
155 | optimizer
156 | });
157 | // Train the model using the data.
158 | const h = await model.fit(xs, ys, { epochs });
159 | return { model, loss: h.history.loss[h.history.loss.length - 1] };
160 | };
161 | ```
162 | * create optimization function
163 | ```
164 | const modelOpt = async ({ optimizer, epochs }, { xs, ys }) => {
165 | const { loss } = await trainModel({ optimizer, epochs }, { xs, ys });
166 | return { loss, status: hpjs.STATUS_OK };
167 | };
168 | ```
169 |
170 | * find optimal hyperparameters
171 | ```
172 | const trials = await hpjs.fmin(
173 | modelOpt, space, hpjs.search.randomSearch, 10,
174 | { rng: new hpjs.RandomState(654321), xs, ys }
175 | );
176 | const opt = trials.argmin;
177 | console.log('best optimizer',opt.optimizer);
178 | console.log('best no of epochs', opt.epochs);
179 | ```
180 |
181 | ### 2. [install with npm](https://github.com/atanasster/hyperparameters/tree/master/examples/react-sample)
182 | * install hyperparameters in your package.json
183 | ```
184 | $ npm install hyperparameters
185 | ```
186 |
187 | * import hyperparameters
188 | ```
189 | import * as tf from '@tensorflow/tfjs';
190 | import * as hpjs from 'hyperparameters';
191 | ```
192 |
193 | * create search space
194 | ```
195 | const space = {
196 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']),
197 | epochs: hpjs.quniform(50, 250, 50),
198 | };
199 |
200 | ```
201 | * create tensorflow.js train function. Parameters are optimizer and epochs. input and output data passed as second argument
202 | ```
203 | const trainModel = async ({ optimizer, epochs }, { xs, ys }) => {
204 | // Create a simple model.
205 | const model = tf.sequential();
206 | model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
207 | // Prepare the model for training: Specify the loss and the optimizer.
208 | model.compile({
209 | loss: 'meanSquaredError',
210 | optimizer
211 | });
212 | // Train the model using the data.
213 | const h = await model.fit(xs, ys, { epochs });
214 | return { model, loss: h.history.loss[h.history.loss.length - 1] };
215 | };
216 | ```
217 | * create optimization function
218 | ```
219 | const modelOpt = async ({ optimizer, epochs }, { xs, ys }) => {
220 | const { loss } = await trainModel({ optimizer, epochs }, { xs, ys });
221 | return { loss, status: hpjs.STATUS_OK };
222 | };
223 | ```
224 |
225 | * find optimal hyperparameters
226 | ```
227 | const trials = await hpjs.fmin(
228 | modelOpt, space, hpjs.search.randomSearch, 10,
229 | { rng: new hpjs.RandomState(654321), xs, ys }
230 | );
231 | const opt = trials.argmin;
232 | console.log('best optimizer',opt.optimizer);
233 | console.log('best no of epochs', opt.epochs);
234 | ```
235 |
236 |
237 | ## License
238 |
239 | MIT © Atanas Stoyanov & Martin Stoyanov
240 |
--------------------------------------------------------------------------------
/examples/react-sample/.babelrc:
--------------------------------------------------------------------------------
1 | {
2 | "presets": [
3 | "env",
4 | "react",
5 | "stage-0"
6 | ]
7 | }
8 |
--------------------------------------------------------------------------------
/examples/react-sample/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow.js Tiny Example with HyperParameters.js
2 |
3 | This minimal tfjs example builds and trains a minimal model, showing the trials history and the best optimizer and number of epochs with a react sample application
4 |
5 | ## Getting started
6 |
7 | * install hyperparameters in your package.json
8 | ```
9 | $ npm install hyperparameters
10 | ```
11 |
12 | * import hyperparameters
13 | ```
14 | import * as tf from '@tensorflow/tfjs';
15 | import * as hpjs from 'hyperparameters';
16 | ```
17 |
18 | * create search space
19 | ```
20 | const space = {
21 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']),
22 | epochs: hpjs.quniform(50, 250, 50),
23 | };
24 |
25 | ```
26 | * create tensorflow.js train function. Parameters are optimizer and epochs. input and output data passed as second argument
27 | ```
28 | const trainModel = async ({ optimizer, epochs }, { xs, ys }) => {
29 | // Create a simple model.
30 | const model = tf.sequential();
31 | model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
32 | // Prepare the model for training: Specify the loss and the optimizer.
33 | model.compile({
34 | loss: 'meanSquaredError',
35 | optimizer
36 | });
37 | // Train the model using the data.
38 | const h = await model.fit(xs, ys, { epochs });
39 | return { model, loss: h.history.loss[h.history.loss.length - 1] };
40 | };
41 | ```
42 | * create optimization function
43 | ```
44 | const modelOpt = async ({ optimizer, epochs }, { xs, ys }) => {
45 | const { loss } = await trainModel({ optimizer, epochs }, { xs, ys });
46 | return { loss, status: hpjs.STATUS_OK };
47 | };
48 | ```
49 |
50 | * find optimal hyperparameters
51 | ```
52 | const trials = await hpjs.fmin(
53 | modelOpt, space, hpjs.search.randomSearch, 10,
54 | { rng: new hpjs.RandomState(654321), xs, ys }
55 | );
56 | const opt = trials.argmin;
57 | console.log('best optimizer',opt.optimizer);
58 | console.log('best no of epochs', opt.epochs);
59 | ```
60 |
61 |
--------------------------------------------------------------------------------
/examples/react-sample/dist/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | Tiny TFJS + HPJS example
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/examples/react-sample/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "react-sample",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "index.js",
6 | "scripts": {
7 | "start": "webpack-dev-server --config ./webpack.config.js --mode development",
8 | "test": "echo \"Error: no test specified\" && exit 1"
9 | },
10 | "keywords": [],
11 | "author": "",
12 | "license": "ISC",
13 | "devDependencies": {
14 | "babel-core": "^6.26.3",
15 | "babel-loader": "^7.1.4",
16 | "babel-polyfill": "^6.26.0",
17 | "babel-preset-env": "^1.7.0",
18 | "babel-preset-react": "^6.24.1",
19 | "babel-preset-stage-0": "^6.24.1",
20 | "react-hot-loader": "^4.3.2",
21 | "webpack": "^4.12.0",
22 | "webpack-cli": "^3.0.6",
23 | "webpack-dev-server": "^3.1.4"
24 | },
25 | "dependencies": {
26 | "@material-ui/core": "latest",
27 | "@tensorflow/tfjs": "^0.11.6",
28 | "prop-types": "latest",
29 | "react": "^16.4.1",
30 | "react-dom": "^16.4.1"
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/examples/react-sample/src/components/ExperimentsTable.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import PropTypes from 'prop-types';
3 | import { withStyles } from '@material-ui/core/styles';
4 | import Table from '@material-ui/core/Table';
5 | import TableBody from '@material-ui/core/TableBody';
6 | import TableCell from '@material-ui/core/TableCell';
7 | import TableHead from '@material-ui/core/TableHead';
8 | import TableRow from '@material-ui/core/TableRow';
9 |
10 | const CustomTableCell = withStyles(theme => ({
11 | head: {
12 | backgroundColor: theme.palette.common.black,
13 | color: theme.palette.common.white,
14 | fontSize: 14,
15 | },
16 | body: {
17 | fontSize: 16,
18 | },
19 | }))(TableCell);
20 |
21 | const STATES_MAP = ['new', 'running', 'done', 'error'];
22 |
23 | export const formatTraingTime = (date, locale = 'en-us') => (
24 | date ? (new Date(date)).toLocaleDateString(locale, {
25 | month: '2-digit',
26 | day: '2-digit',
27 | hour: '2-digit',
28 | minute: '2-digit',
29 | }) : undefined
30 | );
31 |
32 | export const periodToTime = (duration) => {
33 | if (!duration) {
34 | return {
35 | time: 0,
36 | units: 'ms',
37 | };
38 | }
39 | const hours = (duration / (1000 * 60 * 60)).toFixed(0);
40 | const minutes = (duration / (1000 * 60)).toFixed(0);
41 | if (hours > 0) {
42 | return {
43 | time: `${hours}:${minutes}`,
44 | units: 'hrs',
45 | };
46 | }
47 | const seconds = (duration / 1000).toFixed(0);
48 | if (minutes > 0) {
49 | return {
50 | time: `${minutes}:${seconds}`,
51 | units: 'min',
52 | };
53 | }
54 | const milliseconds = (duration).toFixed(0);
55 | if (seconds > 0) {
56 | return {
57 | time: `${seconds}:${milliseconds}`,
58 | units: 'sec',
59 | };
60 | }
61 | return {
62 | time: `${milliseconds}`,
63 | units: 'ms',
64 | };
65 | };
66 |
67 | const ExperimentsTable = ({ classes, experiments }) => (
68 |
69 |
70 |
71 | #
72 | state
73 | start
74 | duration
75 | arguments
76 | results
77 |
78 |
79 | optimizer
80 | epochs
81 | status
82 | loss
83 |
84 |
85 |
86 |
87 | {experiments.map((exp) => {
88 | const duration = periodToTime(exp.refresh_time - exp.book_time);
89 | return (
90 |
91 |
92 | {exp.id}
93 |
94 | {STATES_MAP[exp.state]}
95 | {formatTraingTime(exp.book_time)}
96 | {`${duration.time} ${duration.units}`}
97 | {exp.args.optimizer}
98 | {exp.args.epochs}
99 | {exp.result.status}
100 | {exp.result.loss.toFixed(5)}
101 |
102 | );
103 | })}
104 |
105 |
106 | );
107 |
108 | ExperimentsTable.propTypes = {
109 | classes: PropTypes.object.isRequired,
110 | experiments: PropTypes.array.isRequired,
111 | }
112 | export default ExperimentsTable;
113 |
--------------------------------------------------------------------------------
/examples/react-sample/src/components/Value.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import Grid from '@material-ui/core/Grid';
3 | import Typography from '@material-ui/core/Typography';
4 |
5 | export default ({ value, label, size }) => (
6 |
7 |
8 |
9 | {value}
10 |
11 |
12 | {label}
13 |
14 |
15 |
16 | )
17 |
--------------------------------------------------------------------------------
/examples/react-sample/src/components/ValuesRow.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import Grid from '@material-ui/core/Grid';
3 | import Paper from '@material-ui/core/Paper';
4 | import Typography from '@material-ui/core/Typography';
5 |
6 |
7 | export default ({ classes, children, title }) => (
8 |
9 |
10 |
11 | {title}
12 |
13 |
14 |
15 | {children}
16 |
17 |
18 |
19 | );
20 |
--------------------------------------------------------------------------------
/examples/react-sample/src/index.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom';
3 | import Home from './pages/Home';
4 |
5 | ReactDOM.render(
6 | ,
7 | document.getElementById('app')
8 | );
9 |
10 | module.hot.accept();
11 |
--------------------------------------------------------------------------------
/examples/react-sample/src/pages/Home.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import PropTypes from 'prop-types';
3 | import Grid from '@material-ui/core/Grid';
4 | import Button from '@material-ui/core/Button';
5 | import Typography from '@material-ui/core/Typography';
6 | import * as tf from '@tensorflow/tfjs';
7 | import * as hpjs from '../../../../src/index';
8 |
9 | import withRoot from '../withRoot';
10 | import ValuesRow from '../components/ValuesRow';
11 | import Value from '../components/Value';
12 | import ExperimentsTable from '../components/ExperimentsTable';
13 |
14 | const trainModel = async ({ optimizer, epochs, onEpochEnd }, { xs, ys }) => {
15 | // Create a simple model.
16 | const model = tf.sequential();
17 | model.add(tf.layers.dense({ units: 1, inputShape: xs.shape.slice(1) }));
18 | // Prepare the model for training: Specify the loss and the optimizer.
19 | model.compile({
20 | loss: 'meanSquaredError',
21 | optimizer
22 | });
23 | // Train the model using the data.
24 | const h = await model.fit(xs, ys, { epochs, callbacks: { onEpochEnd } });
25 | return { model, loss: h.history.loss[h.history.loss.length - 1] };
26 | };
27 |
28 | // Generate some synthetic data for training. (y = 2x - 1) and pass to fmin as parameters
29 | const createData = () => ({
30 | xs: tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]),
31 | ys: tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]),
32 | });
33 |
34 | class Home extends React.Component {
35 | state = {
36 | experiments: [],
37 | epoch: undefined,
38 | experimentBegin: undefined,
39 | experimentEnd: undefined,
40 | };
41 |
42 | onEpochEnd = async (epoch, logs) => {
43 | this.setState({ epoch: { epoch, logs } });
44 | await tf.nextFrame();
45 | };
46 |
47 | onExperimentBegin = async (id, experiment) => {
48 | this.setState({ experimentBegin: { id, experiment } });
49 | await tf.nextFrame();
50 | };
51 | onExperimentEnd = async (id, experiment) => {
52 | this.setState({
53 | experimentEnd: { id, experiment },
54 | experiments: [...this.state.experiments, experiment]
55 | });
56 | await tf.nextFrame();
57 | };
58 |
59 | onRunClick = async () => {
60 | this.setState({
61 | epoch: undefined,
62 | experimentBegin: undefined,
63 | experimentEnd: undefined,
64 | experiments: [],
65 | best: undefined,
66 | });
67 | // fmin optmization function, retuns the loss and a STATUS_OK
68 | async function modelOpt({ optimizer, epochs }, { xs, ys, onEpochEnd }) {
69 | const { loss } = await trainModel({ optimizer, epochs, onEpochEnd }, { xs, ys });
70 | return { loss, status: hpjs.STATUS_OK };
71 | }
72 |
73 | // hyperparameters search space
74 | // optmizer is a choice field
75 | // epochs ia an integer value from 10 to 250 with a step of 5
76 | const space = {
77 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']),
78 | epochs: hpjs.quniform(10, 30, 10),
79 | };
80 | tf.ENV.engine.startScope();
81 | // Generate some synthetic data for training. (y = 2x - 1) and pass to fmin as parameters
82 | // data will be passed as a parameters to the fmin
83 | const { xs, ys } = createData();
84 | const experiments = await hpjs.fmin(
85 | modelOpt, space, hpjs.search.randomSearch, 10,
86 | {
87 | rng: new hpjs.RandomState(654321),
88 | xs,
89 | ys,
90 | onEpochEnd: this.onEpochEnd,
91 | callbacks: {
92 | onExperimentBegin: this.onExperimentBegin,
93 | onExperimentEnd: this.onExperimentEnd
94 | }
95 | }
96 | );
97 | this.setState({ best: experiments.argmin });
98 | tf.ENV.engine.endScope();
99 | };
100 |
101 | onPredictClick = async () => {
102 | const { best } = this.state;
103 | tf.ENV.engine.startScope();
104 | const { xs, ys } = createData();
105 | const { model } = await trainModel(best, { xs, ys });
106 | const prediction = model.predict(tf.tensor2d([20], [1, 1]));
107 | this.setState({ prediction: prediction.dataSync() });
108 | tf.ENV.engine.endScope();
109 | };
110 |
111 | render() {
112 | const { classes } = this.props;
113 | const {
114 | epoch, experimentBegin, experimentEnd, experiments, best, prediction
115 | } = this.state;
116 |
117 | const spacing = 24;
118 | return (
119 |
120 |
121 | TensorFlow "tiny"
122 |
123 |
124 | {experimentBegin && (
125 |
126 |
127 | {Object.keys(experimentBegin.experiment.args).map(key => (
128 |
129 | ))}
130 | {epoch && (
131 |
132 |
133 |
134 |
135 | )}
136 |
137 |
138 | )}
139 | {experimentEnd && (
140 |
141 |
142 | {Object.keys(experimentEnd.experiment.args).map(key => (
143 |
144 | ))}
145 |
146 |
147 |
148 | )}
149 | {best && (
150 |
151 | {Object.keys(best).map(key => (
152 |
153 | ))}
154 |
155 |
158 |
159 | {prediction && (
160 |
161 | )}
162 |
163 |
164 | )}
165 |
166 |
169 |
170 |
171 |
172 | Experiments
173 |
174 |
175 |
176 |
177 |
178 | );
179 | }
180 | }
181 |
182 | Home.propTypes = {
183 | classes: PropTypes.object.isRequired,
184 | };
185 |
186 | export default withRoot(Home);
187 |
--------------------------------------------------------------------------------
/examples/react-sample/src/withRoot.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import { MuiThemeProvider, createMuiTheme } from '@material-ui/core/styles';
3 | import orange from '@material-ui/core/colors/orange';
4 | import blueGrey from '@material-ui/core/colors/blueGrey';
5 | import CssBaseline from '@material-ui/core/CssBaseline';
6 | import { withStyles } from "@material-ui/core/styles/index";
7 | import AppBar from '@material-ui/core/AppBar';
8 | import Toolbar from '@material-ui/core/Toolbar';
9 | import Typography from '@material-ui/core/Typography';
10 |
11 |
12 | const theme = createMuiTheme({
13 | palette: {
14 | primary: {
15 | light: orange[300],
16 | main: orange[800],
17 | dark: orange[900],
18 | },
19 | secondary: {
20 | light: blueGrey[300],
21 | main: blueGrey[500],
22 | dark: blueGrey[700],
23 | },
24 | },
25 | });
26 |
27 | const styles = theme => ({
28 | root: {
29 | textAlign: 'center',
30 | flexGrow: 1,
31 | padding: theme.spacing.unit * 10,
32 | },
33 | inforow: {
34 | padding: theme.spacing.unit,
35 | },
36 | });
37 |
38 | function withRoot(Component) {
39 | function WithRoot(props) {
40 | return (
41 |
42 |
43 |
44 |
45 |
46 | hpjs
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | );
55 | }
56 |
57 | return withStyles(styles)(WithRoot);
58 | }
59 |
60 | export default withRoot;
61 |
--------------------------------------------------------------------------------
/examples/react-sample/webpack.config.js:
--------------------------------------------------------------------------------
1 | const webpack = require('webpack');
2 |
3 | module.exports = {
4 | entry: [
5 | 'babel-polyfill',
6 | 'react-hot-loader/patch',
7 | './src/index.js'
8 | ],
9 | module: {
10 | rules: [
11 | {
12 | test: /\.(js|jsx)$/,
13 | exclude: /node_modules/,
14 | use: ['babel-loader']
15 | }
16 | ]
17 | },
18 | resolve: {
19 | extensions: ['*', '.js', '.jsx']
20 | },
21 | output: {
22 | path: __dirname + '/dist',
23 | publicPath: '/',
24 | filename: 'bundle.js'
25 | },
26 | plugins: [
27 | new webpack.HotModuleReplacementPlugin()
28 | ],
29 | devServer: {
30 | contentBase: './dist',
31 | hot: true
32 | }
33 | };
34 |
--------------------------------------------------------------------------------
/examples/tiny/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow.js Tiny Example with HyperParameters.js
2 |
3 | This minimal example loads tfjs and hpjs from a CDN, builds and trains a minimal model,
4 | and finds the optimal optimizer and number of epochs.
5 |
6 | ## Getting started
7 |
8 | * include (latest) version from cdn
9 |
10 | ``
11 |
12 | * create search space
13 | ```
14 | const space = {
15 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']),
16 | epochs: hpjs.quniform(50, 250, 50),
17 | };
18 |
19 | ```
20 | * create tensorflow.js train function. Parameters are optimizer and epochs. input and output data passed as second argument
21 | ```
22 | const trainModel = async ({ optimizer, epochs }, { xs, ys }) => {
23 | // Create a simple model.
24 | const model = tf.sequential();
25 | model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
26 | // Prepare the model for training: Specify the loss and the optimizer.
27 | model.compile({
28 | loss: 'meanSquaredError',
29 | optimizer
30 | });
31 | // Train the model using the data.
32 | const h = await model.fit(xs, ys, { epochs });
33 | return { model, loss: h.history.loss[h.history.loss.length - 1] };
34 | };
35 | ```
36 | * create optimization function
37 | ```
38 | const modelOpt = async ({ optimizer, epochs }, { xs, ys }) => {
39 | const { loss } = await trainModel({ optimizer, epochs }, { xs, ys });
40 | return { loss, status: hpjs.STATUS_OK };
41 | };
42 | ```
43 |
44 | * find optimal hyperparameters
45 | ```
46 | const trials = await hpjs.fmin(
47 | modelOpt, space, hpjs.search.randomSearch, 10,
48 | { rng: new hpjs.RandomState(654321), xs, ys }
49 | );
50 | const opt = trials.argmin;
51 | console.log('best optimizer',opt.optimizer);
52 | console.log('best no of epochs', opt.epochs);
53 | ```
54 |
--------------------------------------------------------------------------------
/examples/tiny/index.html:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | Tiny TFJS + HPJS example
13 |
14 | const model = tf.sequential();
15 | model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
16 | model.compile({
17 | loss: 'meanSquaredError',
18 | optimizer
19 | });
20 | await model.fit(xs, ys, { epochs });
21 |
22 | Best optimizer:
23 |
24 | Epochs:
25 |
26 |
27 | Prediction:
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/examples/tiny/index.js:
--------------------------------------------------------------------------------
1 | /* eslint-disable no-undef */
2 |
3 | // Starting with the tensorflow.js Tiny example, this sampple illustrates how little code is
4 | // necessary to build / train / optimize / predict from a model
5 | // in TensorFlow.js + Hyperparameters.js
6 |
7 | // function to fit a model, given an optimizer and number of epochs
8 | // returns the model and the final loss
9 | const trainModel = async ({ optimizer, epochs }, { xs, ys }) => {
10 | // Create a simple model.
11 | const model = tf.sequential();
12 | model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
13 | // Prepare the model for training: Specify the loss and the optimizer.
14 | model.compile({
15 | loss: 'meanSquaredError',
16 | optimizer
17 | });
18 | // Train the model using the data.
19 | const h = await model.fit(xs, ys, { epochs });
20 | return { model, loss: h.history.loss[h.history.loss.length - 1] };
21 | };
22 |
23 | // fmin optmization function, retuns the loss and a STATUS_OK
24 | const modelOpt = async ({ optimizer, epochs }, { xs, ys }) => {
25 | const { loss } = await trainModel({ optimizer, epochs }, { xs, ys });
26 | return { loss, status: hpjs.STATUS_OK };
27 | };
28 |
29 | async function launchHPJS() {
30 | // hyperparameters search space
31 | // optmizer is a choice field
32 | // epochs ia an integer value from 10 to 250 with a step of 5
33 | const space = {
34 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']),
35 | epochs: hpjs.quniform(50, 250, 50),
36 | };
37 | // Generate some synthetic data for training. (y = 2x - 1) and pass to fmin as parameters
38 | // data will be passed as a parameters to the fmin
39 | const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
40 | const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
41 |
42 | const trials = await hpjs.fmin(
43 | modelOpt, space, hpjs.search.randomSearch, 10,
44 | { rng: new hpjs.RandomState(654321), xs, ys }
45 | );
46 | const opt = trials.argmin;
47 | document.getElementById('optimizer_best').innerText = opt.optimizer;
48 | document.getElementById('epochs').innerText = opt.epochs;
49 | const { model } = await trainModel(opt, { xs, ys });
50 | const prediction = model.predict(tf.tensor2d([20], [1, 1]));
51 | document.getElementById('prediction').innerText += prediction.dataSync();
52 | }
53 |
54 | launchHPJS();
55 |
--------------------------------------------------------------------------------
/examples/tiny/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tfjs-hpjs-tiny",
3 | "version": "0.1.0",
4 | "description": "",
5 | "main": "index.js",
6 | "license": "Apache-2.0",
7 | "private": true,
8 | "engines": {
9 | "node": ">=8.9.0"
10 | },
11 | "dependencies": { },
12 | "scripts": {
13 | "watch": "npm run build && node_modules/http-server/bin/http-server dist -p 1234 ",
14 | "build": "mkdir -p dist/ && cp index.html dist/ && cp index.js dist/"
15 | },
16 | "devDependencies": {
17 | "clang-format": "~1.2.2",
18 | "http-server": "~0.10.0"
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "hyperparameters",
3 | "version": "0.25.6",
4 | "description": "javascript hyperparameters search",
5 | "main": "lib/hyperparameters.js",
6 | "unpkg": "dist/hyperparameters.js",
7 | "module": "es/hyperparameters.js",
8 | "scripts": {
9 | "test": "npm run lint && npm run cover",
10 | "test:prod": "cross-env BABEL_ENV=production npm run test",
11 | "test:only": "mocha $NODE_DEBUG_OPTION --require babel-core/register --require babel-polyfill --recursive",
12 | "test:watch": "npm test -- --watch",
13 | "cover": "istanbul cover _mocha -- --require babel-core/register --require babel-polyfill --recursive",
14 | "lint": "eslint src test",
15 | "clean": "rimraf lib dist es coverage",
16 | "build:commonjs": "cross-env NODE_ENV=cjs rollup -c -o lib/hyperparameters.js",
17 | "build:es": "cross-env BABEL_ENV=es NODE_ENV=es rollup -c -o es/hyperparameters.js",
18 | "build:umd": "cross-env BABEL_ENV=es NODE_ENV=development rollup -c -o dist/hyperparameters.js",
19 | "build:umd:min": "cross-env BABEL_ENV=es NODE_ENV=production rollup -c -o dist/hyperparameters.min.js",
20 | "build": "npm run build:commonjs && npm run build:es && npm run build:umd && npm run build:umd:min",
21 | "prepare": "npm run clean && npm test && npm run build",
22 | "pub": "npm run build && npm publish"
23 | },
24 | "files": [
25 | "dist",
26 | "lib",
27 | "es",
28 | "src"
29 | ],
30 | "repository": {
31 | "type": "git",
32 | "url": "git+https://github.com/atanasster/hyperparameters.git"
33 | },
34 | "keywords": [
35 | "hyperparameters",
36 | "hyperopt",
37 | "tensorflow",
38 | "es6",
39 | "tfjs",
40 | "javascript"
41 | ],
42 | "contributors": [
43 | {
44 | "name": "Atanas Stoyanov",
45 | "email": "atanasster@gmail.com",
46 | "url": "https://github.com/atanasster"
47 | },
48 | {
49 | "name": "Martin Stoyanov",
50 | "email": "martin.a.stoyanov@gmail.com",
51 | "url": "https://github.com/martin-stoyanov"
52 | }
53 | ],
54 | "license": "Apache-2.0",
55 | "bugs": {
56 | "url": "https://github.com/atanasster/hyperparameters/issues"
57 | },
58 | "homepage": "https://github.com/atanasster/hyperparameters#readme",
59 | "devDependencies": {
60 | "babel-cli": "^6.26.0",
61 | "babel-eslint": "^8.2.1",
62 | "babel-plugin-add-module-exports": "^0.2.1",
63 | "babel-plugin-external-helpers": "^6.22.0",
64 | "babel-polyfill": "^6.26.0",
65 | "babel-preset-env": "^1.7.0",
66 | "babel-preset-minify": "^0.3.0",
67 | "babel-preset-stage-0": "^6.24.1",
68 | "chai": "^4.1.2",
69 | "chai-snapshot-tests": "^0.6.0",
70 | "cross-env": "^5.1.3",
71 | "eslint": "^4.16.0",
72 | "eslint-config-airbnb": "^16.1.0",
73 | "eslint-plugin-import": "^2.7.0",
74 | "eslint-plugin-jsx-a11y": "^6.0.2",
75 | "eslint-plugin-react": "^7.4.0",
76 | "istanbul": "^1.0.0-alpha",
77 | "mocha": "^5.0.0",
78 | "pre-commit": "^1.2.2",
79 | "rimraf": "^2.6.2",
80 | "rollup": "^0.60.1",
81 | "rollup-plugin-babel": "^3.0.4",
82 | "rollup-plugin-commonjs": "^9.1.3",
83 | "rollup-plugin-node-resolve": "^3.3.0",
84 | "rollup-plugin-replace": "^2.0.0",
85 | "rollup-plugin-uglify": "^3.0.0",
86 | "rollup-watch": "^4.3.1"
87 | },
88 | "pre-commit": [
89 | "test"
90 | ]
91 | }
92 |
--------------------------------------------------------------------------------
/rollup.config.js:
--------------------------------------------------------------------------------
1 | import nodeResolve from 'rollup-plugin-node-resolve';
2 | import babel from 'rollup-plugin-babel';
3 | import replace from 'rollup-plugin-replace';
4 | import uglify from 'rollup-plugin-uglify';
5 |
6 | const env = process.env.NODE_ENV;
7 | const config = {
8 | input: 'src/index.js',
9 | plugins: []
10 | };
11 |
12 | if (env === 'es' || env === 'cjs') {
13 | config.output = { format: env, indent: false };
14 | config.external = ['symbol-observable'];
15 | config.plugins.push(babel({
16 | plugins: ['external-helpers'],
17 | }));
18 | }
19 |
20 | if (env === 'development' || env === 'production') {
21 | config.output = { format: 'umd', name: 'hpjs', indent: false };
22 | config.plugins.push(
23 | nodeResolve({
24 | jsnext: true
25 | }),
26 | babel({
27 | exclude: 'node_modules/**',
28 | plugins: ['external-helpers'],
29 | }),
30 | replace({
31 | 'process.env.NODE_ENV': JSON.stringify(env)
32 | })
33 | );
34 | }
35 |
36 | if (env === 'production') {
37 | config.plugins.push(uglify({
38 | compress: {
39 | pure_getters: true,
40 | unsafe: true,
41 | unsafe_comps: true,
42 | warnings: false
43 | }
44 | }));
45 | }
46 |
47 | export default config;
48 |
--------------------------------------------------------------------------------
/src/base/base.js:
--------------------------------------------------------------------------------
1 | import RandomState from '../utils/RandomState';
2 |
3 |
4 | export default class BaseSpace {
5 | eval = (expr, { rng: rState }) => {
6 | if (expr === undefined || expr === null) {
7 | return expr;
8 | }
9 | let rng = rState;
10 | if (rng === undefined) {
11 | rng = new RandomState();
12 | }
13 | const { name, ...rest } = expr;
14 | const space = this[name];
15 | if (typeof space !== 'function') {
16 | if (Array.isArray(expr)) {
17 | return expr.map(item => this.eval(item, { rng }));
18 | }
19 | if (typeof expr === 'object') {
20 | return Object.keys(expr)
21 | .reduce((r, key) => ({ ...r, [key]: this.eval(expr[key], { rng }) }), {});
22 | }
23 | return expr;
24 | }
25 | return space(rest, rng);
26 | };
27 | }
28 |
29 | export const STATUS_NEW = 'new';
30 | export const STATUS_RUNNING = 'running';
31 | export const STATUS_SUSPENDED = 'suspended';
32 | export const STATUS_OK = 'ok';
33 | export const STATUS_FAIL = 'fail';
34 | export const STATUS_STRINGS = [
35 | 'new', // computations have not started
36 | 'running', // computations are in prog
37 | 'suspended', // computations have been suspended, job is not finished
38 | 'ok', // computations are finished, terminated normally
39 | 'fail']; // computations are finished, terminated with error
40 |
41 |
42 | // -- named constants for job execution pipeline
43 | export const JOB_STATE_NEW = 0;
44 | export const JOB_STATE_RUNNING = 1;
45 | export const JOB_STATE_DONE = 2;
46 | export const JOB_STATE_ERROR = 3;
47 | export const JOB_STATES = [
48 | JOB_STATE_NEW,
49 | JOB_STATE_RUNNING,
50 | JOB_STATE_DONE,
51 | JOB_STATE_ERROR];
52 |
53 |
54 | export const TRIAL_KEYS = [
55 | 'id',
56 | 'result',
57 | 'args',
58 | 'state',
59 | 'book_time',
60 | 'refresh_time',
61 | ];
62 |
63 | export const range = (start, end) => Array.from({ length: (end - start) }, (v, k) => k + start);
64 |
65 | export class Trials {
66 | constructor(expKey = null, refresh = true) {
67 | this.ids = [];
68 | this.dynamicTrials = [];
69 | this.trials = [];
70 | this.expKey = expKey;
71 | if (refresh) {
72 | this.refresh();
73 | }
74 | }
75 |
76 | get length() {
77 | return this.trials.length;
78 | }
79 |
80 | refresh = () => {
81 | if (this.expKey === null) {
82 | this.trials = this.dynamicTrials
83 | .filter(trial => trial.state !== JOB_STATE_ERROR);
84 | } else {
85 | this.trials = this.dynamicTrials
86 | .filter(trial => trial.state !== JOB_STATE_ERROR && trial.expKey === this.expKey);
87 | this.ids = [];
88 | }
89 | };
90 |
91 | get results() {
92 | return this.trials.map(trial => trial.result);
93 | }
94 |
95 | get args() {
96 | return this.trials.map(trial => trial.args);
97 | }
98 |
99 | assertValidTrial = (trial) => {
100 | if (Object.keys(trial).length <= 0) {
101 | throw new Error('trial should be an object');
102 | }
103 | const missingTrialKey = TRIAL_KEYS.find(key => trial[key] === undefined);
104 | if (missingTrialKey !== undefined) {
105 | throw new Error(`trial missing key ${missingTrialKey}`);
106 | }
107 | if (trial.expKey !== this.expKey) {
108 | throw new Error(`wrong trial expKey ${trial.expKey}, expected ${this.expKey}`);
109 | }
110 | return trial;
111 | };
112 |
113 | internalInsertTrialDocs = (docs) => {
114 | const rval = docs.map(doc => doc.id);
115 | this.dynamicTrials = [...this.dynamicTrials, ...docs];
116 | return rval;
117 | };
118 |
119 | insertTrialDoc = (trial) => {
120 | const doc = this.assertValidTrial(trial);
121 | return this.internalInsertTrialDocs([doc])[0];
122 | };
123 |
124 | insertTrialDocs = (trials) => {
125 | const docs = trials.map(trial => this.assertValidTrial(trial));
126 | return this.internalInsertTrialDocs(docs);
127 | };
128 |
129 | newTrialIds = (N) => {
130 | const aa = this.ids.length;
131 | const rval = range(aa, aa + N);
132 | this.ids = [...this.ids, ...rval];
133 | return rval;
134 | };
135 |
136 | newTrialDocs = (ids, results, args) => {
137 | const rval = [];
138 | for (let i = 0; i < ids.length; i += 1) {
139 | const doc = {
140 | state: JOB_STATE_NEW,
141 | id: ids[i],
142 | result: results[i],
143 | args: args[i],
144 | };
145 | doc.expKey = this.expKey;
146 | doc.book_time = null;
147 | doc.refresh_time = null;
148 | rval.push(doc);
149 | }
150 | return rval;
151 | };
152 |
153 | deleteAll = () => {
154 | this.dynamicTrials = [];
155 | this.refresh();
156 | };
157 |
158 | countByStateSynced = (arg, trials = null) => {
159 | const vTrials = trials === null ? this.trials : trials;
160 | const vArg = Array.isArray(arg) ? arg : [arg];
161 | const queue = vTrials.filter(doc => vArg.indexOf(doc.state) >= 0);
162 | return queue.length;
163 | };
164 |
165 | countByStateUnsynced = (arg) => {
166 | const expTrials = this.expKey !== null ?
167 | this.dynamicTrials.map(trial => trial.expKey === this.expKey) : this.dynamicTrials;
168 | return this.countByStateSynced(arg, expTrials);
169 | };
170 |
171 | losses = () => this.results.map(r => r.loss || r.accuracy);
172 |
173 | statuses = () => this.results.map(r => r.status);
174 |
175 | bestTrial(compare = (a, b) =>
176 | (a.loss !== undefined ? a.loss < b.loss : a.accuracy > b.accuracy)) {
177 | let best = this.trials[0];
178 | this.trials.forEach((trial) => {
179 | if (trial.result.status === STATUS_OK && compare(trial.result, best.result)) {
180 | best = trial;
181 | }
182 | });
183 | return best;
184 | }
185 |
186 | get argmin() {
187 | const best = this.bestTrial();
188 | return best !== undefined ? best.args : undefined;
189 | }
190 |
191 | get argmax() {
192 | const best = this.bestTrial((a, b) =>
193 | (a.loss !== undefined ? a.loss > b.loss : a.accuracy > b.accuracy));
194 | return best !== undefined ? best.args : undefined;
195 | }
196 | }
197 |
198 | export class Domain {
199 | constructor(fn, expr, params) {
200 | this.fn = fn;
201 | this.expr = expr;
202 | this.params = params;
203 | }
204 |
205 | evaluate = async (args) => {
206 | const rval = await this.fn(args, this.params);
207 | let result;
208 | if (typeof rval === 'number' && !Number.isNaN(rval)) {
209 | result = { loss: rval, status: STATUS_OK };
210 | } else {
211 | result = rval;
212 | if (result === undefined) {
213 | throw new Error('Optimization function should return a loss value');
214 | }
215 | const { status, loss, accuracy } = result;
216 | if (STATUS_STRINGS.indexOf(status) < 0) {
217 | throw new Error(`invalid status ${status}`);
218 | }
219 | if (status === STATUS_OK && loss === undefined && accuracy === undefined) {
220 | throw new Error('invalid loss and accuracy');
221 | }
222 | }
223 | return result;
224 | };
225 | newResult = () => ({
226 | status: STATUS_NEW,
227 | });
228 | }
229 |
--------------------------------------------------------------------------------
/src/base/fmin.js:
--------------------------------------------------------------------------------
1 | /* eslint-disable camelcase,no-await-in-loop */
2 | import RandomState from '../utils/RandomState';
3 | import { Trials, Domain, JOB_STATE_NEW, JOB_STATE_RUNNING, JOB_STATE_ERROR, JOB_STATE_DONE } from './base';
4 |
5 | const getTimeStatmp = () => new Date().getTime();
6 |
7 | class FMinIter {
8 | constructor(
9 | algo, domain, trials,
10 | {
11 | rng,
12 | catchExceptions = false,
13 | max_queue_len = 1,
14 | max_evals = Number.MAX_VALUE,
15 | } = {},
16 | params = {}
17 | ) {
18 | this.catchExceptions = catchExceptions;
19 | this.algo = algo;
20 | this.domain = domain;
21 | this.trials = trials;
22 | this.callbacks = params.callbacks || {};
23 | this.max_queue_len = max_queue_len;
24 | this.max_evals = max_evals;
25 | this.rng = rng;
26 | }
27 | async serial_evaluate(N = -1) {
28 | const { onExperimentBegin, onExperimentEnd } = this.callbacks;
29 | let n = N;
30 | let stopped = false;
31 | for (let i = 0; i < this.trials.dynamicTrials.length; i += 1) {
32 | const trial = this.trials.dynamicTrials[i];
33 | if (trial.state === JOB_STATE_NEW) {
34 | trial.state = JOB_STATE_RUNNING;
35 | const now = getTimeStatmp();
36 | trial.book_time = now;
37 | trial.refresh_time = now;
38 | try {
39 | if (typeof onExperimentBegin === 'function') {
40 | if (await onExperimentBegin(i, trial) === true) {
41 | stopped = true;
42 | }
43 | }
44 | // eslint-disable-next-line no-await-in-loop
45 | const result = await this.domain.evaluate(trial.args);
46 | trial.state = JOB_STATE_DONE;
47 | trial.result = result;
48 | trial.refresh_time = getTimeStatmp();
49 | } catch (e) {
50 | trial.state = JOB_STATE_ERROR;
51 | trial.error = `${e}, ${e.message}`;
52 | trial.refresh_time = getTimeStatmp();
53 | if (!this.catchExceptions) {
54 | this.trials.refresh();
55 | throw e;
56 | }
57 | }
58 | if (typeof onExperimentEnd === 'function') {
59 | if (await onExperimentEnd(i, trial) === true) {
60 | stopped = true;
61 | }
62 | }
63 | }
64 | n -= 1;
65 | if (n === 0 || stopped) {
66 | break;
67 | }
68 | }
69 | this.trials.refresh();
70 | return stopped;
71 | }
72 |
73 | run = async (N) => {
74 | const { trials, algo } = this;
75 | let n_queued = 0;
76 |
77 | const get_queue_len = () => this.trials.countByStateUnsynced(JOB_STATE_NEW);
78 |
79 | let stopped = false;
80 | while (n_queued < N) {
81 | let qlen = get_queue_len();
82 | while (qlen < this.max_queue_len && n_queued < N) {
83 | const n_to_enqueue = Math.min(this.max_queue_len - qlen, N - n_queued);
84 | const new_ids = trials.newTrialIds(n_to_enqueue);
85 | trials.refresh();
86 | const new_trials = algo(
87 | new_ids, this.domain, trials,
88 | this.rng.randrange(0, (2 ** 31) - 1)
89 | );
90 | console.assert(new_ids.length >= new_trials.length);
91 | if (new_trials.length) {
92 | this.trials.insertTrialDocs(new_trials);
93 | this.trials.refresh();
94 | n_queued += new_trials.length;
95 | qlen = get_queue_len();
96 | } else {
97 | stopped = true;
98 | break;
99 | }
100 | }
101 | stopped = stopped || await this.serial_evaluate();
102 | if (stopped) {
103 | break;
104 | }
105 | }
106 | const qlen = get_queue_len();
107 | if (qlen) {
108 | const msg = `Exiting run, not waiting for ${qlen} jobs.`;
109 | console.error(msg);
110 | }
111 | };
112 |
113 | exhaust = async () => {
114 | const n_done = this.trials.length;
115 | await this.run(this.max_evals - n_done);
116 | this.trials.refresh();
117 | return this;
118 | }
119 | }
120 |
121 | export default async (fn, space, algo, max_evals, params = {}) => {
122 | const {
123 | trials: defTrials, rng: rngDefault,
124 | catchExceptions = false,
125 | } = params;
126 |
127 | let rng;
128 | if (rngDefault) {
129 | rng = rngDefault;
130 | } else {
131 | rng = new RandomState();
132 | }
133 | let trials;
134 | if (!defTrials) {
135 | trials = new Trials();
136 | } else {
137 | trials = defTrials;
138 | }
139 |
140 | const domain = new Domain(fn, space, params);
141 |
142 | const rval = new FMinIter(
143 | algo, domain, trials,
144 | { max_evals, rng, catchExceptions },
145 | params
146 | );
147 | await rval.exhaust();
148 | return trials;
149 | };
150 |
--------------------------------------------------------------------------------
/src/index.js:
--------------------------------------------------------------------------------
1 | import fmin from './base/fmin';
2 | import { randomSample, randomSearch } from './search/random';
3 | import { gridSample, gridSearch } from './search/grid';
4 | import RandomState from './utils/RandomState';
5 |
6 | export * from './base/base';
7 |
8 | export { fmin, RandomState };
9 |
10 | export const choice = options => ({ name: 'choice', options });
11 | export const randint = upper => ({ name: 'randint', upper });
12 | export const uniform = (low, high) => ({ name: 'uniform', low, high });
13 | export const quniform = (low, high, q) => ({
14 | name: 'quniform', low, high, q
15 | });
16 | export const loguniform = (low, high) => ({ name: 'loguniform', low, high });
17 | export const qloguniform = (low, high, q) => ({
18 | name: 'qloguniform', low, high, q
19 | });
20 | export const normal = (mu, sigma) => ({ name: 'normal', mu, sigma });
21 | export const qnormal = (mu, sigma, q) => ({
22 | name: 'qnormal', mu, sigma, q
23 | });
24 | export const lognormal = (mu, sigma) => ({ name: 'lognormal', mu, sigma });
25 | export const qlognormal = (mu, sigma, q) => ({
26 | name: 'qlognormal', mu, sigma, q
27 | });
28 |
29 | export const search = {
30 | randomSearch,
31 | gridSearch,
32 | };
33 |
34 | export const sample = {
35 | randomSample,
36 | gridSample,
37 | };
38 |
--------------------------------------------------------------------------------
/src/search/grid.js:
--------------------------------------------------------------------------------
1 | /* eslint-disable class-methods-use-this */
2 | import BaseSpace from '../base/base';
3 |
4 | class GridSearchParam {
5 | constructor(params, gs) {
6 | this.params = params;
7 | this.gs = gs;
8 | }
9 |
10 | get numSamples() {
11 | return 1;
12 | }
13 |
14 | getSample = () => undefined;
15 | sample = (index) => {
16 | if (index < 0 || index >= this.numSamples) {
17 | throw new Error(`invalid sample index "${index}"`);
18 | }
19 | return this.getSample(index);
20 | }
21 | }
22 |
23 |
24 | class GridSearchNotImplemented extends GridSearchParam {
25 | get numSamples() {
26 | throw new Error(`Can not evaluate length of non-discrete parameter "${this.params.name}"`);
27 | }
28 | getSample = index => this.params.options[index];
29 | }
30 |
31 | class GridSearchChoice extends GridSearchParam {
32 | get numSamples() {
33 | return this.params.options.reduce((r, e) => (r + this.gs.numSamples(e)), 0);
34 | }
35 | getSample = index => this.params.options[index];
36 | }
37 |
38 | class GridSearchRandInt extends GridSearchParam {
39 | get numSamples() {
40 | return this.params.upper;
41 | }
42 | getSample = index => index;
43 | }
44 |
45 | class GridSearchUniform extends GridSearchParam {
46 | get numSamples() {
47 | return Math.floor((this.params.high - this.params.low) / this.params.q) + 1;
48 | }
49 | getSample = index => this.params.low + (index * this.params.q);
50 | }
51 |
52 | class GridSearchNormal extends GridSearchParam {
53 | get numSamples() {
54 | return Math.floor((4 * this.params.sigma) / this.params.q) + 1;
55 | }
56 | getSample = index => (this.params.mu - (2 * this.params.sigma)) + (index * this.params.q);
57 | }
58 |
59 | const GridSearchParamas = {
60 | choice: GridSearchChoice,
61 | randint: GridSearchRandInt,
62 | quniform: GridSearchUniform,
63 | qloguniform: GridSearchUniform,
64 | qnormal: GridSearchNormal,
65 | qlognormal: GridSearchNormal,
66 | uniform: GridSearchNotImplemented,
67 | loguniform: GridSearchNotImplemented,
68 | normal: GridSearchNotImplemented,
69 | lognormal: GridSearchNotImplemented,
70 | };
71 |
72 | export class GridSearch extends BaseSpace {
73 | numSamples = (expr) => {
74 | const flat = this.samples(expr);
75 | return flat.reduce((r, o) => r * o.samples, 1);
76 | };
77 | samples = (expr) => {
78 | if (!expr) {
79 | return expr;
80 | }
81 | const flat = [];
82 | const { name } = expr;
83 | const Param = GridSearchParamas[name];
84 | if (Param === undefined) {
85 | if (Array.isArray(expr)) {
86 | expr.forEach(el => flat.push(...this.samples(el)));
87 | }
88 | if (typeof expr === 'string') {
89 | flat.push({ name, samples: 1, expr });
90 | }
91 | if (typeof expr === 'object') {
92 | Object.keys(expr).forEach(key => flat.push(...this.samples(expr[key])));
93 | }
94 | } else {
95 | flat.push({ name, samples: (new Param(expr, this)).numSamples, expr });
96 | }
97 | return flat;
98 | };
99 | }
100 |
101 | export const gridSample = (space, params = {}) => {
102 | const gs = new GridSearch();
103 | const args = gs.eval(space, params);
104 | if (Object.keys(args).length === 1) {
105 | const results = Object.keys(args).map(key => args[key]);
106 | return results.length === 1 ? results[0] : results;
107 | }
108 | return args;
109 | };
110 |
111 | export const gridSearch = (newIds, domain, trials) => {
112 | let rval = [];
113 | const gs = new GridSearch();
114 | newIds.forEach((newId) => {
115 | const paramsEval = gs.eval(domain.expr);
116 | const result = domain.newResult();
117 | rval = [...rval, ...trials.newTrialDocs([newId], [result], [paramsEval])];
118 | });
119 | return rval;
120 | };
121 |
--------------------------------------------------------------------------------
/src/search/random.js:
--------------------------------------------------------------------------------
1 | import RandomState from '../utils/RandomState';
2 | import BaseSpace from '../base/base';
3 |
4 | export class RandomSearch extends BaseSpace {
5 | choice = (params, rng) => {
6 | const { options } = params;
7 | const idx = rng.randrange(0, options.length, 1);
8 | const option = options[idx];
9 | const arg = this.eval(option, { rng });
10 | return arg;
11 | };
12 |
13 | randint = (params, rng) => rng.randrange(0, params.upper, 1)
14 |
15 | uniform = (params, rng) => {
16 | const { low, high } = params;
17 | return rng.uniform(low, high);
18 | };
19 |
20 | quniform = (params, rng) => {
21 | const { low, high, q } = params;
22 | return Math.round(rng.uniform(low, high) / q) * q;
23 | };
24 |
25 | loguniform = (params, rng) => {
26 | const { low, high } = params;
27 | return Math.exp(rng.uniform(low, high));
28 | };
29 |
30 | qloguniform = (params, rng) => {
31 | const { low, high, q } = params;
32 | return Math.round(Math.exp(rng.uniform(low, high)) / q) * q;
33 | };
34 |
35 | normal = (params, rng) => {
36 | const { mu, sigma } = params;
37 | return rng.gauss(mu, sigma);
38 | };
39 |
40 | qnormal = (params, rng) => {
41 | const { mu, sigma, q } = params;
42 | return Math.round(rng.gauss(mu, sigma) / q) * q;
43 | };
44 |
45 | lognormal = (params, rng) => {
46 | const { mu, sigma } = params;
47 | return Math.exp(rng.gauss(mu, sigma));
48 | };
49 |
50 | qlognormal = (params, rng) => {
51 | const { mu, sigma, q } = params;
52 | return Math.round(Math.exp(rng.gauss(mu, sigma)) / q) * q;
53 | };
54 | }
55 |
56 | export const randomSample = (space, params = {}) => {
57 | const rs = new RandomSearch();
58 | const args = rs.eval(space, params);
59 | if (Object.keys(args).length === 1) {
60 | const results = Object.keys(args).map(key => args[key]);
61 | return results.length === 1 ? results[0] : results;
62 | }
63 | return args;
64 | };
65 |
66 | export const randomSearch = (newIds, domain, trials, seed) => {
67 | const rng = new RandomState(seed);
68 | let rval = [];
69 | const rs = new RandomSearch();
70 | newIds.forEach((newId) => {
71 | const paramsEval = rs.eval(domain.expr, { rng });
72 | const result = domain.newResult();
73 | rval = [...rval, ...trials.newTrialDocs([newId], [result], [paramsEval])];
74 | });
75 | return rval;
76 | };
77 |
--------------------------------------------------------------------------------
/src/utils/RandomState.js:
--------------------------------------------------------------------------------
1 | /* eslint-disable no-bitwise */
2 | // https://gist.github.com/banksean/300494
3 | // Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
4 |
5 | // https://github.com/jrus/random-js/blob/master/random.coffee
6 |
7 | const POW_NEG_26 = 2 ** -26;
8 | const POW_NEG_27 = 2 ** -27;
9 | const POW_32 = 2 ** 32;
10 |
11 | export default class RandomState {
12 | constructor(seed) {
13 | this.bits = {};
14 | this.seed = seed === undefined ? new Date().getTime() : seed;
15 | this.N = 624;
16 | this.M = 397;
17 | this.MATRIX_A = 0x9908b0df; /* constant vector a */
18 | this.UPPER_MASK = 0x80000000; /* most significant w-r bits */
19 | this.LOWER_MASK = 0x7fffffff; /* least significant r bits */
20 |
21 | this.mt = new Array(this.N); /* the array for the state vector */
22 | this.mti = this.N + 1; /* mti==N+1 means mt[N] is not initialized */
23 |
24 | this.initGen(this.seed);
25 | }
26 |
27 | /* initializes mt[N] with a seed */
28 | initGen(seed) {
29 | this.mt[0] = seed >>> 0;
30 | for (this.mti = 1; this.mti < this.N; this.mti += 1) {
31 | const s = this.mt[this.mti - 1] ^ (this.mt[this.mti - 1] >>> 30);
32 | this.mt[this.mti] =
33 | // eslint-disable-next-line no-mixed-operators
34 | (((((s & 0xffff0000) >>> 16) * 1812433253) << 16) + (s & 0x0000ffff) * 1812433253)
35 | + this.mti;
36 | /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */
37 | /* In the previous versions, MSBs of the seed affect */
38 | /* only MSBs of the array mt[]. */
39 | /* 2002/01/09 modified by Makoto Matsumoto */
40 | this.mt[this.mti] >>>= 0;
41 | /* for >32 bit machines */
42 | }
43 | this.next_gauss = null;
44 | }
45 | randint() {
46 | let y;
47 | const mag01 = [0x0, this.MATRIX_A];
48 | /* mag01[x] = x * MATRIX_A for x=0,1 */
49 |
50 | if (this.mti >= this.N) { /* generate N words at one time */
51 | let kk;
52 |
53 | if (this.mti === this.N + 1) { /* if initGen() has not been called, */
54 | this.initGen(5489);
55 | } /* a default initial seed is used */
56 |
57 | for (kk = 0; kk < this.N - this.M; kk += 1) {
58 | y = (this.mt[kk] & this.UPPER_MASK) | (this.mt[kk + 1] & this.LOWER_MASK);
59 | this.mt[kk] = this.mt[kk + this.M] ^ (y >>> 1) ^ mag01[y & 0x1];
60 | }
61 | for (;kk < this.N - 1; kk += 1) {
62 | y = (this.mt[kk] & this.UPPER_MASK) | (this.mt[kk + 1] & this.LOWER_MASK);
63 | this.mt[kk] = this.mt[kk + (this.M - this.N)] ^ (y >>> 1) ^ mag01[y & 0x1];
64 | }
65 | y = (this.mt[this.N - 1] & this.UPPER_MASK) | (this.mt[0] & this.LOWER_MASK);
66 | this.mt[this.N - 1] = this.mt[this.M - 1] ^ (y >>> 1) ^ mag01[y & 0x1];
67 |
68 | this.mti = 0;
69 | }
70 |
71 | y = this.mt[this.mti += 1];
72 |
73 | /* Tempering */
74 | y ^= (y >>> 11);
75 | y ^= (y << 7) & 0x9d2c5680;
76 | y ^= (y << 15) & 0xefc60000;
77 | y ^= (y >>> 18);
78 |
79 | return y >>> 0;
80 | }
81 |
82 | random() {
83 | // Return a random float in the range [0, 1), with a full 53
84 | // bits of entropy.
85 | const val = this.randint();
86 | const lowBits = val >>> 6;
87 | const highBits = val >>> 5;
88 | return (highBits + (lowBits * POW_NEG_26)) * POW_NEG_27;
89 | }
90 |
91 | randbelow(upperBound) {
92 | if (upperBound <= 0) {
93 | return 0;
94 | }
95 | const lg = x => (Math.LOG2E * Math.log(x + 1e-10)) >> 0;
96 | if (upperBound <= 0x100000000) {
97 | let r = upperBound;
98 | const bits = this.bits[upperBound] ||
99 | (this.bits[upperBound] = (lg(upperBound - 1)) + 1); // memoize values for `bits`
100 | while (r >= upperBound) {
101 | r = this.randint() >>> (32 - bits);
102 | if (r < 0) {
103 | r += POW_32;
104 | }
105 | }
106 | return r;
107 | }
108 | return this.randint() % upperBound;
109 | }
110 | randrange(start, stop, step) {
111 | // Return a random integer N in range `[start...stop] by step`
112 | if (stop === undefined) {
113 | return this.randbelow(start);
114 | } else if (!step) {
115 | return start + this.randbelow(stop - start);
116 | }
117 | return start + (step * this.randbelow(Math.floor((stop - start) / step)));
118 | }
119 | gauss(mu = 0, sigma = 1) {
120 | // Gaussian distribution. `mu` is the mean, and `sigma` is the standard
121 | // deviation. Notes:
122 | // * uses the "polar method"
123 | // * we generate pairs; keep one in a cache for next time
124 | let z = this.next_gauss;
125 | if (z != null) {
126 | this.next_gauss = null;
127 | } else {
128 | let s;
129 | let u;
130 | let v;
131 | while (!s || !(s < 1)) {
132 | u = (2 * this.random()) - 1;
133 | v = (2 * this.random()) - 1;
134 | s = (u * u) + (v * v);
135 | }
136 | const w = Math.sqrt((-2 * (Math.log(s))) / s);
137 | z = u * w; this.next_gauss = v * w;
138 | }
139 | return mu + (z * sigma); // Alias for the `gauss` function
140 | }
141 | uniform(a, b) {
142 | // Return a random floating point number N such that a <= N <= b for
143 | // a <= b and b <= N <= a for b < a.
144 | return a + (this.random() * (b - a));
145 | }
146 | }
147 |
--------------------------------------------------------------------------------
/test/__snapshots__/index.js.json:
--------------------------------------------------------------------------------
1 | {
2 | "upper: negative": "0.0000000",
3 | "upper: 0": "0.0000000",
4 | "upper: 1": "0.0000000",
5 | "upper: 1000000": "933394.0000000",
6 | "uniform -1, 1": "0.7803094",
7 | "uniform -100000, -1": "-10985.4183013",
8 | "uniform -1, -10": "-9.0113925",
9 | "uniform 5, 1": "1.4393811",
10 | "uniform 1, 1000000": "890154.8283796",
11 | "uniform 1, 1": "1.0000000",
12 | "quniform -1, 1, 0.1": "0.8000000",
13 | "quniform -100000, -1, -1": "-10985.0000000",
14 | "quniform -1, -10, 0.22222": "-9.1110200",
15 | "quniform 5, 1, -0.111": "1.4430000",
16 | "quniform 1, 1000000, 50": "890150.0000000",
17 | "quniform 1, 1, 0.001": "1.0000000",
18 | "loguniform -1, 1": "2.1821474",
19 | "loguniform -100000, -1": "0.0000000",
20 | "loguniform -1, -10": "0.0001220",
21 | "loguniform 5, 1": "4.2180845",
22 | "loguniform 1, 1": "2.7182818",
23 | "qloguniform -1, 1, 0.1": "2.2000000",
24 | "qloguniform -100000, -1, -1": "0.0000000",
25 | "qloguniform -1, -10, 0.2222": "0.0000000",
26 | "qloguniform 5, 1, -0.111": "4.2180000",
27 | "qloguniform 5, 1, 0.1": "4.2000000",
28 | "qloguniform 1, 1, 0.001": "2.7180000",
29 | "normal -1, 1": "-0.3037819",
30 | "normal -100000, -1": "-100000.6962181",
31 | "normal -1, -10": "-7.9621809",
32 | "normal 5, 1": "5.6962181",
33 | "normal 1, 1": "1.6962181",
34 | "qnormal -1, 1, 0.1": "-0.3000000",
35 | "qnormal -100000, -1, -1": "-100001.0000000",
36 | "qnormal -1, -10, 0.22222": "-7.9999200",
37 | "qnormal 5, 1, -0.111": "5.6610000",
38 | "qnormal 1, 1000000, 50": "50.0000000",
39 | "qnormal 1, 1, 0.001": "1.6960000",
40 | "lognormal -1, 1": "0.7380218",
41 | "lognormal -100000, -1": "0.0000000",
42 | "lognormal -1, -10": "0.0003484",
43 | "lognormal 5, 1": "297.7392463",
44 | "lognormal 1, 1": "5.4532845",
45 | "qlognormal -1, 1, 0.1": "0.7000000",
46 | "qlognormal -100000, -1, -1": "0.0000000",
47 | "qlognormal -1, -10, 0.22222": "0.0000000",
48 | "qlognormal 5, 1, -0.111": "297.7020000",
49 | "qlognormal 5, 1, 0.1": "297.7000000",
50 | "qlognormal 1, 1, 0.001": "5.4530000",
51 | "sample: array": "-3.6724889",
52 | "sample: depth": {
53 | "x": "1.3924362",
54 | "y": "0.1307073",
55 | "choice": null,
56 | "array": [
57 | -0.6553433999116457,
58 | 0.11927848398147867,
59 | false
60 | ],
61 | "obj": {
62 | "u": 2.47930839269838,
63 | "v": 1.703175088041835,
64 | "w": -1.4037662402075997
65 | }
66 | },
67 | "FMin for x^2 - x + 1": {},
68 | "Hyperparameters space": {
69 | "x": "-0.7975415",
70 | "y": "-0.7414791"
71 | },
72 | "choice as array space": {},
73 | "Deep learning space": {
74 | "learning_rate": "0.0000123",
75 | "use_double_q_learning": true,
76 | "layer1_size": "29.0000000",
77 | "layer2_size": "41.0000000",
78 | "layer3_size": "56.0000000",
79 | "future_discount_max": "0.6085554",
80 | "future_discount_increment": "0.0013308",
81 | "recall_memory_size": "8.0000000",
82 | "recall_memory_num_experiences_per_recall": "1275.0000000",
83 | "num_epochs": "2.0000000"
84 | }
85 | }
--------------------------------------------------------------------------------
/test/index.js:
--------------------------------------------------------------------------------
1 | import chai, { assert } from 'chai';
2 | import snapshots from 'chai-snapshot-tests';
3 | import * as hpjs from '../src';
4 | import { GridSearch } from '../src/search/grid';
5 |
6 | chai.use(snapshots(__filename));
7 |
8 | const seededSample = (space) => hpjs.sample.randomSample(space, { rng: new hpjs.RandomState(12345) });
9 |
10 | const floatSeededSample = (space) => hpjs.sample.randomSample(space, { rng: new hpjs.RandomState(12345) }).toFixed(7);
11 |
12 | const objectToFixed = (obj) => Object.keys(obj).reduce((final, key) => {
13 | const value = typeof obj[key] === 'number' ? obj[key].toFixed(7) : obj[key];
14 | return { ...final, [key]: value};
15 | }, {});
16 |
17 | const randFMinSeeded = async (opt, space) => {
18 | const trials = await hpjs.fmin(opt, space, hpjs.search.randomSearch, 100, { rng: new hpjs.RandomState(12345) });
19 | return objectToFixed(trials.argmin);
20 | }
21 |
22 |
23 | describe('hpjs.choice.', () => {
24 | it('is a string', () => {
25 | const val = seededSample(hpjs.choice(['cat', 'dog']));
26 | assert.typeOf(val, 'string');
27 | });
28 | it('is one of the elements', () => {
29 | const val = seededSample(hpjs.choice(['cat', 'dog']));
30 | assert(['cat', 'dog'].indexOf(val) >= 0, 'val was actually ' + val);
31 | });
32 | it('picks a number', () => {
33 | const val = seededSample(hpjs.choice([1, 2, 3, 4]));
34 | assert(val === 4, 'val was actually: ' + val);
35 | });
36 | });
37 |
38 | describe('hpjs.randint.', () => {
39 | it('in range [0,5)', () => {
40 | const val = seededSample(hpjs.randint(5));
41 | assert(val >= 0 && val < 5, `actual value ${val}`);
42 | });
43 | it('Snapshot tests', () => {
44 | assert.snapshot('upper: negative', floatSeededSample(hpjs.randint(-2)));
45 | assert.snapshot('upper: 0', floatSeededSample(hpjs.randint(0)));
46 | assert.snapshot('upper: 1', floatSeededSample(hpjs.randint(1)));
47 | assert.snapshot('upper: 1000000', floatSeededSample(hpjs.randint(1000000)));
48 | });
49 |
50 | });
51 |
52 | describe('hpjs.uniform.', () => {
53 | it('between 0 and 1', () => {
54 | const val = seededSample(hpjs.uniform(0, 1));
55 | assert(val >= 0 && val <= 1, `actual value ${val}`);
56 | });
57 | it('between -1 and 1', () => {
58 | const val = seededSample(hpjs.uniform(-1, 1));
59 | assert(val >= -1 && val <= 1, `actual value ${val}`);
60 | });
61 |
62 | it('Snapshot tests', () => {
63 | assert.snapshot('uniform -1, 1', floatSeededSample(hpjs.uniform(-1, 1)));
64 | assert.snapshot('uniform -100000, -1', floatSeededSample(hpjs.uniform(-100000, -1)));
65 | assert.snapshot('uniform -1, -10', floatSeededSample(hpjs.uniform(-1, -10)));
66 | assert.snapshot('uniform 5, 1', floatSeededSample(hpjs.uniform(5, 1)));
67 | assert.snapshot('uniform 1, 1000000', floatSeededSample(hpjs.uniform(1, 1000000)));
68 | assert.snapshot('uniform 1, 1', floatSeededSample(hpjs.uniform(1, 1)));
69 | });
70 | });
71 |
72 | describe('hpjs.quniform.', () => {
73 | it('between 0 and 1', () => {
74 | const val = seededSample(hpjs.quniform(0, 1, 0.2));
75 | assert(val >= 0 && val <= 1, `actual value ${val}`);
76 | });
77 | it('between -1 and 1, step 1', () => {
78 | const val = seededSample(hpjs.quniform(-1, 1, 1));
79 | assert(val === -1 || val === 0 || val === 1, `actual value ${val}`);
80 | });
81 | it('Snapshot tests', () => {
82 | assert.snapshot('quniform -1, 1, 0.1', floatSeededSample(hpjs.quniform(-1, 1, 0.1)));
83 | assert.snapshot('quniform -100000, -1, -1', floatSeededSample(hpjs.quniform(-100000, -1, -1)));
84 | assert.snapshot('quniform -1, -10, 0.22222', floatSeededSample(hpjs.quniform(-1, -10, 0.22222)));
85 | assert.snapshot('quniform 5, 1, -0.111', floatSeededSample(hpjs.quniform(5, 1, -0.111)));
86 | assert.snapshot('quniform 1, 1000000, 50', floatSeededSample(hpjs.quniform(1, 1000000, 50)));
87 | assert.snapshot('quniform 1, 1, 0.001', floatSeededSample(hpjs.quniform(1, 1, 0.001)));
88 | });
89 | });
90 |
91 | describe('hpjs.loguniform.', () => {
92 | it('between e^0 and e^1', () => {
93 | const low = 0;
94 | const high = 1;
95 | const val = seededSample(hpjs.loguniform(low, high));
96 | assert(val >= Math.exp(low) && val <= Math.exp(high), `actual value ${val}`);
97 | });
98 | it('Snapshot tests', () => {
99 | assert.snapshot('loguniform -1, 1', floatSeededSample(hpjs.loguniform(-1, 1)));
100 | assert.snapshot('loguniform -100000, -1', floatSeededSample(hpjs.loguniform(-100000, -1)));
101 | assert.snapshot('loguniform -1, -10', floatSeededSample(hpjs.loguniform(-1, -10)));
102 | assert.snapshot('loguniform 5, 1', floatSeededSample(hpjs.loguniform(5, 1)));
103 | assert.snapshot('loguniform 5, 1', floatSeededSample(hpjs.loguniform(5, 1)));
104 | assert.snapshot('loguniform 1, 1', floatSeededSample(hpjs.loguniform(1, 1)));
105 | });
106 | });
107 |
108 | describe('hpjs.qloguniform.', () => {
109 | it('e^0 and e^1', () => {
110 | const low = 0;
111 | const high = 1;
112 | const val = seededSample(hpjs.qloguniform(low, high, 0.2));
113 | assert(val >= Math.exp(low) && val <= Math.exp(high), `actual value ${val}`);
114 | });
115 | it('Snapshot tests', () => {
116 | assert.snapshot('qloguniform -1, 1, 0.1', floatSeededSample(hpjs.qloguniform(-1, 1, 0.1)));
117 | assert.snapshot('qloguniform -100000, -1, -1', floatSeededSample(hpjs.qloguniform(-100000, -1, -1)));
118 | assert.snapshot('qloguniform -1, -10, 0.2222', floatSeededSample(hpjs.qloguniform(-1, -10, 0.22222)));
119 | assert.snapshot('qloguniform 5, 1, -0.111', floatSeededSample(hpjs.qloguniform(5, 1, -0.111)));
120 | assert.snapshot('qloguniform 5, 1, 0.1', floatSeededSample(hpjs.qloguniform(5, 1, 0.1)));
121 | assert.snapshot('qloguniform 1, 1, 0.001', floatSeededSample(hpjs.qloguniform(1, 1, 0.001)));
122 | });
123 | });
124 |
125 | describe('hpjs.normal.', () => {
126 | it('a number', () => {
127 | const mu = -1;
128 | const sigma = 1;
129 | const val = seededSample(hpjs.normal(mu, sigma));
130 | assert(!isNaN(val), `actual value ${val}`);
131 | });
132 | it('within 3 standard deviations of mean', () => {
133 | const mu = 0;
134 | const sigma = 1;
135 | const val = seededSample(hpjs.normal(mu, sigma));
136 | assert(val >= mu - (3*sigma) && val <= mu + (3*sigma), `actual value ${val}`);
137 | });
138 | it('Snapshot tests', () => {
139 | assert.snapshot('normal -1, 1', floatSeededSample(hpjs.normal(-1, 1)));
140 | assert.snapshot('normal -100000, -1', floatSeededSample(hpjs.normal(-100000, -1)));
141 | assert.snapshot('normal -1, -10', floatSeededSample(hpjs.normal(-1, -10)));
142 | assert.snapshot('normal 5, 1', floatSeededSample(hpjs.normal(5, 1)));
143 | assert.snapshot('normal 5, 1', floatSeededSample(hpjs.normal(5, 1)));
144 | assert.snapshot('normal 1, 1', floatSeededSample(hpjs.normal(1, 1)));
145 | });
146 | });
147 |
148 | describe('hpjs.qnormal.', () => {
149 | it('a number', () => {
150 | const mu = -1;
151 | const sigma = 1;
152 | const val = seededSample(hpjs.normal(mu, sigma));
153 | assert(!isNaN(val), `actual value ${val}`);
154 | });
155 | it('within 3 standard deviations of mean', () => {
156 | const mu = 0;
157 | const sigma = 1;
158 | const val = seededSample(hpjs.qnormal(mu, sigma, 0.1));
159 | assert(val >= mu - (3*sigma) && val <= mu + (3*sigma), `actual value ${val}`);
160 | });
161 | it('Snapshot tests', () => {
162 | assert.snapshot('qnormal -1, 1, 0.1', floatSeededSample(hpjs.qnormal(-1, 1, 0.1)));
163 | assert.snapshot('qnormal -100000, -1, -1', floatSeededSample(hpjs.qnormal(-100000, -1, -1)));
164 | assert.snapshot('qnormal -1, -10, 0.22222', floatSeededSample(hpjs.qnormal(-1, -10, 0.22222)));
165 | assert.snapshot('qnormal 5, 1, -0.111', floatSeededSample(hpjs.qnormal(5, 1, -0.111)));
166 | assert.snapshot('qnormal 1, 1000000, 50', floatSeededSample(hpjs.qnormal(1, 100, 50)));
167 | assert.snapshot('qnormal 1, 1, 0.001', floatSeededSample(hpjs.qnormal(1, 1, 0.001)));
168 | });
169 | });
170 |
171 | describe('hpjs.lognormal.', () => {
172 | it('positive', () => {
173 | const mu = 0;
174 | const sigma = 1;
175 | const val = seededSample(hpjs.lognormal(mu, sigma));
176 | assert(val >= 0, `actual value ${val}`);
177 | });
178 | it('less ~e^3 from the mean, or less than ~3 standard deviations from it', () => {
179 | const mu = 0;
180 | const sigma = 1;
181 | const val = seededSample(hpjs.lognormal(mu, sigma));
182 | assert(val <= 50, `actual value ${val}`);
183 | });
184 | it('Snapshot tests', () => {
185 | assert.snapshot('lognormal -1, 1', floatSeededSample(hpjs.lognormal(-1, 1)));
186 | assert.snapshot('lognormal -100000, -1', floatSeededSample(hpjs.lognormal(-100000, -1)));
187 | assert.snapshot('lognormal -1, -10', floatSeededSample(hpjs.lognormal(-1, -10)));
188 | assert.snapshot('lognormal 5, 1', floatSeededSample(hpjs.lognormal(5, 1)));
189 | assert.snapshot('lognormal 5, 1', floatSeededSample(hpjs.lognormal(5, 1)));
190 | assert.snapshot('lognormal 1, 1', floatSeededSample(hpjs.lognormal(1, 1)));
191 | });
192 | });
193 |
194 | describe('hpjs.qlognormal.', () => {
195 | it('a number', () => {
196 | const mu = -1;
197 | const sigma = 1;
198 | const val = seededSample(hpjs.qlognormal(mu, sigma, 0.2));
199 | assert(!isNaN(val), `actual value ${val}`);
200 | });
201 | it('within 3 standard deviations of mean', () => {
202 | const mu = 0;
203 | const sigma = 1;
204 | const val = seededSample(hpjs.qlognormal(mu, sigma, 0.1));
205 | assert(val >= mu - (3*sigma) && val <= mu + (3*sigma), `actual value ${val}`);
206 | });
207 | it('Snapshot tests', () => {
208 | assert.snapshot('qlognormal -1, 1, 0.1', floatSeededSample(hpjs.qlognormal(-1, 1, 0.1)));
209 | assert.snapshot('qlognormal -100000, -1, -1', floatSeededSample(hpjs.qlognormal(-100000, -1, -1)));
210 | assert.snapshot('qlognormal -1, -10, 0.22222', floatSeededSample(hpjs.qlognormal(-1, -10, 0.22222)));
211 | assert.snapshot('qlognormal 5, 1, -0.111', floatSeededSample(hpjs.qlognormal(5, 1, -0.111)));
212 | assert.snapshot('qlognormal 5, 1, 0.1', floatSeededSample(hpjs.qlognormal(5, 1, 0.1)));
213 | assert.snapshot('qlognormal 1, 1, 0.001', floatSeededSample(hpjs.qlognormal(1, 1, 0.001)));
214 | });
215 | });
216 |
217 | describe('random sample', () => {
218 | it('Choice as array', () => {
219 | const space = hpjs.choice(
220 | [
221 | hpjs.lognormal(0, 1),
222 | hpjs.uniform(-10, 10)
223 | ]
224 | );
225 |
226 | assert.snapshot('sample: array', floatSeededSample(space));
227 | });
228 | it('more complex space with depth', () => {
229 | const space = {
230 | x: hpjs.normal(0, 2),
231 | y: hpjs.uniform(0, 1),
232 | choice: hpjs.choice([
233 | null, hpjs.uniform(0, 1),
234 | ]),
235 | array: [
236 | hpjs.normal(0, 2), hpjs.uniform(0, 3), hpjs.choice([false, true]),
237 | ],
238 | obj: {
239 | u: hpjs.uniform(0, 3),
240 | v: hpjs.uniform(0, 3),
241 | w: hpjs.uniform(-3, 0)
242 | }
243 | };
244 | assert.snapshot('sample: depth', objectToFixed(seededSample(space)));
245 | });
246 | });
247 | describe('grid search', () => {
248 | const gs = new GridSearch();
249 | it('choice', () => {
250 | const space = hpjs.choice(['cat', 'dog']);
251 | assert(gs.numSamples(space) === 2, `choice num samples ${gs.numSamples(space)}`);
252 | });
253 | it('randint', () => {
254 | const space = hpjs.randint(5);
255 | assert(gs.numSamples(space) === 5, `randint [0, 5) ${gs.numSamples(space)}`);
256 | });
257 | it('quniform', () => {
258 | let space = hpjs.quniform(0, 1, 0.1);
259 | assert(gs.numSamples(space) === 11, `quniform 0,1,0.1 ${gs.numSamples(space)}`);
260 | space = hpjs.quniform(-5, 5, 1);
261 | assert(gs.numSamples(space) === 11, `quniform -5,5,1 ${gs.numSamples(space)}`);
262 | space = hpjs.quniform(1, 10, 2);
263 | assert(gs.numSamples(space) === 5, `quniform 1,10,2 ${gs.numSamples(space)}`);
264 | });
265 | it('qloguniform', () => {
266 | let space = hpjs.qloguniform(0,1,0.1);
267 | assert(gs.numSamples(space) === 11, `qloguniform 0,1,0.1 ${gs.numSamples(space)}`);
268 | space = hpjs.qloguniform(-5,5,1);
269 | assert(gs.numSamples(space) === 11, `qloguniform -5,5,1 ${gs.numSamples(space)}`);
270 | space = hpjs.qloguniform(1,10,2);
271 | assert(gs.numSamples(space) === 5, `qloguniform 1,10,2 ${gs.numSamples(space)}`);
272 | });
273 |
274 | it('qnormal', () => {
275 | let space = hpjs.qnormal(0, 1, 0.1);
276 | assert(gs.numSamples(space) === 41, `qnormal 0,1,0.1 ${gs.numSamples(space)}`);
277 | space = hpjs.qnormal(-5,5,1);
278 | assert(gs.numSamples(space) === 21, `qnormal -5,5,1 ${gs.numSamples(space)}`);
279 | space = hpjs.qnormal(1,10,2);
280 | assert(gs.numSamples(space) === 21, `qnormal 1,10,2 ${gs.numSamples(space)}`);
281 |
282 | });
283 | it('uniform', () => {
284 | try {
285 | gs.numSamples(hpjs.uniform(0, 1));
286 | assert(false, 'hpjs.uniform not allowed for grid search');
287 | } catch (e) {
288 | assert(e.message === 'Can not evaluate length of non-discrete parameter "uniform"', `exception message ${e.message}`);
289 | }
290 | });
291 | it('loguniform', () => {
292 | try {
293 | gs.numSamples(hpjs.loguniform(0, 1));
294 | assert(false, 'hpjs.loguniform not allowed for grid search');
295 | } catch (e) {
296 | assert(e.message === 'Can not evaluate length of non-discrete parameter "loguniform"', `exception message ${e.message}`);
297 | }
298 | });
299 | it('normal', () => {
300 | try {
301 | gs.numSamples(hpjs.normal(0, 1));
302 | assert(false, 'hpjs.normal not allowed for grid search');
303 | } catch (e) {
304 | assert(e.message === 'Can not evaluate length of non-discrete parameter "normal"', `exception message ${e.message}`);
305 | }
306 | });
307 | it('lognormal', () => {
308 | try {
309 | gs.numSamples(hpjs.lognormal(0, 1));
310 | assert(false, 'hpjs.lognormal not allowed for grid search');
311 | } catch (e) {
312 | assert(e.message === 'Can not evaluate length of non-discrete parameter "lognormal"', `exception message ${e.message}`);
313 | }
314 | });
315 | it('choice grid search', () => {
316 | const space = hpjs.choice(
317 | [
318 | hpjs.qlognormal(0, 1, 1), //5
319 | hpjs.quniform(-10, 10, 1) //21
320 | ]
321 | );
322 | assert(gs.numSamples(space) === 26, `choice ${gs.numSamples(space)}`);
323 | });
324 | });
325 |
326 | describe('fmin + rand', () => {
327 | it('FMin for x^2 - x + 1', async () => {
328 | const space = hpjs.uniform(-5, 5);
329 | const opt = x => ((x ** 2) - (x + 1));
330 | assert.snapshot('FMin for x^2 - x + 1', await randFMinSeeded(opt, space));
331 | });
332 | it('Hyperparameters space', async () => {
333 | const space = {
334 | x: hpjs.uniform(-5, 5),
335 | y: hpjs.uniform(-5, 5)
336 | };
337 | const opt = ({ x, y }) => ((x ** 2) + (y ** 2));
338 | assert.snapshot('Hyperparameters space', await randFMinSeeded(opt, space));
339 | });
340 | it('Choice selection of expressions', async () => {
341 | const space = hpjs.choice([
342 | hpjs.lognormal(0, 1),
343 | hpjs.uniform(-10, 10)
344 | ]
345 | );
346 | const opt = ( x ) => (x ** 2);
347 | assert.snapshot('choice as array space', await randFMinSeeded(opt, space));
348 | });
349 | it('Deep learning space', async () => {
350 | const space = {
351 | // Learning rate should be between 0.00001 and 1
352 | learning_rate:
353 | hpjs.loguniform(Math.log(1e-5), Math.log(1)),
354 | use_double_q_learning:
355 | hpjs.choice([true, false]),
356 | layer1_size: hpjs.quniform(10, 100, 1),
357 | layer2_size: hpjs.quniform(10, 100, 1),
358 | layer3_size: hpjs.quniform(10, 100, 1),
359 | future_discount_max: hpjs.uniform(0.5, 0.99),
360 | future_discount_increment: hpjs.loguniform(Math.log(0.001), Math.log(0.1)),
361 | recall_memory_size: hpjs.quniform(1, 100, 1),
362 | recall_memory_num_experiences_per_recall: hpjs.quniform(10, 2000, 1),
363 | num_epochs: hpjs.quniform(1, 10, 1),
364 | };
365 |
366 | const opt = params => params.learning_rate ** 2;
367 |
368 | assert.snapshot('Deep learning space', await randFMinSeeded(opt, space));
369 | });
370 | });
371 |
372 |
--------------------------------------------------------------------------------