├── LICENSE ├── README.md ├── configs └── config_setting.py ├── dataprepare ├── Prepare_ISIC2017.py ├── Prepare_ISIC2018.py ├── Prepare_PH2.py └── Prepare_your_dataset.py ├── engine.py ├── loader.py ├── models └── UltraLight_VM_UNet.py ├── results └── Readme.txt ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Renkai Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

UltraLight VM-UNet

3 |

Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation

4 | 5 | Renkai Wu1, Yinghao Liu2, Pengchen Liang1\*, Qing Chang1\* 6 | 7 | 1 Shanghai University, 2 University of Shanghai for Science and Technology 8 | 9 | 10 | ArXiv Preprint ([arXiv:2403.20035](https://arxiv.org/abs/2403.20035)) 11 | 12 | 13 | 14 |
15 | 16 | ## 🔥🔥Highlights🔥🔥 17 | ### *1.The UltraLight VM-UNet has only 0.049M parameters, 0.060 GFLOPs, and a model weight file of only 229.1 KB.*
18 | ### *2.Parallel Vision Mamba (or Mamba) is a winner for lightweight models.*
19 | 20 | ## News🚀 21 | (2024.04.24) ***The third version of our paper has been uploaded to [arXiv](https://arxiv.org/abs/2403.20035), adding richer experimental validation. These include not limited to:*** 22 | - Adding key parameter analysis of Mamba variants. 23 | - Adding experiments on parallel connection of multiple Mamba variants. 24 | - Adding the exploration of plug-and-play PVM Layer. 25 | - Adding more ablation experiments for analysis. 26 | 27 | (2024.04.09) ***The second version of our paper has been uploaded to arXiv with adjustments to the description in the methods section.*** 28 | 29 | (2024.04.04) ***Added preprocessing step for private datasets.*** 30 | 31 | (2024.04.01) ***The project code has been uploaded.*** 32 | 33 | (2024.03.29) ***The first edition of our paper has been uploaded to arXiv.*** 📃 34 | 35 | ### Abstract 36 | Traditionally for improving the segmentation performance of models, most approaches prefer to use adding more complex modules. And this is not suitable for the medical field, especially for mobile medical devices, where computationally loaded models are not suitable for real clinical environments due to computational resource constraints. Recently, state-space models (SSMs), represented by Mamba, have become a strong competitor to traditional CNNs and Transformers. In this paper, we deeply explore the key elements of parameter influence in Mamba and propose an UltraLight Vision Mamba UNet (UltraLight VM-UNet) based on this. Specifically, we propose a method for processing features in parallel Vision Mamba, named PVM Layer, which achieves excellent performance with the lowest computational load while keeping the overall number of processing channels constant. We conducted comparisons and ablation experiments with several state-of-the-art lightweight models on three skin lesion public datasets and demonstrated that the UltraLight VM-UNet exhibits the same strong performance competitiveness with parameters of only 0.049M and GFLOPs of 0.060. In addition, this study deeply explores the key elements of parameter influence in Mamba, which will lay a theoretical foundation for Mamba to possibly become a new mainstream module for lightweighting in the future. 37 | 38 | ### Different Parallel Vision Mamba (PVM Layer) settings: 39 | | Setting | Briefly | Params | GFLOPs | DSC | 40 | | --- | --- | --- | --- | --- | 41 | | 1 | No paralleling ( Channel number ```C```) | 0.136M | 0.060 | 0.9069 | 42 | | 2 | Double parallel ( Channel number ```(C/2)+(C/2)```) | 0.070M | 0.060 | 0.9073 | 43 | | 3 | Quadruple parallel ( Channel number ```(C/4)+(C/4)+(C/4)+(C/4)```) | 0.049M | 0.060 | 0.9091 | 44 | 45 | **0. Main Environments.**
46 | The environment installation procedure can be followed by [VM-UNet](https://github.com/JCruan519/VM-UNet), or by following the steps below (python=3.8):
47 | ``` 48 | conda create -n vmunet python=3.8 49 | conda activate vmunet 50 | pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 51 | pip install packaging 52 | pip install timm==0.4.12 53 | pip install pytest chardet yacs termcolor 54 | pip install submitit tensorboardX 55 | pip install triton==2.0.0 56 | pip install causal_conv1d==1.0.0 # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 57 | pip install mamba_ssm==1.0.1 # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 58 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs 59 | ``` 60 | 61 | **1. Datasets.**
62 | Data preprocessing environment installation (python=3.7): 63 | ``` 64 | conda create -n tool python=3.7 65 | conda activate tool 66 | pip install h5py 67 | conda install scipy==1.2.1 # scipy1.2.1 only supports python 3.7 and below. 68 | pip install pillow 69 | ``` 70 | 71 | *A. ISIC2017*
72 | 1. Download the ISIC 2017 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `/data/dataset_isic17/`.
73 | 2. Run `Prepare_ISIC2017.py` for data preparation and dividing data to train, validation and test sets.
74 | 75 | *B. ISIC2018*
76 | 1. Download the ISIC 2018 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `/data/dataset_isic18/`.
77 | 2. Run `Prepare_ISIC2018.py` for data preparation and dividing data to train, validation and test sets.
78 | 79 | *C. PH2*
80 | 1. Download the PH2 dataset from [Dropbox](https://www.dropbox.com/s/k88qukc20ljnbuo/PH2Dataset.rar) or [Google Drive](https://drive.google.com/file/d/1AEMJKAiORlrwdDi37dRqbqXi6zLmnU3Q/view?usp=sharing) and extract both training dataset and ground truth folders inside the `/data/PH2/`.
81 | 2. Run `Prepare_PH2.py` to preprocess the data and form test sets for external validation.
82 | 83 | *D. Prepare your own dataset*
84 | 1. The file format reference is as follows. (The image is a 24-bit png image. The mask is an 8-bit png image. (0 pixel dots for background, 255 pixel dots for target)) 85 | - './your_dataset/' 86 | - images 87 | - 0000.png 88 | - 0001.png 89 | - masks 90 | - 0000.png 91 | - 0001.png 92 | - Prepare_your_dataset.py 93 | 2. In the 'Prepare_your_dataset.py' file, change the number of training sets, validation sets and test sets you want.
94 | 3. Run 'Prepare_your_dataset.py'.
95 | 96 | **2. Train the UltraLight VM-UNet.**
97 | You can simply run the following command to start training, or download the weights file based on this [issue](https://github.com/wurenkai/UltraLight-VM-UNet/issues/38) before training. 98 | ``` 99 | python train.py 100 | ``` 101 | - After trianing, you could obtain the outputs in './results/'
102 | 103 | **3. Test the UltraLight VM-UNet.** 104 | First, in the test.py file, you should change the address of the checkpoint in 'resume_model'. 105 | ``` 106 | python test.py 107 | ``` 108 | - After testing, you could obtain the outputs in './results/'
109 | 110 | **4. Additional information.** 111 | - PVM Layer can be very simply embedded into any model to reduce the overall parameters of the model. Please refer to [issue 7](https://github.com/wurenkai/UltraLight-VM-UNet/issues/7) for the methodology of calculating model parameters and GFLOPs. In addition to the above operations, the exact GFLOPs calculation still requires the addition of the SSM values due to the specific nature of SSM. Refer to [here](https://github.com/state-spaces/mamba/issues/110#issuecomment-1919470069) for details. However, due to the small number of UltraLight VM-UNet channels, the addition of all the SSM values has almost no effect on the results of the GFLOPs obtained through the operations described above (3 valid digits). 112 | ## Citation 113 | If you find this repository helpful, please consider citing:
114 | ``` 115 | @article{wu2024ultralight, 116 | title={UltraLight VM-UNet: Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation}, 117 | author={Wu, Renkai and Liu, Yinghao and Liang, Pengchen and Chang, Qing}, 118 | journal={arXiv preprint arXiv:2403.20035}, 119 | year={2024} 120 | } 121 | ``` 122 | 123 | ## Acknowledgement 124 | Thanks to [Vim](https://github.com/hustvl/Vim), [VMamba](https://github.com/MzeroMiko/VMamba), [VM-UNet](https://github.com/JCruan519/VM-UNet) and [LightM-UNet](https://github.com/MrBlankness/LightM-UNet) for their outstanding work. 125 | -------------------------------------------------------------------------------- /configs/config_setting.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from utils import * 3 | 4 | from datetime import datetime 5 | 6 | class setting_config: 7 | """ 8 | the config of training setting. 9 | """ 10 | network = 'UltraLight_VM_UNet' 11 | model_config = { 12 | 'num_classes': 1, 13 | 'input_channels': 3, 14 | 'c_list': [8,16,24,32,48,64], 15 | 'split_att': 'fc', 16 | 'bridge': True, 17 | } 18 | 19 | test_weights = '' 20 | 21 | datasets = 'ISIC2017' 22 | if datasets == 'ISIC2017': 23 | data_path = '' 24 | elif datasets == 'ISIC2018': 25 | data_path = '' 26 | elif datasets == 'PH2': 27 | data_path = '' 28 | else: 29 | raise Exception('datasets in not right!') 30 | 31 | criterion = BceDiceLoss() 32 | 33 | num_classes = 1 34 | input_size_h = 256 35 | input_size_w = 256 36 | input_channels = 3 37 | distributed = False 38 | local_rank = -1 39 | num_workers = 0 40 | seed = 42 41 | world_size = None 42 | rank = None 43 | amp = False 44 | batch_size = 8 45 | epochs = 250 46 | 47 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 48 | 49 | print_interval = 20 50 | val_interval = 30 51 | save_interval = 100 52 | threshold = 0.5 53 | 54 | 55 | opt = 'AdamW' 56 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 57 | if opt == 'Adadelta': 58 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 59 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 60 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 61 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 62 | elif opt == 'Adagrad': 63 | lr = 0.01 # default: 0.01 – learning rate 64 | lr_decay = 0 # default: 0 – learning rate decay 65 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 66 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 67 | elif opt == 'Adam': 68 | lr = 0.001 # default: 1e-3 – learning rate 69 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 70 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 71 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 72 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 73 | elif opt == 'AdamW': 74 | lr = 0.001 # default: 1e-3 – learning rate 75 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 76 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 77 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 78 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 79 | elif opt == 'Adamax': 80 | lr = 2e-3 # default: 2e-3 – learning rate 81 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 82 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 83 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 84 | elif opt == 'ASGD': 85 | lr = 0.01 # default: 1e-2 – learning rate 86 | lambd = 1e-4 # default: 1e-4 – decay term 87 | alpha = 0.75 # default: 0.75 – power for eta update 88 | t0 = 1e6 # default: 1e6 – point at which to start averaging 89 | weight_decay = 0 # default: 0 – weight decay 90 | elif opt == 'RMSprop': 91 | lr = 1e-2 # default: 1e-2 – learning rate 92 | momentum = 0 # default: 0 – momentum factor 93 | alpha = 0.99 # default: 0.99 – smoothing constant 94 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 95 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 96 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 97 | elif opt == 'Rprop': 98 | lr = 1e-2 # default: 1e-2 – learning rate 99 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 100 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 101 | elif opt == 'SGD': 102 | lr = 0.01 # – learning rate 103 | momentum = 0.9 # default: 0 – momentum factor 104 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 105 | dampening = 0 # default: 0 – dampening for momentum 106 | nesterov = False # default: False – enables Nesterov momentum 107 | 108 | sch = 'CosineAnnealingLR' 109 | if sch == 'StepLR': 110 | step_size = epochs // 5 # – Period of learning rate decay. 111 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 112 | last_epoch = -1 # – The index of last epoch. Default: -1. 113 | elif sch == 'MultiStepLR': 114 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 115 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 116 | last_epoch = -1 # – The index of last epoch. Default: -1. 117 | elif sch == 'ExponentialLR': 118 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 119 | last_epoch = -1 # – The index of last epoch. Default: -1. 120 | elif sch == 'CosineAnnealingLR': 121 | T_max = 50 # – Maximum number of iterations. Cosine function period. 122 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 123 | last_epoch = -1 # – The index of last epoch. Default: -1. 124 | elif sch == 'ReduceLROnPlateau': 125 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 126 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 127 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 128 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 129 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 130 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 131 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 132 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 133 | elif sch == 'CosineAnnealingWarmRestarts': 134 | T_0 = 50 # – Number of iterations for the first restart. 135 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 136 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 137 | last_epoch = -1 # – The index of last epoch. Default: -1. 138 | elif sch == 'WP_MultiStepLR': 139 | warm_up_epochs = 10 140 | gamma = 0.1 141 | milestones = [125, 225] 142 | elif sch == 'WP_CosineLR': 143 | warm_up_epochs = 20 -------------------------------------------------------------------------------- /dataprepare/Prepare_ISIC2017.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Code created on Sat Jun 8 18:15:43 2019 4 | @author: Reza Azad 5 | """ 6 | 7 | """ 8 | Reminder added on December 6, 2023. 9 | Reminder Created on Wed Dec 6 2023 10 | @author: Renkai Wu 11 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1 12 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect. 13 | """ 14 | 15 | import h5py 16 | import numpy as np 17 | import scipy.io as sio 18 | import scipy.misc as sc 19 | import glob 20 | 21 | # Parameters 22 | height = 256 23 | width = 256 24 | channels = 3 25 | 26 | ############################################################# Prepare ISIC 2017 data set ################################################# 27 | Dataset_add = './ISIC2017/' 28 | Tr_add = 'ISIC2017_Task1-2_Training_Input' 29 | 30 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg') 31 | # It contains 2000 training samples 32 | Data_train_2017 = np.zeros([2000, height, width, channels]) 33 | Label_train_2017 = np.zeros([2000, height, width]) 34 | 35 | print('Reading ISIC 2017') 36 | print(Tr_list) 37 | for idx in range(len(Tr_list)): 38 | print(idx+1) 39 | img = sc.imread(Tr_list[idx]) 40 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 41 | Data_train_2017[idx, :,:,:] = img 42 | 43 | b = Tr_list[idx] 44 | a = b[0:len(Dataset_add)] 45 | b = b[len(b)-16: len(b)-4] 46 | add = (a+ 'ISIC2017_Task1_Training_GroundTruth/' + b +'_segmentation.png') 47 | img2 = sc.imread(add) 48 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 49 | Label_train_2017[idx, :,:] = img2 50 | 51 | print('Reading ISIC 2017 finished') 52 | 53 | ################################################################ Make the train and test sets ######################################## 54 | # We consider 1250 samples for training, 150 samples for validation and 600 samples for testing 55 | 56 | Train_img = Data_train_2017[0:1250,:,:,:] 57 | Validation_img = Data_train_2017[1250:1250+150,:,:,:] 58 | Test_img = Data_train_2017[1250+150:2000,:,:,:] 59 | 60 | Train_mask = Label_train_2017[0:1250,:,:] 61 | Validation_mask = Label_train_2017[1250:1250+150,:,:] 62 | Test_mask = Label_train_2017[1250+150:2000,:,:] 63 | 64 | 65 | np.save('data_train', Train_img) 66 | np.save('data_test' , Test_img) 67 | np.save('data_val' , Validation_img) 68 | 69 | np.save('mask_train', Train_mask) 70 | np.save('mask_test' , Test_mask) 71 | np.save('mask_val' , Validation_mask) 72 | 73 | 74 | -------------------------------------------------------------------------------- /dataprepare/Prepare_ISIC2018.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Code created on Sat Jun 8 18:15:43 2019 4 | @author: Reza Azad 5 | """ 6 | 7 | """ 8 | Reminder added on December 6, 2023. 9 | Reminder Created on Wed Dec 6 2023 10 | @author: Renkai Wu 11 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1 12 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect. 13 | """ 14 | 15 | import h5py 16 | import numpy as np 17 | import scipy.io as sio 18 | import scipy.misc as sc 19 | import glob 20 | 21 | # Parameters 22 | height = 256 23 | width = 256 24 | channels = 3 25 | 26 | ############################################################# Prepare ISIC 2018 data set ################################################# 27 | Dataset_add = './ISIC2018/' 28 | Tr_add = 'ISIC2018_Task1-2_Training_Input' 29 | 30 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg') 31 | # It contains 2000 training samples 32 | Data_train_2017 = np.zeros([2594, height, width, channels]) 33 | Label_train_2017 = np.zeros([2594, height, width]) 34 | 35 | print('Reading ISIC 2018') 36 | print(Tr_list) 37 | for idx in range(len(Tr_list)): 38 | print(idx+1) 39 | img = sc.imread(Tr_list[idx]) 40 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 41 | Data_train_2017[idx, :,:,:] = img 42 | 43 | b = Tr_list[idx] 44 | a = b[0:len(Dataset_add)] 45 | b = b[len(b)-16: len(b)-4] 46 | add = (a+ 'ISIC2018_Task1_Training_GroundTruth/' + b +'_segmentation.png') 47 | img2 = sc.imread(add) 48 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 49 | Label_train_2017[idx, :,:] = img2 50 | 51 | print('Reading ISIC 2018 finished') 52 | 53 | ################################################################ Make the train and test sets ######################################## 54 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 55 | 56 | Train_img = Data_train_2017[0:1815,:,:,:] 57 | Validation_img = Data_train_2017[1815:1815+259,:,:,:] 58 | Test_img = Data_train_2017[1815+259:2594,:,:,:] 59 | 60 | Train_mask = Label_train_2017[0:1815,:,:] 61 | Validation_mask = Label_train_2017[1815:1815+259,:,:] 62 | Test_mask = Label_train_2017[1815+259:2594,:,:] 63 | 64 | 65 | np.save('data_train', Train_img) 66 | np.save('data_test' , Test_img) 67 | np.save('data_val' , Validation_img) 68 | 69 | np.save('mask_train', Train_mask) 70 | np.save('mask_test' , Test_mask) 71 | np.save('mask_val' , Validation_mask) 72 | 73 | 74 | -------------------------------------------------------------------------------- /dataprepare/Prepare_PH2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Code created on Sat Jun 8 18:15:43 2019 4 | @author: Reza Azad 5 | """ 6 | 7 | """ 8 | Reminder added on December 6, 2023. 9 | Reminder Created on Wed Dec 6 2023 10 | @author: Renkai Wu 11 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1 12 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect. 13 | 3.PH2 official data provided by the '.bmp' format, should first batch modify the suffix named '.jpg' or '.png' format before processing. 14 | """ 15 | 16 | import h5py 17 | import numpy as np 18 | import scipy.io as sio 19 | import scipy.misc as sc 20 | import glob 21 | 22 | # Parameters 23 | height = 256 24 | width = 256 25 | channels = 3 26 | 27 | ############################################################# Prepare PH2 data set ################################################# 28 | Dataset_add = './PH2/' 29 | Tr_add = 'images' 30 | 31 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg') 32 | # It contains 2000 training samples 33 | Data_train_2017 = np.zeros([200, height, width, channels]) 34 | Label_train_2017 = np.zeros([200, height, width]) 35 | 36 | print('Reading PH2') 37 | print(Tr_list) 38 | for idx in range(len(Tr_list)): 39 | print(idx+1) 40 | img = sc.imread(Tr_list[idx]) 41 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 42 | Data_train_2017[idx, :,:,:] = img 43 | 44 | b = Tr_list[idx] 45 | a = b[0:len(Dataset_add)] 46 | b = b[len(b)-16: len(b)-4] 47 | add = (a+ 'masks/' + b +'.png') 48 | img2 = sc.imread(add) 49 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 50 | Label_train_2017[idx, :,:] = img2 51 | 52 | print('Reading PH2 finished') 53 | 54 | ################################################################ Make test sets ######################################## 55 | # We consider 200 samples for testing 56 | 57 | #Train_img = Data_train_2017[0:1815,:,:,:] 58 | #Validation_img = Data_train_2017[1815:1815+259,:,:,:] 59 | Test_img = Data_train_2017[0:200,:,:,:] 60 | 61 | #Train_mask = Label_train_2017[0:1815,:,:] 62 | #Validation_mask = Label_train_2017[1815:1815+259,:,:] 63 | Test_mask = Label_train_2017[0:200,:,:] 64 | 65 | 66 | #np.save('data_train', Train_img) 67 | np.save('data_test' , Test_img) 68 | #np.save('data_val' , Validation_img) 69 | 70 | #np.save('mask_train', Train_mask) 71 | np.save('mask_test' , Test_mask) 72 | #np.save('mask_val' , Validation_mask) 73 | 74 | 75 | -------------------------------------------------------------------------------- /dataprepare/Prepare_your_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ##scipy==1.2.1 3 | 4 | import h5py 5 | import numpy as np 6 | import scipy.io as sio 7 | import scipy.misc as sc 8 | import glob 9 | 10 | # Parameters 11 | height = 256 # Enter the image size of the model. 12 | width = 256 # Enter the image size of the model. 13 | channels = 3 # Number of image channels 14 | 15 | train_number = 1000 # Randomly assign the number of images for generating the training set. 16 | val_number = 200 # Randomly assign the number of images for generating the validation set. 17 | test_number = 400 # Randomly assign the number of images for generating the test set. 18 | all = int(train_number) + int(val_number) + int(test_number) 19 | 20 | ############################################################# Prepare your data set ################################################# 21 | Tr_list = glob.glob("images"+'/*.png') # Images storage folder. The image type should be 24-bit png format. 22 | # It contains 2594 training samples 23 | Data_train_2018 = np.zeros([all, height, width, channels]) 24 | Label_train_2018 = np.zeros([all, height, width]) 25 | 26 | print('Reading') 27 | print(len(Tr_list)) 28 | for idx in range(len(Tr_list)): 29 | print(idx+1) 30 | img = sc.imread(Tr_list[idx]) 31 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB')) 32 | Data_train_2018[idx, :,:,:] = img 33 | 34 | b = Tr_list[idx] 35 | b = b[len(b)-8: len(b)-4] 36 | add = ("masks/" + b +'.png') # Masks storage folder. The Mask type should be a black and white image of an 8-bit png (0 pixels for the background and 255 pixels for the target). 37 | img2 = sc.imread(add) 38 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear')) 39 | Label_train_2018[idx, :,:] = img2 40 | 41 | print('Reading your dataset finished') 42 | 43 | ################################################################ Make the training, validation and test sets ######################################## 44 | Train_img = Data_train_2018[0:train_number,:,:,:] 45 | Validation_img = Data_train_2018[train_number:train_number+val_number,:,:,:] 46 | Test_img = Data_train_2018[train_number+val_number:all,:,:,:] 47 | 48 | Train_mask = Label_train_2018[0:train_number,:,:] 49 | Validation_mask = Label_train_2018[train_number:train_number+val_number,:,:] 50 | Test_mask = Label_train_2018[train_number+val_number:all,:,:] 51 | 52 | 53 | np.save('data_train', Train_img) 54 | np.save('data_test' , Test_img) 55 | np.save('data_val' , Validation_img) 56 | 57 | np.save('mask_train', Train_mask) 58 | np.save('mask_test' , Test_mask) 59 | np.save('mask_val' , Validation_mask) 60 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from torch.cuda.amp import autocast as autocast 5 | from sklearn.metrics import confusion_matrix 6 | from utils import save_imgs 7 | 8 | 9 | def train_one_epoch(train_loader, 10 | model, 11 | criterion, 12 | optimizer, 13 | scheduler, 14 | epoch, 15 | logger, 16 | config, 17 | scaler=None): 18 | ''' 19 | train model for one epoch 20 | ''' 21 | # switch to train mode 22 | model.train() 23 | 24 | loss_list = [] 25 | 26 | for iter, data in enumerate(train_loader): 27 | optimizer.zero_grad() 28 | images, targets = data 29 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() 30 | if config.amp: 31 | with autocast(): 32 | out = model(images) 33 | loss = criterion(out, targets) 34 | scaler.scale(loss).backward() 35 | scaler.step(optimizer) 36 | scaler.update() 37 | else: 38 | out = model(images) 39 | loss = criterion(out, targets) 40 | loss.backward() 41 | optimizer.step() 42 | 43 | loss_list.append(loss.item()) 44 | 45 | now_lr = optimizer.state_dict()['param_groups'][0]['lr'] 46 | if iter % config.print_interval == 0: 47 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}' 48 | print(log_info) 49 | logger.info(log_info) 50 | scheduler.step() 51 | 52 | 53 | def val_one_epoch(test_loader, 54 | model, 55 | criterion, 56 | epoch, 57 | logger, 58 | config): 59 | # switch to evaluate mode 60 | model.eval() 61 | preds = [] 62 | gts = [] 63 | loss_list = [] 64 | with torch.no_grad(): 65 | for data in tqdm(test_loader): 66 | img, msk = data 67 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 68 | out = model(img) 69 | loss = criterion(out, msk) 70 | loss_list.append(loss.item()) 71 | gts.append(msk.squeeze(1).cpu().detach().numpy()) 72 | if type(out) is tuple: 73 | out = out[0] 74 | out = out.squeeze(1).cpu().detach().numpy() 75 | preds.append(out) 76 | 77 | if epoch % config.val_interval == 0: 78 | preds = np.array(preds).reshape(-1) 79 | gts = np.array(gts).reshape(-1) 80 | 81 | y_pre = np.where(preds>=config.threshold, 1, 0) 82 | y_true = np.where(gts>=0.5, 1, 0) 83 | 84 | confusion = confusion_matrix(y_true, y_pre) 85 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 86 | 87 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 88 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 89 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 90 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 91 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 92 | 93 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 94 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 95 | print(log_info) 96 | logger.info(log_info) 97 | 98 | else: 99 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}' 100 | print(log_info) 101 | logger.info(log_info) 102 | 103 | return np.mean(loss_list) 104 | 105 | 106 | def test_one_epoch(test_loader, 107 | model, 108 | criterion, 109 | logger, 110 | config, 111 | test_data_name=None): 112 | # switch to evaluate mode 113 | model.eval() 114 | preds = [] 115 | gts = [] 116 | loss_list = [] 117 | with torch.no_grad(): 118 | for i, data in enumerate(tqdm(test_loader)): 119 | img, msk = data 120 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 121 | out = model(img) 122 | loss = criterion(out, msk) 123 | loss_list.append(loss.item()) 124 | msk = msk.squeeze(1).cpu().detach().numpy() 125 | gts.append(msk) 126 | if type(out) is tuple: 127 | out = out[0] 128 | out = out.squeeze(1).cpu().detach().numpy() 129 | preds.append(out) 130 | save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name) 131 | 132 | preds = np.array(preds).reshape(-1) 133 | gts = np.array(gts).reshape(-1) 134 | 135 | y_pre = np.where(preds>=config.threshold, 1, 0) 136 | y_true = np.where(gts>=0.5, 1, 0) 137 | 138 | confusion = confusion_matrix(y_true, y_pre) 139 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 140 | 141 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 142 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 143 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 144 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 145 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 146 | 147 | if test_data_name is not None: 148 | log_info = f'test_datasets_name: {test_data_name}' 149 | print(log_info) 150 | logger.info(log_info) 151 | log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 152 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 153 | print(log_info) 154 | logger.info(log_info) 155 | 156 | return np.mean(loss_list) 157 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | from PIL import Image 7 | from einops.layers.torch import Rearrange 8 | from scipy.ndimage.morphology import binary_dilation 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | from scipy import ndimage 12 | from utils import * 13 | 14 | 15 | # ===== normalize over the dataset 16 | def dataset_normalized(imgs): 17 | imgs_normalized = np.empty(imgs.shape) 18 | imgs_std = np.std(imgs) 19 | imgs_mean = np.mean(imgs) 20 | imgs_normalized = (imgs-imgs_mean)/imgs_std 21 | for i in range(imgs.shape[0]): 22 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 23 | return imgs_normalized 24 | 25 | 26 | ## Temporary 27 | class isic_loader(Dataset): 28 | """ dataset class for Brats datasets 29 | """ 30 | def __init__(self, path_Data, train = True, Test = False): 31 | super(isic_loader, self) 32 | self.train = train 33 | if train: 34 | self.data = np.load(path_Data+'data_train.npy') 35 | self.mask = np.load(path_Data+'mask_train.npy') 36 | else: 37 | if Test: 38 | self.data = np.load(path_Data+'data_test.npy') 39 | self.mask = np.load(path_Data+'mask_test.npy') 40 | else: 41 | self.data = np.load(path_Data+'data_val.npy') 42 | self.mask = np.load(path_Data+'mask_val.npy') 43 | 44 | self.data = dataset_normalized(self.data) 45 | self.mask = np.expand_dims(self.mask, axis=3) 46 | self.mask = self.mask/255. 47 | 48 | def __getitem__(self, indx): 49 | img = self.data[indx] 50 | seg = self.mask[indx] 51 | if self.train: 52 | if random.random() > 0.5: 53 | img, seg = self.random_rot_flip(img, seg) 54 | if random.random() > 0.5: 55 | img, seg = self.random_rotate(img, seg) 56 | 57 | seg = torch.tensor(seg.copy()) 58 | img = torch.tensor(img.copy()) 59 | img = img.permute( 2, 0, 1) 60 | seg = seg.permute( 2, 0, 1) 61 | 62 | return img, seg 63 | 64 | def random_rot_flip(self,image, label): 65 | k = np.random.randint(0, 4) 66 | image = np.rot90(image, k) 67 | label = np.rot90(label, k) 68 | axis = np.random.randint(0, 2) 69 | image = np.flip(image, axis=axis).copy() 70 | label = np.flip(label, axis=axis).copy() 71 | return image, label 72 | 73 | def random_rotate(self,image, label): 74 | angle = np.random.randint(20, 80) 75 | image = ndimage.rotate(image, angle, order=0, reshape=False) 76 | label = ndimage.rotate(label, angle, order=0, reshape=False) 77 | return image, label 78 | 79 | 80 | 81 | def __len__(self): 82 | return len(self.data) 83 | -------------------------------------------------------------------------------- /models/UltraLight_VM_UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from timm.models.layers import trunc_normal_ 6 | import math 7 | from mamba_ssm import Mamba 8 | 9 | 10 | class PVMLayer(nn.Module): 11 | def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2): 12 | super().__init__() 13 | self.input_dim = input_dim 14 | self.output_dim = output_dim 15 | self.norm = nn.LayerNorm(input_dim) 16 | self.mamba = Mamba( 17 | d_model=input_dim//4, # Model dimension d_model 18 | d_state=d_state, # SSM state expansion factor 19 | d_conv=d_conv, # Local convolution width 20 | expand=expand, # Block expansion factor 21 | ) 22 | self.proj = nn.Linear(input_dim, output_dim) 23 | self.skip_scale= nn.Parameter(torch.ones(1)) 24 | 25 | def forward(self, x): 26 | if x.dtype == torch.float16: 27 | x = x.type(torch.float32) 28 | B, C = x.shape[:2] 29 | assert C == self.input_dim 30 | n_tokens = x.shape[2:].numel() 31 | img_dims = x.shape[2:] 32 | x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) 33 | x_norm = self.norm(x_flat) 34 | 35 | x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2) 36 | x_mamba1 = self.mamba(x1) + self.skip_scale * x1 37 | x_mamba2 = self.mamba(x2) + self.skip_scale * x2 38 | x_mamba3 = self.mamba(x3) + self.skip_scale * x3 39 | x_mamba4 = self.mamba(x4) + self.skip_scale * x4 40 | x_mamba = torch.cat([x_mamba1, x_mamba2,x_mamba3,x_mamba4], dim=2) 41 | 42 | x_mamba = self.norm(x_mamba) 43 | x_mamba = self.proj(x_mamba) 44 | out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims) 45 | return out 46 | 47 | 48 | class Channel_Att_Bridge(nn.Module): 49 | def __init__(self, c_list, split_att='fc'): 50 | super().__init__() 51 | c_list_sum = sum(c_list) - c_list[-1] 52 | self.split_att = split_att 53 | self.avgpool = nn.AdaptiveAvgPool2d(1) 54 | self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) 55 | self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1) 56 | self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1) 57 | self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1) 58 | self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1) 59 | self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1) 60 | self.sigmoid = nn.Sigmoid() 61 | 62 | def forward(self, t1, t2, t3, t4, t5): 63 | att = torch.cat((self.avgpool(t1), 64 | self.avgpool(t2), 65 | self.avgpool(t3), 66 | self.avgpool(t4), 67 | self.avgpool(t5)), dim=1) 68 | att = self.get_all_att(att.squeeze(-1).transpose(-1, -2)) 69 | if self.split_att != 'fc': 70 | att = att.transpose(-1, -2) 71 | att1 = self.sigmoid(self.att1(att)) 72 | att2 = self.sigmoid(self.att2(att)) 73 | att3 = self.sigmoid(self.att3(att)) 74 | att4 = self.sigmoid(self.att4(att)) 75 | att5 = self.sigmoid(self.att5(att)) 76 | if self.split_att == 'fc': 77 | att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1) 78 | att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2) 79 | att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3) 80 | att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4) 81 | att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5) 82 | else: 83 | att1 = att1.unsqueeze(-1).expand_as(t1) 84 | att2 = att2.unsqueeze(-1).expand_as(t2) 85 | att3 = att3.unsqueeze(-1).expand_as(t3) 86 | att4 = att4.unsqueeze(-1).expand_as(t4) 87 | att5 = att5.unsqueeze(-1).expand_as(t5) 88 | 89 | return att1, att2, att3, att4, att5 90 | 91 | 92 | class Spatial_Att_Bridge(nn.Module): 93 | def __init__(self): 94 | super().__init__() 95 | self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3), 96 | nn.Sigmoid()) 97 | 98 | def forward(self, t1, t2, t3, t4, t5): 99 | t_list = [t1, t2, t3, t4, t5] 100 | att_list = [] 101 | for t in t_list: 102 | avg_out = torch.mean(t, dim=1, keepdim=True) 103 | max_out, _ = torch.max(t, dim=1, keepdim=True) 104 | att = torch.cat([avg_out, max_out], dim=1) 105 | att = self.shared_conv2d(att) 106 | att_list.append(att) 107 | return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4] 108 | 109 | 110 | class SC_Att_Bridge(nn.Module): 111 | def __init__(self, c_list, split_att='fc'): 112 | super().__init__() 113 | 114 | self.catt = Channel_Att_Bridge(c_list, split_att=split_att) 115 | self.satt = Spatial_Att_Bridge() 116 | 117 | def forward(self, t1, t2, t3, t4, t5): 118 | r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5 119 | 120 | satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5) 121 | t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5 122 | 123 | r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5 124 | t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5 125 | 126 | catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5) 127 | t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5 128 | 129 | return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_ 130 | 131 | 132 | class UltraLight_VM_UNet(nn.Module): 133 | 134 | def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64], 135 | split_att='fc', bridge=True): 136 | super().__init__() 137 | 138 | self.bridge = bridge 139 | 140 | self.encoder1 = nn.Sequential( 141 | nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1), 142 | ) 143 | self.encoder2 =nn.Sequential( 144 | nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1), 145 | ) 146 | self.encoder3 = nn.Sequential( 147 | nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1), 148 | ) 149 | self.encoder4 = nn.Sequential( 150 | PVMLayer(input_dim=c_list[2], output_dim=c_list[3]) 151 | ) 152 | self.encoder5 = nn.Sequential( 153 | PVMLayer(input_dim=c_list[3], output_dim=c_list[4]) 154 | ) 155 | self.encoder6 = nn.Sequential( 156 | PVMLayer(input_dim=c_list[4], output_dim=c_list[5]) 157 | ) 158 | 159 | if bridge: 160 | self.scab = SC_Att_Bridge(c_list, split_att) 161 | print('SC_Att_Bridge was used') 162 | 163 | self.decoder1 = nn.Sequential( 164 | PVMLayer(input_dim=c_list[5], output_dim=c_list[4]) 165 | ) 166 | self.decoder2 = nn.Sequential( 167 | PVMLayer(input_dim=c_list[4], output_dim=c_list[3]) 168 | ) 169 | self.decoder3 = nn.Sequential( 170 | PVMLayer(input_dim=c_list[3], output_dim=c_list[2]) 171 | ) 172 | self.decoder4 = nn.Sequential( 173 | nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1), 174 | ) 175 | self.decoder5 = nn.Sequential( 176 | nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1), 177 | ) 178 | self.ebn1 = nn.GroupNorm(4, c_list[0]) 179 | self.ebn2 = nn.GroupNorm(4, c_list[1]) 180 | self.ebn3 = nn.GroupNorm(4, c_list[2]) 181 | self.ebn4 = nn.GroupNorm(4, c_list[3]) 182 | self.ebn5 = nn.GroupNorm(4, c_list[4]) 183 | self.dbn1 = nn.GroupNorm(4, c_list[4]) 184 | self.dbn2 = nn.GroupNorm(4, c_list[3]) 185 | self.dbn3 = nn.GroupNorm(4, c_list[2]) 186 | self.dbn4 = nn.GroupNorm(4, c_list[1]) 187 | self.dbn5 = nn.GroupNorm(4, c_list[0]) 188 | 189 | self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1) 190 | 191 | self.apply(self._init_weights) 192 | 193 | def _init_weights(self, m): 194 | if isinstance(m, nn.Linear): 195 | trunc_normal_(m.weight, std=.02) 196 | if isinstance(m, nn.Linear) and m.bias is not None: 197 | nn.init.constant_(m.bias, 0) 198 | elif isinstance(m, nn.Conv1d): 199 | n = m.kernel_size[0] * m.out_channels 200 | m.weight.data.normal_(0, math.sqrt(2. / n)) 201 | elif isinstance(m, nn.Conv2d): 202 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 203 | fan_out //= m.groups 204 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 205 | if m.bias is not None: 206 | m.bias.data.zero_() 207 | 208 | def forward(self, x): 209 | 210 | out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) 211 | t1 = out # b, c0, H/2, W/2 212 | 213 | out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) 214 | t2 = out # b, c1, H/4, W/4 215 | 216 | out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) 217 | t3 = out # b, c2, H/8, W/8 218 | 219 | out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2)) 220 | t4 = out # b, c3, H/16, W/16 221 | 222 | out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2)) 223 | t5 = out # b, c4, H/32, W/32 224 | 225 | if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5) 226 | 227 | out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32 228 | 229 | out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32 230 | out5 = torch.add(out5, t5) # b, c4, H/32, W/32 231 | 232 | out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16 233 | out4 = torch.add(out4, t4) # b, c3, H/16, W/16 234 | 235 | out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8 236 | out3 = torch.add(out3, t3) # b, c2, H/8, W/8 237 | 238 | out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4 239 | out2 = torch.add(out2, t2) # b, c1, H/4, W/4 240 | 241 | out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2 242 | out1 = torch.add(out1, t1) # b, c0, H/2, W/2 243 | 244 | out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W 245 | 246 | return torch.sigmoid(out0) 247 | 248 | 249 | -------------------------------------------------------------------------------- /results/Readme.txt: -------------------------------------------------------------------------------- 1 | Result save location -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | from loader import * 6 | 7 | from models.UltraLight_VM_UNet import UltraLight_VM_UNet 8 | from engine import * 9 | import os 10 | import sys 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 12 | 13 | from utils import * 14 | from configs.config_setting import setting_config 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | 21 | def main(config): 22 | 23 | print('#----------Creating logger----------#') 24 | sys.path.append(config.work_dir + '/') 25 | log_dir = os.path.join(config.work_dir, 'log') 26 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 27 | resume_model = os.path.join('') 28 | outputs = os.path.join(config.work_dir, 'outputs') 29 | if not os.path.exists(checkpoint_dir): 30 | os.makedirs(checkpoint_dir) 31 | if not os.path.exists(outputs): 32 | os.makedirs(outputs) 33 | 34 | global logger 35 | logger = get_logger('test', log_dir) 36 | 37 | log_config_info(config, logger) 38 | 39 | 40 | 41 | 42 | 43 | print('#----------GPU init----------#') 44 | set_seed(config.seed) 45 | gpu_ids = [0]# [0, 1, 2, 3] 46 | torch.cuda.empty_cache() 47 | 48 | 49 | 50 | print('#----------Prepareing Models----------#') 51 | model_cfg = config.model_config 52 | model = UltraLight_VM_UNet(num_classes=model_cfg['num_classes'], 53 | input_channels=model_cfg['input_channels'], 54 | c_list=model_cfg['c_list'], 55 | split_att=model_cfg['split_att'], 56 | bridge=model_cfg['bridge'],) 57 | 58 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 59 | 60 | 61 | print('#----------Preparing dataset----------#') 62 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True) 63 | test_loader = DataLoader(test_dataset, 64 | batch_size=1, 65 | shuffle=False, 66 | pin_memory=True, 67 | num_workers=config.num_workers, 68 | drop_last=True) 69 | 70 | print('#----------Prepareing loss, opt, sch and amp----------#') 71 | criterion = config.criterion 72 | optimizer = get_optimizer(config, model) 73 | scheduler = get_scheduler(config, optimizer) 74 | scaler = GradScaler() 75 | 76 | 77 | 78 | 79 | 80 | print('#----------Set other params----------#') 81 | min_loss = 999 82 | start_epoch = 1 83 | min_epoch = 1 84 | 85 | 86 | print('#----------Testing----------#') 87 | best_weight = torch.load(resume_model, map_location=torch.device('cpu')) 88 | model.module.load_state_dict(best_weight) 89 | loss = test_one_epoch( 90 | test_loader, 91 | model, 92 | criterion, 93 | logger, 94 | config, 95 | ) 96 | 97 | 98 | 99 | if __name__ == '__main__': 100 | config = setting_config 101 | main(config) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | from loader import * 6 | 7 | from models.UltraLight_VM_UNet import UltraLight_VM_UNet 8 | from engine import * 9 | import os 10 | import sys 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 12 | 13 | from utils import * 14 | from configs.config_setting import setting_config 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def main(config): 21 | 22 | print('#----------Creating logger----------#') 23 | sys.path.append(config.work_dir + '/') 24 | log_dir = os.path.join(config.work_dir, 'log') 25 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 26 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 27 | outputs = os.path.join(config.work_dir, 'outputs') 28 | if not os.path.exists(checkpoint_dir): 29 | os.makedirs(checkpoint_dir) 30 | if not os.path.exists(outputs): 31 | os.makedirs(outputs) 32 | 33 | global logger 34 | logger = get_logger('train', log_dir) 35 | 36 | log_config_info(config, logger) 37 | 38 | 39 | 40 | 41 | 42 | print('#----------GPU init----------#') 43 | set_seed(config.seed) 44 | gpu_ids = [0]# [0, 1, 2, 3] 45 | torch.cuda.empty_cache() 46 | 47 | 48 | 49 | 50 | 51 | print('#----------Preparing dataset----------#') 52 | train_dataset = isic_loader(path_Data = config.data_path, train = True) 53 | train_loader = DataLoader(train_dataset, 54 | batch_size=config.batch_size, 55 | shuffle=True, 56 | pin_memory=True, 57 | num_workers=config.num_workers) 58 | val_dataset = isic_loader(path_Data = config.data_path, train = False) 59 | val_loader = DataLoader(val_dataset, 60 | batch_size=1, 61 | shuffle=False, 62 | pin_memory=True, 63 | num_workers=config.num_workers, 64 | drop_last=True) 65 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True) 66 | test_loader = DataLoader(test_dataset, 67 | batch_size=1, 68 | shuffle=False, 69 | pin_memory=True, 70 | num_workers=config.num_workers, 71 | drop_last=True) 72 | 73 | 74 | 75 | 76 | print('#----------Prepareing Models----------#') 77 | model_cfg = config.model_config 78 | model = UltraLight_VM_UNet(num_classes=model_cfg['num_classes'], 79 | input_channels=model_cfg['input_channels'], 80 | c_list=model_cfg['c_list'], 81 | split_att=model_cfg['split_att'], 82 | bridge=model_cfg['bridge'],) 83 | 84 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 85 | 86 | 87 | 88 | 89 | 90 | 91 | print('#----------Prepareing loss, opt, sch and amp----------#') 92 | criterion = config.criterion 93 | optimizer = get_optimizer(config, model) 94 | scheduler = get_scheduler(config, optimizer) 95 | scaler = GradScaler() 96 | 97 | 98 | 99 | 100 | 101 | print('#----------Set other params----------#') 102 | min_loss = 999 103 | start_epoch = 1 104 | min_epoch = 1 105 | 106 | 107 | 108 | 109 | 110 | if os.path.exists(resume_model): 111 | print('#----------Resume Model and Other params----------#') 112 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 113 | model.module.load_state_dict(checkpoint['model_state_dict']) 114 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 115 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 116 | saved_epoch = checkpoint['epoch'] 117 | start_epoch += saved_epoch 118 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 119 | 120 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 121 | logger.info(log_info) 122 | 123 | 124 | 125 | 126 | 127 | print('#----------Training----------#') 128 | for epoch in range(start_epoch, config.epochs + 1): 129 | 130 | torch.cuda.empty_cache() 131 | 132 | train_one_epoch( 133 | train_loader, 134 | model, 135 | criterion, 136 | optimizer, 137 | scheduler, 138 | epoch, 139 | logger, 140 | config, 141 | scaler=scaler 142 | ) 143 | 144 | loss = val_one_epoch( 145 | val_loader, 146 | model, 147 | criterion, 148 | epoch, 149 | logger, 150 | config 151 | ) 152 | 153 | 154 | if loss < min_loss: 155 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 156 | min_loss = loss 157 | min_epoch = epoch 158 | 159 | torch.save( 160 | { 161 | 'epoch': epoch, 162 | 'min_loss': min_loss, 163 | 'min_epoch': min_epoch, 164 | 'loss': loss, 165 | 'model_state_dict': model.module.state_dict(), 166 | 'optimizer_state_dict': optimizer.state_dict(), 167 | 'scheduler_state_dict': scheduler.state_dict(), 168 | }, os.path.join(checkpoint_dir, 'latest.pth')) 169 | 170 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 171 | print('#----------Testing----------#') 172 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 173 | model.module.load_state_dict(best_weight) 174 | loss = test_one_epoch( 175 | test_loader, 176 | model, 177 | criterion, 178 | logger, 179 | config, 180 | ) 181 | os.rename( 182 | os.path.join(checkpoint_dir, 'best.pth'), 183 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 184 | ) 185 | 186 | 187 | if __name__ == '__main__': 188 | config = setting_config 189 | main(config) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.backends.cudnn as cudnn 5 | import torchvision.transforms.functional as TF 6 | import numpy as np 7 | import os 8 | import math 9 | import random 10 | import logging 11 | import logging.handlers 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | def set_seed(seed): 16 | # for hash 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | # for python and numpy 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | # for cpu gpu 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | # for cudnn 26 | cudnn.benchmark = False 27 | cudnn.deterministic = True 28 | 29 | 30 | def get_logger(name, log_dir): 31 | ''' 32 | Args: 33 | name(str): name of logger 34 | log_dir(str): path of log 35 | ''' 36 | 37 | if not os.path.exists(log_dir): 38 | os.makedirs(log_dir) 39 | 40 | logger = logging.getLogger(name) 41 | logger.setLevel(logging.INFO) 42 | 43 | info_name = os.path.join(log_dir, '{}.info.log'.format(name)) 44 | info_handler = logging.handlers.TimedRotatingFileHandler(info_name, 45 | when='D', 46 | encoding='utf-8') 47 | info_handler.setLevel(logging.INFO) 48 | 49 | formatter = logging.Formatter('%(asctime)s - %(message)s', 50 | datefmt='%Y-%m-%d %H:%M:%S') 51 | 52 | info_handler.setFormatter(formatter) 53 | 54 | logger.addHandler(info_handler) 55 | 56 | return logger 57 | 58 | 59 | def log_config_info(config, logger): 60 | config_dict = config.__dict__ 61 | log_info = f'#----------Config info----------#' 62 | logger.info(log_info) 63 | for k, v in config_dict.items(): 64 | if k[0] == '_': 65 | continue 66 | else: 67 | log_info = f'{k}: {v},' 68 | logger.info(log_info) 69 | 70 | 71 | 72 | def get_optimizer(config, model): 73 | assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 74 | 75 | if config.opt == 'Adadelta': 76 | return torch.optim.Adadelta( 77 | model.parameters(), 78 | lr = config.lr, 79 | rho = config.rho, 80 | eps = config.eps, 81 | weight_decay = config.weight_decay 82 | ) 83 | elif config.opt == 'Adagrad': 84 | return torch.optim.Adagrad( 85 | model.parameters(), 86 | lr = config.lr, 87 | lr_decay = config.lr_decay, 88 | eps = config.eps, 89 | weight_decay = config.weight_decay 90 | ) 91 | elif config.opt == 'Adam': 92 | return torch.optim.Adam( 93 | model.parameters(), 94 | lr = config.lr, 95 | betas = config.betas, 96 | eps = config.eps, 97 | weight_decay = config.weight_decay, 98 | amsgrad = config.amsgrad 99 | ) 100 | elif config.opt == 'AdamW': 101 | return torch.optim.AdamW( 102 | model.parameters(), 103 | lr = config.lr, 104 | betas = config.betas, 105 | eps = config.eps, 106 | weight_decay = config.weight_decay, 107 | amsgrad = config.amsgrad 108 | ) 109 | elif config.opt == 'Adamax': 110 | return torch.optim.Adamax( 111 | model.parameters(), 112 | lr = config.lr, 113 | betas = config.betas, 114 | eps = config.eps, 115 | weight_decay = config.weight_decay 116 | ) 117 | elif config.opt == 'ASGD': 118 | return torch.optim.ASGD( 119 | model.parameters(), 120 | lr = config.lr, 121 | lambd = config.lambd, 122 | alpha = config.alpha, 123 | t0 = config.t0, 124 | weight_decay = config.weight_decay 125 | ) 126 | elif config.opt == 'RMSprop': 127 | return torch.optim.RMSprop( 128 | model.parameters(), 129 | lr = config.lr, 130 | momentum = config.momentum, 131 | alpha = config.alpha, 132 | eps = config.eps, 133 | centered = config.centered, 134 | weight_decay = config.weight_decay 135 | ) 136 | elif config.opt == 'Rprop': 137 | return torch.optim.Rprop( 138 | model.parameters(), 139 | lr = config.lr, 140 | etas = config.etas, 141 | step_sizes = config.step_sizes, 142 | ) 143 | elif config.opt == 'SGD': 144 | return torch.optim.SGD( 145 | model.parameters(), 146 | lr = config.lr, 147 | momentum = config.momentum, 148 | weight_decay = config.weight_decay, 149 | dampening = config.dampening, 150 | nesterov = config.nesterov 151 | ) 152 | else: # default opt is SGD 153 | return torch.optim.SGD( 154 | model.parameters(), 155 | lr = 0.01, 156 | momentum = 0.9, 157 | weight_decay = 0.05, 158 | ) 159 | 160 | 161 | 162 | def get_scheduler(config, optimizer): 163 | assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau', 164 | 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!' 165 | if config.sch == 'StepLR': 166 | scheduler = torch.optim.lr_scheduler.StepLR( 167 | optimizer, 168 | step_size = config.step_size, 169 | gamma = config.gamma, 170 | last_epoch = config.last_epoch 171 | ) 172 | elif config.sch == 'MultiStepLR': 173 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 174 | optimizer, 175 | milestones = config.milestones, 176 | gamma = config.gamma, 177 | last_epoch = config.last_epoch 178 | ) 179 | elif config.sch == 'ExponentialLR': 180 | scheduler = torch.optim.lr_scheduler.ExponentialLR( 181 | optimizer, 182 | gamma = config.gamma, 183 | last_epoch = config.last_epoch 184 | ) 185 | elif config.sch == 'CosineAnnealingLR': 186 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 187 | optimizer, 188 | T_max = config.T_max, 189 | eta_min = config.eta_min, 190 | last_epoch = config.last_epoch 191 | ) 192 | elif config.sch == 'ReduceLROnPlateau': 193 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 194 | optimizer, 195 | mode = config.mode, 196 | factor = config.factor, 197 | patience = config.patience, 198 | threshold = config.threshold, 199 | threshold_mode = config.threshold_mode, 200 | cooldown = config.cooldown, 201 | min_lr = config.min_lr, 202 | eps = config.eps 203 | ) 204 | elif config.sch == 'CosineAnnealingWarmRestarts': 205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 206 | optimizer, 207 | T_0 = config.T_0, 208 | T_mult = config.T_mult, 209 | eta_min = config.eta_min, 210 | last_epoch = config.last_epoch 211 | ) 212 | elif config.sch == 'WP_MultiStepLR': 213 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len( 214 | [m for m in config.milestones if m <= epoch]) 215 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 216 | elif config.sch == 'WP_CosineLR': 217 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * ( 218 | math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1) 219 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 220 | 221 | return scheduler 222 | 223 | 224 | 225 | def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None): 226 | img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy() 227 | img = img / 255. if img.max() > 1.1 else img 228 | if datasets == 'retinal': 229 | msk = np.squeeze(msk, axis=0) 230 | msk_pred = np.squeeze(msk_pred, axis=0) 231 | else: 232 | msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0) 233 | msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0) 234 | 235 | plt.figure(figsize=(7,15)) 236 | 237 | plt.subplot(3,1,1) 238 | plt.imshow(img) 239 | plt.axis('off') 240 | 241 | plt.subplot(3,1,2) 242 | plt.imshow(msk, cmap= 'gray') 243 | plt.axis('off') 244 | 245 | plt.subplot(3,1,3) 246 | plt.imshow(msk_pred, cmap = 'gray') 247 | plt.axis('off') 248 | 249 | if test_data_name is not None: 250 | save_path = save_path + test_data_name + '_' 251 | plt.savefig(save_path + str(i) +'.png') 252 | plt.close() 253 | 254 | 255 | 256 | class BCELoss(nn.Module): 257 | def __init__(self): 258 | super(BCELoss, self).__init__() 259 | self.bceloss = nn.BCELoss() 260 | 261 | def forward(self, pred, target): 262 | size = pred.size(0) 263 | pred_ = pred.view(size, -1) 264 | target_ = target.view(size, -1) 265 | 266 | return self.bceloss(pred_, target_) 267 | 268 | 269 | class DiceLoss(nn.Module): 270 | def __init__(self): 271 | super(DiceLoss, self).__init__() 272 | 273 | def forward(self, pred, target): 274 | smooth = 1 275 | size = pred.size(0) 276 | 277 | pred_ = pred.view(size, -1) 278 | target_ = target.view(size, -1) 279 | intersection = pred_ * target_ 280 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth) 281 | dice_loss = 1 - dice_score.sum()/size 282 | 283 | return dice_loss 284 | 285 | 286 | class BceDiceLoss(nn.Module): 287 | def __init__(self, wb=1, wd=1): 288 | super(BceDiceLoss, self).__init__() 289 | self.bce = BCELoss() 290 | self.dice = DiceLoss() 291 | self.wb = wb 292 | self.wd = wd 293 | 294 | def forward(self, pred, target): 295 | bceloss = self.bce(pred, target) 296 | diceloss = self.dice(pred, target) 297 | 298 | loss = self.wd * diceloss + self.wb * bceloss 299 | return loss 300 | 301 | 302 | from thop import profile ## 导入thop模块 303 | def cal_params_flops(model, size, logger): 304 | input = torch.randn(1, 3, size, size).cuda() 305 | flops, params = profile(model, inputs=(input,)) 306 | print('flops',flops/1e9) ## 打印计算量 307 | print('params',params/1e6) ## 打印参数量 308 | 309 | total = sum(p.numel() for p in model.parameters()) 310 | print("Total params: %.3fM" % (total/1e6)) 311 | logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}') 312 | 313 | 314 | --------------------------------------------------------------------------------