├── README.md ├── configs └── config_setting.py ├── dataprepare ├── Prepare_ISIC2017.py └── Prepare_your_dataset.py ├── engine.py ├── loader.py ├── models ├── H_vmunet.py └── vmamba.py ├── results └── Readme.txt ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # H-vmunet: High-order Vision Mamba UNet for Medical Image Segmentation 4 | 5 | Renkai Wu, Yinghao Liu, Pengchen Liang*, and Qing Chang*
6 | 7 | [![Neucom](https://img.shields.io/badge/Neucom-2025.129447-blue)](https://doi.org/10.1016/j.neucom.2025.129447) 8 | [![arXiv](https://img.shields.io/badge/arXiv-2403.13642-b31b1b.svg)](https://arxiv.org/abs/2403.13642) 9 | 10 |
11 | 12 | ## News🚀 13 | (2025.01.12) ***The paper has been accepted by Neurocomputing***🔥🔥 14 | 15 | (2024.03.21) ***Model weights have been uploaded for download***🔥 16 | 17 | (2024.03.21) ***The project code has been uploaded*** 18 | 19 | (2024.03.20) ***The first edition of our paper has been uploaded to arXiv*** 📃 20 | 21 | **0. Main Environments.**
22 | The environment installation procedure can be followed by [VM-UNet](https://github.com/JCruan519/VM-UNet), or by following the steps below:
23 | ``` 24 | conda create -n vmunet python=3.8 25 | conda activate vmunet 26 | pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 27 | pip install packaging 28 | pip install timm==0.4.12 29 | pip install pytest chardet yacs termcolor 30 | pip install submitit tensorboardX 31 | pip install triton==2.0.0 32 | pip install causal_conv1d==1.0.0 # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 33 | pip install mamba_ssm==1.0.1 # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 34 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs 35 | ``` 36 | 37 | **1. Datasets.** 38 | 39 | *A.ISIC2017*
40 | 1- Download the ISIC 2017 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `/data/dataset_isic17/`.
41 | 2- Run `Prepare_ISIC2017.py` for data preparation and dividing data to train,validation and test sets.
42 | 43 | *B.Spleen*
44 | 1- Download the Spleen dataset from [this](http://medicaldecathlon.com/) link.
45 | 46 | *C.CVC-ClinicDB*
47 | 1- Download the CVC-ClinicDB dataset from [this](https://polyp.grand-challenge.org/CVCClinicDB/) link.
48 | 49 | *D. Prepare your own dataset*
50 | 1. The file format reference is as follows. (The image is a 24-bit png image. The mask is an 8-bit png image. (0 pixel dots for background, 255 pixel dots for target)) 51 | - './your_dataset/' 52 | - images 53 | - 0000.png 54 | - 0001.png 55 | - masks 56 | - 0000.png 57 | - 0001.png 58 | - Prepare_your_dataset.py 59 | 2. In the 'Prepare_your_dataset.py' file, change the number of training sets, validation sets and test sets you want.
60 | 3. Run 'Prepare_your_dataset.py'.
61 | 62 | **2. Train the H_vmunet.** 63 | ``` 64 | python train.py 65 | ``` 66 | - After trianing, you could obtain the outputs in './results/'
67 | 68 | **3. Test the H_vmunet.** 69 | First, in the test.py file, you should change the address of the checkpoint in 'resume_model'. 70 | ``` 71 | python test.py 72 | ``` 73 | - After testing, you could obtain the outputs in './results/'
74 | 75 | **4. Get model weights** 76 | 77 | *A.ISIC2017*
78 | [Google Drive](https://drive.google.com/file/d/10If43saeVW06p9q3oePAL3hOHqRxFoZV/view?usp=sharing) 79 | 80 | *B.Spleen*
81 | [Google Drive](https://drive.google.com/file/d/18aXOv8u-nFIbBdiUwnzHdQ7ELrNIhMu3/view?usp=sharing) 82 | 83 | *C.CVC-ClinicDB*
84 | [Google Drive](https://drive.google.com/file/d/1mG_zOlsz7OuX_qHVmB3mjMeb1GUNgtkP/view?usp=sharing) 85 | 86 | 87 | ## Citation 88 | If you find this repository helpful, please consider citing:
89 | ``` 90 | @article{wu2025h, 91 | title={H-vmunet: High-order vision mamba unet for medical image segmentation}, 92 | author={Wu, Renkai and Liu, Yinghao and Liang, Pengchen and Chang, Qing}, 93 | journal={Neurocomputing}, 94 | pages={129447}, 95 | year={2025}, 96 | publisher={Elsevier} 97 | } 98 | ``` 99 | ## Acknowledgement 100 | Thanks to [Vim](https://github.com/hustvl/Vim), [HorNet](https://github.com/raoyongming/HorNet) and [VM-UNet](https://github.com/JCruan519/VM-UNet) for their outstanding work. 101 | -------------------------------------------------------------------------------- /configs/config_setting.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from utils import * 3 | 4 | from datetime import datetime 5 | 6 | class setting_config: 7 | """ 8 | the config of training setting. 9 | """ 10 | network = 'H_vmunet' 11 | model_config = { 12 | 'num_classes': 1, 13 | 'input_channels': 3, 14 | 'c_list': [8,16,32,64,128,256], 15 | 'split_att': 'fc', 16 | 'bridge': True, 17 | 'drop_path_rate':0.4 18 | } 19 | 20 | test_weights = '' 21 | 22 | datasets = 'ISIC2017' 23 | if datasets == 'ISIC2017': 24 | data_path = '' 25 | elif datasets == 'Spleen': 26 | data_path = '' 27 | elif datasets == 'CVC-ClinicDB': 28 | data_path = '' 29 | else: 30 | raise Exception('datasets in not right!') 31 | 32 | criterion = BceDiceLoss() 33 | 34 | num_classes = 1 35 | input_size_h = 256 36 | input_size_w = 256 37 | input_channels = 3 38 | distributed = False 39 | local_rank = -1 40 | num_workers = 0 41 | seed = 42 42 | world_size = None 43 | rank = None 44 | amp = False 45 | batch_size = 8 46 | epochs = 250 47 | 48 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 49 | 50 | print_interval = 20 51 | val_interval = 30 52 | save_interval = 100 53 | threshold = 0.5 54 | 55 | 56 | opt = 'AdamW' 57 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 58 | if opt == 'Adadelta': 59 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 60 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 61 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 62 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 63 | elif opt == 'Adagrad': 64 | lr = 0.01 # default: 0.01 – learning rate 65 | lr_decay = 0 # default: 0 – learning rate decay 66 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 67 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 68 | elif opt == 'Adam': 69 | lr = 0.001 # default: 1e-3 – learning rate 70 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 71 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 72 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 73 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 74 | elif opt == 'AdamW': 75 | lr = 0.001 # default: 1e-3 – learning rate 76 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 77 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 78 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 79 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 80 | elif opt == 'Adamax': 81 | lr = 2e-3 # default: 2e-3 – learning rate 82 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 83 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 84 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 85 | elif opt == 'ASGD': 86 | lr = 0.01 # default: 1e-2 – learning rate 87 | lambd = 1e-4 # default: 1e-4 – decay term 88 | alpha = 0.75 # default: 0.75 – power for eta update 89 | t0 = 1e6 # default: 1e6 – point at which to start averaging 90 | weight_decay = 0 # default: 0 – weight decay 91 | elif opt == 'RMSprop': 92 | lr = 1e-2 # default: 1e-2 – learning rate 93 | momentum = 0 # default: 0 – momentum factor 94 | alpha = 0.99 # default: 0.99 – smoothing constant 95 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 96 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 97 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 98 | elif opt == 'Rprop': 99 | lr = 1e-2 # default: 1e-2 – learning rate 100 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 101 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 102 | elif opt == 'SGD': 103 | lr = 0.01 # – learning rate 104 | momentum = 0.9 # default: 0 – momentum factor 105 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 106 | dampening = 0 # default: 0 – dampening for momentum 107 | nesterov = False # default: False – enables Nesterov momentum 108 | 109 | sch = 'CosineAnnealingLR' 110 | if sch == 'StepLR': 111 | step_size = epochs // 5 # – Period of learning rate decay. 112 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 113 | last_epoch = -1 # – The index of last epoch. Default: -1. 114 | elif sch == 'MultiStepLR': 115 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 116 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 117 | last_epoch = -1 # – The index of last epoch. Default: -1. 118 | elif sch == 'ExponentialLR': 119 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 120 | last_epoch = -1 # – The index of last epoch. Default: -1. 121 | elif sch == 'CosineAnnealingLR': 122 | T_max = 50 # – Maximum number of iterations. Cosine function period. 123 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 124 | last_epoch = -1 # – The index of last epoch. Default: -1. 125 | elif sch == 'ReduceLROnPlateau': 126 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 127 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 128 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 129 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 130 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 131 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 132 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 133 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 134 | elif sch == 'CosineAnnealingWarmRestarts': 135 | T_0 = 50 # – Number of iterations for the first restart. 136 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 137 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 138 | last_epoch = -1 # – The index of last epoch. Default: -1. 139 | elif sch == 'WP_MultiStepLR': 140 | warm_up_epochs = 10 141 | gamma = 0.1 142 | milestones = [125, 225] 143 | elif sch == 'WP_CosineLR': 144 | warm_up_epochs = 20 -------------------------------------------------------------------------------- /dataprepare/Prepare_ISIC2017.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Code created on Sat Jun 8 18:15:43 2019 4 | @author: Reza Azad 5 | """ 6 | 7 | """ 8 | Reminder Created on December 6, 2023. 9 | @author: Renkai Wu 10 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1 11 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect. 12 | 3.When Dataset_add uses the relative path, you need to change the start to './'. 13 | """ 14 | 15 | import h5py 16 | import numpy as np 17 | import scipy.io as sio 18 | import scipy.misc as sc 19 | import glob 20 | 21 | # Parameters 22 | height = 256 23 | width = 256 24 | channels = 3 25 | 26 | ############################################################# Prepare ISIC 2017 data set ################################################# 27 | Dataset_add = './ISIC2017/' 28 | Tr_add = 'ISIC2017_Task1-2_Training_Input' 29 | 30 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg') 31 | # It contains 2000 training samples 32 | Data_train_2017 = np.zeros([2000, height, width, channels]) 33 | Label_train_2017 = np.zeros([2000, height, width]) 34 | 35 | print('Reading ISIC 2017') 36 | print(Tr_list) 37 | for idx in range(len(Tr_list)): 38 | print(idx+1) 39 | img = sc.imread(Tr_list[idx]) 40 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 41 | Data_train_2017[idx, :,:,:] = img 42 | 43 | b = Tr_list[idx] 44 | a = b[0:len(Dataset_add)] 45 | b = b[len(b)-16: len(b)-4] 46 | add = (a+ 'ISIC2017_Task1_Training_GroundTruth/' + b +'_segmentation.png') 47 | img2 = sc.imread(add) 48 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 49 | Label_train_2017[idx, :,:] = img2 50 | 51 | print('Reading ISIC 2017 finished') 52 | 53 | ################################################################ Make the train and test sets ######################################## 54 | # We consider 1250 samples for training, 150 samples for validation and 600 samples for testing 55 | 56 | Train_img = Data_train_2017[0:1250,:,:,:] 57 | Validation_img = Data_train_2017[1250:1250+150,:,:,:] 58 | Test_img = Data_train_2017[1250+150:2000,:,:,:] 59 | 60 | Train_mask = Label_train_2017[0:1250,:,:] 61 | Validation_mask = Label_train_2017[1250:1250+150,:,:] 62 | Test_mask = Label_train_2017[1250+150:2000,:,:] 63 | 64 | 65 | np.save('data_train', Train_img) 66 | np.save('data_test' , Test_img) 67 | np.save('data_val' , Validation_img) 68 | 69 | np.save('mask_train', Train_mask) 70 | np.save('mask_test' , Test_mask) 71 | np.save('mask_val' , Validation_mask) 72 | 73 | 74 | -------------------------------------------------------------------------------- /dataprepare/Prepare_your_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ##scipy==1.2.1 4 | 5 | import h5py 6 | import numpy as np 7 | import scipy.io as sio 8 | import scipy.misc as sc 9 | import glob 10 | 11 | # Parameters 12 | height = 256 # Enter the image size of the model. 13 | width = 256 # Enter the image size of the model. 14 | channels = 3 # Number of image channels 15 | 16 | train_number = 1000 # Randomly assign the number of images for generating the training set. 17 | val_number = 200 # Randomly assign the number of images for generating the validation set. 18 | test_number = 400 # Randomly assign the number of images for generating the test set. 19 | all = int(train_number) + int(val_number) + int(test_number) 20 | 21 | ############################################################# Prepare your data set ################################################# 22 | Tr_list = glob.glob("images"+'/*.png') # Images storage folder. The image type should be 24-bit png format. 23 | # It contains 2594 training samples 24 | Data_train_2018 = np.zeros([all, height, width, channels]) 25 | Label_train_2018 = np.zeros([all, height, width]) 26 | 27 | print('Reading') 28 | print(len(Tr_list)) 29 | for idx in range(len(Tr_list)): 30 | print(idx+1) 31 | img = sc.imread(Tr_list[idx]) 32 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 33 | Data_train_2018[idx, :,:,:] = img 34 | 35 | b = Tr_list[idx] 36 | b = b[len(b)-8: len(b)-4] 37 | add = ("masks/" + b +'.png') # Masks storage folder. The Mask type should be a black and white image of an 8-bit png (0 pixels for the background and 255 pixels for the target). 38 | img2 = sc.imread(add) 39 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 40 | Label_train_2018[idx, :,:] = img2 41 | 42 | print('Reading your dataset finished') 43 | 44 | ################################################################ Make the training, validation and test sets ######################################## 45 | Train_img = Data_train_2018[0:train_number,:,:,:] 46 | Validation_img = Data_train_2018[train_number:train_number+val_number,:,:,:] 47 | Test_img = Data_train_2018[train_number+val_number:all,:,:,:] 48 | 49 | Train_mask = Label_train_2018[0:train_number,:,:] 50 | Validation_mask = Label_train_2018[train_number:train_number+val_number,:,:] 51 | Test_mask = Label_train_2018[train_number+val_number:all,:,:] 52 | 53 | 54 | np.save('data_train', Train_img) 55 | np.save('data_test' , Test_img) 56 | np.save('data_val' , Validation_img) 57 | 58 | np.save('mask_train', Train_mask) 59 | np.save('mask_test' , Test_mask) 60 | np.save('mask_val' , Validation_mask) 61 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from torch.cuda.amp import autocast as autocast 5 | from sklearn.metrics import confusion_matrix 6 | from utils import save_imgs 7 | 8 | 9 | def train_one_epoch(train_loader, 10 | model, 11 | criterion, 12 | optimizer, 13 | scheduler, 14 | epoch, 15 | logger, 16 | config, 17 | scaler=None): 18 | ''' 19 | train model for one epoch 20 | ''' 21 | # switch to train mode 22 | model.train() 23 | 24 | loss_list = [] 25 | 26 | for iter, data in enumerate(train_loader): 27 | optimizer.zero_grad() 28 | images, targets = data 29 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() 30 | if config.amp: 31 | with autocast(): 32 | out = model(images) 33 | loss = criterion(out, targets) 34 | scaler.scale(loss).backward() 35 | scaler.step(optimizer) 36 | scaler.update() 37 | else: 38 | out = model(images) 39 | loss = criterion(out, targets) 40 | loss.backward() 41 | optimizer.step() 42 | 43 | loss_list.append(loss.item()) 44 | 45 | now_lr = optimizer.state_dict()['param_groups'][0]['lr'] 46 | if iter % config.print_interval == 0: 47 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}' 48 | print(log_info) 49 | logger.info(log_info) 50 | scheduler.step() 51 | 52 | 53 | def val_one_epoch(test_loader, 54 | model, 55 | criterion, 56 | epoch, 57 | logger, 58 | config): 59 | # switch to evaluate mode 60 | model.eval() 61 | preds = [] 62 | gts = [] 63 | loss_list = [] 64 | with torch.no_grad(): 65 | for data in tqdm(test_loader): 66 | img, msk = data 67 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 68 | out = model(img) 69 | loss = criterion(out, msk) 70 | loss_list.append(loss.item()) 71 | gts.append(msk.squeeze(1).cpu().detach().numpy()) 72 | if type(out) is tuple: 73 | out = out[0] 74 | out = out.squeeze(1).cpu().detach().numpy() 75 | preds.append(out) 76 | 77 | if epoch % config.val_interval == 0: 78 | preds = np.array(preds).reshape(-1) 79 | gts = np.array(gts).reshape(-1) 80 | 81 | y_pre = np.where(preds>=config.threshold, 1, 0) 82 | y_true = np.where(gts>=0.5, 1, 0) 83 | 84 | confusion = confusion_matrix(y_true, y_pre) 85 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 86 | 87 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 88 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 89 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 90 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 91 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 92 | 93 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 94 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 95 | print(log_info) 96 | logger.info(log_info) 97 | 98 | else: 99 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}' 100 | print(log_info) 101 | logger.info(log_info) 102 | 103 | return np.mean(loss_list) 104 | 105 | 106 | def test_one_epoch(test_loader, 107 | model, 108 | criterion, 109 | logger, 110 | config, 111 | test_data_name=None): 112 | # switch to evaluate mode 113 | model.eval() 114 | preds = [] 115 | gts = [] 116 | loss_list = [] 117 | with torch.no_grad(): 118 | for i, data in enumerate(tqdm(test_loader)): 119 | img, msk = data 120 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 121 | out = model(img) 122 | loss = criterion(out, msk) 123 | loss_list.append(loss.item()) 124 | msk = msk.squeeze(1).cpu().detach().numpy() 125 | gts.append(msk) 126 | if type(out) is tuple: 127 | out = out[0] 128 | out = out.squeeze(1).cpu().detach().numpy() 129 | preds.append(out) 130 | save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name) 131 | 132 | preds = np.array(preds).reshape(-1) 133 | gts = np.array(gts).reshape(-1) 134 | 135 | y_pre = np.where(preds>=config.threshold, 1, 0) 136 | y_true = np.where(gts>=0.5, 1, 0) 137 | 138 | confusion = confusion_matrix(y_true, y_pre) 139 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 140 | 141 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 142 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 143 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 144 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 145 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 146 | 147 | if test_data_name is not None: 148 | log_info = f'test_datasets_name: {test_data_name}' 149 | print(log_info) 150 | logger.info(log_info) 151 | log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 152 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 153 | print(log_info) 154 | logger.info(log_info) 155 | 156 | return np.mean(loss_list) 157 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | from PIL import Image 7 | from einops.layers.torch import Rearrange 8 | from scipy.ndimage.morphology import binary_dilation 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | from scipy import ndimage 12 | from utils import * 13 | 14 | 15 | # ===== normalize over the dataset 16 | def dataset_normalized(imgs): 17 | imgs_normalized = np.empty(imgs.shape) 18 | imgs_std = np.std(imgs) 19 | imgs_mean = np.mean(imgs) 20 | imgs_normalized = (imgs-imgs_mean)/imgs_std 21 | for i in range(imgs.shape[0]): 22 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 23 | return imgs_normalized 24 | 25 | 26 | ## Temporary 27 | class isic_loader(Dataset): 28 | """ dataset class for Brats datasets 29 | """ 30 | def __init__(self, path_Data, train = True, Test = False): 31 | super(isic_loader, self) 32 | self.train = train 33 | if train: 34 | self.data = np.load(path_Data+'data_train.npy') 35 | self.mask = np.load(path_Data+'mask_train.npy') 36 | else: 37 | if Test: 38 | self.data = np.load(path_Data+'data_test.npy') 39 | self.mask = np.load(path_Data+'mask_test.npy') 40 | else: 41 | self.data = np.load(path_Data+'data_val.npy') 42 | self.mask = np.load(path_Data+'mask_val.npy') 43 | 44 | self.data = dataset_normalized(self.data) 45 | self.mask = np.expand_dims(self.mask, axis=3) 46 | self.mask = self.mask/255. 47 | 48 | def __getitem__(self, indx): 49 | img = self.data[indx] 50 | seg = self.mask[indx] 51 | if self.train: 52 | if random.random() > 0.5: 53 | img, seg = self.random_rot_flip(img, seg) 54 | if random.random() > 0.5: 55 | img, seg = self.random_rotate(img, seg) 56 | 57 | seg = torch.tensor(seg.copy()) 58 | img = torch.tensor(img.copy()) 59 | img = img.permute( 2, 0, 1) 60 | seg = seg.permute( 2, 0, 1) 61 | 62 | return img, seg 63 | 64 | def random_rot_flip(self,image, label): 65 | k = np.random.randint(0, 4) 66 | image = np.rot90(image, k) 67 | label = np.rot90(label, k) 68 | axis = np.random.randint(0, 2) 69 | image = np.flip(image, axis=axis).copy() 70 | label = np.flip(label, axis=axis).copy() 71 | return image, label 72 | 73 | def random_rotate(self,image, label): 74 | angle = np.random.randint(20, 80) 75 | image = ndimage.rotate(image, angle, order=0, reshape=False) 76 | label = ndimage.rotate(label, angle, order=0, reshape=False) 77 | return image, label 78 | 79 | 80 | 81 | def __len__(self): 82 | return len(self.data) 83 | -------------------------------------------------------------------------------- /models/H_vmunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import trunc_normal_, DropPath 6 | import os 7 | import sys 8 | import torch.fft 9 | import math 10 | from .vmamba import SS2D 11 | 12 | import traceback 13 | 14 | class DepthWiseConv2d(nn.Module): 15 | def __init__(self, dim_in, dim_out, kernel_size=3, padding=1, stride=1, dilation=1): 16 | super().__init__() 17 | 18 | self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, 19 | stride=stride, dilation=dilation, groups=dim_in) 20 | self.norm_layer = nn.GroupNorm(4, dim_in) 21 | self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=1) 22 | 23 | def forward(self, x): 24 | return self.conv2(self.norm_layer(self.conv1(x))) 25 | 26 | if 'DWCONV_IMPL' in os.environ: 27 | try: 28 | sys.path.append(os.environ['DWCONV_IMPL']) 29 | from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM 30 | def get_dwconv(dim, kernel, bias): 31 | return DepthWiseConv2dImplicitGEMM(dim, kernel, bias) 32 | # print('Using Megvii large kernel dw conv impl') 33 | except: 34 | print(traceback.format_exc()) 35 | def get_dwconv(dim, kernel, bias): 36 | return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim) 37 | 38 | # print('[fail to use Megvii Large kernel] Using PyTorch large kernel dw conv impl') 39 | else: 40 | def get_dwconv(dim, kernel, bias): 41 | return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim) 42 | 43 | # print('Using PyTorch large kernel dw conv impl') 44 | 45 | class H_SS2D(nn.Module): 46 | def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0,d_state=16): 47 | super().__init__() 48 | self.order = order 49 | self.dims = [dim // 2 ** i for i in range(order)] 50 | self.dims.reverse() 51 | self.proj_in = nn.Conv2d(dim, 2*dim, 1) 52 | 53 | if gflayer is None: 54 | self.dwconv = get_dwconv(sum(self.dims), 7, True) 55 | else: 56 | self.dwconv = gflayer(sum(self.dims), h=h, w=w) 57 | 58 | self.proj_out = nn.Conv2d(dim, dim, 1) 59 | 60 | self.pws = nn.ModuleList( 61 | [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)] 62 | ) 63 | 64 | num = len(self.dims) 65 | if num == 2: 66 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16) 67 | elif num == 3 : 68 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16) 69 | self.ss2d_2 = SS2D(d_model=self.dims[2], dropout=0, d_state=16) 70 | elif num == 4 : 71 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16) 72 | self.ss2d_2 = SS2D(d_model=self.dims[2], dropout=0, d_state=16) 73 | self.ss2d_3 = SS2D(d_model=self.dims[3], dropout=0, d_state=16) 74 | elif num == 5 : 75 | self.ss2d_1 = SS2D(d_model=self.dims[1], dropout=0, d_state=16) 76 | self.ss2d_2 = SS2D(d_model=self.dims[2], dropout=0, d_state=16) 77 | self.ss2d_3 = SS2D(d_model=self.dims[3], dropout=0, d_state=16) 78 | self.ss2d_4 = SS2D(d_model=self.dims[4], dropout=0, d_state=16) 79 | 80 | self.ss2d_in = SS2D(d_model=self.dims[0], dropout=0, d_state=16) 81 | 82 | self.scale = s 83 | 84 | print('[H_SS2D]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale) 85 | 86 | 87 | def forward(self, x, mask=None, dummy=False): 88 | B, C, H, W = x.shape 89 | 90 | fused_x = self.proj_in(x) 91 | pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1) 92 | 93 | dw_abc = self.dwconv(abc) * self.scale 94 | 95 | dw_list = torch.split(dw_abc, self.dims, dim=1) 96 | x = pwa * dw_list[0] 97 | x = x.permute(0, 2, 3, 1) 98 | x = self.ss2d_in(x) 99 | x = x.permute(0, 3, 1, 2) 100 | 101 | for i in range(self.order -1): 102 | x = self.pws[i](x) * dw_list[i+1] 103 | if i == 0 : 104 | x = x.permute(0, 2, 3, 1) 105 | x = self.ss2d_1(x) 106 | x = x.permute(0, 3, 1, 2) 107 | elif i == 1 : 108 | x = x.permute(0, 2, 3, 1) 109 | x = self.ss2d_2(x) 110 | x = x.permute(0, 3, 1, 2) 111 | elif i == 2 : 112 | x = x.permute(0, 2, 3, 1) 113 | x = self.ss2d_3(x) 114 | x = x.permute(0, 3, 1, 2) 115 | elif i == 3 : 116 | x = x.permute(0, 2, 3, 1) 117 | x = self.ss2d_4(x) 118 | x = x.permute(0, 3, 1, 2) 119 | 120 | x = self.proj_out(x) 121 | 122 | return x 123 | 124 | class Block(nn.Module): 125 | r""" H_VSS Block 126 | """ 127 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, H_SS2D=H_SS2D): 128 | super().__init__() 129 | 130 | self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first') 131 | self.H_SS2D = H_SS2D(dim) 132 | self.norm2 = LayerNorm(dim, eps=1e-6) 133 | self.pwconv1 = nn.Linear(dim, 4 * dim) 134 | self.act = nn.GELU() 135 | self.pwconv2 = nn.Linear(4 * dim, dim) 136 | 137 | self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), 138 | requires_grad=True) if layer_scale_init_value > 0 else None 139 | 140 | self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 141 | requires_grad=True) if layer_scale_init_value > 0 else None 142 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 143 | 144 | def forward(self, x): 145 | B, C, H, W = x.shape 146 | if self.gamma1 is not None: 147 | gamma1 = self.gamma1.view(C, 1, 1) 148 | else: 149 | gamma1 = 1 150 | x = x + self.drop_path(gamma1 * self.H_SS2D(self.norm1(x))) 151 | 152 | input = x 153 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 154 | x = self.norm2(x) 155 | x = self.pwconv1(x) 156 | x = self.act(x) 157 | x = self.pwconv2(x) 158 | if self.gamma2 is not None: 159 | x = self.gamma2 * x 160 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 161 | 162 | x = input + self.drop_path(x) 163 | return x 164 | 165 | 166 | class Channel_Att_Bridge(nn.Module): 167 | def __init__(self, c_list, split_att='fc'): 168 | super().__init__() 169 | c_list_sum = sum(c_list) - c_list[-1] 170 | self.split_att = split_att 171 | self.avgpool = nn.AdaptiveAvgPool2d(1) 172 | self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) 173 | self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1) 174 | self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1) 175 | self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1) 176 | self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1) 177 | self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1) 178 | self.sigmoid = nn.Sigmoid() 179 | 180 | def forward(self, t1, t2, t3, t4, t5): 181 | att = torch.cat((self.avgpool(t1), 182 | self.avgpool(t2), 183 | self.avgpool(t3), 184 | self.avgpool(t4), 185 | self.avgpool(t5)), dim=1) 186 | att = self.get_all_att(att.squeeze(-1).transpose(-1, -2)) 187 | if self.split_att != 'fc': 188 | att = att.transpose(-1, -2) 189 | att1 = self.sigmoid(self.att1(att)) 190 | att2 = self.sigmoid(self.att2(att)) 191 | att3 = self.sigmoid(self.att3(att)) 192 | att4 = self.sigmoid(self.att4(att)) 193 | att5 = self.sigmoid(self.att5(att)) 194 | if self.split_att == 'fc': 195 | att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1) 196 | att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2) 197 | att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3) 198 | att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4) 199 | att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5) 200 | else: 201 | att1 = att1.unsqueeze(-1).expand_as(t1) 202 | att2 = att2.unsqueeze(-1).expand_as(t2) 203 | att3 = att3.unsqueeze(-1).expand_as(t3) 204 | att4 = att4.unsqueeze(-1).expand_as(t4) 205 | att5 = att5.unsqueeze(-1).expand_as(t5) 206 | 207 | return att1, att2, att3, att4, att5 208 | 209 | 210 | class Spatial_Att_Bridge(nn.Module): 211 | def __init__(self): 212 | super().__init__() 213 | self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3), 214 | nn.Sigmoid()) 215 | 216 | def forward(self, t1, t2, t3, t4, t5): 217 | t_list = [t1, t2, t3, t4, t5] 218 | att_list = [] 219 | for t in t_list: 220 | avg_out = torch.mean(t, dim=1, keepdim=True) 221 | max_out, _ = torch.max(t, dim=1, keepdim=True) 222 | att = torch.cat([avg_out, max_out], dim=1) 223 | att = self.shared_conv2d(att) 224 | att_list.append(att) 225 | return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4] 226 | 227 | 228 | class SC_Att_Bridge(nn.Module): 229 | def __init__(self, c_list, split_att='fc'): 230 | super().__init__() 231 | 232 | self.catt = Channel_Att_Bridge(c_list, split_att=split_att) 233 | self.satt = Spatial_Att_Bridge() 234 | 235 | def forward(self, t1, t2, t3, t4, t5): 236 | r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5 237 | 238 | satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5) 239 | t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5 240 | 241 | r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5 242 | t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5 243 | 244 | catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5) 245 | t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5 246 | 247 | return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_ 248 | 249 | 250 | class H_vmunet(nn.Module): 251 | 252 | def __init__(self, num_classes=1, input_channels=3,layer_scale_init_value=1e-6,H_SS2D=H_SS2D, block=Block,pretrained=None, 253 | use_checkpoint=False, c_list=[8,16,32,64,128,256], depths=[2, 2, 2, 2],drop_path_rate=0., 254 | split_att='fc', bridge=True): 255 | super().__init__() 256 | self.pretrained = pretrained 257 | self.use_checkpoint = use_checkpoint 258 | self.bridge = bridge 259 | 260 | self.encoder1 = nn.Sequential( 261 | nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1), 262 | ) 263 | self.encoder2 =nn.Sequential( 264 | nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1), 265 | ) 266 | 267 | 268 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 269 | 270 | if not isinstance(H_SS2D, list): 271 | H_SS2D = [partial(H_SS2D, order=2, s=1/3, gflayer=Local_SS2D), 272 | partial(H_SS2D, order=3, s=1/3, gflayer=Local_SS2D), 273 | partial(H_SS2D, order=4, s=1/3, h=24, w=13, gflayer=Local_SS2D), 274 | partial(H_SS2D, order=5, s=1/3, h=12, w=7, gflayer=Local_SS2D)] 275 | else: 276 | H_SS2D = H_SS2D 277 | assert len(H_SS2D) == 4 278 | 279 | if isinstance(H_SS2D[0], str): 280 | H_SS2D = [eval(h) for h in H_SS2D] 281 | 282 | if isinstance(block, str): 283 | block = eval(block) 284 | 285 | 286 | 287 | self.encoder3 = nn.Sequential( 288 | *[block(dim=c_list[1], drop_path=dp_rates[0 + j], 289 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[0]) for j in range(depths[0])], 290 | nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1), 291 | ) 292 | 293 | self.encoder4 = nn.Sequential( 294 | *[block(dim=c_list[2], drop_path=dp_rates[2 + j], 295 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[1]) for j in range(depths[1])], 296 | nn.Conv2d(c_list[2], c_list[3], 3, stride=1, padding=1), 297 | ) 298 | 299 | 300 | self.encoder5 = nn.Sequential( 301 | *[block(dim=c_list[3], drop_path=dp_rates[4 + j], 302 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[2]) for j in range(depths[2])], 303 | nn.Conv2d(c_list[3], c_list[4], 3, stride=1, padding=1), 304 | ) 305 | 306 | self.encoder6 = nn.Sequential( 307 | *[block(dim=c_list[4], drop_path=dp_rates[6 + j], 308 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[3]) for j in range(depths[3])], 309 | nn.Conv2d(c_list[4], c_list[5], 3, stride=1, padding=1), 310 | ) 311 | 312 | 313 | if bridge: 314 | self.scab = SC_Att_Bridge(c_list, split_att) 315 | print('SC_Att_Bridge was used') 316 | 317 | self.decoder1 = nn.Sequential( 318 | *[block(dim=c_list[5], drop_path=dp_rates[6], 319 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[3]) for j in range(depths[3])], 320 | nn.Conv2d(c_list[5], c_list[4], 3, stride=1, padding=1), 321 | ) 322 | 323 | 324 | self.decoder2 = nn.Sequential( 325 | *[block(dim=c_list[4], drop_path=dp_rates[4+j], 326 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[2]) for j in range(depths[2])], 327 | nn.Conv2d(c_list[4], c_list[3], 3, stride=1, padding=1), 328 | ) 329 | 330 | self.decoder3 = nn.Sequential( 331 | *[block(dim=c_list[3], drop_path=dp_rates[2+j], 332 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[1]) for j in range(depths[1])], 333 | nn.Conv2d(c_list[3], c_list[2], 3, stride=1, padding=1), 334 | ) 335 | 336 | self.decoder4 = nn.Sequential( 337 | *[block(dim=c_list[2], drop_path=dp_rates[0+j], 338 | layer_scale_init_value=layer_scale_init_value, H_SS2D=H_SS2D[0]) for j in range(depths[0])], 339 | nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1), 340 | ) 341 | 342 | self.decoder5 = nn.Sequential( 343 | nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1), 344 | ) 345 | 346 | self.ebn1 = nn.GroupNorm(4, c_list[0]) 347 | self.ebn2 = nn.GroupNorm(4, c_list[1]) 348 | self.ebn3 = nn.GroupNorm(4, c_list[2]) 349 | self.ebn4 = nn.GroupNorm(4, c_list[3]) 350 | self.ebn5 = nn.GroupNorm(4, c_list[4]) 351 | self.dbn1 = nn.GroupNorm(4, c_list[4]) 352 | self.dbn2 = nn.GroupNorm(4, c_list[3]) 353 | self.dbn3 = nn.GroupNorm(4, c_list[2]) 354 | self.dbn4 = nn.GroupNorm(4, c_list[1]) 355 | self.dbn5 = nn.GroupNorm(4, c_list[0]) 356 | 357 | self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1) 358 | 359 | self.apply(self._init_weights) 360 | 361 | 362 | 363 | def _init_weights(self, m): 364 | if isinstance(m, nn.Linear): 365 | trunc_normal_(m.weight, std=.02) 366 | if isinstance(m, nn.Linear) and m.bias is not None: 367 | nn.init.constant_(m.bias, 0) 368 | elif isinstance(m, nn.Conv1d): 369 | n = m.kernel_size[0] * m.out_channels 370 | m.weight.data.normal_(0, math.sqrt(2. / n)) 371 | elif isinstance(m, nn.Conv2d): 372 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 373 | fan_out //= m.groups 374 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 375 | if m.bias is not None: 376 | m.bias.data.zero_() 377 | 378 | 379 | def forward(self, x): 380 | 381 | out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) 382 | t1 = out # b, c0, H/2, W/2 383 | 384 | out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) 385 | t2 = out # b, c1, H/4, W/4 386 | 387 | out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) 388 | t3 = out # b, c2, H/8, W/8 389 | 390 | 391 | out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2)) 392 | t4 = out # b, c3, H/16, W/16 393 | 394 | out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2)) 395 | t5 = out # b, c4, H/32, W/32 396 | 397 | if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5) 398 | 399 | out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32 400 | 401 | out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32 402 | out5 = torch.add(out5, t5) # b, c4, H/32, W/32 403 | 404 | out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16 405 | out4 = torch.add(out4, t4) # b, c3, H/16, W/16 406 | 407 | out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8 408 | out3 = torch.add(out3, t3) # b, c2, H/8, W/8 409 | 410 | out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4 411 | out2 = torch.add(out2, t2) # b, c1, H/4, W/4 412 | 413 | out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2 414 | out1 = torch.add(out1, t1) # b, c0, H/2, W/2 415 | 416 | out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W 417 | 418 | return torch.sigmoid(out0) 419 | 420 | class LayerNorm(nn.Module): 421 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 422 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 423 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 424 | with shape (batch_size, channels, height, width). 425 | """ 426 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 427 | super().__init__() 428 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 429 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 430 | self.eps = eps 431 | self.data_format = data_format 432 | if self.data_format not in ["channels_last", "channels_first"]: 433 | raise NotImplementedError 434 | self.normalized_shape = (normalized_shape, ) 435 | 436 | def forward(self, x): 437 | if self.data_format == "channels_last": 438 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 439 | elif self.data_format == "channels_first": 440 | u = x.mean(1, keepdim=True) 441 | s = (x - u).pow(2).mean(1, keepdim=True) 442 | x = (x - u) / torch.sqrt(s + self.eps) 443 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 444 | return x 445 | 446 | class Local_SS2D(nn.Module): 447 | def __init__(self, dim, h=14, w=8): 448 | super().__init__() 449 | self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2) 450 | self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02) 451 | trunc_normal_(self.complex_weight, std=.02) 452 | self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first') 453 | self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first') 454 | 455 | self.SS2D = SS2D(d_model=dim // 2, dropout=0, d_state=16) 456 | 457 | 458 | def forward(self, x): 459 | x = self.pre_norm(x) 460 | x1, x2 = torch.chunk(x, 2, dim=1) 461 | x1 = self.dw(x1) 462 | 463 | B, C, a, b = x2.shape 464 | 465 | x2 = x2.permute(0, 2, 3, 1) 466 | 467 | x2 = self.SS2D(x2) 468 | 469 | x2 = x2.permute(0, 3, 1, 2) 470 | 471 | x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b) 472 | x = self.post_norm(x) 473 | return x 474 | -------------------------------------------------------------------------------- /models/vmamba.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from functools import partial 4 | from typing import Optional, Callable 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as checkpoint 10 | from einops import rearrange, repeat 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | try: 13 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 14 | except: 15 | pass 16 | 17 | # an alternative for mamba_ssm (in which causal_conv1d is needed) 18 | try: 19 | from selective_scan import selective_scan_fn as selective_scan_fn_v1 20 | from selective_scan import selective_scan_ref as selective_scan_ref_v1 21 | except: 22 | pass 23 | 24 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" 25 | 26 | 27 | def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): 28 | """ 29 | u: r(B D L) 30 | delta: r(B D L) 31 | A: r(D N) 32 | B: r(B N L) 33 | C: r(B N L) 34 | D: r(D) 35 | z: r(B D L) 36 | delta_bias: r(D), fp32 37 | 38 | ignores: 39 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 40 | """ 41 | import numpy as np 42 | 43 | # fvcore.nn.jit_handles 44 | def get_flops_einsum(input_shapes, equation): 45 | np_arrs = [np.zeros(s) for s in input_shapes] 46 | optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] 47 | for line in optim.split("\n"): 48 | if "optimized flop" in line.lower(): 49 | # divided by 2 because we count MAC (multiply-add counted as one flop) 50 | flop = float(np.floor(float(line.split(":")[-1]) / 2)) 51 | return flop 52 | 53 | 54 | assert not with_complex 55 | 56 | flops = 0 # below code flops = 0 57 | if False: 58 | ... 59 | """ 60 | dtype_in = u.dtype 61 | u = u.float() 62 | delta = delta.float() 63 | if delta_bias is not None: 64 | delta = delta + delta_bias[..., None].float() 65 | if delta_softplus: 66 | delta = F.softplus(delta) 67 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] 68 | is_variable_B = B.dim() >= 3 69 | is_variable_C = C.dim() >= 3 70 | if A.is_complex(): 71 | if is_variable_B: 72 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) 73 | if is_variable_C: 74 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) 75 | else: 76 | B = B.float() 77 | C = C.float() 78 | x = A.new_zeros((batch, dim, dstate)) 79 | ys = [] 80 | """ 81 | 82 | flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") 83 | if with_Group: 84 | flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") 85 | else: 86 | flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") 87 | if False: 88 | ... 89 | """ 90 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 91 | if not is_variable_B: 92 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) 93 | else: 94 | if B.dim() == 3: 95 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) 96 | else: 97 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) 98 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) 99 | if is_variable_C and C.dim() == 4: 100 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) 101 | last_state = None 102 | """ 103 | 104 | in_for_flops = B * D * N 105 | if with_Group: 106 | in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") 107 | else: 108 | in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") 109 | flops += L * in_for_flops 110 | if False: 111 | ... 112 | """ 113 | for i in range(u.shape[2]): 114 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 115 | if not is_variable_C: 116 | y = torch.einsum('bdn,dn->bd', x, C) 117 | else: 118 | if C.dim() == 3: 119 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) 120 | else: 121 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) 122 | if i == u.shape[2] - 1: 123 | last_state = x 124 | if y.is_complex(): 125 | y = y.real * 2 126 | ys.append(y) 127 | y = torch.stack(ys, dim=2) # (batch dim L) 128 | """ 129 | 130 | if with_D: 131 | flops += B * D * L 132 | if with_Z: 133 | flops += B * D * L 134 | if False: 135 | ... 136 | """ 137 | out = y if D is None else y + u * rearrange(D, "d -> d 1") 138 | if z is not None: 139 | out = out * F.silu(z) 140 | out = out.to(dtype=dtype_in) 141 | """ 142 | 143 | return flops 144 | 145 | 146 | class PatchEmbed2D(nn.Module): 147 | r""" Image to Patch Embedding 148 | Args: 149 | patch_size (int): Patch token size. Default: 4. 150 | in_chans (int): Number of input image channels. Default: 3. 151 | embed_dim (int): Number of linear projection output channels. Default: 96. 152 | norm_layer (nn.Module, optional): Normalization layer. Default: None 153 | """ 154 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): 155 | super().__init__() 156 | if isinstance(patch_size, int): 157 | patch_size = (patch_size, patch_size) 158 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 159 | if norm_layer is not None: 160 | self.norm = norm_layer(embed_dim) 161 | else: 162 | self.norm = None 163 | 164 | def forward(self, x): 165 | x = self.proj(x).permute(0, 2, 3, 1) 166 | if self.norm is not None: 167 | x = self.norm(x) 168 | return x 169 | 170 | 171 | class PatchMerging2D(nn.Module): 172 | r""" Patch Merging Layer. 173 | Args: 174 | input_resolution (tuple[int]): Resolution of input feature. 175 | dim (int): Number of input channels. 176 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 177 | """ 178 | 179 | def __init__(self, dim, norm_layer=nn.LayerNorm): 180 | super().__init__() 181 | self.dim = dim 182 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 183 | self.norm = norm_layer(4 * dim) 184 | 185 | def forward(self, x): 186 | B, H, W, C = x.shape 187 | 188 | SHAPE_FIX = [-1, -1] 189 | if (W % 2 != 0) or (H % 2 != 0): 190 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) 191 | SHAPE_FIX[0] = H // 2 192 | SHAPE_FIX[1] = W // 2 193 | 194 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 195 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 196 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 197 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 198 | 199 | if SHAPE_FIX[0] > 0: 200 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 201 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 202 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 203 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 204 | 205 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 206 | x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C 207 | 208 | x = self.norm(x) 209 | x = self.reduction(x) 210 | 211 | return x 212 | 213 | 214 | class PatchExpand2D(nn.Module): 215 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 216 | super().__init__() 217 | self.dim = dim*2 218 | self.dim_scale = dim_scale 219 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 220 | self.norm = norm_layer(self.dim // dim_scale) 221 | 222 | def forward(self, x): 223 | B, H, W, C = x.shape 224 | x = self.expand(x) 225 | 226 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 227 | x= self.norm(x) 228 | 229 | return x 230 | 231 | 232 | class Final_PatchExpand2D(nn.Module): 233 | def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): 234 | super().__init__() 235 | self.dim = dim 236 | self.dim_scale = dim_scale 237 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 238 | self.norm = norm_layer(self.dim // dim_scale) 239 | 240 | def forward(self, x): 241 | B, H, W, C = x.shape 242 | x = self.expand(x) 243 | 244 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 245 | x= self.norm(x) 246 | 247 | return x 248 | 249 | 250 | class SS2D(nn.Module): 251 | def __init__( 252 | self, 253 | d_model, 254 | d_state=16, 255 | # d_state="auto", # 20240109 256 | d_conv=3, 257 | expand=2, 258 | dt_rank="auto", 259 | dt_min=0.001, 260 | dt_max=0.1, 261 | dt_init="random", 262 | dt_scale=1.0, 263 | dt_init_floor=1e-4, 264 | dropout=0., 265 | conv_bias=True, 266 | bias=False, 267 | device=None, 268 | dtype=None, 269 | **kwargs, 270 | ): 271 | factory_kwargs = {"device": device, "dtype": dtype} 272 | super().__init__() 273 | self.d_model = d_model 274 | self.d_state = d_state 275 | # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 276 | self.d_conv = d_conv 277 | self.expand = expand 278 | self.d_inner = int(self.expand * self.d_model) 279 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 280 | 281 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 282 | self.conv2d = nn.Conv2d( 283 | in_channels=self.d_inner, 284 | out_channels=self.d_inner, 285 | groups=self.d_inner, 286 | bias=conv_bias, 287 | kernel_size=d_conv, 288 | padding=(d_conv - 1) // 2, 289 | **factory_kwargs, 290 | ) 291 | self.act = nn.SiLU() 292 | 293 | self.x_proj = ( 294 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 295 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 296 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 297 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 298 | ) 299 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) 300 | del self.x_proj 301 | 302 | self.dt_projs = ( 303 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 304 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 305 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 306 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 307 | ) 308 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) 309 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) 310 | del self.dt_projs 311 | 312 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) 313 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 314 | 315 | # self.selective_scan = selective_scan_fn 316 | self.forward_core = self.forward_corev0 317 | 318 | self.out_norm = nn.LayerNorm(self.d_inner) 319 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 320 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 321 | 322 | @staticmethod 323 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 324 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 325 | 326 | # Initialize special dt projection to preserve variance at initialization 327 | dt_init_std = dt_rank**-0.5 * dt_scale 328 | if dt_init == "constant": 329 | nn.init.constant_(dt_proj.weight, dt_init_std) 330 | elif dt_init == "random": 331 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 332 | else: 333 | raise NotImplementedError 334 | 335 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 336 | dt = torch.exp( 337 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 338 | + math.log(dt_min) 339 | ).clamp(min=dt_init_floor) 340 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 341 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 342 | with torch.no_grad(): 343 | dt_proj.bias.copy_(inv_dt) 344 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 345 | dt_proj.bias._no_reinit = True 346 | 347 | return dt_proj 348 | 349 | @staticmethod 350 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 351 | # S4D real initialization 352 | A = repeat( 353 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 354 | "n -> d n", 355 | d=d_inner, 356 | ).contiguous() 357 | A_log = torch.log(A) # Keep A_log in fp32 358 | if copies > 1: 359 | A_log = repeat(A_log, "d n -> r d n", r=copies) 360 | if merge: 361 | A_log = A_log.flatten(0, 1) 362 | A_log = nn.Parameter(A_log) 363 | A_log._no_weight_decay = True 364 | return A_log 365 | 366 | @staticmethod 367 | def D_init(d_inner, copies=1, device=None, merge=True): 368 | # D "skip" parameter 369 | D = torch.ones(d_inner, device=device) 370 | if copies > 1: 371 | D = repeat(D, "n1 -> r n1", r=copies) 372 | if merge: 373 | D = D.flatten(0, 1) 374 | D = nn.Parameter(D) # Keep in fp32 375 | D._no_weight_decay = True 376 | return D 377 | 378 | def forward_corev0(self, x: torch.Tensor): 379 | self.selective_scan = selective_scan_fn 380 | 381 | B, C, H, W = x.shape 382 | L = H * W 383 | K = 4 384 | 385 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 386 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 387 | 388 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 389 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 390 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 391 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 392 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 393 | 394 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 395 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 396 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 397 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 398 | Ds = self.Ds.float().view(-1) # (k * d) 399 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 400 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 401 | 402 | out_y = self.selective_scan( 403 | xs, dts, 404 | As, Bs, Cs, Ds, z=None, 405 | delta_bias=dt_projs_bias, 406 | delta_softplus=True, 407 | return_last_state=False, 408 | ).view(B, K, -1, L) 409 | assert out_y.dtype == torch.float 410 | 411 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 412 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 413 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 414 | 415 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 416 | 417 | # an alternative to forward_corev1 418 | def forward_corev1(self, x: torch.Tensor): 419 | self.selective_scan = selective_scan_fn_v1 420 | 421 | B, C, H, W = x.shape 422 | L = H * W 423 | K = 4 424 | 425 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 426 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 427 | 428 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 429 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 430 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 431 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 432 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 433 | 434 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 435 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 436 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 437 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 438 | Ds = self.Ds.float().view(-1) # (k * d) 439 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 440 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 441 | 442 | out_y = self.selective_scan( 443 | xs, dts, 444 | As, Bs, Cs, Ds, 445 | delta_bias=dt_projs_bias, 446 | delta_softplus=True, 447 | ).view(B, K, -1, L) 448 | assert out_y.dtype == torch.float 449 | 450 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 451 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 452 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 453 | 454 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 455 | 456 | def forward(self, x: torch.Tensor, **kwargs): 457 | B, H, W, C = x.shape 458 | 459 | xz = self.in_proj(x) 460 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 461 | 462 | x = x.permute(0, 3, 1, 2).contiguous() 463 | x = self.act(self.conv2d(x)) # (b, d, h, w) 464 | y1, y2, y3, y4 = self.forward_core(x) 465 | assert y1.dtype == torch.float32 466 | y = y1 + y2 + y3 + y4 467 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 468 | y = self.out_norm(y) 469 | y = y * F.silu(z) 470 | out = self.out_proj(y) 471 | if self.dropout is not None: 472 | out = self.dropout(out) 473 | return out 474 | 475 | 476 | class VSSBlock(nn.Module): 477 | def __init__( 478 | self, 479 | hidden_dim: int = 0, 480 | drop_path: float = 0, 481 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 482 | attn_drop_rate: float = 0, 483 | d_state: int = 16, 484 | **kwargs, 485 | ): 486 | super().__init__() 487 | self.ln_1 = norm_layer(hidden_dim) 488 | self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) 489 | self.drop_path = DropPath(drop_path) 490 | 491 | def forward(self, input: torch.Tensor): 492 | x = input + self.drop_path(self.self_attention(self.ln_1(input))) 493 | return x 494 | 495 | 496 | class VSSLayer(nn.Module): 497 | """ A basic Swin Transformer layer for one stage. 498 | Args: 499 | dim (int): Number of input channels. 500 | depth (int): Number of blocks. 501 | drop (float, optional): Dropout rate. Default: 0.0 502 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 503 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 504 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 505 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 506 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 507 | """ 508 | 509 | def __init__( 510 | self, 511 | dim, 512 | depth, 513 | attn_drop=0., 514 | drop_path=0., 515 | norm_layer=nn.LayerNorm, 516 | downsample=None, 517 | use_checkpoint=False, 518 | d_state=16, 519 | **kwargs, 520 | ): 521 | super().__init__() 522 | self.dim = dim 523 | self.use_checkpoint = use_checkpoint 524 | 525 | self.blocks = nn.ModuleList([ 526 | VSSBlock( 527 | hidden_dim=dim, 528 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 529 | norm_layer=norm_layer, 530 | attn_drop_rate=attn_drop, 531 | d_state=d_state, 532 | ) 533 | for i in range(depth)]) 534 | 535 | if True: # is this really applied? Yes, but been overriden later in VSSM! 536 | def _init_weights(module: nn.Module): 537 | for name, p in module.named_parameters(): 538 | if name in ["out_proj.weight"]: 539 | p = p.clone().detach_() # fake init, just to keep the seed .... 540 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 541 | self.apply(_init_weights) 542 | 543 | if downsample is not None: 544 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 545 | else: 546 | self.downsample = None 547 | 548 | 549 | def forward(self, x): 550 | for blk in self.blocks: 551 | if self.use_checkpoint: 552 | x = checkpoint.checkpoint(blk, x) 553 | else: 554 | x = blk(x) 555 | 556 | if self.downsample is not None: 557 | x = self.downsample(x) 558 | 559 | return x 560 | 561 | 562 | 563 | class VSSLayer_up(nn.Module): 564 | """ A basic Swin Transformer layer for one stage. 565 | Args: 566 | dim (int): Number of input channels. 567 | depth (int): Number of blocks. 568 | drop (float, optional): Dropout rate. Default: 0.0 569 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 570 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 571 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 572 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 573 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 574 | """ 575 | 576 | def __init__( 577 | self, 578 | dim, 579 | depth, 580 | attn_drop=0., 581 | drop_path=0., 582 | norm_layer=nn.LayerNorm, 583 | upsample=None, 584 | use_checkpoint=False, 585 | d_state=16, 586 | **kwargs, 587 | ): 588 | super().__init__() 589 | self.dim = dim 590 | self.use_checkpoint = use_checkpoint 591 | 592 | self.blocks = nn.ModuleList([ 593 | VSSBlock( 594 | hidden_dim=dim, 595 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 596 | norm_layer=norm_layer, 597 | attn_drop_rate=attn_drop, 598 | d_state=d_state, 599 | ) 600 | for i in range(depth)]) 601 | 602 | if True: # is this really applied? Yes, but been overriden later in VSSM! 603 | def _init_weights(module: nn.Module): 604 | for name, p in module.named_parameters(): 605 | if name in ["out_proj.weight"]: 606 | p = p.clone().detach_() # fake init, just to keep the seed .... 607 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 608 | self.apply(_init_weights) 609 | 610 | if upsample is not None: 611 | self.upsample = upsample(dim=dim, norm_layer=norm_layer) 612 | else: 613 | self.upsample = None 614 | 615 | 616 | def forward(self, x): 617 | if self.upsample is not None: 618 | x = self.upsample(x) 619 | for blk in self.blocks: 620 | if self.use_checkpoint: 621 | x = checkpoint.checkpoint(blk, x) 622 | else: 623 | x = blk(x) 624 | return x 625 | 626 | 627 | 628 | class VSSM(nn.Module): 629 | def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], 630 | dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 631 | norm_layer=nn.LayerNorm, patch_norm=True, 632 | use_checkpoint=False, **kwargs): 633 | super().__init__() 634 | self.num_classes = num_classes 635 | self.num_layers = len(depths) 636 | if isinstance(dims, int): 637 | dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] 638 | self.embed_dim = dims[0] 639 | self.num_features = dims[-1] 640 | self.dims = dims 641 | 642 | self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, 643 | norm_layer=norm_layer if patch_norm else None) 644 | 645 | # WASTED absolute position embedding ====================== 646 | self.ape = False 647 | # self.ape = False 648 | # drop_rate = 0.0 649 | if self.ape: 650 | self.patches_resolution = self.patch_embed.patches_resolution 651 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) 652 | trunc_normal_(self.absolute_pos_embed, std=.02) 653 | self.pos_drop = nn.Dropout(p=drop_rate) 654 | 655 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 656 | dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] 657 | 658 | self.layers = nn.ModuleList() 659 | for i_layer in range(self.num_layers): 660 | layer = VSSLayer( 661 | dim=dims[i_layer], 662 | depth=depths[i_layer], 663 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 664 | drop=drop_rate, 665 | attn_drop=attn_drop_rate, 666 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 667 | norm_layer=norm_layer, 668 | downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, 669 | use_checkpoint=use_checkpoint, 670 | ) 671 | self.layers.append(layer) 672 | 673 | self.layers_up = nn.ModuleList() 674 | for i_layer in range(self.num_layers): 675 | layer = VSSLayer_up( 676 | dim=dims_decoder[i_layer], 677 | depth=depths_decoder[i_layer], 678 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 679 | drop=drop_rate, 680 | attn_drop=attn_drop_rate, 681 | drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], 682 | norm_layer=norm_layer, 683 | upsample=PatchExpand2D if (i_layer != 0) else None, 684 | use_checkpoint=use_checkpoint, 685 | ) 686 | self.layers_up.append(layer) 687 | 688 | self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) 689 | self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) 690 | 691 | # self.norm = norm_layer(self.num_features) 692 | # self.avgpool = nn.AdaptiveAvgPool1d(1) 693 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 694 | 695 | self.apply(self._init_weights) 696 | 697 | def _init_weights(self, m: nn.Module): 698 | """ 699 | out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear 700 | no fc.weight found in the any of the model parameters 701 | no nn.Embedding found in the any of the model parameters 702 | so the thing is, VSSBlock initialization is useless 703 | 704 | Conv2D is not intialized !!! 705 | """ 706 | if isinstance(m, nn.Linear): 707 | trunc_normal_(m.weight, std=.02) 708 | if isinstance(m, nn.Linear) and m.bias is not None: 709 | nn.init.constant_(m.bias, 0) 710 | elif isinstance(m, nn.LayerNorm): 711 | nn.init.constant_(m.bias, 0) 712 | nn.init.constant_(m.weight, 1.0) 713 | 714 | @torch.jit.ignore 715 | def no_weight_decay(self): 716 | return {'absolute_pos_embed'} 717 | 718 | @torch.jit.ignore 719 | def no_weight_decay_keywords(self): 720 | return {'relative_position_bias_table'} 721 | 722 | def forward_features(self, x): 723 | skip_list = [] 724 | x = self.patch_embed(x) 725 | if self.ape: 726 | x = x + self.absolute_pos_embed 727 | x = self.pos_drop(x) 728 | 729 | for layer in self.layers: 730 | skip_list.append(x) 731 | x = layer(x) 732 | return x, skip_list 733 | 734 | def forward_features_up(self, x, skip_list): 735 | for inx, layer_up in enumerate(self.layers_up): 736 | if inx == 0: 737 | x = layer_up(x) 738 | else: 739 | x = layer_up(x+skip_list[-inx]) 740 | 741 | return x 742 | 743 | def forward_final(self, x): 744 | x = self.final_up(x) 745 | x = x.permute(0,3,1,2) 746 | x = self.final_conv(x) 747 | return x 748 | 749 | def forward_backbone(self, x): 750 | x = self.patch_embed(x) 751 | if self.ape: 752 | x = x + self.absolute_pos_embed 753 | x = self.pos_drop(x) 754 | 755 | for layer in self.layers: 756 | x = layer(x) 757 | return x 758 | 759 | def forward(self, x): 760 | x, skip_list = self.forward_features(x) 761 | x = self.forward_features_up(x, skip_list) 762 | x = self.forward_final(x) 763 | 764 | return x 765 | 766 | 767 | 768 | 769 | 770 | 771 | 772 | -------------------------------------------------------------------------------- /results/Readme.txt: -------------------------------------------------------------------------------- 1 | Result save location -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | from loader import * 6 | 7 | from models.H_vmunet import H_vmunet 8 | from engine import * 9 | import os 10 | import sys 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 12 | 13 | from utils import * 14 | from configs.config_setting import setting_config 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | 21 | def main(config): 22 | 23 | print('#----------Creating logger----------#') 24 | sys.path.append(config.work_dir + '/') 25 | log_dir = os.path.join(config.work_dir, 'log') 26 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 27 | resume_model = os.path.join('') 28 | outputs = os.path.join(config.work_dir, 'outputs') 29 | if not os.path.exists(checkpoint_dir): 30 | os.makedirs(checkpoint_dir) 31 | if not os.path.exists(outputs): 32 | os.makedirs(outputs) 33 | 34 | global logger 35 | logger = get_logger('test', log_dir) 36 | 37 | log_config_info(config, logger) 38 | 39 | 40 | 41 | 42 | 43 | print('#----------GPU init----------#') 44 | set_seed(config.seed) 45 | gpu_ids = [0]# [0, 1, 2, 3] 46 | torch.cuda.empty_cache() 47 | 48 | 49 | 50 | print('#----------Prepareing Models----------#') 51 | model_cfg = config.model_config 52 | model = H_vmunet(num_classes=model_cfg['num_classes'], 53 | input_channels=model_cfg['input_channels'], 54 | c_list=model_cfg['c_list'], 55 | split_att=model_cfg['split_att'], 56 | bridge=model_cfg['bridge'], 57 | drop_path_rate=model_cfg['drop_path_rate']) 58 | 59 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 60 | 61 | 62 | print('#----------Preparing dataset----------#') 63 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True) 64 | test_loader = DataLoader(test_dataset, 65 | batch_size=1, 66 | shuffle=False, 67 | pin_memory=True, 68 | num_workers=config.num_workers, 69 | drop_last=True) 70 | 71 | print('#----------Prepareing loss, opt, sch and amp----------#') 72 | criterion = config.criterion 73 | optimizer = get_optimizer(config, model) 74 | scheduler = get_scheduler(config, optimizer) 75 | scaler = GradScaler() 76 | 77 | 78 | 79 | 80 | 81 | print('#----------Set other params----------#') 82 | min_loss = 999 83 | start_epoch = 1 84 | min_epoch = 1 85 | 86 | 87 | print('#----------Testing----------#') 88 | best_weight = torch.load(resume_model, map_location=torch.device('cpu')) 89 | model.module.load_state_dict(best_weight) 90 | loss = test_one_epoch( 91 | test_loader, 92 | model, 93 | criterion, 94 | logger, 95 | config, 96 | ) 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | config = setting_config 102 | main(config) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | from loader import * 6 | 7 | from models.H_vmunet import H_vmunet 8 | from engine import * 9 | import os 10 | import sys 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 12 | 13 | from utils import * 14 | from configs.config_setting import setting_config 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | def main(config): 20 | 21 | print('#----------Creating logger----------#') 22 | sys.path.append(config.work_dir + '/') 23 | log_dir = os.path.join(config.work_dir, 'log') 24 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 25 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 26 | outputs = os.path.join(config.work_dir, 'outputs') 27 | if not os.path.exists(checkpoint_dir): 28 | os.makedirs(checkpoint_dir) 29 | if not os.path.exists(outputs): 30 | os.makedirs(outputs) 31 | 32 | global logger 33 | logger = get_logger('train', log_dir) 34 | 35 | log_config_info(config, logger) 36 | 37 | 38 | 39 | 40 | 41 | print('#----------GPU init----------#') 42 | set_seed(config.seed) 43 | gpu_ids = [0]# [0, 1, 2, 3] 44 | torch.cuda.empty_cache() 45 | 46 | 47 | 48 | 49 | 50 | print('#----------Preparing dataset----------#') 51 | train_dataset = isic_loader(path_Data = config.data_path, train = True) 52 | train_loader = DataLoader(train_dataset, 53 | batch_size=config.batch_size, 54 | shuffle=True, 55 | pin_memory=True, 56 | num_workers=config.num_workers) 57 | val_dataset = isic_loader(path_Data = config.data_path, train = False) 58 | val_loader = DataLoader(val_dataset, 59 | batch_size=1, 60 | shuffle=False, 61 | pin_memory=True, 62 | num_workers=config.num_workers, 63 | drop_last=True) 64 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True) 65 | test_loader = DataLoader(test_dataset, 66 | batch_size=1, 67 | shuffle=False, 68 | pin_memory=True, 69 | num_workers=config.num_workers, 70 | drop_last=True) 71 | 72 | 73 | 74 | 75 | print('#----------Prepareing Models----------#') 76 | model_cfg = config.model_config 77 | model = H_vmunet(num_classes=model_cfg['num_classes'], 78 | input_channels=model_cfg['input_channels'], 79 | c_list=model_cfg['c_list'], 80 | split_att=model_cfg['split_att'], 81 | bridge=model_cfg['bridge'], 82 | drop_path_rate=model_cfg['drop_path_rate']) 83 | 84 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 85 | 86 | 87 | 88 | 89 | 90 | print('#----------Prepareing loss, opt, sch and amp----------#') 91 | criterion = config.criterion 92 | optimizer = get_optimizer(config, model) 93 | scheduler = get_scheduler(config, optimizer) 94 | scaler = GradScaler() 95 | 96 | 97 | 98 | 99 | 100 | print('#----------Set other params----------#') 101 | min_loss = 999 102 | start_epoch = 1 103 | min_epoch = 1 104 | 105 | 106 | 107 | 108 | 109 | if os.path.exists(resume_model): 110 | print('#----------Resume Model and Other params----------#') 111 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 112 | model.module.load_state_dict(checkpoint['model_state_dict']) 113 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 114 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 115 | saved_epoch = checkpoint['epoch'] 116 | start_epoch += saved_epoch 117 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 118 | 119 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 120 | logger.info(log_info) 121 | 122 | 123 | 124 | 125 | 126 | print('#----------Training----------#') 127 | for epoch in range(start_epoch, config.epochs + 1): 128 | 129 | torch.cuda.empty_cache() 130 | 131 | train_one_epoch( 132 | train_loader, 133 | model, 134 | criterion, 135 | optimizer, 136 | scheduler, 137 | epoch, 138 | logger, 139 | config, 140 | scaler=scaler 141 | ) 142 | 143 | loss = val_one_epoch( 144 | val_loader, 145 | model, 146 | criterion, 147 | epoch, 148 | logger, 149 | config 150 | ) 151 | 152 | 153 | if loss < min_loss: 154 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 155 | min_loss = loss 156 | min_epoch = epoch 157 | 158 | torch.save( 159 | { 160 | 'epoch': epoch, 161 | 'min_loss': min_loss, 162 | 'min_epoch': min_epoch, 163 | 'loss': loss, 164 | 'model_state_dict': model.module.state_dict(), 165 | 'optimizer_state_dict': optimizer.state_dict(), 166 | 'scheduler_state_dict': scheduler.state_dict(), 167 | }, os.path.join(checkpoint_dir, 'latest.pth')) 168 | 169 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 170 | print('#----------Testing----------#') 171 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 172 | model.module.load_state_dict(best_weight) 173 | loss = test_one_epoch( 174 | test_loader, 175 | model, 176 | criterion, 177 | logger, 178 | config, 179 | ) 180 | os.rename( 181 | os.path.join(checkpoint_dir, 'best.pth'), 182 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 183 | ) 184 | 185 | 186 | if __name__ == '__main__': 187 | config = setting_config 188 | main(config) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.backends.cudnn as cudnn 5 | import torchvision.transforms.functional as TF 6 | import numpy as np 7 | import os 8 | import math 9 | import random 10 | import logging 11 | import logging.handlers 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | def set_seed(seed): 16 | # for hash 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | # for python and numpy 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | # for cpu gpu 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | # for cudnn 26 | cudnn.benchmark = False 27 | cudnn.deterministic = True 28 | 29 | 30 | def get_logger(name, log_dir): 31 | ''' 32 | Args: 33 | name(str): name of logger 34 | log_dir(str): path of log 35 | ''' 36 | 37 | if not os.path.exists(log_dir): 38 | os.makedirs(log_dir) 39 | 40 | logger = logging.getLogger(name) 41 | logger.setLevel(logging.INFO) 42 | 43 | info_name = os.path.join(log_dir, '{}.info.log'.format(name)) 44 | info_handler = logging.handlers.TimedRotatingFileHandler(info_name, 45 | when='D', 46 | encoding='utf-8') 47 | info_handler.setLevel(logging.INFO) 48 | 49 | formatter = logging.Formatter('%(asctime)s - %(message)s', 50 | datefmt='%Y-%m-%d %H:%M:%S') 51 | 52 | info_handler.setFormatter(formatter) 53 | 54 | logger.addHandler(info_handler) 55 | 56 | return logger 57 | 58 | 59 | def log_config_info(config, logger): 60 | config_dict = config.__dict__ 61 | log_info = f'#----------Config info----------#' 62 | logger.info(log_info) 63 | for k, v in config_dict.items(): 64 | if k[0] == '_': 65 | continue 66 | else: 67 | log_info = f'{k}: {v},' 68 | logger.info(log_info) 69 | 70 | 71 | 72 | def get_optimizer(config, model): 73 | assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 74 | 75 | if config.opt == 'Adadelta': 76 | return torch.optim.Adadelta( 77 | model.parameters(), 78 | lr = config.lr, 79 | rho = config.rho, 80 | eps = config.eps, 81 | weight_decay = config.weight_decay 82 | ) 83 | elif config.opt == 'Adagrad': 84 | return torch.optim.Adagrad( 85 | model.parameters(), 86 | lr = config.lr, 87 | lr_decay = config.lr_decay, 88 | eps = config.eps, 89 | weight_decay = config.weight_decay 90 | ) 91 | elif config.opt == 'Adam': 92 | return torch.optim.Adam( 93 | model.parameters(), 94 | lr = config.lr, 95 | betas = config.betas, 96 | eps = config.eps, 97 | weight_decay = config.weight_decay, 98 | amsgrad = config.amsgrad 99 | ) 100 | elif config.opt == 'AdamW': 101 | return torch.optim.AdamW( 102 | model.parameters(), 103 | lr = config.lr, 104 | betas = config.betas, 105 | eps = config.eps, 106 | weight_decay = config.weight_decay, 107 | amsgrad = config.amsgrad 108 | ) 109 | elif config.opt == 'Adamax': 110 | return torch.optim.Adamax( 111 | model.parameters(), 112 | lr = config.lr, 113 | betas = config.betas, 114 | eps = config.eps, 115 | weight_decay = config.weight_decay 116 | ) 117 | elif config.opt == 'ASGD': 118 | return torch.optim.ASGD( 119 | model.parameters(), 120 | lr = config.lr, 121 | lambd = config.lambd, 122 | alpha = config.alpha, 123 | t0 = config.t0, 124 | weight_decay = config.weight_decay 125 | ) 126 | elif config.opt == 'RMSprop': 127 | return torch.optim.RMSprop( 128 | model.parameters(), 129 | lr = config.lr, 130 | momentum = config.momentum, 131 | alpha = config.alpha, 132 | eps = config.eps, 133 | centered = config.centered, 134 | weight_decay = config.weight_decay 135 | ) 136 | elif config.opt == 'Rprop': 137 | return torch.optim.Rprop( 138 | model.parameters(), 139 | lr = config.lr, 140 | etas = config.etas, 141 | step_sizes = config.step_sizes, 142 | ) 143 | elif config.opt == 'SGD': 144 | return torch.optim.SGD( 145 | model.parameters(), 146 | lr = config.lr, 147 | momentum = config.momentum, 148 | weight_decay = config.weight_decay, 149 | dampening = config.dampening, 150 | nesterov = config.nesterov 151 | ) 152 | else: # default opt is SGD 153 | return torch.optim.SGD( 154 | model.parameters(), 155 | lr = 0.01, 156 | momentum = 0.9, 157 | weight_decay = 0.05, 158 | ) 159 | 160 | 161 | 162 | def get_scheduler(config, optimizer): 163 | assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau', 164 | 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!' 165 | if config.sch == 'StepLR': 166 | scheduler = torch.optim.lr_scheduler.StepLR( 167 | optimizer, 168 | step_size = config.step_size, 169 | gamma = config.gamma, 170 | last_epoch = config.last_epoch 171 | ) 172 | elif config.sch == 'MultiStepLR': 173 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 174 | optimizer, 175 | milestones = config.milestones, 176 | gamma = config.gamma, 177 | last_epoch = config.last_epoch 178 | ) 179 | elif config.sch == 'ExponentialLR': 180 | scheduler = torch.optim.lr_scheduler.ExponentialLR( 181 | optimizer, 182 | gamma = config.gamma, 183 | last_epoch = config.last_epoch 184 | ) 185 | elif config.sch == 'CosineAnnealingLR': 186 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 187 | optimizer, 188 | T_max = config.T_max, 189 | eta_min = config.eta_min, 190 | last_epoch = config.last_epoch 191 | ) 192 | elif config.sch == 'ReduceLROnPlateau': 193 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 194 | optimizer, 195 | mode = config.mode, 196 | factor = config.factor, 197 | patience = config.patience, 198 | threshold = config.threshold, 199 | threshold_mode = config.threshold_mode, 200 | cooldown = config.cooldown, 201 | min_lr = config.min_lr, 202 | eps = config.eps 203 | ) 204 | elif config.sch == 'CosineAnnealingWarmRestarts': 205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 206 | optimizer, 207 | T_0 = config.T_0, 208 | T_mult = config.T_mult, 209 | eta_min = config.eta_min, 210 | last_epoch = config.last_epoch 211 | ) 212 | elif config.sch == 'WP_MultiStepLR': 213 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len( 214 | [m for m in config.milestones if m <= epoch]) 215 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 216 | elif config.sch == 'WP_CosineLR': 217 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * ( 218 | math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1) 219 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 220 | 221 | return scheduler 222 | 223 | 224 | 225 | def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None): 226 | img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy() 227 | img = img / 255. if img.max() > 1.1 else img 228 | if datasets == 'retinal': 229 | msk = np.squeeze(msk, axis=0) 230 | msk_pred = np.squeeze(msk_pred, axis=0) 231 | else: 232 | msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0) 233 | msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0) 234 | 235 | plt.figure(figsize=(7,15)) 236 | 237 | plt.subplot(3,1,1) 238 | plt.imshow(img) 239 | plt.axis('off') 240 | 241 | plt.subplot(3,1,2) 242 | plt.imshow(msk, cmap= 'gray') 243 | plt.axis('off') 244 | 245 | plt.subplot(3,1,3) 246 | plt.imshow(msk_pred, cmap = 'gray') 247 | plt.axis('off') 248 | 249 | if test_data_name is not None: 250 | save_path = save_path + test_data_name + '_' 251 | plt.savefig(save_path + str(i) +'.png') 252 | plt.close() 253 | 254 | 255 | 256 | class BCELoss(nn.Module): 257 | def __init__(self): 258 | super(BCELoss, self).__init__() 259 | self.bceloss = nn.BCELoss() 260 | 261 | def forward(self, pred, target): 262 | size = pred.size(0) 263 | pred_ = pred.view(size, -1) 264 | target_ = target.view(size, -1) 265 | 266 | return self.bceloss(pred_, target_) 267 | 268 | 269 | class DiceLoss(nn.Module): 270 | def __init__(self): 271 | super(DiceLoss, self).__init__() 272 | 273 | def forward(self, pred, target): 274 | smooth = 1 275 | size = pred.size(0) 276 | 277 | pred_ = pred.view(size, -1) 278 | target_ = target.view(size, -1) 279 | intersection = pred_ * target_ 280 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth) 281 | dice_loss = 1 - dice_score.sum()/size 282 | 283 | return dice_loss 284 | 285 | 286 | class BceDiceLoss(nn.Module): 287 | def __init__(self, wb=1, wd=1): 288 | super(BceDiceLoss, self).__init__() 289 | self.bce = BCELoss() 290 | self.dice = DiceLoss() 291 | self.wb = wb 292 | self.wd = wd 293 | 294 | def forward(self, pred, target): 295 | bceloss = self.bce(pred, target) 296 | diceloss = self.dice(pred, target) 297 | 298 | loss = self.wd * diceloss + self.wb * bceloss 299 | return loss 300 | 301 | 302 | 303 | 304 | --------------------------------------------------------------------------------