├── .babelrc ├── .editorconfig ├── .eslintrc ├── .gitignore ├── .npmignore ├── LICENSE ├── README.md ├── app ├── index.html ├── index.js ├── iris-custom │ ├── data.js │ └── iris.js ├── iris │ ├── data.js │ └── iris.js ├── mnist-conv │ └── mnist.js ├── mnist │ ├── data.js │ └── mnist.js ├── models │ └── mnist-dense │ │ ├── group1-shard1of1 │ │ ├── group2-shard1of1 │ │ ├── group3-shard1of1 │ │ └── model.json └── tiny │ └── tiny.js ├── lib ├── tfjs-model-view.js └── tfjs-model-view.min.js ├── package-lock.json ├── package.json ├── src ├── default.config.js ├── index.js ├── model-parser.js └── renderers │ ├── abstract.renderer.js │ └── canvas.renderer.js ├── stats.json ├── test └── index.spec.js ├── webpack.config.js └── webpack.dev.config.js /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": ["env"], 3 | "plugins": ["babel-plugin-add-module-exports", "transform-object-rest-spread"] 4 | } 5 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 2 6 | end_of_line = LF 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.md] 12 | trim_trailing_whitespace = false 13 | -------------------------------------------------------------------------------- /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "node": true, 5 | "es6": true, 6 | "mocha": true 7 | }, 8 | "parserOptions": { 9 | "ecmaVersion": 6, 10 | "sourceType": "module", 11 | "ecmaFeatures": { 12 | "impliedStrict": true, 13 | "experimentalObjectRestSpread": true 14 | } 15 | }, 16 | "globals": {}, 17 | "rules": { 18 | "no-cond-assign": 2, 19 | "no-constant-condition": 2, 20 | "no-dupe-args": 2, 21 | "no-dupe-keys": 2, 22 | "no-duplicate-case": 2, 23 | "no-empty-character-class": 2, 24 | "no-empty": [ 25 | 1, 26 | { 27 | "allowEmptyCatch": true 28 | } 29 | ], 30 | "no-ex-assign": 2, 31 | "no-extra-boolean-cast": 2, 32 | "no-extra-parens": 2, 33 | "no-extra-semi": 2, 34 | "no-func-assign": 2, 35 | "no-inner-declarations": [ 36 | 2, 37 | "both" 38 | ], 39 | "no-invalid-regexp": 2, 40 | "no-irregular-whitespace": [ 41 | 2, 42 | { 43 | "skipStrings": true, 44 | "skipComments": true, 45 | "skipRegExps": true, 46 | "skipTemplates": true 47 | } 48 | ], 49 | "no-negated-in-lhs": 2, 50 | "no-obj-calls": 2, 51 | "no-prototype-builtins": 2, 52 | "no-regex-spaces": 2, 53 | "no-sparse-arrays": 1, 54 | "no-unexpected-multiline": 1, 55 | "no-unreachable": 2, 56 | "no-unsafe-finally": 1, 57 | "use-isnan": 2, 58 | "valid-jsdoc": 0, 59 | "valid-typeof": 2, 60 | "accessor-pairs": 1, 61 | "array-callback-return": 1, 62 | "block-scoped-var": 1, 63 | "curly": 1, 64 | "default-case": 1, 65 | "dot-notation": 1, 66 | "eqeqeq": [ 67 | 2, 68 | "allow-null" 69 | ], 70 | "guard-for-in": 1, 71 | "no-alert": 1, 72 | "no-caller": 2, 73 | "no-eval": 2, 74 | "no-extend-native": 1, 75 | "no-extra-bind": 1, 76 | 77 | "no-floating-decimal": 2, 78 | 79 | "no-implicit-globals": 1, 80 | 81 | "no-implied-eval": 2, 82 | 83 | "no-lone-blocks": 1, 84 | 85 | "no-loop-func": 1, 86 | 87 | "no-multi-spaces": 1, 88 | 89 | "no-redeclare": [ 90 | 2, 91 | { 92 | "builtinGlobals": true 93 | } 94 | ], 95 | 96 | "no-sequences": 1, 97 | 98 | "vars-on-top": 1, 99 | 100 | "wrap-iife": 1, 101 | 102 | "yoda": 1, 103 | "strict": [ 104 | 1, 105 | "function" 106 | ], 107 | 108 | "no-delete-var": 2, 109 | 110 | "no-restricted-globals": 2, 111 | 112 | "no-unused-vars": 1, 113 | 114 | "no-use-before-define": 2, 115 | 116 | "no-undef": 2, 117 | 118 | "no-undef-init": 1, 119 | 120 | "no-undefined": 1, 121 | 122 | "no-shadow": [ 123 | 2, 124 | { 125 | "hoist": "functions" 126 | } 127 | ], 128 | 129 | "array-bracket-spacing": [ 130 | 1, 131 | "never" 132 | ], 133 | 134 | "block-spacing": [ 135 | 1, 136 | "never" 137 | ], 138 | 139 | "brace-style": [ 140 | 1, 141 | "1tbs", 142 | { 143 | "allowSingleLine": true 144 | } 145 | ], 146 | 147 | "comma-dangle": 1, 148 | 149 | "comma-spacing": [ 150 | 1, 151 | { 152 | "before": false, 153 | "after": true 154 | } 155 | ], 156 | 157 | "comma-style": [ 158 | 1, 159 | "last" 160 | ], 161 | 162 | "computed-property-spacing": [ 163 | 1, 164 | "never" 165 | ], 166 | 167 | "consistent-this": [ 168 | 1, 169 | "that" 170 | ], 171 | 172 | "eol-last": 1, 173 | 174 | "id-blacklist": 1, 175 | 176 | 177 | "key-spacing": [ 178 | 1, 179 | { 180 | "afterColon": true 181 | } 182 | ], 183 | 184 | "keyword-spacing": 1, 185 | 186 | 187 | 188 | "new-cap": 1, 189 | 190 | "new-parens": 1, 191 | 192 | "no-mixed-spaces-and-tabs": 1, 193 | 194 | "no-multiple-empty-lines": [ 195 | 1, 196 | { 197 | "max": 1, 198 | "maxBOF": 0, 199 | "maxEOF": 1 200 | } 201 | ], 202 | 203 | "no-spaced-func": 1, 204 | 205 | 206 | "no-unneeded-ternary": 1, 207 | 208 | "no-whitespace-before-property": 1, 209 | 210 | "object-curly-spacing": [ 211 | 0 212 | ], 213 | 214 | "quotes": [ 215 | 1, 216 | "single", 217 | { 218 | "avoidEscape": true, 219 | "allowTemplateLiterals": true 220 | } 221 | ], 222 | 223 | "semi-spacing": [ 224 | 1, 225 | { 226 | "before": false, 227 | "after": true 228 | } 229 | ], 230 | 231 | 232 | "space-before-blocks": 1, 233 | 234 | "space-before-function-paren": [ 235 | 1, 236 | { 237 | "anonymous": "always", 238 | "named": "never" 239 | } 240 | ], 241 | 242 | "space-in-parens": [ 243 | 1, 244 | "never" 245 | ], 246 | 247 | "space-infix-ops": 1, 248 | 249 | "spaced-comment": [ 250 | 1, 251 | "always" 252 | ], 253 | 254 | "arrow-body-style": [ 255 | 1, 256 | "as-needed" 257 | ], 258 | 259 | "arrow-parens": [ 260 | 1, 261 | "as-needed" 262 | ], 263 | 264 | "arrow-spacing": [ 265 | 1, 266 | { 267 | "before": true, 268 | "after": true 269 | } 270 | ], 271 | 272 | "constructor-super": 2, 273 | 274 | "generator-star-spacing": [ 275 | 1, 276 | { 277 | "before": true, 278 | "after": false 279 | } 280 | ], 281 | 282 | "no-class-assign": 2, 283 | 284 | "no-confusing-arrow": 0, 285 | 286 | "no-const-assign": 2, 287 | 288 | "no-dupe-class-members": 1, 289 | 290 | "no-duplicate-imports": [ 291 | 1, 292 | { 293 | "includeExports": true 294 | } 295 | ], 296 | 297 | "no-new-symbol": 2, 298 | 299 | "no-restricted-imports": [ 300 | 2, 301 | "assert", 302 | "buffer", 303 | "child_process", 304 | "cluster", 305 | "crypto", 306 | "dgram", 307 | "dns", 308 | "domain", 309 | "events", 310 | "freelist", 311 | "fs", 312 | "http", 313 | "https", 314 | "module", 315 | "net", 316 | "os", 317 | "path", 318 | "punycode", 319 | "querystring", 320 | "readline", 321 | "repl", 322 | "smalloc", 323 | "stream", 324 | "string_decoder", 325 | "sys", 326 | "timers", 327 | "tls", 328 | "tracing", 329 | "tty", 330 | "url", 331 | "util", 332 | "vm", 333 | "zlib" 334 | ], 335 | 336 | "no-this-before-super": 2, 337 | 338 | "no-useless-computed-key": 1, 339 | 340 | "no-useless-constructor": 1, 341 | 342 | "no-useless-rename": 1, 343 | 344 | "no-var": 1, 345 | 346 | "prefer-rest-params": 1, 347 | 348 | "prefer-template": 1, 349 | 350 | "require-yield": 2, 351 | 352 | "rest-spread-spacing": [ 353 | 1, 354 | "never" 355 | ], 356 | 357 | "template-curly-spacing": [ 358 | 1, 359 | "never" 360 | ], 361 | 362 | "yield-star-spacing": [ 363 | 1, 364 | { 365 | "before": true, 366 | "after": false 367 | } 368 | ] 369 | } 370 | } 371 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | 5 | ignore 6 | 7 | # Runtime data 8 | pids 9 | *.pid 10 | *.seed 11 | 12 | # Directory for instrumented libs generated by jscoverage/JSCover 13 | lib-cov 14 | 15 | # Coverage directory used by tools like istanbul 16 | coverage 17 | .nyc_output 18 | 19 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) 20 | .grunt 21 | 22 | # node-waf configuration 23 | .lock-wscript 24 | 25 | # Compiled binary addons (http://nodejs.org/api/addons.html) 26 | build/Release 27 | 28 | # Dependency directory 29 | # https://www.npmjs.org/doc/misc/npm-faq.html#should-i-check-my-node_modules-folder-into-git 30 | node_modules 31 | 32 | # Remove some common IDE working directories 33 | .idea 34 | .vscode 35 | 36 | .DS_Store 37 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | app 2 | ignore 3 | test 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Cornel Stefanache 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tfjs-model-view 2 | 3 | __tfjs-model-view__ is a library for _in browser_ visualization of neural network intended for use with TensorFlow.js. 4 | 5 | Features: 6 | 7 | * Automatically render of the neural net 8 | * Automatically update weights/biases/values 9 | 10 | The library also aims to be flexible and make it easy for you to incorporate. 11 | 12 | ## Demos 13 | 14 | - [Movielens recommendation using Tensorflow.js](https://beta.observablehq.com/@cstefanache/movielens-recommendation-using-tensorflow-js) 15 | - [Iris Prediction with Custom Node Renderer](https://beta.observablehq.com/@cstefanache/tensorflow-js-model-viewer-iris) 16 | - [MNIST Prediction](https://beta.observablehq.com/@cstefanache/mnist-tensorflow-js-network-view-tfjs-model-view) 17 | - [Multiple Input Rendering](https://beta.observablehq.com/@cstefanache/tensorflow-js-model-view-multiple-input-test) 18 | 19 | ## Sample rendering output 20 | 21 | ![Samples](https://raw.githubusercontent.com/cstefanache/cstefanache.github.io/master/media/img/net2.png "Samples") 22 | 23 | 24 | ## Usage 25 | 26 | Simple: 27 | ``` 28 | new ModelView(model) 29 | ``` 30 | 31 | Customized: 32 | ``` 33 | new ModelView(model, { 34 | printStats: true, 35 | radius: 25, 36 | renderLinks: true, 37 | xOffset: 100, 38 | renderNode(ctx, node) { 39 | const { x, y, value } = node; 40 | ctx.font = '10px Arial'; 41 | ctx.fillStyle = '#000'; 42 | ctx.textAlign = 'center'; 43 | ctx.textBaseline = 'middle'; 44 | ctx.fillText(Math.round(value * 100) / 100, x, y); 45 | }, 46 | onBeginRender: renderer => { 47 | const { renderContext } = renderer; 48 | renderContext.fillStyle = '#000'; 49 | renderContext.textAlign = 'end'; 50 | renderContext.font = '12px Arial'; 51 | renderContext.fillText('Sepal Length (cm)', 110, 110); 52 | renderContext.fillText('Sepal Width (cm)', 110, 136); 53 | renderContext.fillText('Petal Length (cm)', 110, 163); 54 | renderContext.fillText('Petal Width (cm)', 110, 190); 55 | 56 | renderContext.textAlign = 'start'; 57 | renderContext.fillText('Setosa', renderer.width - 60, 95); 58 | renderContext.fillText('Versicolor', renderer.width - 60, 150); 59 | renderContext.fillText('Virginica', renderer.width - 60, 205); 60 | }, 61 | layer: { 62 | 'dense_Dense1_input': { 63 | domain: [0, 8], 64 | color: [165, 130, 180] 65 | }, 66 | 'dense_Dense1/dense_Dense1': { 67 | color: [125, 125, 125] 68 | }, 69 | 'dense_Dense2/dense_Dense2': { 70 | color: [125, 125, 125] 71 | }, 72 | 'dense_Dense3/dense_Dense3': { 73 | nodePadding: 30 74 | } 75 | } 76 | }); 77 | ``` 78 | 79 | Customizing: 80 | ``` 81 | new ModelView(model, { 82 | /** renders the list of layers **/ 83 | printStats: true, 84 | 85 | /** Default domain for color intensity **/ 86 | domain: [0, 1], 87 | 88 | /** Default node radius **/ 89 | radius: 6, 90 | 91 | /** Default node padding **/ 92 | nodePadding: 2, 93 | 94 | /** Default layer padding **/ 95 | layerPadding: 20, 96 | 97 | /** Default group padding **/ 98 | groupPadding: 1, 99 | 100 | /** Horizontal padding **/ 101 | xPadding: 10, 102 | 103 | /** Vertical padding **/ 104 | yPadding: 10, 105 | 106 | /** Render links between layers **/ 107 | renderLinks: false, 108 | 109 | /** Stroke node outer circle **/ 110 | nodeStroke: true, 111 | 112 | /** custom render node function **/ 113 | renderNode: (ctx, node, nodeIdx) => {...}, 114 | 115 | /** If present will be executed before node rendering **/ 116 | onBeginRender: renderer => { ... }, 117 | 118 | /** If present will be executed after all node rendering is finished **/ 119 | onEndRender: renderer => { ... }, 120 | 121 | /** Personalized layer configuration **/ 122 | /** All defaults can be overridden for each layer individually **/ 123 | layer: { 124 | 'layerName': { 125 | /** Any property mentioned above **/ 126 | 127 | /** Reshape layer to antoher [cols, rows, groups] layout **/ 128 | reshape: [4, 4, 8] 129 | } 130 | } 131 | }); 132 | ``` 133 | 134 | ## Installation 135 | 136 | You can install this using npm with 137 | 138 | ``` 139 | npm install tfjs-model-view 140 | ``` 141 | 142 | or using yarn with 143 | 144 | ``` 145 | yarn add tfjs-model-view 146 | ``` 147 | 148 | ## Building from source 149 | 150 | To build the library, you need to have node.js installed. We use `yarn` 151 | instead of `npm` but you can use either. 152 | 153 | First install dependencies with 154 | 155 | ``` 156 | yarn 157 | ``` 158 | 159 | or 160 | 161 | ``` 162 | npm install 163 | ``` 164 | 165 | You can start the dev environment using 166 | 167 | ``` 168 | yarn dev 169 | ``` 170 | 171 | or 172 | 173 | ``` 174 | npm run dev 175 | ``` 176 | 177 | 178 | ## Sample Usage 179 | 180 | 181 | ## Issues 182 | 183 | Found a bug or have a feature request? Please file an [issue](https://github.com/cstefanache/tfjs-model-view/issues/new) 184 | -------------------------------------------------------------------------------- /app/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TFJS sample 5 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /app/index.js: -------------------------------------------------------------------------------- 1 | import 'babel-polyfill'; 2 | 3 | import runIris from './iris/iris'; 4 | import runIrisText from '!!raw-loader!././iris/iris.js'; 5 | 6 | import runIrisCustom from './iris-custom/iris'; 7 | import runIrisCustomText from '!!raw-loader!./iris-custom/iris.js'; 8 | 9 | import runMnist from './mnist/mnist'; 10 | import runMnistText from '!!raw-loader!./mnist/mnist.js'; 11 | 12 | import runMnistConv from './mnist-conv/mnist'; 13 | import runMnistConvText from '!!raw-loader!./mnist-conv/mnist.js'; 14 | 15 | import tiny from './tiny/tiny'; 16 | import tinyText from '!!raw-loader!./tiny/tiny.js'; 17 | 18 | const samples = { 19 | tiny: { 20 | name: 'Tiny', 21 | link: 'tiny', 22 | executor: tiny, 23 | text: tinyText 24 | }, 25 | iris: { 26 | name: 'Iris', 27 | link: 'iris', 28 | executor: runIris, 29 | text: runIrisText 30 | }, 31 | irisc: { 32 | name: 'Iris Custom', 33 | link: 'irisc', 34 | executor: runIrisCustom, 35 | text: runIrisCustomText 36 | }, 37 | mnist: { 38 | name: 'Mnist Dense', 39 | link: 'mnist', 40 | executor: runMnist, 41 | text: runMnistText 42 | }, 43 | mnistc: { 44 | name: 'Mnist Conv', 45 | link: 'mnistc', 46 | executor: runMnistConv, 47 | text: runMnistConvText 48 | } 49 | }; 50 | 51 | function getUrlParameter(name) { 52 | name = name.replace(/[\[]/, '\\[').replace(/[\]]/, '\\]'); 53 | const regex = new RegExp(`[\\?&]${name}=([^&#]*)`); 54 | const results = regex.exec(location.search); 55 | return results === null ? '' : decodeURIComponent(results[1].replace(/\+/g, ' ')); 56 | } 57 | const runner = samples[getUrlParameter('sample')]; 58 | let load, contentElem; 59 | 60 | function prepareMenu() { 61 | document.body.innerHTML = '

Tensorflow.js model viewer samples:

'; 62 | const menuBar = document.createElement('div'); 63 | contentElem = document.querySelector('#content'); 64 | menuBar.classList.add('menu-content'); 65 | contentElem.appendChild(menuBar); 66 | 67 | Object.values(samples).forEach(sample => { 68 | 69 | const menuItem = document.createElement('a'); 70 | menuItem.setAttribute('href', `http://localhost:4500?sample=${sample.link.toLowerCase()}${sample.append ? sample.append : ''}`) 71 | menuItem.innerHTML = sample.name; 72 | menuItem.classList.add('menu-item') 73 | if (runner === sample) { 74 | menuItem.classList.add('selected') 75 | } 76 | menuItem.addEventListener('click', () => { 77 | load(sample.executor, sample); 78 | 79 | }); 80 | menuBar.appendChild(menuItem); 81 | 82 | }) 83 | } 84 | 85 | load = (executor, sample) => { 86 | if (executor) { 87 | executor(); 88 | let text = sample.text; 89 | text = text.substr(text.indexOf(' new ModelView')) 90 | text = text.substr(0, text.indexOf(' });') + 4) 91 | document.body.innerHTML += `
${text}
`; 92 | } 93 | } 94 | 95 | prepareMenu(); 96 | 97 | if (runner) { 98 | load(runner.executor, runner); 99 | } 100 | -------------------------------------------------------------------------------- /app/iris-custom/data.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs'; 2 | 3 | export const IRIS_CLASSES = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']; 4 | export const IRIS_NUM_CLASSES = IRIS_CLASSES.length; 5 | 6 | // Iris flowers data. Source: 7 | // https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data 8 | export const IRIS_DATA = [ 9 | [5.1, 3.5, 1.4, 0.2, 0], 10 | [4.9, 3.0, 1.4, 0.2, 0], 11 | [4.7, 3.2, 1.3, 0.2, 0], 12 | [4.6, 3.1, 1.5, 0.2, 0], 13 | [5.0, 3.6, 1.4, 0.2, 0], 14 | [5.4, 3.9, 1.7, 0.4, 0], 15 | [4.6, 3.4, 1.4, 0.3, 0], 16 | [5.0, 3.4, 1.5, 0.2, 0], 17 | [4.4, 2.9, 1.4, 0.2, 0], 18 | [4.9, 3.1, 1.5, 0.1, 0], 19 | [5.4, 3.7, 1.5, 0.2, 0], 20 | [4.8, 3.4, 1.6, 0.2, 0], 21 | [4.8, 3.0, 1.4, 0.1, 0], 22 | [4.3, 3.0, 1.1, 0.1, 0], 23 | [5.8, 4.0, 1.2, 0.2, 0], 24 | [5.7, 4.4, 1.5, 0.4, 0], 25 | [5.4, 3.9, 1.3, 0.4, 0], 26 | [5.1, 3.5, 1.4, 0.3, 0], 27 | [5.7, 3.8, 1.7, 0.3, 0], 28 | [5.1, 3.8, 1.5, 0.3, 0], 29 | [5.4, 3.4, 1.7, 0.2, 0], 30 | [5.1, 3.7, 1.5, 0.4, 0], 31 | [4.6, 3.6, 1.0, 0.2, 0], 32 | [5.1, 3.3, 1.7, 0.5, 0], 33 | [4.8, 3.4, 1.9, 0.2, 0], 34 | [5.0, 3.0, 1.6, 0.2, 0], 35 | [5.0, 3.4, 1.6, 0.4, 0], 36 | [5.2, 3.5, 1.5, 0.2, 0], 37 | [5.2, 3.4, 1.4, 0.2, 0], 38 | [4.7, 3.2, 1.6, 0.2, 0], 39 | [4.8, 3.1, 1.6, 0.2, 0], 40 | [5.4, 3.4, 1.5, 0.4, 0], 41 | [5.2, 4.1, 1.5, 0.1, 0], 42 | [5.5, 4.2, 1.4, 0.2, 0], 43 | [4.9, 3.1, 1.5, 0.1, 0], 44 | [5.0, 3.2, 1.2, 0.2, 0], 45 | [5.5, 3.5, 1.3, 0.2, 0], 46 | [4.9, 3.1, 1.5, 0.1, 0], 47 | [4.4, 3.0, 1.3, 0.2, 0], 48 | [5.1, 3.4, 1.5, 0.2, 0], 49 | [5.0, 3.5, 1.3, 0.3, 0], 50 | [4.5, 2.3, 1.3, 0.3, 0], 51 | [4.4, 3.2, 1.3, 0.2, 0], 52 | [5.0, 3.5, 1.6, 0.6, 0], 53 | [5.1, 3.8, 1.9, 0.4, 0], 54 | [4.8, 3.0, 1.4, 0.3, 0], 55 | [5.1, 3.8, 1.6, 0.2, 0], 56 | [4.6, 3.2, 1.4, 0.2, 0], 57 | [5.3, 3.7, 1.5, 0.2, 0], 58 | [5.0, 3.3, 1.4, 0.2, 0], 59 | [7.0, 3.2, 4.7, 1.4, 1], 60 | [6.4, 3.2, 4.5, 1.5, 1], 61 | [6.9, 3.1, 4.9, 1.5, 1], 62 | [5.5, 2.3, 4.0, 1.3, 1], 63 | [6.5, 2.8, 4.6, 1.5, 1], 64 | [5.7, 2.8, 4.5, 1.3, 1], 65 | [6.3, 3.3, 4.7, 1.6, 1], 66 | [4.9, 2.4, 3.3, 1.0, 1], 67 | [6.6, 2.9, 4.6, 1.3, 1], 68 | [5.2, 2.7, 3.9, 1.4, 1], 69 | [5.0, 2.0, 3.5, 1.0, 1], 70 | [5.9, 3.0, 4.2, 1.5, 1], 71 | [6.0, 2.2, 4.0, 1.0, 1], 72 | [6.1, 2.9, 4.7, 1.4, 1], 73 | [5.6, 2.9, 3.6, 1.3, 1], 74 | [6.7, 3.1, 4.4, 1.4, 1], 75 | [5.6, 3.0, 4.5, 1.5, 1], 76 | [5.8, 2.7, 4.1, 1.0, 1], 77 | [6.2, 2.2, 4.5, 1.5, 1], 78 | [5.6, 2.5, 3.9, 1.1, 1], 79 | [5.9, 3.2, 4.8, 1.8, 1], 80 | [6.1, 2.8, 4.0, 1.3, 1], 81 | [6.3, 2.5, 4.9, 1.5, 1], 82 | [6.1, 2.8, 4.7, 1.2, 1], 83 | [6.4, 2.9, 4.3, 1.3, 1], 84 | [6.6, 3.0, 4.4, 1.4, 1], 85 | [6.8, 2.8, 4.8, 1.4, 1], 86 | [6.7, 3.0, 5.0, 1.7, 1], 87 | [6.0, 2.9, 4.5, 1.5, 1], 88 | [5.7, 2.6, 3.5, 1.0, 1], 89 | [5.5, 2.4, 3.8, 1.1, 1], 90 | [5.5, 2.4, 3.7, 1.0, 1], 91 | [5.8, 2.7, 3.9, 1.2, 1], 92 | [6.0, 2.7, 5.1, 1.6, 1], 93 | [5.4, 3.0, 4.5, 1.5, 1], 94 | [6.0, 3.4, 4.5, 1.6, 1], 95 | [6.7, 3.1, 4.7, 1.5, 1], 96 | [6.3, 2.3, 4.4, 1.3, 1], 97 | [5.6, 3.0, 4.1, 1.3, 1], 98 | [5.5, 2.5, 4.0, 1.3, 1], 99 | [5.5, 2.6, 4.4, 1.2, 1], 100 | [6.1, 3.0, 4.6, 1.4, 1], 101 | [5.8, 2.6, 4.0, 1.2, 1], 102 | [5.0, 2.3, 3.3, 1.0, 1], 103 | [5.6, 2.7, 4.2, 1.3, 1], 104 | [5.7, 3.0, 4.2, 1.2, 1], 105 | [5.7, 2.9, 4.2, 1.3, 1], 106 | [6.2, 2.9, 4.3, 1.3, 1], 107 | [5.1, 2.5, 3.0, 1.1, 1], 108 | [5.7, 2.8, 4.1, 1.3, 1], 109 | [6.3, 3.3, 6.0, 2.5, 2], 110 | [5.8, 2.7, 5.1, 1.9, 2], 111 | [7.1, 3.0, 5.9, 2.1, 2], 112 | [6.3, 2.9, 5.6, 1.8, 2], 113 | [6.5, 3.0, 5.8, 2.2, 2], 114 | [7.6, 3.0, 6.6, 2.1, 2], 115 | [4.9, 2.5, 4.5, 1.7, 2], 116 | [7.3, 2.9, 6.3, 1.8, 2], 117 | [6.7, 2.5, 5.8, 1.8, 2], 118 | [7.2, 3.6, 6.1, 2.5, 2], 119 | [6.5, 3.2, 5.1, 2.0, 2], 120 | [6.4, 2.7, 5.3, 1.9, 2], 121 | [6.8, 3.0, 5.5, 2.1, 2], 122 | [5.7, 2.5, 5.0, 2.0, 2], 123 | [5.8, 2.8, 5.1, 2.4, 2], 124 | [6.4, 3.2, 5.3, 2.3, 2], 125 | [6.5, 3.0, 5.5, 1.8, 2], 126 | [7.7, 3.8, 6.7, 2.2, 2], 127 | [7.7, 2.6, 6.9, 2.3, 2], 128 | [6.0, 2.2, 5.0, 1.5, 2], 129 | [6.9, 3.2, 5.7, 2.3, 2], 130 | [5.6, 2.8, 4.9, 2.0, 2], 131 | [7.7, 2.8, 6.7, 2.0, 2], 132 | [6.3, 2.7, 4.9, 1.8, 2], 133 | [6.7, 3.3, 5.7, 2.1, 2], 134 | [7.2, 3.2, 6.0, 1.8, 2], 135 | [6.2, 2.8, 4.8, 1.8, 2], 136 | [6.1, 3.0, 4.9, 1.8, 2], 137 | [6.4, 2.8, 5.6, 2.1, 2], 138 | [7.2, 3.0, 5.8, 1.6, 2], 139 | [7.4, 2.8, 6.1, 1.9, 2], 140 | [7.9, 3.8, 6.4, 2.0, 2], 141 | [6.4, 2.8, 5.6, 2.2, 2], 142 | [6.3, 2.8, 5.1, 1.5, 2], 143 | [6.1, 2.6, 5.6, 1.4, 2], 144 | [7.7, 3.0, 6.1, 2.3, 2], 145 | [6.3, 3.4, 5.6, 2.4, 2], 146 | [6.4, 3.1, 5.5, 1.8, 2], 147 | [6.0, 3.0, 4.8, 1.8, 2], 148 | [6.9, 3.1, 5.4, 2.1, 2], 149 | [6.7, 3.1, 5.6, 2.4, 2], 150 | [6.9, 3.1, 5.1, 2.3, 2], 151 | [5.8, 2.7, 5.1, 1.9, 2], 152 | [6.8, 3.2, 5.9, 2.3, 2], 153 | [6.7, 3.3, 5.7, 2.5, 2], 154 | [6.7, 3.0, 5.2, 2.3, 2], 155 | [6.3, 2.5, 5.0, 1.9, 2], 156 | [6.5, 3.0, 5.2, 2.0, 2], 157 | [6.2, 3.4, 5.4, 2.3, 2], 158 | [5.9, 3.0, 5.1, 1.8, 2] 159 | ]; 160 | 161 | /** 162 | * Convert Iris data arrays to `tf.Tensor`s. 163 | * 164 | * @param data The Iris input feature data, an `Array` of `Array`s, each element 165 | * of which is assumed to be a length-4 `Array` (for petal length, petal 166 | * width, sepal length, sepal width). 167 | * @param targets An `Array` of numbers, with values from the set {0, 1, 2}: 168 | * representing the true category of the Iris flower. Assumed to have the same 169 | * array length as `data`. 170 | * @param testSplit Fraction of the data at the end to split as test data: a 171 | * number between 0 and 1. 172 | * @return A length-4 `Array`, with 173 | * - training data as `tf.Tensor` of shape [numTrainExapmles, 4]. 174 | * - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3] 175 | * - test data as `tf.Tensor` of shape [numTestExamples, 4]. 176 | * - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3] 177 | */ 178 | function convertToTensors(data, targets, testSplit) { 179 | const numExamples = data.length; 180 | if (numExamples !== targets.length) { 181 | throw new Error('data and split have different numbers of examples'); 182 | } 183 | 184 | // Randomly shuffle `data` and `targets`. 185 | const indices = []; 186 | for (let i = 0; i < numExamples; ++i) { 187 | indices.push(i); 188 | } 189 | tf.util.shuffle(indices); 190 | 191 | const shuffledData = []; 192 | const shuffledTargets = []; 193 | for (let i = 0; i < numExamples; ++i) { 194 | shuffledData.push(data[indices[i]]); 195 | shuffledTargets.push(targets[indices[i]]); 196 | } 197 | 198 | // Split the data into a training set and a tet set, based on `testSplit`. 199 | const numTestExamples = Math.round(numExamples * testSplit); 200 | const numTrainExamples = numExamples - numTestExamples; 201 | 202 | const xDims = shuffledData[0].length; 203 | 204 | // Create a 2D `tf.Tensor` to hold the feature data. 205 | const xs = tf.tensor2d(shuffledData, [numExamples, xDims]); 206 | 207 | // Create a 1D `tf.Tensor` to hold the labels, and convert the number label 208 | // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]). 209 | const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES); 210 | 211 | // Split the data into training and test sets, using `slice`. 212 | const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]); 213 | const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]); 214 | const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]); 215 | const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]); 216 | return [xTrain, yTrain, xTest, yTest]; 217 | } 218 | 219 | /** 220 | * Obtains Iris data, split into training and test sets. 221 | * 222 | * @param testSplit Fraction of the data at the end to split as test data: a 223 | * number between 0 and 1. 224 | * 225 | * @param return A length-4 `Array`, with 226 | * - training data as an `Array` of length-4 `Array` of numbers. 227 | * - training labels as an `Array` of numbers, with the same length as the 228 | * return training data above. Each element of the `Array` is from the set 229 | * {0, 1, 2}. 230 | * - test data as an `Array` of length-4 `Array` of numbers. 231 | * - test labels as an `Array` of numbers, with the same length as the 232 | * return test data above. Each element of the `Array` is from the set 233 | * {0, 1, 2}. 234 | */ 235 | export function getIrisData(testSplit) { 236 | return tf.tidy(() => { 237 | const dataByClass = []; 238 | const targetsByClass = []; 239 | for (let i = 0; i < IRIS_CLASSES.length; ++i) { 240 | dataByClass.push([]); 241 | targetsByClass.push([]); 242 | } 243 | for (const example of IRIS_DATA) { 244 | const target = example[example.length - 1]; 245 | const data = example.slice(0, example.length - 1); 246 | dataByClass[target].push(data); 247 | targetsByClass[target].push(target); 248 | } 249 | 250 | const xTrains = []; 251 | const yTrains = []; 252 | const xTests = []; 253 | const yTests = []; 254 | for (let i = 0; i < IRIS_CLASSES.length; ++i) { 255 | const [xTrain, yTrain, xTest, yTest] = 256 | convertToTensors(dataByClass[i], targetsByClass[i], testSplit); 257 | xTrains.push(xTrain); 258 | yTrains.push(yTrain); 259 | xTests.push(xTest); 260 | yTests.push(yTest); 261 | } 262 | 263 | const concatAxis = 0; 264 | return [ 265 | tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis), 266 | tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis) 267 | ]; 268 | }); 269 | } 270 | -------------------------------------------------------------------------------- /app/iris-custom/iris.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs'; 2 | import ModelView from '../../src'; 3 | 4 | import { 5 | IRIS_DATA, 6 | getIrisData 7 | } from './data'; 8 | 9 | async function trainModel(xTrain, yTrain, xTest, yTest) { 10 | 11 | // Define the topology of the model: two dense layers. 12 | const model = tf.sequential(); 13 | model.add(tf.layers.dense({ 14 | units: 10, 15 | activation: 'sigmoid', 16 | inputShape: [xTrain.shape[1]] 17 | })); 18 | 19 | model.add(tf.layers.dense({ 20 | units: 10, 21 | activation: 'sigmoid', 22 | inputShape: [10] 23 | })); 24 | 25 | model.add(tf.layers.dense({ 26 | units: 3, 27 | activation: 'softmax' 28 | })); 29 | 30 | model.summary(); 31 | 32 | const optimizer = tf.train.adam(0.02); 33 | model.compile({ 34 | optimizer: optimizer, 35 | loss: 'categoricalCrossentropy', 36 | metrics: ['accuracy'] 37 | }); 38 | 39 | new ModelView(model, { 40 | printStats: true, 41 | radius: 25, 42 | renderLinks: true, 43 | xOffset: 100, 44 | renderNode(ctx, node) { 45 | const { x, y, value } = node; 46 | ctx.font = '10px Arial'; 47 | ctx.fillStyle = '#000'; 48 | ctx.textAlign = 'center'; 49 | ctx.textBaseline = 'middle'; 50 | ctx.fillText(Math.round(value * 100) / 100, x, y); 51 | }, 52 | onBeginRender: renderer => { 53 | const { renderContext } = renderer; 54 | renderContext.fillStyle = '#000'; 55 | renderContext.textAlign = 'end'; 56 | renderContext.font = '12px Arial'; 57 | renderContext.fillText('Sepal Length (cm)', 110, 110); 58 | renderContext.fillText('Sepal Width (cm)', 110, 136); 59 | renderContext.fillText('Petal Length (cm)', 110, 163); 60 | renderContext.fillText('Petal Width (cm)', 110, 190); 61 | 62 | renderContext.textAlign = 'start'; 63 | renderContext.fillText('Setosa', renderer.width - 60, 95); 64 | renderContext.fillText('Versicolor', renderer.width - 60, 150); 65 | renderContext.fillText('Virginica', renderer.width - 60, 205); 66 | }, 67 | layer: { 68 | 'dense_Dense1_input': { 69 | domain: [0, 8], 70 | color: [165, 130, 180] 71 | }, 72 | 'dense_Dense1/dense_Dense1': { 73 | color: [125, 125, 125] 74 | }, 75 | 'dense_Dense2/dense_Dense2': { 76 | color: [125, 125, 125] 77 | }, 78 | 'dense_Dense3/dense_Dense3': { 79 | nodePadding: 30 80 | } 81 | } 82 | }); 83 | 84 | await model.fit(xTrain, yTrain, { 85 | epochs: 100, 86 | validationData: [xTest, yTest] 87 | }); 88 | 89 | setInterval(() => { 90 | model.predict(tf.tensor([IRIS_DATA[Math.floor(Math.random() * IRIS_DATA.length)].slice(0, 4)])); 91 | }, 1000); 92 | 93 | return model; 94 | } 95 | 96 | const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15); 97 | 98 | export default async () => { 99 | trainModel(xTrain, yTrain, xTest, yTest); 100 | } 101 | -------------------------------------------------------------------------------- /app/iris/data.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs'; 2 | 3 | export const IRIS_CLASSES = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']; 4 | export const IRIS_NUM_CLASSES = IRIS_CLASSES.length; 5 | 6 | // Iris flowers data. Source: 7 | // https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data 8 | export const IRIS_DATA = [ 9 | [5.1, 3.5, 1.4, 0.2, 0], 10 | [4.9, 3.0, 1.4, 0.2, 0], 11 | [4.7, 3.2, 1.3, 0.2, 0], 12 | [4.6, 3.1, 1.5, 0.2, 0], 13 | [5.0, 3.6, 1.4, 0.2, 0], 14 | [5.4, 3.9, 1.7, 0.4, 0], 15 | [4.6, 3.4, 1.4, 0.3, 0], 16 | [5.0, 3.4, 1.5, 0.2, 0], 17 | [4.4, 2.9, 1.4, 0.2, 0], 18 | [4.9, 3.1, 1.5, 0.1, 0], 19 | [5.4, 3.7, 1.5, 0.2, 0], 20 | [4.8, 3.4, 1.6, 0.2, 0], 21 | [4.8, 3.0, 1.4, 0.1, 0], 22 | [4.3, 3.0, 1.1, 0.1, 0], 23 | [5.8, 4.0, 1.2, 0.2, 0], 24 | [5.7, 4.4, 1.5, 0.4, 0], 25 | [5.4, 3.9, 1.3, 0.4, 0], 26 | [5.1, 3.5, 1.4, 0.3, 0], 27 | [5.7, 3.8, 1.7, 0.3, 0], 28 | [5.1, 3.8, 1.5, 0.3, 0], 29 | [5.4, 3.4, 1.7, 0.2, 0], 30 | [5.1, 3.7, 1.5, 0.4, 0], 31 | [4.6, 3.6, 1.0, 0.2, 0], 32 | [5.1, 3.3, 1.7, 0.5, 0], 33 | [4.8, 3.4, 1.9, 0.2, 0], 34 | [5.0, 3.0, 1.6, 0.2, 0], 35 | [5.0, 3.4, 1.6, 0.4, 0], 36 | [5.2, 3.5, 1.5, 0.2, 0], 37 | [5.2, 3.4, 1.4, 0.2, 0], 38 | [4.7, 3.2, 1.6, 0.2, 0], 39 | [4.8, 3.1, 1.6, 0.2, 0], 40 | [5.4, 3.4, 1.5, 0.4, 0], 41 | [5.2, 4.1, 1.5, 0.1, 0], 42 | [5.5, 4.2, 1.4, 0.2, 0], 43 | [4.9, 3.1, 1.5, 0.1, 0], 44 | [5.0, 3.2, 1.2, 0.2, 0], 45 | [5.5, 3.5, 1.3, 0.2, 0], 46 | [4.9, 3.1, 1.5, 0.1, 0], 47 | [4.4, 3.0, 1.3, 0.2, 0], 48 | [5.1, 3.4, 1.5, 0.2, 0], 49 | [5.0, 3.5, 1.3, 0.3, 0], 50 | [4.5, 2.3, 1.3, 0.3, 0], 51 | [4.4, 3.2, 1.3, 0.2, 0], 52 | [5.0, 3.5, 1.6, 0.6, 0], 53 | [5.1, 3.8, 1.9, 0.4, 0], 54 | [4.8, 3.0, 1.4, 0.3, 0], 55 | [5.1, 3.8, 1.6, 0.2, 0], 56 | [4.6, 3.2, 1.4, 0.2, 0], 57 | [5.3, 3.7, 1.5, 0.2, 0], 58 | [5.0, 3.3, 1.4, 0.2, 0], 59 | [7.0, 3.2, 4.7, 1.4, 1], 60 | [6.4, 3.2, 4.5, 1.5, 1], 61 | [6.9, 3.1, 4.9, 1.5, 1], 62 | [5.5, 2.3, 4.0, 1.3, 1], 63 | [6.5, 2.8, 4.6, 1.5, 1], 64 | [5.7, 2.8, 4.5, 1.3, 1], 65 | [6.3, 3.3, 4.7, 1.6, 1], 66 | [4.9, 2.4, 3.3, 1.0, 1], 67 | [6.6, 2.9, 4.6, 1.3, 1], 68 | [5.2, 2.7, 3.9, 1.4, 1], 69 | [5.0, 2.0, 3.5, 1.0, 1], 70 | [5.9, 3.0, 4.2, 1.5, 1], 71 | [6.0, 2.2, 4.0, 1.0, 1], 72 | [6.1, 2.9, 4.7, 1.4, 1], 73 | [5.6, 2.9, 3.6, 1.3, 1], 74 | [6.7, 3.1, 4.4, 1.4, 1], 75 | [5.6, 3.0, 4.5, 1.5, 1], 76 | [5.8, 2.7, 4.1, 1.0, 1], 77 | [6.2, 2.2, 4.5, 1.5, 1], 78 | [5.6, 2.5, 3.9, 1.1, 1], 79 | [5.9, 3.2, 4.8, 1.8, 1], 80 | [6.1, 2.8, 4.0, 1.3, 1], 81 | [6.3, 2.5, 4.9, 1.5, 1], 82 | [6.1, 2.8, 4.7, 1.2, 1], 83 | [6.4, 2.9, 4.3, 1.3, 1], 84 | [6.6, 3.0, 4.4, 1.4, 1], 85 | [6.8, 2.8, 4.8, 1.4, 1], 86 | [6.7, 3.0, 5.0, 1.7, 1], 87 | [6.0, 2.9, 4.5, 1.5, 1], 88 | [5.7, 2.6, 3.5, 1.0, 1], 89 | [5.5, 2.4, 3.8, 1.1, 1], 90 | [5.5, 2.4, 3.7, 1.0, 1], 91 | [5.8, 2.7, 3.9, 1.2, 1], 92 | [6.0, 2.7, 5.1, 1.6, 1], 93 | [5.4, 3.0, 4.5, 1.5, 1], 94 | [6.0, 3.4, 4.5, 1.6, 1], 95 | [6.7, 3.1, 4.7, 1.5, 1], 96 | [6.3, 2.3, 4.4, 1.3, 1], 97 | [5.6, 3.0, 4.1, 1.3, 1], 98 | [5.5, 2.5, 4.0, 1.3, 1], 99 | [5.5, 2.6, 4.4, 1.2, 1], 100 | [6.1, 3.0, 4.6, 1.4, 1], 101 | [5.8, 2.6, 4.0, 1.2, 1], 102 | [5.0, 2.3, 3.3, 1.0, 1], 103 | [5.6, 2.7, 4.2, 1.3, 1], 104 | [5.7, 3.0, 4.2, 1.2, 1], 105 | [5.7, 2.9, 4.2, 1.3, 1], 106 | [6.2, 2.9, 4.3, 1.3, 1], 107 | [5.1, 2.5, 3.0, 1.1, 1], 108 | [5.7, 2.8, 4.1, 1.3, 1], 109 | [6.3, 3.3, 6.0, 2.5, 2], 110 | [5.8, 2.7, 5.1, 1.9, 2], 111 | [7.1, 3.0, 5.9, 2.1, 2], 112 | [6.3, 2.9, 5.6, 1.8, 2], 113 | [6.5, 3.0, 5.8, 2.2, 2], 114 | [7.6, 3.0, 6.6, 2.1, 2], 115 | [4.9, 2.5, 4.5, 1.7, 2], 116 | [7.3, 2.9, 6.3, 1.8, 2], 117 | [6.7, 2.5, 5.8, 1.8, 2], 118 | [7.2, 3.6, 6.1, 2.5, 2], 119 | [6.5, 3.2, 5.1, 2.0, 2], 120 | [6.4, 2.7, 5.3, 1.9, 2], 121 | [6.8, 3.0, 5.5, 2.1, 2], 122 | [5.7, 2.5, 5.0, 2.0, 2], 123 | [5.8, 2.8, 5.1, 2.4, 2], 124 | [6.4, 3.2, 5.3, 2.3, 2], 125 | [6.5, 3.0, 5.5, 1.8, 2], 126 | [7.7, 3.8, 6.7, 2.2, 2], 127 | [7.7, 2.6, 6.9, 2.3, 2], 128 | [6.0, 2.2, 5.0, 1.5, 2], 129 | [6.9, 3.2, 5.7, 2.3, 2], 130 | [5.6, 2.8, 4.9, 2.0, 2], 131 | [7.7, 2.8, 6.7, 2.0, 2], 132 | [6.3, 2.7, 4.9, 1.8, 2], 133 | [6.7, 3.3, 5.7, 2.1, 2], 134 | [7.2, 3.2, 6.0, 1.8, 2], 135 | [6.2, 2.8, 4.8, 1.8, 2], 136 | [6.1, 3.0, 4.9, 1.8, 2], 137 | [6.4, 2.8, 5.6, 2.1, 2], 138 | [7.2, 3.0, 5.8, 1.6, 2], 139 | [7.4, 2.8, 6.1, 1.9, 2], 140 | [7.9, 3.8, 6.4, 2.0, 2], 141 | [6.4, 2.8, 5.6, 2.2, 2], 142 | [6.3, 2.8, 5.1, 1.5, 2], 143 | [6.1, 2.6, 5.6, 1.4, 2], 144 | [7.7, 3.0, 6.1, 2.3, 2], 145 | [6.3, 3.4, 5.6, 2.4, 2], 146 | [6.4, 3.1, 5.5, 1.8, 2], 147 | [6.0, 3.0, 4.8, 1.8, 2], 148 | [6.9, 3.1, 5.4, 2.1, 2], 149 | [6.7, 3.1, 5.6, 2.4, 2], 150 | [6.9, 3.1, 5.1, 2.3, 2], 151 | [5.8, 2.7, 5.1, 1.9, 2], 152 | [6.8, 3.2, 5.9, 2.3, 2], 153 | [6.7, 3.3, 5.7, 2.5, 2], 154 | [6.7, 3.0, 5.2, 2.3, 2], 155 | [6.3, 2.5, 5.0, 1.9, 2], 156 | [6.5, 3.0, 5.2, 2.0, 2], 157 | [6.2, 3.4, 5.4, 2.3, 2], 158 | [5.9, 3.0, 5.1, 1.8, 2] 159 | ]; 160 | 161 | /** 162 | * Convert Iris data arrays to `tf.Tensor`s. 163 | * 164 | * @param data The Iris input feature data, an `Array` of `Array`s, each element 165 | * of which is assumed to be a length-4 `Array` (for petal length, petal 166 | * width, sepal length, sepal width). 167 | * @param targets An `Array` of numbers, with values from the set {0, 1, 2}: 168 | * representing the true category of the Iris flower. Assumed to have the same 169 | * array length as `data`. 170 | * @param testSplit Fraction of the data at the end to split as test data: a 171 | * number between 0 and 1. 172 | * @return A length-4 `Array`, with 173 | * - training data as `tf.Tensor` of shape [numTrainExapmles, 4]. 174 | * - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3] 175 | * - test data as `tf.Tensor` of shape [numTestExamples, 4]. 176 | * - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3] 177 | */ 178 | function convertToTensors(data, targets, testSplit) { 179 | const numExamples = data.length; 180 | if (numExamples !== targets.length) { 181 | throw new Error('data and split have different numbers of examples'); 182 | } 183 | 184 | // Randomly shuffle `data` and `targets`. 185 | const indices = []; 186 | for (let i = 0; i < numExamples; ++i) { 187 | indices.push(i); 188 | } 189 | tf.util.shuffle(indices); 190 | 191 | const shuffledData = []; 192 | const shuffledTargets = []; 193 | for (let i = 0; i < numExamples; ++i) { 194 | shuffledData.push(data[indices[i]]); 195 | shuffledTargets.push(targets[indices[i]]); 196 | } 197 | 198 | // Split the data into a training set and a tet set, based on `testSplit`. 199 | const numTestExamples = Math.round(numExamples * testSplit); 200 | const numTrainExamples = numExamples - numTestExamples; 201 | 202 | const xDims = shuffledData[0].length; 203 | 204 | // Create a 2D `tf.Tensor` to hold the feature data. 205 | const xs = tf.tensor2d(shuffledData, [numExamples, xDims]); 206 | 207 | // Create a 1D `tf.Tensor` to hold the labels, and convert the number label 208 | // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]). 209 | const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES); 210 | 211 | // Split the data into training and test sets, using `slice`. 212 | const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]); 213 | const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]); 214 | const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]); 215 | const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]); 216 | return [xTrain, yTrain, xTest, yTest]; 217 | } 218 | 219 | /** 220 | * Obtains Iris data, split into training and test sets. 221 | * 222 | * @param testSplit Fraction of the data at the end to split as test data: a 223 | * number between 0 and 1. 224 | * 225 | * @param return A length-4 `Array`, with 226 | * - training data as an `Array` of length-4 `Array` of numbers. 227 | * - training labels as an `Array` of numbers, with the same length as the 228 | * return training data above. Each element of the `Array` is from the set 229 | * {0, 1, 2}. 230 | * - test data as an `Array` of length-4 `Array` of numbers. 231 | * - test labels as an `Array` of numbers, with the same length as the 232 | * return test data above. Each element of the `Array` is from the set 233 | * {0, 1, 2}. 234 | */ 235 | export function getIrisData(testSplit) { 236 | return tf.tidy(() => { 237 | const dataByClass = []; 238 | const targetsByClass = []; 239 | for (let i = 0; i < IRIS_CLASSES.length; ++i) { 240 | dataByClass.push([]); 241 | targetsByClass.push([]); 242 | } 243 | for (const example of IRIS_DATA) { 244 | const target = example[example.length - 1]; 245 | const data = example.slice(0, example.length - 1); 246 | dataByClass[target].push(data); 247 | targetsByClass[target].push(target); 248 | } 249 | 250 | const xTrains = []; 251 | const yTrains = []; 252 | const xTests = []; 253 | const yTests = []; 254 | for (let i = 0; i < IRIS_CLASSES.length; ++i) { 255 | const [xTrain, yTrain, xTest, yTest] = 256 | convertToTensors(dataByClass[i], targetsByClass[i], testSplit); 257 | xTrains.push(xTrain); 258 | yTrains.push(yTrain); 259 | xTests.push(xTest); 260 | yTests.push(yTest); 261 | } 262 | 263 | const concatAxis = 0; 264 | return [ 265 | tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis), 266 | tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis) 267 | ]; 268 | }); 269 | } 270 | -------------------------------------------------------------------------------- /app/iris/iris.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs'; 2 | import ModelView from '../../src'; 3 | 4 | import { 5 | IRIS_DATA, 6 | getIrisData 7 | } from './data'; 8 | 9 | async function trainModel(xTrain, yTrain, xTest, yTest) { 10 | 11 | // Define the topology of the model: two dense layers. 12 | const model = tf.sequential(); 13 | model.add(tf.layers.dense({ 14 | units: 10, 15 | activation: 'sigmoid', 16 | inputShape: [xTrain.shape[1]] 17 | })); 18 | 19 | model.add(tf.layers.dense({ 20 | units: 3, 21 | activation: 'softmax' 22 | })); 23 | 24 | model.summary(); 25 | 26 | const optimizer = tf.train.adam(0.02); 27 | model.compile({ 28 | optimizer: optimizer, 29 | loss: 'categoricalCrossentropy', 30 | metrics: ['accuracy'] 31 | }); 32 | 33 | new ModelView(model, { 34 | printStats: true, 35 | radius: 25, 36 | renderLinks: true, 37 | layer: { 38 | 'dense_Dense1_input': { 39 | domain: [0, 7] 40 | }, 41 | 'dense_Dense2/dense_Dense2': { 42 | nodePadding: 30 43 | } 44 | } 45 | }); 46 | 47 | await model.fit(xTrain, yTrain, { 48 | epochs: 100, 49 | validationData: [xTest, yTest] 50 | }); 51 | 52 | setInterval(() => { 53 | model.predict(tf.tensor([IRIS_DATA[Math.floor(Math.random() * IRIS_DATA.length)].slice(0, 4)])); 54 | }, 1000); 55 | 56 | return model; 57 | } 58 | 59 | const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15); 60 | 61 | export default async () => { 62 | trainModel(xTrain, yTrain, xTest, yTest); 63 | } 64 | -------------------------------------------------------------------------------- /app/mnist/data.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | 20 | export const IMAGE_H = 28; 21 | export const IMAGE_W = 28; 22 | const IMAGE_SIZE = IMAGE_H * IMAGE_W; 23 | const NUM_CLASSES = 10; 24 | const NUM_DATASET_ELEMENTS = 65000; 25 | 26 | const NUM_TRAIN_ELEMENTS = 55000; 27 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; 28 | 29 | const MNIST_IMAGES_SPRITE_PATH = 30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; 31 | const MNIST_LABELS_PATH = 32 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; 33 | 34 | /** 35 | * A class that fetches the sprited MNIST dataset and provide data as 36 | * tf.Tensors. 37 | */ 38 | export class MnistData { 39 | constructor() {} 40 | 41 | async load() { 42 | // Make a request for the MNIST sprited image. 43 | const img = new Image(); 44 | const canvas = document.createElement('canvas'); 45 | const ctx = canvas.getContext('2d'); 46 | const imgRequest = new Promise((resolve, reject) => { 47 | img.crossOrigin = ''; 48 | img.onload = () => { 49 | img.width = img.naturalWidth; 50 | img.height = img.naturalHeight; 51 | 52 | const datasetBytesBuffer = 53 | new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); 54 | 55 | const chunkSize = 5000; 56 | canvas.width = img.width; 57 | canvas.height = chunkSize; 58 | 59 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { 60 | const datasetBytesView = new Float32Array( 61 | datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, 62 | IMAGE_SIZE * chunkSize); 63 | ctx.drawImage( 64 | img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, 65 | chunkSize); 66 | 67 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); 68 | 69 | for (let j = 0; j < imageData.data.length / 4; j++) { 70 | // All channels hold an equal value since the image is grayscale, so 71 | // just read the red channel. 72 | datasetBytesView[j] = imageData.data[j * 4] / 255; 73 | } 74 | } 75 | this.datasetImages = new Float32Array(datasetBytesBuffer); 76 | 77 | resolve(); 78 | }; 79 | img.src = MNIST_IMAGES_SPRITE_PATH; 80 | }); 81 | 82 | const labelsRequest = fetch(MNIST_LABELS_PATH); 83 | const [imgResponse, labelsResponse] = 84 | await Promise.all([imgRequest, labelsRequest]); 85 | 86 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); 87 | 88 | // Slice the the images and labels into train and test sets. 89 | this.trainImages = 90 | this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 91 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 92 | this.trainLabels = 93 | this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); 94 | this.testLabels = 95 | this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); 96 | } 97 | 98 | /** 99 | * Get all training data as a data tensor and a labels tensor. 100 | * 101 | * @returns 102 | * xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`. 103 | * labels: The one-hot encoded labels tensor, of shape 104 | * `[numTrainExamples, 10]`. 105 | */ 106 | getTrainData() { 107 | const xs = tf.tensor4d( 108 | this.trainImages, 109 | [this.trainImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]); 110 | const labels = tf.tensor2d( 111 | this.trainLabels, [this.trainLabels.length / NUM_CLASSES, NUM_CLASSES]); 112 | return {xs, labels}; 113 | } 114 | 115 | /** 116 | * Get all test data as a data tensor a a labels tensor. 117 | * 118 | * @param {number} numExamples Optional number of examples to get. If not 119 | * provided, 120 | * all test examples will be returned. 121 | * @returns 122 | * xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`. 123 | * labels: The one-hot encoded labels tensor, of shape 124 | * `[numTestExamples, 10]`. 125 | */ 126 | getTestData(numExamples) { 127 | let xs = tf.tensor4d( 128 | this.testImages, 129 | [this.testImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]); 130 | let labels = tf.tensor2d( 131 | this.testLabels, [this.testLabels.length / NUM_CLASSES, NUM_CLASSES]); 132 | 133 | if (numExamples != null) { 134 | xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]); 135 | labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]); 136 | } 137 | return {xs, labels}; 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /app/models/mnist-dense/group1-shard1of1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cstefanache/tfjs-model-view/860fed91c837fe28b3119e18c1f38d5298a6eddc/app/models/mnist-dense/group1-shard1of1 -------------------------------------------------------------------------------- /app/models/mnist-dense/group2-shard1of1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cstefanache/tfjs-model-view/860fed91c837fe28b3119e18c1f38d5298a6eddc/app/models/mnist-dense/group2-shard1of1 -------------------------------------------------------------------------------- /app/models/mnist-dense/group3-shard1of1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cstefanache/tfjs-model-view/860fed91c837fe28b3119e18c1f38d5298a6eddc/app/models/mnist-dense/group3-shard1of1 -------------------------------------------------------------------------------- /app/models/mnist-dense/model.json: -------------------------------------------------------------------------------- 1 | {"modelTopology": {"keras_version": "2.1.6", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": [{"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "batch_input_shape": [null, 784], "dtype": "float32", "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_1", "trainable": true, "rate": 0.2, "noise_shape": null, "seed": null}}, {"class_name": "Dense", "config": {"name": "dense_2", "trainable": true, "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_2", "trainable": true, "rate": 0.2, "noise_shape": null, "seed": null}}, {"class_name": "Dense", "config": {"name": "dense_3", "trainable": true, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "training_config": {"optimizer_config": {"class_name": "RMSprop", "config": {"lr": 0.0010000000474974513, "rho": 0.8999999761581421, "decay": 0.0, "epsilon": 1e-07}}, "loss": "categorical_crossentropy", "metrics": ["accuracy"], "sample_weight_mode": null, "loss_weights": null}}, "weightsManifest": [{"paths": ["group1-shard1of1"], "weights": [{"name": "dense_1/kernel", "shape": [784, 512], "dtype": "float32"}, {"name": "dense_1/bias", "shape": [512], "dtype": "float32"}]}, {"paths": ["group2-shard1of1"], "weights": [{"name": "dense_2/kernel", "shape": [512, 512], "dtype": "float32"}, {"name": "dense_2/bias", "shape": [512], "dtype": "float32"}]}, {"paths": ["group3-shard1of1"], "weights": [{"name": "dense_3/kernel", "shape": [512, 10], "dtype": "float32"}, {"name": "dense_3/bias", "shape": [10], "dtype": "float32"}]}]} -------------------------------------------------------------------------------- /app/tiny/tiny.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs'; 2 | import ModelView from '../../src'; 3 | 4 | export default async () => { 5 | 6 | // Define the topology of the model: two dense layers. 7 | const model = tf.sequential(); 8 | model.add(tf.layers.dense({ 9 | units: 2, 10 | activation: 'tanh', 11 | inputShape: [1] 12 | })); 13 | 14 | model.add(tf.layers.dense({ 15 | units: 2, 16 | activation: 'relu', 17 | inputShape: [2] 18 | })); 19 | 20 | model.add(tf.layers.dense({ 21 | units: 3, 22 | activation: 'softplus', 23 | inputShape: [2] 24 | })); 25 | 26 | model.add(tf.layers.dense({ 27 | units: 1, 28 | activation: 'softsign' 29 | })); 30 | 31 | model.summary(); 32 | 33 | new ModelView(model, { 34 | radius: 25, 35 | renderLinks: true 36 | }); 37 | 38 | setTimeout(() => { 39 | const data = model.predict(tf.tensor([[1]])).dataSync(); 40 | console.log(data) 41 | }, 10); 42 | return model 43 | } 44 | -------------------------------------------------------------------------------- /lib/tfjs-model-view.js: -------------------------------------------------------------------------------- 1 | (function webpackUniversalModuleDefinition(root, factory) { 2 | if(typeof exports === 'object' && typeof module === 'object') 3 | module.exports = factory(); 4 | else if(typeof define === 'function' && define.amd) 5 | define("tfjs-model-view", [], factory); 6 | else if(typeof exports === 'object') 7 | exports["tfjs-model-view"] = factory(); 8 | else 9 | root["tfjs-model-view"] = factory(); 10 | })(typeof self !== 'undefined' ? self : this, function() { 11 | return /******/ (function(modules) { // webpackBootstrap 12 | /******/ // The module cache 13 | /******/ var installedModules = {}; 14 | /******/ 15 | /******/ // The require function 16 | /******/ function __webpack_require__(moduleId) { 17 | /******/ 18 | /******/ // Check if module is in cache 19 | /******/ if(installedModules[moduleId]) { 20 | /******/ return installedModules[moduleId].exports; 21 | /******/ } 22 | /******/ // Create a new module (and put it into the cache) 23 | /******/ var module = installedModules[moduleId] = { 24 | /******/ i: moduleId, 25 | /******/ l: false, 26 | /******/ exports: {} 27 | /******/ }; 28 | /******/ 29 | /******/ // Execute the module function 30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); 31 | /******/ 32 | /******/ // Flag the module as loaded 33 | /******/ module.l = true; 34 | /******/ 35 | /******/ // Return the exports of the module 36 | /******/ return module.exports; 37 | /******/ } 38 | /******/ 39 | /******/ 40 | /******/ // expose the modules object (__webpack_modules__) 41 | /******/ __webpack_require__.m = modules; 42 | /******/ 43 | /******/ // expose the module cache 44 | /******/ __webpack_require__.c = installedModules; 45 | /******/ 46 | /******/ // define getter function for harmony exports 47 | /******/ __webpack_require__.d = function(exports, name, getter) { 48 | /******/ if(!__webpack_require__.o(exports, name)) { 49 | /******/ Object.defineProperty(exports, name, { enumerable: true, get: getter }); 50 | /******/ } 51 | /******/ }; 52 | /******/ 53 | /******/ // define __esModule on exports 54 | /******/ __webpack_require__.r = function(exports) { 55 | /******/ if(typeof Symbol !== 'undefined' && Symbol.toStringTag) { 56 | /******/ Object.defineProperty(exports, Symbol.toStringTag, { value: 'Module' }); 57 | /******/ } 58 | /******/ Object.defineProperty(exports, '__esModule', { value: true }); 59 | /******/ }; 60 | /******/ 61 | /******/ // create a fake namespace object 62 | /******/ // mode & 1: value is a module id, require it 63 | /******/ // mode & 2: merge all properties of value into the ns 64 | /******/ // mode & 4: return value when already ns object 65 | /******/ // mode & 8|1: behave like require 66 | /******/ __webpack_require__.t = function(value, mode) { 67 | /******/ if(mode & 1) value = __webpack_require__(value); 68 | /******/ if(mode & 8) return value; 69 | /******/ if((mode & 4) && typeof value === 'object' && value && value.__esModule) return value; 70 | /******/ var ns = Object.create(null); 71 | /******/ __webpack_require__.r(ns); 72 | /******/ Object.defineProperty(ns, 'default', { enumerable: true, value: value }); 73 | /******/ if(mode & 2 && typeof value != 'string') for(var key in value) __webpack_require__.d(ns, key, function(key) { return value[key]; }.bind(null, key)); 74 | /******/ return ns; 75 | /******/ }; 76 | /******/ 77 | /******/ // getDefaultExport function for compatibility with non-harmony modules 78 | /******/ __webpack_require__.n = function(module) { 79 | /******/ var getter = module && module.__esModule ? 80 | /******/ function getDefault() { return module['default']; } : 81 | /******/ function getModuleExports() { return module; }; 82 | /******/ __webpack_require__.d(getter, 'a', getter); 83 | /******/ return getter; 84 | /******/ }; 85 | /******/ 86 | /******/ // Object.prototype.hasOwnProperty.call 87 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); }; 88 | /******/ 89 | /******/ // __webpack_public_path__ 90 | /******/ __webpack_require__.p = ""; 91 | /******/ 92 | /******/ 93 | /******/ // Load entry module and return exports 94 | /******/ return __webpack_require__(__webpack_require__.s = "./src/index.js"); 95 | /******/ }) 96 | /************************************************************************/ 97 | /******/ ({ 98 | 99 | /***/ "./src/default.config.js": 100 | /*!*******************************!*\ 101 | !*** ./src/default.config.js ***! 102 | \*******************************/ 103 | /*! no static exports found */ 104 | /***/ (function(module, exports, __webpack_require__) { 105 | 106 | "use strict"; 107 | 108 | 109 | Object.defineProperty(exports, "__esModule", { 110 | value: true 111 | }); 112 | exports.default = { 113 | renderer: 'canvas', 114 | 115 | radius: 6, 116 | nodePadding: 2, 117 | layerPadding: 20, 118 | groupPadding: 1, 119 | 120 | xPadding: 10, 121 | yPadding: 10, 122 | 123 | renderLinks: false, 124 | plotActivations: false, 125 | nodeStroke: true, 126 | 127 | onRendererInitialized: function onRendererInitialized(renderer) { 128 | document.body.appendChild(renderer.canvas); 129 | } 130 | }; 131 | module.exports = exports['default']; 132 | 133 | /***/ }), 134 | 135 | /***/ "./src/index.js": 136 | /*!**********************!*\ 137 | !*** ./src/index.js ***! 138 | \**********************/ 139 | /*! no static exports found */ 140 | /***/ (function(module, exports, __webpack_require__) { 141 | 142 | "use strict"; 143 | 144 | 145 | Object.defineProperty(exports, "__esModule", { 146 | value: true 147 | }); 148 | 149 | var _modelParser = __webpack_require__(/*! ./model-parser */ "./src/model-parser.js"); 150 | 151 | var _modelParser2 = _interopRequireDefault(_modelParser); 152 | 153 | var _canvas = __webpack_require__(/*! ./renderers/canvas.renderer */ "./src/renderers/canvas.renderer.js"); 154 | 155 | var _canvas2 = _interopRequireDefault(_canvas); 156 | 157 | var _default = __webpack_require__(/*! ./default.config */ "./src/default.config.js"); 158 | 159 | var _default2 = _interopRequireDefault(_default); 160 | 161 | function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } 162 | 163 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 164 | 165 | var ModelView = function ModelView(model, customConfig) { 166 | _classCallCheck(this, ModelView); 167 | 168 | var config = Object.assign({}, _default2.default, customConfig); 169 | var onRendererInitialized = config.onRendererInitialized; 170 | 171 | var renderer = void 0; 172 | 173 | config.predictCallback = function (input) { 174 | if (renderer) { 175 | renderer.update(model, input); 176 | renderer.render(); 177 | } 178 | }; 179 | 180 | config.hookCallback = function (layer) { 181 | if (renderer) { 182 | renderer.updateValues(layer); 183 | renderer.render(); 184 | } 185 | }; 186 | 187 | (0, _modelParser2.default)(model, config).then(function (res) { 188 | renderer = new _canvas2.default(config, res); 189 | if (onRendererInitialized) { 190 | onRendererInitialized(renderer); 191 | } 192 | }); 193 | }; 194 | 195 | exports.default = ModelView; 196 | module.exports = exports['default']; 197 | 198 | /***/ }), 199 | 200 | /***/ "./src/model-parser.js": 201 | /*!*****************************!*\ 202 | !*** ./src/model-parser.js ***! 203 | \*****************************/ 204 | /*! no static exports found */ 205 | /***/ (function(module, exports, __webpack_require__) { 206 | 207 | "use strict"; 208 | 209 | 210 | Object.defineProperty(exports, "__esModule", { 211 | value: true 212 | }); 213 | 214 | var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; 215 | 216 | var parseModel = function () { 217 | var _ref = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(model, options) { 218 | var parsed, parserConfig, parseLayer, predict, layerArr; 219 | return regeneratorRuntime.wrap(function _callee3$(_context3) { 220 | while (1) { 221 | switch (_context3.prev = _context3.next) { 222 | case 0: 223 | parseLayer = function parseLayer(layer, nextColumn) { 224 | var _this = this; 225 | 226 | var name = layer.name, 227 | input = layer.input, 228 | inputs = layer.inputs, 229 | shape = layer.shape, 230 | sourceLayer = layer.sourceLayer; 231 | 232 | var _ref2 = sourceLayer || {}, 233 | getWeights = _ref2.getWeights, 234 | setCallHook = _ref2.setCallHook, 235 | activation = _ref2.activation; 236 | 237 | var currentLayer = { 238 | previousColumn: [], 239 | name: name, 240 | shape: shape, 241 | weights: {}, 242 | getWeights: noop, 243 | mapPosition: Object.keys(parsed.layerMap).length 244 | }; 245 | 246 | parsed.layerMap[name] = currentLayer; 247 | parsed.layerArr.unshift(currentLayer); 248 | 249 | if (activation) { 250 | var className = activation.getClassName(); 251 | currentLayer.activation = { 252 | name: className 253 | }; 254 | } 255 | 256 | if (setCallHook) { 257 | sourceLayer.setCallHook(function () { 258 | var _ref3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(layerInput) { 259 | var i; 260 | return regeneratorRuntime.wrap(function _callee$(_context) { 261 | while (1) { 262 | switch (_context.prev = _context.next) { 263 | case 0: 264 | currentLayer.getWeights(); 265 | currentLayer.activations = []; 266 | for (i = 0; i < layerInput.length; i++) { 267 | currentLayer.activations.push(layerInput[i].dataSync()); 268 | } 269 | parserConfig.hookCallback(currentLayer); 270 | 271 | case 4: 272 | case 'end': 273 | return _context.stop(); 274 | } 275 | } 276 | }, _callee, _this); 277 | })); 278 | 279 | return function (_x3) { 280 | return _ref3.apply(this, arguments); 281 | }; 282 | }()); 283 | } 284 | 285 | if (getWeights) { 286 | 287 | currentLayer.getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() { 288 | var weights, i, weight, rankType, weightName; 289 | return regeneratorRuntime.wrap(function _callee2$(_context2) { 290 | while (1) { 291 | switch (_context2.prev = _context2.next) { 292 | case 0: 293 | _context2.next = 2; 294 | return sourceLayer.getWeights(); 295 | 296 | case 2: 297 | weights = _context2.sent; 298 | i = 0; 299 | 300 | case 4: 301 | if (!(i < weights.length)) { 302 | _context2.next = 16; 303 | break; 304 | } 305 | 306 | weight = weights[i]; 307 | rankType = weight.rankType, weightName = weight.name; 308 | 309 | currentLayer.hasWeights = true; 310 | _context2.t0 = weightName; 311 | _context2.next = 11; 312 | return weights[i].dataSync(); 313 | 314 | case 11: 315 | _context2.t1 = _context2.sent; 316 | currentLayer.weights[rankType] = { 317 | name: _context2.t0, 318 | values: _context2.t1 319 | }; 320 | 321 | case 13: 322 | i++; 323 | _context2.next = 4; 324 | break; 325 | 326 | case 16: 327 | case 'end': 328 | return _context2.stop(); 329 | } 330 | } 331 | }, _callee2, _this); 332 | })); 333 | 334 | currentLayer.getWeights(); 335 | } 336 | 337 | if (inputs) { 338 | inputs.forEach(function (inp) { 339 | parseLayer(inp, currentLayer.previousColumn); 340 | }); 341 | } else { 342 | parseLayer(input, currentLayer.previousColumn); 343 | } 344 | 345 | if (nextColumn) { 346 | nextColumn.push(currentLayer); 347 | } 348 | 349 | return currentLayer; 350 | }; 351 | 352 | parsed = { 353 | layerMap: {}, 354 | layerArr: [] 355 | }; 356 | parserConfig = _extends({ 357 | predictCallback: noop, 358 | hookCallback: noop 359 | }, options); 360 | predict = model.predict; 361 | 362 | 363 | model.predict = function () { 364 | for (var _len = arguments.length, args = Array(_len), _key = 0; _key < _len; _key++) { 365 | args[_key] = arguments[_key]; 366 | } 367 | 368 | var result = predict.apply(model, args); 369 | model.outputData = result.dataSync(); 370 | parserConfig.predictCallback(args); 371 | return result; 372 | }; 373 | 374 | _context3.next = 7; 375 | return parseLayer(model.layers[model.layers.length - 1].output); 376 | 377 | case 7: 378 | parsed.model = _context3.sent; 379 | 380 | 381 | if (options.printStats) { 382 | layerArr = parsed.layerArr; 383 | 384 | console.log(new Array(10).join('-')); 385 | layerArr.forEach(function (layer) { 386 | console.log('Layer: ' + layer.name); 387 | }); 388 | } 389 | 390 | return _context3.abrupt('return', parsed); 391 | 392 | case 10: 393 | case 'end': 394 | return _context3.stop(); 395 | } 396 | } 397 | }, _callee3, this); 398 | })); 399 | 400 | return function parseModel(_x, _x2) { 401 | return _ref.apply(this, arguments); 402 | }; 403 | }(); 404 | 405 | function _asyncToGenerator(fn) { return function () { var gen = fn.apply(this, arguments); return new Promise(function (resolve, reject) { function step(key, arg) { try { var info = gen[key](arg); var value = info.value; } catch (error) { reject(error); return; } if (info.done) { resolve(value); } else { return Promise.resolve(value).then(function (value) { step("next", value); }, function (err) { step("throw", err); }); } } return step("next"); }); }; } 406 | 407 | function noop() {} 408 | 409 | exports.default = parseModel; 410 | module.exports = exports['default']; 411 | 412 | /***/ }), 413 | 414 | /***/ "./src/renderers/abstract.renderer.js": 415 | /*!********************************************!*\ 416 | !*** ./src/renderers/abstract.renderer.js ***! 417 | \********************************************/ 418 | /*! no static exports found */ 419 | /***/ (function(module, exports, __webpack_require__) { 420 | 421 | "use strict"; 422 | 423 | 424 | Object.defineProperty(exports, "__esModule", { 425 | value: true 426 | }); 427 | 428 | var _slicedToArray = function () { function sliceIterator(arr, i) { var _arr = []; var _n = true; var _d = false; var _e = undefined; try { for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) { _arr.push(_s.value); if (i && _arr.length === i) break; } } catch (err) { _d = true; _e = err; } finally { try { if (!_n && _i["return"]) _i["return"](); } finally { if (_d) throw _e; } } return _arr; } return function (arr, i) { if (Array.isArray(arr)) { return arr; } else if (Symbol.iterator in Object(arr)) { return sliceIterator(arr, i); } else { throw new TypeError("Invalid attempt to destructure non-iterable instance"); } }; }(); 429 | 430 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 431 | 432 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 433 | 434 | var colors = [[6, 57, 143], [0, 107, 92], [216, 139, 0], [180, 0, 85], [106, 2, 143], [216, 109, 0], [2, 105, 134], [0, 142, 103], [201, 0, 39], [139, 11, 215], [171, 141, 0]]; 435 | 436 | var AbstractRenderer = function () { 437 | function AbstractRenderer(config, initData) { 438 | var _this = this; 439 | 440 | _classCallCheck(this, AbstractRenderer); 441 | 442 | var xPadding = config.xPadding, 443 | yPadding = config.yPadding, 444 | xOffset = config.xOffset, 445 | _config$layer = config.layer, 446 | layer = _config$layer === undefined ? {} : _config$layer; 447 | var layerArr = initData.layerArr; 448 | 449 | 450 | var maxHeight = (yPadding || 1) * 2; 451 | var cx = (xPadding || 0) + (xOffset || 0); 452 | 453 | function processColumn(lyr) { 454 | var col = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : 0; 455 | 456 | lyr.column = col; 457 | lyr.previousColumn.forEach(function (l) { 458 | processColumn(l, col + 1); 459 | }); 460 | } 461 | 462 | processColumn(layerArr[layerArr.length - 1]); 463 | 464 | layerArr.forEach(function (l, lindex) { 465 | var name = l.name, 466 | shape = l.shape, 467 | previousColumn = l.previousColumn; 468 | 469 | _this.outputLayer = l; 470 | var customConfig = layer[name] || {}; 471 | 472 | var layerConfig = Object.assign({}, config, customConfig); 473 | var radius = layerConfig.radius, 474 | nodePadding = layerConfig.nodePadding, 475 | layerPadding = layerConfig.layerPadding, 476 | groupPadding = layerConfig.groupPadding, 477 | _layerConfig$domain = layerConfig.domain, 478 | domain = _layerConfig$domain === undefined ? [0, 1] : _layerConfig$domain, 479 | renderLinks = layerConfig.renderLinks, 480 | renderNode = layerConfig.renderNode, 481 | nodeStroke = layerConfig.nodeStroke, 482 | reshape = layerConfig.reshape; 483 | 484 | 485 | var color = layerConfig.color || (lindex < colors.length ? colors[lindex] : [0, 0, 0]); 486 | 487 | var _Object$assign = Object.assign([1, 1, 1], shape.slice(1)), 488 | _Object$assign2 = _slicedToArray(_Object$assign, 3), 489 | rows = _Object$assign2[0], 490 | cols = _Object$assign2[1], 491 | groups = _Object$assign2[2]; 492 | 493 | var totalNodes = rows * cols * groups; 494 | 495 | if (reshape) { 496 | var _Object$assign3 = Object.assign([1, 1, 1], reshape), 497 | _Object$assign4 = _slicedToArray(_Object$assign3, 3), 498 | nr = _Object$assign4[0], 499 | nc = _Object$assign4[1], 500 | ng = _Object$assign4[2]; 501 | 502 | if (nr * nc * ng !== totalNodes) { 503 | throw new Error("Unable to reshape from [" + rows + ", " + cols + ", " + groups + "] to [" + nr + ", " + nc + ", " + ng + "]"); 504 | } 505 | 506 | rows = nr; 507 | cols = nc; 508 | groups = ng; 509 | } 510 | 511 | cx += layerPadding; 512 | 513 | var step = radius + nodePadding; 514 | var width = layerPadding + cols * step; 515 | var nodes = []; 516 | var height = 0; 517 | 518 | for (var row = 0; row < rows; row++) { 519 | for (var col = 0; col < cols; col++) { 520 | for (var group = 0; group < groups; group++) { 521 | var y = radius + row * step + group * rows * step + group * groupPadding; 522 | nodes.push({ 523 | x: cx + col * step, 524 | y: y, 525 | radius: radius 526 | }); 527 | height = y; 528 | } 529 | } 530 | } 531 | 532 | height += groupPadding + radius; 533 | maxHeight = Math.max(maxHeight, height); 534 | 535 | Object.assign(l, { 536 | name: name, 537 | x: cx, 538 | layerWidth: width, 539 | layerHeight: height, 540 | radius: radius, 541 | nodes: nodes, 542 | domain: domain, 543 | renderLinks: renderLinks, 544 | renderNode: renderNode, 545 | nodeStroke: nodeStroke, 546 | color: color, 547 | previousLayers: previousColumn.map(function (lyr) { 548 | return lyr.name; 549 | }) 550 | }); 551 | 552 | cx += width; 553 | }); 554 | 555 | cx += xPadding || 0; 556 | 557 | layerArr.forEach(function (l) { 558 | var offsetY = Math.floor((maxHeight - l.layerHeight) / 2); 559 | l.nodes.forEach(function (nd) { 560 | return nd.y += offsetY; 561 | }); 562 | }); 563 | 564 | Object.assign(this, { width: cx, height: maxHeight }); 565 | 566 | this.layers = layerArr; 567 | this.layersMap = layerArr.reduce(function (memo, item) { 568 | memo[item.name] = item; 569 | return memo; 570 | }, {}); 571 | } 572 | 573 | _createClass(AbstractRenderer, [{ 574 | key: "update", 575 | value: function update(model, input) { 576 | var _this2 = this; 577 | 578 | if (input) { 579 | model.inputs.forEach(function (inputLayer, index) { 580 | var syntheticLayer = _this2.layersMap[inputLayer.name]; 581 | _this2.updateLayerValues(syntheticLayer, input[index].dataSync()); 582 | }); 583 | } 584 | 585 | this.updateLayerValues(this.outputLayer, model.outputData); 586 | } 587 | }, { 588 | key: "updateLayerValues", 589 | value: function updateLayerValues(layer, data) { 590 | for (var i = 0; i < layer.nodes.length; i++) { 591 | layer.nodes[i].value = data[i]; 592 | } 593 | } 594 | }, { 595 | key: "updateValues", 596 | value: function updateValues(layer) { 597 | var _this3 = this; 598 | 599 | var syntheticLayer = this.layersMap[layer.name]; 600 | syntheticLayer.weights = layer.weights; 601 | syntheticLayer.previousColumn.forEach(function (col, idx) { 602 | _this3.updateLayerValues(col, layer.activations[idx]); 603 | }); 604 | } 605 | }]); 606 | 607 | return AbstractRenderer; 608 | }(); 609 | 610 | exports.default = AbstractRenderer; 611 | module.exports = exports["default"]; 612 | 613 | /***/ }), 614 | 615 | /***/ "./src/renderers/canvas.renderer.js": 616 | /*!******************************************!*\ 617 | !*** ./src/renderers/canvas.renderer.js ***! 618 | \******************************************/ 619 | /*! no static exports found */ 620 | /***/ (function(module, exports, __webpack_require__) { 621 | 622 | "use strict"; 623 | 624 | 625 | Object.defineProperty(exports, "__esModule", { 626 | value: true 627 | }); 628 | 629 | var _slicedToArray = function () { function sliceIterator(arr, i) { var _arr = []; var _n = true; var _d = false; var _e = undefined; try { for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) { _arr.push(_s.value); if (i && _arr.length === i) break; } } catch (err) { _d = true; _e = err; } finally { try { if (!_n && _i["return"]) _i["return"](); } finally { if (_d) throw _e; } } return _arr; } return function (arr, i) { if (Array.isArray(arr)) { return arr; } else if (Symbol.iterator in Object(arr)) { return sliceIterator(arr, i); } else { throw new TypeError("Invalid attempt to destructure non-iterable instance"); } }; }(); 630 | 631 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 632 | 633 | var _abstract = __webpack_require__(/*! ./abstract.renderer */ "./src/renderers/abstract.renderer.js"); 634 | 635 | var _abstract2 = _interopRequireDefault(_abstract); 636 | 637 | function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } 638 | 639 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 640 | 641 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; } 642 | 643 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } 644 | 645 | var CanvasRenderer = function (_AbstractRenderer) { 646 | _inherits(CanvasRenderer, _AbstractRenderer); 647 | 648 | function CanvasRenderer(config, initData) { 649 | _classCallCheck(this, CanvasRenderer); 650 | 651 | var _this = _possibleConstructorReturn(this, (CanvasRenderer.__proto__ || Object.getPrototypeOf(CanvasRenderer)).call(this, config, initData)); 652 | 653 | var canvas = document.createElement('canvas'); 654 | 655 | var onBeginRender = config.onBeginRender, 656 | onEndRender = config.onEndRender; 657 | 658 | 659 | Object.assign(_this, { canvas: canvas, onBeginRender: onBeginRender, onEndRender: onEndRender }); 660 | 661 | canvas.setAttribute('width', _this.width); 662 | canvas.setAttribute('height', _this.height); 663 | 664 | _this.renderContext = canvas.getContext('2d'); 665 | _this.renderElement = canvas; 666 | return _this; 667 | } 668 | 669 | _createClass(CanvasRenderer, [{ 670 | key: 'render', 671 | value: function render() { 672 | var _this2 = this; 673 | 674 | window.requestAnimationFrame(function () { 675 | var onBeginRender = _this2.onBeginRender, 676 | onEndRender = _this2.onEndRender; 677 | 678 | _this2.renderContext.clearRect(0, 0, _this2.width, _this2.height); 679 | 680 | if (onBeginRender) { 681 | onBeginRender(_this2); 682 | } 683 | 684 | _this2.layers.forEach(function (layer) { 685 | var radius = layer.radius, 686 | nodes = layer.nodes, 687 | _layer$domain = _slicedToArray(layer.domain, 2), 688 | min = _layer$domain[0], 689 | max = _layer$domain[1], 690 | previousColumn = layer.previousColumn, 691 | renderLinks = layer.renderLinks, 692 | renderNode = layer.renderNode, 693 | weights = layer.weights, 694 | _layer$color = _slicedToArray(layer.color, 3), 695 | r = _layer$color[0], 696 | g = _layer$color[1], 697 | b = _layer$color[2], 698 | nodeStroke = layer.nodeStroke; 699 | 700 | var kernel = weights[2]; 701 | 702 | var leftSideNodes = void 0; 703 | 704 | if (renderLinks) { 705 | leftSideNodes = previousColumn.reduce(function (memo, prevLayer) { 706 | return memo.concat(prevLayer.nodes); 707 | }, []); 708 | } 709 | 710 | nodes.forEach(function (node, index) { 711 | var nx = node.x, 712 | ny = node.y, 713 | value = node.value; 714 | 715 | 716 | if (renderLinks) { 717 | leftSideNodes.forEach(function (leftNode, leftIdx) { 718 | _this2.renderContext.beginPath(); 719 | var hasWeight = kernel && kernel.values; 720 | var weightVal = hasWeight ? kernel.values[index * leftIdx] : 0.5; 721 | if (hasWeight) { 722 | _this2.renderContext.strokeStyle = weightVal > 0 ? 'rgb(0, 0, 255, ' + weightVal + ')' : 'rgb(255, 0, 0, ' + Math.abs(weightVal) + ')'; 723 | } else { 724 | _this2.renderContext.strokeStyle = 'rgba(0,0,0,.5)'; 725 | } 726 | _this2.renderContext.moveTo(leftNode.x + leftNode.radius / 2, leftNode.y); 727 | _this2.renderContext.lineTo(node.x - node.radius / 2, node.y); 728 | _this2.renderContext.stroke(); 729 | }); 730 | } 731 | _this2.renderContext.strokeStyle = 'rgb(' + r + ', ' + g + ', ' + b + ')'; 732 | var domainValue = value / (max + min); 733 | if (!isNaN(domainValue)) { 734 | _this2.renderContext.fillStyle = 'rgba(' + r + ', ' + g + ', ' + b + ', ' + domainValue + ')'; 735 | } else { 736 | _this2.renderContext.fillStyle = '#FFF'; 737 | } 738 | _this2.renderContext.beginPath(); 739 | _this2.renderContext.arc(nx, ny, radius / 2, 0, 2 * Math.PI); 740 | if (radius > 3 && nodeStroke) { 741 | _this2.renderContext.stroke(); 742 | } 743 | _this2.renderContext.fill(); 744 | 745 | if (renderNode) { 746 | renderNode(_this2.renderContext, node, index); 747 | } 748 | }); 749 | }); 750 | 751 | if (onEndRender) { 752 | onEndRender(_this2); 753 | } 754 | }); 755 | } 756 | }]); 757 | 758 | return CanvasRenderer; 759 | }(_abstract2.default); 760 | 761 | exports.default = CanvasRenderer; 762 | module.exports = exports['default']; 763 | 764 | /***/ }) 765 | 766 | /******/ }); 767 | }); 768 | //# sourceMappingURL=data:application/json;charset=utf-8;base64, -------------------------------------------------------------------------------- /lib/tfjs-model-view.min.js: -------------------------------------------------------------------------------- 1 | !function(e,t){"object"==typeof exports&&"object"==typeof module?module.exports=t():"function"==typeof define&&define.amd?define("tfjs-model-view",[],t):"object"==typeof exports?exports["tfjs-model-view"]=t():e["tfjs-model-view"]=t()}("undefined"!=typeof self?self:this,(function(){return function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}return r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=0)}([function(e,t,r){"use strict";Object.defineProperty(t,"__esModule",{value:!0});var n=i(r(1)),o=i(r(2)),a=i(r(4));function i(e){return e&&e.__esModule?e:{default:e}}t.default=function e(t,r){!function(e,t){if(!(e instanceof t))throw new TypeError("Cannot call a class as a function")}(this,e);var i=Object.assign({},a.default,r),u=i.onRendererInitialized,s=void 0;i.predictCallback=function(e){s&&(s.update(t,e),s.render())},i.hookCallback=function(e){s&&(s.updateValues(e),s.render())},(0,n.default)(t,i).then((function(e){s=new o.default(i,e),u&&u(s)}))},e.exports=t.default},function(e,t,r){"use strict";Object.defineProperty(t,"__esModule",{value:!0});var n,o=Object.assign||function(e){for(var t=1;t0?"rgb(0, 0, 255, "+i+")":"rgb(255, 0, 0, "+Math.abs(i)+")":"rgba(0,0,0,.5)",e.renderContext.moveTo(r.x+r.radius/2,r.y),e.renderContext.lineTo(t.x-t.radius/2,t.y),e.renderContext.stroke()})),e.renderContext.strokeStyle="rgb("+p+", "+h+", "+y+")";var f=s/(u+i);isNaN(f)?e.renderContext.fillStyle="#FFF":e.renderContext.fillStyle="rgba("+p+", "+h+", "+y+", "+f+")",e.renderContext.beginPath(),e.renderContext.arc(o,a,r/2,0,2*Math.PI),r>3&&v&&e.renderContext.stroke(),e.renderContext.fill(),l&&l(e.renderContext,t,n)}))})),r&&r(e)}))}}]),t}(((n=i)&&n.__esModule?n:{default:n}).default);t.default=u,e.exports=t.default},function(e,t,r){"use strict";Object.defineProperty(t,"__esModule",{value:!0});var n=function(e,t){if(Array.isArray(e))return e;if(Symbol.iterator in Object(e))return function(e,t){var r=[],n=!0,o=!1,a=void 0;try{for(var i,u=e[Symbol.iterator]();!(n=(i=u.next()).done)&&(r.push(i.value),!t||r.length!==t);n=!0);}catch(e){o=!0,a=e}finally{try{!n&&u.return&&u.return()}finally{if(o)throw a}}return r}(e,t);throw new TypeError("Invalid attempt to destructure non-iterable instance")},o=function(){function e(e,t){for(var r=0;r1&&void 0!==arguments[1]?arguments[1]:0;t.column=r,t.previousColumn.forEach((function(t){e(t,r+1)}))}(f[f.length-1]),f.forEach((function(e,r){var i=e.name,u=e.shape,s=e.previousColumn;o.outputLayer=e;var c=l[i]||{},f=Object.assign({},t,c),h=f.radius,y=f.nodePadding,v=f.layerPadding,g=f.groupPadding,b=f.domain,m=void 0===b?[0,1]:b,x=f.renderLinks,w=f.renderNode,j=f.nodeStroke,k=f.reshape,C=f.color||(r stats.json && webpack-bundle-analyzer ./stats.json" 13 | }, 14 | "repository": { 15 | "type": "git", 16 | "url": "https://github.com/cstefanache/tfjs-model-view" 17 | }, 18 | "keywords": [ 19 | "tensorflow", 20 | "tfjs", 21 | "webpack", 22 | "es6", 23 | "library", 24 | "universal", 25 | "umd", 26 | "commonjs" 27 | ], 28 | "author": "Cornel Stefanache", 29 | "license": "MIT", 30 | "bugs": { 31 | "url": "https://github.com/cstefanache/tfjs-model-view/issues" 32 | }, 33 | "devDependencies": { 34 | "babel-cli": "^6.26.0", 35 | "babel-core": "^6.26.3", 36 | "babel-eslint": "^8.2.3", 37 | "babel-loader": "^7.1.4", 38 | "babel-plugin-add-module-exports": "^0.2.1", 39 | "babel-plugin-transform-es2015-destructuring": "^6.23.0", 40 | "babel-plugin-transform-object-rest-spread": "^6.26.0", 41 | "babel-preset-env": "^1.7.0", 42 | "chai": "^4.1.2", 43 | "cross-env": "^5.2.0", 44 | "eslint": "^5.0.1", 45 | "eslint-loader": "^2.0.0", 46 | "html-webpack-plugin": "^3.2.0", 47 | "jsdom": "11.11.0", 48 | "jsdom-global": "3.0.2", 49 | "mocha": "^4.0.1", 50 | "nyc": "^13.1.0", 51 | "raw-loader": "^4.0.0", 52 | "uglifyjs-webpack-plugin": "^1.2.7", 53 | "webpack": "^4.12.2", 54 | "webpack-cli": "^3.0.8", 55 | "webpack-dev-server": "^3.1.14", 56 | "yargs": "^10.0.3" 57 | }, 58 | "nyc": { 59 | "sourceMap": false, 60 | "instrument": false 61 | }, 62 | "dependencies": { 63 | "@tensorflow/tfjs": "^0.15.0", 64 | "webpack-bundle-analyzer": "^3.0.4" 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/default.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | renderer: 'canvas', 3 | 4 | radius: 6, 5 | nodePadding: 2, 6 | layerPadding: 20, 7 | groupPadding: 1, 8 | 9 | xPadding: 10, 10 | yPadding: 10, 11 | 12 | renderLinks: false, 13 | plotActivations: false, 14 | nodeStroke: true, 15 | 16 | onRendererInitialized: renderer => { 17 | document.body.appendChild(renderer.canvas); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | import parseModel from './model-parser'; 2 | import CanvasRenderer from './renderers/canvas.renderer'; 3 | import defaultConfig from './default.config'; 4 | 5 | export default class ModelView { 6 | 7 | constructor(model, customConfig) { 8 | const config = Object.assign({}, defaultConfig, customConfig); 9 | const { onRendererInitialized } = config; 10 | let renderer; 11 | 12 | config.predictCallback = input => { 13 | if (renderer) { 14 | renderer.update(model, input); 15 | renderer.render(); 16 | } 17 | } 18 | 19 | config.hookCallback = layer => { 20 | if (renderer) { 21 | renderer.updateValues(layer); 22 | renderer.render(); 23 | } 24 | } 25 | 26 | parseModel(model, config).then(res => { 27 | renderer = new CanvasRenderer(config, res); 28 | if (onRendererInitialized) { 29 | onRendererInitialized(renderer); 30 | } 31 | }); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/model-parser.js: -------------------------------------------------------------------------------- 1 | function noop() { } 2 | 3 | async function parseModel(model, options) { 4 | 5 | const parsed = { 6 | layerMap: {}, 7 | layerArr: [] 8 | }; 9 | 10 | const parserConfig = { 11 | predictCallback: noop, 12 | hookCallback: noop, 13 | ...options 14 | } 15 | 16 | function parseLayer(layer, nextColumn) { 17 | 18 | const { 19 | name, 20 | input, 21 | inputs, 22 | shape, 23 | sourceLayer 24 | } = layer; 25 | 26 | const { 27 | getWeights, 28 | setCallHook, 29 | activation 30 | } = sourceLayer || {}; 31 | 32 | const currentLayer = { 33 | previousColumn: [], 34 | name, 35 | shape, 36 | weights: {}, 37 | getWeights: noop, 38 | mapPosition: Object.keys(parsed.layerMap).length 39 | }; 40 | 41 | parsed.layerMap[name] = currentLayer; 42 | parsed.layerArr.unshift(currentLayer); 43 | 44 | if (activation) { 45 | let className = activation.getClassName(); 46 | currentLayer.activation = { 47 | name: className 48 | } 49 | } 50 | 51 | if (setCallHook) { 52 | sourceLayer.setCallHook(async layerInput => { 53 | currentLayer.getWeights(); 54 | currentLayer.activations = [] 55 | for (let i = 0; i < layerInput.length; i++) { 56 | currentLayer.activations.push(layerInput[i].dataSync()) 57 | } 58 | parserConfig.hookCallback(currentLayer); 59 | }); 60 | } 61 | 62 | if (getWeights) { 63 | 64 | currentLayer.getWeights = async () => { 65 | const weights = await sourceLayer.getWeights(); 66 | 67 | for (let i = 0; i < weights.length; i++) { 68 | const weight = weights[i]; 69 | const { 70 | rankType, 71 | name: weightName 72 | } = weight; 73 | currentLayer.hasWeights = true; 74 | currentLayer.weights[rankType] = { 75 | name: weightName, 76 | values: await weights[i].dataSync() 77 | } 78 | } 79 | } 80 | 81 | currentLayer.getWeights(); 82 | } 83 | 84 | if (inputs) { 85 | inputs.forEach(inp => { 86 | parseLayer(inp, currentLayer.previousColumn); 87 | }) 88 | } else { 89 | parseLayer(input, currentLayer.previousColumn); 90 | } 91 | 92 | if (nextColumn) { 93 | nextColumn.push(currentLayer); 94 | } 95 | 96 | return currentLayer; 97 | } 98 | 99 | const predict = model.predict; 100 | 101 | model.predict = (...args) => { 102 | const result = predict.apply(model, args); 103 | model.outputData = result.dataSync(); 104 | parserConfig.predictCallback(args); 105 | return result; 106 | }; 107 | 108 | parsed.model = await parseLayer(model.layers[model.layers.length - 1].output); 109 | 110 | if (options.printStats) { 111 | const { 112 | layerArr 113 | } = parsed; 114 | console.log(new Array(10).join('-')); 115 | layerArr.forEach(layer => { 116 | console.log(`Layer: ${layer.name}`); 117 | }); 118 | } 119 | 120 | return parsed; 121 | } 122 | 123 | export default parseModel; 124 | -------------------------------------------------------------------------------- /src/renderers/abstract.renderer.js: -------------------------------------------------------------------------------- 1 | const colors = [ 2 | [6, 57, 143], 3 | [0, 107, 92], 4 | [216, 139, 0], 5 | [180, 0, 85], 6 | [106, 2, 143], 7 | [216, 109, 0], 8 | [2, 105, 134], 9 | [0, 142, 103], 10 | [201, 0, 39], 11 | [139, 11, 215], 12 | [171, 141, 0] 13 | ] 14 | 15 | export default class AbstractRenderer { 16 | 17 | constructor(config, initData) { 18 | const { 19 | xPadding, 20 | yPadding, 21 | xOffset, 22 | layer = {} 23 | } = config; 24 | const { layerArr } = initData; 25 | 26 | let maxHeight = (yPadding || 1) * 2; 27 | let cx = (xPadding || 0) + (xOffset || 0) 28 | 29 | function processColumn(lyr, col = 0) { 30 | lyr.column = col; 31 | lyr.previousColumn.forEach(l => { 32 | processColumn(l, col + 1); 33 | }); 34 | } 35 | 36 | processColumn(layerArr[layerArr.length - 1]); 37 | 38 | layerArr.forEach((l, lindex) => { 39 | const { name, shape, previousColumn } = l; 40 | this.outputLayer = l; 41 | const customConfig = layer[name] || {}; 42 | 43 | const layerConfig = Object.assign({}, config, customConfig) 44 | const { 45 | radius, 46 | nodePadding, 47 | layerPadding, 48 | groupPadding, 49 | domain = [0, 1], 50 | renderLinks, 51 | renderNode, 52 | nodeStroke, 53 | reshape } = layerConfig; 54 | 55 | const color = layerConfig.color || (lindex < colors.length ? colors[lindex] : [0, 0, 0]); 56 | let [rows, cols, groups] = Object.assign([1, 1, 1], shape.slice(1)); 57 | const totalNodes = rows * cols * groups; 58 | 59 | if (reshape) { 60 | const [nr, nc, ng] = Object.assign([1, 1, 1], reshape); 61 | if (nr * nc * ng !== totalNodes) { 62 | throw new Error(`Unable to reshape from [${rows}, ${cols}, ${groups}] to [${nr}, ${nc}, ${ng}]`) 63 | } 64 | 65 | rows = nr; 66 | cols = nc; 67 | groups = ng; 68 | } 69 | 70 | cx += layerPadding; 71 | 72 | const step = radius + nodePadding; 73 | const width = layerPadding + cols * step 74 | const nodes = []; 75 | let height = 0; 76 | 77 | for (let row = 0; row < rows; row++) { 78 | for (let col = 0; col < cols; col++) { 79 | for (let group = 0; group < groups; group++) { 80 | const y = radius + row * step + group * rows * step + group * groupPadding 81 | nodes.push({ 82 | x: cx + col * step, 83 | y, 84 | radius 85 | }); 86 | height = y; 87 | } 88 | } 89 | } 90 | 91 | height += groupPadding + radius; 92 | maxHeight = Math.max(maxHeight, height); 93 | 94 | Object.assign(l, { 95 | name, 96 | x: cx, 97 | layerWidth: width, 98 | layerHeight: height, 99 | radius, 100 | nodes, 101 | domain, 102 | renderLinks, 103 | renderNode, 104 | nodeStroke, 105 | color, 106 | previousLayers: previousColumn.map(lyr => lyr.name) 107 | }) 108 | 109 | cx += width; 110 | }); 111 | 112 | cx += xPadding || 0; 113 | 114 | layerArr.forEach(l => { 115 | const offsetY = Math.floor((maxHeight - l.layerHeight) / 2); 116 | l.nodes.forEach(nd => nd.y += offsetY); 117 | }); 118 | 119 | Object.assign(this, { width: cx, height: maxHeight }); 120 | 121 | this.layers = layerArr; 122 | this.layersMap = layerArr.reduce((memo, item) => { 123 | memo[item.name] = item; 124 | return memo; 125 | }, {}); 126 | } 127 | 128 | update(model, input) { 129 | if (input) { 130 | model.inputs.forEach((inputLayer, index) => { 131 | const syntheticLayer = this.layersMap[inputLayer.name]; 132 | this.updateLayerValues(syntheticLayer, input[index].dataSync()); 133 | }); 134 | } 135 | 136 | this.updateLayerValues(this.outputLayer, model.outputData); 137 | } 138 | 139 | updateLayerValues(layer, data) { 140 | for (let i = 0; i < layer.nodes.length; i++) { 141 | layer.nodes[i].value = data[i]; 142 | } 143 | } 144 | 145 | updateValues(layer) { 146 | const syntheticLayer = this.layersMap[layer.name]; 147 | syntheticLayer.weights = layer.weights; 148 | syntheticLayer.previousColumn.forEach((col, idx) => { 149 | this.updateLayerValues(col, layer.activations[idx]); 150 | }) 151 | 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/renderers/canvas.renderer.js: -------------------------------------------------------------------------------- 1 | import AbstractRenderer from './abstract.renderer'; 2 | 3 | export default class CanvasRenderer extends AbstractRenderer { 4 | constructor(config, initData) { 5 | super(config, initData); 6 | const canvas = document.createElement('canvas'); 7 | 8 | const { onBeginRender, onEndRender } = config; 9 | 10 | Object.assign(this, { canvas, onBeginRender, onEndRender }) 11 | 12 | canvas.setAttribute('width', this.width); 13 | canvas.setAttribute('height', this.height); 14 | 15 | this.renderContext = canvas.getContext('2d') 16 | this.renderElement = canvas; 17 | } 18 | 19 | render() { 20 | window.requestAnimationFrame(() => { 21 | const { onBeginRender, onEndRender } = this; 22 | this.renderContext.clearRect(0, 0, this.width, this.height); 23 | 24 | if (onBeginRender) { 25 | onBeginRender(this); 26 | } 27 | 28 | this.layers.forEach(layer => { 29 | const { 30 | radius, 31 | nodes, 32 | domain: [min, max], 33 | previousColumn, 34 | renderLinks, 35 | renderNode, 36 | weights, 37 | color: [r, g, b], 38 | nodeStroke 39 | } = layer; 40 | 41 | let { 2: kernel } = weights; 42 | let leftSideNodes; 43 | 44 | if (renderLinks) { 45 | leftSideNodes = previousColumn.reduce((memo, prevLayer) => memo.concat(prevLayer.nodes), []) 46 | } 47 | 48 | nodes.forEach((node, index) => { 49 | const { x: nx, y: ny, value } = node; 50 | 51 | if (renderLinks) { 52 | leftSideNodes.forEach((leftNode, leftIdx) => { 53 | this.renderContext.beginPath(); 54 | let hasWeight = kernel && kernel.values; 55 | const weightVal = hasWeight ? kernel.values[index * leftIdx] : 0.5 56 | if (hasWeight) { 57 | this.renderContext.strokeStyle = weightVal > 0 ? 58 | `rgb(0, 0, 255, ${weightVal})` : 59 | `rgb(255, 0, 0, ${Math.abs(weightVal)})`; 60 | } else { 61 | this.renderContext.strokeStyle = 'rgba(0,0,0,.5)'; 62 | } 63 | this.renderContext.moveTo(leftNode.x + leftNode.radius / 2, leftNode.y); 64 | this.renderContext.lineTo(node.x - node.radius / 2, node.y); 65 | this.renderContext.stroke(); 66 | }) 67 | } 68 | this.renderContext.strokeStyle = `rgb(${r}, ${g}, ${b})`; 69 | const domainValue = value / (max + min) 70 | if (!isNaN(domainValue)) { 71 | this.renderContext.fillStyle = `rgba(${r}, ${g}, ${b}, ${domainValue})` 72 | } else { 73 | this.renderContext.fillStyle = '#FFF'; 74 | } 75 | this.renderContext.beginPath(); 76 | this.renderContext.arc(nx, ny, radius / 2, 0, 2 * Math.PI) 77 | if (radius > 3 && nodeStroke) { 78 | this.renderContext.stroke(); 79 | } 80 | this.renderContext.fill(); 81 | 82 | if (renderNode) { 83 | renderNode(this.renderContext, node, index); 84 | } 85 | }); 86 | }) 87 | 88 | if (onEndRender) { 89 | onEndRender(this); 90 | } 91 | }); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /test/index.spec.js: -------------------------------------------------------------------------------- 1 | import chai from 'chai'; 2 | 3 | chai.expect(); 4 | 5 | const expect = chai.expect; 6 | 7 | let lib; 8 | 9 | describe('Nothing', () => { 10 | before(() => { 11 | }); 12 | describe('when I need the name', () => { 13 | it('should return the name', () => { 14 | }); 15 | }); 16 | }); 17 | -------------------------------------------------------------------------------- /webpack.config.js: -------------------------------------------------------------------------------- 1 | /* global __dirname, require, module*/ 2 | 3 | const webpack = require('webpack'); 4 | const path = require('path'); 5 | const env = require('yargs').argv.env; // use --env with webpack 2 6 | const pkg = require('./package.json'); 7 | 8 | let libraryName = pkg.name; 9 | 10 | let outputFile, mode; 11 | 12 | if (env === 'build') { 13 | mode = 'production'; 14 | outputFile = `${libraryName}.min.js`; 15 | } else { 16 | mode = 'development'; 17 | outputFile = `${libraryName}.js`; 18 | } 19 | 20 | const config = { 21 | mode: mode, 22 | entry: `${__dirname}/src/index.js`, 23 | devtool: 'inline-source-map', 24 | output: { 25 | path: `${__dirname}/lib`, 26 | filename: outputFile, 27 | library: libraryName, 28 | libraryTarget: 'umd', 29 | umdNamedDefine: true, 30 | globalObject: "typeof self !== 'undefined' ? self : this" 31 | }, 32 | module: { 33 | rules: [{ 34 | test: /(\.jsx|\.js)$/, 35 | loader: 'babel-loader', 36 | exclude: /(node_modules|bower_components)/ 37 | }, 38 | { 39 | test: /(\.jsx|\.js)$/, 40 | loader: 'eslint-loader', 41 | exclude: /node_modules/ 42 | } 43 | ] 44 | }, 45 | resolve: { 46 | modules: [path.resolve('./node_modules'), path.resolve('./src')], 47 | extensions: ['.json', '.js'] 48 | } 49 | }; 50 | 51 | module.exports = config; 52 | -------------------------------------------------------------------------------- /webpack.dev.config.js: -------------------------------------------------------------------------------- 1 | const config = require('./webpack.config'); 2 | const HtmlWebpackPlugin = require('html-webpack-plugin'); 3 | 4 | let plugins = [new HtmlWebpackPlugin({ 5 | template: './app/index.html' 6 | })]; 7 | 8 | module.exports = Object.assign(config, { 9 | entry: `${__dirname}/app/index.js`, 10 | plugins, 11 | node: { 12 | fs: 'empty' 13 | }, 14 | devServer: { 15 | hot: true, 16 | inline: true, 17 | clientLogLevel: 'error', 18 | stats: { 19 | colors: true 20 | }, 21 | proxy: { 22 | '/static': { 23 | target: 'http://localhost:4500', 24 | pathRewrite: { 25 | '^/static': './app/models' 26 | } 27 | } 28 | }, 29 | host: '0.0.0.0', 30 | port: 4500 31 | } 32 | }); 33 | --------------------------------------------------------------------------------