├── Architectures ├── Inception-v4.png ├── Inception ResNet-v1.png └── Inception ResNet-v2.png ├── README.md ├── google-cloud.sh ├── .gitignore ├── experiment.sh ├── data.py ├── train_inception_v4.py ├── train_inception_resnet_v2.py ├── inception_v4.py └── inception_resnet_v2.py /Architectures/Inception-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaspermarstal/BrainNet/HEAD/Architectures/Inception-v4.png -------------------------------------------------------------------------------- /Architectures/Inception ResNet-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaspermarstal/BrainNet/HEAD/Architectures/Inception ResNet-v1.png -------------------------------------------------------------------------------- /Architectures/Inception ResNet-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaspermarstal/BrainNet/HEAD/Architectures/Inception ResNet-v2.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BrainNet 2 | Adaptation of @titu1994's Inception v4 and Inception ResNet v4 architectures to MRI images of the human brain. The paper on these architectures is available at "Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning". 3 | 4 | ## Please note 5 | This repo serves as an example on how to run experiments on Google Cloud, not how to segment brain images. Today there are more efficient architectures out there for this kind of segmentation. 6 | 7 | ## Experiment 8 | This repository contains code for training the networks to segment white matter and gray matter on MRI scans from the The Open Access Series of Imaging Studies (OASIS) archive. 9 | 10 | To start the experiment, clone the repository and run 11 | 12 | ``` 13 | $ ./experiment.sh 14 | ``` 15 | 16 | Data is downloaded, extracted and preprocessed automatically. 17 | 18 | ## Google Cloud 19 | Provision a Google Cloud CPU or GPU instance with `google-cloud.sh` using either of the following commands: 20 | 21 | ``` 22 | $ ./google-cloud.sh --create-cpu-instance 23 | $ ./google-cloud.sh --create-gpu-instance 24 | ``` 25 | 26 | SSH into the instance once it is up and running, clone, and invoke `experiment.sh` from there. 27 | -------------------------------------------------------------------------------- /google-cloud.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == "--create-cpu-instance" ]]; then 4 | gcloud beta compute instances create brainnet-cpu \ 5 | --zone europe-west1-b \ 6 | --machine-type n1-highcpu-16 \ 7 | --image-project ubuntu-os-cloud \ 8 | --image-family ubuntu-1604-lts \ 9 | --boot-disk-device-name=brainnet \ 10 | --boot-disk-type=pd-standard \ 11 | --boot-disk-size=64GB \ 12 | --maintenance-policy TERMINATE --restart-on-failure 13 | fi 14 | 15 | if [[ $1 == "--create-gpu-instance" ]]; then 16 | gcloud beta compute instances create brainnet-gpu \ 17 | --zone europe-west1-b \ 18 | --machine-type n1-highmem-2 \ 19 | --image-project ubuntu-os-cloud \ 20 | --image-family ubuntu-1604-lts \ 21 | --boot-disk-device-name=brainnet-gpu \ 22 | --boot-disk-type=pd-standard \ 23 | --boot-disk-size=64GB \ 24 | --accelerator type=nvidia-tesla-k80,count=1 \ 25 | --maintenance-policy TERMINATE --restart-on-failure \ 26 | --metadata startup-script='#!/bin/bash 27 | echo "Checking for CUDA and installing." 28 | if ! dpkg-query -W cuda; then 29 | curl -O http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_8.0.61-1_amd64.deb 30 | dpkg -i ./cuda-repo-ubuntu1604_8.0.61-1_amd64.deb 31 | apt-get update 32 | apt-get install cuda -y 33 | fi' 34 | fi 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # PyCharm project settings 97 | .idea/ 98 | 99 | # TensorFlow 100 | TensorBoard/ -------------------------------------------------------------------------------- /experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == "--download-data" ]]; then 4 | if [[ -d ~/BrainNet/downloads ]]; then 5 | rm -fr ~/BrainNet/downloads 6 | fi 7 | mkdir -p ~/BrainNet/downloads 8 | cd ~/BrainNet/downloads 9 | curl -OL ftp://ftp.nrg.wustl.edu/data/oasis_cross-sectional_disc{1,2,3,4,5,6,7,8,9,10,11}.tar.gz 10 | fi 11 | 12 | if [[ $1 == "--extract-data" ]]; then 13 | if [[ -d ~/BrainNet/data ]]; then 14 | rm -fr ~/BrainNet/data 15 | fi 16 | mkdir -p ~/BrainNet/data 17 | for filename in ~/downloads/*.tar.gz 18 | do 19 | tar zxf $filename -C ~/data 20 | done 21 | fi 22 | 23 | if [[ $1 == "--install-venv" ]]; then 24 | if [[ -d ~/BrainNet/venv ]]; then 25 | rm -fr ~/BrainNet/venv 26 | fi 27 | sudo apt-get install -y virtualenv gcc python-dev 28 | virtualenv ~/BrainNet/venv 29 | source ~/BrainNet/venv/bin/activate 30 | pip install keras SimpleITK numpy sklearn scikit-image tensorflow 31 | fi 32 | 33 | if [[ $1 == "--install-tensorflow-gpu" ]]; then 34 | source ~/BrainNet/venv/bin/activate 35 | pip install --upgrade tensorflow-gpu 36 | fi 37 | 38 | if [[ $1 == "--run-inception-v4" ]]; then 39 | source ~/BrainNet/venv/bin/activate 40 | python ~/BrainNet/BrainNet/train_inception_v4.py --data-dir=$HOME/data 41 | fi 42 | 43 | if [[ $1 == "--run-inception-resnet-v2" ]]; then 44 | source ~/BrainNet/venv/bin/activate 45 | python ~/BrainNet/BrainNet/train_inception_resnet_v2.py --data-dir=$HOME/BrainNet/data 46 | fi 47 | 48 | if [ -z $1 ]; then 49 | ./experiment.sh --download-data 50 | ./experiment.sh --extract-data 51 | ./experiment.sh --install-venv 52 | ./experiment.sh --install-tensorflow-gpu || true 53 | ./experiment.sh --run-inception-v4 54 | fi 55 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Util 2 | import time 3 | 4 | # Images 5 | import numpy as np 6 | from numpy.random import random, randint 7 | from SimpleITK import Extract, GetArrayFromImage, ReadImage 8 | from skimage.transform import resize 9 | 10 | # Keras 11 | from keras.utils.np_utils import to_categorical 12 | 13 | # The number of classes is tied to the dataset 14 | nb_classes = 4 15 | 16 | class Timer(object): 17 | def __init__(self, name=None): 18 | self.name = name 19 | 20 | def __enter__(self): 21 | self.tstart = time.time() 22 | 23 | def __exit__(self, type, value, traceback): 24 | if self.name: 25 | print '[%s]' % self.name, 26 | print 'Elapsed: %s.' % (time.time() - self.tstart) 27 | 28 | def generator(images, labels, input_shape, patch_size=32, batch_size=32): 29 | while True: 30 | X = np.empty((batch_size, input_shape[0], input_shape[1], input_shape[2]), dtype=float) 31 | y = np.empty(batch_size, dtype=int) 32 | 33 | for i in range(0, batch_size): 34 | subject_id = randint(0, len(images)) 35 | X[i, :, :], y[i] = generate_one(images[subject_id], labels[subject_id], input_shape, patch_size) 36 | 37 | y = to_categorical(y, nb_classes=nb_classes) 38 | 39 | yield X, y 40 | 41 | def generate_one(image, label, input_shape, patch_size): 42 | p = random() 43 | 44 | # TODO: Obtain patches at abitrary angles 45 | if p > 0.66: 46 | patch, label = axial_patch_generator(image, label, patch_size) 47 | elif p > 0.33: 48 | patch, label = coronal_patch_generator(image, label, patch_size) 49 | else: 50 | patch, label = sagittal_patch_generator(image, label, patch_size) 51 | 52 | patch = GetArrayFromImage(patch) 53 | 54 | if random() > 0.5: 55 | patch = np.fliplr(patch) 56 | 57 | if random() > 0.5: 58 | patch = np.flipud(patch) 59 | 60 | patch = resize(patch, input_shape) 61 | 62 | return patch, label 63 | 64 | def axial_patch_generator(image, label, patch_size): 65 | image_size = image.GetSize() 66 | assert(image_size == label.GetSize()) 67 | 68 | point = (randint(0, image_size[0] - patch_size), randint(0, image_size[1]), randint(0, image_size[2] - patch_size)) 69 | patch = Extract(image, (patch_size, 0, patch_size), point) 70 | 71 | return patch, label.GetPixel(point) 72 | 73 | def coronal_patch_generator(image, label, patch_size): 74 | image_size = image.GetSize() 75 | assert (image_size == label.GetSize()) 76 | 77 | point = (randint(0, image_size[0] - patch_size), randint(0, image_size[1] - patch_size), randint(0, image_size[2])) 78 | patch = Extract(image, (patch_size, patch_size, 0), point) 79 | 80 | return patch, label.GetPixel(point) 81 | 82 | def sagittal_patch_generator(image, label, patch_size): 83 | image_size = image.GetSize() 84 | assert (image_size == label.GetSize()) 85 | 86 | point = (randint(0, image_size[0]), randint(0, image_size[1] - patch_size), randint(0, image_size[2] - patch_size)) 87 | patch = Extract(image, (0, patch_size, patch_size), point) 88 | 89 | return patch, label.GetPixel(point) 90 | 91 | 92 | -------------------------------------------------------------------------------- /train_inception_v4.py: -------------------------------------------------------------------------------- 1 | # Util 2 | import os 3 | import sys 4 | import argparse 5 | from glob import glob 6 | 7 | # Data 8 | from data import generator, nb_classes, ReadImage, Timer 9 | from sklearn.model_selection import train_test_split 10 | 11 | # Keras 12 | from keras.callbacks import EarlyStopping, TensorBoard 13 | from keras.optimizers import RMSprop 14 | 15 | # Model 16 | from inception_v4 import create_inception_v4, input_shape 17 | 18 | # TensorFlow 19 | from tensorflow.python.platform import app 20 | 21 | def main(argv): 22 | 23 | print('Finding data ...'), 24 | with Timer(): 25 | image_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/PROCESSED/MPRAGE/T88_111/OAS1_*_MR1_mpr_n4_anon_111_t88_gfc.hdr')) 26 | label_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/FSL_SEG/OAS1_*_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.hdr')) 27 | assert(len(image_filenames) == len(label_filenames)) 28 | print('Found %i images.' % len(image_filenames)) 29 | 30 | print('Loading images ...'), 31 | with Timer(): 32 | images = [ReadImage(image_filename) for image_filename in image_filenames] 33 | labels = [ReadImage(label_filename) for label_filename in label_filenames] 34 | images_train, images_test, labels_train, labels_test = train_test_split(images, labels, train_size=0.66) 35 | 36 | tensor_board = TensorBoard(log_dir='./TensorBoard') 37 | early_stopping = EarlyStopping(monitor='acc', patience=2, verbose=1) 38 | 39 | model = create_inception_v4(nb_classes=nb_classes, load_weights=False) 40 | model.compile(optimizer=RMSprop(lr=0.045, rho=0.94, epsilon=1., decay=0.9), loss='categorical_crossentropy', metrics=['acc']) 41 | model.fit_generator(generator(images_train, labels_train, input_shape, nb_classes, FLAGS.patch_size, FLAGS.batch_size), 42 | samples_per_epoch=FLAGS.samples_per_epoch, nb_epoch=FLAGS.nb_epochs, callbacks=[tensor_board, early_stopping], 43 | verbose=1) 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument( 48 | '-d', 49 | '--data-dir', 50 | dest='data_dir', 51 | help='Path to data directory.', 52 | required=True, 53 | ) 54 | parser.add_argument( 55 | '-p', 56 | '--patch-size', 57 | default=32, 58 | type=int, 59 | dest='patch_size', 60 | help='Size of the p-by-p patch in millimetre (mm).', 61 | ) 62 | parser.add_argument( 63 | '-b', 64 | '--batch-size', 65 | default=32, 66 | type=int, 67 | dest='batch_size', 68 | help='Batch size.', 69 | ) 70 | parser.add_argument( 71 | '-e', 72 | '--nb-epochs', 73 | default=8, 74 | type=int, 75 | dest='nb_epochs', 76 | help='Number of epochs.', 77 | ) 78 | parser.add_argument( 79 | '-s', 80 | '--samples-per-epoch', 81 | default=1024, 82 | type=int, 83 | dest='samples_per_epoch', 84 | help='Number of samples per epoch.', 85 | ) 86 | 87 | FLAGS, unparsed = parser.parse_known_args() 88 | app.run(main=main, argv=[sys.argv[0]] + unparsed) 89 | -------------------------------------------------------------------------------- /train_inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | # Util 2 | import os 3 | import sys 4 | import argparse 5 | from glob import glob 6 | 7 | # Data 8 | from data import generator, nb_classes, ReadImage, Timer 9 | from sklearn.model_selection import train_test_split 10 | 11 | # Keras 12 | from keras.callbacks import EarlyStopping, TensorBoard 13 | from keras.optimizers import RMSprop 14 | 15 | # Model 16 | from inception_resnet_v2 import create_inception_resnet_v2, input_shape 17 | 18 | # TensorFlow 19 | from tensorflow.python.platform import app 20 | 21 | def main(argv): 22 | 23 | print('Finding data ...'), 24 | with Timer(): 25 | image_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/PROCESSED/MPRAGE/T88_111/OAS1_*_MR1_mpr_n4_anon_111_t88_gfc.hdr')) 26 | label_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/FSL_SEG/OAS1_*_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.hdr')) 27 | assert(len(image_filenames) == len(label_filenames)) 28 | print('Found %i images.' % len(image_filenames)) 29 | 30 | print('Loading images ...'), 31 | with Timer(): 32 | images = [ReadImage(image_filename) for image_filename in image_filenames] 33 | labels = [ReadImage(label_filename) for label_filename in label_filenames] 34 | images_train, images_test, labels_train, labels_test = train_test_split(images, labels, train_size=0.66) 35 | 36 | tensor_board = TensorBoard(log_dir='./TensorBoard') 37 | early_stopping = EarlyStopping(monitor='acc', patience=2, verbose=1) 38 | 39 | model = create_inception_resnet_v2(nb_classes=nb_classes) 40 | model.compile(optimizer=RMSprop(lr=0.045, rho=0.94, epsilon=1., decay=0.9), loss='categorical_crossentropy', metrics=['acc']) 41 | model.fit_generator(generator(images_train, labels_train, input_shape, nb_classes, FLAGS.patch_size, FLAGS.batch_size), 42 | samples_per_epoch=FLAGS.samples_per_epoch, nb_epoch=FLAGS.nb_epochs, callbacks=[tensor_board, early_stopping], 43 | verbose=1) 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument( 48 | '-d', 49 | '--data-dir', 50 | dest='data_dir', 51 | help='Path to data directory.', 52 | required=True, 53 | ) 54 | parser.add_argument( 55 | '-p', 56 | '--patch-size', 57 | default=32, 58 | type=int, 59 | dest='patch_size', 60 | help='Size of the p-by-p patch in millimetre (mm).', 61 | ) 62 | parser.add_argument( 63 | '-b', 64 | '--batch-size', 65 | default=32, 66 | type=int, 67 | dest='batch_size', 68 | help='Batch size.', 69 | ) 70 | parser.add_argument( 71 | '-e', 72 | '--nb-epochs', 73 | default=8, 74 | type=int, 75 | dest='nb_epochs', 76 | help='Number of epochs.', 77 | ) 78 | parser.add_argument( 79 | '-s', 80 | '--samples-per-epoch', 81 | default=1024, 82 | type=int, 83 | dest='samples_per_epoch', 84 | help='Number of samples per epoch.', 85 | ) 86 | 87 | FLAGS, unparsed = parser.parse_known_args() 88 | app.run(main=main, argv=[sys.argv[0]] + unparsed) 89 | 90 | -------------------------------------------------------------------------------- /inception_v4.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, merge, Dropout, Dense, Flatten, Activation 2 | from keras.layers.convolutional import MaxPooling2D, Convolution2D, AveragePooling2D 3 | from keras.layers.normalization import BatchNormalization 4 | from keras.models import Model 5 | 6 | from keras import backend as K 7 | from keras.utils.data_utils import get_file 8 | 9 | """ 10 | Implementation of Inception Network v4 [Inception Network v4 Paper](http://arxiv.org/pdf/1602.07261v1.pdf) in Keras. 11 | """ 12 | 13 | # The input shape is tied to the network 14 | if K.image_dim_ordering() == 'th': 15 | input_shape = (1, 299, 299) 16 | else: 17 | input_shape = (299, 299, 1) 18 | 19 | def conv_block(x, nb_filter, nb_row, nb_col, border_mode='same', subsample=(1, 1), bias=False): 20 | if K.image_dim_ordering() == "th": 21 | channel_axis = 1 22 | else: 23 | channel_axis = -1 24 | 25 | x = Convolution2D(nb_filter, nb_row, nb_col, subsample=subsample, border_mode=border_mode, bias=bias)(x) 26 | x = BatchNormalization(axis=channel_axis)(x) 27 | x = Activation('relu')(x) 28 | return x 29 | 30 | 31 | def inception_stem(input): 32 | if K.image_dim_ordering() == "th": 33 | channel_axis = 1 34 | else: 35 | channel_axis = -1 36 | 37 | # Input Shape is 299 x 299 x 3 (th) or 3 x 299 x 299 (th) 38 | x = conv_block(input, 32, 3, 3, subsample=(2, 2), border_mode='valid') 39 | x = conv_block(x, 32, 3, 3, border_mode='valid') 40 | x = conv_block(x, 64, 3, 3) 41 | 42 | x1 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(x) 43 | x2 = conv_block(x, 96, 3, 3, subsample=(2, 2), border_mode='valid') 44 | 45 | x = merge([x1, x2], mode='concat', concat_axis=channel_axis) 46 | 47 | x1 = conv_block(x, 64, 1, 1) 48 | x1 = conv_block(x1, 96, 3, 3, border_mode='valid') 49 | 50 | x2 = conv_block(x, 64, 1, 1) 51 | x2 = conv_block(x2, 64, 1, 7) 52 | x2 = conv_block(x2, 64, 7, 1) 53 | x2 = conv_block(x2, 96, 3, 3, border_mode='valid') 54 | 55 | x = merge([x1, x2], mode='concat', concat_axis=channel_axis) 56 | 57 | x1 = conv_block(x, 192, 3, 3, subsample=(2, 2), border_mode='valid') 58 | x2 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(x) 59 | 60 | x = merge([x1, x2], mode='concat', concat_axis=channel_axis) 61 | return x 62 | 63 | 64 | def inception_A(input): 65 | if K.image_dim_ordering() == "th": 66 | channel_axis = 1 67 | else: 68 | channel_axis = -1 69 | 70 | a1 = conv_block(input, 96, 1, 1) 71 | 72 | a2 = conv_block(input, 64, 1, 1) 73 | a2 = conv_block(a2, 96, 3, 3) 74 | 75 | a3 = conv_block(input, 64, 1, 1) 76 | a3 = conv_block(a3, 96, 3, 3) 77 | a3 = conv_block(a3, 96, 3, 3) 78 | 79 | a4 = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(input) 80 | a4 = conv_block(a4, 96, 1, 1) 81 | 82 | m = merge([a1, a2, a3, a4], mode='concat', concat_axis=channel_axis) 83 | return m 84 | 85 | 86 | def inception_B(input): 87 | if K.image_dim_ordering() == "th": 88 | channel_axis = 1 89 | else: 90 | channel_axis = -1 91 | 92 | b1 = conv_block(input, 384, 1, 1) 93 | 94 | b2 = conv_block(input, 192, 1, 1) 95 | b2 = conv_block(b2, 224, 1, 7) 96 | b2 = conv_block(b2, 256, 7, 1) 97 | 98 | b3 = conv_block(input, 192, 1, 1) 99 | b3 = conv_block(b3, 192, 7, 1) 100 | b3 = conv_block(b3, 224, 1, 7) 101 | b3 = conv_block(b3, 224, 7, 1) 102 | b3 = conv_block(b3, 256, 1, 7) 103 | 104 | b4 = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(input) 105 | b4 = conv_block(b4, 128, 1, 1) 106 | 107 | m = merge([b1, b2, b3, b4], mode='concat', concat_axis=channel_axis) 108 | return m 109 | 110 | 111 | def inception_C(input): 112 | if K.image_dim_ordering() == "th": 113 | channel_axis = 1 114 | else: 115 | channel_axis = -1 116 | 117 | c1 = conv_block(input, 256, 1, 1) 118 | 119 | c2 = conv_block(input, 384, 1, 1) 120 | c2_1 = conv_block(c2, 256, 1, 3) 121 | c2_2 = conv_block(c2, 256, 3, 1) 122 | c2 = merge([c2_1, c2_2], mode='concat', concat_axis=channel_axis) 123 | 124 | c3 = conv_block(input, 384, 1, 1) 125 | c3 = conv_block(c3, 448, 3, 1) 126 | c3 = conv_block(c3, 512, 1, 3) 127 | c3_1 = conv_block(c3, 256, 1, 3) 128 | c3_2 = conv_block(c3, 256, 3, 1) 129 | c3 = merge([c3_1, c3_2], mode='concat', concat_axis=channel_axis) 130 | 131 | c4 = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(input) 132 | c4 = conv_block(c4, 256, 1, 1) 133 | 134 | m = merge([c1, c2, c3, c4], mode='concat', concat_axis=channel_axis) 135 | return m 136 | 137 | 138 | def reduction_A(input): 139 | if K.image_dim_ordering() == "th": 140 | channel_axis = 1 141 | else: 142 | channel_axis = -1 143 | 144 | r1 = conv_block(input, 384, 3, 3, subsample=(2, 2), border_mode='valid') 145 | 146 | r2 = conv_block(input, 192, 1, 1) 147 | r2 = conv_block(r2, 224, 3, 3) 148 | r2 = conv_block(r2, 256, 3, 3, subsample=(2, 2), border_mode='valid') 149 | 150 | r3 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(input) 151 | 152 | m = merge([r1, r2, r3], mode='concat', concat_axis=channel_axis) 153 | return m 154 | 155 | 156 | def reduction_B(input): 157 | if K.image_dim_ordering() == "th": 158 | channel_axis = 1 159 | else: 160 | channel_axis = -1 161 | 162 | r1 = conv_block(input, 192, 1, 1) 163 | r1 = conv_block(r1, 192, 3, 3, subsample=(2, 2), border_mode='valid') 164 | 165 | r2 = conv_block(input, 256, 1, 1) 166 | r2 = conv_block(r2, 256, 1, 7) 167 | r2 = conv_block(r2, 320, 7, 1) 168 | r2 = conv_block(r2, 320, 3, 3, subsample=(2, 2), border_mode='valid') 169 | 170 | r3 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(input) 171 | 172 | m = merge([r1, r2, r3], mode='concat', concat_axis=channel_axis) 173 | return m 174 | 175 | 176 | def create_inception_v4(nb_classes=1001, load_weights=True): 177 | ''' 178 | Creates a inception v4 network 179 | 180 | :param nb_classes: number of classes.txt 181 | :return: Keras Model with 1 input and 1 output 182 | ''' 183 | 184 | # Input Shape is 299 x 299 x 1 (tf) or 1 x 299 x 299 (th) 185 | init = Input(input_shape) 186 | 187 | x = inception_stem(init) 188 | 189 | # 4 x Inception A 190 | for i in range(4): 191 | x = inception_A(x) 192 | 193 | # Reduction A 194 | x = reduction_A(x) 195 | 196 | # 7 x Inception B 197 | for i in range(7): 198 | x = inception_B(x) 199 | 200 | # Reduction B 201 | x = reduction_B(x) 202 | 203 | # 3 x Inception C 204 | for i in range(3): 205 | x = inception_C(x) 206 | 207 | # Average Pooling 208 | x = AveragePooling2D((8, 8))(x) 209 | 210 | # Dropout 211 | x = Dropout(0.8)(x) 212 | x = Flatten()(x) 213 | 214 | # Output 215 | out = Dense(output_dim=nb_classes, activation='softmax')(x) 216 | 217 | model = Model(init, out, name='Inception-v4') 218 | return model 219 | 220 | 221 | if __name__ == "__main__": 222 | # from keras.utils.visualize_util import plot 223 | 224 | inception_v4 = create_inception_v4(load_weights=True) 225 | # inception_v4.summary() 226 | 227 | # plot(inception_v4, to_file="Inception-v4.png", show_shapes=True) 228 | -------------------------------------------------------------------------------- /inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, merge, Dropout, Dense, Lambda, Flatten, Activation 2 | from keras.layers.convolutional import MaxPooling2D, Convolution2D, AveragePooling2D 3 | from keras.layers.normalization import BatchNormalization 4 | from keras.models import Model 5 | 6 | from keras import backend as K 7 | 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | 11 | """ 12 | Implementation of Inception-Residual Network v1 [Inception Network v4 Paper](http://arxiv.org/pdf/1602.07261v1.pdf) in Keras. 13 | 14 | Some additional details: 15 | [1] Each of the A, B and C blocks have a 'scale_residual' parameter. 16 | The scale residual parameter is according to the paper. It is however turned OFF by default. 17 | 18 | Simply setting 'scale=True' in the create_inception_resnet_v2() method will add scaling. 19 | 20 | [2] There were minor inconsistencies with filter size in both B and C blocks. 21 | 22 | In the B blocks: 'ir_conv' nb of filters is given as 1154, however input size is 1152. 23 | This causes inconsistencies in the merge-add mode, therefore the 'ir_conv' filter size 24 | is reduced to 1152 to match input size. 25 | 26 | In the C blocks: 'ir_conv' nb of filter is given as 2048, however input size is 2144. 27 | This causes inconsistencies in the merge-add mode, therefore the 'ir_conv' filter size 28 | is increased to 2144 to match input size. 29 | 30 | Currently trying to find a proper solution with original nb of filters. 31 | 32 | [3] In the stem function, the last Convolutional2D layer has 384 filters instead of the original 256. 33 | This is to correctly match the nb of filters in 'ir_conv' of the next A blocks. 34 | """ 35 | 36 | # The input shape is tied to the network 37 | if K.image_dim_ordering() == 'th': 38 | input_shape = (1, 299, 299) 39 | else: 40 | input_shape = (299, 299, 1) 41 | 42 | def inception_resnet_stem(input): 43 | if K.image_dim_ordering() == "th": 44 | channel_axis = 1 45 | else: 46 | channel_axis = -1 47 | 48 | # Input Shape is 299 x 299 x 3 (th) or 3 x 299 x 299 (th) 49 | c = Convolution2D(32, 3, 3, activation='relu', subsample=(2, 2))(input) 50 | c = Convolution2D(32, 3, 3, activation='relu', )(c) 51 | c = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(c) 52 | 53 | c1 = MaxPooling2D((3, 3), strides=(2, 2))(c) 54 | c2 = Convolution2D(96, 3, 3, activation='relu', subsample=(2, 2))(c) 55 | 56 | m = merge([c1, c2], mode='concat', concat_axis=channel_axis) 57 | 58 | c1 = Convolution2D(64, 1, 1, activation='relu', border_mode='same')(m) 59 | c1 = Convolution2D(96, 3, 3, activation='relu', )(c1) 60 | 61 | c2 = Convolution2D(64, 1, 1, activation='relu', border_mode='same')(m) 62 | c2 = Convolution2D(64, 7, 1, activation='relu', border_mode='same')(c2) 63 | c2 = Convolution2D(64, 1, 7, activation='relu', border_mode='same')(c2) 64 | c2 = Convolution2D(96, 3, 3, activation='relu', border_mode='valid')(c2) 65 | 66 | m2 = merge([c1, c2], mode='concat', concat_axis=channel_axis) 67 | 68 | p1 = MaxPooling2D((3, 3), strides=(2, 2), )(m2) 69 | p2 = Convolution2D(192, 3, 3, activation='relu', subsample=(2, 2))(m2) 70 | 71 | m3 = merge([p1, p2], mode='concat', concat_axis=channel_axis) 72 | m3 = BatchNormalization(axis=channel_axis)(m3) 73 | m3 = Activation('relu')(m3) 74 | return m3 75 | 76 | def inception_resnet_v2_A(input, scale_residual=True): 77 | if K.image_dim_ordering() == "th": 78 | channel_axis = 1 79 | else: 80 | channel_axis = -1 81 | 82 | # Input is relu activation 83 | init = input 84 | 85 | ir1 = Convolution2D(32, 1, 1, activation='relu', border_mode='same')(input) 86 | 87 | ir2 = Convolution2D(32, 1, 1, activation='relu', border_mode='same')(input) 88 | ir2 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(ir2) 89 | 90 | ir3 = Convolution2D(32, 1, 1, activation='relu', border_mode='same')(input) 91 | ir3 = Convolution2D(48, 3, 3, activation='relu', border_mode='same')(ir3) 92 | ir3 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(ir3) 93 | 94 | ir_merge = merge([ir1, ir2, ir3], concat_axis=channel_axis, mode='concat') 95 | 96 | ir_conv = Convolution2D(384, 1, 1, activation='linear', border_mode='same')(ir_merge) 97 | if scale_residual: ir_conv = Lambda(lambda x: x * 0.1)(ir_conv) 98 | 99 | out = merge([init, ir_conv], mode='sum') 100 | out = BatchNormalization(axis=channel_axis)(out) 101 | out = Activation("relu")(out) 102 | return out 103 | 104 | def inception_resnet_v2_B(input, scale_residual=True): 105 | if K.image_dim_ordering() == "th": 106 | channel_axis = 1 107 | else: 108 | channel_axis = -1 109 | 110 | # Input is relu activation 111 | init = input 112 | 113 | ir1 = Convolution2D(192, 1, 1, activation='relu', border_mode='same')(input) 114 | 115 | ir2 = Convolution2D(128, 1, 1, activation='relu', border_mode='same')(input) 116 | ir2 = Convolution2D(160, 1, 7, activation='relu', border_mode='same')(ir2) 117 | ir2 = Convolution2D(192, 7, 1, activation='relu', border_mode='same')(ir2) 118 | 119 | ir_merge = merge([ir1, ir2], mode='concat', concat_axis=channel_axis) 120 | 121 | ir_conv = Convolution2D(1152, 1, 1, activation='linear', border_mode='same')(ir_merge) 122 | if scale_residual: ir_conv = Lambda(lambda x: x * 0.1)(ir_conv) 123 | 124 | out = merge([init, ir_conv], mode='sum') 125 | out = BatchNormalization(axis=channel_axis)(out) 126 | out = Activation("relu")(out) 127 | return out 128 | 129 | def inception_resnet_v2_C(input, scale_residual=True): 130 | if K.image_dim_ordering() == "th": 131 | channel_axis = 1 132 | else: 133 | channel_axis = -1 134 | 135 | # Input is relu activation 136 | init = input 137 | 138 | ir1 = Convolution2D(192, 1, 1, activation='relu', border_mode='same')(input) 139 | 140 | ir2 = Convolution2D(192, 1, 1, activation='relu', border_mode='same')(input) 141 | ir2 = Convolution2D(224, 1, 3, activation='relu', border_mode='same')(ir2) 142 | ir2 = Convolution2D(256, 3, 1, activation='relu', border_mode='same')(ir2) 143 | 144 | ir_merge = merge([ir1, ir2], mode='concat', concat_axis=channel_axis) 145 | 146 | ir_conv = Convolution2D(2144, 1, 1, activation='linear', border_mode='same')(ir_merge) 147 | if scale_residual: ir_conv = Lambda(lambda x: x * 0.1)(ir_conv) 148 | 149 | out = merge([init, ir_conv], mode='sum') 150 | out = BatchNormalization(axis=channel_axis)(out) 151 | out = Activation("relu")(out) 152 | return out 153 | 154 | 155 | def reduction_A(input, k=192, l=224, m=256, n=384): 156 | if K.image_dim_ordering() == "th": 157 | channel_axis = 1 158 | else: 159 | channel_axis = -1 160 | 161 | r1 = MaxPooling2D((3,3), strides=(2,2))(input) 162 | 163 | r2 = Convolution2D(n, 3, 3, activation='relu', subsample=(2,2))(input) 164 | 165 | r3 = Convolution2D(k, 1, 1, activation='relu', border_mode='same')(input) 166 | r3 = Convolution2D(l, 3, 3, activation='relu', border_mode='same')(r3) 167 | r3 = Convolution2D(m, 3, 3, activation='relu', subsample=(2,2))(r3) 168 | 169 | m = merge([r1, r2, r3], mode='concat', concat_axis=channel_axis) 170 | m = BatchNormalization(axis=1)(m) 171 | m = Activation('relu')(m) 172 | return m 173 | 174 | 175 | def reduction_resnet_v2_B(input): 176 | if K.image_dim_ordering() == "th": 177 | channel_axis = 1 178 | else: 179 | channel_axis = -1 180 | 181 | r1 = MaxPooling2D((3,3), strides=(2,2), border_mode='valid')(input) 182 | 183 | r2 = Convolution2D(256, 1, 1, activation='relu', border_mode='same')(input) 184 | r2 = Convolution2D(384, 3, 3, activation='relu', subsample=(2,2))(r2) 185 | 186 | r3 = Convolution2D(256, 1, 1, activation='relu', border_mode='same')(input) 187 | r3 = Convolution2D(288, 3, 3, activation='relu', subsample=(2, 2))(r3) 188 | 189 | r4 = Convolution2D(256, 1, 1, activation='relu', border_mode='same')(input) 190 | r4 = Convolution2D(288, 3, 3, activation='relu', border_mode='same')(r4) 191 | r4 = Convolution2D(320, 3, 3, activation='relu', subsample=(2, 2))(r4) 192 | 193 | m = merge([r1, r2, r3, r4], concat_axis=channel_axis, mode='concat') 194 | m = BatchNormalization(axis=channel_axis)(m) 195 | m = Activation('relu')(m) 196 | return m 197 | 198 | def create_inception_resnet_v2(nb_classes=1001, scale=True): 199 | ''' 200 | Creates a inception resnet v2 network 201 | 202 | :param nb_classes: number of classes.txt 203 | :param scale: flag to add scaling of activations 204 | :return: Keras Model with 1 input (299x299x3) input shape and 2 outputs (final_output, auxiliary_output) 205 | ''' 206 | 207 | # Input Shape is 299 x 299 x 1 (tf) or 1 x 299 x 299 (th) 208 | init = Input(INPUT_SHAPE) 209 | 210 | x = inception_resnet_stem(init) 211 | 212 | # 10 x Inception Resnet A 213 | for i in range(10): 214 | x = inception_resnet_v2_A(x, scale_residual=scale) 215 | 216 | # Reduction A 217 | x = reduction_A(x, k=256, l=256, m=384, n=384) 218 | 219 | # 20 x Inception Resnet B 220 | for i in range(20): 221 | x = inception_resnet_v2_B(x, scale_residual=scale) 222 | 223 | # Auxiliary tower 224 | aux_out = AveragePooling2D((5, 5), strides=(3, 3))(x) 225 | aux_out = Convolution2D(128, 1, 1, border_mode='same', activation='relu')(aux_out) 226 | aux_out = Convolution2D(768, 5, 5, activation='relu')(aux_out) 227 | aux_out = Flatten()(aux_out) 228 | aux_out = Dense(nb_classes, activation='softmax')(aux_out) 229 | 230 | # Reduction Resnet B 231 | x = reduction_resnet_v2_B(x) 232 | 233 | # 10 x Inception Resnet C 234 | for i in range(10): 235 | x = inception_resnet_v2_C(x, scale_residual=scale) 236 | 237 | # Average Pooling 238 | x = AveragePooling2D((8,8))(x) 239 | 240 | # Dropout 241 | x = Dropout(0.8)(x) 242 | x = Flatten()(x) 243 | 244 | # Output 245 | out = Dense(output_dim=nb_classes, activation='softmax')(x) 246 | 247 | model = Model(init, output=[out, aux_out], name='Inception-Resnet-v2') 248 | return model 249 | 250 | if __name__ == "__main__": 251 | from keras.utils.visualize_util import plot 252 | 253 | inception_resnet_v2 = create_inception_resnet_v2() 254 | #inception_resnet_v2.summary() 255 | 256 | plot(inception_resnet_v2, to_file="Inception ResNet-v2.png", show_shapes=True) --------------------------------------------------------------------------------