├── .gitignore ├── README.md └── tfjs-vis ├── .npmignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cloudbuild.yml ├── demos ├── api │ ├── .babelrc │ ├── README.md │ ├── index.css │ ├── index.html │ ├── index.js │ ├── package.json │ └── yarn.lock ├── mnist │ ├── .babelrc │ ├── README.md │ ├── data.js │ ├── index.html │ ├── index.js │ ├── model.js │ ├── package.json │ ├── tufte.css │ └── yarn.lock └── mnist_internals │ ├── .babelrc │ ├── README.md │ ├── data.js │ ├── index.html │ ├── index.js │ ├── model.js │ ├── package.json │ ├── tufte.css │ └── yarn.lock ├── docs └── visor-usage.png ├── package.json ├── scripts ├── build-npm.sh ├── deploy.sh ├── publish-npm.sh ├── tag-version.js └── test-ci.sh ├── src ├── components │ ├── surface.tsx │ ├── tabs.tsx │ ├── visor.tsx │ └── visor_test.tsx ├── index.ts ├── render │ ├── barchart.ts │ ├── barchart_test.ts │ ├── confusion_matrix.ts │ ├── confusion_matrix_test.ts │ ├── heatmap.ts │ ├── heatmap_test.ts │ ├── histogram.ts │ ├── histogram_test.ts │ ├── linechart.ts │ ├── linechart_test.ts │ ├── render_utils.ts │ ├── render_utils_test.ts │ ├── scatterplot.ts │ ├── scatterplot_tests.ts │ ├── table.ts │ └── table_test.ts ├── show │ ├── history.ts │ ├── history_test.ts │ ├── model.ts │ ├── model_test.ts │ ├── quality.ts │ ├── quality_test.ts │ ├── tensor.ts │ └── tensor_test.ts ├── types.ts ├── types │ └── glamor-tachyons │ │ └── index.d.ts ├── util │ ├── dom.ts │ ├── math.ts │ ├── math_test.ts │ └── utils.ts ├── visor.ts └── visor_test.ts ├── tsconfig.json ├── tslint.json ├── webpack.config.js └── yarn.lock /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | coverage/ 3 | npm-debug.log 4 | yarn-error.log 5 | .DS_Store 6 | dist/ 7 | .idea/ 8 | *.tgz 9 | .yalc/ 10 | yalc.lock 11 | .rpt2_cache/ 12 | package/ 13 | .cache 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This repository has been archived in favor of tensorflow/tfjs. 2 | 3 | This repo will remain around for some time to keep history but all future PRs should be sent to [tensorflow/tfjs](https://github.com/tensorflow/tfjs) inside the [tfjs-vis](https://github.com/tensorflow/tfjs/tree/master/tfjs-vis) folder. 4 | 5 | All history and contributions have been preserved in the monorepo. 6 | -------------------------------------------------------------------------------- /tfjs-vis/.npmignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .rpt2_cache/ 3 | demos/ 4 | scripts/ 5 | models/ 6 | coverage/ 7 | package/ 8 | **/node_modules/ 9 | karma.conf.js 10 | *.tgz 11 | *.log 12 | cloudbuild.yml 13 | CONTRIBUTING.md 14 | tslint.json 15 | yarn.lock 16 | DEVELOPMENT.md 17 | ISSUE_TEMPLATE.md 18 | PULL_REQUEST_TEMPLATE.md 19 | webpack.config.js 20 | tsconfig.json 21 | docs/ 22 | .yalc/ 23 | -------------------------------------------------------------------------------- /tfjs-vis/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [1.0.0] 2019-03-06 4 | Major version upgrade including a number of **breaking changes**. 5 | 6 | ### Changed 7 | All `render.*` functions now take a surface as their first parameter. This unifies the convention used in these functions and `show.*` functions. 8 | 9 | In `heatmap` and `confusionMatrix` params, `xLabels/yLabels` has been renamed to `xtickLabels/yTickLabels` to more clearly distinguish it from the axis label (`xLabel/yLabel`). 10 | 11 | ### Removed 12 | `show.confusionMatrix` has been removed in favor of using `render.confusionMatrix` 13 | 14 | ## [0.5.0] 2018-1-31 15 | ### Added 16 | Support for rendering heatmaps with `render.heatmap`. 17 | 18 | Sourcemaps now include sourcesContent to improve editor experience, particularly for typescript users. 19 | 20 | ## [0.4.0] 2018-12-07 21 | 22 | ### Changed 23 | `confusionMatrix` now shades the diagonal by default. The chart also has improved 24 | contrast between the text and the chart cells. 25 | 26 | bugfixes for `metrics.confusionMatrix` on Safari 27 | 28 | Improvements to chart rendering options to help prevent situations where the container 29 | div grows on each render if dimensions were not specified in code or constrained by css. 30 | 31 | 32 | ## [0.3.0] 2018-11-06 33 | ### Added 34 | `fontSize` can now be passed into render.* methods 35 | `zoomToFit` and `yAxisDomain` are new options that linecharts and scatterplots 36 | take to allow finer control over the display of the yAxis 37 | `xAxisDomain` option added to scatterplots. 38 | 39 | ### Changed 40 | `show.history` and `show.fitCallbacks` now take an optional `opts` parameter. 41 | These allow passing configuration to the underlying charts as well as overriding 42 | which callbacks get generated by `show.fitCallbacks`. 43 | 44 | `show.history` and `show.fitCallbacks` will now automatically group a metric with 45 | its corresponding validation metric and display them on the same chart. For example 46 | if you have `['acc', 'val_acc', 'loss', 'val_loss']` as your metrics, these will 47 | be rendered on two charts, one for the loss metrics and one for the accuracy metrics. 48 | -------------------------------------------------------------------------------- /tfjs-vis/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We are happy to accept patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Adding functionality 19 | 20 | One way to ensure that your PR will be accepted is to add functionality that 21 | has been requested in Github issues. If there is something you think is 22 | important and we're missing it but does not show up in Github issues, it would 23 | be good to file an issue there first so we can have the discussion before 24 | sending us a PR. 25 | 26 | ## Code reviews 27 | 28 | All submissions, including submissions by project members, require review. We 29 | use GitHub pull requests for this purpose. Consult 30 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 31 | information on using pull requests. 32 | 33 | We require unit tests for most code, instructions for running our unit test 34 | suites are in the documentation. 35 | -------------------------------------------------------------------------------- /tfjs-vis/README.md: -------------------------------------------------------------------------------- 1 | # tfjs-vis 2 | 3 | __tfjs-vis__ is a small library for _in browser_ visualization intended for use 4 | with TensorFlow.js. 5 | 6 | It's main features are: 7 | 8 | * A set of visualizations useful for visualizing model behaviour 9 | * A set of high level functions for visualizing objects specific to TensorFlow.js 10 | * A way to organize visualizations of model behaviour that won't interfere with your web application 11 | 12 | The library also aims to be flexible and make it easy for you to incorporate 13 | custom visualizations using tools of your choosing, such as d3, Chart.js or plotly.js. 14 | 15 | ## Demos 16 | 17 | - [Visualizing Training with tfjs-vis](https://storage.googleapis.com/tfjs-vis/mnist/dist/index.html) 18 | - [Looking inside a digit recognizer](https://storage.googleapis.com/tfjs-vis/mnist_internals/dist/index.html) 19 | 20 | ## Installation 21 | 22 | You can install this using npm with 23 | 24 | ``` 25 | npm install @tensorflow/tfjs-vis 26 | ``` 27 | 28 | or using yarn with 29 | 30 | ``` 31 | yarn add @tensorflow/tfjs-vis 32 | ``` 33 | 34 | You can also load it via script tag using the following tag, however you need 35 | to have TensorFlow.js also loaded on the page to work. Including both is shown 36 | below. 37 | 38 | ``` 39 | 40 | 41 | ``` 42 | 43 | 44 | ## Building from source 45 | 46 | To build the library, you need to have node.js installed. We use `yarn` 47 | instead of `npm` but you can use either. 48 | 49 | First install dependencies with 50 | 51 | ``` 52 | yarn 53 | ``` 54 | 55 | or 56 | 57 | ``` 58 | npm install 59 | ``` 60 | 61 | Then do a build with 62 | 63 | ``` 64 | yarn build 65 | ``` 66 | 67 | or 68 | 69 | ``` 70 | npm run build 71 | ``` 72 | 73 | This should produce a `tfjs-vis.umd.min.js` file in the `dist` folder that you can 74 | use. 75 | 76 | ## Sample Usage 77 | 78 | ```js 79 | const data = [ 80 | { index: 0, value: 50 }, 81 | { index: 1, value: 100 }, 82 | { index: 2, value: 150 }, 83 | ]; 84 | 85 | // Get a surface 86 | const surface = tfvis.visor().surface({ name: 'Barchart', tab: 'Charts' }); 87 | 88 | // Render a barchart on that surface 89 | tfvis.render.barchart(surface, data, {}); 90 | ``` 91 | 92 | This should show something like the following 93 | 94 | ![visor screenshot with barchart](./docs/visor-usage.png) 95 | 96 | ## Issues 97 | 98 | Found a bug or have a feature request? Please file an [issue](https://github.com/tensorflow/tfjs/issues/new) on the main [TensorFlow.js repository](https://github.com/tensorflow/tfjs/issues) 99 | 100 | ## API 101 | 102 | See https://js.tensorflow.org/api_vis/latest/ for interactive API documentation. 103 | -------------------------------------------------------------------------------- /tfjs-vis/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'node:10' 3 | entrypoint: 'yarn' 4 | args: ['test-ci'] 5 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1'] 6 | secretEnv: ['BROWSERSTACK_KEY'] 7 | secrets: 8 | - kmsKeyName: projects/learnjs-174218/locations/global/keyRings/tfjs/cryptoKeys/enc 9 | secretEnv: 10 | BROWSERSTACK_KEY: CiQAkwyoIW0LcnxymzotLwaH4udVTQFBEN4AEA5CA+a3+yflL2ASPQAD8BdZnGARf78MhH5T9rQqyz9HNODwVjVIj64CTkFlUCGrP1B2HX9LXHWHLmtKutEGTeFFX9XhuBzNExA= 11 | timeout: 1800s 12 | logsBucket: 'gs://tfjs-build-logs' 13 | options: 14 | logStreamingOption: 'STREAM_ON' 15 | substitution_option: 'ALLOW_LOOSE' 16 | -------------------------------------------------------------------------------- /tfjs-vis/demos/api/.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [ 3 | [ 4 | "env", 5 | { 6 | "esmodules": false, 7 | "targets": { 8 | "browsers": [ 9 | "> 3%" 10 | ] 11 | } 12 | } 13 | ] 14 | ], 15 | "plugins": [ 16 | "@babel/plugin-transform-runtime" 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /tfjs-vis/demos/api/README.md: -------------------------------------------------------------------------------- 1 | # tfjs-vis renderers demo 2 | 3 | A page showing examples of all the tfjs-vis render.* functions. 4 | 5 | See a live version of this [here](https://storage.googleapis.com/tfjs-vis/api/dist/index.html). 6 | 7 | To run this locally run `yarn watch-api` from the `demos` folder 8 | -------------------------------------------------------------------------------- /tfjs-vis/demos/api/index.css: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | /* Test CSS overrides here */ 19 | -------------------------------------------------------------------------------- /tfjs-vis/demos/api/index.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs' 2 | import * as tfvis from '@tensorflow/tfjs-vis'; 3 | window.tf = tf; 4 | window.tfvis = tfvis; 5 | -------------------------------------------------------------------------------- /tfjs-vis/demos/api/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tfjs-vis-api-demo", 3 | "private": true, 4 | "version": "0.0.1", 5 | "description": "", 6 | "main": "index.js", 7 | "scripts": { 8 | "watch": "NODE_ENV=development parcel index.html --no-hmr --open --no-autoinstall", 9 | "build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./" 10 | }, 11 | "author": "", 12 | "license": "Apache-2.0", 13 | "devDependencies": { 14 | "@babel/core": "^7.0.0-0", 15 | "@babel/plugin-transform-runtime": "^7.1.0", 16 | "babel-preset-env": "~1.6.1", 17 | "parcel-bundler": "1.12.3", 18 | "rimraf": "^2.6.2" 19 | }, 20 | "dependencies": { 21 | "@tensorflow/tfjs": "1.0.0", 22 | "@tensorflow/tfjs-vis": "1.0.3" 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist/.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [ 3 | [ 4 | "env", 5 | { 6 | "esmodules": false, 7 | "targets": { 8 | "browsers": [ 9 | "> 3%" 10 | ] 11 | } 12 | } 13 | ] 14 | ], 15 | "plugins": [ 16 | "@babel/plugin-transform-runtime" 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist/README.md: -------------------------------------------------------------------------------- 1 | # tfjs-vis mnist training demo 2 | 3 | A demonstration of using tfjs-vis to look at training progress and perform model evaluation. 4 | 5 | See a live version of this [here](https://storage.googleapis.com/tfjs-vis/mnist/dist/index.html). 6 | 7 | To run this locally run `yarn watch-api` from the `demos` folder 8 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist/data.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | 20 | const IMAGE_SIZE = 784; 21 | const NUM_CLASSES = 10; 22 | const NUM_DATASET_ELEMENTS = 65000; 23 | 24 | const NUM_TRAIN_ELEMENTS = 55000; 25 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; 26 | 27 | const MNIST_IMAGES_SPRITE_PATH = 28 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; 29 | const MNIST_LABELS_PATH = 30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; 31 | 32 | /** 33 | * A class that fetches the sprited MNIST dataset and returns shuffled batches. 34 | * 35 | * NOTE: This will get much easier. For now, we do data fetching and 36 | * manipulation manually. 37 | */ 38 | export class MnistData { 39 | constructor() { 40 | this.shuffledTrainIndex = 0; 41 | this.shuffledTestIndex = 0; 42 | } 43 | 44 | async load() { 45 | // Make a request for the MNIST sprited image. 46 | const img = new Image(); 47 | const canvas = document.createElement('canvas'); 48 | const ctx = canvas.getContext('2d'); 49 | const imgRequest = new Promise((resolve, reject) => { 50 | img.crossOrigin = ''; 51 | img.onload = () => { 52 | img.width = img.naturalWidth; 53 | img.height = img.naturalHeight; 54 | 55 | const datasetBytesBuffer = 56 | new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); 57 | 58 | const chunkSize = 5000; 59 | canvas.width = img.width; 60 | canvas.height = chunkSize; 61 | 62 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { 63 | const datasetBytesView = new Float32Array( 64 | datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, 65 | IMAGE_SIZE * chunkSize); 66 | ctx.drawImage( 67 | img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, 68 | chunkSize); 69 | 70 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); 71 | 72 | for (let j = 0; j < imageData.data.length / 4; j++) { 73 | // All channels hold an equal value since the image is grayscale, so 74 | // just read the red channel. 75 | datasetBytesView[j] = imageData.data[j * 4] / 255; 76 | } 77 | } 78 | this.datasetImages = new Float32Array(datasetBytesBuffer); 79 | 80 | resolve(); 81 | }; 82 | img.src = MNIST_IMAGES_SPRITE_PATH; 83 | }); 84 | 85 | const labelsRequest = fetch(MNIST_LABELS_PATH); 86 | const [imgResponse, labelsResponse] = 87 | await Promise.all([imgRequest, labelsRequest]); 88 | 89 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); 90 | 91 | // Create shuffled indices into the train/test set for when we select a 92 | // random dataset element for training / validation. 93 | this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); 94 | this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); 95 | 96 | // Slice the the images and labels into train and test sets. 97 | this.trainImages = 98 | this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 99 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 100 | this.trainLabels = 101 | this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); 102 | this.testLabels = 103 | this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); 104 | } 105 | 106 | nextTrainBatch(batchSize) { 107 | return this.nextBatch( 108 | batchSize, [this.trainImages, this.trainLabels], () => { 109 | this.shuffledTrainIndex = 110 | (this.shuffledTrainIndex + 1) % this.trainIndices.length; 111 | return this.trainIndices[this.shuffledTrainIndex]; 112 | }); 113 | } 114 | 115 | nextTestBatch(batchSize) { 116 | return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => { 117 | this.shuffledTestIndex = 118 | (this.shuffledTestIndex + 1) % this.testIndices.length; 119 | return this.testIndices[this.shuffledTestIndex]; 120 | }); 121 | } 122 | 123 | nextBatch(batchSize, data, index) { 124 | const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE); 125 | const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES); 126 | 127 | for (let i = 0; i < batchSize; i++) { 128 | const idx = index(); 129 | 130 | const image = 131 | data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE); 132 | batchImagesArray.set(image, i * IMAGE_SIZE); 133 | 134 | const label = 135 | data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES); 136 | batchLabelsArray.set(label, i * NUM_CLASSES); 137 | } 138 | 139 | const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]); 140 | const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]); 141 | 142 | return {xs, labels}; 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist/index.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | import * as tfvis from '@tensorflow/tfjs-vis' 20 | import {getModel, loadData} from './model'; 21 | 22 | window.tf = tf; 23 | window.tfvis = tfvis; 24 | 25 | window.data; 26 | window.model; 27 | 28 | async function initData() { 29 | window.data = await loadData(); 30 | } 31 | 32 | function initModel() { 33 | window.model = getModel(); 34 | } 35 | 36 | function setupListeners() { 37 | document.querySelector('#show-visor').addEventListener('click', () => { 38 | const visorInstance = tfvis.visor(); 39 | if (!visorInstance.isOpen()) { 40 | visorInstance.toggle(); 41 | } 42 | }); 43 | 44 | document.querySelector('#make-first-surface') 45 | .addEventListener('click', () => { 46 | tfvis.visor().surface({name: 'My First Surface', tab: 'Input Data'}); 47 | }); 48 | 49 | document.querySelector('#load-data').addEventListener('click', async (e) => { 50 | await initData(); 51 | document.querySelector('#show-examples').disabled = false; 52 | document.querySelector('#start-training-1').disabled = false; 53 | document.querySelector('#start-training-2').disabled = false; 54 | e.target.disabled = true; 55 | }); 56 | } 57 | 58 | document.addEventListener('DOMContentLoaded', function() { 59 | initModel(); 60 | setupListeners(); 61 | }); 62 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist/model.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | import {MnistData} from './data'; 20 | 21 | export function getModel() { 22 | const model = tf.sequential(); 23 | 24 | model.add(tf.layers.conv2d({ 25 | inputShape: [28, 28, 1], 26 | kernelSize: 5, 27 | filters: 8, 28 | strides: 1, 29 | activation: 'relu', 30 | kernelInitializer: 'varianceScaling' 31 | })); 32 | 33 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); 34 | model.add(tf.layers.conv2d({ 35 | kernelSize: 5, 36 | filters: 16, 37 | strides: 1, 38 | activation: 'relu', 39 | kernelInitializer: 'varianceScaling' 40 | })); 41 | 42 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); 43 | model.add(tf.layers.flatten()); 44 | 45 | model.add(tf.layers.dense({ 46 | units: 10, 47 | kernelInitializer: 'varianceScaling', 48 | activation: 'softmax' 49 | })); 50 | 51 | const LEARNING_RATE = 0.15; 52 | const optimizer = tf.train.sgd(LEARNING_RATE); 53 | 54 | model.compile({ 55 | optimizer: optimizer, 56 | loss: 'categoricalCrossentropy', 57 | metrics: ['accuracy'], 58 | }); 59 | 60 | return model; 61 | } 62 | 63 | export async function loadData() { 64 | const data = new MnistData(); 65 | await data.load(); 66 | return data; 67 | } 68 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mnist-vis-demo", 3 | "private": true, 4 | "version": "0.0.1", 5 | "description": "", 6 | "main": "index.js", 7 | "scripts": { 8 | "watch": "NODE_ENV=development parcel index.html --no-hmr --open --no-autoinstall", 9 | "build": "NODE_ENV=production parcel build index.html --no-minify --detailed-report --public-url ./" 10 | }, 11 | "author": "", 12 | "license": "Apache-2.0", 13 | "devDependencies": { 14 | "@babel/core": "^7.0.0-0", 15 | "@babel/plugin-transform-runtime": "^7.1.0", 16 | "babel-preset-env": "~1.6.1", 17 | "parcel-bundler": "1.12.3", 18 | "rimraf": "^2.6.2" 19 | }, 20 | "dependencies": { 21 | "@tensorflow/tfjs": "1.0.0", 22 | "@tensorflow/tfjs-vis": "1.0.3" 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist/tufte.css: -------------------------------------------------------------------------------- 1 | /* 2 | From https://github.com/edwardtufte/tufte-css 3 | https://github.com/edwardtufte/tufte-css/blob/gh-pages/LICENSE 4 | */ 5 | 6 | @charset "UTF-8"; 7 | 8 | /* Tufte CSS styles */ 9 | html { font-size: 15px; } 10 | 11 | body { width: 87.5%; 12 | margin-left: auto; 13 | margin-right: auto; 14 | padding-left: 12.5%; 15 | font-family: et-book, Palatino, "Palatino Linotype", "Palatino LT STD", "Book Antiqua", Georgia, serif; 16 | background-color: #fffff8; 17 | color: #111; 18 | max-width: 1400px; 19 | counter-reset: sidenote-counter; } 20 | 21 | h1 { font-weight: 400; 22 | margin-top: 4rem; 23 | margin-bottom: 1.5rem; 24 | font-size: 3.2rem; 25 | line-height: 1; } 26 | 27 | h2 { font-style: italic; 28 | font-weight: 400; 29 | margin-top: 2.1rem; 30 | margin-bottom: 1.4rem; 31 | font-size: 2.2rem; 32 | line-height: 1; } 33 | 34 | h3 { font-style: italic; 35 | font-weight: 400; 36 | font-size: 1.7rem; 37 | margin-top: 2rem; 38 | margin-bottom: 1.4rem; 39 | line-height: 1; } 40 | 41 | hr { display: block; 42 | height: 1px; 43 | width: 55%; 44 | border: 0; 45 | border-top: 1px solid #ccc; 46 | margin: 1em 0; 47 | padding: 0; } 48 | 49 | p.subtitle { font-style: italic; 50 | margin-top: 1rem; 51 | margin-bottom: 1rem; 52 | font-size: 1.8rem; 53 | display: block; 54 | line-height: 1; } 55 | 56 | .numeral { font-family: et-book-roman-old-style; } 57 | 58 | .danger { color: red; } 59 | 60 | article { position: relative; 61 | padding: 5rem 0rem; } 62 | 63 | section { padding-top: 1rem; 64 | padding-bottom: 1rem; } 65 | 66 | p, ol, ul { font-size: 1.4rem; 67 | line-height: 2rem; } 68 | 69 | p { margin-top: 1.4rem; 70 | margin-bottom: 1.4rem; 71 | padding-right: 0; 72 | vertical-align: baseline; } 73 | 74 | /* Chapter Epigraphs */ 75 | div.epigraph { margin: 5em 0; } 76 | 77 | div.epigraph > blockquote { margin-top: 3em; 78 | margin-bottom: 3em; } 79 | 80 | div.epigraph > blockquote, div.epigraph > blockquote > p { font-style: italic; } 81 | 82 | div.epigraph > blockquote > footer { font-style: normal; } 83 | 84 | div.epigraph > blockquote > footer > cite { font-style: italic; } 85 | /* end chapter epigraphs styles */ 86 | 87 | blockquote { font-size: 1.4rem; } 88 | 89 | blockquote p { width: 55%; 90 | margin-right: 40px; } 91 | 92 | blockquote footer { width: 55%; 93 | font-size: 1.1rem; 94 | text-align: right; } 95 | 96 | section > p, section > footer, section > table { width: 55%; } 97 | 98 | /* 50 + 5 == 55, to be the same width as paragraph */ 99 | section > ol, section > ul { width: 50%; 100 | -webkit-padding-start: 5%; } 101 | 102 | li:not(:first-child) { margin-top: 0.25rem; } 103 | 104 | figure { padding: 0; 105 | border: 0; 106 | font-size: 100%; 107 | font: inherit; 108 | vertical-align: baseline; 109 | max-width: 55%; 110 | -webkit-margin-start: 0; 111 | -webkit-margin-end: 0; 112 | margin: 0 0 3em 0; } 113 | 114 | figcaption { float: right; 115 | clear: right; 116 | margin-top: 0; 117 | margin-bottom: 0; 118 | font-size: 1.1rem; 119 | line-height: 1.6; 120 | vertical-align: baseline; 121 | position: relative; 122 | max-width: 40%; } 123 | 124 | figure.fullwidth figcaption { margin-right: 24%; } 125 | 126 | /* Links: replicate underline that clears descenders */ 127 | a:link, a:visited { color: inherit; } 128 | 129 | a:link { text-decoration: none; 130 | background: -webkit-linear-gradient(#fffff8, #fffff8), -webkit-linear-gradient(#fffff8, #fffff8), -webkit-linear-gradient(#333, #333); 131 | background: linear-gradient(#fffff8, #fffff8), linear-gradient(#fffff8, #fffff8), linear-gradient(#333, #333); 132 | -webkit-background-size: 0.05em 1px, 0.05em 1px, 1px 1px; 133 | -moz-background-size: 0.05em 1px, 0.05em 1px, 1px 1px; 134 | background-size: 0.05em 1px, 0.05em 1px, 1px 1px; 135 | background-repeat: no-repeat, no-repeat, repeat-x; 136 | text-shadow: 0.03em 0 #fffff8, -0.03em 0 #fffff8, 0 0.03em #fffff8, 0 -0.03em #fffff8, 0.06em 0 #fffff8, -0.06em 0 #fffff8, 0.09em 0 #fffff8, -0.09em 0 #fffff8, 0.12em 0 #fffff8, -0.12em 0 #fffff8, 0.15em 0 #fffff8, -0.15em 0 #fffff8; 137 | background-position: 0% 93%, 100% 93%, 0% 93%; } 138 | 139 | @media screen and (-webkit-min-device-pixel-ratio: 0) { a:link { background-position-y: 87%, 87%, 87%; } } 140 | 141 | a:link::selection { text-shadow: 0.03em 0 #b4d5fe, -0.03em 0 #b4d5fe, 0 0.03em #b4d5fe, 0 -0.03em #b4d5fe, 0.06em 0 #b4d5fe, -0.06em 0 #b4d5fe, 0.09em 0 #b4d5fe, -0.09em 0 #b4d5fe, 0.12em 0 #b4d5fe, -0.12em 0 #b4d5fe, 0.15em 0 #b4d5fe, -0.15em 0 #b4d5fe; 142 | background: #b4d5fe; } 143 | 144 | a:link::-moz-selection { text-shadow: 0.03em 0 #b4d5fe, -0.03em 0 #b4d5fe, 0 0.03em #b4d5fe, 0 -0.03em #b4d5fe, 0.06em 0 #b4d5fe, -0.06em 0 #b4d5fe, 0.09em 0 #b4d5fe, -0.09em 0 #b4d5fe, 0.12em 0 #b4d5fe, -0.12em 0 #b4d5fe, 0.15em 0 #b4d5fe, -0.15em 0 #b4d5fe; 145 | background: #b4d5fe; } 146 | 147 | /* Sidenotes, margin notes, figures, captions */ 148 | img { max-width: 100%; } 149 | 150 | .sidenote, .marginnote { float: right; 151 | clear: right; 152 | margin-right: -60%; 153 | width: 50%; 154 | margin-top: 0; 155 | margin-bottom: 0; 156 | font-size: 1.1rem; 157 | line-height: 1.3; 158 | vertical-align: baseline; 159 | position: relative; } 160 | 161 | .sidenote-number { counter-increment: sidenote-counter; } 162 | 163 | .sidenote-number:after, .sidenote:before { font-family: et-book-roman-old-style; 164 | position: relative; 165 | vertical-align: baseline; } 166 | 167 | .sidenote-number:after { content: counter(sidenote-counter); 168 | font-size: 1rem; 169 | top: -0.5rem; 170 | left: 0.1rem; } 171 | 172 | .sidenote:before { content: counter(sidenote-counter) " "; 173 | top: -0.5rem; } 174 | 175 | blockquote .sidenote, blockquote .marginnote { margin-right: -82%; 176 | min-width: 59%; 177 | text-align: left; } 178 | 179 | div.fullwidth, table.fullwidth { width: 100%; } 180 | 181 | div.table-wrapper { overflow-x: auto; 182 | font-family: "Trebuchet MS", "Gill Sans", "Gill Sans MT", sans-serif; } 183 | 184 | .sans { font-family: "Gill Sans", "Gill Sans MT", Calibri, sans-serif; 185 | letter-spacing: .03em; } 186 | 187 | code { font-family: Consolas, "Liberation Mono", Menlo, Courier, monospace; 188 | font-size: 1.0rem; 189 | line-height: 1.42; } 190 | 191 | .sans > code { font-size: 1.2rem; } 192 | 193 | h1 > code, h2 > code, h3 > code { font-size: 0.80em; } 194 | 195 | .marginnote > code, .sidenote > code { font-size: 1rem; } 196 | 197 | pre.code { font-size: 0.9rem; 198 | width: 52.5%; 199 | margin-left: 2.5%; 200 | overflow-x: auto; } 201 | 202 | pre.code.fullwidth { width: 90%; } 203 | 204 | .fullwidth { max-width: 90%; 205 | clear:both; } 206 | 207 | span.newthought { font-variant: small-caps; 208 | font-size: 1.2em; } 209 | 210 | input.margin-toggle { display: none; } 211 | 212 | label.sidenote-number { display: inline; } 213 | 214 | label.margin-toggle:not(.sidenote-number) { display: none; } 215 | 216 | .iframe-wrapper { position: relative; 217 | padding-bottom: 56.25%; /* 16:9 */ 218 | padding-top: 25px; 219 | height: 0; } 220 | 221 | .iframe-wrapper iframe { position: absolute; 222 | top: 0; 223 | left: 0; 224 | width: 100%; 225 | height: 100%; } 226 | 227 | @media (max-width: 760px) { body { width: 84%; 228 | padding-left: 8%; 229 | padding-right: 8%; } 230 | hr, section > p, section > footer, section > table { width: 100%; } 231 | pre.code { width: 97%; } 232 | section > ol { width: 90%; } 233 | section > ul { width: 90%; } 234 | figure { max-width: 90%; } 235 | figcaption, figure.fullwidth figcaption { margin-right: 0%; 236 | max-width: none; } 237 | blockquote { margin-left: 1.5em; 238 | margin-right: 0em; } 239 | blockquote p, blockquote footer { width: 100%; } 240 | label.margin-toggle:not(.sidenote-number) { display: inline; } 241 | .sidenote, .marginnote { display: none; } 242 | .margin-toggle:checked + .sidenote, 243 | .margin-toggle:checked + .marginnote { display: block; 244 | float: left; 245 | left: 1rem; 246 | clear: both; 247 | width: 95%; 248 | margin: 1rem 2.5%; 249 | vertical-align: baseline; 250 | position: relative; } 251 | label { cursor: pointer; } 252 | div.table-wrapper, table { width: 85%; } 253 | img { width: 100%; } } 254 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist_internals/.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [ 3 | [ 4 | "env", 5 | { 6 | "esmodules": false, 7 | "targets": { 8 | "browsers": [ 9 | "> 3%" 10 | ] 11 | } 12 | } 13 | ] 14 | ], 15 | "plugins": [ 16 | "@babel/plugin-transform-runtime" 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist_internals/README.md: -------------------------------------------------------------------------------- 1 | # tfjs-vis mnist internals demo 2 | 3 | A demonstration of using tfjs-vis to look at model internals such as summaries and 4 | activations 5 | 6 | See a live version of this [here](https://storage.googleapis.com/tfjs-vis/mnist_internals/dist/index.html). 7 | 8 | To run this locally run `yarn watch-api` from the `demos` folder 9 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist_internals/data.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | 20 | const IMAGE_SIZE = 784; 21 | const NUM_CLASSES = 10; 22 | const NUM_DATASET_ELEMENTS = 65000; 23 | 24 | const NUM_TRAIN_ELEMENTS = 55000; 25 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; 26 | 27 | const MNIST_IMAGES_SPRITE_PATH = 28 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; 29 | const MNIST_LABELS_PATH = 30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; 31 | 32 | /** 33 | * A class that fetches the sprited MNIST dataset and returns shuffled batches. 34 | * 35 | * NOTE: This will get much easier. For now, we do data fetching and 36 | * manipulation manually. 37 | */ 38 | export class MnistData { 39 | constructor() { 40 | this.shuffledTrainIndex = 0; 41 | this.shuffledTestIndex = 0; 42 | } 43 | 44 | async load() { 45 | // Make a request for the MNIST sprited image. 46 | const img = new Image(); 47 | const canvas = document.createElement('canvas'); 48 | const ctx = canvas.getContext('2d'); 49 | const imgRequest = new Promise((resolve, reject) => { 50 | img.crossOrigin = ''; 51 | img.onload = () => { 52 | img.width = img.naturalWidth; 53 | img.height = img.naturalHeight; 54 | 55 | const datasetBytesBuffer = 56 | new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); 57 | 58 | const chunkSize = 5000; 59 | canvas.width = img.width; 60 | canvas.height = chunkSize; 61 | 62 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { 63 | const datasetBytesView = new Float32Array( 64 | datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, 65 | IMAGE_SIZE * chunkSize); 66 | ctx.drawImage( 67 | img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, 68 | chunkSize); 69 | 70 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); 71 | 72 | for (let j = 0; j < imageData.data.length / 4; j++) { 73 | // All channels hold an equal value since the image is grayscale, so 74 | // just read the red channel. 75 | datasetBytesView[j] = imageData.data[j * 4] / 255; 76 | } 77 | } 78 | this.datasetImages = new Float32Array(datasetBytesBuffer); 79 | 80 | resolve(); 81 | }; 82 | img.src = MNIST_IMAGES_SPRITE_PATH; 83 | }); 84 | 85 | const labelsRequest = fetch(MNIST_LABELS_PATH); 86 | const [imgResponse, labelsResponse] = 87 | await Promise.all([imgRequest, labelsRequest]); 88 | 89 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); 90 | 91 | // Create shuffled indices into the train/test set for when we select a 92 | // random dataset element for training / validation. 93 | this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); 94 | this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); 95 | 96 | // Slice the the images and labels into train and test sets. 97 | this.trainImages = 98 | this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 99 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 100 | this.trainLabels = 101 | this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); 102 | this.testLabels = 103 | this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); 104 | } 105 | 106 | nextTrainBatch(batchSize) { 107 | return this.nextBatch( 108 | batchSize, [this.trainImages, this.trainLabels], () => { 109 | this.shuffledTrainIndex = 110 | (this.shuffledTrainIndex + 1) % this.trainIndices.length; 111 | return this.trainIndices[this.shuffledTrainIndex]; 112 | }); 113 | } 114 | 115 | nextTestBatch(batchSize) { 116 | return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => { 117 | this.shuffledTestIndex = 118 | (this.shuffledTestIndex + 1) % this.testIndices.length; 119 | return this.testIndices[this.shuffledTestIndex]; 120 | }); 121 | } 122 | 123 | nextBatch(batchSize, data, index) { 124 | const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE); 125 | const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES); 126 | 127 | for (let i = 0; i < batchSize; i++) { 128 | const idx = index(); 129 | 130 | const image = 131 | data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE); 132 | batchImagesArray.set(image, i * IMAGE_SIZE); 133 | 134 | const label = 135 | data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES); 136 | batchLabelsArray.set(label, i * NUM_CLASSES); 137 | } 138 | 139 | const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]); 140 | const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]); 141 | 142 | return {xs, labels}; 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist_internals/index.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | import * as tfvis from '@tensorflow/tfjs-vis' 20 | import {getModel, loadData} from './model'; 21 | 22 | window.tf = tf; 23 | window.tfvis = tfvis; 24 | 25 | window.data; 26 | window.model; 27 | 28 | async function initData() { 29 | window.data = await loadData(); 30 | window.examples = data.nextTestBatch(10) 31 | 32 | showExamples(document.querySelector('#mnist-examples'), 200); 33 | } 34 | 35 | function initModel() { 36 | window.model = getModel(); 37 | } 38 | 39 | async function showExamples(drawArea, numExamples) { 40 | // Get the examples 41 | const examples = data.nextTestBatch(numExamples); 42 | const tensorsToDispose = []; 43 | const drawPromises = []; 44 | for (let i = 0; i < numExamples; i++) { 45 | const imageTensor = tf.tidy(() => { 46 | return examples.xs.slice([i, 0], [1, examples.xs.shape[1]]).reshape([ 47 | 28, 28, 1 48 | ]); 49 | }); 50 | 51 | // Create a canvas element to render each example 52 | const canvas = document.createElement('canvas'); 53 | canvas.width = 28; 54 | canvas.height = 28; 55 | canvas.style = 'margin: 4px;'; 56 | const drawPromise = tf.browser.toPixels(imageTensor, canvas); 57 | drawArea.appendChild(canvas); 58 | 59 | tensorsToDispose.push(imageTensor); 60 | drawPromises.push(drawPromise); 61 | } 62 | 63 | await Promise.all(drawPromises); 64 | tf.dispose(tensorsToDispose); 65 | } 66 | 67 | function setupListeners() { 68 | document.querySelector('#load-data').addEventListener('click', async (e) => { 69 | await initData(); 70 | document.querySelector('#start-training-1').disabled = false; 71 | e.target.disabled = true; 72 | }); 73 | } 74 | 75 | document.addEventListener('DOMContentLoaded', function() { 76 | initModel(); 77 | setupListeners(); 78 | }); 79 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist_internals/model.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | import {MnistData} from './data'; 20 | 21 | export function getModel() { 22 | const model = tf.sequential(); 23 | 24 | model.add(tf.layers.conv2d({ 25 | inputShape: [28, 28, 1], 26 | kernelSize: 5, 27 | filters: 8, 28 | strides: 1, 29 | activation: 'relu', 30 | kernelInitializer: 'varianceScaling' 31 | })); 32 | 33 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); 34 | model.add(tf.layers.conv2d({ 35 | kernelSize: 5, 36 | filters: 16, 37 | strides: 1, 38 | activation: 'relu', 39 | kernelInitializer: 'varianceScaling' 40 | })); 41 | 42 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); 43 | model.add(tf.layers.flatten()); 44 | 45 | model.add(tf.layers.dense({ 46 | units: 10, 47 | kernelInitializer: 'varianceScaling', 48 | activation: 'softmax' 49 | })); 50 | 51 | const LEARNING_RATE = 0.15; 52 | const optimizer = tf.train.sgd(LEARNING_RATE); 53 | 54 | model.compile({ 55 | optimizer: optimizer, 56 | loss: 'categoricalCrossentropy', 57 | metrics: ['accuracy'], 58 | }); 59 | 60 | return model; 61 | } 62 | 63 | export async function loadData() { 64 | const data = new MnistData(); 65 | await data.load(); 66 | return data; 67 | } 68 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist_internals/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mnist-internals-vis-demo", 3 | "private": true, 4 | "version": "0.0.1", 5 | "description": "", 6 | "main": "index.js", 7 | "scripts": { 8 | "watch": "NODE_ENV=development parcel index.html --no-hmr --open --no-autoinstall", 9 | "build": "NODE_ENV=production parcel build index.html --no-minify --detailed-report --public-url ./" 10 | }, 11 | "author": "", 12 | "license": "Apache-2.0", 13 | "devDependencies": { 14 | "@babel/core": "^7.0.0-0", 15 | "@babel/plugin-transform-runtime": "^7.1.0", 16 | "babel-preset-env": "~1.6.1", 17 | "parcel-bundler": "1.12.3", 18 | "rimraf": "^2.6.2" 19 | }, 20 | "dependencies": { 21 | "@tensorflow/tfjs": "1.0.0", 22 | "@tensorflow/tfjs-vis": "1.0.3" 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /tfjs-vis/demos/mnist_internals/tufte.css: -------------------------------------------------------------------------------- 1 | /* 2 | From https://github.com/edwardtufte/tufte-css 3 | https://github.com/edwardtufte/tufte-css/blob/gh-pages/LICENSE 4 | */ 5 | 6 | @charset "UTF-8"; 7 | 8 | /* Tufte CSS styles */ 9 | html { font-size: 15px; } 10 | 11 | body { width: 87.5%; 12 | margin-left: auto; 13 | margin-right: auto; 14 | padding-left: 12.5%; 15 | font-family: et-book, Palatino, "Palatino Linotype", "Palatino LT STD", "Book Antiqua", Georgia, serif; 16 | background-color: #fffff8; 17 | color: #111; 18 | max-width: 1400px; 19 | counter-reset: sidenote-counter; } 20 | 21 | h1 { font-weight: 400; 22 | margin-top: 4rem; 23 | margin-bottom: 1.5rem; 24 | font-size: 3.2rem; 25 | line-height: 1; } 26 | 27 | h2 { font-style: italic; 28 | font-weight: 400; 29 | margin-top: 2.1rem; 30 | margin-bottom: 1.4rem; 31 | font-size: 2.2rem; 32 | line-height: 1; } 33 | 34 | h3 { font-style: italic; 35 | font-weight: 400; 36 | font-size: 1.7rem; 37 | margin-top: 2rem; 38 | margin-bottom: 1.4rem; 39 | line-height: 1; } 40 | 41 | hr { display: block; 42 | height: 1px; 43 | width: 55%; 44 | border: 0; 45 | border-top: 1px solid #ccc; 46 | margin: 1em 0; 47 | padding: 0; } 48 | 49 | p.subtitle { font-style: italic; 50 | margin-top: 1rem; 51 | margin-bottom: 1rem; 52 | font-size: 1.8rem; 53 | display: block; 54 | line-height: 1; } 55 | 56 | .numeral { font-family: et-book-roman-old-style; } 57 | 58 | .danger { color: red; } 59 | 60 | article { position: relative; 61 | padding: 5rem 0rem; } 62 | 63 | section { padding-top: 1rem; 64 | padding-bottom: 1rem; } 65 | 66 | p, ol, ul { font-size: 1.4rem; 67 | line-height: 2rem; } 68 | 69 | p { margin-top: 1.4rem; 70 | margin-bottom: 1.4rem; 71 | padding-right: 0; 72 | vertical-align: baseline; } 73 | 74 | /* Chapter Epigraphs */ 75 | div.epigraph { margin: 5em 0; } 76 | 77 | div.epigraph > blockquote { margin-top: 3em; 78 | margin-bottom: 3em; } 79 | 80 | div.epigraph > blockquote, div.epigraph > blockquote > p { font-style: italic; } 81 | 82 | div.epigraph > blockquote > footer { font-style: normal; } 83 | 84 | div.epigraph > blockquote > footer > cite { font-style: italic; } 85 | /* end chapter epigraphs styles */ 86 | 87 | blockquote { font-size: 1.4rem; } 88 | 89 | blockquote p { width: 55%; 90 | margin-right: 40px; } 91 | 92 | blockquote footer { width: 55%; 93 | font-size: 1.1rem; 94 | text-align: right; } 95 | 96 | section > p, section > footer, section > table { width: 55%; } 97 | 98 | /* 50 + 5 == 55, to be the same width as paragraph */ 99 | section > ol, section > ul { width: 50%; 100 | -webkit-padding-start: 5%; } 101 | 102 | li:not(:first-child) { margin-top: 0.25rem; } 103 | 104 | figure { padding: 0; 105 | border: 0; 106 | font-size: 100%; 107 | font: inherit; 108 | vertical-align: baseline; 109 | max-width: 55%; 110 | -webkit-margin-start: 0; 111 | -webkit-margin-end: 0; 112 | margin: 0 0 3em 0; } 113 | 114 | figcaption { float: right; 115 | clear: right; 116 | margin-top: 0; 117 | margin-bottom: 0; 118 | font-size: 1.1rem; 119 | line-height: 1.6; 120 | vertical-align: baseline; 121 | position: relative; 122 | max-width: 40%; } 123 | 124 | figure.fullwidth figcaption { margin-right: 24%; } 125 | 126 | /* Links: replicate underline that clears descenders */ 127 | a:link, a:visited { color: inherit; } 128 | 129 | a:link { text-decoration: none; 130 | background: -webkit-linear-gradient(#fffff8, #fffff8), -webkit-linear-gradient(#fffff8, #fffff8), -webkit-linear-gradient(#333, #333); 131 | background: linear-gradient(#fffff8, #fffff8), linear-gradient(#fffff8, #fffff8), linear-gradient(#333, #333); 132 | -webkit-background-size: 0.05em 1px, 0.05em 1px, 1px 1px; 133 | -moz-background-size: 0.05em 1px, 0.05em 1px, 1px 1px; 134 | background-size: 0.05em 1px, 0.05em 1px, 1px 1px; 135 | background-repeat: no-repeat, no-repeat, repeat-x; 136 | text-shadow: 0.03em 0 #fffff8, -0.03em 0 #fffff8, 0 0.03em #fffff8, 0 -0.03em #fffff8, 0.06em 0 #fffff8, -0.06em 0 #fffff8, 0.09em 0 #fffff8, -0.09em 0 #fffff8, 0.12em 0 #fffff8, -0.12em 0 #fffff8, 0.15em 0 #fffff8, -0.15em 0 #fffff8; 137 | background-position: 0% 93%, 100% 93%, 0% 93%; } 138 | 139 | @media screen and (-webkit-min-device-pixel-ratio: 0) { a:link { background-position-y: 87%, 87%, 87%; } } 140 | 141 | a:link::selection { text-shadow: 0.03em 0 #b4d5fe, -0.03em 0 #b4d5fe, 0 0.03em #b4d5fe, 0 -0.03em #b4d5fe, 0.06em 0 #b4d5fe, -0.06em 0 #b4d5fe, 0.09em 0 #b4d5fe, -0.09em 0 #b4d5fe, 0.12em 0 #b4d5fe, -0.12em 0 #b4d5fe, 0.15em 0 #b4d5fe, -0.15em 0 #b4d5fe; 142 | background: #b4d5fe; } 143 | 144 | a:link::-moz-selection { text-shadow: 0.03em 0 #b4d5fe, -0.03em 0 #b4d5fe, 0 0.03em #b4d5fe, 0 -0.03em #b4d5fe, 0.06em 0 #b4d5fe, -0.06em 0 #b4d5fe, 0.09em 0 #b4d5fe, -0.09em 0 #b4d5fe, 0.12em 0 #b4d5fe, -0.12em 0 #b4d5fe, 0.15em 0 #b4d5fe, -0.15em 0 #b4d5fe; 145 | background: #b4d5fe; } 146 | 147 | /* Sidenotes, margin notes, figures, captions */ 148 | img { max-width: 100%; } 149 | 150 | .sidenote, .marginnote { float: right; 151 | clear: right; 152 | margin-right: -60%; 153 | width: 50%; 154 | margin-top: 0; 155 | margin-bottom: 0; 156 | font-size: 1.1rem; 157 | line-height: 1.3; 158 | vertical-align: baseline; 159 | position: relative; } 160 | 161 | .sidenote-number { counter-increment: sidenote-counter; } 162 | 163 | .sidenote-number:after, .sidenote:before { font-family: et-book-roman-old-style; 164 | position: relative; 165 | vertical-align: baseline; } 166 | 167 | .sidenote-number:after { content: counter(sidenote-counter); 168 | font-size: 1rem; 169 | top: -0.5rem; 170 | left: 0.1rem; } 171 | 172 | .sidenote:before { content: counter(sidenote-counter) " "; 173 | top: -0.5rem; } 174 | 175 | blockquote .sidenote, blockquote .marginnote { margin-right: -82%; 176 | min-width: 59%; 177 | text-align: left; } 178 | 179 | div.fullwidth, table.fullwidth { width: 100%; } 180 | 181 | div.table-wrapper { overflow-x: auto; 182 | font-family: "Trebuchet MS", "Gill Sans", "Gill Sans MT", sans-serif; } 183 | 184 | .sans { font-family: "Gill Sans", "Gill Sans MT", Calibri, sans-serif; 185 | letter-spacing: .03em; } 186 | 187 | code { font-family: Consolas, "Liberation Mono", Menlo, Courier, monospace; 188 | font-size: 1.0rem; 189 | line-height: 1.42; } 190 | 191 | .sans > code { font-size: 1.2rem; } 192 | 193 | h1 > code, h2 > code, h3 > code { font-size: 0.80em; } 194 | 195 | .marginnote > code, .sidenote > code { font-size: 1rem; } 196 | 197 | pre.code { font-size: 0.9rem; 198 | width: 52.5%; 199 | margin-left: 2.5%; 200 | overflow-x: auto; } 201 | 202 | pre.code.fullwidth { width: 90%; } 203 | 204 | .fullwidth { max-width: 90%; 205 | clear:both; } 206 | 207 | span.newthought { font-variant: small-caps; 208 | font-size: 1.2em; } 209 | 210 | input.margin-toggle { display: none; } 211 | 212 | label.sidenote-number { display: inline; } 213 | 214 | label.margin-toggle:not(.sidenote-number) { display: none; } 215 | 216 | .iframe-wrapper { position: relative; 217 | padding-bottom: 56.25%; /* 16:9 */ 218 | padding-top: 25px; 219 | height: 0; } 220 | 221 | .iframe-wrapper iframe { position: absolute; 222 | top: 0; 223 | left: 0; 224 | width: 100%; 225 | height: 100%; } 226 | 227 | @media (max-width: 760px) { body { width: 84%; 228 | padding-left: 8%; 229 | padding-right: 8%; } 230 | hr, section > p, section > footer, section > table { width: 100%; } 231 | pre.code { width: 97%; } 232 | section > ol { width: 90%; } 233 | section > ul { width: 90%; } 234 | figure { max-width: 90%; } 235 | figcaption, figure.fullwidth figcaption { margin-right: 0%; 236 | max-width: none; } 237 | blockquote { margin-left: 1.5em; 238 | margin-right: 0em; } 239 | blockquote p, blockquote footer { width: 100%; } 240 | label.margin-toggle:not(.sidenote-number) { display: inline; } 241 | .sidenote, .marginnote { display: none; } 242 | .margin-toggle:checked + .sidenote, 243 | .margin-toggle:checked + .marginnote { display: block; 244 | float: left; 245 | left: 1rem; 246 | clear: both; 247 | width: 95%; 248 | margin: 1rem 2.5%; 249 | vertical-align: baseline; 250 | position: relative; } 251 | label { cursor: pointer; } 252 | div.table-wrapper, table { width: 85%; } 253 | img { width: 100%; } } 254 | -------------------------------------------------------------------------------- /tfjs-vis/docs/visor-usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/tfjs-vis/c394fe8b19f25d7dd7ecdf187b3d7d8d8e5c319a/tfjs-vis/docs/visor-usage.png -------------------------------------------------------------------------------- /tfjs-vis/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@tensorflow/tfjs-vis", 3 | "version": "1.1.0", 4 | "description": "Utilities for in browser visualization with TensorFlow.js", 5 | "repository": "https://github.com/tensorflow/tfjs-vis", 6 | "license": "Apache-2.0", 7 | "private": false, 8 | "main": "dist/index.js", 9 | "jsdelivr": "dist/tfjs-vis.umd.min.js", 10 | "unpkg": "dist/tfjs-vis.umd.min.js", 11 | "types": "dist/index.d.ts", 12 | "scripts": { 13 | "build": "tsc && NODE_ENV=production webpack", 14 | "lint": "tslint -p . -t verbose", 15 | "test": "karma start", 16 | "run-browserstack": "karma start --singleRun --reporters='dots,karma-typescript,BrowserStack' --hostname='bs-local.com'", 17 | "test-ci": "./scripts/test-ci.sh", 18 | "build-npm": "./scripts/build-npm.sh", 19 | "link-local": "yalc link", 20 | "publish-local": "rimraf dist/ && yarn build && yalc push" 21 | }, 22 | "dependencies": { 23 | "d3-format": "^1.3.0", 24 | "d3-selection": "^1.3.0", 25 | "glamor": "^2.20.40", 26 | "glamor-tachyons": "^1.0.0-alpha.1", 27 | "preact": "^8.2.9", 28 | "vega-embed": "3.30.0" 29 | }, 30 | "devDependencies": { 31 | "@tensorflow/tfjs": "~1.0.0", 32 | "@types/d3-format": "^1.3.0", 33 | "@types/d3-selection": "^1.3.2", 34 | "@types/jasmine": "^2.8.8", 35 | "@types/json-stable-stringify": "^1.0.32", 36 | "clang-format": "~1.2.2", 37 | "jasmine": "^3.2.0", 38 | "jasmine-core": "^3.2.0", 39 | "karma": "~4.0.1", 40 | "karma-browserstack-launcher": "^1.3.0", 41 | "karma-chrome-launcher": "^2.2.0", 42 | "karma-firefox-launcher": "^1.1.0", 43 | "karma-jasmine": "^2.0.0", 44 | "karma-safari-launcher": "^1.0.0", 45 | "karma-typescript": "~4.0.0", 46 | "npm-run-all": "^4.1.5", 47 | "preact-render-spy": "^1.3.0", 48 | "rimraf": "^2.6.2", 49 | "tslint": "^5.11.0", 50 | "tslint-no-circular-imports": "^0.5.0", 51 | "typescript": "3.3.3333", 52 | "webpack": "^4.16.3", 53 | "webpack-cli": "^3.1.0", 54 | "yalc": "~1.0.0-pre.21" 55 | }, 56 | "peerDependencies": { 57 | "@tensorflow/tfjs": ">= 1.0.0" 58 | }, 59 | "alias": { 60 | "react": "preact", 61 | "react-dom": "preact" 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /tfjs-vis/scripts/build-npm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | rimraf dist/ 20 | yarn 21 | yarn build 22 | yarn lint 23 | npm pack 24 | -------------------------------------------------------------------------------- /tfjs-vis/scripts/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | # This script deploys examples to GCP so they can be statically hosted. 18 | # 19 | # The script can either be used without arguments, which deploys all demos: 20 | # ./deploy.sh 21 | # Or you can pass a single argument specifying a single demo to deploy: 22 | # ./deploy.sh mnist 23 | # 24 | # This script assumes that a directory in this repo corresponds to an example. 25 | # 26 | # Example directories should have: 27 | # - package.json 28 | # - `yarn build` script which generates a dist/ folder in the example directory. 29 | 30 | if [ -z "$1" ] 31 | then 32 | EXAMPLES="api mnist mnist_internals" 33 | else 34 | EXAMPLES=$1 35 | if [ ! -d "$EXAMPLES" ]; then 36 | echo "Error: Could not find example $1" 37 | echo "Make sure the first argument to this script matches the example dir" 38 | exit 1 39 | fi 40 | fi 41 | 42 | cd demos 43 | for i in $EXAMPLES; do 44 | cd ${i} 45 | # Strip any trailing slashes. 46 | EXAMPLE_NAME=${i%/} 47 | 48 | echo "building ${EXAMPLE_NAME}..." 49 | yarn 50 | rm -rf dist .cache 51 | yarn build 52 | # Remove files in the example directory (but not sub-directories). 53 | gsutil -m rm gs://tfjs-vis/$EXAMPLE_NAME/dist/* 54 | # Gzip and copy all the dist files. 55 | # The trailing slash is important so we get $EXAMPLE_NAME/dist/. 56 | gsutil -m cp -Z -r dist gs://tfjs-vis/$EXAMPLE_NAME/ 57 | cd .. 58 | done 59 | -------------------------------------------------------------------------------- /tfjs-vis/scripts/publish-npm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | # Before you run this script, do this: 18 | # 1) Update the version in package.json 19 | # 2) Run ./scripts/build-npm from the base dir of the project. 20 | 21 | # Then: 22 | # 3) Checkout the master branch of this repo. 23 | # 4) Run this script as `./scripts/publish-npm.sh` from the project base dir. 24 | 25 | set -e 26 | 27 | BRANCH=`git rev-parse --abbrev-ref HEAD` 28 | 29 | if [ "$BRANCH" != "master" ]; then 30 | echo "Error: Switch to the master branch before publishing." 31 | exit 32 | fi 33 | 34 | yarn build-npm 35 | npm publish 36 | echo 'Yay! Published a new package to npm.' 37 | -------------------------------------------------------------------------------- /tfjs-vis/scripts/tag-version.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Copyright 2018 Google LLC. All Rights Reserved. 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // ============================================================================= 16 | 17 | 18 | // Run this script from the base directory (not the script directory): 19 | // ./scripts/tag-version 20 | 21 | // tslint:disable-next-line:no-require-imports 22 | const fs = require('fs'); 23 | // tslint:disable-next-line:no-require-imports 24 | const exec = require('child_process').exec; 25 | 26 | const version = JSON.parse(fs.readFileSync('package.json', 'utf8')).version; 27 | const tag = `v${version}`; 28 | 29 | exec(`git tag ${tag}`, err => { 30 | if (err) { 31 | throw new Error(`Could not git tag with ${tag}: ${err.message}.`); 32 | } 33 | console.log(`Successfully tagged with ${tag}.`); 34 | }); 35 | 36 | exec(`git push --tags`, err => { 37 | if (err) { 38 | throw new Error(`Could not push git tags: ${err.message}.`); 39 | } 40 | console.log(`Successfully pushed tags.`); 41 | }); 42 | -------------------------------------------------------------------------------- /tfjs-vis/scripts/test-ci.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | yarn 20 | yarn build 21 | yarn lint 22 | 23 | # Run the first karma separately so it can download the BrowserStack binary 24 | # without conflicting with others. 25 | yarn run-browserstack --browsers=bs_chrome_mac 26 | 27 | # Run the rest of the karma tests in parallel. These runs will reuse the 28 | # already downloaded binary. 29 | npm-run-all -p -c --aggregate-output \ 30 | "run-browserstack --browsers=bs_firefox_mac" \ 31 | "run-browserstack --browsers=bs_safari_mac" 32 | -------------------------------------------------------------------------------- /tfjs-vis/src/components/surface.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import { h, Component } from 'preact'; 19 | import { css } from 'glamor'; 20 | import { tachyons as tac } from 'glamor-tachyons'; 21 | import { SurfaceInfoStrict, StyleOptions } from '../types'; 22 | 23 | // Internal Props 24 | interface SurfaceProps extends SurfaceInfoStrict { 25 | visible: boolean; 26 | registerSurface: (name: string, tab: string, surface: SurfaceComponent) 27 | => void; 28 | } 29 | 30 | /** 31 | * A surface is container for visualizations and other rendered thigns. 32 | * It consists of a containing DOM Element, a label and an empty drawArea. 33 | */ 34 | export class SurfaceComponent extends Component { 35 | 36 | static defaultStyles: Partial = { 37 | maxWidth: '580px', 38 | maxHeight: '580px', 39 | height: 'auto', 40 | width: 'auto', 41 | }; 42 | 43 | container: HTMLElement; 44 | label: HTMLElement; 45 | drawArea: HTMLElement; 46 | 47 | componentDidMount() { 48 | const { name, tab } = this.props; 49 | this.props.registerSurface(name, tab, this); 50 | } 51 | 52 | componentDidUpdate() { 53 | // Prevent re-rendering of this component as it 54 | // is primarily controlled outside of this class 55 | return false; 56 | } 57 | 58 | render() { 59 | const { name, visible, styles } = this.props; 60 | const finalStyles = { 61 | ...SurfaceComponent.defaultStyles, 62 | ...styles, 63 | }; 64 | 65 | const { width, height, } = finalStyles; 66 | let { maxHeight, maxWidth, } = finalStyles; 67 | maxHeight = height === SurfaceComponent.defaultStyles.height ? 68 | maxHeight : height; 69 | maxWidth = width === SurfaceComponent.defaultStyles.width ? 70 | maxWidth : width; 71 | 72 | const surfaceStyle = css({ 73 | display: visible ? 'block' : 'none', 74 | backgroundColor: 'white', 75 | marginTop: '10px', 76 | marginBottom: '10px', 77 | boxShadow: '0 0 6px -3px #777', 78 | padding: '10px !important', 79 | height, 80 | width, 81 | maxHeight, 82 | maxWidth, 83 | overflow: 'auto', 84 | }); 85 | 86 | const labelStyle = css({ 87 | backgroundColor: 'white', 88 | boxSizing: 'border-box', 89 | borderBottom: '1px solid #357EDD', 90 | lineHeight: '2em', 91 | marginBottom: '20px', 92 | ...tac('fw6 tc') 93 | }); 94 | 95 | const drawAreaStyle = css({ 96 | boxSizing: 'border-box', 97 | }); 98 | 99 | return ( 100 |
this.container = r!} 103 | data-visible={visible} 104 | > 105 |
this.label = r!}> 106 | {name} 107 |
108 | 109 |
this.drawArea = r!} 112 | /> 113 |
114 | ); 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /tfjs-vis/src/components/tabs.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import { h, Component } from 'preact'; 19 | import { css } from 'glamor'; 20 | import { tachyons as tac } from 'glamor-tachyons'; 21 | 22 | interface TabsProps { 23 | tabNames: string[]; 24 | activeTab: string | null; 25 | handleClick: (tabName: string) => void; 26 | } 27 | 28 | /** 29 | * Renders a container for tab links 30 | */ 31 | export class Tabs extends Component { 32 | render() { 33 | const { tabNames, activeTab, handleClick } = this.props; 34 | 35 | const tabs = tabNames.length > 0 ? 36 | tabNames.map((name) => ( 37 | 41 | {name} 42 | 43 | )) 44 | : null; 45 | 46 | const tabStyle = css({ 47 | overflowX: 'scroll', 48 | overflowY: 'hidden', 49 | whiteSpace: 'nowrap', 50 | ...tac('bb b--light-gray pb3 mt3') 51 | }); 52 | 53 | return ( 54 |
55 | {tabs} 56 |
57 | ); 58 | } 59 | } 60 | 61 | interface TabProps { 62 | id: string; 63 | isActive: boolean; 64 | handleClick: (tabName: string) => void; 65 | } 66 | 67 | /** 68 | * A link representing a tab. Note that the component does not contain the 69 | * tab content 70 | */ 71 | class Tab extends Component { 72 | 73 | render() { 74 | const { children, isActive, handleClick, id } = this.props; 75 | 76 | const tabStyle = css({ 77 | borderBottomColor: isActive ? '#357EDD' : '#AAAAAA', 78 | borderBottomWidth: '1px', 79 | borderBottomStyle: 'solid', 80 | cursor: 'pointer', 81 | ':hover': { 82 | color: '#357EDD' 83 | }, 84 | display: 'inline-block', 85 | ...tac('b f5 mr3 pa2') 86 | }); 87 | 88 | return ( 89 | handleClick(id)} 92 | > 93 | {children} 94 | 95 | ); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /tfjs-vis/src/components/visor_test.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import { h } from 'preact'; 19 | import { render } from 'preact-render-spy'; 20 | 21 | import { VisorComponent } from './visor'; 22 | import { SurfaceInfoStrict } from '../types'; 23 | 24 | afterEach(() => { 25 | document.body.innerHTML = ''; 26 | }); 27 | 28 | describe('Visor Component', () => { 29 | it('renders an empty visor', () => { 30 | const wrapper = render( 31 | 32 | ); 33 | 34 | expect(wrapper.find('.visor').length).toBe(1); 35 | expect(wrapper.find('.visor-surfaces').length).toBe(1); 36 | expect(wrapper.find('.tf-surface').length).toBe(0); 37 | expect(wrapper.state().isOpen).toBe(true); 38 | expect(wrapper.state().isFullscreen).toBe(false); 39 | }); 40 | 41 | it('renders an empty and closed visor', () => { 42 | const wrapper = render( 43 | 47 | ); 48 | 49 | expect(wrapper.find('.visor').length).toBe(1); 50 | expect(wrapper.state().isOpen).toBe(false); 51 | expect(wrapper.state().isFullscreen).toBe(false); 52 | }); 53 | 54 | it('renders a surface', () => { 55 | const surfaceList: SurfaceInfoStrict[] = [ 56 | { name: 'surface 1', tab: 'tab 1' }, 57 | ]; 58 | 59 | const wrapper = render( 60 | 61 | ); 62 | 63 | expect(wrapper.find('.tf-surface').length).toBe(1); 64 | expect(wrapper.find('.tf-surface').text()).toMatch('surface 1'); 65 | expect(wrapper.find('.tf-tab').length).toBe(1); 66 | expect(wrapper.find('.tf-tab').text()).toMatch('tab 1'); 67 | }); 68 | 69 | it('switches tabs on click', () => { 70 | const surfaceList: SurfaceInfoStrict[] = [ 71 | { name: 'surface 1', tab: 'tab 1' }, 72 | { name: 'surface 2', tab: 'tab 2' }, 73 | ]; 74 | 75 | const wrapper = render( 76 | 77 | ); 78 | 79 | expect(wrapper.find('.tf-tab').length).toBe(2); 80 | expect(wrapper.state().activeTab).toEqual('tab 2'); 81 | 82 | // Clicks 83 | wrapper.find('.tf-tab').at(0).simulate('click'); 84 | expect(wrapper.state().activeTab).toEqual('tab 1'); 85 | expect(wrapper.find('.tf-tab').at(0).attr('data-isactive' as never)) 86 | .toEqual(true); 87 | expect(wrapper.find('.tf-tab').at(1).attr('data-isactive' as never)) 88 | .toEqual(false); 89 | 90 | expect(wrapper.find('.tf-surface').at(0).attr('data-visible' as never)) 91 | .toEqual(true); 92 | expect(wrapper.find('.tf-surface').at(1).attr('data-visible' as never)) 93 | .toEqual(false); 94 | 95 | wrapper.find('.tf-tab').at(1).simulate('click'); 96 | expect(wrapper.state().activeTab).toEqual('tab 2'); 97 | expect(wrapper.find('.tf-tab').at(0).attr('data-isactive' as never)) 98 | .toEqual(false); 99 | expect(wrapper.find('.tf-tab').at(1).attr('data-isactive' as never)) 100 | .toEqual(true); 101 | 102 | expect(wrapper.find('.tf-surface').at(0).attr('data-visible' as never)) 103 | .toEqual(false); 104 | expect(wrapper.find('.tf-surface').at(1).attr('data-visible' as never)) 105 | .toEqual(true); 106 | }); 107 | 108 | it('hides on close button click', () => { 109 | const surfaceList: SurfaceInfoStrict[] = []; 110 | 111 | const wrapper = render( 112 | 113 | ); 114 | 115 | expect(wrapper.state().isOpen).toEqual(true); 116 | 117 | const hideButton = wrapper.find('.visor-controls').children().at(1); 118 | expect(hideButton.text()).toEqual('Hide'); 119 | 120 | hideButton.simulate('click'); 121 | expect(wrapper.state().isOpen).toEqual(false); 122 | }); 123 | 124 | it('maximises and minimizes', () => { 125 | const surfaceList: SurfaceInfoStrict[] = []; 126 | 127 | const wrapper = render( 128 | 129 | ); 130 | 131 | expect(wrapper.state().isOpen).toEqual(true); 132 | 133 | let toggleButton; 134 | toggleButton = wrapper.find('.visor-controls').children().at(0); 135 | expect(toggleButton.text()).toEqual('Maximize'); 136 | expect(wrapper.state().isFullscreen).toEqual(false); 137 | expect(wrapper.find('.visor').at(0).attr('data-isfullscreen' as never)) 138 | .toEqual(false); 139 | 140 | toggleButton.simulate('click'); 141 | toggleButton = wrapper.find('.visor-controls').children().at(0); 142 | expect(toggleButton.text()).toEqual('Minimize'); 143 | expect(wrapper.state().isFullscreen).toEqual(true); 144 | expect(wrapper.find('.visor').at(0).attr('data-isfullscreen' as never)) 145 | .toEqual(true); 146 | }); 147 | }); 148 | -------------------------------------------------------------------------------- /tfjs-vis/src/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {barchart} from './render/barchart'; 19 | import {confusionMatrix} from './render/confusion_matrix'; 20 | import {heatmap} from './render/heatmap'; 21 | import {histogram} from './render/histogram'; 22 | import {linechart} from './render/linechart'; 23 | import {scatterplot} from './render/scatterplot'; 24 | import {table} from './render/table'; 25 | import {fitCallbacks, history} from './show/history'; 26 | import {layer, modelSummary} from './show/model'; 27 | import {showPerClassAccuracy} from './show/quality'; 28 | import {valuesDistribution} from './show/tensor'; 29 | import {accuracy, confusionMatrix as metricsConfusionMatrix, perClassAccuracy} from './util/math'; 30 | 31 | const render = { 32 | barchart, 33 | table, 34 | histogram, 35 | linechart, 36 | scatterplot, 37 | confusionMatrix, 38 | heatmap, 39 | }; 40 | 41 | const metrics = { 42 | accuracy, 43 | perClassAccuracy, 44 | confusionMatrix: metricsConfusionMatrix, 45 | }; 46 | 47 | const show = { 48 | history, 49 | fitCallbacks, 50 | perClassAccuracy: showPerClassAccuracy, 51 | valuesDistribution, 52 | layer, 53 | modelSummary, 54 | }; 55 | 56 | export {visor} from './visor'; 57 | export {render}; 58 | export {metrics}; 59 | export {show}; 60 | 61 | export * from './types'; 62 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/barchart.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import embed, {Mode, Result as EmbedRes, VisualizationSpec} from 'vega-embed'; 19 | 20 | import {Drawable, VisOptions} from '../types'; 21 | 22 | import {getDrawArea, nextFrame, shallowEquals} from './render_utils'; 23 | 24 | /** 25 | * Renders a barchart. 26 | * 27 | * ```js 28 | * const data = [ 29 | * { index: 0, value: 50 }, 30 | * { index: 1, value: 100 }, 31 | * { index: 2, value: 150 }, 32 | * ]; 33 | * 34 | * // Render to visor 35 | * const surface = { name: 'Bar chart', tab: 'Charts' }; 36 | * tfvis.render.barchart(surface, data); 37 | * ``` 38 | * 39 | * @param data Data in the following format, (an array of objects) 40 | * `[ {index: number, value: number} ... ]` 41 | * 42 | * @returns Promise - indicates completion of rendering 43 | */ 44 | /** @doc {heading: 'Charts', namespace: 'render'} */ 45 | export async function barchart( 46 | container: Drawable, data: Array<{index: number; value: number;}>, 47 | opts: VisOptions = {}): Promise { 48 | const drawArea = getDrawArea(container); 49 | const values = data; 50 | const options = Object.assign({}, defaultOpts, opts); 51 | 52 | // If we have rendered this chart before with the same options we can do a 53 | // data only update, else we do a regular re-render. 54 | if (instances.has(drawArea)) { 55 | const instanceInfo = instances.get(drawArea)!; 56 | if (shallowEquals(options, instanceInfo.lastOptions)) { 57 | await nextFrame(); 58 | const view = instanceInfo.view; 59 | const changes = view.changeset().remove(() => true).insert(values); 60 | await view.change('values', changes).runAsync(); 61 | return; 62 | } 63 | } 64 | 65 | const {xLabel, yLabel, xType, yType} = options; 66 | 67 | let xAxis: {}|null = null; 68 | if (xLabel != null) { 69 | xAxis = {title: xLabel}; 70 | } 71 | 72 | let yAxis: {}|null = null; 73 | if (yLabel != null) { 74 | yAxis = {title: yLabel}; 75 | } 76 | 77 | const embedOpts = { 78 | actions: false, 79 | mode: 'vega-lite' as Mode, 80 | defaultStyle: false, 81 | }; 82 | 83 | const spec: VisualizationSpec = { 84 | 'width': options.width || drawArea.clientWidth, 85 | 'height': options.height || drawArea.clientHeight, 86 | 'padding': 0, 87 | 'autosize': { 88 | 'type': 'fit', 89 | 'contains': 'padding', 90 | 'resize': true, 91 | }, 92 | 'config': { 93 | 'axis': { 94 | 'labelFontSize': options.fontSize, 95 | 'titleFontSize': options.fontSize, 96 | }, 97 | 'text': {'fontSize': options.fontSize}, 98 | 'legend': { 99 | 'labelFontSize': options.fontSize, 100 | 'titleFontSize': options.fontSize, 101 | } 102 | }, 103 | 'data': {'values': values, 'name': 'values'}, 104 | 'mark': 'bar', 105 | 'encoding': { 106 | 'x': {'field': 'index', 'type': xType, 'axis': xAxis}, 107 | 'y': {'field': 'value', 'type': yType, 'axis': yAxis} 108 | } 109 | }; 110 | 111 | await nextFrame(); 112 | const embedRes = await embed(drawArea, spec, embedOpts); 113 | instances.set(drawArea, { 114 | view: embedRes.view, 115 | lastOptions: options, 116 | }); 117 | } 118 | 119 | const defaultOpts = { 120 | xLabel: '', 121 | yLabel: '', 122 | xType: 'ordinal', 123 | yType: 'quantitative', 124 | fontSize: 11, 125 | }; 126 | 127 | // We keep a map of containers to chart instances in order to reuse the instance 128 | // where possible. 129 | const instances: Map = 130 | new Map(); 131 | 132 | interface InstanceInfo { 133 | // tslint:disable-next-line:no-any 134 | view: any; 135 | lastOptions: VisOptions; 136 | } 137 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/barchart_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {barchart} from './barchart'; 19 | 20 | describe('renderBarChart', () => { 21 | let pixelRatio: number; 22 | 23 | beforeEach(() => { 24 | document.body.innerHTML = '
'; 25 | pixelRatio = window.devicePixelRatio; 26 | }); 27 | 28 | it('renders a bar chart', async () => { 29 | const data = [ 30 | {index: 0, value: 50}, 31 | {index: 1, value: 100}, 32 | {index: 2, value: 230}, 33 | ]; 34 | 35 | const container = document.getElementById('container') as HTMLElement; 36 | await barchart(container, data); 37 | 38 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 39 | }); 40 | 41 | it('re-renders a bar chart', async () => { 42 | const data = [ 43 | {index: 0, value: 50}, 44 | {index: 1, value: 100}, 45 | {index: 2, value: 230}, 46 | ]; 47 | 48 | const container = document.getElementById('container') as HTMLElement; 49 | 50 | await barchart(container, data); 51 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 52 | 53 | await barchart(container, data); 54 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 55 | }); 56 | 57 | it('updates a bar chart', async () => { 58 | let data = [ 59 | {index: 0, value: 50}, 60 | {index: 1, value: 100}, 61 | {index: 2, value: 150}, 62 | ]; 63 | 64 | const container = document.getElementById('container') as HTMLElement; 65 | 66 | await barchart(container, data); 67 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 68 | 69 | data = [ 70 | {index: 0, value: 50}, 71 | {index: 1, value: 100}, 72 | {index: 2, value: 150}, 73 | ]; 74 | 75 | await barchart(container, data); 76 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 77 | }); 78 | 79 | it('sets width of chart', async () => { 80 | const data = [ 81 | {index: 0, value: 50}, 82 | {index: 1, value: 100}, 83 | {index: 2, value: 230}, 84 | ]; 85 | 86 | const container = document.getElementById('container') as HTMLElement; 87 | await barchart(container, data, {width: 400}); 88 | 89 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 90 | expect(document.querySelectorAll('canvas').length).toBe(1); 91 | expect(document.querySelector('canvas')!.width).toBe(400 * pixelRatio); 92 | }); 93 | 94 | it('sets height of chart', async () => { 95 | const data = [ 96 | {index: 0, value: 50}, 97 | {index: 1, value: 100}, 98 | {index: 2, value: 230}, 99 | ]; 100 | 101 | const container = document.getElementById('container') as HTMLElement; 102 | await barchart(container, data, {height: 200}); 103 | 104 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 105 | expect(document.querySelectorAll('canvas').length).toBe(1); 106 | expect(document.querySelector('canvas')!.height).toBe(200 * pixelRatio); 107 | }); 108 | }); 109 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/confusion_matrix.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import embed, {Mode, VisualizationSpec} from 'vega-embed'; 19 | 20 | import {ConfusionMatrixData, ConfusionMatrixOptions, Drawable,} from '../types'; 21 | 22 | import {getDrawArea} from './render_utils'; 23 | 24 | /** 25 | * Renders a confusion matrix. 26 | * 27 | * Can optionally exclude the diagonal from being shaded if one wants the visual 28 | * focus to be on the incorrect classifications. Note that if the classification 29 | * is perfect (i.e. only the diagonal has values) then the diagonal will always 30 | * be shaded. 31 | * 32 | * ```js 33 | * const rows = 5; 34 | * const cols = 5; 35 | * const values = []; 36 | * for (let i = 0; i < rows; i++) { 37 | * const row = [] 38 | * for (let j = 0; j < cols; j++) { 39 | * row.push(Math.round(Math.random() * 50)); 40 | * } 41 | * values.push(row); 42 | * } 43 | * const data = { values }; 44 | * 45 | * // Render to visor 46 | * const surface = { name: 'Confusion Matrix', tab: 'Charts' }; 47 | * tfvis.render.confusionMatrix(surface, data); 48 | * ``` 49 | * 50 | * ```js 51 | * // The diagonal can be excluded from shading. 52 | * 53 | * const data = { 54 | * values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 55 | * } 56 | * 57 | * // Render to visor 58 | * const surface = { 59 | * name: 'Confusion Matrix with Excluded Diagonal', tab: 'Charts' 60 | * }; 61 | * 62 | * tfvis.render.confusionMatrix(surface, data, { 63 | * shadeDiagonal: false 64 | * }); 65 | * ``` 66 | */ 67 | /** @doc {heading: 'Charts', namespace: 'render'} */ 68 | export async function confusionMatrix( 69 | container: Drawable, data: ConfusionMatrixData, 70 | opts: ConfusionMatrixOptions = {}): Promise { 71 | const options = Object.assign({}, defaultOpts, opts); 72 | const drawArea = getDrawArea(container); 73 | 74 | // Format data for vega spec; an array of objects, one for for each cell 75 | // in the matrix. 76 | const values: MatrixEntry[] = []; 77 | 78 | const inputArray = data.values; 79 | const tickLabels = data.tickLabels || []; 80 | const generateLabels = tickLabels.length === 0; 81 | 82 | let nonDiagonalIsAllZeroes = true; 83 | for (let i = 0; i < inputArray.length; i++) { 84 | const label = generateLabels ? `Class ${i}` : tickLabels[i]; 85 | 86 | if (generateLabels) { 87 | tickLabels.push(label); 88 | } 89 | 90 | for (let j = 0; j < inputArray[i].length; j++) { 91 | const prediction = generateLabels ? `Class ${j}` : tickLabels[j]; 92 | 93 | const count = inputArray[i][j]; 94 | if (i === j && !options.shadeDiagonal) { 95 | values.push({ 96 | label, 97 | prediction, 98 | diagCount: count, 99 | noFill: true, 100 | }); 101 | } else { 102 | values.push({ 103 | label, 104 | prediction, 105 | count, 106 | }); 107 | // When not shading the diagonal we want to check if there is a non 108 | // zero value. If all values are zero we will not color them as the 109 | // scale will be invalid. 110 | if (count !== 0) { 111 | nonDiagonalIsAllZeroes = false; 112 | } 113 | } 114 | } 115 | } 116 | 117 | if (!options.shadeDiagonal && nonDiagonalIsAllZeroes) { 118 | // User has specified requested not to shade the diagonal but all the other 119 | // values are zero. We have two choices, don't shade the anything or only 120 | // shade the diagonal. We choose to shade the diagonal as that is likely 121 | // more helpful even if it is not what the user specified. 122 | for (const val of values) { 123 | if (val.noFill === true) { 124 | val.noFill = false; 125 | val.count = val.diagCount; 126 | } 127 | } 128 | } 129 | 130 | const embedOpts = { 131 | actions: false, 132 | mode: 'vega-lite' as Mode, 133 | defaultStyle: false, 134 | }; 135 | 136 | const spec: VisualizationSpec = { 137 | 'width': options.width || drawArea.clientWidth, 138 | 'height': options.height || drawArea.clientHeight, 139 | 'padding': 0, 140 | 'autosize': { 141 | 'type': 'fit', 142 | 'contains': 'padding', 143 | 'resize': true, 144 | }, 145 | 'config': { 146 | 'axis': { 147 | 'labelFontSize': options.fontSize, 148 | 'titleFontSize': options.fontSize, 149 | }, 150 | 'text': {'fontSize': options.fontSize}, 151 | 'legend': { 152 | 'labelFontSize': options.fontSize, 153 | 'titleFontSize': options.fontSize, 154 | } 155 | }, 156 | 'data': {'values': values}, 157 | 'encoding': { 158 | 'x': { 159 | 'field': 'prediction', 160 | 'type': 'ordinal', 161 | // Maintain sort order of the axis if labels is passed in 162 | 'scale': {'domain': tickLabels}, 163 | }, 164 | 'y': { 165 | 'field': 'label', 166 | 'type': 'ordinal', 167 | // Maintain sort order of the axis if labels is passed in 168 | 'scale': {'domain': tickLabels}, 169 | }, 170 | }, 171 | 'layer': [ 172 | { 173 | // The matrix 174 | 'mark': { 175 | 'type': 'rect', 176 | }, 177 | 'encoding': { 178 | 'fill': { 179 | 'condition': { 180 | 'test': 'datum["noFill"] == true', 181 | 'value': 'white', 182 | }, 183 | 'field': 'count', 184 | 'type': 'quantitative', 185 | 'scale': {'range': ['#f7fbff', '#4292c6']}, 186 | }, 187 | 'tooltip': { 188 | 'condition': { 189 | 'test': 'datum["noFill"] == true', 190 | 'field': 'diagCount', 191 | 'type': 'nominal', 192 | }, 193 | 'field': 'count', 194 | 'type': 'nominal', 195 | } 196 | }, 197 | 198 | }, 199 | ] 200 | }; 201 | 202 | if (options.showTextOverlay) { 203 | spec.layer.push({ 204 | // The text labels 205 | 'mark': {'type': 'text', 'baseline': 'middle'}, 206 | 'encoding': { 207 | 'text': { 208 | 'condition': { 209 | 'test': 'datum["noFill"] == true', 210 | 'field': 'diagCount', 211 | 'type': 'nominal', 212 | }, 213 | 'field': 'count', 214 | 'type': 'nominal', 215 | }, 216 | } 217 | }); 218 | } 219 | 220 | await embed(drawArea, spec, embedOpts); 221 | return Promise.resolve(); 222 | } 223 | 224 | const defaultOpts = { 225 | xLabel: null, 226 | yLabel: null, 227 | xType: 'nominal', 228 | yType: 'nominal', 229 | shadeDiagonal: true, 230 | fontSize: 12, 231 | showTextOverlay: true, 232 | height: 400, 233 | }; 234 | 235 | interface MatrixEntry { 236 | label: string; 237 | prediction: string; 238 | count?: number; 239 | diagCount?: number; 240 | noFill?: boolean; 241 | } 242 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/confusion_matrix_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {ConfusionMatrixData} from '../types'; 19 | 20 | import {confusionMatrix} from './confusion_matrix'; 21 | 22 | describe('renderConfusionMatrix', () => { 23 | let pixelRatio: number; 24 | 25 | beforeEach(() => { 26 | document.body.innerHTML = '
'; 27 | pixelRatio = window.devicePixelRatio; 28 | }); 29 | 30 | it('renders a chart', async () => { 31 | const data: ConfusionMatrixData = { 32 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 33 | tickLabels: ['cheese', 'pig', 'font'], 34 | }; 35 | 36 | const container = document.getElementById('container') as HTMLElement; 37 | await confusionMatrix(container, data); 38 | 39 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 40 | }); 41 | 42 | it('renders a chart with shaded diagonal', async () => { 43 | const data: ConfusionMatrixData = { 44 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 45 | tickLabels: ['cheese', 'pig', 'font'], 46 | }; 47 | 48 | const container = document.getElementById('container') as HTMLElement; 49 | await confusionMatrix(container, data, {shadeDiagonal: true}); 50 | 51 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 52 | }); 53 | 54 | it('renders the chart with generated labels', async () => { 55 | const data: ConfusionMatrixData = { 56 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 57 | }; 58 | 59 | const container = document.getElementById('container') as HTMLElement; 60 | 61 | await confusionMatrix(container, data); 62 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 63 | }); 64 | 65 | it('updates the chart', async () => { 66 | let data: ConfusionMatrixData = { 67 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 68 | tickLabels: ['cheese', 'pig', 'font'], 69 | }; 70 | 71 | const container = document.getElementById('container') as HTMLElement; 72 | 73 | await confusionMatrix(container, data); 74 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 75 | 76 | data = { 77 | values: [[43, 2, 8], [1, 7, 2], [3, 3, 20]], 78 | tickLabels: ['cheese', 'pig', 'font'], 79 | }; 80 | 81 | await confusionMatrix(container, data); 82 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 83 | }); 84 | 85 | it('sets width of chart', async () => { 86 | const data: ConfusionMatrixData = { 87 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 88 | tickLabels: ['cheese', 'pig', 'font'], 89 | }; 90 | 91 | const container = document.getElementById('container') as HTMLElement; 92 | await confusionMatrix(container, data, {width: 400}); 93 | 94 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 95 | expect(document.querySelectorAll('canvas').length).toBe(1); 96 | expect(document.querySelector('canvas')!.width).toBe(400 * pixelRatio); 97 | }); 98 | 99 | it('sets height of chart', async () => { 100 | const data: ConfusionMatrixData = { 101 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 102 | tickLabels: ['cheese', 'pig', 'font'], 103 | }; 104 | 105 | const container = document.getElementById('container') as HTMLElement; 106 | await confusionMatrix(container, data, {height: 200}); 107 | 108 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 109 | expect(document.querySelectorAll('canvas').length).toBe(1); 110 | expect(document.querySelector('canvas')!.height).toBe(200 * pixelRatio); 111 | }); 112 | }); 113 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/heatmap.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | import embed, {Mode, VisualizationSpec} from 'vega-embed'; 20 | 21 | import {Drawable, HeatmapData, HeatmapOptions} from '../types'; 22 | import {assert} from '../util/utils'; 23 | 24 | import {getDrawArea} from './render_utils'; 25 | 26 | /** 27 | * Renders a heatmap. 28 | * 29 | * ```js 30 | * const cols = 50; 31 | * const rows = 20; 32 | * const values = []; 33 | * for (let i = 0; i < cols; i++) { 34 | * const col = [] 35 | * for (let j = 0; j < rows; j++) { 36 | * col.push(i * j) 37 | * } 38 | * values.push(col); 39 | * } 40 | * const data = { values }; 41 | * 42 | * // Render to visor 43 | * const surface = { name: 'Heatmap', tab: 'Charts' }; 44 | * tfvis.render.heatmap(surface, data); 45 | * ``` 46 | * 47 | * ```js 48 | * const data = { 49 | * values: [[4, 2, 8, 20], [1, 7, 2, 10], [3, 3, 20, 13]], 50 | * xTickLabels: ['cheese', 'pig', 'font'], 51 | * yTickLabels: ['speed', 'smoothness', 'dexterity', 'mana'], 52 | * } 53 | * 54 | * // Render to visor 55 | * const surface = { name: 'Heatmap w Custom Labels', tab: 'Charts' }; 56 | * tfvis.render.heatmap(surface, data); 57 | * ``` 58 | * 59 | */ 60 | /** @doc {heading: 'Charts', namespace: 'render'} */ 61 | export async function heatmap( 62 | container: Drawable, data: HeatmapData, 63 | opts: HeatmapOptions = {}): Promise { 64 | const options = Object.assign({}, defaultOpts, opts); 65 | const drawArea = getDrawArea(container); 66 | 67 | let inputValues = data.values; 68 | if (options.rowMajor) { 69 | let originalShape: number[]; 70 | let transposed: tf.Tensor2D; 71 | if (inputValues instanceof tf.Tensor) { 72 | originalShape = inputValues.shape; 73 | transposed = inputValues.transpose(); 74 | } else { 75 | originalShape = [inputValues.length, inputValues[0].length]; 76 | transposed = 77 | tf.tidy(() => tf.tensor2d(inputValues as number[][]).transpose()); 78 | } 79 | 80 | assert( 81 | transposed.rank === 2, 82 | 'Input to renderHeatmap must be a 2d array or Tensor2d'); 83 | 84 | // Download the intermediate tensor values and 85 | // dispose the transposed tensor. 86 | inputValues = await transposed.array(); 87 | transposed.dispose(); 88 | 89 | const transposedShape = [inputValues.length, inputValues[0].length]; 90 | assert( 91 | originalShape[0] === transposedShape[1] && 92 | originalShape[1] === transposedShape[0], 93 | `Unexpected transposed shape. Original ${originalShape} : Transposed ${ 94 | transposedShape}`); 95 | } 96 | 97 | // Format data for vega spec; an array of objects, one for for each cell 98 | // in the matrix. 99 | const values: MatrixEntry[] = []; 100 | const {xTickLabels, yTickLabels} = data; 101 | 102 | // These two branches are very similar but we want to do the test once 103 | // rather than on every element access 104 | if (inputValues instanceof tf.Tensor) { 105 | assert( 106 | inputValues.rank === 2, 107 | 'Input to renderHeatmap must be a 2d array or Tensor2d'); 108 | 109 | const shape = inputValues.shape; 110 | if (xTickLabels) { 111 | assert( 112 | shape[0] === xTickLabels.length, 113 | `Length of xTickLabels (${ 114 | xTickLabels.length}) must match number of rows 115 | (${shape[0]})`); 116 | } 117 | 118 | if (yTickLabels) { 119 | assert( 120 | shape[1] === yTickLabels.length, 121 | `Length of yTickLabels (${ 122 | yTickLabels.length}) must match number of columns 123 | (${shape[1]})`); 124 | } 125 | 126 | // This is a slightly specialized version of TensorBuffer.get, inlining it 127 | // avoids the overhead of a function call per data element access and is 128 | // specialized to only deal with the 2d case. 129 | const inputArray = await inputValues.data(); 130 | const [numRows, numCols] = shape; 131 | 132 | for (let row = 0; row < numRows; row++) { 133 | const x = xTickLabels ? xTickLabels[row] : row; 134 | for (let col = 0; col < numCols; col++) { 135 | const y = yTickLabels ? yTickLabels[col] : col; 136 | 137 | const index = (row * numCols) + col; 138 | const value = inputArray[index]; 139 | 140 | values.push({x, y, value}); 141 | } 142 | } 143 | } else { 144 | if (xTickLabels) { 145 | assert( 146 | inputValues.length === xTickLabels.length, 147 | `Number of rows (${inputValues.length}) must match 148 | number of xTickLabels (${xTickLabels.length})`); 149 | } 150 | 151 | const inputArray = inputValues as number[][]; 152 | for (let row = 0; row < inputArray.length; row++) { 153 | const x = xTickLabels ? xTickLabels[row] : row; 154 | if (yTickLabels) { 155 | assert( 156 | inputValues[row].length === yTickLabels.length, 157 | `Number of columns in row ${row} (${inputValues[row].length}) 158 | must match length of yTickLabels (${yTickLabels.length})`); 159 | } 160 | for (let col = 0; col < inputArray[row].length; col++) { 161 | const y = yTickLabels ? yTickLabels[col] : col; 162 | const value = inputArray[row][col]; 163 | values.push({x, y, value}); 164 | } 165 | } 166 | } 167 | 168 | const embedOpts = { 169 | actions: false, 170 | mode: 'vega-lite' as Mode, 171 | defaultStyle: false, 172 | }; 173 | 174 | const spec: VisualizationSpec = { 175 | 'width': options.width || drawArea.clientWidth, 176 | 'height': options.height || drawArea.clientHeight, 177 | 'padding': 0, 178 | 'autosize': { 179 | 'type': 'fit', 180 | 'contains': 'padding', 181 | 'resize': true, 182 | }, 183 | 'config': { 184 | 'axis': { 185 | 'labelFontSize': options.fontSize, 186 | 'titleFontSize': options.fontSize, 187 | }, 188 | 'text': {'fontSize': options.fontSize}, 189 | 'legend': { 190 | 'labelFontSize': options.fontSize, 191 | 'titleFontSize': options.fontSize, 192 | }, 193 | 'scale': {'bandPaddingInner': 0, 'bandPaddingOuter': 0}, 194 | }, 195 | 'data': {'values': values}, 196 | 'mark': 'rect', 197 | 'encoding': { 198 | 'x': { 199 | 'field': 'x', 200 | 'type': options.xType, 201 | // Maintain sort order of the axis if labels is passed in 202 | 'scale': {'domain': xTickLabels}, 203 | 'title': options.xLabel, 204 | }, 205 | 'y': { 206 | 'field': 'y', 207 | 'type': options.yType, 208 | // Maintain sort order of the axis if labels is passed in 209 | 'scale': {'domain': yTickLabels}, 210 | 'title': options.yLabel, 211 | }, 212 | 'fill': { 213 | 'field': 'value', 214 | 'type': 'quantitative', 215 | }, 216 | } 217 | }; 218 | 219 | let colorRange: string[]|string; 220 | switch (options.colorMap) { 221 | case 'blues': 222 | colorRange = ['#f7fbff', '#4292c6']; 223 | break; 224 | case 'greyscale': 225 | colorRange = ['#000000', '#ffffff']; 226 | break; 227 | case 'viridis': 228 | default: 229 | colorRange = 'viridis'; 230 | break; 231 | } 232 | 233 | if (colorRange !== 'viridis') { 234 | const fill = spec.encoding!.fill; 235 | // @ts-ignore 236 | fill.scale = {'range': colorRange}; 237 | } 238 | 239 | if (options.domain) { 240 | const fill = spec.encoding!.fill; 241 | // @ts-ignore 242 | if (fill.scale != null) { 243 | // @ts-ignore 244 | fill.scale = Object.assign({}, fill.scale, {'domain': options.domain}); 245 | } else { 246 | // @ts-ignore 247 | fill.scale = {'domain': options.domain}; 248 | } 249 | } 250 | 251 | await embed(drawArea, spec, embedOpts); 252 | } 253 | 254 | const defaultOpts = { 255 | xLabel: null, 256 | yLabel: null, 257 | xType: 'ordinal', 258 | yType: 'ordinal', 259 | colorMap: 'viridis', 260 | fontSize: 12, 261 | domain: null, 262 | rowMajor: false, 263 | }; 264 | 265 | interface MatrixEntry { 266 | x: string|number; 267 | y: string|number; 268 | value: number; 269 | } 270 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/heatmap_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | 20 | import {HeatmapData} from '../types'; 21 | 22 | import {heatmap} from './heatmap'; 23 | 24 | describe('renderHeatmap', () => { 25 | let pixelRatio: number; 26 | 27 | beforeEach(() => { 28 | document.body.innerHTML = '
'; 29 | pixelRatio = window.devicePixelRatio; 30 | }); 31 | 32 | it('renders a chart', async () => { 33 | const data: HeatmapData = { 34 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 35 | }; 36 | 37 | const container = document.getElementById('container') as HTMLElement; 38 | await heatmap(container, data); 39 | 40 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 41 | }); 42 | 43 | it('renders a chart with rowMajor=true', async () => { 44 | const data: HeatmapData = { 45 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20], [8, 2, 8]], 46 | }; 47 | 48 | const numTensorsBefore = tf.memory().numTensors; 49 | 50 | const container = document.getElementById('container') as HTMLElement; 51 | await heatmap(container, data, {rowMajor: true}); 52 | 53 | const numTensorsAfter = tf.memory().numTensors; 54 | 55 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 56 | expect(numTensorsAfter).toEqual(numTensorsBefore); 57 | }); 58 | 59 | it('renders a chart with rowMajor=true and custom labels', async () => { 60 | const data: HeatmapData = { 61 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20], [8, 2, 8]], 62 | xTickLabels: ['alpha', 'beta', 'gamma'], 63 | yTickLabels: ['first', 'second', 'third', 'fourth'], 64 | }; 65 | 66 | const numTensorsBefore = tf.memory().numTensors; 67 | 68 | const container = document.getElementById('container') as HTMLElement; 69 | await heatmap(container, data, {rowMajor: true}); 70 | 71 | const numTensorsAfter = tf.memory().numTensors; 72 | 73 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 74 | expect(numTensorsAfter).toEqual(numTensorsBefore); 75 | }); 76 | 77 | it('renders a chart with a tensor', async () => { 78 | const values = tf.tensor2d([[4, 2, 8], [1, 7, 2], [3, 3, 20]]); 79 | const data: HeatmapData = { 80 | values, 81 | }; 82 | 83 | const container = document.getElementById('container') as HTMLElement; 84 | await heatmap(container, data); 85 | 86 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 87 | 88 | values.dispose(); 89 | }); 90 | 91 | it('throws an exception with a non 2d tensor', async () => { 92 | const values = tf.tensor1d([4, 2, 8, 1, 7, 2, 3, 3, 20]); 93 | const data = { 94 | values, 95 | }; 96 | 97 | const container = document.getElementById('container') as HTMLElement; 98 | 99 | let threw = false; 100 | try { 101 | // @ts-ignore — passing in the wrong datatype 102 | await heatmap(data, container); 103 | } catch (e) { 104 | threw = true; 105 | } finally { 106 | values.dispose(); 107 | } 108 | expect(threw).toBe(true); 109 | }); 110 | 111 | it('renders a chart with custom colormap', async () => { 112 | const data: HeatmapData = { 113 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 114 | }; 115 | 116 | const container = document.getElementById('container') as HTMLElement; 117 | await heatmap(container, data, {colorMap: 'greyscale'}); 118 | 119 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 120 | }); 121 | 122 | it('renders a chart with custom domain', async () => { 123 | const data: HeatmapData = { 124 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 125 | }; 126 | 127 | const container = document.getElementById('container') as HTMLElement; 128 | await heatmap(container, data, {domain: [0, 30]}); 129 | 130 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 131 | }); 132 | 133 | it('renders a chart with custom labels', async () => { 134 | const data: HeatmapData = { 135 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 136 | xTickLabels: ['cheese', 'pig', 'font'], 137 | yTickLabels: ['speed', 'dexterity', 'roundness'], 138 | }; 139 | 140 | const container = document.getElementById('container') as HTMLElement; 141 | await heatmap(container, data); 142 | 143 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 144 | }); 145 | 146 | it('updates the chart', async () => { 147 | let data: HeatmapData = { 148 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 149 | }; 150 | 151 | const container = document.getElementById('container') as HTMLElement; 152 | 153 | await heatmap(container, data); 154 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 155 | 156 | data = { 157 | values: [[43, 2, 8], [1, 7, 2], [3, 3, 20]], 158 | }; 159 | 160 | await heatmap(container, data); 161 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 162 | }); 163 | 164 | it('sets width of chart', async () => { 165 | const data: HeatmapData = { 166 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 167 | }; 168 | 169 | const container = document.getElementById('container') as HTMLElement; 170 | await heatmap(container, data, {width: 400}); 171 | 172 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 173 | expect(document.querySelectorAll('canvas').length).toBe(1); 174 | expect(document.querySelector('canvas')!.width).toBe(400 * pixelRatio); 175 | }); 176 | 177 | it('sets height of chart', async () => { 178 | const data: HeatmapData = { 179 | values: [[4, 2, 8], [1, 7, 2], [3, 3, 20]], 180 | }; 181 | 182 | const container = document.getElementById('container') as HTMLElement; 183 | await heatmap(container, data, {height: 200}); 184 | 185 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 186 | expect(document.querySelectorAll('canvas').length).toBe(1); 187 | expect(document.querySelector('canvas')!.height).toBe(200 * pixelRatio); 188 | }); 189 | }); 190 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/histogram.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {format as d3Format} from 'd3-format'; 19 | import embed, {Mode, VisualizationSpec} from 'vega-embed'; 20 | 21 | import {Drawable, HistogramOpts, HistogramStats, TypedArray} from '../types'; 22 | import {subSurface} from '../util/dom'; 23 | import {arrayStats} from '../util/math'; 24 | 25 | import {table} from './table'; 26 | 27 | const defaultOpts = { 28 | maxBins: 12, 29 | fontSize: 11, 30 | }; 31 | 32 | /** 33 | * Renders a histogram of values 34 | * 35 | * ```js 36 | * const data = Array(100).fill(0) 37 | * .map(x => Math.random() * 100 - (Math.random() * 50)) 38 | * 39 | * // Push some special values for the stats table. 40 | * data.push(Infinity); 41 | * data.push(NaN); 42 | * data.push(0); 43 | * 44 | * const surface = { name: 'Histogram', tab: 'Charts' }; 45 | * tfvis.render.histogram(surface, data); 46 | * ``` 47 | */ 48 | /** @doc {heading: 'Charts', namespace: 'render'} */ 49 | export async function histogram( 50 | container: Drawable, data: Array<{value: number}>|number[]|TypedArray, 51 | opts: HistogramOpts = {}) { 52 | const values = prepareData(data); 53 | 54 | const options = Object.assign({}, defaultOpts, opts); 55 | 56 | const embedOpts = { 57 | actions: false, 58 | mode: 'vega-lite' as Mode, 59 | defaultStyle: false, 60 | }; 61 | 62 | const histogramContainer = subSurface(container, 'histogram'); 63 | if (opts.stats !== false) { 64 | const statsContainer = subSurface(container, 'stats', { 65 | prepend: true, 66 | }); 67 | let stats: HistogramStats; 68 | 69 | if (opts.stats) { 70 | stats = opts.stats; 71 | } else { 72 | stats = arrayStats(values.map(x => x.value)); 73 | } 74 | renderStats(stats, statsContainer, {fontSize: options.fontSize}); 75 | } 76 | 77 | // Now that we have rendered stats we need to remove any NaNs and Infinities 78 | // before rendering the histogram 79 | const filtered = []; 80 | for (let i = 0; i < values.length; i++) { 81 | const val = values[i].value; 82 | if (val != null && isFinite(val)) { 83 | filtered.push(values[i]); 84 | } 85 | } 86 | 87 | const histogramSpec: VisualizationSpec = { 88 | 89 | 'width': options.width || histogramContainer.clientWidth, 90 | 'height': options.height || histogramContainer.clientHeight, 91 | 'padding': 0, 92 | 'autosize': { 93 | 'type': 'fit', 94 | 'contains': 'padding', 95 | 'resize': true, 96 | }, 97 | 'data': {'values': filtered}, 98 | 'mark': 'bar', 99 | 'config': { 100 | 'axis': { 101 | 'labelFontSize': options.fontSize, 102 | 'titleFontSize': options.fontSize, 103 | }, 104 | 'text': {'fontSize': options.fontSize}, 105 | 'legend': { 106 | 'labelFontSize': options.fontSize, 107 | 'titleFontSize': options.fontSize, 108 | } 109 | }, 110 | 'encoding': { 111 | 'x': { 112 | 'bin': {'maxbins': options.maxBins}, 113 | 'field': 'value', 114 | 'type': 'quantitative', 115 | }, 116 | 'y': { 117 | 'aggregate': 'count', 118 | 'type': 'quantitative', 119 | }, 120 | 'color': { 121 | // TODO extract to theme? 122 | 'value': '#001B44', 123 | } 124 | } 125 | }; 126 | 127 | return embed(histogramContainer, histogramSpec, embedOpts); 128 | } 129 | 130 | function renderStats( 131 | stats: HistogramStats, container: HTMLElement, opts: {fontSize: number}) { 132 | const format = d3Format(',.4~f'); 133 | const pctFormat = d3Format('.4~p'); 134 | 135 | const headers: string[] = []; 136 | const vals: string[] = []; 137 | 138 | if (stats.numVals != null) { 139 | headers.push('Num Vals'); 140 | vals.push(format(stats.numVals)); 141 | } 142 | 143 | if (stats.min != null) { 144 | headers.push('Min'); 145 | vals.push(format(stats.min)); 146 | } 147 | 148 | if (stats.max != null) { 149 | headers.push('Max'); 150 | vals.push(format(stats.max)); 151 | } 152 | 153 | if (stats.numZeros != null) { 154 | headers.push('# Zeros'); 155 | let zeroPct = ''; 156 | if (stats.numVals) { 157 | zeroPct = stats.numZeros > 0 ? 158 | `(${pctFormat(stats.numZeros / stats.numVals)})` : 159 | ''; 160 | } 161 | 162 | vals.push(`${format(stats.numZeros)} ${zeroPct}`); 163 | } 164 | 165 | if (stats.numNans != null) { 166 | headers.push('# NaNs'); 167 | let nanPct = ''; 168 | if (stats.numVals) { 169 | nanPct = stats.numNans > 0 ? 170 | `(${pctFormat(stats.numNans / stats.numVals)})` : 171 | ''; 172 | } 173 | 174 | vals.push(`${format(stats.numNans)} ${nanPct}`); 175 | } 176 | 177 | if (stats.numInfs != null) { 178 | headers.push('# Infinity'); 179 | let infPct = ''; 180 | if (stats.numVals) { 181 | infPct = stats.numInfs > 0 ? 182 | `(${pctFormat(stats.numInfs / stats.numVals)})` : 183 | ''; 184 | } 185 | 186 | vals.push(`${format(stats.numInfs)} ${infPct}`); 187 | } 188 | 189 | table(container, {headers, values: [vals]}, opts); 190 | } 191 | 192 | /** 193 | * Formats data to the internal format used by this chart. 194 | */ 195 | function prepareData(data: Array<{value: number}>|number[]| 196 | TypedArray): Array<{value: number}> { 197 | if (data.length == null) { 198 | throw new Error('input data must be an array'); 199 | } 200 | 201 | if (data.length === 0) { 202 | return []; 203 | } else if (typeof data[0] === 'object') { 204 | if ((data[0] as {value: number}).value == null) { 205 | throw new Error('input data must have a value field'); 206 | } else { 207 | return data as Array<{value: number}>; 208 | } 209 | } else { 210 | const ret = Array(data.length); 211 | for (let i = 0; i < data.length; i++) { 212 | ret[i] = {value: data[i]}; 213 | } 214 | return ret; 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/histogram_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {HistogramStats} from '../types'; 19 | 20 | import {histogram} from './histogram'; 21 | 22 | describe('renderHistogram', () => { 23 | let pixelRatio: number; 24 | 25 | beforeEach(() => { 26 | document.body.innerHTML = '
'; 27 | pixelRatio = window.devicePixelRatio; 28 | }); 29 | 30 | it('renders a histogram', async () => { 31 | const data = [ 32 | {value: 50}, 33 | {value: 100}, 34 | {value: 100}, 35 | ]; 36 | 37 | const container = document.getElementById('container') as HTMLElement; 38 | await histogram(container, data); 39 | 40 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 41 | expect(document.querySelectorAll('table').length).toBe(1); 42 | expect(document.querySelectorAll('table thead tr').length).toBe(1); 43 | expect(document.querySelectorAll('table thead th').length).toBe(6); 44 | expect(document.querySelectorAll('table tbody tr').length).toBe(1); 45 | expect(document.querySelectorAll('table tbody td').length).toBe(6); 46 | }); 47 | 48 | it('renders a histogram with number array', async () => { 49 | const data = [50, 100, 100]; 50 | 51 | const container = document.getElementById('container') as HTMLElement; 52 | await histogram(container, data); 53 | 54 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 55 | expect(document.querySelectorAll('table').length).toBe(1); 56 | expect(document.querySelectorAll('table thead tr').length).toBe(1); 57 | expect(document.querySelectorAll('table thead th').length).toBe(6); 58 | expect(document.querySelectorAll('table tbody tr').length).toBe(1); 59 | expect(document.querySelectorAll('table tbody td').length).toBe(6); 60 | }); 61 | 62 | it('renders a histogram with typed array', async () => { 63 | const data = new Int32Array([50, 100, 100]); 64 | 65 | const container = document.getElementById('container') as HTMLElement; 66 | await histogram(container, data); 67 | 68 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 69 | expect(document.querySelectorAll('table').length).toBe(1); 70 | expect(document.querySelectorAll('table thead tr').length).toBe(1); 71 | expect(document.querySelectorAll('table thead th').length).toBe(6); 72 | expect(document.querySelectorAll('table tbody tr').length).toBe(1); 73 | expect(document.querySelectorAll('table tbody td').length).toBe(6); 74 | }); 75 | 76 | it('re-renders a histogram', async () => { 77 | const data = [ 78 | {value: 50}, 79 | {value: 100}, 80 | {value: 100}, 81 | ]; 82 | 83 | const container = document.getElementById('container') as HTMLElement; 84 | 85 | await histogram(container, data); 86 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 87 | 88 | await histogram(container, data); 89 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 90 | }); 91 | 92 | it('updates a histogram chart', async () => { 93 | let data = [ 94 | {value: 50}, 95 | {value: 100}, 96 | {value: 100}, 97 | ]; 98 | 99 | const container = document.getElementById('container') as HTMLElement; 100 | 101 | await histogram(container, data); 102 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 103 | 104 | data = [ 105 | {value: 150}, 106 | {value: 100}, 107 | {value: 150}, 108 | ]; 109 | 110 | await histogram(container, data); 111 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 112 | }); 113 | 114 | it('renders correct stats', async () => { 115 | const data = [ 116 | {value: 50}, 117 | {value: -100}, 118 | {value: 200}, 119 | {value: 0}, 120 | {value: 0}, 121 | {value: NaN}, 122 | {value: NaN}, 123 | {value: NaN}, 124 | ]; 125 | 126 | const container = document.getElementById('container') as HTMLElement; 127 | await histogram(container, data); 128 | 129 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 130 | expect(document.querySelectorAll('table').length).toBe(1); 131 | expect(document.querySelectorAll('table tbody tr').length).toBe(1); 132 | 133 | const statsEls = document.querySelectorAll('table tbody td'); 134 | expect(statsEls.length).toBe(6); 135 | expect(statsEls[0].textContent).toEqual('8'); 136 | expect(statsEls[1].textContent).toEqual('-100'); 137 | expect(statsEls[2].textContent).toEqual('200'); 138 | expect(statsEls[3].textContent).toEqual('2 (25%)'); 139 | expect(statsEls[4].textContent).toEqual('3 (37.5%)'); 140 | }); 141 | 142 | it('does not throw on empty data', async () => { 143 | const data: Array<{value: number}> = []; 144 | 145 | const container = document.getElementById('container') as HTMLElement; 146 | expect(async () => { 147 | await histogram(container, data); 148 | }).not.toThrow(); 149 | 150 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 151 | expect(document.querySelectorAll('table').length).toBe(1); 152 | expect(document.querySelectorAll('table thead tr').length).toBe(1); 153 | expect(document.querySelectorAll('table thead th').length).toBe(3); 154 | }); 155 | 156 | it('renders custom stats', async () => { 157 | const data = [ 158 | {value: 50}, 159 | ]; 160 | 161 | const stats: HistogramStats = { 162 | numVals: 200, 163 | min: -30, 164 | max: 140, 165 | numZeros: 2, 166 | numNans: 5, 167 | }; 168 | 169 | const container = document.getElementById('container') as HTMLElement; 170 | await histogram(container, data, {stats}); 171 | 172 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 173 | expect(document.querySelectorAll('table').length).toBe(1); 174 | expect(document.querySelectorAll('table tbody tr').length).toBe(1); 175 | 176 | const statsEls = document.querySelectorAll('table tbody td'); 177 | expect(statsEls.length).toBe(5); 178 | expect(statsEls[0].textContent).toEqual('200'); 179 | expect(statsEls[1].textContent).toEqual('-30'); 180 | expect(statsEls[2].textContent).toEqual('140'); 181 | expect(statsEls[3].textContent).toEqual('2 (1%)'); 182 | expect(statsEls[4].textContent).toEqual('5 (2.5%)'); 183 | }); 184 | 185 | it('sets width of chart', async () => { 186 | const data = [ 187 | {value: 50}, 188 | {value: 100}, 189 | {value: 230}, 190 | ]; 191 | 192 | const container = document.getElementById('container') as HTMLElement; 193 | await histogram(container, data, {width: 400}); 194 | 195 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 196 | expect(document.querySelectorAll('canvas').length).toBe(1); 197 | expect(document.querySelector('canvas')!.width).toBe(400 * pixelRatio); 198 | }); 199 | 200 | it('sets height of chart', async () => { 201 | const data = [ 202 | {value: 50}, 203 | {value: 100}, 204 | {value: 230}, 205 | ]; 206 | 207 | const container = document.getElementById('container') as HTMLElement; 208 | await histogram(container, data, {height: 200}); 209 | 210 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 211 | expect(document.querySelectorAll('canvas').length).toBe(1); 212 | expect(document.querySelector('canvas')!.height).toBe(200 * pixelRatio); 213 | }); 214 | }); 215 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/linechart.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import embed, {Mode, VisualizationSpec} from 'vega-embed'; 19 | 20 | import {Drawable, Point2D, XYPlotData, XYPlotOptions} from '../types'; 21 | 22 | import {getDrawArea} from './render_utils'; 23 | 24 | /** 25 | * Renders a line chart 26 | * 27 | * ```js 28 | * const series1 = Array(100).fill(0) 29 | * .map(y => Math.random() * 100 - (Math.random() * 50)) 30 | * .map((y, x) => ({ x, y, })); 31 | * 32 | * const series2 = Array(100).fill(0) 33 | * .map(y => Math.random() * 100 - (Math.random() * 150)) 34 | * .map((y, x) => ({ x, y, })); 35 | * 36 | * const series = ['First', 'Second']; 37 | * const data = { values: [series1, series2], series } 38 | * 39 | * const surface = { name: 'Line chart', tab: 'Charts' }; 40 | * tfvis.render.linechart(surface, data); 41 | * ``` 42 | * 43 | * ```js 44 | * const series1 = Array(100).fill(0) 45 | * .map(y => Math.random() * 100 + 50) 46 | * .map((y, x) => ({ x, y, })); 47 | * 48 | * const data = { values: [series1] } 49 | * 50 | * // Render to visor 51 | * const surface = { name: 'Zoomed Line Chart', tab: 'Charts' }; 52 | * tfvis.render.linechart(surface, data, { zoomToFit: true }); 53 | * ``` 54 | * 55 | */ 56 | /** @doc {heading: 'Charts', namespace: 'render'} */ 57 | export async function linechart( 58 | container: Drawable, data: XYPlotData, 59 | opts: XYPlotOptions = {}): Promise { 60 | let inputArray = data.values; 61 | const _series = data.series == null ? [] : data.series; 62 | 63 | // Nest data if necessary before further processing 64 | inputArray = Array.isArray(inputArray[0]) ? inputArray as Point2D[][] : 65 | [inputArray] as Point2D[][]; 66 | 67 | const values: Point2D[] = []; 68 | const seriesNames = new Set(); 69 | inputArray.forEach((seriesData, i) => { 70 | const seriesName: string = 71 | _series[i] != null ? _series[i] : `Series ${i + 1}`; 72 | seriesNames.add(seriesName); 73 | const seriesVals = 74 | seriesData.map(v => Object.assign({}, v, {series: seriesName})); 75 | values.push(...seriesVals); 76 | }); 77 | 78 | const drawArea = getDrawArea(container); 79 | const options = Object.assign({}, defaultOpts, opts); 80 | 81 | const embedOpts = { 82 | actions: false, 83 | mode: 'vega-lite' as Mode, 84 | defaultStyle: false, 85 | }; 86 | 87 | const yScale = (): {}|undefined => { 88 | if (options.zoomToFit) { 89 | return {'zero': false}; 90 | } else if (options.yAxisDomain != null) { 91 | return {'domain': options.yAxisDomain}; 92 | } 93 | return undefined; 94 | }; 95 | 96 | // tslint:disable-next-line:no-any 97 | const encodings: any = { 98 | 'x': { 99 | 'field': 'x', 100 | 'type': options.xType, 101 | 'title': options.xLabel, 102 | }, 103 | 'y': { 104 | 'field': 'y', 105 | 'type': options.yType, 106 | 'title': options.yLabel, 107 | 'scale': yScale(), 108 | }, 109 | 'color': { 110 | 'field': 'series', 111 | 'type': 'nominal', 112 | 'legend': {'values': Array.from(seriesNames)} 113 | }, 114 | }; 115 | 116 | // tslint:disable-next-line:no-any 117 | let domainFilter: any; 118 | if (options.yAxisDomain != null) { 119 | domainFilter = {'filter': {'field': 'y', 'range': options.yAxisDomain}}; 120 | } 121 | 122 | const spec: VisualizationSpec = { 123 | 124 | 'width': options.width || drawArea.clientWidth, 125 | 'height': options.height || drawArea.clientHeight, 126 | 'padding': 0, 127 | 'autosize': { 128 | 'type': 'fit', 129 | 'contains': 'padding', 130 | 'resize': true, 131 | }, 132 | 'config': { 133 | 'axis': { 134 | 'labelFontSize': options.fontSize, 135 | 'titleFontSize': options.fontSize, 136 | }, 137 | 'text': {'fontSize': options.fontSize}, 138 | 'legend': { 139 | 'labelFontSize': options.fontSize, 140 | 'titleFontSize': options.fontSize, 141 | } 142 | }, 143 | 'data': {'values': values}, 144 | 'layer': [ 145 | { 146 | // Render the main line chart 147 | 'mark': { 148 | 'type': 'line', 149 | 'clip': true, 150 | }, 151 | 'encoding': encodings, 152 | }, 153 | { 154 | // Render invisible points for all the the data to make selections 155 | // easier 156 | 'mark': {'type': 'point'}, 157 | // 'encoding': encodings, 158 | // If a custom domain is set, filter out the values that will not 159 | // fit we do this on the points and not the line so that the line 160 | // still appears clipped for values outside the domain but we can 161 | // still operate on an unclipped set of points. 162 | 'transform': options.yAxisDomain ? [domainFilter] : undefined, 163 | 'selection': { 164 | 'nearestPoint': { 165 | 'type': 'single', 166 | 'on': 'mouseover', 167 | 'nearest': true, 168 | 'empty': 'none', 169 | 'encodings': ['x'], 170 | }, 171 | }, 172 | 'encoding': Object.assign({}, encodings, { 173 | 'opacity': { 174 | 'value': 0, 175 | 'condition': { 176 | 'selection': 'nearestPoint', 177 | 'value': 1, 178 | }, 179 | } 180 | }), 181 | }, 182 | { 183 | // Render a tooltip where the selection is 184 | 'transform': [ 185 | {'filter': {'selection': 'nearestPoint'}}, 186 | domainFilter 187 | ].filter(Boolean), // remove undefineds from array 188 | 'mark': { 189 | 'type': 'text', 190 | 'align': 'left', 191 | 'dx': 5, 192 | 'dy': -5, 193 | 'color': 'black', 194 | }, 195 | 'encoding': Object.assign({}, encodings, { 196 | 'text': { 197 | 'type': options.xType, 198 | 'field': 'y', 199 | 'format': '.6f', 200 | }, 201 | // Unset text color to improve readability 202 | 'color': undefined, 203 | }), 204 | }, 205 | { 206 | // Draw a vertical line where the selection is 207 | 'transform': [{'filter': {'selection': 'nearestPoint'}}], 208 | 'mark': {'type': 'rule', 'color': 'gray'}, 209 | 'encoding': { 210 | 'x': { 211 | 'type': options.xType, 212 | 'field': 'x', 213 | } 214 | } 215 | }, 216 | ], 217 | }; 218 | 219 | await embed(drawArea, spec, embedOpts); 220 | return Promise.resolve(); 221 | } 222 | 223 | const defaultOpts = { 224 | xLabel: 'x', 225 | yLabel: 'y', 226 | xType: 'quantitative', 227 | yType: 'quantitative', 228 | zoomToFit: false, 229 | fontSize: 11, 230 | }; 231 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/linechart_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {linechart} from './linechart'; 19 | 20 | describe('renderLineChart', () => { 21 | let pixelRatio: number; 22 | 23 | beforeEach(() => { 24 | document.body.innerHTML = '
'; 25 | pixelRatio = window.devicePixelRatio; 26 | }); 27 | 28 | it('renders a line chart', async () => { 29 | const data = { 30 | values: [ 31 | {x: 0, y: 50}, 32 | {x: 1, y: 100}, 33 | {x: 2, y: 230}, 34 | ] 35 | }; 36 | 37 | const container = document.getElementById('container') as HTMLElement; 38 | await linechart(container, data); 39 | 40 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 41 | }); 42 | 43 | it('renders a line chart with multiple series', async () => { 44 | const data = { 45 | values: [ 46 | [ 47 | {x: 0, y: 50}, 48 | {x: 1, y: 100}, 49 | {x: 2, y: 230}, 50 | ], 51 | [ 52 | {x: 0, y: 20}, 53 | {x: 1, y: 300}, 54 | {x: 2, y: 630}, 55 | ], 56 | ] 57 | }; 58 | 59 | const container = document.getElementById('container') as HTMLElement; 60 | 61 | await linechart(container, data); 62 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 63 | }); 64 | 65 | it('renders a line chart with multiple series custom names', async () => { 66 | const data = { 67 | values: [ 68 | [ 69 | {x: 0, y: 50}, 70 | {x: 1, y: 100}, 71 | {x: 2, y: 230}, 72 | ], 73 | [ 74 | {x: 0, y: 20}, 75 | {x: 1, y: 300}, 76 | {x: 2, y: 630}, 77 | ], 78 | ], 79 | series: ['First', 'Second'], 80 | }; 81 | 82 | const container = document.getElementById('container') as HTMLElement; 83 | 84 | await linechart(container, data); 85 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 86 | }); 87 | 88 | it('updates a line chart', async () => { 89 | let data = { 90 | values: [ 91 | [ 92 | {x: 0, y: 50}, 93 | {x: 1, y: 100}, 94 | {x: 2, y: 230}, 95 | ], 96 | [ 97 | {x: 0, y: 20}, 98 | {x: 1, y: 300}, 99 | {x: 2, y: 630}, 100 | ], 101 | ], 102 | series: ['First', 'Second'], 103 | }; 104 | 105 | const container = document.getElementById('container') as HTMLElement; 106 | 107 | await linechart(container, data); 108 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 109 | 110 | data = { 111 | values: [ 112 | [ 113 | {x: 0, y: 50}, 114 | {x: 1, y: 100}, 115 | {x: 2, y: 230}, 116 | ], 117 | [ 118 | {x: 0, y: 20}, 119 | {x: 1, y: 300}, 120 | {x: 2, y: 630}, 121 | {x: 3, y: 530}, 122 | {x: 4, y: 230}, 123 | ], 124 | ], 125 | series: ['First', 'Second'], 126 | }; 127 | 128 | await linechart(container, data); 129 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 130 | }); 131 | 132 | it('sets width of chart', async () => { 133 | const data = { 134 | values: [ 135 | {x: 0, y: 50}, 136 | {x: 1, y: 100}, 137 | {x: 2, y: 230}, 138 | ] 139 | }; 140 | 141 | const container = document.getElementById('container') as HTMLElement; 142 | await linechart(container, data, {width: 400}); 143 | 144 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 145 | expect(document.querySelectorAll('canvas').length).toBe(1); 146 | expect(document.querySelector('canvas')!.width).toBe(400 * pixelRatio); 147 | }); 148 | 149 | it('sets height of chart', async () => { 150 | const data = { 151 | values: [ 152 | {x: 0, y: 50}, 153 | {x: 1, y: 100}, 154 | {x: 2, y: 230}, 155 | ] 156 | }; 157 | 158 | const container = document.getElementById('container') as HTMLElement; 159 | await linechart(container, data, {height: 200}); 160 | 161 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 162 | expect(document.querySelectorAll('canvas').length).toBe(1); 163 | expect(document.querySelector('canvas')!.height).toBe(200 * pixelRatio); 164 | }); 165 | }); 166 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/render_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Drawable, isSurface, isSurfaceInfo} from '../types'; 19 | import {visor} from '../visor'; 20 | 21 | export function getDrawArea(drawable: Drawable): HTMLElement { 22 | if (drawable instanceof HTMLElement) { 23 | return drawable; 24 | } else if (isSurface(drawable)) { 25 | return drawable.drawArea; 26 | } else if (isSurfaceInfo(drawable)) { 27 | const surface = visor().surface( 28 | {name: drawable.name, tab: drawable.tab, styles: drawable.styles}); 29 | return surface.drawArea; 30 | } else { 31 | throw new Error('Not a drawable'); 32 | } 33 | } 34 | 35 | export function shallowEquals( 36 | // tslint:disable-next-line:no-any 37 | a: {[key: string]: any}, b: {[key: string]: any}) { 38 | const aProps = Object.getOwnPropertyNames(a); 39 | const bProps = Object.getOwnPropertyNames(b); 40 | 41 | if (aProps.length !== bProps.length) { 42 | return false; 43 | } 44 | 45 | for (let i = 0; i < aProps.length; i++) { 46 | const prop = aProps[i]; 47 | if (a[prop] !== b[prop]) { 48 | return false; 49 | } 50 | } 51 | 52 | return true; 53 | } 54 | 55 | export async function nextFrame() { 56 | await new Promise(r => requestAnimationFrame(r)); 57 | } 58 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/render_utils_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {visor} from '../index'; 19 | import {getDrawArea, shallowEquals} from './render_utils'; 20 | 21 | describe('shallowEqual', () => { 22 | beforeEach(() => { 23 | document.body.innerHTML = '
'; 24 | }); 25 | 26 | it('returns true for similar objects', async () => { 27 | const a = { 28 | stringProp: 'astring', 29 | numProp: 55, 30 | boolProp: true, 31 | }; 32 | 33 | const b = { 34 | stringProp: 'astring', 35 | boolProp: true, 36 | numProp: 55, 37 | }; 38 | 39 | expect(shallowEquals(a, b)).toBe(true); 40 | }); 41 | 42 | it('returns false for different objects', async () => { 43 | const a = { 44 | stringProp: 'astring', 45 | numProp: 55, 46 | boolProp: false, 47 | }; 48 | 49 | const b = { 50 | stringProp: 'astring', 51 | numProp: 55, 52 | boolProp: true, 53 | }; 54 | 55 | expect(shallowEquals(a, b)).toBe(false); 56 | }); 57 | 58 | it('returns true for similar objects (array ref)', async () => { 59 | // tslint:disable-next-line:no-any 60 | const ref: any[] = []; 61 | 62 | const a = { 63 | stringProp: 'astring', 64 | numProp: 55, 65 | refProp: ref, 66 | }; 67 | 68 | const b = { 69 | numProp: 55, 70 | stringProp: 'astring', 71 | refProp: ref, 72 | }; 73 | 74 | expect(shallowEquals(a, b)).toBe(true); 75 | }); 76 | }); 77 | 78 | describe('getDrawArea', () => { 79 | beforeEach(() => { 80 | document.body.innerHTML = '
'; 81 | }); 82 | 83 | it('works with HTMLElement', async () => { 84 | const el = document.getElementById('container') as HTMLElement; 85 | expect(getDrawArea(el)).toEqual(el); 86 | }); 87 | 88 | it('works with a surface', async () => { 89 | const surface = visor().surface({name: 'test'}); 90 | expect(getDrawArea(surface)).toEqual(surface.drawArea); 91 | }); 92 | 93 | it('fails with other stuff', async () => { 94 | const surface = visor().surface({name: 'test'}); 95 | //@ts-ignore 96 | expect(() => getDrawArea('not-a-surface')).toThrow(); 97 | }); 98 | }); 99 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/scatterplot.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import embed, {Mode, VisualizationSpec} from 'vega-embed'; 19 | 20 | import {Drawable, Point2D, XYPlotData, XYPlotOptions} from '../types'; 21 | 22 | import {getDrawArea} from './render_utils'; 23 | 24 | /** 25 | * Renders a scatter plot 26 | * 27 | * ```js 28 | * const series1 = Array(100).fill(0) 29 | * .map(y => Math.random() * 100 - (Math.random() * 50)) 30 | * .map((y, x) => ({ x, y, })); 31 | * 32 | * const series2 = Array(100).fill(0) 33 | * .map(y => Math.random() * 100 - (Math.random() * 150)) 34 | * .map((y, x) => ({ x, y, })); 35 | * 36 | * const series = ['First', 'Second']; 37 | * const data = { values: [series1, series2], series } 38 | * 39 | * const surface = { name: 'Scatterplot', tab: 'Charts' }; 40 | * tfvis.render.scatterplot(surface, data); 41 | * ``` 42 | * 43 | */ 44 | /** @doc {heading: 'Charts', namespace: 'render'} */ 45 | export async function scatterplot( 46 | container: Drawable, data: XYPlotData, 47 | opts: XYPlotOptions = {}): Promise { 48 | let _values = data.values; 49 | const _series = data.series == null ? [] : data.series; 50 | 51 | // Nest data if necessary before further processing 52 | _values = Array.isArray(_values[0]) ? _values as Point2D[][] : 53 | [_values] as Point2D[][]; 54 | 55 | const values: Point2D[] = []; 56 | _values.forEach((seriesData, i) => { 57 | const seriesName: string = 58 | _series[i] != null ? _series[i] : `Series ${i + 1}`; 59 | const seriesVals = 60 | seriesData.map(v => Object.assign({}, v, {series: seriesName})); 61 | values.push(...seriesVals); 62 | }); 63 | 64 | const drawArea = getDrawArea(container); 65 | const options = Object.assign({}, defaultOpts, opts); 66 | 67 | const embedOpts = { 68 | actions: false, 69 | mode: 'vega-lite' as Mode, 70 | defaultStyle: false, 71 | }; 72 | 73 | const xDomain = (): {}|undefined => { 74 | if (options.zoomToFit) { 75 | return {'zero': false}; 76 | } else if (options.xAxisDomain != null) { 77 | return {'domain': options.xAxisDomain}; 78 | } 79 | return undefined; 80 | }; 81 | 82 | const yDomain = (): {}|undefined => { 83 | if (options.zoomToFit) { 84 | return {'zero': false}; 85 | } else if (options.yAxisDomain != null) { 86 | return {'domain': options.yAxisDomain}; 87 | } 88 | return undefined; 89 | }; 90 | 91 | const spec: VisualizationSpec = { 92 | 'width': options.width || drawArea.clientWidth, 93 | 'height': options.height || drawArea.clientHeight, 94 | 'padding': 0, 95 | 'autosize': { 96 | 'type': 'fit', 97 | 'contains': 'padding', 98 | 'resize': true, 99 | }, 100 | 'config': { 101 | 'axis': { 102 | 'labelFontSize': options.fontSize, 103 | 'titleFontSize': options.fontSize, 104 | }, 105 | 'text': {'fontSize': options.fontSize}, 106 | 'legend': { 107 | 'labelFontSize': options.fontSize, 108 | 'titleFontSize': options.fontSize, 109 | } 110 | }, 111 | 'data': { 112 | 'values': values, 113 | }, 114 | 'mark': { 115 | 'type': 'point', 116 | 'clip': true, 117 | 'tooltip': {'content': 'data'}, 118 | }, 119 | 'encoding': { 120 | 'x': { 121 | 'field': 'x', 122 | 'type': options.xType, 123 | 'title': options.xLabel, 124 | 'scale': xDomain(), 125 | }, 126 | 'y': { 127 | 'field': 'y', 128 | 'type': options.yType, 129 | 'title': options.yLabel, 130 | 'scale': yDomain(), 131 | }, 132 | 'color': { 133 | 'field': 'series', 134 | 'type': 'nominal', 135 | }, 136 | 'shape': { 137 | 'field': 'series', 138 | 'type': 'nominal', 139 | } 140 | }, 141 | }; 142 | 143 | await embed(drawArea, spec, embedOpts); 144 | return Promise.resolve(); 145 | } 146 | 147 | const defaultOpts = { 148 | xLabel: 'x', 149 | yLabel: 'y', 150 | xType: 'quantitative', 151 | yType: 'quantitative', 152 | zoomToFit: false, 153 | fontSize: 11, 154 | }; 155 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/scatterplot_tests.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {scatterplot} from './scatterplot'; 19 | 20 | describe('renderScatterplot', () => { 21 | let pixelRatio: number; 22 | 23 | beforeEach(() => { 24 | document.body.innerHTML = '
'; 25 | pixelRatio = window.devicePixelRatio; 26 | }); 27 | 28 | it('renders a scatterplot', async () => { 29 | const data = { 30 | values: [ 31 | {x: 0, y: 50}, 32 | {x: 1, y: 100}, 33 | {x: 2, y: 230}, 34 | ] 35 | }; 36 | 37 | const container = document.getElementById('container') as HTMLElement; 38 | await scatterplot(container, data); 39 | 40 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 41 | }); 42 | 43 | it('renders the chart with multiple series', async () => { 44 | const data = { 45 | values: [ 46 | [ 47 | {x: 0, y: 50}, 48 | {x: 1, y: 100}, 49 | {x: 2, y: 230}, 50 | ], 51 | [ 52 | {x: 0, y: 20}, 53 | {x: 1, y: 300}, 54 | {x: 2, y: 630}, 55 | ], 56 | ] 57 | }; 58 | 59 | const container = document.getElementById('container') as HTMLElement; 60 | 61 | await scatterplot(container, data); 62 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 63 | 64 | await scatterplot(container, data); 65 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 66 | }); 67 | 68 | it('renders a line chart with multiple series custom names', async () => { 69 | const data = { 70 | values: [ 71 | [ 72 | {x: 0, y: 50}, 73 | {x: 1, y: 100}, 74 | {x: 2, y: 230}, 75 | ], 76 | [ 77 | {x: 0, y: 20}, 78 | {x: 1, y: 300}, 79 | {x: 2, y: 630}, 80 | ], 81 | ], 82 | series: ['First', 'Second'], 83 | }; 84 | 85 | const container = document.getElementById('container') as HTMLElement; 86 | 87 | await scatterplot(container, data); 88 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 89 | }); 90 | 91 | it('updates the chart', async () => { 92 | let data = { 93 | values: [ 94 | [ 95 | {x: 0, y: 50}, 96 | {x: 1, y: 100}, 97 | {x: 2, y: 230}, 98 | ], 99 | [ 100 | {x: 0, y: 20}, 101 | {x: 1, y: 300}, 102 | {x: 2, y: 630}, 103 | ], 104 | ], 105 | series: ['First', 'Second'], 106 | }; 107 | 108 | const container = document.getElementById('container') as HTMLElement; 109 | 110 | await scatterplot(container, data); 111 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 112 | 113 | data = { 114 | values: [ 115 | [ 116 | {x: 0, y: 50}, 117 | {x: 1, y: 100}, 118 | {x: 2, y: 230}, 119 | ], 120 | [ 121 | {x: 0, y: 20}, 122 | {x: 1, y: 300}, 123 | {x: 2, y: 630}, 124 | {x: 3, y: 530}, 125 | {x: 4, y: 230}, 126 | ], 127 | ], 128 | series: ['First', 'Second'], 129 | }; 130 | 131 | await scatterplot(container, data); 132 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 133 | }); 134 | 135 | it('sets width of chart', async () => { 136 | const data = { 137 | values: [ 138 | {x: 0, y: 50}, 139 | {x: 1, y: 100}, 140 | {x: 2, y: 230}, 141 | ] 142 | }; 143 | 144 | const container = document.getElementById('container') as HTMLElement; 145 | await scatterplot(container, data, {width: 400}); 146 | 147 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 148 | expect(document.querySelectorAll('canvas').length).toBe(1); 149 | expect(document.querySelector('canvas')!.width).toBe(400 * pixelRatio); 150 | }); 151 | 152 | it('sets height of chart', async () => { 153 | const data = { 154 | values: [ 155 | {x: 0, y: 50}, 156 | {x: 1, y: 100}, 157 | {x: 2, y: 230}, 158 | ] 159 | }; 160 | 161 | const container = document.getElementById('container') as HTMLElement; 162 | await scatterplot(container, data, {height: 200}); 163 | 164 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 165 | expect(document.querySelectorAll('canvas').length).toBe(1); 166 | expect(document.querySelector('canvas')!.height).toBe(200 * pixelRatio); 167 | }); 168 | }); 169 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/table.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {format as d3Format} from 'd3-format'; 19 | import {select as d3Select} from 'd3-selection'; 20 | import {css} from 'glamor'; 21 | import {tachyons as tac} from 'glamor-tachyons'; 22 | import {Drawable, TableData} from '../types'; 23 | import {getDrawArea} from './render_utils'; 24 | 25 | /** 26 | * Renders a table 27 | * 28 | * ```js 29 | * const headers = [ 30 | * 'Col 1', 31 | * 'Col 2', 32 | * 'Col 3', 33 | * ]; 34 | * 35 | * const values = [ 36 | * [1, 2, 3], 37 | * ['4', '5', '6'], 38 | * ['strong>7', true, false], 39 | * ]; 40 | * 41 | * const surface = { name: 'Table', tab: 'Charts' }; 42 | * tfvis.render.table(surface, { headers, values }); 43 | * ``` 44 | * 45 | * @param opts.fontSize fontSize in pixels for text in the chart. 46 | * 47 | */ 48 | /** @doc {heading: 'Charts', namespace: 'render'} */ 49 | export function table( 50 | container: Drawable, 51 | // tslint:disable-next-line:no-any 52 | data: TableData, opts: {fontSize?: number} = {}) { 53 | if (data && data.headers == null) { 54 | throw new Error('Data to render must have a "headers" property'); 55 | } 56 | 57 | if (data && data.values == null) { 58 | throw new Error('Data to render must have a "values" property'); 59 | } 60 | 61 | const drawArea = getDrawArea(container); 62 | 63 | const options = Object.assign({}, defaultOpts, opts); 64 | 65 | let table = d3Select(drawArea).select('table.tf-table'); 66 | 67 | const tableStyle = css({ 68 | ...tac('f6 w-100 mw8 center'), 69 | fontSize: options.fontSize, 70 | }); 71 | 72 | // If a table is not already present on this element add one 73 | if (table.size() === 0) { 74 | table = d3Select(drawArea).append('table'); 75 | 76 | table.attr('class', ` ${tableStyle} tf-table`); 77 | 78 | table.append('thead').append('tr'); 79 | table.append('tbody'); 80 | } 81 | 82 | if (table.size() !== 1) { 83 | throw new Error('Error inserting table'); 84 | } 85 | 86 | // 87 | // Add the reader row 88 | // 89 | const headerRowStyle = 90 | css({...tac('fw6 bb b--black-20 tl pb3 pr3 bg-white')}); 91 | const headers = 92 | table.select('thead').select('tr').selectAll('th').data(data.headers); 93 | const headersEnter = 94 | headers.enter().append('th').attr('class', `${headerRowStyle}`); 95 | headers.merge(headersEnter).html(d => d); 96 | 97 | headers.exit().remove(); 98 | 99 | // 100 | // Add the data rows 101 | // 102 | const format = d3Format(',.4~f'); 103 | 104 | const rows = table.select('tbody').selectAll('tr').data(data.values); 105 | const rowsEnter = rows.enter().append('tr'); 106 | 107 | // Nested selection to add individual cells 108 | const cellStyle = css({...tac('pa1 bb b--black-20')}); 109 | const cells = rows.merge(rowsEnter).selectAll('td').data(d => d); 110 | const cellsEnter = cells.enter().append('td').attr('class', `${cellStyle}`); 111 | cells.merge(cellsEnter).html(d => typeof d === 'number' ? format(d) : d); 112 | 113 | cells.exit().remove(); 114 | rows.exit().remove(); 115 | } 116 | 117 | const defaultOpts = { 118 | fontSize: 14, 119 | }; 120 | -------------------------------------------------------------------------------- /tfjs-vis/src/render/table_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {table} from './table'; 19 | 20 | function getRowHTML(row: Element) { 21 | return Array.from(row.querySelectorAll('td')).map(r => r.innerHTML); 22 | } 23 | 24 | function getRowText(row: Element) { 25 | return Array.from(row.querySelectorAll('td')).map(r => r.textContent); 26 | } 27 | 28 | describe('renderTable', () => { 29 | beforeEach(() => { 30 | document.body.innerHTML = '
'; 31 | }); 32 | 33 | it('renders a table', () => { 34 | const headers = [ 35 | 'Col1', 36 | 'Col 2', 37 | 'Column 3', 38 | ]; 39 | 40 | const values = [ 41 | [1, 2, 3], 42 | ['4', '5', '6'], 43 | ['7', true, false], 44 | ]; 45 | 46 | const container = document.getElementById('container') as HTMLElement; 47 | table(container, {headers, values}); 48 | 49 | expect(document.querySelectorAll('.tf-table').length).toBe(1); 50 | expect(document.querySelectorAll('.tf-table thead tr').length).toBe(1); 51 | 52 | const headerEl = document.querySelectorAll('.tf-table thead tr th'); 53 | expect(headerEl[0].innerHTML).toEqual('Col1'); 54 | expect(headerEl[1].innerHTML).toEqual('Col 2'); 55 | expect(headerEl[2].innerHTML).toEqual('Column 3'); 56 | expect(headerEl[2].textContent).toEqual('Column 3'); 57 | 58 | expect(document.querySelectorAll('.tf-table tbody tr').length).toBe(3); 59 | 60 | const rows = document.querySelectorAll('.tf-table tbody tr'); 61 | expect(getRowHTML(rows[0])).toEqual(['1', '2', '3']); 62 | expect(getRowHTML(rows[1])).toEqual(['4', '5', '6']); 63 | expect(getRowHTML(rows[2])).toEqual([ 64 | '7', 'true', 'false' 65 | ]); 66 | expect(getRowText(rows[2])).toEqual(['7', 'true', 'false']); 67 | }); 68 | 69 | it('requires necessary param', () => { 70 | const container = document.getElementById('container') as HTMLElement; 71 | 72 | // @ts-ignore 73 | expect(() => table({headers: []}, container)).toThrow(); 74 | // @ts-ignore 75 | expect(() => table({values: [[]]}, container)).toThrow(); 76 | // @ts-ignore 77 | expect(() => table({}, container)).toThrow(); 78 | }); 79 | 80 | it('should not throw on empty table', () => { 81 | const container = document.getElementById('container') as HTMLElement; 82 | const headers: string[] = []; 83 | const values: string[][] = []; 84 | 85 | expect(() => table(container, {headers, values})).not.toThrow(); 86 | 87 | expect(document.querySelectorAll('.tf-table').length).toBe(1); 88 | expect(document.querySelectorAll('.tf-table thead tr').length).toBe(1); 89 | expect(document.querySelectorAll('.tf-table tbody tr').length).toBe(0); 90 | }); 91 | }); 92 | -------------------------------------------------------------------------------- /tfjs-vis/src/show/history_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {fitCallbacks, history} from './history'; 19 | 20 | describe('fitCallbacks', () => { 21 | beforeEach(() => { 22 | document.body.innerHTML = '
'; 23 | }); 24 | 25 | it('returns two callbacks', async () => { 26 | const container = {name: 'Test'}; 27 | const callbacks = fitCallbacks(container, ['loss', 'acc']); 28 | 29 | expect(typeof (callbacks.onEpochEnd)).toEqual('function'); 30 | expect(typeof (callbacks.onBatchEnd)).toEqual('function'); 31 | }); 32 | 33 | it('returns one callback', async () => { 34 | const container = {name: 'Test'}; 35 | const callbacks = fitCallbacks(container, ['loss', 'acc'], { 36 | callbacks: ['onBatchEnd'], 37 | }); 38 | 39 | expect(callbacks.onEpochEnd).toEqual(undefined); 40 | expect(typeof (callbacks.onBatchEnd)).toEqual('function'); 41 | }); 42 | 43 | it('onEpochEnd callback can render logs', async () => { 44 | const container = {name: 'Test'}; 45 | const callbacks = 46 | fitCallbacks(container, ['loss', 'val_loss', 'acc', 'val_acc']); 47 | 48 | const l1 = {loss: 0.5, 'val_loss': 0.7}; 49 | const l2 = {loss: 0.2, acc: 0.6, 'val_loss': 0.5, 'val_acc': 0.3}; 50 | 51 | await callbacks.onEpochEnd(0, l1); 52 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 53 | expect(document.querySelectorAll('div[data-name="loss"]').length).toBe(1); 54 | 55 | await callbacks.onEpochEnd(1, l2); 56 | expect(document.querySelectorAll('.vega-embed').length).toBe(2); 57 | expect(document.querySelectorAll('div[data-name="loss"]').length).toBe(1); 58 | expect(document.querySelectorAll('div[data-name="acc"]').length).toBe(1); 59 | }); 60 | 61 | it('onBatchEnd callback can render logs', async () => { 62 | const container = {name: 'Test'}; 63 | const callbacks = fitCallbacks(container, ['loss', 'acc']); 64 | 65 | const l1 = {loss: 0.5}; 66 | const l2 = {loss: 0.2, acc: 0.6}; 67 | 68 | await callbacks.onBatchEnd(0, l1); 69 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 70 | expect(document.querySelectorAll('div[data-name="loss"]').length).toBe(1); 71 | 72 | await callbacks.onBatchEnd(1, l2); 73 | expect(document.querySelectorAll('.vega-embed').length).toBe(2); 74 | expect(document.querySelectorAll('div[data-name="loss"]').length).toBe(1); 75 | expect(document.querySelectorAll('div[data-name="acc"]').length).toBe(1); 76 | }); 77 | }); 78 | 79 | describe('history', () => { 80 | beforeEach(() => { 81 | document.body.innerHTML = '
'; 82 | }); 83 | 84 | it('renders a logs[]', async () => { 85 | const container = {name: 'Test'}; 86 | const logs = [{loss: 0.5}, {loss: 0.3}]; 87 | const metrics = ['loss']; 88 | await history(container, logs, metrics); 89 | 90 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 91 | }); 92 | 93 | it('renders a logs object with multiple metrics', async () => { 94 | const container = {name: 'Test'}; 95 | const logs = [{loss: 0.2, acc: 0.6}, {loss: 0.1, acc: 0.65}]; 96 | const metrics = ['loss', 'acc']; 97 | await history(container, logs, metrics); 98 | 99 | expect(document.querySelectorAll('.vega-embed').length).toBe(2); 100 | }); 101 | 102 | it('renders a history object with multiple metrics', async () => { 103 | const container = {name: 'Test'}; 104 | const hist = { 105 | history: { 106 | 'loss': [0.7, 0.3, 0.2], 107 | 'acc': [0.2, 0.3, 0.21], 108 | } 109 | }; 110 | const metrics = ['loss', 'acc']; 111 | await history(container, hist, metrics); 112 | 113 | expect(document.querySelectorAll('.vega-embed').length).toBe(2); 114 | }); 115 | 116 | it('can render multiple history objects', async () => { 117 | const container = {name: 'Test'}; 118 | const container2 = {name: 'Other Test'}; 119 | const hist = { 120 | history: { 121 | 'loss': [0.7, 0.3, 0.2], 122 | 'acc': [0.2, 0.3, 0.21], 123 | } 124 | }; 125 | 126 | await history(container, hist, ['loss']); 127 | await history(container2, hist, ['acc']); 128 | 129 | expect(document.querySelectorAll('.vega-embed').length).toBe(2); 130 | }); 131 | }); 132 | -------------------------------------------------------------------------------- /tfjs-vis/src/show/model.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | import {Layer} from '@tensorflow/tfjs-layers/dist/engine/topology'; 20 | 21 | import {histogram} from '../render/histogram'; 22 | import {getDrawArea} from '../render/render_utils'; 23 | import {table} from '../render/table'; 24 | import {Drawable, HistogramStats} from '../types'; 25 | import {subSurface} from '../util/dom'; 26 | import {tensorStats} from '../util/math'; 27 | 28 | /** 29 | * Renders a summary of a tf.Model. Displays a table with layer information. 30 | * 31 | * ```js 32 | * const model = tf.sequential({ 33 | * layers: [ 34 | * tf.layers.dense({inputShape: [784], units: 32, activation: 'relu'}), 35 | * tf.layers.dense({units: 10, activation: 'softmax'}), 36 | * ] 37 | * }); 38 | * 39 | * const surface = { name: 'Model Summary', tab: 'Model Inspection'}; 40 | * tfvis.show.modelSummary(surface, model); 41 | * ``` 42 | * 43 | */ 44 | /** 45 | * @doc { 46 | * heading: 'Models & Tensors', 47 | * subheading: 'Model Inspection', 48 | * namespace: 'show' 49 | * } 50 | */ 51 | export async function modelSummary(container: Drawable, model: tf.LayersModel) { 52 | const drawArea = getDrawArea(container); 53 | const summary = getModelSummary(model); 54 | 55 | const headers = [ 56 | 'Layer Name', 57 | 'Output Shape', 58 | '# Of Params', 59 | 'Trainable', 60 | ]; 61 | 62 | const values = summary.layers.map( 63 | l => 64 | [l.name, 65 | l.outputShape, 66 | l.parameters, 67 | l.trainable, 68 | ]); 69 | 70 | table(drawArea, {headers, values}); 71 | } 72 | 73 | /** 74 | * Renders summary information about a layer and a histogram of parameters in 75 | * that layer. 76 | * 77 | * ```js 78 | * const model = tf.sequential({ 79 | * layers: [ 80 | * tf.layers.dense({inputShape: [784], units: 32, activation: 'relu'}), 81 | * tf.layers.dense({units: 10, activation: 'softmax'}), 82 | * ] 83 | * }); 84 | * 85 | * const surface = { name: 'Layer Summary', tab: 'Model Inspection'}; 86 | * tfvis.show.layer(surface, model.getLayer(undefined, 1)); 87 | * ``` 88 | * 89 | */ 90 | /** 91 | * @doc { 92 | * heading: 'Models & Tensors', 93 | * subheading: 'Model Inspection', 94 | * namespace: 'show' 95 | * } 96 | */ 97 | export async function layer(container: Drawable, layer: Layer) { 98 | const drawArea = getDrawArea(container); 99 | const details = await getLayerDetails(layer); 100 | 101 | const headers = [ 102 | 'Weight Name', 103 | 'Shape', 104 | 'Min', 105 | 'Max', 106 | '# Params', 107 | '# Zeros', 108 | '# NaNs', 109 | '# Infinity', 110 | ]; 111 | 112 | // Show layer summary 113 | const weightsInfoSurface = subSurface(drawArea, 'layer-weights-info'); 114 | const detailValues = details.map( 115 | l => 116 | [l.name, l.shape, l.stats.min, l.stats.max, l.weight.size, 117 | l.stats.numZeros, l.stats.numNans, l.stats.numInfs]); 118 | 119 | table(weightsInfoSurface, {headers, values: detailValues}); 120 | 121 | const histogramSelectorSurface = subSurface(drawArea, 'select-layer'); 122 | const layerValuesHistogram = subSurface(drawArea, 'param-distribution'); 123 | 124 | const handleSelection = async (layerName: string) => { 125 | const layer = details.filter(d => d.name === layerName)[0]; 126 | const weights = await layer.weight.data(); 127 | 128 | histogram( 129 | layerValuesHistogram, weights, {height: 150, width: 460, stats: false}); 130 | }; 131 | 132 | addHistogramSelector( 133 | details.map(d => d.name), histogramSelectorSurface, handleSelection); 134 | } 135 | 136 | // 137 | // Helper functions 138 | // 139 | 140 | function getModelSummary(model: tf.LayersModel) { 141 | return { 142 | layers: model.layers.map(getLayerSummary), 143 | }; 144 | } 145 | 146 | /* 147 | * Gets summary information/metadata about a layer. 148 | */ 149 | function getLayerSummary(layer: Layer): LayerSummary { 150 | let outputShape: string; 151 | if (Array.isArray(layer.outputShape[0])) { 152 | const shapes = (layer.outputShape as number[][]).map(s => formatShape(s)); 153 | outputShape = `[${shapes.join(', ')}]`; 154 | } else { 155 | outputShape = formatShape(layer.outputShape as number[]); 156 | } 157 | 158 | return { 159 | name: layer.name, 160 | trainable: layer.trainable, 161 | parameters: layer.countParams(), 162 | outputShape, 163 | }; 164 | } 165 | 166 | interface LayerSummary { 167 | name: string; 168 | trainable: boolean; 169 | parameters: number; 170 | outputShape: string; 171 | } 172 | 173 | /* 174 | * Gets summary stats and shape for all weights in a layer. 175 | */ 176 | async function getLayerDetails(layer: Layer): Promise> { 178 | const weights = layer.getWeights(); 179 | const layerVariables = layer.weights; 180 | const statsPromises = weights.map(tensorStats); 181 | const stats = await Promise.all(statsPromises); 182 | const shapes = weights.map(w => w.shape); 183 | return weights.map((weight, i) => ({ 184 | name: layerVariables[i].name, 185 | stats: stats[i], 186 | shape: formatShape(shapes[i]), 187 | weight, 188 | })); 189 | } 190 | 191 | function formatShape(shape: number[]): string { 192 | const oShape: Array = shape.slice(); 193 | if (oShape.length === 0) { 194 | return 'Scalar'; 195 | } 196 | if (oShape[0] === null) { 197 | oShape[0] = 'batch'; 198 | } 199 | return `[${oShape.join(',')}]`; 200 | } 201 | 202 | function addHistogramSelector( 203 | items: string[], parent: HTMLElement, 204 | // tslint:disable-next-line:no-any 205 | selectionHandler: (item: string) => any) { 206 | const select = ` 207 | 210 | `; 211 | 212 | const button = ``; 213 | const content = `
${button}${select}
`; 214 | 215 | parent.innerHTML = content; 216 | 217 | // Add listeners 218 | const buttonEl = parent.querySelector('button')!; 219 | const selectEl = parent.querySelector('select')!; 220 | 221 | buttonEl.addEventListener('click', () => { 222 | selectionHandler(selectEl.selectedOptions[0].label); 223 | }); 224 | } 225 | -------------------------------------------------------------------------------- /tfjs-vis/src/show/model_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | 20 | import {layer, modelSummary} from './model'; 21 | 22 | describe('modelSummary', () => { 23 | beforeEach(() => { 24 | document.body.innerHTML = '
'; 25 | }); 26 | 27 | it('renders a model summary', async () => { 28 | const container = {name: 'Test'}; 29 | const model = tf.sequential(); 30 | model.add(tf.layers.dense({units: 1, inputShape: [1]})); 31 | await modelSummary(container, model); 32 | expect(document.querySelectorAll('table').length).toBe(1); 33 | expect(document.querySelectorAll('tr').length).toBe(2); 34 | }); 35 | }); 36 | 37 | describe('layer', () => { 38 | beforeEach(() => { 39 | document.body.innerHTML = '
'; 40 | }); 41 | 42 | it('renders a layer summary', async () => { 43 | const container = {name: 'Test'}; 44 | const model = tf.sequential(); 45 | const dense = tf.layers.dense({units: 1, inputShape: [1]}); 46 | model.add(dense); 47 | model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); 48 | await layer(container, dense); 49 | expect(document.querySelectorAll('table').length).toBe(1); 50 | expect(document.querySelectorAll('tr').length).toBe(3); 51 | }); 52 | }); 53 | -------------------------------------------------------------------------------- /tfjs-vis/src/show/quality.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getDrawArea} from '../render/render_utils'; 19 | import {table} from '../render/table'; 20 | import {Drawable} from '../types'; 21 | 22 | /** 23 | * Renders a per class accuracy table for classification task evaluation 24 | * 25 | * ```js 26 | * const labels = tf.tensor1d([0, 0, 1, 2, 2, 2]); 27 | * const predictions = tf.tensor1d([0, 0, 0, 2, 1, 1]); 28 | * 29 | * const result = await tfvis.metrics.perClassAccuracy(labels, predictions); 30 | * console.log(result) 31 | * 32 | * const container = {name: 'Per Class Accuracy', tab: 'Evaluation'}; 33 | * const categories = ['cat', 'dog', 'mouse']; 34 | * await tfvis.show.perClassAccuracy(container, result, categories); 35 | * ``` 36 | * 37 | * @param container A `{name: string, tab?: string}` object specifying which 38 | * surface to render to. 39 | * @param classAccuracy An `Array<{accuracy: number, count: number}>` array with 40 | * the accuracy data. See metrics.perClassAccuracy for details on how to 41 | * generate this object. 42 | * @param classLabels An array of string labels for the classes in 43 | * `classAccuracy`. Optional. 44 | * 45 | */ 46 | export async function showPerClassAccuracy( 47 | container: Drawable, 48 | classAccuracy: Array<{accuracy: number, count: number}>, 49 | classLabels?: string[]) { 50 | const drawArea = getDrawArea(container); 51 | 52 | const headers = [ 53 | 'Class', 54 | 'Accuracy', 55 | '# Samples', 56 | ]; 57 | const values: Array> = []; 58 | 59 | for (let i = 0; i < classAccuracy.length; i++) { 60 | const label = classLabels ? classLabels[i] : i.toString(); 61 | const classAcc = classAccuracy[i]; 62 | values.push([label, classAcc.accuracy, classAcc.count]); 63 | } 64 | 65 | return table(drawArea, {headers, values}); 66 | } 67 | -------------------------------------------------------------------------------- /tfjs-vis/src/show/quality_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {showPerClassAccuracy} from './quality'; 19 | 20 | describe('perClassAccuracy', () => { 21 | beforeEach(() => { 22 | document.body.innerHTML = '
'; 23 | }); 24 | 25 | it('renders perClassAccuracy', async () => { 26 | const container = {name: 'Test'}; 27 | const acc = [ 28 | {accuracy: 0.5, count: 10}, 29 | {accuracy: 0.8, count: 10}, 30 | ]; 31 | 32 | const labels = ['cat', 'dog']; 33 | await showPerClassAccuracy(container, acc, labels); 34 | expect(document.querySelectorAll('table').length).toBe(1); 35 | }); 36 | 37 | it('renders perClassAccuracy without explicit labels', async () => { 38 | const container = {name: 'Test'}; 39 | const acc = [ 40 | {accuracy: 0.5, count: 10}, 41 | {accuracy: 0.8, count: 10}, 42 | ]; 43 | await showPerClassAccuracy(container, acc); 44 | expect(document.querySelectorAll('table').length).toBe(1); 45 | }); 46 | }); 47 | -------------------------------------------------------------------------------- /tfjs-vis/src/show/tensor.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Tensor} from '@tensorflow/tfjs'; 19 | 20 | import {histogram} from '../render/histogram'; 21 | import {getDrawArea} from '../render/render_utils'; 22 | import {Drawable} from '../types'; 23 | import {tensorStats} from '../util/math'; 24 | 25 | /** 26 | * Shows a histogram with the distribution of all values in a given tensor. 27 | * 28 | * ```js 29 | * const tensor = tf.tensor1d([0, 0, 0, 0, 2, 3, 4]); 30 | * 31 | * const surface = {name: 'Values Distribution', tab: 'Model Inspection'}; 32 | * await tfvis.show.valuesDistribution(surface, tensor); 33 | * ``` 34 | * 35 | */ 36 | /** 37 | * @doc {heading: 'Models & Tensors', subheading: 'Model Inspection', namespace: 38 | * 'show'} 39 | */ 40 | export async function valuesDistribution(container: Drawable, tensor: Tensor) { 41 | const drawArea = getDrawArea(container); 42 | const stats = await tensorStats(tensor); 43 | const values = await tensor.data(); 44 | histogram(drawArea, values, {height: 150, stats}); 45 | } 46 | -------------------------------------------------------------------------------- /tfjs-vis/src/show/tensor_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs'; 19 | import {valuesDistribution} from './tensor'; 20 | 21 | describe('perClassAccuracy', () => { 22 | beforeEach(() => { 23 | document.body.innerHTML = '
'; 24 | }); 25 | 26 | it('renders histogram', async () => { 27 | const container = {name: 'Test'}; 28 | const tensor = tf.tensor1d([0, 0, 0, 0, 2, 3, 4]); 29 | 30 | await valuesDistribution(container, tensor); 31 | expect(document.querySelectorAll('table').length).toBe(1); 32 | expect(document.querySelectorAll('.vega-embed').length).toBe(1); 33 | }); 34 | }); 35 | -------------------------------------------------------------------------------- /tfjs-vis/src/types.ts: -------------------------------------------------------------------------------- 1 | import {Tensor2D} from '@tensorflow/tfjs'; 2 | 3 | /* 4 | * @license 5 | * Copyright 2018 Google LLC. All Rights Reserved. 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | * ============================================================================= 18 | */ 19 | 20 | // Types shared across the project and that users will commonly interact with 21 | 22 | /** 23 | * The public api of a 'surface' 24 | */ 25 | export interface Surface { 26 | /** 27 | * The containing HTML element for this surface 28 | */ 29 | container: HTMLElement; 30 | 31 | /** 32 | * A textual label for the surface. 33 | */ 34 | label: HTMLElement; 35 | 36 | /** 37 | * A container for plots and other renderings 38 | */ 39 | drawArea: HTMLElement; 40 | } 41 | 42 | /** 43 | * Options used to specify a surface. 44 | * 45 | * name and tab are also used for retrieval of a surface instance. 46 | */ 47 | export interface SurfaceInfo { 48 | /** 49 | * The name / label of this surface 50 | */ 51 | name: string; 52 | 53 | /** 54 | * The name of the tab this surface should appear on 55 | */ 56 | tab?: string; 57 | 58 | /** 59 | * Display Styles for the surface 60 | */ 61 | styles?: StyleOptions; 62 | } 63 | 64 | /** 65 | * Internally all surfaces must have a tab. 66 | */ 67 | export interface SurfaceInfoStrict extends SurfaceInfo { 68 | name: string; 69 | tab: string; 70 | styles?: StyleOptions; 71 | } 72 | 73 | /** 74 | * Style properties are generally optional as components will specify defaults. 75 | */ 76 | export interface StyleOptions { 77 | width?: string; 78 | height?: string; 79 | maxWidth?: string; 80 | maxHeight?: string; 81 | } 82 | 83 | /** 84 | * @docalias HTMLElement|{name: string, tab?: string}|Surface|{drawArea: 85 | * HTMLElement} 86 | */ 87 | export type Drawable = HTMLElement|Surface|SurfaceInfo|{ 88 | drawArea: HTMLElement; 89 | }; 90 | 91 | export function isSurfaceInfo(drawable: Drawable): drawable is SurfaceInfo { 92 | if ((drawable as SurfaceInfo).name != null) { 93 | return true; 94 | } 95 | return false; 96 | } 97 | 98 | export function isSurface(drawable: Drawable): drawable is Surface { 99 | if ((drawable as Surface).drawArea instanceof HTMLElement) { 100 | return true; 101 | } 102 | return false; 103 | } 104 | 105 | /** 106 | * Common visualisation options for '.render' functions. 107 | */ 108 | export interface VisOptions { 109 | /** 110 | * Width of chart in px 111 | */ 112 | width?: number; 113 | /** 114 | * Height of chart in px 115 | */ 116 | height?: number; 117 | /** 118 | * Label for xAxis 119 | */ 120 | xLabel?: string; 121 | /** 122 | * Label for yAxis 123 | */ 124 | yLabel?: string; 125 | /** 126 | * Fontsize in px 127 | */ 128 | fontSize?: number; 129 | /** 130 | * Will be set automatically 131 | */ 132 | xType?: 'quantitative'|'ordinal'|'nominal'; 133 | /** 134 | * Will be set automatically 135 | */ 136 | yType?: 'quantitative'|'ordinal'|'nominal'; 137 | } 138 | 139 | /** 140 | * Options for XY plots 141 | */ 142 | export interface XYPlotOptions extends VisOptions { 143 | /** 144 | * domain of the x axis. Overriden by zoomToFit 145 | */ 146 | xAxisDomain?: [number, number]; 147 | /** 148 | * domain of the y axis. Overriden by zoomToFit 149 | */ 150 | yAxisDomain?: [number, number]; 151 | /** 152 | * Set the chart bounds to just fit the data. This may modify the axis scales 153 | * but allows fitting more data into view. 154 | */ 155 | zoomToFit?: boolean; 156 | } 157 | 158 | /** 159 | * Data format for XY plots 160 | */ 161 | export interface XYPlotData { 162 | /** 163 | * An array (or nested array) of {x, y} tuples. 164 | */ 165 | values: Point2D[][]|Point2D[]; 166 | /** 167 | * Series names/labels 168 | */ 169 | series?: string[]; 170 | } 171 | 172 | /** 173 | * Histogram options. 174 | */ 175 | export interface HistogramOpts extends VisOptions { 176 | /** 177 | * By default a histogram will also compute and display summary statistics. 178 | * If stats is set to false then summary statistics will not be displayed. 179 | * 180 | * Pre computed stats can also be passed in and should have the following 181 | * format: 182 | * { 183 | * numVals?: number, 184 | * min?: number, 185 | * max?: number, 186 | * numNans?: number, 187 | * numZeros?: number, 188 | * numInfs?: number, 189 | * } 190 | */ 191 | stats?: HistogramStats|false; 192 | 193 | /** 194 | * Maximum number of bins in histogram. 195 | */ 196 | maxBins?: number; 197 | } 198 | 199 | /** 200 | * Summary statistics for histogram. 201 | */ 202 | export interface HistogramStats { 203 | numVals?: number; 204 | min?: number; 205 | max?: number; 206 | numNans?: number; 207 | numZeros?: number; 208 | numInfs?: number; 209 | } 210 | 211 | /** 212 | * Type alias for typed arrays 213 | */ 214 | export type TypedArray = Int8Array|Uint8Array|Int16Array|Uint16Array|Int32Array| 215 | Uint32Array|Uint8ClampedArray|Float32Array|Float64Array; 216 | 217 | /** 218 | * An object with a 'values' property and a 'labels' property. 219 | */ 220 | export interface ConfusionMatrixData { 221 | /** 222 | * a square matrix of numbers representing counts for each (label, prediction) 223 | * pair 224 | */ 225 | values: number[][]; 226 | 227 | /** 228 | * Human readable labels for each class in the matrix. Optional 229 | */ 230 | tickLabels?: string[]; 231 | } 232 | 233 | export interface ConfusionMatrixOptions extends VisOptions { 234 | /** 235 | * Color cells on the diagonal. Defaults to true 236 | */ 237 | shadeDiagonal?: boolean; 238 | /** 239 | * render the values of each cell as text. Defaults to true 240 | */ 241 | showTextOverlay?: boolean; 242 | } 243 | 244 | /** 245 | * Datum format for scatter and line plots 246 | */ 247 | export interface Point2D { 248 | x: number; 249 | y: number; 250 | } 251 | 252 | /** 253 | * An object with a 'values' property and a 'labels' property. 254 | */ 255 | export interface HeatmapData { 256 | /** 257 | * Matrix of values in column-major order. 258 | * 259 | * Row major order is supported by setting a boolean in options. 260 | */ 261 | values: number[][]|Tensor2D; 262 | /** 263 | * x axis tick labels 264 | */ 265 | xTickLabels?: string[]; 266 | /** 267 | * y axis tick labels 268 | */ 269 | yTickLabels?: string[]; 270 | } 271 | 272 | /** 273 | * Color map names. 274 | */ 275 | /** @docinline */ 276 | export type NamedColorMap = 'greyscale'|'viridis'|'blues'; 277 | 278 | /** 279 | * Visualization options for Heatmap 280 | */ 281 | export interface HeatmapOptions extends VisOptions { 282 | /** 283 | * Defaults to viridis 284 | */ 285 | colorMap?: NamedColorMap; 286 | 287 | /** 288 | * Custom output domain for the color scale. 289 | * Useful if you want to plot multiple heatmaps using the same scale. 290 | */ 291 | domain?: number[]; 292 | 293 | /** 294 | * Pass in data values in row-major order. 295 | * 296 | * Internally this will transpose the data values before rendering. 297 | */ 298 | rowMajor?: boolean; 299 | } 300 | 301 | /** 302 | * Data format for render.table 303 | */ 304 | export interface TableData { 305 | /** 306 | * Column names 307 | */ 308 | headers: string[]; 309 | 310 | /** 311 | * An array of arrays (one for each row). The inner 312 | * array length usually matches the length of data.headers. 313 | * 314 | * Typically the values are numbers or strings. 315 | */ 316 | // tslint:disable-next-line:no-any 317 | values: any[][]; 318 | } 319 | -------------------------------------------------------------------------------- /tfjs-vis/src/types/glamor-tachyons/index.d.ts: -------------------------------------------------------------------------------- 1 | // Minimal typings for glamor-tachyhons 2 | 3 | declare module 'glamor-tachyons' { 4 | export function tachyons(input: string|{}): {}; 5 | } 6 | -------------------------------------------------------------------------------- /tfjs-vis/src/util/dom.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {css} from 'glamor'; 19 | import {tachyons as tac} from 'glamor-tachyons'; 20 | 21 | import {getDrawArea} from '../render/render_utils'; 22 | import {Drawable} from '../types'; 23 | 24 | const DEFAULT_SUBSURFACE_OPTS = { 25 | prepend: false, 26 | }; 27 | 28 | /** 29 | * Utility function to create/retrieve divs within an HTMLElement|Surface 30 | */ 31 | export function subSurface(parent: Drawable, name: string, opts: Options = {}) { 32 | const container = getDrawArea(parent); 33 | const style = css({ 34 | '& canvas': { 35 | display: 'block', 36 | }, 37 | ...tac('mv2'), 38 | }); 39 | const titleStyle = css({ 40 | backgroundColor: 'white', 41 | display: 'inline-block', 42 | boxSizing: 'border-box', 43 | borderBottom: '1px solid #357EDD', 44 | lineHeight: '2em', 45 | padding: '0 10px 0 10px', 46 | marginBottom: '20px', 47 | ...tac('fw6 tl') 48 | }); 49 | const options = Object.assign({}, DEFAULT_SUBSURFACE_OPTS, opts); 50 | 51 | let sub: HTMLElement|null = container.querySelector(`div[data-name=${name}]`); 52 | if (!sub) { 53 | sub = document.createElement('div'); 54 | sub.setAttribute('class', `${style}`); 55 | sub.dataset.name = name; 56 | 57 | if (options.title) { 58 | const title = document.createElement('div'); 59 | title.setAttribute('class', `subsurface-title ${titleStyle}`); 60 | title.innerText = options.title; 61 | sub.appendChild(title); 62 | } 63 | 64 | if (options.prepend) { 65 | container.insertBefore(sub, container.firstChild); 66 | } else { 67 | container.appendChild(sub); 68 | } 69 | } 70 | return sub; 71 | } 72 | 73 | interface Options { 74 | prepend?: boolean; 75 | title?: string; 76 | } 77 | -------------------------------------------------------------------------------- /tfjs-vis/src/util/utils.ts: -------------------------------------------------------------------------------- 1 | import {TypedArray} from '../types'; 2 | 3 | /* 4 | * @license 5 | * Copyright 2018 Google LLC. All Rights Reserved. 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | * ============================================================================= 18 | */ 19 | 20 | /** 21 | * Tests a boolean expression and throws a message if false. 22 | */ 23 | export function assert(expr: boolean, msg: string|(() => string)) { 24 | if (!expr) { 25 | throw new Error(typeof msg === 'string' ? msg : msg()); 26 | } 27 | } 28 | 29 | export function assertShapesMatch( 30 | shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void { 31 | assert( 32 | arraysEqual(shapeA, shapeB), 33 | errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); 34 | } 35 | 36 | export function arraysEqual(n1: number[]|TypedArray, n2: number[]|TypedArray) { 37 | if (n1.length !== n2.length) { 38 | return false; 39 | } 40 | for (let i = 0; i < n1.length; i++) { 41 | if (n1[i] !== n2[i]) { 42 | return false; 43 | } 44 | } 45 | return true; 46 | } 47 | 48 | // Number of decimal places to when checking float similarity 49 | export const DECIMAL_PLACES_TO_CHECK = 4; 50 | -------------------------------------------------------------------------------- /tfjs-vis/src/visor.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {VisorComponent} from './components/visor'; 19 | import {SurfaceInfo, SurfaceInfoStrict} from './types'; 20 | 21 | let visorSingleton: Visor; 22 | const DEFAULT_TAB = 'Visor'; 23 | const VISOR_CONTAINER_ID = 'tfjs-visor-container'; 24 | 25 | /** 26 | * The primary interface to the visor is the visor() function. 27 | * 28 | * This returns a singleton instance of the Visor class. The 29 | * singleton object will be replaced if the visor is removed from the DOM for 30 | * some reason. 31 | * 32 | * ```js 33 | * // Show the visor 34 | * tfvis.visor(); 35 | * ``` 36 | * 37 | */ 38 | /** @doc {heading: 'Visor & Surfaces'} */ 39 | export function visor(): Visor { 40 | if (typeof document === 'undefined') { 41 | throw new Error( 42 | 'No document defined. This library needs a browser/dom to work'); 43 | } 44 | 45 | if (document.getElementById(VISOR_CONTAINER_ID) && visorSingleton != null) { 46 | return visorSingleton; 47 | } 48 | 49 | // Create the container 50 | let visorEl = document.getElementById(VISOR_CONTAINER_ID); 51 | 52 | if (visorEl == null) { 53 | visorEl = document.createElement('div'); 54 | visorEl.id = VISOR_CONTAINER_ID; 55 | document.body.appendChild(visorEl); 56 | } 57 | 58 | let renderRoot: Element; 59 | function renderVisor( 60 | domNode: HTMLElement, 61 | surfaceList: Map): VisorComponent { 62 | let visorInstance: VisorComponent; 63 | renderRoot = VisorComponent.render(domNode, renderRoot, { 64 | ref: (r: VisorComponent) => visorInstance = r, 65 | surfaceList: Array.from(surfaceList.values()), 66 | }); 67 | // Side effect of VisorComponent.render() is to assign visorInstance 68 | return visorInstance!; 69 | } 70 | 71 | // TODO: consider changing this type. Possibly lift into top level state 72 | // object 73 | const surfaceList: Map = new Map(); 74 | const visorComponentInstance: VisorComponent = 75 | renderVisor(visorEl, surfaceList); 76 | 77 | visorSingleton = 78 | new Visor(visorComponentInstance, visorEl, surfaceList, renderVisor); 79 | 80 | return visorSingleton; 81 | } 82 | 83 | /** 84 | * An instance of the visor. An instance of this class is created using the 85 | * `visor()` function. 86 | */ 87 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 88 | export class Visor { 89 | private visorComponent: VisorComponent; 90 | private surfaceList: Map; 91 | private renderVisor: 92 | (domNode: HTMLElement, 93 | surfaceList: Map) => VisorComponent; 94 | 95 | /** 96 | * The underlying html element of the visor. 97 | */ 98 | public el: HTMLElement; 99 | 100 | constructor( 101 | visorComponent: VisorComponent, visorEl: HTMLElement, 102 | surfaceList: Map, 103 | renderVisor: 104 | (domNode: HTMLElement, 105 | surfaceList: Map) => VisorComponent) { 106 | this.visorComponent = visorComponent; 107 | this.el = visorEl; 108 | this.surfaceList = surfaceList; 109 | this.renderVisor = renderVisor; 110 | } 111 | 112 | /** 113 | * Creates a surface on the visor 114 | * 115 | * Most methods in tfjs-vis that take a surface also take a SurfaceInfo 116 | * so you rarely need to call this method unless you want to make a custom 117 | * plot. 118 | * 119 | * ```js 120 | * // Create a surface on a tab 121 | * tfvis.visor().surface({name: 'My Surface', tab: 'My Tab'}); 122 | * ``` 123 | * 124 | * ```js 125 | * // Create a surface and specify its height 126 | * tfvis.visor().surface({name: 'Custom Height', tab: 'My Tab', styles: { 127 | * height: 500 128 | * }}) 129 | * ``` 130 | * 131 | * @param options 132 | */ 133 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 134 | surface(options: SurfaceInfo) { 135 | const {name} = options; 136 | const tab = options.tab == null ? DEFAULT_TAB : options.tab; 137 | 138 | if (name == null || 139 | // tslint:disable-next-line 140 | !(typeof name === 'string' || name as any instanceof String)) { 141 | throw new Error( 142 | // tslint:disable-next-line 143 | 'You must pass a config object with a \'name\' property to create or retrieve a surface'); 144 | } 145 | 146 | const finalOptions: SurfaceInfoStrict = { 147 | ...options, 148 | tab, 149 | }; 150 | 151 | const key = `${name}-${tab}`; 152 | if (!this.surfaceList.has(key)) { 153 | this.surfaceList.set(key, finalOptions); 154 | } 155 | 156 | this.renderVisor(this.el as HTMLElement, this.surfaceList); 157 | return this.visorComponent.getSurface(name, tab); 158 | } 159 | 160 | /** 161 | * Returns a boolean indicating if the visor is in 'fullscreen' mode 162 | */ 163 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 164 | isFullscreen() { 165 | return this.visorComponent.isFullscreen(); 166 | } 167 | 168 | /** 169 | * Returns a boolean indicating if the visor is open 170 | */ 171 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 172 | isOpen() { 173 | return this.visorComponent.isOpen(); 174 | } 175 | 176 | /** 177 | * Closes the visor. 178 | */ 179 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 180 | close() { 181 | return this.visorComponent.close(); 182 | } 183 | 184 | /** 185 | * Opens the visor. 186 | */ 187 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 188 | open() { 189 | return this.visorComponent.open(); 190 | } 191 | 192 | /** 193 | * Toggles the visor (closed vs open). 194 | */ 195 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 196 | toggle() { 197 | return this.visorComponent.toggle(); 198 | } 199 | 200 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 201 | toggleFullScreen() { 202 | return this.visorComponent.toggleFullScreen(); 203 | } 204 | 205 | /** 206 | * Binds the ~ (tilde) key to toggle the visor. 207 | * 208 | * This is called by default when the visor is initially created. 209 | */ 210 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 211 | bindKeys() { 212 | return this.visorComponent.bindKeys(); 213 | } 214 | 215 | /** 216 | * Unbinds the keyboard control to toggle the visor. 217 | */ 218 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 219 | unbindKeys() { 220 | return this.visorComponent.unbindKeys(); 221 | } 222 | 223 | /** 224 | * Sets the active tab for the visor. 225 | */ 226 | /** @doc {heading: 'Visor & Surfaces', subheading: 'Visor Methods'} */ 227 | setActiveTab(tabName: string) { 228 | const tabs = this.visorComponent.state.tabs; 229 | if (!tabs.has(tabName)) { 230 | throw new Error(`Tab '${tabName}' does not exist`); 231 | } 232 | this.visorComponent.setState({activeTab: tabName}); 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /tfjs-vis/src/visor_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {visor} from './index'; 19 | 20 | const tick = (ms = 1) => new Promise(resolve => setTimeout(resolve, ms)); 21 | 22 | describe('Visor Singleton', () => { 23 | afterEach(() => { 24 | document.body.innerHTML = ''; 25 | }); 26 | 27 | it('renders an empty visor', () => { 28 | visor(); 29 | expect(document.querySelectorAll('.visor').length).toBe(1); 30 | }); 31 | 32 | it('visor.el is an HTMLElement', () => { 33 | const visorInstance = visor(); 34 | expect(visorInstance.el instanceof HTMLElement).toBe(true); 35 | }); 36 | 37 | it('renders only one visor', () => { 38 | const v1 = visor(); 39 | const v2 = visor(); 40 | const v3 = visor(); 41 | expect(document.querySelectorAll('.visor').length).toBe(1); 42 | expect(v1).toEqual(v2); 43 | expect(v1).toEqual(v3); 44 | }); 45 | 46 | it('adds a surface', () => { 47 | const visorInstance = visor(); 48 | visorInstance.surface({name: 'surface 1', tab: 'tab 1'}); 49 | expect(document.querySelectorAll('.tf-surface').length).toBe(1); 50 | expect(document.querySelector('.tf-surface')!.textContent) 51 | .toEqual('surface 1'); 52 | 53 | expect(document.querySelectorAll('.tf-tab').length).toBe(1); 54 | expect(document.querySelector('.tf-tab')!.textContent).toEqual('tab 1'); 55 | }); 56 | 57 | it('requires a surface name', () => { 58 | const visorInstance = visor(); 59 | expect(() => { 60 | // @ts-ignore 61 | visorInstance.surface(); 62 | }).toThrow(); 63 | 64 | expect(() => { 65 | // @ts-ignore 66 | visorInstance.surface('Incorrect Name Param'); 67 | }).toThrow(); 68 | 69 | expect(() => { 70 | // @ts-ignore 71 | visorInstance.surface({notName: 'Incorrect Name Param'}); 72 | }).toThrow(); 73 | }); 74 | 75 | it('retrieves a surface', () => { 76 | const visorInstance = visor(); 77 | const s1 = visorInstance.surface({name: 'surface 1', tab: 'tab 1'}); 78 | expect(document.querySelectorAll('.tf-surface').length).toBe(1); 79 | expect(document.querySelector('.tf-surface')!.textContent) 80 | .toEqual('surface 1'); 81 | 82 | const s2 = visorInstance.surface({name: 'surface 1', tab: 'tab 1'}); 83 | expect(document.querySelectorAll('.tf-surface').length).toBe(1); 84 | expect(document.querySelector('.tf-surface')!.textContent) 85 | .toEqual('surface 1'); 86 | 87 | expect(s1).toEqual(s2); 88 | }); 89 | 90 | it('adds a surface with the default tab', () => { 91 | const visorInstance = visor(); 92 | visorInstance.surface({name: 'surface1'}); 93 | 94 | expect(document.querySelectorAll('.tf-tab').length).toBe(1); 95 | expect(document.querySelector('.tf-tab')!.textContent).toEqual('Visor'); 96 | }); 97 | 98 | it('adds two surfaces', () => { 99 | const visorInstance = visor(); 100 | const s1 = visorInstance.surface({name: 'surface 1', tab: 'tab 1'}); 101 | const s2 = visorInstance.surface({name: 'surface 2', tab: 'tab 1'}); 102 | 103 | expect(s1).not.toEqual(s2); 104 | 105 | const surfaces = document.querySelectorAll('.tf-surface'); 106 | expect(surfaces.length).toBe(2); 107 | expect(document.querySelectorAll('.tf-tab').length).toBe(1); 108 | 109 | expect(surfaces[0].textContent).toEqual('surface 1'); 110 | expect(surfaces[1].textContent).toEqual('surface 2'); 111 | }); 112 | 113 | it('switches tabs on surface addition', () => { 114 | let tabs; 115 | const visorInstance = visor(); 116 | 117 | visorInstance.surface({name: 'surface 1', tab: 'tab 1'}); 118 | tabs = document.querySelectorAll('.tf-tab'); 119 | expect(tabs[0].getAttribute('data-isactive')).toEqual('true'); 120 | 121 | visorInstance.surface({name: 'surface 2', tab: 'tab 2'}); 122 | tabs = document.querySelectorAll('.tf-tab'); 123 | expect(tabs[1].getAttribute('data-isactive')).toEqual('true'); 124 | expect(tabs[0].getAttribute('data-isactive')).toBeFalsy(); 125 | 126 | visorInstance.surface({name: 'surface 3', tab: 'tab 3'}); 127 | tabs = document.querySelectorAll('.tf-tab'); 128 | expect(tabs[2].getAttribute('data-isactive')).toEqual('true'); 129 | expect(tabs[0].getAttribute('data-isactive')).toBeFalsy(); 130 | expect(tabs[1].getAttribute('data-isactive')).toBeFalsy(); 131 | }); 132 | 133 | it('closes/opens', async () => { 134 | const visorInstance = visor(); 135 | 136 | expect(document.querySelector('.visor')!.getAttribute('data-isopen')) 137 | .toBe('true'); 138 | expect(visorInstance.isOpen()).toBe(true); 139 | 140 | visorInstance.close(); 141 | await tick(); 142 | expect(document.querySelector('.visor')!.getAttribute('data-isopen')) 143 | .toBeFalsy(); 144 | expect(visorInstance.isOpen()).toBe(false); 145 | 146 | visorInstance.open(); 147 | await tick(); 148 | expect(document.querySelector('.visor')!.getAttribute('data-isopen')) 149 | .toBe('true'); 150 | expect(visorInstance.isOpen()).toBe(true); 151 | }); 152 | 153 | it('toggles', async () => { 154 | const visorInstance = visor(); 155 | 156 | expect(document.querySelector('.visor')!.getAttribute('data-isopen')) 157 | .toBe('true'); 158 | expect(visorInstance.isOpen()).toBe(true); 159 | 160 | visorInstance.toggle(); 161 | await tick(); 162 | expect(document.querySelector('.visor')!.getAttribute('data-isopen')) 163 | .toBeFalsy(); 164 | expect(visorInstance.isOpen()).toBe(false); 165 | 166 | visorInstance.toggle(); 167 | await tick(); 168 | expect(document.querySelector('.visor')!.getAttribute('data-isopen')) 169 | .toBe('true'); 170 | expect(visorInstance.isOpen()).toBe(true); 171 | }); 172 | 173 | it('fullscreen toggles', async () => { 174 | const visorInstance = visor(); 175 | expect(visorInstance.isOpen()).toBe(true); 176 | 177 | expect(document.querySelector('.visor')!.getAttribute('data-isfullscreen')) 178 | .toBeFalsy(); 179 | 180 | visorInstance.toggleFullScreen(); 181 | await tick(); 182 | expect(document.querySelector('.visor')!.getAttribute('data-isfullscreen')) 183 | .toBe('true'); 184 | 185 | visorInstance.toggleFullScreen(); 186 | await tick(); 187 | expect(document.querySelector('.visor')!.getAttribute('data-isfullscreen')) 188 | .toBeFalsy(); 189 | }); 190 | 191 | it('sets the active tab', async () => { 192 | let tabs; 193 | const visorInstance = visor(); 194 | 195 | visorInstance.surface({name: 'surface 1', tab: 'tab 1'}); 196 | visorInstance.surface({name: 'surface 2', tab: 'tab 2'}); 197 | visorInstance.surface({name: 'surface 2', tab: 'tab 3'}); 198 | 199 | tabs = document.querySelectorAll('.tf-tab'); 200 | expect(tabs[2].getAttribute('data-isactive')).toEqual('true'); 201 | 202 | visorInstance.setActiveTab('tab 2'); 203 | await tick(); 204 | tabs = document.querySelectorAll('.tf-tab'); 205 | expect(tabs[1].getAttribute('data-isactive')).toEqual('true'); 206 | 207 | visorInstance.setActiveTab('tab 1'); 208 | await tick(); 209 | tabs = document.querySelectorAll('.tf-tab'); 210 | expect(tabs[0].getAttribute('data-isactive')).toEqual('true'); 211 | 212 | visorInstance.setActiveTab('tab 3'); 213 | await tick(); 214 | tabs = document.querySelectorAll('.tf-tab'); 215 | expect(tabs[2].getAttribute('data-isactive')).toEqual('true'); 216 | }); 217 | 218 | it('throws error if tab does not exist', () => { 219 | const visorInstance = visor(); 220 | 221 | visorInstance.surface({name: 'surface 1', tab: 'tab 1'}); 222 | visorInstance.surface({name: 'surface 2', tab: 'tab 2'}); 223 | 224 | expect(() => { 225 | visorInstance.setActiveTab('not present'); 226 | }).toThrow(); 227 | }); 228 | 229 | it('unbinds keyboard handler', () => { 230 | const visorInstance = visor(); 231 | 232 | const BACKTICK_KEY = 192; 233 | const event = document.createEvent('Event'); 234 | event.initEvent('keydown', true, true); 235 | // @ts-ignore 236 | event['keyCode'] = BACKTICK_KEY; 237 | 238 | document.dispatchEvent(event); 239 | expect(visorInstance.isOpen()).toBe(false); 240 | document.dispatchEvent(event); 241 | expect(visorInstance.isOpen()).toBe(true); 242 | 243 | // Unbind keys 244 | visorInstance.unbindKeys(); 245 | document.dispatchEvent(event); 246 | expect(visorInstance.isOpen()).toBe(true); 247 | document.dispatchEvent(event); 248 | expect(visorInstance.isOpen()).toBe(true); 249 | }); 250 | 251 | it('rebinds keyboard handler', () => { 252 | const visorInstance = visor(); 253 | 254 | const BACKTICK_KEY = 192; 255 | const event = document.createEvent('Event'); 256 | event.initEvent('keydown', true, true); 257 | // @ts-ignore 258 | event['keyCode'] = BACKTICK_KEY; 259 | 260 | // Unbind keys 261 | visorInstance.unbindKeys(); 262 | document.dispatchEvent(event); 263 | expect(visorInstance.isOpen()).toBe(true); 264 | document.dispatchEvent(event); 265 | expect(visorInstance.isOpen()).toBe(true); 266 | 267 | // rebind keys 268 | visorInstance.bindKeys(); 269 | document.dispatchEvent(event); 270 | expect(visorInstance.isOpen()).toBe(false); 271 | document.dispatchEvent(event); 272 | expect(visorInstance.isOpen()).toBe(true); 273 | }); 274 | }); 275 | -------------------------------------------------------------------------------- /tfjs-vis/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "module": "commonjs", 4 | "moduleResolution": "node", 5 | "noImplicitAny": true, 6 | "sourceMap": true, 7 | "inlineSources": true, 8 | "removeComments": true, 9 | "preserveConstEnums": true, 10 | "declaration": true, 11 | "target": "es2015", 12 | "lib": [ 13 | "es2015", 14 | "dom" 15 | ], 16 | "outDir": "./dist", 17 | "noUnusedLocals": false, 18 | "noImplicitReturns": true, 19 | "noImplicitThis": true, 20 | "alwaysStrict": true, 21 | "strictNullChecks": true, 22 | "noUnusedParameters": true, 23 | "pretty": true, 24 | "noFallthroughCasesInSwitch": true, 25 | "allowUnreachableCode": false, 26 | "experimentalDecorators": true, 27 | "jsx": "react", 28 | "jsxFactory": "h", 29 | "esModuleInterop": true, 30 | "skipLibCheck": true 31 | }, 32 | "include": [ 33 | "src" 34 | ] 35 | } 36 | -------------------------------------------------------------------------------- /tfjs-vis/tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "tslint-no-circular-imports" 4 | ], 5 | "rules": { 6 | "array-type": [ 7 | true, 8 | "array-simple" 9 | ], 10 | "arrow-return-shorthand": true, 11 | "ban": [ 12 | true, 13 | [ 14 | "fit" 15 | ], 16 | [ 17 | "fdescribe" 18 | ], 19 | [ 20 | "xit" 21 | ], 22 | [ 23 | "xdescribe" 24 | ], 25 | [ 26 | "fitAsync" 27 | ], 28 | [ 29 | "xitAsync" 30 | ], 31 | [ 32 | "fitFakeAsync" 33 | ], 34 | [ 35 | "xitFakeAsync" 36 | ], 37 | { 38 | "name": [ 39 | "*", 40 | "reduce" 41 | ], 42 | "message": "Use forEach or a regular for loop instead." 43 | } 44 | ], 45 | "ban-types": [ 46 | true, 47 | [ 48 | "Object", 49 | "Use {} instead." 50 | ], 51 | [ 52 | "String", 53 | "Use 'string' instead." 54 | ], 55 | [ 56 | "Number", 57 | "Use 'number' instead." 58 | ], 59 | [ 60 | "Boolean", 61 | "Use 'boolean' instead." 62 | ] 63 | ], 64 | "class-name": true, 65 | "curly": true, 66 | "interface-name": [ 67 | true, 68 | "never-prefix" 69 | ], 70 | "jsdoc-format": true, 71 | "forin": false, 72 | "label-position": true, 73 | "max-line-length": { 74 | "options": { 75 | "limit": 80, 76 | "ignore-pattern": "^import |^export " 77 | } 78 | }, 79 | "new-parens": true, 80 | "no-angle-bracket-type-assertion": true, 81 | "no-any": true, 82 | "no-construct": true, 83 | "no-consecutive-blank-lines": true, 84 | "no-debugger": true, 85 | "no-default-export": true, 86 | "no-inferrable-types": true, 87 | "no-namespace": [ 88 | true, 89 | "allow-declarations" 90 | ], 91 | "no-reference": true, 92 | "no-require-imports": true, 93 | "no-string-throw": true, 94 | "no-unused-expression": true, 95 | "no-var-keyword": true, 96 | "object-literal-shorthand": true, 97 | "only-arrow-functions": [ 98 | true, 99 | "allow-declarations", 100 | "allow-named-functions" 101 | ], 102 | "prefer-const": true, 103 | "quotemark": [ 104 | true, 105 | "single" 106 | ], 107 | "radix": true, 108 | "restrict-plus-operands": true, 109 | "semicolon": [ 110 | true, 111 | "always", 112 | "ignore-bound-class-methods" 113 | ], 114 | "switch-default": true, 115 | "triple-equals": [ 116 | true, 117 | "allow-null-check" 118 | ], 119 | "use-isnan": true, 120 | "variable-name": [ 121 | true, 122 | "check-format", 123 | "ban-keywords", 124 | "allow-leading-underscore", 125 | "allow-trailing-underscore" 126 | ] 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /tfjs-vis/webpack.config.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | var path = require('path'); 19 | 20 | module.exports = { 21 | mode: 'production', 22 | entry: './dist/index.js', 23 | devtool: 'source-map', 24 | output: { 25 | path: path.resolve(__dirname, 'dist'), 26 | filename: 'tfjs-vis.umd.min.js', 27 | libraryTarget: 'umd', 28 | library: 'tfvis', 29 | }, 30 | externals: { 31 | '@tensorflow/tfjs': 'tf', 32 | } 33 | }; 34 | --------------------------------------------------------------------------------