├── .gitignore
├── .travis.yml
├── README.md
├── package-lock.json
├── package.json
├── public
├── favicon.ico
├── index.html
└── manifest.json
└── src
├── App.css
├── App.js
├── index.css
├── index.js
├── neuralNetwork.js
└── registerServiceWorker.js
/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/ignore-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 |
6 | # testing
7 | /coverage
8 |
9 | # production
10 | /build
11 |
12 | # misc
13 | .DS_Store
14 | .env.local
15 | .env.development.local
16 | .env.test.local
17 | .env.production.local
18 |
19 | npm-debug.log*
20 | yarn-debug.log*
21 | yarn-error.log*
22 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: node_js
2 |
3 | node_js:
4 | - stable
5 |
6 | install:
7 | - npm install
8 |
9 | script:
10 | - npm test
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MNIST Digit Recognition Neural Network in JavaScript with Deeplearn.js
2 |
3 | [](https://travis-ci.org/javascript-machine-learning/mnist-neural-network-deeplearnjs)
4 |
5 | Use Case: Recognizing handwritten digits from the [MNIST Database](https://en.wikipedia.org/wiki/MNIST_database).
6 |
7 | This example project demonstrates how neural networks may be used to solve a multi-class classification problem. It uses [deeplearn.js](https://deeplearnjs.org/) to recognize handwritten digits from the MNIST database.
8 |
9 | 
10 |
11 | ## Installation
12 |
13 | * `git clone git@github.com:javascript-machine-learning/mnist-neural-network-deeplearnjs.git`
14 | * `cd mnist-neural-network-deeplearnjs`
15 | * `npm install`
16 | * `npm start`
17 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "mnist-neural-network-deeplearnjs",
3 | "version": "0.1.0",
4 | "private": true,
5 | "dependencies": {
6 | "deeplearn": "0.3.12",
7 | "mnist": "^1.1.0",
8 | "react": "^16.1.1",
9 | "react-dom": "^16.1.1",
10 | "react-scripts": "1.0.17"
11 | },
12 | "scripts": {
13 | "start": "react-scripts start",
14 | "build": "react-scripts build",
15 | "test": "react-scripts test --env=jsdom",
16 | "eject": "react-scripts eject"
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/javascript-machine-learning/mnist-neural-network-deeplearnjs/bf3bae4da20d2aaf8830ffeff0a4aa1beb271246/public/favicon.ico
--------------------------------------------------------------------------------
/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
12 |
13 |
22 | React App
23 |
24 |
25 |
28 |
29 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/public/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "short_name": "React App",
3 | "name": "Create React App Sample",
4 | "icons": [
5 | {
6 | "src": "favicon.ico",
7 | "sizes": "64x64 32x32 24x24 16x16",
8 | "type": "image/x-icon"
9 | }
10 | ],
11 | "start_url": "./index.html",
12 | "display": "standalone",
13 | "theme_color": "#000000",
14 | "background_color": "#ffffff"
15 | }
16 |
--------------------------------------------------------------------------------
/src/App.css:
--------------------------------------------------------------------------------
1 | .app {
2 | display: flex;
3 | flex-direction: column;
4 | justify-content: center;
5 | align-items: center;
6 | text-align: center;
7 | }
8 |
9 | .test-example-list {
10 | display: flex;
11 | justify-content: center;
12 | flex-wrap: wrap;
13 | max-width: 700px;
14 | }
15 |
16 | .test-example-item {
17 | display: flex;
18 | border: 1px solid #000000;
19 | margin-right: 5px;
20 | margin-bottom: 5px;
21 | }
22 |
23 | .test-example-item-prediction {
24 | color: #FFFFFF;
25 | height: 56px;
26 | width: 56px;
27 | }
28 |
29 | .prediction-digit {
30 | margin: 3px;
31 | font-size: 28px;
32 | }
33 |
34 | .prediction-probability {
35 | font-size: 8px;
36 | }
37 |
38 | .pixel-row {
39 | display: flex;
40 | }
41 |
42 | .pixel {
43 | width: 2px;
44 | height: 2px;
45 | }
--------------------------------------------------------------------------------
/src/App.js:
--------------------------------------------------------------------------------
1 | import React, { Component } from 'react';
2 | import mnist from 'mnist';
3 |
4 | import './App.css';
5 |
6 | import MnistModel from './neuralNetwork';
7 |
8 | const ITERATIONS = 750;
9 | const TRAINING_SET_SIZE = 3000;
10 | const TEST_SET_SIZE = 50;
11 |
12 | class App extends Component {
13 |
14 | testSet;
15 | trainingSet;
16 | mnistModel;
17 |
18 | constructor() {
19 | super();
20 |
21 | const { training, test } = mnist.set(TRAINING_SET_SIZE, TEST_SET_SIZE);
22 | this.testSet = test;
23 | this.trainingSet = training;
24 |
25 | this.mnistModel = new MnistModel();
26 | this.mnistModel.setupSession(this.trainingSet);
27 |
28 | this.state = {
29 | currentIteration: 0,
30 | cost: -1,
31 | };
32 | }
33 |
34 | componentDidMount () {
35 | requestAnimationFrame(this.tick);
36 | };
37 |
38 | tick = () => {
39 | this.setState((state) => ({ currentIteration: state.currentIteration + 1 }));
40 |
41 | if (this.state.currentIteration < ITERATIONS) {
42 | requestAnimationFrame(this.tick);
43 |
44 | let computeCost = !(this.state.currentIteration % 5);
45 | let cost = this.mnistModel.train(this.state.currentIteration, computeCost);
46 |
47 | if (cost > 0) {
48 | this.setState(() => ({ cost }));
49 | }
50 | }
51 | };
52 |
53 | render() {
54 | const { currentIteration, cost } = this.state;
55 | return (
56 |
57 |
58 |
Neural Network for MNIST Digit Recognition in JavaScript
59 |
Iterations: {currentIteration}
60 |
Cost: {cost.toFixed(3)}
61 |
62 |
63 |
67 |
68 | );
69 | }
70 | }
71 |
72 | const TestExamples = ({ model, testSet }) =>
73 |
74 | {Array(TEST_SET_SIZE).fill(0).map((v, i) =>
75 |
81 | )}
82 |
83 |
84 | const TestExampleItem = ({ model, input, output }) =>
85 |
95 |
96 | class MnistDigit extends Component {
97 | shouldComponentUpdate() {
98 | return false;
99 | }
100 |
101 | render() {
102 | const { digitInput } = this.props;
103 | return (
104 |
105 | {fromUnrolledToPartition(digitInput, 28).map((row, i) =>
106 |
107 | {row.map((p, j) =>
108 |
113 | )}
114 |
115 | )}
116 |
117 | );
118 | }
119 | }
120 |
121 | const PredictedMnistDigit = ({ digitInput, digitOutput }) => {
122 | const digit = fromClassifierToDigit(digitInput);
123 |
124 | return (
125 |
129 |
130 | {digit.number}
131 |
132 |
133 | p(x)={digit.probability.toFixed(2)}
134 |
135 |
136 | );
137 | }
138 |
139 | const getColor = (output, digit) =>
140 | fromClassifierToDigit(output).number === digit.number
141 | ? { backgroundColor: '#55AA55' }
142 | : { backgroundColor: '#D46A6A' }
143 |
144 | const fromClassifierToDigit = (classifier) =>
145 | classifier.reduce(toNumber, { number: -1, probability: -1 });
146 |
147 | const toNumber = (result, value, key) => {
148 | if (value > result.probability) {
149 | result = { number: key, probability: value };
150 | }
151 | return result;
152 | };
153 |
154 | const fromUnrolledToPartition = (digit, size) =>
155 | digit.reduce(toPartition(size), []);
156 |
157 | const toPartition = (size) => (result, value, key) => {
158 | if (key % size === 0) {
159 | result.push([]);
160 | }
161 |
162 | result[result.length - 1].push(value);
163 |
164 | return result;
165 | };
166 |
167 | const denormalizeAndColorize = (p) =>
168 | compose(
169 | toColor,
170 | denormalize
171 | )(p);
172 |
173 | const denormalize = (p) =>
174 | (p * 255).toFixed(0);
175 |
176 | const toColor = (colorChannel) =>
177 | `rgb(${colorChannel}, ${colorChannel}, ${colorChannel})`;
178 |
179 | const compose = (...fns) =>
180 | fns.reduce((f, g) => (...args) => f(g(...args)));
181 |
182 | export default App;
183 |
--------------------------------------------------------------------------------
/src/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | margin: 0;
3 | padding: 0;
4 | font-family: sans-serif;
5 | }
6 |
--------------------------------------------------------------------------------
/src/index.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom';
3 | import './index.css';
4 | import App from './App';
5 | import registerServiceWorker from './registerServiceWorker';
6 |
7 | ReactDOM.render(, document.getElementById('root'));
8 | registerServiceWorker();
9 |
--------------------------------------------------------------------------------
/src/neuralNetwork.js:
--------------------------------------------------------------------------------
1 | import {
2 | Array1D,
3 | InCPUMemoryShuffledInputProviderBuilder,
4 | Graph,
5 | Session,
6 | SGDOptimizer,
7 | NDArrayMathGPU,
8 | CostReduction,
9 | } from 'deeplearn';
10 |
11 | const math = new NDArrayMathGPU();
12 |
13 | class MnistModel {
14 | session;
15 |
16 | initialLearningRate = 0.06;
17 | optimizer;
18 |
19 | batchSize = 300;
20 |
21 | inputTensor;
22 | targetTensor;
23 | costTensor;
24 | predictionTensor;
25 |
26 | feedEntries;
27 |
28 | constructor() {
29 | this.optimizer = new SGDOptimizer(this.initialLearningRate);
30 | }
31 |
32 | setupSession(trainingSet) {
33 | const graph = new Graph();
34 |
35 | this.inputTensor = graph.placeholder('input unrolled pixels', [784]);
36 | this.targetTensor = graph.placeholder('output digit classifier', [10]);
37 |
38 | let fullyConnectedLayer = this.createFullyConnectedLayer(graph, this.inputTensor, 0, 64);
39 | fullyConnectedLayer = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 1, 32);
40 | fullyConnectedLayer = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 2, 16);
41 |
42 | this.predictionTensor = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 3, 10);
43 | this.costTensor = graph.meanSquaredCost(this.targetTensor, this.predictionTensor);
44 |
45 | this.session = new Session(graph, math);
46 |
47 | this.prepareTrainingSet(trainingSet);
48 | }
49 |
50 | prepareTrainingSet(trainingSet) {
51 | math.scope(() => {
52 | const inputArray = trainingSet.map(v => Array1D.new(v.input));
53 | const targetArray = trainingSet.map(v => Array1D.new(v.output));
54 |
55 | const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder([ inputArray, targetArray ]);
56 | const [ inputProvider, targetProvider ] = shuffledInputProviderBuilder.getInputProviders();
57 |
58 | this.feedEntries = [
59 | { tensor: this.inputTensor, data: inputProvider },
60 | { tensor: this.targetTensor, data: targetProvider },
61 | ];
62 | });
63 | }
64 |
65 | train(step, computeCost) {
66 | let learningRate = this.initialLearningRate * Math.pow(0.90, Math.floor(step / 50));
67 | this.optimizer.setLearningRate(learningRate);
68 |
69 | let costValue;
70 | math.scope(() => {
71 | const cost = this.session.train(
72 | this.costTensor,
73 | this.feedEntries,
74 | this.batchSize,
75 | this.optimizer,
76 | computeCost ? CostReduction.MEAN : CostReduction.NONE,
77 | );
78 |
79 | if (computeCost) {
80 | costValue = cost.get();
81 | }
82 | });
83 |
84 | return costValue;
85 | }
86 |
87 | predict(pixels) {
88 | let classifier = [];
89 |
90 | math.scope(() => {
91 | const mapping = [{
92 | tensor: this.inputTensor,
93 | data: Array1D.new(pixels),
94 | }];
95 |
96 | classifier = this.session.eval(this.predictionTensor, mapping).getValues();
97 | });
98 |
99 | return [ ...classifier ];
100 | }
101 |
102 | createFullyConnectedLayer(
103 | graph,
104 | inputLayer,
105 | layerIndex,
106 | units,
107 | activationFunction
108 | ) {
109 | return graph.layers.dense(
110 | `fully_connected_${layerIndex}`,
111 | inputLayer,
112 | units,
113 | activationFunction
114 | ? activationFunction
115 | : (x) => graph.relu(x)
116 | );
117 | }
118 | }
119 |
120 | export default MnistModel;
--------------------------------------------------------------------------------
/src/registerServiceWorker.js:
--------------------------------------------------------------------------------
1 | // In production, we register a service worker to serve assets from local cache.
2 |
3 | // This lets the app load faster on subsequent visits in production, and gives
4 | // it offline capabilities. However, it also means that developers (and users)
5 | // will only see deployed updates on the "N+1" visit to a page, since previously
6 | // cached resources are updated in the background.
7 |
8 | // To learn more about the benefits of this model, read https://goo.gl/KwvDNy.
9 | // This link also includes instructions on opting out of this behavior.
10 |
11 | const isLocalhost = Boolean(
12 | window.location.hostname === 'localhost' ||
13 | // [::1] is the IPv6 localhost address.
14 | window.location.hostname === '[::1]' ||
15 | // 127.0.0.1/8 is considered localhost for IPv4.
16 | window.location.hostname.match(
17 | /^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/
18 | )
19 | );
20 |
21 | export default function register() {
22 | if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) {
23 | // The URL constructor is available in all browsers that support SW.
24 | const publicUrl = new URL(process.env.PUBLIC_URL, window.location);
25 | if (publicUrl.origin !== window.location.origin) {
26 | // Our service worker won't work if PUBLIC_URL is on a different origin
27 | // from what our page is served on. This might happen if a CDN is used to
28 | // serve assets; see https://github.com/facebookincubator/create-react-app/issues/2374
29 | return;
30 | }
31 |
32 | window.addEventListener('load', () => {
33 | const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`;
34 |
35 | if (isLocalhost) {
36 | // This is running on localhost. Lets check if a service worker still exists or not.
37 | checkValidServiceWorker(swUrl);
38 | } else {
39 | // Is not local host. Just register service worker
40 | registerValidSW(swUrl);
41 | }
42 | });
43 | }
44 | }
45 |
46 | function registerValidSW(swUrl) {
47 | navigator.serviceWorker
48 | .register(swUrl)
49 | .then(registration => {
50 | registration.onupdatefound = () => {
51 | const installingWorker = registration.installing;
52 | installingWorker.onstatechange = () => {
53 | if (installingWorker.state === 'installed') {
54 | if (navigator.serviceWorker.controller) {
55 | // At this point, the old content will have been purged and
56 | // the fresh content will have been added to the cache.
57 | // It's the perfect time to display a "New content is
58 | // available; please refresh." message in your web app.
59 | console.log('New content is available; please refresh.');
60 | } else {
61 | // At this point, everything has been precached.
62 | // It's the perfect time to display a
63 | // "Content is cached for offline use." message.
64 | console.log('Content is cached for offline use.');
65 | }
66 | }
67 | };
68 | };
69 | })
70 | .catch(error => {
71 | console.error('Error during service worker registration:', error);
72 | });
73 | }
74 |
75 | function checkValidServiceWorker(swUrl) {
76 | // Check if the service worker can be found. If it can't reload the page.
77 | fetch(swUrl)
78 | .then(response => {
79 | // Ensure service worker exists, and that we really are getting a JS file.
80 | if (
81 | response.status === 404 ||
82 | response.headers.get('content-type').indexOf('javascript') === -1
83 | ) {
84 | // No service worker found. Probably a different app. Reload the page.
85 | navigator.serviceWorker.ready.then(registration => {
86 | registration.unregister().then(() => {
87 | window.location.reload();
88 | });
89 | });
90 | } else {
91 | // Service worker found. Proceed as normal.
92 | registerValidSW(swUrl);
93 | }
94 | })
95 | .catch(() => {
96 | console.log(
97 | 'No internet connection found. App is running in offline mode.'
98 | );
99 | });
100 | }
101 |
102 | export function unregister() {
103 | if ('serviceWorker' in navigator) {
104 | navigator.serviceWorker.ready.then(registration => {
105 | registration.unregister();
106 | });
107 | }
108 | }
109 |
--------------------------------------------------------------------------------