├── .gitignore ├── README.md ├── input_data.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | .env/ 4 | mnist/ 5 | models/ 6 | logs/ 7 | checkpoints/ 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Highway Networks 2 | 3 | Starter project for training highway networks on Fomoro. 4 | 5 | Check out the [blog post](https://medium.com/jim-fleming/highway-networks-with-tensorflow-1e6dfa667daa). 6 | 7 | ## Training 8 | 9 | 1. [Install TensorFlow](https://www.tensorflow.org/versions/r0.7/get_started/os_setup.html#pip-installation). 10 | 2. Clone the repo: `git clone https://github.com/fomorians/highway-cnn.git && cd highway-cnn` 11 | 3. Run training: `python main.py` 12 | -------------------------------------------------------------------------------- /input_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for downloading and reading MNIST data.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import gzip 22 | import os 23 | 24 | import tensorflow.python.platform 25 | 26 | import numpy 27 | from six.moves import urllib 28 | from six.moves import xrange # pylint: disable=redefined-builtin 29 | import tensorflow as tf 30 | 31 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 32 | 33 | 34 | def maybe_download(filename, work_directory): 35 | """Download the data from Yann's website, unless it's already here.""" 36 | if not os.path.exists(work_directory): 37 | os.mkdir(work_directory) 38 | filepath = os.path.join(work_directory, filename) 39 | if not os.path.exists(filepath): 40 | filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 41 | statinfo = os.stat(filepath) 42 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 43 | return filepath 44 | 45 | 46 | def _read32(bytestream): 47 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 48 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 49 | 50 | 51 | def extract_images(filename): 52 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 53 | print('Extracting', filename) 54 | with gzip.open(filename) as bytestream: 55 | magic = _read32(bytestream) 56 | if magic != 2051: 57 | raise ValueError( 58 | 'Invalid magic number %d in MNIST image file: %s' % 59 | (magic, filename)) 60 | num_images = _read32(bytestream) 61 | rows = _read32(bytestream) 62 | cols = _read32(bytestream) 63 | buf = bytestream.read(rows * cols * num_images) 64 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 65 | data = data.reshape(num_images, rows, cols, 1) 66 | return data 67 | 68 | 69 | def dense_to_one_hot(labels_dense, num_classes=10): 70 | """Convert class labels from scalars to one-hot vectors.""" 71 | num_labels = labels_dense.shape[0] 72 | index_offset = numpy.arange(num_labels) * num_classes 73 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 74 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 75 | return labels_one_hot 76 | 77 | 78 | def extract_labels(filename, one_hot=False): 79 | """Extract the labels into a 1D uint8 numpy array [index].""" 80 | print('Extracting', filename) 81 | with gzip.open(filename) as bytestream: 82 | magic = _read32(bytestream) 83 | if magic != 2049: 84 | raise ValueError( 85 | 'Invalid magic number %d in MNIST label file: %s' % 86 | (magic, filename)) 87 | num_items = _read32(bytestream) 88 | buf = bytestream.read(num_items) 89 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 90 | if one_hot: 91 | return dense_to_one_hot(labels) 92 | return labels 93 | 94 | 95 | class DataSet(object): 96 | 97 | def __init__(self, images, labels, fake_data=False, one_hot=False, 98 | dtype=tf.float32): 99 | """Construct a DataSet. 100 | 101 | one_hot arg is used only if fake_data is true. `dtype` can be either 102 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 103 | `[0, 1]`. 104 | """ 105 | dtype = tf.as_dtype(dtype).base_dtype 106 | if dtype not in (tf.uint8, tf.float32): 107 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 108 | dtype) 109 | if fake_data: 110 | self._num_examples = 10000 111 | self.one_hot = one_hot 112 | else: 113 | assert images.shape[0] == labels.shape[0], ( 114 | 'images.shape: %s labels.shape: %s' % (images.shape, 115 | labels.shape)) 116 | self._num_examples = images.shape[0] 117 | 118 | # Convert shape from [num examples, rows, columns, depth] 119 | # to [num examples, rows*columns] (assuming depth == 1) 120 | assert images.shape[3] == 1 121 | images = images.reshape(images.shape[0], 122 | images.shape[1] * images.shape[2]) 123 | if dtype == tf.float32: 124 | # Convert from [0, 255] -> [0.0, 1.0]. 125 | images = images.astype(numpy.float32) 126 | images = numpy.multiply(images, 1.0 / 255.0) 127 | self._images = images 128 | self._labels = labels 129 | self._epochs_completed = 0 130 | self._index_in_epoch = 0 131 | 132 | @property 133 | def images(self): 134 | return self._images 135 | 136 | @property 137 | def labels(self): 138 | return self._labels 139 | 140 | @property 141 | def num_examples(self): 142 | return self._num_examples 143 | 144 | @property 145 | def epochs_completed(self): 146 | return self._epochs_completed 147 | 148 | def next_batch(self, batch_size, fake_data=False): 149 | """Return the next `batch_size` examples from this data set.""" 150 | if fake_data: 151 | fake_image = [1] * 784 152 | if self.one_hot: 153 | fake_label = [1] + [0] * 9 154 | else: 155 | fake_label = 0 156 | return [fake_image for _ in xrange(batch_size)], [ 157 | fake_label for _ in xrange(batch_size)] 158 | start = self._index_in_epoch 159 | self._index_in_epoch += batch_size 160 | if self._index_in_epoch > self._num_examples: 161 | # Finished epoch 162 | self._epochs_completed += 1 163 | # Shuffle the data 164 | perm = numpy.arange(self._num_examples) 165 | numpy.random.shuffle(perm) 166 | self._images = self._images[perm] 167 | self._labels = self._labels[perm] 168 | # Start next epoch 169 | start = 0 170 | self._index_in_epoch = batch_size 171 | assert batch_size <= self._num_examples 172 | end = self._index_in_epoch 173 | return self._images[start:end], self._labels[start:end] 174 | 175 | 176 | def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): 177 | class DataSets(object): 178 | pass 179 | data_sets = DataSets() 180 | 181 | if fake_data: 182 | def fake(): 183 | return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 184 | data_sets.train = fake() 185 | data_sets.validation = fake() 186 | data_sets.test = fake() 187 | return data_sets 188 | 189 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 190 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 191 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 192 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 193 | VALIDATION_SIZE = 5000 194 | 195 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 196 | train_images = extract_images(local_file) 197 | 198 | local_file = maybe_download(TRAIN_LABELS, train_dir) 199 | train_labels = extract_labels(local_file, one_hot=one_hot) 200 | 201 | local_file = maybe_download(TEST_IMAGES, train_dir) 202 | test_images = extract_images(local_file) 203 | 204 | local_file = maybe_download(TEST_LABELS, train_dir) 205 | test_labels = extract_labels(local_file, one_hot=one_hot) 206 | 207 | validation_images = train_images[:VALIDATION_SIZE] 208 | validation_labels = train_labels[:VALIDATION_SIZE] 209 | train_images = train_images[VALIDATION_SIZE:] 210 | train_labels = train_labels[VALIDATION_SIZE:] 211 | 212 | data_sets.train = DataSet(train_images, train_labels, dtype=dtype) 213 | data_sets.validation = DataSet(validation_images, validation_labels, 214 | dtype=dtype) 215 | data_sets.test = DataSet(test_images, test_labels, dtype=dtype) 216 | 217 | return data_sets 218 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import tensorflow as tf 5 | import numpy as np 6 | import input_data 7 | 8 | flags = tf.app.flags 9 | FLAGS = flags.FLAGS 10 | 11 | # define flags (note that Fomoro will not pass any flags by default) 12 | flags.DEFINE_boolean('skip-training', False, 'If true, skip training the model.') 13 | flags.DEFINE_boolean('restore', False, 'If true, restore the model from the latest checkpoint.') 14 | 15 | # define artifact directories where results from the session can be saved 16 | model_path = os.environ.get('MODEL_PATH', 'models/') 17 | checkpoint_path = os.environ.get('CHECKPOINT_PATH', 'checkpoints/') 18 | summary_path = os.environ.get('SUMMARY_PATH', 'logs/') 19 | 20 | mnist = input_data.read_data_sets('mnist', one_hot=True) 21 | 22 | def weight_bias(W_shape, b_shape, bias_init=0.1): 23 | W = tf.Variable(tf.truncated_normal(W_shape, stddev=0.1), name='weight') 24 | b = tf.Variable(tf.constant(bias_init, shape=b_shape), name='bias') 25 | return W, b 26 | 27 | def dense_layer(x, W_shape, b_shape, activation): 28 | W, b = weight_bias(W_shape, b_shape) 29 | return activation(tf.matmul(x, W) + b) 30 | 31 | def conv2d_layer(x, W_shape, b_shape, strides, padding): 32 | W, b = weight_bias(W_shape, b_shape) 33 | return tf.nn.relu(tf.nn.conv2d(x, W, strides, padding) + b) 34 | 35 | def highway_conv2d_layer(x, W_shape, b_shape, strides, padding, carry_bias=-1.0): 36 | W, b = weight_bias(W_shape, b_shape, carry_bias) 37 | W_T, b_T = weight_bias(W_shape, b_shape) 38 | H = tf.nn.relu(tf.nn.conv2d(x, W, strides, padding) + b, name='activation') 39 | T = tf.sigmoid(tf.nn.conv2d(x, W_T, strides, padding) + b_T, name='transform_gate') 40 | C = tf.sub(1.0, T, name="carry_gate") 41 | return tf.add(tf.mul(H, T), tf.mul(x, C), 'y') # y = (H * T) + (x * C) 42 | 43 | with tf.Graph().as_default(), tf.Session() as sess: 44 | x = tf.placeholder("float", [None, 784]) 45 | y_ = tf.placeholder("float", [None, 10]) 46 | 47 | carry_bias_init = -1.0 48 | 49 | x_image = tf.reshape(x, [-1, 28, 28, 1]) # reshape for conv 50 | 51 | keep_prob1 = tf.placeholder("float", name="keep_prob1") 52 | x_drop = tf.nn.dropout(x_image, keep_prob1) 53 | 54 | prev_y = conv2d_layer(x_drop, [5, 5, 1, 32], [32], [1, 1, 1, 1], 'SAME') 55 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 56 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 57 | 58 | prev_y = tf.nn.max_pool(prev_y, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') 59 | 60 | keep_prob2 = tf.placeholder("float", name="keep_prob2") 61 | prev_y = tf.nn.dropout(prev_y, keep_prob2) 62 | 63 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 64 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 65 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 66 | 67 | prev_y = tf.nn.max_pool(prev_y, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') 68 | 69 | keep_prob3 = tf.placeholder("float", name="keep_prob3") 70 | prev_y = tf.nn.dropout(prev_y, keep_prob3) 71 | 72 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 73 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 74 | prev_y = highway_conv2d_layer(prev_y, [3, 3, 32, 32], [32], [1, 1, 1, 1], 'SAME', carry_bias=carry_bias_init) 75 | 76 | prev_y = tf.nn.max_pool(prev_y, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 77 | 78 | keep_prob4 = tf.placeholder("float", name="keep_prob4") 79 | prev_y = tf.nn.dropout(prev_y, keep_prob4) 80 | 81 | prev_y = tf.reshape(prev_y, [-1, 4 * 4 * 32]) 82 | y = dense_layer(prev_y, [4 * 4 * 32, 10], [10], tf.nn.softmax) 83 | 84 | # define training and accuracy operations 85 | with tf.name_scope("loss") as scope: 86 | loss = -tf.reduce_sum(y_ * tf.log(y)) 87 | tf.scalar_summary("loss", loss) 88 | 89 | with tf.name_scope("train") as scope: 90 | train_step = tf.train.GradientDescentOptimizer(1e-2).minimize(loss) 91 | 92 | with tf.name_scope("test") as scope: 93 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 94 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 95 | tf.scalar_summary('accuracy', accuracy) 96 | 97 | merged_summaries = tf.merge_all_summaries() 98 | 99 | # create a saver instance to restore from the checkpoint 100 | saver = tf.train.Saver(max_to_keep=1) 101 | 102 | # initialize our variables 103 | sess.run(tf.initialize_all_variables()) 104 | 105 | # save the graph definition as a protobuf file 106 | tf.train.write_graph(sess.graph_def, model_path, 'highway.pb', as_text=False) 107 | 108 | # restore variables 109 | if FLAGS.restore: 110 | latest_checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) 111 | if latest_checkpoint_path: 112 | saver.restore(sess, latest_checkpoint_path) 113 | 114 | if not FLAGS.skip_training: 115 | summary_writer = tf.train.SummaryWriter(summary_path, sess.graph_def) 116 | 117 | num_steps = 5000 118 | checkpoint_interval = 100 119 | batch_size = 50 120 | 121 | step = 0 122 | for i in range(num_steps): 123 | batch_xs, batch_ys = mnist.train.next_batch(batch_size) 124 | if step % checkpoint_interval == 0: 125 | validation_accuracy, summary = sess.run([accuracy, merged_summaries], feed_dict={ 126 | x: mnist.validation.images, 127 | y_: mnist.validation.labels, 128 | keep_prob1: 1.0, 129 | keep_prob2: 1.0, 130 | keep_prob3: 1.0, 131 | keep_prob4: 1.0, 132 | }) 133 | summary_writer.add_summary(summary, step) 134 | saver.save(sess, checkpoint_path + 'checkpoint', global_step=step) 135 | print('step %d, training accuracy %g' % (step, validation_accuracy)) 136 | 137 | sess.run(train_step, feed_dict={ 138 | x: batch_xs, 139 | y_: batch_ys, 140 | keep_prob1: 0.8, 141 | keep_prob2: 0.7, 142 | keep_prob3: 0.6, 143 | keep_prob4: 0.5, 144 | }) 145 | 146 | step += 1 147 | 148 | summary_writer.close() 149 | 150 | test_accuracy = sess.run(accuracy, feed_dict={ 151 | x: mnist.test.images, 152 | y_: mnist.test.labels, 153 | keep_prob1: 1.0, 154 | keep_prob2: 1.0, 155 | keep_prob3: 1.0, 156 | keep_prob4: 1.0, 157 | }) 158 | print('test accuracy %g' % test_accuracy) 159 | --------------------------------------------------------------------------------