├── .babelrc
├── .eslintignore
├── .eslintrc
├── .gitignore
├── .prettierignore
├── .vscode
├── launch.json
└── settings.json
├── README.md
├── netlify.toml
├── package.json
├── public
├── classifiers
│ ├── group1-shard1of1
│ └── model.json
└── images
│ ├── desktop.gif
│ └── phone.gif
├── src
├── actions
│ ├── chart.js
│ ├── drawing.js
│ ├── mnist.js
│ └── pipeline.js
├── classifiers
│ └── handwriting-digits-classifier.js
├── components
│ ├── app.jsx
│ ├── bounding-box.jsx
│ ├── button-space.jsx
│ ├── center-col.jsx
│ ├── chart-section.jsx
│ ├── confirm-retrain-button.jsx
│ ├── display.jsx
│ ├── drawing.jsx
│ ├── image-pipeline.jsx
│ ├── line-chart.jsx
│ ├── mnist-command.jsx
│ ├── mnist-content.jsx
│ ├── mnist-drawing.jsx
│ ├── mnist-footer.jsx
│ ├── mnist-header.jsx
│ ├── mnist.jsx
│ └── output.jsx
├── containers
│ ├── accuracy-chart.jsx
│ ├── bounding-box.jsx
│ ├── centered-box.jsx
│ ├── clear-button.jsx
│ ├── confirm-retrain-button.jsx
│ ├── cropped-box.jsx
│ ├── drawing.jsx
│ ├── loss-chart.jsx
│ ├── mnist-drawing.jsx
│ ├── normalized-box.jsx
│ ├── output.jsx
│ └── retrain-button.jsx
├── data
│ └── mnist-data.js
├── main.jsx
├── reducers
│ ├── chart.js
│ ├── drawing.js
│ ├── index.js
│ ├── mnist.js
│ └── pipeline.js
├── sagas
│ ├── index.js
│ ├── mnist.js
│ └── pipeline.js
├── store
│ └── configureStore.js
└── utils
│ ├── classifier.js
│ ├── image-processing.js
│ └── rect.js
├── templates
└── index.html
├── webpack.config.js
├── webpack
├── utils.js
├── webpack.analyze.js
├── webpack.common.js
├── webpack.dev.js
└── webpack.prod.js
└── yarn.lock
/.babelrc:
--------------------------------------------------------------------------------
1 | {
2 | "presets": [
3 | "react"
4 | ],
5 | "plugins": [
6 | "transform-class-properties",
7 | "transform-object-rest-spread",
8 | "babel-plugin-styled-components",
9 | [
10 | "import",
11 | {
12 | "libraryName": "antd",
13 | "libraryDirectory": "es",
14 | "style": "css"
15 | }
16 | ]
17 | ],
18 | "env": {
19 | "development": {
20 | "plugins": [
21 | "react-hot-loader/babel"
22 | ]
23 | },
24 | "production": {
25 | "presets": ["react-optimize"]
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/.eslintignore:
--------------------------------------------------------------------------------
1 | webpack.*.js
2 | node_modules
3 |
--------------------------------------------------------------------------------
/.eslintrc:
--------------------------------------------------------------------------------
1 | {
2 | "parser": "babel-eslint",
3 | "extends": "airbnb",
4 | "env": {
5 | "browser": true
6 | },
7 | "plugins": [
8 | "react",
9 | "jsx-a11y",
10 | "import"
11 | ]
12 | }
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules/
2 | dist/
3 | stats.json
4 |
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | .babelrc
2 | .eslintrc
3 | node_modules
4 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "type": "chrome",
9 | "request": "launch",
10 | "name": "Launch Chrome against localhost",
11 | "url": "http://localhost:9000",
12 | "webRoot": "${workspaceFolder}"
13 | }
14 | ]
15 | }
16 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.formatOnSave": true,
3 | "javascript.format.enable": false,
4 | "editor.tabSize": 2,
5 | "editor.detectIndentation": false,
6 | "prettier.eslintIntegration": true,
7 | "editor.codeActionsOnSave": {
8 | "source.organizeImports": true
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Handwritten digit recogniton
2 |
3 | Digit recognition built with Tensorflow.js, Mnist dataset, React, Redux, Redux-Saga, Babel, Webpack, Styled-components, Eslint, Prettier and Ant Design.
4 |
5 | A demo is available at this location: [https://digit-recognition.ixartz.com](https://digit-recognition.ixartz.com).
6 |
7 | ### Videos
8 |
9 | Phone (iOS and Android) version:
10 |
11 | 
12 |
13 | Desktop version:
14 |
15 | 
16 |
17 | ## Setup environment
18 |
19 | This project is based on JavaScript environment and you need to install dependencies with Yarn or NPM:
20 |
21 | $ yarn install
22 |
23 | ## Launch locally
24 |
25 | $ yarn start
26 | $ Open https://localhost:9000 with your favorite browser
27 |
28 | ## Build for production
29 |
30 | $ yarn build
31 |
32 | ## Author
33 |
34 | [Ixartz's technical blog](https://blog.ixartz.com/)
35 |
--------------------------------------------------------------------------------
/netlify.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | publish = 'dist'
3 | command = 'yarn build'
4 |
5 | [[redirects]]
6 | from = "http://digit-recognition.netlify.com/*"
7 | to = "https://digit-recognition.ixartz.com/:splat"
8 | status = 301
9 | force = true
10 |
11 | [[redirects]]
12 | from = "https://digit-recognition.netlify.com/*"
13 | to = "https://digit-recognition.ixartz.com/:splat"
14 | status = 301
15 | force = true
16 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tensorflow.js",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "index.js",
6 | "scripts": {
7 | "start": "webpack-dev-server --hot --env.WEBPACK_CONFIG=dev --progress",
8 | "build": "webpack --env.WEBPACK_CONFIG=prod --progress",
9 | "http-server": "http-server -p 9000 ./dist",
10 | "serve": "yarn build && yarn http-server",
11 | "serve-analyze": "webpack --env.WEBPACK_CONFIG=analyze --progress && yarn http-server",
12 | "build-stats": "webpack --env.WEBPACK_CONFIG=prod --progress --profile --json > stats.json",
13 | "format": "prettier-eslint src/**/*.jsx src/**/*.js --write",
14 | "test": "eslint --ext .jsx,.js src/**"
15 | },
16 | "husky": {
17 | "hooks": {
18 | "pre-commit": "lint-staged"
19 | }
20 | },
21 | "lint-staged": {
22 | "*.{js,jsx}": [
23 | "prettier-eslint --write",
24 | "git add"
25 | ]
26 | },
27 | "keywords": [],
28 | "author": "",
29 | "license": "ISC",
30 | "dependencies": {
31 | "@tensorflow/tfjs": "^0.12.5",
32 | "antd": "^3.6.6",
33 | "bizcharts": "^3.2.1-beta.2",
34 | "prop-types": "^15.6.2",
35 | "react": "^16.4.2",
36 | "react-dom": "^16.4.2",
37 | "react-hot-loader": "^4.3.4",
38 | "react-redux": "^5.0.7",
39 | "redux": "^4.0.0",
40 | "redux-devtools-extension": "^2.13.5",
41 | "redux-saga": "^0.16.0",
42 | "styled-components": "^3.3.3"
43 | },
44 | "devDependencies": {
45 | "async-stylesheet-webpack-plugin": "^0.4.1",
46 | "babel-core": "^6.26.0",
47 | "babel-eslint": "^8.2.5",
48 | "babel-loader": "^7.1.4",
49 | "babel-plugin-import": "^1.8.0",
50 | "babel-plugin-styled-components": "^1.5.1",
51 | "babel-plugin-transform-class-properties": "^6.24.1",
52 | "babel-plugin-transform-object-rest-spread": "^6.26.0",
53 | "babel-preset-react": "^6.24.1",
54 | "babel-preset-react-optimize": "^1.0.1",
55 | "clean-webpack-plugin": "^0.1.19",
56 | "copy-webpack-plugin": "^4.5.2",
57 | "css-loader": "^1.0.0",
58 | "eslint": "^5.0.1",
59 | "eslint-config-airbnb": "^17.0.0",
60 | "eslint-loader": "^2.1.0",
61 | "eslint-plugin-import": "^2.13.0",
62 | "eslint-plugin-jsx-a11y": "^6.1.0",
63 | "eslint-plugin-react": "^7.10.0",
64 | "html-webpack-inline-source-plugin": "^0.0.10",
65 | "html-webpack-plugin": "^3.2.0",
66 | "http-server": "^0.11.1",
67 | "husky": "^1.0.0-rc.13",
68 | "lint-staged": "^7.2.2",
69 | "mini-css-extract-plugin": "^0.4.1",
70 | "optimize-css-assets-webpack-plugin": "^5.0.0",
71 | "prettier-eslint": "^8.8.2",
72 | "prettier-eslint-cli": "^4.7.1",
73 | "style-loader": "^0.22.1",
74 | "uglifyjs-webpack-plugin": "^1.3.0",
75 | "webpack": "^4.5.0",
76 | "webpack-bundle-analyzer": "^3.3.2",
77 | "webpack-cdn-plugin": "^3.1.4",
78 | "webpack-cli": "^3.1.0",
79 | "webpack-dev-server": "^3.1.3",
80 | "webpack-merge": "^4.1.4"
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/public/classifiers/group1-shard1of1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixartz/handwritten-digit-recognition-tensorflowjs/20864a2b544a547b8e8a4a961f12b9f5af89b584/public/classifiers/group1-shard1of1
--------------------------------------------------------------------------------
/public/classifiers/model.json:
--------------------------------------------------------------------------------
1 | {"modelTopology": {"keras_version": "2.1.6", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": [{"class_name": "Conv2D", "config": {"name": "conv2d_21", "trainable": true, "batch_input_shape": [null, 28, 28, 1], "dtype": "float32", "filters": 32, "kernel_size": [5, 5], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Conv2D", "config": {"name": "conv2d_22", "trainable": true, "filters": 32, "kernel_size": [5, 5], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_11", "trainable": true, "pool_size": [2, 2], "padding": "valid", "strides": [2, 2], "data_format": "channels_last"}}, {"class_name": "Dropout", "config": {"name": "dropout_16", "trainable": true, "rate": 0.25, "noise_shape": null, "seed": null}}, {"class_name": "Conv2D", "config": {"name": "conv2d_23", "trainable": true, "filters": 64, "kernel_size": [3, 3], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Conv2D", "config": {"name": "conv2d_24", "trainable": true, "filters": 64, "kernel_size": [3, 3], "strides": [1, 1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_12", "trainable": true, "pool_size": [2, 2], "padding": "valid", "strides": [2, 2], "data_format": "channels_last"}}, {"class_name": "Dropout", "config": {"name": "dropout_17", "trainable": true, "rate": 0.25, "noise_shape": null, "seed": null}}, {"class_name": "Flatten", "config": {"name": "flatten_6", "trainable": true, "data_format": "channels_last"}}, {"class_name": "Dense", "config": {"name": "dense_11", "trainable": true, "units": 256, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_18", "trainable": true, "rate": 0.5, "noise_shape": null, "seed": null}}, {"class_name": "Dense", "config": {"name": "dense_12", "trainable": true, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "training_config": {"optimizer_config": {"class_name": "RMSprop", "config": {"lr": 6.25000029685907e-05, "rho": 0.8999999761581421, "decay": 0.0, "epsilon": 1e-08}}, "loss": "categorical_crossentropy", "metrics": ["accuracy"], "sample_weight_mode": null, "loss_weights": null}}, "weightsManifest": [{"paths": ["group1-shard1of1"], "weights": [{"name": "conv2d_21/kernel", "shape": [5, 5, 1, 32], "dtype": "float32"}, {"name": "conv2d_21/bias", "shape": [32], "dtype": "float32"}, {"name": "conv2d_22/kernel", "shape": [5, 5, 32, 32], "dtype": "float32"}, {"name": "conv2d_22/bias", "shape": [32], "dtype": "float32"}, {"name": "conv2d_23/kernel", "shape": [3, 3, 32, 64], "dtype": "float32"}, {"name": "conv2d_23/bias", "shape": [64], "dtype": "float32"}, {"name": "conv2d_24/kernel", "shape": [3, 3, 64, 64], "dtype": "float32"}, {"name": "conv2d_24/bias", "shape": [64], "dtype": "float32"}, {"name": "dense_11/kernel", "shape": [3136, 256], "dtype": "float32"}, {"name": "dense_11/bias", "shape": [256], "dtype": "float32"}, {"name": "dense_12/kernel", "shape": [256, 10], "dtype": "float32"}, {"name": "dense_12/bias", "shape": [10], "dtype": "float32"}]}]}
--------------------------------------------------------------------------------
/public/images/desktop.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixartz/handwritten-digit-recognition-tensorflowjs/20864a2b544a547b8e8a4a961f12b9f5af89b584/public/images/desktop.gif
--------------------------------------------------------------------------------
/public/images/phone.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixartz/handwritten-digit-recognition-tensorflowjs/20864a2b544a547b8e8a4a961f12b9f5af89b584/public/images/phone.gif
--------------------------------------------------------------------------------
/src/actions/chart.js:
--------------------------------------------------------------------------------
1 | export const ChartAction = {
2 | ADD_LOSS_POINT: 'ADD_LOSS_POINT',
3 | ADD_ACCURACY_POINT: 'ADD_ACCURACY_POINT',
4 | RESET_CHART: 'RESET_CHART',
5 | };
6 |
7 | export const addLossPoint = (batch, loss) => ({
8 | type: ChartAction.ADD_LOSS_POINT,
9 | pt: {
10 | batch,
11 | loss,
12 | },
13 | });
14 |
15 | export const addAccuracyPoint = (batch, accuracy) => ({
16 | type: ChartAction.ADD_ACCURACY_POINT,
17 | pt: {
18 | batch,
19 | accuracy,
20 | },
21 | });
22 |
23 | export const resetChart = () => ({
24 | type: ChartAction.RESET_CHART,
25 | });
26 |
--------------------------------------------------------------------------------
/src/actions/drawing.js:
--------------------------------------------------------------------------------
1 | export const StrokeAction = {
2 | ADD_STROKE: 'ADD_STROKE',
3 | ADD_STROKE_POS: 'ADD_STROKE_POS',
4 | END_STROKE: 'END_STROKE',
5 | RESET_DRAWING: 'RESET_DRAWING',
6 | };
7 |
8 | export const addStroke = pos => ({
9 | type: StrokeAction.ADD_STROKE,
10 | pos,
11 | });
12 |
13 | export const addStrokePos = pos => ({
14 | type: StrokeAction.ADD_STROKE_POS,
15 | pos,
16 | });
17 |
18 | export const endStroke = () => ({
19 | type: StrokeAction.END_STROKE,
20 | });
21 |
22 | export const resetDrawing = () => ({
23 | type: StrokeAction.RESET_DRAWING,
24 | });
25 |
--------------------------------------------------------------------------------
/src/actions/mnist.js:
--------------------------------------------------------------------------------
1 | export const MnistAction = {
2 | INIT: 'INIT',
3 | LOAD_PRETRAINED_MODEL_SUCCEEDED: 'LOAD_PRETRAINED_MODEL_SUCCEEDED',
4 | LOAD_AND_TRAIN_MNIST_REQUESTED: 'LOAD_AND_TRAIN_MNIST_REQUESTED',
5 | LOADING_MNIST: 'LOADING_MNIST',
6 | TRAINING_MNIST: 'TRAINING_MNIST',
7 | LOAD_AND_TRAIN_MNIST_SUCCEEDED: 'LOAD_AND_TRAIN_MNIST_SUCCEEDED',
8 | PREDICT_REQUESTED: 'PREDICT_REQUESTED',
9 | PREDICT_SUCCEEDED: 'PREDICT_SUCCEEDED',
10 | };
11 |
12 | export const loadPretrainedModelSucceeded = () => ({
13 | type: MnistAction.LOAD_PRETRAINED_MODEL_SUCCEEDED,
14 | });
15 |
16 | export const loadAndTrainMnist = () => ({
17 | type: MnistAction.LOAD_AND_TRAIN_MNIST_REQUESTED,
18 | });
19 |
20 | export const loadingMnist = () => ({
21 | type: MnistAction.LOADING_MNIST,
22 | });
23 |
24 | export const trainingMnist = () => ({
25 | type: MnistAction.TRAINING_MNIST,
26 | });
27 |
28 | export const loadAndTrainMnistSucceeded = () => ({
29 | type: MnistAction.LOAD_AND_TRAIN_MNIST_SUCCEEDED,
30 | });
31 |
32 | export const requestPredict = image => ({
33 | type: MnistAction.PREDICT_REQUESTED,
34 | image,
35 | });
36 |
37 | export const predictSucceeded = prediction => ({
38 | type: MnistAction.PREDICT_SUCCEEDED,
39 | prediction,
40 | });
41 |
--------------------------------------------------------------------------------
/src/actions/pipeline.js:
--------------------------------------------------------------------------------
1 | export const PipelineAction = {
2 | INPUT: 'INPUT',
3 | DISPLAY_BOUNDING_BOX: 'DISPLAY_BOUNDING_BOX',
4 | DISPLAY_CROPPED_BOX: 'DISPLAY_CROPPED_BOX',
5 | DISPLAY_CENTERED_BOX: 'DISPLAY_CENTERED_BOX',
6 | DISPLAY_NORMALIZED_BOX: 'DISPLAY_NORMALIZED_BOX',
7 | };
8 |
9 | export const addInput = image => ({
10 | type: PipelineAction.INPUT,
11 | image,
12 | });
13 |
14 | export const displayBoundingBox = (imageUrl, rect) => ({
15 | type: PipelineAction.DISPLAY_BOUNDING_BOX,
16 | imageUrl,
17 | rect,
18 | });
19 |
20 | export const displayCroppedBox = croppedUrl => ({
21 | type: PipelineAction.DISPLAY_CROPPED_BOX,
22 | croppedUrl,
23 | });
24 |
25 | export const displayCenteredBox = centeredUrl => ({
26 | type: PipelineAction.DISPLAY_CENTERED_BOX,
27 | centeredUrl,
28 | });
29 |
30 | export const displayNormalizedBox = normalizedUrl => ({
31 | type: PipelineAction.DISPLAY_NORMALIZED_BOX,
32 | normalizedUrl,
33 | });
34 |
--------------------------------------------------------------------------------
/src/classifiers/handwriting-digits-classifier.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 | import { put } from 'redux-saga/effects';
3 | import { addAccuracyPoint, addLossPoint } from '../actions/chart';
4 |
5 | export default class HandwritingDigitsClassifier {
6 | static TRAIN_BATCHES = 150;
7 |
8 | static BATCH_SIZE = 64;
9 |
10 | static TEST_ITERATION_FREQUENCY = 5;
11 |
12 | static TEST_BATCH_SIZE = 1000;
13 |
14 | static LEARNING_RATE = 0.15;
15 |
16 | static CLASSIFIER_FOLDER = 'classifiers';
17 |
18 | static CLASSIFIER_NAME = 'model';
19 |
20 | initializeModel(data) {
21 | this.data = data;
22 |
23 | this.model = tf.sequential();
24 |
25 | this.model.add(
26 | tf.layers.conv2d({
27 | inputShape: [28, 28, 1],
28 | kernelSize: 5,
29 | filters: 8,
30 | strides: 1,
31 | activation: 'relu',
32 | kernelInitializer: 'varianceScaling',
33 | }),
34 | );
35 |
36 | this.model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
37 |
38 | this.model.add(
39 | tf.layers.conv2d({
40 | kernelSize: 5,
41 | filters: 16,
42 | strides: 1,
43 | activation: 'relu',
44 | kernelInitializer: 'varianceScaling',
45 | }),
46 | );
47 |
48 | this.model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
49 |
50 | this.model.add(tf.layers.flatten());
51 |
52 | this.model.add(
53 | tf.layers.dense({ units: 10, kernelInitializer: 'varianceScaling', activation: 'softmax' }),
54 | );
55 |
56 | this.compile();
57 | }
58 |
59 | compile() {
60 | const optimizer = tf.train.sgd(HandwritingDigitsClassifier.LEARNING_RATE);
61 |
62 | this.model.compile({
63 | optimizer,
64 | loss: 'categoricalCrossentropy',
65 | metrics: ['accuracy'],
66 | });
67 | }
68 |
69 | * loadModel() {
70 | const { host, protocol } = window.location;
71 |
72 | this.model = yield tf.loadModel(
73 | `${protocol}//${host}/${HandwritingDigitsClassifier.CLASSIFIER_FOLDER}/${
74 | HandwritingDigitsClassifier.CLASSIFIER_NAME
75 | }.json`,
76 | );
77 | this.compile();
78 | }
79 |
80 | getTrainBatch(i) {
81 | const batch = this.data.nextTrainBatch(HandwritingDigitsClassifier.BATCH_SIZE);
82 | batch.xs = batch.xs.reshape([HandwritingDigitsClassifier.BATCH_SIZE, 28, 28, 1]);
83 |
84 | let validationData;
85 |
86 | if (i % HandwritingDigitsClassifier.TEST_ITERATION_FREQUENCY === 0) {
87 | const testBatch = this.data.nextTrainBatch(HandwritingDigitsClassifier.TEST_BATCH_SIZE);
88 |
89 | validationData = [
90 | testBatch.xs.reshape([HandwritingDigitsClassifier.TEST_BATCH_SIZE, 28, 28, 1]),
91 | testBatch.labels,
92 | ];
93 | }
94 |
95 | return [batch, validationData];
96 | }
97 |
98 | * train(save = false) {
99 | for (let i = 0; i < HandwritingDigitsClassifier.TRAIN_BATCHES; i += 1) {
100 | const [batch, validationData] = tf.tidy(() => this.getTrainBatch(i));
101 |
102 | const history = yield this.model.fit(batch.xs, batch.labels, {
103 | batchSize: HandwritingDigitsClassifier.BATCH_SIZE,
104 | validationData,
105 | epochs: 1,
106 | });
107 |
108 | const loss = history.history.loss[0];
109 | const accuracy = history.history.acc[0];
110 |
111 | yield put(addLossPoint(i, loss));
112 |
113 | if (validationData != null) {
114 | yield put(addAccuracyPoint(i, accuracy));
115 | }
116 |
117 | tf.dispose([batch, validationData]);
118 |
119 | yield tf.nextFrame();
120 | }
121 |
122 | if (save) {
123 | yield this.model.save(`downloads://${HandwritingDigitsClassifier.CLASSIFIER_NAME}`);
124 | }
125 | }
126 |
127 | predict(dataTensor) {
128 | return tf.tidy(() => {
129 | const output = this.model.predict(dataTensor);
130 |
131 | const axis = 1;
132 | const predictions = Array.from(output.argMax(axis).dataSync());
133 |
134 | return predictions[0];
135 | });
136 | }
137 | }
138 |
--------------------------------------------------------------------------------
/src/components/app.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import { hot } from 'react-hot-loader';
3 | import { Provider } from 'react-redux';
4 | import configureStore from '../store/configureStore';
5 | import Mnist from './mnist';
6 |
7 | window.BizCharts.track(false);
8 |
9 | const App = class App extends React.Component {
10 | constructor(props) {
11 | super(props);
12 |
13 | this.store = configureStore();
14 | }
15 |
16 | render() {
17 | return (
18 |
19 |
20 |
21 | );
22 | }
23 | };
24 |
25 | export default hot(module)(App);
26 |
--------------------------------------------------------------------------------
/src/components/bounding-box.jsx:
--------------------------------------------------------------------------------
1 | import PropTypes from 'prop-types';
2 | import React from 'react';
3 | import Rect from '../utils/rect';
4 | import { Display140 } from './display';
5 |
6 | export default class BoundingBox extends React.Component {
7 | static propTypes = {
8 | imageUrl: PropTypes.string,
9 | rect: PropTypes.instanceOf(Rect),
10 | };
11 |
12 | static defaultProps = {
13 | imageUrl: null,
14 | rect: null,
15 | };
16 |
17 | componentDidMount() {
18 | this.initContext();
19 | }
20 |
21 | componentDidUpdate() {
22 | this.initContext();
23 | }
24 |
25 | setDisplay = (elt) => {
26 | this.display = elt;
27 | };
28 |
29 | initContext() {
30 | const { imageUrl, rect } = this.props;
31 |
32 | if (imageUrl) {
33 | const image = new Image();
34 | image.onload = () => {
35 | this.display.ctx.drawImage(image, 0, 0);
36 |
37 | this.display.ctx.strokeStyle = 'red';
38 | this.display.ctx.strokeRect(
39 | rect.xmin,
40 | rect.ymin,
41 | rect.computeWidth(),
42 | rect.computeHeight(),
43 | );
44 | };
45 | image.src = imageUrl;
46 | }
47 | }
48 |
49 | render() {
50 | return ;
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/components/button-space.jsx:
--------------------------------------------------------------------------------
1 | import { Button } from 'antd';
2 | import styled from 'styled-components';
3 |
4 | const ButtonSpace = styled(Button)`
5 | margin: 0 2px;
6 | `;
7 |
8 | export default ButtonSpace;
9 |
--------------------------------------------------------------------------------
/src/components/center-col.jsx:
--------------------------------------------------------------------------------
1 | import { Col } from 'antd';
2 | import styled from 'styled-components';
3 |
4 | const CenterCol = styled(Col)`
5 | display: flex;
6 | align-items: center;
7 | justify-content: center;
8 | `;
9 |
10 | export default CenterCol;
11 |
--------------------------------------------------------------------------------
/src/components/chart-section.jsx:
--------------------------------------------------------------------------------
1 | import { Col, Row } from 'antd';
2 | import React from 'react';
3 | import AccuracyChart from '../containers/accuracy-chart';
4 | import LossChart from '../containers/loss-chart';
5 |
6 | const ChartSection = () => (
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | );
16 |
17 | export default ChartSection;
18 |
--------------------------------------------------------------------------------
/src/components/confirm-retrain-button.jsx:
--------------------------------------------------------------------------------
1 | import { Popconfirm } from 'antd';
2 | import PropTypes from 'prop-types';
3 | import React from 'react';
4 | import styled from 'styled-components';
5 | import RetrainButton from '../containers/retrain-button';
6 |
7 | const PWarning = styled.p`
8 | color: red;
9 | `;
10 |
11 | const content = (
12 |
13 |
14 | The system has already loaded a pretrained model with 0.99700 accuracy. If you want to train a
15 | new model, please click on continue button.
16 |
17 |
18 | It is unrecommended to run the training on your phone. The process will drain your battery.
19 |
20 |
21 | );
22 |
23 | const ConfirmRetrainButton = ({ onConfirm }) => (
24 |
25 |
26 |
27 | );
28 |
29 | ConfirmRetrainButton.propTypes = {
30 | onConfirm: PropTypes.func.isRequired,
31 | };
32 |
33 | export default ConfirmRetrainButton;
34 |
--------------------------------------------------------------------------------
/src/components/display.jsx:
--------------------------------------------------------------------------------
1 | import { Row } from 'antd';
2 | import PropTypes from 'prop-types';
3 | import React from 'react';
4 | import styled from 'styled-components';
5 | import CenterCol from './center-col';
6 |
7 | const DisplayRow = styled(Row)`
8 | margin-bottom: 10px;
9 | `;
10 |
11 | const H3 = styled.h3`
12 | margin: 0;
13 | `;
14 |
15 | const Canvas = styled.canvas`
16 | display: block;
17 | `;
18 |
19 | export class Display extends React.Component {
20 | static propTypes = {
21 | imageUrl: PropTypes.string,
22 | className: PropTypes.string,
23 | width: PropTypes.number,
24 | height: PropTypes.number,
25 | title: PropTypes.string.isRequired,
26 | };
27 |
28 | static defaultProps = {
29 | imageUrl: null,
30 | className: '',
31 | width: 280,
32 | height: 280,
33 | };
34 |
35 | constructor(props) {
36 | super(props);
37 | this.canvas = null;
38 | this.ctx = null;
39 | }
40 |
41 | componentDidMount() {
42 | this.initContext();
43 | }
44 |
45 | componentDidUpdate() {
46 | this.initContext();
47 | }
48 |
49 | setCanvasRef = (elt) => {
50 | this.canvas = elt;
51 | };
52 |
53 | initContext() {
54 | this.ctx = this.canvas.getContext('2d');
55 |
56 | this.clearCanvas();
57 | this.loadImage();
58 | }
59 |
60 | clearCanvas() {
61 | this.ctx.fillRect(0, 0, this.canvas.width, this.canvas.height);
62 | }
63 |
64 | loadImage() {
65 | const { imageUrl } = this.props;
66 |
67 | if (imageUrl) {
68 | const image = new Image();
69 | image.onload = () => {
70 | this.ctx.drawImage(image, 0, 0);
71 | };
72 | image.src = imageUrl;
73 | }
74 | }
75 |
76 | render() {
77 | const {
78 | className, width, height, title,
79 | } = this.props;
80 |
81 | return (
82 |
83 |
84 | {title}
85 |
86 |
87 |
93 |
94 |
95 | );
96 | }
97 | }
98 |
99 | export const Display140 = styled(Display)`
100 | width: 140px;
101 | height: 140px;
102 | `;
103 |
104 | export const Display200 = styled(Display)`
105 | width: 200px;
106 | height: 200px;
107 | `;
108 |
109 | export const Display100 = styled(Display)`
110 | width: 100px;
111 | height: 100px;
112 | `;
113 |
114 | export const Display28 = styled(Display)`
115 | width: 28px;
116 | height: 28px;
117 | `;
118 |
--------------------------------------------------------------------------------
/src/components/drawing.jsx:
--------------------------------------------------------------------------------
1 | import PropTypes from 'prop-types';
2 | import React from 'react';
3 | import styled from 'styled-components';
4 |
5 | const Canvas = styled.canvas`
6 | width: 300px;
7 | height: 300px;
8 | border-color: dodgerblue;
9 | border-width: 5px;
10 | border-style: solid;
11 | display: block;
12 | touch-action: none;
13 |
14 | &:hover {
15 | border-color: deepskyblue;
16 | }
17 | `;
18 |
19 | export default class Drawing extends React.Component {
20 | static propTypes = {
21 | isDrawing: PropTypes.bool.isRequired,
22 | isEndStroke: PropTypes.bool.isRequired,
23 | strokes: PropTypes.arrayOf(
24 | PropTypes.arrayOf(PropTypes.shape({ x: PropTypes.number, y: PropTypes.number })),
25 | ).isRequired,
26 | addStroke: PropTypes.func.isRequired,
27 | addStrokePos: PropTypes.func.isRequired,
28 | endStroke: PropTypes.func.isRequired,
29 | addInput: PropTypes.func.isRequired,
30 | };
31 |
32 | constructor(props) {
33 | super(props);
34 | this.canvas = null;
35 | this.ctx = null;
36 | }
37 |
38 | componentDidMount() {
39 | this.initContext();
40 | }
41 |
42 | componentDidUpdate() {
43 | this.initContext();
44 | this.drawStrokes();
45 |
46 | const { isEndStroke, addInput } = this.props;
47 |
48 | if (isEndStroke) {
49 | addInput(this.ctx.getImageData(0, 0, 280, 280));
50 | }
51 | }
52 |
53 | onMouseDown = (e) => {
54 | const { addStroke } = this.props;
55 | addStroke(this.computeMousePos(e));
56 | };
57 |
58 | onMouseMove = (e) => {
59 | const { isDrawing, addStrokePos } = this.props;
60 |
61 | if (!isDrawing) {
62 | return;
63 | }
64 |
65 | addStrokePos(this.computeMousePos(e));
66 | };
67 |
68 | onStrokeEnd = () => {
69 | const { isDrawing, endStroke } = this.props;
70 |
71 | if (isDrawing) {
72 | endStroke();
73 | }
74 | };
75 |
76 | setCanvasRef = (elt) => {
77 | this.canvas = elt;
78 | };
79 |
80 | initContext() {
81 | this.ctx = this.canvas.getContext('2d');
82 | this.ctx.lineWidth = 10;
83 | this.ctx.lineJoin = 'round';
84 | this.ctx.lineCap = 'round';
85 | this.ctx.strokeStyle = 'white';
86 |
87 | this.ctx.fillStyle = 'dark';
88 |
89 | this.clearCanvas();
90 | }
91 |
92 | clearCanvas() {
93 | this.ctx.fillRect(0, 0, this.canvas.width, this.canvas.height);
94 | }
95 |
96 | computeMousePos(e) {
97 | return {
98 | x: this.computeMousePosX(e),
99 | y: this.computeMousePosY(e),
100 | };
101 | }
102 |
103 | computeMousePosX(e) {
104 | const rect = this.canvas.getBoundingClientRect();
105 | const scaleX = this.canvas.width / rect.width;
106 |
107 | return (e.clientX - rect.left) * scaleX;
108 | }
109 |
110 | computeMousePosY(e) {
111 | const rect = this.canvas.getBoundingClientRect();
112 | const scaleY = this.canvas.height / rect.height;
113 |
114 | return (e.clientY - rect.top) * scaleY;
115 | }
116 |
117 | drawStrokes() {
118 | const { strokes } = this.props;
119 |
120 | for (let j = 0; j < strokes.length; j += 1) {
121 | const points = strokes[j];
122 |
123 | this.ctx.beginPath();
124 | this.ctx.moveTo(points[0].x, points[0].y);
125 |
126 | for (let i = 1; i < points.length; i += 1) {
127 | this.ctx.lineTo(points[i].x, points[i].y);
128 | }
129 | this.ctx.stroke();
130 | }
131 | }
132 |
133 | render() {
134 | return (
135 |
136 |
145 |
146 | );
147 | }
148 | }
149 |
--------------------------------------------------------------------------------
/src/components/image-pipeline.jsx:
--------------------------------------------------------------------------------
1 | import { Row } from 'antd';
2 | import React from 'react';
3 | import styled from 'styled-components';
4 | import BoundingBox from '../containers/bounding-box';
5 | import CenteredBox from '../containers/centered-box';
6 | import CroppedBox from '../containers/cropped-box';
7 | import NormalizedBox from '../containers/normalized-box';
8 | import CenterCol from './center-col';
9 |
10 | const ImagePipelineRow = styled(Row)`
11 | margin-top: 20px;
12 | `;
13 |
14 | const ImagePipeline = () => (
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | );
30 |
31 | export default ImagePipeline;
32 |
--------------------------------------------------------------------------------
/src/components/line-chart.jsx:
--------------------------------------------------------------------------------
1 | import {
2 | Axis, Chart, Geom, Tooltip,
3 | } from 'bizcharts';
4 | import PropTypes from 'prop-types';
5 | import React from 'react';
6 |
7 | export default class LineChart extends React.Component {
8 | static propTypes = {
9 | data: PropTypes.oneOfType([PropTypes.arrayOf(PropTypes.object), PropTypes.object]),
10 | cols: PropTypes.oneOfType([PropTypes.object, PropTypes.array]),
11 | };
12 |
13 | static defaultProps = {
14 | data: null,
15 | cols: null,
16 | };
17 |
18 | static getObjectKeyName(obj, i) {
19 | return Object.keys(obj)[i];
20 | }
21 |
22 | render() {
23 | const { data, cols } = this.props;
24 | const firstAxis = LineChart.getObjectKeyName(cols, 0);
25 | const secondAxis = LineChart.getObjectKeyName(cols, 1);
26 | const position = `${firstAxis}*${secondAxis}`;
27 |
28 | if (data.length > 0) {
29 | return (
30 |
37 |
38 |
39 |
40 |
41 |
48 |
49 | );
50 | }
51 |
52 | return null;
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/src/components/mnist-command.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ClearButton from '../containers/clear-button';
3 | import ConfirmRetrainButton from '../containers/confirm-retrain-button';
4 |
5 | const MnistCommand = () => (
6 |
7 | Clear
8 |
9 |
10 | );
11 |
12 | export default MnistCommand;
13 |
--------------------------------------------------------------------------------
/src/components/mnist-content.jsx:
--------------------------------------------------------------------------------
1 | import { Layout, Row } from 'antd';
2 | import React from 'react';
3 | import styled from 'styled-components';
4 | import MnistDrawing from '../containers/mnist-drawing';
5 | import ChartSection from './chart-section';
6 | import ImagePipeline from './image-pipeline';
7 |
8 | const { Content } = Layout;
9 |
10 | const ContentContainer = styled(Content)`
11 | max-width: 880px;
12 | margin: 0 auto;
13 | `;
14 |
15 | const MnistContent = () => (
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | );
24 |
25 | export default MnistContent;
26 |
--------------------------------------------------------------------------------
/src/components/mnist-drawing.jsx:
--------------------------------------------------------------------------------
1 | import { Spin } from 'antd';
2 | import PropTypes from 'prop-types';
3 | import React from 'react';
4 | import Drawing from '../containers/drawing';
5 | import Output from '../containers/output';
6 | import CenterCol from './center-col';
7 | import MnistCommand from './mnist-command';
8 |
9 | const MnistDrawing = ({ spinning }) => (
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | );
22 |
23 | MnistDrawing.propTypes = {
24 | spinning: PropTypes.bool,
25 | };
26 |
27 | MnistDrawing.defaultProps = {
28 | spinning: true,
29 | };
30 |
31 | export default MnistDrawing;
32 |
--------------------------------------------------------------------------------
/src/components/mnist-footer.jsx:
--------------------------------------------------------------------------------
1 | import { Icon, Layout } from 'antd';
2 | import React from 'react';
3 | import styled from 'styled-components';
4 |
5 | const { Footer } = Layout;
6 |
7 | const IconFooter = styled(Icon)`
8 | margin-right: 3px;
9 | font-size: 15px;
10 | `;
11 |
12 | const PFooter = styled.p`
13 | font-size: 8px;
14 | text-align: center;
15 | `;
16 |
17 | const AFooter = styled.a`
18 | color: inherit;
19 | `;
20 |
21 | const MnistFooter = () => (
22 |
23 |
24 |
29 |
30 |
31 | This page is made with Tensorflow.js, Mnist dataset, React, Redux, Redux-Saga, Babel, Webpack,
32 | Styled-components, Eslint, Prettier and Ant Design - Author:
33 | {' '}
34 | {/* eslint-disable-next-line react/jsx-no-target-blank */}
35 |
36 | Ixartz
37 |
38 |
39 |
40 | );
41 |
42 | export default MnistFooter;
43 |
--------------------------------------------------------------------------------
/src/components/mnist-header.jsx:
--------------------------------------------------------------------------------
1 | import { Layout } from 'antd';
2 | import React from 'react';
3 | import styled from 'styled-components';
4 |
5 | const { Header } = Layout;
6 |
7 | const H1 = styled.h1`
8 | color: white;
9 | overflow: hidden;
10 | text-overflow: ellipsis;
11 | white-space: nowrap;
12 | `;
13 |
14 | const WhiteHeader = styled(Header)`
15 | background: dodgerblue;
16 | margin-bottom: 20px;
17 | `;
18 |
19 | const MnistHeader = () => (
20 |
21 | Digit Recognition with Tensorflow.js and React
22 |
23 | );
24 |
25 | export default MnistHeader;
26 |
--------------------------------------------------------------------------------
/src/components/mnist.jsx:
--------------------------------------------------------------------------------
1 | import { Layout } from 'antd';
2 | import React from 'react';
3 | import MnistContent from './mnist-content';
4 | import MnistFooter from './mnist-footer';
5 | import MnistHeader from './mnist-header';
6 |
7 | const Mnist = () => (
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | );
16 |
17 | export default Mnist;
18 |
--------------------------------------------------------------------------------
/src/components/output.jsx:
--------------------------------------------------------------------------------
1 | import PropTypes from 'prop-types';
2 | import React from 'react';
3 | import styled from 'styled-components';
4 |
5 | const H2 = styled.h2`
6 | text-align: center;
7 | `;
8 |
9 | export default class Output extends React.PureComponent {
10 | static propTypes = {
11 | prediction: PropTypes.number,
12 | };
13 |
14 | static defaultProps = {
15 | prediction: null,
16 | };
17 |
18 | render() {
19 | const { prediction } = this.props;
20 |
21 | return (
22 |
23 | {prediction != null
24 | ? `Prediction: ${prediction}`
25 | : 'Draw a number (0-9) in the black box above'}
26 |
27 | );
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/src/containers/accuracy-chart.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import LineChart from '../components/line-chart';
3 |
4 | const mapStateToProps = state => ({
5 | data: state.chart.accuracy,
6 | cols: {
7 | batch: { range: [0, 1], alias: 'Batch' },
8 | accuracy: { min: 0, alias: 'Accuracy' },
9 | },
10 | });
11 |
12 | export default connect(mapStateToProps)(LineChart);
13 |
--------------------------------------------------------------------------------
/src/containers/bounding-box.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import BoundingBox from '../components/bounding-box';
3 |
4 | const mapStateToProps = state => ({
5 | imageUrl: state.pipeline.imageUrl,
6 | rect: state.pipeline.rect,
7 | });
8 |
9 | export default connect(mapStateToProps)(BoundingBox);
10 |
--------------------------------------------------------------------------------
/src/containers/centered-box.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { Display140 } from '../components/display';
3 |
4 | const mapStateToProps = state => ({
5 | imageUrl: state.pipeline.centeredUrl,
6 | width: 280,
7 | height: 280,
8 | title: 'Centered image',
9 | });
10 |
11 | export default connect(mapStateToProps)(Display140);
12 |
--------------------------------------------------------------------------------
/src/containers/clear-button.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { resetDrawing } from '../actions/drawing';
3 | import ButtonSpace from '../components/button-space';
4 |
5 | const mapStateToProps = state => ({
6 | disabled: state.mnist.prediction == null,
7 | });
8 |
9 | const mapDispatchToProps = dispatch => ({
10 | onClick: () => dispatch(resetDrawing()),
11 | });
12 |
13 | export default connect(
14 | mapStateToProps,
15 | mapDispatchToProps,
16 | )(ButtonSpace);
17 |
--------------------------------------------------------------------------------
/src/containers/confirm-retrain-button.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { loadAndTrainMnist } from '../actions/mnist';
3 | import ConfirmRetrainButton from '../components/confirm-retrain-button';
4 |
5 | const mapDispatchToProps = dispatch => ({
6 | onConfirm: () => dispatch(loadAndTrainMnist()),
7 | });
8 |
9 | export default connect(
10 | null,
11 | mapDispatchToProps,
12 | )(ConfirmRetrainButton);
13 |
--------------------------------------------------------------------------------
/src/containers/cropped-box.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { Display100 } from '../components/display';
3 |
4 | const mapStateToProps = state => ({
5 | imageUrl: state.pipeline.croppedUrl,
6 | width: 200,
7 | height: 200,
8 | title: 'Cropped image',
9 | });
10 |
11 | export default connect(mapStateToProps)(Display100);
12 |
--------------------------------------------------------------------------------
/src/containers/drawing.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { addStroke, addStrokePos, endStroke } from '../actions/drawing';
3 | import { addInput } from '../actions/pipeline';
4 | import Drawing from '../components/drawing';
5 |
6 | const mapStateToProps = state => ({
7 | isDrawing: state.drawing.isDrawing,
8 | isEndStroke: state.drawing.isEndStroke,
9 | strokes: state.drawing.strokes,
10 | });
11 |
12 | const mapDispatchToProps = dispatch => ({
13 | addStroke: pos => dispatch(addStroke(pos)),
14 | addStrokePos: pos => dispatch(addStrokePos(pos)),
15 | endStroke: () => dispatch(endStroke()),
16 | addInput: image => dispatch(addInput(image)),
17 | });
18 |
19 | export default connect(
20 | mapStateToProps,
21 | mapDispatchToProps,
22 | )(Drawing);
23 |
--------------------------------------------------------------------------------
/src/containers/loss-chart.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import LineChart from '../components/line-chart';
3 |
4 | const mapStateToProps = state => ({
5 | data: state.chart.loss,
6 | cols: {
7 | batch: { range: [0, 1], alias: 'Batch' },
8 | loss: { min: 0, alias: 'Loss' },
9 | },
10 | });
11 |
12 | export default connect(mapStateToProps)(LineChart);
13 |
--------------------------------------------------------------------------------
/src/containers/mnist-drawing.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { MnistAction } from '../actions/mnist';
3 | import MnistDrawing from '../components/mnist-drawing';
4 | import isLoadingClassifier from '../utils/classifier';
5 |
6 | const mapStateToProps = state => ({
7 | spinning: state.mnist.status === MnistAction.INIT || isLoadingClassifier(state),
8 | });
9 |
10 | export default connect(mapStateToProps)(MnistDrawing);
11 |
--------------------------------------------------------------------------------
/src/containers/normalized-box.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { Display28 } from '../components/display';
3 |
4 | const mapStateToProps = state => ({
5 | imageUrl: state.pipeline.normalizedUrl,
6 | width: 28,
7 | height: 28,
8 | title: 'Normalized image',
9 | });
10 |
11 | export default connect(mapStateToProps)(Display28);
12 |
--------------------------------------------------------------------------------
/src/containers/output.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import Output from '../components/output';
3 |
4 | const mapStateToProps = state => ({
5 | prediction: state.mnist.prediction,
6 | });
7 |
8 | export default connect(mapStateToProps)(Output);
9 |
--------------------------------------------------------------------------------
/src/containers/retrain-button.jsx:
--------------------------------------------------------------------------------
1 | import { connect } from 'react-redux';
2 | import { MnistAction } from '../actions/mnist';
3 | import ButtonSpace from '../components/button-space';
4 | import isLoadingClassifier from '../utils/classifier';
5 |
6 | const text = (state) => {
7 | switch (state.mnist.retrainStatus) {
8 | case MnistAction.LOADING_MNIST:
9 | return 'Loading MNIST database';
10 | case MnistAction.TRAINING_MNIST:
11 | return 'Training new model';
12 | default:
13 | return 'Retrain model';
14 | }
15 | };
16 |
17 | const mapStateToProps = state => ({
18 | loading: isLoadingClassifier(state),
19 | disabled: state.mnist.prediction == null,
20 | children: text(state),
21 | });
22 |
23 | // Do delete mapDispatchToProps, keep it empty. Otherwise, it will create an error on the console
24 | const mapDispatchToProps = () => ({});
25 |
26 | export default connect(
27 | mapStateToProps,
28 | mapDispatchToProps,
29 | )(ButtonSpace);
30 |
--------------------------------------------------------------------------------
/src/data/mnist-data.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | /* eslint-disable */
19 |
20 | import * as tf from '@tensorflow/tfjs';
21 |
22 | const IMAGE_SIZE = 784;
23 | const NUM_CLASSES = 10;
24 | const NUM_DATASET_ELEMENTS = 65000;
25 |
26 | const NUM_TRAIN_ELEMENTS = 55000;
27 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
28 |
29 | const MNIST_IMAGES_SPRITE_PATH =
30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
31 | const MNIST_LABELS_PATH =
32 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
33 |
34 | /**
35 | * A class that fetches the sprited MNIST dataset and returns shuffled batches.
36 | *
37 | * NOTE: This will get much easier. For now, we do data fetching and
38 | * manipulation manually.
39 | */
40 | export default class MnistData {
41 | constructor() {
42 | this.shuffledTrainIndex = 0;
43 | this.shuffledTestIndex = 0;
44 | }
45 |
46 | load = async () => {
47 | // Make a request for the MNIST sprited image.
48 | const img = new Image();
49 | const canvas = document.createElement('canvas');
50 | const ctx = canvas.getContext('2d');
51 | const imgRequest = new Promise((resolve, reject) => {
52 | img.crossOrigin = '';
53 | img.onload = () => {
54 | img.width = img.naturalWidth;
55 | img.height = img.naturalHeight;
56 |
57 | const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
58 |
59 | const chunkSize = 5000;
60 | canvas.width = img.width;
61 | canvas.height = chunkSize;
62 |
63 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
64 | const datasetBytesView = new Float32Array(
65 | datasetBytesBuffer,
66 | i * IMAGE_SIZE * chunkSize * 4,
67 | IMAGE_SIZE * chunkSize,
68 | );
69 | ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);
70 |
71 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
72 |
73 | for (let j = 0; j < imageData.data.length / 4; j++) {
74 | // All channels hold an equal value since the image is grayscale, so
75 | // just read the red channel.
76 | datasetBytesView[j] = imageData.data[j * 4] / 255;
77 | }
78 | }
79 | this.datasetImages = new Float32Array(datasetBytesBuffer);
80 |
81 | resolve();
82 | };
83 | img.src = MNIST_IMAGES_SPRITE_PATH;
84 | });
85 |
86 | const labelsRequest = fetch(MNIST_LABELS_PATH);
87 | const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]);
88 |
89 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
90 |
91 | // Create shuffled indices into the train/test set for when we select a
92 | // random dataset element for training / validation.
93 | this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
94 | this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
95 |
96 | // Slice the the images and labels into train and test sets.
97 | this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
98 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
99 | this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
100 | this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
101 | };
102 |
103 | nextTrainBatch(batchSize) {
104 | return this.nextBatch(batchSize, [this.trainImages, this.trainLabels], () => {
105 | this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length;
106 | return this.trainIndices[this.shuffledTrainIndex];
107 | });
108 | }
109 |
110 | nextTestBatch(batchSize) {
111 | return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
112 | this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length;
113 | return this.testIndices[this.shuffledTestIndex];
114 | });
115 | }
116 |
117 | nextBatch(batchSize, data, index) {
118 | const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
119 | const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
120 |
121 | for (let i = 0; i < batchSize; i++) {
122 | const idx = index();
123 |
124 | const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
125 | batchImagesArray.set(image, i * IMAGE_SIZE);
126 |
127 | const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
128 | batchLabelsArray.set(label, i * NUM_CLASSES);
129 | }
130 |
131 | const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
132 | const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
133 |
134 | return { xs, labels };
135 | }
136 | }
137 |
--------------------------------------------------------------------------------
/src/main.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDom from 'react-dom';
3 | import App from './components/app';
4 |
5 | ReactDom.render( , document.getElementById('app'));
6 |
--------------------------------------------------------------------------------
/src/reducers/chart.js:
--------------------------------------------------------------------------------
1 | import { ChartAction } from '../actions/chart';
2 |
3 | const initialState = {
4 | loss: [],
5 | accuracy: [],
6 | };
7 |
8 | const chart = (state = initialState, action) => {
9 | switch (action.type) {
10 | case ChartAction.ADD_LOSS_POINT:
11 | return {
12 | ...state,
13 | loss: [...state.loss, action.pt],
14 | };
15 | case ChartAction.ADD_ACCURACY_POINT:
16 | return {
17 | ...state,
18 | accuracy: [...state.accuracy, action.pt],
19 | };
20 | case ChartAction.RESET_CHART:
21 | return initialState;
22 | default:
23 | return state;
24 | }
25 | };
26 |
27 | export default chart;
28 |
--------------------------------------------------------------------------------
/src/reducers/drawing.js:
--------------------------------------------------------------------------------
1 | import { StrokeAction } from '../actions/drawing';
2 |
3 | const initialState = {
4 | isDrawing: false,
5 | isEndStroke: false,
6 | strokes: [],
7 | };
8 |
9 | const drawing = (state = initialState, action) => {
10 | switch (action.type) {
11 | case StrokeAction.ADD_STROKE:
12 | return {
13 | ...state,
14 | isDrawing: true,
15 | isEndStroke: false,
16 | strokes: [...state.strokes, [action.pos]],
17 | };
18 | case StrokeAction.ADD_STROKE_POS:
19 | return {
20 | ...state,
21 | strokes: [
22 | ...state.strokes.slice(0, state.strokes.length - 1),
23 | [...state.strokes[state.strokes.length - 1], action.pos],
24 | ],
25 | };
26 | case StrokeAction.END_STROKE:
27 | return {
28 | ...state,
29 | isDrawing: false,
30 | isEndStroke: true,
31 | };
32 | case StrokeAction.RESET_DRAWING:
33 | return initialState;
34 | default:
35 | return state;
36 | }
37 | };
38 |
39 | export default drawing;
40 |
--------------------------------------------------------------------------------
/src/reducers/index.js:
--------------------------------------------------------------------------------
1 | import { combineReducers } from 'redux';
2 | import chart from './chart';
3 | import drawing from './drawing';
4 | import mnist from './mnist';
5 | import pipeline from './pipeline';
6 |
7 | export default combineReducers({
8 | drawing,
9 | pipeline,
10 | mnist,
11 | chart,
12 | });
13 |
--------------------------------------------------------------------------------
/src/reducers/mnist.js:
--------------------------------------------------------------------------------
1 | import { MnistAction } from '../actions/mnist';
2 |
3 | const initialState = {
4 | status: MnistAction.INIT,
5 | retrainStatus: MnistAction.INIT,
6 | };
7 |
8 | const mnist = (state = initialState, action) => {
9 | switch (action.type) {
10 | case MnistAction.LOAD_PRETRAINED_MODEL_SUCCEEDED:
11 | return {
12 | ...state,
13 | status: MnistAction.LOAD_PRETRAINED_MODEL_SUCCEEDED,
14 | };
15 | case MnistAction.LOADING_MNIST:
16 | case MnistAction.TRAINING_MNIST:
17 | case MnistAction.LOAD_AND_TRAIN_MNIST_SUCCEEDED:
18 | return {
19 | ...state,
20 | retrainStatus: action.type,
21 | };
22 | case MnistAction.PREDICT_SUCCEEDED:
23 | return {
24 | ...state,
25 | prediction: action.prediction,
26 | };
27 | default:
28 | return state;
29 | }
30 | };
31 |
32 | export default mnist;
33 |
--------------------------------------------------------------------------------
/src/reducers/pipeline.js:
--------------------------------------------------------------------------------
1 | import { PipelineAction } from '../actions/pipeline';
2 |
3 | const pipeline = (state = {}, action) => {
4 | switch (action.type) {
5 | case PipelineAction.DISPLAY_BOUNDING_BOX:
6 | return {
7 | ...state,
8 | imageUrl: action.imageUrl,
9 | rect: action.rect,
10 | };
11 | case PipelineAction.DISPLAY_CROPPED_BOX:
12 | return {
13 | ...state,
14 | croppedUrl: action.croppedUrl,
15 | };
16 | case PipelineAction.DISPLAY_CENTERED_BOX:
17 | return {
18 | ...state,
19 | centeredUrl: action.centeredUrl,
20 | };
21 | case PipelineAction.DISPLAY_NORMALIZED_BOX:
22 | return {
23 | ...state,
24 | normalizedUrl: action.normalizedUrl,
25 | };
26 | default:
27 | return state;
28 | }
29 | };
30 |
31 | export default pipeline;
32 |
--------------------------------------------------------------------------------
/src/sagas/index.js:
--------------------------------------------------------------------------------
1 | import { all, fork } from 'redux-saga/effects';
2 | import watchMnist from './mnist';
3 | import watchPipeline from './pipeline';
4 |
5 | export default function* rootSaga() {
6 | yield all([fork(watchMnist), fork(watchPipeline)]);
7 | }
8 |
--------------------------------------------------------------------------------
/src/sagas/mnist.js:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 | import {
3 | apply, call, put, take,
4 | } from 'redux-saga/effects';
5 | import { resetChart } from '../actions/chart';
6 | import { resetDrawing } from '../actions/drawing';
7 | import {
8 | loadAndTrainMnistSucceeded,
9 | loadingMnist,
10 | loadPretrainedModelSucceeded,
11 | MnistAction,
12 | predictSucceeded,
13 | trainingMnist,
14 | } from '../actions/mnist';
15 | import HandwritingDigitsClassifier from '../classifiers/handwriting-digits-classifier';
16 | import MnistData from '../data/mnist-data';
17 | import { convertToGrayscale } from '../utils/image-processing';
18 |
19 | function* trainMnist(mnistData) {
20 | yield put(trainingMnist());
21 | const handwritingDigitsClassifier = new HandwritingDigitsClassifier();
22 | handwritingDigitsClassifier.initializeModel(mnistData);
23 | yield apply(handwritingDigitsClassifier, handwritingDigitsClassifier.train);
24 | return handwritingDigitsClassifier;
25 | }
26 |
27 | function* loadAndTrainMnist() {
28 | yield put(loadingMnist());
29 | const mnistData = new MnistData();
30 | yield call(mnistData.load);
31 | const handwritingDigitsClassifier = yield call(trainMnist, mnistData);
32 | return handwritingDigitsClassifier;
33 | }
34 |
35 | function* loadPretrainedModel() {
36 | const pretrainedHandwritingDigitsClassifier = new HandwritingDigitsClassifier();
37 | yield apply(
38 | pretrainedHandwritingDigitsClassifier,
39 | pretrainedHandwritingDigitsClassifier.loadModel,
40 | );
41 | yield put(loadPretrainedModelSucceeded());
42 | return pretrainedHandwritingDigitsClassifier;
43 | }
44 |
45 | function* predict(handwritingDigitsClassifier, image) {
46 | const dataGrayscale = convertToGrayscale(image);
47 | const dataTensor = tf.tensor(dataGrayscale, [1, 28, 28, 1]);
48 | const prediction = handwritingDigitsClassifier.predict(dataTensor);
49 | yield put(predictSucceeded(prediction));
50 | }
51 |
52 | function* request(pretrainedHandwritingDigitsClassifier) {
53 | let handwritingDigitsClassifier = pretrainedHandwritingDigitsClassifier;
54 |
55 | while (true) {
56 | const action = yield take([
57 | MnistAction.PREDICT_REQUESTED,
58 | MnistAction.LOAD_AND_TRAIN_MNIST_REQUESTED,
59 | ]);
60 |
61 | switch (action.type) {
62 | case MnistAction.PREDICT_REQUESTED:
63 | yield call(predict, handwritingDigitsClassifier, action.image);
64 | break;
65 | case MnistAction.LOAD_AND_TRAIN_MNIST_REQUESTED:
66 | yield put(resetChart());
67 | yield put(resetDrawing());
68 |
69 | if (handwritingDigitsClassifier.data) {
70 | handwritingDigitsClassifier = yield call(trainMnist, handwritingDigitsClassifier.data);
71 | } else {
72 | handwritingDigitsClassifier = yield call(loadAndTrainMnist);
73 | }
74 |
75 | yield put(loadAndTrainMnistSucceeded());
76 | break;
77 | default:
78 | break;
79 | }
80 | }
81 | }
82 |
83 | function* watchMnist() {
84 | const pretrainedHandwritingDigitsClassifier = yield call(loadPretrainedModel);
85 | yield call(request, pretrainedHandwritingDigitsClassifier);
86 | }
87 |
88 | export default watchMnist;
89 |
--------------------------------------------------------------------------------
/src/sagas/pipeline.js:
--------------------------------------------------------------------------------
1 | import { call, put, takeLatest } from 'redux-saga/effects';
2 | import { requestPredict } from '../actions/mnist';
3 | import {
4 | displayBoundingBox,
5 | displayCenteredBox,
6 | displayCroppedBox,
7 | displayNormalizedBox,
8 | PipelineAction,
9 | } from '../actions/pipeline';
10 | import { computeBoundingRect } from '../utils/image-processing';
11 |
12 | function* boundingBoxTask(action) {
13 | const canvas = document.createElement('canvas');
14 | canvas.width = 280;
15 | canvas.height = 280;
16 | const ctx = canvas.getContext('2d');
17 | ctx.putImageData(action.image, 0, 0);
18 |
19 | const rect = computeBoundingRect(action.image);
20 |
21 | yield put(displayBoundingBox(canvas.toDataURL(), rect));
22 |
23 | return {
24 | canvas,
25 | rect,
26 | };
27 | }
28 |
29 | function* croppedBoxTask(canvas, rect) {
30 | const croppedCanvas = document.createElement('canvas');
31 | croppedCanvas.width = 200;
32 | croppedCanvas.height = 200;
33 | const croppedCtx = croppedCanvas.getContext('2d');
34 |
35 | const rectWidth = rect.computeWidth();
36 | const rectHeight = rect.computeHeight();
37 | const scalingFactor = 200 / Math.max(rectWidth, rectHeight);
38 | const croppedRectSize = {
39 | width: rectWidth * scalingFactor,
40 | height: rectHeight * scalingFactor,
41 | };
42 |
43 | croppedCtx.drawImage(
44 | canvas,
45 | rect.xmin,
46 | rect.ymin,
47 | rectWidth,
48 | rectHeight,
49 | 0,
50 | 0,
51 | croppedRectSize.width,
52 | croppedRectSize.height,
53 | );
54 |
55 | yield put(displayCroppedBox(croppedCanvas.toDataURL()));
56 |
57 | return {
58 | croppedCanvas,
59 | croppedRectSize,
60 | };
61 | }
62 |
63 | function* centeredBoxTask(croppedCanvas, croppedRectSize) {
64 | const centeredCanvas = document.createElement('canvas');
65 | centeredCanvas.width = 280;
66 | centeredCanvas.height = 280;
67 | const centeredCtx = centeredCanvas.getContext('2d');
68 |
69 | centeredCtx.drawImage(
70 | croppedCanvas,
71 | centeredCanvas.width / 2 - croppedRectSize.width / 2,
72 | centeredCanvas.height / 2 - croppedRectSize.height / 2,
73 | );
74 |
75 | yield put(displayCenteredBox(centeredCanvas.toDataURL()));
76 |
77 | return centeredCanvas;
78 | }
79 |
80 | function* normalizedTask(centeredCanvas) {
81 | const normalizedCanvas = document.createElement('canvas');
82 | normalizedCanvas.width = 28;
83 | normalizedCanvas.height = 28;
84 | const normalizedCtx = normalizedCanvas.getContext('2d');
85 |
86 | normalizedCtx.drawImage(
87 | centeredCanvas,
88 | 0,
89 | 0,
90 | centeredCanvas.width,
91 | centeredCanvas.height,
92 | 0,
93 | 0,
94 | normalizedCanvas.width,
95 | normalizedCanvas.height,
96 | );
97 |
98 | yield put(displayNormalizedBox(normalizedCanvas.toDataURL()));
99 |
100 | return normalizedCanvas;
101 | }
102 |
103 | function* requestPredictTask(normalizedCanvas) {
104 | const normalizedCtx = normalizedCanvas.getContext('2d');
105 | yield put(
106 | requestPredict(
107 | normalizedCtx.getImageData(0, 0, normalizedCanvas.width, normalizedCanvas.height),
108 | ),
109 | );
110 | }
111 |
112 | function* pipelineImgProcessing(action) {
113 | const { canvas, rect } = yield call(boundingBoxTask, action);
114 | const { croppedCanvas, croppedRectSize } = yield call(croppedBoxTask, canvas, rect);
115 | const centeredCanvas = yield call(centeredBoxTask, croppedCanvas, croppedRectSize);
116 | const normalizedCanvas = yield call(normalizedTask, centeredCanvas);
117 | yield call(requestPredictTask, normalizedCanvas);
118 | }
119 |
120 | function* watchPipeline() {
121 | yield takeLatest(PipelineAction.INPUT, pipelineImgProcessing);
122 | }
123 |
124 | export default watchPipeline;
125 |
--------------------------------------------------------------------------------
/src/store/configureStore.js:
--------------------------------------------------------------------------------
1 | import { applyMiddleware, createStore } from 'redux';
2 | import { composeWithDevTools } from 'redux-devtools-extension/logOnlyInProduction';
3 | import createSagaMiddleware from 'redux-saga';
4 | import reducer from '../reducers';
5 | import rootSaga from '../sagas';
6 |
7 | export default function configureStore(preloadedState) {
8 | const composeEnhancers = composeWithDevTools({});
9 |
10 | const sagaMiddleware = createSagaMiddleware();
11 | const middleware = [sagaMiddleware];
12 |
13 | const store = createStore(
14 | reducer,
15 | preloadedState,
16 | composeEnhancers(applyMiddleware(...middleware)),
17 | );
18 | sagaMiddleware.run(rootSaga);
19 |
20 | if (module.hot) {
21 | module.hot.accept('../reducers', () => store.replaceReducer(require('../reducers').default)); // eslint-disable-line global-require
22 |
23 | module.hot.decline('../sagas');
24 | }
25 |
26 | return store;
27 | }
28 |
--------------------------------------------------------------------------------
/src/utils/classifier.js:
--------------------------------------------------------------------------------
1 | import { MnistAction } from '../actions/mnist';
2 |
3 | export default function isLoadingClassifier(state) {
4 | return (
5 | state.mnist.retrainStatus !== MnistAction.INIT
6 | && state.mnist.retrainStatus !== MnistAction.LOAD_AND_TRAIN_MNIST_SUCCEEDED
7 | );
8 | }
9 |
--------------------------------------------------------------------------------
/src/utils/image-processing.js:
--------------------------------------------------------------------------------
1 | import Rect from './rect';
2 |
3 | export function computeBoundingRect(image) {
4 | const rect = new Rect(image);
5 |
6 | for (let i = 0; i < image.width * image.height; i += 1) {
7 | const j = i * 4;
8 |
9 | if (image.data[j + 0] > 0 || image.data[j + 1] > 0 || image.data[j + 2] > 0) {
10 | const x = i % image.width;
11 | const y = Math.floor(i / image.width);
12 |
13 | rect.xmin = Math.min(x, rect.xmin);
14 | rect.xmax = Math.max(x, rect.xmax);
15 | rect.ymin = Math.min(y, rect.ymin);
16 | rect.ymax = Math.max(y, rect.ymax);
17 | }
18 | }
19 |
20 | return rect;
21 | }
22 |
23 | export function convertToGrayscale(image) {
24 | const dataGrayscale = [];
25 | const { data } = image;
26 |
27 | for (let i = 0; i < image.width * image.height; i += 1) {
28 | const j = i * 4;
29 | const avg = (data[j + 0] + data[j + 1] + data[j + 2]) / 3;
30 | const normalized = avg / 255.0;
31 | dataGrayscale.push(normalized);
32 | }
33 |
34 | return dataGrayscale;
35 | }
36 |
--------------------------------------------------------------------------------
/src/utils/rect.js:
--------------------------------------------------------------------------------
1 | export default class Rect {
2 | constructor(image) {
3 | this.xmin = image.width;
4 | this.xmax = -1;
5 | this.ymin = image.height;
6 | this.ymax = -1;
7 | }
8 |
9 | computeWidth() {
10 | return this.xmax - this.xmin + 1;
11 | }
12 |
13 | computeHeight() {
14 | return this.ymax - this.ymin + 1;
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | <%= htmlWebpackPlugin.options.title %>
9 |
10 |
11 |
12 |
19 |
20 |
21 |
22 | Loading application...
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/webpack.config.js:
--------------------------------------------------------------------------------
1 | module.exports = (env) => {
2 | if (env.WEBPACK_CONFIG === 'dev') {
3 | return require('./webpack/webpack.dev.js');
4 | }
5 | if (env.WEBPACK_CONFIG === 'prod') {
6 | return require('./webpack/webpack.prod.js');
7 | }
8 | if (env.WEBPACK_CONFIG === 'analyze') {
9 | return require('./webpack/webpack.analyze.js');
10 | }
11 | };
12 |
--------------------------------------------------------------------------------
/webpack/utils.js:
--------------------------------------------------------------------------------
1 | const path = require('path');
2 |
3 | const ROOT_DIR = path.resolve(__dirname, '../');
4 | const PUBLIC_DIR = path.resolve(ROOT_DIR, 'public');
5 | const DIST_DIR = path.resolve(ROOT_DIR, 'dist');
6 |
7 | function generateHtmlWebpackPluginConfig() {
8 | return {
9 | title: 'Live digit recognition on phone and desktop',
10 | template: path.resolve(ROOT_DIR, 'templates/index.html'),
11 | };
12 | }
13 |
14 | module.exports = {
15 | ROOT_DIR,
16 | PUBLIC_DIR,
17 | DIST_DIR,
18 | generateHtmlWebpackPluginConfig,
19 | };
20 |
--------------------------------------------------------------------------------
/webpack/webpack.analyze.js:
--------------------------------------------------------------------------------
1 | const BundleAnalyzerPlugin = require('webpack-bundle-analyzer').BundleAnalyzerPlugin;
2 | const merge = require('webpack-merge');
3 | const Prod = require('./webpack.prod.js');
4 |
5 | module.exports = merge.smartStrategy({
6 | 'module.rules.use': 'prepend',
7 | })(Prod, {
8 | plugins: [
9 | new BundleAnalyzerPlugin({
10 | openAnalyzer: false,
11 | analyzerMode: 'static',
12 | }),
13 | ],
14 | });
15 |
--------------------------------------------------------------------------------
/webpack/webpack.common.js:
--------------------------------------------------------------------------------
1 | const CleanWebpackPlugin = require('clean-webpack-plugin');
2 | const HtmlWebpackPlugin = require('html-webpack-plugin');
3 | const path = require('path');
4 | const Utils = require('./utils');
5 |
6 | const config = {
7 | entry: path.resolve(Utils.ROOT_DIR, 'src/main.jsx'),
8 | output: {
9 | path: Utils.DIST_DIR,
10 | filename: 'scripts/[name].js',
11 | publicPath: '/',
12 | },
13 | resolve: {
14 | extensions: ['.js', '.jsx'],
15 | },
16 | module: {
17 | rules: [
18 | {
19 | enforce: 'pre',
20 | test: /\.(js|jsx)$/,
21 | include: /src/,
22 | use: [
23 | {
24 | loader: 'eslint-loader',
25 | options: {
26 | cache: true,
27 | emitWarning: true,
28 | },
29 | },
30 | ],
31 | },
32 | {
33 | test: /\.(js|jsx)$/,
34 | include: /src/,
35 | use: [
36 | {
37 | loader: 'babel-loader',
38 | options: {
39 | cacheDirectory: true,
40 | },
41 | },
42 | ],
43 | },
44 | {
45 | test: /\.css$/,
46 | use: [
47 | // Environment config will prepend loader here
48 | {
49 | loader: 'css-loader',
50 | },
51 | ],
52 | },
53 | ],
54 | },
55 | plugins: [
56 | new CleanWebpackPlugin([Utils.DIST_DIR], { root: Utils.ROOT_DIR }),
57 | new HtmlWebpackPlugin({
58 | ...Utils.generateHtmlWebpackPluginConfig(),
59 | }),
60 | ],
61 | optimization: {
62 | runtimeChunk: 'single',
63 | splitChunks: {
64 | cacheGroups: {
65 | vendor: {
66 | test: /[\\/]node_modules[\\/]/,
67 | name: 'vendors',
68 | chunks: 'all',
69 | },
70 | },
71 | },
72 | },
73 | };
74 |
75 | module.exports = config;
76 |
--------------------------------------------------------------------------------
/webpack/webpack.dev.js:
--------------------------------------------------------------------------------
1 | const BundleAnalyzerPlugin = require('webpack-bundle-analyzer').BundleAnalyzerPlugin;
2 | const WebpackCdnPlugin = require('webpack-cdn-plugin');
3 | const merge = require('webpack-merge');
4 | const path = require('path');
5 | const Common = require('./webpack.common.js');
6 | const Utils = require('./utils');
7 |
8 | module.exports = merge.smartStrategy({
9 | 'module.rules.use': 'prepend',
10 | })(Common, {
11 | mode: 'development',
12 | devtool: 'source-map',
13 | devServer: {
14 | contentBase: [Utils.PUBLIC_DIR, path.resolve(Utils.ROOT_DIR, 'node_modules')],
15 | overlay: {
16 | warnings: false,
17 | errors: true,
18 | },
19 | port: 9000,
20 | },
21 | module: {
22 | rules: [
23 | {
24 | test: /\.css$/,
25 | use: [
26 | {
27 | loader: 'style-loader',
28 | options: {
29 | sourceMap: true,
30 | },
31 | },
32 | {
33 | loader: 'css-loader',
34 | options: {
35 | sourceMap: true,
36 | },
37 | },
38 | ],
39 | },
40 | ],
41 | },
42 | plugins: [
43 | new WebpackCdnPlugin({
44 | modules: [
45 | { name: 'react', var: 'React', path: 'umd/react.development.js' },
46 | { name: 'react-dom', var: 'ReactDOM', path: 'umd/react-dom.development.js' },
47 | { name: 'react-redux', var: 'ReactRedux', path: 'dist/react-redux.js' },
48 | { name: 'redux-saga', var: 'ReduxSaga', path: 'dist/redux-saga.js' },
49 | { name: '@tensorflow/tfjs', var: 'tf', path: 'dist/tf.js' },
50 | { name: 'bizcharts', var: 'BizCharts', path: 'umd/BizCharts.js' },
51 | ],
52 | prod: false,
53 | }),
54 | new BundleAnalyzerPlugin({
55 | openAnalyzer: false,
56 | }),
57 | ],
58 | });
59 |
--------------------------------------------------------------------------------
/webpack/webpack.prod.js:
--------------------------------------------------------------------------------
1 | const AsyncStylesheetWebpackPlugin = require('async-stylesheet-webpack-plugin');
2 | const CopyWebpackPlugin = require('copy-webpack-plugin');
3 | const HtmlWebpackPlugin = require('html-webpack-plugin');
4 | const InlineSourcePlugin = require('html-webpack-inline-source-plugin');
5 | const MiniCssExtractPlugin = require('mini-css-extract-plugin');
6 | const OptimizeCSSAssetsPlugin = require('optimize-css-assets-webpack-plugin');
7 | const UglifyJsPlugin = require('uglifyjs-webpack-plugin');
8 | const WebpackCdnPlugin = require('webpack-cdn-plugin');
9 | const merge = require('webpack-merge');
10 | const webpack = require('webpack');
11 | const Common = require('./webpack.common.js');
12 | const Utils = require('./utils');
13 |
14 | module.exports = merge.smartStrategy({
15 | 'module.rules.use': 'prepend',
16 | })(Common, {
17 | mode: 'production',
18 | devtool: false,
19 | output: {
20 | filename: 'assets/[name].[contenthash].js',
21 | },
22 | module: {
23 | rules: [
24 | {
25 | test: /\.(js|jsx)$/,
26 | include: /src/,
27 | use: [
28 | {
29 | loader: 'babel-loader',
30 | options: {
31 | forceEnv: 'production',
32 | },
33 | },
34 | ],
35 | },
36 | {
37 | test: /\.css$/,
38 | use: [
39 | {
40 | loader: MiniCssExtractPlugin.loader,
41 | },
42 | ],
43 | },
44 | ],
45 | },
46 | plugins: [
47 | new webpack.HashedModuleIdsPlugin(),
48 | new CopyWebpackPlugin([{ from: Utils.PUBLIC_DIR, to: Utils.DIST_DIR }]),
49 | new MiniCssExtractPlugin({
50 | filename: 'styles/[name].[contenthash].css',
51 | }),
52 | new HtmlWebpackPlugin({
53 | ...Utils.generateHtmlWebpackPluginConfig(),
54 | inlineSource: 'runtime.+\\.js',
55 | }),
56 | new InlineSourcePlugin(),
57 | new AsyncStylesheetWebpackPlugin({
58 | preloadPolyfill: true,
59 | }),
60 | new WebpackCdnPlugin({
61 | modules: [
62 | { name: 'react', var: 'React', path: 'umd/react.production.min.js' },
63 | { name: 'react-dom', var: 'ReactDOM', path: 'umd/react-dom.production.min.js' },
64 | { name: 'styled-components', var: 'styled', path: 'dist/styled-components.min.js' },
65 | { name: 'redux', var: 'Redux', path: 'dist/redux.min.js' },
66 | { name: 'react-redux', var: 'ReactRedux', path: 'dist/react-redux.min.js' },
67 | { name: 'redux-saga', var: 'ReduxSaga', path: 'dist/redux-saga.min.js' },
68 | { name: '@tensorflow/tfjs', var: 'tf', path: 'dist/tf.min.js' },
69 | { name: 'bizcharts', var: 'BizCharts', path: 'umd/BizCharts.min.js' },
70 | ],
71 | prod: true,
72 | }),
73 | ],
74 | optimization: {
75 | noEmitOnErrors: true,
76 | minimizer: [
77 | new UglifyJsPlugin({
78 | cache: true,
79 | parallel: true,
80 | uglifyOptions: {
81 | output: {
82 | comments: false,
83 | },
84 | },
85 | }),
86 | new OptimizeCSSAssetsPlugin({}),
87 | ],
88 | },
89 | performance: {
90 | maxEntrypointSize: 350000,
91 | },
92 | });
93 |
--------------------------------------------------------------------------------