├── meta
├── images
├── nets.png
└── LGG_HGG.jpg
├── runs
└── runs
├── outputs
└── checkpoints
├── dataset
├── data
│ └── data.md
└── datasets.py
├── requirements.txt
├── score_gen
├── scores.png
├── pretrained_mixed_conv.png
├── confmatrix.py
├── confmatrix_gen.py
└── heatmap_scores.py
├── .idea
├── vcs.xml
├── misc.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── Classification-of-brain-tumor-using-Spatiotemporal-models.iml
└── modules.xml
├── run.py
├── stems
├── 3dconv.py
├── resnet_mixed_conv.py
└── resnet2p1.py
├── configs
└── config.yaml
├── README.md
└── train.py
/meta/images:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/runs/runs:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/outputs/checkpoints:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/dataset/data/data.md:
--------------------------------------------------------------------------------
1 | root/dataset/data/IXI/ \
2 | root/dataset/data/MICCAI_BraTS_2019_Data/
3 |
--------------------------------------------------------------------------------
/meta/nets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/farazahmeds/Classification-of-brain-tumor-using-Spatiotemporal-models/HEAD/meta/nets.png
--------------------------------------------------------------------------------
/meta/LGG_HGG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/farazahmeds/Classification-of-brain-tumor-using-Spatiotemporal-models/HEAD/meta/LGG_HGG.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | hydra-core
4 | wandb
5 | scikit-learn
6 | matplotlib
7 | scipy
8 | seaborn
9 | pandas
10 |
--------------------------------------------------------------------------------
/score_gen/scores.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/farazahmeds/Classification-of-brain-tumor-using-Spatiotemporal-models/HEAD/score_gen/scores.png
--------------------------------------------------------------------------------
/score_gen/pretrained_mixed_conv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/farazahmeds/Classification-of-brain-tumor-using-Spatiotemporal-models/HEAD/score_gen/pretrained_mixed_conv.png
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
13 |
14 |
(a) ResNet(2+1)D (b) ResNet Mixed Convolutions (c) ResNet 3D
16 | 17 | 18 | ##### Getting started: ##### 19 | 20 | Execute ```run.py```, you can change the hyperparameters and settings for your experiments by overriding them in ```configs/config.yaml``` file or overriding through command line example: ```python run.py training.batch_size=2``` for multirun use ```-m``` flag, ```python run.py -m training.batch_size=2,5,10```. 21 | 22 | ##### Dataset: ##### 23 | 24 | The original implementation of this work was done using [BraTS 2019](https://www.med.upenn.edu/cbica/brats2019/data.html) dataset (T1 contrast enhanced images) with high-grade and low-grade glioma samples, and using [IXI](https://brain-development.org/ixi-dataset/) (T1 images), the healthy brain images. The IXI samples were skull-stripped and resampled to be used as non-pathological images for the classification task. 25 | 26 | ##### Results: ##### 27 | 28 |
29 |
30 |
Heatmaps showing the class-wise performance of the classifiers, compared using Precision, Recall, Specificity, and F1-score: (a) LGG, (b) HGG, and (c) Healthy
33 | 34 |
35 |
37 |
Confusion matrix for Pretrained ResNet Mixed Convolution (winning model)
39 | 40 | ##### Preprint: ##### 41 | [Soumick Chatterjee, Faraz Ahmed Nizamani, Andreas Nürnberger, and Oliver Speck, Classification of Brain Tumours in MR Images using Deep Spatiospatial Models](https://arxiv.org/pdf/2105.14071.pdf) 42 | 43 | BibTeX: 44 | ``` 45 | @article{chatterjee2021classification, 46 | title={Classification of Brain Tumours in MR Images using Deep Spatiospatial Models}, 47 | author={Chatterjee, Soumick and Nizamani, Faraz Ahmed and N{\"u}rnberger, Andreas and Speck, Oliver}, 48 | journal={arXiv preprint arXiv:2105.14071}, 49 | year={2021} 50 | } 51 | ``` 52 | ##### Special thanks to: ##### 53 | 1) [Soumick Chatterjee](https://github.com/soumickmj) 54 | 2) [TorchIO and Fernando Pérez-García](https://github.com/fepegar/torchio) 55 | 3) [Tristan Payer](https://github.com/sirtris) 56 | 57 | This work was in part conducted within the context of the International Graduate School MEMoRIAL at Otto von Guericke 58 | University (OVGU) Magdeburg, Germany, kindly supported by the European Structural and Investment Funds (ESF) under the 59 | programme "Sachsen-Anhalt WISSENSCHAFT Internationalisierung" (project no. ZS/2016/08/80646). 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /score_gen/confmatrix_gen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | from sklearn.metrics import accuracy_score, classification_report, confusion_matrix 6 | 7 | 8 | labels = ['LGG', 'HGG', 'Healthy'] 9 | 10 | c = ([[399 , 1767 , 138], 11 | [306 , 7033 , 149], 12 | [61 , 148 , 6991]]) 13 | 14 | # c = ([[854, 417, 1], 15 | # [269, 3862, 3], 16 | # [12, 30, 3933]]) 17 | 18 | c=np.asarray(c) 19 | group_names = ['LGG', 'HGG', 'Healthy'] 20 | 21 | k = c / c.astype(np.float).sum(axis=1) 22 | # print (k) 23 | 24 | df_cm = pd.DataFrame(k, index = [i for i in "ABC"], 25 | columns = [i for i in "ABC"]) 26 | plt.figure(figsize = (3,3)) 27 | sns.heatmap(df_cm, annot=True, fmt='.2%', cmap='Blues') 28 | 29 | plt.show() 30 | 31 | # 32 | 33 | def plot_confusion_matrix(cm, 34 | target_names, 35 | title='R(2+1)D without data augmentation', 36 | cmap=None, 37 | normalize=True): 38 | """ 39 | given a sklearn confusion matrix (cm), make a nice plot 40 | 41 | Arguments 42 | --------- 43 | cm: confusion matrix from sklearn.metrics.confusion_matrix 44 | 45 | target_names: given classification classes such as [0, 1, 2] 46 | the class names, for example: ['high', 'medium', 'low'] 47 | 48 | title: the text to display at the top of the matrix 49 | 50 | cmap: the gradient of the values displayed from matplotlib.pyplot.cm 51 | see http://matplotlib.org/examples/color/colormaps_reference.html 52 | plt.get_cmap('jet') or plt.cm.Blues 53 | 54 | normalize: If False, plot the raw numbers 55 | If True, plot the proportions 56 | 57 | Usage 58 | ----- 59 | plot_confusion_matrix(cm = cm, # confusion matrix created by 60 | # sklearn.metrics.confusion_matrix 61 | normalize = True, # show proportions 62 | target_names = y_labels_vals, # list of names of the classes 63 | title = best_estimator_name) # title of graph 64 | 65 | Citiation 66 | --------- 67 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html 68 | 69 | """ 70 | import matplotlib.pyplot as plt 71 | import numpy as np 72 | import itertools 73 | 74 | accuracy = np.trace(cm) / np.sum(cm).astype('float') 75 | misclass = 1 - accuracy 76 | 77 | if cmap is None: 78 | cmap = plt.get_cmap('Blues') 79 | 80 | plt.figure(figsize=(7, 5)) 81 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 82 | plt.title(title) 83 | plt.colorbar() 84 | 85 | if target_names is not None: 86 | tick_marks = np.arange(len(target_names)) 87 | plt.xticks(tick_marks, target_names, rotation=45) 88 | plt.yticks(tick_marks, target_names) 89 | 90 | if normalize: 91 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 92 | 93 | 94 | thresh = cm.max() / 1.5 if normalize else cm.max() / 2 95 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 96 | if normalize: 97 | plt.text(j, i, "{:0.4f}".format(cm[i, j]), 98 | horizontalalignment="center", 99 | color="white" if cm[i, j] > thresh else "black") 100 | else: 101 | plt.text(j, i, "{:,}".format(cm[i, j]), 102 | horizontalalignment="center", 103 | color="white" if cm[i, j] > thresh else "black") 104 | 105 | 106 | plt.tight_layout() 107 | plt.ylabel('True label') 108 | plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass)) 109 | plt.show() 110 | 111 | alpha= ['LGG', 'HGG', 'Healthy'] 112 | 113 | plot_confusion_matrix(c, target_names=alpha) 114 | 115 | tp_and_fn = k.sum(1) 116 | tp_and_fp = k.sum(0) 117 | tp = k.diagonal() 118 | 119 | precision = tp / tp_and_fp 120 | recall = tp / tp_and_fn 121 | # SPC = TN / (FP + TN) 122 | 123 | print (precision) 124 | 125 | print (recall) 126 | 127 | fscore = 2*precision*recall / (precision + recall) 128 | print (fscore) 129 | 130 | #########################################################3 131 | k = pd.DataFrame(k) 132 | 133 | FP = k.sum(axis=0) - np.diag(k) 134 | FN = k.sum(axis=1) - np.diag(k) 135 | TP = np.diag(k) 136 | TN = k.values.sum() - (FP + FN + TP) 137 | 138 | # Sensitivity, hit rate, recall, or true positive rate 139 | TPR = TP/(TP+FN) 140 | # Specificity or true negative rate 141 | TNR = TN/(TN+FP) 142 | # Precision or positive predictive value 143 | PPV = TP/(TP+FP) 144 | # Negative predictive value 145 | NPV = TN/(TN+FN) 146 | # Fall out or false positive rate 147 | FPR = FP/(FP+TN) 148 | # False negative rate 149 | FNR = FN/(TP+FN) 150 | # False discovery rate 151 | FDR = FP/(TP+FP) 152 | 153 | # Overall accuracy 154 | ACC = (TP+TN)/(TP+FP+FN+TN) 155 | 156 | print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") 157 | # print (PPV) 158 | # 159 | # print (FNR) 160 | 161 | 162 | print (TNR) 163 | -------------------------------------------------------------------------------- /score_gen/heatmap_scores.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | # sphinx_gallery_thumbnail_number = 2 5 | 6 | farmers = ["Precision", "Recall", "Specificity", "F1", ] 7 | vegetables = ["ResNet 3D", "Pre-trained ResNet 3D", "ResNet(2+1)D", "Pre-trained ResNet(2+1)D", "ResNet Mixed Convolution", "Pre-trained ResNet Mixed Convolution" ] 8 | 9 | harvest = np.array([[0.843, 2.4, 2.5, 3.3], 10 | [2.4, 0.0, 4.0, 4.3], 11 | [1.1, 2.4, 0.8, 3.2], 12 | [0.8, 2.4, 2.5, 3.3], 13 | [0.8, 2.4, 2.5, 3.3], 14 | [0.8, 2.4, 2.5, 3.3]] 15 | ) 16 | 17 | 18 | # OLD STUFFFF---------------------------------------- 19 | # fig, ax = plt.subplots() 20 | # im = ax.imshow(harvest, cmap="Oranges") 21 | # 22 | # # We want to show all ticks... 23 | # ax.set_xticks(np.arange(len(farmers))) 24 | # ax.set_yticks(np.arange(len(vegetables))) 25 | # # ... and label them with the respective list entries 26 | # ax.set_xticklabels(farmers) 27 | # ax.set_yticklabels(vegetables) 28 | # 29 | # # Rotate the tick labels and set their alignment. 30 | # plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 31 | # rotation_mode="anchor") 32 | # plt.colorbar() 33 | # 34 | # # Loop over data dimensions and create text annotations. 35 | # for i in range(len(vegetables)): 36 | # for j in range(len(farmers)): 37 | # text = ax.text(j, i, harvest[i, j], 38 | # ha="center", va="center", color="w") 39 | # 40 | # ax.set_title("Classifer performance for class Low Grade Glioma") 41 | # fig.tight_layout() 42 | # plt.show() 43 | 44 | 45 | def heatmap(data, row_labels, col_labels, ax=None, 46 | cbar_kw={}, cbarlabel="", **kwargs): 47 | """ 48 | Create a heatmap from a numpy array and two lists of labels. 49 | 50 | Parameters 51 | ---------- 52 | data 53 | A 2D numpy array of shape (N, M). 54 | row_labels 55 | A list or array of length N with the labels for the rows. 56 | col_labels 57 | A list or array of length M with the labels for the columns. 58 | ax 59 | A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 60 | not provided, use current axes or create a new one. Optional. 61 | cbar_kw 62 | A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 63 | cbarlabel 64 | The label for the colorbar. Optional. 65 | **kwargs 66 | All other arguments are forwarded to `imshow`. 67 | """ 68 | 69 | if not ax: 70 | ax = plt.gca() 71 | 72 | # Plot the heatmap 73 | im = ax.imshow(data, **kwargs) 74 | 75 | # Create colorbar 76 | cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) 77 | cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 78 | 79 | # We want to show all ticks... 80 | ax.set_xticks(np.arange(data.shape[1])) 81 | ax.set_yticks(np.arange(data.shape[0])) 82 | # ... and label them with the respective list entries. 83 | ax.set_xticklabels(col_labels) 84 | ax.set_yticklabels(row_labels) 85 | 86 | # Let the horizontal axes labeling appear on top. 87 | ax.tick_params(top=True, bottom=False, 88 | labeltop=True, labelbottom=False) 89 | 90 | # Rotate the tick labels and set their alignment. 91 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", 92 | rotation_mode="anchor") 93 | 94 | # Turn spines off and create white grid. 95 | for edge, spine in ax.spines.items(): 96 | spine.set_visible(False) 97 | 98 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 99 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 100 | ax.grid(which="minor", color="w", linestyle='-', linewidth=0) 101 | ax.tick_params(which="minor", bottom=False, left=False) 102 | 103 | return im, cbar 104 | 105 | 106 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 107 | textcolors=["black", "white"], 108 | threshold=None, **textkw): 109 | """ 110 | A function to annotate a heatmap. 111 | 112 | Parameters 113 | ---------- 114 | im 115 | The AxesImage to be labeled. 116 | data 117 | Data used to annotate. If None, the image's data is used. Optional. 118 | valfmt 119 | The format of the annotations inside the heatmap. This should either 120 | use the string format method, e.g. "$ {x:.2f}", or be a 121 | `matplotlib.ticker.Formatter`. Optional. 122 | textcolors 123 | A list or array of two color specifications. The first is used for 124 | values below a threshold, the second for those above. Optional. 125 | threshold 126 | Value in data units according to which the colors from textcolors are 127 | applied. If None (the default) uses the middle of the colormap as 128 | separation. Optional. 129 | **kwargs 130 | All other arguments are forwarded to each call to `text` used to create 131 | the text labels. 132 | """ 133 | 134 | if not isinstance(data, (list, np.ndarray)): 135 | data = im.get_array() 136 | 137 | # Normalize the threshold to the images color range. 138 | if threshold is not None: 139 | threshold = im.norm(threshold) 140 | else: 141 | threshold = im.norm(data.max())/2. 142 | 143 | # Set default alignment to center, but allow it to be 144 | # overwritten by textkw. 145 | kw = dict(horizontalalignment="center", 146 | verticalalignment="center") 147 | kw.update(textkw) 148 | 149 | # Get the formatter in case a string is supplied 150 | if isinstance(valfmt, str): 151 | valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 152 | 153 | # Loop over the data and create a `Text` for each "pixel". 154 | # Change the text's color depending on the data. 155 | texts = [] 156 | for i in range(data.shape[0]): 157 | for j in range(data.shape[1]): 158 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 159 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 160 | texts.append(text) 161 | 162 | return texts 163 | fig, ax = plt.subplots() 164 | 165 | im, cbar = heatmap(harvest, vegetables, farmers, ax=ax, 166 | cmap="Oranges") 167 | texts = annotate_heatmap(im, valfmt="{x:.4f}") 168 | 169 | fig.tight_layout() 170 | plt.title ("") 171 | plt.show() 172 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig, OmegaConf 3 | from hydra.utils import instantiate 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch import utils 9 | from torch.utils.data import dataset 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from torch.optim import lr_scheduler 14 | import torch.multiprocessing 15 | import torchvision 16 | from torchvision import datasets, models, transforms 17 | 18 | import torchio 19 | 20 | from torchio.transforms import ( 21 | CropOrPad, 22 | OneOf, 23 | RescaleIntensity, 24 | RandomAffine, 25 | RandomElasticDeformation, 26 | RandomFlip, 27 | Compose, 28 | ) 29 | 30 | import matplotlib.pyplot as plt 31 | from scipy.ndimage import zoom 32 | 33 | import os 34 | from pathlib import Path 35 | 36 | from torchvision.transforms import transforms 37 | 38 | from sklearn.metrics import confusion_matrix 39 | from score_gen.confmatrix import plot_confusion_matrix 40 | 41 | from dataset.datasets import Datasets 42 | import wandb 43 | 44 | 45 | def train(config: DictConfig) -> None: 46 | 47 | wandb.init(project=config.wandb_logger.project, entity=config.wandb_logger.entity, group=config.wandb_logger.group) 48 | 49 | cwd = os.getcwd() 50 | 51 | volumes = hydra.utils.instantiate(config.dataset_) 52 | 53 | total_samples = volumes.return_total_samples() 54 | 55 | torch.manual_seed(config.global_seed) 56 | 57 | load_model = config.load_model 58 | 59 | class_weights = torch.FloatTensor([3.54,1,1]).cuda() #for dataset being unbalanced for classes [LGG, HGG, Healthy] 60 | 61 | #Transforms 62 | 63 | rescale = RescaleIntensity((0.05, 99.5)) 64 | randaffine = torchio.RandomAffine(scales=(0.9,1.2),degrees=10, isotropic=True, image_interpolation='nearest') 65 | flip = torchio.RandomFlip(axes=('LR'), p=0.5) 66 | transforms = [rescale, flip, randaffine] 67 | 68 | transform = Compose(transforms) 69 | 70 | subjects_dataset = torchio.SubjectsDataset(total_samples, transform=transform) 71 | 72 | train_set_samples = (int(len(total_samples)-0.3*len(total_samples))) #train_test_split 73 | test_set_samples = (int(len(total_samples))-(train_set_samples)) 74 | 75 | trainset, testset = torch.utils.data.random_split(subjects_dataset, [train_set_samples, test_set_samples], generator=torch.Generator().manual_seed(config.dataset.train_test_split_seed)) 76 | 77 | trainloader = DataLoader(dataset=trainset, batch_size=config.training.batch_size, shuffle=True) 78 | testloader = DataLoader(dataset=testset, batch_size=config.training.batch_size, shuffle=True) 79 | 80 | #instantiate the overriden classes: 81 | 82 | if config.models.model == 'resnet2p1': 83 | model = torchvision.models.video.r2plus1d_18(pretrained=config.pretrain) 84 | model.stem = hydra.utils.instantiate(config.resnet2p1Stem) 85 | elif config.models.model == 'resnet_mixed_conv': 86 | model = torchvision.models.video.mc3_18(pretrained=config.pretrain) 87 | model.stem = hydra.utils.instantiate(config.resnet_mixed_convStem) 88 | else: 89 | model = torchvision.models.video.r3d_18(pretrained=config.pretrain) 90 | model.stem = hydra.utils.instantiate(config.conv3dStem) 91 | 92 | # regularization 93 | 94 | model.fc = nn.Sequential( 95 | nn.Dropout(config.training.dropout), 96 | nn.Linear(model.fc.in_features, config.training.num_classes) 97 | ) 98 | 99 | model.to('cuda:0') 100 | 101 | criterion = nn.CrossEntropyLoss(weight=class_weights) 102 | 103 | optimizer = torch.optim.Adam(model.parameters(), lr=config.training.learning_rate, weight_decay=config.training.weight_decay) 104 | 105 | # Initialize the prediction and label lists(tensors) for confusion matrix 106 | predlist = torch.zeros(0, dtype=torch.long).to('cuda:0') 107 | lbllist = torch.zeros(0, dtype=torch.long).to('cuda:0') 108 | 109 | if load_model: 110 | the_model = torch.load(Path(cwd,'outputs')) 111 | 112 | for epoch in range(config.training.num_epoch): 113 | 114 | logs = {} 115 | total_correct = 0 116 | total_loss = 0 117 | total_images = 0 118 | total_val_loss = 0 119 | 120 | if epoch % 5 == 0: 121 | checkpoint = {'epoch': epoch+1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()} 122 | print("Load True: saving checkpoint") 123 | torch.save(model.state_dict(), Path(cwd,'outputs')) 124 | 125 | # else: 126 | # checkpoint = {'epoch': epoch + 1, 'state_dict': model.state_dict(), 127 | # 'optimizer': optimizer.state_dict()} 128 | # print("Loade False: saving checkpoint") 129 | # save_checkpoint(checkpoint) 130 | 131 | for i, traindata in enumerate(trainloader): 132 | images = F.interpolate(traindata['t1'][torchio.DATA], scale_factor=(config.dataset.img_scale_factor,config.dataset.img_scale_factor,config.dataset.img_scale_factor)).to('cuda:0') 133 | labels = traindata['label'].to('cuda:0') 134 | optimizer.zero_grad() 135 | 136 | # Forward propagation 137 | outputs = model(images) 138 | 139 | loss = criterion(outputs, labels) # ....> 140 | 141 | # Backward prop 142 | loss.backward() 143 | 144 | # Updating gradients 145 | optimizer.step() 146 | # scheduler.step() 147 | 148 | # Total number of labels 149 | total_images += labels.size(0) 150 | 151 | # Obtaining predictions from max value 152 | _, predicted = torch.max(outputs.data, 1) 153 | 154 | # Calculate the number of correct answers 155 | correct = (predicted == labels).sum().item() 156 | 157 | total_correct += correct 158 | total_loss += loss.item() 159 | 160 | running_trainacc = ((total_correct / total_images) * 100) 161 | 162 | logs['log loss'] = total_loss / total_images 163 | logs['Accuracy'] = ((total_correct / total_images) * 100) 164 | wandb.log({'training accuracy': running_trainacc}) 165 | 166 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%' 167 | .format(epoch + 1, config.training.num_epoch, i + 1, len(trainloader), (total_loss / total_images), 168 | (total_correct / total_images) * 100)) 169 | 170 | # Testing the model 171 | 172 | with torch.no_grad(): 173 | correct = 0 174 | total = 0 175 | 176 | for testdata in testloader: 177 | 178 | images = F.interpolate(testdata['t1'][torchio.DATA], scale_factor=(config.dataset.img_scale_factor,config.dataset.img_scale_factor,config.dataset.img_scale_factor)).to('cuda:0') 179 | 180 | labels = testdata['label'].to('cuda:0') 181 | outputs = model(images) 182 | 183 | _, predicted = torch.max(outputs.data, 1) 184 | 185 | predlist = torch.cat([predlist, predicted.view(-1)]) # Append batch prediction results 186 | 187 | lbllist = torch.cat([lbllist, labels.view(-1)]) 188 | 189 | total += labels.size(0) 190 | correct += (predicted == labels).sum().item() 191 | 192 | total_losss = loss.item() 193 | 194 | accuracy = correct / total 195 | 196 | print('Test Accuracy of the model: {} %'.format(100 * correct / total)) 197 | 198 | logs['val_' + 'log loss'] = total_loss / total 199 | validationloss = total_loss / total 200 | 201 | validationacc = ((correct / total) * 100) 202 | logs['val_' + 'Accuracy'] = ((correct / total) * 100) 203 | 204 | wandb.log({'test accuracy': validationacc, 'val loss': validationloss}) 205 | 206 | # Computing metrics: 207 | 208 | conf_mat = confusion_matrix(lbllist.cpu().numpy(), predlist.cpu().numpy()) 209 | 210 | print(conf_mat) 211 | cls = ["lower grade glioma (LGG)", "Glioblastoma (GBM/high grade glioma)", "Normal Brain"] 212 | # Per-class accuracy 213 | class_accuracy = 100 * conf_mat.diagonal() / conf_mat.sum(1) 214 | print(class_accuracy) 215 | plt.figure(figsize=(10, 10)) 216 | plot_confusion_matrix(conf_mat, cls) 217 | plt.show() 218 | --------------------------------------------------------------------------------