├── README.md └── se_resnet.py /README.md: -------------------------------------------------------------------------------- 1 | # se-resnet 2 | 3 | # Squeeze-and-Excitation based ResNet architecture 4 | 5 | Squeeze-and-Excitation block (SE-block) was first proposed in the following paper: 6 | 7 | https://arxiv.org/pdf/1709.01507v2.pdf 8 | 9 | Instead of an equal representation of all channels in a given layer, it suggests developing a weighted representation. The corresponding weights of each channel can be learned in the SE-block. 10 | It introduces an addition hyperparameter, r (ratio) to be used in the SE-block. 11 | For c number of channels, it attempts to learn a (sigmoidal) vector of size c (a tensor of 1x1xc to be exact) and multiplies it with the current tensor in the given layer. 12 | 13 | ![alt text](https://cdn-images-1.medium.com/max/1600/1*WNk-atKDUsZPvMddvYL01g.png) 14 | 15 | Apart from ResNet, SE-blocks can also be implemented in other popular classification models such as Inception and DenseNet. 16 | -------------------------------------------------------------------------------- /se_resnet.py: -------------------------------------------------------------------------------- 1 | # Author: Md. Ibrahim Khan 2 | 3 | from keras import optimizers, losses 4 | from keras.layers import * 5 | from keras.models import Model 6 | from keras.backend import int_shape 7 | from keras.utils import to_categorical, plot_model 8 | import numpy as np 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | 12 | def se_block(block_input, num_filters, ratio=8): # Squeeze and excitation block 13 | 14 | ''' 15 | Args: 16 | block_input: input tensor to the squeeze and excitation block 17 | num_filters: no. of filters/channels in block_input 18 | ratio: a hyperparameter that denotes the ratio by which no. of channels will be reduced 19 | 20 | Returns: 21 | scale: scaled tensor after getting multiplied by new channel weights 22 | ''' 23 | 24 | pool1 = GlobalAveragePooling2D()(block_input) 25 | flat = Reshape((1, 1, num_filters))(pool1) 26 | dense1 = Dense(num_filters//ratio, activation='relu')(flat) 27 | dense2 = Dense(num_filters, activation='sigmoid')(dense1) 28 | scale = multiply([block_input, dense2]) 29 | 30 | return scale 31 | 32 | def resnet_block(block_input, num_filters): # Single ResNet block 33 | 34 | ''' 35 | Args: 36 | block_input: input tensor to the ResNet block 37 | num_filters: no. of filters/channels in block_input 38 | 39 | Returns: 40 | relu2: activated tensor after addition with original input 41 | ''' 42 | 43 | if int_shape(block_input)[3] != num_filters: 44 | block_input = Conv2D(num_filters, kernel_size=(1, 1))(block_input) 45 | 46 | conv1 = Conv2D(num_filters, kernel_size=(3, 3), padding='same')(block_input) 47 | norm1 = BatchNormalization()(conv1) 48 | relu1 = Activation('relu')(norm1) 49 | conv2 = Conv2D(num_filters, kernel_size=(3, 3), padding='same')(relu1) 50 | norm2 = BatchNormalization()(conv2) 51 | 52 | se = se_block(norm2, num_filters=num_filters) 53 | 54 | sum = Add()([block_input, se]) 55 | relu2 = Activation('relu')(sum) 56 | 57 | return relu2 58 | 59 | def se_resnet14(): 60 | 61 | ''' 62 | Squeeze and excitation blocks applied on an 14-layer adapted version of ResNet18. 63 | Adapted for MNIST dataset. 64 | Input size is 28x28x1 representing images in MNIST. 65 | Output size is 10 representing classes to which images belong. 66 | ''' 67 | 68 | input = Input(shape=(28, 28, 1)) 69 | conv1 = Conv2D(64, kernel_size=(7, 7), activation='relu', padding='same', kernel_initializer='he_normal')(input) 70 | pool1 = MaxPooling2D((2, 2), strides=2)(conv1) 71 | 72 | block1 = resnet_block(pool1, 64) 73 | block2 = resnet_block(block1, 64) 74 | 75 | pool2 = MaxPooling2D((2, 2), strides=2)(block2) 76 | 77 | block3 = resnet_block(pool2, 128) 78 | block4 = resnet_block(block3, 128) 79 | 80 | pool3 = MaxPooling2D((3, 3), strides=2)(block4) 81 | 82 | block5 = resnet_block(pool3, 256) 83 | block6 = resnet_block(block5, 256) 84 | 85 | pool4 = MaxPooling2D((3, 3), strides=2)(block6) 86 | flat = Flatten()(pool4) 87 | 88 | output = Dense(10, activation='softmax')(flat) 89 | 90 | model = Model(inputs=input, outputs=output) 91 | return model 92 | 93 | 94 | if __name__=='__main__': 95 | 96 | model = se_resnet14() 97 | print(model.summary()) 98 | 99 | # Training configuration 100 | model.compile(loss=losses.categorical_crossentropy, 101 | optimizer=optimizers.Adam(), 102 | metrics=['accuracy']) 103 | 104 | # Data preparation 105 | train = pd.read_csv('mnist_train.csv') 106 | test = pd.read_csv('mnist_test.csv') 107 | 108 | input_shape = (28, 28, 1) 109 | 110 | X_train = np.array(train.iloc[:, 1:]) 111 | y_train = to_categorical(np.array(train.iloc[:, 0])) 112 | 113 | X_test = np.array(test.iloc[:, 1:]) 114 | y_test = to_categorical(np.array(test.iloc[:, 0])) 115 | 116 | X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) 117 | X_test = X_test.reshape(X_test.shape[0], 28, 28, 1) 118 | 119 | X_train = X_train.astype('float32') 120 | X_test = X_test.astype('float32') 121 | 122 | X_train /= 255 123 | X_test /= 255 124 | 125 | # Training 126 | train_history = model.fit(X_train, y_train, 127 | batch_size=128, 128 | epochs=20, 129 | verbose=1) 130 | 131 | # Evaluation 132 | result = model.evaluate(X_test, y_test, verbose=0) 133 | 134 | print("Test Loss", result[0]) 135 | print("Test Accuracy", result[1]) 136 | 137 | # Plotting loss and accuracy metrics 138 | accuracy = train_history.history['acc'] 139 | loss = train_history.history['loss'] 140 | epochs = range(len(accuracy)) 141 | 142 | plt.plot(epochs, accuracy, 'b', label='Training accuracy') 143 | plt.title('Training accuracy') 144 | plt.savefig('train_acc.jpg') 145 | plt.figure() 146 | 147 | plt.plot(epochs, loss, 'b', label='Training loss') 148 | plt.title('Training loss') 149 | plt.savefig('train_loss.jpg') 150 | plt.show() --------------------------------------------------------------------------------