├── .github └── FUNDING.yml ├── .gitignore ├── README.md ├── blog ├── own_0.png ├── own_3.png ├── own_4.png └── own_8.png ├── cps ├── checkpoint ├── model.ckpt └── model.ckpt.meta ├── img └── test_2.png ├── input_data.py ├── input_data.pyc ├── learn ├── test_2_1050_585_4.png ├── test_2_1125_100_3.png ├── test_2_1330_585_8.png ├── test_2_1550_100_4.png ├── test_2_1575_600_6.png ├── test_2_1800_540_9.png ├── test_2_2030_585_5.png ├── test_2_2040_90_5.png ├── test_2_2280_540_2.png ├── test_2_2490_100_6.png ├── test_2_2700_100_7.png ├── test_2_300_550_1.png ├── test_2_3150_100_8.png ├── test_2_3600_100_9.png ├── test_2_425_120_1.png ├── test_2_550_550_3.png ├── test_2_60_135_0.png ├── test_2_750_135_2.png ├── test_2_770_540_7.png └── test_2_90_550_0.png ├── learn_extra.py ├── mnist.py ├── predict_interface.py ├── predict_interface_usage.py └── requirements.txt /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [Wikunia] 4 | patreon: opensources 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | MNIST_data/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow MNIST 2 | How to preprocess for MNIST? I didn't find any good resources for it so I decided to reasearch and try it out and 3 | explain it simple with steps in Python. You can find the result on my blog [OpenSourcES](http://opensourc.es/blog/tensorflow-mnist). 4 | 5 | ## Usage 6 | 7 | This code should be working on Python 3.6 if opencv and tensorflow and a few other packages are installed. 8 | See `requirements.txt` 9 | 10 | You should be able to run `python mnist.py` and `python predict_interface_usage.py test_2` where `test_2` is the filename (without extension) of an image in img/ 11 | 12 | ### mnist.py 13 | `SUCCESS` will be in the form of something like this: 14 | > 0.9145 15 | > 16 | > [8 0 4 3] 17 | > 18 | > 1.0 19 | 20 | ### predict_interface_usage.py 21 | A photograph of handwritten digits as input (`img/`), 22 | `SUCCESS` will write an output with predictions to the command 23 | prompt and it will generate an image with the predictions in `pro-img/` 24 | 25 | -------------------------------------------------------------------------------- /blog/own_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/blog/own_0.png -------------------------------------------------------------------------------- /blog/own_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/blog/own_3.png -------------------------------------------------------------------------------- /blog/own_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/blog/own_4.png -------------------------------------------------------------------------------- /blog/own_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/blog/own_8.png -------------------------------------------------------------------------------- /cps/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /cps/model.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/cps/model.ckpt -------------------------------------------------------------------------------- /cps/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/cps/model.ckpt.meta -------------------------------------------------------------------------------- /img/test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/img/test_2.png -------------------------------------------------------------------------------- /input_data.py: -------------------------------------------------------------------------------- 1 | """Functions for downloading and reading MNIST data.""" 2 | 3 | import gzip 4 | import os 5 | import urllib.request, urllib.parse, urllib.error 6 | import numpy 7 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 8 | def maybe_download(filename, work_directory): 9 | """Download the data from Yann's website, unless it's already here.""" 10 | if not os.path.exists(work_directory): 11 | os.mkdir(work_directory) 12 | filepath = os.path.join(work_directory, filename) 13 | if not os.path.exists(filepath): 14 | filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 15 | statinfo = os.stat(filepath) 16 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 17 | return filepath 18 | def _read32(bytestream): 19 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 20 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 21 | def extract_images(filename): 22 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 23 | print('Extracting', filename) 24 | with gzip.open(filename) as bytestream: 25 | magic = _read32(bytestream) 26 | if magic != 2051: 27 | raise ValueError( 28 | 'Invalid magic number %d in MNIST image file: %s' % 29 | (magic, filename)) 30 | num_images = _read32(bytestream) 31 | rows = _read32(bytestream) 32 | cols = _read32(bytestream) 33 | buf = bytestream.read(rows * cols * num_images) 34 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 35 | data = data.reshape(num_images, rows, cols, 1) 36 | return data 37 | def dense_to_one_hot(labels_dense, num_classes=10): 38 | """Convert class labels from scalars to one-hot vectors.""" 39 | num_labels = labels_dense.shape[0] 40 | index_offset = numpy.arange(num_labels) * num_classes 41 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 42 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 43 | return labels_one_hot 44 | def extract_labels(filename, one_hot=False): 45 | """Extract the labels into a 1D uint8 numpy array [index].""" 46 | print('Extracting', filename) 47 | with gzip.open(filename) as bytestream: 48 | magic = _read32(bytestream) 49 | if magic != 2049: 50 | raise ValueError( 51 | 'Invalid magic number %d in MNIST label file: %s' % 52 | (magic, filename)) 53 | num_items = _read32(bytestream) 54 | buf = bytestream.read(num_items) 55 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 56 | if one_hot: 57 | return dense_to_one_hot(labels) 58 | return labels 59 | class DataSet(object): 60 | def __init__(self, images, labels, fake_data=False): 61 | if fake_data: 62 | self._num_examples = 10000 63 | else: 64 | assert images.shape[0] == labels.shape[0], ( 65 | "images.shape: %s labels.shape: %s" % (images.shape, 66 | labels.shape)) 67 | self._num_examples = images.shape[0] 68 | # Convert shape from [num examples, rows, columns, depth] 69 | # to [num examples, rows*columns] (assuming depth == 1) 70 | assert images.shape[3] == 1 71 | images = images.reshape(images.shape[0], 72 | images.shape[1] * images.shape[2]) 73 | # Convert from [0, 255] -> [0.0, 1.0]. 74 | images = images.astype(numpy.float32) 75 | images = numpy.multiply(images, 1.0 / 255.0) 76 | self._images = images 77 | self._labels = labels 78 | self._epochs_completed = 0 79 | self._index_in_epoch = 0 80 | @property 81 | def images(self): 82 | return self._images 83 | @property 84 | def labels(self): 85 | return self._labels 86 | @property 87 | def num_examples(self): 88 | return self._num_examples 89 | @property 90 | def epochs_completed(self): 91 | return self._epochs_completed 92 | def next_batch(self, batch_size, fake_data=False): 93 | """Return the next `batch_size` examples from this data set.""" 94 | if fake_data: 95 | fake_image = [1.0 for _ in range(784)] 96 | fake_label = 0 97 | return [fake_image for _ in range(batch_size)], [ 98 | fake_label for _ in range(batch_size)] 99 | start = self._index_in_epoch 100 | self._index_in_epoch += batch_size 101 | if self._index_in_epoch > self._num_examples: 102 | # Finished epoch 103 | self._epochs_completed += 1 104 | # Shuffle the data 105 | perm = numpy.arange(self._num_examples) 106 | numpy.random.shuffle(perm) 107 | self._images = self._images[perm] 108 | self._labels = self._labels[perm] 109 | # Start next epoch 110 | start = 0 111 | self._index_in_epoch = batch_size 112 | assert batch_size <= self._num_examples 113 | end = self._index_in_epoch 114 | return self._images[start:end], self._labels[start:end] 115 | def read_data_sets(train_dir, fake_data=False, one_hot=False): 116 | class DataSets(object): 117 | pass 118 | data_sets = DataSets() 119 | if fake_data: 120 | data_sets.train = DataSet([], [], fake_data=True) 121 | data_sets.validation = DataSet([], [], fake_data=True) 122 | data_sets.test = DataSet([], [], fake_data=True) 123 | return data_sets 124 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 125 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 126 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 127 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 128 | VALIDATION_SIZE = 5000 129 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 130 | train_images = extract_images(local_file) 131 | local_file = maybe_download(TRAIN_LABELS, train_dir) 132 | train_labels = extract_labels(local_file, one_hot=one_hot) 133 | local_file = maybe_download(TEST_IMAGES, train_dir) 134 | test_images = extract_images(local_file) 135 | local_file = maybe_download(TEST_LABELS, train_dir) 136 | test_labels = extract_labels(local_file, one_hot=one_hot) 137 | validation_images = train_images[:VALIDATION_SIZE] 138 | validation_labels = train_labels[:VALIDATION_SIZE] 139 | train_images = train_images[VALIDATION_SIZE:] 140 | train_labels = train_labels[VALIDATION_SIZE:] 141 | data_sets.train = DataSet(train_images, train_labels) 142 | data_sets.validation = DataSet(validation_images, validation_labels) 143 | data_sets.test = DataSet(test_images, test_labels) 144 | return data_sets -------------------------------------------------------------------------------- /input_data.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/input_data.pyc -------------------------------------------------------------------------------- /learn/test_2_1050_585_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_1050_585_4.png -------------------------------------------------------------------------------- /learn/test_2_1125_100_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_1125_100_3.png -------------------------------------------------------------------------------- /learn/test_2_1330_585_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_1330_585_8.png -------------------------------------------------------------------------------- /learn/test_2_1550_100_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_1550_100_4.png -------------------------------------------------------------------------------- /learn/test_2_1575_600_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_1575_600_6.png -------------------------------------------------------------------------------- /learn/test_2_1800_540_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_1800_540_9.png -------------------------------------------------------------------------------- /learn/test_2_2030_585_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_2030_585_5.png -------------------------------------------------------------------------------- /learn/test_2_2040_90_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_2040_90_5.png -------------------------------------------------------------------------------- /learn/test_2_2280_540_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_2280_540_2.png -------------------------------------------------------------------------------- /learn/test_2_2490_100_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_2490_100_6.png -------------------------------------------------------------------------------- /learn/test_2_2700_100_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_2700_100_7.png -------------------------------------------------------------------------------- /learn/test_2_300_550_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_300_550_1.png -------------------------------------------------------------------------------- /learn/test_2_3150_100_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_3150_100_8.png -------------------------------------------------------------------------------- /learn/test_2_3600_100_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_3600_100_9.png -------------------------------------------------------------------------------- /learn/test_2_425_120_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_425_120_1.png -------------------------------------------------------------------------------- /learn/test_2_550_550_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_550_550_3.png -------------------------------------------------------------------------------- /learn/test_2_60_135_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_60_135_0.png -------------------------------------------------------------------------------- /learn/test_2_750_135_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_750_135_2.png -------------------------------------------------------------------------------- /learn/test_2_770_540_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_770_540_7.png -------------------------------------------------------------------------------- /learn/test_2_90_550_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opensourcesblog/tensorflow-mnist/bbc7e11980df55a1595bc3e2144f7cd28bf1a3cd/learn/test_2_90_550_0.png -------------------------------------------------------------------------------- /learn_extra.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import numpy as np 4 | from scipy import ndimage 5 | import sys 6 | import os 7 | import math 8 | 9 | def getBestShift(img): 10 | cy,cx = ndimage.measurements.center_of_mass(img) 11 | 12 | rows,cols = img.shape 13 | shiftx = np.round(cols/2.0-cx).astype(int) 14 | shifty = np.round(rows/2.0-cy).astype(int) 15 | 16 | return shiftx,shifty 17 | 18 | def shift(img,sx,sy): 19 | rows,cols = img.shape 20 | M = np.float32([[1,0,sx],[0,1,sy]]) 21 | shifted = cv2.warpAffine(img,M,(cols,rows)) 22 | return shifted 23 | 24 | def get_x_by_image(folder,image,reverse=False): 25 | # read the image 26 | gray = cv2.imread(folder+"/"+image, 0) 27 | 28 | # rescale it 29 | if reverse: 30 | gray = cv2.resize(255 - gray, (28, 28)) 31 | else: 32 | gray = cv2.resize(gray, (28, 28)) 33 | # better black and white version 34 | (thresh, gray) = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 35 | 36 | while np.sum(gray[0]) == 0: 37 | gray = gray[1:] 38 | 39 | while np.sum(gray[:, 0]) == 0: 40 | gray = np.delete(gray, 0, 1) 41 | 42 | while np.sum(gray[-1]) == 0: 43 | gray = gray[:-1] 44 | 45 | while np.sum(gray[:, -1]) == 0: 46 | gray = np.delete(gray, -1, 1) 47 | 48 | rows, cols = gray.shape 49 | 50 | if rows > cols: 51 | factor = 20.0 / rows 52 | rows = 20 53 | cols = int(round(cols * factor)) 54 | # first cols than rows 55 | gray = cv2.resize(gray, (cols, rows)) 56 | else: 57 | factor = 20.0 / cols 58 | cols = 20 59 | rows = int(round(rows * factor)) 60 | # first cols than rows 61 | gray = cv2.resize(gray, (cols, rows)) 62 | 63 | colsPadding = (int(math.ceil((28 - cols) / 2.0)), int(math.floor((28 - cols) / 2.0))) 64 | rowsPadding = (int(math.ceil((28 - rows) / 2.0)), int(math.floor((28 - rows) / 2.0))) 65 | gray = np.lib.pad(gray, (rowsPadding, colsPadding), 'constant') 66 | 67 | shiftx, shifty = getBestShift(gray) 68 | shifted = shift(gray, shiftx, shifty) 69 | gray = shifted 70 | 71 | """ 72 | all images in the training set have an range from 0-1 73 | and not from 0-255 so we divide our flatten images 74 | (a one dimensional vector with our 784 pixels) 75 | to use the same 0-1 based range 76 | """ 77 | flatten = gray.flatten() / 255.0 78 | return flatten 79 | 80 | def get_y_by_digit(digit): 81 | arr = np.zeros((10)) 82 | arr[digit] = 1 83 | return arr 84 | 85 | def get_learning_batch(folder,reverse=False): 86 | batch_xs = [] 87 | batch_ys = [] 88 | for file in os.listdir(folder): 89 | if file.endswith(".png"): 90 | digit = file[-5:-4] 91 | y = get_y_by_digit(digit) 92 | x = get_x_by_image(folder,file,reverse=reverse) 93 | batch_xs.append(x) 94 | batch_ys.append(y) 95 | return batch_xs, batch_ys 96 | 97 | """ 98 | a placeholder for our image data: 99 | None stands for an unspecified number of images 100 | 784 = 28*28 pixel 101 | """ 102 | x = tf.placeholder("float", [None, 784]) 103 | 104 | # we need our weights for our neural net 105 | W = tf.Variable(tf.zeros([784,10])) 106 | # and the biases 107 | b = tf.Variable(tf.zeros([10])) 108 | 109 | """ 110 | softmax provides a probability based output 111 | we need to multiply the image values x and the weights 112 | and add the biases 113 | (the normal procedure, explained in previous articles) 114 | """ 115 | y = tf.nn.softmax(tf.matmul(x,W) + b) 116 | 117 | """ 118 | y_ will be filled with the real values 119 | which we want to train (digits 0-9) 120 | for an undefined number of images 121 | """ 122 | y_ = tf.placeholder("float", [None,10]) 123 | 124 | """ 125 | we use the cross_entropy function 126 | which we want to minimize to improve our model 127 | """ 128 | cross_entropy = -tf.reduce_sum(y_*tf.log(y)) 129 | 130 | """ 131 | use a learning rate of 0.01 132 | to minimize the cross_entropy error 133 | """ 134 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 135 | 136 | 137 | image = sys.argv[1] 138 | train = False if len(sys.argv) == 2 else sys.argv[2] 139 | checkpoint_dir = "cps/" 140 | 141 | saver = tf.train.Saver() 142 | sess = tf.Session() 143 | # initialize all variables and run init 144 | sess.run(tf.initialize_all_variables()) 145 | 146 | folder = sys.argv[1] 147 | 148 | # Here's where you're restoring the variables w and b. 149 | # Note that the graph is exactly as it was when the variables were 150 | # saved in a prior training run. 151 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 152 | if ckpt and ckpt.model_checkpoint_path: 153 | saver.restore(sess, ckpt.model_checkpoint_path) 154 | else: 155 | print('No checkpoint found') 156 | exit(1) 157 | 158 | if len(sys.argv) > 2: 159 | reverse =sys.argv[2] 160 | else: 161 | reverse = False 162 | 163 | batch_xs, batch_ys = get_learning_batch(folder,reverse=reverse) 164 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 165 | saver.save(sess, checkpoint_dir+'model.ckpt') 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import input_data 3 | import cv2 4 | import numpy as np 5 | import math 6 | from scipy import ndimage 7 | 8 | 9 | def getBestShift(img): 10 | cy,cx = ndimage.measurements.center_of_mass(img) 11 | 12 | rows,cols = img.shape 13 | shiftx = np.round(cols/2.0-cx).astype(int) 14 | shifty = np.round(rows/2.0-cy).astype(int) 15 | 16 | return shiftx,shifty 17 | 18 | 19 | def shift(img,sx,sy): 20 | rows,cols = img.shape 21 | M = np.float32([[1,0,sx],[0,1,sy]]) 22 | shifted = cv2.warpAffine(img,M,(cols,rows)) 23 | return shifted 24 | 25 | 26 | def train_and_predict(input_images): 27 | # create a MNIST_data folder with the MNIST dataset if necessary 28 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 29 | 30 | """ 31 | a placeholder for our image data: 32 | None stands for an unspecified number of images 33 | 784 = 28*28 pixel 34 | """ 35 | x = tf.placeholder("float", [None, 784]) 36 | 37 | # we need our weights for our neural net 38 | W = tf.Variable(tf.zeros([784,10])) 39 | # and the biases 40 | b = tf.Variable(tf.zeros([10])) 41 | 42 | """ 43 | softmax provides a probability based output 44 | we need to multiply the image values x and the weights 45 | and add the biases 46 | (the normal procedure, explained in previous articles) 47 | """ 48 | y = tf.nn.softmax(tf.matmul(x,W) + b) 49 | 50 | """ 51 | y_ will be filled with the real values 52 | which we want to train (digits 0-9) 53 | for an undefined number of images 54 | """ 55 | y_ = tf.placeholder("float", [None,10]) 56 | 57 | """ 58 | we use the cross_entropy function 59 | which we want to minimize to improve our model 60 | """ 61 | cross_entropy = -tf.reduce_sum(y_*tf.log(y)) 62 | 63 | """ 64 | use a learning rate of 0.01 65 | to minimize the cross_entropy error 66 | """ 67 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 68 | 69 | # initialize all variables 70 | init = tf.initialize_all_variables() 71 | 72 | # create a session 73 | sess = tf.Session() 74 | sess.run(init) 75 | 76 | # use 1000 batches with a size of 100 each to train our net 77 | for i in range(1000): 78 | batch_xs, batch_ys = mnist.train.next_batch(100) 79 | # run the train_step function with the given image values (x) and the real output (y_) 80 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 81 | 82 | """ 83 | Let's get the accuracy of our model: 84 | our model is correct if the index with the highest y value 85 | is the same as in the real digit vector 86 | The mean of the correct_prediction gives us the accuracy. 87 | We need to run the accuracy function 88 | with our test set (mnist.test) 89 | We use the keys "images" and "labels" for x and y_ 90 | """ 91 | correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 92 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 93 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 94 | 95 | # create an an array where we can store our 4 pictures 96 | images = np.zeros((4,784)) 97 | # and the correct values 98 | correct_vals = np.zeros((4,10)) 99 | 100 | # we want to test our images which you saw at the top of this page 101 | i = 0 102 | # for no in [8,0,4,3]: 103 | for no in input_images: 104 | 105 | # read the image 106 | gray = cv2.imread("blog/own_"+str(no)+".png", 0) 107 | # gray = cv2.imread(no, 0) 108 | 109 | # rescale it 110 | gray = cv2.resize(255-gray, (28, 28)) 111 | # better black and white version 112 | (thresh, gray) = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 113 | 114 | while np.sum(gray[0]) == 0: 115 | gray = gray[1:] 116 | 117 | while np.sum(gray[:,0]) == 0: 118 | gray = np.delete(gray,0,1) 119 | 120 | while np.sum(gray[-1]) == 0: 121 | gray = gray[:-1] 122 | 123 | while np.sum(gray[:,-1]) == 0: 124 | gray = np.delete(gray,-1,1) 125 | 126 | rows,cols = gray.shape 127 | 128 | if rows > cols: 129 | factor = 20.0/rows 130 | rows = 20 131 | cols = int(round(cols*factor)) 132 | # first cols than rows 133 | gray = cv2.resize(gray, (cols,rows)) 134 | else: 135 | factor = 20.0/cols 136 | cols = 20 137 | rows = int(round(rows*factor)) 138 | # first cols than rows 139 | gray = cv2.resize(gray, (cols, rows)) 140 | 141 | colsPadding = (int(math.ceil((28-cols)/2.0)),int(math.floor((28-cols)/2.0))) 142 | rowsPadding = (int(math.ceil((28-rows)/2.0)),int(math.floor((28-rows)/2.0))) 143 | gray = np.lib.pad(gray,(rowsPadding,colsPadding),'constant') 144 | 145 | shiftx,shifty = getBestShift(gray) 146 | shifted = shift(gray,shiftx,shifty) 147 | gray = shifted 148 | 149 | # save the processed images 150 | cv2.imwrite("pro-img/image_"+str(no)+".png", gray) 151 | """ 152 | all images in the training set have an range from 0-1 153 | and not from 0-255 so we divide our flatten images 154 | (a one dimensional vector with our 784 pixels) 155 | to use the same 0-1 based range 156 | """ 157 | flatten = gray.flatten() / 255.0 158 | """ 159 | we need to store the flatten image and generate 160 | the correct_vals array 161 | correct_val for the first digit (9) would be 162 | [0,0,0,0,0,0,0,0,0,1] 163 | """ 164 | images[i] = flatten 165 | correct_val = np.zeros((10)) 166 | correct_val[no] = 1 167 | correct_vals[i] = correct_val 168 | i += 1 169 | 170 | """ 171 | the prediction will be an array with four values, 172 | which show the predicted number 173 | """ 174 | prediction = tf.argmax(y,1) 175 | print(prediction) 176 | """ 177 | we want to run the prediction and the accuracy function 178 | using our generated arrays (images and correct_vals) 179 | """ 180 | #print(sess.run(prediction, feed_dict={x: images, y_: correct_vals})) 181 | return(sess.run(prediction, feed_dict={x: images, y_: correct_vals})) 182 | print(sess.run(accuracy, feed_dict={x: images, y_: correct_vals})) 183 | 184 | if __name__ == '__main__': 185 | train_and_predict([0, 8, 4, 3]) -------------------------------------------------------------------------------- /predict_interface.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import input_data 3 | import cv2 4 | import numpy as np 5 | from scipy import ndimage 6 | import sys 7 | import os 8 | 9 | class PredSet(object): 10 | def __init__(self, location, top_left=None, bottom_right=None, actual_w_h=None, prob_with_pred=None): 11 | self.location = location 12 | 13 | if top_left is None: 14 | top_left = [] 15 | else: 16 | self.top_left = top_left 17 | 18 | if bottom_right is None: 19 | bottom_right = [] 20 | else: 21 | self.bottom_right = bottom_right 22 | 23 | if actual_w_h is None: 24 | actual_w_h = [] 25 | else: 26 | self.actual_w_h = actual_w_h 27 | 28 | if prob_with_pred is None: 29 | prob_with_pred = [] 30 | else: 31 | self.prob_with_pred = prob_with_pred 32 | 33 | def get_location(self): 34 | return self.location 35 | 36 | def get_top_left(self): 37 | return self.top_left 38 | 39 | def get_bottom_right(self): 40 | return self.bottom_right 41 | 42 | def get_actual_w_h(self): 43 | return self.actual_w_h 44 | 45 | def get_prediction(self): 46 | return self.prob_with_pred[1] 47 | 48 | def get_probability(self): 49 | return self.prob_with_pred[0] 50 | 51 | 52 | def getBestShift(img): 53 | cy,cx = ndimage.measurements.center_of_mass(img) 54 | print(cy,cx) 55 | 56 | rows,cols = img.shape 57 | shiftx = np.round(cols/2.0-cx).astype(int) 58 | shifty = np.round(rows/2.0-cy).astype(int) 59 | 60 | return shiftx,shifty 61 | 62 | 63 | def shift(img,sx,sy): 64 | rows,cols = img.shape 65 | M = np.float32([[1,0,sx],[0,1,sy]]) 66 | shifted = cv2.warpAffine(img,M,(cols,rows)) 67 | return shifted 68 | 69 | 70 | def pred_from_img(image, train): 71 | image = image 72 | train = train 73 | """ 74 | a placeholder for our image data: 75 | None stands for an unspecified number of images 76 | 784 = 28*28 pixel 77 | """ 78 | x = tf.placeholder("float", [None, 784]) 79 | 80 | # we need our weights for our neural net 81 | W = tf.Variable(tf.zeros([784,10])) 82 | # and the biases 83 | b = tf.Variable(tf.zeros([10])) 84 | 85 | """ 86 | softmax provides a probability based output 87 | we need to multiply the image values x and the weights 88 | and add the biases 89 | (the normal procedure, explained in previous articles) 90 | """ 91 | y = tf.nn.softmax(tf.matmul(x,W) + b) 92 | 93 | """ 94 | y_ will be filled with the real values 95 | which we want to train (digits 0-9) 96 | for an undefined number of images 97 | """ 98 | y_ = tf.placeholder("float", [None,10]) 99 | 100 | """ 101 | we use the cross_entropy function 102 | which we want to minimize to improve our model 103 | """ 104 | cross_entropy = -tf.reduce_sum(y_*tf.log(y)) 105 | 106 | """ 107 | use a learning rate of 0.01 108 | to minimize the cross_entropy error 109 | """ 110 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 111 | 112 | 113 | # image = sys.argv[1] 114 | # train = False if len(sys.argv) == 2 else sys.argv[2] 115 | checkpoint_dir = "cps/" 116 | 117 | saver = tf.train.Saver() 118 | sess = tf.Session() 119 | # initialize all variables and run init 120 | sess.run(tf.initialize_all_variables()) 121 | if train: 122 | print("TRAIN!!!") 123 | # create a MNIST_data folder with the MNIST dataset if necessary 124 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 125 | 126 | # use 1000 batches with a size of 100 each to train our net 127 | for i in range(1000): 128 | batch_xs, batch_ys = mnist.train.next_batch(100) 129 | # run the train_step function with the given image values (x) and the real output (y_) 130 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 131 | saver.save(sess, checkpoint_dir+'model.ckpt') 132 | """ 133 | Let's get the accuracy of our model: 134 | our model is correct if the index with the highest y value 135 | is the same as in the real digit vector 136 | The mean of the correct_prediction gives us the accuracy. 137 | We need to run the accuracy function 138 | with our test set (mnist.test) 139 | We use the keys "images" and "labels" for x and y_ 140 | """ 141 | correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 142 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 143 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 144 | else: 145 | # Here's where you're restoring the variables w and b. 146 | # Note that the graph is exactly as it was when the variables were 147 | # saved in a prior training run. 148 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 149 | if ckpt and ckpt.model_checkpoint_path: 150 | saver.restore(sess, ckpt.model_checkpoint_path) 151 | else: 152 | print('No checkpoint found') 153 | exit(1) 154 | 155 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 156 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 157 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 158 | print("accuracy: ", sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 159 | 160 | 161 | 162 | if not os.path.exists("img/" + image + ".png"): 163 | print("File img/" + image + ".png doesn't exist") 164 | exit(1) 165 | 166 | # read original image 167 | color_complete = cv2.imread("img/" + image + ".png") 168 | 169 | print(("read", "img/" + image + ".png")) 170 | # read the bw image 171 | gray_complete = cv2.imread("img/" + image + ".png", 0) 172 | 173 | # better black and white version 174 | _, gray_complete = cv2.threshold(255-gray_complete, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 175 | 176 | if not os.path.exists("pro-img"): 177 | os.makedirs("pro-img") 178 | 179 | cv2.imwrite("pro-img/compl.png", gray_complete) 180 | 181 | digit_image = -np.ones(gray_complete.shape) 182 | 183 | height, width = gray_complete.shape 184 | 185 | predSet_ret = [] 186 | 187 | """ 188 | crop into several images 189 | """ 190 | for cropped_width in range(100, 300, 20): 191 | for cropped_height in range(100, 300, 20): 192 | for shift_x in range(0, width-cropped_width, int(cropped_width/4)): 193 | for shift_y in range(0, height-cropped_height, int(cropped_height/4)): 194 | gray = gray_complete[shift_y:shift_y+cropped_height,shift_x:shift_x + cropped_width] 195 | if np.count_nonzero(gray) <= 20: 196 | continue 197 | 198 | if (np.sum(gray[0]) != 0) or (np.sum(gray[:,0]) != 0) or (np.sum(gray[-1]) != 0) or (np.sum(gray[:, 199 | -1]) != 0): 200 | continue 201 | 202 | top_left = np.array([shift_y, shift_x]) 203 | bottom_right = np.array([shift_y+cropped_height, shift_x + cropped_width]) 204 | 205 | while np.sum(gray[0]) == 0: 206 | top_left[0] += 1 207 | gray = gray[1:] 208 | 209 | while np.sum(gray[:,0]) == 0: 210 | top_left[1] += 1 211 | gray = np.delete(gray,0,1) 212 | 213 | while np.sum(gray[-1]) == 0: 214 | bottom_right[0] -= 1 215 | gray = gray[:-1] 216 | 217 | while np.sum(gray[:,-1]) == 0: 218 | bottom_right[1] -= 1 219 | gray = np.delete(gray,-1,1) 220 | 221 | actual_w_h = bottom_right-top_left 222 | if (np.count_nonzero(digit_image[top_left[0]:bottom_right[0],top_left[1]:bottom_right[1]]+1) > 223 | 0.2*actual_w_h[0]*actual_w_h[1]): 224 | continue 225 | 226 | print("------------------") 227 | print("------------------") 228 | 229 | rows,cols = gray.shape 230 | compl_dif = abs(rows-cols) 231 | half_Sm = int(compl_dif/2) 232 | half_Big = half_Sm if half_Sm*2 == compl_dif else half_Sm+1 233 | if rows > cols: 234 | gray = np.lib.pad(gray,((0,0),(half_Sm,half_Big)),'constant') 235 | else: 236 | gray = np.lib.pad(gray,((half_Sm,half_Big),(0,0)),'constant') 237 | 238 | gray = cv2.resize(gray, (20, 20)) 239 | gray = np.lib.pad(gray,((4,4),(4,4)),'constant') 240 | 241 | 242 | shiftx,shifty = getBestShift(gray) 243 | shifted = shift(gray,shiftx,shifty) 244 | gray = shifted 245 | 246 | cv2.imwrite("pro-img/"+image+"_"+str(shift_x)+"_"+str(shift_y)+".png", gray) 247 | 248 | """ 249 | all images in the training set have an range from 0-1 250 | and not from 0-255 so we divide our flatten images 251 | (a one dimensional vector with our 784 pixels) 252 | to use the same 0-1 based range 253 | """ 254 | flatten = gray.flatten() / 255.0 255 | 256 | 257 | print("Prediction for ",(shift_x, shift_y, cropped_width)) 258 | print("Pos") 259 | print(top_left) 260 | print(bottom_right) 261 | print(actual_w_h) 262 | print(" ") 263 | prediction = [tf.reduce_max(y),tf.argmax(y,1)[0]] 264 | pred = sess.run(prediction, feed_dict={x: [flatten]}) 265 | print(pred) 266 | 267 | predSet_ret.append(PredSet((shift_x, shift_y, cropped_width), 268 | top_left, 269 | bottom_right, 270 | actual_w_h, 271 | pred)) 272 | 273 | 274 | digit_image[top_left[0]:bottom_right[0],top_left[1]:bottom_right[1]] = pred[1] 275 | 276 | cv2.rectangle(color_complete,tuple(top_left[::-1]),tuple(bottom_right[::-1]),color=(0,255,0),thickness=5) 277 | 278 | font = cv2.FONT_HERSHEY_SIMPLEX 279 | cv2.putText(color_complete,str(pred[1]),(top_left[1],bottom_right[0]+50), 280 | font,fontScale=1.4,color=(0,255,0),thickness=4) 281 | cv2.putText(color_complete,format(pred[0]*100,".1f")+"%",(top_left[1]+30,bottom_right[0]+60), 282 | font,fontScale=0.8,color=(0,255,0),thickness=2) 283 | 284 | 285 | cv2.imwrite("pro-img/"+image+"_digitized_image.png", color_complete) 286 | return predSet_ret 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /predict_interface_usage.py: -------------------------------------------------------------------------------- 1 | import predict_interface 2 | import numpy as np 3 | import sys 4 | 5 | """ 6 | How to use the predict_interface -interface 7 | """ 8 | 9 | retVals = [] 10 | image = sys.argv[1] 11 | train = False if len(sys.argv) == 2 else sys.argv[2] 12 | 13 | ret_vals = predict_interface.pred_from_img(image, train) 14 | 15 | print(" ") 16 | print("----------") 17 | 18 | for i in range(0, len(ret_vals)): 19 | print("Prediction: " + str(ret_vals[i].get_prediction()) + ", Probability: " + str(ret_vals[i].get_probability())) 20 | print("Location in image: x(" + str(ret_vals[i].get_location()[1]) + "), y(" + str(ret_vals[i].get_location()[0]) + ")") 21 | print("Top left corner (shifted image): x(" + str(ret_vals[i].get_top_left()[1]) + "), y(" + str(ret_vals[i].get_top_left()[0]) + ")") 22 | print("Bottom right corner (shifted image): x(" + str(ret_vals[i].get_bottom_right()[1]) + "), y(" + str(ret_vals[i].get_bottom_right()[0]) + ")") 23 | print("Actual width and height (cropped image): w(" + str(ret_vals[i].get_actual_w_h()[1]) + "), h(" + str(ret_vals[i].get_actual_w_h()[0]) + ")") 24 | print(" ") 25 | print("----------") 26 | 27 | print("A modified image with the predictions: pro-img/IMAGE_NAME_digitized_image.png") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | astor==0.7.1 3 | gast==0.2.0 4 | grpcio==1.15.0 5 | h5py==2.8.0 6 | Keras-Applications==1.0.6 7 | Keras-Preprocessing==1.0.5 8 | Markdown==3.0.1 9 | numpy==1.15.2 10 | opencv-python==3.4.3.18 11 | pandas==0.23.4 12 | protobuf==3.6.1 13 | python-dateutil==2.7.3 14 | pytz==2018.5 15 | scipy==1.1.0 16 | six==1.11.0 17 | tensorboard==1.11.0 18 | tensorflow>=1.12.1 19 | termcolor==1.1.0 20 | Werkzeug==0.14.1 21 | --------------------------------------------------------------------------------