├── .gitignore ├── DataLoader.py ├── README.md ├── config.py ├── imgs └── Fig1.png ├── main.py ├── models ├── AlexNet_3D.py ├── FastCapsNet_3D.py ├── Original_CapsNet.py ├── ResNet_3D.py ├── __init__.py ├── base_model.py └── utils │ ├── __init__.py │ ├── loss_ops.py │ ├── ops_caps.py │ └── ops_cnn.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | *.hdf5 106 | *.h5 107 | data/ 108 | .idea/ 109 | idea/ 110 | Results/ 111 | 112 | -------------------------------------------------------------------------------- /DataLoader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import scipy 3 | import numpy as np 4 | import h5py 5 | import scipy.ndimage 6 | 7 | 8 | class DataLoader(object): 9 | 10 | def __init__(self, cfg): 11 | self.cfg = cfg 12 | self.mean = None 13 | self.std = None 14 | if cfg.percent == 1: 15 | self.data_path = './data/Lung_Nodule.h5' 16 | else: 17 | self.data_path = './data/Lung_Nodule_' + str(cfg.percent) + '.h5' 18 | 19 | def get_data(self, mode='train'): 20 | h5f = h5py.File(self.data_path, 'r') 21 | if mode == 'train': 22 | x_train = h5f['X_train'][:] 23 | y_train = h5f['Y_train'][:] 24 | self.x_train, self.y_train = self.prepare_data(x_train, y_train) 25 | elif mode == 'valid': 26 | x_valid = h5f['X_valid'][:] 27 | y_valid = h5f['Y_valid'][:] 28 | self.x_valid, self.y_valid = self.prepare_data(x_valid, y_valid) 29 | elif mode == 'test': 30 | x_test = h5f['X_valid'][:] 31 | y_test = h5f['Y_valid'][:] 32 | self.x_test, self.y_test = self.prepare_data(x_test, y_test) 33 | h5f.close() 34 | 35 | def prepare_data(self, x, y): 36 | if self.cfg.normalize: 37 | x = np.maximum(np.minimum(x, 4096.), 0.) 38 | try: 39 | _ = self.mean.shape 40 | except AttributeError: 41 | self.get_stats() 42 | x = (x - self.mean) / self.std 43 | x = x.reshape((-1, self.cfg.height, self.cfg.width, self.cfg.depth, self.cfg.channel)).astype(np.float32) 44 | if self.cfg.one_hot: 45 | y = (np.arange(self.cfg.num_cls) == y[:, None]).astype(np.float32) 46 | return x, y 47 | 48 | def next_batch(self, start=None, end=None, mode='train'): 49 | if mode == 'train': 50 | x = self.x_train[start:end] 51 | y = self.y_train[start:end] 52 | if self.cfg.data_augment: 53 | x = random_rotation_3d(x, self.cfg.max_angle) 54 | elif mode == 'valid': 55 | x = self.x_valid[start:end] 56 | y = self.y_valid[start:end] 57 | elif mode == 'test': 58 | x = self.x_test[start:end] 59 | y = self.y_test[start:end] 60 | return x, y 61 | 62 | def count_num_batch(self, batch_size, mode='train'): 63 | if mode == 'train': 64 | num_batch = int(self.y_train.shape[0] / batch_size) 65 | elif mode == 'valid': 66 | num_batch = int(self.y_valid.shape[0] / batch_size) 67 | elif mode == 'test': 68 | num_batch = int(self.y_test.shape[0] / batch_size) 69 | return num_batch 70 | 71 | def randomize(self): 72 | """ Randomizes the order of training data samples and their corresponding labels""" 73 | permutation = np.random.permutation(self.y_train.shape[0]) 74 | self.x_train = self.x_train[permutation, :, :, :] 75 | self.y_train = self.y_train[permutation] 76 | 77 | def get_stats(self): 78 | """ 79 | compute and store the mean and std of training samples (to be used for normalization) 80 | """ 81 | h5f = h5py.File(self.data_path, 'r') 82 | x_train = np.maximum(np.minimum(h5f['X_train'][:], 4096.), 0.) 83 | h5f.close() 84 | self.mean = np.mean(x_train, axis=0) 85 | self.std = np.std(x_train, axis=0) 86 | 87 | 88 | def random_rotation_3d(batch, max_angle): 89 | """ 90 | Randomly rotate an image by a random angle (-max_angle, max_angle). 91 | :param batch: batch of images of shape (batch_size, height, width, depth, channel) 92 | :param max_angle: maximum rotation angle in degree 93 | :return: array of rotated batch of images of the same shape as 'batch' 94 | """ 95 | size = batch.shape 96 | batch_rot = np.squeeze(batch) 97 | for i in range(batch.shape[0]): 98 | image = np.squeeze(batch[i]) 99 | if bool(random.getrandbits(1)): # rotate along x-axis 100 | angle = random.uniform(-max_angle, max_angle) 101 | image = scipy.ndimage.interpolation.rotate(image, angle, mode='nearest', axes=(1, 2), reshape=False) 102 | if bool(random.getrandbits(1)): # rotate along y-axis 103 | angle = random.uniform(-max_angle, max_angle) 104 | image = scipy.ndimage.interpolation.rotate(image, angle, mode='nearest', axes=(0, 2), reshape=False) 105 | if bool(random.getrandbits(1)): # rotate along z-axis 106 | angle = random.uniform(-max_angle, max_angle) 107 | image = scipy.ndimage.interpolation.rotate(image, angle, mode='nearest', axes=(0, 1), reshape=False) 108 | batch_rot[i] = image 109 | return batch_rot.reshape(size) 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast Capsule Network 2 | Official TensorFlow implementation of the Fast Capsule Network proposed in 3 | the paper [Fast CapsNet for Lung Cancer Screening](https://arxiv.org/abs/1806.07416). 4 | 5 | ![FastCapsNet](imgs/Fig1.png) 6 | *Fig1. Fast Capsule Network architecture* 7 | 8 | 9 | ## Dependencies 10 | - Python (2.7 preferably; also works fine with python 3) 11 | - NumPy 12 | - [Tensorflow](https://github.com/tensorflow/tensorflow)>=1.3 13 | - Matplotlib (for saving images) 14 | 15 | 16 | ## How to run the code 17 | 18 | ### 1. Prepare your data 19 | To run the code, you first need to store your data in a folder named 'data' inside the project folder. 20 | Given the current DataLoader code, it must be an HDF5 file containing train, validation and test sets. 21 | 22 | 23 | ### 2. Train 24 | Most of the network parameters can be found in ```config.py``` file. You may modify them or run with 25 | the default values which runs the 3D Fast Capsule Network proposed in the paper. 26 | 27 | 28 | Training the model displays the training results and saves the trained model after each epoch 29 | if an improvement observed in the accuracy value. 30 | - For training in the default setting: ```python main.py ``` 31 | - Loading the model and continue training: ```python main.py --reload_epoch=epoch_num``` 32 | where ```epoch_num``` determines th model number to be reload (e.g. is epoch_num=3, 33 | it will load the model trained and stored after 3 epochs). 34 | - For training AlexNet network: 35 | ```python main.py --model=alexnet --loss_type=cross_entropy --add_recon_loss=False``` 36 | 37 | ### 3. Test: 38 | - For running the test: ```python main.py --mode=test --reload_epoch=epoch_num``` 39 | where ```epoch_num``` determines th model number to be reload (e.g. is epoch_num=3, 40 | it will load the model trained and stored after 3 epochs). -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | flags = tf.app.flags 4 | flags.DEFINE_string('mode', 'train', 'train, train_sequence, test, test_sequence or get_features') 5 | flags.DEFINE_integer('reload_epoch', 0, 'model number to load (either for testing or continue training)') 6 | flags.DEFINE_string('model', 'fast_capsule', 'alexnet, resnet, original_capsule, fast_capsule') 7 | flags.DEFINE_string('loss_type', 'margin', 'cross_entropy, spread or margin') 8 | flags.DEFINE_boolean('add_recon_loss', True, 'To add reconstruction loss') 9 | flags.DEFINE_boolean('L2_reg', False, 'Adds L2-regularization to all the network weights') 10 | flags.DEFINE_float('lmbda', 5e-04, 'L2-regularization coefficient') 11 | 12 | # Training logs 13 | flags.DEFINE_integer('max_step', 100000, '# of step for training (only for mnist)') 14 | flags.DEFINE_integer('max_epoch', 1000, '# of step for training (only for nodule data)') 15 | flags.DEFINE_integer('SUMMARY_FREQ', 100, 'Number of step to save summary') 16 | 17 | # For margin loss 18 | flags.DEFINE_float('m_plus', 0.9, 'm+ parameter') 19 | flags.DEFINE_float('m_minus', 0.1, 'm- parameter') 20 | flags.DEFINE_float('lambda_val', 0.5, 'Down-weighting parameter for the absent class') 21 | # For reconstruction loss 22 | flags.DEFINE_float('alpha', 0.0005, 'Regularization coefficient to scale down the reconstruction loss') 23 | # For training 24 | flags.DEFINE_integer('batch_size', 32, 'training batch size') 25 | flags.DEFINE_float('init_lr', 1e-4, 'Initial learning rate') 26 | flags.DEFINE_float('lr_min', 1e-5, 'Minimum learning rate') 27 | 28 | # data 29 | flags.DEFINE_string('data', 'nodule', 'nodule') 30 | flags.DEFINE_integer('num_cls', 2, 'Number of output classes') 31 | flags.DEFINE_float('percent', 1, 'Percentage of training data to use') 32 | flags.DEFINE_boolean('one_hot', True, 'one-hot-encodes the labels (set to False if it is already one-hot-encoded)') 33 | flags.DEFINE_boolean('normalize', True, 'Normalizes the data (set to False if it is already normalized)') 34 | flags.DEFINE_boolean('data_augment', True, 'Adds augmentation to data') 35 | flags.DEFINE_integer('max_angle', 180, 'Maximum rotation angle along each axis; when applying augmentation') 36 | flags.DEFINE_integer('height', 32, 'Network input height size') 37 | flags.DEFINE_integer('width', 32, 'Network input width size') 38 | flags.DEFINE_integer('depth', 32, 'Network input depth size (in the case of 3D input images)') 39 | flags.DEFINE_integer('channel', 1, 'Network input channel size') 40 | 41 | # Directories 42 | flags.DEFINE_string('run_name', '01', 'Run name') 43 | flags.DEFINE_string('logdir', './Results/log_dir/', 'Logs directory') 44 | flags.DEFINE_string('modeldir', './Results/model_dir/', 'Saved models directory') 45 | flags.DEFINE_string('model_name', 'model', 'Model file name') 46 | 47 | # CapsNet architecture 48 | flags.DEFINE_integer('iter_routing', 3, 'Number of dynamic routing iterations') 49 | flags.DEFINE_integer('prim_caps_dim', 256, 'Dimension of the PrimaryCapsules') 50 | flags.DEFINE_integer('digit_caps_dim', 16, 'Dimension of the DigitCapsules') 51 | flags.DEFINE_integer('h1', 512, 'Number of hidden units of the first FC layer of the reconstruction network') 52 | flags.DEFINE_integer('h2', 1024, 'Number of hidden units of the second FC layer of the reconstruction network') 53 | 54 | # cnn architectures 55 | flags.DEFINE_float('dropout_rate', 0.2, 'Drop-out rate of the CNN models') 56 | 57 | args = tf.app.flags.FLAGS 58 | -------------------------------------------------------------------------------- /imgs/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amobiny/Fast_CapsNet/a2f0ea3a89733bc747342566c43f5be468dcb029/imgs/Fig1.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from config import args 3 | import os 4 | from utils import write_spec 5 | 6 | if args.model == 'original_capsule': 7 | from models.Original_CapsNet import OrigCapsNet as Model 8 | elif args.model == 'fast_capsule': 9 | from models.FastCapsNet_3D import FastCapsNet3D as Model 10 | elif args.model == 'alexnet': 11 | from models.AlexNet_3D import AlexNet3D as Model 12 | elif args.model == 'resnet': 13 | from models.ResNet_3D import ResNet3D as Model 14 | 15 | 16 | def main(_): 17 | if args.mode not in ['train', 'test']: 18 | print('invalid mode: ', args.mode) 19 | print("Please input a mode: train or test") 20 | elif args.mode == 'train' or args.mode == 'test': 21 | model = Model(tf.Session(), args) 22 | if not os.path.exists(args.modeldir+args.run_name): 23 | os.makedirs(args.modeldir+args.run_name) 24 | if not os.path.exists(args.logdir+args.run_name): 25 | os.makedirs(args.logdir+args.run_name) 26 | if args.mode == 'train': 27 | write_spec(args) 28 | model.train() 29 | elif args.mode == 'test': 30 | model.test(args.reload_epoch) 31 | 32 | 33 | if __name__ == '__main__': 34 | # configure which gpu to use 35 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 36 | tf.app.run() 37 | -------------------------------------------------------------------------------- /models/AlexNet_3D.py: -------------------------------------------------------------------------------- 1 | from base_model import BaseModel 2 | import tensorflow as tf 3 | from models.utils.ops_cnn import conv_layer_3d, fc_layer, dropout, max_pool_3d, relu, batch_normalization 4 | from tensorflow.contrib.layers import flatten 5 | 6 | 7 | class AlexNet3D(BaseModel): 8 | def __init__(self, sess, conf): 9 | super(AlexNet3D, self).__init__(sess, conf) 10 | self.build_network(self.x) 11 | if self.conf.mode != 'train_sequence': 12 | self.configure_network() 13 | 14 | def build_network(self, x): 15 | # Building network... 16 | with tf.variable_scope('CapsNet'): 17 | net = batch_normalization(relu(conv_layer_3d(x, kernel_size=7, stride=2, num_filters=96, 18 | add_reg=self.conf.L2_reg, layer_name='CONV1')), 19 | training=self.is_training, scope='BN1') 20 | net = max_pool_3d(net, pool_size=3, stride=2, padding='SAME', name='MaxPool1') 21 | net = batch_normalization(relu(conv_layer_3d(net, kernel_size=5, stride=2, num_filters=256, 22 | add_reg=self.conf.L2_reg, layer_name='CONV2')), 23 | training=self.is_training, scope='BN2') 24 | net = max_pool_3d(net, pool_size=3, stride=2, padding='SAME', name='MaxPool2') 25 | net = batch_normalization(relu(conv_layer_3d(net, kernel_size=3, stride=1, num_filters=384, 26 | add_reg=self.conf.L2_reg, layer_name='CONV3')), 27 | training=self.is_training, scope='BN3') 28 | net = batch_normalization(relu(conv_layer_3d(net, kernel_size=3, stride=1, num_filters=384, 29 | add_reg=self.conf.L2_reg, layer_name='CONV4')), 30 | training=self.is_training, scope='BN4') 31 | net = batch_normalization(relu(conv_layer_3d(net, kernel_size=3, stride=1, num_filters=256, 32 | add_reg=self.conf.L2_reg, layer_name='CONV5')), 33 | training=self.is_training, scope='BN5') 34 | net = max_pool_3d(net, pool_size=3, stride=2, padding='SAME', name='MaxPool3') 35 | layer_flat = flatten(net) 36 | net = relu(fc_layer(layer_flat, num_units=200, add_reg=self.conf.L2_reg, layer_name='FC1')) 37 | net = dropout(net, self.conf.dropout_rate, training=self.is_training) 38 | net = relu(fc_layer(net, num_units=75, add_reg=self.conf.L2_reg, layer_name='FC2')) 39 | net = dropout(net, self.conf.dropout_rate, training=self.is_training) 40 | self.features = net 41 | self.logits = fc_layer(net, num_units=self.conf.num_cls, add_reg=self.conf.L2_reg, layer_name='FC3') 42 | # [?, num_cls] 43 | self.probs = tf.nn.softmax(self.logits) 44 | # [?, num_cls] 45 | self.y_pred = tf.to_int32(tf.argmax(self.probs, 1)) 46 | # [?] (predicted labels) 47 | -------------------------------------------------------------------------------- /models/FastCapsNet_3D.py: -------------------------------------------------------------------------------- 1 | from models.base_model import BaseModel 2 | import tensorflow as tf 3 | from models.utils.ops_caps import * 4 | import numpy as np 5 | 6 | 7 | class FastCapsNet3D(BaseModel): 8 | def __init__(self, sess, conf): 9 | super(FastCapsNet3D, self).__init__(sess, conf) 10 | self.build_network() 11 | self.configure_network() 12 | 13 | def build_network(self): 14 | # Building network... 15 | with tf.variable_scope('CapsNet'): 16 | with tf.variable_scope('Conv1_layer'): 17 | conv1 = tf.layers.conv3d(self.x, filters=256, kernel_size=9, strides=1, 18 | padding='valid', activation=tf.nn.relu, name="conv1") 19 | # [batch_size, 24, 24, 24, 256] 20 | 21 | with tf.variable_scope('PrimaryCaps_layer'): 22 | conv2 = tf.layers.conv3d(conv1, filters=256, kernel_size=9, strides=2, 23 | padding='valid', activation=tf.nn.relu, name="conv2") 24 | # [batch_size, 8, 8, 8, 256] 25 | shape = conv2.get_shape().as_list() 26 | num_prim_caps = int(shape[1] * shape[2] * shape[3] * shape[4] / self.conf.prim_caps_dim) 27 | caps1_raw = tf.reshape(conv2, (self.conf.batch_size, num_prim_caps, 28 | self.conf.prim_caps_dim, 1), name="caps1_raw") 29 | # [batch_size, 8*8*8, 256, 1] 30 | caps1_output = squash(caps1_raw, name="caps1_output") 31 | # [batch_size, 512, 256, 1] 32 | 33 | # DigitCaps layer, return [batch_size, 10, 16, 1] 34 | with tf.variable_scope('DigitCaps_layer'): 35 | caps2_input = tf.reshape(caps1_output, 36 | shape=(self.conf.batch_size, num_prim_caps, 1, self.conf.prim_caps_dim, 1)) 37 | # [batch_size, 512, 1, 256, 1] 512 capsules of 256D 38 | b_IJ = tf.zeros([self.conf.batch_size, num_prim_caps, self.conf.num_cls, 1, 1], 39 | dtype=np.float32, name="b_ij") 40 | # [batch_size, 512, 2, 1, 1] 41 | self.caps2_output, u_hat = routing(caps2_input, b_IJ, self.conf.digit_caps_dim) 42 | # [batch_size, 2, 16, 1], [batch_size, 512, 2, 16, 1] 43 | u_hat_shape = u_hat.get_shape().as_list() 44 | self.img_s = int(round(u_hat_shape[1] ** (1. / 3))) 45 | self.u_hat = tf.reshape(u_hat, 46 | (self.conf.batch_size, self.img_s, self.img_s, self.img_s, 1, self.conf.num_cls, 47 | -1)) 48 | # [batch_size, 8, 8, 8, 1, 2, 16] 49 | 50 | epsilon = 1e-9 51 | self.v_length = tf.squeeze(tf.sqrt(tf.reduce_sum( 52 | tf.square(self.caps2_output), axis=2, keep_dims=True) + epsilon)) 53 | # [batch_size, 2] 54 | self.y_pred = tf.to_int32(tf.argmax(self.v_length, axis=1)) 55 | # [batch_size,] (predicted labels) 56 | 57 | if self.conf.add_recon_loss: 58 | self.mask() 59 | self.decoder() 60 | 61 | def mask(self): 62 | with tf.variable_scope('Masking'): 63 | y_pred_ohe = tf.one_hot(self.y_pred, depth=self.conf.num_cls) 64 | # [batch_size, 2] (one-hot-encoded predicted labels) 65 | 66 | reconst_targets = tf.cond(self.is_training, # condition 67 | lambda: self.y, # if True (Training) 68 | lambda: y_pred_ohe, # if False (Test) 69 | name="reconstruction_targets") 70 | # [batch_size, 2] 71 | reconst_targets = tf.reshape(reconst_targets, (self.conf.batch_size, 1, 1, 1, self.conf.num_cls)) 72 | # [batch_size, 1, 1, 2] 73 | reconst_targets = tf.tile(reconst_targets, (1, self.img_s, self.img_s, self.img_s, 1)) 74 | # [batch_size, 8, 8, 8, 2] 75 | indices = tf.argmax(self.v_length, axis=1) 76 | self.u_hat = tf.transpose(self.u_hat, perm=[5, 0, 1, 2, 3, 4, 6]) 77 | # u_hat: [2, batch_size, 8, 8, 8, 1, 16] 78 | u_list = tf.unstack(self.u_hat, axis=1) 79 | ind_list = tf.unstack(indices, axis=0) 80 | a = tf.stack([tf.gather_nd(mat, [[ind]]) for mat, ind in zip(u_list, ind_list)]) 81 | # [batch_size, 1, 8, 8, 8, 1, 16] 82 | feat = tf.reshape(tf.transpose(a, perm=[0, 2, 3, 4, 1, 5, 6]), 83 | (self.conf.batch_size, self.img_s, self.img_s, self.img_s, -1)) 84 | # [batch_size, 8, 8, 8, 16] 85 | self.cube = tf.concat([feat, reconst_targets], axis=-1) 86 | # [batch_size, 8, 8, 8, 18] 87 | 88 | def decoder(self): 89 | with tf.variable_scope('Decoder'): 90 | res1 = deconv3d(self.cube, [self.conf.batch_size, 16, 16, 16, 16], 91 | k_h=4, k_w=4, k_d=4, d_h=2, d_w=2, d_d=2, stddev=0.02, name="deconv_1") 92 | self.decoder_output = deconv3d(res1, [self.conf.batch_size, 32, 32, 32, 1], 93 | k_h=4, k_w=4, k_d=4, d_h=2, d_w=2, d_d=2, stddev=0.02, name="deconv_2") 94 | -------------------------------------------------------------------------------- /models/Original_CapsNet.py: -------------------------------------------------------------------------------- 1 | from base_model import BaseModel 2 | import tensorflow as tf 3 | from ops import * 4 | 5 | 6 | class OrigCapsNet(BaseModel): 7 | def __init__(self, sess, conf): 8 | super(OrigCapsNet, self).__init__(sess, conf) 9 | self.build_network(self.x) 10 | if self.conf.mode != 'train_sequence' and self.conf.mode != 'test_sequence': 11 | self.configure_network() 12 | 13 | def build_network(self, x): 14 | # Building network... 15 | with tf.variable_scope('CapsNet'): 16 | # Layer 1: A 2D conv layer 17 | conv1 = tf.keras.layers.Conv2D(filters=256, kernel_size=9, strides=1, trainable=self.conf.trainable, 18 | padding='valid', activation='relu', name='conv1')(x) 19 | 20 | # Layer 2: Primary Capsule Layer; simply a 2D conv + reshaping 21 | primary_caps = tf.keras.layers.Conv2D(filters=256, kernel_size=9, strides=2, trainable=self.conf.trainable, 22 | padding='valid', activation='relu', name='primary_caps')(conv1) 23 | _, H, W, dim = primary_caps.get_shape() 24 | num_caps = H.value * W.value * dim.value / self.conf.prim_caps_dim 25 | primary_caps_reshaped = tf.keras.layers.Reshape((num_caps, self.conf.prim_caps_dim))(primary_caps) 26 | caps1_output = squash(primary_caps_reshaped) 27 | 28 | # Layer 3: Digit Capsule Layer; Here is where the routing takes place 29 | self.digit_caps = FCCapsuleLayer(num_caps=self.conf.num_cls, caps_dim=self.conf.digit_caps_dim, 30 | routings=3, name='digit_caps', trainable=self.conf.trainable)(caps1_output) 31 | # [?, 2, 16] 32 | 33 | epsilon = 1e-9 34 | self.v_length = tf.sqrt(tf.reduce_sum(tf.square(self.digit_caps), axis=2, keep_dims=True) + epsilon) 35 | # [?, 2, 1] 36 | y_prob_argmax = tf.to_int32(tf.argmax(self.v_length, axis=1)) 37 | # [?, 1] 38 | self.y_pred = tf.squeeze(y_prob_argmax) 39 | # [?] (predicted labels) 40 | 41 | if self.conf.add_recon_loss: 42 | self.mask() 43 | self.decoder() 44 | 45 | def mask(self): # used in capsule network 46 | with tf.variable_scope('Masking'): 47 | y_pred_ohe = tf.one_hot(self.y_pred, depth=self.conf.num_cls) 48 | # [?, 10] (one-hot-encoded predicted labels) 49 | 50 | reconst_targets = tf.cond(self.is_training, # condition 51 | lambda: self.y, # if True (Training) 52 | lambda: y_pred_ohe, # if False (Test) 53 | name="reconstruction_targets") 54 | # [?, 10] 55 | self.output_masked = tf.multiply(self.digit_caps, tf.expand_dims(reconst_targets, -1)) 56 | # [?, 2, 16] 57 | 58 | def decoder(self): 59 | with tf.variable_scope('Decoder'): 60 | decoder_input = tf.reshape(self.output_masked, [-1, self.conf.num_cls * self.conf.digit_caps_dim]) 61 | # [?, 160] 62 | fc1 = tf.layers.dense(decoder_input, self.conf.h1, activation=tf.nn.relu, name="FC1", 63 | trainable=self.conf.trainable) 64 | # [?, 512] 65 | fc2 = tf.layers.dense(fc1, self.conf.h2, activation=tf.nn.relu, name="FC2", trainable=self.conf.trainable) 66 | # [?, 1024] 67 | self.decoder_output = tf.layers.dense(fc2, self.conf.width * self.conf.height * self.conf.channel, 68 | activation=tf.nn.sigmoid, name="FC3", trainable=self.conf.trainable) 69 | # [?, 784] 70 | -------------------------------------------------------------------------------- /models/ResNet_3D.py: -------------------------------------------------------------------------------- 1 | from base_model import BaseModel 2 | import tensorflow as tf 3 | from models.utils.ops_cnn import batch_normalization, relu, conv_layer_3d, dropout, fc_layer, \ 4 | max_pool_3d, average_pool_3d, flatten 5 | from collections import namedtuple 6 | 7 | 8 | class ResNet3D(BaseModel): 9 | def __init__(self, sess, conf): 10 | super(ResNet3D, self).__init__(sess, conf) 11 | # Configurations for each bottleneck group. 12 | BottleneckGroup = namedtuple('BottleneckGroup', 13 | ['num_blocks', 'bottleneck_size', 'out_filters']) 14 | self.groups = [BottleneckGroup(3, 32, 64), BottleneckGroup(4, 48, 128), 15 | BottleneckGroup(6, 64, 256), BottleneckGroup(3, 128, 512)] 16 | self.build_network(self.x) 17 | if self.conf.mode != 'train_sequence': 18 | self.configure_network() 19 | 20 | def build_network(self, x): 21 | # Building network... 22 | with tf.variable_scope('ResNet'): 23 | net = conv_layer_3d(x, num_filters=64, kernel_size=4, stride=1, add_reg=self.conf.L2_reg, 24 | layer_name='CONV0') 25 | net = relu(batch_normalization(net, training=self.is_training, scope='BN1')) 26 | # net = max_pool_3d(net, pool_size=3, stride=2, name='MaxPool0') 27 | 28 | # Create the bottleneck groups, each of which contains `num_blocks` bottleneck blocks. 29 | for group_i, group in enumerate(self.groups): 30 | first_block = True 31 | for block_i in range(group.num_blocks): 32 | block_name = 'group_%d/block_%d' % (group_i, block_i) 33 | net = self.bottleneck_block(net, group, block_name, is_first_block=first_block) 34 | first_block = False 35 | 36 | net = average_pool_3d(net, pool_size=2, stride=1, name='avg_pool') 37 | net = flatten(net) 38 | net = fc_layer(net, num_units=75, add_reg=self.conf.L2_reg, layer_name='Fc1') 39 | net = dropout(net, self.conf.dropout_rate, training=self.is_training) 40 | self.logits = fc_layer(net, num_units=self.conf.num_cls, add_reg=self.conf.L2_reg, layer_name='Fc2') 41 | # [?, num_cls] 42 | self.probs = tf.nn.softmax(self.logits) 43 | # [?, num_cls] 44 | self.y_pred = tf.to_int32(tf.argmax(self.probs, 1)) 45 | # [?] (predicted labels) 46 | 47 | def bottleneck_block(self, input_x, group, name, is_first_block=False): 48 | with tf.variable_scope(name): 49 | # 1x1 convolution responsible for reducing the depth 50 | with tf.variable_scope('conv_in'): 51 | stride = 2 if is_first_block else 1 52 | conv = conv_layer_3d(input_x, num_filters=group.bottleneck_size, add_reg=self.conf.L2_reg, 53 | kernel_size=1, stride=stride, layer_name='CONV') 54 | conv = relu(batch_normalization(conv, self.is_training, scope='BN')) 55 | 56 | with tf.variable_scope('conv_bottleneck'): 57 | conv = conv_layer_3d(conv, num_filters=group.bottleneck_size, kernel_size=3, 58 | add_reg=self.conf.L2_reg, layer_name='CONV') 59 | conv = relu(batch_normalization(conv, self.is_training, scope='BN')) 60 | 61 | # 1x1 convolution responsible for increasing the depth 62 | with tf.variable_scope('conv_out'): 63 | conv = conv_layer_3d(conv, num_filters=group.out_filters, kernel_size=1, 64 | add_reg=self.conf.L2_reg, layer_name='CONV') 65 | conv = batch_normalization(conv, self.is_training, scope='BN') 66 | 67 | # shortcut connections that turn the network into its counterpart 68 | # residual function (identity shortcut) 69 | with tf.variable_scope('shortcut'): 70 | if is_first_block: 71 | shortcut = conv_layer_3d(input_x, num_filters=group.out_filters, stride=2, kernel_size=1, 72 | add_reg=self.conf.L2_reg, layer_name='CONV_shortcut') 73 | shortcut = batch_normalization(shortcut, self.is_training, scope='BN_shortcut') 74 | assert (shortcut.get_shape().as_list() == conv.get_shape().as_list()), \ 75 | "Tensor sizes of the two branches are not matched!" 76 | res = shortcut + conv 77 | else: 78 | res = conv + input_x 79 | assert (input_x.get_shape().as_list() == conv.get_shape().as_list()), \ 80 | "Tensor sizes of the two branches are not matched!" 81 | return relu(res) 82 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amobiny/Fast_CapsNet/a2f0ea3a89733bc747342566c43f5be468dcb029/models/__init__.py -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import numpy as np 4 | from models.utils.loss_ops import margin_loss, spread_loss, cross_entropy 5 | from sklearn.metrics import confusion_matrix 6 | from DataLoader import DataLoader 7 | 8 | 9 | class BaseModel(object): 10 | def __init__(self, sess, conf): 11 | self.sess = sess 12 | self.conf = conf 13 | self.summary_list = [] 14 | self.input_shape = [conf.batch_size, conf.height, conf.width, conf.depth, conf.channel] 15 | self.output_shape = [self.conf.batch_size, self.conf.num_cls] 16 | self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 17 | self.create_placeholders() 18 | 19 | def create_placeholders(self): 20 | with tf.name_scope('Input'): 21 | self.x = tf.placeholder(tf.float32, self.input_shape, name='input') 22 | self.y = tf.placeholder(tf.float32, self.output_shape, name='annotation') 23 | self.is_training = tf.placeholder_with_default(False, shape=(), name="is_training") 24 | 25 | def loss_func(self): 26 | with tf.variable_scope('Loss'): 27 | if self.conf.loss_type == 'margin': 28 | loss = margin_loss(self.y, self.v_length, self.conf) 29 | self.summary_list.append(tf.summary.scalar('margin', loss)) 30 | elif self.conf.loss_type == 'spread': 31 | self.generate_margin() 32 | loss = spread_loss(self.y, self.act, self.margin, 'spread_loss') 33 | self.summary_list.append(tf.summary.scalar('spread_loss', loss)) 34 | elif self.conf.loss_type == 'cross_entropy': 35 | loss = cross_entropy(self.y, self.logits) 36 | tf.summary.scalar('cross_entropy', loss) 37 | if self.conf.L2_reg: 38 | with tf.name_scope('l2_loss'): 39 | l2_loss = tf.reduce_sum(self.conf.lmbda * tf.stack([tf.nn.l2_loss(v) 40 | for v in tf.get_collection('weights')])) 41 | loss += l2_loss 42 | self.summary_list.append(tf.summary.scalar('l2_loss', l2_loss)) 43 | if self.conf.add_recon_loss: 44 | with tf.variable_scope('Reconstruction_Loss'): 45 | squared = tf.square(self.decoder_output - self.x) 46 | self.recon_err = tf.reduce_mean(squared) 47 | self.total_loss = loss + self.conf.alpha * self.recon_err 48 | self.summary_list.append(tf.summary.scalar('reconstruction_loss', self.recon_err)) 49 | self.summary_list.append(tf.summary.image('reconstructed', self.decoder_output[:, :, :, 16, :])) 50 | self.summary_list.append(tf.summary.image('original', self.x[:, :, :, 16, :])) 51 | else: 52 | self.total_loss = loss 53 | self.mean_loss, self.mean_loss_op = tf.metrics.mean(self.total_loss) 54 | 55 | def accuracy_func(self): 56 | with tf.variable_scope('Accuracy'): 57 | correct_prediction = tf.equal(tf.to_int32(tf.argmax(self.y, axis=1)), self.y_pred) 58 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 59 | self.mean_accuracy, self.mean_accuracy_op = tf.metrics.mean(accuracy) 60 | 61 | def generate_margin(self): 62 | # margin schedule 63 | # margin increase from 0.2 to 0.9 after margin_schedule_epoch_achieve_max 64 | NUM_STEPS_PER_EPOCH = int(self.conf.N / self.conf.batch_size) 65 | margin_schedule_epoch_achieve_max = 10.0 66 | self.margin = tf.train.piecewise_constant(tf.cast(self.global_step, dtype=tf.int32), 67 | boundaries=[int(NUM_STEPS_PER_EPOCH * 68 | margin_schedule_epoch_achieve_max * x / 7) 69 | for x in xrange(1, 8)], 70 | values=[x / 10.0 for x in range(2, 10)]) 71 | 72 | def configure_network(self): 73 | self.loss_func() 74 | self.accuracy_func() 75 | 76 | with tf.name_scope('Optimizer'): 77 | with tf.name_scope('Learning_rate_decay'): 78 | learning_rate = tf.train.exponential_decay(self.conf.init_lr, 79 | self.global_step, 80 | decay_steps=3000, 81 | decay_rate=0.97, 82 | staircase=True) 83 | self.learning_rate = tf.maximum(learning_rate, self.conf.lr_min) 84 | self.summary_list.append(tf.summary.scalar('learning_rate', self.learning_rate)) 85 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 86 | grads = optimizer.compute_gradients(self.total_loss) 87 | self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step) 88 | self.sess.run(tf.global_variables_initializer()) 89 | trainable_vars = tf.trainable_variables() 90 | self.saver = tf.train.Saver(var_list=trainable_vars, max_to_keep=1000) 91 | self.train_writer = tf.summary.FileWriter(self.conf.logdir + self.conf.run_name + '/train/', self.sess.graph) 92 | self.valid_writer = tf.summary.FileWriter(self.conf.logdir + self.conf.run_name + '/valid/') 93 | self.configure_summary() 94 | print('*' * 50) 95 | print('Total number of trainable parameters: {}'. 96 | format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))) 97 | print('*' * 50) 98 | 99 | def configure_summary(self): 100 | summary_list = [tf.summary.scalar('Loss/total_loss', self.mean_loss), 101 | tf.summary.scalar('Accuracy/average_accuracy', self.mean_accuracy)] + self.summary_list 102 | self.merged_summary = tf.summary.merge(summary_list) 103 | 104 | def save_summary(self, summary, step, mode): 105 | if mode == 'train': 106 | self.train_writer.add_summary(summary, step) 107 | elif mode == 'valid': 108 | self.valid_writer.add_summary(summary, step) 109 | self.sess.run(tf.local_variables_initializer()) 110 | 111 | def train(self): 112 | self.sess.run(tf.local_variables_initializer()) 113 | self.best_validation_accuracy = 0 114 | self.data_reader = DataLoader(self.conf) 115 | self.data_reader.get_data(mode='train') 116 | self.data_reader.get_data(mode='valid') 117 | self.train_loop() 118 | 119 | def train_loop(self): 120 | if self.conf.reload_epoch > 0: 121 | self.reload(self.conf.reload_epoch) 122 | print('*' * 50) 123 | print('----> Continue Training from step #{}'.format(self.conf.reload_epoch)) 124 | print('*' * 50) 125 | else: 126 | print('*' * 50) 127 | print('----> Start Training') 128 | print('*' * 50) 129 | self.num_val_batch = self.data_reader.count_num_batch(self.conf.batch_size, mode='valid') 130 | self.num_train_batch = self.data_reader.count_num_batch(self.conf.batch_size, mode='train') 131 | for epoch in range(self.conf.max_epoch): 132 | self.data_reader.randomize() 133 | for train_step in range(self.num_train_batch): 134 | glob_step = epoch * self.num_train_batch + train_step 135 | start = train_step * self.conf.batch_size 136 | end = (train_step + 1) * self.conf.batch_size 137 | x_batch, y_batch = self.data_reader.next_batch(start, end, mode='train') 138 | feed_dict = {self.x: x_batch, self.y: y_batch, self.is_training: True} 139 | if train_step % self.conf.SUMMARY_FREQ == 0: 140 | _, _, _, summary = self.sess.run([self.train_op, 141 | self.mean_loss_op, 142 | self.mean_accuracy_op, 143 | self.merged_summary], feed_dict=feed_dict) 144 | loss, acc = self.sess.run([self.mean_loss, self.mean_accuracy]) 145 | self.save_summary(summary, glob_step + self.conf.reload_epoch * self.num_train_batch, mode='train') 146 | print('step: {0:<6}, train_loss= {1:.4f}, train_acc={2:.01%}'.format(train_step, loss, acc)) 147 | else: 148 | self.sess.run([self.train_op, self.mean_loss_op, self.mean_accuracy_op], feed_dict=feed_dict) 149 | self.evaluate(glob_step, epoch) 150 | 151 | def evaluate(self, train_step, epoch): 152 | self.sess.run(tf.local_variables_initializer()) 153 | y_pred = np.zeros((self.data_reader.y_valid.shape[0])) 154 | for step in range(self.num_val_batch): 155 | start = step * self.conf.batch_size 156 | end = (step + 1) * self.conf.batch_size 157 | x_val, y_val = self.data_reader.next_batch(start, end, mode='valid') 158 | feed_dict = {self.x: x_val, self.y: y_val, self.is_training: False} 159 | yp, _, _ = self.sess.run([self.y_pred, self.mean_loss_op, self.mean_accuracy_op], feed_dict=feed_dict) 160 | y_pred[start:end] = yp 161 | summary_valid = self.sess.run(self.merged_summary, feed_dict=feed_dict) 162 | valid_loss, valid_acc = self.sess.run([self.mean_loss, self.mean_accuracy]) 163 | self.save_summary(summary_valid, train_step + self.conf.reload_epoch * self.num_train_batch, mode='valid') 164 | if valid_acc > self.best_validation_accuracy: 165 | self.best_validation_accuracy = valid_acc 166 | improved_str = '(improved)' 167 | self.save(epoch) 168 | else: 169 | improved_str = '' 170 | 171 | print('-' * 25 + 'Validation' + '-' * 25) 172 | print('After {0} training step: val_loss= {1:.4f}, val_acc={2:.01%}{3}' 173 | .format(train_step, valid_loss, valid_acc, improved_str)) 174 | print(confusion_matrix(np.argmax(self.data_reader.y_valid, axis=1), y_pred)) 175 | print('-' * 60) 176 | 177 | def test(self, step_num): 178 | self.sess.run(tf.local_variables_initializer()) 179 | self.reload(step_num) 180 | self.data_reader = DataLoader(self.conf) 181 | self.data_reader.get_data(mode='test') 182 | self.num_test_batch = self.data_reader.count_num_batch(self.conf.batch_size, mode='test') 183 | self.is_train = False 184 | self.sess.run(tf.local_variables_initializer()) 185 | y_pred = np.zeros((self.data_reader.y_test.shape[0])) 186 | img_recon = np.zeros((self.data_reader.y_test.shape[0], self.conf.height * self.conf.width)) 187 | for step in range(self.num_test_batch): 188 | start = step * self.conf.batch_size 189 | end = (step + 1) * self.conf.batch_size 190 | x_test, y_test = self.data_reader.next_batch(start, end, mode='test') 191 | feed_dict = {self.x: x_test, self.y: y_test, self.is_training: False} 192 | yp, _, _, img = self.sess.run([self.y_pred, self.mean_loss_op, self.mean_accuracy_op, self.decoder_output], 193 | feed_dict=feed_dict) 194 | y_pred[start:end] = yp 195 | img_recon[start:end] = img 196 | test_loss, test_acc = self.sess.run([self.mean_loss, self.mean_accuracy]) 197 | print('-' * 18 + 'Test Completed' + '-' * 18) 198 | print('test_loss= {0:.4f}, test_acc={1:.01%}'.format(test_loss, test_acc)) 199 | print(confusion_matrix(np.argmax(self.data_reader.y_test, axis=1), y_pred)) 200 | print('-' * 50) 201 | 202 | def save(self, epch): 203 | print('----> Saving the model at step #{0}'.format(epch)) 204 | checkpoint_path = os.path.join(self.conf.modeldir + self.conf.run_name, self.conf.model_name) 205 | self.saver.save(self.sess, checkpoint_path, global_step=epch+1) 206 | 207 | def reload(self, epch): 208 | checkpoint_path = os.path.join(self.conf.modeldir + self.conf.run_name, self.conf.model_name) 209 | model_path = checkpoint_path + '-' + str(epch) 210 | if not os.path.exists(model_path + '.meta'): 211 | print('----> No such checkpoint found', model_path) 212 | return 213 | print('----> Restoring the CNN model...') 214 | self.saver.restore(self.sess, model_path) 215 | print('----> CNN Model successfully restored') 216 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amobiny/Fast_CapsNet/a2f0ea3a89733bc747342566c43f5be468dcb029/models/utils/__init__.py -------------------------------------------------------------------------------- /models/utils/loss_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def spread_loss(labels, activations, margin, name): 5 | """This adds spread loss to total loss. 6 | :param labels: [N, O], where O is number of output classes, one hot vector, tf.uint8. 7 | :param activations: [N, O], activations. 8 | :param margin: margin 0.2 - 0.9 fixed schedule during training. 9 | :return: spread loss 10 | """ 11 | activations_shape = activations.get_shape().as_list() 12 | with tf.variable_scope(name): 13 | mask_t = tf.equal(labels, 1) 14 | mask_i = tf.equal(labels, 0) 15 | 16 | activations_t = tf.reshape(tf.boolean_mask(activations, mask_t), [activations_shape[0], 1]) 17 | activations_i = tf.reshape(tf.boolean_mask(activations, mask_i), 18 | [activations_shape[0], activations_shape[1] - 1]) 19 | gap_mit = tf.reduce_sum(tf.square(tf.nn.relu(margin - (activations_t - activations_i)))) 20 | return gap_mit 21 | 22 | 23 | def margin_loss(y, v_length, conf): 24 | with tf.variable_scope('Margin_Loss'): 25 | # max(0, m_plus-||v_c||)^2 26 | present_error = tf.square(tf.maximum(0., conf.m_plus - v_length)) 27 | # [?, 10, 1] 28 | # max(0, ||v_c||-m_minus)^2 29 | absent_error = tf.square(tf.maximum(0., v_length - conf.m_minus)) 30 | # [?, 10, 1] 31 | # reshape: [?, 10, 1] => [?, 10] 32 | present_error = tf.squeeze(present_error) 33 | absent_error = tf.squeeze(absent_error) 34 | T_c = y 35 | # [?, 10] 36 | L_c = T_c * present_error + conf.lambda_val * (1 - T_c) * absent_error 37 | # [?, 10] 38 | margin_loss = tf.reduce_mean(tf.reduce_sum(L_c, axis=1), name="margin_loss") 39 | return margin_loss 40 | 41 | 42 | def cross_entropy(y, logits): 43 | try: 44 | diff = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits) 45 | except: 46 | diff = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits) 47 | loss = tf.reduce_mean(diff) 48 | return loss 49 | -------------------------------------------------------------------------------- /models/utils/ops_caps.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | 3 | 4 | def squash(s, epsilon=1e-7, name=None): 5 | """ 6 | Squashing function corresponding to Eq. 1 7 | :param s: A tensor with shape [batch_size, 1, num_caps, vec_len, 1] or [batch_size, num_caps, vec_len, 1]. 8 | :param epsilon: To compute norm safely 9 | :param name: 10 | :return: A tensor with the same shape as vector but squashed in 'vec_len' dimension. 11 | """ 12 | with tf.name_scope(name, default_name="squash"): 13 | squared_norm = tf.reduce_sum(tf.square(s), axis=-2, keep_dims=True) 14 | safe_norm = tf.sqrt(squared_norm + epsilon) 15 | squash_factor = squared_norm / (1. + squared_norm) 16 | unit_vector = s / safe_norm 17 | return squash_factor * unit_vector 18 | 19 | 20 | def routing(inputs, b_ij, out_caps_dim): 21 | """ 22 | The routing algorithm 23 | :param inputs: A tensor with [batch_size, num_caps_in=1152, 1, in_caps_dim=8, 1] shape. 24 | num_caps_in: the number of capsule in layer l (i.e. PrimaryCaps). 25 | in_caps_dim: dimension of the output vectors of layer l (i.e. PrimaryCaps) 26 | :param b_ij: [batch_size, num_caps_in=1152, num_caps_out=10, 1, 1] 27 | num_caps_out: the number of capsule in layer l+1 (i.e. DigitCaps). 28 | :param out_caps_dim: dimension of the output vectors of layer l+1 (i.e. DigitCaps) 29 | 30 | :return: A Tensor of shape [batch_size, num_caps_out=10, out_caps_dim=16, 1] 31 | representing the vector output `v_j` in layer l+1. 32 | """ 33 | # W: [num_caps_in, num_caps_out, len_u_i, len_v_j] 34 | W = tf.get_variable('W', shape=(1, inputs.shape[1].value, b_ij.shape[2].value, inputs.shape[3].value, out_caps_dim), 35 | dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=0.01)) 36 | 37 | inputs = tf.tile(inputs, [1, 1, b_ij.shape[2].value, 1, 1]) 38 | # input => [batch_size, 1152, 10, 8, 1] 39 | 40 | W = tf.tile(W, [args.batch_size, 1, 1, 1, 1]) 41 | # W => [batch_size, 1152, 10, 8, 16] 42 | 43 | u_hat = tf.matmul(W, inputs, transpose_a=True) 44 | # [batch_size, 1152, 10, 16, 1] 45 | 46 | # In forward, u_hat_stopped = u_hat; in backward, no gradient passed back from u_hat_stopped to u_hat 47 | u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient') 48 | 49 | # For r iterations do 50 | for r_iter in range(args.iter_routing): 51 | with tf.variable_scope('iter_' + str(r_iter)): 52 | c_ij = tf.nn.softmax(b_ij, dim=2) 53 | # [batch_size, 1152, 10, 1, 1] 54 | 55 | # At last iteration, use `u_hat` in order to receive gradients from the following graph 56 | if r_iter == args.iter_routing - 1: 57 | s_j = tf.multiply(c_ij, u_hat) 58 | # [batch_size, 1152, 10, 16, 1] 59 | # then sum in the second dim 60 | s_j = tf.reduce_sum(s_j, axis=1, keep_dims=True) 61 | # [batch_size, 1, 10, 16, 1] 62 | v_j = squash(s_j) 63 | # [batch_size, 1, 10, 16, 1] 64 | 65 | elif r_iter < args.iter_routing - 1: # Inner iterations, do not apply backpropagation 66 | s_j = tf.multiply(c_ij, u_hat_stopped) 67 | s_j = tf.reduce_sum(s_j, axis=1, keep_dims=True) 68 | v_j = squash(s_j) 69 | v_j_tiled = tf.tile(v_j, [1, inputs.shape[1].value, 1, 1, 1]) 70 | # [batch_size, 1152, 10, 16, 1] 71 | 72 | # then matmul in the last two dim: [16, 1].T x [16, 1] => [1, 1] 73 | u_produce_v = tf.matmul(u_hat_stopped, v_j_tiled, transpose_a=True) 74 | # [batch_size, 1152, 10, 1, 1] 75 | 76 | b_ij += u_produce_v 77 | return tf.squeeze(v_j, axis=1), u_hat 78 | # [batch_size, 10, 16, 1] 79 | 80 | 81 | def deconv3d(input_, output_shape, 82 | k_h=4, k_w=4, k_d=4, d_h=2, d_w=2, d_d=2, stddev=0.02, 83 | name="deconv2d"): 84 | with tf.variable_scope(name): 85 | # filter : [height, width, output_channels, in_channels] 86 | w = tf.get_variable('w', [k_h, k_w, k_d, output_shape[-1], input_.get_shape()[-1]], 87 | initializer=tf.random_normal_initializer(stddev=stddev)) 88 | 89 | deconv = tf.nn.conv3d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, d_d, 1]) 90 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 91 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 92 | return deconv 93 | -------------------------------------------------------------------------------- /models/utils/ops_cnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2022 Department of Electrical and Computer Engineering 3 | University of Houston, TX/USA 4 | ********************************************************************************** 5 | Author: Aryan Mobiny 6 | Date: 9/1/2018 7 | Comments: Includes functions for defining the CNN layers 8 | ********************************************************************************** 9 | """ 10 | import tensorflow as tf 11 | from tflearn.layers.conv import global_avg_pool 12 | from tensorflow.contrib.framework import arg_scope 13 | from tensorflow.contrib.layers import batch_norm, flatten 14 | 15 | 16 | def conv_layer_2d(x, num_filters, kernel_size, add_reg=False, stride=1, layer_name="conv"): 17 | with tf.name_scope(layer_name): 18 | regularizer = None 19 | if add_reg: 20 | regularizer = tf.contrib.layers.l2_regularizer(scale=0.1) 21 | net = tf.layers.conv2d(inputs=x, filters=num_filters, kernel_size=kernel_size, 22 | strides=stride, padding='SAME', kernel_regularizer=regularizer) 23 | print('{}: {}'.format(layer_name, net.get_shape())) 24 | return net 25 | 26 | 27 | def conv_layer_3d(x, num_filters, kernel_size, add_reg=False, stride=1, layer_name="conv"): 28 | with tf.name_scope(layer_name): 29 | regularizer = None 30 | if add_reg: 31 | regularizer = tf.contrib.layers.l2_regularizer(scale=0.1) 32 | net = tf.layers.conv3d(inputs=x, filters=num_filters, kernel_size=kernel_size, 33 | strides=stride, padding='SAME', kernel_regularizer=regularizer) 34 | print('{}: {}'.format(layer_name, net.get_shape())) 35 | return net 36 | 37 | 38 | def fc_layer(x, num_units, add_reg, layer_name): 39 | with tf.name_scope(layer_name): 40 | regularizer = None 41 | if add_reg: 42 | regularizer = tf.contrib.layers.l2_regularizer(scale=0.1) 43 | net = tf.layers.dense(inputs=x, units=num_units, kernel_regularizer=regularizer) 44 | print('{}: {}'.format(layer_name, net.get_shape())) 45 | return net 46 | 47 | 48 | def max_pool_2d(x, pool_size, stride, name, padding='VALID'): 49 | """Create a max pooling layer.""" 50 | net = tf.layers.max_pooling2d(inputs=x, pool_size=pool_size, strides=stride, 51 | padding=padding, name=name) 52 | print('{}: {}'.format(name, net.get_shape())) 53 | return net 54 | 55 | 56 | def max_pool_3d(x, pool_size, stride, name, padding='VALID'): 57 | """Create a max pooling layer.""" 58 | net = tf.layers.max_pooling3d(inputs=x, pool_size=pool_size, strides=stride, 59 | padding=padding, name=name) 60 | print('{}: {}'.format(name, net.get_shape())) 61 | return net 62 | 63 | 64 | def average_pool_2d(x, pool_size, stride, name, padding='VALID'): 65 | """Create an average pooling layer.""" 66 | net = tf.layers.average_pooling2d(inputs=x, pool_size=pool_size, strides=stride, 67 | padding=padding, name=name) 68 | print('{}: {}'.format(name, net.get_shape())) 69 | return net 70 | 71 | 72 | def average_pool_3d(x, pool_size, stride, name, padding='VALID'): 73 | """Create an average pooling layer.""" 74 | net = tf.layers.average_pooling3d(inputs=x, pool_size=pool_size, strides=stride, 75 | padding=padding, name=name) 76 | print('{}: {}'.format(name, net.get_shape())) 77 | return net 78 | 79 | 80 | def global_average_pool(x, name='global_avg_pooling'): 81 | """ 82 | width = np.shape(x)[1] 83 | height = np.shape(x)[2] 84 | pool_size = [width, height] 85 | return tf.layers.average_pooling2d(inputs=x, pool_size=pool_size, strides=stride) 86 | """ 87 | net = global_avg_pool(x, name=name) 88 | print('{}: {}'.format(name, net.get_shape())) 89 | return net 90 | 91 | 92 | def dropout(x, rate, training): 93 | """Create a dropout layer.""" 94 | return tf.layers.dropout(inputs=x, rate=rate, training=training) 95 | 96 | 97 | def batch_normalization(x, training, scope): 98 | with arg_scope([batch_norm], 99 | scope=scope, 100 | updates_collections=None, 101 | decay=0.9, 102 | center=True, 103 | scale=True, 104 | zero_debias_moving_mean=True): 105 | out = tf.cond(training, 106 | lambda: batch_norm(inputs=x, is_training=training, reuse=None), 107 | lambda: batch_norm(inputs=x, is_training=training, reuse=True)) 108 | return out 109 | 110 | 111 | def lrn(inputs, depth_radius=2, alpha=0.0001, beta=0.75, bias=1.0): 112 | return tf.nn.local_response_normalization(inputs, depth_radius=depth_radius, alpha=alpha, beta=beta, bias=bias) 113 | 114 | 115 | def concatenation(layers): 116 | return tf.concat(layers, axis=3) 117 | 118 | 119 | def relu(x): 120 | return tf.nn.relu(x) 121 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def write_spec(args): 5 | config_file = open(args.modeldir + args.run_name + '/config.txt', 'w') 6 | config_file.write('run_name: ' + args.run_name + '\n') 7 | config_file.write('model: ' + args.model + '\n') 8 | config_file.write('loss_type: ' + args.loss_type + '\n') 9 | config_file.write('add_recon_loss: ' + str(args.add_recon_loss) + '\n') 10 | config_file.write('data: ' + args.data + '\n') 11 | config_file.write('height: ' + str(args.height) + '\n') 12 | config_file.write('num_cls: ' + str(args.num_cls) + '\n') 13 | config_file.write('batch_size: ' + str(args.batch_size) + '\n') 14 | config_file.write('optimizer: ' + 'Adam' + '\n') 15 | config_file.write('learning_rate: ' + str(args.init_lr) + ' : ' + str(args.lr_min) + '\n') 16 | config_file.write('data_augmentation: ' + str(args.data_augment) + '\n') 17 | config_file.write('max_angle: ' + str(args.max_angle) + '\n') 18 | if args.model == 'original_capsule': 19 | config_file.write('prim_caps_dim: ' + str(args.prim_caps_dim) + '\n') 20 | config_file.write('digit_caps_dim: ' + str(args.digit_caps_dim) + '\n') 21 | elif args.model == 'alexnet' or args.model == 'resnet': 22 | config_file.write('dropout_rate: ' + str(args.dropout_rate)) 23 | config_file.close() 24 | 25 | 26 | def weight_variable(shape): 27 | """ 28 | Create a weight variable with appropriate initialization 29 | :param shape: weight shape 30 | :return: initialized weight variable 31 | """ 32 | initer = tf.truncated_normal_initializer(stddev=0.01) 33 | return tf.get_variable('W', 34 | dtype=tf.float32, 35 | shape=shape, 36 | initializer=initer) 37 | 38 | 39 | def bias_variable(shape): 40 | """ 41 | Create a bias variable with appropriate initialization 42 | :param shape: bias variable shape 43 | :return: initialized bias variable 44 | """ 45 | initer = tf.constant(0., shape=shape, dtype=tf.float32) 46 | return tf.get_variable('b', 47 | dtype=tf.float32, 48 | initializer=initer) --------------------------------------------------------------------------------