├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── cifar10.py ├── images ├── architecture.png ├── convergence.png └── residual_block.png ├── resnet.py └── tests └── test_resnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # test-related 2 | .coverage 3 | .cache 4 | 5 | # dev envs 6 | .idea/ 7 | *.iml 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | language: python 4 | matrix: 5 | include: 6 | - python: 2.7 7 | env: KERAS_BACKEND=theano 8 | - python: 2.7 9 | env: KERAS_BACKEND=tensorflow 10 | - python: 3.4 11 | env: KERAS_BACKEND=theano 12 | - python: 3.4 13 | env: KERAS_BACKEND=tensorflow 14 | 15 | install: 16 | # code below is taken from http://conda.pydata.org/docs/travis.html 17 | # We do this conditionally because it saves us some downloading if the 18 | # version is the same. 19 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 20 | wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh; 21 | else 22 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 23 | fi 24 | - bash miniconda.sh -b -p $HOME/miniconda 25 | - export PATH="$HOME/miniconda/bin:$PATH" 26 | - hash -r 27 | - conda config --set always_yes yes --set changeps1 no 28 | - conda update -q conda 29 | # Useful for debugging any issues with conda 30 | - conda info -a 31 | 32 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib pandas pytest h5py 33 | - source activate test-environment 34 | - pip install git+git://github.com/Theano/Theano.git 35 | - pip install keras 36 | 37 | # install PIL for preprocessing tests 38 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 39 | conda install pil; 40 | elif [[ "$TRAVIS_PYTHON_VERSION" == "3.5" ]]; then 41 | conda install Pillow; 42 | fi 43 | 44 | # install TensorFlow 45 | - pip install tensorflow 46 | 47 | script: 48 | # run keras backend init to initialize backend config 49 | - python -c "import keras.backend" 50 | # set up keras backend 51 | - sed -i -e 's/"backend":[[:space:]]*"[^"]*/"backend":\ "'$KERAS_BACKEND'/g' ~/.keras/keras.json; 52 | - echo -e "Running tests with the following config:\n$(cat ~/.keras/keras.json)" 53 | PYTHONPATH=../$PWD:$PYTHONPATH py.test tests/; 54 | after_success: 55 | - coveralls -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | All contributions by Raghavendra Kotikalapudi: 4 | Copyright (c) 2016, Raghavendra Kotikalapudi. 5 | All rights reserved. 6 | 7 | All other contributions: 8 | Copyright (c) 2016, the respective contributors. 9 | All rights reserved. 10 | 11 | Each contributor holds copyright over their respective contributions. 12 | The project versioning (Git) records all such contribution source information. 13 | 14 | LICENSE 15 | 16 | The MIT License (MIT) 17 | 18 | Permission is hereby granted, free of charge, to any person obtaining a copy 19 | of this software and associated documentation files (the "Software"), to deal 20 | in the Software without restriction, including without limitation the rights 21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 | copies of the Software, and to permit persons to whom the Software is 23 | furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in all 26 | copies or substantial portions of the Software. 27 | 28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 | SOFTWARE. 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-resnet 2 | [![Build Status](https://travis-ci.org/raghakot/keras-resnet.svg?branch=master)](https://travis-ci.org/raghakot/keras-resnet) 3 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/raghakot/keras-resnet/blob/master/LICENSE) 4 | 5 | Residual networks implementation using Keras-1.0 functional API, that works with 6 | both theano/tensorflow backend and 'th'/'tf' image dim ordering. 7 | 8 | ### The original articles 9 | * [Deep Residual Learning for Image Recognition](http://arxiv.org/abs/1512.03385) (the 2015 ImageNet competition winner) 10 | * [Identity Mappings in Deep Residual Networks](http://arxiv.org/abs/1603.05027) 11 | 12 | ### Residual blocks 13 | The residual blocks are based on the new improved scheme proposed in [Identity Mappings in Deep Residual Networks](http://arxiv.org/abs/1603.05027) as shown in figure (b) 14 | 15 | ![Residual Block Scheme](images/residual_block.png?raw=true "Residual Block Scheme") 16 | 17 | Both bottleneck and basic residual blocks are supported. To switch them, simply provide the block function [here](https://github.com/raghakot/keras-resnet/blob/master/resnet.py#L109) 18 | 19 | ### Code Walkthrough 20 | The architecture is based on 50 layer sample (snippet from paper) 21 | 22 | ![Architecture Reference](images/architecture.png?raw=true "Architecture Reference") 23 | 24 | There are two key aspects to note here 25 | 26 | 1. conv2_1 has stride of (1, 1) while remaining conv layers has stride (2, 2) at the beginning of the block. This fact is expressed in the following [lines](https://github.com/raghakot/keras-resnet/blob/master/resnet.py#L63-L65). 27 | 2. At the end of the first skip connection of a block, there is a disconnect in num_filters, width and height at the merge layer. This is addressed in [`_shortcut`](https://github.com/raghakot/keras-resnet/blob/master/resnet.py#L41) by using `conv 1X1` with an appropriate stride. 28 | For remaining cases, input is directly merged with residual block as identity. 29 | 30 | ### ResNetBuilder factory 31 | - Use ResNetBuilder [build](https://github.com/raghakot/keras-resnet/blob/master/resnet.py#L135-L153) methods to build standard ResNet architectures with your own input shape. It will auto calculate paddings and final pooling layer filters for you. 32 | - Use the generic [build](https://github.com/raghakot/keras-resnet/blob/master/resnet.py#L99) method to setup your own architecture. 33 | 34 | ### Cifar10 Example 35 | 36 | Includes cifar10 training example. Achieves ~86% accuracy using Resnet18 model. 37 | 38 | ![cifar10_convergence](images/convergence.png?raw=true "Convergence on cifar10") 39 | 40 | Note that ResNet18 as implemented doesn't really seem appropriate for CIFAR-10 as the last two residual stages end up 41 | as all 1x1 convolutions from downsampling (stride). This is worse for deeper versions. A smaller, modified ResNet-like 42 | architecture achieves ~92% accuracy (see [gist](https://gist.github.com/JefferyRPrice/c1ecc3d67068c8d9b3120475baba1d7e)). -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from keras example cifar10_cnn.py 3 | Train ResNet-18 on the CIFAR10 small images dataset. 4 | 5 | GPU run command with Theano backend (with TensorFlow, the GPU is automatically used): 6 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10.py 7 | """ 8 | from __future__ import print_function 9 | from keras.datasets import cifar10 10 | from keras.preprocessing.image import ImageDataGenerator 11 | from keras.utils import np_utils 12 | from keras.callbacks import ReduceLROnPlateau, CSVLogger, EarlyStopping 13 | 14 | import numpy as np 15 | import resnet 16 | 17 | 18 | lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1), cooldown=0, patience=5, min_lr=0.5e-6) 19 | early_stopper = EarlyStopping(min_delta=0.001, patience=10) 20 | csv_logger = CSVLogger('resnet18_cifar10.csv') 21 | 22 | batch_size = 32 23 | nb_classes = 10 24 | nb_epoch = 200 25 | data_augmentation = True 26 | 27 | # input image dimensions 28 | img_rows, img_cols = 32, 32 29 | # The CIFAR10 images are RGB. 30 | img_channels = 3 31 | 32 | # The data, shuffled and split between train and test sets: 33 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 34 | 35 | # Convert class vectors to binary class matrices. 36 | Y_train = np_utils.to_categorical(y_train, nb_classes) 37 | Y_test = np_utils.to_categorical(y_test, nb_classes) 38 | 39 | X_train = X_train.astype('float32') 40 | X_test = X_test.astype('float32') 41 | 42 | # subtract mean and normalize 43 | mean_image = np.mean(X_train, axis=0) 44 | X_train -= mean_image 45 | X_test -= mean_image 46 | X_train /= 128. 47 | X_test /= 128. 48 | 49 | model = resnet.ResnetBuilder.build_resnet_18((img_channels, img_rows, img_cols), nb_classes) 50 | model.compile(loss='categorical_crossentropy', 51 | optimizer='adam', 52 | metrics=['accuracy']) 53 | 54 | if not data_augmentation: 55 | print('Not using data augmentation.') 56 | model.fit(X_train, Y_train, 57 | batch_size=batch_size, 58 | nb_epoch=nb_epoch, 59 | validation_data=(X_test, Y_test), 60 | shuffle=True, 61 | callbacks=[lr_reducer, early_stopper, csv_logger]) 62 | else: 63 | print('Using real-time data augmentation.') 64 | # This will do preprocessing and realtime data augmentation: 65 | datagen = ImageDataGenerator( 66 | featurewise_center=False, # set input mean to 0 over the dataset 67 | samplewise_center=False, # set each sample mean to 0 68 | featurewise_std_normalization=False, # divide inputs by std of the dataset 69 | samplewise_std_normalization=False, # divide each input by its std 70 | zca_whitening=False, # apply ZCA whitening 71 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 72 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 73 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 74 | horizontal_flip=True, # randomly flip images 75 | vertical_flip=False) # randomly flip images 76 | 77 | # Compute quantities required for featurewise normalization 78 | # (std, mean, and principal components if ZCA whitening is applied). 79 | datagen.fit(X_train) 80 | 81 | # Fit the model on the batches generated by datagen.flow(). 82 | model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size), 83 | steps_per_epoch=X_train.shape[0] // batch_size, 84 | validation_data=(X_test, Y_test), 85 | epochs=nb_epoch, verbose=1, max_q_size=100, 86 | callbacks=[lr_reducer, early_stopper, csv_logger]) 87 | -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghakot/keras-resnet/5e9bcca7e467f7baf3459d809ef16bb75e53f115/images/architecture.png -------------------------------------------------------------------------------- /images/convergence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghakot/keras-resnet/5e9bcca7e467f7baf3459d809ef16bb75e53f115/images/convergence.png -------------------------------------------------------------------------------- /images/residual_block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghakot/keras-resnet/5e9bcca7e467f7baf3459d809ef16bb75e53f115/images/residual_block.png -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import six 4 | from keras.models import Model 5 | from keras.layers import ( 6 | Input, 7 | Activation, 8 | Dense, 9 | Flatten 10 | ) 11 | from keras.layers.convolutional import ( 12 | Conv2D, 13 | MaxPooling2D, 14 | AveragePooling2D 15 | ) 16 | from keras.layers.merge import add 17 | from keras.layers.normalization import BatchNormalization 18 | from keras.regularizers import l2 19 | from keras import backend as K 20 | 21 | 22 | def _bn_relu(input): 23 | """Helper to build a BN -> relu block 24 | """ 25 | norm = BatchNormalization(axis=CHANNEL_AXIS)(input) 26 | return Activation("relu")(norm) 27 | 28 | 29 | def _conv_bn_relu(**conv_params): 30 | """Helper to build a conv -> BN -> relu block 31 | """ 32 | filters = conv_params["filters"] 33 | kernel_size = conv_params["kernel_size"] 34 | strides = conv_params.setdefault("strides", (1, 1)) 35 | kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal") 36 | padding = conv_params.setdefault("padding", "same") 37 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4)) 38 | 39 | def f(input): 40 | conv = Conv2D(filters=filters, kernel_size=kernel_size, 41 | strides=strides, padding=padding, 42 | kernel_initializer=kernel_initializer, 43 | kernel_regularizer=kernel_regularizer)(input) 44 | return _bn_relu(conv) 45 | 46 | return f 47 | 48 | 49 | def _bn_relu_conv(**conv_params): 50 | """Helper to build a BN -> relu -> conv block. 51 | This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf 52 | """ 53 | filters = conv_params["filters"] 54 | kernel_size = conv_params["kernel_size"] 55 | strides = conv_params.setdefault("strides", (1, 1)) 56 | kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal") 57 | padding = conv_params.setdefault("padding", "same") 58 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4)) 59 | 60 | def f(input): 61 | activation = _bn_relu(input) 62 | return Conv2D(filters=filters, kernel_size=kernel_size, 63 | strides=strides, padding=padding, 64 | kernel_initializer=kernel_initializer, 65 | kernel_regularizer=kernel_regularizer)(activation) 66 | 67 | return f 68 | 69 | 70 | def _shortcut(input, residual): 71 | """Adds a shortcut between input and residual block and merges them with "sum" 72 | """ 73 | # Expand channels of shortcut to match residual. 74 | # Stride appropriately to match residual (width, height) 75 | # Should be int if network architecture is correctly configured. 76 | input_shape = K.int_shape(input) 77 | residual_shape = K.int_shape(residual) 78 | stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS])) 79 | stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS])) 80 | equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS] 81 | 82 | shortcut = input 83 | # 1 X 1 conv if shape is different. Else identity. 84 | if stride_width > 1 or stride_height > 1 or not equal_channels: 85 | shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS], 86 | kernel_size=(1, 1), 87 | strides=(stride_width, stride_height), 88 | padding="valid", 89 | kernel_initializer="he_normal", 90 | kernel_regularizer=l2(0.0001))(input) 91 | 92 | return add([shortcut, residual]) 93 | 94 | 95 | def _residual_block(block_function, filters, repetitions, is_first_layer=False): 96 | """Builds a residual block with repeating bottleneck blocks. 97 | """ 98 | def f(input): 99 | for i in range(repetitions): 100 | init_strides = (1, 1) 101 | if i == 0 and not is_first_layer: 102 | init_strides = (2, 2) 103 | input = block_function(filters=filters, init_strides=init_strides, 104 | is_first_block_of_first_layer=(is_first_layer and i == 0))(input) 105 | return input 106 | 107 | return f 108 | 109 | 110 | def basic_block(filters, init_strides=(1, 1), is_first_block_of_first_layer=False): 111 | """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34. 112 | Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf 113 | """ 114 | def f(input): 115 | 116 | if is_first_block_of_first_layer: 117 | # don't repeat bn->relu since we just did bn->relu->maxpool 118 | conv1 = Conv2D(filters=filters, kernel_size=(3, 3), 119 | strides=init_strides, 120 | padding="same", 121 | kernel_initializer="he_normal", 122 | kernel_regularizer=l2(1e-4))(input) 123 | else: 124 | conv1 = _bn_relu_conv(filters=filters, kernel_size=(3, 3), 125 | strides=init_strides)(input) 126 | 127 | residual = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv1) 128 | return _shortcut(input, residual) 129 | 130 | return f 131 | 132 | 133 | def bottleneck(filters, init_strides=(1, 1), is_first_block_of_first_layer=False): 134 | """Bottleneck architecture for > 34 layer resnet. 135 | Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf 136 | 137 | Returns: 138 | A final conv layer of filters * 4 139 | """ 140 | def f(input): 141 | 142 | if is_first_block_of_first_layer: 143 | # don't repeat bn->relu since we just did bn->relu->maxpool 144 | conv_1_1 = Conv2D(filters=filters, kernel_size=(1, 1), 145 | strides=init_strides, 146 | padding="same", 147 | kernel_initializer="he_normal", 148 | kernel_regularizer=l2(1e-4))(input) 149 | else: 150 | conv_1_1 = _bn_relu_conv(filters=filters, kernel_size=(1, 1), 151 | strides=init_strides)(input) 152 | 153 | conv_3_3 = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv_1_1) 154 | residual = _bn_relu_conv(filters=filters * 4, kernel_size=(1, 1))(conv_3_3) 155 | return _shortcut(input, residual) 156 | 157 | return f 158 | 159 | 160 | def _handle_dim_ordering(): 161 | global ROW_AXIS 162 | global COL_AXIS 163 | global CHANNEL_AXIS 164 | if K.image_dim_ordering() == 'tf': 165 | ROW_AXIS = 1 166 | COL_AXIS = 2 167 | CHANNEL_AXIS = 3 168 | else: 169 | CHANNEL_AXIS = 1 170 | ROW_AXIS = 2 171 | COL_AXIS = 3 172 | 173 | 174 | def _get_block(identifier): 175 | if isinstance(identifier, six.string_types): 176 | res = globals().get(identifier) 177 | if not res: 178 | raise ValueError('Invalid {}'.format(identifier)) 179 | return res 180 | return identifier 181 | 182 | 183 | class ResnetBuilder(object): 184 | @staticmethod 185 | def build(input_shape, num_outputs, block_fn, repetitions): 186 | """Builds a custom ResNet like architecture. 187 | 188 | Args: 189 | input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols) 190 | num_outputs: The number of outputs at final softmax layer 191 | block_fn: The block function to use. This is either `basic_block` or `bottleneck`. 192 | The original paper used basic_block for layers < 50 193 | repetitions: Number of repetitions of various block units. 194 | At each block unit, the number of filters are doubled and the input size is halved 195 | 196 | Returns: 197 | The keras `Model`. 198 | """ 199 | _handle_dim_ordering() 200 | if len(input_shape) != 3: 201 | raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)") 202 | 203 | # Permute dimension order if necessary 204 | if K.image_dim_ordering() == 'tf': 205 | input_shape = (input_shape[1], input_shape[2], input_shape[0]) 206 | 207 | # Load function from str if needed. 208 | block_fn = _get_block(block_fn) 209 | 210 | input = Input(shape=input_shape) 211 | conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7), strides=(2, 2))(input) 212 | pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1) 213 | 214 | block = pool1 215 | filters = 64 216 | for i, r in enumerate(repetitions): 217 | block = _residual_block(block_fn, filters=filters, repetitions=r, is_first_layer=(i == 0))(block) 218 | filters *= 2 219 | 220 | # Last activation 221 | block = _bn_relu(block) 222 | 223 | # Classifier block 224 | block_shape = K.int_shape(block) 225 | pool2 = AveragePooling2D(pool_size=(block_shape[ROW_AXIS], block_shape[COL_AXIS]), 226 | strides=(1, 1))(block) 227 | flatten1 = Flatten()(pool2) 228 | dense = Dense(units=num_outputs, kernel_initializer="he_normal", 229 | activation="softmax")(flatten1) 230 | 231 | model = Model(inputs=input, outputs=dense) 232 | return model 233 | 234 | @staticmethod 235 | def build_resnet_18(input_shape, num_outputs): 236 | return ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2]) 237 | 238 | @staticmethod 239 | def build_resnet_34(input_shape, num_outputs): 240 | return ResnetBuilder.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3]) 241 | 242 | @staticmethod 243 | def build_resnet_50(input_shape, num_outputs): 244 | return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3]) 245 | 246 | @staticmethod 247 | def build_resnet_101(input_shape, num_outputs): 248 | return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3]) 249 | 250 | @staticmethod 251 | def build_resnet_152(input_shape, num_outputs): 252 | return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3]) 253 | -------------------------------------------------------------------------------- /tests/test_resnet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from keras import backend as K 3 | from resnet import ResnetBuilder 4 | 5 | 6 | DIM_ORDERING = {'th', 'tf'} 7 | 8 | 9 | def _test_model_compile(model): 10 | for ordering in DIM_ORDERING: 11 | K.set_image_dim_ordering(ordering) 12 | model.compile(loss="categorical_crossentropy", optimizer="sgd") 13 | assert True, "Failed to compile with '{}' dim ordering".format(ordering) 14 | 15 | 16 | def test_resnet18(): 17 | model = ResnetBuilder.build_resnet_18((3, 224, 224), 100) 18 | _test_model_compile(model) 19 | 20 | 21 | def test_resnet34(): 22 | model = ResnetBuilder.build_resnet_34((3, 224, 224), 100) 23 | _test_model_compile(model) 24 | 25 | 26 | def test_resnet50(): 27 | model = ResnetBuilder.build_resnet_50((3, 224, 224), 100) 28 | _test_model_compile(model) 29 | 30 | 31 | def test_resnet101(): 32 | model = ResnetBuilder.build_resnet_101((3, 224, 224), 100) 33 | _test_model_compile(model) 34 | 35 | 36 | def test_resnet152(): 37 | model = ResnetBuilder.build_resnet_152((3, 224, 224), 100) 38 | _test_model_compile(model) 39 | 40 | 41 | def test_custom1(): 42 | """ https://github.com/raghakot/keras-resnet/issues/34 43 | """ 44 | model = ResnetBuilder.build_resnet_152((3, 300, 300), 100) 45 | _test_model_compile(model) 46 | 47 | 48 | def test_custom2(): 49 | """ https://github.com/raghakot/keras-resnet/issues/34 50 | """ 51 | model = ResnetBuilder.build_resnet_152((3, 512, 512), 2) 52 | _test_model_compile(model) 53 | 54 | 55 | if __name__ == '__main__': 56 | pytest.main([__file__]) 57 | --------------------------------------------------------------------------------