├── src ├── utils.py ├── models │ ├── unetaspp │ │ ├── __init__.py │ │ └── aspp.py │ ├── vnet │ │ ├── __init__.py │ │ ├── scse.py │ │ ├── cbam.py │ │ ├── aspp.py │ │ └── vnet.py │ ├── ModelDeepLab │ │ ├── __init__.py │ │ ├── deeplab.py │ │ ├── backbones_deeplab.py │ │ ├── resnext.py │ │ ├── deeplab_jpu.py │ │ └── resnet.py │ ├── __init__.py │ ├── unet3d │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── buildingblocks.py │ │ └── models.py │ └── DialResUnet.py ├── runner.py ├── __init__.py ├── extract_slices.py ├── losses.py ├── augmentation.py ├── callbacks.py ├── inference.py ├── experiment.py ├── preprocessing.py ├── swa.py ├── prepare_data.py ├── Segmentation.py ├── Segmentation3d.py ├── optimizers.py ├── schedulers.py └── dataset.py ├── bin ├── train.sh ├── train_2d.sh └── extract_slices.sh ├── README.md ├── configs ├── config.yml └── config_2d.yml └── requirements.txt /src/utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/unetaspp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/vnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .vnet import VNet -------------------------------------------------------------------------------- /src/models/ModelDeepLab/__init__.py: -------------------------------------------------------------------------------- 1 | # from .deeplab import DeepLab 2 | # from .deeplab_jpu import DeepLab as DeepLabJPU -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet3d import * 2 | # from .ModelDeepLab.deeplab import DeepLab 3 | from .vnet import VNet -------------------------------------------------------------------------------- /src/models/unet3d/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import UNet3D, ResidualUNet3D, Noise2NoiseUNet3D, TagsUNet3D, DistanceTransformUNet3D, EndToEndDTUNet3D, UNet3D2 -------------------------------------------------------------------------------- /src/runner.py: -------------------------------------------------------------------------------- 1 | try: 2 | import os 3 | 4 | if os.environ.get("USE_WANDB", "1") == "1": 5 | from catalyst.dl import SupervisedWandbRunner as Runner 6 | else: 7 | from catalyst.dl import SupervisedRunner as Runner 8 | except ImportError: 9 | from catalyst.dl import SupervisedRunner as Runner 10 | 11 | 12 | class ModelRunner(Runner): 13 | def __init__(self, model=None, device=None): 14 | super().__init__( 15 | model=model, device=device, input_key="images", output_key='logits' 16 | ) 17 | -------------------------------------------------------------------------------- /src/models/vnet/scse.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SCSEModule(nn.Module): 5 | def __init__(self, in_channels, reduction=16): 6 | super().__init__() 7 | self.cSE = nn.Sequential( 8 | nn.AdaptiveAvgPool2d(1), 9 | nn.Conv2d(in_channels, in_channels // reduction, 1), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(in_channels // reduction, in_channels, 1), 12 | nn.Sigmoid(), 13 | ) 14 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1), nn.Sigmoid()) 15 | 16 | def forward(self, x): 17 | return x * self.cSE(x) + x * self.sSE(x) 18 | -------------------------------------------------------------------------------- /bin/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=2,3 4 | RUN_CONFIG=config.yml 5 | 6 | for model in se_resnext101_32x4d; do 7 | for fold in 0 1 2 3 4; do 8 | log_name=Vnet-$model-weighted-cedice19-cbam-fold-${fold} 9 | # tag="["Unet","$model","$loss","fold-$fold"]" 10 | #stage 1 11 | LOGDIR=/logs/ss_task3/${log_name}/ 12 | catalyst-dl run \ 13 | --config=./configs/${RUN_CONFIG} \ 14 | --logdir=$LOGDIR \ 15 | --out_dir=$LOGDIR:str \ 16 | --model_params/encoder_name=$model:str \ 17 | --monitoring_params/name=${log_name}:str \ 18 | --stages/data_params/train_csv=./csv/5folds/train_$fold.csv:str \ 19 | --stages/data_params/valid_csv=./csv/5folds/valid_$fold.csv:str \ 20 | --verbose 21 | done 22 | done -------------------------------------------------------------------------------- /bin/train_2d.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1 4 | RUN_CONFIG=config_2d.yml 5 | 6 | for model in se_resnext50_32x4d resnet34; do 7 | for fold in 0 1 2 3 4; do 8 | log_name=FPN-$model-fold-${fold} 9 | # tag="["Unet","$model","$loss","fold-$fold"]" 10 | #stage 1 11 | LOGDIR=/logs/ss_miccai/${log_name}/ 12 | 13 | # catalyst-dl trace $LOGDIR 14 | USE_WANDB=0 catalyst-dl run \ 15 | --config ./configs/${RUN_CONFIG} \ 16 | --logdir=$LOGDIR \ 17 | --out_dir=$LOGDIR:str \ 18 | --model_params/encoder_name=$model:str \ 19 | --monitoring_params/name=${log_name}:str \ 20 | --stages/data_params/train_csv=./csv/5folds/train_$fold.csv:str \ 21 | --stages/data_params/valid_csv=./csv/5folds/valid_$fold.csv:str \ 22 | --verbose 23 | done 24 | done -------------------------------------------------------------------------------- /bin/extract_slices.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export LC_ALL=C.UTF-8 4 | export LANG=C.UTF-8 5 | 6 | 7 | slice_thichness=1 8 | scale_ratio=0.125 9 | slice=16 10 | patch=32 11 | 12 | #python src/preprocessing.py extract --csv_file ./data/Lung_GTV/idx-train.csv \ 13 | # --root ./data/Lung_GTV/ \ 14 | # --save_dir ./data/Lung_GTV_st${slice_thichness}_sr${scale_ratio}_s${slice}_p${patch} \ 15 | # --slice_thichness $slice_thichness \ 16 | # --scale_ratio $scale_ratio \ 17 | # --slice $slice \ 18 | # --patch $patch 19 | 20 | 21 | 22 | python src/preprocessing.py extract-2d --root /data/Thoracic_OAR/ \ 23 | --save_dir /data/Thoracic_OAR_2d/ 24 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from catalyst.dl import registry 3 | from .experiment import Experiment 4 | from .runner import ModelRunner as Runner 5 | from models import * 6 | from losses import * 7 | from callbacks import * 8 | from optimizers import * 9 | from schedulers import * 10 | from segmentation_models_pytorch import Unet as smpUnet 11 | from segmentation_models_pytorch import FPN 12 | 13 | import torchvision 14 | 15 | 16 | # Register models 17 | registry.Model(UNet3D) 18 | registry.Model(UNet3D2) 19 | registry.Model(ResidualUNet3D) 20 | registry.Model(VNet) 21 | registry.Model(smpUnet) 22 | registry.Model(FPN) 23 | # registry.Model(DeepLab) 24 | registry.MODELS._late_add_callbacks = [] 25 | 26 | # Register callbacks 27 | registry.Callback(MultiDiceCallback) 28 | 29 | # Register criterions 30 | registry.Criterion(MultiDiceLoss) 31 | 32 | # Register optimizers 33 | # registry.Optimizer(AdamW) 34 | # registry.Optimizer(Nadam) 35 | 36 | # registry.Scheduler(CyclicLRFix) -------------------------------------------------------------------------------- /src/extract_slices.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pandas as pd 3 | import os 4 | import numpy as np 5 | from dataset import slice_builder 6 | 7 | 8 | @click.group() 9 | def cli(): 10 | print("Extract slices") 11 | 12 | 13 | @cli.command() 14 | @click.option('--csv_file', type=str) 15 | @click.option('--root', type=str) 16 | @click.option('--save_dir', type=str) 17 | @click.option('--slice', type=int) 18 | @click.option('--patch', type=int) 19 | def extract( 20 | csv_file, 21 | root, 22 | save_dir=None, 23 | slice=16, 24 | patch=32 25 | ): 26 | df = pd.read_csv(csv_file) 27 | all_patient_df = [] 28 | for imgpath, mskpath in zip(df.path, df.pathmsk): 29 | imgpath = os.path.join(root, imgpath) 30 | mskpath = os.path.join(root, mskpath) 31 | patient_df = slice_builder(imgpath, mskpath, slice, patch, save_dir) 32 | all_patient_df.append(patient_df) 33 | all_patient_df = pd.concat(all_patient_df, axis=0).reset_index(drop=True) 34 | all_patient_df.to_csv(os.path.join(save_dir, 'data.csv')) 35 | 36 | 37 | if __name__ == '__main__': 38 | cli() 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StructSeg2019 2 | 3rd of Task3 and 5th of Task4 of StructSeg2019 competition - MICCAI 2019 3 | 4 | # Overview 5 | 6 | This repository is the solution of 3rd place of Task3 and 5th place of Task4 of [StructSeg2019](https://structseg2019.grand-challenge.org/) competition which is a part of MICCAI 2019. 7 | 8 | 9 | # Requirements 10 | - catalyst==19.9.1 11 | - albumentations==0.3.2 12 | - segmentation-models-pytorch==0.0.2 13 | 14 | 15 | # Note 16 | You may see my model named as `VNet`, you may be confused to this [paper](https://arxiv.org/abs/1606.04797). I named as `V` because of my personal purpose (The full name is `VUONGNet`). 17 | 18 | # How to run 19 | 20 | ## Extract 2d slices 21 | Change the input and output path in [extract_slices.sh](bin/extract_slices.sh#L22). 22 | 23 | ```bash 24 | bash bin/extract_slices.sh 25 | ``` 26 | 27 | The output should contain numpy array of each slice and a csv file (data.csv) 28 | 29 | ## Split kfold 30 | 31 | ```bash 32 | python src/preprocessing.py split-kfold --csv_file --n_folds 5 --out_dir 33 | ``` 34 | 35 | ## Train 36 | 37 | I train 2D images for 2 tasks. 38 | All the settings are placed at: [config_2d.yml](configs/config_2d.yml) 39 | 40 | ```bash 41 | bash bin/train_2d.sh 42 | ``` 43 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from catalyst.dl.utils import criterion 7 | from catalyst.utils import get_activation_fn 8 | 9 | 10 | class MultiDiceLoss(nn.Module): 11 | def __init__( 12 | self, 13 | activation: str = "Softmax", 14 | num_classes: int = 7, 15 | weight = None, 16 | dice_weight: float = 0.3, 17 | ): 18 | super().__init__() 19 | if weight is not None: 20 | weight = torch.from_numpy(np.asarray(weight).astype(np.float32)) 21 | else: 22 | weight = None 23 | self.num_classes = num_classes 24 | self.activation = activation 25 | self.dice_weight = dice_weight 26 | self.ce_loss = nn.CrossEntropyLoss(weight=weight) 27 | self.dice_loss = criterion.dice 28 | 29 | def forward(self, logits, targets): 30 | activation_fnc = get_activation_fn(self.activation) 31 | logits_softmax = activation_fnc(logits) 32 | 33 | ce_loss = self.ce_loss(logits, targets) 34 | 35 | dice_loss = 0 36 | for cls in range(self.num_classes): 37 | targets_cls = (targets == cls).float() 38 | outputs_cls = logits_softmax[:, cls] 39 | score = 1 - criterion.dice(outputs_cls, targets_cls, eps=1e-7, activation='none', threshold=None) 40 | dice_loss += score / self.num_classes 41 | 42 | loss = (1 - self.dice_weight) * ce_loss + self.dice_weight * dice_loss 43 | return loss 44 | -------------------------------------------------------------------------------- /src/augmentation.py: -------------------------------------------------------------------------------- 1 | from albumentations import * 2 | 3 | import itertools 4 | 5 | 6 | def train_aug(image_size=224): 7 | # return Compose([ 8 | # HorizontalFlip(), 9 | # Normalize() 10 | # ],p=1) 11 | 12 | return Compose([ 13 | Resize(image_size, image_size), 14 | HorizontalFlip(p=0.5), 15 | # OneOf([ 16 | # RandomContrast(), 17 | # RandomGamma(), 18 | # RandomBrightness(), 19 | # ], p=0.3), 20 | OneOf([ 21 | ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03), 22 | GridDistortion(), 23 | OpticalDistortion(distort_limit=2, shift_limit=0.5), 24 | ], p=0.3), 25 | ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15), 26 | Normalize(max_pixel_value=1) 27 | # RandomSizedCrop(min_max_height=(156, 256), height=h, width=w, p=0.25), 28 | # ToFloat(max_value=1) 29 | ], p=1) 30 | 31 | 32 | def valid_aug(image_size=224): 33 | return Compose([ 34 | Resize(image_size, image_size), 35 | Normalize(max_pixel_value=1) 36 | ], p=1) 37 | 38 | 39 | def test_tta(image_size): 40 | test_dict = { 41 | 'normal': Compose([ 42 | Resize(image_size, image_size) 43 | ]), 44 | # 'hflip': Compose([ 45 | # HorizontalFlip(p=1), 46 | # Resize(image_size, image_size), 47 | # ], p=1), 48 | # 'rot90': Compose([ 49 | # Rotate(limit=(90, 90), p=1), 50 | # Resize(image_size, image_size), 51 | # ], p=1), 52 | } 53 | 54 | return test_dict -------------------------------------------------------------------------------- /src/callbacks.py: -------------------------------------------------------------------------------- 1 | from catalyst.dl.core import Callback, CallbackOrder 2 | from catalyst.dl.utils import criterion 3 | from catalyst.dl.core.state import RunnerState 4 | from catalyst.utils import get_activation_fn 5 | 6 | 7 | class MultiDiceCallback(Callback): 8 | """ 9 | Dice metric callback. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | input_key: str = "targets", 15 | output_key: str = "logits", 16 | prefix: str = "dice", 17 | activation: str = "Softmax", 18 | num_classes : int = 7, 19 | ): 20 | """ 21 | :param input_key: input key to use for dice calculation; 22 | specifies our `y_true`. 23 | :param output_key: output key to use for dice calculation; 24 | specifies our `y_pred`. 25 | """ 26 | super().__init__(CallbackOrder.Metric) 27 | self.input_key = input_key 28 | self.output_key = output_key 29 | self.prefix = prefix 30 | self.activation = activation 31 | self.num_classes = num_classes 32 | 33 | def on_batch_end(self, state: RunnerState): 34 | outputs = state.output[self.output_key] 35 | targets = state.input[self.input_key] 36 | 37 | activation_fnc = get_activation_fn(self.activation) 38 | outputs = activation_fnc(outputs) 39 | _, outputs = outputs.max(dim=1) 40 | 41 | dice = 0 42 | for cls in range(self.num_classes): 43 | targets_cls = (targets == cls).float() 44 | outputs_cls = (outputs == cls).float() 45 | score = criterion.dice(outputs_cls, targets_cls, eps=1e-7, activation='none', threshold=None) 46 | dice += score / self.num_classes 47 | state.metrics.add_batch_value(name=self.prefix, value=dice) 48 | -------------------------------------------------------------------------------- /src/models/vnet/cbam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CBAM_Module(nn.Module): 6 | def __init__(self, channels, reduction,attention_kernel_size=3): 7 | super(CBAM_Module, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.max_pool = nn.AdaptiveMaxPool2d(1) 10 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 11 | padding=0) 12 | self.relu = nn.ReLU(inplace=True) 13 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 14 | padding=0) 15 | self.sigmoid_channel = nn.Sigmoid() 16 | k=2 17 | self.conv_after_concat = nn.Conv2d(k, 1, 18 | kernel_size = attention_kernel_size, 19 | stride=1, 20 | padding = attention_kernel_size//2) 21 | self.sigmoid_spatial = nn.Sigmoid() 22 | 23 | def forward(self, x): 24 | # Channel attention module 25 | module_input = x 26 | avg = self.avg_pool(x) 27 | mx = self.max_pool(x) 28 | avg = self.fc1(avg) 29 | mx = self.fc1(mx) 30 | avg = self.relu(avg) 31 | mx = self.relu(mx) 32 | avg = self.fc2(avg) 33 | mx = self.fc2(mx) 34 | x = avg + mx 35 | x = self.sigmoid_channel(x) 36 | # Spatial attention module 37 | x = module_input * x 38 | module_input = x 39 | b, c, h, w = x.size() 40 | avg = torch.mean(x, 1, True) 41 | mx, _ = torch.max(x, 1, True) 42 | x = torch.cat((avg, mx), 1) 43 | x = self.conv_after_concat(x) 44 | x = self.sigmoid_spatial(x) 45 | x = module_input * x 46 | return x 47 | -------------------------------------------------------------------------------- /configs/config.yml: -------------------------------------------------------------------------------- 1 | model_params: 2 | model: &model VNet 3 | encoder_name: &encoder_name "resnet34" 4 | group_norm: &group_norm False 5 | classes: 7 6 | center: !!str ¢er 'none' 7 | attention_type: !!str &attention_type 'cbam' 8 | reslink: True 9 | multi_task: &multi_task False 10 | 11 | 12 | args: 13 | expdir: "src" 14 | logdir: &logdir "./logs/structseg" 15 | baselogdir: "./logs/structseg" 16 | 17 | distributed_params: 18 | opt_level: O1 19 | 20 | 21 | stages: 22 | 23 | state_params: 24 | main_metric: &reduce_metric dice 25 | minimize_metric: False 26 | 27 | criterion_params: 28 | criterion: &criterion MultiDiceLoss 29 | activation: 'Softmax' 30 | weight: [0.1, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4] 31 | dice_weight: 0.9 32 | 33 | data_params: 34 | batch_size: 16 35 | num_workers: 4 36 | drop_last: False 37 | # drop_last: True 38 | 39 | image_size: &image_size 512 40 | train_csv: "./csv/train_0.csv" 41 | valid_csv: "./csv/valid_0.csv" 42 | data: "2D" 43 | 44 | stage1: 45 | 46 | optimizer_params: 47 | optimizer: Adam 48 | lr: 0.0001 49 | 50 | scheduler_params: 51 | scheduler: OneCycleLR 52 | num_steps: &num_epochs 20 53 | lr_range: [0.0005, 0.00001] 54 | warmup_steps: 5 55 | momentum_range: [0.85, 0.95] 56 | 57 | state_params: 58 | num_epochs: *num_epochs 59 | 60 | callbacks_params: 61 | loss: 62 | callback: CriterionCallback 63 | optimizer: 64 | callback: OptimizerCallback 65 | accumulation_steps: 1 66 | dice: 67 | callback: MultiDiceCallback 68 | activation: 'Softmax' 69 | scheduler: 70 | callback: SchedulerCallback 71 | reduce_metric: *reduce_metric 72 | # mode: 'batch' 73 | saver: 74 | callback: CheckpointCallback 75 | 76 | monitoring_params: 77 | project: "SS_Task3" 78 | tags: [*model, *encoder_name, *criterion] -------------------------------------------------------------------------------- /configs/config_2d.yml: -------------------------------------------------------------------------------- 1 | model_params: 2 | model: &model VNet 3 | encoder_name: &encoder_name "resnet34" 4 | group_norm: &group_norm False 5 | classes: 7 6 | center: !!str ¢er 'none' 7 | attention_type: !!str &attention_type 'none' 8 | reslink: False 9 | multi_task: &multi_task False 10 | 11 | 12 | args: 13 | expdir: "src" 14 | logdir: &logdir "./logs/structseg" 15 | baselogdir: "./logs/structseg" 16 | 17 | distributed_params: 18 | opt_level: O1 19 | 20 | 21 | stages: 22 | 23 | state_params: 24 | main_metric: &reduce_metric dice 25 | minimize_metric: False 26 | 27 | criterion_params: 28 | criterion: &criterion MultiDiceLoss 29 | activation: 'Softmax' 30 | weight: [0.1, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4] 31 | dice_weight: 0.9 32 | 33 | data_params: 34 | batch_size: 16 35 | num_workers: 4 36 | drop_last: False 37 | # drop_last: True 38 | 39 | image_size: &image_size 512 40 | train_csv: "./csv/train_0.csv" 41 | valid_csv: "./csv/valid_0.csv" 42 | data: "2D" 43 | 44 | stage1: 45 | 46 | optimizer_params: 47 | optimizer: Adam 48 | lr: 0.0001 49 | 50 | scheduler_params: 51 | scheduler: OneCycleLRWithWarmup 52 | num_steps: &num_epochs 50 53 | lr_range: [0.0005, 0.00001] 54 | warmup_steps: 5 55 | momentum_range: [0.85, 0.95] 56 | 57 | state_params: 58 | num_epochs: *num_epochs 59 | 60 | callbacks_params: 61 | loss: 62 | callback: CriterionCallback 63 | optimizer: 64 | callback: OptimizerCallback 65 | accumulation_steps: 1 66 | dice: 67 | callback: MultiDiceCallback 68 | activation: 'Softmax' 69 | scheduler: 70 | callback: SchedulerCallback 71 | reduce_metric: *reduce_metric 72 | # mode: 'batch' 73 | saver: 74 | callback: CheckpointCallback 75 | 76 | monitoring_params: 77 | project: "SS_Task3" 78 | tags: [*model, *encoder_name, *criterion] -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as Ftorch 7 | from torch.utils.data import DataLoader 8 | import os 9 | import glob 10 | import click 11 | from tqdm import * 12 | import cv2 13 | 14 | from models import * 15 | from segmentation_models_pytorch.unet import Unet 16 | from augmentation import * 17 | from dataset import * 18 | from utils import * 19 | 20 | 21 | device = torch.device('cuda') 22 | 23 | 24 | def predict(model, loader): 25 | model.eval() 26 | preds = [] 27 | gts = [] 28 | with torch.no_grad(): 29 | for dct in tqdm(loader, total=len(loader)): 30 | images = dct['images'].to(device) 31 | pred = model(images) 32 | pred = Ftorch.sigmoid(pred) 33 | pred = pred.detach().cpu().numpy() 34 | preds.append(pred) 35 | mask = dct['targets'].numpy() 36 | gts.append(mask) 37 | 38 | preds = np.concatenate(preds, axis=0) 39 | gts = np.concatenate(gts, axis=0) 40 | return preds, gts 41 | 42 | 43 | data_csv = "../Lung_GTV_2d/data.csv" 44 | log_dir = f"../logs/Unet-SEResnext50-0/" 45 | 46 | 47 | def predict_valid(): 48 | test_csv = './csv/valid_0.csv' 49 | 50 | model = Unet( 51 | encoder_name='se_resnext50_32x4d', 52 | classes=1, 53 | activation='sigmoid' 54 | ) 55 | ckp = os.path.join(log_dir, "checkpoints/best.pth") 56 | checkpoint = torch.load(ckp) 57 | model.load_state_dict(checkpoint['model_state_dict']) 58 | model = nn.DataParallel(model) 59 | model = model.to(device) 60 | 61 | print("*" * 50) 62 | print(f"checkpoint: {ckp}") 63 | # Dataset 64 | dataset = StructSegTrain2D( 65 | csv_file=test_csv, 66 | data_csv=data_csv, 67 | transform=valid_aug(image_size=512), 68 | ) 69 | 70 | loader = DataLoader( 71 | dataset=dataset, 72 | batch_size=32, 73 | shuffle=False, 74 | num_workers=8, 75 | ) 76 | 77 | preds, gts = predict(model, loader) 78 | 79 | os.makedirs("./prediction/", exist_ok=True) 80 | np.save(f"./prediction/valid.npy", preds) 81 | np.save(f"./prediction/gts.npy", gts) 82 | 83 | 84 | if __name__ == '__main__': 85 | # predict_test() 86 | predict_valid() 87 | -------------------------------------------------------------------------------- /src/experiment.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import ConcatDataset 5 | import random 6 | from catalyst.dl.experiment import ConfigExperiment 7 | from dataset import * 8 | from augmentation import train_aug, valid_aug 9 | 10 | 11 | class Experiment(ConfigExperiment): 12 | def _postprocess_model_for_stage(self, stage: str, model: nn.Module): 13 | 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | random.seed(2411) 18 | np.random.seed(2411) 19 | torch.manual_seed(2411) 20 | 21 | model_ = model 22 | if isinstance(model, torch.nn.DataParallel): 23 | model_ = model_.module 24 | 25 | return model_ 26 | 27 | def get_datasets(self, stage: str, **kwargs): 28 | datasets = OrderedDict() 29 | 30 | """ 31 | image_key: 'id' 32 | label_key: 'attribute_ids' 33 | """ 34 | 35 | image_size = kwargs.get("image_size", 224) 36 | train_csv = kwargs.get('train_csv', None) 37 | valid_csv = kwargs.get('valid_csv', None) 38 | root = kwargs.get('root', None) 39 | data_csv = kwargs.get('data_csv', None) 40 | data = kwargs.get('data', '2D') 41 | 42 | if train_csv: 43 | transform = train_aug(image_size) 44 | if data == '2D': 45 | train_set = StructSegTrain2D( 46 | csv_file=train_csv, 47 | transform=transform, 48 | ) 49 | else: 50 | train_set = StructSegTrain3D( 51 | csv_file=train_csv, 52 | transform=transform, 53 | mode='train' 54 | ) 55 | datasets["train"] = train_set 56 | 57 | if valid_csv: 58 | transform = valid_aug(image_size) 59 | if data == '2D': 60 | valid_set = StructSegTrain2D( 61 | csv_file=valid_csv, 62 | transform=transform, 63 | ) 64 | else: 65 | valid_set = StructSegTrain3D( 66 | csv_file=valid_csv, 67 | transform=transform, 68 | mode='valid' 69 | ) 70 | datasets["valid"] = valid_set 71 | 72 | return datasets 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.2.1 2 | albumentations==0.3.2 3 | apex==0.1 4 | argh==0.26.2 5 | asn1crypto==0.24.0 6 | backcall==0.1.0 7 | beautifulsoup4==4.7.1 8 | catalyst==19.9.1 9 | certifi==2019.3.9 10 | cffi==1.12.3 11 | chardet==3.0.4 12 | click==7.0 13 | cnn-finetune==0.6.0 14 | conda-build==3.17.8 15 | conda==4.6.14 16 | configparser==4.0.2 17 | crc32c==1.7 18 | cryptography==2.6.1 19 | cycler==0.10.0 20 | cython==0.29.13 21 | decorator==4.4.0 22 | docker-pycreds==0.4.0 23 | filelock==3.0.10 24 | gitdb2==2.0.5 25 | gitpython==3.0.2 26 | glob2==0.6 27 | gql==0.1.0 28 | graphql-core==2.2.1 29 | idna==2.8 30 | imagecorruptions==1.0.0 31 | imageio==2.5.0 32 | imgaug==0.2.6 33 | ipython-genutils==0.2.0 34 | ipython==7.5.0 35 | jedi==0.13.3 36 | jinja2==2.10.1 37 | joblib==0.13.2 38 | kiwisolver==1.1.0 39 | libarchive-c==2.8 40 | lief==0.9.0 41 | lz4==2.2.1 42 | markupsafe==1.1.1 43 | matplotlib==3.1.1 44 | mkl-fft==1.0.12 45 | mkl-random==1.0.2 46 | mmcv==0.2.13 47 | mmdet==1.0rc0+c64beaf 48 | munch==2.3.2 49 | networkx==2.3 50 | numpy==1.17.1 51 | nvidia-ml-py3==7.352.0 52 | olefile==0.46 53 | opencv-python-headless==4.1.1.26 54 | opencv-python==4.1.1.26 55 | pandas==0.25.1 56 | parso==0.4.0 57 | pathtools==0.1.2 58 | pexpect==4.7.0 59 | pickleshare==0.7.5 60 | pillow==6.0.0 61 | pip==19.1 62 | pkginfo==1.5.0.1 63 | plotly==4.1.1 64 | pretrainedmodels==0.7.4 65 | promise==2.2.1 66 | prompt-toolkit==2.0.9 67 | protobuf==3.9.1 68 | psutil==5.6.2 69 | ptyprocess==0.6.0 70 | pyarrow==0.14.1 71 | pycocotools==2.0.0 72 | pycosat==0.6.3 73 | pycparser==2.19 74 | pygments==2.3.1 75 | pyopenssl==19.0.0 76 | pyparsing==2.4.2 77 | pysocks==1.6.8 78 | python-dateutil==2.8.0 79 | pytz==2019.1 80 | pywavelets==1.0.3 81 | pyyaml==5.1 82 | requests==2.21.0 83 | retrying==1.3.3 84 | ruamel-yaml==0.15.46 85 | rx==1.6.1 86 | safitty==1.2.5 87 | scikit-image==0.15.0 88 | scikit-learn==0.21.3 89 | scipy==1.3.1 90 | seaborn==0.9.0 91 | segmentation-models-pytorch==0.0.2 92 | sentry-sdk==0.12.0 93 | setuptools==41.0.1 94 | shortuuid==0.5.0 95 | simpleitk==1.2.2 96 | six==1.12.0 97 | smmap2==2.0.5 98 | soupsieve==1.8 99 | subprocess32==3.5.4 100 | tensorboardx==1.8 101 | terminaltables==3.1.0 102 | timm==0.1.12 103 | torch==1.2.0 104 | torchvision==0.4.0 105 | tqdm==4.31.1 106 | traitlets==4.3.2 107 | urllib3==1.24.2 108 | wandb==0.8.10 109 | watchdog==0.9.0 110 | wcwidth==0.1.7 111 | wheel==0.33.1 112 | -------------------------------------------------------------------------------- /src/models/unetaspp/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GroupNorm32(nn.GroupNorm): 7 | def __init__(self, num_channels): 8 | super(GroupNorm32, self).__init__(num_channels=num_channels, num_groups=32) 9 | 10 | class _ASPPModule(nn.Module): 11 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, norm_layer): 12 | super(_ASPPModule, self).__init__() 13 | self.norm = norm_layer(planes) 14 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 15 | stride=1, padding=padding, dilation=dilation, bias=False) 16 | self.elu = nn.ELU(True) 17 | 18 | def forward(self, x): 19 | x = self.atrous_conv(x) 20 | x = self.norm(x) 21 | return self.elu(x) 22 | 23 | class ASPP(nn.Module): 24 | def __init__(self, dilations, inplanes, planes, norm_layer, dropout=0.5): 25 | super(ASPP, self).__init__() 26 | 27 | self.aspp1 = _ASPPModule(inplanes, planes, 1, padding=0, dilation=dilations[0], norm_layer=norm_layer) 28 | self.aspp2 = _ASPPModule(inplanes, planes, 3, padding=dilations[1], dilation=dilations[1], norm_layer=norm_layer) 29 | self.aspp3 = _ASPPModule(inplanes, planes, 3, padding=dilations[2], dilation=dilations[2], norm_layer=norm_layer) 30 | self.aspp4 = _ASPPModule(inplanes, planes, 3, padding=dilations[3], dilation=dilations[3], norm_layer=norm_layer) 31 | 32 | self.norm1 = norm_layer(planes) 33 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 34 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), 35 | norm_layer(planes), 36 | nn.ELU(True)) 37 | self.conv1 = nn.Conv2d(5 * planes, planes, 1, bias=False) 38 | self.elu = nn.ELU(True) 39 | self.dropout = nn.Dropout2d(dropout) 40 | 41 | def forward(self, x): 42 | x1 = self.aspp1(x) 43 | x2 = self.aspp2(x) 44 | x3 = self.aspp3(x) 45 | x4 = self.aspp4(x) 46 | x5 = self.global_avg_pool(x) 47 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 48 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 49 | 50 | x = self.conv1(x) 51 | x = self.norm1(x) 52 | x = self.elu(x) 53 | 54 | return self.dropout(x) 55 | -------------------------------------------------------------------------------- /src/models/vnet/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GroupNorm32(nn.GroupNorm): 7 | def __init__(self, num_channels): 8 | super(GroupNorm32, self).__init__(num_channels=num_channels, num_groups=16) 9 | 10 | 11 | class _ASPPModule(nn.Module): 12 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, norm_layer): 13 | super(_ASPPModule, self).__init__() 14 | self.norm = norm_layer(planes) 15 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 16 | stride=1, padding=padding, dilation=dilation, bias=False) 17 | self.elu = nn.ELU(True) 18 | 19 | def forward(self, x): 20 | x = self.atrous_conv(x) 21 | x = self.norm(x) 22 | return self.elu(x) 23 | 24 | 25 | class ASPP(nn.Module): 26 | def __init__(self, inplanes, planes, norm_layer=GroupNorm32, dilations=[1, 6, 12, 18], dropout=0.5): 27 | super(ASPP, self).__init__() 28 | 29 | self.aspp1 = _ASPPModule(inplanes, planes, 1, padding=0, dilation=dilations[0], norm_layer=norm_layer) 30 | self.aspp2 = _ASPPModule(inplanes, planes, 3, padding=dilations[1], dilation=dilations[1], norm_layer=norm_layer) 31 | self.aspp3 = _ASPPModule(inplanes, planes, 3, padding=dilations[2], dilation=dilations[2], norm_layer=norm_layer) 32 | self.aspp4 = _ASPPModule(inplanes, planes, 3, padding=dilations[3], dilation=dilations[3], norm_layer=norm_layer) 33 | 34 | self.norm1 = norm_layer(planes) 35 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 36 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), 37 | norm_layer(planes), 38 | nn.ELU(True)) 39 | self.conv1 = nn.Conv2d(5 * planes, planes, 1, bias=False) 40 | self.elu = nn.ELU(True) 41 | self.dropout = nn.Dropout2d(dropout) 42 | self._init_weight() 43 | 44 | def forward(self, x): 45 | x1 = self.aspp1(x) 46 | x2 = self.aspp2(x) 47 | x3 = self.aspp3(x) 48 | x4 = self.aspp4(x) 49 | x5 = self.global_avg_pool(x) 50 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 51 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 52 | 53 | x = self.conv1(x) 54 | x = self.norm1(x) 55 | x = self.elu(x) 56 | 57 | return self.dropout(x) 58 | 59 | def _init_weight(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | torch.nn.init.kaiming_normal_(m.weight) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | elif isinstance(m, GroupNorm32): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | -------------------------------------------------------------------------------- /src/preprocessing.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pandas as pd 3 | import os 4 | import numpy as np 5 | from dataset import slice_builder, slice_builder_2d 6 | from sklearn.model_selection import KFold, GroupKFold 7 | 8 | 9 | @click.group() 10 | def cli(): 11 | print("Extract slices") 12 | 13 | 14 | @cli.command() 15 | @click.option('--csv_file', type=str) 16 | @click.option('--root', type=str) 17 | @click.option('--save_dir', type=str) 18 | @click.option('--slice_thichness', type=int) 19 | @click.option('--scale_ratio', type=float) 20 | @click.option('--slice', type=int) 21 | @click.option('--patch', type=int) 22 | def extract( 23 | csv_file, 24 | root, 25 | save_dir, 26 | slice_thichness, 27 | scale_ratio, 28 | slice=16, 29 | patch=32 30 | ): 31 | df = pd.read_csv(csv_file) 32 | all_patient_df = [] 33 | for imgpath, mskpath in zip(df.path, df.pathmsk): 34 | imgpath = os.path.join(root, imgpath) 35 | mskpath = os.path.join(root, mskpath) 36 | patient_df = slice_builder(imgpath, mskpath, slice_thichness, scale_ratio, slice, patch, save_dir) 37 | all_patient_df.append(patient_df) 38 | all_patient_df = pd.concat(all_patient_df, axis=0).reset_index(drop=True) 39 | all_patient_df.to_csv(os.path.join(save_dir, 'data.csv')) 40 | 41 | 42 | @cli.command() 43 | @click.option('--root', type=str) 44 | @click.option('--save_dir', type=str) 45 | def extract_2d( 46 | root, 47 | save_dir, 48 | ): 49 | # df = pd.read_csv(csv_file) 50 | all_patient_df = [] 51 | import glob 52 | paths = glob.glob(root + "/*/*data*") 53 | masks = glob.glob(root + "/*/*label*") 54 | 55 | for imgpath, mskpath in zip(paths, masks): 56 | patient_df = slice_builder_2d(imgpath, mskpath, save_dir) 57 | all_patient_df.append(patient_df) 58 | all_patient_df = pd.concat(all_patient_df, axis=0).reset_index(drop=True) 59 | all_patient_df.to_csv(os.path.join(save_dir, 'data.csv')) 60 | 61 | 62 | @cli.command() 63 | @click.option('--csv_file', type=str) 64 | @click.option('--n_folds', type=int) 65 | @click.option('--save_dir', type=str) 66 | def split_kfold( 67 | csv_file, 68 | n_folds, 69 | save_dir, 70 | ): 71 | os.makedirs(save_dir, exist_ok=True) 72 | df = pd.read_csv(csv_file) 73 | patient_ids = df['patient_id'].values 74 | kf = GroupKFold(n_splits=n_folds) 75 | for fold, (train_idx, valid_idx) in enumerate(kf.split(df, groups=patient_ids)): 76 | # train_patient = patient_ids[train_idx] 77 | # valid_patient = patient_ids[valid_idx] 78 | # train_df = df[df['patient_id'].isin(train_patient)].reset_index(drop=True) 79 | # valid_df = df[df['patient_id'].isin(valid_patient)].reset_index(drop=True) 80 | train_df = df.iloc[train_idx].reset_index(drop=True) 81 | valid_df = df.iloc[valid_idx].reset_index(drop=True) 82 | 83 | train_df.to_csv(os.path.join(save_dir, f'train_{fold}.csv'), index=False) 84 | valid_df.to_csv(os.path.join(save_dir, f'valid_{fold}.csv'), index=False) 85 | 86 | 87 | @cli.command() 88 | @click.option('--csv_file', type=str) 89 | @click.option('--n_folds', type=int) 90 | @click.option('--save_dir', type=str) 91 | def split_kfold_semi( 92 | csv_file, 93 | n_folds, 94 | save_dir, 95 | ): 96 | os.makedirs(save_dir, exist_ok=True) 97 | df = pd.read_csv(csv_file) 98 | all_patients = df['patient_id'].unique() 99 | unlabeled_patients = np.random.choice(all_patients, size=10, replace=False) 100 | unlabeled_df = df[df['patient_id'].isin(unlabeled_patients)] 101 | labeled_df = df[~df['patient_id'].isin(unlabeled_patients)].reset_index(drop=True) 102 | patient_ids = labeled_df['patient_id'].values 103 | kf = GroupKFold(n_splits=n_folds) 104 | for fold, (train_idx, valid_idx) in enumerate(kf.split(labeled_df, groups=patient_ids)): 105 | train_df = labeled_df.iloc[train_idx].reset_index(drop=True) 106 | valid_df = labeled_df.iloc[valid_idx].reset_index(drop=True) 107 | 108 | train_df.to_csv(os.path.join(save_dir, f'train_{fold}.csv'), index=False) 109 | valid_df.to_csv(os.path.join(save_dir, f'valid_{fold}.csv'), index=False) 110 | 111 | unlabeled_df.to_csv(os.path.join(save_dir, 'unlabeled_patients.csv'), index=False) 112 | 113 | 114 | if __name__ == '__main__': 115 | cli() 116 | -------------------------------------------------------------------------------- /src/swa.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | 4 | """ 5 | Stochastic Weight Averaging (SWA) 6 | Averaging Weights Leads to Wider Optima and Better Generalization 7 | https://github.com/timgaripov/swa 8 | """ 9 | import torch 10 | import models 11 | from tqdm import tqdm 12 | import glob 13 | 14 | 15 | def moving_average(net1, net2, alpha=1.): 16 | for param1, param2 in zip(net1.parameters(), net2.parameters()): 17 | param1.data *= (1.0 - alpha) 18 | param1.data += param2.data * alpha 19 | 20 | 21 | def _check_bn(module, flag): 22 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 23 | flag[0] = True 24 | 25 | 26 | def check_bn(model): 27 | flag = [False] 28 | model.apply(lambda module: _check_bn(module, flag)) 29 | return flag[0] 30 | 31 | 32 | def reset_bn(module): 33 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 34 | module.running_mean = torch.zeros_like(module.running_mean) 35 | module.running_var = torch.ones_like(module.running_var) 36 | 37 | 38 | def _get_momenta(module, momenta): 39 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 40 | momenta[module] = module.momentum 41 | 42 | 43 | def _set_momenta(module, momenta): 44 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 45 | module.momentum = momenta[module] 46 | 47 | 48 | def bn_update(loader, model): 49 | """ 50 | BatchNorm buffers update (if any). 51 | Performs 1 epochs to estimate buffers average using train dataset. 52 | :param loader: train dataset loader for buffers average estimation. 53 | :param model: model being update 54 | :return: None 55 | """ 56 | if not check_bn(model): 57 | return 58 | model.train() 59 | momenta = {} 60 | model.apply(reset_bn) 61 | model.apply(lambda module: _get_momenta(module, momenta)) 62 | n = 0 63 | 64 | pbar = tqdm(loader, unit="images", unit_scale=loader.batch_size) 65 | for batch in pbar: 66 | input, targets = batch['images'], batch['targets'] 67 | input = input.cuda() 68 | b = input.size(0) 69 | 70 | momentum = b / (n + b) 71 | for module in momenta.keys(): 72 | module.momentum = momentum 73 | 74 | model(input) 75 | n += b 76 | 77 | model.apply(lambda module: _set_momenta(module, momenta)) 78 | 79 | 80 | if __name__ == '__main__': 81 | import argparse 82 | from pathlib import Path 83 | from torchvision.transforms import Compose 84 | from torch.utils.data import DataLoader 85 | from augmentation import valid_aug 86 | from dataset import SIIMDataset 87 | 88 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 89 | parser.add_argument("--input", type=str, help='input directory') 90 | parser.add_argument("--output", type=str, default='swa_model.pth', help='output model file') 91 | parser.add_argument("--batch-size", type=int, default=16, help='batch size') 92 | args = parser.parse_args() 93 | 94 | # directory = Path(args.input) 95 | # files = [f for f in directory.iterdir() if f.suffix == ".pth"] 96 | files = glob.glob(args.input + "/stage1/checkpoints/stage1.*.pth") 97 | files += glob.glob(args.input + "/stage2/checkpoints/stage1.*.pth") 98 | assert(len(files) > 1) 99 | 100 | net = models.Unet( 101 | encoder_name="resnet34", 102 | activation='sigmoid', 103 | classes=1, 104 | # center=True 105 | ) 106 | checkpoint = torch.load(files[0]) 107 | net.load_state_dict(checkpoint['model_state_dict']) 108 | 109 | for i, f in enumerate(files[1:]): 110 | # net2 = model.load(f) 111 | net2 = models.Unet( 112 | encoder_name="resnet34", 113 | activation='sigmoid', 114 | classes=1, 115 | # center=True 116 | ) 117 | checkpoint = torch.load(f) 118 | net2.load_state_dict(checkpoint['model_state_dict']) 119 | moving_average(net, net2, 1. / (i + 2)) 120 | 121 | test_csv = './csv/train_0.csv' 122 | root = "/raid/data/kaggle/siim/siim256/" 123 | # img_size = 128 124 | batch_size = 16 125 | train_transform = valid_aug() 126 | train_dataset = SIIMDataset( 127 | csv_file=test_csv, 128 | root=root, 129 | transform=train_transform, 130 | mode='train' 131 | ) 132 | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True) 133 | net.cuda() 134 | bn_update(train_dataloader, net) 135 | 136 | # models.save(net, args.output) 137 | torch.save({ 138 | 'model_state_dict': net.state_dict() 139 | }, args.output) 140 | -------------------------------------------------------------------------------- /src/prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | 获取固定取样方式下的训练数据 3 | 首先将灰度值超过upper和低于lower的灰度进行截断 4 | 然后调整slice thickness,然后将slice的分辨率调整为256*256 5 | 只有包含肝脏以及肝脏上下 expand_slice 张slice作为训练样本 6 | 最后将输入数据分块,以轴向 stride 张slice为步长进行取样 7 | 网络输入为256*256*size 8 | 当前脚本依然对金标准进行了缩小,如果要改变,直接修改第70行就行 9 | """ 10 | 11 | import os 12 | import shutil 13 | from time import time 14 | 15 | import numpy as np 16 | import SimpleITK as sitk 17 | import matplotlib.pyplot as plt 18 | import scipy.ndimage as ndimage 19 | 20 | 21 | upper = 400 22 | lower = -1000 23 | expand_slice = 20 # 轴向上向外扩张的slice数量 24 | size = 48 # 取样的slice数量 25 | stride = 3 # 取样的步长 26 | down_scale = 0.5 27 | slice_thickness = 2 28 | 29 | ct_dir = '../Lung_GTV/' 30 | # seg_dir = '/home/zcy/Desktop/dataset/MICCAI-LITS-2017/train/seg/' 31 | 32 | new_ct_dir = f'../Lung_GTV_sz{size}_st{stride}_ds{down_scale}_st{slice_thickness}' 33 | # new_seg_dir = '/home/zcy/Desktop/train/seg/' 34 | 35 | # if os.path.exists('/home/zcy/Desktop/train/'): 36 | # shutil.rmtree('/home/zcy/Desktop/train/') 37 | 38 | # os.mkdir('/home/zcy/Desktop/train/') 39 | os.makedirs(new_ct_dir, exist_ok=True) 40 | # os.mkdir(new_seg_dir) 41 | 42 | 43 | # 用来记录产生的数据的序号 44 | file_index = 0 45 | 46 | # 用来统计最终剩下的slice数量 47 | left_slice_list = [] 48 | 49 | start_time = time() 50 | for ct_file in os.listdir(ct_dir): 51 | 52 | # 将CT和金标准入读内存 53 | ct = sitk.ReadImage(os.path.join(ct_dir, ct_file, 'data.nii.gz'), sitk.sitkInt16) 54 | ct_array = sitk.GetArrayFromImage(ct) 55 | 56 | seg = sitk.ReadImage(os.path.join(ct_dir, ct_file, 'label.nii.gz'), sitk.sitkInt8) 57 | seg_array = sitk.GetArrayFromImage(seg) 58 | 59 | # 将金标准中肝脏和肝肿瘤的标签融合为一个 60 | seg_array[seg_array > 0] = 1 61 | 62 | # 将灰度值在阈值之外的截断掉 63 | ct_array[ct_array > upper] = upper 64 | ct_array[ct_array < lower] = lower 65 | 66 | # 对CT和金标准进行插值,插值之后的array依然是int类型 67 | ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / slice_thickness, down_scale, down_scale), order=3) 68 | seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / slice_thickness, down_scale, down_scale), order=0) 69 | 70 | # 找到肝脏区域开始和结束的slice,并各向外扩张 71 | z = np.any(seg_array, axis=(1, 2)) 72 | start_slice, end_slice = np.where(z)[0][[0, -1]] 73 | 74 | # 两个方向上各扩张个slice 75 | if start_slice - expand_slice < 0: 76 | start_slice = 0 77 | else: 78 | start_slice -= expand_slice 79 | 80 | if end_slice + expand_slice >= seg_array.shape[0]: 81 | end_slice = seg_array.shape[0] - 1 82 | else: 83 | end_slice += expand_slice 84 | 85 | # 如果这时候剩下的slice数量不足size,直接放弃,这样的数据很少 86 | if end_slice - start_slice + 1 < size: 87 | print('!!!!!!!!!!!!!!!!') 88 | print(ct_file, 'too little slice') 89 | print('!!!!!!!!!!!!!!!!') 90 | continue 91 | 92 | ct_array = ct_array[start_slice:end_slice + 1, :, :] 93 | seg_array = seg_array[start_slice:end_slice + 1, :, :] 94 | 95 | print('{} have {} slice left'.format(ct_file, ct_array.shape[0])) 96 | left_slice_list.append(ct_array.shape[0]) 97 | 98 | # 在轴向上按照一定的步长进行切块取样,并将结果保存为nii数据 99 | start_slice = 0 100 | end_slice = start_slice + size - 1 101 | 102 | while end_slice <= ct_array.shape[0] - 1: 103 | 104 | new_ct_array = ct_array[start_slice:end_slice + 1, :, :] 105 | new_seg_array = seg_array[start_slice:end_slice + 1, :, :] 106 | 107 | new_ct = sitk.GetImageFromArray(new_ct_array) 108 | 109 | new_ct.SetDirection(ct.GetDirection()) 110 | new_ct.SetOrigin(ct.GetOrigin()) 111 | new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / down_scale), ct.GetSpacing()[1] * int(1 / down_scale), slice_thickness)) 112 | 113 | new_seg = sitk.GetImageFromArray(new_seg_array) 114 | 115 | new_seg.SetDirection(ct.GetDirection()) 116 | new_seg.SetOrigin(ct.GetOrigin()) 117 | new_seg.SetSpacing((ct.GetSpacing()[0] * int(1 / down_scale), ct.GetSpacing()[1] * int(1 / down_scale), slice_thickness)) 118 | 119 | new_ct_name = 'data-' + str(file_index) + '.nii.gz' 120 | new_seg_name = 'label-' + str(file_index) + '.nii.gz' 121 | 122 | os.makedirs(os.path.join(new_ct_dir, ct_file), exist_ok=True) 123 | 124 | sitk.WriteImage(new_ct, os.path.join(new_ct_dir, ct_file, new_ct_name)) 125 | sitk.WriteImage(new_seg, os.path.join(new_ct_dir, ct_file, new_seg_name)) 126 | 127 | file_index += 1 128 | 129 | start_slice += stride 130 | end_slice = start_slice + size - 1 131 | 132 | # 当无法整除的时候反向取最后一个block 133 | if end_slice is not ct_array.shape[0] - 1: 134 | new_ct_array = ct_array[-size:, :, :] 135 | new_seg_array = seg_array[-size:, :, :] 136 | 137 | new_ct = sitk.GetImageFromArray(new_ct_array) 138 | 139 | new_ct.SetDirection(ct.GetDirection()) 140 | new_ct.SetOrigin(ct.GetOrigin()) 141 | new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / down_scale), ct.GetSpacing()[1] * int(1 / down_scale), slice_thickness)) 142 | 143 | new_seg = sitk.GetImageFromArray(new_seg_array) 144 | 145 | new_seg.SetDirection(ct.GetDirection()) 146 | new_seg.SetOrigin(ct.GetOrigin()) 147 | new_seg.SetSpacing((ct.GetSpacing()[0], ct.GetSpacing()[1], slice_thickness)) 148 | 149 | new_ct_name = 'data-' + str(file_index) + '.nii.gz' 150 | new_seg_name = 'label-' + str(file_index) + '.nii.gz' 151 | 152 | os.makedirs(os.path.join(new_ct_dir, ct_file), exist_ok=True) 153 | 154 | sitk.WriteImage(new_ct, os.path.join(new_ct_dir, ct_file, new_ct_name)) 155 | sitk.WriteImage(new_seg, os.path.join(new_ct_dir, ct_file, new_seg_name)) 156 | 157 | file_index += 1 158 | 159 | # 每处理完一个数据,打印一次已经使用的时间 160 | print('already use {:.3f} min'.format((time() - start_time) / 60)) 161 | print('-----------') 162 | 163 | 164 | left_slice_list = np.array(left_slice_list) 165 | 166 | plt.hist(left_slice_list, 200, rwidth=1) 167 | plt.show() -------------------------------------------------------------------------------- /src/Segmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import cv2 5 | import pandas as pd 6 | import SimpleITK 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader 11 | from tqdm import tqdm 12 | import SimpleITK 13 | from augmentation import valid_aug 14 | from segmentation_models_pytorch.unet import Unet 15 | from segmentation_models_pytorch import FPN 16 | from models import VNet 17 | 18 | 19 | import scipy.ndimage as ndimage 20 | 21 | 22 | UPPER_BOUND = 400 23 | LOWER_BOUND = -1000 24 | 25 | 26 | device = torch.device('cuda') 27 | 28 | 29 | def predict(model, loader): 30 | model.eval() 31 | preds = [] 32 | pred_logits = [] 33 | with torch.no_grad(): 34 | for dct in loader: 35 | images = dct['images'].to(device) 36 | pred = model(images) 37 | pred_sofmax = F.softmax(pred, dim=1) 38 | pred_sofmax = pred_sofmax.detach().cpu().numpy() 39 | pred = pred.detach().cpu().numpy() 40 | preds.append(pred_sofmax) 41 | pred_logits.append(pred) 42 | 43 | preds = np.concatenate(preds, axis=0) 44 | pred_logits = np.concatenate(pred_logits, axis=0) 45 | return preds, pred_logits 46 | 47 | 48 | class TestDataset(Dataset): 49 | def __init__(self, image_slices, transform): 50 | self.image_slices = image_slices 51 | self.transform = transform 52 | 53 | def __len__(self): 54 | return len(self.image_slices) 55 | 56 | def __getitem__(self, idx): 57 | image = self.image_slices[idx] 58 | image = np.stack((image, image, image), axis=-1).astype(np.float32) 59 | 60 | if self.transform: 61 | transform = self.transform(image=image) 62 | image = transform['image'] 63 | 64 | image = np.transpose(image, (2, 0, 1)) 65 | 66 | return { 67 | 'images': image 68 | } 69 | 70 | 71 | class TestDatasetNB(Dataset): 72 | def __init__(self, image_slices, transform): 73 | self.image_slices = image_slices 74 | self.transform = transform 75 | 76 | def __len__(self): 77 | return len(self.image_slices) 78 | 79 | def __getitem__(self, idx): 80 | image = self.image_slices[idx] 81 | if idx == 0: 82 | image_prev = image 83 | else: 84 | image_prev = self.image_slices[idx-1] 85 | 86 | if idx == len(self.image_slices) - 1: 87 | image_next = image 88 | else: 89 | image_next = self.image_slices[idx + 1] 90 | 91 | image = np.stack((image_prev, image, image_next), axis=-1).astype(np.float32) 92 | 93 | if self.transform: 94 | transform = self.transform(image=image) 95 | image = transform['image'] 96 | 97 | image = np.transpose(image, (2, 0, 1)) 98 | 99 | return { 100 | 'images': image 101 | } 102 | 103 | 104 | def extract_slice(file): 105 | ct_image = SimpleITK.ReadImage(file) 106 | image = SimpleITK.GetArrayFromImage(ct_image).astype(np.float32) 107 | 108 | image = (image - LOWER_BOUND) / (UPPER_BOUND - LOWER_BOUND) 109 | image[image > 1] = 1. 110 | image[image < 0] = 0. 111 | image = image.astype(np.float32) 112 | 113 | image_slices = [] 114 | for i, image_slice in enumerate(image): 115 | image_slices.append(image_slice) 116 | 117 | return image_slices, ct_image 118 | 119 | 120 | def predict_valid(): 121 | inputdir = "/data/Thoracic_OAR/" 122 | 123 | transform = valid_aug(image_size=512) 124 | 125 | # nii_files = glob.glob(inputdir + "/*/data.nii.gz") 126 | 127 | folds = [0, 1, 2, 3, 4] 128 | 129 | for fold in folds: 130 | print(fold) 131 | outdir = f"/data/Thoracic_OAR_predict/FPN-seresnext50/" 132 | log_dir = f"/logs/ss_miccai/FPN-se_resnext50_32x4d-fold-{fold}" 133 | # model = VNet( 134 | # encoder_name='se_resnext50_32x4d', 135 | # encoder_weights=None, 136 | # classes=7, 137 | # # activation='sigmoid', 138 | # group_norm=False, 139 | # center='none', 140 | # attention_type='scse', 141 | # reslink=True, 142 | # multi_task=False 143 | # ) 144 | 145 | model = FPN( 146 | encoder_name='se_resnext50_32x4d', 147 | encoder_weights=None, 148 | classes=7 149 | ) 150 | 151 | ckp = os.path.join(log_dir, "checkpoints/best.pth") 152 | checkpoint = torch.load(ckp) 153 | model.load_state_dict(checkpoint['model_state_dict']) 154 | model = nn.DataParallel(model) 155 | model = model.to(device) 156 | 157 | df = pd.read_csv(f'./csv/5folds/valid_{fold}.csv') 158 | patient_ids = df.patient_id.unique() 159 | for patient_id in patient_ids: 160 | print(patient_id) 161 | nii_file = f"{inputdir}/{patient_id}/data.nii.gz" 162 | 163 | image_slices, ct_image = extract_slice(nii_file) 164 | dataset = TestDataset(image_slices, transform) 165 | dataloader = DataLoader( 166 | dataset=dataset, 167 | num_workers=4, 168 | batch_size=8, 169 | drop_last=False 170 | ) 171 | 172 | pred_mask, pred_logits = predict(model, dataloader) 173 | # import pdb 174 | # pdb.set_trace() 175 | pred_mask = np.argmax(pred_mask, axis=1).astype(np.uint8) 176 | pred_mask = SimpleITK.GetImageFromArray(pred_mask) 177 | 178 | pred_mask.SetDirection(ct_image.GetDirection()) 179 | pred_mask.SetOrigin(ct_image.GetOrigin()) 180 | pred_mask.SetSpacing(ct_image.GetSpacing()) 181 | 182 | # patient_id = nii_file.split("/")[-2] 183 | patient_dir = f"{outdir}/{patient_id}" 184 | os.makedirs(patient_dir, exist_ok=True) 185 | patient_pred = f"{patient_dir}/predict.nii.gz" 186 | SimpleITK.WriteImage( 187 | pred_mask, patient_pred 188 | ) 189 | # np.save(f"{patient_dir}/predic_logits.npy", pred_logits) 190 | 191 | 192 | 193 | if __name__ == '__main__': 194 | predict_valid() 195 | -------------------------------------------------------------------------------- /src/Segmentation3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import cv2 5 | import pandas as pd 6 | import SimpleITK 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader 11 | from tqdm import tqdm 12 | import SimpleITK 13 | from augmentation import valid_aug 14 | from segmentation_models_pytorch.unet import Unet 15 | from models import ResidualUNet3D 16 | 17 | import scipy.ndimage as ndimage 18 | 19 | 20 | UPPER_BOUND = 400 21 | LOWER_BOUND = -1000 22 | 23 | 24 | device = torch.device('cuda') 25 | 26 | 27 | def predict(model, loader): 28 | model.eval() 29 | preds = [] 30 | pred_logits = [] 31 | with torch.no_grad(): 32 | for dct in loader: 33 | images = dct['images'].to(device) 34 | pred = model(images) 35 | pred_sofmax = F.softmax(pred, dim=1) 36 | pred_sofmax = pred_sofmax.detach().cpu().numpy() 37 | pred = pred.detach().cpu().numpy() 38 | preds.append(pred_sofmax) 39 | pred_logits.append(pred) 40 | 41 | preds = np.concatenate(preds, axis=0) 42 | pred_logits = np.concatenate(pred_logits, axis=0) 43 | return preds, pred_logits 44 | 45 | 46 | def load_ct_images(path): 47 | image = SimpleITK.ReadImage(path) 48 | spacing = image.GetSpacing()[-1] 49 | image_arr = SimpleITK.GetArrayFromImage(image).astype(np.float32) 50 | return image_arr, image 51 | 52 | 53 | 54 | class TestDataset(Dataset): 55 | def __init__(self, image_slices, transform): 56 | self.image_slices = image_slices 57 | self.transform = transform 58 | 59 | def __len__(self): 60 | return len(self.image_slices) 61 | 62 | def __getitem__(self, idx): 63 | image = self.image_slices[idx] 64 | image = np.stack((image, image, image), axis=-1).astype(np.float32) 65 | 66 | if self.transform: 67 | transform = self.transform(image=image) 68 | image = transform['image'] 69 | 70 | image = np.transpose(image, (2, 0, 1)) 71 | 72 | return { 73 | 'images': image 74 | } 75 | 76 | 77 | def predict_valid(): 78 | inputdir = "/data/Thoracic_OAR/" 79 | 80 | transform = valid_aug(image_size=512) 81 | 82 | # nii_files = glob.glob(inputdir + "/*/data.nii.gz") 83 | 84 | folds = [0] 85 | 86 | crop_size = (32, 256, 256) 87 | xstep = 1 88 | ystep = 256 89 | zstep = 256 90 | num_classes = 7 91 | 92 | for fold in folds: 93 | print(fold) 94 | outdir = f"/data/Thoracic_OAR_predict/Unet3D/" 95 | log_dir = f"/logs/ss_miccai/Unet3D-fold-{fold}" 96 | model = ResidualUNet3D( 97 | in_channels=1, 98 | out_channels=num_classes 99 | ) 100 | 101 | ckp = os.path.join(log_dir, "checkpoints/best.pth") 102 | checkpoint = torch.load(ckp) 103 | model.load_state_dict(checkpoint['model_state_dict']) 104 | model = nn.DataParallel(model) 105 | model = model.to(device) 106 | 107 | df = pd.read_csv(f'./csv/5folds/valid_{fold}.csv') 108 | patient_ids = df.patient_id.unique() 109 | for patient_id in patient_ids: 110 | print(patient_id) 111 | nii_file = f"{inputdir}/{patient_id}/data.nii.gz" 112 | 113 | image, ct_image = load_ct_images(nii_file) 114 | 115 | image = (image - LOWER_BOUND) / (UPPER_BOUND - LOWER_BOUND) 116 | image[image > 1] = 1. 117 | image[image < 0] = 0. 118 | image = image.astype(np.float32) 119 | C, H, W = image.shape 120 | 121 | deep_slices = np.arange(0, C - crop_size[0] + xstep, xstep) 122 | height_slices = np.arange(0, H - crop_size[1] + ystep, ystep) 123 | width_slices = np.arange(0, W - crop_size[2] + zstep, zstep) 124 | 125 | whole_pred = np.zeros((num_classes, C, H, W)) 126 | count_used = np.zeros((C, H, W)) + 1e-5 127 | 128 | # no update parameter gradients during testing 129 | with torch.no_grad(): 130 | for i in tqdm(range(len(deep_slices))): 131 | for j in range(len(height_slices)): 132 | for k in range(len(width_slices)): 133 | deep = deep_slices[i] 134 | height = height_slices[j] 135 | width = width_slices[k] 136 | image_crop = image[deep: deep + crop_size[0], 137 | height: height + crop_size[1], 138 | width: width + crop_size[2]] 139 | image_crop = np.expand_dims(image_crop, axis=0) 140 | image_crop = np.expand_dims(image_crop, axis=0) 141 | image_crop = torch.from_numpy(image_crop).to(device) 142 | # import pdb 143 | # pdb.set_trace() 144 | outputs = model(image_crop) 145 | outputs = F.softmax(outputs, dim=1) 146 | # ----------------Average------------------------------- 147 | whole_pred[:, 148 | deep: deep + crop_size[0], 149 | height: height + crop_size[1], 150 | width: width + crop_size[2] 151 | ] += outputs.data.cpu().numpy()[0] 152 | 153 | count_used[deep: deep + crop_size[0], 154 | height: height + crop_size[1], 155 | width: width + crop_size[2]] += 1 156 | 157 | whole_pred = whole_pred / count_used 158 | pred_mask = np.argmax(whole_pred, axis=0).astype(np.uint8) 159 | 160 | # pred_mask, pred_logits = predict(model, dataloader) 161 | # # import pdb 162 | # # pdb.set_trace() 163 | pred_mask = SimpleITK.GetImageFromArray(pred_mask) 164 | 165 | pred_mask.SetDirection(ct_image.GetDirection()) 166 | pred_mask.SetOrigin(ct_image.GetOrigin()) 167 | pred_mask.SetSpacing(ct_image.GetSpacing()) 168 | 169 | # patient_id = nii_file.split("/")[-2] 170 | patient_dir = f"{outdir}/{patient_id}" 171 | os.makedirs(patient_dir, exist_ok=True) 172 | patient_pred = f"{patient_dir}/predict.nii.gz" 173 | SimpleITK.WriteImage( 174 | pred_mask, patient_pred 175 | ) 176 | # np.save(f"{patient_dir}/predic_logits.npy", pred_logits) 177 | 178 | 179 | 180 | if __name__ == '__main__': 181 | predict_valid() 182 | -------------------------------------------------------------------------------- /src/models/DialResUnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | 基础网络脚本 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class DialResUNet(nn.Module): 11 | """ 12 | 共9498260个可训练的参数, 接近九百五十万 13 | """ 14 | def __init__(self, training): 15 | super().__init__() 16 | 17 | self.training = training 18 | 19 | self.encoder_stage1 = nn.Sequential( 20 | nn.Conv3d(1, 16, 3, 1, padding=1), 21 | nn.PReLU(16), 22 | 23 | nn.Conv3d(16, 16, 3, 1, padding=1), 24 | nn.PReLU(16), 25 | ) 26 | 27 | self.encoder_stage2 = nn.Sequential( 28 | nn.Conv3d(32, 32, 3, 1, padding=1), 29 | nn.PReLU(32), 30 | 31 | nn.Conv3d(32, 32, 3, 1, padding=1), 32 | nn.PReLU(32), 33 | 34 | nn.Conv3d(32, 32, 3, 1, padding=1), 35 | nn.PReLU(32), 36 | ) 37 | 38 | self.encoder_stage3 = nn.Sequential( 39 | nn.Conv3d(64, 64, 3, 1, padding=1), 40 | nn.PReLU(64), 41 | 42 | nn.Conv3d(64, 64, 3, 1, padding=2, dilation=2), 43 | nn.PReLU(64), 44 | 45 | nn.Conv3d(64, 64, 3, 1, padding=4, dilation=4), 46 | nn.PReLU(64), 47 | ) 48 | 49 | self.encoder_stage4 = nn.Sequential( 50 | nn.Conv3d(128, 128, 3, 1, padding=3, dilation=3), 51 | nn.PReLU(128), 52 | 53 | nn.Conv3d(128, 128, 3, 1, padding=4, dilation=4), 54 | nn.PReLU(128), 55 | 56 | nn.Conv3d(128, 128, 3, 1, padding=5, dilation=5), 57 | nn.PReLU(128), 58 | ) 59 | 60 | self.decoder_stage1 = nn.Sequential( 61 | nn.Conv3d(128, 256, 3, 1, padding=1), 62 | nn.PReLU(256), 63 | 64 | nn.Conv3d(256, 256, 3, 1, padding=1), 65 | nn.PReLU(256), 66 | 67 | nn.Conv3d(256, 256, 3, 1, padding=1), 68 | nn.PReLU(256), 69 | ) 70 | 71 | self.decoder_stage2 = nn.Sequential( 72 | nn.Conv3d(128 + 64, 128, 3, 1, padding=1), 73 | nn.PReLU(128), 74 | 75 | nn.Conv3d(128, 128, 3, 1, padding=1), 76 | nn.PReLU(128), 77 | 78 | nn.Conv3d(128, 128, 3, 1, padding=1), 79 | nn.PReLU(128), 80 | ) 81 | 82 | self.decoder_stage3 = nn.Sequential( 83 | nn.Conv3d(64 + 32, 64, 3, 1, padding=1), 84 | nn.PReLU(64), 85 | 86 | nn.Conv3d(64, 64, 3, 1, padding=1), 87 | nn.PReLU(64), 88 | 89 | nn.Conv3d(64, 64, 3, 1, padding=1), 90 | nn.PReLU(64), 91 | ) 92 | 93 | self.decoder_stage4 = nn.Sequential( 94 | nn.Conv3d(32 + 16, 32, 3, 1, padding=1), 95 | nn.PReLU(32), 96 | 97 | nn.Conv3d(32, 32, 3, 1, padding=1), 98 | nn.PReLU(32), 99 | ) 100 | 101 | self.down_conv1 = nn.Sequential( 102 | nn.Conv3d(16, 32, 2, 2), 103 | nn.PReLU(32) 104 | ) 105 | 106 | self.down_conv2 = nn.Sequential( 107 | nn.Conv3d(32, 64, 2, 2), 108 | nn.PReLU(64) 109 | ) 110 | 111 | self.down_conv3 = nn.Sequential( 112 | nn.Conv3d(64, 128, 2, 2), 113 | nn.PReLU(128) 114 | ) 115 | 116 | self.down_conv4 = nn.Sequential( 117 | nn.Conv3d(128, 256, 3, 1, padding=1), 118 | nn.PReLU(256) 119 | ) 120 | 121 | self.up_conv2 = nn.Sequential( 122 | nn.ConvTranspose3d(256, 128, 2, 2), 123 | nn.PReLU(128) 124 | ) 125 | 126 | self.up_conv3 = nn.Sequential( 127 | nn.ConvTranspose3d(128, 64, 2, 2), 128 | nn.PReLU(64) 129 | ) 130 | 131 | self.up_conv4 = nn.Sequential( 132 | nn.ConvTranspose3d(64, 32, 2, 2), 133 | nn.PReLU(32) 134 | ) 135 | 136 | # 最后大尺度下的映射(256*256),下面的尺度依次递减 137 | self.map4 = nn.Sequential( 138 | nn.Conv3d(32, 1, 1, 1), 139 | nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear'), 140 | nn.Sigmoid() 141 | ) 142 | 143 | # 128*128 尺度下的映射 144 | self.map3 = nn.Sequential( 145 | nn.Conv3d(64, 1, 1, 1), 146 | nn.Upsample(scale_factor=(2, 4, 4), mode='trilinear'), 147 | nn.Sigmoid() 148 | ) 149 | 150 | # 64*64 尺度下的映射 151 | self.map2 = nn.Sequential( 152 | nn.Conv3d(128, 1, 1, 1), 153 | nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear'), 154 | nn.Sigmoid() 155 | ) 156 | 157 | # 32*32 尺度下的映射 158 | self.map1 = nn.Sequential( 159 | nn.Conv3d(256, 1, 1, 1), 160 | nn.Upsample(scale_factor=(8, 16, 16), mode='trilinear'), 161 | nn.Sigmoid() 162 | ) 163 | 164 | def forward(self, inputs): 165 | 166 | long_range1 = self.encoder_stage1(inputs) + inputs 167 | 168 | short_range1 = self.down_conv1(long_range1) 169 | 170 | long_range2 = self.encoder_stage2(short_range1) + short_range1 171 | long_range2 = F.dropout(long_range2, 0.3, self.training) 172 | 173 | short_range2 = self.down_conv2(long_range2) 174 | 175 | long_range3 = self.encoder_stage3(short_range2) + short_range2 176 | long_range3 = F.dropout(long_range3, 0.3, self.training) 177 | 178 | short_range3 = self.down_conv3(long_range3) 179 | 180 | long_range4 = self.encoder_stage4(short_range3) + short_range3 181 | long_range4 = F.dropout(long_range4, 0.3, self.training) 182 | 183 | short_range4 = self.down_conv4(long_range4) 184 | 185 | outputs = self.decoder_stage1(long_range4) + short_range4 186 | outputs = F.dropout(outputs, 0.3, self.training) 187 | 188 | output1 = self.map1(outputs) 189 | 190 | short_range6 = self.up_conv2(outputs) 191 | 192 | outputs = self.decoder_stage2(torch.cat([short_range6, long_range3], dim=1)) + short_range6 193 | outputs = F.dropout(outputs, 0.3, self.training) 194 | 195 | output2 = self.map2(outputs) 196 | 197 | short_range7 = self.up_conv3(outputs) 198 | 199 | outputs = self.decoder_stage3(torch.cat([short_range7, long_range2], dim=1)) + short_range7 200 | outputs = F.dropout(outputs, 0.3, self.training) 201 | 202 | output3 = self.map3(outputs) 203 | 204 | short_range8 = self.up_conv4(outputs) 205 | 206 | outputs = self.decoder_stage4(torch.cat([short_range8, long_range1], dim=1)) + short_range8 207 | 208 | output4 = self.map4(outputs) 209 | 210 | if self.training is True: 211 | return output1, output2, output3, output4 212 | else: 213 | return output4 214 | 215 | 216 | def init(module): 217 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): 218 | nn.init.kaiming_normal(module.weight.data, 0.25) 219 | nn.init.constant(module.bias.data, 0) -------------------------------------------------------------------------------- /src/models/ModelDeepLab/deeplab.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/jfzhang95/pytorch-deeplab-xception 2 | 3 | import models.ModelDeepLab.backbones_deeplab as backbones 4 | 5 | import torch 6 | 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from torch.nn.modules.batchnorm import _BatchNorm 11 | 12 | class GroupNorm32(nn.GroupNorm): 13 | def __init__(self, num_channels): 14 | super(GroupNorm32, self).__init__(num_channels=num_channels, num_groups=32) 15 | 16 | class _ASPPModule(nn.Module): 17 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, norm_layer): 18 | super(_ASPPModule, self).__init__() 19 | self.norm = norm_layer(planes) 20 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 21 | stride=1, padding=padding, dilation=dilation, bias=False) 22 | self.elu = nn.ELU(True) 23 | 24 | def forward(self, x): 25 | x = self.atrous_conv(x) 26 | x = self.norm(x) 27 | return self.elu(x) 28 | 29 | class ASPP(nn.Module): 30 | def __init__(self, dilations, inplanes, planes, norm_layer, dropout=0.5): 31 | super(ASPP, self).__init__() 32 | 33 | self.aspp1 = _ASPPModule(inplanes, planes, 1, padding=0, dilation=dilations[0], norm_layer=norm_layer) 34 | self.aspp2 = _ASPPModule(inplanes, planes, 3, padding=dilations[1], dilation=dilations[1], norm_layer=norm_layer) 35 | self.aspp3 = _ASPPModule(inplanes, planes, 3, padding=dilations[2], dilation=dilations[2], norm_layer=norm_layer) 36 | self.aspp4 = _ASPPModule(inplanes, planes, 3, padding=dilations[3], dilation=dilations[3], norm_layer=norm_layer) 37 | 38 | self.norm1 = norm_layer(planes) 39 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 40 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), 41 | norm_layer(planes), 42 | nn.ELU(True)) 43 | self.conv1 = nn.Conv2d(5 * planes, planes, 1, bias=False) 44 | self.elu = nn.ELU(True) 45 | self.dropout = nn.Dropout2d(dropout) 46 | 47 | def forward(self, x): 48 | x1 = self.aspp1(x) 49 | x2 = self.aspp2(x) 50 | x3 = self.aspp3(x) 51 | x4 = self.aspp4(x) 52 | x5 = self.global_avg_pool(x) 53 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 54 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 55 | 56 | x = self.conv1(x) 57 | x = self.norm1(x) 58 | x = self.elu(x) 59 | 60 | return self.dropout(x) 61 | 62 | class Decoder(nn.Module): 63 | def __init__(self, num_classes, spp_inplanes, low_level_inplanes, inplanes, dropout, norm_layer): 64 | super(Decoder, self).__init__() 65 | 66 | self.conv1 = nn.Conv2d(low_level_inplanes, inplanes, 1, bias=False) 67 | self.norm1 = norm_layer(inplanes) 68 | 69 | self.elu = nn.ELU(True) 70 | self.last_conv = nn.Sequential(nn.Conv2d(spp_inplanes + inplanes, spp_inplanes, kernel_size=3, stride=1, padding=1, bias=False), 71 | norm_layer(spp_inplanes), 72 | nn.ELU(True), 73 | nn.Dropout2d(dropout[0]), 74 | nn.Conv2d(spp_inplanes, spp_inplanes, kernel_size=3, stride=1, padding=1, bias=False), 75 | norm_layer(spp_inplanes), 76 | nn.ELU(True), 77 | nn.Dropout2d(dropout[1]), 78 | nn.Conv2d(spp_inplanes, num_classes, kernel_size=1, stride=1)) 79 | 80 | def forward(self, x, low_level_feat, classifier=False): 81 | low_level_feat = self.conv1(low_level_feat) 82 | low_level_feat = self.norm1(low_level_feat) 83 | low_level_feat = self.elu(low_level_feat) 84 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=False) 85 | x = torch.cat((x, low_level_feat), dim=1) 86 | decoder_output = x 87 | x = self.last_conv(x) 88 | if classifier: 89 | return x, decoder_output 90 | else: 91 | return x 92 | 93 | 94 | class DeepLab(nn.Module): 95 | def __init__(self, backbone, 96 | output_stride=16, 97 | group_norm=True, 98 | classifier=False, 99 | dropout=dict( 100 | spp=0.5, 101 | cls=0.2, 102 | dc0=0.5, 103 | dc1=0.1 104 | ), 105 | num_classes=2, 106 | norm_eval=False): 107 | super(DeepLab, self).__init__() 108 | 109 | layers, channels, backbone = getattr(backbones, backbone)(output_stride=output_stride) 110 | 111 | self.input_range = backbone.input_range 112 | self.mean = backbone.mean 113 | self.std = backbone.std 114 | 115 | self.classifier = classifier 116 | 117 | # default is freeze BatchNorm 118 | self.norm_eval = norm_eval 119 | 120 | norm_layer = GroupNorm32 if group_norm else nn.BatchNorm2d 121 | 122 | self.backbone = layers[0] 123 | self.low_level = layers[1] 124 | 125 | self.aspp_planes = 256 126 | 127 | if output_stride == 16: 128 | aspp_dilations = (1, 6, 12, 18) 129 | elif output_stride == 8: 130 | aspp_dilations = (1, 12, 24, 36) 131 | 132 | self.spp = ASPP(aspp_dilations, inplanes=channels[1], planes=self.aspp_planes, dropout=dropout['spp'], norm_layer=norm_layer) 133 | self.decoder = Decoder(num_classes, self.aspp_planes, channels[0], 64, (dropout['dc0'], dropout['dc1']), norm_layer) 134 | self.train_mode = True 135 | 136 | # classifier branch 137 | if classifier: 138 | self.logit_image = nn.Sequential(nn.Dropout(dropout['cls']), nn.Linear(channels[1]+self.aspp_planes+64, num_classes)) 139 | 140 | def forward(self, x_input): 141 | low_level_feat = self.low_level(x_input) 142 | features = self.backbone(x_input) 143 | 144 | x = self.spp(features) 145 | if self.classifier: 146 | x, decoder_output = self.decoder(x, low_level_feat, classifier=self.classifier) 147 | else: 148 | x = self.decoder(x, low_level_feat, classifier=self.classifier) 149 | out_size = x_input.size()[2:] 150 | x = F.interpolate(x, size=out_size, mode='bilinear', align_corners=False) 151 | 152 | # classifier branch 153 | if self.classifier: 154 | features = features.mean([2, 3]) 155 | decoder_output = decoder_output.mean([2,3]) 156 | features = torch.cat((features, decoder_output), dim=1) 157 | c = self.logit_image(features) 158 | return x, c 159 | else: 160 | return x 161 | 162 | def train(self, mode=True): 163 | super(DeepLab, self).train(mode) 164 | if mode and self.norm_eval: 165 | for m in self.modules(): 166 | # trick: eval have effect on BatchNorm only 167 | if isinstance(m, _BatchNorm): 168 | m.eval() 169 | return self 170 | 171 | 172 | -------------------------------------------------------------------------------- /src/models/unet3d/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import sys 5 | import scipy.sparse as sparse 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def save_checkpoint(state, is_best, checkpoint_dir, logger=None): 12 | """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'. 13 | If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well. 14 | Args: 15 | state (dict): contains model's state_dict, optimizer's state_dict, epoch 16 | and best evaluation metric value so far 17 | is_best (bool): if True state contains the best model seen so far 18 | checkpoint_dir (string): directory where the checkpoint are to be saved 19 | """ 20 | 21 | def log_info(message): 22 | if logger is not None: 23 | logger.info(message) 24 | 25 | if not os.path.exists(checkpoint_dir): 26 | log_info( 27 | f"Checkpoint directory does not exists. Creating {checkpoint_dir}") 28 | os.mkdir(checkpoint_dir) 29 | 30 | last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch') 31 | log_info(f"Saving last checkpoint to '{last_file_path}'") 32 | torch.save(state, last_file_path) 33 | if is_best: 34 | best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch') 35 | log_info(f"Saving best checkpoint to '{best_file_path}'") 36 | shutil.copyfile(last_file_path, best_file_path) 37 | 38 | 39 | def load_checkpoint(checkpoint_path, model, optimizer=None): 40 | """Loads model and training parameters from a given checkpoint_path 41 | If optimizer is provided, loads optimizer's state_dict of as well. 42 | Args: 43 | checkpoint_path (string): path to the checkpoint to be loaded 44 | model (torch.nn.Module): model into which the parameters are to be copied 45 | optimizer (torch.optim.Optimizer) optional: optimizer instance into 46 | which the parameters are to be copied 47 | Returns: 48 | state 49 | """ 50 | if not os.path.exists(checkpoint_path): 51 | raise IOError(f"Checkpoint '{checkpoint_path}' does not exist") 52 | 53 | state = torch.load(checkpoint_path) 54 | model.load_state_dict(state['model_state_dict']) 55 | 56 | if optimizer is not None: 57 | optimizer.load_state_dict(state['optimizer_state_dict']) 58 | 59 | return state 60 | 61 | 62 | def get_logger(name, level=logging.INFO): 63 | logger = logging.getLogger(name) 64 | logger.setLevel(level) 65 | # Logging to console 66 | stream_handler = logging.StreamHandler(sys.stdout) 67 | formatter = logging.Formatter( 68 | '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s') 69 | stream_handler.setFormatter(formatter) 70 | logger.addHandler(stream_handler) 71 | 72 | return logger 73 | 74 | 75 | def get_number_of_learnable_parameters(model): 76 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 77 | return sum([np.prod(p.size()) for p in model_parameters]) 78 | 79 | 80 | class RunningAverage: 81 | """Computes and stores the average 82 | """ 83 | 84 | def __init__(self): 85 | self.count = 0 86 | self.sum = 0 87 | self.avg = 0 88 | 89 | def update(self, value, n=1): 90 | self.count += n 91 | self.sum += value * n 92 | self.avg = self.sum / self.count 93 | 94 | 95 | def find_maximum_patch_size(model, device): 96 | """Tries to find the biggest patch size that can be send to GPU for inference 97 | without throwing CUDA out of memory""" 98 | logger = get_logger('PatchFinder') 99 | in_channels = model.in_channels 100 | 101 | patch_shapes = [(64, 128, 128), (96, 128, 128), 102 | (64, 160, 160), (96, 160, 160), 103 | (64, 192, 192), (96, 192, 192)] 104 | 105 | for shape in patch_shapes: 106 | # generate random patch of a given size 107 | patch = np.random.randn(*shape).astype('float32') 108 | 109 | patch = torch \ 110 | .from_numpy(patch) \ 111 | .view((1, in_channels) + patch.shape) \ 112 | .to(device) 113 | 114 | logger.info(f"Current patch size: {shape}") 115 | model(patch) 116 | 117 | 118 | def unpad(probs, index, shape, pad_width=8): 119 | def _new_slices(slicing, max_size): 120 | if slicing.start == 0: 121 | p_start = 0 122 | i_start = 0 123 | else: 124 | p_start = pad_width 125 | i_start = slicing.start + pad_width 126 | 127 | if slicing.stop == max_size: 128 | p_stop = None 129 | i_stop = max_size 130 | else: 131 | p_stop = -pad_width 132 | i_stop = slicing.stop - pad_width 133 | 134 | return slice(p_start, p_stop), slice(i_start, i_stop) 135 | 136 | D, H, W = shape 137 | 138 | i_c, i_z, i_y, i_x = index 139 | p_c = slice(0, probs.shape[0]) 140 | 141 | p_z, i_z = _new_slices(i_z, D) 142 | p_y, i_y = _new_slices(i_y, H) 143 | p_x, i_x = _new_slices(i_x, W) 144 | 145 | probs_index = (p_c, p_z, p_y, p_x) 146 | index = (i_c, i_z, i_y, i_x) 147 | return probs[probs_index], index 148 | 149 | 150 | def create_feature_maps(init_channel_number, number_of_fmaps): 151 | return [init_channel_number * 2 ** k for k in range(number_of_fmaps)] 152 | 153 | 154 | # Code taken from https://github.com/cremi/cremi_python 155 | def adapted_rand(seg, gt, all_stats=False): 156 | """Compute Adapted Rand error as defined by the SNEMI3D contest [1] 157 | Formula is given as 1 - the maximal F-score of the Rand index 158 | (excluding the zero component of the original labels). Adapted 159 | from the SNEMI3D MATLAB script, hence the strange style. 160 | Parameters 161 | ---------- 162 | seg : np.ndarray 163 | the segmentation to score, where each value is the label at that point 164 | gt : np.ndarray, same shape as seg 165 | the groundtruth to score against, where each value is a label 166 | all_stats : boolean, optional 167 | whether to also return precision and recall as a 3-tuple with rand_error 168 | Returns 169 | ------- 170 | are : float 171 | The adapted Rand error; equal to $1 - \frac{2pr}{p + r}$, 172 | where $p$ and $r$ are the precision and recall described below. 173 | prec : float, optional 174 | The adapted Rand precision. (Only returned when `all_stats` is ``True``.) 175 | rec : float, optional 176 | The adapted Rand recall. (Only returned when `all_stats` is ``True``.) 177 | References 178 | ---------- 179 | [1]: http://brainiac2.mit.edu/SNEMI3D/evaluation 180 | """ 181 | # just to prevent division by 0 182 | epsilon = 1e-6 183 | 184 | # segA is truth, segB is query 185 | segA = np.ravel(gt) 186 | segB = np.ravel(seg) 187 | n = segA.size 188 | 189 | n_labels_A = np.amax(segA) + 1 190 | n_labels_B = np.amax(segB) + 1 191 | 192 | ones_data = np.ones(n) 193 | 194 | p_ij = sparse.csr_matrix((ones_data, (segA[:], segB[:])), shape=(n_labels_A, n_labels_B)) 195 | 196 | a = p_ij[1:n_labels_A, :] 197 | b = p_ij[1:n_labels_A, 1:n_labels_B] 198 | c = p_ij[1:n_labels_A, 0].todense() 199 | d = b.multiply(b) 200 | 201 | a_i = np.array(a.sum(1)) 202 | b_i = np.array(b.sum(0)) 203 | 204 | sumA = np.sum(a_i * a_i) 205 | sumB = np.sum(b_i * b_i) + (np.sum(c) / n) 206 | sumAB = np.sum(d) + (np.sum(c) / n) 207 | 208 | precision = sumAB / max(sumB, epsilon) 209 | recall = sumAB / max(sumA, epsilon) 210 | 211 | fScore = 2.0 * precision * recall / max(precision + recall, epsilon) 212 | are = 1.0 - fScore 213 | 214 | if all_stats: 215 | return are, precision, recall 216 | else: 217 | return are 218 | -------------------------------------------------------------------------------- /src/models/ModelDeepLab/backbones_deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .resnext import ResNeXt, ResNet 6 | 7 | class SpatialAttention2d(nn.Module): 8 | def __init__(self, channel, conv_layer): 9 | super(SpatialAttention2d, self).__init__() 10 | self.squeeze = nn.Conv2d(channel, 1, kernel_size=1, bias=False) 11 | self.sigmoid = nn.Sigmoid() 12 | 13 | def forward(self, x): 14 | z = self.squeeze(x) 15 | z = self.sigmoid(z) 16 | return x * z 17 | 18 | class GAB(nn.Module): 19 | def __init__(self, input_dim, conv_layer, reduction=4): 20 | super(GAB, self).__init__() 21 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 22 | self.conv1 = nn.Conv2d(input_dim, input_dim // reduction, kernel_size=1, stride=1) 23 | self.conv2 = nn.Conv2d(input_dim // reduction, input_dim, kernel_size=1, stride=1) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.sigmoid = nn.Sigmoid() 26 | 27 | def forward(self, x): 28 | z = self.global_avgpool(x) 29 | z = self.relu(self.conv1(z)) 30 | z = self.sigmoid(self.conv2(z)) 31 | return x * z 32 | 33 | class scSE(nn.Module): 34 | def __init__(self, dim, conv_layer, reduction=4): 35 | super(scSE, self).__init__() 36 | self.satt = SpatialAttention2d(dim, conv_layer) 37 | self.catt = GAB(dim, conv_layer, reduction) 38 | 39 | def forward(self, x): 40 | return self.satt(x) + self.catt(x) 41 | 42 | def resnet50_gn_ws(output_stride=16, use_scse=True): 43 | if output_stride == 16: 44 | strides = (1, 2, 2, 1) 45 | dilations = (1, 1, 1, 2) 46 | elif output_stride == 8: 47 | strides = (1, 2, 1, 1) 48 | dilations = (1, 1, 2, 4) 49 | backbone = dict( 50 | type='ResNet', 51 | depth=50, 52 | num_stages=4, 53 | out_indices=(0, 1, 2, 3), 54 | mg_rates=(1, 2, 4), 55 | strides=strides, 56 | dilations=dilations, 57 | frozen_stages=0, 58 | style='pytorch', 59 | norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), 60 | conv_cfg=dict(type='ConvWS') 61 | ) 62 | # 63 | model = backbone.pop('type') 64 | model = eval(model)(**backbone) 65 | model.init_weights('open-mmlab://jhu/resnet50_gn_ws') 66 | model.input_range = [0, 255] 67 | model.mean = [102.9801, 115.9465, 122.7717] 68 | model.std = [1.0, 1.0, 1.0] 69 | # 70 | low_level = nn.Sequential(model.conv1, model.gn1, model.relu, model.maxpool, model.layer1) 71 | if use_scse: 72 | encoder = nn.Sequential(low_level, 73 | scSE(256, nn.Conv2d), 74 | model.layer2, 75 | scSE(512, nn.Conv2d), 76 | model.layer3, 77 | scSE(1024, nn.Conv2d), 78 | model.layer4, 79 | scSE(2048, nn.Conv2d)) 80 | else: 81 | encoder = model 82 | return (encoder, low_level), [256, 2048], model 83 | 84 | def resnet101_gn_ws(output_stride=16, use_scse=True): 85 | if output_stride == 16: 86 | strides = (1, 2, 2, 1) 87 | dilations = (1, 1, 1, 2) 88 | elif output_stride == 8: 89 | strides = (1, 2, 1, 1) 90 | dilations = (1, 1, 2, 4) 91 | backbone = dict( 92 | type='ResNet', 93 | depth=101, 94 | num_stages=4, 95 | out_indices=(0, 1, 2, 3), 96 | mg_rates=(1, 2, 4), 97 | strides=strides, 98 | dilations=dilations, 99 | frozen_stages=0, 100 | style='pytorch', 101 | norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), 102 | conv_cfg=dict(type='ConvWS') 103 | ) 104 | # 105 | model = backbone.pop('type') 106 | model = eval(model)(**backbone) 107 | model.init_weights('open-mmlab://jhu/resnet101_gn_ws') 108 | model.input_range = [0, 255] 109 | model.mean = [123.675, 116.28, 103.53] 110 | model.std = [58.395, 57.12, 57.375] 111 | # 112 | low_level = nn.Sequential(model.conv1, model.gn1, model.relu, model.maxpool, model.layer1) 113 | if use_scse: 114 | encoder = nn.Sequential(low_level, 115 | scSE(256, nn.Conv2d), 116 | model.layer2, 117 | scSE(512, nn.Conv2d), 118 | model.layer3, 119 | scSE(1024, nn.Conv2d), 120 | model.layer4, 121 | scSE(2048, nn.Conv2d)) 122 | else: 123 | encoder = model 124 | return (encoder, low_level), [256, 2048], model 125 | 126 | 127 | def resnext50_gn_ws(output_stride=16, use_scse=True): 128 | if output_stride == 16: 129 | strides = (1, 2, 2, 1) 130 | dilations = (1, 1, 1, 2) 131 | elif output_stride == 8: 132 | strides = (1, 2, 1, 1) 133 | dilations = (1, 1, 2, 4) 134 | backbone = dict( 135 | type='ResNeXt', 136 | depth=50, 137 | groups=32, 138 | base_width=4, 139 | num_stages=4, 140 | out_indices=(0, 1, 2, 3), 141 | mg_rates=(1, 2, 4), 142 | strides=strides, 143 | dilations=dilations, 144 | frozen_stages=0, 145 | style='pytorch', 146 | norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), 147 | conv_cfg=dict(type='ConvWS') 148 | ) 149 | # 150 | model = backbone.pop('type') 151 | model = eval(model)(**backbone) 152 | model.init_weights('open-mmlab://jhu/resnext50_32x4d_gn_ws') 153 | model.input_range = [0, 255] 154 | model.mean = [123.675, 116.28, 103.53] 155 | model.std = [58.395, 57.12, 57.375] 156 | # 157 | low_level = nn.Sequential(model.conv1, model.gn1, model.relu, model.maxpool, model.layer1) 158 | if use_scse: 159 | encoder = nn.Sequential(low_level, 160 | scSE(256, nn.Conv2d), 161 | model.layer2, 162 | scSE(512, nn.Conv2d), 163 | model.layer3, 164 | scSE(1024, nn.Conv2d), 165 | model.layer4, 166 | scSE(2048, nn.Conv2d)) 167 | else: 168 | encoder = model 169 | return (encoder, low_level), [256, 2048], model 170 | 171 | def resnext101_gn_ws(output_stride=16, use_scse=True): 172 | if output_stride == 16: 173 | strides = (1, 2, 2, 1) 174 | dilations = (1, 1, 1, 2) 175 | elif output_stride == 8: 176 | strides = (1, 2, 1, 1) 177 | dilations = (1, 1, 2, 4) 178 | backbone = dict( 179 | type='ResNeXt', 180 | depth=101, 181 | groups=32, 182 | base_width=4, 183 | num_stages=4, 184 | out_indices=(0, 1, 2, 3), 185 | mg_rates=(1, 2, 4), 186 | strides=strides, 187 | dilations=dilations, 188 | frozen_stages=0, 189 | style='pytorch', 190 | norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), 191 | conv_cfg=dict(type='ConvWS') 192 | ) 193 | # 194 | model = backbone.pop('type') 195 | model = eval(model)(**backbone) 196 | model.init_weights('open-mmlab://jhu/resnext101_32x4d_gn_ws') 197 | model.input_range = [0, 255] 198 | model.mean = [123.675, 116.28, 103.53] 199 | model.std = [58.395, 57.12, 57.375] 200 | # 201 | low_level = nn.Sequential(model.conv1, model.gn1, model.relu, model.maxpool, model.layer1) 202 | if use_scse: 203 | encoder = nn.Sequential(low_level, 204 | scSE(256, nn.Conv2d), 205 | model.layer2, 206 | scSE(512, nn.Conv2d), 207 | model.layer3, 208 | scSE(1024, nn.Conv2d), 209 | model.layer4, 210 | scSE(2048, nn.Conv2d)) 211 | else: 212 | encoder = model 213 | return (encoder, low_level), [256, 2048], model -------------------------------------------------------------------------------- /src/optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class AdamW(Optimizer): 7 | r"""Implements AdamW algorithm. 8 | 9 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 10 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 11 | 12 | Arguments: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float, optional): learning rate (default: 1e-3) 16 | betas (Tuple[float, float], optional): coefficients used for computing 17 | running averages of gradient and its square (default: (0.9, 0.999)) 18 | eps (float, optional): term added to the denominator to improve 19 | numerical stability (default: 1e-8) 20 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 21 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 22 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 23 | (default: False) 24 | 25 | .. _Adam\: A Method for Stochastic Optimization: 26 | https://arxiv.org/abs/1412.6980 27 | .. _Decoupled Weight Decay Regularization: 28 | https://arxiv.org/abs/1711.05101 29 | .. _On the Convergence of Adam and Beyond: 30 | https://openreview.net/forum?id=ryQu7f-RZ 31 | """ 32 | 33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 34 | weight_decay=1e-2, amsgrad=False): 35 | if not 0.0 <= lr: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if not 0.0 <= eps: 38 | raise ValueError("Invalid epsilon value: {}".format(eps)) 39 | if not 0.0 <= betas[0] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 41 | if not 0.0 <= betas[1] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay, amsgrad=amsgrad) 45 | super(AdamW, self).__init__(params, defaults) 46 | 47 | def __setstate__(self, state): 48 | super(AdamW, self).__setstate__(state) 49 | for group in self.param_groups: 50 | group.setdefault('amsgrad', False) 51 | 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | 55 | Arguments: 56 | closure (callable, optional): A closure that reevaluates the model 57 | and returns the loss. 58 | """ 59 | loss = None 60 | if closure is not None: 61 | loss = closure() 62 | 63 | for group in self.param_groups: 64 | for p in group['params']: 65 | if p.grad is None: 66 | continue 67 | 68 | # Perform stepweight decay 69 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 70 | 71 | # Perform optimization step 72 | grad = p.grad.data 73 | if grad.is_sparse: 74 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 75 | amsgrad = group['amsgrad'] 76 | 77 | state = self.state[p] 78 | 79 | # State initialization 80 | if len(state) == 0: 81 | state['step'] = 0 82 | # Exponential moving average of gradient values 83 | state['exp_avg'] = torch.zeros_like(p.data) 84 | # Exponential moving average of squared gradient values 85 | state['exp_avg_sq'] = torch.zeros_like(p.data) 86 | if amsgrad: 87 | # Maintains max of all exp. moving avg. of sq. grad. values 88 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 89 | 90 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 91 | if amsgrad: 92 | max_exp_avg_sq = state['max_exp_avg_sq'] 93 | beta1, beta2 = group['betas'] 94 | 95 | state['step'] += 1 96 | 97 | # Decay the first and second moment running average coefficient 98 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 99 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 100 | if amsgrad: 101 | # Maintains the maximum of all 2nd moment running avg. till now 102 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 103 | # Use the max. for normalizing running avg. of gradient 104 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 105 | else: 106 | denom = exp_avg_sq.sqrt().add_(group['eps']) 107 | 108 | bias_correction1 = 1 - beta1 ** state['step'] 109 | bias_correction2 = 1 - beta2 ** state['step'] 110 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 111 | 112 | p.data.addcdiv_(-step_size, exp_avg, denom) 113 | 114 | return loss 115 | 116 | 117 | class Nadam(Optimizer): 118 | 119 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 120 | schedule_decay=0.004,amsgrad=False): 121 | if not 0.0 <= betas[0] < 1.0: 122 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 123 | if not 0.0 <= betas[1] < 1.0: 124 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 125 | defaults = dict(lr=lr, betas=betas, eps=eps, 126 | amsgrad=amsgrad,schedule_decay=schedule_decay) 127 | super(Nadam, self).__init__(params, defaults) 128 | 129 | def step(self, closure=None): 130 | loss = None 131 | if closure is not None: 132 | loss = closure() 133 | 134 | for group in self.param_groups: 135 | for p in group['params']: 136 | if p.grad is None: 137 | continue 138 | grad = p.grad.data 139 | if grad.is_sparse: 140 | raise RuntimeError('Nadam does not support sparse gradients, please consider SparseAdam instead') 141 | amsgrad = group['amsgrad'] 142 | 143 | state = self.state[p] 144 | 145 | # State initialization 146 | if len(state) == 0: 147 | state['step'] = 0 148 | # Exponential moving average of gradient values 149 | state['exp_avg'] = torch.zeros_like(p.data) 150 | # Exponential moving average of squared gradient values 151 | state['exp_avg_sq'] = torch.zeros_like(p.data) 152 | 153 | state['m_schedule'] = 1 154 | if amsgrad: 155 | # Maintains max of all exp. moving avg. of sq. grad. values 156 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 157 | 158 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 159 | if amsgrad: 160 | max_exp_avg_sq = state['max_exp_avg_sq'] 161 | beta1, beta2 = group['betas'] 162 | 163 | 164 | state['step'] += 1 165 | momentum_cache_t = beta1 * ( 166 | 1. - 0.5 * math.pow(0.96, state['step'] * group['schedule_decay'] )) 167 | momentum_cache_t_1 = beta1 * ( 168 | 1. - 0.5 * math.pow(0.96, (state['step']+1) * group['schedule_decay'] )) 169 | state['m_schedule'] = state['m_schedule'] * momentum_cache_t 170 | 171 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 172 | m_t_prime = exp_avg/(1 - state['m_schedule'] * momentum_cache_t_1) 173 | 174 | g_prime = grad.div(1 - state['m_schedule']) 175 | m_t_bar = (1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime 176 | 177 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 178 | if amsgrad: 179 | # Maintains the maximum of all 2nd moment running avg. till now 180 | torch.max(max_exp_avg_sq, exp_avg_sq , out=max_exp_avg_sq) 181 | # Use the max. for normalizing running avg. of gradient 182 | v_t_prime = max_exp_avg_sq/(1 - beta2 ** state['step']) 183 | else: 184 | v_t_prime = exp_avg_sq / (1 - beta2 ** state['step']) 185 | 186 | denom = v_t_prime.sqrt().add_(group['eps']) 187 | p.data.addcdiv_(-group['lr'], m_t_bar , denom) 188 | 189 | return loss -------------------------------------------------------------------------------- /src/models/vnet/vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from segmentation_models_pytorch.base.model import Model 7 | from segmentation_models_pytorch.base.encoder_decoder import EncoderDecoder 8 | # from segmentation_models_pytorch.common.blocks import Conv2dReLU 9 | from segmentation_models_pytorch.encoders import get_encoder 10 | 11 | from .cbam import CBAM_Module 12 | from .aspp import ASPP, GroupNorm32 13 | from .scse import SCSEModule 14 | 15 | 16 | class ConvBn2d(nn.Module): 17 | def __init__(self, in_channels, out_channels, norm_layer, kernel_size=(3,3), stride=(1,1), padding=(1,1)): 18 | super(ConvBn2d, self).__init__() 19 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 20 | self.bn = norm_layer(out_channels) 21 | 22 | def forward(self, z): 23 | x = self.conv(z) 24 | x = self.bn(x) 25 | return x 26 | 27 | 28 | class ConvBnRelu2d(nn.Module): 29 | def __init__(self, in_channels, out_channels, norm_layer, kernel_size=(3,3), stride=(1,1), padding=(1,1)): 30 | super(ConvBnRelu2d, self).__init__() 31 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 32 | self.bn = norm_layer(out_channels) 33 | 34 | def forward(self, z): 35 | x = self.conv(z) 36 | x = self.bn(x) 37 | x = F.relu(x, inplace=True) 38 | return x 39 | 40 | 41 | class DecoderBlock(nn.Module): 42 | def __init__(self, 43 | in_channels, 44 | out_channels, 45 | norm_layer=GroupNorm32, 46 | reduction=16, 47 | attention_kernel_size=3, 48 | attention_type='none', 49 | reslink=False, 50 | ): 51 | super().__init__() 52 | self.block = nn.Sequential( 53 | ConvBnRelu2d(in_channels, out_channels, kernel_size=3, padding=1, norm_layer=norm_layer), 54 | ConvBnRelu2d(out_channels, out_channels, kernel_size=3, padding=1, norm_layer=norm_layer), 55 | ) 56 | 57 | self.attention = not attention_type == 'none' 58 | self.attention_type = attention_type 59 | self.reslink = reslink 60 | 61 | if self.attention_type.find('cbam') >= 0: 62 | self.channel_gate = CBAM_Module(out_channels, reduction, attention_kernel_size) 63 | elif self.attention_type.find('scse') >= 0: 64 | self.channel_gate = SCSEModule(out_channels, reduction) 65 | 66 | if self.reslink: 67 | self.shortcut = ConvBn2d( 68 | in_channels, 69 | out_channels, 70 | kernel_size=(1, 1), 71 | stride=(1, 1), 72 | padding=(0, 0), 73 | norm_layer=norm_layer 74 | ) 75 | 76 | def forward(self, x): 77 | x, skip = x 78 | x = F.interpolate(x, scale_factor=2, mode='nearest') 79 | if skip is not None: 80 | x = torch.cat([x, skip], dim=1) 81 | 82 | if self.reslink: 83 | shortcut = self.shortcut(x) 84 | x = self.block(x) 85 | if self.attention: 86 | x = self.channel_gate(x) 87 | if self.reslink: 88 | x = F.relu(x + shortcut) 89 | return x 90 | 91 | 92 | class CenterBlock(DecoderBlock): 93 | 94 | def forward(self, x): 95 | return self.block(x) 96 | 97 | 98 | class UnetDecoder(Model): 99 | 100 | def __init__( 101 | self, 102 | encoder_channels, 103 | decoder_channels=(256, 128, 64, 32, 16), 104 | final_channels=1, 105 | center='none', 106 | group_norm=False, 107 | reslink=False, 108 | attention_type='none', 109 | multi_task=False 110 | ): 111 | super().__init__() 112 | 113 | norm_layer = GroupNorm32 if group_norm else nn.BatchNorm2d 114 | 115 | if center.find('none') >= 0: 116 | self.center = None 117 | else: 118 | channels = encoder_channels[0] 119 | if center.find('normal') >= 0: 120 | self.center = CenterBlock( 121 | channels, channels, 122 | norm_layer=norm_layer, 123 | attention_type=attention_type, 124 | reslink=reslink 125 | ) 126 | elif center.find('aspp') >= 0: 127 | self.center = ASPP(channels, channels, norm_layer) 128 | 129 | in_channels = self.compute_channels(encoder_channels, decoder_channels) 130 | out_channels = decoder_channels 131 | 132 | self.layer1 = DecoderBlock( 133 | in_channels[0], 134 | out_channels[0], 135 | norm_layer=norm_layer, 136 | attention_type=attention_type, 137 | reslink=reslink 138 | ) 139 | self.layer2 = DecoderBlock( 140 | in_channels[1], 141 | out_channels[1], 142 | norm_layer=norm_layer, 143 | attention_type=attention_type, 144 | reslink=reslink 145 | ) 146 | self.layer3 = DecoderBlock( 147 | in_channels[2], 148 | out_channels[2], 149 | norm_layer=norm_layer, 150 | attention_type=attention_type, 151 | reslink=reslink 152 | ) 153 | self.layer4 = DecoderBlock( 154 | in_channels[3], 155 | out_channels[3], 156 | norm_layer=norm_layer, 157 | attention_type=attention_type, 158 | reslink=reslink 159 | ) 160 | self.layer5 = DecoderBlock( 161 | in_channels[4], 162 | out_channels[4], 163 | norm_layer=norm_layer, 164 | attention_type=attention_type, 165 | reslink=reslink 166 | ) 167 | self.final_conv = nn.Conv2d(out_channels[4], final_channels, kernel_size=(1, 1)) 168 | 169 | self.multi_task = multi_task 170 | if self.multi_task: 171 | self.fc_1 = nn.Linear(encoder_channels[0], final_channels) 172 | 173 | self.initialize() 174 | 175 | def compute_channels(self, encoder_channels, decoder_channels): 176 | channels = [ 177 | encoder_channels[0] + encoder_channels[1], 178 | encoder_channels[2] + decoder_channels[0], 179 | encoder_channels[3] + decoder_channels[1], 180 | encoder_channels[4] + decoder_channels[2], 181 | 0 + decoder_channels[3], 182 | ] 183 | return channels 184 | 185 | def forward(self, x): 186 | encoder_head = x[0] 187 | skips = x[1:] 188 | 189 | if self.multi_task: 190 | encoder_features = F.adaptive_avg_pool2d(encoder_head, 1) 191 | encoder_features = encoder_features.view(encoder_features.size(0), -1) 192 | # print(encoder_features.shape) 193 | x_logit = self.fc_1(encoder_features) 194 | 195 | if self.center: 196 | encoder_head = self.center(encoder_head) 197 | 198 | x = self.layer1([encoder_head, skips[0]]) 199 | x = self.layer2([x, skips[1]]) 200 | x = self.layer3([x, skips[2]]) 201 | x = self.layer4([x, skips[3]]) 202 | x = self.layer5([x, None]) 203 | x = self.final_conv(x) 204 | 205 | if self.multi_task: 206 | return x, x_logit 207 | else: 208 | return x 209 | 210 | 211 | class VNet(EncoderDecoder): 212 | def __init__( 213 | self, 214 | encoder_name='resnet34', 215 | encoder_weights='imagenet', 216 | group_norm=True, 217 | decoder_channels=(256, 128, 64, 32, 16), 218 | classes=1, 219 | activation='sigmoid', 220 | center='none', # usefull for VGG models 221 | attention_type=None, 222 | reslink=False, 223 | multi_task=False 224 | ): 225 | assert center in ['none', 'normal', 'aspp'] 226 | assert attention_type in ['none', 'cbam', 'scse'] 227 | 228 | print("**" * 50) 229 | print("Encoder name: \t\t{}".format(encoder_name)) 230 | print("Center: \t\t{}".format(center)) 231 | print("Attention type: \t\t{}".format(attention_type)) 232 | print("Reslink: \t\t{}".format(reslink)) 233 | 234 | encoder = get_encoder( 235 | encoder_name, 236 | encoder_weights=encoder_weights 237 | ) 238 | 239 | decoder = UnetDecoder( 240 | encoder_channels=encoder.out_shapes, 241 | decoder_channels=decoder_channels, 242 | final_channels=classes, 243 | group_norm=group_norm, 244 | center=center, 245 | attention_type=attention_type, 246 | reslink=reslink, 247 | multi_task=multi_task 248 | ) 249 | 250 | super().__init__(encoder, decoder, activation) 251 | 252 | self.name = 'vnet-{}'.format(encoder_name) 253 | -------------------------------------------------------------------------------- /src/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.optimizer import Optimizer 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class CyclicLRFix(_LRScheduler): 7 | """Sets the learning rate of each parameter group according to 8 | cyclical learning rate policy (CLR). The policy cycles the learning 9 | rate between two boundaries with a constant frequency, as detailed in 10 | the paper `Cyclical Learning Rates for Training Neural Networks`_. 11 | The distance between the two boundaries can be scaled on a per-iteration 12 | or per-cycle basis. 13 | Cyclical learning rate policy changes the learning rate after every batch. 14 | `step` should be called after a batch has been used for training. 15 | This class has three built-in policies, as put forth in the paper: 16 | "triangular": 17 | A basic triangular cycle w/ no amplitude scaling. 18 | "triangular2": 19 | A basic triangular cycle that scales initial amplitude by half each cycle. 20 | "exp_range": 21 | A cycle that scales initial amplitude by gamma**(cycle iterations) at each 22 | cycle iteration. 23 | This implementation was adapted from the github repo: `bckenstler/CLR`_ 24 | Args: 25 | optimizer (Optimizer): Wrapped optimizer. 26 | base_lr (float or list): Initial learning rate which is the 27 | lower boundary in the cycle for each parameter group. 28 | max_lr (float or list): Upper learning rate boundaries in the cycle 29 | for each parameter group. Functionally, 30 | it defines the cycle amplitude (max_lr - base_lr). 31 | The lr at any cycle is the sum of base_lr 32 | and some scaling of the amplitude; therefore 33 | max_lr may not actually be reached depending on 34 | scaling function. 35 | step_size_up (int): Number of training iterations in the 36 | increasing half of a cycle. Default: 2000 37 | step_size_down (int): Number of training iterations in the 38 | decreasing half of a cycle. If step_size_down is None, 39 | it is set to step_size_up. Default: None 40 | mode (str): One of {triangular, triangular2, exp_range}. 41 | Values correspond to policies detailed above. 42 | If scale_fn is not None, this argument is ignored. 43 | Default: 'triangular' 44 | gamma (float): Constant in 'exp_range' scaling function: 45 | gamma**(cycle iterations) 46 | Default: 1.0 47 | scale_fn (function): Custom scaling policy defined by a single 48 | argument lambda function, where 49 | 0 <= scale_fn(x) <= 1 for all x >= 0. 50 | If specified, then 'mode' is ignored. 51 | Default: None 52 | scale_mode (str): {'cycle', 'iterations'}. 53 | Defines whether scale_fn is evaluated on 54 | cycle number or cycle iterations (training 55 | iterations since start of cycle). 56 | Default: 'cycle' 57 | cycle_momentum (bool): If ``True``, momentum is cycled inversely 58 | to learning rate between 'base_momentum' and 'max_momentum'. 59 | Default: True 60 | base_momentum (float or list): Lower momentum boundaries in the cycle 61 | for each parameter group. Note that momentum is cycled inversely 62 | to learning rate; at the peak of a cycle, momentum is 63 | 'base_momentum' and learning rate is 'max_lr'. 64 | Default: 0.8 65 | max_momentum (float or list): Upper momentum boundaries in the cycle 66 | for each parameter group. Functionally, 67 | it defines the cycle amplitude (max_momentum - base_momentum). 68 | The momentum at any cycle is the difference of max_momentum 69 | and some scaling of the amplitude; therefore 70 | base_momentum may not actually be reached depending on 71 | scaling function. Note that momentum is cycled inversely 72 | to learning rate; at the start of a cycle, momentum is 'max_momentum' 73 | and learning rate is 'base_lr' 74 | Default: 0.9 75 | last_epoch (int): The index of the last batch. This parameter is used when 76 | resuming a training job. Since `step()` should be invoked after each 77 | batch instead of after each epoch, this number represents the total 78 | number of *batches* computed, not the total number of epochs computed. 79 | When last_epoch=-1, the schedule is started from the beginning. 80 | Default: -1 81 | Example: 82 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 83 | >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) 84 | >>> data_loader = torch.utils.data.DataLoader(...) 85 | >>> for epoch in range(10): 86 | >>> for batch in data_loader: 87 | >>> train_batch(...) 88 | >>> scheduler.step() 89 | .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 90 | .. _bckenstler/CLR: https://github.com/bckenstler/CLR 91 | """ 92 | 93 | def __init__(self, 94 | optimizer, 95 | base_lr, 96 | max_lr, 97 | step_size_up=2000, 98 | step_size_down=None, 99 | mode='triangular', 100 | gamma=1., 101 | scale_fn=None, 102 | scale_mode='cycle', 103 | cycle_momentum=True, 104 | base_momentum=0.8, 105 | max_momentum=0.9, 106 | last_epoch=-1): 107 | 108 | if not isinstance(optimizer, Optimizer): 109 | raise TypeError('{} is not an Optimizer'.format( 110 | type(optimizer).__name__)) 111 | self.optimizer = optimizer 112 | 113 | base_lrs = self._format_param('base_lr', optimizer, base_lr) 114 | if last_epoch == -1: 115 | for lr, group in zip(base_lrs, optimizer.param_groups): 116 | group['lr'] = lr 117 | 118 | self.max_lrs = self._format_param('max_lr', optimizer, max_lr) 119 | 120 | step_size_up = float(step_size_up) 121 | step_size_down = float(step_size_down) if step_size_down is not None else step_size_up 122 | self.total_size = step_size_up + step_size_down 123 | self.step_ratio = step_size_up / self.total_size 124 | 125 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 126 | and scale_fn is None: 127 | raise ValueError('mode is invalid and scale_fn is None') 128 | 129 | self.mode = mode 130 | self.gamma = gamma 131 | 132 | if scale_fn is None: 133 | if self.mode == 'triangular': 134 | self.scale_fn = self._triangular_scale_fn 135 | self.scale_mode = 'cycle' 136 | elif self.mode == 'triangular2': 137 | self.scale_fn = self._triangular2_scale_fn 138 | self.scale_mode = 'cycle' 139 | elif self.mode == 'exp_range': 140 | self.scale_fn = self._exp_range_scale_fn 141 | self.scale_mode = 'iterations' 142 | else: 143 | self.scale_fn = scale_fn 144 | self.scale_mode = scale_mode 145 | 146 | self.cycle_momentum = cycle_momentum 147 | if cycle_momentum: 148 | if 'momentum' not in optimizer.defaults: 149 | raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') 150 | 151 | base_momentums = self._format_param('base_momentum', optimizer, base_momentum) 152 | if last_epoch == -1: 153 | for momentum, group in zip(base_momentums, optimizer.param_groups): 154 | group['momentum'] = momentum 155 | self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups)) 156 | self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) 157 | 158 | super(CyclicLRFix, self).__init__(optimizer, last_epoch) 159 | 160 | def _format_param(self, name, optimizer, param): 161 | """Return correctly formatted lr/momentum for each param group.""" 162 | if isinstance(param, (list, tuple)): 163 | if len(param) != len(optimizer.param_groups): 164 | raise ValueError("expected {} values for {}, got {}".format( 165 | len(optimizer.param_groups), name, len(param))) 166 | return param 167 | else: 168 | return [param] * len(optimizer.param_groups) 169 | 170 | def _triangular_scale_fn(self, x): 171 | return 1. 172 | 173 | def _triangular2_scale_fn(self, x): 174 | return 1 / (2. ** (x - 1)) 175 | 176 | def _exp_range_scale_fn(self, x): 177 | return self.gamma**(x) 178 | 179 | def get_lr(self): 180 | """Calculates the learning rate at batch index. This function treats 181 | `self.last_epoch` as the last batch index. 182 | If `self.cycle_momentum` is ``True``, this function has a side effect of 183 | updating the optimizer's momentum. 184 | """ 185 | cycle = math.floor(1 + self.last_epoch / self.total_size) 186 | x = 1. + self.last_epoch / self.total_size - cycle 187 | if x <= self.step_ratio: 188 | scale_factor = x / self.step_ratio 189 | else: 190 | scale_factor = (x - 1) / (self.step_ratio - 1) 191 | 192 | lrs = [] 193 | for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): 194 | base_height = (max_lr - base_lr) * scale_factor 195 | if self.scale_mode == 'cycle': 196 | lr = base_lr + base_height * self.scale_fn(cycle) 197 | else: 198 | lr = base_lr + base_height * self.scale_fn(self.last_epoch) 199 | lrs.append(lr) 200 | 201 | if self.cycle_momentum: 202 | momentums = [] 203 | for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): 204 | base_height = (max_momentum - base_momentum) * scale_factor 205 | if self.scale_mode == 'cycle': 206 | momentum = max_momentum - base_height * self.scale_fn(cycle) 207 | else: 208 | momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) 209 | momentums.append(momentum) 210 | for param_group, momentum in zip(self.optimizer.param_groups, momentums): 211 | param_group['momentum'] = momentum 212 | 213 | return lrs 214 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import pandas as pd 5 | from torch.utils.data import Dataset 6 | from tqdm import tqdm 7 | import SimpleITK 8 | import scipy.ndimage as ndimage 9 | import SimpleITK as sitk 10 | 11 | 12 | UPPER_BOUND = 400 13 | LOWER_BOUND = -1000 14 | 15 | 16 | def load_ct_images(path): 17 | image = SimpleITK.ReadImage(path) 18 | spacing = image.GetSpacing()[-1] 19 | image = SimpleITK.GetArrayFromImage(image).astype(np.float32) 20 | return image, spacing 21 | 22 | 23 | def load_itkfilewithtrucation(filename, upper=200, lower=-200): 24 | """ 25 | load mhd files,set truncted value range and normalization 0-255 26 | :param filename: 27 | :param upper: 28 | :param lower: 29 | :return: 30 | """ 31 | # 1,tructed outside of liver value 32 | srcitkimage = sitk.Cast(sitk.ReadImage(filename), sitk.sitkFloat32) 33 | srcitkimagearray = sitk.GetArrayFromImage(srcitkimage) 34 | srcitkimagearray[srcitkimagearray > upper] = upper 35 | srcitkimagearray[srcitkimagearray < lower] = lower 36 | # 2,get tructed outside of liver value image 37 | sitktructedimage = sitk.GetImageFromArray(srcitkimagearray) 38 | origin = np.array(srcitkimage.GetOrigin()) 39 | spacing = np.array(srcitkimage.GetSpacing()) 40 | sitktructedimage.SetSpacing(spacing) 41 | sitktructedimage.SetOrigin(origin) 42 | # 3 normalization value to 0-255 43 | rescalFilt = sitk.RescaleIntensityImageFilter() 44 | rescalFilt.SetOutputMaximum(255) 45 | rescalFilt.SetOutputMinimum(0) 46 | itkimage = rescalFilt.Execute(sitk.Cast(sitktructedimage, sitk.sitkFloat32)) 47 | return sitk.GetArrayFromImage(itkimage) 48 | 49 | 50 | def resize(image, mask, spacing, slice_thickness, scale_ratio): 51 | image = (image - LOWER_BOUND) / (UPPER_BOUND - LOWER_BOUND) 52 | image[image > 1] = 1. 53 | image[image < 0] = 0. 54 | image = image.astype(np.float32) 55 | 56 | if slice_thickness and scale_ratio: 57 | image = ndimage.zoom(image, (spacing / slice_thickness, scale_ratio, scale_ratio), order=3) 58 | mask = ndimage.zoom(mask, (spacing / slice_thickness, scale_ratio, scale_ratio), order=0) 59 | return image, mask 60 | 61 | 62 | def load_patient(imgpath, mskpath, slice_thickness=None, scale_ratio=None): 63 | image, spacing = load_ct_images(imgpath) 64 | 65 | mask, _ = load_ct_images(mskpath) 66 | image, mask = resize(image, mask, spacing, slice_thickness, scale_ratio) 67 | return image, mask 68 | 69 | 70 | def pad_if_need(image, mask, patch): 71 | assert image.shape == mask.shape 72 | 73 | n_slices, x, y = image.shape 74 | if n_slices < patch: 75 | padding = patch - n_slices 76 | offset = padding // 2 77 | image = np.pad(image, (offset, patch - n_slices - offset), 'edge') 78 | mask = np.pad(mask, (offset, patch - n_slices - offset), 'edge') 79 | 80 | return image, mask 81 | 82 | 83 | def slice_window(image, mask, slice, patch): 84 | image, mask = pad_if_need(image, mask, patch) 85 | n_slices, x, y = image.shape 86 | idx = 0 87 | 88 | image_patches = [] 89 | mask_patches = [] 90 | 91 | while idx + patch <= n_slices: 92 | image_patch = image[idx:idx + patch] 93 | mask_patch = mask[idx:idx + patch] 94 | 95 | # Save patch 96 | image_patches.append(image_patch) 97 | mask_patches.append(mask_patch) 98 | 99 | idx += slice 100 | 101 | return image_patches, mask_patches 102 | 103 | 104 | def slice_builder(imgpath, mskpath, slice_thichness, scale_ratio, slice, patch, save_dir): 105 | image, mask = load_patient(imgpath, mskpath, slice_thichness, scale_ratio) 106 | image_patches, mask_patches = slice_window(image, mask, slice, patch) 107 | patient_id = imgpath.split("/")[-2] 108 | save_dir = os.path.join(save_dir, patient_id) 109 | os.makedirs(save_dir, exist_ok=True) 110 | 111 | image_paths = [] 112 | mask_paths = [] 113 | for i, (image_patch, mask_patch) in enumerate(zip(image_patches, mask_patches)): 114 | image_path = os.path.join(save_dir, f'image.{i}.npy') 115 | mask_path = os.path.join(save_dir, f'mask.{i}.npy') 116 | 117 | image_paths.append(image_path) 118 | mask_paths.append(mask_path) 119 | 120 | np.save(image_path, image_patch) 121 | np.save(mask_path, mask_patch) 122 | 123 | df = pd.DataFrame({ 124 | 'image': image_paths, 125 | 'mask': mask_paths 126 | }) 127 | 128 | df['patient_id'] = patient_id 129 | return df 130 | 131 | 132 | def slice_builder_2d(imgpath, mskpath, save_dir): 133 | image, mask = load_patient(imgpath, mskpath) 134 | patient_id = imgpath.split("/")[-2] 135 | save_dir = os.path.join(save_dir, patient_id) 136 | os.makedirs(save_dir, exist_ok=True) 137 | 138 | image_paths = [] 139 | mask_paths = [] 140 | for i, (image_slice, mask_slice) in enumerate(zip(image, mask)): 141 | # if np.any(mask_slice): 142 | image_path = os.path.join(save_dir, f'image.{i}.npy') 143 | mask_path = os.path.join(save_dir, f'mask.{i}.npy') 144 | 145 | image_paths.append(image_path) 146 | mask_paths.append(mask_path) 147 | 148 | np.save(image_path, image_slice) 149 | np.save(mask_path, mask_slice) 150 | 151 | df = pd.DataFrame({ 152 | 'image': image_paths, 153 | 'mask': mask_paths 154 | }) 155 | 156 | df['patient_id'] = patient_id 157 | return df 158 | 159 | 160 | def random_crop(image, mask, patch): 161 | n_slices = image.shape[0] 162 | start = 0 163 | end = int(n_slices - patch) 164 | 165 | rnd_idx = np.random.randint(start, end) 166 | return image[rnd_idx:rnd_idx + patch, :, :], mask[rnd_idx:rnd_idx + patch, :, :] 167 | 168 | 169 | def center_crop(image, mask, patch): 170 | n_slices = image.shape[0] 171 | mid = n_slices // 2 172 | start = int(mid - patch // 2) 173 | end = int(mid + patch // 2) 174 | 175 | return image[start:end, :, :], mask[start:end, :, :] 176 | 177 | 178 | class StructSegTrain2D(Dataset): 179 | 180 | def __init__(self, 181 | csv_file, 182 | transform 183 | ): 184 | df = pd.read_csv(csv_file) 185 | self.transform = transform 186 | self.images = df['image'].values 187 | self.masks = df['mask'].values 188 | 189 | def __len__(self): 190 | return len(self.images) 191 | 192 | def __getitem__(self, idx): 193 | 194 | image = self.images[idx] 195 | mask = self.masks[idx] 196 | 197 | image = np.load(image) 198 | mask = np.load(mask) 199 | 200 | image = np.stack((image, image, image), axis=-1).astype(np.float32) 201 | 202 | if self.transform: 203 | transform = self.transform(image=image, mask=mask) 204 | image = transform['image'] 205 | mask = transform['mask'] 206 | 207 | # image = np.stack((image, image, image), axis=0).astype(np.float32) 208 | image = np.transpose(image, (2, 0, 1)) 209 | # mask = np.transpose(mask, (2, 0, 1)) 210 | 211 | # image = np.expand_dims(image, axis=0) 212 | mask = mask.astype(np.int) 213 | 214 | return { 215 | 'images': image, 216 | 'targets': mask 217 | } 218 | 219 | 220 | def cut_edge(data, keep_margin): 221 | ''' 222 | function that cuts zero edge 223 | ''' 224 | D, H, W = data.shape 225 | D_s, D_e = 0, D - 1 226 | H_s, H_e = 0, H - 1 227 | W_s, W_e = 0, W - 1 228 | 229 | while D_s < D: 230 | if data[D_s].sum() != 0: 231 | break 232 | D_s += 1 233 | while D_e > D_s: 234 | if data[D_e].sum() != 0: 235 | break 236 | D_e -= 1 237 | while H_s < H: 238 | if data[:, H_s].sum() != 0: 239 | break 240 | H_s += 1 241 | while H_e > H_s: 242 | if data[:, H_e].sum() != 0: 243 | break 244 | H_e -= 1 245 | while W_s < W: 246 | if data[:, :, W_s].sum() != 0: 247 | break 248 | W_s += 1 249 | while W_e > W_s: 250 | if data[:, :, W_e].sum() != 0: 251 | break 252 | W_e -= 1 253 | 254 | if keep_margin != 0: 255 | D_s = max(0, D_s - keep_margin) 256 | D_e = min(D - 1, D_e + keep_margin) 257 | H_s = max(0, H_s - keep_margin) 258 | H_e = min(H - 1, H_e + keep_margin) 259 | W_s = max(0, W_s - keep_margin) 260 | W_e = min(W - 1, W_e + keep_margin) 261 | 262 | return int(D_s), int(D_e), int(H_s), int(H_e), int(W_s), int(W_e) 263 | 264 | 265 | 266 | import random 267 | class StructSegTrain3D(Dataset): 268 | 269 | def __init__(self, 270 | csv_file, 271 | transform, 272 | mode='train' 273 | ): 274 | df = pd.read_csv(csv_file) 275 | self.transform = transform 276 | self.patients = df['patient_id'].unique() 277 | self.root = "/data/Thoracic_OAR/" 278 | self.crop_size = (16, 256, 256) 279 | self.mode = mode 280 | 281 | def __len__(self): 282 | return len(self.patients) 283 | 284 | def __getitem__(self, idx): 285 | 286 | patient_id = self.patients[idx] 287 | data_path = os.path.join(self.root, str(patient_id), 'data.nii.gz') 288 | label_path = os.path.join(self.root, str(patient_id), 'label.nii.gz') 289 | 290 | image, mask = load_patient(data_path, label_path) 291 | # image[image < 0.5] = 0 292 | # image[image >= 0.5] = 1 293 | # # image2 = image > 0 294 | # margin = 32 295 | # min_D_s, max_D_e, min_H_s, max_H_e, min_W_s, max_W_e = cut_edge(image, margin) 296 | # 297 | # image = image[min_D_s:max_D_e + 1, min_H_s: max_H_e + 1, min_W_s:max_W_e + 1] 298 | # mask = mask[min_D_s:max_D_e + 1, min_H_s: max_H_e + 1, min_W_s:max_W_e + 1] 299 | 300 | # print(image.shape) 301 | 302 | D, H, W = image.shape 303 | 304 | if self.mode == 'train': 305 | rd = random.randint(0, D - self.crop_size[0]) 306 | rh = random.randint(0, H - self.crop_size[1]) 307 | rw = random.randint(0, W - self.crop_size[2]) 308 | 309 | # rd = (D - self.crop_size[0]) // 2 310 | # rh = (H - self.crop_size[1]) // 2 311 | # rw = (W - self.crop_size[2]) // 2 312 | else: 313 | rd = (D - self.crop_size[0]) // 2 314 | rh = (H - self.crop_size[1]) // 2 315 | rw = (W - self.crop_size[2]) // 2 316 | 317 | image = image[rd: rd + self.crop_size[0], rh: rh + self.crop_size[1], rw: rw + self.crop_size[2]] 318 | mask = mask[rd: rd + self.crop_size[0], rh: rh + self.crop_size[1], rw: rw + self.crop_size[2]] 319 | 320 | image = np.expand_dims(image, axis=0).astype(np.float32) 321 | mask = mask.astype(np.int) 322 | # mask = np.expand_dims(mask, axis=0).astype(np.int) 323 | 324 | return { 325 | 'images': image, 326 | 'targets': mask 327 | } -------------------------------------------------------------------------------- /src/models/ModelDeepLab/resnext.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | 5 | from mmdet.ops import DeformConv, ModulatedDeformConv 6 | from .resnet import Bottleneck as _Bottleneck 7 | from .resnet import ResNet 8 | from mmdet.models.utils import build_conv_layer, build_norm_layer 9 | 10 | 11 | class Bottleneck(_Bottleneck): 12 | 13 | def __init__(self, inplanes, planes, groups=1, base_width=4, **kwargs): 14 | """Bottleneck block for ResNeXt. 15 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, 16 | if it is "caffe", the stride-two layer is the first 1x1 conv layer. 17 | """ 18 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs) 19 | 20 | if groups == 1: 21 | width = self.planes 22 | else: 23 | width = math.floor(self.planes * (base_width / 64)) * groups 24 | 25 | self.norm1_name, norm1 = build_norm_layer( 26 | self.norm_cfg, width, postfix=1) 27 | self.norm2_name, norm2 = build_norm_layer( 28 | self.norm_cfg, width, postfix=2) 29 | self.norm3_name, norm3 = build_norm_layer( 30 | self.norm_cfg, self.planes * self.expansion, postfix=3) 31 | 32 | self.conv1 = build_conv_layer( 33 | self.conv_cfg, 34 | self.inplanes, 35 | width, 36 | kernel_size=1, 37 | stride=self.conv1_stride, 38 | bias=False) 39 | self.add_module(self.norm1_name, norm1) 40 | fallback_on_stride = False 41 | self.with_modulated_dcn = False 42 | if self.with_dcn: 43 | fallback_on_stride = self.dcn.get('fallback_on_stride', False) 44 | self.with_modulated_dcn = self.dcn.get('modulated', False) 45 | if not self.with_dcn or fallback_on_stride: 46 | self.conv2 = build_conv_layer( 47 | self.conv_cfg, 48 | width, 49 | width, 50 | kernel_size=3, 51 | stride=self.conv2_stride, 52 | padding=self.dilation, 53 | dilation=self.dilation, 54 | groups=groups, 55 | bias=False) 56 | else: 57 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN' 58 | groups = self.dcn.get('groups', 1) 59 | deformable_groups = self.dcn.get('deformable_groups', 1) 60 | if not self.with_modulated_dcn: 61 | conv_op = DeformConv 62 | offset_channels = 18 63 | else: 64 | conv_op = ModulatedDeformConv 65 | offset_channels = 27 66 | self.conv2_offset = nn.Conv2d( 67 | width, 68 | deformable_groups * offset_channels, 69 | kernel_size=3, 70 | stride=self.conv2_stride, 71 | padding=self.dilation, 72 | dilation=self.dilation) 73 | self.conv2 = conv_op( 74 | width, 75 | width, 76 | kernel_size=3, 77 | stride=self.conv2_stride, 78 | padding=self.dilation, 79 | dilation=self.dilation, 80 | groups=groups, 81 | deformable_groups=deformable_groups, 82 | bias=False) 83 | self.add_module(self.norm2_name, norm2) 84 | self.conv3 = build_conv_layer( 85 | self.conv_cfg, 86 | width, 87 | self.planes * self.expansion, 88 | kernel_size=1, 89 | bias=False) 90 | self.add_module(self.norm3_name, norm3) 91 | 92 | 93 | def make_multigrid(block, 94 | inplanes, 95 | planes, 96 | blocks, 97 | stride=1, 98 | dilation=1, 99 | groups=1, 100 | base_width=4, 101 | style='pytorch', 102 | with_cp=False, 103 | conv_cfg=None, 104 | norm_cfg=dict(type='BN'), 105 | dcn=None, 106 | gcb=None): 107 | downsample = None 108 | if stride != 1 or inplanes != planes * block.expansion: 109 | downsample = nn.Sequential( 110 | build_conv_layer( 111 | conv_cfg, 112 | inplanes, 113 | planes * block.expansion, 114 | kernel_size=1, 115 | stride=stride, 116 | bias=False), 117 | build_norm_layer(norm_cfg, planes * block.expansion)[1], 118 | ) 119 | 120 | layers = [] 121 | layers.append( 122 | block( 123 | inplanes=inplanes, 124 | planes=planes, 125 | stride=stride, 126 | dilation=blocks[0]*dilation, 127 | downsample=downsample, 128 | groups=groups, 129 | base_width=base_width, 130 | style=style, 131 | with_cp=with_cp, 132 | conv_cfg=conv_cfg, 133 | norm_cfg=norm_cfg, 134 | dcn=dcn, 135 | gcb=gcb)) 136 | inplanes = planes * block.expansion 137 | for i in range(1, len(blocks)): 138 | layers.append( 139 | block( 140 | inplanes=inplanes, 141 | planes=planes, 142 | stride=1, 143 | dilation=blocks[i]*dilation, 144 | groups=groups, 145 | base_width=base_width, 146 | style=style, 147 | with_cp=with_cp, 148 | conv_cfg=conv_cfg, 149 | norm_cfg=norm_cfg, 150 | dcn=dcn, 151 | gcb=gcb)) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def make_res_layer(block, 156 | inplanes, 157 | planes, 158 | blocks, 159 | stride=1, 160 | dilation=1, 161 | groups=1, 162 | base_width=4, 163 | style='pytorch', 164 | with_cp=False, 165 | conv_cfg=None, 166 | norm_cfg=dict(type='BN'), 167 | dcn=None, 168 | gcb=None): 169 | downsample = None 170 | if stride != 1 or inplanes != planes * block.expansion: 171 | downsample = nn.Sequential( 172 | build_conv_layer( 173 | conv_cfg, 174 | inplanes, 175 | planes * block.expansion, 176 | kernel_size=1, 177 | stride=stride, 178 | bias=False), 179 | build_norm_layer(norm_cfg, planes * block.expansion)[1], 180 | ) 181 | 182 | layers = [] 183 | layers.append( 184 | block( 185 | inplanes=inplanes, 186 | planes=planes, 187 | stride=stride, 188 | dilation=dilation, 189 | downsample=downsample, 190 | groups=groups, 191 | base_width=base_width, 192 | style=style, 193 | with_cp=with_cp, 194 | conv_cfg=conv_cfg, 195 | norm_cfg=norm_cfg, 196 | dcn=dcn, 197 | gcb=gcb)) 198 | inplanes = planes * block.expansion 199 | for i in range(1, blocks): 200 | layers.append( 201 | block( 202 | inplanes=inplanes, 203 | planes=planes, 204 | stride=1, 205 | dilation=dilation, 206 | groups=groups, 207 | base_width=base_width, 208 | style=style, 209 | with_cp=with_cp, 210 | conv_cfg=conv_cfg, 211 | norm_cfg=norm_cfg, 212 | dcn=dcn, 213 | gcb=gcb)) 214 | 215 | return nn.Sequential(*layers) 216 | 217 | class ResNeXt(ResNet): 218 | """ResNeXt backbone. 219 | 220 | Args: 221 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 222 | num_stages (int): Resnet stages, normally 4. 223 | groups (int): Group of resnext. 224 | base_width (int): Base width of resnext. 225 | strides (Sequence[int]): Strides of the first block of each stage. 226 | dilations (Sequence[int]): Dilation of each stage. 227 | out_indices (Sequence[int]): Output from which stages. 228 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 229 | layer is the 3x3 conv layer, otherwise the stride-two layer is 230 | the first 1x1 conv layer. 231 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means 232 | not freezing any parameters. 233 | norm_cfg (dict): dictionary to construct and config norm layer. 234 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 235 | freeze running stats (mean and var). Note: Effect on Batch Norm 236 | and its variants only. 237 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 238 | memory while slowing down the training speed. 239 | zero_init_residual (bool): whether to use zero init for last norm layer 240 | in resblocks to let them behave as identity. 241 | """ 242 | 243 | arch_settings = { 244 | 50: (Bottleneck, (3, 4, 6, 3)), 245 | 101: (Bottleneck, (3, 4, 23, 3)), 246 | 152: (Bottleneck, (3, 8, 36, 3)) 247 | } 248 | 249 | def __init__(self, groups=1, base_width=4, **kwargs): 250 | super(ResNeXt, self).__init__(**kwargs) 251 | self.groups = groups 252 | self.base_width = base_width 253 | 254 | self.inplanes = 64 255 | self.res_layers = [] 256 | for i, num_blocks in enumerate(self.stage_blocks): 257 | stride = self.strides[i] 258 | dilation = self.dilations[i] 259 | dcn = self.dcn if self.stage_with_dcn[i] else None 260 | gcb = self.gcb if self.stage_with_gcb[i] else None 261 | planes = 64 * 2**i 262 | # Allow for DCN + multigrid final block 263 | if i == len(self.stage_blocks) - 1: #and not True in self.stage_with_dcn: 264 | assert len(self.mg_rates) == num_blocks 265 | res_layer = make_multigrid( 266 | self.block, 267 | self.inplanes, 268 | planes, 269 | self.mg_rates, 270 | stride=stride, 271 | dilation=dilation, 272 | groups=self.groups, 273 | base_width=self.base_width, 274 | style=self.style, 275 | with_cp=self.with_cp, 276 | conv_cfg=self.conv_cfg, 277 | norm_cfg=self.norm_cfg, 278 | dcn=dcn, 279 | gcb=gcb) 280 | else: 281 | res_layer = make_res_layer( 282 | self.block, 283 | self.inplanes, 284 | planes, 285 | num_blocks, 286 | stride=stride, 287 | dilation=dilation, 288 | groups=self.groups, 289 | base_width=self.base_width, 290 | style=self.style, 291 | with_cp=self.with_cp, 292 | conv_cfg=self.conv_cfg, 293 | norm_cfg=self.norm_cfg, 294 | dcn=dcn, 295 | gcb=gcb) 296 | self.inplanes = planes * self.block.expansion 297 | layer_name = 'layer{}'.format(i + 1) 298 | self.add_module(layer_name, res_layer) 299 | self.res_layers.append(layer_name) 300 | 301 | self._freeze_stages() 302 | 303 | -------------------------------------------------------------------------------- /src/models/unet3d/buildingblocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): 7 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 8 | 9 | 10 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): 11 | """ 12 | Create a list of modules with together constitute a single conv layer with non-linearity 13 | and optional batchnorm/groupnorm. 14 | Args: 15 | in_channels (int): number of input channels 16 | out_channels (int): number of output channels 17 | order (string): order of things, e.g. 18 | 'cr' -> conv + ReLU 19 | 'crg' -> conv + ReLU + groupnorm 20 | 'cl' -> conv + LeakyReLU 21 | 'ce' -> conv + ELU 22 | num_groups (int): number of groups for the GroupNorm 23 | padding (int): add zero-padding to the input 24 | Return: 25 | list of tuple (name, module) 26 | """ 27 | assert 'c' in order, "Conv layer MUST be present" 28 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' 29 | 30 | modules = [] 31 | for i, char in enumerate(order): 32 | if char == 'r': 33 | modules.append(('ReLU', nn.ReLU(inplace=True))) 34 | elif char == 'l': 35 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) 36 | elif char == 'e': 37 | modules.append(('ELU', nn.ELU(inplace=True))) 38 | elif char == 'c': 39 | # add learnable bias only in the absence of gatchnorm/groupnorm 40 | bias = not ('g' in order or 'b' in order) 41 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) 42 | elif char == 'g': 43 | is_before_conv = i < order.index('c') 44 | assert not is_before_conv, 'GroupNorm MUST go after the Conv3d' 45 | # number of groups must be less or equal the number of channels 46 | if out_channels < num_groups: 47 | num_groups = out_channels 48 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=out_channels))) 49 | elif char == 'b': 50 | is_before_conv = i < order.index('c') 51 | if is_before_conv: 52 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) 53 | else: 54 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) 55 | else: 56 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") 57 | 58 | return modules 59 | 60 | 61 | class SingleConv(nn.Sequential): 62 | """ 63 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order 64 | of operations can be specified via the `order` parameter 65 | Args: 66 | in_channels (int): number of input channels 67 | out_channels (int): number of output channels 68 | kernel_size (int): size of the convolving kernel 69 | order (string): determines the order of layers, e.g. 70 | 'cr' -> conv + ReLU 71 | 'crg' -> conv + ReLU + groupnorm 72 | 'cl' -> conv + LeakyReLU 73 | 'ce' -> conv + ELU 74 | num_groups (int): number of groups for the GroupNorm 75 | """ 76 | 77 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): 78 | super(SingleConv, self).__init__() 79 | 80 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): 81 | self.add_module(name, module) 82 | 83 | 84 | class DoubleConv(nn.Sequential): 85 | """ 86 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). 87 | We use (Conv3d+ReLU+GroupNorm3d) by default. 88 | This can be changed however by providing the 'order' argument, e.g. in order 89 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'. 90 | Use padded convolutions to make sure that the output (H_out, W_out) is the same 91 | as (H_in, W_in), so that you don't have to crop in the decoder path. 92 | Args: 93 | in_channels (int): number of input channels 94 | out_channels (int): number of output channels 95 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder 96 | kernel_size (int): size of the convolving kernel 97 | order (string): determines the order of layers, e.g. 98 | 'cr' -> conv + ReLU 99 | 'crg' -> conv + ReLU + groupnorm 100 | 'cl' -> conv + LeakyReLU 101 | 'ce' -> conv + ELU 102 | num_groups (int): number of groups for the GroupNorm 103 | """ 104 | 105 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): 106 | super(DoubleConv, self).__init__() 107 | if encoder: 108 | # we're in the encoder path 109 | conv1_in_channels = in_channels 110 | conv1_out_channels = out_channels // 2 111 | if conv1_out_channels < in_channels: 112 | conv1_out_channels = in_channels 113 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels 114 | else: 115 | # we're in the decoder path, decrease the number of channels in the 1st convolution 116 | conv1_in_channels, conv1_out_channels = in_channels, out_channels 117 | conv2_in_channels, conv2_out_channels = out_channels, out_channels 118 | 119 | # conv1 120 | self.add_module('SingleConv1', 121 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) 122 | # conv2 123 | self.add_module('SingleConv2', 124 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) 125 | 126 | 127 | class ExtResNetBlock(nn.Module): 128 | """ 129 | Basic UNet block consisting of a SingleConv followed by the residual block. 130 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number 131 | of output channels is compatible with the residual block that follows. 132 | This block can be used instead of standard DoubleConv in the Encoder module. 133 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf 134 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. 135 | """ 136 | 137 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): 138 | super(ExtResNetBlock, self).__init__() 139 | 140 | # first convolution 141 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 142 | # residual block 143 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 144 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 145 | n_order = order 146 | for c in 'rel': 147 | n_order = n_order.replace(c, '') 148 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, 149 | num_groups=num_groups) 150 | 151 | # create non-linearity separately 152 | if 'l' in order: 153 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) 154 | elif 'e' in order: 155 | self.non_linearity = nn.ELU(inplace=True) 156 | else: 157 | self.non_linearity = nn.ReLU(inplace=True) 158 | 159 | def forward(self, x): 160 | # apply first convolution and save the output as a residual 161 | out = self.conv1(x) 162 | residual = out 163 | 164 | # residual block 165 | out = self.conv2(out) 166 | out = self.conv3(out) 167 | 168 | out += residual 169 | out = self.non_linearity(out) 170 | 171 | return out 172 | 173 | 174 | class Encoder(nn.Module): 175 | """ 176 | A single module from the encoder path consisting of the optional max 177 | pooling layer (one may specify the MaxPool kernel_size to be different 178 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic 179 | (make sure to use complementary scale_factor in the decoder path) followed by 180 | a DoubleConv module. 181 | Args: 182 | in_channels (int): number of input channels 183 | out_channels (int): number of output channels 184 | conv_kernel_size (int): size of the convolving kernel 185 | apply_pooling (bool): if True use MaxPool3d before DoubleConv 186 | pool_kernel_size (tuple): the size of the window to take a max over 187 | pool_type (str): pooling layer: 'max' or 'avg' 188 | basic_module(nn.Module): either ResNetBlock or DoubleConv 189 | conv_layer_order (string): determines the order of layers 190 | in `DoubleConv` module. See `DoubleConv` for more info. 191 | num_groups (int): number of groups for the GroupNorm 192 | """ 193 | 194 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, 195 | pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg', 196 | num_groups=8): 197 | super(Encoder, self).__init__() 198 | assert pool_type in ['max', 'avg'] 199 | if apply_pooling: 200 | if pool_type == 'max': 201 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) 202 | else: 203 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) 204 | else: 205 | self.pooling = None 206 | 207 | self.basic_module = basic_module(in_channels, out_channels, 208 | encoder=True, 209 | kernel_size=conv_kernel_size, 210 | order=conv_layer_order, 211 | num_groups=num_groups) 212 | 213 | def forward(self, x): 214 | if self.pooling is not None: 215 | x = self.pooling(x) 216 | x = self.basic_module(x) 217 | return x 218 | 219 | 220 | class Decoder(nn.Module): 221 | """ 222 | A single module for decoder path consisting of the upsample layer 223 | (either learned ConvTranspose3d or interpolation) followed by a DoubleConv 224 | module. 225 | Args: 226 | in_channels (int): number of input channels 227 | out_channels (int): number of output channels 228 | kernel_size (int): size of the convolving kernel 229 | scale_factor (tuple): used as the multiplier for the image H/W/D in 230 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation 231 | from the corresponding encoder 232 | basic_module(nn.Module): either ResNetBlock or DoubleConv 233 | conv_layer_order (string): determines the order of layers 234 | in `DoubleConv` module. See `DoubleConv` for more info. 235 | num_groups (int): number of groups for the GroupNorm 236 | """ 237 | 238 | def __init__(self, in_channels, out_channels, kernel_size=3, 239 | scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='crg', num_groups=8): 240 | super(Decoder, self).__init__() 241 | if basic_module == DoubleConv: 242 | # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling 243 | self.upsample = None 244 | else: 245 | # otherwise use ConvTranspose3d (bear in mind your GPU memory) 246 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder 247 | # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) 248 | # also scale the number of channels from in_channels to out_channels so that summation joining 249 | # works correctly 250 | self.upsample = nn.ConvTranspose3d(in_channels, 251 | out_channels, 252 | kernel_size=kernel_size, 253 | stride=scale_factor, 254 | padding=1, 255 | output_padding=1) 256 | # adapt the number of in_channels for the ExtResNetBlock 257 | in_channels = out_channels 258 | 259 | self.basic_module = basic_module(in_channels, out_channels, 260 | encoder=False, 261 | kernel_size=kernel_size, 262 | order=conv_layer_order, 263 | num_groups=num_groups) 264 | 265 | def forward(self, encoder_features, x): 266 | if self.upsample is None: 267 | # use nearest neighbor interpolation and concatenation joining 268 | output_size = encoder_features.size()[2:] 269 | x = F.interpolate(x, size=output_size, mode='nearest') 270 | # concatenate encoder_features (encoder path) with the upsampled input across channel dimension 271 | x = torch.cat((encoder_features, x), dim=1) 272 | else: 273 | # use ConvTranspose3d and summation joining 274 | x = self.upsample(x) 275 | x += encoder_features 276 | 277 | x = self.basic_module(x) 278 | return x 279 | 280 | 281 | class FinalConv(nn.Sequential): 282 | """ 283 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution 284 | which reduces the number of channels to 'out_channels'. 285 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively. 286 | We use (Conv3d+ReLU+GroupNorm3d) by default. 287 | This can be change however by providing the 'order' argument, e.g. in order 288 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. 289 | Args: 290 | in_channels (int): number of input channels 291 | out_channels (int): number of output channels 292 | kernel_size (int): size of the convolving kernel 293 | order (string): determines the order of layers, e.g. 294 | 'cr' -> conv + ReLU 295 | 'crg' -> conv + ReLU + groupnorm 296 | num_groups (int): number of groups for the GroupNorm 297 | """ 298 | 299 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): 300 | super(FinalConv, self).__init__() 301 | 302 | # conv1 303 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) 304 | 305 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels 306 | final_conv = nn.Conv3d(in_channels, out_channels, 1) 307 | self.add_module('final_conv', final_conv) 308 | -------------------------------------------------------------------------------- /src/models/ModelDeepLab/deeplab_jpu.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/jfzhang95/pytorch-deeplab-xception 2 | 3 | import models.ModelDeepLab.backbones_deeplab_jpu as backbones 4 | try: 5 | from model.encoding import * 6 | except: 7 | pass 8 | 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | import torch 13 | 14 | # Wrapper for GroupNorm with 32 channels 15 | class GroupNorm32(nn.GroupNorm): 16 | def __init__(self, num_channels): 17 | super(GroupNorm32, self).__init__(num_channels=num_channels, num_groups=32) 18 | 19 | ######## 20 | # ASPP # 21 | ######## 22 | 23 | class _ASPPModule(nn.Module): 24 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, norm_layer): 25 | super(_ASPPModule, self).__init__() 26 | self.norm = norm_layer(planes) 27 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 28 | stride=1, padding=padding, dilation=dilation, bias=False) 29 | self.elu = nn.ELU(True) 30 | 31 | def forward(self, x): 32 | x = self.atrous_conv(x) 33 | x = self.norm(x) 34 | return self.elu(x) 35 | 36 | class ASPP(nn.Module): 37 | def __init__(self, dilations, inplanes, planes, norm_layer, dropout=0.5): 38 | super(ASPP, self).__init__() 39 | 40 | self.aspp1 = _ASPPModule(inplanes, planes, 1, padding=0, dilation=dilations[0], norm_layer=norm_layer) 41 | self.aspp2 = _ASPPModule(inplanes, planes, 3, padding=dilations[1], dilation=dilations[1], norm_layer=norm_layer) 42 | self.aspp3 = _ASPPModule(inplanes, planes, 3, padding=dilations[2], dilation=dilations[2], norm_layer=norm_layer) 43 | self.aspp4 = _ASPPModule(inplanes, planes, 3, padding=dilations[3], dilation=dilations[3], norm_layer=norm_layer) 44 | 45 | self.norm1 = norm_layer(planes) 46 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 47 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), 48 | norm_layer(planes), 49 | nn.ELU(True)) 50 | self.conv1 = nn.Conv2d(5 * planes, planes, 1, bias=False) 51 | self.elu = nn.ELU(True) 52 | self.dropout = nn.Dropout2d(dropout) 53 | 54 | def forward(self, x): 55 | x1 = self.aspp1(x) 56 | x2 = self.aspp2(x) 57 | x3 = self.aspp3(x) 58 | x4 = self.aspp4(x) 59 | x5 = self.global_avg_pool(x) 60 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear') 61 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 62 | 63 | x = self.conv1(x) 64 | x = self.norm1(x) 65 | x = self.elu(x) 66 | 67 | return self.dropout(x) 68 | 69 | ####### 70 | # FPA # 71 | ####### 72 | 73 | # From phalanx 74 | class FPAv2(nn.Module): 75 | def __init__(self, input_dim, output_dim, norm_layer): 76 | super(FPAv2, self).__init__() 77 | self.glob = nn.Sequential(nn.AdaptiveAvgPool2d(1), 78 | nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False)) 79 | 80 | self.down2_1 = nn.Sequential(nn.Conv2d(input_dim, input_dim, kernel_size=5, stride=2, padding=2, bias=False), 81 | norm_layer(input_dim)) 82 | self.down2_2 = nn.Sequential(nn.Conv2d(input_dim, output_dim, kernel_size=5, padding=2, bias=False), 83 | norm_layer(output_dim)) 84 | 85 | self.down3_1 = nn.Sequential(nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=2, padding=1, bias=False), 86 | norm_layer(input_dim)) 87 | self.down3_2 = nn.Sequential(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, bias=False), 88 | norm_layer(output_dim)) 89 | 90 | self.conv1 = nn.Sequential(nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False), 91 | norm_layer(output_dim)) 92 | 93 | def forward(self, x): 94 | x_glob = self.glob(x) 95 | x_glob = F.interpolate(x_glob, scale_factor=int(x.size()[-1] / x_glob.size()[-1]), mode='bilinear') # 256, 16, 16 96 | 97 | d2 = F.elu(self.down2_1(x)) 98 | d3 = F.elu(self.down3_1(d2)) 99 | d2 = F.elu(self.down2_2(d2)) 100 | d3 = F.elu(self.down3_2(d3)) 101 | d3 = F.interpolate(d3, scale_factor=2, mode='bilinear') # 256, 8, 8 102 | d2 = d2 + d3 103 | d2 = F.interpolate(d2, scale_factor=2, mode='bilinear') # 256, 16, 16 104 | 105 | x = F.elu(self.conv1(x)) 106 | x = x * d2 107 | x = x + x_glob 108 | 109 | return x 110 | 111 | ####### 112 | # JPU # 113 | ####### 114 | 115 | # From https://github.com/wuhuikai/FastFCN/ 116 | class SeparableConv2d(nn.Module): 117 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, BatchNorm=nn.BatchNorm2d): 118 | super(SeparableConv2d, self).__init__() 119 | 120 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias) 121 | self.bn = BatchNorm(inplanes) 122 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | x = self.bn(x) 127 | x = self.pointwise(x) 128 | return x 129 | 130 | class JPU16(nn.Module): 131 | def __init__(self, in_channels, width=512, norm_layer=None): 132 | super(JPU16, self).__init__() 133 | 134 | self.conv5 = nn.Sequential( 135 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False), 136 | norm_layer(width), 137 | nn.ReLU(inplace=True)) 138 | self.conv4 = nn.Sequential( 139 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False), 140 | norm_layer(width), 141 | nn.ReLU(inplace=True)) 142 | 143 | self.dilation1 = nn.Sequential(SeparableConv2d(2*width, width, kernel_size=3, padding=1, dilation=1, bias=False), 144 | norm_layer(width), 145 | nn.ReLU(inplace=True)) 146 | self.dilation2 = nn.Sequential(SeparableConv2d(2*width, width, kernel_size=3, padding=2, dilation=2, bias=False), 147 | norm_layer(width), 148 | nn.ReLU(inplace=True)) 149 | self.dilation3 = nn.Sequential(SeparableConv2d(2*width, width, kernel_size=3, padding=4, dilation=4, bias=False), 150 | norm_layer(width), 151 | nn.ReLU(inplace=True)) 152 | self.dilation4 = nn.Sequential(SeparableConv2d(2*width, width, kernel_size=3, padding=8, dilation=8, bias=False), 153 | norm_layer(width), 154 | nn.ReLU(inplace=True)) 155 | 156 | def forward(self, *inputs): 157 | feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2])] 158 | _, _, h, w = feats[-1].size() 159 | feats[-2] = F.interpolate(feats[-2], size=(h, w), mode='bilinear') 160 | feat = torch.cat(feats, dim=1) 161 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], dim=1) 162 | 163 | return feat 164 | 165 | class JPU08(nn.Module): 166 | def __init__(self, in_channels, width=512, norm_layer=None): 167 | super(JPU08, self).__init__() 168 | 169 | self.conv5 = nn.Sequential( 170 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False), 171 | norm_layer(width), 172 | nn.ReLU(inplace=True)) 173 | self.conv4 = nn.Sequential( 174 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False), 175 | norm_layer(width), 176 | nn.ReLU(inplace=True)) 177 | self.conv3 = nn.Sequential( 178 | nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False), 179 | norm_layer(width), 180 | nn.ReLU(inplace=True)) 181 | 182 | self.dilation1 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=1, dilation=1, bias=False), 183 | norm_layer(width), 184 | nn.ReLU(inplace=True)) 185 | self.dilation2 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=2, dilation=2, bias=False), 186 | norm_layer(width), 187 | nn.ReLU(inplace=True)) 188 | self.dilation3 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=4, dilation=4, bias=False), 189 | norm_layer(width), 190 | nn.ReLU(inplace=True)) 191 | self.dilation4 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=8, dilation=8, bias=False), 192 | norm_layer(width), 193 | nn.ReLU(inplace=True)) 194 | 195 | def forward(self, *inputs): 196 | feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])] 197 | _, _, h, w = feats[-1].size() 198 | feats[-2] = F.interpolate(feats[-2], size=(h, w), mode='bilinear') 199 | feats[-3] = F.interpolate(feats[-3], size=(h, w), mode='bilinear') 200 | feat = torch.cat(feats, dim=1) 201 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], dim=1) 202 | 203 | return feat 204 | 205 | ####### 206 | # PSP # 207 | ####### 208 | # From https://github.com/Lextal/pspnet-pytorch 209 | 210 | class PSPModule(nn.Module): 211 | def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)): 212 | super().__init__() 213 | self.stages = [] 214 | self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes]) 215 | self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1) 216 | self.relu = nn.ReLU() 217 | 218 | def _make_stage(self, features, size): 219 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 220 | conv = nn.Conv2d(features, features, kernel_size=1, bias=False) 221 | return nn.Sequential(prior, conv) 222 | 223 | def forward(self, feats): 224 | h, w = feats.size(2), feats.size(3) 225 | priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats] 226 | bottle = self.bottleneck(torch.cat(priors, 1)) 227 | return self.relu(bottle) 228 | 229 | # Decoder for ModelDeepLab 230 | class Decoder(nn.Module): 231 | def __init__(self, num_classes, spp_inplanes, low_level_inplanes, inplanes, dropout, norm_layer): 232 | super(Decoder, self).__init__() 233 | 234 | self.conv1 = nn.Conv2d(low_level_inplanes, inplanes, 1, bias=False) 235 | self.norm1 = norm_layer(inplanes) 236 | 237 | self.elu = nn.ELU(True) 238 | self.last_conv = nn.Sequential(nn.Conv2d(spp_inplanes + inplanes, spp_inplanes, kernel_size=3, stride=1, padding=1, bias=False), 239 | norm_layer(spp_inplanes), 240 | nn.ELU(True), 241 | nn.Dropout2d(dropout[0]), 242 | nn.Conv2d(spp_inplanes, spp_inplanes, kernel_size=3, stride=1, padding=1, bias=False), 243 | norm_layer(spp_inplanes), 244 | nn.ELU(True), 245 | nn.Dropout2d(dropout[1]), 246 | nn.Conv2d(spp_inplanes, num_classes, kernel_size=1, stride=1)) 247 | 248 | def forward(self, x, low_level_feat): 249 | low_level_feat = self.conv1(low_level_feat) 250 | low_level_feat = self.norm1(low_level_feat) 251 | low_level_feat = self.elu(low_level_feat) 252 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear') 253 | x = torch.cat((x, low_level_feat), dim=1) 254 | decoder_output = x 255 | x = self.last_conv(x) 256 | return x 257 | 258 | class DeepLab(nn.Module): 259 | def __init__(self, backbone, 260 | output_stride=16, 261 | group_norm=True, 262 | dropout=dict( 263 | spp=0.5, 264 | dc0=0.5, 265 | dc1=0.1 266 | ), 267 | num_classes=2, 268 | center='aspp', 269 | jpu=True, 270 | norm_eval=False, 271 | use_maxpool=True): 272 | super(DeepLab, self).__init__() 273 | 274 | layers, channels, backbone = getattr(backbones, backbone)(output_stride=output_stride, jpu=jpu, use_maxpool=use_maxpool) 275 | 276 | self.input_range = backbone.input_range 277 | self.mean = backbone.mean 278 | self.std = backbone.std 279 | 280 | # default is freeze BatchNorm 281 | self.norm_eval = norm_eval 282 | 283 | norm_layer = GroupNorm32 if group_norm else nn.BatchNorm2d 284 | 285 | if jpu: 286 | self.backbone1 = layers[0][0] 287 | self.backbone2 = layers[0][1] 288 | self.backbone3 = layers[0][2] 289 | else: 290 | self.backbone = layers[0] 291 | self.low_level = layers[1] 292 | self.center_type = center 293 | self.use_jpu = jpu 294 | 295 | self.aspp_planes = 256 296 | self.output_stride = output_stride 297 | 298 | if output_stride == 16: 299 | aspp_dilations = (1, 6, 12, 18) 300 | elif output_stride == 8: 301 | aspp_dilations = (1, 12, 24, 36) 302 | 303 | center_input_channels = channels[-1] 304 | 305 | if center == 'fpa': 306 | self.center = FPAv2(center_input_channels, self.aspp_planes, norm_layer=norm_layer) 307 | elif center == 'aspp': 308 | self.center = ASPP(aspp_dilations, inplanes=center_input_channels, planes=self.aspp_planes, dropout=dropout['spp'], norm_layer=norm_layer) 309 | elif center == 'psp': 310 | self.center = PSPModule(center_input_channels, out_features=self.aspp_planes) 311 | elif center == 'enc': 312 | self.center = EncModule(center_input_channels, self.aspp_planes, norm_layer) 313 | if jpu: 314 | if output_stride == 16: 315 | self.jpu = JPU16(channels[2:], norm_layer=norm_layer, width=center_input_channels // 4) 316 | elif output_stride == 8: 317 | self.jpu = JPU08(channels[1:], norm_layer=norm_layer, width=center_input_channels // 4) 318 | 319 | self.decoder = Decoder(num_classes, self.aspp_planes, channels[0], 64, (dropout['dc0'], dropout['dc1']), norm_layer) 320 | self.train_mode = True 321 | 322 | def forward(self, x_input): 323 | low_level_feat = self.low_level(x_input) 324 | if self.use_jpu: 325 | c2 = self.backbone1(low_level_feat) 326 | c3 = self.backbone2(c2) 327 | c4 = self.backbone3(c3) 328 | else: 329 | features = self.backbone(low_level_feat) 330 | 331 | if self.use_jpu: 332 | if self.output_stride == 16: 333 | x = self.center(self.jpu(c3, c4)) 334 | elif self.output_stride == 8: 335 | x = self.center(self.jpu(c2, c3, c4)) 336 | else: 337 | x = self.center(features) 338 | x = self.decoder(x, low_level_feat) 339 | out_size = x_input.size()[2:] 340 | x = F.interpolate(x, size=out_size, mode='bilinear') 341 | return x 342 | 343 | def train(self, mode=True): 344 | super(DeepLab, self).train(mode) 345 | if mode and self.norm_eval: 346 | for m in self.modules(): 347 | # trick: eval have effect on BatchNorm only 348 | if isinstance(m, nn.BatchNorm2d): 349 | m.eval() 350 | return self 351 | 352 | 353 | -------------------------------------------------------------------------------- /src/models/ModelDeepLab/resnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | import torch.utils.checkpoint as cp 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | 7 | from mmcv.cnn import constant_init, kaiming_init 8 | from mmcv.runner import load_checkpoint 9 | 10 | from mmdet.ops import DeformConv, ModulatedDeformConv, ContextBlock 11 | from mmdet.models.plugins import GeneralizedAttention 12 | 13 | from mmdet.models.utils import build_conv_layer, build_norm_layer 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, 19 | inplanes, 20 | planes, 21 | stride=1, 22 | dilation=1, 23 | downsample=None, 24 | style='pytorch', 25 | with_cp=False, 26 | conv_cfg=None, 27 | norm_cfg=dict(type='BN'), 28 | dcn=None, 29 | gcb=None, 30 | gen_attention=None): 31 | super(BasicBlock, self).__init__() 32 | assert dcn is None, "Not implemented yet." 33 | assert gen_attention is None, "Not implemented yet." 34 | assert gcb is None, "Not implemented yet." 35 | 36 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) 37 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) 38 | 39 | self.conv1 = build_conv_layer( 40 | conv_cfg, 41 | inplanes, 42 | planes, 43 | 3, 44 | stride=stride, 45 | padding=dilation, 46 | dilation=dilation, 47 | bias=False) 48 | self.add_module(self.norm1_name, norm1) 49 | self.conv2 = build_conv_layer( 50 | conv_cfg, planes, planes, 3, padding=1, bias=False) 51 | self.add_module(self.norm2_name, norm2) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stride = stride 56 | self.dilation = dilation 57 | assert not with_cp 58 | 59 | @property 60 | def norm1(self): 61 | return getattr(self, self.norm1_name) 62 | 63 | @property 64 | def norm2(self): 65 | return getattr(self, self.norm2_name) 66 | 67 | def forward(self, x): 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.norm1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.norm2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | expansion = 4 88 | 89 | def __init__(self, 90 | inplanes, 91 | planes, 92 | stride=1, 93 | dilation=1, 94 | downsample=None, 95 | style='pytorch', 96 | with_cp=False, 97 | conv_cfg=None, 98 | norm_cfg=dict(type='BN'), 99 | dcn=None, 100 | gcb=None, 101 | gen_attention=None): 102 | """Bottleneck block for ResNet. 103 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, 104 | if it is "caffe", the stride-two layer is the first 1x1 conv layer. 105 | """ 106 | super(Bottleneck, self).__init__() 107 | assert style in ['pytorch', 'caffe'] 108 | assert dcn is None or isinstance(dcn, dict) 109 | assert gcb is None or isinstance(gcb, dict) 110 | assert gen_attention is None or isinstance(gen_attention, dict) 111 | 112 | self.inplanes = inplanes 113 | self.planes = planes 114 | self.stride = stride 115 | self.dilation = dilation 116 | self.style = style 117 | self.with_cp = with_cp 118 | self.conv_cfg = conv_cfg 119 | self.norm_cfg = norm_cfg 120 | self.dcn = dcn 121 | self.with_dcn = dcn is not None 122 | self.gcb = gcb 123 | self.with_gcb = gcb is not None 124 | self.gen_attention = gen_attention 125 | self.with_gen_attention = gen_attention is not None 126 | 127 | if self.style == 'pytorch': 128 | self.conv1_stride = 1 129 | self.conv2_stride = stride 130 | else: 131 | self.conv1_stride = stride 132 | self.conv2_stride = 1 133 | 134 | self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) 135 | self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) 136 | self.norm3_name, norm3 = build_norm_layer( 137 | norm_cfg, planes * self.expansion, postfix=3) 138 | 139 | self.conv1 = build_conv_layer( 140 | conv_cfg, 141 | inplanes, 142 | planes, 143 | kernel_size=1, 144 | stride=self.conv1_stride, 145 | bias=False) 146 | self.add_module(self.norm1_name, norm1) 147 | fallback_on_stride = False 148 | self.with_modulated_dcn = False 149 | if self.with_dcn: 150 | fallback_on_stride = dcn.get('fallback_on_stride', False) 151 | self.with_modulated_dcn = dcn.get('modulated', False) 152 | if not self.with_dcn or fallback_on_stride: 153 | self.conv2 = build_conv_layer( 154 | conv_cfg, 155 | planes, 156 | planes, 157 | kernel_size=3, 158 | stride=self.conv2_stride, 159 | padding=dilation, 160 | dilation=dilation, 161 | bias=False) 162 | else: 163 | assert conv_cfg is None, 'conv_cfg must be None for DCN' 164 | deformable_groups = dcn.get('deformable_groups', 1) 165 | if not self.with_modulated_dcn: 166 | conv_op = DeformConv 167 | offset_channels = 18 168 | else: 169 | conv_op = ModulatedDeformConv 170 | offset_channels = 27 171 | self.conv2_offset = nn.Conv2d( 172 | planes, 173 | deformable_groups * offset_channels, 174 | kernel_size=3, 175 | stride=self.conv2_stride, 176 | padding=dilation, 177 | dilation=dilation) 178 | self.conv2 = conv_op( 179 | planes, 180 | planes, 181 | kernel_size=3, 182 | stride=self.conv2_stride, 183 | padding=dilation, 184 | dilation=dilation, 185 | deformable_groups=deformable_groups, 186 | bias=False) 187 | self.add_module(self.norm2_name, norm2) 188 | self.conv3 = build_conv_layer( 189 | conv_cfg, 190 | planes, 191 | planes * self.expansion, 192 | kernel_size=1, 193 | bias=False) 194 | self.add_module(self.norm3_name, norm3) 195 | 196 | self.relu = nn.ReLU(inplace=True) 197 | self.downsample = downsample 198 | 199 | if self.with_gcb: 200 | gcb_inplanes = planes * self.expansion 201 | self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb) 202 | 203 | # gen_attention 204 | if self.with_gen_attention: 205 | self.gen_attention_block = GeneralizedAttention( 206 | planes, **gen_attention) 207 | 208 | @property 209 | def norm1(self): 210 | return getattr(self, self.norm1_name) 211 | 212 | @property 213 | def norm2(self): 214 | return getattr(self, self.norm2_name) 215 | 216 | @property 217 | def norm3(self): 218 | return getattr(self, self.norm3_name) 219 | 220 | def forward(self, x): 221 | 222 | def _inner_forward(x): 223 | identity = x 224 | 225 | out = self.conv1(x) 226 | out = self.norm1(out) 227 | out = self.relu(out) 228 | 229 | if not self.with_dcn: 230 | out = self.conv2(out) 231 | elif self.with_modulated_dcn: 232 | offset_mask = self.conv2_offset(out) 233 | offset = offset_mask[:, :18, :, :] 234 | mask = offset_mask[:, -9:, :, :].sigmoid() 235 | out = self.conv2(out, offset, mask) 236 | else: 237 | offset = self.conv2_offset(out) 238 | out = self.conv2(out, offset) 239 | out = self.norm2(out) 240 | out = self.relu(out) 241 | 242 | if self.with_gen_attention: 243 | out = self.gen_attention_block(out) 244 | 245 | out = self.conv3(out) 246 | out = self.norm3(out) 247 | 248 | if self.with_gcb: 249 | out = self.context_block(out) 250 | 251 | if self.downsample is not None: 252 | identity = self.downsample(x) 253 | 254 | out += identity 255 | 256 | return out 257 | 258 | if self.with_cp and x.requires_grad: 259 | out = cp.checkpoint(_inner_forward, x) 260 | else: 261 | out = _inner_forward(x) 262 | 263 | out = self.relu(out) 264 | 265 | return out 266 | 267 | 268 | def make_multigrid(block, 269 | inplanes, 270 | planes, 271 | blocks, 272 | stride=1, 273 | dilation=1, 274 | style='pytorch', 275 | with_cp=False, 276 | conv_cfg=None, 277 | norm_cfg=dict(type='BN'), 278 | dcn=None, 279 | gcb=None, 280 | gen_attention=None, 281 | gen_attention_blocks=[]): 282 | downsample = None 283 | if stride != 1 or inplanes != planes * block.expansion: 284 | downsample = nn.Sequential( 285 | build_conv_layer( 286 | conv_cfg, 287 | inplanes, 288 | planes * block.expansion, 289 | kernel_size=1, 290 | stride=stride, 291 | bias=False), 292 | build_norm_layer(norm_cfg, planes * block.expansion)[1], 293 | ) 294 | 295 | layers = [] 296 | layers.append( 297 | block( 298 | inplanes=inplanes, 299 | planes=planes, 300 | stride=stride, 301 | dilation=blocks[0]*dilation, 302 | downsample=downsample, 303 | style=style, 304 | with_cp=with_cp, 305 | conv_cfg=conv_cfg, 306 | norm_cfg=norm_cfg, 307 | dcn=dcn, 308 | gcb=gcb, 309 | gen_attention=gen_attention if 310 | (0 in gen_attention_blocks) else None)) 311 | inplanes = planes * block.expansion 312 | for i in range(1, len(blocks)): 313 | layers.append( 314 | block( 315 | inplanes=inplanes, 316 | planes=planes, 317 | stride=1, 318 | dilation=blocks[i]*dilation, 319 | style=style, 320 | with_cp=with_cp, 321 | conv_cfg=conv_cfg, 322 | norm_cfg=norm_cfg, 323 | dcn=dcn, 324 | gcb=gcb, 325 | gen_attention=gen_attention if 326 | (i in gen_attention_blocks) else None)) 327 | 328 | return nn.Sequential(*layers) 329 | 330 | def make_res_layer(block, 331 | inplanes, 332 | planes, 333 | blocks, 334 | stride=1, 335 | dilation=1, 336 | style='pytorch', 337 | with_cp=False, 338 | conv_cfg=None, 339 | norm_cfg=dict(type='BN'), 340 | dcn=None, 341 | gcb=None, 342 | gen_attention=None, 343 | gen_attention_blocks=[]): 344 | downsample = None 345 | if stride != 1 or inplanes != planes * block.expansion: 346 | downsample = nn.Sequential( 347 | build_conv_layer( 348 | conv_cfg, 349 | inplanes, 350 | planes * block.expansion, 351 | kernel_size=1, 352 | stride=stride, 353 | bias=False), 354 | build_norm_layer(norm_cfg, planes * block.expansion)[1], 355 | ) 356 | 357 | layers = [] 358 | layers.append( 359 | block( 360 | inplanes=inplanes, 361 | planes=planes, 362 | stride=stride, 363 | dilation=dilation, 364 | downsample=downsample, 365 | style=style, 366 | with_cp=with_cp, 367 | conv_cfg=conv_cfg, 368 | norm_cfg=norm_cfg, 369 | dcn=dcn, 370 | gcb=gcb, 371 | gen_attention=gen_attention if 372 | (0 in gen_attention_blocks) else None)) 373 | inplanes = planes * block.expansion 374 | for i in range(1, blocks): 375 | layers.append( 376 | block( 377 | inplanes=inplanes, 378 | planes=planes, 379 | stride=1, 380 | dilation=dilation, 381 | style=style, 382 | with_cp=with_cp, 383 | conv_cfg=conv_cfg, 384 | norm_cfg=norm_cfg, 385 | dcn=dcn, 386 | gcb=gcb, 387 | gen_attention=gen_attention if 388 | (i in gen_attention_blocks) else None)) 389 | 390 | return nn.Sequential(*layers) 391 | 392 | class ResNet(nn.Module): 393 | """ResNet backbone. 394 | 395 | Args: 396 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 397 | num_stages (int): Resnet stages, normally 4. 398 | strides (Sequence[int]): Strides of the first block of each stage. 399 | dilations (Sequence[int]): Dilation of each stage. 400 | out_indices (Sequence[int]): Output from which stages. 401 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 402 | layer is the 3x3 conv layer, otherwise the stride-two layer is 403 | the first 1x1 conv layer. 404 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 405 | -1 means not freezing any parameters. 406 | norm_cfg (dict): dictionary to construct and config norm layer. 407 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 408 | freeze running stats (mean and var). Note: Effect on Batch Norm 409 | and its variants only. 410 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 411 | memory while slowing down the training speed. 412 | zero_init_residual (bool): whether to use zero init for last norm layer 413 | in resblocks to let them behave as identity. 414 | """ 415 | 416 | arch_settings = { 417 | 18: (BasicBlock, (2, 2, 2, 2)), 418 | 34: (BasicBlock, (3, 4, 6, 3)), 419 | 50: (Bottleneck, (3, 4, 6, 3)), 420 | 101: (Bottleneck, (3, 4, 23, 3)), 421 | 152: (Bottleneck, (3, 8, 36, 3)) 422 | } 423 | 424 | def __init__(self, 425 | depth, 426 | num_stages=4, 427 | mg_rates=(1, 2, 4), 428 | strides=(1, 2, 2, 2), 429 | dilations=(1, 1, 1, 1), 430 | out_indices=(0, 1, 2, 3), 431 | style='pytorch', 432 | frozen_stages=-1, 433 | conv_cfg=None, 434 | norm_cfg=dict(type='BN', requires_grad=True), 435 | norm_eval=True, 436 | dcn=None, 437 | stage_with_dcn=(False, False, False, False), 438 | gcb=None, 439 | stage_with_gcb=(False, False, False, False), 440 | gen_attention=None, 441 | stage_with_gen_attention=((), (), (), ()), 442 | with_cp=False, 443 | zero_init_residual=True): 444 | super(ResNet, self).__init__() 445 | if depth not in self.arch_settings: 446 | raise KeyError('invalid depth {} for resnet'.format(depth)) 447 | self.depth = depth 448 | self.num_stages = num_stages 449 | self.mg_rates = mg_rates 450 | assert num_stages >= 1 and num_stages <= 4 451 | self.strides = strides 452 | self.dilations = dilations 453 | assert len(strides) == len(dilations) == num_stages 454 | self.out_indices = out_indices 455 | assert max(out_indices) < num_stages 456 | self.style = style 457 | self.frozen_stages = frozen_stages 458 | self.conv_cfg = conv_cfg 459 | self.norm_cfg = norm_cfg 460 | self.with_cp = with_cp 461 | self.norm_eval = norm_eval 462 | self.dcn = dcn 463 | self.stage_with_dcn = stage_with_dcn 464 | if dcn is not None: 465 | assert len(stage_with_dcn) == num_stages 466 | self.gen_attention = gen_attention 467 | self.gcb = gcb 468 | self.stage_with_gcb = stage_with_gcb 469 | if gcb is not None: 470 | assert len(stage_with_gcb) == num_stages 471 | self.zero_init_residual = zero_init_residual 472 | self.block, stage_blocks = self.arch_settings[depth] 473 | self.stage_blocks = stage_blocks[:num_stages] 474 | self.inplanes = 64 475 | 476 | self._make_stem_layer() 477 | 478 | self.res_layers = [] 479 | for i, num_blocks in enumerate(self.stage_blocks): 480 | stride = strides[i] 481 | dilation = dilations[i] 482 | dcn = self.dcn if self.stage_with_dcn[i] else None 483 | gcb = self.gcb if self.stage_with_gcb[i] else None 484 | planes = 64 * 2**i 485 | if i == len(self.stage_blocks) - 1 and not True in self.stage_with_dcn: 486 | assert len(self.mg_rates) == num_blocks 487 | res_layer = make_multigrid( 488 | self.block, 489 | self.inplanes, 490 | planes, 491 | self.mg_rates, 492 | stride=stride, 493 | dilation=dilation, 494 | style=self.style, 495 | with_cp=with_cp, 496 | conv_cfg=conv_cfg, 497 | norm_cfg=norm_cfg, 498 | dcn=dcn, 499 | gcb=gcb, 500 | gen_attention=gen_attention, 501 | gen_attention_blocks=stage_with_gen_attention[i]) 502 | else: 503 | res_layer = make_res_layer( 504 | self.block, 505 | self.inplanes, 506 | planes, 507 | num_blocks, 508 | stride=stride, 509 | dilation=dilation, 510 | style=self.style, 511 | with_cp=with_cp, 512 | conv_cfg=conv_cfg, 513 | norm_cfg=norm_cfg, 514 | dcn=dcn, 515 | gcb=gcb, 516 | gen_attention=gen_attention, 517 | gen_attention_blocks=stage_with_gen_attention[i]) 518 | self.inplanes = planes * self.block.expansion 519 | layer_name = 'layer{}'.format(i + 1) 520 | self.add_module(layer_name, res_layer) 521 | self.res_layers.append(layer_name) 522 | 523 | self._freeze_stages() 524 | 525 | self.feat_dim = self.block.expansion * 64 * 2**( 526 | len(self.stage_blocks) - 1) 527 | 528 | @property 529 | def norm1(self): 530 | return getattr(self, self.norm1_name) 531 | 532 | def _make_stem_layer(self): 533 | self.conv1 = build_conv_layer( 534 | self.conv_cfg, 535 | 3, 536 | 64, 537 | kernel_size=7, 538 | stride=2, 539 | padding=3, 540 | bias=False) 541 | self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) 542 | self.add_module(self.norm1_name, norm1) 543 | self.relu = nn.ReLU(inplace=True) 544 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 545 | 546 | def _freeze_stages(self): 547 | if self.frozen_stages >= 0: 548 | self.norm1.eval() 549 | for m in [self.conv1, self.norm1]: 550 | for param in m.parameters(): 551 | param.requires_grad = False 552 | 553 | for i in range(1, self.frozen_stages + 1): 554 | m = getattr(self, 'layer{}'.format(i)) 555 | m.eval() 556 | for param in m.parameters(): 557 | param.requires_grad = False 558 | 559 | def init_weights(self, pretrained=None): 560 | if isinstance(pretrained, str): 561 | logger = logging.getLogger() 562 | load_checkpoint(self, pretrained, strict=False, logger=logger) 563 | elif pretrained is None: 564 | for m in self.modules(): 565 | if isinstance(m, nn.Conv2d): 566 | kaiming_init(m) 567 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 568 | constant_init(m, 1) 569 | 570 | if self.dcn is not None: 571 | for m in self.modules(): 572 | if isinstance(m, Bottleneck) and hasattr( 573 | m, 'conv2_offset'): 574 | constant_init(m.conv2_offset, 0) 575 | 576 | if self.zero_init_residual: 577 | for m in self.modules(): 578 | if isinstance(m, Bottleneck): 579 | constant_init(m.norm3, 0) 580 | elif isinstance(m, BasicBlock): 581 | constant_init(m.norm2, 0) 582 | else: 583 | raise TypeError('pretrained must be a str or None') 584 | 585 | def forward(self, x): 586 | x = self.conv1(x) 587 | x = self.norm1(x) 588 | x = self.relu(x) 589 | x = self.maxpool(x) 590 | for i, layer_name in enumerate(self.res_layers): 591 | res_layer = getattr(self, layer_name) 592 | x = res_layer(x) 593 | return x 594 | 595 | def train(self, mode=True): 596 | super(ResNet, self).train(mode) 597 | self._freeze_stages() 598 | if mode and self.norm_eval: 599 | for m in self.modules(): 600 | # trick: eval have effect on BatchNorm only 601 | if isinstance(m, _BatchNorm): 602 | m.eval() 603 | 604 | -------------------------------------------------------------------------------- /src/models/unet3d/models.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .buildingblocks import Encoder, Decoder, FinalConv, DoubleConv, ExtResNetBlock, SingleConv 7 | from .utils import create_feature_maps 8 | 9 | 10 | class UNet3D(nn.Module): 11 | """ 12 | 3DUnet model from 13 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 14 | `. 15 | Args: 16 | in_channels (int): number of input channels 17 | out_channels (int): number of output segmentation masks; 18 | Note that that the of out_channels might correspond to either 19 | different semantic classes or to different binary segmentation mask. 20 | It's up to the user of the class to interpret the out_channels and 21 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 22 | or BCEWithLogitsLoss (two-class) respectively) 23 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 24 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 25 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 26 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 27 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 28 | layer_order (string): determines the order of layers 29 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 30 | See `SingleConv` for more info 31 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 32 | num_groups (int): number of groups for the GroupNorm 33 | """ 34 | 35 | def __init__(self, in_channels, out_channels, f_maps=64, layer_order='crg', num_groups=8, 36 | **kwargs): 37 | super(UNet3D, self).__init__() 38 | 39 | if isinstance(f_maps, int): 40 | # use 4 levels in the encoder path as suggested in the paper 41 | f_maps = create_feature_maps(f_maps, number_of_fmaps=4) 42 | 43 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 44 | # uses DoubleConv as a basic_module for the Encoder 45 | encoders = [] 46 | for i, out_feature_num in enumerate(f_maps): 47 | if i == 0: 48 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv, 49 | conv_layer_order=layer_order, num_groups=num_groups) 50 | else: 51 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv, 52 | conv_layer_order=layer_order, num_groups=num_groups) 53 | encoders.append(encoder) 54 | 55 | self.encoders = nn.ModuleList(encoders) 56 | 57 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 58 | # uses DoubleConv as a basic_module for the Decoder 59 | decoders = [] 60 | reversed_f_maps = list(reversed(f_maps)) 61 | for i in range(len(reversed_f_maps) - 1): 62 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 63 | out_feature_num = reversed_f_maps[i + 1] 64 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv, 65 | conv_layer_order=layer_order, num_groups=num_groups) 66 | decoders.append(decoder) 67 | 68 | self.decoders = nn.ModuleList(decoders) 69 | 70 | # in the last layer a 1×1 convolution reduces the number of output 71 | # channels to the number of labels 72 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 73 | 74 | def forward(self, x): 75 | # encoder part 76 | encoders_features = [] 77 | for encoder in self.encoders: 78 | x = encoder(x) 79 | # reverse the encoder outputs to be aligned with the decoder 80 | encoders_features.insert(0, x) 81 | 82 | # remove the last encoder's output from the list 83 | # !!remember: it's the 1st in the list 84 | encoders_features = encoders_features[1:] 85 | 86 | # decoder part 87 | for decoder, encoder_features in zip(self.decoders, encoders_features): 88 | # pass the output from the corresponding encoder and the output 89 | # of the previous decoder 90 | x = decoder(encoder_features, x) 91 | 92 | x = self.final_conv(x) 93 | return x 94 | 95 | 96 | class ResidualUNet3D(nn.Module): 97 | """ 98 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 99 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead 100 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 101 | Args: 102 | in_channels (int): number of input channels 103 | out_channels (int): number of output segmentation masks; 104 | Note that that the of out_channels might correspond to either 105 | different semantic classes or to different binary segmentation mask. 106 | It's up to the user of the class to interpret the out_channels and 107 | use the proper loss criterion during training (i.e. NLLLoss (multi-class) 108 | or BCELoss (two-class) respectively) 109 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 110 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5 111 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 112 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 113 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 114 | conv_layer_order (string): determines the order of layers 115 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 116 | See `SingleConv` for more info 117 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 118 | num_groups (int): number of groups for the GroupNorm 119 | """ 120 | 121 | def __init__(self, in_channels, out_channels, f_maps=32, conv_layer_order='cge', num_groups=8, 122 | **kwargs): 123 | super(ResidualUNet3D, self).__init__() 124 | 125 | if isinstance(f_maps, int): 126 | # use 5 levels in the encoder path as suggested in the paper 127 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5) 128 | 129 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 130 | # uses ExtResNetBlock as a basic_module for the Encoder 131 | encoders = [] 132 | for i, out_feature_num in enumerate(f_maps): 133 | if i == 0: 134 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=ExtResNetBlock, 135 | conv_layer_order=conv_layer_order, num_groups=num_groups) 136 | else: 137 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=ExtResNetBlock, 138 | conv_layer_order=conv_layer_order, num_groups=num_groups) 139 | encoders.append(encoder) 140 | 141 | self.encoders = nn.ModuleList(encoders) 142 | 143 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 144 | # uses ExtResNetBlock as a basic_module for the Decoder 145 | decoders = [] 146 | reversed_f_maps = list(reversed(f_maps)) 147 | for i in range(len(reversed_f_maps) - 1): 148 | decoder = Decoder(reversed_f_maps[i], reversed_f_maps[i + 1], basic_module=ExtResNetBlock, 149 | conv_layer_order=conv_layer_order, num_groups=num_groups) 150 | decoders.append(decoder) 151 | 152 | self.decoders = nn.ModuleList(decoders) 153 | 154 | # in the last layer a 1×1 convolution reduces the number of output 155 | # channels to the number of labels 156 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 157 | 158 | def forward(self, x): 159 | # encoder part 160 | encoders_features = [] 161 | for encoder in self.encoders: 162 | x = encoder(x) 163 | # reverse the encoder outputs to be aligned with the decoder 164 | encoders_features.insert(0, x) 165 | 166 | # remove the last encoder's output from the list 167 | # !!remember: it's the 1st in the list 168 | encoders_features = encoders_features[1:] 169 | 170 | # decoder part 171 | for decoder, encoder_features in zip(self.decoders, encoders_features): 172 | # pass the output from the corresponding encoder and the output 173 | # of the previous decoder 174 | x = decoder(encoder_features, x) 175 | 176 | x = self.final_conv(x) 177 | 178 | return x 179 | 180 | 181 | class Noise2NoiseUNet3D(nn.Module): 182 | """ 183 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 184 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead 185 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 186 | Args: 187 | in_channels (int): number of input channels 188 | out_channels (int): number of output segmentation masks; 189 | Note that that the of out_channels might correspond to either 190 | different semantic classes or to different binary segmentation mask. 191 | It's up to the user of the class to interpret the out_channels and 192 | use the proper loss criterion during training (i.e. NLLLoss (multi-class) 193 | or BCELoss (two-class) respectively) 194 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 195 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5 196 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 197 | num_groups (int): number of groups for the GroupNorm 198 | """ 199 | 200 | def __init__(self, in_channels, out_channels, f_maps=16, num_groups=8, **kwargs): 201 | super(Noise2NoiseUNet3D, self).__init__() 202 | 203 | # Use LeakyReLU activation everywhere except the last layer 204 | conv_layer_order = 'clg' 205 | 206 | if isinstance(f_maps, int): 207 | # use 5 levels in the encoder path as suggested in the paper 208 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5) 209 | 210 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 211 | # uses DoubleConv as a basic_module for the Encoder 212 | encoders = [] 213 | for i, out_feature_num in enumerate(f_maps): 214 | if i == 0: 215 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv, 216 | conv_layer_order=conv_layer_order, num_groups=num_groups) 217 | else: 218 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv, 219 | conv_layer_order=conv_layer_order, num_groups=num_groups) 220 | encoders.append(encoder) 221 | 222 | self.encoders = nn.ModuleList(encoders) 223 | 224 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 225 | # uses DoubleConv as a basic_module for the Decoder 226 | decoders = [] 227 | reversed_f_maps = list(reversed(f_maps)) 228 | for i in range(len(reversed_f_maps) - 1): 229 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 230 | out_feature_num = reversed_f_maps[i + 1] 231 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv, 232 | conv_layer_order=conv_layer_order, num_groups=num_groups) 233 | decoders.append(decoder) 234 | 235 | self.decoders = nn.ModuleList(decoders) 236 | 237 | # 1x1x1 conv + simple ReLU in the final convolution 238 | self.final_conv = SingleConv(f_maps[0], out_channels, kernel_size=1, order='cr', padding=0) 239 | 240 | def forward(self, x): 241 | # encoder part 242 | encoders_features = [] 243 | for encoder in self.encoders: 244 | x = encoder(x) 245 | # reverse the encoder outputs to be aligned with the decoder 246 | encoders_features.insert(0, x) 247 | 248 | # remove the last encoder's output from the list 249 | # !!remember: it's the 1st in the list 250 | encoders_features = encoders_features[1:] 251 | 252 | # decoder part 253 | for decoder, encoder_features in zip(self.decoders, encoders_features): 254 | # pass the output from the corresponding encoder and the output 255 | # of the previous decoder 256 | x = decoder(encoder_features, x) 257 | 258 | x = self.final_conv(x) 259 | 260 | return x 261 | 262 | 263 | def get_model(config): 264 | def _model_class(class_name): 265 | m = importlib.import_module('unet3d.model') 266 | clazz = getattr(m, class_name) 267 | return clazz 268 | 269 | assert 'model' in config, 'Could not find model configuration' 270 | model_config = config['model'] 271 | model_class = _model_class(model_config['name']) 272 | return model_class(**model_config) 273 | 274 | 275 | ###############################################Supervised Tags 3DUnet################################################### 276 | 277 | class TagsUNet3D(nn.Module): 278 | """ 279 | Supervised tags 3DUnet 280 | Args: 281 | in_channels (int): number of input channels 282 | out_channels (int): number of output channels; since most often we're trying to learn 283 | 3D unit vectors we use 3 as a default value 284 | output_heads (int): number of output heads from the network, each head corresponds to different 285 | semantic tag/direction to be learned 286 | conv_layer_order (string): determines the order of layers 287 | in `DoubleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 288 | See `DoubleConv` for more info 289 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 290 | """ 291 | 292 | def __init__(self, in_channels, out_channels=3, output_heads=1, conv_layer_order='crg', init_channel_number=32, 293 | **kwargs): 294 | super(TagsUNet3D, self).__init__() 295 | 296 | # number of groups for the GroupNorm 297 | num_groups = min(init_channel_number // 2, 32) 298 | 299 | # encoder path consist of 4 subsequent Encoder modules 300 | # the number of features maps is the same as in the paper 301 | self.encoders = nn.ModuleList([ 302 | Encoder(in_channels, init_channel_number, apply_pooling=False, conv_layer_order=conv_layer_order, 303 | num_groups=num_groups), 304 | Encoder(init_channel_number, 2 * init_channel_number, conv_layer_order=conv_layer_order, 305 | num_groups=num_groups), 306 | Encoder(2 * init_channel_number, 4 * init_channel_number, conv_layer_order=conv_layer_order, 307 | num_groups=num_groups), 308 | Encoder(4 * init_channel_number, 8 * init_channel_number, conv_layer_order=conv_layer_order, 309 | num_groups=num_groups) 310 | ]) 311 | 312 | self.decoders = nn.ModuleList([ 313 | Decoder(4 * init_channel_number + 8 * init_channel_number, 4 * init_channel_number, 314 | conv_layer_order=conv_layer_order, num_groups=num_groups), 315 | Decoder(2 * init_channel_number + 4 * init_channel_number, 2 * init_channel_number, 316 | conv_layer_order=conv_layer_order, num_groups=num_groups), 317 | Decoder(init_channel_number + 2 * init_channel_number, init_channel_number, 318 | conv_layer_order=conv_layer_order, num_groups=num_groups) 319 | ]) 320 | 321 | self.final_heads = nn.ModuleList( 322 | [FinalConv(init_channel_number, out_channels, num_groups=num_groups) for _ in 323 | range(output_heads)]) 324 | 325 | def forward(self, x): 326 | # encoder part 327 | encoders_features = [] 328 | for encoder in self.encoders: 329 | x = encoder(x) 330 | # reverse the encoder outputs to be aligned with the decoder 331 | encoders_features.insert(0, x) 332 | 333 | # remove the last encoder's output from the list 334 | # !!remember: it's the 1st in the list 335 | encoders_features = encoders_features[1:] 336 | 337 | # decoder part 338 | for decoder, encoder_features in zip(self.decoders, encoders_features): 339 | # pass the output from the corresponding encoder and the output 340 | # of the previous decoder 341 | x = decoder(encoder_features, x) 342 | 343 | # apply final layer per each output head 344 | tags = [final_head(x) for final_head in self.final_heads] 345 | 346 | # normalize directions with L2 norm 347 | return [tag / torch.norm(tag, p=2, dim=1).detach().clamp(min=1e-8) for tag in tags] 348 | 349 | 350 | ################################################Distance transform 3DUNet############################################## 351 | class DistanceTransformUNet3D(nn.Module): 352 | """ 353 | Predict Distance Transform to the boundary signal based on the output from the Tags3DUnet. Fore training use either: 354 | 1. PixelWiseCrossEntropyLoss if the distance transform is quantized (classification) 355 | 2. MSELoss if the distance transform is continuous (regression) 356 | Args: 357 | in_channels (int): number of input channels 358 | out_channels (int): number of output segmentation masks; 359 | Note that that the of out_channels might correspond to either 360 | different semantic classes or to different binary segmentation mask. 361 | It's up to the user of the class to interpret the out_channels and 362 | use the proper loss criterion during training (i.e. NLLLoss (multi-class) 363 | or BCELoss (two-class) respectively) 364 | final_sigmoid (bool): 'sigmoid'/'softmax' whether element-wise nn.Sigmoid or nn.Softmax should be applied after 365 | the final 1x1 convolution 366 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 367 | """ 368 | 369 | def __init__(self, in_channels, out_channels, final_sigmoid, init_channel_number=32, **kwargs): 370 | super(DistanceTransformUNet3D, self).__init__() 371 | 372 | # number of groups for the GroupNorm 373 | num_groups = min(init_channel_number // 2, 32) 374 | 375 | # encoder path consist of 4 subsequent Encoder modules 376 | # the number of features maps is the same as in the paper 377 | self.encoders = nn.ModuleList([ 378 | Encoder(in_channels, init_channel_number, apply_pooling=False, conv_layer_order='crg', 379 | num_groups=num_groups), 380 | Encoder(init_channel_number, 2 * init_channel_number, pool_type='avg', conv_layer_order='crg', 381 | num_groups=num_groups) 382 | ]) 383 | 384 | self.decoders = nn.ModuleList([ 385 | Decoder(3 * init_channel_number, init_channel_number, conv_layer_order='crg', num_groups=num_groups) 386 | ]) 387 | 388 | # in the last layer a 1×1 convolution reduces the number of output 389 | # channels to the number of labels 390 | self.final_conv = nn.Conv3d(init_channel_number, out_channels, 1) 391 | 392 | if final_sigmoid: 393 | self.final_activation = nn.Sigmoid() 394 | else: 395 | self.final_activation = nn.Softmax(dim=1) 396 | 397 | def forward(self, inputs): 398 | # allow multiple heads 399 | if isinstance(inputs, list) or isinstance(inputs, tuple): 400 | x = torch.cat(inputs, dim=1) 401 | else: 402 | x = inputs 403 | 404 | # encoder part 405 | encoders_features = [] 406 | for encoder in self.encoders: 407 | x = encoder(x) 408 | # reverse the encoder outputs to be aligned with the decoder 409 | encoders_features.insert(0, x) 410 | 411 | # remove the last encoder's output from the list 412 | # !!remember: it's the 1st in the list 413 | encoders_features = encoders_features[1:] 414 | 415 | # decoder part 416 | for decoder, encoder_features in zip(self.decoders, encoders_features): 417 | # pass the output from the corresponding encoder and the output 418 | # of the previous decoder 419 | x = decoder(encoder_features, x) 420 | 421 | # apply final 1x1 convolution 422 | x = self.final_conv(x) 423 | 424 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs 425 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 426 | if not self.training: 427 | x = self.final_activation(x) 428 | 429 | return x 430 | 431 | 432 | class EndToEndDTUNet3D(nn.Module): 433 | def __init__(self, tags_in_channels, tags_out_channels, tags_output_heads, tags_init_channel_number, 434 | dt_in_channels, dt_out_channels, dt_final_sigmoid, dt_init_channel_number, 435 | tags_net_path=None, dt_net_path=None, **kwargs): 436 | super(EndToEndDTUNet3D, self).__init__() 437 | 438 | self.tags_net = TagsUNet3D(tags_in_channels, tags_out_channels, tags_output_heads, 439 | init_channel_number=tags_init_channel_number) 440 | if tags_net_path is not None: 441 | # load pre-trained TagsUNet3D 442 | self.tags_net = self._load_net(tags_net_path, self.tags_net) 443 | 444 | self.dt_net = DistanceTransformUNet3D(dt_in_channels, dt_out_channels, dt_final_sigmoid, 445 | init_channel_number=dt_init_channel_number) 446 | if dt_net_path is not None: 447 | # load pre-trained DistanceTransformUNet3D 448 | self.dt_net = self._load_net(dt_net_path, self.dt_net) 449 | 450 | @staticmethod 451 | def _load_net(checkpoint_path, model): 452 | state = torch.load(checkpoint_path) 453 | model.load_state_dict(state['model_state_dict']) 454 | return model 455 | 456 | def forward(self, x): 457 | x = self.tags_net(x) 458 | return self.dt_net(x) 459 | 460 | 461 | class UNet3D2(nn.Module): 462 | def __init__(self, in_channel, n_classes): 463 | self.in_channel = in_channel 464 | self.n_classes = n_classes 465 | super(UNet3D2, self).__init__() 466 | self.ec0 = self.down_conv(self.in_channel, 32, bias=False, batchnorm=False) 467 | self.ec1 = self.down_conv(32, 64, bias=False, batchnorm=False) 468 | self.ec2 = self.down_conv(64, 64, bias=False, batchnorm=False) 469 | self.ec3 = self.down_conv(64, 128, bias=False, batchnorm=False) 470 | self.ec4 = self.down_conv(128, 128, bias=False, batchnorm=False) 471 | self.ec5 = self.down_conv(128, 256, bias=False, batchnorm=False) 472 | self.ec6 = self.down_conv(256, 256, bias=False, batchnorm=False) 473 | self.ec7 = self.down_conv(256, 512, bias=False, batchnorm=False) 474 | 475 | self.pool0 = nn.MaxPool3d(2) 476 | self.pool1 = nn.MaxPool3d(2) 477 | self.pool2 = nn.MaxPool3d(2) 478 | 479 | self.dc9 = self.up_conv(512, 512, kernel_size=2, stride=2, bias=False) 480 | self.dc8 = self.down_conv(256 + 512, 256, bias=False) 481 | self.dc7 = self.down_conv(256, 256, bias=False) 482 | self.dc6 = self.up_conv(256, 256, kernel_size=2, stride=2, bias=False) 483 | self.dc5 = self.down_conv(128 + 256, 128, bias=False) 484 | self.dc4 = self.down_conv(128, 128, bias=False) 485 | self.dc3 = self.up_conv(128, 128, kernel_size=2, stride=2, bias=False) 486 | self.dc2 = self.down_conv(64 + 128, 64, bias=False) 487 | self.dc1 = self.down_conv(64, 64, kernel_size=3, stride=1, padding=1, bias=False) 488 | self.dc0 = self.down_conv(64, n_classes, kernel_size=1, stride=1, padding=0, bias=False) 489 | 490 | 491 | def down_conv(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, 492 | bias=True, batchnorm=False): 493 | if batchnorm: 494 | layer = nn.Sequential( 495 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 496 | nn.BatchNorm2d(out_channels), 497 | nn.ReLU()) 498 | else: 499 | layer = nn.Sequential( 500 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 501 | nn.ReLU()) 502 | return layer 503 | 504 | 505 | def up_conv(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, 506 | output_padding=0, bias=True): 507 | layer = nn.Sequential( 508 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, 509 | padding=padding, output_padding=output_padding, bias=bias), 510 | nn.ReLU()) 511 | return layer 512 | 513 | def forward(self, x): 514 | e0 = self.ec0(x) 515 | syn0 = self.ec1(e0) 516 | e1 = self.pool0(syn0) 517 | e2 = self.ec2(e1) 518 | syn1 = self.ec3(e2) 519 | del e0, e1, e2 520 | 521 | e3 = self.pool1(syn1) 522 | e4 = self.ec4(e3) 523 | syn2 = self.ec5(e4) 524 | del e3, e4 525 | 526 | e5 = self.pool2(syn2) 527 | e6 = self.ec6(e5) 528 | e7 = self.ec7(e6) 529 | del e5, e6 530 | 531 | d9 = torch.cat((self.dc9(e7), syn2), dim=1) 532 | del e7, syn2 533 | 534 | d8 = self.dc8(d9) 535 | d7 = self.dc7(d8) 536 | del d9, d8 537 | 538 | d6 = torch.cat((self.dc6(d7), syn1), dim=1) 539 | del d7, syn1 540 | 541 | d5 = self.dc5(d6) 542 | d4 = self.dc4(d5) 543 | del d6, d5 544 | 545 | d3 = torch.cat((self.dc3(d4), syn0), dim=1) 546 | del d4, syn0 547 | 548 | d2 = self.dc2(d3) 549 | d1 = self.dc1(d2) 550 | del d3, d2 551 | 552 | d0 = self.dc0(d1) 553 | return d0 554 | --------------------------------------------------------------------------------