├── .gitignore ├── LICENSE ├── README.md ├── ResNeXt.py ├── assests ├── ResNeXt.JPG ├── ResNet.JPG └── comparision.png └── cifar10.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ResNeXt-Tensorflow 2 | Tensorflow implementation of [ResNeXt](https://arxiv.org/abs/1611.05431) using **Cifar10** 3 | 4 | If you want to see the ***original author's code***, please refer to this [link](https://github.com/facebookresearch/ResNeXt) 5 | 6 | ## Requirements 7 | * Tensorflow 1.x 8 | * Python 3.x 9 | * tflearn (If you are easy to use ***global average pooling***, you should install ***tflearn***) 10 | 11 | ## Issue 12 | * If not enough GPU memory, Please edit the code 13 | ```python 14 | with tf.Session() as sess : NO 15 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess : OK 16 | ``` 17 | ## Compare Architecture 18 | ### ResNet 19 | ![ResNet](./assests/ResNet.JPG) 20 | 21 | ### ResNeXt 22 | ![ResNeXt](./assests/ResNeXt.JPG) 23 | 24 | * I implemented (b) 25 | * (b) is ***split + transform(bottleneck) + concatenate + transition + merge*** 26 | 27 | ## Idea 28 | ### What is the "split" ? 29 | ```python 30 | def split_layer(self, input_x, stride, layer_name): 31 | with tf.name_scope(layer_name) : 32 | layers_split = list() 33 | for i in range(cardinality) : 34 | splits = self.transform_layer(input_x, stride=stride, scope=layer_name + '_splitN_' + str(i)) 35 | layers_split.append(splits) 36 | 37 | return Concatenation(layers_split) 38 | ``` 39 | * ***Cardinality*** means how many times you want to split. 40 | 41 | ### What is the "transform" ? 42 | ```python 43 | def transform_layer(self, x, stride, scope): 44 | with tf.name_scope(scope) : 45 | x = conv_layer(x, filter=depth, kernel=[1,1], stride=stride, layer_name=scope+'_conv1') 46 | x = Batch_Normalization(x, training=self.training, scope=scope+'_batch1') 47 | x = Relu(x) 48 | 49 | x = conv_layer(x, filter=depth, kernel=[3,3], stride=1, layer_name=scope+'_conv2') 50 | x = Batch_Normalization(x, training=self.training, scope=scope+'_batch2') 51 | x = Relu(x) 52 | return x 53 | ``` 54 | 55 | ### What is the "transition" ? 56 | ```python 57 | def transition_layer(self, x, out_dim, scope): 58 | with tf.name_scope(scope): 59 | x = conv_layer(x, filter=out_dim, kernel=[1,1], stride=1, layer_name=scope+'_conv1') 60 | x = Batch_Normalization(x, training=self.training, scope=scope+'_batch1') 61 | 62 | return x 63 | ```` 64 | 65 | ## Comapre Results (ResNet, DenseNet, ResNeXt) 66 | ![compare](./assests/comparision.png) 67 | 68 | ## Related works 69 | * [DenseNet-Tensorflow](https://github.com/taki0112/Densenet-Tensorflow) 70 | * [SENet-Tensorflow](https://github.com/taki0112/SENet-Tensorflow) 71 | * [ResNet-Tensorflow](https://github.com/taki0112/ResNet-Tensorflow) 72 | 73 | ## References 74 | * [Classification Datasets Results](http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html) 75 | 76 | ## Author 77 | Junho Kim 78 | -------------------------------------------------------------------------------- /ResNeXt.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tflearn.layers.conv import global_avg_pool 3 | from tensorflow.contrib.layers import batch_norm, flatten 4 | from tensorflow.contrib.framework import arg_scope 5 | from cifar10 import * 6 | import numpy as np 7 | 8 | weight_decay = 0.0005 9 | momentum = 0.9 10 | 11 | init_learning_rate = 0.1 12 | cardinality = 8 # how many split ? 13 | blocks = 3 # res_block ! (split + transition) 14 | 15 | """ 16 | So, the total number of layers is (3*blokcs)*residual_layer_num + 2 17 | because, blocks = split(conv 2) + transition(conv 1) = 3 layer 18 | and, first conv layer 1, last dense layer 1 19 | thus, total number of layers = (3*blocks)*residual_layer_num + 2 20 | """ 21 | 22 | depth = 64 # out channel 23 | 24 | batch_size = 128 25 | iteration = 391 26 | # 128 * 391 ~ 50,000 27 | 28 | test_iteration = 10 29 | 30 | total_epochs = 300 31 | 32 | def conv_layer(input, filter, kernel, stride, padding='SAME', layer_name="conv"): 33 | with tf.name_scope(layer_name): 34 | network = tf.layers.conv2d(inputs=input, use_bias=False, filters=filter, kernel_size=kernel, strides=stride, padding=padding) 35 | return network 36 | 37 | def Global_Average_Pooling(x): 38 | return global_avg_pool(x, name='Global_avg_pooling') 39 | 40 | def Average_pooling(x, pool_size=[2,2], stride=2, padding='SAME'): 41 | return tf.layers.average_pooling2d(inputs=x, pool_size=pool_size, strides=stride, padding=padding) 42 | 43 | def Batch_Normalization(x, training, scope): 44 | with arg_scope([batch_norm], 45 | scope=scope, 46 | updates_collections=None, 47 | decay=0.9, 48 | center=True, 49 | scale=True, 50 | zero_debias_moving_mean=True) : 51 | return tf.cond(training, 52 | lambda : batch_norm(inputs=x, is_training=training, reuse=None), 53 | lambda : batch_norm(inputs=x, is_training=training, reuse=True)) 54 | 55 | def Relu(x): 56 | return tf.nn.relu(x) 57 | 58 | def Concatenation(layers) : 59 | return tf.concat(layers, axis=3) 60 | 61 | def Linear(x) : 62 | return tf.layers.dense(inputs=x, use_bias=False, units=class_num, name='linear') 63 | 64 | def Evaluate(sess): 65 | test_acc = 0.0 66 | test_loss = 0.0 67 | test_pre_index = 0 68 | add = 1000 69 | 70 | for it in range(test_iteration): 71 | test_batch_x = test_x[test_pre_index: test_pre_index + add] 72 | test_batch_y = test_y[test_pre_index: test_pre_index + add] 73 | test_pre_index = test_pre_index + add 74 | 75 | test_feed_dict = { 76 | x: test_batch_x, 77 | label: test_batch_y, 78 | learning_rate: epoch_learning_rate, 79 | training_flag: False 80 | } 81 | 82 | loss_, acc_ = sess.run([cost, accuracy], feed_dict=test_feed_dict) 83 | 84 | test_loss += loss_ 85 | test_acc += acc_ 86 | 87 | test_loss /= test_iteration # average loss 88 | test_acc /= test_iteration # average accuracy 89 | 90 | summary = tf.Summary(value=[tf.Summary.Value(tag='test_loss', simple_value=test_loss), 91 | tf.Summary.Value(tag='test_accuracy', simple_value=test_acc)]) 92 | 93 | return test_acc, test_loss, summary 94 | 95 | class ResNeXt(): 96 | def __init__(self, x, training): 97 | self.training = training 98 | self.model = self.Build_ResNext(x) 99 | 100 | def first_layer(self, x, scope): 101 | with tf.name_scope(scope) : 102 | x = conv_layer(x, filter=64, kernel=[3, 3], stride=1, layer_name=scope+'_conv1') 103 | x = Batch_Normalization(x, training=self.training, scope=scope+'_batch1') 104 | x = Relu(x) 105 | 106 | return x 107 | 108 | def transform_layer(self, x, stride, scope): 109 | with tf.name_scope(scope) : 110 | x = conv_layer(x, filter=depth, kernel=[1,1], stride=stride, layer_name=scope+'_conv1') 111 | x = Batch_Normalization(x, training=self.training, scope=scope+'_batch1') 112 | x = Relu(x) 113 | 114 | x = conv_layer(x, filter=depth, kernel=[3,3], stride=1, layer_name=scope+'_conv2') 115 | x = Batch_Normalization(x, training=self.training, scope=scope+'_batch2') 116 | x = Relu(x) 117 | return x 118 | 119 | def transition_layer(self, x, out_dim, scope): 120 | with tf.name_scope(scope): 121 | x = conv_layer(x, filter=out_dim, kernel=[1,1], stride=1, layer_name=scope+'_conv1') 122 | x = Batch_Normalization(x, training=self.training, scope=scope+'_batch1') 123 | # x = Relu(x) 124 | 125 | return x 126 | 127 | def split_layer(self, input_x, stride, layer_name): 128 | with tf.name_scope(layer_name) : 129 | layers_split = list() 130 | for i in range(cardinality) : 131 | splits = self.transform_layer(input_x, stride=stride, scope=layer_name + '_splitN_' + str(i)) 132 | layers_split.append(splits) 133 | 134 | return Concatenation(layers_split) 135 | 136 | def residual_layer(self, input_x, out_dim, layer_num, res_block=blocks): 137 | # split + transform(bottleneck) + transition + merge 138 | 139 | for i in range(res_block): 140 | # input_dim = input_x.get_shape().as_list()[-1] 141 | input_dim = int(np.shape(input_x)[-1]) 142 | 143 | if input_dim * 2 == out_dim: 144 | flag = True 145 | stride = 2 146 | channel = input_dim // 2 147 | else: 148 | flag = False 149 | stride = 1 150 | x = self.split_layer(input_x, stride=stride, layer_name='split_layer_'+layer_num+'_'+str(i)) 151 | x = self.transition_layer(x, out_dim=out_dim, scope='trans_layer_'+layer_num+'_'+str(i)) 152 | 153 | if flag is True : 154 | pad_input_x = Average_pooling(input_x) 155 | pad_input_x = tf.pad(pad_input_x, [[0, 0], [0, 0], [0, 0], [channel, channel]]) # [?, height, width, channel] 156 | else : 157 | pad_input_x = input_x 158 | 159 | input_x = Relu(x + pad_input_x) 160 | 161 | return input_x 162 | 163 | 164 | def Build_ResNext(self, input_x): 165 | # only cifar10 architecture 166 | 167 | input_x = self.first_layer(input_x, scope='first_layer') 168 | 169 | x = self.residual_layer(input_x, out_dim=64, layer_num='1') 170 | x = self.residual_layer(x, out_dim=128, layer_num='2') 171 | x = self.residual_layer(x, out_dim=256, layer_num='3') 172 | 173 | x = Global_Average_Pooling(x) 174 | x = flatten(x) 175 | x = Linear(x) 176 | 177 | # x = tf.reshape(x, [-1,10]) 178 | return x 179 | 180 | 181 | train_x, train_y, test_x, test_y = prepare_data() 182 | train_x, test_x = color_preprocessing(train_x, test_x) 183 | 184 | 185 | # image_size = 32, img_channels = 3, class_num = 10 in cifar10 186 | x = tf.placeholder(tf.float32, shape=[None, image_size, image_size, img_channels]) 187 | label = tf.placeholder(tf.float32, shape=[None, class_num]) 188 | 189 | training_flag = tf.placeholder(tf.bool) 190 | 191 | 192 | learning_rate = tf.placeholder(tf.float32, name='learning_rate') 193 | 194 | logits = ResNeXt(x, training=training_flag).model 195 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label, logits=logits)) 196 | 197 | l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()]) 198 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum, use_nesterov=True) 199 | train = optimizer.minimize(cost + l2_loss * weight_decay) 200 | 201 | correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1)) 202 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 203 | 204 | saver = tf.train.Saver(tf.global_variables()) 205 | 206 | with tf.Session() as sess: 207 | ckpt = tf.train.get_checkpoint_state('./model') 208 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 209 | saver.restore(sess, ckpt.model_checkpoint_path) 210 | else: 211 | sess.run(tf.global_variables_initializer()) 212 | 213 | summary_writer = tf.summary.FileWriter('./logs', sess.graph) 214 | 215 | epoch_learning_rate = init_learning_rate 216 | for epoch in range(1, total_epochs + 1): 217 | if epoch == (total_epochs * 0.5) or epoch == (total_epochs * 0.75): 218 | epoch_learning_rate = epoch_learning_rate / 10 219 | 220 | pre_index = 0 221 | train_acc = 0.0 222 | train_loss = 0.0 223 | 224 | for step in range(1, iteration + 1): 225 | if pre_index + batch_size < 50000: 226 | batch_x = train_x[pre_index: pre_index + batch_size] 227 | batch_y = train_y[pre_index: pre_index + batch_size] 228 | else: 229 | batch_x = train_x[pre_index:] 230 | batch_y = train_y[pre_index:] 231 | 232 | batch_x = data_augmentation(batch_x) 233 | 234 | train_feed_dict = { 235 | x: batch_x, 236 | label: batch_y, 237 | learning_rate: epoch_learning_rate, 238 | training_flag: True 239 | } 240 | 241 | _, batch_loss = sess.run([train, cost], feed_dict=train_feed_dict) 242 | batch_acc = accuracy.eval(feed_dict=train_feed_dict) 243 | 244 | train_loss += batch_loss 245 | train_acc += batch_acc 246 | pre_index += batch_size 247 | 248 | 249 | train_loss /= iteration # average loss 250 | train_acc /= iteration # average accuracy 251 | 252 | train_summary = tf.Summary(value=[tf.Summary.Value(tag='train_loss', simple_value=train_loss), 253 | tf.Summary.Value(tag='train_accuracy', simple_value=train_acc)]) 254 | 255 | test_acc, test_loss, test_summary = Evaluate(sess) 256 | 257 | summary_writer.add_summary(summary=train_summary, global_step=epoch) 258 | summary_writer.add_summary(summary=test_summary, global_step=epoch) 259 | summary_writer.flush() 260 | 261 | line = "epoch: %d/%d, train_loss: %.4f, train_acc: %.4f, test_loss: %.4f, test_acc: %.4f \n" % ( 262 | epoch, total_epochs, train_loss, train_acc, test_loss, test_acc) 263 | print(line) 264 | 265 | with open('logs.txt', 'a') as f: 266 | f.write(line) 267 | 268 | saver.save(sess=sess, save_path='./model/ResNeXt.ckpt') 269 | -------------------------------------------------------------------------------- /assests/ResNeXt.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/ResNeXt-Tensorflow/60bfd72c5c944ca960f2c906406772c8901cdcef/assests/ResNeXt.JPG -------------------------------------------------------------------------------- /assests/ResNet.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/ResNeXt-Tensorflow/60bfd72c5c944ca960f2c906406772c8901cdcef/assests/ResNet.JPG -------------------------------------------------------------------------------- /assests/comparision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/ResNeXt-Tensorflow/60bfd72c5c944ca960f2c906406772c8901cdcef/assests/comparision.png -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import time 6 | import pickle 7 | import random 8 | import numpy as np 9 | 10 | class_num = 10 11 | image_size = 32 12 | img_channels = 3 13 | 14 | 15 | # ========================================================== # 16 | # ├─ prepare_data() 17 | # ├─ download training data if not exist by download_data() 18 | # ├─ load data by load_data() 19 | # └─ shuffe and return data 20 | # ========================================================== # 21 | 22 | 23 | 24 | def download_data(): 25 | dirname = 'cifar-10-batches-py' 26 | origin = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 27 | fname = 'cifar-10-python.tar.gz' 28 | fpath = './' + dirname 29 | 30 | download = False 31 | if os.path.exists(fpath) or os.path.isfile(fname): 32 | download = False 33 | print("DataSet aready exist!") 34 | else: 35 | download = True 36 | if download: 37 | print('Downloading data from', origin) 38 | import urllib.request 39 | import tarfile 40 | 41 | def reporthook(count, block_size, total_size): 42 | global start_time 43 | if count == 0: 44 | start_time = time.time() 45 | return 46 | duration = time.time() - start_time 47 | progress_size = int(count * block_size) 48 | speed = int(progress_size / (1024 * duration)) 49 | percent = min(int(count * block_size * 100 / total_size), 100) 50 | sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % 51 | (percent, progress_size / (1024 * 1024), speed, duration)) 52 | sys.stdout.flush() 53 | 54 | urllib.request.urlretrieve(origin, fname, reporthook) 55 | print('Download finished. Start extract!', origin) 56 | if (fname.endswith("tar.gz")): 57 | tar = tarfile.open(fname, "r:gz") 58 | tar.extractall() 59 | tar.close() 60 | elif (fname.endswith("tar")): 61 | tar = tarfile.open(fname, "r:") 62 | tar.extractall() 63 | tar.close() 64 | 65 | 66 | def unpickle(file): 67 | with open(file, 'rb') as fo: 68 | dict = pickle.load(fo, encoding='bytes') 69 | return dict 70 | 71 | 72 | def load_data_one(file): 73 | batch = unpickle(file) 74 | data = batch[b'data'] 75 | labels = batch[b'labels'] 76 | print("Loading %s : %d." % (file, len(data))) 77 | return data, labels 78 | 79 | 80 | def load_data(files, data_dir, label_count): 81 | global image_size, img_channels 82 | data, labels = load_data_one(data_dir + '/' + files[0]) 83 | for f in files[1:]: 84 | data_n, labels_n = load_data_one(data_dir + '/' + f) 85 | data = np.append(data, data_n, axis=0) 86 | labels = np.append(labels, labels_n, axis=0) 87 | labels = np.array([[float(i == label) for i in range(label_count)] for label in labels]) 88 | data = data.reshape([-1, img_channels, image_size, image_size]) 89 | data = data.transpose([0, 2, 3, 1]) 90 | return data, labels 91 | 92 | 93 | def prepare_data(): 94 | print("======Loading data======") 95 | download_data() 96 | data_dir = './cifar-10-batches-py' 97 | image_dim = image_size * image_size * img_channels 98 | meta = unpickle(data_dir + '/batches.meta') 99 | 100 | label_names = meta[b'label_names'] 101 | label_count = len(label_names) 102 | train_files = ['data_batch_%d' % d for d in range(1, 6)] 103 | train_data, train_labels = load_data(train_files, data_dir, label_count) 104 | test_data, test_labels = load_data(['test_batch'], data_dir, label_count) 105 | 106 | print("Train data:", np.shape(train_data), np.shape(train_labels)) 107 | print("Test data :", np.shape(test_data), np.shape(test_labels)) 108 | print("======Load finished======") 109 | 110 | print("======Shuffling data======") 111 | indices = np.random.permutation(len(train_data)) 112 | train_data = train_data[indices] 113 | train_labels = train_labels[indices] 114 | print("======Prepare Finished======") 115 | 116 | return train_data, train_labels, test_data, test_labels 117 | 118 | 119 | # ========================================================== # 120 | # ├─ _random_crop() 121 | # ├─ _random_flip_leftright() 122 | # ├─ data_augmentation() 123 | # └─ color_preprocessing() 124 | # ========================================================== # 125 | 126 | def _random_crop(batch, crop_shape, padding=None): 127 | oshape = np.shape(batch[0]) 128 | 129 | if padding: 130 | oshape = (oshape[0] + 2 * padding, oshape[1] + 2 * padding) 131 | new_batch = [] 132 | npad = ((padding, padding), (padding, padding), (0, 0)) 133 | for i in range(len(batch)): 134 | new_batch.append(batch[i]) 135 | if padding: 136 | new_batch[i] = np.lib.pad(batch[i], pad_width=npad, 137 | mode='constant', constant_values=0) 138 | nh = random.randint(0, oshape[0] - crop_shape[0]) 139 | nw = random.randint(0, oshape[1] - crop_shape[1]) 140 | new_batch[i] = new_batch[i][nh:nh + crop_shape[0], 141 | nw:nw + crop_shape[1]] 142 | return new_batch 143 | 144 | 145 | def _random_flip_leftright(batch): 146 | for i in range(len(batch)): 147 | if bool(random.getrandbits(1)): 148 | batch[i] = np.fliplr(batch[i]) 149 | return batch 150 | 151 | 152 | def color_preprocessing(x_train, x_test): 153 | x_train = x_train.astype('float32') 154 | x_test = x_test.astype('float32') 155 | x_train[:, :, :, 0] = (x_train[:, :, :, 0] - np.mean(x_train[:, :, :, 0])) / np.std(x_train[:, :, :, 0]) 156 | x_train[:, :, :, 1] = (x_train[:, :, :, 1] - np.mean(x_train[:, :, :, 1])) / np.std(x_train[:, :, :, 1]) 157 | x_train[:, :, :, 2] = (x_train[:, :, :, 2] - np.mean(x_train[:, :, :, 2])) / np.std(x_train[:, :, :, 2]) 158 | 159 | x_test[:, :, :, 0] = (x_test[:, :, :, 0] - np.mean(x_test[:, :, :, 0])) / np.std(x_test[:, :, :, 0]) 160 | x_test[:, :, :, 1] = (x_test[:, :, :, 1] - np.mean(x_test[:, :, :, 1])) / np.std(x_test[:, :, :, 1]) 161 | x_test[:, :, :, 2] = (x_test[:, :, :, 2] - np.mean(x_test[:, :, :, 2])) / np.std(x_test[:, :, :, 2]) 162 | 163 | return x_train, x_test 164 | 165 | 166 | def data_augmentation(batch): 167 | batch = _random_flip_leftright(batch) 168 | batch = _random_crop(batch, [32, 32], 4) 169 | return batch --------------------------------------------------------------------------------