├── .babelrc ├── .eslintignore ├── .eslintrc ├── .gitignore ├── .prettierignore ├── .vscode ├── launch.json └── settings.json ├── README.md ├── netlify.toml ├── package.json ├── public ├── classifiers │ ├── group1-shard1of1 │ └── model.json └── images │ ├── desktop.gif │ └── phone.gif ├── src ├── actions │ ├── chart.js │ ├── drawing.js │ ├── mnist.js │ └── pipeline.js ├── classifiers │ └── handwriting-digits-classifier.js ├── components │ ├── app.jsx │ ├── bounding-box.jsx │ ├── button-space.jsx │ ├── center-col.jsx │ ├── chart-section.jsx │ ├── confirm-retrain-button.jsx │ ├── display.jsx │ ├── drawing.jsx │ ├── image-pipeline.jsx │ ├── line-chart.jsx │ ├── mnist-command.jsx │ ├── mnist-content.jsx │ ├── mnist-drawing.jsx │ ├── mnist-footer.jsx │ ├── mnist-header.jsx │ ├── mnist.jsx │ └── output.jsx ├── containers │ ├── accuracy-chart.jsx │ ├── bounding-box.jsx │ ├── centered-box.jsx │ ├── clear-button.jsx │ ├── confirm-retrain-button.jsx │ ├── cropped-box.jsx │ ├── drawing.jsx │ ├── loss-chart.jsx │ ├── mnist-drawing.jsx │ ├── normalized-box.jsx │ ├── output.jsx │ └── retrain-button.jsx ├── data │ └── mnist-data.js ├── main.jsx ├── reducers │ ├── chart.js │ ├── drawing.js │ ├── index.js │ ├── mnist.js │ └── pipeline.js ├── sagas │ ├── index.js │ ├── mnist.js │ └── pipeline.js ├── store │ └── configureStore.js └── utils │ ├── classifier.js │ ├── image-processing.js │ └── rect.js ├── templates └── index.html ├── webpack.config.js ├── webpack ├── utils.js ├── webpack.analyze.js ├── webpack.common.js ├── webpack.dev.js └── webpack.prod.js └── yarn.lock /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [ 3 | "react" 4 | ], 5 | "plugins": [ 6 | "transform-class-properties", 7 | "transform-object-rest-spread", 8 | "babel-plugin-styled-components", 9 | [ 10 | "import", 11 | { 12 | "libraryName": "antd", 13 | "libraryDirectory": "es", 14 | "style": "css" 15 | } 16 | ] 17 | ], 18 | "env": { 19 | "development": { 20 | "plugins": [ 21 | "react-hot-loader/babel" 22 | ] 23 | }, 24 | "production": { 25 | "presets": ["react-optimize"] 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /.eslintignore: -------------------------------------------------------------------------------- 1 | webpack.*.js 2 | node_modules 3 | -------------------------------------------------------------------------------- /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "parser": "babel-eslint", 3 | "extends": "airbnb", 4 | "env": { 5 | "browser": true 6 | }, 7 | "plugins": [ 8 | "react", 9 | "jsx-a11y", 10 | "import" 11 | ] 12 | } 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | dist/ 3 | stats.json 4 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | .babelrc 2 | .eslintrc 3 | node_modules 4 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "chrome", 9 | "request": "launch", 10 | "name": "Launch Chrome against localhost", 11 | "url": "http://localhost:9000", 12 | "webRoot": "${workspaceFolder}" 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "javascript.format.enable": false, 4 | "editor.tabSize": 2, 5 | "editor.detectIndentation": false, 6 | "prettier.eslintIntegration": true, 7 | "editor.codeActionsOnSave": { 8 | "source.organizeImports": true 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Handwritten digit recogniton 2 | 3 | Digit recognition built with Tensorflow.js, Mnist dataset, React, Redux, Redux-Saga, Babel, Webpack, Styled-components, Eslint, Prettier and Ant Design. 4 | 5 | A demo is available at this location: [https://digit-recognition.ixartz.com](https://digit-recognition.ixartz.com). 6 | 7 | ### Videos 8 | 9 | Phone (iOS and Android) version: 10 | 11 | ![Demo on phone](https://digit-recognition.ixartz.com/images/phone.gif) 12 | 13 | Desktop version: 14 | 15 | ![Demo on desktop](https://digit-recognition.ixartz.com/images/desktop.gif) 16 | 17 | ## Setup environment 18 | 19 | This project is based on JavaScript environment and you need to install dependencies with Yarn or NPM: 20 | 21 | $ yarn install 22 | 23 | ## Launch locally 24 | 25 | $ yarn start 26 | $ Open https://localhost:9000 with your favorite browser 27 | 28 | ## Build for production 29 | 30 | $ yarn build 31 | 32 | ## Author 33 | 34 | [Ixartz's technical blog](https://blog.ixartz.com/) 35 | -------------------------------------------------------------------------------- /netlify.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | publish = 'dist' 3 | command = 'yarn build' 4 | 5 | [[redirects]] 6 | from = "http://digit-recognition.netlify.com/*" 7 | to = "https://digit-recognition.ixartz.com/:splat" 8 | status = 301 9 | force = true 10 | 11 | [[redirects]] 12 | from = "https://digit-recognition.netlify.com/*" 13 | to = "https://digit-recognition.ixartz.com/:splat" 14 | status = 301 15 | force = true 16 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorflow.js", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "start": "webpack-dev-server --hot --env.WEBPACK_CONFIG=dev --progress", 8 | "build": "webpack --env.WEBPACK_CONFIG=prod --progress", 9 | "http-server": "http-server -p 9000 ./dist", 10 | "serve": "yarn build && yarn http-server", 11 | "serve-analyze": "webpack --env.WEBPACK_CONFIG=analyze --progress && yarn http-server", 12 | "build-stats": "webpack --env.WEBPACK_CONFIG=prod --progress --profile --json > stats.json", 13 | "format": "prettier-eslint src/**/*.jsx src/**/*.js --write", 14 | "test": "eslint --ext .jsx,.js src/**" 15 | }, 16 | "husky": { 17 | "hooks": { 18 | "pre-commit": "lint-staged" 19 | } 20 | }, 21 | "lint-staged": { 22 | "*.{js,jsx}": [ 23 | "prettier-eslint --write", 24 | "git add" 25 | ] 26 | }, 27 | "keywords": [], 28 | "author": "", 29 | "license": "ISC", 30 | "dependencies": { 31 | "@tensorflow/tfjs": "^0.12.5", 32 | "antd": "^3.6.6", 33 | "bizcharts": "^3.2.1-beta.2", 34 | "prop-types": "^15.6.2", 35 | "react": "^16.4.2", 36 | "react-dom": "^16.4.2", 37 | "react-hot-loader": "^4.3.4", 38 | "react-redux": "^5.0.7", 39 | "redux": "^4.0.0", 40 | "redux-devtools-extension": "^2.13.5", 41 | "redux-saga": "^0.16.0", 42 | "styled-components": "^3.3.3" 43 | }, 44 | "devDependencies": { 45 | "async-stylesheet-webpack-plugin": "^0.4.1", 46 | "babel-core": "^6.26.0", 47 | "babel-eslint": "^8.2.5", 48 | "babel-loader": "^7.1.4", 49 | "babel-plugin-import": "^1.8.0", 50 | "babel-plugin-styled-components": "^1.5.1", 51 | "babel-plugin-transform-class-properties": "^6.24.1", 52 | "babel-plugin-transform-object-rest-spread": "^6.26.0", 53 | "babel-preset-react": "^6.24.1", 54 | "babel-preset-react-optimize": "^1.0.1", 55 | "clean-webpack-plugin": "^0.1.19", 56 | "copy-webpack-plugin": "^4.5.2", 57 | "css-loader": "^1.0.0", 58 | "eslint": "^5.0.1", 59 | "eslint-config-airbnb": "^17.0.0", 60 | "eslint-loader": "^2.1.0", 61 | "eslint-plugin-import": "^2.13.0", 62 | "eslint-plugin-jsx-a11y": "^6.1.0", 63 | "eslint-plugin-react": "^7.10.0", 64 | "html-webpack-inline-source-plugin": "^0.0.10", 65 | "html-webpack-plugin": "^3.2.0", 66 | "http-server": "^0.11.1", 67 | "husky": "^1.0.0-rc.13", 68 | "lint-staged": "^7.2.2", 69 | "mini-css-extract-plugin": "^0.4.1", 70 | "optimize-css-assets-webpack-plugin": "^5.0.0", 71 | "prettier-eslint": "^8.8.2", 72 | "prettier-eslint-cli": "^4.7.1", 73 | "style-loader": "^0.22.1", 74 | "uglifyjs-webpack-plugin": "^1.3.0", 75 | "webpack": "^4.5.0", 76 | "webpack-bundle-analyzer": "^3.3.2", 77 | "webpack-cdn-plugin": "^3.1.4", 78 | "webpack-cli": "^3.1.0", 79 | "webpack-dev-server": "^3.1.3", 80 | "webpack-merge": "^4.1.4" 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /public/classifiers/group1-shard1of1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixartz/handwritten-digit-recognition-tensorflowjs/20864a2b544a547b8e8a4a961f12b9f5af89b584/public/classifiers/group1-shard1of1 -------------------------------------------------------------------------------- /public/classifiers/model.json: -------------------------------------------------------------------------------- 1 | {"modelTopology": {"keras_version": "2.1.6", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": [{"class_name": "Conv2D", "config": {"name": "conv2d_21", "trainable": true, "batch_input_shape": [null, 28, 28, 1], "dtype": "float32", "filters": 32, "kernel_size": [5, 5], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Conv2D", "config": {"name": "conv2d_22", "trainable": true, "filters": 32, "kernel_size": [5, 5], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_11", "trainable": true, "pool_size": [2, 2], "padding": "valid", "strides": [2, 2], "data_format": "channels_last"}}, {"class_name": "Dropout", "config": {"name": "dropout_16", "trainable": true, "rate": 0.25, "noise_shape": null, "seed": null}}, {"class_name": "Conv2D", "config": {"name": "conv2d_23", "trainable": true, "filters": 64, "kernel_size": [3, 3], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Conv2D", "config": {"name": "conv2d_24", "trainable": true, "filters": 64, "kernel_size": [3, 3], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_12", "trainable": true, "pool_size": [2, 2], "padding": "valid", "strides": [2, 2], "data_format": "channels_last"}}, {"class_name": "Dropout", "config": {"name": "dropout_17", "trainable": true, "rate": 0.25, "noise_shape": null, "seed": null}}, {"class_name": "Flatten", "config": {"name": "flatten_6", "trainable": true, "data_format": "channels_last"}}, {"class_name": "Dense", "config": {"name": "dense_11", "trainable": true, "units": 256, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_18", "trainable": true, "rate": 0.5, "noise_shape": null, "seed": null}}, {"class_name": "Dense", "config": {"name": "dense_12", "trainable": true, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "training_config": {"optimizer_config": {"class_name": "RMSprop", "config": {"lr": 6.25000029685907e-05, "rho": 0.8999999761581421, "decay": 0.0, "epsilon": 1e-08}}, "loss": "categorical_crossentropy", "metrics": ["accuracy"], "sample_weight_mode": null, "loss_weights": null}}, "weightsManifest": [{"paths": ["group1-shard1of1"], "weights": [{"name": "conv2d_21/kernel", "shape": [5, 5, 1, 32], "dtype": "float32"}, {"name": "conv2d_21/bias", "shape": [32], "dtype": "float32"}, {"name": "conv2d_22/kernel", "shape": [5, 5, 32, 32], "dtype": "float32"}, {"name": "conv2d_22/bias", "shape": [32], "dtype": "float32"}, {"name": "conv2d_23/kernel", "shape": [3, 3, 32, 64], "dtype": "float32"}, {"name": "conv2d_23/bias", "shape": [64], "dtype": "float32"}, {"name": "conv2d_24/kernel", "shape": [3, 3, 64, 64], "dtype": "float32"}, {"name": "conv2d_24/bias", "shape": [64], "dtype": "float32"}, {"name": "dense_11/kernel", "shape": [3136, 256], "dtype": "float32"}, {"name": "dense_11/bias", "shape": [256], "dtype": "float32"}, {"name": "dense_12/kernel", "shape": [256, 10], "dtype": "float32"}, {"name": "dense_12/bias", "shape": [10], "dtype": "float32"}]}]} -------------------------------------------------------------------------------- /public/images/desktop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixartz/handwritten-digit-recognition-tensorflowjs/20864a2b544a547b8e8a4a961f12b9f5af89b584/public/images/desktop.gif -------------------------------------------------------------------------------- /public/images/phone.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixartz/handwritten-digit-recognition-tensorflowjs/20864a2b544a547b8e8a4a961f12b9f5af89b584/public/images/phone.gif -------------------------------------------------------------------------------- /src/actions/chart.js: -------------------------------------------------------------------------------- 1 | export const ChartAction = { 2 | ADD_LOSS_POINT: 'ADD_LOSS_POINT', 3 | ADD_ACCURACY_POINT: 'ADD_ACCURACY_POINT', 4 | RESET_CHART: 'RESET_CHART', 5 | }; 6 | 7 | export const addLossPoint = (batch, loss) => ({ 8 | type: ChartAction.ADD_LOSS_POINT, 9 | pt: { 10 | batch, 11 | loss, 12 | }, 13 | }); 14 | 15 | export const addAccuracyPoint = (batch, accuracy) => ({ 16 | type: ChartAction.ADD_ACCURACY_POINT, 17 | pt: { 18 | batch, 19 | accuracy, 20 | }, 21 | }); 22 | 23 | export const resetChart = () => ({ 24 | type: ChartAction.RESET_CHART, 25 | }); 26 | -------------------------------------------------------------------------------- /src/actions/drawing.js: -------------------------------------------------------------------------------- 1 | export const StrokeAction = { 2 | ADD_STROKE: 'ADD_STROKE', 3 | ADD_STROKE_POS: 'ADD_STROKE_POS', 4 | END_STROKE: 'END_STROKE', 5 | RESET_DRAWING: 'RESET_DRAWING', 6 | }; 7 | 8 | export const addStroke = pos => ({ 9 | type: StrokeAction.ADD_STROKE, 10 | pos, 11 | }); 12 | 13 | export const addStrokePos = pos => ({ 14 | type: StrokeAction.ADD_STROKE_POS, 15 | pos, 16 | }); 17 | 18 | export const endStroke = () => ({ 19 | type: StrokeAction.END_STROKE, 20 | }); 21 | 22 | export const resetDrawing = () => ({ 23 | type: StrokeAction.RESET_DRAWING, 24 | }); 25 | -------------------------------------------------------------------------------- /src/actions/mnist.js: -------------------------------------------------------------------------------- 1 | export const MnistAction = { 2 | INIT: 'INIT', 3 | LOAD_PRETRAINED_MODEL_SUCCEEDED: 'LOAD_PRETRAINED_MODEL_SUCCEEDED', 4 | LOAD_AND_TRAIN_MNIST_REQUESTED: 'LOAD_AND_TRAIN_MNIST_REQUESTED', 5 | LOADING_MNIST: 'LOADING_MNIST', 6 | TRAINING_MNIST: 'TRAINING_MNIST', 7 | LOAD_AND_TRAIN_MNIST_SUCCEEDED: 'LOAD_AND_TRAIN_MNIST_SUCCEEDED', 8 | PREDICT_REQUESTED: 'PREDICT_REQUESTED', 9 | PREDICT_SUCCEEDED: 'PREDICT_SUCCEEDED', 10 | }; 11 | 12 | export const loadPretrainedModelSucceeded = () => ({ 13 | type: MnistAction.LOAD_PRETRAINED_MODEL_SUCCEEDED, 14 | }); 15 | 16 | export const loadAndTrainMnist = () => ({ 17 | type: MnistAction.LOAD_AND_TRAIN_MNIST_REQUESTED, 18 | }); 19 | 20 | export const loadingMnist = () => ({ 21 | type: MnistAction.LOADING_MNIST, 22 | }); 23 | 24 | export const trainingMnist = () => ({ 25 | type: MnistAction.TRAINING_MNIST, 26 | }); 27 | 28 | export const loadAndTrainMnistSucceeded = () => ({ 29 | type: MnistAction.LOAD_AND_TRAIN_MNIST_SUCCEEDED, 30 | }); 31 | 32 | export const requestPredict = image => ({ 33 | type: MnistAction.PREDICT_REQUESTED, 34 | image, 35 | }); 36 | 37 | export const predictSucceeded = prediction => ({ 38 | type: MnistAction.PREDICT_SUCCEEDED, 39 | prediction, 40 | }); 41 | -------------------------------------------------------------------------------- /src/actions/pipeline.js: -------------------------------------------------------------------------------- 1 | export const PipelineAction = { 2 | INPUT: 'INPUT', 3 | DISPLAY_BOUNDING_BOX: 'DISPLAY_BOUNDING_BOX', 4 | DISPLAY_CROPPED_BOX: 'DISPLAY_CROPPED_BOX', 5 | DISPLAY_CENTERED_BOX: 'DISPLAY_CENTERED_BOX', 6 | DISPLAY_NORMALIZED_BOX: 'DISPLAY_NORMALIZED_BOX', 7 | }; 8 | 9 | export const addInput = image => ({ 10 | type: PipelineAction.INPUT, 11 | image, 12 | }); 13 | 14 | export const displayBoundingBox = (imageUrl, rect) => ({ 15 | type: PipelineAction.DISPLAY_BOUNDING_BOX, 16 | imageUrl, 17 | rect, 18 | }); 19 | 20 | export const displayCroppedBox = croppedUrl => ({ 21 | type: PipelineAction.DISPLAY_CROPPED_BOX, 22 | croppedUrl, 23 | }); 24 | 25 | export const displayCenteredBox = centeredUrl => ({ 26 | type: PipelineAction.DISPLAY_CENTERED_BOX, 27 | centeredUrl, 28 | }); 29 | 30 | export const displayNormalizedBox = normalizedUrl => ({ 31 | type: PipelineAction.DISPLAY_NORMALIZED_BOX, 32 | normalizedUrl, 33 | }); 34 | -------------------------------------------------------------------------------- /src/classifiers/handwriting-digits-classifier.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs'; 2 | import { put } from 'redux-saga/effects'; 3 | import { addAccuracyPoint, addLossPoint } from '../actions/chart'; 4 | 5 | export default class HandwritingDigitsClassifier { 6 | static TRAIN_BATCHES = 150; 7 | 8 | static BATCH_SIZE = 64; 9 | 10 | static TEST_ITERATION_FREQUENCY = 5; 11 | 12 | static TEST_BATCH_SIZE = 1000; 13 | 14 | static LEARNING_RATE = 0.15; 15 | 16 | static CLASSIFIER_FOLDER = 'classifiers'; 17 | 18 | static CLASSIFIER_NAME = 'model'; 19 | 20 | initializeModel(data) { 21 | this.data = data; 22 | 23 | this.model = tf.sequential(); 24 | 25 | this.model.add( 26 | tf.layers.conv2d({ 27 | inputShape: [28, 28, 1], 28 | kernelSize: 5, 29 | filters: 8, 30 | strides: 1, 31 | activation: 'relu', 32 | kernelInitializer: 'varianceScaling', 33 | }), 34 | ); 35 | 36 | this.model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })); 37 | 38 | this.model.add( 39 | tf.layers.conv2d({ 40 | kernelSize: 5, 41 | filters: 16, 42 | strides: 1, 43 | activation: 'relu', 44 | kernelInitializer: 'varianceScaling', 45 | }), 46 | ); 47 | 48 | this.model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })); 49 | 50 | this.model.add(tf.layers.flatten()); 51 | 52 | this.model.add( 53 | tf.layers.dense({ units: 10, kernelInitializer: 'varianceScaling', activation: 'softmax' }), 54 | ); 55 | 56 | this.compile(); 57 | } 58 | 59 | compile() { 60 | const optimizer = tf.train.sgd(HandwritingDigitsClassifier.LEARNING_RATE); 61 | 62 | this.model.compile({ 63 | optimizer, 64 | loss: 'categoricalCrossentropy', 65 | metrics: ['accuracy'], 66 | }); 67 | } 68 | 69 | * loadModel() { 70 | const { host, protocol } = window.location; 71 | 72 | this.model = yield tf.loadModel( 73 | `${protocol}//${host}/${HandwritingDigitsClassifier.CLASSIFIER_FOLDER}/${ 74 | HandwritingDigitsClassifier.CLASSIFIER_NAME 75 | }.json`, 76 | ); 77 | this.compile(); 78 | } 79 | 80 | getTrainBatch(i) { 81 | const batch = this.data.nextTrainBatch(HandwritingDigitsClassifier.BATCH_SIZE); 82 | batch.xs = batch.xs.reshape([HandwritingDigitsClassifier.BATCH_SIZE, 28, 28, 1]); 83 | 84 | let validationData; 85 | 86 | if (i % HandwritingDigitsClassifier.TEST_ITERATION_FREQUENCY === 0) { 87 | const testBatch = this.data.nextTrainBatch(HandwritingDigitsClassifier.TEST_BATCH_SIZE); 88 | 89 | validationData = [ 90 | testBatch.xs.reshape([HandwritingDigitsClassifier.TEST_BATCH_SIZE, 28, 28, 1]), 91 | testBatch.labels, 92 | ]; 93 | } 94 | 95 | return [batch, validationData]; 96 | } 97 | 98 | * train(save = false) { 99 | for (let i = 0; i < HandwritingDigitsClassifier.TRAIN_BATCHES; i += 1) { 100 | const [batch, validationData] = tf.tidy(() => this.getTrainBatch(i)); 101 | 102 | const history = yield this.model.fit(batch.xs, batch.labels, { 103 | batchSize: HandwritingDigitsClassifier.BATCH_SIZE, 104 | validationData, 105 | epochs: 1, 106 | }); 107 | 108 | const loss = history.history.loss[0]; 109 | const accuracy = history.history.acc[0]; 110 | 111 | yield put(addLossPoint(i, loss)); 112 | 113 | if (validationData != null) { 114 | yield put(addAccuracyPoint(i, accuracy)); 115 | } 116 | 117 | tf.dispose([batch, validationData]); 118 | 119 | yield tf.nextFrame(); 120 | } 121 | 122 | if (save) { 123 | yield this.model.save(`downloads://${HandwritingDigitsClassifier.CLASSIFIER_NAME}`); 124 | } 125 | } 126 | 127 | predict(dataTensor) { 128 | return tf.tidy(() => { 129 | const output = this.model.predict(dataTensor); 130 | 131 | const axis = 1; 132 | const predictions = Array.from(output.argMax(axis).dataSync()); 133 | 134 | return predictions[0]; 135 | }); 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/components/app.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { hot } from 'react-hot-loader'; 3 | import { Provider } from 'react-redux'; 4 | import configureStore from '../store/configureStore'; 5 | import Mnist from './mnist'; 6 | 7 | window.BizCharts.track(false); 8 | 9 | const App = class App extends React.Component { 10 | constructor(props) { 11 | super(props); 12 | 13 | this.store = configureStore(); 14 | } 15 | 16 | render() { 17 | return ( 18 | 19 | 20 | 21 | ); 22 | } 23 | }; 24 | 25 | export default hot(module)(App); 26 | -------------------------------------------------------------------------------- /src/components/bounding-box.jsx: -------------------------------------------------------------------------------- 1 | import PropTypes from 'prop-types'; 2 | import React from 'react'; 3 | import Rect from '../utils/rect'; 4 | import { Display140 } from './display'; 5 | 6 | export default class BoundingBox extends React.Component { 7 | static propTypes = { 8 | imageUrl: PropTypes.string, 9 | rect: PropTypes.instanceOf(Rect), 10 | }; 11 | 12 | static defaultProps = { 13 | imageUrl: null, 14 | rect: null, 15 | }; 16 | 17 | componentDidMount() { 18 | this.initContext(); 19 | } 20 | 21 | componentDidUpdate() { 22 | this.initContext(); 23 | } 24 | 25 | setDisplay = (elt) => { 26 | this.display = elt; 27 | }; 28 | 29 | initContext() { 30 | const { imageUrl, rect } = this.props; 31 | 32 | if (imageUrl) { 33 | const image = new Image(); 34 | image.onload = () => { 35 | this.display.ctx.drawImage(image, 0, 0); 36 | 37 | this.display.ctx.strokeStyle = 'red'; 38 | this.display.ctx.strokeRect( 39 | rect.xmin, 40 | rect.ymin, 41 | rect.computeWidth(), 42 | rect.computeHeight(), 43 | ); 44 | }; 45 | image.src = imageUrl; 46 | } 47 | } 48 | 49 | render() { 50 | return ; 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/components/button-space.jsx: -------------------------------------------------------------------------------- 1 | import { Button } from 'antd'; 2 | import styled from 'styled-components'; 3 | 4 | const ButtonSpace = styled(Button)` 5 | margin: 0 2px; 6 | `; 7 | 8 | export default ButtonSpace; 9 | -------------------------------------------------------------------------------- /src/components/center-col.jsx: -------------------------------------------------------------------------------- 1 | import { Col } from 'antd'; 2 | import styled from 'styled-components'; 3 | 4 | const CenterCol = styled(Col)` 5 | display: flex; 6 | align-items: center; 7 | justify-content: center; 8 | `; 9 | 10 | export default CenterCol; 11 | -------------------------------------------------------------------------------- /src/components/chart-section.jsx: -------------------------------------------------------------------------------- 1 | import { Col, Row } from 'antd'; 2 | import React from 'react'; 3 | import AccuracyChart from '../containers/accuracy-chart'; 4 | import LossChart from '../containers/loss-chart'; 5 | 6 | const ChartSection = () => ( 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | ); 16 | 17 | export default ChartSection; 18 | -------------------------------------------------------------------------------- /src/components/confirm-retrain-button.jsx: -------------------------------------------------------------------------------- 1 | import { Popconfirm } from 'antd'; 2 | import PropTypes from 'prop-types'; 3 | import React from 'react'; 4 | import styled from 'styled-components'; 5 | import RetrainButton from '../containers/retrain-button'; 6 | 7 | const PWarning = styled.p` 8 | color: red; 9 | `; 10 | 11 | const content = ( 12 |
13 |

14 | The system has already loaded a pretrained model with 0.99700 accuracy. If you want to train a 15 | new model, please click on continue button. 16 |

17 | 18 | It is unrecommended to run the training on your phone. The process will drain your battery. 19 | 20 |
21 | ); 22 | 23 | const ConfirmRetrainButton = ({ onConfirm }) => ( 24 | 25 | 26 | 27 | ); 28 | 29 | ConfirmRetrainButton.propTypes = { 30 | onConfirm: PropTypes.func.isRequired, 31 | }; 32 | 33 | export default ConfirmRetrainButton; 34 | -------------------------------------------------------------------------------- /src/components/display.jsx: -------------------------------------------------------------------------------- 1 | import { Row } from 'antd'; 2 | import PropTypes from 'prop-types'; 3 | import React from 'react'; 4 | import styled from 'styled-components'; 5 | import CenterCol from './center-col'; 6 | 7 | const DisplayRow = styled(Row)` 8 | margin-bottom: 10px; 9 | `; 10 | 11 | const H3 = styled.h3` 12 | margin: 0; 13 | `; 14 | 15 | const Canvas = styled.canvas` 16 | display: block; 17 | `; 18 | 19 | export class Display extends React.Component { 20 | static propTypes = { 21 | imageUrl: PropTypes.string, 22 | className: PropTypes.string, 23 | width: PropTypes.number, 24 | height: PropTypes.number, 25 | title: PropTypes.string.isRequired, 26 | }; 27 | 28 | static defaultProps = { 29 | imageUrl: null, 30 | className: '', 31 | width: 280, 32 | height: 280, 33 | }; 34 | 35 | constructor(props) { 36 | super(props); 37 | this.canvas = null; 38 | this.ctx = null; 39 | } 40 | 41 | componentDidMount() { 42 | this.initContext(); 43 | } 44 | 45 | componentDidUpdate() { 46 | this.initContext(); 47 | } 48 | 49 | setCanvasRef = (elt) => { 50 | this.canvas = elt; 51 | }; 52 | 53 | initContext() { 54 | this.ctx = this.canvas.getContext('2d'); 55 | 56 | this.clearCanvas(); 57 | this.loadImage(); 58 | } 59 | 60 | clearCanvas() { 61 | this.ctx.fillRect(0, 0, this.canvas.width, this.canvas.height); 62 | } 63 | 64 | loadImage() { 65 | const { imageUrl } = this.props; 66 | 67 | if (imageUrl) { 68 | const image = new Image(); 69 | image.onload = () => { 70 | this.ctx.drawImage(image, 0, 0); 71 | }; 72 | image.src = imageUrl; 73 | } 74 | } 75 | 76 | render() { 77 | const { 78 | className, width, height, title, 79 | } = this.props; 80 | 81 | return ( 82 | 83 | 84 |

{title}

85 |
86 | 87 | 93 | 94 |
95 | ); 96 | } 97 | } 98 | 99 | export const Display140 = styled(Display)` 100 | width: 140px; 101 | height: 140px; 102 | `; 103 | 104 | export const Display200 = styled(Display)` 105 | width: 200px; 106 | height: 200px; 107 | `; 108 | 109 | export const Display100 = styled(Display)` 110 | width: 100px; 111 | height: 100px; 112 | `; 113 | 114 | export const Display28 = styled(Display)` 115 | width: 28px; 116 | height: 28px; 117 | `; 118 | -------------------------------------------------------------------------------- /src/components/drawing.jsx: -------------------------------------------------------------------------------- 1 | import PropTypes from 'prop-types'; 2 | import React from 'react'; 3 | import styled from 'styled-components'; 4 | 5 | const Canvas = styled.canvas` 6 | width: 300px; 7 | height: 300px; 8 | border-color: dodgerblue; 9 | border-width: 5px; 10 | border-style: solid; 11 | display: block; 12 | touch-action: none; 13 | 14 | &:hover { 15 | border-color: deepskyblue; 16 | } 17 | `; 18 | 19 | export default class Drawing extends React.Component { 20 | static propTypes = { 21 | isDrawing: PropTypes.bool.isRequired, 22 | isEndStroke: PropTypes.bool.isRequired, 23 | strokes: PropTypes.arrayOf( 24 | PropTypes.arrayOf(PropTypes.shape({ x: PropTypes.number, y: PropTypes.number })), 25 | ).isRequired, 26 | addStroke: PropTypes.func.isRequired, 27 | addStrokePos: PropTypes.func.isRequired, 28 | endStroke: PropTypes.func.isRequired, 29 | addInput: PropTypes.func.isRequired, 30 | }; 31 | 32 | constructor(props) { 33 | super(props); 34 | this.canvas = null; 35 | this.ctx = null; 36 | } 37 | 38 | componentDidMount() { 39 | this.initContext(); 40 | } 41 | 42 | componentDidUpdate() { 43 | this.initContext(); 44 | this.drawStrokes(); 45 | 46 | const { isEndStroke, addInput } = this.props; 47 | 48 | if (isEndStroke) { 49 | addInput(this.ctx.getImageData(0, 0, 280, 280)); 50 | } 51 | } 52 | 53 | onMouseDown = (e) => { 54 | const { addStroke } = this.props; 55 | addStroke(this.computeMousePos(e)); 56 | }; 57 | 58 | onMouseMove = (e) => { 59 | const { isDrawing, addStrokePos } = this.props; 60 | 61 | if (!isDrawing) { 62 | return; 63 | } 64 | 65 | addStrokePos(this.computeMousePos(e)); 66 | }; 67 | 68 | onStrokeEnd = () => { 69 | const { isDrawing, endStroke } = this.props; 70 | 71 | if (isDrawing) { 72 | endStroke(); 73 | } 74 | }; 75 | 76 | setCanvasRef = (elt) => { 77 | this.canvas = elt; 78 | }; 79 | 80 | initContext() { 81 | this.ctx = this.canvas.getContext('2d'); 82 | this.ctx.lineWidth = 10; 83 | this.ctx.lineJoin = 'round'; 84 | this.ctx.lineCap = 'round'; 85 | this.ctx.strokeStyle = 'white'; 86 | 87 | this.ctx.fillStyle = 'dark'; 88 | 89 | this.clearCanvas(); 90 | } 91 | 92 | clearCanvas() { 93 | this.ctx.fillRect(0, 0, this.canvas.width, this.canvas.height); 94 | } 95 | 96 | computeMousePos(e) { 97 | return { 98 | x: this.computeMousePosX(e), 99 | y: this.computeMousePosY(e), 100 | }; 101 | } 102 | 103 | computeMousePosX(e) { 104 | const rect = this.canvas.getBoundingClientRect(); 105 | const scaleX = this.canvas.width / rect.width; 106 | 107 | return (e.clientX - rect.left) * scaleX; 108 | } 109 | 110 | computeMousePosY(e) { 111 | const rect = this.canvas.getBoundingClientRect(); 112 | const scaleY = this.canvas.height / rect.height; 113 | 114 | return (e.clientY - rect.top) * scaleY; 115 | } 116 | 117 | drawStrokes() { 118 | const { strokes } = this.props; 119 | 120 | for (let j = 0; j < strokes.length; j += 1) { 121 | const points = strokes[j]; 122 | 123 | this.ctx.beginPath(); 124 | this.ctx.moveTo(points[0].x, points[0].y); 125 | 126 | for (let i = 1; i < points.length; i += 1) { 127 | this.ctx.lineTo(points[i].x, points[i].y); 128 | } 129 | this.ctx.stroke(); 130 | } 131 | } 132 | 133 | render() { 134 | return ( 135 |
136 | 145 |
146 | ); 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/components/image-pipeline.jsx: -------------------------------------------------------------------------------- 1 | import { Row } from 'antd'; 2 | import React from 'react'; 3 | import styled from 'styled-components'; 4 | import BoundingBox from '../containers/bounding-box'; 5 | import CenteredBox from '../containers/centered-box'; 6 | import CroppedBox from '../containers/cropped-box'; 7 | import NormalizedBox from '../containers/normalized-box'; 8 | import CenterCol from './center-col'; 9 | 10 | const ImagePipelineRow = styled(Row)` 11 | margin-top: 20px; 12 | `; 13 | 14 | const ImagePipeline = () => ( 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | ); 30 | 31 | export default ImagePipeline; 32 | -------------------------------------------------------------------------------- /src/components/line-chart.jsx: -------------------------------------------------------------------------------- 1 | import { 2 | Axis, Chart, Geom, Tooltip, 3 | } from 'bizcharts'; 4 | import PropTypes from 'prop-types'; 5 | import React from 'react'; 6 | 7 | export default class LineChart extends React.Component { 8 | static propTypes = { 9 | data: PropTypes.oneOfType([PropTypes.arrayOf(PropTypes.object), PropTypes.object]), 10 | cols: PropTypes.oneOfType([PropTypes.object, PropTypes.array]), 11 | }; 12 | 13 | static defaultProps = { 14 | data: null, 15 | cols: null, 16 | }; 17 | 18 | static getObjectKeyName(obj, i) { 19 | return Object.keys(obj)[i]; 20 | } 21 | 22 | render() { 23 | const { data, cols } = this.props; 24 | const firstAxis = LineChart.getObjectKeyName(cols, 0); 25 | const secondAxis = LineChart.getObjectKeyName(cols, 1); 26 | const position = `${firstAxis}*${secondAxis}`; 27 | 28 | if (data.length > 0) { 29 | return ( 30 | 37 | 38 | 39 | 40 | 41 | 48 | 49 | ); 50 | } 51 | 52 | return null; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/components/mnist-command.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ClearButton from '../containers/clear-button'; 3 | import ConfirmRetrainButton from '../containers/confirm-retrain-button'; 4 | 5 | const MnistCommand = () => ( 6 |
7 | Clear 8 | 9 |
10 | ); 11 | 12 | export default MnistCommand; 13 | -------------------------------------------------------------------------------- /src/components/mnist-content.jsx: -------------------------------------------------------------------------------- 1 | import { Layout, Row } from 'antd'; 2 | import React from 'react'; 3 | import styled from 'styled-components'; 4 | import MnistDrawing from '../containers/mnist-drawing'; 5 | import ChartSection from './chart-section'; 6 | import ImagePipeline from './image-pipeline'; 7 | 8 | const { Content } = Layout; 9 | 10 | const ContentContainer = styled(Content)` 11 | max-width: 880px; 12 | margin: 0 auto; 13 | `; 14 | 15 | const MnistContent = () => ( 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | ); 24 | 25 | export default MnistContent; 26 | -------------------------------------------------------------------------------- /src/components/mnist-drawing.jsx: -------------------------------------------------------------------------------- 1 | import { Spin } from 'antd'; 2 | import PropTypes from 'prop-types'; 3 | import React from 'react'; 4 | import Drawing from '../containers/drawing'; 5 | import Output from '../containers/output'; 6 | import CenterCol from './center-col'; 7 | import MnistCommand from './mnist-command'; 8 | 9 | const MnistDrawing = ({ spinning }) => ( 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | ); 22 | 23 | MnistDrawing.propTypes = { 24 | spinning: PropTypes.bool, 25 | }; 26 | 27 | MnistDrawing.defaultProps = { 28 | spinning: true, 29 | }; 30 | 31 | export default MnistDrawing; 32 | -------------------------------------------------------------------------------- /src/components/mnist-footer.jsx: -------------------------------------------------------------------------------- 1 | import { Icon, Layout } from 'antd'; 2 | import React from 'react'; 3 | import styled from 'styled-components'; 4 | 5 | const { Footer } = Layout; 6 | 7 | const IconFooter = styled(Icon)` 8 | margin-right: 3px; 9 | font-size: 15px; 10 | `; 11 | 12 | const PFooter = styled.p` 13 | font-size: 8px; 14 | text-align: center; 15 | `; 16 | 17 | const AFooter = styled.a` 18 | color: inherit; 19 | `; 20 | 21 | const MnistFooter = () => ( 22 | 40 | ); 41 | 42 | export default MnistFooter; 43 | -------------------------------------------------------------------------------- /src/components/mnist-header.jsx: -------------------------------------------------------------------------------- 1 | import { Layout } from 'antd'; 2 | import React from 'react'; 3 | import styled from 'styled-components'; 4 | 5 | const { Header } = Layout; 6 | 7 | const H1 = styled.h1` 8 | color: white; 9 | overflow: hidden; 10 | text-overflow: ellipsis; 11 | white-space: nowrap; 12 | `; 13 | 14 | const WhiteHeader = styled(Header)` 15 | background: dodgerblue; 16 | margin-bottom: 20px; 17 | `; 18 | 19 | const MnistHeader = () => ( 20 | 21 |

Digit Recognition with Tensorflow.js and React

22 |
23 | ); 24 | 25 | export default MnistHeader; 26 | -------------------------------------------------------------------------------- /src/components/mnist.jsx: -------------------------------------------------------------------------------- 1 | import { Layout } from 'antd'; 2 | import React from 'react'; 3 | import MnistContent from './mnist-content'; 4 | import MnistFooter from './mnist-footer'; 5 | import MnistHeader from './mnist-header'; 6 | 7 | const Mnist = () => ( 8 |
9 | 10 | 11 | 12 | 13 | 14 |
15 | ); 16 | 17 | export default Mnist; 18 | -------------------------------------------------------------------------------- /src/components/output.jsx: -------------------------------------------------------------------------------- 1 | import PropTypes from 'prop-types'; 2 | import React from 'react'; 3 | import styled from 'styled-components'; 4 | 5 | const H2 = styled.h2` 6 | text-align: center; 7 | `; 8 | 9 | export default class Output extends React.PureComponent { 10 | static propTypes = { 11 | prediction: PropTypes.number, 12 | }; 13 | 14 | static defaultProps = { 15 | prediction: null, 16 | }; 17 | 18 | render() { 19 | const { prediction } = this.props; 20 | 21 | return ( 22 |

23 | {prediction != null 24 | ? `Prediction: ${prediction}` 25 | : 'Draw a number (0-9) in the black box above'} 26 |

27 | ); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/containers/accuracy-chart.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import LineChart from '../components/line-chart'; 3 | 4 | const mapStateToProps = state => ({ 5 | data: state.chart.accuracy, 6 | cols: { 7 | batch: { range: [0, 1], alias: 'Batch' }, 8 | accuracy: { min: 0, alias: 'Accuracy' }, 9 | }, 10 | }); 11 | 12 | export default connect(mapStateToProps)(LineChart); 13 | -------------------------------------------------------------------------------- /src/containers/bounding-box.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import BoundingBox from '../components/bounding-box'; 3 | 4 | const mapStateToProps = state => ({ 5 | imageUrl: state.pipeline.imageUrl, 6 | rect: state.pipeline.rect, 7 | }); 8 | 9 | export default connect(mapStateToProps)(BoundingBox); 10 | -------------------------------------------------------------------------------- /src/containers/centered-box.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { Display140 } from '../components/display'; 3 | 4 | const mapStateToProps = state => ({ 5 | imageUrl: state.pipeline.centeredUrl, 6 | width: 280, 7 | height: 280, 8 | title: 'Centered image', 9 | }); 10 | 11 | export default connect(mapStateToProps)(Display140); 12 | -------------------------------------------------------------------------------- /src/containers/clear-button.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { resetDrawing } from '../actions/drawing'; 3 | import ButtonSpace from '../components/button-space'; 4 | 5 | const mapStateToProps = state => ({ 6 | disabled: state.mnist.prediction == null, 7 | }); 8 | 9 | const mapDispatchToProps = dispatch => ({ 10 | onClick: () => dispatch(resetDrawing()), 11 | }); 12 | 13 | export default connect( 14 | mapStateToProps, 15 | mapDispatchToProps, 16 | )(ButtonSpace); 17 | -------------------------------------------------------------------------------- /src/containers/confirm-retrain-button.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { loadAndTrainMnist } from '../actions/mnist'; 3 | import ConfirmRetrainButton from '../components/confirm-retrain-button'; 4 | 5 | const mapDispatchToProps = dispatch => ({ 6 | onConfirm: () => dispatch(loadAndTrainMnist()), 7 | }); 8 | 9 | export default connect( 10 | null, 11 | mapDispatchToProps, 12 | )(ConfirmRetrainButton); 13 | -------------------------------------------------------------------------------- /src/containers/cropped-box.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { Display100 } from '../components/display'; 3 | 4 | const mapStateToProps = state => ({ 5 | imageUrl: state.pipeline.croppedUrl, 6 | width: 200, 7 | height: 200, 8 | title: 'Cropped image', 9 | }); 10 | 11 | export default connect(mapStateToProps)(Display100); 12 | -------------------------------------------------------------------------------- /src/containers/drawing.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { addStroke, addStrokePos, endStroke } from '../actions/drawing'; 3 | import { addInput } from '../actions/pipeline'; 4 | import Drawing from '../components/drawing'; 5 | 6 | const mapStateToProps = state => ({ 7 | isDrawing: state.drawing.isDrawing, 8 | isEndStroke: state.drawing.isEndStroke, 9 | strokes: state.drawing.strokes, 10 | }); 11 | 12 | const mapDispatchToProps = dispatch => ({ 13 | addStroke: pos => dispatch(addStroke(pos)), 14 | addStrokePos: pos => dispatch(addStrokePos(pos)), 15 | endStroke: () => dispatch(endStroke()), 16 | addInput: image => dispatch(addInput(image)), 17 | }); 18 | 19 | export default connect( 20 | mapStateToProps, 21 | mapDispatchToProps, 22 | )(Drawing); 23 | -------------------------------------------------------------------------------- /src/containers/loss-chart.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import LineChart from '../components/line-chart'; 3 | 4 | const mapStateToProps = state => ({ 5 | data: state.chart.loss, 6 | cols: { 7 | batch: { range: [0, 1], alias: 'Batch' }, 8 | loss: { min: 0, alias: 'Loss' }, 9 | }, 10 | }); 11 | 12 | export default connect(mapStateToProps)(LineChart); 13 | -------------------------------------------------------------------------------- /src/containers/mnist-drawing.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { MnistAction } from '../actions/mnist'; 3 | import MnistDrawing from '../components/mnist-drawing'; 4 | import isLoadingClassifier from '../utils/classifier'; 5 | 6 | const mapStateToProps = state => ({ 7 | spinning: state.mnist.status === MnistAction.INIT || isLoadingClassifier(state), 8 | }); 9 | 10 | export default connect(mapStateToProps)(MnistDrawing); 11 | -------------------------------------------------------------------------------- /src/containers/normalized-box.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { Display28 } from '../components/display'; 3 | 4 | const mapStateToProps = state => ({ 5 | imageUrl: state.pipeline.normalizedUrl, 6 | width: 28, 7 | height: 28, 8 | title: 'Normalized image', 9 | }); 10 | 11 | export default connect(mapStateToProps)(Display28); 12 | -------------------------------------------------------------------------------- /src/containers/output.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import Output from '../components/output'; 3 | 4 | const mapStateToProps = state => ({ 5 | prediction: state.mnist.prediction, 6 | }); 7 | 8 | export default connect(mapStateToProps)(Output); 9 | -------------------------------------------------------------------------------- /src/containers/retrain-button.jsx: -------------------------------------------------------------------------------- 1 | import { connect } from 'react-redux'; 2 | import { MnistAction } from '../actions/mnist'; 3 | import ButtonSpace from '../components/button-space'; 4 | import isLoadingClassifier from '../utils/classifier'; 5 | 6 | const text = (state) => { 7 | switch (state.mnist.retrainStatus) { 8 | case MnistAction.LOADING_MNIST: 9 | return 'Loading MNIST database'; 10 | case MnistAction.TRAINING_MNIST: 11 | return 'Training new model'; 12 | default: 13 | return 'Retrain model'; 14 | } 15 | }; 16 | 17 | const mapStateToProps = state => ({ 18 | loading: isLoadingClassifier(state), 19 | disabled: state.mnist.prediction == null, 20 | children: text(state), 21 | }); 22 | 23 | // Do delete mapDispatchToProps, keep it empty. Otherwise, it will create an error on the console 24 | const mapDispatchToProps = () => ({}); 25 | 26 | export default connect( 27 | mapStateToProps, 28 | mapDispatchToProps, 29 | )(ButtonSpace); 30 | -------------------------------------------------------------------------------- /src/data/mnist-data.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | /* eslint-disable */ 19 | 20 | import * as tf from '@tensorflow/tfjs'; 21 | 22 | const IMAGE_SIZE = 784; 23 | const NUM_CLASSES = 10; 24 | const NUM_DATASET_ELEMENTS = 65000; 25 | 26 | const NUM_TRAIN_ELEMENTS = 55000; 27 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; 28 | 29 | const MNIST_IMAGES_SPRITE_PATH = 30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; 31 | const MNIST_LABELS_PATH = 32 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; 33 | 34 | /** 35 | * A class that fetches the sprited MNIST dataset and returns shuffled batches. 36 | * 37 | * NOTE: This will get much easier. For now, we do data fetching and 38 | * manipulation manually. 39 | */ 40 | export default class MnistData { 41 | constructor() { 42 | this.shuffledTrainIndex = 0; 43 | this.shuffledTestIndex = 0; 44 | } 45 | 46 | load = async () => { 47 | // Make a request for the MNIST sprited image. 48 | const img = new Image(); 49 | const canvas = document.createElement('canvas'); 50 | const ctx = canvas.getContext('2d'); 51 | const imgRequest = new Promise((resolve, reject) => { 52 | img.crossOrigin = ''; 53 | img.onload = () => { 54 | img.width = img.naturalWidth; 55 | img.height = img.naturalHeight; 56 | 57 | const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); 58 | 59 | const chunkSize = 5000; 60 | canvas.width = img.width; 61 | canvas.height = chunkSize; 62 | 63 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { 64 | const datasetBytesView = new Float32Array( 65 | datasetBytesBuffer, 66 | i * IMAGE_SIZE * chunkSize * 4, 67 | IMAGE_SIZE * chunkSize, 68 | ); 69 | ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize); 70 | 71 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); 72 | 73 | for (let j = 0; j < imageData.data.length / 4; j++) { 74 | // All channels hold an equal value since the image is grayscale, so 75 | // just read the red channel. 76 | datasetBytesView[j] = imageData.data[j * 4] / 255; 77 | } 78 | } 79 | this.datasetImages = new Float32Array(datasetBytesBuffer); 80 | 81 | resolve(); 82 | }; 83 | img.src = MNIST_IMAGES_SPRITE_PATH; 84 | }); 85 | 86 | const labelsRequest = fetch(MNIST_LABELS_PATH); 87 | const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); 88 | 89 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); 90 | 91 | // Create shuffled indices into the train/test set for when we select a 92 | // random dataset element for training / validation. 93 | this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); 94 | this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); 95 | 96 | // Slice the the images and labels into train and test sets. 97 | this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 98 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 99 | this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); 100 | this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); 101 | }; 102 | 103 | nextTrainBatch(batchSize) { 104 | return this.nextBatch(batchSize, [this.trainImages, this.trainLabels], () => { 105 | this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length; 106 | return this.trainIndices[this.shuffledTrainIndex]; 107 | }); 108 | } 109 | 110 | nextTestBatch(batchSize) { 111 | return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => { 112 | this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length; 113 | return this.testIndices[this.shuffledTestIndex]; 114 | }); 115 | } 116 | 117 | nextBatch(batchSize, data, index) { 118 | const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE); 119 | const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES); 120 | 121 | for (let i = 0; i < batchSize; i++) { 122 | const idx = index(); 123 | 124 | const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE); 125 | batchImagesArray.set(image, i * IMAGE_SIZE); 126 | 127 | const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES); 128 | batchLabelsArray.set(label, i * NUM_CLASSES); 129 | } 130 | 131 | const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]); 132 | const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]); 133 | 134 | return { xs, labels }; 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /src/main.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDom from 'react-dom'; 3 | import App from './components/app'; 4 | 5 | ReactDom.render(, document.getElementById('app')); 6 | -------------------------------------------------------------------------------- /src/reducers/chart.js: -------------------------------------------------------------------------------- 1 | import { ChartAction } from '../actions/chart'; 2 | 3 | const initialState = { 4 | loss: [], 5 | accuracy: [], 6 | }; 7 | 8 | const chart = (state = initialState, action) => { 9 | switch (action.type) { 10 | case ChartAction.ADD_LOSS_POINT: 11 | return { 12 | ...state, 13 | loss: [...state.loss, action.pt], 14 | }; 15 | case ChartAction.ADD_ACCURACY_POINT: 16 | return { 17 | ...state, 18 | accuracy: [...state.accuracy, action.pt], 19 | }; 20 | case ChartAction.RESET_CHART: 21 | return initialState; 22 | default: 23 | return state; 24 | } 25 | }; 26 | 27 | export default chart; 28 | -------------------------------------------------------------------------------- /src/reducers/drawing.js: -------------------------------------------------------------------------------- 1 | import { StrokeAction } from '../actions/drawing'; 2 | 3 | const initialState = { 4 | isDrawing: false, 5 | isEndStroke: false, 6 | strokes: [], 7 | }; 8 | 9 | const drawing = (state = initialState, action) => { 10 | switch (action.type) { 11 | case StrokeAction.ADD_STROKE: 12 | return { 13 | ...state, 14 | isDrawing: true, 15 | isEndStroke: false, 16 | strokes: [...state.strokes, [action.pos]], 17 | }; 18 | case StrokeAction.ADD_STROKE_POS: 19 | return { 20 | ...state, 21 | strokes: [ 22 | ...state.strokes.slice(0, state.strokes.length - 1), 23 | [...state.strokes[state.strokes.length - 1], action.pos], 24 | ], 25 | }; 26 | case StrokeAction.END_STROKE: 27 | return { 28 | ...state, 29 | isDrawing: false, 30 | isEndStroke: true, 31 | }; 32 | case StrokeAction.RESET_DRAWING: 33 | return initialState; 34 | default: 35 | return state; 36 | } 37 | }; 38 | 39 | export default drawing; 40 | -------------------------------------------------------------------------------- /src/reducers/index.js: -------------------------------------------------------------------------------- 1 | import { combineReducers } from 'redux'; 2 | import chart from './chart'; 3 | import drawing from './drawing'; 4 | import mnist from './mnist'; 5 | import pipeline from './pipeline'; 6 | 7 | export default combineReducers({ 8 | drawing, 9 | pipeline, 10 | mnist, 11 | chart, 12 | }); 13 | -------------------------------------------------------------------------------- /src/reducers/mnist.js: -------------------------------------------------------------------------------- 1 | import { MnistAction } from '../actions/mnist'; 2 | 3 | const initialState = { 4 | status: MnistAction.INIT, 5 | retrainStatus: MnistAction.INIT, 6 | }; 7 | 8 | const mnist = (state = initialState, action) => { 9 | switch (action.type) { 10 | case MnistAction.LOAD_PRETRAINED_MODEL_SUCCEEDED: 11 | return { 12 | ...state, 13 | status: MnistAction.LOAD_PRETRAINED_MODEL_SUCCEEDED, 14 | }; 15 | case MnistAction.LOADING_MNIST: 16 | case MnistAction.TRAINING_MNIST: 17 | case MnistAction.LOAD_AND_TRAIN_MNIST_SUCCEEDED: 18 | return { 19 | ...state, 20 | retrainStatus: action.type, 21 | }; 22 | case MnistAction.PREDICT_SUCCEEDED: 23 | return { 24 | ...state, 25 | prediction: action.prediction, 26 | }; 27 | default: 28 | return state; 29 | } 30 | }; 31 | 32 | export default mnist; 33 | -------------------------------------------------------------------------------- /src/reducers/pipeline.js: -------------------------------------------------------------------------------- 1 | import { PipelineAction } from '../actions/pipeline'; 2 | 3 | const pipeline = (state = {}, action) => { 4 | switch (action.type) { 5 | case PipelineAction.DISPLAY_BOUNDING_BOX: 6 | return { 7 | ...state, 8 | imageUrl: action.imageUrl, 9 | rect: action.rect, 10 | }; 11 | case PipelineAction.DISPLAY_CROPPED_BOX: 12 | return { 13 | ...state, 14 | croppedUrl: action.croppedUrl, 15 | }; 16 | case PipelineAction.DISPLAY_CENTERED_BOX: 17 | return { 18 | ...state, 19 | centeredUrl: action.centeredUrl, 20 | }; 21 | case PipelineAction.DISPLAY_NORMALIZED_BOX: 22 | return { 23 | ...state, 24 | normalizedUrl: action.normalizedUrl, 25 | }; 26 | default: 27 | return state; 28 | } 29 | }; 30 | 31 | export default pipeline; 32 | -------------------------------------------------------------------------------- /src/sagas/index.js: -------------------------------------------------------------------------------- 1 | import { all, fork } from 'redux-saga/effects'; 2 | import watchMnist from './mnist'; 3 | import watchPipeline from './pipeline'; 4 | 5 | export default function* rootSaga() { 6 | yield all([fork(watchMnist), fork(watchPipeline)]); 7 | } 8 | -------------------------------------------------------------------------------- /src/sagas/mnist.js: -------------------------------------------------------------------------------- 1 | import * as tf from '@tensorflow/tfjs'; 2 | import { 3 | apply, call, put, take, 4 | } from 'redux-saga/effects'; 5 | import { resetChart } from '../actions/chart'; 6 | import { resetDrawing } from '../actions/drawing'; 7 | import { 8 | loadAndTrainMnistSucceeded, 9 | loadingMnist, 10 | loadPretrainedModelSucceeded, 11 | MnistAction, 12 | predictSucceeded, 13 | trainingMnist, 14 | } from '../actions/mnist'; 15 | import HandwritingDigitsClassifier from '../classifiers/handwriting-digits-classifier'; 16 | import MnistData from '../data/mnist-data'; 17 | import { convertToGrayscale } from '../utils/image-processing'; 18 | 19 | function* trainMnist(mnistData) { 20 | yield put(trainingMnist()); 21 | const handwritingDigitsClassifier = new HandwritingDigitsClassifier(); 22 | handwritingDigitsClassifier.initializeModel(mnistData); 23 | yield apply(handwritingDigitsClassifier, handwritingDigitsClassifier.train); 24 | return handwritingDigitsClassifier; 25 | } 26 | 27 | function* loadAndTrainMnist() { 28 | yield put(loadingMnist()); 29 | const mnistData = new MnistData(); 30 | yield call(mnistData.load); 31 | const handwritingDigitsClassifier = yield call(trainMnist, mnistData); 32 | return handwritingDigitsClassifier; 33 | } 34 | 35 | function* loadPretrainedModel() { 36 | const pretrainedHandwritingDigitsClassifier = new HandwritingDigitsClassifier(); 37 | yield apply( 38 | pretrainedHandwritingDigitsClassifier, 39 | pretrainedHandwritingDigitsClassifier.loadModel, 40 | ); 41 | yield put(loadPretrainedModelSucceeded()); 42 | return pretrainedHandwritingDigitsClassifier; 43 | } 44 | 45 | function* predict(handwritingDigitsClassifier, image) { 46 | const dataGrayscale = convertToGrayscale(image); 47 | const dataTensor = tf.tensor(dataGrayscale, [1, 28, 28, 1]); 48 | const prediction = handwritingDigitsClassifier.predict(dataTensor); 49 | yield put(predictSucceeded(prediction)); 50 | } 51 | 52 | function* request(pretrainedHandwritingDigitsClassifier) { 53 | let handwritingDigitsClassifier = pretrainedHandwritingDigitsClassifier; 54 | 55 | while (true) { 56 | const action = yield take([ 57 | MnistAction.PREDICT_REQUESTED, 58 | MnistAction.LOAD_AND_TRAIN_MNIST_REQUESTED, 59 | ]); 60 | 61 | switch (action.type) { 62 | case MnistAction.PREDICT_REQUESTED: 63 | yield call(predict, handwritingDigitsClassifier, action.image); 64 | break; 65 | case MnistAction.LOAD_AND_TRAIN_MNIST_REQUESTED: 66 | yield put(resetChart()); 67 | yield put(resetDrawing()); 68 | 69 | if (handwritingDigitsClassifier.data) { 70 | handwritingDigitsClassifier = yield call(trainMnist, handwritingDigitsClassifier.data); 71 | } else { 72 | handwritingDigitsClassifier = yield call(loadAndTrainMnist); 73 | } 74 | 75 | yield put(loadAndTrainMnistSucceeded()); 76 | break; 77 | default: 78 | break; 79 | } 80 | } 81 | } 82 | 83 | function* watchMnist() { 84 | const pretrainedHandwritingDigitsClassifier = yield call(loadPretrainedModel); 85 | yield call(request, pretrainedHandwritingDigitsClassifier); 86 | } 87 | 88 | export default watchMnist; 89 | -------------------------------------------------------------------------------- /src/sagas/pipeline.js: -------------------------------------------------------------------------------- 1 | import { call, put, takeLatest } from 'redux-saga/effects'; 2 | import { requestPredict } from '../actions/mnist'; 3 | import { 4 | displayBoundingBox, 5 | displayCenteredBox, 6 | displayCroppedBox, 7 | displayNormalizedBox, 8 | PipelineAction, 9 | } from '../actions/pipeline'; 10 | import { computeBoundingRect } from '../utils/image-processing'; 11 | 12 | function* boundingBoxTask(action) { 13 | const canvas = document.createElement('canvas'); 14 | canvas.width = 280; 15 | canvas.height = 280; 16 | const ctx = canvas.getContext('2d'); 17 | ctx.putImageData(action.image, 0, 0); 18 | 19 | const rect = computeBoundingRect(action.image); 20 | 21 | yield put(displayBoundingBox(canvas.toDataURL(), rect)); 22 | 23 | return { 24 | canvas, 25 | rect, 26 | }; 27 | } 28 | 29 | function* croppedBoxTask(canvas, rect) { 30 | const croppedCanvas = document.createElement('canvas'); 31 | croppedCanvas.width = 200; 32 | croppedCanvas.height = 200; 33 | const croppedCtx = croppedCanvas.getContext('2d'); 34 | 35 | const rectWidth = rect.computeWidth(); 36 | const rectHeight = rect.computeHeight(); 37 | const scalingFactor = 200 / Math.max(rectWidth, rectHeight); 38 | const croppedRectSize = { 39 | width: rectWidth * scalingFactor, 40 | height: rectHeight * scalingFactor, 41 | }; 42 | 43 | croppedCtx.drawImage( 44 | canvas, 45 | rect.xmin, 46 | rect.ymin, 47 | rectWidth, 48 | rectHeight, 49 | 0, 50 | 0, 51 | croppedRectSize.width, 52 | croppedRectSize.height, 53 | ); 54 | 55 | yield put(displayCroppedBox(croppedCanvas.toDataURL())); 56 | 57 | return { 58 | croppedCanvas, 59 | croppedRectSize, 60 | }; 61 | } 62 | 63 | function* centeredBoxTask(croppedCanvas, croppedRectSize) { 64 | const centeredCanvas = document.createElement('canvas'); 65 | centeredCanvas.width = 280; 66 | centeredCanvas.height = 280; 67 | const centeredCtx = centeredCanvas.getContext('2d'); 68 | 69 | centeredCtx.drawImage( 70 | croppedCanvas, 71 | centeredCanvas.width / 2 - croppedRectSize.width / 2, 72 | centeredCanvas.height / 2 - croppedRectSize.height / 2, 73 | ); 74 | 75 | yield put(displayCenteredBox(centeredCanvas.toDataURL())); 76 | 77 | return centeredCanvas; 78 | } 79 | 80 | function* normalizedTask(centeredCanvas) { 81 | const normalizedCanvas = document.createElement('canvas'); 82 | normalizedCanvas.width = 28; 83 | normalizedCanvas.height = 28; 84 | const normalizedCtx = normalizedCanvas.getContext('2d'); 85 | 86 | normalizedCtx.drawImage( 87 | centeredCanvas, 88 | 0, 89 | 0, 90 | centeredCanvas.width, 91 | centeredCanvas.height, 92 | 0, 93 | 0, 94 | normalizedCanvas.width, 95 | normalizedCanvas.height, 96 | ); 97 | 98 | yield put(displayNormalizedBox(normalizedCanvas.toDataURL())); 99 | 100 | return normalizedCanvas; 101 | } 102 | 103 | function* requestPredictTask(normalizedCanvas) { 104 | const normalizedCtx = normalizedCanvas.getContext('2d'); 105 | yield put( 106 | requestPredict( 107 | normalizedCtx.getImageData(0, 0, normalizedCanvas.width, normalizedCanvas.height), 108 | ), 109 | ); 110 | } 111 | 112 | function* pipelineImgProcessing(action) { 113 | const { canvas, rect } = yield call(boundingBoxTask, action); 114 | const { croppedCanvas, croppedRectSize } = yield call(croppedBoxTask, canvas, rect); 115 | const centeredCanvas = yield call(centeredBoxTask, croppedCanvas, croppedRectSize); 116 | const normalizedCanvas = yield call(normalizedTask, centeredCanvas); 117 | yield call(requestPredictTask, normalizedCanvas); 118 | } 119 | 120 | function* watchPipeline() { 121 | yield takeLatest(PipelineAction.INPUT, pipelineImgProcessing); 122 | } 123 | 124 | export default watchPipeline; 125 | -------------------------------------------------------------------------------- /src/store/configureStore.js: -------------------------------------------------------------------------------- 1 | import { applyMiddleware, createStore } from 'redux'; 2 | import { composeWithDevTools } from 'redux-devtools-extension/logOnlyInProduction'; 3 | import createSagaMiddleware from 'redux-saga'; 4 | import reducer from '../reducers'; 5 | import rootSaga from '../sagas'; 6 | 7 | export default function configureStore(preloadedState) { 8 | const composeEnhancers = composeWithDevTools({}); 9 | 10 | const sagaMiddleware = createSagaMiddleware(); 11 | const middleware = [sagaMiddleware]; 12 | 13 | const store = createStore( 14 | reducer, 15 | preloadedState, 16 | composeEnhancers(applyMiddleware(...middleware)), 17 | ); 18 | sagaMiddleware.run(rootSaga); 19 | 20 | if (module.hot) { 21 | module.hot.accept('../reducers', () => store.replaceReducer(require('../reducers').default)); // eslint-disable-line global-require 22 | 23 | module.hot.decline('../sagas'); 24 | } 25 | 26 | return store; 27 | } 28 | -------------------------------------------------------------------------------- /src/utils/classifier.js: -------------------------------------------------------------------------------- 1 | import { MnistAction } from '../actions/mnist'; 2 | 3 | export default function isLoadingClassifier(state) { 4 | return ( 5 | state.mnist.retrainStatus !== MnistAction.INIT 6 | && state.mnist.retrainStatus !== MnistAction.LOAD_AND_TRAIN_MNIST_SUCCEEDED 7 | ); 8 | } 9 | -------------------------------------------------------------------------------- /src/utils/image-processing.js: -------------------------------------------------------------------------------- 1 | import Rect from './rect'; 2 | 3 | export function computeBoundingRect(image) { 4 | const rect = new Rect(image); 5 | 6 | for (let i = 0; i < image.width * image.height; i += 1) { 7 | const j = i * 4; 8 | 9 | if (image.data[j + 0] > 0 || image.data[j + 1] > 0 || image.data[j + 2] > 0) { 10 | const x = i % image.width; 11 | const y = Math.floor(i / image.width); 12 | 13 | rect.xmin = Math.min(x, rect.xmin); 14 | rect.xmax = Math.max(x, rect.xmax); 15 | rect.ymin = Math.min(y, rect.ymin); 16 | rect.ymax = Math.max(y, rect.ymax); 17 | } 18 | } 19 | 20 | return rect; 21 | } 22 | 23 | export function convertToGrayscale(image) { 24 | const dataGrayscale = []; 25 | const { data } = image; 26 | 27 | for (let i = 0; i < image.width * image.height; i += 1) { 28 | const j = i * 4; 29 | const avg = (data[j + 0] + data[j + 1] + data[j + 2]) / 3; 30 | const normalized = avg / 255.0; 31 | dataGrayscale.push(normalized); 32 | } 33 | 34 | return dataGrayscale; 35 | } 36 | -------------------------------------------------------------------------------- /src/utils/rect.js: -------------------------------------------------------------------------------- 1 | export default class Rect { 2 | constructor(image) { 3 | this.xmin = image.width; 4 | this.xmax = -1; 5 | this.ymin = image.height; 6 | this.ymax = -1; 7 | } 8 | 9 | computeWidth() { 10 | return this.xmax - this.xmin + 1; 11 | } 12 | 13 | computeHeight() { 14 | return this.ymax - this.ymin + 1; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | <%= htmlWebpackPlugin.options.title %> 9 | 10 | 11 | 12 | 19 | 20 | 21 | 22 |
Loading application...
23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /webpack.config.js: -------------------------------------------------------------------------------- 1 | module.exports = (env) => { 2 | if (env.WEBPACK_CONFIG === 'dev') { 3 | return require('./webpack/webpack.dev.js'); 4 | } 5 | if (env.WEBPACK_CONFIG === 'prod') { 6 | return require('./webpack/webpack.prod.js'); 7 | } 8 | if (env.WEBPACK_CONFIG === 'analyze') { 9 | return require('./webpack/webpack.analyze.js'); 10 | } 11 | }; 12 | -------------------------------------------------------------------------------- /webpack/utils.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | 3 | const ROOT_DIR = path.resolve(__dirname, '../'); 4 | const PUBLIC_DIR = path.resolve(ROOT_DIR, 'public'); 5 | const DIST_DIR = path.resolve(ROOT_DIR, 'dist'); 6 | 7 | function generateHtmlWebpackPluginConfig() { 8 | return { 9 | title: 'Live digit recognition on phone and desktop', 10 | template: path.resolve(ROOT_DIR, 'templates/index.html'), 11 | }; 12 | } 13 | 14 | module.exports = { 15 | ROOT_DIR, 16 | PUBLIC_DIR, 17 | DIST_DIR, 18 | generateHtmlWebpackPluginConfig, 19 | }; 20 | -------------------------------------------------------------------------------- /webpack/webpack.analyze.js: -------------------------------------------------------------------------------- 1 | const BundleAnalyzerPlugin = require('webpack-bundle-analyzer').BundleAnalyzerPlugin; 2 | const merge = require('webpack-merge'); 3 | const Prod = require('./webpack.prod.js'); 4 | 5 | module.exports = merge.smartStrategy({ 6 | 'module.rules.use': 'prepend', 7 | })(Prod, { 8 | plugins: [ 9 | new BundleAnalyzerPlugin({ 10 | openAnalyzer: false, 11 | analyzerMode: 'static', 12 | }), 13 | ], 14 | }); 15 | -------------------------------------------------------------------------------- /webpack/webpack.common.js: -------------------------------------------------------------------------------- 1 | const CleanWebpackPlugin = require('clean-webpack-plugin'); 2 | const HtmlWebpackPlugin = require('html-webpack-plugin'); 3 | const path = require('path'); 4 | const Utils = require('./utils'); 5 | 6 | const config = { 7 | entry: path.resolve(Utils.ROOT_DIR, 'src/main.jsx'), 8 | output: { 9 | path: Utils.DIST_DIR, 10 | filename: 'scripts/[name].js', 11 | publicPath: '/', 12 | }, 13 | resolve: { 14 | extensions: ['.js', '.jsx'], 15 | }, 16 | module: { 17 | rules: [ 18 | { 19 | enforce: 'pre', 20 | test: /\.(js|jsx)$/, 21 | include: /src/, 22 | use: [ 23 | { 24 | loader: 'eslint-loader', 25 | options: { 26 | cache: true, 27 | emitWarning: true, 28 | }, 29 | }, 30 | ], 31 | }, 32 | { 33 | test: /\.(js|jsx)$/, 34 | include: /src/, 35 | use: [ 36 | { 37 | loader: 'babel-loader', 38 | options: { 39 | cacheDirectory: true, 40 | }, 41 | }, 42 | ], 43 | }, 44 | { 45 | test: /\.css$/, 46 | use: [ 47 | // Environment config will prepend loader here 48 | { 49 | loader: 'css-loader', 50 | }, 51 | ], 52 | }, 53 | ], 54 | }, 55 | plugins: [ 56 | new CleanWebpackPlugin([Utils.DIST_DIR], { root: Utils.ROOT_DIR }), 57 | new HtmlWebpackPlugin({ 58 | ...Utils.generateHtmlWebpackPluginConfig(), 59 | }), 60 | ], 61 | optimization: { 62 | runtimeChunk: 'single', 63 | splitChunks: { 64 | cacheGroups: { 65 | vendor: { 66 | test: /[\\/]node_modules[\\/]/, 67 | name: 'vendors', 68 | chunks: 'all', 69 | }, 70 | }, 71 | }, 72 | }, 73 | }; 74 | 75 | module.exports = config; 76 | -------------------------------------------------------------------------------- /webpack/webpack.dev.js: -------------------------------------------------------------------------------- 1 | const BundleAnalyzerPlugin = require('webpack-bundle-analyzer').BundleAnalyzerPlugin; 2 | const WebpackCdnPlugin = require('webpack-cdn-plugin'); 3 | const merge = require('webpack-merge'); 4 | const path = require('path'); 5 | const Common = require('./webpack.common.js'); 6 | const Utils = require('./utils'); 7 | 8 | module.exports = merge.smartStrategy({ 9 | 'module.rules.use': 'prepend', 10 | })(Common, { 11 | mode: 'development', 12 | devtool: 'source-map', 13 | devServer: { 14 | contentBase: [Utils.PUBLIC_DIR, path.resolve(Utils.ROOT_DIR, 'node_modules')], 15 | overlay: { 16 | warnings: false, 17 | errors: true, 18 | }, 19 | port: 9000, 20 | }, 21 | module: { 22 | rules: [ 23 | { 24 | test: /\.css$/, 25 | use: [ 26 | { 27 | loader: 'style-loader', 28 | options: { 29 | sourceMap: true, 30 | }, 31 | }, 32 | { 33 | loader: 'css-loader', 34 | options: { 35 | sourceMap: true, 36 | }, 37 | }, 38 | ], 39 | }, 40 | ], 41 | }, 42 | plugins: [ 43 | new WebpackCdnPlugin({ 44 | modules: [ 45 | { name: 'react', var: 'React', path: 'umd/react.development.js' }, 46 | { name: 'react-dom', var: 'ReactDOM', path: 'umd/react-dom.development.js' }, 47 | { name: 'react-redux', var: 'ReactRedux', path: 'dist/react-redux.js' }, 48 | { name: 'redux-saga', var: 'ReduxSaga', path: 'dist/redux-saga.js' }, 49 | { name: '@tensorflow/tfjs', var: 'tf', path: 'dist/tf.js' }, 50 | { name: 'bizcharts', var: 'BizCharts', path: 'umd/BizCharts.js' }, 51 | ], 52 | prod: false, 53 | }), 54 | new BundleAnalyzerPlugin({ 55 | openAnalyzer: false, 56 | }), 57 | ], 58 | }); 59 | -------------------------------------------------------------------------------- /webpack/webpack.prod.js: -------------------------------------------------------------------------------- 1 | const AsyncStylesheetWebpackPlugin = require('async-stylesheet-webpack-plugin'); 2 | const CopyWebpackPlugin = require('copy-webpack-plugin'); 3 | const HtmlWebpackPlugin = require('html-webpack-plugin'); 4 | const InlineSourcePlugin = require('html-webpack-inline-source-plugin'); 5 | const MiniCssExtractPlugin = require('mini-css-extract-plugin'); 6 | const OptimizeCSSAssetsPlugin = require('optimize-css-assets-webpack-plugin'); 7 | const UglifyJsPlugin = require('uglifyjs-webpack-plugin'); 8 | const WebpackCdnPlugin = require('webpack-cdn-plugin'); 9 | const merge = require('webpack-merge'); 10 | const webpack = require('webpack'); 11 | const Common = require('./webpack.common.js'); 12 | const Utils = require('./utils'); 13 | 14 | module.exports = merge.smartStrategy({ 15 | 'module.rules.use': 'prepend', 16 | })(Common, { 17 | mode: 'production', 18 | devtool: false, 19 | output: { 20 | filename: 'assets/[name].[contenthash].js', 21 | }, 22 | module: { 23 | rules: [ 24 | { 25 | test: /\.(js|jsx)$/, 26 | include: /src/, 27 | use: [ 28 | { 29 | loader: 'babel-loader', 30 | options: { 31 | forceEnv: 'production', 32 | }, 33 | }, 34 | ], 35 | }, 36 | { 37 | test: /\.css$/, 38 | use: [ 39 | { 40 | loader: MiniCssExtractPlugin.loader, 41 | }, 42 | ], 43 | }, 44 | ], 45 | }, 46 | plugins: [ 47 | new webpack.HashedModuleIdsPlugin(), 48 | new CopyWebpackPlugin([{ from: Utils.PUBLIC_DIR, to: Utils.DIST_DIR }]), 49 | new MiniCssExtractPlugin({ 50 | filename: 'styles/[name].[contenthash].css', 51 | }), 52 | new HtmlWebpackPlugin({ 53 | ...Utils.generateHtmlWebpackPluginConfig(), 54 | inlineSource: 'runtime.+\\.js', 55 | }), 56 | new InlineSourcePlugin(), 57 | new AsyncStylesheetWebpackPlugin({ 58 | preloadPolyfill: true, 59 | }), 60 | new WebpackCdnPlugin({ 61 | modules: [ 62 | { name: 'react', var: 'React', path: 'umd/react.production.min.js' }, 63 | { name: 'react-dom', var: 'ReactDOM', path: 'umd/react-dom.production.min.js' }, 64 | { name: 'styled-components', var: 'styled', path: 'dist/styled-components.min.js' }, 65 | { name: 'redux', var: 'Redux', path: 'dist/redux.min.js' }, 66 | { name: 'react-redux', var: 'ReactRedux', path: 'dist/react-redux.min.js' }, 67 | { name: 'redux-saga', var: 'ReduxSaga', path: 'dist/redux-saga.min.js' }, 68 | { name: '@tensorflow/tfjs', var: 'tf', path: 'dist/tf.min.js' }, 69 | { name: 'bizcharts', var: 'BizCharts', path: 'umd/BizCharts.min.js' }, 70 | ], 71 | prod: true, 72 | }), 73 | ], 74 | optimization: { 75 | noEmitOnErrors: true, 76 | minimizer: [ 77 | new UglifyJsPlugin({ 78 | cache: true, 79 | parallel: true, 80 | uglifyOptions: { 81 | output: { 82 | comments: false, 83 | }, 84 | }, 85 | }), 86 | new OptimizeCSSAssetsPlugin({}), 87 | ], 88 | }, 89 | performance: { 90 | maxEntrypointSize: 350000, 91 | }, 92 | }); 93 | --------------------------------------------------------------------------------