├── .eslintrc ├── .gitignore ├── LICENSE ├── Procfile ├── README.md ├── app.json ├── gulpfile.js ├── main.py ├── mnist ├── __init__.py ├── convolutional.py ├── data │ ├── convolutional.ckpt.data-00000-of-00001 │ ├── convolutional.ckpt.index │ ├── regression.ckpt.data-00000-of-00001 │ └── regression.ckpt.index ├── model.py └── regression.py ├── package.json ├── requirements.txt ├── runtime.txt ├── src └── js │ └── main.js ├── static ├── css │ └── bootstrap.min.css └── js │ └── jquery.min.js └── templates └── index.html /.eslintrc: -------------------------------------------------------------------------------- 1 | { 2 | "rules": { 3 | "indent": [ 4 | 2, 5 | 4 6 | ], 7 | "quotes": [ 8 | 2, 9 | "single" 10 | ], 11 | "linebreak-style": [ 12 | 2, 13 | "unix" 14 | ], 15 | "semi": [ 16 | 2, 17 | "always" 18 | ] 19 | }, 20 | "env": { 21 | "es6": true, 22 | "node": true, 23 | "browser": true 24 | }, 25 | "extends": "eslint:recommended" 26 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | *.pyc 3 | node_modules 4 | static/js/main.js 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Yoshihiro Sugi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: gunicorn main:app --log-file=- 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNIST classification by TensorFlow # 2 | 3 | - [MNIST For ML Beginners](https://www.tensorflow.org/tutorials/mnist/beginners/) 4 | - [Deep MNIST for Experts](https://www.tensorflow.org/tutorials/mnist/pros/) 5 | 6 |  7 | 8 | ### Requirement ### 9 | 10 | - Python >=2.7 or >=3.4 11 | - TensorFlow >=1.0 12 | - Node >=6.9 13 | 14 | 15 | ### How to run ### 16 | 17 | $ pip install -r requirements.txt 18 | $ npm install 19 | $ gunicorn main:app --log-file=- 20 | 21 | 22 | ### Deploy to Heroku ### 23 | 24 | $ heroku apps:create [NAME] 25 | $ heroku buildpacks:add heroku/nodejs 26 | $ heroku buildpacks:add heroku/python 27 | $ git push heroku master 28 | 29 | or Heroku Button. 30 | 31 | [](https://heroku.com/deploy) 32 | -------------------------------------------------------------------------------- /app.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorflow-mnist", 3 | "buildpacks": [ 4 | { "url": "https://github.com/heroku/heroku-buildpack-nodejs" }, 5 | { "url": "https://github.com/heroku/heroku-buildpack-python" } 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /gulpfile.js: -------------------------------------------------------------------------------- 1 | var gulp = require('gulp'); 2 | var babel = require('gulp-babel'); 3 | var sourcemaps = require('gulp-sourcemaps'); 4 | var uglify = require('gulp-uglify'); 5 | 6 | gulp.task('build', function() { 7 | return gulp.src('src/js/*.js') 8 | .pipe(babel({ presets: ['es2015'] })) 9 | .pipe(sourcemaps.init({ loadMaps: true })) 10 | .pipe(uglify()) 11 | .pipe(sourcemaps.write()) 12 | .pipe(gulp.dest('static/js')); 13 | }); 14 | 15 | gulp.task('watch', function() { 16 | gulp.watch('src/js/*.js', ['build']); 17 | }); 18 | 19 | gulp.task('default', ['build']); 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from flask import Flask, jsonify, render_template, request 4 | 5 | from mnist import model 6 | 7 | 8 | x = tf.placeholder("float", [None, 784]) 9 | sess = tf.Session() 10 | 11 | # restore trained data 12 | with tf.variable_scope("regression"): 13 | y1, variables = model.regression(x) 14 | saver = tf.train.Saver(variables) 15 | saver.restore(sess, "mnist/data/regression.ckpt") 16 | 17 | 18 | with tf.variable_scope("convolutional"): 19 | keep_prob = tf.placeholder("float") 20 | y2, variables = model.convolutional(x, keep_prob) 21 | saver = tf.train.Saver(variables) 22 | saver.restore(sess, "mnist/data/convolutional.ckpt") 23 | 24 | 25 | def regression(input): 26 | return sess.run(y1, feed_dict={x: input}).flatten().tolist() 27 | 28 | 29 | def convolutional(input): 30 | return sess.run(y2, feed_dict={x: input, keep_prob: 1.0}).flatten().tolist() 31 | 32 | 33 | # webapp 34 | app = Flask(__name__) 35 | 36 | 37 | @app.route('/api/mnist', methods=['POST']) 38 | def mnist(): 39 | input = ((255 - np.array(request.json, dtype=np.uint8)) / 255.0).reshape(1, 784) 40 | output1 = regression(input) 41 | output2 = convolutional(input) 42 | return jsonify(results=[output1, output2]) 43 | 44 | 45 | @app.route('/') 46 | def main(): 47 | return render_template('index.html') 48 | 49 | 50 | if __name__ == '__main__': 51 | app.run() 52 | -------------------------------------------------------------------------------- /mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugyan/tensorflow-mnist/a98b50dc584fb3f56ca9181d441cf886f2b75ee3/mnist/__init__.py -------------------------------------------------------------------------------- /mnist/convolutional.py: -------------------------------------------------------------------------------- 1 | import os 2 | import model 3 | import tensorflow as tf 4 | 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | data = input_data.read_data_sets("/tmp/data/", one_hot=True) 7 | 8 | # model 9 | with tf.variable_scope("convolutional"): 10 | x = tf.placeholder(tf.float32, [None, 784]) 11 | keep_prob = tf.placeholder(tf.float32) 12 | y, variables = model.convolutional(x, keep_prob) 13 | 14 | # train 15 | y_ = tf.placeholder(tf.float32, [None, 10]) 16 | cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) 17 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 18 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 19 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 20 | 21 | saver = tf.train.Saver(variables) 22 | with tf.Session() as sess: 23 | sess.run(tf.global_variables_initializer()) 24 | for i in range(20000): 25 | batch = data.train.next_batch(50) 26 | if i % 100 == 0: 27 | train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}) 28 | print("step %d, training accuracy %g" % (i, train_accuracy)) 29 | sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 30 | 31 | print(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels, keep_prob: 1.0})) 32 | 33 | path = saver.save( 34 | sess, os.path.join(os.path.dirname(__file__), 'data', 'convolutional.ckpt'), 35 | write_meta_graph=False, write_state=False) 36 | print("Saved:", path) 37 | -------------------------------------------------------------------------------- /mnist/data/convolutional.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugyan/tensorflow-mnist/a98b50dc584fb3f56ca9181d441cf886f2b75ee3/mnist/data/convolutional.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /mnist/data/convolutional.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugyan/tensorflow-mnist/a98b50dc584fb3f56ca9181d441cf886f2b75ee3/mnist/data/convolutional.ckpt.index -------------------------------------------------------------------------------- /mnist/data/regression.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugyan/tensorflow-mnist/a98b50dc584fb3f56ca9181d441cf886f2b75ee3/mnist/data/regression.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /mnist/data/regression.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugyan/tensorflow-mnist/a98b50dc584fb3f56ca9181d441cf886f2b75ee3/mnist/data/regression.ckpt.index -------------------------------------------------------------------------------- /mnist/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # Softmax Regression Model 5 | def regression(x): 6 | W = tf.Variable(tf.zeros([784, 10]), name="W") 7 | b = tf.Variable(tf.zeros([10]), name="b") 8 | y = tf.nn.softmax(tf.matmul(x, W) + b) 9 | return y, [W, b] 10 | 11 | 12 | # Multilayer Convolutional Network 13 | def convolutional(x, keep_prob): 14 | def conv2d(x, W): 15 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 16 | 17 | def max_pool_2x2(x): 18 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 19 | 20 | def weight_variable(shape): 21 | initial = tf.truncated_normal(shape, stddev=0.1) 22 | return tf.Variable(initial) 23 | 24 | def bias_variable(shape): 25 | initial = tf.constant(0.1, shape=shape) 26 | return tf.Variable(initial) 27 | 28 | # First Convolutional Layer 29 | x_image = tf.reshape(x, [-1, 28, 28, 1]) 30 | W_conv1 = weight_variable([5, 5, 1, 32]) 31 | b_conv1 = bias_variable([32]) 32 | h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 33 | h_pool1 = max_pool_2x2(h_conv1) 34 | # Second Convolutional Layer 35 | W_conv2 = weight_variable([5, 5, 32, 64]) 36 | b_conv2 = bias_variable([64]) 37 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 38 | h_pool2 = max_pool_2x2(h_conv2) 39 | # Densely Connected Layer 40 | W_fc1 = weight_variable([7 * 7 * 64, 1024]) 41 | b_fc1 = bias_variable([1024]) 42 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 43 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 44 | # Dropout 45 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 46 | # Readout Layer 47 | W_fc2 = weight_variable([1024, 10]) 48 | b_fc2 = bias_variable([10]) 49 | y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 50 | return y, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2] 51 | -------------------------------------------------------------------------------- /mnist/regression.py: -------------------------------------------------------------------------------- 1 | import os 2 | import model 3 | import tensorflow as tf 4 | 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | data = input_data.read_data_sets("/tmp/data/", one_hot=True) 7 | 8 | # model 9 | with tf.variable_scope("regression"): 10 | x = tf.placeholder(tf.float32, [None, 784]) 11 | y, variables = model.regression(x) 12 | 13 | # train 14 | y_ = tf.placeholder("float", [None, 10]) 15 | cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) 16 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 17 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 18 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 19 | 20 | saver = tf.train.Saver(variables) 21 | with tf.Session() as sess: 22 | sess.run(tf.global_variables_initializer()) 23 | for _ in range(1000): 24 | batch_xs, batch_ys = data.train.next_batch(100) 25 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 26 | 27 | print(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels})) 28 | 29 | path = saver.save( 30 | sess, os.path.join(os.path.dirname(__file__), 'data', 'regression.ckpt'), 31 | write_meta_graph=False, write_state=False) 32 | print("Saved:", path) 33 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorflow-mnist", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1", 8 | "postinstall": "gulp" 9 | }, 10 | "keywords": [], 11 | "author": "", 12 | "license": "ISC", 13 | "repository": { 14 | "type": "git", 15 | "url": "https://github.com/sugyan/tensorflow-mnist.git" 16 | }, 17 | "engines": { 18 | "node": "6.x" 19 | }, 20 | "dependencies": { 21 | "babel-preset-es2015": "^6.1.18", 22 | "bootstrap": "^3.3.5", 23 | "gulp": "^3.9.0", 24 | "gulp-babel": "^6.1.0", 25 | "gulp-sourcemaps": "^1.6.0", 26 | "gulp-uglify": "^1.5.1", 27 | "jquery": "^3.0.0" 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.1 2 | click==6.7 3 | Flask==0.12 4 | gunicorn==19.6.0 5 | itsdangerous==0.24 6 | Jinja2==2.9.5 7 | MarkupSafe==0.23 8 | numpy==1.12.0 9 | packaging==16.8 10 | protobuf==3.2.0 11 | pyparsing==2.1.10 12 | six==1.10.0 13 | tensorflow==1.0.0 14 | Werkzeug==0.11.15 15 | -------------------------------------------------------------------------------- /runtime.txt: -------------------------------------------------------------------------------- 1 | python-3.6.0 -------------------------------------------------------------------------------- /src/js/main.js: -------------------------------------------------------------------------------- 1 | /* global $ */ 2 | class Main { 3 | constructor() { 4 | this.canvas = document.getElementById('main'); 5 | this.input = document.getElementById('input'); 6 | this.canvas.width = 449; // 16 * 28 + 1 7 | this.canvas.height = 449; // 16 * 28 + 1 8 | this.ctx = this.canvas.getContext('2d'); 9 | this.canvas.addEventListener('mousedown', this.onMouseDown.bind(this)); 10 | this.canvas.addEventListener('mouseup', this.onMouseUp.bind(this)); 11 | this.canvas.addEventListener('mousemove', this.onMouseMove.bind(this)); 12 | this.initialize(); 13 | } 14 | initialize() { 15 | this.ctx.fillStyle = '#FFFFFF'; 16 | this.ctx.fillRect(0, 0, 449, 449); 17 | this.ctx.lineWidth = 1; 18 | this.ctx.strokeRect(0, 0, 449, 449); 19 | this.ctx.lineWidth = 0.05; 20 | for (var i = 0; i < 27; i++) { 21 | this.ctx.beginPath(); 22 | this.ctx.moveTo((i + 1) * 16, 0); 23 | this.ctx.lineTo((i + 1) * 16, 449); 24 | this.ctx.closePath(); 25 | this.ctx.stroke(); 26 | 27 | this.ctx.beginPath(); 28 | this.ctx.moveTo( 0, (i + 1) * 16); 29 | this.ctx.lineTo(449, (i + 1) * 16); 30 | this.ctx.closePath(); 31 | this.ctx.stroke(); 32 | } 33 | this.drawInput(); 34 | $('#output td').text('').removeClass('success'); 35 | } 36 | onMouseDown(e) { 37 | this.canvas.style.cursor = 'default'; 38 | this.drawing = true; 39 | this.prev = this.getPosition(e.clientX, e.clientY); 40 | } 41 | onMouseUp() { 42 | this.drawing = false; 43 | this.drawInput(); 44 | } 45 | onMouseMove(e) { 46 | if (this.drawing) { 47 | var curr = this.getPosition(e.clientX, e.clientY); 48 | this.ctx.lineWidth = 16; 49 | this.ctx.lineCap = 'round'; 50 | this.ctx.beginPath(); 51 | this.ctx.moveTo(this.prev.x, this.prev.y); 52 | this.ctx.lineTo(curr.x, curr.y); 53 | this.ctx.stroke(); 54 | this.ctx.closePath(); 55 | this.prev = curr; 56 | } 57 | } 58 | getPosition(clientX, clientY) { 59 | var rect = this.canvas.getBoundingClientRect(); 60 | return { 61 | x: clientX - rect.left, 62 | y: clientY - rect.top 63 | }; 64 | } 65 | drawInput() { 66 | var ctx = this.input.getContext('2d'); 67 | var img = new Image(); 68 | img.onload = () => { 69 | var inputs = []; 70 | var small = document.createElement('canvas').getContext('2d'); 71 | small.drawImage(img, 0, 0, img.width, img.height, 0, 0, 28, 28); 72 | var data = small.getImageData(0, 0, 28, 28).data; 73 | for (var i = 0; i < 28; i++) { 74 | for (var j = 0; j < 28; j++) { 75 | var n = 4 * (i * 28 + j); 76 | inputs[i * 28 + j] = (data[n + 0] + data[n + 1] + data[n + 2]) / 3; 77 | ctx.fillStyle = 'rgb(' + [data[n + 0], data[n + 1], data[n + 2]].join(',') + ')'; 78 | ctx.fillRect(j * 5, i * 5, 5, 5); 79 | } 80 | } 81 | if (Math.min(...inputs) === 255) { 82 | return; 83 | } 84 | $.ajax({ 85 | url: '/api/mnist', 86 | method: 'POST', 87 | contentType: 'application/json', 88 | data: JSON.stringify(inputs), 89 | success: (data) => { 90 | for (let i = 0; i < 2; i++) { 91 | var max = 0; 92 | var max_index = 0; 93 | for (let j = 0; j < 10; j++) { 94 | var value = Math.round(data.results[i][j] * 1000); 95 | if (value > max) { 96 | max = value; 97 | max_index = j; 98 | } 99 | var digits = String(value).length; 100 | for (var k = 0; k < 3 - digits; k++) { 101 | value = '0' + value; 102 | } 103 | var text = '0.' + value; 104 | if (value > 999) { 105 | text = '1.000'; 106 | } 107 | $('#output tr').eq(j + 1).find('td').eq(i).text(text); 108 | } 109 | for (let j = 0; j < 10; j++) { 110 | if (j === max_index) { 111 | $('#output tr').eq(j + 1).find('td').eq(i).addClass('success'); 112 | } else { 113 | $('#output tr').eq(j + 1).find('td').eq(i).removeClass('success'); 114 | } 115 | } 116 | } 117 | } 118 | }); 119 | }; 120 | img.src = this.canvas.toDataURL(); 121 | } 122 | } 123 | 124 | $(() => { 125 | var main = new Main(); 126 | $('#clear').click(() => { 127 | main.initialize(); 128 | }); 129 | }); 130 | -------------------------------------------------------------------------------- /static/css/bootstrap.min.css: -------------------------------------------------------------------------------- 1 | ../../node_modules/bootstrap/dist/css/bootstrap.min.css -------------------------------------------------------------------------------- /static/js/jquery.min.js: -------------------------------------------------------------------------------- 1 | ../../node_modules/jquery/dist/jquery.min.js -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 |draw a digit here!
16 | 17 |18 | 19 |
20 |input:
23 | 24 |output:
26 |29 | | regression | 30 |convolutional | 31 |
---|---|---|
0 | 34 |35 | | 36 | |
1 | 39 |40 | | 41 | |
2 | 44 |45 | | 46 | |
3 | 49 |50 | | 51 | |
4 | 54 |55 | | 56 | |
5 | 59 |60 | | 61 | |
6 | 64 |65 | | 66 | |
7 | 69 |70 | | 71 | |
8 | 74 |75 | | 76 | |
9 | 79 |80 | | 81 | |