├── .gitignore ├── Makefile ├── README.md ├── __init__.py ├── datasets ├── __init__.py ├── cifar_10 │ ├── cifar10_pylearn2_gca_to_tfrecords.py │ ├── cifar10_to_tfrecord.py │ └── cifar10_utils.py ├── mnist │ ├── download_mnist_dataset.py │ └── mnist_utils.py └── utils.py ├── environment.yml ├── models ├── __init__.py └── binaryconnect.py ├── optimisers.py ├── run_with_args.py ├── run_with_yaml.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.tfrecord* 2 | _MNIST_DATA/* 3 | _sessions/ 4 | _pinned_sessions/ 5 | *tfevents* 6 | *.meta 7 | *.index 8 | *checkpoint* 9 | log.txt 10 | tmux-* 11 | ._* 12 | .DS_Store 13 | *.pyc 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help 2 | help: 3 | @echo "Make targets are:" 4 | @echo " make help - shows this message" 5 | @echo " make download-cifar10 - Downloads pylearn2 GCA Whitened CIFAR10 dataset." 6 | @echo " make clean-last-run - removes last run's intermediate files." 7 | @echo " make clean - removes intermediate files." 8 | 9 | .PHONY: download-cifar10 10 | download-cifar10: 11 | @echo Downloading ... 12 | @(curl -o archive.zip -L 'https://www.dropbox.com/sh/ygbx8ckcjwk0jlj/AAAtNQojrcx0A5Fc8FLZ3iTia?dl=1') 13 | @echo Exracting ... 14 | @(unzip archive.zip -d datasets/cifar_10/pylearn2_gca_whitened -x / && rm archive.zip) 15 | @echo Checking the MD5 sums ... 16 | @(cd datasets/cifar_10 && md5sum --check pylearn2_gca_whitened.md5 && cd - > /dev/null) 17 | 18 | .PHONY: create-conda-env 19 | create-conda-env: 20 | @(conda env create --file environment.yml) 21 | 22 | .PHONY: clean-last-run 23 | clean-last-run: 24 | @(ls -d -t -1 ./_sessions/** | head -n 1 | xargs rm -rf) 25 | @(rm ./_sessions/latest) 26 | 27 | .PHONY: clean 28 | clean: 29 | @(rm -rf _sessions) 30 | @(rm -rf _MNIST_DATA) 31 | @(find . -type d -name __pycache__ | xargs rm -rf) 32 | @(find . -type d -name .ipynb_checkpoints | xargs rm -rf) 33 | @(rm -f tmux*) 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## An Empirical study of Binary Neural Networks' Optimisation 2 | 3 | The source code used for experiments in the paper "[An Empirical study of Binary Neural Networks' Optimisation](https://openreview.net/forum?id=rJfUCoR5KX)". 4 | 5 | The code grew organically as we tweaked more and more hyperparameters. Had I been more familiar with class-based declerations in TensorFlow (or embraded PyTorch sooner) the code would have been more elegant. 6 | 7 | ### Environment 8 | This code has been only tested with TensorFlow 1.8.0 and Python 3.5.4. The exact environment can be replicated by: 9 | 10 | `$ conda env create -f environment.yml` 11 | 12 | This would create a conda environment called `studying-bnns`. 13 | 14 | ### Usage 15 | 16 | ```bash 17 | $ conda activate studying-bnns 18 | 19 | # Run an experiment by passing args 20 | $ python run_with_args.py --model binary_connect_mlp \ 21 | --dataset mnist --epochs 250 --batch-size 100 \ 22 | --binarization deterministic-binary 23 | 24 | # Run an experiment defined in a YAML file 25 | python run_with_yaml.py some_experiment.yaml 26 | ``` 27 | 28 | An example experiment defintion in YAML file: 29 | 30 | 31 | ```yaml 32 | experiment-name: some_experiment 33 | 34 | model: binary_connect_cnn 35 | dataset: cifar10 36 | epochs: 500 37 | batch_size: 50 38 | 39 | binarization: deterministic-binary 40 | 41 | learning_rate: 42 | type: exponential-decay 43 | start: 3e-3 44 | finish: 2e-6 45 | 46 | loss: 'square_hinge_loss' 47 | 48 | optimiser: 49 | function: tf.train.AdamOptimizer 50 | kwargs: "{'beta1': 0.9, 'beta2': 0.999}" 51 | ``` 52 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mil-ad/studying-binary-neural-networks/6a59d0e7f1f83eb906e7163e90f8ec46791c452e/__init__.py -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.cifar_10.cifar10_utils import CIFAR10, CIFAR10_GCN_WHITENED 2 | from datasets.mnist.mnist_utils import MNIST 3 | -------------------------------------------------------------------------------- /datasets/cifar_10/cifar10_pylearn2_gca_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import os 4 | import sys 5 | import re 6 | import random 7 | from glob import glob 8 | import argparse 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | from pylearn2.datasets.zca_dataset import ZCA_Dataset 14 | from pylearn2.utils import serial 15 | 16 | 17 | def _float_feature(value): 18 | """Wrapper for inserting float features into Example proto.""" 19 | # if not isinstance(value, list): 20 | # value = [value] 21 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 22 | 23 | 24 | def _int64_feature(value): 25 | if not isinstance(value, list): 26 | value = [value] 27 | """Wrapper for inserting int64 features into Example proto.""" 28 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 29 | 30 | 31 | def _bytes_feature(value): 32 | if not isinstance(value, list): 33 | value = [value] 34 | """Wrapper for inserting byte features into Example proto.""" 35 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 36 | 37 | 38 | def _convert_to_example_proto(label, image): 39 | """ 40 | Build an Example proto for an example. 41 | """ 42 | example = tf.train.Example(features=tf.train.Features(feature={ 43 | 'label': _int64_feature(label), 44 | 'image': _bytes_feature(image)})) 45 | 46 | return example 47 | 48 | 49 | def create_tfrecords(name, dataset, output_dir): 50 | 51 | output_filename = os.path.join(output_dir, '{}.tfrecords'.format(name)) 52 | 53 | with tf.python_io.TFRecordWriter(output_filename) as writer: 54 | 55 | for item in zip(dataset.y, dataset.X): 56 | 57 | example = _convert_to_example_proto(np.squeeze(item[0]), 58 | item[1].tobytes()) 59 | 60 | writer.write(example.SerializeToString()) 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | print("Generating .tfrecords files ...") 66 | 67 | preprocessor = serial.load("/datasets/pylearn2_gcn_whitened/preprocessor.pkl") 68 | 69 | train_set = ZCA_Dataset( 70 | preprocessed_dataset=serial.load("/datasets/pylearn2_gcn_whitened/train.pkl"), 71 | preprocessor=preprocessor, 72 | start=0, stop=45000) 73 | valid_set = ZCA_Dataset( 74 | preprocessed_dataset=serial.load("/datasets/pylearn2_gcn_whitened/train.pkl"), 75 | preprocessor=preprocessor, 76 | start=45000, stop=50000) 77 | test_set = ZCA_Dataset( 78 | preprocessed_dataset=serial.load("/datasets/pylearn2_gcn_whitened/test.pkl"), 79 | preprocessor=preprocessor) 80 | 81 | output_dir = '/datasets/cifar_10/pylearn2_tfrecords' 82 | 83 | create_tfrecords('train', train_set, output_dir) 84 | create_tfrecords('val', valid_set, output_dir) 85 | create_tfrecords('test', test_set, output_dir) 86 | -------------------------------------------------------------------------------- /datasets/cifar_10/cifar10_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import glob 4 | import sys 5 | 6 | from six.moves import urllib 7 | import tarfile 8 | import shutil 9 | 10 | _DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 11 | _FILE_NAME = 'cifar-10-python.tar.gz' 12 | _FOLDER_NAME_AFTER_UNTAR = 'cifar-10-batches-py' 13 | 14 | FLAGS = tf.app.flags.FLAGS 15 | tf.app.flags.DEFINE_string( 16 | 'dataset_dir', '.', 17 | 'The directory where you want the dataset to be downloaded to and converted.') 18 | 19 | 20 | def unpickle(file): 21 | import cPickle 22 | with open(file, 'rb') as fo: 23 | dict = cPickle.load(fo) 24 | return dict 25 | 26 | 27 | def loadData(path_to_downloaded_dataset): 28 | 29 | print("Loading data...") 30 | num_train_data_batches = 5 31 | 32 | for i in range(num_train_data_batches): 33 | data_dict = unpickle(path_to_downloaded_dataset + 'data_batch_' + str(i+1)) 34 | if i == 0: 35 | images_train = data_dict['data'] 36 | labels_train = np.asarray(data_dict['labels']) 37 | else: 38 | images_train = np.append(images_train, data_dict['data'], axis=0) 39 | labels_train = np.append(labels_train, data_dict['labels']) 40 | 41 | print("Image_train data shape:", images_train.shape) 42 | print("Labels_train data shape:", labels_train.shape) 43 | 44 | # now we do the same for the test set 45 | data_dict = unpickle(path_to_downloaded_dataset + 'test_batch') 46 | images_test = data_dict['data'] 47 | labels_test = np.asarray(data_dict['labels']) 48 | print("Image_test data shape:", images_test.shape) 49 | print("Labels_test data shape:", labels_test.shape) 50 | 51 | return images_train, labels_train, images_test, labels_test 52 | 53 | 54 | def _int64_feature(value): 55 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 56 | 57 | 58 | def _bytes_feature(value): 59 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 60 | 61 | 62 | def reshapeImage(img_flat): 63 | img_R = img_flat[0:1024].reshape((32, 32)) 64 | img_G = img_flat[1024:2048].reshape((32, 32)) 65 | img_B = img_flat[2048:3072].reshape((32, 32)) 66 | return np.dstack((img_R, img_G, img_B)) 67 | 68 | 69 | def generateTfRecordFile(fileName, images, labels): 70 | 71 | print("Generating Tfrecord file...") 72 | 73 | if images.shape[0] != labels.shape[0]: 74 | print(" %d and %d" % (images.shape[0], labels.shape[0])) 75 | raise ValueError("dimensions mismatch!!") 76 | 77 | writer = tf.python_io.TFRecordWriter(fileName) 78 | 79 | for i in range(images.shape[0]): 80 | 81 | img = reshapeImage(images[i].astype('f')) 82 | 83 | feature = {'label': _int64_feature(labels[i]), 84 | 'image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 85 | 86 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 87 | 88 | writer.write(example.SerializeToString()) 89 | 90 | writer.close() 91 | sys.stdout.flush() 92 | print("Generated: %s" % fileName) 93 | 94 | 95 | def download_dataset(path): 96 | 97 | file_path = path + _FILE_NAME 98 | 99 | def _progress(count, block_size, total_size): 100 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 101 | sys.stdout.flush() 102 | 103 | print(_DATA_URL) 104 | print(file_path) 105 | filepath, _ = urllib.request.urlretrieve(_DATA_URL, file_path, _progress) 106 | print() 107 | 108 | with tf.gfile.GFile(filepath) as f: 109 | size = f.size() 110 | print('Successfully downloaded', _FILE_NAME, size, 'bytes.') 111 | 112 | 113 | def convert(path_to_dataset): 114 | 115 | if not tf.gfile.Exists(path_to_dataset): 116 | tf.gfile.MakeDirs(path_to_dataset) 117 | 118 | train_filename = path_to_dataset + 'train.tfrecord' 119 | test_filename = path_to_dataset + 'test.tfrecord' 120 | 121 | if tf.gfile.Exists(train_filename) and tf.gfile.Exists(test_filename): 122 | print('Dataset files already exist. Exiting without re-creating them.') 123 | return 124 | 125 | file_name = path_to_dataset + _FILE_NAME 126 | if not tf.gfile.Exists(file_name): 127 | download_dataset(path_to_dataset) 128 | print("Uncompressing dataset file...") 129 | tarfile.open(file_name, 'r:gz').extractall(path_to_dataset) 130 | else: 131 | print("Downloaded dataset found: %s" % file_name) 132 | print("Uncompressing dataset file...") 133 | tarfile.open(file_name, 'r:gz').extractall(path_to_dataset) 134 | 135 | uncompressed_data_dir = path_to_dataset+_FOLDER_NAME_AFTER_UNTAR+'/' 136 | images_train, labels_train, images_test, labels_test = loadData(uncompressed_data_dir) 137 | 138 | generateTfRecordFile(train_filename, images_train, labels_train) 139 | generateTfRecordFile(test_filename, images_test, labels_test) 140 | 141 | print("Deleting directory with uncompressed dataset from URL: %s" % uncompressed_data_dir) 142 | shutil.rmtree(uncompressed_data_dir) 143 | 144 | 145 | if __name__ == '__main__': 146 | print("Downloading and converting CIFAR-10 dataset to/in: %s" % FLAGS.dataset_dir) 147 | convert(FLAGS.dataset_dir) 148 | -------------------------------------------------------------------------------- /datasets/cifar_10/cifar10_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import tensorflow as tf 6 | 7 | from datasets.utils import DatasetBase 8 | 9 | NUM_THREADS = 28 10 | 11 | 12 | class CIFAR10_GCN_WHITENED(DatasetBase): 13 | def __init__(self, num_epochs, batch_size): 14 | 15 | DatasetBase.__init__(self, 16 | name='CIFAR-10-GCN-WHITENED', 17 | num_classes=10, 18 | num_train_samples=45000, 19 | num_val_samples=5000, 20 | num_test_samples=10000, 21 | image_size=32, 22 | channels=3) 23 | 24 | train_dataset = self._prepare_dataset( 25 | '/datasets/cifar10_gca_whitened/train.tfrecords', 26 | batch_size, num_epochs) 27 | 28 | val_dataset = self._prepare_dataset( 29 | '/datasets/cifar10_gca_whitened/val.tfrecords', batch_size) 30 | 31 | test_dataset = self._prepare_dataset( 32 | '/datasets/cifar10_gca_whitened/test.tfrecords', batch_size) 33 | 34 | self._create_iterators(train_dataset, val_dataset, test_dataset) 35 | 36 | def _prepare_dataset(self, tfrecord_path, batch_size, repeat=1): 37 | 38 | dataset = tf.data.TFRecordDataset(tfrecord_path) 39 | 40 | dataset = dataset.apply( 41 | tf.contrib.data.shuffle_and_repeat(10000, repeat)) 42 | 43 | dataset = dataset.apply( 44 | tf.contrib.data.map_and_batch(self._cifar10_mapper, 45 | batch_size, 46 | num_parallel_batches=NUM_THREADS)) 47 | 48 | dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) 49 | 50 | return dataset 51 | 52 | @staticmethod 53 | def _cifar10_mapper(dataset): 54 | 55 | feature_map = { 56 | 'image': tf.FixedLenFeature([], dtype=tf.string), 57 | 'label': tf.FixedLenFeature([], dtype=tf.int64) 58 | } 59 | 60 | parsed_features = tf.parse_single_example(dataset, feature_map) 61 | 62 | image = tf.decode_raw(parsed_features['image'], tf.float64) 63 | image = tf.reshape(image, [3, 32, 32]) 64 | image = tf.transpose(image, [1, 2, 0]) 65 | 66 | label = tf.cast(parsed_features['label'], tf.int32) 67 | one_hot_label = tf.one_hot(label, 10) 68 | one_hot_label = tf.squeeze(one_hot_label) 69 | 70 | # This is due to a bug in TF: 71 | # https://github.com/tensorflow/tensorflow/issues/18355 72 | return tf.convert_to_tensor(''), image, one_hot_label 73 | 74 | 75 | class CIFAR10(DatasetBase): 76 | def __init__(self, num_epochs, batch_size, validation_size=0): 77 | 78 | DatasetBase.__init__(self, 79 | name='CIFAR10', 80 | num_classes=10, 81 | num_train_samples=50000, 82 | num_val_samples=0, 83 | num_test_samples=10000, 84 | image_size=32, 85 | channels=3) 86 | 87 | train_dataset = tf.data.TFRecordDataset( 88 | './datasets/cifar_10/by_javier/train.tfrecord') 89 | test_dataset = tf.data.TFRecordDataset( 90 | './datasets/cifar_10/by_javier/test.tfrecord') 91 | 92 | train_dataset, val_dataset = self._split_for_validation( 93 | train_dataset, validation_size) 94 | 95 | train_dataset = self._prepare_dataset(train_dataset) 96 | test_dataset = self._prepare_dataset(test_dataset) 97 | 98 | train_dataset = train_dataset.repeat(num_epochs) 99 | 100 | train_dataset = train_dataset.batch(batch_size) 101 | val_dataset = val_dataset.batch(batch_size) 102 | test_dataset = test_dataset.batch(batch_size) 103 | 104 | self._create_iterators(train_dataset, val_dataset, test_dataset) 105 | 106 | def _prepare_dataset(self, dataset): 107 | 108 | dataset = dataset.map(self._cifar10_mapper) 109 | 110 | dataset = dataset.shuffle(buffer_size=100) 111 | 112 | return dataset 113 | 114 | @staticmethod 115 | def _cifar10_mapper(dataset): 116 | 117 | feature = {'image': tf.FixedLenFeature([], tf.string), 118 | 'label': tf.FixedLenFeature([], tf.int64)} 119 | 120 | parsed_features = tf.parse_single_example(dataset, feature) 121 | 122 | image = tf.decode_raw(parsed_features['image'], tf.float32) 123 | image = tf.reshape(image, [32, 32, 3]) 124 | 125 | label = tf.cast(parsed_features['label'], tf.int32) 126 | one_hot_label = tf.one_hot(label, 10) 127 | one_hot_label = tf.squeeze(one_hot_label) 128 | 129 | return '', image, one_hot_label 130 | -------------------------------------------------------------------------------- /datasets/mnist/download_mnist_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """tf.data.Dataset interface to the MNIST dataset.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import gzip 21 | import os 22 | import shutil 23 | import tempfile 24 | 25 | import numpy as np 26 | from six.moves import urllib 27 | import tensorflow as tf 28 | 29 | 30 | def read32(bytestream): 31 | """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" 32 | dt = np.dtype(np.uint32).newbyteorder('>') 33 | return np.frombuffer(bytestream.read(4), dtype=dt)[0] 34 | 35 | 36 | def check_image_file_header(filename): 37 | """Validate that filename corresponds to images for the MNIST dataset.""" 38 | with tf.gfile.Open(filename, 'rb') as f: 39 | magic = read32(f) 40 | read32(f) # num_images, unused 41 | rows = read32(f) 42 | cols = read32(f) 43 | if magic != 2051: 44 | raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, 45 | f.name)) 46 | if rows != 28 or cols != 28: 47 | raise ValueError( 48 | 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % 49 | (f.name, rows, cols)) 50 | 51 | 52 | def check_labels_file_header(filename): 53 | """Validate that filename corresponds to labels for the MNIST dataset.""" 54 | with tf.gfile.Open(filename, 'rb') as f: 55 | magic = read32(f) 56 | read32(f) # num_items, unused 57 | if magic != 2049: 58 | raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, 59 | f.name)) 60 | 61 | 62 | def download(directory, filename): 63 | """Download (and unzip) a file from the MNIST dataset if not already done.""" 64 | filepath = os.path.join(directory, filename) 65 | if tf.gfile.Exists(filepath): 66 | return filepath 67 | if not tf.gfile.Exists(directory): 68 | tf.gfile.MakeDirs(directory) 69 | # CVDF mirror of http://yann.lecun.com/exdb/mnist/ 70 | url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' 71 | _, zipped_filepath = tempfile.mkstemp(suffix='.gz') 72 | print('Downloading %s to %s' % (url, zipped_filepath)) 73 | urllib.request.urlretrieve(url, zipped_filepath) 74 | with gzip.open(zipped_filepath, 'rb') as f_in, \ 75 | tf.gfile.Open(filepath, 'wb') as f_out: 76 | shutil.copyfileobj(f_in, f_out) 77 | os.remove(zipped_filepath) 78 | return filepath 79 | 80 | 81 | def dataset(directory, images_file, labels_file): 82 | """Download and parse MNIST dataset.""" 83 | 84 | images_file = download(directory, images_file) 85 | labels_file = download(directory, labels_file) 86 | 87 | check_image_file_header(images_file) 88 | check_labels_file_header(labels_file) 89 | 90 | def decode_image(image): 91 | # Normalize from [0, 255] to [0.0, 1.0] 92 | image = tf.decode_raw(image, tf.uint8) 93 | image = tf.cast(image, tf.float32) 94 | image = tf.reshape(image, [784]) 95 | return image / 255.0 96 | 97 | def decode_label(label): 98 | label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] 99 | label = tf.reshape(label, []) # label is a scalar 100 | return tf.to_int32(label) 101 | 102 | images = tf.data.FixedLengthRecordDataset( 103 | images_file, 28 * 28, header_bytes=16).map(decode_image) 104 | labels = tf.data.FixedLengthRecordDataset( 105 | labels_file, 1, header_bytes=8).map(decode_label) 106 | return tf.data.Dataset.zip((images, labels)) 107 | 108 | 109 | def train(directory): 110 | """tf.data.Dataset object for MNIST training data.""" 111 | return dataset(directory, 'train-images-idx3-ubyte', 112 | 'train-labels-idx1-ubyte') 113 | 114 | 115 | def test(directory): 116 | """tf.data.Dataset object for MNIST test data.""" 117 | return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') 118 | -------------------------------------------------------------------------------- /datasets/mnist/mnist_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import sys 6 | 7 | import tensorflow as tf 8 | from tensorflow.examples.tutorials.mnist import input_data as mnist_data 9 | 10 | # This is a copy from TF's official models. It's supposed to become available 11 | # as a pip package at some point : 12 | # See [this](https://github.com/tensorflow/models/issues/917) 13 | from . import download_mnist_dataset 14 | from datasets.utils import DatasetBase 15 | 16 | NUM_THREADS = 28 17 | 18 | 19 | class HiddenPrints: 20 | """Used to disable TF's annoying prints when reading MNIST data""" 21 | def __enter__(self): 22 | self._stdout_old = sys.stdout 23 | sys.stdout = None 24 | 25 | def __exit__(self, exc_type, exc_val, exc_tb): 26 | sys.stdout = self._stdout_old 27 | 28 | 29 | class MNIST(DatasetBase): 30 | def __init__(self, num_epochs, batch_size): 31 | 32 | DatasetBase.__init__(self, 33 | name='MNIST', 34 | num_classes=10, 35 | num_train_samples=60000, 36 | num_val_samples=0, 37 | num_test_samples=10000, 38 | image_size=28, 39 | channels=1) 40 | 41 | with HiddenPrints(): 42 | train_dataset = download_mnist_dataset.train('./_MNIST_DATA/') 43 | test_dataset = download_mnist_dataset.test('./_MNIST_DATA/') 44 | 45 | # Set aside 5000 images for validation 46 | train_dataset, val_dataset = self._split_for_validation( 47 | train_dataset, 5000) 48 | 49 | train_dataset = self._prepare_dataset( 50 | train_dataset, batch_size, num_epochs) 51 | val_dataset = self._prepare_dataset(val_dataset, batch_size) 52 | test_dataset = self._prepare_dataset(test_dataset, batch_size) 53 | 54 | self._create_iterators(train_dataset, val_dataset, test_dataset) 55 | 56 | def _prepare_dataset(self, dataset, batch_size, repeat=1): 57 | 58 | dataset = dataset.apply( 59 | tf.contrib.data.shuffle_and_repeat(10000, repeat)) 60 | 61 | dataset = dataset.apply( 62 | tf.contrib.data.map_and_batch(self._mapper, 63 | batch_size, 64 | num_parallel_batches=NUM_THREADS)) 65 | 66 | return dataset 67 | 68 | @staticmethod 69 | def _mapper(image, label): 70 | image = tf.reshape(image, [28, 28, 1]) 71 | image = tf.image.per_image_standardization(image) 72 | label = tf.cast(label, tf.int32) 73 | one_hot_label = tf.one_hot(label, 10) 74 | one_hot_label = tf.squeeze(one_hot_label) 75 | 76 | return tf.convert_to_tensor(''), image, one_hot_label 77 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division, print_function, absolute_import 3 | 4 | import tensorflow as tf 5 | 6 | 7 | class DatasetBase: 8 | """ 9 | Usage: 10 | 11 | Before a dataset object can be used it needs to evaluate dataset 12 | handles inside session: 13 | 14 | dataset.evaluate_handles(sess) 15 | 16 | Once the handles are evaluated you can get elements from the training 17 | dataset by evaluating dataset.next_element and passing the dataset 18 | selector: 19 | 20 | _, images, labels = sess.run( 21 | dataset.next_element, 22 | feed_dict=dataset.from_training_set()) 23 | 24 | The validation and testing datasets are re-initialisable so that you 25 | can go through them multiple times during training. In order to switch 26 | to validation dataset to run some evaluation first initialize it and 27 | then get elements: 28 | 29 | dataset.initialize_validation(sess) 30 | _, images, labels = sess.run( 31 | dataset.next_element, 32 | feed_dict=dataset.from_validation_set()) 33 | 34 | To switch back to the next element from training dataset simply use 35 | from_training_set() selector. """ 36 | 37 | def __init__(self, name, num_classes, num_train_samples, 38 | num_val_samples, num_test_samples, image_size, channels): 39 | self._name = name 40 | self._num_classes = num_classes 41 | self._num_train_samples = num_train_samples 42 | self._num_val_samples = num_val_samples 43 | self._num_test_samples = num_test_samples 44 | self._image_size = image_size 45 | self._channels = channels 46 | 47 | def _split_for_validation(self, train_dataset, val_size): 48 | 49 | assert(self._num_val_samples == 0) 50 | 51 | self._num_train_samples -= val_size 52 | self._num_val_samples = val_size 53 | 54 | val_dataset = train_dataset.skip(self.num_train_samples) 55 | train_dataset = train_dataset.take(self.num_train_samples) 56 | 57 | return train_dataset, val_dataset 58 | 59 | def _create_iterators(self, train_dataset, val_dataset, test_dataset): 60 | 61 | self._handle = tf.placeholder(tf.string, shape=[]) 62 | iterator = tf.data.Iterator.from_string_handle( 63 | self._handle, 64 | train_dataset.output_types, 65 | train_dataset.output_shapes) 66 | self._next_element = iterator.get_next() 67 | 68 | self._training_iterator = train_dataset.make_one_shot_iterator() 69 | self._validation_iterator = val_dataset.make_initializable_iterator() 70 | self._testing_iterator = test_dataset.make_initializable_iterator() 71 | 72 | def evaluate_handles(self, sess): 73 | """ 74 | Evaluate the tensors that are used to feed the `handle` placeholder to 75 | switch between datasets. 76 | """ 77 | self._training_handle = sess.run( 78 | self._training_iterator.string_handle()) 79 | self._validation_handle = sess.run( 80 | self._validation_iterator.string_handle()) 81 | self._testing_handle = sess.run( 82 | self._testing_iterator.string_handle()) 83 | 84 | def initialize_validation(self, sess): 85 | sess.run(self._validation_iterator.initializer) 86 | 87 | def initialize_testing(self, sess): 88 | sess.run(self._testing_iterator.initializer) 89 | 90 | def from_training_set(self): 91 | return {self._handle: self._training_handle} 92 | 93 | def from_validation_set(self): 94 | return {self._handle: self._validation_handle} 95 | 96 | def from_testing_set(self): 97 | return {self._handle: self._testing_handle} 98 | 99 | @property 100 | def name(self): 101 | return self._name 102 | 103 | @property 104 | def num_classes(self): 105 | return self._num_classes 106 | 107 | @property 108 | def num_train_samples(self): 109 | return self._num_train_samples 110 | 111 | @property 112 | def num_val_samples(self): 113 | return self._num_val_samples 114 | 115 | @property 116 | def num_test_samples(self): 117 | return self._num_test_samples 118 | 119 | @property 120 | def image_size(self): 121 | return self._image_size 122 | 123 | @property 124 | def channels(self): 125 | return self._channels 126 | 127 | @property 128 | def next_element(self): 129 | return self._next_element 130 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # platform: linux-64 2 | 3 | name: studying-bnns 4 | dependencies: 5 | - tensorflow-gpu=1.8.0 6 | - pyyaml 7 | - python=3.5.4 8 | - pip 9 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.alexnet import AlexNet 2 | -------------------------------------------------------------------------------- /models/binaryconnect.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Courbariaux, Matthieu, Yoshua Bengio, and Jean-Pierre David. "Binaryconnect: 5 | Training deep neural networks with binary weights during propagations." 6 | Advances in neural information processing systems. 2015. 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import tensorflow as tf 14 | 15 | import math 16 | 17 | 18 | def lr_mult(alpha): 19 | @tf.custom_gradient 20 | def _lr_mult(x): 21 | def grad(dy): 22 | return dy * alpha * tf.ones_like(x) 23 | return x, grad 24 | return _lr_mult 25 | 26 | 27 | class BinaryConnect(object): 28 | """The base class defining core binary operations used in BinaryConnect. 29 | """ 30 | def __init__(self, is_training, BN_momentum, binary=True, stochastic=False, 31 | weight_decay=0.0, 32 | disable_weight_constraint=False, disable_gradient_clipping=False, 33 | enable_glorot_scaling=False): 34 | 35 | self.regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 36 | self.initializer = tf.contrib.layers.xavier_initializer() 37 | 38 | self.is_binary = binary 39 | self.is_stochastic = stochastic 40 | self.is_training = is_training 41 | self.BN_momentum = BN_momentum 42 | self.enable_glorot_scaling = enable_glorot_scaling 43 | 44 | if disable_weight_constraint or not binary: 45 | self._weight_constraint = None 46 | else: 47 | self._weight_constraint = lambda x: tf.clip_by_value(x, -1, 1) 48 | 49 | self.disable_gradient_clipping = disable_gradient_clipping 50 | 51 | @property 52 | def output(self): 53 | return self._output 54 | 55 | @staticmethod 56 | def glorot_LR_scale(x): 57 | 58 | shape = x.get_shape().as_list() 59 | 60 | fan_in = float(shape[-2]) if len(shape) > 1 else float(shape[-1]) 61 | fan_out = float(shape[-1]) 62 | n = (fan_in + fan_out) / 2.0 63 | limit = math.sqrt(1.5 / n) 64 | 65 | return lr_mult(1.0/limit) 66 | 67 | @staticmethod 68 | def deterministic_binary_op(input_op): 69 | g = tf.get_default_graph() 70 | with g.gradient_override_map({"Sign": "Identity"}): 71 | x = tf.clip_by_value(input_op, -1.0, 1.0) 72 | return tf.sign(x) 73 | 74 | @staticmethod 75 | def deterministic_binary_op_pure_ste(input_op): 76 | """No gradient clipping""" 77 | g = tf.get_default_graph() 78 | with g.gradient_override_map({"Sign": "Identity"}): 79 | return tf.sign(input_op) 80 | 81 | @staticmethod 82 | def stochastic_binary_op(input_op): 83 | p = tf.clip_by_value((input_op+1.0)/2, 0, 1) # Hard sigmoid 84 | 85 | forward_path = (2. * tf.cast( 86 | tf.greater(p, tf.random_uniform(tf.shape(p))), tf.float32)) - 1. 87 | 88 | backward_path = tf.clip_by_value((input_op), -1.0, 1.0) 89 | 90 | return backward_path + tf.stop_gradient(forward_path - backward_path) 91 | 92 | def binarize(self, input_op): 93 | """Binarizes weights in the forward pass and uses Straight-Through 94 | Estimator in the backwards pass (with hard limits to stop gradients 95 | flowing backwards when the input is too large) 96 | """ 97 | assert (self.is_binary is True) # sanity check 98 | 99 | if self.is_stochastic: 100 | # In stochastic scenario, BinaryConnect only uses binarization to 101 | # achieve faster training and during test time the real-valued 102 | # weights are used. 103 | out_op = tf.where(self.is_training, self.stochastic_binary_op(input_op), input_op) 104 | elif self.disable_gradient_clipping: 105 | out_op = self.deterministic_binary_op_pure_ste(input_op) 106 | else: 107 | out_op = self.deterministic_binary_op(input_op) 108 | 109 | if self.enable_glorot_scaling: 110 | return self.glorot_LR_scale(out_op)(out_op) 111 | else: 112 | return out_op 113 | 114 | def _get_weights(self, shape, name): 115 | 116 | w_full = tf.get_variable(name=name, 117 | shape=shape, 118 | initializer=self.initializer, 119 | regularizer=self.regularizer, 120 | constraint=self._weight_constraint) 121 | 122 | # tf.summary.histogram('weights full-precision', w_full) 123 | 124 | if self.is_binary: 125 | w_bin = self.binarize(w_full) 126 | # tf.summary.histogram('weights binary', w_bin) 127 | return w_bin 128 | else: 129 | return w_full 130 | 131 | def _batch_norm_layer(self, x): 132 | if self.BN_momentum is None: 133 | return tf.identity(x, 'batch_norm_bypass') 134 | else: 135 | return tf.layers.batch_normalization( 136 | x, axis=-1, epsilon=1e-4, center=True, scale=True, 137 | momentum=self.BN_momentum, training=self.is_training) 138 | 139 | 140 | class MLP(BinaryConnect): 141 | """Multi-Layer Perceptron used for MNIST. No convolution layers. 142 | """ 143 | def __init__(self, input_op, is_training, keep_prob, num_classes, 144 | binary, stochastic, units_per_layer=2048, weight_decay=0.0, 145 | BN_momentum=0.85, 146 | disable_weight_constraint=False, disable_gradient_clipping=False, 147 | enable_glorot_scaling=False): 148 | 149 | assert(keep_prob == 1.0) 150 | 151 | BinaryConnect.__init__(self, is_training, BN_momentum, binary, 152 | stochastic, weight_decay, 153 | disable_weight_constraint, disable_gradient_clipping, 154 | enable_glorot_scaling) 155 | 156 | self._output = self._build_model(input_op, units_per_layer, 157 | num_classes, keep_prob) 158 | 159 | @property 160 | def name(self): 161 | return 'BinaryConnect_MLP' 162 | 163 | def _build_model(self, input_op, units_per_layer, num_classes, keep_prob): 164 | 165 | model = tf.layers.flatten(input_op) # preserves the batch axis 166 | 167 | model = tf.nn.dropout(model, keep_prob) 168 | 169 | model = self._dense("fc_layer1", model, units_per_layer, keep_prob) 170 | model = self._dense("fc_layer2", model, units_per_layer, keep_prob) 171 | model = self._dense("fc_layer3", model, units_per_layer, keep_prob) 172 | 173 | model = self._last_layer(model, num_classes) 174 | 175 | model = tf.identity(model, 'model_output') # just to give it a name 176 | 177 | return model 178 | 179 | def _dense(self, name, input_op, num_units, keep_prob): 180 | 181 | input_dim = input_op.get_shape().as_list()[-1] 182 | 183 | with tf.variable_scope(name) as scope: 184 | 185 | layer = self._get_weights([input_dim, num_units], 'weights') 186 | # tf.summary.histogram('weights binary', w_binary) 187 | layer = tf.matmul(input_op, layer) 188 | layer = self._batch_norm_layer(layer) 189 | layer = tf.nn.relu(layer) 190 | layer = tf.nn.dropout(layer, keep_prob) 191 | 192 | return layer 193 | 194 | def _last_layer(self, input_op, num_classes): 195 | 196 | input_dim = input_op.get_shape().as_list()[-1] 197 | 198 | with tf.variable_scope('fc_final') as scope: 199 | 200 | weights = self._get_weights( 201 | [input_dim, num_classes], name='weights') 202 | layer = tf.matmul(input_op, weights) 203 | layer = self._batch_norm_layer(layer) 204 | 205 | return layer 206 | 207 | 208 | class CNN(BinaryConnect): 209 | def __init__(self, input_op, is_training, num_classes, 210 | binary, stochastic, weight_decay=0.0, BN_momentum=0.9, 211 | disable_weight_constraint=False, disable_gradient_clipping=False, 212 | enable_glorot_scaling=False): 213 | 214 | BinaryConnect.__init__(self, 215 | is_training, 216 | BN_momentum, 217 | binary, 218 | stochastic, 219 | weight_decay, 220 | disable_weight_constraint, 221 | disable_gradient_clipping, 222 | enable_glorot_scaling) 223 | 224 | self._output = self._build_model(input_op, num_classes) 225 | 226 | @property 227 | def name(self): 228 | return 'BinaryConnect_CNN' 229 | 230 | def _build_model(self, input_op, num_classes): 231 | 232 | model = self._conv_layer('conv_1', input_op, 128) 233 | model = self._conv_layer('conv_2', model, 128, pool=True) 234 | model = self._conv_layer('conv_3', model, 256) 235 | model = self._conv_layer('conv_4', model, 256, pool=True) 236 | model = self._conv_layer('conv_5', model, 512) 237 | model = self._conv_layer('conv_6', model, 512, pool=True) 238 | 239 | model = tf.layers.flatten(model) # preserves the batch axis 240 | 241 | model = self._dense('fc1', model, 1024) 242 | model = self._dense('fc2', model, 1024) 243 | model = self._dense('fc3', model, num_classes, relu=False) 244 | 245 | model = tf.identity(model, 'model_output') # Just to give it a name 246 | 247 | # tf.summary.histogram('final_activations', model) 248 | 249 | return model 250 | 251 | def _conv_layer(self, name, input_op, out_channels, pool=False): 252 | 253 | in_channels = input_op.get_shape().as_list()[-1] 254 | 255 | with tf.variable_scope(name) as scope: 256 | 257 | weights = self._get_weights([3, 3, in_channels, out_channels], 258 | 'weights') 259 | 260 | layer = tf.nn.conv2d(input=input_op, 261 | filter=weights, 262 | strides=[1, 1, 1, 1], 263 | padding='SAME') 264 | 265 | if pool: 266 | layer = tf.nn.max_pool(layer, 267 | ksize=[1, 2, 2, 1], 268 | strides=[1, 2, 2, 1], 269 | padding='VALID') 270 | 271 | layer = self._batch_norm_layer(layer) 272 | 273 | layer = tf.nn.relu(layer) 274 | 275 | return layer 276 | 277 | def _dense(self, name, input_op, num_units, relu=True): 278 | 279 | input_dim = input_op.get_shape().as_list()[-1] 280 | 281 | with tf.variable_scope(name) as scope: 282 | 283 | layer = self._get_weights([input_dim, num_units], 'weights') 284 | # tf.summary.histogram('weights_binary', w_binary) 285 | layer = tf.matmul(input_op, layer) 286 | layer = self._batch_norm_layer(layer) 287 | 288 | if relu: 289 | layer = tf.nn.relu(layer) 290 | 291 | return layer 292 | -------------------------------------------------------------------------------- /optimisers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | 5 | """ 6 | from __future__ import division, print_function, absolute_import 7 | import tensorflow as tf 8 | 9 | 10 | def square_hinge_loss(labels, predictions): 11 | """TF has builtin hinge loss function but not the squared version. 12 | There are multiple definitions for multi-class hinge loss; this one is 13 | based on the implementation in BinaryConnect/BinaryNets papers. 14 | """ 15 | polar_labels = tf.cast((labels*2)-1, tf.float32) # [0,1] -> [-1,1] 16 | 17 | hinge_loss = tf.maximum(0.0, 1.0-tf.multiply(polar_labels, predictions)) 18 | 19 | return tf.reduce_mean(tf.square(hinge_loss)) 20 | 21 | 22 | def binary_connect_optimiser(global_step_op, NUM_EPOCHS, steps_per_epoch, 23 | labels, model_output, start_lr, finish_lr): 24 | 25 | # Apply exponential LR decay at the end of each epoch 26 | decaye_rate = (finish_lr/start_lr)**(1.0/NUM_EPOCHS) 27 | learning_rate = tf.train.exponential_decay(start_lr, global_step_op, 28 | decay_steps=steps_per_epoch, 29 | decay_rate=decaye_rate, 30 | staircase=True) 31 | 32 | # No explicit weight regularization so the only loss is square hinge loss 33 | total_loss = square_hinge_loss(labels, model_output) 34 | 35 | # This is necessary because of tf.layers.batch_normalization() 36 | # See https://tensorflow.org/api_docs/python/tf/layers/batch_normalization 37 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 38 | with tf.control_dependencies(update_ops): 39 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss, global_step_op) 40 | 41 | tf.summary.scalar('Total loss', total_loss) 42 | tf.summary.scalar('Learning Rate', learning_rate) 43 | 44 | return train_op, total_loss 45 | 46 | 47 | def alexnet_optimiser(global_step_op, labels, model_output): 48 | # NUM_EPOCHS = 90 49 | # BATCH_SIZE = 128 50 | # AlexNet: Divide learning rate by 10 when the validation error rate stops 51 | # improving. 52 | # TODO: We're not following above here because I don't know when validation 53 | # error rate stops imporving. For now let's just decay by %25 every 100k steps 54 | INITIAL_LEARNING_RATE = 0.01 55 | learning_rate = tf.train.exponential_decay( 56 | INITIAL_LEARNING_RATE, global_step_op, 57 | decay_steps=100000, decay_rate=0.75, staircase=True) 58 | 59 | MOMENTUM = 0.9 60 | WEIGHT_DECAY = 0.0005 61 | loss = tf.losses.softmax_cross_entropy(labels, model_output) 62 | reg_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 63 | total_loss = loss + reg_loss 64 | 65 | optimizer = tf.train.MomentumOptimizer(learning_rate, MOMENTUM) 66 | train_op = optimizer.minimize(total_loss, global_step_op) 67 | 68 | tf.summary.scalar('classifier loss', loss) 69 | tf.summary.scalar('reg loss', reg_loss) 70 | tf.summary.scalar('total loss', total_loss) 71 | tf.summary.scalar('Learning Rate', learning_rate) 72 | 73 | return train_op, total_loss 74 | -------------------------------------------------------------------------------- /run_with_args.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Top-level file. Parses the passed arguments to select the model, dataset and 4 | optimisation and passes them to the training routine. 5 | """ 6 | from __future__ import division, print_function, absolute_import 7 | 8 | import sys 9 | import argparse 10 | import tensorflow as tf 11 | 12 | from datasets import CIFAR10, CIFAR10_GCN_WHITENED 13 | from datasets import MNIST 14 | 15 | from models import AlexNet 16 | import models.binaryconnect as binaryconnect 17 | 18 | from optimisers import binary_connect_optimiser, alexnet_optimiser 19 | 20 | from train_utils import train 21 | 22 | _RANDOM_SEED = 1234 23 | 24 | 25 | def get_dataset(dataset_name, num_epochs, batch_size): 26 | 27 | if dataset_name == 'mnist': 28 | dataset = MNIST(num_epochs, batch_size) 29 | elif dataset_name == 'cifar10': 30 | # dataset = CIFAR10(num_epochs, batch_size, validation_size=5000) 31 | dataset = CIFAR10_GCN_WHITENED(num_epochs, batch_size) 32 | else: 33 | raise ValueError('Dataset option not valid.') 34 | 35 | return dataset 36 | 37 | 38 | def get_model_fn(model_name, binarization): 39 | 40 | if binarization == 'deterministic-binary': 41 | binary = True 42 | stochastic = False 43 | elif binarization == 'stochastic-binary': 44 | binary = True 45 | stochastic = True 46 | elif binarization == 'disabled': 47 | binary = False 48 | stochastic = False 49 | else: 50 | print('ERROR!') # TODO 51 | 52 | def model_fn(input_images, num_classes, is_training, keep_prob): 53 | if model_name == 'alexnet': 54 | model = AlexNet(input_images, keep_prob, num_classes, 55 | weight_decay=0.0005) 56 | 57 | elif model_name == 'binary_connect_mlp': 58 | model = binaryconnect.MLP(input_images, is_training, 59 | 1.0, num_classes, 60 | binary, stochastic) 61 | 62 | elif model_name == 'binary_connect_cnn': 63 | # Paper settings: 500 epochs, batch size 50 64 | model = binaryconnect.CNN(input_images, is_training, num_classes, 65 | binary, stochastic) 66 | 67 | return model 68 | 69 | return model_fn 70 | 71 | 72 | def get_optimiser_fn(model_name, num_epochs, batch_size, dataset): 73 | 74 | steps_per_epoch = dataset.num_train_samples // batch_size 75 | 76 | def optimiser_fn(labels, model_output): 77 | 78 | global_step_op = tf.train.get_global_step() 79 | 80 | if model_name == 'alexnet': 81 | train_op, loss = alexnet_optimiser( 82 | global_step_op, labels, model_output) 83 | 84 | elif model_name in ['binary_connect_mlp', 'binarynet_mlp']: 85 | train_op, loss = binary_connect_optimiser( 86 | global_step_op, num_epochs, steps_per_epoch, labels, 87 | model_output, 1e-3, 3e-6 88 | ) 89 | 90 | elif model_name == 'binary_connect_cnn': 91 | train_op, loss = binary_connect_optimiser( 92 | global_step_op, num_epochs, steps_per_epoch, labels, 93 | model_output, 3e-3, 2e-6 94 | ) 95 | 96 | else: 97 | print("Error!") 98 | 99 | return train_op, loss 100 | 101 | return optimiser_fn 102 | 103 | 104 | def args_parser(args): 105 | parser = argparse.ArgumentParser(description='TODO') 106 | 107 | parser.add_argument('-m', '--model', choices=[ 108 | 'alexnet', 'xnornet', 109 | 'binary_connect_mlp', 'binary_connect_cnn', 110 | 'binarynet_mlp'], 111 | action='store', required=True, help='TODO') 112 | 113 | parser.add_argument('-d', '--dataset', choices=['mnist', 'cifar10', 114 | 'cifar100', 'imagenet'], 115 | action='store', required=True, help='TODO') 116 | 117 | parser.add_argument('-e', '--epochs', action='store', default=250, 118 | type=int, help='Number of Epochs (Default: 250)') 119 | 120 | parser.add_argument('-b', '--batch-size', action='store', default=100, 121 | type=int, help='Batch Size (Default: 100') 122 | 123 | parser.add_argument('-r', '--resume-from-latest-checkpoint', 124 | action='store_true', required=False, help='TODO') 125 | 126 | parser.add_argument('-t', '--tag', action='store', required=False, 127 | help='Set a tag for the test run. Overrides default unique name') 128 | 129 | parser.add_argument('-f', '--freeze', action='store_true', required=False, 130 | help='Freeze the model after training.') 131 | 132 | parser.add_argument('--binarization', 133 | choices=['deterministic-binary', 134 | 'stochastic-binary', 135 | 'disabled'], 136 | action='store', 137 | required=False, 138 | default='deterministic-binary', 139 | help='binarization mode') 140 | 141 | return parser.parse_args() 142 | 143 | 144 | if __name__ == '__main__': 145 | 146 | tf.set_random_seed(_RANDOM_SEED) 147 | 148 | parsed_args = args_parser(sys.argv) 149 | 150 | dataset = get_dataset(parsed_args.dataset, parsed_args.epochs, 151 | parsed_args.batch_size) 152 | 153 | train(parsed_args.epochs, 154 | parsed_args.batch_size, 155 | dataset, 156 | get_model_fn(parsed_args.model, parsed_args.binarization), 157 | get_optimiser_fn(parsed_args.model, parsed_args.epochs, parsed_args.batch_size, dataset), 158 | parsed_args.resume_from_latest_checkpoint, 159 | parsed_args.tag, 160 | parsed_args.freeze) 161 | -------------------------------------------------------------------------------- /run_with_yaml.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import yaml 4 | import sys 5 | 6 | import tensorflow as tf 7 | 8 | from datasets import CIFAR10, CIFAR10_GCN_WHITENED 9 | from datasets import MNIST 10 | 11 | from models import AlexNet 12 | import models.binaryconnect as binaryconnect 13 | 14 | from train_utils import train 15 | 16 | _RANDOM_SEED = 1234 17 | 18 | 19 | def get_dataset(dataset_name, num_epochs, batch_size): 20 | 21 | if dataset_name == 'mnist': 22 | dataset = MNIST(num_epochs, batch_size) 23 | elif dataset_name == 'cifar10': 24 | # dataset = CIFAR10(num_epochs, batch_size, validation_size=5000) 25 | dataset = CIFAR10_GCN_WHITENED(num_epochs, batch_size) 26 | else: 27 | raise ValueError('Dataset option not valid.') 28 | 29 | return dataset 30 | 31 | 32 | def get_model_fn(model_name, binarization, disable_batch_norm, disable_weight_constraint, disable_gradient_clipping, 33 | enable_glorot_scaling): 34 | 35 | if binarization == 'deterministic-binary': 36 | binary = True 37 | stochastic = False 38 | elif binarization == 'stochastic-binary': 39 | binary = True 40 | stochastic = True 41 | elif binarization == 'disabled': 42 | binary = False 43 | stochastic = False 44 | else: 45 | print('ERROR!') # TODO 46 | 47 | kwargs = {} 48 | 49 | if disable_batch_norm is True: 50 | kwargs['BN_momentum'] = None 51 | 52 | kwargs['disable_weight_constraint'] = disable_weight_constraint 53 | kwargs['disable_gradient_clipping'] = disable_gradient_clipping 54 | kwargs['enable_glorot_scaling'] = enable_glorot_scaling 55 | 56 | def model_fn(input_images, num_classes, is_training, keep_prob): 57 | if model_name == 'alexnet': 58 | model = AlexNet(input_images, keep_prob, num_classes, 59 | weight_decay=0.0005) 60 | 61 | elif model_name == 'binary_connect_mlp': 62 | model = binaryconnect.MLP(input_images, is_training, 63 | 1.0, num_classes, 64 | binary, stochastic, **kwargs) 65 | 66 | elif model_name == 'binary_connect_cnn': 67 | # Paper settings: 500 epochs, batch size 50 68 | model = binaryconnect.CNN(input_images, is_training, num_classes, 69 | binary, stochastic, **kwargs) 70 | 71 | return model 72 | 73 | return model_fn 74 | 75 | 76 | def get_learning_rate_fn(config, dataset): 77 | 78 | if config['learning_rate']['type'] == 'exponential-decay': 79 | 80 | start_lr = float(config['learning_rate']['start']) 81 | finish_lr = float(config['learning_rate']['finish']) 82 | num_epochs = int(config['epochs']) 83 | steps_per_epoch = dataset.num_train_samples // config['batch_size'] 84 | 85 | def exponential_decay_lr(): 86 | 87 | global_step_op = tf.train.get_global_step() 88 | 89 | decaye_rate = (finish_lr/start_lr)**(1.0/num_epochs) 90 | 91 | learning_rate = tf.train.exponential_decay(start_lr, global_step_op, 92 | decay_steps=steps_per_epoch, 93 | decay_rate=decaye_rate, 94 | staircase=True) 95 | 96 | return(learning_rate) 97 | 98 | return exponential_decay_lr 99 | 100 | elif config['learning_rate']['type'] == 'piecewise_constant': 101 | 102 | steps_per_epoch = dataset.num_train_samples // config['batch_size'] 103 | 104 | def piecewise_lr(): 105 | 106 | global_step_op = tf.train.get_global_step() 107 | 108 | epoch_boundaries = eval(config['learning_rate']['epoch_boundaries']) 109 | step_boundaries = [epoch * steps_per_epoch for epoch in epoch_boundaries] 110 | 111 | values = eval(config['learning_rate']['values']) 112 | 113 | learning_rate = tf.train.piecewise_constant(global_step_op, step_boundaries, values) 114 | 115 | return learning_rate 116 | 117 | return piecewise_lr 118 | 119 | else: 120 | assert(0) 121 | 122 | 123 | def get_loss_fn(config): 124 | 125 | # Make sure it gives loss operator the name: total_loss 126 | 127 | if config['loss'] == 'square_hinge_loss': 128 | 129 | def square_hinge_loss(labels, predictions): 130 | """TF has builtin hinge loss function but not the squared version. 131 | There are multiple definitions for multi-class hinge loss; this one is 132 | based on the implementation in BinaryConnect/BinaryNets papers. 133 | """ 134 | polar_labels = tf.cast((labels*2)-1, tf.float32) # [0,1] -> [-1,1] 135 | 136 | hinge_loss = tf.maximum(0.0, 1.0-tf.multiply(polar_labels, predictions)) 137 | 138 | return tf.reduce_mean(tf.square(hinge_loss), name='total_loss') 139 | 140 | return square_hinge_loss 141 | 142 | elif config['loss'] == 'softmax_cross_entropy': 143 | 144 | def loss_fn(labels, predictions): 145 | return tf.losses.softmax_cross_entropy(labels, predictions) 146 | 147 | return loss_fn 148 | 149 | else: 150 | assert(0) 151 | 152 | 153 | def get_optimiser_fn(config, lr_fn, loss_fn): 154 | 155 | def optimiser_fn(actual_labels, model_output): 156 | 157 | learning_rate = lr_fn() 158 | total_loss = loss_fn(actual_labels, model_output) 159 | optimiser_op = eval(config['optimiser']['function'])(learning_rate, **eval(config['optimiser']['kwargs'])) 160 | 161 | # This is necessary because of tf.layers.batch_normalization() 162 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 163 | with tf.control_dependencies(update_ops): 164 | train_op = optimiser_op.minimize(total_loss, tf.train.get_global_step()) 165 | 166 | tf.summary.scalar('Total loss', total_loss) 167 | tf.summary.scalar('Learning Rate', learning_rate) 168 | 169 | return train_op, total_loss 170 | 171 | return optimiser_fn 172 | 173 | 174 | if __name__ == '__main__': 175 | 176 | with open(sys.argv[1], 'r', ) as f: 177 | config = yaml.load(f) 178 | 179 | if config.get('fixed_seed', True): 180 | tf.set_random_seed(_RANDOM_SEED) 181 | 182 | dataset = get_dataset(config['dataset'], config['epochs'], config['batch_size']) 183 | model_fn = get_model_fn( 184 | config['model'], config['binarization'], 185 | config.get('disable_batch_norm', False), 186 | config.get('disable_weight_constraint', False), 187 | config.get('disable_gradient_clipping', False), 188 | config.get('enable_glorot_scaling', False), 189 | ) 190 | 191 | lr_fn = get_learning_rate_fn(config, dataset) 192 | loss_fn = get_loss_fn(config) 193 | optimiser_fn = get_optimiser_fn(config, lr_fn, loss_fn) 194 | 195 | train(config['epochs'], 196 | config['batch_size'], 197 | dataset, 198 | model_fn, 199 | optimiser_fn, 200 | False, 201 | config['experiment-name'], 202 | False) 203 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division, print_function, absolute_import 3 | 4 | import os 5 | import sys 6 | import subprocess 7 | from datetime import datetime 8 | from pprint import pprint, pformat 9 | import logging 10 | from glob import glob 11 | from collections import namedtuple 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | _DEFAULT_SESSIONS_PATH = './_sessions/' 17 | _PRINT_SUMMARY_FREQ = 50 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | ############################################################################### 23 | # Some helper functions 24 | ############################################################################### 25 | def compute_accuracy(oneshot_labels, predictions, k): 26 | """Computes Top-k accuracy. Note that the behaviour of in_top_k differs from 27 | the top_k op in its handling of ties; if multiple classes have the same 28 | prediction value and straddle the top-k boundary, all of those classes are 29 | considered to be in the top k.""" 30 | correct_mask = tf.nn.in_top_k(predictions, tf.argmax(oneshot_labels, 1), 31 | k, name="top_{}_correct_mask".format(k)) 32 | return tf.reduce_mean(tf.cast(correct_mask, tf.float32)) * 100.00 33 | 34 | 35 | def get_num_trainable_params(): 36 | return sum(np.prod(p.get_shape().as_list()) 37 | for p in tf.trainable_variables()) 38 | 39 | 40 | def get_timedelta(start_time): 41 | delta = datetime.now() - start_time 42 | 43 | # timedelta objects don't have hours and seconds! 44 | hours, remainder = divmod(delta.seconds, 3600) 45 | minutes, seconds = divmod(remainder, 60) 46 | 47 | return('[{:01d} day(s), {:02d} hr(s), {:02d} min(s)]'. 48 | format(delta.days, hours, minutes)) 49 | 50 | 51 | def time_per_step(start_time, num_steps): 52 | delta = datetime.now() - start_time 53 | 54 | return (delta / num_steps).total_seconds() * 1000 55 | 56 | 57 | def get_resettable_mean_metric(values, scope): 58 | """Used for validation and test datasets where we're interested in computing 59 | accuracy over the entire dataset but due to their size still have to loop 60 | through them via batches. Since we do this multiple times during training 61 | we'd like the metric to be resettable.""" 62 | 63 | with tf.variable_scope(scope) as s: 64 | mean_op, update_mean_op = tf.metrics.mean(values) 65 | 66 | # The .* regex makes filtering work in nested scopes 67 | variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, '.*'+scope) 68 | 69 | if (len(variables) == 0): 70 | logger.error("Couldn't collect resettable mean \"{}\"".format(scope)) 71 | sys.exit() 72 | 73 | reset_mean_op = tf.variables_initializer(variables) 74 | 75 | return {'value_op': mean_op, 76 | 'update_op': update_mean_op, 77 | 'reset_op': reset_mean_op} 78 | 79 | 80 | def find_latest_chkpt(search_dir): 81 | try: 82 | 83 | sessions = glob(os.path.join(search_dir, '*/')) 84 | sessions.sort(key=os.path.getmtime, reverse=True) 85 | 86 | # We pick up index 2 from the sorted list instead of 0 or 1 because 87 | # we've already created the session directory for this run plus the 88 | # 'latest' symlink 89 | latest_session = sessions[2] 90 | 91 | latest_checkpoint = tf.train.latest_checkpoint( 92 | os.path.join(latest_session, 'train_checkpoints')) 93 | 94 | logger.info('Latest session found: {}'.format(latest_session)) 95 | logger.info('Resuming from checkpoint: {}'.format(latest_checkpoint)) 96 | 97 | return latest_checkpoint 98 | except IndexError: 99 | logger.error("Couldn't find the checkpoint.") 100 | sys.exit() 101 | 102 | 103 | def make_symbolic_link(sessions_dir, session_name): 104 | latest_symlink = os.path.join(sessions_dir, 'latest') 105 | 106 | try: 107 | os.symlink(session_name, latest_symlink) 108 | except FileExistsError: 109 | logger.debug("Replaced old symbolic link.") 110 | os.unlink(latest_symlink) 111 | os.symlink(session_name, latest_symlink) 112 | 113 | 114 | def save_frozen_model(sess, output_node_names, output_path, 115 | placeholder_values=None): 116 | """Args: 117 | placeholder_values: dictionary for injecting values into placeholders 118 | """ 119 | vars_removed_graph_def = tf.graph_util.convert_variables_to_constants( 120 | sess, tf.get_default_graph().as_graph_def(), output_node_names) 121 | 122 | # We'd like to convert unrelated placeholders to consts at this stage. 123 | # Examples of such placeholders are "keep_prob" for Dropout layers or 124 | # "is_training" for BatchNorm layers. Unfortunately there is not an easy way 125 | # to do this when freezing the model. The only official API I'm aware of is 126 | # the "map_dict" argument in "tf.import_graph_def" method which means 127 | # placeholders values can be provided when loading the frozen models. That's 128 | # why we're doing a dummy load here to use that API and will store the 129 | # result. The downside is that calling import_graph_def() adds an extra 130 | # prefix which cannot be removed. 131 | 132 | with tf.Graph().as_default() as output_graph: 133 | tf.import_graph_def( 134 | vars_removed_graph_def, 135 | input_map=placeholder_values, 136 | name="frozen") 137 | 138 | tf.train.write_graph( 139 | output_graph, 140 | output_path, 141 | 'frozen_model.pb', 142 | as_text=False) 143 | 144 | 145 | def load_frozen_graph(frozen_graph): 146 | 147 | with open(frozen_graph, "rb") as f: 148 | restored_graph_def = tf.GraphDef() 149 | restored_graph_def.ParseFromString(f.read()) 150 | 151 | # temporarily override the current default graph 152 | with tf.Graph().as_default() as graph: 153 | tf.import_graph_def(restored_graph_def, name="") 154 | 155 | return graph 156 | 157 | 158 | def configure_logger(log_file): 159 | 160 | # Make sure we're not adding duplicate handlers if train() is called 161 | # multiple times. 162 | if not logger.handlers: 163 | 164 | logger.setLevel(logging.DEBUG) 165 | 166 | file_log_format = logging.Formatter( 167 | fmt='%(asctime)s [%(levelname)s] %(message)s', 168 | datefmt="%Y-%m-%d %H:%M:%S") 169 | 170 | console_log_format = file_log_format 171 | 172 | console_log_handler = logging.StreamHandler() 173 | console_log_handler.setFormatter(console_log_format) 174 | console_log_handler.setLevel(logging.INFO) 175 | 176 | logger.addHandler(console_log_handler) 177 | 178 | # Also store logs to a file 179 | file_log_handler = logging.FileHandler(log_file, 'w') 180 | file_log_handler.setFormatter(file_log_format) 181 | file_log_handler.setLevel(logging.DEBUG) 182 | logger.addHandler(file_log_handler) 183 | 184 | 185 | def train(num_epochs, batch_size, dataset, model_fn, optimiser_fn, 186 | resume_from_latest_checkpoint=False, tag=None, freeze_model=False): 187 | 188 | # Hopefully the caller has already set the the graph-level random seed for 189 | # reproducibility. 190 | 191 | # Generate a unique name for this run. Used to name directories for 192 | # TensorBoard and stored checkpoints. 193 | this_session_name = datetime.now().strftime("%a-%d-%b-%I%M%p") 194 | if tag is not None: 195 | this_session_name = tag 196 | 197 | this_session_path = os.path.join(_DEFAULT_SESSIONS_PATH, this_session_name) 198 | os.makedirs(this_session_path) 199 | 200 | # Create a 'latest' symlink when necessary 201 | if len(os.listdir(_DEFAULT_SESSIONS_PATH)) > 1: 202 | make_symbolic_link(_DEFAULT_SESSIONS_PATH, this_session_name) 203 | 204 | SUMMARIES_PATH = os.path.join(this_session_path, 'summaries') 205 | TRAIN_CHECKPOINTS_PATH = os.path.join( 206 | this_session_path, 'train_checkpoints/train_checkpoint') 207 | BEST_VAL_CHECKPOINT_PATH = os.path.join( 208 | this_session_path, 'val_checkpoints/best_val_checkpoint') 209 | 210 | configure_logger(os.path.join(this_session_path, 'log.txt')) 211 | 212 | ########################################################################### 213 | # Placeholders 214 | ########################################################################### 215 | batch_images = tf.placeholder( 216 | tf.float32, 217 | [None, dataset.image_size, dataset.image_size, dataset.channels], 218 | 'model_input') 219 | batch_labels = tf.placeholder(tf.float32, [None, dataset.num_classes]) 220 | 221 | tf.summary.image('images', batch_images, max_outputs=6) 222 | 223 | is_training = tf.placeholder(tf.bool, name='is_training') 224 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 225 | 226 | ########################################################################### 227 | # The Model 228 | ########################################################################### 229 | model = model_fn(batch_images, dataset.num_classes, is_training, keep_prob) 230 | 231 | ########################################################################### 232 | # Accuracy 233 | ########################################################################### 234 | batch_top1_accuracy_op = compute_accuracy(batch_labels, model.output, 1) 235 | batch_top5_accuracy_op = compute_accuracy(batch_labels, model.output, 5) 236 | 237 | top1_mean_acc_ops = get_resettable_mean_metric(batch_top1_accuracy_op, 238 | 'val_top1_mean') 239 | top5_mean_acc_ops = get_resettable_mean_metric(batch_top5_accuracy_op, 240 | 'val_top5_mean') 241 | 242 | tf.summary.scalar('Top-1 Training Accuracy', batch_top1_accuracy_op) 243 | tf.summary.scalar('Top-5 Training Accuracy', batch_top5_accuracy_op) 244 | 245 | # Validation-related summaries are added to a separate collection so that 246 | # they can be evaluated at separate time from training summaries. 247 | tf.summary.scalar('Validation Top-1 Accuracy', 248 | top1_mean_acc_ops['value_op'], 249 | collections=['VALIDATION_SUMMARIES']) 250 | 251 | ########################################################################### 252 | # Optimiser 253 | ########################################################################### 254 | global_step_op = tf.train.create_global_step() 255 | steps_per_epoch = dataset.num_train_samples // batch_size 256 | 257 | train_op, total_loss_op = optimiser_fn(batch_labels, model.output) 258 | 259 | ########################################################################### 260 | # TensorBoard Summaries/Checkpoints 261 | ########################################################################### 262 | train_summary_op = tf.summary.merge_all() 263 | val_summary_op = tf.summary.merge_all('VALIDATION_SUMMARIES') 264 | 265 | train_chkpt_op = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) 266 | best_val_chkpt_op = tf.train.Saver(max_to_keep=1) 267 | 268 | ########################################################################### 269 | # Session 270 | ########################################################################### 271 | 272 | def evaluate(dataset_init_fn, dataset_select_fn): 273 | """A helper function used to compute accuracy on validation and 274 | training dataset. Ideally we'd like to do this in one go but often the 275 | entire validation set is too big to fit in memory and therefore we 276 | still need to use batches for validation as well. 277 | 278 | Args: 279 | 280 | Returns: 281 | Top 1% and Top 5% accuracies 282 | """ 283 | dataset_init_fn(sess) 284 | sess.run([top1_mean_acc_ops['reset_op'], 285 | top5_mean_acc_ops['reset_op']]) 286 | 287 | while True: 288 | try: 289 | _, images, labels = sess.run(dataset.next_element, 290 | feed_dict=dataset_select_fn()) 291 | 292 | sess.run([top1_mean_acc_ops['update_op'], 293 | top5_mean_acc_ops['update_op']], 294 | feed_dict={batch_images: images, 295 | batch_labels: labels, 296 | is_training: False, 297 | keep_prob: 1.0}) 298 | 299 | except tf.errors.OutOfRangeError: 300 | top1, top5 = sess.run([top1_mean_acc_ops['value_op'], 301 | top5_mean_acc_ops['value_op']]) 302 | break 303 | 304 | return top1, top5 305 | 306 | def main_training_loop(sess, starting_epoch, num_epochs): 307 | 308 | ################################# 309 | # The Actual Training Loop 310 | ################################# 311 | summary_writer = tf.summary.FileWriter(SUMMARIES_PATH, sess.graph) 312 | 313 | dataset.evaluate_handles(sess) 314 | 315 | train_start_time = summary_start_time = datetime.now() 316 | logger.info('Training start time: {}'.format(train_start_time)) 317 | 318 | current_step = starting_epoch * steps_per_epoch 319 | best_epoch = 0 320 | best_validation_acc = 0.0 321 | testing_acc = 0.0 322 | 323 | # TODO: Put these in the right place 324 | # --------8<---------- 325 | epoch_loss = tf.Variable(0, trainable=False, dtype=tf.float32) 326 | update_epoch_loss = tf.assign_add(epoch_loss, total_loss_op) 327 | reset_epoch_loss = tf.variables_initializer([epoch_loss]) 328 | tf.summary.scalar('Epoch Training Loss', epoch_loss, collections=['EPOCH_SUMMARIES']) 329 | epoch_summary_op = tf.summary.merge_all('EPOCH_SUMMARIES') 330 | # --------8<---------- 331 | 332 | for epoch_num in range(starting_epoch, num_epochs): 333 | 334 | sess.run(reset_epoch_loss) 335 | 336 | for batch_num in range(steps_per_epoch): 337 | 338 | _, images, labels = sess.run( 339 | dataset.next_element, 340 | feed_dict=dataset.from_training_set()) 341 | 342 | sess.run([train_op, update_epoch_loss], feed_dict={batch_images: images, 343 | batch_labels: labels, 344 | is_training: True, 345 | keep_prob: 0.5}) 346 | 347 | if batch_num % _PRINT_SUMMARY_FREQ == 0: 348 | 349 | top1, top5, batch_summary = sess.run( 350 | [batch_top1_accuracy_op, batch_top5_accuracy_op, 351 | train_summary_op], 352 | feed_dict={batch_images: images, 353 | batch_labels: labels, 354 | is_training: False, 355 | keep_prob: 1.0}) 356 | 357 | logger.debug('time/step: {:.2f} ms'. 358 | format(time_per_step(summary_start_time, 359 | _PRINT_SUMMARY_FREQ))) 360 | 361 | logger.info( 362 | '{} Epoch {:>3} - Batch {:>4} - ' 363 | 'Batch accuracy: top-1 {:5.2f}% - top-5: {:5.2f}%' 364 | .format(get_timedelta(train_start_time), epoch_num, 365 | batch_num, top1, top5)) 366 | 367 | summary_writer.add_summary(batch_summary, current_step) 368 | summary_start_time = datetime.now() 369 | 370 | current_step += 1 371 | 372 | ######################## 373 | # End of Epoch 374 | ######################## 375 | train_chkpt_op.save(sess, TRAIN_CHECKPOINTS_PATH, global_step_op) 376 | 377 | # Evaluate accuracy on the validation dataset (and testing dataset 378 | # if observed validation is the best seen) 379 | val_acc, _ = evaluate(dataset.initialize_validation, 380 | dataset.from_validation_set) 381 | 382 | summary_writer.add_summary(sess.run(val_summary_op), current_step) 383 | summary_writer.add_summary(sess.run(epoch_summary_op), current_step) 384 | 385 | logger.info('Epoch training loss {} - Validation accuracy: top-1 {:5.2f}%'.format( 386 | epoch_loss.eval(), val_acc)) 387 | 388 | if val_acc > best_validation_acc: 389 | logger.info('New Best Validation!') 390 | 391 | best_validation_acc = val_acc 392 | best_epoch = epoch_num 393 | 394 | testing_acc, _ = evaluate(dataset.initialize_testing, 395 | dataset.from_testing_set) 396 | 397 | best_val_chkpt_op.save(sess, BEST_VAL_CHECKPOINT_PATH, global_step_op) 398 | 399 | logger.info( 400 | 'Testing accuracy to report: {:5.2f}% (error: {:5.2f}%) - ' 401 | 'Seen in epoch {}' 402 | .format(testing_acc, 100.0 - testing_acc, best_epoch)) 403 | 404 | ######################## 405 | # End of training! 406 | ######################## 407 | total_train_time = datetime.now() - train_start_time 408 | 409 | logger.info('End of training! Overall training time: {}' 410 | .format(total_train_time)) 411 | 412 | logger.info('Best observed validation accuracy: {:5.2f}%' 413 | .format(best_validation_acc)) 414 | logger.info('Testing accuracy to report: {:5.2f}% (error: {:5.2f}%) - ' 415 | 'Seen in epoch {}' 416 | .format(testing_acc, 100-testing_acc, best_epoch)) 417 | 418 | summary_writer.close() 419 | 420 | return testing_acc, best_epoch 421 | 422 | def _print_env_details(): 423 | logger.info("=====================================") 424 | """ A helper function to print env details. """ 425 | logger.info("Timestamp: {}".format(datetime.now())) 426 | logger.info("TensorFlow Version: {}".format(tf.VERSION)) 427 | logger.info('Session Name: "{}"'.format(this_session_name)) 428 | logger.info("Total number of trainable parameters: {:,}" 429 | .format(get_num_trainable_params())) 430 | logger.info("{} Epochs - Batch Size {}".format(num_epochs, batch_size)) 431 | logger.debug(pformat(globals())) 432 | logger.debug(pformat(locals())) 433 | logger.debug("All Trainable Parameters:") 434 | for var in tf.trainable_variables(): 435 | logger.debug('{} {}'.format(var.name, var.shape)) 436 | logger.info("=====================================") 437 | 438 | _print_env_details() 439 | 440 | with tf.Session() as sess: 441 | 442 | sess.run(tf.global_variables_initializer()) 443 | 444 | if resume_from_latest_checkpoint: 445 | train_chkpt_op.restore(sess, find_latest_chkpt(_DEFAULT_SESSIONS_PATH)) 446 | starting_epoch = global_step_op.eval() // steps_per_epoch 447 | # TODO: Restore best validation details 448 | else: 449 | starting_epoch = 0 450 | 451 | ####################################################################### 452 | # Let's train! 453 | ####################################################################### 454 | acc, best_epoch = main_training_loop(sess, starting_epoch, num_epochs) 455 | 456 | if freeze_model: 457 | logger.info("Feezing the model ...") 458 | save_frozen_model(sess, ['model_output'], this_session_path, 459 | {'is_training': False}) 460 | logger.info("Frozen model saved.") 461 | 462 | sess.close() 463 | 464 | TestResults = namedtuple('TestResults', ['session_name', 'accuracy', 465 | 'best_epoch', 'trained_model']) 466 | 467 | return TestResults(this_session_name, acc, best_epoch, model) 468 | --------------------------------------------------------------------------------