├── .gitignore ├── Mnasnet.py ├── MnasnetEager.py ├── README.md ├── mnasnet.png ├── train.py └── train_eager.py /.gitignore: -------------------------------------------------------------------------------- 1 | weights 2 | weights/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /Mnasnet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import optimizers, layers, models, callbacks, utils, preprocessing, regularizers 2 | from tensorflow.keras import backend as K 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | 7 | 8 | 9 | def MnasNet(n_classes=1000, input_shape=(224, 224, 3), alpha=1): 10 | inputs = layers.Input(shape=input_shape) 11 | 12 | x = conv_bn(inputs, 32*alpha, 3, strides=2) 13 | x = sepConv_bn_noskip(x, 16*alpha, 3, strides=1) 14 | # MBConv3 3x3 15 | x = MBConv_idskip(x, filters=24, kernel_size=3, strides=2, filters_multiplier=3, alpha=alpha) 16 | x = MBConv_idskip(x, filters=24, kernel_size=3, strides=1, filters_multiplier=3, alpha=alpha) 17 | x = MBConv_idskip(x, filters=24, kernel_size=3, strides=1, filters_multiplier=3, alpha=alpha) 18 | # MBConv3 5x5 19 | x = MBConv_idskip(x, filters=40, kernel_size=5, strides=2, filters_multiplier=3, alpha=alpha) 20 | x = MBConv_idskip(x, filters=40, kernel_size=5, strides=1, filters_multiplier=3, alpha=alpha) 21 | x = MBConv_idskip(x, filters=40, kernel_size=5, strides=1, filters_multiplier=3, alpha=alpha) 22 | # MBConv6 5x5 23 | x = MBConv_idskip(x, filters=80, kernel_size=5, strides=2, filters_multiplier=6, alpha=alpha) 24 | x = MBConv_idskip(x, filters=80, kernel_size=5, strides=1, filters_multiplier=6, alpha=alpha) 25 | x = MBConv_idskip(x, filters=80, kernel_size=5, strides=1, filters_multiplier=6, alpha=alpha) 26 | # MBConv6 3x3 27 | x = MBConv_idskip(x, filters=96, kernel_size=3, strides=1, filters_multiplier=6, alpha=alpha) 28 | x = MBConv_idskip(x, filters=96, kernel_size=3, strides=1, filters_multiplier=6, alpha=alpha) 29 | # MBConv6 5x5 30 | x = MBConv_idskip(x, filters=192, kernel_size=5, strides=2, filters_multiplier=6, alpha=alpha) 31 | x = MBConv_idskip(x, filters=192, kernel_size=5, strides=1, filters_multiplier=6, alpha=alpha) 32 | x = MBConv_idskip(x, filters=192, kernel_size=5, strides=1, filters_multiplier=6, alpha=alpha) 33 | x = MBConv_idskip(x, filters=192, kernel_size=5, strides=1, filters_multiplier=6, alpha=alpha) 34 | # MBConv6 3x3 35 | x = MBConv_idskip(x, filters=320, kernel_size=3, strides=1, filters_multiplier=6, alpha=alpha) 36 | 37 | # FC + POOL 38 | x = conv_bn(x, filters=1152*alpha, kernel_size=1, strides=1) 39 | x = layers.GlobalAveragePooling2D()(x) 40 | predictions = layers.Dense(n_classes, activation='softmax')(x) 41 | 42 | return models.Model(inputs=inputs, outputs=predictions) 43 | 44 | 45 | 46 | 47 | # Convolution with batch normalization 48 | def conv_bn(x, filters, kernel_size, strides=1, alpha=1, activation=True): 49 | """Convolution Block 50 | This function defines a 2D convolution operation with BN and relu6. 51 | # Arguments 52 | x: Tensor, input tensor of conv layer. 53 | filters: Integer, the dimensionality of the output space. 54 | kernel_size: An integer or tuple/list of 2 integers, specifying the 55 | width and height of the 2D convolution window. 56 | strides: An integer or tuple/list of 2 integers, 57 | specifying the strides of the convolution along the width and height. 58 | Can be a single integer to specify the same value for 59 | all spatial dimensions. 60 | alpha: An integer which multiplies the filters dimensionality 61 | activation: A boolean which indicates whether to have an activation after the normalization 62 | # Returns 63 | Output tensor. 64 | """ 65 | filters = _make_divisible(filters * alpha) 66 | x = layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same', 67 | use_bias=False, kernel_regularizer=regularizers.l2(l=0.0003))(x) 68 | x = layers.BatchNormalization(epsilon=1e-3, momentum=0.999)(x) 69 | if activation: 70 | x = layers.ReLU(max_value=6)(x) 71 | return x 72 | 73 | # Depth-wise Separable Convolution with batch normalization 74 | def depthwiseConv_bn(x, depth_multiplier, kernel_size, strides=1): 75 | """ Depthwise convolution 76 | The DepthwiseConv2D is just the first step of the Depthwise Separable convolution (without the pointwise step). 77 | Depthwise Separable convolutions consists in performing just the first step in a depthwise spatial convolution 78 | (which acts on each input channel separately). 79 | 80 | This function defines a 2D Depthwise separable convolution operation with BN and relu6. 81 | # Arguments 82 | x: Tensor, input tensor of conv layer. 83 | filters: Integer, the dimensionality of the output space. 84 | kernel_size: An integer or tuple/list of 2 integers, specifying the 85 | width and height of the 2D convolution window. 86 | strides: An integer or tuple/list of 2 integers, 87 | specifying the strides of the convolution along the width and height. 88 | Can be a single integer to specify the same value for 89 | all spatial dimensions. 90 | # Returns 91 | Output tensor. 92 | """ 93 | 94 | x = layers.DepthwiseConv2D(kernel_size=kernel_size, strides=strides, depth_multiplier=depth_multiplier, 95 | padding='same', use_bias=False, kernel_regularizer=regularizers.l2(l=0.0003))(x) 96 | x = layers.BatchNormalization(epsilon=1e-3, momentum=0.999)(x) 97 | x = layers.ReLU(max_value=6)(x) 98 | return x 99 | 100 | def sepConv_bn_noskip(x, filters, kernel_size, strides=1): 101 | """ Separable convolution block (Block F of MNasNet paper https://arxiv.org/pdf/1807.11626.pdf) 102 | 103 | # Arguments 104 | x: Tensor, input tensor of conv layer. 105 | filters: Integer, the dimensionality of the output space. 106 | kernel_size: An integer or tuple/list of 2 integers, specifying the 107 | width and height of the 2D convolution window. 108 | strides: An integer or tuple/list of 2 integers, 109 | specifying the strides of the convolution along the width and height. 110 | Can be a single integer to specify the same value for 111 | all spatial dimensions. 112 | # Returns 113 | Output tensor. 114 | """ 115 | 116 | x = depthwiseConv_bn(x, depth_multiplier=1, kernel_size=kernel_size, strides=strides) 117 | x = conv_bn(x, filters=filters, kernel_size=1, strides=1) 118 | 119 | return x 120 | 121 | # Inverted bottleneck block with identity skip connection 122 | def MBConv_idskip(x_input, filters, kernel_size, strides=1, filters_multiplier=1, alpha=1): 123 | """ Mobile inverted bottleneck convolution (Block b, c, d, e of MNasNet paper https://arxiv.org/pdf/1807.11626.pdf) 124 | 125 | # Arguments 126 | x: Tensor, input tensor of conv layer. 127 | filters: Integer, the dimensionality of the output space. 128 | kernel_size: An integer or tuple/list of 2 integers, specifying the 129 | width and height of the 2D convolution window. 130 | strides: An integer or tuple/list of 2 integers, 131 | specifying the strides of the convolution along the width and height. 132 | Can be a single integer to specify the same value for 133 | all spatial dimensions. 134 | alpha: An integer which multiplies the filters dimensionality 135 | 136 | # Returns 137 | Output tensor. 138 | """ 139 | 140 | depthwise_conv_filters = _make_divisible(x_input.shape[3].value) 141 | pointwise_conv_filters = _make_divisible(filters * alpha) 142 | 143 | x = conv_bn(x_input, filters=depthwise_conv_filters * filters_multiplier, kernel_size=1, strides=1) 144 | x = depthwiseConv_bn(x, depth_multiplier=1, kernel_size=kernel_size, strides=strides) 145 | x = conv_bn(x, filters=pointwise_conv_filters, kernel_size=1, strides=1, activation=False) 146 | 147 | # Residual connection if possible 148 | if strides==1 and x.shape[3] == x_input.shape[3]: 149 | return layers.add([x_input, x]) 150 | else: 151 | return x 152 | 153 | 154 | # This function is taken from the original tf repo. 155 | # It ensures that all layers have a channel number that is divisible by 8 156 | # It can be seen here: 157 | # https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 158 | def _make_divisible(v, divisor=8, min_value=None): 159 | if min_value is None: 160 | min_value = divisor 161 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 162 | # Make sure that round down does not go down by more than 10%. 163 | if new_v < 0.9 * v: 164 | new_v += divisor 165 | return new_v 166 | 167 | 168 | 169 | if __name__ == "__main__": 170 | 171 | model = MnasNet() 172 | model.compile(optimizer='adam') 173 | model.summary() 174 | -------------------------------------------------------------------------------- /MnasnetEager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.keras import layers, regularizers, activations 5 | 6 | 7 | 8 | class Mnasnet(tf.keras.Model): 9 | def __init__(self, num_classes, alpha=1, **kwargs): 10 | super(Mnasnet, self).__init__(**kwargs) 11 | self.blocks = [] 12 | 13 | self.conv_bn_initial = Conv_BN(filters=32*alpha, kernel_size=3, strides=2) 14 | 15 | # Frist block (non-identity) Conv+ DepthwiseConv 16 | self.conv1_block1 = depthwiseConv(depth_multiplier=1, kernel_size=3, strides=1) 17 | self.bn1_block1 = layers.BatchNormalization(epsilon=1e-3, momentum=0.999) 18 | self.relu1_block1 = layers.ReLU(max_value=6) 19 | 20 | self.conv_bn_block_1 = Conv_BN(filters=16*alpha, kernel_size=1, strides=1) 21 | 22 | # MBConv3 3x3 23 | self.blocks.append(MBConv_idskip(input_filters=16*alpha, filters=24, kernel_size=3, strides=2, 24 | filters_multiplier=3, alpha=alpha)) 25 | self.blocks.append(MBConv_idskip(input_filters=24*alpha, filters=24, kernel_size=3, strides=1, 26 | filters_multiplier=3, alpha=alpha)) 27 | self.blocks.append(MBConv_idskip(input_filters=24*alpha, filters=24, kernel_size=3, strides=1, 28 | filters_multiplier=3, alpha=alpha)) 29 | 30 | # MBConv3 5x5 31 | self.blocks.append(MBConv_idskip(input_filters=24*alpha, filters=40, kernel_size=5, strides=2, 32 | filters_multiplier=3, alpha=alpha)) 33 | self.blocks.append(MBConv_idskip(input_filters=40*alpha, filters=40, kernel_size=5, strides=1, 34 | filters_multiplier=3, alpha=alpha)) 35 | self.blocks.append(MBConv_idskip(input_filters=40*alpha, filters=40, kernel_size=5, strides=1, 36 | filters_multiplier=3, alpha=alpha)) 37 | # MBConv6 5x5 38 | self.blocks.append(MBConv_idskip(input_filters=40*alpha, filters=80, kernel_size=5, strides=2, 39 | filters_multiplier=6, alpha=alpha)) 40 | self.blocks.append(MBConv_idskip(input_filters=80*alpha, filters=80, kernel_size=5, strides=1, 41 | filters_multiplier=6, alpha=alpha)) 42 | self.blocks.append(MBConv_idskip(input_filters=80*alpha, filters=80, kernel_size=5, strides=1, 43 | filters_multiplier=6, alpha=alpha)) 44 | 45 | # MBConv6 3x3 46 | self.blocks.append(MBConv_idskip(input_filters=80*alpha, filters=96, kernel_size=3, strides=1, 47 | filters_multiplier=6, alpha=alpha)) 48 | self.blocks.append(MBConv_idskip(input_filters=96*alpha, filters=96, kernel_size=3, strides=1, 49 | filters_multiplier=6, alpha=alpha)) 50 | 51 | # MBConv6 5x5 52 | self.blocks.append(MBConv_idskip(input_filters=96*alpha, filters=192, kernel_size=5, strides=2, 53 | filters_multiplier=6, alpha=alpha)) 54 | self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=192, kernel_size=5, strides=1, 55 | filters_multiplier=6, alpha=alpha)) 56 | self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=192, kernel_size=5, strides=1, 57 | filters_multiplier=6, alpha=alpha)) 58 | self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=192, kernel_size=5, strides=1, 59 | filters_multiplier=6, alpha=alpha)) 60 | # MBConv6 3x3 61 | self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=320, kernel_size=3, strides=1, 62 | filters_multiplier=6, alpha=alpha)) 63 | 64 | # Last convolution 65 | self.conv_bn_last = Conv_BN(filters=1152*alpha, kernel_size=1, strides=1) 66 | 67 | # Pool + FC 68 | self.avg_pool = layers.GlobalAveragePooling2D() 69 | self.fc = layers.Dense(num_classes) 70 | 71 | 72 | def call(self, inputs, training=None, mask=None): 73 | out = self.conv_bn_initial(inputs, training=training) 74 | 75 | 76 | out = self.conv1_block1(out) 77 | out = self.bn1_block1(out, training=training) 78 | out = self.relu1_block1(out) 79 | 80 | out = self.conv_bn_block_1(out, training=training) 81 | 82 | # forward pass through all the blocks 83 | for block in self.blocks: 84 | out = block(out, training=training) 85 | 86 | out = self.conv_bn_last(out, training=training) 87 | 88 | out = self.avg_pool(out) 89 | out = self.fc(out) 90 | 91 | ''' 92 | You could return several outputs, even intermediate outputs 93 | ''' 94 | return out 95 | 96 | 97 | 98 | class MBConv_idskip(tf.keras.Model): 99 | 100 | def __init__(self, input_filters, filters, kernel_size, strides=1, filters_multiplier=1, alpha=1): 101 | super(MBConv_idskip, self).__init__() 102 | 103 | self.filters = filters 104 | self.kernel_size = kernel_size 105 | self.strides = strides 106 | self.filters_multiplier = filters_multiplier 107 | self.alpha = alpha 108 | 109 | self.depthwise_conv_filters = _make_divisible(input_filters) 110 | self.pointwise_conv_filters = _make_divisible(filters * alpha) 111 | 112 | #conv1 113 | self.conv_bn1 = Conv_BN(filters=self.depthwise_conv_filters*filters_multiplier, kernel_size=1, strides=1) 114 | 115 | #depthwiseconv2 116 | self.depthwise_conv = depthwiseConv(depth_multiplier=1, kernel_size=kernel_size, strides=strides) 117 | self.bn = layers.BatchNormalization(epsilon=1e-3, momentum=0.999) 118 | self.relu = layers.ReLU(max_value=6) 119 | 120 | #conv3 121 | self.conv_bn2 = Conv_BN(filters=self.pointwise_conv_filters, kernel_size=1, strides=1) 122 | 123 | 124 | 125 | def call(self, inputs, training=None): 126 | 127 | x = self.conv_bn1(inputs, training=training) 128 | 129 | x = self.depthwise_conv(x) 130 | x = self.bn(x, training=training) 131 | x = self.relu(x) 132 | 133 | x = self.conv_bn2(x, training=training, activation=False) 134 | 135 | 136 | # Residual/Identity connection if possible 137 | if self.strides==1 and x.shape[3] == inputs.shape[3]: 138 | return layers.add([inputs, x]) 139 | else: 140 | return x 141 | 142 | 143 | class Conv_BN(tf.keras.Model): 144 | 145 | def __init__(self, filters, kernel_size, strides=1): 146 | super(Conv_BN, self).__init__() 147 | 148 | self.filters = filters 149 | self.kernel_size = kernel_size 150 | self.strides = strides 151 | 152 | self.conv = conv(filters=filters, kernel_size=kernel_size, strides=strides) 153 | self.bn = layers.BatchNormalization(epsilon=1e-3, momentum=0.999) 154 | self.relu = layers.ReLU(max_value=6) 155 | 156 | 157 | def call(self, inputs, training=None, activation=True): 158 | 159 | x = self.conv(inputs) 160 | x = self.bn(x, training=training) 161 | if activation: 162 | x = self.relu(x) 163 | 164 | return x 165 | 166 | # convolution 167 | def conv(filters, kernel_size, strides=1, dilation_rate=1, use_bias=False): 168 | return layers.Conv2D(filters, kernel_size, strides=strides, padding='same', use_bias=use_bias, 169 | kernel_regularizer=regularizers.l2(l=0.0003), dilation_rate=dilation_rate) 170 | 171 | 172 | # Depthwise convolution 173 | def depthwiseConv(kernel_size, strides=1, depth_multiplier=1, dilation_rate=1, use_bias=False): 174 | return layers.DepthwiseConv2D(kernel_size, strides=strides, depth_multiplier=depth_multiplier, 175 | padding='same', use_bias=use_bias, kernel_regularizer=regularizers.l2(l=0.0003), 176 | dilation_rate=dilation_rate) 177 | 178 | # dilation_rate 179 | 180 | 181 | # This function is taken from the original tf repo. 182 | # It ensures that all layers have a channel number that is divisible by 8 183 | # It can be seen here: 184 | # https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 185 | def _make_divisible(v, divisor=8, min_value=None): 186 | if min_value is None: 187 | min_value = divisor 188 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 189 | # Make sure that round down does not go down by more than 10%. 190 | if new_v < 0.9 * v: 191 | new_v += divisor 192 | return new_v 193 | 194 | 195 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNasNet 2 | [Keras (Tensorflow) Implementation](https://github.com/Shathe/MNasNet-Keras-Tensorflow/blob/master/Mnasnet.py) of MNasNet and an example for training and evaluating it on the MNIST dataset. 3 | Check also the [eager execution implementation](https://github.com/Shathe/MNasNet-Keras-Tensorflow/blob/master/MnasnetEager.py) 4 | 5 | According to the paper: [MnasNet: Platform-Aware Neural Architecture Search for Mobile](https://arxiv.org/pdf/1807.11626.pdf) 6 | 7 | ## Requirement 8 | * Python 2.7+ 9 | * Tensorflow-gpu 1.10 10 | 11 | ## Train it 12 | Train the [MNasNet model](https://github.com/Shathe/MNasNet-Keras-Tensorflow/blob/master/Mnasnet.py) on the MNIST dataset! just execute: 13 | ``` 14 | python train.py 15 | ``` 16 | For checking and inspecting the Mnasnet model described in the paper, execute: 17 | ``` 18 | python Mnasnet.py 19 | ``` 20 | 21 | 22 | 23 | ## Train it with eager execution 24 | Train the [MNasNet (eager) model](https://github.com/Shathe/MNasNet-Keras-Tensorflow/blob/master/MnasnetEager.py) on the MNIST dataset! just execute: 25 | 26 | ``` 27 | python train_eager.py 28 | ``` 29 | 30 | The eager execution implementation also outputs logs on Tensorboard. For its visualization: 31 | ``` 32 | tensorboard --logdir=train_log:./logs/train, test_log:./logs/test 33 | ``` 34 | 35 | ## MnasNet for... Semantic Segmentation! 36 | In this other repository, [FC-Mnasnet](https://github.com/Shathe/Semantic-Segmentation-Tensorflow-Eager) I added a decoder to the MnasNet architecture in order to turn it into a semantic segmentation model. 37 | 38 | 39 | 40 | ![alt text](https://github.com/Shathe/MNasNet-Keras-Tensorflow/raw/master/mnasnet.png) 41 | -------------------------------------------------------------------------------- /mnasnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shathe/MNasNet-Keras-Tensorflow/502f9df7d9a837c4d4e23a5e76851f733914a342/mnasnet.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | mnist = tf.keras.datasets.mnist 3 | import Mnasnet 4 | import numpy as np 5 | 6 | # Load Mnist data 7 | (x_train, y_train),(x_test, y_test) = mnist.load_data() 8 | # Preprocess data 9 | x_train, x_test = x_train / 127.5 - 1, x_test / 127.5 - 1 10 | x_train = np.expand_dims(x_train, axis=3) 11 | x_test = np.expand_dims(x_test, axis=3) 12 | 13 | # Load model 14 | model = Mnasnet.MnasNet(input_shape=(28, 28, 1)) 15 | 16 | model.compile(optimizer='adam', 17 | loss='sparse_categorical_crossentropy', 18 | metrics=['accuracy']) 19 | model.summary() 20 | # Train it 21 | model.fit(x_train, y_train, epochs=20) 22 | # Evaluate it 23 | loss, acc = model.evaluate(x_test, y_test) 24 | print('Accuracy of: ' + str(acc*100.) + '%') -------------------------------------------------------------------------------- /train_eager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.eager as tfe 4 | import MnasnetEager 5 | 6 | # enable eager mode 7 | tf.enable_eager_execution() 8 | tf.set_random_seed(0) 9 | np.random.seed(0) 10 | 11 | 12 | # Define the loss function 13 | def loss_function(model, x, y, training=True): 14 | y_ = model(x, training=training) 15 | loss = tf.losses.softmax_cross_entropy(y, y_) 16 | print(loss) 17 | return loss 18 | 19 | # Prints the number of parameters of a model 20 | def get_params(model): 21 | total_parameters = 0 22 | for variable in model.variables: 23 | # shape is an array of tf.Dimension 24 | shape = variable.get_shape() 25 | variable_parameters = 1 26 | 27 | for dim in shape: 28 | variable_parameters *= dim.value 29 | total_parameters += variable_parameters 30 | print("Total parameters of the net: " + str(total_parameters)+ " == " + str(total_parameters/1000000.0) + "M") 31 | 32 | # Returns a pretrained model (can be used in eager execution) 33 | def get_pretrained_model(num_classes, input_shape=(224, 224, 3)): 34 | model = tf.keras.applications.ResNet50(input_shape=input_shape, include_top=False, weights='imagenet') 35 | logits = tf.keras.layers.Dense(num_classes, name='fc')(model.output) 36 | model = tf.keras.models.Model(model.inputs, logits) 37 | return model 38 | 39 | # Writes a summary given a tensor 40 | def write_summary(tensor, writer, name): 41 | with tf.contrib.summary.always_record_summaries(): # record_summaries_every_n_global_steps(1) 42 | writer.set_as_default() 43 | tf.contrib.summary.scalar(name, tensor) 44 | 45 | 46 | # Trains the model for certains epochs on a dataset 47 | def train(dset_train, dset_test, model, epochs=5, show_loss=False): 48 | # Define summary writers and global step for logging 49 | writer_train = tf.contrib.summary.create_file_writer('./logs/train') 50 | writer_test = tf.contrib.summary.create_file_writer('./logs/test') 51 | global_step=tf.train.get_or_create_global_step() # return global step var 52 | 53 | for epoch in xrange(epochs): 54 | print('epoch: '+ str(epoch)) 55 | for x, y in dset_train: # for every batch 56 | global_step.assign_add(1) # add one step per iteration 57 | 58 | with tf.GradientTape() as g: 59 | y_ = model(x, training=True) 60 | loss = tf.losses.softmax_cross_entropy(y, y_) 61 | write_summary(loss, writer_train, 'loss') 62 | if show_loss: print('Training loss: ' + str(loss.numpy())) 63 | 64 | # Gets gradients and applies them 65 | grads = g.gradient(loss, model.variables) 66 | optimizer.apply_gradients(zip(grads, model.variables)) 67 | 68 | # Get accuracies 69 | train_acc = get_accuracy(dset_train, model, training=True) 70 | test_acc = get_accuracy(dset_test, model, writer=writer_test) 71 | # write summaries and print 72 | write_summary(train_acc, writer_train, 'accuracy') 73 | write_summary(test_acc, writer_test, 'accuracy') 74 | print('Train accuracy: ' + str(train_acc.numpy())) 75 | print('Test accuracy: ' + str(test_acc.numpy())) 76 | 77 | 78 | # Tests the model on a dataset 79 | def get_accuracy(dset_test, model, training=False, writer=None): 80 | accuracy = tfe.metrics.Accuracy() 81 | if writer: loss = [0, 0] 82 | 83 | for x, y in dset_test: # for every batch 84 | y_ = model(x, training=training) 85 | accuracy(tf.argmax(y, 1), tf.argmax(y_, 1)) 86 | 87 | if writer: 88 | loss[0] += tf.losses.softmax_cross_entropy(y, y_) 89 | loss[1] += 1. 90 | 91 | if writer: 92 | write_summary(tf.convert_to_tensor(loss[0]/loss[1]), writer, 'loss') 93 | 94 | return accuracy.result() 95 | 96 | 97 | def restore_state(saver, checkpoint): 98 | try: 99 | saver.restore(checkpoint) 100 | print('Model loaded') 101 | except Exception: 102 | print('Model not loaded') 103 | 104 | 105 | def init_model(model, input_shape): 106 | model._set_inputs(np.zeros(input_shape)) 107 | 108 | if __name__ == "__main__": 109 | 110 | 111 | # constants 112 | image_size = 28 113 | batch_size = 128 114 | epochs = 20 115 | num_classes = 10 116 | channels= 1 117 | 118 | # Get dataset 119 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() 120 | 121 | # Reshape images 122 | x_train = x_train.reshape(-1, image_size, image_size, channels).astype('float32') 123 | x_test = x_test.reshape(-1, image_size, image_size, channels).astype('float32') 124 | 125 | # We are normalizing the images to the range of [-1, 1] 126 | x_train = x_train / 127.5 - 1 127 | x_test = x_test / 127.5 - 1 128 | 129 | # Onehot: from 28,28 to 28,28,n_classes 130 | y_train_ohe = tf.one_hot(y_train, depth=num_classes).numpy() 131 | y_test_ohe = tf.one_hot(y_test, depth=num_classes).numpy() 132 | 133 | 134 | print('x train', x_train.shape) 135 | print('y train', y_train_ohe.shape) 136 | print('x test', x_test.shape) 137 | print('y test', y_test_ohe.shape) 138 | 139 | # Creates the tf.Dataset 140 | n_elements_train = x_train.shape[0] 141 | n_elements_test = x_test.shape[0] 142 | dset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train_ohe)).shuffle(n_elements_train).batch(batch_size) 143 | dset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test_ohe)).shuffle(n_elements_test).batch(batch_size) 144 | 145 | # build model and optimizer 146 | model = MnasnetEager.Mnasnet(num_classes=10) 147 | 148 | 149 | # optimizer 150 | optimizer = tf.train.AdamOptimizer(0.001) 151 | 152 | # Init model (variables and input shape) 153 | init_model(model, input_shape=(batch_size, image_size, image_size, channels)) 154 | 155 | # show the number of parametrs of the model 156 | get_params(model) 157 | 158 | # Init saver 159 | saver_model = tfe.Saver(var_list=model.variables) # can use also ckpt = tfe.Checkpoint(model=model) 160 | 161 | restore_state(saver_model, 'weights/last_saver') 162 | 163 | train(dset_train=dset_train, dset_test=dset_test, model=model, epochs=epochs) 164 | 165 | saver_model.save('weights/last_saver') 166 | 167 | 168 | 169 | ''' 170 | tensorboard --logdir=train_log:./logs/train 171 | 172 | You can olso optimize with only: 173 | optimizer.minimize(lambda: loss_function(model, x, y, training=True)) 174 | 175 | or you can build the Keras model: 176 | model.compile(optimizer=tf.train.AdamOptimizer(0.001), loss='categorical_crossentropy', metrics=['accuracy']) 177 | ''' 178 | 179 | --------------------------------------------------------------------------------