├── images ├── HiResNet.png └── ResNet50.jpg ├── HiResNet.py ├── README.md └── model_heads.py /images/HiResNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnGettings/Hi-ResNet/HEAD/images/HiResNet.png -------------------------------------------------------------------------------- /images/ResNet50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnGettings/Hi-ResNet/HEAD/images/ResNet50.jpg -------------------------------------------------------------------------------- /HiResNet.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from keras.layers import * 3 | from keras.applications.resnet import ResNet50 4 | 5 | from model_heads import * 6 | 7 | 8 | 9 | def HiResNet(size, weights, classes): 10 | 11 | if size == 448: 12 | hi_res_head = resnet_448_head() 13 | elif size == 896: 14 | hi_res_head = resnet_896_head() 15 | elif size == 1792: 16 | hi_res_head = resnet_1792_head() 17 | else: 18 | raise ValueError('size should be an integer value of: 448, 896, or 1792') 19 | 20 | if not isinstance(classes, int): 21 | raise ValueError('classes must be an integer') 22 | 23 | if weights == "Res50": 24 | base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) 25 | elif weights == "None": 26 | base_model = ResNet50(weights=None, include_top=False, input_shape=(224, 224, 3)) 27 | else: 28 | raise ValueError('weights should be either: \"Res50\" or \"None\"') 29 | 30 | # Constructing the ResNet50 base model. Removing top and first seven layers 31 | 32 | truncated_model = Model(inputs = base_model.layers[7].input, outputs = base_model.layers[-1].output) 33 | 34 | #Combining HiResNet head with ResNet50 base 35 | final_model = truncated_model(hi_res_head.output) 36 | model = Model(inputs=hi_res_head.input, outputs=final_model, name='HiResnet') 37 | 38 | # adding final layer 39 | head_model = MaxPool2D(pool_size=(4, 4))(model.output) 40 | head_model = Flatten(name='flatten')(head_model) 41 | head_model = Dense(1024, activation='relu')(head_model) 42 | head_model = Dropout(0.2)(head_model) 43 | head_model = Dense(512, activation='relu')(head_model) 44 | head_model = Dropout(0.2)(head_model) 45 | head_model = Dense(classes, activation='softmax')(head_model) 46 | 47 | # final configuration 48 | return Model(model.input, head_model) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hi-ResNet 2 | 3 | Hi-ResNet is an expansion of the original ResNet50 architecture to allow for higher resolution inputs (448x448, 896x896, or 1792x1792). It was created as an alternative to image tiling and may prove useful in analyzing large images with fine details necessary for classification. The inception came from a personal coin grading project, which relies on evaluating fine details of both the front and back of the entirety of the coin. I quickly realized that a 224x224 image completely obscures fine details needed to grade the coin. The model saw ~40% accuracy increase when switching to Hi-ResNet and I believe it could even see more if I had a larger dataset. 4 | 5 | The actual efficiency and effectiveness on datasets of various sizes has not been thoroughly tested. Other composite or tiling methods may achieve higher accuracy but will come with a longer and more complex pipeline. 6 | 7 | # Architecture 8 | The architecture is just a continuation from the original paper. I have included an architecture diagram for the original ResNet as well as the model heads for the three Hi-ResNet models below. The pattern from the original paper is continued down to the correct input size and number of features. The initial 7x7 conv and max pooling layers are removed from the original architecture and replaced by the Hi-ResNet head. 9 | 10 | ## Original Architecture 11 | ![Res50](./images/ResNet50.jpg) 12 | ## Hi-ResNet Heads 13 | ![Res50](./images/HiResNet.png) 14 | 15 | # Training 16 | There are three Hi-ResNet model heads to choose from, depending on your image input size. The three arguments to pass through are: 17 | 1) size (Int) 18 | (448, 896, or 1792) Size of the input image. 19 | 3) weights (Str) 20 | ("Res50" or "None") Set to "Res50" to train with original ResNet50 imagenet weights within the base model and randomly initialized weights for the Hi-ResNet head. "None" will randomly initialize everything. 21 | 4) classes (Int) 22 | Will set the number of output neurons for the final layer. 23 | 24 | If you want to tweak anything else such as the fully connected layers, dropout, regularization, you will need to revise the code. 25 | 26 | # Example code 27 | ```python 28 | model = HiResNet(size, weights, classes) #will return the Hi-ResNet model. 29 | ``` 30 | 31 | Example: 32 | ```python 33 | !git clone 'https://github.com/johnGettings/Hi-ResNet' 34 | ``` 35 | 36 | ```python 37 | %cd Hi-ResNet 38 | from HiResNet import HiResNet 39 | from tensorflow.keras.optimizers import SGD 40 | 41 | model = HiResNet(896, "None", 29) 42 | 43 | optimizer = SGD(learning_rate=0.01, momentum=0.9) 44 | model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=['accuracy']) 45 | ``` 46 | 47 | ```python 48 | from keras.callbacks import ModelCheckpoint 49 | INITIAL_EPOCHS = 99 50 | 51 | checkpoint = ModelCheckpoint("HiResNet.hdf5",monitor='val_loss',verbose=1,mode='min',save_best_only=True,save_weights_only=True) 52 | 53 | history = model.fit(ds_train, 54 | epochs=INITIAL_EPOCHS, 55 | callbacks=[checkpoint], 56 | validation_data=ds_val) 57 | ``` 58 | -------------------------------------------------------------------------------- /model_heads.py: -------------------------------------------------------------------------------- 1 | from keras.layers import * 2 | from tensorflow.keras.regularizers import l2 3 | from keras.models import Model 4 | 5 | def res_identity(x, filters): 6 | x_skip = x 7 | f1, f2 = filters 8 | 9 | #first block 10 | x = Conv2D(f1, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=l2(0.001))(x) 11 | x = BatchNormalization()(x) 12 | x = Activation(activation='relu')(x) 13 | 14 | #second block 15 | x = Conv2D(f1, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_regularizer=l2(0.001))(x) 16 | x = BatchNormalization()(x) 17 | x = Activation(activation='relu')(x) 18 | 19 | # third block 20 | x = Conv2D(f2, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=l2(0.001))(x) 21 | x = BatchNormalization()(x) 22 | # x = Activation(activations.relu)(x) 23 | 24 | # add the input 25 | x = Add()([x, x_skip]) 26 | x = Activation(activation='relu')(x) 27 | 28 | return x 29 | 30 | def res_conv(x, s, filters): 31 | x_skip = x 32 | f1, f2 = filters 33 | 34 | # first block 35 | x = Conv2D(f1, kernel_size=(1, 1), strides=(s, s), padding='valid', kernel_regularizer=l2(0.001))(x) 36 | x = BatchNormalization()(x) 37 | x = Activation(activation='relu')(x) 38 | 39 | # second block 40 | x = Conv2D(f1, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_regularizer=l2(0.001))(x) 41 | x = BatchNormalization()(x) 42 | x = Activation(activation='relu')(x) 43 | 44 | #third block 45 | x = Conv2D(f2, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=l2(0.001))(x) 46 | x = BatchNormalization()(x) 47 | 48 | # shortcut 49 | x_skip = Conv2D(f2, kernel_size=(1, 1), strides=(s, s), padding='valid', kernel_regularizer=l2(0.001))(x_skip) 50 | x_skip = BatchNormalization()(x_skip) 51 | 52 | # add 53 | x = Add()([x, x_skip]) 54 | x = Activation(activation='relu')(x) 55 | 56 | return x 57 | 58 | def resnet_1792_head(input_shape=(1792,1792,3)): 59 | 60 | input_im = Input(shape=input_shape) 61 | x = ZeroPadding2D(padding=(3, 3))(input_im) 62 | 63 | x = Conv2D(8, kernel_size=(7, 7), strides=(2, 2))(x) #Output size 896 64 | x = BatchNormalization()(x) 65 | x = Activation(activation='relu')(x) 66 | x = ZeroPadding2D(padding=(1, 1))(x) 67 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) #Output size 448 68 | 69 | x = res_conv(x, s=1, filters=(8, 32)) 70 | x = res_identity(x, filters=(8, 32)) 71 | x = res_identity(x, filters=(8, 32)) 72 | 73 | x = res_conv(x, s=2, filters=(16, 64)) #Output Size 224 74 | x = res_identity(x, filters=(16, 64)) 75 | x = res_identity(x, filters=(16, 64)) 76 | x = res_identity(x, filters=(16, 64)) 77 | 78 | x = res_conv(x, s=2, filters=(32, 128)) #Output size 112 79 | x = res_identity(x, filters=(32, 128)) 80 | x = res_identity(x, filters=(32, 128)) 81 | x = res_identity(x, filters=(32, 128)) 82 | 83 | x = Conv2D(64, kernel_size=(1, 1), strides=(2, 2), padding='valid', kernel_regularizer=l2(0.001))(x) 84 | x = BatchNormalization()(x) 85 | x = Activation(activation='relu')(x) 86 | 87 | model = Model(inputs=input_im, outputs=x, name='HiResnet50') 88 | 89 | return model 90 | 91 | def resnet_896_head(input_shape=(896,896,3)): 92 | 93 | input_im = Input(shape=input_shape) 94 | x = ZeroPadding2D(padding=(3, 3))(input_im) 95 | 96 | x = Conv2D(16, kernel_size=(7, 7), strides=(2, 2))(x) #Output size 448 97 | x = BatchNormalization()(x) 98 | x = Activation(activation='relu')(x) 99 | x = ZeroPadding2D(padding=(1, 1))(x) 100 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) #Output size 224 101 | 102 | x = res_conv(x, s=1, filters=(16, 64)) 103 | x = res_identity(x, filters=(16, 64)) 104 | x = res_identity(x, filters=(16, 64)) 105 | 106 | x = res_conv(x, s=2, filters=(32, 128)) #Output size 112 107 | x = res_identity(x, filters=(32, 128)) 108 | x = res_identity(x, filters=(32, 128)) 109 | x = res_identity(x, filters=(32, 128)) 110 | 111 | x = Conv2D(64, kernel_size=(1, 1), strides=(2, 2), padding='valid', kernel_regularizer=l2(0.001))(x) 112 | x = BatchNormalization()(x) 113 | x = Activation(activation='relu')(x) 114 | 115 | model = Model(inputs=input_im, outputs=x, name='HiResnet50-896') 116 | 117 | return model 118 | 119 | def resnet_448_head(input_shape=(448,448,3)): 120 | 121 | input_im = Input(shape=input_shape) 122 | x = ZeroPadding2D(padding=(3, 3))(input_im) 123 | 124 | x = Conv2D(32, kernel_size=(7, 7), strides=(2, 2))(x) #Output size 224 125 | x = BatchNormalization()(x) 126 | x = Activation(activation='relu')(x) 127 | x = ZeroPadding2D(padding=(1, 1))(x) 128 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) #Output size 112 129 | 130 | x = res_conv(x, s=1, filters=(32, 128)) 131 | x = res_identity(x, filters=(32, 128)) 132 | x = res_identity(x, filters=(32, 128)) 133 | x = res_identity(x, filters=(32, 128)) 134 | 135 | x = Conv2D(64, kernel_size=(1, 1), strides=(2, 2), padding='valid', kernel_regularizer=l2(0.001))(x) #Output size 56 136 | x = BatchNormalization()(x) 137 | x = Activation(activation='relu')(x) 138 | 139 | model = Model(inputs=input_im, outputs=x, name='HiResnet50-448') 140 | 141 | return model 142 | --------------------------------------------------------------------------------