├── img80.jpg
├── result.png
├── utl
├── __init__.py
├── dataset.py
├── metrics.py
├── data_aug_op.py
├── Cell_Net.py
├── DataGenerator.py
└── custom_layers.py
├── README.md
└── main.py
/img80.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utayao/Atten_Deep_MIL/HEAD/img80.jpg
--------------------------------------------------------------------------------
/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utayao/Atten_Deep_MIL/HEAD/result.png
--------------------------------------------------------------------------------
/utl/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from . import dataset
3 | from . import Cell_Net
4 | from . import custom_layers
5 | from . import metrics
--------------------------------------------------------------------------------
/utl/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | from sklearn.model_selection import KFold
4 |
5 | def load_dataset(dataset_path, n_folds, rand_state):
6 | """
7 | Parameters
8 | --------------------
9 | :param dataset_path:
10 | :param n_folds:
11 | :return: list
12 | List contains split datasets for K-Fold cross-validation
13 | """
14 |
15 | # load datapath from path
16 | pos_path = glob.glob(dataset_path+'/1/img*')
17 | neg_path = glob.glob(dataset_path+'/0/img*')
18 |
19 | pos_num = len(pos_path)
20 | neg_num = len(neg_path)
21 |
22 | all_path = pos_path + neg_path
23 |
24 | #num_bag = len(all_path)
25 | kf = KFold(n_splits=n_folds, shuffle=True, random_state=rand_state)
26 | datasets = []
27 | for train_idx, test_idx in kf.split(all_path):
28 | dataset = {}
29 | dataset['train'] = [all_path[ibag] for ibag in train_idx]
30 | dataset['test'] = [all_path[ibag] for ibag in test_idx]
31 | datasets.append(dataset)
32 | return datasets
--------------------------------------------------------------------------------
/utl/metrics.py:
--------------------------------------------------------------------------------
1 | from keras import backend as K
2 |
3 | def bag_accuracy(y_true, y_pred):
4 | """Compute accuracy of one bag.
5 | Parameters
6 | ---------------------
7 | y_true : Tensor (N x 1)
8 | GroundTruth of bag.
9 | y_pred : Tensor (1 X 1)
10 | Prediction score of bag.
11 | Return
12 | ---------------------
13 | acc : Tensor (1 x 1)
14 | Accuracy of bag label prediction.
15 | """
16 | y_true = K.mean(y_true, axis=0, keepdims=False)
17 | y_pred = K.mean(y_pred, axis=0, keepdims=False)
18 | acc = K.mean(K.equal(y_true, K.round(y_pred)))
19 | return acc
20 |
21 |
22 | def bag_loss(y_true, y_pred):
23 | """Compute binary crossentropy loss of predicting bag loss.
24 | Parameters
25 | ---------------------
26 | y_true : Tensor (N x 1)
27 | GroundTruth of bag.
28 | y_pred : Tensor (1 X 1)
29 | Prediction score of bag.
30 | Return
31 | ---------------------
32 | acc : Tensor (1 x 1)
33 | Binary Crossentropy loss of predicting bag label.
34 | """
35 | y_true = K.mean(y_true, axis=0, keepdims=False)
36 | y_pred = K.mean(y_pred, axis=0, keepdims=False)
37 | loss = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
38 | return loss
--------------------------------------------------------------------------------
/utl/data_aug_op.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 |
5 | def random_flip_img(img, horizontal_chance=0, vertical_chance=0):
6 | flip_horizontal = False
7 | if random.random() < horizontal_chance:
8 | flip_horizontal = True
9 |
10 | flip_vertical = False
11 | if random.random() < vertical_chance:
12 | flip_vertical = True
13 |
14 | if not flip_horizontal and not flip_vertical:
15 | return img
16 |
17 | flip_val = 1
18 | if flip_vertical:
19 | flip_val = -1 if flip_horizontal else 0
20 |
21 | if not isinstance(img, list):
22 | res = cv2.flip(img, flip_val) # 0 = X axis, 1 = Y axis, -1 = both
23 | else:
24 | res = []
25 | for img_item in img:
26 | img_flip = cv2.flip(img_item, flip_val)
27 | res.append(img_flip)
28 | return res
29 |
30 | def random_rotate_img(images):
31 | rand_roat = np.random.randint(4, size=1)
32 | angle = 90*rand_roat
33 | center = (images.shape[0] / 2, images.shape[1] / 2)
34 | rot_matrix = cv2.getRotationMatrix2D(center, angle[0], scale=1.0)
35 |
36 | img_inst = cv2.warpAffine(images, rot_matrix, dsize=images.shape[:2], borderMode=cv2.BORDER_CONSTANT)
37 |
38 | return img_inst
39 |
40 | def random_crop(image, crop_size=(400, 400)):
41 | height, width = image.shape[:-1]
42 | dy, dx = crop_size
43 | X = np.copy(image)
44 | aX = np.zeros(tuple([3, 400, 400]))
45 | if width < dx or height < dy:
46 | return None
47 | x = np.random.randint(0, width - dx + 1)
48 | y = np.random.randint(0, height - dy + 1)
49 | aX = X[y:(y + dy), x:(x + dx), :]
50 | return aX
--------------------------------------------------------------------------------
/utl/Cell_Net.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | from random import shuffle
4 | import numpy as np
5 | import argparse
6 | import tensorflow as tf
7 |
8 | from keras.utils import multi_gpu_model
9 | from keras.models import Model
10 | from keras.optimizers import SGD,Adam
11 | from keras.regularizers import l2
12 | from keras.layers import Input, Dense, Layer, Dropout, Conv2D, MaxPooling2D, Flatten, multiply
13 | from .metrics import bag_accuracy, bag_loss
14 | from .custom_layers import Mil_Attention, Last_Sigmoid
15 |
16 | def cell_net(input_dim, args, useMulGpu=False):
17 |
18 | lr = args.init_lr
19 | weight_decay = args.init_lr
20 | momentum = args.momentum
21 |
22 | data_input = Input(shape=input_dim, dtype='float32', name='input')
23 | conv1 = Conv2D(36, kernel_size=(4,4), kernel_regularizer=l2(weight_decay), activation='relu')(data_input)
24 | conv1 = MaxPooling2D((2,2))(conv1)
25 |
26 | conv2 = Conv2D(48, kernel_size=(3,3), kernel_regularizer=l2(weight_decay), activation='relu')(conv1)
27 | conv2 = MaxPooling2D((2,2))(conv2)
28 | x = Flatten()(conv2)
29 |
30 | fc1 = Dense(512, activation='relu',kernel_regularizer=l2(weight_decay), name='fc1')(x)
31 | fc1 = Dropout(0.5)(fc1)
32 | fc2 = Dense(512, activation='relu', kernel_regularizer=l2(weight_decay), name='fc2')(fc1)
33 | fc2 = Dropout(0.5)(fc2)
34 |
35 | # fp = Feature_pooling(output_dim=1, kernel_regularizer=l2(0.0005), pooling_mode='max',
36 | # name='fp')(fc2)
37 |
38 | alpha = Mil_Attention(L_dim=128, output_dim=1, kernel_regularizer=l2(weight_decay), name='alpha', use_gated=args.useGated)(fc2)
39 | x_mul = multiply([alpha, fc2])
40 |
41 | out = Last_Sigmoid(output_dim=1, name='FC1_sigmoid')(x_mul)
42 | #
43 | model = Model(inputs=[data_input], outputs=[out])
44 |
45 | # model.summary()
46 |
47 | if useMulGpu == True:
48 | parallel_model = multi_gpu_model(model, gpus=2)
49 | parallel_model.compile(optimizer=Adam(lr=lr, beta_1=0.9, beta_2=0.999), loss=bag_loss, metrics=[bag_accuracy])
50 | else:
51 | model.compile(optimizer=Adam(lr=lr, beta_1=0.9, beta_2=0.999), loss=bag_loss, metrics=[bag_accuracy])
52 | parallel_model = model
53 |
54 | return parallel_model
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/utl/DataGenerator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import threading
4 | from .data_aug_op import random_flip_img, random_rotate_img
5 | #from keras.preprocessing.image import ImageDataGenerator
6 | import scipy.misc as sci
7 |
8 | class threadsafe_iter(object):
9 | """
10 | Takes an iterator/generator and makes it thread-safe by
11 | serializing call to the `next` method of given iterator/generator.
12 | """
13 | def __init__(self, it):
14 | self.it = it
15 | self.lock = threading.Lock()
16 |
17 | def __iter__(self):
18 | return self
19 |
20 | def next(self):
21 | with self.lock:
22 | return self.it.next()
23 |
24 | def threadsafe_generator(f):
25 | """
26 | A decorator that takes a generator function and makes it thread-safe.
27 | """
28 | def g(*a, **kw):
29 | return threadsafe_iter(f(*a, **kw))
30 | return g
31 |
32 | class DataGenerator(object):
33 | def __init__(self, batch_size=32, shuffle=True):
34 | self.shuffle = shuffle
35 | self.batch_size = batch_size
36 |
37 | def __Get_exploration_order(self, list_patient, shuffle):
38 | indexes = np.arange(len(list_patient))
39 | if shuffle:
40 | random.shuffle(indexes)
41 | return indexes
42 |
43 | def __Data_Genaration(self, batch_train):
44 | bag_batch = []
45 | bag_label = []
46 |
47 | #datagen = ImageDataGenerator()
48 |
49 | #transforms = {
50 | # "theta" : 0.25,
51 | # "tx" : 0.2,
52 | # "ty" : 0.2,
53 | # "shear" : 0.2,
54 | # "zx" : 0.2,
55 | # "zy" : 0.2,
56 | # "flip_horizontal" : True,
57 | # "zx" : 0.2,
58 | # "zy" : 0.2,
59 | #}
60 |
61 | for ibatch, batch in enumerate(batch_train):
62 | aug_batch = []
63 | img_data = batch[0]
64 | for i in range(img_data.shape[0]):
65 | ori_img = img_data[i, :, :, :]
66 | # sci.imshow(ori_img)
67 | if self.shuffle:
68 | img = random_flip_img(ori_img, horizontal_chance=0.5, vertical_chance=0.5)
69 | img = random_rotate_img(img)
70 | #img = datagen.apply_transform(ori_img,transforms)
71 | else:
72 | img = ori_img
73 | exp_img = np.expand_dims(img, 0)
74 | # sci.imshow(img)
75 | aug_batch.append(exp_img)
76 | input_batch = np.concatenate(aug_batch)
77 | bag_batch.append((input_batch))
78 | bag_label.append(batch[1])
79 |
80 | return bag_batch, bag_label
81 |
82 |
83 | def generate(self, train_set):
84 | flag_train = self.shuffle
85 |
86 | while 1:
87 |
88 | # status_list = np.zeros(batch_size)
89 | # status_list = []
90 | indexes = self.__Get_exploration_order(train_set, shuffle=flag_train)
91 |
92 | # Generate batches
93 | imax = int(len(indexes) / self.batch_size)
94 |
95 | for i in range(imax):
96 | Batch_train_set = [train_set[k] for k in indexes[i * self.batch_size:(i + 1) * self.batch_size]]
97 |
98 | X, y = self.__Data_Genaration(Batch_train_set)
99 |
100 | yield X, y
101 |
102 |
103 | # batch_train_set = Generate_Batch_Set(train_set, batch_size, flag_train) # Get small batch from the original set
104 |
105 |
106 | # print img_list_1[0].shape
107 | # yield img_list, status_list
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Attention-based Deep Multiple Instance Learning
2 |
3 | Attention-based Deep Multiple Instance Learning could be applied in a wide range of medical imaging applications. Supported by the project "[Deep Learning for Survival Prediction](http://ranger.uta.edu/~huang/R_Survival.htm)"@[UTA-SMILE](http://ranger.uta.edu/~huang/), I wrote the **Keras** version of ICML 2018 paper "Attention-based Deep Multiple Instance Learning" (https://arxiv.org/pdf/1802.04712.pdf) in this repo to share the solution for Keras users.
4 |
5 | The official Pytorch implementation can be found [here](https://github.com/AMLab-Amsterdam/AttentionDeepMIL). I built it with **Keras** using Tensorflow backend. I wrote attention layers described in the paper and did experiments in colon images with 10-fold cross validation. I got the very close average accuracy described in the paper and visualization results can be seen as below. Parts of codes are from https://github.com/yanyongluan/MINNs.
6 |
7 | When train the model, we only use the image-level label (0 or 1 to see if it is a cancer image). The attention layer can provide an interpretation of the decision by presenting only a small subset of positive patches.
8 |
9 | ---
10 |
11 | ### Results from my implementation
12 |
13 |
14 |
15 |
16 |
17 | ### Dataset
18 | - Colon cancer dataset [[Data]](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/crchistolabelednucleihe/)
19 | - Processed patches [[Google Drive]](https://drive.google.com/file/d/1RcNlwg0TwaZoaFO0uMXHFtAo_DCVPE6z/view?usp=sharing)
20 |
21 | I put my processed data here and you can also set up according to the paper. If you have any problem, please feel free to contact me.
22 |
23 | ---
24 | ### Applications
25 |
26 | #### The first one is our recent work.
27 | |Year|Author list|Title|Conference/Journal|
28 | |---|---|---|---|
29 | |2020|[Jiawen Yao](https://utayao.github.io/), Xinliang Zhu, Jitendra Jonnagaddala, Nicholas Hawkins, [Junzhou Huang](https://ranger.uta.edu/~huang/)|Whole slide images based cancer survival prediction using attention guided deep multiple instance learning networks. [[Pytorch]](https://github.com/uta-smile/DeepAttnMISL) | Medical Image Analysis, 101789, 2020, [[PDF]](https://www.sciencedirect.com/science/article/abs/pii/S1361841520301535?dgcid=rss_sd_all), [[arxiv]](https://arxiv.org/pdf/2009.11169.pdf)|
30 |
31 |
32 |
33 |
34 |
35 |
36 | #### Other important work used multiple-instance learning in medical imaging include (list will be updated frequently)
37 |
38 | |Year|Author list|Title|Conference/Journal|
39 | |---|---|---|---|
40 | |2021|Ming Y. Lu, Tiffany Y. Chen, Drew F. K. Williamson, Melissa Zhao, Maha Shady, Jana Lipkova & [Faisal Mahmood](https://faisal.ai/)|AI-based pathology predicts origins for cancers of unknown primary. [[Pytorch]](https://github.com/mahmoodlab/TOAD) | [Nature](https://www.nature.com/articles/s41586-021-03512-4), [arxiv](https://arxiv.org/abs/2006.13932)|
41 | |2021|Ming Y. Lu, Drew F. K. Williamson, Tiffany Y. Chen, Richard J. Chen, Matteo Barbieri & [Faisal Mahmood](https://faisal.ai/)|Data-efficient and weakly supervised computational pathology on whole-slide images. [[Pytorch]](https://github.com/mahmoodlab/CLAM) | [Nature Biomedical Engineering](https://www.nature.com/articles/s41551-020-00682-w), [arxiv](https://arxiv.org/pdf/2004.09666.pdf)|
42 | |2021|Jianan Chen, Helen M. C. Cheung, Laurent Milot and [Anne L. Martel](http://martellab.com/)|AMINN: Autoencoder-based Multiple Instance Neural Network Improves Outcome Prediction of Multifocal Liver Metastases. [[Keras]](https://github.com/martellab-sri/AMINN) | MICCAI 2021 [arxiv](https://arxiv.org/pdf/2012.06875.pdf)|
43 | |2020|Ole-Johan Skrede et al.|Deep learning for prediction of colorectal cancer outcome: a discovery and validation study|[Lancet](https://www.thelancet.com/journals/lancet/article/PIIS0140-6736(19)32998-8/fulltext)|
44 | |2019|[Shujun Wang](https://emma-sjwang.github.io/), Yaxi Zhu, et al.|RMDL: Recalibrated multi-instance deep learning for whole slide gastric image classification [[Keras]](https://github.com/EmmaW8/RMDL)|Medical Image Analysis [arxiv](https://arxiv.org/abs/2010.06440)|
45 |
46 |
47 | ---
48 | ### Contact
49 | If you have any questions about this code, I am happy to answer your issues or emails (to yjiaweneecs@gmail.com).
50 |
51 | I plan to review recent work using Deep MIL techniques in medical imaging and Your suggestions are very welcome !
52 |
53 | ### Acknowledgments
54 | --------------------
55 | The work conducted by [Jiawen Yao](https://utayao.github.io/) was funded by Grants from the [UTA-SMILE Lab](https://github.com/uta-smile).
56 |
57 |
--------------------------------------------------------------------------------
/utl/custom_layers.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Layer
2 | from keras import backend as K
3 | from keras import activations, initializers, regularizers
4 |
5 | class Mil_Attention(Layer):
6 | """
7 | Mil Attention Mechanism
8 |
9 | This layer contains Mil Attention Mechanism
10 |
11 | # Input Shape
12 | 2D tensor with shape: (batch_size, input_dim)
13 |
14 | # Output Shape
15 | 2D tensor with shape: (1, units)
16 | """
17 |
18 | def __init__(self, L_dim, output_dim, kernel_initializer='glorot_uniform', kernel_regularizer=None,
19 | use_bias=True, use_gated=False, **kwargs):
20 | self.L_dim = L_dim
21 | self.output_dim = output_dim
22 | self.use_bias = use_bias
23 | self.use_gated = use_gated
24 |
25 | self.v_init = initializers.get(kernel_initializer)
26 | self.w_init = initializers.get(kernel_initializer)
27 | self.u_init = initializers.get(kernel_initializer)
28 |
29 |
30 | self.v_regularizer = regularizers.get(kernel_regularizer)
31 | self.w_regularizer = regularizers.get(kernel_regularizer)
32 | self.u_regularizer = regularizers.get(kernel_regularizer)
33 |
34 | super(Mil_Attention, self).__init__(**kwargs)
35 |
36 | def build(self, input_shape):
37 |
38 | assert len(input_shape) == 2
39 | input_dim = input_shape[1]
40 |
41 | self.V = self.add_weight(shape=(input_dim, self.L_dim),
42 | initializer=self.v_init,
43 | name='v',
44 | regularizer=self.v_regularizer,
45 | trainable=True)
46 |
47 |
48 | self.w = self.add_weight(shape=(self.L_dim, 1),
49 | initializer=self.w_init,
50 | name='w',
51 | regularizer=self.w_regularizer,
52 | trainable=True)
53 |
54 |
55 | if self.use_gated:
56 | self.U = self.add_weight(shape=(input_dim, self.L_dim),
57 | initializer=self.u_init,
58 | name='U',
59 | regularizer=self.u_regularizer,
60 | trainable=True)
61 | else:
62 | self.U = None
63 |
64 | self.input_built = True
65 |
66 |
67 | def call(self, x, mask=None):
68 | n, d = x.shape
69 | ori_x = x
70 | # do Vhk^T
71 | x = K.tanh(K.dot(x, self.V)) # (2,64)
72 |
73 | if self.use_gated:
74 | gate_x = K.sigmoid(K.dot(ori_x, self.U))
75 | ac_x = x * gate_x
76 | else:
77 | ac_x = x
78 |
79 | # do w^T x
80 | soft_x = K.dot(ac_x, self.w) # (2,64) * (64, 1) = (2,1)
81 | alpha = K.softmax(K.transpose(soft_x)) # (2,1)
82 | alpha = K.transpose(alpha)
83 | return alpha
84 |
85 | def compute_output_shape(self, input_shape):
86 | shape = list(input_shape)
87 | assert len(shape) == 2
88 | shape[1] = self.output_dim
89 | return tuple(shape)
90 |
91 | def get_config(self):
92 | config = {
93 | 'output_dim': self.output_dim,
94 | 'v_initializer': initializers.serialize(self.V.initializer),
95 | 'w_initializer': initializers.serialize(self.w.initializer),
96 | 'v_regularizer': regularizers.serialize(self.v_regularizer),
97 | 'w_regularizer': regularizers.serialize(self.w_regularizer),
98 | 'use_bias': self.use_bias
99 | }
100 | base_config = super(Mil_Attention, self).get_config()
101 | return dict(list(base_config.items()) + list(config.items()))
102 |
103 |
104 | class Last_Sigmoid(Layer):
105 | """
106 | Attention Activation
107 |
108 | This layer contains a FC layer which only has one neural with sigmoid actiavtion
109 | and MIL pooling. The input of this layer is instance features. Then we obtain
110 | instance scores via this FC layer. And use MIL pooling to aggregate instance scores
111 | into bag score that is the output of Score pooling layer.
112 | This layer is used in mi-Net.
113 |
114 | # Arguments
115 | output_dim: Positive integer, dimensionality of the output space
116 | kernel_initializer: Initializer of the `kernel` weights matrix
117 | bias_initializer: Initializer of the `bias` weights
118 | kernel_regularizer: Regularizer function applied to the `kernel` weights matrix
119 | bias_regularizer: Regularizer function applied to the `bias` weights
120 | use_bias: Boolean, whether use bias or not
121 | pooling_mode: A string,
122 | the mode of MIL pooling method, like 'max' (max pooling),
123 | 'ave' (average pooling), 'lse' (log-sum-exp pooling)
124 |
125 | # Input shape
126 | 2D tensor with shape: (batch_size, input_dim)
127 | # Output shape
128 | 2D tensor with shape: (1, units)
129 | """
130 | def __init__(self, output_dim, kernel_initializer='glorot_uniform', bias_initializer='zeros',
131 | kernel_regularizer=None, bias_regularizer=None,
132 | use_bias=True, **kwargs):
133 | self.output_dim = output_dim
134 |
135 | self.kernel_initializer = initializers.get(kernel_initializer)
136 | self.bias_initializer = initializers.get(bias_initializer)
137 | self.kernel_regularizer = regularizers.get(kernel_regularizer)
138 | self.bias_regularizer = regularizers.get(bias_regularizer)
139 |
140 | self.use_bias = use_bias
141 | super(Last_Sigmoid, self).__init__(**kwargs)
142 |
143 | def build(self, input_shape):
144 | assert len(input_shape) == 2
145 | input_dim = input_shape[1]
146 |
147 | self.kernel = self.add_weight(shape=(input_dim, self.output_dim),
148 | initializer=self.kernel_initializer,
149 | name='kernel',
150 | regularizer=self.kernel_regularizer)
151 |
152 | if self.use_bias:
153 | self.bias = self.add_weight(shape=(self.output_dim,),
154 | initializer=self.bias_initializer,
155 | name='bias',
156 | regularizer=self.bias_regularizer)
157 | else:
158 | self.bias = None
159 |
160 | self.input_built = True
161 |
162 | def call(self, x, mask=None):
163 | n, d = x.shape
164 | x = K.sum(x, axis=0, keepdims=True)
165 | # compute instance-level score
166 | x = K.dot(x, self.kernel)
167 | if self.use_bias:
168 | x = K.bias_add(x, self.bias)
169 |
170 | # sigmoid
171 | out = K.sigmoid(x)
172 |
173 |
174 | return out
175 |
176 | def compute_output_shape(self, input_shape):
177 | shape = list(input_shape)
178 | assert len(shape) == 2
179 | shape[1] = self.output_dim
180 | return tuple(shape)
181 |
182 | def get_config(self):
183 | config = {
184 | 'output_dim': self.output_dim,
185 | 'kernel_initializer': initializers.serialize(self.kernel.initializer),
186 | 'bias_initializer': initializers.serialize(self.bias_initializer),
187 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
188 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
189 | 'use_bias': self.use_bias
190 | }
191 | base_config = super(Last_Sigmoid, self).get_config()
192 | return dict(list(base_config.items()) + list(config.items()))
193 |
194 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | '''
3 | This is a re-implementation of the following paper:
4 | "Attention-based Deep Multiple Instance Learning"
5 | I got very similar results but some data augmentation techniques not used here
6 | https://128.84.21.199/pdf/1802.04712.pdf
7 | *---- Jiawen Yao--------------*
8 | '''
9 |
10 |
11 | import numpy as np
12 | import time
13 | from utl import Cell_Net
14 | from random import shuffle
15 | import argparse
16 | from keras.models import Model
17 | from utl.dataset import load_dataset
18 | from utl.data_aug_op import random_flip_img, random_rotate_img
19 | import glob
20 | import scipy.misc as sci
21 | import tensorflow as tf
22 |
23 | from keras import backend as K
24 | from keras.utils import multi_gpu_model
25 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard, EarlyStopping
26 |
27 | import matplotlib.pyplot as plt
28 |
29 | import os
30 |
31 | def parse_args():
32 | """Parse input arguments.
33 | Parameters
34 | -------------------
35 | No parameters.
36 | Returns
37 | -------------------
38 | args: argparser.Namespace class object
39 | An argparse.Namespace class object contains experimental hyper-parameters.
40 | """
41 | parser = argparse.ArgumentParser(description='Train a Attention-based Deep MIL')
42 | parser.add_argument('--lr', dest='init_lr',
43 | help='initial learning rate',
44 | default=1e-4, type=float)
45 | parser.add_argument('--decay', dest='weight_decay',
46 | help='weight decay',
47 | default=0.0005, type=float)
48 | parser.add_argument('--momentum', dest='momentum',
49 | help='momentum',
50 | default=0.9, type=float)
51 | parser.add_argument('--epoch', dest='max_epoch',
52 | help='number of epoch to train',
53 | default=100, type=int)
54 | parser.add_argument('--useGated', dest='useGated',
55 | help='use Gated Attention',
56 | default=False, type=int)
57 |
58 | # if len(sys.argv) == 1:
59 | # parser.print_help()
60 | # sys.exit(1)
61 |
62 | args = parser.parse_args()
63 | return args
64 |
65 | def generate_batch(path):
66 | bags = []
67 | for each_path in path:
68 | name_img = []
69 | img = []
70 | img_path = glob.glob(each_path + '/*.bmp')
71 | num_ins = len(img_path)
72 |
73 | label = int(each_path.split('/')[-2])
74 |
75 | if label == 1:
76 | curr_label = np.ones(num_ins,dtype=np.uint8)
77 | else:
78 | curr_label = np.zeros(num_ins, dtype=np.uint8)
79 | for each_img in img_path:
80 | img_data = np.asarray(sci.imread(each_img), dtype=np.float32)
81 | #img_data -= 255
82 | img_data[:, :, 0] -= 123.68
83 | img_data[:, :, 1] -= 116.779
84 | img_data[:, :, 2] -= 103.939
85 | img_data /= 255
86 | # sci.imshow(img_data)
87 | img.append(np.expand_dims(img_data,0))
88 | name_img.append(each_img.split('/')[-1])
89 | stack_img = np.concatenate(img, axis=0)
90 | bags.append((stack_img, curr_label, name_img))
91 |
92 | return bags
93 |
94 |
95 | def Get_train_valid_Path(Train_set, train_percentage=0.8):
96 | """
97 | Get path from training set
98 | :param Train_set:
99 | :param train_percentage:
100 | :return:
101 | """
102 | import random
103 | indexes = np.arange(len(Train_set))
104 | random.shuffle(indexes)
105 |
106 | num_train = int(train_percentage*len(Train_set))
107 | train_index, test_index = np.asarray(indexes[:num_train]), np.asarray(indexes[num_train:])
108 |
109 | Model_Train = [Train_set[i] for i in train_index]
110 | Model_Val = [Train_set[j] for j in test_index]
111 |
112 | return Model_Train, Model_Val
113 |
114 |
115 | def test_eval(model, test_set):
116 | """Evaluate on testing set.
117 | Parameters
118 | -----------------
119 | model : keras.engine.training.Model object
120 | The training mi-Cell-Net model.
121 | test_set : list
122 | A list of testing set contains all training bags features and labels.
123 | Returns
124 | -----------------
125 | test_loss : float
126 | Mean loss of evaluating on testing set.
127 | test_acc : float
128 | Mean accuracy of evaluating on testing set.
129 | """
130 | num_test_batch = len(test_set)
131 | test_loss = np.zeros((num_test_batch, 1), dtype=float)
132 | test_acc = np.zeros((num_test_batch, 1), dtype=float)
133 | for ibatch, batch in enumerate(test_set):
134 | result = model.test_on_batch(x=batch[0], y=batch[1])
135 | test_loss[ibatch] = result[0]
136 | test_acc[ibatch] = result[1]
137 | return np.mean(test_loss), np.mean(test_acc)
138 |
139 | def train_eval(model, train_set, irun, ifold):
140 | """Evaluate on training set. Use Keras fit_generator
141 | Parameters
142 | -----------------
143 | model : keras.engine.training.Model object
144 | The training mi-Cell-Net model.
145 | train_set : list
146 | A list of training set contains all training bags features and labels.
147 | Returns
148 | -----------------
149 | model_name: saved lowest val_loss model's name
150 | """
151 | batch_size = 1
152 | model_train_set, model_val_set = Get_train_valid_Path(train_set, train_percentage=0.9)
153 |
154 | from utl.DataGenerator import DataGenerator
155 | train_gen = DataGenerator(batch_size=1, shuffle=True).generate(model_train_set)
156 | val_gen = DataGenerator(batch_size=1, shuffle=False).generate(model_val_set)
157 |
158 | model_name = "Saved_model/" + "_Batch_size_" + str(batch_size) + "epoch_" + "best.hd5"
159 |
160 | checkpoint_fixed_name = ModelCheckpoint(model_name,
161 | monitor='val_loss', verbose=1, save_best_only=True,
162 | save_weights_only=True, mode='auto', period=1)
163 |
164 | EarlyStop = EarlyStopping(monitor='val_loss', patience=20)
165 |
166 | callbacks = [checkpoint_fixed_name, EarlyStop]
167 |
168 | history = model.fit_generator(generator=train_gen, steps_per_epoch=len(model_train_set)//batch_size,
169 | epochs=args.max_epoch, validation_data=val_gen,
170 | validation_steps=len(model_val_set)//batch_size, callbacks=callbacks)
171 |
172 | train_loss = history.history['loss']
173 | val_loss = history.history['val_loss']
174 |
175 | train_acc = history.history['bag_accuracy']
176 | val_acc = history.history['val_bag_accuracy']
177 |
178 | fig = plt.figure()
179 | plt.plot(train_loss)
180 | plt.plot(val_loss)
181 | plt.title('model loss')
182 | plt.ylabel('loss')
183 | plt.xlabel('epoch')
184 | plt.legend(['train', 'val'], loc='upper left')
185 | save_fig_name = 'Results/' + str(irun) + '_' + str(ifold) + "_loss_batchsize_" + str(batch_size) + "_epoch" + ".png"
186 | fig.savefig(save_fig_name)
187 |
188 |
189 | fig = plt.figure()
190 | plt.plot(train_acc)
191 | plt.plot(val_acc)
192 | plt.title('model loss')
193 | plt.ylabel('loss')
194 | plt.xlabel('epoch')
195 | plt.legend(['train', 'val'], loc='upper left')
196 | save_fig_name = 'Results/' + str(irun) + '_' + str(ifold) + "_val_batchsize_" + str(batch_size) + "_epoch" + ".png"
197 | fig.savefig(save_fig_name)
198 |
199 | return model_name
200 |
201 |
202 | def model_training(input_dim, dataset, irun, ifold):
203 |
204 | train_bags = dataset['train']
205 | test_bags = dataset['test']
206 |
207 | # convert bag to batch
208 | train_set = generate_batch(train_bags)
209 | test_set = generate_batch(test_bags)
210 |
211 | model = Cell_Net.cell_net(input_dim, args, useMulGpu=False)
212 |
213 | # train model
214 | t1 = time.time()
215 | num_batch = len(train_set)
216 | # for epoch in range(args.max_epoch):
217 | model_name = train_eval(model, train_set, irun, ifold)
218 |
219 | print("load saved model weights")
220 | model.load_weights(model_name)
221 |
222 | test_loss, test_acc = test_eval(model, test_set)
223 |
224 | t2 = time.time()
225 | #
226 |
227 | print ('run time:', (t2 - t1) / 60.0, 'min')
228 | print ('test_acc={:.3f}'.format(test_acc))
229 |
230 | return test_acc
231 |
232 |
233 |
234 | if __name__ == "__main__":
235 |
236 | args = parse_args()
237 |
238 | print ('Called with args:')
239 | print (args)
240 |
241 | input_dim = (27,27,3)
242 |
243 | run = 1
244 | n_folds = 10
245 | acc = np.zeros((run, n_folds), dtype=float)
246 | data_path = '../data/Patches'
247 |
248 | for irun in range(run):
249 | dataset = load_dataset(dataset_path=data_path, n_folds=n_folds, rand_state=irun)
250 | for ifold in range(n_folds):
251 | print ('run=', irun, ' fold=', ifold)
252 | acc[irun][ifold] = model_training(input_dim, dataset[ifold], irun, ifold)
253 | print ('mi-net mean accuracy = ', np.mean(acc))
254 | print ('std = ', np.std(acc))
255 |
256 |
--------------------------------------------------------------------------------