├── README.md ├── segmentation ├── 3d_attention_unet.py ├── BuildingBlocks.py └── sca_3d.py └── survival_prediction ├── matlab ├── Bland Altman plot.ipynb ├── Brats_valid │ ├── .ipynb_checkpoints │ │ ├── Best_model_withRFE_XGB-checkpoint.ipynb │ │ ├── Normalizing-checkpoint.ipynb │ │ ├── Radiomics_model-checkpoint.ipynb │ │ ├── XGB_withRFE_crossvalidation-checkpoint.ipynb │ │ ├── XGBregressor-checkpoint.ipynb │ │ ├── XGBregressor-rfe-valid-checkpoint.ipynb │ │ └── training_cross_valid-checkpoint.ipynb │ ├── Best_model_withRFE_XGB.ipynb │ ├── Normalizing.ipynb │ ├── XGB_withRFE_crossvalidation.ipynb │ ├── XGBregressor.ipynb │ ├── radiomics_normalized.csv │ ├── radiomics_normalized_SS.csv │ ├── radiomics_normalized_new.csv │ ├── radiomics_test_normalized.csv │ ├── radiomics_valid_normalized.csv │ ├── radiomics_valid_normalized_SS.csv │ ├── radiomics_valid_normalized_new.csv │ ├── submission.csv │ ├── submission_best_14.csv │ ├── submission_best_14_test.csv │ ├── submission_best_30.csv │ ├── test_dataset_beforenormalizing.csv │ ├── train_cross.csv │ ├── train_dataset_beforenormalizing.csv │ ├── training_cross_valid.ipynb │ └── valid_dataset_beforenormalizing.csv ├── Feature_Extraction │ ├── boxcount │ │ ├── Apollonian_gasket.gif │ │ ├── Contents.m │ │ ├── Thumbs.db │ │ ├── boxcount.m │ │ ├── demo.m │ │ ├── dla.gif │ │ ├── fractal_tree.jpg │ │ ├── html │ │ │ ├── Thumbs.db │ │ │ ├── demo.html │ │ │ ├── demo.png │ │ │ ├── demo_01.png │ │ │ ├── demo_02.png │ │ │ ├── demo_03.png │ │ │ ├── demo_04.png │ │ │ ├── demo_05.png │ │ │ ├── demo_06.png │ │ │ ├── demo_07.png │ │ │ ├── demo_08.png │ │ │ ├── demo_09.png │ │ │ ├── demo_10.png │ │ │ ├── demo_11.png │ │ │ ├── demo_12.png │ │ │ ├── demo_13.png │ │ │ ├── demo_14.png │ │ │ ├── demo_15.png │ │ │ └── demo_16.png │ │ └── randcantor.m │ ├── fractal.m │ ├── fractal_nec.txt │ ├── fractal_tc.txt │ ├── geometry.m │ ├── histogram.m │ ├── regionprops3.m │ ├── test_brats.csv │ ├── test_features │ │ ├── fractal_nec.txt │ │ ├── fractal_tc.txt │ │ ├── geometry_nec.txt │ │ ├── geometry_tc.txt │ │ ├── geometry_wt.txt │ │ ├── hist_enh.txt │ │ └── hist_nec.txt │ ├── train_features │ │ ├── fractal_nec.txt │ │ ├── fractal_tc.txt │ │ ├── geometry_nec.txt │ │ ├── geometry_tc.txt │ │ ├── geometry_wt.txt │ │ ├── hist_enh.txt │ │ └── hist_nec.txt │ ├── valid_brats.csv │ └── valid_features │ │ ├── fractal_nec.txt │ │ ├── fractal_tc.txt │ │ ├── geometry_nec.txt │ │ ├── geometry_tc.txt │ │ ├── geometry_wt.txt │ │ ├── hist_enh.txt │ │ └── hist_nec.txt ├── Filename_into_textfile.ipynb ├── keplen mier.ipynb ├── npy_fromcsv.ipynb └── spearmanr.ipynb └── python ├── Classification ├── ANN.ipynb ├── Fold1 │ ├── Train_dir.txt │ └── Val_Dir.txt ├── Fold2 │ ├── Train_dir.txt │ └── Val_Dir.txt ├── Fold3 │ ├── Train_dir.txt │ └── Val_Dir.txt ├── Fold4 │ ├── Train_dir.txt │ └── Val_Dir.txt ├── ICHFeatures.csv ├── ICH_Features.csv ├── Logistic regression.ipynb ├── RFC.ipynb ├── SVC.ipynb └── xgboost.ipynb ├── Features_Final.ipynb ├── Regression ├── GroundTruth.xlsx ├── ICHFeatures.csv ├── Linear Regression.ipynb ├── SVR.ipynb ├── random forest regression.ipynb └── xgboost.ipynb ├── base_nn.py ├── check_snap.py ├── shap_box.py └── svm_rfe.py /README.md: -------------------------------------------------------------------------------- 1 | # 3D_Attention_UNet 2 | This repository contains the official implementation of the paper "Brain Tumor Segmentation and Survival 3 | Prediction using 3D Attention UNet" [preprint](https://arxiv.org/pdf/2104.00985.pdf) and [in workshop Proceedings](https://link.springer.com/chapter/10.1007/978-3-030-46640-4_25)
4 | 5 | The baseline implementation of UNet3D is adopted from
6 | SSD: https://github.com/wolny/pytorch-3dunet
7 | 8 | ## Dataset 9 | [Multimodal Brain Tumor Segmentation Challenge 2019 (BraTS)](https://www.med.upenn.edu/cbica/brats2019/data.html)
10 | 11 | ## Citation 12 | If you use this code for your research, please cite our paper. 13 | 14 | ``` 15 | @inproceedings{islam2019brain, 16 | title={Brain tumor segmentation and survival prediction using 3D attention UNet}, 17 | author={Islam, Mobarakol and Vibashan, VS and Jose, V Jeya Maria and Wijethilake, Navodini and Utkarsh, Uppal and Ren, Hongliang}, 18 | booktitle={International MICCAI Brainlesion Workshop}, 19 | pages={262--272}, 20 | year={2019}, 21 | organization={Springer} 22 | } 23 | ``` -------------------------------------------------------------------------------- /segmentation/3d_attention_unet.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from BuildingBlocks import Encoder, Decoder, FinalConv, DoubleConv, ExtResNetBlock, SingleConv 7 | 8 | 9 | 10 | def create_feature_maps(init_channel_number, number_of_fmaps): 11 | return [init_channel_number * 2 ** k for k in range(number_of_fmaps)] 12 | 13 | class UNet3D(nn.Module): 14 | """ 15 | 3DUnet model from 16 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 17 | `. 18 | Args: 19 | in_channels (int): number of input channels 20 | out_channels (int): number of output segmentation masks; 21 | Note that that the of out_channels might correspond to either 22 | different semantic classes or to different binary segmentation mask. 23 | It's up to the user of the class to interpret the out_channels and 24 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 25 | or BCEWithLogitsLoss (two-class) respectively) 26 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 27 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 28 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 29 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 30 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 31 | layer_order (string): determines the order of layers 32 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 33 | See `SingleConv` for more info 34 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 35 | num_groups (int): number of groups for the GroupNorm 36 | """ 37 | 38 | def __init__(self, in_channels, out_channels, final_sigmoid, f_maps=16, layer_order='crg', num_groups=8, 39 | **kwargs): 40 | super(UNet3D, self).__init__() 41 | 42 | if isinstance(f_maps, int): 43 | # use 4 levels in the encoder path as suggested in the paper 44 | f_maps = create_feature_maps(f_maps, number_of_fmaps=6) 45 | 46 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 47 | # uses DoubleConv as a basic_module for the Encoder 48 | encoders = [] 49 | for i, out_feature_num in enumerate(f_maps): 50 | if i == 0: 51 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv, 52 | conv_layer_order=layer_order, num_groups=num_groups) 53 | else: 54 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv, 55 | conv_layer_order=layer_order, num_groups=num_groups) 56 | encoders.append(encoder) 57 | 58 | self.encoders = nn.ModuleList(encoders) 59 | 60 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 61 | # uses DoubleConv as a basic_module for the Decoder 62 | decoders = [] 63 | reversed_f_maps = list(reversed(f_maps)) 64 | for i in range(len(reversed_f_maps) - 1): 65 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 66 | out_feature_num = reversed_f_maps[i + 1] 67 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv, 68 | conv_layer_order=layer_order, num_groups=num_groups) 69 | decoders.append(decoder) 70 | 71 | self.decoders = nn.ModuleList(decoders) 72 | 73 | # in the last layer a 1×1 convolution reduces the number of output 74 | # channels to the number of labels 75 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 76 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 77 | 78 | if final_sigmoid: 79 | self.final_activation = nn.Sigmoid() 80 | else: 81 | self.final_activation = nn.Softmax(dim=1) 82 | 83 | def forward(self, x): 84 | # encoder part 85 | encoders_features = [] 86 | for encoder in self.encoders: 87 | x = encoder(x) 88 | # reverse the encoder outputs to be aligned with the decoder 89 | encoders_features.insert(0, x) 90 | 91 | # remove the last encoder's output from the list 92 | # !!remember: it's the 1st in the list 93 | pool_fea = self.avg_pool(encoders_features[0]).squeeze(0).squeeze(1).squeeze(1).squeeze(1) 94 | encoders_features = encoders_features[1:] 95 | # decoder part 96 | for decoder, encoder_features in zip(self.decoders, encoders_features): 97 | # pass the output from the corresponding encoder and the output 98 | # of the previous decoder 99 | x = decoder(encoder_features, x) 100 | 101 | x = self.final_conv(x) 102 | 103 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs 104 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 105 | if not self.training: 106 | x = self.final_activation(x) 107 | 108 | return x, pool_fea 109 | -------------------------------------------------------------------------------- /segmentation/BuildingBlocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from sca_3d import SCA3D 6 | 7 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): 8 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 9 | 10 | 11 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): 12 | """ 13 | Create a list of modules with together constitute a single conv layer with non-linearity 14 | and optional batchnorm/groupnorm. 15 | Args: 16 | in_channels (int): number of input channels 17 | out_channels (int): number of output channels 18 | order (string): order of things, e.g. 19 | 'cr' -> conv + ReLU 20 | 'crg' -> conv + ReLU + groupnorm 21 | 'cl' -> conv + LeakyReLU 22 | 'ce' -> conv + ELU 23 | num_groups (int): number of groups for the GroupNorm 24 | padding (int): add zero-padding to the input 25 | Return: 26 | list of tuple (name, module) 27 | """ 28 | assert 'c' in order, "Conv layer MUST be present" 29 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' 30 | 31 | modules = [] 32 | for i, char in enumerate(order): 33 | if char == 'r': 34 | modules.append(('ReLU', nn.ReLU(inplace=True))) 35 | elif char == 'l': 36 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) 37 | elif char == 'e': 38 | modules.append(('ELU', nn.ELU(inplace=True))) 39 | elif char == 'c': 40 | # add learnable bias only in the absence of gatchnorm/groupnorm 41 | bias = not ('g' in order or 'b' in order) 42 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) 43 | elif char == 'g': 44 | is_before_conv = i < order.index('c') 45 | assert not is_before_conv, 'GroupNorm MUST go after the Conv3d' 46 | # number of groups must be less or equal the number of channels 47 | if out_channels < num_groups: 48 | num_groups = out_channels 49 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=out_channels))) 50 | elif char == 'b': 51 | is_before_conv = i < order.index('c') 52 | if is_before_conv: 53 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) 54 | else: 55 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) 56 | else: 57 | raise ValueError("Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") 58 | 59 | return modules 60 | 61 | 62 | class SingleConv(nn.Sequential): 63 | """ 64 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order 65 | of operations can be specified via the `order` parameter 66 | Args: 67 | in_channels (int): number of input channels 68 | out_channels (int): number of output channels 69 | kernel_size (int): size of the convolving kernel 70 | order (string): determines the order of layers, e.g. 71 | 'cr' -> conv + ReLU 72 | 'crg' -> conv + ReLU + groupnorm 73 | 'cl' -> conv + LeakyReLU 74 | 'ce' -> conv + ELU 75 | num_groups (int): number of groups for the GroupNorm 76 | """ 77 | 78 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): 79 | super(SingleConv, self).__init__() 80 | 81 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): 82 | self.add_module(name, module) 83 | 84 | 85 | class DoubleConv(nn.Sequential): 86 | """ 87 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). 88 | We use (Conv3d+ReLU+GroupNorm3d) by default. 89 | This can be changed however by providing the 'order' argument, e.g. in order 90 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'. 91 | Use padded convolutions to make sure that the output (H_out, W_out) is the same 92 | as (H_in, W_in), so that you don't have to crop in the decoder path. 93 | Args: 94 | in_channels (int): number of input channels 95 | out_channels (int): number of output channels 96 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder 97 | kernel_size (int): size of the convolving kernel 98 | order (string): determines the order of layers, e.g. 99 | 'cr' -> conv + ReLU 100 | 'crg' -> conv + ReLU + groupnorm 101 | 'cl' -> conv + LeakyReLU 102 | 'ce' -> conv + ELU 103 | num_groups (int): number of groups for the GroupNorm 104 | """ 105 | 106 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): 107 | super(DoubleConv, self).__init__() 108 | if encoder: 109 | # we're in the encoder path 110 | conv1_in_channels = in_channels 111 | conv1_out_channels = out_channels // 2 112 | if conv1_out_channels < in_channels: 113 | conv1_out_channels = in_channels 114 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels 115 | else: 116 | # we're in the decoder path, decrease the number of channels in the 1st convolution 117 | conv1_in_channels, conv1_out_channels = in_channels, out_channels 118 | conv2_in_channels, conv2_out_channels = out_channels, out_channels 119 | 120 | # conv1 121 | self.add_module('SingleConv1', 122 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) 123 | # conv2 124 | self.add_module('SingleConv2', 125 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) 126 | 127 | 128 | class ExtResNetBlock(nn.Module): 129 | """ 130 | Basic UNet block consisting of a SingleConv followed by the residual block. 131 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number 132 | of output channels is compatible with the residual block that follows. 133 | This block can be used instead of standard DoubleConv in the Encoder module. 134 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf 135 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. 136 | """ 137 | 138 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): 139 | super(ExtResNetBlock, self).__init__() 140 | 141 | # first convolution 142 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 143 | # residual block 144 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 145 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 146 | n_order = order 147 | for c in 'rel': 148 | n_order = n_order.replace(c, '') 149 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, 150 | num_groups=num_groups) 151 | 152 | # create non-linearity separately 153 | if 'l' in order: 154 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) 155 | elif 'e' in order: 156 | self.non_linearity = nn.ELU(inplace=True) 157 | else: 158 | self.non_linearity = nn.ReLU(inplace=True) 159 | 160 | def forward(self, x): 161 | # apply first convolution and save the output as a residual 162 | out = self.conv1(x) 163 | residual = out 164 | 165 | # residual block 166 | out = self.conv2(out) 167 | out = self.conv3(out) 168 | 169 | out += residual 170 | out = self.non_linearity(out) 171 | 172 | return out 173 | 174 | 175 | class Encoder(nn.Module): 176 | """ 177 | A single module from the encoder path consisting of the optional max 178 | pooling layer (one may specify the MaxPool kernel_size to be different 179 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic 180 | (make sure to use complementary scale_factor in the decoder path) followed by 181 | a DoubleConv module. 182 | Args: 183 | in_channels (int): number of input channels 184 | out_channels (int): number of output channels 185 | conv_kernel_size (int): size of the convolving kernel 186 | apply_pooling (bool): if True use MaxPool3d before DoubleConv 187 | pool_kernel_size (tuple): the size of the window to take a max over 188 | pool_type (str): pooling layer: 'max' or 'avg' 189 | basic_module(nn.Module): either ResNetBlock or DoubleConv 190 | conv_layer_order (string): determines the order of layers 191 | in `DoubleConv` module. See `DoubleConv` for more info. 192 | num_groups (int): number of groups for the GroupNorm 193 | """ 194 | 195 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, 196 | pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg', 197 | num_groups=8): 198 | super(Encoder, self).__init__() 199 | assert pool_type in ['max', 'avg'] 200 | if apply_pooling: 201 | if pool_type == 'max': 202 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) 203 | else: 204 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) 205 | else: 206 | self.pooling = None 207 | 208 | self.basic_module = basic_module(in_channels, out_channels, 209 | encoder=True, 210 | kernel_size=conv_kernel_size, 211 | order=conv_layer_order, 212 | num_groups=num_groups) 213 | 214 | def forward(self, x): 215 | if self.pooling is not None: 216 | x = self.pooling(x) 217 | #x = self.scse(x) 218 | x = self.basic_module(x) 219 | return x 220 | 221 | 222 | class Decoder(nn.Module): 223 | """ 224 | A single module for decoder path consisting of the upsample layer 225 | (either learned ConvTranspose3d or interpolation) followed by a DoubleConv 226 | module. 227 | Args: 228 | in_channels (int): number of input channels 229 | out_channels (int): number of output channels 230 | kernel_size (int): size of the convolving kernel 231 | scale_factor (tuple): used as the multiplier for the image H/W/D in 232 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation 233 | from the corresponding encoder 234 | basic_module(nn.Module): either ResNetBlock or DoubleConv 235 | conv_layer_order (string): determines the order of layers 236 | in `DoubleConv` module. See `DoubleConv` for more info. 237 | num_groups (int): number of groups for the GroupNorm 238 | """ 239 | 240 | def __init__(self, in_channels, out_channels, kernel_size=3, 241 | scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='crg', num_groups=8): 242 | super(Decoder, self).__init__() 243 | if basic_module == DoubleConv: 244 | # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling 245 | self.upsample = None 246 | else: 247 | # otherwise use ConvTranspose3d (bear in mind your GPU memory) 248 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder 249 | # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) 250 | # also scale the number of channels from in_channels to out_channels so that summation joining 251 | # works correctly 252 | self.upsample = nn.ConvTranspose3d(in_channels, 253 | out_channels, 254 | kernel_size=kernel_size, 255 | stride=scale_factor, 256 | padding=1, 257 | output_padding=1) 258 | # adapt the number of in_channels for the ExtResNetBlock 259 | in_channels = out_channels 260 | 261 | self.scse = SCA3D(in_channels) 262 | 263 | self.basic_module = basic_module(in_channels, out_channels, 264 | encoder=False, 265 | kernel_size=kernel_size, 266 | order=conv_layer_order, 267 | num_groups=num_groups) 268 | 269 | def forward(self, encoder_features, x): 270 | if self.upsample is None: 271 | # use nearest neighbor interpolation and concatenation joining 272 | output_size = encoder_features.size()[2:] 273 | x = F.interpolate(x, size=output_size, mode='nearest') 274 | # concatenate encoder_features (encoder path) with the upsampled input across channel dimension 275 | x = torch.cat((encoder_features, x), dim=1) 276 | else: 277 | # use ConvTranspose3d and summation joining 278 | x = self.upsample(x) 279 | x += encoder_features 280 | x = self.scse(x) 281 | x = self.basic_module(x) 282 | return x 283 | 284 | 285 | class FinalConv(nn.Sequential): 286 | """ 287 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution 288 | which reduces the number of channels to 'out_channels'. 289 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively. 290 | We use (Conv3d+ReLU+GroupNorm3d) by default. 291 | This can be change however by providing the 'order' argument, e.g. in order 292 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. 293 | Args: 294 | in_channels (int): number of input channels 295 | out_channels (int): number of output channels 296 | kernel_size (int): size of the convolving kernel 297 | order (string): determines the order of layers, e.g. 298 | 'cr' -> conv + ReLU 299 | 'crg' -> conv + ReLU + groupnorm 300 | num_groups (int): number of groups for the GroupNorm 301 | """ 302 | 303 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): 304 | super(FinalConv, self).__init__() 305 | 306 | # conv1 307 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) 308 | 309 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels 310 | final_conv = nn.Conv3d(in_channels, out_channels, 1) 311 | self.add_module('final_conv', final_conv) -------------------------------------------------------------------------------- /segmentation/sca_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SCA3D(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super().__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 9 | self.channel_excitation = nn.Sequential(nn.Linear(channel, int(channel // reduction)), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(int(channel // reduction), channel)) 12 | self.spatial_se = nn.Conv3d(channel, 1, kernel_size=1, 13 | stride=1, padding=0, bias=False) 14 | 15 | def forward(self, x): 16 | bahs, chs, _, _, _ = x.size() 17 | chn_se = self.avg_pool(x).view(bahs, chs) 18 | chn_se = torch.sigmoid(self.channel_excitation(chn_se).view(bahs, chs, 1, 1,1)) 19 | chn_se = torch.mul(x, chn_se) 20 | spa_se = torch.sigmoid(self.spatial_se(x)) 21 | spa_se = torch.mul(x, spa_se) 22 | net_out = spa_se + x + chn_se 23 | return net_out 24 | -------------------------------------------------------------------------------- /survival_prediction/matlab/Brats_valid/.ipynb_checkpoints/Best_model_withRFE_XGB-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /survival_prediction/matlab/Brats_valid/.ipynb_checkpoints/Normalizing-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /survival_prediction/matlab/Brats_valid/.ipynb_checkpoints/Radiomics_model-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import numpy as np\n", 11 | "import pandas as pd \n", 12 | "from sklearn.preprocessing import StandardScaler\n", 13 | "\n", 14 | "dataset_radiomics = pd.read_csv('/home/navodini/Documents/NUS/Brats19/train_brats.csv',header=None,names = ['f1_nec','f2_nec','f3_nec','f4_nec','f5_nec','f1_tc','f2_tc','f3_tc','f4_tc','f5_tc','FirstAxis1_nec','FirstAxis2_nec','FirstAxis3_nec','SecondAxis1_nec','SecondAxis2_nec','SecondAxis3_nec','ThirdAxis1_nec','ThirdAxis2_nec','ThirdAxis3_nec','EigenValues1_nec','EigenValues2_nec','EigenValues3_nec','FirstAxisLength_nec','SecondAxisLength_nec','ThirdAxisLength_nec','Centroid1_nec','Centroid2_nec','Centroid3_nec','MeridionalEccentricity_nec','EquatorialEccentricity_nec','FirstAxis1_tc','FirstAxis2_tc','FirstAxis3_tc','SecondAxis1_tc','SecondAxis2_tc','SecondAxis3_tc','ThirdAxis1_tc','ThirdAxis2_tc','ThirdAxis3_tc','EigenValues1_tc','EigenValues2_tc','EigenValues3_tc','FirstAxisLength_tc','SecondAxisLength_tc','ThirdAxisLength_tc','Centroid1_tc','Centroid2_tc','Centroid3_tc','MeridionalEccentricity_tc','EquatorialEccentricity_tc','FirstAxis1_wt','FirstAxis2_wt','FirstAxis3_wt','SecondAxis1_wt','SecondAxis2_wt','SecondAxis3_wt','ThirdAxis1_wt','ThirdAxis2_wt','ThirdAxis3_wt','EigenValues1_wt','EigenValues2_wt','EigenValues3_wt','FirstAxisLength_wt','SecondAxisLength_wt','ThirdAxisLength_wt','Centroid1_wt','Centroid2_wt','Centroid3_wt','MeridionalEccentricity_wt','EquatorialEccentricity_wt','kurtosis_necrosis','entropy_necrosis','histogram_necrosis','entropy_enhancement','histogram_enhancement'])\n", 15 | "dataset_valid_radiomics = pd.read_csv('/home/navodini/Documents/NUS/Brats19/valid_brats.csv',header=None,names = ['f1_nec','f2_nec','f3_nec','f4_nec','f5_nec','f1_tc','f2_tc','f3_tc','f4_tc','f5_tc','FirstAxis1_nec','FirstAxis2_nec','FirstAxis3_nec','SecondAxis1_nec','SecondAxis2_nec','SecondAxis3_nec','ThirdAxis1_nec','ThirdAxis2_nec','ThirdAxis3_nec','EigenValues1_nec','EigenValues2_nec','EigenValues3_nec','FirstAxisLength_nec','SecondAxisLength_nec','ThirdAxisLength_nec','Centroid1_nec','Centroid2_nec','Centroid3_nec','MeridionalEccentricity_nec','EquatorialEccentricity_nec','FirstAxis1_tc','FirstAxis2_tc','FirstAxis3_tc','SecondAxis1_tc','SecondAxis2_tc','SecondAxis3_tc','ThirdAxis1_tc','ThirdAxis2_tc','ThirdAxis3_tc','EigenValues1_tc','EigenValues2_tc','EigenValues3_tc','FirstAxisLength_tc','SecondAxisLength_tc','ThirdAxisLength_tc','Centroid1_tc','Centroid2_tc','Centroid3_tc','MeridionalEccentricity_tc','EquatorialEccentricity_tc','FirstAxis1_wt','FirstAxis2_wt','FirstAxis3_wt','SecondAxis1_wt','SecondAxis2_wt','SecondAxis3_wt','ThirdAxis1_wt','ThirdAxis2_wt','ThirdAxis3_wt','EigenValues1_wt','EigenValues2_wt','EigenValues3_wt','FirstAxisLength_wt','SecondAxisLength_wt','ThirdAxisLength_wt','Centroid1_wt','Centroid2_wt','Centroid3_wt','MeridionalEccentricity_wt','EquatorialEccentricity_wt','kurtosis_necrosis','entropy_necrosis','histogram_necrosis','entropy_enhancement','histogram_enhancement'])\n", 16 | "\n", 17 | "dataset = pd.read_csv('/home/navodini/Documents/NUS/Brats19/train_OS.csv',header=None,names = ['Patient_ID','Age','Survival','class'])\n", 18 | "age = ((dataset['Age'].values)[1:]).astype('float').reshape(-1,1)\n", 19 | "radiomics = (dataset_radiomics[['f1_nec','f2_nec','f3_nec','f4_nec','f5_nec','f1_tc','f2_tc','f3_tc','f4_tc','f5_tc','FirstAxis1_nec','FirstAxis2_nec','FirstAxis3_nec','SecondAxis1_nec','SecondAxis2_nec','SecondAxis3_nec','ThirdAxis1_nec','ThirdAxis2_nec','ThirdAxis3_nec','EigenValues1_nec','EigenValues2_nec','EigenValues3_nec','FirstAxisLength_nec','SecondAxisLength_nec','ThirdAxisLength_nec','Centroid1_nec','Centroid2_nec','Centroid3_nec','MeridionalEccentricity_nec','EquatorialEccentricity_nec','FirstAxis1_tc','FirstAxis2_tc','FirstAxis3_tc','SecondAxis1_tc','SecondAxis2_tc','SecondAxis3_tc','ThirdAxis1_tc','ThirdAxis2_tc','ThirdAxis3_tc','EigenValues1_tc','EigenValues2_tc','EigenValues3_tc','FirstAxisLength_tc','SecondAxisLength_tc','ThirdAxisLength_tc','Centroid1_tc','Centroid2_tc','Centroid3_tc','MeridionalEccentricity_tc','EquatorialEccentricity_tc','FirstAxis1_wt','FirstAxis2_wt','FirstAxis3_wt','SecondAxis1_wt','SecondAxis2_wt','SecondAxis3_wt','ThirdAxis1_wt','ThirdAxis2_wt','ThirdAxis3_wt','EigenValues1_wt','EigenValues2_wt','EigenValues3_wt','FirstAxisLength_wt','SecondAxisLength_wt','ThirdAxisLength_wt','Centroid1_wt','Centroid2_wt','Centroid3_wt','MeridionalEccentricity_wt','EquatorialEccentricity_wt','kurtosis_necrosis','entropy_necrosis','histogram_necrosis','entropy_enhancement','histogram_enhancement']].values)[1:].astype('float')\n", 20 | "radiomics_valid = (dataset_valid_radiomics[['f1_nec','f2_nec','f3_nec','f4_nec','f5_nec','f1_tc','f2_tc','f3_tc','f4_tc','f5_tc','FirstAxis1_nec','FirstAxis2_nec','FirstAxis3_nec','SecondAxis1_nec','SecondAxis2_nec','SecondAxis3_nec','ThirdAxis1_nec','ThirdAxis2_nec','ThirdAxis3_nec','EigenValues1_nec','EigenValues2_nec','EigenValues3_nec','FirstAxisLength_nec','SecondAxisLength_nec','ThirdAxisLength_nec','Centroid1_nec','Centroid2_nec','Centroid3_nec','MeridionalEccentricity_nec','EquatorialEccentricity_nec','FirstAxis1_tc','FirstAxis2_tc','FirstAxis3_tc','SecondAxis1_tc','SecondAxis2_tc','SecondAxis3_tc','ThirdAxis1_tc','ThirdAxis2_tc','ThirdAxis3_tc','EigenValues1_tc','EigenValues2_tc','EigenValues3_tc','FirstAxisLength_tc','SecondAxisLength_tc','ThirdAxisLength_tc','Centroid1_tc','Centroid2_tc','Centroid3_tc','MeridionalEccentricity_tc','EquatorialEccentricity_tc','FirstAxis1_wt','FirstAxis2_wt','FirstAxis3_wt','SecondAxis1_wt','SecondAxis2_wt','SecondAxis3_wt','ThirdAxis1_wt','ThirdAxis2_wt','ThirdAxis3_wt','EigenValues1_wt','EigenValues2_wt','EigenValues3_wt','FirstAxisLength_wt','SecondAxisLength_wt','ThirdAxisLength_wt','Centroid1_wt','Centroid2_wt','Centroid3_wt','MeridionalEccentricity_wt','EquatorialEccentricity_wt','kurtosis_necrosis','entropy_necrosis','histogram_necrosis','entropy_enhancement','histogram_enhancement']].values)[1:].astype('float')\n", 21 | "\n", 22 | "OS = ((dataset['Survival'].values)[1:]).astype('float').reshape(-1,1)\n", 23 | "\n", 24 | "sc_X = StandardScaler()\n", 25 | "X = sc_X.fit_transform(radiomics)\n", 26 | "#X_test = sc_X.fit_transform(radiomics_valid)\n", 27 | "#X = np.append(age/100,X,axis=1)\n", 28 | "y = OS/1000" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Accuracy: 47.62%\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "from numpy import loadtxt\n", 46 | "from xgboost import XGBRegressor\n", 47 | "from sklearn.model_selection import train_test_split\n", 48 | "from sklearn.metrics import accuracy_score\n", 49 | "\n", 50 | "def categorize(array):\n", 51 | " new_array=np.zeros_like(array)\n", 52 | " for i in range(0,array.shape[0]): \n", 53 | " k=array[i]\n", 54 | " #print(k)\n", 55 | " if k>0.45:\n", 56 | " new_array[i,:]=2\n", 57 | " elif 0.3=8 && size(c,2)>=8 59 | c = sum(c,3); 60 | end 61 | end 62 | 63 | warning off 64 | c = logical(squeeze(c)); 65 | warning on 66 | 67 | dim = ndims(c); % dim is 2 for a vector or a matrix, 3 for a cube 68 | if dim>3 69 | error('Maximum dimension is 3.'); 70 | end 71 | 72 | % transpose the vector to a 1-by-n vector 73 | if length(c)==numel(c) 74 | dim=1; 75 | if size(c,1)~=1 76 | c = c'; 77 | end 78 | end 79 | 80 | width = max(size(c)); % largest size of the box 81 | p = log(width)/log(2); % nbre of generations 82 | 83 | % remap the array if the sizes are not all equal, 84 | % or if they are not power of two 85 | % (this slows down the computation!) 86 | if p~=round(p) || any(size(c)~=width) 87 | p = ceil(p); 88 | width = 2^p; 89 | switch dim 90 | case 1 91 | mz = zeros(1,width); 92 | mz(1:length(c)) = c; 93 | c = mz; 94 | case 2 95 | mz = zeros(width, width); 96 | mz(1:size(c,1), 1:size(c,2)) = c; 97 | c = mz; 98 | case 3 99 | mz = zeros(width, width, width); 100 | mz(1:size(c,1), 1:size(c,2), 1:size(c,3)) = c; 101 | c = mz; 102 | end 103 | end 104 | 105 | n=zeros(1,p+1); % pre-allocate the number of box of size r 106 | 107 | switch dim 108 | 109 | case 1 %------------------- 1D boxcount ---------------------% 110 | 111 | n(p+1) = sum(c); 112 | for g=(p-1):-1:0 113 | siz = 2^(p-g); 114 | siz2 = round(siz/2); 115 | for i=1:siz:(width-siz+1) 116 | c(i) = ( c(i) || c(i+siz2)); 117 | end 118 | n(g+1) = sum(c(1:siz:(width-siz+1))); 119 | end 120 | 121 | case 2 %------------------- 2D boxcount ---------------------% 122 | 123 | n(p+1) = sum(c(:)); 124 | for g=(p-1):-1:0 125 | siz = 2^(p-g); 126 | siz2 = round(siz/2); 127 | for i=1:siz:(width-siz+1) 128 | for j=1:siz:(width-siz+1) 129 | c(i,j) = ( c(i,j) || c(i+siz2,j) || c(i,j+siz2) || c(i+siz2,j+siz2) ); 130 | end 131 | end 132 | n(g+1) = sum(sum(c(1:siz:(width-siz+1),1:siz:(width-siz+1)))); 133 | end 134 | 135 | case 3 %------------------- 3D boxcount ---------------------% 136 | 137 | n(p+1) = sum(c(:)); 138 | for g=(p-1):-1:0 139 | siz = 2^(p-g); 140 | siz2 = round(siz/2); 141 | for i=1:siz:(width-siz+1), 142 | for j=1:siz:(width-siz+1), 143 | for k=1:siz:(width-siz+1), 144 | c(i,j,k)=( c(i,j,k) || c(i+siz2,j,k) || c(i,j+siz2,k) ... 145 | || c(i+siz2,j+siz2,k) || c(i,j,k+siz2) || c(i+siz2,j,k+siz2) ... 146 | || c(i,j+siz2,k+siz2) || c(i+siz2,j+siz2,k+siz2)); 147 | end 148 | end 149 | end 150 | n(g+1) = sum(sum(sum(c(1:siz:(width-siz+1),1:siz:(width-siz+1),1:siz:(width-siz+1))))); 151 | end 152 | 153 | end 154 | n = n(end:-1:1); 155 | r = 2.^(0:p); % box size (1, 2, 4, 8...) 156 | 157 | if any(strncmpi(varargin,'slope',1)) 158 | s=-gradient(log(n))./gradient(log(r)); 159 | semilogx(r, s, 's-'); 160 | ylim([0 dim]); 161 | xlabel('r, box size'); ylabel('- d ln n / d ln r, local dimension'); 162 | title([num2str(dim) 'D box-count']); 163 | elseif nargout==0 || any(strncmpi(varargin,'plot',1)) 164 | loglog(r,n,'s-'); 165 | xlabel('r, box size'); ylabel('n(r), number of boxes'); 166 | title([num2str(dim) 'D box-count']); 167 | end 168 | if nargout==0 169 | clear r n 170 | end 171 | -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/demo.m: -------------------------------------------------------------------------------- 1 | %% Computing a fractal dimension with Matlab: 1D, 2D and 3D Box-counting 2 | % F. Moisy, 9 july 2008. 3 | % University Paris Sud. 4 | 5 | %% About Fractals and box-counting 6 | % A set (e.g. an image) is called "fractal" if it displays self-similarity: 7 | % it can be split into parts, each of which is (at least approximately) 8 | % a reduced-size copy of the whole. 9 | % 10 | % A possible characterisation of a fractal set is provided by the 11 | % "box-counting" method: The number N of boxes of size R needed to cover a 12 | % fractal set follows a power-law, N = N0 * R^(-DF), with DF<=D (D is the 13 | % dimension of the space, usually D=1, 2, 3). 14 | % 15 | % DF is known as the Minkowski-Bouligand dimension, or Kolmogorov capacity, 16 | % or Kolmogorov dimension, or simply box-counting dimension. 17 | % 18 | % To learn more about box-counting, fractals and fractal dimensions: 19 | % 20 | % - http://en.wikipedia.org/wiki/Fractal 21 | % 22 | % - http://en.wikipedia.org/wiki/Box_counting_dimension 23 | % 24 | % - http://mathworld.wolfram.com/Fractal.html 25 | % 26 | % - http://mathworld.wolfram.com/CapacityDimension.html 27 | 28 | %% About the 'boxcount' package for Matlab 29 | % The following examples illustrate how to use the Matlab package 30 | % 'boxcount' to compute the fractal dimension of 1D, 2D or 3D sets, using 31 | % the 'box-counting' method. 32 | % 33 | % The directory contains the main function 'boxcount', three sample images, 34 | % and an additional function 'randcantor' to generate 1D, 2D and 3D 35 | % generalized random Cantor sets. 36 | % 37 | % Type 'help boxcount' or 'help randcantor' for more details. 38 | 39 | 40 | %% Box-counting of a 2D image 41 | % Let's start with the image 'dla.gif', a 800x800 logical array (i.e., it 42 | % contains only 0 and 1). It originates from a numerical simulation of a 43 | % "Diffusion Limited Aggregation" process, in which particles move randomly 44 | % until they hit a central seed. 45 | % (see P. Bourke, http://local.wasp.uwa.edu.au/~pbourke/fractals/dla/ ) 46 | 47 | c = imread('dla.gif'); 48 | imagesc(~c) 49 | colormap gray 50 | axis square 51 | 52 | %% 53 | % Calling boxcount without output arguments simply displays N (the number 54 | % of boxes needed to cover the set) as a function of R (the size of the 55 | % boxes). If the set is a fractal, then a power-law N = N0 * R^(-DF) 56 | % should appear, with DF the fractal dimension (Kolmogorov capacity). 57 | 58 | boxcount(c) 59 | 60 | %% 61 | % The result of the box count can be obtained using: 62 | 63 | [n, r] = boxcount(c) 64 | loglog(r, n,'bo-', r, (r/r(end)).^(-2), 'r--') 65 | xlabel('r') 66 | ylabel('n(r)') 67 | legend('actual box-count','space-filling box-count'); 68 | 69 | %% 70 | % The red dotted line shows the scaling N(R) = R^-2 for comparision, 71 | % expected for a space-filling 2D image. The discrepancy between the two 72 | % curves indicates a possible fractal behaviour. 73 | 74 | 75 | %% Local scaling exponent 76 | % If the set has some fractal properties over a limited range of box size 77 | % R, this may be appreciated by plotting the local exponent, 78 | % D = - d ln N / ln R. For this, use the option 'slope': 79 | 80 | boxcount(c, 'slope') 81 | 82 | %% 83 | % Strictly speaking, the local exponent is not constant, but lies in the 84 | % range [1.6 1.8]. 85 | 86 | %% 87 | % Let's try now with another image, the so-called Apollonian gasket 88 | % (Wikipedia, http://en.wikipedia.org/wiki/Image:Apollonian_gasket.gif ). 89 | % The background level is 198 for this image, so this value is used to 90 | % binarize the image: 91 | 92 | c = imread('Apollonian_gasket.gif'); 93 | c = (c<198); 94 | imagesc(~c) 95 | colormap gray 96 | axis square 97 | figure 98 | boxcount(c) 99 | figure 100 | boxcount(c,'slope') 101 | 102 | %% 103 | % The local slope shows that the image is indeed approximately fractal, 104 | % with a fractal dimension DF = 1.4 +/- 0.1 for scales R < 100. 105 | 106 | 107 | %% Box-counting of a natural image. 108 | % Consider now this RGB (2272x1704) picture of a tree (J.A. Adam, 109 | % http://epod.usra.edu/archive/images/fractal_tree.jpg ): 110 | c = imread('fractal_tree.jpg'); 111 | image(c) 112 | axis image 113 | 114 | %% 115 | % Let's extract a rectangle in the blue (3rd) plane, and binarize the 116 | % image for levels < 80 (white pixels are logical 'true'): 117 | 118 | i = c(1:1200, 120:2150, 3); 119 | bi = (i<80); 120 | imagesc(bi) 121 | colormap gray 122 | axis image 123 | 124 | %% 125 | 126 | [n,r] = boxcount(bi,'slope'); 127 | 128 | %% 129 | % The boxcount shows that the local exponent is approximately constant for 130 | % less than one decade, in the range 8 < R < 128 (the 'exact' value of Df 131 | % depends on the threshold, 80 gray levels here): 132 | 133 | df = -diff(log(n))./diff(log(r)); 134 | disp(['Fractal dimension, Df = ' num2str(mean(df(4:8))) ' +/- ' num2str(std(df(4:8)))]); 135 | 136 | 137 | %% Generalized random Cantor sets 138 | % Fractal sets may be obtained from an IFS (iterated function system). 139 | % For example, the function 'randcantor' provided with the package generates a 1D, 2D or 3D 140 | % generalized random Cantor set. This set is obtained by iteratively 141 | % dividing an initial set filled with 1 into 2^D subsets, and setting each 142 | % subset to 0 with probability P. The result is a fractal set (or "fractal 143 | % dust") of dimension DF = D + log(P)/log(2) < D. 144 | 145 | %% 146 | % The following example generates a 2048x2048 image with probability P=0.8, 147 | % i.e. fractal dimension DF = 1.678. 148 | 149 | c = randcantor(0.8, 2048, 2); 150 | imagesc(~c) 151 | colormap gray 152 | axis image 153 | 154 | %% 155 | % Let's see its box-count and local exponent 156 | 157 | boxcount(c) 158 | figure 159 | boxcount(c, 'slope') 160 | 161 | %% 162 | % For such set generated by an iterated scheme, the local slope shows as 163 | % expected a well defined plateau, around DF = 1.7. 164 | 165 | %% 1D random Cantor set 166 | % 1D random Cantor sets may also be generated. Here, a 2^18 = 262144 long 167 | % set with P = 0.9 and expected fractal dimension DF = 1 + log(P)/log(2) = 168 | % 0.848: 169 | 170 | tic 171 | c = randcantor(0.9, 2^18, 1, 'show'); 172 | figure 173 | boxcount(c, 'slope'); 174 | toc 175 | 176 | %% 3D random Cantor set 177 | % Now a 3D random Cantor set of size (2^7)^3 = 128^3 with P = 0.7 and 178 | % expected fractal dimension DF = 3 + log(P)/log(2) = 2.485. Note that 179 | % 3D sets cannot be displayed using randcantor. 180 | 181 | tic 182 | c = randcantor(0.7, 2^7, 3); 183 | toc 184 | tic 185 | boxcount(c, 'slope'); 186 | toc 187 | 188 | %% More? 189 | % That's all. To learn more about this package, type: 190 | 191 | help boxcount.m 192 | -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/dla.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/dla.gif -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/fractal_tree.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/fractal_tree.jpg -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/Thumbs.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/Thumbs.db -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_01.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_02.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_03.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_04.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_05.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_06.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_07.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_08.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_09.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_10.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_11.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_12.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_13.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_14.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_15.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/matlab/Feature_Extraction/boxcount/html/demo_16.png -------------------------------------------------------------------------------- /survival_prediction/matlab/Feature_Extraction/boxcount/randcantor.m: -------------------------------------------------------------------------------- 1 | function c = randcantor(p,n,d,varargin) 2 | %RANDCANTOR 1D, 2D or 3D generalized random Cantor set 3 | % C = RANDCANTOR(P, N, D) generates a logical D-dimensional array (with 4 | % D=1, 2, or 3) of size N^D, containing a set of fractally-distributed 1. 5 | % The size N must be a power of 2. C is obtained by iteratively dividing 6 | % an initial set filled with 1 into 2^D subsets, multiplying each by 0 7 | % with probability P (with 0 0) & (S < k*k))[0])\n", 208 | "\n", 209 | "\n", 210 | " Z = (Z < threshold)\n", 211 | "\n", 212 | "\n", 213 | " p = min(Z.shape)\n", 214 | "\n", 215 | " n = 2**np.floor(np.log(p)/np.log(2))\n", 216 | "\n", 217 | " n = int(np.log(n)/np.log(2))\n", 218 | "\n", 219 | " sizes = 2**np.arange(n, 1, -1)\n", 220 | "\n", 221 | " counts = []\n", 222 | " for size in sizes:\n", 223 | " counts.append(boxcount(Z, size))\n", 224 | "\n", 225 | " coeffs = np.polyfit(np.log(sizes), np.log(counts), 1)\n", 226 | " return -coeffs[0]\n", 227 | "\n", 228 | " FractalDim = fractal_dimension(img)\n", 229 | " Entropy = skimage.measure.shannon_entropy(img, base=2)\n", 230 | " parameters = []\n", 231 | " parameters.append(Centroid)\n", 232 | " parameters.append(MajorAxisLength)\n", 233 | " parameters.append(MinorAxisLength)\n", 234 | " parameters.append(DiagonalAxis)\n", 235 | " parameters.append(DiagonalPerp)\n", 236 | " parameters.append(Extent)\n", 237 | " parameters.append(Diameter)\n", 238 | " parameters.append(EigenValues)\n", 239 | " parameters.append(Solidity)\n", 240 | " parameters.append(FirstAxis)\n", 241 | " parameters.append(SecondAxis)\n", 242 | " parameters.append(ThirdAxis)\n", 243 | " parameters.append(FirstAxisLength)\n", 244 | " parameters.append(SecondAxisLength)\n", 245 | " parameters.append(ThirdAxisLength)\n", 246 | " parameters.append(kurt)\n", 247 | " parameters.append(histo)\n", 248 | " parameters.append(hemorrhage)\n", 249 | " parameters.append(FractalDim)\n", 250 | " parameters.append(Entropy)\n", 251 | " parameters = np.asarray(parameters)\n", 252 | " np.save(mri_file[-27:-21]+\"_seg.npy\",parameters)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "#Format\n", 262 | "how2read = []\n", 263 | "how2read.append(\"Centroid,3\")\n", 264 | "how2read.append(\"MajorAxisLength,1\")\n", 265 | "how2read.append()" 266 | ] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "Python 2", 272 | "language": "python", 273 | "name": "python2" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 2 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython2", 285 | "version": "2.7.12" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 2 290 | } 291 | -------------------------------------------------------------------------------- /survival_prediction/python/Regression/GroundTruth.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mobarakol/3D_Attention_UNet/6c6ef922b12673d53a8a11e29ad14df36fbb92ed/survival_prediction/python/Regression/GroundTruth.xlsx -------------------------------------------------------------------------------- /survival_prediction/python/Regression/random forest regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "import numpy as np\n", 13 | "import keras.backend as K\n", 14 | "from keras.wrappers.scikit_learn import KerasRegressor\n", 15 | "from sklearn.model_selection import train_test_split\n", 16 | "from sklearn import metrics\n", 17 | "from sklearn.feature_selection import RFE\n", 18 | "from sklearn import preprocessing\n", 19 | "from sklearn.ensemble import RandomForestRegressor\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "from torch.autograd import Variable\n", 24 | "from sklearn.preprocessing import StandardScaler,Normalizer\n", 25 | "import pandas as pd\n", 26 | "import sklearn.svm as svm\n", 27 | "from imblearn.datasets import make_imbalance\n", 28 | "from imblearn.over_sampling import RandomOverSampler\n", 29 | "from sklearn.metrics import confusion_matrix, r2_score\n", 30 | "from sklearn.metrics import mean_squared_error\n", 31 | "\n", 32 | "def categorize(array):\n", 33 | " #print(array)\n", 34 | " new_array=np.zeros_like(array)\n", 35 | " for i in range(0,array.shape[0]): \n", 36 | " k=array[i]\n", 37 | " if k>0.33:\n", 38 | " new_array[i,:]=1\n", 39 | " else: \n", 40 | " new_array[i,:]=0\n", 41 | " return new_array\n", 42 | "#df = pd.DataFrame(columns=['Parameter','fold1','fold2','fold3','fol4'])\n", 43 | "df = pd.DataFrame(columns=['Parameter','fold1','fold2','fold3','fol4'])\n", 44 | "#for i in range(1,80):\n", 45 | "for i in range(1,25):\n", 46 | "\n", 47 | " print(i)\n", 48 | " MSE = np.array('mse')\n", 49 | " Accuracy = np.array('Acc')\n", 50 | " r2_sc = np.array('r2_score')\n", 51 | " \n", 52 | " for fold in range(1,5):\n", 53 | " features = pd.read_csv('ICHFeatures.csv',header=0)\n", 54 | " OS_train = pd.read_excel(r'GroundTruth.xlsx', sheet_name=\"Fold\"+str(fold)+'_Seg',header = 0, dtype=str)\n", 55 | " OS_train[\"ID\"] = OS_train[\"ID\"].str.zfill(3)\n", 56 | " #OS_train.columns = ['ID','OS']\n", 57 | " OS_valid = pd.read_excel(r'GroundTruth.xlsx', sheet_name=\"Fold\"+str(fold)+'_Val',header = 0, dtype=str)\n", 58 | " OS_valid[\"ID\"] = OS_valid[\"ID\"].str.zfill(3)\n", 59 | " #OS_valid.columns = ['ID','OS']\n", 60 | " features['ID']=features['ID'].str.replace('ct1','')\n", 61 | " train = pd.merge(features, OS_train, how='right', on='ID')\n", 62 | " test = pd.merge(features, OS_valid, how='right', on='ID')\n", 63 | " norm_wihtout = [col for col in train.columns if col not in ['ID','Delta','Class']]\n", 64 | " #norm_valid = [col for col in test.columns if col not in ['ID','GCS','Onset','OS']]\n", 65 | " scaler = StandardScaler()\n", 66 | " train_ss = scaler.fit_transform(train[norm_wihtout])\n", 67 | " test_ss = scaler.transform(test[norm_wihtout])\n", 68 | " train[norm_wihtout] = train_ss\n", 69 | " test[norm_wihtout] = test_ss\n", 70 | " #train = train.assign(norm_train.values = train_ss)\n", 71 | " col_withoutID = [col for col in train.columns if col not in ['ID','Class']]\n", 72 | " ros = RandomOverSampler(random_state=42)\n", 73 | " X_res, y_res = ros.fit_resample(train[col_withoutID], train['Class'].values.astype(float))\n", 74 | " X_withDelta = pd.DataFrame(X_res,columns = col_withoutID)\n", 75 | " train_class = pd.DataFrame(y_res, columns = ['Class'])\n", 76 | " col_withoutDelta = [col for col in X_withDelta.columns if col not in ['Delta']]\n", 77 | " train_X = X_withDelta[col_withoutDelta]\n", 78 | " train_y = X_withDelta[\"Delta\"]\n", 79 | " num_features = i\n", 80 | " estimator = RandomForestRegressor(max_depth=2, random_state=0)\n", 81 | " #print(num_features)\n", 82 | " rfe=RFE(estimator, n_features_to_select=num_features,step=1)\n", 83 | " rfe.fit(train_X,train_y)\n", 84 | " ranking_RFE=rfe.ranking_\n", 85 | " indices=np.where(ranking_RFE==1)\n", 86 | " indices = list(indices[0])\n", 87 | " data_RFE=train_X.iloc[:,indices]\n", 88 | " valid_RFE = test[col_withoutID].iloc[:,indices]\n", 89 | " #print(data_RFE.columns)\n", 90 | " model = RandomForestRegressor(max_depth=2, random_state=0)\n", 91 | " model.fit(data_RFE, train_y)\n", 92 | "\n", 93 | " Y_pred=model.predict(valid_RFE).ravel()\n", 94 | " #acc=metrics.accuracy_score(test['Delta'].values,Y_pred)\n", 95 | " #print(\"accuracy score = \"+str(acc)) \n", 96 | " mse = mean_squared_error(test['Delta'].values, Y_pred)\n", 97 | " MSE = np.append(MSE,mse)\n", 98 | " r2_s = r2_score(test['Delta'].values, Y_pred)\n", 99 | " #print(mse)\n", 100 | " #print(r2_s)\n", 101 | " r2_sc = np.append(r2_sc,r2_s)\n", 102 | " # con_matrix = confusion_matrix(test['Delta'].values.tolist(),Y_pred.tolist())\n", 103 | " # TN,FP,FN,TP = con_matrix.ravel()\n", 104 | " # # Sensitivity, hit rate, recall, or true positive rate\n", 105 | " # TPR = TP/(TP+FN)\n", 106 | " # # Specificity or true negative rate\n", 107 | " # TNR = TN/(TN+FP) \n", 108 | " # #Precision\n", 109 | " # PPV = TP/(TP+FP)\n", 110 | " # Prec = np.append(Prec,PPV)\n", 111 | " # Sens = np.append(Sens,TPR)\n", 112 | " # Spec = np.append(Spec,TNR)\n", 113 | " predictions = categorize(Y_pred.reshape(20,1))\n", 114 | " #print(predictions)\n", 115 | " y_test_class = categorize(pd.to_numeric(test['Delta']).values.reshape(20,1))#.reshape(20,1)\n", 116 | " # evaluate predictions\n", 117 | " accuracy = metrics.accuracy_score(y_test_class, predictions)\n", 118 | " #print(\"Accuracy: %.2f%%\" % (accuracy * 100.0))\n", 119 | " Accuracy = np.append(Accuracy,accuracy)\n", 120 | " #best=metrics.mean_squared_error(y_test*1000, y_pred*1000) \n", 121 | " #print(best)\n", 122 | " MSE = pd.DataFrame(data = MSE.reshape(1,5),columns = df.columns)\n", 123 | " ACC = pd.DataFrame(data = Accuracy.reshape(1,5),columns = df.columns)\n", 124 | " R2_Score = pd.DataFrame(data = r2_sc.reshape(1,5),columns = df.columns)\n", 125 | " # Spec = pd.DataFrame(data = Spec.reshape(1,5),columns = df.columns)\n", 126 | " # df = df.append(Accuracy)\n", 127 | " df = df.append(MSE)\n", 128 | " df = df.append(ACC)\n", 129 | " df= df.append(R2_Score)\n", 130 | " #del Accuracy\n", 131 | " print(np.average(Accuracy[1:].astype(np.float)))" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 1, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stderr", 141 | "output_type": "stream", 142 | "text": [ 143 | "Using TensorFlow backend.\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "import os\n", 149 | "import numpy as np\n", 150 | "import keras.backend as K\n", 151 | "from keras.wrappers.scikit_learn import KerasRegressor\n", 152 | "from sklearn.model_selection import train_test_split\n", 153 | "from sklearn import metrics\n", 154 | "from sklearn.feature_selection import RFE\n", 155 | "from sklearn import preprocessing\n", 156 | "from sklearn.ensemble import RandomForestRegressor\n", 157 | "import matplotlib.pyplot as plt\n", 158 | "import torch\n", 159 | "import torch.nn as nn\n", 160 | "from torch.autograd import Variable\n", 161 | "from sklearn.preprocessing import StandardScaler,Normalizer\n", 162 | "import pandas as pd\n", 163 | "import sklearn.svm as svm\n", 164 | "from imblearn.datasets import make_imbalance\n", 165 | "from imblearn.over_sampling import RandomOverSampler\n", 166 | "from sklearn.metrics import confusion_matrix, r2_score\n", 167 | "from sklearn.metrics import mean_squared_error\n", 168 | "\n", 169 | "def categorize(array):\n", 170 | " #print(array)\n", 171 | " new_array=np.zeros_like(array)\n", 172 | " for i in range(0,array.shape[0]): \n", 173 | " k=array[i]\n", 174 | " if k>0.33:\n", 175 | " new_array[i,:]=1\n", 176 | " else: \n", 177 | " new_array[i,:]=0\n", 178 | " return new_array\n", 179 | "df = pd.DataFrame(columns=['Parameter','fold1','fold2','fold3','fol4'])\n", 180 | "# df = pd.DataFrame(columns=['Parameter','fold1','fold2','fold3','fol4'])\n", 181 | "# #for i in range(1,80):\n", 182 | "# for i in range(1,25):\n", 183 | "\n", 184 | "# print(i)\n", 185 | "MSE = np.array('mse')\n", 186 | "Accuracy = np.array('Acc')\n", 187 | "r2_sc = np.array('r2_score')\n", 188 | "Prec = np.array('Precision')\n", 189 | "Sens = np.array('Sensitivity')\n", 190 | "Spec =np.array('Specificity')\n", 191 | "for fold in range(3,4):\n", 192 | " features = pd.read_csv('ICHFeatures.csv',header=0)\n", 193 | " OS_train = pd.read_excel(r'GroundTruth.xlsx', sheet_name=\"Fold\"+str(fold)+'_Seg',header = 0, dtype=str)\n", 194 | " OS_train[\"ID\"] = OS_train[\"ID\"].str.zfill(3)\n", 195 | " #OS_train.columns = ['ID','OS']\n", 196 | " OS_valid = pd.read_excel(r'GroundTruth.xlsx', sheet_name=\"Fold\"+str(fold)+'_Val',header = 0, dtype=str)\n", 197 | " OS_valid[\"ID\"] = OS_valid[\"ID\"].str.zfill(3)\n", 198 | " #OS_valid.columns = ['ID','OS']\n", 199 | " features['ID']=features['ID'].str.replace('ct1','')\n", 200 | " train = pd.merge(features, OS_train, how='right', on='ID')\n", 201 | " test = pd.merge(features, OS_valid, how='right', on='ID')\n", 202 | " norm_wihtout = [col for col in train.columns if col not in ['ID','Delta','Class']]\n", 203 | " #norm_valid = [col for col in test.columns if col not in ['ID','GCS','Onset','OS']]\n", 204 | " scaler = StandardScaler()\n", 205 | " train_ss = scaler.fit_transform(train[norm_wihtout])\n", 206 | " test_ss = scaler.transform(test[norm_wihtout])\n", 207 | " train[norm_wihtout] = train_ss\n", 208 | " test[norm_wihtout] = test_ss\n", 209 | " #train = train.assign(norm_train.values = train_ss)\n", 210 | " col_withoutID = [col for col in train.columns if col not in ['ID','Class']]\n", 211 | " ros = RandomOverSampler(random_state=42)\n", 212 | " X_res, y_res = ros.fit_resample(train[col_withoutID], train['Class'].values.astype(float))\n", 213 | " X_withDelta = pd.DataFrame(X_res,columns = col_withoutID)\n", 214 | " train_class = pd.DataFrame(y_res, columns = ['Class'])\n", 215 | " col_withoutDelta = [col for col in X_withDelta.columns if col not in ['Delta']]\n", 216 | " train_X = X_withDelta[col_withoutDelta]\n", 217 | " train_y = X_withDelta[\"Delta\"]\n", 218 | " num_features = 6\n", 219 | " estimator = RandomForestRegressor(max_depth=2, random_state=0)\n", 220 | " #print(num_features)\n", 221 | " rfe=RFE(estimator, n_features_to_select=num_features,step=1)\n", 222 | " rfe.fit(train_X,train_y)\n", 223 | " ranking_RFE=rfe.ranking_\n", 224 | " indices=np.where(ranking_RFE==1)\n", 225 | " indices = list(indices[0])\n", 226 | " data_RFE=train_X.iloc[:,indices]\n", 227 | " valid_RFE = test[col_withoutID].iloc[:,indices]\n", 228 | " #print(data_RFE.columns)\n", 229 | " model = RandomForestRegressor(max_depth=2, random_state=0)\n", 230 | " model.fit(data_RFE, train_y)\n", 231 | "\n", 232 | " Y_pred=model.predict(valid_RFE).ravel()\n", 233 | " #acc=metrics.accuracy_score(test['Delta'].values,Y_pred)\n", 234 | " #print(\"accuracy score = \"+str(acc)) \n", 235 | " mse = mean_squared_error(test['Delta'].values, Y_pred)\n", 236 | " MSE = np.append(MSE,mse)\n", 237 | " r2_s = r2_score(test['Delta'].values, Y_pred)\n", 238 | " #print(mse)\n", 239 | " #print(r2_s)\n", 240 | " r2_sc = np.append(r2_sc,r2_s)\n", 241 | "\n", 242 | " predictions = categorize(Y_pred.reshape(20,1))\n", 243 | " #print(predictions)\n", 244 | " y_test_class = categorize(pd.to_numeric(test['Delta']).values.reshape(20,1))#.reshape(20,1)\n", 245 | " # evaluate predictions\n", 246 | " accuracy = metrics.accuracy_score(y_test_class, predictions)\n", 247 | " #print(\"Accuracy: %.2f%%\" % (accuracy * 100.0))\n", 248 | " Accuracy = np.append(Accuracy,accuracy)\n", 249 | " #best=metrics.mean_squared_error(y_test*1000, y_pred*1000) \n", 250 | " #print(best)\n", 251 | " con_matrix = confusion_matrix(y_test_class,predictions)\n", 252 | " TN,FP,FN,TP = con_matrix.ravel()\n", 253 | " # Sensitivity, hit rate, recall, or true positive rate\n", 254 | " TPR = TP/(TP+FN)\n", 255 | " # Specificity or true negative rate\n", 256 | " TNR = TN/(TN+FP) \n", 257 | " #Precision\n", 258 | " PPV = TP/(TP+FP)\n", 259 | " Prec = np.append(Prec,PPV)\n", 260 | " Sens = np.append(Sens,TPR)\n", 261 | " Spec = np.append(Spec,TNR)\n", 262 | "# MSE = pd.DataFrame(data = MSE.reshape(1,5),columns = df.columns)\n", 263 | "# ACC = pd.DataFrame(data = Accuracy.reshape(1,5),columns = df.columns)\n", 264 | "# R2_Score = pd.DataFrame(data = r2_sc.reshape(1,5),columns = df.columns)\n", 265 | "# Spec = pd.DataFrame(data = Spec.reshape(1,5),columns = df.columns)\n", 266 | "# Prec = pd.DataFrame(data = Prec.reshape(1,5),columns = df.columns)\n", 267 | "# Sens = pd.DataFrame(data = Sens.reshape(1,5),columns = df.columns)\n", 268 | "# # df = df.append(Accuracy)\n", 269 | "# df = df.append(MSE)\n", 270 | "# df = df.append(ACC)\n", 271 | "# df= df.append(R2_Score)\n", 272 | "# df = df.append(Spec)\n", 273 | "# df = df.append(Sens)\n", 274 | "# df = df.append(Prec)\n", 275 | "# #del Accuracy\n", 276 | "# print(np.average(Accuracy[1:].astype(np.float)))" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "df.to_csv('RFR_results.csv')" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "df2 = pd.DataFrame(Y_pred)\n", 295 | "df2.to_csv('rfr_pred.csv')" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 3, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "df3 = pd.DataFrame(test['Delta'].values)\n", 305 | "df3.to_csv('gt.csv')" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "\n", 315 | "\n", 316 | "\n", 317 | "\n", 318 | "\n", 319 | "\n", 320 | "\n", 321 | "\n", 322 | "\n", 323 | "\n", 324 | "\n", 325 | "\n", 326 | "\n", 327 | "\n", 328 | "\n", 329 | "\n", 330 | "\n", 331 | "\n", 332 | "\n", 333 | "\n", 334 | "\n" 335 | ] 336 | } 337 | ], 338 | "metadata": { 339 | "kernelspec": { 340 | "display_name": "py3", 341 | "language": "python", 342 | "name": "py3" 343 | }, 344 | "language_info": { 345 | "codemirror_mode": { 346 | "name": "ipython", 347 | "version": 3 348 | }, 349 | "file_extension": ".py", 350 | "mimetype": "text/x-python", 351 | "name": "python", 352 | "nbconvert_exporter": "python", 353 | "pygments_lexer": "ipython3", 354 | "version": "3.5.2" 355 | } 356 | }, 357 | "nbformat": 4, 358 | "nbformat_minor": 2 359 | } 360 | -------------------------------------------------------------------------------- /survival_prediction/python/base_nn.py: -------------------------------------------------------------------------------- 1 | #System 2 | import numpy as np 3 | import sys 4 | import os 5 | import random 6 | from glob import glob 7 | from skimage import io 8 | from PIL import Image 9 | import random 10 | import SimpleITK as sitk 11 | #Torch 12 | from torch.autograd import Variable 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | from torch.autograd import Function 17 | import torch 18 | import torch.nn as nn 19 | import torchvision.transforms as standard_transforms 20 | #from torchvision.models import resnet18 21 | import nibabel as nib 22 | from sklearn.metrics import classification_report, confusion_matrix, accuracy_score 23 | 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 25 | ckpt_path = 'ckpt' 26 | exp_name = 'lol' 27 | if not os.path.exists(ckpt_path): 28 | os.makedirs(ckpt_path) 29 | if not os.path.exists(os.path.join(ckpt_path, exp_name)): 30 | os.makedirs(os.path.join(ckpt_path, exp_name)) 31 | args = { 32 | 'num_class': 2, 33 | 'num_gpus': 1, 34 | 'start_epoch': 1, 35 | 'num_epoch': 100, 36 | 'batch_size': 8, 37 | 'lr': 0.001, 38 | 'lr_decay': 0.9, 39 | 'weight_decay': 1e-4, 40 | 'momentum': 0.9, 41 | 'snapshot': '', 42 | 'opt': 'adam', 43 | 'crop_size1': 138, 44 | 45 | } 46 | 47 | class HEMDataset(Dataset): 48 | def __init__(self, text_dir): 49 | file_pairs = open(text_dir,'r') 50 | self.img_anno_pairs = file_pairs.readlines() 51 | self.req_file, self.req_tar = [],[] 52 | for i in range(len(self.img_anno_pairs)): 53 | net = self.img_anno_pairs[i][:-1] 54 | self.req_file.append(net[:3]) 55 | self.req_tar.append(net[4]) 56 | 57 | 58 | def __len__(self): 59 | return len(self.req_tar) 60 | 61 | def __getitem__(self, index): 62 | _file_num = self.req_file[index] 63 | _gt = float(self.req_tar[index]) 64 | 65 | req_npy = './Features_Train/'+ str(_file_num) + 'ct1_seg.npy' 66 | _input_arr = np.load(req_npy, allow_pickle=True) 67 | _input = np.array([]) 68 | for i in range(len(_input_arr)): 69 | if i > 18: 70 | _input = np.concatenate((_input, _input_arr[i]), axis=None) 71 | _input = torch.from_numpy(np.array(_input)).float() 72 | _target = torch.from_numpy(np.array(_gt)).long() 73 | 74 | return _input, _target 75 | 76 | class HEMDataset_test(Dataset): 77 | def __init__(self, text_dir): 78 | file_pairs = open(text_dir,'r') 79 | self.img_anno_pairs = file_pairs.readlines() 80 | self.req_file, self.req_tar = [],[] 81 | for i in range(len(self.img_anno_pairs)): 82 | net = self.img_anno_pairs[i][:-1] 83 | self.req_file.append(net[:3]) 84 | self.req_tar.append(net[4]) 85 | 86 | 87 | def __len__(self): 88 | return len(self.req_tar) 89 | 90 | def __getitem__(self, index): 91 | _file_num = self.req_file[index] 92 | _gt = float(self.req_tar[index]) 93 | 94 | req_npy = './Features_Val/'+ str(_file_num) + 'ct1_seg.npy' 95 | _input_arr = np.load(req_npy, allow_pickle=True) 96 | _input = np.array([]) 97 | for i in range(len(_input_arr)): 98 | if i > 18: 99 | _input = np.concatenate((_input, _input_arr[i]), axis=None) 100 | _input = torch.from_numpy(np.array(_input)).float() 101 | _target = torch.from_numpy(np.array(_gt)).long() 102 | 103 | return _input, _target 104 | 105 | class Net(nn.Module): 106 | def __init__(self): 107 | super(Net, self).__init__() 108 | self.fc1 = nn.Linear(4, 2048) 109 | self.fc2 = nn.Linear(2048, 1024) 110 | self.fc3 = nn.Linear(1024, 2) 111 | 112 | def forward(self, x): 113 | x = F.relu(self.fc1(x)) 114 | x = F.relu(self.fc2(x)) 115 | x = self.fc3(x) 116 | return x 117 | 118 | if __name__ == '__main__': 119 | 120 | train_file = 'Train_dir.txt' 121 | test_file = 'Val_dir.txt' 122 | train_dataset = HEMDataset(text_dir=train_file) 123 | test_dataset = HEMDataset_test(text_dir=test_file) 124 | train_loader = DataLoader(dataset=train_dataset, batch_size=args['batch_size'], shuffle=True, num_workers=2,drop_last=True) 125 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=2,drop_last=False) 126 | 127 | net = Net().cuda() 128 | criterion = nn.NLLLoss() 129 | optimizer = torch.optim.Adam(net.parameters(), lr=args['lr']) 130 | max_epoch = 50 131 | for epoch in range (max_epoch): 132 | net.train() 133 | for batch_idx, data in enumerate(train_loader): 134 | inputs, labels = data 135 | inputs = Variable(inputs).cuda() 136 | labels = Variable(labels).cuda() 137 | 138 | optimizer.zero_grad() 139 | outputs = net(inputs) 140 | loss = criterion(outputs, labels) 141 | loss.backward() 142 | optimizer.step() 143 | 144 | net.eval() 145 | correct, total = 0, 0 146 | class_pred, class_gt = [], [] 147 | with torch.no_grad(): 148 | for batch_idx, (inputs, targets) in enumerate(test_loader): 149 | inputs, targets = inputs.cuda(), targets.cuda() 150 | inputs, targets = Variable(inputs), Variable(targets) 151 | outputs = net(inputs) 152 | 153 | _, predicted = torch.max(outputs.data, 1) 154 | class_pred.append(predicted.item()) 155 | class_gt.append(targets.item()) 156 | total += targets.size(0) 157 | correct += (predicted == targets).sum().item() 158 | 159 | print('Epoch:', epoch)#, 'Accuracy: %f %%' % (100 * correct / total)) 160 | print(confusion_matrix(np.array(class_pred),np.array(class_gt))) 161 | print(classification_report(np.array(class_pred),np.array(class_gt))) 162 | print(accuracy_score(np.array(class_pred),np.array(class_gt))) 163 | print('') 164 | print('Finished Training') 165 | -------------------------------------------------------------------------------- /survival_prediction/python/check_snap.py: -------------------------------------------------------------------------------- 1 | import sklearn 2 | import shap 3 | from sklearn.model_selection import train_test_split 4 | 5 | # print the JS visualization code to the notebook 6 | shap.initjs() 7 | 8 | # train a SVM classifier 9 | X_train,X_test,Y_train,Y_test = train_test_split(*shap.datasets.iris(), test_size=0.2, random_state=0) 10 | svm = sklearn.svm.SVC(kernel='rbf', probability=True) 11 | svm.fit(X_train, Y_train) 12 | 13 | # use Kernel SHAP to explain test set predictions 14 | explainer = shap.KernelExplainer(svm.predict_proba, X_train, link="logit") 15 | shap_values = explainer.shap_values(X_test, nsamples=100) 16 | 17 | # plot the SHAP values for the Setosa output of the first instance 18 | shap.summary_plot(shap_values, X_train, plot_type="bar") 19 | -------------------------------------------------------------------------------- /survival_prediction/python/shap_box.py: -------------------------------------------------------------------------------- 1 | #System 2 | import numpy as np 3 | import sys 4 | import os 5 | import random 6 | from glob import glob 7 | from skimage import io 8 | from PIL import Image 9 | import random 10 | import SimpleITK as sitk 11 | #Torch 12 | from torch.autograd import Variable 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | from torch.autograd import Function 17 | import torch 18 | import torch.nn as nn 19 | import torchvision.transforms as standard_transforms 20 | #from torchvision.models import resnet18 21 | import nibabel as nib 22 | import matplotlib.pyplot as plt 23 | 24 | from sklearn.metrics import classification_report, confusion_matrix, accuracy_score 25 | from sklearn.feature_selection import RFE 26 | from sklearn.linear_model import LogisticRegression 27 | from sklearn.datasets import make_friedman1 28 | from sklearn.svm import LinearSVC 29 | from sklearn.svm import SVR 30 | 31 | import xgboost 32 | import shap 33 | import sklearn 34 | 35 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 36 | ckpt_path = 'ckpt' 37 | exp_name = 'lol' 38 | if not os.path.exists(ckpt_path): 39 | os.makedirs(ckpt_path) 40 | if not os.path.exists(os.path.join(ckpt_path, exp_name)): 41 | os.makedirs(os.path.join(ckpt_path, exp_name)) 42 | args = { 43 | 'num_class': 2, 44 | 'num_gpus': 1, 45 | 'start_epoch': 1, 46 | 'num_epoch': 100, 47 | 'batch_size': 1, 48 | 'lr': 3, 49 | 'lr_decay': 0.9, 50 | 'weight_decay': 1e-4, 51 | 'momentum': 0.9, 52 | 'snapshot': '', 53 | 'opt': 'adam', 54 | 'crop_size1': 138, 55 | 56 | } 57 | 58 | class HEMDataset(Dataset): 59 | def __init__(self, text_dir): 60 | file_pairs = open(text_dir,'r') 61 | self.img_anno_pairs = file_pairs.readlines() 62 | print(self.img_anno_pairs) 63 | self.req_file, self.req_tar = [],[] 64 | for i in range(len(self.img_anno_pairs)): 65 | net = self.img_anno_pairs[i][:-1] 66 | self.req_file.append(net[:3]) 67 | self.req_tar.append(net[4]) 68 | 69 | def __len__(self): 70 | return len(self.req_tar) 71 | 72 | def __getitem__(self, index): 73 | _file_num = self.req_file[index] 74 | _gt = float(self.req_tar[index]) 75 | 76 | req_npy = './Features_Train/'+ str(_file_num) + 'ct1_seg.npy' 77 | _input_arr = np.load(req_npy, allow_pickle=True) 78 | _input = np.array([]) 79 | for i in range(len(_input_arr)): 80 | _input = np.concatenate((_input, _input_arr[i]), axis=None) 81 | _input = torch.from_numpy(np.array(_input)).float() 82 | _target = torch.from_numpy(np.array(_gt)).long() 83 | 84 | return _input, _target 85 | 86 | class HEMDataset_test(Dataset): 87 | def __init__(self, text_dir): 88 | file_pairs = open(text_dir,'r') 89 | self.img_anno_pairs = file_pairs.readlines() 90 | print(self.img_anno_pairs) 91 | self.req_file, self.req_tar = [],[] 92 | for i in range(len(self.img_anno_pairs)): 93 | net = self.img_anno_pairs[i][:-1] 94 | self.req_file.append(net[:3]) 95 | self.req_tar.append(net[4]) 96 | 97 | def __len__(self): 98 | return len(self.req_tar) 99 | 100 | def __getitem__(self, index): 101 | _file_num = self.req_file[index] 102 | _gt = float(self.req_tar[index]) 103 | 104 | req_npy = './Features_Val/'+ str(_file_num) + 'ct1_seg.npy' 105 | _input_arr = np.load(req_npy, allow_pickle=True) 106 | _input = np.array([]) 107 | for i in range(len(_input_arr)): 108 | _input = np.concatenate((_input, _input_arr[i]), axis=None) 109 | _input = torch.from_numpy(np.array(_input)).float() 110 | _target = torch.from_numpy(np.array(_gt)).long() 111 | 112 | return _input, _target 113 | 114 | class Net(nn.Module): 115 | def __init__(self): 116 | super(Net, self).__init__() 117 | self.fc1 = nn.Linear(4, 2048) 118 | self.fc2 = nn.Linear(2048, 256) 119 | self.fc3 = nn.Linear(256, 2) 120 | self.out_act = nn.LogSoftmax(dim=1) 121 | 122 | def forward(self, x): 123 | x = F.relu(self.fc1(x)) 124 | x = F.relu(self.fc2(x)) 125 | x = self.fc3(x) 126 | x = self.out_act(x) 127 | return x 128 | 129 | 130 | 131 | if __name__ == '__main__': 132 | 133 | train_file = 'Train_dir.txt' 134 | test_file = 'Val_dir.txt' 135 | train_dataset = HEMDataset(text_dir=train_file) 136 | test_dataset = HEMDataset_test(text_dir=test_file) 137 | train_loader = DataLoader(dataset=train_dataset, batch_size=args['batch_size'], shuffle=True, num_workers=2,drop_last=True) 138 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=2,drop_last=False) 139 | 140 | max_epoch = 1 141 | X_train, Y_train = [], [] 142 | for epoch in range (max_epoch): 143 | for batch_idx, data in enumerate(train_loader): 144 | inputs, labels = data 145 | inputs, labels = inputs.cpu().numpy(), labels.cpu().numpy() 146 | X_train.append(inputs) 147 | Y_train.append(labels) 148 | print(batch_idx) 149 | print('okay') 150 | X_train, Y_train = np.squeeze(X_train, axis=1), np.squeeze(Y_train, axis=1) 151 | print(X_train.shape, Y_train.shape) 152 | 153 | X_test, Y_test = [], [] 154 | for epoch in range(max_epoch): 155 | for batch_idx, data in enumerate(test_loader): 156 | inputs, labels = data 157 | inputs, labels = inputs.cpu().numpy(), labels.cpu().numpy() 158 | X_test.append(inputs) 159 | Y_test.append(labels) 160 | print('okay') 161 | X_test, Y_test = np.squeeze(X_test, axis=1), np.squeeze(Y_test, axis=1) 162 | print(X_test.shape, Y_test.shape) 163 | 164 | X_train, X_test, Y_train, Y_test = np.array(X_train, dtype='f'), np.array(X_test, dtype='f'), np.array(Y_train, dtype='f'), np.array(Y_test, dtype='f') 165 | print(np.max(X_train),np.max(X_test),np.max(Y_train),np.max(Y_test)) 166 | print(np.where(np.isnan(X_test))) 167 | 168 | svm = sklearn.svm.SVC(kernel='rbf', probability=True) 169 | svm.fit(X_train, Y_train) 170 | 171 | # use Kernel SHAP to explain test set predictions 172 | explainer = shap.KernelExplainer(svm.predict_proba, X_train, link="logit") 173 | shap_values = explainer.shap_values(X_test, nsamples=100) 174 | 175 | # plot the SHAP values for the Setosa output of the first instance 176 | shap.summary_plot(shap_values, X_train, plot_type="bar") 177 | 178 | 179 | -------------------------------------------------------------------------------- /survival_prediction/python/svm_rfe.py: -------------------------------------------------------------------------------- 1 | #System 2 | import numpy as np 3 | import sys 4 | import os 5 | import random 6 | from glob import glob 7 | from skimage import io 8 | from PIL import Image 9 | import random 10 | import SimpleITK as sitk 11 | #Torch 12 | from torch.autograd import Variable 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | from torch.autograd import Function 17 | import torch 18 | import torch.nn as nn 19 | import torchvision.transforms as standard_transforms 20 | #from torchvision.models import resnet18 21 | import nibabel as nib 22 | from sklearn.metrics import classification_report, confusion_matrix, accuracy_score 23 | 24 | from sklearn.feature_selection import RFE 25 | from sklearn.linear_model import LogisticRegression 26 | from sklearn.datasets import make_friedman1 27 | from sklearn.svm import LinearSVC 28 | from sklearn.svm import SVR 29 | from sklearn.svm import SVC 30 | 31 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 32 | ckpt_path = 'ckpt' 33 | exp_name = 'lol' 34 | if not os.path.exists(ckpt_path): 35 | os.makedirs(ckpt_path) 36 | if not os.path.exists(os.path.join(ckpt_path, exp_name)): 37 | os.makedirs(os.path.join(ckpt_path, exp_name)) 38 | args = { 39 | 'num_class': 2, 40 | 'num_gpus': 1, 41 | 'start_epoch': 1, 42 | 'num_epoch': 100, 43 | 'batch_size': 1, 44 | 'lr': 0.01, 45 | 'lr_decay': 0.9, 46 | 'weight_decay': 1e-4, 47 | 'momentum': 0.9, 48 | 'snapshot': '', 49 | 'opt': 'adam', 50 | 'crop_size1': 138, 51 | 52 | } 53 | 54 | class HEMDataset(Dataset): 55 | def __init__(self, text_dir): 56 | file_pairs = open(text_dir,'r') 57 | self.img_anno_pairs = file_pairs.readlines() 58 | self.req_file, self.req_tar = [],[] 59 | for i in range(len(self.img_anno_pairs)): 60 | net = self.img_anno_pairs[i][:-1] 61 | self.req_file.append(net[:3]) 62 | self.req_tar.append(net[4]) 63 | 64 | 65 | def __len__(self): 66 | return len(self.req_tar) 67 | 68 | def __getitem__(self, index): 69 | _file_num = self.req_file[index] 70 | _gt = float(self.req_tar[index]) 71 | 72 | req_npy = './Features_Train/'+ str(_file_num) + 'ct1_seg.npy' 73 | _input_arr = np.load(req_npy, allow_pickle=True) 74 | _input = np.array([]) 75 | for i in range(len(_input_arr)): 76 | _input = np.concatenate((_input, _input_arr[i]), axis=None) 77 | _input = torch.from_numpy(np.array(_input)).float() 78 | _target = torch.from_numpy(np.array(_gt)).long() 79 | 80 | return _input, _target 81 | 82 | class HEMDataset_test(Dataset): 83 | def __init__(self, text_dir): 84 | file_pairs = open(text_dir,'r') 85 | self.img_anno_pairs = file_pairs.readlines() 86 | self.req_file, self.req_tar = [],[] 87 | for i in range(len(self.img_anno_pairs)): 88 | net = self.img_anno_pairs[i][:-1] 89 | self.req_file.append(net[:3]) 90 | self.req_tar.append(net[4]) 91 | 92 | 93 | def __len__(self): 94 | return len(self.req_tar) 95 | 96 | def __getitem__(self, index): 97 | _file_num = self.req_file[index] 98 | _gt = float(self.req_tar[index]) 99 | 100 | req_npy = './Features_Val/'+ str(_file_num) + 'ct1_seg.npy' 101 | _input_arr = np.load(req_npy, allow_pickle=True) 102 | _input = np.array([]) 103 | for i in range(len(_input_arr)): 104 | _input = np.concatenate((_input, _input_arr[i]), axis=None) 105 | #print(_input) 106 | _input = torch.from_numpy(np.array(_input)).float() 107 | _target = torch.from_numpy(np.array(_gt)).long() 108 | 109 | return _input, _target 110 | 111 | class Net(nn.Module): 112 | def __init__(self): 113 | super(Net, self).__init__() 114 | self.fc1 = nn.Linear(10, 1024) 115 | self.fc2 = nn.Linear(1024, 128) 116 | self.fc3 = nn.Linear(128, 2) 117 | 118 | def forward(self, x): 119 | x = F.relu(self.fc1(x)) 120 | x = F.relu(self.fc2(x)) 121 | x = self.fc3(x) 122 | return x 123 | 124 | def important_features(fea, idx): 125 | batch = [] 126 | for j in range(len(fea)): 127 | req_inputs = [] 128 | for i in idx[0]: 129 | req_inputs.append(fea[0][i]) 130 | batch.append(req_inputs) 131 | return req_inputs 132 | 133 | if __name__ == '__main__': 134 | 135 | train_file = 'Train_dir.txt' 136 | test_file = 'Val_dir.txt' 137 | train_dataset = HEMDataset(text_dir=train_file) 138 | test_dataset = HEMDataset_test(text_dir=test_file) 139 | rfe_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=True, num_workers=2, drop_last=True) 140 | train_loader = DataLoader(dataset=train_dataset, batch_size=args['batch_size'], shuffle=True, num_workers=2,drop_last=True) 141 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=2,drop_last=False) 142 | 143 | max_epoch = 1 144 | X_rfe, Y_rfe = [], [] 145 | for epoch in range (max_epoch): 146 | for batch_idx, data in enumerate(rfe_loader): 147 | inputs, labels = data 148 | inputs, labels = inputs.cpu().numpy(), labels.cpu().numpy() 149 | X_rfe.append(inputs) 150 | Y_rfe.append(labels) 151 | 152 | X_rfe, Y_rfe = np.squeeze(X_rfe, axis=1), np.squeeze(Y_rfe, axis=1) 153 | rfe_model = SVR(kernel="linear") 154 | rfe = RFE(rfe_model, 5, step=1) 155 | fit = rfe.fit(X_rfe, Y_rfe) 156 | 157 | rank = fit.ranking_ 158 | req_idx = np.where(rank == 1) 159 | print(fit.ranking_) 160 | print('Finished RFE') 161 | 162 | X_train, Y_train = [], [] 163 | for batch_idx, data in enumerate(train_loader): 164 | inputs, labels = data 165 | req_inputs = important_features(inputs.cpu().numpy(), req_idx) 166 | X_train.append(req_inputs) 167 | Y_train.append(labels.cpu().numpy()) 168 | X_train, Y_train = np.array(X_train), np.squeeze(np.array(Y_train), axis=1) 169 | 170 | X_test, Y_test = [], [] 171 | for batch_idx, data in enumerate(test_loader): 172 | inputs, labels = data 173 | req_inputs = important_features(inputs.cpu().numpy(), req_idx) 174 | X_test.append(req_inputs) 175 | Y_test.append(labels.cpu().numpy()) 176 | X_test, Y_test = np.array(X_test), np.squeeze(np.array(Y_test), axis=1) 177 | 178 | score, count = [], [] 179 | #model = LogisticRegression() 180 | model = SVC(gamma='auto') 181 | model.fit(X_train, Y_train) 182 | score.append(sum(model.predict(X_test) == Y_test)) 183 | count.append(len(Y_test)) 184 | print(model.predict(X_test)) 185 | print(score) 186 | 187 | #print(X_train.shape, Y_train.shape, X_test.shape,Y_test.shape) --------------------------------------------------------------------------------