├── .babelrc ├── .editorconfig ├── .eslintrc ├── .gitignore ├── .npmignore ├── LICENSE ├── README.md ├── app ├── index.html ├── index.js ├── iris-custom │ ├── data.js │ └── iris.js ├── iris │ ├── data.js │ └── iris.js ├── mnist-conv │ └── mnist.js ├── mnist │ ├── data.js │ └── mnist.js ├── models │ └── mnist-dense │ │ ├── group1-shard1of1 │ │ ├── group2-shard1of1 │ │ ├── group3-shard1of1 │ │ └── model.json └── tiny │ └── tiny.js ├── lib ├── tfjs-model-view.js └── tfjs-model-view.min.js ├── package-lock.json ├── package.json ├── src ├── default.config.js ├── index.js ├── model-parser.js └── renderers │ ├── abstract.renderer.js │ └── canvas.renderer.js ├── stats.json ├── test └── index.spec.js ├── webpack.config.js └── webpack.dev.config.js /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": ["env"], 3 | "plugins": ["babel-plugin-add-module-exports", "transform-object-rest-spread"] 4 | } 5 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 2 6 | end_of_line = LF 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.md] 12 | trim_trailing_whitespace = false 13 | -------------------------------------------------------------------------------- /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "node": true, 5 | "es6": true, 6 | "mocha": true 7 | }, 8 | "parserOptions": { 9 | "ecmaVersion": 6, 10 | "sourceType": "module", 11 | "ecmaFeatures": { 12 | "impliedStrict": true, 13 | "experimentalObjectRestSpread": true 14 | } 15 | }, 16 | "globals": {}, 17 | "rules": { 18 | "no-cond-assign": 2, 19 | "no-constant-condition": 2, 20 | "no-dupe-args": 2, 21 | "no-dupe-keys": 2, 22 | "no-duplicate-case": 2, 23 | "no-empty-character-class": 2, 24 | "no-empty": [ 25 | 1, 26 | { 27 | "allowEmptyCatch": true 28 | } 29 | ], 30 | "no-ex-assign": 2, 31 | "no-extra-boolean-cast": 2, 32 | "no-extra-parens": 2, 33 | "no-extra-semi": 2, 34 | "no-func-assign": 2, 35 | "no-inner-declarations": [ 36 | 2, 37 | "both" 38 | ], 39 | "no-invalid-regexp": 2, 40 | "no-irregular-whitespace": [ 41 | 2, 42 | { 43 | "skipStrings": true, 44 | "skipComments": true, 45 | "skipRegExps": true, 46 | "skipTemplates": true 47 | } 48 | ], 49 | "no-negated-in-lhs": 2, 50 | "no-obj-calls": 2, 51 | "no-prototype-builtins": 2, 52 | "no-regex-spaces": 2, 53 | "no-sparse-arrays": 1, 54 | "no-unexpected-multiline": 1, 55 | "no-unreachable": 2, 56 | "no-unsafe-finally": 1, 57 | "use-isnan": 2, 58 | "valid-jsdoc": 0, 59 | "valid-typeof": 2, 60 | "accessor-pairs": 1, 61 | "array-callback-return": 1, 62 | "block-scoped-var": 1, 63 | "curly": 1, 64 | "default-case": 1, 65 | "dot-notation": 1, 66 | "eqeqeq": [ 67 | 2, 68 | "allow-null" 69 | ], 70 | "guard-for-in": 1, 71 | "no-alert": 1, 72 | "no-caller": 2, 73 | "no-eval": 2, 74 | "no-extend-native": 1, 75 | "no-extra-bind": 1, 76 | 77 | "no-floating-decimal": 2, 78 | 79 | "no-implicit-globals": 1, 80 | 81 | "no-implied-eval": 2, 82 | 83 | "no-lone-blocks": 1, 84 | 85 | "no-loop-func": 1, 86 | 87 | "no-multi-spaces": 1, 88 | 89 | "no-redeclare": [ 90 | 2, 91 | { 92 | "builtinGlobals": true 93 | } 94 | ], 95 | 96 | "no-sequences": 1, 97 | 98 | "vars-on-top": 1, 99 | 100 | "wrap-iife": 1, 101 | 102 | "yoda": 1, 103 | "strict": [ 104 | 1, 105 | "function" 106 | ], 107 | 108 | "no-delete-var": 2, 109 | 110 | "no-restricted-globals": 2, 111 | 112 | "no-unused-vars": 1, 113 | 114 | "no-use-before-define": 2, 115 | 116 | "no-undef": 2, 117 | 118 | "no-undef-init": 1, 119 | 120 | "no-undefined": 1, 121 | 122 | "no-shadow": [ 123 | 2, 124 | { 125 | "hoist": "functions" 126 | } 127 | ], 128 | 129 | "array-bracket-spacing": [ 130 | 1, 131 | "never" 132 | ], 133 | 134 | "block-spacing": [ 135 | 1, 136 | "never" 137 | ], 138 | 139 | "brace-style": [ 140 | 1, 141 | "1tbs", 142 | { 143 | "allowSingleLine": true 144 | } 145 | ], 146 | 147 | "comma-dangle": 1, 148 | 149 | "comma-spacing": [ 150 | 1, 151 | { 152 | "before": false, 153 | "after": true 154 | } 155 | ], 156 | 157 | "comma-style": [ 158 | 1, 159 | "last" 160 | ], 161 | 162 | "computed-property-spacing": [ 163 | 1, 164 | "never" 165 | ], 166 | 167 | "consistent-this": [ 168 | 1, 169 | "that" 170 | ], 171 | 172 | "eol-last": 1, 173 | 174 | "id-blacklist": 1, 175 | 176 | 177 | "key-spacing": [ 178 | 1, 179 | { 180 | "afterColon": true 181 | } 182 | ], 183 | 184 | "keyword-spacing": 1, 185 | 186 | 187 | 188 | "new-cap": 1, 189 | 190 | "new-parens": 1, 191 | 192 | "no-mixed-spaces-and-tabs": 1, 193 | 194 | "no-multiple-empty-lines": [ 195 | 1, 196 | { 197 | "max": 1, 198 | "maxBOF": 0, 199 | "maxEOF": 1 200 | } 201 | ], 202 | 203 | "no-spaced-func": 1, 204 | 205 | 206 | "no-unneeded-ternary": 1, 207 | 208 | "no-whitespace-before-property": 1, 209 | 210 | "object-curly-spacing": [ 211 | 0 212 | ], 213 | 214 | "quotes": [ 215 | 1, 216 | "single", 217 | { 218 | "avoidEscape": true, 219 | "allowTemplateLiterals": true 220 | } 221 | ], 222 | 223 | "semi-spacing": [ 224 | 1, 225 | { 226 | "before": false, 227 | "after": true 228 | } 229 | ], 230 | 231 | 232 | "space-before-blocks": 1, 233 | 234 | "space-before-function-paren": [ 235 | 1, 236 | { 237 | "anonymous": "always", 238 | "named": "never" 239 | } 240 | ], 241 | 242 | "space-in-parens": [ 243 | 1, 244 | "never" 245 | ], 246 | 247 | "space-infix-ops": 1, 248 | 249 | "spaced-comment": [ 250 | 1, 251 | "always" 252 | ], 253 | 254 | "arrow-body-style": [ 255 | 1, 256 | "as-needed" 257 | ], 258 | 259 | "arrow-parens": [ 260 | 1, 261 | "as-needed" 262 | ], 263 | 264 | "arrow-spacing": [ 265 | 1, 266 | { 267 | "before": true, 268 | "after": true 269 | } 270 | ], 271 | 272 | "constructor-super": 2, 273 | 274 | "generator-star-spacing": [ 275 | 1, 276 | { 277 | "before": true, 278 | "after": false 279 | } 280 | ], 281 | 282 | "no-class-assign": 2, 283 | 284 | "no-confusing-arrow": 0, 285 | 286 | "no-const-assign": 2, 287 | 288 | "no-dupe-class-members": 1, 289 | 290 | "no-duplicate-imports": [ 291 | 1, 292 | { 293 | "includeExports": true 294 | } 295 | ], 296 | 297 | "no-new-symbol": 2, 298 | 299 | "no-restricted-imports": [ 300 | 2, 301 | "assert", 302 | "buffer", 303 | "child_process", 304 | "cluster", 305 | "crypto", 306 | "dgram", 307 | "dns", 308 | "domain", 309 | "events", 310 | "freelist", 311 | "fs", 312 | "http", 313 | "https", 314 | "module", 315 | "net", 316 | "os", 317 | "path", 318 | "punycode", 319 | "querystring", 320 | "readline", 321 | "repl", 322 | "smalloc", 323 | "stream", 324 | "string_decoder", 325 | "sys", 326 | "timers", 327 | "tls", 328 | "tracing", 329 | "tty", 330 | "url", 331 | "util", 332 | "vm", 333 | "zlib" 334 | ], 335 | 336 | "no-this-before-super": 2, 337 | 338 | "no-useless-computed-key": 1, 339 | 340 | "no-useless-constructor": 1, 341 | 342 | "no-useless-rename": 1, 343 | 344 | "no-var": 1, 345 | 346 | "prefer-rest-params": 1, 347 | 348 | "prefer-template": 1, 349 | 350 | "require-yield": 2, 351 | 352 | "rest-spread-spacing": [ 353 | 1, 354 | "never" 355 | ], 356 | 357 | "template-curly-spacing": [ 358 | 1, 359 | "never" 360 | ], 361 | 362 | "yield-star-spacing": [ 363 | 1, 364 | { 365 | "before": true, 366 | "after": false 367 | } 368 | ] 369 | } 370 | } 371 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | 5 | ignore 6 | 7 | # Runtime data 8 | pids 9 | *.pid 10 | *.seed 11 | 12 | # Directory for instrumented libs generated by jscoverage/JSCover 13 | lib-cov 14 | 15 | # Coverage directory used by tools like istanbul 16 | coverage 17 | .nyc_output 18 | 19 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) 20 | .grunt 21 | 22 | # node-waf configuration 23 | .lock-wscript 24 | 25 | # Compiled binary addons (http://nodejs.org/api/addons.html) 26 | build/Release 27 | 28 | # Dependency directory 29 | # https://www.npmjs.org/doc/misc/npm-faq.html#should-i-check-my-node_modules-folder-into-git 30 | node_modules 31 | 32 | # Remove some common IDE working directories 33 | .idea 34 | .vscode 35 | 36 | .DS_Store 37 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | app 2 | ignore 3 | test 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Cornel Stefanache 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tfjs-model-view 2 | 3 | __tfjs-model-view__ is a library for _in browser_ visualization of neural network intended for use with TensorFlow.js. 4 | 5 | Features: 6 | 7 | * Automatically render of the neural net 8 | * Automatically update weights/biases/values 9 | 10 | The library also aims to be flexible and make it easy for you to incorporate. 11 | 12 | ## Demos 13 | 14 | - [Movielens recommendation using Tensorflow.js](https://beta.observablehq.com/@cstefanache/movielens-recommendation-using-tensorflow-js) 15 | - [Iris Prediction with Custom Node Renderer](https://beta.observablehq.com/@cstefanache/tensorflow-js-model-viewer-iris) 16 | - [MNIST Prediction](https://beta.observablehq.com/@cstefanache/mnist-tensorflow-js-network-view-tfjs-model-view) 17 | - [Multiple Input Rendering](https://beta.observablehq.com/@cstefanache/tensorflow-js-model-view-multiple-input-test) 18 | 19 | ## Sample rendering output 20 | 21 |  22 | 23 | 24 | ## Usage 25 | 26 | Simple: 27 | ``` 28 | new ModelView(model) 29 | ``` 30 | 31 | Customized: 32 | ``` 33 | new ModelView(model, { 34 | printStats: true, 35 | radius: 25, 36 | renderLinks: true, 37 | xOffset: 100, 38 | renderNode(ctx, node) { 39 | const { x, y, value } = node; 40 | ctx.font = '10px Arial'; 41 | ctx.fillStyle = '#000'; 42 | ctx.textAlign = 'center'; 43 | ctx.textBaseline = 'middle'; 44 | ctx.fillText(Math.round(value * 100) / 100, x, y); 45 | }, 46 | onBeginRender: renderer => { 47 | const { renderContext } = renderer; 48 | renderContext.fillStyle = '#000'; 49 | renderContext.textAlign = 'end'; 50 | renderContext.font = '12px Arial'; 51 | renderContext.fillText('Sepal Length (cm)', 110, 110); 52 | renderContext.fillText('Sepal Width (cm)', 110, 136); 53 | renderContext.fillText('Petal Length (cm)', 110, 163); 54 | renderContext.fillText('Petal Width (cm)', 110, 190); 55 | 56 | renderContext.textAlign = 'start'; 57 | renderContext.fillText('Setosa', renderer.width - 60, 95); 58 | renderContext.fillText('Versicolor', renderer.width - 60, 150); 59 | renderContext.fillText('Virginica', renderer.width - 60, 205); 60 | }, 61 | layer: { 62 | 'dense_Dense1_input': { 63 | domain: [0, 8], 64 | color: [165, 130, 180] 65 | }, 66 | 'dense_Dense1/dense_Dense1': { 67 | color: [125, 125, 125] 68 | }, 69 | 'dense_Dense2/dense_Dense2': { 70 | color: [125, 125, 125] 71 | }, 72 | 'dense_Dense3/dense_Dense3': { 73 | nodePadding: 30 74 | } 75 | } 76 | }); 77 | ``` 78 | 79 | Customizing: 80 | ``` 81 | new ModelView(model, { 82 | /** renders the list of layers **/ 83 | printStats: true, 84 | 85 | /** Default domain for color intensity **/ 86 | domain: [0, 1], 87 | 88 | /** Default node radius **/ 89 | radius: 6, 90 | 91 | /** Default node padding **/ 92 | nodePadding: 2, 93 | 94 | /** Default layer padding **/ 95 | layerPadding: 20, 96 | 97 | /** Default group padding **/ 98 | groupPadding: 1, 99 | 100 | /** Horizontal padding **/ 101 | xPadding: 10, 102 | 103 | /** Vertical padding **/ 104 | yPadding: 10, 105 | 106 | /** Render links between layers **/ 107 | renderLinks: false, 108 | 109 | /** Stroke node outer circle **/ 110 | nodeStroke: true, 111 | 112 | /** custom render node function **/ 113 | renderNode: (ctx, node, nodeIdx) => {...}, 114 | 115 | /** If present will be executed before node rendering **/ 116 | onBeginRender: renderer => { ... }, 117 | 118 | /** If present will be executed after all node rendering is finished **/ 119 | onEndRender: renderer => { ... }, 120 | 121 | /** Personalized layer configuration **/ 122 | /** All defaults can be overridden for each layer individually **/ 123 | layer: { 124 | 'layerName': { 125 | /** Any property mentioned above **/ 126 | 127 | /** Reshape layer to antoher [cols, rows, groups] layout **/ 128 | reshape: [4, 4, 8] 129 | } 130 | } 131 | }); 132 | ``` 133 | 134 | ## Installation 135 | 136 | You can install this using npm with 137 | 138 | ``` 139 | npm install tfjs-model-view 140 | ``` 141 | 142 | or using yarn with 143 | 144 | ``` 145 | yarn add tfjs-model-view 146 | ``` 147 | 148 | ## Building from source 149 | 150 | To build the library, you need to have node.js installed. We use `yarn` 151 | instead of `npm` but you can use either. 152 | 153 | First install dependencies with 154 | 155 | ``` 156 | yarn 157 | ``` 158 | 159 | or 160 | 161 | ``` 162 | npm install 163 | ``` 164 | 165 | You can start the dev environment using 166 | 167 | ``` 168 | yarn dev 169 | ``` 170 | 171 | or 172 | 173 | ``` 174 | npm run dev 175 | ``` 176 | 177 | 178 | ## Sample Usage 179 | 180 | 181 | ## Issues 182 | 183 | Found a bug or have a feature request? Please file an [issue](https://github.com/cstefanache/tfjs-model-view/issues/new) 184 | -------------------------------------------------------------------------------- /app/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 |${text}
`;
92 | }
93 | }
94 |
95 | prepareMenu();
96 |
97 | if (runner) {
98 | load(runner.executor, runner);
99 | }
100 |
--------------------------------------------------------------------------------
/app/iris-custom/data.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 |
3 | export const IRIS_CLASSES = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'];
4 | export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;
5 |
6 | // Iris flowers data. Source:
7 | // https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
8 | export const IRIS_DATA = [
9 | [5.1, 3.5, 1.4, 0.2, 0],
10 | [4.9, 3.0, 1.4, 0.2, 0],
11 | [4.7, 3.2, 1.3, 0.2, 0],
12 | [4.6, 3.1, 1.5, 0.2, 0],
13 | [5.0, 3.6, 1.4, 0.2, 0],
14 | [5.4, 3.9, 1.7, 0.4, 0],
15 | [4.6, 3.4, 1.4, 0.3, 0],
16 | [5.0, 3.4, 1.5, 0.2, 0],
17 | [4.4, 2.9, 1.4, 0.2, 0],
18 | [4.9, 3.1, 1.5, 0.1, 0],
19 | [5.4, 3.7, 1.5, 0.2, 0],
20 | [4.8, 3.4, 1.6, 0.2, 0],
21 | [4.8, 3.0, 1.4, 0.1, 0],
22 | [4.3, 3.0, 1.1, 0.1, 0],
23 | [5.8, 4.0, 1.2, 0.2, 0],
24 | [5.7, 4.4, 1.5, 0.4, 0],
25 | [5.4, 3.9, 1.3, 0.4, 0],
26 | [5.1, 3.5, 1.4, 0.3, 0],
27 | [5.7, 3.8, 1.7, 0.3, 0],
28 | [5.1, 3.8, 1.5, 0.3, 0],
29 | [5.4, 3.4, 1.7, 0.2, 0],
30 | [5.1, 3.7, 1.5, 0.4, 0],
31 | [4.6, 3.6, 1.0, 0.2, 0],
32 | [5.1, 3.3, 1.7, 0.5, 0],
33 | [4.8, 3.4, 1.9, 0.2, 0],
34 | [5.0, 3.0, 1.6, 0.2, 0],
35 | [5.0, 3.4, 1.6, 0.4, 0],
36 | [5.2, 3.5, 1.5, 0.2, 0],
37 | [5.2, 3.4, 1.4, 0.2, 0],
38 | [4.7, 3.2, 1.6, 0.2, 0],
39 | [4.8, 3.1, 1.6, 0.2, 0],
40 | [5.4, 3.4, 1.5, 0.4, 0],
41 | [5.2, 4.1, 1.5, 0.1, 0],
42 | [5.5, 4.2, 1.4, 0.2, 0],
43 | [4.9, 3.1, 1.5, 0.1, 0],
44 | [5.0, 3.2, 1.2, 0.2, 0],
45 | [5.5, 3.5, 1.3, 0.2, 0],
46 | [4.9, 3.1, 1.5, 0.1, 0],
47 | [4.4, 3.0, 1.3, 0.2, 0],
48 | [5.1, 3.4, 1.5, 0.2, 0],
49 | [5.0, 3.5, 1.3, 0.3, 0],
50 | [4.5, 2.3, 1.3, 0.3, 0],
51 | [4.4, 3.2, 1.3, 0.2, 0],
52 | [5.0, 3.5, 1.6, 0.6, 0],
53 | [5.1, 3.8, 1.9, 0.4, 0],
54 | [4.8, 3.0, 1.4, 0.3, 0],
55 | [5.1, 3.8, 1.6, 0.2, 0],
56 | [4.6, 3.2, 1.4, 0.2, 0],
57 | [5.3, 3.7, 1.5, 0.2, 0],
58 | [5.0, 3.3, 1.4, 0.2, 0],
59 | [7.0, 3.2, 4.7, 1.4, 1],
60 | [6.4, 3.2, 4.5, 1.5, 1],
61 | [6.9, 3.1, 4.9, 1.5, 1],
62 | [5.5, 2.3, 4.0, 1.3, 1],
63 | [6.5, 2.8, 4.6, 1.5, 1],
64 | [5.7, 2.8, 4.5, 1.3, 1],
65 | [6.3, 3.3, 4.7, 1.6, 1],
66 | [4.9, 2.4, 3.3, 1.0, 1],
67 | [6.6, 2.9, 4.6, 1.3, 1],
68 | [5.2, 2.7, 3.9, 1.4, 1],
69 | [5.0, 2.0, 3.5, 1.0, 1],
70 | [5.9, 3.0, 4.2, 1.5, 1],
71 | [6.0, 2.2, 4.0, 1.0, 1],
72 | [6.1, 2.9, 4.7, 1.4, 1],
73 | [5.6, 2.9, 3.6, 1.3, 1],
74 | [6.7, 3.1, 4.4, 1.4, 1],
75 | [5.6, 3.0, 4.5, 1.5, 1],
76 | [5.8, 2.7, 4.1, 1.0, 1],
77 | [6.2, 2.2, 4.5, 1.5, 1],
78 | [5.6, 2.5, 3.9, 1.1, 1],
79 | [5.9, 3.2, 4.8, 1.8, 1],
80 | [6.1, 2.8, 4.0, 1.3, 1],
81 | [6.3, 2.5, 4.9, 1.5, 1],
82 | [6.1, 2.8, 4.7, 1.2, 1],
83 | [6.4, 2.9, 4.3, 1.3, 1],
84 | [6.6, 3.0, 4.4, 1.4, 1],
85 | [6.8, 2.8, 4.8, 1.4, 1],
86 | [6.7, 3.0, 5.0, 1.7, 1],
87 | [6.0, 2.9, 4.5, 1.5, 1],
88 | [5.7, 2.6, 3.5, 1.0, 1],
89 | [5.5, 2.4, 3.8, 1.1, 1],
90 | [5.5, 2.4, 3.7, 1.0, 1],
91 | [5.8, 2.7, 3.9, 1.2, 1],
92 | [6.0, 2.7, 5.1, 1.6, 1],
93 | [5.4, 3.0, 4.5, 1.5, 1],
94 | [6.0, 3.4, 4.5, 1.6, 1],
95 | [6.7, 3.1, 4.7, 1.5, 1],
96 | [6.3, 2.3, 4.4, 1.3, 1],
97 | [5.6, 3.0, 4.1, 1.3, 1],
98 | [5.5, 2.5, 4.0, 1.3, 1],
99 | [5.5, 2.6, 4.4, 1.2, 1],
100 | [6.1, 3.0, 4.6, 1.4, 1],
101 | [5.8, 2.6, 4.0, 1.2, 1],
102 | [5.0, 2.3, 3.3, 1.0, 1],
103 | [5.6, 2.7, 4.2, 1.3, 1],
104 | [5.7, 3.0, 4.2, 1.2, 1],
105 | [5.7, 2.9, 4.2, 1.3, 1],
106 | [6.2, 2.9, 4.3, 1.3, 1],
107 | [5.1, 2.5, 3.0, 1.1, 1],
108 | [5.7, 2.8, 4.1, 1.3, 1],
109 | [6.3, 3.3, 6.0, 2.5, 2],
110 | [5.8, 2.7, 5.1, 1.9, 2],
111 | [7.1, 3.0, 5.9, 2.1, 2],
112 | [6.3, 2.9, 5.6, 1.8, 2],
113 | [6.5, 3.0, 5.8, 2.2, 2],
114 | [7.6, 3.0, 6.6, 2.1, 2],
115 | [4.9, 2.5, 4.5, 1.7, 2],
116 | [7.3, 2.9, 6.3, 1.8, 2],
117 | [6.7, 2.5, 5.8, 1.8, 2],
118 | [7.2, 3.6, 6.1, 2.5, 2],
119 | [6.5, 3.2, 5.1, 2.0, 2],
120 | [6.4, 2.7, 5.3, 1.9, 2],
121 | [6.8, 3.0, 5.5, 2.1, 2],
122 | [5.7, 2.5, 5.0, 2.0, 2],
123 | [5.8, 2.8, 5.1, 2.4, 2],
124 | [6.4, 3.2, 5.3, 2.3, 2],
125 | [6.5, 3.0, 5.5, 1.8, 2],
126 | [7.7, 3.8, 6.7, 2.2, 2],
127 | [7.7, 2.6, 6.9, 2.3, 2],
128 | [6.0, 2.2, 5.0, 1.5, 2],
129 | [6.9, 3.2, 5.7, 2.3, 2],
130 | [5.6, 2.8, 4.9, 2.0, 2],
131 | [7.7, 2.8, 6.7, 2.0, 2],
132 | [6.3, 2.7, 4.9, 1.8, 2],
133 | [6.7, 3.3, 5.7, 2.1, 2],
134 | [7.2, 3.2, 6.0, 1.8, 2],
135 | [6.2, 2.8, 4.8, 1.8, 2],
136 | [6.1, 3.0, 4.9, 1.8, 2],
137 | [6.4, 2.8, 5.6, 2.1, 2],
138 | [7.2, 3.0, 5.8, 1.6, 2],
139 | [7.4, 2.8, 6.1, 1.9, 2],
140 | [7.9, 3.8, 6.4, 2.0, 2],
141 | [6.4, 2.8, 5.6, 2.2, 2],
142 | [6.3, 2.8, 5.1, 1.5, 2],
143 | [6.1, 2.6, 5.6, 1.4, 2],
144 | [7.7, 3.0, 6.1, 2.3, 2],
145 | [6.3, 3.4, 5.6, 2.4, 2],
146 | [6.4, 3.1, 5.5, 1.8, 2],
147 | [6.0, 3.0, 4.8, 1.8, 2],
148 | [6.9, 3.1, 5.4, 2.1, 2],
149 | [6.7, 3.1, 5.6, 2.4, 2],
150 | [6.9, 3.1, 5.1, 2.3, 2],
151 | [5.8, 2.7, 5.1, 1.9, 2],
152 | [6.8, 3.2, 5.9, 2.3, 2],
153 | [6.7, 3.3, 5.7, 2.5, 2],
154 | [6.7, 3.0, 5.2, 2.3, 2],
155 | [6.3, 2.5, 5.0, 1.9, 2],
156 | [6.5, 3.0, 5.2, 2.0, 2],
157 | [6.2, 3.4, 5.4, 2.3, 2],
158 | [5.9, 3.0, 5.1, 1.8, 2]
159 | ];
160 |
161 | /**
162 | * Convert Iris data arrays to `tf.Tensor`s.
163 | *
164 | * @param data The Iris input feature data, an `Array` of `Array`s, each element
165 | * of which is assumed to be a length-4 `Array` (for petal length, petal
166 | * width, sepal length, sepal width).
167 | * @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
168 | * representing the true category of the Iris flower. Assumed to have the same
169 | * array length as `data`.
170 | * @param testSplit Fraction of the data at the end to split as test data: a
171 | * number between 0 and 1.
172 | * @return A length-4 `Array`, with
173 | * - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
174 | * - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
175 | * - test data as `tf.Tensor` of shape [numTestExamples, 4].
176 | * - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
177 | */
178 | function convertToTensors(data, targets, testSplit) {
179 | const numExamples = data.length;
180 | if (numExamples !== targets.length) {
181 | throw new Error('data and split have different numbers of examples');
182 | }
183 |
184 | // Randomly shuffle `data` and `targets`.
185 | const indices = [];
186 | for (let i = 0; i < numExamples; ++i) {
187 | indices.push(i);
188 | }
189 | tf.util.shuffle(indices);
190 |
191 | const shuffledData = [];
192 | const shuffledTargets = [];
193 | for (let i = 0; i < numExamples; ++i) {
194 | shuffledData.push(data[indices[i]]);
195 | shuffledTargets.push(targets[indices[i]]);
196 | }
197 |
198 | // Split the data into a training set and a tet set, based on `testSplit`.
199 | const numTestExamples = Math.round(numExamples * testSplit);
200 | const numTrainExamples = numExamples - numTestExamples;
201 |
202 | const xDims = shuffledData[0].length;
203 |
204 | // Create a 2D `tf.Tensor` to hold the feature data.
205 | const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);
206 |
207 | // Create a 1D `tf.Tensor` to hold the labels, and convert the number label
208 | // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
209 | const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
210 |
211 | // Split the data into training and test sets, using `slice`.
212 | const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
213 | const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
214 | const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
215 | const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
216 | return [xTrain, yTrain, xTest, yTest];
217 | }
218 |
219 | /**
220 | * Obtains Iris data, split into training and test sets.
221 | *
222 | * @param testSplit Fraction of the data at the end to split as test data: a
223 | * number between 0 and 1.
224 | *
225 | * @param return A length-4 `Array`, with
226 | * - training data as an `Array` of length-4 `Array` of numbers.
227 | * - training labels as an `Array` of numbers, with the same length as the
228 | * return training data above. Each element of the `Array` is from the set
229 | * {0, 1, 2}.
230 | * - test data as an `Array` of length-4 `Array` of numbers.
231 | * - test labels as an `Array` of numbers, with the same length as the
232 | * return test data above. Each element of the `Array` is from the set
233 | * {0, 1, 2}.
234 | */
235 | export function getIrisData(testSplit) {
236 | return tf.tidy(() => {
237 | const dataByClass = [];
238 | const targetsByClass = [];
239 | for (let i = 0; i < IRIS_CLASSES.length; ++i) {
240 | dataByClass.push([]);
241 | targetsByClass.push([]);
242 | }
243 | for (const example of IRIS_DATA) {
244 | const target = example[example.length - 1];
245 | const data = example.slice(0, example.length - 1);
246 | dataByClass[target].push(data);
247 | targetsByClass[target].push(target);
248 | }
249 |
250 | const xTrains = [];
251 | const yTrains = [];
252 | const xTests = [];
253 | const yTests = [];
254 | for (let i = 0; i < IRIS_CLASSES.length; ++i) {
255 | const [xTrain, yTrain, xTest, yTest] =
256 | convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
257 | xTrains.push(xTrain);
258 | yTrains.push(yTrain);
259 | xTests.push(xTest);
260 | yTests.push(yTest);
261 | }
262 |
263 | const concatAxis = 0;
264 | return [
265 | tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
266 | tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
267 | ];
268 | });
269 | }
270 |
--------------------------------------------------------------------------------
/app/iris-custom/iris.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 | import ModelView from '../../src';
3 |
4 | import {
5 | IRIS_DATA,
6 | getIrisData
7 | } from './data';
8 |
9 | async function trainModel(xTrain, yTrain, xTest, yTest) {
10 |
11 | // Define the topology of the model: two dense layers.
12 | const model = tf.sequential();
13 | model.add(tf.layers.dense({
14 | units: 10,
15 | activation: 'sigmoid',
16 | inputShape: [xTrain.shape[1]]
17 | }));
18 |
19 | model.add(tf.layers.dense({
20 | units: 10,
21 | activation: 'sigmoid',
22 | inputShape: [10]
23 | }));
24 |
25 | model.add(tf.layers.dense({
26 | units: 3,
27 | activation: 'softmax'
28 | }));
29 |
30 | model.summary();
31 |
32 | const optimizer = tf.train.adam(0.02);
33 | model.compile({
34 | optimizer: optimizer,
35 | loss: 'categoricalCrossentropy',
36 | metrics: ['accuracy']
37 | });
38 |
39 | new ModelView(model, {
40 | printStats: true,
41 | radius: 25,
42 | renderLinks: true,
43 | xOffset: 100,
44 | renderNode(ctx, node) {
45 | const { x, y, value } = node;
46 | ctx.font = '10px Arial';
47 | ctx.fillStyle = '#000';
48 | ctx.textAlign = 'center';
49 | ctx.textBaseline = 'middle';
50 | ctx.fillText(Math.round(value * 100) / 100, x, y);
51 | },
52 | onBeginRender: renderer => {
53 | const { renderContext } = renderer;
54 | renderContext.fillStyle = '#000';
55 | renderContext.textAlign = 'end';
56 | renderContext.font = '12px Arial';
57 | renderContext.fillText('Sepal Length (cm)', 110, 110);
58 | renderContext.fillText('Sepal Width (cm)', 110, 136);
59 | renderContext.fillText('Petal Length (cm)', 110, 163);
60 | renderContext.fillText('Petal Width (cm)', 110, 190);
61 |
62 | renderContext.textAlign = 'start';
63 | renderContext.fillText('Setosa', renderer.width - 60, 95);
64 | renderContext.fillText('Versicolor', renderer.width - 60, 150);
65 | renderContext.fillText('Virginica', renderer.width - 60, 205);
66 | },
67 | layer: {
68 | 'dense_Dense1_input': {
69 | domain: [0, 8],
70 | color: [165, 130, 180]
71 | },
72 | 'dense_Dense1/dense_Dense1': {
73 | color: [125, 125, 125]
74 | },
75 | 'dense_Dense2/dense_Dense2': {
76 | color: [125, 125, 125]
77 | },
78 | 'dense_Dense3/dense_Dense3': {
79 | nodePadding: 30
80 | }
81 | }
82 | });
83 |
84 | await model.fit(xTrain, yTrain, {
85 | epochs: 100,
86 | validationData: [xTest, yTest]
87 | });
88 |
89 | setInterval(() => {
90 | model.predict(tf.tensor([IRIS_DATA[Math.floor(Math.random() * IRIS_DATA.length)].slice(0, 4)]));
91 | }, 1000);
92 |
93 | return model;
94 | }
95 |
96 | const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
97 |
98 | export default async () => {
99 | trainModel(xTrain, yTrain, xTest, yTest);
100 | }
101 |
--------------------------------------------------------------------------------
/app/iris/data.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 |
3 | export const IRIS_CLASSES = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'];
4 | export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;
5 |
6 | // Iris flowers data. Source:
7 | // https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
8 | export const IRIS_DATA = [
9 | [5.1, 3.5, 1.4, 0.2, 0],
10 | [4.9, 3.0, 1.4, 0.2, 0],
11 | [4.7, 3.2, 1.3, 0.2, 0],
12 | [4.6, 3.1, 1.5, 0.2, 0],
13 | [5.0, 3.6, 1.4, 0.2, 0],
14 | [5.4, 3.9, 1.7, 0.4, 0],
15 | [4.6, 3.4, 1.4, 0.3, 0],
16 | [5.0, 3.4, 1.5, 0.2, 0],
17 | [4.4, 2.9, 1.4, 0.2, 0],
18 | [4.9, 3.1, 1.5, 0.1, 0],
19 | [5.4, 3.7, 1.5, 0.2, 0],
20 | [4.8, 3.4, 1.6, 0.2, 0],
21 | [4.8, 3.0, 1.4, 0.1, 0],
22 | [4.3, 3.0, 1.1, 0.1, 0],
23 | [5.8, 4.0, 1.2, 0.2, 0],
24 | [5.7, 4.4, 1.5, 0.4, 0],
25 | [5.4, 3.9, 1.3, 0.4, 0],
26 | [5.1, 3.5, 1.4, 0.3, 0],
27 | [5.7, 3.8, 1.7, 0.3, 0],
28 | [5.1, 3.8, 1.5, 0.3, 0],
29 | [5.4, 3.4, 1.7, 0.2, 0],
30 | [5.1, 3.7, 1.5, 0.4, 0],
31 | [4.6, 3.6, 1.0, 0.2, 0],
32 | [5.1, 3.3, 1.7, 0.5, 0],
33 | [4.8, 3.4, 1.9, 0.2, 0],
34 | [5.0, 3.0, 1.6, 0.2, 0],
35 | [5.0, 3.4, 1.6, 0.4, 0],
36 | [5.2, 3.5, 1.5, 0.2, 0],
37 | [5.2, 3.4, 1.4, 0.2, 0],
38 | [4.7, 3.2, 1.6, 0.2, 0],
39 | [4.8, 3.1, 1.6, 0.2, 0],
40 | [5.4, 3.4, 1.5, 0.4, 0],
41 | [5.2, 4.1, 1.5, 0.1, 0],
42 | [5.5, 4.2, 1.4, 0.2, 0],
43 | [4.9, 3.1, 1.5, 0.1, 0],
44 | [5.0, 3.2, 1.2, 0.2, 0],
45 | [5.5, 3.5, 1.3, 0.2, 0],
46 | [4.9, 3.1, 1.5, 0.1, 0],
47 | [4.4, 3.0, 1.3, 0.2, 0],
48 | [5.1, 3.4, 1.5, 0.2, 0],
49 | [5.0, 3.5, 1.3, 0.3, 0],
50 | [4.5, 2.3, 1.3, 0.3, 0],
51 | [4.4, 3.2, 1.3, 0.2, 0],
52 | [5.0, 3.5, 1.6, 0.6, 0],
53 | [5.1, 3.8, 1.9, 0.4, 0],
54 | [4.8, 3.0, 1.4, 0.3, 0],
55 | [5.1, 3.8, 1.6, 0.2, 0],
56 | [4.6, 3.2, 1.4, 0.2, 0],
57 | [5.3, 3.7, 1.5, 0.2, 0],
58 | [5.0, 3.3, 1.4, 0.2, 0],
59 | [7.0, 3.2, 4.7, 1.4, 1],
60 | [6.4, 3.2, 4.5, 1.5, 1],
61 | [6.9, 3.1, 4.9, 1.5, 1],
62 | [5.5, 2.3, 4.0, 1.3, 1],
63 | [6.5, 2.8, 4.6, 1.5, 1],
64 | [5.7, 2.8, 4.5, 1.3, 1],
65 | [6.3, 3.3, 4.7, 1.6, 1],
66 | [4.9, 2.4, 3.3, 1.0, 1],
67 | [6.6, 2.9, 4.6, 1.3, 1],
68 | [5.2, 2.7, 3.9, 1.4, 1],
69 | [5.0, 2.0, 3.5, 1.0, 1],
70 | [5.9, 3.0, 4.2, 1.5, 1],
71 | [6.0, 2.2, 4.0, 1.0, 1],
72 | [6.1, 2.9, 4.7, 1.4, 1],
73 | [5.6, 2.9, 3.6, 1.3, 1],
74 | [6.7, 3.1, 4.4, 1.4, 1],
75 | [5.6, 3.0, 4.5, 1.5, 1],
76 | [5.8, 2.7, 4.1, 1.0, 1],
77 | [6.2, 2.2, 4.5, 1.5, 1],
78 | [5.6, 2.5, 3.9, 1.1, 1],
79 | [5.9, 3.2, 4.8, 1.8, 1],
80 | [6.1, 2.8, 4.0, 1.3, 1],
81 | [6.3, 2.5, 4.9, 1.5, 1],
82 | [6.1, 2.8, 4.7, 1.2, 1],
83 | [6.4, 2.9, 4.3, 1.3, 1],
84 | [6.6, 3.0, 4.4, 1.4, 1],
85 | [6.8, 2.8, 4.8, 1.4, 1],
86 | [6.7, 3.0, 5.0, 1.7, 1],
87 | [6.0, 2.9, 4.5, 1.5, 1],
88 | [5.7, 2.6, 3.5, 1.0, 1],
89 | [5.5, 2.4, 3.8, 1.1, 1],
90 | [5.5, 2.4, 3.7, 1.0, 1],
91 | [5.8, 2.7, 3.9, 1.2, 1],
92 | [6.0, 2.7, 5.1, 1.6, 1],
93 | [5.4, 3.0, 4.5, 1.5, 1],
94 | [6.0, 3.4, 4.5, 1.6, 1],
95 | [6.7, 3.1, 4.7, 1.5, 1],
96 | [6.3, 2.3, 4.4, 1.3, 1],
97 | [5.6, 3.0, 4.1, 1.3, 1],
98 | [5.5, 2.5, 4.0, 1.3, 1],
99 | [5.5, 2.6, 4.4, 1.2, 1],
100 | [6.1, 3.0, 4.6, 1.4, 1],
101 | [5.8, 2.6, 4.0, 1.2, 1],
102 | [5.0, 2.3, 3.3, 1.0, 1],
103 | [5.6, 2.7, 4.2, 1.3, 1],
104 | [5.7, 3.0, 4.2, 1.2, 1],
105 | [5.7, 2.9, 4.2, 1.3, 1],
106 | [6.2, 2.9, 4.3, 1.3, 1],
107 | [5.1, 2.5, 3.0, 1.1, 1],
108 | [5.7, 2.8, 4.1, 1.3, 1],
109 | [6.3, 3.3, 6.0, 2.5, 2],
110 | [5.8, 2.7, 5.1, 1.9, 2],
111 | [7.1, 3.0, 5.9, 2.1, 2],
112 | [6.3, 2.9, 5.6, 1.8, 2],
113 | [6.5, 3.0, 5.8, 2.2, 2],
114 | [7.6, 3.0, 6.6, 2.1, 2],
115 | [4.9, 2.5, 4.5, 1.7, 2],
116 | [7.3, 2.9, 6.3, 1.8, 2],
117 | [6.7, 2.5, 5.8, 1.8, 2],
118 | [7.2, 3.6, 6.1, 2.5, 2],
119 | [6.5, 3.2, 5.1, 2.0, 2],
120 | [6.4, 2.7, 5.3, 1.9, 2],
121 | [6.8, 3.0, 5.5, 2.1, 2],
122 | [5.7, 2.5, 5.0, 2.0, 2],
123 | [5.8, 2.8, 5.1, 2.4, 2],
124 | [6.4, 3.2, 5.3, 2.3, 2],
125 | [6.5, 3.0, 5.5, 1.8, 2],
126 | [7.7, 3.8, 6.7, 2.2, 2],
127 | [7.7, 2.6, 6.9, 2.3, 2],
128 | [6.0, 2.2, 5.0, 1.5, 2],
129 | [6.9, 3.2, 5.7, 2.3, 2],
130 | [5.6, 2.8, 4.9, 2.0, 2],
131 | [7.7, 2.8, 6.7, 2.0, 2],
132 | [6.3, 2.7, 4.9, 1.8, 2],
133 | [6.7, 3.3, 5.7, 2.1, 2],
134 | [7.2, 3.2, 6.0, 1.8, 2],
135 | [6.2, 2.8, 4.8, 1.8, 2],
136 | [6.1, 3.0, 4.9, 1.8, 2],
137 | [6.4, 2.8, 5.6, 2.1, 2],
138 | [7.2, 3.0, 5.8, 1.6, 2],
139 | [7.4, 2.8, 6.1, 1.9, 2],
140 | [7.9, 3.8, 6.4, 2.0, 2],
141 | [6.4, 2.8, 5.6, 2.2, 2],
142 | [6.3, 2.8, 5.1, 1.5, 2],
143 | [6.1, 2.6, 5.6, 1.4, 2],
144 | [7.7, 3.0, 6.1, 2.3, 2],
145 | [6.3, 3.4, 5.6, 2.4, 2],
146 | [6.4, 3.1, 5.5, 1.8, 2],
147 | [6.0, 3.0, 4.8, 1.8, 2],
148 | [6.9, 3.1, 5.4, 2.1, 2],
149 | [6.7, 3.1, 5.6, 2.4, 2],
150 | [6.9, 3.1, 5.1, 2.3, 2],
151 | [5.8, 2.7, 5.1, 1.9, 2],
152 | [6.8, 3.2, 5.9, 2.3, 2],
153 | [6.7, 3.3, 5.7, 2.5, 2],
154 | [6.7, 3.0, 5.2, 2.3, 2],
155 | [6.3, 2.5, 5.0, 1.9, 2],
156 | [6.5, 3.0, 5.2, 2.0, 2],
157 | [6.2, 3.4, 5.4, 2.3, 2],
158 | [5.9, 3.0, 5.1, 1.8, 2]
159 | ];
160 |
161 | /**
162 | * Convert Iris data arrays to `tf.Tensor`s.
163 | *
164 | * @param data The Iris input feature data, an `Array` of `Array`s, each element
165 | * of which is assumed to be a length-4 `Array` (for petal length, petal
166 | * width, sepal length, sepal width).
167 | * @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
168 | * representing the true category of the Iris flower. Assumed to have the same
169 | * array length as `data`.
170 | * @param testSplit Fraction of the data at the end to split as test data: a
171 | * number between 0 and 1.
172 | * @return A length-4 `Array`, with
173 | * - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
174 | * - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
175 | * - test data as `tf.Tensor` of shape [numTestExamples, 4].
176 | * - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
177 | */
178 | function convertToTensors(data, targets, testSplit) {
179 | const numExamples = data.length;
180 | if (numExamples !== targets.length) {
181 | throw new Error('data and split have different numbers of examples');
182 | }
183 |
184 | // Randomly shuffle `data` and `targets`.
185 | const indices = [];
186 | for (let i = 0; i < numExamples; ++i) {
187 | indices.push(i);
188 | }
189 | tf.util.shuffle(indices);
190 |
191 | const shuffledData = [];
192 | const shuffledTargets = [];
193 | for (let i = 0; i < numExamples; ++i) {
194 | shuffledData.push(data[indices[i]]);
195 | shuffledTargets.push(targets[indices[i]]);
196 | }
197 |
198 | // Split the data into a training set and a tet set, based on `testSplit`.
199 | const numTestExamples = Math.round(numExamples * testSplit);
200 | const numTrainExamples = numExamples - numTestExamples;
201 |
202 | const xDims = shuffledData[0].length;
203 |
204 | // Create a 2D `tf.Tensor` to hold the feature data.
205 | const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);
206 |
207 | // Create a 1D `tf.Tensor` to hold the labels, and convert the number label
208 | // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
209 | const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
210 |
211 | // Split the data into training and test sets, using `slice`.
212 | const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
213 | const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
214 | const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
215 | const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
216 | return [xTrain, yTrain, xTest, yTest];
217 | }
218 |
219 | /**
220 | * Obtains Iris data, split into training and test sets.
221 | *
222 | * @param testSplit Fraction of the data at the end to split as test data: a
223 | * number between 0 and 1.
224 | *
225 | * @param return A length-4 `Array`, with
226 | * - training data as an `Array` of length-4 `Array` of numbers.
227 | * - training labels as an `Array` of numbers, with the same length as the
228 | * return training data above. Each element of the `Array` is from the set
229 | * {0, 1, 2}.
230 | * - test data as an `Array` of length-4 `Array` of numbers.
231 | * - test labels as an `Array` of numbers, with the same length as the
232 | * return test data above. Each element of the `Array` is from the set
233 | * {0, 1, 2}.
234 | */
235 | export function getIrisData(testSplit) {
236 | return tf.tidy(() => {
237 | const dataByClass = [];
238 | const targetsByClass = [];
239 | for (let i = 0; i < IRIS_CLASSES.length; ++i) {
240 | dataByClass.push([]);
241 | targetsByClass.push([]);
242 | }
243 | for (const example of IRIS_DATA) {
244 | const target = example[example.length - 1];
245 | const data = example.slice(0, example.length - 1);
246 | dataByClass[target].push(data);
247 | targetsByClass[target].push(target);
248 | }
249 |
250 | const xTrains = [];
251 | const yTrains = [];
252 | const xTests = [];
253 | const yTests = [];
254 | for (let i = 0; i < IRIS_CLASSES.length; ++i) {
255 | const [xTrain, yTrain, xTest, yTest] =
256 | convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
257 | xTrains.push(xTrain);
258 | yTrains.push(yTrain);
259 | xTests.push(xTest);
260 | yTests.push(yTest);
261 | }
262 |
263 | const concatAxis = 0;
264 | return [
265 | tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
266 | tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
267 | ];
268 | });
269 | }
270 |
--------------------------------------------------------------------------------
/app/iris/iris.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 | import ModelView from '../../src';
3 |
4 | import {
5 | IRIS_DATA,
6 | getIrisData
7 | } from './data';
8 |
9 | async function trainModel(xTrain, yTrain, xTest, yTest) {
10 |
11 | // Define the topology of the model: two dense layers.
12 | const model = tf.sequential();
13 | model.add(tf.layers.dense({
14 | units: 10,
15 | activation: 'sigmoid',
16 | inputShape: [xTrain.shape[1]]
17 | }));
18 |
19 | model.add(tf.layers.dense({
20 | units: 3,
21 | activation: 'softmax'
22 | }));
23 |
24 | model.summary();
25 |
26 | const optimizer = tf.train.adam(0.02);
27 | model.compile({
28 | optimizer: optimizer,
29 | loss: 'categoricalCrossentropy',
30 | metrics: ['accuracy']
31 | });
32 |
33 | new ModelView(model, {
34 | printStats: true,
35 | radius: 25,
36 | renderLinks: true,
37 | layer: {
38 | 'dense_Dense1_input': {
39 | domain: [0, 7]
40 | },
41 | 'dense_Dense2/dense_Dense2': {
42 | nodePadding: 30
43 | }
44 | }
45 | });
46 |
47 | await model.fit(xTrain, yTrain, {
48 | epochs: 100,
49 | validationData: [xTest, yTest]
50 | });
51 |
52 | setInterval(() => {
53 | model.predict(tf.tensor([IRIS_DATA[Math.floor(Math.random() * IRIS_DATA.length)].slice(0, 4)]));
54 | }, 1000);
55 |
56 | return model;
57 | }
58 |
59 | const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
60 |
61 | export default async () => {
62 | trainModel(xTrain, yTrain, xTest, yTest);
63 | }
64 |
--------------------------------------------------------------------------------
/app/mnist/data.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '@tensorflow/tfjs';
19 |
20 | export const IMAGE_H = 28;
21 | export const IMAGE_W = 28;
22 | const IMAGE_SIZE = IMAGE_H * IMAGE_W;
23 | const NUM_CLASSES = 10;
24 | const NUM_DATASET_ELEMENTS = 65000;
25 |
26 | const NUM_TRAIN_ELEMENTS = 55000;
27 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
28 |
29 | const MNIST_IMAGES_SPRITE_PATH =
30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
31 | const MNIST_LABELS_PATH =
32 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
33 |
34 | /**
35 | * A class that fetches the sprited MNIST dataset and provide data as
36 | * tf.Tensors.
37 | */
38 | export class MnistData {
39 | constructor() {}
40 |
41 | async load() {
42 | // Make a request for the MNIST sprited image.
43 | const img = new Image();
44 | const canvas = document.createElement('canvas');
45 | const ctx = canvas.getContext('2d');
46 | const imgRequest = new Promise((resolve, reject) => {
47 | img.crossOrigin = '';
48 | img.onload = () => {
49 | img.width = img.naturalWidth;
50 | img.height = img.naturalHeight;
51 |
52 | const datasetBytesBuffer =
53 | new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
54 |
55 | const chunkSize = 5000;
56 | canvas.width = img.width;
57 | canvas.height = chunkSize;
58 |
59 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
60 | const datasetBytesView = new Float32Array(
61 | datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
62 | IMAGE_SIZE * chunkSize);
63 | ctx.drawImage(
64 | img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
65 | chunkSize);
66 |
67 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
68 |
69 | for (let j = 0; j < imageData.data.length / 4; j++) {
70 | // All channels hold an equal value since the image is grayscale, so
71 | // just read the red channel.
72 | datasetBytesView[j] = imageData.data[j * 4] / 255;
73 | }
74 | }
75 | this.datasetImages = new Float32Array(datasetBytesBuffer);
76 |
77 | resolve();
78 | };
79 | img.src = MNIST_IMAGES_SPRITE_PATH;
80 | });
81 |
82 | const labelsRequest = fetch(MNIST_LABELS_PATH);
83 | const [imgResponse, labelsResponse] =
84 | await Promise.all([imgRequest, labelsRequest]);
85 |
86 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
87 |
88 | // Slice the the images and labels into train and test sets.
89 | this.trainImages =
90 | this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
91 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
92 | this.trainLabels =
93 | this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
94 | this.testLabels =
95 | this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
96 | }
97 |
98 | /**
99 | * Get all training data as a data tensor and a labels tensor.
100 | *
101 | * @returns
102 | * xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`.
103 | * labels: The one-hot encoded labels tensor, of shape
104 | * `[numTrainExamples, 10]`.
105 | */
106 | getTrainData() {
107 | const xs = tf.tensor4d(
108 | this.trainImages,
109 | [this.trainImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]);
110 | const labels = tf.tensor2d(
111 | this.trainLabels, [this.trainLabels.length / NUM_CLASSES, NUM_CLASSES]);
112 | return {xs, labels};
113 | }
114 |
115 | /**
116 | * Get all test data as a data tensor a a labels tensor.
117 | *
118 | * @param {number} numExamples Optional number of examples to get. If not
119 | * provided,
120 | * all test examples will be returned.
121 | * @returns
122 | * xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`.
123 | * labels: The one-hot encoded labels tensor, of shape
124 | * `[numTestExamples, 10]`.
125 | */
126 | getTestData(numExamples) {
127 | let xs = tf.tensor4d(
128 | this.testImages,
129 | [this.testImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]);
130 | let labels = tf.tensor2d(
131 | this.testLabels, [this.testLabels.length / NUM_CLASSES, NUM_CLASSES]);
132 |
133 | if (numExamples != null) {
134 | xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]);
135 | labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
136 | }
137 | return {xs, labels};
138 | }
139 | }
140 |
--------------------------------------------------------------------------------
/app/models/mnist-dense/group1-shard1of1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cstefanache/tfjs-model-view/860fed91c837fe28b3119e18c1f38d5298a6eddc/app/models/mnist-dense/group1-shard1of1
--------------------------------------------------------------------------------
/app/models/mnist-dense/group2-shard1of1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cstefanache/tfjs-model-view/860fed91c837fe28b3119e18c1f38d5298a6eddc/app/models/mnist-dense/group2-shard1of1
--------------------------------------------------------------------------------
/app/models/mnist-dense/group3-shard1of1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cstefanache/tfjs-model-view/860fed91c837fe28b3119e18c1f38d5298a6eddc/app/models/mnist-dense/group3-shard1of1
--------------------------------------------------------------------------------
/app/models/mnist-dense/model.json:
--------------------------------------------------------------------------------
1 | {"modelTopology": {"keras_version": "2.1.6", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": [{"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "batch_input_shape": [null, 784], "dtype": "float32", "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_1", "trainable": true, "rate": 0.2, "noise_shape": null, "seed": null}}, {"class_name": "Dense", "config": {"name": "dense_2", "trainable": true, "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_2", "trainable": true, "rate": 0.2, "noise_shape": null, "seed": null}}, {"class_name": "Dense", "config": {"name": "dense_3", "trainable": true, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "training_config": {"optimizer_config": {"class_name": "RMSprop", "config": {"lr": 0.0010000000474974513, "rho": 0.8999999761581421, "decay": 0.0, "epsilon": 1e-07}}, "loss": "categorical_crossentropy", "metrics": ["accuracy"], "sample_weight_mode": null, "loss_weights": null}}, "weightsManifest": [{"paths": ["group1-shard1of1"], "weights": [{"name": "dense_1/kernel", "shape": [784, 512], "dtype": "float32"}, {"name": "dense_1/bias", "shape": [512], "dtype": "float32"}]}, {"paths": ["group2-shard1of1"], "weights": [{"name": "dense_2/kernel", "shape": [512, 512], "dtype": "float32"}, {"name": "dense_2/bias", "shape": [512], "dtype": "float32"}]}, {"paths": ["group3-shard1of1"], "weights": [{"name": "dense_3/kernel", "shape": [512, 10], "dtype": "float32"}, {"name": "dense_3/bias", "shape": [10], "dtype": "float32"}]}]}
--------------------------------------------------------------------------------
/app/tiny/tiny.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 | import ModelView from '../../src';
3 |
4 | export default async () => {
5 |
6 | // Define the topology of the model: two dense layers.
7 | const model = tf.sequential();
8 | model.add(tf.layers.dense({
9 | units: 2,
10 | activation: 'tanh',
11 | inputShape: [1]
12 | }));
13 |
14 | model.add(tf.layers.dense({
15 | units: 2,
16 | activation: 'relu',
17 | inputShape: [2]
18 | }));
19 |
20 | model.add(tf.layers.dense({
21 | units: 3,
22 | activation: 'softplus',
23 | inputShape: [2]
24 | }));
25 |
26 | model.add(tf.layers.dense({
27 | units: 1,
28 | activation: 'softsign'
29 | }));
30 |
31 | model.summary();
32 |
33 | new ModelView(model, {
34 | radius: 25,
35 | renderLinks: true
36 | });
37 |
38 | setTimeout(() => {
39 | const data = model.predict(tf.tensor([[1]])).dataSync();
40 | console.log(data)
41 | }, 10);
42 | return model
43 | }
44 |
--------------------------------------------------------------------------------
/lib/tfjs-model-view.js:
--------------------------------------------------------------------------------
1 | (function webpackUniversalModuleDefinition(root, factory) {
2 | if(typeof exports === 'object' && typeof module === 'object')
3 | module.exports = factory();
4 | else if(typeof define === 'function' && define.amd)
5 | define("tfjs-model-view", [], factory);
6 | else if(typeof exports === 'object')
7 | exports["tfjs-model-view"] = factory();
8 | else
9 | root["tfjs-model-view"] = factory();
10 | })(typeof self !== 'undefined' ? self : this, function() {
11 | return /******/ (function(modules) { // webpackBootstrap
12 | /******/ // The module cache
13 | /******/ var installedModules = {};
14 | /******/
15 | /******/ // The require function
16 | /******/ function __webpack_require__(moduleId) {
17 | /******/
18 | /******/ // Check if module is in cache
19 | /******/ if(installedModules[moduleId]) {
20 | /******/ return installedModules[moduleId].exports;
21 | /******/ }
22 | /******/ // Create a new module (and put it into the cache)
23 | /******/ var module = installedModules[moduleId] = {
24 | /******/ i: moduleId,
25 | /******/ l: false,
26 | /******/ exports: {}
27 | /******/ };
28 | /******/
29 | /******/ // Execute the module function
30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__);
31 | /******/
32 | /******/ // Flag the module as loaded
33 | /******/ module.l = true;
34 | /******/
35 | /******/ // Return the exports of the module
36 | /******/ return module.exports;
37 | /******/ }
38 | /******/
39 | /******/
40 | /******/ // expose the modules object (__webpack_modules__)
41 | /******/ __webpack_require__.m = modules;
42 | /******/
43 | /******/ // expose the module cache
44 | /******/ __webpack_require__.c = installedModules;
45 | /******/
46 | /******/ // define getter function for harmony exports
47 | /******/ __webpack_require__.d = function(exports, name, getter) {
48 | /******/ if(!__webpack_require__.o(exports, name)) {
49 | /******/ Object.defineProperty(exports, name, { enumerable: true, get: getter });
50 | /******/ }
51 | /******/ };
52 | /******/
53 | /******/ // define __esModule on exports
54 | /******/ __webpack_require__.r = function(exports) {
55 | /******/ if(typeof Symbol !== 'undefined' && Symbol.toStringTag) {
56 | /******/ Object.defineProperty(exports, Symbol.toStringTag, { value: 'Module' });
57 | /******/ }
58 | /******/ Object.defineProperty(exports, '__esModule', { value: true });
59 | /******/ };
60 | /******/
61 | /******/ // create a fake namespace object
62 | /******/ // mode & 1: value is a module id, require it
63 | /******/ // mode & 2: merge all properties of value into the ns
64 | /******/ // mode & 4: return value when already ns object
65 | /******/ // mode & 8|1: behave like require
66 | /******/ __webpack_require__.t = function(value, mode) {
67 | /******/ if(mode & 1) value = __webpack_require__(value);
68 | /******/ if(mode & 8) return value;
69 | /******/ if((mode & 4) && typeof value === 'object' && value && value.__esModule) return value;
70 | /******/ var ns = Object.create(null);
71 | /******/ __webpack_require__.r(ns);
72 | /******/ Object.defineProperty(ns, 'default', { enumerable: true, value: value });
73 | /******/ if(mode & 2 && typeof value != 'string') for(var key in value) __webpack_require__.d(ns, key, function(key) { return value[key]; }.bind(null, key));
74 | /******/ return ns;
75 | /******/ };
76 | /******/
77 | /******/ // getDefaultExport function for compatibility with non-harmony modules
78 | /******/ __webpack_require__.n = function(module) {
79 | /******/ var getter = module && module.__esModule ?
80 | /******/ function getDefault() { return module['default']; } :
81 | /******/ function getModuleExports() { return module; };
82 | /******/ __webpack_require__.d(getter, 'a', getter);
83 | /******/ return getter;
84 | /******/ };
85 | /******/
86 | /******/ // Object.prototype.hasOwnProperty.call
87 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); };
88 | /******/
89 | /******/ // __webpack_public_path__
90 | /******/ __webpack_require__.p = "";
91 | /******/
92 | /******/
93 | /******/ // Load entry module and return exports
94 | /******/ return __webpack_require__(__webpack_require__.s = "./src/index.js");
95 | /******/ })
96 | /************************************************************************/
97 | /******/ ({
98 |
99 | /***/ "./src/default.config.js":
100 | /*!*******************************!*\
101 | !*** ./src/default.config.js ***!
102 | \*******************************/
103 | /*! no static exports found */
104 | /***/ (function(module, exports, __webpack_require__) {
105 |
106 | "use strict";
107 |
108 |
109 | Object.defineProperty(exports, "__esModule", {
110 | value: true
111 | });
112 | exports.default = {
113 | renderer: 'canvas',
114 |
115 | radius: 6,
116 | nodePadding: 2,
117 | layerPadding: 20,
118 | groupPadding: 1,
119 |
120 | xPadding: 10,
121 | yPadding: 10,
122 |
123 | renderLinks: false,
124 | plotActivations: false,
125 | nodeStroke: true,
126 |
127 | onRendererInitialized: function onRendererInitialized(renderer) {
128 | document.body.appendChild(renderer.canvas);
129 | }
130 | };
131 | module.exports = exports['default'];
132 |
133 | /***/ }),
134 |
135 | /***/ "./src/index.js":
136 | /*!**********************!*\
137 | !*** ./src/index.js ***!
138 | \**********************/
139 | /*! no static exports found */
140 | /***/ (function(module, exports, __webpack_require__) {
141 |
142 | "use strict";
143 |
144 |
145 | Object.defineProperty(exports, "__esModule", {
146 | value: true
147 | });
148 |
149 | var _modelParser = __webpack_require__(/*! ./model-parser */ "./src/model-parser.js");
150 |
151 | var _modelParser2 = _interopRequireDefault(_modelParser);
152 |
153 | var _canvas = __webpack_require__(/*! ./renderers/canvas.renderer */ "./src/renderers/canvas.renderer.js");
154 |
155 | var _canvas2 = _interopRequireDefault(_canvas);
156 |
157 | var _default = __webpack_require__(/*! ./default.config */ "./src/default.config.js");
158 |
159 | var _default2 = _interopRequireDefault(_default);
160 |
161 | function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
162 |
163 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
164 |
165 | var ModelView = function ModelView(model, customConfig) {
166 | _classCallCheck(this, ModelView);
167 |
168 | var config = Object.assign({}, _default2.default, customConfig);
169 | var onRendererInitialized = config.onRendererInitialized;
170 |
171 | var renderer = void 0;
172 |
173 | config.predictCallback = function (input) {
174 | if (renderer) {
175 | renderer.update(model, input);
176 | renderer.render();
177 | }
178 | };
179 |
180 | config.hookCallback = function (layer) {
181 | if (renderer) {
182 | renderer.updateValues(layer);
183 | renderer.render();
184 | }
185 | };
186 |
187 | (0, _modelParser2.default)(model, config).then(function (res) {
188 | renderer = new _canvas2.default(config, res);
189 | if (onRendererInitialized) {
190 | onRendererInitialized(renderer);
191 | }
192 | });
193 | };
194 |
195 | exports.default = ModelView;
196 | module.exports = exports['default'];
197 |
198 | /***/ }),
199 |
200 | /***/ "./src/model-parser.js":
201 | /*!*****************************!*\
202 | !*** ./src/model-parser.js ***!
203 | \*****************************/
204 | /*! no static exports found */
205 | /***/ (function(module, exports, __webpack_require__) {
206 |
207 | "use strict";
208 |
209 |
210 | Object.defineProperty(exports, "__esModule", {
211 | value: true
212 | });
213 |
214 | var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; };
215 |
216 | var parseModel = function () {
217 | var _ref = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(model, options) {
218 | var parsed, parserConfig, parseLayer, predict, layerArr;
219 | return regeneratorRuntime.wrap(function _callee3$(_context3) {
220 | while (1) {
221 | switch (_context3.prev = _context3.next) {
222 | case 0:
223 | parseLayer = function parseLayer(layer, nextColumn) {
224 | var _this = this;
225 |
226 | var name = layer.name,
227 | input = layer.input,
228 | inputs = layer.inputs,
229 | shape = layer.shape,
230 | sourceLayer = layer.sourceLayer;
231 |
232 | var _ref2 = sourceLayer || {},
233 | getWeights = _ref2.getWeights,
234 | setCallHook = _ref2.setCallHook,
235 | activation = _ref2.activation;
236 |
237 | var currentLayer = {
238 | previousColumn: [],
239 | name: name,
240 | shape: shape,
241 | weights: {},
242 | getWeights: noop,
243 | mapPosition: Object.keys(parsed.layerMap).length
244 | };
245 |
246 | parsed.layerMap[name] = currentLayer;
247 | parsed.layerArr.unshift(currentLayer);
248 |
249 | if (activation) {
250 | var className = activation.getClassName();
251 | currentLayer.activation = {
252 | name: className
253 | };
254 | }
255 |
256 | if (setCallHook) {
257 | sourceLayer.setCallHook(function () {
258 | var _ref3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(layerInput) {
259 | var i;
260 | return regeneratorRuntime.wrap(function _callee$(_context) {
261 | while (1) {
262 | switch (_context.prev = _context.next) {
263 | case 0:
264 | currentLayer.getWeights();
265 | currentLayer.activations = [];
266 | for (i = 0; i < layerInput.length; i++) {
267 | currentLayer.activations.push(layerInput[i].dataSync());
268 | }
269 | parserConfig.hookCallback(currentLayer);
270 |
271 | case 4:
272 | case 'end':
273 | return _context.stop();
274 | }
275 | }
276 | }, _callee, _this);
277 | }));
278 |
279 | return function (_x3) {
280 | return _ref3.apply(this, arguments);
281 | };
282 | }());
283 | }
284 |
285 | if (getWeights) {
286 |
287 | currentLayer.getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
288 | var weights, i, weight, rankType, weightName;
289 | return regeneratorRuntime.wrap(function _callee2$(_context2) {
290 | while (1) {
291 | switch (_context2.prev = _context2.next) {
292 | case 0:
293 | _context2.next = 2;
294 | return sourceLayer.getWeights();
295 |
296 | case 2:
297 | weights = _context2.sent;
298 | i = 0;
299 |
300 | case 4:
301 | if (!(i < weights.length)) {
302 | _context2.next = 16;
303 | break;
304 | }
305 |
306 | weight = weights[i];
307 | rankType = weight.rankType, weightName = weight.name;
308 |
309 | currentLayer.hasWeights = true;
310 | _context2.t0 = weightName;
311 | _context2.next = 11;
312 | return weights[i].dataSync();
313 |
314 | case 11:
315 | _context2.t1 = _context2.sent;
316 | currentLayer.weights[rankType] = {
317 | name: _context2.t0,
318 | values: _context2.t1
319 | };
320 |
321 | case 13:
322 | i++;
323 | _context2.next = 4;
324 | break;
325 |
326 | case 16:
327 | case 'end':
328 | return _context2.stop();
329 | }
330 | }
331 | }, _callee2, _this);
332 | }));
333 |
334 | currentLayer.getWeights();
335 | }
336 |
337 | if (inputs) {
338 | inputs.forEach(function (inp) {
339 | parseLayer(inp, currentLayer.previousColumn);
340 | });
341 | } else {
342 | parseLayer(input, currentLayer.previousColumn);
343 | }
344 |
345 | if (nextColumn) {
346 | nextColumn.push(currentLayer);
347 | }
348 |
349 | return currentLayer;
350 | };
351 |
352 | parsed = {
353 | layerMap: {},
354 | layerArr: []
355 | };
356 | parserConfig = _extends({
357 | predictCallback: noop,
358 | hookCallback: noop
359 | }, options);
360 | predict = model.predict;
361 |
362 |
363 | model.predict = function () {
364 | for (var _len = arguments.length, args = Array(_len), _key = 0; _key < _len; _key++) {
365 | args[_key] = arguments[_key];
366 | }
367 |
368 | var result = predict.apply(model, args);
369 | model.outputData = result.dataSync();
370 | parserConfig.predictCallback(args);
371 | return result;
372 | };
373 |
374 | _context3.next = 7;
375 | return parseLayer(model.layers[model.layers.length - 1].output);
376 |
377 | case 7:
378 | parsed.model = _context3.sent;
379 |
380 |
381 | if (options.printStats) {
382 | layerArr = parsed.layerArr;
383 |
384 | console.log(new Array(10).join('-'));
385 | layerArr.forEach(function (layer) {
386 | console.log('Layer: ' + layer.name);
387 | });
388 | }
389 |
390 | return _context3.abrupt('return', parsed);
391 |
392 | case 10:
393 | case 'end':
394 | return _context3.stop();
395 | }
396 | }
397 | }, _callee3, this);
398 | }));
399 |
400 | return function parseModel(_x, _x2) {
401 | return _ref.apply(this, arguments);
402 | };
403 | }();
404 |
405 | function _asyncToGenerator(fn) { return function () { var gen = fn.apply(this, arguments); return new Promise(function (resolve, reject) { function step(key, arg) { try { var info = gen[key](arg); var value = info.value; } catch (error) { reject(error); return; } if (info.done) { resolve(value); } else { return Promise.resolve(value).then(function (value) { step("next", value); }, function (err) { step("throw", err); }); } } return step("next"); }); }; }
406 |
407 | function noop() {}
408 |
409 | exports.default = parseModel;
410 | module.exports = exports['default'];
411 |
412 | /***/ }),
413 |
414 | /***/ "./src/renderers/abstract.renderer.js":
415 | /*!********************************************!*\
416 | !*** ./src/renderers/abstract.renderer.js ***!
417 | \********************************************/
418 | /*! no static exports found */
419 | /***/ (function(module, exports, __webpack_require__) {
420 |
421 | "use strict";
422 |
423 |
424 | Object.defineProperty(exports, "__esModule", {
425 | value: true
426 | });
427 |
428 | var _slicedToArray = function () { function sliceIterator(arr, i) { var _arr = []; var _n = true; var _d = false; var _e = undefined; try { for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) { _arr.push(_s.value); if (i && _arr.length === i) break; } } catch (err) { _d = true; _e = err; } finally { try { if (!_n && _i["return"]) _i["return"](); } finally { if (_d) throw _e; } } return _arr; } return function (arr, i) { if (Array.isArray(arr)) { return arr; } else if (Symbol.iterator in Object(arr)) { return sliceIterator(arr, i); } else { throw new TypeError("Invalid attempt to destructure non-iterable instance"); } }; }();
429 |
430 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
431 |
432 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
433 |
434 | var colors = [[6, 57, 143], [0, 107, 92], [216, 139, 0], [180, 0, 85], [106, 2, 143], [216, 109, 0], [2, 105, 134], [0, 142, 103], [201, 0, 39], [139, 11, 215], [171, 141, 0]];
435 |
436 | var AbstractRenderer = function () {
437 | function AbstractRenderer(config, initData) {
438 | var _this = this;
439 |
440 | _classCallCheck(this, AbstractRenderer);
441 |
442 | var xPadding = config.xPadding,
443 | yPadding = config.yPadding,
444 | xOffset = config.xOffset,
445 | _config$layer = config.layer,
446 | layer = _config$layer === undefined ? {} : _config$layer;
447 | var layerArr = initData.layerArr;
448 |
449 |
450 | var maxHeight = (yPadding || 1) * 2;
451 | var cx = (xPadding || 0) + (xOffset || 0);
452 |
453 | function processColumn(lyr) {
454 | var col = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0;
455 |
456 | lyr.column = col;
457 | lyr.previousColumn.forEach(function (l) {
458 | processColumn(l, col + 1);
459 | });
460 | }
461 |
462 | processColumn(layerArr[layerArr.length - 1]);
463 |
464 | layerArr.forEach(function (l, lindex) {
465 | var name = l.name,
466 | shape = l.shape,
467 | previousColumn = l.previousColumn;
468 |
469 | _this.outputLayer = l;
470 | var customConfig = layer[name] || {};
471 |
472 | var layerConfig = Object.assign({}, config, customConfig);
473 | var radius = layerConfig.radius,
474 | nodePadding = layerConfig.nodePadding,
475 | layerPadding = layerConfig.layerPadding,
476 | groupPadding = layerConfig.groupPadding,
477 | _layerConfig$domain = layerConfig.domain,
478 | domain = _layerConfig$domain === undefined ? [0, 1] : _layerConfig$domain,
479 | renderLinks = layerConfig.renderLinks,
480 | renderNode = layerConfig.renderNode,
481 | nodeStroke = layerConfig.nodeStroke,
482 | reshape = layerConfig.reshape;
483 |
484 |
485 | var color = layerConfig.color || (lindex < colors.length ? colors[lindex] : [0, 0, 0]);
486 |
487 | var _Object$assign = Object.assign([1, 1, 1], shape.slice(1)),
488 | _Object$assign2 = _slicedToArray(_Object$assign, 3),
489 | rows = _Object$assign2[0],
490 | cols = _Object$assign2[1],
491 | groups = _Object$assign2[2];
492 |
493 | var totalNodes = rows * cols * groups;
494 |
495 | if (reshape) {
496 | var _Object$assign3 = Object.assign([1, 1, 1], reshape),
497 | _Object$assign4 = _slicedToArray(_Object$assign3, 3),
498 | nr = _Object$assign4[0],
499 | nc = _Object$assign4[1],
500 | ng = _Object$assign4[2];
501 |
502 | if (nr * nc * ng !== totalNodes) {
503 | throw new Error("Unable to reshape from [" + rows + ", " + cols + ", " + groups + "] to [" + nr + ", " + nc + ", " + ng + "]");
504 | }
505 |
506 | rows = nr;
507 | cols = nc;
508 | groups = ng;
509 | }
510 |
511 | cx += layerPadding;
512 |
513 | var step = radius + nodePadding;
514 | var width = layerPadding + cols * step;
515 | var nodes = [];
516 | var height = 0;
517 |
518 | for (var row = 0; row < rows; row++) {
519 | for (var col = 0; col < cols; col++) {
520 | for (var group = 0; group < groups; group++) {
521 | var y = radius + row * step + group * rows * step + group * groupPadding;
522 | nodes.push({
523 | x: cx + col * step,
524 | y: y,
525 | radius: radius
526 | });
527 | height = y;
528 | }
529 | }
530 | }
531 |
532 | height += groupPadding + radius;
533 | maxHeight = Math.max(maxHeight, height);
534 |
535 | Object.assign(l, {
536 | name: name,
537 | x: cx,
538 | layerWidth: width,
539 | layerHeight: height,
540 | radius: radius,
541 | nodes: nodes,
542 | domain: domain,
543 | renderLinks: renderLinks,
544 | renderNode: renderNode,
545 | nodeStroke: nodeStroke,
546 | color: color,
547 | previousLayers: previousColumn.map(function (lyr) {
548 | return lyr.name;
549 | })
550 | });
551 |
552 | cx += width;
553 | });
554 |
555 | cx += xPadding || 0;
556 |
557 | layerArr.forEach(function (l) {
558 | var offsetY = Math.floor((maxHeight - l.layerHeight) / 2);
559 | l.nodes.forEach(function (nd) {
560 | return nd.y += offsetY;
561 | });
562 | });
563 |
564 | Object.assign(this, { width: cx, height: maxHeight });
565 |
566 | this.layers = layerArr;
567 | this.layersMap = layerArr.reduce(function (memo, item) {
568 | memo[item.name] = item;
569 | return memo;
570 | }, {});
571 | }
572 |
573 | _createClass(AbstractRenderer, [{
574 | key: "update",
575 | value: function update(model, input) {
576 | var _this2 = this;
577 |
578 | if (input) {
579 | model.inputs.forEach(function (inputLayer, index) {
580 | var syntheticLayer = _this2.layersMap[inputLayer.name];
581 | _this2.updateLayerValues(syntheticLayer, input[index].dataSync());
582 | });
583 | }
584 |
585 | this.updateLayerValues(this.outputLayer, model.outputData);
586 | }
587 | }, {
588 | key: "updateLayerValues",
589 | value: function updateLayerValues(layer, data) {
590 | for (var i = 0; i < layer.nodes.length; i++) {
591 | layer.nodes[i].value = data[i];
592 | }
593 | }
594 | }, {
595 | key: "updateValues",
596 | value: function updateValues(layer) {
597 | var _this3 = this;
598 |
599 | var syntheticLayer = this.layersMap[layer.name];
600 | syntheticLayer.weights = layer.weights;
601 | syntheticLayer.previousColumn.forEach(function (col, idx) {
602 | _this3.updateLayerValues(col, layer.activations[idx]);
603 | });
604 | }
605 | }]);
606 |
607 | return AbstractRenderer;
608 | }();
609 |
610 | exports.default = AbstractRenderer;
611 | module.exports = exports["default"];
612 |
613 | /***/ }),
614 |
615 | /***/ "./src/renderers/canvas.renderer.js":
616 | /*!******************************************!*\
617 | !*** ./src/renderers/canvas.renderer.js ***!
618 | \******************************************/
619 | /*! no static exports found */
620 | /***/ (function(module, exports, __webpack_require__) {
621 |
622 | "use strict";
623 |
624 |
625 | Object.defineProperty(exports, "__esModule", {
626 | value: true
627 | });
628 |
629 | var _slicedToArray = function () { function sliceIterator(arr, i) { var _arr = []; var _n = true; var _d = false; var _e = undefined; try { for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) { _arr.push(_s.value); if (i && _arr.length === i) break; } } catch (err) { _d = true; _e = err; } finally { try { if (!_n && _i["return"]) _i["return"](); } finally { if (_d) throw _e; } } return _arr; } return function (arr, i) { if (Array.isArray(arr)) { return arr; } else if (Symbol.iterator in Object(arr)) { return sliceIterator(arr, i); } else { throw new TypeError("Invalid attempt to destructure non-iterable instance"); } }; }();
630 |
631 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
632 |
633 | var _abstract = __webpack_require__(/*! ./abstract.renderer */ "./src/renderers/abstract.renderer.js");
634 |
635 | var _abstract2 = _interopRequireDefault(_abstract);
636 |
637 | function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
638 |
639 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
640 |
641 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; }
642 |
643 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; }
644 |
645 | var CanvasRenderer = function (_AbstractRenderer) {
646 | _inherits(CanvasRenderer, _AbstractRenderer);
647 |
648 | function CanvasRenderer(config, initData) {
649 | _classCallCheck(this, CanvasRenderer);
650 |
651 | var _this = _possibleConstructorReturn(this, (CanvasRenderer.__proto__ || Object.getPrototypeOf(CanvasRenderer)).call(this, config, initData));
652 |
653 | var canvas = document.createElement('canvas');
654 |
655 | var onBeginRender = config.onBeginRender,
656 | onEndRender = config.onEndRender;
657 |
658 |
659 | Object.assign(_this, { canvas: canvas, onBeginRender: onBeginRender, onEndRender: onEndRender });
660 |
661 | canvas.setAttribute('width', _this.width);
662 | canvas.setAttribute('height', _this.height);
663 |
664 | _this.renderContext = canvas.getContext('2d');
665 | _this.renderElement = canvas;
666 | return _this;
667 | }
668 |
669 | _createClass(CanvasRenderer, [{
670 | key: 'render',
671 | value: function render() {
672 | var _this2 = this;
673 |
674 | window.requestAnimationFrame(function () {
675 | var onBeginRender = _this2.onBeginRender,
676 | onEndRender = _this2.onEndRender;
677 |
678 | _this2.renderContext.clearRect(0, 0, _this2.width, _this2.height);
679 |
680 | if (onBeginRender) {
681 | onBeginRender(_this2);
682 | }
683 |
684 | _this2.layers.forEach(function (layer) {
685 | var radius = layer.radius,
686 | nodes = layer.nodes,
687 | _layer$domain = _slicedToArray(layer.domain, 2),
688 | min = _layer$domain[0],
689 | max = _layer$domain[1],
690 | previousColumn = layer.previousColumn,
691 | renderLinks = layer.renderLinks,
692 | renderNode = layer.renderNode,
693 | weights = layer.weights,
694 | _layer$color = _slicedToArray(layer.color, 3),
695 | r = _layer$color[0],
696 | g = _layer$color[1],
697 | b = _layer$color[2],
698 | nodeStroke = layer.nodeStroke;
699 |
700 | var kernel = weights[2];
701 |
702 | var leftSideNodes = void 0;
703 |
704 | if (renderLinks) {
705 | leftSideNodes = previousColumn.reduce(function (memo, prevLayer) {
706 | return memo.concat(prevLayer.nodes);
707 | }, []);
708 | }
709 |
710 | nodes.forEach(function (node, index) {
711 | var nx = node.x,
712 | ny = node.y,
713 | value = node.value;
714 |
715 |
716 | if (renderLinks) {
717 | leftSideNodes.forEach(function (leftNode, leftIdx) {
718 | _this2.renderContext.beginPath();
719 | var hasWeight = kernel && kernel.values;
720 | var weightVal = hasWeight ? kernel.values[index * leftIdx] : 0.5;
721 | if (hasWeight) {
722 | _this2.renderContext.strokeStyle = weightVal > 0 ? 'rgb(0, 0, 255, ' + weightVal + ')' : 'rgb(255, 0, 0, ' + Math.abs(weightVal) + ')';
723 | } else {
724 | _this2.renderContext.strokeStyle = 'rgba(0,0,0,.5)';
725 | }
726 | _this2.renderContext.moveTo(leftNode.x + leftNode.radius / 2, leftNode.y);
727 | _this2.renderContext.lineTo(node.x - node.radius / 2, node.y);
728 | _this2.renderContext.stroke();
729 | });
730 | }
731 | _this2.renderContext.strokeStyle = 'rgb(' + r + ', ' + g + ', ' + b + ')';
732 | var domainValue = value / (max + min);
733 | if (!isNaN(domainValue)) {
734 | _this2.renderContext.fillStyle = 'rgba(' + r + ', ' + g + ', ' + b + ', ' + domainValue + ')';
735 | } else {
736 | _this2.renderContext.fillStyle = '#FFF';
737 | }
738 | _this2.renderContext.beginPath();
739 | _this2.renderContext.arc(nx, ny, radius / 2, 0, 2 * Math.PI);
740 | if (radius > 3 && nodeStroke) {
741 | _this2.renderContext.stroke();
742 | }
743 | _this2.renderContext.fill();
744 |
745 | if (renderNode) {
746 | renderNode(_this2.renderContext, node, index);
747 | }
748 | });
749 | });
750 |
751 | if (onEndRender) {
752 | onEndRender(_this2);
753 | }
754 | });
755 | }
756 | }]);
757 |
758 | return CanvasRenderer;
759 | }(_abstract2.default);
760 |
761 | exports.default = CanvasRenderer;
762 | module.exports = exports['default'];
763 |
764 | /***/ })
765 |
766 | /******/ });
767 | });
768 | //# sourceMappingURL=data:application/json;charset=utf-8;base64,
--------------------------------------------------------------------------------
/lib/tfjs-model-view.min.js:
--------------------------------------------------------------------------------
1 | !function(e,t){"object"==typeof exports&&"object"==typeof module?module.exports=t():"function"==typeof define&&define.amd?define("tfjs-model-view",[],t):"object"==typeof exports?exports["tfjs-model-view"]=t():e["tfjs-model-view"]=t()}("undefined"!=typeof self?self:this,(function(){return function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}return r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=0)}([function(e,t,r){"use strict";Object.defineProperty(t,"__esModule",{value:!0});var n=i(r(1)),o=i(r(2)),a=i(r(4));function i(e){return e&&e.__esModule?e:{default:e}}t.default=function e(t,r){!function(e,t){if(!(e instanceof t))throw new TypeError("Cannot call a class as a function")}(this,e);var i=Object.assign({},a.default,r),u=i.onRendererInitialized,s=void 0;i.predictCallback=function(e){s&&(s.update(t,e),s.render())},i.hookCallback=function(e){s&&(s.updateValues(e),s.render())},(0,n.default)(t,i).then((function(e){s=new o.default(i,e),u&&u(s)}))},e.exports=t.default},function(e,t,r){"use strict";Object.defineProperty(t,"__esModule",{value:!0});var n,o=Object.assign||function(e){for(var t=1;t