├── src ├── __init__.py ├── metrics.py ├── losses.py ├── data │ ├── splits │ │ ├── train_val_split_0.pkl │ │ ├── train_val_split_1.pkl │ │ ├── train_val_split_2.pkl │ │ ├── train_val_split_3.pkl │ │ ├── train_val_new_split_0.pkl │ │ ├── train_val_new_split_1.pkl │ │ ├── train_val_new_split_2.pkl │ │ └── train_val_new_split_3.pkl │ ├── make_dataset.py │ └── utils.py ├── dataset.py ├── layers.py ├── predictor.py ├── models.py ├── transforms.py └── trainer.py ├── .gitignore ├── config ├── make_dataset.yaml ├── model_train.yaml └── model_predict.yaml ├── LICENSE ├── model ├── predict.py └── train.py ├── README.md └── notebooks └── make_dataset.ipynb /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | temp.ipynb 2 | .ipynb_checkpoints -------------------------------------------------------------------------------- /config/make_dataset.yaml: -------------------------------------------------------------------------------- 1 | path_to_input: '/home/iantsen/hecktor/data/hecktor_test/hecktor_nii_test/' # directory with images 2 | path_to_bb: '/home/iantsen/hecktor/data/hecktor_test/bbox_test.csv' # file with bounding box coordinates 3 | path_to_output: '/home/iantsen/hecktor/data/hecktor_test/hecktor_nii_resampled/' # directory to save resampled images 4 | is_mask_available: false # if `true`, masks will be preprocessed. Use `false`, if masks are unavailable 5 | verbose: false # if `true`, the progress bar will be shown -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | def dice(input, target): 2 | axes = tuple(range(1, input.dim())) 3 | bin_input = (input > 0.5).float() 4 | 5 | intersect = (bin_input * target).sum(dim=axes) 6 | union = bin_input.sum(dim=axes) + target.sum(dim=axes) 7 | score = 2 * intersect / (union + 1e-3) 8 | 9 | return score.mean() 10 | 11 | 12 | def recall(input, target): 13 | axes = tuple(range(1, input.dim())) 14 | binary_input = (input > 0.5).float() 15 | 16 | true_positives = (binary_input * target).sum(dim=axes) 17 | all_positives = target.sum(dim=axes) 18 | recall = true_positives / all_positives 19 | 20 | return recall.mean() 21 | 22 | 23 | def precision(input, target): 24 | axes = tuple(range(1, input.dim())) 25 | binary_input = (input > 0.5).float() 26 | 27 | true_positives = (binary_input * target).sum(dim=axes) 28 | all_positive_calls = binary_input.sum(dim=axes) 29 | precision = true_positives / all_positive_calls 30 | 31 | return precision.mean() 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 iantsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/model_train.yaml: -------------------------------------------------------------------------------- 1 | # paths: 2 | path_to_data: 'C:/inserm/hecktor/hecktor_train/hecktor_nii_resampled/' # directory with images 3 | path_to_pkl: 'C:/inserm/hecktor/splits/train_val_split_0.pkl' # pkl file with train / val splits 4 | path_to_save_dir: 'C:/inserm/hecktor/results/' # all results (weights, learning curves, etc) will be saved here 5 | 6 | # train settings: 7 | train_batch_size: 1 8 | val_batch_size: 1 9 | num_workers: 2 # for example, use a number of CPU cores 10 | 11 | lr: 1e-3 # initial learning rate 12 | n_epochs: 2 # number of training epochs (300 was used in the paper) 13 | n_cls: 2 # number of classes to predict (background and tumor) 14 | in_channels: 2 # number of input modalities 15 | n_filters: 4 # number of filters after the input (24 was used in the paper) 16 | reduction: 2 # parameter controls the size of the bottleneck in SENorm layers 17 | 18 | T_0: 25 # parameter for 'torch.optim.lr_scheduler.CosineAnnealingWarmRestarts' 19 | eta_min: 1e-5 # parameter for 'torch.optim.lr_scheduler.CosineAnnealingWarmRestarts' 20 | 21 | # model: 22 | baseline: false # if `true`, U-Net will be used. Otherwise, the model described in the paper will be trained. 23 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self): 7 | super(DiceLoss, self).__init__() 8 | self.smooth = 1 9 | 10 | def forward(self, input, target): 11 | axes = tuple(range(1, input.dim())) 12 | intersect = (input * target).sum(dim=axes) 13 | union = torch.pow(input, 2).sum(dim=axes) + torch.pow(target, 2).sum(dim=axes) 14 | loss = 1 - (2 * intersect + self.smooth) / (union + self.smooth) 15 | return loss.mean() 16 | 17 | 18 | class FocalLoss(nn.Module): 19 | def __init__(self, gamma=2): 20 | super(FocalLoss, self).__init__() 21 | self.gamma = gamma 22 | self.eps = 1e-3 23 | 24 | def forward(self, input, target): 25 | input = input.clamp(self.eps, 1 - self.eps) 26 | loss = - (target * torch.pow((1 - input), self.gamma) * torch.log(input) + 27 | (1 - target) * torch.pow(input, self.gamma) * torch.log(1 - input)) 28 | return loss.mean() 29 | 30 | 31 | class Dice_and_FocalLoss(nn.Module): 32 | def __init__(self, gamma=2): 33 | super(Dice_and_FocalLoss, self).__init__() 34 | self.dice_loss = DiceLoss() 35 | self.focal_loss = FocalLoss(gamma) 36 | 37 | def forward(self, input, target): 38 | loss = self.dice_loss(input, target) + self.focal_loss(input, target) 39 | return loss 40 | -------------------------------------------------------------------------------- /config/model_predict.yaml: -------------------------------------------------------------------------------- 1 | # paths: 2 | path_to_data: '/home/iantsen/hecktor/data/hecktor_test/hecktor_nii_resampled/' # directory with test images 3 | path_to_save_dir: '/home/iantsen/hecktor/data/hecktor_test/preds/' # predictions will be saved here 4 | 5 | path_to_weights: # path or paths to weights. If multiple paths provided, an ensemble of models will be used 6 | - '/home/iantsen/hecktor/model/weights/s0_best_model_weights.pt' 7 | - '/home/iantsen/hecktor/model/weights/s1_best_model_weights.pt' 8 | - '/home/iantsen/hecktor/model/weights/s2_best_model_weights.pt' 9 | - '/home/iantsen/hecktor/model/weights/s3_best_model_weights.pt' 10 | - '/home/iantsen/hecktor/model/weights/ns0_best_model_weights.pt' 11 | - '/home/iantsen/hecktor/model/weights/ns1_best_model_weights.pt' 12 | - '/home/iantsen/hecktor/model/weights/ns2_best_model_weights.pt' 13 | - '/home/iantsen/hecktor/model/weights/ns3_best_model_weights.pt' 14 | 15 | # output: 16 | probs: false # if `true`, the sigmoid output will be saved. Otherwise, 0.5-threshold will be applied to get binary labels 17 | 18 | # train settings: 19 | num_workers: 2 # for example, use a number of CPU cores 20 | 21 | n_cls: 2 # number of classes to predict (background and tumor) 22 | in_channels: 2 # number of input modalities 23 | n_filters: 24 # number of filters after the input (24 was used in the paper) 24 | reduction: 2 # parameter controls the size of the bottleneck in SENorm layers 25 | -------------------------------------------------------------------------------- /src/data/splits/train_val_split_0.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHUS003", "CHUS004", "CHUS005", "CHUS006", "CHUS007", "CHUS008", "CHUS009", "CHUS010", "CHUS013", "CHUS015", "CHUS016", "CHUS019", "CHUS020", "CHUS021", "CHUS022", "CHUS026", "CHUS027", "CHUS028", "CHUS030", "CHUS031", "CHUS033", "CHUS035", "CHUS036", "CHUS038", "CHUS039", "CHUS040", "CHUS041", "CHUS042", "CHUS043", "CHUS045", "CHUS046", "CHUS047", "CHUS048", "CHUS049", "CHUS050", "CHUS051", "CHUS052", "CHUS053", "CHUS055", "CHUS056", "CHUS057", "CHUS058", "CHUS060", "CHUS061", "CHUS064", "CHUS065", "CHUS066", "CHUS067", "CHUS068", "CHUS069", "CHUS073", "CHUS074", "CHUS076", "CHUS077", "CHUS078", "CHUS080", "CHUS081", "CHUS083", "CHUS085", "CHUS086", "CHUS087", "CHUS088", "CHUS089", "CHUS090", "CHUS091", "CHUS094", "CHUS095", "CHUS096", "CHUS097", "CHUS098", "CHUS100", "CHUS101", "CHUM001", "CHUM002", "CHUM006", "CHUM007", "CHUM008", "CHUM010", "CHUM011", "CHUM012", "CHUM013", "CHUM014", "CHUM015", "CHUM016", "CHUM017", "CHUM018", "CHUM019", "CHUM021", "CHUM022", "CHUM023", "CHUM024", "CHUM026", "CHUM027", "CHUM029", "CHUM030", "CHUM032", "CHUM033", "CHUM034", "CHUM035", "CHUM036", "CHUM037", "CHUM038", "CHUM039", "CHUM040", "CHUM041", "CHUM042", "CHUM043", "CHUM044", "CHUM045", "CHUM046", "CHUM047", "CHUM048", "CHUM049", "CHUM050", "CHUM051", "CHUM053", "CHUM054", "CHUM055", "CHUM056", "CHUM057", "CHUM058", "CHUM059", "CHUM060", "CHUM061", "CHUM062", "CHUM063", "CHUM064", "CHUM065", "CHMR001", "CHMR004", "CHMR005", "CHMR011", "CHMR012", "CHMR013", "CHMR014", "CHMR016", "CHMR020", "CHMR021", "CHMR023", "CHMR024", "CHMR025", "CHMR028", "CHMR029", "CHMR030", "CHMR034", "CHMR040"], "val": ["CHGJ007", "CHGJ008", "CHGJ010", "CHGJ013", "CHGJ015", "CHGJ016", "CHGJ017", "CHGJ018", "CHGJ025", "CHGJ026", "CHGJ028", "CHGJ029", "CHGJ030", "CHGJ031", "CHGJ032", "CHGJ034", "CHGJ035", "CHGJ036", "CHGJ037", "CHGJ038", "CHGJ039", "CHGJ043", "CHGJ046", "CHGJ048", "CHGJ050", "CHGJ052", "CHGJ053", "CHGJ055", "CHGJ057", "CHGJ058", "CHGJ062", "CHGJ065", "CHGJ066", "CHGJ067", "CHGJ069", "CHGJ070", "CHGJ071", "CHGJ072", "CHGJ073", "CHGJ074", "CHGJ076", "CHGJ077", "CHGJ078", "CHGJ080", "CHGJ081", "CHGJ082", "CHGJ083", "CHGJ085", "CHGJ086", "CHGJ087", "CHGJ088", "CHGJ089", "CHGJ090", "CHGJ091", "CHGJ092"]} -------------------------------------------------------------------------------- /src/data/splits/train_val_split_1.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHGJ007", "CHGJ008", "CHGJ010", "CHGJ013", "CHGJ015", "CHGJ016", "CHGJ017", "CHGJ018", "CHGJ025", "CHGJ026", "CHGJ028", "CHGJ029", "CHGJ030", "CHGJ031", "CHGJ032", "CHGJ034", "CHGJ035", "CHGJ036", "CHGJ037", "CHGJ038", "CHGJ039", "CHGJ043", "CHGJ046", "CHGJ048", "CHGJ050", "CHGJ052", "CHGJ053", "CHGJ055", "CHGJ057", "CHGJ058", "CHGJ062", "CHGJ065", "CHGJ066", "CHGJ067", "CHGJ069", "CHGJ070", "CHGJ071", "CHGJ072", "CHGJ073", "CHGJ074", "CHGJ076", "CHGJ077", "CHGJ078", "CHGJ080", "CHGJ081", "CHGJ082", "CHGJ083", "CHGJ085", "CHGJ086", "CHGJ087", "CHGJ088", "CHGJ089", "CHGJ090", "CHGJ091", "CHGJ092", "CHUM001", "CHUM002", "CHUM006", "CHUM007", "CHUM008", "CHUM010", "CHUM011", "CHUM012", "CHUM013", "CHUM014", "CHUM015", "CHUM016", "CHUM017", "CHUM018", "CHUM019", "CHUM021", "CHUM022", "CHUM023", "CHUM024", "CHUM026", "CHUM027", "CHUM029", "CHUM030", "CHUM032", "CHUM033", "CHUM034", "CHUM035", "CHUM036", "CHUM037", "CHUM038", "CHUM039", "CHUM040", "CHUM041", "CHUM042", "CHUM043", "CHUM044", "CHUM045", "CHUM046", "CHUM047", "CHUM048", "CHUM049", "CHUM050", "CHUM051", "CHUM053", "CHUM054", "CHUM055", "CHUM056", "CHUM057", "CHUM058", "CHUM059", "CHUM060", "CHUM061", "CHUM062", "CHUM063", "CHUM064", "CHUM065", "CHMR001", "CHMR004", "CHMR005", "CHMR011", "CHMR012", "CHMR013", "CHMR014", "CHMR016", "CHMR020", "CHMR021", "CHMR023", "CHMR024", "CHMR025", "CHMR028", "CHMR029", "CHMR030", "CHMR034", "CHMR040"], "val": ["CHUS003", "CHUS004", "CHUS005", "CHUS006", "CHUS007", "CHUS008", "CHUS009", "CHUS010", "CHUS013", "CHUS015", "CHUS016", "CHUS019", "CHUS020", "CHUS021", "CHUS022", "CHUS026", "CHUS027", "CHUS028", "CHUS030", "CHUS031", "CHUS033", "CHUS035", "CHUS036", "CHUS038", "CHUS039", "CHUS040", "CHUS041", "CHUS042", "CHUS043", "CHUS045", "CHUS046", "CHUS047", "CHUS048", "CHUS049", "CHUS050", "CHUS051", "CHUS052", "CHUS053", "CHUS055", "CHUS056", "CHUS057", "CHUS058", "CHUS060", "CHUS061", "CHUS064", "CHUS065", "CHUS066", "CHUS067", "CHUS068", "CHUS069", "CHUS073", "CHUS074", "CHUS076", "CHUS077", "CHUS078", "CHUS080", "CHUS081", "CHUS083", "CHUS085", "CHUS086", "CHUS087", "CHUS088", "CHUS089", "CHUS090", "CHUS091", "CHUS094", "CHUS095", "CHUS096", "CHUS097", "CHUS098", "CHUS100", "CHUS101"]} -------------------------------------------------------------------------------- /src/data/splits/train_val_split_2.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHGJ007", "CHGJ008", "CHGJ010", "CHGJ013", "CHGJ015", "CHGJ016", "CHGJ017", "CHGJ018", "CHGJ025", "CHGJ026", "CHGJ028", "CHGJ029", "CHGJ030", "CHGJ031", "CHGJ032", "CHGJ034", "CHGJ035", "CHGJ036", "CHGJ037", "CHGJ038", "CHGJ039", "CHGJ043", "CHGJ046", "CHGJ048", "CHGJ050", "CHGJ052", "CHGJ053", "CHGJ055", "CHGJ057", "CHGJ058", "CHGJ062", "CHGJ065", "CHGJ066", "CHGJ067", "CHGJ069", "CHGJ070", "CHGJ071", "CHGJ072", "CHGJ073", "CHGJ074", "CHGJ076", "CHGJ077", "CHGJ078", "CHGJ080", "CHGJ081", "CHGJ082", "CHGJ083", "CHGJ085", "CHGJ086", "CHGJ087", "CHGJ088", "CHGJ089", "CHGJ090", "CHGJ091", "CHGJ092", "CHUS003", "CHUS004", "CHUS005", "CHUS006", "CHUS007", "CHUS008", "CHUS009", "CHUS010", "CHUS013", "CHUS015", "CHUS016", "CHUS019", "CHUS020", "CHUS021", "CHUS022", "CHUS026", "CHUS027", "CHUS028", "CHUS030", "CHUS031", "CHUS033", "CHUS035", "CHUS036", "CHUS038", "CHUS039", "CHUS040", "CHUS041", "CHUS042", "CHUS043", "CHUS045", "CHUS046", "CHUS047", "CHUS048", "CHUS049", "CHUS050", "CHUS051", "CHUS052", "CHUS053", "CHUS055", "CHUS056", "CHUS057", "CHUS058", "CHUS060", "CHUS061", "CHUS064", "CHUS065", "CHUS066", "CHUS067", "CHUS068", "CHUS069", "CHUS073", "CHUS074", "CHUS076", "CHUS077", "CHUS078", "CHUS080", "CHUS081", "CHUS083", "CHUS085", "CHUS086", "CHUS087", "CHUS088", "CHUS089", "CHUS090", "CHUS091", "CHUS094", "CHUS095", "CHUS096", "CHUS097", "CHUS098", "CHUS100", "CHUS101", "CHMR001", "CHMR004", "CHMR005", "CHMR011", "CHMR012", "CHMR013", "CHMR014", "CHMR016", "CHMR020", "CHMR021", "CHMR023", "CHMR024", "CHMR025", "CHMR028", "CHMR029", "CHMR030", "CHMR034", "CHMR040"], "val": ["CHUM001", "CHUM002", "CHUM006", "CHUM007", "CHUM008", "CHUM010", "CHUM011", "CHUM012", "CHUM013", "CHUM014", "CHUM015", "CHUM016", "CHUM017", "CHUM018", "CHUM019", "CHUM021", "CHUM022", "CHUM023", "CHUM024", "CHUM026", "CHUM027", "CHUM029", "CHUM030", "CHUM032", "CHUM033", "CHUM034", "CHUM035", "CHUM036", "CHUM037", "CHUM038", "CHUM039", "CHUM040", "CHUM041", "CHUM042", "CHUM043", "CHUM044", "CHUM045", "CHUM046", "CHUM047", "CHUM048", "CHUM049", "CHUM050", "CHUM051", "CHUM053", "CHUM054", "CHUM055", "CHUM056", "CHUM057", "CHUM058", "CHUM059", "CHUM060", "CHUM061", "CHUM062", "CHUM063", "CHUM064", "CHUM065"]} -------------------------------------------------------------------------------- /src/data/splits/train_val_split_3.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHGJ007", "CHGJ008", "CHGJ010", "CHGJ013", "CHGJ015", "CHGJ016", "CHGJ017", "CHGJ018", "CHGJ025", "CHGJ026", "CHGJ028", "CHGJ029", "CHGJ030", "CHGJ031", "CHGJ032", "CHGJ034", "CHGJ035", "CHGJ036", "CHGJ037", "CHGJ038", "CHGJ039", "CHGJ043", "CHGJ046", "CHGJ048", "CHGJ050", "CHGJ052", "CHGJ053", "CHGJ055", "CHGJ057", "CHGJ058", "CHGJ062", "CHGJ065", "CHGJ066", "CHGJ067", "CHGJ069", "CHGJ070", "CHGJ071", "CHGJ072", "CHGJ073", "CHGJ074", "CHGJ076", "CHGJ077", "CHGJ078", "CHGJ080", "CHGJ081", "CHGJ082", "CHGJ083", "CHGJ085", "CHGJ086", "CHGJ087", "CHGJ088", "CHGJ089", "CHGJ090", "CHGJ091", "CHGJ092", "CHUS003", "CHUS004", "CHUS005", "CHUS006", "CHUS007", "CHUS008", "CHUS009", "CHUS010", "CHUS013", "CHUS015", "CHUS016", "CHUS019", "CHUS020", "CHUS021", "CHUS022", "CHUS026", "CHUS027", "CHUS028", "CHUS030", "CHUS031", "CHUS033", "CHUS035", "CHUS036", "CHUS038", "CHUS039", "CHUS040", "CHUS041", "CHUS042", "CHUS043", "CHUS045", "CHUS046", "CHUS047", "CHUS048", "CHUS049", "CHUS050", "CHUS051", "CHUS052", "CHUS053", "CHUS055", "CHUS056", "CHUS057", "CHUS058", "CHUS060", "CHUS061", "CHUS064", "CHUS065", "CHUS066", "CHUS067", "CHUS068", "CHUS069", "CHUS073", "CHUS074", "CHUS076", "CHUS077", "CHUS078", "CHUS080", "CHUS081", "CHUS083", "CHUS085", "CHUS086", "CHUS087", "CHUS088", "CHUS089", "CHUS090", "CHUS091", "CHUS094", "CHUS095", "CHUS096", "CHUS097", "CHUS098", "CHUS100", "CHUS101", "CHUM001", "CHUM002", "CHUM006", "CHUM007", "CHUM008", "CHUM010", "CHUM011", "CHUM012", "CHUM013", "CHUM014", "CHUM015", "CHUM016", "CHUM017", "CHUM018", "CHUM019", "CHUM021", "CHUM022", "CHUM023", "CHUM024", "CHUM026", "CHUM027", "CHUM029", "CHUM030", "CHUM032", "CHUM033", "CHUM034", "CHUM035", "CHUM036", "CHUM037", "CHUM038", "CHUM039", "CHUM040", "CHUM041", "CHUM042", "CHUM043", "CHUM044", "CHUM045", "CHUM046", "CHUM047", "CHUM048", "CHUM049", "CHUM050", "CHUM051", "CHUM053", "CHUM054", "CHUM055", "CHUM056", "CHUM057", "CHUM058", "CHUM059", "CHUM060", "CHUM061", "CHUM062", "CHUM063", "CHUM064", "CHUM065"], "val": ["CHMR001", "CHMR004", "CHMR005", "CHMR011", "CHMR012", "CHMR013", "CHMR014", "CHMR016", "CHMR020", "CHMR021", "CHMR023", "CHMR024", "CHMR025", "CHMR028", "CHMR029", "CHMR030", "CHMR034", "CHMR040"]} -------------------------------------------------------------------------------- /src/data/splits/train_val_new_split_0.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHUS052", "CHUS053", "CHMR024", "CHUS050", "CHUM038", "CHGJ030", "CHMR020", "CHMR021", "CHUS065", "CHGJ052", "CHMR014", "CHUS036", "CHMR040", "CHUM032", "CHUS033", "CHUS069", "CHUM044", "CHUS021", "CHUS013", "CHUS094", "CHGJ035", "CHUM006", "CHMR011", "CHUM002", "CHUM046", "CHGJ038", "CHUM043", "CHGJ073", "CHGJ085", "CHMR030", "CHUM065", "CHUM007", "CHUM061", "CHUS003", "CHUM037", "CHUS027", "CHGJ062", "CHMR013", "CHUM016", "CHUS030", "CHGJ086", "CHUS043", "CHGJ025", "CHUS091", "CHUM029", "CHGJ091", "CHGJ048", "CHUS051", "CHUS090", "CHUM026", "CHGJ046", "CHUM036", "CHUS009", "CHUM023", "CHGJ008", "CHUM012", "CHMR016", "CHUS097", "CHUM060", "CHUS042", "CHGJ078", "CHUM054", "CHMR005", "CHGJ026", "CHUS007", "CHUM014", "CHUS098", "CHUS004", "CHUS066", "CHUS035", "CHGJ018", "CHGJ077", "CHUS049", "CHUM048", "CHUM040", "CHUS078", "CHMR012", "CHUM015", "CHUS074", "CHUS045", "CHUS048", "CHGJ088", "CHGJ083", "CHUS055", "CHGJ050", "CHUM050", "CHGJ065", "CHUS068", "CHGJ072", "CHUS089", "CHGJ071", "CHUS077", "CHGJ037", "CHUS016", "CHUS083", "CHUS005", "CHUS067", "CHGJ028", "CHUS064", "CHUS056", "CHUM034", "CHUS010", "CHMR025", "CHUM001", "CHGJ013", "CHUM035", "CHGJ070", "CHUS095", "CHGJ057", "CHUS046", "CHUM059", "CHUM013", "CHUM064", "CHUM022", "CHMR001", "CHUM011", "CHMR034", "CHUM030", "CHUS031", "CHGJ031", "CHUS100", "CHGJ087", "CHUM062", "CHMR029", "CHUS096", "CHUM021", "CHUS038", "CHUM010", "CHUS088", "CHGJ067", "CHUS060", "CHUS076", "CHGJ034", "CHUM017", "CHUM063", "CHGJ058", "CHUS019", "CHUM058", "CHUM033", "CHUM045", "CHUM053", "CHUM041", "CHUS101", "CHUM039", "CHUM008", "CHGJ036", "CHUS041", "CHUS039", "CHGJ032", "CHUM056"], "val": ["CHUM051", "CHUM055", "CHUS073", "CHGJ010", "CHUS026", "CHUS022", "CHGJ039", "CHUS015", "CHUS061", "CHUS008", "CHGJ076", "CHGJ007", "CHUS087", "CHGJ029", "CHGJ017", "CHUM042", "CHGJ080", "CHUM019", "CHMR028", "CHGJ015", "CHGJ081", "CHGJ043", "CHGJ069", "CHGJ016", "CHUS020", "CHUS058", "CHGJ082", "CHUS047", "CHUM018", "CHUS080", "CHGJ055", "CHGJ066", "CHGJ089", "CHUM057", "CHUS081", "CHGJ090", "CHUS085", "CHUM049", "CHGJ074", "CHUS040", "CHUS086", "CHUS006", "CHUM027", "CHUS028", "CHUM024", "CHGJ053", "CHMR004", "CHGJ092", "CHUS057", "CHMR023", "CHUM047"]} -------------------------------------------------------------------------------- /src/data/splits/train_val_new_split_1.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHUM051", "CHUM055", "CHUS073", "CHGJ010", "CHUS026", "CHUS022", "CHGJ039", "CHUS015", "CHUS061", "CHUS008", "CHGJ076", "CHGJ007", "CHUS087", "CHGJ029", "CHGJ017", "CHUM042", "CHGJ080", "CHUM019", "CHMR028", "CHGJ015", "CHGJ081", "CHGJ043", "CHGJ069", "CHGJ016", "CHUS020", "CHUS058", "CHGJ082", "CHUS047", "CHUM018", "CHUS080", "CHGJ055", "CHGJ066", "CHGJ089", "CHUM057", "CHUS081", "CHGJ090", "CHUS085", "CHUM049", "CHGJ074", "CHUS040", "CHUS086", "CHUS006", "CHUM027", "CHUS028", "CHUM024", "CHGJ053", "CHMR004", "CHGJ092", "CHUS057", "CHMR023", "CHUM047", "CHUM002", "CHGJ038", "CHUM029", "CHGJ091", "CHUS051", "CHUS090", "CHGJ046", "CHUS009", "CHUM023", "CHGJ008", "CHUM012", "CHMR016", "CHUM060", "CHUS042", "CHGJ078", "CHMR005", "CHGJ026", "CHUS007", "CHUM014", "CHUS004", "CHUS066", "CHUS035", "CHGJ018", "CHGJ077", "CHUS049", "CHUM048", "CHUM040", "CHUS078", "CHMR012", "CHUM015", "CHUS074", "CHUS045", "CHUS048", "CHGJ088", "CHGJ083", "CHUM050", "CHGJ065", "CHUS068", "CHGJ072", "CHUS089", "CHGJ071", "CHUS077", "CHGJ037", "CHUS016", "CHUS083", "CHUS005", "CHUS067", "CHGJ028", "CHUS064", "CHUS056", "CHUM034", "CHUS010", "CHMR025", "CHUM001", "CHGJ013", "CHUM035", "CHGJ070", "CHUS095", "CHGJ057", "CHUS046", "CHUM059", "CHUM013", "CHUM064", "CHUM022", "CHMR001", "CHUM011", "CHMR034", "CHUM030", "CHUS031", "CHGJ031", "CHUS100", "CHGJ087", "CHUM062", "CHMR029", "CHUS096", "CHUM021", "CHUS038", "CHUM010", "CHUS088", "CHGJ067", "CHUS060", "CHUS076", "CHGJ034", "CHUM017", "CHUM063", "CHGJ058", "CHUS019", "CHUM058", "CHUM033", "CHUM045", "CHUM053", "CHUM041", "CHUS101", "CHUM039", "CHUM008", "CHGJ036", "CHUS041", "CHUS039", "CHGJ032", "CHUM056"], "val": ["CHUS052", "CHUS053", "CHMR024", "CHUS050", "CHUM038", "CHGJ030", "CHMR020", "CHMR021", "CHUS065", "CHGJ052", "CHMR014", "CHUS036", "CHMR040", "CHUM032", "CHUS033", "CHUS069", "CHUM044", "CHUS021", "CHUS013", "CHUS094", "CHGJ035", "CHUM006", "CHMR011", "CHUM046", "CHUM043", "CHGJ073", "CHGJ085", "CHMR030", "CHUM065", "CHUM007", "CHUM061", "CHUS003", "CHUM037", "CHUS027", "CHGJ062", "CHMR013", "CHUM016", "CHUS030", "CHGJ086", "CHUS043", "CHGJ025", "CHUS091", "CHGJ048", "CHUM026", "CHUM036", "CHUS097", "CHUM054", "CHUS098", "CHUS055", "CHGJ050"]} -------------------------------------------------------------------------------- /src/data/splits/train_val_new_split_2.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHUM051", "CHUM055", "CHUS073", "CHGJ010", "CHUS026", "CHUS022", "CHGJ039", "CHUS015", "CHUS061", "CHUS008", "CHGJ076", "CHGJ007", "CHUS087", "CHGJ029", "CHGJ017", "CHUM042", "CHGJ080", "CHUM019", "CHMR028", "CHGJ015", "CHGJ081", "CHGJ043", "CHGJ069", "CHGJ016", "CHUS020", "CHUS058", "CHGJ082", "CHUS047", "CHUM018", "CHUS080", "CHGJ055", "CHUS052", "CHGJ066", "CHGJ089", "CHUM057", "CHUS081", "CHGJ090", "CHUS085", "CHUM049", "CHGJ074", "CHUS040", "CHUS053", "CHUS086", "CHUS006", "CHUM027", "CHUS028", "CHMR024", "CHUM024", "CHGJ053", "CHUS050", "CHMR004", "CHGJ092", "CHUM038", "CHGJ030", "CHMR020", "CHMR021", "CHUS065", "CHUS057", "CHGJ052", "CHMR014", "CHUS036", "CHMR023", "CHMR040", "CHUM032", "CHUS033", "CHUS069", "CHUM044", "CHUM047", "CHUS021", "CHUS013", "CHUS094", "CHGJ035", "CHUM006", "CHMR011", "CHUM046", "CHUM043", "CHGJ073", "CHGJ085", "CHMR030", "CHUM065", "CHUM007", "CHUM061", "CHUS003", "CHUM037", "CHUS027", "CHGJ062", "CHMR013", "CHUM016", "CHUS030", "CHGJ086", "CHUS043", "CHGJ025", "CHUS091", "CHGJ048", "CHUM026", "CHUM036", "CHUS097", "CHUM054", "CHUS098", "CHGJ083", "CHUS055", "CHGJ050", "CHUM050", "CHGJ071", "CHUS005", "CHUS064", "CHUM034", "CHGJ013", "CHUM035", "CHUS095", "CHGJ057", "CHUS046", "CHUM059", "CHUM013", "CHUM022", "CHMR001", "CHUM011", "CHMR034", "CHUM030", "CHUS031", "CHGJ031", "CHUS100", "CHGJ087", "CHUM062", "CHMR029", "CHUS096", "CHUM021", "CHUS038", "CHUM010", "CHUS088", "CHGJ067", "CHUS060", "CHUS076", "CHGJ034", "CHUM017", "CHUM063", "CHGJ058", "CHUS019", "CHUM058", "CHUM033", "CHUM045", "CHUM053", "CHUM041", "CHUS101", "CHUM039", "CHUM008", "CHGJ036", "CHUS041", "CHUS039", "CHGJ032", "CHUM056"], "val": ["CHUM002", "CHGJ038", "CHUM029", "CHGJ091", "CHUS051", "CHUS090", "CHGJ046", "CHUS009", "CHUM023", "CHGJ008", "CHUM012", "CHMR016", "CHUM060", "CHUS042", "CHGJ078", "CHMR005", "CHGJ026", "CHUS007", "CHUM014", "CHUS004", "CHUS066", "CHUS035", "CHGJ018", "CHGJ077", "CHUS049", "CHUM048", "CHUM040", "CHUS078", "CHMR012", "CHUM015", "CHUS074", "CHUS045", "CHUS048", "CHGJ088", "CHGJ065", "CHUS068", "CHGJ072", "CHUS089", "CHUS077", "CHGJ037", "CHUS016", "CHUS083", "CHUS067", "CHGJ028", "CHUS056", "CHUS010", "CHMR025", "CHUM001", "CHGJ070", "CHUM064"]} -------------------------------------------------------------------------------- /src/data/splits/train_val_new_split_3.pkl: -------------------------------------------------------------------------------- 1 | {"train": ["CHUM051", "CHUM055", "CHUS073", "CHGJ010", "CHUS026", "CHUS022", "CHGJ039", "CHUS015", "CHUS061", "CHUS008", "CHGJ076", "CHGJ007", "CHUS087", "CHGJ029", "CHGJ017", "CHUM042", "CHGJ080", "CHUM019", "CHMR028", "CHGJ015", "CHGJ081", "CHGJ043", "CHGJ069", "CHGJ016", "CHUS020", "CHUS058", "CHGJ082", "CHUS047", "CHUM018", "CHUS080", "CHGJ055", "CHUS052", "CHGJ066", "CHGJ089", "CHUM057", "CHUS081", "CHGJ090", "CHUS085", "CHUM049", "CHGJ074", "CHUS040", "CHUS053", "CHUS086", "CHUS006", "CHUM027", "CHUS028", "CHMR024", "CHUM024", "CHGJ053", "CHUS050", "CHMR004", "CHGJ092", "CHUM038", "CHGJ030", "CHMR020", "CHMR021", "CHUS065", "CHUS057", "CHGJ052", "CHMR014", "CHUS036", "CHMR023", "CHMR040", "CHUM032", "CHUS033", "CHUS069", "CHUM044", "CHUM047", "CHUS021", "CHUS013", "CHUS094", "CHGJ035", "CHUM006", "CHMR011", "CHUM002", "CHUM046", "CHGJ038", "CHUM043", "CHGJ073", "CHGJ085", "CHMR030", "CHUM065", "CHUM007", "CHUM061", "CHUS003", "CHUM037", "CHUS027", "CHGJ062", "CHMR013", "CHUM016", "CHUS030", "CHGJ086", "CHUS043", "CHGJ025", "CHUS091", "CHUM029", "CHGJ091", "CHGJ048", "CHUS051", "CHUS090", "CHUM026", "CHGJ046", "CHUM036", "CHUS009", "CHUM023", "CHGJ008", "CHUM012", "CHMR016", "CHUS097", "CHUM060", "CHUS042", "CHGJ078", "CHUM054", "CHMR005", "CHGJ026", "CHUS007", "CHUM014", "CHUS098", "CHUS004", "CHUS066", "CHUS035", "CHGJ018", "CHGJ077", "CHUS049", "CHUM048", "CHUM040", "CHUS078", "CHMR012", "CHUM015", "CHUS074", "CHUS045", "CHUS048", "CHGJ088", "CHUS055", "CHGJ050", "CHGJ065", "CHUS068", "CHGJ072", "CHUS089", "CHUS077", "CHGJ037", "CHUS016", "CHUS083", "CHUS067", "CHGJ028", "CHUS056", "CHUS010", "CHMR025", "CHUM001", "CHGJ070", "CHUM064"], "val": ["CHGJ083", "CHUM050", "CHGJ071", "CHUS005", "CHUS064", "CHUM034", "CHGJ013", "CHUM035", "CHUS095", "CHGJ057", "CHUS046", "CHUM059", "CHUM013", "CHUM022", "CHMR001", "CHUM011", "CHMR034", "CHUM030", "CHUS031", "CHGJ031", "CHUS100", "CHGJ087", "CHUM062", "CHMR029", "CHUS096", "CHUM021", "CHUS038", "CHUM010", "CHUS088", "CHGJ067", "CHUS060", "CHUS076", "CHGJ034", "CHUM017", "CHUM063", "CHGJ058", "CHUS019", "CHUM058", "CHUM033", "CHUM045", "CHUM053", "CHUM041", "CHUS101", "CHUM039", "CHUM008", "CHGJ036", "CHUS041", "CHUS039", "CHGJ032", "CHUM056"]} -------------------------------------------------------------------------------- /model/predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import yaml 4 | import pathlib 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | torch.backends.cudnn.benchmark = True 9 | 10 | sys.path.append('../src/') 11 | sys.path.append('../src/data/') 12 | import dataset 13 | import transforms 14 | import utils 15 | import models 16 | import predictor 17 | 18 | 19 | def main(args): 20 | path_to_config = pathlib.Path(args.path) 21 | with open(path_to_config) as f: 22 | config = yaml.safe_load(f) 23 | 24 | # read config: 25 | path_to_data = pathlib.Path(config['path_to_data']) 26 | path_to_save_dir = pathlib.Path(config['path_to_save_dir']) 27 | path_to_weights = config['path_to_weights'] 28 | probs = config['probs'] 29 | num_workers = int(config['num_workers']) 30 | n_cls = int(config['n_cls']) 31 | in_channels = int(config['in_channels']) 32 | n_filters = int(config['n_filters']) 33 | reduction = int(config['reduction']) 34 | 35 | # test data paths: 36 | all_paths = utils.get_paths_to_patient_files(path_to_imgs=path_to_data, append_mask=False) 37 | 38 | # input transforms: 39 | input_transforms = transforms.Compose([ 40 | transforms.NormalizeIntensity(), 41 | transforms.ToTensor(mode='test') 42 | ]) 43 | 44 | # ensemble output transforms: 45 | output_transforms = [ 46 | transforms.InverseToTensor(), 47 | transforms.CheckOutputShape(shape=(144, 144, 144)) 48 | ] 49 | if not probs: 50 | output_transforms.append(transforms.ProbsToLabels()) 51 | 52 | output_transforms = transforms.Compose(output_transforms) 53 | 54 | # dataset and dataloader: 55 | data_set = dataset.HecktorDataset(all_paths, transforms=input_transforms, mode='test') 56 | data_loader = DataLoader(data_set, batch_size=1, shuffle=False, num_workers=num_workers) 57 | 58 | # model: 59 | model = models.FastSmoothSENormDeepUNet_supervision_skip_no_drop(in_channels, n_cls, n_filters, reduction) 60 | 61 | # init predictor: 62 | predictor_ = predictor.Predictor( 63 | model=model, 64 | path_to_model_weights=path_to_weights, 65 | dataloader=data_loader, 66 | output_transforms=output_transforms, 67 | path_to_save_dir=path_to_save_dir 68 | ) 69 | 70 | # check if multiple paths were provided to run an ensemble: 71 | if isinstance(path_to_weights, list): 72 | predictor_.ensemble_predict() 73 | 74 | elif isinstance(path_to_weights, str): 75 | predictor_.predict() 76 | 77 | else: 78 | raise ValueError(f"Argument 'path_to_weights' must be str or list of str, provided {type(path_to_weights)}") 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser(description='Model Inference Script') 83 | parser.add_argument("-p", "--path", type=str, required=True, help="path to the config file") 84 | args = parser.parse_args() 85 | main(args) 86 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class HecktorDataset(Dataset): 7 | """A class for fetching data samples. 8 | 9 | Parameters 10 | ---------- 11 | paths_to_samples : list 12 | A list wherein each element is a tuple with two (three) `pathlib.Path` objects for a single patient. 13 | The first one is the path to the CT image, the second one - to the PET image. If `mode == 'train'`, a path to 14 | a ground truth mask must be provided for each patient. 15 | transforms 16 | Transformations applied to each data sample. 17 | mode : str 18 | Must be `train` or `test`. If `train`, a ground truth mask is loaded using a path from `paths_to_samples` and 19 | added to a sample. 20 | If `test`, an additional information (an affine array), that describes the position of the image data 21 | in a reference space, is added to each data sample. Ground truth masks are not loaded in this mode. 22 | 23 | Returns 24 | ------- 25 | dict 26 | A dictionary corresponding to a data sample. 27 | Keys: 28 | id : A patient's ID. 29 | input : A numpy array containing CT & PET images stacked along the last (4th) dimension. 30 | target : A numpy array containing a ground truth mask. 31 | affine : A numpy array with the position of the image data in a reference space (needed for resampling). 32 | """ 33 | 34 | def __init__(self, paths_to_samples, transforms=None, mode='train'): 35 | self.paths_to_samples = paths_to_samples 36 | self.transforms = transforms 37 | if mode not in ['train', 'test']: 38 | raise ValueError(f"Argument 'mode' must be 'train' or 'test'. Received {mode}") 39 | self.mode = mode 40 | if mode == 'train': 41 | self.num_of_seqs = len(paths_to_samples[0]) - 1 42 | else: 43 | self.num_of_seqs = len(paths_to_samples[0]) 44 | 45 | def __len__(self): 46 | return len(self.paths_to_samples) 47 | 48 | def __getitem__(self, index): 49 | sample = dict() 50 | 51 | id_ = self.paths_to_samples[index][0].parent.stem 52 | sample['id'] = id_ 53 | 54 | img = [self.read_data(self.paths_to_samples[index][i]) for i in range(self.num_of_seqs)] 55 | img = np.stack(img, axis=-1) 56 | sample['input'] = img 57 | 58 | if self.mode == 'train': 59 | mask = self.read_data(self.paths_to_samples[index][-1]) 60 | mask = np.expand_dims(mask, axis=3) 61 | 62 | assert img.shape[:-1] == mask.shape[:-1], \ 63 | f"Shape mismatch for the image with the shape {img.shape} and the mask with the shape {mask.shape}." 64 | 65 | sample['target'] = mask 66 | 67 | else: 68 | sample['affine'] = self.read_data(self.paths_to_samples[index][0], False).affine 69 | if self.transforms: 70 | sample = self.transforms(sample) 71 | 72 | return sample 73 | 74 | @staticmethod 75 | def read_data(path_to_nifti, return_numpy=True): 76 | """Read a NIfTI image. Return a numpy array (default) or `nibabel.nifti1.Nifti1Image` object""" 77 | if return_numpy: 78 | return nib.load(str(path_to_nifti)).get_fdata() 79 | return nib.load(str(path_to_nifti)) 80 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class BasicConv3d(nn.Module): 7 | def __init__(self, in_channels, out_channels, **kwargs): 8 | super(BasicConv3d, self).__init__() 9 | self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) 10 | self.norm = nn.InstanceNorm3d(out_channels, affine=True) 11 | 12 | def forward(self, x): 13 | x = self.conv(x) 14 | x = self.norm(x) 15 | x = F.relu(x, inplace=True) 16 | return x 17 | 18 | 19 | class FastSmoothSENorm(nn.Module): 20 | class SEWeights(nn.Module): 21 | def __init__(self, in_channels, reduction=2): 22 | super().__init__() 23 | self.conv1 = nn.Conv3d(in_channels, in_channels // reduction, kernel_size=1, stride=1, padding=0, bias=True) 24 | self.conv2 = nn.Conv3d(in_channels // reduction, in_channels, kernel_size=1, stride=1, padding=0, bias=True) 25 | 26 | def forward(self, x): 27 | b, c, d, h, w = x.size() 28 | out = torch.mean(x.view(b, c, -1), dim=-1).view(b, c, 1, 1, 1) # output_shape: in_channels x (1, 1, 1) 29 | out = F.relu(self.conv1(out)) 30 | out = self.conv2(out) 31 | return out 32 | 33 | def __init__(self, in_channels, reduction=2): 34 | super(FastSmoothSENorm, self).__init__() 35 | self.norm = nn.InstanceNorm3d(in_channels, affine=False) 36 | self.gamma = self.SEWeights(in_channels, reduction) 37 | self.beta = self.SEWeights(in_channels, reduction) 38 | 39 | def forward(self, x): 40 | gamma = torch.sigmoid(self.gamma(x)) 41 | beta = torch.tanh(self.beta(x)) 42 | x = self.norm(x) 43 | return gamma * x + beta 44 | 45 | 46 | class FastSmoothSeNormConv3d(nn.Module): 47 | def __init__(self, in_channels, out_channels, reduction=2, **kwargs): 48 | super(FastSmoothSeNormConv3d, self).__init__() 49 | self.conv = nn.Conv3d(in_channels, out_channels, bias=True, **kwargs) 50 | self.norm = FastSmoothSENorm(out_channels, reduction) 51 | 52 | def forward(self, x): 53 | x = self.conv(x) 54 | x = F.relu(x, inplace=True) 55 | x = self.norm(x) 56 | return x 57 | 58 | 59 | class RESseNormConv3d(nn.Module): 60 | def __init__(self, in_channels, out_channels, reduction=2, **kwargs): 61 | super().__init__() 62 | self.conv1 = FastSmoothSeNormConv3d(in_channels, out_channels, reduction, **kwargs) 63 | 64 | if in_channels != out_channels: 65 | self.res_conv = FastSmoothSeNormConv3d(in_channels, out_channels, reduction, kernel_size=1, stride=1, padding=0) 66 | else: 67 | self.res_conv = None 68 | 69 | def forward(self, x): 70 | residual = self.res_conv(x) if self.res_conv else x 71 | x = self.conv1(x) 72 | x += residual 73 | return x 74 | 75 | 76 | class UpConv(nn.Module): 77 | def __init__(self, in_channels, out_channels, reduction=2, scale=2): 78 | super().__init__() 79 | self.scale = scale 80 | self.conv = FastSmoothSeNormConv3d(in_channels, out_channels, reduction, kernel_size=1, stride=1, padding=0) 81 | 82 | def forward(self, x): 83 | x = self.conv(x) 84 | x = F.interpolate(x, scale_factor=self.scale, mode='trilinear', align_corners=False) 85 | return x 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 1st Place Solution for the [HECKTOR](https://www.aicrowd.com/challenges/miccai-2020-hecktor) challenge 2 | 3 | > The official implementation of the winning solution for the MICCAI 2020 HEad and neCK TumOR segmentation challenge (HECKTOR). 4 | 5 | ### Main requirements 6 | - PyTorch 1.6.0 (cuda 10.2) 7 | - SimpleITK 1.2.4 (ITK 4.13) 8 | - nibabel 3.1.1 9 | - skimage 0.17.2 10 | 11 | ### Dataset 12 | Train and test images are available through the competition [website](https://www.aicrowd.com/challenges/miccai-2020-hecktor). The concise description of the dataset is present in `notebooks/make_dataset.ipynb`. 13 | 14 | ### Data preprocessing 15 | The data preprocessing consists of: 16 | - Resampling the pair of PET & CT images for each patient to a common reference space. 17 | - Extracting the region of interest (bounding box) of the size of 144x144x144 voxels. 18 | - Saving the transformed images in NIfTI format. 19 | 20 | To prepare the dataset in _an interactive manner_, one can use `notebooks/make_dataset.ipynb`, that gives an explanation about each step. 21 | Alternatively, _the fully automated data preprocessing_ can be performed by running `src/data/make_dataset.py`. All required parameters must be provided as _a single config file_ in the YAML data format: 22 | ```sh 23 | cd hecktor/src/data/ 24 | python make_dataset.py -p hecktor/config/make_dataset.yaml 25 | ``` 26 | Use `/config/make_dataset.yaml` to specify all required parameters. 27 | 28 | ### Training 29 | For training the model from scratch, one can use `notebooks/model_train.ipynb` setting all parameters right in the notebook. Otherwise, with all parameters written in the config file, one needs to run `hecktor/model/train.py` from its current directory: 30 | ```sh 31 | cd hecktor/model/ 32 | python train.py -p hecktor/config/model_train.yaml 33 | ``` 34 | All parameters are described in `hecktor/config/model_train.yaml` that should be used as a template to build your own config file. 35 | 36 | ### Inference 37 | For inference, run the script `hecktor/model/predict.py` with parameters defined in the config file `hecktor/config/model_predict.yaml`: 38 | ```sh 39 | cd hecktor/model/ 40 | python predict.py -p hecktor/config/model_predict.yaml 41 | ``` 42 | 43 | ### Model weights 44 | To reproduce results presented in the paper on different train / validation folds, one must download and save pretrained weights in the folder `hecktor/model/weights/`. Weights of a single model are stored in files named `{split}_best_model_weights.pt`. IDs of the patients in the train / validation folds for each data split are stored in `train_val_{split}.pkl` files located in the folder `hecktor/src/data/splits/`. 45 | 46 | In order to download weights of all pretrained model (eight models in total) built on the different train / validation, use the following [link](https://www.dropbox.com/sh/kkvqwn0bpnt1ynk/AABGNdpzTSiIKjiGV5K2ta0Na?dl=0). 47 | 48 | 49 | ### Example 50 |  51 | 52 | 53 | ### Paper 54 | If you use this code in you research, please cite the following paper ([arXiv](https://arxiv.org/abs/2102.10446)): 55 | > Iantsen A., Visvikis D., Hatt M. (2021) Squeeze-and-Excitation Normalization for Automated Delineation of Head and Neck Primary Tumors in Combined PET and CT Images. In: Andrearczyk V., Oreiller V., Depeursinge A. (eds) Head and Neck Tumor Segmentation. HECKTOR 2020. Lecture Notes in Computer Science, vol 12603. Springer, Cham. https://doi.org/10.1007/978-3-030-67194-5_4 56 | -------------------------------------------------------------------------------- /model/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import yaml 4 | import pathlib 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | torch.backends.cudnn.benchmark = True 9 | 10 | sys.path.append('../src/') 11 | sys.path.append('../src/data/') 12 | import dataset 13 | import transforms 14 | import losses 15 | import metrics 16 | import trainer 17 | import models 18 | 19 | import utils 20 | 21 | 22 | def main(args): 23 | path_to_config = pathlib.Path(args.path) 24 | with open(path_to_config) as f: 25 | config = yaml.safe_load(f) 26 | 27 | # read config: 28 | path_to_data = pathlib.Path(config['path_to_data']) 29 | path_to_pkl = pathlib.Path(config['path_to_pkl']) 30 | path_to_save_dir = pathlib.Path(config['path_to_save_dir']) 31 | 32 | train_batch_size = int(config['train_batch_size']) 33 | val_batch_size = int(config['val_batch_size']) 34 | num_workers = int(config['num_workers']) 35 | lr = float(config['lr']) 36 | n_epochs = int(config['n_epochs']) 37 | n_cls = int(config['n_cls']) 38 | in_channels = int(config['in_channels']) 39 | n_filters = int(config['n_filters']) 40 | reduction = int(config['reduction']) 41 | T_0 = int(config['T_0']) 42 | eta_min = float(config['eta_min']) 43 | baseline = config['baseline'] 44 | 45 | # train and val data paths: 46 | all_paths = utils.get_paths_to_patient_files(path_to_imgs=path_to_data, append_mask=True) 47 | train_paths, val_paths = utils.get_train_val_paths(all_paths=all_paths, path_to_train_val_pkl=path_to_pkl) 48 | train_paths = train_paths[:2] 49 | val_paths = val_paths[:2] 50 | 51 | # train and val data transforms: 52 | train_transforms = transforms.Compose([ 53 | transforms.RandomRotation(p=0.5, angle_range=[0, 45]), 54 | transforms.Mirroring(p=0.5), 55 | transforms.NormalizeIntensity(), 56 | transforms.ToTensor() 57 | ]) 58 | 59 | val_transforms = transforms.Compose([ 60 | transforms.NormalizeIntensity(), 61 | transforms.ToTensor() 62 | ]) 63 | 64 | # datasets: 65 | train_set = dataset.HecktorDataset(train_paths, transforms=train_transforms) 66 | val_set = dataset.HecktorDataset(val_paths, transforms=val_transforms) 67 | 68 | # dataloaders: 69 | train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, num_workers=num_workers) 70 | val_loader = DataLoader(val_set, batch_size=val_batch_size, shuffle=False, num_workers=num_workers) 71 | 72 | dataloaders = { 73 | 'train': train_loader, 74 | 'val': val_loader 75 | } 76 | 77 | if baseline: 78 | model = models.BaselineUNet(in_channels, n_cls, n_filters) 79 | else: 80 | model = models.FastSmoothSENormDeepUNet_supervision_skip_no_drop(in_channels, n_cls, n_filters, reduction) 81 | 82 | criterion = losses.Dice_and_FocalLoss() 83 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99)) 84 | metric = metrics.dice 85 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T_0, eta_min=eta_min) 86 | 87 | trainer_ = trainer.ModelTrainer( 88 | model=model, 89 | dataloaders=dataloaders, 90 | criterion=criterion, 91 | optimizer=optimizer, 92 | metric=metric, 93 | scheduler=scheduler, 94 | num_epochs=n_epochs, 95 | parallel=True 96 | ) 97 | 98 | trainer_.train_model() 99 | trainer_.save_results(path_to_dir=path_to_save_dir) 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = argparse.ArgumentParser(description='Model Training Script') 104 | parser.add_argument("-p", "--path", type=str, required=True, help="path to the config file") 105 | args = parser.parse_args() 106 | main(args) 107 | -------------------------------------------------------------------------------- /src/data/make_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | import pathlib 5 | 6 | import pandas as pd 7 | import SimpleITK as sitk 8 | from tqdm import tqdm 9 | 10 | from utils import read_nifti, write_nifti, get_attributes, resample_sitk_image 11 | 12 | 13 | def main(args): 14 | path_to_config = pathlib.Path(args.path) 15 | with open(path_to_config) as f: 16 | config = yaml.safe_load(f) 17 | 18 | path_to_input = pathlib.Path(config['path_to_input']) 19 | path_to_bb = pathlib.Path(config['path_to_bb']) 20 | path_to_output = pathlib.Path(config['path_to_output']) 21 | is_mask_available = config['is_mask_available'] 22 | verbose = config['verbose'] 23 | 24 | bb = pd.read_csv(path_to_bb) 25 | patients = list(bb.PatientID) 26 | print(f"Total number of patients: {len(patients)}") 27 | 28 | print(f"Resampled images will be saved in {path_to_output}") 29 | if not os.path.exists(path_to_output): 30 | os.makedirs(path_to_output, exist_ok=True) 31 | 32 | for p in tqdm(patients) if verbose else patients: 33 | # Read images: 34 | img_ct = read_nifti(path_to_input / p / (p + '_ct.nii.gz')) 35 | img_pt = read_nifti(path_to_input / p / (p + '_pt.nii.gz')) 36 | if is_mask_available: 37 | mask = read_nifti(path_to_input / p / (p + '_ct_gtvt.nii.gz')) 38 | 39 | # Get bounding boxes: 40 | pt1 = bb.loc[bb.PatientID == p, ['x1', 'y1', 'z1']] 41 | pt2 = bb.loc[bb.PatientID == p, ['x2', 'y2', 'z2']] 42 | pt1, pt2 = tuple(*pt1.values), tuple(*pt2.values) 43 | 44 | # Convert physcial points into array indexes: 45 | pt1_ct = img_ct.TransformPhysicalPointToIndex(pt1) 46 | pt1_pt = img_pt.TransformPhysicalPointToIndex(pt1) 47 | if is_mask_available: 48 | pt1_mask = mask.TransformPhysicalPointToIndex(pt1) 49 | 50 | pt2_ct = img_ct.TransformPhysicalPointToIndex(pt2) 51 | pt2_pt = img_pt.TransformPhysicalPointToIndex(pt2) 52 | if is_mask_available: 53 | pt2_mask = mask.TransformPhysicalPointToIndex(pt2) 54 | 55 | # Exctract the patch: 56 | cr_img_ct = img_ct[pt1_ct[0]: pt2_ct[0], pt1_ct[1]: pt2_ct[1], pt1_ct[2]: pt2_ct[2]] 57 | cr_img_pt = img_pt[pt1_pt[0]: pt2_pt[0], pt1_pt[1]: pt2_pt[1], pt1_pt[2]: pt2_pt[2]] 58 | if is_mask_available: 59 | cr_mask = mask[pt1_mask[0]: pt2_mask[0], pt1_mask[1]: pt2_mask[1], pt1_mask[2]: pt2_mask[2]] 60 | 61 | # Resample all images using CT attributes: 62 | # CT: 63 | cr_img_ct = resample_sitk_image( 64 | cr_img_ct, 65 | new_spacing=[1, 1, 1], 66 | new_size=[144, 144, 144], 67 | interpolator=sitk.sitkLinear) 68 | target_size = list(cr_img_ct.GetSize()) 69 | attributes = get_attributes(cr_img_ct) 70 | 71 | # PT: 72 | cr_img_pt = resample_sitk_image( 73 | cr_img_pt, 74 | new_spacing=[1, 1, 1], 75 | new_size=target_size, 76 | attributes=attributes, 77 | interpolator=sitk.sitkLinear 78 | ) 79 | 80 | # Mask: 81 | if is_mask_available: 82 | cr_mask = resample_sitk_image( 83 | cr_mask, 84 | new_spacing=[1, 1, 1], 85 | new_size=target_size, 86 | attributes=attributes, 87 | interpolator=sitk.sitkNearestNeighbor 88 | ) 89 | 90 | # Save resampled images: 91 | if not os.path.exists(path_to_output / p): 92 | os.makedirs(path_to_output / p, exist_ok=True) 93 | 94 | write_nifti(cr_img_ct, path_to_output / p / (p + '_ct.nii.gz')) 95 | write_nifti(cr_img_pt, path_to_output / p / (p + '_pt.nii.gz')) 96 | if is_mask_available: 97 | write_nifti(cr_mask, path_to_output / p / (p + '_ct_gtvt.nii.gz')) 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser(description='Data Preprocessing Script') 102 | parser.add_argument("-p", "--path", type=str, required=True, help="path to the config file") 103 | args = parser.parse_args() 104 | main(args) 105 | -------------------------------------------------------------------------------- /src/predictor.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import os 3 | import nibabel as nib 4 | import torch 5 | 6 | 7 | class Predictor: 8 | """ 9 | A class for building a model predictions. 10 | 11 | Parameters 12 | ---------- 13 | model : a subclass of `torch.nn.Module` 14 | A model used for prediction. 15 | path_to_model_weights : list of (`pathlib.Path` or str) or (`pathlib.Path` or str) 16 | A path to model weights. Provide a path and use `self.predict` to build predictions using a single model. 17 | Use a list of paths and `self.ensemble_predict` to get predictions for an ensemble (the same architecture but 18 | different weights). 19 | dataloaders : `torch.utils.data.DataLoader` 20 | A dataloader fetching test samples. 21 | output_transforms 22 | Transforms applied to outputs. 23 | path_to_save_dir : `pathlib.Path` or str 24 | A path to a directory to save predictions 25 | """ 26 | 27 | def __init__(self, 28 | model, 29 | path_to_model_weights, # list of paths or path 30 | dataloader, 31 | output_transforms=None, 32 | path_to_save_dir='.', 33 | device="cuda:0"): 34 | 35 | self.model = model 36 | self.path_to_model_weights = [pathlib.Path(p) for p in path_to_model_weights] \ 37 | if isinstance(path_to_model_weights, list) else pathlib.Path(path_to_model_weights) 38 | 39 | self.dataloader = dataloader 40 | self.output_transforms = output_transforms 41 | self.path_to_save_dir = pathlib.Path(path_to_save_dir) 42 | self.device = torch.device(device if torch.cuda.is_available() else "cpu") 43 | 44 | def predict(self): 45 | """Run inference for an single model""" 46 | 47 | if self.device.type == 'cpu': 48 | print(f'Run inference for a model on CPU') 49 | else: 50 | print(f'Run inference for a model' 51 | f' on {torch.cuda.get_device_name(torch.cuda.current_device())}') 52 | 53 | # Check if the directory exists: 54 | if not os.path.exists(self.path_to_save_dir): 55 | os.makedirs(self.path_to_save_dir, exist_ok=True) 56 | 57 | # Send model to device: 58 | self.model = self.model.to(self.device) 59 | self.model.eval() 60 | 61 | # Load model weights: 62 | self.model = self._load_model_weights(self.model, self.path_to_model_weights) 63 | 64 | # Inference: 65 | with torch.no_grad(): 66 | for sample in self.dataloader: 67 | input = sample['input'] 68 | input = input.to(self.device) 69 | 70 | output = self.model(input) 71 | output = output.cpu() 72 | 73 | sample['output'] = output 74 | 75 | # apply output transforms, if any: 76 | if self.output_transforms: 77 | sample = self.output_transforms(sample) 78 | 79 | # Save prediction: 80 | self._save_preds(sample, self.path_to_save_dir) 81 | 82 | print(f'Predictions have been saved in {self.path_to_save_dir}') 83 | 84 | def ensemble_predict(self): 85 | """Run inference for an ensemble of models""" 86 | 87 | if self.device.type == 'cpu': 88 | print(f'Run inference for an ensemble of {len(self.path_to_model_weights)} models on CPU') 89 | else: 90 | print(f'Run inference for an ensemble of {len(self.path_to_model_weights)} models' 91 | f' on {torch.cuda.get_device_name(torch.cuda.current_device())}') 92 | 93 | # Check if the directory exists: 94 | if not os.path.exists(self.path_to_save_dir): 95 | os.makedirs(self.path_to_save_dir, exist_ok=True) 96 | 97 | # Send model to device: 98 | self.model = self.model.to(self.device) 99 | self.model.eval() 100 | 101 | # Inference: 102 | with torch.no_grad(): 103 | for sample in self.dataloader: 104 | input = sample['input'] 105 | input = input.to(self.device) 106 | 107 | ensemble_output = 0 108 | for path in self.path_to_model_weights: 109 | self.model = self._load_model_weights(self.model, path) 110 | output = self.model(input) 111 | output = output.cpu() 112 | 113 | ensemble_output += output 114 | 115 | ensemble_output /= len(self.path_to_model_weights) 116 | sample['output'] = ensemble_output 117 | 118 | # apply (ensemble) output transforms, if any: 119 | if self.output_transforms: 120 | sample = self.output_transforms(sample) 121 | 122 | # Save prediction: 123 | self._save_preds(sample, self.path_to_save_dir) 124 | 125 | print(f'Predictions have been saved in {self.path_to_save_dir}') 126 | 127 | @staticmethod 128 | def _save_preds(sample, path_to_dir): 129 | preds = sample['output'] 130 | sample_id = sample['id'][0] 131 | affine = sample['affine'][0].numpy() 132 | preds = nib.Nifti1Image(preds, affine=affine) 133 | nib.save(preds, str(path_to_dir / (sample_id + '.nii.gz'))) 134 | 135 | @staticmethod 136 | def _load_model_weights(model, path_to_model_weights): 137 | model_state_dict = torch.load(path_to_model_weights, map_location=lambda storage, loc: storage) 138 | try: 139 | model.load_state_dict(model_state_dict, strict=True) 140 | except RuntimeError: 141 | # if model was trained in parallel 142 | from collections import OrderedDict 143 | new_model_state_dict = OrderedDict() 144 | for k, v in model_state_dict.items(): 145 | k = k.replace('module.', '') 146 | new_model_state_dict[k] = v 147 | model.load_state_dict(new_model_state_dict, strict=True) 148 | return model 149 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import json 4 | import numpy as np 5 | import SimpleITK as sitk 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | def get_paths_to_patient_files(path_to_imgs, append_mask=True): 11 | """ 12 | Get paths to all data samples, i.e., CT & PET images (and a mask) for each patient. 13 | 14 | Parameters 15 | ---------- 16 | path_to_imgs : str 17 | A path to a directory with patients' data. Each folder in the directory must corresponds to a single patient. 18 | append_mask : bool 19 | Used to append a path to a ground truth mask. 20 | 21 | Returns 22 | ------- 23 | list of tuple 24 | A list wherein each element is a tuple with two (three) `pathlib.Path` objects for a single patient. 25 | The first one is the path to the CT image, the second one - to the PET image. If `append_mask` is True, 26 | the path to the ground truth mask is added. 27 | """ 28 | path_to_imgs = pathlib.Path(path_to_imgs) 29 | 30 | patients = [p for p in os.listdir(path_to_imgs) if os.path.isdir(path_to_imgs / p)] 31 | paths = [] 32 | for p in patients: 33 | path_to_ct = path_to_imgs / p / (p + '_ct.nii.gz') 34 | path_to_pt = path_to_imgs / p / (p + '_pt.nii.gz') 35 | 36 | if append_mask: 37 | path_to_mask = path_to_imgs / p / (p + '_ct_gtvt.nii.gz') 38 | paths.append((path_to_ct, path_to_pt, path_to_mask)) 39 | else: 40 | paths.append((path_to_ct, path_to_pt)) 41 | return paths 42 | 43 | 44 | def get_train_val_paths(all_paths, path_to_train_val_pkl): 45 | """" 46 | Split a list of all paths to patients' data into train & validation parts using patients' IDs. 47 | 48 | Parameters 49 | ---------- 50 | all_paths: list 51 | An output of `get_paths_to_patient_files`. 52 | path_to_train_val_pkl: str 53 | A path to a pkl file storing train & validation IDs. 54 | 55 | Returns 56 | ------- 57 | (list, list) 58 | Two lists of paths to train & validation data samples. 59 | """ 60 | path_to_train_val_pkl = pathlib.Path(path_to_train_val_pkl) 61 | with open(path_to_train_val_pkl) as f: 62 | train_val_split = json.load(f) 63 | 64 | train_paths = [path for path in all_paths 65 | if any(patient_id + '_ct.nii.gz' in str(path[0]) for patient_id in train_val_split['train'])] 66 | 67 | val_paths = [path for path in all_paths 68 | if any(patient_id + '_ct.nii.gz' in str(path[0]) for patient_id in train_val_split['val'])] 69 | 70 | return train_paths, val_paths 71 | 72 | 73 | def read_nifti(path): 74 | """Read a NIfTI image. Return a SimpleITK Image.""" 75 | nifti = sitk.ReadImage(str(path)) 76 | return nifti 77 | 78 | 79 | def write_nifti(sitk_img, path): 80 | """Save a SimpleITK Image to disk in NIfTI format.""" 81 | writer = sitk.ImageFileWriter() 82 | writer.SetImageIO("NiftiImageIO") 83 | writer.SetFileName(str(path)) 84 | writer.Execute(sitk_img) 85 | 86 | 87 | def get_attributes(sitk_image): 88 | """Get physical space attributes (meta-data) of the image.""" 89 | attributes = {} 90 | attributes['orig_pixelid'] = sitk_image.GetPixelIDValue() 91 | attributes['orig_origin'] = sitk_image.GetOrigin() 92 | attributes['orig_direction'] = sitk_image.GetDirection() 93 | attributes['orig_spacing'] = np.array(sitk_image.GetSpacing()) 94 | attributes['orig_size'] = np.array(sitk_image.GetSize(), dtype=np.int) 95 | return attributes 96 | 97 | 98 | def resample_sitk_image(sitk_image, 99 | new_spacing=[1, 1, 1], 100 | new_size=None, 101 | attributes=None, 102 | interpolator=sitk.sitkLinear, 103 | fill_value=0): 104 | """ 105 | Resample a SimpleITK Image. 106 | 107 | Parameters 108 | ---------- 109 | sitk_image : sitk.Image 110 | An input image. 111 | new_spacing : list of int 112 | A distance between adjacent voxels in each dimension given in physical units (mm) for the output image. 113 | new_size : list of int or None 114 | A number of pixels per dimension of the output image. If None, `new_size` is computed based on the original 115 | input size, original spacing and new spacing. 116 | attributes : dict or None 117 | The desired output image's spatial domain (its meta-data). If None, the original image's meta-data is used. 118 | interpolator 119 | Available interpolators: 120 | - sitk.sitkNearestNeighbor : nearest 121 | - sitk.sitkLinear : linear 122 | - sitk.sitkGaussian : gaussian 123 | - sitk.sitkLabelGaussian : label_gaussian 124 | - sitk.sitkBSpline : bspline 125 | - sitk.sitkHammingWindowedSinc : hamming_sinc 126 | - sitk.sitkCosineWindowedSinc : cosine_windowed_sinc 127 | - sitk.sitkWelchWindowedSinc : welch_windowed_sinc 128 | - sitk.sitkLanczosWindowedSinc : lanczos_windowed_sinc 129 | fill_value : int or float 130 | A value used for padding, if the output image size is less than `new_size`. 131 | 132 | Returns 133 | ------- 134 | sitk.Image 135 | The resampled image. 136 | 137 | Notes 138 | ----- 139 | This implementation is based on https://github.com/deepmedic/SimpleITK-examples/blob/master/examples/resample_isotropically.py 140 | """ 141 | sitk_interpolator = interpolator 142 | 143 | # provided attributes: 144 | if attributes: 145 | orig_pixelid = attributes['orig_pixelid'] 146 | orig_origin = attributes['orig_origin'] 147 | orig_direction = attributes['orig_direction'] 148 | orig_spacing = attributes['orig_spacing'] 149 | orig_size = attributes['orig_size'] 150 | 151 | else: 152 | # use original attributes: 153 | orig_pixelid = sitk_image.GetPixelIDValue() 154 | orig_origin = sitk_image.GetOrigin() 155 | orig_direction = sitk_image.GetDirection() 156 | orig_spacing = np.array(sitk_image.GetSpacing()) 157 | orig_size = np.array(sitk_image.GetSize(), dtype=np.int) 158 | 159 | # new image size: 160 | if not new_size: 161 | new_size = orig_size * (orig_spacing / new_spacing) 162 | new_size = np.ceil(new_size).astype(np.int) # Image dimensions are in integers 163 | new_size = [int(s) for s in new_size] # SimpleITK expects lists, not ndarrays 164 | 165 | resample_filter = sitk.ResampleImageFilter() 166 | resampled_sitk_image = resample_filter.Execute(sitk_image, 167 | new_size, 168 | sitk.Transform(), 169 | sitk_interpolator, 170 | orig_origin, 171 | new_spacing, 172 | orig_direction, 173 | fill_value, 174 | orig_pixelid) 175 | 176 | return resampled_sitk_image 177 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from layers import BasicConv3d, FastSmoothSeNormConv3d, RESseNormConv3d, UpConv 6 | 7 | 8 | class BaselineUNet(nn.Module): 9 | def __init__(self, in_channels, n_cls, n_filters): 10 | super(BaselineUNet, self).__init__() 11 | self.in_channels = in_channels 12 | self.n_cls = 1 if n_cls == 2 else n_cls 13 | self.n_filters = n_filters 14 | 15 | self.block_1_1_left = BasicConv3d(in_channels, n_filters, kernel_size=3, stride=1, padding=1) 16 | self.block_1_2_left = BasicConv3d(n_filters, n_filters, kernel_size=3, stride=1, padding=1) 17 | 18 | self.pool_1 = nn.MaxPool3d(kernel_size=2, stride=2) # 64, 1/2 19 | self.block_2_1_left = BasicConv3d(n_filters, 2 * n_filters, kernel_size=3, stride=1, padding=1) 20 | self.block_2_2_left = BasicConv3d(2 * n_filters, 2 * n_filters, kernel_size=3, stride=1, padding=1) 21 | 22 | self.pool_2 = nn.MaxPool3d(kernel_size=2, stride=2) # 128, 1/4 23 | self.block_3_1_left = BasicConv3d(2 * n_filters, 4 * n_filters, kernel_size=3, stride=1, padding=1) 24 | self.block_3_2_left = BasicConv3d(4 * n_filters, 4 * n_filters, kernel_size=3, stride=1, padding=1) 25 | 26 | self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2) # 256, 1/8 27 | self.block_4_1_left = BasicConv3d(4 * n_filters, 8 * n_filters, kernel_size=3, stride=1, padding=1) 28 | self.block_4_2_left = BasicConv3d(8 * n_filters, 8 * n_filters, kernel_size=3, stride=1, padding=1) 29 | 30 | self.upconv_3 = nn.ConvTranspose3d(8 * n_filters, 4 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 31 | self.block_3_1_right = BasicConv3d((4 + 4) * n_filters, 4 * n_filters, kernel_size=3, stride=1, padding=1) 32 | self.block_3_2_right = BasicConv3d(4 * n_filters, 4 * n_filters, kernel_size=3, stride=1, padding=1) 33 | 34 | self.upconv_2 = nn.ConvTranspose3d(4 * n_filters, 2 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 35 | self.block_2_1_right = BasicConv3d((2 + 2) * n_filters, 2 * n_filters, kernel_size=3, stride=1, padding=1) 36 | self.block_2_2_right = BasicConv3d(2 * n_filters, 2 * n_filters, kernel_size=3, stride=1, padding=1) 37 | 38 | self.upconv_1 = nn.ConvTranspose3d(2 * n_filters, n_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 39 | self.block_1_1_right = BasicConv3d((1 + 1) * n_filters, n_filters, kernel_size=3, stride=1, padding=1) 40 | self.block_1_2_right = BasicConv3d(n_filters, n_filters, kernel_size=3, stride=1, padding=1) 41 | 42 | self.conv1x1 = nn.Conv3d(n_filters, self.n_cls, kernel_size=1, stride=1, padding=0) 43 | 44 | def forward(self, x): 45 | 46 | ds0 = self.block_1_2_left(self.block_1_1_left(x)) 47 | ds1 = self.block_2_2_left(self.block_2_1_left(self.pool_1(ds0))) 48 | ds2 = self.block_3_2_left(self.block_3_1_left(self.pool_2(ds1))) 49 | x = self.block_4_2_left(self.block_4_1_left(self.pool_3(ds2))) 50 | 51 | x = self.block_3_2_right(self.block_3_1_right(torch.cat([self.upconv_3(x), ds2], 1))) 52 | x = self.block_2_2_right(self.block_2_1_right(torch.cat([self.upconv_2(x), ds1], 1))) 53 | x = self.block_1_2_right(self.block_1_1_right(torch.cat([self.upconv_1(x), ds0], 1))) 54 | 55 | x = self.conv1x1(x) 56 | 57 | if self.n_cls == 1: 58 | return torch.sigmoid(x) 59 | else: 60 | return F.softmax(x, dim=1) 61 | 62 | 63 | class FastSmoothSENormDeepUNet_supervision_skip_no_drop(nn.Module): 64 | """The model presented in the paper. This model is one of the multiple models that we tried in our experiments 65 | that it why it has such an awkward name.""" 66 | 67 | def __init__(self, in_channels, n_cls, n_filters, reduction=2, return_logits=False): 68 | super(FastSmoothSENormDeepUNet_supervision_skip_no_drop, self).__init__() 69 | self.in_channels = in_channels 70 | self.n_cls = 1 if n_cls == 2 else n_cls 71 | self.n_filters = n_filters 72 | self.return_logits = return_logits 73 | 74 | self.block_1_1_left = RESseNormConv3d(in_channels, n_filters, reduction, kernel_size=7, stride=1, padding=3) 75 | self.block_1_2_left = RESseNormConv3d(n_filters, n_filters, reduction, kernel_size=3, stride=1, padding=1) 76 | 77 | self.pool_1 = nn.MaxPool3d(kernel_size=2, stride=2) 78 | self.block_2_1_left = RESseNormConv3d(n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 79 | self.block_2_2_left = RESseNormConv3d(2 * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 80 | self.block_2_3_left = RESseNormConv3d(2 * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 81 | 82 | self.pool_2 = nn.MaxPool3d(kernel_size=2, stride=2) 83 | self.block_3_1_left = RESseNormConv3d(2 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 84 | self.block_3_2_left = RESseNormConv3d(4 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 85 | self.block_3_3_left = RESseNormConv3d(4 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 86 | 87 | self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2) 88 | self.block_4_1_left = RESseNormConv3d(4 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 89 | self.block_4_2_left = RESseNormConv3d(8 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 90 | self.block_4_3_left = RESseNormConv3d(8 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 91 | 92 | self.pool_4 = nn.MaxPool3d(kernel_size=2, stride=2) 93 | self.block_5_1_left = RESseNormConv3d(8 * n_filters, 16 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 94 | self.block_5_2_left = RESseNormConv3d(16 * n_filters, 16 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 95 | self.block_5_3_left = RESseNormConv3d(16 * n_filters, 16 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 96 | 97 | self.upconv_4 = nn.ConvTranspose3d(16 * n_filters, 8 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 98 | self.block_4_1_right = FastSmoothSeNormConv3d((8 + 8) * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 99 | self.block_4_2_right = FastSmoothSeNormConv3d(8 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 100 | self.vision_4 = UpConv(8 * n_filters, n_filters, reduction, scale=8) 101 | 102 | self.upconv_3 = nn.ConvTranspose3d(8 * n_filters, 4 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 103 | self.block_3_1_right = FastSmoothSeNormConv3d((4 + 4) * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 104 | self.block_3_2_right = FastSmoothSeNormConv3d(4 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 105 | self.vision_3 = UpConv(4 * n_filters, n_filters, reduction, scale=4) 106 | 107 | self.upconv_2 = nn.ConvTranspose3d(4 * n_filters, 2 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 108 | self.block_2_1_right = FastSmoothSeNormConv3d((2 + 2) * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 109 | self.block_2_2_right = FastSmoothSeNormConv3d(2 * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1) 110 | self.vision_2 = UpConv(2 * n_filters, n_filters, reduction, scale=2) 111 | 112 | self.upconv_1 = nn.ConvTranspose3d(2 * n_filters, 1 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 113 | self.block_1_1_right = FastSmoothSeNormConv3d((1 + 1) * n_filters, n_filters, reduction, kernel_size=3, stride=1, padding=1) 114 | self.block_1_2_right = FastSmoothSeNormConv3d(n_filters, n_filters, reduction, kernel_size=3, stride=1, padding=1) 115 | 116 | self.conv1x1 = nn.Conv3d(1 * n_filters, self.n_cls, kernel_size=1, stride=1, padding=0) 117 | 118 | def forward(self, x): 119 | 120 | ds0 = self.block_1_2_left(self.block_1_1_left(x)) 121 | ds1 = self.block_2_3_left(self.block_2_2_left(self.block_2_1_left(self.pool_1(ds0)))) 122 | ds2 = self.block_3_3_left(self.block_3_2_left(self.block_3_1_left(self.pool_2(ds1)))) 123 | ds3 = self.block_4_3_left(self.block_4_2_left(self.block_4_1_left(self.pool_3(ds2)))) 124 | x = self.block_5_3_left(self.block_5_2_left(self.block_5_1_left(self.pool_4(ds3)))) 125 | 126 | x = self.block_4_2_right(self.block_4_1_right(torch.cat([self.upconv_4(x), ds3], 1))) 127 | sv4 = self.vision_4(x) 128 | 129 | x = self.block_3_2_right(self.block_3_1_right(torch.cat([self.upconv_3(x), ds2], 1))) 130 | sv3 = self.vision_3(x) 131 | 132 | x = self.block_2_2_right(self.block_2_1_right(torch.cat([self.upconv_2(x), ds1], 1))) 133 | sv2 = self.vision_2(x) 134 | 135 | x = self.block_1_1_right(torch.cat([self.upconv_1(x), ds0], 1)) 136 | x = x + sv4 + sv3 + sv2 137 | x = self.block_1_2_right(x) 138 | 139 | x = self.conv1x1(x) 140 | 141 | if self.return_logits: 142 | return x 143 | else: 144 | if self.n_cls == 1: 145 | return torch.sigmoid(x) 146 | else: 147 | return F.softmax(x, dim=1) 148 | -------------------------------------------------------------------------------- /src/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | from skimage.transform import rotate 6 | 7 | 8 | class Compose: 9 | def __init__(self, transforms=None): 10 | self.transforms = transforms 11 | 12 | def __call__(self, sample): 13 | for transform in self.transforms: 14 | sample = transform(sample) 15 | 16 | return sample 17 | 18 | 19 | class ToTensor: 20 | def __init__(self, mode='train'): 21 | if mode not in ['train', 'test']: 22 | raise ValueError(f"Argument 'mode' must be 'train' or 'test'. Received {mode}") 23 | self.mode = mode 24 | 25 | def __call__(self, sample): 26 | if self.mode == 'train': 27 | img, mask = sample['input'], sample['target'] 28 | img = np.transpose(img, axes=[3, 0, 1, 2]) 29 | mask = np.transpose(mask, axes=[3, 0, 1, 2]) 30 | img = torch.from_numpy(img).float() 31 | mask = torch.from_numpy(mask).float() 32 | sample['input'], sample['target'] = img, mask 33 | 34 | else: # if self.mode == 'test' 35 | img = sample['input'] 36 | img = np.transpose(img, axes=[3, 0, 1, 2]) 37 | img = torch.from_numpy(img).float() 38 | sample['input'] = img 39 | 40 | return sample 41 | 42 | 43 | class Mirroring: 44 | def __init__(self, p=0.5): 45 | self.p = p 46 | 47 | def __call__(self, sample): 48 | if random.random() < self.p: 49 | img, mask = sample['input'], sample['target'] 50 | 51 | n_axes = random.randint(0, 3) 52 | random_axes = random.sample(range(3), n_axes) 53 | 54 | img = np.flip(img, axis=tuple(random_axes)) 55 | mask = np.flip(mask, axis=tuple(random_axes)) 56 | 57 | sample['input'], sample['target'] = img.copy(), mask.copy() 58 | 59 | return sample 60 | 61 | 62 | class NormalizeIntensity: 63 | 64 | def __call__(self, sample): 65 | img = sample['input'] 66 | img[:, :, :, 0] = self.normalize_ct(img[:, :, :, 0]) 67 | img[:, :, :, 1] = self.normalize_pt(img[:, :, :, 1]) 68 | 69 | sample['input'] = img 70 | return sample 71 | 72 | @staticmethod 73 | def normalize_ct(img): 74 | norm_img = np.clip(img, -1024, 1024) / 1024 75 | return norm_img 76 | 77 | @staticmethod 78 | def normalize_pt(img): 79 | mean = np.mean(img) 80 | std = np.std(img) 81 | return (img - mean) / (std + 1e-3) 82 | 83 | 84 | class RandomRotation: 85 | def __init__(self, p=0.5, angle_range=[5, 15]): 86 | self.p = p 87 | self.angle_range = angle_range 88 | 89 | def __call__(self, sample): 90 | if random.random() < self.p: 91 | img, mask = sample['input'], sample['target'] 92 | 93 | num_of_seqs = img.shape[-1] 94 | n_axes = random.randint(1, 3) 95 | random_axes = random.sample([0, 1, 2], n_axes) 96 | 97 | for axis in random_axes: 98 | 99 | angle = random.randrange(*self.angle_range) 100 | angle = -angle if random.random() < 0.5 else angle 101 | 102 | for i in range(num_of_seqs): 103 | img[:, :, :, i] = RandomRotation.rotate_3d_along_axis(img[:, :, :, i], angle, axis, 1) 104 | 105 | mask[:, :, :, 0] = RandomRotation.rotate_3d_along_axis(mask[:, :, :, 0], angle, axis, 0) 106 | 107 | sample['input'], sample['target'] = img, mask 108 | return sample 109 | 110 | @staticmethod 111 | def rotate_3d_along_axis(img, angle, axis, order): 112 | 113 | if axis == 0: 114 | rot_img = rotate(img, angle, order=order, preserve_range=True) 115 | 116 | if axis == 1: 117 | rot_img = np.transpose(img, axes=(1, 2, 0)) 118 | rot_img = rotate(rot_img, angle, order=order, preserve_range=True) 119 | rot_img = np.transpose(rot_img, axes=(2, 0, 1)) 120 | 121 | if axis == 2: 122 | rot_img = np.transpose(img, axes=(2, 0, 1)) 123 | rot_img = rotate(rot_img, angle, order=order, preserve_range=True) 124 | rot_img = np.transpose(rot_img, axes=(1, 2, 0)) 125 | 126 | return rot_img 127 | 128 | 129 | class ZeroPadding: 130 | 131 | def __init__(self, target_shape, mode='train'): 132 | self.target_shape = np.array(target_shape) # without channel dimension 133 | if mode not in ['train', 'test']: 134 | raise ValueError(f"Argument 'mode' must be 'train' or 'test'. Received {mode}") 135 | self.mode = mode 136 | 137 | def __call__(self, sample): 138 | if self.mode == 'train': 139 | img, mask = sample['input'], sample['target'] 140 | 141 | input_shape = np.array(img.shape[:-1]) # last (channel) dimension is ignored 142 | d_x, d_y, d_z = self.target_shape - input_shape 143 | d_x, d_y, d_z = int(d_x), int(d_y), int(d_z) 144 | 145 | if not all(i == 0 for i in (d_x, d_y, d_z)): 146 | positive = [i if i > 0 else 0 for i in (d_x, d_y, d_z)] 147 | negative = [i if i < 0 else None for i in (d_x, d_y, d_z)] 148 | 149 | # padding for positive values: 150 | img = np.pad(img, ((0, positive[0]), (0, positive[1]), (0, positive[2]), (0, 0)), 'constant', constant_values=(0, 0)) 151 | mask = np.pad(mask, ((0, positive[0]), (0, positive[1]), (0, positive[2]), (0, 0)), 'constant', constant_values=(0, 0)) 152 | 153 | # cropping for negative values: 154 | img = img[: negative[0], : negative[1], : negative[2], :].copy() 155 | mask = mask[: negative[0], : negative[1], : negative[2], :].copy() 156 | 157 | assert img.shape[:-1] == mask.shape[:-1], f'Shape mismatch for the image {img.shape[:-1]} and mask {mask.shape[:-1]}' 158 | 159 | sample['input'], sample['target'] = img, mask 160 | 161 | return sample 162 | 163 | else: # if self.mode == 'test' 164 | img = sample['input'] 165 | 166 | input_shape = np.array(img.shape[:-1]) # last (channel) dimension is ignored 167 | d_x, d_y, d_z = self.target_shape - input_shape 168 | d_x, d_y, d_z = int(d_x), int(d_y), int(d_z) 169 | 170 | if not all(i == 0 for i in (d_x, d_y, d_z)): 171 | positive = [i if i > 0 else 0 for i in (d_x, d_y, d_z)] 172 | negative = [i if i < 0 else None for i in (d_x, d_y, d_z)] 173 | 174 | # padding for positive values: 175 | img = np.pad(img, ((0, positive[0]), (0, positive[1]), (0, positive[2]), (0, 0)), 'constant', constant_values=(0, 0)) 176 | 177 | # cropping for negative values: 178 | img = img[: negative[0], : negative[1], : negative[2], :].copy() 179 | 180 | sample['input'] = img 181 | 182 | return sample 183 | 184 | 185 | class ExtractPatch: 186 | """Extracts a patch of a given size from an image (4D numpy array).""" 187 | 188 | def __init__(self, patch_size, p_tumor=0.5): 189 | self.patch_size = patch_size # without channel dimension! 190 | self.p_tumor = p_tumor # probs to extract a patch with a tumor 191 | 192 | def __call__(self, sample): 193 | img = sample['input'] 194 | mask = sample['target'] 195 | 196 | assert all(x <= y for x, y in zip(self.patch_size, img.shape[:-1])), \ 197 | f"Cannot extract the patch with the shape {self.patch_size} from " \ 198 | f"the image with the shape {img.shape}." 199 | 200 | # patch_size components: 201 | ps_x, ps_y, ps_z = self.patch_size 202 | 203 | if random.random() < self.p_tumor: 204 | # coordinates of the tumor's center: 205 | xs, ys, zs, _ = np.where(mask != 0) 206 | tumor_center_x = np.min(xs) + (np.max(xs) - np.min(xs)) // 2 207 | tumor_center_y = np.min(ys) + (np.max(ys) - np.min(ys)) // 2 208 | tumor_center_z = np.min(zs) + (np.max(zs) - np.min(zs)) // 2 209 | 210 | # compute the origin of the patch: 211 | patch_org_x = random.randint(tumor_center_x - ps_x, tumor_center_x) 212 | patch_org_x = np.clip(patch_org_x, 0, img.shape[0] - ps_x) 213 | 214 | patch_org_y = random.randint(tumor_center_y - ps_y, tumor_center_y) 215 | patch_org_y = np.clip(patch_org_y, 0, img.shape[1] - ps_y) 216 | 217 | patch_org_z = random.randint(tumor_center_z - ps_z, tumor_center_z) 218 | patch_org_z = np.clip(patch_org_z, 0, img.shape[2] - ps_z) 219 | else: 220 | patch_org_x = random.randint(0, img.shape[0] - ps_x) 221 | patch_org_y = random.randint(0, img.shape[1] - ps_y) 222 | patch_org_z = random.randint(0, img.shape[2] - ps_z) 223 | 224 | # extract the patch: 225 | patch_img = img[patch_org_x: patch_org_x + ps_x, 226 | patch_org_y: patch_org_y + ps_y, 227 | patch_org_z: patch_org_z + ps_z, 228 | :].copy() 229 | 230 | patch_mask = mask[patch_org_x: patch_org_x + ps_x, 231 | patch_org_y: patch_org_y + ps_y, 232 | patch_org_z: patch_org_z + ps_z, 233 | :].copy() 234 | 235 | assert patch_img.shape[:-1] == self.patch_size, \ 236 | f"Shape mismatch for the patch with the shape {patch_img.shape[:-1]}, " \ 237 | f"whereas the required shape is {self.patch_size}." 238 | 239 | sample['input'] = patch_img 240 | sample['target'] = patch_mask 241 | 242 | return sample 243 | 244 | 245 | class InverseToTensor: 246 | def __call__(self, sample): 247 | output = sample['output'] 248 | 249 | output = torch.squeeze(output) # squeeze the batch and channel dimensions 250 | output = output.numpy() 251 | 252 | sample['output'] = output 253 | return sample 254 | 255 | 256 | class CheckOutputShape: 257 | def __init__(self, shape=(144, 144, 144)): 258 | self.shape = shape 259 | 260 | def __call__(self, sample): 261 | output = sample['output'] 262 | assert output.shape == self.shape, \ 263 | f'Received wrong output shape. Must be {self.shape}, but received {output.shape}.' 264 | return sample 265 | 266 | 267 | class ProbsToLabels: 268 | def __call__(self, sample): 269 | output = sample['output'] 270 | output = (output > 0.5).astype(int) # get binary label 271 | sample['output'] = output 272 | return sample 273 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import copy 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import torch 8 | 9 | 10 | class ModelTrainer: 11 | """ 12 | A class for fitting a model. 13 | 14 | Parameters 15 | ---------- 16 | model : a subclass of `torch.nn.Module` 17 | A model to fit. 18 | dataloaders : dict of `torch.utils.data.DataLoader` 19 | A dictionary with 'train' and 'val' keys specifying dataloaders for model training and validation. 20 | criterion : a subclass of `torch.nn.Module` 21 | A loss function used for model training. 22 | optimizer : a subclass of `torch.optim.Optimizer` 23 | An optimizer for training a model. 24 | metric : function or a subclass of `torch.nn.Module` 25 | A metric used for evaluation. 26 | mode : str 27 | Must be 'min' or 'max'. If 'max', a model with the highest metric will be treated as the best one. 28 | scheduler : a class from `torch.optim.lr_scheduler` (a subclass of _LRScheduler) 29 | A method to adjust a learning rate during training. 30 | num_epochs : int 31 | A number of epochs for training. 32 | parallel : bool 33 | Train a model on multiple GPUs using `torch.nn.DataParallel`. 34 | cuda_device : str 35 | A CUDA device used for training and validation. If the device is unavailable, 36 | all computation are performed on a CPU. 37 | save_last_model : bool 38 | If 'true', a checkpoint of the last epoch will be saved (for inference and/or resuming training). 39 | scheduler_step_per_epoch : bool 40 | If 'true', a learning rate adjustment is performed after each epoch. Otherwise, after each training batch. 41 | 42 | Attributes 43 | ---------- 44 | learning_curves : dict of dict 45 | A dictionary containing train & validation learning curves (loss and metric). 46 | best_val_epoch : int 47 | Indicates an epoch with the best metric on a validation set. 48 | best_model_wts 49 | Weights of the best model. 50 | checkpoint 51 | The model weights and optimizer state after the last epoch. 52 | """ 53 | 54 | def __init__(self, model, dataloaders, criterion, optimizer, 55 | metric=None, mode='max', scheduler=None, num_epochs=25, 56 | parallel=False, cuda_device="cuda:0", 57 | save_last_model=True, scheduler_step_per_epoch=True): 58 | 59 | self.model = model 60 | self.dataloaders = dataloaders 61 | self.criterion = criterion 62 | self.metric = metric 63 | self.mode = mode 64 | self.optimizer = optimizer 65 | self.scheduler = scheduler 66 | self.num_epochs = num_epochs 67 | self.parallel = parallel 68 | self.device = torch.device(cuda_device if torch.cuda.is_available() else "cpu") 69 | self.save_last_model = save_last_model 70 | self.scheduler_step_per_epoch = scheduler_step_per_epoch 71 | 72 | # Dicts for saving train and val losses: 73 | self.learning_curves = dict() 74 | self.learning_curves['loss'], self.learning_curves['metric'] = dict(), dict() 75 | self.learning_curves['loss']['train'], self.learning_curves['loss']['val'] = [], [] 76 | self.learning_curves['metric']['train'], self.learning_curves['metric']['val'] = [], [] 77 | 78 | # Summary: Best epoch, loss, metric and best model weights: 79 | self.best_val_epoch = 0 80 | self.best_val_loss = float('inf') 81 | if self.mode == 'max': 82 | self.best_val_avg_metric = -float('inf') 83 | else: 84 | self.best_val_avg_metric = float('inf') 85 | self.best_val_metric = 0.0 86 | self.best_model_wts = None 87 | self.checkpoint = None # last model and optimizer weights 88 | 89 | def train_model(self): 90 | """Fit a model.""" 91 | 92 | if self.device.type == 'cpu': 93 | print('Start training the model on CPU') 94 | elif self.parallel and torch.cuda.device_count() > 1: 95 | print(f'Start training the model on {torch.cuda.device_count()} ' 96 | f'{torch.cuda.get_device_name(torch.cuda.current_device())} in parallel') 97 | self.model = torch.nn.DataParallel(self.model) 98 | else: 99 | print(f'Start training the model on {torch.cuda.get_device_name(torch.cuda.current_device())}') 100 | 101 | self.model = self.model.to(self.device) 102 | 103 | for epoch in range(self.num_epochs): 104 | print(f'Epoch {epoch} / {self.num_epochs - 1}') 105 | print('-' * 20) 106 | 107 | # Each epoch has a training and validation phase: 108 | for phase in ['train', 'val']: 109 | if phase == 'train': 110 | self.model.train() # Set model to training mode 111 | else: 112 | self.model.eval() # Set model to evaluate mode 113 | 114 | phase_loss = 0.0 # Train or val loss 115 | phase_metric = 0.0 116 | 117 | # Track history only if in train phase: 118 | with torch.set_grad_enabled(phase == 'train'): 119 | # Iterate over data batches: 120 | batch = 0 121 | for sample in self.dataloaders[phase]: 122 | input, target = sample['input'], sample['target'] 123 | input, target = input.to(self.device), target.to(self.device) 124 | 125 | # Forward pass: 126 | output = self.model(input) 127 | 128 | loss = self.criterion(output, target) 129 | metric = self.metric(output.detach(), target.detach()) 130 | 131 | # Losses and metric: 132 | phase_loss += loss.item() 133 | phase_metric += metric.item() 134 | 135 | with np.printoptions(precision=3, suppress=True): 136 | print(f'batch: {batch} batch loss: {loss:.3f} \tmetric: {metric:.3f}') 137 | 138 | del input, target, output, metric 139 | 140 | # Backward pass + optimize only if in training phase: 141 | if phase == 'train': 142 | loss.backward() 143 | self.optimizer.step() 144 | 145 | # zero the parameter gradients: 146 | self.optimizer.zero_grad() 147 | 148 | if self.scheduler and not self.scheduler_step_per_epoch: 149 | self.scheduler.step() 150 | 151 | del loss 152 | batch += 1 153 | 154 | phase_loss /= len(self.dataloaders[phase]) 155 | phase_metric /= len(self.dataloaders[phase]) 156 | self.learning_curves['loss'][phase].append(phase_loss) 157 | self.learning_curves['metric'][phase].append(phase_metric) 158 | 159 | print(f'{phase.upper()} loss: {phase_loss:.3f} \tavg_metric: {np.mean(phase_metric):.3f}') 160 | 161 | # Save summary if it is the best val results so far: 162 | if phase == 'val': 163 | if self.mode == 'max' and np.mean(phase_metric) > self.best_val_avg_metric: 164 | self.best_val_epoch = epoch 165 | self.best_val_loss = phase_loss 166 | self.best_val_avg_metric = np.mean(phase_metric) 167 | self.best_val_metric = phase_metric 168 | self.best_model_wts = copy.deepcopy(self.model.state_dict()) 169 | 170 | if self.mode == 'min' and np.mean(phase_metric) < self.best_val_avg_metric: 171 | self.best_val_epoch = epoch 172 | self.best_val_loss = phase_loss 173 | self.best_val_avg_metric = np.mean(phase_metric) 174 | self.best_val_metric = phase_metric 175 | self.best_model_wts = copy.deepcopy(self.model.state_dict()) 176 | 177 | # Adjust learning rate after val phase: 178 | if self.scheduler and self.scheduler_step_per_epoch: 179 | if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 180 | self.scheduler.step(np.mean(phase_metric)) 181 | else: 182 | self.scheduler.step() 183 | 184 | if self.save_last_model: 185 | self.checkpoint = {'model_state_dict': copy.deepcopy(self.model.state_dict()), 186 | 'optimizer_state_dict': copy.deepcopy(self.optimizer.state_dict())} 187 | 188 | def save_results(self, path_to_dir): 189 | """" 190 | Save results in a directory. The method must be used after training. 191 | 192 | A short summary is stored in a csv file ('summary.csv'). Weights of the best model are stored in 193 | 'best_model_weights.pt'. A checkpoint of the last epoch is stored in 'last_model_checkpoint.tar'. Two plots 194 | for the loss function and metric are stored in 'loss_plot.png' and 'metric_plot.png', respectively. 195 | 196 | Parameters 197 | ---------- 198 | path_to_dir : str 199 | A path to the directory for storing all results. 200 | """ 201 | 202 | path_to_dir = pathlib.Path(path_to_dir) 203 | 204 | # Check if the directory exists: 205 | if not os.path.exists(path_to_dir): 206 | os.makedirs(path_to_dir) 207 | 208 | # Write a short summary in a csv file: 209 | with open(path_to_dir / 'summary.csv', 'w', newline='', encoding='utf-8') as summary: 210 | summary.write(f'SUMMARY OF THE EXPERIMENT:\n\n') 211 | summary.write(f'BEST VAL EPOCH: {self.best_val_epoch}\n') 212 | summary.write(f'BEST VAL LOSS: {self.best_val_loss}\n') 213 | summary.write(f'BEST VAL AVG metric: {self.best_val_avg_metric}\n') 214 | summary.write(f'BEST VAL metric: {self.best_val_metric}\n') 215 | 216 | # Save best model weights: 217 | torch.save(self.best_model_wts, path_to_dir / 'best_model_weights.pt') 218 | 219 | # Save last model weights (checkpoint): 220 | if self.save_last_model: 221 | torch.save(self.checkpoint, path_to_dir / 'last_model_checkpoint.tar') 222 | 223 | # Save learning curves as pandas df: 224 | df_learning_curves = pd.DataFrame.from_dict({ 225 | 'loss_train': self.learning_curves['loss']['train'], 226 | 'loss_val': self.learning_curves['loss']['val'], 227 | 'metric_train': self.learning_curves['metric']['train'], 228 | 'metric_val': self.learning_curves['metric']['val'] 229 | }) 230 | df_learning_curves.to_csv(path_to_dir / 'learning_curves.csv', sep=';') 231 | 232 | # Save learning curves' plots in png files: 233 | # Loss figure: 234 | plt.figure(figsize=(17.5, 10)) 235 | plt.plot(range(self.num_epochs), self.learning_curves['loss']['train'], label='train') 236 | plt.plot(range(self.num_epochs), self.learning_curves['loss']['val'], label='val') 237 | plt.xlabel('Epoch', fontsize=20) 238 | plt.ylabel('Loss', fontsize=20) 239 | plt.xticks(fontsize=15) 240 | plt.yticks(fontsize=15) 241 | plt.legend(fontsize=20) 242 | plt.grid() 243 | plt.savefig(path_to_dir / 'loss_plot.png', bbox_inches='tight') 244 | 245 | # metric figure: 246 | train_avg_metric = [np.mean(i) for i in self.learning_curves['metric']['train']] 247 | val_avg_metric = [np.mean(i) for i in self.learning_curves['metric']['val']] 248 | 249 | plt.figure(figsize=(17.5, 10)) 250 | plt.plot(range(self.num_epochs), train_avg_metric, label='train') 251 | plt.plot(range(self.num_epochs), val_avg_metric, label='val') 252 | plt.xlabel('Epoch', fontsize=20) 253 | plt.ylabel('Avg metric', fontsize=20) 254 | plt.xticks(fontsize=15) 255 | plt.yticks(fontsize=15) 256 | plt.legend(fontsize=20) 257 | plt.grid() 258 | plt.savefig(path_to_dir / 'metric_plot.png', bbox_inches='tight') 259 | 260 | print(f'All results have been saved in {path_to_dir}') 261 | -------------------------------------------------------------------------------- /notebooks/make_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Data Preprocessing for the MICCAI 2020 HEad and neCK TumOR segmentation challenge [(HECKTOR)](https://www.aicrowd.com/challenges/miccai-2020-hecktor)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import sys\n", 18 | "import pathlib\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "import pandas as pd\n", 22 | "import SimpleITK as sitk\n", 23 | "from tqdm.notebook import tqdm\n", 24 | "\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "%matplotlib inline\n", 27 | "\n", 28 | "sys.path.append('../')\n", 29 | "from src.data.utils import read_nifti, write_nifti, get_attributes, resample_sitk_image" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "### Summary:" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "Dataset:\n", 44 | "- Each data sample (patient) consists of PET & CT images and a GTVt (primary Gross Tumor Volume) mask provided in NIfTI format.\n", 45 | "- PET & CT images for a single patient might occupy different regions in physical space, i.e., the images have a different size (number of pixels per dimension), origin, spacing and direction cosine matrix (axis directions in physical space).\n", 46 | "- For each case, a bounding box of the size of 144x144x144 mm is available. Segmentation must be performed within the bounding box. \n", 47 | "\n", 48 | "This notebook shows the way to transform (resample) a pair of PET & CT images for each patient to a common reference space and to extract a region of interest (a bounding box). Transformed images will be saved in NIfTI format.\n", 49 | "\n", 50 | "**From now onward, the train set preprocessing will be demonstrated. For the test set preprocessing, all operations with ground truth labels (segmentation masks) must be omitted.**" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "#### Input / Output paths:" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "path_to_input = pathlib.Path('C:/inserm/hecktor/hecktor_train/hecktor_nii/')\n", 67 | "path_to_bb = pathlib.Path('C:/inserm/hecktor/hecktor_train/bbox.csv')\n", 68 | "\n", 69 | "path_to_output = pathlib.Path('C:/inserm/hecktor/hecktor_train/hecktor_nii_resampled/')" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "#### Bounding boxes:" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/html": [ 87 | "
| \n", 105 | " | PatientID | \n", 106 | "x1 | \n", 107 | "x2 | \n", 108 | "y1 | \n", 109 | "y2 | \n", 110 | "z1 | \n", 111 | "z2 | \n", 112 | "
|---|---|---|---|---|---|---|---|
| 0 | \n", 117 | "CHGJ007 | \n", 118 | "-65.039062 | \n", 119 | "75.585938 | \n", 120 | "-166.992188 | \n", 121 | "-26.367188 | \n", 122 | "-204.050262 | \n", 123 | "-60.170746 | \n", 124 | "
| 1 | \n", 127 | "CHGJ008 | \n", 128 | "-65.039062 | \n", 129 | "75.585938 | \n", 130 | "-166.992188 | \n", 131 | "-26.367188 | \n", 132 | "-460.319519 | \n", 133 | "-316.438660 | \n", 134 | "
| 2 | \n", 137 | "CHGJ010 | \n", 138 | "-72.070312 | \n", 139 | "68.554688 | \n", 140 | "-135.351562 | \n", 141 | "5.273438 | \n", 142 | "-232.130219 | \n", 143 | "-88.250702 | \n", 144 | "
| 3 | \n", 147 | "CHGJ013 | \n", 148 | "-75.585938 | \n", 149 | "65.039062 | \n", 150 | "-152.929688 | \n", 151 | "-12.304688 | \n", 152 | "-242.630219 | \n", 153 | "-98.750702 | \n", 154 | "
| 4 | \n", 157 | "CHGJ015 | \n", 158 | "-65.039062 | \n", 159 | "75.585938 | \n", 160 | "-152.929688 | \n", 161 | "-12.304688 | \n", 162 | "-243.780273 | \n", 163 | "-99.900757 | \n", 164 | "