├── models
├── __init__.py
├── build.py
├── swin_3D.py
└── phtrans.py
├── img
└── method.png
├── data
├── __init__.py
├── dataset_val.py
├── build.py
├── data_augmentation.py
├── utils.py
└── dataset_train.py
├── .gitignore
├── predict.sh
├── utils.py
├── requirements.txt
├── optimizer.py
├── README.md
├── lr_scheduler.py
├── metrics.py
├── coarse_train.py
├── fine_train.py
├── data_preprocess.py
├── unlabel_data_preprocess.py
├── trainer.py
├── losses.py
├── config.py
├── predict.py
└── LICENSE
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_coarse_model,build_fine_model
--------------------------------------------------------------------------------
/img/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lseventeen/FLARE22-TwoStagePHTrans/HEAD/img/method.png
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_loader,DataLoaderX
2 | from .dataset_val import predict_dataset
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb
2 | *egg-info
3 | .vscode
4 | *__pycache__*
5 | save*
6 | val*
7 | *image
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/predict.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash -e
2 | python /home/lwt/code/flare/FLARE22-TwoStagePHTrans/predict.py -dp '/workspace/inputs/' -op '/workspace/outputs/'
3 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | import os
5 | from batchgenerators.utilities.file_and_folder_operations import *
6 |
7 |
8 | def seed_torch(seed=42):
9 | random.seed(seed)
10 | os.environ['PYTHONHASHSEED'] = str(seed)
11 | np.random.seed(seed)
12 | torch.manual_seed(seed)
13 | torch.cuda.manual_seed(seed)
14 | torch.backends.cudnn.deterministic = True
15 |
16 |
17 | def to_cuda(data, non_blocking=True):
18 | if isinstance(data, list):
19 | data = [i.cuda(non_blocking=non_blocking) for i in data]
20 | else:
21 | data = data.cuda(non_blocking=non_blocking)
22 | return data
23 |
24 |
25 | def load_checkpoint(checkpoint_path):
26 | checkpoint_file = "final_checkpoint.pth"
27 | checkpoint = torch.load(
28 | join(checkpoint_path, checkpoint_file), map_location=torch.device('cpu'))
29 | return checkpoint
30 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 |
2 | batchgenerators==0.23
3 | brotlipy==0.7.0
4 | bunch==1.0.1
5 | certifi==2021.10.8
6 |
7 |
8 | click==8.1.3
9 | connected-components-3d==3.10.0
10 |
11 | docker-pycreds==0.4.0
12 | einops==0.4.1
13 | fastremap==1.13.0
14 | future==0.18.2
15 | gitdb==4.0.9
16 | GitPython==3.1.27
17 |
18 | imageio==2.17.0
19 | joblib==1.1.0
20 | linecache2==1.0.0
21 | lmdb==1.3.0
22 | loguru==0.6.0
23 | mkl-fft==1.3.1
24 |
25 | mkl-service==2.4.0
26 | networkx==2.8
27 | nibabel==3.2.2
28 | numpy==1.22.3
29 | opencv-python==4.5.5.64
30 | packaging==21.3
31 | pathtools==0.1.2
32 | Pillow==9.1.0
33 | prefetch-generator==1.0.1
34 | promise==2.3
35 | protobuf==3.20.1
36 | psutil==5.9.0
37 |
38 | pyparsing==3.0.8
39 |
40 | python-dateutil==2.8.2
41 | PyWavelets==1.3.0
42 | PyYAML==6.0
43 |
44 | ruamel.yaml==0.17.21
45 | ruamel.yaml.clib==0.2.6
46 | scikit-image==0.19.2
47 | scikit-learn==1.0.2
48 | scipy==1.8.0
49 | sentry-sdk==1.5.10
50 | setproctitle==1.2.3
51 | shortuuid==1.0.8
52 | SimpleITK==2.0.2
53 |
54 | smmap==5.0.0
55 | threadpoolctl==3.1.0
56 | tifffile==2022.4.22
57 | timm==0.5.4
58 |
59 | tqdm==4.64.0
60 | traceback2==1.4.0
61 |
62 | unittest2==1.1.0
63 |
64 | wandb==0.12.15
65 | yacs==0.1.8
66 |
--------------------------------------------------------------------------------
/data/dataset_val.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset
3 | from batchgenerators.utilities.file_and_folder_operations import *
4 | from .utils import load_data,change_axes_of_image
5 |
6 | class predict_dataset(Dataset):
7 | def __init__(self, config):
8 | super(predict_dataset, self).__init__()
9 | self.config = config
10 | self.data_path = config.DATASET.VAL_IMAGE_PATH
11 |
12 | self.is_nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION
13 |
14 | self.series_ids = subfiles(self.data_path, join=False, suffix='gz')
15 | def __len__(self):
16 | return len(self.series_ids)
17 |
18 | def __getitem__(self, idx):
19 | image_id = self.series_ids[idx].split("_")[1]
20 | raw_image, image_spacing, image_direction= load_data(join(self.data_path,self.series_ids[idx]))
21 | if self.is_nor_dir:
22 | raw_image = change_axes_of_image(raw_image, image_direction)
23 | return {'image_id': image_id,
24 | 'raw_image': np.ascontiguousarray(raw_image),
25 | 'raw_spacing': image_spacing,
26 | 'image_direction': image_direction
27 | }
28 |
29 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch import optim as optim
2 |
3 | def build_optimizer(config, model):
4 | """
5 | Build optimizer, set weight decay of normalization to 0 by default.
6 | """
7 | skip = {}
8 | skip_keywords = {}
9 | if hasattr(model, 'no_weight_decay'):
10 | skip = model.no_weight_decay()
11 | if hasattr(model, 'no_weight_decay_keywords'):
12 | skip_keywords = model.no_weight_decay_keywords()
13 | parameters = set_weight_decay(model, skip, skip_keywords)
14 |
15 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
16 | optimizer = None
17 | if opt_lower == 'sgd':
18 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
19 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
20 | elif opt_lower == 'adamw':
21 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
22 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
23 |
24 | return optimizer
25 |
26 |
27 | def set_weight_decay(model, skip_list=(), skip_keywords=()):
28 | has_decay = []
29 | no_decay = []
30 |
31 | for name, param in model.named_parameters():
32 | if not param.requires_grad:
33 | continue
34 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
35 | check_keywords_in_name(name, skip_keywords):
36 | no_decay.append(param)
37 | else:
38 | has_decay.append(param)
39 | return [{'params': has_decay},
40 | {'params': no_decay, 'weight_decay': 0.}]
41 |
42 |
43 | def check_keywords_in_name(name, keywords=()):
44 | isin = False
45 | for keyword in keywords:
46 | if keyword in name:
47 | isin = True
48 | return isin
49 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Two-stage PHTrans
2 |
3 | This repository is a solution for the [MICCAI FLARE2022 challenge](https://flare22.grand-challenge.org/). A detailed description of the method introduction, experiments and analysis of the results for this solution is presented in paper : [Combining Self-Training and Hybrid Architecture for Semi-supervised Abdominal Organ Segmentation](https://arxiv.org/abs/2207.11512). As shown in the figure below, this pipeline consists of two parts: (a) pseudo-label generation for unlabeled data, which is implemented using PHTrans under the nn-UNet framework (for more information, see [PHTrans](https://github.com/lseventeen/PHTrans)); (b) a two-stage segmentation framework with Lightweight PHTrans. This repository is the code implementation of this part.
4 |
5 |
6 |

7 |
8 |
9 |
10 | ## Prerequisites
11 |
12 |
13 |
14 | Download our repo and install packages:
15 | ```
16 | git clone https://github.com/lseventeen/FLARE22-TwoStagePHTrans
17 | cd FLARE22-TwoStagePHTrans
18 | pip install -r requirements.txt
19 | ```
20 |
21 |
22 | ## Datasets processing
23 | Download [FLARE 2022](https://flare22.grand-challenge.org/Dataset/) datasets. Generate pseudo-labels for unlabeled data based on the repository [PHTrans](https://github.com/lseventeen/PHTrans). Modify the data path in the [config.py](https://github.com/lseventeen/FLARE22-TwoStagePHTrans/blob/master/config.py) file. Type this in the terminal to perform dataset processing:
24 |
25 | ```
26 | python data_processing.py
27 | ```
28 |
29 | ## Training
30 | Type this in terminal to run coarse segmentation train:
31 |
32 | ```
33 | python coarse_train.py
34 | ```
35 | Type this in terminal to run fine segmentation train:
36 |
37 | ```
38 | python fine_train.py
39 | ```
40 | ## Inference
41 | Type this in terminal to Inference:
42 |
43 | ```
44 | python predict.py -dp DATA_PATH -op SAVE_RESULTS_PATH
45 | ```
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from batchgenerators.utilities.file_and_folder_operations import *
3 | from data.dataset_train import flare22_dataset
4 | from sklearn.model_selection import train_test_split
5 | from prefetch_generator import BackgroundGenerator
6 | import torch
7 | class DataLoaderX(DataLoader):
8 | def __iter__(self):
9 | return BackgroundGenerator(super().__iter__())
10 |
11 | def build_loader(config,data_size, data_path,unlab_data_path, pool_op_kernel_sizes, num_each_epoch):
12 | series_ids_train = subfiles(data_path, join=False, suffix='npz')
13 |
14 | if config.DATASET.WITH_VAL:
15 |
16 | series_ids_train, series_ids_val = train_test_split(series_ids_train, test_size=config.DATASET.VAL_SPLIT,random_state=42)
17 | val_dataset = flare22_dataset(config,series_ids_val,data_size, data_path, pool_op_kernel_sizes,is_train=False)
18 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) if config.DIS else None
19 | val_loader = DataLoaderX(
20 | dataset=val_dataset,
21 | sampler=val_sampler ,
22 | batch_size = config.DATALOADER.BATCH_SIZE,
23 | num_workers=config.DATALOADER.NUM_WORKERS,
24 | pin_memory= config.DATALOADER.PIN_MEMORY,
25 | shuffle=False,
26 | drop_last=False
27 | )
28 | else:
29 | val_loader = None
30 |
31 |
32 | train_dataset = flare22_dataset(config, data_size, data_path, unlab_data_path, pool_op_kernel_sizes, num_each_epoch,is_train=True)
33 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True) if config.DIS else None
34 | train_loader = DataLoaderX(
35 | train_dataset,
36 | sampler=train_sampler,
37 | batch_size = config.DATALOADER.BATCH_SIZE,
38 | num_workers=config.DATALOADER.NUM_WORKERS,
39 | pin_memory= config.DATALOADER.PIN_MEMORY,
40 | shuffle=True if train_sampler is None else False,
41 | drop_last=True
42 | )
43 | return train_loader,val_loader
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/models/build.py:
--------------------------------------------------------------------------------
1 | from .phtrans import PHTrans
2 |
3 |
4 | def build_coarse_model(config, is_VAL = False):
5 | if config.MODEL.COARSE.TYPE == 'phtrans':
6 | model = PHTrans(
7 | img_size = config.DATASET.COARSE.SIZE,
8 | base_num_features = config.MODEL.COARSE.BASE_NUM_FEATURES,
9 | num_classes = config.DATASET.COARSE.LABEL_CLASSES,
10 | num_only_conv_stage = config.MODEL.COARSE.NUM_ONLY_CONV_STAGE,
11 | num_conv_per_stage = config.MODEL.COARSE.NUM_CONV_PER_STAGE,
12 | feat_map_mul_on_downscale = config.MODEL.COARSE.FEAT_MAP_MUL_ON_DOWNSCALE,
13 | pool_op_kernel_sizes = config.MODEL.COARSE.POOL_OP_KERNEL_SIZES,
14 | conv_kernel_sizes = config.MODEL.COARSE.CONV_KERNEL_SIZES,
15 | dropout_p = config.MODEL.COARSE.DROPOUT_P,
16 | deep_supervision = config.MODEL.DEEP_SUPERVISION if not is_VAL else False,
17 | max_num_features = config.MODEL.COARSE.MAX_NUM_FEATURES,
18 | depths = config.MODEL.COARSE.DEPTHS,
19 | num_heads = config.MODEL.COARSE.NUM_HEADS,
20 | window_size = config.MODEL.COARSE.WINDOW_SIZE,
21 | mlp_ratio = config.MODEL.COARSE.MLP_RATIO,
22 | qkv_bias = config.MODEL.COARSE.DROP_RATE,
23 | qk_scale = config.MODEL.COARSE.QK_SCALE,
24 | drop_rate = config.MODEL.COARSE.DROP_RATE,
25 | drop_path_rate = config.MODEL.COARSE.DROP_PATH_RATE,
26 | )
27 | else:
28 | raise NotImplementedError(f"Unkown model: {config.MODEL.COARSE.TYPE}")
29 |
30 | return model
31 |
32 |
33 |
34 | def build_fine_model(config, is_VAL = False):
35 | if config.MODEL.FINE.TYPE == 'phtrans':
36 | model = PHTrans(
37 | img_size = config.DATASET.FINE.SIZE,
38 | base_num_features = config.MODEL.FINE.BASE_NUM_FEATURES,
39 | num_classes = config.DATASET.FINE.LABEL_CLASSES,
40 | num_only_conv_stage = config.MODEL.FINE.NUM_ONLY_CONV_STAGE,
41 | num_conv_per_stage = config.MODEL.FINE.NUM_CONV_PER_STAGE,
42 | feat_map_mul_on_downscale = config.MODEL.FINE.FEAT_MAP_MUL_ON_DOWNSCALE,
43 | pool_op_kernel_sizes = config.MODEL.FINE.POOL_OP_KERNEL_SIZES,
44 | conv_kernel_sizes = config.MODEL.FINE.CONV_KERNEL_SIZES,
45 | dropout_p = config.MODEL.FINE.DROPOUT_P,
46 | deep_supervision = config.MODEL.DEEP_SUPERVISION if not is_VAL else False,
47 | max_num_features = config.MODEL.FINE.MAX_NUM_FEATURES,
48 | depths = config.MODEL.FINE.DEPTHS,
49 | num_heads = config.MODEL.FINE.NUM_HEADS,
50 | window_size = config.MODEL.FINE.WINDOW_SIZE,
51 | mlp_ratio = config.MODEL.FINE.MLP_RATIO,
52 | qkv_bias = config.MODEL.FINE.DROP_RATE,
53 | qk_scale = config.MODEL.FINE.QK_SCALE,
54 | drop_rate = config.MODEL.FINE.DROP_RATE,
55 | drop_path_rate = config.MODEL.FINE.DROP_PATH_RATE,
56 | )
57 | else:
58 | raise NotImplementedError(f"Unkown model: {config.MODEL.FINE.TYPE}")
59 |
60 | return model
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from timm.scheduler.cosine_lr import CosineLRScheduler
3 | from timm.scheduler.step_lr import StepLRScheduler
4 | from timm.scheduler.scheduler import Scheduler
5 |
6 |
7 | def build_scheduler(config, optimizer, n_iter_per_epoch):
8 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
9 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
10 | decay_steps = int(
11 | config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
12 |
13 | lr_scheduler = None
14 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
15 | lr_scheduler = CosineLRScheduler(
16 | optimizer,
17 | t_initial=num_steps,
18 | cycle_mul=1.,
19 | lr_min=config.TRAIN.MIN_LR,
20 | warmup_lr_init=config.TRAIN.WARMUP_LR,
21 | warmup_t=warmup_steps,
22 | cycle_limit=1,
23 | t_in_epochs=False,
24 | )
25 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
26 | lr_scheduler = LinearLRScheduler(
27 | optimizer,
28 | t_initial=num_steps,
29 | lr_min_rate=0.01,
30 | warmup_lr_init=config.TRAIN.WARMUP_LR,
31 | warmup_t=warmup_steps,
32 | t_in_epochs=False,
33 | )
34 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
35 | lr_scheduler = StepLRScheduler(
36 | optimizer,
37 | decay_t=decay_steps,
38 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
39 | warmup_lr_init=config.TRAIN.WARMUP_LR,
40 | warmup_t=warmup_steps,
41 | t_in_epochs=False,
42 | )
43 |
44 | return lr_scheduler
45 |
46 |
47 | class LinearLRScheduler(Scheduler):
48 | def __init__(self,
49 | optimizer: torch.optim.Optimizer,
50 | t_initial: int,
51 | lr_min_rate: float,
52 | warmup_t=0,
53 | warmup_lr_init=0.,
54 | t_in_epochs=True,
55 | noise_range_t=None,
56 | noise_pct=0.67,
57 | noise_std=1.0,
58 | noise_seed=42,
59 | initialize=True,
60 | ) -> None:
61 | super().__init__(
62 | optimizer, param_group_field="lr",
63 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
64 | initialize=initialize)
65 |
66 | self.t_initial = t_initial
67 | self.lr_min_rate = lr_min_rate
68 | self.warmup_t = warmup_t
69 | self.warmup_lr_init = warmup_lr_init
70 | self.t_in_epochs = t_in_epochs
71 | if self.warmup_t:
72 | self.warmup_steps = [(v - warmup_lr_init) /
73 | self.warmup_t for v in self.base_values]
74 | super().update_groups(self.warmup_lr_init)
75 | else:
76 | self.warmup_steps = [1 for _ in self.base_values]
77 |
78 | def _get_lr(self, t):
79 | if t < self.warmup_t:
80 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
81 | else:
82 | t = t - self.warmup_t
83 | total_t = self.t_initial - self.warmup_t
84 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t))
85 | for v in self.base_values]
86 | return lrs
87 |
88 | def get_epoch_values(self, epoch: int):
89 | if self.t_in_epochs:
90 | return self._get_lr(epoch)
91 | else:
92 | return None
93 |
94 | def get_update_values(self, num_updates: int):
95 | if not self.t_in_epochs:
96 | return self._get_lr(num_updates)
97 | else:
98 | return None
99 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | class AverageMeter(object):
6 | def __init__(self):
7 | self.initialized = False
8 | self.val = None
9 | self.avg = None
10 | self.sum = None
11 | self.count = None
12 |
13 | def initialize(self, val, weight):
14 | self.val = val
15 | self.avg = val
16 | self.sum = np.multiply(val, weight)
17 | self.count = weight
18 | self.initialized = True
19 |
20 | def update(self, val, weight=1):
21 | if not self.initialized:
22 | self.initialize(val, weight)
23 | else:
24 | self.add(val, weight)
25 |
26 | def add(self, val, weight):
27 | self.val = val
28 | self.sum = np.add(self.sum, np.multiply(val, weight))
29 | self.count = self.count + weight
30 | self.avg = self.sum / self.count
31 |
32 | @property
33 | def value(self):
34 | return np.round(self.val, 4)
35 |
36 | @property
37 | def average(self):
38 | return np.round(self.avg, 4)
39 |
40 |
41 | def run_online_evaluation(output, target):
42 | if isinstance(output, list):
43 | output = output[0]
44 | if isinstance(target, list):
45 | target = target[0]
46 | online_eval_foreground_dc = []
47 | online_eval_tp = []
48 | online_eval_fp = []
49 | online_eval_fn = []
50 | with torch.no_grad():
51 | num_classes = output.shape[1]
52 | output_softmax = F.softmax(output, 1)
53 | output_seg = output_softmax.argmax(1)
54 | target = target[:, 0]
55 | axes = tuple(range(1, len(target.shape)))
56 | tp_hard = torch.zeros(
57 | (target.shape[0], num_classes - 1)).to(output_seg.device.index)
58 | fp_hard = torch.zeros(
59 | (target.shape[0], num_classes - 1)).to(output_seg.device.index)
60 | fn_hard = torch.zeros(
61 | (target.shape[0], num_classes - 1)).to(output_seg.device.index)
62 | for c in range(1, num_classes):
63 | tp_hard[:, c - 1] = sum_tensor(
64 | (output_seg == c).float() * (target == c).float(), axes=axes)
65 | fp_hard[:, c - 1] = sum_tensor(
66 | (output_seg == c).float() * (target != c).float(), axes=axes)
67 | fn_hard[:, c - 1] = sum_tensor(
68 | (output_seg != c).float() * (target == c).float(), axes=axes)
69 |
70 | tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
71 | fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
72 | fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()
73 |
74 | online_eval_foreground_dc.append(
75 | list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
76 | online_eval_tp.append(list(tp_hard))
77 | online_eval_fp.append(list(fp_hard))
78 | online_eval_fn.append(list(fn_hard))
79 |
80 | online_eval_tp = np.sum(online_eval_tp, 0)
81 | online_eval_fp = np.sum(online_eval_fp, 0)
82 | online_eval_fn = np.sum(online_eval_fn, 0)
83 |
84 | global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
85 | zip(online_eval_tp, online_eval_fp, online_eval_fn)]
86 | if not np.isnan(i)]
87 | average_global_dc = np.mean(global_dc_per_class)
88 | return average_global_dc
89 |
90 |
91 | def sum_tensor(inp, axes, keepdim=False):
92 | axes = np.unique(axes).astype(int)
93 | if keepdim:
94 | for ax in axes:
95 | inp = inp.sum(int(ax), keepdim=True)
96 | else:
97 | for ax in sorted(axes, reverse=True):
98 | inp = inp.sum(int(ax))
99 | return inp
100 |
--------------------------------------------------------------------------------
/coarse_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from loguru import logger
3 | from data import build_loader
4 | from trainer import Trainer
5 | from utils import seed_torch
6 | from losses import build_loss
7 | from datetime import datetime
8 | import wandb
9 | from config import get_config
10 | from models import build_coarse_model
11 | from lr_scheduler import build_scheduler
12 | from optimizer import build_optimizer
13 | import os
14 | import torch.backends.cudnn as cudnn
15 | import numpy as np
16 | import torch
17 | import torch.multiprocessing as mp
18 | import torch.distributed as dist
19 |
20 |
21 | def parse_option():
22 | parser = argparse.ArgumentParser("FLARE2022_corase_training")
23 | parser.add_argument('--cfg', type=str, metavar="FILE",
24 | help='path to config file')
25 | parser.add_argument(
26 | "--opts",
27 | help="Modify config options by adding 'KEY VALUE' pairs. ",
28 | default=None,
29 | nargs='+',
30 | )
31 | parser.add_argument("--tag", help='tag of experiment')
32 | parser.add_argument("-wm", "--wandb_mode", default="offline")
33 | parser.add_argument('-bs', '--batch-size', type=int,
34 | help="batch size for single GPU")
35 | parser.add_argument('-wd', '--with_distributed', help="training without DDP",
36 | required=False, default=False, action="store_true")
37 | parser.add_argument('-ws', '--world_size', type=int,
38 | help="process number for DDP")
39 | args = parser.parse_args()
40 | config = get_config(args)
41 |
42 | return args, config
43 |
44 |
45 | def main(config):
46 | if config.DIS:
47 | mp.spawn(main_worker,
48 | args=(config,),
49 | nprocs=config.WORLD_SIZE,)
50 | else:
51 | main_worker(0, config)
52 |
53 |
54 | def main_worker(local_rank, config):
55 | if local_rank == 0:
56 | config.defrost()
57 | config.EXPERIMENT_ID = f"{config.WANDB.TAG}_{datetime.now().strftime('%y%m%d_%H%M%S')}"
58 | config.freeze()
59 | wandb.init(project=config.WANDB.COARSE_PROJECT,
60 | name=config.EXPERIMENT_ID, config=config, mode=config.WANDB.MODE)
61 | np.set_printoptions(formatter={'float': '{: 0.4f}'.format}, suppress=True)
62 | torch.cuda.set_device(local_rank)
63 | if config.DIS:
64 | dist.init_process_group(
65 | "nccl", init_method='env://', rank=local_rank, world_size=config.WORLD_SIZE)
66 | seed = config.SEED + local_rank
67 | seed_torch(seed)
68 | cudnn.benchmark = True
69 |
70 | train_loader, val_loader = build_loader(config,
71 | config.DATASET.COARSE.SIZE,
72 | config.DATASET.COARSE.PROPRECESS_PATH,
73 | config.DATASET.COARSE.PROPRECESS_UL_PATH,
74 | config.MODEL.COARSE.POOL_OP_KERNEL_SIZES,
75 | config.DATASET.COARSE.NUM_EACH_EPOCH
76 | )
77 | model = build_coarse_model(config).cuda()
78 | if config.DIS:
79 | model = torch.nn.parallel.DistributedDataParallel(
80 | model, device_ids=[local_rank], find_unused_parameters=True)
81 | logger.info(f'\n{model}\n')
82 | loss = build_loss(config.MODEL.DEEP_SUPERVISION,
83 | config.MODEL.COARSE.POOL_OP_KERNEL_SIZES)
84 | optimizer = build_optimizer(config, model)
85 | lr_scheduler = build_scheduler(config, optimizer, len(train_loader))
86 | trainer = Trainer(config=config,
87 | train_loader=train_loader,
88 | val_loader=val_loader,
89 | model=model,
90 | loss=loss,
91 | optimizer=optimizer,
92 | lr_scheduler=lr_scheduler)
93 | trainer.train()
94 |
95 |
96 | if __name__ == '__main__':
97 | os.environ["MASTER_ADDR"] = "localhost"
98 | os.environ["MASTER_PORT"] = "10000"
99 | _, config = parse_option()
100 |
101 | main(config)
102 |
--------------------------------------------------------------------------------
/fine_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from loguru import logger
3 | from data import build_loader
4 | from trainer import Trainer
5 | from utils import seed_torch
6 | from losses import build_loss
7 | from datetime import datetime
8 | import wandb
9 | from config import get_config
10 | from models import build_fine_model
11 | from lr_scheduler import build_scheduler
12 | from optimizer import build_optimizer
13 | import os
14 | import torch.backends.cudnn as cudnn
15 | import numpy as np
16 | import torch
17 | import torch.multiprocessing as mp
18 | import torch.distributed as dist
19 | from batchgenerators.utilities.file_and_folder_operations import *
20 |
21 |
22 | def parse_option():
23 | parser = argparse.ArgumentParser("FLARE2022_fine_training")
24 | parser.add_argument('--cfg', type=str, metavar="FILE",
25 | help='path to config file')
26 | parser.add_argument(
27 | "--opts",
28 | help="Modify config options by adding 'KEY VALUE' pairs. ",
29 | default=None,
30 | nargs='+',
31 | )
32 | parser.add_argument("--tag", help='tag of experiment')
33 | parser.add_argument("-wm", "--wandb_mode", default="offline")
34 | parser.add_argument('-bs', '--batch-size', type=int,
35 | help="batch size for single GPU")
36 | parser.add_argument('-wd', '--with_distributed', help="training without DDP",
37 | required=False, default=False, action="store_true")
38 | parser.add_argument('-ws', '--world_size', type=int,
39 | help="process number for DDP")
40 | args = parser.parse_args()
41 | config = get_config(args)
42 |
43 | return args, config
44 |
45 |
46 | def main(config):
47 | if config.DIS:
48 | mp.spawn(main_worker,
49 | args=(config,),
50 | nprocs=config.WORLD_SIZE,)
51 | else:
52 | main_worker(0, config)
53 |
54 |
55 | def main_worker(local_rank, config):
56 | if local_rank == 0:
57 | config.defrost()
58 | config.EXPERIMENT_ID = f"{config.WANDB.TAG}_{datetime.now().strftime('%y%m%d_%H%M%S')}"
59 | config.freeze()
60 | wandb.init(project=config.WANDB.FINE_PROJECT,
61 | name=config.EXPERIMENT_ID, config=config, mode=config.WANDB.MODE)
62 | np.set_printoptions(formatter={'float': '{: 0.4f}'.format}, suppress=True)
63 | torch.cuda.set_device(local_rank)
64 | if config.DIS:
65 | dist.init_process_group(
66 | "nccl", init_method='env://', rank=local_rank, world_size=config.WORLD_SIZE)
67 | seed = config.SEED + local_rank
68 | seed_torch(seed)
69 | cudnn.benchmark = True
70 |
71 | train_loader, val_loader = build_loader(config,
72 | config.DATASET.FINE.SIZE,
73 | config.DATASET.FINE.PROPRECESS_PATH,
74 | config.DATASET.FINE.PROPRECESS_UL_PATH,
75 | config.MODEL.FINE.POOL_OP_KERNEL_SIZES,
76 | config.DATASET.FINE.NUM_EACH_EPOCH
77 | )
78 | model = build_fine_model(config).cuda()
79 | if config.DIS:
80 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda()
81 | model = torch.nn.parallel.DistributedDataParallel(
82 | model, device_ids=[local_rank], find_unused_parameters=True)
83 | logger.info(f'\n{model}\n')
84 | loss = build_loss(config.MODEL.DEEP_SUPERVISION,
85 | config.MODEL.FINE.POOL_OP_KERNEL_SIZES)
86 | optimizer = build_optimizer(config, model)
87 | lr_scheduler = build_scheduler(config, optimizer, len(train_loader))
88 | trainer = Trainer(config=config,
89 | train_loader=train_loader,
90 | val_loader=val_loader,
91 | model=model,
92 | loss=loss,
93 | optimizer=optimizer,
94 | lr_scheduler=lr_scheduler)
95 | trainer.train()
96 |
97 |
98 | if __name__ == '__main__':
99 | os.environ["MASTER_ADDR"] = "localhost"
100 | os.environ["MASTER_PORT"] = "10000"
101 | _, config = parse_option()
102 |
103 | main(config)
104 |
--------------------------------------------------------------------------------
/data/data_augmentation.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import numpy as np
3 | from batchgenerators.augmentations.utils import resize_segmentation
4 | from batchgenerators.transforms.abstract_transforms import AbstractTransform
5 |
6 | default_3D_augmentation_params = {
7 |
8 | # "do_elastic": False,
9 | "elastic_deform_alpha": (0., 900.),
10 | "elastic_deform_sigma": (9., 13.),
11 | "p_eldef": 0.2,
12 |
13 | # "do_scaling": True,
14 | "scale_range": (0.85, 1.25),
15 | "independent_scale_factor_for_each_axis": False,
16 | "p_independent_scale_per_axis": 1,
17 | "p_scale": 0.2,
18 |
19 | # "do_rotation": True,
20 | "rotation_x": (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
21 | "rotation_y": (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
22 | "rotation_z": (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
23 | "rotation_p_per_axis": 1,
24 | "p_rot": 0.2,
25 |
26 | # "random_crop": False,
27 | "random_crop_dist_to_border": None,
28 |
29 | # "do_gamma": True,
30 | "gamma_retain_stats": True,
31 | "gamma_range": (0.7, 1.5),
32 | "p_gamma": 0.3,
33 |
34 | # "do_mirror": True,
35 | "mirror_axes": (0, 1, 2),
36 |
37 | "border_mode_data": "constant",
38 |
39 | # "do_additive_brightness": False,
40 | "additive_brightness_p_per_sample": 0.15,
41 | "additive_brightness_p_per_channel": 0.5,
42 | "additive_brightness_mu": 0.0,
43 | "additive_brightness_sigma": 0.1
44 | }
45 |
46 | default_2D_augmentation_params = deepcopy(default_3D_augmentation_params)
47 |
48 | default_2D_augmentation_params["elastic_deform_alpha"] = (0., 200.)
49 | default_2D_augmentation_params["elastic_deform_sigma"] = (9., 13.)
50 | default_2D_augmentation_params["rotation_x"] = (
51 | -180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
52 | default_2D_augmentation_params["rotation_y"] = (
53 | -0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
54 | default_2D_augmentation_params["rotation_z"] = (
55 | -0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
56 |
57 |
58 | def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):
59 | if isinstance(rot_x, (tuple, list)):
60 | rot_x = max(np.abs(rot_x))
61 | if isinstance(rot_y, (tuple, list)):
62 | rot_y = max(np.abs(rot_y))
63 | if isinstance(rot_z, (tuple, list)):
64 | rot_z = max(np.abs(rot_z))
65 | rot_x = min(90 / 360 * 2. * np.pi, rot_x)
66 | rot_y = min(90 / 360 * 2. * np.pi, rot_y)
67 | rot_z = min(90 / 360 * 2. * np.pi, rot_z)
68 | from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
69 | coords = np.array(final_patch_size)
70 | final_shape = np.copy(coords)
71 | if len(coords) == 3:
72 | final_shape = np.max(
73 | np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
74 | final_shape = np.max(
75 | np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
76 | final_shape = np.max(
77 | np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
78 | elif len(coords) == 2:
79 | final_shape = np.max(
80 | np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
81 | final_shape /= min(scale_range)
82 | return final_shape.astype(int)
83 |
84 |
85 | class DownsampleSegForDSTransform(AbstractTransform):
86 |
87 | def __init__(self, ds_scales=(1, 0.5, 0.25), order=0, input_key="seg", output_key="seg", axes=None):
88 | self.axes = axes
89 | self.output_key = output_key
90 | self.input_key = input_key
91 | self.order = order
92 | self.ds_scales = ds_scales
93 |
94 | def __call__(self, **data_dict):
95 | data_dict[self.output_key] = downsample_seg_for_ds_transform(data_dict[self.input_key], self.ds_scales,
96 | self.order, self.axes)
97 | return data_dict
98 |
99 |
100 | def downsample_seg_for_ds_transform(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), order=0, axes=None):
101 | if axes is None:
102 | axes = list(range(1, len(seg.shape)))
103 | output = []
104 | for s in ds_scales:
105 | if all([i == 1 for i in s]):
106 | output.append(seg)
107 | else:
108 | new_shape = np.array(seg.shape).astype(float)
109 | for i, a in enumerate(axes):
110 | new_shape[a] *= s[i]
111 | new_shape = np.round(new_shape).astype(int)
112 | out_seg = np.zeros(new_shape, dtype=seg.dtype)
113 | for c in range(seg.shape[0]):
114 | out_seg[c] = resize_segmentation(seg[c], new_shape[1:], order)
115 | output.append(out_seg)
116 | return output
117 |
--------------------------------------------------------------------------------
/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from batchgenerators.utilities.file_and_folder_operations import *
4 | import shutil
5 | import traceback
6 | from multiprocessing import Pool, cpu_count
7 | from config import get_config_no_args
8 | from data.utils import load_data, clip_and_normalize_mean_std, resize_segmentation, change_axes_of_image, crop_image_according_to_mask, create_two_class_mask
9 | from collections import OrderedDict
10 | from skimage.transform import resize
11 |
12 | def run_prepare_data(config, is_overwrite, is_multiprocessing=True):
13 |
14 | data_prepare = data_process(config, is_overwrite)
15 | if is_multiprocessing:
16 | pool = Pool(int(cpu_count() * 0.2))
17 | for data in data_prepare.data_list:
18 | try:
19 | pool.apply_async(data_prepare.process, (data,))
20 | except Exception as err:
21 | traceback.print_exc()
22 | print('Create image/label throws exception %s, with series_id %s!' %
23 | (err, data_prepare.data_info))
24 |
25 | pool.close()
26 | pool.join()
27 | else:
28 | for data in data_prepare.data_list:
29 | data_prepare.process(data)
30 |
31 |
32 | class data_process(object):
33 | def __init__(self, config, is_overwrite=True):
34 | self.config = config
35 | self.coarse_size = self.config.DATASET.COARSE.SIZE
36 | self.fine_size = self.config.DATASET.FINE.SIZE
37 | self.nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION
38 | self.extend_size = self.config.DATASET.EXTEND_SIZE
39 |
40 | self.image_path = config.DATASET.TRAIN_IMAGE_PATH
41 | self.mask_path = config.DATASET.TRAIN_MASK_PATH
42 | self.preprocess_coarse_path = config.DATASET.COARSE.PROPRECESS_PATH
43 | self.preprocess_fine_path = config.DATASET.FINE.PROPRECESS_PATH
44 | self.data_list = subfiles(self.image_path, join=False, suffix='nii.gz')
45 | if is_overwrite and isdir(self.preprocess_coarse_path):
46 | shutil.rmtree(self.preprocess_coarse_path)
47 | os.makedirs(self.preprocess_coarse_path, exist_ok=True)
48 | if is_overwrite and isdir(self.preprocess_fine_path):
49 | shutil.rmtree(self.preprocess_fine_path)
50 | os.makedirs(self.preprocess_fine_path, exist_ok=True)
51 |
52 | def process(self, image_id):
53 | data_id = image_id.split("_0000.nii.gz")[0]
54 | image, image_spacing, image_direction = load_data(
55 | join(self.image_path, data_id + "_0000.nii.gz"))
56 | mask, _, mask_direction = load_data(
57 | join(self.mask_path, data_id + ".nii.gz"))
58 | assert image_direction.all() == mask_direction.all()
59 | if self.nor_dir:
60 | image = change_axes_of_image(image, image_direction)
61 | mask = change_axes_of_image(mask, mask_direction)
62 | data_info = OrderedDict()
63 | data_info["raw_shape"] = image.shape
64 | data_info["raw_spacing"] = image_spacing
65 | resize_spacing = image_spacing*image.shape/self.coarse_size
66 | data_info["resize_spacing"] = resize_spacing
67 | data_info["image_direction"] = image_direction
68 | with open(os.path.join(self.preprocess_coarse_path, "%s_info.pkl" % data_id), 'wb') as f:
69 | pickle.dump(data_info, f)
70 | print(data_id, image.shape)
71 |
72 | image_resize = resize(image, self.coarse_size,
73 | order=3, mode='edge', anti_aliasing=False)
74 |
75 | mask_resize = resize_segmentation(mask, self.coarse_size, order=0)
76 | mask_binary = create_two_class_mask(mask_resize)
77 |
78 | image_normal = clip_and_normalize_mean_std(image_resize)
79 |
80 | np.savez_compressed(os.path.join(self.preprocess_coarse_path, "%s.npz" %
81 | data_id), data=image_normal[None, ...], seg=mask_binary[None, ...])
82 | margin = [int(self.extend_size / image_spacing[0]),
83 | int(self.extend_size / image_spacing[1]),
84 | int(self.extend_size / image_spacing[2])]
85 | crop_image, crop_mask = crop_image_according_to_mask(
86 | image, np.array(mask, dtype=int), margin)
87 |
88 | data_info_crop = OrderedDict()
89 | data_info_crop["raw_shape"] = image.shape
90 | data_info_crop["crop_shape"] = crop_image.shape
91 | data_info_crop["raw_spacing"] = image_spacing
92 | resize_crop_spacing = image_spacing*crop_image.shape/self.fine_size
93 | data_info_crop["resize_crop_spacing"] = resize_crop_spacing
94 | data_info_crop["image_direction"] = image_direction
95 | with open(os.path.join(self.preprocess_fine_path, "%s_info.pkl" % data_id), 'wb') as f:
96 | pickle.dump(data_info_crop, f)
97 |
98 | crop_image_resize = resize(crop_image, self.fine_size,
99 | order=3, mode='edge', anti_aliasing=False)
100 | crop_mask_resize = resize_segmentation(
101 | crop_mask, self.fine_size, order=0)
102 | crop_image_normal = clip_and_normalize_mean_std(crop_image_resize)
103 | np.savez_compressed(os.path.join(self.preprocess_fine_path, "%s.npz" % data_id),
104 | data=crop_image_normal[None, ...], seg=crop_mask_resize[None, ...])
105 | print('End processing %s.' % data_id)
106 |
107 |
108 |
109 | if __name__ == '__main__':
110 | config = get_config_no_args()
111 |
112 | run_prepare_data(config, True, True)
113 |
--------------------------------------------------------------------------------
/unlabel_data_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from batchgenerators.utilities.file_and_folder_operations import *
4 | import shutil
5 | import traceback
6 | from multiprocessing import Pool, cpu_count
7 | from config import get_config_no_args
8 | from data.utils import crop_image_according_to_mask, load_data, clip_and_normalize_mean_std, resize_segmentation, change_axes_of_image, create_two_class_mask
9 | from collections import OrderedDict
10 | from skimage.transform import resize
11 |
12 |
13 | def run_prepare_data(config, is_overwrite, is_multiprocessing=True):
14 |
15 | data_prepare = data_process(config, is_overwrite)
16 | if is_multiprocessing:
17 | pool = Pool(int(cpu_count() * 0.2))
18 | for data in data_prepare.data_list:
19 | try:
20 | pool.apply_async(data_prepare.process, (data,))
21 | except Exception as err:
22 | traceback.print_exc()
23 | print('Create image/label throws exception %s, with series_id %s!' %
24 | (err, data_prepare.data_info))
25 |
26 | pool.close()
27 | pool.join()
28 | else:
29 | for data in data_prepare.data_list:
30 | data_prepare.process(data)
31 |
32 | class data_process(object):
33 | def __init__(self, config, is_overwrite=False):
34 | self.config = config
35 | self.coarse_size = self.config.DATASET.COARSE.SIZE
36 | self.fine_size = self.config.DATASET.FINE.SIZE
37 | self.nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION
38 | self.extend_size = self.config.DATASET.EXTEND_SIZE
39 |
40 | self.image_path = config.DATASET.TRAIN_UNLABELED_IMAGE_PATH
41 | self.mask_path = config.DATASET.TRAIN_UNLABELED_MASK_PATH
42 | self.preprocess_coarse_path = config.DATASET.COARSE.PROPRECESS_UL_PATH
43 | self.preprocess_fine_path = config.DATASET.FINE.PROPRECESS_UL_PATH
44 | self.data_list = subfiles(self.image_path, join=False, suffix='nii.gz')
45 | if is_overwrite and isdir(self.preprocess_coarse_path):
46 | shutil.rmtree(self.preprocess_coarse_path)
47 | os.makedirs(self.preprocess_coarse_path, exist_ok=True)
48 | if is_overwrite and isdir(self.preprocess_fine_path):
49 | shutil.rmtree(self.preprocess_fine_path)
50 | os.makedirs(self.preprocess_fine_path, exist_ok=True)
51 |
52 | def process(self, image_id):
53 |
54 | data_id = image_id.split("_0000.nii.gz")[0]
55 |
56 | image, image_spacing, image_direction = load_data(
57 | join(self.image_path, data_id + "_0000.nii.gz"))
58 | mask, _, mask_direction = load_data(
59 | join(self.mask_path, data_id + ".nii.gz"))
60 | assert image_direction.all() == mask_direction.all()
61 | print(data_id, image.shape)
62 | if self.nor_dir:
63 | image = change_axes_of_image(image, image_direction)
64 | mask = change_axes_of_image(mask, mask_direction)
65 | data_info = OrderedDict()
66 |
67 | data_info["raw_shape"] = image.shape
68 | data_info["raw_spacing"] = image_spacing
69 | resize_spacing = image_spacing*image.shape/self.coarse_size
70 | data_info["resize_spacing"] = resize_spacing
71 | data_info["image_direction"] = image_direction
72 | with open(os.path.join(self.preprocess_coarse_path, "%s_info.pkl" % data_id), 'wb') as f:
73 | pickle.dump(data_info, f)
74 |
75 | image_resize = resize(image, self.coarse_size,
76 | order=3, mode='edge', anti_aliasing=False)
77 | mask_resize = resize_segmentation(
78 | mask, self.coarse_size, order=0)
79 | mask_binary = create_two_class_mask(mask_resize)
80 | image_normal = clip_and_normalize_mean_std(image_resize)
81 |
82 | np.savez_compressed(os.path.join(self.preprocess_coarse_path, "%s.npz" %
83 | data_id), data=image_normal[None, ...], seg=mask_binary[None, ...])
84 |
85 |
86 | margin = [int(self.extend_size / image_spacing[0]),
87 | int(self.extend_size / image_spacing[1]),
88 | int(self.extend_size / image_spacing[2])]
89 | crop_image, crop_mask = crop_image_according_to_mask(
90 | image, np.array(mask, dtype=int), margin)
91 | data_info_crop = OrderedDict()
92 | data_info_crop["raw_shape"] = image.shape
93 | data_info_crop["crop_shape"] = crop_image.shape
94 | data_info_crop["raw_spacing"] = image_spacing
95 | resize_crop_spacing = image_spacing*crop_image.shape/self.fine_size
96 | data_info_crop["resize_crop_spacing"] = resize_crop_spacing
97 | data_info_crop["image_direction"] = image_direction
98 | with open(os.path.join(self.preprocess_fine_path, "%s_info.pkl" % data_id), 'wb') as f:
99 | pickle.dump(data_info_crop, f)
100 |
101 | crop_image_resize = resize(
102 | crop_image, self.fine_size, order=3, mode='edge', anti_aliasing=False)
103 | crop_mask_resize = resize_segmentation(
104 | crop_mask, self.fine_size, order=0)
105 | crop_image_normal = clip_and_normalize_mean_std(crop_image_resize)
106 | np.savez_compressed(os.path.join(self.preprocess_fine_path, "%s.npz" % data_id),
107 | data=crop_image_normal[None, ...], seg=crop_mask_resize[None, ...])
108 |
109 | print('End processing %s.' % data_id)
110 |
111 | if __name__ == '__main__':
112 | config = get_config_no_args()
113 | run_prepare_data(config, False, False)
114 |
115 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from batchgenerators.utilities.file_and_folder_operations import *
3 | from skimage.transform import resize
4 | import cv2
5 | import shutil
6 | import torch
7 | import SimpleITK as sitk
8 | import cc3d
9 | import fastremap
10 | import torch.nn.functional as F
11 | from scipy.ndimage import binary_fill_holes
12 |
13 | def load_pickle(file: str, mode: str = 'rb'):
14 | with open(file, mode) as f:
15 | a = pickle.load(f)
16 | return a
17 |
18 |
19 | def load_data(data_path):
20 |
21 | data_itk = sitk.ReadImage(data_path)
22 | data_npy = sitk.GetArrayFromImage(data_itk)[None].astype(np.float32)
23 | data_spacing = np.array(data_itk.GetSpacing())[[2, 1, 0]]
24 | direction = data_itk.GetDirection()
25 | direction = np.array((direction[8], direction[4], direction[0]))
26 | return data_npy[0], data_spacing, direction
27 |
28 |
29 | def change_axes_of_image(npy_image, orientation):
30 | if orientation[0] < 0:
31 | npy_image = np.flip(npy_image, axis=0)
32 | if orientation[1] > 0:
33 | npy_image = np.flip(npy_image, axis=1)
34 | if orientation[2] > 0:
35 | npy_image = np.flip(npy_image, axis=2)
36 | return npy_image
37 |
38 |
39 | def clip_and_normalize_mean_std(image):
40 | mean = np.mean(image)
41 | std = np.std(image)
42 |
43 | image = (image - mean) / (std + 1e-5)
44 | return image
45 |
46 |
47 | def resize_segmentation(segmentation, new_shape, order=3):
48 | tpe = segmentation.dtype
49 | unique_labels = np.unique(segmentation)
50 | assert len(segmentation.shape) == len(
51 | new_shape), "new shape must have same dimensionality as segmentation"
52 | if order == 0:
53 | return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe)
54 | else:
55 | reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
56 |
57 | for i, c in enumerate(unique_labels):
58 | mask = segmentation == c
59 | reshaped_multihot = resize(mask.astype(
60 | float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
61 | reshaped[reshaped_multihot >= 0.5] = c
62 | return reshaped
63 |
64 | def maybe_to_torch(d):
65 | if isinstance(d, list):
66 | d = [maybe_to_torch(i) if not isinstance(
67 | i, torch.Tensor) else i for i in d]
68 | elif not isinstance(d, torch.Tensor):
69 | d = torch.from_numpy(d).float()
70 | return d
71 |
72 |
73 | def create_two_class_mask(mask):
74 |
75 | mask = np.clip(mask, 0, 1)
76 | mask = binary_fill_holes(mask, origin=1,)
77 | return mask
78 |
79 |
80 | def extract_topk_largest_candidates(npy_mask: np.array, label_unique, out_num_label: List) -> np.array:
81 | mask_shape = npy_mask.shape
82 | out_mask = np.zeros(
83 | [mask_shape[1], mask_shape[2], mask_shape[3]], np.uint8)
84 | for i in range(1, mask_shape[0]):
85 | t_mask = npy_mask[i].copy()
86 | keep_topk_largest_connected_object(
87 | t_mask, out_num_label, out_mask, label_unique[i])
88 |
89 | return out_mask
90 |
91 |
92 | def keep_topk_largest_connected_object(npy_mask, k, out_mask, out_label):
93 | labels_out = cc3d.connected_components(npy_mask, connectivity=26)
94 | areas = {}
95 | for label, extracted in cc3d.each(labels_out, binary=True, in_place=True):
96 | areas[label] = fastremap.foreground(extracted)
97 | candidates = sorted(areas.items(), key=lambda item: item[1], reverse=True)
98 |
99 | for i in range(min(k, len(candidates))):
100 | out_mask[labels_out == int(candidates[i][0])] = out_label
101 |
102 |
103 | def to_one_hot(seg, all_seg_labels=None):
104 | if all_seg_labels is None:
105 | all_seg_labels = np.unique(seg)
106 | result = np.zeros((len(all_seg_labels), *seg.shape), dtype=seg.dtype)
107 | for i, l in enumerate(all_seg_labels):
108 | result[i][seg == l] = 1
109 | return result
110 |
111 |
112 | def input_downsample(x, input_size):
113 | x = F.interpolate(x, size=input_size, mode='trilinear',align_corners=False)
114 | mean = torch.mean(x)
115 | std = torch.std(x)
116 | x = (x - mean) / (1e-5 + std)
117 | return x
118 |
119 |
120 | def output_upsample(x, output_size):
121 | x = F.interpolate(x, size=output_size,
122 | mode='trilinear', align_corners=False)
123 | return x
124 |
125 |
126 | def get_bbox_from_mask(mask, outside_value=0):
127 | mask_voxel_coords = np.where(mask != outside_value)
128 | minzidx = int(np.min(mask_voxel_coords[0]))
129 | maxzidx = int(np.max(mask_voxel_coords[0])) + 1
130 | minxidx = int(np.min(mask_voxel_coords[1]))
131 | maxxidx = int(np.max(mask_voxel_coords[1])) + 1
132 | minyidx = int(np.min(mask_voxel_coords[2]))
133 | maxyidx = int(np.max(mask_voxel_coords[2])) + 1
134 | return np.array([[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]])
135 |
136 |
137 | def crop_image_according_to_mask(npy_image, npy_mask, margin=None):
138 | if margin is None:
139 | margin = [20, 20, 20]
140 |
141 | bbox = get_bbox_from_mask(npy_mask)
142 |
143 | extend_bbox = np.concatenate(
144 | [np.max([[0, 0, 0], bbox[:, 0] - margin], axis=0)[:, np.newaxis],
145 | np.min([npy_image.shape, bbox[:, 1] + margin], axis=0)[:, np.newaxis]], axis=1)
146 |
147 |
148 | crop_mask = crop_to_bbox(npy_mask,extend_bbox)
149 | crop_image = crop_to_bbox(npy_image,extend_bbox)
150 |
151 |
152 | return crop_image, crop_mask
153 |
154 |
155 | def crop_to_bbox(image, bbox):
156 | assert len(image.shape) == 3, "only supports 3d images"
157 | resizer = (slice(bbox[0][0], bbox[0][1]), slice(
158 | bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
159 | return image[resizer]
160 |
161 |
162 | def crop_image_according_to_bbox(npy_image, bbox, margin=None):
163 | if margin is None:
164 | margin = [20, 20, 20]
165 |
166 | image_shape = npy_image.shape
167 | extend_bbox = [[max(0, int(bbox[0][0]-margin[0])),
168 | min(image_shape[0], int(bbox[0][1]+margin[0]))],
169 | [max(0, int(bbox[1][0]-margin[1])),
170 | min(image_shape[1], int(bbox[1][1]+margin[1]))],
171 | [max(0, int(bbox[2][0]-margin[2])),
172 | min(image_shape[2], int(bbox[2][1]+margin[2]))]]
173 |
174 |
175 | crop_image = crop_to_bbox(npy_image, extend_bbox)
176 |
177 |
178 | return crop_image, extend_bbox
179 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | from loguru import logger
5 | from tqdm import tqdm
6 | from utils import to_cuda
7 | from metrics import AverageMeter,run_online_evaluation
8 | import torch.distributed as dist
9 | import wandb
10 |
11 |
12 | class Trainer:
13 | def __init__(self, config, train_loader,val_loader, model,loss,optimizer,lr_scheduler):
14 | self.config = config
15 |
16 | self.scaler = torch.cuda.amp.GradScaler(enabled=True)
17 | self.loss = loss
18 | self.model = model
19 | self.train_loader = train_loader
20 | self.val_loader = val_loader
21 | self.optimizer = optimizer
22 | self.lr_scheduler = lr_scheduler
23 | self.num_steps = len(self.train_loader)
24 | if self._get_rank()==0:
25 | self.checkpoint_dir = os.path.join(config.SAVE_DIR,config.EXPERIMENT_ID)
26 |
27 | os.makedirs(self.checkpoint_dir)
28 | def train(self):
29 |
30 | for epoch in range(1, self.config.TRAIN.EPOCHS+1):
31 | if self.config.DIS:
32 | self.train_loader.sampler.set_epoch(epoch)
33 | self._train_epoch(epoch)
34 | if self.val_loader is not None and epoch % self.config.TRAIN.VAL_NUM_EPOCHS == 0:
35 | results = self._valid_epoch(epoch)
36 | if self._get_rank()==0 :
37 | logger.info(f'## Info for epoch {epoch} ## ')
38 | for k, v in results.items():
39 | logger.info(f'{str(k):15s}: {v}')
40 | if epoch % self.config.TRAIN.VAL_NUM_EPOCHS == 0 and self._get_rank()==0:
41 | self._save_checkpoint(epoch)
42 |
43 |
44 | def _train_epoch(self, epoch):
45 | self.batch_time = AverageMeter()
46 | self.data_time = AverageMeter()
47 | self.total_loss = AverageMeter()
48 | self.DICE = AverageMeter()
49 |
50 | self.model.train()
51 |
52 |
53 | tbar = tqdm(self.train_loader, ncols=150)
54 | tic = time.time()
55 | for idx, (data,_) in enumerate(tbar):
56 | self.data_time.update(time.time() - tic)
57 | img = to_cuda(data["data"])
58 | gt = to_cuda(data["seg"])
59 | self.optimizer.zero_grad()
60 |
61 | with torch.cuda.amp.autocast(enabled=self.config.AMP):
62 | pre = self.model(img)
63 | loss = self.loss(pre, gt)
64 | if self.config.AMP:
65 | self.scaler.scale(loss).backward()
66 | if self.config.TRAIN.DO_BACKPROP:
67 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 12)
68 | self.scaler.step(self.optimizer)
69 | self.scaler.update()
70 | else:
71 | loss.backward()
72 | if self.config.TRAIN.DO_BACKPROP:
73 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 12)
74 | self.optimizer.step()
75 |
76 | self.total_loss.update(loss.item())
77 | self.batch_time.update(time.time() - tic)
78 | self.DICE.update(run_online_evaluation(pre, gt))
79 |
80 | tbar.set_description(
81 | 'TRAIN ({}) | Loss: {} | DICE {} |B {} D {} |'.format(
82 | epoch, self.total_loss.average, self.DICE.average, self.batch_time.average, self.data_time.average))
83 | tic = time.time()
84 |
85 | self.lr_scheduler.step_update(epoch * self.num_steps + idx)
86 | if self._get_rank()==0:
87 | wandb.log({'train/loss': self.total_loss.average,
88 | 'train/dice': self.DICE.average,
89 | 'train/lr': self.optimizer.param_groups[0]['lr']},
90 | step=epoch)
91 | def _valid_epoch(self, epoch):
92 | logger.info('\n###### EVALUATION ######')
93 | self.batch_time = AverageMeter()
94 | self.data_time = AverageMeter()
95 | self.total_loss = AverageMeter()
96 | self.DICE = AverageMeter()
97 |
98 | self.model.eval()
99 |
100 | tbar = tqdm(self.val_loader, ncols=150)
101 | tic = time.time()
102 | with torch.no_grad():
103 |
104 | for idx, (data, _) in enumerate(tbar):
105 | self.data_time.update(time.time() - tic)
106 | img = to_cuda(data["data"])
107 | gt = to_cuda(data["seg"])
108 |
109 | with torch.cuda.amp.autocast(enabled=self.config.AMP):
110 |
111 | pre = self.model(img)
112 | loss = self.loss(pre, gt)
113 |
114 | self.total_loss.update(loss.item())
115 | self.batch_time.update(time.time() - tic)
116 |
117 | self.DICE.update(run_online_evaluation(pre, gt))
118 | tbar.set_description(
119 | 'TEST ({}) | Loss: {} | DICE {} |B {} D {} |'.format(
120 | epoch, self.total_loss.average, self.DICE.average, self.batch_time.average, self.data_time.average))
121 | tic = time.time()
122 | if self._get_rank()==0:
123 | wandb.log({'val/loss': self.total_loss.average,
124 | 'val/dice': self.DICE.average,
125 | 'val/batch_time': self.batch_time.average,
126 | 'val/data_time': self.data_time.average
127 | },
128 | step=epoch)
129 | log = {'val_loss': self.total_loss.average,
130 | 'val_dice': self.DICE.average
131 | }
132 | return log
133 | def _get_rank(self):
134 | """get gpu id in distribution training."""
135 | if not dist.is_available():
136 | return 0
137 | if not dist.is_initialized():
138 | return 0
139 | return dist.get_rank()
140 |
141 | def _save_checkpoint(self, epoch):
142 | state = {
143 | 'arch': type(self.model).__name__,
144 | 'epoch': epoch,
145 | 'state_dict': self.model.state_dict(),
146 | 'optimizer': self.optimizer.state_dict(),
147 | 'config': self.config
148 | }
149 | filename = os.path.join(self.checkpoint_dir,
150 | 'final_checkpoint.pth')
151 | logger.info(f'Saving a checkpoint: {filename} ...')
152 | torch.save(state, filename)
153 | return filename
154 |
155 | def _reset_metrics(self):
156 | self.batch_time = AverageMeter()
157 | self.data_time = AverageMeter()
158 | self.total_loss = AverageMeter()
159 | self.DICE = AverageMeter()
160 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | from torch import nn, Tensor
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | def softmax_helper(x): return F.softmax(x, 1)
6 |
7 |
8 | def build_loss(deep_supervision, pool_op_kernel_sizes):
9 | if deep_supervision:
10 | weight = get_weight_factors(len(pool_op_kernel_sizes))
11 | loss = MultipleOutputLoss2(DC_and_CE_loss(), weight)
12 | else:
13 | loss = DC_and_CE_loss()
14 | return loss
15 |
16 |
17 | def get_weight_factors(net_numpool):
18 | weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
19 | mask = np.array([True] + [True if i < net_numpool -
20 | 1 else False for i in range(1, net_numpool)])
21 | weights[~mask] = 0
22 | weights = weights / weights.sum()
23 | return weights
24 |
25 |
26 | class MultipleOutputLoss2(nn.Module):
27 | def __init__(self, loss, weight_factors=None):
28 | super(MultipleOutputLoss2, self).__init__()
29 | self.weight_factors = weight_factors
30 | self.loss = loss
31 |
32 | def forward(self, x, y):
33 | assert isinstance(x, (tuple, list)), "x must be either tuple or list"
34 | assert isinstance(y, (tuple, list)), "y must be either tuple or list"
35 | if self.weight_factors is None:
36 | weights = [1] * len(x)
37 | else:
38 | weights = self.weight_factors
39 |
40 | l = weights[0] * self.loss(x[0], y[0])
41 | for i in range(1, len(x)):
42 | if weights[i] != 0:
43 | l += weights[i] * self.loss(x[i], y[i])
44 | return l
45 |
46 |
47 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
48 |
49 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
50 | if len(target.shape) == len(input.shape):
51 | assert target.shape[1] == 1
52 | target = target[:, 0]
53 | return super().forward(input, target.long())
54 |
55 |
56 | class DC_and_CE_loss(nn.Module):
57 | def __init__(self, aggregate="sum", weight_ce=1, weight_dice=1,
58 | log_dice=False, ignore_label=None):
59 | super(DC_and_CE_loss, self).__init__()
60 |
61 | self.log_dice = log_dice
62 | self.weight_dice = weight_dice
63 | self.weight_ce = weight_ce
64 | self.aggregate = aggregate
65 | self.ce = RobustCrossEntropyLoss()
66 |
67 | self.ignore_label = ignore_label
68 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper)
69 |
70 | def forward(self, net_output, target):
71 | if self.ignore_label is not None:
72 | assert target.shape[1] == 1, 'not implemented for one hot encoding'
73 | mask = target != self.ignore_label
74 | target[~mask] = 0
75 | mask = mask.float()
76 | else:
77 | mask = None
78 |
79 | dc_loss = self.dc(net_output, target,
80 | loss_mask=mask) if self.weight_dice != 0 else 0
81 | if self.log_dice:
82 | dc_loss = -torch.log(-dc_loss)
83 |
84 | ce_loss = self.ce(
85 | net_output, target[:, 0].long()) if self.weight_ce != 0 else 0
86 | if self.ignore_label is not None:
87 | ce_loss *= mask[:, 0]
88 | ce_loss = ce_loss.sum() / mask.sum()
89 |
90 | if self.aggregate == "sum":
91 | result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
92 | else:
93 | # reserved for other stuff (later)
94 | raise NotImplementedError("nah son")
95 | return result
96 |
97 |
98 | class SoftDiceLoss(nn.Module):
99 | def __init__(self, apply_nonlin=None, batch_dice=True, do_bg=False, smooth=1e-5):
100 | """
101 | """
102 | super(SoftDiceLoss, self).__init__()
103 |
104 | self.do_bg = do_bg
105 | self.batch_dice = batch_dice
106 | self.apply_nonlin = apply_nonlin
107 | self.smooth = smooth
108 |
109 | def forward(self, x, y, loss_mask=None):
110 | shp_x = x.shape
111 |
112 | if self.batch_dice:
113 | axes = [0] + list(range(2, len(shp_x)))
114 | else:
115 | axes = list(range(2, len(shp_x)))
116 |
117 | if self.apply_nonlin is not None:
118 | x = self.apply_nonlin(x)
119 |
120 | tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
121 |
122 | nominator = 2 * tp + self.smooth
123 | denominator = 2 * tp + fp + fn + self.smooth
124 |
125 | dc = nominator / (denominator + 1e-8)
126 |
127 | if not self.do_bg:
128 | if self.batch_dice:
129 | dc = dc[1:]
130 | else:
131 | dc = dc[:, 1:]
132 | dc = dc.mean()
133 |
134 | return -dc
135 |
136 |
137 | def sum_tensor(inp, axes, keepdim=False):
138 | axes = np.unique(axes).astype(int)
139 | if keepdim:
140 | for ax in axes:
141 | inp = inp.sum(int(ax), keepdim=True)
142 | else:
143 | for ax in sorted(axes, reverse=True):
144 | inp = inp.sum(int(ax))
145 | return inp
146 |
147 |
148 | def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
149 | if axes is None:
150 | axes = tuple(range(2, len(net_output.size())))
151 |
152 | shp_x = net_output.shape
153 | shp_y = gt.shape
154 |
155 | with torch.no_grad():
156 | if len(shp_x) != len(shp_y):
157 | gt = gt.view((shp_y[0], 1, *shp_y[1:]))
158 |
159 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
160 | # if this is the case then gt is probably already a one hot encoding
161 | y_onehot = gt
162 | else:
163 | gt = gt.long()
164 | y_onehot = torch.zeros(shp_x, device=net_output.device)
165 | y_onehot.scatter_(1, gt, 1)
166 |
167 | tp = net_output * y_onehot
168 | fp = net_output * (1 - y_onehot)
169 | fn = (1 - net_output) * y_onehot
170 | tn = (1 - net_output) * (1 - y_onehot)
171 |
172 | if mask is not None:
173 | tp = torch.stack(tuple(x_i * mask[:, 0]
174 | for x_i in torch.unbind(tp, dim=1)), dim=1)
175 | fp = torch.stack(tuple(x_i * mask[:, 0]
176 | for x_i in torch.unbind(fp, dim=1)), dim=1)
177 | fn = torch.stack(tuple(x_i * mask[:, 0]
178 | for x_i in torch.unbind(fn, dim=1)), dim=1)
179 | tn = torch.stack(tuple(x_i * mask[:, 0]
180 | for x_i in torch.unbind(tn, dim=1)), dim=1)
181 |
182 | if square:
183 | tp = tp ** 2
184 | fp = fp ** 2
185 | fn = fn ** 2
186 | tn = tn ** 2
187 |
188 | if len(axes) > 0:
189 | tp = sum_tensor(tp, axes, keepdim=False)
190 | fp = sum_tensor(fp, axes, keepdim=False)
191 | fn = sum_tensor(fn, axes, keepdim=False)
192 | tn = sum_tensor(tn, axes, keepdim=False)
193 |
194 | return tp, fp, fn, tn
195 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | from yacs.config import CfgNode as CN
4 |
5 | _C = CN()
6 | # -----------------------------------------------------------------------------
7 | # Base settings
8 | # -----------------------------------------------------------------------------
9 | _C.BASE = ['']
10 |
11 | _C.DIS = False
12 | _C.WORLD_SIZE = 1
13 | _C.SEED = 1234
14 | _C.AMP = True
15 | _C.EXPERIMENT_ID = ""
16 | _C.SAVE_DIR = "save_pth"
17 | _C.VAL_OUTPUT_PATH = "/home/lwt/code/flare/FLARE22-TwoStagePHTrans/save_results"
18 | _C.COARSE_MODEL_PATH = "/home/lwt/code/flare/FLARE22-TwoStagePHTrans/save_pth/phtrans_c_220813_001000"
19 | _C.FINE_MODEL_PATH = "/home/lwt/code/flare/FLARE22-TwoStagePHTrans/save_pth/phtrans_f_220813_001056"
20 |
21 | # -----------------------------------------------------------------------------
22 | # Wandb settings
23 | # -----------------------------------------------------------------------------
24 | _C.WANDB = CN()
25 | _C.WANDB.COARSE_PROJECT = "FLARE2022_COARSE"
26 | _C.WANDB.FINE_PROJECT = "FLARE2022_FINE"
27 | _C.WANDB.TAG = "PHTrans"
28 | _C.WANDB.MODE = "offline"
29 |
30 | # -----------------------------------------------------------------------------
31 | # Data settings
32 | # -----------------------------------------------------------------------------
33 | _C.DATASET = CN()
34 | _C.DATASET.WITH_VAL = False
35 | _C.DATASET.TRAIN_UNLABELED_IMAGE_PATH = "/home/lwt/data/flare22/UnlabeledCase"
36 | _C.DATASET.TRAIN_UNLABELED_MASK_PATH = "/home/lwt/data/flare22/Unlabel2000_phtranPre"
37 | _C.DATASET.TRAIN_IMAGE_PATH = "/home/lwt/data/flare22/Training/FLARE22_LabeledCase50/images"
38 | _C.DATASET.TRAIN_MASK_PATH = "/home/lwt/data/flare22/Training/FLARE22_LabeledCase50/labels"
39 | _C.DATASET.VAL_IMAGE_PATH = "/home/lwt/data/flare22/Validation"
40 | _C.DATASET.EXTEND_SIZE = 20
41 | _C.DATASET.IS_NORMALIZATION_DIRECTION = True
42 |
43 | _C.DATASET.COARSE = CN()
44 | _C.DATASET.COARSE.PROPRECESS_PATH = "/home/lwt/data_pro/flare22/Training/coarse_646464"
45 | _C.DATASET.COARSE.PROPRECESS_UL_PATH = "/home/lwt/data_pro/flare22/Unlabel2000_coarse_646464"
46 | _C.DATASET.COARSE.NUM_EACH_EPOCH = 512
47 | _C.DATASET.COARSE.SIZE = [64, 64, 64]
48 | _C.DATASET.COARSE.LABEL_CLASSES = 2
49 |
50 | _C.DATASET.FINE = CN()
51 | _C.DATASET.FINE.PROPRECESS_PATH = "/home/lwt/data_pro/flare22/Training/fine_96192192"
52 | _C.DATASET.FINE.PROPRECESS_UL_PATH = "/home/lwt/data_pro/flare22/Unlabel2000_fine_96192192"
53 | _C.DATASET.FINE.NUM_EACH_EPOCH = 512
54 | _C.DATASET.FINE.SIZE = [96, 192, 192]
55 | _C.DATASET.FINE.LABEL_CLASSES = 14
56 |
57 | _C.DATASET.DA = CN()
58 | _C.DATASET.DA.DO_2D_AUG = True
59 | _C.DATASET.DA.DO_ELASTIC = True
60 | _C.DATASET.DA.DO_SCALING = True
61 | _C.DATASET.DA.DO_ROTATION = True
62 | _C.DATASET.DA.RANDOM_CROP = False
63 | _C.DATASET.DA.DO_GAMMA = True
64 | _C.DATASET.DA.DO_MIRROR = False
65 | _C.DATASET.DA.DO_ADDITIVE_BRIGHTNESS = True
66 |
67 | # -----------------------------------------------------------------------------
68 | # Dataloader settings
69 | # -----------------------------------------------------------------------------
70 | _C.DATALOADER = CN()
71 | _C.DATALOADER.BATCH_SIZE = 1
72 | _C.DATALOADER.PIN_MEMORY = True
73 | _C.DATALOADER.NUM_WORKERS = 8
74 |
75 | # -----------------------------------------------------------------------------
76 | # Model settings
77 | # -----------------------------------------------------------------------------
78 | _C.MODEL = CN()
79 | _C.MODEL.DEEP_SUPERVISION = True
80 |
81 | _C.MODEL.COARSE = CN()
82 | _C.MODEL.COARSE.TYPE = "phtrans"
83 | _C.MODEL.COARSE.BASE_NUM_FEATURES = 16
84 | _C.MODEL.COARSE.NUM_ONLY_CONV_STAGE = 2
85 | _C.MODEL.COARSE.NUM_CONV_PER_STAGE = 2
86 | _C.MODEL.COARSE.FEAT_MAP_MUL_ON_DOWNSCALE = 2
87 | _C.MODEL.COARSE.POOL_OP_KERNEL_SIZES = [
88 | [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
89 | _C.MODEL.COARSE.CONV_KERNEL_SIZES = [
90 | [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
91 | _C.MODEL.COARSE.DROPOUT_P = 0.1
92 |
93 | _C.MODEL.COARSE.MAX_NUM_FEATURES = 200
94 | _C.MODEL.COARSE.DEPTHS = [2, 2, 2, 2]
95 | _C.MODEL.COARSE.NUM_HEADS = [4, 4, 4, 4]
96 | _C.MODEL.COARSE.WINDOW_SIZE = [4, 4, 4]
97 | _C.MODEL.COARSE.MLP_RATIO = 1.
98 | _C.MODEL.COARSE.QKV_BIAS = True
99 | _C.MODEL.COARSE.QK_SCALE = None
100 | _C.MODEL.COARSE.DROP_RATE = 0.
101 | _C.MODEL.COARSE.DROP_PATH_RATE = 0.1
102 |
103 | _C.MODEL.FINE = CN()
104 | _C.MODEL.FINE.TYPE = "phtrans"
105 | _C.MODEL.FINE.BASE_NUM_FEATURES = 16
106 | _C.MODEL.FINE.NUM_ONLY_CONV_STAGE = 2
107 | _C.MODEL.FINE.NUM_CONV_PER_STAGE = 2
108 | _C.MODEL.FINE.FEAT_MAP_MUL_ON_DOWNSCALE = 2
109 | _C.MODEL.FINE.POOL_OP_KERNEL_SIZES = [
110 | [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
111 | _C.MODEL.FINE.CONV_KERNEL_SIZES = [[3, 3, 3], [
112 | 3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
113 | _C.MODEL.FINE.DROPOUT_P = 0.1
114 |
115 | _C.MODEL.FINE.MAX_NUM_FEATURES = 200
116 | _C.MODEL.FINE.DEPTHS = [2, 2, 2, 2]
117 | _C.MODEL.FINE.NUM_HEADS = [4, 4, 4, 4]
118 | _C.MODEL.FINE.WINDOW_SIZE = [3, 4, 4]
119 | _C.MODEL.FINE.MLP_RATIO = 1.
120 | _C.MODEL.FINE.QKV_BIAS = True
121 | _C.MODEL.FINE.QK_SCALE = None
122 | _C.MODEL.FINE.DROP_RATE = 0.
123 | _C.MODEL.FINE.DROP_PATH_RATE = 0.1
124 |
125 | # -----------------------------------------------------------------------------
126 | # Training settings
127 | # -----------------------------------------------------------------------------
128 | _C.TRAIN = CN()
129 | _C.TRAIN.DO_BACKPROP = True
130 | _C.TRAIN.VAL_NUM_EPOCHS = 1
131 | _C.TRAIN.SAVE_PERIOD = 1
132 |
133 | _C.TRAIN.EPOCHS = 300
134 | _C.TRAIN.WEIGHT_DECAY = 0.01
135 | _C.TRAIN.WARMUP_EPOCHS = 20
136 | _C.TRAIN.BASE_LR = 5e-4
137 | _C.TRAIN.WARMUP_LR = 5e-7
138 | _C.TRAIN.MIN_LR = 5e-6
139 | # LR scheduler
140 | _C.TRAIN.LR_SCHEDULER = CN()
141 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
142 |
143 | # Epoch interval to decay LR, used in StepLRScheduler
144 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
145 | # LR decay rate, used in StepLRScheduler
146 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
147 |
148 | # Optimizer
149 | _C.TRAIN.OPTIMIZER = CN()
150 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
151 | # Optimizer Epsilon
152 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
153 | # Optimizer Betas
154 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
155 | # SGD momentum
156 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
157 |
158 | # -----------------------------------------------------------------------------
159 | # Test settings
160 | # -----------------------------------------------------------------------------
161 | _C.VAL = CN()
162 | _C.VAL.IS_POST_PROCESS = True
163 | _C.VAL.IS_WITH_DATALOADER = True
164 |
165 |
166 | def _update_config_from_file(config, cfg_file):
167 | config.defrost()
168 | with open(cfg_file, 'r') as f:
169 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
170 |
171 | for cfg in yaml_cfg.setdefault('BASE', ['']):
172 | if cfg:
173 | _update_config_from_file(
174 | config, os.path.join(os.path.dirname(cfg_file), cfg)
175 | )
176 | print('=> merge config from {}'.format(cfg_file))
177 | config.merge_from_file(cfg_file)
178 | config.freeze()
179 |
180 |
181 | def update_config(config, args):
182 | if args.cfg is not None:
183 | _update_config_from_file(config, args.cfg)
184 |
185 | config.defrost()
186 | if args.opts:
187 | config.merge_from_list(args.opts)
188 | if args.batch_size:
189 | config.DATALOADER.BATCH_SIZE = args.batch_size
190 | if args.tag:
191 | config.WANDB.TAG = args.tag
192 | if args.wandb_mode == "online":
193 | config.WANDB.MODE = args.wandb_mode
194 | if args.world_size:
195 | config.WORLD_SIZE = args.world_size
196 | if args.with_distributed:
197 | config.DIS = True
198 | config.freeze()
199 |
200 |
201 | def update_val_config(config, args):
202 | if args.cfg is not None:
203 | _update_config_from_file(config, args.cfg)
204 |
205 | config.defrost()
206 | if args.opts:
207 | config.merge_from_list(args.opts)
208 |
209 | # merge from specific arguments
210 | if args.save_model_path:
211 | config.SAVE_MODEL_PATH = args.save_model_path
212 | if args.data_path:
213 | config.DATASET.VAL_IMAGE_PATH = args.data_path
214 | if args.output_path:
215 | config.VAL_OUTPUT_PATH = args.output_path
216 |
217 | config.freeze()
218 |
219 |
220 | def get_config(args=None):
221 | config = _C.clone()
222 | update_config(config, args)
223 |
224 | return config
225 |
226 |
227 | def get_config_no_args():
228 | config = _C.clone()
229 |
230 | return config
231 |
232 |
233 | def get_val_config(args=None):
234 | config = _C.clone()
235 | update_val_config(config, args)
236 |
237 | return config
238 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from config import get_val_config
3 | from models import build_coarse_model, build_fine_model
4 | import os
5 | import torch.backends.cudnn as cudnn
6 | import numpy as np
7 | import time
8 | import torch
9 | from torch.cuda.amp import autocast
10 | import SimpleITK as sitk
11 | from utils import to_cuda, load_checkpoint
12 | from data import predict_dataset, DataLoaderX
13 | from data.utils import change_axes_of_image, extract_topk_largest_candidates, to_one_hot, input_downsample, output_upsample, crop_image_according_to_bbox, get_bbox_from_mask
14 | from batchgenerators.utilities.file_and_folder_operations import *
15 | import torch.nn.functional as F
16 |
17 | def parse_option():
18 | parser = argparse.ArgumentParser("FLARE2022_training")
19 | parser.add_argument('--cfg', type=str, metavar="FILE",
20 | help='path to config file')
21 | parser.add_argument(
22 | "--opts",
23 | help="Modify config options by adding 'KEY VALUE' pairs. ",
24 | default=None,
25 | nargs='+',
26 | )
27 | parser.add_argument('-smp', '--save_model_path', type=str,
28 | default=None, help='path to model.pth')
29 | parser.add_argument('-dp', '--data_path', type=str,
30 | default=None, help='path to validation image path')
31 | parser.add_argument('-op', '--output_path', type=str,
32 | default=None, help='path to output image path')
33 | args = parser.parse_args()
34 | config = get_val_config(args)
35 |
36 | return args, config
37 |
38 | class Inference:
39 | def __init__(self, config) -> None:
40 | self.config = config
41 | self.output_path = self.config.VAL_OUTPUT_PATH
42 | os.makedirs(config.VAL_OUTPUT_PATH, exist_ok=True)
43 | self.coarse_size = self.config.DATASET.COARSE.SIZE
44 | self.fine_size = self.config.DATASET.FINE.SIZE
45 | self.extend_size = self.config.DATASET.EXTEND_SIZE
46 | self.is_post_process = self.config.VAL.IS_POST_PROCESS
47 | self.is_nor_dir = self.config.DATASET.IS_NORMALIZATION_DIRECTION
48 | self.is_with_dataloader = self.config.VAL.IS_WITH_DATALOADER
49 | if self.is_with_dataloader:
50 | val_dataset = predict_dataset(config)
51 | self.val_loader = DataLoaderX(
52 | val_dataset,
53 | batch_size=1,
54 | num_workers=0,
55 | pin_memory=config.DATALOADER.PIN_MEMORY,
56 | shuffle=False,
57 | )
58 | else:
59 | self.val_loader = predict_dataset(config)
60 | cudnn.benchmark = True
61 |
62 | def run(self):
63 | torch.cuda.synchronize()
64 | t_start = time.time()
65 | with autocast():
66 | with torch.no_grad():
67 | for image_dict in self.val_loader:
68 | image_dict = image_dict[0] if type(image_dict) is list else image_dict
69 | if self.is_with_dataloader:
70 | image_id = image_dict['image_id'][0]
71 | raw_image = np.array(image_dict['raw_image'].squeeze(0))
72 | raw_spacing = np.array(image_dict['raw_spacing'][0])
73 | image_direction = np.array(image_dict['image_direction'][0])
74 | else:
75 | image_id = image_dict['image_id']
76 | raw_image = image_dict['raw_image']
77 | raw_spacing = image_dict['raw_spacing']
78 | image_direction = image_dict['image_direction']
79 | coarse_image = torch.from_numpy(
80 | raw_image).unsqueeze(0).unsqueeze(0).float()
81 | raw_image_shape = raw_image.shape
82 | coarse_resize_factor = np.array(raw_image.shape) / np.array(self.coarse_size)
83 | coarse_image = input_downsample(coarse_image, self.coarse_size)
84 | coarse_image = self.coarse_predict(coarse_image, self.config.COARSE_MODEL_PATH)
85 | coarse_pre = F.softmax(coarse_image, 1)
86 | coarse_pre = coarse_pre.cpu().float()
87 | torch.cuda.empty_cache()
88 | coarse_mask = coarse_pre.argmax(1).squeeze(axis=0).numpy().astype(np.uint8)
89 | lab_unique = np.unique(coarse_mask)
90 | coarse_mask = to_one_hot(coarse_mask)
91 | coarse_mask = extract_topk_largest_candidates(coarse_mask,lab_unique, 1)
92 | coarse_bbox = get_bbox_from_mask(coarse_mask)
93 | raw_bbox = [[int(coarse_bbox[0][0] * coarse_resize_factor[0]),
94 | int(coarse_bbox[0][1] * coarse_resize_factor[0])],
95 | [int(coarse_bbox[1][0] * coarse_resize_factor[1]),
96 | int(coarse_bbox[1][1] * coarse_resize_factor[1])],
97 | [int(coarse_bbox[2][0] * coarse_resize_factor[2]),
98 | int(coarse_bbox[2][1] * coarse_resize_factor[2])]]
99 | margin = [self.extend_size / raw_spacing[i]
100 | for i in range(3)]
101 | crop_image, crop_fine_bbox = crop_image_according_to_bbox(
102 | raw_image, raw_bbox, margin)
103 | print(crop_fine_bbox)
104 | crop_image_size = crop_image.shape
105 | crop_image = torch.from_numpy(crop_image).unsqueeze(0).unsqueeze(0)
106 | crop_image = input_downsample(crop_image, self.fine_size)
107 | crop_image = self.fine_predict(crop_image, config.FINE_MODEL_PATH)
108 | torch.cuda.empty_cache()
109 | crop_image = output_upsample(crop_image, crop_image_size)
110 | crop_image = F.softmax(crop_image, 1)
111 | fine_mask = crop_image.argmax(1).squeeze(axis=0).numpy().astype(np.uint8)
112 | if self.is_post_process:
113 | lab_unique = np.unique(fine_mask)
114 | fine_mask = to_one_hot(fine_mask)
115 | fine_mask = extract_topk_largest_candidates(fine_mask,lab_unique, 1)
116 | out_mask = np.zeros(raw_image_shape, np.uint8)
117 | out_mask[crop_fine_bbox[0][0]:crop_fine_bbox[0][1],
118 | crop_fine_bbox[1][0]:crop_fine_bbox[1][1],
119 | crop_fine_bbox[2][0]:crop_fine_bbox[2][1]] = fine_mask
120 | if self.is_nor_dir:
121 | out_mask = change_axes_of_image(out_mask, image_direction)
122 | sitk_image = sitk.GetImageFromArray(out_mask)
123 | sitk.WriteImage(sitk_image, os.path.join(
124 | self.output_path, "FLARETs_{}.nii.gz".format(image_id)), True)
125 | print(f"{image_id} Done")
126 |
127 | torch.cuda.synchronize()
128 | t_end = time.time()
129 | average_time_usage = (t_end - t_start) * 1.0 / len(self.val_loader)
130 | print("Average time usage: {} s".format(average_time_usage))
131 |
132 | def coarse_predict(self, input, model_path):
133 | coarse_model_checkpoint = load_checkpoint(model_path)
134 | coarse_model = build_coarse_model(coarse_model_checkpoint["config"], True).eval()
135 | coarse_model.load_state_dict({k.replace('module.', ''): v for k, v in coarse_model_checkpoint['state_dict'].items()})
136 | self._set_requires_grad(coarse_model, False)
137 | coarse_model = coarse_model.cuda().half()
138 | input = to_cuda(input).half()
139 | out = coarse_model(input)
140 | coarse_model = coarse_model.cpu()
141 | return out.cpu().float()
142 |
143 | def fine_predict(self, input, model_path):
144 | fine_model_checkpoint = load_checkpoint(model_path)
145 | fine_model = build_fine_model(fine_model_checkpoint["config"], True).eval()
146 | fine_model.load_state_dict({k.replace('module.', ''): v for k, v in fine_model_checkpoint['state_dict'].items()})
147 | self._set_requires_grad(fine_model, False)
148 | fine_model = fine_model.cuda().half()
149 | input = to_cuda(input).half()
150 | out = fine_model(input)
151 | fine_model = fine_model.cpu()
152 | return out.cpu().float()
153 |
154 | @staticmethod
155 | def _set_requires_grad(model, requires_grad=False):
156 | for param in model.parameters():
157 | param.requires_grad = requires_grad
158 |
159 | if __name__ == '__main__':
160 | torch.cuda.synchronize()
161 | t_start = time.time()
162 | _, config = parse_option()
163 |
164 | predict = Inference(config)
165 | predict.run()
166 | torch.cuda.synchronize()
167 | t_end = time.time()
168 | total_time = t_end - t_start
169 | print("Total_time: {} s".format(total_time))
170 |
--------------------------------------------------------------------------------
/data/dataset_train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | from torch.utils.data import Dataset
4 | from batchgenerators.utilities.file_and_folder_operations import *
5 | from .data_augmentation import default_3D_augmentation_params,default_2D_augmentation_params,get_patch_size,DownsampleSegForDSTransform
6 | from batchgenerators.transforms.abstract_transforms import Compose
7 | from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
8 | ContrastAugmentationTransform, BrightnessTransform
9 | from batchgenerators.transforms.color_transforms import GammaTransform
10 | from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
11 | from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
12 | from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
13 | from batchgenerators.transforms.utility_transforms import NumpyToTensor
14 | from data.utils import load_pickle
15 | class flare22_dataset(Dataset):
16 | def __init__(self, config, data_size, data_path, unlab_data_path, pool_op_kernel_sizes, num_each_epoch,is_train=True, is_deep_supervision=True):
17 | self.config=config
18 | self.data_path = data_path
19 | self.data_size = data_size
20 | self.unlab_data_path = unlab_data_path
21 | self.pool_op_kernel_sizes = pool_op_kernel_sizes
22 | self.num_each_epoch = num_each_epoch
23 | self.series_ids = subfiles(data_path, join=False, suffix='npz')
24 | self.unlab_series_ids = subfiles(unlab_data_path, join=False, suffix='npz')
25 | self.setup_DA_params()
26 |
27 | self.transforms = self.get_augmentation(
28 | data_size,
29 | self.data_aug_params,is_train=is_train,
30 | deep_supervision_scales=self.deep_supervision_scales if is_deep_supervision else None
31 | )
32 | def __getitem__(self, idx):
33 | if idx < len(self.series_ids):
34 | data_id = self.series_ids[idx]
35 | data_info = load_pickle(join(self.data_path, data_id.split(".")[0] + "_info.pkl"))
36 | data_load = np.load(join(self.data_path,data_id))
37 | else:
38 | data_id = self.unlab_series_ids[random.randint(0,len(self.unlab_series_ids)-1)]
39 | data_info = load_pickle(join(self.unlab_data_path, data_id.split(".")[0] + "_info.pkl"))
40 | data_load = np.load(join(self.unlab_data_path,data_id))
41 |
42 | data_trans = self.transforms(**data_load)
43 | return data_trans, data_info
44 |
45 | def __len__(self):
46 | return self.num_each_epoch
47 |
48 |
49 |
50 | def setup_DA_params(self):
51 | if self.config.MODEL.DEEP_SUPERVISION:
52 | self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(np.vstack(self.pool_op_kernel_sizes), axis=0))[:-1]
53 | self.data_aug_params = default_3D_augmentation_params
54 | self.data_aug_params['rotation_x'] = (
55 | -30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
56 | self.data_aug_params['rotation_y'] = (
57 | -30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
58 | self.data_aug_params['rotation_z'] = (
59 | -30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
60 |
61 | if self.config.DATASET.DA.DO_2D_AUG:
62 | if self.config.DATASET.DA.DO_ELASTIC:
63 | self.data_aug_params["elastic_deform_alpha"] = \
64 | default_2D_augmentation_params["elastic_deform_alpha"]
65 | self.data_aug_params["elastic_deform_sigma"] = \
66 | default_2D_augmentation_params["elastic_deform_sigma"]
67 | self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
68 |
69 | if self.config.DATASET.DA.DO_2D_AUG:
70 | self.basic_generator_patch_size = get_patch_size(self.data_size[1:],
71 | self.data_aug_params['rotation_x'],
72 | self.data_aug_params['rotation_y'],
73 | self.data_aug_params['rotation_z'],
74 | self.data_aug_params['scale_range'])
75 | self.basic_generator_patch_size = np.array(
76 | [self.data_size[0]] + list(self.basic_generator_patch_size))
77 | else:
78 | self.basic_generator_patch_size = get_patch_size(self.data_size, self.data_aug_params['rotation_x'],
79 | self.data_aug_params['rotation_y'],
80 | self.data_aug_params['rotation_z'],
81 | self.data_aug_params['scale_range'])
82 |
83 |
84 |
85 | def get_augmentation(self, patch_size, params=default_3D_augmentation_params,is_train=True,border_val_seg=-1,
86 | order_seg=1, order_data=3, deep_supervision_scales=None,):
87 | transforms = []
88 | if is_train:
89 |
90 | if self.config.DATASET.DA.DO_2D_AUG:
91 | ignore_axes = (1,)
92 |
93 | patch_size_spatial = patch_size[1:]
94 | else:
95 | patch_size_spatial = patch_size
96 | ignore_axes = None
97 |
98 | transforms.append(SpatialTransform(
99 | patch_size_spatial, patch_center_dist_from_border=None,
100 | do_elastic_deform=self.config.DATASET.DA.DO_ELASTIC, alpha=params.get("elastic_deform_alpha"),
101 | sigma=params.get("elastic_deform_sigma"),
102 | do_rotation=self.config.DATASET.DA.DO_ROTATION, angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
103 | angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"),
104 | do_scale=self.config.DATASET.DA.DO_SCALING, scale=params.get("scale_range"),
105 | border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
106 | border_mode_seg="constant", border_cval_seg=border_val_seg,
107 | order_seg=order_seg, random_crop=self.config.DATASET.DA.RANDOM_CROP, p_el_per_sample=params.get("p_eldef"),
108 | p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
109 | independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
110 | ))
111 |
112 |
113 | transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
114 | transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
115 | p_per_channel=0.5))
116 | transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
117 |
118 | if self.config.DATASET.DA.DO_ADDITIVE_BRIGHTNESS:
119 | transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
120 | params.get("additive_brightness_sigma"),
121 | True, p_per_sample=params.get("additive_brightness_p_per_sample"),
122 | p_per_channel=params.get("additive_brightness_p_per_channel")))
123 |
124 | transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
125 | transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
126 | p_per_channel=0.5,
127 | order_downsample=0, order_upsample=3, p_per_sample=0.25,
128 | ignore_axes=ignore_axes))
129 | transforms.append(
130 | GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
131 | p_per_sample=0.1)) # inverted gamma
132 |
133 | if self.config.DATASET.DA.DO_GAMMA:
134 | transforms.append(
135 | GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
136 | p_per_sample=params["p_gamma"]))
137 |
138 | if self.config.DATASET.DA.DO_MIRROR:
139 | transforms.append(MirrorTransform(params.get("mirror_axes")))
140 |
141 | if deep_supervision_scales is not None:
142 | transforms.append(DownsampleSegForDSTransform(deep_supervision_scales, 0, input_key='seg',
143 | output_key='seg'))
144 |
145 | transforms.append(NumpyToTensor(['data', 'seg'], 'float'))
146 | transforms = Compose(transforms)
147 | return transforms
148 |
149 |
--------------------------------------------------------------------------------
/models/swin_3D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import DropPath, trunc_normal_
4 | from einops import rearrange
5 |
6 |
7 | class Mlp(nn.Module):
8 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
9 | super().__init__()
10 | out_features = out_features or in_features
11 | hidden_features = hidden_features or in_features
12 | self.fc1 = nn.Linear(in_features, hidden_features)
13 | self.act = act_layer()
14 | self.fc2 = nn.Linear(hidden_features, out_features)
15 | self.drop = nn.Dropout(drop)
16 |
17 | def forward(self, x):
18 | x = self.fc1(x)
19 | x = self.act(x)
20 | x = self.drop(x)
21 | x = self.fc2(x)
22 | x = self.drop(x)
23 | return x
24 |
25 |
26 | def window_partition(x, window_size):
27 |
28 | B, S, H, W, C = x.shape
29 | windows = rearrange(x, 'b (s p1) (h p2) (w p3) c -> (b s h w) p1 p2 p3 c',
30 | p1=window_size[0], p2=window_size[1], p3=window_size[2], c=C)
31 | return windows
32 |
33 |
34 | def window_reverse(windows, window_size, S, H, W):
35 | B = int(windows.shape[0] / (S * H * W /
36 | window_size[0] / window_size[1] / window_size[2]))
37 |
38 | x = rearrange(windows, '(b s h w) p1 p2 p3 c -> b (s p1) (h p2) (w p3) c',
39 | p1=window_size[0], p2=window_size[1], p3=window_size[2], b=B,
40 | s=S//window_size[0], h=H//window_size[1], w=W//window_size[2])
41 | return x
42 |
43 |
44 | class WindowAttention(nn.Module):
45 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
46 |
47 | super().__init__()
48 | self.dim = dim
49 | self.window_size = window_size
50 | self.num_heads = num_heads
51 | head_dim = dim // num_heads
52 | self.scale = qk_scale or head_dim ** -0.5
53 |
54 | self.relative_position_bias_table = nn.Parameter(
55 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1),
56 | num_heads))
57 |
58 | coords_s = torch.arange(self.window_size[0])
59 | coords_h = torch.arange(self.window_size[1])
60 | coords_w = torch.arange(self.window_size[2])
61 | coords = torch.stack(torch.meshgrid(
62 | [coords_s, coords_h, coords_w]))
63 | coords_flatten = torch.flatten(coords, 1)
64 | relative_coords = coords_flatten[:, :,
65 | None] - coords_flatten[:, None, :]
66 | relative_coords = relative_coords.permute(
67 | 1, 2, 0).contiguous()
68 | relative_coords[:, :, 0] += self.window_size[0] - 1
69 | relative_coords[:, :, 1] += self.window_size[1] - 1
70 | relative_coords[:, :, 2] += self.window_size[2] - 1
71 |
72 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * \
73 | (2 * self.window_size[2] - 1)
74 | relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
75 | relative_position_index = relative_coords.sum(-1)
76 | self.register_buffer("relative_position_index",
77 | relative_position_index)
78 |
79 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
80 | self.attn_drop = nn.Dropout(attn_drop)
81 | self.proj = nn.Linear(dim, dim)
82 | self.proj_drop = nn.Dropout(proj_drop)
83 |
84 | trunc_normal_(self.relative_position_bias_table, std=.02)
85 | self.softmax = nn.Softmax(dim=-1)
86 |
87 | def forward(self, x, mask=None):
88 |
89 | B_, N, C = x.shape
90 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //
91 | self.num_heads).permute(2, 0, 3, 1, 4)
92 |
93 | q, k, v = qkv[0], qkv[1], qkv[2]
94 |
95 | q = q * self.scale
96 | attn = (q @ k.transpose(-2, -1))
97 |
98 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
99 | self.window_size[0] * self.window_size[1] * self.window_size[2],
100 | self.window_size[0] * self.window_size[1] * self.window_size[2], -1)
101 | relative_position_bias = relative_position_bias.permute(
102 | 2, 0, 1).contiguous()
103 | attn = attn + relative_position_bias.unsqueeze(0)
104 |
105 | if mask is not None:
106 | nW = mask.shape[0]
107 | attn = attn.view(B_ // nW, nW, self.num_heads, N,
108 | N) + mask.unsqueeze(1).unsqueeze(0)
109 | attn = attn.view(-1, self.num_heads, N, N)
110 | attn = self.softmax(attn)
111 | else:
112 | attn = self.softmax(attn)
113 |
114 | attn = self.attn_drop(attn)
115 |
116 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
117 | x = self.proj(x)
118 | x = self.proj_drop(x)
119 | return x
120 |
121 |
122 | class SwinTransformerBlock(nn.Module):
123 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
124 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
125 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
126 | super().__init__()
127 | self.dim = dim
128 | self.input_resolution = input_resolution
129 | self.num_heads = num_heads
130 | self.window_size = window_size
131 | self.shift_size = shift_size
132 | self.mlp_ratio = mlp_ratio
133 |
134 | if self.shift_size != 0:
135 | assert 0 <= min(self.shift_size) < min(
136 | self.window_size), "shift_size must in 0-window_size"
137 |
138 | self.norm1 = norm_layer(dim)
139 | self.attn = WindowAttention(
140 | dim, window_size=self.window_size, num_heads=num_heads,
141 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
142 |
143 | self.drop_path = DropPath(
144 | drop_path) if drop_path > 0. else nn.Identity()
145 | self.norm2 = norm_layer(dim)
146 | mlp_hidden_dim = int(dim * mlp_ratio)
147 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
148 | act_layer=act_layer, drop=drop)
149 |
150 | if max(self.shift_size) > 0:
151 |
152 | S, H, W = self.input_resolution
153 | img_mask = torch.zeros((1, S, H, W, 1))
154 | s_slices = (slice(0, -self.window_size[0]),
155 | slice(-self.window_size[0], -self.shift_size[0]),
156 | slice(-self.shift_size[0], None))
157 | h_slices = (slice(0, -self.window_size[1]),
158 | slice(-self.window_size[1], -self.shift_size[1]),
159 | slice(-self.shift_size[1], None))
160 | w_slices = (slice(0, -self.window_size[2]),
161 | slice(-self.window_size[2], -self.shift_size[2]),
162 | slice(-self.shift_size[2], None))
163 | cnt = 0
164 | for s in s_slices:
165 | for h in h_slices:
166 | for w in w_slices:
167 | img_mask[:, s, h, w, :] = cnt
168 | cnt += 1
169 |
170 | mask_windows = window_partition(img_mask, self.window_size)
171 | mask_windows = mask_windows.view(
172 | -1, self.window_size[0] * self.window_size[1] * self.window_size[2])
173 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
174 | attn_mask = attn_mask.masked_fill(
175 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
176 | else:
177 | attn_mask = None
178 |
179 | self.register_buffer("attn_mask", attn_mask)
180 |
181 | def forward(self, x):
182 | s, h, w = self.input_resolution
183 | B, C, S, H, W = x.shape
184 | assert S == s and H == h and W == w, "input feature has wrong size"
185 | x = rearrange(x, 'b c s h w -> b (s h w) c')
186 | shortcut = x
187 | x = self.norm1(x)
188 | x = x.view(B, S, H, W, C)
189 |
190 | # cyclic shift
191 | if max(self.shift_size) > 0:
192 | shifted_x = torch.roll(
193 | x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
194 | else:
195 | shifted_x = x
196 |
197 | x_windows = window_partition(shifted_x, self.window_size)
198 |
199 | x_windows = x_windows.view(
200 | -1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)
201 |
202 | attn_windows = self.attn(x_windows, mask=self.attn_mask)
203 | attn_windows = attn_windows.view(
204 | -1, self.window_size[0], self.window_size[1], self.window_size[2], C)
205 | shifted_x = window_reverse(
206 | attn_windows, self.window_size, S, H, W)
207 | if max(self.shift_size) > 0:
208 | x = torch.roll(shifted_x, shifts=(
209 | self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
210 | else:
211 | x = shifted_x
212 | x = x.view(B, S * H * W, C)
213 |
214 | x = shortcut + self.drop_path(x)
215 | x = x + self.drop_path(self.mlp(self.norm2(x)))
216 | x = rearrange(x, 'b (s h w) c -> b c s h w', s=S, h=H, w=W)
217 | return x
218 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/models/phtrans.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from timm.models.layers import trunc_normal_
5 | from .swin_3D import *
6 |
7 |
8 | class PHTrans(nn.Module):
9 | def __init__(self, img_size, base_num_features, num_classes, image_channels=1, num_only_conv_stage=2, num_conv_per_stage=2,
10 | feat_map_mul_on_downscale=2, pool_op_kernel_sizes=None,
11 | conv_kernel_sizes=None, dropout_p=0., deep_supervision=True, max_num_features=None, only_conv=False, depths=None, num_heads=None,
12 | window_size=None, mlp_ratio=4., qkv_bias=True, qk_scale=None,
13 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
14 | norm_layer=nn.LayerNorm, is_preprocess=False, **kwargs):
15 | super().__init__()
16 |
17 | conv_op = nn.Conv3d
18 | norm_op = nn.InstanceNorm3d
19 | norm_op_kwargs = {'eps': 1e-5, 'affine': True}
20 | dropout_op = nn.Dropout3d
21 | dropout_op_kwargs = {'p': dropout_p, 'inplace': True}
22 | nonlin = nn.GELU
23 | nonlin_kwargs = {}
24 |
25 | self.is_preprocess = is_preprocess
26 | self._deep_supervision = deep_supervision
27 | self.num_pool = len(pool_op_kernel_sizes)
28 | conv_pad_sizes = []
29 | for krnl in conv_kernel_sizes:
30 | conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
31 | dpr = [x.item() for x in torch.linspace(
32 | 0, drop_path_rate, sum(depths))]
33 |
34 | self.seg_outputs = []
35 | for ds in range(self.num_pool):
36 | self.seg_outputs.append(DeepSupervision(min(
37 | (base_num_features * feat_map_mul_on_downscale ** ds), max_num_features), num_classes))
38 | self.seg_outputs = nn.ModuleList(self.seg_outputs)
39 |
40 | # build layers
41 | self.down_layers = nn.ModuleList()
42 | for i_layer in range(self.num_pool+1):
43 | layer = BasicLayer(num_stage=i_layer,
44 | only_conv=only_conv,
45 | num_only_conv_stage=num_only_conv_stage,
46 | num_pool=self.num_pool,
47 | base_num_features=base_num_features,
48 | dim=min(
49 | (base_num_features * feat_map_mul_on_downscale ** i_layer), max_num_features),
50 | input_resolution=(
51 | img_size // np.prod(pool_op_kernel_sizes[:i_layer], 0, dtype=np.int64)),
52 | depth=depths[i_layer-num_only_conv_stage] if (
53 | i_layer >= num_only_conv_stage) else None,
54 | num_heads=num_heads[i_layer-num_only_conv_stage] if (
55 | i_layer >= num_only_conv_stage) else None,
56 | window_size=window_size,
57 | image_channels=image_channels, num_conv_per_stage=num_conv_per_stage,
58 | conv_op=conv_op, norm_op=norm_op, norm_op_kwargs=norm_op_kwargs, dropout_op=dropout_op,
59 | dropout_op_kwargs=dropout_op_kwargs, nonlin=nonlin, nonlin_kwargs=nonlin_kwargs,
60 | conv_kernel_sizes=conv_kernel_sizes, conv_pad_sizes=conv_pad_sizes, pool_op_kernel_sizes=pool_op_kernel_sizes,
61 | max_num_features=max_num_features,
62 | mlp_ratio=mlp_ratio,
63 | qkv_bias=qkv_bias, qk_scale=qk_scale,
64 | drop=drop_rate, attn_drop=attn_drop_rate,
65 | drop_path=dpr[sum(depths[:i_layer-num_only_conv_stage]):sum(depths[:i_layer-num_only_conv_stage + 1])] if (
66 | i_layer >= num_only_conv_stage) else None,
67 | norm_layer=norm_layer,
68 | down_or_upsample=nn.Conv3d if i_layer > 0 else None,
69 | feat_map_mul_on_downscale=feat_map_mul_on_downscale,
70 | is_encoder=True)
71 | self.down_layers.append(layer)
72 | self.up_layers = nn.ModuleList()
73 | for i_layer in range(self.num_pool)[::-1]:
74 | layer = BasicLayer(num_stage=i_layer,
75 | only_conv=only_conv,
76 | num_only_conv_stage=num_only_conv_stage,
77 | num_pool=self.num_pool,
78 | base_num_features=base_num_features,
79 | dim=min(
80 | (base_num_features * feat_map_mul_on_downscale ** i_layer), max_num_features),
81 | input_resolution=(
82 | img_size // np.prod(pool_op_kernel_sizes[:i_layer], 0, dtype=np.int64)),
83 | depth=depths[i_layer-num_only_conv_stage] if (
84 | i_layer >= num_only_conv_stage) else None,
85 | num_heads=num_heads[i_layer-num_only_conv_stage] if (
86 | i_layer >= num_only_conv_stage) else None,
87 | window_size=window_size,
88 | image_channels=image_channels, num_conv_per_stage=num_conv_per_stage,
89 | conv_op=conv_op, norm_op=norm_op, norm_op_kwargs=norm_op_kwargs, dropout_op=dropout_op,
90 | dropout_op_kwargs=dropout_op_kwargs, nonlin=nonlin, nonlin_kwargs=nonlin_kwargs,
91 | conv_kernel_sizes=conv_kernel_sizes, conv_pad_sizes=conv_pad_sizes, pool_op_kernel_sizes=pool_op_kernel_sizes,
92 | max_num_features=max_num_features,
93 | mlp_ratio=mlp_ratio,
94 | qkv_bias=qkv_bias, qk_scale=qk_scale,
95 | drop=drop_rate, attn_drop=attn_drop_rate,
96 | drop_path=dpr[sum(depths[:i_layer-num_only_conv_stage]):sum(depths[:i_layer-num_only_conv_stage + 1])] if (
97 | i_layer >= num_only_conv_stage) else None,
98 | norm_layer=norm_layer,
99 | down_or_upsample=nn.ConvTranspose3d,
100 | feat_map_mul_on_downscale=feat_map_mul_on_downscale,
101 | is_encoder=False)
102 | self.up_layers.append(layer)
103 | self.apply(self._InitWeights)
104 |
105 | def _InitWeights(self, module):
106 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
107 | module.weight = nn.init.kaiming_normal_(module.weight, a=.02)
108 | if module.bias is not None:
109 | module.bias = nn.init.constant_(module.bias, 0)
110 | elif isinstance(module, nn.Linear):
111 | trunc_normal_(module.weight, std=.02)
112 | if isinstance(module, nn.Linear) and module.bias is not None:
113 | nn.init.constant_(module.bias, 0)
114 | elif isinstance(module, nn.LayerNorm):
115 | nn.init.constant_(module.bias, 0)
116 | nn.init.constant_(module.weight, 1.0)
117 |
118 | def forward(self, x):
119 |
120 | x_skip = list()
121 | for i, layer in enumerate(self.down_layers):
122 | x = layer(x, None)
123 | if i < self.num_pool:
124 | x_skip.append(x)
125 | out = []
126 | for i, layer in enumerate(self.up_layers):
127 | x = layer(x, x_skip[-(i+1)])
128 | if self._deep_supervision:
129 | out.append(x)
130 |
131 | if self._deep_supervision:
132 | ds = []
133 | for i in range(len(out)):
134 | ds.append(self.seg_outputs[i](out[-(i+1)]))
135 | else:
136 | ds = self.seg_outputs[0](x)
137 |
138 | return ds
139 |
140 | @torch.jit.ignore
141 | def no_weight_decay(self):
142 | return {'absolute_pos_embed'}
143 |
144 | @torch.jit.ignore
145 | def no_weight_decay_keywords(self):
146 | return {'relative_position_bias_table'}
147 |
148 |
149 | class ConvDropoutNormNonlin(nn.Module):
150 | def __init__(self, input_channels, output_channels,
151 | conv_op=nn.Conv3d, conv_kwargs=None,
152 | norm_op=nn.BatchNorm3d, norm_op_kwargs=None,
153 | dropout_op=nn.Dropout3d, dropout_op_kwargs=None,
154 | nonlin=nn.LeakyReLU, nonlin_kwargs=None):
155 | super(ConvDropoutNormNonlin, self).__init__()
156 | self.conv = conv_op(input_channels, output_channels, **conv_kwargs)
157 |
158 | if dropout_op is not None and dropout_op_kwargs['p'] is not None and dropout_op_kwargs[
159 | 'p'] > 0:
160 |
161 | self.dropout = dropout_op(**dropout_op_kwargs)
162 | else:
163 | self.dropout = None
164 | self.instnorm = norm_op(output_channels, **norm_op_kwargs)
165 | if nonlin == nn.GELU:
166 | self.lrelu = nonlin()
167 | else:
168 | self.lrelu = nonlin(**nonlin_kwargs)
169 |
170 | def forward(self, x):
171 | x = self.conv(x)
172 | if self.dropout is not None:
173 | x = self.dropout(x)
174 | return self.lrelu(self.instnorm(x))
175 |
176 |
177 | class DeepSupervision(nn.Module):
178 | def __init__(self, dim, num_classes):
179 | super().__init__()
180 | self.proj = nn.Conv3d(
181 | dim, num_classes, kernel_size=1, stride=1, bias=False)
182 |
183 | def forward(self, x):
184 | x = self.proj(x)
185 | return x
186 |
187 |
188 | class BasicLayer(nn.Module):
189 |
190 | def __init__(self, num_stage, only_conv, num_only_conv_stage, num_pool, base_num_features, dim, input_resolution, depth, num_heads,
191 | window_size, image_channels=1, num_conv_per_stage=2, conv_op=None,
192 | norm_op=None, norm_op_kwargs=None,
193 | dropout_op=None, dropout_op_kwargs=None,
194 | nonlin=None, nonlin_kwargs=None,
195 | conv_kernel_sizes=None, conv_pad_sizes=None, pool_op_kernel_sizes=None, basic_block=ConvDropoutNormNonlin, max_num_features=None,
196 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
197 | drop_path=0., norm_layer=nn.LayerNorm, down_or_upsample=None, feat_map_mul_on_downscale=2, is_encoder=True):
198 |
199 | super().__init__()
200 | self.num_stage = num_stage
201 | self.only_conv = only_conv
202 | self.num_only_conv_stage = num_only_conv_stage
203 | self.num_pool = num_pool
204 | self.is_encoder = is_encoder
205 | self.image_channels = image_channels
206 | conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True}
207 |
208 | if is_encoder:
209 | input_features = dim
210 | else:
211 | input_features = 2*dim
212 |
213 | # self.depth = depth
214 | conv_kwargs['kernel_size'] = conv_kernel_sizes[num_stage]
215 | conv_kwargs['padding'] = conv_pad_sizes[num_stage]
216 |
217 | input_du_channels = min(int(base_num_features * feat_map_mul_on_downscale ** (num_stage-1 if is_encoder else num_stage+1)),
218 | max_num_features)
219 | output_du_channels = dim
220 | if self.is_encoder and self.num_stage == 0:
221 | self.frist_conv = conv_op(
222 | image_channels, dim, kernel_size=1, stride=1, bias=True)
223 | else:
224 | self.frist_conv = None
225 | self.conv_blocks = nn.Sequential(
226 | *([basic_block(input_features, dim, conv_op,
227 | conv_kwargs,
228 | norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
229 | nonlin, nonlin_kwargs)] +
230 | [basic_block(dim, dim, conv_op,
231 | conv_kwargs,
232 | norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
233 | nonlin, nonlin_kwargs) for _ in range(num_conv_per_stage - 1)]))
234 |
235 | # build blocks
236 | if num_stage >= num_only_conv_stage and not only_conv:
237 | self.swin_blocks = nn.ModuleList([
238 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
239 | num_heads=num_heads, window_size=window_size,
240 | shift_size=[0, 0, 0] if (i % 2 == 0) else [
241 | window_size[0] // 2, window_size[1] // 2, window_size[2] // 2],
242 | mlp_ratio=mlp_ratio,
243 | qkv_bias=qkv_bias, qk_scale=qk_scale,
244 | drop=drop, attn_drop=attn_drop,
245 | drop_path=drop_path[i] if isinstance(
246 | drop_path, list) else drop_path,
247 | norm_layer=norm_layer)
248 | for i in range(depth)])
249 |
250 | # patch merging layer
251 | if down_or_upsample is not None:
252 | dowm_stage = num_stage-1 if is_encoder else num_stage
253 | self.down_or_upsample = nn.Sequential(down_or_upsample(input_du_channels, output_du_channels, pool_op_kernel_sizes[dowm_stage],
254 | pool_op_kernel_sizes[dowm_stage], bias=False),
255 | norm_op(
256 | output_du_channels, **norm_op_kwargs)
257 | )
258 | else:
259 | self.down_or_upsample = None
260 |
261 | def forward(self, x, skip):
262 | if self.frist_conv is not None:
263 | x = self.frist_conv(x)
264 | if self.down_or_upsample is not None:
265 | x = self.down_or_upsample(x)
266 | s = x
267 | if not self.is_encoder and self.num_stage < self.num_pool:
268 | x = torch.cat((x, skip), dim=1)
269 | x = self.conv_blocks(x)
270 | if self.num_stage >= self.num_only_conv_stage and not self.only_conv:
271 | if not self.is_encoder and self.num_stage < self.num_pool:
272 | s = s + skip
273 | for tblk in self.swin_blocks:
274 | s = tblk(s)
275 | x = x + s
276 | return x
277 |
--------------------------------------------------------------------------------