├── .gitignore ├── .github └── FUNDING.yml ├── LICENSE ├── predict.py ├── include ├── model.py └── data.py ├── train.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data_set 3 | tensorboard 4 | include/__pycache__ -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: exelban 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Serhiy Mytrovtsiy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.compat import v1 as tf 3 | 4 | from include.data import get_data_set 5 | from include.model import model 6 | 7 | 8 | test_x, test_y = get_data_set("test") 9 | x, y, output, y_pred_cls, global_step, learning_rate = model() 10 | 11 | 12 | _BATCH_SIZE = 128 13 | _CLASS_SIZE = 10 14 | _SAVE_PATH = "./tensorboard/cifar-10-v1.0.0/" 15 | 16 | 17 | saver = tf.train.Saver() 18 | sess = tf.Session() 19 | 20 | 21 | try: 22 | print("\nTrying to restore last checkpoint ...") 23 | last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=_SAVE_PATH) 24 | saver.restore(sess, save_path=last_chk_path) 25 | print("Restored checkpoint from:", last_chk_path) 26 | except ValueError: 27 | print("\nFailed to restore checkpoint. Initializing variables instead.") 28 | sess.run(tf.global_variables_initializer()) 29 | 30 | 31 | def main(): 32 | i = 0 33 | predicted_class = np.zeros(shape=len(test_x), dtype=np.int) 34 | while i < len(test_x): 35 | j = min(i + _BATCH_SIZE, len(test_x)) 36 | batch_xs = test_x[i:j, :] 37 | batch_ys = test_y[i:j, :] 38 | predicted_class[i:j] = sess.run(y_pred_cls, feed_dict={x: batch_xs, y: batch_ys}) 39 | i = j 40 | 41 | correct = (np.argmax(test_y, axis=1) == predicted_class) 42 | acc = correct.mean() * 100 43 | correct_numbers = correct.sum() 44 | print() 45 | print("Accuracy on Test-Set: {0:.2f}% ({1} / {2})".format(acc, correct_numbers, len(test_x))) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | 51 | 52 | sess.close() 53 | -------------------------------------------------------------------------------- /include/model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.compat import v1 as tf 2 | 3 | 4 | def model(): 5 | _IMAGE_SIZE = 32 6 | _IMAGE_CHANNELS = 3 7 | _NUM_CLASSES = 10 8 | 9 | with tf.name_scope('main_params'): 10 | x = tf.placeholder(tf.float32, shape=[None, _IMAGE_SIZE * _IMAGE_SIZE * _IMAGE_CHANNELS], name='Input') 11 | y = tf.placeholder(tf.float32, shape=[None, _NUM_CLASSES], name='Output') 12 | x_image = tf.reshape(x, [-1, _IMAGE_SIZE, _IMAGE_SIZE, _IMAGE_CHANNELS], name='images') 13 | 14 | global_step = tf.Variable(initial_value=0, trainable=False, name='global_step') 15 | learning_rate = tf.placeholder(tf.float32, shape=[], name='learning_rate') 16 | 17 | with tf.variable_scope('conv1') as scope: 18 | conv = tf.layers.conv2d( 19 | inputs=x_image, 20 | filters=32, 21 | kernel_size=[3, 3], 22 | padding='SAME', 23 | activation=tf.nn.relu 24 | ) 25 | conv = tf.layers.conv2d( 26 | inputs=conv, 27 | filters=64, 28 | kernel_size=[3, 3], 29 | padding='SAME', 30 | activation=tf.nn.relu 31 | ) 32 | pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') 33 | drop = tf.layers.dropout(pool, rate=0.25, name=scope.name) 34 | 35 | with tf.variable_scope('conv2') as scope: 36 | conv = tf.layers.conv2d( 37 | inputs=drop, 38 | filters=128, 39 | kernel_size=[3, 3], 40 | padding='SAME', 41 | activation=tf.nn.relu 42 | ) 43 | pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') 44 | conv = tf.layers.conv2d( 45 | inputs=pool, 46 | filters=128, 47 | kernel_size=[2, 2], 48 | padding='SAME', 49 | activation=tf.nn.relu 50 | ) 51 | pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') 52 | drop = tf.layers.dropout(pool, rate=0.25, name=scope.name) 53 | 54 | with tf.variable_scope('fully_connected') as scope: 55 | flat = tf.reshape(drop, [-1, 4 * 4 * 128]) 56 | 57 | fc = tf.layers.dense(inputs=flat, units=1500, activation=tf.nn.relu) 58 | drop = tf.layers.dropout(fc, rate=0.5) 59 | softmax = tf.layers.dense(inputs=drop, units=_NUM_CLASSES, name=scope.name) 60 | 61 | y_pred_cls = tf.argmax(softmax, axis=1) 62 | 63 | return x, y, softmax, y_pred_cls, global_step, learning_rate 64 | 65 | 66 | def lr(epoch): 67 | learning_rate = 1e-3 68 | if epoch > 80: 69 | learning_rate *= 0.5e-3 70 | elif epoch > 60: 71 | learning_rate *= 1e-3 72 | elif epoch > 40: 73 | learning_rate *= 1e-2 74 | elif epoch > 20: 75 | learning_rate *= 1e-1 76 | return learning_rate 77 | -------------------------------------------------------------------------------- /include/data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | from urllib.request import urlretrieve 5 | import tarfile 6 | import zipfile 7 | import sys 8 | 9 | 10 | def get_data_set(name="train"): 11 | x = None 12 | y = None 13 | 14 | maybe_download_and_extract() 15 | 16 | folder_name = "cifar_10" 17 | 18 | f = open('./data_set/'+folder_name+'/batches.meta', 'rb') 19 | f.close() 20 | 21 | if name is "train": 22 | for i in range(5): 23 | f = open('./data_set/'+folder_name+'/data_batch_' + str(i + 1), 'rb') 24 | datadict = pickle.load(f, encoding='latin1') 25 | f.close() 26 | 27 | _X = datadict["data"] 28 | _Y = datadict['labels'] 29 | 30 | _X = np.array(_X, dtype=float) / 255.0 31 | _X = _X.reshape([-1, 3, 32, 32]) 32 | _X = _X.transpose([0, 2, 3, 1]) 33 | _X = _X.reshape(-1, 32*32*3) 34 | 35 | if x is None: 36 | x = _X 37 | y = _Y 38 | else: 39 | x = np.concatenate((x, _X), axis=0) 40 | y = np.concatenate((y, _Y), axis=0) 41 | 42 | elif name is "test": 43 | f = open('./data_set/'+folder_name+'/test_batch', 'rb') 44 | datadict = pickle.load(f, encoding='latin1') 45 | f.close() 46 | 47 | x = datadict["data"] 48 | y = np.array(datadict['labels']) 49 | 50 | x = np.array(x, dtype=float) / 255.0 51 | x = x.reshape([-1, 3, 32, 32]) 52 | x = x.transpose([0, 2, 3, 1]) 53 | x = x.reshape(-1, 32*32*3) 54 | 55 | return x, dense_to_one_hot(y) 56 | 57 | 58 | def dense_to_one_hot(labels_dense, num_classes=10): 59 | num_labels = labels_dense.shape[0] 60 | index_offset = np.arange(num_labels) * num_classes 61 | labels_one_hot = np.zeros((num_labels, num_classes)) 62 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 63 | 64 | return labels_one_hot 65 | 66 | 67 | def _print_download_progress(count, block_size, total_size): 68 | pct_complete = float(count * block_size) / total_size 69 | msg = "\r- Download progress: {0:.1%}".format(pct_complete) 70 | sys.stdout.write(msg) 71 | sys.stdout.flush() 72 | 73 | 74 | def maybe_download_and_extract(): 75 | main_directory = "./data_set/" 76 | cifar_10_directory = main_directory+"cifar_10/" 77 | if not os.path.exists(main_directory): 78 | os.makedirs(main_directory) 79 | 80 | url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 81 | filename = url.split('/')[-1] 82 | file_path = os.path.join(main_directory, filename) 83 | zip_cifar_10 = file_path 84 | file_path, _ = urlretrieve(url=url, filename=file_path, reporthook=_print_download_progress) 85 | 86 | print() 87 | print("Download finished. Extracting files.") 88 | if file_path.endswith(".zip"): 89 | zipfile.ZipFile(file=file_path, mode="r").extractall(main_directory) 90 | elif file_path.endswith((".tar.gz", ".tgz")): 91 | tarfile.open(name=file_path, mode="r:gz").extractall(main_directory) 92 | print("Done.") 93 | 94 | os.rename(main_directory+"./cifar-10-batches-py", cifar_10_directory) 95 | os.remove(zip_cifar_10) 96 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from time import time 4 | import math 5 | 6 | 7 | from include.data import get_data_set 8 | from include.model import model, lr 9 | 10 | 11 | train_x, train_y = get_data_set("train") 12 | test_x, test_y = get_data_set("test") 13 | tf.set_random_seed(21) 14 | x, y, output, y_pred_cls, global_step, learning_rate = model() 15 | global_accuracy = 0 16 | epoch_start = 0 17 | 18 | 19 | # PARAMS 20 | _BATCH_SIZE = 128 21 | _EPOCH = 60 22 | _SAVE_PATH = "./tensorboard/cifar-10-v1.0.0/" 23 | 24 | 25 | # LOSS AND OPTIMIZER 26 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=output, labels=y)) 27 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 28 | beta1=0.9, 29 | beta2=0.999, 30 | epsilon=1e-08).minimize(loss, global_step=global_step) 31 | 32 | 33 | # PREDICTION AND ACCURACY CALCULATION 34 | correct_prediction = tf.equal(y_pred_cls, tf.argmax(y, axis=1)) 35 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 36 | 37 | 38 | # SAVER 39 | merged = tf.summary.merge_all() 40 | saver = tf.train.Saver() 41 | sess = tf.Session() 42 | train_writer = tf.summary.FileWriter(_SAVE_PATH, sess.graph) 43 | 44 | 45 | try: 46 | print("\nTrying to restore last checkpoint ...") 47 | last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=_SAVE_PATH) 48 | saver.restore(sess, save_path=last_chk_path) 49 | print("Restored checkpoint from:", last_chk_path) 50 | except ValueError: 51 | print("\nFailed to restore checkpoint. Initializing variables instead.") 52 | sess.run(tf.global_variables_initializer()) 53 | 54 | 55 | def train(epoch): 56 | global epoch_start 57 | epoch_start = time() 58 | batch_size = int(math.ceil(len(train_x) / _BATCH_SIZE)) 59 | i_global = 0 60 | 61 | for s in range(batch_size): 62 | batch_xs = train_x[s*_BATCH_SIZE: (s+1)*_BATCH_SIZE] 63 | batch_ys = train_y[s*_BATCH_SIZE: (s+1)*_BATCH_SIZE] 64 | 65 | start_time = time() 66 | i_global, _, batch_loss, batch_acc = sess.run( 67 | [global_step, optimizer, loss, accuracy], 68 | feed_dict={x: batch_xs, y: batch_ys, learning_rate: lr(epoch)}) 69 | duration = time() - start_time 70 | 71 | if s % 10 == 0: 72 | percentage = int(round((s/batch_size)*100)) 73 | 74 | bar_len = 29 75 | filled_len = int((bar_len*int(percentage))/100) 76 | bar = '=' * filled_len + '>' + '-' * (bar_len - filled_len) 77 | 78 | msg = "Global step: {:>5} - [{}] {:>3}% - acc: {:.4f} - loss: {:.4f} - {:.1f} sample/sec" 79 | print(msg.format(i_global, bar, percentage, batch_acc, batch_loss, _BATCH_SIZE / duration)) 80 | 81 | test_and_save(i_global, epoch) 82 | 83 | 84 | def test_and_save(_global_step, epoch): 85 | global global_accuracy 86 | global epoch_start 87 | 88 | i = 0 89 | predicted_class = np.zeros(shape=len(test_x), dtype=np.int) 90 | while i < len(test_x): 91 | j = min(i + _BATCH_SIZE, len(test_x)) 92 | batch_xs = test_x[i:j, :] 93 | batch_ys = test_y[i:j, :] 94 | predicted_class[i:j] = sess.run( 95 | y_pred_cls, 96 | feed_dict={x: batch_xs, y: batch_ys, learning_rate: lr(epoch)} 97 | ) 98 | i = j 99 | 100 | correct = (np.argmax(test_y, axis=1) == predicted_class) 101 | acc = correct.mean()*100 102 | correct_numbers = correct.sum() 103 | 104 | hours, rem = divmod(time() - epoch_start, 3600) 105 | minutes, seconds = divmod(rem, 60) 106 | mes = "\nEpoch {} - accuracy: {:.2f}% ({}/{}) - time: {:0>2}:{:0>2}:{:05.2f}" 107 | print(mes.format((epoch+1), acc, correct_numbers, len(test_x), int(hours), int(minutes), seconds)) 108 | 109 | if global_accuracy != 0 and global_accuracy < acc: 110 | 111 | summary = tf.Summary(value=[ 112 | tf.Summary.Value(tag="Accuracy/test", simple_value=acc), 113 | ]) 114 | train_writer.add_summary(summary, _global_step) 115 | 116 | saver.save(sess, save_path=_SAVE_PATH, global_step=_global_step) 117 | 118 | mes = "This epoch receive better accuracy: {:.2f} > {:.2f}. Saving session..." 119 | print(mes.format(acc, global_accuracy)) 120 | global_accuracy = acc 121 | 122 | elif global_accuracy == 0: 123 | global_accuracy = acc 124 | 125 | print("###########################################################################################################") 126 | 127 | 128 | def main(): 129 | train_start = time() 130 | 131 | for i in range(_EPOCH): 132 | print("\nEpoch: {}/{}\n".format((i+1), _EPOCH)) 133 | train(i) 134 | 135 | hours, rem = divmod(time() - train_start, 3600) 136 | minutes, seconds = divmod(rem, 60) 137 | mes = "Best accuracy pre session: {:.2f}, time: {:0>2}:{:0>2}:{:05.2f}" 138 | print(mes.format(global_accuracy, int(hours), int(minutes), seconds)) 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | 144 | 145 | sess.close() 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-cifar-10 2 | Cifar-10 convolutional network implementation example using TensorFlow library. 3 | ![](https://s3.eu-central-1.amazonaws.com/serhiy/Github_repo/tensorflow-cifar-10/v1.0.0/plot.png?v1) 4 | 5 | ## Requirement 6 | **Library** | **Version** 7 | --- | --- 8 | **Python** | **^3.6.5** 9 | **Tensorflow** | **^1.6.0** 10 | **Numpy** | **^1.14.2** 11 | **Pickle** | **^4.0** 12 | 13 | ## Accuracy 14 | Best accurancy what I receive was ```79.12%``` on test data set. You must to understand that network cant always learn with the same accuracy. But almost always accuracy more than ```78%```. 15 | 16 | This repository is just example of implemantation convolution neural network. Here I implement a simple neural network for image recognition with good accuracy. 17 | 18 | If you want to get more that 80% accuracy, You need to implement more complicated nn models (such as [ResNet](https://arxiv.org/abs/1512.03385), [GoogleLeNet](https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf), [mobileNetV2](https://arxiv.org/abs/1801.04381) ect). 19 | 20 | 21 | ## Usage 22 | ### Download code 23 | ```sh 24 | git clone https://github.com/exelban/tensorflow-cifar-10 25 | 26 | cd tensorflow-cifar-10 27 | ``` 28 | 29 | ### Check if you have nessessary packages 30 | ```sh 31 | pip3 install numpy tensorflow pickle 32 | ``` 33 | 34 | 35 | ### Train network 36 | By default network will be run 60 epoch (60 times on all training data set). 37 | You can change that by editing ```_EPOCH``` in ```train.py``` file. 38 | 39 | Also by default it process 128 files in each step. 40 | If you training network on CPU or GPU (lowest that 1060 6GB) change ```_BATCH_SIZE``` in ```train.py``` to a smaller value. 41 | 42 | 43 | ```sh 44 | python3 train.py 45 | ``` 46 | 47 | Simple output: 48 | ```sh 49 | Epoch: 60/60 50 | 51 | Global step: 23070 - [>-----------------------------] 0% - acc: 0.9531 - loss: 1.5081 - 7045.4 sample/sec 52 | Global step: 23080 - [>-----------------------------] 3% - acc: 0.9453 - loss: 1.5159 - 7147.6 sample/sec 53 | Global step: 23090 - [=>----------------------------] 5% - acc: 0.9844 - loss: 1.4764 - 7154.6 sample/sec 54 | Global step: 23100 - [==>---------------------------] 8% - acc: 0.9297 - loss: 1.5307 - 7104.4 sample/sec 55 | Global step: 23110 - [==>---------------------------] 10% - acc: 0.9141 - loss: 1.5462 - 7091.4 sample/sec 56 | Global step: 23120 - [===>--------------------------] 13% - acc: 0.9297 - loss: 1.5314 - 7162.9 sample/sec 57 | Global step: 23130 - [====>-------------------------] 15% - acc: 0.9297 - loss: 1.5307 - 7174.8 sample/sec 58 | Global step: 23140 - [=====>------------------------] 18% - acc: 0.9375 - loss: 1.5231 - 7140.0 sample/sec 59 | Global step: 23150 - [=====>------------------------] 20% - acc: 0.9297 - loss: 1.5301 - 7152.8 sample/sec 60 | Global step: 23160 - [======>-----------------------] 23% - acc: 0.9531 - loss: 1.5080 - 7112.3 sample/sec 61 | Global step: 23170 - [=======>----------------------] 26% - acc: 0.9609 - loss: 1.5000 - 7154.0 sample/sec 62 | Global step: 23180 - [========>---------------------] 28% - acc: 0.9531 - loss: 1.5074 - 6862.2 sample/sec 63 | Global step: 23190 - [========>---------------------] 31% - acc: 0.9609 - loss: 1.4993 - 7134.5 sample/sec 64 | Global step: 23200 - [=========>--------------------] 33% - acc: 0.9609 - loss: 1.4995 - 7166.0 sample/sec 65 | Global step: 23210 - [==========>-------------------] 36% - acc: 0.9375 - loss: 1.5231 - 7116.7 sample/sec 66 | Global step: 23220 - [===========>------------------] 38% - acc: 0.9453 - loss: 1.5153 - 7134.1 sample/sec 67 | Global step: 23230 - [===========>------------------] 41% - acc: 0.9375 - loss: 1.5233 - 7074.5 sample/sec 68 | Global step: 23240 - [============>-----------------] 43% - acc: 0.9219 - loss: 1.5387 - 7176.9 sample/sec 69 | Global step: 23250 - [=============>----------------] 46% - acc: 0.8828 - loss: 1.5769 - 7144.1 sample/sec 70 | Global step: 23260 - [==============>---------------] 49% - acc: 0.9219 - loss: 1.5383 - 7059.7 sample/sec 71 | Global step: 23270 - [==============>---------------] 51% - acc: 0.8984 - loss: 1.5618 - 6638.6 sample/sec 72 | Global step: 23280 - [===============>--------------] 54% - acc: 0.9453 - loss: 1.5151 - 7035.7 sample/sec 73 | Global step: 23290 - [================>-------------] 56% - acc: 0.9609 - loss: 1.4996 - 7129.0 sample/sec 74 | Global step: 23300 - [=================>------------] 59% - acc: 0.9609 - loss: 1.4997 - 7075.4 sample/sec 75 | Global step: 23310 - [=================>------------] 61% - acc: 0.8750 - loss: 1.5842 - 7117.8 sample/sec 76 | Global step: 23320 - [==================>-----------] 64% - acc: 0.9141 - loss: 1.5463 - 7157.2 sample/sec 77 | Global step: 23330 - [===================>----------] 66% - acc: 0.9062 - loss: 1.5549 - 7169.3 sample/sec 78 | Global step: 23340 - [====================>---------] 69% - acc: 0.9219 - loss: 1.5389 - 7164.4 sample/sec 79 | Global step: 23350 - [====================>---------] 72% - acc: 0.9609 - loss: 1.5002 - 7135.4 sample/sec 80 | Global step: 23360 - [=====================>--------] 74% - acc: 0.9766 - loss: 1.4842 - 7124.2 sample/sec 81 | Global step: 23370 - [======================>-------] 77% - acc: 0.9375 - loss: 1.5231 - 7168.5 sample/sec 82 | Global step: 23380 - [======================>-------] 79% - acc: 0.8906 - loss: 1.5695 - 7175.2 sample/sec 83 | Global step: 23390 - [=======================>------] 82% - acc: 0.9375 - loss: 1.5225 - 7132.1 sample/sec 84 | Global step: 23400 - [========================>-----] 84% - acc: 0.9844 - loss: 1.4768 - 7100.1 sample/sec 85 | Global step: 23410 - [=========================>----] 87% - acc: 0.9766 - loss: 1.4840 - 7172.0 sample/sec 86 | Global step: 23420 - [==========================>---] 90% - acc: 0.9062 - loss: 1.5542 - 7122.1 sample/sec 87 | Global step: 23430 - [==========================>---] 92% - acc: 0.9297 - loss: 1.5313 - 7145.3 sample/sec 88 | Global step: 23440 - [===========================>--] 95% - acc: 0.9297 - loss: 1.5301 - 7133.3 sample/sec 89 | Global step: 23450 - [============================>-] 97% - acc: 0.9375 - loss: 1.5231 - 7135.7 sample/sec 90 | Global step: 23460 - [=============================>] 100% - acc: 0.9250 - loss: 1.5362 - 10297.5 sample/sec 91 | 92 | Epoch 60 - accuracy: 78.81% (7881/10000) 93 | This epoch receive better accuracy: 78.81 > 78.78. Saving session... 94 | ########################################################################################################### 95 | ``` 96 | 97 | 98 | ### Run network on test data set 99 | ```sh 100 | python3 predict.py 101 | ``` 102 | 103 | Simple output: 104 | ```sh 105 | Trying to restore last checkpoint ... 106 | Restored checkpoint from: ./tensorboard/cifar-10-v1.0.0/-23460 107 | 108 | Accuracy on Test-Set: 78.81% (7881 / 10000) 109 | ``` 110 | 111 | 112 | ## Training time 113 | Here you can see how much time takes 60 epoch: 114 | 115 | **Device** | **Batch size** | **Time** | **Accuracy [%]** 116 | --- | --- | --- | --- 117 | **NVidia GTX 1070** | **128** | **8m 4s** | **79.12** 118 | **Intel i7 7700HQ** | **128** | **3h 30m** | **78.91** 119 | 120 | Please send me (or open issue) your time and accuracy. I will add it to the list. 121 | 122 | ## Model 123 | 124 | ## What's new 125 | 126 | ### v1.0.1 127 | - Set random seed 128 | - Added more information about elapsed time on epoch and full training 129 | 130 | ### v1.0.0 131 | - Removed all references to cifar 100 132 | - Small fixes in data functions 133 | - Almost fully rewrited train.py 134 | - Simplyfy cnn model 135 | - Changed optimizer to AdamOptimizer 136 | - Changed Licence to MIT 137 | - Removed confusion matrix (don't like to have unnecessary dependencies) 138 | - Improved accuracy on testing data set (up to 79%) 139 | - Small fixes in train.py 140 | - Changed saver functions (now session will be saved only if accuracy in this session will be better than the last saved) 141 | - Updated packages 142 | 143 | ### v0.0.1 144 | - Make tests on AWS instances 145 | - Model fixes 146 | - Remove cifar-100 dataset 147 | 148 | 149 | ### v0.0.0 150 | - First release 151 | 152 | ## License 153 | [MIT License](https://github.com/exelban/tensorflow-cifar-10/blob/master/LICENSE) 154 | --------------------------------------------------------------------------------