├── models ├── __init__.py ├── build.py ├── swin_3D.py └── phtrans.py ├── img └── method.png ├── data ├── __init__.py ├── dataset_val.py ├── build.py ├── data_augmentation.py ├── utils.py └── dataset_train.py ├── .gitignore ├── predict.sh ├── utils.py ├── requirements.txt ├── optimizer.py ├── README.md ├── lr_scheduler.py ├── metrics.py ├── coarse_train.py ├── fine_train.py ├── data_preprocess.py ├── unlabel_data_preprocess.py ├── trainer.py ├── losses.py ├── config.py ├── predict.py └── LICENSE /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_coarse_model,build_fine_model -------------------------------------------------------------------------------- /img/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lseventeen/FLARE22-TwoStagePHTrans/HEAD/img/method.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader,DataLoaderX 2 | from .dataset_val import predict_dataset -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | *egg-info 3 | .vscode 4 | *__pycache__* 5 | save* 6 | val* 7 | *image 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /predict.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash -e 2 | python /home/lwt/code/flare/FLARE22-TwoStagePHTrans/predict.py -dp '/workspace/inputs/' -op '/workspace/outputs/' 3 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import os 5 | from batchgenerators.utilities.file_and_folder_operations import * 6 | 7 | 8 | def seed_torch(seed=42): 9 | random.seed(seed) 10 | os.environ['PYTHONHASHSEED'] = str(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | torch.backends.cudnn.deterministic = True 15 | 16 | 17 | def to_cuda(data, non_blocking=True): 18 | if isinstance(data, list): 19 | data = [i.cuda(non_blocking=non_blocking) for i in data] 20 | else: 21 | data = data.cuda(non_blocking=non_blocking) 22 | return data 23 | 24 | 25 | def load_checkpoint(checkpoint_path): 26 | checkpoint_file = "final_checkpoint.pth" 27 | checkpoint = torch.load( 28 | join(checkpoint_path, checkpoint_file), map_location=torch.device('cpu')) 29 | return checkpoint 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | batchgenerators==0.23 3 | brotlipy==0.7.0 4 | bunch==1.0.1 5 | certifi==2021.10.8 6 | 7 | 8 | click==8.1.3 9 | connected-components-3d==3.10.0 10 | 11 | docker-pycreds==0.4.0 12 | einops==0.4.1 13 | fastremap==1.13.0 14 | future==0.18.2 15 | gitdb==4.0.9 16 | GitPython==3.1.27 17 | 18 | imageio==2.17.0 19 | joblib==1.1.0 20 | linecache2==1.0.0 21 | lmdb==1.3.0 22 | loguru==0.6.0 23 | mkl-fft==1.3.1 24 | 25 | mkl-service==2.4.0 26 | networkx==2.8 27 | nibabel==3.2.2 28 | numpy==1.22.3 29 | opencv-python==4.5.5.64 30 | packaging==21.3 31 | pathtools==0.1.2 32 | Pillow==9.1.0 33 | prefetch-generator==1.0.1 34 | promise==2.3 35 | protobuf==3.20.1 36 | psutil==5.9.0 37 | 38 | pyparsing==3.0.8 39 | 40 | python-dateutil==2.8.2 41 | PyWavelets==1.3.0 42 | PyYAML==6.0 43 | 44 | ruamel.yaml==0.17.21 45 | ruamel.yaml.clib==0.2.6 46 | scikit-image==0.19.2 47 | scikit-learn==1.0.2 48 | scipy==1.8.0 49 | sentry-sdk==1.5.10 50 | setproctitle==1.2.3 51 | shortuuid==1.0.8 52 | SimpleITK==2.0.2 53 | 54 | smmap==5.0.0 55 | threadpoolctl==3.1.0 56 | tifffile==2022.4.22 57 | timm==0.5.4 58 | 59 | tqdm==4.64.0 60 | traceback2==1.4.0 61 | 62 | unittest2==1.1.0 63 | 64 | wandb==0.12.15 65 | yacs==0.1.8 66 | -------------------------------------------------------------------------------- /data/dataset_val.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | from batchgenerators.utilities.file_and_folder_operations import * 4 | from .utils import load_data,change_axes_of_image 5 | 6 | class predict_dataset(Dataset): 7 | def __init__(self, config): 8 | super(predict_dataset, self).__init__() 9 | self.config = config 10 | self.data_path = config.DATASET.VAL_IMAGE_PATH 11 | 12 | self.is_nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION 13 | 14 | self.series_ids = subfiles(self.data_path, join=False, suffix='gz') 15 | def __len__(self): 16 | return len(self.series_ids) 17 | 18 | def __getitem__(self, idx): 19 | image_id = self.series_ids[idx].split("_")[1] 20 | raw_image, image_spacing, image_direction= load_data(join(self.data_path,self.series_ids[idx])) 21 | if self.is_nor_dir: 22 | raw_image = change_axes_of_image(raw_image, image_direction) 23 | return {'image_id': image_id, 24 | 'raw_image': np.ascontiguousarray(raw_image), 25 | 'raw_spacing': image_spacing, 26 | 'image_direction': image_direction 27 | } 28 | 29 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | from torch import optim as optim 2 | 3 | def build_optimizer(config, model): 4 | """ 5 | Build optimizer, set weight decay of normalization to 0 by default. 6 | """ 7 | skip = {} 8 | skip_keywords = {} 9 | if hasattr(model, 'no_weight_decay'): 10 | skip = model.no_weight_decay() 11 | if hasattr(model, 'no_weight_decay_keywords'): 12 | skip_keywords = model.no_weight_decay_keywords() 13 | parameters = set_weight_decay(model, skip, skip_keywords) 14 | 15 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 16 | optimizer = None 17 | if opt_lower == 'sgd': 18 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 19 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 20 | elif opt_lower == 'adamw': 21 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 22 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 23 | 24 | return optimizer 25 | 26 | 27 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 28 | has_decay = [] 29 | no_decay = [] 30 | 31 | for name, param in model.named_parameters(): 32 | if not param.requires_grad: 33 | continue 34 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 35 | check_keywords_in_name(name, skip_keywords): 36 | no_decay.append(param) 37 | else: 38 | has_decay.append(param) 39 | return [{'params': has_decay}, 40 | {'params': no_decay, 'weight_decay': 0.}] 41 | 42 | 43 | def check_keywords_in_name(name, keywords=()): 44 | isin = False 45 | for keyword in keywords: 46 | if keyword in name: 47 | isin = True 48 | return isin 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Two-stage PHTrans 2 | 3 | This repository is a solution for the [MICCAI FLARE2022 challenge](https://flare22.grand-challenge.org/). A detailed description of the method introduction, experiments and analysis of the results for this solution is presented in paper : [Combining Self-Training and Hybrid Architecture for Semi-supervised Abdominal Organ Segmentation](https://arxiv.org/abs/2207.11512). As shown in the figure below, this pipeline consists of two parts: (a) pseudo-label generation for unlabeled data, which is implemented using PHTrans under the nn-UNet framework (for more information, see [PHTrans](https://github.com/lseventeen/PHTrans)); (b) a two-stage segmentation framework with Lightweight PHTrans. This repository is the code implementation of this part. 4 | 5 |
6 | 7 |
8 | 9 | 10 | ## Prerequisites 11 | 12 | 13 | 14 | Download our repo and install packages: 15 | ``` 16 | git clone https://github.com/lseventeen/FLARE22-TwoStagePHTrans 17 | cd FLARE22-TwoStagePHTrans 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | 22 | ## Datasets processing 23 | Download [FLARE 2022](https://flare22.grand-challenge.org/Dataset/) datasets. Generate pseudo-labels for unlabeled data based on the repository [PHTrans](https://github.com/lseventeen/PHTrans). Modify the data path in the [config.py](https://github.com/lseventeen/FLARE22-TwoStagePHTrans/blob/master/config.py) file. Type this in the terminal to perform dataset processing: 24 | 25 | ``` 26 | python data_processing.py 27 | ``` 28 | 29 | ## Training 30 | Type this in terminal to run coarse segmentation train: 31 | 32 | ``` 33 | python coarse_train.py 34 | ``` 35 | Type this in terminal to run fine segmentation train: 36 | 37 | ``` 38 | python fine_train.py 39 | ``` 40 | ## Inference 41 | Type this in terminal to Inference: 42 | 43 | ``` 44 | python predict.py -dp DATA_PATH -op SAVE_RESULTS_PATH 45 | ``` 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from batchgenerators.utilities.file_and_folder_operations import * 3 | from data.dataset_train import flare22_dataset 4 | from sklearn.model_selection import train_test_split 5 | from prefetch_generator import BackgroundGenerator 6 | import torch 7 | class DataLoaderX(DataLoader): 8 | def __iter__(self): 9 | return BackgroundGenerator(super().__iter__()) 10 | 11 | def build_loader(config,data_size, data_path,unlab_data_path, pool_op_kernel_sizes, num_each_epoch): 12 | series_ids_train = subfiles(data_path, join=False, suffix='npz') 13 | 14 | if config.DATASET.WITH_VAL: 15 | 16 | series_ids_train, series_ids_val = train_test_split(series_ids_train, test_size=config.DATASET.VAL_SPLIT,random_state=42) 17 | val_dataset = flare22_dataset(config,series_ids_val,data_size, data_path, pool_op_kernel_sizes,is_train=False) 18 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) if config.DIS else None 19 | val_loader = DataLoaderX( 20 | dataset=val_dataset, 21 | sampler=val_sampler , 22 | batch_size = config.DATALOADER.BATCH_SIZE, 23 | num_workers=config.DATALOADER.NUM_WORKERS, 24 | pin_memory= config.DATALOADER.PIN_MEMORY, 25 | shuffle=False, 26 | drop_last=False 27 | ) 28 | else: 29 | val_loader = None 30 | 31 | 32 | train_dataset = flare22_dataset(config, data_size, data_path, unlab_data_path, pool_op_kernel_sizes, num_each_epoch,is_train=True) 33 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True) if config.DIS else None 34 | train_loader = DataLoaderX( 35 | train_dataset, 36 | sampler=train_sampler, 37 | batch_size = config.DATALOADER.BATCH_SIZE, 38 | num_workers=config.DATALOADER.NUM_WORKERS, 39 | pin_memory= config.DATALOADER.PIN_MEMORY, 40 | shuffle=True if train_sampler is None else False, 41 | drop_last=True 42 | ) 43 | return train_loader,val_loader 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | from .phtrans import PHTrans 2 | 3 | 4 | def build_coarse_model(config, is_VAL = False): 5 | if config.MODEL.COARSE.TYPE == 'phtrans': 6 | model = PHTrans( 7 | img_size = config.DATASET.COARSE.SIZE, 8 | base_num_features = config.MODEL.COARSE.BASE_NUM_FEATURES, 9 | num_classes = config.DATASET.COARSE.LABEL_CLASSES, 10 | num_only_conv_stage = config.MODEL.COARSE.NUM_ONLY_CONV_STAGE, 11 | num_conv_per_stage = config.MODEL.COARSE.NUM_CONV_PER_STAGE, 12 | feat_map_mul_on_downscale = config.MODEL.COARSE.FEAT_MAP_MUL_ON_DOWNSCALE, 13 | pool_op_kernel_sizes = config.MODEL.COARSE.POOL_OP_KERNEL_SIZES, 14 | conv_kernel_sizes = config.MODEL.COARSE.CONV_KERNEL_SIZES, 15 | dropout_p = config.MODEL.COARSE.DROPOUT_P, 16 | deep_supervision = config.MODEL.DEEP_SUPERVISION if not is_VAL else False, 17 | max_num_features = config.MODEL.COARSE.MAX_NUM_FEATURES, 18 | depths = config.MODEL.COARSE.DEPTHS, 19 | num_heads = config.MODEL.COARSE.NUM_HEADS, 20 | window_size = config.MODEL.COARSE.WINDOW_SIZE, 21 | mlp_ratio = config.MODEL.COARSE.MLP_RATIO, 22 | qkv_bias = config.MODEL.COARSE.DROP_RATE, 23 | qk_scale = config.MODEL.COARSE.QK_SCALE, 24 | drop_rate = config.MODEL.COARSE.DROP_RATE, 25 | drop_path_rate = config.MODEL.COARSE.DROP_PATH_RATE, 26 | ) 27 | else: 28 | raise NotImplementedError(f"Unkown model: {config.MODEL.COARSE.TYPE}") 29 | 30 | return model 31 | 32 | 33 | 34 | def build_fine_model(config, is_VAL = False): 35 | if config.MODEL.FINE.TYPE == 'phtrans': 36 | model = PHTrans( 37 | img_size = config.DATASET.FINE.SIZE, 38 | base_num_features = config.MODEL.FINE.BASE_NUM_FEATURES, 39 | num_classes = config.DATASET.FINE.LABEL_CLASSES, 40 | num_only_conv_stage = config.MODEL.FINE.NUM_ONLY_CONV_STAGE, 41 | num_conv_per_stage = config.MODEL.FINE.NUM_CONV_PER_STAGE, 42 | feat_map_mul_on_downscale = config.MODEL.FINE.FEAT_MAP_MUL_ON_DOWNSCALE, 43 | pool_op_kernel_sizes = config.MODEL.FINE.POOL_OP_KERNEL_SIZES, 44 | conv_kernel_sizes = config.MODEL.FINE.CONV_KERNEL_SIZES, 45 | dropout_p = config.MODEL.FINE.DROPOUT_P, 46 | deep_supervision = config.MODEL.DEEP_SUPERVISION if not is_VAL else False, 47 | max_num_features = config.MODEL.FINE.MAX_NUM_FEATURES, 48 | depths = config.MODEL.FINE.DEPTHS, 49 | num_heads = config.MODEL.FINE.NUM_HEADS, 50 | window_size = config.MODEL.FINE.WINDOW_SIZE, 51 | mlp_ratio = config.MODEL.FINE.MLP_RATIO, 52 | qkv_bias = config.MODEL.FINE.DROP_RATE, 53 | qk_scale = config.MODEL.FINE.QK_SCALE, 54 | drop_rate = config.MODEL.FINE.DROP_RATE, 55 | drop_path_rate = config.MODEL.FINE.DROP_PATH_RATE, 56 | ) 57 | else: 58 | raise NotImplementedError(f"Unkown model: {config.MODEL.FINE.TYPE}") 59 | 60 | return model 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.scheduler.cosine_lr import CosineLRScheduler 3 | from timm.scheduler.step_lr import StepLRScheduler 4 | from timm.scheduler.scheduler import Scheduler 5 | 6 | 7 | def build_scheduler(config, optimizer, n_iter_per_epoch): 8 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 9 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 10 | decay_steps = int( 11 | config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 12 | 13 | lr_scheduler = None 14 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 15 | lr_scheduler = CosineLRScheduler( 16 | optimizer, 17 | t_initial=num_steps, 18 | cycle_mul=1., 19 | lr_min=config.TRAIN.MIN_LR, 20 | warmup_lr_init=config.TRAIN.WARMUP_LR, 21 | warmup_t=warmup_steps, 22 | cycle_limit=1, 23 | t_in_epochs=False, 24 | ) 25 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 26 | lr_scheduler = LinearLRScheduler( 27 | optimizer, 28 | t_initial=num_steps, 29 | lr_min_rate=0.01, 30 | warmup_lr_init=config.TRAIN.WARMUP_LR, 31 | warmup_t=warmup_steps, 32 | t_in_epochs=False, 33 | ) 34 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 35 | lr_scheduler = StepLRScheduler( 36 | optimizer, 37 | decay_t=decay_steps, 38 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 39 | warmup_lr_init=config.TRAIN.WARMUP_LR, 40 | warmup_t=warmup_steps, 41 | t_in_epochs=False, 42 | ) 43 | 44 | return lr_scheduler 45 | 46 | 47 | class LinearLRScheduler(Scheduler): 48 | def __init__(self, 49 | optimizer: torch.optim.Optimizer, 50 | t_initial: int, 51 | lr_min_rate: float, 52 | warmup_t=0, 53 | warmup_lr_init=0., 54 | t_in_epochs=True, 55 | noise_range_t=None, 56 | noise_pct=0.67, 57 | noise_std=1.0, 58 | noise_seed=42, 59 | initialize=True, 60 | ) -> None: 61 | super().__init__( 62 | optimizer, param_group_field="lr", 63 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 64 | initialize=initialize) 65 | 66 | self.t_initial = t_initial 67 | self.lr_min_rate = lr_min_rate 68 | self.warmup_t = warmup_t 69 | self.warmup_lr_init = warmup_lr_init 70 | self.t_in_epochs = t_in_epochs 71 | if self.warmup_t: 72 | self.warmup_steps = [(v - warmup_lr_init) / 73 | self.warmup_t for v in self.base_values] 74 | super().update_groups(self.warmup_lr_init) 75 | else: 76 | self.warmup_steps = [1 for _ in self.base_values] 77 | 78 | def _get_lr(self, t): 79 | if t < self.warmup_t: 80 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 81 | else: 82 | t = t - self.warmup_t 83 | total_t = self.t_initial - self.warmup_t 84 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) 85 | for v in self.base_values] 86 | return lrs 87 | 88 | def get_epoch_values(self, epoch: int): 89 | if self.t_in_epochs: 90 | return self._get_lr(epoch) 91 | else: 92 | return None 93 | 94 | def get_update_values(self, num_updates: int): 95 | if not self.t_in_epochs: 96 | return self._get_lr(num_updates) 97 | else: 98 | return None 99 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class AverageMeter(object): 6 | def __init__(self): 7 | self.initialized = False 8 | self.val = None 9 | self.avg = None 10 | self.sum = None 11 | self.count = None 12 | 13 | def initialize(self, val, weight): 14 | self.val = val 15 | self.avg = val 16 | self.sum = np.multiply(val, weight) 17 | self.count = weight 18 | self.initialized = True 19 | 20 | def update(self, val, weight=1): 21 | if not self.initialized: 22 | self.initialize(val, weight) 23 | else: 24 | self.add(val, weight) 25 | 26 | def add(self, val, weight): 27 | self.val = val 28 | self.sum = np.add(self.sum, np.multiply(val, weight)) 29 | self.count = self.count + weight 30 | self.avg = self.sum / self.count 31 | 32 | @property 33 | def value(self): 34 | return np.round(self.val, 4) 35 | 36 | @property 37 | def average(self): 38 | return np.round(self.avg, 4) 39 | 40 | 41 | def run_online_evaluation(output, target): 42 | if isinstance(output, list): 43 | output = output[0] 44 | if isinstance(target, list): 45 | target = target[0] 46 | online_eval_foreground_dc = [] 47 | online_eval_tp = [] 48 | online_eval_fp = [] 49 | online_eval_fn = [] 50 | with torch.no_grad(): 51 | num_classes = output.shape[1] 52 | output_softmax = F.softmax(output, 1) 53 | output_seg = output_softmax.argmax(1) 54 | target = target[:, 0] 55 | axes = tuple(range(1, len(target.shape))) 56 | tp_hard = torch.zeros( 57 | (target.shape[0], num_classes - 1)).to(output_seg.device.index) 58 | fp_hard = torch.zeros( 59 | (target.shape[0], num_classes - 1)).to(output_seg.device.index) 60 | fn_hard = torch.zeros( 61 | (target.shape[0], num_classes - 1)).to(output_seg.device.index) 62 | for c in range(1, num_classes): 63 | tp_hard[:, c - 1] = sum_tensor( 64 | (output_seg == c).float() * (target == c).float(), axes=axes) 65 | fp_hard[:, c - 1] = sum_tensor( 66 | (output_seg == c).float() * (target != c).float(), axes=axes) 67 | fn_hard[:, c - 1] = sum_tensor( 68 | (output_seg != c).float() * (target == c).float(), axes=axes) 69 | 70 | tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() 71 | fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() 72 | fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() 73 | 74 | online_eval_foreground_dc.append( 75 | list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) 76 | online_eval_tp.append(list(tp_hard)) 77 | online_eval_fp.append(list(fp_hard)) 78 | online_eval_fn.append(list(fn_hard)) 79 | 80 | online_eval_tp = np.sum(online_eval_tp, 0) 81 | online_eval_fp = np.sum(online_eval_fp, 0) 82 | online_eval_fn = np.sum(online_eval_fn, 0) 83 | 84 | global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in 85 | zip(online_eval_tp, online_eval_fp, online_eval_fn)] 86 | if not np.isnan(i)] 87 | average_global_dc = np.mean(global_dc_per_class) 88 | return average_global_dc 89 | 90 | 91 | def sum_tensor(inp, axes, keepdim=False): 92 | axes = np.unique(axes).astype(int) 93 | if keepdim: 94 | for ax in axes: 95 | inp = inp.sum(int(ax), keepdim=True) 96 | else: 97 | for ax in sorted(axes, reverse=True): 98 | inp = inp.sum(int(ax)) 99 | return inp 100 | -------------------------------------------------------------------------------- /coarse_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from loguru import logger 3 | from data import build_loader 4 | from trainer import Trainer 5 | from utils import seed_torch 6 | from losses import build_loss 7 | from datetime import datetime 8 | import wandb 9 | from config import get_config 10 | from models import build_coarse_model 11 | from lr_scheduler import build_scheduler 12 | from optimizer import build_optimizer 13 | import os 14 | import torch.backends.cudnn as cudnn 15 | import numpy as np 16 | import torch 17 | import torch.multiprocessing as mp 18 | import torch.distributed as dist 19 | 20 | 21 | def parse_option(): 22 | parser = argparse.ArgumentParser("FLARE2022_corase_training") 23 | parser.add_argument('--cfg', type=str, metavar="FILE", 24 | help='path to config file') 25 | parser.add_argument( 26 | "--opts", 27 | help="Modify config options by adding 'KEY VALUE' pairs. ", 28 | default=None, 29 | nargs='+', 30 | ) 31 | parser.add_argument("--tag", help='tag of experiment') 32 | parser.add_argument("-wm", "--wandb_mode", default="offline") 33 | parser.add_argument('-bs', '--batch-size', type=int, 34 | help="batch size for single GPU") 35 | parser.add_argument('-wd', '--with_distributed', help="training without DDP", 36 | required=False, default=False, action="store_true") 37 | parser.add_argument('-ws', '--world_size', type=int, 38 | help="process number for DDP") 39 | args = parser.parse_args() 40 | config = get_config(args) 41 | 42 | return args, config 43 | 44 | 45 | def main(config): 46 | if config.DIS: 47 | mp.spawn(main_worker, 48 | args=(config,), 49 | nprocs=config.WORLD_SIZE,) 50 | else: 51 | main_worker(0, config) 52 | 53 | 54 | def main_worker(local_rank, config): 55 | if local_rank == 0: 56 | config.defrost() 57 | config.EXPERIMENT_ID = f"{config.WANDB.TAG}_{datetime.now().strftime('%y%m%d_%H%M%S')}" 58 | config.freeze() 59 | wandb.init(project=config.WANDB.COARSE_PROJECT, 60 | name=config.EXPERIMENT_ID, config=config, mode=config.WANDB.MODE) 61 | np.set_printoptions(formatter={'float': '{: 0.4f}'.format}, suppress=True) 62 | torch.cuda.set_device(local_rank) 63 | if config.DIS: 64 | dist.init_process_group( 65 | "nccl", init_method='env://', rank=local_rank, world_size=config.WORLD_SIZE) 66 | seed = config.SEED + local_rank 67 | seed_torch(seed) 68 | cudnn.benchmark = True 69 | 70 | train_loader, val_loader = build_loader(config, 71 | config.DATASET.COARSE.SIZE, 72 | config.DATASET.COARSE.PROPRECESS_PATH, 73 | config.DATASET.COARSE.PROPRECESS_UL_PATH, 74 | config.MODEL.COARSE.POOL_OP_KERNEL_SIZES, 75 | config.DATASET.COARSE.NUM_EACH_EPOCH 76 | ) 77 | model = build_coarse_model(config).cuda() 78 | if config.DIS: 79 | model = torch.nn.parallel.DistributedDataParallel( 80 | model, device_ids=[local_rank], find_unused_parameters=True) 81 | logger.info(f'\n{model}\n') 82 | loss = build_loss(config.MODEL.DEEP_SUPERVISION, 83 | config.MODEL.COARSE.POOL_OP_KERNEL_SIZES) 84 | optimizer = build_optimizer(config, model) 85 | lr_scheduler = build_scheduler(config, optimizer, len(train_loader)) 86 | trainer = Trainer(config=config, 87 | train_loader=train_loader, 88 | val_loader=val_loader, 89 | model=model, 90 | loss=loss, 91 | optimizer=optimizer, 92 | lr_scheduler=lr_scheduler) 93 | trainer.train() 94 | 95 | 96 | if __name__ == '__main__': 97 | os.environ["MASTER_ADDR"] = "localhost" 98 | os.environ["MASTER_PORT"] = "10000" 99 | _, config = parse_option() 100 | 101 | main(config) 102 | -------------------------------------------------------------------------------- /fine_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from loguru import logger 3 | from data import build_loader 4 | from trainer import Trainer 5 | from utils import seed_torch 6 | from losses import build_loss 7 | from datetime import datetime 8 | import wandb 9 | from config import get_config 10 | from models import build_fine_model 11 | from lr_scheduler import build_scheduler 12 | from optimizer import build_optimizer 13 | import os 14 | import torch.backends.cudnn as cudnn 15 | import numpy as np 16 | import torch 17 | import torch.multiprocessing as mp 18 | import torch.distributed as dist 19 | from batchgenerators.utilities.file_and_folder_operations import * 20 | 21 | 22 | def parse_option(): 23 | parser = argparse.ArgumentParser("FLARE2022_fine_training") 24 | parser.add_argument('--cfg', type=str, metavar="FILE", 25 | help='path to config file') 26 | parser.add_argument( 27 | "--opts", 28 | help="Modify config options by adding 'KEY VALUE' pairs. ", 29 | default=None, 30 | nargs='+', 31 | ) 32 | parser.add_argument("--tag", help='tag of experiment') 33 | parser.add_argument("-wm", "--wandb_mode", default="offline") 34 | parser.add_argument('-bs', '--batch-size', type=int, 35 | help="batch size for single GPU") 36 | parser.add_argument('-wd', '--with_distributed', help="training without DDP", 37 | required=False, default=False, action="store_true") 38 | parser.add_argument('-ws', '--world_size', type=int, 39 | help="process number for DDP") 40 | args = parser.parse_args() 41 | config = get_config(args) 42 | 43 | return args, config 44 | 45 | 46 | def main(config): 47 | if config.DIS: 48 | mp.spawn(main_worker, 49 | args=(config,), 50 | nprocs=config.WORLD_SIZE,) 51 | else: 52 | main_worker(0, config) 53 | 54 | 55 | def main_worker(local_rank, config): 56 | if local_rank == 0: 57 | config.defrost() 58 | config.EXPERIMENT_ID = f"{config.WANDB.TAG}_{datetime.now().strftime('%y%m%d_%H%M%S')}" 59 | config.freeze() 60 | wandb.init(project=config.WANDB.FINE_PROJECT, 61 | name=config.EXPERIMENT_ID, config=config, mode=config.WANDB.MODE) 62 | np.set_printoptions(formatter={'float': '{: 0.4f}'.format}, suppress=True) 63 | torch.cuda.set_device(local_rank) 64 | if config.DIS: 65 | dist.init_process_group( 66 | "nccl", init_method='env://', rank=local_rank, world_size=config.WORLD_SIZE) 67 | seed = config.SEED + local_rank 68 | seed_torch(seed) 69 | cudnn.benchmark = True 70 | 71 | train_loader, val_loader = build_loader(config, 72 | config.DATASET.FINE.SIZE, 73 | config.DATASET.FINE.PROPRECESS_PATH, 74 | config.DATASET.FINE.PROPRECESS_UL_PATH, 75 | config.MODEL.FINE.POOL_OP_KERNEL_SIZES, 76 | config.DATASET.FINE.NUM_EACH_EPOCH 77 | ) 78 | model = build_fine_model(config).cuda() 79 | if config.DIS: 80 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 81 | model = torch.nn.parallel.DistributedDataParallel( 82 | model, device_ids=[local_rank], find_unused_parameters=True) 83 | logger.info(f'\n{model}\n') 84 | loss = build_loss(config.MODEL.DEEP_SUPERVISION, 85 | config.MODEL.FINE.POOL_OP_KERNEL_SIZES) 86 | optimizer = build_optimizer(config, model) 87 | lr_scheduler = build_scheduler(config, optimizer, len(train_loader)) 88 | trainer = Trainer(config=config, 89 | train_loader=train_loader, 90 | val_loader=val_loader, 91 | model=model, 92 | loss=loss, 93 | optimizer=optimizer, 94 | lr_scheduler=lr_scheduler) 95 | trainer.train() 96 | 97 | 98 | if __name__ == '__main__': 99 | os.environ["MASTER_ADDR"] = "localhost" 100 | os.environ["MASTER_PORT"] = "10000" 101 | _, config = parse_option() 102 | 103 | main(config) 104 | -------------------------------------------------------------------------------- /data/data_augmentation.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | from batchgenerators.augmentations.utils import resize_segmentation 4 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 5 | 6 | default_3D_augmentation_params = { 7 | 8 | # "do_elastic": False, 9 | "elastic_deform_alpha": (0., 900.), 10 | "elastic_deform_sigma": (9., 13.), 11 | "p_eldef": 0.2, 12 | 13 | # "do_scaling": True, 14 | "scale_range": (0.85, 1.25), 15 | "independent_scale_factor_for_each_axis": False, 16 | "p_independent_scale_per_axis": 1, 17 | "p_scale": 0.2, 18 | 19 | # "do_rotation": True, 20 | "rotation_x": (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), 21 | "rotation_y": (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), 22 | "rotation_z": (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), 23 | "rotation_p_per_axis": 1, 24 | "p_rot": 0.2, 25 | 26 | # "random_crop": False, 27 | "random_crop_dist_to_border": None, 28 | 29 | # "do_gamma": True, 30 | "gamma_retain_stats": True, 31 | "gamma_range": (0.7, 1.5), 32 | "p_gamma": 0.3, 33 | 34 | # "do_mirror": True, 35 | "mirror_axes": (0, 1, 2), 36 | 37 | "border_mode_data": "constant", 38 | 39 | # "do_additive_brightness": False, 40 | "additive_brightness_p_per_sample": 0.15, 41 | "additive_brightness_p_per_channel": 0.5, 42 | "additive_brightness_mu": 0.0, 43 | "additive_brightness_sigma": 0.1 44 | } 45 | 46 | default_2D_augmentation_params = deepcopy(default_3D_augmentation_params) 47 | 48 | default_2D_augmentation_params["elastic_deform_alpha"] = (0., 200.) 49 | default_2D_augmentation_params["elastic_deform_sigma"] = (9., 13.) 50 | default_2D_augmentation_params["rotation_x"] = ( 51 | -180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) 52 | default_2D_augmentation_params["rotation_y"] = ( 53 | -0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) 54 | default_2D_augmentation_params["rotation_z"] = ( 55 | -0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) 56 | 57 | 58 | def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): 59 | if isinstance(rot_x, (tuple, list)): 60 | rot_x = max(np.abs(rot_x)) 61 | if isinstance(rot_y, (tuple, list)): 62 | rot_y = max(np.abs(rot_y)) 63 | if isinstance(rot_z, (tuple, list)): 64 | rot_z = max(np.abs(rot_z)) 65 | rot_x = min(90 / 360 * 2. * np.pi, rot_x) 66 | rot_y = min(90 / 360 * 2. * np.pi, rot_y) 67 | rot_z = min(90 / 360 * 2. * np.pi, rot_z) 68 | from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d 69 | coords = np.array(final_patch_size) 70 | final_shape = np.copy(coords) 71 | if len(coords) == 3: 72 | final_shape = np.max( 73 | np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) 74 | final_shape = np.max( 75 | np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) 76 | final_shape = np.max( 77 | np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) 78 | elif len(coords) == 2: 79 | final_shape = np.max( 80 | np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) 81 | final_shape /= min(scale_range) 82 | return final_shape.astype(int) 83 | 84 | 85 | class DownsampleSegForDSTransform(AbstractTransform): 86 | 87 | def __init__(self, ds_scales=(1, 0.5, 0.25), order=0, input_key="seg", output_key="seg", axes=None): 88 | self.axes = axes 89 | self.output_key = output_key 90 | self.input_key = input_key 91 | self.order = order 92 | self.ds_scales = ds_scales 93 | 94 | def __call__(self, **data_dict): 95 | data_dict[self.output_key] = downsample_seg_for_ds_transform(data_dict[self.input_key], self.ds_scales, 96 | self.order, self.axes) 97 | return data_dict 98 | 99 | 100 | def downsample_seg_for_ds_transform(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), order=0, axes=None): 101 | if axes is None: 102 | axes = list(range(1, len(seg.shape))) 103 | output = [] 104 | for s in ds_scales: 105 | if all([i == 1 for i in s]): 106 | output.append(seg) 107 | else: 108 | new_shape = np.array(seg.shape).astype(float) 109 | for i, a in enumerate(axes): 110 | new_shape[a] *= s[i] 111 | new_shape = np.round(new_shape).astype(int) 112 | out_seg = np.zeros(new_shape, dtype=seg.dtype) 113 | for c in range(seg.shape[0]): 114 | out_seg[c] = resize_segmentation(seg[c], new_shape[1:], order) 115 | output.append(out_seg) 116 | return output 117 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from batchgenerators.utilities.file_and_folder_operations import * 4 | import shutil 5 | import traceback 6 | from multiprocessing import Pool, cpu_count 7 | from config import get_config_no_args 8 | from data.utils import load_data, clip_and_normalize_mean_std, resize_segmentation, change_axes_of_image, crop_image_according_to_mask, create_two_class_mask 9 | from collections import OrderedDict 10 | from skimage.transform import resize 11 | 12 | def run_prepare_data(config, is_overwrite, is_multiprocessing=True): 13 | 14 | data_prepare = data_process(config, is_overwrite) 15 | if is_multiprocessing: 16 | pool = Pool(int(cpu_count() * 0.2)) 17 | for data in data_prepare.data_list: 18 | try: 19 | pool.apply_async(data_prepare.process, (data,)) 20 | except Exception as err: 21 | traceback.print_exc() 22 | print('Create image/label throws exception %s, with series_id %s!' % 23 | (err, data_prepare.data_info)) 24 | 25 | pool.close() 26 | pool.join() 27 | else: 28 | for data in data_prepare.data_list: 29 | data_prepare.process(data) 30 | 31 | 32 | class data_process(object): 33 | def __init__(self, config, is_overwrite=True): 34 | self.config = config 35 | self.coarse_size = self.config.DATASET.COARSE.SIZE 36 | self.fine_size = self.config.DATASET.FINE.SIZE 37 | self.nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION 38 | self.extend_size = self.config.DATASET.EXTEND_SIZE 39 | 40 | self.image_path = config.DATASET.TRAIN_IMAGE_PATH 41 | self.mask_path = config.DATASET.TRAIN_MASK_PATH 42 | self.preprocess_coarse_path = config.DATASET.COARSE.PROPRECESS_PATH 43 | self.preprocess_fine_path = config.DATASET.FINE.PROPRECESS_PATH 44 | self.data_list = subfiles(self.image_path, join=False, suffix='nii.gz') 45 | if is_overwrite and isdir(self.preprocess_coarse_path): 46 | shutil.rmtree(self.preprocess_coarse_path) 47 | os.makedirs(self.preprocess_coarse_path, exist_ok=True) 48 | if is_overwrite and isdir(self.preprocess_fine_path): 49 | shutil.rmtree(self.preprocess_fine_path) 50 | os.makedirs(self.preprocess_fine_path, exist_ok=True) 51 | 52 | def process(self, image_id): 53 | data_id = image_id.split("_0000.nii.gz")[0] 54 | image, image_spacing, image_direction = load_data( 55 | join(self.image_path, data_id + "_0000.nii.gz")) 56 | mask, _, mask_direction = load_data( 57 | join(self.mask_path, data_id + ".nii.gz")) 58 | assert image_direction.all() == mask_direction.all() 59 | if self.nor_dir: 60 | image = change_axes_of_image(image, image_direction) 61 | mask = change_axes_of_image(mask, mask_direction) 62 | data_info = OrderedDict() 63 | data_info["raw_shape"] = image.shape 64 | data_info["raw_spacing"] = image_spacing 65 | resize_spacing = image_spacing*image.shape/self.coarse_size 66 | data_info["resize_spacing"] = resize_spacing 67 | data_info["image_direction"] = image_direction 68 | with open(os.path.join(self.preprocess_coarse_path, "%s_info.pkl" % data_id), 'wb') as f: 69 | pickle.dump(data_info, f) 70 | print(data_id, image.shape) 71 | 72 | image_resize = resize(image, self.coarse_size, 73 | order=3, mode='edge', anti_aliasing=False) 74 | 75 | mask_resize = resize_segmentation(mask, self.coarse_size, order=0) 76 | mask_binary = create_two_class_mask(mask_resize) 77 | 78 | image_normal = clip_and_normalize_mean_std(image_resize) 79 | 80 | np.savez_compressed(os.path.join(self.preprocess_coarse_path, "%s.npz" % 81 | data_id), data=image_normal[None, ...], seg=mask_binary[None, ...]) 82 | margin = [int(self.extend_size / image_spacing[0]), 83 | int(self.extend_size / image_spacing[1]), 84 | int(self.extend_size / image_spacing[2])] 85 | crop_image, crop_mask = crop_image_according_to_mask( 86 | image, np.array(mask, dtype=int), margin) 87 | 88 | data_info_crop = OrderedDict() 89 | data_info_crop["raw_shape"] = image.shape 90 | data_info_crop["crop_shape"] = crop_image.shape 91 | data_info_crop["raw_spacing"] = image_spacing 92 | resize_crop_spacing = image_spacing*crop_image.shape/self.fine_size 93 | data_info_crop["resize_crop_spacing"] = resize_crop_spacing 94 | data_info_crop["image_direction"] = image_direction 95 | with open(os.path.join(self.preprocess_fine_path, "%s_info.pkl" % data_id), 'wb') as f: 96 | pickle.dump(data_info_crop, f) 97 | 98 | crop_image_resize = resize(crop_image, self.fine_size, 99 | order=3, mode='edge', anti_aliasing=False) 100 | crop_mask_resize = resize_segmentation( 101 | crop_mask, self.fine_size, order=0) 102 | crop_image_normal = clip_and_normalize_mean_std(crop_image_resize) 103 | np.savez_compressed(os.path.join(self.preprocess_fine_path, "%s.npz" % data_id), 104 | data=crop_image_normal[None, ...], seg=crop_mask_resize[None, ...]) 105 | print('End processing %s.' % data_id) 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | config = get_config_no_args() 111 | 112 | run_prepare_data(config, True, True) 113 | -------------------------------------------------------------------------------- /unlabel_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from batchgenerators.utilities.file_and_folder_operations import * 4 | import shutil 5 | import traceback 6 | from multiprocessing import Pool, cpu_count 7 | from config import get_config_no_args 8 | from data.utils import crop_image_according_to_mask, load_data, clip_and_normalize_mean_std, resize_segmentation, change_axes_of_image, create_two_class_mask 9 | from collections import OrderedDict 10 | from skimage.transform import resize 11 | 12 | 13 | def run_prepare_data(config, is_overwrite, is_multiprocessing=True): 14 | 15 | data_prepare = data_process(config, is_overwrite) 16 | if is_multiprocessing: 17 | pool = Pool(int(cpu_count() * 0.2)) 18 | for data in data_prepare.data_list: 19 | try: 20 | pool.apply_async(data_prepare.process, (data,)) 21 | except Exception as err: 22 | traceback.print_exc() 23 | print('Create image/label throws exception %s, with series_id %s!' % 24 | (err, data_prepare.data_info)) 25 | 26 | pool.close() 27 | pool.join() 28 | else: 29 | for data in data_prepare.data_list: 30 | data_prepare.process(data) 31 | 32 | class data_process(object): 33 | def __init__(self, config, is_overwrite=False): 34 | self.config = config 35 | self.coarse_size = self.config.DATASET.COARSE.SIZE 36 | self.fine_size = self.config.DATASET.FINE.SIZE 37 | self.nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION 38 | self.extend_size = self.config.DATASET.EXTEND_SIZE 39 | 40 | self.image_path = config.DATASET.TRAIN_UNLABELED_IMAGE_PATH 41 | self.mask_path = config.DATASET.TRAIN_UNLABELED_MASK_PATH 42 | self.preprocess_coarse_path = config.DATASET.COARSE.PROPRECESS_UL_PATH 43 | self.preprocess_fine_path = config.DATASET.FINE.PROPRECESS_UL_PATH 44 | self.data_list = subfiles(self.image_path, join=False, suffix='nii.gz') 45 | if is_overwrite and isdir(self.preprocess_coarse_path): 46 | shutil.rmtree(self.preprocess_coarse_path) 47 | os.makedirs(self.preprocess_coarse_path, exist_ok=True) 48 | if is_overwrite and isdir(self.preprocess_fine_path): 49 | shutil.rmtree(self.preprocess_fine_path) 50 | os.makedirs(self.preprocess_fine_path, exist_ok=True) 51 | 52 | def process(self, image_id): 53 | 54 | data_id = image_id.split("_0000.nii.gz")[0] 55 | 56 | image, image_spacing, image_direction = load_data( 57 | join(self.image_path, data_id + "_0000.nii.gz")) 58 | mask, _, mask_direction = load_data( 59 | join(self.mask_path, data_id + ".nii.gz")) 60 | assert image_direction.all() == mask_direction.all() 61 | print(data_id, image.shape) 62 | if self.nor_dir: 63 | image = change_axes_of_image(image, image_direction) 64 | mask = change_axes_of_image(mask, mask_direction) 65 | data_info = OrderedDict() 66 | 67 | data_info["raw_shape"] = image.shape 68 | data_info["raw_spacing"] = image_spacing 69 | resize_spacing = image_spacing*image.shape/self.coarse_size 70 | data_info["resize_spacing"] = resize_spacing 71 | data_info["image_direction"] = image_direction 72 | with open(os.path.join(self.preprocess_coarse_path, "%s_info.pkl" % data_id), 'wb') as f: 73 | pickle.dump(data_info, f) 74 | 75 | image_resize = resize(image, self.coarse_size, 76 | order=3, mode='edge', anti_aliasing=False) 77 | mask_resize = resize_segmentation( 78 | mask, self.coarse_size, order=0) 79 | mask_binary = create_two_class_mask(mask_resize) 80 | image_normal = clip_and_normalize_mean_std(image_resize) 81 | 82 | np.savez_compressed(os.path.join(self.preprocess_coarse_path, "%s.npz" % 83 | data_id), data=image_normal[None, ...], seg=mask_binary[None, ...]) 84 | 85 | 86 | margin = [int(self.extend_size / image_spacing[0]), 87 | int(self.extend_size / image_spacing[1]), 88 | int(self.extend_size / image_spacing[2])] 89 | crop_image, crop_mask = crop_image_according_to_mask( 90 | image, np.array(mask, dtype=int), margin) 91 | data_info_crop = OrderedDict() 92 | data_info_crop["raw_shape"] = image.shape 93 | data_info_crop["crop_shape"] = crop_image.shape 94 | data_info_crop["raw_spacing"] = image_spacing 95 | resize_crop_spacing = image_spacing*crop_image.shape/self.fine_size 96 | data_info_crop["resize_crop_spacing"] = resize_crop_spacing 97 | data_info_crop["image_direction"] = image_direction 98 | with open(os.path.join(self.preprocess_fine_path, "%s_info.pkl" % data_id), 'wb') as f: 99 | pickle.dump(data_info_crop, f) 100 | 101 | crop_image_resize = resize( 102 | crop_image, self.fine_size, order=3, mode='edge', anti_aliasing=False) 103 | crop_mask_resize = resize_segmentation( 104 | crop_mask, self.fine_size, order=0) 105 | crop_image_normal = clip_and_normalize_mean_std(crop_image_resize) 106 | np.savez_compressed(os.path.join(self.preprocess_fine_path, "%s.npz" % data_id), 107 | data=crop_image_normal[None, ...], seg=crop_mask_resize[None, ...]) 108 | 109 | print('End processing %s.' % data_id) 110 | 111 | if __name__ == '__main__': 112 | config = get_config_no_args() 113 | run_prepare_data(config, False, False) 114 | 115 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from batchgenerators.utilities.file_and_folder_operations import * 3 | from skimage.transform import resize 4 | import cv2 5 | import shutil 6 | import torch 7 | import SimpleITK as sitk 8 | import cc3d 9 | import fastremap 10 | import torch.nn.functional as F 11 | from scipy.ndimage import binary_fill_holes 12 | 13 | def load_pickle(file: str, mode: str = 'rb'): 14 | with open(file, mode) as f: 15 | a = pickle.load(f) 16 | return a 17 | 18 | 19 | def load_data(data_path): 20 | 21 | data_itk = sitk.ReadImage(data_path) 22 | data_npy = sitk.GetArrayFromImage(data_itk)[None].astype(np.float32) 23 | data_spacing = np.array(data_itk.GetSpacing())[[2, 1, 0]] 24 | direction = data_itk.GetDirection() 25 | direction = np.array((direction[8], direction[4], direction[0])) 26 | return data_npy[0], data_spacing, direction 27 | 28 | 29 | def change_axes_of_image(npy_image, orientation): 30 | if orientation[0] < 0: 31 | npy_image = np.flip(npy_image, axis=0) 32 | if orientation[1] > 0: 33 | npy_image = np.flip(npy_image, axis=1) 34 | if orientation[2] > 0: 35 | npy_image = np.flip(npy_image, axis=2) 36 | return npy_image 37 | 38 | 39 | def clip_and_normalize_mean_std(image): 40 | mean = np.mean(image) 41 | std = np.std(image) 42 | 43 | image = (image - mean) / (std + 1e-5) 44 | return image 45 | 46 | 47 | def resize_segmentation(segmentation, new_shape, order=3): 48 | tpe = segmentation.dtype 49 | unique_labels = np.unique(segmentation) 50 | assert len(segmentation.shape) == len( 51 | new_shape), "new shape must have same dimensionality as segmentation" 52 | if order == 0: 53 | return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe) 54 | else: 55 | reshaped = np.zeros(new_shape, dtype=segmentation.dtype) 56 | 57 | for i, c in enumerate(unique_labels): 58 | mask = segmentation == c 59 | reshaped_multihot = resize(mask.astype( 60 | float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) 61 | reshaped[reshaped_multihot >= 0.5] = c 62 | return reshaped 63 | 64 | def maybe_to_torch(d): 65 | if isinstance(d, list): 66 | d = [maybe_to_torch(i) if not isinstance( 67 | i, torch.Tensor) else i for i in d] 68 | elif not isinstance(d, torch.Tensor): 69 | d = torch.from_numpy(d).float() 70 | return d 71 | 72 | 73 | def create_two_class_mask(mask): 74 | 75 | mask = np.clip(mask, 0, 1) 76 | mask = binary_fill_holes(mask, origin=1,) 77 | return mask 78 | 79 | 80 | def extract_topk_largest_candidates(npy_mask: np.array, label_unique, out_num_label: List) -> np.array: 81 | mask_shape = npy_mask.shape 82 | out_mask = np.zeros( 83 | [mask_shape[1], mask_shape[2], mask_shape[3]], np.uint8) 84 | for i in range(1, mask_shape[0]): 85 | t_mask = npy_mask[i].copy() 86 | keep_topk_largest_connected_object( 87 | t_mask, out_num_label, out_mask, label_unique[i]) 88 | 89 | return out_mask 90 | 91 | 92 | def keep_topk_largest_connected_object(npy_mask, k, out_mask, out_label): 93 | labels_out = cc3d.connected_components(npy_mask, connectivity=26) 94 | areas = {} 95 | for label, extracted in cc3d.each(labels_out, binary=True, in_place=True): 96 | areas[label] = fastremap.foreground(extracted) 97 | candidates = sorted(areas.items(), key=lambda item: item[1], reverse=True) 98 | 99 | for i in range(min(k, len(candidates))): 100 | out_mask[labels_out == int(candidates[i][0])] = out_label 101 | 102 | 103 | def to_one_hot(seg, all_seg_labels=None): 104 | if all_seg_labels is None: 105 | all_seg_labels = np.unique(seg) 106 | result = np.zeros((len(all_seg_labels), *seg.shape), dtype=seg.dtype) 107 | for i, l in enumerate(all_seg_labels): 108 | result[i][seg == l] = 1 109 | return result 110 | 111 | 112 | def input_downsample(x, input_size): 113 | x = F.interpolate(x, size=input_size, mode='trilinear',align_corners=False) 114 | mean = torch.mean(x) 115 | std = torch.std(x) 116 | x = (x - mean) / (1e-5 + std) 117 | return x 118 | 119 | 120 | def output_upsample(x, output_size): 121 | x = F.interpolate(x, size=output_size, 122 | mode='trilinear', align_corners=False) 123 | return x 124 | 125 | 126 | def get_bbox_from_mask(mask, outside_value=0): 127 | mask_voxel_coords = np.where(mask != outside_value) 128 | minzidx = int(np.min(mask_voxel_coords[0])) 129 | maxzidx = int(np.max(mask_voxel_coords[0])) + 1 130 | minxidx = int(np.min(mask_voxel_coords[1])) 131 | maxxidx = int(np.max(mask_voxel_coords[1])) + 1 132 | minyidx = int(np.min(mask_voxel_coords[2])) 133 | maxyidx = int(np.max(mask_voxel_coords[2])) + 1 134 | return np.array([[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]) 135 | 136 | 137 | def crop_image_according_to_mask(npy_image, npy_mask, margin=None): 138 | if margin is None: 139 | margin = [20, 20, 20] 140 | 141 | bbox = get_bbox_from_mask(npy_mask) 142 | 143 | extend_bbox = np.concatenate( 144 | [np.max([[0, 0, 0], bbox[:, 0] - margin], axis=0)[:, np.newaxis], 145 | np.min([npy_image.shape, bbox[:, 1] + margin], axis=0)[:, np.newaxis]], axis=1) 146 | 147 | 148 | crop_mask = crop_to_bbox(npy_mask,extend_bbox) 149 | crop_image = crop_to_bbox(npy_image,extend_bbox) 150 | 151 | 152 | return crop_image, crop_mask 153 | 154 | 155 | def crop_to_bbox(image, bbox): 156 | assert len(image.shape) == 3, "only supports 3d images" 157 | resizer = (slice(bbox[0][0], bbox[0][1]), slice( 158 | bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) 159 | return image[resizer] 160 | 161 | 162 | def crop_image_according_to_bbox(npy_image, bbox, margin=None): 163 | if margin is None: 164 | margin = [20, 20, 20] 165 | 166 | image_shape = npy_image.shape 167 | extend_bbox = [[max(0, int(bbox[0][0]-margin[0])), 168 | min(image_shape[0], int(bbox[0][1]+margin[0]))], 169 | [max(0, int(bbox[1][0]-margin[1])), 170 | min(image_shape[1], int(bbox[1][1]+margin[1]))], 171 | [max(0, int(bbox[2][0]-margin[2])), 172 | min(image_shape[2], int(bbox[2][1]+margin[2]))]] 173 | 174 | 175 | crop_image = crop_to_bbox(npy_image, extend_bbox) 176 | 177 | 178 | return crop_image, extend_bbox 179 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from loguru import logger 5 | from tqdm import tqdm 6 | from utils import to_cuda 7 | from metrics import AverageMeter,run_online_evaluation 8 | import torch.distributed as dist 9 | import wandb 10 | 11 | 12 | class Trainer: 13 | def __init__(self, config, train_loader,val_loader, model,loss,optimizer,lr_scheduler): 14 | self.config = config 15 | 16 | self.scaler = torch.cuda.amp.GradScaler(enabled=True) 17 | self.loss = loss 18 | self.model = model 19 | self.train_loader = train_loader 20 | self.val_loader = val_loader 21 | self.optimizer = optimizer 22 | self.lr_scheduler = lr_scheduler 23 | self.num_steps = len(self.train_loader) 24 | if self._get_rank()==0: 25 | self.checkpoint_dir = os.path.join(config.SAVE_DIR,config.EXPERIMENT_ID) 26 | 27 | os.makedirs(self.checkpoint_dir) 28 | def train(self): 29 | 30 | for epoch in range(1, self.config.TRAIN.EPOCHS+1): 31 | if self.config.DIS: 32 | self.train_loader.sampler.set_epoch(epoch) 33 | self._train_epoch(epoch) 34 | if self.val_loader is not None and epoch % self.config.TRAIN.VAL_NUM_EPOCHS == 0: 35 | results = self._valid_epoch(epoch) 36 | if self._get_rank()==0 : 37 | logger.info(f'## Info for epoch {epoch} ## ') 38 | for k, v in results.items(): 39 | logger.info(f'{str(k):15s}: {v}') 40 | if epoch % self.config.TRAIN.VAL_NUM_EPOCHS == 0 and self._get_rank()==0: 41 | self._save_checkpoint(epoch) 42 | 43 | 44 | def _train_epoch(self, epoch): 45 | self.batch_time = AverageMeter() 46 | self.data_time = AverageMeter() 47 | self.total_loss = AverageMeter() 48 | self.DICE = AverageMeter() 49 | 50 | self.model.train() 51 | 52 | 53 | tbar = tqdm(self.train_loader, ncols=150) 54 | tic = time.time() 55 | for idx, (data,_) in enumerate(tbar): 56 | self.data_time.update(time.time() - tic) 57 | img = to_cuda(data["data"]) 58 | gt = to_cuda(data["seg"]) 59 | self.optimizer.zero_grad() 60 | 61 | with torch.cuda.amp.autocast(enabled=self.config.AMP): 62 | pre = self.model(img) 63 | loss = self.loss(pre, gt) 64 | if self.config.AMP: 65 | self.scaler.scale(loss).backward() 66 | if self.config.TRAIN.DO_BACKPROP: 67 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 12) 68 | self.scaler.step(self.optimizer) 69 | self.scaler.update() 70 | else: 71 | loss.backward() 72 | if self.config.TRAIN.DO_BACKPROP: 73 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 12) 74 | self.optimizer.step() 75 | 76 | self.total_loss.update(loss.item()) 77 | self.batch_time.update(time.time() - tic) 78 | self.DICE.update(run_online_evaluation(pre, gt)) 79 | 80 | tbar.set_description( 81 | 'TRAIN ({}) | Loss: {} | DICE {} |B {} D {} |'.format( 82 | epoch, self.total_loss.average, self.DICE.average, self.batch_time.average, self.data_time.average)) 83 | tic = time.time() 84 | 85 | self.lr_scheduler.step_update(epoch * self.num_steps + idx) 86 | if self._get_rank()==0: 87 | wandb.log({'train/loss': self.total_loss.average, 88 | 'train/dice': self.DICE.average, 89 | 'train/lr': self.optimizer.param_groups[0]['lr']}, 90 | step=epoch) 91 | def _valid_epoch(self, epoch): 92 | logger.info('\n###### EVALUATION ######') 93 | self.batch_time = AverageMeter() 94 | self.data_time = AverageMeter() 95 | self.total_loss = AverageMeter() 96 | self.DICE = AverageMeter() 97 | 98 | self.model.eval() 99 | 100 | tbar = tqdm(self.val_loader, ncols=150) 101 | tic = time.time() 102 | with torch.no_grad(): 103 | 104 | for idx, (data, _) in enumerate(tbar): 105 | self.data_time.update(time.time() - tic) 106 | img = to_cuda(data["data"]) 107 | gt = to_cuda(data["seg"]) 108 | 109 | with torch.cuda.amp.autocast(enabled=self.config.AMP): 110 | 111 | pre = self.model(img) 112 | loss = self.loss(pre, gt) 113 | 114 | self.total_loss.update(loss.item()) 115 | self.batch_time.update(time.time() - tic) 116 | 117 | self.DICE.update(run_online_evaluation(pre, gt)) 118 | tbar.set_description( 119 | 'TEST ({}) | Loss: {} | DICE {} |B {} D {} |'.format( 120 | epoch, self.total_loss.average, self.DICE.average, self.batch_time.average, self.data_time.average)) 121 | tic = time.time() 122 | if self._get_rank()==0: 123 | wandb.log({'val/loss': self.total_loss.average, 124 | 'val/dice': self.DICE.average, 125 | 'val/batch_time': self.batch_time.average, 126 | 'val/data_time': self.data_time.average 127 | }, 128 | step=epoch) 129 | log = {'val_loss': self.total_loss.average, 130 | 'val_dice': self.DICE.average 131 | } 132 | return log 133 | def _get_rank(self): 134 | """get gpu id in distribution training.""" 135 | if not dist.is_available(): 136 | return 0 137 | if not dist.is_initialized(): 138 | return 0 139 | return dist.get_rank() 140 | 141 | def _save_checkpoint(self, epoch): 142 | state = { 143 | 'arch': type(self.model).__name__, 144 | 'epoch': epoch, 145 | 'state_dict': self.model.state_dict(), 146 | 'optimizer': self.optimizer.state_dict(), 147 | 'config': self.config 148 | } 149 | filename = os.path.join(self.checkpoint_dir, 150 | 'final_checkpoint.pth') 151 | logger.info(f'Saving a checkpoint: {filename} ...') 152 | torch.save(state, filename) 153 | return filename 154 | 155 | def _reset_metrics(self): 156 | self.batch_time = AverageMeter() 157 | self.data_time = AverageMeter() 158 | self.total_loss = AverageMeter() 159 | self.DICE = AverageMeter() 160 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | def softmax_helper(x): return F.softmax(x, 1) 6 | 7 | 8 | def build_loss(deep_supervision, pool_op_kernel_sizes): 9 | if deep_supervision: 10 | weight = get_weight_factors(len(pool_op_kernel_sizes)) 11 | loss = MultipleOutputLoss2(DC_and_CE_loss(), weight) 12 | else: 13 | loss = DC_and_CE_loss() 14 | return loss 15 | 16 | 17 | def get_weight_factors(net_numpool): 18 | weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) 19 | mask = np.array([True] + [True if i < net_numpool - 20 | 1 else False for i in range(1, net_numpool)]) 21 | weights[~mask] = 0 22 | weights = weights / weights.sum() 23 | return weights 24 | 25 | 26 | class MultipleOutputLoss2(nn.Module): 27 | def __init__(self, loss, weight_factors=None): 28 | super(MultipleOutputLoss2, self).__init__() 29 | self.weight_factors = weight_factors 30 | self.loss = loss 31 | 32 | def forward(self, x, y): 33 | assert isinstance(x, (tuple, list)), "x must be either tuple or list" 34 | assert isinstance(y, (tuple, list)), "y must be either tuple or list" 35 | if self.weight_factors is None: 36 | weights = [1] * len(x) 37 | else: 38 | weights = self.weight_factors 39 | 40 | l = weights[0] * self.loss(x[0], y[0]) 41 | for i in range(1, len(x)): 42 | if weights[i] != 0: 43 | l += weights[i] * self.loss(x[i], y[i]) 44 | return l 45 | 46 | 47 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 48 | 49 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 50 | if len(target.shape) == len(input.shape): 51 | assert target.shape[1] == 1 52 | target = target[:, 0] 53 | return super().forward(input, target.long()) 54 | 55 | 56 | class DC_and_CE_loss(nn.Module): 57 | def __init__(self, aggregate="sum", weight_ce=1, weight_dice=1, 58 | log_dice=False, ignore_label=None): 59 | super(DC_and_CE_loss, self).__init__() 60 | 61 | self.log_dice = log_dice 62 | self.weight_dice = weight_dice 63 | self.weight_ce = weight_ce 64 | self.aggregate = aggregate 65 | self.ce = RobustCrossEntropyLoss() 66 | 67 | self.ignore_label = ignore_label 68 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper) 69 | 70 | def forward(self, net_output, target): 71 | if self.ignore_label is not None: 72 | assert target.shape[1] == 1, 'not implemented for one hot encoding' 73 | mask = target != self.ignore_label 74 | target[~mask] = 0 75 | mask = mask.float() 76 | else: 77 | mask = None 78 | 79 | dc_loss = self.dc(net_output, target, 80 | loss_mask=mask) if self.weight_dice != 0 else 0 81 | if self.log_dice: 82 | dc_loss = -torch.log(-dc_loss) 83 | 84 | ce_loss = self.ce( 85 | net_output, target[:, 0].long()) if self.weight_ce != 0 else 0 86 | if self.ignore_label is not None: 87 | ce_loss *= mask[:, 0] 88 | ce_loss = ce_loss.sum() / mask.sum() 89 | 90 | if self.aggregate == "sum": 91 | result = self.weight_ce * ce_loss + self.weight_dice * dc_loss 92 | else: 93 | # reserved for other stuff (later) 94 | raise NotImplementedError("nah son") 95 | return result 96 | 97 | 98 | class SoftDiceLoss(nn.Module): 99 | def __init__(self, apply_nonlin=None, batch_dice=True, do_bg=False, smooth=1e-5): 100 | """ 101 | """ 102 | super(SoftDiceLoss, self).__init__() 103 | 104 | self.do_bg = do_bg 105 | self.batch_dice = batch_dice 106 | self.apply_nonlin = apply_nonlin 107 | self.smooth = smooth 108 | 109 | def forward(self, x, y, loss_mask=None): 110 | shp_x = x.shape 111 | 112 | if self.batch_dice: 113 | axes = [0] + list(range(2, len(shp_x))) 114 | else: 115 | axes = list(range(2, len(shp_x))) 116 | 117 | if self.apply_nonlin is not None: 118 | x = self.apply_nonlin(x) 119 | 120 | tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) 121 | 122 | nominator = 2 * tp + self.smooth 123 | denominator = 2 * tp + fp + fn + self.smooth 124 | 125 | dc = nominator / (denominator + 1e-8) 126 | 127 | if not self.do_bg: 128 | if self.batch_dice: 129 | dc = dc[1:] 130 | else: 131 | dc = dc[:, 1:] 132 | dc = dc.mean() 133 | 134 | return -dc 135 | 136 | 137 | def sum_tensor(inp, axes, keepdim=False): 138 | axes = np.unique(axes).astype(int) 139 | if keepdim: 140 | for ax in axes: 141 | inp = inp.sum(int(ax), keepdim=True) 142 | else: 143 | for ax in sorted(axes, reverse=True): 144 | inp = inp.sum(int(ax)) 145 | return inp 146 | 147 | 148 | def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): 149 | if axes is None: 150 | axes = tuple(range(2, len(net_output.size()))) 151 | 152 | shp_x = net_output.shape 153 | shp_y = gt.shape 154 | 155 | with torch.no_grad(): 156 | if len(shp_x) != len(shp_y): 157 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 158 | 159 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 160 | # if this is the case then gt is probably already a one hot encoding 161 | y_onehot = gt 162 | else: 163 | gt = gt.long() 164 | y_onehot = torch.zeros(shp_x, device=net_output.device) 165 | y_onehot.scatter_(1, gt, 1) 166 | 167 | tp = net_output * y_onehot 168 | fp = net_output * (1 - y_onehot) 169 | fn = (1 - net_output) * y_onehot 170 | tn = (1 - net_output) * (1 - y_onehot) 171 | 172 | if mask is not None: 173 | tp = torch.stack(tuple(x_i * mask[:, 0] 174 | for x_i in torch.unbind(tp, dim=1)), dim=1) 175 | fp = torch.stack(tuple(x_i * mask[:, 0] 176 | for x_i in torch.unbind(fp, dim=1)), dim=1) 177 | fn = torch.stack(tuple(x_i * mask[:, 0] 178 | for x_i in torch.unbind(fn, dim=1)), dim=1) 179 | tn = torch.stack(tuple(x_i * mask[:, 0] 180 | for x_i in torch.unbind(tn, dim=1)), dim=1) 181 | 182 | if square: 183 | tp = tp ** 2 184 | fp = fp ** 2 185 | fn = fn ** 2 186 | tn = tn ** 2 187 | 188 | if len(axes) > 0: 189 | tp = sum_tensor(tp, axes, keepdim=False) 190 | fp = sum_tensor(fp, axes, keepdim=False) 191 | fn = sum_tensor(fn, axes, keepdim=False) 192 | tn = sum_tensor(tn, axes, keepdim=False) 193 | 194 | return tp, fp, fn, tn 195 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | _C = CN() 6 | # ----------------------------------------------------------------------------- 7 | # Base settings 8 | # ----------------------------------------------------------------------------- 9 | _C.BASE = [''] 10 | 11 | _C.DIS = False 12 | _C.WORLD_SIZE = 1 13 | _C.SEED = 1234 14 | _C.AMP = True 15 | _C.EXPERIMENT_ID = "" 16 | _C.SAVE_DIR = "save_pth" 17 | _C.VAL_OUTPUT_PATH = "/home/lwt/code/flare/FLARE22-TwoStagePHTrans/save_results" 18 | _C.COARSE_MODEL_PATH = "/home/lwt/code/flare/FLARE22-TwoStagePHTrans/save_pth/phtrans_c_220813_001000" 19 | _C.FINE_MODEL_PATH = "/home/lwt/code/flare/FLARE22-TwoStagePHTrans/save_pth/phtrans_f_220813_001056" 20 | 21 | # ----------------------------------------------------------------------------- 22 | # Wandb settings 23 | # ----------------------------------------------------------------------------- 24 | _C.WANDB = CN() 25 | _C.WANDB.COARSE_PROJECT = "FLARE2022_COARSE" 26 | _C.WANDB.FINE_PROJECT = "FLARE2022_FINE" 27 | _C.WANDB.TAG = "PHTrans" 28 | _C.WANDB.MODE = "offline" 29 | 30 | # ----------------------------------------------------------------------------- 31 | # Data settings 32 | # ----------------------------------------------------------------------------- 33 | _C.DATASET = CN() 34 | _C.DATASET.WITH_VAL = False 35 | _C.DATASET.TRAIN_UNLABELED_IMAGE_PATH = "/home/lwt/data/flare22/UnlabeledCase" 36 | _C.DATASET.TRAIN_UNLABELED_MASK_PATH = "/home/lwt/data/flare22/Unlabel2000_phtranPre" 37 | _C.DATASET.TRAIN_IMAGE_PATH = "/home/lwt/data/flare22/Training/FLARE22_LabeledCase50/images" 38 | _C.DATASET.TRAIN_MASK_PATH = "/home/lwt/data/flare22/Training/FLARE22_LabeledCase50/labels" 39 | _C.DATASET.VAL_IMAGE_PATH = "/home/lwt/data/flare22/Validation" 40 | _C.DATASET.EXTEND_SIZE = 20 41 | _C.DATASET.IS_NORMALIZATION_DIRECTION = True 42 | 43 | _C.DATASET.COARSE = CN() 44 | _C.DATASET.COARSE.PROPRECESS_PATH = "/home/lwt/data_pro/flare22/Training/coarse_646464" 45 | _C.DATASET.COARSE.PROPRECESS_UL_PATH = "/home/lwt/data_pro/flare22/Unlabel2000_coarse_646464" 46 | _C.DATASET.COARSE.NUM_EACH_EPOCH = 512 47 | _C.DATASET.COARSE.SIZE = [64, 64, 64] 48 | _C.DATASET.COARSE.LABEL_CLASSES = 2 49 | 50 | _C.DATASET.FINE = CN() 51 | _C.DATASET.FINE.PROPRECESS_PATH = "/home/lwt/data_pro/flare22/Training/fine_96192192" 52 | _C.DATASET.FINE.PROPRECESS_UL_PATH = "/home/lwt/data_pro/flare22/Unlabel2000_fine_96192192" 53 | _C.DATASET.FINE.NUM_EACH_EPOCH = 512 54 | _C.DATASET.FINE.SIZE = [96, 192, 192] 55 | _C.DATASET.FINE.LABEL_CLASSES = 14 56 | 57 | _C.DATASET.DA = CN() 58 | _C.DATASET.DA.DO_2D_AUG = True 59 | _C.DATASET.DA.DO_ELASTIC = True 60 | _C.DATASET.DA.DO_SCALING = True 61 | _C.DATASET.DA.DO_ROTATION = True 62 | _C.DATASET.DA.RANDOM_CROP = False 63 | _C.DATASET.DA.DO_GAMMA = True 64 | _C.DATASET.DA.DO_MIRROR = False 65 | _C.DATASET.DA.DO_ADDITIVE_BRIGHTNESS = True 66 | 67 | # ----------------------------------------------------------------------------- 68 | # Dataloader settings 69 | # ----------------------------------------------------------------------------- 70 | _C.DATALOADER = CN() 71 | _C.DATALOADER.BATCH_SIZE = 1 72 | _C.DATALOADER.PIN_MEMORY = True 73 | _C.DATALOADER.NUM_WORKERS = 8 74 | 75 | # ----------------------------------------------------------------------------- 76 | # Model settings 77 | # ----------------------------------------------------------------------------- 78 | _C.MODEL = CN() 79 | _C.MODEL.DEEP_SUPERVISION = True 80 | 81 | _C.MODEL.COARSE = CN() 82 | _C.MODEL.COARSE.TYPE = "phtrans" 83 | _C.MODEL.COARSE.BASE_NUM_FEATURES = 16 84 | _C.MODEL.COARSE.NUM_ONLY_CONV_STAGE = 2 85 | _C.MODEL.COARSE.NUM_CONV_PER_STAGE = 2 86 | _C.MODEL.COARSE.FEAT_MAP_MUL_ON_DOWNSCALE = 2 87 | _C.MODEL.COARSE.POOL_OP_KERNEL_SIZES = [ 88 | [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] 89 | _C.MODEL.COARSE.CONV_KERNEL_SIZES = [ 90 | [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] 91 | _C.MODEL.COARSE.DROPOUT_P = 0.1 92 | 93 | _C.MODEL.COARSE.MAX_NUM_FEATURES = 200 94 | _C.MODEL.COARSE.DEPTHS = [2, 2, 2, 2] 95 | _C.MODEL.COARSE.NUM_HEADS = [4, 4, 4, 4] 96 | _C.MODEL.COARSE.WINDOW_SIZE = [4, 4, 4] 97 | _C.MODEL.COARSE.MLP_RATIO = 1. 98 | _C.MODEL.COARSE.QKV_BIAS = True 99 | _C.MODEL.COARSE.QK_SCALE = None 100 | _C.MODEL.COARSE.DROP_RATE = 0. 101 | _C.MODEL.COARSE.DROP_PATH_RATE = 0.1 102 | 103 | _C.MODEL.FINE = CN() 104 | _C.MODEL.FINE.TYPE = "phtrans" 105 | _C.MODEL.FINE.BASE_NUM_FEATURES = 16 106 | _C.MODEL.FINE.NUM_ONLY_CONV_STAGE = 2 107 | _C.MODEL.FINE.NUM_CONV_PER_STAGE = 2 108 | _C.MODEL.FINE.FEAT_MAP_MUL_ON_DOWNSCALE = 2 109 | _C.MODEL.FINE.POOL_OP_KERNEL_SIZES = [ 110 | [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] 111 | _C.MODEL.FINE.CONV_KERNEL_SIZES = [[3, 3, 3], [ 112 | 3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] 113 | _C.MODEL.FINE.DROPOUT_P = 0.1 114 | 115 | _C.MODEL.FINE.MAX_NUM_FEATURES = 200 116 | _C.MODEL.FINE.DEPTHS = [2, 2, 2, 2] 117 | _C.MODEL.FINE.NUM_HEADS = [4, 4, 4, 4] 118 | _C.MODEL.FINE.WINDOW_SIZE = [3, 4, 4] 119 | _C.MODEL.FINE.MLP_RATIO = 1. 120 | _C.MODEL.FINE.QKV_BIAS = True 121 | _C.MODEL.FINE.QK_SCALE = None 122 | _C.MODEL.FINE.DROP_RATE = 0. 123 | _C.MODEL.FINE.DROP_PATH_RATE = 0.1 124 | 125 | # ----------------------------------------------------------------------------- 126 | # Training settings 127 | # ----------------------------------------------------------------------------- 128 | _C.TRAIN = CN() 129 | _C.TRAIN.DO_BACKPROP = True 130 | _C.TRAIN.VAL_NUM_EPOCHS = 1 131 | _C.TRAIN.SAVE_PERIOD = 1 132 | 133 | _C.TRAIN.EPOCHS = 300 134 | _C.TRAIN.WEIGHT_DECAY = 0.01 135 | _C.TRAIN.WARMUP_EPOCHS = 20 136 | _C.TRAIN.BASE_LR = 5e-4 137 | _C.TRAIN.WARMUP_LR = 5e-7 138 | _C.TRAIN.MIN_LR = 5e-6 139 | # LR scheduler 140 | _C.TRAIN.LR_SCHEDULER = CN() 141 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 142 | 143 | # Epoch interval to decay LR, used in StepLRScheduler 144 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 145 | # LR decay rate, used in StepLRScheduler 146 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 147 | 148 | # Optimizer 149 | _C.TRAIN.OPTIMIZER = CN() 150 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 151 | # Optimizer Epsilon 152 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 153 | # Optimizer Betas 154 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 155 | # SGD momentum 156 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 157 | 158 | # ----------------------------------------------------------------------------- 159 | # Test settings 160 | # ----------------------------------------------------------------------------- 161 | _C.VAL = CN() 162 | _C.VAL.IS_POST_PROCESS = True 163 | _C.VAL.IS_WITH_DATALOADER = True 164 | 165 | 166 | def _update_config_from_file(config, cfg_file): 167 | config.defrost() 168 | with open(cfg_file, 'r') as f: 169 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 170 | 171 | for cfg in yaml_cfg.setdefault('BASE', ['']): 172 | if cfg: 173 | _update_config_from_file( 174 | config, os.path.join(os.path.dirname(cfg_file), cfg) 175 | ) 176 | print('=> merge config from {}'.format(cfg_file)) 177 | config.merge_from_file(cfg_file) 178 | config.freeze() 179 | 180 | 181 | def update_config(config, args): 182 | if args.cfg is not None: 183 | _update_config_from_file(config, args.cfg) 184 | 185 | config.defrost() 186 | if args.opts: 187 | config.merge_from_list(args.opts) 188 | if args.batch_size: 189 | config.DATALOADER.BATCH_SIZE = args.batch_size 190 | if args.tag: 191 | config.WANDB.TAG = args.tag 192 | if args.wandb_mode == "online": 193 | config.WANDB.MODE = args.wandb_mode 194 | if args.world_size: 195 | config.WORLD_SIZE = args.world_size 196 | if args.with_distributed: 197 | config.DIS = True 198 | config.freeze() 199 | 200 | 201 | def update_val_config(config, args): 202 | if args.cfg is not None: 203 | _update_config_from_file(config, args.cfg) 204 | 205 | config.defrost() 206 | if args.opts: 207 | config.merge_from_list(args.opts) 208 | 209 | # merge from specific arguments 210 | if args.save_model_path: 211 | config.SAVE_MODEL_PATH = args.save_model_path 212 | if args.data_path: 213 | config.DATASET.VAL_IMAGE_PATH = args.data_path 214 | if args.output_path: 215 | config.VAL_OUTPUT_PATH = args.output_path 216 | 217 | config.freeze() 218 | 219 | 220 | def get_config(args=None): 221 | config = _C.clone() 222 | update_config(config, args) 223 | 224 | return config 225 | 226 | 227 | def get_config_no_args(): 228 | config = _C.clone() 229 | 230 | return config 231 | 232 | 233 | def get_val_config(args=None): 234 | config = _C.clone() 235 | update_val_config(config, args) 236 | 237 | return config 238 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from config import get_val_config 3 | from models import build_coarse_model, build_fine_model 4 | import os 5 | import torch.backends.cudnn as cudnn 6 | import numpy as np 7 | import time 8 | import torch 9 | from torch.cuda.amp import autocast 10 | import SimpleITK as sitk 11 | from utils import to_cuda, load_checkpoint 12 | from data import predict_dataset, DataLoaderX 13 | from data.utils import change_axes_of_image, extract_topk_largest_candidates, to_one_hot, input_downsample, output_upsample, crop_image_according_to_bbox, get_bbox_from_mask 14 | from batchgenerators.utilities.file_and_folder_operations import * 15 | import torch.nn.functional as F 16 | 17 | def parse_option(): 18 | parser = argparse.ArgumentParser("FLARE2022_training") 19 | parser.add_argument('--cfg', type=str, metavar="FILE", 20 | help='path to config file') 21 | parser.add_argument( 22 | "--opts", 23 | help="Modify config options by adding 'KEY VALUE' pairs. ", 24 | default=None, 25 | nargs='+', 26 | ) 27 | parser.add_argument('-smp', '--save_model_path', type=str, 28 | default=None, help='path to model.pth') 29 | parser.add_argument('-dp', '--data_path', type=str, 30 | default=None, help='path to validation image path') 31 | parser.add_argument('-op', '--output_path', type=str, 32 | default=None, help='path to output image path') 33 | args = parser.parse_args() 34 | config = get_val_config(args) 35 | 36 | return args, config 37 | 38 | class Inference: 39 | def __init__(self, config) -> None: 40 | self.config = config 41 | self.output_path = self.config.VAL_OUTPUT_PATH 42 | os.makedirs(config.VAL_OUTPUT_PATH, exist_ok=True) 43 | self.coarse_size = self.config.DATASET.COARSE.SIZE 44 | self.fine_size = self.config.DATASET.FINE.SIZE 45 | self.extend_size = self.config.DATASET.EXTEND_SIZE 46 | self.is_post_process = self.config.VAL.IS_POST_PROCESS 47 | self.is_nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION 48 | self.is_with_dataloader = self.config.VAL.IS_WITH_DATALOADER 49 | if self.is_with_dataloader: 50 | val_dataset = predict_dataset(config) 51 | self.val_loader = DataLoaderX( 52 | val_dataset, 53 | batch_size=1, 54 | num_workers=0, 55 | pin_memory=config.DATALOADER.PIN_MEMORY, 56 | shuffle=False, 57 | ) 58 | else: 59 | self.val_loader = predict_dataset(config) 60 | cudnn.benchmark = True 61 | 62 | def run(self): 63 | torch.cuda.synchronize() 64 | t_start = time.time() 65 | with autocast(): 66 | with torch.no_grad(): 67 | for image_dict in self.val_loader: 68 | image_dict = image_dict[0] if type(image_dict) is list else image_dict 69 | if self.is_with_dataloader: 70 | image_id = image_dict['image_id'][0] 71 | raw_image = np.array(image_dict['raw_image'].squeeze(0)) 72 | raw_spacing = np.array(image_dict['raw_spacing'][0]) 73 | image_direction = np.array(image_dict['image_direction'][0]) 74 | else: 75 | image_id = image_dict['image_id'] 76 | raw_image = image_dict['raw_image'] 77 | raw_spacing = image_dict['raw_spacing'] 78 | image_direction = image_dict['image_direction'] 79 | coarse_image = torch.from_numpy( 80 | raw_image).unsqueeze(0).unsqueeze(0).float() 81 | raw_image_shape = raw_image.shape 82 | coarse_resize_factor = np.array(raw_image.shape) / np.array(self.coarse_size) 83 | coarse_image = input_downsample(coarse_image, self.coarse_size) 84 | coarse_image = self.coarse_predict(coarse_image, self.config.COARSE_MODEL_PATH) 85 | coarse_pre = F.softmax(coarse_image, 1) 86 | coarse_pre = coarse_pre.cpu().float() 87 | torch.cuda.empty_cache() 88 | coarse_mask = coarse_pre.argmax(1).squeeze(axis=0).numpy().astype(np.uint8) 89 | lab_unique = np.unique(coarse_mask) 90 | coarse_mask = to_one_hot(coarse_mask) 91 | coarse_mask = extract_topk_largest_candidates(coarse_mask,lab_unique, 1) 92 | coarse_bbox = get_bbox_from_mask(coarse_mask) 93 | raw_bbox = [[int(coarse_bbox[0][0] * coarse_resize_factor[0]), 94 | int(coarse_bbox[0][1] * coarse_resize_factor[0])], 95 | [int(coarse_bbox[1][0] * coarse_resize_factor[1]), 96 | int(coarse_bbox[1][1] * coarse_resize_factor[1])], 97 | [int(coarse_bbox[2][0] * coarse_resize_factor[2]), 98 | int(coarse_bbox[2][1] * coarse_resize_factor[2])]] 99 | margin = [self.extend_size / raw_spacing[i] 100 | for i in range(3)] 101 | crop_image, crop_fine_bbox = crop_image_according_to_bbox( 102 | raw_image, raw_bbox, margin) 103 | print(crop_fine_bbox) 104 | crop_image_size = crop_image.shape 105 | crop_image = torch.from_numpy(crop_image).unsqueeze(0).unsqueeze(0) 106 | crop_image = input_downsample(crop_image, self.fine_size) 107 | crop_image = self.fine_predict(crop_image, config.FINE_MODEL_PATH) 108 | torch.cuda.empty_cache() 109 | crop_image = output_upsample(crop_image, crop_image_size) 110 | crop_image = F.softmax(crop_image, 1) 111 | fine_mask = crop_image.argmax(1).squeeze(axis=0).numpy().astype(np.uint8) 112 | if self.is_post_process: 113 | lab_unique = np.unique(fine_mask) 114 | fine_mask = to_one_hot(fine_mask) 115 | fine_mask = extract_topk_largest_candidates(fine_mask,lab_unique, 1) 116 | out_mask = np.zeros(raw_image_shape, np.uint8) 117 | out_mask[crop_fine_bbox[0][0]:crop_fine_bbox[0][1], 118 | crop_fine_bbox[1][0]:crop_fine_bbox[1][1], 119 | crop_fine_bbox[2][0]:crop_fine_bbox[2][1]] = fine_mask 120 | if self.is_nor_dir: 121 | out_mask = change_axes_of_image(out_mask, image_direction) 122 | sitk_image = sitk.GetImageFromArray(out_mask) 123 | sitk.WriteImage(sitk_image, os.path.join( 124 | self.output_path, "FLARETs_{}.nii.gz".format(image_id)), True) 125 | print(f"{image_id} Done") 126 | 127 | torch.cuda.synchronize() 128 | t_end = time.time() 129 | average_time_usage = (t_end - t_start) * 1.0 / len(self.val_loader) 130 | print("Average time usage: {} s".format(average_time_usage)) 131 | 132 | def coarse_predict(self, input, model_path): 133 | coarse_model_checkpoint = load_checkpoint(model_path) 134 | coarse_model = build_coarse_model(coarse_model_checkpoint["config"], True).eval() 135 | coarse_model.load_state_dict({k.replace('module.', ''): v for k, v in coarse_model_checkpoint['state_dict'].items()}) 136 | self._set_requires_grad(coarse_model, False) 137 | coarse_model = coarse_model.cuda().half() 138 | input = to_cuda(input).half() 139 | out = coarse_model(input) 140 | coarse_model = coarse_model.cpu() 141 | return out.cpu().float() 142 | 143 | def fine_predict(self, input, model_path): 144 | fine_model_checkpoint = load_checkpoint(model_path) 145 | fine_model = build_fine_model(fine_model_checkpoint["config"], True).eval() 146 | fine_model.load_state_dict({k.replace('module.', ''): v for k, v in fine_model_checkpoint['state_dict'].items()}) 147 | self._set_requires_grad(fine_model, False) 148 | fine_model = fine_model.cuda().half() 149 | input = to_cuda(input).half() 150 | out = fine_model(input) 151 | fine_model = fine_model.cpu() 152 | return out.cpu().float() 153 | 154 | @staticmethod 155 | def _set_requires_grad(model, requires_grad=False): 156 | for param in model.parameters(): 157 | param.requires_grad = requires_grad 158 | 159 | if __name__ == '__main__': 160 | torch.cuda.synchronize() 161 | t_start = time.time() 162 | _, config = parse_option() 163 | 164 | predict = Inference(config) 165 | predict.run() 166 | torch.cuda.synchronize() 167 | t_end = time.time() 168 | total_time = t_end - t_start 169 | print("Total_time: {} s".format(total_time)) 170 | -------------------------------------------------------------------------------- /data/dataset_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from torch.utils.data import Dataset 4 | from batchgenerators.utilities.file_and_folder_operations import * 5 | from .data_augmentation import default_3D_augmentation_params,default_2D_augmentation_params,get_patch_size,DownsampleSegForDSTransform 6 | from batchgenerators.transforms.abstract_transforms import Compose 7 | from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ 8 | ContrastAugmentationTransform, BrightnessTransform 9 | from batchgenerators.transforms.color_transforms import GammaTransform 10 | from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform 11 | from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform 12 | from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform 13 | from batchgenerators.transforms.utility_transforms import NumpyToTensor 14 | from data.utils import load_pickle 15 | class flare22_dataset(Dataset): 16 | def __init__(self, config, data_size, data_path, unlab_data_path, pool_op_kernel_sizes, num_each_epoch,is_train=True, is_deep_supervision=True): 17 | self.config=config 18 | self.data_path = data_path 19 | self.data_size = data_size 20 | self.unlab_data_path = unlab_data_path 21 | self.pool_op_kernel_sizes = pool_op_kernel_sizes 22 | self.num_each_epoch = num_each_epoch 23 | self.series_ids = subfiles(data_path, join=False, suffix='npz') 24 | self.unlab_series_ids = subfiles(unlab_data_path, join=False, suffix='npz') 25 | self.setup_DA_params() 26 | 27 | self.transforms = self.get_augmentation( 28 | data_size, 29 | self.data_aug_params,is_train=is_train, 30 | deep_supervision_scales=self.deep_supervision_scales if is_deep_supervision else None 31 | ) 32 | def __getitem__(self, idx): 33 | if idx < len(self.series_ids): 34 | data_id = self.series_ids[idx] 35 | data_info = load_pickle(join(self.data_path, data_id.split(".")[0] + "_info.pkl")) 36 | data_load = np.load(join(self.data_path,data_id)) 37 | else: 38 | data_id = self.unlab_series_ids[random.randint(0,len(self.unlab_series_ids)-1)] 39 | data_info = load_pickle(join(self.unlab_data_path, data_id.split(".")[0] + "_info.pkl")) 40 | data_load = np.load(join(self.unlab_data_path,data_id)) 41 | 42 | data_trans = self.transforms(**data_load) 43 | return data_trans, data_info 44 | 45 | def __len__(self): 46 | return self.num_each_epoch 47 | 48 | 49 | 50 | def setup_DA_params(self): 51 | if self.config.MODEL.DEEP_SUPERVISION: 52 | self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(np.vstack(self.pool_op_kernel_sizes), axis=0))[:-1] 53 | self.data_aug_params = default_3D_augmentation_params 54 | self.data_aug_params['rotation_x'] = ( 55 | -30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) 56 | self.data_aug_params['rotation_y'] = ( 57 | -30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) 58 | self.data_aug_params['rotation_z'] = ( 59 | -30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) 60 | 61 | if self.config.DATASET.DA.DO_2D_AUG: 62 | if self.config.DATASET.DA.DO_ELASTIC: 63 | self.data_aug_params["elastic_deform_alpha"] = \ 64 | default_2D_augmentation_params["elastic_deform_alpha"] 65 | self.data_aug_params["elastic_deform_sigma"] = \ 66 | default_2D_augmentation_params["elastic_deform_sigma"] 67 | self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] 68 | 69 | if self.config.DATASET.DA.DO_2D_AUG: 70 | self.basic_generator_patch_size = get_patch_size(self.data_size[1:], 71 | self.data_aug_params['rotation_x'], 72 | self.data_aug_params['rotation_y'], 73 | self.data_aug_params['rotation_z'], 74 | self.data_aug_params['scale_range']) 75 | self.basic_generator_patch_size = np.array( 76 | [self.data_size[0]] + list(self.basic_generator_patch_size)) 77 | else: 78 | self.basic_generator_patch_size = get_patch_size(self.data_size, self.data_aug_params['rotation_x'], 79 | self.data_aug_params['rotation_y'], 80 | self.data_aug_params['rotation_z'], 81 | self.data_aug_params['scale_range']) 82 | 83 | 84 | 85 | def get_augmentation(self, patch_size, params=default_3D_augmentation_params,is_train=True,border_val_seg=-1, 86 | order_seg=1, order_data=3, deep_supervision_scales=None,): 87 | transforms = [] 88 | if is_train: 89 | 90 | if self.config.DATASET.DA.DO_2D_AUG: 91 | ignore_axes = (1,) 92 | 93 | patch_size_spatial = patch_size[1:] 94 | else: 95 | patch_size_spatial = patch_size 96 | ignore_axes = None 97 | 98 | transforms.append(SpatialTransform( 99 | patch_size_spatial, patch_center_dist_from_border=None, 100 | do_elastic_deform=self.config.DATASET.DA.DO_ELASTIC, alpha=params.get("elastic_deform_alpha"), 101 | sigma=params.get("elastic_deform_sigma"), 102 | do_rotation=self.config.DATASET.DA.DO_ROTATION, angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), 103 | angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"), 104 | do_scale=self.config.DATASET.DA.DO_SCALING, scale=params.get("scale_range"), 105 | border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, 106 | border_mode_seg="constant", border_cval_seg=border_val_seg, 107 | order_seg=order_seg, random_crop=self.config.DATASET.DA.RANDOM_CROP, p_el_per_sample=params.get("p_eldef"), 108 | p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), 109 | independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") 110 | )) 111 | 112 | 113 | transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) 114 | transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, 115 | p_per_channel=0.5)) 116 | transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) 117 | 118 | if self.config.DATASET.DA.DO_ADDITIVE_BRIGHTNESS: 119 | transforms.append(BrightnessTransform(params.get("additive_brightness_mu"), 120 | params.get("additive_brightness_sigma"), 121 | True, p_per_sample=params.get("additive_brightness_p_per_sample"), 122 | p_per_channel=params.get("additive_brightness_p_per_channel"))) 123 | 124 | transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) 125 | transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, 126 | p_per_channel=0.5, 127 | order_downsample=0, order_upsample=3, p_per_sample=0.25, 128 | ignore_axes=ignore_axes)) 129 | transforms.append( 130 | GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), 131 | p_per_sample=0.1)) # inverted gamma 132 | 133 | if self.config.DATASET.DA.DO_GAMMA: 134 | transforms.append( 135 | GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), 136 | p_per_sample=params["p_gamma"])) 137 | 138 | if self.config.DATASET.DA.DO_MIRROR: 139 | transforms.append(MirrorTransform(params.get("mirror_axes"))) 140 | 141 | if deep_supervision_scales is not None: 142 | transforms.append(DownsampleSegForDSTransform(deep_supervision_scales, 0, input_key='seg', 143 | output_key='seg')) 144 | 145 | transforms.append(NumpyToTensor(['data', 'seg'], 'float')) 146 | transforms = Compose(transforms) 147 | return transforms 148 | 149 | -------------------------------------------------------------------------------- /models/swin_3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath, trunc_normal_ 4 | from einops import rearrange 5 | 6 | 7 | class Mlp(nn.Module): 8 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 9 | super().__init__() 10 | out_features = out_features or in_features 11 | hidden_features = hidden_features or in_features 12 | self.fc1 = nn.Linear(in_features, hidden_features) 13 | self.act = act_layer() 14 | self.fc2 = nn.Linear(hidden_features, out_features) 15 | self.drop = nn.Dropout(drop) 16 | 17 | def forward(self, x): 18 | x = self.fc1(x) 19 | x = self.act(x) 20 | x = self.drop(x) 21 | x = self.fc2(x) 22 | x = self.drop(x) 23 | return x 24 | 25 | 26 | def window_partition(x, window_size): 27 | 28 | B, S, H, W, C = x.shape 29 | windows = rearrange(x, 'b (s p1) (h p2) (w p3) c -> (b s h w) p1 p2 p3 c', 30 | p1=window_size[0], p2=window_size[1], p3=window_size[2], c=C) 31 | return windows 32 | 33 | 34 | def window_reverse(windows, window_size, S, H, W): 35 | B = int(windows.shape[0] / (S * H * W / 36 | window_size[0] / window_size[1] / window_size[2])) 37 | 38 | x = rearrange(windows, '(b s h w) p1 p2 p3 c -> b (s p1) (h p2) (w p3) c', 39 | p1=window_size[0], p2=window_size[1], p3=window_size[2], b=B, 40 | s=S//window_size[0], h=H//window_size[1], w=W//window_size[2]) 41 | return x 42 | 43 | 44 | class WindowAttention(nn.Module): 45 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | 47 | super().__init__() 48 | self.dim = dim 49 | self.window_size = window_size 50 | self.num_heads = num_heads 51 | head_dim = dim // num_heads 52 | self.scale = qk_scale or head_dim ** -0.5 53 | 54 | self.relative_position_bias_table = nn.Parameter( 55 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), 56 | num_heads)) 57 | 58 | coords_s = torch.arange(self.window_size[0]) 59 | coords_h = torch.arange(self.window_size[1]) 60 | coords_w = torch.arange(self.window_size[2]) 61 | coords = torch.stack(torch.meshgrid( 62 | [coords_s, coords_h, coords_w])) 63 | coords_flatten = torch.flatten(coords, 1) 64 | relative_coords = coords_flatten[:, :, 65 | None] - coords_flatten[:, None, :] 66 | relative_coords = relative_coords.permute( 67 | 1, 2, 0).contiguous() 68 | relative_coords[:, :, 0] += self.window_size[0] - 1 69 | relative_coords[:, :, 1] += self.window_size[1] - 1 70 | relative_coords[:, :, 2] += self.window_size[2] - 1 71 | 72 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * \ 73 | (2 * self.window_size[2] - 1) 74 | relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 75 | relative_position_index = relative_coords.sum(-1) 76 | self.register_buffer("relative_position_index", 77 | relative_position_index) 78 | 79 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 80 | self.attn_drop = nn.Dropout(attn_drop) 81 | self.proj = nn.Linear(dim, dim) 82 | self.proj_drop = nn.Dropout(proj_drop) 83 | 84 | trunc_normal_(self.relative_position_bias_table, std=.02) 85 | self.softmax = nn.Softmax(dim=-1) 86 | 87 | def forward(self, x, mask=None): 88 | 89 | B_, N, C = x.shape 90 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // 91 | self.num_heads).permute(2, 0, 3, 1, 4) 92 | 93 | q, k, v = qkv[0], qkv[1], qkv[2] 94 | 95 | q = q * self.scale 96 | attn = (q @ k.transpose(-2, -1)) 97 | 98 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 99 | self.window_size[0] * self.window_size[1] * self.window_size[2], 100 | self.window_size[0] * self.window_size[1] * self.window_size[2], -1) 101 | relative_position_bias = relative_position_bias.permute( 102 | 2, 0, 1).contiguous() 103 | attn = attn + relative_position_bias.unsqueeze(0) 104 | 105 | if mask is not None: 106 | nW = mask.shape[0] 107 | attn = attn.view(B_ // nW, nW, self.num_heads, N, 108 | N) + mask.unsqueeze(1).unsqueeze(0) 109 | attn = attn.view(-1, self.num_heads, N, N) 110 | attn = self.softmax(attn) 111 | else: 112 | attn = self.softmax(attn) 113 | 114 | attn = self.attn_drop(attn) 115 | 116 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 117 | x = self.proj(x) 118 | x = self.proj_drop(x) 119 | return x 120 | 121 | 122 | class SwinTransformerBlock(nn.Module): 123 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 124 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 125 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 126 | super().__init__() 127 | self.dim = dim 128 | self.input_resolution = input_resolution 129 | self.num_heads = num_heads 130 | self.window_size = window_size 131 | self.shift_size = shift_size 132 | self.mlp_ratio = mlp_ratio 133 | 134 | if self.shift_size != 0: 135 | assert 0 <= min(self.shift_size) < min( 136 | self.window_size), "shift_size must in 0-window_size" 137 | 138 | self.norm1 = norm_layer(dim) 139 | self.attn = WindowAttention( 140 | dim, window_size=self.window_size, num_heads=num_heads, 141 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 142 | 143 | self.drop_path = DropPath( 144 | drop_path) if drop_path > 0. else nn.Identity() 145 | self.norm2 = norm_layer(dim) 146 | mlp_hidden_dim = int(dim * mlp_ratio) 147 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 148 | act_layer=act_layer, drop=drop) 149 | 150 | if max(self.shift_size) > 0: 151 | 152 | S, H, W = self.input_resolution 153 | img_mask = torch.zeros((1, S, H, W, 1)) 154 | s_slices = (slice(0, -self.window_size[0]), 155 | slice(-self.window_size[0], -self.shift_size[0]), 156 | slice(-self.shift_size[0], None)) 157 | h_slices = (slice(0, -self.window_size[1]), 158 | slice(-self.window_size[1], -self.shift_size[1]), 159 | slice(-self.shift_size[1], None)) 160 | w_slices = (slice(0, -self.window_size[2]), 161 | slice(-self.window_size[2], -self.shift_size[2]), 162 | slice(-self.shift_size[2], None)) 163 | cnt = 0 164 | for s in s_slices: 165 | for h in h_slices: 166 | for w in w_slices: 167 | img_mask[:, s, h, w, :] = cnt 168 | cnt += 1 169 | 170 | mask_windows = window_partition(img_mask, self.window_size) 171 | mask_windows = mask_windows.view( 172 | -1, self.window_size[0] * self.window_size[1] * self.window_size[2]) 173 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 174 | attn_mask = attn_mask.masked_fill( 175 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 176 | else: 177 | attn_mask = None 178 | 179 | self.register_buffer("attn_mask", attn_mask) 180 | 181 | def forward(self, x): 182 | s, h, w = self.input_resolution 183 | B, C, S, H, W = x.shape 184 | assert S == s and H == h and W == w, "input feature has wrong size" 185 | x = rearrange(x, 'b c s h w -> b (s h w) c') 186 | shortcut = x 187 | x = self.norm1(x) 188 | x = x.view(B, S, H, W, C) 189 | 190 | # cyclic shift 191 | if max(self.shift_size) > 0: 192 | shifted_x = torch.roll( 193 | x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) 194 | else: 195 | shifted_x = x 196 | 197 | x_windows = window_partition(shifted_x, self.window_size) 198 | 199 | x_windows = x_windows.view( 200 | -1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) 201 | 202 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 203 | attn_windows = attn_windows.view( 204 | -1, self.window_size[0], self.window_size[1], self.window_size[2], C) 205 | shifted_x = window_reverse( 206 | attn_windows, self.window_size, S, H, W) 207 | if max(self.shift_size) > 0: 208 | x = torch.roll(shifted_x, shifts=( 209 | self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) 210 | else: 211 | x = shifted_x 212 | x = x.view(B, S * H * W, C) 213 | 214 | x = shortcut + self.drop_path(x) 215 | x = x + self.drop_path(self.mlp(self.norm2(x))) 216 | x = rearrange(x, 'b (s h w) c -> b c s h w', s=S, h=H, w=W) 217 | return x 218 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/phtrans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from timm.models.layers import trunc_normal_ 5 | from .swin_3D import * 6 | 7 | 8 | class PHTrans(nn.Module): 9 | def __init__(self, img_size, base_num_features, num_classes, image_channels=1, num_only_conv_stage=2, num_conv_per_stage=2, 10 | feat_map_mul_on_downscale=2, pool_op_kernel_sizes=None, 11 | conv_kernel_sizes=None, dropout_p=0., deep_supervision=True, max_num_features=None, only_conv=False, depths=None, num_heads=None, 12 | window_size=None, mlp_ratio=4., qkv_bias=True, qk_scale=None, 13 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 14 | norm_layer=nn.LayerNorm, is_preprocess=False, **kwargs): 15 | super().__init__() 16 | 17 | conv_op = nn.Conv3d 18 | norm_op = nn.InstanceNorm3d 19 | norm_op_kwargs = {'eps': 1e-5, 'affine': True} 20 | dropout_op = nn.Dropout3d 21 | dropout_op_kwargs = {'p': dropout_p, 'inplace': True} 22 | nonlin = nn.GELU 23 | nonlin_kwargs = {} 24 | 25 | self.is_preprocess = is_preprocess 26 | self._deep_supervision = deep_supervision 27 | self.num_pool = len(pool_op_kernel_sizes) 28 | conv_pad_sizes = [] 29 | for krnl in conv_kernel_sizes: 30 | conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) 31 | dpr = [x.item() for x in torch.linspace( 32 | 0, drop_path_rate, sum(depths))] 33 | 34 | self.seg_outputs = [] 35 | for ds in range(self.num_pool): 36 | self.seg_outputs.append(DeepSupervision(min( 37 | (base_num_features * feat_map_mul_on_downscale ** ds), max_num_features), num_classes)) 38 | self.seg_outputs = nn.ModuleList(self.seg_outputs) 39 | 40 | # build layers 41 | self.down_layers = nn.ModuleList() 42 | for i_layer in range(self.num_pool+1): 43 | layer = BasicLayer(num_stage=i_layer, 44 | only_conv=only_conv, 45 | num_only_conv_stage=num_only_conv_stage, 46 | num_pool=self.num_pool, 47 | base_num_features=base_num_features, 48 | dim=min( 49 | (base_num_features * feat_map_mul_on_downscale ** i_layer), max_num_features), 50 | input_resolution=( 51 | img_size // np.prod(pool_op_kernel_sizes[:i_layer], 0, dtype=np.int64)), 52 | depth=depths[i_layer-num_only_conv_stage] if ( 53 | i_layer >= num_only_conv_stage) else None, 54 | num_heads=num_heads[i_layer-num_only_conv_stage] if ( 55 | i_layer >= num_only_conv_stage) else None, 56 | window_size=window_size, 57 | image_channels=image_channels, num_conv_per_stage=num_conv_per_stage, 58 | conv_op=conv_op, norm_op=norm_op, norm_op_kwargs=norm_op_kwargs, dropout_op=dropout_op, 59 | dropout_op_kwargs=dropout_op_kwargs, nonlin=nonlin, nonlin_kwargs=nonlin_kwargs, 60 | conv_kernel_sizes=conv_kernel_sizes, conv_pad_sizes=conv_pad_sizes, pool_op_kernel_sizes=pool_op_kernel_sizes, 61 | max_num_features=max_num_features, 62 | mlp_ratio=mlp_ratio, 63 | qkv_bias=qkv_bias, qk_scale=qk_scale, 64 | drop=drop_rate, attn_drop=attn_drop_rate, 65 | drop_path=dpr[sum(depths[:i_layer-num_only_conv_stage]):sum(depths[:i_layer-num_only_conv_stage + 1])] if ( 66 | i_layer >= num_only_conv_stage) else None, 67 | norm_layer=norm_layer, 68 | down_or_upsample=nn.Conv3d if i_layer > 0 else None, 69 | feat_map_mul_on_downscale=feat_map_mul_on_downscale, 70 | is_encoder=True) 71 | self.down_layers.append(layer) 72 | self.up_layers = nn.ModuleList() 73 | for i_layer in range(self.num_pool)[::-1]: 74 | layer = BasicLayer(num_stage=i_layer, 75 | only_conv=only_conv, 76 | num_only_conv_stage=num_only_conv_stage, 77 | num_pool=self.num_pool, 78 | base_num_features=base_num_features, 79 | dim=min( 80 | (base_num_features * feat_map_mul_on_downscale ** i_layer), max_num_features), 81 | input_resolution=( 82 | img_size // np.prod(pool_op_kernel_sizes[:i_layer], 0, dtype=np.int64)), 83 | depth=depths[i_layer-num_only_conv_stage] if ( 84 | i_layer >= num_only_conv_stage) else None, 85 | num_heads=num_heads[i_layer-num_only_conv_stage] if ( 86 | i_layer >= num_only_conv_stage) else None, 87 | window_size=window_size, 88 | image_channels=image_channels, num_conv_per_stage=num_conv_per_stage, 89 | conv_op=conv_op, norm_op=norm_op, norm_op_kwargs=norm_op_kwargs, dropout_op=dropout_op, 90 | dropout_op_kwargs=dropout_op_kwargs, nonlin=nonlin, nonlin_kwargs=nonlin_kwargs, 91 | conv_kernel_sizes=conv_kernel_sizes, conv_pad_sizes=conv_pad_sizes, pool_op_kernel_sizes=pool_op_kernel_sizes, 92 | max_num_features=max_num_features, 93 | mlp_ratio=mlp_ratio, 94 | qkv_bias=qkv_bias, qk_scale=qk_scale, 95 | drop=drop_rate, attn_drop=attn_drop_rate, 96 | drop_path=dpr[sum(depths[:i_layer-num_only_conv_stage]):sum(depths[:i_layer-num_only_conv_stage + 1])] if ( 97 | i_layer >= num_only_conv_stage) else None, 98 | norm_layer=norm_layer, 99 | down_or_upsample=nn.ConvTranspose3d, 100 | feat_map_mul_on_downscale=feat_map_mul_on_downscale, 101 | is_encoder=False) 102 | self.up_layers.append(layer) 103 | self.apply(self._InitWeights) 104 | 105 | def _InitWeights(self, module): 106 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 107 | module.weight = nn.init.kaiming_normal_(module.weight, a=.02) 108 | if module.bias is not None: 109 | module.bias = nn.init.constant_(module.bias, 0) 110 | elif isinstance(module, nn.Linear): 111 | trunc_normal_(module.weight, std=.02) 112 | if isinstance(module, nn.Linear) and module.bias is not None: 113 | nn.init.constant_(module.bias, 0) 114 | elif isinstance(module, nn.LayerNorm): 115 | nn.init.constant_(module.bias, 0) 116 | nn.init.constant_(module.weight, 1.0) 117 | 118 | def forward(self, x): 119 | 120 | x_skip = list() 121 | for i, layer in enumerate(self.down_layers): 122 | x = layer(x, None) 123 | if i < self.num_pool: 124 | x_skip.append(x) 125 | out = [] 126 | for i, layer in enumerate(self.up_layers): 127 | x = layer(x, x_skip[-(i+1)]) 128 | if self._deep_supervision: 129 | out.append(x) 130 | 131 | if self._deep_supervision: 132 | ds = [] 133 | for i in range(len(out)): 134 | ds.append(self.seg_outputs[i](out[-(i+1)])) 135 | else: 136 | ds = self.seg_outputs[0](x) 137 | 138 | return ds 139 | 140 | @torch.jit.ignore 141 | def no_weight_decay(self): 142 | return {'absolute_pos_embed'} 143 | 144 | @torch.jit.ignore 145 | def no_weight_decay_keywords(self): 146 | return {'relative_position_bias_table'} 147 | 148 | 149 | class ConvDropoutNormNonlin(nn.Module): 150 | def __init__(self, input_channels, output_channels, 151 | conv_op=nn.Conv3d, conv_kwargs=None, 152 | norm_op=nn.BatchNorm3d, norm_op_kwargs=None, 153 | dropout_op=nn.Dropout3d, dropout_op_kwargs=None, 154 | nonlin=nn.LeakyReLU, nonlin_kwargs=None): 155 | super(ConvDropoutNormNonlin, self).__init__() 156 | self.conv = conv_op(input_channels, output_channels, **conv_kwargs) 157 | 158 | if dropout_op is not None and dropout_op_kwargs['p'] is not None and dropout_op_kwargs[ 159 | 'p'] > 0: 160 | 161 | self.dropout = dropout_op(**dropout_op_kwargs) 162 | else: 163 | self.dropout = None 164 | self.instnorm = norm_op(output_channels, **norm_op_kwargs) 165 | if nonlin == nn.GELU: 166 | self.lrelu = nonlin() 167 | else: 168 | self.lrelu = nonlin(**nonlin_kwargs) 169 | 170 | def forward(self, x): 171 | x = self.conv(x) 172 | if self.dropout is not None: 173 | x = self.dropout(x) 174 | return self.lrelu(self.instnorm(x)) 175 | 176 | 177 | class DeepSupervision(nn.Module): 178 | def __init__(self, dim, num_classes): 179 | super().__init__() 180 | self.proj = nn.Conv3d( 181 | dim, num_classes, kernel_size=1, stride=1, bias=False) 182 | 183 | def forward(self, x): 184 | x = self.proj(x) 185 | return x 186 | 187 | 188 | class BasicLayer(nn.Module): 189 | 190 | def __init__(self, num_stage, only_conv, num_only_conv_stage, num_pool, base_num_features, dim, input_resolution, depth, num_heads, 191 | window_size, image_channels=1, num_conv_per_stage=2, conv_op=None, 192 | norm_op=None, norm_op_kwargs=None, 193 | dropout_op=None, dropout_op_kwargs=None, 194 | nonlin=None, nonlin_kwargs=None, 195 | conv_kernel_sizes=None, conv_pad_sizes=None, pool_op_kernel_sizes=None, basic_block=ConvDropoutNormNonlin, max_num_features=None, 196 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 197 | drop_path=0., norm_layer=nn.LayerNorm, down_or_upsample=None, feat_map_mul_on_downscale=2, is_encoder=True): 198 | 199 | super().__init__() 200 | self.num_stage = num_stage 201 | self.only_conv = only_conv 202 | self.num_only_conv_stage = num_only_conv_stage 203 | self.num_pool = num_pool 204 | self.is_encoder = is_encoder 205 | self.image_channels = image_channels 206 | conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True} 207 | 208 | if is_encoder: 209 | input_features = dim 210 | else: 211 | input_features = 2*dim 212 | 213 | # self.depth = depth 214 | conv_kwargs['kernel_size'] = conv_kernel_sizes[num_stage] 215 | conv_kwargs['padding'] = conv_pad_sizes[num_stage] 216 | 217 | input_du_channels = min(int(base_num_features * feat_map_mul_on_downscale ** (num_stage-1 if is_encoder else num_stage+1)), 218 | max_num_features) 219 | output_du_channels = dim 220 | if self.is_encoder and self.num_stage == 0: 221 | self.frist_conv = conv_op( 222 | image_channels, dim, kernel_size=1, stride=1, bias=True) 223 | else: 224 | self.frist_conv = None 225 | self.conv_blocks = nn.Sequential( 226 | *([basic_block(input_features, dim, conv_op, 227 | conv_kwargs, 228 | norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, 229 | nonlin, nonlin_kwargs)] + 230 | [basic_block(dim, dim, conv_op, 231 | conv_kwargs, 232 | norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, 233 | nonlin, nonlin_kwargs) for _ in range(num_conv_per_stage - 1)])) 234 | 235 | # build blocks 236 | if num_stage >= num_only_conv_stage and not only_conv: 237 | self.swin_blocks = nn.ModuleList([ 238 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 239 | num_heads=num_heads, window_size=window_size, 240 | shift_size=[0, 0, 0] if (i % 2 == 0) else [ 241 | window_size[0] // 2, window_size[1] // 2, window_size[2] // 2], 242 | mlp_ratio=mlp_ratio, 243 | qkv_bias=qkv_bias, qk_scale=qk_scale, 244 | drop=drop, attn_drop=attn_drop, 245 | drop_path=drop_path[i] if isinstance( 246 | drop_path, list) else drop_path, 247 | norm_layer=norm_layer) 248 | for i in range(depth)]) 249 | 250 | # patch merging layer 251 | if down_or_upsample is not None: 252 | dowm_stage = num_stage-1 if is_encoder else num_stage 253 | self.down_or_upsample = nn.Sequential(down_or_upsample(input_du_channels, output_du_channels, pool_op_kernel_sizes[dowm_stage], 254 | pool_op_kernel_sizes[dowm_stage], bias=False), 255 | norm_op( 256 | output_du_channels, **norm_op_kwargs) 257 | ) 258 | else: 259 | self.down_or_upsample = None 260 | 261 | def forward(self, x, skip): 262 | if self.frist_conv is not None: 263 | x = self.frist_conv(x) 264 | if self.down_or_upsample is not None: 265 | x = self.down_or_upsample(x) 266 | s = x 267 | if not self.is_encoder and self.num_stage < self.num_pool: 268 | x = torch.cat((x, skip), dim=1) 269 | x = self.conv_blocks(x) 270 | if self.num_stage >= self.num_only_conv_stage and not self.only_conv: 271 | if not self.is_encoder and self.num_stage < self.num_pool: 272 | s = s + skip 273 | for tblk in self.swin_blocks: 274 | s = tblk(s) 275 | x = x + s 276 | return x 277 | --------------------------------------------------------------------------------