├── README.md
├── config.py
├── configs
└── swin_tiny_patch4_window7_224_lite.yaml
├── data_demo
├── classification
│ └── UDIAT
│ │ ├── 0
│ │ └── demo_01.png
│ │ ├── 1
│ │ └── demo_02.png
│ │ ├── config.yaml
│ │ └── test.txt
└── segmentation
│ └── BUSIS
│ ├── config.yaml
│ ├── imgs
│ ├── demo_01.png
│ └── demo_02.png
│ ├── masks
│ ├── demo_01.png
│ └── demo_02.png
│ └── test.txt
├── datasets
├── dataset.py
└── omni_dataset.py
├── networks
└── omni_vision_transformer.py
├── omni_test.py
├── omni_train.py
├── omni_trainer.py
├── pretrained_ckpt
└── .gitkeeep
├── requirements.txt
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # UniUSNet: A Promptable Framework for Universal Ultrasound Disease Prediction and Tissue Segmentation
2 |
3 | UniUSNet is a universal framework for ultrasound image classification and segmentation, featuring:
4 |
5 | - A novel promptable module for incorporating detailed information into the model's learning process.
6 | - Versatility across various ultrasound natures, anatomical positions, and input types. Proficiency in both segmentation and classification tasks
7 | - Strong generalization capabilities demonstrated through zero-shot and fine-tuning experiments on new datasets.
8 |
9 | For more details, see the accompanying paper and [Project Page](https://zehui-lin.github.io/UniUSNet/),
10 |
11 | > [**UniUSNet: A Promptable Framework for Universal Ultrasound Disease Prediction and Tissue Segmentation**](https://doi.org/10.1109/BIBM62325.2024.10822429)
12 | Zehui Lin, Zhuoneng Zhang, Xindi Hu, Zhifan Gao, Xin Yang, Yue Sun, Dong Ni, Tao Tan. BIBM, 2024.
13 |
14 | ## Installation
15 | - Clone this repository.
16 | ```
17 | git clone https://github.com/Zehui-Lin/UniUSNet.git
18 | cd UniUSNet
19 | ```
20 | - Create a new conda environment.
21 | ```
22 | conda create -n UniUSNet python=3.10
23 | conda activate UniUSNet
24 | ```
25 | - Install the required packages.
26 | ```
27 | pip install -r requirements.txt
28 | ```
29 |
30 | ## Data
31 |
32 | - BroadUS-9.7K consists of ten publicly-available datasets, including [BUSI](https://www.kaggle.com/datasets/aryashah2k/breast-ultrasound-images-dataset), [BUSIS](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9025635/), [UDIAT](https://ieeexplore.ieee.org/abstract/document/8003418), [BUS-BRA](https://aapm.onlinelibrary.wiley.com/doi/abs/10.1002/mp.16812), [Fatty-Liver](https://link.springer.com/article/10.1007/s11548-018-1843-2#Sec8), [kidneyUS](http://rsingla.ca/kidneyUS/), [DDTI](https://www.kaggle.com/datasets/dasmehdixtr/ddti-thyroid-ultrasound-images/data), [Fetal HC](https://hc18.grand-challenge.org/), [CAMUS](https://www.creatis.insa-lyon.fr/Challenge/camus/index.html) and [Appendix](https://zenodo.org/records/7669442).
33 | - You can prepare the data by downloading the datasets and organizing them as follows:
34 |
35 | ```
36 | data
37 | ├── classification
38 | │ └── UDIAT
39 | │ ├── 0
40 | │ │ ├── 000001.png
41 | │ │ ├── ...
42 | │ ├── 1
43 | │ │ ├── 000100.png
44 | │ │ ├── ...
45 | │ ├── config.yaml
46 | │ ├── test.txt
47 | │ ├── train.txt
48 | │ └── val.txt
49 | │ └── ...
50 | └── segmentation
51 | └── BUSIS
52 | ├── config.yaml
53 | ├── imgs
54 | │ ├── 000001.png
55 | │ ├── ...
56 | ├── masks
57 | │ ├── 000001.png
58 | │ ├── ...
59 | ├── test.txt
60 | ├── train.txt
61 | └── val.txt
62 | └── ...
63 | ```
64 | - Please refer to the `data_demo` folder for examples.
65 |
66 | ## Training
67 | We use `torch.distributed` for multi-GPU training (also supports single GPU training). To train the model, run the following command:
68 | ```
69 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=1234 omni_train.py --output_dir exp_out/trial_1 --prompt
70 | ```
71 |
72 | ## Testing
73 | To test the model, run the following command:
74 | ```
75 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=1234 omni_test.py --output_dir exp_out/trial_1 --prompt
76 | ```
77 |
78 |
79 | ## Checkpoints
80 | - You can download the pre-trained checkpoints from [BaiduYun](https://pan.baidu.com/s/1uciwM5K4wRiMWnrAsB4qMQ?pwd=x390).
81 |
82 | ## Pretrained Weights
83 |
84 | To train your own model, please download the Swin Transformer backbone weights and place it in the `pretrained_ckpt/` directory:
85 |
86 | * [swin\_tiny\_patch4\_window7\_224.pth](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)
87 |
88 | The folder structure should look like:
89 |
90 | ```
91 | pretrained_ckpt
92 | └── swin_tiny_patch4_window7_224.pth
93 | ```
94 |
95 | ## Citation
96 | If you find this work useful, please consider citing:
97 |
98 | ```
99 | @inproceedings{lin2024uniusnet,
100 | title={UniUSNet: A Promptable Framework for Universal Ultrasound Disease Prediction and Tissue Segmentation},
101 | author={Lin, Zehui and Zhang, Zhuoneng and Hu, Xindi and Gao, Zhifan and Yang, Xin and Sun, Yue and Ni, Dong and Tan, Tao},
102 | booktitle={2024 IEEE International Conference on Bioinformatics and Biomedicine (BIBM)},
103 | pages={3501--3504},
104 | year={2024},
105 | organization={IEEE}
106 | }
107 | ```
108 |
109 | ## Acknowledgements
110 | This repository is based on the [Swin-Unet](https://github.com/HuCaoFighting/Swin-Unet) repository. We thank the authors for their contributions.
111 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------'
7 |
8 | import os
9 | import yaml
10 | from yacs.config import CfgNode as CN
11 |
12 | _C = CN()
13 |
14 | # Base config files
15 | _C.BASE = ['']
16 |
17 | # -----------------------------------------------------------------------------
18 | # Data settings
19 | # -----------------------------------------------------------------------------
20 | _C.DATA = CN()
21 | # Batch size for a single GPU, could be overwritten by command line argument
22 | _C.DATA.BATCH_SIZE = 128
23 | # Path to dataset, could be overwritten by command line argument
24 | _C.DATA.DATA_PATH = ''
25 | # Dataset name
26 | _C.DATA.DATASET = 'imagenet'
27 | # Input image size
28 | _C.DATA.IMG_SIZE = 224
29 | # Interpolation to resize image (random, bilinear, bicubic)
30 | _C.DATA.INTERPOLATION = 'bicubic'
31 | # Use zipped dataset instead of folder dataset
32 | # could be overwritten by command line argument
33 | _C.DATA.ZIP_MODE = False
34 | # Cache Data in Memory, could be overwritten by command line argument
35 | _C.DATA.CACHE_MODE = 'part'
36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
37 | _C.DATA.PIN_MEMORY = True
38 | # Number of data loading threads
39 | _C.DATA.NUM_WORKERS = 8
40 |
41 | # -----------------------------------------------------------------------------
42 | # Model settings
43 | # -----------------------------------------------------------------------------
44 | _C.MODEL = CN()
45 | # Model type
46 | _C.MODEL.TYPE = 'swin'
47 | # Model name
48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
49 | # Checkpoint to resume, could be overwritten by command line argument
50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth'
51 | _C.MODEL.RESUME = ''
52 | # Number of classes, overwritten in data preparation
53 | _C.MODEL.NUM_CLASSES = 1000
54 | # Dropout rate
55 | _C.MODEL.DROP_RATE = 0.0
56 | # Drop path rate
57 | _C.MODEL.DROP_PATH_RATE = 0.1
58 | # Label Smoothing
59 | _C.MODEL.LABEL_SMOOTHING = 0.1
60 |
61 | # Swin Transformer parameters
62 | _C.MODEL.SWIN = CN()
63 | _C.MODEL.SWIN.PATCH_SIZE = 4
64 | _C.MODEL.SWIN.IN_CHANS = 3
65 | _C.MODEL.SWIN.EMBED_DIM = 96
66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
67 | _C.MODEL.SWIN.ENCODER_DEPTHS = [2, 2, 6, 2]
68 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
69 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
70 | _C.MODEL.SWIN.WINDOW_SIZE = 7
71 | _C.MODEL.SWIN.MLP_RATIO = 4.
72 | _C.MODEL.SWIN.QKV_BIAS = True
73 | _C.MODEL.SWIN.QK_SCALE = None
74 | _C.MODEL.SWIN.APE = False
75 | _C.MODEL.SWIN.PATCH_NORM = True
76 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first"
77 |
78 | # -----------------------------------------------------------------------------
79 | # Training settings
80 | # -----------------------------------------------------------------------------
81 | _C.TRAIN = CN()
82 | _C.TRAIN.START_EPOCH = 0
83 | _C.TRAIN.EPOCHS = 300
84 | _C.TRAIN.WARMUP_EPOCHS = 20
85 | _C.TRAIN.WEIGHT_DECAY = 0.05
86 | _C.TRAIN.BASE_LR = 5e-4
87 | _C.TRAIN.WARMUP_LR = 5e-7
88 | _C.TRAIN.MIN_LR = 5e-6
89 | # Clip gradient norm
90 | _C.TRAIN.CLIP_GRAD = 5.0
91 | # Auto resume from latest checkpoint
92 | _C.TRAIN.AUTO_RESUME = True
93 | # Gradient accumulation steps
94 | # could be overwritten by command line argument
95 | _C.TRAIN.ACCUMULATION_STEPS = 0
96 | # Whether to use gradient checkpointing to save memory
97 | # could be overwritten by command line argument
98 | _C.TRAIN.USE_CHECKPOINT = False
99 |
100 | # LR scheduler
101 | _C.TRAIN.LR_SCHEDULER = CN()
102 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
103 | # Epoch interval to decay LR, used in StepLRScheduler
104 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
105 | # LR decay rate, used in StepLRScheduler
106 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
107 |
108 | # Optimizer
109 | _C.TRAIN.OPTIMIZER = CN()
110 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
111 | # Optimizer Epsilon
112 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
113 | # Optimizer Betas
114 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
115 | # SGD momentum
116 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
117 |
118 | # -----------------------------------------------------------------------------
119 | # Augmentation settings
120 | # -----------------------------------------------------------------------------
121 | _C.AUG = CN()
122 | # Color jitter factor
123 | _C.AUG.COLOR_JITTER = 0.4
124 | # Use AutoAugment policy. "v0" or "original"
125 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
126 | # Random erase prob
127 | _C.AUG.REPROB = 0.25
128 | # Random erase mode
129 | _C.AUG.REMODE = 'pixel'
130 | # Random erase count
131 | _C.AUG.RECOUNT = 1
132 | # Mixup alpha, mixup enabled if > 0
133 | _C.AUG.MIXUP = 0.8
134 | # Cutmix alpha, cutmix enabled if > 0
135 | _C.AUG.CUTMIX = 1.0
136 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
137 | _C.AUG.CUTMIX_MINMAX = None
138 | # Probability of performing mixup or cutmix when either/both is enabled
139 | _C.AUG.MIXUP_PROB = 1.0
140 | # Probability of switching to cutmix when both mixup and cutmix enabled
141 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
142 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
143 | _C.AUG.MIXUP_MODE = 'batch'
144 |
145 | # -----------------------------------------------------------------------------
146 | # Testing settings
147 | # -----------------------------------------------------------------------------
148 | _C.TEST = CN()
149 | # Whether to use center crop when testing
150 | _C.TEST.CROP = True
151 |
152 | # -----------------------------------------------------------------------------
153 | # Misc
154 | # -----------------------------------------------------------------------------
155 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
156 | # overwritten by command line argument
157 | _C.AMP_OPT_LEVEL = ''
158 | # Path to output folder, overwritten by command line argument
159 | _C.OUTPUT = ''
160 | # Tag of experiment, overwritten by command line argument
161 | _C.TAG = 'default'
162 | # Frequency to save checkpoint
163 | _C.SAVE_FREQ = 1
164 | # Frequency to logging info
165 | _C.PRINT_FREQ = 10
166 | # Fixed random seed
167 | _C.SEED = 0
168 | # Perform evaluation only, overwritten by command line argument
169 | _C.EVAL_MODE = False
170 | # Test throughput only, overwritten by command line argument
171 | _C.THROUGHPUT_MODE = False
172 | # local rank for DistributedDataParallel, given by command line argument
173 | _C.LOCAL_RANK = 0
174 |
175 |
176 | def _update_config_from_file(config, cfg_file):
177 | config.defrost()
178 | with open(cfg_file, 'r') as f:
179 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
180 |
181 | for cfg in yaml_cfg.setdefault('BASE', ['']):
182 | if cfg:
183 | _update_config_from_file(
184 | config, os.path.join(os.path.dirname(cfg_file), cfg)
185 | )
186 | print('=> merge config from {}'.format(cfg_file))
187 | config.merge_from_file(cfg_file)
188 | config.freeze()
189 |
190 |
191 | def update_config(config, args):
192 | _update_config_from_file(config, args.cfg)
193 |
194 | config.defrost()
195 | if args.opts:
196 | config.merge_from_list(args.opts)
197 |
198 | # merge from specific arguments
199 | if args.batch_size:
200 | config.DATA.BATCH_SIZE = args.batch_size
201 | if args.zip:
202 | config.DATA.ZIP_MODE = True
203 | if args.cache_mode:
204 | config.DATA.CACHE_MODE = args.cache_mode
205 | if args.resume:
206 | config.MODEL.RESUME = args.resume
207 | if args.accumulation_steps:
208 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
209 | if args.use_checkpoint:
210 | config.TRAIN.USE_CHECKPOINT = True
211 | if args.amp_opt_level:
212 | config.AMP_OPT_LEVEL = args.amp_opt_level
213 | if args.tag:
214 | config.TAG = args.tag
215 | if args.eval:
216 | config.EVAL_MODE = True
217 | if args.throughput:
218 | config.THROUGHPUT_MODE = True
219 |
220 | config.freeze()
221 |
222 |
223 | def get_config(args):
224 | """Get a yacs CfgNode object with default values."""
225 | # Return a clone so that the defaults will not be altered
226 | # This is for the "local variable" use pattern
227 | config = _C.clone()
228 | update_config(config, args)
229 |
230 | return config
231 |
--------------------------------------------------------------------------------
/configs/swin_tiny_patch4_window7_224_lite.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: swin_tiny_patch4_window7_224
4 | DROP_PATH_RATE: 0.2
5 | PRETRAIN_CKPT: "./pretrained_ckpt/swin_tiny_patch4_window7_224.pth"
6 | SWIN:
7 | FINAL_UPSAMPLE: "expand_first"
8 | EMBED_DIM: 96
9 | ENCODER_DEPTHS: [ 4, 4, 4, 4]
10 | DECODER_DEPTHS: [ 2, 2, 2, 2]
11 | NUM_HEADS: [ 3, 6, 12, 24 ]
12 | WINDOW_SIZE: 7
--------------------------------------------------------------------------------
/data_demo/classification/UDIAT/0/demo_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/classification/UDIAT/0/demo_01.png
--------------------------------------------------------------------------------
/data_demo/classification/UDIAT/1/demo_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/classification/UDIAT/1/demo_02.png
--------------------------------------------------------------------------------
/data_demo/classification/UDIAT/config.yaml:
--------------------------------------------------------------------------------
1 | 0:Benign
2 | 1:Malignant
3 |
--------------------------------------------------------------------------------
/data_demo/classification/UDIAT/test.txt:
--------------------------------------------------------------------------------
1 | 0/demo_01.png
2 | 1/demo_02.png
--------------------------------------------------------------------------------
/data_demo/segmentation/BUSIS/config.yaml:
--------------------------------------------------------------------------------
1 | 0:background:0
2 | 1:nodule:255
3 |
--------------------------------------------------------------------------------
/data_demo/segmentation/BUSIS/imgs/demo_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/imgs/demo_01.png
--------------------------------------------------------------------------------
/data_demo/segmentation/BUSIS/imgs/demo_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/imgs/demo_02.png
--------------------------------------------------------------------------------
/data_demo/segmentation/BUSIS/masks/demo_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/masks/demo_01.png
--------------------------------------------------------------------------------
/data_demo/segmentation/BUSIS/masks/demo_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/data_demo/segmentation/BUSIS/masks/demo_02.png
--------------------------------------------------------------------------------
/data_demo/segmentation/BUSIS/test.txt:
--------------------------------------------------------------------------------
1 | demo_01.png
2 | demo_02.png
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 | import torch
6 | from scipy import ndimage
7 | from scipy.ndimage.interpolation import zoom
8 | from torch.utils.data import Dataset
9 |
10 | from datasets.omni_dataset import position_prompt_dict
11 | from datasets.omni_dataset import nature_prompt_dict
12 |
13 | from datasets.omni_dataset import position_prompt_one_hot_dict
14 | from datasets.omni_dataset import nature_prompt_one_hot_dict
15 | from datasets.omni_dataset import type_prompt_one_hot_dict
16 |
17 |
18 | def random_horizontal_flip(image, label):
19 | axis = 1
20 | image = np.flip(image, axis=axis).copy()
21 | label = np.flip(label, axis=axis).copy()
22 | return image, label
23 |
24 |
25 | def random_rotate(image, label):
26 | angle = np.random.randint(-20, 20)
27 | image = ndimage.rotate(image, angle, order=0, reshape=False)
28 | label = ndimage.rotate(label, angle, order=0, reshape=False)
29 | return image, label
30 |
31 |
32 | class RandomGenerator(object):
33 | def __init__(self, output_size):
34 | self.output_size = output_size
35 |
36 | def __call__(self, sample):
37 | image, label = sample['image'], sample['label']
38 | if 'type_prompt' in sample:
39 | type_prompt = sample['type_prompt']
40 |
41 | if random.random() > 0.5:
42 | image, label = random_horizontal_flip(image, label)
43 | elif random.random() > 0.5:
44 | image, label = random_rotate(image, label)
45 | x, y, _ = image.shape
46 |
47 | if x > y:
48 | image = zoom(image, (self.output_size[0] / y, self.output_size[1] / y, 1), order=1)
49 | label = zoom(label, (self.output_size[0] / y, self.output_size[1] / y), order=0)
50 | else:
51 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / x, 1), order=1)
52 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / x), order=0)
53 |
54 | scale = random.uniform(0.8, 1.2)
55 | image = zoom(image, (scale, scale, 1), order=1)
56 | label = zoom(label, (scale, scale), order=0)
57 |
58 | x, y, _ = image.shape
59 | if scale > 1:
60 | startx = x//2 - (self.output_size[0]//2)
61 | starty = y//2 - (self.output_size[1]//2)
62 | image = image[startx:startx+self.output_size[0], starty:starty+self.output_size[1], :]
63 | label = label[startx:startx+self.output_size[0], starty:starty+self.output_size[1]]
64 | else:
65 | if x > self.output_size[0]:
66 | startx = x//2 - (self.output_size[0]//2)
67 | image = image[startx:startx+self.output_size[0], :, :]
68 | label = label[startx:startx+self.output_size[0], :]
69 | if y > self.output_size[1]:
70 | starty = y//2 - (self.output_size[1]//2)
71 | image = image[:, starty:starty+self.output_size[1], :]
72 | label = label[:, starty:starty+self.output_size[1]]
73 | x, y, _ = image.shape
74 | new_image = np.zeros((self.output_size[0], self.output_size[1], 3))
75 | new_label = np.zeros((self.output_size[0], self.output_size[1]))
76 | if x < y:
77 | startx = self.output_size[0]//2 - (x//2)
78 | starty = 0
79 | new_image[startx:startx+x, starty:starty+y, :] = image
80 | new_label[startx:startx+x, starty:starty+y] = label
81 | else:
82 | startx = 0
83 | starty = self.output_size[1]//2 - (y//2)
84 | new_image[startx:startx+x, starty:starty+y, :] = image
85 | new_label[startx:startx+x, starty:starty+y] = label
86 | image = new_image
87 | label = new_label
88 |
89 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
90 | label = torch.from_numpy(label.astype(np.float32))
91 | if 'type_prompt' in sample:
92 | sample = {'image': image, 'label': label.long(), 'type_prompt': type_prompt}
93 | else:
94 | sample = {'image': image, 'label': label.long()}
95 | return sample
96 |
97 |
98 | class CenterCropGenerator(object):
99 | def __init__(self, output_size):
100 | self.output_size = output_size
101 |
102 | def __call__(self, sample):
103 | image, label = sample['image'], sample['label']
104 | if 'type_prompt' in sample:
105 | type_prompt = sample['type_prompt']
106 | x, y, _ = image.shape
107 | if x > y:
108 | image = zoom(image, (self.output_size[0] / y, self.output_size[1] / y, 1), order=1)
109 | label = zoom(label, (self.output_size[0] / y, self.output_size[1] / y), order=0)
110 | else:
111 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / x, 1), order=1)
112 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / x), order=0)
113 | x, y, _ = image.shape
114 | startx = x//2 - (self.output_size[0]//2)
115 | starty = y//2 - (self.output_size[1]//2)
116 | image = image[startx:startx+self.output_size[0], starty:starty+self.output_size[1], :]
117 | label = label[startx:startx+self.output_size[0], starty:starty+self.output_size[1]]
118 |
119 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
120 | label = torch.from_numpy(label.astype(np.float32))
121 | if 'type_prompt' in sample:
122 | sample = {'image': image, 'label': label.long(), 'type_prompt': type_prompt}
123 | else:
124 | sample = {'image': image, 'label': label.long()}
125 | return sample
126 |
127 |
128 | class USdatasetSeg(Dataset):
129 | def __init__(self, base_dir, list_dir, split, transform=None, prompt=False):
130 | self.transform = transform
131 | self.split = split
132 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
133 |
134 | # BUSI
135 | self.sample_list = [sample for sample in self.sample_list if not "normal" in sample]
136 |
137 | self.data_dir = base_dir
138 | self.label_info = open(os.path.join(list_dir, "config.yaml")).readlines()
139 | self.prompt = prompt
140 |
141 | def __len__(self):
142 | return len(self.sample_list)
143 |
144 | def __getitem__(self, idx):
145 |
146 | img_name = self.sample_list[idx].strip('\n')
147 | img_path = os.path.join(self.data_dir, "imgs", img_name)
148 | label_path = os.path.join(self.data_dir, "masks", img_name)
149 |
150 | image = cv2.imread(img_path)
151 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
152 |
153 | label_info_list = [info.strip().split(":") for info in self.label_info]
154 | for single_label_info in label_info_list:
155 | label_index = int(single_label_info[0])
156 | label_value_in_image = int(single_label_info[2])
157 | label[label == label_value_in_image] = label_index
158 |
159 | label[label > 0] = 1
160 |
161 | sample = {'image': image/255.0, 'label': label}
162 | if self.transform:
163 | sample = self.transform(sample)
164 | sample['case_name'] = self.sample_list[idx].strip('\n')
165 | if self.prompt:
166 | dataset_name = img_path.split("/")[-3]
167 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"]
168 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]]
169 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]]
170 | return sample
171 |
172 |
173 | class USdatasetCls(Dataset):
174 | def __init__(self, base_dir, list_dir, split, transform=None, prompt=False):
175 | self.transform = transform # using transform in torch!
176 | self.split = split
177 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
178 |
179 | # BUSI
180 | self.sample_list = [sample for sample in self.sample_list if not "normal" in sample]
181 |
182 | self.data_dir = base_dir
183 | self.label_info = open(os.path.join(list_dir, "config.yaml")).readlines()
184 | self.prompt = prompt
185 |
186 | def __len__(self):
187 | return len(self.sample_list)
188 |
189 | def __getitem__(self, idx):
190 |
191 | img_name = self.sample_list[idx].strip('\n')
192 | img_path = os.path.join(self.data_dir, img_name)
193 |
194 | image = cv2.imread(img_path)
195 | label = int(img_name.split("/")[0])
196 |
197 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])}
198 | if self.transform:
199 | sample = self.transform(sample)
200 | sample['label'] = torch.from_numpy(np.array(label))
201 | sample['case_name'] = self.sample_list[idx].strip('\n')
202 | if self.prompt:
203 | dataset_name = img_path.split("/")[-3]
204 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"]
205 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]]
206 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]]
207 | return sample
208 |
--------------------------------------------------------------------------------
/datasets/omni_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from torch.utils.data.distributed import DistributedSampler
8 | from torch import Tensor
9 | from typing import Sequence
10 |
11 | # prompt info dict
12 | # task prompt
13 | task_prompt_list = [
14 | "segmentation",
15 | "classification",
16 | ]
17 | # position prompt
18 | position_prompt_dict = {
19 | "BUS-BRA": "breast",
20 | "BUSIS": "breast",
21 | "CAMUS": "cardiac",
22 | "DDTI": "thyroid",
23 | "Fetal_HC": "head",
24 | "kidneyUS": "kidney",
25 | "UDIAT": "breast",
26 | "Appendix": "appendix",
27 | "Fatty-Liver": "liver",
28 | "BUSI": "breast",
29 | }
30 | # nature prompt
31 | nature_prompt_dict = {
32 | "BUS-BRA": "tumor",
33 | "BUSIS": "tumor",
34 | "CAMUS": "organ",
35 | "DDTI": "tumor",
36 | "Fetal_HC": "organ",
37 | "kidneyUS": "organ",
38 | "UDIAT": "organ",
39 | "Appendix": "organ",
40 | "Fatty-Liver": "organ",
41 | "BUSI": "tumor",
42 | }
43 | # type prompt
44 | available_type_prompt_list = [
45 | "BUS-BRA",
46 | "BUSIS",
47 | "CAMUS",
48 | "DDTI",
49 | "Fetal_HC",
50 | "kidneyUS",
51 | "UDIAT",
52 | "BUSI"
53 | ]
54 |
55 | # prompt one-hot
56 | # organ prompt
57 | position_prompt_one_hot_dict = {
58 | "breast": [1, 0, 0, 0, 0, 0, 0, 0],
59 | "cardiac": [0, 1, 0, 0, 0, 0, 0, 0],
60 | "thyroid": [0, 0, 1, 0, 0, 0, 0, 0],
61 | "head": [0, 0, 0, 1, 0, 0, 0, 0],
62 | "kidney": [0, 0, 0, 0, 1, 0, 0, 0],
63 | "appendix": [0, 0, 0, 0, 0, 1, 0, 0],
64 | "liver": [0, 0, 0, 0, 0, 0, 1, 0],
65 | "indis": [0, 0, 0, 0, 0, 0, 0, 1]
66 | }
67 | # task prompt
68 | task_prompt_one_hot_dict = {
69 | "segmentation": [1, 0],
70 | "classification": [0, 1]
71 | }
72 | # nature prompt
73 | nature_prompt_one_hot_dict = {
74 | "tumor": [1, 0],
75 | "organ": [0, 1],
76 | }
77 | # type prompt
78 | type_prompt_one_hot_dict = {
79 | "whole": [1, 0, 0],
80 | "local": [0, 1, 0],
81 | "location": [0, 0, 1],
82 | }
83 |
84 |
85 | def list_add_prefix(txt_path, prefix_1, prefix_2):
86 |
87 | with open(txt_path, 'r') as f:
88 | lines = f.readlines()
89 | if prefix_2 is not None:
90 | return [os.path.join(prefix_1, prefix_2, line.strip('\n')) for line in lines]
91 | else:
92 | return [os.path.join(prefix_1, line.strip('\n')) for line in lines]
93 |
94 |
95 | class WeightedRandomSamplerDDP(DistributedSampler):
96 | r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
97 |
98 | Args:
99 | data_set: Dataset used for sampling.
100 | weights (sequence) : a sequence of weights, not necessary summing up to one
101 | num_replicas (int, optional): Number of processes participating in
102 | distributed training. By default, :attr:`world_size` is retrieved from the
103 | current distributed group.
104 | rank (int, optional): Rank of the current process within :attr:`num_replicas`.
105 | By default, :attr:`rank` is retrieved from the current distributed
106 | group.
107 | num_samples (int): number of samples to draw
108 | replacement (bool): if ``True``, samples are drawn with replacement.
109 | If not, they are drawn without replacement, which means that when a
110 | sample index is drawn for a row, it cannot be drawn again for that row.
111 | generator (Generator): Generator used in sampling.
112 | """
113 | weights: Tensor
114 | num_samples: int
115 | replacement: bool
116 |
117 | def __init__(self, data_set, weights: Sequence[float], num_replicas: int, rank: int, num_samples: int,
118 | replacement: bool = True, generator=None) -> None:
119 | super(WeightedRandomSamplerDDP, self).__init__(data_set, num_replicas, rank)
120 | if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
121 | num_samples <= 0:
122 | raise ValueError("num_samples should be a positive integer "
123 | "value, but got num_samples={}".format(num_samples))
124 | if not isinstance(replacement, bool):
125 | raise ValueError("replacement should be a boolean value, but got "
126 | "replacement={}".format(replacement))
127 | self.weights = torch.as_tensor(weights, dtype=torch.double)
128 | self.num_samples = num_samples
129 | self.replacement = replacement
130 | self.generator = generator
131 | self.num_replicas = num_replicas
132 | self.rank = rank
133 | self.weights = self.weights[self.rank::self.num_replicas]
134 | self.num_samples = self.num_samples // self.num_replicas
135 |
136 | def __iter__(self):
137 | rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
138 | rand_tensor = self.rank + rand_tensor * self.num_replicas
139 | return iter(rand_tensor.tolist())
140 |
141 | def __len__(self):
142 | return self.num_samples
143 |
144 |
145 | class USdatasetOmni_seg(Dataset):
146 | def __init__(self, base_dir, split, transform=None, prompt=False):
147 | self.transform = transform
148 | self.split = split
149 | self.data_dir = base_dir
150 | self.sample_list = []
151 | self.subset_len = []
152 | self.prompt = prompt
153 |
154 | self.sample_list.extend(list_add_prefix(os.path.join(
155 | base_dir, "segmentation", "BUS-BRA", split + ".txt"), "BUS-BRA", "imgs"))
156 | self.sample_list.extend(list_add_prefix(os.path.join(
157 | base_dir, "segmentation", "BUSIS", split + ".txt"), "BUSIS", "imgs"))
158 | self.sample_list.extend(list_add_prefix(os.path.join(
159 | base_dir, "segmentation", "CAMUS", split + ".txt"), "CAMUS", "imgs"))
160 | self.sample_list.extend(list_add_prefix(os.path.join(
161 | base_dir, "segmentation", "DDTI", split + ".txt"), "DDTI", "imgs"))
162 | self.sample_list.extend(list_add_prefix(os.path.join(base_dir, "segmentation",
163 | "Fetal_HC", split + ".txt"), "Fetal_HC", "imgs"))
164 | self.sample_list.extend(list_add_prefix(os.path.join(base_dir, "segmentation",
165 | "kidneyUS", split + ".txt"), "kidneyUS", "imgs"))
166 | self.sample_list.extend(list_add_prefix(os.path.join(
167 | base_dir, "segmentation", "UDIAT", split + ".txt"), "UDIAT", "imgs"))
168 |
169 | self.subset_len.append(len(list_add_prefix(os.path.join(
170 | base_dir, "segmentation", "BUS-BRA", split + ".txt"), "BUS-BRA", "imgs")))
171 | self.subset_len.append(len(list_add_prefix(os.path.join(
172 | base_dir, "segmentation", "BUSIS", split + ".txt"), "BUSIS", "imgs")))
173 | self.subset_len.append(len(list_add_prefix(os.path.join(
174 | base_dir, "segmentation", "CAMUS", split + ".txt"), "CAMUS", "imgs")))
175 | self.subset_len.append(len(list_add_prefix(os.path.join(
176 | base_dir, "segmentation", "DDTI", split + ".txt"), "DDTI", "imgs")))
177 | self.subset_len.append(len(list_add_prefix(os.path.join(
178 | base_dir, "segmentation", "Fetal_HC", split + ".txt"), "Fetal_HC", "imgs")))
179 | self.subset_len.append(len(list_add_prefix(os.path.join(
180 | base_dir, "segmentation", "kidneyUS", split + ".txt"), "kidneyUS", "imgs")))
181 | self.subset_len.append(len(list_add_prefix(os.path.join(
182 | base_dir, "segmentation", "UDIAT", split + ".txt"), "UDIAT", "imgs")))
183 |
184 | def __len__(self):
185 | return len(self.sample_list)
186 |
187 | def __getitem__(self, idx):
188 |
189 | img_name = self.sample_list[idx].strip('\n')
190 | img_path = os.path.join(self.data_dir, "segmentation", img_name)
191 | label_path = os.path.join(self.data_dir, "segmentation", img_name).replace("imgs", "masks")
192 |
193 | image = cv2.imread(img_path)
194 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
195 |
196 | dataset_name = img_name.split("/")[0]
197 | label_info = open(os.path.join(self.data_dir, "segmentation", dataset_name, "config.yaml")).readlines()
198 |
199 | label_info_list = [info.strip().split(":") for info in label_info]
200 | for single_label_info in label_info_list:
201 | label_index = int(single_label_info[0])
202 | label_value_in_image = int(single_label_info[2])
203 | label[label == label_value_in_image] = label_index
204 |
205 | label[label > 0] = 1
206 |
207 | if not self.prompt:
208 | sample = {'image': image/255.0, 'label': label}
209 | else:
210 | if random.random() > 0.5:
211 | x, y, w, h = cv2.boundingRect(label)
212 | length = max(w, h)
213 |
214 | if 0 in image[y:y+length, x:x+length, :].shape:
215 | image = image
216 | label = label
217 | sample = {'image': image/255.0, 'label': label}
218 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"]
219 | else:
220 | image = image[y:y+length, x:x+length, :]
221 | label = label[y:y+length, x:x+length]
222 | sample = {'image': image/255.0, 'label': label}
223 | sample['type_prompt'] = type_prompt_one_hot_dict["local"]
224 |
225 | else:
226 | sample = {'image': image/255.0, 'label': label}
227 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"]
228 | pass
229 | if self.transform:
230 | sample = self.transform(sample)
231 | sample['case_name'] = self.sample_list[idx].strip('\n')
232 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]]
233 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]]
234 | sample['task_prompt'] = task_prompt_one_hot_dict["segmentation"]
235 |
236 | return sample
237 |
238 |
239 | class USdatasetOmni_cls(Dataset):
240 | def __init__(self, base_dir, split, transform=None, prompt=False):
241 | self.transform = transform
242 | self.split = split
243 | self.data_dir = base_dir
244 | self.sample_list = []
245 | self.subset_len = []
246 | self.prompt = prompt
247 |
248 | self.sample_list.extend(list_add_prefix(os.path.join(
249 | base_dir, "classification", "Appendix", split + ".txt"), "Appendix", None))
250 | self.sample_list.extend(list_add_prefix(os.path.join(
251 | base_dir, "classification", "BUS-BRA", split + ".txt"), "BUS-BRA", None))
252 | self.sample_list.extend(list_add_prefix(os.path.join(base_dir, "classification",
253 | "Fatty-Liver", split + ".txt"), "Fatty-Liver", None))
254 | self.sample_list.extend(list_add_prefix(os.path.join(
255 | base_dir, "classification", "UDIAT", split + ".txt"), "UDIAT", None))
256 |
257 | self.subset_len.append(len(list_add_prefix(os.path.join(
258 | base_dir, "classification", "Appendix", split + ".txt"), "Appendix", None)))
259 | self.subset_len.append(len(list_add_prefix(os.path.join(
260 | base_dir, "classification", "BUS-BRA", split + ".txt"), "BUS-BRA", None)))
261 | self.subset_len.append(len(list_add_prefix(os.path.join(base_dir, "classification",
262 | "Fatty-Liver", split + ".txt"), "Fatty-Liver", None)))
263 | self.subset_len.append(len(list_add_prefix(os.path.join(
264 | base_dir, "classification", "UDIAT", split + ".txt"), "UDIAT", None)))
265 |
266 | def __len__(self):
267 | return len(self.sample_list)
268 |
269 | def __getitem__(self, idx):
270 |
271 | img_name = self.sample_list[idx].strip('\n')
272 | img_path = os.path.join(self.data_dir, "classification", img_name)
273 |
274 | image = cv2.imread(img_path)
275 | dataset_name = img_name.split("/")[0]
276 | label = int(img_name.split("/")[-2])
277 |
278 | if not self.prompt:
279 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])}
280 | else:
281 | if dataset_name in available_type_prompt_list:
282 | random_number = random.random()
283 | mask_path = os.path.join(self.data_dir, "segmentation",
284 | "/".join([img_name.split("/")[0], "masks", img_name.split("/")[2]]))
285 | if random_number < 0.3:
286 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])}
287 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"]
288 | elif random_number < 0.6:
289 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
290 | x, y, w, h = cv2.boundingRect(mask)
291 | length = max(w, h)
292 |
293 | if 0 in image[y:y+length, x:x+length, :].shape:
294 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])}
295 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"]
296 | else:
297 | image = image[y:y+length, x:x+length, :]
298 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])}
299 | sample['type_prompt'] = type_prompt_one_hot_dict["local"]
300 | else:
301 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
302 | mask[mask > 0] = 255
303 | image = image + (np.expand_dims(mask, axis=2)*0.1).astype('uint8')
304 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])}
305 | sample['type_prompt'] = type_prompt_one_hot_dict["location"]
306 | else:
307 | sample = {'image': image/255.0, 'label': np.zeros(image.shape[:2])}
308 | sample['type_prompt'] = type_prompt_one_hot_dict["whole"]
309 | if self.transform:
310 | sample = self.transform(sample)
311 | sample['label'] = torch.from_numpy(np.array(label))
312 | sample['case_name'] = self.sample_list[idx].strip('\n')
313 | sample['nature_prompt'] = nature_prompt_one_hot_dict[nature_prompt_dict[dataset_name]]
314 | sample['position_prompt'] = position_prompt_one_hot_dict[position_prompt_dict[dataset_name]]
315 | sample['task_prompt'] = task_prompt_one_hot_dict["classification"]
316 |
317 | return sample
318 |
--------------------------------------------------------------------------------
/networks/omni_vision_transformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | import torch.utils.checkpoint as checkpoint
8 | from timm.models.layers import trunc_normal_
9 |
10 | from torch.functional import F
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.utils.checkpoint as checkpoint
15 | from einops import rearrange
16 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17 | from torch.functional import F
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class Mlp(nn.Module):
23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
24 | super().__init__()
25 | out_features = out_features or in_features
26 | hidden_features = hidden_features or in_features
27 | self.fc1 = nn.Linear(in_features, hidden_features)
28 | self.act = act_layer()
29 | self.fc2 = nn.Linear(hidden_features, out_features)
30 | self.drop = nn.Dropout(drop)
31 |
32 | def forward(self, x):
33 | x = self.fc1(x)
34 | x = self.act(x)
35 | x = self.drop(x)
36 | x = self.fc2(x)
37 | x = self.drop(x)
38 | return x
39 |
40 |
41 | def window_partition(x, window_size):
42 | """
43 | Args:
44 | x: (B, H, W, C)
45 | window_size (int): window size
46 |
47 | Returns:
48 | windows: (num_windows*B, window_size, window_size, C)
49 | """
50 | B, H, W, C = x.shape
51 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
52 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
53 | return windows
54 |
55 |
56 | def window_reverse(windows, window_size, H, W):
57 | """
58 | Args:
59 | windows: (num_windows*B, window_size, window_size, C)
60 | window_size (int): Window size
61 | H (int): Height of image
62 | W (int): Width of image
63 |
64 | Returns:
65 | x: (B, H, W, C)
66 | """
67 | B = int(windows.shape[0] / (H * W / window_size / window_size))
68 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
69 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
70 | return x
71 |
72 |
73 | class WindowAttention(nn.Module):
74 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
75 | It supports both of shifted and non-shifted window.
76 |
77 | Args:
78 | dim (int): Number of input channels.
79 | window_size (tuple[int]): The height and width of the window.
80 | num_heads (int): Number of attention heads.
81 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
82 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
83 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
84 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
85 | """
86 |
87 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
88 |
89 | super().__init__()
90 | self.dim = dim
91 | self.window_size = window_size # Wh, Ww
92 | self.num_heads = num_heads
93 | head_dim = dim // num_heads
94 | self.scale = qk_scale or head_dim ** -0.5
95 |
96 | # define a parameter table of relative position bias
97 | self.relative_position_bias_table = nn.Parameter(
98 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
99 |
100 | # get pair-wise relative position index for each token inside the window
101 | coords_h = torch.arange(self.window_size[0])
102 | coords_w = torch.arange(self.window_size[1])
103 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
104 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
105 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
106 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
107 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
108 | relative_coords[:, :, 1] += self.window_size[1] - 1
109 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
110 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
111 | self.register_buffer("relative_position_index", relative_position_index)
112 |
113 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
114 | self.attn_drop = nn.Dropout(attn_drop)
115 | self.proj = nn.Linear(dim, dim)
116 | self.proj_drop = nn.Dropout(proj_drop)
117 |
118 | trunc_normal_(self.relative_position_bias_table, std=.02)
119 | self.softmax = nn.Softmax(dim=-1)
120 |
121 | def forward(self, x, mask=None):
122 | """
123 | Args:
124 | x: input features with shape of (num_windows*B, N, C)
125 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
126 | """
127 | B_, N, C = x.shape
128 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
129 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
130 |
131 | q = q * self.scale
132 | attn = (q @ k.transpose(-2, -1))
133 |
134 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
135 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
136 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137 | attn = attn + relative_position_bias.unsqueeze(0)
138 |
139 | if mask is not None:
140 | nW = mask.shape[0]
141 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
142 | attn = attn.view(-1, self.num_heads, N, N)
143 | attn = self.softmax(attn)
144 | else:
145 | attn = self.softmax(attn)
146 |
147 | attn = self.attn_drop(attn)
148 |
149 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
150 | x = self.proj(x)
151 | x = self.proj_drop(x)
152 | return x
153 |
154 | def extra_repr(self) -> str:
155 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
156 |
157 | def flops(self, N):
158 | # calculate flops for 1 window with token length of N
159 | flops = 0
160 | # qkv = self.qkv(x)
161 | flops += N * self.dim * 3 * self.dim
162 | # attn = (q @ k.transpose(-2, -1))
163 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
164 | # x = (attn @ v)
165 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
166 | # x = self.proj(x)
167 | flops += N * self.dim * self.dim
168 | return flops
169 |
170 |
171 | class SwinTransformerBlock(nn.Module):
172 | r""" Swin Transformer Block.
173 |
174 | Args:
175 | dim (int): Number of input channels.
176 | input_resolution (tuple[int]): Input resolution.
177 | num_heads (int): Number of attention heads.
178 | window_size (int): Window size.
179 | shift_size (int): Shift size for SW-MSA.
180 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
181 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
182 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
183 | drop (float, optional): Dropout rate. Default: 0.0
184 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
185 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
186 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
187 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
188 | """
189 |
190 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
191 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
192 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
193 | super().__init__()
194 | self.dim = dim
195 | self.input_resolution = input_resolution
196 | self.num_heads = num_heads
197 | self.window_size = window_size
198 | self.shift_size = shift_size
199 | self.mlp_ratio = mlp_ratio
200 | if min(self.input_resolution) <= self.window_size:
201 | # if window size is larger than input resolution, we don't partition windows
202 | self.shift_size = 0
203 | self.window_size = min(self.input_resolution)
204 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
205 |
206 | self.norm1 = norm_layer(dim)
207 | self.attn = WindowAttention(
208 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
209 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
210 |
211 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
212 | self.norm2 = norm_layer(dim)
213 | mlp_hidden_dim = int(dim * mlp_ratio)
214 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
215 |
216 | if self.shift_size > 0:
217 | # calculate attention mask for SW-MSA
218 | H, W = self.input_resolution
219 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
220 | h_slices = (slice(0, -self.window_size),
221 | slice(-self.window_size, -self.shift_size),
222 | slice(-self.shift_size, None))
223 | w_slices = (slice(0, -self.window_size),
224 | slice(-self.window_size, -self.shift_size),
225 | slice(-self.shift_size, None))
226 | cnt = 0
227 | for h in h_slices:
228 | for w in w_slices:
229 | img_mask[:, h, w, :] = cnt
230 | cnt += 1
231 |
232 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
233 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
234 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
235 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
236 | else:
237 | attn_mask = None
238 |
239 | self.register_buffer("attn_mask", attn_mask)
240 |
241 | def forward(self, x):
242 | H, W = self.input_resolution
243 | B, L, C = x.shape
244 | assert L == H * W, "input feature has wrong size"
245 |
246 | shortcut = x
247 | x = self.norm1(x)
248 | x = x.view(B, H, W, C)
249 |
250 | # cyclic shift
251 | if self.shift_size > 0:
252 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
253 | else:
254 | shifted_x = x
255 |
256 | # partition windows
257 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
258 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
259 |
260 | # W-MSA/SW-MSA
261 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
262 |
263 | # merge windows
264 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
265 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
266 |
267 | # reverse cyclic shift
268 | if self.shift_size > 0:
269 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
270 | else:
271 | x = shifted_x
272 | x = x.view(B, H * W, C)
273 |
274 | # FFN
275 | x = shortcut + self.drop_path(x)
276 | x = x + self.drop_path(self.mlp(self.norm2(x)))
277 |
278 | return x
279 |
280 | def extra_repr(self) -> str:
281 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
282 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
283 |
284 | def flops(self):
285 | flops = 0
286 | H, W = self.input_resolution
287 | # norm1
288 | flops += self.dim * H * W
289 | # W-MSA/SW-MSA
290 | nW = H * W / self.window_size / self.window_size
291 | flops += nW * self.attn.flops(self.window_size * self.window_size)
292 | # mlp
293 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
294 | # norm2
295 | flops += self.dim * H * W
296 | return flops
297 |
298 |
299 | class FinalPatchExpand_X4(nn.Module):
300 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
301 | super().__init__()
302 | self.input_resolution = input_resolution
303 | self.dim = dim
304 | self.dim_scale = dim_scale
305 | self.expand = nn.Linear(dim, 16*dim, bias=False)
306 | self.output_dim = dim
307 | self.norm = norm_layer(self.output_dim)
308 |
309 | def forward(self, x):
310 | """
311 | x: B, H*W, C
312 | """
313 | H, W = self.input_resolution
314 | x = self.expand(x)
315 | B, L, C = x.shape
316 | assert L == H * W, "input feature has wrong size"
317 |
318 | x = x.view(B, H, W, C)
319 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale,
320 | p2=self.dim_scale, c=C//(self.dim_scale**2))
321 | x = x.view(B, -1, self.output_dim)
322 | x = self.norm(x)
323 |
324 | return x
325 |
326 |
327 | class PatchMerging(nn.Module):
328 | r""" Patch Merging Layer.
329 |
330 | Args:
331 | input_resolution (tuple[int]): Resolution of input feature.
332 | dim (int): Number of input channels.
333 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
334 | """
335 |
336 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
337 | super().__init__()
338 | self.input_resolution = input_resolution
339 | self.dim = dim
340 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
341 | self.norm = norm_layer(4 * dim)
342 |
343 | def forward(self, x):
344 | """
345 | x: B, H*W, C
346 | """
347 | H, W = self.input_resolution
348 | B, L, C = x.shape
349 | assert L == H * W, "input feature has wrong size"
350 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
351 |
352 | x = x.view(B, H, W, C)
353 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
354 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
355 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
356 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
357 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
358 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
359 |
360 | x = self.norm(x)
361 | x = self.reduction(x)
362 |
363 | return x
364 |
365 | def extra_repr(self) -> str:
366 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
367 |
368 | def flops(self):
369 | H, W = self.input_resolution
370 | flops = H * W * self.dim
371 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim # reduction 4 * self.dim -> 2 * self.dim
372 | return flops
373 |
374 |
375 | class PatchExpand(nn.Module):
376 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
377 | super().__init__()
378 | self.input_resolution = input_resolution
379 | self.dim = dim
380 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity()
381 | self.norm = norm_layer(dim // dim_scale)
382 |
383 | def forward(self, x):
384 | """
385 | x: B, H*W, C
386 | """
387 | H, W = self.input_resolution
388 | x = self.expand(x)
389 | B, L, C = x.shape
390 | assert L == H * W, "input feature has wrong size"
391 |
392 | x = x.view(B, H, W, C)
393 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
394 | x = x.view(B, -1, C//4)
395 | x = self.norm(x)
396 |
397 | return x
398 |
399 |
400 | class ChannelHalf(nn.Module):
401 | def __init__(self, input_resolution=None, dim=0, norm_layer=nn.LayerNorm):
402 | super().__init__()
403 | self.linear = nn.Linear(dim, dim // 2, bias=False)
404 | self.norm = norm_layer(dim // 2)
405 | self.input_resolution = input_resolution
406 |
407 | def forward(self, x):
408 | x = self.linear(x)
409 | x = self.norm(x)
410 | return x
411 |
412 |
413 | class PatchEmbed(nn.Module):
414 | r""" Image to Patch Embedding
415 |
416 | Args:
417 | img_size (int): Image size. Default: 224.
418 | patch_size (int): Patch token size. Default: 4.
419 | in_chans (int): Number of input image channels. Default: 3.
420 | embed_dim (int): Number of linear projection output channels. Default: 96.
421 | norm_layer (nn.Module, optional): Normalization layer. Default: None
422 | """
423 |
424 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
425 | super().__init__()
426 | img_size = to_2tuple(img_size)
427 | patch_size = to_2tuple(patch_size)
428 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
429 | self.img_size = img_size
430 | self.patch_size = patch_size
431 | self.patches_resolution = patches_resolution
432 | self.num_patches = patches_resolution[0] * patches_resolution[1]
433 |
434 | self.in_chans = in_chans
435 | self.embed_dim = embed_dim
436 |
437 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
438 | if norm_layer is not None:
439 | self.norm = norm_layer(embed_dim)
440 | else:
441 | self.norm = None
442 |
443 | def forward(self, x):
444 | B, C, H, W = x.shape
445 | # FIXME look at relaxing size constraints
446 | assert H == self.img_size[0] and W == self.img_size[1], \
447 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
448 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
449 | if self.norm is not None:
450 | x = self.norm(x)
451 | return x
452 |
453 | def flops(self):
454 | Ho, Wo = self.patches_resolution
455 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
456 | if self.norm is not None:
457 | flops += Ho * Wo * self.embed_dim
458 | return flops
459 |
460 |
461 | class BasicLayer(nn.Module):
462 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
463 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
464 | drop_path=0., norm_layer=nn.LayerNorm, res_scale=None, use_checkpoint=False,
465 | ):
466 |
467 | super().__init__()
468 | self.dim = dim
469 | self.input_resolution = input_resolution
470 | self.depth = depth
471 | self.use_checkpoint = use_checkpoint
472 |
473 | # build blocks
474 | self.blocks = nn.ModuleList([
475 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
476 | num_heads=num_heads, window_size=window_size,
477 | shift_size=0 if (
478 | i % 2 == 0) else window_size // 2,
479 | mlp_ratio=mlp_ratio,
480 | qkv_bias=qkv_bias, qk_scale=qk_scale,
481 | drop=drop, attn_drop=attn_drop,
482 | drop_path=drop_path[i] if isinstance(
483 | drop_path, list) else drop_path,
484 | norm_layer=norm_layer)
485 | for i in range(depth)])
486 |
487 | # patch merging layer
488 | if res_scale is not None:
489 | self.res_scale = res_scale(input_resolution, dim)
490 | else:
491 | self.res_scale = None
492 |
493 | def forward(self, x):
494 | for blk in self.blocks:
495 | if self.use_checkpoint:
496 | x = checkpoint.checkpoint(blk, x)
497 | else:
498 | x = blk(x)
499 | if self.res_scale is not None:
500 | x = self.res_scale(x)
501 | return x
502 |
503 |
504 | class SwinTransformer(nn.Module):
505 | def __init__(self, img_size=224, patch_size=4, in_chans=3,
506 | embed_dim=96,
507 | encoder_depths=[2, 2, 2, 2],
508 | decoder_depths=[2, 2, 2, 2],
509 | num_heads=[3, 6, 12, 24],
510 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
511 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
512 | norm_layer=nn.LayerNorm, patch_norm=True,
513 | ape=False,
514 | use_checkpoint=False,
515 | prompt=False,
516 | ):
517 | super().__init__()
518 |
519 | print("SwinTransformer architecture information:")
520 |
521 | self.num_layers = len(encoder_depths)
522 | self.embed_dim = embed_dim
523 | self.ape = ape
524 | self.patch_norm = patch_norm
525 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
526 | self.mlp_ratio = mlp_ratio
527 | self.prompt = prompt
528 |
529 | self.patch_embed = PatchEmbed(
530 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
531 | norm_layer=norm_layer if self.patch_norm else None)
532 | num_patches = self.patch_embed.num_patches
533 | patches_resolution = self.patch_embed.patches_resolution
534 | self.patches_resolution = patches_resolution
535 |
536 | # absolute position embedding
537 | if self.ape:
538 | self.absolute_pos_embed = nn.Parameter(
539 | torch.zeros(1, num_patches, embed_dim))
540 | trunc_normal_(self.absolute_pos_embed, std=.02)
541 |
542 | # learnable prompt embedding
543 | if self.prompt:
544 | self.dec_prompt_mlp = nn.Linear(8+2+2+3, embed_dim*8)
545 | self.dec_prompt_mlp_cls2 = nn.Linear(8+2+2+3, embed_dim*4)
546 | self.dec_prompt_mlp_seg2_cls3 = nn.Linear(8+2+2+3, embed_dim*2)
547 | self.dec_prompt_mlp_seg3 = nn.Linear(8+2+2+3, embed_dim*1)
548 |
549 | self.pos_drop = nn.Dropout(p=drop_rate)
550 |
551 | # stochastic depth
552 | enc_dpr = [x.item() for x in torch.linspace(
553 | 0, drop_path_rate, sum(encoder_depths))]
554 |
555 | ## Encoder + bottleneck ##
556 | self.layers = nn.ModuleList()
557 | for i_layer in range(self.num_layers):
558 |
559 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
560 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
561 | patches_resolution[1] // (2 ** i_layer)),
562 | depth=encoder_depths[i_layer],
563 | num_heads=num_heads[i_layer],
564 | window_size=window_size,
565 | mlp_ratio=self.mlp_ratio,
566 | qkv_bias=qkv_bias, qk_scale=qk_scale,
567 | drop=drop_rate, attn_drop=attn_drop_rate,
568 | drop_path=enc_dpr[sum(encoder_depths[:i_layer]):sum(encoder_depths[:i_layer + 1])],
569 | norm_layer=norm_layer,
570 | res_scale=PatchMerging if (i_layer < self.num_layers - 1) else None,
571 | use_checkpoint=use_checkpoint
572 | )
573 | self.layers.append(layer)
574 |
575 | ## Multi Decoder ##
576 |
577 | self.layers_task_seg_up = nn.ModuleList()
578 | self.layers_task_seg_skip = nn.ModuleList()
579 | self.layers_task_seg_head = nn.ModuleList()
580 |
581 | self.layers_task_cls_up = nn.ModuleList()
582 | self.layers_task_cls_head = nn.ModuleList()
583 |
584 | # stochastic depth
585 | dec_dpr = [x.item() for x in torch.linspace(
586 | 0, drop_path_rate, sum(decoder_depths))]
587 |
588 | for i_layer in range(self.num_layers):
589 | # seg
590 | self.layers_task_seg_skip.append(
591 | nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),
592 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()
593 | )
594 | if i_layer == 0:
595 | self.layers_task_seg_up.append(
596 | PatchExpand(input_resolution=(
597 | patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
598 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
599 | dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
600 | dim_scale=2, norm_layer=norm_layer))
601 | else:
602 | self.layers_task_seg_up.append(
603 | BasicLayer(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
604 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
605 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
606 | depth=decoder_depths[(self.num_layers-1-i_layer)],
607 | num_heads=num_heads[(
608 | self.num_layers-1-i_layer)],
609 | window_size=window_size,
610 | mlp_ratio=self.mlp_ratio,
611 | qkv_bias=qkv_bias, qk_scale=qk_scale,
612 | drop=drop_rate, attn_drop=attn_drop_rate,
613 | drop_path=dec_dpr[sum(decoder_depths[:(
614 | self.num_layers-1-i_layer)]):sum(decoder_depths[:(self.num_layers-1-i_layer) + 1])],
615 | norm_layer=norm_layer,
616 | res_scale=PatchExpand if (i_layer < self.num_layers - 1) else None,
617 | use_checkpoint=use_checkpoint,
618 | )
619 | )
620 | # cls
621 | if i_layer == 0:
622 | pass
623 | else:
624 | self.layers_task_cls_up.append(
625 | BasicLayer(dim=int(embed_dim * 2 ** (self.num_layers-i_layer)),
626 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-0)),
627 | patches_resolution[1] // (2 ** (self.num_layers-1-0))),
628 | depth=decoder_depths[(self.num_layers-i_layer)],
629 | num_heads=num_heads[(self.num_layers-i_layer)],
630 | window_size=window_size,
631 | mlp_ratio=self.mlp_ratio,
632 | qkv_bias=qkv_bias, qk_scale=qk_scale,
633 | drop=drop_rate, attn_drop=attn_drop_rate,
634 | drop_path=dec_dpr[sum(decoder_depths[:(self.num_layers-i_layer)]):sum(decoder_depths[:(self.num_layers-i_layer) + 1])],
635 | norm_layer=norm_layer,
636 | res_scale=ChannelHalf if (i_layer < self.num_layers - 1) else None,
637 | use_checkpoint=use_checkpoint
638 | ))
639 |
640 | self.layers_task_seg_head.append(
641 | FinalPatchExpand_X4(input_resolution=(img_size//patch_size, img_size//patch_size), dim=embed_dim)
642 | )
643 | self.layers_task_seg_head.append(
644 | nn.Conv2d(in_channels=embed_dim, out_channels=2, kernel_size=1, bias=False)
645 | )
646 | self.layers_task_cls_head.append(
647 | nn.Linear(self.embed_dim*2, 2)
648 | )
649 |
650 | ## Norm Layer ##
651 | self.norm = norm_layer(self.num_features)
652 | self.norm_task_seg = norm_layer(self.embed_dim)
653 | self.norm_task_cls = norm_layer(self.embed_dim*2)
654 |
655 | self.apply(self._init_weights)
656 |
657 | def _init_weights(self, m):
658 | if isinstance(m, nn.Linear):
659 | trunc_normal_(m.weight, std=.02)
660 | if isinstance(m, nn.Linear) and m.bias is not None:
661 | nn.init.constant_(m.bias, 0)
662 | elif isinstance(m, nn.LayerNorm):
663 | nn.init.constant_(m.bias, 0)
664 | nn.init.constant_(m.weight, 1.0)
665 |
666 | @torch.jit.ignore
667 | def no_weight_decay(self):
668 | return {'absolute_pos_embed'}
669 |
670 | @torch.jit.ignore
671 | def no_weight_decay_keywords(self):
672 | return {'relative_position_bias_table'}
673 |
674 | # Encoder and Bottleneck
675 | def forward_features(self, x):
676 | x = self.patch_embed(x)
677 | if self.ape:
678 | x = x + self.absolute_pos_embed
679 |
680 | x = self.pos_drop(x)
681 | x_downsample = []
682 |
683 | for layer in self.layers:
684 | x_downsample.append(x)
685 | x = layer(x)
686 |
687 | x = self.norm(x)
688 |
689 | return x, x_downsample
690 |
691 | # Decoder task head
692 | def forward_task_features(self, x, x_downsample):
693 | if self.prompt:
694 | x, position_prompt, task_prompt, type_prompt, nature_prompt = x
695 |
696 | # seg
697 | for inx, layer_seg in enumerate(self.layers_task_seg_up):
698 | if inx == 0:
699 | x_seg = layer_seg(x)
700 | else:
701 | x_seg = torch.cat([x_seg, x_downsample[3-inx]], -1)
702 | x_seg = self.layers_task_seg_skip[inx](x_seg)
703 |
704 | if self.prompt and inx > 1:
705 | if inx == 2:
706 | x_seg = layer_seg(x_seg +
707 | self.dec_prompt_mlp_seg2_cls3(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1))
708 | if inx == 3:
709 | x_seg = layer_seg(x_seg +
710 | self.dec_prompt_mlp_seg3(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1))
711 | else:
712 | x_seg = layer_seg(x_seg)
713 |
714 | x_seg = self.norm_task_seg(x_seg)
715 |
716 | H, W = self.patches_resolution
717 | B, _, _ = x_seg.shape
718 | x_seg = self.layers_task_seg_head[0](x_seg)
719 | x_seg = x_seg.view(B, 4*H, 4*W, -1)
720 | x_seg = x_seg.permute(0, 3, 1, 2)
721 | x_seg = self.layers_task_seg_head[1](x_seg)
722 |
723 | # cls
724 | for inx, layer_head in enumerate(self.layers_task_cls_up):
725 | if inx == 0:
726 | x_cls = layer_head(x)
727 | else:
728 | if self.prompt:
729 | if inx == 1:
730 | x_cls = layer_head(x_cls +
731 | self.dec_prompt_mlp_cls2(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1))
732 | if inx == 2:
733 | x_cls = layer_head(x_cls +
734 | self.dec_prompt_mlp_seg2_cls3(torch.cat([position_prompt, task_prompt, type_prompt, nature_prompt], dim=1)).unsqueeze(1))
735 | else:
736 | x_cls = layer_head(x_cls)
737 |
738 | x_cls = self.norm_task_cls(x_cls)
739 |
740 | B, _, _ = x_cls.shape
741 | x_cls = x_cls.transpose(1, 2)
742 | x_cls = F.adaptive_avg_pool1d(x_cls, 1).view(B, -1)
743 | x_cls = self.layers_task_cls_head[0](x_cls)
744 |
745 | return (x_seg, x_cls)
746 |
747 | def forward(self, x):
748 | if self.prompt:
749 | x, position_prompt, task_prompt, type_prompt, nature_prompt = x
750 | x, x_downsample = self.forward_features(x)
751 | x = x + self.dec_prompt_mlp(torch.cat([position_prompt, task_prompt,
752 | type_prompt, nature_prompt], dim=1)).unsqueeze(1)
753 | x_tuple = self.forward_task_features(
754 | (x, position_prompt, task_prompt, type_prompt, nature_prompt), x_downsample)
755 | else:
756 | x, x_downsample = self.forward_features(x)
757 | x_tuple = self.forward_task_features(x, x_downsample)
758 | return x_tuple
759 |
760 |
761 | class OmniVisionTransformer(nn.Module):
762 | def __init__(self, config,
763 | prompt=False,
764 | ):
765 | super(OmniVisionTransformer, self).__init__()
766 | self.config = config
767 | self.prompt = prompt
768 |
769 | self.swin = SwinTransformer(img_size=config.DATA.IMG_SIZE,
770 | patch_size=config.MODEL.SWIN.PATCH_SIZE,
771 | in_chans=config.MODEL.SWIN.IN_CHANS,
772 | embed_dim=config.MODEL.SWIN.EMBED_DIM,
773 | encoder_depths=config.MODEL.SWIN.ENCODER_DEPTHS,
774 | decoder_depths=config.MODEL.SWIN.DECODER_DEPTHS,
775 | num_heads=config.MODEL.SWIN.NUM_HEADS,
776 | window_size=config.MODEL.SWIN.WINDOW_SIZE,
777 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
778 | qkv_bias=config.MODEL.SWIN.QKV_BIAS,
779 | qk_scale=config.MODEL.SWIN.QK_SCALE,
780 | drop_rate=config.MODEL.DROP_RATE,
781 | drop_path_rate=config.MODEL.DROP_PATH_RATE,
782 | ape=config.MODEL.SWIN.APE,
783 | patch_norm=config.MODEL.SWIN.PATCH_NORM,
784 | use_checkpoint=config.TRAIN.USE_CHECKPOINT,
785 | prompt=prompt,
786 | )
787 |
788 | def forward(self, x):
789 | if self.prompt:
790 | image = x[0].squeeze(1).permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
791 | position_prompt = x[1]
792 | task_prompt = x[2]
793 | type_prompt = x[3]
794 | nature_prompt = x[4]
795 | result = self.swin((image, position_prompt, task_prompt, type_prompt, nature_prompt))
796 | else:
797 | x = x.squeeze(1).permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
798 | result = self.swin(x)
799 | return result
800 |
801 | def load_from(self, config):
802 | pretrained_path = config.MODEL.PRETRAIN_CKPT
803 | if pretrained_path is not None:
804 | print("pretrained_path:{}".format(pretrained_path))
805 | device = torch.device(
806 | 'cuda' if torch.cuda.is_available() else 'cpu')
807 | pretrained_dict = torch.load(pretrained_path, map_location=device)
808 | pretrained_dict = pretrained_dict['model']
809 | print("---start load pretrained model of swin encoder---")
810 | model_dict = self.swin.state_dict()
811 | full_dict = copy.deepcopy(pretrained_dict)
812 | for k, v in pretrained_dict.items():
813 | if "layers." in k:
814 | current_layer_num = 3-int(k[7:8])
815 | current_k = "layers_up." + str(current_layer_num) + k[8:]
816 | full_dict.update({current_k: v})
817 | for k in list(full_dict.keys()):
818 | if k in model_dict:
819 | if full_dict[k].shape != model_dict[k].shape:
820 | print("delete:{};shape pretrain:{};shape model:{}".format(
821 | k, v.shape, model_dict[k].shape))
822 | del full_dict[k]
823 |
824 | self.swin.load_state_dict(full_dict, strict=False)
825 | else:
826 | print("none pretrain")
827 |
828 | def load_from_self(self, pretrained_path):
829 | print("pretrained_path:{}".format(pretrained_path))
830 | device = torch.device(
831 | 'cuda' if torch.cuda.is_available() else 'cpu')
832 | pretrained_dict = torch.load(pretrained_path, map_location=device)
833 | full_dict = copy.deepcopy(pretrained_dict)
834 | for k, v in pretrained_dict.items():
835 | if "module.swin." in k:
836 | current_k = k[12:]
837 | full_dict.update({current_k: v})
838 | del full_dict[k]
839 |
840 | self.swin.load_state_dict(full_dict)
841 |
--------------------------------------------------------------------------------
/omni_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import random
5 | import sys
6 | import numpy as np
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 | from torch.utils.data import DataLoader
10 | from tqdm import tqdm
11 |
12 | from config import get_config
13 |
14 | from datasets.dataset import CenterCropGenerator
15 | from datasets.dataset import USdatasetCls, USdatasetSeg
16 |
17 | from utils import omni_seg_test
18 | from sklearn.metrics import accuracy_score
19 |
20 | from networks.omni_vision_transformer import OmniVisionTransformer as ViT_omni
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--root_path', type=str,
24 | default='data_demo/', help='root dir for data')
25 | parser.add_argument('--output_dir', type=str, help='output dir')
26 | parser.add_argument('--max_epochs', type=int, default=200, help='maximum epoch number to train')
27 | parser.add_argument('--batch_size', type=int, default=16,
28 | help='batch_size per gpu')
29 | parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input')
30 | parser.add_argument('--is_saveout', action="store_true", help='whether to save results during inference')
31 | parser.add_argument('--test_save_dir', type=str, default='../predictions', help='saving prediction as nii!')
32 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training')
33 | parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate')
34 | parser.add_argument('--seed', type=int, default=1234, help='random seed')
35 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml",
36 | metavar="FILE", help='path to config file', )
37 | parser.add_argument(
38 | "--opts",
39 | help="Modify config options by adding 'KEY VALUE' pairs. ",
40 | default=None,
41 | nargs='+',
42 | )
43 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
44 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
45 | help='no: no cache, '
46 | 'full: cache all data, '
47 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
48 | parser.add_argument('--resume', help='resume from checkpoint')
49 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
50 | parser.add_argument('--use-checkpoint', action='store_true',
51 | help="whether to use gradient checkpointing to save memory")
52 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
53 | help='mixed precision opt level, if O0, no amp is used')
54 | parser.add_argument('--tag', help='tag of experiment')
55 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
56 | parser.add_argument('--throughput', action='store_true', help='Test throughput only')
57 |
58 | parser.add_argument('--prompt', action='store_true', help='using prompt')
59 |
60 | args = parser.parse_args()
61 | config = get_config(args)
62 |
63 |
64 | def inference(args, model, test_save_path=None):
65 | import csv
66 | import time
67 |
68 | if not os.path.exists("exp_out/result.csv"):
69 | with open("exp_out/result.csv", 'w', newline='') as csvfile:
70 | writer = csv.writer(csvfile)
71 | writer.writerow(['dataset', 'task', 'metric', 'time'])
72 |
73 | seg_test_set = ["BUS-BRA", "BUSIS", "CAMUS", "DDTI", "Fetal_HC", "kidneyUS", "UDIAT"]
74 |
75 | for dataset_name in seg_test_set:
76 | num_classes = 2
77 | db_test = USdatasetSeg(
78 | base_dir=os.path.join(args.root_path, "segmentation", dataset_name),
79 | split="test",
80 | list_dir=os.path.join(args.root_path, "segmentation", dataset_name),
81 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]),
82 | prompt=args.prompt
83 | )
84 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
85 | logging.info("{} test iterations per epoch".format(len(testloader)))
86 | model.eval()
87 |
88 | metric_list = 0.0
89 | count_matrix = np.ones((len(db_test), num_classes-1))
90 | for i_batch, sampled_batch in tqdm(enumerate(testloader)):
91 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]
92 | if args.prompt:
93 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([1, 0]).float()
94 | task_prompt = torch.tensor(np.array([[1], [0]])).permute([1, 0]).float()
95 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float()
96 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float()
97 | metric_i = omni_seg_test(image, label, model,
98 | classes=num_classes,
99 | test_save_path=test_save_path,
100 | case=case_name,
101 | prompt=args.prompt,
102 | type_prompt=type_prompt,
103 | nature_prompt=nature_prompt,
104 | position_prompt=position_prompt,
105 | task_prompt=task_prompt
106 | )
107 | else:
108 | metric_i = omni_seg_test(image, label, model,
109 | classes=num_classes,
110 | test_save_path=test_save_path,
111 | case=case_name)
112 | zero_label_flag = False
113 | for i in range(1, num_classes):
114 | if not metric_i[i-1][1]:
115 | count_matrix[i_batch, i-1] = 0
116 | zero_label_flag = True
117 | metric_i = [element[0] for element in metric_i]
118 | metric_list += np.array(metric_i)
119 | logging.info('idx %d case %s mean_dice %f' %
120 | (i_batch, case_name, np.mean(metric_i, axis=0)))
121 | logging.info("This case has zero label: %s" % zero_label_flag)
122 |
123 | metric_list = metric_list / (count_matrix.sum(axis=0) + 1e-6)
124 | for i in range(1, num_classes):
125 | logging.info('Mean class %d mean_dice %f' % (i, metric_list[i-1]))
126 | performance = np.mean(metric_list, axis=0)
127 | logging.info('Testing performance in best val model: mean_dice : %f' % (performance))
128 |
129 | with open("exp_out/result.csv", 'a', newline='') as csvfile:
130 | writer = csv.writer(csvfile)
131 | if args.prompt:
132 | writer.writerow([dataset_name, 'omni_seg_prompt@'+args.output_dir, performance,
133 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())])
134 | else:
135 | writer.writerow([dataset_name, 'omni_seg@'+args.output_dir, performance,
136 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())])
137 |
138 | cls_test_set = ["Appendix", "BUS-BRA", "Fatty-Liver", "UDIAT"]
139 |
140 | for dataset_name in cls_test_set:
141 | num_classes = 2
142 | db_test = USdatasetCls(
143 | base_dir=os.path.join(args.root_path, "classification", dataset_name),
144 | split="test",
145 | list_dir=os.path.join(args.root_path, "classification", dataset_name),
146 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]),
147 | prompt=args.prompt
148 | )
149 |
150 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
151 | logging.info("{} test iterations per epoch".format(len(testloader)))
152 | model.eval()
153 |
154 | label_list = []
155 | prediction_list = []
156 | for i_batch, sampled_batch in tqdm(enumerate(testloader)):
157 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]
158 | if args.prompt:
159 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([1, 0]).float()
160 | task_prompt = torch.tensor(np.array([[0], [1]])).permute([1, 0]).float()
161 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float()
162 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float()
163 | with torch.no_grad():
164 | output = model((image.cuda(), position_prompt.cuda(), task_prompt.cuda(),
165 | type_prompt.cuda(), nature_prompt.cuda()))[1]
166 | else:
167 | with torch.no_grad():
168 | output = model(image.cuda())[1]
169 |
170 | output = np.argmax(torch.softmax(output, dim=1).data.cpu().numpy())
171 | logging.info('idx %d case %s label: %d predict: %d' % (i_batch, case_name, label, output))
172 |
173 | label_list.append(label.numpy())
174 | prediction_list.append(output)
175 |
176 | label_list = np.array(label_list)
177 | prediction_list = np.array(prediction_list)
178 | for i in range(num_classes):
179 | logging.info('class %d acc %f' % (i, accuracy_score(
180 | (label_list == i).astype(int), (prediction_list == i).astype(int))))
181 | performance = accuracy_score(label_list, prediction_list)
182 | logging.info('Testing performance in best val model: acc : %f' % (performance))
183 |
184 | with open("exp_out/result.csv", 'a', newline='') as csvfile:
185 | writer = csv.writer(csvfile)
186 | if args.prompt:
187 | writer.writerow([dataset_name, 'omni_cls_prompt@'+args.output_dir, performance,
188 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())])
189 | else:
190 | writer.writerow([dataset_name, 'omni_cls@'+args.output_dir, performance,
191 | time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())])
192 |
193 |
194 | if __name__ == "__main__":
195 | if not args.deterministic:
196 | cudnn.benchmark = True
197 | cudnn.deterministic = False
198 | else:
199 | cudnn.benchmark = False
200 | cudnn.deterministic = True
201 | random.seed(args.seed)
202 | np.random.seed(args.seed)
203 | torch.manual_seed(args.seed)
204 | torch.cuda.manual_seed(args.seed)
205 |
206 | net = ViT_omni(
207 | config,
208 | prompt=args.prompt,
209 | ).cuda()
210 | net.load_from(config)
211 |
212 | snapshot = os.path.join(args.output_dir, 'best_model.pth')
213 | if not os.path.exists(snapshot):
214 | snapshot = snapshot.replace('best_model', 'epoch_'+str(args.max_epochs-1))
215 |
216 | device = torch.device("cuda")
217 | model = net.to(device=device)
218 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
219 | torch.distributed.init_process_group(backend="nccl", init_method='env://', world_size=1, rank=0)
220 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
221 |
222 | import copy
223 | pretrained_dict = torch.load(snapshot, map_location=device)
224 | full_dict = copy.deepcopy(pretrained_dict)
225 | for k, v in pretrained_dict.items():
226 | if "module." not in k:
227 | full_dict["module."+k] = v
228 | del full_dict[k]
229 |
230 | msg = model.load_state_dict(full_dict)
231 |
232 | print("self trained swin unet", msg)
233 | snapshot_name = snapshot.split('/')[-1]
234 |
235 | logging.basicConfig(filename=args.output_dir+"/"+"test_result.txt", level=logging.INFO,
236 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
237 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
238 | logging.info(str(args))
239 | logging.info(snapshot_name)
240 |
241 | if args.is_saveout:
242 | args.test_save_dir = os.path.join(args.output_dir, "predictions")
243 | test_save_path = args.test_save_dir
244 | os.makedirs(test_save_path, exist_ok=True)
245 | else:
246 | test_save_path = None
247 | inference(args, net, test_save_path)
248 |
--------------------------------------------------------------------------------
/omni_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | from networks.omni_vision_transformer import OmniVisionTransformer as ViT_omni
8 | from omni_trainer import omni_train
9 | from config import get_config
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--root_path', type=str,
13 | default='data_demo/', help='root dir for data')
14 | parser.add_argument('--output_dir', type=str, help='output dir')
15 | parser.add_argument('--max_epochs', type=int,
16 | default=200, help='maximum epoch number to train')
17 | parser.add_argument('--batch_size', type=int,
18 | default=16, help='batch_size per gpu')
19 | parser.add_argument('--gpu', type=str, default=None)
20 | parser.add_argument('--deterministic', type=int, default=1,
21 | help='whether use deterministic training')
22 | parser.add_argument('--base_lr', type=float, default=0.01,
23 | help='segmentation network learning rate')
24 | parser.add_argument('--img_size', type=int,
25 | default=224, help='input patch size of network input')
26 | parser.add_argument('--seed', type=int,
27 | default=1234, help='random seed')
28 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml",
29 | metavar="FILE", help='path to config file', )
30 | parser.add_argument(
31 | "--opts",
32 | help="Modify config options by adding 'KEY VALUE' pairs. ",
33 | default=None,
34 | nargs='+',
35 | )
36 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
37 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
38 | help='no: no cache, '
39 | 'full: cache all data, '
40 | 'part: sharding the dataset into non-overlapping pieces and only cache one piece')
41 | parser.add_argument('--resume', help='resume from checkpoint')
42 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
43 | parser.add_argument('--use-checkpoint', action='store_true',
44 | help="whether to use gradient checkpointing to save memory")
45 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
46 | help='mixed precision opt level, if O0, no amp is used')
47 | parser.add_argument('--tag', help='tag of experiment')
48 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
49 | parser.add_argument('--throughput', action='store_true', help='Test throughput only')
50 |
51 | parser.add_argument('--pretrain_ckpt', type=str, help='pretrained checkpoint')
52 |
53 | parser.add_argument('--prompt', action='store_true', help='using prompt for training')
54 | parser.add_argument('--adapter_ft', action='store_true', help='using adapter for fine-tuning')
55 |
56 |
57 | args = parser.parse_args()
58 |
59 | config = get_config(args)
60 |
61 |
62 | if __name__ == "__main__":
63 | if not args.deterministic:
64 | cudnn.benchmark = True
65 | cudnn.deterministic = False
66 | else:
67 | cudnn.benchmark = False
68 | cudnn.deterministic = True
69 |
70 | random.seed(args.seed)
71 | np.random.seed(args.seed)
72 | torch.manual_seed(args.seed)
73 | torch.cuda.manual_seed(args.seed)
74 |
75 | if args.batch_size != 24 and args.batch_size % 6 == 0:
76 | args.base_lr *= args.batch_size / 24
77 |
78 | if not os.path.exists(args.output_dir):
79 | os.makedirs(args.output_dir, exist_ok=True)
80 |
81 | net = ViT_omni(
82 | config,
83 | prompt=args.prompt,
84 | ).cuda()
85 | if args.pretrain_ckpt is not None:
86 | net.load_from_self(args.pretrain_ckpt)
87 | else:
88 | net.load_from(config)
89 |
90 | if args.prompt and args.adapter_ft:
91 |
92 | for name, param in net.named_parameters():
93 | if 'prompt' in name:
94 | param.requires_grad = True
95 | print(name)
96 | else:
97 | param.requires_grad = False
98 |
99 | omni_train(args, net, args.output_dir)
100 |
--------------------------------------------------------------------------------
/omni_trainer.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import random
5 | import logging
6 | import datetime
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import torch
11 | import torch.optim as optim
12 | import torch.distributed as dist
13 | from torch.nn.modules.loss import CrossEntropyLoss
14 | from torch.utils.data import DataLoader
15 | from torchvision import transforms
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 |
19 | from utils import DiceLoss
20 | from datasets.dataset import USdatasetCls, USdatasetSeg
21 | from datasets.omni_dataset import WeightedRandomSamplerDDP
22 | from datasets.omni_dataset import USdatasetOmni_cls, USdatasetOmni_seg
23 | from datasets.dataset import RandomGenerator, CenterCropGenerator
24 | from sklearn.metrics import roc_auc_score
25 | from utils import omni_seg_test
26 |
27 |
28 | def omni_train(args, model, snapshot_path):
29 |
30 | if args.gpu:
31 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
32 |
33 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
34 | device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
35 | gpu_id = rank = int(os.environ["LOCAL_RANK"])
36 | world_size = int(os.environ["WORLD_SIZE"])
37 | torch.distributed.init_process_group(backend="nccl", init_method='env://', timeout=datetime.timedelta(seconds=7200))
38 |
39 | if int(os.environ["LOCAL_RANK"]) == 0:
40 | print('** GPU NUM ** : ', torch.cuda.device_count())
41 | print('** WORLD SIZE ** : ', torch.distributed.get_world_size())
42 | print(f"** DDP ** : Start running on rank {rank}.")
43 |
44 | logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
45 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
46 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
47 | logging.info(str(args))
48 | base_lr = args.base_lr
49 | batch_size = args.batch_size
50 |
51 | def worker_init_fn(worker_id):
52 | random.seed(args.seed + worker_id)
53 |
54 | db_train_seg = USdatasetOmni_seg(base_dir=args.root_path, split="train", transform=transforms.Compose(
55 | [RandomGenerator(output_size=[args.img_size, args.img_size])]), prompt=args.prompt)
56 |
57 | weight_base = [1/4, 1/2, 2, 2, 1, 2, 2]
58 | sample_weight_seq = [[weight_base[dataset_index]] *
59 | element for dataset_index, element in enumerate(db_train_seg.subset_len)]
60 | sample_weight_seq = [element for sublist in sample_weight_seq for element in sublist]
61 |
62 | weighted_sampler_seg = WeightedRandomSamplerDDP(
63 | data_set=db_train_seg,
64 | weights=sample_weight_seq,
65 | num_replicas=world_size,
66 | rank=rank,
67 | num_samples=args.num_samples_seg,
68 | replacement=True
69 | )
70 | trainloader_seg = DataLoader(db_train_seg,
71 | batch_size=batch_size,
72 | num_workers=16,
73 | pin_memory=True,
74 | worker_init_fn=worker_init_fn,
75 | sampler=weighted_sampler_seg
76 | )
77 |
78 | db_train_cls = USdatasetOmni_cls(base_dir=args.root_path, split="train", transform=transforms.Compose(
79 | [RandomGenerator(output_size=[args.img_size, args.img_size])]), prompt=args.prompt)
80 |
81 | weight_base = [2, 1/4, 2, 2]
82 | sample_weight_seq = [[weight_base[dataset_index]] *
83 | element for dataset_index, element in enumerate(db_train_cls.subset_len)]
84 | sample_weight_seq = [element for sublist in sample_weight_seq for element in sublist]
85 |
86 | weighted_sampler_cls = WeightedRandomSamplerDDP(
87 | data_set=db_train_cls,
88 | weights=sample_weight_seq,
89 | num_replicas=world_size,
90 | rank=rank,
91 | num_samples=args.num_samples_cls,
92 | replacement=True
93 | )
94 | trainloader_cls = DataLoader(db_train_cls,
95 | batch_size=batch_size,
96 | num_workers=16,
97 | pin_memory=True,
98 | worker_init_fn=worker_init_fn,
99 | sampler=weighted_sampler_cls
100 | )
101 |
102 | model = model.to(device=device)
103 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
104 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_id], find_unused_parameters=True)
105 |
106 | model.train()
107 |
108 | seg_ce_loss = CrossEntropyLoss()
109 | seg_dice_loss = DiceLoss()
110 | cls_ce_loss = CrossEntropyLoss()
111 |
112 | optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.05, betas=(0.9, 0.999))
113 |
114 | resume_epoch = 0
115 | if args.resume is not None:
116 | model.load_state_dict(torch.load(args.resume, map_location='cpu')['model'])
117 | optimizer.load_state_dict(torch.load(args.resume, map_location='cpu')['optimizer'])
118 | resume_epoch = torch.load(args.resume, map_location='cpu')['epoch']
119 |
120 | writer = SummaryWriter(snapshot_path + '/log')
121 | global_iter_num = 0
122 | seg_iter_num = 0
123 | cls_iter_num = 0
124 | max_epoch = args.max_epochs
125 | total_iterations = (len(trainloader_seg) + len(trainloader_cls))
126 | max_iterations = args.max_epochs * total_iterations
127 | logging.info("{} batch size. {} iterations per epoch. {} max iterations ".format(
128 | batch_size, total_iterations, max_iterations))
129 | best_performance = 0.0
130 | best_epoch = 0
131 |
132 | if int(os.environ["LOCAL_RANK"]) != 0:
133 | iterator = tqdm(range(resume_epoch, max_epoch), ncols=70, disable=True)
134 | else:
135 | iterator = tqdm(range(resume_epoch, max_epoch), ncols=70, disable=False)
136 |
137 | for epoch_num in iterator:
138 | logging.info("\n epoch: {}".format(epoch_num))
139 | weighted_sampler_seg.set_epoch(epoch_num)
140 | weighted_sampler_cls.set_epoch(epoch_num)
141 |
142 | torch.cuda.empty_cache()
143 | for i_batch, sampled_batch in tqdm(enumerate(trainloader_seg)):
144 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
145 | image_batch, label_batch = image_batch.to(device=device), label_batch.to(device=device)
146 | if args.prompt:
147 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([
148 | 1, 0]).float().to(device=device)
149 | task_prompt = torch.tensor(np.array(sampled_batch['task_prompt'])).permute([
150 | 1, 0]).float().to(device=device)
151 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([
152 | 1, 0]).float().to(device=device)
153 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([
154 | 1, 0]).float().to(device=device)
155 | (x_seg, _) = model((image_batch, position_prompt, task_prompt, type_prompt, nature_prompt))
156 | else:
157 | (x_seg, _) = model(image_batch)
158 |
159 | loss_ce = seg_ce_loss(x_seg, label_batch[:].long())
160 | loss_dice = seg_dice_loss(x_seg, label_batch, softmax=True)
161 | loss = 0.4 * loss_ce + 0.6 * loss_dice
162 |
163 | optimizer.zero_grad()
164 | loss.backward()
165 | optimizer.step()
166 | lr_ = base_lr * (1.0 - global_iter_num / max_iterations) ** 0.9
167 | for param_group in optimizer.param_groups:
168 | param_group['lr'] = lr_
169 |
170 | seg_iter_num = seg_iter_num + 1
171 | global_iter_num = global_iter_num + 1
172 |
173 | writer.add_scalar('info/lr', lr_, seg_iter_num)
174 | writer.add_scalar('info/seg_loss', loss, seg_iter_num)
175 |
176 | logging.info('global iteration %d and seg iteration %d : loss : %f' %
177 | (global_iter_num, seg_iter_num, loss.item()))
178 |
179 | torch.cuda.empty_cache()
180 | for i_batch, sampled_batch in tqdm(enumerate(trainloader_cls)):
181 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
182 | image_batch, label_batch = image_batch.to(device=device), label_batch.to(device=device)
183 | if args.prompt:
184 | position_prompt = torch.tensor(np.array(sampled_batch['position_prompt'])).permute([
185 | 1, 0]).float().to(device=device)
186 | task_prompt = torch.tensor(np.array(sampled_batch['task_prompt'])).permute([
187 | 1, 0]).float().to(device=device)
188 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([
189 | 1, 0]).float().to(device=device)
190 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([
191 | 1, 0]).float().to(device=device)
192 | (_, x_cls) = model((image_batch, position_prompt, task_prompt, type_prompt, nature_prompt))
193 | else:
194 | (_, x_cls) = model(image_batch)
195 |
196 | loss_ce = cls_ce_loss(x_cls, label_batch[:].long())
197 | loss = loss_ce
198 |
199 | optimizer.zero_grad()
200 | loss.backward()
201 | optimizer.step()
202 | lr_ = base_lr * (1.0 - global_iter_num / max_iterations) ** 0.9
203 | for param_group in optimizer.param_groups:
204 | param_group['lr'] = lr_
205 |
206 | cls_iter_num = cls_iter_num + 1
207 | global_iter_num = global_iter_num + 1
208 |
209 | writer.add_scalar('info/lr', lr_, cls_iter_num)
210 | writer.add_scalar('info/cls_loss', loss, cls_iter_num)
211 |
212 | logging.info('global iteration %d and cls iteration %d : loss : %f' %
213 | (global_iter_num, cls_iter_num, loss.item()))
214 |
215 | dist.barrier()
216 |
217 | if int(os.environ["LOCAL_RANK"]) == 0:
218 | torch.cuda.empty_cache()
219 |
220 | save_dict = {'model': model.state_dict(),
221 | 'optimizer': optimizer.state_dict(),
222 | 'epoch': epoch_num}
223 | save_latest_path = os.path.join(snapshot_path, 'latest_{}.pth'.format(epoch_num))
224 | if os.path.exists(os.path.join(snapshot_path, 'latest_{}.pth'.format(epoch_num-1))):
225 | os.remove(os.path.join(snapshot_path, 'latest_{}.pth'.format(epoch_num-1)))
226 | os.remove(os.path.join(snapshot_path, 'latest.pth'))
227 | torch.save(save_dict, save_latest_path)
228 | os.system('ln -s ' + os.path.abspath(save_latest_path) + ' ' + os.path.join(snapshot_path, 'latest.pth'))
229 |
230 | model.eval()
231 | total_performance = 0.0
232 |
233 | seg_val_set = ["BUS-BRA", "BUSIS", "CAMUS", "DDTI", "Fetal_HC", "kidneyUS", "UDIAT"]
234 | seg_avg_performance = 0.0
235 |
236 | for dataset_name in seg_val_set:
237 | num_classes = 2
238 | db_val = USdatasetSeg(
239 | base_dir=os.path.join(args.root_path, "segmentation", dataset_name),
240 | split="val",
241 | list_dir=os.path.join(args.root_path, "segmentation", dataset_name),
242 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]),
243 | prompt=args.prompt
244 | )
245 | val_loader = DataLoader(db_val, batch_size=batch_size, shuffle=False, num_workers=8)
246 | logging.info("{} val iterations per epoch".format(len(val_loader)))
247 |
248 | metric_list = 0.0
249 | count_matrix = np.ones((len(db_val), num_classes-1))
250 | for i_batch, sampled_batch in tqdm(enumerate(val_loader)):
251 | image, label = sampled_batch["image"], sampled_batch["label"]
252 | if args.prompt:
253 | position_prompt = torch.tensor(
254 | np.array(sampled_batch['position_prompt'])).permute([1, 0]).float()
255 | task_prompt = torch.tensor(
256 | np.array([[1]*position_prompt.shape[0], [0]*position_prompt.shape[0]])).permute([1, 0]).float()
257 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float()
258 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float()
259 | metric_i = omni_seg_test(image, label, model,
260 | classes=num_classes,
261 | prompt=args.prompt,
262 | type_prompt=type_prompt,
263 | nature_prompt=nature_prompt,
264 | position_prompt=position_prompt,
265 | task_prompt=task_prompt
266 | )
267 | else:
268 | metric_i = omni_seg_test(image, label, model,
269 | classes=num_classes)
270 |
271 | for sample_index in range(len(metric_i)):
272 | if not metric_i[sample_index][1]:
273 | count_matrix[i_batch*batch_size+sample_index, 0] = 0
274 | metric_i = [element[0] for element in metric_i]
275 | metric_list += np.array(metric_i).sum()
276 |
277 | metric_list = metric_list / (count_matrix.sum(axis=0) + 1e-6)
278 | performance = np.mean(metric_list, axis=0)
279 |
280 | writer.add_scalar('info/val_seg_metric_{}'.format(dataset_name), performance, epoch_num)
281 |
282 | seg_avg_performance += performance
283 |
284 | seg_avg_performance = seg_avg_performance / len(seg_val_set)
285 | total_performance += seg_avg_performance
286 | writer.add_scalar('info/val_metric_seg_Total', seg_avg_performance, epoch_num)
287 |
288 | cls_val_set = ["Appendix", "BUS-BRA", "Fatty-Liver", "UDIAT"]
289 | cls_avg_performance = 0.0
290 |
291 | for dataset_name in cls_val_set:
292 | num_classes = 2
293 | db_val = USdatasetCls(
294 | base_dir=os.path.join(args.root_path, "classification", dataset_name),
295 | split="val",
296 | list_dir=os.path.join(args.root_path, "classification", dataset_name),
297 | transform=CenterCropGenerator(output_size=[args.img_size, args.img_size]),
298 | prompt=args.prompt
299 | )
300 |
301 | val_loader = DataLoader(db_val, batch_size=batch_size, shuffle=False, num_workers=8)
302 | logging.info("{} val iterations per epoch".format(len(val_loader)))
303 | model.eval()
304 |
305 | label_list = []
306 | prediction_prob_list = []
307 | for i_batch, sampled_batch in tqdm(enumerate(val_loader)):
308 | image, label = sampled_batch["image"], sampled_batch["label"]
309 | if args.prompt:
310 | position_prompt = torch.tensor(
311 | np.array(sampled_batch['position_prompt'])).permute([1, 0]).float()
312 | task_prompt = torch.tensor(
313 | np.array([[0]*position_prompt.shape[0], [1]*position_prompt.shape[0]])).permute([1, 0]).float()
314 | type_prompt = torch.tensor(np.array(sampled_batch['type_prompt'])).permute([1, 0]).float()
315 | nature_prompt = torch.tensor(np.array(sampled_batch['nature_prompt'])).permute([1, 0]).float()
316 | with torch.no_grad():
317 | output = model((image.cuda(), position_prompt.cuda(), task_prompt.cuda(),
318 | type_prompt.cuda(), nature_prompt.cuda()))[1]
319 | else:
320 | with torch.no_grad():
321 | output = model(image.cuda())[1]
322 | output_prob = torch.softmax(output, dim=1).data.cpu().numpy()
323 |
324 | label_list.append(label.numpy())
325 | prediction_prob_list.append(output_prob)
326 |
327 | label_list = np.expand_dims(np.concatenate(
328 | (np.array(label_list[:-1]).flatten(), np.array(label_list[-1]).flatten())), axis=1).astype('uint8')
329 | label_list_OneHot = np.eye(num_classes)[label_list].squeeze(1)
330 | performance = roc_auc_score(label_list_OneHot, np.concatenate(
331 | (np.array(prediction_prob_list[:-1]).reshape(-1, 2), prediction_prob_list[-1])), multi_class='ovo')
332 |
333 | writer.add_scalar('info/val_cls_metric_{}'.format(dataset_name), performance, epoch_num)
334 |
335 | cls_avg_performance += performance
336 |
337 | cls_avg_performance = cls_avg_performance / len(cls_val_set)
338 | total_performance += cls_avg_performance
339 | writer.add_scalar('info/val_metric_cls_Total', cls_avg_performance, epoch_num)
340 |
341 | TotalAvgPerformance = total_performance/2
342 |
343 | logging.info('This epoch %d Validation performance: %f' % (epoch_num, TotalAvgPerformance))
344 | logging.info('But the best epoch is: %d and performance: %f' % (best_epoch, best_performance))
345 | writer.add_scalar('info/val_metric_TotalMean', TotalAvgPerformance, epoch_num)
346 | if TotalAvgPerformance >= best_performance:
347 | if os.path.exists(os.path.join(snapshot_path, 'best_model_{}_{}.pth'.format(best_epoch, round(best_performance, 4)))):
348 | os.remove(os.path.join(snapshot_path, 'best_model_{}_{}.pth'.format(
349 | best_epoch, round(best_performance, 4))))
350 | os.remove(os.path.join(snapshot_path, 'best_model.pth'))
351 | best_epoch = epoch_num
352 | best_performance = TotalAvgPerformance
353 | logging.info('Validation TotalAvgPerformance in best val model: %f' % (TotalAvgPerformance))
354 | save_model_path = os.path.join(snapshot_path, 'best_model_{}_{}.pth'.format(
355 | epoch_num, round(best_performance, 4)))
356 | os.system('ln -s ' + os.path.abspath(save_model_path) +
357 | ' ' + os.path.join(snapshot_path, 'best_model.pth'))
358 | torch.save(model.state_dict(), save_model_path)
359 | logging.info("save model to {}".format(save_model_path))
360 |
361 | model.train()
362 |
363 | writer.close()
364 | return "Training Finished!"
365 |
--------------------------------------------------------------------------------
/pretrained_ckpt/.gitkeeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zehui-Lin/UniUSNet/c61a55a3cd967188dd9eaf0cb4f3d29e924c999e/pretrained_ckpt/.gitkeeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | MedPy
3 | numpy
4 | opencv_python
5 | PyYAML
6 | scikit_learn
7 | scipy
8 | timm
9 | torch
10 | torchvision
11 | tqdm
12 | yacs
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from medpy import metric
4 | import torch.nn as nn
5 | import cv2
6 |
7 |
8 | class DiceLoss(nn.Module):
9 | def __init__(self, n_classes=2):
10 | super(DiceLoss, self).__init__()
11 | self.n_classes = n_classes
12 |
13 | def _one_hot_encoder(self, input_tensor):
14 | tensor_list = []
15 | for i in range(self.n_classes):
16 | temp_prob = input_tensor == i
17 | tensor_list.append(temp_prob.unsqueeze(1))
18 | output_tensor = torch.cat(tensor_list, dim=1)
19 | return output_tensor.float()
20 |
21 | def _dice_loss(self, score, target):
22 | target = target.float()
23 | smooth = 1e-5
24 | intersect = torch.sum(score * target)
25 | y_sum = torch.sum(target * target)
26 | z_sum = torch.sum(score * score)
27 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
28 | loss = 1 - loss
29 | return loss
30 |
31 | def forward(self, inputs, target, weight=None, softmax=False):
32 | if softmax:
33 | inputs = torch.softmax(inputs, dim=1)
34 | target = self._one_hot_encoder(target)
35 | if weight is None:
36 | weight = [1] * self.n_classes
37 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
38 | class_wise_dice = []
39 | loss = 0.0
40 | for i in range(0, self.n_classes):
41 | dice = self._dice_loss(inputs[:, i], target[:, i])
42 | class_wise_dice.append(1.0 - dice.item())
43 | loss += dice * weight[i]
44 | return loss / self.n_classes
45 |
46 |
47 | def calculate_metric_percase(pred, gt):
48 | pred[pred > 0] = 1
49 | gt[gt > 0] = 1
50 | if pred.sum() > 0 and gt.sum() > 0:
51 | dice = metric.binary.dc(pred, gt)
52 | return dice, True
53 | elif pred.sum() > 0 and gt.sum() == 0:
54 | return 0, False
55 | elif pred.sum() == 0 and gt.sum() > 0:
56 | return 0, True
57 | else:
58 | return 0, False
59 |
60 |
61 | def omni_seg_test(image, label, net, classes, ClassStartIndex=1, test_save_path=None, case=None,
62 | prompt=False,
63 | type_prompt=None,
64 | nature_prompt=None,
65 | position_prompt=None,
66 | task_prompt=None
67 | ):
68 | label = label.squeeze(0).cpu().detach().numpy()
69 | image_save = image.squeeze(0).cpu().detach().numpy()
70 | input = image.cuda()
71 | if prompt:
72 | position_prompt = position_prompt.cuda()
73 | task_prompt = task_prompt.cuda()
74 | type_prompt = type_prompt.cuda()
75 | nature_prompt = nature_prompt.cuda()
76 | net.eval()
77 | with torch.no_grad():
78 | if prompt:
79 | seg_out = net((input, position_prompt, task_prompt, type_prompt, nature_prompt))[0]
80 | else:
81 | seg_out = net(input)[0]
82 | out_label_back_transform = torch.cat(
83 | [seg_out[:, 0:1], seg_out[:, ClassStartIndex:ClassStartIndex+classes-1]], axis=1)
84 | out = torch.argmax(torch.softmax(out_label_back_transform, dim=1), dim=1).squeeze(0)
85 | prediction = out.cpu().detach().numpy()
86 |
87 | metric_list = []
88 | for i in range(1, classes): # 这里的第二个维度的含义不一样,这里是类别数
89 | metric_list.append(calculate_metric_percase(prediction == i, label == i))
90 |
91 | if test_save_path is not None:
92 | image = (image_save - np.min(image_save)) / (np.max(image_save) - np.min(image_save))
93 | cv2.imwrite(test_save_path + '/'+case + "_pred.png", (prediction*255).astype(np.uint8))
94 | cv2.imwrite(test_save_path + '/'+case + "_img.png", ((image.squeeze(0))*255).astype(np.uint8))
95 | cv2.imwrite(test_save_path + '/'+case + "_gt.png", (label*255).astype(np.uint8))
96 | return metric_list
97 |
--------------------------------------------------------------------------------