├── runs └── __init__.py ├── utils ├── __init__.py ├── utils.py └── data_utils.py ├── dataset └── __init__.py ├── networks ├── __init__.py └── unetr.py ├── optimizers ├── __init__.py └── lr_scheduler.py ├── pretrained_models └── __init__.py ├── requirements.txt ├── test.py ├── README.md ├── trainer.py └── main.py /runs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrained_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.1 2 | monai==0.7.0 3 | nibabel==3.1.1 4 | tqdm==4.59.0 5 | einops==0.3.0 6 | tensorboardX==2.1 7 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import numpy as np 13 | import torch 14 | 15 | 16 | def dice(x, y): 17 | intersect = np.sum(np.sum(np.sum(x * y))) 18 | y_sum = np.sum(np.sum(np.sum(y))) 19 | if y_sum == 0: 20 | return 0.0 21 | x_sum = np.sum(np.sum(np.sum(x))) 22 | return 2 * intersect / (x_sum + y_sum) 23 | 24 | 25 | class AverageMeter(object): 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = np.where(self.count > 0, self.sum / self.count, self.sum) 40 | 41 | 42 | def distributed_all_gather( 43 | tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None 44 | ): 45 | 46 | if world_size is None: 47 | world_size = torch.distributed.get_world_size() 48 | if valid_batch_size is not None: 49 | valid_batch_size = min(valid_batch_size, world_size) 50 | elif is_valid is not None: 51 | is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device) 52 | if not no_barrier: 53 | torch.distributed.barrier() 54 | tensor_list_out = [] 55 | with torch.no_grad(): 56 | if is_valid is not None: 57 | is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)] 58 | torch.distributed.all_gather(is_valid_list, is_valid) 59 | is_valid = [x.item() for x in is_valid_list] 60 | for tensor in tensor_list: 61 | gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] 62 | torch.distributed.all_gather(gather_list, tensor) 63 | if valid_batch_size is not None: 64 | gather_list = gather_list[:valid_batch_size] 65 | elif is_valid is not None: 66 | gather_list = [g for g, v in zip(gather_list, is_valid_list) if v] 67 | if out_numpy: 68 | gather_list = [t.cpu().numpy() for t in gather_list] 69 | tensor_list_out.append(gather_list) 70 | return tensor_list_out 71 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import argparse 13 | import os 14 | 15 | import numpy as np 16 | import torch 17 | from networks.unetr import UNETR 18 | from trainer import dice 19 | from utils.data_utils import get_loader 20 | 21 | from monai.inferers import sliding_window_inference 22 | 23 | parser = argparse.ArgumentParser(description="UNETR segmentation pipeline") 24 | parser.add_argument( 25 | "--pretrained_dir", default="./pretrained_models/", type=str, help="pretrained checkpoint directory" 26 | ) 27 | parser.add_argument("--data_dir", default="/dataset/dataset0/", type=str, help="dataset directory") 28 | parser.add_argument("--json_list", default="dataset_0.json", type=str, help="dataset json file") 29 | parser.add_argument( 30 | "--pretrained_model_name", default="UNETR_model_best_acc.pth", type=str, help="pretrained model name" 31 | ) 32 | parser.add_argument( 33 | "--saved_checkpoint", default="ckpt", type=str, help="Supports torchscript or ckpt pretrained checkpoint type" 34 | ) 35 | parser.add_argument("--mlp_dim", default=3072, type=int, help="mlp dimention in ViT encoder") 36 | parser.add_argument("--hidden_size", default=768, type=int, help="hidden size dimention in ViT encoder") 37 | parser.add_argument("--feature_size", default=16, type=int, help="feature size dimention") 38 | parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap") 39 | parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") 40 | parser.add_argument("--out_channels", default=14, type=int, help="number of output channels") 41 | parser.add_argument("--num_heads", default=12, type=int, help="number of attention heads in ViT encoder") 42 | parser.add_argument("--res_block", action="store_true", help="use residual blocks") 43 | parser.add_argument("--conv_block", action="store_true", help="use conv blocks") 44 | parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") 45 | parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") 46 | parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") 47 | parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") 48 | parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") 49 | parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") 50 | parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") 51 | parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction") 52 | parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction") 53 | parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction") 54 | parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") 55 | parser.add_argument("--distributed", action="store_true", help="start distributed training") 56 | parser.add_argument("--workers", default=8, type=int, help="number of workers") 57 | parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability") 58 | parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability") 59 | parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") 60 | parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") 61 | parser.add_argument("--pos_embed", default="perceptron", type=str, help="type of position embedding") 62 | parser.add_argument("--norm_name", default="instance", type=str, help="normalization layer type in decoder") 63 | 64 | 65 | def main(): 66 | args = parser.parse_args() 67 | args.test_mode = True 68 | val_loader = get_loader(args) 69 | pretrained_dir = args.pretrained_dir 70 | model_name = args.pretrained_model_name 71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 72 | pretrained_pth = os.path.join(pretrained_dir, model_name) 73 | if args.saved_checkpoint == "torchscript": 74 | model = torch.jit.load(pretrained_pth) 75 | elif args.saved_checkpoint == "ckpt": 76 | model = UNETR( 77 | in_channels=args.in_channels, 78 | out_channels=args.out_channels, 79 | img_size=(args.roi_x, args.roi_y, args.roi_z), 80 | feature_size=args.feature_size, 81 | hidden_size=args.hidden_size, 82 | mlp_dim=args.mlp_dim, 83 | num_heads=args.num_heads, 84 | pos_embed=args.pos_embed, 85 | norm_name=args.norm_name, 86 | conv_block=True, 87 | res_block=True, 88 | dropout_rate=args.dropout_rate, 89 | ) 90 | model_dict = torch.load(pretrained_pth) 91 | model.load_state_dict(model_dict) 92 | model.eval() 93 | model.to(device) 94 | 95 | with torch.no_grad(): 96 | dice_list_case = [] 97 | for i, batch in enumerate(val_loader): 98 | val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) 99 | img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1] 100 | print("Inference on case {}".format(img_name)) 101 | val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model, overlap=args.infer_overlap) 102 | val_outputs = torch.softmax(val_outputs, 1).cpu().numpy() 103 | val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8) 104 | val_labels = val_labels.cpu().numpy()[:, 0, :, :, :] 105 | dice_list_sub = [] 106 | for i in range(1, 14): 107 | organ_Dice = dice(val_outputs[0] == i, val_labels[0] == i) 108 | dice_list_sub.append(organ_Dice) 109 | mean_dice = np.mean(dice_list_sub) 110 | print("Mean Organ Dice: {}".format(mean_dice)) 111 | dice_list_case.append(mean_dice) 112 | print("Overall Mean Dice: {}".format(np.mean(dice_list_case))) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import math 13 | import warnings 14 | from typing import List 15 | 16 | from torch import nn as nn 17 | from torch.optim import Adam, Optimizer 18 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 19 | 20 | __all__ = ["LinearLR", "ExponentialLR"] 21 | 22 | 23 | class _LRSchedulerMONAI(_LRScheduler): 24 | """Base class for increasing the learning rate between two boundaries over a number 25 | of iterations""" 26 | 27 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: 28 | """ 29 | Args: 30 | optimizer: wrapped optimizer. 31 | end_lr: the final learning rate. 32 | num_iter: the number of iterations over which the test occurs. 33 | last_epoch: the index of last epoch. 34 | Returns: 35 | None 36 | """ 37 | self.end_lr = end_lr 38 | self.num_iter = num_iter 39 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) 40 | 41 | 42 | class LinearLR(_LRSchedulerMONAI): 43 | """Linearly increases the learning rate between two boundaries over a number of 44 | iterations. 45 | """ 46 | 47 | def get_lr(self): 48 | r = self.last_epoch / (self.num_iter - 1) 49 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 50 | 51 | 52 | class ExponentialLR(_LRSchedulerMONAI): 53 | """Exponentially increases the learning rate between two boundaries over a number of 54 | iterations. 55 | """ 56 | 57 | def get_lr(self): 58 | r = self.last_epoch / (self.num_iter - 1) 59 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 60 | 61 | 62 | class WarmupCosineSchedule(LambdaLR): 63 | """Linear warmup and then cosine decay. 64 | Based on https://huggingface.co/ implementation. 65 | """ 66 | 67 | def __init__( 68 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 69 | ) -> None: 70 | """ 71 | Args: 72 | optimizer: wrapped optimizer. 73 | warmup_steps: number of warmup iterations. 74 | t_total: total number of training iterations. 75 | cycles: cosine cycles parameter. 76 | last_epoch: the index of last epoch. 77 | Returns: 78 | None 79 | """ 80 | self.warmup_steps = warmup_steps 81 | self.t_total = t_total 82 | self.cycles = cycles 83 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) 84 | 85 | def lr_lambda(self, step): 86 | if step < self.warmup_steps: 87 | return float(step) / float(max(1.0, self.warmup_steps)) 88 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 89 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 90 | 91 | 92 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 93 | def __init__( 94 | self, 95 | optimizer: Optimizer, 96 | warmup_epochs: int, 97 | max_epochs: int, 98 | warmup_start_lr: float = 0.0, 99 | eta_min: float = 0.0, 100 | last_epoch: int = -1, 101 | ) -> None: 102 | """ 103 | Args: 104 | optimizer (Optimizer): Wrapped optimizer. 105 | warmup_epochs (int): Maximum number of iterations for linear warmup 106 | max_epochs (int): Maximum number of iterations 107 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 108 | eta_min (float): Minimum learning rate. Default: 0. 109 | last_epoch (int): The index of last epoch. Default: -1. 110 | """ 111 | self.warmup_epochs = warmup_epochs 112 | self.max_epochs = max_epochs 113 | self.warmup_start_lr = warmup_start_lr 114 | self.eta_min = eta_min 115 | 116 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 117 | 118 | def get_lr(self) -> List[float]: 119 | """ 120 | Compute learning rate using chainable form of the scheduler 121 | """ 122 | if not self._get_lr_called_within_step: 123 | warnings.warn( 124 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning 125 | ) 126 | 127 | if self.last_epoch == 0: 128 | return [self.warmup_start_lr] * len(self.base_lrs) 129 | elif self.last_epoch < self.warmup_epochs: 130 | return [ 131 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 132 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 133 | ] 134 | elif self.last_epoch == self.warmup_epochs: 135 | return self.base_lrs 136 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 137 | return [ 138 | group["lr"] 139 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 140 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 141 | ] 142 | 143 | return [ 144 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 145 | / ( 146 | 1 147 | + math.cos( 148 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) 149 | ) 150 | ) 151 | * (group["lr"] - self.eta_min) 152 | + self.eta_min 153 | for group in self.optimizer.param_groups 154 | ] 155 | 156 | def _get_closed_form_lr(self) -> List[float]: 157 | """ 158 | Called when epoch is passed as a param to the `step` function of the scheduler. 159 | """ 160 | if self.last_epoch < self.warmup_epochs: 161 | return [ 162 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 163 | for base_lr in self.base_lrs 164 | ] 165 | 166 | return [ 167 | self.eta_min 168 | + 0.5 169 | * (base_lr - self.eta_min) 170 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 171 | for base_lr in self.base_lrs 172 | ] 173 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import math 13 | import os 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from monai import data, transforms 19 | from monai.data import load_decathlon_datalist 20 | 21 | 22 | class Sampler(torch.utils.data.Sampler): 23 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True): 24 | if num_replicas is None: 25 | if not torch.distributed.is_available(): 26 | raise RuntimeError("Requires distributed package to be available") 27 | num_replicas = torch.distributed.get_world_size() 28 | if rank is None: 29 | if not torch.distributed.is_available(): 30 | raise RuntimeError("Requires distributed package to be available") 31 | rank = torch.distributed.get_rank() 32 | self.shuffle = shuffle 33 | self.make_even = make_even 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | indices = list(range(len(self.dataset))) 41 | self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas]) 42 | 43 | def __iter__(self): 44 | if self.shuffle: 45 | g = torch.Generator() 46 | g.manual_seed(self.epoch) 47 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 48 | else: 49 | indices = list(range(len(self.dataset))) 50 | if self.make_even: 51 | if len(indices) < self.total_size: 52 | if self.total_size - len(indices) < len(indices): 53 | indices += indices[: (self.total_size - len(indices))] 54 | else: 55 | extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices)) 56 | indices += [indices[ids] for ids in extra_ids] 57 | assert len(indices) == self.total_size 58 | indices = indices[self.rank : self.total_size : self.num_replicas] 59 | self.num_samples = len(indices) 60 | return iter(indices) 61 | 62 | def __len__(self): 63 | return self.num_samples 64 | 65 | def set_epoch(self, epoch): 66 | self.epoch = epoch 67 | 68 | 69 | def get_loader(args): 70 | data_dir = args.data_dir 71 | datalist_json = os.path.join(data_dir, args.json_list) 72 | train_transform = transforms.Compose( 73 | [ 74 | transforms.LoadImaged(keys=["image", "label"]), 75 | transforms.AddChanneld(keys=["image", "label"]), 76 | transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), 77 | transforms.Spacingd( 78 | keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") 79 | ), 80 | transforms.ScaleIntensityRanged( 81 | keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True 82 | ), 83 | transforms.CropForegroundd(keys=["image", "label"], source_key="image"), 84 | transforms.RandCropByPosNegLabeld( 85 | keys=["image", "label"], 86 | label_key="label", 87 | spatial_size=(args.roi_x, args.roi_y, args.roi_z), 88 | pos=1, 89 | neg=1, 90 | num_samples=4, 91 | image_key="image", 92 | image_threshold=0, 93 | ), 94 | transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=0), 95 | transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=1), 96 | transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=2), 97 | transforms.RandRotate90d(keys=["image", "label"], prob=args.RandRotate90d_prob, max_k=3), 98 | transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=args.RandScaleIntensityd_prob), 99 | transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=args.RandShiftIntensityd_prob), 100 | transforms.ToTensord(keys=["image", "label"]), 101 | ] 102 | ) 103 | val_transform = transforms.Compose( 104 | [ 105 | transforms.LoadImaged(keys=["image", "label"]), 106 | transforms.AddChanneld(keys=["image", "label"]), 107 | transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), 108 | transforms.Spacingd( 109 | keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") 110 | ), 111 | transforms.ScaleIntensityRanged( 112 | keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True 113 | ), 114 | transforms.CropForegroundd(keys=["image", "label"], source_key="image"), 115 | transforms.ToTensord(keys=["image", "label"]), 116 | ] 117 | ) 118 | 119 | if args.test_mode: 120 | test_files = load_decathlon_datalist(datalist_json, True, "validation", base_dir=data_dir) 121 | test_ds = data.Dataset(data=test_files, transform=val_transform) 122 | test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None 123 | test_loader = data.DataLoader( 124 | test_ds, 125 | batch_size=1, 126 | shuffle=False, 127 | num_workers=args.workers, 128 | sampler=test_sampler, 129 | pin_memory=True, 130 | persistent_workers=True, 131 | ) 132 | loader = test_loader 133 | else: 134 | datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir) 135 | if args.use_normal_dataset: 136 | train_ds = data.Dataset(data=datalist, transform=train_transform) 137 | else: 138 | train_ds = data.CacheDataset( 139 | data=datalist, transform=train_transform, cache_num=24, cache_rate=1.0, num_workers=args.workers 140 | ) 141 | train_sampler = Sampler(train_ds) if args.distributed else None 142 | train_loader = data.DataLoader( 143 | train_ds, 144 | batch_size=args.batch_size, 145 | shuffle=(train_sampler is None), 146 | num_workers=args.workers, 147 | sampler=train_sampler, 148 | pin_memory=True, 149 | persistent_workers=True, 150 | ) 151 | val_files = load_decathlon_datalist(datalist_json, True, "validation", base_dir=data_dir) 152 | val_ds = data.Dataset(data=val_files, transform=val_transform) 153 | val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None 154 | val_loader = data.DataLoader( 155 | val_ds, 156 | batch_size=1, 157 | shuffle=False, 158 | num_workers=args.workers, 159 | sampler=val_sampler, 160 | pin_memory=True, 161 | persistent_workers=True, 162 | ) 163 | loader = [train_loader, val_loader] 164 | 165 | return loader 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paper Overview 2 | * Published: Oct 2021 3 | * Published in: IEEE Winter Conference on Applications of Computer Vision (WACV) 2022 4 | * Paper:https://arxiv.org/abs/2103.10504 5 | * Codes:https://monai.io/research/unetr 6 | 7 | # Model Overview 8 | This repository contains the code for UNETR: Transformers for 3D Medical Image Segmentation [1]. UNETR is the first 3D segmentation network that uses a pure vision transformer as its encoder without relying on CNNs for feature extraction. 9 | The code presents a volumetric (3D) multi-organ segmentation application using the BTCV challenge dataset. 10 | ![image](https://lh3.googleusercontent.com/pw/AM-JKLU2eTW17rYtCmiZP3WWC-U1HCPOHwLe6pxOfJXwv2W-00aHfsNy7jeGV1dwUq0PXFOtkqasQ2Vyhcu6xkKsPzy3wx7O6yGOTJ7ZzA01S6LSh8szbjNLfpbuGgMe6ClpiS61KGvqu71xXFnNcyvJNFjN=w1448-h496-no?authuser=0) 11 | 12 | ### Installing Dependencies 13 | Dependencies can be installed using: 14 | ``` bash 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Training 19 | 20 | A UNETR network with standard hyper-parameters for the task of multi-organ semantic segmentation (BTCV dataset) can be defined as follows: 21 | 22 | ``` bash 23 | model = UNETR( 24 | in_channels=1, 25 | out_channels=14, 26 | img_size=(96, 96, 96), 27 | feature_size=16, 28 | hidden_size=768, 29 | mlp_dim=3072, 30 | num_heads=12, 31 | pos_embed='perceptron', 32 | norm_name='instance', 33 | conv_block=True, 34 | res_block=True, 35 | dropout_rate=0.0) 36 | ``` 37 | 38 | The above UNETR model is used for CT images (1-channel input) and for 14-class segmentation outputs. The network expects 39 | resampled input images with size ```(96, 96, 96)``` which will be converted into non-overlapping patches of size ```(16, 16, 16)```. 40 | The position embedding is performed using a perceptron layer. The ViT encoder follows standard hyper-parameters as introduced in [2]. 41 | The decoder uses convolutional and residual blocks as well as instance normalization. More details can be found in [1]. 42 | 43 | Using the default values for hyper-parameters, the following command can be used to initiate training using PyTorch native AMP package: 44 | ``` bash 45 | python main.py 46 | --feature_size=32 47 | --batch_size=1 48 | --logdir=unetr_test 49 | --fold=0 50 | --optim_lr=1e-4 51 | --lrschedule=warmup_cosine 52 | --infer_overlap=0.5 53 | --save_checkpoint 54 | --data_dir=/dataset/dataset0/ 55 | ``` 56 | 57 | Note that you need to provide the location of your dataset directory by using ```--data_dir```. 58 | 59 | To initiate distributed multi-gpu training, ```--distributed``` needs to be added to the training command. 60 | 61 | To disable AMP, ```--noamp``` needs to be added to the training command. 62 | 63 | If UNETR is used in distributed multi-gpu training, we recommend increasing the learning rate (i.e. ```--optim_lr```) 64 | according to the number of GPUs. For instance, ```--optim_lr=4e-4``` is recommended for training with 4 GPUs. 65 | 66 | ### Finetuning 67 | We provide state-of-the-art pre-trained checkpoints and TorchScript models of UNETR using BTCV dataset. 68 | 69 | For using the pre-trained checkpoint, please download the weights from the following directory: 70 | 71 | https://drive.google.com/file/d/1kR5QuRAuooYcTNLMnMj80Z9IgSs8jtLO/view?usp=sharing 72 | 73 | Once downloaded, please place the checkpoint in the following directory or use ```--pretrained_dir``` to provide the address of where the model is placed: 74 | 75 | ```./pretrained_models``` 76 | 77 | The following command initiates finetuning using the pretrained checkpoint: 78 | ``` bash 79 | python main.py 80 | --batch_size=1 81 | --logdir=unetr_pretrained 82 | --fold=0 83 | --optim_lr=1e-4 84 | --lrschedule=warmup_cosine 85 | --infer_overlap=0.5 86 | --save_checkpoint 87 | --data_dir=/dataset/dataset0/ 88 | --pretrained_dir='./pretrained_models/' 89 | --pretrained_model_name='UNETR_model_best_acc.pth' 90 | --resume_ckpt 91 | ``` 92 | 93 | For using the pre-trained TorchScript model, please download the model from the following directory: 94 | 95 | https://drive.google.com/file/d/1_YbUE0abQFJUR4Luwict6BB8S77yUaWN/view?usp=sharing 96 | 97 | Once downloaded, please place the TorchScript model in the following directory or use ```--pretrained_dir``` to provide the address of where the model is placed: 98 | 99 | ```./pretrained_models``` 100 | 101 | The following command initiates finetuning using the TorchScript model: 102 | ``` bash 103 | python main.py 104 | --batch_size=1 105 | --logdir=unetr_pretrained 106 | --fold=0 107 | --optim_lr=1e-4 108 | --lrschedule=warmup_cosine 109 | --infer_overlap=0.5 110 | --save_checkpoint 111 | --data_dir=/dataset/dataset0/ 112 | --pretrained_dir='./pretrained_models/' 113 | --noamp 114 | --pretrained_model_name='UNETR_model_best_acc.pt' 115 | --resume_jit 116 | ``` 117 | Note that finetuning from the provided TorchScript model does not support AMP. 118 | 119 | 120 | ### Testing 121 | You can use the state-of-the-art pre-trained TorchScript model or checkpoint of UNETR to test it on your own data. 122 | 123 | Once the pretrained weights are downloaded, using the links above, please place the TorchScript model in the following directory or 124 | use ```--pretrained_dir``` to provide the address of where the model is placed: 125 | 126 | ```./pretrained_models``` 127 | 128 | The following command runs inference using the provided checkpoint: 129 | ``` bash 130 | python test.py 131 | --infer_overlap=0.5 132 | --data_dir=/dataset/dataset0/ 133 | --pretrained_dir='./pretrained_models/' 134 | --saved_checkpoint=ckpt 135 | ``` 136 | 137 | Note that ```--infer_overlap``` determines the overlap between the sliding window patches. A higher value typically results in more accurate segmentation outputs but with the cost of longer inference time. 138 | 139 | If you would like to use the pretrained TorchScript model, ```--saved_checkpoint=torchscript``` should be used. 140 | 141 | ### Tutorial 142 | A tutorial for the task of multi-organ segmentation using BTCV dataset can be found in the following: 143 | 144 | https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb 145 | 146 | Additionally, a tutorial which leverages PyTorch Lightning can be found in the following: 147 | 148 | https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb 149 | ## Dataset 150 | ![image](https://lh3.googleusercontent.com/pw/AM-JKLX0svvlMdcrchGAgiWWNkg40lgXYjSHsAAuRc5Frakmz2pWzSzf87JQCRgYpqFR0qAjJWPzMQLc_mmvzNjfF9QWl_1OHZ8j4c9qrbR6zQaDJWaCLArRFh0uPvk97qAa11HtYbD6HpJ-wwTCUsaPcYvM=w1724-h522-no?authuser=0) 151 | 152 | The training data is from the [BTCV challenge dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217752). 153 | 154 | Under Institutional Review Board (IRB) supervision, 50 abdomen CT scans of were randomly selected from a combination of an ongoing colorectal cancer chemotherapy trial, and a retrospective ventral hernia study. The 50 scans were captured during portal venous contrast phase with variable volume sizes (512 x 512 x 85 - 512 x 512 x 198) and field of views (approx. 280 x 280 x 280 mm3 - 500 x 500 x 650 mm3). The in-plane resolution varies from 0.54 x 0.54 mm2 to 0.98 x 0.98 mm2, while the slice thickness ranges from 2.5 mm to 5.0 mm. 155 | 156 | - Target: 13 abdominal organs including 1. Spleen 2. Right Kidney 3. Left Kideny 4.Gallbladder 5.Esophagus 6. Liver 7. Stomach 8.Aorta 9. IVC 10. Portal and Splenic Veins 11. Pancreas 12.Right adrenal gland 13.Left adrenal gland. 157 | - Task: Segmentation 158 | - Modality: CT 159 | - Size: 30 3D volumes (24 Training + 6 Testing) 160 | 161 | 162 | We provide the json file that is used to train our models in the following link: 163 | 164 | https://drive.google.com/file/d/1t4fIQQkONv7ArTSZe4Nucwkk1KfdUDvW/view?usp=sharing 165 | 166 | Once the json file is downloaded, please place it in the same folder as the dataset. 167 | 168 | ## Citation 169 | If you find this repository useful, please consider citing UNETR paper: 170 | 171 | ``` 172 | @inproceedings{hatamizadeh2022unetr, 173 | title={Unetr: Transformers for 3d medical image segmentation}, 174 | author={Hatamizadeh, Ali and Tang, Yucheng and Nath, Vishwesh and Yang, Dong and Myronenko, Andriy and Landman, Bennett and Roth, Holger R and Xu, Daguang}, 175 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 176 | pages={574--584}, 177 | year={2022} 178 | } 179 | ``` 180 | 181 | ## References 182 | [1] Hatamizadeh, Ali, et al. "UNETR: Transformers for 3D Medical Image Segmentation", 2021. https://arxiv.org/abs/2103.10504. 183 | 184 | [2] Dosovitskiy, Alexey, et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale 185 | ", 2020. https://arxiv.org/abs/2010.11929. 186 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import os 13 | import shutil 14 | import time 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.parallel 19 | import torch.utils.data.distributed 20 | from tensorboardX import SummaryWriter 21 | from torch.cuda.amp import GradScaler, autocast 22 | from utils.utils import distributed_all_gather 23 | 24 | from monai.data import decollate_batch 25 | 26 | 27 | def dice(x, y): 28 | intersect = np.sum(np.sum(np.sum(x * y))) 29 | y_sum = np.sum(np.sum(np.sum(y))) 30 | if y_sum == 0: 31 | return 0.0 32 | x_sum = np.sum(np.sum(np.sum(x))) 33 | return 2 * intersect / (x_sum + y_sum) 34 | 35 | 36 | class AverageMeter(object): 37 | def __init__(self): 38 | self.reset() 39 | 40 | def reset(self): 41 | self.val = 0 42 | self.avg = 0 43 | self.sum = 0 44 | self.count = 0 45 | 46 | def update(self, val, n=1): 47 | self.val = val 48 | self.sum += val * n 49 | self.count += n 50 | self.avg = np.where(self.count > 0, self.sum / self.count, self.sum) 51 | 52 | 53 | def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args): 54 | model.train() 55 | start_time = time.time() 56 | run_loss = AverageMeter() 57 | for idx, batch_data in enumerate(loader): 58 | if isinstance(batch_data, list): 59 | data, target = batch_data 60 | else: 61 | data, target = batch_data["image"], batch_data["label"] 62 | data, target = data.cuda(args.rank), target.cuda(args.rank) 63 | for param in model.parameters(): 64 | param.grad = None 65 | with autocast(enabled=args.amp): 66 | logits = model(data) 67 | loss = loss_func(logits, target) 68 | if args.amp: 69 | scaler.scale(loss).backward() 70 | scaler.step(optimizer) 71 | scaler.update() 72 | else: 73 | loss.backward() 74 | optimizer.step() 75 | if args.distributed: 76 | loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length) 77 | run_loss.update( 78 | np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size 79 | ) 80 | else: 81 | run_loss.update(loss.item(), n=args.batch_size) 82 | if args.rank == 0: 83 | print( 84 | "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), 85 | "loss: {:.4f}".format(run_loss.avg), 86 | "time {:.2f}s".format(time.time() - start_time), 87 | ) 88 | start_time = time.time() 89 | for param in model.parameters(): 90 | param.grad = None 91 | return run_loss.avg 92 | 93 | 94 | def val_epoch(model, loader, epoch, acc_func, args, model_inferer=None, post_label=None, post_pred=None): 95 | model.eval() 96 | start_time = time.time() 97 | with torch.no_grad(): 98 | for idx, batch_data in enumerate(loader): 99 | if isinstance(batch_data, list): 100 | data, target = batch_data 101 | else: 102 | data, target = batch_data["image"], batch_data["label"] 103 | data, target = data.cuda(args.rank), target.cuda(args.rank) 104 | with autocast(enabled=args.amp): 105 | if model_inferer is not None: 106 | logits = model_inferer(data) 107 | else: 108 | logits = model(data) 109 | if not logits.is_cuda: 110 | target = target.cpu() 111 | val_labels_list = decollate_batch(target) 112 | val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list] 113 | val_outputs_list = decollate_batch(logits) 114 | val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list] 115 | acc = acc_func(y_pred=val_output_convert, y=val_labels_convert) 116 | acc = acc.cuda(args.rank) 117 | 118 | if args.distributed: 119 | acc_list = distributed_all_gather([acc], out_numpy=True, is_valid=idx < loader.sampler.valid_length) 120 | avg_acc = np.mean([np.nanmean(l) for l in acc_list]) 121 | 122 | else: 123 | acc_list = acc.detach().cpu().numpy() 124 | avg_acc = np.mean([np.nanmean(l) for l in acc_list]) 125 | 126 | if args.rank == 0: 127 | print( 128 | "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), 129 | "acc", 130 | avg_acc, 131 | "time {:.2f}s".format(time.time() - start_time), 132 | ) 133 | start_time = time.time() 134 | return avg_acc 135 | 136 | 137 | def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None): 138 | state_dict = model.state_dict() if not args.distributed else model.module.state_dict() 139 | save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict} 140 | if optimizer is not None: 141 | save_dict["optimizer"] = optimizer.state_dict() 142 | if scheduler is not None: 143 | save_dict["scheduler"] = scheduler.state_dict() 144 | filename = os.path.join(args.logdir, filename) 145 | torch.save(save_dict, filename) 146 | print("Saving checkpoint", filename) 147 | 148 | 149 | def run_training( 150 | model, 151 | train_loader, 152 | val_loader, 153 | optimizer, 154 | loss_func, 155 | acc_func, 156 | args, 157 | model_inferer=None, 158 | scheduler=None, 159 | start_epoch=0, 160 | post_label=None, 161 | post_pred=None, 162 | ): 163 | writer = None 164 | if args.logdir is not None and args.rank == 0: 165 | writer = SummaryWriter(log_dir=args.logdir) 166 | if args.rank == 0: 167 | print("Writing Tensorboard logs to ", args.logdir) 168 | scaler = None 169 | if args.amp: 170 | scaler = GradScaler() 171 | val_acc_max = 0.0 172 | for epoch in range(start_epoch, args.max_epochs): 173 | if args.distributed: 174 | train_loader.sampler.set_epoch(epoch) 175 | torch.distributed.barrier() 176 | print(args.rank, time.ctime(), "Epoch:", epoch) 177 | epoch_time = time.time() 178 | train_loss = train_epoch( 179 | model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args 180 | ) 181 | if args.rank == 0: 182 | print( 183 | "Final training {}/{}".format(epoch, args.max_epochs - 1), 184 | "loss: {:.4f}".format(train_loss), 185 | "time {:.2f}s".format(time.time() - epoch_time), 186 | ) 187 | if args.rank == 0 and writer is not None: 188 | writer.add_scalar("train_loss", train_loss, epoch) 189 | b_new_best = False 190 | if (epoch + 1) % args.val_every == 0: 191 | if args.distributed: 192 | torch.distributed.barrier() 193 | epoch_time = time.time() 194 | val_avg_acc = val_epoch( 195 | model, 196 | val_loader, 197 | epoch=epoch, 198 | acc_func=acc_func, 199 | model_inferer=model_inferer, 200 | args=args, 201 | post_label=post_label, 202 | post_pred=post_pred, 203 | ) 204 | if args.rank == 0: 205 | print( 206 | "Final validation {}/{}".format(epoch, args.max_epochs - 1), 207 | "acc", 208 | val_avg_acc, 209 | "time {:.2f}s".format(time.time() - epoch_time), 210 | ) 211 | if writer is not None: 212 | writer.add_scalar("val_acc", val_avg_acc, epoch) 213 | if val_avg_acc > val_acc_max: 214 | print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc)) 215 | val_acc_max = val_avg_acc 216 | b_new_best = True 217 | if args.rank == 0 and args.logdir is not None and args.save_checkpoint: 218 | save_checkpoint( 219 | model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler 220 | ) 221 | if args.rank == 0 and args.logdir is not None and args.save_checkpoint: 222 | save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt") 223 | if b_new_best: 224 | print("Copying to model.pt new best model!!!!") 225 | shutil.copyfile(os.path.join(args.logdir, "model_final.pt"), os.path.join(args.logdir, "model.pt")) 226 | 227 | if scheduler is not None: 228 | scheduler.step() 229 | 230 | print("Training Finished !, Best Accuracy: ", val_acc_max) 231 | 232 | return val_acc_max 233 | -------------------------------------------------------------------------------- /networks/unetr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Tuple, Union 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock 18 | from monai.networks.blocks.dynunet_block import UnetOutBlock 19 | from monai.networks.nets import ViT 20 | 21 | 22 | class UNETR(nn.Module): 23 | """ 24 | UNETR based on: "Hatamizadeh et al., 25 | UNETR: Transformers for 3D Medical Image Segmentation " 26 | """ 27 | 28 | def __init__( 29 | self, 30 | in_channels: int, 31 | out_channels: int, 32 | img_size: Tuple[int, int, int], 33 | feature_size: int = 16, 34 | hidden_size: int = 768, 35 | mlp_dim: int = 3072, 36 | num_heads: int = 12, 37 | pos_embed: str = "perceptron", 38 | norm_name: Union[Tuple, str] = "instance", 39 | conv_block: bool = False, 40 | res_block: bool = True, 41 | dropout_rate: float = 0.0, 42 | ) -> None: 43 | """ 44 | Args: 45 | in_channels: dimension of input channels. 46 | out_channels: dimension of output channels. 47 | img_size: dimension of input image. 48 | feature_size: dimension of network feature size. 49 | hidden_size: dimension of hidden layer. 50 | mlp_dim: dimension of feedforward layer. 51 | num_heads: number of attention heads. 52 | pos_embed: position embedding layer type. 53 | norm_name: feature normalization type and arguments. 54 | conv_block: bool argument to determine if convolutional block is used. 55 | res_block: bool argument to determine if residual block is used. 56 | dropout_rate: faction of the input units to drop. 57 | 58 | Examples:: 59 | 60 | # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm 61 | >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') 62 | 63 | # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm 64 | >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') 65 | 66 | """ 67 | 68 | super().__init__() 69 | 70 | if not (0 <= dropout_rate <= 1): 71 | raise AssertionError("dropout_rate should be between 0 and 1.") 72 | 73 | if hidden_size % num_heads != 0: 74 | raise AssertionError("hidden size should be divisible by num_heads.") 75 | 76 | if pos_embed not in ["conv", "perceptron"]: 77 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 78 | 79 | self.num_layers = 12 80 | self.patch_size = (16, 16, 16) 81 | self.feat_size = ( 82 | img_size[0] // self.patch_size[0], 83 | img_size[1] // self.patch_size[1], 84 | img_size[2] // self.patch_size[2], 85 | ) 86 | self.hidden_size = hidden_size 87 | self.classification = False 88 | self.vit = ViT( 89 | in_channels=in_channels, 90 | img_size=img_size, 91 | patch_size=self.patch_size, 92 | hidden_size=hidden_size, 93 | mlp_dim=mlp_dim, 94 | num_layers=self.num_layers, 95 | num_heads=num_heads, 96 | pos_embed=pos_embed, 97 | classification=self.classification, 98 | dropout_rate=dropout_rate, 99 | ) 100 | self.encoder1 = UnetrBasicBlock( 101 | spatial_dims=3, 102 | in_channels=in_channels, 103 | out_channels=feature_size, 104 | kernel_size=3, 105 | stride=1, 106 | norm_name=norm_name, 107 | res_block=res_block, 108 | ) 109 | self.encoder2 = UnetrPrUpBlock( 110 | spatial_dims=3, 111 | in_channels=hidden_size, 112 | out_channels=feature_size * 2, 113 | num_layer=2, 114 | kernel_size=3, 115 | stride=1, 116 | upsample_kernel_size=2, 117 | norm_name=norm_name, 118 | conv_block=conv_block, 119 | res_block=res_block, 120 | ) 121 | self.encoder3 = UnetrPrUpBlock( 122 | spatial_dims=3, 123 | in_channels=hidden_size, 124 | out_channels=feature_size * 4, 125 | num_layer=1, 126 | kernel_size=3, 127 | stride=1, 128 | upsample_kernel_size=2, 129 | norm_name=norm_name, 130 | conv_block=conv_block, 131 | res_block=res_block, 132 | ) 133 | self.encoder4 = UnetrPrUpBlock( 134 | spatial_dims=3, 135 | in_channels=hidden_size, 136 | out_channels=feature_size * 8, 137 | num_layer=0, 138 | kernel_size=3, 139 | stride=1, 140 | upsample_kernel_size=2, 141 | norm_name=norm_name, 142 | conv_block=conv_block, 143 | res_block=res_block, 144 | ) 145 | self.decoder5 = UnetrUpBlock( 146 | spatial_dims=3, 147 | in_channels=hidden_size, 148 | out_channels=feature_size * 8, 149 | kernel_size=3, 150 | upsample_kernel_size=2, 151 | norm_name=norm_name, 152 | res_block=res_block, 153 | ) 154 | self.decoder4 = UnetrUpBlock( 155 | spatial_dims=3, 156 | in_channels=feature_size * 8, 157 | out_channels=feature_size * 4, 158 | kernel_size=3, 159 | upsample_kernel_size=2, 160 | norm_name=norm_name, 161 | res_block=res_block, 162 | ) 163 | self.decoder3 = UnetrUpBlock( 164 | spatial_dims=3, 165 | in_channels=feature_size * 4, 166 | out_channels=feature_size * 2, 167 | kernel_size=3, 168 | upsample_kernel_size=2, 169 | norm_name=norm_name, 170 | res_block=res_block, 171 | ) 172 | self.decoder2 = UnetrUpBlock( 173 | spatial_dims=3, 174 | in_channels=feature_size * 2, 175 | out_channels=feature_size, 176 | kernel_size=3, 177 | upsample_kernel_size=2, 178 | norm_name=norm_name, 179 | res_block=res_block, 180 | ) 181 | self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore 182 | 183 | def proj_feat(self, x, hidden_size, feat_size): 184 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 185 | x = x.permute(0, 4, 1, 2, 3).contiguous() 186 | return x 187 | 188 | def load_from(self, weights): 189 | with torch.no_grad(): 190 | res_weight = weights 191 | # copy weights from patch embedding 192 | for i in weights["state_dict"]: 193 | print(i) 194 | self.vit.patch_embedding.position_embeddings.copy_( 195 | weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] 196 | ) 197 | self.vit.patch_embedding.cls_token.copy_( 198 | weights["state_dict"]["module.transformer.patch_embedding.cls_token"] 199 | ) 200 | self.vit.patch_embedding.patch_embeddings[1].weight.copy_( 201 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] 202 | ) 203 | self.vit.patch_embedding.patch_embeddings[1].bias.copy_( 204 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] 205 | ) 206 | 207 | # copy weights from encoding blocks (default: num of blocks: 12) 208 | for bname, block in self.vit.blocks.named_children(): 209 | print(block) 210 | block.loadFrom(weights, n_block=bname) 211 | # last norm layer of transformer 212 | self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) 213 | self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) 214 | 215 | def forward(self, x_in): 216 | x, hidden_states_out = self.vit(x_in) 217 | enc1 = self.encoder1(x_in) 218 | x2 = hidden_states_out[3] 219 | enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) 220 | x3 = hidden_states_out[6] 221 | enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) 222 | x4 = hidden_states_out[9] 223 | enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) 224 | dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) 225 | dec3 = self.decoder5(dec4, enc4) 226 | dec2 = self.decoder4(dec3, enc3) 227 | dec1 = self.decoder3(dec2, enc2) 228 | out = self.decoder2(dec1, enc1) 229 | logits = self.out(out) 230 | return logits 231 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import argparse 13 | import os 14 | from functools import partial 15 | 16 | import numpy as np 17 | import torch 18 | import torch.distributed as dist 19 | import torch.multiprocessing as mp 20 | import torch.nn.parallel 21 | import torch.utils.data.distributed 22 | from networks.unetr import UNETR 23 | from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR 24 | from trainer import run_training 25 | from utils.data_utils import get_loader 26 | 27 | from monai.inferers import sliding_window_inference 28 | from monai.losses import DiceCELoss, DiceLoss 29 | from monai.metrics import DiceMetric 30 | from monai.transforms import Activations, AsDiscrete, Compose 31 | from monai.utils.enums import MetricReduction 32 | 33 | parser = argparse.ArgumentParser(description="UNETR segmentation pipeline") 34 | parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint") 35 | parser.add_argument("--logdir", default="test", type=str, help="directory to save the tensorboard logs") 36 | parser.add_argument( 37 | "--pretrained_dir", default="./pretrained_models/", type=str, help="pretrained checkpoint directory" 38 | ) 39 | parser.add_argument("--data_dir", default="/dataset/dataset0/", type=str, help="dataset directory") 40 | parser.add_argument("--json_list", default="dataset_0.json", type=str, help="dataset json file") 41 | parser.add_argument( 42 | "--pretrained_model_name", default="UNETR_model_best_acc.pth", type=str, help="pretrained model name" 43 | ) 44 | parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training") 45 | parser.add_argument("--max_epochs", default=5000, type=int, help="max number of training epochs") 46 | parser.add_argument("--batch_size", default=1, type=int, help="number of batch size") 47 | parser.add_argument("--sw_batch_size", default=1, type=int, help="number of sliding window batch size") 48 | parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate") 49 | parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm") 50 | parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight") 51 | parser.add_argument("--momentum", default=0.99, type=float, help="momentum") 52 | parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") 53 | parser.add_argument("--val_every", default=100, type=int, help="validation frequency") 54 | parser.add_argument("--distributed", action="store_true", help="start distributed training") 55 | parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") 56 | parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") 57 | parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url") 58 | parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") 59 | parser.add_argument("--workers", default=8, type=int, help="number of workers") 60 | parser.add_argument("--model_name", default="unetr", type=str, help="model name") 61 | parser.add_argument("--pos_embed", default="perceptron", type=str, help="type of position embedding") 62 | parser.add_argument("--norm_name", default="instance", type=str, help="normalization layer type in decoder") 63 | parser.add_argument("--num_heads", default=12, type=int, help="number of attention heads in ViT encoder") 64 | parser.add_argument("--mlp_dim", default=3072, type=int, help="mlp dimention in ViT encoder") 65 | parser.add_argument("--hidden_size", default=768, type=int, help="hidden size dimention in ViT encoder") 66 | parser.add_argument("--feature_size", default=16, type=int, help="feature size dimention") 67 | parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") 68 | parser.add_argument("--out_channels", default=14, type=int, help="number of output channels") 69 | parser.add_argument("--res_block", action="store_true", help="use residual blocks") 70 | parser.add_argument("--conv_block", action="store_true", help="use conv blocks") 71 | parser.add_argument("--use_normal_dataset", action="store_true", help="use monai Dataset class") 72 | parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") 73 | parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") 74 | parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") 75 | parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") 76 | parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") 77 | parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") 78 | parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") 79 | parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction") 80 | parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction") 81 | parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction") 82 | parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") 83 | parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability") 84 | parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability") 85 | parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") 86 | parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") 87 | parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap") 88 | parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler") 89 | parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs") 90 | parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint") 91 | parser.add_argument("--resume_jit", action="store_true", help="resume training from pretrained torchscript checkpoint") 92 | parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan") 93 | parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero") 94 | 95 | 96 | def main(): 97 | args = parser.parse_args() 98 | args.amp = not args.noamp 99 | args.logdir = "./runs/" + args.logdir 100 | if args.distributed: 101 | args.ngpus_per_node = torch.cuda.device_count() 102 | print("Found total gpus", args.ngpus_per_node) 103 | args.world_size = args.ngpus_per_node * args.world_size 104 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,)) 105 | else: 106 | main_worker(gpu=0, args=args) 107 | 108 | 109 | def main_worker(gpu, args): 110 | 111 | if args.distributed: 112 | torch.multiprocessing.set_start_method("fork", force=True) 113 | np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True) 114 | args.gpu = gpu 115 | if args.distributed: 116 | args.rank = args.rank * args.ngpus_per_node + gpu 117 | dist.init_process_group( 118 | backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank 119 | ) 120 | torch.cuda.set_device(args.gpu) 121 | torch.backends.cudnn.benchmark = True 122 | args.test_mode = False 123 | loader = get_loader(args) 124 | print(args.rank, " gpu", args.gpu) 125 | if args.rank == 0: 126 | print("Batch size is:", args.batch_size, "epochs", args.max_epochs) 127 | inf_size = [args.roi_x, args.roi_y, args.roi_z] 128 | pretrained_dir = args.pretrained_dir 129 | if (args.model_name is None) or args.model_name == "unetr": 130 | model = UNETR( 131 | in_channels=args.in_channels, 132 | out_channels=args.out_channels, 133 | img_size=(args.roi_x, args.roi_y, args.roi_z), 134 | feature_size=args.feature_size, 135 | hidden_size=args.hidden_size, 136 | mlp_dim=args.mlp_dim, 137 | num_heads=args.num_heads, 138 | pos_embed=args.pos_embed, 139 | norm_name=args.norm_name, 140 | conv_block=True, 141 | res_block=True, 142 | dropout_rate=args.dropout_rate, 143 | ) 144 | 145 | if args.resume_ckpt: 146 | model_dict = torch.load(os.path.join(pretrained_dir, args.pretrained_model_name)) 147 | model.load_state_dict(model_dict) 148 | print("Use pretrained weights") 149 | 150 | if args.resume_jit: 151 | if not args.noamp: 152 | print("Training from pre-trained checkpoint does not support AMP\nAMP is disabled.") 153 | args.amp = args.noamp 154 | model = torch.jit.load(os.path.join(pretrained_dir, args.pretrained_model_name)) 155 | else: 156 | raise ValueError("Unsupported model " + str(args.model_name)) 157 | 158 | dice_loss = DiceCELoss( 159 | to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr 160 | ) 161 | post_label = AsDiscrete(to_onehot=True, n_classes=args.out_channels) 162 | post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.out_channels) 163 | dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True) 164 | model_inferer = partial( 165 | sliding_window_inference, 166 | roi_size=inf_size, 167 | sw_batch_size=args.sw_batch_size, 168 | predictor=model, 169 | overlap=args.infer_overlap, 170 | ) 171 | 172 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 173 | print("Total parameters count", pytorch_total_params) 174 | 175 | best_acc = 0 176 | start_epoch = 0 177 | 178 | if args.checkpoint is not None: 179 | checkpoint = torch.load(args.checkpoint, map_location="cpu") 180 | from collections import OrderedDict 181 | 182 | new_state_dict = OrderedDict() 183 | for k, v in checkpoint["state_dict"].items(): 184 | new_state_dict[k.replace("backbone.", "")] = v 185 | model.load_state_dict(new_state_dict, strict=False) 186 | if "epoch" in checkpoint: 187 | start_epoch = checkpoint["epoch"] 188 | if "best_acc" in checkpoint: 189 | best_acc = checkpoint["best_acc"] 190 | print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc)) 191 | 192 | model.cuda(args.gpu) 193 | 194 | if args.distributed: 195 | torch.cuda.set_device(args.gpu) 196 | if args.norm_name == "batch": 197 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 198 | model.cuda(args.gpu) 199 | model = torch.nn.parallel.DistributedDataParallel( 200 | model, device_ids=[args.gpu], output_device=args.gpu, find_unused_parameters=True 201 | ) 202 | if args.optim_name == "adam": 203 | optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) 204 | elif args.optim_name == "adamw": 205 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) 206 | elif args.optim_name == "sgd": 207 | optimizer = torch.optim.SGD( 208 | model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight 209 | ) 210 | else: 211 | raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name)) 212 | 213 | if args.lrschedule == "warmup_cosine": 214 | scheduler = LinearWarmupCosineAnnealingLR( 215 | optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs 216 | ) 217 | elif args.lrschedule == "cosine_anneal": 218 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs) 219 | if args.checkpoint is not None: 220 | scheduler.step(epoch=start_epoch) 221 | else: 222 | scheduler = None 223 | accuracy = run_training( 224 | model=model, 225 | train_loader=loader[0], 226 | val_loader=loader[1], 227 | optimizer=optimizer, 228 | loss_func=dice_loss, 229 | acc_func=dice_acc, 230 | args=args, 231 | model_inferer=model_inferer, 232 | scheduler=scheduler, 233 | start_epoch=start_epoch, 234 | post_label=post_label, 235 | post_pred=post_pred, 236 | ) 237 | return accuracy 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | --------------------------------------------------------------------------------