├── LICENSE ├── README.md ├── __pycache__ ├── config.cpython-38.pyc ├── config.cpython-39.pyc ├── trainer.cpython-38.pyc ├── trainer.cpython-39.pyc ├── utils.cpython-38.pyc └── utils.cpython-39.pyc ├── config.py ├── configs └── swin_tiny_patch4_window7_224_lite.yaml ├── datasets ├── README.md ├── Synapse │ ├── test_vol_h5 │ │ └── 1.npz │ ├── tonpz.py │ └── train_npz │ │ └── 1.npz ├── __pycache__ │ ├── dataset_synapse.cpython-38.pyc │ └── dataset_synapse.cpython-39.pyc └── dataset_synapse.py ├── img ├── Comparison.png └── SCUnet++.png ├── lists └── lists_Synapse │ └── tool.py ├── networks ├── __pycache__ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-39.pyc │ ├── vision_transformer.cpython-38.pyc │ └── vision_transformer.cpython-39.pyc ├── swin_transformer_unet_skip_expand_decoder_sys.py └── vision_transformer.py ├── requirements.txt ├── test.py ├── train.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 JustlfC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCUNet++ 2 | 3 | We propose an automatic PE segmentation method called SCUNet++ (Swin Conv UNet++). This method incorporates multiple fusion dense skip connections between the encoder and decoder, utilizing the Swin Transformer as the encoder. 4 | ![image](img/SCUnet++.png) 5 | 6 | Comparison of segmentation performance of different network models on the CAD-PE dataset: (a) input image; (b) ground truth mask; (c) the proposed SCUNet++ model; (d) UNet++ model; (e) UNet model; (f) Swin-UNet model; and (g) ResD-UNet model. 7 | ![Comparison](img/Comparison.png) 8 | 9 | ## 1. Download pre-trained swin transformer model (Swin-T) 10 | 11 | * [Get pre-trained model in this link] (https://drive.google.com/drive/folders/1UC3XOoezeum0uck4KBVGa8osahs6rKUY?usp=sharing): Put pretrained Swin-T into folder "pretrained_ckpt/" 12 | 13 | ## 2. Environment 14 | 15 | - Please prepare an environment with python=3.7, and then use the command "pip install -r requirements.txt" for the dependencies. 16 | 17 | ## 3. Train/Test 18 | 19 | - Run lists/lists_Synapse/tool.py to Generate txt file 20 | - Run train.py to Train (Put the dataset in npz format into datasets/Synapse/train_npz) 21 | - Run test.py to Test (Put the dataset in npz format into datasets/test) 22 | - The batch size we used is 24. If you do not have enough GPU memory, the bacth size can be reduced to 12 or 6 to save memory. 23 | 24 | ## 4. New FUMPE dataset 25 | 26 | Upon review, we find significant errors and deviations in the original dataset annotations. So, we reannotated these datasets to ensure accuracy. 27 | We have uploaded the link for downloading the new FUMPE dataset to our GitHub repository, thereby providing other users with the ability to access and utilize it. 28 | If you want to use our relabelled dataset, please cite our article. 29 | 30 | Download link is available at [https://drive.google.com/file/d/1hOmQ9s8eE__nqIe3lpwGYoydR4_UNRrU/view?usp=drive_link](https://drive.google.com/file/d/1hOmQ9s8eE__nqIe3lpwGYoydR4_UNRrU/view?usp=drive_link). 31 | 32 | ## 5. Citation 33 | 34 | ``` 35 | @InProceedings{Chen_2024_WACV, 36 | author = {Chen, Yifei and Zou, Binfeng and Guo, Zhaoxin and Huang, Yiyu and Huang, Yifan and Qin, Feiwei and Li, Qinhai and Wang, Changmiao}, 37 | title = {SCUNet++: Swin-UNet and CNN Bottleneck Hybrid Architecture With Multi-Fusion Dense Skip Connection for Pulmonary Embolism CT Image Segmentation}, 38 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 39 | month = {January}, 40 | year = {2024}, 41 | pages = {7759-7767} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/__pycache__/trainer.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import yaml 4 | from yacs.config import CfgNode as CN 5 | 6 | _C = CN() 7 | 8 | # Base config files 9 | _C.BASE = [''] 10 | 11 | # ----------------------------------------------------------------------------- 12 | # Data settings 13 | # ----------------------------------------------------------------------------- 14 | _C.DATA = CN() 15 | # Batch size for a single GPU, could be overwritten by command line argument 16 | _C.DATA.BATCH_SIZE = 128 17 | # Path to dataset, could be overwritten by command line argument 18 | _C.DATA.DATA_PATH = '' 19 | # Dataset name 20 | _C.DATA.DATASET = 'imagenet' 21 | # Input image size 22 | _C.DATA.IMG_SIZE = 224 23 | # Interpolation to resize image (random, bilinear, bicubic) 24 | _C.DATA.INTERPOLATION = 'bicubic' 25 | # Use zipped dataset instead of folder dataset 26 | # could be overwritten by command line argument 27 | _C.DATA.ZIP_MODE = False 28 | # Cache Data in Memory, could be overwritten by command line argument 29 | _C.DATA.CACHE_MODE = 'part' 30 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 31 | _C.DATA.PIN_MEMORY = True 32 | # Number of data loading threads 33 | _C.DATA.NUM_WORKERS = 8 34 | 35 | # ----------------------------------------------------------------------------- 36 | # Model settings 37 | # ----------------------------------------------------------------------------- 38 | _C.MODEL = CN() 39 | # Model type 40 | _C.MODEL.TYPE = 'swin' 41 | # Model name 42 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 43 | # Checkpoint to resume, could be overwritten by command line argument 44 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 45 | _C.MODEL.RESUME = '' 46 | # Number of classes, overwritten in data preparation 47 | _C.MODEL.NUM_CLASSES = 1000 48 | # Dropout rate 49 | _C.MODEL.DROP_RATE = 0.0 50 | # Drop path rate 51 | _C.MODEL.DROP_PATH_RATE = 0.1 52 | # Label Smoothing 53 | _C.MODEL.LABEL_SMOOTHING = 0.1 54 | 55 | # Swin Transformer parameters 56 | _C.MODEL.SWIN = CN() 57 | _C.MODEL.SWIN.PATCH_SIZE = 4 58 | _C.MODEL.SWIN.IN_CHANS = 3 59 | _C.MODEL.SWIN.EMBED_DIM = 96 60 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 61 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 62 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 63 | _C.MODEL.SWIN.WINDOW_SIZE = 7 64 | _C.MODEL.SWIN.MLP_RATIO = 4. 65 | _C.MODEL.SWIN.QKV_BIAS = True 66 | _C.MODEL.SWIN.QK_SCALE = None 67 | _C.MODEL.SWIN.APE = False 68 | _C.MODEL.SWIN.PATCH_NORM = True 69 | _C.MODEL.SWIN.FINAL_UPSAMPLE = "expand_first" 70 | 71 | # ----------------------------------------------------------------------------- 72 | # Training settings 73 | # ----------------------------------------------------------------------------- 74 | _C.TRAIN = CN() 75 | _C.TRAIN.START_EPOCH = 0 76 | _C.TRAIN.EPOCHS = 300 77 | _C.TRAIN.WARMUP_EPOCHS = 20 78 | _C.TRAIN.WEIGHT_DECAY = 0.05 79 | _C.TRAIN.BASE_LR = 5e-4 80 | _C.TRAIN.WARMUP_LR = 5e-7 81 | _C.TRAIN.MIN_LR = 5e-6 82 | # Clip gradient norm 83 | _C.TRAIN.CLIP_GRAD = 5.0 84 | # Auto resume from latest checkpoint 85 | _C.TRAIN.AUTO_RESUME = True 86 | # Gradient accumulation steps 87 | # could be overwritten by command line argument 88 | _C.TRAIN.ACCUMULATION_STEPS = 0 89 | # Whether to use gradient checkpointing to save memory 90 | # could be overwritten by command line argument 91 | _C.TRAIN.USE_CHECKPOINT = False 92 | 93 | # LR scheduler 94 | _C.TRAIN.LR_SCHEDULER = CN() 95 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 96 | # Epoch interval to decay LR, used in StepLRScheduler 97 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 98 | # LR decay rate, used in StepLRScheduler 99 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 100 | 101 | # Optimizer 102 | _C.TRAIN.OPTIMIZER = CN() 103 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 104 | # Optimizer Epsilon 105 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 106 | # Optimizer Betas 107 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 108 | # SGD momentum 109 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 110 | 111 | # ----------------------------------------------------------------------------- 112 | # Augmentation settings 113 | # ----------------------------------------------------------------------------- 114 | _C.AUG = CN() 115 | # Color jitter factor 116 | _C.AUG.COLOR_JITTER = 0.4 117 | # Use AutoAugment policy. "v0" or "original" 118 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 119 | # Random erase prob 120 | _C.AUG.REPROB = 0.25 121 | # Random erase mode 122 | _C.AUG.REMODE = 'pixel' 123 | # Random erase count 124 | _C.AUG.RECOUNT = 1 125 | # Mixup alpha, mixup enabled if > 0 126 | _C.AUG.MIXUP = 0.8 127 | # Cutmix alpha, cutmix enabled if > 0 128 | _C.AUG.CUTMIX = 1.0 129 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 130 | _C.AUG.CUTMIX_MINMAX = None 131 | # Probability of performing mixup or cutmix when either/both is enabled 132 | _C.AUG.MIXUP_PROB = 1.0 133 | # Probability of switching to cutmix when both mixup and cutmix enabled 134 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 135 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 136 | _C.AUG.MIXUP_MODE = 'batch' 137 | 138 | # ----------------------------------------------------------------------------- 139 | # Testing settings 140 | # ----------------------------------------------------------------------------- 141 | _C.TEST = CN() 142 | # Whether to use center crop when testing 143 | _C.TEST.CROP = True 144 | 145 | # ----------------------------------------------------------------------------- 146 | # Misc 147 | # ----------------------------------------------------------------------------- 148 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 149 | # overwritten by command line argument 150 | _C.AMP_OPT_LEVEL = '' 151 | # Path to output folder, overwritten by command line argument 152 | _C.OUTPUT = '' 153 | # Tag of experiment, overwritten by command line argument 154 | _C.TAG = 'default' 155 | # Frequency to save checkpoint 156 | _C.SAVE_FREQ = 1 157 | # Frequency to logging info 158 | _C.PRINT_FREQ = 10 159 | # Fixed random seed 160 | _C.SEED = 0 161 | # Perform evaluation only, overwritten by command line argument 162 | _C.EVAL_MODE = False 163 | # Test throughput only, overwritten by command line argument 164 | _C.THROUGHPUT_MODE = False 165 | # local rank for DistributedDataParallel, given by command line argument 166 | _C.LOCAL_RANK = 0 167 | 168 | 169 | def _update_config_from_file(config, cfg_file): 170 | config.defrost() 171 | with open(cfg_file, 'r') as f: 172 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 173 | 174 | for cfg in yaml_cfg.setdefault('BASE', ['']): 175 | if cfg: 176 | _update_config_from_file( 177 | config, os.path.join(os.path.dirname(cfg_file), cfg) 178 | ) 179 | print('=> merge config from {}'.format(cfg_file)) 180 | config.merge_from_file(cfg_file) 181 | config.freeze() 182 | 183 | 184 | def update_config(config, args): 185 | _update_config_from_file(config, args.cfg) 186 | 187 | config.defrost() 188 | if args.opts: 189 | config.merge_from_list(args.opts) 190 | 191 | # merge from specific arguments 192 | if args.batch_size: 193 | config.DATA.BATCH_SIZE = args.batch_size 194 | if args.zip: 195 | config.DATA.ZIP_MODE = True 196 | if args.cache_mode: 197 | config.DATA.CACHE_MODE = args.cache_mode 198 | if args.resume: 199 | config.MODEL.RESUME = args.resume 200 | if args.accumulation_steps: 201 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 202 | if args.use_checkpoint: 203 | config.TRAIN.USE_CHECKPOINT = True 204 | if args.amp_opt_level: 205 | config.AMP_OPT_LEVEL = args.amp_opt_level 206 | if args.tag: 207 | config.TAG = args.tag 208 | if args.eval: 209 | config.EVAL_MODE = True 210 | if args.throughput: 211 | config.THROUGHPUT_MODE = True 212 | 213 | config.freeze() 214 | 215 | 216 | def get_config(args): 217 | """Get a yacs CfgNode object with default values.""" 218 | # Return a clone so that the defaults will not be altered 219 | # This is for the "local variable" use pattern 220 | config = _C.clone() 221 | update_config(config, args) 222 | 223 | return config 224 | -------------------------------------------------------------------------------- /configs/swin_tiny_patch4_window7_224_lite.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "./pretrained_ckpt/swin_tiny_patch4_window7_224.pth" 6 | SWIN: 7 | FINAL_UPSAMPLE: "expand_first" 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 2, 2 ] 10 | DECODER_DEPTHS: [ 2, 2, 2, 1 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Data Preparing 2 | we conducted experiments using publicly available datasets FUMPE and CAD-PE . We only selected fixed-angle PE images and corresponding labels for each case. By the way, the dataset code "dataset_synapse.py" we are used is provided by the authors of [TransUNet](https://github.com/Beckschen/TransUNet). 3 | -------------------------------------------------------------------------------- /datasets/Synapse/test_vol_h5/1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/datasets/Synapse/test_vol_h5/1.npz -------------------------------------------------------------------------------- /datasets/Synapse/tonpz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import glob 4 | import os 5 | 6 | 7 | def npz(): 8 | # 原图像存放的路径 9 | # path = r'D:\DEMO\SCUNet++\datasets\Synapse\train\images\*.png' 10 | path = r'G:\FINAL\SCUNet++\datasets\test\images\*.png' 11 | # 存放训练(测试)所用的npz文件路径 12 | # path2 = r'D:\Research_Topic\Swin-Unet-main\datasets\Synapse\train_npz\\' 13 | # path2 = r'D:\DEMO\SCUNet++\datasets\Synapse\train_npz' 14 | path2 = r'G:\FINAL\SCUNet++\datasets/Synapse/test_vol_h5' 15 | for i, img_path in enumerate(glob.glob(path)): 16 | # 读入图像 17 | image = cv2.imread(img_path) 18 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 19 | # 读入标签 20 | label_path = img_path.replace('images', 'labels') 21 | label = cv2.imread(label_path, flags=0) 22 | # 将非目标像素设置为0 23 | label[label == 0] = 0 24 | # 将目标像素设置为1 25 | label[label != 0] = 1 26 | # 保存npz文件 27 | # print(os.path.join(path2, str(i + 1))) 28 | np.savez(os.path.join(path2, str(i + 1)), image=image, label=label) 29 | print('finished:', i + 1) 30 | 31 | print('Finished') 32 | 33 | 34 | npz() 35 | -------------------------------------------------------------------------------- /datasets/Synapse/train_npz/1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/datasets/Synapse/train_npz/1.npz -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_synapse.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/datasets/__pycache__/dataset_synapse.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_synapse.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/datasets/__pycache__/dataset_synapse.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/dataset_synapse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def random_rot_flip(image, label): 12 | k = np.random.randint(0, 4) 13 | image = np.rot90(image, k) 14 | label = np.rot90(label, k) 15 | axis = np.random.randint(0, 2) 16 | image = np.flip(image, axis=axis).copy() 17 | label = np.flip(label, axis=axis).copy() 18 | return image, label 19 | 20 | 21 | def random_rotate(image, label): 22 | angle = np.random.randint(-20, 20) 23 | image = ndimage.rotate(image, angle, order=0, reshape=False) 24 | label = ndimage.rotate(label, angle, order=0, reshape=False) 25 | return image, label 26 | 27 | 28 | class RandomGenerator(object): 29 | def __init__(self, output_size): 30 | self.output_size = output_size 31 | 32 | def __call__(self, sample): 33 | image, label = sample['image'], sample['label'] 34 | 35 | if random.random() > 0.5: 36 | image, label = random_rot_flip(image, label) 37 | elif random.random() > 0.5: 38 | image, label = random_rotate(image, label) 39 | # 修改 40 | # x, y = image.shape 41 | x, y, _ = image.shape 42 | if x != self.output_size[0] or y != self.output_size[1]: 43 | # 修改 44 | # image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) 45 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y, 1), order=3) 46 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 47 | # 修改 48 | # image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 49 | image = torch.from_numpy(image.astype(np.float32)) 50 | image = image.permute(2, 0, 1) 51 | 52 | label = torch.from_numpy(label.astype(np.float32)) 53 | sample = {'image': image, 'label': label.long()} 54 | return sample 55 | 56 | 57 | class Synapse_dataset(Dataset): 58 | def __init__(self, base_dir, list_dir, split, transform=None): 59 | self.transform = transform 60 | self.split = split 61 | self.sample_list = open(os.path.join(list_dir, self.split + '.txt')).readlines() 62 | self.data_dir = base_dir 63 | 64 | def __len__(self): 65 | return len(self.sample_list) 66 | 67 | def __getitem__(self, idx): 68 | if self.split == "train": 69 | slice_name = self.sample_list[idx].strip('\n') 70 | data_path = os.path.join(self.data_dir, slice_name + '.npz') 71 | data = np.load(data_path) 72 | image, label = data['image'], data['label'] 73 | else: 74 | slice_name = self.sample_list[idx].strip('\n') 75 | data_path = os.path.join(self.data_dir, slice_name + '.npz') 76 | data = np.load(data_path) 77 | image, label = data['image'], data['label'] 78 | # 修改 79 | # vol_name = self.sample_list[idx].strip('\n') 80 | # filepath = self.data_dir + "/{}.npy.h5".format(vol_name) 81 | # data = h5py.File(filepath) 82 | # image, label = data['image'][:], data['label'][:] 83 | image = torch.from_numpy(image.astype(np.float32)) 84 | image = image.permute(2, 0, 1) 85 | label = torch.from_numpy(label.astype(np.float32)) 86 | 87 | sample = {'image': image, 'label': label} 88 | if self.transform: 89 | sample = self.transform(sample) 90 | sample['case_name'] = self.sample_list[idx].strip('\n') 91 | return sample 92 | -------------------------------------------------------------------------------- /img/Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/img/Comparison.png -------------------------------------------------------------------------------- /img/SCUnet++.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/img/SCUnet++.png -------------------------------------------------------------------------------- /lists/lists_Synapse/tool.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | 4 | def write_name(): 5 | # npz文件路径 6 | # files = glob.glob(r'D:\DEMO\SCUNet++\datasets\Synapse\train_npz\*.npz') 7 | files = glob.glob(r'G:\FINAL\SCUNet++\datasets\Synapse\test_vol_h5\*.npz') 8 | # txt文件路径 9 | # f = open(r'D:\DEMO\SCUNet++\lists\lists_Synapse\train.txt', 'w') 10 | f = open(r'G:\FINAL\SCUNet++\lists\lists_Synapse\test.txt', 'w') 11 | for i in files: 12 | name = i.split('\\')[-1] 13 | name = name[:-4] + '\n' 14 | f.write(name) 15 | 16 | print("Finished!") 17 | 18 | 19 | write_name() 20 | -------------------------------------------------------------------------------- /networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/vision_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/networks/__pycache__/vision_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/vision_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustlfC03/SCUNet-plusplus/c965058fec41f6738def31da7a304c14069d0848/networks/__pycache__/vision_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /networks/swin_transformer_unet_skip_expand_decoder_sys.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.checkpoint as checkpoint 6 | from einops import rearrange 7 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 8 | 9 | 10 | class Mlp(nn.Module): 11 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 12 | super().__init__() 13 | out_features = out_features or in_features 14 | hidden_features = hidden_features or in_features 15 | self.fc1 = nn.Linear(in_features, hidden_features) 16 | self.act = act_layer() 17 | self.fc2 = nn.Linear(hidden_features, out_features) 18 | self.drop = nn.Dropout(drop) 19 | 20 | def forward(self, x): 21 | x = self.fc1(x) 22 | x = self.act(x) 23 | x = self.drop(x) 24 | x = self.fc2(x) 25 | x = self.drop(x) 26 | return x 27 | 28 | 29 | def window_partition(x, window_size): 30 | """ 31 | Args: 32 | x: (B, H, W, C) 33 | window_size (int): window size 34 | 35 | Returns: 36 | windows: (num_windows*B, window_size, window_size, C) 37 | """ 38 | B, H, W, C = x.shape 39 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 40 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 41 | return windows 42 | 43 | 44 | def window_reverse(windows, window_size, H, W): 45 | """ 46 | Args: 47 | windows: (num_windows*B, window_size, window_size, C) 48 | window_size (int): Window size 49 | H (int): Height of image 50 | W (int): Width of image 51 | 52 | Returns: 53 | x: (B, H, W, C) 54 | """ 55 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 56 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 57 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 58 | return x 59 | 60 | 61 | class WindowAttention(nn.Module): 62 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 63 | It supports both of shifted and non-shifted window. 64 | 65 | Args: 66 | dim (int): Number of input channels. 67 | window_size (tuple[int]): The height and width of the window. 68 | num_heads (int): Number of attention heads. 69 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 70 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 71 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 72 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 73 | """ 74 | 75 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 76 | 77 | super().__init__() 78 | self.dim = dim 79 | self.window_size = window_size # Wh, Ww 80 | self.num_heads = num_heads 81 | head_dim = dim // num_heads 82 | self.scale = qk_scale or head_dim ** -0.5 83 | 84 | # define a parameter table of relative position bias 85 | self.relative_position_bias_table = nn.Parameter( 86 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 87 | 88 | # get pair-wise relative position index for each token inside the window 89 | coords_h = torch.arange(self.window_size[0]) 90 | coords_w = torch.arange(self.window_size[1]) 91 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 92 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 93 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 94 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 95 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 96 | relative_coords[:, :, 1] += self.window_size[1] - 1 97 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 98 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 99 | self.register_buffer("relative_position_index", relative_position_index) 100 | 101 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 102 | self.attn_drop = nn.Dropout(attn_drop) 103 | self.proj = nn.Linear(dim, dim) 104 | self.proj_drop = nn.Dropout(proj_drop) 105 | 106 | trunc_normal_(self.relative_position_bias_table, std=.02) 107 | self.softmax = nn.Softmax(dim=-1) 108 | 109 | def forward(self, x, mask=None): 110 | """ 111 | Args: 112 | x: input features with shape of (num_windows*B, N, C) 113 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 114 | """ 115 | B_, N, C = x.shape 116 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 117 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 118 | 119 | q = q * self.scale 120 | attn = (q @ k.transpose(-2, -1)) 121 | 122 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 123 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 124 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 125 | attn = attn + relative_position_bias.unsqueeze(0) 126 | 127 | if mask is not None: 128 | nW = mask.shape[0] 129 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 130 | attn = attn.view(-1, self.num_heads, N, N) 131 | attn = self.softmax(attn) 132 | else: 133 | attn = self.softmax(attn) 134 | 135 | attn = self.attn_drop(attn) 136 | 137 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 138 | x = self.proj(x) 139 | x = self.proj_drop(x) 140 | return x 141 | 142 | def extra_repr(self) -> str: 143 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 144 | 145 | def flops(self, N): 146 | # calculate flops for 1 window with token length of N 147 | flops = 0 148 | # qkv = self.qkv(x) 149 | flops += N * self.dim * 3 * self.dim 150 | # attn = (q @ k.transpose(-2, -1)) 151 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 152 | # x = (attn @ v) 153 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 154 | # x = self.proj(x) 155 | flops += N * self.dim * self.dim 156 | return flops 157 | 158 | 159 | class SwinTransformerBlock(nn.Module): 160 | r""" Swin Transformer Block. 161 | 162 | Args: 163 | dim (int): Number of input channels. 164 | input_resolution (tuple[int]): Input resulotion. 165 | num_heads (int): Number of attention heads. 166 | window_size (int): Window size. 167 | shift_size (int): Shift size for SW-MSA. 168 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 169 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 170 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 171 | drop (float, optional): Dropout rate. Default: 0.0 172 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 173 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 174 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 175 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 176 | """ 177 | 178 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 179 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 180 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 181 | super().__init__() 182 | self.dim = dim 183 | self.input_resolution = input_resolution 184 | self.num_heads = num_heads 185 | self.window_size = window_size 186 | self.shift_size = shift_size 187 | self.mlp_ratio = mlp_ratio 188 | if min(self.input_resolution) <= self.window_size: 189 | # if window size is larger than input resolution, we don't partition windows 190 | self.shift_size = 0 191 | self.window_size = min(self.input_resolution) 192 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 193 | 194 | self.norm1 = norm_layer(dim) 195 | self.attn = WindowAttention( 196 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 197 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 198 | 199 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 200 | self.norm2 = norm_layer(dim) 201 | mlp_hidden_dim = int(dim * mlp_ratio) 202 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 203 | 204 | if self.shift_size > 0: 205 | # calculate attention mask for SW-MSA 206 | H, W = self.input_resolution 207 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 208 | h_slices = (slice(0, -self.window_size), 209 | slice(-self.window_size, -self.shift_size), 210 | slice(-self.shift_size, None)) 211 | w_slices = (slice(0, -self.window_size), 212 | slice(-self.window_size, -self.shift_size), 213 | slice(-self.shift_size, None)) 214 | cnt = 0 215 | for h in h_slices: 216 | for w in w_slices: 217 | img_mask[:, h, w, :] = cnt 218 | cnt += 1 219 | 220 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 221 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 222 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 223 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 224 | else: 225 | attn_mask = None 226 | 227 | self.register_buffer("attn_mask", attn_mask) 228 | 229 | def forward(self, x): 230 | H, W = self.input_resolution 231 | B, L, C = x.shape 232 | assert L == H * W, "input feature has wrong size" 233 | 234 | shortcut = x 235 | x = self.norm1(x) 236 | x = x.view(B, H, W, C) 237 | 238 | # cyclic shift 239 | if self.shift_size > 0: 240 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 241 | else: 242 | shifted_x = x 243 | 244 | # partition windows 245 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 246 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 247 | 248 | # W-MSA/SW-MSA 249 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 250 | 251 | # merge windows 252 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 253 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 254 | 255 | # reverse cyclic shift 256 | if self.shift_size > 0: 257 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 258 | else: 259 | x = shifted_x 260 | x = x.view(B, H * W, C) 261 | 262 | # FFN 263 | x = shortcut + self.drop_path(x) 264 | x = x + self.drop_path(self.mlp(self.norm2(x))) 265 | 266 | return x 267 | 268 | def extra_repr(self) -> str: 269 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 270 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 271 | 272 | def flops(self): 273 | flops = 0 274 | H, W = self.input_resolution 275 | # norm1 276 | flops += self.dim * H * W 277 | # W-MSA/SW-MSA 278 | nW = H * W / self.window_size / self.window_size 279 | flops += nW * self.attn.flops(self.window_size * self.window_size) 280 | # mlp 281 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 282 | # norm2 283 | flops += self.dim * H * W 284 | return flops 285 | 286 | 287 | class PatchMerging(nn.Module): 288 | r""" Patch Merging Layer. 289 | 290 | Args: 291 | input_resolution (tuple[int]): Resolution of input feature. 292 | dim (int): Number of input channels. 293 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 294 | """ 295 | 296 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 297 | super().__init__() 298 | self.input_resolution = input_resolution 299 | self.dim = dim 300 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 301 | self.norm = norm_layer(4 * dim) 302 | 303 | def forward(self, x): 304 | """ 305 | x: B, H*W, C 306 | """ 307 | H, W = self.input_resolution 308 | B, L, C = x.shape 309 | assert L == H * W, "input feature has wrong size" 310 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 311 | 312 | x = x.view(B, H, W, C) 313 | 314 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 315 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 316 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 317 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 318 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 319 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 320 | 321 | x = self.norm(x) 322 | x = self.reduction(x) 323 | 324 | return x 325 | 326 | def extra_repr(self) -> str: 327 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 328 | 329 | def flops(self): 330 | H, W = self.input_resolution 331 | flops = H * W * self.dim 332 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 333 | return flops 334 | 335 | 336 | class PatchExpand(nn.Module): 337 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 338 | super().__init__() 339 | self.input_resolution = input_resolution 340 | self.dim = dim 341 | self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() 342 | self.norm = norm_layer(dim // dim_scale) 343 | 344 | def forward(self, x): 345 | """ 346 | x: B, H*W, C 347 | """ 348 | # print(x.shape) 349 | H, W = self.input_resolution 350 | x = self.expand(x) 351 | # print(x.shape) 352 | B, L, C = x.shape 353 | assert L == H * W, "input feature has wrong size" 354 | 355 | x = x.view(B, H, W, C) 356 | # print(x.shape) 357 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) 358 | # print(x.shape) 359 | x = x.view(B, -1, C // 4) 360 | # print(x.shape) 361 | x = self.norm(x) 362 | # print(x.shape) 363 | 364 | return x 365 | 366 | 367 | class my_PatchExpand(nn.Module): 368 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 369 | super().__init__() 370 | self.input_resolution = input_resolution 371 | self.dim = dim 372 | self.expand = nn.Linear(dim, 4 * dim, bias=False) if dim_scale == 2 else nn.Identity() 373 | self.norm = norm_layer(dim // dim_scale) 374 | 375 | def forward(self, x): 376 | """ 377 | x: B, H*W, C 378 | """ 379 | H, W = self.input_resolution 380 | print(x.shape) 381 | x = self.expand(x) 382 | print(x.shape) 383 | B, L, C = x.shape 384 | assert L == H * W, "input feature has wrong size" 385 | 386 | x = x.view(B, H, W, C) 387 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) 388 | x = x.view(B, -1, C // 4) 389 | x = self.norm(x) 390 | 391 | return x 392 | 393 | 394 | class FinalPatchExpand_X4(nn.Module): 395 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 396 | super().__init__() 397 | self.input_resolution = input_resolution 398 | self.dim = dim 399 | self.dim_scale = dim_scale 400 | self.expand = nn.Linear(dim, 16 * dim, bias=False) 401 | self.output_dim = dim 402 | self.norm = norm_layer(self.output_dim) 403 | 404 | def forward(self, x): 405 | """ 406 | x: B, H*W, C 407 | """ 408 | H, W = self.input_resolution 409 | x = self.expand(x) 410 | B, L, C = x.shape 411 | assert L == H * W, "input feature has wrong size" 412 | 413 | x = x.view(B, H, W, C) 414 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, 415 | c=C // (self.dim_scale ** 2)) 416 | x = x.view(B, -1, self.output_dim) 417 | x = self.norm(x) 418 | 419 | return x 420 | 421 | 422 | class BasicLayer(nn.Module): 423 | """ A basic Swin Transformer layer for one stage. 424 | 425 | Args: 426 | dim (int): Number of input channels. 427 | input_resolution (tuple[int]): Input resolution. 428 | depth (int): Number of blocks. 429 | num_heads (int): Number of attention heads. 430 | window_size (int): Local window size. 431 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 432 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 433 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 434 | drop (float, optional): Dropout rate. Default: 0.0 435 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 436 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 437 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 438 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 439 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 440 | """ 441 | 442 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 443 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 444 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, flag=False): 445 | 446 | super().__init__() 447 | self.dim = dim 448 | self.input_resolution = input_resolution 449 | self.depth = depth 450 | self.use_checkpoint = use_checkpoint 451 | 452 | # build blocks 453 | self.blocks = nn.ModuleList([ 454 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 455 | num_heads=num_heads, window_size=window_size, 456 | shift_size=0 if (i % 2 == 0) else window_size // 2, 457 | mlp_ratio=mlp_ratio, 458 | qkv_bias=qkv_bias, qk_scale=qk_scale, 459 | drop=drop, attn_drop=attn_drop, 460 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 461 | norm_layer=norm_layer) 462 | for i in range(depth)]) 463 | 464 | # patch merging layer 465 | if downsample is not None: 466 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 467 | else: 468 | self.downsample = None 469 | 470 | self.flag = flag 471 | self.convBottleneckBlock = ConvBottleneckBlock(768, 768, 768) 472 | 473 | def forward(self, x): 474 | for blk in self.blocks: 475 | if self.use_checkpoint: 476 | x = checkpoint.checkpoint(blk, x) 477 | # else: 478 | # x = blk(x) 479 | else: 480 | if not self.flag: 481 | x = blk(x) 482 | if self.flag: 483 | # print(x.shape) 484 | x = x.reshape(-1, 7, 7, 768) # iteration 7:([10, 768, 7, 7]) 485 | # print(x.shape) 486 | x = x.permute(0, 3, 1, 2) 487 | # print(x.shape) 488 | x = self.convBottleneckBlock(x) 489 | # print(x.shape) 490 | if self.downsample is not None: 491 | x = self.downsample(x) 492 | # print(x.shape) 493 | return x 494 | 495 | def extra_repr(self) -> str: 496 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 497 | 498 | def flops(self): 499 | flops = 0 500 | for blk in self.blocks: 501 | flops += blk.flops() 502 | if self.downsample is not None: 503 | flops += self.downsample.flops() 504 | return flops 505 | 506 | 507 | class ConvBottleneckBlock(nn.Module): 508 | def __init__(self, in_channels, middle_channels, out_channels): 509 | super().__init__() 510 | self.relu = nn.ReLU(inplace=True) 511 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 1, padding=0) 512 | self.bn1 = nn.BatchNorm2d(middle_channels) 513 | self.conv2 = nn.Conv2d(middle_channels, middle_channels, 3, padding=1) 514 | self.bn2 = nn.BatchNorm2d(out_channels) 515 | self.conv3 = nn.Conv2d(middle_channels, out_channels, 1, padding=0) 516 | self.bn3 = nn.BatchNorm2d(out_channels) 517 | 518 | def forward(self, x): 519 | out = self.conv1(x) 520 | out = self.bn1(out) 521 | out = self.relu(out) 522 | 523 | out = self.conv2(out) 524 | out = self.bn2(out) 525 | out = self.relu(out) 526 | 527 | out = self.conv3(out) 528 | out = self.bn3(out) 529 | out = self.relu(out) 530 | 531 | return out 532 | 533 | 534 | class BasicLayer_up(nn.Module): 535 | """ A basic Swin Transformer layer for one stage. 536 | 537 | Args: 538 | dim (int): Number of input channels. 539 | input_resolution (tuple[int]): Input resolution. 540 | depth (int): Number of blocks. 541 | num_heads (int): Number of attention heads. 542 | window_size (int): Local window size. 543 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 544 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 545 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 546 | drop (float, optional): Dropout rate. Default: 0.0 547 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 548 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 549 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 550 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 551 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 552 | """ 553 | 554 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 555 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 556 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 557 | 558 | super().__init__() 559 | self.dim = dim 560 | self.input_resolution = input_resolution 561 | self.depth = depth 562 | self.use_checkpoint = use_checkpoint 563 | 564 | # build blocks 565 | self.blocks = nn.ModuleList([ 566 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 567 | num_heads=num_heads, window_size=window_size, 568 | shift_size=0 if (i % 2 == 0) else window_size // 2, 569 | mlp_ratio=mlp_ratio, 570 | qkv_bias=qkv_bias, qk_scale=qk_scale, 571 | drop=drop, attn_drop=attn_drop, 572 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 573 | norm_layer=norm_layer) 574 | for i in range(depth)]) 575 | 576 | # patch merging layer 577 | if upsample is not None: 578 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 579 | else: 580 | self.upsample = None 581 | 582 | def forward(self, x): 583 | for blk in self.blocks: 584 | if self.use_checkpoint: 585 | x = checkpoint.checkpoint(blk, x) 586 | else: 587 | x = blk(x) 588 | if self.upsample is not None: 589 | x = self.upsample(x) 590 | return x 591 | 592 | 593 | class PatchEmbed(nn.Module): 594 | r""" Image to Patch Embedding 595 | 596 | Args: 597 | img_size (int): Image size. Default: 224. 598 | patch_size (int): Patch token size. Default: 4. 599 | in_chans (int): Number of input image channels. Default: 3. 600 | embed_dim (int): Number of linear projection output channels. Default: 96. 601 | norm_layer (nn.Module, optional): Normalization layer. Default: None 602 | """ 603 | 604 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 605 | super().__init__() 606 | img_size = to_2tuple(img_size) 607 | patch_size = to_2tuple(patch_size) 608 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 609 | self.img_size = img_size 610 | self.patch_size = patch_size 611 | self.patches_resolution = patches_resolution 612 | self.num_patches = patches_resolution[0] * patches_resolution[1] 613 | 614 | self.in_chans = in_chans 615 | self.embed_dim = embed_dim 616 | 617 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 618 | if norm_layer is not None: 619 | self.norm = norm_layer(embed_dim) 620 | else: 621 | self.norm = None 622 | 623 | def forward(self, x): 624 | B, C, H, W = x.shape 625 | # FIXME look at relaxing size constraints 626 | assert H == self.img_size[0] and W == self.img_size[1], \ 627 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 628 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 629 | if self.norm is not None: 630 | x = self.norm(x) 631 | return x 632 | 633 | def flops(self): 634 | Ho, Wo = self.patches_resolution 635 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 636 | if self.norm is not None: 637 | flops += Ho * Wo * self.embed_dim 638 | return flops 639 | 640 | 641 | class SwinTransformerSys(nn.Module): 642 | r""" Swin Transformer 643 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 644 | https://arxiv.org/pdf/2103.14030 645 | 646 | Args: 647 | img_size (int | tuple(int)): Input image size. Default 224 648 | patch_size (int | tuple(int)): Patch size. Default: 4 649 | in_chans (int): Number of input image channels. Default: 3 650 | num_classes (int): Number of classes for classification head. Default: 1000 651 | embed_dim (int): Patch embedding dimension. Default: 96 652 | depths (tuple(int)): Depth of each Swin Transformer layer. 653 | num_heads (tuple(int)): Number of attention heads in different layers. 654 | window_size (int): Window size. Default: 7 655 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 656 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 657 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 658 | drop_rate (float): Dropout rate. Default: 0 659 | attn_drop_rate (float): Attention dropout rate. Default: 0 660 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 661 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 662 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 663 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 664 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 665 | """ 666 | 667 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 668 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 669 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 670 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 671 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 672 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 673 | super().__init__() 674 | 675 | print( 676 | "SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format( 677 | depths, 678 | depths_decoder, drop_path_rate, num_classes)) 679 | 680 | self.num_classes = num_classes 681 | self.num_layers = len(depths) 682 | self.embed_dim = embed_dim 683 | self.ape = ape 684 | self.patch_norm = patch_norm 685 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 686 | self.num_features_up = int(embed_dim * 2) 687 | self.mlp_ratio = mlp_ratio 688 | self.final_upsample = final_upsample 689 | 690 | # split image into non-overlapping patches 691 | self.patch_embed = PatchEmbed( 692 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 693 | norm_layer=norm_layer if self.patch_norm else None) 694 | num_patches = self.patch_embed.num_patches 695 | patches_resolution = self.patch_embed.patches_resolution 696 | self.patches_resolution = patches_resolution 697 | 698 | # absolute position embedding 699 | if self.ape: 700 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 701 | trunc_normal_(self.absolute_pos_embed, std=.02) 702 | 703 | self.pos_drop = nn.Dropout(p=drop_rate) 704 | 705 | # stochastic depth 706 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 707 | 708 | # build encoder and bottleneck layers 709 | self.layers = nn.ModuleList() 710 | self.count = 1 711 | for i_layer in range(self.num_layers): 712 | if self.count != 4: 713 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 714 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 715 | patches_resolution[1] // (2 ** i_layer)), 716 | depth=depths[i_layer], 717 | num_heads=num_heads[i_layer], 718 | window_size=window_size, 719 | mlp_ratio=self.mlp_ratio, 720 | qkv_bias=qkv_bias, qk_scale=qk_scale, 721 | drop=drop_rate, attn_drop=attn_drop_rate, 722 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 723 | norm_layer=norm_layer, 724 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 725 | use_checkpoint=use_checkpoint, flag=False) 726 | self.count += 1 727 | elif self.count == 4: 728 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 729 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 730 | patches_resolution[1] // (2 ** i_layer)), 731 | depth=depths[i_layer], 732 | num_heads=num_heads[i_layer], 733 | window_size=window_size, 734 | mlp_ratio=self.mlp_ratio, 735 | qkv_bias=qkv_bias, qk_scale=qk_scale, 736 | drop=drop_rate, attn_drop=attn_drop_rate, 737 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 738 | norm_layer=norm_layer, 739 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 740 | use_checkpoint=use_checkpoint, flag=True) 741 | self.layers.append(layer) 742 | 743 | # build decoder layers 744 | self.layers_up = nn.ModuleList() 745 | self.concat_back_dim = nn.ModuleList() 746 | for i_layer in range(self.num_layers): 747 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 748 | int(embed_dim * 2 ** ( 749 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() 750 | if i_layer == 0: 751 | layer_up = PatchExpand( 752 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 753 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 754 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer) 755 | else: 756 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 757 | input_resolution=( 758 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 759 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 760 | depth=depths[(self.num_layers - 1 - i_layer)], 761 | num_heads=num_heads[(self.num_layers - 1 - i_layer)], 762 | window_size=window_size, 763 | mlp_ratio=self.mlp_ratio, 764 | qkv_bias=qkv_bias, qk_scale=qk_scale, 765 | drop=drop_rate, attn_drop=attn_drop_rate, 766 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum( 767 | depths[:(self.num_layers - 1 - i_layer) + 1])], 768 | norm_layer=norm_layer, 769 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 770 | use_checkpoint=use_checkpoint) 771 | self.layers_up.append(layer_up) 772 | self.concat_back_dim.append(concat_linear) 773 | 774 | self.norm = norm_layer(self.num_features) 775 | self.norm_up = norm_layer(self.embed_dim) 776 | 777 | if self.final_upsample == "expand_first": 778 | print("---final upsample expand_first---") 779 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), 780 | dim_scale=4, dim=embed_dim) 781 | self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 782 | 783 | self.apply(self._init_weights) 784 | 785 | self.my_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 786 | 787 | nb_filter = [96, 192, 384, 768] 788 | self.conv0_1 = VGGBlock(nb_filter[0] + nb_filter[1], nb_filter[1], nb_filter[0]) 789 | self.conv1_1 = VGGBlock(nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1]) 790 | self.conv0_2 = VGGBlock(nb_filter[0] * 2 + nb_filter[1], nb_filter[0], nb_filter[0]) 791 | self.conv2_1 = VGGBlock(nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2]) 792 | self.conv1_2 = VGGBlock(nb_filter[1] * 2 + nb_filter[2], nb_filter[1], nb_filter[1]) 793 | self.conv0_3 = VGGBlock(nb_filter[0] * 3 + nb_filter[1], nb_filter[0], nb_filter[0]) 794 | 795 | def _init_weights(self, m): 796 | if isinstance(m, nn.Linear): 797 | trunc_normal_(m.weight, std=.02) 798 | if isinstance(m, nn.Linear) and m.bias is not None: 799 | nn.init.constant_(m.bias, 0) 800 | elif isinstance(m, nn.LayerNorm): 801 | nn.init.constant_(m.bias, 0) 802 | nn.init.constant_(m.weight, 1.0) 803 | 804 | @torch.jit.ignore 805 | def no_weight_decay(self): 806 | return {'absolute_pos_embed'} 807 | 808 | @torch.jit.ignore 809 | def no_weight_decay_keywords(self): 810 | return {'relative_position_bias_table'} 811 | 812 | # Encoder and Bottleneck 813 | def forward_features(self, x): 814 | x = self.patch_embed(x) 815 | if self.ape: 816 | x = x + self.absolute_pos_embed 817 | x = self.pos_drop(x) 818 | x_downsample = [] 819 | 820 | for layer in self.layers: 821 | # print(x.shape) 822 | B, L, C = x.shape 823 | H, W = int(math.sqrt(L)), int(math.sqrt(L)) 824 | 825 | x_downsample.append(x.view(B, H, W, C)) 826 | x = layer(x) 827 | 828 | # print(x.shape) # (12,768,7,7) 829 | x = x.permute(0, 2, 3, 1) 830 | # print(x.shape) # (12,7,7,768) 831 | x_downsample.append(x) 832 | B, H, W, C = x.shape 833 | x = x.reshape(B, H * W, C) 834 | x = self.norm(x) # B L C 835 | 836 | # print("a:", x.shape) 837 | # print("b:", x_downsample[0].shape) 838 | # print("c:", x_downsample[1].shape) 839 | # print("d:", self.up(x_downsample[1]).shape) 840 | 841 | # my_i_layer = [0, 1, 2, 3] 842 | # layer_up1 = my_PatchExpand( 843 | # input_resolution=(self.patches_resolution[0] // (2 ** (self.num_layers - 1 - my_i_layer[0])), 844 | # self.patches_resolution[1] // (2 ** (self.num_layers - 1 - my_i_layer[0]))), 845 | # dim=int(self.embed_dim * 2 ** (self.num_layers - 1 - my_i_layer[0])), dim_scale=2, 846 | # norm_layer=nn.LayerNorm).cuda() 847 | # 848 | # layer_up2 = my_PatchExpand( 849 | # input_resolution=(self.patches_resolution[0] // (2 ** (self.num_layers - 1 - my_i_layer[1])), 850 | # self.patches_resolution[1] // (2 ** (self.num_layers - 1 - my_i_layer[1]))), 851 | # dim=int(self.embed_dim * 2 ** (self.num_layers - 1 - my_i_layer[1])), dim_scale=2, 852 | # norm_layer=nn.LayerNorm).cuda() 853 | # 854 | # layer_up3 = my_PatchExpand( 855 | # input_resolution=(self.patches_resolution[0] // (2 ** (self.num_layers - 1 - my_i_layer[2])), 856 | # self.patches_resolution[1] // (2 ** (self.num_layers - 1 - my_i_layer[2]))), 857 | # dim=int(self.embed_dim * 2 ** (self.num_layers - 1 - my_i_layer[2])), dim_scale=2, 858 | # norm_layer=nn.LayerNorm).cuda() 859 | 860 | # x0_1 = self.conv0_1(torch.cat([x_downsample[0], layer_up1(x_downsample[1])], -1)) 861 | # x1_1 = self.conv1_1(torch.cat([x_downsample[1], layer_up2(x_downsample[2])], -1)) 862 | # x0_2 = self.conv0_2(torch.cat([x_downsample[0], x0_1, layer_up1(x1_1)], -1)) 863 | # # skip3 864 | # x2_1 = self.conv2_1(torch.cat([x_downsample[2], layer_up3(x_downsample[3])], -1)) 865 | # # skip2 866 | # x1_2 = self.conv1_2(torch.cat([x_downsample[1], x1_1, layer_up2(x2_1)], -1)) 867 | # # skip1 868 | # x0_3 = self.conv0_3(torch.cat([x_downsample[0], x0_1, x0_2, layer_up1(x1_2)], -1)) 869 | 870 | # 范例 871 | # b = torch.randn(4, 3, 32, 32) # [B C H W] 872 | # b = b.transpose(1, 3) # [B W H C] 873 | # [4,32,32,3] 874 | # a = torch.rand(4, 3, 28, 28) # [B C H W] 875 | # a.permute(0, 2, 3, 1) # [B H W C] 876 | # [4,32,32,3] 877 | 878 | x0_1 = self.conv0_1( 879 | torch.cat([x_downsample[0].permute(0, 3, 1, 2), self.my_up(x_downsample[1].permute(0, 3, 1, 2))], 1)) 880 | x1_1 = self.conv1_1( 881 | torch.cat([x_downsample[1].permute(0, 3, 1, 2), self.my_up(x_downsample[2].permute(0, 3, 1, 2))], 1)) 882 | x0_2 = self.conv0_2(torch.cat([x_downsample[0].permute(0, 3, 1, 2), x0_1, self.my_up(x1_1)], 1)) 883 | # skip3 884 | x2_1 = self.conv2_1( 885 | # torch.cat([x_downsample[2].permute(0, 3, 1, 2), self.my_up(x_downsample[3].permute(0, 3, 1, 2))], 1)) 886 | torch.cat([x_downsample[2].permute(0, 3, 1, 2), self.my_up(x_downsample[4].permute(0, 3, 1, 2))], 1)) 887 | # skip2 888 | x1_2 = self.conv1_2(torch.cat([x_downsample[1].permute(0, 3, 1, 2), x1_1, self.my_up(x2_1)], 1)) 889 | # skip1 890 | x0_3 = self.conv0_3(torch.cat([x_downsample[0].permute(0, 3, 1, 2), x0_1, x0_2, self.my_up(x1_2)], 1)) 891 | 892 | # torch.flatten(input, start_dim=0, end_dim=-1) 893 | x_downsample_new = [] 894 | # print(x0_3.shape) 895 | # print(x1_2.shape) 896 | # print(x2_1.shape) 897 | # print(torch.flatten(x0_3, start_dim=2, end_dim=-1).shape) 898 | # print(torch.flatten(x1_2, start_dim=2, end_dim=-1).shape) 899 | # print(torch.flatten(x2_1, start_dim=2, end_dim=-1).shape) 900 | x_downsample_new.append(torch.flatten(x0_3, start_dim=2, end_dim=-1).permute(0, 2, 1)) 901 | x_downsample_new.append(torch.flatten(x1_2, start_dim=2, end_dim=-1).permute(0, 2, 1)) 902 | x_downsample_new.append(torch.flatten(x2_1, start_dim=2, end_dim=-1).permute(0, 2, 1)) 903 | 904 | return x, x_downsample_new 905 | 906 | # Dencoder and Skip connection 907 | def forward_up_features(self, x, x_downsample): 908 | for inx, layer_up in enumerate(self.layers_up): 909 | if inx == 0: 910 | x = layer_up(x) 911 | else: 912 | x = torch.cat([x, x_downsample[3 - inx]], -1) 913 | x = self.concat_back_dim[inx](x) 914 | x = layer_up(x) 915 | 916 | x = self.norm_up(x) # B L C 917 | 918 | return x 919 | 920 | def up_x4(self, x): 921 | H, W = self.patches_resolution 922 | B, L, C = x.shape 923 | assert L == H * W, "input features has wrong size" 924 | 925 | if self.final_upsample == "expand_first": 926 | x = self.up(x) 927 | x = x.view(B, 4 * H, 4 * W, -1) 928 | x = x.permute(0, 3, 1, 2) # B,C,H,W 929 | x = self.output(x) 930 | 931 | return x 932 | 933 | def forward(self, x): 934 | x, x_downsample = self.forward_features(x) 935 | x = self.forward_up_features(x, x_downsample) 936 | x = self.up_x4(x) 937 | 938 | return x 939 | 940 | def flops(self): 941 | flops = 0 942 | flops += self.patch_embed.flops() 943 | for i, layer in enumerate(self.layers): 944 | flops += layer.flops() 945 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 946 | flops += self.num_features * self.num_classes 947 | return flops 948 | 949 | 950 | class VGGBlock(nn.Module): 951 | def __init__(self, in_channels, middle_channels, out_channels): 952 | super().__init__() 953 | self.relu = nn.ReLU(inplace=True) 954 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) 955 | self.bn1 = nn.BatchNorm2d(middle_channels) 956 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) 957 | self.bn2 = nn.BatchNorm2d(out_channels) 958 | 959 | def forward(self, x): 960 | out = self.conv1(x) 961 | out = self.bn1(out) 962 | out = self.relu(out) 963 | 964 | out = self.conv2(out) 965 | out = self.bn2(out) 966 | out = self.relu(out) 967 | 968 | return out 969 | -------------------------------------------------------------------------------- /networks/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from .swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | self.config = config 30 | 31 | self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 32 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 33 | in_chans=config.MODEL.SWIN.IN_CHANS, 34 | num_classes=self.num_classes, 35 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 36 | depths=config.MODEL.SWIN.DEPTHS, 37 | num_heads=config.MODEL.SWIN.NUM_HEADS, 38 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 39 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 40 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 41 | qk_scale=config.MODEL.SWIN.QK_SCALE, 42 | drop_rate=config.MODEL.DROP_RATE, 43 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 44 | ape=config.MODEL.SWIN.APE, 45 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 46 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 47 | 48 | def forward(self, x): 49 | if x.size()[1] == 1: 50 | x = x.repeat(1, 3, 1, 1) 51 | logits = self.swin_unet(x) 52 | return logits 53 | 54 | def load_from(self, config): 55 | pretrained_path = config.MODEL.PRETRAIN_CKPT 56 | if pretrained_path is not None: 57 | print("pretrained_path:{}".format(pretrained_path)) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | pretrained_dict = torch.load(pretrained_path, map_location=device) 60 | if "model" not in pretrained_dict: 61 | print("---start load pretrained modle by splitting---") 62 | pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} 63 | for k in list(pretrained_dict.keys()): 64 | if "output" in k: 65 | print("delete key:{}".format(k)) 66 | del pretrained_dict[k] 67 | msg = self.swin_unet.load_state_dict(pretrained_dict, strict=False) 68 | # print(msg) 69 | return 70 | pretrained_dict = pretrained_dict['model'] 71 | print("---start load pretrained modle of swin encoder---") 72 | 73 | model_dict = self.swin_unet.state_dict() 74 | full_dict = copy.deepcopy(pretrained_dict) 75 | for k, v in pretrained_dict.items(): 76 | if "layers." in k: 77 | current_layer_num = 3 - int(k[7:8]) 78 | current_k = "layers_up." + str(current_layer_num) + k[8:] 79 | full_dict.update({current_k: v}) 80 | for k in list(full_dict.keys()): 81 | if k in model_dict: 82 | if full_dict[k].shape != model_dict[k].shape: 83 | print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) 84 | del full_dict[k] 85 | 86 | msg = self.swin_unet.load_state_dict(full_dict, strict=False) 87 | # print(msg) 88 | else: 89 | print("none pretrain") 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | tqdm 5 | tensorboard 6 | tensorboardX 7 | ml-collections 8 | medpy 9 | SimpleITK 10 | scipy 11 | h5py 12 | timm 13 | einops 14 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from datasets.dataset_synapse import Synapse_dataset 13 | from utils import test_single_volume 14 | from networks.vision_transformer import SwinUnet as ViT_seg 15 | from trainer import trainer_synapse 16 | from config import get_config 17 | 18 | """ 19 | --dataset Synapse 20 | --cfg ./configs/swin_tiny_patch4_window7_224_lite.yaml 21 | --is_saveni 22 | --volume_path ./datasets/Synapse 23 | --output_dir ./output 24 | --max_epoch 150 25 | --base_lr 0.05 26 | --img_size 224 27 | --batch_size 1 28 | """ 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--volume_path', type=str, 32 | default='./datasets/Synapse/test_vol_h5', 33 | help='root dir for validation volume data') 34 | parser.add_argument('--dataset', type=str, 35 | default='Synapse', help='experiment_name') 36 | # parser.add_argument('--num_classes', type=int, 37 | # default=9, help='output channel of network') 38 | parser.add_argument('--num_classes', type=int, 39 | default=2, help='output channel of network') 40 | parser.add_argument('--list_dir', type=str, 41 | default='./lists/lists_Synapse', help='list dir') 42 | parser.add_argument('--output_dir', default='./output', type=str, help='output dir') 43 | parser.add_argument('--max_iterations', type=int, default=30000, help='maximum epoch number to train') 44 | parser.add_argument('--max_epochs', type=int, default=150, help='maximum epoch number to train') 45 | parser.add_argument('--batch_size', type=int, default=1, 46 | help='batch_size per gpu') 47 | parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input') 48 | parser.add_argument('--is_savenii', action="store_true", help='whether to save results during inference') 49 | parser.add_argument('--test_save_dir', type=str, default='../predictions', help='saving prediction as nii!') 50 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 51 | parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate') 52 | parser.add_argument('--seed', type=int, default=1234, help='random seed') 53 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 54 | parser.add_argument( 55 | "--opts", 56 | help="Modify config options by adding 'KEY VALUE' pairs. ", 57 | default=None, 58 | nargs='+', 59 | ) 60 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 61 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 62 | help='no: no cache, ' 63 | 'full: cache all data, ' 64 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 65 | parser.add_argument('--resume', help='resume from checkpoint') 66 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 67 | parser.add_argument('--use-checkpoint', action='store_true', 68 | help="whether to use gradient checkpointing to save memory") 69 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 70 | help='mixed precision opt level, if O0, no amp is used') 71 | parser.add_argument('--tag', help='tag of experiment') 72 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 73 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 74 | 75 | args = parser.parse_args() 76 | if args.dataset == "Synapse": 77 | args.volume_path = os.path.join(args.volume_path, "test_vol_h5") 78 | # print(args.volume_path) 79 | config = get_config(args) 80 | 81 | 82 | def inference(args, model, test_save_path=None): 83 | db_test = args.Dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir) 84 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 85 | logging.info("{} test iterations per epoch".format(len(testloader))) 86 | model.eval() 87 | metric_list = 0.0 88 | f = open(r'G:\FINAL\SCUNet++\lists\lists_Synapse\testxg.txt', 'w') 89 | 90 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 91 | h, w = sampled_batch["image"].size()[2:] 92 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 93 | metric_i = test_single_volume(image, label, model, classes=args.num_classes, 94 | patch_size=[args.img_size, args.img_size], 95 | test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) 96 | metric_list += np.array(metric_i) 97 | # logging.info('idx %d case %s mean_dice %f mean_hd95 %f' % ( 98 | # i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) 99 | 100 | # if((np.mean(metric_i, axis=0)[0]>0.69)and(np.mean(metric_i, axis=0)[1]<100)): 101 | # f.write( 'order %d name %s mean_dice %f mean_hd95 %f' % ( 102 | # i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])+"\n") 103 | 104 | # f.write((str(case_name)+"\n")) 105 | 106 | f.write('order %d name %s mean_dice %f mean_hd95 %f' % ( 107 | i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1]) + "\n") 108 | 109 | metric_list = metric_list / len(db_test) 110 | for i in range(1, args.num_classes): 111 | logging.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i - 1][0], metric_list[i - 1][1])) 112 | performance = np.mean(metric_list, axis=0)[0] 113 | mean_hd95 = np.mean(metric_list, axis=0)[1] 114 | print(metric_list, axis=0) 115 | print(metric_list, axis=0) 116 | 117 | logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f' % (performance, mean_hd95)) 118 | return "Testing Finished!" 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | if not args.deterministic: 124 | cudnn.benchmark = True 125 | cudnn.deterministic = False 126 | else: 127 | cudnn.benchmark = False 128 | cudnn.deterministic = True 129 | random.seed(args.seed) 130 | np.random.seed(args.seed) 131 | torch.manual_seed(args.seed) 132 | torch.cuda.manual_seed(args.seed) 133 | 134 | dataset_config = { 135 | 'Synapse': { 136 | 'Dataset': Synapse_dataset, 137 | 'volume_path': args.volume_path, 138 | 'list_dir': './lists/lists_Synapse', 139 | # 'num_classes': 9, 140 | 'num_classes': 2, 141 | 'z_spacing': 1, 142 | }, 143 | } 144 | dataset_name = args.dataset 145 | args.num_classes = dataset_config[dataset_name]['num_classes'] 146 | args.volume_path = dataset_config[dataset_name]['volume_path'] 147 | args.Dataset = dataset_config[dataset_name]['Dataset'] 148 | args.list_dir = dataset_config[dataset_name]['list_dir'] 149 | args.z_spacing = dataset_config[dataset_name]['z_spacing'] 150 | args.is_pretrain = True 151 | 152 | net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda() 153 | 154 | snapshot = os.path.join(args.output_dir, 'best_model.pth') 155 | if not os.path.exists(snapshot): 156 | # snapshot = snapshot.replace('best_model', 'epoch_' + str(args.max_epochs - 1)) 157 | snapshot = snapshot.replace('best_model', 'epoch_' + str(59)) 158 | 159 | # msg = net.load_state_dict(torch.load(snapshot), False) 160 | 161 | msg = net.load_state_dict(torch.load(snapshot)) 162 | print("self trained swin unet", msg) 163 | print(args.output_dir) 164 | snapshot_name = args.output_dir.split('/')[-1] 165 | print(snapshot_name) 166 | 167 | log_folder = './test_log/test_log_' 168 | os.makedirs(log_folder, exist_ok=True) 169 | logging.basicConfig(filename=log_folder + '/' + snapshot_name + ".txt", level=logging.INFO, 170 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 171 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 172 | logging.info(str(args)) 173 | logging.info(snapshot_name) 174 | 175 | if args.is_savenii: 176 | args.test_save_dir = os.path.join(args.output_dir, "predictions") 177 | test_save_path = args.test_save_dir 178 | os.makedirs(test_save_path, exist_ok=True) 179 | else: 180 | test_save_path = None 181 | inference(args, net, test_save_path) 182 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from networks.vision_transformer import SwinUnet as ViT_seg 9 | from trainer import trainer_synapse 10 | from config import get_config 11 | 12 | """ 13 | --dataset Synapse 14 | --cfg ./configs/swin_tiny_patch4_window7_224_lite.yaml 15 | --root_path ./datasets/Synapse 16 | --max_epochs 1500 17 | --output_dir ./output 18 | --img_size 224 19 | --base_lr 0.005 20 | --batch_size 24 21 | """ 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--root_path', type=str, 25 | default='./datasets/Synapse/train_npz', help='root dir for data') 26 | parser.add_argument('--dataset', type=str, 27 | default='Synapse', help='experiment_name') 28 | parser.add_argument('--list_dir', type=str, 29 | default='./lists/lists_Synapse', help='list dir') 30 | # parser.add_argument('--num_classes', type=int, 31 | # default=9, help='output channel of network') 32 | parser.add_argument('--num_classes', type=int, 33 | default=2, help='output channel of network') 34 | parser.add_argument('--output_dir', default='./output', type=str, help='output dir') 35 | parser.add_argument('--max_iterations', type=int, 36 | default=30000, help='maximum epoch number to train') 37 | parser.add_argument('--max_epochs', type=int, 38 | default=1500, help='maximum epoch number to train') 39 | parser.add_argument('--batch_size', type=int, 40 | default=24, help='batch_size per gpu') 41 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 42 | parser.add_argument('--deterministic', type=int, default=1, 43 | help='whether use deterministic training') 44 | parser.add_argument('--base_lr', type=float, default=0.005, 45 | help='segmentation network learning rate') 46 | parser.add_argument('--img_size', type=int, 47 | default=224, help='input patch size of network input') 48 | parser.add_argument('--seed', type=int, 49 | default=1234, help='random seed') 50 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 51 | parser.add_argument( 52 | "--opts", 53 | help="Modify config options by adding 'KEY VALUE' pairs. ", 54 | default=None, 55 | nargs='+', 56 | ) 57 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 58 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 59 | help='no: no cache, ' 60 | 'full: cache all data, ' 61 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 62 | parser.add_argument('--resume', help='resume from checkpoint') 63 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 64 | parser.add_argument('--use-checkpoint', action='store_true', 65 | help="whether to use gradient checkpointing to save memory") 66 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 67 | help='mixed precision opt level, if O0, no amp is used') 68 | parser.add_argument('--tag', help='tag of experiment') 69 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 70 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 71 | 72 | args = parser.parse_args() 73 | if args.dataset == "Synapse": 74 | args.root_path = os.path.join(args.root_path, "train_npz") 75 | config = get_config(args) 76 | 77 | if __name__ == "__main__": 78 | if not args.deterministic: 79 | cudnn.benchmark = True 80 | cudnn.deterministic = False 81 | else: 82 | cudnn.benchmark = False 83 | cudnn.deterministic = True 84 | 85 | random.seed(args.seed) 86 | np.random.seed(args.seed) 87 | torch.manual_seed(args.seed) 88 | torch.cuda.manual_seed(args.seed) 89 | 90 | dataset_name = args.dataset 91 | dataset_config = { 92 | 'Synapse': { 93 | 'root_path': args.root_path, 94 | 'list_dir': './lists/lists_Synapse', 95 | # 'num_classes': 9, 96 | 'num_classes': 2, 97 | }, 98 | } 99 | 100 | if args.batch_size != 24 and args.batch_size % 6 == 0: 101 | args.base_lr *= args.batch_size / 24 102 | args.num_classes = dataset_config[dataset_name]['num_classes'] 103 | args.root_path = dataset_config[dataset_name]['root_path'] 104 | args.list_dir = dataset_config[dataset_name]['list_dir'] 105 | 106 | if not os.path.exists(args.output_dir): 107 | os.makedirs(args.output_dir) 108 | net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda() 109 | net.load_from(config) 110 | 111 | trainer = {'Synapse': trainer_synapse, } 112 | trainer[dataset_name](args, net, args.output_dir) 113 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import time 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from tensorboardX import SummaryWriter 12 | from torch.nn.modules.loss import CrossEntropyLoss 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | from utils import DiceLoss 16 | from torchvision import transforms 17 | from utils import test_single_volume 18 | 19 | 20 | def trainer_synapse(args, model, snapshot_path): 21 | from datasets.dataset_synapse import Synapse_dataset, RandomGenerator 22 | logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, 23 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 24 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 25 | logging.info(str(args)) 26 | base_lr = args.base_lr 27 | num_classes = args.num_classes 28 | batch_size = args.batch_size * args.n_gpu 29 | max_iterations = args.max_iterations 30 | db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", 31 | transform=transforms.Compose( 32 | [RandomGenerator(output_size=[args.img_size, args.img_size])])) 33 | print("The length of train set is: {}".format(len(db_train))) 34 | 35 | def worker_init_fn(worker_id): 36 | random.seed(args.seed + worker_id) 37 | 38 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, 39 | worker_init_fn=worker_init_fn) 40 | if args.n_gpu > 1: 41 | model = nn.DataParallel(model) 42 | model.train() 43 | ce_loss = CrossEntropyLoss() 44 | dice_loss = DiceLoss(num_classes) 45 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 46 | writer = SummaryWriter(snapshot_path + '/log') 47 | iter_num = 0 48 | max_epoch = args.max_epochs 49 | max_iterations = args.max_epochs * len(trainloader) 50 | logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations)) 51 | best_performance = 0.0 52 | iterator = tqdm(range(max_epoch), ncols=70) 53 | for epoch_num in iterator: 54 | for i_batch, sampled_batch in enumerate(trainloader): 55 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 56 | image_batch, label_batch = image_batch.cuda(), label_batch.cuda() 57 | outputs = model(image_batch) 58 | loss_ce = ce_loss(outputs, label_batch[:].long()) 59 | loss_dice = dice_loss(outputs, label_batch, softmax=True) 60 | loss = 0.4 * loss_ce + 0.6 * loss_dice 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr_ 67 | 68 | iter_num = iter_num + 1 69 | writer.add_scalar('info/lr', lr_, iter_num) 70 | writer.add_scalar('info/total_loss', loss, iter_num) 71 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 72 | 73 | logging.info('iteration %d : loss : %f, loss_ce: %f' % (iter_num, loss.item(), loss_ce.item())) 74 | 75 | if iter_num % 20 == 0: 76 | image = image_batch[1, 0:1, :, :] 77 | image = (image - image.min()) / (image.max() - image.min()) 78 | writer.add_image('train/Image', image, iter_num) 79 | outputs = torch.argmax(torch.softmax(outputs, dim=1), dim=1, keepdim=True) 80 | writer.add_image('train/Prediction', outputs[1, ...] * 50, iter_num) 81 | labs = label_batch[1, ...].unsqueeze(0) * 50 82 | writer.add_image('train/GroundTruth', labs, iter_num) 83 | 84 | # save_interval = 50 85 | # if epoch_num > int(max_epoch / 2) and (epoch_num + 1) % save_interval == 0: 86 | # save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 87 | # torch.save(model.state_dict(), save_mode_path) 88 | # logging.info("save model to {}".format(save_mode_path)) 89 | # 90 | # if epoch_num >= max_epoch - 1: 91 | # save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 92 | # torch.save(model.state_dict(), save_mode_path) 93 | # logging.info("save model to {}".format(save_mode_path)) 94 | # iterator.close() 95 | # break 96 | 97 | save_interval = 2 98 | if (epoch_num + 1) % save_interval == 0: 99 | save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 100 | torch.save(model.state_dict(), save_mode_path) 101 | logging.info("save model to {}".format(save_mode_path)) 102 | 103 | if epoch_num >= max_epoch - 1: 104 | save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 105 | torch.save(model.state_dict(), save_mode_path) 106 | logging.info("save model to {}".format(save_mode_path)) 107 | iterator.close() 108 | break 109 | 110 | writer.close() 111 | return "Training Finished!" 112 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from medpy import metric 6 | from scipy.ndimage import zoom 7 | import torch.nn as nn 8 | import SimpleITK as sitk 9 | import copy 10 | from PIL import Image 11 | import matplotlib.pyplot as plt 12 | from einops import rearrange 13 | 14 | def dice_coefficient(a, b): 15 | """计算两个numpy数组之间的Dice系数""" 16 | a = np.asarray(a).astype(np.bool) 17 | b = np.asarray(b).astype(np.bool) 18 | 19 | # 计算交集 20 | intersection = np.logical_and(a, b) 21 | if a.sum() + b.sum() != 0.0: 22 | return 2. * intersection.sum() / (a.sum() + b.sum()) 23 | else: 24 | return 1.0 25 | 26 | class DiceLoss(nn.Module): 27 | def __init__(self, n_classes): 28 | super(DiceLoss, self).__init__() 29 | self.n_classes = n_classes 30 | 31 | def _one_hot_encoder(self, input_tensor): 32 | tensor_list = [] 33 | for i in range(self.n_classes): 34 | temp_prob = input_tensor == i 35 | tensor_list.append(temp_prob.unsqueeze(1)) 36 | output_tensor = torch.cat(tensor_list, dim=1) 37 | return output_tensor.float() 38 | 39 | def _dice_loss(self, score, target): 40 | target = target.float() 41 | smooth = 1e-5 42 | intersect = torch.sum(score * target) 43 | y_sum = torch.sum(target * target) 44 | z_sum = torch.sum(score * score) 45 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 46 | loss = 1 - loss 47 | return loss 48 | 49 | def forward(self, inputs, target, weight=None, softmax=False): 50 | if softmax: 51 | inputs = torch.softmax(inputs, dim=1) 52 | target = self._one_hot_encoder(target) 53 | if weight is None: 54 | weight = [1] * self.n_classes 55 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), 56 | target.size()) 57 | class_wise_dice = [] 58 | loss = 0.0 59 | for i in range(0, self.n_classes): 60 | dice = self._dice_loss(inputs[:, i], target[:, i]) 61 | class_wise_dice.append(1.0 - dice.item()) 62 | loss += dice * weight[i] 63 | return loss / self.n_classes 64 | 65 | 66 | def calculate_metric_percase(pred, gt): 67 | pred[pred > 0] = 1 68 | gt[gt > 0] = 1 69 | if pred.sum() > 0 and gt.sum() > 0: 70 | dice = dice_coefficient(pred, gt) 71 | hd95 = metric.binary.hd95(pred, gt) 72 | return dice, hd95 73 | elif pred.sum() > 0 and gt.sum() == 0: 74 | return 1, 0 75 | elif pred.sum() == 0 and gt.sum() == 0: 76 | return 1, 0 77 | else: 78 | return 0, 0 79 | 80 | 81 | def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1): 82 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 83 | _, x, y = image.shape 84 | result = image.copy() 85 | 86 | if x != patch_size[0] or y != patch_size[1]: 87 | image = zoom(image, (1, patch_size[0] / x, patch_size[1] / y), order=3) 88 | input = torch.from_numpy(image).unsqueeze(0).float().cuda() 89 | net.eval() 90 | with torch.no_grad(): 91 | out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 92 | out = out.cpu().detach().numpy() 93 | if x != patch_size[0] or y != patch_size[1]: 94 | prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 95 | else: 96 | prediction = out 97 | 98 | if not os.path.exists('pred_image'): 99 | os.makedirs('pred_image') 100 | 101 | if not os.path.exists('label_image'): 102 | os.makedirs('label_image') 103 | 104 | #展示分割效果 105 | # 创建一个新的NumPy数组来存储结果 106 | result_pred = rearrange(result,'c h w -> h w c') 107 | result_label = rearrange(result, 'c h w -> h w c') 108 | print(result.shape) 109 | print(prediction.shape) 110 | for i in range(512): 111 | for j in range(512): 112 | if prediction[i,j] == 1: 113 | result_pred[i,j] = [255, 0, 0] 114 | 115 | for i in range(512): 116 | for j in range(512): 117 | if label[i,j] == 1: 118 | result_label[i,j] = [255, 0, 0] 119 | 120 | filename, extension = os.path.splitext(case) 121 | print(os.path.join('pred_image',filename)) 122 | plt.imsave(os.path.join('pred_image',filename)+'.png', result_pred/255, cmap='jet') 123 | print(os.path.join('label_image',filename)) 124 | plt.imsave(os.path.join('label_image',filename)+'.png', result_label/255, cmap='jet') 125 | print(prediction.shape) 126 | print(label.shape) 127 | print('dice:'+str(dice_coefficient(prediction,label))) 128 | metric_list = [] 129 | for i in range(1, classes): 130 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 131 | 132 | # if test_save_path is not None: 133 | # prediction = Image.fromarray(np.uint8(prediction)).convert('L') 134 | # prediction.save(test_save_path + '/' + case + '.png') 135 | 136 | if test_save_path is not None: 137 | a1 = copy.deepcopy(prediction) 138 | a2 = copy.deepcopy(prediction) 139 | a3 = copy.deepcopy(prediction) 140 | a1[a1 == 1] = 0 141 | a2[a2 == 1] = 255 142 | a3[a3 == 1] = 0 143 | a1 = Image.fromarray(np.uint8(a1)).convert('L') 144 | a2 = Image.fromarray(np.uint8(a2)).convert('L') 145 | a3 = Image.fromarray(np.uint8(a3)).convert('L') 146 | prediction = Image.merge('RGB', [a1, a2, a3]) 147 | prediction.save(test_save_path + '/' + case + '.png') 148 | 149 | return metric_list 150 | 151 | # def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1): 152 | # image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 153 | # if len(image.shape) == 3: 154 | # prediction = np.zeros_like(label) 155 | # for ind in range(image.shape[0]): 156 | # slice = image[ind, :, :] 157 | # x, y = slice.shape[0], slice.shape[1] 158 | # if x != patch_size[0] or y != patch_size[1]: 159 | # slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 160 | # input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 161 | # net.eval() 162 | # with torch.no_grad(): 163 | # outputs = net(input) 164 | # out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) 165 | # out = out.cpu().detach().numpy() 166 | # if x != patch_size[0] or y != patch_size[1]: 167 | # pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 168 | # else: 169 | # pred = out 170 | # prediction[ind] = pred 171 | # else: 172 | # input = torch.from_numpy(image).unsqueeze( 173 | # 0).unsqueeze(0).float().cuda() 174 | # net.eval() 175 | # with torch.no_grad(): 176 | # out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 177 | # prediction = out.cpu().detach().numpy() 178 | # metric_list = [] 179 | # for i in range(1, classes): 180 | # metric_list.append(calculate_metric_percase(prediction == i, label == i)) 181 | # 182 | # if test_save_path is not None: 183 | # img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 184 | # prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 185 | # lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 186 | # img_itk.SetSpacing((1, 1, z_spacing)) 187 | # prd_itk.SetSpacing((1, 1, z_spacing)) 188 | # lab_itk.SetSpacing((1, 1, z_spacing)) 189 | # sitk.WriteImage(prd_itk, test_save_path + '/' + case + "_pred.nii.gz") 190 | # sitk.WriteImage(img_itk, test_save_path + '/' + case + "_img.nii.gz") 191 | # sitk.WriteImage(lab_itk, test_save_path + '/' + case + "_gt.nii.gz") 192 | # return metric_list 193 | --------------------------------------------------------------------------------