├── README.md
├── configs
└── config_setting.py
├── dataprepare
├── Prepare_ISIC2017.py
└── Prepare_your_dataset.py
├── engine.py
├── loader.py
├── models
├── H_vmunet.py
└── vmamba.py
├── results
└── Readme.txt
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # H-vmunet: High-order Vision Mamba UNet for Medical Image Segmentation
4 |
5 | Renkai Wu, Yinghao Liu, Pengchen Liang*, and Qing Chang*
6 |
7 | [](https://doi.org/10.1016/j.neucom.2025.129447)
8 | [](https://arxiv.org/abs/2403.13642)
9 |
10 |
11 |
12 | ## News🚀
13 | (2025.01.12) ***The paper has been accepted by Neurocomputing***🔥🔥
14 |
15 | (2024.03.21) ***Model weights have been uploaded for download***🔥
16 |
17 | (2024.03.21) ***The project code has been uploaded***
18 |
19 | (2024.03.20) ***The first edition of our paper has been uploaded to arXiv*** 📃
20 |
21 | **0. Main Environments.**
22 | The environment installation procedure can be followed by [VM-UNet](https://github.com/JCruan519/VM-UNet), or by following the steps below:
23 | ```
24 | conda create -n vmunet python=3.8
25 | conda activate vmunet
26 | pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
27 | pip install packaging
28 | pip install timm==0.4.12
29 | pip install pytest chardet yacs termcolor
30 | pip install submitit tensorboardX
31 | pip install triton==2.0.0
32 | pip install causal_conv1d==1.0.0 # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
33 | pip install mamba_ssm==1.0.1 # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
34 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs
35 | ```
36 |
37 | **1. Datasets.**
38 |
39 | *A.ISIC2017*
40 | 1- Download the ISIC 2017 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `/data/dataset_isic17/`.
41 | 2- Run `Prepare_ISIC2017.py` for data preparation and dividing data to train,validation and test sets.
42 |
43 | *B.Spleen*
44 | 1- Download the Spleen dataset from [this](http://medicaldecathlon.com/) link.
45 |
46 | *C.CVC-ClinicDB*
47 | 1- Download the CVC-ClinicDB dataset from [this](https://polyp.grand-challenge.org/CVCClinicDB/) link.
48 |
49 | *D. Prepare your own dataset*
50 | 1. The file format reference is as follows. (The image is a 24-bit png image. The mask is an 8-bit png image. (0 pixel dots for background, 255 pixel dots for target))
51 | - './your_dataset/'
52 | - images
53 | - 0000.png
54 | - 0001.png
55 | - masks
56 | - 0000.png
57 | - 0001.png
58 | - Prepare_your_dataset.py
59 | 2. In the 'Prepare_your_dataset.py' file, change the number of training sets, validation sets and test sets you want.
60 | 3. Run 'Prepare_your_dataset.py'.
61 |
62 | **2. Train the H_vmunet.**
63 | ```
64 | python train.py
65 | ```
66 | - After trianing, you could obtain the outputs in './results/'
67 |
68 | **3. Test the H_vmunet.**
69 | First, in the test.py file, you should change the address of the checkpoint in 'resume_model'.
70 | ```
71 | python test.py
72 | ```
73 | - After testing, you could obtain the outputs in './results/'
74 |
75 | **4. Get model weights**
76 |
77 | *A.ISIC2017*
78 | [Google Drive](https://drive.google.com/file/d/10If43saeVW06p9q3oePAL3hOHqRxFoZV/view?usp=sharing)
79 |
80 | *B.Spleen*
81 | [Google Drive](https://drive.google.com/file/d/18aXOv8u-nFIbBdiUwnzHdQ7ELrNIhMu3/view?usp=sharing)
82 |
83 | *C.CVC-ClinicDB*
84 | [Google Drive](https://drive.google.com/file/d/1mG_zOlsz7OuX_qHVmB3mjMeb1GUNgtkP/view?usp=sharing)
85 |
86 |
87 | ## Citation
88 | If you find this repository helpful, please consider citing:
89 | ```
90 | @article{wu2025h,
91 | title={H-vmunet: High-order vision mamba unet for medical image segmentation},
92 | author={Wu, Renkai and Liu, Yinghao and Liang, Pengchen and Chang, Qing},
93 | journal={Neurocomputing},
94 | pages={129447},
95 | year={2025},
96 | publisher={Elsevier}
97 | }
98 | ```
99 | ## Acknowledgement
100 | Thanks to [Vim](https://github.com/hustvl/Vim), [HorNet](https://github.com/raoyongming/HorNet) and [VM-UNet](https://github.com/JCruan519/VM-UNet) for their outstanding work.
101 |
--------------------------------------------------------------------------------
/configs/config_setting.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from utils import *
3 |
4 | from datetime import datetime
5 |
6 | class setting_config:
7 | """
8 | the config of training setting.
9 | """
10 | network = 'H_vmunet'
11 | model_config = {
12 | 'num_classes': 1,
13 | 'input_channels': 3,
14 | 'c_list': [8,16,32,64,128,256],
15 | 'split_att': 'fc',
16 | 'bridge': True,
17 | 'drop_path_rate':0.4
18 | }
19 |
20 | test_weights = ''
21 |
22 | datasets = 'ISIC2017'
23 | if datasets == 'ISIC2017':
24 | data_path = ''
25 | elif datasets == 'Spleen':
26 | data_path = ''
27 | elif datasets == 'CVC-ClinicDB':
28 | data_path = ''
29 | else:
30 | raise Exception('datasets in not right!')
31 |
32 | criterion = BceDiceLoss()
33 |
34 | num_classes = 1
35 | input_size_h = 256
36 | input_size_w = 256
37 | input_channels = 3
38 | distributed = False
39 | local_rank = -1
40 | num_workers = 0
41 | seed = 42
42 | world_size = None
43 | rank = None
44 | amp = False
45 | batch_size = 8
46 | epochs = 250
47 |
48 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/'
49 |
50 | print_interval = 20
51 | val_interval = 30
52 | save_interval = 100
53 | threshold = 0.5
54 |
55 |
56 | opt = 'AdamW'
57 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!'
58 | if opt == 'Adadelta':
59 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters
60 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients
61 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability
62 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty)
63 | elif opt == 'Adagrad':
64 | lr = 0.01 # default: 0.01 – learning rate
65 | lr_decay = 0 # default: 0 – learning rate decay
66 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability
67 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty)
68 | elif opt == 'Adam':
69 | lr = 0.001 # default: 1e-3 – learning rate
70 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
71 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
72 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty)
73 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond
74 | elif opt == 'AdamW':
75 | lr = 0.001 # default: 1e-3 – learning rate
76 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
77 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
78 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient
79 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond
80 | elif opt == 'Adamax':
81 | lr = 2e-3 # default: 2e-3 – learning rate
82 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
83 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
84 | weight_decay = 0 # default: 0 – weight decay (L2 penalty)
85 | elif opt == 'ASGD':
86 | lr = 0.01 # default: 1e-2 – learning rate
87 | lambd = 1e-4 # default: 1e-4 – decay term
88 | alpha = 0.75 # default: 0.75 – power for eta update
89 | t0 = 1e6 # default: 1e6 – point at which to start averaging
90 | weight_decay = 0 # default: 0 – weight decay
91 | elif opt == 'RMSprop':
92 | lr = 1e-2 # default: 1e-2 – learning rate
93 | momentum = 0 # default: 0 – momentum factor
94 | alpha = 0.99 # default: 0.99 – smoothing constant
95 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
96 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
97 | weight_decay = 0 # default: 0 – weight decay (L2 penalty)
98 | elif opt == 'Rprop':
99 | lr = 1e-2 # default: 1e-2 – learning rate
100 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors
101 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes
102 | elif opt == 'SGD':
103 | lr = 0.01 # – learning rate
104 | momentum = 0.9 # default: 0 – momentum factor
105 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty)
106 | dampening = 0 # default: 0 – dampening for momentum
107 | nesterov = False # default: False – enables Nesterov momentum
108 |
109 | sch = 'CosineAnnealingLR'
110 | if sch == 'StepLR':
111 | step_size = epochs // 5 # – Period of learning rate decay.
112 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1
113 | last_epoch = -1 # – The index of last epoch. Default: -1.
114 | elif sch == 'MultiStepLR':
115 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing.
116 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1.
117 | last_epoch = -1 # – The index of last epoch. Default: -1.
118 | elif sch == 'ExponentialLR':
119 | gamma = 0.99 # – Multiplicative factor of learning rate decay.
120 | last_epoch = -1 # – The index of last epoch. Default: -1.
121 | elif sch == 'CosineAnnealingLR':
122 | T_max = 50 # – Maximum number of iterations. Cosine function period.
123 | eta_min = 0.00001 # – Minimum learning rate. Default: 0.
124 | last_epoch = -1 # – The index of last epoch. Default: -1.
125 | elif sch == 'ReduceLROnPlateau':
126 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
127 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
128 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
129 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
130 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
131 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
132 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
133 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
134 | elif sch == 'CosineAnnealingWarmRestarts':
135 | T_0 = 50 # – Number of iterations for the first restart.
136 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1.
137 | eta_min = 1e-6 # – Minimum learning rate. Default: 0.
138 | last_epoch = -1 # – The index of last epoch. Default: -1.
139 | elif sch == 'WP_MultiStepLR':
140 | warm_up_epochs = 10
141 | gamma = 0.1
142 | milestones = [125, 225]
143 | elif sch == 'WP_CosineLR':
144 | warm_up_epochs = 20
--------------------------------------------------------------------------------
/dataprepare/Prepare_ISIC2017.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Code created on Sat Jun 8 18:15:43 2019
4 | @author: Reza Azad
5 | """
6 |
7 | """
8 | Reminder Created on December 6, 2023.
9 | @author: Renkai Wu
10 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1
11 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect.
12 | 3.When Dataset_add uses the relative path, you need to change the start to './'.
13 | """
14 |
15 | import h5py
16 | import numpy as np
17 | import scipy.io as sio
18 | import scipy.misc as sc
19 | import glob
20 |
21 | # Parameters
22 | height = 256
23 | width = 256
24 | channels = 3
25 |
26 | ############################################################# Prepare ISIC 2017 data set #################################################
27 | Dataset_add = './ISIC2017/'
28 | Tr_add = 'ISIC2017_Task1-2_Training_Input'
29 |
30 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg')
31 | # It contains 2000 training samples
32 | Data_train_2017 = np.zeros([2000, height, width, channels])
33 | Label_train_2017 = np.zeros([2000, height, width])
34 |
35 | print('Reading ISIC 2017')
36 | print(Tr_list)
37 | for idx in range(len(Tr_list)):
38 | print(idx+1)
39 | img = sc.imread(Tr_list[idx])
40 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB'))
41 | Data_train_2017[idx, :,:,:] = img
42 |
43 | b = Tr_list[idx]
44 | a = b[0:len(Dataset_add)]
45 | b = b[len(b)-16: len(b)-4]
46 | add = (a+ 'ISIC2017_Task1_Training_GroundTruth/' + b +'_segmentation.png')
47 | img2 = sc.imread(add)
48 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear'))
49 | Label_train_2017[idx, :,:] = img2
50 |
51 | print('Reading ISIC 2017 finished')
52 |
53 | ################################################################ Make the train and test sets ########################################
54 | # We consider 1250 samples for training, 150 samples for validation and 600 samples for testing
55 |
56 | Train_img = Data_train_2017[0:1250,:,:,:]
57 | Validation_img = Data_train_2017[1250:1250+150,:,:,:]
58 | Test_img = Data_train_2017[1250+150:2000,:,:,:]
59 |
60 | Train_mask = Label_train_2017[0:1250,:,:]
61 | Validation_mask = Label_train_2017[1250:1250+150,:,:]
62 | Test_mask = Label_train_2017[1250+150:2000,:,:]
63 |
64 |
65 | np.save('data_train', Train_img)
66 | np.save('data_test' , Test_img)
67 | np.save('data_val' , Validation_img)
68 |
69 | np.save('mask_train', Train_mask)
70 | np.save('mask_test' , Test_mask)
71 | np.save('mask_val' , Validation_mask)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/dataprepare/Prepare_your_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | ##scipy==1.2.1
4 |
5 | import h5py
6 | import numpy as np
7 | import scipy.io as sio
8 | import scipy.misc as sc
9 | import glob
10 |
11 | # Parameters
12 | height = 256 # Enter the image size of the model.
13 | width = 256 # Enter the image size of the model.
14 | channels = 3 # Number of image channels
15 |
16 | train_number = 1000 # Randomly assign the number of images for generating the training set.
17 | val_number = 200 # Randomly assign the number of images for generating the validation set.
18 | test_number = 400 # Randomly assign the number of images for generating the test set.
19 | all = int(train_number) + int(val_number) + int(test_number)
20 |
21 | ############################################################# Prepare your data set #################################################
22 | Tr_list = glob.glob("images"+'/*.png') # Images storage folder. The image type should be 24-bit png format.
23 | # It contains 2594 training samples
24 | Data_train_2018 = np.zeros([all, height, width, channels])
25 | Label_train_2018 = np.zeros([all, height, width])
26 |
27 | print('Reading')
28 | print(len(Tr_list))
29 | for idx in range(len(Tr_list)):
30 | print(idx+1)
31 | img = sc.imread(Tr_list[idx])
32 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB'))
33 | Data_train_2018[idx, :,:,:] = img
34 |
35 | b = Tr_list[idx]
36 | b = b[len(b)-8: len(b)-4]
37 | add = ("masks/" + b +'.png') # Masks storage folder. The Mask type should be a black and white image of an 8-bit png (0 pixels for the background and 255 pixels for the target).
38 | img2 = sc.imread(add)
39 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear'))
40 | Label_train_2018[idx, :,:] = img2
41 |
42 | print('Reading your dataset finished')
43 |
44 | ################################################################ Make the training, validation and test sets ########################################
45 | Train_img = Data_train_2018[0:train_number,:,:,:]
46 | Validation_img = Data_train_2018[train_number:train_number+val_number,:,:,:]
47 | Test_img = Data_train_2018[train_number+val_number:all,:,:,:]
48 |
49 | Train_mask = Label_train_2018[0:train_number,:,:]
50 | Validation_mask = Label_train_2018[train_number:train_number+val_number,:,:]
51 | Test_mask = Label_train_2018[train_number+val_number:all,:,:]
52 |
53 |
54 | np.save('data_train', Train_img)
55 | np.save('data_test' , Test_img)
56 | np.save('data_val' , Validation_img)
57 |
58 | np.save('mask_train', Train_mask)
59 | np.save('mask_test' , Test_mask)
60 | np.save('mask_val' , Validation_mask)
61 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import torch
4 | from torch.cuda.amp import autocast as autocast
5 | from sklearn.metrics import confusion_matrix
6 | from utils import save_imgs
7 |
8 |
9 | def train_one_epoch(train_loader,
10 | model,
11 | criterion,
12 | optimizer,
13 | scheduler,
14 | epoch,
15 | logger,
16 | config,
17 | scaler=None):
18 | '''
19 | train model for one epoch
20 | '''
21 | # switch to train mode
22 | model.train()
23 |
24 | loss_list = []
25 |
26 | for iter, data in enumerate(train_loader):
27 | optimizer.zero_grad()
28 | images, targets = data
29 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float()
30 | if config.amp:
31 | with autocast():
32 | out = model(images)
33 | loss = criterion(out, targets)
34 | scaler.scale(loss).backward()
35 | scaler.step(optimizer)
36 | scaler.update()
37 | else:
38 | out = model(images)
39 | loss = criterion(out, targets)
40 | loss.backward()
41 | optimizer.step()
42 |
43 | loss_list.append(loss.item())
44 |
45 | now_lr = optimizer.state_dict()['param_groups'][0]['lr']
46 | if iter % config.print_interval == 0:
47 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}'
48 | print(log_info)
49 | logger.info(log_info)
50 | scheduler.step()
51 |
52 |
53 | def val_one_epoch(test_loader,
54 | model,
55 | criterion,
56 | epoch,
57 | logger,
58 | config):
59 | # switch to evaluate mode
60 | model.eval()
61 | preds = []
62 | gts = []
63 | loss_list = []
64 | with torch.no_grad():
65 | for data in tqdm(test_loader):
66 | img, msk = data
67 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
68 | out = model(img)
69 | loss = criterion(out, msk)
70 | loss_list.append(loss.item())
71 | gts.append(msk.squeeze(1).cpu().detach().numpy())
72 | if type(out) is tuple:
73 | out = out[0]
74 | out = out.squeeze(1).cpu().detach().numpy()
75 | preds.append(out)
76 |
77 | if epoch % config.val_interval == 0:
78 | preds = np.array(preds).reshape(-1)
79 | gts = np.array(gts).reshape(-1)
80 |
81 | y_pre = np.where(preds>=config.threshold, 1, 0)
82 | y_true = np.where(gts>=0.5, 1, 0)
83 |
84 | confusion = confusion_matrix(y_true, y_pre)
85 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
86 |
87 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
88 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
89 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
90 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
91 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
92 |
93 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
94 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
95 | print(log_info)
96 | logger.info(log_info)
97 |
98 | else:
99 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}'
100 | print(log_info)
101 | logger.info(log_info)
102 |
103 | return np.mean(loss_list)
104 |
105 |
106 | def test_one_epoch(test_loader,
107 | model,
108 | criterion,
109 | logger,
110 | config,
111 | test_data_name=None):
112 | # switch to evaluate mode
113 | model.eval()
114 | preds = []
115 | gts = []
116 | loss_list = []
117 | with torch.no_grad():
118 | for i, data in enumerate(tqdm(test_loader)):
119 | img, msk = data
120 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
121 | out = model(img)
122 | loss = criterion(out, msk)
123 | loss_list.append(loss.item())
124 | msk = msk.squeeze(1).cpu().detach().numpy()
125 | gts.append(msk)
126 | if type(out) is tuple:
127 | out = out[0]
128 | out = out.squeeze(1).cpu().detach().numpy()
129 | preds.append(out)
130 | save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name)
131 |
132 | preds = np.array(preds).reshape(-1)
133 | gts = np.array(gts).reshape(-1)
134 |
135 | y_pre = np.where(preds>=config.threshold, 1, 0)
136 | y_true = np.where(gts>=0.5, 1, 0)
137 |
138 | confusion = confusion_matrix(y_true, y_pre)
139 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
140 |
141 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
142 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
143 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
144 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
145 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
146 |
147 | if test_data_name is not None:
148 | log_info = f'test_datasets_name: {test_data_name}'
149 | print(log_info)
150 | logger.info(log_info)
151 | log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
152 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
153 | print(log_info)
154 | logger.info(log_info)
155 |
156 | return np.mean(loss_list)
157 |
--------------------------------------------------------------------------------
/loader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torch
3 | import numpy as np
4 | import random
5 | import os
6 | from PIL import Image
7 | from einops.layers.torch import Rearrange
8 | from scipy.ndimage.morphology import binary_dilation
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms
11 | from scipy import ndimage
12 | from utils import *
13 |
14 |
15 | # ===== normalize over the dataset
16 | def dataset_normalized(imgs):
17 | imgs_normalized = np.empty(imgs.shape)
18 | imgs_std = np.std(imgs)
19 | imgs_mean = np.mean(imgs)
20 | imgs_normalized = (imgs-imgs_mean)/imgs_std
21 | for i in range(imgs.shape[0]):
22 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255
23 | return imgs_normalized
24 |
25 |
26 | ## Temporary
27 | class isic_loader(Dataset):
28 | """ dataset class for Brats datasets
29 | """
30 | def __init__(self, path_Data, train = True, Test = False):
31 | super(isic_loader, self)
32 | self.train = train
33 | if train:
34 | self.data = np.load(path_Data+'data_train.npy')
35 | self.mask = np.load(path_Data+'mask_train.npy')
36 | else:
37 | if Test:
38 | self.data = np.load(path_Data+'data_test.npy')
39 | self.mask = np.load(path_Data+'mask_test.npy')
40 | else:
41 | self.data = np.load(path_Data+'data_val.npy')
42 | self.mask = np.load(path_Data+'mask_val.npy')
43 |
44 | self.data = dataset_normalized(self.data)
45 | self.mask = np.expand_dims(self.mask, axis=3)
46 | self.mask = self.mask/255.
47 |
48 | def __getitem__(self, indx):
49 | img = self.data[indx]
50 | seg = self.mask[indx]
51 | if self.train:
52 | if random.random() > 0.5:
53 | img, seg = self.random_rot_flip(img, seg)
54 | if random.random() > 0.5:
55 | img, seg = self.random_rotate(img, seg)
56 |
57 | seg = torch.tensor(seg.copy())
58 | img = torch.tensor(img.copy())
59 | img = img.permute( 2, 0, 1)
60 | seg = seg.permute( 2, 0, 1)
61 |
62 | return img, seg
63 |
64 | def random_rot_flip(self,image, label):
65 | k = np.random.randint(0, 4)
66 | image = np.rot90(image, k)
67 | label = np.rot90(label, k)
68 | axis = np.random.randint(0, 2)
69 | image = np.flip(image, axis=axis).copy()
70 | label = np.flip(label, axis=axis).copy()
71 | return image, label
72 |
73 | def random_rotate(self,image, label):
74 | angle = np.random.randint(20, 80)
75 | image = ndimage.rotate(image, angle, order=0, reshape=False)
76 | label = ndimage.rotate(label, angle, order=0, reshape=False)
77 | return image, label
78 |
79 |
80 |
81 | def __len__(self):
82 | return len(self.data)
83 |
--------------------------------------------------------------------------------
/models/H_vmunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 | from timm.models.layers import trunc_normal_, DropPath
6 | import os
7 | import sys
8 | import torch.fft
9 | import math
10 | from .vmamba import SS2D
11 |
12 | import traceback
13 |
14 | class DepthWiseConv2d(nn.Module):
15 | def __init__(self, dim_in, dim_out, kernel_size=3, padding=1, stride=1, dilation=1):
16 | super().__init__()
17 |
18 | self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding,
19 | stride=stride, dilation=dilation, groups=dim_in)
20 | self.norm_layer = nn.GroupNorm(4, dim_in)
21 | self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=1)
22 |
23 | def forward(self, x):
24 | return self.conv2(self.norm_layer(self.conv1(x)))
25 |
26 | if 'DWCONV_IMPL' in os.environ:
27 | try:
28 | sys.path.append(os.environ['DWCONV_IMPL'])
29 | from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
30 | def get_dwconv(dim, kernel, bias):
31 | return DepthWiseConv2dImplicitGEMM(dim, kernel, bias)
32 | # print('Using Megvii large kernel dw conv impl')
33 | except:
34 | print(traceback.format_exc())
35 | def get_dwconv(dim, kernel, bias):
36 | return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
37 |
38 | # print('[fail to use Megvii Large kernel] Using PyTorch large kernel dw conv impl')
39 | else:
40 | def get_dwconv(dim, kernel, bias):
41 | return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
42 |
43 | # print('Using PyTorch large kernel dw conv impl')
44 |
45 | class H_SS2D(nn.Module):
46 | def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0,d_state=16):
47 | super().__init__()
48 | self.order = order
49 | self.dims = [dim // 2 ** i for i in range(order)]
50 | self.dims.reverse()
51 | self.proj_in = nn.Conv2d(dim, 2*dim, 1)
52 |
53 | if gflayer is None:
54 | self.dwconv = get_dwconv(sum(self.dims), 7, True)
55 | else:
56 | self.dwconv = gflayer(sum(self.dims), h=h, w=w)
57 |
58 | self.proj_out = nn.Conv2d(dim, dim, 1)
59 |
60 | self.pws = nn.ModuleList(
61 | [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
62 | )
63 |
64 | num = len(self.dims)
65 | if num == 2:
66 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16)
67 | elif num == 3 :
68 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16)
69 | self.ss2d_2 = SS2D(d_model=self.dims[2], dropout=0, d_state=16)
70 | elif num == 4 :
71 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16)
72 | self.ss2d_2 = SS2D(d_model=self.dims[2], dropout=0, d_state=16)
73 | self.ss2d_3 = SS2D(d_model=self.dims[3], dropout=0, d_state=16)
74 | elif num == 5 :
75 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16)
76 | self.ss2d_2 = SS2D(d_model=self.dims[2], dropout=0, d_state=16)
77 | self.ss2d_3 = SS2D(d_model=self.dims[3], dropout=0, d_state=16)
78 | self.ss2d_4 = SS2D(d_model=self.dims[4], dropout=0, d_state=16)
79 |
80 | self.ss2d_in = SS2D(d_model=self.dims[0], dropout=0, d_state=16)
81 |
82 | self.scale = s
83 |
84 | print('[H_SS2D]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale)
85 |
86 |
87 | def forward(self, x, mask=None, dummy=False):
88 | B, C, H, W = x.shape
89 |
90 | fused_x = self.proj_in(x)
91 | pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
92 |
93 | dw_abc = self.dwconv(abc) * self.scale
94 |
95 | dw_list = torch.split(dw_abc, self.dims, dim=1)
96 | x = pwa * dw_list[0]
97 | x = x.permute(0, 2, 3, 1)
98 | x = self.ss2d_in(x)
99 | x = x.permute(0, 3, 1, 2)
100 |
101 | for i in range(self.order -1):
102 | x = self.pws[i](x) * dw_list[i+1]
103 | if i == 0 :
104 | x = x.permute(0, 2, 3, 1)
105 | x = self.ss2d_1(x)
106 | x = x.permute(0, 3, 1, 2)
107 | elif i == 1 :
108 | x = x.permute(0, 2, 3, 1)
109 | x = self.ss2d_2(x)
110 | x = x.permute(0, 3, 1, 2)
111 | elif i == 2 :
112 | x = x.permute(0, 2, 3, 1)
113 | x = self.ss2d_3(x)
114 | x = x.permute(0, 3, 1, 2)
115 | elif i == 3 :
116 | x = x.permute(0, 2, 3, 1)
117 | x = self.ss2d_4(x)
118 | x = x.permute(0, 3, 1, 2)
119 |
120 | x = self.proj_out(x)
121 |
122 | return x
123 |
124 | class Block(nn.Module):
125 | r""" H_VSS Block
126 | """
127 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, H_SS2D=H_SS2D):
128 | super().__init__()
129 |
130 | self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
131 | self.H_SS2D = H_SS2D(dim)
132 | self.norm2 = LayerNorm(dim, eps=1e-6)
133 | self.pwconv1 = nn.Linear(dim, 4 * dim)
134 | self.act = nn.GELU()
135 | self.pwconv2 = nn.Linear(4 * dim, dim)
136 |
137 | self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
138 | requires_grad=True) if layer_scale_init_value > 0 else None
139 |
140 | self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
141 | requires_grad=True) if layer_scale_init_value > 0 else None
142 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
143 |
144 | def forward(self, x):
145 | B, C, H, W = x.shape
146 | if self.gamma1 is not None:
147 | gamma1 = self.gamma1.view(C, 1, 1)
148 | else:
149 | gamma1 = 1
150 | x = x + self.drop_path(gamma1 * self.H_SS2D(self.norm1(x)))
151 |
152 | input = x
153 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
154 | x = self.norm2(x)
155 | x = self.pwconv1(x)
156 | x = self.act(x)
157 | x = self.pwconv2(x)
158 | if self.gamma2 is not None:
159 | x = self.gamma2 * x
160 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
161 |
162 | x = input + self.drop_path(x)
163 | return x
164 |
165 |
166 | class Channel_Att_Bridge(nn.Module):
167 | def __init__(self, c_list, split_att='fc'):
168 | super().__init__()
169 | c_list_sum = sum(c_list) - c_list[-1]
170 | self.split_att = split_att
171 | self.avgpool = nn.AdaptiveAvgPool2d(1)
172 | self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
173 | self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1)
174 | self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1)
175 | self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1)
176 | self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1)
177 | self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1)
178 | self.sigmoid = nn.Sigmoid()
179 |
180 | def forward(self, t1, t2, t3, t4, t5):
181 | att = torch.cat((self.avgpool(t1),
182 | self.avgpool(t2),
183 | self.avgpool(t3),
184 | self.avgpool(t4),
185 | self.avgpool(t5)), dim=1)
186 | att = self.get_all_att(att.squeeze(-1).transpose(-1, -2))
187 | if self.split_att != 'fc':
188 | att = att.transpose(-1, -2)
189 | att1 = self.sigmoid(self.att1(att))
190 | att2 = self.sigmoid(self.att2(att))
191 | att3 = self.sigmoid(self.att3(att))
192 | att4 = self.sigmoid(self.att4(att))
193 | att5 = self.sigmoid(self.att5(att))
194 | if self.split_att == 'fc':
195 | att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1)
196 | att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2)
197 | att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3)
198 | att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4)
199 | att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5)
200 | else:
201 | att1 = att1.unsqueeze(-1).expand_as(t1)
202 | att2 = att2.unsqueeze(-1).expand_as(t2)
203 | att3 = att3.unsqueeze(-1).expand_as(t3)
204 | att4 = att4.unsqueeze(-1).expand_as(t4)
205 | att5 = att5.unsqueeze(-1).expand_as(t5)
206 |
207 | return att1, att2, att3, att4, att5
208 |
209 |
210 | class Spatial_Att_Bridge(nn.Module):
211 | def __init__(self):
212 | super().__init__()
213 | self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3),
214 | nn.Sigmoid())
215 |
216 | def forward(self, t1, t2, t3, t4, t5):
217 | t_list = [t1, t2, t3, t4, t5]
218 | att_list = []
219 | for t in t_list:
220 | avg_out = torch.mean(t, dim=1, keepdim=True)
221 | max_out, _ = torch.max(t, dim=1, keepdim=True)
222 | att = torch.cat([avg_out, max_out], dim=1)
223 | att = self.shared_conv2d(att)
224 | att_list.append(att)
225 | return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4]
226 |
227 |
228 | class SC_Att_Bridge(nn.Module):
229 | def __init__(self, c_list, split_att='fc'):
230 | super().__init__()
231 |
232 | self.catt = Channel_Att_Bridge(c_list, split_att=split_att)
233 | self.satt = Spatial_Att_Bridge()
234 |
235 | def forward(self, t1, t2, t3, t4, t5):
236 | r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5
237 |
238 | satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5)
239 | t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5
240 |
241 | r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5
242 | t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5
243 |
244 | catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5)
245 | t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5
246 |
247 | return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_
248 |
249 |
250 | class H_vmunet(nn.Module):
251 |
252 | def __init__(self, num_classes=1, input_channels=3,layer_scale_init_value=1e-6,H_SS2D=H_SS2D, block=Block,pretrained=None,
253 | use_checkpoint=False, c_list=[8,16,32,64,128,256], depths=[2, 2, 2, 2],drop_path_rate=0.,
254 | split_att='fc', bridge=True):
255 | super().__init__()
256 | self.pretrained = pretrained
257 | self.use_checkpoint = use_checkpoint
258 | self.bridge = bridge
259 |
260 | self.encoder1 = nn.Sequential(
261 | nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1),
262 | )
263 | self.encoder2 =nn.Sequential(
264 | nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1),
265 | )
266 |
267 |
268 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
269 |
270 | if not isinstance(H_SS2D, list):
271 | H_SS2D = [partial(H_SS2D, order=2, s=1/3, gflayer=Local_SS2D),
272 | partial(H_SS2D, order=3, s=1/3, gflayer=Local_SS2D),
273 | partial(H_SS2D, order=4, s=1/3, h=24, w=13, gflayer=Local_SS2D),
274 | partial(H_SS2D, order=5, s=1/3, h=12, w=7, gflayer=Local_SS2D)]
275 | else:
276 | H_SS2D = H_SS2D
277 | assert len(H_SS2D) == 4
278 |
279 | if isinstance(H_SS2D[0], str):
280 | H_SS2D = [eval(h) for h in H_SS2D]
281 |
282 | if isinstance(block, str):
283 | block = eval(block)
284 |
285 |
286 |
287 | self.encoder3 = nn.Sequential(
288 | *[block(dim=c_list[1], drop_path=dp_rates[0 + j],
289 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[0]) for j in range(depths[0])],
290 | nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1),
291 | )
292 |
293 | self.encoder4 = nn.Sequential(
294 | *[block(dim=c_list[2], drop_path=dp_rates[2 + j],
295 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[1]) for j in range(depths[1])],
296 | nn.Conv2d(c_list[2], c_list[3], 3, stride=1, padding=1),
297 | )
298 |
299 |
300 | self.encoder5 = nn.Sequential(
301 | *[block(dim=c_list[3], drop_path=dp_rates[4 + j],
302 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[2]) for j in range(depths[2])],
303 | nn.Conv2d(c_list[3], c_list[4], 3, stride=1, padding=1),
304 | )
305 |
306 | self.encoder6 = nn.Sequential(
307 | *[block(dim=c_list[4], drop_path=dp_rates[6 + j],
308 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[3]) for j in range(depths[3])],
309 | nn.Conv2d(c_list[4], c_list[5], 3, stride=1, padding=1),
310 | )
311 |
312 |
313 | if bridge:
314 | self.scab = SC_Att_Bridge(c_list, split_att)
315 | print('SC_Att_Bridge was used')
316 |
317 | self.decoder1 = nn.Sequential(
318 | *[block(dim=c_list[5], drop_path=dp_rates[6],
319 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[3]) for j in range(depths[3])],
320 | nn.Conv2d(c_list[5], c_list[4], 3, stride=1, padding=1),
321 | )
322 |
323 |
324 | self.decoder2 = nn.Sequential(
325 | *[block(dim=c_list[4], drop_path=dp_rates[4+j],
326 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[2]) for j in range(depths[2])],
327 | nn.Conv2d(c_list[4], c_list[3], 3, stride=1, padding=1),
328 | )
329 |
330 | self.decoder3 = nn.Sequential(
331 | *[block(dim=c_list[3], drop_path=dp_rates[2+j],
332 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[1]) for j in range(depths[1])],
333 | nn.Conv2d(c_list[3], c_list[2], 3, stride=1, padding=1),
334 | )
335 |
336 | self.decoder4 = nn.Sequential(
337 | *[block(dim=c_list[2], drop_path=dp_rates[0+j],
338 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[0]) for j in range(depths[0])],
339 | nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1),
340 | )
341 |
342 | self.decoder5 = nn.Sequential(
343 | nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1),
344 | )
345 |
346 | self.ebn1 = nn.GroupNorm(4, c_list[0])
347 | self.ebn2 = nn.GroupNorm(4, c_list[1])
348 | self.ebn3 = nn.GroupNorm(4, c_list[2])
349 | self.ebn4 = nn.GroupNorm(4, c_list[3])
350 | self.ebn5 = nn.GroupNorm(4, c_list[4])
351 | self.dbn1 = nn.GroupNorm(4, c_list[4])
352 | self.dbn2 = nn.GroupNorm(4, c_list[3])
353 | self.dbn3 = nn.GroupNorm(4, c_list[2])
354 | self.dbn4 = nn.GroupNorm(4, c_list[1])
355 | self.dbn5 = nn.GroupNorm(4, c_list[0])
356 |
357 | self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1)
358 |
359 | self.apply(self._init_weights)
360 |
361 |
362 |
363 | def _init_weights(self, m):
364 | if isinstance(m, nn.Linear):
365 | trunc_normal_(m.weight, std=.02)
366 | if isinstance(m, nn.Linear) and m.bias is not None:
367 | nn.init.constant_(m.bias, 0)
368 | elif isinstance(m, nn.Conv1d):
369 | n = m.kernel_size[0] * m.out_channels
370 | m.weight.data.normal_(0, math.sqrt(2. / n))
371 | elif isinstance(m, nn.Conv2d):
372 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
373 | fan_out //= m.groups
374 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
375 | if m.bias is not None:
376 | m.bias.data.zero_()
377 |
378 |
379 | def forward(self, x):
380 |
381 | out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
382 | t1 = out # b, c0, H/2, W/2
383 |
384 | out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
385 | t2 = out # b, c1, H/4, W/4
386 |
387 | out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
388 | t3 = out # b, c2, H/8, W/8
389 |
390 |
391 | out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2))
392 | t4 = out # b, c3, H/16, W/16
393 |
394 | out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2))
395 | t5 = out # b, c4, H/32, W/32
396 |
397 | if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5)
398 |
399 | out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32
400 |
401 | out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32
402 | out5 = torch.add(out5, t5) # b, c4, H/32, W/32
403 |
404 | out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16
405 | out4 = torch.add(out4, t4) # b, c3, H/16, W/16
406 |
407 | out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8
408 | out3 = torch.add(out3, t3) # b, c2, H/8, W/8
409 |
410 | out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4
411 | out2 = torch.add(out2, t2) # b, c1, H/4, W/4
412 |
413 | out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2
414 | out1 = torch.add(out1, t1) # b, c0, H/2, W/2
415 |
416 | out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W
417 |
418 | return torch.sigmoid(out0)
419 |
420 | class LayerNorm(nn.Module):
421 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
422 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
423 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
424 | with shape (batch_size, channels, height, width).
425 | """
426 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
427 | super().__init__()
428 | self.weight = nn.Parameter(torch.ones(normalized_shape))
429 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
430 | self.eps = eps
431 | self.data_format = data_format
432 | if self.data_format not in ["channels_last", "channels_first"]:
433 | raise NotImplementedError
434 | self.normalized_shape = (normalized_shape, )
435 |
436 | def forward(self, x):
437 | if self.data_format == "channels_last":
438 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
439 | elif self.data_format == "channels_first":
440 | u = x.mean(1, keepdim=True)
441 | s = (x - u).pow(2).mean(1, keepdim=True)
442 | x = (x - u) / torch.sqrt(s + self.eps)
443 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
444 | return x
445 |
446 | class Local_SS2D(nn.Module):
447 | def __init__(self, dim, h=14, w=8):
448 | super().__init__()
449 | self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
450 | self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
451 | trunc_normal_(self.complex_weight, std=.02)
452 | self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
453 | self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
454 |
455 | self.SS2D = SS2D(d_model=dim // 2, dropout=0, d_state=16)
456 |
457 |
458 | def forward(self, x):
459 | x = self.pre_norm(x)
460 | x1, x2 = torch.chunk(x, 2, dim=1)
461 | x1 = self.dw(x1)
462 |
463 | B, C, a, b = x2.shape
464 |
465 | x2 = x2.permute(0, 2, 3, 1)
466 |
467 | x2 = self.SS2D(x2)
468 |
469 | x2 = x2.permute(0, 3, 1, 2)
470 |
471 | x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b)
472 | x = self.post_norm(x)
473 | return x
474 |
--------------------------------------------------------------------------------
/models/vmamba.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | from functools import partial
4 | from typing import Optional, Callable
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.utils.checkpoint as checkpoint
10 | from einops import rearrange, repeat
11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12 | try:
13 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
14 | except:
15 | pass
16 |
17 | # an alternative for mamba_ssm (in which causal_conv1d is needed)
18 | try:
19 | from selective_scan import selective_scan_fn as selective_scan_fn_v1
20 | from selective_scan import selective_scan_ref as selective_scan_ref_v1
21 | except:
22 | pass
23 |
24 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
25 |
26 |
27 | def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
28 | """
29 | u: r(B D L)
30 | delta: r(B D L)
31 | A: r(D N)
32 | B: r(B N L)
33 | C: r(B N L)
34 | D: r(D)
35 | z: r(B D L)
36 | delta_bias: r(D), fp32
37 |
38 | ignores:
39 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
40 | """
41 | import numpy as np
42 |
43 | # fvcore.nn.jit_handles
44 | def get_flops_einsum(input_shapes, equation):
45 | np_arrs = [np.zeros(s) for s in input_shapes]
46 | optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
47 | for line in optim.split("\n"):
48 | if "optimized flop" in line.lower():
49 | # divided by 2 because we count MAC (multiply-add counted as one flop)
50 | flop = float(np.floor(float(line.split(":")[-1]) / 2))
51 | return flop
52 |
53 |
54 | assert not with_complex
55 |
56 | flops = 0 # below code flops = 0
57 | if False:
58 | ...
59 | """
60 | dtype_in = u.dtype
61 | u = u.float()
62 | delta = delta.float()
63 | if delta_bias is not None:
64 | delta = delta + delta_bias[..., None].float()
65 | if delta_softplus:
66 | delta = F.softplus(delta)
67 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
68 | is_variable_B = B.dim() >= 3
69 | is_variable_C = C.dim() >= 3
70 | if A.is_complex():
71 | if is_variable_B:
72 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
73 | if is_variable_C:
74 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
75 | else:
76 | B = B.float()
77 | C = C.float()
78 | x = A.new_zeros((batch, dim, dstate))
79 | ys = []
80 | """
81 |
82 | flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
83 | if with_Group:
84 | flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
85 | else:
86 | flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
87 | if False:
88 | ...
89 | """
90 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
91 | if not is_variable_B:
92 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
93 | else:
94 | if B.dim() == 3:
95 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
96 | else:
97 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
98 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
99 | if is_variable_C and C.dim() == 4:
100 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
101 | last_state = None
102 | """
103 |
104 | in_for_flops = B * D * N
105 | if with_Group:
106 | in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
107 | else:
108 | in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
109 | flops += L * in_for_flops
110 | if False:
111 | ...
112 | """
113 | for i in range(u.shape[2]):
114 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
115 | if not is_variable_C:
116 | y = torch.einsum('bdn,dn->bd', x, C)
117 | else:
118 | if C.dim() == 3:
119 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
120 | else:
121 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
122 | if i == u.shape[2] - 1:
123 | last_state = x
124 | if y.is_complex():
125 | y = y.real * 2
126 | ys.append(y)
127 | y = torch.stack(ys, dim=2) # (batch dim L)
128 | """
129 |
130 | if with_D:
131 | flops += B * D * L
132 | if with_Z:
133 | flops += B * D * L
134 | if False:
135 | ...
136 | """
137 | out = y if D is None else y + u * rearrange(D, "d -> d 1")
138 | if z is not None:
139 | out = out * F.silu(z)
140 | out = out.to(dtype=dtype_in)
141 | """
142 |
143 | return flops
144 |
145 |
146 | class PatchEmbed2D(nn.Module):
147 | r""" Image to Patch Embedding
148 | Args:
149 | patch_size (int): Patch token size. Default: 4.
150 | in_chans (int): Number of input image channels. Default: 3.
151 | embed_dim (int): Number of linear projection output channels. Default: 96.
152 | norm_layer (nn.Module, optional): Normalization layer. Default: None
153 | """
154 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
155 | super().__init__()
156 | if isinstance(patch_size, int):
157 | patch_size = (patch_size, patch_size)
158 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
159 | if norm_layer is not None:
160 | self.norm = norm_layer(embed_dim)
161 | else:
162 | self.norm = None
163 |
164 | def forward(self, x):
165 | x = self.proj(x).permute(0, 2, 3, 1)
166 | if self.norm is not None:
167 | x = self.norm(x)
168 | return x
169 |
170 |
171 | class PatchMerging2D(nn.Module):
172 | r""" Patch Merging Layer.
173 | Args:
174 | input_resolution (tuple[int]): Resolution of input feature.
175 | dim (int): Number of input channels.
176 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
177 | """
178 |
179 | def __init__(self, dim, norm_layer=nn.LayerNorm):
180 | super().__init__()
181 | self.dim = dim
182 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
183 | self.norm = norm_layer(4 * dim)
184 |
185 | def forward(self, x):
186 | B, H, W, C = x.shape
187 |
188 | SHAPE_FIX = [-1, -1]
189 | if (W % 2 != 0) or (H % 2 != 0):
190 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
191 | SHAPE_FIX[0] = H // 2
192 | SHAPE_FIX[1] = W // 2
193 |
194 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
195 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
196 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
197 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
198 |
199 | if SHAPE_FIX[0] > 0:
200 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
201 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
202 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
203 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
204 |
205 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
206 | x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C
207 |
208 | x = self.norm(x)
209 | x = self.reduction(x)
210 |
211 | return x
212 |
213 |
214 | class PatchExpand2D(nn.Module):
215 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
216 | super().__init__()
217 | self.dim = dim*2
218 | self.dim_scale = dim_scale
219 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False)
220 | self.norm = norm_layer(self.dim // dim_scale)
221 |
222 | def forward(self, x):
223 | B, H, W, C = x.shape
224 | x = self.expand(x)
225 |
226 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale)
227 | x= self.norm(x)
228 |
229 | return x
230 |
231 |
232 | class Final_PatchExpand2D(nn.Module):
233 | def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
234 | super().__init__()
235 | self.dim = dim
236 | self.dim_scale = dim_scale
237 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False)
238 | self.norm = norm_layer(self.dim // dim_scale)
239 |
240 | def forward(self, x):
241 | B, H, W, C = x.shape
242 | x = self.expand(x)
243 |
244 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale)
245 | x= self.norm(x)
246 |
247 | return x
248 |
249 |
250 | class SS2D(nn.Module):
251 | def __init__(
252 | self,
253 | d_model,
254 | d_state=16,
255 | # d_state="auto", # 20240109
256 | d_conv=3,
257 | expand=2,
258 | dt_rank="auto",
259 | dt_min=0.001,
260 | dt_max=0.1,
261 | dt_init="random",
262 | dt_scale=1.0,
263 | dt_init_floor=1e-4,
264 | dropout=0.,
265 | conv_bias=True,
266 | bias=False,
267 | device=None,
268 | dtype=None,
269 | **kwargs,
270 | ):
271 | factory_kwargs = {"device": device, "dtype": dtype}
272 | super().__init__()
273 | self.d_model = d_model
274 | self.d_state = d_state
275 | # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
276 | self.d_conv = d_conv
277 | self.expand = expand
278 | self.d_inner = int(self.expand * self.d_model)
279 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
280 |
281 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
282 | self.conv2d = nn.Conv2d(
283 | in_channels=self.d_inner,
284 | out_channels=self.d_inner,
285 | groups=self.d_inner,
286 | bias=conv_bias,
287 | kernel_size=d_conv,
288 | padding=(d_conv - 1) // 2,
289 | **factory_kwargs,
290 | )
291 | self.act = nn.SiLU()
292 |
293 | self.x_proj = (
294 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
295 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
296 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
297 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
298 | )
299 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
300 | del self.x_proj
301 |
302 | self.dt_projs = (
303 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
304 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
305 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
306 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
307 | )
308 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
309 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
310 | del self.dt_projs
311 |
312 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
313 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
314 |
315 | # self.selective_scan = selective_scan_fn
316 | self.forward_core = self.forward_corev0
317 |
318 | self.out_norm = nn.LayerNorm(self.d_inner)
319 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
320 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None
321 |
322 | @staticmethod
323 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
324 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
325 |
326 | # Initialize special dt projection to preserve variance at initialization
327 | dt_init_std = dt_rank**-0.5 * dt_scale
328 | if dt_init == "constant":
329 | nn.init.constant_(dt_proj.weight, dt_init_std)
330 | elif dt_init == "random":
331 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
332 | else:
333 | raise NotImplementedError
334 |
335 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
336 | dt = torch.exp(
337 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
338 | + math.log(dt_min)
339 | ).clamp(min=dt_init_floor)
340 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
341 | inv_dt = dt + torch.log(-torch.expm1(-dt))
342 | with torch.no_grad():
343 | dt_proj.bias.copy_(inv_dt)
344 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
345 | dt_proj.bias._no_reinit = True
346 |
347 | return dt_proj
348 |
349 | @staticmethod
350 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
351 | # S4D real initialization
352 | A = repeat(
353 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
354 | "n -> d n",
355 | d=d_inner,
356 | ).contiguous()
357 | A_log = torch.log(A) # Keep A_log in fp32
358 | if copies > 1:
359 | A_log = repeat(A_log, "d n -> r d n", r=copies)
360 | if merge:
361 | A_log = A_log.flatten(0, 1)
362 | A_log = nn.Parameter(A_log)
363 | A_log._no_weight_decay = True
364 | return A_log
365 |
366 | @staticmethod
367 | def D_init(d_inner, copies=1, device=None, merge=True):
368 | # D "skip" parameter
369 | D = torch.ones(d_inner, device=device)
370 | if copies > 1:
371 | D = repeat(D, "n1 -> r n1", r=copies)
372 | if merge:
373 | D = D.flatten(0, 1)
374 | D = nn.Parameter(D) # Keep in fp32
375 | D._no_weight_decay = True
376 | return D
377 |
378 | def forward_corev0(self, x: torch.Tensor):
379 | self.selective_scan = selective_scan_fn
380 |
381 | B, C, H, W = x.shape
382 | L = H * W
383 | K = 4
384 |
385 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
386 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
387 |
388 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
389 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
390 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
391 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
392 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
393 |
394 | xs = xs.float().view(B, -1, L) # (b, k * d, l)
395 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
396 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
397 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
398 | Ds = self.Ds.float().view(-1) # (k * d)
399 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
400 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
401 |
402 | out_y = self.selective_scan(
403 | xs, dts,
404 | As, Bs, Cs, Ds, z=None,
405 | delta_bias=dt_projs_bias,
406 | delta_softplus=True,
407 | return_last_state=False,
408 | ).view(B, K, -1, L)
409 | assert out_y.dtype == torch.float
410 |
411 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
412 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
413 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
414 |
415 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
416 |
417 | # an alternative to forward_corev1
418 | def forward_corev1(self, x: torch.Tensor):
419 | self.selective_scan = selective_scan_fn_v1
420 |
421 | B, C, H, W = x.shape
422 | L = H * W
423 | K = 4
424 |
425 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
426 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
427 |
428 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
429 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
430 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
431 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
432 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
433 |
434 | xs = xs.float().view(B, -1, L) # (b, k * d, l)
435 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
436 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
437 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
438 | Ds = self.Ds.float().view(-1) # (k * d)
439 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
440 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
441 |
442 | out_y = self.selective_scan(
443 | xs, dts,
444 | As, Bs, Cs, Ds,
445 | delta_bias=dt_projs_bias,
446 | delta_softplus=True,
447 | ).view(B, K, -1, L)
448 | assert out_y.dtype == torch.float
449 |
450 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
451 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
452 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
453 |
454 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
455 |
456 | def forward(self, x: torch.Tensor, **kwargs):
457 | B, H, W, C = x.shape
458 |
459 | xz = self.in_proj(x)
460 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
461 |
462 | x = x.permute(0, 3, 1, 2).contiguous()
463 | x = self.act(self.conv2d(x)) # (b, d, h, w)
464 | y1, y2, y3, y4 = self.forward_core(x)
465 | assert y1.dtype == torch.float32
466 | y = y1 + y2 + y3 + y4
467 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
468 | y = self.out_norm(y)
469 | y = y * F.silu(z)
470 | out = self.out_proj(y)
471 | if self.dropout is not None:
472 | out = self.dropout(out)
473 | return out
474 |
475 |
476 | class VSSBlock(nn.Module):
477 | def __init__(
478 | self,
479 | hidden_dim: int = 0,
480 | drop_path: float = 0,
481 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
482 | attn_drop_rate: float = 0,
483 | d_state: int = 16,
484 | **kwargs,
485 | ):
486 | super().__init__()
487 | self.ln_1 = norm_layer(hidden_dim)
488 | self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs)
489 | self.drop_path = DropPath(drop_path)
490 |
491 | def forward(self, input: torch.Tensor):
492 | x = input + self.drop_path(self.self_attention(self.ln_1(input)))
493 | return x
494 |
495 |
496 | class VSSLayer(nn.Module):
497 | """ A basic Swin Transformer layer for one stage.
498 | Args:
499 | dim (int): Number of input channels.
500 | depth (int): Number of blocks.
501 | drop (float, optional): Dropout rate. Default: 0.0
502 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
503 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
504 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
505 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
506 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
507 | """
508 |
509 | def __init__(
510 | self,
511 | dim,
512 | depth,
513 | attn_drop=0.,
514 | drop_path=0.,
515 | norm_layer=nn.LayerNorm,
516 | downsample=None,
517 | use_checkpoint=False,
518 | d_state=16,
519 | **kwargs,
520 | ):
521 | super().__init__()
522 | self.dim = dim
523 | self.use_checkpoint = use_checkpoint
524 |
525 | self.blocks = nn.ModuleList([
526 | VSSBlock(
527 | hidden_dim=dim,
528 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
529 | norm_layer=norm_layer,
530 | attn_drop_rate=attn_drop,
531 | d_state=d_state,
532 | )
533 | for i in range(depth)])
534 |
535 | if True: # is this really applied? Yes, but been overriden later in VSSM!
536 | def _init_weights(module: nn.Module):
537 | for name, p in module.named_parameters():
538 | if name in ["out_proj.weight"]:
539 | p = p.clone().detach_() # fake init, just to keep the seed ....
540 | nn.init.kaiming_uniform_(p, a=math.sqrt(5))
541 | self.apply(_init_weights)
542 |
543 | if downsample is not None:
544 | self.downsample = downsample(dim=dim, norm_layer=norm_layer)
545 | else:
546 | self.downsample = None
547 |
548 |
549 | def forward(self, x):
550 | for blk in self.blocks:
551 | if self.use_checkpoint:
552 | x = checkpoint.checkpoint(blk, x)
553 | else:
554 | x = blk(x)
555 |
556 | if self.downsample is not None:
557 | x = self.downsample(x)
558 |
559 | return x
560 |
561 |
562 |
563 | class VSSLayer_up(nn.Module):
564 | """ A basic Swin Transformer layer for one stage.
565 | Args:
566 | dim (int): Number of input channels.
567 | depth (int): Number of blocks.
568 | drop (float, optional): Dropout rate. Default: 0.0
569 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
570 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
571 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
572 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
573 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
574 | """
575 |
576 | def __init__(
577 | self,
578 | dim,
579 | depth,
580 | attn_drop=0.,
581 | drop_path=0.,
582 | norm_layer=nn.LayerNorm,
583 | upsample=None,
584 | use_checkpoint=False,
585 | d_state=16,
586 | **kwargs,
587 | ):
588 | super().__init__()
589 | self.dim = dim
590 | self.use_checkpoint = use_checkpoint
591 |
592 | self.blocks = nn.ModuleList([
593 | VSSBlock(
594 | hidden_dim=dim,
595 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
596 | norm_layer=norm_layer,
597 | attn_drop_rate=attn_drop,
598 | d_state=d_state,
599 | )
600 | for i in range(depth)])
601 |
602 | if True: # is this really applied? Yes, but been overriden later in VSSM!
603 | def _init_weights(module: nn.Module):
604 | for name, p in module.named_parameters():
605 | if name in ["out_proj.weight"]:
606 | p = p.clone().detach_() # fake init, just to keep the seed ....
607 | nn.init.kaiming_uniform_(p, a=math.sqrt(5))
608 | self.apply(_init_weights)
609 |
610 | if upsample is not None:
611 | self.upsample = upsample(dim=dim, norm_layer=norm_layer)
612 | else:
613 | self.upsample = None
614 |
615 |
616 | def forward(self, x):
617 | if self.upsample is not None:
618 | x = self.upsample(x)
619 | for blk in self.blocks:
620 | if self.use_checkpoint:
621 | x = checkpoint.checkpoint(blk, x)
622 | else:
623 | x = blk(x)
624 | return x
625 |
626 |
627 |
628 | class VSSM(nn.Module):
629 | def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2],
630 | dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
631 | norm_layer=nn.LayerNorm, patch_norm=True,
632 | use_checkpoint=False, **kwargs):
633 | super().__init__()
634 | self.num_classes = num_classes
635 | self.num_layers = len(depths)
636 | if isinstance(dims, int):
637 | dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
638 | self.embed_dim = dims[0]
639 | self.num_features = dims[-1]
640 | self.dims = dims
641 |
642 | self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
643 | norm_layer=norm_layer if patch_norm else None)
644 |
645 | # WASTED absolute position embedding ======================
646 | self.ape = False
647 | # self.ape = False
648 | # drop_rate = 0.0
649 | if self.ape:
650 | self.patches_resolution = self.patch_embed.patches_resolution
651 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
652 | trunc_normal_(self.absolute_pos_embed, std=.02)
653 | self.pos_drop = nn.Dropout(p=drop_rate)
654 |
655 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
656 | dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1]
657 |
658 | self.layers = nn.ModuleList()
659 | for i_layer in range(self.num_layers):
660 | layer = VSSLayer(
661 | dim=dims[i_layer],
662 | depth=depths[i_layer],
663 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
664 | drop=drop_rate,
665 | attn_drop=attn_drop_rate,
666 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
667 | norm_layer=norm_layer,
668 | downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
669 | use_checkpoint=use_checkpoint,
670 | )
671 | self.layers.append(layer)
672 |
673 | self.layers_up = nn.ModuleList()
674 | for i_layer in range(self.num_layers):
675 | layer = VSSLayer_up(
676 | dim=dims_decoder[i_layer],
677 | depth=depths_decoder[i_layer],
678 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
679 | drop=drop_rate,
680 | attn_drop=attn_drop_rate,
681 | drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])],
682 | norm_layer=norm_layer,
683 | upsample=PatchExpand2D if (i_layer != 0) else None,
684 | use_checkpoint=use_checkpoint,
685 | )
686 | self.layers_up.append(layer)
687 |
688 | self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer)
689 | self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1)
690 |
691 | # self.norm = norm_layer(self.num_features)
692 | # self.avgpool = nn.AdaptiveAvgPool1d(1)
693 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
694 |
695 | self.apply(self._init_weights)
696 |
697 | def _init_weights(self, m: nn.Module):
698 | """
699 | out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
700 | no fc.weight found in the any of the model parameters
701 | no nn.Embedding found in the any of the model parameters
702 | so the thing is, VSSBlock initialization is useless
703 |
704 | Conv2D is not intialized !!!
705 | """
706 | if isinstance(m, nn.Linear):
707 | trunc_normal_(m.weight, std=.02)
708 | if isinstance(m, nn.Linear) and m.bias is not None:
709 | nn.init.constant_(m.bias, 0)
710 | elif isinstance(m, nn.LayerNorm):
711 | nn.init.constant_(m.bias, 0)
712 | nn.init.constant_(m.weight, 1.0)
713 |
714 | @torch.jit.ignore
715 | def no_weight_decay(self):
716 | return {'absolute_pos_embed'}
717 |
718 | @torch.jit.ignore
719 | def no_weight_decay_keywords(self):
720 | return {'relative_position_bias_table'}
721 |
722 | def forward_features(self, x):
723 | skip_list = []
724 | x = self.patch_embed(x)
725 | if self.ape:
726 | x = x + self.absolute_pos_embed
727 | x = self.pos_drop(x)
728 |
729 | for layer in self.layers:
730 | skip_list.append(x)
731 | x = layer(x)
732 | return x, skip_list
733 |
734 | def forward_features_up(self, x, skip_list):
735 | for inx, layer_up in enumerate(self.layers_up):
736 | if inx == 0:
737 | x = layer_up(x)
738 | else:
739 | x = layer_up(x+skip_list[-inx])
740 |
741 | return x
742 |
743 | def forward_final(self, x):
744 | x = self.final_up(x)
745 | x = x.permute(0,3,1,2)
746 | x = self.final_conv(x)
747 | return x
748 |
749 | def forward_backbone(self, x):
750 | x = self.patch_embed(x)
751 | if self.ape:
752 | x = x + self.absolute_pos_embed
753 | x = self.pos_drop(x)
754 |
755 | for layer in self.layers:
756 | x = layer(x)
757 | return x
758 |
759 | def forward(self, x):
760 | x, skip_list = self.forward_features(x)
761 | x = self.forward_features_up(x, skip_list)
762 | x = self.forward_final(x)
763 |
764 | return x
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
--------------------------------------------------------------------------------
/results/Readme.txt:
--------------------------------------------------------------------------------
1 | Result save location
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.cuda.amp import autocast, GradScaler
4 | from torch.utils.data import DataLoader
5 | from loader import *
6 |
7 | from models.H_vmunet import H_vmunet
8 | from engine import *
9 | import os
10 | import sys
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3"
12 |
13 | from utils import *
14 | from configs.config_setting import setting_config
15 |
16 | import warnings
17 | warnings.filterwarnings("ignore")
18 |
19 |
20 |
21 | def main(config):
22 |
23 | print('#----------Creating logger----------#')
24 | sys.path.append(config.work_dir + '/')
25 | log_dir = os.path.join(config.work_dir, 'log')
26 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
27 | resume_model = os.path.join('')
28 | outputs = os.path.join(config.work_dir, 'outputs')
29 | if not os.path.exists(checkpoint_dir):
30 | os.makedirs(checkpoint_dir)
31 | if not os.path.exists(outputs):
32 | os.makedirs(outputs)
33 |
34 | global logger
35 | logger = get_logger('test', log_dir)
36 |
37 | log_config_info(config, logger)
38 |
39 |
40 |
41 |
42 |
43 | print('#----------GPU init----------#')
44 | set_seed(config.seed)
45 | gpu_ids = [0]# [0, 1, 2, 3]
46 | torch.cuda.empty_cache()
47 |
48 |
49 |
50 | print('#----------Prepareing Models----------#')
51 | model_cfg = config.model_config
52 | model = H_vmunet(num_classes=model_cfg['num_classes'],
53 | input_channels=model_cfg['input_channels'],
54 | c_list=model_cfg['c_list'],
55 | split_att=model_cfg['split_att'],
56 | bridge=model_cfg['bridge'],
57 | drop_path_rate=model_cfg['drop_path_rate'])
58 |
59 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])
60 |
61 |
62 | print('#----------Preparing dataset----------#')
63 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True)
64 | test_loader = DataLoader(test_dataset,
65 | batch_size=1,
66 | shuffle=False,
67 | pin_memory=True,
68 | num_workers=config.num_workers,
69 | drop_last=True)
70 |
71 | print('#----------Prepareing loss, opt, sch and amp----------#')
72 | criterion = config.criterion
73 | optimizer = get_optimizer(config, model)
74 | scheduler = get_scheduler(config, optimizer)
75 | scaler = GradScaler()
76 |
77 |
78 |
79 |
80 |
81 | print('#----------Set other params----------#')
82 | min_loss = 999
83 | start_epoch = 1
84 | min_epoch = 1
85 |
86 |
87 | print('#----------Testing----------#')
88 | best_weight = torch.load(resume_model, map_location=torch.device('cpu'))
89 | model.module.load_state_dict(best_weight)
90 | loss = test_one_epoch(
91 | test_loader,
92 | model,
93 | criterion,
94 | logger,
95 | config,
96 | )
97 |
98 |
99 |
100 | if __name__ == '__main__':
101 | config = setting_config
102 | main(config)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.cuda.amp import autocast, GradScaler
4 | from torch.utils.data import DataLoader
5 | from loader import *
6 |
7 | from models.H_vmunet import H_vmunet
8 | from engine import *
9 | import os
10 | import sys
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3"
12 |
13 | from utils import *
14 | from configs.config_setting import setting_config
15 |
16 | import warnings
17 | warnings.filterwarnings("ignore")
18 |
19 | def main(config):
20 |
21 | print('#----------Creating logger----------#')
22 | sys.path.append(config.work_dir + '/')
23 | log_dir = os.path.join(config.work_dir, 'log')
24 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
25 | resume_model = os.path.join(checkpoint_dir, 'latest.pth')
26 | outputs = os.path.join(config.work_dir, 'outputs')
27 | if not os.path.exists(checkpoint_dir):
28 | os.makedirs(checkpoint_dir)
29 | if not os.path.exists(outputs):
30 | os.makedirs(outputs)
31 |
32 | global logger
33 | logger = get_logger('train', log_dir)
34 |
35 | log_config_info(config, logger)
36 |
37 |
38 |
39 |
40 |
41 | print('#----------GPU init----------#')
42 | set_seed(config.seed)
43 | gpu_ids = [0]# [0, 1, 2, 3]
44 | torch.cuda.empty_cache()
45 |
46 |
47 |
48 |
49 |
50 | print('#----------Preparing dataset----------#')
51 | train_dataset = isic_loader(path_Data = config.data_path, train = True)
52 | train_loader = DataLoader(train_dataset,
53 | batch_size=config.batch_size,
54 | shuffle=True,
55 | pin_memory=True,
56 | num_workers=config.num_workers)
57 | val_dataset = isic_loader(path_Data = config.data_path, train = False)
58 | val_loader = DataLoader(val_dataset,
59 | batch_size=1,
60 | shuffle=False,
61 | pin_memory=True,
62 | num_workers=config.num_workers,
63 | drop_last=True)
64 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True)
65 | test_loader = DataLoader(test_dataset,
66 | batch_size=1,
67 | shuffle=False,
68 | pin_memory=True,
69 | num_workers=config.num_workers,
70 | drop_last=True)
71 |
72 |
73 |
74 |
75 | print('#----------Prepareing Models----------#')
76 | model_cfg = config.model_config
77 | model = H_vmunet(num_classes=model_cfg['num_classes'],
78 | input_channels=model_cfg['input_channels'],
79 | c_list=model_cfg['c_list'],
80 | split_att=model_cfg['split_att'],
81 | bridge=model_cfg['bridge'],
82 | drop_path_rate=model_cfg['drop_path_rate'])
83 |
84 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])
85 |
86 |
87 |
88 |
89 |
90 | print('#----------Prepareing loss, opt, sch and amp----------#')
91 | criterion = config.criterion
92 | optimizer = get_optimizer(config, model)
93 | scheduler = get_scheduler(config, optimizer)
94 | scaler = GradScaler()
95 |
96 |
97 |
98 |
99 |
100 | print('#----------Set other params----------#')
101 | min_loss = 999
102 | start_epoch = 1
103 | min_epoch = 1
104 |
105 |
106 |
107 |
108 |
109 | if os.path.exists(resume_model):
110 | print('#----------Resume Model and Other params----------#')
111 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu'))
112 | model.module.load_state_dict(checkpoint['model_state_dict'])
113 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
114 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
115 | saved_epoch = checkpoint['epoch']
116 | start_epoch += saved_epoch
117 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss']
118 |
119 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}'
120 | logger.info(log_info)
121 |
122 |
123 |
124 |
125 |
126 | print('#----------Training----------#')
127 | for epoch in range(start_epoch, config.epochs + 1):
128 |
129 | torch.cuda.empty_cache()
130 |
131 | train_one_epoch(
132 | train_loader,
133 | model,
134 | criterion,
135 | optimizer,
136 | scheduler,
137 | epoch,
138 | logger,
139 | config,
140 | scaler=scaler
141 | )
142 |
143 | loss = val_one_epoch(
144 | val_loader,
145 | model,
146 | criterion,
147 | epoch,
148 | logger,
149 | config
150 | )
151 |
152 |
153 | if loss < min_loss:
154 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth'))
155 | min_loss = loss
156 | min_epoch = epoch
157 |
158 | torch.save(
159 | {
160 | 'epoch': epoch,
161 | 'min_loss': min_loss,
162 | 'min_epoch': min_epoch,
163 | 'loss': loss,
164 | 'model_state_dict': model.module.state_dict(),
165 | 'optimizer_state_dict': optimizer.state_dict(),
166 | 'scheduler_state_dict': scheduler.state_dict(),
167 | }, os.path.join(checkpoint_dir, 'latest.pth'))
168 |
169 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')):
170 | print('#----------Testing----------#')
171 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu'))
172 | model.module.load_state_dict(best_weight)
173 | loss = test_one_epoch(
174 | test_loader,
175 | model,
176 | criterion,
177 | logger,
178 | config,
179 | )
180 | os.rename(
181 | os.path.join(checkpoint_dir, 'best.pth'),
182 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth')
183 | )
184 |
185 |
186 | if __name__ == '__main__':
187 | config = setting_config
188 | main(config)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.backends.cudnn as cudnn
5 | import torchvision.transforms.functional as TF
6 | import numpy as np
7 | import os
8 | import math
9 | import random
10 | import logging
11 | import logging.handlers
12 | from matplotlib import pyplot as plt
13 |
14 |
15 | def set_seed(seed):
16 | # for hash
17 | os.environ['PYTHONHASHSEED'] = str(seed)
18 | # for python and numpy
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | # for cpu gpu
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | # for cudnn
26 | cudnn.benchmark = False
27 | cudnn.deterministic = True
28 |
29 |
30 | def get_logger(name, log_dir):
31 | '''
32 | Args:
33 | name(str): name of logger
34 | log_dir(str): path of log
35 | '''
36 |
37 | if not os.path.exists(log_dir):
38 | os.makedirs(log_dir)
39 |
40 | logger = logging.getLogger(name)
41 | logger.setLevel(logging.INFO)
42 |
43 | info_name = os.path.join(log_dir, '{}.info.log'.format(name))
44 | info_handler = logging.handlers.TimedRotatingFileHandler(info_name,
45 | when='D',
46 | encoding='utf-8')
47 | info_handler.setLevel(logging.INFO)
48 |
49 | formatter = logging.Formatter('%(asctime)s - %(message)s',
50 | datefmt='%Y-%m-%d %H:%M:%S')
51 |
52 | info_handler.setFormatter(formatter)
53 |
54 | logger.addHandler(info_handler)
55 |
56 | return logger
57 |
58 |
59 | def log_config_info(config, logger):
60 | config_dict = config.__dict__
61 | log_info = f'#----------Config info----------#'
62 | logger.info(log_info)
63 | for k, v in config_dict.items():
64 | if k[0] == '_':
65 | continue
66 | else:
67 | log_info = f'{k}: {v},'
68 | logger.info(log_info)
69 |
70 |
71 |
72 | def get_optimizer(config, model):
73 | assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!'
74 |
75 | if config.opt == 'Adadelta':
76 | return torch.optim.Adadelta(
77 | model.parameters(),
78 | lr = config.lr,
79 | rho = config.rho,
80 | eps = config.eps,
81 | weight_decay = config.weight_decay
82 | )
83 | elif config.opt == 'Adagrad':
84 | return torch.optim.Adagrad(
85 | model.parameters(),
86 | lr = config.lr,
87 | lr_decay = config.lr_decay,
88 | eps = config.eps,
89 | weight_decay = config.weight_decay
90 | )
91 | elif config.opt == 'Adam':
92 | return torch.optim.Adam(
93 | model.parameters(),
94 | lr = config.lr,
95 | betas = config.betas,
96 | eps = config.eps,
97 | weight_decay = config.weight_decay,
98 | amsgrad = config.amsgrad
99 | )
100 | elif config.opt == 'AdamW':
101 | return torch.optim.AdamW(
102 | model.parameters(),
103 | lr = config.lr,
104 | betas = config.betas,
105 | eps = config.eps,
106 | weight_decay = config.weight_decay,
107 | amsgrad = config.amsgrad
108 | )
109 | elif config.opt == 'Adamax':
110 | return torch.optim.Adamax(
111 | model.parameters(),
112 | lr = config.lr,
113 | betas = config.betas,
114 | eps = config.eps,
115 | weight_decay = config.weight_decay
116 | )
117 | elif config.opt == 'ASGD':
118 | return torch.optim.ASGD(
119 | model.parameters(),
120 | lr = config.lr,
121 | lambd = config.lambd,
122 | alpha = config.alpha,
123 | t0 = config.t0,
124 | weight_decay = config.weight_decay
125 | )
126 | elif config.opt == 'RMSprop':
127 | return torch.optim.RMSprop(
128 | model.parameters(),
129 | lr = config.lr,
130 | momentum = config.momentum,
131 | alpha = config.alpha,
132 | eps = config.eps,
133 | centered = config.centered,
134 | weight_decay = config.weight_decay
135 | )
136 | elif config.opt == 'Rprop':
137 | return torch.optim.Rprop(
138 | model.parameters(),
139 | lr = config.lr,
140 | etas = config.etas,
141 | step_sizes = config.step_sizes,
142 | )
143 | elif config.opt == 'SGD':
144 | return torch.optim.SGD(
145 | model.parameters(),
146 | lr = config.lr,
147 | momentum = config.momentum,
148 | weight_decay = config.weight_decay,
149 | dampening = config.dampening,
150 | nesterov = config.nesterov
151 | )
152 | else: # default opt is SGD
153 | return torch.optim.SGD(
154 | model.parameters(),
155 | lr = 0.01,
156 | momentum = 0.9,
157 | weight_decay = 0.05,
158 | )
159 |
160 |
161 |
162 | def get_scheduler(config, optimizer):
163 | assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau',
164 | 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!'
165 | if config.sch == 'StepLR':
166 | scheduler = torch.optim.lr_scheduler.StepLR(
167 | optimizer,
168 | step_size = config.step_size,
169 | gamma = config.gamma,
170 | last_epoch = config.last_epoch
171 | )
172 | elif config.sch == 'MultiStepLR':
173 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
174 | optimizer,
175 | milestones = config.milestones,
176 | gamma = config.gamma,
177 | last_epoch = config.last_epoch
178 | )
179 | elif config.sch == 'ExponentialLR':
180 | scheduler = torch.optim.lr_scheduler.ExponentialLR(
181 | optimizer,
182 | gamma = config.gamma,
183 | last_epoch = config.last_epoch
184 | )
185 | elif config.sch == 'CosineAnnealingLR':
186 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
187 | optimizer,
188 | T_max = config.T_max,
189 | eta_min = config.eta_min,
190 | last_epoch = config.last_epoch
191 | )
192 | elif config.sch == 'ReduceLROnPlateau':
193 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
194 | optimizer,
195 | mode = config.mode,
196 | factor = config.factor,
197 | patience = config.patience,
198 | threshold = config.threshold,
199 | threshold_mode = config.threshold_mode,
200 | cooldown = config.cooldown,
201 | min_lr = config.min_lr,
202 | eps = config.eps
203 | )
204 | elif config.sch == 'CosineAnnealingWarmRestarts':
205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
206 | optimizer,
207 | T_0 = config.T_0,
208 | T_mult = config.T_mult,
209 | eta_min = config.eta_min,
210 | last_epoch = config.last_epoch
211 | )
212 | elif config.sch == 'WP_MultiStepLR':
213 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len(
214 | [m for m in config.milestones if m <= epoch])
215 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
216 | elif config.sch == 'WP_CosineLR':
217 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * (
218 | math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1)
219 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
220 |
221 | return scheduler
222 |
223 |
224 |
225 | def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None):
226 | img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy()
227 | img = img / 255. if img.max() > 1.1 else img
228 | if datasets == 'retinal':
229 | msk = np.squeeze(msk, axis=0)
230 | msk_pred = np.squeeze(msk_pred, axis=0)
231 | else:
232 | msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0)
233 | msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0)
234 |
235 | plt.figure(figsize=(7,15))
236 |
237 | plt.subplot(3,1,1)
238 | plt.imshow(img)
239 | plt.axis('off')
240 |
241 | plt.subplot(3,1,2)
242 | plt.imshow(msk, cmap= 'gray')
243 | plt.axis('off')
244 |
245 | plt.subplot(3,1,3)
246 | plt.imshow(msk_pred, cmap = 'gray')
247 | plt.axis('off')
248 |
249 | if test_data_name is not None:
250 | save_path = save_path + test_data_name + '_'
251 | plt.savefig(save_path + str(i) +'.png')
252 | plt.close()
253 |
254 |
255 |
256 | class BCELoss(nn.Module):
257 | def __init__(self):
258 | super(BCELoss, self).__init__()
259 | self.bceloss = nn.BCELoss()
260 |
261 | def forward(self, pred, target):
262 | size = pred.size(0)
263 | pred_ = pred.view(size, -1)
264 | target_ = target.view(size, -1)
265 |
266 | return self.bceloss(pred_, target_)
267 |
268 |
269 | class DiceLoss(nn.Module):
270 | def __init__(self):
271 | super(DiceLoss, self).__init__()
272 |
273 | def forward(self, pred, target):
274 | smooth = 1
275 | size = pred.size(0)
276 |
277 | pred_ = pred.view(size, -1)
278 | target_ = target.view(size, -1)
279 | intersection = pred_ * target_
280 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth)
281 | dice_loss = 1 - dice_score.sum()/size
282 |
283 | return dice_loss
284 |
285 |
286 | class BceDiceLoss(nn.Module):
287 | def __init__(self, wb=1, wd=1):
288 | super(BceDiceLoss, self).__init__()
289 | self.bce = BCELoss()
290 | self.dice = DiceLoss()
291 | self.wb = wb
292 | self.wd = wd
293 |
294 | def forward(self, pred, target):
295 | bceloss = self.bce(pred, target)
296 | diceloss = self.dice(pred, target)
297 |
298 | loss = self.wd * diceloss + self.wb * bceloss
299 | return loss
300 |
301 |
302 |
303 |
304 |
--------------------------------------------------------------------------------