├── LICENSE ├── README.md ├── lib ├── __init__.py ├── dataloader │ ├── MRIDataloader.py │ ├── __init__.py │ ├── df_reader.py │ └── image_processing.py ├── model │ ├── DuoAttention.py │ ├── __init__.py │ ├── attention_block.py │ └── create_model.py ├── training │ ├── __init__.py │ ├── train.py │ └── train_helper.py └── utils │ ├── EarlyStopping.py │ ├── __init__.py │ └── utils.py ├── photo ├── model.png └── test_performance.png └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 giaminhgist 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A reproducible 3D convolutional neural network with dual attention module (3D-DAM) for Alzheimer's disease classification 2 | 3 | The journal version of the paper describing this work is available [here](https://doi.org/10.48550/arXiv.2310.12574) 4 | 5 | ## Abstract 6 | 7 | Alzheimer's disease is one of the most common types of neurodegenerative disease, characterized by the accumulation of amyloid-beta plaque and tau tangles. Recently, deep learning approaches have shown promise in Alzheimer's disease diagnosis. In this study, we propose a reproducible model that utilizes a 3D convolutional neural network with a dual attention module for Alzheimer's disease classification. We trained the model in the ADNI database and verified the generalizability of our method in two independent datasets (AIBL and OASIS1). Our method achieved state-of-the-art classification performance, with an accuracy of 91.94% for MCI progression classification and 96.30% for Alzheimer's disease classification on the ADNI dataset. Furthermore, the model demonstrated good generalizability, achieving an accuracy of 86.37% on the AIBL dataset and 83.42% on the OASIS1 dataset. These results indicate that our proposed approach has competitive performance and generalizability when compared to recent studies in the field. 8 | 9 | ## Model Architecture 10 | ![model architecture](https://github.com/giaminhgist/3D-DAM/blob/main/photo/model.png) 11 | 12 | ## Main Results 13 | 14 | ### ADNI - AIBL - OASIS 15 | 16 | | Training| Test| Accuracy(%) | Sensitivity(%) | Specificity(%) | 17 | |-------------|----------|-----------|--------|----------| 18 | | ADNI | AIBL | 86.3 | 80.2 | 87.1 | 19 | | ADNI | OASIS | 83.4 | 85.8 | 82.6 | 20 | | AIBL - OASIS | ADNI | 85.4 |80.1 | 89.5 | 21 | 22 | ![Test Performance](https://github.com/giaminhgist/3D-DAM/blob/main/photo/test_performance.png) 23 | 24 | ## Citation 25 | 26 | If you find this project useful for your research, please use the following BibTeX entries. 27 | 28 | @misc{vu2024reproducible, 29 | title={A reproducible 3D convolutional neural network with dual attention module (3D-DAM) for Alzheimer's disease classification}, 30 | year={2024}, 31 | eprint={2310.12574}, 32 | archivePrefix={arXiv}, 33 | primaryClass={eess.IV} 34 | } 35 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/dataloader/MRIDataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | import nibabel as nib 4 | from .image_processing import normalise_zero_one, reshape_zero_padding 5 | 6 | 7 | class MRIDataset(Dataset): 8 | def __init__(self, image_paths, label_dict, feature_dict, 9 | task='AD_CN', 10 | ): 11 | 12 | self.image_paths = image_paths 13 | self.label_dict = label_dict 14 | self.feature_dict = feature_dict 15 | 16 | if task == 'AD_CN': 17 | classes = ['AD', 'CN'] 18 | 19 | 20 | self.idx_to_class = {i: j for i, j in enumerate(classes)} 21 | self.class_to_idx = {value: key for key, value in self.idx_to_class.items()} 22 | 23 | def __len__(self): 24 | return len(self.image_paths) 25 | 26 | def __getitem__(self, idx): 27 | image_filepath = self.image_paths[idx] 28 | image_id = image_filepath.split('/')[-1] 29 | label = self.label_dict[image_id] 30 | label = self.class_to_idx[label] 31 | feature = self.feature_dict[image_id] 32 | 33 | image_original = nib.load(image_filepath).get_fdata() 34 | image = normalise_zero_one(image_original) 35 | # image = reshape_zero_padding(image) 36 | # image = np.expand_dims(image, axis=0) 37 | 38 | return image, label 39 | -------------------------------------------------------------------------------- /lib/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/dataloader/df_reader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import numpy as np 4 | 5 | 6 | # Read data from Excel file and save as image_[ath and label dict 7 | def df_reader(df_path, process_path='/media/tedi/Elements/ADNI_Database/Images/PROCESS/subjects/'): 8 | image_path = [] 9 | label_dict = {} 10 | feature_dict = {} 11 | df = pd.read_excel(df_path) 12 | for index, row in df.iterrows(): 13 | participant_id = row['participant_id'] 14 | session_id = row['session_id'] 15 | diagnosis = row['diagnosis'] 16 | 17 | idx = ''.join([participant_id, '_', session_id, '_', 18 | 'T1w_space-MNI152NLin2009cSym_desc-Crop_res-1x1x1_T1w.nii.gz']) 19 | img_path = f'{process_path}/{participant_id}/{session_id}/t1_linear/{idx}' 20 | 21 | if os.path.isfile(img_path): 22 | image_path.append(img_path) 23 | label_dict[idx] = diagnosis 24 | try: 25 | feature_dict[idx] = np.array([row['MMSE'], row['CDR'], row['APOE1'], row['APOE2'], row['AGE']]) 26 | except: 27 | feature_dict[idx] = None 28 | 29 | return image_path, label_dict, feature_dict 30 | -------------------------------------------------------------------------------- /lib/dataloader/image_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | 4 | 5 | # Preprocessing function 6 | def normalise_zero_one(image): 7 | """Image normalisation. Normalises image to fit [0, 1] range.""" 8 | 9 | image = image.astype(np.float32) 10 | 11 | minimum = np.min(image) 12 | maximum = np.max(image) 13 | 14 | if maximum > minimum: 15 | ret = (image - minimum) / (maximum - minimum) 16 | else: 17 | ret = image * 0. 18 | return ret 19 | 20 | 21 | def reshape_zero_padding(img, target_shape=224): 22 | s_1 = int((target_shape - img.shape[0]) / 2) 23 | s_2 = target_shape - img.shape[0] - s_1 24 | c_1 = int((target_shape - img.shape[1]) / 2) 25 | c_2 = target_shape - img.shape[1] - c_1 26 | a_1 = int((target_shape - img.shape[2]) / 2) 27 | a_2 = target_shape - img.shape[2] - a_1 28 | img = np.pad(img, ((s_1, s_2), (c_1, c_2), (a_1, a_2)), 'constant', constant_values=0) 29 | return img 30 | 31 | -------------------------------------------------------------------------------- /lib/model/DuoAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from lib.model.attention_block import SpatialAttention3D, ChannelAttention3D, residual_block 5 | 6 | 7 | class DAM(nn.Module): 8 | def __init__(self, channels=64): 9 | super(DAM, self).__init__() 10 | 11 | self.sa = SpatialAttention3D(out_channel=channels) 12 | self.ca = ChannelAttention3D(in_planes=channels) 13 | 14 | def forward(self, x): 15 | residual = x 16 | out = self.ca(x) 17 | out = self.sa(out) 18 | out = out + residual 19 | return out 20 | 21 | 22 | class Duo_Attention(nn.Module): 23 | def __init__( 24 | self, input_size=(1, 169, 208, 179), num_classes=3, dropout=0 25 | ): 26 | super().__init__() 27 | self.conv = nn.Sequential( 28 | nn.Conv3d(input_size[0], 8, 3, padding=1), 29 | nn.BatchNorm3d(8), 30 | nn.ReLU(), 31 | # nn.MaxPool3d(2, 2), 32 | 33 | nn.Conv3d(8, 16, 3, padding=1, stride=2), 34 | nn.BatchNorm3d(16), 35 | nn.ReLU(), 36 | residual_block(channel_size=16), 37 | nn.MaxPool3d(2, 2), 38 | 39 | nn.Conv3d(16, 32, 3, padding=1, stride=2), 40 | nn.BatchNorm3d(32), 41 | nn.ReLU(), 42 | residual_block(channel_size=32), 43 | DAM(channels=32), 44 | nn.MaxPool3d(2, 2), 45 | 46 | nn.Conv3d(32, 64, 3, padding=1, stride=2), 47 | nn.BatchNorm3d(64), 48 | nn.ReLU(), 49 | residual_block(channel_size=64), 50 | nn.MaxPool3d(2, 2), 51 | DAM(channels=64), 52 | 53 | nn.AvgPool3d(3, stride=1), 54 | ) 55 | 56 | input_tensor = torch.zeros(input_size).unsqueeze(0) 57 | output_convolutions = self.conv(input_tensor) 58 | self.fc = nn.Sequential( 59 | nn.Flatten(), 60 | nn.Dropout(p=dropout), 61 | nn.Linear(np.prod(list(output_convolutions.shape)).item(), 1024), 62 | nn.Linear(1024, num_classes), 63 | ) 64 | 65 | def forward(self, x): 66 | y = torch.unsqueeze(x, dim=1) 67 | y = self.conv(y) 68 | y = self.fc(y) 69 | return y 70 | -------------------------------------------------------------------------------- /lib/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/model/attention_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ChannelAttention3D(nn.Module): 6 | def __init__(self, in_planes=64, ratio=8): 7 | super(ChannelAttention3D, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 9 | self.max_pool = nn.AdaptiveMaxPool3d(1) 10 | 11 | self.fc = nn.Sequential(nn.Conv3d(in_planes, in_planes // ratio, 1, bias=False), 12 | nn.ReLU(), 13 | nn.Conv3d(in_planes // ratio, in_planes, 1, bias=False)) 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | residual = x 18 | avg_out = self.fc(self.avg_pool(x)) 19 | max_out = self.fc(self.max_pool(x)) 20 | out = avg_out + max_out 21 | return self.sigmoid(out) * residual 22 | 23 | 24 | class SpatialAttention3D(nn.Module): 25 | def __init__(self, out_channel=64, kernel_size=3, stride=1, padding=1): 26 | super(SpatialAttention3D, self).__init__() 27 | 28 | self.conv = nn.Conv3d(2, out_channel, 29 | kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | def forward(self, x): 33 | residual = x 34 | avg_out = torch.mean(x, dim=1, keepdim=True) 35 | max_out, _ = torch.max(x, dim=1, keepdim=True) 36 | x = torch.cat([avg_out, max_out], dim=1) 37 | x = self.conv(x) 38 | x = self.sigmoid(x) 39 | out = x * residual 40 | return out 41 | 42 | 43 | class residual_block(nn.Module): 44 | def __init__(self, channel_size=64): 45 | super(residual_block, self).__init__() 46 | 47 | self.conv = nn.Conv3d(channel_size, channel_size, kernel_size=3, padding=1) 48 | self.relu = nn.ReLU() 49 | self.bn = nn.BatchNorm3d(channel_size) 50 | 51 | def forward(self, x): 52 | residual = x 53 | y = self.conv(x) 54 | y = self.bn(y) 55 | y = self.relu(y) 56 | out = y + residual 57 | return out 58 | -------------------------------------------------------------------------------- /lib/model/create_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from .DuoAttention import Duo_Attention 4 | import torch 5 | 6 | model_dict = { 7 | 'DuoAttention':Duo_Attention, 8 | } 9 | 10 | 11 | def create_model( 12 | model_name: str, 13 | num_classes: int, 14 | pretrained_path: str = None, 15 | **kwargs, 16 | ): 17 | model = model_dict[model_name]( 18 | num_classes=num_classes, 19 | **kwargs, 20 | ) 21 | 22 | if pretrained_path is not None: 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | model.to(device) 25 | print('Load pretrained...') 26 | model.module.load_state_dict( 27 | torch.load( 28 | pretrained_path, 29 | map_location=str(device)) 30 | ) 31 | 32 | return model 33 | 34 | -------------------------------------------------------------------------------- /lib/training/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/training/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from collections import OrderedDict 5 | from lib.utils.utils import AverageMeter, accuracy 6 | from lib.utils.EarlyStopping import EarlyStopping 7 | from lib.training.train_helper import plot_result 8 | from sklearn.metrics import confusion_matrix 9 | import torch.nn.functional as F 10 | from tqdm import tqdm 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | def train_one_epoch( 16 | model, 17 | loader, 18 | optimizer, 19 | epoch_idx: int, 20 | lr_scheduler=None, 21 | ): 22 | losses_m = AverageMeter() 23 | acc_m = AverageMeter() 24 | 25 | model.train() 26 | print('Start training epoch: ', epoch_idx) 27 | for batch_idx, data in enumerate(tqdm(loader)): 28 | 29 | images, target = data 30 | images, target = images.to(device), target.to(device) 31 | target = target.flatten() 32 | 33 | output = model(images) 34 | 35 | loss = nn.CrossEntropyLoss()(output, target) 36 | 37 | losses_m.update(loss.item(), images.size(0)) 38 | acc1 = accuracy(output, target, topk=(1,)) 39 | acc_m.update(acc1[0].item(), output.size(0)) 40 | optimizer.zero_grad() 41 | loss.backward() 42 | optimizer.step() 43 | 44 | torch.cuda.synchronize() 45 | 46 | print(optimizer.param_groups[0]['lr']) 47 | 48 | if hasattr(optimizer, 'sync_lookahead'): 49 | optimizer.sync_lookahead() 50 | 51 | metrics = OrderedDict([('loss', losses_m.avg), ('Acc', acc_m.avg)]) 52 | if lr_scheduler is not None: 53 | lr_scheduler.step() 54 | 55 | return metrics 56 | 57 | 58 | def validate(model, loader): 59 | losses_m = AverageMeter() 60 | acc_m = AverageMeter() 61 | 62 | model.eval() 63 | 64 | with torch.no_grad(): 65 | for batch_idx, data in enumerate(loader): 66 | images, target = data 67 | images, target = images.to(device), target.to(device) 68 | target = target.flatten() 69 | 70 | output = model(images) 71 | 72 | loss = nn.CrossEntropyLoss()(output, target) 73 | acc1 = accuracy(output, target, topk=(1,)) 74 | # reduced_loss = loss.data 75 | 76 | torch.cuda.synchronize() 77 | 78 | losses_m.update(loss.item(), images.size(0)) 79 | acc_m.update(acc1[0].item(), output.size(0)) 80 | 81 | metrics = OrderedDict([('loss', losses_m.avg), ('Acc', acc_m.avg)]) 82 | 83 | return metrics 84 | 85 | 86 | def train(model, 87 | train_loader, 88 | val_loader, 89 | epoch_size=300, 90 | lr_scheduler=True, 91 | learning_rate=1e-7, optimizer_setup='Adam', w_decay=1e-7, 92 | patience=20, save_last=True, 93 | name='save', fold=0, 94 | ): 95 | seed = 42 96 | torch.manual_seed(seed) 97 | np.random.seed(seed) 98 | print('Training using:', device) 99 | model = torch.nn.DataParallel(model) 100 | model.to(device) 101 | 102 | if optimizer_setup == 'Adam': 103 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=w_decay) 104 | elif optimizer_setup == 'SGD': 105 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=w_decay) 106 | else: 107 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=w_decay) 108 | 109 | min_valid_loss = np.inf 110 | max_acc = 0 111 | highest_val_epoch = 0 112 | train_acc = [] 113 | train_losses = [] 114 | val_acc = [] 115 | val_losses = [] 116 | 117 | if lr_scheduler: 118 | print('Applied lr_scheduler') 119 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) 120 | else: 121 | scheduler = None 122 | 123 | early_stopping = EarlyStopping(patience=patience, verbose=True) 124 | 125 | print('Start Training Process:...') 126 | 127 | for epoch in range(epoch_size): 128 | 129 | train_metrics = train_one_epoch( 130 | model=model, 131 | loader=train_loader, 132 | optimizer=optimizer, 133 | epoch_idx=epoch + 1, 134 | lr_scheduler=scheduler, 135 | ) 136 | 137 | eval_metrics = validate(model, val_loader) 138 | 139 | train_acc.append(train_metrics['Acc']) 140 | train_losses.append(train_metrics['loss']) 141 | val_acc.append(eval_metrics['Acc']) 142 | val_losses.append(eval_metrics['loss']) 143 | 144 | if save_last: 145 | torch.save(model.module.state_dict(), 146 | f'/media/tedi/Elements/YJ_GM_Project/WEIGHT/{name}/Fold{fold}/{name}_{fold}_last.pth') 147 | print(f'Epoch {epoch + 1}: Train: {train_metrics}-----Val: {eval_metrics}') 148 | 149 | if min_valid_loss > eval_metrics['loss']: 150 | print(f'Validation Loss Decreased. \t Saving The Model') 151 | min_valid_loss = eval_metrics['loss'] 152 | # Saving State Dict 153 | torch.save(model.module.state_dict(), 154 | f'/media/tedi/Elements/YJ_GM_Project/WEIGHT/{name}/Fold{fold}/{name}_{fold}_best_loss.pth') 155 | 156 | if max_acc < eval_metrics['Acc']: 157 | print(f'Validation Acc Increased. \t Saving The Model') 158 | max_acc = eval_metrics['Acc'] 159 | highest_val_epoch = epoch + 1 160 | # Saving State Dict 161 | torch.save(model.module.state_dict(), 162 | f'/media/tedi/Elements/YJ_GM_Project/WEIGHT/{name}/Fold{fold}/{name}_{fold}_best_acc.pth') 163 | 164 | early_stopping(eval_metrics['loss'], model) 165 | if early_stopping.early_stop: 166 | print(f'Early stopping at: {epoch - 9}') 167 | print(f'Highest validation accuracy: {max_acc} at epoch {highest_val_epoch}') 168 | plot_result(f'/media/tedi/Elements/YJ_GM_Project/WEIGHT/{name}/Fold{fold}/{name}_{fold}__Loss', val_losses, 169 | train_losses, type_data='Loss') 170 | plot_result(f'/media/tedi/Elements/YJ_GM_Project/WEIGHT/{name}/Fold{fold}/{name}_{fold}__Acc', val_acc, 171 | train_acc, 172 | type_data='Accuracy') 173 | break 174 | 175 | 176 | def test( 177 | model, 178 | test_loader, 179 | output_size 180 | ): 181 | y_pred = [] 182 | y_true = [] 183 | prob = [] 184 | 185 | seed = 42 186 | torch.manual_seed(seed) 187 | np.random.seed(seed) 188 | model.eval() 189 | with torch.no_grad(): 190 | print('Start Testing:...') 191 | for batch_idx, data in enumerate(test_loader): 192 | images, target = data 193 | images, target = images.to(device), target.to(device) 194 | target = target.flatten() 195 | 196 | output = model(images) 197 | 198 | prob.extend(output) 199 | output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() 200 | y_pred.extend(output) 201 | labels = target.data.cpu().numpy() 202 | y_true.extend(labels) # Save Truth 203 | 204 | conf_mat = confusion_matrix(y_true, y_pred) 205 | y_true_1 = torch.LongTensor(y_true) 206 | y_true_2 = F.one_hot(y_true_1, num_classes=output_size) 207 | prob_1 = torch.FloatTensor(prob) 208 | print('Testing has finished.') 209 | return prob_1, y_true_2, conf_mat 210 | -------------------------------------------------------------------------------- /lib/training/train_helper.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | 3 | 4 | def plot_result(title, val_list, train_list, type_data='Loss'): 5 | fig = plt.figure(figsize=(10, 10), dpi=500) 6 | plt.title(f'{title}') 7 | plt.plot(val_list, label="val") 8 | plt.plot(train_list, label="train") 9 | plt.xlabel("iterations") 10 | plt.ylabel(f'{type_data}') 11 | if type_data == 'Loss': 12 | plt.ylim(0, 3) 13 | else: 14 | plt.ylim(40, 100) 15 | plt.legend() 16 | fig.savefig(f'{title}_{type_data}.png', bbox_inches='tight') 17 | plt.close(fig) -------------------------------------------------------------------------------- /lib/utils/EarlyStopping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class EarlyStopping: 6 | """Early stops the training if validation loss doesn't improve after a given patience.""" 7 | 8 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 9 | """ 10 | Args: 11 | patience (int): How long to wait after last time validation loss improved. 12 | Default: 7 13 | verbose (bool): If True, prints a message for each validation loss improvement. 14 | Default: False 15 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 16 | Default: 0 17 | path (str): Path for the checkpoint to be saved to. 18 | Default: 'checkpoint.pt' 19 | trace_func (function): trace print function. 20 | Default: print 21 | """ 22 | self.patience = patience 23 | self.verbose = verbose 24 | self.counter = 0 25 | self.best_score = None 26 | self.early_stop = False 27 | self.val_loss_min = np.Inf 28 | self.delta = delta 29 | self.path = path 30 | self.trace_func = trace_func 31 | 32 | def __call__(self, val_loss, model): 33 | 34 | score = -val_loss 35 | 36 | if self.best_score is None: 37 | self.best_score = score 38 | self.save_checkpoint(val_loss, model) 39 | elif score < self.best_score + self.delta: 40 | self.counter += 1 41 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 42 | if self.counter >= self.patience: 43 | self.early_stop = True 44 | else: 45 | self.best_score = score 46 | self.save_checkpoint(val_loss, model) 47 | self.counter = 0 48 | 49 | def save_checkpoint(self, val_loss, model): 50 | '''Saves model when validation loss decrease.''' 51 | if self.verbose: 52 | self.trace_func( 53 | f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 54 | torch.save(model.state_dict(), self.path) 55 | self.val_loss_min = val_loss 56 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | 4 | 5 | class AverageMeter: 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def accuracy(output, target, topk=(1,)): 25 | """Computes the accuracy over the k top predictions for the specified values of k""" 26 | maxk = min(max(topk), output.size()[1]) 27 | batch_size = target.size(0) 28 | _, pred = output.topk(maxk, 1, True, True) 29 | pred = pred.t() 30 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 31 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 32 | 33 | 34 | class ParseKwargs(argparse.Action): 35 | def __call__(self, parser, namespace, values, option_string=None): 36 | kw = {} 37 | for value in values: 38 | key, value = value.split('=') 39 | try: 40 | kw[key] = ast.literal_eval(value) 41 | except ValueError: 42 | kw[key] = str(value) # fallback to string (avoid need to escape on command line) 43 | setattr(namespace, self.dest, kw) 44 | -------------------------------------------------------------------------------- /photo/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giaminhgist/3D-DAM/ee6841f6d3ce5eb68e511d233f29c6bcbe93272e/photo/model.png -------------------------------------------------------------------------------- /photo/test_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giaminhgist/3D-DAM/ee6841f6d3ce5eb68e511d233f29c6bcbe93272e/photo/test_performance.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from lib.dataloader.df_reader import df_reader 2 | from lib.dataloader.MRIDataloader import MRIDataset 3 | from lib.training import train 4 | from lib.model.create_model import create_model 5 | from lib.utils.utils import ParseKwargs 6 | from torch.utils.data import DataLoader 7 | import argparse 8 | import torch 9 | import numpy as np 10 | 11 | parser = argparse.ArgumentParser(description='Training Config', add_help=False) 12 | parser.add_argument('--experiment_name', type=str, default='AD_CN') 13 | parser.add_argument('--task', type=str, default='AD_CN') 14 | parser.add_argument('--fold', type=int, default=0) 15 | parser.add_argument('--train_type', type=str, default='image_level') 16 | parser.add_argument('--output_size', type=int, default=3) 17 | parser.add_argument('--learning_rate', type=float, default=1e-4) 18 | parser.add_argument('--w_decay', type=float, default=1e-4) 19 | parser.add_argument('--batch_size', type=int, default=8) 20 | parser.add_argument('--patch_size', type=int, default=32) 21 | parser.add_argument('--epoch_size', type=int, default=400) 22 | parser.add_argument('--drop_out', type=float, default=0.5) 23 | parser.add_argument('--patience', type=int, default=15) 24 | parser.add_argument('--image_folder', type=str, default='/media/tedi/Elements/ADNI_Database/Images/PROCESS/subjects/') 25 | parser.add_argument('--train_path', type=str, default=None) 26 | parser.add_argument('--val_path', type=str, default=None) 27 | parser.add_argument('--model_name', type=str, default='SEModule') 28 | parser.add_argument('--model_kwargs', nargs='*', default={}, 29 | action=ParseKwargs) # example --model_kwargs embed_size=64 number_head=8 30 | args = parser.parse_args() 31 | 32 | if __name__ == '__main__': 33 | torch.cuda.empty_cache() 34 | seed = 42 35 | torch.manual_seed(seed) 36 | np.random.seed(seed) 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | 39 | # Training 40 | 41 | 42 | if args.train_path is None: 43 | train_path = f'/media/tedi/Elements/ADNI_Database/XLS_Files/{args.train_type}/{args.task}/{args.task}_train_fold_{args.fold}.xlsx' 44 | else: 45 | train_path = args.train_path 46 | if args.val_path is None: 47 | val_path = f'/media/tedi/Elements/ADNI_Database/XLS_Files/{args.train_type}/{args.task}/{args.task}_val_fold_{args.fold}.xlsx' 48 | else: 49 | val_path = args.val_path 50 | 51 | train_image_path, train_label_dict = df_reader(train_path, process_path=args.image_folder) 52 | val_image_path, val_label_dict = df_reader(val_path, process_path=args.image_folder) 53 | 54 | train_dataset = MRIDataset( 55 | train_image_path, train_label_dict, task=args.task 56 | ) 57 | valid_dataset = MRIDataset( 58 | val_image_path, val_label_dict, task=args.task 59 | ) 60 | 61 | print('Number of train files', len(train_dataset)) 62 | print('Number of val files', len(valid_dataset)) 63 | 64 | train_loader = DataLoader( 65 | train_dataset, batch_size=args.batch_size, shuffle=True 66 | ) 67 | 68 | valid_loader = DataLoader( 69 | valid_dataset, batch_size=args.batch_size, shuffle=False 70 | ) 71 | 72 | # Model configuration 73 | model = create_model( 74 | model_name=args.model_name, 75 | num_classes=args.output_size, 76 | **args.model_kwargs, 77 | ) 78 | train.train( 79 | model=model, 80 | train_loader=train_loader, 81 | val_loader=valid_loader, 82 | epoch_size=args.epoch_size, 83 | lr_scheduler=True, 84 | learning_rate=args.learning_rate, optimizer_setup='Adam', w_decay=args.w_decay, 85 | patience=args.patience, save_last=True, 86 | name=args.experiment_name, fold=args.fold 87 | ) 88 | --------------------------------------------------------------------------------