├── .gitignore ├── .travis.yml ├── README.md ├── package-lock.json ├── package.json ├── public ├── favicon.ico ├── index.html └── manifest.json └── src ├── App.css ├── App.js ├── index.css ├── index.js ├── neuralNetwork.js └── registerServiceWorker.js /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/ignore-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | 6 | # testing 7 | /coverage 8 | 9 | # production 10 | /build 11 | 12 | # misc 13 | .DS_Store 14 | .env.local 15 | .env.development.local 16 | .env.test.local 17 | .env.production.local 18 | 19 | npm-debug.log* 20 | yarn-debug.log* 21 | yarn-error.log* 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: node_js 2 | 3 | node_js: 4 | - stable 5 | 6 | install: 7 | - npm install 8 | 9 | script: 10 | - npm test -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNIST Digit Recognition Neural Network in JavaScript with Deeplearn.js 2 | 3 | [![Build Status](https://travis-ci.org/javascript-machine-learning/mnist-neural-network-deeplearnjs.svg?branch=master)](https://travis-ci.org/javascript-machine-learning/mnist-neural-network-deeplearnjs) 4 | 5 | Use Case: Recognizing handwritten digits from the [MNIST Database](https://en.wikipedia.org/wiki/MNIST_database). 6 | 7 | This example project demonstrates how neural networks may be used to solve a multi-class classification problem. It uses [deeplearn.js](https://deeplearnjs.org/) to recognize handwritten digits from the MNIST database. 8 | 9 | ![dec-14-2017 10-56-04](https://user-images.githubusercontent.com/2479967/33973368-74405dc0-e0bd-11e7-929f-d8a8b9aab55f.gif) 10 | 11 | ## Installation 12 | 13 | * `git clone git@github.com:javascript-machine-learning/mnist-neural-network-deeplearnjs.git` 14 | * `cd mnist-neural-network-deeplearnjs` 15 | * `npm install` 16 | * `npm start` 17 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mnist-neural-network-deeplearnjs", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "deeplearn": "0.3.12", 7 | "mnist": "^1.1.0", 8 | "react": "^16.1.1", 9 | "react-dom": "^16.1.1", 10 | "react-scripts": "1.0.17" 11 | }, 12 | "scripts": { 13 | "start": "react-scripts start", 14 | "build": "react-scripts build", 15 | "test": "react-scripts test --env=jsdom", 16 | "eject": "react-scripts eject" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/javascript-machine-learning/mnist-neural-network-deeplearnjs/bf3bae4da20d2aaf8830ffeff0a4aa1beb271246/public/favicon.ico -------------------------------------------------------------------------------- /public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | 22 | React App 23 | 24 | 25 | 28 |
29 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | } 10 | ], 11 | "start_url": "./index.html", 12 | "display": "standalone", 13 | "theme_color": "#000000", 14 | "background_color": "#ffffff" 15 | } 16 | -------------------------------------------------------------------------------- /src/App.css: -------------------------------------------------------------------------------- 1 | .app { 2 | display: flex; 3 | flex-direction: column; 4 | justify-content: center; 5 | align-items: center; 6 | text-align: center; 7 | } 8 | 9 | .test-example-list { 10 | display: flex; 11 | justify-content: center; 12 | flex-wrap: wrap; 13 | max-width: 700px; 14 | } 15 | 16 | .test-example-item { 17 | display: flex; 18 | border: 1px solid #000000; 19 | margin-right: 5px; 20 | margin-bottom: 5px; 21 | } 22 | 23 | .test-example-item-prediction { 24 | color: #FFFFFF; 25 | height: 56px; 26 | width: 56px; 27 | } 28 | 29 | .prediction-digit { 30 | margin: 3px; 31 | font-size: 28px; 32 | } 33 | 34 | .prediction-probability { 35 | font-size: 8px; 36 | } 37 | 38 | .pixel-row { 39 | display: flex; 40 | } 41 | 42 | .pixel { 43 | width: 2px; 44 | height: 2px; 45 | } -------------------------------------------------------------------------------- /src/App.js: -------------------------------------------------------------------------------- 1 | import React, { Component } from 'react'; 2 | import mnist from 'mnist'; 3 | 4 | import './App.css'; 5 | 6 | import MnistModel from './neuralNetwork'; 7 | 8 | const ITERATIONS = 750; 9 | const TRAINING_SET_SIZE = 3000; 10 | const TEST_SET_SIZE = 50; 11 | 12 | class App extends Component { 13 | 14 | testSet; 15 | trainingSet; 16 | mnistModel; 17 | 18 | constructor() { 19 | super(); 20 | 21 | const { training, test } = mnist.set(TRAINING_SET_SIZE, TEST_SET_SIZE); 22 | this.testSet = test; 23 | this.trainingSet = training; 24 | 25 | this.mnistModel = new MnistModel(); 26 | this.mnistModel.setupSession(this.trainingSet); 27 | 28 | this.state = { 29 | currentIteration: 0, 30 | cost: -1, 31 | }; 32 | } 33 | 34 | componentDidMount () { 35 | requestAnimationFrame(this.tick); 36 | }; 37 | 38 | tick = () => { 39 | this.setState((state) => ({ currentIteration: state.currentIteration + 1 })); 40 | 41 | if (this.state.currentIteration < ITERATIONS) { 42 | requestAnimationFrame(this.tick); 43 | 44 | let computeCost = !(this.state.currentIteration % 5); 45 | let cost = this.mnistModel.train(this.state.currentIteration, computeCost); 46 | 47 | if (cost > 0) { 48 | this.setState(() => ({ cost })); 49 | } 50 | } 51 | }; 52 | 53 | render() { 54 | const { currentIteration, cost } = this.state; 55 | return ( 56 |
57 |
58 |

Neural Network for MNIST Digit Recognition in JavaScript

59 |

Iterations: {currentIteration}

60 |

Cost: {cost.toFixed(3)}

61 |
62 | 63 | 67 |
68 | ); 69 | } 70 | } 71 | 72 | const TestExamples = ({ model, testSet }) => 73 |
74 | {Array(TEST_SET_SIZE).fill(0).map((v, i) => 75 | 81 | )} 82 |
83 | 84 | const TestExampleItem = ({ model, input, output }) => 85 |
86 | 89 | 90 | 94 |
95 | 96 | class MnistDigit extends Component { 97 | shouldComponentUpdate() { 98 | return false; 99 | } 100 | 101 | render() { 102 | const { digitInput } = this.props; 103 | return ( 104 |
105 | {fromUnrolledToPartition(digitInput, 28).map((row, i) => 106 |
107 | {row.map((p, j) => 108 |
113 | )} 114 |
115 | )} 116 |
117 | ); 118 | } 119 | } 120 | 121 | const PredictedMnistDigit = ({ digitInput, digitOutput }) => { 122 | const digit = fromClassifierToDigit(digitInput); 123 | 124 | return ( 125 |
129 |
130 | {digit.number} 131 |
132 |
133 | p(x)={digit.probability.toFixed(2)} 134 |
135 |
136 | ); 137 | } 138 | 139 | const getColor = (output, digit) => 140 | fromClassifierToDigit(output).number === digit.number 141 | ? { backgroundColor: '#55AA55' } 142 | : { backgroundColor: '#D46A6A' } 143 | 144 | const fromClassifierToDigit = (classifier) => 145 | classifier.reduce(toNumber, { number: -1, probability: -1 }); 146 | 147 | const toNumber = (result, value, key) => { 148 | if (value > result.probability) { 149 | result = { number: key, probability: value }; 150 | } 151 | return result; 152 | }; 153 | 154 | const fromUnrolledToPartition = (digit, size) => 155 | digit.reduce(toPartition(size), []); 156 | 157 | const toPartition = (size) => (result, value, key) => { 158 | if (key % size === 0) { 159 | result.push([]); 160 | } 161 | 162 | result[result.length - 1].push(value); 163 | 164 | return result; 165 | }; 166 | 167 | const denormalizeAndColorize = (p) => 168 | compose( 169 | toColor, 170 | denormalize 171 | )(p); 172 | 173 | const denormalize = (p) => 174 | (p * 255).toFixed(0); 175 | 176 | const toColor = (colorChannel) => 177 | `rgb(${colorChannel}, ${colorChannel}, ${colorChannel})`; 178 | 179 | const compose = (...fns) => 180 | fns.reduce((f, g) => (...args) => f(g(...args))); 181 | 182 | export default App; 183 | -------------------------------------------------------------------------------- /src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | padding: 0; 4 | font-family: sans-serif; 5 | } 6 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import './index.css'; 4 | import App from './App'; 5 | import registerServiceWorker from './registerServiceWorker'; 6 | 7 | ReactDOM.render(, document.getElementById('root')); 8 | registerServiceWorker(); 9 | -------------------------------------------------------------------------------- /src/neuralNetwork.js: -------------------------------------------------------------------------------- 1 | import { 2 | Array1D, 3 | InCPUMemoryShuffledInputProviderBuilder, 4 | Graph, 5 | Session, 6 | SGDOptimizer, 7 | NDArrayMathGPU, 8 | CostReduction, 9 | } from 'deeplearn'; 10 | 11 | const math = new NDArrayMathGPU(); 12 | 13 | class MnistModel { 14 | session; 15 | 16 | initialLearningRate = 0.06; 17 | optimizer; 18 | 19 | batchSize = 300; 20 | 21 | inputTensor; 22 | targetTensor; 23 | costTensor; 24 | predictionTensor; 25 | 26 | feedEntries; 27 | 28 | constructor() { 29 | this.optimizer = new SGDOptimizer(this.initialLearningRate); 30 | } 31 | 32 | setupSession(trainingSet) { 33 | const graph = new Graph(); 34 | 35 | this.inputTensor = graph.placeholder('input unrolled pixels', [784]); 36 | this.targetTensor = graph.placeholder('output digit classifier', [10]); 37 | 38 | let fullyConnectedLayer = this.createFullyConnectedLayer(graph, this.inputTensor, 0, 64); 39 | fullyConnectedLayer = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 1, 32); 40 | fullyConnectedLayer = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 2, 16); 41 | 42 | this.predictionTensor = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 3, 10); 43 | this.costTensor = graph.meanSquaredCost(this.targetTensor, this.predictionTensor); 44 | 45 | this.session = new Session(graph, math); 46 | 47 | this.prepareTrainingSet(trainingSet); 48 | } 49 | 50 | prepareTrainingSet(trainingSet) { 51 | math.scope(() => { 52 | const inputArray = trainingSet.map(v => Array1D.new(v.input)); 53 | const targetArray = trainingSet.map(v => Array1D.new(v.output)); 54 | 55 | const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder([ inputArray, targetArray ]); 56 | const [ inputProvider, targetProvider ] = shuffledInputProviderBuilder.getInputProviders(); 57 | 58 | this.feedEntries = [ 59 | { tensor: this.inputTensor, data: inputProvider }, 60 | { tensor: this.targetTensor, data: targetProvider }, 61 | ]; 62 | }); 63 | } 64 | 65 | train(step, computeCost) { 66 | let learningRate = this.initialLearningRate * Math.pow(0.90, Math.floor(step / 50)); 67 | this.optimizer.setLearningRate(learningRate); 68 | 69 | let costValue; 70 | math.scope(() => { 71 | const cost = this.session.train( 72 | this.costTensor, 73 | this.feedEntries, 74 | this.batchSize, 75 | this.optimizer, 76 | computeCost ? CostReduction.MEAN : CostReduction.NONE, 77 | ); 78 | 79 | if (computeCost) { 80 | costValue = cost.get(); 81 | } 82 | }); 83 | 84 | return costValue; 85 | } 86 | 87 | predict(pixels) { 88 | let classifier = []; 89 | 90 | math.scope(() => { 91 | const mapping = [{ 92 | tensor: this.inputTensor, 93 | data: Array1D.new(pixels), 94 | }]; 95 | 96 | classifier = this.session.eval(this.predictionTensor, mapping).getValues(); 97 | }); 98 | 99 | return [ ...classifier ]; 100 | } 101 | 102 | createFullyConnectedLayer( 103 | graph, 104 | inputLayer, 105 | layerIndex, 106 | units, 107 | activationFunction 108 | ) { 109 | return graph.layers.dense( 110 | `fully_connected_${layerIndex}`, 111 | inputLayer, 112 | units, 113 | activationFunction 114 | ? activationFunction 115 | : (x) => graph.relu(x) 116 | ); 117 | } 118 | } 119 | 120 | export default MnistModel; -------------------------------------------------------------------------------- /src/registerServiceWorker.js: -------------------------------------------------------------------------------- 1 | // In production, we register a service worker to serve assets from local cache. 2 | 3 | // This lets the app load faster on subsequent visits in production, and gives 4 | // it offline capabilities. However, it also means that developers (and users) 5 | // will only see deployed updates on the "N+1" visit to a page, since previously 6 | // cached resources are updated in the background. 7 | 8 | // To learn more about the benefits of this model, read https://goo.gl/KwvDNy. 9 | // This link also includes instructions on opting out of this behavior. 10 | 11 | const isLocalhost = Boolean( 12 | window.location.hostname === 'localhost' || 13 | // [::1] is the IPv6 localhost address. 14 | window.location.hostname === '[::1]' || 15 | // 127.0.0.1/8 is considered localhost for IPv4. 16 | window.location.hostname.match( 17 | /^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/ 18 | ) 19 | ); 20 | 21 | export default function register() { 22 | if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) { 23 | // The URL constructor is available in all browsers that support SW. 24 | const publicUrl = new URL(process.env.PUBLIC_URL, window.location); 25 | if (publicUrl.origin !== window.location.origin) { 26 | // Our service worker won't work if PUBLIC_URL is on a different origin 27 | // from what our page is served on. This might happen if a CDN is used to 28 | // serve assets; see https://github.com/facebookincubator/create-react-app/issues/2374 29 | return; 30 | } 31 | 32 | window.addEventListener('load', () => { 33 | const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`; 34 | 35 | if (isLocalhost) { 36 | // This is running on localhost. Lets check if a service worker still exists or not. 37 | checkValidServiceWorker(swUrl); 38 | } else { 39 | // Is not local host. Just register service worker 40 | registerValidSW(swUrl); 41 | } 42 | }); 43 | } 44 | } 45 | 46 | function registerValidSW(swUrl) { 47 | navigator.serviceWorker 48 | .register(swUrl) 49 | .then(registration => { 50 | registration.onupdatefound = () => { 51 | const installingWorker = registration.installing; 52 | installingWorker.onstatechange = () => { 53 | if (installingWorker.state === 'installed') { 54 | if (navigator.serviceWorker.controller) { 55 | // At this point, the old content will have been purged and 56 | // the fresh content will have been added to the cache. 57 | // It's the perfect time to display a "New content is 58 | // available; please refresh." message in your web app. 59 | console.log('New content is available; please refresh.'); 60 | } else { 61 | // At this point, everything has been precached. 62 | // It's the perfect time to display a 63 | // "Content is cached for offline use." message. 64 | console.log('Content is cached for offline use.'); 65 | } 66 | } 67 | }; 68 | }; 69 | }) 70 | .catch(error => { 71 | console.error('Error during service worker registration:', error); 72 | }); 73 | } 74 | 75 | function checkValidServiceWorker(swUrl) { 76 | // Check if the service worker can be found. If it can't reload the page. 77 | fetch(swUrl) 78 | .then(response => { 79 | // Ensure service worker exists, and that we really are getting a JS file. 80 | if ( 81 | response.status === 404 || 82 | response.headers.get('content-type').indexOf('javascript') === -1 83 | ) { 84 | // No service worker found. Probably a different app. Reload the page. 85 | navigator.serviceWorker.ready.then(registration => { 86 | registration.unregister().then(() => { 87 | window.location.reload(); 88 | }); 89 | }); 90 | } else { 91 | // Service worker found. Proceed as normal. 92 | registerValidSW(swUrl); 93 | } 94 | }) 95 | .catch(() => { 96 | console.log( 97 | 'No internet connection found. App is running in offline mode.' 98 | ); 99 | }); 100 | } 101 | 102 | export function unregister() { 103 | if ('serviceWorker' in navigator) { 104 | navigator.serviceWorker.ready.then(registration => { 105 | registration.unregister(); 106 | }); 107 | } 108 | } 109 | --------------------------------------------------------------------------------