├── .babelrc ├── .editorconfig ├── .eslintignore ├── .eslintrc ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── examples ├── react-sample │ ├── .babelrc │ ├── README.md │ ├── dist │ │ └── index.html │ ├── package.json │ ├── src │ │ ├── components │ │ │ ├── ExperimentsTable.js │ │ │ ├── Value.js │ │ │ └── ValuesRow.js │ │ ├── index.js │ │ ├── pages │ │ │ └── Home.js │ │ └── withRoot.js │ └── webpack.config.js └── tiny │ ├── README.md │ ├── index.html │ ├── index.js │ └── package.json ├── package.json ├── rollup.config.js ├── src ├── base │ ├── base.js │ └── fmin.js ├── index.js ├── search │ ├── grid.js │ └── random.js └── utils │ └── RandomState.js └── test ├── __snapshots__ └── index.js.json └── index.js /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [ 3 | [ 4 | "env", 5 | { 6 | "targets": { 7 | "browsers": [ 8 | "ie >= 11" 9 | ] 10 | }, 11 | "exclude": ["transform-async-to-generator", "transform-regenerator"], 12 | "modules": false, 13 | "loose": true 14 | } 15 | ] 16 | , "stage-0" 17 | ], 18 | "plugins": [ 19 | "transform-object-rest-spread" 20 | ], 21 | "env": { 22 | "development": { 23 | "presets": ["env", "stage-0"] 24 | }, 25 | "commonjs": { 26 | "presets": [ 27 | [ 28 | "env", 29 | { 30 | "loose": true 31 | } 32 | ] 33 | ] 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | charset = utf-8 6 | trim_trailing_whitespace = true 7 | insert_final_newline = true 8 | indent_style = space 9 | indent_size = 2 10 | 11 | [*.md] 12 | trim_trailing_whitespace = false 13 | 14 | [{package,bower}.json] 15 | indent_style = space 16 | indent_size = 2 17 | -------------------------------------------------------------------------------- /.eslintignore: -------------------------------------------------------------------------------- 1 | node_modules/** 2 | coverage/** 3 | test/** 4 | lib/** 5 | .vscode/** 6 | src/random.js 7 | -------------------------------------------------------------------------------- /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "parser": "babel-eslint", 3 | "extends": "airbnb", 4 | "env": { 5 | "mocha": true 6 | }, 7 | "rules": { 8 | "comma-dangle": ["error", "only-multiline"] 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | 6 | # Runtime data 7 | pids 8 | *.pid 9 | *.seed 10 | 11 | # Directory for instrumented libs generated by jscoverage/JSCover 12 | lib-cov 13 | 14 | # Coverage directory used by tools like istanbul 15 | coverage 16 | 17 | # nyc test coverage 18 | .nyc_output 19 | 20 | # Compiled binary addons (http://nodejs.org/api/addons.html) 21 | build/Release 22 | 23 | # Dependency directories 24 | node_modules 25 | jspm_packages 26 | 27 | # Optional npm cache directory 28 | .npm 29 | 30 | # Optional REPL history 31 | .node_repl_history 32 | 33 | # Editors 34 | .idea 35 | 36 | # Lib 37 | lib 38 | 39 | # npm package lock 40 | package-lock.json 41 | yarn.lock 42 | 43 | others 44 | .DS_Store 45 | .history -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: node_js 2 | node_js: 3 | - "6" 4 | script: 5 | - npm run test 6 | branches: 7 | only: 8 | - master 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ES6 hyperparameters optimization 2 | 3 | [![Build Status](https://travis-ci.org/atanasster/hyperparameters.svg?branch=master)](https://travis-ci.org/atanasster/hyperparameters) [![dependencies Status](https://david-dm.org/atanasster/hyperjs/status.svg)](https://david-dm.org/atanasster/hyperjs) [![devDependencies Status](https://david-dm.org/atanasster/hyperjs/dev-status.svg)](https://david-dm.org/atanasster/hyperjs?type=dev) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) 4 | 5 | :warning: Early version subject to changes. 6 | 7 | 8 | 9 | ## Features 10 | * **written in javascript** - Use with tensorflow.js as a replacement to your python hyperparameters library 11 | * **use from cdn or npm** - Link hpjs in your html file from a cdn, or install in your project with npm 12 | * **versatile** - Utilize multiple parameters and multiple search algorithms (grid search, random, bayesian) 13 | 14 | 15 | 16 | ## Installation 17 | 18 | ``` 19 | $ npm install hyperparameters 20 | ``` 21 | 22 | 23 | ## Parameter Expressions 24 | 25 | ``` 26 | import * as hpjs from 'hyperparameters'; 27 | ``` 28 | 29 | ### hpjs.choice(options) 30 | 31 | - Randomly returns one of the options 32 | 33 | ### hpjs.randint(upper) 34 | 35 | - Return a random integer in the range [0, upper) 36 | 37 | ### hpjs.uniform(low, high) 38 | 39 | - Returns a single value uniformly between `low` and `high` i.e. any value between `low` and `high` has an equal probability of being selected 40 | 41 | ### hpjs.quniform(low, high, q) 42 | 43 | - returns a quantized value of `hp.uniform` calculated as `round(uniform(low, high) / q) * q` 44 | 45 | ### hpjs.loguniform(low, high) 46 | 47 | - Returns a value `exp(uniform(low, high))` so the logarithm of the return value is uniformly distributed. 48 | 49 | ### hpjs.qloguniform(low, high, q) 50 | 51 | - Returns a value `round(exp(uniform(low, high)) / q) * q` 52 | 53 | ### hpjs.normal(mu, sigma) 54 | 55 | - Returns a real number that's normally-distributed with mean mu and standard deviation sigma 56 | 57 | ### hpjs.qnormal(mu, sigma, q) 58 | 59 | - Returns a value `round(normal(mu, sigma) / q) * q` 60 | 61 | ### hpjs.lognormal(mu, sigma) 62 | 63 | - Returns a value `exp(normal(mu, sigma))` 64 | 65 | ### hpjs.qlognormal(mu, sigma, q) 66 | 67 | - Returns a value `round(exp(normal(mu, sigma)) / q) * q` 68 | 69 | 70 | 71 | ## Random numbers generator 72 | 73 | ``` 74 | import { RandomState } from 'hyperparameters'; 75 | ``` 76 | 77 | **example:** 78 | ``` 79 | const rng = new RandomState(12345); 80 | console.log(rng.randrange(0, 5, 0.5)); 81 | 82 | ``` 83 | 84 | 85 | ## Spaces 86 | 87 | ``` 88 | import { sample } from 'hyperparameters'; 89 | ``` 90 | 91 | **example:** 92 | ``` 93 | import * as hpjs from 'hyperparameters'; 94 | 95 | const space = { 96 | x: hpjs.normal(0, 2), 97 | y: hpjs.uniform(0, 1), 98 | choice: hpjs.choice([ 99 | undefined, hp.uniform('float', 0, 1), 100 | ]), 101 | array: [ 102 | hpjs.normal(0, 2), hpjs.uniform(0, 3), hpjs.choice([false, true]), 103 | ], 104 | obj: { 105 | u: hpjs.uniform(0, 3), 106 | v: hpjs.uniform(0, 3), 107 | w: hpjs.uniform(-3, 0) 108 | } 109 | }; 110 | 111 | console.log(hpjs.sample.randomSample(space)); 112 | 113 | ``` 114 | ## fmin - find best value of a function over the arguments 115 | 116 | ``` 117 | import * as hpjs from 'hyperparameters'; 118 | const trials = hpjs.fmin(optimizationFunction, space, estimator, max_estimates, options); 119 | ``` 120 | 121 | **example:** 122 | ``` 123 | import * as hpjs from 'hyperparameters'; 124 | 125 | const fn = x => ((x ** 2) - (x + 1)); 126 | const space = hpjs.uniform(-5, 5); 127 | fmin(fn, space, hpjs.search.randomSearch, 1000, { rng: new hpjs.RandomState(123456) }) 128 | .then(trials => console.log(result.argmin)); 129 | ``` 130 | ## Getting started with tensorflow.js 131 | 132 | ### 1. [include javascript file](https://github.com/atanasster/hyperparameters/tree/master/examples/tiny) 133 | 134 | * include (latest) version from cdn 135 | 136 | ` 14 | 15 | 16 | -------------------------------------------------------------------------------- /examples/react-sample/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "react-sample", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "start": "webpack-dev-server --config ./webpack.config.js --mode development", 8 | "test": "echo \"Error: no test specified\" && exit 1" 9 | }, 10 | "keywords": [], 11 | "author": "", 12 | "license": "ISC", 13 | "devDependencies": { 14 | "babel-core": "^6.26.3", 15 | "babel-loader": "^7.1.4", 16 | "babel-polyfill": "^6.26.0", 17 | "babel-preset-env": "^1.7.0", 18 | "babel-preset-react": "^6.24.1", 19 | "babel-preset-stage-0": "^6.24.1", 20 | "react-hot-loader": "^4.3.2", 21 | "webpack": "^4.12.0", 22 | "webpack-cli": "^3.0.6", 23 | "webpack-dev-server": "^3.1.4" 24 | }, 25 | "dependencies": { 26 | "@material-ui/core": "latest", 27 | "@tensorflow/tfjs": "^0.11.6", 28 | "prop-types": "latest", 29 | "react": "^16.4.1", 30 | "react-dom": "^16.4.1" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /examples/react-sample/src/components/ExperimentsTable.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import PropTypes from 'prop-types'; 3 | import { withStyles } from '@material-ui/core/styles'; 4 | import Table from '@material-ui/core/Table'; 5 | import TableBody from '@material-ui/core/TableBody'; 6 | import TableCell from '@material-ui/core/TableCell'; 7 | import TableHead from '@material-ui/core/TableHead'; 8 | import TableRow from '@material-ui/core/TableRow'; 9 | 10 | const CustomTableCell = withStyles(theme => ({ 11 | head: { 12 | backgroundColor: theme.palette.common.black, 13 | color: theme.palette.common.white, 14 | fontSize: 14, 15 | }, 16 | body: { 17 | fontSize: 16, 18 | }, 19 | }))(TableCell); 20 | 21 | const STATES_MAP = ['new', 'running', 'done', 'error']; 22 | 23 | export const formatTraingTime = (date, locale = 'en-us') => ( 24 | date ? (new Date(date)).toLocaleDateString(locale, { 25 | month: '2-digit', 26 | day: '2-digit', 27 | hour: '2-digit', 28 | minute: '2-digit', 29 | }) : undefined 30 | ); 31 | 32 | export const periodToTime = (duration) => { 33 | if (!duration) { 34 | return { 35 | time: 0, 36 | units: 'ms', 37 | }; 38 | } 39 | const hours = (duration / (1000 * 60 * 60)).toFixed(0); 40 | const minutes = (duration / (1000 * 60)).toFixed(0); 41 | if (hours > 0) { 42 | return { 43 | time: `${hours}:${minutes}`, 44 | units: 'hrs', 45 | }; 46 | } 47 | const seconds = (duration / 1000).toFixed(0); 48 | if (minutes > 0) { 49 | return { 50 | time: `${minutes}:${seconds}`, 51 | units: 'min', 52 | }; 53 | } 54 | const milliseconds = (duration).toFixed(0); 55 | if (seconds > 0) { 56 | return { 57 | time: `${seconds}:${milliseconds}`, 58 | units: 'sec', 59 | }; 60 | } 61 | return { 62 | time: `${milliseconds}`, 63 | units: 'ms', 64 | }; 65 | }; 66 | 67 | const ExperimentsTable = ({ classes, experiments }) => ( 68 | 69 | 70 | 71 | # 72 | state 73 | start 74 | duration 75 | arguments 76 | results 77 | 78 | 79 | optimizer 80 | epochs 81 | status 82 | loss 83 | 84 | 85 | 86 | 87 | {experiments.map((exp) => { 88 | const duration = periodToTime(exp.refresh_time - exp.book_time); 89 | return ( 90 | 91 | 92 | {exp.id} 93 | 94 | {STATES_MAP[exp.state]} 95 | {formatTraingTime(exp.book_time)} 96 | {`${duration.time} ${duration.units}`} 97 | {exp.args.optimizer} 98 | {exp.args.epochs} 99 | {exp.result.status} 100 | {exp.result.loss.toFixed(5)} 101 | 102 | ); 103 | })} 104 | 105 |
106 | ); 107 | 108 | ExperimentsTable.propTypes = { 109 | classes: PropTypes.object.isRequired, 110 | experiments: PropTypes.array.isRequired, 111 | } 112 | export default ExperimentsTable; 113 | -------------------------------------------------------------------------------- /examples/react-sample/src/components/Value.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import Grid from '@material-ui/core/Grid'; 3 | import Typography from '@material-ui/core/Typography'; 4 | 5 | export default ({ value, label, size }) => ( 6 | 7 | 8 | 9 | {value} 10 | 11 | 12 | {label} 13 | 14 | 15 | 16 | ) 17 | -------------------------------------------------------------------------------- /examples/react-sample/src/components/ValuesRow.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import Grid from '@material-ui/core/Grid'; 3 | import Paper from '@material-ui/core/Paper'; 4 | import Typography from '@material-ui/core/Typography'; 5 | 6 | 7 | export default ({ classes, children, title }) => ( 8 | 9 | 10 | 11 | {title} 12 | 13 |
14 | 15 | {children} 16 | 17 |
18 |
19 | ); 20 | -------------------------------------------------------------------------------- /examples/react-sample/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import Home from './pages/Home'; 4 | 5 | ReactDOM.render( 6 | , 7 | document.getElementById('app') 8 | ); 9 | 10 | module.hot.accept(); 11 | -------------------------------------------------------------------------------- /examples/react-sample/src/pages/Home.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import PropTypes from 'prop-types'; 3 | import Grid from '@material-ui/core/Grid'; 4 | import Button from '@material-ui/core/Button'; 5 | import Typography from '@material-ui/core/Typography'; 6 | import * as tf from '@tensorflow/tfjs'; 7 | import * as hpjs from '../../../../src/index'; 8 | 9 | import withRoot from '../withRoot'; 10 | import ValuesRow from '../components/ValuesRow'; 11 | import Value from '../components/Value'; 12 | import ExperimentsTable from '../components/ExperimentsTable'; 13 | 14 | const trainModel = async ({ optimizer, epochs, onEpochEnd }, { xs, ys }) => { 15 | // Create a simple model. 16 | const model = tf.sequential(); 17 | model.add(tf.layers.dense({ units: 1, inputShape: xs.shape.slice(1) })); 18 | // Prepare the model for training: Specify the loss and the optimizer. 19 | model.compile({ 20 | loss: 'meanSquaredError', 21 | optimizer 22 | }); 23 | // Train the model using the data. 24 | const h = await model.fit(xs, ys, { epochs, callbacks: { onEpochEnd } }); 25 | return { model, loss: h.history.loss[h.history.loss.length - 1] }; 26 | }; 27 | 28 | // Generate some synthetic data for training. (y = 2x - 1) and pass to fmin as parameters 29 | const createData = () => ({ 30 | xs: tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]), 31 | ys: tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]), 32 | }); 33 | 34 | class Home extends React.Component { 35 | state = { 36 | experiments: [], 37 | epoch: undefined, 38 | experimentBegin: undefined, 39 | experimentEnd: undefined, 40 | }; 41 | 42 | onEpochEnd = async (epoch, logs) => { 43 | this.setState({ epoch: { epoch, logs } }); 44 | await tf.nextFrame(); 45 | }; 46 | 47 | onExperimentBegin = async (id, experiment) => { 48 | this.setState({ experimentBegin: { id, experiment } }); 49 | await tf.nextFrame(); 50 | }; 51 | onExperimentEnd = async (id, experiment) => { 52 | this.setState({ 53 | experimentEnd: { id, experiment }, 54 | experiments: [...this.state.experiments, experiment] 55 | }); 56 | await tf.nextFrame(); 57 | }; 58 | 59 | onRunClick = async () => { 60 | this.setState({ 61 | epoch: undefined, 62 | experimentBegin: undefined, 63 | experimentEnd: undefined, 64 | experiments: [], 65 | best: undefined, 66 | }); 67 | // fmin optmization function, retuns the loss and a STATUS_OK 68 | async function modelOpt({ optimizer, epochs }, { xs, ys, onEpochEnd }) { 69 | const { loss } = await trainModel({ optimizer, epochs, onEpochEnd }, { xs, ys }); 70 | return { loss, status: hpjs.STATUS_OK }; 71 | } 72 | 73 | // hyperparameters search space 74 | // optmizer is a choice field 75 | // epochs ia an integer value from 10 to 250 with a step of 5 76 | const space = { 77 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']), 78 | epochs: hpjs.quniform(10, 30, 10), 79 | }; 80 | tf.ENV.engine.startScope(); 81 | // Generate some synthetic data for training. (y = 2x - 1) and pass to fmin as parameters 82 | // data will be passed as a parameters to the fmin 83 | const { xs, ys } = createData(); 84 | const experiments = await hpjs.fmin( 85 | modelOpt, space, hpjs.search.randomSearch, 10, 86 | { 87 | rng: new hpjs.RandomState(654321), 88 | xs, 89 | ys, 90 | onEpochEnd: this.onEpochEnd, 91 | callbacks: { 92 | onExperimentBegin: this.onExperimentBegin, 93 | onExperimentEnd: this.onExperimentEnd 94 | } 95 | } 96 | ); 97 | this.setState({ best: experiments.argmin }); 98 | tf.ENV.engine.endScope(); 99 | }; 100 | 101 | onPredictClick = async () => { 102 | const { best } = this.state; 103 | tf.ENV.engine.startScope(); 104 | const { xs, ys } = createData(); 105 | const { model } = await trainModel(best, { xs, ys }); 106 | const prediction = model.predict(tf.tensor2d([20], [1, 1])); 107 | this.setState({ prediction: prediction.dataSync() }); 108 | tf.ENV.engine.endScope(); 109 | }; 110 | 111 | render() { 112 | const { classes } = this.props; 113 | const { 114 | epoch, experimentBegin, experimentEnd, experiments, best, prediction 115 | } = this.state; 116 | 117 | const spacing = 24; 118 | return ( 119 |
120 | 121 | TensorFlow "tiny" 122 | 123 | 124 | {experimentBegin && ( 125 | 126 | 127 | {Object.keys(experimentBegin.experiment.args).map(key => ( 128 | 129 | ))} 130 | {epoch && ( 131 | 132 | 133 | 134 | 135 | )} 136 | 137 | 138 | )} 139 | {experimentEnd && ( 140 | 141 | 142 | {Object.keys(experimentEnd.experiment.args).map(key => ( 143 | 144 | ))} 145 | 146 | 147 | 148 | )} 149 | {best && ( 150 | 151 | {Object.keys(best).map(key => ( 152 | 153 | ))} 154 | 155 | 158 | 159 | {prediction && ( 160 | 161 | )} 162 | 163 | 164 | )} 165 | 166 | 169 | 170 | 171 | 172 | Experiments 173 | 174 | 175 | 176 | 177 |
178 | ); 179 | } 180 | } 181 | 182 | Home.propTypes = { 183 | classes: PropTypes.object.isRequired, 184 | }; 185 | 186 | export default withRoot(Home); 187 | -------------------------------------------------------------------------------- /examples/react-sample/src/withRoot.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { MuiThemeProvider, createMuiTheme } from '@material-ui/core/styles'; 3 | import orange from '@material-ui/core/colors/orange'; 4 | import blueGrey from '@material-ui/core/colors/blueGrey'; 5 | import CssBaseline from '@material-ui/core/CssBaseline'; 6 | import { withStyles } from "@material-ui/core/styles/index"; 7 | import AppBar from '@material-ui/core/AppBar'; 8 | import Toolbar from '@material-ui/core/Toolbar'; 9 | import Typography from '@material-ui/core/Typography'; 10 | 11 | 12 | const theme = createMuiTheme({ 13 | palette: { 14 | primary: { 15 | light: orange[300], 16 | main: orange[800], 17 | dark: orange[900], 18 | }, 19 | secondary: { 20 | light: blueGrey[300], 21 | main: blueGrey[500], 22 | dark: blueGrey[700], 23 | }, 24 | }, 25 | }); 26 | 27 | const styles = theme => ({ 28 | root: { 29 | textAlign: 'center', 30 | flexGrow: 1, 31 | padding: theme.spacing.unit * 10, 32 | }, 33 | inforow: { 34 | padding: theme.spacing.unit, 35 | }, 36 | }); 37 | 38 | function withRoot(Component) { 39 | function WithRoot(props) { 40 | return ( 41 | 42 | 43 | 44 | 45 | 46 | hpjs 47 | 48 | 49 | 50 |
51 | 52 |
53 |
54 | ); 55 | } 56 | 57 | return withStyles(styles)(WithRoot); 58 | } 59 | 60 | export default withRoot; 61 | -------------------------------------------------------------------------------- /examples/react-sample/webpack.config.js: -------------------------------------------------------------------------------- 1 | const webpack = require('webpack'); 2 | 3 | module.exports = { 4 | entry: [ 5 | 'babel-polyfill', 6 | 'react-hot-loader/patch', 7 | './src/index.js' 8 | ], 9 | module: { 10 | rules: [ 11 | { 12 | test: /\.(js|jsx)$/, 13 | exclude: /node_modules/, 14 | use: ['babel-loader'] 15 | } 16 | ] 17 | }, 18 | resolve: { 19 | extensions: ['*', '.js', '.jsx'] 20 | }, 21 | output: { 22 | path: __dirname + '/dist', 23 | publicPath: '/', 24 | filename: 'bundle.js' 25 | }, 26 | plugins: [ 27 | new webpack.HotModuleReplacementPlugin() 28 | ], 29 | devServer: { 30 | contentBase: './dist', 31 | hot: true 32 | } 33 | }; 34 | -------------------------------------------------------------------------------- /examples/tiny/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow.js Tiny Example with HyperParameters.js 2 | 3 | This minimal example loads tfjs and hpjs from a CDN, builds and trains a minimal model, 4 | and finds the optimal optimizer and number of epochs. 5 | 6 | ## Getting started 7 | 8 | * include (latest) version from cdn 9 | 10 | ` 9 | 10 | 11 | 12 |

Tiny TFJS + HPJS example

13 |
14 |       const model = tf.sequential();
15 |       model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
16 |       model.compile({
17 |         loss: 'meanSquaredError',
18 |         optimizer
19 |       });
20 |       await model.fit(xs, ys, { epochs });
21 |     
22 |

Best optimizer:

23 |
24 |

Epochs:

25 |
26 | 27 |

Prediction:

28 |
29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /examples/tiny/index.js: -------------------------------------------------------------------------------- 1 | /* eslint-disable no-undef */ 2 | 3 | // Starting with the tensorflow.js Tiny example, this sampple illustrates how little code is 4 | // necessary to build / train / optimize / predict from a model 5 | // in TensorFlow.js + Hyperparameters.js 6 | 7 | // function to fit a model, given an optimizer and number of epochs 8 | // returns the model and the final loss 9 | const trainModel = async ({ optimizer, epochs }, { xs, ys }) => { 10 | // Create a simple model. 11 | const model = tf.sequential(); 12 | model.add(tf.layers.dense({ units: 1, inputShape: [1] })); 13 | // Prepare the model for training: Specify the loss and the optimizer. 14 | model.compile({ 15 | loss: 'meanSquaredError', 16 | optimizer 17 | }); 18 | // Train the model using the data. 19 | const h = await model.fit(xs, ys, { epochs }); 20 | return { model, loss: h.history.loss[h.history.loss.length - 1] }; 21 | }; 22 | 23 | // fmin optmization function, retuns the loss and a STATUS_OK 24 | const modelOpt = async ({ optimizer, epochs }, { xs, ys }) => { 25 | const { loss } = await trainModel({ optimizer, epochs }, { xs, ys }); 26 | return { loss, status: hpjs.STATUS_OK }; 27 | }; 28 | 29 | async function launchHPJS() { 30 | // hyperparameters search space 31 | // optmizer is a choice field 32 | // epochs ia an integer value from 10 to 250 with a step of 5 33 | const space = { 34 | optimizer: hpjs.choice(['sgd', 'adam', 'adagrad', 'rmsprop']), 35 | epochs: hpjs.quniform(50, 250, 50), 36 | }; 37 | // Generate some synthetic data for training. (y = 2x - 1) and pass to fmin as parameters 38 | // data will be passed as a parameters to the fmin 39 | const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]); 40 | const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]); 41 | 42 | const trials = await hpjs.fmin( 43 | modelOpt, space, hpjs.search.randomSearch, 10, 44 | { rng: new hpjs.RandomState(654321), xs, ys } 45 | ); 46 | const opt = trials.argmin; 47 | document.getElementById('optimizer_best').innerText = opt.optimizer; 48 | document.getElementById('epochs').innerText = opt.epochs; 49 | const { model } = await trainModel(opt, { xs, ys }); 50 | const prediction = model.predict(tf.tensor2d([20], [1, 1])); 51 | document.getElementById('prediction').innerText += prediction.dataSync(); 52 | } 53 | 54 | launchHPJS(); 55 | -------------------------------------------------------------------------------- /examples/tiny/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tfjs-hpjs-tiny", 3 | "version": "0.1.0", 4 | "description": "", 5 | "main": "index.js", 6 | "license": "Apache-2.0", 7 | "private": true, 8 | "engines": { 9 | "node": ">=8.9.0" 10 | }, 11 | "dependencies": { }, 12 | "scripts": { 13 | "watch": "npm run build && node_modules/http-server/bin/http-server dist -p 1234 ", 14 | "build": "mkdir -p dist/ && cp index.html dist/ && cp index.js dist/" 15 | }, 16 | "devDependencies": { 17 | "clang-format": "~1.2.2", 18 | "http-server": "~0.10.0" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hyperparameters", 3 | "version": "0.25.6", 4 | "description": "javascript hyperparameters search", 5 | "main": "lib/hyperparameters.js", 6 | "unpkg": "dist/hyperparameters.js", 7 | "module": "es/hyperparameters.js", 8 | "scripts": { 9 | "test": "npm run lint && npm run cover", 10 | "test:prod": "cross-env BABEL_ENV=production npm run test", 11 | "test:only": "mocha $NODE_DEBUG_OPTION --require babel-core/register --require babel-polyfill --recursive", 12 | "test:watch": "npm test -- --watch", 13 | "cover": "istanbul cover _mocha -- --require babel-core/register --require babel-polyfill --recursive", 14 | "lint": "eslint src test", 15 | "clean": "rimraf lib dist es coverage", 16 | "build:commonjs": "cross-env NODE_ENV=cjs rollup -c -o lib/hyperparameters.js", 17 | "build:es": "cross-env BABEL_ENV=es NODE_ENV=es rollup -c -o es/hyperparameters.js", 18 | "build:umd": "cross-env BABEL_ENV=es NODE_ENV=development rollup -c -o dist/hyperparameters.js", 19 | "build:umd:min": "cross-env BABEL_ENV=es NODE_ENV=production rollup -c -o dist/hyperparameters.min.js", 20 | "build": "npm run build:commonjs && npm run build:es && npm run build:umd && npm run build:umd:min", 21 | "prepare": "npm run clean && npm test && npm run build", 22 | "pub": "npm run build && npm publish" 23 | }, 24 | "files": [ 25 | "dist", 26 | "lib", 27 | "es", 28 | "src" 29 | ], 30 | "repository": { 31 | "type": "git", 32 | "url": "git+https://github.com/atanasster/hyperparameters.git" 33 | }, 34 | "keywords": [ 35 | "hyperparameters", 36 | "hyperopt", 37 | "tensorflow", 38 | "es6", 39 | "tfjs", 40 | "javascript" 41 | ], 42 | "contributors": [ 43 | { 44 | "name": "Atanas Stoyanov", 45 | "email": "atanasster@gmail.com", 46 | "url": "https://github.com/atanasster" 47 | }, 48 | { 49 | "name": "Martin Stoyanov", 50 | "email": "martin.a.stoyanov@gmail.com", 51 | "url": "https://github.com/martin-stoyanov" 52 | } 53 | ], 54 | "license": "Apache-2.0", 55 | "bugs": { 56 | "url": "https://github.com/atanasster/hyperparameters/issues" 57 | }, 58 | "homepage": "https://github.com/atanasster/hyperparameters#readme", 59 | "devDependencies": { 60 | "babel-cli": "^6.26.0", 61 | "babel-eslint": "^8.2.1", 62 | "babel-plugin-add-module-exports": "^0.2.1", 63 | "babel-plugin-external-helpers": "^6.22.0", 64 | "babel-polyfill": "^6.26.0", 65 | "babel-preset-env": "^1.7.0", 66 | "babel-preset-minify": "^0.3.0", 67 | "babel-preset-stage-0": "^6.24.1", 68 | "chai": "^4.1.2", 69 | "chai-snapshot-tests": "^0.6.0", 70 | "cross-env": "^5.1.3", 71 | "eslint": "^4.16.0", 72 | "eslint-config-airbnb": "^16.1.0", 73 | "eslint-plugin-import": "^2.7.0", 74 | "eslint-plugin-jsx-a11y": "^6.0.2", 75 | "eslint-plugin-react": "^7.4.0", 76 | "istanbul": "^1.0.0-alpha", 77 | "mocha": "^5.0.0", 78 | "pre-commit": "^1.2.2", 79 | "rimraf": "^2.6.2", 80 | "rollup": "^0.60.1", 81 | "rollup-plugin-babel": "^3.0.4", 82 | "rollup-plugin-commonjs": "^9.1.3", 83 | "rollup-plugin-node-resolve": "^3.3.0", 84 | "rollup-plugin-replace": "^2.0.0", 85 | "rollup-plugin-uglify": "^3.0.0", 86 | "rollup-watch": "^4.3.1" 87 | }, 88 | "pre-commit": [ 89 | "test" 90 | ] 91 | } 92 | -------------------------------------------------------------------------------- /rollup.config.js: -------------------------------------------------------------------------------- 1 | import nodeResolve from 'rollup-plugin-node-resolve'; 2 | import babel from 'rollup-plugin-babel'; 3 | import replace from 'rollup-plugin-replace'; 4 | import uglify from 'rollup-plugin-uglify'; 5 | 6 | const env = process.env.NODE_ENV; 7 | const config = { 8 | input: 'src/index.js', 9 | plugins: [] 10 | }; 11 | 12 | if (env === 'es' || env === 'cjs') { 13 | config.output = { format: env, indent: false }; 14 | config.external = ['symbol-observable']; 15 | config.plugins.push(babel({ 16 | plugins: ['external-helpers'], 17 | })); 18 | } 19 | 20 | if (env === 'development' || env === 'production') { 21 | config.output = { format: 'umd', name: 'hpjs', indent: false }; 22 | config.plugins.push( 23 | nodeResolve({ 24 | jsnext: true 25 | }), 26 | babel({ 27 | exclude: 'node_modules/**', 28 | plugins: ['external-helpers'], 29 | }), 30 | replace({ 31 | 'process.env.NODE_ENV': JSON.stringify(env) 32 | }) 33 | ); 34 | } 35 | 36 | if (env === 'production') { 37 | config.plugins.push(uglify({ 38 | compress: { 39 | pure_getters: true, 40 | unsafe: true, 41 | unsafe_comps: true, 42 | warnings: false 43 | } 44 | })); 45 | } 46 | 47 | export default config; 48 | -------------------------------------------------------------------------------- /src/base/base.js: -------------------------------------------------------------------------------- 1 | import RandomState from '../utils/RandomState'; 2 | 3 | 4 | export default class BaseSpace { 5 | eval = (expr, { rng: rState }) => { 6 | if (expr === undefined || expr === null) { 7 | return expr; 8 | } 9 | let rng = rState; 10 | if (rng === undefined) { 11 | rng = new RandomState(); 12 | } 13 | const { name, ...rest } = expr; 14 | const space = this[name]; 15 | if (typeof space !== 'function') { 16 | if (Array.isArray(expr)) { 17 | return expr.map(item => this.eval(item, { rng })); 18 | } 19 | if (typeof expr === 'object') { 20 | return Object.keys(expr) 21 | .reduce((r, key) => ({ ...r, [key]: this.eval(expr[key], { rng }) }), {}); 22 | } 23 | return expr; 24 | } 25 | return space(rest, rng); 26 | }; 27 | } 28 | 29 | export const STATUS_NEW = 'new'; 30 | export const STATUS_RUNNING = 'running'; 31 | export const STATUS_SUSPENDED = 'suspended'; 32 | export const STATUS_OK = 'ok'; 33 | export const STATUS_FAIL = 'fail'; 34 | export const STATUS_STRINGS = [ 35 | 'new', // computations have not started 36 | 'running', // computations are in prog 37 | 'suspended', // computations have been suspended, job is not finished 38 | 'ok', // computations are finished, terminated normally 39 | 'fail']; // computations are finished, terminated with error 40 | 41 | 42 | // -- named constants for job execution pipeline 43 | export const JOB_STATE_NEW = 0; 44 | export const JOB_STATE_RUNNING = 1; 45 | export const JOB_STATE_DONE = 2; 46 | export const JOB_STATE_ERROR = 3; 47 | export const JOB_STATES = [ 48 | JOB_STATE_NEW, 49 | JOB_STATE_RUNNING, 50 | JOB_STATE_DONE, 51 | JOB_STATE_ERROR]; 52 | 53 | 54 | export const TRIAL_KEYS = [ 55 | 'id', 56 | 'result', 57 | 'args', 58 | 'state', 59 | 'book_time', 60 | 'refresh_time', 61 | ]; 62 | 63 | export const range = (start, end) => Array.from({ length: (end - start) }, (v, k) => k + start); 64 | 65 | export class Trials { 66 | constructor(expKey = null, refresh = true) { 67 | this.ids = []; 68 | this.dynamicTrials = []; 69 | this.trials = []; 70 | this.expKey = expKey; 71 | if (refresh) { 72 | this.refresh(); 73 | } 74 | } 75 | 76 | get length() { 77 | return this.trials.length; 78 | } 79 | 80 | refresh = () => { 81 | if (this.expKey === null) { 82 | this.trials = this.dynamicTrials 83 | .filter(trial => trial.state !== JOB_STATE_ERROR); 84 | } else { 85 | this.trials = this.dynamicTrials 86 | .filter(trial => trial.state !== JOB_STATE_ERROR && trial.expKey === this.expKey); 87 | this.ids = []; 88 | } 89 | }; 90 | 91 | get results() { 92 | return this.trials.map(trial => trial.result); 93 | } 94 | 95 | get args() { 96 | return this.trials.map(trial => trial.args); 97 | } 98 | 99 | assertValidTrial = (trial) => { 100 | if (Object.keys(trial).length <= 0) { 101 | throw new Error('trial should be an object'); 102 | } 103 | const missingTrialKey = TRIAL_KEYS.find(key => trial[key] === undefined); 104 | if (missingTrialKey !== undefined) { 105 | throw new Error(`trial missing key ${missingTrialKey}`); 106 | } 107 | if (trial.expKey !== this.expKey) { 108 | throw new Error(`wrong trial expKey ${trial.expKey}, expected ${this.expKey}`); 109 | } 110 | return trial; 111 | }; 112 | 113 | internalInsertTrialDocs = (docs) => { 114 | const rval = docs.map(doc => doc.id); 115 | this.dynamicTrials = [...this.dynamicTrials, ...docs]; 116 | return rval; 117 | }; 118 | 119 | insertTrialDoc = (trial) => { 120 | const doc = this.assertValidTrial(trial); 121 | return this.internalInsertTrialDocs([doc])[0]; 122 | }; 123 | 124 | insertTrialDocs = (trials) => { 125 | const docs = trials.map(trial => this.assertValidTrial(trial)); 126 | return this.internalInsertTrialDocs(docs); 127 | }; 128 | 129 | newTrialIds = (N) => { 130 | const aa = this.ids.length; 131 | const rval = range(aa, aa + N); 132 | this.ids = [...this.ids, ...rval]; 133 | return rval; 134 | }; 135 | 136 | newTrialDocs = (ids, results, args) => { 137 | const rval = []; 138 | for (let i = 0; i < ids.length; i += 1) { 139 | const doc = { 140 | state: JOB_STATE_NEW, 141 | id: ids[i], 142 | result: results[i], 143 | args: args[i], 144 | }; 145 | doc.expKey = this.expKey; 146 | doc.book_time = null; 147 | doc.refresh_time = null; 148 | rval.push(doc); 149 | } 150 | return rval; 151 | }; 152 | 153 | deleteAll = () => { 154 | this.dynamicTrials = []; 155 | this.refresh(); 156 | }; 157 | 158 | countByStateSynced = (arg, trials = null) => { 159 | const vTrials = trials === null ? this.trials : trials; 160 | const vArg = Array.isArray(arg) ? arg : [arg]; 161 | const queue = vTrials.filter(doc => vArg.indexOf(doc.state) >= 0); 162 | return queue.length; 163 | }; 164 | 165 | countByStateUnsynced = (arg) => { 166 | const expTrials = this.expKey !== null ? 167 | this.dynamicTrials.map(trial => trial.expKey === this.expKey) : this.dynamicTrials; 168 | return this.countByStateSynced(arg, expTrials); 169 | }; 170 | 171 | losses = () => this.results.map(r => r.loss || r.accuracy); 172 | 173 | statuses = () => this.results.map(r => r.status); 174 | 175 | bestTrial(compare = (a, b) => 176 | (a.loss !== undefined ? a.loss < b.loss : a.accuracy > b.accuracy)) { 177 | let best = this.trials[0]; 178 | this.trials.forEach((trial) => { 179 | if (trial.result.status === STATUS_OK && compare(trial.result, best.result)) { 180 | best = trial; 181 | } 182 | }); 183 | return best; 184 | } 185 | 186 | get argmin() { 187 | const best = this.bestTrial(); 188 | return best !== undefined ? best.args : undefined; 189 | } 190 | 191 | get argmax() { 192 | const best = this.bestTrial((a, b) => 193 | (a.loss !== undefined ? a.loss > b.loss : a.accuracy > b.accuracy)); 194 | return best !== undefined ? best.args : undefined; 195 | } 196 | } 197 | 198 | export class Domain { 199 | constructor(fn, expr, params) { 200 | this.fn = fn; 201 | this.expr = expr; 202 | this.params = params; 203 | } 204 | 205 | evaluate = async (args) => { 206 | const rval = await this.fn(args, this.params); 207 | let result; 208 | if (typeof rval === 'number' && !Number.isNaN(rval)) { 209 | result = { loss: rval, status: STATUS_OK }; 210 | } else { 211 | result = rval; 212 | if (result === undefined) { 213 | throw new Error('Optimization function should return a loss value'); 214 | } 215 | const { status, loss, accuracy } = result; 216 | if (STATUS_STRINGS.indexOf(status) < 0) { 217 | throw new Error(`invalid status ${status}`); 218 | } 219 | if (status === STATUS_OK && loss === undefined && accuracy === undefined) { 220 | throw new Error('invalid loss and accuracy'); 221 | } 222 | } 223 | return result; 224 | }; 225 | newResult = () => ({ 226 | status: STATUS_NEW, 227 | }); 228 | } 229 | -------------------------------------------------------------------------------- /src/base/fmin.js: -------------------------------------------------------------------------------- 1 | /* eslint-disable camelcase,no-await-in-loop */ 2 | import RandomState from '../utils/RandomState'; 3 | import { Trials, Domain, JOB_STATE_NEW, JOB_STATE_RUNNING, JOB_STATE_ERROR, JOB_STATE_DONE } from './base'; 4 | 5 | const getTimeStatmp = () => new Date().getTime(); 6 | 7 | class FMinIter { 8 | constructor( 9 | algo, domain, trials, 10 | { 11 | rng, 12 | catchExceptions = false, 13 | max_queue_len = 1, 14 | max_evals = Number.MAX_VALUE, 15 | } = {}, 16 | params = {} 17 | ) { 18 | this.catchExceptions = catchExceptions; 19 | this.algo = algo; 20 | this.domain = domain; 21 | this.trials = trials; 22 | this.callbacks = params.callbacks || {}; 23 | this.max_queue_len = max_queue_len; 24 | this.max_evals = max_evals; 25 | this.rng = rng; 26 | } 27 | async serial_evaluate(N = -1) { 28 | const { onExperimentBegin, onExperimentEnd } = this.callbacks; 29 | let n = N; 30 | let stopped = false; 31 | for (let i = 0; i < this.trials.dynamicTrials.length; i += 1) { 32 | const trial = this.trials.dynamicTrials[i]; 33 | if (trial.state === JOB_STATE_NEW) { 34 | trial.state = JOB_STATE_RUNNING; 35 | const now = getTimeStatmp(); 36 | trial.book_time = now; 37 | trial.refresh_time = now; 38 | try { 39 | if (typeof onExperimentBegin === 'function') { 40 | if (await onExperimentBegin(i, trial) === true) { 41 | stopped = true; 42 | } 43 | } 44 | // eslint-disable-next-line no-await-in-loop 45 | const result = await this.domain.evaluate(trial.args); 46 | trial.state = JOB_STATE_DONE; 47 | trial.result = result; 48 | trial.refresh_time = getTimeStatmp(); 49 | } catch (e) { 50 | trial.state = JOB_STATE_ERROR; 51 | trial.error = `${e}, ${e.message}`; 52 | trial.refresh_time = getTimeStatmp(); 53 | if (!this.catchExceptions) { 54 | this.trials.refresh(); 55 | throw e; 56 | } 57 | } 58 | if (typeof onExperimentEnd === 'function') { 59 | if (await onExperimentEnd(i, trial) === true) { 60 | stopped = true; 61 | } 62 | } 63 | } 64 | n -= 1; 65 | if (n === 0 || stopped) { 66 | break; 67 | } 68 | } 69 | this.trials.refresh(); 70 | return stopped; 71 | } 72 | 73 | run = async (N) => { 74 | const { trials, algo } = this; 75 | let n_queued = 0; 76 | 77 | const get_queue_len = () => this.trials.countByStateUnsynced(JOB_STATE_NEW); 78 | 79 | let stopped = false; 80 | while (n_queued < N) { 81 | let qlen = get_queue_len(); 82 | while (qlen < this.max_queue_len && n_queued < N) { 83 | const n_to_enqueue = Math.min(this.max_queue_len - qlen, N - n_queued); 84 | const new_ids = trials.newTrialIds(n_to_enqueue); 85 | trials.refresh(); 86 | const new_trials = algo( 87 | new_ids, this.domain, trials, 88 | this.rng.randrange(0, (2 ** 31) - 1) 89 | ); 90 | console.assert(new_ids.length >= new_trials.length); 91 | if (new_trials.length) { 92 | this.trials.insertTrialDocs(new_trials); 93 | this.trials.refresh(); 94 | n_queued += new_trials.length; 95 | qlen = get_queue_len(); 96 | } else { 97 | stopped = true; 98 | break; 99 | } 100 | } 101 | stopped = stopped || await this.serial_evaluate(); 102 | if (stopped) { 103 | break; 104 | } 105 | } 106 | const qlen = get_queue_len(); 107 | if (qlen) { 108 | const msg = `Exiting run, not waiting for ${qlen} jobs.`; 109 | console.error(msg); 110 | } 111 | }; 112 | 113 | exhaust = async () => { 114 | const n_done = this.trials.length; 115 | await this.run(this.max_evals - n_done); 116 | this.trials.refresh(); 117 | return this; 118 | } 119 | } 120 | 121 | export default async (fn, space, algo, max_evals, params = {}) => { 122 | const { 123 | trials: defTrials, rng: rngDefault, 124 | catchExceptions = false, 125 | } = params; 126 | 127 | let rng; 128 | if (rngDefault) { 129 | rng = rngDefault; 130 | } else { 131 | rng = new RandomState(); 132 | } 133 | let trials; 134 | if (!defTrials) { 135 | trials = new Trials(); 136 | } else { 137 | trials = defTrials; 138 | } 139 | 140 | const domain = new Domain(fn, space, params); 141 | 142 | const rval = new FMinIter( 143 | algo, domain, trials, 144 | { max_evals, rng, catchExceptions }, 145 | params 146 | ); 147 | await rval.exhaust(); 148 | return trials; 149 | }; 150 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | import fmin from './base/fmin'; 2 | import { randomSample, randomSearch } from './search/random'; 3 | import { gridSample, gridSearch } from './search/grid'; 4 | import RandomState from './utils/RandomState'; 5 | 6 | export * from './base/base'; 7 | 8 | export { fmin, RandomState }; 9 | 10 | export const choice = options => ({ name: 'choice', options }); 11 | export const randint = upper => ({ name: 'randint', upper }); 12 | export const uniform = (low, high) => ({ name: 'uniform', low, high }); 13 | export const quniform = (low, high, q) => ({ 14 | name: 'quniform', low, high, q 15 | }); 16 | export const loguniform = (low, high) => ({ name: 'loguniform', low, high }); 17 | export const qloguniform = (low, high, q) => ({ 18 | name: 'qloguniform', low, high, q 19 | }); 20 | export const normal = (mu, sigma) => ({ name: 'normal', mu, sigma }); 21 | export const qnormal = (mu, sigma, q) => ({ 22 | name: 'qnormal', mu, sigma, q 23 | }); 24 | export const lognormal = (mu, sigma) => ({ name: 'lognormal', mu, sigma }); 25 | export const qlognormal = (mu, sigma, q) => ({ 26 | name: 'qlognormal', mu, sigma, q 27 | }); 28 | 29 | export const search = { 30 | randomSearch, 31 | gridSearch, 32 | }; 33 | 34 | export const sample = { 35 | randomSample, 36 | gridSample, 37 | }; 38 | -------------------------------------------------------------------------------- /src/search/grid.js: -------------------------------------------------------------------------------- 1 | /* eslint-disable class-methods-use-this */ 2 | import BaseSpace from '../base/base'; 3 | 4 | class GridSearchParam { 5 | constructor(params, gs) { 6 | this.params = params; 7 | this.gs = gs; 8 | } 9 | 10 | get numSamples() { 11 | return 1; 12 | } 13 | 14 | getSample = () => undefined; 15 | sample = (index) => { 16 | if (index < 0 || index >= this.numSamples) { 17 | throw new Error(`invalid sample index "${index}"`); 18 | } 19 | return this.getSample(index); 20 | } 21 | } 22 | 23 | 24 | class GridSearchNotImplemented extends GridSearchParam { 25 | get numSamples() { 26 | throw new Error(`Can not evaluate length of non-discrete parameter "${this.params.name}"`); 27 | } 28 | getSample = index => this.params.options[index]; 29 | } 30 | 31 | class GridSearchChoice extends GridSearchParam { 32 | get numSamples() { 33 | return this.params.options.reduce((r, e) => (r + this.gs.numSamples(e)), 0); 34 | } 35 | getSample = index => this.params.options[index]; 36 | } 37 | 38 | class GridSearchRandInt extends GridSearchParam { 39 | get numSamples() { 40 | return this.params.upper; 41 | } 42 | getSample = index => index; 43 | } 44 | 45 | class GridSearchUniform extends GridSearchParam { 46 | get numSamples() { 47 | return Math.floor((this.params.high - this.params.low) / this.params.q) + 1; 48 | } 49 | getSample = index => this.params.low + (index * this.params.q); 50 | } 51 | 52 | class GridSearchNormal extends GridSearchParam { 53 | get numSamples() { 54 | return Math.floor((4 * this.params.sigma) / this.params.q) + 1; 55 | } 56 | getSample = index => (this.params.mu - (2 * this.params.sigma)) + (index * this.params.q); 57 | } 58 | 59 | const GridSearchParamas = { 60 | choice: GridSearchChoice, 61 | randint: GridSearchRandInt, 62 | quniform: GridSearchUniform, 63 | qloguniform: GridSearchUniform, 64 | qnormal: GridSearchNormal, 65 | qlognormal: GridSearchNormal, 66 | uniform: GridSearchNotImplemented, 67 | loguniform: GridSearchNotImplemented, 68 | normal: GridSearchNotImplemented, 69 | lognormal: GridSearchNotImplemented, 70 | }; 71 | 72 | export class GridSearch extends BaseSpace { 73 | numSamples = (expr) => { 74 | const flat = this.samples(expr); 75 | return flat.reduce((r, o) => r * o.samples, 1); 76 | }; 77 | samples = (expr) => { 78 | if (!expr) { 79 | return expr; 80 | } 81 | const flat = []; 82 | const { name } = expr; 83 | const Param = GridSearchParamas[name]; 84 | if (Param === undefined) { 85 | if (Array.isArray(expr)) { 86 | expr.forEach(el => flat.push(...this.samples(el))); 87 | } 88 | if (typeof expr === 'string') { 89 | flat.push({ name, samples: 1, expr }); 90 | } 91 | if (typeof expr === 'object') { 92 | Object.keys(expr).forEach(key => flat.push(...this.samples(expr[key]))); 93 | } 94 | } else { 95 | flat.push({ name, samples: (new Param(expr, this)).numSamples, expr }); 96 | } 97 | return flat; 98 | }; 99 | } 100 | 101 | export const gridSample = (space, params = {}) => { 102 | const gs = new GridSearch(); 103 | const args = gs.eval(space, params); 104 | if (Object.keys(args).length === 1) { 105 | const results = Object.keys(args).map(key => args[key]); 106 | return results.length === 1 ? results[0] : results; 107 | } 108 | return args; 109 | }; 110 | 111 | export const gridSearch = (newIds, domain, trials) => { 112 | let rval = []; 113 | const gs = new GridSearch(); 114 | newIds.forEach((newId) => { 115 | const paramsEval = gs.eval(domain.expr); 116 | const result = domain.newResult(); 117 | rval = [...rval, ...trials.newTrialDocs([newId], [result], [paramsEval])]; 118 | }); 119 | return rval; 120 | }; 121 | -------------------------------------------------------------------------------- /src/search/random.js: -------------------------------------------------------------------------------- 1 | import RandomState from '../utils/RandomState'; 2 | import BaseSpace from '../base/base'; 3 | 4 | export class RandomSearch extends BaseSpace { 5 | choice = (params, rng) => { 6 | const { options } = params; 7 | const idx = rng.randrange(0, options.length, 1); 8 | const option = options[idx]; 9 | const arg = this.eval(option, { rng }); 10 | return arg; 11 | }; 12 | 13 | randint = (params, rng) => rng.randrange(0, params.upper, 1) 14 | 15 | uniform = (params, rng) => { 16 | const { low, high } = params; 17 | return rng.uniform(low, high); 18 | }; 19 | 20 | quniform = (params, rng) => { 21 | const { low, high, q } = params; 22 | return Math.round(rng.uniform(low, high) / q) * q; 23 | }; 24 | 25 | loguniform = (params, rng) => { 26 | const { low, high } = params; 27 | return Math.exp(rng.uniform(low, high)); 28 | }; 29 | 30 | qloguniform = (params, rng) => { 31 | const { low, high, q } = params; 32 | return Math.round(Math.exp(rng.uniform(low, high)) / q) * q; 33 | }; 34 | 35 | normal = (params, rng) => { 36 | const { mu, sigma } = params; 37 | return rng.gauss(mu, sigma); 38 | }; 39 | 40 | qnormal = (params, rng) => { 41 | const { mu, sigma, q } = params; 42 | return Math.round(rng.gauss(mu, sigma) / q) * q; 43 | }; 44 | 45 | lognormal = (params, rng) => { 46 | const { mu, sigma } = params; 47 | return Math.exp(rng.gauss(mu, sigma)); 48 | }; 49 | 50 | qlognormal = (params, rng) => { 51 | const { mu, sigma, q } = params; 52 | return Math.round(Math.exp(rng.gauss(mu, sigma)) / q) * q; 53 | }; 54 | } 55 | 56 | export const randomSample = (space, params = {}) => { 57 | const rs = new RandomSearch(); 58 | const args = rs.eval(space, params); 59 | if (Object.keys(args).length === 1) { 60 | const results = Object.keys(args).map(key => args[key]); 61 | return results.length === 1 ? results[0] : results; 62 | } 63 | return args; 64 | }; 65 | 66 | export const randomSearch = (newIds, domain, trials, seed) => { 67 | const rng = new RandomState(seed); 68 | let rval = []; 69 | const rs = new RandomSearch(); 70 | newIds.forEach((newId) => { 71 | const paramsEval = rs.eval(domain.expr, { rng }); 72 | const result = domain.newResult(); 73 | rval = [...rval, ...trials.newTrialDocs([newId], [result], [paramsEval])]; 74 | }); 75 | return rval; 76 | }; 77 | -------------------------------------------------------------------------------- /src/utils/RandomState.js: -------------------------------------------------------------------------------- 1 | /* eslint-disable no-bitwise */ 2 | // https://gist.github.com/banksean/300494 3 | // Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, 4 | 5 | // https://github.com/jrus/random-js/blob/master/random.coffee 6 | 7 | const POW_NEG_26 = 2 ** -26; 8 | const POW_NEG_27 = 2 ** -27; 9 | const POW_32 = 2 ** 32; 10 | 11 | export default class RandomState { 12 | constructor(seed) { 13 | this.bits = {}; 14 | this.seed = seed === undefined ? new Date().getTime() : seed; 15 | this.N = 624; 16 | this.M = 397; 17 | this.MATRIX_A = 0x9908b0df; /* constant vector a */ 18 | this.UPPER_MASK = 0x80000000; /* most significant w-r bits */ 19 | this.LOWER_MASK = 0x7fffffff; /* least significant r bits */ 20 | 21 | this.mt = new Array(this.N); /* the array for the state vector */ 22 | this.mti = this.N + 1; /* mti==N+1 means mt[N] is not initialized */ 23 | 24 | this.initGen(this.seed); 25 | } 26 | 27 | /* initializes mt[N] with a seed */ 28 | initGen(seed) { 29 | this.mt[0] = seed >>> 0; 30 | for (this.mti = 1; this.mti < this.N; this.mti += 1) { 31 | const s = this.mt[this.mti - 1] ^ (this.mt[this.mti - 1] >>> 30); 32 | this.mt[this.mti] = 33 | // eslint-disable-next-line no-mixed-operators 34 | (((((s & 0xffff0000) >>> 16) * 1812433253) << 16) + (s & 0x0000ffff) * 1812433253) 35 | + this.mti; 36 | /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ 37 | /* In the previous versions, MSBs of the seed affect */ 38 | /* only MSBs of the array mt[]. */ 39 | /* 2002/01/09 modified by Makoto Matsumoto */ 40 | this.mt[this.mti] >>>= 0; 41 | /* for >32 bit machines */ 42 | } 43 | this.next_gauss = null; 44 | } 45 | randint() { 46 | let y; 47 | const mag01 = [0x0, this.MATRIX_A]; 48 | /* mag01[x] = x * MATRIX_A for x=0,1 */ 49 | 50 | if (this.mti >= this.N) { /* generate N words at one time */ 51 | let kk; 52 | 53 | if (this.mti === this.N + 1) { /* if initGen() has not been called, */ 54 | this.initGen(5489); 55 | } /* a default initial seed is used */ 56 | 57 | for (kk = 0; kk < this.N - this.M; kk += 1) { 58 | y = (this.mt[kk] & this.UPPER_MASK) | (this.mt[kk + 1] & this.LOWER_MASK); 59 | this.mt[kk] = this.mt[kk + this.M] ^ (y >>> 1) ^ mag01[y & 0x1]; 60 | } 61 | for (;kk < this.N - 1; kk += 1) { 62 | y = (this.mt[kk] & this.UPPER_MASK) | (this.mt[kk + 1] & this.LOWER_MASK); 63 | this.mt[kk] = this.mt[kk + (this.M - this.N)] ^ (y >>> 1) ^ mag01[y & 0x1]; 64 | } 65 | y = (this.mt[this.N - 1] & this.UPPER_MASK) | (this.mt[0] & this.LOWER_MASK); 66 | this.mt[this.N - 1] = this.mt[this.M - 1] ^ (y >>> 1) ^ mag01[y & 0x1]; 67 | 68 | this.mti = 0; 69 | } 70 | 71 | y = this.mt[this.mti += 1]; 72 | 73 | /* Tempering */ 74 | y ^= (y >>> 11); 75 | y ^= (y << 7) & 0x9d2c5680; 76 | y ^= (y << 15) & 0xefc60000; 77 | y ^= (y >>> 18); 78 | 79 | return y >>> 0; 80 | } 81 | 82 | random() { 83 | // Return a random float in the range [0, 1), with a full 53 84 | // bits of entropy. 85 | const val = this.randint(); 86 | const lowBits = val >>> 6; 87 | const highBits = val >>> 5; 88 | return (highBits + (lowBits * POW_NEG_26)) * POW_NEG_27; 89 | } 90 | 91 | randbelow(upperBound) { 92 | if (upperBound <= 0) { 93 | return 0; 94 | } 95 | const lg = x => (Math.LOG2E * Math.log(x + 1e-10)) >> 0; 96 | if (upperBound <= 0x100000000) { 97 | let r = upperBound; 98 | const bits = this.bits[upperBound] || 99 | (this.bits[upperBound] = (lg(upperBound - 1)) + 1); // memoize values for `bits` 100 | while (r >= upperBound) { 101 | r = this.randint() >>> (32 - bits); 102 | if (r < 0) { 103 | r += POW_32; 104 | } 105 | } 106 | return r; 107 | } 108 | return this.randint() % upperBound; 109 | } 110 | randrange(start, stop, step) { 111 | // Return a random integer N in range `[start...stop] by step` 112 | if (stop === undefined) { 113 | return this.randbelow(start); 114 | } else if (!step) { 115 | return start + this.randbelow(stop - start); 116 | } 117 | return start + (step * this.randbelow(Math.floor((stop - start) / step))); 118 | } 119 | gauss(mu = 0, sigma = 1) { 120 | // Gaussian distribution. `mu` is the mean, and `sigma` is the standard 121 | // deviation. Notes: 122 | // * uses the "polar method" 123 | // * we generate pairs; keep one in a cache for next time 124 | let z = this.next_gauss; 125 | if (z != null) { 126 | this.next_gauss = null; 127 | } else { 128 | let s; 129 | let u; 130 | let v; 131 | while (!s || !(s < 1)) { 132 | u = (2 * this.random()) - 1; 133 | v = (2 * this.random()) - 1; 134 | s = (u * u) + (v * v); 135 | } 136 | const w = Math.sqrt((-2 * (Math.log(s))) / s); 137 | z = u * w; this.next_gauss = v * w; 138 | } 139 | return mu + (z * sigma); // Alias for the `gauss` function 140 | } 141 | uniform(a, b) { 142 | // Return a random floating point number N such that a <= N <= b for 143 | // a <= b and b <= N <= a for b < a. 144 | return a + (this.random() * (b - a)); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /test/__snapshots__/index.js.json: -------------------------------------------------------------------------------- 1 | { 2 | "upper: negative": "0.0000000", 3 | "upper: 0": "0.0000000", 4 | "upper: 1": "0.0000000", 5 | "upper: 1000000": "933394.0000000", 6 | "uniform -1, 1": "0.7803094", 7 | "uniform -100000, -1": "-10985.4183013", 8 | "uniform -1, -10": "-9.0113925", 9 | "uniform 5, 1": "1.4393811", 10 | "uniform 1, 1000000": "890154.8283796", 11 | "uniform 1, 1": "1.0000000", 12 | "quniform -1, 1, 0.1": "0.8000000", 13 | "quniform -100000, -1, -1": "-10985.0000000", 14 | "quniform -1, -10, 0.22222": "-9.1110200", 15 | "quniform 5, 1, -0.111": "1.4430000", 16 | "quniform 1, 1000000, 50": "890150.0000000", 17 | "quniform 1, 1, 0.001": "1.0000000", 18 | "loguniform -1, 1": "2.1821474", 19 | "loguniform -100000, -1": "0.0000000", 20 | "loguniform -1, -10": "0.0001220", 21 | "loguniform 5, 1": "4.2180845", 22 | "loguniform 1, 1": "2.7182818", 23 | "qloguniform -1, 1, 0.1": "2.2000000", 24 | "qloguniform -100000, -1, -1": "0.0000000", 25 | "qloguniform -1, -10, 0.2222": "0.0000000", 26 | "qloguniform 5, 1, -0.111": "4.2180000", 27 | "qloguniform 5, 1, 0.1": "4.2000000", 28 | "qloguniform 1, 1, 0.001": "2.7180000", 29 | "normal -1, 1": "-0.3037819", 30 | "normal -100000, -1": "-100000.6962181", 31 | "normal -1, -10": "-7.9621809", 32 | "normal 5, 1": "5.6962181", 33 | "normal 1, 1": "1.6962181", 34 | "qnormal -1, 1, 0.1": "-0.3000000", 35 | "qnormal -100000, -1, -1": "-100001.0000000", 36 | "qnormal -1, -10, 0.22222": "-7.9999200", 37 | "qnormal 5, 1, -0.111": "5.6610000", 38 | "qnormal 1, 1000000, 50": "50.0000000", 39 | "qnormal 1, 1, 0.001": "1.6960000", 40 | "lognormal -1, 1": "0.7380218", 41 | "lognormal -100000, -1": "0.0000000", 42 | "lognormal -1, -10": "0.0003484", 43 | "lognormal 5, 1": "297.7392463", 44 | "lognormal 1, 1": "5.4532845", 45 | "qlognormal -1, 1, 0.1": "0.7000000", 46 | "qlognormal -100000, -1, -1": "0.0000000", 47 | "qlognormal -1, -10, 0.22222": "0.0000000", 48 | "qlognormal 5, 1, -0.111": "297.7020000", 49 | "qlognormal 5, 1, 0.1": "297.7000000", 50 | "qlognormal 1, 1, 0.001": "5.4530000", 51 | "sample: array": "-3.6724889", 52 | "sample: depth": { 53 | "x": "1.3924362", 54 | "y": "0.1307073", 55 | "choice": null, 56 | "array": [ 57 | -0.6553433999116457, 58 | 0.11927848398147867, 59 | false 60 | ], 61 | "obj": { 62 | "u": 2.47930839269838, 63 | "v": 1.703175088041835, 64 | "w": -1.4037662402075997 65 | } 66 | }, 67 | "FMin for x^2 - x + 1": {}, 68 | "Hyperparameters space": { 69 | "x": "-0.7975415", 70 | "y": "-0.7414791" 71 | }, 72 | "choice as array space": {}, 73 | "Deep learning space": { 74 | "learning_rate": "0.0000123", 75 | "use_double_q_learning": true, 76 | "layer1_size": "29.0000000", 77 | "layer2_size": "41.0000000", 78 | "layer3_size": "56.0000000", 79 | "future_discount_max": "0.6085554", 80 | "future_discount_increment": "0.0013308", 81 | "recall_memory_size": "8.0000000", 82 | "recall_memory_num_experiences_per_recall": "1275.0000000", 83 | "num_epochs": "2.0000000" 84 | } 85 | } -------------------------------------------------------------------------------- /test/index.js: -------------------------------------------------------------------------------- 1 | import chai, { assert } from 'chai'; 2 | import snapshots from 'chai-snapshot-tests'; 3 | import * as hpjs from '../src'; 4 | import { GridSearch } from '../src/search/grid'; 5 | 6 | chai.use(snapshots(__filename)); 7 | 8 | const seededSample = (space) => hpjs.sample.randomSample(space, { rng: new hpjs.RandomState(12345) }); 9 | 10 | const floatSeededSample = (space) => hpjs.sample.randomSample(space, { rng: new hpjs.RandomState(12345) }).toFixed(7); 11 | 12 | const objectToFixed = (obj) => Object.keys(obj).reduce((final, key) => { 13 | const value = typeof obj[key] === 'number' ? obj[key].toFixed(7) : obj[key]; 14 | return { ...final, [key]: value}; 15 | }, {}); 16 | 17 | const randFMinSeeded = async (opt, space) => { 18 | const trials = await hpjs.fmin(opt, space, hpjs.search.randomSearch, 100, { rng: new hpjs.RandomState(12345) }); 19 | return objectToFixed(trials.argmin); 20 | } 21 | 22 | 23 | describe('hpjs.choice.', () => { 24 | it('is a string', () => { 25 | const val = seededSample(hpjs.choice(['cat', 'dog'])); 26 | assert.typeOf(val, 'string'); 27 | }); 28 | it('is one of the elements', () => { 29 | const val = seededSample(hpjs.choice(['cat', 'dog'])); 30 | assert(['cat', 'dog'].indexOf(val) >= 0, 'val was actually ' + val); 31 | }); 32 | it('picks a number', () => { 33 | const val = seededSample(hpjs.choice([1, 2, 3, 4])); 34 | assert(val === 4, 'val was actually: ' + val); 35 | }); 36 | }); 37 | 38 | describe('hpjs.randint.', () => { 39 | it('in range [0,5)', () => { 40 | const val = seededSample(hpjs.randint(5)); 41 | assert(val >= 0 && val < 5, `actual value ${val}`); 42 | }); 43 | it('Snapshot tests', () => { 44 | assert.snapshot('upper: negative', floatSeededSample(hpjs.randint(-2))); 45 | assert.snapshot('upper: 0', floatSeededSample(hpjs.randint(0))); 46 | assert.snapshot('upper: 1', floatSeededSample(hpjs.randint(1))); 47 | assert.snapshot('upper: 1000000', floatSeededSample(hpjs.randint(1000000))); 48 | }); 49 | 50 | }); 51 | 52 | describe('hpjs.uniform.', () => { 53 | it('between 0 and 1', () => { 54 | const val = seededSample(hpjs.uniform(0, 1)); 55 | assert(val >= 0 && val <= 1, `actual value ${val}`); 56 | }); 57 | it('between -1 and 1', () => { 58 | const val = seededSample(hpjs.uniform(-1, 1)); 59 | assert(val >= -1 && val <= 1, `actual value ${val}`); 60 | }); 61 | 62 | it('Snapshot tests', () => { 63 | assert.snapshot('uniform -1, 1', floatSeededSample(hpjs.uniform(-1, 1))); 64 | assert.snapshot('uniform -100000, -1', floatSeededSample(hpjs.uniform(-100000, -1))); 65 | assert.snapshot('uniform -1, -10', floatSeededSample(hpjs.uniform(-1, -10))); 66 | assert.snapshot('uniform 5, 1', floatSeededSample(hpjs.uniform(5, 1))); 67 | assert.snapshot('uniform 1, 1000000', floatSeededSample(hpjs.uniform(1, 1000000))); 68 | assert.snapshot('uniform 1, 1', floatSeededSample(hpjs.uniform(1, 1))); 69 | }); 70 | }); 71 | 72 | describe('hpjs.quniform.', () => { 73 | it('between 0 and 1', () => { 74 | const val = seededSample(hpjs.quniform(0, 1, 0.2)); 75 | assert(val >= 0 && val <= 1, `actual value ${val}`); 76 | }); 77 | it('between -1 and 1, step 1', () => { 78 | const val = seededSample(hpjs.quniform(-1, 1, 1)); 79 | assert(val === -1 || val === 0 || val === 1, `actual value ${val}`); 80 | }); 81 | it('Snapshot tests', () => { 82 | assert.snapshot('quniform -1, 1, 0.1', floatSeededSample(hpjs.quniform(-1, 1, 0.1))); 83 | assert.snapshot('quniform -100000, -1, -1', floatSeededSample(hpjs.quniform(-100000, -1, -1))); 84 | assert.snapshot('quniform -1, -10, 0.22222', floatSeededSample(hpjs.quniform(-1, -10, 0.22222))); 85 | assert.snapshot('quniform 5, 1, -0.111', floatSeededSample(hpjs.quniform(5, 1, -0.111))); 86 | assert.snapshot('quniform 1, 1000000, 50', floatSeededSample(hpjs.quniform(1, 1000000, 50))); 87 | assert.snapshot('quniform 1, 1, 0.001', floatSeededSample(hpjs.quniform(1, 1, 0.001))); 88 | }); 89 | }); 90 | 91 | describe('hpjs.loguniform.', () => { 92 | it('between e^0 and e^1', () => { 93 | const low = 0; 94 | const high = 1; 95 | const val = seededSample(hpjs.loguniform(low, high)); 96 | assert(val >= Math.exp(low) && val <= Math.exp(high), `actual value ${val}`); 97 | }); 98 | it('Snapshot tests', () => { 99 | assert.snapshot('loguniform -1, 1', floatSeededSample(hpjs.loguniform(-1, 1))); 100 | assert.snapshot('loguniform -100000, -1', floatSeededSample(hpjs.loguniform(-100000, -1))); 101 | assert.snapshot('loguniform -1, -10', floatSeededSample(hpjs.loguniform(-1, -10))); 102 | assert.snapshot('loguniform 5, 1', floatSeededSample(hpjs.loguniform(5, 1))); 103 | assert.snapshot('loguniform 5, 1', floatSeededSample(hpjs.loguniform(5, 1))); 104 | assert.snapshot('loguniform 1, 1', floatSeededSample(hpjs.loguniform(1, 1))); 105 | }); 106 | }); 107 | 108 | describe('hpjs.qloguniform.', () => { 109 | it('e^0 and e^1', () => { 110 | const low = 0; 111 | const high = 1; 112 | const val = seededSample(hpjs.qloguniform(low, high, 0.2)); 113 | assert(val >= Math.exp(low) && val <= Math.exp(high), `actual value ${val}`); 114 | }); 115 | it('Snapshot tests', () => { 116 | assert.snapshot('qloguniform -1, 1, 0.1', floatSeededSample(hpjs.qloguniform(-1, 1, 0.1))); 117 | assert.snapshot('qloguniform -100000, -1, -1', floatSeededSample(hpjs.qloguniform(-100000, -1, -1))); 118 | assert.snapshot('qloguniform -1, -10, 0.2222', floatSeededSample(hpjs.qloguniform(-1, -10, 0.22222))); 119 | assert.snapshot('qloguniform 5, 1, -0.111', floatSeededSample(hpjs.qloguniform(5, 1, -0.111))); 120 | assert.snapshot('qloguniform 5, 1, 0.1', floatSeededSample(hpjs.qloguniform(5, 1, 0.1))); 121 | assert.snapshot('qloguniform 1, 1, 0.001', floatSeededSample(hpjs.qloguniform(1, 1, 0.001))); 122 | }); 123 | }); 124 | 125 | describe('hpjs.normal.', () => { 126 | it('a number', () => { 127 | const mu = -1; 128 | const sigma = 1; 129 | const val = seededSample(hpjs.normal(mu, sigma)); 130 | assert(!isNaN(val), `actual value ${val}`); 131 | }); 132 | it('within 3 standard deviations of mean', () => { 133 | const mu = 0; 134 | const sigma = 1; 135 | const val = seededSample(hpjs.normal(mu, sigma)); 136 | assert(val >= mu - (3*sigma) && val <= mu + (3*sigma), `actual value ${val}`); 137 | }); 138 | it('Snapshot tests', () => { 139 | assert.snapshot('normal -1, 1', floatSeededSample(hpjs.normal(-1, 1))); 140 | assert.snapshot('normal -100000, -1', floatSeededSample(hpjs.normal(-100000, -1))); 141 | assert.snapshot('normal -1, -10', floatSeededSample(hpjs.normal(-1, -10))); 142 | assert.snapshot('normal 5, 1', floatSeededSample(hpjs.normal(5, 1))); 143 | assert.snapshot('normal 5, 1', floatSeededSample(hpjs.normal(5, 1))); 144 | assert.snapshot('normal 1, 1', floatSeededSample(hpjs.normal(1, 1))); 145 | }); 146 | }); 147 | 148 | describe('hpjs.qnormal.', () => { 149 | it('a number', () => { 150 | const mu = -1; 151 | const sigma = 1; 152 | const val = seededSample(hpjs.normal(mu, sigma)); 153 | assert(!isNaN(val), `actual value ${val}`); 154 | }); 155 | it('within 3 standard deviations of mean', () => { 156 | const mu = 0; 157 | const sigma = 1; 158 | const val = seededSample(hpjs.qnormal(mu, sigma, 0.1)); 159 | assert(val >= mu - (3*sigma) && val <= mu + (3*sigma), `actual value ${val}`); 160 | }); 161 | it('Snapshot tests', () => { 162 | assert.snapshot('qnormal -1, 1, 0.1', floatSeededSample(hpjs.qnormal(-1, 1, 0.1))); 163 | assert.snapshot('qnormal -100000, -1, -1', floatSeededSample(hpjs.qnormal(-100000, -1, -1))); 164 | assert.snapshot('qnormal -1, -10, 0.22222', floatSeededSample(hpjs.qnormal(-1, -10, 0.22222))); 165 | assert.snapshot('qnormal 5, 1, -0.111', floatSeededSample(hpjs.qnormal(5, 1, -0.111))); 166 | assert.snapshot('qnormal 1, 1000000, 50', floatSeededSample(hpjs.qnormal(1, 100, 50))); 167 | assert.snapshot('qnormal 1, 1, 0.001', floatSeededSample(hpjs.qnormal(1, 1, 0.001))); 168 | }); 169 | }); 170 | 171 | describe('hpjs.lognormal.', () => { 172 | it('positive', () => { 173 | const mu = 0; 174 | const sigma = 1; 175 | const val = seededSample(hpjs.lognormal(mu, sigma)); 176 | assert(val >= 0, `actual value ${val}`); 177 | }); 178 | it('less ~e^3 from the mean, or less than ~3 standard deviations from it', () => { 179 | const mu = 0; 180 | const sigma = 1; 181 | const val = seededSample(hpjs.lognormal(mu, sigma)); 182 | assert(val <= 50, `actual value ${val}`); 183 | }); 184 | it('Snapshot tests', () => { 185 | assert.snapshot('lognormal -1, 1', floatSeededSample(hpjs.lognormal(-1, 1))); 186 | assert.snapshot('lognormal -100000, -1', floatSeededSample(hpjs.lognormal(-100000, -1))); 187 | assert.snapshot('lognormal -1, -10', floatSeededSample(hpjs.lognormal(-1, -10))); 188 | assert.snapshot('lognormal 5, 1', floatSeededSample(hpjs.lognormal(5, 1))); 189 | assert.snapshot('lognormal 5, 1', floatSeededSample(hpjs.lognormal(5, 1))); 190 | assert.snapshot('lognormal 1, 1', floatSeededSample(hpjs.lognormal(1, 1))); 191 | }); 192 | }); 193 | 194 | describe('hpjs.qlognormal.', () => { 195 | it('a number', () => { 196 | const mu = -1; 197 | const sigma = 1; 198 | const val = seededSample(hpjs.qlognormal(mu, sigma, 0.2)); 199 | assert(!isNaN(val), `actual value ${val}`); 200 | }); 201 | it('within 3 standard deviations of mean', () => { 202 | const mu = 0; 203 | const sigma = 1; 204 | const val = seededSample(hpjs.qlognormal(mu, sigma, 0.1)); 205 | assert(val >= mu - (3*sigma) && val <= mu + (3*sigma), `actual value ${val}`); 206 | }); 207 | it('Snapshot tests', () => { 208 | assert.snapshot('qlognormal -1, 1, 0.1', floatSeededSample(hpjs.qlognormal(-1, 1, 0.1))); 209 | assert.snapshot('qlognormal -100000, -1, -1', floatSeededSample(hpjs.qlognormal(-100000, -1, -1))); 210 | assert.snapshot('qlognormal -1, -10, 0.22222', floatSeededSample(hpjs.qlognormal(-1, -10, 0.22222))); 211 | assert.snapshot('qlognormal 5, 1, -0.111', floatSeededSample(hpjs.qlognormal(5, 1, -0.111))); 212 | assert.snapshot('qlognormal 5, 1, 0.1', floatSeededSample(hpjs.qlognormal(5, 1, 0.1))); 213 | assert.snapshot('qlognormal 1, 1, 0.001', floatSeededSample(hpjs.qlognormal(1, 1, 0.001))); 214 | }); 215 | }); 216 | 217 | describe('random sample', () => { 218 | it('Choice as array', () => { 219 | const space = hpjs.choice( 220 | [ 221 | hpjs.lognormal(0, 1), 222 | hpjs.uniform(-10, 10) 223 | ] 224 | ); 225 | 226 | assert.snapshot('sample: array', floatSeededSample(space)); 227 | }); 228 | it('more complex space with depth', () => { 229 | const space = { 230 | x: hpjs.normal(0, 2), 231 | y: hpjs.uniform(0, 1), 232 | choice: hpjs.choice([ 233 | null, hpjs.uniform(0, 1), 234 | ]), 235 | array: [ 236 | hpjs.normal(0, 2), hpjs.uniform(0, 3), hpjs.choice([false, true]), 237 | ], 238 | obj: { 239 | u: hpjs.uniform(0, 3), 240 | v: hpjs.uniform(0, 3), 241 | w: hpjs.uniform(-3, 0) 242 | } 243 | }; 244 | assert.snapshot('sample: depth', objectToFixed(seededSample(space))); 245 | }); 246 | }); 247 | describe('grid search', () => { 248 | const gs = new GridSearch(); 249 | it('choice', () => { 250 | const space = hpjs.choice(['cat', 'dog']); 251 | assert(gs.numSamples(space) === 2, `choice num samples ${gs.numSamples(space)}`); 252 | }); 253 | it('randint', () => { 254 | const space = hpjs.randint(5); 255 | assert(gs.numSamples(space) === 5, `randint [0, 5) ${gs.numSamples(space)}`); 256 | }); 257 | it('quniform', () => { 258 | let space = hpjs.quniform(0, 1, 0.1); 259 | assert(gs.numSamples(space) === 11, `quniform 0,1,0.1 ${gs.numSamples(space)}`); 260 | space = hpjs.quniform(-5, 5, 1); 261 | assert(gs.numSamples(space) === 11, `quniform -5,5,1 ${gs.numSamples(space)}`); 262 | space = hpjs.quniform(1, 10, 2); 263 | assert(gs.numSamples(space) === 5, `quniform 1,10,2 ${gs.numSamples(space)}`); 264 | }); 265 | it('qloguniform', () => { 266 | let space = hpjs.qloguniform(0,1,0.1); 267 | assert(gs.numSamples(space) === 11, `qloguniform 0,1,0.1 ${gs.numSamples(space)}`); 268 | space = hpjs.qloguniform(-5,5,1); 269 | assert(gs.numSamples(space) === 11, `qloguniform -5,5,1 ${gs.numSamples(space)}`); 270 | space = hpjs.qloguniform(1,10,2); 271 | assert(gs.numSamples(space) === 5, `qloguniform 1,10,2 ${gs.numSamples(space)}`); 272 | }); 273 | 274 | it('qnormal', () => { 275 | let space = hpjs.qnormal(0, 1, 0.1); 276 | assert(gs.numSamples(space) === 41, `qnormal 0,1,0.1 ${gs.numSamples(space)}`); 277 | space = hpjs.qnormal(-5,5,1); 278 | assert(gs.numSamples(space) === 21, `qnormal -5,5,1 ${gs.numSamples(space)}`); 279 | space = hpjs.qnormal(1,10,2); 280 | assert(gs.numSamples(space) === 21, `qnormal 1,10,2 ${gs.numSamples(space)}`); 281 | 282 | }); 283 | it('uniform', () => { 284 | try { 285 | gs.numSamples(hpjs.uniform(0, 1)); 286 | assert(false, 'hpjs.uniform not allowed for grid search'); 287 | } catch (e) { 288 | assert(e.message === 'Can not evaluate length of non-discrete parameter "uniform"', `exception message ${e.message}`); 289 | } 290 | }); 291 | it('loguniform', () => { 292 | try { 293 | gs.numSamples(hpjs.loguniform(0, 1)); 294 | assert(false, 'hpjs.loguniform not allowed for grid search'); 295 | } catch (e) { 296 | assert(e.message === 'Can not evaluate length of non-discrete parameter "loguniform"', `exception message ${e.message}`); 297 | } 298 | }); 299 | it('normal', () => { 300 | try { 301 | gs.numSamples(hpjs.normal(0, 1)); 302 | assert(false, 'hpjs.normal not allowed for grid search'); 303 | } catch (e) { 304 | assert(e.message === 'Can not evaluate length of non-discrete parameter "normal"', `exception message ${e.message}`); 305 | } 306 | }); 307 | it('lognormal', () => { 308 | try { 309 | gs.numSamples(hpjs.lognormal(0, 1)); 310 | assert(false, 'hpjs.lognormal not allowed for grid search'); 311 | } catch (e) { 312 | assert(e.message === 'Can not evaluate length of non-discrete parameter "lognormal"', `exception message ${e.message}`); 313 | } 314 | }); 315 | it('choice grid search', () => { 316 | const space = hpjs.choice( 317 | [ 318 | hpjs.qlognormal(0, 1, 1), //5 319 | hpjs.quniform(-10, 10, 1) //21 320 | ] 321 | ); 322 | assert(gs.numSamples(space) === 26, `choice ${gs.numSamples(space)}`); 323 | }); 324 | }); 325 | 326 | describe('fmin + rand', () => { 327 | it('FMin for x^2 - x + 1', async () => { 328 | const space = hpjs.uniform(-5, 5); 329 | const opt = x => ((x ** 2) - (x + 1)); 330 | assert.snapshot('FMin for x^2 - x + 1', await randFMinSeeded(opt, space)); 331 | }); 332 | it('Hyperparameters space', async () => { 333 | const space = { 334 | x: hpjs.uniform(-5, 5), 335 | y: hpjs.uniform(-5, 5) 336 | }; 337 | const opt = ({ x, y }) => ((x ** 2) + (y ** 2)); 338 | assert.snapshot('Hyperparameters space', await randFMinSeeded(opt, space)); 339 | }); 340 | it('Choice selection of expressions', async () => { 341 | const space = hpjs.choice([ 342 | hpjs.lognormal(0, 1), 343 | hpjs.uniform(-10, 10) 344 | ] 345 | ); 346 | const opt = ( x ) => (x ** 2); 347 | assert.snapshot('choice as array space', await randFMinSeeded(opt, space)); 348 | }); 349 | it('Deep learning space', async () => { 350 | const space = { 351 | // Learning rate should be between 0.00001 and 1 352 | learning_rate: 353 | hpjs.loguniform(Math.log(1e-5), Math.log(1)), 354 | use_double_q_learning: 355 | hpjs.choice([true, false]), 356 | layer1_size: hpjs.quniform(10, 100, 1), 357 | layer2_size: hpjs.quniform(10, 100, 1), 358 | layer3_size: hpjs.quniform(10, 100, 1), 359 | future_discount_max: hpjs.uniform(0.5, 0.99), 360 | future_discount_increment: hpjs.loguniform(Math.log(0.001), Math.log(0.1)), 361 | recall_memory_size: hpjs.quniform(1, 100, 1), 362 | recall_memory_num_experiences_per_recall: hpjs.quniform(10, 2000, 1), 363 | num_epochs: hpjs.quniform(1, 10, 1), 364 | }; 365 | 366 | const opt = params => params.learning_rate ** 2; 367 | 368 | assert.snapshot('Deep learning space', await randFMinSeeded(opt, space)); 369 | }); 370 | }); 371 | 372 | --------------------------------------------------------------------------------