├── .gitignore ├── elephant.jpg ├── requirements.txt ├── load_weights.py ├── extract_weights.py ├── test_inception_resnet_v2.py ├── README.md ├── evaluate_imagenet.py └── inception_resnet_v2.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | weights 3 | models 4 | -------------------------------------------------------------------------------- /elephant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuyang-huang/keras-inception-resnet-v2/HEAD/elephant.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.10.0 2 | Keras==2.2.5 3 | numpy==1.18.1 4 | Pillow==8.1.1 5 | tensorflow-gpu==1.15.4 6 | tf-slim==1.1.0 7 | tqdm==4.47.0 8 | -------------------------------------------------------------------------------- /load_weights.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | from keras.models import Model 7 | from inception_resnet_v2 import InceptionResNetV2 8 | 9 | 10 | WEIGHTS_DIR = './weights' 11 | MODEL_DIR = './models' 12 | OUTPUT_WEIGHT_FILENAME = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5' 13 | OUTPUT_WEIGHT_FILENAME_NOTOP = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5' 14 | 15 | 16 | print('Instantiating an empty InceptionResNetV2 model...') 17 | model = InceptionResNetV2(weights=None, input_shape=(299, 299, 3)) 18 | 19 | print('Loading weights from', WEIGHTS_DIR) 20 | for layer in tqdm(model.layers): 21 | if layer.weights: 22 | weights = [] 23 | for w in layer.weights: 24 | weight_name = os.path.basename(w.name).replace(':0', '') 25 | weight_file = layer.name + '_' + weight_name + '.npy' 26 | weight_arr = np.load(os.path.join(WEIGHTS_DIR, weight_file)) 27 | 28 | # remove the "background class" 29 | if weight_file.startswith('Logits_bias'): 30 | weight_arr = weight_arr[1:] 31 | elif weight_file.startswith('Logits_kernel'): 32 | weight_arr = weight_arr[:, 1:] 33 | 34 | weights.append(weight_arr) 35 | layer.set_weights(weights) 36 | 37 | 38 | print('Saving model weights...') 39 | if not os.path.exists(MODEL_DIR): 40 | os.makedirs(MODEL_DIR) 41 | model.save_weights(os.path.join(MODEL_DIR, OUTPUT_WEIGHT_FILENAME)) 42 | 43 | print('Saving model weights (no top)...') 44 | model_notop = Model(model.inputs, model.get_layer('Conv2d_7b_1x1_Activation').output) 45 | model_notop.save_weights(os.path.join(MODEL_DIR, OUTPUT_WEIGHT_FILENAME_NOTOP)) 46 | -------------------------------------------------------------------------------- /extract_weights.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import re 5 | from glob import glob 6 | import numpy as np 7 | import tensorflow as tf 8 | from keras.utils.data_utils import get_file 9 | 10 | 11 | # regex for renaming the tensors to their corresponding Keras counterpart 12 | re_repeat = re.compile(r'Repeat_[0-9_]*b') 13 | re_block8 = re.compile(r'Block8_[A-Za-z]') 14 | 15 | 16 | def get_filename(key): 17 | """Rename tensor name to the corresponding Keras layer weight name. 18 | 19 | # Arguments 20 | key: tensor name in TF (determined by tf.variable_scope) 21 | """ 22 | filename = str(key) 23 | filename = filename.replace('/', '_') 24 | filename = filename.replace('InceptionResnetV2_', '') 25 | 26 | # remove "Repeat" scope from filename 27 | filename = re_repeat.sub('B', filename) 28 | 29 | if re_block8.match(filename): 30 | # the last block8 has different name with the previous 9 occurrences 31 | filename = filename.replace('Block8', 'Block8_10') 32 | elif filename.startswith('Logits'): 33 | # remove duplicate "Logits" scope 34 | filename = filename.replace('Logits_', '', 1) 35 | 36 | # from TF to Keras naming 37 | filename = filename.replace('_weights', '_kernel') 38 | filename = filename.replace('_biases', '_bias') 39 | 40 | return filename + '.npy' 41 | 42 | 43 | def extract_tensors_from_checkpoint_file(filename, output_folder='weights'): 44 | """Extract tensors from a TF checkpoint file. 45 | 46 | # Arguments 47 | filename: TF checkpoint file 48 | output_folder: where to save the output numpy array files 49 | """ 50 | if not os.path.exists(output_folder): 51 | os.makedirs(output_folder) 52 | 53 | reader = tf.train.NewCheckpointReader(filename) 54 | 55 | for key in reader.get_variable_to_shape_map(): 56 | # not saving the following tensors 57 | if key == 'global_step': 58 | continue 59 | if 'AuxLogit' in key: 60 | continue 61 | 62 | # convert tensor name into the corresponding Keras layer weight name and save 63 | path = os.path.join(output_folder, get_filename(key)) 64 | arr = reader.get_tensor(key) 65 | np.save(path, arr) 66 | print("tensor_name: ", key) 67 | 68 | 69 | # download TF-slim checkpoint for Inception-ResNet v2 and extract 70 | CKPT_URL = 'http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz' 71 | MODEL_DIR = './models' 72 | 73 | os.makedirs(MODEL_DIR, exist_ok=True) 74 | checkpoint_tar = get_file( 75 | 'inception_resnet_v2_2016_08_30.tar.gz', 76 | CKPT_URL, 77 | file_hash='9e0f18e1259acf943e30690460d96123', 78 | hash_algorithm='md5', 79 | extract=True, 80 | cache_subdir='', 81 | cache_dir=MODEL_DIR) 82 | 83 | checkpoint_file = glob(os.path.join(MODEL_DIR, 'inception_resnet_v2_*.ckpt'))[0] 84 | extract_tensors_from_checkpoint_file(checkpoint_file) 85 | -------------------------------------------------------------------------------- /test_inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import tensorflow as tf 5 | import tf_slim 6 | import numpy as np 7 | from PIL import Image 8 | from nets import inception_resnet_v2 as slim_irv2 # PYHTONPATH should contain the research/slim/ directory in the tensorflow/models repo. 9 | import inception_resnet_v2 as keras_irv2 10 | 11 | 12 | IMAGES = ['elephant.jpg'] 13 | MODEL_DIR = './models' 14 | SLIM_CKPT = os.path.join(MODEL_DIR, 'inception_resnet_v2_2016_08_30.ckpt') 15 | KERAS_CKPT = os.path.join(MODEL_DIR, 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5') 16 | ATOL = 1e-5 17 | VERBOSE = True 18 | 19 | 20 | def predict_slim(sample_images, print_func=print): 21 | """ 22 | Code modified from here: [https://github.com/tensorflow/models/issues/429] 23 | """ 24 | # Setup preprocessing 25 | input_tensor = tf.placeholder(tf.float32, shape=(None, 299, 299, 3), name='input_image') 26 | scaled_input_tensor = tf.scalar_mul((1.0 / 255), input_tensor) 27 | scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5) 28 | scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0) 29 | 30 | # Setup session 31 | sess = tf.Session() 32 | arg_scope = slim_irv2.inception_resnet_v2_arg_scope() 33 | with tf_slim.arg_scope(arg_scope): 34 | _, end_points = slim_irv2.inception_resnet_v2(scaled_input_tensor, is_training=False) 35 | 36 | # Load the model 37 | print_func("Loading TF-slim checkpoint...") 38 | saver = tf.train.Saver() 39 | saver.restore(sess, SLIM_CKPT) 40 | 41 | # Make prediction 42 | predict_values = [] 43 | for image in sample_images: 44 | im = Image.open(image).resize((299, 299)) 45 | arr = np.expand_dims(np.array(im), axis=0) 46 | y_pred = sess.run([end_points['Predictions']], feed_dict={input_tensor: arr}) 47 | y_pred = y_pred[0].ravel() 48 | 49 | y_pred = y_pred[1:] / y_pred[1:].sum() # remove background class and renormalize 50 | print_func("{} class={} prob={}".format(image, np.argmax(y_pred), np.max(y_pred))) 51 | predict_values.append(y_pred) 52 | 53 | return predict_values 54 | 55 | 56 | def predict_keras(sample_images, print_func=print): 57 | # Load the model 58 | print_func("Loading Keras checkpoint...") 59 | model = keras_irv2.InceptionResNetV2(weights=None) 60 | model.load_weights(KERAS_CKPT) 61 | 62 | # Make prediction 63 | predict_values = [] 64 | for image in sample_images: 65 | im = Image.open(image).resize((299, 299)) 66 | arr = np.expand_dims(np.array(im), axis=0) 67 | y_pred = model.predict(keras_irv2.preprocess_input(arr.astype('float32'))) 68 | y_pred = y_pred.ravel() 69 | print_func("{} class={} prob={}".format(image, np.argmax(y_pred), np.max(y_pred))) 70 | predict_values.append(y_pred) 71 | 72 | return predict_values 73 | 74 | 75 | # test whether Keras implementation gives the same result as TF-slim implementation 76 | verboseprint = print if VERBOSE else lambda *a, **k: None 77 | slim_predictions = predict_slim(IMAGES, verboseprint) 78 | keras_predictions = predict_keras(IMAGES, verboseprint) 79 | 80 | for filename, y_slim, y_keras in zip(IMAGES, slim_predictions, keras_predictions): 81 | np.testing.assert_allclose(y_slim, y_keras, atol=ATOL, err_msg=filename) 82 | verboseprint('{} passed test. (tolerance={})'.format(filename, ATOL)) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-inception-resnet-v2 2 | The Inception-ResNet v2 model using Keras (with weight files) 3 | 4 | Tested with `tensorflow-gpu==1.15.3` and `Keras==2.2.5` under Python 3.6 5 | (although there are lots of deprecation warnings since this code was written way before TF 1.15). 6 | 7 | Layers and namings follow the TF-slim implementation: 8 | https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py 9 | 10 | 11 | ## News 12 | 13 | This implementation has been merged into the `keras.applications` module! 14 | 15 | Install the latest version Keras on GitHub and import it with: 16 | ```python 17 | from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input 18 | ``` 19 | 20 | 21 | ## Usage 22 | Basically the same with the `keras.applications.InceptionV3` model. 23 | ```python 24 | from inception_resnet_v2 import InceptionResNetV2 25 | 26 | # ImageNet classification 27 | model = InceptionResNetV2() 28 | model.predict(...) 29 | 30 | # Finetuning on another 100-class dataset 31 | base_model = InceptionResNetV2(include_top=False, pooling='avg') 32 | outputs = Dense(100, activation='softmax')(base_model.output) 33 | model = Model(base_model.inputs, outputs) 34 | model.compile(...) 35 | model.fit(...) 36 | ``` 37 | 38 | 39 | ### Extract layer weights from TF checkpoint 40 | ``` 41 | python extract_weights.py 42 | ``` 43 | By default, the TF checkpoint file will be downloaded to `./models` folder, and the layer weights (`.npy` files) will be saved to `./weights` folder. 44 | 45 | 46 | ### Load NumPy weight files and save to a Keras HDF5 weights file 47 | ``` 48 | python load_weights.py 49 | ``` 50 | The following weight files: 51 | - models/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5 52 | - models/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5 53 | 54 | will be generated. 55 | 56 | 57 | ### Test model prediction on single image 58 | To test whether this implementation gives the same prediction as TF-slim implementation: 59 | ``` 60 | PYTHONPATH=../tensorflow-models/research/slim python test_inception_resnet_v2.py 61 | ``` 62 | `PYTHONPATH` should point to the `research/slim` folder under the https://github.com/tensorflow/models repo. 63 | 64 | The image file `elephant.jpg` (and basically the entire idea of converting weights from TF-slim to Keras) comes from: 65 | https://github.com/kentsommer/keras-inception-resnetV2 66 | 67 | 68 | ### Evaluate the model on ImageNet 2012 dataset 69 | First, follow the 70 | [instructions](https://github.com/tensorflow/models/tree/master/research/slim#an-automated-script-for-processing-imagenet-data) 71 | from TF-slim to download and process the data. 72 | 73 | Suppose that the dataset is saved to the `imagenet_2012` directory, to evaluate: 74 | ``` 75 | PYTHONPATH=../tensorflow-models/research/slim python evaluate_imagenet.py ../tensorflow-models/research/slim/datasets/imagenet_2012 --verbose 76 | ``` 77 | 78 | The script should print out top-1 and top-5 accuracy on validation set: 79 | 80 | Implementation | Top-1 Accuracy | Top-5 Accuracy 81 | --- | --- | --- 82 | [TF-slim](https://github.com/tensorflow/models/tree/master/research/slim) | 80.4 | 95.3 83 | This repo | 80.4 | 95.3 84 | 85 | 86 | ## Current status 87 | - [X] Extract weights from TF-slim 88 | - [X] Convert weights to HDF5 files 89 | - [X] Test weight loading and image prediction (`elephant.jpg`) 90 | - [X] Release weight files 91 | - [X] Evaluate accuracy on ImageNet benchmark dataset 92 | -------------------------------------------------------------------------------- /evaluate_imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import math 6 | import tensorflow as tf; slim = tf.contrib.slim 7 | from keras import backend as K 8 | 9 | # PYHTONPATH should contain the research/slim/ directory in the tensorflow/models repo. 10 | from datasets import dataset_factory 11 | from preprocessing import preprocessing_factory 12 | from inception_resnet_v2 import InceptionResNetV2 13 | 14 | 15 | def prepare_data(imagenet_dir, batch_size, num_threads): 16 | # setup image loading 17 | dataset = dataset_factory.get_dataset('imagenet', 'validation', imagenet_dir) 18 | provider = slim.dataset_data_provider.DatasetDataProvider(dataset, 19 | shuffle=False, 20 | common_queue_capacity=batch_size * 5, 21 | common_queue_min=batch_size) 22 | image, label = provider.get(['image', 'label']) 23 | 24 | # preprocess images and split into batches 25 | preprocess_input = preprocessing_factory.get_preprocessing('inception_resnet_v2', 26 | is_training=False) 27 | image = preprocess_input(image, 299, 299) 28 | images, labels = tf.train.batch([image, label], 29 | batch_size=batch_size, 30 | num_threads=num_threads, 31 | capacity=batch_size * 5) 32 | 33 | # Keras label is different from TF 34 | labels = labels - 1 # remove the "background class" 35 | labels = K.cast(K.expand_dims(labels, -1), K.floatx()) # Keras labels are 2D float tensors 36 | return images, labels, dataset.num_samples 37 | 38 | 39 | def evaluate(imagenet_dir, batch_size=100, steps=None, num_threads=4, verbose=False): 40 | with K.get_session().as_default(): 41 | # setup data tensors 42 | images, labels, num_samples = prepare_data(imagenet_dir, batch_size, num_threads) 43 | tf.train.start_queue_runners(coord=tf.train.Coordinator()) 44 | 45 | # compile model in order to provide `metrics` and `target_tensors` 46 | model = InceptionResNetV2(input_tensor=images) 47 | model.compile(optimizer='adam', 48 | loss='sparse_categorical_crossentropy', 49 | metrics=['sparse_categorical_accuracy', 'sparse_top_k_categorical_accuracy'], 50 | target_tensors=[labels]) 51 | 52 | # start evaluation 53 | if steps is None: 54 | steps = int(math.ceil(num_samples / batch_size)) 55 | _, acc1, acc5 = model.evaluate(x=None, y=None, steps=steps, verbose=int(verbose)) 56 | print() 57 | print('Top-1 Accuracy {:.1%}'.format(acc1)) 58 | print('Top-5 Accuracy {:.1%}'.format(acc5)) 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("imagenet_dir", type=str, help="where ImageNet data is located (i.e. the output of `download_and_convert_imagenet.sh` from TF-slim)") 64 | parser.add_argument("--batch_size", type=int, default=100, help="batch size when evaluating, set this number according to your GPU memory") 65 | parser.add_argument("--steps", type=int, default=None, help="maximum number of batches to evaluate, if not specified, will go through the entire validation set by default") 66 | parser.add_argument("--num_threads", type=int, default=4, help="number of threads to use for data loading, default 4") 67 | parser.add_argument("--verbose", action='store_true', help="if specified, print the progress bar") 68 | args = parser.parse_args() 69 | evaluate(**vars(args)) 70 | -------------------------------------------------------------------------------- /inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Inception-ResNet V2 model for Keras. 3 | 4 | Model naming and structure follows TF-slim implementation (which has some additional 5 | layers and different number of filters from the original arXiv paper): 6 | https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py 7 | 8 | Pre-trained ImageNet weights are also converted from TF-slim, which can be found in: 9 | https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models 10 | 11 | # Reference 12 | - [Inception-v4, Inception-ResNet and the Impact of 13 | Residual Connections on Learning](https://arxiv.org/abs/1602.07261) 14 | 15 | """ 16 | from __future__ import print_function 17 | from __future__ import absolute_import 18 | 19 | import warnings 20 | from functools import partial 21 | 22 | from keras.models import Model 23 | from keras.layers import Activation 24 | from keras.layers import AveragePooling2D 25 | from keras.layers import BatchNormalization 26 | from keras.layers import Concatenate 27 | from keras.layers import Conv2D 28 | from keras.layers import Dense 29 | from keras.layers import Dropout 30 | from keras.layers import GlobalAveragePooling2D 31 | from keras.layers import GlobalMaxPooling2D 32 | from keras.layers import Input 33 | from keras.layers import Lambda 34 | from keras.layers import MaxPooling2D 35 | from keras.utils.data_utils import get_file 36 | from keras.engine.topology import get_source_inputs 37 | from keras_applications.imagenet_utils import _obtain_input_shape 38 | from keras import backend as K 39 | 40 | 41 | BASE_WEIGHT_URL = 'https://github.com/myutwo150/keras-inception-resnet-v2/releases/download/v0.1/' 42 | 43 | 44 | def preprocess_input(x): 45 | """Preprocesses a numpy array encoding a batch of images. 46 | 47 | This function applies the "Inception" preprocessing which converts 48 | the RGB values from [0, 255] to [-1, 1]. Note that this preprocessing 49 | function is different from `imagenet_utils.preprocess_input()`. 50 | 51 | # Arguments 52 | x: a 4D numpy array consists of RGB values within [0, 255]. 53 | 54 | # Returns 55 | Preprocessed array. 56 | """ 57 | x /= 255. 58 | x -= 0.5 59 | x *= 2. 60 | return x 61 | 62 | 63 | def conv2d_bn(x, 64 | filters, 65 | kernel_size, 66 | strides=1, 67 | padding='same', 68 | activation='relu', 69 | use_bias=False, 70 | name=None): 71 | """Utility function to apply conv + BN. 72 | 73 | # Arguments 74 | x: input tensor. 75 | filters: filters in `Conv2D`. 76 | kernel_size: kernel size as in `Conv2D`. 77 | padding: padding mode in `Conv2D`. 78 | activation: activation in `Conv2D`. 79 | strides: strides in `Conv2D`. 80 | name: name of the ops; will become `name + '_Activation'` 81 | for the activation and `name + '_BatchNorm'` for the 82 | batch norm layer. 83 | 84 | # Returns 85 | Output tensor after applying `Conv2D` and `BatchNormalization`. 86 | """ 87 | x = Conv2D(filters, 88 | kernel_size, 89 | strides=strides, 90 | padding=padding, 91 | use_bias=use_bias, 92 | name=name)(x) 93 | if not use_bias: 94 | bn_axis = 1 if K.image_data_format() == 'channels_first' else 3 95 | bn_name = _generate_layer_name('BatchNorm', prefix=name) 96 | x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) 97 | if activation is not None: 98 | ac_name = _generate_layer_name('Activation', prefix=name) 99 | x = Activation(activation, name=ac_name)(x) 100 | return x 101 | 102 | 103 | def _generate_layer_name(name, branch_idx=None, prefix=None): 104 | """Utility function for generating layer names. 105 | 106 | If `prefix` is `None`, returns `None` to use default automatic layer names. 107 | Otherwise, the returned layer name is: 108 | - PREFIX_NAME if `branch_idx` is not given. 109 | - PREFIX_Branch_0_NAME if e.g. `branch_idx=0` is given. 110 | 111 | # Arguments 112 | name: base layer name string, e.g. `'Concatenate'` or `'Conv2d_1x1'`. 113 | branch_idx: an `int`. If given, will add e.g. `'Branch_0'` 114 | after `prefix` and in front of `name` in order to identify 115 | layers in the same block but in different branches. 116 | prefix: string prefix that will be added in front of `name` to make 117 | all layer names unique (e.g. which block this layer belongs to). 118 | 119 | # Returns 120 | The layer name. 121 | """ 122 | if prefix is None: 123 | return None 124 | if branch_idx is None: 125 | return '_'.join((prefix, name)) 126 | return '_'.join((prefix, 'Branch', str(branch_idx), name)) 127 | 128 | 129 | def _inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): 130 | """Adds a Inception-ResNet block. 131 | 132 | This function builds 3 types of Inception-ResNet blocks mentioned 133 | in the paper, controlled by the `block_type` argument (which is the 134 | block name used in the official TF-slim implementation): 135 | - Inception-ResNet-A: `block_type='Block35'` 136 | - Inception-ResNet-B: `block_type='Block17'` 137 | - Inception-ResNet-C: `block_type='Block8'` 138 | 139 | # Arguments 140 | x: input tensor. 141 | scale: scaling factor to scale the residuals before adding 142 | them to the shortcut branch. 143 | block_type: `'Block35'`, `'Block17'` or `'Block8'`, determines 144 | the network structure in the residual branch. 145 | block_idx: used for generating layer names. 146 | activation: name of the activation function to use at the end 147 | of the block (see [activations](../activations.md)). 148 | When `activation=None`, no activation is applied 149 | (i.e., "linear" activation: `a(x) = x`). 150 | 151 | # Returns 152 | Output tensor for the block. 153 | 154 | # Raises 155 | ValueError: if `block_type` is not one of `'Block35'`, 156 | `'Block17'` or `'Block8'`. 157 | """ 158 | channel_axis = 1 if K.image_data_format() == 'channels_first' else 3 159 | if block_idx is None: 160 | prefix = None 161 | else: 162 | prefix = '_'.join((block_type, str(block_idx))) 163 | name_fmt = partial(_generate_layer_name, prefix=prefix) 164 | 165 | if block_type == 'Block35': 166 | branch_0 = conv2d_bn(x, 32, 1, name=name_fmt('Conv2d_1x1', 0)) 167 | branch_1 = conv2d_bn(x, 32, 1, name=name_fmt('Conv2d_0a_1x1', 1)) 168 | branch_1 = conv2d_bn(branch_1, 32, 3, name=name_fmt('Conv2d_0b_3x3', 1)) 169 | branch_2 = conv2d_bn(x, 32, 1, name=name_fmt('Conv2d_0a_1x1', 2)) 170 | branch_2 = conv2d_bn(branch_2, 48, 3, name=name_fmt('Conv2d_0b_3x3', 2)) 171 | branch_2 = conv2d_bn(branch_2, 64, 3, name=name_fmt('Conv2d_0c_3x3', 2)) 172 | branches = [branch_0, branch_1, branch_2] 173 | elif block_type == 'Block17': 174 | branch_0 = conv2d_bn(x, 192, 1, name=name_fmt('Conv2d_1x1', 0)) 175 | branch_1 = conv2d_bn(x, 128, 1, name=name_fmt('Conv2d_0a_1x1', 1)) 176 | branch_1 = conv2d_bn(branch_1, 160, [1, 7], name=name_fmt('Conv2d_0b_1x7', 1)) 177 | branch_1 = conv2d_bn(branch_1, 192, [7, 1], name=name_fmt('Conv2d_0c_7x1', 1)) 178 | branches = [branch_0, branch_1] 179 | elif block_type == 'Block8': 180 | branch_0 = conv2d_bn(x, 192, 1, name=name_fmt('Conv2d_1x1', 0)) 181 | branch_1 = conv2d_bn(x, 192, 1, name=name_fmt('Conv2d_0a_1x1', 1)) 182 | branch_1 = conv2d_bn(branch_1, 224, [1, 3], name=name_fmt('Conv2d_0b_1x3', 1)) 183 | branch_1 = conv2d_bn(branch_1, 256, [3, 1], name=name_fmt('Conv2d_0c_3x1', 1)) 184 | branches = [branch_0, branch_1] 185 | else: 186 | raise ValueError('Unknown Inception-ResNet block type. ' 187 | 'Expects "Block35", "Block17" or "Block8", ' 188 | 'but got: ' + str(block_type)) 189 | 190 | mixed = Concatenate(axis=channel_axis, name=name_fmt('Concatenate'))(branches) 191 | up = conv2d_bn(mixed, 192 | K.int_shape(x)[channel_axis], 193 | 1, 194 | activation=None, 195 | use_bias=True, 196 | name=name_fmt('Conv2d_1x1')) 197 | x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale, 198 | output_shape=K.int_shape(x)[1:], 199 | arguments={'scale': scale}, 200 | name=name_fmt('ScaleSum'))([x, up]) 201 | if activation is not None: 202 | x = Activation(activation, name=name_fmt('Activation'))(x) 203 | return x 204 | 205 | 206 | def InceptionResNetV2(include_top=True, 207 | weights='imagenet', 208 | input_tensor=None, 209 | input_shape=None, 210 | pooling=None, 211 | classes=1000, 212 | dropout_keep_prob=0.8): 213 | """Instantiates the Inception-ResNet v2 architecture. 214 | 215 | Optionally loads weights pre-trained on ImageNet. 216 | Note that when using TensorFlow, for best performance you should 217 | set `"image_data_format": "channels_last"` in your Keras config 218 | at `~/.keras/keras.json`. 219 | 220 | The model and the weights are compatible with both TensorFlow and Theano. 221 | The data format convention used by the model is the one specified in your 222 | Keras config file. 223 | 224 | Note that the default input image size for this model is 299x299, instead 225 | of 224x224 as in the VGG16 and ResNet models. Also, the input preprocessing 226 | function is different (i.e., do not use `imagenet_utils.preprocess_input()` 227 | with this model. Use `preprocess_input()` defined in this module instead). 228 | 229 | # Arguments 230 | include_top: whether to include the fully-connected 231 | layer at the top of the network. 232 | weights: one of `None` (random initialization) 233 | or `'imagenet'` (pre-training on ImageNet). 234 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 235 | to use as image input for the model. 236 | input_shape: optional shape tuple, only to be specified 237 | if `include_top` is `False` (otherwise the input shape 238 | has to be `(299, 299, 3)` (with `channels_last` data format) 239 | or `(3, 299, 299)` (with `channels_first` data format). 240 | It should have exactly 3 inputs channels, 241 | and width and height should be no smaller than 139. 242 | E.g. `(150, 150, 3)` would be one valid value. 243 | pooling: Optional pooling mode for feature extraction 244 | when `include_top` is `False`. 245 | - `None` means that the output of the model will be 246 | the 4D tensor output of the last convolutional layer. 247 | - `'avg'` means that global average pooling 248 | will be applied to the output of the 249 | last convolutional layer, and thus 250 | the output of the model will be a 2D tensor. 251 | - `'max'` means that global max pooling will be applied. 252 | classes: optional number of classes to classify images 253 | into, only to be specified if `include_top` is `True`, and 254 | if no `weights` argument is specified. 255 | dropout_keep_prob: dropout keep rate after pooling and before the 256 | classification layer, only to be specified if `include_top` is `True`. 257 | 258 | # Returns 259 | A Keras `Model` instance. 260 | 261 | # Raises 262 | ValueError: in case of invalid argument for `weights`, 263 | or invalid input shape. 264 | """ 265 | if weights not in {'imagenet', None}: 266 | raise ValueError('The `weights` argument should be either ' 267 | '`None` (random initialization) or `imagenet` ' 268 | '(pre-training on ImageNet).') 269 | 270 | if weights == 'imagenet' and include_top and classes != 1000: 271 | raise ValueError('If using `weights` as imagenet with `include_top`' 272 | ' as true, `classes` should be 1000') 273 | 274 | # Determine proper input shape 275 | input_shape = _obtain_input_shape( 276 | input_shape, 277 | default_size=299, 278 | min_size=139, 279 | data_format=K.image_data_format(), 280 | require_flatten=False, 281 | weights=weights) 282 | 283 | if input_tensor is None: 284 | img_input = Input(shape=input_shape) 285 | else: 286 | if not K.is_keras_tensor(input_tensor): 287 | img_input = Input(tensor=input_tensor, shape=input_shape) 288 | else: 289 | img_input = input_tensor 290 | 291 | # Stem block: 35 x 35 x 192 292 | x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid', name='Conv2d_1a_3x3') 293 | x = conv2d_bn(x, 32, 3, padding='valid', name='Conv2d_2a_3x3') 294 | x = conv2d_bn(x, 64, 3, name='Conv2d_2b_3x3') 295 | x = MaxPooling2D(3, strides=2, name='MaxPool_3a_3x3')(x) 296 | x = conv2d_bn(x, 80, 1, padding='valid', name='Conv2d_3b_1x1') 297 | x = conv2d_bn(x, 192, 3, padding='valid', name='Conv2d_4a_3x3') 298 | x = MaxPooling2D(3, strides=2, name='MaxPool_5a_3x3')(x) 299 | 300 | # Mixed 5b (Inception-A block): 35 x 35 x 320 301 | channel_axis = 1 if K.image_data_format() == 'channels_first' else 3 302 | name_fmt = partial(_generate_layer_name, prefix='Mixed_5b') 303 | branch_0 = conv2d_bn(x, 96, 1, name=name_fmt('Conv2d_1x1', 0)) 304 | branch_1 = conv2d_bn(x, 48, 1, name=name_fmt('Conv2d_0a_1x1', 1)) 305 | branch_1 = conv2d_bn(branch_1, 64, 5, name=name_fmt('Conv2d_0b_5x5', 1)) 306 | branch_2 = conv2d_bn(x, 64, 1, name=name_fmt('Conv2d_0a_1x1', 2)) 307 | branch_2 = conv2d_bn(branch_2, 96, 3, name=name_fmt('Conv2d_0b_3x3', 2)) 308 | branch_2 = conv2d_bn(branch_2, 96, 3, name=name_fmt('Conv2d_0c_3x3', 2)) 309 | branch_pool = AveragePooling2D(3, 310 | strides=1, 311 | padding='same', 312 | name=name_fmt('AvgPool_0a_3x3', 3))(x) 313 | branch_pool = conv2d_bn(branch_pool, 64, 1, name=name_fmt('Conv2d_0b_1x1', 3)) 314 | branches = [branch_0, branch_1, branch_2, branch_pool] 315 | x = Concatenate(axis=channel_axis, name='Mixed_5b')(branches) 316 | 317 | # 10x Block35 (Inception-ResNet-A block): 35 x 35 x 320 318 | for block_idx in range(1, 11): 319 | x = _inception_resnet_block(x, 320 | scale=0.17, 321 | block_type='Block35', 322 | block_idx=block_idx) 323 | 324 | # Mixed 6a (Reduction-A block): 17 x 17 x 1088 325 | name_fmt = partial(_generate_layer_name, prefix='Mixed_6a') 326 | branch_0 = conv2d_bn(x, 327 | 384, 328 | 3, 329 | strides=2, 330 | padding='valid', 331 | name=name_fmt('Conv2d_1a_3x3', 0)) 332 | branch_1 = conv2d_bn(x, 256, 1, name=name_fmt('Conv2d_0a_1x1', 1)) 333 | branch_1 = conv2d_bn(branch_1, 256, 3, name=name_fmt('Conv2d_0b_3x3', 1)) 334 | branch_1 = conv2d_bn(branch_1, 335 | 384, 336 | 3, 337 | strides=2, 338 | padding='valid', 339 | name=name_fmt('Conv2d_1a_3x3', 1)) 340 | branch_pool = MaxPooling2D(3, 341 | strides=2, 342 | padding='valid', 343 | name=name_fmt('MaxPool_1a_3x3', 2))(x) 344 | branches = [branch_0, branch_1, branch_pool] 345 | x = Concatenate(axis=channel_axis, name='Mixed_6a')(branches) 346 | 347 | # 20x Block17 (Inception-ResNet-B block): 17 x 17 x 1088 348 | for block_idx in range(1, 21): 349 | x = _inception_resnet_block(x, 350 | scale=0.1, 351 | block_type='Block17', 352 | block_idx=block_idx) 353 | 354 | # Mixed 7a (Reduction-B block): 8 x 8 x 2080 355 | name_fmt = partial(_generate_layer_name, prefix='Mixed_7a') 356 | branch_0 = conv2d_bn(x, 256, 1, name=name_fmt('Conv2d_0a_1x1', 0)) 357 | branch_0 = conv2d_bn(branch_0, 358 | 384, 359 | 3, 360 | strides=2, 361 | padding='valid', 362 | name=name_fmt('Conv2d_1a_3x3', 0)) 363 | branch_1 = conv2d_bn(x, 256, 1, name=name_fmt('Conv2d_0a_1x1', 1)) 364 | branch_1 = conv2d_bn(branch_1, 365 | 288, 366 | 3, 367 | strides=2, 368 | padding='valid', 369 | name=name_fmt('Conv2d_1a_3x3', 1)) 370 | branch_2 = conv2d_bn(x, 256, 1, name=name_fmt('Conv2d_0a_1x1', 2)) 371 | branch_2 = conv2d_bn(branch_2, 288, 3, name=name_fmt('Conv2d_0b_3x3', 2)) 372 | branch_2 = conv2d_bn(branch_2, 373 | 320, 374 | 3, 375 | strides=2, 376 | padding='valid', 377 | name=name_fmt('Conv2d_1a_3x3', 2)) 378 | branch_pool = MaxPooling2D(3, 379 | strides=2, 380 | padding='valid', 381 | name=name_fmt('MaxPool_1a_3x3', 3))(x) 382 | branches = [branch_0, branch_1, branch_2, branch_pool] 383 | x = Concatenate(axis=channel_axis, name='Mixed_7a')(branches) 384 | 385 | # 10x Block8 (Inception-ResNet-C block): 8 x 8 x 2080 386 | for block_idx in range(1, 10): 387 | x = _inception_resnet_block(x, 388 | scale=0.2, 389 | block_type='Block8', 390 | block_idx=block_idx) 391 | x = _inception_resnet_block(x, 392 | scale=1., 393 | activation=None, 394 | block_type='Block8', 395 | block_idx=10) 396 | 397 | # Final convolution block 398 | x = conv2d_bn(x, 1536, 1, name='Conv2d_7b_1x1') 399 | 400 | if include_top: 401 | # Classification block 402 | x = GlobalAveragePooling2D(name='AvgPool')(x) 403 | x = Dropout(1.0 - dropout_keep_prob, name='Dropout')(x) 404 | x = Dense(classes, name='Logits')(x) 405 | x = Activation('softmax', name='Predictions')(x) 406 | else: 407 | if pooling == 'avg': 408 | x = GlobalAveragePooling2D(name='AvgPool')(x) 409 | elif pooling == 'max': 410 | x = GlobalMaxPooling2D(name='MaxPool')(x) 411 | 412 | # Ensure that the model takes into account 413 | # any potential predecessors of `input_tensor` 414 | if input_tensor is not None: 415 | inputs = get_source_inputs(input_tensor) 416 | else: 417 | inputs = img_input 418 | 419 | # Create model 420 | model = Model(inputs, x, name='inception_resnet_v2') 421 | 422 | # Load weights 423 | if weights == 'imagenet': 424 | if K.image_data_format() == 'channels_first': 425 | if K.backend() == 'tensorflow': 426 | warnings.warn('You are using the TensorFlow backend, yet you ' 427 | 'are using the Theano ' 428 | 'image data format convention ' 429 | '(`image_data_format="channels_first"`). ' 430 | 'For best performance, set ' 431 | '`image_data_format="channels_last"` in ' 432 | 'your Keras config ' 433 | 'at ~/.keras/keras.json.') 434 | if include_top: 435 | weights_filename = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5' 436 | weights_path = get_file(weights_filename, 437 | BASE_WEIGHT_URL + weights_filename, 438 | cache_subdir='models', 439 | md5_hash='e693bd0210a403b3192acc6073ad2e96') 440 | else: 441 | weights_filename = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5' 442 | weights_path = get_file(weights_filename, 443 | BASE_WEIGHT_URL + weights_filename, 444 | cache_subdir='models', 445 | md5_hash='d19885ff4a710c122648d3b5c3b684e4') 446 | model.load_weights(weights_path) 447 | 448 | return model 449 | --------------------------------------------------------------------------------