├── README.md
├── create_model_1.py
├── create_model_2.py
├── predict_1.py
└── predict_2.py
/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlowTM MNIST predict (recognise handwriting)
2 |
3 | This repository accompanies the blog post [Using TensorFlowTM to create your own handwriting recognition engine](http://niektemme.com/2016/02/21/tensorflow-handwriting/).
4 |
5 | ## Installation & Setup
6 |
7 | ### Overview
8 | This project uses the MNIST tutorials from the TensorFlow website. The two tutorials, the beginner tutorial and the expert tutorial, use different deep learning models. The python scripts ending with _1 use the model from the beginner tutorial. The scripts ending with _2 use the model from the advanced tutorial. As expected scripts using the model from the expert tutorial give better results.
9 |
10 | This projects consists of four scripts:
11 |
12 | 1. _create_model_1.py_ – creates a model model.ckpt file based on the beginners tutorial.
13 | 2. *create_model_2.py* – creates a model model2.ckpt file based on the expert tutorial.
14 | 3. *predict_1.py* – uses the model.ckpt (beginners tutorial) file to predict the correct integer form a handwritten number in a .png file.
15 | 4. *predict_2.py* – uses the model2.ckpt (expert tutorial) file to predict the correct integer form a handwritten number in a .png file.
16 |
17 | ### Dependencies
18 | The following Python libraries are required.
19 |
20 | - sys - should be installed by default
21 | - tensorflow - [TensorFlow](https://www.tensorflow.org/)
22 | - PIL – [Pillow](http://pillow.readthedocs.org)
23 |
24 | ### Installing TensorFlow
25 | Of course TensorFlow needs to be installed. The [TensorFlow website](https://www.tensorflow.org/versions/master/get_started/index.html) has a good manual .
26 |
27 | ### Installing Python Image Library (PIL)
28 | The Python Image Library (PIL) is no longer available. Luckily there is a good fork called Pillow. Installing is as easy as:
29 |
30 | ```sudo pip install Pillow```
31 |
32 | Or look at the [Pillow documentation ](http://pillow.readthedocs.org) for other installation options.
33 |
34 | ### The python scripts
35 | The easiest way the use the scripts is to put all four scripts in the same folder. If TensorFlow is installed correctly the images to train the model are downloaded automatically.
36 |
37 | ## Running
38 | Running is based on the steps:
39 |
40 | 1. create the model file
41 | 2. create an image file containing a handwritten number
42 | 3. predict the integer
43 |
44 | ### 1. create the model file
45 | The easiest way is to cd to the directory where the python files are located. Then run:
46 |
47 | ```python create_model_1.py```
48 |
49 | or
50 |
51 | ```python create_model_2.py```
52 |
53 | to create the model based on the MNIST beginners tutorial (model_1) or the model based on the expert tutorial (model_2).
54 |
55 | ### 2. create an image file
56 | You have to create a PNG file that contains a handwritten number. The background has to be white and the number has to be black. Any paint program should be able to do this. Also the image has to be auto cropped so that there is no border around the number.
57 |
58 | ### 3. predict the integer
59 | The easiest way again is to put the image file from the previous step (step 2) in the same directory as the python scripts and cd to the directory where the python files are located.
60 |
61 | The predict scripts require one argument: the file location of the image file containing the handwritten number. For example when the image file is ‘number1.png’ and is in the same location as the script, run:
62 |
63 | ```python predict_1.py ‘number1.png’```
64 |
65 | or
66 |
67 | ```python predict_2.py ‘number1.png’```
68 |
69 | The first script, predict_1.py, uses the model.ckpt file created by the create_model_1.py script. The second script, predict_2.py, uses the model2.ckpt file created by the create_model_2.py script.
70 |
71 |
72 |
--------------------------------------------------------------------------------
/create_model_1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Niek Temme.
2 | # Adapted form the on the MNIST biginners tutorial by Google.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """A very simple MNIST classifier.
18 | Documentation at
19 | http://niektemme.com/ @@to do
20 |
21 | This script is based on the Tensoflow MNIST beginners tutorial
22 | See extensive documentation for the tutorial at
23 | https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html
24 | """
25 |
26 | #import modules
27 | import tensorflow as tf
28 | from tensorflow.examples.tutorials.mnist import input_data
29 |
30 | #import data
31 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
32 |
33 | # Create the model
34 | x = tf.placeholder(tf.float32, [None, 784])
35 | W = tf.Variable(tf.zeros([784, 10]))
36 | b = tf.Variable(tf.zeros([10]))
37 | y = tf.nn.softmax(tf.matmul(x, W) + b)
38 |
39 | # Define loss and optimizer
40 | y_ = tf.placeholder(tf.float32, [None, 10])
41 | cross_entropy = -tf.reduce_sum(y_*tf.log(y))
42 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
43 |
44 | init_op = global_variables_initializer()
45 | saver = tf.train.Saver()
46 |
47 |
48 | # Train the model and save the model to disk as a model.ckpt file
49 | # file is stored in the same directory as this python script is started
50 | """
51 | The use of 'with tf.Session() as sess:' is taken from the Tensor flow documentation
52 | on on saving and restoring variables.
53 | https://www.tensorflow.org/versions/master/how_tos/variables/index.html
54 | """
55 | with tf.Session() as sess:
56 | sess.run(init_op)
57 | for i in range(1000):
58 | batch_xs, batch_ys = mnist.train.next_batch(100)
59 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
60 |
61 | save_path = saver.save(sess, "model.ckpt")
62 | print ("Model saved in file: ", save_path)
63 |
64 |
--------------------------------------------------------------------------------
/create_model_2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Niek Temme.
2 | # Adapted form the on the MNIST expert tutorial by Google.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """A very simple MNIST classifier.
18 | Documentation at
19 | http://niektemme.com/ @@to do
20 |
21 | This script is based on the Tensoflow MNIST expert tutorial
22 | See extensive documentation for the tutorial at
23 | https://www.tensorflow.org/versions/master/tutorials/mnist/pros/index.html
24 | """
25 |
26 | #import modules
27 | import tensorflow as tf
28 | from tensorflow.examples.tutorials.mnist import input_data
29 |
30 | #import data
31 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
32 |
33 | sess = tf.InteractiveSession()
34 |
35 | # Create the model
36 | x = tf.placeholder(tf.float32, [None, 784])
37 | y_ = tf.placeholder(tf.float32, [None, 10])
38 | W = tf.Variable(tf.zeros([784, 10]))
39 | b = tf.Variable(tf.zeros([10]))
40 | y = tf.nn.softmax(tf.matmul(x, W) + b)
41 |
42 | def weight_variable(shape):
43 | initial = tf.truncated_normal(shape, stddev=0.1)
44 | return tf.Variable(initial)
45 |
46 | def bias_variable(shape):
47 | initial = tf.constant(0.1, shape=shape)
48 | return tf.Variable(initial)
49 |
50 |
51 | def conv2d(x, W):
52 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
53 |
54 | def max_pool_2x2(x):
55 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
56 | strides=[1, 2, 2, 1], padding='SAME')
57 |
58 |
59 | W_conv1 = weight_variable([5, 5, 1, 32])
60 | b_conv1 = bias_variable([32])
61 |
62 | x_image = tf.reshape(x, [-1,28,28,1])
63 | h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
64 | h_pool1 = max_pool_2x2(h_conv1)
65 |
66 |
67 | W_conv2 = weight_variable([5, 5, 32, 64])
68 | b_conv2 = bias_variable([64])
69 |
70 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
71 | h_pool2 = max_pool_2x2(h_conv2)
72 |
73 | W_fc1 = weight_variable([7 * 7 * 64, 1024])
74 | b_fc1 = bias_variable([1024])
75 |
76 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
77 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
78 |
79 | keep_prob = tf.placeholder(tf.float32)
80 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
81 |
82 | W_fc2 = weight_variable([1024, 10])
83 | b_fc2 = bias_variable([10])
84 |
85 | y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
86 |
87 | # Define loss and optimizer
88 | cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
89 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
90 | correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
91 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
92 |
93 |
94 | """
95 | Train the model and save the model to disk as a model2.ckpt file
96 | file is stored in the same directory as this python script is started
97 |
98 | Based on the documentatoin at
99 | https://www.tensorflow.org/versions/master/how_tos/variables/index.html
100 | """
101 | saver = tf.train.Saver()
102 | sess.run(global_variables_initializer())
103 | #with tf.Session() as sess:
104 | #sess.run(init_op)
105 | for i in range(20000):
106 | batch = mnist.train.next_batch(50)
107 | if i%100 == 0:
108 | train_accuracy = accuracy.eval(feed_dict={
109 | x:batch[0], y_: batch[1], keep_prob: 1.0})
110 | print("step %d, training accuracy %g"%(i, train_accuracy))
111 | train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
112 |
113 | save_path = saver.save(sess, "model2.ckpt")
114 | print ("Model saved in file: ", save_path)
115 |
116 | print("test accuracy %g"%accuracy.eval(feed_dict={
117 | x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/predict_1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Niek Temme.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Predict a handwritten integer (MNIST beginners).
17 |
18 | Script requires
19 | 1) saved model (model.ckpt file) in the same location as the script is run from.
20 | (requried a model created in the MNIST beginners tutorial)
21 | 2) one argument (png file location of a handwritten integer)
22 |
23 | Documentation at:
24 | http://niektemme.com/ @@to do
25 | """
26 |
27 | #import modules
28 | import sys
29 | import tensorflow as tf
30 | from PIL import Image,ImageFilter
31 |
32 | def predictint(imvalue):
33 | """
34 | This function returns the predicted integer.
35 | The input is the pixel values from the imageprepare() function.
36 | """
37 |
38 | # Define the model (same as when creating the model file)
39 | x = tf.placeholder(tf.float32, [None, 784])
40 | W = tf.Variable(tf.zeros([784, 10]))
41 | b = tf.Variable(tf.zeros([10]))
42 | y = tf.nn.softmax(tf.matmul(x, W) + b)
43 |
44 | init_op = tf.global_variables_initializer()
45 | saver = tf.train.Saver()
46 |
47 | """
48 | Load the model.ckpt file
49 | file is stored in the same directory as this python script is started
50 | Use the model to predict the integer. Integer is returend as list.
51 |
52 | Based on the documentatoin at
53 | https://www.tensorflow.org/versions/master/how_tos/variables/index.html
54 | """
55 | with tf.Session() as sess:
56 | sess.run(init_op)
57 | saver.restore(sess, "model.ckpt")
58 | #print ("Model restored.")
59 |
60 | prediction=tf.argmax(y,1)
61 | return prediction.eval(feed_dict={x: [imvalue]}, session=sess)
62 |
63 |
64 | def imageprepare(argv):
65 | """
66 | This function returns the pixel values.
67 | The imput is a png file location.
68 | """
69 | im = Image.open(argv).convert('L')
70 | width = float(im.size[0])
71 | height = float(im.size[1])
72 | newImage = Image.new('L', (28, 28), (255)) #creates white canvas of 28x28 pixels
73 |
74 | if width > height: #check which dimension is bigger
75 | #Width is bigger. Width becomes 20 pixels.
76 | nheight = int(round((20.0/width*height),0)) #resize height according to ratio width
77 | if (nheight == 0): #rare case but minimum is 1 pixel
78 | nheight = 1
79 | # resize and sharpen
80 | img = im.resize((20,nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
81 | wtop = int(round(((28 - nheight)/2),0)) #caculate horizontal pozition
82 | newImage.paste(img, (4, wtop)) #paste resized image on white canvas
83 | else:
84 | #Height is bigger. Heigth becomes 20 pixels.
85 | nwidth = int(round((20.0/height*width),0)) #resize width according to ratio height
86 | if (nwidth == 0): #rare case but minimum is 1 pixel
87 | nwidth = 1
88 | # resize and sharpen
89 | img = im.resize((nwidth,20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
90 | wleft = int(round(((28 - nwidth)/2),0)) #caculate vertical pozition
91 | newImage.paste(img, (wleft, 4)) #paste resized image on white canvas
92 |
93 | #newImage.save("sample.png")
94 |
95 | tv = list(newImage.getdata()) #get pixel values
96 |
97 | #normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.
98 | tva = [ (255-x)*1.0/255.0 for x in tv]
99 | return tva
100 | #print(tva)
101 |
102 | def main(argv):
103 | """
104 | Main function.
105 | """
106 | imvalue = imageprepare(argv)
107 | predint = predictint(imvalue)
108 | print (predint[0]) #first value in list
109 |
110 | if __name__ == "__main__":
111 | main(sys.argv[1])
112 |
--------------------------------------------------------------------------------
/predict_2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Niek Temme.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Predict a handwritten integer (MNIST expert).
17 |
18 | Script requires
19 | 1) saved model (model2.ckpt file) in the same location as the script is run from.
20 | (requried a model created in the MNIST expert tutorial)
21 | 2) one argument (png file location of a handwritten integer)
22 |
23 | Documentation at:
24 | http://niektemme.com/ @@to do
25 | """
26 |
27 | #import modules
28 | import sys
29 | import tensorflow as tf
30 | from PIL import Image, ImageFilter
31 |
32 | def predictint(imvalue):
33 | """
34 | This function returns the predicted integer.
35 | The imput is the pixel values from the imageprepare() function.
36 | """
37 |
38 | # Define the model (same as when creating the model file)
39 | x = tf.placeholder(tf.float32, [None, 784])
40 | W = tf.Variable(tf.zeros([784, 10]))
41 | b = tf.Variable(tf.zeros([10]))
42 |
43 | def weight_variable(shape):
44 | initial = tf.truncated_normal(shape, stddev=0.1)
45 | return tf.Variable(initial)
46 |
47 | def bias_variable(shape):
48 | initial = tf.constant(0.1, shape=shape)
49 | return tf.Variable(initial)
50 |
51 | def conv2d(x, W):
52 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
53 |
54 | def max_pool_2x2(x):
55 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
56 |
57 | W_conv1 = weight_variable([5, 5, 1, 32])
58 | b_conv1 = bias_variable([32])
59 |
60 | x_image = tf.reshape(x, [-1,28,28,1])
61 | h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
62 | h_pool1 = max_pool_2x2(h_conv1)
63 |
64 | W_conv2 = weight_variable([5, 5, 32, 64])
65 | b_conv2 = bias_variable([64])
66 |
67 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
68 | h_pool2 = max_pool_2x2(h_conv2)
69 |
70 | W_fc1 = weight_variable([7 * 7 * 64, 1024])
71 | b_fc1 = bias_variable([1024])
72 |
73 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
74 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
75 |
76 | keep_prob = tf.placeholder(tf.float32)
77 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
78 |
79 | W_fc2 = weight_variable([1024, 10])
80 | b_fc2 = bias_variable([10])
81 |
82 | y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
83 |
84 | init_op = tf.global_variables_initializer()
85 | saver = tf.train.Saver()
86 |
87 | """
88 | Load the model2.ckpt file
89 | file is stored in the same directory as this python script is started
90 | Use the model to predict the integer. Integer is returend as list.
91 |
92 | Based on the documentatoin at
93 | https://www.tensorflow.org/versions/master/how_tos/variables/index.html
94 | """
95 | with tf.Session() as sess:
96 | sess.run(init_op)
97 | saver.restore(sess, "model2.ckpt")
98 | #print ("Model restored.")
99 |
100 | prediction=tf.argmax(y_conv,1)
101 | return prediction.eval(feed_dict={x: [imvalue],keep_prob: 1.0}, session=sess)
102 |
103 |
104 | def imageprepare(argv):
105 | """
106 | This function returns the pixel values.
107 | The imput is a png file location.
108 | """
109 | im = Image.open(argv).convert('L')
110 | width = float(im.size[0])
111 | height = float(im.size[1])
112 | newImage = Image.new('L', (28, 28), (255)) #creates white canvas of 28x28 pixels
113 |
114 | if width > height: #check which dimension is bigger
115 | #Width is bigger. Width becomes 20 pixels.
116 | nheight = int(round((20.0/width*height),0)) #resize height according to ratio width
117 | if (nheigth == 0): #rare case but minimum is 1 pixel
118 | nheigth = 1
119 | # resize and sharpen
120 | img = im.resize((20,nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
121 | wtop = int(round(((28 - nheight)/2),0)) #caculate horizontal pozition
122 | newImage.paste(img, (4, wtop)) #paste resized image on white canvas
123 | else:
124 | #Height is bigger. Heigth becomes 20 pixels.
125 | nwidth = int(round((20.0/height*width),0)) #resize width according to ratio height
126 | if (nwidth == 0): #rare case but minimum is 1 pixel
127 | nwidth = 1
128 | # resize and sharpen
129 | img = im.resize((nwidth,20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
130 | wleft = int(round(((28 - nwidth)/2),0)) #caculate vertical pozition
131 | newImage.paste(img, (wleft, 4)) #paste resized image on white canvas
132 |
133 | #newImage.save("sample.png")
134 |
135 | tv = list(newImage.getdata()) #get pixel values
136 |
137 | #normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.
138 | tva = [ (255-x)*1.0/255.0 for x in tv]
139 | return tva
140 | #print(tva)
141 |
142 | def main(argv):
143 | """
144 | Main function.
145 | """
146 | imvalue = imageprepare(argv)
147 | predint = predictint(imvalue)
148 | print (predint[0]) #first value in list
149 |
150 | if __name__ == "__main__":
151 | main(sys.argv[1])
152 |
--------------------------------------------------------------------------------