├── Architectures
├── Inception-v4.png
├── Inception ResNet-v1.png
└── Inception ResNet-v2.png
├── README.md
├── google-cloud.sh
├── .gitignore
├── experiment.sh
├── data.py
├── train_inception_v4.py
├── train_inception_resnet_v2.py
├── inception_v4.py
└── inception_resnet_v2.py
/Architectures/Inception-v4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaspermarstal/BrainNet/HEAD/Architectures/Inception-v4.png
--------------------------------------------------------------------------------
/Architectures/Inception ResNet-v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaspermarstal/BrainNet/HEAD/Architectures/Inception ResNet-v1.png
--------------------------------------------------------------------------------
/Architectures/Inception ResNet-v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaspermarstal/BrainNet/HEAD/Architectures/Inception ResNet-v2.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BrainNet
2 | Adaptation of @titu1994's Inception v4 and Inception ResNet v4 architectures to MRI images of the human brain. The paper on these architectures is available at "Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning".
3 |
4 | ## Please note
5 | This repo serves as an example on how to run experiments on Google Cloud, not how to segment brain images. Today there are more efficient architectures out there for this kind of segmentation.
6 |
7 | ## Experiment
8 | This repository contains code for training the networks to segment white matter and gray matter on MRI scans from the The Open Access Series of Imaging Studies (OASIS) archive.
9 |
10 | To start the experiment, clone the repository and run
11 |
12 | ```
13 | $ ./experiment.sh
14 | ```
15 |
16 | Data is downloaded, extracted and preprocessed automatically.
17 |
18 | ## Google Cloud
19 | Provision a Google Cloud CPU or GPU instance with `google-cloud.sh` using either of the following commands:
20 |
21 | ```
22 | $ ./google-cloud.sh --create-cpu-instance
23 | $ ./google-cloud.sh --create-gpu-instance
24 | ```
25 |
26 | SSH into the instance once it is up and running, clone, and invoke `experiment.sh` from there.
27 |
--------------------------------------------------------------------------------
/google-cloud.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ $1 == "--create-cpu-instance" ]]; then
4 | gcloud beta compute instances create brainnet-cpu \
5 | --zone europe-west1-b \
6 | --machine-type n1-highcpu-16 \
7 | --image-project ubuntu-os-cloud \
8 | --image-family ubuntu-1604-lts \
9 | --boot-disk-device-name=brainnet \
10 | --boot-disk-type=pd-standard \
11 | --boot-disk-size=64GB \
12 | --maintenance-policy TERMINATE --restart-on-failure
13 | fi
14 |
15 | if [[ $1 == "--create-gpu-instance" ]]; then
16 | gcloud beta compute instances create brainnet-gpu \
17 | --zone europe-west1-b \
18 | --machine-type n1-highmem-2 \
19 | --image-project ubuntu-os-cloud \
20 | --image-family ubuntu-1604-lts \
21 | --boot-disk-device-name=brainnet-gpu \
22 | --boot-disk-type=pd-standard \
23 | --boot-disk-size=64GB \
24 | --accelerator type=nvidia-tesla-k80,count=1 \
25 | --maintenance-policy TERMINATE --restart-on-failure \
26 | --metadata startup-script='#!/bin/bash
27 | echo "Checking for CUDA and installing."
28 | if ! dpkg-query -W cuda; then
29 | curl -O http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_8.0.61-1_amd64.deb
30 | dpkg -i ./cuda-repo-ubuntu1604_8.0.61-1_amd64.deb
31 | apt-get update
32 | apt-get install cuda -y
33 | fi'
34 | fi
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *,cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 |
93 | # Rope project settings
94 | .ropeproject
95 |
96 | # PyCharm project settings
97 | .idea/
98 |
99 | # TensorFlow
100 | TensorBoard/
--------------------------------------------------------------------------------
/experiment.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ $1 == "--download-data" ]]; then
4 | if [[ -d ~/BrainNet/downloads ]]; then
5 | rm -fr ~/BrainNet/downloads
6 | fi
7 | mkdir -p ~/BrainNet/downloads
8 | cd ~/BrainNet/downloads
9 | curl -OL ftp://ftp.nrg.wustl.edu/data/oasis_cross-sectional_disc{1,2,3,4,5,6,7,8,9,10,11}.tar.gz
10 | fi
11 |
12 | if [[ $1 == "--extract-data" ]]; then
13 | if [[ -d ~/BrainNet/data ]]; then
14 | rm -fr ~/BrainNet/data
15 | fi
16 | mkdir -p ~/BrainNet/data
17 | for filename in ~/downloads/*.tar.gz
18 | do
19 | tar zxf $filename -C ~/data
20 | done
21 | fi
22 |
23 | if [[ $1 == "--install-venv" ]]; then
24 | if [[ -d ~/BrainNet/venv ]]; then
25 | rm -fr ~/BrainNet/venv
26 | fi
27 | sudo apt-get install -y virtualenv gcc python-dev
28 | virtualenv ~/BrainNet/venv
29 | source ~/BrainNet/venv/bin/activate
30 | pip install keras SimpleITK numpy sklearn scikit-image tensorflow
31 | fi
32 |
33 | if [[ $1 == "--install-tensorflow-gpu" ]]; then
34 | source ~/BrainNet/venv/bin/activate
35 | pip install --upgrade tensorflow-gpu
36 | fi
37 |
38 | if [[ $1 == "--run-inception-v4" ]]; then
39 | source ~/BrainNet/venv/bin/activate
40 | python ~/BrainNet/BrainNet/train_inception_v4.py --data-dir=$HOME/data
41 | fi
42 |
43 | if [[ $1 == "--run-inception-resnet-v2" ]]; then
44 | source ~/BrainNet/venv/bin/activate
45 | python ~/BrainNet/BrainNet/train_inception_resnet_v2.py --data-dir=$HOME/BrainNet/data
46 | fi
47 |
48 | if [ -z $1 ]; then
49 | ./experiment.sh --download-data
50 | ./experiment.sh --extract-data
51 | ./experiment.sh --install-venv
52 | ./experiment.sh --install-tensorflow-gpu || true
53 | ./experiment.sh --run-inception-v4
54 | fi
55 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # Util
2 | import time
3 |
4 | # Images
5 | import numpy as np
6 | from numpy.random import random, randint
7 | from SimpleITK import Extract, GetArrayFromImage, ReadImage
8 | from skimage.transform import resize
9 |
10 | # Keras
11 | from keras.utils.np_utils import to_categorical
12 |
13 | # The number of classes is tied to the dataset
14 | nb_classes = 4
15 |
16 | class Timer(object):
17 | def __init__(self, name=None):
18 | self.name = name
19 |
20 | def __enter__(self):
21 | self.tstart = time.time()
22 |
23 | def __exit__(self, type, value, traceback):
24 | if self.name:
25 | print '[%s]' % self.name,
26 | print 'Elapsed: %s.' % (time.time() - self.tstart)
27 |
28 | def generator(images, labels, input_shape, patch_size=32, batch_size=32):
29 | while True:
30 | X = np.empty((batch_size, input_shape[0], input_shape[1], input_shape[2]), dtype=float)
31 | y = np.empty(batch_size, dtype=int)
32 |
33 | for i in range(0, batch_size):
34 | subject_id = randint(0, len(images))
35 | X[i, :, :], y[i] = generate_one(images[subject_id], labels[subject_id], input_shape, patch_size)
36 |
37 | y = to_categorical(y, nb_classes=nb_classes)
38 |
39 | yield X, y
40 |
41 | def generate_one(image, label, input_shape, patch_size):
42 | p = random()
43 |
44 | # TODO: Obtain patches at abitrary angles
45 | if p > 0.66:
46 | patch, label = axial_patch_generator(image, label, patch_size)
47 | elif p > 0.33:
48 | patch, label = coronal_patch_generator(image, label, patch_size)
49 | else:
50 | patch, label = sagittal_patch_generator(image, label, patch_size)
51 |
52 | patch = GetArrayFromImage(patch)
53 |
54 | if random() > 0.5:
55 | patch = np.fliplr(patch)
56 |
57 | if random() > 0.5:
58 | patch = np.flipud(patch)
59 |
60 | patch = resize(patch, input_shape)
61 |
62 | return patch, label
63 |
64 | def axial_patch_generator(image, label, patch_size):
65 | image_size = image.GetSize()
66 | assert(image_size == label.GetSize())
67 |
68 | point = (randint(0, image_size[0] - patch_size), randint(0, image_size[1]), randint(0, image_size[2] - patch_size))
69 | patch = Extract(image, (patch_size, 0, patch_size), point)
70 |
71 | return patch, label.GetPixel(point)
72 |
73 | def coronal_patch_generator(image, label, patch_size):
74 | image_size = image.GetSize()
75 | assert (image_size == label.GetSize())
76 |
77 | point = (randint(0, image_size[0] - patch_size), randint(0, image_size[1] - patch_size), randint(0, image_size[2]))
78 | patch = Extract(image, (patch_size, patch_size, 0), point)
79 |
80 | return patch, label.GetPixel(point)
81 |
82 | def sagittal_patch_generator(image, label, patch_size):
83 | image_size = image.GetSize()
84 | assert (image_size == label.GetSize())
85 |
86 | point = (randint(0, image_size[0]), randint(0, image_size[1] - patch_size), randint(0, image_size[2] - patch_size))
87 | patch = Extract(image, (0, patch_size, patch_size), point)
88 |
89 | return patch, label.GetPixel(point)
90 |
91 |
92 |
--------------------------------------------------------------------------------
/train_inception_v4.py:
--------------------------------------------------------------------------------
1 | # Util
2 | import os
3 | import sys
4 | import argparse
5 | from glob import glob
6 |
7 | # Data
8 | from data import generator, nb_classes, ReadImage, Timer
9 | from sklearn.model_selection import train_test_split
10 |
11 | # Keras
12 | from keras.callbacks import EarlyStopping, TensorBoard
13 | from keras.optimizers import RMSprop
14 |
15 | # Model
16 | from inception_v4 import create_inception_v4, input_shape
17 |
18 | # TensorFlow
19 | from tensorflow.python.platform import app
20 |
21 | def main(argv):
22 |
23 | print('Finding data ...'),
24 | with Timer():
25 | image_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/PROCESSED/MPRAGE/T88_111/OAS1_*_MR1_mpr_n4_anon_111_t88_gfc.hdr'))
26 | label_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/FSL_SEG/OAS1_*_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.hdr'))
27 | assert(len(image_filenames) == len(label_filenames))
28 | print('Found %i images.' % len(image_filenames))
29 |
30 | print('Loading images ...'),
31 | with Timer():
32 | images = [ReadImage(image_filename) for image_filename in image_filenames]
33 | labels = [ReadImage(label_filename) for label_filename in label_filenames]
34 | images_train, images_test, labels_train, labels_test = train_test_split(images, labels, train_size=0.66)
35 |
36 | tensor_board = TensorBoard(log_dir='./TensorBoard')
37 | early_stopping = EarlyStopping(monitor='acc', patience=2, verbose=1)
38 |
39 | model = create_inception_v4(nb_classes=nb_classes, load_weights=False)
40 | model.compile(optimizer=RMSprop(lr=0.045, rho=0.94, epsilon=1., decay=0.9), loss='categorical_crossentropy', metrics=['acc'])
41 | model.fit_generator(generator(images_train, labels_train, input_shape, nb_classes, FLAGS.patch_size, FLAGS.batch_size),
42 | samples_per_epoch=FLAGS.samples_per_epoch, nb_epoch=FLAGS.nb_epochs, callbacks=[tensor_board, early_stopping],
43 | verbose=1)
44 |
45 | if __name__ == '__main__':
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument(
48 | '-d',
49 | '--data-dir',
50 | dest='data_dir',
51 | help='Path to data directory.',
52 | required=True,
53 | )
54 | parser.add_argument(
55 | '-p',
56 | '--patch-size',
57 | default=32,
58 | type=int,
59 | dest='patch_size',
60 | help='Size of the p-by-p patch in millimetre (mm).',
61 | )
62 | parser.add_argument(
63 | '-b',
64 | '--batch-size',
65 | default=32,
66 | type=int,
67 | dest='batch_size',
68 | help='Batch size.',
69 | )
70 | parser.add_argument(
71 | '-e',
72 | '--nb-epochs',
73 | default=8,
74 | type=int,
75 | dest='nb_epochs',
76 | help='Number of epochs.',
77 | )
78 | parser.add_argument(
79 | '-s',
80 | '--samples-per-epoch',
81 | default=1024,
82 | type=int,
83 | dest='samples_per_epoch',
84 | help='Number of samples per epoch.',
85 | )
86 |
87 | FLAGS, unparsed = parser.parse_known_args()
88 | app.run(main=main, argv=[sys.argv[0]] + unparsed)
89 |
--------------------------------------------------------------------------------
/train_inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | # Util
2 | import os
3 | import sys
4 | import argparse
5 | from glob import glob
6 |
7 | # Data
8 | from data import generator, nb_classes, ReadImage, Timer
9 | from sklearn.model_selection import train_test_split
10 |
11 | # Keras
12 | from keras.callbacks import EarlyStopping, TensorBoard
13 | from keras.optimizers import RMSprop
14 |
15 | # Model
16 | from inception_resnet_v2 import create_inception_resnet_v2, input_shape
17 |
18 | # TensorFlow
19 | from tensorflow.python.platform import app
20 |
21 | def main(argv):
22 |
23 | print('Finding data ...'),
24 | with Timer():
25 | image_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/PROCESSED/MPRAGE/T88_111/OAS1_*_MR1_mpr_n4_anon_111_t88_gfc.hdr'))
26 | label_filenames = glob(os.path.join(FLAGS.data_dir, 'disc*/OAS1_*_MR1/FSL_SEG/OAS1_*_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.hdr'))
27 | assert(len(image_filenames) == len(label_filenames))
28 | print('Found %i images.' % len(image_filenames))
29 |
30 | print('Loading images ...'),
31 | with Timer():
32 | images = [ReadImage(image_filename) for image_filename in image_filenames]
33 | labels = [ReadImage(label_filename) for label_filename in label_filenames]
34 | images_train, images_test, labels_train, labels_test = train_test_split(images, labels, train_size=0.66)
35 |
36 | tensor_board = TensorBoard(log_dir='./TensorBoard')
37 | early_stopping = EarlyStopping(monitor='acc', patience=2, verbose=1)
38 |
39 | model = create_inception_resnet_v2(nb_classes=nb_classes)
40 | model.compile(optimizer=RMSprop(lr=0.045, rho=0.94, epsilon=1., decay=0.9), loss='categorical_crossentropy', metrics=['acc'])
41 | model.fit_generator(generator(images_train, labels_train, input_shape, nb_classes, FLAGS.patch_size, FLAGS.batch_size),
42 | samples_per_epoch=FLAGS.samples_per_epoch, nb_epoch=FLAGS.nb_epochs, callbacks=[tensor_board, early_stopping],
43 | verbose=1)
44 |
45 | if __name__ == '__main__':
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument(
48 | '-d',
49 | '--data-dir',
50 | dest='data_dir',
51 | help='Path to data directory.',
52 | required=True,
53 | )
54 | parser.add_argument(
55 | '-p',
56 | '--patch-size',
57 | default=32,
58 | type=int,
59 | dest='patch_size',
60 | help='Size of the p-by-p patch in millimetre (mm).',
61 | )
62 | parser.add_argument(
63 | '-b',
64 | '--batch-size',
65 | default=32,
66 | type=int,
67 | dest='batch_size',
68 | help='Batch size.',
69 | )
70 | parser.add_argument(
71 | '-e',
72 | '--nb-epochs',
73 | default=8,
74 | type=int,
75 | dest='nb_epochs',
76 | help='Number of epochs.',
77 | )
78 | parser.add_argument(
79 | '-s',
80 | '--samples-per-epoch',
81 | default=1024,
82 | type=int,
83 | dest='samples_per_epoch',
84 | help='Number of samples per epoch.',
85 | )
86 |
87 | FLAGS, unparsed = parser.parse_known_args()
88 | app.run(main=main, argv=[sys.argv[0]] + unparsed)
89 |
90 |
--------------------------------------------------------------------------------
/inception_v4.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Input, merge, Dropout, Dense, Flatten, Activation
2 | from keras.layers.convolutional import MaxPooling2D, Convolution2D, AveragePooling2D
3 | from keras.layers.normalization import BatchNormalization
4 | from keras.models import Model
5 |
6 | from keras import backend as K
7 | from keras.utils.data_utils import get_file
8 |
9 | """
10 | Implementation of Inception Network v4 [Inception Network v4 Paper](http://arxiv.org/pdf/1602.07261v1.pdf) in Keras.
11 | """
12 |
13 | # The input shape is tied to the network
14 | if K.image_dim_ordering() == 'th':
15 | input_shape = (1, 299, 299)
16 | else:
17 | input_shape = (299, 299, 1)
18 |
19 | def conv_block(x, nb_filter, nb_row, nb_col, border_mode='same', subsample=(1, 1), bias=False):
20 | if K.image_dim_ordering() == "th":
21 | channel_axis = 1
22 | else:
23 | channel_axis = -1
24 |
25 | x = Convolution2D(nb_filter, nb_row, nb_col, subsample=subsample, border_mode=border_mode, bias=bias)(x)
26 | x = BatchNormalization(axis=channel_axis)(x)
27 | x = Activation('relu')(x)
28 | return x
29 |
30 |
31 | def inception_stem(input):
32 | if K.image_dim_ordering() == "th":
33 | channel_axis = 1
34 | else:
35 | channel_axis = -1
36 |
37 | # Input Shape is 299 x 299 x 3 (th) or 3 x 299 x 299 (th)
38 | x = conv_block(input, 32, 3, 3, subsample=(2, 2), border_mode='valid')
39 | x = conv_block(x, 32, 3, 3, border_mode='valid')
40 | x = conv_block(x, 64, 3, 3)
41 |
42 | x1 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(x)
43 | x2 = conv_block(x, 96, 3, 3, subsample=(2, 2), border_mode='valid')
44 |
45 | x = merge([x1, x2], mode='concat', concat_axis=channel_axis)
46 |
47 | x1 = conv_block(x, 64, 1, 1)
48 | x1 = conv_block(x1, 96, 3, 3, border_mode='valid')
49 |
50 | x2 = conv_block(x, 64, 1, 1)
51 | x2 = conv_block(x2, 64, 1, 7)
52 | x2 = conv_block(x2, 64, 7, 1)
53 | x2 = conv_block(x2, 96, 3, 3, border_mode='valid')
54 |
55 | x = merge([x1, x2], mode='concat', concat_axis=channel_axis)
56 |
57 | x1 = conv_block(x, 192, 3, 3, subsample=(2, 2), border_mode='valid')
58 | x2 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(x)
59 |
60 | x = merge([x1, x2], mode='concat', concat_axis=channel_axis)
61 | return x
62 |
63 |
64 | def inception_A(input):
65 | if K.image_dim_ordering() == "th":
66 | channel_axis = 1
67 | else:
68 | channel_axis = -1
69 |
70 | a1 = conv_block(input, 96, 1, 1)
71 |
72 | a2 = conv_block(input, 64, 1, 1)
73 | a2 = conv_block(a2, 96, 3, 3)
74 |
75 | a3 = conv_block(input, 64, 1, 1)
76 | a3 = conv_block(a3, 96, 3, 3)
77 | a3 = conv_block(a3, 96, 3, 3)
78 |
79 | a4 = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(input)
80 | a4 = conv_block(a4, 96, 1, 1)
81 |
82 | m = merge([a1, a2, a3, a4], mode='concat', concat_axis=channel_axis)
83 | return m
84 |
85 |
86 | def inception_B(input):
87 | if K.image_dim_ordering() == "th":
88 | channel_axis = 1
89 | else:
90 | channel_axis = -1
91 |
92 | b1 = conv_block(input, 384, 1, 1)
93 |
94 | b2 = conv_block(input, 192, 1, 1)
95 | b2 = conv_block(b2, 224, 1, 7)
96 | b2 = conv_block(b2, 256, 7, 1)
97 |
98 | b3 = conv_block(input, 192, 1, 1)
99 | b3 = conv_block(b3, 192, 7, 1)
100 | b3 = conv_block(b3, 224, 1, 7)
101 | b3 = conv_block(b3, 224, 7, 1)
102 | b3 = conv_block(b3, 256, 1, 7)
103 |
104 | b4 = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(input)
105 | b4 = conv_block(b4, 128, 1, 1)
106 |
107 | m = merge([b1, b2, b3, b4], mode='concat', concat_axis=channel_axis)
108 | return m
109 |
110 |
111 | def inception_C(input):
112 | if K.image_dim_ordering() == "th":
113 | channel_axis = 1
114 | else:
115 | channel_axis = -1
116 |
117 | c1 = conv_block(input, 256, 1, 1)
118 |
119 | c2 = conv_block(input, 384, 1, 1)
120 | c2_1 = conv_block(c2, 256, 1, 3)
121 | c2_2 = conv_block(c2, 256, 3, 1)
122 | c2 = merge([c2_1, c2_2], mode='concat', concat_axis=channel_axis)
123 |
124 | c3 = conv_block(input, 384, 1, 1)
125 | c3 = conv_block(c3, 448, 3, 1)
126 | c3 = conv_block(c3, 512, 1, 3)
127 | c3_1 = conv_block(c3, 256, 1, 3)
128 | c3_2 = conv_block(c3, 256, 3, 1)
129 | c3 = merge([c3_1, c3_2], mode='concat', concat_axis=channel_axis)
130 |
131 | c4 = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(input)
132 | c4 = conv_block(c4, 256, 1, 1)
133 |
134 | m = merge([c1, c2, c3, c4], mode='concat', concat_axis=channel_axis)
135 | return m
136 |
137 |
138 | def reduction_A(input):
139 | if K.image_dim_ordering() == "th":
140 | channel_axis = 1
141 | else:
142 | channel_axis = -1
143 |
144 | r1 = conv_block(input, 384, 3, 3, subsample=(2, 2), border_mode='valid')
145 |
146 | r2 = conv_block(input, 192, 1, 1)
147 | r2 = conv_block(r2, 224, 3, 3)
148 | r2 = conv_block(r2, 256, 3, 3, subsample=(2, 2), border_mode='valid')
149 |
150 | r3 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(input)
151 |
152 | m = merge([r1, r2, r3], mode='concat', concat_axis=channel_axis)
153 | return m
154 |
155 |
156 | def reduction_B(input):
157 | if K.image_dim_ordering() == "th":
158 | channel_axis = 1
159 | else:
160 | channel_axis = -1
161 |
162 | r1 = conv_block(input, 192, 1, 1)
163 | r1 = conv_block(r1, 192, 3, 3, subsample=(2, 2), border_mode='valid')
164 |
165 | r2 = conv_block(input, 256, 1, 1)
166 | r2 = conv_block(r2, 256, 1, 7)
167 | r2 = conv_block(r2, 320, 7, 1)
168 | r2 = conv_block(r2, 320, 3, 3, subsample=(2, 2), border_mode='valid')
169 |
170 | r3 = MaxPooling2D((3, 3), strides=(2, 2), border_mode='valid')(input)
171 |
172 | m = merge([r1, r2, r3], mode='concat', concat_axis=channel_axis)
173 | return m
174 |
175 |
176 | def create_inception_v4(nb_classes=1001, load_weights=True):
177 | '''
178 | Creates a inception v4 network
179 |
180 | :param nb_classes: number of classes.txt
181 | :return: Keras Model with 1 input and 1 output
182 | '''
183 |
184 | # Input Shape is 299 x 299 x 1 (tf) or 1 x 299 x 299 (th)
185 | init = Input(input_shape)
186 |
187 | x = inception_stem(init)
188 |
189 | # 4 x Inception A
190 | for i in range(4):
191 | x = inception_A(x)
192 |
193 | # Reduction A
194 | x = reduction_A(x)
195 |
196 | # 7 x Inception B
197 | for i in range(7):
198 | x = inception_B(x)
199 |
200 | # Reduction B
201 | x = reduction_B(x)
202 |
203 | # 3 x Inception C
204 | for i in range(3):
205 | x = inception_C(x)
206 |
207 | # Average Pooling
208 | x = AveragePooling2D((8, 8))(x)
209 |
210 | # Dropout
211 | x = Dropout(0.8)(x)
212 | x = Flatten()(x)
213 |
214 | # Output
215 | out = Dense(output_dim=nb_classes, activation='softmax')(x)
216 |
217 | model = Model(init, out, name='Inception-v4')
218 | return model
219 |
220 |
221 | if __name__ == "__main__":
222 | # from keras.utils.visualize_util import plot
223 |
224 | inception_v4 = create_inception_v4(load_weights=True)
225 | # inception_v4.summary()
226 |
227 | # plot(inception_v4, to_file="Inception-v4.png", show_shapes=True)
228 |
--------------------------------------------------------------------------------
/inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Input, merge, Dropout, Dense, Lambda, Flatten, Activation
2 | from keras.layers.convolutional import MaxPooling2D, Convolution2D, AveragePooling2D
3 | from keras.layers.normalization import BatchNormalization
4 | from keras.models import Model
5 |
6 | from keras import backend as K
7 |
8 | import warnings
9 | warnings.filterwarnings('ignore')
10 |
11 | """
12 | Implementation of Inception-Residual Network v1 [Inception Network v4 Paper](http://arxiv.org/pdf/1602.07261v1.pdf) in Keras.
13 |
14 | Some additional details:
15 | [1] Each of the A, B and C blocks have a 'scale_residual' parameter.
16 | The scale residual parameter is according to the paper. It is however turned OFF by default.
17 |
18 | Simply setting 'scale=True' in the create_inception_resnet_v2() method will add scaling.
19 |
20 | [2] There were minor inconsistencies with filter size in both B and C blocks.
21 |
22 | In the B blocks: 'ir_conv' nb of filters is given as 1154, however input size is 1152.
23 | This causes inconsistencies in the merge-add mode, therefore the 'ir_conv' filter size
24 | is reduced to 1152 to match input size.
25 |
26 | In the C blocks: 'ir_conv' nb of filter is given as 2048, however input size is 2144.
27 | This causes inconsistencies in the merge-add mode, therefore the 'ir_conv' filter size
28 | is increased to 2144 to match input size.
29 |
30 | Currently trying to find a proper solution with original nb of filters.
31 |
32 | [3] In the stem function, the last Convolutional2D layer has 384 filters instead of the original 256.
33 | This is to correctly match the nb of filters in 'ir_conv' of the next A blocks.
34 | """
35 |
36 | # The input shape is tied to the network
37 | if K.image_dim_ordering() == 'th':
38 | input_shape = (1, 299, 299)
39 | else:
40 | input_shape = (299, 299, 1)
41 |
42 | def inception_resnet_stem(input):
43 | if K.image_dim_ordering() == "th":
44 | channel_axis = 1
45 | else:
46 | channel_axis = -1
47 |
48 | # Input Shape is 299 x 299 x 3 (th) or 3 x 299 x 299 (th)
49 | c = Convolution2D(32, 3, 3, activation='relu', subsample=(2, 2))(input)
50 | c = Convolution2D(32, 3, 3, activation='relu', )(c)
51 | c = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(c)
52 |
53 | c1 = MaxPooling2D((3, 3), strides=(2, 2))(c)
54 | c2 = Convolution2D(96, 3, 3, activation='relu', subsample=(2, 2))(c)
55 |
56 | m = merge([c1, c2], mode='concat', concat_axis=channel_axis)
57 |
58 | c1 = Convolution2D(64, 1, 1, activation='relu', border_mode='same')(m)
59 | c1 = Convolution2D(96, 3, 3, activation='relu', )(c1)
60 |
61 | c2 = Convolution2D(64, 1, 1, activation='relu', border_mode='same')(m)
62 | c2 = Convolution2D(64, 7, 1, activation='relu', border_mode='same')(c2)
63 | c2 = Convolution2D(64, 1, 7, activation='relu', border_mode='same')(c2)
64 | c2 = Convolution2D(96, 3, 3, activation='relu', border_mode='valid')(c2)
65 |
66 | m2 = merge([c1, c2], mode='concat', concat_axis=channel_axis)
67 |
68 | p1 = MaxPooling2D((3, 3), strides=(2, 2), )(m2)
69 | p2 = Convolution2D(192, 3, 3, activation='relu', subsample=(2, 2))(m2)
70 |
71 | m3 = merge([p1, p2], mode='concat', concat_axis=channel_axis)
72 | m3 = BatchNormalization(axis=channel_axis)(m3)
73 | m3 = Activation('relu')(m3)
74 | return m3
75 |
76 | def inception_resnet_v2_A(input, scale_residual=True):
77 | if K.image_dim_ordering() == "th":
78 | channel_axis = 1
79 | else:
80 | channel_axis = -1
81 |
82 | # Input is relu activation
83 | init = input
84 |
85 | ir1 = Convolution2D(32, 1, 1, activation='relu', border_mode='same')(input)
86 |
87 | ir2 = Convolution2D(32, 1, 1, activation='relu', border_mode='same')(input)
88 | ir2 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(ir2)
89 |
90 | ir3 = Convolution2D(32, 1, 1, activation='relu', border_mode='same')(input)
91 | ir3 = Convolution2D(48, 3, 3, activation='relu', border_mode='same')(ir3)
92 | ir3 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(ir3)
93 |
94 | ir_merge = merge([ir1, ir2, ir3], concat_axis=channel_axis, mode='concat')
95 |
96 | ir_conv = Convolution2D(384, 1, 1, activation='linear', border_mode='same')(ir_merge)
97 | if scale_residual: ir_conv = Lambda(lambda x: x * 0.1)(ir_conv)
98 |
99 | out = merge([init, ir_conv], mode='sum')
100 | out = BatchNormalization(axis=channel_axis)(out)
101 | out = Activation("relu")(out)
102 | return out
103 |
104 | def inception_resnet_v2_B(input, scale_residual=True):
105 | if K.image_dim_ordering() == "th":
106 | channel_axis = 1
107 | else:
108 | channel_axis = -1
109 |
110 | # Input is relu activation
111 | init = input
112 |
113 | ir1 = Convolution2D(192, 1, 1, activation='relu', border_mode='same')(input)
114 |
115 | ir2 = Convolution2D(128, 1, 1, activation='relu', border_mode='same')(input)
116 | ir2 = Convolution2D(160, 1, 7, activation='relu', border_mode='same')(ir2)
117 | ir2 = Convolution2D(192, 7, 1, activation='relu', border_mode='same')(ir2)
118 |
119 | ir_merge = merge([ir1, ir2], mode='concat', concat_axis=channel_axis)
120 |
121 | ir_conv = Convolution2D(1152, 1, 1, activation='linear', border_mode='same')(ir_merge)
122 | if scale_residual: ir_conv = Lambda(lambda x: x * 0.1)(ir_conv)
123 |
124 | out = merge([init, ir_conv], mode='sum')
125 | out = BatchNormalization(axis=channel_axis)(out)
126 | out = Activation("relu")(out)
127 | return out
128 |
129 | def inception_resnet_v2_C(input, scale_residual=True):
130 | if K.image_dim_ordering() == "th":
131 | channel_axis = 1
132 | else:
133 | channel_axis = -1
134 |
135 | # Input is relu activation
136 | init = input
137 |
138 | ir1 = Convolution2D(192, 1, 1, activation='relu', border_mode='same')(input)
139 |
140 | ir2 = Convolution2D(192, 1, 1, activation='relu', border_mode='same')(input)
141 | ir2 = Convolution2D(224, 1, 3, activation='relu', border_mode='same')(ir2)
142 | ir2 = Convolution2D(256, 3, 1, activation='relu', border_mode='same')(ir2)
143 |
144 | ir_merge = merge([ir1, ir2], mode='concat', concat_axis=channel_axis)
145 |
146 | ir_conv = Convolution2D(2144, 1, 1, activation='linear', border_mode='same')(ir_merge)
147 | if scale_residual: ir_conv = Lambda(lambda x: x * 0.1)(ir_conv)
148 |
149 | out = merge([init, ir_conv], mode='sum')
150 | out = BatchNormalization(axis=channel_axis)(out)
151 | out = Activation("relu")(out)
152 | return out
153 |
154 |
155 | def reduction_A(input, k=192, l=224, m=256, n=384):
156 | if K.image_dim_ordering() == "th":
157 | channel_axis = 1
158 | else:
159 | channel_axis = -1
160 |
161 | r1 = MaxPooling2D((3,3), strides=(2,2))(input)
162 |
163 | r2 = Convolution2D(n, 3, 3, activation='relu', subsample=(2,2))(input)
164 |
165 | r3 = Convolution2D(k, 1, 1, activation='relu', border_mode='same')(input)
166 | r3 = Convolution2D(l, 3, 3, activation='relu', border_mode='same')(r3)
167 | r3 = Convolution2D(m, 3, 3, activation='relu', subsample=(2,2))(r3)
168 |
169 | m = merge([r1, r2, r3], mode='concat', concat_axis=channel_axis)
170 | m = BatchNormalization(axis=1)(m)
171 | m = Activation('relu')(m)
172 | return m
173 |
174 |
175 | def reduction_resnet_v2_B(input):
176 | if K.image_dim_ordering() == "th":
177 | channel_axis = 1
178 | else:
179 | channel_axis = -1
180 |
181 | r1 = MaxPooling2D((3,3), strides=(2,2), border_mode='valid')(input)
182 |
183 | r2 = Convolution2D(256, 1, 1, activation='relu', border_mode='same')(input)
184 | r2 = Convolution2D(384, 3, 3, activation='relu', subsample=(2,2))(r2)
185 |
186 | r3 = Convolution2D(256, 1, 1, activation='relu', border_mode='same')(input)
187 | r3 = Convolution2D(288, 3, 3, activation='relu', subsample=(2, 2))(r3)
188 |
189 | r4 = Convolution2D(256, 1, 1, activation='relu', border_mode='same')(input)
190 | r4 = Convolution2D(288, 3, 3, activation='relu', border_mode='same')(r4)
191 | r4 = Convolution2D(320, 3, 3, activation='relu', subsample=(2, 2))(r4)
192 |
193 | m = merge([r1, r2, r3, r4], concat_axis=channel_axis, mode='concat')
194 | m = BatchNormalization(axis=channel_axis)(m)
195 | m = Activation('relu')(m)
196 | return m
197 |
198 | def create_inception_resnet_v2(nb_classes=1001, scale=True):
199 | '''
200 | Creates a inception resnet v2 network
201 |
202 | :param nb_classes: number of classes.txt
203 | :param scale: flag to add scaling of activations
204 | :return: Keras Model with 1 input (299x299x3) input shape and 2 outputs (final_output, auxiliary_output)
205 | '''
206 |
207 | # Input Shape is 299 x 299 x 1 (tf) or 1 x 299 x 299 (th)
208 | init = Input(INPUT_SHAPE)
209 |
210 | x = inception_resnet_stem(init)
211 |
212 | # 10 x Inception Resnet A
213 | for i in range(10):
214 | x = inception_resnet_v2_A(x, scale_residual=scale)
215 |
216 | # Reduction A
217 | x = reduction_A(x, k=256, l=256, m=384, n=384)
218 |
219 | # 20 x Inception Resnet B
220 | for i in range(20):
221 | x = inception_resnet_v2_B(x, scale_residual=scale)
222 |
223 | # Auxiliary tower
224 | aux_out = AveragePooling2D((5, 5), strides=(3, 3))(x)
225 | aux_out = Convolution2D(128, 1, 1, border_mode='same', activation='relu')(aux_out)
226 | aux_out = Convolution2D(768, 5, 5, activation='relu')(aux_out)
227 | aux_out = Flatten()(aux_out)
228 | aux_out = Dense(nb_classes, activation='softmax')(aux_out)
229 |
230 | # Reduction Resnet B
231 | x = reduction_resnet_v2_B(x)
232 |
233 | # 10 x Inception Resnet C
234 | for i in range(10):
235 | x = inception_resnet_v2_C(x, scale_residual=scale)
236 |
237 | # Average Pooling
238 | x = AveragePooling2D((8,8))(x)
239 |
240 | # Dropout
241 | x = Dropout(0.8)(x)
242 | x = Flatten()(x)
243 |
244 | # Output
245 | out = Dense(output_dim=nb_classes, activation='softmax')(x)
246 |
247 | model = Model(init, output=[out, aux_out], name='Inception-Resnet-v2')
248 | return model
249 |
250 | if __name__ == "__main__":
251 | from keras.utils.visualize_util import plot
252 |
253 | inception_resnet_v2 = create_inception_resnet_v2()
254 | #inception_resnet_v2.summary()
255 |
256 | plot(inception_resnet_v2, to_file="Inception ResNet-v2.png", show_shapes=True)
--------------------------------------------------------------------------------