├── utils ├── __init__.py ├── util.py └── illustration_util.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── base ├── __init__.py ├── base_model.py ├── base_data_loader.py └── base_trainer.py ├── requirements.txt ├── data_loader └── Dataloader.py ├── LICENSE ├── model ├── utils │ ├── metric.py │ ├── metric_utils.py │ ├── loss.py │ └── layers.py ├── DeepAtlas.py ├── MultitaskNetwork.py ├── LongitudinalFCDenseNet.py └── FCDenseNet.py ├── dataset ├── DatasetStatic.py ├── DatasetLongitudinal.py ├── DatasetStaticStacked.py └── dataset_utils.py ├── configs ├── Deep_Atlas.py ├── Static_Network.py ├── Static_Network_Zhang.py ├── Static_Network_Asymmetric.py ├── Longitudinal_Siamese_Network.py ├── Longitudinal_Network.py └── Multitask_Longitudinal_Network.py ├── trainer ├── StaticTrainer.py ├── LongitudinalTrainer.py ├── LongitudinalMultitaskTrainer.py ├── DeepAtlasTrainer.py └── Trainer.py ├── .gitignore ├── train.py ├── parse_config.py ├── README.md ├── test_single_view.py └── test.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * 3 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future==0.18.2 2 | h5py==2.10.0 3 | joblib==0.17.0 4 | nibabel==3.1.1 5 | numpy==1.19.2 6 | opencv-python==4.4.0.44 7 | packaging==20.4 8 | pandas==1.1.2 9 | Pillow==7.2.0 10 | pyparsing==2.4.7 11 | python-dateutil==2.8.1 12 | pytz==2020.1 13 | scikit-learn==0.23.2 14 | scipy==1.5.2 15 | six==1.15.0 16 | threadpoolctl==2.1.0 17 | torch==1.6.0 18 | torchvision==0.7.0 19 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | 7 | class BaseModel(nn.Module): 8 | """ 9 | Base class for all models 10 | """ 11 | 12 | @abstractmethod 13 | def forward(self, *inputs): 14 | """ 15 | Forward pass logic 16 | 17 | :return: Model output 18 | """ 19 | raise NotImplementedError 20 | 21 | def __str__(self): 22 | """ 23 | Model prints with number of trainable parameters 24 | """ 25 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 26 | params = sum([np.prod(p.size()) for p in model_parameters]) 27 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 28 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | 5 | from utils import read_json 6 | 7 | 8 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 9 | """ 10 | Setup logging configuration 11 | """ 12 | log_config = Path(log_config) 13 | if log_config.is_file(): 14 | config = read_json(log_config) 15 | # modify logging paths based on run config 16 | for _, handler in config['handlers'].items(): 17 | if 'filename' in handler: 18 | handler['filename'] = str(save_dir / handler['filename']) 19 | 20 | logging.config.dictConfig(config) 21 | else: 22 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 23 | logging.basicConfig(level=default_level) 24 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "disable_existing_loggers": false, 4 | "formatters": { 5 | "simple": { 6 | "format": "%(message)s" 7 | }, 8 | "datetime": { 9 | "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 10 | } 11 | }, 12 | "handlers": { 13 | "console": { 14 | "class": "logging.StreamHandler", 15 | "level": "DEBUG", 16 | "formatter": "simple", 17 | "stream": "ext://sys.stdout" 18 | }, 19 | "info_file_handler": { 20 | "class": "logging.handlers.RotatingFileHandler", 21 | "level": "INFO", 22 | "formatter": "datetime", 23 | "filename": "info.log", 24 | "maxBytes": 10485760, 25 | "backupCount": 20, 26 | "encoding": "utf8" 27 | } 28 | }, 29 | "root": { 30 | "level": "INFO", 31 | "handlers": [ 32 | "console", 33 | "info_file_handler" 34 | ] 35 | } 36 | } -------------------------------------------------------------------------------- /data_loader/Dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | 5 | class Dataloader(DataLoader): 6 | """ 7 | Data loading 8 | """ 9 | 10 | def __init__(self, dataset, batch_size, shuffle=True, num_workers=1): 11 | self.dataset = dataset 12 | 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | 17 | if self.shuffle: 18 | self.sampler = RandomSampler(self.dataset) 19 | else: 20 | self.sampler = SequentialSampler(self.dataset) 21 | self.shuffle = False 22 | 23 | self.init_kwargs = { 24 | 'dataset': self.dataset, 25 | 'batch_size': batch_size, 26 | 'shuffle': self.shuffle, 27 | 'collate_fn': default_collate, 28 | 'num_workers': num_workers 29 | } 30 | super().__init__(sampler=self.sampler, **self.init_kwargs) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Stefan Denner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import f1_score 4 | 5 | from model.utils import metric_utils 6 | 7 | 8 | def precision(output, target): 9 | with torch.no_grad(): 10 | epsilon, tp, _, fp, _ = metric_utils.eps_tp_tn_fp_fn(output, target) 11 | return tp / (tp + fp + epsilon) 12 | 13 | 14 | def recall(output, target): 15 | with torch.no_grad(): 16 | epsilon, tp, _, _, fn = metric_utils.eps_tp_tn_fp_fn(output, target) 17 | return tp / (tp + fn + epsilon) 18 | 19 | 20 | def dice_loss(output, target): 21 | with torch.no_grad(): 22 | return metric_utils.asymmetric_loss(1, output, target) 23 | 24 | 25 | def dice_score(output, target): 26 | with torch.no_grad(): 27 | target = metric_utils.flatten(target).cpu().detach().float() 28 | output = metric_utils.flatten(output).cpu().detach().float() 29 | if len(output.shape) == 2: # is one hot encoded vector 30 | target = np.argmax(target, axis=0) 31 | output = np.argmax(output, axis=0) 32 | return f1_score(target, output) 33 | 34 | 35 | def asymmetric_loss(output, target): 36 | with torch.no_grad(): 37 | return metric_utils.asymmetric_loss(2, output, target) 38 | -------------------------------------------------------------------------------- /dataset/DatasetStatic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | 8 | from dataset.dataset_utils import Phase, Modalities, Views, Mode, retrieve_data_dir_paths, Evaluate 9 | 10 | 11 | class DatasetStatic(Dataset): 12 | """DatasetStatic dataset""" 13 | 14 | def __init__(self, data_dir, phase=Phase.TRAIN, modalities=(), val_patients=None, evaluate: Evaluate = Evaluate.TRAINING, preprocess=True, 15 | view: Views = None): 16 | self.modalities = list(map(lambda x: Modalities(x), modalities)) 17 | self.data_dir_paths = retrieve_data_dir_paths(data_dir, evaluate, phase, preprocess, val_patients, Mode.STATIC, view) 18 | 19 | def __len__(self): 20 | return len(self.data_dir_paths) 21 | 22 | def __getitem__(self, idx): 23 | data, label = [], None 24 | for i, modality in enumerate(self.modalities): 25 | with h5py.File(os.path.join(self.data_dir_paths[idx], f'{modality.value}.h5'), 'r') as f: 26 | data.append(f['data'][()]) 27 | if label is None: 28 | label = F.one_hot(torch.as_tensor(f['label'][()], dtype=torch.int64), num_classes=2).permute(2, 0, 1) 29 | 30 | return torch.as_tensor(data).float(), label.float() 31 | -------------------------------------------------------------------------------- /model/DeepAtlas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from base import BaseModel 4 | from model.FCDenseNet import FCDenseNet 5 | from model.utils.layers import SpatialTransformer 6 | 7 | 8 | class DeepAtlas(BaseModel): 9 | def __init__(self, in_channels=8, resolution=(217, 217)): 10 | super().__init__() 11 | self.densenet_seg = FCDenseNet( 12 | in_channels=in_channels, n_classes=1, apply_softmax=False 13 | ) 14 | self.densenet_voxelmorph = FCDenseNet( 15 | in_channels=in_channels, n_classes=2, apply_softmax=False 16 | ) 17 | self.spatial_transform = SpatialTransformer(resolution) 18 | 19 | def forward(self, input_moving, input_fixed): 20 | x = torch.cat([input_moving, input_fixed], dim=1) 21 | x_ref = torch.cat([input_fixed, input_moving], dim=1) 22 | y_seg_moving = torch.sigmoid(self.densenet_seg(x)) 23 | y_seg_fixed = torch.sigmoid(self.densenet_seg(x_ref)) 24 | flow = self.densenet_voxelmorph(x) 25 | 26 | modalities = torch.unbind(input_moving, dim=1) 27 | 28 | y_deformation = torch.stack( 29 | [ 30 | torch.squeeze( 31 | self.spatial_transform(torch.unsqueeze(modality, 1), flow), 32 | dim=1 33 | ) 34 | for modality in modalities 35 | ], 36 | dim=1 37 | ) 38 | 39 | y_seg_deformation = self.spatial_transform(y_seg_moving, flow) 40 | 41 | return y_seg_moving, y_seg_fixed, y_deformation, y_seg_deformation, flow 42 | -------------------------------------------------------------------------------- /configs/Deep_Atlas.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CONFIG = { 4 | "name": f"{os.path.basename(__file__).split('.')[0]}", 5 | "n_gpu": 1, 6 | 7 | "arch": { 8 | "type": "DeepAtlas", 9 | "args": { 10 | "in_channels": 2 11 | } 12 | }, 13 | "dataset": { 14 | "type": "DatasetLongitudinal", 15 | "args": { 16 | "data_dir": "../ISBIMSlesionChallenge/", 17 | "preprocess": True, 18 | "modalities": ['flair'], 19 | "val_patients": [4] 20 | } 21 | }, 22 | "data_loader": { 23 | "type": "Dataloader", 24 | "args": { 25 | "batch_size": 4, 26 | "shuffle": True, 27 | "num_workers": 4, 28 | } 29 | }, 30 | "optimizer": { 31 | "type": "Adam", 32 | "args": { 33 | "lr": 0.0001, 34 | "weight_decay": 0, 35 | "amsgrad": True 36 | } 37 | }, 38 | "loss": "deep_atlas_loss", 39 | "metrics": [ 40 | "precision", "recall", "dice_loss", "dice_score", "asymmetric_loss" 41 | ], 42 | "lr_scheduler": { 43 | "type": "StepLR", 44 | "args": { 45 | "step_size": 50, 46 | "gamma": 0.1 47 | } 48 | }, 49 | "trainer": { 50 | "type": "DeepAtlasTrainer", 51 | "epochs": 100, 52 | "save_dir": "../saved/", 53 | "save_period": 1, 54 | "verbosity": 2, 55 | "monitor": "min val_dice_loss", 56 | "early_stop": 10, 57 | "tensorboard": True 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /configs/Static_Network.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CONFIG = { 4 | "name": f"{os.path.basename(__file__).split('.')[0]}", 5 | "n_gpu": 1, 6 | 7 | "arch": { 8 | "type": "FCDenseNet", 9 | "args": { 10 | "in_channels": 4 # 1 # 4 11 | } 12 | }, 13 | "dataset": { 14 | "type": "DatasetStatic", 15 | "args": { 16 | "data_dir": "../ISBIMSlesionChallenge/", 17 | "preprocess": True, 18 | "modalities": ['flair', 'mprage', 'pd', 't2'], 19 | "val_patients": [4] 20 | } 21 | }, 22 | "data_loader": { 23 | "type": "Dataloader", 24 | "args": { 25 | "batch_size": 2, 26 | "shuffle": True, 27 | "num_workers": 4, 28 | } 29 | }, 30 | "optimizer": { 31 | "type": "Adam", 32 | "args": { 33 | "lr": 0.0001, 34 | "weight_decay": 0, 35 | "amsgrad": True 36 | } 37 | }, 38 | "loss": "mse", 39 | "metrics": [ 40 | "precision", "recall", "dice_loss", "dice_score", "asymmetric_loss" 41 | ], 42 | "lr_scheduler": { 43 | "type": "StepLR", 44 | "args": { 45 | "step_size": 50, 46 | "gamma": 0.1 47 | } 48 | }, 49 | "trainer": { 50 | "type": "StaticTrainer", 51 | "epochs": 100, 52 | "save_dir": "../saved/", 53 | "save_period": 1, 54 | "verbosity": 2, 55 | "monitor": "min val_dice_loss", 56 | "early_stop": 10, 57 | "tensorboard": True 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /configs/Static_Network_Zhang.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CONFIG = { 4 | "name": f"{os.path.basename(__file__).split('.')[0]}", 5 | "n_gpu": 1, 6 | 7 | "arch": { 8 | "type": "FCDenseNet", 9 | "args": { 10 | "in_channels": 12 11 | } 12 | }, 13 | "dataset": { 14 | "type": "DatasetStaticStacked", 15 | "args": { 16 | "data_dir": "../ISBIMSlesionChallenge/", 17 | "preprocess": True, 18 | "modalities": ['flair', 'mprage', 'pd', 't2'], 19 | "val_patients": [4] 20 | } 21 | }, 22 | "data_loader": { 23 | "type": "Dataloader", 24 | "args": { 25 | "batch_size": 2, 26 | "shuffle": True, 27 | "num_workers": 4, 28 | } 29 | }, 30 | "optimizer": { 31 | "type": "Adam", 32 | "args": { 33 | "lr": 0.0001, 34 | "weight_decay": 0, 35 | "amsgrad": True 36 | } 37 | }, 38 | "loss": "mse", 39 | "metrics": [ 40 | "precision", "recall", "dice_loss", "dice_score", "asymmetric_loss" 41 | ], 42 | "lr_scheduler": { 43 | "type": "StepLR", 44 | "args": { 45 | "step_size": 50, 46 | "gamma": 0.1 47 | } 48 | }, 49 | "trainer": { 50 | "type": "StaticTrainer", 51 | "epochs": 100, 52 | "save_dir": "../saved/", 53 | "save_period": 1, 54 | "verbosity": 2, 55 | "monitor": "min val_dice_loss", 56 | "early_stop": 10, 57 | "tensorboard": True 58 | }, 59 | } 60 | -------------------------------------------------------------------------------- /configs/Static_Network_Asymmetric.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CONFIG = { 4 | "name": f"{os.path.basename(__file__).split('.')[0]}", 5 | "n_gpu": 1, 6 | 7 | "arch": { 8 | "type": "FCDenseNet", 9 | "args": { 10 | "in_channels": 4 # 1 # 4 11 | } 12 | }, 13 | "dataset": { 14 | "type": "DatasetStatic", 15 | "args": { 16 | "data_dir": "../ISBIMSlesionChallenge/", 17 | "preprocess": True, 18 | "modalities": ['flair', 'mprage', 'pd', 't2'], 19 | "val_patients": [4] 20 | } 21 | }, 22 | "data_loader": { 23 | "type": "Dataloader", 24 | "args": { 25 | "batch_size": 2, 26 | "shuffle": True, 27 | "num_workers": 4, 28 | } 29 | }, 30 | "optimizer": { 31 | "type": "Adam", 32 | "args": { 33 | "lr": 0.0001, 34 | "weight_decay": 0, 35 | "amsgrad": True 36 | } 37 | }, 38 | "loss": "asymmetric_loss", 39 | "metrics": [ 40 | "precision", "recall", "dice_loss", "dice_score", "asymmetric_loss" 41 | ], 42 | "lr_scheduler": { 43 | "type": "StepLR", 44 | "args": { 45 | "step_size": 50, 46 | "gamma": 0.1 47 | } 48 | }, 49 | "trainer": { 50 | "type": "StaticTrainer", 51 | "epochs": 100, 52 | "save_dir": "../saved/", 53 | "save_period": 1, 54 | "verbosity": 2, 55 | "monitor": "min val_dice_loss", 56 | "early_stop": 10, 57 | "tensorboard": True 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /configs/Longitudinal_Siamese_Network.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CONFIG = { 4 | "name": f"{os.path.basename(__file__).split('.')[0]}", 5 | "n_gpu": 1, 6 | 7 | "arch": { 8 | "type": "LongitudinalFCDenseNet", 9 | "args": { 10 | "in_channels": 4 11 | } 12 | }, 13 | "dataset": { 14 | "type": "DatasetLongitudinal", 15 | "args": { 16 | "data_dir": "../ISBIMSlesionChallenge/", 17 | "preprocess": True, 18 | "modalities": ['flair', 'mprage', 'pd', 't2'], 19 | "val_patients": [4] 20 | } 21 | }, 22 | "data_loader": { 23 | "type": "Dataloader", 24 | "args": { 25 | "batch_size": 2, 26 | "shuffle": True, 27 | "num_workers": 4, 28 | } 29 | }, 30 | "optimizer": { 31 | "type": "Adam", 32 | "args": { 33 | "lr": 0.0001, 34 | "weight_decay": 0, 35 | "amsgrad": True 36 | } 37 | }, 38 | "loss": "mse", 39 | "metrics": [ 40 | "precision", "recall", "dice_loss", "dice_score", "asymmetric_loss" 41 | ], 42 | "lr_scheduler": { 43 | "type": "StepLR", 44 | "args": { 45 | "step_size": 50, 46 | "gamma": 0.1 47 | } 48 | }, 49 | "trainer": { 50 | "type": "LongitudinalTrainer", 51 | "epochs": 100, 52 | "save_dir": "../saved/", 53 | "save_period": 1, 54 | "verbosity": 2, 55 | "monitor": "min val_dice_loss", 56 | "early_stop": 10, 57 | "tensorboard": True 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /trainer/StaticTrainer.py: -------------------------------------------------------------------------------- 1 | from logger import Mode 2 | from trainer.Trainer import Trainer 3 | from utils.illustration_util import log_visualizations 4 | 5 | 6 | class StaticTrainer(Trainer): 7 | """ 8 | Trainer class 9 | """ 10 | 11 | def __init__(self, model, loss, metric_ftns, optimizer, config, data_loader, 12 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 13 | super().__init__(model, loss, metric_ftns, optimizer, config, data_loader, valid_data_loader, lr_scheduler, len_epoch) 14 | 15 | def _process(self, epoch, data_loader, metrics, mode: Mode = Mode.TRAIN): 16 | _len_epoch = self.len_epoch if mode == Mode.TRAIN else self.len_epoch_val 17 | for batch_idx, (data, target) in enumerate(data_loader): 18 | data, target = data.to(self.device), target.to(self.device) 19 | 20 | if mode == Mode.TRAIN: 21 | self.optimizer.zero_grad() 22 | output = self.model(data) 23 | loss = self.loss(output, target) 24 | if mode == Mode.TRAIN: 25 | loss.backward() 26 | self.optimizer.step() 27 | 28 | self.log_scalars(metrics, self.get_step(batch_idx, epoch, _len_epoch), output, target, loss, mode) 29 | 30 | if not (batch_idx % self.log_step): 31 | self.logger.info(f'{mode.value} Epoch: {epoch} {self._progress(data_loader, batch_idx, _len_epoch)} Loss: {loss.item():.6f}') 32 | if not (batch_idx % (_len_epoch // 10)): 33 | log_visualizations(self.writer, data, output, target) 34 | 35 | del data, target 36 | -------------------------------------------------------------------------------- /model/MultitaskNetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from base import BaseModel 4 | from model.FCDenseNet import FCDenseNet, FCDenseNetEncoder 5 | from model.utils.layers import SpatialTransformer 6 | 7 | 8 | class MultitaskNetwork(BaseModel): 9 | def __init__( 10 | self, in_channels=8, resolution=(217, 217) 11 | ): 12 | super().__init__() 13 | self.encoder = FCDenseNetEncoder(in_channels=in_channels) 14 | self.densenet_seg = FCDenseNet( 15 | in_channels=in_channels, n_classes=2, apply_softmax=True, 16 | encoder=self.encoder 17 | ) 18 | self.densenet_voxelmorph = FCDenseNet( 19 | in_channels=in_channels, n_classes=2, apply_softmax=False, 20 | encoder=self.encoder 21 | ) 22 | self.spatial_transform = SpatialTransformer(resolution) 23 | 24 | def forward(self, input_moving, input_fixed): 25 | x = torch.cat([input_moving, input_fixed], dim=1) 26 | out, skip_connections = self.encoder(x) 27 | y_seg = self.densenet_seg( 28 | [out, skip_connections], is_encoder_output=True 29 | ) 30 | 31 | flow = self.densenet_voxelmorph( 32 | [out, skip_connections], is_encoder_output=True 33 | ) 34 | 35 | modalities = torch.unbind(input_moving, dim=1) 36 | 37 | y_deformation = torch.stack( 38 | [torch.squeeze( 39 | self.spatial_transform(torch.unsqueeze(modality, 1), flow), 40 | dim=1 41 | ) for modality in modalities], 42 | dim=1 43 | ) 44 | return y_seg, y_deformation, flow 45 | -------------------------------------------------------------------------------- /configs/Longitudinal_Network.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CONFIG = { 4 | "name": f"{os.path.basename(__file__).split('.')[0]}", 5 | "n_gpu": 1, 6 | 7 | "arch": { 8 | "type": "LongitudinalFCDenseNet", 9 | "args": { 10 | "in_channels": 4, 11 | "siamese": False 12 | } 13 | }, 14 | "dataset": { 15 | "type": "DatasetLongitudinal", 16 | "args": { 17 | "data_dir": "../ISBIMSlesionChallenge/", 18 | "preprocess": True, 19 | "modalities": ['flair', 'mprage', 'pd', 't2'], 20 | "val_patients": [4] 21 | } 22 | }, 23 | "data_loader": { 24 | "type": "Dataloader", 25 | "args": { 26 | "batch_size": 4, 27 | "shuffle": True, 28 | "num_workers": 4, 29 | } 30 | }, 31 | "optimizer": { 32 | "type": "Adam", 33 | "args": { 34 | "lr": 0.0001, 35 | "weight_decay": 0, 36 | "amsgrad": True 37 | } 38 | }, 39 | "loss": "mse", 40 | "metrics": [ 41 | "precision", "recall", "dice_loss", "dice_score", "asymmetric_loss" 42 | ], 43 | "lr_scheduler": { 44 | "type": "StepLR", 45 | "args": { 46 | "step_size": 50, 47 | "gamma": 0.1 48 | } 49 | }, 50 | "trainer": { 51 | "type": "LongitudinalTrainer", 52 | "epochs": 100, 53 | "save_dir": "../saved/", 54 | "save_period": 1, 55 | "verbosity": 2, 56 | "monitor": "min val_dice_loss", 57 | "early_stop": 10, 58 | "tensorboard": True 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /configs/Multitask_Longitudinal_Network.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CONFIG = { 4 | "name": f"{os.path.basename(__file__).split('.')[0]}", 5 | "n_gpu": 1, 6 | 7 | "arch": { 8 | "type": "MultitaskNetwork", 9 | "args": { 10 | "in_channels": 8, 11 | "resolution": (217, 217) 12 | } 13 | }, 14 | "dataset": { 15 | "type": "DatasetLongitudinal", 16 | "args": { 17 | "data_dir": "../ISBIMSlesionChallenge/", 18 | "preprocess": True, 19 | "modalities": ['flair', 'mprage', 'pd', 't2'], 20 | "val_patients": [4] 21 | } 22 | }, 23 | "data_loader": { 24 | "type": "Dataloader", 25 | "args": { 26 | "batch_size": 2, 27 | "shuffle": True, 28 | "num_workers": 4, 29 | } 30 | }, 31 | "optimizer": { 32 | "type": "Adam", 33 | "args": { 34 | "lr": 0.0001, 35 | "weight_decay": 0, 36 | "amsgrad": True 37 | } 38 | }, 39 | "loss": "multitask_loss", 40 | "metrics": [ 41 | "precision", "recall", "dice_loss", "dice_score", "asymmetric_loss" 42 | ], 43 | "lr_scheduler": { 44 | "type": "StepLR", 45 | "args": { 46 | "step_size": 50, 47 | "gamma": 0.1 48 | } 49 | }, 50 | "trainer": { 51 | "type": "LongitudinalMultitaskTrainer", 52 | "epochs": 100, 53 | "save_dir": "../saved/", 54 | "save_period": 1, 55 | "verbosity": 2, 56 | "monitor": "min val_dice_loss", 57 | "early_stop": 10, 58 | "tensorboard": True 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pprint 3 | from collections import OrderedDict 4 | from itertools import repeat 5 | from pathlib import Path 6 | 7 | import pandas as pd 8 | 9 | 10 | def write_config(content, fname): 11 | with fname.open('wt') as handle: 12 | handle.write("CONFIG = " + pprint.pformat(content)) 13 | handle.close() 14 | 15 | 16 | def read_json(fname): 17 | fname = Path(fname) 18 | with fname.open('rt') as handle: 19 | return json.load(handle, object_hook=OrderedDict) 20 | 21 | 22 | def write_json(content, fname): 23 | fname = Path(fname) 24 | with fname.open('wt') as handle: 25 | json.dump(content, handle, indent=4, sort_keys=False) 26 | 27 | 28 | def inf_loop(data_loader): 29 | ''' wrapper function for endless data loader. ''' 30 | for loader in repeat(data_loader): 31 | yield from loader 32 | 33 | 34 | class MetricTracker: 35 | def __init__(self, *keys, writer=None): 36 | self.writer = writer 37 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 38 | self.reset() 39 | 40 | def reset(self): 41 | for col in self._data.columns: 42 | self._data[col].values[:] = 0 43 | 44 | def update(self, key, value, n=1): 45 | if self.writer is not None: 46 | self.writer.add_scalar(key, value) 47 | self._data.total[key] += value * n 48 | self._data.counts[key] += n 49 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 50 | 51 | def avg(self, key): 52 | return self._data.average[key] 53 | 54 | def result(self): 55 | return dict(self._data.average) 56 | -------------------------------------------------------------------------------- /trainer/LongitudinalTrainer.py: -------------------------------------------------------------------------------- 1 | from logger import Mode 2 | from trainer.Trainer import Trainer 3 | from utils.illustration_util import log_visualizations_longitudinal 4 | 5 | 6 | class LongitudinalTrainer(Trainer): 7 | """ 8 | Trainer class 9 | """ 10 | 11 | def __init__(self, model, loss, metric_ftns, optimizer, config, data_loader, 12 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 13 | super().__init__(model, loss, metric_ftns, optimizer, config, data_loader, valid_data_loader, lr_scheduler, len_epoch) 14 | 15 | def _process(self, epoch, data_loader, metrics, mode: Mode = Mode.TRAIN): 16 | _len_epoch = self.len_epoch if mode == Mode.TRAIN else self.len_epoch_val 17 | for batch_idx, (x_ref, x, _, target) in enumerate(data_loader): 18 | x_ref, x, target = x_ref.to(self.device), x.to(self.device), target.to(self.device) 19 | 20 | if mode == Mode.TRAIN: 21 | self.optimizer.zero_grad() 22 | output = self.model(x_ref, x) 23 | loss = self.loss(output, target) 24 | if mode == Mode.TRAIN: 25 | loss.backward() 26 | self.optimizer.step() 27 | 28 | self.log_scalars(metrics, self.get_step(batch_idx, epoch, _len_epoch), output, target, loss, mode) 29 | 30 | if not (batch_idx % self.log_step): 31 | self.logger.info(f'{mode.value} Epoch: {epoch} {self._progress(data_loader, batch_idx, _len_epoch)} Loss: {loss.item():.6f}') 32 | if not (batch_idx % (_len_epoch // 10)): 33 | log_visualizations_longitudinal(self.writer, x_ref, x, output, target) 34 | 35 | del x_ref, x, target 36 | -------------------------------------------------------------------------------- /dataset/DatasetLongitudinal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | 8 | from dataset.dataset_utils import Phase, Modalities, Mode, retrieve_data_dir_paths, Evaluate 9 | 10 | 11 | class DatasetLongitudinal(Dataset): 12 | """DatasetLongitudinal dataset""" 13 | 14 | def __init__(self, data_dir, phase=Phase.TRAIN, modalities=(), val_patients=None, evaluate: Evaluate = Evaluate.TRAINING, preprocess=True, view=None): 15 | self.modalities = list(map(lambda x: Modalities(x), modalities)) 16 | self.data_dir_paths = retrieve_data_dir_paths(data_dir, evaluate, phase, preprocess, val_patients, Mode.LONGITUDINAL, view) 17 | 18 | def __len__(self): 19 | return len(self.data_dir_paths) 20 | 21 | def __getitem__(self, idx): 22 | x_ref, x, ref_label, label = [], [], None, None 23 | x_ref_path, x_path = self.data_dir_paths[idx] 24 | for i, modality in enumerate(self.modalities): 25 | with h5py.File(os.path.join(x_ref_path, f'{modality.value}.h5'), 'r') as f: 26 | x_ref.append(f['data'][()]) 27 | if ref_label is None: 28 | ref_label = F.one_hot(torch.as_tensor(f['label'][()], dtype=torch.int64), num_classes=2).permute(2, 0, 1) 29 | 30 | with h5py.File(os.path.join(x_path, f'{modality.value}.h5'), 'r') as f: 31 | x.append(f['data'][()]) 32 | if label is None: 33 | label = F.one_hot(torch.as_tensor(f['label'][()], dtype=torch.int64), num_classes=2).permute(2, 0, 1) 34 | return torch.as_tensor(x_ref).float(), torch.as_tensor(x).float(), ref_label.float(), label.float() 35 | -------------------------------------------------------------------------------- /model/utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def asymmetric_loss(beta, output, target): 6 | g = flatten(target) 7 | p = flatten(output) 8 | pg = (p * g).sum(-1) 9 | beta_sq = beta ** 2 10 | a = beta_sq / (1 + beta_sq) 11 | b = 1 / (1 + beta_sq) 12 | g_p = ((1 - p) * g).sum(-1) 13 | p_g = (p * (1 - g)).sum(-1) 14 | loss = (1. + pg) / (1. + pg + a * g_p + b * p_g) 15 | total_loss = torch.mean(1. - loss) 16 | return total_loss 17 | 18 | 19 | def eps_tp_tn_fp_fn(output, target): 20 | with torch.no_grad(): 21 | epsilon = 1e-7 22 | target = flatten(target).cpu().detach().float() 23 | output = flatten(output).cpu().detach().float() 24 | if len(output.shape) == 2: # is one hot encoded vector 25 | target = np.argmax(target, axis=0) 26 | output = np.argmax(output, axis=0) 27 | tp = torch.sum(target * output) 28 | tn = torch.sum((1 - target) * (1 - output)) 29 | fp = torch.sum((1 - target) * output) 30 | fn = torch.sum(target * (1 - output)) 31 | return epsilon, tp.float(), tn.float(), fp.float(), fn.float() 32 | 33 | 34 | def flatten(tensor): 35 | """Flattens a given tensor such that the channel axis is first. 36 | The shapes are transformed as follows: 37 | (N, C, D, H, W) -> (C, N * D * H * W) 38 | """ 39 | if type(tensor) == torch.Tensor: 40 | C = tensor.size(1) 41 | # new axis order 42 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 43 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 44 | transposed = tensor.permute(axis_order) 45 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 46 | return transposed.contiguous().view(C, -1).float() 47 | else: 48 | return torch.as_tensor(tensor.flatten()).float() 49 | -------------------------------------------------------------------------------- /model/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from model.utils import metric_utils 5 | 6 | 7 | def inf(*args): 8 | return torch.as_tensor(float("Inf")) 9 | 10 | 11 | def gradient_loss(s): 12 | dy = torch.abs(s[:, :, 1:, :] - s[:, :, :-1, :]) ** 2 13 | dx = torch.abs(s[:, :, :, 1:] - s[:, :, :, :-1]) ** 2 14 | return (torch.mean(dx) + torch.mean(dy)) / 2.0 15 | 16 | 17 | def multitask_loss(warp, flow, output, input_fixed, target_fixed): 18 | recon_loss = mse(warp, input_fixed) 19 | grad_loss = gradient_loss(flow) 20 | seg_loss = mse(output, target_fixed) 21 | return recon_loss + 0.01 * grad_loss + seg_loss 22 | 23 | 24 | def deformation_loss(warp, flow, input_fixed): 25 | recon_loss = mse(warp, input_fixed) 26 | grad_loss = gradient_loss(flow) 27 | return recon_loss + 0.01 * grad_loss 28 | 29 | 30 | def l1(output, target): 31 | return F.l1_loss(output, target) 32 | 33 | 34 | def mse(output, target): 35 | return F.mse_loss(output, target) 36 | 37 | 38 | def nll_loss(output, target): 39 | return F.nll_loss(metric_utils.flatten(output), metric_utils.flatten(target)) 40 | 41 | 42 | def dice_loss(output, target): 43 | return metric_utils.asymmetric_loss(1, output, target) 44 | 45 | 46 | def asymmetric_loss(output, target): 47 | return metric_utils.asymmetric_loss(2, output, target) 48 | 49 | 50 | def deep_atlas_loss( 51 | y_seg_moving, y_seg_fixed, y_deformation, y_seg_deformation, flow, 52 | target_moving, target_fixed, input_fixed 53 | ): 54 | seg_moving_loss = mse(y_seg_moving, target_moving) 55 | seg_fixed_loss = mse(y_seg_fixed, target_fixed) 56 | seg_deformation_loss = mse(y_seg_deformation, target_fixed) 57 | recon_loss = mse(y_deformation, input_fixed) 58 | grad_loss = 0.01 * gradient_loss(flow) 59 | loss = seg_moving_loss + seg_fixed_loss + seg_deformation_loss + recon_loss + grad_loss 60 | 61 | return loss 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # input data, saved log, checkpoints 104 | data/ 105 | input/ 106 | saved/ 107 | datasets/ 108 | 109 | # editor, os cache directory 110 | .vscode/ 111 | .idea/ 112 | __MACOSX/ 113 | /wandb/ 114 | Longitudinal-MS-Lesion-Segmentation.iml 115 | -------------------------------------------------------------------------------- /trainer/LongitudinalMultitaskTrainer.py: -------------------------------------------------------------------------------- 1 | from logger import Mode 2 | from trainer.Trainer import Trainer 3 | from utils.illustration_util import log_visualizations_deformations 4 | 5 | 6 | class LongitudinalMultitaskTrainer(Trainer): 7 | """ 8 | Trainer class 9 | """ 10 | 11 | def __init__(self, model, loss, metric_ftns, optimizer, config, data_loader, 12 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 13 | super().__init__(model, loss, metric_ftns, optimizer, config, data_loader, valid_data_loader, lr_scheduler, len_epoch) 14 | 15 | def _process(self, epoch, data_loader, metrics, mode: Mode = Mode.TRAIN): 16 | _len_epoch = self.len_epoch if mode == Mode.TRAIN else self.len_epoch_val 17 | for batch_idx, (input_moving, input_fixed, target_moving, target_fixed) in enumerate(data_loader): 18 | input_moving, input_fixed, target_moving, target_fixed = \ 19 | input_moving.to(self.device), input_fixed.to(self.device), target_moving.to(self.device), target_fixed.to(self.device) 20 | 21 | if mode == Mode.TRAIN: 22 | self.optimizer.zero_grad() 23 | 24 | output, warp, flow = self.model(input_moving, input_fixed) 25 | loss = self.loss(warp, flow, output, input_fixed, target_fixed) 26 | 27 | if mode == Mode.TRAIN: 28 | loss.backward() 29 | self.optimizer.step() 30 | 31 | self.log_scalars(metrics, self.get_step(batch_idx, epoch, _len_epoch), output, target_fixed, loss, mode) 32 | 33 | if not (batch_idx % self.log_step): 34 | self.logger.info(f'{mode.value} Epoch: {epoch} {self._progress(data_loader, batch_idx, _len_epoch)} Loss: {loss.item():.6f}') 35 | if not (batch_idx % (_len_epoch // 10)): 36 | log_visualizations_deformations(self.writer, input_moving, input_fixed, flow, target_moving, target_fixed, output) 37 | 38 | del input_moving, input_fixed, target_moving, target_fixed 39 | -------------------------------------------------------------------------------- /model/LongitudinalFCDenseNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from base import BaseModel 5 | from model.FCDenseNet import FCDenseNetEncoder, FCDenseNetDecoder 6 | 7 | 8 | class LongitudinalFCDenseNet(BaseModel): 9 | def __init__( 10 | self, in_channels=1, down_blocks=(4, 4, 4, 4, 4), 11 | up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4, growth_rate=12, 12 | out_chans_first_conv=48, n_classes=2, encoder=None, siamese=True 13 | ): 14 | super().__init__() 15 | self.up_blocks = up_blocks 16 | self.densenet_encoder = encoder 17 | self.siamese = siamese 18 | if not encoder: 19 | self.densenet_encoder = FCDenseNetEncoder( 20 | in_channels=in_channels * (1 if siamese else 2), 21 | down_blocks=down_blocks, 22 | bottleneck_layers=bottleneck_layers, 23 | growth_rate=growth_rate, 24 | out_chans_first_conv=out_chans_first_conv 25 | ) 26 | 27 | prev_block_channels = self.densenet_encoder.prev_block_channels 28 | skip_connection_channel_counts = self.densenet_encoder.skip_connection_channel_counts 29 | 30 | if self.siamese: 31 | self.add_module( 32 | 'merge_conv', 33 | nn.Conv2d(prev_block_channels * 2, prev_block_channels, 1, 1) 34 | ) 35 | 36 | self.decoder = FCDenseNetDecoder( 37 | prev_block_channels, skip_connection_channel_counts, growth_rate, 38 | n_classes, up_blocks 39 | ) 40 | 41 | def forward(self, x_ref, x): 42 | if self.siamese: 43 | out, skip_connections = self.densenet_encoder(x) 44 | out_ref, _ = self.densenet_encoder(x_ref) 45 | out = torch.cat((out, out_ref), dim=1) 46 | out = self.merge_conv(out) 47 | else: 48 | out, skip_connections = self.densenet_encoder( 49 | torch.cat((x_ref, x), dim=1) 50 | ) 51 | 52 | out = self.decoder(out, skip_connections) 53 | 54 | return out 55 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | 12 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 13 | self.validation_split = validation_split 14 | self.shuffle = shuffle 15 | 16 | self.batch_idx = 0 17 | self.n_samples = len(dataset) 18 | 19 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 20 | 21 | self.init_kwargs = { 22 | 'dataset': dataset, 23 | 'batch_size': batch_size, 24 | 'shuffle': self.shuffle, 25 | 'collate_fn': collate_fn, 26 | 'num_workers': num_workers 27 | } 28 | super().__init__(sampler=self.sampler, **self.init_kwargs) 29 | 30 | def _split_sampler(self, split): 31 | if split == 0.0: 32 | return None, None 33 | 34 | idx_full = np.arange(self.n_samples) 35 | 36 | np.random.seed(0) 37 | np.random.shuffle(idx_full) 38 | 39 | if isinstance(split, int): 40 | assert split > 0 41 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 42 | len_valid = split 43 | else: 44 | len_valid = int(self.n_samples * split) 45 | 46 | valid_idx = idx_full[0:len_valid] 47 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 48 | 49 | train_sampler = SubsetRandomSampler(train_idx) 50 | valid_sampler = SubsetRandomSampler(valid_idx) 51 | 52 | # turn off shuffle option which is mutually exclusive with sampler 53 | self.shuffle = False 54 | self.n_samples = len(train_idx) 55 | 56 | return train_sampler, valid_sampler 57 | 58 | def split_validation(self): 59 | if self.valid_sampler is None: 60 | return None 61 | else: 62 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 63 | -------------------------------------------------------------------------------- /dataset/DatasetStaticStacked.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | 8 | from dataset.dataset_utils import Phase, Modalities, Views, retrieve_data_dir_paths, Mode, Evaluate 9 | 10 | 11 | class DatasetStaticStacked(Dataset): 12 | """DatasetStaticStacked dataset""" 13 | 14 | def __init__(self, data_dir, phase=Phase.TRAIN, modalities=(), val_patients=None, evaluate: Evaluate = Evaluate.TRAINING, preprocess=True, 15 | view: Views = None): 16 | self.modalities = list(map(lambda x: Modalities(x), modalities)) 17 | self.data_dir_paths = retrieve_data_dir_paths(data_dir, evaluate, phase, preprocess, val_patients, Mode.STATIC, view) 18 | 19 | def __len__(self): 20 | return len(self.data_dir_paths) 21 | 22 | def __getitem__(self, idx): 23 | data, label, shape = [], None, 0 24 | base_path, slice_index = os.path.join(os.sep, *self.data_dir_paths[idx].split(os.sep)[:-1]), int(self.data_dir_paths[idx].split(os.sep)[-1]) 25 | for i, modality in enumerate(self.modalities): 26 | path = os.path.join(base_path, f'{slice_index - 1:03}', f'{modality.value}.h5') 27 | if os.path.exists(path): 28 | with h5py.File(path, 'r') as f: 29 | data.append(f['data'][()]) 30 | path = os.path.join(base_path, f'{slice_index:03}', f'{modality.value}.h5') 31 | if os.path.exists(path): 32 | with h5py.File(path, 'r') as f: 33 | data_ = f['data'][()] 34 | shape = data_.shape 35 | data.append(data_) 36 | if label is None: 37 | label = F.one_hot(torch.as_tensor(f['label'][()], dtype=torch.int64), num_classes=2).permute(2, 0, 1) 38 | path = os.path.join(base_path, f'{slice_index + 1:03}', f'{modality.value}.h5') 39 | if os.path.exists(path): 40 | with h5py.File(path, 'r') as f: 41 | data.append(f['data'][()]) 42 | 43 | if len(data) != len(self.modalities) * 3: 44 | return torch.zeros(len(self.modalities) * 3, *shape), torch.cat([torch.ones(1, *label.shape[-2:]), torch.zeros(1, *label.shape[-2:])], dim=0) 45 | return torch.as_tensor(data), label.float() 46 | -------------------------------------------------------------------------------- /trainer/DeepAtlasTrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from logger import Mode 4 | from trainer.Trainer import Trainer 5 | from utils.illustration_util import log_visualizations_deformations 6 | 7 | 8 | class DeepAtlasTrainer(Trainer): 9 | """ 10 | Trainer class 11 | """ 12 | 13 | def __init__(self, model, loss, metric_ftns, optimizer, config, data_loader, 14 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 15 | super().__init__( 16 | model, loss, metric_ftns, optimizer, config, data_loader, 17 | valid_data_loader, lr_scheduler, len_epoch 18 | ) 19 | 20 | def _process(self, epoch, data_loader, metrics, mode: Mode = Mode.TRAIN): 21 | _len_epoch = self.len_epoch if mode == Mode.TRAIN else self.len_epoch_val 22 | for batch_idx, (x_ref, x, target_ref, target) in enumerate(data_loader): 23 | x_ref, x, target_ref, target = x_ref.to(self.device), x.to(self.device), target_ref.to(self.device), target.to(self.device) 24 | 25 | if mode == Mode.TRAIN: 26 | self.optimizer.zero_grad() 27 | output = self.model(x_ref, x) 28 | 29 | loss, scores = self.loss( 30 | *output, 31 | torch.argmax(target_ref, dim=1, keepdim=True).float(), 32 | torch.argmax(target, dim=1, keepdim=True).float(), 33 | x 34 | ) 35 | if mode == Mode.TRAIN: 36 | loss.backward() 37 | self.optimizer.step() 38 | 39 | self.log_scalars( 40 | metrics, 41 | self.get_step(batch_idx, epoch, _len_epoch), 42 | output[1], 43 | torch.argmax(target, dim=1, keepdim=True), 44 | scores, 45 | mode 46 | ) 47 | 48 | if not (batch_idx % self.log_step): 49 | self.logger.info( 50 | f'{mode.value} Epoch: {epoch} ' 51 | f'{self._progress(data_loader, batch_idx, _len_epoch)} ' 52 | f'Loss: {loss.item():.6f}' 53 | ) 54 | if not (batch_idx % (_len_epoch // 10)): 55 | log_visualizations_deformations( 56 | self.writer, x_ref, x, output[-1], target_ref, target 57 | ) 58 | 59 | del x_ref, x, target 60 | -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | from enum import Enum 4 | 5 | 6 | class Mode(Enum): 7 | TRAIN = 'Train' 8 | VAL = 'Val' 9 | 10 | 11 | class TensorboardWriter(): 12 | def __init__(self, log_dir, logger, enabled): 13 | self.writer = None 14 | self.selected_module = "" 15 | 16 | if enabled: 17 | log_dir = str(log_dir) 18 | 19 | # Retrieve visualization writer. 20 | succeeded = False 21 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 22 | try: 23 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 24 | succeeded = True 25 | break 26 | except ImportError: 27 | succeeded = False 28 | self.selected_module = module 29 | 30 | if not succeeded: 31 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 32 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 33 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 34 | logger.warning(message) 35 | 36 | self.step = 0 37 | self.mode = None 38 | 39 | self.tb_writer_ftns = { 40 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 'add_graph', 41 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 42 | } 43 | self.tag_mode_exceptions = {'add_graph', 'add_histogram', 'add_embedding'} 44 | self.timer = datetime.now() 45 | 46 | def set_step(self, step, mode=Mode.TRAIN): 47 | self.mode = mode 48 | self.step = step 49 | if step == 0: 50 | self.timer = datetime.now() 51 | else: 52 | duration = datetime.now() - self.timer 53 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 54 | self.timer = datetime.now() 55 | 56 | def __getattr__(self, name): 57 | """ 58 | If visualization is configured to use: 59 | return add_data() methods of tensorboard with additional information (step, tag) added. 60 | Otherwise: 61 | return a blank function handle that does nothing 62 | """ 63 | if name in self.tb_writer_ftns: 64 | add_data = getattr(self.writer, name, None) 65 | 66 | def wrapper(tag, data, *args, **kwargs): 67 | if add_data is not None: 68 | # add mode(train/valid) tag 69 | if name not in self.tag_mode_exceptions: 70 | tag = '{}/{}'.format(tag, self.mode.value) 71 | add_data(tag, data, self.step, *args, **kwargs) 72 | 73 | return wrapper 74 | else: 75 | # default action for returning methods defined in this class, set_step() for instance. 76 | try: 77 | attr = object.__getattr__(name) 78 | except AttributeError: 79 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 80 | return attr 81 | -------------------------------------------------------------------------------- /utils/illustration_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torchvision.utils import make_grid 5 | 6 | 7 | def warp_flow(img, flow): 8 | h, w = flow.shape[:2] 9 | flow = -flow 10 | flow[:, :, 0] += np.arange(w) 11 | flow[:, :, 1] += np.arange(h)[:, np.newaxis] 12 | res = cv2.remap(img, flow, None, cv2.INTER_LINEAR) 13 | return res 14 | 15 | 16 | def visualize_flow(flow): 17 | """Visualize optical flow 18 | 19 | Args: 20 | flow: optical flow map with shape of (H, W, 2), with (y, x) order 21 | 22 | Returns: 23 | RGB image of shape (H, W, 3) 24 | """ 25 | assert flow.ndim == 3 26 | assert flow.shape[2] == 2 27 | 28 | hsv = np.zeros([flow.shape[0], flow.shape[1], 3], dtype=np.uint8) 29 | mag, ang = cv2.cartToPolar(flow[..., 1], flow[..., 0]) 30 | hsv[..., 0] = ang * 180 / np.pi / 2 31 | hsv[..., 1] = 255 32 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 33 | rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 34 | return rgb 35 | 36 | 37 | def log_visualizations(writer, data, output, target): 38 | for (a, b, c) in zip(cast(data), cast(output, True), cast(target, True)): 39 | tensor = np.expand_dims(np.transpose(np.hstack([a, b, c]), (2, 0, 1)), axis=0) 40 | writer.add_image('input_output_target', make_grid(torch.as_tensor(tensor), nrow=8, normalize=True)) 41 | 42 | 43 | def log_visualizations_longitudinal(writer, x_ref, x, output, target): 44 | for (a1, a2, b, c) in zip(cast(x_ref), cast(x), cast(output, True), cast(target, True)): 45 | tensor = np.expand_dims(np.transpose(np.hstack([a1, a2, b, c]), (2, 0, 1)), axis=0) 46 | writer.add_image('x_ref_x_output_target', make_grid(torch.as_tensor(tensor), nrow=8, normalize=True)) 47 | 48 | 49 | def log_visualizations_deformations(writer, input_moving, input_fixed, flow, target_moving, target_fixed, output=None): 50 | zipped_data = zip( 51 | cast(input_moving), 52 | cast(input_fixed), 53 | cast(flow, normalize_data=False), 54 | cast(target_moving, True), 55 | cast(target_fixed, True), 56 | cast(output, True) if type(None) != type(output) else [None for _ in input_moving] 57 | ) 58 | for (_input_moving, _input_fixed, _flow, _target_moving, _target_fixed, _output) in zipped_data: 59 | transposed_flow = np.transpose(_flow, (1, 2, 0)) 60 | 61 | illustration = [ 62 | _input_moving, 63 | _input_fixed, 64 | visualize_flow(transposed_flow) / 255., 65 | _target_moving, 66 | _target_fixed 67 | ] 68 | if type(None) != type(_output): 69 | illustration.append(_output) 70 | 71 | tensor = np.expand_dims(np.transpose(np.hstack(illustration), (2, 0, 1)), axis=0) 72 | description = 'inputmoving_inputfixed_flowfield_targetmoving_targetfixed_output' 73 | writer.add_image(description, make_grid(torch.as_tensor(tensor), nrow=8, normalize=True)) 74 | 75 | 76 | def cast(data, argmax=False, normalize_data=True): 77 | data = data.cpu().detach().numpy() 78 | if argmax: 79 | data = np.argmax(data, axis=1) 80 | 81 | data = data.astype('float32') 82 | 83 | if normalize_data: 84 | data = np.asarray([normalize(date) for date in data]) 85 | 86 | return data 87 | 88 | 89 | def normalize(x): 90 | if len(x.shape) > 2: 91 | x = x[0] 92 | 93 | return cv2.cvtColor(cv2.normalize(x, None, 0, 1, cv2.NORM_MINMAX), cv2.COLOR_GRAY2RGB) 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import defaultdict 4 | from copy import copy 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import data_loader as module_data_loader 10 | import dataset as module_dataset 11 | import model as module_arch 12 | import model.utils.loss as module_loss 13 | import model.utils.metric as module_metric 14 | import trainer as trainer_module 15 | from dataset.DatasetStatic import Phase 16 | from dataset.dataset_utils import Views 17 | from parse_config import ConfigParser, parse_cmd_args 18 | 19 | 20 | def main(config, resume=None): 21 | torch.manual_seed(0) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | np.random.seed(0) 25 | 26 | if resume: 27 | config.resume = resume 28 | 29 | logger = config.get_logger('train') 30 | 31 | # get function handles of loss and metrics 32 | loss = getattr(module_loss, config['loss']) 33 | metrics = [getattr(module_metric, met) for met in config['metrics']] 34 | 35 | # setup data_loader instances 36 | if config['single_view']: 37 | results = defaultdict(list) 38 | for view in list(Views): 39 | _cfg = copy(config) 40 | logs = train(logger, _cfg, loss, metrics, view=view) 41 | for k, v in list(logs.items()): 42 | results[k].append(v) 43 | 44 | else: 45 | train(logger, config, loss, metrics) 46 | 47 | 48 | def train(logger, config, loss, metrics, view: Views = None): 49 | dataset = config.retrieve_class('dataset', module_dataset)(**config['dataset']['args'], phase=Phase.TRAIN, view=view) 50 | data_loader = config.retrieve_class('data_loader', module_data_loader)(**config['data_loader']['args'], dataset=dataset) 51 | 52 | val_dataset = config.retrieve_class('dataset', module_dataset)(**config['dataset']['args'], phase=Phase.VAL, view=view) 53 | valid_data_loader = config.retrieve_class('data_loader', module_data_loader)(**config['data_loader']['args'], dataset=val_dataset) 54 | 55 | # build model architecture, then print to console 56 | model = config.initialize_class('arch', module_arch) 57 | logger.info(model) 58 | 59 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 60 | optimizer = config.initialize('optimizer', torch.optim, trainable_params) 61 | 62 | lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer) 63 | if view: 64 | config._save_dir = os.path.join(config._save_dir, str(view.name)) 65 | config._log_dir = os.path.join(config._log_dir, str(view.name)) 66 | os.mkdir(config._save_dir) 67 | os.mkdir(config._log_dir) 68 | trainer = config.retrieve_class('trainer', trainer_module)(model, loss, metrics, optimizer, config, data_loader, valid_data_loader, lr_scheduler) 69 | return trainer.train() 70 | 71 | 72 | if __name__ == '__main__': 73 | args = argparse.ArgumentParser(description='PyTorch Template') 74 | args.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') 75 | args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') 76 | args.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 77 | args.add_argument('-s', '--single_view', default=False, type=bool, help='Defines if a single is used per plane orientation') 78 | 79 | config = ConfigParser(*parse_cmd_args(args)) 80 | main(config) 81 | -------------------------------------------------------------------------------- /trainer/Trainer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from base import BaseTrainer 7 | from logger import Mode 8 | from utils import MetricTracker 9 | 10 | 11 | class Trainer(BaseTrainer): 12 | """ 13 | Trainer class 14 | """ 15 | 16 | def __init__(self, model, loss, metric_ftns, optimizer, config, data_loader, 17 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 18 | super().__init__(model, loss, metric_ftns, optimizer, config) 19 | self.config = config 20 | self.data_loader = data_loader 21 | self.valid_data_loader = valid_data_loader 22 | self.do_validation = self.valid_data_loader is not None 23 | 24 | if len_epoch is None: 25 | # epoch-based training 26 | self.len_epoch = len(self.data_loader) 27 | self.len_epoch_val = len(self.valid_data_loader) if self.do_validation else 0 28 | 29 | self.lr_scheduler = lr_scheduler 30 | self.log_step = int(np.sqrt(data_loader.batch_size)) 31 | 32 | self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 33 | self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 34 | 35 | @abstractmethod 36 | def _process(self, epoch, data_loader, metrics, mode: Mode = Mode.TRAIN): 37 | raise NotImplementedError('Method _process() from Trainer class has to be implemented!') 38 | 39 | def _train_epoch(self, epoch): 40 | """ 41 | Training logic for an epoch 42 | 43 | :param epoch: Integer, current training epoch. 44 | :return: A log that contains average loss and metric in this epoch. 45 | """ 46 | self.model.train() 47 | self.train_metrics.reset() 48 | 49 | self._process(epoch, self.data_loader, self.train_metrics, Mode.TRAIN) 50 | 51 | log = self.train_metrics.result() 52 | 53 | if self.do_validation: 54 | val_log = self._valid_epoch(epoch) 55 | log.update(**{'val_' + k: v for k, v in val_log.items()}) 56 | 57 | if self.lr_scheduler is not None: 58 | self.lr_scheduler.step() 59 | return log 60 | 61 | def _valid_epoch(self, epoch): 62 | """ 63 | Validate after training an epoch 64 | 65 | :param epoch: Integer, current training epoch. 66 | :return: A log that contains information about validation 67 | """ 68 | self.model.eval() 69 | self.valid_metrics.reset() 70 | with torch.no_grad(): 71 | self._process(epoch, self.valid_data_loader, self.valid_metrics, Mode.VAL) 72 | 73 | # add histogram of model parameters to the tensorboard 74 | for name, p in self.model.named_parameters(): 75 | self.writer.add_histogram(name, p, bins='auto') 76 | return self.valid_metrics.result() 77 | 78 | def log_scalars(self, metrics, step, output, target, loss, mode=Mode.TRAIN): 79 | self.writer.set_step(step, mode) 80 | metrics.update('loss', loss.item()) 81 | for met in self.metric_ftns: 82 | metrics.update(met.__name__, met(output, target)) 83 | 84 | @staticmethod 85 | def _progress(data_loader, batch_idx, batches): 86 | base = '[{}/{} ({:.0f}%)]' 87 | if hasattr(data_loader, 'n_samples'): 88 | current = batch_idx * data_loader.batch_size 89 | total = data_loader.n_samples 90 | else: 91 | current = batch_idx 92 | total = batches 93 | return base.format(current, total, 100.0 * current / total) 94 | 95 | @staticmethod 96 | def get_step(batch_idx, epoch, len_epoch): 97 | return (epoch - 1) * len_epoch + batch_idx 98 | -------------------------------------------------------------------------------- /model/utils/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SpatialTransformer(nn.Module): 7 | def __init__(self, size, mode='bilinear'): 8 | super(SpatialTransformer, self).__init__() 9 | 10 | vectors = [torch.arange(0, s) for s in size] 11 | grid = torch.unsqueeze(torch.stack(torch.meshgrid(vectors)), dim=0).float() 12 | self.register_buffer('grid', grid) 13 | self.mode = mode 14 | 15 | def forward(self, src, flow): 16 | new_locs = self.grid + flow 17 | 18 | shape = flow.shape[2:] 19 | 20 | for i in range(len(shape)): 21 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 22 | 23 | new_locs = new_locs.permute(0, 2, 3, 1) 24 | new_locs = new_locs[..., [1, 0]] 25 | 26 | return F.grid_sample(src, new_locs, mode=self.mode, align_corners=True) 27 | 28 | 29 | class DenseLayer(nn.Sequential): 30 | def __init__(self, in_channels, growth_rate): 31 | super().__init__() 32 | self.add_module('norm', nn.BatchNorm2d(in_channels)) 33 | self.add_module('relu', nn.ReLU(True)) 34 | self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=True)) 35 | self.add_module('drop', nn.Dropout2d(0.2)) 36 | 37 | def forward(self, x): 38 | return super().forward(x) 39 | 40 | 41 | class DenseBlock(nn.Module): 42 | def __init__(self, in_channels, growth_rate, n_layers, upsample=False): 43 | super().__init__() 44 | self.upsample = upsample 45 | self.layers = nn.ModuleList([DenseLayer(in_channels + i * growth_rate, growth_rate) for i in range(n_layers)]) 46 | 47 | def forward(self, x): 48 | if self.upsample: 49 | new_features = [] 50 | for layer in self.layers: 51 | out = layer(x) 52 | x = torch.cat([x, out], 1) 53 | new_features.append(out) 54 | return torch.cat(new_features, 1) 55 | else: 56 | for layer in self.layers: 57 | out = layer(x) 58 | x = torch.cat([x, out], 1) 59 | return x 60 | 61 | 62 | class TransitionDown(nn.Sequential): 63 | def __init__(self, in_channels): 64 | super().__init__() 65 | self.add_module('norm', nn.BatchNorm2d(num_features=in_channels)) 66 | self.add_module('relu', nn.ReLU(inplace=True)) 67 | self.add_module('conv', nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=True)) 68 | self.add_module('drop', nn.Dropout2d(0.2)) 69 | self.add_module('maxpool', nn.MaxPool2d(2)) 70 | 71 | def forward(self, x): 72 | return super().forward(x) 73 | 74 | 75 | class TransitionUp(nn.Module): 76 | def __init__(self, in_channels, out_channels): 77 | super().__init__() 78 | self.convTrans = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=0, bias=True) 79 | 80 | def forward(self, x, skip_x): 81 | out = self.convTrans(x) 82 | out = center_crop(out, skip_x.size(2), skip_x.size(3)) 83 | out = torch.cat([out, skip_x], 1) 84 | return out 85 | 86 | 87 | class Bottleneck(nn.Sequential): 88 | def __init__(self, in_channels, growth_rate, n_layers): 89 | super().__init__() 90 | self.add_module('bottleneck', DenseBlock(in_channels, growth_rate, n_layers, upsample=True)) 91 | 92 | def forward(self, x): 93 | return super().forward(x) 94 | 95 | 96 | def center_crop(layer, max_height, max_width): 97 | _, _, h, w = layer.size() 98 | xy1 = (w - max_width) // 2 99 | xy2 = (h - max_height) // 2 100 | return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] 101 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from datetime import datetime 4 | from functools import reduce 5 | from importlib.machinery import SourceFileLoader 6 | from operator import getitem 7 | from pathlib import Path 8 | 9 | from logger import setup_logging 10 | from utils.util import write_config 11 | 12 | 13 | def parse_cmd_args(args): 14 | # parse default cli options 15 | args = args.parse_args() 16 | 17 | if args.device: 18 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 19 | if args.resume: 20 | resume = Path(args.resume) 21 | cfg_fname = resume.parent / 'config.py' 22 | else: 23 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.py', for example." 24 | assert args.config is not None, msg_no_cfg 25 | resume = None 26 | cfg_fname = Path(args.config) 27 | 28 | # load config file and apply custom cli options 29 | config = SourceFileLoader("CONFIG", str(cfg_fname)).load_module().CONFIG 30 | 31 | for key, value in args.__dict__.items(): 32 | config[key] = value 33 | return config, resume 34 | 35 | 36 | class ConfigParser: 37 | def __init__(self, config, resume=None, modification=None, run_id=None): 38 | """ 39 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 40 | and logging module. 41 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 42 | :param resume: String, path to the checkpoint being loaded. 43 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 44 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 45 | """ 46 | # load config file and apply modification 47 | self._config = config 48 | self.resume = resume 49 | 50 | # set save_dir where trained model and log will be saved. 51 | save_dir = Path(self.config['trainer']['save_dir']) 52 | 53 | exper_name = self.config['name'] 54 | if run_id is None: # use timestamp as default run-id 55 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 56 | self._save_dir = save_dir / 'models' / exper_name / run_id 57 | self._log_dir = save_dir / 'log' / exper_name / run_id 58 | 59 | # make directory for saving checkpoints and log. 60 | exist_ok = run_id == '' 61 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 62 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) 63 | 64 | # save updated config file to the checkpoint dir 65 | write_config(self.config, self.save_dir / 'config.py') 66 | 67 | # configure logging module 68 | setup_logging(self.log_dir) 69 | self.log_levels = { 70 | 0: logging.WARNING, 71 | 1: logging.INFO, 72 | 2: logging.DEBUG 73 | } 74 | 75 | def initialize(self, name, module, *args): 76 | """ 77 | finds a function handle with the name given as 'type' in config, and returns the 78 | instance initialized with corresponding keyword args given as 'args'. 79 | """ 80 | module_cfg = self[name] 81 | return getattr(module, module_cfg['type'])(*args, **module_cfg['args']) 82 | 83 | def initialize_class(self, name, module, *args): 84 | """ 85 | finds a function handle with the name given as 'type' in config, and returns the 86 | instance initialized with corresponding keyword args given as 'args'. 87 | """ 88 | class_instance = self.retrieve_class(name, module) 89 | return class_instance(*args, **self[name]['args']) 90 | 91 | def retrieve_class(self, name, module): 92 | module_cfg = self[name] 93 | class_name = module_cfg["type"] 94 | base_path = os.path.join(Path(os.path.dirname(os.path.abspath(__file__))), module.__name__, f'{class_name}.py') 95 | class_instance = getattr(SourceFileLoader(class_name, base_path).load_module(), class_name) 96 | return class_instance 97 | 98 | def __getitem__(self, name): 99 | """Access items like ordinary dict.""" 100 | return self.config[name] 101 | 102 | def get_logger(self, name, verbosity=2): 103 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 104 | assert verbosity in self.log_levels, msg_verbosity 105 | logger = logging.getLogger(name) 106 | logger.setLevel(self.log_levels[verbosity]) 107 | return logger 108 | 109 | @property 110 | def config(self): 111 | return self._config 112 | 113 | @property 114 | def save_dir(self): 115 | return self._save_dir 116 | 117 | @property 118 | def log_dir(self): 119 | return self._log_dir 120 | 121 | 122 | def _get_opt_name(flags): 123 | for flg in flags: 124 | if flg.startswith('--'): 125 | return flg.replace('--', '') 126 | return flags[0].replace('--', '') 127 | 128 | 129 | def _set_by_path(tree, keys, value): 130 | """Set a value in a nested object in tree by sequence of keys.""" 131 | keys = keys.split(';') 132 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 133 | 134 | 135 | def _get_by_path(tree, keys): 136 | """Access a nested object in tree by sequence of keys.""" 137 | return reduce(getitem, keys, tree) 138 | -------------------------------------------------------------------------------- /model/FCDenseNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from base import BaseModel 4 | from model.utils.layers import DenseBlock, TransitionDown, Bottleneck, \ 5 | TransitionUp 6 | 7 | 8 | class FCDenseNetEncoder(BaseModel): 9 | def __init__( 10 | self, in_channels=1, down_blocks=(4, 4, 4, 4, 4), 11 | bottleneck_layers=4, growth_rate=12, out_chans_first_conv=48 12 | ): 13 | super().__init__() 14 | self.down_blocks = down_blocks 15 | self.skip_connection_channel_counts = [] 16 | 17 | self.add_module( 18 | 'firstconv', 19 | nn.Conv2d( 20 | in_channels=in_channels, out_channels=out_chans_first_conv, 21 | kernel_size=3, stride=1, padding=1, bias=True 22 | ) 23 | ) 24 | self.cur_channels_count = out_chans_first_conv 25 | 26 | self.denseBlocksDown = nn.ModuleList([]) 27 | self.transDownBlocks = nn.ModuleList([]) 28 | for i in range(len(down_blocks)): 29 | self.denseBlocksDown.append( 30 | DenseBlock(self.cur_channels_count, growth_rate, down_blocks[i]) 31 | ) 32 | self.cur_channels_count += (growth_rate * down_blocks[i]) 33 | self.skip_connection_channel_counts.insert( 34 | 0, self.cur_channels_count 35 | ) 36 | self.transDownBlocks.append(TransitionDown(self.cur_channels_count)) 37 | 38 | self.add_module( 39 | 'bottleneck', 40 | Bottleneck(self.cur_channels_count, growth_rate, bottleneck_layers) 41 | ) 42 | self.prev_block_channels = growth_rate * bottleneck_layers 43 | self.cur_channels_count += self.prev_block_channels 44 | 45 | def forward(self, x): 46 | out = self.firstconv(x) 47 | 48 | skip_connections = [] 49 | for i in range(len(self.down_blocks)): 50 | out = self.denseBlocksDown[i](out) 51 | skip_connections.append(out) 52 | out = self.transDownBlocks[i](out) 53 | 54 | out = self.bottleneck(out) 55 | return out, skip_connections 56 | 57 | 58 | class FCDenseNetDecoder(BaseModel): 59 | def __init__( 60 | self, prev_block_channels, skip_connection_channel_counts, 61 | growth_rate, n_classes, up_blocks, apply_softmax=True 62 | ): 63 | super().__init__() 64 | self.apply_softmax = apply_softmax 65 | self.up_blocks = up_blocks 66 | self.transUpBlocks = nn.ModuleList([]) 67 | self.denseBlocksUp = nn.ModuleList([]) 68 | for i in range(len(self.up_blocks) - 1): 69 | self.transUpBlocks.append( 70 | TransitionUp(prev_block_channels, prev_block_channels) 71 | ) 72 | cur_channels_count = prev_block_channels + \ 73 | skip_connection_channel_counts[i] 74 | 75 | self.denseBlocksUp.append( 76 | DenseBlock( 77 | cur_channels_count, growth_rate, self.up_blocks[i], 78 | upsample=True 79 | ) 80 | ) 81 | prev_block_channels = growth_rate * self.up_blocks[i] 82 | cur_channels_count += prev_block_channels 83 | 84 | self.transUpBlocks.append( 85 | TransitionUp(prev_block_channels, prev_block_channels) 86 | ) 87 | cur_channels_count = prev_block_channels + \ 88 | skip_connection_channel_counts[-1] 89 | self.denseBlocksUp.append( 90 | DenseBlock( 91 | cur_channels_count, growth_rate, self.up_blocks[-1], 92 | upsample=False 93 | ) 94 | ) 95 | cur_channels_count += growth_rate * self.up_blocks[-1] 96 | 97 | self.finalConv = nn.Conv2d( 98 | in_channels=cur_channels_count, out_channels=n_classes, 99 | kernel_size=1, stride=1, padding=0, bias=True 100 | ) 101 | self.softmax = nn.Softmax2d() 102 | 103 | def forward(self, out, skip_connections): 104 | for i in range(len(self.up_blocks)): 105 | skip = skip_connections[-i - 1] 106 | out = self.transUpBlocks[i](out, skip) 107 | out = self.denseBlocksUp[i](out) 108 | 109 | out = self.finalConv(out) 110 | if self.apply_softmax: 111 | out = self.softmax(out) 112 | 113 | return out 114 | 115 | 116 | class FCDenseNet(BaseModel): 117 | def __init__( 118 | self, in_channels=1, down_blocks=(4, 4, 4, 4, 4), 119 | up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4, growth_rate=12, 120 | out_chans_first_conv=48, n_classes=2, apply_softmax=True, 121 | encoder=None 122 | ): 123 | super().__init__() 124 | self.up_blocks = up_blocks 125 | self.encoder = encoder 126 | if not encoder: 127 | self.encoder = FCDenseNetEncoder( 128 | in_channels=in_channels, down_blocks=down_blocks, 129 | bottleneck_layers=bottleneck_layers, growth_rate=growth_rate, 130 | out_chans_first_conv=out_chans_first_conv 131 | ) 132 | 133 | prev_block_channels = self.encoder.prev_block_channels 134 | skip_connection_channel_counts = self.encoder.skip_connection_channel_counts 135 | 136 | self.decoder = FCDenseNetDecoder( 137 | prev_block_channels, skip_connection_channel_counts, growth_rate, 138 | n_classes, up_blocks, apply_softmax 139 | ) 140 | 141 | def forward(self, x, is_encoder_output=False): 142 | if is_encoder_output: 143 | out, skip_connections = x 144 | else: 145 | out, skip_connections = self.encoder(x) 146 | 147 | out = self.decoder(out, skip_connections) 148 | return out 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatio-temporal Learning from Longitudinal Data for Multiple Sclerosis Lesion Segmentation 2 | 3 | This is the code for our paper Spatio-temporal Learning from Longitudinal Data for Multiple Sclerosis Lesion Segmentation which can be found [here](https://arxiv.org/pdf/2004.03675.pdf) 4 | 5 | If you use any of our code, please cite: 6 | ``` 7 | @article{Denner2020, 8 | author = {Denner, Stefan and Khakzar, Ashkan and Sajid, Moiz and Saleh, Mahdi and Spiclin, Ziga and Kim, Seong Tae and Navab, Nassir}, 9 | title = {Spatio-temporal Learning from Longitudinal Data for Multiple Sclerosis Lesion Segmentation}, 10 | url = {http://arxiv.org/abs/2004.03675}, 11 | year = {2020} 12 | } 13 | ``` 14 | 15 | 16 | 17 | 18 | * [Spatio-temporal Learning from Longitudinal Data for Multiple Sclerosis Lesion Segmentation](#spatio-temporal-learning-from-longitudinal-data-for-multiple-sclerosis-lesion-segmentation) 19 | * [Requirements](#requirements) 20 | * [Folder Structure](#folder-structure) 21 | * [Usage](#usage) 22 | * [Train](#train) 23 | * [Resuming from checkpoints](#resuming-from-checkpoints) 24 | * [Test](#test) 25 | * [Disclaimer](#disclaimer) 26 | * [License](#license) 27 | * [Acknowledgements](#acknowledgements) 28 | 29 | 30 | 31 | ## Requirements 32 | * Python >= 3.5 (3.6 recommended) 33 | * PyTorch = 1.4 34 | * tqdm 35 | * tensorboard >= 1.14 36 | * nibabel >= 2.5 37 | 38 | ## Folder Structure 39 | ``` 40 | Spatio-temporal-MS-Lesion-Segmentation/ 41 | │ 42 | ├── train.py - main script to start/resume training 43 | ├── test.py - evaluation of trained model 44 | ├── test_single_view.py - evaluation of models which use a single model for each plane orientation 45 | │ 46 | ├── base/ - abstract base classes 47 | │ 48 | ├── configs/ - holds all the configurations files for the different models 49 | │ ├── Longitudinal_Network.py 50 | │ ├── Longitudinal_Siamese_Network.py 51 | │ ├── Multitask_Longitudinal_Network.py 52 | │ ├── Deep_Atlas.py 53 | │ ├── Static_Network.py 54 | │ ├── Static_Network_Asymmetric.py 55 | │ └── Static_Network_Zhang.py 56 | │ 57 | ├── data_loader/ 58 | │ └── Dataloader.py - dataloader for the Dataset 59 | │ 60 | ├── model/ 61 | │ ├── utils/ - contains additional Modules, losses and metrics 62 | │ ├── FCDenseNet.py 63 | │ ├── LongitudinalFCDenseNet.py 64 | │ ├── MultitaskNetwork.py.py 65 | │ └── DeepAtlas.py 66 | │ 67 | └── trainer/ - trainers 68 | ├── Trainer.py 69 | ├── LongitudinalMultitaskTrainer.py 70 | ├── LongitudinalTrainer.py 71 | ├── DeepAtlasTrainer.py 72 | └── StaticTrainer.py 73 | 74 | ``` 75 | 76 | ## Usage 77 | Before the models can be trained or tested, the paths in the config files (located in `configs/`) have to be adjusted: 78 | - `data_loader.args.data_dir` specifies where the data is located 79 | - `trainer.save_dir` specifies where to store the model checkpoints and logs. 80 | 81 | ### Train 82 | To run the experiments from our paper the following table specifies the commands to run: 83 | 84 | | Network | Command | 85 | |----------------------------------------------|-------------------------------------------------------------------| 86 | | Multitask Longitudinal Network | python train.py -c Multitask_Longitudinal_Network.py | 87 | | Baseline Longitudinal Network | python train.py -c Longitudinal_Network.py | 88 | | Baseline Static Network | python train.py -c Static_Network.py | 89 | | Longitudinal Siamese Network | python train.py -c Longitudinal_Siamese_Network.py | 90 | | Static Network (Zhang et al. [2]) | python train.py -c Static_Network_Zhang.py -s True | 91 | | Static Network (Asymmetric Dice Loss [12]) | python train.py -c Static_Network_Asymmetric.py | 92 | 93 | 94 | ### Resuming from checkpoints 95 | Resume the training form a certain checkpoint can be done by executing: 96 | 97 | ``` 98 | python train.py --resume path/to/checkpoint 99 | ``` 100 | 101 | ### Test 102 | A trained model can be tested by executing `test.py` passing the path to the trained checkpoint to the `--resume` argument. 103 | For the networks which use a separate model for each plane orientation, the model can be tested with `test_single_view.py`. 104 | Here the _parent folder_ of a checkpoint has to be provided as the `--resume` argument. 105 | 106 | For testing a **longitudinal** model we perform a majority vote over all possible combinates for a given target image with its reference images. 107 | A longitudinal model usually has a reference image(timepoint t-1) and the target/follow-up image(timepoint t) as input. 108 | Our experiments have shown that we achieve the best performance when applying a majority vote over all possible permutations for a certain target image. 109 | This means, for a patient with four timesteps t ∈ {0; 1; 2; 3} and having t = 1 as the target image, we do 110 | a majority votes over the probability outputs of the inputs (reference, target): (0, 1), (2, 1), (3, 1). 111 | 112 | ### General notes 113 | All hyperparameters are defined in the config file. 114 | Majority vote for merging the outputs of the different plane orientations is only applied in `test.py` and `test_single_view.py`. 115 | The majority voting (MV) is done on the merged probability maps for a voxel from each view. 116 | Since it is from great interest to see the actual performance (i.e. after MV) of a model on the validation set, the 117 | test script has an argument `-e` or `--evaluate` which can be either `train` or `test`. 118 | This argument specifies which data should be used. For evaluating the performance of the model on the 119 | train/validation set, this argument has to be `train` else `test`(default). 120 | 121 | ## Disclaimer 122 | The code has been cleaned and polished for the sake of clarity and reproducibility, and even though it has been checked thoroughly, it might contain bugs or mistakes. Please do not hesitate to open an issue or contact the authors to inform of any problem you may find within this repository. 123 | 124 | ## License 125 | This project is licensed under the MIT License. See LICENSE for more details 126 | 127 | ## Acknowledgements 128 | This project is a fork of the project [PyTorch-Template](https://github.com/victoresque/pytorch-template) by [Victor Huang](https://github.com/victoresque) 129 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch 4 | from numpy import inf 5 | 6 | from logger import TensorboardWriter 7 | 8 | 9 | class BaseTrainer: 10 | """ 11 | Base class for all trainers 12 | """ 13 | 14 | def __init__(self, model, loss, metric_ftns, optimizer, config): 15 | self.config = config 16 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 17 | 18 | # setup GPU device if available, move model into configured device 19 | self.device, device_ids = self._prepare_device(config['n_gpu']) 20 | self.model = model.to(self.device) 21 | if len(device_ids) > 1: 22 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 23 | 24 | self.loss = loss 25 | self.metric_ftns = metric_ftns 26 | self.optimizer = optimizer 27 | 28 | cfg_trainer = config['trainer'] 29 | self.epochs = cfg_trainer['epochs'] 30 | self.save_period = cfg_trainer['save_period'] 31 | self.monitor = cfg_trainer.get('monitor', 'off') 32 | 33 | # configuration to monitor model performance and save best 34 | if self.monitor == 'off': 35 | self.mnt_mode = 'off' 36 | self.mnt_best = 0 37 | else: 38 | self.mnt_mode, self.mnt_metric = self.monitor.split() 39 | assert self.mnt_mode in ['min', 'max'] 40 | 41 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 42 | self.early_stop = cfg_trainer.get('early_stop', inf) 43 | 44 | self.start_epoch = 1 45 | 46 | self.checkpoint_dir = config.save_dir 47 | 48 | # setup visualization writer instance 49 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 50 | 51 | if config.resume is not None: 52 | self._resume_checkpoint(config.resume) 53 | self.not_improved_count = 0 54 | 55 | @abstractmethod 56 | def _train_epoch(self, epoch): 57 | """ 58 | Training logic for an epoch 59 | 60 | :param epoch: Current epoch number 61 | """ 62 | raise NotImplementedError 63 | 64 | def train(self): 65 | """ 66 | Full training logic 67 | """ 68 | best_log = None 69 | for epoch in range(self.start_epoch, self.epochs + 1): 70 | result = self._train_epoch(epoch) 71 | 72 | # save logged informations into log dict 73 | log = {'epoch': epoch} 74 | log.update(result) 75 | 76 | # print logged informations to the screen 77 | for key, value in log.items(): 78 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 79 | 80 | # evaluate model performance according to configured metric, save best checkpoint as model_best 81 | best = False 82 | if self.mnt_mode != 'off': 83 | try: 84 | # check whether model performance improved or not, according to specified metric(mnt_metric) 85 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 86 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 87 | except KeyError: 88 | self.logger.warning("Warning: Metric '{}' is not found. " 89 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 90 | self.mnt_mode = 'off' 91 | improved = False 92 | 93 | if improved: 94 | self.mnt_best = log[self.mnt_metric] 95 | best_log = log 96 | self.not_improved_count = 0 97 | best = True 98 | else: 99 | self.not_improved_count += 1 100 | 101 | if self.not_improved_count > self.early_stop: 102 | self.logger.info("Validation performance hasn\'t improve for {} epochs. Training stops.".format(self.early_stop)) 103 | break 104 | 105 | if epoch % self.save_period == 0: 106 | self._save_checkpoint(epoch, save_best=best) 107 | 108 | return best_log 109 | 110 | def _prepare_device(self, n_gpu_use): 111 | """ 112 | setup GPU device if available, move model into configured device 113 | """ 114 | n_gpu = torch.cuda.device_count() 115 | if n_gpu_use > 0 and n_gpu == 0: 116 | self.logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") 117 | n_gpu_use = 0 118 | if n_gpu_use > n_gpu: 119 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 120 | "on this machine.".format(n_gpu_use, n_gpu)) 121 | n_gpu_use = n_gpu 122 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 123 | list_ids = list(range(n_gpu_use)) 124 | return device, list_ids 125 | 126 | def _save_checkpoint(self, epoch, save_best=False): 127 | """ 128 | Saving checkpoints 129 | 130 | :param epoch: current epoch number 131 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 132 | """ 133 | arch = type(self.model).__name__ 134 | state = { 135 | 'arch': arch, 136 | 'epoch': epoch, 137 | 'state_dict': self.model.state_dict(), 138 | 'optimizer': self.optimizer.state_dict(), 139 | 'monitor_best': self.mnt_best, 140 | 'config': self.config 141 | } 142 | filename = f'{str(self.checkpoint_dir)}/checkpoint-epoch{epoch}.pth' 143 | # filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 144 | torch.save(state, filename) 145 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 146 | if save_best: 147 | best_path = f'{str(self.checkpoint_dir)}/model_best.pth' 148 | torch.save(state, best_path) 149 | self.logger.info("Saving current best: model_best.pth ...") 150 | 151 | def _resume_checkpoint(self, resume_path): 152 | """ 153 | Resume from saved checkpoints 154 | 155 | :param resume_path: Checkpoint path to be resumed 156 | """ 157 | resume_path = str(resume_path) 158 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 159 | checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) 160 | self.start_epoch = checkpoint['epoch'] + 1 161 | self.mnt_best = checkpoint['monitor_best'] 162 | 163 | # load architecture params from checkpoint. 164 | if checkpoint['config']['arch'] != self.config['arch']: 165 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 166 | "checkpoint. This may yield an exception while state_dict is being loaded.") 167 | status = self._load_dict(checkpoint) 168 | 169 | self.logger.warning(f'Missing keys: {str(status[0])}') if status[0] else None 170 | self.logger.warning(f'Unexpected keys: {str(status[1])}') if status[1] else None 171 | 172 | # load optimizer state from checkpoint only when optimizer type is not changed. 173 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 174 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 175 | "Optimizer parameters not being resumed.") 176 | else: 177 | self.optimizer.load_state_dict(checkpoint['optimizer']) 178 | 179 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 180 | 181 | def _load_dict(self, checkpoint): 182 | return list(self.model.load_state_dict(checkpoint['state_dict'], False)) 183 | -------------------------------------------------------------------------------- /test_single_view.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import nibabel 6 | import numpy as np 7 | import torch 8 | from scipy.ndimage import rotate 9 | from tqdm import tqdm 10 | 11 | import data_loader as module_data_loader 12 | import dataset as module_dataset 13 | import model as module_arch 14 | import model.utils.metric as module_metric 15 | from dataset.DatasetStatic import Phase 16 | from dataset.dataset_utils import Evaluate, Dataset, Views 17 | from parse_config import ConfigParser, parse_cmd_args 18 | from test import get_timestep_limit 19 | 20 | 21 | def main(config, resume=None): 22 | logger = config.get_logger('test') 23 | 24 | if config.resume: 25 | resume = config.resume 26 | 27 | resume = Path(resume).parent 28 | for view in list(Views): 29 | 30 | dataset = config.retrieve_class('dataset', module_dataset)( 31 | **config['dataset']['args'], phase=Phase.TEST, evaluate=config['evaluate'], view=view 32 | ) 33 | data_loader = config.retrieve_class('data_loader', module_data_loader)( 34 | dataset=dataset, 35 | batch_size=config['data_loader']['args']['batch_size'], 36 | num_workers=config['data_loader']['args']['num_workers'], 37 | shuffle=False 38 | ) 39 | resume_path = os.path.join(resume, view.name, 'model_best.pth') 40 | if os.path.exists(resume_path): 41 | evaluate_model(config, data_loader, logger, resume_path, view) 42 | 43 | create_final_segmentations(config, logger) 44 | logger.info('================================') 45 | logger.info(f'Done') 46 | 47 | 48 | def create_final_segmentations(config, logger): 49 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 50 | total_metrics = torch.zeros(len(metric_fns)) 51 | patient_metrics = torch.zeros(len(metric_fns)) 52 | 53 | save_dir_path = os.path.join(config.config['trainer']['save_dir'], 'output', *str(config._save_dir).split(os.sep)[-2:]) 54 | patient_paths = sorted(os.listdir(save_dir_path)) 55 | for patient_path, patient_dir in zip(map(lambda x: os.path.join(save_dir_path, x), patient_paths), patient_paths): 56 | if not os.path.isdir(patient_path): 57 | continue 58 | 59 | segmentation_dirs = sorted(os.listdir(patient_path)) 60 | timestep = 0 61 | for seg_path, seg_dir in zip(map(lambda x: os.path.join(patient_path, x), segmentation_dirs), segmentation_dirs): 62 | if not os.path.isdir(seg_path): 63 | continue 64 | seg_path_elements = list(map(lambda x: os.path.join(seg_path, x), os.listdir(seg_path))) 65 | # load all segmentations and average over them 66 | segmentations = np.round(np.mean(np.asarray(list(map(lambda x: nibabel.load(x).get_data(), seg_path_elements))), axis=0)).astype('int8') 67 | # save all segmentations with directory's name 68 | nibabel.save(nibabel.Nifti1Image(segmentations, np.eye(4)), os.path.join(save_dir_path, f'{patient_dir}_{seg_dir}_seg.nii')) 69 | logger.info(f'Patient {int(patient_path[-2:])} - Timestep {timestep}:') 70 | if config['evaluate'] == 'train': 71 | target_path = os.path.join(config.config['data_loader']['args']['data_dir'], patient_dir, 'masks', f'{patient_dir}_{seg_dir}_mask1.nii') 72 | target_volume = np.asarray(nibabel.load(target_path).get_data()) 73 | timestep += 1 74 | for i, metric in enumerate(metric_fns): 75 | current_metric = metric(segmentations, target_volume) 76 | logger.info(f' {metric.__name__}: {current_metric}') 77 | patient_metrics[i] += current_metric 78 | total_metrics[i] += current_metric 79 | 80 | if config['evaluate'] == 'train': 81 | logger.info(f'Averaged over patient {int(patient_path[-2:])}:') 82 | for i, met in enumerate(metric_fns): 83 | logger.info(f' {met.__name__}: {patient_metrics[i].item() / timestep}') 84 | patient_metrics = torch.zeros(len(metric_fns)) 85 | 86 | 87 | def evaluate_model(config, data_loader, logger, resume, view): 88 | # build model architecture 89 | model = config.initialize_class('arch', module_arch) 90 | logger.info('Loading checkpoint: {} ...'.format(resume)) 91 | checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) 92 | state_dict = checkpoint['state_dict'] 93 | if config['n_gpu'] > 1: 94 | model = torch.nn.DataParallel(model) 95 | model.load_state_dict(state_dict) 96 | # prepare model for testing 97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 98 | model = model.to(device) 99 | model.eval() 100 | with torch.no_grad(): 101 | # setup 102 | patient = 0 103 | timestep = 0 # max 3 104 | c = 0 105 | output_list = [] 106 | n_samples = 0 107 | timestep_limit = 4 if config['dataset_type'] == Dataset.ISBI else 2 108 | res = 217 if config['dataset_type'] == Dataset.ISBI else 229 109 | 110 | for i, (data, target) in enumerate(tqdm(data_loader)): 111 | data, target = data.to(device), target.to(device) 112 | output = model(data) 113 | for slice_output, slice_data, slice_target in zip(output, data, target): 114 | # we only deal with binary data. Storing only prob for label 1 is enough because of softmax normalization: P(0) = 1 - P(1) 115 | output_list.append(np.expand_dims(slice_output.cpu().detach().float()[1], axis=0)) 116 | c += 1 117 | 118 | if not c % res and c > 0: 119 | n_samples += 1 120 | path = os.path.join(config.config['trainer']['save_dir'], 'output', *str(config._save_dir).split('/')[-2:]) 121 | evaluate_timestep(output_list, path, patient, timestep, logger, view, config) 122 | 123 | # axis = 0 124 | timestep += 1 125 | if not timestep % timestep_limit and timestep > 0: 126 | # inferred whole patient 127 | logger.info('---------------------------------') 128 | logger.info(f'Done with patient {int(patient) + 1}:') 129 | timestep = 0 130 | patient += 1 131 | timestep_limit = get_timestep_limit(config['evaluate'], patient, config['dataset_type']) 132 | logger.info(f'There exist {timestep_limit} timesteps for Patient {int(patient) + 1}') 133 | 134 | output_list = [] 135 | 136 | 137 | def evaluate_timestep(output_list, path, patient, timestep, logger, view, config): 138 | sub_path = os.path.join(path, f'{config["evaluate"].value}{(int(patient) + 1):02}', f'{int(timestep) + 1:02}') 139 | os.makedirs(sub_path, exist_ok=True) 140 | 141 | seg_volume = np.moveaxis(np.squeeze(np.asarray(output_list)), 0, int(view.value)) 142 | rotated_seg_volume = rotate(rotate(seg_volume, 90, axes=(0, 1)), -90, axes=(1, 2)) 143 | cropped_seg_volume = rotated_seg_volume[18:-18, :, 18:-18] 144 | 145 | nibabel.save(nibabel.Nifti1Image(cropped_seg_volume, np.eye(4)), os.path.join(sub_path, f'{view.name}.nii')) 146 | logger.info(f'Done with Patient {int(patient) + 1} - Timestep {int(timestep) + 1:02}') 147 | 148 | 149 | if __name__ == '__main__': 150 | args = argparse.ArgumentParser(description='PyTorch Template') 151 | args.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') 152 | args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') 153 | args.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 154 | args.add_argument('-e', '--evaluate', default=Evaluate.TEST, type=Evaluate, help='Either "training" or "test"; Determines the prefix of the folders to use') 155 | args.add_argument('-m', '--dataset_type', default=Dataset.ISBI, type=Dataset, help='Dataset to use') 156 | config = ConfigParser(*parse_cmd_args(args)) 157 | main(config) 158 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import nibabel 5 | import numpy as np 6 | import torch 7 | from scipy.ndimage import rotate 8 | from tqdm import tqdm 9 | 10 | import data_loader as module_data_loader 11 | import dataset as module_dataset 12 | import model as module_arch 13 | import model.utils.metric as module_metric 14 | from dataset.DatasetStatic import Phase 15 | from dataset.dataset_utils import Evaluate, Dataset 16 | from parse_config import ConfigParser, parse_cmd_args 17 | 18 | 19 | def main(config, resume=None): 20 | if config.resume: 21 | resume = config.resume 22 | 23 | logger = config.get_logger('test') 24 | 25 | # setup data_loader instances 26 | dataset = config.retrieve_class('dataset', module_dataset)( 27 | **config['dataset']['args'], phase=Phase.TEST, evaluate=config['evaluate'] 28 | ) 29 | data_loader = config.retrieve_class('data_loader', module_data_loader)( 30 | dataset=dataset, 31 | batch_size=config['data_loader']['args']['batch_size'], 32 | num_workers=config['data_loader']['args']['num_workers'], 33 | shuffle=False 34 | ) 35 | 36 | # build model architecture 37 | model = config.initialize_class('arch', module_arch) 38 | logger.info(model) 39 | 40 | # get function handles of loss and metrics 41 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 42 | 43 | logger.info('Loading checkpoint: {} ...'.format(resume)) 44 | checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) 45 | if config['n_gpu'] > 1: 46 | model = torch.nn.DataParallel(model) 47 | if 'state_dict' in checkpoint.keys(): 48 | model.load_state_dict(checkpoint['state_dict']) 49 | else: 50 | model.load_state_dict(checkpoint) 51 | 52 | # prepare model for testing 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | model = model.to(device) 55 | model.eval() 56 | 57 | timestep_limit = 4 if config['dataset_type'] == Dataset.ISBI else 2 58 | res = 217 if config['dataset_type'] == Dataset.ISBI else 229 59 | 60 | total_metrics = torch.zeros(len(metric_fns)) 61 | patient_metrics = torch.zeros(len(metric_fns)) 62 | 63 | with torch.no_grad(): 64 | # setup 65 | patient = 0 66 | inner_timestep = 0 67 | timestep = 0 # max 3 68 | axis = 0 # max 2 69 | c = 0 70 | data_shape = [res, res, res] 71 | 72 | output_agg = torch.zeros([3, *data_shape]).to(device) 73 | avg_seg_volume = None 74 | target_agg = torch.zeros(data_shape).to(device) 75 | 76 | n_samples = 0 77 | for idx, loaded_data in enumerate(tqdm(data_loader)): 78 | if len(loaded_data) == 2: 79 | inner_timestep_limit = 1 80 | # static case 81 | data, target = loaded_data 82 | data, target = data.to(device), target.to(device) 83 | output = model(data) 84 | else: 85 | # longitudinal case 86 | inner_timestep_limit = timestep_limit - 1 87 | x_ref, x, _, target = loaded_data 88 | x_ref, x, target = x_ref.to(device), x.to(device), target.to(device) 89 | output = model(x_ref, x) 90 | if isinstance(output, tuple): 91 | output, warp, flow = output 92 | 93 | for slice_output, slice_target in zip(output, target): 94 | # we only deal with binary data. Storing only prob for label 1 is enough because of softmax normalization: P(0) = 1 - P(1) 95 | output_agg[axis][c % res] = torch.unsqueeze(slice_output.float()[1], dim=0) 96 | 97 | if axis == 0 and inner_timestep == 0: 98 | target_agg[c % res] = torch.argmax(slice_target.float(), dim=0) 99 | 100 | c += 1 101 | 102 | if not c % res and c > 0: 103 | axis += 1 104 | if not axis % 3 and axis > 0: 105 | path = os.path.join(config.config['trainer']['save_dir'], 'output', *str(config._save_dir).split(os.sep)[-2:], 106 | str(resume).split(os.sep)[-1][:-4]) 107 | os.makedirs(path, exist_ok=True) 108 | 109 | if avg_seg_volume is None: 110 | avg_seg_volume = get_avg_seg_volume(output_agg, data_shape) 111 | else: 112 | avg_seg_volume = torch.cat([avg_seg_volume, get_avg_seg_volume(output_agg, data_shape)], dim=0) 113 | 114 | axis = 0 115 | inner_timestep += 1 116 | if not inner_timestep % inner_timestep_limit and inner_timestep > 0: 117 | # inferred one timestep 118 | n_samples += 1 119 | avg_seg_volume = avg_seg_volume.mean(dim=0) 120 | 121 | evaluate_timestep(avg_seg_volume, target_agg, metric_fns, config, path, patient, patient_metrics, total_metrics, 122 | timestep, 123 | logger) 124 | timestep += 1 125 | if not timestep % timestep_limit and timestep > 0: 126 | # inferred whole patient 127 | logger.info('---------------------------------') 128 | logger.info(f'Averaged over patient {int(patient) + 1}:') 129 | for i, met in enumerate(metric_fns): 130 | logger.info(f' {met.__name__}: {patient_metrics[i].item() / timestep}') 131 | patient_metrics = torch.zeros(len(metric_fns)) 132 | timestep = 0 133 | patient += 1 134 | timestep_limit = get_timestep_limit(config['evaluate'], patient, config['dataset_type']) 135 | logger.info(f'There exist {timestep_limit} timesteps for Patient {int(patient) + 1}') 136 | 137 | inner_timestep = 0 138 | avg_seg_volume = None 139 | 140 | logger.info('================================') 141 | logger.info(f'Averaged over all patients:') 142 | for i, met in enumerate(metric_fns): 143 | logger.info(f' {met.__name__}: {total_metrics[i].item() / n_samples}') 144 | 145 | 146 | def evaluate_timestep(avg_seg_volume, target_agg, metric_fns, config, path, patient, patient_metrics, total_metrics, timestep, logger): 147 | prefix = f'{config["evaluate"].value}{(int(patient) + 1):02}_{int(timestep) + 1:02}' 148 | seg_volume = torch.round(avg_seg_volume).int().cpu().detach().numpy() 149 | rotated_seg_volume = rotate(rotate(seg_volume, 90, axes=(0, 1)), -90, axes=(1, 2)) 150 | cropped_seg_volume = rotated_seg_volume[18:-18, :, 18:-18] 151 | nibabel.save(nibabel.Nifti1Image(cropped_seg_volume, np.eye(4)), os.path.join(path, f'{prefix}_seg.nii')) 152 | 153 | target_volume = torch.squeeze(target_agg).int().cpu().detach().numpy() 154 | rotated_target_volume = rotate(rotate(target_volume, 90, axes=(0, 1)), -90, axes=(1, 2)) 155 | cropped_target_volume = rotated_target_volume[18:-18, :, 18:-18] 156 | nibabel.save(nibabel.Nifti1Image(cropped_target_volume, np.eye(4)), os.path.join(path, f'{prefix}_target.nii.gz')) 157 | # computing loss, metrics on test set 158 | logger.info(f'Patient {int(patient) + 1} - Timestep {int(timestep) + 1:02}:') 159 | for i, metric in enumerate(metric_fns): 160 | current_metric = metric(cropped_seg_volume, cropped_target_volume) 161 | logger.info(f' {metric.__name__}: {current_metric}') 162 | patient_metrics[i] += current_metric 163 | total_metrics[i] += current_metric 164 | 165 | 166 | def get_avg_seg_volume(output_dict, data_shape): 167 | axis_volumes = torch.zeros([3, *data_shape]) 168 | for i in range(len(output_dict)): 169 | axis_volume = torch.squeeze(output_dict[i]) 170 | if i == 1: 171 | rotated_axis_volume = axis_volume.permute(1, 0, 2) 172 | elif i == 2: 173 | rotated_axis_volume = axis_volume.permute(1, 2, 0) 174 | else: 175 | rotated_axis_volume = axis_volume 176 | axis_volumes[i] = rotated_axis_volume 177 | 178 | # Some explanations for the following line: 179 | # for axis_volumes we only used the predictions for the 1 label. By building the mean over all values up and rounding this we get the value 1 180 | # for those where the label 1 has the majority in softmax space, else 0. This exactly corresponds to our prediction as we would have taken argmax. 181 | return torch.unsqueeze(axis_volumes.mean(dim=0), dim=0) 182 | 183 | 184 | def get_timestep_limit(evaluate, patient, dataset): 185 | if dataset == Dataset.ISBI: 186 | timestep_limit = 4 187 | if evaluate == Evaluate.TEST: 188 | if patient == 1 or patient == 10 or patient == 13: 189 | timestep_limit = 5 190 | elif patient == 9: 191 | timestep_limit = 6 192 | else: 193 | timestep_limit = 4 194 | elif evaluate == Evaluate.TRAINING: 195 | if patient == 2: 196 | timestep_limit = 5 197 | else: 198 | timestep_limit = 4 199 | elif dataset == Dataset.INHOUSE: 200 | timestep_limit = 2 201 | if evaluate == Evaluate.TEST: 202 | if patient == 40 or patient == 87 or patient == 100 or patient == 110: 203 | timestep_limit = 3 204 | elif patient == 122: 205 | timestep_limit = 4 206 | else: 207 | timestep_limit = 2 208 | elif evaluate == Evaluate.TRAINING: 209 | if patient == 27: 210 | timestep_limit = 3 211 | else: 212 | raise ValueError(f'Invalid dataset type given: {dataset}') 213 | 214 | return timestep_limit 215 | 216 | 217 | if __name__ == '__main__': 218 | args = argparse.ArgumentParser(description='PyTorch Template') 219 | args.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') 220 | args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') 221 | args.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 222 | args.add_argument('-e', '--evaluate', default=Evaluate.TEST, type=Evaluate, help='Either "training" or "test"; Determines the prefix of the folders to use') 223 | args.add_argument('-m', '--dataset_type', default=Dataset.ISBI, type=Dataset, help='Dataset to use') 224 | config = ConfigParser(*parse_cmd_args(args)) 225 | main(config) 226 | -------------------------------------------------------------------------------- /dataset/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import defaultdict, OrderedDict 4 | from enum import Enum 5 | from glob import glob 6 | 7 | import h5py 8 | import nibabel as nib 9 | import numpy as np 10 | import scipy.ndimage 11 | 12 | 13 | class Modalities(Enum): 14 | FLAIR = 'flair' 15 | MPRAGE = 'mprage' 16 | PD = 'pd' 17 | T2 = 't2' 18 | T1W = 't1w' 19 | 20 | 21 | class Phase(Enum): 22 | TRAIN = 'train' 23 | VAL = 'val' 24 | TEST = 'test' 25 | 26 | 27 | class Views(Enum): 28 | AXIAL = 0 29 | SAGITTAL = 1 30 | CORONAL = 2 31 | 32 | 33 | class Mode(Enum): 34 | STATIC = 'static' 35 | LONGITUDINAL = 'longitudinal' 36 | 37 | 38 | class Dataset(Enum): 39 | ISBI = 'isbi' 40 | INHOUSE = 'inhouse' 41 | 42 | 43 | class Evaluate(Enum): 44 | TRAINING = 'training' 45 | TEST = 'test' 46 | 47 | 48 | def retrieve_data_dir_paths(data_dir, evaluate: Evaluate, phase, preprocess, val_patients, mode, view=None): 49 | empty_slices = None 50 | non_positive_slices = None 51 | if preprocess: 52 | print('Preprocessing files...') 53 | empty_slices, non_positive_slices = preprocess_files(data_dir, phase, evaluate) 54 | print('Files preprocessed.') 55 | if mode == Mode.LONGITUDINAL: 56 | data_dir_paths = retrieve_paths_longitudinal(get_patient_paths(data_dir, evaluate, phase)).items() 57 | else: 58 | data_dir_paths = retrieve_paths_static(get_patient_paths(data_dir, evaluate, phase)).items() 59 | data_dir_paths = OrderedDict(sorted(data_dir_paths)) 60 | _data_dir_paths = [] 61 | patient_keys = [key for key in data_dir_paths.keys()] 62 | if phase == Phase.TRAIN: 63 | for val_patient in val_patients[::-1]: 64 | patient_keys.remove(patient_keys[val_patient]) 65 | 66 | for patient in patient_keys: 67 | _data_dir_paths += data_dir_paths[patient] 68 | elif phase == Phase.VAL: 69 | for val_patient in val_patients: 70 | _data_dir_paths += data_dir_paths[patient_keys[val_patient]] 71 | else: 72 | for patient in patient_keys: 73 | _data_dir_paths += data_dir_paths[patient] 74 | 75 | if view: 76 | _data_dir_paths = list(filter(lambda path: int(path.split(os.sep)[-2]) == view.value, _data_dir_paths)) 77 | if phase == Phase.TRAIN or phase == Phase.VAL: 78 | _data_dir_paths = retrieve_filtered_data_dir_paths(data_dir, phase, _data_dir_paths, empty_slices, non_positive_slices, 79 | mode, val_patients, view) 80 | return _data_dir_paths 81 | 82 | 83 | def preprocess_files(root_dir, phase, evaluate, base_path='data'): 84 | patients = list(filter(lambda name: (evaluate.value if phase == Phase.TEST else Evaluate.TRAINING.value) in name, os.listdir(root_dir))) 85 | empty_slices = [] 86 | non_positive_slices = [] 87 | i_patients = len(patients) + 1 88 | for patient in patients: 89 | print(f'Processing patient {patient}') 90 | patient_path = os.path.join(root_dir, patient) 91 | if os.path.exists(os.path.join(patient_path, base_path)): 92 | continue 93 | patient_data_path = os.path.join(patient_path, 'preprocessed', patient) 94 | patient_label_path = os.path.join(patient_path, 'masks', patient) 95 | 96 | for modality in list(Modalities): 97 | mod, value = modality.name, modality.value 98 | for timestep in range(10): 99 | data_path = f'{patient_data_path}_0{timestep + 1}_{value}_pp.nii' 100 | if not os.path.exists(data_path): 101 | continue 102 | rotated_data = transform_data(data_path) 103 | normalized_data = (rotated_data - np.min(rotated_data)) / (np.max(rotated_data) - np.min(rotated_data)) 104 | label_path = f'{patient_label_path}_0{timestep + 1}_mask1.nii' 105 | if os.path.exists(label_path): 106 | rotated_labels = transform_data(label_path) 107 | else: 108 | rotated_labels = np.zeros(normalized_data.shape) 109 | 110 | # create slices through all views 111 | temp_empty_slices, temp_non_positive_slices = create_slices(normalized_data, rotated_labels, 112 | os.path.join(patient_path, base_path, str(timestep)), value) 113 | empty_slices += temp_empty_slices 114 | non_positive_slices += temp_non_positive_slices 115 | 116 | i_patients += 1 117 | return empty_slices, non_positive_slices 118 | 119 | 120 | def transform_data(data_path): 121 | data = nib.load(data_path).get_data() 122 | x_dim, y_dim, z_dim = data.shape 123 | max_dim = max(x_dim, y_dim, z_dim) 124 | x_pad = get_padding(max_dim, x_dim) 125 | y_pad = get_padding(max_dim, y_dim) 126 | z_pad = get_padding(max_dim, z_dim) 127 | padded_data = np.pad(data, (x_pad, y_pad, z_pad), 'constant') 128 | rotated_data = scipy.ndimage.rotate(scipy.ndimage.rotate(padded_data, 90, axes=(1, 2)), -90, axes=(0, 1)) 129 | return rotated_data 130 | 131 | 132 | def get_padding(max_dim, current_dim): 133 | diff = max_dim - current_dim 134 | pad = diff // 2 135 | if diff % 2 == 0: 136 | return pad, pad 137 | else: 138 | return pad, pad + 1 139 | 140 | 141 | def create_slices(data, label, timestep_path, modality): 142 | empty_slices = [] 143 | non_positive_slices = [] 144 | for view in list(Views): 145 | name, axis = view.name, view.value 146 | temp_data = np.moveaxis(data, axis, 0) 147 | temp_labels = np.moveaxis(label, axis, 0) 148 | for i, (data_slice, label_slice) in enumerate(zip(temp_data, temp_labels)): 149 | path = os.path.join(timestep_path, str(axis), f'{i:03}') 150 | full_path = os.path.join(path, f'{modality}.h5') 151 | if np.sum(data_slice) <= 1e-5: 152 | empty_slices.append(path) 153 | 154 | if np.sum(label_slice) <= 1e-5: 155 | non_positive_slices.append(path) 156 | 157 | while not os.path.exists(full_path): # sometimes file is not created correctly => Just redo until it exists 158 | if not os.path.exists(path): 159 | os.makedirs(path) 160 | with h5py.File(full_path, 'w') as data_file: 161 | data_file.create_dataset('data', data=data_slice, dtype='f') 162 | data_file.create_dataset('label', data=label_slice, dtype='i') 163 | 164 | return empty_slices, non_positive_slices 165 | 166 | 167 | def retrieve_paths_static(patient_paths): 168 | data_dir_paths = defaultdict(list) 169 | for patient_path in patient_paths: 170 | if not os.path.isdir(patient_path): 171 | continue 172 | patient = patient_path.split(os.sep)[-2] 173 | for timestep in filter(lambda x: os.path.isdir(os.path.join(patient_path, x)), os.listdir(patient_path)): 174 | timestep_path = os.path.join(patient_path, timestep) 175 | for axis in filter(lambda x: os.path.isdir(os.path.join(timestep_path, x)), os.listdir(timestep_path)): 176 | axis_path = os.path.join(timestep_path, axis) 177 | slice_paths = filter(lambda x: os.path.isdir(x), map(lambda x: os.path.join(axis_path, x), os.listdir(axis_path))) 178 | data_dir_paths[patient] += slice_paths 179 | 180 | return data_dir_paths 181 | 182 | 183 | def retrieve_paths_longitudinal(patient_paths): 184 | data_dir_paths = defaultdict(list) 185 | for patient_path in patient_paths: 186 | if not os.path.isdir(patient_path): 187 | continue 188 | patient = patient_path.split(os.sep)[-2] 189 | for timestep_x in sorted(filter(lambda x: os.path.isdir(os.path.join(patient_path, x)), os.listdir(patient_path))): 190 | x_timestep = defaultdict(list) 191 | timestep_x_int = int(timestep_x) 192 | timestep_x_path = os.path.join(patient_path, timestep_x) 193 | for axis in sorted(filter(lambda x: os.path.isdir(os.path.join(timestep_x_path, x)), os.listdir(timestep_x_path))): 194 | axis_path = os.path.join(timestep_x_path, axis) 195 | slice_paths = sorted(filter(lambda x: os.path.isdir(x), map(lambda x: os.path.join(axis_path, x), os.listdir(axis_path)))) 196 | x_timestep[axis] = slice_paths 197 | 198 | for timestep_x_ref in sorted(filter(lambda x: os.path.isdir(os.path.join(patient_path, x)), os.listdir(patient_path))): 199 | x_ref_timestep = defaultdict(list) 200 | timestep_x_ref_int = int(timestep_x_ref) 201 | timestep_x_ref_path = os.path.join(patient_path, timestep_x_ref) 202 | for axis in sorted(filter(lambda x: os.path.isdir(os.path.join(timestep_x_ref_path, x)), os.listdir(timestep_x_ref_path))): 203 | axis_path = os.path.join(timestep_x_ref_path, axis) 204 | slice_paths = sorted(filter(lambda x: os.path.isdir(x), map(lambda x: os.path.join(axis_path, x), os.listdir(axis_path)))) 205 | x_ref_timestep[axis] = slice_paths 206 | 207 | if timestep_x_int != timestep_x_ref_int: 208 | data_dir_paths[patient] += zip(x_ref_timestep[axis], x_timestep[axis]) 209 | 210 | return data_dir_paths 211 | 212 | 213 | def get_patient_paths(data_dir, evaluate, phase): 214 | patient_paths = map(lambda name: os.path.join(name, 'data'), 215 | (filter(lambda name: (evaluate.value if phase == Phase.TEST else Evaluate.TRAINING.value) in name, 216 | glob(os.path.join(data_dir, '*'))))) 217 | return patient_paths 218 | 219 | 220 | def retrieve_filtered_data_dir_paths(root_dir, phase, data_dir_paths, empty_slices, non_positive_slices, mode, val_patients, view: Views = None): 221 | empty_file_path = os.path.join(root_dir, 'empty_slices.pckl') 222 | non_positive_slices_path = os.path.join(root_dir, 'non_positive_slices.pckl') 223 | 224 | if empty_slices: 225 | pickle.dump(empty_slices, open(empty_file_path, 'wb')) 226 | if non_positive_slices: 227 | pickle.dump(non_positive_slices, open(non_positive_slices_path, 'wb')) 228 | 229 | data_dir_path = os.path.join(root_dir, f'data_dir_{mode.value}_{phase.value}_{val_patients}{f"_{view.name}" if view else ""}.pckl') 230 | if os.path.exists(data_dir_path): 231 | # means it has been preprocessed before -> directly load data_dir_paths 232 | data_dir_paths = pickle.load(open(data_dir_path, 'rb')) 233 | print(f'Elements in data_dir_paths: {len(data_dir_paths)}') 234 | else: 235 | if not empty_slices: 236 | empty_slices = pickle.load(open(empty_file_path, 'rb')) 237 | if not non_positive_slices: 238 | non_positive_slices = pickle.load(open(non_positive_slices_path, 'rb')) 239 | print(f'Elements in data_dir_paths before filtering empty slices: {len(data_dir_paths)}') 240 | if mode == Mode.STATIC: 241 | data_dir_paths = [x for x in data_dir_paths if x not in set(empty_slices + non_positive_slices)] 242 | else: 243 | data_dir_paths = [(x_ref, x) for x_ref, x in data_dir_paths if x not in set(empty_slices + non_positive_slices)] 244 | 245 | print(f'Elements in data_dir_paths after filtering empty slices: {len(data_dir_paths)}') 246 | pickle.dump(data_dir_paths, open(data_dir_path, 'wb')) 247 | 248 | return data_dir_paths 249 | --------------------------------------------------------------------------------