├── img80.jpg ├── result.png ├── utl ├── __init__.py ├── dataset.py ├── metrics.py ├── data_aug_op.py ├── Cell_Net.py ├── DataGenerator.py └── custom_layers.py ├── README.md └── main.py /img80.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utayao/Atten_Deep_MIL/HEAD/img80.jpg -------------------------------------------------------------------------------- /result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utayao/Atten_Deep_MIL/HEAD/result.png -------------------------------------------------------------------------------- /utl/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import dataset 3 | from . import Cell_Net 4 | from . import custom_layers 5 | from . import metrics -------------------------------------------------------------------------------- /utl/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | from sklearn.model_selection import KFold 4 | 5 | def load_dataset(dataset_path, n_folds, rand_state): 6 | """ 7 | Parameters 8 | -------------------- 9 | :param dataset_path: 10 | :param n_folds: 11 | :return: list 12 | List contains split datasets for K-Fold cross-validation 13 | """ 14 | 15 | # load datapath from path 16 | pos_path = glob.glob(dataset_path+'/1/img*') 17 | neg_path = glob.glob(dataset_path+'/0/img*') 18 | 19 | pos_num = len(pos_path) 20 | neg_num = len(neg_path) 21 | 22 | all_path = pos_path + neg_path 23 | 24 | #num_bag = len(all_path) 25 | kf = KFold(n_splits=n_folds, shuffle=True, random_state=rand_state) 26 | datasets = [] 27 | for train_idx, test_idx in kf.split(all_path): 28 | dataset = {} 29 | dataset['train'] = [all_path[ibag] for ibag in train_idx] 30 | dataset['test'] = [all_path[ibag] for ibag in test_idx] 31 | datasets.append(dataset) 32 | return datasets -------------------------------------------------------------------------------- /utl/metrics.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | 3 | def bag_accuracy(y_true, y_pred): 4 | """Compute accuracy of one bag. 5 | Parameters 6 | --------------------- 7 | y_true : Tensor (N x 1) 8 | GroundTruth of bag. 9 | y_pred : Tensor (1 X 1) 10 | Prediction score of bag. 11 | Return 12 | --------------------- 13 | acc : Tensor (1 x 1) 14 | Accuracy of bag label prediction. 15 | """ 16 | y_true = K.mean(y_true, axis=0, keepdims=False) 17 | y_pred = K.mean(y_pred, axis=0, keepdims=False) 18 | acc = K.mean(K.equal(y_true, K.round(y_pred))) 19 | return acc 20 | 21 | 22 | def bag_loss(y_true, y_pred): 23 | """Compute binary crossentropy loss of predicting bag loss. 24 | Parameters 25 | --------------------- 26 | y_true : Tensor (N x 1) 27 | GroundTruth of bag. 28 | y_pred : Tensor (1 X 1) 29 | Prediction score of bag. 30 | Return 31 | --------------------- 32 | acc : Tensor (1 x 1) 33 | Binary Crossentropy loss of predicting bag label. 34 | """ 35 | y_true = K.mean(y_true, axis=0, keepdims=False) 36 | y_pred = K.mean(y_pred, axis=0, keepdims=False) 37 | loss = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1) 38 | return loss -------------------------------------------------------------------------------- /utl/data_aug_op.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | 5 | def random_flip_img(img, horizontal_chance=0, vertical_chance=0): 6 | flip_horizontal = False 7 | if random.random() < horizontal_chance: 8 | flip_horizontal = True 9 | 10 | flip_vertical = False 11 | if random.random() < vertical_chance: 12 | flip_vertical = True 13 | 14 | if not flip_horizontal and not flip_vertical: 15 | return img 16 | 17 | flip_val = 1 18 | if flip_vertical: 19 | flip_val = -1 if flip_horizontal else 0 20 | 21 | if not isinstance(img, list): 22 | res = cv2.flip(img, flip_val) # 0 = X axis, 1 = Y axis, -1 = both 23 | else: 24 | res = [] 25 | for img_item in img: 26 | img_flip = cv2.flip(img_item, flip_val) 27 | res.append(img_flip) 28 | return res 29 | 30 | def random_rotate_img(images): 31 | rand_roat = np.random.randint(4, size=1) 32 | angle = 90*rand_roat 33 | center = (images.shape[0] / 2, images.shape[1] / 2) 34 | rot_matrix = cv2.getRotationMatrix2D(center, angle[0], scale=1.0) 35 | 36 | img_inst = cv2.warpAffine(images, rot_matrix, dsize=images.shape[:2], borderMode=cv2.BORDER_CONSTANT) 37 | 38 | return img_inst 39 | 40 | def random_crop(image, crop_size=(400, 400)): 41 | height, width = image.shape[:-1] 42 | dy, dx = crop_size 43 | X = np.copy(image) 44 | aX = np.zeros(tuple([3, 400, 400])) 45 | if width < dx or height < dy: 46 | return None 47 | x = np.random.randint(0, width - dx + 1) 48 | y = np.random.randint(0, height - dy + 1) 49 | aX = X[y:(y + dy), x:(x + dx), :] 50 | return aX -------------------------------------------------------------------------------- /utl/Cell_Net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from random import shuffle 4 | import numpy as np 5 | import argparse 6 | import tensorflow as tf 7 | 8 | from keras.utils import multi_gpu_model 9 | from keras.models import Model 10 | from keras.optimizers import SGD,Adam 11 | from keras.regularizers import l2 12 | from keras.layers import Input, Dense, Layer, Dropout, Conv2D, MaxPooling2D, Flatten, multiply 13 | from .metrics import bag_accuracy, bag_loss 14 | from .custom_layers import Mil_Attention, Last_Sigmoid 15 | 16 | def cell_net(input_dim, args, useMulGpu=False): 17 | 18 | lr = args.init_lr 19 | weight_decay = args.init_lr 20 | momentum = args.momentum 21 | 22 | data_input = Input(shape=input_dim, dtype='float32', name='input') 23 | conv1 = Conv2D(36, kernel_size=(4,4), kernel_regularizer=l2(weight_decay), activation='relu')(data_input) 24 | conv1 = MaxPooling2D((2,2))(conv1) 25 | 26 | conv2 = Conv2D(48, kernel_size=(3,3), kernel_regularizer=l2(weight_decay), activation='relu')(conv1) 27 | conv2 = MaxPooling2D((2,2))(conv2) 28 | x = Flatten()(conv2) 29 | 30 | fc1 = Dense(512, activation='relu',kernel_regularizer=l2(weight_decay), name='fc1')(x) 31 | fc1 = Dropout(0.5)(fc1) 32 | fc2 = Dense(512, activation='relu', kernel_regularizer=l2(weight_decay), name='fc2')(fc1) 33 | fc2 = Dropout(0.5)(fc2) 34 | 35 | # fp = Feature_pooling(output_dim=1, kernel_regularizer=l2(0.0005), pooling_mode='max', 36 | # name='fp')(fc2) 37 | 38 | alpha = Mil_Attention(L_dim=128, output_dim=1, kernel_regularizer=l2(weight_decay), name='alpha', use_gated=args.useGated)(fc2) 39 | x_mul = multiply([alpha, fc2]) 40 | 41 | out = Last_Sigmoid(output_dim=1, name='FC1_sigmoid')(x_mul) 42 | # 43 | model = Model(inputs=[data_input], outputs=[out]) 44 | 45 | # model.summary() 46 | 47 | if useMulGpu == True: 48 | parallel_model = multi_gpu_model(model, gpus=2) 49 | parallel_model.compile(optimizer=Adam(lr=lr, beta_1=0.9, beta_2=0.999), loss=bag_loss, metrics=[bag_accuracy]) 50 | else: 51 | model.compile(optimizer=Adam(lr=lr, beta_1=0.9, beta_2=0.999), loss=bag_loss, metrics=[bag_accuracy]) 52 | parallel_model = model 53 | 54 | return parallel_model 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /utl/DataGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import threading 4 | from .data_aug_op import random_flip_img, random_rotate_img 5 | #from keras.preprocessing.image import ImageDataGenerator 6 | import scipy.misc as sci 7 | 8 | class threadsafe_iter(object): 9 | """ 10 | Takes an iterator/generator and makes it thread-safe by 11 | serializing call to the `next` method of given iterator/generator. 12 | """ 13 | def __init__(self, it): 14 | self.it = it 15 | self.lock = threading.Lock() 16 | 17 | def __iter__(self): 18 | return self 19 | 20 | def next(self): 21 | with self.lock: 22 | return self.it.next() 23 | 24 | def threadsafe_generator(f): 25 | """ 26 | A decorator that takes a generator function and makes it thread-safe. 27 | """ 28 | def g(*a, **kw): 29 | return threadsafe_iter(f(*a, **kw)) 30 | return g 31 | 32 | class DataGenerator(object): 33 | def __init__(self, batch_size=32, shuffle=True): 34 | self.shuffle = shuffle 35 | self.batch_size = batch_size 36 | 37 | def __Get_exploration_order(self, list_patient, shuffle): 38 | indexes = np.arange(len(list_patient)) 39 | if shuffle: 40 | random.shuffle(indexes) 41 | return indexes 42 | 43 | def __Data_Genaration(self, batch_train): 44 | bag_batch = [] 45 | bag_label = [] 46 | 47 | #datagen = ImageDataGenerator() 48 | 49 | #transforms = { 50 | # "theta" : 0.25, 51 | # "tx" : 0.2, 52 | # "ty" : 0.2, 53 | # "shear" : 0.2, 54 | # "zx" : 0.2, 55 | # "zy" : 0.2, 56 | # "flip_horizontal" : True, 57 | # "zx" : 0.2, 58 | # "zy" : 0.2, 59 | #} 60 | 61 | for ibatch, batch in enumerate(batch_train): 62 | aug_batch = [] 63 | img_data = batch[0] 64 | for i in range(img_data.shape[0]): 65 | ori_img = img_data[i, :, :, :] 66 | # sci.imshow(ori_img) 67 | if self.shuffle: 68 | img = random_flip_img(ori_img, horizontal_chance=0.5, vertical_chance=0.5) 69 | img = random_rotate_img(img) 70 | #img = datagen.apply_transform(ori_img,transforms) 71 | else: 72 | img = ori_img 73 | exp_img = np.expand_dims(img, 0) 74 | # sci.imshow(img) 75 | aug_batch.append(exp_img) 76 | input_batch = np.concatenate(aug_batch) 77 | bag_batch.append((input_batch)) 78 | bag_label.append(batch[1]) 79 | 80 | return bag_batch, bag_label 81 | 82 | 83 | def generate(self, train_set): 84 | flag_train = self.shuffle 85 | 86 | while 1: 87 | 88 | # status_list = np.zeros(batch_size) 89 | # status_list = [] 90 | indexes = self.__Get_exploration_order(train_set, shuffle=flag_train) 91 | 92 | # Generate batches 93 | imax = int(len(indexes) / self.batch_size) 94 | 95 | for i in range(imax): 96 | Batch_train_set = [train_set[k] for k in indexes[i * self.batch_size:(i + 1) * self.batch_size]] 97 | 98 | X, y = self.__Data_Genaration(Batch_train_set) 99 | 100 | yield X, y 101 | 102 | 103 | # batch_train_set = Generate_Batch_Set(train_set, batch_size, flag_train) # Get small batch from the original set 104 | 105 | 106 | # print img_list_1[0].shape 107 | # yield img_list, status_list -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Attention-based Deep Multiple Instance Learning 2 | 3 | Attention-based Deep Multiple Instance Learning could be applied in a wide range of medical imaging applications. Supported by the project "[Deep Learning for Survival Prediction](http://ranger.uta.edu/~huang/R_Survival.htm)"@[UTA-SMILE](http://ranger.uta.edu/~huang/), I wrote the **Keras** version of ICML 2018 paper "Attention-based Deep Multiple Instance Learning" (https://arxiv.org/pdf/1802.04712.pdf) in this repo to share the solution for Keras users. 4 | 5 | The official Pytorch implementation can be found [here](https://github.com/AMLab-Amsterdam/AttentionDeepMIL). I built it with **Keras** using Tensorflow backend. I wrote attention layers described in the paper and did experiments in colon images with 10-fold cross validation. I got the very close average accuracy described in the paper and visualization results can be seen as below. Parts of codes are from https://github.com/yanyongluan/MINNs. 6 | 7 | When train the model, we only use the image-level label (0 or 1 to see if it is a cancer image). The attention layer can provide an interpretation of the decision by presenting only a small subset of positive patches. 8 | 9 | --- 10 | 11 | ### Results from my implementation 12 | 13 |

14 | 15 |

16 | 17 | ### Dataset 18 | - Colon cancer dataset [[Data]](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/crchistolabelednucleihe/) 19 | - Processed patches [[Google Drive]](https://drive.google.com/file/d/1RcNlwg0TwaZoaFO0uMXHFtAo_DCVPE6z/view?usp=sharing) 20 | 21 | I put my processed data here and you can also set up according to the paper. If you have any problem, please feel free to contact me. 22 | 23 | --- 24 | ### Applications 25 | 26 | #### The first one is our recent work. 27 | |Year|Author list|Title|Conference/Journal| 28 | |---|---|---|---| 29 | |2020|[Jiawen Yao](https://utayao.github.io/), Xinliang Zhu, Jitendra Jonnagaddala, Nicholas Hawkins, [Junzhou Huang](https://ranger.uta.edu/~huang/)|Whole slide images based cancer survival prediction using attention guided deep multiple instance learning networks. [[Pytorch]](https://github.com/uta-smile/DeepAttnMISL) | Medical Image Analysis, 101789, 2020, [[PDF]](https://www.sciencedirect.com/science/article/abs/pii/S1361841520301535?dgcid=rss_sd_all), [[arxiv]](https://arxiv.org/pdf/2009.11169.pdf)| 30 | 31 | 32 |

33 | 34 |

35 | 36 | #### Other important work used multiple-instance learning in medical imaging include (list will be updated frequently) 37 | 38 | |Year|Author list|Title|Conference/Journal| 39 | |---|---|---|---| 40 | |2021|Ming Y. Lu, Tiffany Y. Chen, Drew F. K. Williamson, Melissa Zhao, Maha Shady, Jana Lipkova & [Faisal Mahmood](https://faisal.ai/)|AI-based pathology predicts origins for cancers of unknown primary. [[Pytorch]](https://github.com/mahmoodlab/TOAD) | [Nature](https://www.nature.com/articles/s41586-021-03512-4), [arxiv](https://arxiv.org/abs/2006.13932)| 41 | |2021|Ming Y. Lu, Drew F. K. Williamson, Tiffany Y. Chen, Richard J. Chen, Matteo Barbieri & [Faisal Mahmood](https://faisal.ai/)|Data-efficient and weakly supervised computational pathology on whole-slide images. [[Pytorch]](https://github.com/mahmoodlab/CLAM) | [Nature Biomedical Engineering](https://www.nature.com/articles/s41551-020-00682-w), [arxiv](https://arxiv.org/pdf/2004.09666.pdf)| 42 | |2021|Jianan Chen, Helen M. C. Cheung, Laurent Milot and [Anne L. Martel](http://martellab.com/)|AMINN: Autoencoder-based Multiple Instance Neural Network Improves Outcome Prediction of Multifocal Liver Metastases. [[Keras]](https://github.com/martellab-sri/AMINN) | MICCAI 2021 [arxiv](https://arxiv.org/pdf/2012.06875.pdf)| 43 | |2020|Ole-Johan Skrede et al.|Deep learning for prediction of colorectal cancer outcome: a discovery and validation study|[Lancet](https://www.thelancet.com/journals/lancet/article/PIIS0140-6736(19)32998-8/fulltext)| 44 | |2019|[Shujun Wang](https://emma-sjwang.github.io/), Yaxi Zhu, et al.|RMDL: Recalibrated multi-instance deep learning for whole slide gastric image classification [[Keras]](https://github.com/EmmaW8/RMDL)|Medical Image Analysis [arxiv](https://arxiv.org/abs/2010.06440)| 45 | 46 | 47 | --- 48 | ### Contact 49 | If you have any questions about this code, I am happy to answer your issues or emails (to yjiaweneecs@gmail.com). 50 | 51 | I plan to review recent work using Deep MIL techniques in medical imaging and Your suggestions are very welcome ! 52 | 53 | ### Acknowledgments 54 | -------------------- 55 | The work conducted by [Jiawen Yao](https://utayao.github.io/) was funded by Grants from the [UTA-SMILE Lab](https://github.com/uta-smile). 56 | 57 | -------------------------------------------------------------------------------- /utl/custom_layers.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Layer 2 | from keras import backend as K 3 | from keras import activations, initializers, regularizers 4 | 5 | class Mil_Attention(Layer): 6 | """ 7 | Mil Attention Mechanism 8 | 9 | This layer contains Mil Attention Mechanism 10 | 11 | # Input Shape 12 | 2D tensor with shape: (batch_size, input_dim) 13 | 14 | # Output Shape 15 | 2D tensor with shape: (1, units) 16 | """ 17 | 18 | def __init__(self, L_dim, output_dim, kernel_initializer='glorot_uniform', kernel_regularizer=None, 19 | use_bias=True, use_gated=False, **kwargs): 20 | self.L_dim = L_dim 21 | self.output_dim = output_dim 22 | self.use_bias = use_bias 23 | self.use_gated = use_gated 24 | 25 | self.v_init = initializers.get(kernel_initializer) 26 | self.w_init = initializers.get(kernel_initializer) 27 | self.u_init = initializers.get(kernel_initializer) 28 | 29 | 30 | self.v_regularizer = regularizers.get(kernel_regularizer) 31 | self.w_regularizer = regularizers.get(kernel_regularizer) 32 | self.u_regularizer = regularizers.get(kernel_regularizer) 33 | 34 | super(Mil_Attention, self).__init__(**kwargs) 35 | 36 | def build(self, input_shape): 37 | 38 | assert len(input_shape) == 2 39 | input_dim = input_shape[1] 40 | 41 | self.V = self.add_weight(shape=(input_dim, self.L_dim), 42 | initializer=self.v_init, 43 | name='v', 44 | regularizer=self.v_regularizer, 45 | trainable=True) 46 | 47 | 48 | self.w = self.add_weight(shape=(self.L_dim, 1), 49 | initializer=self.w_init, 50 | name='w', 51 | regularizer=self.w_regularizer, 52 | trainable=True) 53 | 54 | 55 | if self.use_gated: 56 | self.U = self.add_weight(shape=(input_dim, self.L_dim), 57 | initializer=self.u_init, 58 | name='U', 59 | regularizer=self.u_regularizer, 60 | trainable=True) 61 | else: 62 | self.U = None 63 | 64 | self.input_built = True 65 | 66 | 67 | def call(self, x, mask=None): 68 | n, d = x.shape 69 | ori_x = x 70 | # do Vhk^T 71 | x = K.tanh(K.dot(x, self.V)) # (2,64) 72 | 73 | if self.use_gated: 74 | gate_x = K.sigmoid(K.dot(ori_x, self.U)) 75 | ac_x = x * gate_x 76 | else: 77 | ac_x = x 78 | 79 | # do w^T x 80 | soft_x = K.dot(ac_x, self.w) # (2,64) * (64, 1) = (2,1) 81 | alpha = K.softmax(K.transpose(soft_x)) # (2,1) 82 | alpha = K.transpose(alpha) 83 | return alpha 84 | 85 | def compute_output_shape(self, input_shape): 86 | shape = list(input_shape) 87 | assert len(shape) == 2 88 | shape[1] = self.output_dim 89 | return tuple(shape) 90 | 91 | def get_config(self): 92 | config = { 93 | 'output_dim': self.output_dim, 94 | 'v_initializer': initializers.serialize(self.V.initializer), 95 | 'w_initializer': initializers.serialize(self.w.initializer), 96 | 'v_regularizer': regularizers.serialize(self.v_regularizer), 97 | 'w_regularizer': regularizers.serialize(self.w_regularizer), 98 | 'use_bias': self.use_bias 99 | } 100 | base_config = super(Mil_Attention, self).get_config() 101 | return dict(list(base_config.items()) + list(config.items())) 102 | 103 | 104 | class Last_Sigmoid(Layer): 105 | """ 106 | Attention Activation 107 | 108 | This layer contains a FC layer which only has one neural with sigmoid actiavtion 109 | and MIL pooling. The input of this layer is instance features. Then we obtain 110 | instance scores via this FC layer. And use MIL pooling to aggregate instance scores 111 | into bag score that is the output of Score pooling layer. 112 | This layer is used in mi-Net. 113 | 114 | # Arguments 115 | output_dim: Positive integer, dimensionality of the output space 116 | kernel_initializer: Initializer of the `kernel` weights matrix 117 | bias_initializer: Initializer of the `bias` weights 118 | kernel_regularizer: Regularizer function applied to the `kernel` weights matrix 119 | bias_regularizer: Regularizer function applied to the `bias` weights 120 | use_bias: Boolean, whether use bias or not 121 | pooling_mode: A string, 122 | the mode of MIL pooling method, like 'max' (max pooling), 123 | 'ave' (average pooling), 'lse' (log-sum-exp pooling) 124 | 125 | # Input shape 126 | 2D tensor with shape: (batch_size, input_dim) 127 | # Output shape 128 | 2D tensor with shape: (1, units) 129 | """ 130 | def __init__(self, output_dim, kernel_initializer='glorot_uniform', bias_initializer='zeros', 131 | kernel_regularizer=None, bias_regularizer=None, 132 | use_bias=True, **kwargs): 133 | self.output_dim = output_dim 134 | 135 | self.kernel_initializer = initializers.get(kernel_initializer) 136 | self.bias_initializer = initializers.get(bias_initializer) 137 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 138 | self.bias_regularizer = regularizers.get(bias_regularizer) 139 | 140 | self.use_bias = use_bias 141 | super(Last_Sigmoid, self).__init__(**kwargs) 142 | 143 | def build(self, input_shape): 144 | assert len(input_shape) == 2 145 | input_dim = input_shape[1] 146 | 147 | self.kernel = self.add_weight(shape=(input_dim, self.output_dim), 148 | initializer=self.kernel_initializer, 149 | name='kernel', 150 | regularizer=self.kernel_regularizer) 151 | 152 | if self.use_bias: 153 | self.bias = self.add_weight(shape=(self.output_dim,), 154 | initializer=self.bias_initializer, 155 | name='bias', 156 | regularizer=self.bias_regularizer) 157 | else: 158 | self.bias = None 159 | 160 | self.input_built = True 161 | 162 | def call(self, x, mask=None): 163 | n, d = x.shape 164 | x = K.sum(x, axis=0, keepdims=True) 165 | # compute instance-level score 166 | x = K.dot(x, self.kernel) 167 | if self.use_bias: 168 | x = K.bias_add(x, self.bias) 169 | 170 | # sigmoid 171 | out = K.sigmoid(x) 172 | 173 | 174 | return out 175 | 176 | def compute_output_shape(self, input_shape): 177 | shape = list(input_shape) 178 | assert len(shape) == 2 179 | shape[1] = self.output_dim 180 | return tuple(shape) 181 | 182 | def get_config(self): 183 | config = { 184 | 'output_dim': self.output_dim, 185 | 'kernel_initializer': initializers.serialize(self.kernel.initializer), 186 | 'bias_initializer': initializers.serialize(self.bias_initializer), 187 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 188 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 189 | 'use_bias': self.use_bias 190 | } 191 | base_config = super(Last_Sigmoid, self).get_config() 192 | return dict(list(base_config.items()) + list(config.items())) 193 | 194 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | ''' 3 | This is a re-implementation of the following paper: 4 | "Attention-based Deep Multiple Instance Learning" 5 | I got very similar results but some data augmentation techniques not used here 6 | https://128.84.21.199/pdf/1802.04712.pdf 7 | *---- Jiawen Yao--------------* 8 | ''' 9 | 10 | 11 | import numpy as np 12 | import time 13 | from utl import Cell_Net 14 | from random import shuffle 15 | import argparse 16 | from keras.models import Model 17 | from utl.dataset import load_dataset 18 | from utl.data_aug_op import random_flip_img, random_rotate_img 19 | import glob 20 | import scipy.misc as sci 21 | import tensorflow as tf 22 | 23 | from keras import backend as K 24 | from keras.utils import multi_gpu_model 25 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard, EarlyStopping 26 | 27 | import matplotlib.pyplot as plt 28 | 29 | import os 30 | 31 | def parse_args(): 32 | """Parse input arguments. 33 | Parameters 34 | ------------------- 35 | No parameters. 36 | Returns 37 | ------------------- 38 | args: argparser.Namespace class object 39 | An argparse.Namespace class object contains experimental hyper-parameters. 40 | """ 41 | parser = argparse.ArgumentParser(description='Train a Attention-based Deep MIL') 42 | parser.add_argument('--lr', dest='init_lr', 43 | help='initial learning rate', 44 | default=1e-4, type=float) 45 | parser.add_argument('--decay', dest='weight_decay', 46 | help='weight decay', 47 | default=0.0005, type=float) 48 | parser.add_argument('--momentum', dest='momentum', 49 | help='momentum', 50 | default=0.9, type=float) 51 | parser.add_argument('--epoch', dest='max_epoch', 52 | help='number of epoch to train', 53 | default=100, type=int) 54 | parser.add_argument('--useGated', dest='useGated', 55 | help='use Gated Attention', 56 | default=False, type=int) 57 | 58 | # if len(sys.argv) == 1: 59 | # parser.print_help() 60 | # sys.exit(1) 61 | 62 | args = parser.parse_args() 63 | return args 64 | 65 | def generate_batch(path): 66 | bags = [] 67 | for each_path in path: 68 | name_img = [] 69 | img = [] 70 | img_path = glob.glob(each_path + '/*.bmp') 71 | num_ins = len(img_path) 72 | 73 | label = int(each_path.split('/')[-2]) 74 | 75 | if label == 1: 76 | curr_label = np.ones(num_ins,dtype=np.uint8) 77 | else: 78 | curr_label = np.zeros(num_ins, dtype=np.uint8) 79 | for each_img in img_path: 80 | img_data = np.asarray(sci.imread(each_img), dtype=np.float32) 81 | #img_data -= 255 82 | img_data[:, :, 0] -= 123.68 83 | img_data[:, :, 1] -= 116.779 84 | img_data[:, :, 2] -= 103.939 85 | img_data /= 255 86 | # sci.imshow(img_data) 87 | img.append(np.expand_dims(img_data,0)) 88 | name_img.append(each_img.split('/')[-1]) 89 | stack_img = np.concatenate(img, axis=0) 90 | bags.append((stack_img, curr_label, name_img)) 91 | 92 | return bags 93 | 94 | 95 | def Get_train_valid_Path(Train_set, train_percentage=0.8): 96 | """ 97 | Get path from training set 98 | :param Train_set: 99 | :param train_percentage: 100 | :return: 101 | """ 102 | import random 103 | indexes = np.arange(len(Train_set)) 104 | random.shuffle(indexes) 105 | 106 | num_train = int(train_percentage*len(Train_set)) 107 | train_index, test_index = np.asarray(indexes[:num_train]), np.asarray(indexes[num_train:]) 108 | 109 | Model_Train = [Train_set[i] for i in train_index] 110 | Model_Val = [Train_set[j] for j in test_index] 111 | 112 | return Model_Train, Model_Val 113 | 114 | 115 | def test_eval(model, test_set): 116 | """Evaluate on testing set. 117 | Parameters 118 | ----------------- 119 | model : keras.engine.training.Model object 120 | The training mi-Cell-Net model. 121 | test_set : list 122 | A list of testing set contains all training bags features and labels. 123 | Returns 124 | ----------------- 125 | test_loss : float 126 | Mean loss of evaluating on testing set. 127 | test_acc : float 128 | Mean accuracy of evaluating on testing set. 129 | """ 130 | num_test_batch = len(test_set) 131 | test_loss = np.zeros((num_test_batch, 1), dtype=float) 132 | test_acc = np.zeros((num_test_batch, 1), dtype=float) 133 | for ibatch, batch in enumerate(test_set): 134 | result = model.test_on_batch(x=batch[0], y=batch[1]) 135 | test_loss[ibatch] = result[0] 136 | test_acc[ibatch] = result[1] 137 | return np.mean(test_loss), np.mean(test_acc) 138 | 139 | def train_eval(model, train_set, irun, ifold): 140 | """Evaluate on training set. Use Keras fit_generator 141 | Parameters 142 | ----------------- 143 | model : keras.engine.training.Model object 144 | The training mi-Cell-Net model. 145 | train_set : list 146 | A list of training set contains all training bags features and labels. 147 | Returns 148 | ----------------- 149 | model_name: saved lowest val_loss model's name 150 | """ 151 | batch_size = 1 152 | model_train_set, model_val_set = Get_train_valid_Path(train_set, train_percentage=0.9) 153 | 154 | from utl.DataGenerator import DataGenerator 155 | train_gen = DataGenerator(batch_size=1, shuffle=True).generate(model_train_set) 156 | val_gen = DataGenerator(batch_size=1, shuffle=False).generate(model_val_set) 157 | 158 | model_name = "Saved_model/" + "_Batch_size_" + str(batch_size) + "epoch_" + "best.hd5" 159 | 160 | checkpoint_fixed_name = ModelCheckpoint(model_name, 161 | monitor='val_loss', verbose=1, save_best_only=True, 162 | save_weights_only=True, mode='auto', period=1) 163 | 164 | EarlyStop = EarlyStopping(monitor='val_loss', patience=20) 165 | 166 | callbacks = [checkpoint_fixed_name, EarlyStop] 167 | 168 | history = model.fit_generator(generator=train_gen, steps_per_epoch=len(model_train_set)//batch_size, 169 | epochs=args.max_epoch, validation_data=val_gen, 170 | validation_steps=len(model_val_set)//batch_size, callbacks=callbacks) 171 | 172 | train_loss = history.history['loss'] 173 | val_loss = history.history['val_loss'] 174 | 175 | train_acc = history.history['bag_accuracy'] 176 | val_acc = history.history['val_bag_accuracy'] 177 | 178 | fig = plt.figure() 179 | plt.plot(train_loss) 180 | plt.plot(val_loss) 181 | plt.title('model loss') 182 | plt.ylabel('loss') 183 | plt.xlabel('epoch') 184 | plt.legend(['train', 'val'], loc='upper left') 185 | save_fig_name = 'Results/' + str(irun) + '_' + str(ifold) + "_loss_batchsize_" + str(batch_size) + "_epoch" + ".png" 186 | fig.savefig(save_fig_name) 187 | 188 | 189 | fig = plt.figure() 190 | plt.plot(train_acc) 191 | plt.plot(val_acc) 192 | plt.title('model loss') 193 | plt.ylabel('loss') 194 | plt.xlabel('epoch') 195 | plt.legend(['train', 'val'], loc='upper left') 196 | save_fig_name = 'Results/' + str(irun) + '_' + str(ifold) + "_val_batchsize_" + str(batch_size) + "_epoch" + ".png" 197 | fig.savefig(save_fig_name) 198 | 199 | return model_name 200 | 201 | 202 | def model_training(input_dim, dataset, irun, ifold): 203 | 204 | train_bags = dataset['train'] 205 | test_bags = dataset['test'] 206 | 207 | # convert bag to batch 208 | train_set = generate_batch(train_bags) 209 | test_set = generate_batch(test_bags) 210 | 211 | model = Cell_Net.cell_net(input_dim, args, useMulGpu=False) 212 | 213 | # train model 214 | t1 = time.time() 215 | num_batch = len(train_set) 216 | # for epoch in range(args.max_epoch): 217 | model_name = train_eval(model, train_set, irun, ifold) 218 | 219 | print("load saved model weights") 220 | model.load_weights(model_name) 221 | 222 | test_loss, test_acc = test_eval(model, test_set) 223 | 224 | t2 = time.time() 225 | # 226 | 227 | print ('run time:', (t2 - t1) / 60.0, 'min') 228 | print ('test_acc={:.3f}'.format(test_acc)) 229 | 230 | return test_acc 231 | 232 | 233 | 234 | if __name__ == "__main__": 235 | 236 | args = parse_args() 237 | 238 | print ('Called with args:') 239 | print (args) 240 | 241 | input_dim = (27,27,3) 242 | 243 | run = 1 244 | n_folds = 10 245 | acc = np.zeros((run, n_folds), dtype=float) 246 | data_path = '../data/Patches' 247 | 248 | for irun in range(run): 249 | dataset = load_dataset(dataset_path=data_path, n_folds=n_folds, rand_state=irun) 250 | for ifold in range(n_folds): 251 | print ('run=', irun, ' fold=', ifold) 252 | acc[irun][ifold] = model_training(input_dim, dataset[ifold], irun, ifold) 253 | print ('mi-net mean accuracy = ', np.mean(acc)) 254 | print ('std = ', np.std(acc)) 255 | 256 | --------------------------------------------------------------------------------