├── README.md ├── config.py ├── configs └── swin_tiny_patch4_window7_224_lite.yaml ├── data_demo ├── classification │ └── UDIAT │ │ ├── 0 │ │ └── demo_01.png │ │ ├── 1 │ │ └── demo_02.png │ │ ├── config.yaml │ │ └── test.txt └── segmentation │ └── BUSIS │ ├── config.yaml │ ├── imgs │ ├── demo_01.png │ └── demo_02.png │ ├── masks │ ├── demo_01.png │ └── demo_02.png │ └── test.txt ├── datasets ├── dataset.py └── omni_dataset.py ├── networks └── omni_vision_transformer.py ├── omni_test.py ├── omni_train.py ├── omni_trainer.py ├── pretrained_ckpt └── .gitkeeep ├── requirements.txt └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # UniUSNet: A Promptable Framework for Universal Ultrasound Disease Prediction and Tissue Segmentation 2 | 3 | UniUSNet is a universal framework for ultrasound image classification and segmentation, featuring: 4 | 5 | - A novel promptable module for incorporating detailed information into the model's learning process. 6 | - Versatility across various ultrasound natures, anatomical positions, and input types. Proficiency in both segmentation and classification tasks 7 | - Strong generalization capabilities demonstrated through zero-shot and fine-tuning experiments on new datasets. 8 | 9 | For more details, see the accompanying paper and [Project Page](https://zehui-lin.github.io/UniUSNet/), 10 | 11 | > [**UniUSNet: A Promptable Framework for Universal Ultrasound Disease Prediction and Tissue Segmentation**](https://doi.org/10.1109/BIBM62325.2024.10822429)
12 | Zehui Lin, Zhuoneng Zhang, Xindi Hu, Zhifan Gao, Xin Yang, Yue Sun, Dong Ni, Tao Tan. BIBM, 2024. 13 | 14 | ## Installation 15 | - Clone this repository. 16 | ``` 17 | git clone https://github.com/Zehui-Lin/UniUSNet.git 18 | cd UniUSNet 19 | ``` 20 | - Create a new conda environment. 21 | ``` 22 | conda create -n UniUSNet python=3.10 23 | conda activate UniUSNet 24 | ``` 25 | - Install the required packages. 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Data 31 | 32 | - BroadUS-9.7K consists of ten publicly-available datasets, including [BUSI](https://www.kaggle.com/datasets/aryashah2k/breast-ultrasound-images-dataset), [BUSIS](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9025635/), [UDIAT](https://ieeexplore.ieee.org/abstract/document/8003418), [BUS-BRA](https://aapm.onlinelibrary.wiley.com/doi/abs/10.1002/mp.16812), [Fatty-Liver](https://link.springer.com/article/10.1007/s11548-018-1843-2#Sec8), [kidneyUS](http://rsingla.ca/kidneyUS/), [DDTI](https://www.kaggle.com/datasets/dasmehdixtr/ddti-thyroid-ultrasound-images/data), [Fetal HC](https://hc18.grand-challenge.org/), [CAMUS](https://www.creatis.insa-lyon.fr/Challenge/camus/index.html) and [Appendix](https://zenodo.org/records/7669442). 33 | - You can prepare the data by downloading the datasets and organizing them as follows: 34 | 35 | ``` 36 | data 37 | ├── classification 38 | │ └── UDIAT 39 | │ ├── 0 40 | │ │ ├── 000001.png 41 | │ │ ├── ... 42 | │ ├── 1 43 | │ │ ├── 000100.png 44 | │ │ ├── ... 45 | │ ├── config.yaml 46 | │ ├── test.txt 47 | │ ├── train.txt 48 | │ └── val.txt 49 | │ └── ... 50 | └── segmentation 51 | └── BUSIS 52 | ├── config.yaml 53 | ├── imgs 54 | │ ├── 000001.png 55 | │ ├── ... 56 | ├── masks 57 | │ ├── 000001.png 58 | │ ├── ... 59 | ├── test.txt 60 | ├── train.txt 61 | └── val.txt 62 | └── ... 63 | ``` 64 | - Please refer to the `data_demo` folder for examples. 65 | 66 | ## Training 67 | We use `torch.distributed` for multi-GPU training (also supports single GPU training). To train the model, run the following command: 68 | ``` 69 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=1234 omni_train.py --output_dir exp_out/trial_1 --prompt 70 | ``` 71 | 72 | ## Testing 73 | To test the model, run the following command: 74 | ``` 75 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=1234 omni_test.py --output_dir exp_out/trial_1 --prompt 76 | ``` 77 | 78 | 79 | ## Checkpoints 80 | - You can download the pre-trained checkpoints from [BaiduYun](https://pan.baidu.com/s/1uciwM5K4wRiMWnrAsB4qMQ?pwd=x390). 81 | 82 | ## Pretrained Weights 83 | 84 | To train your own model, please download the Swin Transformer backbone weights and place it in the `pretrained_ckpt/` directory: 85 | 86 | * [swin\_tiny\_patch4\_window7\_224.pth](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth) 87 | 88 | The folder structure should look like: 89 | 90 | ``` 91 | pretrained_ckpt 92 | └── swin_tiny_patch4_window7_224.pth 93 | ``` 94 | 95 | ## Citation 96 | If you find this work useful, please consider citing: 97 | 98 | ``` 99 | @inproceedings{lin2024uniusnet, 100 | title={UniUSNet: A Promptable Framework for Universal Ultrasound Disease Prediction and Tissue Segmentation}, 101 | author={Lin, Zehui and Zhang, Zhuoneng and Hu, Xindi and Gao, Zhifan and Yang, Xin and Sun, Yue and Ni, Dong and Tan, Tao}, 102 | booktitle={2024 IEEE International Conference on Bioinformatics and Biomedicine (BIBM)}, 103 | pages={3501--3504}, 104 | year={2024}, 105 | organization={IEEE} 106 | } 107 | ``` 108 | 109 | ## Acknowledgements 110 | This repository is based on the [Swin-Unet](https://github.com/HuCaoFighting/Swin-Unet) repository. We thank the authors for their contributions. 111 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.ENCODER_DEPTHS = [2, 2, 6, 2] 68 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 69 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 70 | _C.MODEL.SWIN.WINDOW_SIZE = 7 71 | _C.MODEL.SWIN.MLP_RATIO = 4. 72 | _C.MODEL.SWIN.QKV_BIAS = True 73 | _C.MODEL.SWIN.QK_SCALE = None 74 | _C.MODEL.SWIN.APE = False 75 | _C.MODEL.SWIN.PATCH_NORM = True 76 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 77 | 78 | # ----------------------------------------------------------------------------- 79 | # Training settings 80 | # ----------------------------------------------------------------------------- 81 | _C.TRAIN = CN() 82 | _C.TRAIN.START_EPOCH = 0 83 | _C.TRAIN.EPOCHS = 300 84 | _C.TRAIN.WARMUP_EPOCHS = 20 85 | _C.TRAIN.WEIGHT_DECAY = 0.05 86 | _C.TRAIN.BASE_LR = 5e-4 87 | _C.TRAIN.WARMUP_LR = 5e-7 88 | _C.TRAIN.MIN_LR = 5e-6 89 | # Clip gradient norm 90 | _C.TRAIN.CLIP_GRAD = 5.0 91 | # Auto resume from latest checkpoint 92 | _C.TRAIN.AUTO_RESUME = True 93 | # Gradient accumulation steps 94 | # could be overwritten by command line argument 95 | _C.TRAIN.ACCUMULATION_STEPS = 0 96 | # Whether to use gradient checkpointing to save memory 97 | # could be overwritten by command line argument 98 | _C.TRAIN.USE_CHECKPOINT = False 99 | 100 | # LR scheduler 101 | _C.TRAIN.LR_SCHEDULER = CN() 102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 103 | # Epoch interval to decay LR, used in StepLRScheduler 104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 105 | # LR decay rate, used in StepLRScheduler 106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 107 | 108 | # Optimizer 109 | _C.TRAIN.OPTIMIZER = CN() 110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 111 | # Optimizer Epsilon 112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 113 | # Optimizer Betas 114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 115 | # SGD momentum 116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 117 | 118 | # ----------------------------------------------------------------------------- 119 | # Augmentation settings 120 | # ----------------------------------------------------------------------------- 121 | _C.AUG = CN() 122 | # Color jitter factor 123 | _C.AUG.COLOR_JITTER = 0.4 124 | # Use AutoAugment policy. "v0" or "original" 125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 126 | # Random erase prob 127 | _C.AUG.REPROB = 0.25 128 | # Random erase mode 129 | _C.AUG.REMODE = 'pixel' 130 | # Random erase count 131 | _C.AUG.RECOUNT = 1 132 | # Mixup alpha, mixup enabled if > 0 133 | _C.AUG.MIXUP = 0.8 134 | # Cutmix alpha, cutmix enabled if > 0 135 | _C.AUG.CUTMIX = 1.0 136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 137 | _C.AUG.CUTMIX_MINMAX = None 138 | # Probability of performing mixup or cutmix when either/both is enabled 139 | _C.AUG.MIXUP_PROB = 1.0 140 | # Probability of switching to cutmix when both mixup and cutmix enabled 141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 143 | _C.AUG.MIXUP_MODE = 'batch' 144 | 145 | # ----------------------------------------------------------------------------- 146 | # Testing settings 147 | # ----------------------------------------------------------------------------- 148 | _C.TEST = CN() 149 | # Whether to use center crop when testing 150 | _C.TEST.CROP = True 151 | 152 | # ----------------------------------------------------------------------------- 153 | # Misc 154 | # ----------------------------------------------------------------------------- 155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 156 | # overwritten by command line argument 157 | _C.AMP_OPT_LEVEL = '' 158 | # Path to output folder, overwritten by command line argument 159 | _C.OUTPUT = '' 160 | # Tag of experiment, overwritten by command line argument 161 | _C.TAG = 'default' 162 | # Frequency to save checkpoint 163 | _C.SAVE_FREQ = 1 164 | # Frequency to logging info 165 | _C.PRINT_FREQ = 10 166 | # Fixed random seed 167 | _C.SEED = 0 168 | # Perform evaluation only, overwritten by command line argument 169 | _C.EVAL_MODE = False 170 | # Test throughput only, overwritten by command line argument 171 | _C.THROUGHPUT_MODE = False 172 | # local rank for DistributedDataParallel, given by command line argument 173 | _C.LOCAL_RANK = 0 174 | 175 | 176 | def _update_config_from_file(config, cfg_file): 177 | config.defrost() 178 | with open(cfg_file, 'r') as f: 179 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 180 | 181 | for cfg in yaml_cfg.setdefault('BASE', ['']): 182 | if cfg: 183 | _update_config_from_file( 184 | config, os.path.join(os.path.dirname(cfg_file), cfg) 185 | ) 186 | print('=> merge config from {}'.format(cfg_file)) 187 | config.merge_from_file(cfg_file) 188 | config.freeze() 189 | 190 | 191 | def update_config(config, args): 192 | _update_config_from_file(config, args.cfg) 193 | 194 | config.defrost() 195 | if args.opts: 196 | config.merge_from_list(args.opts) 197 | 198 | # merge from specific arguments 199 | if args.batch_size: 200 | config.DATA.BATCH_SIZE = args.batch_size 201 | if args.zip: 202 | config.DATA.ZIP_MODE = True 203 | if args.cache_mode: 204 | config.DATA.CACHE_MODE = args.cache_mode 205 | if args.resume: 206 | config.MODEL.RESUME = args.resume 207 | if args.accumulation_steps: 208 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 209 | if args.use_checkpoint: 210 | config.TRAIN.USE_CHECKPOINT = True 211 | if args.amp_opt_level: 212 | config.AMP_OPT_LEVEL = args.amp_opt_level 213 | if args.tag: 214 | config.TAG = args.tag 215 | if args.eval: 216 | config.EVAL_MODE = True 217 | if args.throughput: 218 | config.THROUGHPUT_MODE = True 219 | 220 | config.freeze() 221 | 222 | 223 | def get_config(args): 224 | """Get a yacs CfgNode object with default values.""" 225 | # Return a clone so that the defaults will not be altered 226 | # This is for the "local variable" use pattern 227 | config = _C.clone() 228 | update_config(config, args) 229 | 230 | return config 231 | -------------------------------------------------------------------------------- /configs/swin_tiny_patch4_window7_224_lite.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "./pretrained_ckpt/swin_tiny_patch4_window7_224.pth" 6 | SWIN: 7 | FINAL_UPSAMPLE: "expand_first" 8 | EMBED_DIM: 96 9 | ENCODER_DEPTHS: [ 4, 4, 4, 4] 10 | DECODER_DEPTHS: [ 2, 2, 2, 2] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /data_demo/classification/UDIAT/0/demo_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/classification/UDIAT/0/demo_01.png -------------------------------------------------------------------------------- /data_demo/classification/UDIAT/1/demo_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/classification/UDIAT/1/demo_02.png -------------------------------------------------------------------------------- /data_demo/classification/UDIAT/config.yaml: -------------------------------------------------------------------------------- 1 | 0:Benign 2 | 1:Malignant 3 | -------------------------------------------------------------------------------- /data_demo/classification/UDIAT/test.txt: -------------------------------------------------------------------------------- 1 | 0/demo_01.png 2 | 1/demo_02.png -------------------------------------------------------------------------------- /data_demo/segmentation/BUSIS/config.yaml: -------------------------------------------------------------------------------- 1 | 0:background:0 2 | 1:nodule:255 3 | -------------------------------------------------------------------------------- /data_demo/segmentation/BUSIS/imgs/demo_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/imgs/demo_01.png -------------------------------------------------------------------------------- /data_demo/segmentation/BUSIS/imgs/demo_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/imgs/demo_02.png -------------------------------------------------------------------------------- /data_demo/segmentation/BUSIS/masks/demo_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/masks/demo_01.png -------------------------------------------------------------------------------- /data_demo/segmentation/BUSIS/masks/demo_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/masks/demo_02.png -------------------------------------------------------------------------------- /data_demo/segmentation/BUSIS/test.txt: -------------------------------------------------------------------------------- 1 | demo_01.png 2 | demo_02.png -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | import torch 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torch.utils.data import Dataset 9 | 10 | from datasets.omni_dataset import position_prompt_dict 11 | from datasets.omni_dataset import nature_prompt_dict 12 | 13 | from datasets.omni_dataset import position_prompt_one_hot_dict 14 | from datasets.omni_dataset import nature_prompt_one_hot_dict 15 | from datasets.omni_dataset import type_prompt_one_hot_dict 16 | 17 | 18 | def random_horizontal_flip(image, label): 19 | axis = 1 20 | image = np.flip(image, axis=axis).copy() 21 | label = np.flip(label, axis=axis).copy() 22 | return image, label 23 | 24 | 25 | def random_rotate(image, label): 26 | angle = np.random.randint(-20, 20) 27 | image = ndimage.rotate(image, angle, order=0, reshape=False) 28 | label = ndimage.rotate(label, angle, order=0, reshape=False) 29 | return image, label 30 | 31 | 32 | class RandomGenerator(object): 33 | def __init__(self, output_size): 34 | self.output_size = output_size 35 | 36 | def __call__(self, sample): 37 | image, label = sample['image'], sample['label'] 38 | if 'type_prompt' in sample: 39 | type_prompt = sample['type_prompt'] 40 | 41 | if random.random() > 0.5: 42 | image, label = random_horizontal_flip(image, label) 43 | elif random.random() > 0.5: 44 | image, label = random_rotate(image, label) 45 | x, y, _ = image.shape 46 | 47 | if x > y: 48 | image = zoom(image, (self.output_size[0] / y, self.output_size[1] / y, 1), order=1) 49 | label = zoom(label, (self.output_size[0] / y, self.output_size[1] / y), order=0) 50 | else: 51 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / x, 1), order=1) 52 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / x), order=0) 53 | 54 | scale = random.uniform(0.8, 1.2) 55 | image = zoom(image, (scale, scale, 1), order=1) 56 | label = zoom(label, (scale, scale), order=0) 57 | 58 | x, y, _ = image.shape 59 | if scale > 1: 60 | startx = x//2 - (self.output_size[0]//2) 61 | starty = y//2 - (self.output_size[1]//2) 62 | image = image[startx:startx+self.output_size[0], starty:starty+self.output_size[1], :] 63 | label = label[startx:startx+self.output_size[0], starty:starty+self.output_size[1]] 64 | else: 65 | if x > self.output_size[0]: 66 | startx = x//2 - (self.output_size[0]//2) 67 | image = image[startx:startx+self.output_size[0], :, :] 68 | label = label[startx:startx+self.output_size[0], :] 69 | if y > self.output_size[1]: 70 | starty = y//2 - (self.output_size[1]//2) 71 | image = image[:, starty:starty+self.output_size[1], :] 72 | label = label[:, starty:starty+self.output_size[1]] 73 | x, y, _ = image.shape 74 | new_image = np.zeros((self.output_size[0], self.output_size[1], 3)) 75 | new_label = np.zeros((self.output_size[0], self.output_size[1])) 76 | if x < y: 77 | startx = self.output_size[0]//2 - (x//2) 78 | starty = 0 79 | new_image[startx:startx+x, starty:starty+y, :] = image 80 | new_label[startx:startx+x, starty:starty+y] = label 81 | else: 82 | startx = 0 83 | starty = self.output_size[1]//2 - (y//2) 84 | new_image[startx:startx+x, starty:starty+y, :] = image 85 | new_label[startx:startx+x, starty:starty+y] = label 86 | image = new_image 87 | label = new_label 88 | 89 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 90 | label = torch.from_numpy(label.astype(np.float32)) 91 | if 'type_prompt' in sample: 92 | sample = {'image': image, 'label': label.long(), 'type_prompt': type_prompt} 93 | else: 94 | sample = {'image': image, 'label': label.long()} 95 | return sample 96 | 97 | 98 | class CenterCropGenerator(object): 99 | def __init__(self, output_size): 100 | self.output_size = output_size 101 | 102 | def __call__(self, sample): 103 | image, label = sample['image'], sample['label'] 104 | if 'type_prompt' in sample: 105 | type_prompt = sample['type_prompt'] 106 | x, y, _ = image.shape 107 | if x > y: 108 | image = zoom(image, (self.output_size[0] / y, self.output_size[1] / y, 1), order=1) 109 | label = zoom(label, (self.output_size[0] / y, self.output_size[1] / y), order=0) 110 | else: 111 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / x, 1), order=1) 112 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / x), order=0) 113 | x, y, _ = image.shape 114 | startx = x//2 - (self.output_size[0]//2) 115 | starty = y//2 - (self.output_size[1]//2) 116 | image = image[startx:startx+self.output_size[0], starty:starty+self.output_size[1], :] 117 | label = label[startx:startx+self.output_size[0], starty:starty+self.output_size[1]] 118 | 119 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 120 | label = torch.from_numpy(label.astype(np.float32)) 121 | if 'type_prompt' in sample: 122 | sample = {'image': image, 'label': label.long(), 'type_prompt': type_prompt} 123 | else: 124 | sample = {'image': image, 'label': label.long()} 125 | return sample 126 | 127 | 128 | class USdatasetSeg(Dataset): 129 | def __init__(self, base_dir, list_dir, split, transform=None, prompt=False): 130 | self.transform = transform 131 | self.split = split 132 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 133 | 134 | # BUSI 135 | self.sample_list = [sample for sample in self.sample_list if not "normal" in sample] 136 | 137 | self.data_dir = base_dir 138 | self.label_info = open(os.path.join(list_dir, "config.yaml")).readlines() 139 | self.prompt = prompt 140 | 141 | def __len__(self): 142 | return len(self.sample_list) 143 | 144 | def __getitem__(self, idx): 145 | 146 | img_name = self.sample_list[idx].strip('\n') 147 | img_path = os.path.join(self.data_dir, "imgs", img_name) 148 | label_path = os.path.join(self.data_dir, "masks", img_name) 149 | 150 | image = cv2.imread(img_path) 151 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 152 | 153 | label_info_list = [info.strip().split(":") for info in self.label_info] 154 | for single_label_info in label_info_list: 155 | label_index = int(single_label_info[0]) 156 | label_value_in_image = int(single_label_info[2]) 157 | label[label == label_value_in_image] = label_index 158 | 159 | label[label > 0] = 1 160 | 161 | sample = {'image': image/255.0, 'label': label} 162 | if self.transform: 163 | sample = self.transform(sample) 164 | sample['case_name'] = self.sample_list[idx].strip('\n') 165 | if self.prompt: 166 | dataset_name = img_path.split("/")[-3] 167 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"] 168 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]] 169 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]] 170 | return sample 171 | 172 | 173 | class USdatasetCls(Dataset): 174 | def __init__(self, base_dir, list_dir, split, transform=None, prompt=False): 175 | self.transform = transform # using transform in torch! 176 | self.split = split 177 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 178 | 179 | # BUSI 180 | self.sample_list = [sample for sample in self.sample_list if not "normal" in sample] 181 | 182 | self.data_dir = base_dir 183 | self.label_info = open(os.path.join(list_dir, "config.yaml")).readlines() 184 | self.prompt = prompt 185 | 186 | def __len__(self): 187 | return len(self.sample_list) 188 | 189 | def __getitem__(self, idx): 190 | 191 | img_name = self.sample_list[idx].strip('\n') 192 | img_path = os.path.join(self.data_dir, img_name) 193 | 194 | image = cv2.imread(img_path) 195 | label = int(img_name.split("/")[0]) 196 | 197 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])} 198 | if self.transform: 199 | sample = self.transform(sample) 200 | sample['label'] = torch.from_numpy(np.array(label)) 201 | sample['case_name'] = self.sample_list[idx].strip('\n') 202 | if self.prompt: 203 | dataset_name = img_path.split("/")[-3] 204 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"] 205 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]] 206 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]] 207 | return sample 208 | -------------------------------------------------------------------------------- /datasets/omni_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import Tensor 9 | from typing import Sequence 10 | 11 | # prompt info dict 12 | # task prompt 13 | task_prompt_list = [ 14 | "segmentation", 15 | "classification", 16 | ] 17 | # position prompt 18 | position_prompt_dict = { 19 | "BUS-BRA": "breast", 20 | "BUSIS": "breast", 21 | "CAMUS": "cardiac", 22 | "DDTI": "thyroid", 23 | "Fetal_HC": "head", 24 | "kidneyUS": "kidney", 25 | "UDIAT": "breast", 26 | "Appendix": "appendix", 27 | "Fatty-Liver": "liver", 28 | "BUSI": "breast", 29 | } 30 | # nature prompt 31 | nature_prompt_dict = { 32 | "BUS-BRA": "tumor", 33 | "BUSIS": "tumor", 34 | "CAMUS": "organ", 35 | "DDTI": "tumor", 36 | "Fetal_HC": "organ", 37 | "kidneyUS": "organ", 38 | "UDIAT": "organ", 39 | "Appendix": "organ", 40 | "Fatty-Liver": "organ", 41 | "BUSI": "tumor", 42 | } 43 | # type prompt 44 | available_type_prompt_list = [ 45 | "BUS-BRA", 46 | "BUSIS", 47 | "CAMUS", 48 | "DDTI", 49 | "Fetal_HC", 50 | "kidneyUS", 51 | "UDIAT", 52 | "BUSI" 53 | ] 54 | 55 | # prompt one-hot 56 | # organ prompt 57 | position_prompt_one_hot_dict = { 58 | "breast": [1, 0, 0, 0, 0, 0, 0, 0], 59 | "cardiac": [0, 1, 0, 0, 0, 0, 0, 0], 60 | "thyroid": [0, 0, 1, 0, 0, 0, 0, 0], 61 | "head": [0, 0, 0, 1, 0, 0, 0, 0], 62 | "kidney": [0, 0, 0, 0, 1, 0, 0, 0], 63 | "appendix": [0, 0, 0, 0, 0, 1, 0, 0], 64 | "liver": [0, 0, 0, 0, 0, 0, 1, 0], 65 | "indis": [0, 0, 0, 0, 0, 0, 0, 1] 66 | } 67 | # task prompt 68 | task_prompt_one_hot_dict = { 69 | "segmentation": [1, 0], 70 | "classification": [0, 1] 71 | } 72 | # nature prompt 73 | nature_prompt_one_hot_dict = { 74 | "tumor": [1, 0], 75 | "organ": [0, 1], 76 | } 77 | # type prompt 78 | type_prompt_one_hot_dict = { 79 | "whole": [1, 0, 0], 80 | "local": [0, 1, 0], 81 | "location": [0, 0, 1], 82 | } 83 | 84 | 85 | def list_add_prefix(txt_path, prefix_1, prefix_2): 86 | 87 | with open(txt_path, 'r') as f: 88 | lines = f.readlines() 89 | if prefix_2 is not None: 90 | return [os.path.join(prefix_1, prefix_2, line.strip('\n')) for line in lines] 91 | else: 92 | return [os.path.join(prefix_1, line.strip('\n')) for line in lines] 93 | 94 | 95 | class WeightedRandomSamplerDDP(DistributedSampler): 96 | r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). 97 | 98 | Args: 99 | data_set: Dataset used for sampling. 100 | weights (sequence) : a sequence of weights, not necessary summing up to one 101 | num_replicas (int, optional): Number of processes participating in 102 | distributed training. By default, :attr:`world_size` is retrieved from the 103 | current distributed group. 104 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 105 | By default, :attr:`rank` is retrieved from the current distributed 106 | group. 107 | num_samples (int): number of samples to draw 108 | replacement (bool): if ``True``, samples are drawn with replacement. 109 | If not, they are drawn without replacement, which means that when a 110 | sample index is drawn for a row, it cannot be drawn again for that row. 111 | generator (Generator): Generator used in sampling. 112 | """ 113 | weights: Tensor 114 | num_samples: int 115 | replacement: bool 116 | 117 | def __init__(self, data_set, weights: Sequence[float], num_replicas: int, rank: int, num_samples: int, 118 | replacement: bool = True, generator=None) -> None: 119 | super(WeightedRandomSamplerDDP, self).__init__(data_set, num_replicas, rank) 120 | if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ 121 | num_samples <= 0: 122 | raise ValueError("num_samples should be a positive integer " 123 | "value, but got num_samples={}".format(num_samples)) 124 | if not isinstance(replacement, bool): 125 | raise ValueError("replacement should be a boolean value, but got " 126 | "replacement={}".format(replacement)) 127 | self.weights = torch.as_tensor(weights, dtype=torch.double) 128 | self.num_samples = num_samples 129 | self.replacement = replacement 130 | self.generator = generator 131 | self.num_replicas = num_replicas 132 | self.rank = rank 133 | self.weights = self.weights[self.rank::self.num_replicas] 134 | self.num_samples = self.num_samples // self.num_replicas 135 | 136 | def __iter__(self): 137 | rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) 138 | rand_tensor = self.rank + rand_tensor * self.num_replicas 139 | return iter(rand_tensor.tolist()) 140 | 141 | def __len__(self): 142 | return self.num_samples 143 | 144 | 145 | class USdatasetOmni_seg(Dataset): 146 | def __init__(self, base_dir, split, transform=None, prompt=False): 147 | self.transform = transform 148 | self.split = split 149 | self.data_dir = base_dir 150 | self.sample_list = [] 151 | self.subset_len = [] 152 | self.prompt = prompt 153 | 154 | self.sample_list.extend(list_add_prefix(os.path.join( 155 | base_dir, "segmentation", "BUS-BRA", split + ".txt"), "BUS-BRA", "imgs")) 156 | self.sample_list.extend(list_add_prefix(os.path.join( 157 | base_dir, "segmentation", "BUSIS", split + ".txt"), "BUSIS", "imgs")) 158 | self.sample_list.extend(list_add_prefix(os.path.join( 159 | base_dir, "segmentation", "CAMUS", split + ".txt"), "CAMUS", "imgs")) 160 | self.sample_list.extend(list_add_prefix(os.path.join( 161 | base_dir, "segmentation", "DDTI", split + ".txt"), "DDTI", "imgs")) 162 | self.sample_list.extend(list_add_prefix(os.path.join(base_dir, "segmentation", 163 | "Fetal_HC", split + ".txt"), "Fetal_HC", "imgs")) 164 | self.sample_list.extend(list_add_prefix(os.path.join(base_dir, "segmentation", 165 | "kidneyUS", split + ".txt"), "kidneyUS", "imgs")) 166 | self.sample_list.extend(list_add_prefix(os.path.join( 167 | base_dir, "segmentation", "UDIAT", split + ".txt"), "UDIAT", "imgs")) 168 | 169 | self.subset_len.append(len(list_add_prefix(os.path.join( 170 | base_dir, "segmentation", "BUS-BRA", split + ".txt"), "BUS-BRA", "imgs"))) 171 | self.subset_len.append(len(list_add_prefix(os.path.join( 172 | base_dir, "segmentation", "BUSIS", split + ".txt"), "BUSIS", "imgs"))) 173 | self.subset_len.append(len(list_add_prefix(os.path.join( 174 | base_dir, "segmentation", "CAMUS", split + ".txt"), "CAMUS", "imgs"))) 175 | self.subset_len.append(len(list_add_prefix(os.path.join( 176 | base_dir, "segmentation", "DDTI", split + ".txt"), "DDTI", "imgs"))) 177 | self.subset_len.append(len(list_add_prefix(os.path.join( 178 | base_dir, "segmentation", "Fetal_HC", split + ".txt"), "Fetal_HC", "imgs"))) 179 | self.subset_len.append(len(list_add_prefix(os.path.join( 180 | base_dir, "segmentation", "kidneyUS", split + ".txt"), "kidneyUS", "imgs"))) 181 | self.subset_len.append(len(list_add_prefix(os.path.join( 182 | base_dir, "segmentation", "UDIAT", split + ".txt"), "UDIAT", "imgs"))) 183 | 184 | def __len__(self): 185 | return len(self.sample_list) 186 | 187 | def __getitem__(self, idx): 188 | 189 | img_name = self.sample_list[idx].strip('\n') 190 | img_path = os.path.join(self.data_dir, "segmentation", img_name) 191 | label_path = os.path.join(self.data_dir, "segmentation", img_name).replace("imgs", "masks") 192 | 193 | image = cv2.imread(img_path) 194 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 195 | 196 | dataset_name = img_name.split("/")[0] 197 | label_info = open(os.path.join(self.data_dir, "segmentation", dataset_name, "config.yaml")).readlines() 198 | 199 | label_info_list = [info.strip().split(":") for info in label_info] 200 | for single_label_info in label_info_list: 201 | label_index = int(single_label_info[0]) 202 | label_value_in_image = int(single_label_info[2]) 203 | label[label == label_value_in_image] = label_index 204 | 205 | label[label > 0] = 1 206 | 207 | if not self.prompt: 208 | sample = {'image': image/255.0, 'label': label} 209 | else: 210 | if random.random() > 0.5: 211 | x, y, w, h = cv2.boundingRect(label) 212 | length = max(w, h) 213 | 214 | if 0 in image[y:y+length, x:x+length, :].shape: 215 | image = image 216 | label = label 217 | sample = {'image': image/255.0, 'label': label} 218 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"] 219 | else: 220 | image = image[y:y+length, x:x+length, :] 221 | label = label[y:y+length, x:x+length] 222 | sample = {'image': image/255.0, 'label': label} 223 | sample['type_prompt'] = type_prompt_one_hot_dict["local"] 224 | 225 | else: 226 | sample = {'image': image/255.0, 'label': label} 227 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"] 228 | pass 229 | if self.transform: 230 | sample = self.transform(sample) 231 | sample['case_name'] = self.sample_list[idx].strip('\n') 232 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]] 233 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]] 234 | sample['task_prompt'] = task_prompt_one_hot_dict["segmentation"] 235 | 236 | return sample 237 | 238 | 239 | class USdatasetOmni_cls(Dataset): 240 | def __init__(self, base_dir, split, transform=None, prompt=False): 241 | self.transform = transform 242 | self.split = split 243 | self.data_dir = base_dir 244 | self.sample_list = [] 245 | self.subset_len = [] 246 | self.prompt = prompt 247 | 248 | self.sample_list.extend(list_add_prefix(os.path.join( 249 | base_dir, "classification", "Appendix", split + ".txt"), "Appendix", None)) 250 | self.sample_list.extend(list_add_prefix(os.path.join( 251 | base_dir, "classification", "BUS-BRA", split + ".txt"), "BUS-BRA", None)) 252 | self.sample_list.extend(list_add_prefix(os.path.join(base_dir, "classification", 253 | "Fatty-Liver", split + ".txt"), "Fatty-Liver", None)) 254 | self.sample_list.extend(list_add_prefix(os.path.join( 255 | base_dir, "classification", "UDIAT", split + ".txt"), "UDIAT", None)) 256 | 257 | self.subset_len.append(len(list_add_prefix(os.path.join( 258 | base_dir, "classification", "Appendix", split + ".txt"), "Appendix", None))) 259 | self.subset_len.append(len(list_add_prefix(os.path.join( 260 | base_dir, "classification", "BUS-BRA", split + ".txt"), "BUS-BRA", None))) 261 | self.subset_len.append(len(list_add_prefix(os.path.join(base_dir, "classification", 262 | "Fatty-Liver", split + ".txt"), "Fatty-Liver", None))) 263 | self.subset_len.append(len(list_add_prefix(os.path.join( 264 | base_dir, "classification", "UDIAT", split + ".txt"), "UDIAT", None))) 265 | 266 | def __len__(self): 267 | return len(self.sample_list) 268 | 269 | def __getitem__(self, idx): 270 | 271 | img_name = self.sample_list[idx].strip('\n') 272 | img_path = os.path.join(self.data_dir, "classification", img_name) 273 | 274 | image = cv2.imread(img_path) 275 | dataset_name = img_name.split("/")[0] 276 | label = int(img_name.split("/")[-2]) 277 | 278 | if not self.prompt: 279 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])} 280 | else: 281 | if dataset_name in available_type_prompt_list: 282 | random_number = random.random() 283 | mask_path = os.path.join(self.data_dir, "segmentation", 284 | "/".join([img_name.split("/")[0], "masks", img_name.split("/")[2]])) 285 | if random_number < 0.3: 286 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])} 287 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"] 288 | elif random_number < 0.6: 289 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 290 | x, y, w, h = cv2.boundingRect(mask) 291 | length = max(w, h) 292 | 293 | if 0 in image[y:y+length, x:x+length, :].shape: 294 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])} 295 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"] 296 | else: 297 | image = image[y:y+length, x:x+length, :] 298 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])} 299 | sample['type_prompt'] = type_prompt_one_hot_dict["local"] 300 | else: 301 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 302 | mask[mask > 0] = 255 303 | image = image + (np.expand_dims(mask, axis=2)*0.1).astype('uint8') 304 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])} 305 | sample['type_prompt'] = type_prompt_one_hot_dict["location"] 306 | else: 307 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])} 308 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"] 309 | if self.transform: 310 | sample = self.transform(sample) 311 | sample['label'] = torch.from_numpy(np.array(label)) 312 | sample['case_name'] = self.sample_list[idx].strip('\n') 313 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]] 314 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]] 315 | sample['task_prompt'] = task_prompt_one_hot_dict["classification"] 316 | 317 | return sample 318 | -------------------------------------------------------------------------------- /networks/omni_vision_transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import torch.utils.checkpoint as checkpoint 8 | from timm.models.layers import trunc_normal_ 9 | 10 | from torch.functional import F 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.utils.checkpoint as checkpoint 15 | from einops import rearrange 16 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 17 | from torch.functional import F 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Mlp(nn.Module): 23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.fc1 = nn.Linear(in_features, hidden_features) 28 | self.act = act_layer() 29 | self.fc2 = nn.Linear(hidden_features, out_features) 30 | self.drop = nn.Dropout(drop) 31 | 32 | def forward(self, x): 33 | x = self.fc1(x) 34 | x = self.act(x) 35 | x = self.drop(x) 36 | x = self.fc2(x) 37 | x = self.drop(x) 38 | return x 39 | 40 | 41 | def window_partition(x, window_size): 42 | """ 43 | Args: 44 | x: (B, H, W, C) 45 | window_size (int): window size 46 | 47 | Returns: 48 | windows: (num_windows*B, window_size, window_size, C) 49 | """ 50 | B, H, W, C = x.shape 51 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 52 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 53 | return windows 54 | 55 | 56 | def window_reverse(windows, window_size, H, W): 57 | """ 58 | Args: 59 | windows: (num_windows*B, window_size, window_size, C) 60 | window_size (int): Window size 61 | H (int): Height of image 62 | W (int): Width of image 63 | 64 | Returns: 65 | x: (B, H, W, C) 66 | """ 67 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 68 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 69 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 70 | return x 71 | 72 | 73 | class WindowAttention(nn.Module): 74 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 75 | It supports both of shifted and non-shifted window. 76 | 77 | Args: 78 | dim (int): Number of input channels. 79 | window_size (tuple[int]): The height and width of the window. 80 | num_heads (int): Number of attention heads. 81 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 82 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 83 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 84 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 85 | """ 86 | 87 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 88 | 89 | super().__init__() 90 | self.dim = dim 91 | self.window_size = window_size # Wh, Ww 92 | self.num_heads = num_heads 93 | head_dim = dim // num_heads 94 | self.scale = qk_scale or head_dim ** -0.5 95 | 96 | # define a parameter table of relative position bias 97 | self.relative_position_bias_table = nn.Parameter( 98 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 99 | 100 | # get pair-wise relative position index for each token inside the window 101 | coords_h = torch.arange(self.window_size[0]) 102 | coords_w = torch.arange(self.window_size[1]) 103 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 104 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 105 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 106 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 107 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 108 | relative_coords[:, :, 1] += self.window_size[1] - 1 109 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 110 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 111 | self.register_buffer("relative_position_index", relative_position_index) 112 | 113 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 114 | self.attn_drop = nn.Dropout(attn_drop) 115 | self.proj = nn.Linear(dim, dim) 116 | self.proj_drop = nn.Dropout(proj_drop) 117 | 118 | trunc_normal_(self.relative_position_bias_table, std=.02) 119 | self.softmax = nn.Softmax(dim=-1) 120 | 121 | def forward(self, x, mask=None): 122 | """ 123 | Args: 124 | x: input features with shape of (num_windows*B, N, C) 125 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 126 | """ 127 | B_, N, C = x.shape 128 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 129 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 130 | 131 | q = q * self.scale 132 | attn = (q @ k.transpose(-2, -1)) 133 | 134 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 135 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 136 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 137 | attn = attn + relative_position_bias.unsqueeze(0) 138 | 139 | if mask is not None: 140 | nW = mask.shape[0] 141 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 142 | attn = attn.view(-1, self.num_heads, N, N) 143 | attn = self.softmax(attn) 144 | else: 145 | attn = self.softmax(attn) 146 | 147 | attn = self.attn_drop(attn) 148 | 149 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 150 | x = self.proj(x) 151 | x = self.proj_drop(x) 152 | return x 153 | 154 | def extra_repr(self) -> str: 155 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 156 | 157 | def flops(self, N): 158 | # calculate flops for 1 window with token length of N 159 | flops = 0 160 | # qkv = self.qkv(x) 161 | flops += N * self.dim * 3 * self.dim 162 | # attn = (q @ k.transpose(-2, -1)) 163 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 164 | # x = (attn @ v) 165 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 166 | # x = self.proj(x) 167 | flops += N * self.dim * self.dim 168 | return flops 169 | 170 | 171 | class SwinTransformerBlock(nn.Module): 172 | r""" Swin Transformer Block. 173 | 174 | Args: 175 | dim (int): Number of input channels. 176 | input_resolution (tuple[int]): Input resolution. 177 | num_heads (int): Number of attention heads. 178 | window_size (int): Window size. 179 | shift_size (int): Shift size for SW-MSA. 180 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 181 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 182 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 183 | drop (float, optional): Dropout rate. Default: 0.0 184 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 185 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 186 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 187 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 188 | """ 189 | 190 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 191 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 192 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 193 | super().__init__() 194 | self.dim = dim 195 | self.input_resolution = input_resolution 196 | self.num_heads = num_heads 197 | self.window_size = window_size 198 | self.shift_size = shift_size 199 | self.mlp_ratio = mlp_ratio 200 | if min(self.input_resolution) <= self.window_size: 201 | # if window size is larger than input resolution, we don't partition windows 202 | self.shift_size = 0 203 | self.window_size = min(self.input_resolution) 204 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 205 | 206 | self.norm1 = norm_layer(dim) 207 | self.attn = WindowAttention( 208 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 209 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 210 | 211 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 212 | self.norm2 = norm_layer(dim) 213 | mlp_hidden_dim = int(dim * mlp_ratio) 214 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 215 | 216 | if self.shift_size > 0: 217 | # calculate attention mask for SW-MSA 218 | H, W = self.input_resolution 219 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 220 | h_slices = (slice(0, -self.window_size), 221 | slice(-self.window_size, -self.shift_size), 222 | slice(-self.shift_size, None)) 223 | w_slices = (slice(0, -self.window_size), 224 | slice(-self.window_size, -self.shift_size), 225 | slice(-self.shift_size, None)) 226 | cnt = 0 227 | for h in h_slices: 228 | for w in w_slices: 229 | img_mask[:, h, w, :] = cnt 230 | cnt += 1 231 | 232 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 233 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 234 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 235 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 236 | else: 237 | attn_mask = None 238 | 239 | self.register_buffer("attn_mask", attn_mask) 240 | 241 | def forward(self, x): 242 | H, W = self.input_resolution 243 | B, L, C = x.shape 244 | assert L == H * W, "input feature has wrong size" 245 | 246 | shortcut = x 247 | x = self.norm1(x) 248 | x = x.view(B, H, W, C) 249 | 250 | # cyclic shift 251 | if self.shift_size > 0: 252 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 253 | else: 254 | shifted_x = x 255 | 256 | # partition windows 257 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 258 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 259 | 260 | # W-MSA/SW-MSA 261 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 262 | 263 | # merge windows 264 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 265 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 266 | 267 | # reverse cyclic shift 268 | if self.shift_size > 0: 269 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 270 | else: 271 | x = shifted_x 272 | x = x.view(B, H * W, C) 273 | 274 | # FFN 275 | x = shortcut + self.drop_path(x) 276 | x = x + self.drop_path(self.mlp(self.norm2(x))) 277 | 278 | return x 279 | 280 | def extra_repr(self) -> str: 281 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 282 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 283 | 284 | def flops(self): 285 | flops = 0 286 | H, W = self.input_resolution 287 | # norm1 288 | flops += self.dim * H * W 289 | # W-MSA/SW-MSA 290 | nW = H * W / self.window_size / self.window_size 291 | flops += nW * self.attn.flops(self.window_size * self.window_size) 292 | # mlp 293 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 294 | # norm2 295 | flops += self.dim * H * W 296 | return flops 297 | 298 | 299 | class FinalPatchExpand_X4(nn.Module): 300 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 301 | super().__init__() 302 | self.input_resolution = input_resolution 303 | self.dim = dim 304 | self.dim_scale = dim_scale 305 | self.expand = nn.Linear(dim, 16*dim, bias=False) 306 | self.output_dim = dim 307 | self.norm = norm_layer(self.output_dim) 308 | 309 | def forward(self, x): 310 | """ 311 | x: B, H*W, C 312 | """ 313 | H, W = self.input_resolution 314 | x = self.expand(x) 315 | B, L, C = x.shape 316 | assert L == H * W, "input feature has wrong size" 317 | 318 | x = x.view(B, H, W, C) 319 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, 320 | p2=self.dim_scale, c=C//(self.dim_scale**2)) 321 | x = x.view(B, -1, self.output_dim) 322 | x = self.norm(x) 323 | 324 | return x 325 | 326 | 327 | class PatchMerging(nn.Module): 328 | r""" Patch Merging Layer. 329 | 330 | Args: 331 | input_resolution (tuple[int]): Resolution of input feature. 332 | dim (int): Number of input channels. 333 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 334 | """ 335 | 336 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 337 | super().__init__() 338 | self.input_resolution = input_resolution 339 | self.dim = dim 340 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 341 | self.norm = norm_layer(4 * dim) 342 | 343 | def forward(self, x): 344 | """ 345 | x: B, H*W, C 346 | """ 347 | H, W = self.input_resolution 348 | B, L, C = x.shape 349 | assert L == H * W, "input feature has wrong size" 350 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 351 | 352 | x = x.view(B, H, W, C) 353 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 354 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 355 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 356 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 357 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 358 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 359 | 360 | x = self.norm(x) 361 | x = self.reduction(x) 362 | 363 | return x 364 | 365 | def extra_repr(self) -> str: 366 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 367 | 368 | def flops(self): 369 | H, W = self.input_resolution 370 | flops = H * W * self.dim 371 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim # reduction 4 * self.dim -> 2 * self.dim 372 | return flops 373 | 374 | 375 | class PatchExpand(nn.Module): 376 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 377 | super().__init__() 378 | self.input_resolution = input_resolution 379 | self.dim = dim 380 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity() 381 | self.norm = norm_layer(dim // dim_scale) 382 | 383 | def forward(self, x): 384 | """ 385 | x: B, H*W, C 386 | """ 387 | H, W = self.input_resolution 388 | x = self.expand(x) 389 | B, L, C = x.shape 390 | assert L == H * W, "input feature has wrong size" 391 | 392 | x = x.view(B, H, W, C) 393 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 394 | x = x.view(B, -1, C//4) 395 | x = self.norm(x) 396 | 397 | return x 398 | 399 | 400 | class ChannelHalf(nn.Module): 401 | def __init__(self, input_resolution=None, dim=0, norm_layer=nn.LayerNorm): 402 | super().__init__() 403 | self.linear = nn.Linear(dim, dim // 2, bias=False) 404 | self.norm = norm_layer(dim // 2) 405 | self.input_resolution = input_resolution 406 | 407 | def forward(self, x): 408 | x = self.linear(x) 409 | x = self.norm(x) 410 | return x 411 | 412 | 413 | class PatchEmbed(nn.Module): 414 | r""" Image to Patch Embedding 415 | 416 | Args: 417 | img_size (int): Image size. Default: 224. 418 | patch_size (int): Patch token size. Default: 4. 419 | in_chans (int): Number of input image channels. Default: 3. 420 | embed_dim (int): Number of linear projection output channels. Default: 96. 421 | norm_layer (nn.Module, optional): Normalization layer. Default: None 422 | """ 423 | 424 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 425 | super().__init__() 426 | img_size = to_2tuple(img_size) 427 | patch_size = to_2tuple(patch_size) 428 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 429 | self.img_size = img_size 430 | self.patch_size = patch_size 431 | self.patches_resolution = patches_resolution 432 | self.num_patches = patches_resolution[0] * patches_resolution[1] 433 | 434 | self.in_chans = in_chans 435 | self.embed_dim = embed_dim 436 | 437 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 438 | if norm_layer is not None: 439 | self.norm = norm_layer(embed_dim) 440 | else: 441 | self.norm = None 442 | 443 | def forward(self, x): 444 | B, C, H, W = x.shape 445 | # FIXME look at relaxing size constraints 446 | assert H == self.img_size[0] and W == self.img_size[1], \ 447 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 448 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 449 | if self.norm is not None: 450 | x = self.norm(x) 451 | return x 452 | 453 | def flops(self): 454 | Ho, Wo = self.patches_resolution 455 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 456 | if self.norm is not None: 457 | flops += Ho * Wo * self.embed_dim 458 | return flops 459 | 460 | 461 | class BasicLayer(nn.Module): 462 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 463 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 464 | drop_path=0., norm_layer=nn.LayerNorm, res_scale=None, use_checkpoint=False, 465 | ): 466 | 467 | super().__init__() 468 | self.dim = dim 469 | self.input_resolution = input_resolution 470 | self.depth = depth 471 | self.use_checkpoint = use_checkpoint 472 | 473 | # build blocks 474 | self.blocks = nn.ModuleList([ 475 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 476 | num_heads=num_heads, window_size=window_size, 477 | shift_size=0 if ( 478 | i % 2 == 0) else window_size // 2, 479 | mlp_ratio=mlp_ratio, 480 | qkv_bias=qkv_bias, qk_scale=qk_scale, 481 | drop=drop, attn_drop=attn_drop, 482 | drop_path=drop_path[i] if isinstance( 483 | drop_path, list) else drop_path, 484 | norm_layer=norm_layer) 485 | for i in range(depth)]) 486 | 487 | # patch merging layer 488 | if res_scale is not None: 489 | self.res_scale = res_scale(input_resolution, dim) 490 | else: 491 | self.res_scale = None 492 | 493 | def forward(self, x): 494 | for blk in self.blocks: 495 | if self.use_checkpoint: 496 | x = checkpoint.checkpoint(blk, x) 497 | else: 498 | x = blk(x) 499 | if self.res_scale is not None: 500 | x = self.res_scale(x) 501 | return x 502 | 503 | 504 | class SwinTransformer(nn.Module): 505 | def __init__(self, img_size=224, patch_size=4, in_chans=3, 506 | embed_dim=96, 507 | encoder_depths=[2, 2, 2, 2], 508 | decoder_depths=[2, 2, 2, 2], 509 | num_heads=[3, 6, 12, 24], 510 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 511 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 512 | norm_layer=nn.LayerNorm, patch_norm=True, 513 | ape=False, 514 | use_checkpoint=False, 515 | prompt=False, 516 | ): 517 | super().__init__() 518 | 519 | print("SwinTransformer architecture information:") 520 | 521 | self.num_layers = len(encoder_depths) 522 | self.embed_dim = embed_dim 523 | self.ape = ape 524 | self.patch_norm = patch_norm 525 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 526 | self.mlp_ratio = mlp_ratio 527 | self.prompt = prompt 528 | 529 | self.patch_embed = PatchEmbed( 530 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 531 | norm_layer=norm_layer if self.patch_norm else None) 532 | num_patches = self.patch_embed.num_patches 533 | patches_resolution = self.patch_embed.patches_resolution 534 | self.patches_resolution = patches_resolution 535 | 536 | # absolute position embedding 537 | if self.ape: 538 | self.absolute_pos_embed = nn.Parameter( 539 | torch.zeros(1, num_patches, embed_dim)) 540 | trunc_normal_(self.absolute_pos_embed, std=.02) 541 | 542 | # learnable prompt embedding 543 | if self.prompt: 544 | self.dec_prompt_mlp = nn.Linear(8+2+2+3, embed_dim*8) 545 | self.dec_prompt_mlp_cls2 = nn.Linear(8+2+2+3, embed_dim*4) 546 | self.dec_prompt_mlp_seg2_cls3 = nn.Linear(8+2+2+3, embed_dim*2) 547 | self.dec_prompt_mlp_seg3 = nn.Linear(8+2+2+3, embed_dim*1) 548 | 549 | self.pos_drop = nn.Dropout(p=drop_rate) 550 | 551 | # stochastic depth 552 | enc_dpr = [x.item() for x in torch.linspace( 553 | 0, drop_path_rate, sum(encoder_depths))] 554 | 555 | ## Encoder + bottleneck ## 556 | self.layers = nn.ModuleList() 557 | for i_layer in range(self.num_layers): 558 | 559 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 560 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 561 | patches_resolution[1] // (2 ** i_layer)), 562 | depth=encoder_depths[i_layer], 563 | num_heads=num_heads[i_layer], 564 | window_size=window_size, 565 | mlp_ratio=self.mlp_ratio, 566 | qkv_bias=qkv_bias, qk_scale=qk_scale, 567 | drop=drop_rate, attn_drop=attn_drop_rate, 568 | drop_path=enc_dpr[sum(encoder_depths[:i_layer]):sum(encoder_depths[:i_layer + 1])], 569 | norm_layer=norm_layer, 570 | res_scale=PatchMerging if (i_layer < self.num_layers - 1) else None, 571 | use_checkpoint=use_checkpoint 572 | ) 573 | self.layers.append(layer) 574 | 575 | ## Multi Decoder ## 576 | 577 | self.layers_task_seg_up = nn.ModuleList() 578 | self.layers_task_seg_skip = nn.ModuleList() 579 | self.layers_task_seg_head = nn.ModuleList() 580 | 581 | self.layers_task_cls_up = nn.ModuleList() 582 | self.layers_task_cls_head = nn.ModuleList() 583 | 584 | # stochastic depth 585 | dec_dpr = [x.item() for x in torch.linspace( 586 | 0, drop_path_rate, sum(decoder_depths))] 587 | 588 | for i_layer in range(self.num_layers): 589 | # seg 590 | self.layers_task_seg_skip.append( 591 | nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), 592 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() 593 | ) 594 | if i_layer == 0: 595 | self.layers_task_seg_up.append( 596 | PatchExpand(input_resolution=( 597 | patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 598 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), 599 | dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), 600 | dim_scale=2, norm_layer=norm_layer)) 601 | else: 602 | self.layers_task_seg_up.append( 603 | BasicLayer(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), 604 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 605 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), 606 | depth=decoder_depths[(self.num_layers-1-i_layer)], 607 | num_heads=num_heads[( 608 | self.num_layers-1-i_layer)], 609 | window_size=window_size, 610 | mlp_ratio=self.mlp_ratio, 611 | qkv_bias=qkv_bias, qk_scale=qk_scale, 612 | drop=drop_rate, attn_drop=attn_drop_rate, 613 | drop_path=dec_dpr[sum(decoder_depths[:( 614 | self.num_layers-1-i_layer)]):sum(decoder_depths[:(self.num_layers-1-i_layer) + 1])], 615 | norm_layer=norm_layer, 616 | res_scale=PatchExpand if (i_layer < self.num_layers - 1) else None, 617 | use_checkpoint=use_checkpoint, 618 | ) 619 | ) 620 | # cls 621 | if i_layer == 0: 622 | pass 623 | else: 624 | self.layers_task_cls_up.append( 625 | BasicLayer(dim=int(embed_dim * 2 ** (self.num_layers-i_layer)), 626 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-0)), 627 | patches_resolution[1] // (2 ** (self.num_layers-1-0))), 628 | depth=decoder_depths[(self.num_layers-i_layer)], 629 | num_heads=num_heads[(self.num_layers-i_layer)], 630 | window_size=window_size, 631 | mlp_ratio=self.mlp_ratio, 632 | qkv_bias=qkv_bias, qk_scale=qk_scale, 633 | drop=drop_rate, attn_drop=attn_drop_rate, 634 | drop_path=dec_dpr[sum(decoder_depths[:(self.num_layers-i_layer)]):sum(decoder_depths[:(self.num_layers-i_layer) + 1])], 635 | norm_layer=norm_layer, 636 | res_scale=ChannelHalf if (i_layer < self.num_layers - 1) else None, 637 | use_checkpoint=use_checkpoint 638 | )) 639 | 640 | self.layers_task_seg_head.append( 641 | FinalPatchExpand_X4(input_resolution=(img_size//patch_size, img_size//patch_size), dim=embed_dim) 642 | ) 643 | self.layers_task_seg_head.append( 644 | nn.Conv2d(in_channels=embed_dim, out_channels=2, kernel_size=1, bias=False) 645 | ) 646 | self.layers_task_cls_head.append( 647 | nn.Linear(self.embed_dim*2, 2) 648 | ) 649 | 650 | ## Norm Layer ## 651 | self.norm = norm_layer(self.num_features) 652 | self.norm_task_seg = norm_layer(self.embed_dim) 653 | self.norm_task_cls = norm_layer(self.embed_dim*2) 654 | 655 | self.apply(self._init_weights) 656 | 657 | def _init_weights(self, m): 658 | if isinstance(m, nn.Linear): 659 | trunc_normal_(m.weight, std=.02) 660 | if isinstance(m, nn.Linear) and m.bias is not None: 661 | nn.init.constant_(m.bias, 0) 662 | elif isinstance(m, nn.LayerNorm): 663 | nn.init.constant_(m.bias, 0) 664 | nn.init.constant_(m.weight, 1.0) 665 | 666 | @torch.jit.ignore 667 | def no_weight_decay(self): 668 | return {'absolute_pos_embed'} 669 | 670 | @torch.jit.ignore 671 | def no_weight_decay_keywords(self): 672 | return {'relative_position_bias_table'} 673 | 674 | # Encoder and Bottleneck 675 | def forward_features(self, x): 676 | x = self.patch_embed(x) 677 | if self.ape: 678 | x = x + self.absolute_pos_embed 679 | 680 | x = self.pos_drop(x) 681 | x_downsample = [] 682 | 683 | for layer in self.layers: 684 | x_downsample.append(x) 685 | x = layer(x) 686 | 687 | x = self.norm(x) 688 | 689 | return x, x_downsample 690 | 691 | # Decoder task head 692 | def forward_task_features(self, x, x_downsample): 693 | if self.prompt: 694 | x, position_prompt, task_prompt, type_prompt, nature_prompt = x 695 | 696 | # seg 697 | for inx, layer_seg in enumerate(self.layers_task_seg_up): 698 | if inx == 0: 699 | x_seg = layer_seg(x) 700 | else: 701 | x_seg = torch.cat([x_seg, x_downsample[3-inx]], -1) 702 | x_seg = self.layers_task_seg_skip[inx](x_seg) 703 | 704 | if self.prompt and inx > 1: 705 | if inx == 2: 706 | x_seg = layer_seg(x_seg + 707 | self.dec_prompt_mlp_seg2_cls3(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1)) 708 | if inx == 3: 709 | x_seg = layer_seg(x_seg + 710 | self.dec_prompt_mlp_seg3(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1)) 711 | else: 712 | x_seg = layer_seg(x_seg) 713 | 714 | x_seg = self.norm_task_seg(x_seg) 715 | 716 | H, W = self.patches_resolution 717 | B, _, _ = x_seg.shape 718 | x_seg = self.layers_task_seg_head[0](x_seg) 719 | x_seg = x_seg.view(B, 4*H, 4*W, -1) 720 | x_seg = x_seg.permute(0, 3, 1, 2) 721 | x_seg = self.layers_task_seg_head[1](x_seg) 722 | 723 | # cls 724 | for inx, layer_head in enumerate(self.layers_task_cls_up): 725 | if inx == 0: 726 | x_cls = layer_head(x) 727 | else: 728 | if self.prompt: 729 | if inx == 1: 730 | x_cls = layer_head(x_cls + 731 | self.dec_prompt_mlp_cls2(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1)) 732 | if inx == 2: 733 | x_cls = layer_head(x_cls + 734 | self.dec_prompt_mlp_seg2_cls3(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1)) 735 | else: 736 | x_cls = layer_head(x_cls) 737 | 738 | x_cls = self.norm_task_cls(x_cls) 739 | 740 | B, _, _ = x_cls.shape 741 | x_cls = x_cls.transpose(1, 2) 742 | x_cls = F.adaptive_avg_pool1d(x_cls, 1).view(B, -1) 743 | x_cls = self.layers_task_cls_head[0](x_cls) 744 | 745 | return (x_seg, x_cls) 746 | 747 | def forward(self, x): 748 | if self.prompt: 749 | x, position_prompt, task_prompt, type_prompt, nature_prompt = x 750 | x, x_downsample = self.forward_features(x) 751 | x = x + self.dec_prompt_mlp(torch.cat([position_prompt, task_prompt, 752 | type_prompt, nature_prompt], dim=1)).unsqueeze(1) 753 | x_tuple = self.forward_task_features( 754 | (x, position_prompt, task_prompt, type_prompt, nature_prompt), x_downsample) 755 | else: 756 | x, x_downsample = self.forward_features(x) 757 | x_tuple = self.forward_task_features(x, x_downsample) 758 | return x_tuple 759 | 760 | 761 | class OmniVisionTransformer(nn.Module): 762 | def __init__(self, config, 763 | prompt=False, 764 | ): 765 | super(OmniVisionTransformer, self).__init__() 766 | self.config = config 767 | self.prompt = prompt 768 | 769 | self.swin = SwinTransformer(img_size=config.DATA.IMG_SIZE, 770 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 771 | in_chans=config.MODEL.SWIN.IN_CHANS, 772 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 773 | encoder_depths=config.MODEL.SWIN.ENCODER_DEPTHS, 774 | decoder_depths=config.MODEL.SWIN.DECODER_DEPTHS, 775 | num_heads=config.MODEL.SWIN.NUM_HEADS, 776 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 777 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 778 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 779 | qk_scale=config.MODEL.SWIN.QK_SCALE, 780 | drop_rate=config.MODEL.DROP_RATE, 781 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 782 | ape=config.MODEL.SWIN.APE, 783 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 784 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 785 | prompt=prompt, 786 | ) 787 | 788 | def forward(self, x): 789 | if self.prompt: 790 | image = x[0].squeeze(1).permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] 791 | position_prompt = x[1] 792 | task_prompt = x[2] 793 | type_prompt = x[3] 794 | nature_prompt = x[4] 795 | result = self.swin((image, position_prompt, task_prompt, type_prompt, nature_prompt)) 796 | else: 797 | x = x.squeeze(1).permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] 798 | result = self.swin(x) 799 | return result 800 | 801 | def load_from(self, config): 802 | pretrained_path = config.MODEL.PRETRAIN_CKPT 803 | if pretrained_path is not None: 804 | print("pretrained_path:{}".format(pretrained_path)) 805 | device = torch.device( 806 | 'cuda' if torch.cuda.is_available() else 'cpu') 807 | pretrained_dict = torch.load(pretrained_path, map_location=device) 808 | pretrained_dict = pretrained_dict['model'] 809 | print("---start load pretrained model of swin encoder---") 810 | model_dict = self.swin.state_dict() 811 | full_dict = copy.deepcopy(pretrained_dict) 812 | for k, v in pretrained_dict.items(): 813 | if "layers." in k: 814 | current_layer_num = 3-int(k[7:8]) 815 | current_k = "layers_up." + str(current_layer_num) + k[8:] 816 | full_dict.update({current_k: v}) 817 | for k in list(full_dict.keys()): 818 | if k in model_dict: 819 | if full_dict[k].shape != model_dict[k].shape: 820 | print("delete:{};shape pretrain:{};shape model:{}".format( 821 | k, v.shape, model_dict[k].shape)) 822 | del full_dict[k] 823 | 824 | self.swin.load_state_dict(full_dict, strict=False) 825 | else: 826 | print("none pretrain") 827 | 828 | def load_from_self(self, pretrained_path): 829 | print("pretrained_path:{}".format(pretrained_path)) 830 | device = torch.device( 831 | 'cuda' if torch.cuda.is_available() else 'cpu') 832 | pretrained_dict = torch.load(pretrained_path, map_location=device) 833 | full_dict = copy.deepcopy(pretrained_dict) 834 | for k, v in pretrained_dict.items(): 835 | if "module.swin." in k: 836 | current_k = k[12:] 837 | full_dict.update({current_k: v}) 838 | del full_dict[k] 839 | 840 | self.swin.load_state_dict(full_dict) 841 | -------------------------------------------------------------------------------- /omni_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from config import get_config 13 | 14 | from datasets.dataset import CenterCropGenerator 15 | from datasets.dataset import USdatasetCls, USdatasetSeg 16 | 17 | from utils import omni_seg_test 18 | from sklearn.metrics import accuracy_score 19 | 20 | from networks.omni_vision_transformer import OmniVisionTransformer as ViT_omni 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--root_path', type=str, 24 | default='data_demo/', help='root dir for data') 25 | parser.add_argument('--output_dir', type=str, help='output dir') 26 | parser.add_argument('--max_epochs', type=int, default=200, help='maximum epoch number to train') 27 | parser.add_argument('--batch_size', type=int, default=16, 28 | help='batch_size per gpu') 29 | parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input') 30 | parser.add_argument('--is_saveout', action="store_true", help='whether to save results during inference') 31 | parser.add_argument('--test_save_dir', type=str, default='../predictions', help='saving prediction as nii!') 32 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 33 | parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate') 34 | parser.add_argument('--seed', type=int, default=1234, help='random seed') 35 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml", 36 | metavar="FILE", help='path to config file', ) 37 | parser.add_argument( 38 | "--opts", 39 | help="Modify config options by adding 'KEY VALUE' pairs. ", 40 | default=None, 41 | nargs='+', 42 | ) 43 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 44 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 45 | help='no: no cache, ' 46 | 'full: cache all data, ' 47 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 48 | parser.add_argument('--resume', help='resume from checkpoint') 49 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 50 | parser.add_argument('--use-checkpoint', action='store_true', 51 | help="whether to use gradient checkpointing to save memory") 52 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 53 | help='mixed precision opt level, if O0, no amp is used') 54 | parser.add_argument('--tag', help='tag of experiment') 55 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 56 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 57 | 58 | parser.add_argument('--prompt', action='store_true', help='using prompt') 59 | 60 | args = parser.parse_args() 61 | config = get_config(args) 62 | 63 | 64 | def inference(args, model, test_save_path=None): 65 | import csv 66 | import time 67 | 68 | if not os.path.exists("exp_out/result.csv"): 69 | with open("exp_out/result.csv", 'w', newline='') as csvfile: 70 | writer = csv.writer(csvfile) 71 | writer.writerow(['dataset', 'task', 'metric', 'time']) 72 | 73 | seg_test_set = ["BUS-BRA", "BUSIS", "CAMUS", "DDTI", "Fetal_HC", "kidneyUS", "UDIAT"] 74 | 75 | for dataset_name in seg_test_set: 76 | num_classes = 2 77 | db_test = USdatasetSeg( 78 | base_dir=os.path.join(args.root_path, "segmentation", dataset_name), 79 | split="test", 80 | list_dir=os.path.join(args.root_path, "segmentation", dataset_name), 81 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]), 82 | prompt=args.prompt 83 | ) 84 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 85 | logging.info("{} test iterations per epoch".format(len(testloader))) 86 | model.eval() 87 | 88 | metric_list = 0.0 89 | count_matrix = np.ones((len(db_test), num_classes-1)) 90 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 91 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 92 | if args.prompt: 93 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([1, 0]).float() 94 | task_prompt = torch.tensor(np.array([[1], [0]])).permute([1, 0]).float() 95 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float() 96 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float() 97 | metric_i = omni_seg_test(image, label, model, 98 | classes=num_classes, 99 | test_save_path=test_save_path, 100 | case=case_name, 101 | prompt=args.prompt, 102 | type_prompt=type_prompt, 103 | nature_prompt=nature_prompt, 104 | position_prompt=position_prompt, 105 | task_prompt=task_prompt 106 | ) 107 | else: 108 | metric_i = omni_seg_test(image, label, model, 109 | classes=num_classes, 110 | test_save_path=test_save_path, 111 | case=case_name) 112 | zero_label_flag = False 113 | for i in range(1, num_classes): 114 | if not metric_i[i-1][1]: 115 | count_matrix[i_batch, i-1] = 0 116 | zero_label_flag = True 117 | metric_i = [element[0] for element in metric_i] 118 | metric_list += np.array(metric_i) 119 | logging.info('idx %d case %s mean_dice %f' % 120 | (i_batch, case_name, np.mean(metric_i, axis=0))) 121 | logging.info("This case has zero label: %s" % zero_label_flag) 122 | 123 | metric_list = metric_list / (count_matrix.sum(axis=0) + 1e-6) 124 | for i in range(1, num_classes): 125 | logging.info('Mean class %d mean_dice %f' % (i, metric_list[i-1])) 126 | performance = np.mean(metric_list, axis=0) 127 | logging.info('Testing performance in best val model: mean_dice : %f' % (performance)) 128 | 129 | with open("exp_out/result.csv", 'a', newline='') as csvfile: 130 | writer = csv.writer(csvfile) 131 | if args.prompt: 132 | writer.writerow([dataset_name, 'omni_seg_prompt@'+args.output_dir, performance, 133 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())]) 134 | else: 135 | writer.writerow([dataset_name, 'omni_seg@'+args.output_dir, performance, 136 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())]) 137 | 138 | cls_test_set = ["Appendix", "BUS-BRA", "Fatty-Liver", "UDIAT"] 139 | 140 | for dataset_name in cls_test_set: 141 | num_classes = 2 142 | db_test = USdatasetCls( 143 | base_dir=os.path.join(args.root_path, "classification", dataset_name), 144 | split="test", 145 | list_dir=os.path.join(args.root_path, "classification", dataset_name), 146 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]), 147 | prompt=args.prompt 148 | ) 149 | 150 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 151 | logging.info("{} test iterations per epoch".format(len(testloader))) 152 | model.eval() 153 | 154 | label_list = [] 155 | prediction_list = [] 156 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 157 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 158 | if args.prompt: 159 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([1, 0]).float() 160 | task_prompt = torch.tensor(np.array([[0], [1]])).permute([1, 0]).float() 161 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float() 162 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float() 163 | with torch.no_grad(): 164 | output = model((image.cuda(), position_prompt.cuda(), task_prompt.cuda(), 165 | type_prompt.cuda(), nature_prompt.cuda()))[1] 166 | else: 167 | with torch.no_grad(): 168 | output = model(image.cuda())[1] 169 | 170 | output = np.argmax(torch.softmax(output, dim=1).data.cpu().numpy()) 171 | logging.info('idx %d case %s label: %d predict: %d' % (i_batch, case_name, label, output)) 172 | 173 | label_list.append(label.numpy()) 174 | prediction_list.append(output) 175 | 176 | label_list = np.array(label_list) 177 | prediction_list = np.array(prediction_list) 178 | for i in range(num_classes): 179 | logging.info('class %d acc %f' % (i, accuracy_score( 180 | (label_list == i).astype(int), (prediction_list == i).astype(int)))) 181 | performance = accuracy_score(label_list, prediction_list) 182 | logging.info('Testing performance in best val model: acc : %f' % (performance)) 183 | 184 | with open("exp_out/result.csv", 'a', newline='') as csvfile: 185 | writer = csv.writer(csvfile) 186 | if args.prompt: 187 | writer.writerow([dataset_name, 'omni_cls_prompt@'+args.output_dir, performance, 188 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())]) 189 | else: 190 | writer.writerow([dataset_name, 'omni_cls@'+args.output_dir, performance, 191 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())]) 192 | 193 | 194 | if __name__ == "__main__": 195 | if not args.deterministic: 196 | cudnn.benchmark = True 197 | cudnn.deterministic = False 198 | else: 199 | cudnn.benchmark = False 200 | cudnn.deterministic = True 201 | random.seed(args.seed) 202 | np.random.seed(args.seed) 203 | torch.manual_seed(args.seed) 204 | torch.cuda.manual_seed(args.seed) 205 | 206 | net = ViT_omni( 207 | config, 208 | prompt=args.prompt, 209 | ).cuda() 210 | net.load_from(config) 211 | 212 | snapshot = os.path.join(args.output_dir, 'best_model.pth') 213 | if not os.path.exists(snapshot): 214 | snapshot = snapshot.replace('best_model', 'epoch_'+str(args.max_epochs-1)) 215 | 216 | device = torch.device("cuda") 217 | model = net.to(device=device) 218 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 219 | torch.distributed.init_process_group(backend="nccl", init_method='env://', world_size=1, rank=0) 220 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) 221 | 222 | import copy 223 | pretrained_dict = torch.load(snapshot, map_location=device) 224 | full_dict = copy.deepcopy(pretrained_dict) 225 | for k, v in pretrained_dict.items(): 226 | if "module." not in k: 227 | full_dict["module."+k] = v 228 | del full_dict[k] 229 | 230 | msg = model.load_state_dict(full_dict) 231 | 232 | print("self trained swin unet", msg) 233 | snapshot_name = snapshot.split('/')[-1] 234 | 235 | logging.basicConfig(filename=args.output_dir+"/"+"test_result.txt", level=logging.INFO, 236 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 237 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 238 | logging.info(str(args)) 239 | logging.info(snapshot_name) 240 | 241 | if args.is_saveout: 242 | args.test_save_dir = os.path.join(args.output_dir, "predictions") 243 | test_save_path = args.test_save_dir 244 | os.makedirs(test_save_path, exist_ok=True) 245 | else: 246 | test_save_path = None 247 | inference(args, net, test_save_path) 248 | -------------------------------------------------------------------------------- /omni_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | from networks.omni_vision_transformer import OmniVisionTransformer as ViT_omni 8 | from omni_trainer import omni_train 9 | from config import get_config 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--root_path', type=str, 13 | default='data_demo/', help='root dir for data') 14 | parser.add_argument('--output_dir', type=str, help='output dir') 15 | parser.add_argument('--max_epochs', type=int, 16 | default=200, help='maximum epoch number to train') 17 | parser.add_argument('--batch_size', type=int, 18 | default=16, help='batch_size per gpu') 19 | parser.add_argument('--gpu', type=str, default=None) 20 | parser.add_argument('--deterministic', type=int, default=1, 21 | help='whether use deterministic training') 22 | parser.add_argument('--base_lr', type=float, default=0.01, 23 | help='segmentation network learning rate') 24 | parser.add_argument('--img_size', type=int, 25 | default=224, help='input patch size of network input') 26 | parser.add_argument('--seed', type=int, 27 | default=1234, help='random seed') 28 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml", 29 | metavar="FILE", help='path to config file', ) 30 | parser.add_argument( 31 | "--opts", 32 | help="Modify config options by adding 'KEY VALUE' pairs. ", 33 | default=None, 34 | nargs='+', 35 | ) 36 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 37 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 38 | help='no: no cache, ' 39 | 'full: cache all data, ' 40 | 'part: sharding the dataset into non-overlapping pieces and only cache one piece') 41 | parser.add_argument('--resume', help='resume from checkpoint') 42 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 43 | parser.add_argument('--use-checkpoint', action='store_true', 44 | help="whether to use gradient checkpointing to save memory") 45 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 46 | help='mixed precision opt level, if O0, no amp is used') 47 | parser.add_argument('--tag', help='tag of experiment') 48 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 49 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 50 | 51 | parser.add_argument('--pretrain_ckpt', type=str, help='pretrained checkpoint') 52 | 53 | parser.add_argument('--prompt', action='store_true', help='using prompt for training') 54 | parser.add_argument('--adapter_ft', action='store_true', help='using adapter for fine-tuning') 55 | 56 | 57 | args = parser.parse_args() 58 | 59 | config = get_config(args) 60 | 61 | 62 | if __name__ == "__main__": 63 | if not args.deterministic: 64 | cudnn.benchmark = True 65 | cudnn.deterministic = False 66 | else: 67 | cudnn.benchmark = False 68 | cudnn.deterministic = True 69 | 70 | random.seed(args.seed) 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | torch.cuda.manual_seed(args.seed) 74 | 75 | if args.batch_size != 24 and args.batch_size % 6 == 0: 76 | args.base_lr *= args.batch_size / 24 77 | 78 | if not os.path.exists(args.output_dir): 79 | os.makedirs(args.output_dir, exist_ok=True) 80 | 81 | net = ViT_omni( 82 | config, 83 | prompt=args.prompt, 84 | ).cuda() 85 | if args.pretrain_ckpt is not None: 86 | net.load_from_self(args.pretrain_ckpt) 87 | else: 88 | net.load_from(config) 89 | 90 | if args.prompt and args.adapter_ft: 91 | 92 | for name, param in net.named_parameters(): 93 | if 'prompt' in name: 94 | param.requires_grad = True 95 | print(name) 96 | else: 97 | param.requires_grad = False 98 | 99 | omni_train(args, net, args.output_dir) 100 | -------------------------------------------------------------------------------- /omni_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import random 5 | import logging 6 | import datetime 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.optim as optim 12 | import torch.distributed as dist 13 | from torch.nn.modules.loss import CrossEntropyLoss 14 | from torch.utils.data import DataLoader 15 | from torchvision import transforms 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | 19 | from utils import DiceLoss 20 | from datasets.dataset import USdatasetCls, USdatasetSeg 21 | from datasets.omni_dataset import WeightedRandomSamplerDDP 22 | from datasets.omni_dataset import USdatasetOmni_cls, USdatasetOmni_seg 23 | from datasets.dataset import RandomGenerator, CenterCropGenerator 24 | from sklearn.metrics import roc_auc_score 25 | from utils import omni_seg_test 26 | 27 | 28 | def omni_train(args, model, snapshot_path): 29 | 30 | if args.gpu: 31 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 32 | 33 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) 34 | device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) 35 | gpu_id = rank = int(os.environ["LOCAL_RANK"]) 36 | world_size = int(os.environ["WORLD_SIZE"]) 37 | torch.distributed.init_process_group(backend="nccl", init_method='env://', timeout=datetime.timedelta(seconds=7200)) 38 | 39 | if int(os.environ["LOCAL_RANK"]) == 0: 40 | print('** GPU NUM ** : ', torch.cuda.device_count()) 41 | print('** WORLD SIZE ** : ', torch.distributed.get_world_size()) 42 | print(f"** DDP ** : Start running on rank {rank}.") 43 | 44 | logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, 45 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 46 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 47 | logging.info(str(args)) 48 | base_lr = args.base_lr 49 | batch_size = args.batch_size 50 | 51 | def worker_init_fn(worker_id): 52 | random.seed(args.seed + worker_id) 53 | 54 | db_train_seg = USdatasetOmni_seg(base_dir=args.root_path, split="train", transform=transforms.Compose( 55 | [RandomGenerator(output_size=[args.img_size, args.img_size])]), prompt=args.prompt) 56 | 57 | weight_base = [1/4, 1/2, 2, 2, 1, 2, 2] 58 | sample_weight_seq = [[weight_base[dataset_index]] * 59 | element for dataset_index, element in enumerate(db_train_seg.subset_len)] 60 | sample_weight_seq = [element for sublist in sample_weight_seq for element in sublist] 61 | 62 | weighted_sampler_seg = WeightedRandomSamplerDDP( 63 | data_set=db_train_seg, 64 | weights=sample_weight_seq, 65 | num_replicas=world_size, 66 | rank=rank, 67 | num_samples=args.num_samples_seg, 68 | replacement=True 69 | ) 70 | trainloader_seg = DataLoader(db_train_seg, 71 | batch_size=batch_size, 72 | num_workers=16, 73 | pin_memory=True, 74 | worker_init_fn=worker_init_fn, 75 | sampler=weighted_sampler_seg 76 | ) 77 | 78 | db_train_cls = USdatasetOmni_cls(base_dir=args.root_path, split="train", transform=transforms.Compose( 79 | [RandomGenerator(output_size=[args.img_size, args.img_size])]), prompt=args.prompt) 80 | 81 | weight_base = [2, 1/4, 2, 2] 82 | sample_weight_seq = [[weight_base[dataset_index]] * 83 | element for dataset_index, element in enumerate(db_train_cls.subset_len)] 84 | sample_weight_seq = [element for sublist in sample_weight_seq for element in sublist] 85 | 86 | weighted_sampler_cls = WeightedRandomSamplerDDP( 87 | data_set=db_train_cls, 88 | weights=sample_weight_seq, 89 | num_replicas=world_size, 90 | rank=rank, 91 | num_samples=args.num_samples_cls, 92 | replacement=True 93 | ) 94 | trainloader_cls = DataLoader(db_train_cls, 95 | batch_size=batch_size, 96 | num_workers=16, 97 | pin_memory=True, 98 | worker_init_fn=worker_init_fn, 99 | sampler=weighted_sampler_cls 100 | ) 101 | 102 | model = model.to(device=device) 103 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 104 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_id], find_unused_parameters=True) 105 | 106 | model.train() 107 | 108 | seg_ce_loss = CrossEntropyLoss() 109 | seg_dice_loss = DiceLoss() 110 | cls_ce_loss = CrossEntropyLoss() 111 | 112 | optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.05, betas=(0.9, 0.999)) 113 | 114 | resume_epoch = 0 115 | if args.resume is not None: 116 | model.load_state_dict(torch.load(args.resume, map_location='cpu')['model']) 117 | optimizer.load_state_dict(torch.load(args.resume, map_location='cpu')['optimizer']) 118 | resume_epoch = torch.load(args.resume, map_location='cpu')['epoch'] 119 | 120 | writer = SummaryWriter(snapshot_path + '/log') 121 | global_iter_num = 0 122 | seg_iter_num = 0 123 | cls_iter_num = 0 124 | max_epoch = args.max_epochs 125 | total_iterations = (len(trainloader_seg) + len(trainloader_cls)) 126 | max_iterations = args.max_epochs * total_iterations 127 | logging.info("{} batch size. {} iterations per epoch. {} max iterations ".format( 128 | batch_size, total_iterations, max_iterations)) 129 | best_performance = 0.0 130 | best_epoch = 0 131 | 132 | if int(os.environ["LOCAL_RANK"]) != 0: 133 | iterator = tqdm(range(resume_epoch, max_epoch), ncols=70, disable=True) 134 | else: 135 | iterator = tqdm(range(resume_epoch, max_epoch), ncols=70, disable=False) 136 | 137 | for epoch_num in iterator: 138 | logging.info("\n epoch: {}".format(epoch_num)) 139 | weighted_sampler_seg.set_epoch(epoch_num) 140 | weighted_sampler_cls.set_epoch(epoch_num) 141 | 142 | torch.cuda.empty_cache() 143 | for i_batch, sampled_batch in tqdm(enumerate(trainloader_seg)): 144 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 145 | image_batch, label_batch = image_batch.to(device=device), label_batch.to(device=device) 146 | if args.prompt: 147 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([ 148 | 1, 0]).float().to(device=device) 149 | task_prompt = torch.tensor(np.array(sampled_batch['task_prompt'])).permute([ 150 | 1, 0]).float().to(device=device) 151 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([ 152 | 1, 0]).float().to(device=device) 153 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([ 154 | 1, 0]).float().to(device=device) 155 | (x_seg, _) = model((image_batch, position_prompt, task_prompt, type_prompt, nature_prompt)) 156 | else: 157 | (x_seg, _) = model(image_batch) 158 | 159 | loss_ce = seg_ce_loss(x_seg, label_batch[:].long()) 160 | loss_dice = seg_dice_loss(x_seg, label_batch, softmax=True) 161 | loss = 0.4 * loss_ce + 0.6 * loss_dice 162 | 163 | optimizer.zero_grad() 164 | loss.backward() 165 | optimizer.step() 166 | lr_ = base_lr * (1.0 - global_iter_num / max_iterations) ** 0.9 167 | for param_group in optimizer.param_groups: 168 | param_group['lr'] = lr_ 169 | 170 | seg_iter_num = seg_iter_num + 1 171 | global_iter_num = global_iter_num + 1 172 | 173 | writer.add_scalar('info/lr', lr_, seg_iter_num) 174 | writer.add_scalar('info/seg_loss', loss, seg_iter_num) 175 | 176 | logging.info('global iteration %d and seg iteration %d : loss : %f' % 177 | (global_iter_num, seg_iter_num, loss.item())) 178 | 179 | torch.cuda.empty_cache() 180 | for i_batch, sampled_batch in tqdm(enumerate(trainloader_cls)): 181 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 182 | image_batch, label_batch = image_batch.to(device=device), label_batch.to(device=device) 183 | if args.prompt: 184 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([ 185 | 1, 0]).float().to(device=device) 186 | task_prompt = torch.tensor(np.array(sampled_batch['task_prompt'])).permute([ 187 | 1, 0]).float().to(device=device) 188 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([ 189 | 1, 0]).float().to(device=device) 190 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([ 191 | 1, 0]).float().to(device=device) 192 | (_, x_cls) = model((image_batch, position_prompt, task_prompt, type_prompt, nature_prompt)) 193 | else: 194 | (_, x_cls) = model(image_batch) 195 | 196 | loss_ce = cls_ce_loss(x_cls, label_batch[:].long()) 197 | loss = loss_ce 198 | 199 | optimizer.zero_grad() 200 | loss.backward() 201 | optimizer.step() 202 | lr_ = base_lr * (1.0 - global_iter_num / max_iterations) ** 0.9 203 | for param_group in optimizer.param_groups: 204 | param_group['lr'] = lr_ 205 | 206 | cls_iter_num = cls_iter_num + 1 207 | global_iter_num = global_iter_num + 1 208 | 209 | writer.add_scalar('info/lr', lr_, cls_iter_num) 210 | writer.add_scalar('info/cls_loss', loss, cls_iter_num) 211 | 212 | logging.info('global iteration %d and cls iteration %d : loss : %f' % 213 | (global_iter_num, cls_iter_num, loss.item())) 214 | 215 | dist.barrier() 216 | 217 | if int(os.environ["LOCAL_RANK"]) == 0: 218 | torch.cuda.empty_cache() 219 | 220 | save_dict = {'model': model.state_dict(), 221 | 'optimizer': optimizer.state_dict(), 222 | 'epoch': epoch_num} 223 | save_latest_path = os.path.join(snapshot_path, 'latest_{}.pth'.format(epoch_num)) 224 | if os.path.exists(os.path.join(snapshot_path, 'latest_{}.pth'.format(epoch_num-1))): 225 | os.remove(os.path.join(snapshot_path, 'latest_{}.pth'.format(epoch_num-1))) 226 | os.remove(os.path.join(snapshot_path, 'latest.pth')) 227 | torch.save(save_dict, save_latest_path) 228 | os.system('ln -s ' + os.path.abspath(save_latest_path) + ' ' + os.path.join(snapshot_path, 'latest.pth')) 229 | 230 | model.eval() 231 | total_performance = 0.0 232 | 233 | seg_val_set = ["BUS-BRA", "BUSIS", "CAMUS", "DDTI", "Fetal_HC", "kidneyUS", "UDIAT"] 234 | seg_avg_performance = 0.0 235 | 236 | for dataset_name in seg_val_set: 237 | num_classes = 2 238 | db_val = USdatasetSeg( 239 | base_dir=os.path.join(args.root_path, "segmentation", dataset_name), 240 | split="val", 241 | list_dir=os.path.join(args.root_path, "segmentation", dataset_name), 242 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]), 243 | prompt=args.prompt 244 | ) 245 | val_loader = DataLoader(db_val, batch_size=batch_size, shuffle=False, num_workers=8) 246 | logging.info("{} val iterations per epoch".format(len(val_loader))) 247 | 248 | metric_list = 0.0 249 | count_matrix = np.ones((len(db_val), num_classes-1)) 250 | for i_batch, sampled_batch in tqdm(enumerate(val_loader)): 251 | image, label = sampled_batch["image"], sampled_batch["label"] 252 | if args.prompt: 253 | position_prompt = torch.tensor( 254 | np.array(sampled_batch['position_prompt'])).permute([1, 0]).float() 255 | task_prompt = torch.tensor( 256 | np.array([[1]*position_prompt.shape[0], [0]*position_prompt.shape[0]])).permute([1, 0]).float() 257 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float() 258 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float() 259 | metric_i = omni_seg_test(image, label, model, 260 | classes=num_classes, 261 | prompt=args.prompt, 262 | type_prompt=type_prompt, 263 | nature_prompt=nature_prompt, 264 | position_prompt=position_prompt, 265 | task_prompt=task_prompt 266 | ) 267 | else: 268 | metric_i = omni_seg_test(image, label, model, 269 | classes=num_classes) 270 | 271 | for sample_index in range(len(metric_i)): 272 | if not metric_i[sample_index][1]: 273 | count_matrix[i_batch*batch_size+sample_index, 0] = 0 274 | metric_i = [element[0] for element in metric_i] 275 | metric_list += np.array(metric_i).sum() 276 | 277 | metric_list = metric_list / (count_matrix.sum(axis=0) + 1e-6) 278 | performance = np.mean(metric_list, axis=0) 279 | 280 | writer.add_scalar('info/val_seg_metric_{}'.format(dataset_name), performance, epoch_num) 281 | 282 | seg_avg_performance += performance 283 | 284 | seg_avg_performance = seg_avg_performance / len(seg_val_set) 285 | total_performance += seg_avg_performance 286 | writer.add_scalar('info/val_metric_seg_Total', seg_avg_performance, epoch_num) 287 | 288 | cls_val_set = ["Appendix", "BUS-BRA", "Fatty-Liver", "UDIAT"] 289 | cls_avg_performance = 0.0 290 | 291 | for dataset_name in cls_val_set: 292 | num_classes = 2 293 | db_val = USdatasetCls( 294 | base_dir=os.path.join(args.root_path, "classification", dataset_name), 295 | split="val", 296 | list_dir=os.path.join(args.root_path, "classification", dataset_name), 297 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]), 298 | prompt=args.prompt 299 | ) 300 | 301 | val_loader = DataLoader(db_val, batch_size=batch_size, shuffle=False, num_workers=8) 302 | logging.info("{} val iterations per epoch".format(len(val_loader))) 303 | model.eval() 304 | 305 | label_list = [] 306 | prediction_prob_list = [] 307 | for i_batch, sampled_batch in tqdm(enumerate(val_loader)): 308 | image, label = sampled_batch["image"], sampled_batch["label"] 309 | if args.prompt: 310 | position_prompt = torch.tensor( 311 | np.array(sampled_batch['position_prompt'])).permute([1, 0]).float() 312 | task_prompt = torch.tensor( 313 | np.array([[0]*position_prompt.shape[0], [1]*position_prompt.shape[0]])).permute([1, 0]).float() 314 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float() 315 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float() 316 | with torch.no_grad(): 317 | output = model((image.cuda(), position_prompt.cuda(), task_prompt.cuda(), 318 | type_prompt.cuda(), nature_prompt.cuda()))[1] 319 | else: 320 | with torch.no_grad(): 321 | output = model(image.cuda())[1] 322 | output_prob = torch.softmax(output, dim=1).data.cpu().numpy() 323 | 324 | label_list.append(label.numpy()) 325 | prediction_prob_list.append(output_prob) 326 | 327 | label_list = np.expand_dims(np.concatenate( 328 | (np.array(label_list[:-1]).flatten(), np.array(label_list[-1]).flatten())), axis=1).astype('uint8') 329 | label_list_OneHot = np.eye(num_classes)[label_list].squeeze(1) 330 | performance = roc_auc_score(label_list_OneHot, np.concatenate( 331 | (np.array(prediction_prob_list[:-1]).reshape(-1, 2), prediction_prob_list[-1])), multi_class='ovo') 332 | 333 | writer.add_scalar('info/val_cls_metric_{}'.format(dataset_name), performance, epoch_num) 334 | 335 | cls_avg_performance += performance 336 | 337 | cls_avg_performance = cls_avg_performance / len(cls_val_set) 338 | total_performance += cls_avg_performance 339 | writer.add_scalar('info/val_metric_cls_Total', cls_avg_performance, epoch_num) 340 | 341 | TotalAvgPerformance = total_performance/2 342 | 343 | logging.info('This epoch %d Validation performance: %f' % (epoch_num, TotalAvgPerformance)) 344 | logging.info('But the best epoch is: %d and performance: %f' % (best_epoch, best_performance)) 345 | writer.add_scalar('info/val_metric_TotalMean', TotalAvgPerformance, epoch_num) 346 | if TotalAvgPerformance >= best_performance: 347 | if os.path.exists(os.path.join(snapshot_path, 'best_model_{}_{}.pth'.format(best_epoch, round(best_performance, 4)))): 348 | os.remove(os.path.join(snapshot_path, 'best_model_{}_{}.pth'.format( 349 | best_epoch, round(best_performance, 4)))) 350 | os.remove(os.path.join(snapshot_path, 'best_model.pth')) 351 | best_epoch = epoch_num 352 | best_performance = TotalAvgPerformance 353 | logging.info('Validation TotalAvgPerformance in best val model: %f' % (TotalAvgPerformance)) 354 | save_model_path = os.path.join(snapshot_path, 'best_model_{}_{}.pth'.format( 355 | epoch_num, round(best_performance, 4))) 356 | os.system('ln -s ' + os.path.abspath(save_model_path) + 357 | ' ' + os.path.join(snapshot_path, 'best_model.pth')) 358 | torch.save(model.state_dict(), save_model_path) 359 | logging.info("save model to {}".format(save_model_path)) 360 | 361 | model.train() 362 | 363 | writer.close() 364 | return "Training Finished!" 365 | -------------------------------------------------------------------------------- /pretrained_ckpt/.gitkeeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/pretrained_ckpt/.gitkeeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | MedPy 3 | numpy 4 | opencv_python 5 | PyYAML 6 | scikit_learn 7 | scipy 8 | timm 9 | torch 10 | torchvision 11 | tqdm 12 | yacs -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | import torch.nn as nn 5 | import cv2 6 | 7 | 8 | class DiceLoss(nn.Module): 9 | def __init__(self, n_classes=2): 10 | super(DiceLoss, self).__init__() 11 | self.n_classes = n_classes 12 | 13 | def _one_hot_encoder(self, input_tensor): 14 | tensor_list = [] 15 | for i in range(self.n_classes): 16 | temp_prob = input_tensor == i 17 | tensor_list.append(temp_prob.unsqueeze(1)) 18 | output_tensor = torch.cat(tensor_list, dim=1) 19 | return output_tensor.float() 20 | 21 | def _dice_loss(self, score, target): 22 | target = target.float() 23 | smooth = 1e-5 24 | intersect = torch.sum(score * target) 25 | y_sum = torch.sum(target * target) 26 | z_sum = torch.sum(score * score) 27 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 28 | loss = 1 - loss 29 | return loss 30 | 31 | def forward(self, inputs, target, weight=None, softmax=False): 32 | if softmax: 33 | inputs = torch.softmax(inputs, dim=1) 34 | target = self._one_hot_encoder(target) 35 | if weight is None: 36 | weight = [1] * self.n_classes 37 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 38 | class_wise_dice = [] 39 | loss = 0.0 40 | for i in range(0, self.n_classes): 41 | dice = self._dice_loss(inputs[:, i], target[:, i]) 42 | class_wise_dice.append(1.0 - dice.item()) 43 | loss += dice * weight[i] 44 | return loss / self.n_classes 45 | 46 | 47 | def calculate_metric_percase(pred, gt): 48 | pred[pred > 0] = 1 49 | gt[gt > 0] = 1 50 | if pred.sum() > 0 and gt.sum() > 0: 51 | dice = metric.binary.dc(pred, gt) 52 | return dice, True 53 | elif pred.sum() > 0 and gt.sum() == 0: 54 | return 0, False 55 | elif pred.sum() == 0 and gt.sum() > 0: 56 | return 0, True 57 | else: 58 | return 0, False 59 | 60 | 61 | def omni_seg_test(image, label, net, classes, ClassStartIndex=1, test_save_path=None, case=None, 62 | prompt=False, 63 | type_prompt=None, 64 | nature_prompt=None, 65 | position_prompt=None, 66 | task_prompt=None 67 | ): 68 | label = label.squeeze(0).cpu().detach().numpy() 69 | image_save = image.squeeze(0).cpu().detach().numpy() 70 | input = image.cuda() 71 | if prompt: 72 | position_prompt = position_prompt.cuda() 73 | task_prompt = task_prompt.cuda() 74 | type_prompt = type_prompt.cuda() 75 | nature_prompt = nature_prompt.cuda() 76 | net.eval() 77 | with torch.no_grad(): 78 | if prompt: 79 | seg_out = net((input, position_prompt, task_prompt, type_prompt, nature_prompt))[0] 80 | else: 81 | seg_out = net(input)[0] 82 | out_label_back_transform = torch.cat( 83 | [seg_out[:, 0:1], seg_out[:, ClassStartIndex:ClassStartIndex+classes-1]], axis=1) 84 | out = torch.argmax(torch.softmax(out_label_back_transform, dim=1), dim=1).squeeze(0) 85 | prediction = out.cpu().detach().numpy() 86 | 87 | metric_list = [] 88 | for i in range(1, classes): # 这里的第二个维度的含义不一样,这里是类别数 89 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 90 | 91 | if test_save_path is not None: 92 | image = (image_save - np.min(image_save)) / (np.max(image_save) - np.min(image_save)) 93 | cv2.imwrite(test_save_path + '/'+case + "_pred.png", (prediction*255).astype(np.uint8)) 94 | cv2.imwrite(test_save_path + '/'+case + "_img.png", ((image.squeeze(0))*255).astype(np.uint8)) 95 | cv2.imwrite(test_save_path + '/'+case + "_gt.png", (label*255).astype(np.uint8)) 96 | return metric_list 97 | --------------------------------------------------------------------------------