├── Pretrain ├── optimizers │ ├── __init__.py │ └── lr_scheduler.py ├── pretrain_models │ ├── __init__.py │ ├── utils.py │ ├── deep_unet.py │ ├── deep_unet_v2.py │ ├── deep_unet_v2_64_16.py │ ├── deep_unet_v2_64_32.py │ ├── swinunetr.py │ └── swinunetr_8.py ├── utils │ ├── luna_data_convert.py │ ├── data_utils.py │ └── data_utils_first_level_64.py ├── losses │ └── loss.py └── main.py ├── .DS_Store ├── imgs └── framework.png ├── README.md ├── .gitignore └── LICENSE /Pretrain/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pretrain/pretrain_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ge-xing/HybridMIM/HEAD/.DS_Store -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ge-xing/HybridMIM/HEAD/imgs/framework.png -------------------------------------------------------------------------------- /Pretrain/utils/luna_data_convert.py: -------------------------------------------------------------------------------- 1 | 2 | import SimpleITK as sitk 3 | import glob 4 | import os 5 | 6 | luna16_paths = glob.glob("/mnt/xingzhaohu/data/luna16/*/*.mhd") 7 | 8 | new_dir = "/mnt/xingzhaohu/data/luna16_convert/" 9 | os.makedirs(new_dir, exist_ok=True) 10 | index = 0 11 | 12 | from multiprocessing import Pool, Process 13 | 14 | def handle(data): 15 | filename = data.split("/")[-1] 16 | filename = filename[:-4] 17 | data = sitk.ReadImage(data) 18 | sitk.WriteImage(data, os.path.join(new_dir, f'{filename}.nii.gz')) 19 | print(f"{filename} save done.") 20 | return 21 | 22 | p = Pool(16) 23 | p.map_async(handle, luna16_paths, chunksize=16) 24 | p.close() 25 | p.join() 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HybridMIM 2 | HybridMIM: A Hybrid Masked Image Modeling Framework for 3D Medical Image Segmentation 3 | 4 | ![](/imgs/framework.png) 5 | # Pre-training 6 | 7 | We design a self-supervised learning method to learn the spatial information from the high dimensional medical images at multiple levels, including: pixel-level, region-level and sample-level. 8 | 9 | Compared with other self-supervised learning methods which evaluated on the single architecture, We support two architectures: UNet and SwinUNETR. 10 | 11 | We collect four datasets: Luna16, FLARE2021, Covid-19 and ATM22 for a general pre-training. You can search them at https://grand-challenge.org/challenges/all-challenges/. 12 | 13 | We evaluate HybridMIM using four segmentation datasets: BraTS2020, BTCV, MSD Liver. You can find them at: 14 | 1. https://www.med.upenn.edu/cbica/brats2020/data.html 15 | 2. https://www.synapse.org/#!Synapse:syn3193805 16 | 3. http://medicaldecathlon.com/ 17 | 18 | ## UNet pre-training 19 | ```bash 20 | python -m torch.distributed.launch --nproc_per_node=2 --master_port=11223 main.py --batch_size=1 --num_steps=100000 --lrdecay --lr=1e-4 --decay=0.001 --logdir=./deepunet --model_name=deepunet_v2 --eval_num=500 21 | ``` 22 | 23 | ## SwinUNETR pre-training 24 | 25 | ```bash 26 | python -m torch.distributed.launch --nproc_per_node=2 --master_port=11223 main.py --batch_size=1 --num_steps=100000 --lrdecay --lr=6e-6 --decay=0.1 --logdir=./swin_pretrain --smartcache_dataset --model_name=swin --eval_num=500 --val_cache 27 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Pretrain/losses/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2022 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | 15 | import einops 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from pretrain_models.utils import patchify 20 | 21 | def forward_constrast_loss(x_i, x_j, temp=0.5): 22 | device = x_i.device 23 | batch_size = x_i.shape[0] 24 | temp = torch.tensor(temp).to(device) 25 | neg_mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float() 26 | z_i = F.normalize(x_i, dim=1) 27 | z_j = F.normalize(x_j, dim=1) 28 | z = torch.cat([z_i, z_j], dim=0) 29 | sim = F.cosine_similarity(z.unsqueeze(1), 30 | z.unsqueeze(0), 31 | dim=2) 32 | sim_ij = torch.diag(sim, batch_size) 33 | sim_ji = torch.diag(sim, -batch_size) 34 | pos = torch.cat([sim_ij, sim_ji], dim=0) 35 | nom = torch.exp(pos / temp) 36 | denom = neg_mask * torch.exp(sim / temp) 37 | return torch.sum(-torch.log(nom / torch.sum(denom, dim=1))) / (2 * batch_size) 38 | 39 | def forward_loss_reconstruct_mask(pred, labels, mask_image, mask_value=0.0): 40 | # pred (b c d w h) 41 | # pred = torch.einsum("") 42 | mask = (mask_image == mask_value).float() 43 | 44 | loss = (pred - labels) ** 2 45 | 46 | # loss = loss.mean(dim=-1) # [N, L], mean loss per patch 47 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 48 | return loss 49 | 50 | def forward_loss_reconstruct(pred, labels): 51 | loss = (pred - labels) ** 2 52 | loss = loss.mean() # [N, L], mean loss per patch 53 | return loss 54 | 55 | 56 | def forward_loss_similarity(pred_1, pred_2): 57 | loss_fct = nn.CrossEntropyLoss() 58 | device = pred_1.device 59 | cos = nn.CosineSimilarity(dim=-1) 60 | sim = cos(pred_1.unsqueeze(dim=1), pred_2.unsqueeze(dim=0)) 61 | labels = torch.arange(sim.shape[0], dtype=torch.long).to(device) 62 | 63 | loss = loss_fct(sim, labels) 64 | 65 | return loss 66 | 67 | def forward_loss_mask_region(pred_bottom_feature, mask_labels): 68 | 69 | # pred_bottom_feature = einops.rearrange(pred_bottom_feature, "b d w h c->b (d w h) c") 70 | loss_fct = nn.CrossEntropyLoss() 71 | loss = loss_fct(pred_bottom_feature.reshape(-1, pred_bottom_feature.shape[-1]), mask_labels.reshape(-1)) 72 | return loss 73 | 74 | def forward_loss_mask(pred_bottom_feature, mask_labels): 75 | mask_labels = mask_labels.long() 76 | loss_fct = nn.CrossEntropyLoss() 77 | loss = loss_fct(pred_bottom_feature.reshape(-1, pred_bottom_feature.shape[-1]), mask_labels.reshape(-1)) 78 | return loss 79 | 80 | def forward_loss_mask_region_patch(pred_bottom_feature_patch, mask_labels_patch): 81 | 82 | loss_fct = nn.CrossEntropyLoss() 83 | loss = loss_fct(pred_bottom_feature_patch.reshape(-1, pred_bottom_feature_patch.shape[-1]), mask_labels_patch.reshape(-1)) 84 | return loss 85 | 86 | 87 | def forward_loss_mask_region_multi_label(pred_bottom_feature, mask_labels): 88 | loss_fct = nn.BCEWithLogitsLoss() 89 | mask_labels = mask_labels.float() 90 | loss = loss_fct(pred_bottom_feature, mask_labels) 91 | return loss 92 | 93 | 94 | def forward_loss_mask_position(pred_bottom_feature, mask_labels): 95 | mask_labels = mask_labels.float() 96 | loss_fct = nn.BCEWithLogitsLoss() 97 | loss = loss_fct(pred_bottom_feature, mask_labels) 98 | return loss 99 | 100 | 101 | class Contrast(torch.nn.Module): 102 | def __init__(self, args, batch_size, temperature=0.5): 103 | super().__init__() 104 | device = torch.device(f"cuda:{args.local_rank}") 105 | self.batch_size = batch_size 106 | self.register_buffer("temp", torch.tensor(temperature).to(torch.device(f"cuda:{args.local_rank}"))) 107 | self.register_buffer("neg_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float()) 108 | 109 | def forward(self, x_i, x_j): 110 | z_i = F.normalize(x_i, dim=1) 111 | z_j = F.normalize(x_j, dim=1) 112 | z = torch.cat([z_i, z_j], dim=0) 113 | sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) 114 | sim_ij = torch.diag(sim, self.batch_size) 115 | sim_ji = torch.diag(sim, -self.batch_size) 116 | pos = torch.cat([sim_ij, sim_ji], dim=0) 117 | nom = torch.exp(pos / self.temp) 118 | denom = self.neg_mask * torch.exp(sim / self.temp) 119 | return torch.sum(-torch.log(nom / torch.sum(denom, dim=1))) / (2 * self.batch_size) 120 | 121 | 122 | class Loss(torch.nn.Module): 123 | def __init__(self, batch_size, args): 124 | super().__init__() 125 | self.rot_loss = torch.nn.CrossEntropyLoss().cuda() 126 | self.recon_loss = torch.nn.L1Loss().cuda() 127 | self.contrast_loss = Contrast(args, batch_size).cuda() 128 | self.alpha1 = 1.0 129 | self.alpha2 = 1.0 130 | self.alpha3 = 1.0 131 | 132 | def __call__(self, output_rot, target_rot, output_contrastive, target_contrastive, output_recons, target_recons): 133 | rot_loss = self.alpha1 * self.rot_loss(output_rot, target_rot) 134 | contrast_loss = self.alpha2 * self.contrast_loss(output_contrastive, target_contrastive) 135 | recon_loss = self.alpha3 * self.recon_loss(output_recons, target_recons) 136 | total_loss = rot_loss + contrast_loss + recon_loss 137 | 138 | return total_loss, (rot_loss, contrast_loss, recon_loss) 139 | -------------------------------------------------------------------------------- /Pretrain/pretrain_models/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from einops import rearrange 5 | 6 | def patchify(in_channels, imgs, patch_size): 7 | """ 8 | imgs: (N, 4, D, H, W) 9 | x: (N, L, patch_size**3 *4) 10 | """ 11 | p = patch_size[0] 12 | assert imgs.shape[3] == imgs.shape[4] and imgs.shape[2] % p == 0 13 | d = h = w = imgs.shape[2] // p 14 | x = imgs.reshape(shape=(imgs.shape[0], in_channels, d, p, h, p, w, p)) 15 | x = torch.einsum('ncdkhpwq->ndhwkpqc', x) 16 | x = x.reshape(shape=(imgs.shape[0], d * h * w, p**3 * in_channels)) 17 | return x 18 | 19 | def unpatchify(in_channels, x, patch_size, image_size): 20 | """ 21 | x: (N, L, patch_size**3 *4) 22 | imgs: (N, 4, D, H, W) 23 | """ 24 | p = patch_size[0] 25 | d, h, w = image_size 26 | assert h * w * d == x.shape[1] 27 | 28 | x = x.reshape(shape=(x.shape[0], d, h, w, p, p, p, in_channels)) 29 | x = torch.einsum('ndhwkpqc->ncdkhpwq', x) 30 | imgs = x.reshape(shape=(x.shape[0], in_channels, d * p, h * p, h * p)) 31 | return imgs 32 | 33 | def random_masking(x, mask_ratio): 34 | """ 35 | Perform per-sample random masking by per-sample shuffling. 36 | Per-sample shuffling is done by argsort random noise. 37 | x: [N, L, D], sequence 38 | """ 39 | N, L, D = x.shape # batch, length, dim 40 | len_keep = int(L * (1 - mask_ratio)) 41 | 42 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 43 | 44 | # sort noise for each sample 45 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 46 | ids_restore = torch.argsort(ids_shuffle, dim=1) 47 | 48 | # keep the first subset 49 | ids_keep = ids_shuffle[:, :len_keep] 50 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 51 | 52 | # generate the binary mask: 0 is keep, 1 is remove 53 | mask = torch.ones([N, L], device=x.device) 54 | mask[:, :len_keep] = 0 55 | # unshuffle to get the binary mask 56 | mask = torch.gather(mask, dim=1, index=ids_restore) 57 | 58 | return x_masked, mask, ids_restore 59 | 60 | def mask_func(x, in_channels, mask_ratio, patch_size, image_size, mask_value=0.0): 61 | batch = x.shape[0] 62 | x_patch = patchify(in_channels, x, patch_size) 63 | 64 | mask_patch, mask, id = random_masking(x_patch, mask_ratio) 65 | mask_tokens = torch.ones(1, 1, in_channels * patch_size[0] * patch_size[1] * patch_size[2]) * mask_value 66 | device = x.device 67 | mask_tokens = mask_tokens.repeat(batch, id.shape[1] - mask_patch.shape[1], 1) 68 | mask_tokens = mask_tokens.to(device) 69 | 70 | x_ = torch.cat([mask_patch, mask_tokens], dim=1) # no cls token 71 | x_ = torch.gather(x_, dim=1, index=id.unsqueeze(-1).repeat(1, 1, mask_patch.shape[2])) # unshuffle 72 | # mask the input 73 | x = unpatchify(in_channels, x_, patch_size=patch_size, image_size=image_size) 74 | return x, mask 75 | 76 | def get_region_nums(mask_nums, patches_of_region): 77 | assert mask_nums % patches_of_region == 0 78 | return mask_nums // patches_of_region, patches_of_region 79 | 80 | def get_mask_labels(batch_size, num_regions, mask, mask_region_patches, device): 81 | mask_labels = [] 82 | for b in range(batch_size): 83 | mask_label_b = [] 84 | for i in range(num_regions): 85 | mask_label_b.append(mask[b, i*mask_region_patches: (i+1)*mask_region_patches].sum().item()) 86 | mask_labels.append(mask_label_b) 87 | mask_labels = torch.tensor(mask_labels, device=device).long() 88 | 89 | return mask_labels 90 | 91 | def get_mask_labelsv2(batch_size, num_regions, mask, mask_region_patches, device): 92 | mask_labels = torch.zeros(batch_size, num_regions, mask_region_patches).to(device) 93 | for b in range(batch_size): 94 | for i in range(len(mask[b])): 95 | region_i = i // mask_region_patches 96 | patch_i = i % mask_region_patches 97 | mask_labels[b, region_i, patch_i] = mask[b, i] 98 | return mask_labels 99 | 100 | def get_random_patch(img, 101 | downsample_scale, 102 | mask_labels, 103 | patches_of_region): 104 | 105 | device = img.device 106 | batch_size = img.shape[0] 107 | in_channels = img.shape[1] 108 | d, w, h = img.shape[2], img.shape[3], img.shape[4] 109 | patch_scale = (d // downsample_scale[0], w // downsample_scale[1], h // downsample_scale[2]) 110 | img = rearrange(img, "b c (p f) (q g) (o h) -> b (f g h) (c p q o)", 111 | p=downsample_scale[0], q=downsample_scale[1], o=downsample_scale[2], 112 | f=patch_scale[0], g=patch_scale[1], h=patch_scale[2]) 113 | rec_patchs = torch.zeros(img.shape[0], 114 | in_channels, 115 | downsample_scale[0], 116 | downsample_scale[1], 117 | downsample_scale[2], 118 | device=device) 119 | index = [] 120 | mask_labels_cpu = mask_labels.cpu().numpy() 121 | 122 | for b in range(batch_size): 123 | no_all_mask_patches = np.argwhere(mask_labels_cpu[b] < patches_of_region).reshape(-1) 124 | # get the random patch index 125 | random_rec_patch_index = no_all_mask_patches[np.random.randint(0, len(no_all_mask_patches))] 126 | index.append(random_rec_patch_index) 127 | rec_patchs[b] = rearrange(img[b, random_rec_patch_index], "(c p q o) -> c p q o", 128 | c=in_channels, 129 | p=downsample_scale[0], 130 | q=downsample_scale[1], 131 | o=downsample_scale[2]) 132 | 133 | return rec_patchs, index 134 | 135 | def get_random_patch_new(img, 136 | downsample_scale,): 137 | 138 | device = img.device 139 | batch_size = img.shape[0] 140 | in_channels = img.shape[1] 141 | patch_images = patchify(in_channels, img, downsample_scale) 142 | num_patchs = patch_images.shape[1] 143 | 144 | rec_patchs = torch.zeros(img.shape[0], 145 | in_channels, 146 | downsample_scale[0], 147 | downsample_scale[1], 148 | downsample_scale[2], 149 | device=device) 150 | index = [] 151 | 152 | for b in range(batch_size): 153 | # get the random patch index 154 | p_sum = 0 155 | while p_sum == 0: 156 | random_index = np.random.randint(0, num_patchs) 157 | random_patch = patch_images[b, random_index] 158 | p_sum = random_patch.sum() 159 | 160 | random_patch = random_patch.reshape(shape=(downsample_scale[0], downsample_scale[1], downsample_scale[2], in_channels)) 161 | random_patch = torch.einsum("hpqc->chpq", random_patch) 162 | 163 | index.append(random_index) 164 | rec_patchs[b] = random_patch 165 | 166 | return rec_patchs, index -------------------------------------------------------------------------------- /Pretrain/optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import math 13 | import warnings 14 | from typing import List 15 | 16 | from torch import nn as nn 17 | from torch.optim import Adam, Optimizer 18 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 19 | 20 | __all__ = ["LinearLR", "ExponentialLR"] 21 | 22 | 23 | class _LRSchedulerMONAI(_LRScheduler): 24 | """Base class for increasing the learning rate between two boundaries over a number 25 | of iterations""" 26 | 27 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: 28 | """ 29 | Args: 30 | optimizer: wrapped optimizer. 31 | end_lr: the final learning rate. 32 | num_iter: the number of iterations over which the test occurs. 33 | last_epoch: the index of last epoch. 34 | Returns: 35 | None 36 | """ 37 | self.end_lr = end_lr 38 | self.num_iter = num_iter 39 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) 40 | 41 | 42 | class LinearLR(_LRSchedulerMONAI): 43 | """Linearly increases the learning rate between two boundaries over a number of 44 | iterations. 45 | """ 46 | 47 | def get_lr(self): 48 | r = self.last_epoch / (self.num_iter - 1) 49 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 50 | 51 | 52 | class ExponentialLR(_LRSchedulerMONAI): 53 | """Exponentially increases the learning rate between two boundaries over a number of 54 | iterations. 55 | """ 56 | 57 | def get_lr(self): 58 | r = self.last_epoch / (self.num_iter - 1) 59 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 60 | 61 | 62 | class WarmupCosineSchedule(LambdaLR): 63 | """Linear warmup and then cosine decay. 64 | Based on https://huggingface.co/ implementation. 65 | """ 66 | 67 | def __init__( 68 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 69 | ) -> None: 70 | """ 71 | Args: 72 | optimizer: wrapped optimizer. 73 | warmup_steps: number of warmup iterations. 74 | t_total: total number of training iterations. 75 | cycles: cosine cycles parameter. 76 | last_epoch: the index of last epoch. 77 | Returns: 78 | None 79 | """ 80 | self.warmup_steps = warmup_steps 81 | self.t_total = t_total 82 | self.cycles = cycles 83 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) 84 | 85 | def lr_lambda(self, step): 86 | if step < self.warmup_steps: 87 | return float(step) / float(max(1.0, self.warmup_steps)) 88 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 89 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 90 | 91 | 92 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 93 | def __init__( 94 | self, 95 | optimizer: Optimizer, 96 | warmup_epochs: int, 97 | max_epochs: int, 98 | warmup_start_lr: float = 0.0, 99 | eta_min: float = 0.0, 100 | last_epoch: int = -1, 101 | ) -> None: 102 | """ 103 | Args: 104 | optimizer (Optimizer): Wrapped optimizer. 105 | warmup_epochs (int): Maximum number of iterations for linear warmup 106 | max_epochs (int): Maximum number of iterations 107 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 108 | eta_min (float): Minimum learning rate. Default: 0. 109 | last_epoch (int): The index of last epoch. Default: -1. 110 | """ 111 | self.warmup_epochs = warmup_epochs 112 | self.max_epochs = max_epochs 113 | self.warmup_start_lr = warmup_start_lr 114 | self.eta_min = eta_min 115 | 116 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 117 | 118 | def get_lr(self) -> List[float]: 119 | """ 120 | Compute learning rate using chainable form of the scheduler 121 | """ 122 | if not self._get_lr_called_within_step: 123 | warnings.warn( 124 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning 125 | ) 126 | 127 | if self.last_epoch == 0: 128 | return [self.warmup_start_lr] * len(self.base_lrs) 129 | elif self.last_epoch < self.warmup_epochs: 130 | return [ 131 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 132 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 133 | ] 134 | elif self.last_epoch == self.warmup_epochs: 135 | return self.base_lrs 136 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 137 | return [ 138 | group["lr"] 139 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 140 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 141 | ] 142 | 143 | return [ 144 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 145 | / ( 146 | 1 147 | + math.cos( 148 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) 149 | ) 150 | ) 151 | * (group["lr"] - self.eta_min) 152 | + self.eta_min 153 | for group in self.optimizer.param_groups 154 | ] 155 | 156 | def _get_closed_form_lr(self) -> List[float]: 157 | """ 158 | Called when epoch is passed as a param to the `step` function of the scheduler. 159 | """ 160 | if self.last_epoch < self.warmup_epochs: 161 | return [ 162 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 163 | for base_lr in self.base_lrs 164 | ] 165 | 166 | return [ 167 | self.eta_min 168 | + 0.5 169 | * (base_lr - self.eta_min) 170 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 171 | for base_lr in self.base_lrs 172 | ] 173 | -------------------------------------------------------------------------------- /Pretrain/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from monai.data import CacheDataset, DataLoader, Dataset, DistributedSampler, SmartCacheDataset, load_decathlon_datalist 2 | from monai.transforms import ( 3 | AddChanneld, 4 | AsChannelFirstd, 5 | Compose, 6 | CropForegroundd, 7 | LoadImaged, 8 | NormalizeIntensityd, 9 | Orientationd, 10 | RandCropByPosNegLabeld, 11 | RandSpatialCropSamplesd, 12 | ScaleIntensityRanged, 13 | Spacingd, 14 | RandSpatialCropd, 15 | SpatialPadd, 16 | ToTensord, 17 | RandFlipd 18 | ) 19 | import glob 20 | 21 | atm22_paths = glob.glob("/mnt/xingzhaohu/data/ATM22_1/imagesTr/*.nii.gz") + \ 22 | glob.glob("/mnt/xingzhaohu/data/TrainBatch2_new/imagesTr/*.nii.gz") 23 | luna16_paths = glob.glob("/mnt/xingzhaohu/data/luna16_convert/*.nii.gz") 24 | covid19_paths = glob.glob("/mnt/xingzhaohu/data/COVID-19-20_v2/*/*.nii.gz") 25 | flare21_paths = glob.glob("/mnt/xingzhaohu/data/FLARE2021/*.nii.gz") + \ 26 | glob.glob("/mnt/xingzhaohu/data/FLARE2021/ValidationImg/*.nii.gz") 27 | 28 | def build_ATM22(): 29 | data_list = [] 30 | for p in atm22_paths: 31 | data_list.append({"image": p}) 32 | 33 | return data_list 34 | 35 | def bulid_covid19(): 36 | data_list = [] 37 | for p in covid19_paths: 38 | data_list.append({"image": p}) 39 | 40 | return data_list 41 | 42 | def build_flare2021(): 43 | data_list = [] 44 | for p in flare21_paths: 45 | data_list.append({"image": p}) 46 | 47 | return data_list 48 | 49 | def build_luna16(): 50 | data_list = [] 51 | for p in luna16_paths: 52 | data_list.append({"image": p}) 53 | 54 | return data_list 55 | 56 | def random_selected(data: list, n): 57 | import random 58 | total = len(data) 59 | val_data = [] 60 | for i in range(n): 61 | random_i = random.randint(0, total-1) 62 | val_data.append(data[random_i]) 63 | data.pop(random_i) 64 | total = len(data) 65 | 66 | return data, val_data 67 | 68 | 69 | def get_loader(args): 70 | datalist1 = bulid_covid19() 71 | datalist2 = build_ATM22() 72 | datalist3 = build_flare2021() 73 | datalist4 = build_luna16() 74 | 75 | num_workers = 8 76 | print("Dataset 1 covid-19: number of data: {}".format(len(datalist1))) 77 | print("Dataset 2 ATM22: number of data: {}".format(len(datalist2))) 78 | print("Dataset 3 FLARE21: number of data: {}".format(len(datalist3))) 79 | print("Dataset 4 Luna16: number of data: {}".format(len(datalist4))) 80 | 81 | datalist = datalist1 + datalist2 + datalist3 + datalist4 82 | print("Dataset all training: number of data: {}".format(len(datalist))) 83 | 84 | train_data, val_data = random_selected(datalist, 270) 85 | 86 | print(f"training data is {len(train_data)}, validation data is {len(val_data)}") 87 | 88 | train_transforms = Compose( 89 | [ 90 | LoadImaged(keys=["image"]), 91 | AddChanneld(keys=["image"]), 92 | Orientationd(keys=["image"], axcodes="RAS"), 93 | ScaleIntensityRanged( 94 | keys=["image"], a_min=-1000, a_max=1000, b_min=0.001, b_max=1.0, clip=True 95 | ), 96 | SpatialPadd(keys="image", spatial_size=[96, 96, 96]), 97 | CropForegroundd(keys=["image"], source_key="image", k_divisible=[96, 96, 96]), 98 | RandSpatialCropSamplesd( 99 | keys=["image"], 100 | roi_size=[96, 96, 96], 101 | num_samples=2, 102 | random_center=True, 103 | random_size=False, 104 | ), 105 | # RandSpatialCropd(keys=["image"], roi_size=[96, 96, 96], 106 | # random_size=False, 107 | # allow_missing_keys=True), 108 | # RandFlipd(keys=["image"], prob=0.5, spatial_axis=0), 109 | # RandFlipd(keys=["image"], prob=0.5, spatial_axis=1), 110 | # RandFlipd(keys=["image"], prob=0.5, spatial_axis=2), 111 | # RandSpatialCropd(keys=["image", "label"], roi_size=[96, 112 | # 96, 113 | # 96], 114 | 115 | # random_size=False, 116 | # allow_missing_keys=True), 117 | ToTensord(keys=["image"]), 118 | ] 119 | ) 120 | 121 | val_transforms = Compose( 122 | [ 123 | LoadImaged(keys=["image"]), 124 | AddChanneld(keys=["image"]), 125 | Orientationd(keys=["image"], axcodes="RAS"), 126 | ScaleIntensityRanged( 127 | keys=["image"], a_min=-1000, a_max=1000, b_min=0.01, b_max=1.0, clip=True 128 | ), 129 | SpatialPadd(keys="image", spatial_size=[96, 96, 96]), 130 | CropForegroundd(keys=["image"], source_key="image", k_divisible=[96, 96, 96]), 131 | RandSpatialCropSamplesd( 132 | keys=["image"], 133 | roi_size=[96, 96, 96], 134 | num_samples=2, 135 | random_center=True, 136 | random_size=False, 137 | ), 138 | # RandSpatialCropd(keys=["image"], roi_size=[96, 96, 96], 139 | # random_size=False, 140 | # allow_missing_keys=True), 141 | # RandSpatialCropd(keys=["image", "label"], roi_size=[96, 142 | # 96, 143 | # 96], 144 | 145 | # random_size=False, 146 | # allow_missing_keys=True), 147 | ToTensord(keys=["image"]), 148 | ] 149 | ) 150 | 151 | if args.cache_dataset: 152 | print("Using MONAI Cache Dataset") 153 | train_ds = CacheDataset(data=train_data, transform=train_transforms, cache_rate=0.5, num_workers=num_workers) 154 | elif args.smartcache_dataset: 155 | print("Using MONAI SmartCache Dataset") 156 | train_ds = SmartCacheDataset( 157 | data=datalist, 158 | transform=train_transforms, 159 | replace_rate=1.0, 160 | cache_num=8, 161 | ) 162 | else: 163 | print("Using generic dataset") 164 | train_ds = Dataset(data=datalist, transform=train_transforms) 165 | 166 | if args.distributed: 167 | train_sampler = DistributedSampler(dataset=train_ds, even_divisible=True, shuffle=True) 168 | else: 169 | train_sampler = None 170 | train_loader = DataLoader( 171 | train_ds, batch_size=args.batch_size, num_workers=num_workers, sampler=train_sampler, drop_last=True 172 | ) 173 | 174 | if args.val_cache: 175 | val_ds = CacheDataset(data=val_data, transform=val_transforms, num_workers=num_workers) 176 | else : 177 | val_ds = Dataset(data=val_data, transform=val_transforms) 178 | 179 | val_loader = DataLoader( 180 | val_ds, batch_size=args.batch_size, num_workers=8, drop_last=True, shuffle=False 181 | ) 182 | 183 | return train_loader, val_loader -------------------------------------------------------------------------------- /Pretrain/utils/data_utils_first_level_64.py: -------------------------------------------------------------------------------- 1 | from monai.data import CacheDataset, DataLoader, Dataset, DistributedSampler, SmartCacheDataset, load_decathlon_datalist 2 | from monai.transforms import ( 3 | AddChanneld, 4 | AsChannelFirstd, 5 | Compose, 6 | CropForegroundd, 7 | LoadImaged, 8 | NormalizeIntensityd, 9 | Orientationd, 10 | RandCropByPosNegLabeld, 11 | RandSpatialCropSamplesd, 12 | ScaleIntensityRanged, 13 | Spacingd, 14 | RandSpatialCropd, 15 | SpatialPadd, 16 | ToTensord, 17 | RandFlipd 18 | ) 19 | import glob 20 | 21 | atm22_paths = glob.glob("/mnt/xingzhaohu/data/ATM22_1/imagesTr/*.nii.gz") + \ 22 | glob.glob("/mnt/xingzhaohu/data/TrainBatch2_new/imagesTr/*.nii.gz") 23 | luna16_paths = glob.glob("/mnt/xingzhaohu/data/luna16_convert/*.nii.gz") 24 | covid19_paths = glob.glob("/mnt/xingzhaohu/data/COVID-19-20_v2/*/*.nii.gz") 25 | flare21_paths = glob.glob("/mnt/xingzhaohu/data/FLARE2021/*.nii.gz") + \ 26 | glob.glob("/mnt/xingzhaohu/data/FLARE2021/ValidationImg/*.nii.gz") 27 | 28 | def build_ATM22(): 29 | data_list = [] 30 | for p in atm22_paths: 31 | data_list.append({"image": p}) 32 | 33 | return data_list 34 | 35 | def bulid_covid19(): 36 | data_list = [] 37 | for p in covid19_paths: 38 | data_list.append({"image": p}) 39 | 40 | return data_list 41 | 42 | def build_flare2021(): 43 | data_list = [] 44 | for p in flare21_paths: 45 | data_list.append({"image": p}) 46 | 47 | return data_list 48 | 49 | def build_luna16(): 50 | data_list = [] 51 | for p in luna16_paths: 52 | data_list.append({"image": p}) 53 | 54 | return data_list 55 | 56 | def random_selected(data: list, n): 57 | import random 58 | total = len(data) 59 | val_data = [] 60 | for i in range(n): 61 | random_i = random.randint(0, total-1) 62 | val_data.append(data[random_i]) 63 | data.pop(random_i) 64 | total = len(data) 65 | 66 | return data, val_data 67 | 68 | 69 | def get_loader(args): 70 | datalist1 = bulid_covid19() 71 | datalist2 = build_ATM22() 72 | datalist3 = build_flare2021() 73 | datalist4 = build_luna16() 74 | 75 | num_workers = 8 76 | print("Dataset 1 covid-19: number of data: {}".format(len(datalist1))) 77 | print("Dataset 2 ATM22: number of data: {}".format(len(datalist2))) 78 | print("Dataset 3 FLARE21: number of data: {}".format(len(datalist3))) 79 | print("Dataset 4 Luna16: number of data: {}".format(len(datalist4))) 80 | 81 | datalist = datalist1 + datalist2 + datalist3 + datalist4 82 | print("Dataset all training: number of data: {}".format(len(datalist))) 83 | 84 | train_data, val_data = random_selected(datalist, 270) 85 | 86 | print(f"training data is {len(train_data)}, validation data is {len(val_data)}") 87 | 88 | train_transforms = Compose( 89 | [ 90 | LoadImaged(keys=["image"]), 91 | AddChanneld(keys=["image"]), 92 | Orientationd(keys=["image"], axcodes="RAS"), 93 | ScaleIntensityRanged( 94 | keys=["image"], a_min=-1000, a_max=1000, b_min=0.001, b_max=1.0, clip=True 95 | ), 96 | SpatialPadd(keys="image", spatial_size=[128, 128, 128]), 97 | CropForegroundd(keys=["image"], source_key="image", k_divisible=[128, 128, 128]), 98 | RandSpatialCropSamplesd( 99 | keys=["image"], 100 | roi_size=[128, 128, 128], 101 | num_samples=2, 102 | random_center=True, 103 | random_size=False, 104 | ), 105 | # RandSpatialCropd(keys=["image"], roi_size=[96, 96, 96], 106 | # random_size=False, 107 | # allow_missing_keys=True), 108 | # RandFlipd(keys=["image"], prob=0.5, spatial_axis=0), 109 | # RandFlipd(keys=["image"], prob=0.5, spatial_axis=1), 110 | # RandFlipd(keys=["image"], prob=0.5, spatial_axis=2), 111 | # RandSpatialCropd(keys=["image", "label"], roi_size=[96, 112 | # 96, 113 | # 96], 114 | 115 | # random_size=False, 116 | # allow_missing_keys=True), 117 | ToTensord(keys=["image"]), 118 | ] 119 | ) 120 | 121 | val_transforms = Compose( 122 | [ 123 | LoadImaged(keys=["image"]), 124 | AddChanneld(keys=["image"]), 125 | Orientationd(keys=["image"], axcodes="RAS"), 126 | ScaleIntensityRanged( 127 | keys=["image"], a_min=-1000, a_max=1000, b_min=0.01, b_max=1.0, clip=True 128 | ), 129 | SpatialPadd(keys="image", spatial_size=[128, 128, 128]), 130 | CropForegroundd(keys=["image"], source_key="image", k_divisible=[128, 128, 128]), 131 | RandSpatialCropSamplesd( 132 | keys=["image"], 133 | roi_size=[128, 128, 128], 134 | num_samples=2, 135 | random_center=True, 136 | random_size=False, 137 | ), 138 | # RandSpatialCropd(keys=["image"], roi_size=[96, 96, 96], 139 | # random_size=False, 140 | # allow_missing_keys=True), 141 | # RandSpatialCropd(keys=["image", "label"], roi_size=[96, 142 | # 96, 143 | # 96], 144 | 145 | # random_size=False, 146 | # allow_missing_keys=True), 147 | ToTensord(keys=["image"]), 148 | ] 149 | ) 150 | 151 | if args.cache_dataset: 152 | print("Using MONAI Cache Dataset") 153 | train_ds = CacheDataset(data=train_data, transform=train_transforms, cache_rate=0.5, num_workers=num_workers) 154 | elif args.smartcache_dataset: 155 | print("Using MONAI SmartCache Dataset") 156 | train_ds = SmartCacheDataset( 157 | data=datalist, 158 | transform=train_transforms, 159 | replace_rate=1.0, 160 | cache_num=8, 161 | ) 162 | else: 163 | print("Using generic dataset") 164 | train_ds = Dataset(data=datalist, transform=train_transforms) 165 | 166 | if args.distributed: 167 | train_sampler = DistributedSampler(dataset=train_ds, even_divisible=True, shuffle=True) 168 | else: 169 | train_sampler = None 170 | train_loader = DataLoader( 171 | train_ds, batch_size=args.batch_size, num_workers=num_workers, sampler=train_sampler, drop_last=True 172 | ) 173 | 174 | if args.val_cache: 175 | val_ds = CacheDataset(data=val_data, transform=val_transforms, num_workers=num_workers) 176 | else : 177 | val_ds = Dataset(data=val_data, transform=val_transforms) 178 | 179 | val_loader = DataLoader( 180 | val_ds, batch_size=args.batch_size, num_workers=8, drop_last=True, shuffle=False 181 | ) 182 | 183 | return train_loader, val_loader -------------------------------------------------------------------------------- /Pretrain/pretrain_models/deep_unet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from multiprocessing import pool 13 | from typing import Sequence, Union 14 | import torch 15 | import torch.nn as nn 16 | from monai.networks.blocks import Convolution, UpSample 17 | from monai.networks.layers.factories import Conv, Pool 18 | from .utils import mask_func 19 | from .utils import get_mask_labels, get_mask_labelsv2 20 | import copy 21 | 22 | class TwoConv(nn.Sequential): 23 | """two convolutions.""" 24 | def __init__( 25 | self, 26 | dim: int, 27 | in_chns: int, 28 | out_chns: int, 29 | act: Union[str, tuple], 30 | norm: Union[str, tuple], 31 | dropout: Union[float, tuple] = 0.0, 32 | ): 33 | """ 34 | Args: 35 | dim: number of spatial dimensions. 36 | in_chns: number of input channels. 37 | out_chns: number of output channels. 38 | act: activation type and arguments. 39 | norm: feature normalization type and arguments. 40 | dropout: dropout ratio. Defaults to no dropout. 41 | """ 42 | super().__init__() 43 | 44 | conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 45 | conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 46 | self.add_module("conv_0", conv_0) 47 | self.add_module("conv_1", conv_1) 48 | 49 | class Down(nn.Sequential): 50 | """maxpooling downsampling and two convolutions.""" 51 | 52 | def __init__( 53 | self, 54 | dim: int, 55 | in_chns: int, 56 | out_chns: int, 57 | act: Union[str, tuple], 58 | norm: Union[str, tuple], 59 | dropout: Union[float, tuple] = 0.0, 60 | pool_size=(2, 2, 2) 61 | ): 62 | """ 63 | Args: 64 | dim: number of spatial dimensions. 65 | in_chns: number of input channels. 66 | out_chns: number of output channels. 67 | act: activation type and arguments. 68 | norm: feature normalization type and arguments. 69 | dropout: dropout ratio. Defaults to no dropout. 70 | """ 71 | super().__init__() 72 | 73 | max_pooling = Pool["MAX", dim](kernel_size=pool_size) 74 | convs = TwoConv(dim, in_chns, out_chns, act, norm, dropout) 75 | self.add_module("max_pooling", max_pooling) 76 | self.add_module("convs", convs) 77 | 78 | 79 | class UpCat(nn.Module): 80 | """upsampling, concatenation with the encoder feature map, two convolutions""" 81 | 82 | def __init__( 83 | self, 84 | dim: int, 85 | in_chns: int, 86 | cat_chns: int, 87 | out_chns: int, 88 | act: Union[str, tuple], 89 | norm: Union[str, tuple], 90 | dropout: Union[float, tuple] = 0.0, 91 | upsample: str = "deconv", 92 | halves: bool = True, 93 | pool_size = (2, 2, 2) 94 | ): 95 | """ 96 | Args: 97 | dim: number of spatial dimensions. 98 | in_chns: number of input channels to be upsampled. 99 | cat_chns: number of channels from the decoder. 100 | out_chns: number of output channels. 101 | act: activation type and arguments. 102 | norm: feature normalization type and arguments. 103 | dropout: dropout ratio. Defaults to no dropout. 104 | upsample: upsampling mode, available options are 105 | ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. 106 | halves: whether to halve the number of channels during upsampling. 107 | """ 108 | super().__init__() 109 | 110 | up_chns = in_chns // 2 if halves else in_chns 111 | self.upsample = UpSample(dim, in_chns, up_chns, pool_size, mode=upsample) 112 | self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout) 113 | 114 | def forward(self, x: torch.Tensor, x_e: torch.Tensor): 115 | """ 116 | 117 | Args: 118 | x: features to be upsampled. 119 | x_e: features from the encoder. 120 | """ 121 | x_0 = self.upsample(x) 122 | 123 | # handling spatial shapes due to the 2x maxpooling with odd edge lengths. 124 | dimensions = len(x.shape) - 2 125 | sp = [0] * (dimensions * 2) 126 | for i in range(dimensions): 127 | if x_e.shape[-i - 1] != x_0.shape[-i - 1]: 128 | sp[i * 2 + 1] = 1 129 | x_0 = torch.nn.functional.pad(x_0, sp, "replicate") 130 | 131 | x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns) 132 | return x 133 | 134 | class DeepUNet(nn.Module): 135 | 136 | def cons_stages(self, pools, region): 137 | stage = [(copy.deepcopy(region[0]), copy.deepcopy(region[1]))] 138 | for pool in pools: 139 | for i, r in enumerate(region): 140 | region[i][0] = region[i][0] * pool[0] 141 | region[i][1] = region[i][1] * pool[1] 142 | region[i][2] = region[i][2] * pool[2] 143 | stage.append((copy.deepcopy(region[0]), copy.deepcopy(region[1]))) 144 | 145 | return stage 146 | 147 | def __init__( 148 | self, 149 | in_channels: int = 1, 150 | out_channels: int = 2, 151 | features: Sequence[int] = (32, 32, 64, 128, 256), 152 | act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), 153 | norm: Union[str, tuple] = ("instance", {"affine": True}), 154 | dropout: Union[float, tuple] = 0.0, 155 | upsample: str = "deconv", 156 | pool_size = ((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), 157 | select_reconstruct_region=[[0, 0, 0], [3, 3, 3]], # 重构范围, 158 | pretrain=False, 159 | ): 160 | super().__init__() 161 | deepth = len(pool_size) 162 | self.deepth = deepth 163 | self.in_channels = in_channels 164 | fea = features 165 | down_sample_size = len(pool_size) 166 | print(f"BasicUNet features: {fea}.") 167 | self.select_reconstruct_region = select_reconstruct_region 168 | self.stages = self.cons_stages(pool_size, select_reconstruct_region) 169 | print(f"self.stages is {self.stages}") 170 | self.pretrain = pretrain 171 | ## get patches of region 172 | self.drop = nn.Dropout() 173 | self.conv_0 = TwoConv(3, in_channels, features[0], act, norm, dropout) 174 | 175 | self.downs = nn.ModuleList([]) 176 | 177 | for d in range(deepth): 178 | self.downs.append(Down(3, fea[d], fea[d+1], act=act, norm=norm, pool_size=pool_size[d])) 179 | 180 | self.ups = nn.ModuleList([]) 181 | for d in range(deepth): 182 | self.ups.append(UpCat(3, fea[deepth-d], fea[deepth-d-1], fea[deepth-d-1], act, norm, dropout, pool_size=pool_size[deepth-d-1], upsample=upsample)) 183 | 184 | self.decoder_pred = nn.Conv3d(fea[0], out_channels, 1, 1) 185 | 186 | if pretrain: 187 | bottom_feature = features[-1] 188 | self.pred_mask_region = nn.Linear(bottom_feature, 9)# 一个region 4个 patch 189 | self.contrast_learning_head = nn.Linear(bottom_feature, 384) 190 | self.pred_mask_region_position = nn.Linear(bottom_feature, 8) 191 | 192 | def wrap_feature_selection(self, feature, region_box): 193 | # feature: b, c, d, w, h 194 | return feature[..., region_box[0][0]:region_box[1][0], region_box[0][1]:region_box[1][1], region_box[0][2]:region_box[1][2]] 195 | 196 | def get_local_images(self, images): 197 | images = self.wrap_feature_selection(images, region_box=self.stages[-1]) 198 | return images 199 | 200 | def forward_encoder(self, x): 201 | x = self.conv_0(x) 202 | x_downs = [x] 203 | for d in range(self.deepth): 204 | x = self.downs[d](x) 205 | x_downs.append(x) 206 | return x_downs 207 | 208 | def forward_decoder(self, x_downs): 209 | x = self.wrap_feature_selection(x_downs[-1], self.stages[0]) 210 | 211 | for d in range(self.deepth): 212 | x = self.ups[d](x, self.wrap_feature_selection(x_downs[self.deepth-d-1], self.stages[d+1])) 213 | logits = self.decoder_pred(x) 214 | return logits 215 | 216 | def forward(self, x): 217 | device = x.device 218 | images = x.detach() 219 | local_images = self.get_local_images(images) 220 | if self.pretrain: 221 | mask_ratio = torch.clamp(torch.rand(1), 0.4, 0.7) 222 | x, mask = mask_func(x, self.in_channels, mask_ratio, (16, 16, 16), (6, 6, 6), mask_value=0.0) 223 | region_mask_labels = get_mask_labels(x.shape[0], 3*3*3, mask, 2*2*2, device) 224 | region_mask_position = get_mask_labelsv2(x.shape[0], 3*3*3, mask, 2*2*2, device=device) 225 | 226 | x_mask = self.wrap_feature_selection(x, region_box=self.stages[-1]) 227 | 228 | hidden_states_out = self.forward_encoder(x) 229 | logits = self.forward_decoder(hidden_states_out) 230 | 231 | if self.pretrain: 232 | with torch.no_grad(): 233 | hidden_states_out_2 = self.forward_encoder(x) 234 | encode_feature = hidden_states_out[-1] 235 | encode_feature_2 = hidden_states_out_2[-1] 236 | 237 | x4_reshape = encode_feature.flatten(start_dim=2, end_dim=4) 238 | x4_reshape = x4_reshape.transpose(1, 2) 239 | 240 | x4_reshape_2 = encode_feature_2.flatten(start_dim=2, end_dim=4) 241 | x4_reshape_2 = x4_reshape_2.transpose(1, 2) 242 | 243 | contrast_pred = self.contrast_learning_head(x4_reshape.mean(dim=1)) 244 | contrast_pred_2 = self.contrast_learning_head(x4_reshape_2.mean(dim=1)) 245 | 246 | pred_mask_feature = encode_feature.flatten(start_dim=2, end_dim=4) 247 | pred_mask_feature = pred_mask_feature.transpose(1, 2) 248 | mask_region_pred = self.pred_mask_region(pred_mask_feature) 249 | 250 | pred_mask_feature_position = encode_feature.flatten(start_dim=2, end_dim=4) 251 | pred_mask_feature_position = pred_mask_feature_position.transpose(1, 2) 252 | mask_region_position_pred = self.pred_mask_region_position(pred_mask_feature_position) 253 | 254 | return { 255 | "logits": logits, 256 | 'images': local_images, 257 | "pred_mask_region": mask_region_pred, 258 | "pred_mask_region_position": mask_region_position_pred, 259 | "mask_position_lables": region_mask_position, 260 | "mask": mask, 261 | "x_mask": x_mask, 262 | # "patch_size": patch_size, 263 | # "mask_feat_size": mask_feat_size, 264 | # "mask_labels": mask_labels, 265 | 266 | "mask_labels": region_mask_labels, 267 | "contrast_pred_1": contrast_pred, 268 | "contrast_pred_2": contrast_pred_2, 269 | } 270 | else : 271 | return logits 272 | 273 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Pretrain/pretrain_models/deep_unet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | from einops import rearrange 12 | 13 | from multiprocessing import pool 14 | from typing import Sequence, Union 15 | import torch 16 | import torch.nn as nn 17 | from monai.networks.blocks import Convolution, UpSample 18 | from monai.networks.layers.factories import Conv, Pool 19 | from .utils import mask_func 20 | from .utils import get_mask_labels, get_mask_labelsv2 21 | import copy 22 | 23 | class TwoConv(nn.Sequential): 24 | """two convolutions.""" 25 | def __init__( 26 | self, 27 | dim: int, 28 | in_chns: int, 29 | out_chns: int, 30 | act: Union[str, tuple], 31 | norm: Union[str, tuple], 32 | dropout: Union[float, tuple] = 0.0, 33 | ): 34 | """ 35 | Args: 36 | dim: number of spatial dimensions. 37 | in_chns: number of input channels. 38 | out_chns: number of output channels. 39 | act: activation type and arguments. 40 | norm: feature normalization type and arguments. 41 | dropout: dropout ratio. Defaults to no dropout. 42 | """ 43 | super().__init__() 44 | 45 | conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 46 | conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 47 | self.add_module("conv_0", conv_0) 48 | self.add_module("conv_1", conv_1) 49 | 50 | class Down(nn.Sequential): 51 | """maxpooling downsampling and two convolutions.""" 52 | 53 | def __init__( 54 | self, 55 | dim: int, 56 | in_chns: int, 57 | out_chns: int, 58 | act: Union[str, tuple], 59 | norm: Union[str, tuple], 60 | dropout: Union[float, tuple] = 0.0, 61 | pool_size=(2, 2, 2) 62 | ): 63 | """ 64 | Args: 65 | dim: number of spatial dimensions. 66 | in_chns: number of input channels. 67 | out_chns: number of output channels. 68 | act: activation type and arguments. 69 | norm: feature normalization type and arguments. 70 | dropout: dropout ratio. Defaults to no dropout. 71 | """ 72 | super().__init__() 73 | 74 | max_pooling = Pool["MAX", dim](kernel_size=pool_size) 75 | convs = TwoConv(dim, in_chns, out_chns, act, norm, dropout) 76 | self.add_module("max_pooling", max_pooling) 77 | self.add_module("convs", convs) 78 | 79 | 80 | class UpCat(nn.Module): 81 | """upsampling, concatenation with the encoder feature map, two convolutions""" 82 | 83 | def __init__( 84 | self, 85 | dim: int, 86 | in_chns: int, 87 | cat_chns: int, 88 | out_chns: int, 89 | act: Union[str, tuple], 90 | norm: Union[str, tuple], 91 | dropout: Union[float, tuple] = 0.0, 92 | upsample: str = "deconv", 93 | halves: bool = True, 94 | pool_size = (2, 2, 2) 95 | ): 96 | """ 97 | Args: 98 | dim: number of spatial dimensions. 99 | in_chns: number of input channels to be upsampled. 100 | cat_chns: number of channels from the decoder. 101 | out_chns: number of output channels. 102 | act: activation type and arguments. 103 | norm: feature normalization type and arguments. 104 | dropout: dropout ratio. Defaults to no dropout. 105 | upsample: upsampling mode, available options are 106 | ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. 107 | halves: whether to halve the number of channels during upsampling. 108 | """ 109 | super().__init__() 110 | 111 | up_chns = in_chns // 2 if halves else in_chns 112 | self.upsample = UpSample(dim, in_chns, up_chns, pool_size, mode=upsample) 113 | self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout) 114 | 115 | def forward(self, x: torch.Tensor, x_e: torch.Tensor): 116 | """ 117 | 118 | Args: 119 | x: features to be upsampled. 120 | x_e: features from the encoder. 121 | """ 122 | x_0 = self.upsample(x) 123 | 124 | # handling spatial shapes due to the 2x maxpooling with odd edge lengths. 125 | dimensions = len(x.shape) - 2 126 | sp = [0] * (dimensions * 2) 127 | for i in range(dimensions): 128 | if x_e.shape[-i - 1] != x_0.shape[-i - 1]: 129 | sp[i * 2 + 1] = 1 130 | x_0 = torch.nn.functional.pad(x_0, sp, "replicate") 131 | 132 | x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns) 133 | return x 134 | 135 | class DeepUNet(nn.Module): 136 | 137 | def cons_stages(self, pools, region): 138 | stage = [(copy.deepcopy(region[0]), copy.deepcopy(region[1]))] 139 | for pool in reversed(pools): 140 | for i, r in enumerate(region): 141 | region[i][0] = region[i][0] * pool[0] 142 | region[i][1] = region[i][1] * pool[1] 143 | region[i][2] = region[i][2] * pool[2] 144 | stage.append((copy.deepcopy(region[0]), copy.deepcopy(region[1]))) 145 | 146 | return stage 147 | 148 | def __init__( 149 | self, 150 | in_channels: int = 1, 151 | out_channels: int = 2, 152 | features: Sequence[int] = (32, 32, 64, 128, 256), 153 | act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), 154 | norm: Union[str, tuple] = ("instance", {"affine": True}), 155 | dropout: Union[float, tuple] = 0.0, 156 | upsample: str = "deconv", 157 | pool_size = ((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), 158 | select_reconstruct_region=[[0, 0, 0], [8, 8, 8]], # 重构范围, 159 | first_level_region = (32, 32, 32), 160 | two_level_region = (16, 16, 16), 161 | pretrain=False, 162 | ): 163 | super().__init__() 164 | deepth = len(pool_size) 165 | self.deepth = deepth 166 | self.in_channels = in_channels 167 | fea = features 168 | print(f"BasicUNet features: {fea}.") 169 | self.select_reconstruct_region = select_reconstruct_region 170 | self.stages = self.cons_stages(pool_size, select_reconstruct_region) 171 | print(f"self.stages is {self.stages}") 172 | self.pool_size_all = self.get_pool_size_all(pool_size) 173 | self.window_size = torch.tensor(first_level_region) // torch.tensor(self.pool_size_all) 174 | print(f"window size is {self.window_size}") 175 | self.pretrain = pretrain 176 | ## get patches of region 177 | self.drop = nn.Dropout() 178 | self.conv_0 = TwoConv(3, in_channels, features[0], act, norm, dropout) 179 | 180 | self.downs = nn.ModuleList([]) 181 | 182 | for d in range(deepth): 183 | self.downs.append(Down(3, fea[d], fea[d+1], act=act, norm=norm, pool_size=pool_size[d])) 184 | 185 | self.ups = nn.ModuleList([]) 186 | for d in range(deepth): 187 | self.ups.append(UpCat(3, fea[deepth-d], fea[deepth-d-1], fea[deepth-d-1], act, norm, dropout, pool_size=pool_size[deepth-d-1], upsample=upsample)) 188 | 189 | self.decoder_pred = nn.Conv3d(fea[0], out_channels, 1, 1) 190 | 191 | if pretrain: 192 | bottom_feature = features[-1] 193 | self.pred_mask_region = nn.Linear(bottom_feature, 9)# 一个region 4个 patch 194 | self.contrast_learning_head = nn.Linear(bottom_feature, 384) 195 | self.pred_mask_region_position = nn.Linear(bottom_feature, 8) 196 | 197 | def get_pool_size_all(self, pool_size): 198 | p_all = [1, 1, 1] 199 | for p in pool_size: 200 | p_all[0] = p_all[0] * p[0] 201 | p_all[1] = p_all[1] * p[1] 202 | p_all[2] = p_all[2] * p[2] 203 | return p_all 204 | 205 | def wrap_feature_selection(self, feature, region_box): 206 | # feature: b, c, d, w, h 207 | return feature[..., region_box[0][0]:region_box[1][0], region_box[0][1]:region_box[1][1], region_box[0][2]:region_box[1][2]] 208 | 209 | def get_local_images(self, images): 210 | images = self.wrap_feature_selection(images, region_box=self.stages[-1]) 211 | return images 212 | 213 | def forward_encoder(self, x): 214 | x = self.conv_0(x) 215 | x_downs = [x] 216 | for d in range(self.deepth): 217 | x = self.downs[d](x) 218 | x_downs.append(x) 219 | return x_downs 220 | 221 | def forward_decoder(self, x_downs): 222 | x = self.wrap_feature_selection(x_downs[-1], self.stages[0]) 223 | 224 | for d in range(self.deepth): 225 | x = self.ups[d](x, self.wrap_feature_selection(x_downs[self.deepth-d-1], self.stages[d+1])) 226 | logits = self.decoder_pred(x) 227 | return logits 228 | 229 | def forward(self, x): 230 | device = x.device 231 | images = x.detach() 232 | local_images = self.get_local_images(images) 233 | if self.pretrain: 234 | # mask_ratio = torch.clamp(torch.rand(1), 0.4, 0.75) 235 | mask_ratio = 0.4 236 | x, mask = mask_func(x, self.in_channels, mask_ratio, (16, 16, 16), (6, 6, 6), mask_value=0.0) 237 | region_mask_labels = get_mask_labels(x.shape[0], 3*3*3, mask, 2*2*2, device) 238 | region_mask_position = get_mask_labelsv2(x.shape[0], 3*3*3, mask, 2*2*2, device=device) 239 | 240 | x_mask = self.wrap_feature_selection(x, region_box=self.stages[-1]) 241 | 242 | hidden_states_out = self.forward_encoder(x) 243 | logits = self.forward_decoder(hidden_states_out) 244 | 245 | if self.pretrain: 246 | # print(hidden_states_out.shape) 247 | classifier_hidden_states = rearrange(hidden_states_out[-1], "b c (d m) (w n) (h l) -> b c d w h (m n l)", m=self.window_size[0], n=self.window_size[1], l=self.window_size[2]) 248 | classifier_hidden_states = classifier_hidden_states.mean(dim=-1) 249 | with torch.no_grad(): 250 | hidden_states_out_2 = self.forward_encoder(x) 251 | encode_feature = hidden_states_out[-1] 252 | encode_feature_2 = hidden_states_out_2[-1] 253 | 254 | x4_reshape = encode_feature.flatten(start_dim=2, end_dim=4) 255 | x4_reshape = x4_reshape.transpose(1, 2) 256 | 257 | x4_reshape_2 = encode_feature_2.flatten(start_dim=2, end_dim=4) 258 | x4_reshape_2 = x4_reshape_2.transpose(1, 2) 259 | 260 | contrast_pred = self.contrast_learning_head(x4_reshape.mean(dim=1)) 261 | contrast_pred_2 = self.contrast_learning_head(x4_reshape_2.mean(dim=1)) 262 | 263 | pred_mask_feature = classifier_hidden_states.flatten(start_dim=2, end_dim=4) 264 | pred_mask_feature = pred_mask_feature.transpose(1, 2) 265 | mask_region_pred = self.pred_mask_region(pred_mask_feature) 266 | 267 | pred_mask_feature_position = classifier_hidden_states.flatten(start_dim=2, end_dim=4) 268 | pred_mask_feature_position = pred_mask_feature_position.transpose(1, 2) 269 | mask_region_position_pred = self.pred_mask_region_position(pred_mask_feature_position) 270 | 271 | return { 272 | "logits": logits, 273 | 'images': local_images, 274 | "pred_mask_region": mask_region_pred, 275 | "pred_mask_region_position": mask_region_position_pred, 276 | "mask_position_lables": region_mask_position, 277 | "mask": mask, 278 | "x_mask": x_mask, 279 | "mask_labels": region_mask_labels, 280 | "contrast_pred_1": contrast_pred, 281 | "contrast_pred_2": contrast_pred_2, 282 | } 283 | else : 284 | return logits 285 | 286 | -------------------------------------------------------------------------------- /Pretrain/pretrain_models/deep_unet_v2_64_16.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | from einops import rearrange 12 | 13 | from multiprocessing import pool 14 | from typing import Sequence, Union 15 | import torch 16 | import torch.nn as nn 17 | from monai.networks.blocks import Convolution, UpSample 18 | from monai.networks.layers.factories import Conv, Pool 19 | from .utils import mask_func 20 | from .utils import get_mask_labels, get_mask_labelsv2 21 | import copy 22 | 23 | class TwoConv(nn.Sequential): 24 | """two convolutions.""" 25 | def __init__( 26 | self, 27 | dim: int, 28 | in_chns: int, 29 | out_chns: int, 30 | act: Union[str, tuple], 31 | norm: Union[str, tuple], 32 | dropout: Union[float, tuple] = 0.0, 33 | ): 34 | """ 35 | Args: 36 | dim: number of spatial dimensions. 37 | in_chns: number of input channels. 38 | out_chns: number of output channels. 39 | act: activation type and arguments. 40 | norm: feature normalization type and arguments. 41 | dropout: dropout ratio. Defaults to no dropout. 42 | """ 43 | super().__init__() 44 | 45 | conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 46 | conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 47 | self.add_module("conv_0", conv_0) 48 | self.add_module("conv_1", conv_1) 49 | 50 | class Down(nn.Sequential): 51 | """maxpooling downsampling and two convolutions.""" 52 | 53 | def __init__( 54 | self, 55 | dim: int, 56 | in_chns: int, 57 | out_chns: int, 58 | act: Union[str, tuple], 59 | norm: Union[str, tuple], 60 | dropout: Union[float, tuple] = 0.0, 61 | pool_size=(2, 2, 2) 62 | ): 63 | """ 64 | Args: 65 | dim: number of spatial dimensions. 66 | in_chns: number of input channels. 67 | out_chns: number of output channels. 68 | act: activation type and arguments. 69 | norm: feature normalization type and arguments. 70 | dropout: dropout ratio. Defaults to no dropout. 71 | """ 72 | super().__init__() 73 | 74 | max_pooling = Pool["MAX", dim](kernel_size=pool_size) 75 | convs = TwoConv(dim, in_chns, out_chns, act, norm, dropout) 76 | self.add_module("max_pooling", max_pooling) 77 | self.add_module("convs", convs) 78 | 79 | 80 | class UpCat(nn.Module): 81 | """upsampling, concatenation with the encoder feature map, two convolutions""" 82 | 83 | def __init__( 84 | self, 85 | dim: int, 86 | in_chns: int, 87 | cat_chns: int, 88 | out_chns: int, 89 | act: Union[str, tuple], 90 | norm: Union[str, tuple], 91 | dropout: Union[float, tuple] = 0.0, 92 | upsample: str = "deconv", 93 | halves: bool = True, 94 | pool_size = (2, 2, 2) 95 | ): 96 | """ 97 | Args: 98 | dim: number of spatial dimensions. 99 | in_chns: number of input channels to be upsampled. 100 | cat_chns: number of channels from the decoder. 101 | out_chns: number of output channels. 102 | act: activation type and arguments. 103 | norm: feature normalization type and arguments. 104 | dropout: dropout ratio. Defaults to no dropout. 105 | upsample: upsampling mode, available options are 106 | ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. 107 | halves: whether to halve the number of channels during upsampling. 108 | """ 109 | super().__init__() 110 | 111 | up_chns = in_chns // 2 if halves else in_chns 112 | self.upsample = UpSample(dim, in_chns, up_chns, pool_size, mode=upsample) 113 | self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout) 114 | 115 | def forward(self, x: torch.Tensor, x_e: torch.Tensor): 116 | """ 117 | 118 | Args: 119 | x: features to be upsampled. 120 | x_e: features from the encoder. 121 | """ 122 | x_0 = self.upsample(x) 123 | 124 | # handling spatial shapes due to the 2x maxpooling with odd edge lengths. 125 | dimensions = len(x.shape) - 2 126 | sp = [0] * (dimensions * 2) 127 | for i in range(dimensions): 128 | if x_e.shape[-i - 1] != x_0.shape[-i - 1]: 129 | sp[i * 2 + 1] = 1 130 | x_0 = torch.nn.functional.pad(x_0, sp, "replicate") 131 | 132 | x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns) 133 | return x 134 | 135 | class DeepUNet(nn.Module): 136 | 137 | def cons_stages(self, pools, region): 138 | stage = [(copy.deepcopy(region[0]), copy.deepcopy(region[1]))] 139 | for pool in reversed(pools): 140 | for i, r in enumerate(region): 141 | region[i][0] = region[i][0] * pool[0] 142 | region[i][1] = region[i][1] * pool[1] 143 | region[i][2] = region[i][2] * pool[2] 144 | stage.append((copy.deepcopy(region[0]), copy.deepcopy(region[1]))) 145 | 146 | return stage 147 | 148 | def __init__( 149 | self, 150 | in_channels: int = 1, 151 | out_channels: int = 2, 152 | features: Sequence[int] = (32, 32, 64, 128, 256), 153 | act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), 154 | norm: Union[str, tuple] = ("instance", {"affine": True}), 155 | dropout: Union[float, tuple] = 0.0, 156 | upsample: str = "deconv", 157 | pool_size = ((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), 158 | select_reconstruct_region=[[0, 0, 0], [8, 8, 8]], # 重构范围, 159 | first_level_region = (64, 64, 64), 160 | two_level_region = (32, 32, 32), 161 | pretrain=False, 162 | ): 163 | super().__init__() 164 | deepth = len(pool_size) 165 | self.deepth = deepth 166 | self.in_channels = in_channels 167 | fea = features 168 | print(f"BasicUNet features: {fea}.") 169 | self.select_reconstruct_region = select_reconstruct_region 170 | self.stages = self.cons_stages(pool_size, select_reconstruct_region) 171 | print(f"self.stages is {self.stages}") 172 | self.pool_size_all = self.get_pool_size_all(pool_size) 173 | self.window_size = torch.tensor(first_level_region) // torch.tensor(self.pool_size_all) 174 | print(f"window size is {self.window_size}") 175 | self.pretrain = pretrain 176 | ## get patches of region 177 | self.drop = nn.Dropout() 178 | self.conv_0 = TwoConv(3, in_channels, features[0], act, norm, dropout) 179 | 180 | self.downs = nn.ModuleList([]) 181 | 182 | for d in range(deepth): 183 | self.downs.append(Down(3, fea[d], fea[d+1], act=act, norm=norm, pool_size=pool_size[d])) 184 | 185 | self.ups = nn.ModuleList([]) 186 | for d in range(deepth): 187 | self.ups.append(UpCat(3, fea[deepth-d], fea[deepth-d-1], fea[deepth-d-1], act, norm, dropout, pool_size=pool_size[deepth-d-1], upsample=upsample)) 188 | 189 | self.decoder_pred = nn.Conv3d(fea[0], out_channels, 1, 1) 190 | 191 | if pretrain: 192 | bottom_feature = features[-1] 193 | self.pred_mask_region = nn.Linear(bottom_feature, 65)# 一个region 4个 patch 194 | self.contrast_learning_head = nn.Linear(bottom_feature, 384) 195 | self.pred_mask_region_position = nn.Linear(bottom_feature, 64) 196 | 197 | def get_pool_size_all(self, pool_size): 198 | p_all = [1, 1, 1] 199 | for p in pool_size: 200 | p_all[0] = p_all[0] * p[0] 201 | p_all[1] = p_all[1] * p[1] 202 | p_all[2] = p_all[2] * p[2] 203 | return p_all 204 | 205 | def wrap_feature_selection(self, feature, region_box): 206 | # feature: b, c, d, w, h 207 | return feature[..., region_box[0][0]:region_box[1][0], region_box[0][1]:region_box[1][1], region_box[0][2]:region_box[1][2]] 208 | 209 | def get_local_images(self, images): 210 | images = self.wrap_feature_selection(images, region_box=self.stages[-1]) 211 | return images 212 | 213 | def forward_encoder(self, x): 214 | x = self.conv_0(x) 215 | x_downs = [x] 216 | for d in range(self.deepth): 217 | x = self.downs[d](x) 218 | x_downs.append(x) 219 | return x_downs 220 | 221 | def forward_decoder(self, x_downs): 222 | x = self.wrap_feature_selection(x_downs[-1], self.stages[0]) 223 | 224 | for d in range(self.deepth): 225 | x = self.ups[d](x, self.wrap_feature_selection(x_downs[self.deepth-d-1], self.stages[d+1])) 226 | logits = self.decoder_pred(x) 227 | return logits 228 | 229 | def forward(self, x): 230 | device = x.device 231 | images = x.detach() 232 | local_images = self.get_local_images(images) 233 | if self.pretrain: 234 | # mask_ratio = torch.clamp(torch.rand(1), 0.4, 0.75) 235 | mask_ratio = 0.4 236 | x, mask = mask_func(x, self.in_channels, mask_ratio, (32, 32, 32), (4, 4, 4), mask_value=0.0) 237 | region_mask_labels = get_mask_labels(x.shape[0], 2*2*2, mask, 4*4*4, device) 238 | region_mask_position = get_mask_labelsv2(x.shape[0], 2*2*2, mask, 4*4*4, device=device) 239 | 240 | x_mask = self.wrap_feature_selection(x, region_box=self.stages[-1]) 241 | 242 | hidden_states_out = self.forward_encoder(x) 243 | logits = self.forward_decoder(hidden_states_out) 244 | 245 | if self.pretrain: 246 | # print(hidden_states_out.shape) 247 | classifier_hidden_states = rearrange(hidden_states_out[-1], "b c (d m) (w n) (h l) -> b c d w h (m n l)", m=self.window_size[0], n=self.window_size[1], l=self.window_size[2]) 248 | classifier_hidden_states = classifier_hidden_states.mean(dim=-1) 249 | with torch.no_grad(): 250 | hidden_states_out_2 = self.forward_encoder(x) 251 | encode_feature = hidden_states_out[-1] 252 | encode_feature_2 = hidden_states_out_2[-1] 253 | 254 | x4_reshape = encode_feature.flatten(start_dim=2, end_dim=4) 255 | x4_reshape = x4_reshape.transpose(1, 2) 256 | 257 | x4_reshape_2 = encode_feature_2.flatten(start_dim=2, end_dim=4) 258 | x4_reshape_2 = x4_reshape_2.transpose(1, 2) 259 | 260 | contrast_pred = self.contrast_learning_head(x4_reshape.mean(dim=1)) 261 | contrast_pred_2 = self.contrast_learning_head(x4_reshape_2.mean(dim=1)) 262 | 263 | pred_mask_feature = classifier_hidden_states.flatten(start_dim=2, end_dim=4) 264 | pred_mask_feature = pred_mask_feature.transpose(1, 2) 265 | mask_region_pred = self.pred_mask_region(pred_mask_feature) 266 | 267 | pred_mask_feature_position = classifier_hidden_states.flatten(start_dim=2, end_dim=4) 268 | pred_mask_feature_position = pred_mask_feature_position.transpose(1, 2) 269 | mask_region_position_pred = self.pred_mask_region_position(pred_mask_feature_position) 270 | 271 | return { 272 | "logits": logits, 273 | 'images': local_images, 274 | "pred_mask_region": mask_region_pred, 275 | "pred_mask_region_position": mask_region_position_pred, 276 | "mask_position_lables": region_mask_position, 277 | "mask": mask, 278 | "x_mask": x_mask, 279 | "mask_labels": region_mask_labels, 280 | "contrast_pred_1": contrast_pred, 281 | "contrast_pred_2": contrast_pred_2, 282 | } 283 | else : 284 | return logits 285 | 286 | -------------------------------------------------------------------------------- /Pretrain/pretrain_models/deep_unet_v2_64_32.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | from einops import rearrange 12 | 13 | from multiprocessing import pool 14 | from typing import Sequence, Union 15 | import torch 16 | import torch.nn as nn 17 | from monai.networks.blocks import Convolution, UpSample 18 | from monai.networks.layers.factories import Conv, Pool 19 | from .utils import mask_func 20 | from .utils import get_mask_labels, get_mask_labelsv2 21 | import copy 22 | 23 | class TwoConv(nn.Sequential): 24 | """two convolutions.""" 25 | def __init__( 26 | self, 27 | dim: int, 28 | in_chns: int, 29 | out_chns: int, 30 | act: Union[str, tuple], 31 | norm: Union[str, tuple], 32 | dropout: Union[float, tuple] = 0.0, 33 | ): 34 | """ 35 | Args: 36 | dim: number of spatial dimensions. 37 | in_chns: number of input channels. 38 | out_chns: number of output channels. 39 | act: activation type and arguments. 40 | norm: feature normalization type and arguments. 41 | dropout: dropout ratio. Defaults to no dropout. 42 | """ 43 | super().__init__() 44 | 45 | conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 46 | conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) 47 | self.add_module("conv_0", conv_0) 48 | self.add_module("conv_1", conv_1) 49 | 50 | class Down(nn.Sequential): 51 | """maxpooling downsampling and two convolutions.""" 52 | 53 | def __init__( 54 | self, 55 | dim: int, 56 | in_chns: int, 57 | out_chns: int, 58 | act: Union[str, tuple], 59 | norm: Union[str, tuple], 60 | dropout: Union[float, tuple] = 0.0, 61 | pool_size=(2, 2, 2) 62 | ): 63 | """ 64 | Args: 65 | dim: number of spatial dimensions. 66 | in_chns: number of input channels. 67 | out_chns: number of output channels. 68 | act: activation type and arguments. 69 | norm: feature normalization type and arguments. 70 | dropout: dropout ratio. Defaults to no dropout. 71 | """ 72 | super().__init__() 73 | 74 | max_pooling = Pool["MAX", dim](kernel_size=pool_size) 75 | convs = TwoConv(dim, in_chns, out_chns, act, norm, dropout) 76 | self.add_module("max_pooling", max_pooling) 77 | self.add_module("convs", convs) 78 | 79 | 80 | class UpCat(nn.Module): 81 | """upsampling, concatenation with the encoder feature map, two convolutions""" 82 | 83 | def __init__( 84 | self, 85 | dim: int, 86 | in_chns: int, 87 | cat_chns: int, 88 | out_chns: int, 89 | act: Union[str, tuple], 90 | norm: Union[str, tuple], 91 | dropout: Union[float, tuple] = 0.0, 92 | upsample: str = "deconv", 93 | halves: bool = True, 94 | pool_size = (2, 2, 2) 95 | ): 96 | """ 97 | Args: 98 | dim: number of spatial dimensions. 99 | in_chns: number of input channels to be upsampled. 100 | cat_chns: number of channels from the decoder. 101 | out_chns: number of output channels. 102 | act: activation type and arguments. 103 | norm: feature normalization type and arguments. 104 | dropout: dropout ratio. Defaults to no dropout. 105 | upsample: upsampling mode, available options are 106 | ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. 107 | halves: whether to halve the number of channels during upsampling. 108 | """ 109 | super().__init__() 110 | 111 | up_chns = in_chns // 2 if halves else in_chns 112 | self.upsample = UpSample(dim, in_chns, up_chns, pool_size, mode=upsample) 113 | self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout) 114 | 115 | def forward(self, x: torch.Tensor, x_e: torch.Tensor): 116 | """ 117 | 118 | Args: 119 | x: features to be upsampled. 120 | x_e: features from the encoder. 121 | """ 122 | x_0 = self.upsample(x) 123 | 124 | # handling spatial shapes due to the 2x maxpooling with odd edge lengths. 125 | dimensions = len(x.shape) - 2 126 | sp = [0] * (dimensions * 2) 127 | for i in range(dimensions): 128 | if x_e.shape[-i - 1] != x_0.shape[-i - 1]: 129 | sp[i * 2 + 1] = 1 130 | x_0 = torch.nn.functional.pad(x_0, sp, "replicate") 131 | 132 | x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns) 133 | return x 134 | 135 | class DeepUNet(nn.Module): 136 | 137 | def cons_stages(self, pools, region): 138 | stage = [(copy.deepcopy(region[0]), copy.deepcopy(region[1]))] 139 | for pool in reversed(pools): 140 | for i, r in enumerate(region): 141 | region[i][0] = region[i][0] * pool[0] 142 | region[i][1] = region[i][1] * pool[1] 143 | region[i][2] = region[i][2] * pool[2] 144 | stage.append((copy.deepcopy(region[0]), copy.deepcopy(region[1]))) 145 | 146 | return stage 147 | 148 | def __init__( 149 | self, 150 | in_channels: int = 1, 151 | out_channels: int = 2, 152 | features: Sequence[int] = (32, 32, 64, 128, 256), 153 | act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), 154 | norm: Union[str, tuple] = ("instance", {"affine": True}), 155 | dropout: Union[float, tuple] = 0.0, 156 | upsample: str = "deconv", 157 | pool_size = ((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)), 158 | select_reconstruct_region=[[0, 0, 0], [8, 8, 8]], # 重构范围, 159 | first_level_region = (64, 64, 64), 160 | two_level_region = (32, 32, 32), 161 | pretrain=False, 162 | ): 163 | super().__init__() 164 | deepth = len(pool_size) 165 | self.deepth = deepth 166 | self.in_channels = in_channels 167 | fea = features 168 | print(f"BasicUNet features: {fea}.") 169 | self.select_reconstruct_region = select_reconstruct_region 170 | self.stages = self.cons_stages(pool_size, select_reconstruct_region) 171 | print(f"self.stages is {self.stages}") 172 | self.pool_size_all = self.get_pool_size_all(pool_size) 173 | self.window_size = torch.tensor(first_level_region) // torch.tensor(self.pool_size_all) 174 | print(f"window size is {self.window_size}") 175 | self.pretrain = pretrain 176 | ## get patches of region 177 | self.drop = nn.Dropout() 178 | self.conv_0 = TwoConv(3, in_channels, features[0], act, norm, dropout) 179 | 180 | self.downs = nn.ModuleList([]) 181 | 182 | for d in range(deepth): 183 | self.downs.append(Down(3, fea[d], fea[d+1], act=act, norm=norm, pool_size=pool_size[d])) 184 | 185 | self.ups = nn.ModuleList([]) 186 | for d in range(deepth): 187 | self.ups.append(UpCat(3, fea[deepth-d], fea[deepth-d-1], fea[deepth-d-1], act, norm, dropout, pool_size=pool_size[deepth-d-1], upsample=upsample)) 188 | 189 | self.decoder_pred = nn.Conv3d(fea[0], out_channels, 1, 1) 190 | 191 | if pretrain: 192 | bottom_feature = features[-1] 193 | self.pred_mask_region = nn.Linear(bottom_feature, 9)# 一个region 4个 patch 194 | self.contrast_learning_head = nn.Linear(bottom_feature, 384) 195 | self.pred_mask_region_position = nn.Linear(bottom_feature, 8) 196 | 197 | def get_pool_size_all(self, pool_size): 198 | p_all = [1, 1, 1] 199 | for p in pool_size: 200 | p_all[0] = p_all[0] * p[0] 201 | p_all[1] = p_all[1] * p[1] 202 | p_all[2] = p_all[2] * p[2] 203 | return p_all 204 | 205 | def wrap_feature_selection(self, feature, region_box): 206 | # feature: b, c, d, w, h 207 | return feature[..., region_box[0][0]:region_box[1][0], region_box[0][1]:region_box[1][1], region_box[0][2]:region_box[1][2]] 208 | 209 | def get_local_images(self, images): 210 | images = self.wrap_feature_selection(images, region_box=self.stages[-1]) 211 | return images 212 | 213 | def forward_encoder(self, x): 214 | x = self.conv_0(x) 215 | x_downs = [x] 216 | for d in range(self.deepth): 217 | x = self.downs[d](x) 218 | x_downs.append(x) 219 | return x_downs 220 | 221 | def forward_decoder(self, x_downs): 222 | x = self.wrap_feature_selection(x_downs[-1], self.stages[0]) 223 | 224 | for d in range(self.deepth): 225 | x = self.ups[d](x, self.wrap_feature_selection(x_downs[self.deepth-d-1], self.stages[d+1])) 226 | logits = self.decoder_pred(x) 227 | return logits 228 | 229 | def forward(self, x): 230 | device = x.device 231 | images = x.detach() 232 | local_images = self.get_local_images(images) 233 | if self.pretrain: 234 | # mask_ratio = torch.clamp(torch.rand(1), 0.4, 0.75) 235 | mask_ratio = 0.4 236 | x, mask = mask_func(x, self.in_channels, mask_ratio, (32, 32, 32), (4, 4, 4), mask_value=0.0) 237 | region_mask_labels = get_mask_labels(x.shape[0], 2*2*2, mask, 2*2*2, device) 238 | region_mask_position = get_mask_labelsv2(x.shape[0], 2*2*2, mask, 2*2*2, device=device) 239 | 240 | x_mask = self.wrap_feature_selection(x, region_box=self.stages[-1]) 241 | 242 | hidden_states_out = self.forward_encoder(x) 243 | logits = self.forward_decoder(hidden_states_out) 244 | 245 | if self.pretrain: 246 | # print(hidden_states_out.shape) 247 | classifier_hidden_states = rearrange(hidden_states_out[-1], "b c (d m) (w n) (h l) -> b c d w h (m n l)", m=self.window_size[0], n=self.window_size[1], l=self.window_size[2]) 248 | classifier_hidden_states = classifier_hidden_states.mean(dim=-1) 249 | with torch.no_grad(): 250 | hidden_states_out_2 = self.forward_encoder(x) 251 | encode_feature = hidden_states_out[-1] 252 | encode_feature_2 = hidden_states_out_2[-1] 253 | 254 | x4_reshape = encode_feature.flatten(start_dim=2, end_dim=4) 255 | x4_reshape = x4_reshape.transpose(1, 2) 256 | 257 | x4_reshape_2 = encode_feature_2.flatten(start_dim=2, end_dim=4) 258 | x4_reshape_2 = x4_reshape_2.transpose(1, 2) 259 | 260 | contrast_pred = self.contrast_learning_head(x4_reshape.mean(dim=1)) 261 | contrast_pred_2 = self.contrast_learning_head(x4_reshape_2.mean(dim=1)) 262 | 263 | pred_mask_feature = classifier_hidden_states.flatten(start_dim=2, end_dim=4) 264 | pred_mask_feature = pred_mask_feature.transpose(1, 2) 265 | mask_region_pred = self.pred_mask_region(pred_mask_feature) 266 | 267 | pred_mask_feature_position = classifier_hidden_states.flatten(start_dim=2, end_dim=4) 268 | pred_mask_feature_position = pred_mask_feature_position.transpose(1, 2) 269 | mask_region_position_pred = self.pred_mask_region_position(pred_mask_feature_position) 270 | 271 | return { 272 | "logits": logits, 273 | 'images': local_images, 274 | "pred_mask_region": mask_region_pred, 275 | "pred_mask_region_position": mask_region_position_pred, 276 | "mask_position_lables": region_mask_position, 277 | "mask": mask, 278 | "x_mask": x_mask, 279 | "mask_labels": region_mask_labels, 280 | "contrast_pred_1": contrast_pred, 281 | "contrast_pred_2": contrast_pred_2, 282 | } 283 | else : 284 | return logits 285 | 286 | -------------------------------------------------------------------------------- /Pretrain/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2022 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import argparse 13 | import os 14 | from time import time 15 | 16 | import numpy as np 17 | import torch 18 | import torch.distributed as dist 19 | import torch.optim as optim 20 | from losses.loss import Loss 21 | from optimizers.lr_scheduler import WarmupCosineSchedule 22 | from torch.cuda.amp import GradScaler, autocast 23 | from torch.nn.parallel import DistributedDataParallel 24 | from torch.utils.tensorboard import SummaryWriter 25 | from utils.data_utils import get_loader 26 | 27 | from losses.loss import * 28 | 29 | mse_loss = torch.nn.MSELoss() 30 | def main(): 31 | def save_ckp(state, checkpoint_dir): 32 | torch.save(state, checkpoint_dir) 33 | 34 | def train(args, global_step, train_loader, best_val=1e8): 35 | 36 | model.train() 37 | 38 | for step, batch in enumerate(train_loader): 39 | t1 = time() 40 | model.train() 41 | 42 | image = batch["image"].cuda() 43 | 44 | model_out = model(image) 45 | x_rec = torch.sigmoid(model_out["logits"]) 46 | labels = model_out['images'] 47 | mask_images = model_out["x_mask"] 48 | pred_mask_region = model_out["pred_mask_region"] 49 | mask_labels = model_out["mask_labels"] 50 | contrast_pred_1 = model_out["contrast_pred_1"] 51 | contrast_pred_2 = model_out["contrast_pred_2"] 52 | pred_mask_region_position = model_out["pred_mask_region_position"] 53 | mask_region_position_label = model_out["mask_position_lables"] 54 | loss_rec = forward_loss_reconstruct_mask(x_rec, labels, mask_images, mask_value=0.0) 55 | # loss_rec = forward_loss_reconstruct(x_rec, labels) 56 | # loss_rec = mse_loss(x_rec, labels) 57 | loss_mask_region = forward_loss_mask(pred_mask_region, mask_labels) 58 | position_pred = (torch.sigmoid(pred_mask_region_position) > 0.5).float() 59 | position_pred_num_region = position_pred.sum(dim=-1) 60 | 61 | loss_consistency = (forward_loss_mask(pred_mask_region, position_pred_num_region.detach()) + nn.MSELoss()(position_pred_num_region, pred_mask_region.argmax(dim=-1).float().detach())) / 2 62 | 63 | loss_contrast = forward_constrast_loss(contrast_pred_1, contrast_pred_2) 64 | loss_position = forward_loss_mask_position(pred_mask_region_position, mask_region_position_label) 65 | 66 | if args.distributed: 67 | if args.rank == 0: 68 | print(f"Step:{global_step}/{args.num_steps}, loss_rec is {loss_rec}, loss_mask_num is {loss_mask_region}, loss_mask_position is {loss_position}, loss_consistency is {loss_consistency}") 69 | else : 70 | print(f"loss_rec is {loss_rec}") 71 | 72 | loss = loss_rec + 0.1 * loss_mask_region + 0.1 * loss_position + 0.01 * loss_consistency + 0.1 * loss_contrast 73 | loss.backward() 74 | if args.grad_clip: 75 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 76 | optimizer.step() 77 | 78 | if args.lrdecay: 79 | scheduler.step() 80 | optimizer.zero_grad() 81 | if args.distributed: 82 | if dist.get_rank() == 0: 83 | print("Step:{}/{}, Loss:{:.4f}, Time:{:.4f}".format(global_step, args.num_steps, loss, time() - t1)) 84 | else: 85 | print("Step:{}/{}, Loss:{:.4f}, Time:{:.4f}".format(global_step, args.num_steps, loss, time() - t1)) 86 | 87 | global_step += 1 88 | 89 | 90 | if args.distributed: 91 | val_cond = (dist.get_rank() == 0) and (global_step % args.eval_num == 0) 92 | else: 93 | val_cond = global_step % args.eval_num == 0 94 | 95 | if val_cond: 96 | val_losses, img_list = validation(args, test_loader) 97 | writer.add_scalar("Validation/loss_total", scalar_value=val_losses[0], global_step=global_step) 98 | writer.add_scalar("Validation/loss_recon", scalar_value=val_losses[1], global_step=global_step) 99 | writer.add_scalar("Validation/loss_num", scalar_value=val_losses[2], global_step=global_step) 100 | writer.add_scalar("Validation/loss_position", scalar_value=val_losses[3], global_step=global_step) 101 | writer.add_scalar("Validation/loss_cl", scalar_value=val_losses[4], global_step=global_step) 102 | writer.add_scalar("train/loss_total", scalar_value=np.mean(loss), global_step=global_step) 103 | writer.add_scalar("train/loss_recon", scalar_value=np.mean(loss_rec), global_step=global_step) 104 | 105 | writer.add_image("Validation/x1_row", img_list[0], global_step, dataformats="HW") 106 | writer.add_image("Validation/x1_gt", img_list[1], global_step, dataformats="HW") 107 | writer.add_image("Validation/x1_recon", img_list[2], global_step, dataformats="HW") 108 | writer.add_image("Validation/x1_mask", img_list[3], global_step, dataformats="HW") 109 | writer.add_image("Validation/x1_recon_raw", img_list[4], global_step, dataformats="HW") 110 | 111 | val_loss_recon = val_losses[1] 112 | if val_loss_recon < best_val: 113 | best_val = val_loss_recon 114 | 115 | torch.save(model.state_dict(), logdir + "/model_bestValRMSE.pt") 116 | print( 117 | "Model was saved ! Best Recon. Val Loss: {:.4f}, Recon. Val Loss: {:.4f}".format( 118 | best_val, val_loss_recon 119 | ) 120 | ) 121 | else: 122 | print( 123 | "Model was not saved ! Best Recon. Val Loss: {:.4f} Recon. Val Loss: {:.4f}".format( 124 | best_val, val_loss_recon 125 | ) 126 | ) 127 | return global_step, best_val 128 | 129 | def validation(args, test_loader): 130 | model.eval() 131 | loss_val = [] 132 | loss_val_recon = [] 133 | loss_num_total = [] 134 | loss_position_total = [] 135 | loss_cl_total = [] 136 | with torch.no_grad(): 137 | for step, batch in enumerate(test_loader): 138 | val_inputs = batch["image"].cuda() 139 | model_out = model(val_inputs) 140 | 141 | x_rec = torch.sigmoid(model_out["logits"]) 142 | labels = model_out['images'] 143 | mask_images = model_out["x_mask"] 144 | pred_mask_region = model_out["pred_mask_region"] 145 | mask_labels = model_out["mask_labels"] 146 | contrast_pred_1 = model_out["contrast_pred_1"] 147 | contrast_pred_2 = model_out["contrast_pred_2"] 148 | pred_mask_region_position = model_out["pred_mask_region_position"] 149 | mask_region_position_label = model_out["mask_position_lables"] 150 | loss_rec = forward_loss_reconstruct_mask(x_rec, labels, mask_images, mask_value=0.0) 151 | # loss_rec = forward_loss_reconstruct(x_rec, labels) 152 | loss_rec = mse_loss(x_rec, labels) 153 | 154 | num_pred = pred_mask_region.argmax(dim=-1) 155 | 156 | num_acc = (num_pred == mask_labels).float().sum() / (num_pred.shape[0] * num_pred.shape[1]) 157 | 158 | 159 | loss_mask_region = forward_loss_mask(pred_mask_region, mask_labels) 160 | position_pred = (torch.sigmoid(pred_mask_region_position) > 0.5).float() 161 | position_pred_num_region = position_pred.sum(dim=-1) 162 | position_acc = (position_pred == mask_region_position_label).float().sum() / (position_pred.shape[0] * position_pred.shape[1] * position_pred.shape[2]) 163 | 164 | loss_consistency = (forward_loss_mask(pred_mask_region, position_pred_num_region.detach()) + nn.MSELoss()(position_pred_num_region, pred_mask_region.argmax(dim=-1).float().detach())) / 2 165 | 166 | loss_contrast = forward_constrast_loss(contrast_pred_1, contrast_pred_2) 167 | loss_position = forward_loss_mask_position(pred_mask_region_position, mask_region_position_label) 168 | 169 | loss = loss_rec + 0.1 * loss_mask_region + 0.1 * loss_position + 0.001 * loss_consistency + 0.001 * loss_contrast 170 | loss_val.append(loss.item()) 171 | loss_val_recon.append(loss_rec.item()) 172 | loss_num_total.append(num_acc.item()) 173 | loss_position_total.append(position_acc.item()) 174 | loss_cl_total.append(loss_contrast.item()) 175 | 176 | x_gt = labels.detach().cpu().numpy() 177 | x_gt = (x_gt - np.min(x_gt)) / (np.max(x_gt) - np.min(x_gt)) 178 | xgt = x_gt[0][0][:, :, 48] * 255.0 179 | xgt = xgt.astype(np.uint8) 180 | 181 | x_mask = mask_images.detach().cpu().numpy() 182 | x_mask = (x_mask - np.min(x_mask)) / (np.max(x_mask) - np.min(x_mask)) 183 | x_mask = x_mask[0][0][:, :, 48] * 255.0 184 | x_mask = x_mask.astype(np.uint8) 185 | 186 | x_row = val_inputs.detach().cpu().numpy() 187 | x_row = (x_row - np.min(x_row)) / (np.max(x_row) - np.min(x_row)) 188 | x_row = x_row[0][0][:, :, 48] * 255.0 189 | x_row = x_row.astype(np.uint8) 190 | 191 | mask = (mask_images != 0.0).float() 192 | rec_x1 = mask_images.detach().cpu().numpy() + (1 - mask) * x_rec.detach().cpu().numpy() 193 | rec_x1 = (rec_x1 - np.min(rec_x1)) / (np.max(rec_x1) - np.min(rec_x1)) 194 | recon = rec_x1[0][0][:, :, 48] * 255.0 195 | recon = recon.astype(np.uint8) 196 | 197 | rec_row = x_rec.detach().cpu().numpy() 198 | rec_row = (rec_row - np.min(rec_row)) / (np.max(rec_row) - np.min(rec_row)) 199 | rec_row = rec_row[0][0][:, :, 48] * 255.0 200 | rec_row = rec_row.astype(np.uint8) 201 | 202 | img_list = [x_row, xgt, recon, x_mask, rec_row] 203 | 204 | print("Validation step:{}, Loss:{:.4f}, Loss Reconstruction:{:.4f}".format(step, loss, loss_rec)) 205 | 206 | return [np.mean(loss_val), np.mean(loss_val_recon), np.mean(loss_num_total), np.mean(loss_position_total), np.mean(loss_cl_total)], img_list 207 | 208 | parser = argparse.ArgumentParser(description="PyTorch Training") 209 | parser.add_argument("--model_name", default="deepunet_v2", type=str, help="directory to save the tensorboard logs") 210 | parser.add_argument("--logdir", default="test", type=str, help="directory to save the tensorboard logs") 211 | parser.add_argument("--epochs", default=200, type=int, help="number of training epochs") 212 | parser.add_argument("--num_steps", default=100000, type=int, help="number of training iterations") 213 | parser.add_argument("--eval_num", default=100, type=int, help="evaluation frequency") 214 | parser.add_argument("--warmup_steps", default=500, type=int, help="warmup steps") 215 | parser.add_argument("--in_channels", default=1, type=int, help="number of input channels") 216 | parser.add_argument("--feature_size", default=48, type=int, help="embedding size") 217 | parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate") 218 | parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") 219 | parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") 220 | parser.add_argument("--a_min", default=-1000, type=float, help="a_min in ScaleIntensityRanged") 221 | parser.add_argument("--a_max", default=1000, type=float, help="a_max in ScaleIntensityRanged") 222 | parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") 223 | parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") 224 | parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") 225 | parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") 226 | parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") 227 | parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction") 228 | parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction") 229 | parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction") 230 | parser.add_argument("--batch_size", default=2, type=int, help="number of batch size") 231 | parser.add_argument("--sw_batch_size", default=2, type=int, help="number of sliding window batch size") 232 | parser.add_argument("--lr", default=4e-4, type=float, help="learning rate") 233 | parser.add_argument("--decay", default=0.1, type=float, help="decay rate") 234 | parser.add_argument("--momentum", default=0.9, type=float, help="momentum") 235 | parser.add_argument("--lrdecay", action="store_true", help="enable learning rate decay") 236 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="maximum gradient norm") 237 | parser.add_argument("--loss_type", default="SSL", type=str) 238 | parser.add_argument("--opt", default="adamw", type=str, help="optimization algorithm") 239 | parser.add_argument("--lr_schedule", default="warmup_cosine", type=str) 240 | parser.add_argument("--resume", default=None, type=str, help="resume training") 241 | parser.add_argument("--local_rank", type=int, default=0, help="local rank") 242 | parser.add_argument("--grad_clip", action="store_true", help="gradient clip") 243 | parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") 244 | parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") 245 | parser.add_argument("--smartcache_dataset", action="store_true", help="use monai smartcache Dataset") 246 | parser.add_argument("--cache_dataset", action="store_true", help="use monai cache Dataset") 247 | parser.add_argument("--val_cache", action="store_true", help="use monai cache Dataset") 248 | 249 | args = parser.parse_args() 250 | logdir = "./runs/" + args.logdir 251 | args.amp = not args.noamp 252 | torch.backends.cudnn.benchmark = True 253 | torch.autograd.set_detect_anomaly(True) 254 | args.distributed = False 255 | if "WORLD_SIZE" in os.environ: 256 | args.distributed = int(os.environ["WORLD_SIZE"]) > 1 257 | args.device = "cuda:0" 258 | args.world_size = 1 259 | args.rank = 0 260 | 261 | if args.distributed: 262 | args.device = "cuda:%d" % args.local_rank 263 | torch.cuda.set_device(args.local_rank) 264 | torch.distributed.init_process_group(backend="nccl", init_method=args.dist_url) 265 | args.world_size = torch.distributed.get_world_size() 266 | args.rank = torch.distributed.get_rank() 267 | print( 268 | "Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d." 269 | % (args.rank, args.world_size) 270 | ) 271 | else: 272 | print("Training with a single process on 1 GPUs.") 273 | assert args.rank >= 0 274 | 275 | if args.rank == 0: 276 | os.makedirs(logdir, exist_ok=True) 277 | writer = SummaryWriter(logdir) 278 | else: 279 | writer = None 280 | 281 | if args.model_name == "deepunet_v2": 282 | from pretrain_models.deep_unet_v2 import DeepUNet 283 | model = DeepUNet(1, 1, features=[64, 64, 128, 256, 512], 284 | pool_size=((2, 2, 2), (2, 2, 2), (2, 2, 2), (1, 1, 1)), 285 | select_reconstruct_region=[[4, 4, 4], [12, 12, 12]], 286 | dropout=0.1, 287 | pretrain=True) 288 | 289 | else : 290 | from pretrain_models.swinunetr import SwinUNETR 291 | model = SwinUNETR((96, 96, 96), 292 | in_channels=1, 293 | out_channels=1, 294 | drop_rate=0.1, 295 | feature_size=args.feature_size, 296 | pretrain=True, 297 | select_reconstruct_region=(1, 3)) 298 | 299 | model.cuda() 300 | 301 | if args.opt == "adam": 302 | optimizer = optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.decay) 303 | 304 | elif args.opt == "adamw": 305 | optimizer = optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.decay) 306 | 307 | elif args.opt == "sgd": 308 | optimizer = optim.SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) 309 | 310 | if args.resume: 311 | model_pth = args.resume 312 | model_dict = torch.load(model_pth) 313 | model.load_state_dict(model_dict["state_dict"]) 314 | model.epoch = model_dict["epoch"] 315 | model.optimizer = model_dict["optimizer"] 316 | 317 | if args.lrdecay: 318 | if args.lr_schedule == "warmup_cosine": 319 | scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=args.num_steps) 320 | 321 | elif args.lr_schedule == "poly": 322 | 323 | def lambdas(epoch): 324 | return (1 - float(epoch) / float(args.epochs)) ** 0.9 325 | 326 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas) 327 | 328 | if args.distributed: 329 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 330 | model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) 331 | train_loader, test_loader = get_loader(args) 332 | 333 | global_step = 0 334 | best_val = 1e8 335 | while global_step < args.num_steps: 336 | global_step, best_val = train(args, global_step, train_loader, best_val=best_val) 337 | 338 | if args.distributed: 339 | if dist.get_rank() == 0: 340 | torch.save(model.state_dict(), logdir + "final_model.pth") 341 | dist.destroy_process_group() 342 | else: 343 | torch.save(model.state_dict(), logdir + "final_model.pth") 344 | 345 | 346 | if __name__ == "__main__": 347 | main() -------------------------------------------------------------------------------- /Pretrain/pretrain_models/swinunetr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Sequence, Tuple, Type, Union 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.utils.checkpoint as checkpoint 19 | from torch.nn import LayerNorm 20 | 21 | from monai.networks.blocks import MLPBlock as Mlp 22 | from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock 23 | from monai.networks.layers import DropPath, trunc_normal_ 24 | from monai.utils import ensure_tuple_rep, optional_import 25 | from .utils import mask_func 26 | from .utils import get_mask_labels, get_mask_labelsv2 27 | 28 | rearrange, _ = optional_import("einops", name="rearrange") 29 | 30 | class SwinUNETR(nn.Module): 31 | """ 32 | Swin UNETR based on: "Hatamizadeh et al., 33 | Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images 34 | " 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[Sequence[int], int], 40 | in_channels: int, 41 | out_channels: int, 42 | depths: Sequence[int] = (2, 2, 2, 2), 43 | num_heads: Sequence[int] = (3, 6, 12, 24), 44 | feature_size: int = 24, 45 | norm_name: Union[Tuple, str] = "instance", 46 | drop_rate: float = 0.0, 47 | attn_drop_rate: float = 0.0, 48 | dropout_path_rate: float = 0.0, 49 | normalize: bool = True, 50 | use_checkpoint: bool = False, 51 | spatial_dims: int = 3, 52 | pretrain=False, 53 | select_reconstruct_region=(0, 3), 54 | ) -> None: 55 | """ 56 | Args: 57 | img_size: dimension of input image. 58 | in_channels: dimension of input channels. 59 | out_channels: dimension of output channels. 60 | feature_size: dimension of network feature size. 61 | depths: number of layers in each stage. 62 | num_heads: number of attention heads. 63 | norm_name: feature normalization type and arguments. 64 | drop_rate: dropout rate. 65 | attn_drop_rate: attention dropout rate. 66 | dropout_path_rate: drop path rate. 67 | normalize: normalize output intermediate features in each stage. 68 | use_checkpoint: use gradient checkpointing for reduced memory usage. 69 | spatial_dims: number of spatial dims. 70 | 71 | Examples:: 72 | 73 | # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. 74 | >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) 75 | 76 | # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. 77 | >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) 78 | 79 | # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. 80 | >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) 81 | 82 | """ 83 | 84 | super().__init__() 85 | self.in_channels = in_channels 86 | downsample_size = 5 87 | img_size = ensure_tuple_rep(img_size, spatial_dims) 88 | patch_size = ensure_tuple_rep(2, spatial_dims) 89 | window_size = ensure_tuple_rep(7, spatial_dims) 90 | self.select_reconstruct_region = select_reconstruct_region 91 | self.stages = [] 92 | for i in range(downsample_size+1): 93 | self.stages.append((select_reconstruct_region[0] * 2**i, select_reconstruct_region[1] * 2**i)) 94 | print(self.stages) 95 | self.pretrain = pretrain 96 | if not (spatial_dims == 2 or spatial_dims == 3): 97 | raise ValueError("spatial dimension should be 2 or 3.") 98 | 99 | for m, p in zip(img_size, patch_size): 100 | for i in range(5): 101 | if m % np.power(p, i + 1) != 0: 102 | raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") 103 | 104 | if not (0 <= drop_rate <= 1): 105 | raise ValueError("dropout rate should be between 0 and 1.") 106 | 107 | if not (0 <= attn_drop_rate <= 1): 108 | raise ValueError("attention dropout rate should be between 0 and 1.") 109 | 110 | if not (0 <= dropout_path_rate <= 1): 111 | raise ValueError("drop path rate should be between 0 and 1.") 112 | 113 | if feature_size % 12 != 0: 114 | raise ValueError("feature_size should be divisible by 12.") 115 | 116 | self.normalize = normalize 117 | 118 | self.swinViT = SwinTransformer( 119 | in_chans=in_channels, 120 | embed_dim=feature_size, 121 | window_size=window_size, 122 | patch_size=patch_size, 123 | depths=depths, 124 | num_heads=num_heads, 125 | mlp_ratio=4.0, 126 | qkv_bias=True, 127 | drop_rate=drop_rate, 128 | attn_drop_rate=attn_drop_rate, 129 | drop_path_rate=dropout_path_rate, 130 | norm_layer=nn.LayerNorm, 131 | use_checkpoint=use_checkpoint, 132 | spatial_dims=spatial_dims, 133 | ) 134 | 135 | self.encoder1 = UnetrBasicBlock( 136 | spatial_dims=spatial_dims, 137 | in_channels=in_channels, 138 | out_channels=feature_size, 139 | kernel_size=3, 140 | stride=1, 141 | norm_name=norm_name, 142 | res_block=True, 143 | ) 144 | 145 | self.encoder2 = UnetrBasicBlock( 146 | spatial_dims=spatial_dims, 147 | in_channels=feature_size, 148 | out_channels=feature_size, 149 | kernel_size=3, 150 | stride=1, 151 | norm_name=norm_name, 152 | res_block=True, 153 | ) 154 | 155 | self.encoder3 = UnetrBasicBlock( 156 | spatial_dims=spatial_dims, 157 | in_channels=2 * feature_size, 158 | out_channels=2 * feature_size, 159 | kernel_size=3, 160 | stride=1, 161 | norm_name=norm_name, 162 | res_block=True, 163 | ) 164 | 165 | self.encoder4 = UnetrBasicBlock( 166 | spatial_dims=spatial_dims, 167 | in_channels=4 * feature_size, 168 | out_channels=4 * feature_size, 169 | kernel_size=3, 170 | stride=1, 171 | norm_name=norm_name, 172 | res_block=True, 173 | ) 174 | 175 | self.encoder10 = UnetrBasicBlock( 176 | spatial_dims=spatial_dims, 177 | in_channels=16 * feature_size, 178 | out_channels=16 * feature_size, 179 | kernel_size=3, 180 | stride=1, 181 | norm_name=norm_name, 182 | res_block=True, 183 | ) 184 | 185 | self.decoder5 = UnetrUpBlock( 186 | spatial_dims=spatial_dims, 187 | in_channels=16 * feature_size, 188 | out_channels=8 * feature_size, 189 | kernel_size=3, 190 | upsample_kernel_size=2, 191 | norm_name=norm_name, 192 | res_block=True, 193 | ) 194 | 195 | self.decoder4 = UnetrUpBlock( 196 | spatial_dims=spatial_dims, 197 | in_channels=feature_size * 8, 198 | out_channels=feature_size * 4, 199 | kernel_size=3, 200 | upsample_kernel_size=2, 201 | norm_name=norm_name, 202 | res_block=True, 203 | ) 204 | 205 | self.decoder3 = UnetrUpBlock( 206 | spatial_dims=spatial_dims, 207 | in_channels=feature_size * 4, 208 | out_channels=feature_size * 2, 209 | kernel_size=3, 210 | upsample_kernel_size=2, 211 | norm_name=norm_name, 212 | res_block=True, 213 | ) 214 | self.decoder2 = UnetrUpBlock( 215 | spatial_dims=spatial_dims, 216 | in_channels=feature_size * 2, 217 | out_channels=feature_size, 218 | kernel_size=3, 219 | upsample_kernel_size=2, 220 | norm_name=norm_name, 221 | res_block=True, 222 | ) 223 | 224 | self.decoder1 = UnetrUpBlock( 225 | spatial_dims=spatial_dims, 226 | in_channels=feature_size, 227 | out_channels=feature_size, 228 | kernel_size=3, 229 | upsample_kernel_size=2, 230 | norm_name=norm_name, 231 | res_block=True, 232 | ) 233 | 234 | self.out = UnetOutBlock( 235 | spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels 236 | ) # type: ignore 237 | if pretrain: 238 | if feature_size == 24: 239 | self.pred_mask_region = nn.Linear(384, 9)# 一个region 8个 patch 240 | self.contrast_learning_head = nn.Linear(384, 384) 241 | else: 242 | self.pred_mask_region = nn.Linear(768, 9)# 一个region 8个 patch 243 | self.contrast_learning_head = nn.Linear(768, 384) 244 | self.pred_mask_region_position = nn.Linear(768, 8) 245 | 246 | def wrap_feature_selection(self, feature, region_box): 247 | # feature: b, c, d, w, h 248 | return feature[..., region_box[0]:region_box[1], region_box[0]:region_box[1], region_box[0]:region_box[1]] 249 | 250 | def get_local_images(self, images): 251 | images = self.wrap_feature_selection(images, region_box=self.stages[5]) 252 | return images 253 | 254 | def forward(self, x_in): 255 | device = x_in.device 256 | images = x_in.detach() 257 | if self.pretrain: 258 | # mask_ratio = torch.clamp(torch.rand(1), 0.2, 0.4) 259 | mask_ratio = 0.4 260 | # x_in, mask = mask_func(x_in, self.in_channels, mask_ratio, (32, 32, 32), (4, 4, 4)) 261 | x_in, mask = mask_func(x_in, self.in_channels, mask_ratio, (16, 16, 16), (6, 6, 6)) 262 | 263 | region_mask_labels = get_mask_labels(x_in.shape[0], 3*3*3, mask, 2*2*2, device) 264 | region_mask_position = get_mask_labelsv2(x_in.shape[0], 3*3*3, mask, 2*2*2, device=device) 265 | 266 | hidden_states_out = self.swinViT(x_in, self.normalize) 267 | local_images = self.get_local_images(images) 268 | return_x_in = self.wrap_feature_selection(x_in, region_box=self.stages[5]) 269 | 270 | enc0 = self.encoder1(self.wrap_feature_selection(x_in, region_box=self.stages[5])) 271 | enc1 = self.encoder2(self.wrap_feature_selection(hidden_states_out[0], region_box=self.stages[4])) 272 | enc2 = self.encoder3(self.wrap_feature_selection(hidden_states_out[1], region_box=self.stages[3])) 273 | enc3 = self.encoder4(self.wrap_feature_selection(hidden_states_out[2], region_box=self.stages[2])) 274 | dec4 = self.encoder10(self.wrap_feature_selection(hidden_states_out[4], region_box=self.stages[0])) 275 | dec3 = self.decoder5(dec4, self.wrap_feature_selection(hidden_states_out[3], region_box=self.stages[1])) 276 | dec2 = self.decoder4(dec3, enc3) 277 | dec1 = self.decoder3(dec2, enc2) 278 | dec0 = self.decoder2(dec1, enc1) 279 | out = self.decoder1(dec0, enc0) 280 | logits = self.out(out) 281 | 282 | if self.pretrain: 283 | with torch.no_grad(): 284 | hidden_states_out_2 = self.swinViT(x_in, self.normalize) 285 | encode_feature = hidden_states_out[4] 286 | encode_feature_2 = hidden_states_out_2[4] 287 | 288 | x4_reshape = encode_feature.flatten(start_dim=2, end_dim=4) 289 | x4_reshape = x4_reshape.transpose(1, 2) 290 | 291 | x4_reshape_2 = encode_feature_2.flatten(start_dim=2, end_dim=4) 292 | x4_reshape_2 = x4_reshape_2.transpose(1, 2) 293 | 294 | contrast_pred = self.contrast_learning_head(x4_reshape[:, 1]) 295 | contrast_pred_2 = self.contrast_learning_head(x4_reshape_2[:, 1]) 296 | 297 | pred_mask_feature = encode_feature.flatten(start_dim=2, end_dim=4) 298 | pred_mask_feature = pred_mask_feature.transpose(1, 2) 299 | mask_region_pred = self.pred_mask_region(pred_mask_feature) 300 | 301 | pred_mask_feature_position = encode_feature.flatten(start_dim=2, end_dim=4) 302 | pred_mask_feature_position = pred_mask_feature_position.transpose(1, 2) 303 | mask_region_position_pred = self.pred_mask_region_position(pred_mask_feature_position) 304 | 305 | return { 306 | "logits": logits, 307 | 'images': local_images, 308 | "pred_mask_region": mask_region_pred, 309 | "pred_mask_region_position": mask_region_position_pred, 310 | "mask": mask, 311 | "x_mask": return_x_in, 312 | # "patch_size": patch_size, 313 | # "mask_feat_size": mask_feat_size, 314 | # "mask_labels": mask_labels, 315 | "mask_position_lables": region_mask_position, 316 | "mask_labels": region_mask_labels, 317 | "contrast_pred_1": contrast_pred, 318 | "contrast_pred_2": contrast_pred_2, 319 | } 320 | else : 321 | return logits 322 | 323 | 324 | def window_partition(x, window_size): 325 | """window partition operation based on: "Liu et al., 326 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 327 | " 328 | https://github.com/microsoft/Swin-Transformer 329 | 330 | Args: 331 | x: input tensor. 332 | window_size: local window size. 333 | """ 334 | x_shape = x.size() 335 | if len(x_shape) == 5: 336 | b, d, h, w, c = x_shape 337 | x = x.view( 338 | b, 339 | d // window_size[0], 340 | window_size[0], 341 | h // window_size[1], 342 | window_size[1], 343 | w // window_size[2], 344 | window_size[2], 345 | c, 346 | ) 347 | windows = ( 348 | x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) 349 | ) 350 | elif len(x_shape) == 4: 351 | b, h, w, c = x.shape 352 | x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) 353 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) 354 | return windows 355 | 356 | 357 | def window_reverse(windows, window_size, dims): 358 | """window reverse operation based on: "Liu et al., 359 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 360 | " 361 | https://github.com/microsoft/Swin-Transformer 362 | 363 | Args: 364 | windows: windows tensor. 365 | window_size: local window size. 366 | dims: dimension values. 367 | """ 368 | if len(dims) == 4: 369 | b, d, h, w = dims 370 | x = windows.view( 371 | b, 372 | d // window_size[0], 373 | h // window_size[1], 374 | w // window_size[2], 375 | window_size[0], 376 | window_size[1], 377 | window_size[2], 378 | -1, 379 | ) 380 | x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) 381 | 382 | elif len(dims) == 3: 383 | b, h, w = dims 384 | x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1) 385 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) 386 | return x 387 | 388 | 389 | def get_window_size(x_size, window_size, shift_size=None): 390 | """Computing window size based on: "Liu et al., 391 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 392 | " 393 | https://github.com/microsoft/Swin-Transformer 394 | 395 | Args: 396 | x_size: input size. 397 | window_size: local window size. 398 | shift_size: window shifting size. 399 | """ 400 | 401 | use_window_size = list(window_size) 402 | if shift_size is not None: 403 | use_shift_size = list(shift_size) 404 | for i in range(len(x_size)): 405 | if x_size[i] <= window_size[i]: 406 | use_window_size[i] = x_size[i] 407 | if shift_size is not None: 408 | use_shift_size[i] = 0 409 | 410 | if shift_size is None: 411 | return tuple(use_window_size) 412 | else: 413 | return tuple(use_window_size), tuple(use_shift_size) 414 | 415 | 416 | class WindowAttention(nn.Module): 417 | """ 418 | Window based multi-head self attention module with relative position bias based on: "Liu et al., 419 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 420 | " 421 | https://github.com/microsoft/Swin-Transformer 422 | """ 423 | 424 | def __init__( 425 | self, 426 | dim: int, 427 | num_heads: int, 428 | window_size: Sequence[int], 429 | qkv_bias: bool = False, 430 | attn_drop: float = 0.0, 431 | proj_drop: float = 0.0, 432 | ) -> None: 433 | """ 434 | Args: 435 | dim: number of feature channels. 436 | num_heads: number of attention heads. 437 | window_size: local window size. 438 | qkv_bias: add a learnable bias to query, key, value. 439 | attn_drop: attention dropout rate. 440 | proj_drop: dropout rate of output. 441 | """ 442 | 443 | super().__init__() 444 | self.dim = dim 445 | self.window_size = window_size 446 | self.num_heads = num_heads 447 | head_dim = dim // num_heads 448 | self.scale = head_dim**-0.5 449 | mesh_args = torch.meshgrid.__kwdefaults__ 450 | 451 | if len(self.window_size) == 3: 452 | self.relative_position_bias_table = nn.Parameter( 453 | torch.zeros( 454 | (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), 455 | num_heads, 456 | ) 457 | ) 458 | coords_d = torch.arange(self.window_size[0]) 459 | coords_h = torch.arange(self.window_size[1]) 460 | coords_w = torch.arange(self.window_size[2]) 461 | if mesh_args is not None: 462 | coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) 463 | else: 464 | coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) 465 | coords_flatten = torch.flatten(coords, 1) 466 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 467 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 468 | relative_coords[:, :, 0] += self.window_size[0] - 1 469 | relative_coords[:, :, 1] += self.window_size[1] - 1 470 | relative_coords[:, :, 2] += self.window_size[2] - 1 471 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) 472 | relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 473 | elif len(self.window_size) == 2: 474 | self.relative_position_bias_table = nn.Parameter( 475 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) 476 | ) 477 | coords_h = torch.arange(self.window_size[0]) 478 | coords_w = torch.arange(self.window_size[1]) 479 | if mesh_args is not None: 480 | coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) 481 | else: 482 | coords = torch.stack(torch.meshgrid(coords_h, coords_w)) 483 | coords_flatten = torch.flatten(coords, 1) 484 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 485 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 486 | relative_coords[:, :, 0] += self.window_size[0] - 1 487 | relative_coords[:, :, 1] += self.window_size[1] - 1 488 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 489 | 490 | relative_position_index = relative_coords.sum(-1) 491 | self.register_buffer("relative_position_index", relative_position_index) 492 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 493 | self.attn_drop = nn.Dropout(attn_drop) 494 | self.proj = nn.Linear(dim, dim) 495 | self.proj_drop = nn.Dropout(proj_drop) 496 | trunc_normal_(self.relative_position_bias_table, std=0.02) 497 | self.softmax = nn.Softmax(dim=-1) 498 | 499 | def forward(self, x, mask): 500 | b, n, c = x.shape 501 | qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) 502 | q, k, v = qkv[0], qkv[1], qkv[2] 503 | q = q * self.scale 504 | attn = q @ k.transpose(-2, -1) 505 | relative_position_bias = self.relative_position_bias_table[ 506 | self.relative_position_index.clone()[:n, :n].reshape(-1) 507 | ].reshape(n, n, -1) 508 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 509 | attn = attn + relative_position_bias.unsqueeze(0) 510 | if mask is not None: 511 | nw = mask.shape[0] 512 | attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) 513 | attn = attn.view(-1, self.num_heads, n, n) 514 | attn = self.softmax(attn) 515 | else: 516 | attn = self.softmax(attn) 517 | 518 | attn = self.attn_drop(attn) 519 | x = (attn @ v).transpose(1, 2).reshape(b, n, c) 520 | x = self.proj(x) 521 | x = self.proj_drop(x) 522 | return x 523 | 524 | 525 | class SwinTransformerBlock(nn.Module): 526 | """ 527 | Swin Transformer block based on: "Liu et al., 528 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 529 | " 530 | https://github.com/microsoft/Swin-Transformer 531 | """ 532 | 533 | def __init__( 534 | self, 535 | dim: int, 536 | num_heads: int, 537 | window_size: Sequence[int], 538 | shift_size: Sequence[int], 539 | mlp_ratio: float = 4.0, 540 | qkv_bias: bool = True, 541 | drop: float = 0.0, 542 | attn_drop: float = 0.0, 543 | drop_path: float = 0.0, 544 | act_layer: str = "GELU", 545 | norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore 546 | use_checkpoint: bool = False, 547 | ) -> None: 548 | """ 549 | Args: 550 | dim: number of feature channels. 551 | num_heads: number of attention heads. 552 | window_size: local window size. 553 | shift_size: window shift size. 554 | mlp_ratio: ratio of mlp hidden dim to embedding dim. 555 | qkv_bias: add a learnable bias to query, key, value. 556 | drop: dropout rate. 557 | attn_drop: attention dropout rate. 558 | drop_path: stochastic depth rate. 559 | act_layer: activation layer. 560 | norm_layer: normalization layer. 561 | use_checkpoint: use gradient checkpointing for reduced memory usage. 562 | """ 563 | 564 | super().__init__() 565 | self.dim = dim 566 | self.num_heads = num_heads 567 | self.window_size = window_size 568 | self.shift_size = shift_size 569 | self.mlp_ratio = mlp_ratio 570 | self.use_checkpoint = use_checkpoint 571 | self.norm1 = norm_layer(dim) 572 | self.attn = WindowAttention( 573 | dim, 574 | window_size=self.window_size, 575 | num_heads=num_heads, 576 | qkv_bias=qkv_bias, 577 | attn_drop=attn_drop, 578 | proj_drop=drop, 579 | ) 580 | 581 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 582 | self.norm2 = norm_layer(dim) 583 | mlp_hidden_dim = int(dim * mlp_ratio) 584 | self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin") 585 | 586 | def forward_part1(self, x, mask_matrix): 587 | x_shape = x.size() 588 | x = self.norm1(x) 589 | if len(x_shape) == 5: 590 | b, d, h, w, c = x.shape 591 | window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) 592 | pad_l = pad_t = pad_d0 = 0 593 | pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] 594 | pad_b = (window_size[1] - h % window_size[1]) % window_size[1] 595 | pad_r = (window_size[2] - w % window_size[2]) % window_size[2] 596 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) 597 | _, dp, hp, wp, _ = x.shape 598 | dims = [b, dp, hp, wp] 599 | 600 | elif len(x_shape) == 4: 601 | b, h, w, c = x.shape 602 | window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) 603 | pad_l = pad_t = 0 604 | pad_r = (window_size[0] - h % window_size[0]) % window_size[0] 605 | pad_b = (window_size[1] - w % window_size[1]) % window_size[1] 606 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 607 | _, hp, wp, _ = x.shape 608 | dims = [b, hp, wp] 609 | 610 | if any(i > 0 for i in shift_size): 611 | if len(x_shape) == 5: 612 | shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) 613 | elif len(x_shape) == 4: 614 | shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) 615 | attn_mask = mask_matrix 616 | else: 617 | shifted_x = x 618 | attn_mask = None 619 | x_windows = window_partition(shifted_x, window_size) 620 | attn_windows = self.attn(x_windows, mask=attn_mask) 621 | attn_windows = attn_windows.view(-1, *(window_size + (c,))) 622 | shifted_x = window_reverse(attn_windows, window_size, dims) 623 | if any(i > 0 for i in shift_size): 624 | if len(x_shape) == 5: 625 | x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) 626 | elif len(x_shape) == 4: 627 | x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) 628 | else: 629 | x = shifted_x 630 | 631 | if len(x_shape) == 5: 632 | if pad_d1 > 0 or pad_r > 0 or pad_b > 0: 633 | x = x[:, :d, :h, :w, :].contiguous() 634 | elif len(x_shape) == 4: 635 | if pad_r > 0 or pad_b > 0: 636 | x = x[:, :h, :w, :].contiguous() 637 | 638 | return x 639 | 640 | def forward_part2(self, x): 641 | return self.drop_path(self.mlp(self.norm2(x))) 642 | 643 | def load_from(self, weights, n_block, layer): 644 | root = f"module.{layer}.0.blocks.{n_block}." 645 | block_names = [ 646 | "norm1.weight", 647 | "norm1.bias", 648 | "attn.relative_position_bias_table", 649 | "attn.relative_position_index", 650 | "attn.qkv.weight", 651 | "attn.qkv.bias", 652 | "attn.proj.weight", 653 | "attn.proj.bias", 654 | "norm2.weight", 655 | "norm2.bias", 656 | "mlp.fc1.weight", 657 | "mlp.fc1.bias", 658 | "mlp.fc2.weight", 659 | "mlp.fc2.bias", 660 | ] 661 | with torch.no_grad(): 662 | self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) 663 | self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]]) 664 | self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]]) 665 | self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) 666 | self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]]) 667 | self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]]) 668 | self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]]) 669 | self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]]) 670 | self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]]) 671 | self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]]) 672 | self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]]) 673 | self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]]) 674 | self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]]) 675 | self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]]) 676 | 677 | def forward(self, x, mask_matrix): 678 | shortcut = x 679 | if self.use_checkpoint: 680 | x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) 681 | else: 682 | x = self.forward_part1(x, mask_matrix) 683 | x = shortcut + self.drop_path(x) 684 | if self.use_checkpoint: 685 | x = x + checkpoint.checkpoint(self.forward_part2, x) 686 | else: 687 | x = x + self.forward_part2(x) 688 | return x 689 | 690 | 691 | class PatchMerging(nn.Module): 692 | """ 693 | Patch merging layer based on: "Liu et al., 694 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 695 | " 696 | https://github.com/microsoft/Swin-Transformer 697 | """ 698 | 699 | def __init__( 700 | self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3 701 | ) -> None: # type: ignore 702 | """ 703 | Args: 704 | dim: number of feature channels. 705 | norm_layer: normalization layer. 706 | spatial_dims: number of spatial dims. 707 | """ 708 | 709 | super().__init__() 710 | self.dim = dim 711 | if spatial_dims == 3: 712 | self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) 713 | self.norm = norm_layer(8 * dim) 714 | elif spatial_dims == 2: 715 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 716 | self.norm = norm_layer(4 * dim) 717 | 718 | def forward(self, x): 719 | 720 | x_shape = x.size() 721 | if len(x_shape) == 5: 722 | b, d, h, w, c = x_shape 723 | pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) 724 | if pad_input: 725 | x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) 726 | x0 = x[:, 0::2, 0::2, 0::2, :] 727 | x1 = x[:, 1::2, 0::2, 0::2, :] 728 | x2 = x[:, 0::2, 1::2, 0::2, :] 729 | x3 = x[:, 0::2, 0::2, 1::2, :] 730 | x4 = x[:, 1::2, 0::2, 1::2, :] 731 | x5 = x[:, 0::2, 1::2, 0::2, :] 732 | x6 = x[:, 0::2, 0::2, 1::2, :] 733 | x7 = x[:, 1::2, 1::2, 1::2, :] 734 | x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) 735 | 736 | elif len(x_shape) == 4: 737 | b, h, w, c = x_shape 738 | pad_input = (h % 2 == 1) or (w % 2 == 1) 739 | if pad_input: 740 | x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) 741 | x0 = x[:, 0::2, 0::2, :] 742 | x1 = x[:, 1::2, 0::2, :] 743 | x2 = x[:, 0::2, 1::2, :] 744 | x3 = x[:, 1::2, 1::2, :] 745 | x = torch.cat([x0, x1, x2, x3], -1) 746 | 747 | x = self.norm(x) 748 | x = self.reduction(x) 749 | return x 750 | 751 | 752 | def compute_mask(dims, window_size, shift_size, device): 753 | """Computing region masks based on: "Liu et al., 754 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 755 | " 756 | https://github.com/microsoft/Swin-Transformer 757 | 758 | Args: 759 | dims: dimension values. 760 | window_size: local window size. 761 | shift_size: shift size. 762 | device: device. 763 | """ 764 | 765 | cnt = 0 766 | 767 | if len(dims) == 3: 768 | d, h, w = dims 769 | img_mask = torch.zeros((1, d, h, w, 1), device=device) 770 | for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): 771 | for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): 772 | for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): 773 | img_mask[:, d, h, w, :] = cnt 774 | cnt += 1 775 | 776 | elif len(dims) == 2: 777 | h, w = dims 778 | img_mask = torch.zeros((1, h, w, 1), device=device) 779 | for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): 780 | for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): 781 | img_mask[:, h, w, :] = cnt 782 | cnt += 1 783 | 784 | mask_windows = window_partition(img_mask, window_size) 785 | mask_windows = mask_windows.squeeze(-1) 786 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 787 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 788 | 789 | return attn_mask 790 | 791 | 792 | class BasicLayer(nn.Module): 793 | """ 794 | Basic Swin Transformer layer in one stage based on: "Liu et al., 795 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 796 | " 797 | https://github.com/microsoft/Swin-Transformer 798 | """ 799 | 800 | def __init__( 801 | self, 802 | dim: int, 803 | depth: int, 804 | num_heads: int, 805 | window_size: Sequence[int], 806 | drop_path: list, 807 | mlp_ratio: float = 4.0, 808 | qkv_bias: bool = False, 809 | drop: float = 0.0, 810 | attn_drop: float = 0.0, 811 | norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore 812 | downsample: isinstance = None, # type: ignore 813 | use_checkpoint: bool = False, 814 | ) -> None: 815 | """ 816 | Args: 817 | dim: number of feature channels. 818 | depths: number of layers in each stage. 819 | num_heads: number of attention heads. 820 | window_size: local window size. 821 | drop_path: stochastic depth rate. 822 | mlp_ratio: ratio of mlp hidden dim to embedding dim. 823 | qkv_bias: add a learnable bias to query, key, value. 824 | drop: dropout rate. 825 | attn_drop: attention dropout rate. 826 | norm_layer: normalization layer. 827 | downsample: downsample layer at the end of the layer. 828 | use_checkpoint: use gradient checkpointing for reduced memory usage. 829 | """ 830 | 831 | super().__init__() 832 | self.window_size = window_size 833 | self.shift_size = tuple(i // 2 for i in window_size) 834 | self.no_shift = tuple(0 for i in window_size) 835 | self.depth = depth 836 | self.use_checkpoint = use_checkpoint 837 | self.blocks = nn.ModuleList( 838 | [ 839 | SwinTransformerBlock( 840 | dim=dim, 841 | num_heads=num_heads, 842 | window_size=self.window_size, 843 | shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, 844 | mlp_ratio=mlp_ratio, 845 | qkv_bias=qkv_bias, 846 | drop=drop, 847 | attn_drop=attn_drop, 848 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 849 | norm_layer=norm_layer, 850 | use_checkpoint=use_checkpoint, 851 | ) 852 | for i in range(depth) 853 | ] 854 | ) 855 | self.downsample = downsample 856 | if self.downsample is not None: 857 | self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) 858 | 859 | def forward(self, x): 860 | x_shape = x.size() 861 | if len(x_shape) == 5: 862 | b, c, d, h, w = x_shape 863 | window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) 864 | x = rearrange(x, "b c d h w -> b d h w c") 865 | dp = int(np.ceil(d / window_size[0])) * window_size[0] 866 | hp = int(np.ceil(h / window_size[1])) * window_size[1] 867 | wp = int(np.ceil(w / window_size[2])) * window_size[2] 868 | attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) 869 | for blk in self.blocks: 870 | x = blk(x, attn_mask) 871 | x = x.view(b, d, h, w, -1) 872 | if self.downsample is not None: 873 | x = self.downsample(x) 874 | x = rearrange(x, "b d h w c -> b c d h w") 875 | 876 | elif len(x_shape) == 4: 877 | b, c, h, w = x_shape 878 | window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) 879 | x = rearrange(x, "b c h w -> b h w c") 880 | hp = int(np.ceil(h / window_size[0])) * window_size[0] 881 | wp = int(np.ceil(w / window_size[1])) * window_size[1] 882 | attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device) 883 | for blk in self.blocks: 884 | x = blk(x, attn_mask) 885 | x = x.view(b, h, w, -1) 886 | if self.downsample is not None: 887 | x = self.downsample(x) 888 | x = rearrange(x, "b h w c -> b c h w") 889 | return x 890 | 891 | 892 | class SwinTransformer(nn.Module): 893 | """ 894 | Swin Transformer based on: "Liu et al., 895 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 896 | " 897 | https://github.com/microsoft/Swin-Transformer 898 | """ 899 | 900 | def __init__( 901 | self, 902 | in_chans: int, 903 | embed_dim: int, 904 | window_size: Sequence[int], 905 | patch_size: Sequence[int], 906 | depths: Sequence[int], 907 | num_heads: Sequence[int], 908 | mlp_ratio: float = 4.0, 909 | qkv_bias: bool = True, 910 | drop_rate: float = 0.0, 911 | attn_drop_rate: float = 0.0, 912 | drop_path_rate: float = 0.0, 913 | norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore 914 | patch_norm: bool = False, 915 | use_checkpoint: bool = False, 916 | spatial_dims: int = 3, 917 | ) -> None: 918 | """ 919 | Args: 920 | in_chans: dimension of input channels. 921 | embed_dim: number of linear projection output channels. 922 | window_size: local window size. 923 | patch_size: patch size. 924 | depths: number of layers in each stage. 925 | num_heads: number of attention heads. 926 | mlp_ratio: ratio of mlp hidden dim to embedding dim. 927 | qkv_bias: add a learnable bias to query, key, value. 928 | drop_rate: dropout rate. 929 | attn_drop_rate: attention dropout rate. 930 | drop_path_rate: stochastic depth rate. 931 | norm_layer: normalization layer. 932 | patch_norm: add normalization after patch embedding. 933 | use_checkpoint: use gradient checkpointing for reduced memory usage. 934 | spatial_dims: spatial dimension. 935 | """ 936 | 937 | super().__init__() 938 | self.num_layers = len(depths) 939 | self.embed_dim = embed_dim 940 | self.patch_norm = patch_norm 941 | self.window_size = window_size 942 | self.patch_size = patch_size 943 | self.patch_embed = PatchEmbed( 944 | patch_size=self.patch_size, 945 | in_chans=in_chans, 946 | embed_dim=embed_dim, 947 | norm_layer=norm_layer if self.patch_norm else None, # type: ignore 948 | spatial_dims=spatial_dims, 949 | ) 950 | self.pos_drop = nn.Dropout(p=drop_rate) 951 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 952 | self.layers1 = nn.ModuleList() 953 | self.layers2 = nn.ModuleList() 954 | self.layers3 = nn.ModuleList() 955 | self.layers4 = nn.ModuleList() 956 | for i_layer in range(self.num_layers): 957 | layer = BasicLayer( 958 | dim=int(embed_dim * 2**i_layer), 959 | depth=depths[i_layer], 960 | num_heads=num_heads[i_layer], 961 | window_size=self.window_size, 962 | drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], 963 | mlp_ratio=mlp_ratio, 964 | qkv_bias=qkv_bias, 965 | drop=drop_rate, 966 | attn_drop=attn_drop_rate, 967 | norm_layer=norm_layer, 968 | downsample=PatchMerging, 969 | use_checkpoint=use_checkpoint, 970 | ) 971 | if i_layer == 0: 972 | self.layers1.append(layer) 973 | elif i_layer == 1: 974 | self.layers2.append(layer) 975 | elif i_layer == 2: 976 | self.layers3.append(layer) 977 | elif i_layer == 3: 978 | self.layers4.append(layer) 979 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 980 | 981 | def proj_out(self, x, normalize=False): 982 | if normalize: 983 | x_shape = x.size() 984 | if len(x_shape) == 5: 985 | n, ch, d, h, w = x_shape 986 | x = rearrange(x, "n c d h w -> n d h w c") 987 | x = F.layer_norm(x, [ch]) 988 | x = rearrange(x, "n d h w c -> n c d h w") 989 | elif len(x_shape) == 4: 990 | n, ch, h, w = x_shape 991 | x = rearrange(x, "n c h w -> n h w c") 992 | x = F.layer_norm(x, [ch]) 993 | x = rearrange(x, "n h w c -> n c h w") 994 | return x 995 | 996 | def forward(self, x, normalize=True): 997 | x0 = self.patch_embed(x) 998 | x0 = self.pos_drop(x0) 999 | x0_out = self.proj_out(x0, normalize) 1000 | x1 = self.layers1[0](x0.contiguous()) 1001 | x1_out = self.proj_out(x1, normalize) 1002 | x2 = self.layers2[0](x1.contiguous()) 1003 | x2_out = self.proj_out(x2, normalize) 1004 | x3 = self.layers3[0](x2.contiguous()) 1005 | x3_out = self.proj_out(x3, normalize) 1006 | x4 = self.layers4[0](x3.contiguous()) 1007 | x4_out = self.proj_out(x4, normalize) 1008 | return [x0_out, x1_out, x2_out, x3_out, x4_out] 1009 | -------------------------------------------------------------------------------- /Pretrain/pretrain_models/swinunetr_8.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Sequence, Tuple, Type, Union 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.utils.checkpoint as checkpoint 19 | from torch.nn import LayerNorm 20 | 21 | from monai.networks.blocks import MLPBlock as Mlp 22 | from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock 23 | from monai.networks.layers import DropPath, trunc_normal_ 24 | from monai.utils import ensure_tuple_rep, optional_import 25 | from .utils import mask_func 26 | from .utils import get_mask_labels, get_mask_labelsv2 27 | 28 | rearrange, _ = optional_import("einops", name="rearrange") 29 | 30 | class SwinUNETR(nn.Module): 31 | """ 32 | Swin UNETR based on: "Hatamizadeh et al., 33 | Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images 34 | " 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[Sequence[int], int], 40 | in_channels: int, 41 | out_channels: int, 42 | depths: Sequence[int] = (2, 2, 2, 2), 43 | num_heads: Sequence[int] = (3, 6, 12, 24), 44 | feature_size: int = 24, 45 | norm_name: Union[Tuple, str] = "instance", 46 | drop_rate: float = 0.0, 47 | attn_drop_rate: float = 0.0, 48 | dropout_path_rate: float = 0.0, 49 | normalize: bool = True, 50 | use_checkpoint: bool = False, 51 | spatial_dims: int = 3, 52 | pretrain=False, 53 | select_reconstruct_region=(0, 3), 54 | ) -> None: 55 | """ 56 | Args: 57 | img_size: dimension of input image. 58 | in_channels: dimension of input channels. 59 | out_channels: dimension of output channels. 60 | feature_size: dimension of network feature size. 61 | depths: number of layers in each stage. 62 | num_heads: number of attention heads. 63 | norm_name: feature normalization type and arguments. 64 | drop_rate: dropout rate. 65 | attn_drop_rate: attention dropout rate. 66 | dropout_path_rate: drop path rate. 67 | normalize: normalize output intermediate features in each stage. 68 | use_checkpoint: use gradient checkpointing for reduced memory usage. 69 | spatial_dims: number of spatial dims. 70 | 71 | Examples:: 72 | 73 | # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. 74 | >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) 75 | 76 | # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. 77 | >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) 78 | 79 | # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. 80 | >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) 81 | 82 | """ 83 | 84 | super().__init__() 85 | self.in_channels = in_channels 86 | downsample_size = 5 87 | img_size = ensure_tuple_rep(img_size, spatial_dims) 88 | patch_size = ensure_tuple_rep(2, spatial_dims) 89 | window_size = ensure_tuple_rep(7, spatial_dims) 90 | self.select_reconstruct_region = select_reconstruct_region 91 | self.stages = [] 92 | for i in range(downsample_size+1): 93 | self.stages.append((select_reconstruct_region[0] * 2**i, select_reconstruct_region[1] * 2**i)) 94 | print(self.stages) 95 | self.pretrain = pretrain 96 | if not (spatial_dims == 2 or spatial_dims == 3): 97 | raise ValueError("spatial dimension should be 2 or 3.") 98 | 99 | for m, p in zip(img_size, patch_size): 100 | for i in range(5): 101 | if m % np.power(p, i + 1) != 0: 102 | raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") 103 | 104 | if not (0 <= drop_rate <= 1): 105 | raise ValueError("dropout rate should be between 0 and 1.") 106 | 107 | if not (0 <= attn_drop_rate <= 1): 108 | raise ValueError("attention dropout rate should be between 0 and 1.") 109 | 110 | if not (0 <= dropout_path_rate <= 1): 111 | raise ValueError("drop path rate should be between 0 and 1.") 112 | 113 | if feature_size % 12 != 0: 114 | raise ValueError("feature_size should be divisible by 12.") 115 | 116 | self.normalize = normalize 117 | 118 | self.swinViT = SwinTransformer( 119 | in_chans=in_channels, 120 | embed_dim=feature_size, 121 | window_size=window_size, 122 | patch_size=patch_size, 123 | depths=depths, 124 | num_heads=num_heads, 125 | mlp_ratio=4.0, 126 | qkv_bias=True, 127 | drop_rate=drop_rate, 128 | attn_drop_rate=attn_drop_rate, 129 | drop_path_rate=dropout_path_rate, 130 | norm_layer=nn.LayerNorm, 131 | use_checkpoint=use_checkpoint, 132 | spatial_dims=spatial_dims, 133 | ) 134 | 135 | self.encoder1 = UnetrBasicBlock( 136 | spatial_dims=spatial_dims, 137 | in_channels=in_channels, 138 | out_channels=feature_size, 139 | kernel_size=3, 140 | stride=1, 141 | norm_name=norm_name, 142 | res_block=True, 143 | ) 144 | 145 | self.encoder2 = UnetrBasicBlock( 146 | spatial_dims=spatial_dims, 147 | in_channels=feature_size, 148 | out_channels=feature_size, 149 | kernel_size=3, 150 | stride=1, 151 | norm_name=norm_name, 152 | res_block=True, 153 | ) 154 | 155 | self.encoder3 = UnetrBasicBlock( 156 | spatial_dims=spatial_dims, 157 | in_channels=2 * feature_size, 158 | out_channels=2 * feature_size, 159 | kernel_size=3, 160 | stride=1, 161 | norm_name=norm_name, 162 | res_block=True, 163 | ) 164 | 165 | self.encoder4 = UnetrBasicBlock( 166 | spatial_dims=spatial_dims, 167 | in_channels=4 * feature_size, 168 | out_channels=4 * feature_size, 169 | kernel_size=3, 170 | stride=1, 171 | norm_name=norm_name, 172 | res_block=True, 173 | ) 174 | 175 | self.encoder10 = UnetrBasicBlock( 176 | spatial_dims=spatial_dims, 177 | in_channels=16 * feature_size, 178 | out_channels=16 * feature_size, 179 | kernel_size=3, 180 | stride=1, 181 | norm_name=norm_name, 182 | res_block=True, 183 | ) 184 | 185 | self.decoder5 = UnetrUpBlock( 186 | spatial_dims=spatial_dims, 187 | in_channels=16 * feature_size, 188 | out_channels=8 * feature_size, 189 | kernel_size=3, 190 | upsample_kernel_size=2, 191 | norm_name=norm_name, 192 | res_block=True, 193 | ) 194 | 195 | self.decoder4 = UnetrUpBlock( 196 | spatial_dims=spatial_dims, 197 | in_channels=feature_size * 8, 198 | out_channels=feature_size * 4, 199 | kernel_size=3, 200 | upsample_kernel_size=2, 201 | norm_name=norm_name, 202 | res_block=True, 203 | ) 204 | 205 | self.decoder3 = UnetrUpBlock( 206 | spatial_dims=spatial_dims, 207 | in_channels=feature_size * 4, 208 | out_channels=feature_size * 2, 209 | kernel_size=3, 210 | upsample_kernel_size=2, 211 | norm_name=norm_name, 212 | res_block=True, 213 | ) 214 | self.decoder2 = UnetrUpBlock( 215 | spatial_dims=spatial_dims, 216 | in_channels=feature_size * 2, 217 | out_channels=feature_size, 218 | kernel_size=3, 219 | upsample_kernel_size=2, 220 | norm_name=norm_name, 221 | res_block=True, 222 | ) 223 | 224 | self.decoder1 = UnetrUpBlock( 225 | spatial_dims=spatial_dims, 226 | in_channels=feature_size, 227 | out_channels=feature_size, 228 | kernel_size=3, 229 | upsample_kernel_size=2, 230 | norm_name=norm_name, 231 | res_block=True, 232 | ) 233 | 234 | self.out = UnetOutBlock( 235 | spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels 236 | ) # type: ignore 237 | if pretrain: 238 | if feature_size == 24: 239 | self.pred_mask_region = nn.Linear(384, 65)# 一个region 8个 patch 240 | self.contrast_learning_head = nn.Linear(384, 384) 241 | else: 242 | self.pred_mask_region = nn.Linear(768, 65)# 一个region 8个 patch 243 | self.contrast_learning_head = nn.Linear(768, 384) 244 | self.pred_mask_region_position = nn.Linear(768, 64) 245 | 246 | def wrap_feature_selection(self, feature, region_box): 247 | # feature: b, c, d, w, h 248 | return feature[..., region_box[0]:region_box[1], region_box[0]:region_box[1], region_box[0]:region_box[1]] 249 | 250 | def get_local_images(self, images): 251 | images = self.wrap_feature_selection(images, region_box=self.stages[5]) 252 | return images 253 | 254 | def forward(self, x_in): 255 | device = x_in.device 256 | images = x_in.detach() 257 | if self.pretrain: 258 | # mask_ratio = torch.clamp(torch.rand(1), 0.2, 0.4) 259 | mask_ratio = 0.4 260 | # x_in, mask = mask_func(x_in, self.in_channels, mask_ratio, (32, 32, 32), (4, 4, 4)) 261 | x_in, mask = mask_func(x_in, self.in_channels, mask_ratio, (8, 8, 8), (12, 12, 12)) 262 | 263 | region_mask_labels = get_mask_labels(x_in.shape[0], 3*3*3, mask, 4*4*4, device) 264 | region_mask_position = get_mask_labelsv2(x_in.shape[0], 3*3*3, mask, 4*4*4, device=device) 265 | 266 | hidden_states_out = self.swinViT(x_in, self.normalize) 267 | local_images = self.get_local_images(images) 268 | return_x_in = self.wrap_feature_selection(x_in, region_box=self.stages[5]) 269 | 270 | enc0 = self.encoder1(self.wrap_feature_selection(x_in, region_box=self.stages[5])) 271 | enc1 = self.encoder2(self.wrap_feature_selection(hidden_states_out[0], region_box=self.stages[4])) 272 | enc2 = self.encoder3(self.wrap_feature_selection(hidden_states_out[1], region_box=self.stages[3])) 273 | enc3 = self.encoder4(self.wrap_feature_selection(hidden_states_out[2], region_box=self.stages[2])) 274 | dec4 = self.encoder10(self.wrap_feature_selection(hidden_states_out[4], region_box=self.stages[0])) 275 | dec3 = self.decoder5(dec4, self.wrap_feature_selection(hidden_states_out[3], region_box=self.stages[1])) 276 | dec2 = self.decoder4(dec3, enc3) 277 | dec1 = self.decoder3(dec2, enc2) 278 | dec0 = self.decoder2(dec1, enc1) 279 | out = self.decoder1(dec0, enc0) 280 | logits = self.out(out) 281 | 282 | if self.pretrain: 283 | with torch.no_grad(): 284 | hidden_states_out_2 = self.swinViT(x_in, self.normalize) 285 | encode_feature = hidden_states_out[4] 286 | encode_feature_2 = hidden_states_out_2[4] 287 | 288 | x4_reshape = encode_feature.flatten(start_dim=2, end_dim=4) 289 | x4_reshape = x4_reshape.transpose(1, 2) 290 | 291 | x4_reshape_2 = encode_feature_2.flatten(start_dim=2, end_dim=4) 292 | x4_reshape_2 = x4_reshape_2.transpose(1, 2) 293 | 294 | contrast_pred = self.contrast_learning_head(x4_reshape[:, 1]) 295 | contrast_pred_2 = self.contrast_learning_head(x4_reshape_2[:, 1]) 296 | 297 | pred_mask_feature = encode_feature.flatten(start_dim=2, end_dim=4) 298 | pred_mask_feature = pred_mask_feature.transpose(1, 2) 299 | mask_region_pred = self.pred_mask_region(pred_mask_feature) 300 | 301 | pred_mask_feature_position = encode_feature.flatten(start_dim=2, end_dim=4) 302 | pred_mask_feature_position = pred_mask_feature_position.transpose(1, 2) 303 | mask_region_position_pred = self.pred_mask_region_position(pred_mask_feature_position) 304 | 305 | return { 306 | "logits": logits, 307 | 'images': local_images, 308 | "pred_mask_region": mask_region_pred, 309 | "pred_mask_region_position": mask_region_position_pred, 310 | "mask": mask, 311 | "x_mask": return_x_in, 312 | # "patch_size": patch_size, 313 | # "mask_feat_size": mask_feat_size, 314 | # "mask_labels": mask_labels, 315 | "mask_position_lables": region_mask_position, 316 | "mask_labels": region_mask_labels, 317 | "contrast_pred_1": contrast_pred, 318 | "contrast_pred_2": contrast_pred_2, 319 | } 320 | else : 321 | return logits 322 | 323 | 324 | def window_partition(x, window_size): 325 | """window partition operation based on: "Liu et al., 326 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 327 | " 328 | https://github.com/microsoft/Swin-Transformer 329 | 330 | Args: 331 | x: input tensor. 332 | window_size: local window size. 333 | """ 334 | x_shape = x.size() 335 | if len(x_shape) == 5: 336 | b, d, h, w, c = x_shape 337 | x = x.view( 338 | b, 339 | d // window_size[0], 340 | window_size[0], 341 | h // window_size[1], 342 | window_size[1], 343 | w // window_size[2], 344 | window_size[2], 345 | c, 346 | ) 347 | windows = ( 348 | x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) 349 | ) 350 | elif len(x_shape) == 4: 351 | b, h, w, c = x.shape 352 | x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) 353 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) 354 | return windows 355 | 356 | 357 | def window_reverse(windows, window_size, dims): 358 | """window reverse operation based on: "Liu et al., 359 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 360 | " 361 | https://github.com/microsoft/Swin-Transformer 362 | 363 | Args: 364 | windows: windows tensor. 365 | window_size: local window size. 366 | dims: dimension values. 367 | """ 368 | if len(dims) == 4: 369 | b, d, h, w = dims 370 | x = windows.view( 371 | b, 372 | d // window_size[0], 373 | h // window_size[1], 374 | w // window_size[2], 375 | window_size[0], 376 | window_size[1], 377 | window_size[2], 378 | -1, 379 | ) 380 | x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) 381 | 382 | elif len(dims) == 3: 383 | b, h, w = dims 384 | x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1) 385 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) 386 | return x 387 | 388 | 389 | def get_window_size(x_size, window_size, shift_size=None): 390 | """Computing window size based on: "Liu et al., 391 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 392 | " 393 | https://github.com/microsoft/Swin-Transformer 394 | 395 | Args: 396 | x_size: input size. 397 | window_size: local window size. 398 | shift_size: window shifting size. 399 | """ 400 | 401 | use_window_size = list(window_size) 402 | if shift_size is not None: 403 | use_shift_size = list(shift_size) 404 | for i in range(len(x_size)): 405 | if x_size[i] <= window_size[i]: 406 | use_window_size[i] = x_size[i] 407 | if shift_size is not None: 408 | use_shift_size[i] = 0 409 | 410 | if shift_size is None: 411 | return tuple(use_window_size) 412 | else: 413 | return tuple(use_window_size), tuple(use_shift_size) 414 | 415 | 416 | class WindowAttention(nn.Module): 417 | """ 418 | Window based multi-head self attention module with relative position bias based on: "Liu et al., 419 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 420 | " 421 | https://github.com/microsoft/Swin-Transformer 422 | """ 423 | 424 | def __init__( 425 | self, 426 | dim: int, 427 | num_heads: int, 428 | window_size: Sequence[int], 429 | qkv_bias: bool = False, 430 | attn_drop: float = 0.0, 431 | proj_drop: float = 0.0, 432 | ) -> None: 433 | """ 434 | Args: 435 | dim: number of feature channels. 436 | num_heads: number of attention heads. 437 | window_size: local window size. 438 | qkv_bias: add a learnable bias to query, key, value. 439 | attn_drop: attention dropout rate. 440 | proj_drop: dropout rate of output. 441 | """ 442 | 443 | super().__init__() 444 | self.dim = dim 445 | self.window_size = window_size 446 | self.num_heads = num_heads 447 | head_dim = dim // num_heads 448 | self.scale = head_dim**-0.5 449 | mesh_args = torch.meshgrid.__kwdefaults__ 450 | 451 | if len(self.window_size) == 3: 452 | self.relative_position_bias_table = nn.Parameter( 453 | torch.zeros( 454 | (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), 455 | num_heads, 456 | ) 457 | ) 458 | coords_d = torch.arange(self.window_size[0]) 459 | coords_h = torch.arange(self.window_size[1]) 460 | coords_w = torch.arange(self.window_size[2]) 461 | if mesh_args is not None: 462 | coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) 463 | else: 464 | coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) 465 | coords_flatten = torch.flatten(coords, 1) 466 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 467 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 468 | relative_coords[:, :, 0] += self.window_size[0] - 1 469 | relative_coords[:, :, 1] += self.window_size[1] - 1 470 | relative_coords[:, :, 2] += self.window_size[2] - 1 471 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) 472 | relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 473 | elif len(self.window_size) == 2: 474 | self.relative_position_bias_table = nn.Parameter( 475 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) 476 | ) 477 | coords_h = torch.arange(self.window_size[0]) 478 | coords_w = torch.arange(self.window_size[1]) 479 | if mesh_args is not None: 480 | coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) 481 | else: 482 | coords = torch.stack(torch.meshgrid(coords_h, coords_w)) 483 | coords_flatten = torch.flatten(coords, 1) 484 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 485 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 486 | relative_coords[:, :, 0] += self.window_size[0] - 1 487 | relative_coords[:, :, 1] += self.window_size[1] - 1 488 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 489 | 490 | relative_position_index = relative_coords.sum(-1) 491 | self.register_buffer("relative_position_index", relative_position_index) 492 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 493 | self.attn_drop = nn.Dropout(attn_drop) 494 | self.proj = nn.Linear(dim, dim) 495 | self.proj_drop = nn.Dropout(proj_drop) 496 | trunc_normal_(self.relative_position_bias_table, std=0.02) 497 | self.softmax = nn.Softmax(dim=-1) 498 | 499 | def forward(self, x, mask): 500 | b, n, c = x.shape 501 | qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) 502 | q, k, v = qkv[0], qkv[1], qkv[2] 503 | q = q * self.scale 504 | attn = q @ k.transpose(-2, -1) 505 | relative_position_bias = self.relative_position_bias_table[ 506 | self.relative_position_index.clone()[:n, :n].reshape(-1) 507 | ].reshape(n, n, -1) 508 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 509 | attn = attn + relative_position_bias.unsqueeze(0) 510 | if mask is not None: 511 | nw = mask.shape[0] 512 | attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) 513 | attn = attn.view(-1, self.num_heads, n, n) 514 | attn = self.softmax(attn) 515 | else: 516 | attn = self.softmax(attn) 517 | 518 | attn = self.attn_drop(attn) 519 | x = (attn @ v).transpose(1, 2).reshape(b, n, c) 520 | x = self.proj(x) 521 | x = self.proj_drop(x) 522 | return x 523 | 524 | 525 | class SwinTransformerBlock(nn.Module): 526 | """ 527 | Swin Transformer block based on: "Liu et al., 528 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 529 | " 530 | https://github.com/microsoft/Swin-Transformer 531 | """ 532 | 533 | def __init__( 534 | self, 535 | dim: int, 536 | num_heads: int, 537 | window_size: Sequence[int], 538 | shift_size: Sequence[int], 539 | mlp_ratio: float = 4.0, 540 | qkv_bias: bool = True, 541 | drop: float = 0.0, 542 | attn_drop: float = 0.0, 543 | drop_path: float = 0.0, 544 | act_layer: str = "GELU", 545 | norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore 546 | use_checkpoint: bool = False, 547 | ) -> None: 548 | """ 549 | Args: 550 | dim: number of feature channels. 551 | num_heads: number of attention heads. 552 | window_size: local window size. 553 | shift_size: window shift size. 554 | mlp_ratio: ratio of mlp hidden dim to embedding dim. 555 | qkv_bias: add a learnable bias to query, key, value. 556 | drop: dropout rate. 557 | attn_drop: attention dropout rate. 558 | drop_path: stochastic depth rate. 559 | act_layer: activation layer. 560 | norm_layer: normalization layer. 561 | use_checkpoint: use gradient checkpointing for reduced memory usage. 562 | """ 563 | 564 | super().__init__() 565 | self.dim = dim 566 | self.num_heads = num_heads 567 | self.window_size = window_size 568 | self.shift_size = shift_size 569 | self.mlp_ratio = mlp_ratio 570 | self.use_checkpoint = use_checkpoint 571 | self.norm1 = norm_layer(dim) 572 | self.attn = WindowAttention( 573 | dim, 574 | window_size=self.window_size, 575 | num_heads=num_heads, 576 | qkv_bias=qkv_bias, 577 | attn_drop=attn_drop, 578 | proj_drop=drop, 579 | ) 580 | 581 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 582 | self.norm2 = norm_layer(dim) 583 | mlp_hidden_dim = int(dim * mlp_ratio) 584 | self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin") 585 | 586 | def forward_part1(self, x, mask_matrix): 587 | x_shape = x.size() 588 | x = self.norm1(x) 589 | if len(x_shape) == 5: 590 | b, d, h, w, c = x.shape 591 | window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) 592 | pad_l = pad_t = pad_d0 = 0 593 | pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] 594 | pad_b = (window_size[1] - h % window_size[1]) % window_size[1] 595 | pad_r = (window_size[2] - w % window_size[2]) % window_size[2] 596 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) 597 | _, dp, hp, wp, _ = x.shape 598 | dims = [b, dp, hp, wp] 599 | 600 | elif len(x_shape) == 4: 601 | b, h, w, c = x.shape 602 | window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) 603 | pad_l = pad_t = 0 604 | pad_r = (window_size[0] - h % window_size[0]) % window_size[0] 605 | pad_b = (window_size[1] - w % window_size[1]) % window_size[1] 606 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 607 | _, hp, wp, _ = x.shape 608 | dims = [b, hp, wp] 609 | 610 | if any(i > 0 for i in shift_size): 611 | if len(x_shape) == 5: 612 | shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) 613 | elif len(x_shape) == 4: 614 | shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) 615 | attn_mask = mask_matrix 616 | else: 617 | shifted_x = x 618 | attn_mask = None 619 | x_windows = window_partition(shifted_x, window_size) 620 | attn_windows = self.attn(x_windows, mask=attn_mask) 621 | attn_windows = attn_windows.view(-1, *(window_size + (c,))) 622 | shifted_x = window_reverse(attn_windows, window_size, dims) 623 | if any(i > 0 for i in shift_size): 624 | if len(x_shape) == 5: 625 | x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) 626 | elif len(x_shape) == 4: 627 | x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) 628 | else: 629 | x = shifted_x 630 | 631 | if len(x_shape) == 5: 632 | if pad_d1 > 0 or pad_r > 0 or pad_b > 0: 633 | x = x[:, :d, :h, :w, :].contiguous() 634 | elif len(x_shape) == 4: 635 | if pad_r > 0 or pad_b > 0: 636 | x = x[:, :h, :w, :].contiguous() 637 | 638 | return x 639 | 640 | def forward_part2(self, x): 641 | return self.drop_path(self.mlp(self.norm2(x))) 642 | 643 | def load_from(self, weights, n_block, layer): 644 | root = f"module.{layer}.0.blocks.{n_block}." 645 | block_names = [ 646 | "norm1.weight", 647 | "norm1.bias", 648 | "attn.relative_position_bias_table", 649 | "attn.relative_position_index", 650 | "attn.qkv.weight", 651 | "attn.qkv.bias", 652 | "attn.proj.weight", 653 | "attn.proj.bias", 654 | "norm2.weight", 655 | "norm2.bias", 656 | "mlp.fc1.weight", 657 | "mlp.fc1.bias", 658 | "mlp.fc2.weight", 659 | "mlp.fc2.bias", 660 | ] 661 | with torch.no_grad(): 662 | self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) 663 | self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]]) 664 | self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]]) 665 | self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) 666 | self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]]) 667 | self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]]) 668 | self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]]) 669 | self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]]) 670 | self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]]) 671 | self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]]) 672 | self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]]) 673 | self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]]) 674 | self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]]) 675 | self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]]) 676 | 677 | def forward(self, x, mask_matrix): 678 | shortcut = x 679 | if self.use_checkpoint: 680 | x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) 681 | else: 682 | x = self.forward_part1(x, mask_matrix) 683 | x = shortcut + self.drop_path(x) 684 | if self.use_checkpoint: 685 | x = x + checkpoint.checkpoint(self.forward_part2, x) 686 | else: 687 | x = x + self.forward_part2(x) 688 | return x 689 | 690 | 691 | class PatchMerging(nn.Module): 692 | """ 693 | Patch merging layer based on: "Liu et al., 694 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 695 | " 696 | https://github.com/microsoft/Swin-Transformer 697 | """ 698 | 699 | def __init__( 700 | self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3 701 | ) -> None: # type: ignore 702 | """ 703 | Args: 704 | dim: number of feature channels. 705 | norm_layer: normalization layer. 706 | spatial_dims: number of spatial dims. 707 | """ 708 | 709 | super().__init__() 710 | self.dim = dim 711 | if spatial_dims == 3: 712 | self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) 713 | self.norm = norm_layer(8 * dim) 714 | elif spatial_dims == 2: 715 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 716 | self.norm = norm_layer(4 * dim) 717 | 718 | def forward(self, x): 719 | 720 | x_shape = x.size() 721 | if len(x_shape) == 5: 722 | b, d, h, w, c = x_shape 723 | pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) 724 | if pad_input: 725 | x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) 726 | x0 = x[:, 0::2, 0::2, 0::2, :] 727 | x1 = x[:, 1::2, 0::2, 0::2, :] 728 | x2 = x[:, 0::2, 1::2, 0::2, :] 729 | x3 = x[:, 0::2, 0::2, 1::2, :] 730 | x4 = x[:, 1::2, 0::2, 1::2, :] 731 | x5 = x[:, 0::2, 1::2, 0::2, :] 732 | x6 = x[:, 0::2, 0::2, 1::2, :] 733 | x7 = x[:, 1::2, 1::2, 1::2, :] 734 | x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) 735 | 736 | elif len(x_shape) == 4: 737 | b, h, w, c = x_shape 738 | pad_input = (h % 2 == 1) or (w % 2 == 1) 739 | if pad_input: 740 | x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) 741 | x0 = x[:, 0::2, 0::2, :] 742 | x1 = x[:, 1::2, 0::2, :] 743 | x2 = x[:, 0::2, 1::2, :] 744 | x3 = x[:, 1::2, 1::2, :] 745 | x = torch.cat([x0, x1, x2, x3], -1) 746 | 747 | x = self.norm(x) 748 | x = self.reduction(x) 749 | return x 750 | 751 | 752 | def compute_mask(dims, window_size, shift_size, device): 753 | """Computing region masks based on: "Liu et al., 754 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 755 | " 756 | https://github.com/microsoft/Swin-Transformer 757 | 758 | Args: 759 | dims: dimension values. 760 | window_size: local window size. 761 | shift_size: shift size. 762 | device: device. 763 | """ 764 | 765 | cnt = 0 766 | 767 | if len(dims) == 3: 768 | d, h, w = dims 769 | img_mask = torch.zeros((1, d, h, w, 1), device=device) 770 | for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): 771 | for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): 772 | for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): 773 | img_mask[:, d, h, w, :] = cnt 774 | cnt += 1 775 | 776 | elif len(dims) == 2: 777 | h, w = dims 778 | img_mask = torch.zeros((1, h, w, 1), device=device) 779 | for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): 780 | for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): 781 | img_mask[:, h, w, :] = cnt 782 | cnt += 1 783 | 784 | mask_windows = window_partition(img_mask, window_size) 785 | mask_windows = mask_windows.squeeze(-1) 786 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 787 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 788 | 789 | return attn_mask 790 | 791 | 792 | class BasicLayer(nn.Module): 793 | """ 794 | Basic Swin Transformer layer in one stage based on: "Liu et al., 795 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 796 | " 797 | https://github.com/microsoft/Swin-Transformer 798 | """ 799 | 800 | def __init__( 801 | self, 802 | dim: int, 803 | depth: int, 804 | num_heads: int, 805 | window_size: Sequence[int], 806 | drop_path: list, 807 | mlp_ratio: float = 4.0, 808 | qkv_bias: bool = False, 809 | drop: float = 0.0, 810 | attn_drop: float = 0.0, 811 | norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore 812 | downsample: isinstance = None, # type: ignore 813 | use_checkpoint: bool = False, 814 | ) -> None: 815 | """ 816 | Args: 817 | dim: number of feature channels. 818 | depths: number of layers in each stage. 819 | num_heads: number of attention heads. 820 | window_size: local window size. 821 | drop_path: stochastic depth rate. 822 | mlp_ratio: ratio of mlp hidden dim to embedding dim. 823 | qkv_bias: add a learnable bias to query, key, value. 824 | drop: dropout rate. 825 | attn_drop: attention dropout rate. 826 | norm_layer: normalization layer. 827 | downsample: downsample layer at the end of the layer. 828 | use_checkpoint: use gradient checkpointing for reduced memory usage. 829 | """ 830 | 831 | super().__init__() 832 | self.window_size = window_size 833 | self.shift_size = tuple(i // 2 for i in window_size) 834 | self.no_shift = tuple(0 for i in window_size) 835 | self.depth = depth 836 | self.use_checkpoint = use_checkpoint 837 | self.blocks = nn.ModuleList( 838 | [ 839 | SwinTransformerBlock( 840 | dim=dim, 841 | num_heads=num_heads, 842 | window_size=self.window_size, 843 | shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, 844 | mlp_ratio=mlp_ratio, 845 | qkv_bias=qkv_bias, 846 | drop=drop, 847 | attn_drop=attn_drop, 848 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 849 | norm_layer=norm_layer, 850 | use_checkpoint=use_checkpoint, 851 | ) 852 | for i in range(depth) 853 | ] 854 | ) 855 | self.downsample = downsample 856 | if self.downsample is not None: 857 | self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) 858 | 859 | def forward(self, x): 860 | x_shape = x.size() 861 | if len(x_shape) == 5: 862 | b, c, d, h, w = x_shape 863 | window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) 864 | x = rearrange(x, "b c d h w -> b d h w c") 865 | dp = int(np.ceil(d / window_size[0])) * window_size[0] 866 | hp = int(np.ceil(h / window_size[1])) * window_size[1] 867 | wp = int(np.ceil(w / window_size[2])) * window_size[2] 868 | attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) 869 | for blk in self.blocks: 870 | x = blk(x, attn_mask) 871 | x = x.view(b, d, h, w, -1) 872 | if self.downsample is not None: 873 | x = self.downsample(x) 874 | x = rearrange(x, "b d h w c -> b c d h w") 875 | 876 | elif len(x_shape) == 4: 877 | b, c, h, w = x_shape 878 | window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) 879 | x = rearrange(x, "b c h w -> b h w c") 880 | hp = int(np.ceil(h / window_size[0])) * window_size[0] 881 | wp = int(np.ceil(w / window_size[1])) * window_size[1] 882 | attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device) 883 | for blk in self.blocks: 884 | x = blk(x, attn_mask) 885 | x = x.view(b, h, w, -1) 886 | if self.downsample is not None: 887 | x = self.downsample(x) 888 | x = rearrange(x, "b h w c -> b c h w") 889 | return x 890 | 891 | 892 | class SwinTransformer(nn.Module): 893 | """ 894 | Swin Transformer based on: "Liu et al., 895 | Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 896 | " 897 | https://github.com/microsoft/Swin-Transformer 898 | """ 899 | 900 | def __init__( 901 | self, 902 | in_chans: int, 903 | embed_dim: int, 904 | window_size: Sequence[int], 905 | patch_size: Sequence[int], 906 | depths: Sequence[int], 907 | num_heads: Sequence[int], 908 | mlp_ratio: float = 4.0, 909 | qkv_bias: bool = True, 910 | drop_rate: float = 0.0, 911 | attn_drop_rate: float = 0.0, 912 | drop_path_rate: float = 0.0, 913 | norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore 914 | patch_norm: bool = False, 915 | use_checkpoint: bool = False, 916 | spatial_dims: int = 3, 917 | ) -> None: 918 | """ 919 | Args: 920 | in_chans: dimension of input channels. 921 | embed_dim: number of linear projection output channels. 922 | window_size: local window size. 923 | patch_size: patch size. 924 | depths: number of layers in each stage. 925 | num_heads: number of attention heads. 926 | mlp_ratio: ratio of mlp hidden dim to embedding dim. 927 | qkv_bias: add a learnable bias to query, key, value. 928 | drop_rate: dropout rate. 929 | attn_drop_rate: attention dropout rate. 930 | drop_path_rate: stochastic depth rate. 931 | norm_layer: normalization layer. 932 | patch_norm: add normalization after patch embedding. 933 | use_checkpoint: use gradient checkpointing for reduced memory usage. 934 | spatial_dims: spatial dimension. 935 | """ 936 | 937 | super().__init__() 938 | self.num_layers = len(depths) 939 | self.embed_dim = embed_dim 940 | self.patch_norm = patch_norm 941 | self.window_size = window_size 942 | self.patch_size = patch_size 943 | self.patch_embed = PatchEmbed( 944 | patch_size=self.patch_size, 945 | in_chans=in_chans, 946 | embed_dim=embed_dim, 947 | norm_layer=norm_layer if self.patch_norm else None, # type: ignore 948 | spatial_dims=spatial_dims, 949 | ) 950 | self.pos_drop = nn.Dropout(p=drop_rate) 951 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 952 | self.layers1 = nn.ModuleList() 953 | self.layers2 = nn.ModuleList() 954 | self.layers3 = nn.ModuleList() 955 | self.layers4 = nn.ModuleList() 956 | for i_layer in range(self.num_layers): 957 | layer = BasicLayer( 958 | dim=int(embed_dim * 2**i_layer), 959 | depth=depths[i_layer], 960 | num_heads=num_heads[i_layer], 961 | window_size=self.window_size, 962 | drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], 963 | mlp_ratio=mlp_ratio, 964 | qkv_bias=qkv_bias, 965 | drop=drop_rate, 966 | attn_drop=attn_drop_rate, 967 | norm_layer=norm_layer, 968 | downsample=PatchMerging, 969 | use_checkpoint=use_checkpoint, 970 | ) 971 | if i_layer == 0: 972 | self.layers1.append(layer) 973 | elif i_layer == 1: 974 | self.layers2.append(layer) 975 | elif i_layer == 2: 976 | self.layers3.append(layer) 977 | elif i_layer == 3: 978 | self.layers4.append(layer) 979 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 980 | 981 | def proj_out(self, x, normalize=False): 982 | if normalize: 983 | x_shape = x.size() 984 | if len(x_shape) == 5: 985 | n, ch, d, h, w = x_shape 986 | x = rearrange(x, "n c d h w -> n d h w c") 987 | x = F.layer_norm(x, [ch]) 988 | x = rearrange(x, "n d h w c -> n c d h w") 989 | elif len(x_shape) == 4: 990 | n, ch, h, w = x_shape 991 | x = rearrange(x, "n c h w -> n h w c") 992 | x = F.layer_norm(x, [ch]) 993 | x = rearrange(x, "n h w c -> n c h w") 994 | return x 995 | 996 | def forward(self, x, normalize=True): 997 | x0 = self.patch_embed(x) 998 | x0 = self.pos_drop(x0) 999 | x0_out = self.proj_out(x0, normalize) 1000 | x1 = self.layers1[0](x0.contiguous()) 1001 | x1_out = self.proj_out(x1, normalize) 1002 | x2 = self.layers2[0](x1.contiguous()) 1003 | x2_out = self.proj_out(x2, normalize) 1004 | x3 = self.layers3[0](x2.contiguous()) 1005 | x3_out = self.proj_out(x3, normalize) 1006 | x4 = self.layers4[0](x3.contiguous()) 1007 | x4_out = self.proj_out(x4, normalize) 1008 | return [x0_out, x1_out, x2_out, x3_out, x4_out] 1009 | --------------------------------------------------------------------------------