├── requirements.cpu.txt ├── requirements.gpu.txt ├── .gitignore ├── readme.md ├── effnet.py ├── normalizer.py └── train.py /requirements.cpu.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.0 2 | astor==0.7.1 3 | certifi==2018.11.29 4 | decorator==4.4.0 5 | gast==0.2.2 6 | google-pasta==0.1.7 7 | grpcio==1.18.0 8 | h5py==2.9.0 9 | imageio==2.5.0 10 | intel-openmp==2019.0 11 | Keras==2.2.4 12 | Keras-Applications==1.0.7 13 | Keras-Preprocessing==1.0.9 14 | Markdown==3.0.1 15 | mkl==2019.0 16 | networkx==2.3 17 | numpy==1.16.4 18 | opencv-contrib-python==4.0.0.21 19 | pandas==0.25.0 20 | Pillow==6.2.0 21 | protobuf==3.6.1 22 | python-dateutil==2.8.0 23 | pytz==2019.2 24 | PyWavelets==1.0.3 25 | PyYAML==3.13 26 | scikit-image==0.15.0 27 | scipy==1.3.0 28 | six==1.12.0 29 | tb-nightly==1.14.0a20190603 30 | tensorboard==1.12.2 31 | tensorflow==2.0.0b1 32 | termcolor==1.1.0 33 | tf-estimator-nightly==1.14.0.dev2019060501 34 | Werkzeug==0.15.3 35 | wincertstore==0.2 36 | wrapt==1.11.2 37 | -------------------------------------------------------------------------------- /requirements.gpu.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.0 2 | astor==0.7.1 3 | certifi==2018.11.29 4 | decorator==4.4.0 5 | gast==0.2.2 6 | google-pasta==0.1.7 7 | grpcio==1.18.0 8 | h5py==2.9.0 9 | imageio==2.5.0 10 | intel-openmp==2019.0 11 | Keras==2.2.4 12 | Keras-Applications==1.0.7 13 | Keras-Preprocessing==1.0.9 14 | Markdown==3.0.1 15 | mkl==2019.0 16 | networkx==2.3 17 | numpy==1.16.4 18 | opencv-contrib-python==4.0.0.21 19 | pandas==0.25.0 20 | Pillow==6.2.0 21 | protobuf==3.6.1 22 | python-dateutil==2.8.0 23 | pytz==2019.2 24 | PyWavelets==1.0.3 25 | PyYAML==3.13 26 | scikit-image==0.15.0 27 | scipy==1.3.0 28 | six==1.12.0 29 | tb-nightly==1.14.0a20190603 30 | tensorboard==1.12.2 31 | tensorflow-gpu==2.0.0b1 32 | termcolor==1.1.0 33 | tf-estimator-nightly==1.14.0.dev2019060501 34 | Werkzeug==0.15.3 35 | wincertstore==0.2 36 | wrapt==1.11.2 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | data/ 103 | model/ 104 | .vscode/ -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # efficientnet-tf2 2 | A TensorFlow 2.0 implementation of [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946), aka EfficientNet. 3 | 4 | ## Motivation 5 | 6 | EfficientNet is still one of the most efficient architectures for image classification. Considering that TensorFlow 2.0 has already hit version beta1, I think that a flexible and reusable implementation of EfficientNet in TF 2.0 might be useful for practitioners. 7 | 8 | ## Implementation 9 | 10 | I implemented a running mean and standard deviation calculation with [Welford algorithm](https://www.johndcook.com/blog/standard_deviation/), which eliminates the problem of loading the whole dataset into the memory. `Normalizer` class, calculating the mean and standard deviation, is also used as a `preprocessing_function` argument to `tf.keras.preprocessing.image.ImageDataGenerator`. 11 | 12 | ## Install 13 | 14 | 1. `conda create -n effnet python=3.6.8` 15 | 2. `conda activate effnet` 16 | 3. `git clone https://github.com/monatis/effnet-tf2.git` 17 | 4. `cd efficientnet-tf2` 18 | 5. `python -m pip install -r requirements.gpu.txt` # Change to `requirements.cpu.txt` if you're not using GPU. 19 | 20 | ## Usage 21 | 22 | `train_dir` and `validation_dir` directories should contain a subdirectory for each class in the dataset. Then run: 23 | 24 | - `python train.py --train_dir /path/to/training/images --validation_dir /path/to/validation/images` 25 | - See `model/` directory for training output. 26 | 27 | run `python train.py --help` to see all the options. 28 | 29 | ## Roadmap 30 | 31 | - [x] Share model architecture and a training script. 32 | - [x] Implement export to saved model. 33 | - [x] Implement command line arguments to configure data augmentation. 34 | - [ ] Share an inference script. 35 | - [x] Implement mean and STD normalization. 36 | - [ ] Implement confusion matrix. 37 | - [ ] Implement export to TFLite for model inference. 38 | - [ ] Share an example Android app using the exported TFLite model. 39 | 40 | ## License 41 | 42 | MIT -------------------------------------------------------------------------------- /effnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Reference 3 | - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks] 4 | (https://arxiv.org/abs/1905.11946) 5 | """ 6 | 7 | import tensorflow as tf 8 | 9 | def get_top(x_input): 10 | """Block top operations 11 | This functions apply Batch Normalization and Leaky ReLU activation to the input. 12 | 13 | # Arguments: 14 | x_input: Tensor, input to apply BN and activation to. 15 | 16 | # Returns: 17 | Output tensor 18 | """ 19 | 20 | x = tf.keras.layers.BatchNormalization()(x_input) 21 | x = tf.keras.layers.LeakyReLU()(x) 22 | return x 23 | 24 | def get_block(x_input, input_channels, output_channels): 25 | """MBConv block 26 | This function defines a mobile Inverted Residual Bottleneck block with BN and Leaky ReLU 27 | 28 | # Arguments 29 | x_input: Tensor, input tensor of conv layer. 30 | input_channels: Integer, the dimentionality of the input space. 31 | output_channels: Integer, the dimensionality of the output space. 32 | 33 | # Returns 34 | Output tensor. 35 | """ 36 | 37 | x = tf.keras.layers.Conv2D(input_channels, kernel_size=(1, 1), padding='same', use_bias=False)(x_input) 38 | x = get_top(x) 39 | x = tf.keras.layers.DepthwiseConv2D(kernel_size=(1, 3), padding='same', use_bias=False)(x) 40 | x = get_top(x) 41 | x = tf.keras.layers.MaxPooling2D(pool_size=(2, 1), strides=(2, 1))(x) 42 | x = tf.keras.layers.DepthwiseConv2D(kernel_size=(3, 1), padding='same', use_bias=False)(x) 43 | x = get_top(x) 44 | x = tf.keras.layers.Conv2D(output_channels, kernel_size=(2, 1), strides=(1, 2), padding='same', use_bias=False)(x) 45 | return x 46 | 47 | 48 | def EffNet(input_shape, num_classes, plot_model=False): 49 | """EffNet 50 | This function defines a EfficientNet architecture. 51 | 52 | # Arguments 53 | input_shape: An integer or tuple/list of 3 integers, shape 54 | of input tensor. 55 | num_classes: Integer, number of classes. 56 | plot_model: Boolean, whether to plot model architecture or not 57 | # Returns 58 | EfficientNet model. 59 | """ 60 | x_input = tf.keras.layers.Input(shape=input_shape) 61 | x = get_block(x_input, 32, 64) 62 | x = get_block(x, 64, 128) 63 | x = get_block(x, 128, 256) 64 | x = tf.keras.layers.Flatten()(x) 65 | x = tf.keras.layers.Dense(num_classes, activation='softmax')(x) 66 | model = tf.keras.models.Model(inputs=x_input, outputs=x) 67 | 68 | if plot_model: 69 | tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True) 70 | 71 | return model 72 | 73 | -------------------------------------------------------------------------------- /normalizer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | 5 | def get_mean_std(base_dir, filenames, target_size): 6 | n = 0 7 | r_mean, g_mean, b_mean = 0.0, 0.0, 0.0 8 | r_M2, g_M2, b_M2 = 0.0, 0.0, 0.0 9 | 10 | 11 | for z, filename in enumerate(filenames): 12 | if z % 1000 == 0: 13 | print("Processing image {}/{}".format(z+1, len(filenames))) 14 | 15 | x = tf.keras.preprocessing.image.img_to_array(tf.keras.preprocessing.image.load_img(os.path.join(base_dir, filename), target_size=target_size)) 16 | r = x[:, :, 0].flatten().tolist() 17 | g = x[:, :, 1].flatten().tolist() 18 | b = x[:, :, 2].flatten().tolist() 19 | 20 | for (xr, xg, xb) in zip(r, g, b): 21 | n = n + 1 22 | 23 | r_delta = xr - r_mean 24 | g_delta = xg - g_mean 25 | b_delta = xb - b_mean 26 | 27 | r_mean = r_mean + r_delta/n 28 | g_mean = g_mean + g_delta/n 29 | b_mean = b_mean + b_delta/n 30 | 31 | r_M2 = r_M2 + r_delta * (xr - r_mean) 32 | g_M2 = g_M2 + g_delta * (xg - g_mean) 33 | b_M2 = b_M2 + b_delta * (xb - b_mean) 34 | 35 | r_variance = r_M2 / (n - 1) 36 | g_variance = g_M2 / (n - 1) 37 | b_variance = b_M2 / (n - 1) 38 | 39 | r_std = np.sqrt(r_variance) 40 | g_std = np.sqrt(g_variance) 41 | b_std = np.sqrt(b_variance) 42 | 43 | return np.array([r_mean, g_mean, b_mean]), np.array([r_std, g_std, b_std]) 44 | 45 | 46 | class Normalizer(): 47 | def __init__(self, mean=None, std=None): 48 | self.mean = mean 49 | self.std = std 50 | 51 | def __call__(self, img): 52 | if self.mean is not None: 53 | img = self.center(img) 54 | if self.std is not None: 55 | img = self.scale(img) 56 | return img 57 | 58 | def center(self, img): 59 | return img - self.mean 60 | 61 | def scale(self, img): 62 | return img / self.std 63 | 64 | def set_stats(self, mean, std): 65 | self.mean = np.array(mean).reshape(1, 1, 3) 66 | self.std = np.array(std).reshape(1, 1, 3) 67 | 68 | 69 | def get_stats(self, base_dir, filenames, target_size, calc_mean=True, calc_std=True): 70 | print("Calculating mean and standard deviation with shape: ", target_size) 71 | m, s = get_mean_std(base_dir, filenames, target_size) 72 | if calc_mean: 73 | self.mean = m 74 | self.mean = self.mean.reshape(1, 1, 3) 75 | print("Dataset mean [r, g, b] = {}".format(m.tolist())) 76 | if calc_std: 77 | self.std = s 78 | self.std = self.std.reshape(1, 1, 3) 79 | print("Dataset std [r, g, b] = {}". format(s.tolist())) 80 | 81 | return str(m.tolist()), str(s.tolist()) 82 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train the EffNet model 3 | """ 4 | import os 5 | import argparse 6 | import pandas as pd 7 | import csv 8 | from normalizer import Normalizer 9 | 10 | from effnet import EffNet 11 | 12 | import tensorflow as tf 13 | 14 | 15 | def generate(args): 16 | """Data generation and augmentation 17 | 18 | # Arguments 19 | args: Dictionary, command line arguments 20 | 21 | # Returns 22 | train_generator: train set generator 23 | validation_generator: validation set generator 24 | num_training: Integer, number of images in the train split. 25 | num_validation: Integer, number of images in the validation split. 26 | """ 27 | 28 | # Using the data Augmentation in traning data 29 | 30 | normalizer = Normalizer() 31 | 32 | train_aug = tf.keras.preprocessing.image.ImageDataGenerator( 33 | #rescale=1. / 255., 34 | shear_range=args.shear_range, 35 | zoom_range=args.zoom_range, 36 | rotation_range=args.rotation_range, 37 | width_shift_range=args.width_shift_range, 38 | height_shift_range=args.height_shift_range, 39 | horizontal_flip=args.horizontal_flip, 40 | vertical_flip=args.vertical_flip, 41 | preprocessing_function=normalizer) 42 | 43 | 44 | validation_aug = tf.keras.preprocessing.image.ImageDataGenerator(preprocessing_function=normalizer) 45 | 46 | train_generator = train_aug.flow_from_directory( 47 | args.train_dir, 48 | target_size=(args.input_size, args.input_size), 49 | batch_size=args.batch_size, 50 | class_mode='categorical', 51 | shuffle=True) 52 | 53 | mean, std = [], [] 54 | if args.mean is None or args.std is None: 55 | mean, std = normalizer.get_stats(args.train_dir, train_generator.filenames, (args.input_size, args.input_size)) 56 | else: 57 | mean = [float(m.strip()) for m in args.mean.split(',')] 58 | std = [float(s.strip()) for s in args.std.split(',')] 59 | normalizer.set_stats(mean, std) 60 | 61 | if not os.path.exists('model'): 62 | os.makedirs('model') 63 | with open('model/stats.txt', 'w') as stats: 64 | stats.write("Dataset mean [r, g, b] = {}\n".format(mean)) 65 | 66 | 67 | label_map = (train_generator.class_indices) 68 | label_map = dict((v,k) for k,v in label_map.items()) 69 | 70 | with open('model/labels.csv', 'w') as csv_file: 71 | csv_writer = csv.writer(csv_file, lineterminator='\n') 72 | csv_writer.writerows(label_map.items()) 73 | 74 | validation_generator = validation_aug.flow_from_directory( 75 | args.validation_dir, 76 | target_size=(args.input_size, args.input_size), 77 | batch_size=args.batch_size, 78 | class_mode='categorical') 79 | 80 | return train_generator, validation_generator, train_generator.samples, validation_generator.samples, len(label_map) 81 | 82 | 83 | 84 | def train(args): 85 | """Train the model. 86 | 87 | # Arguments 88 | args: Dictionary, command line arguments.""" 89 | 90 | 91 | train_generator, validation_generator, num_training, num_validation, num_classes = generate(args) 92 | print("{} classes found".format(num_classes)) 93 | 94 | model = EffNet((args.input_size, args.input_size, 3), num_classes, args.plot_model) 95 | 96 | opt = tf.keras.optimizers.Adam() 97 | earlystop = tf.keras.callbacks.EarlyStopping(monitor='val_acc', patience=30, verbose=1, mode='auto') 98 | model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['acc']) 99 | 100 | hist = model.fit_generator( 101 | train_generator, 102 | validation_data=validation_generator, 103 | steps_per_epoch=num_training // args.batch_size, 104 | validation_steps=num_validation // args.batch_size, 105 | epochs=args.epochs, 106 | callbacks=[earlystop]) 107 | 108 | if not os.path.exists('model'): 109 | os.makedirs('model') 110 | 111 | df = pd.DataFrame.from_dict(hist.history) 112 | df.to_csv('model/hist.csv', encoding='utf-8', index=False) 113 | if not os.path.exists('model/output'): 114 | os.makedirs('model/output') 115 | model.save('model/output') 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser() 120 | # Required arguments. 121 | parser.add_argument( 122 | "-t", 123 | "--train_dir", 124 | required=True, 125 | help="Path to directory containing training images") 126 | parser.add_argument( 127 | "-v", 128 | "--validation_dir", 129 | required=True, 130 | help="Path to directory containing validation images") 131 | # Optional arguments. 132 | parser.add_argument( 133 | "-s", 134 | "--input_size", 135 | type=int, 136 | default=224, 137 | help="Input image size.") 138 | parser.add_argument( 139 | "-b", 140 | "--batch_size", 141 | type=int, 142 | default=32, 143 | help="Number of images in a training batch.") 144 | parser.add_argument( 145 | "-e", 146 | "--epochs", 147 | type=int, 148 | default=50, 149 | help="Number of training epochs.") 150 | parser.add_argument( 151 | "-p", 152 | "--plot_model", 153 | type=bool, 154 | default=False) 155 | parser.add_argument( 156 | "--shear_range", 157 | type=float, 158 | default=0.2, 159 | help="Shear range value for data augmentation.") 160 | parser.add_argument( 161 | "--zoom_range", 162 | type=float, 163 | default=0.2, 164 | help="Zoom range value for data augmentation.") 165 | parser.add_argument( 166 | "--rotation_range", 167 | type=int, 168 | default=90, 169 | help="Rotation range value for data augmentation.") 170 | parser.add_argument( 171 | "--width_shift_range", 172 | type=float, 173 | default=0.2, 174 | help="Width shift range value for data augmentation.") 175 | parser.add_argument( 176 | "--height_shift_range", 177 | type=float, 178 | default=0.2, 179 | help="Height shift range value for data augmentation.") 180 | parser.add_argument( 181 | "--horizontal_flip", 182 | type=bool, 183 | default=True, 184 | help="Whether or not to flip horizontally for data augmentation.") 185 | parser.add_argument( 186 | "--vertical_flip", 187 | type=bool, 188 | default=False, 189 | help="Whether or not to flip vertically for data augmentation.") 190 | parser.add_argument( 191 | "--mean", 192 | default=None, 193 | help="Dataset mean values for r, g, b values separates by commas") 194 | parser.add_argument( 195 | "--std", 196 | default=None, 197 | help="Dataset std values for r, g, b values separates by commas") 198 | 199 | args = parser.parse_args() 200 | train(args) 201 | --------------------------------------------------------------------------------