├── LICENSE
├── README.md
├── configs
└── config_setting.py
├── dataprepare
├── Prepare_ISIC2017.py
├── Prepare_ISIC2018.py
├── Prepare_PH2.py
└── Prepare_your_dataset.py
├── engine.py
├── loader.py
├── models
└── UltraLight_VM_UNet.py
├── results
└── Readme.txt
├── test.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Renkai Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
UltraLight VM-UNet
3 | Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation
4 |
5 | Renkai Wu1, Yinghao Liu2, Pengchen Liang1\*, Qing Chang1\*
6 |
7 | 1 Shanghai University, 2 University of Shanghai for Science and Technology
8 |
9 |
10 | ArXiv Preprint ([arXiv:2403.20035](https://arxiv.org/abs/2403.20035))
11 |
12 |
13 |
14 |
15 |
16 | ## 🔥🔥Highlights🔥🔥
17 | ### *1.The UltraLight VM-UNet has only 0.049M parameters, 0.060 GFLOPs, and a model weight file of only 229.1 KB.*
18 | ### *2.Parallel Vision Mamba (or Mamba) is a winner for lightweight models.*
19 |
20 | ## News🚀
21 | (2024.04.24) ***The third version of our paper has been uploaded to [arXiv](https://arxiv.org/abs/2403.20035), adding richer experimental validation. These include not limited to:***
22 | - Adding key parameter analysis of Mamba variants.
23 | - Adding experiments on parallel connection of multiple Mamba variants.
24 | - Adding the exploration of plug-and-play PVM Layer.
25 | - Adding more ablation experiments for analysis.
26 |
27 | (2024.04.09) ***The second version of our paper has been uploaded to arXiv with adjustments to the description in the methods section.***
28 |
29 | (2024.04.04) ***Added preprocessing step for private datasets.***
30 |
31 | (2024.04.01) ***The project code has been uploaded.***
32 |
33 | (2024.03.29) ***The first edition of our paper has been uploaded to arXiv.*** 📃
34 |
35 | ### Abstract
36 | Traditionally for improving the segmentation performance of models, most approaches prefer to use adding more complex modules. And this is not suitable for the medical field, especially for mobile medical devices, where computationally loaded models are not suitable for real clinical environments due to computational resource constraints. Recently, state-space models (SSMs), represented by Mamba, have become a strong competitor to traditional CNNs and Transformers. In this paper, we deeply explore the key elements of parameter influence in Mamba and propose an UltraLight Vision Mamba UNet (UltraLight VM-UNet) based on this. Specifically, we propose a method for processing features in parallel Vision Mamba, named PVM Layer, which achieves excellent performance with the lowest computational load while keeping the overall number of processing channels constant. We conducted comparisons and ablation experiments with several state-of-the-art lightweight models on three skin lesion public datasets and demonstrated that the UltraLight VM-UNet exhibits the same strong performance competitiveness with parameters of only 0.049M and GFLOPs of 0.060. In addition, this study deeply explores the key elements of parameter influence in Mamba, which will lay a theoretical foundation for Mamba to possibly become a new mainstream module for lightweighting in the future.
37 |
38 | ### Different Parallel Vision Mamba (PVM Layer) settings:
39 | | Setting | Briefly | Params | GFLOPs | DSC |
40 | | --- | --- | --- | --- | --- |
41 | | 1 | No paralleling ( Channel number ```C```) | 0.136M | 0.060 | 0.9069 |
42 | | 2 | Double parallel ( Channel number ```(C/2)+(C/2)```) | 0.070M | 0.060 | 0.9073 |
43 | | 3 | Quadruple parallel ( Channel number ```(C/4)+(C/4)+(C/4)+(C/4)```) | 0.049M | 0.060 | 0.9091 |
44 |
45 | **0. Main Environments.**
46 | The environment installation procedure can be followed by [VM-UNet](https://github.com/JCruan519/VM-UNet), or by following the steps below (python=3.8):
47 | ```
48 | conda create -n vmunet python=3.8
49 | conda activate vmunet
50 | pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
51 | pip install packaging
52 | pip install timm==0.4.12
53 | pip install pytest chardet yacs termcolor
54 | pip install submitit tensorboardX
55 | pip install triton==2.0.0
56 | pip install causal_conv1d==1.0.0 # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
57 | pip install mamba_ssm==1.0.1 # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
58 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs
59 | ```
60 |
61 | **1. Datasets.**
62 | Data preprocessing environment installation (python=3.7):
63 | ```
64 | conda create -n tool python=3.7
65 | conda activate tool
66 | pip install h5py
67 | conda install scipy==1.2.1 # scipy1.2.1 only supports python 3.7 and below.
68 | pip install pillow
69 | ```
70 |
71 | *A. ISIC2017*
72 | 1. Download the ISIC 2017 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `/data/dataset_isic17/`.
73 | 2. Run `Prepare_ISIC2017.py` for data preparation and dividing data to train, validation and test sets.
74 |
75 | *B. ISIC2018*
76 | 1. Download the ISIC 2018 train dataset from [this](https://challenge.isic-archive.com/data) link and extract both training dataset and ground truth folders inside the `/data/dataset_isic18/`.
77 | 2. Run `Prepare_ISIC2018.py` for data preparation and dividing data to train, validation and test sets.
78 |
79 | *C. PH2*
80 | 1. Download the PH2 dataset from [Dropbox](https://www.dropbox.com/s/k88qukc20ljnbuo/PH2Dataset.rar) or [Google Drive](https://drive.google.com/file/d/1AEMJKAiORlrwdDi37dRqbqXi6zLmnU3Q/view?usp=sharing) and extract both training dataset and ground truth folders inside the `/data/PH2/`.
81 | 2. Run `Prepare_PH2.py` to preprocess the data and form test sets for external validation.
82 |
83 | *D. Prepare your own dataset*
84 | 1. The file format reference is as follows. (The image is a 24-bit png image. The mask is an 8-bit png image. (0 pixel dots for background, 255 pixel dots for target))
85 | - './your_dataset/'
86 | - images
87 | - 0000.png
88 | - 0001.png
89 | - masks
90 | - 0000.png
91 | - 0001.png
92 | - Prepare_your_dataset.py
93 | 2. In the 'Prepare_your_dataset.py' file, change the number of training sets, validation sets and test sets you want.
94 | 3. Run 'Prepare_your_dataset.py'.
95 |
96 | **2. Train the UltraLight VM-UNet.**
97 | You can simply run the following command to start training, or download the weights file based on this [issue](https://github.com/wurenkai/UltraLight-VM-UNet/issues/38) before training.
98 | ```
99 | python train.py
100 | ```
101 | - After trianing, you could obtain the outputs in './results/'
102 |
103 | **3. Test the UltraLight VM-UNet.**
104 | First, in the test.py file, you should change the address of the checkpoint in 'resume_model'.
105 | ```
106 | python test.py
107 | ```
108 | - After testing, you could obtain the outputs in './results/'
109 |
110 | **4. Additional information.**
111 | - PVM Layer can be very simply embedded into any model to reduce the overall parameters of the model. Please refer to [issue 7](https://github.com/wurenkai/UltraLight-VM-UNet/issues/7) for the methodology of calculating model parameters and GFLOPs. In addition to the above operations, the exact GFLOPs calculation still requires the addition of the SSM values due to the specific nature of SSM. Refer to [here](https://github.com/state-spaces/mamba/issues/110#issuecomment-1919470069) for details. However, due to the small number of UltraLight VM-UNet channels, the addition of all the SSM values has almost no effect on the results of the GFLOPs obtained through the operations described above (3 valid digits).
112 | ## Citation
113 | If you find this repository helpful, please consider citing:
114 | ```
115 | @article{wu2024ultralight,
116 | title={UltraLight VM-UNet: Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation},
117 | author={Wu, Renkai and Liu, Yinghao and Liang, Pengchen and Chang, Qing},
118 | journal={arXiv preprint arXiv:2403.20035},
119 | year={2024}
120 | }
121 | ```
122 |
123 | ## Acknowledgement
124 | Thanks to [Vim](https://github.com/hustvl/Vim), [VMamba](https://github.com/MzeroMiko/VMamba), [VM-UNet](https://github.com/JCruan519/VM-UNet) and [LightM-UNet](https://github.com/MrBlankness/LightM-UNet) for their outstanding work.
125 |
--------------------------------------------------------------------------------
/configs/config_setting.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from utils import *
3 |
4 | from datetime import datetime
5 |
6 | class setting_config:
7 | """
8 | the config of training setting.
9 | """
10 | network = 'UltraLight_VM_UNet'
11 | model_config = {
12 | 'num_classes': 1,
13 | 'input_channels': 3,
14 | 'c_list': [8,16,24,32,48,64],
15 | 'split_att': 'fc',
16 | 'bridge': True,
17 | }
18 |
19 | test_weights = ''
20 |
21 | datasets = 'ISIC2017'
22 | if datasets == 'ISIC2017':
23 | data_path = ''
24 | elif datasets == 'ISIC2018':
25 | data_path = ''
26 | elif datasets == 'PH2':
27 | data_path = ''
28 | else:
29 | raise Exception('datasets in not right!')
30 |
31 | criterion = BceDiceLoss()
32 |
33 | num_classes = 1
34 | input_size_h = 256
35 | input_size_w = 256
36 | input_channels = 3
37 | distributed = False
38 | local_rank = -1
39 | num_workers = 0
40 | seed = 42
41 | world_size = None
42 | rank = None
43 | amp = False
44 | batch_size = 8
45 | epochs = 250
46 |
47 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/'
48 |
49 | print_interval = 20
50 | val_interval = 30
51 | save_interval = 100
52 | threshold = 0.5
53 |
54 |
55 | opt = 'AdamW'
56 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!'
57 | if opt == 'Adadelta':
58 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters
59 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients
60 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability
61 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty)
62 | elif opt == 'Adagrad':
63 | lr = 0.01 # default: 0.01 – learning rate
64 | lr_decay = 0 # default: 0 – learning rate decay
65 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability
66 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty)
67 | elif opt == 'Adam':
68 | lr = 0.001 # default: 1e-3 – learning rate
69 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
70 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
71 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty)
72 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond
73 | elif opt == 'AdamW':
74 | lr = 0.001 # default: 1e-3 – learning rate
75 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
76 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
77 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient
78 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond
79 | elif opt == 'Adamax':
80 | lr = 2e-3 # default: 2e-3 – learning rate
81 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
82 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
83 | weight_decay = 0 # default: 0 – weight decay (L2 penalty)
84 | elif opt == 'ASGD':
85 | lr = 0.01 # default: 1e-2 – learning rate
86 | lambd = 1e-4 # default: 1e-4 – decay term
87 | alpha = 0.75 # default: 0.75 – power for eta update
88 | t0 = 1e6 # default: 1e6 – point at which to start averaging
89 | weight_decay = 0 # default: 0 – weight decay
90 | elif opt == 'RMSprop':
91 | lr = 1e-2 # default: 1e-2 – learning rate
92 | momentum = 0 # default: 0 – momentum factor
93 | alpha = 0.99 # default: 0.99 – smoothing constant
94 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
95 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
96 | weight_decay = 0 # default: 0 – weight decay (L2 penalty)
97 | elif opt == 'Rprop':
98 | lr = 1e-2 # default: 1e-2 – learning rate
99 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors
100 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes
101 | elif opt == 'SGD':
102 | lr = 0.01 # – learning rate
103 | momentum = 0.9 # default: 0 – momentum factor
104 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty)
105 | dampening = 0 # default: 0 – dampening for momentum
106 | nesterov = False # default: False – enables Nesterov momentum
107 |
108 | sch = 'CosineAnnealingLR'
109 | if sch == 'StepLR':
110 | step_size = epochs // 5 # – Period of learning rate decay.
111 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1
112 | last_epoch = -1 # – The index of last epoch. Default: -1.
113 | elif sch == 'MultiStepLR':
114 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing.
115 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1.
116 | last_epoch = -1 # – The index of last epoch. Default: -1.
117 | elif sch == 'ExponentialLR':
118 | gamma = 0.99 # – Multiplicative factor of learning rate decay.
119 | last_epoch = -1 # – The index of last epoch. Default: -1.
120 | elif sch == 'CosineAnnealingLR':
121 | T_max = 50 # – Maximum number of iterations. Cosine function period.
122 | eta_min = 0.00001 # – Minimum learning rate. Default: 0.
123 | last_epoch = -1 # – The index of last epoch. Default: -1.
124 | elif sch == 'ReduceLROnPlateau':
125 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
126 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
127 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
128 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
129 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
130 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
131 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
132 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
133 | elif sch == 'CosineAnnealingWarmRestarts':
134 | T_0 = 50 # – Number of iterations for the first restart.
135 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1.
136 | eta_min = 1e-6 # – Minimum learning rate. Default: 0.
137 | last_epoch = -1 # – The index of last epoch. Default: -1.
138 | elif sch == 'WP_MultiStepLR':
139 | warm_up_epochs = 10
140 | gamma = 0.1
141 | milestones = [125, 225]
142 | elif sch == 'WP_CosineLR':
143 | warm_up_epochs = 20
--------------------------------------------------------------------------------
/dataprepare/Prepare_ISIC2017.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Code created on Sat Jun 8 18:15:43 2019
4 | @author: Reza Azad
5 | """
6 |
7 | """
8 | Reminder added on December 6, 2023.
9 | Reminder Created on Wed Dec 6 2023
10 | @author: Renkai Wu
11 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1
12 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect.
13 | """
14 |
15 | import h5py
16 | import numpy as np
17 | import scipy.io as sio
18 | import scipy.misc as sc
19 | import glob
20 |
21 | # Parameters
22 | height = 256
23 | width = 256
24 | channels = 3
25 |
26 | ############################################################# Prepare ISIC 2017 data set #################################################
27 | Dataset_add = './ISIC2017/'
28 | Tr_add = 'ISIC2017_Task1-2_Training_Input'
29 |
30 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg')
31 | # It contains 2000 training samples
32 | Data_train_2017 = np.zeros([2000, height, width, channels])
33 | Label_train_2017 = np.zeros([2000, height, width])
34 |
35 | print('Reading ISIC 2017')
36 | print(Tr_list)
37 | for idx in range(len(Tr_list)):
38 | print(idx+1)
39 | img = sc.imread(Tr_list[idx])
40 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB'))
41 | Data_train_2017[idx, :,:,:] = img
42 |
43 | b = Tr_list[idx]
44 | a = b[0:len(Dataset_add)]
45 | b = b[len(b)-16: len(b)-4]
46 | add = (a+ 'ISIC2017_Task1_Training_GroundTruth/' + b +'_segmentation.png')
47 | img2 = sc.imread(add)
48 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear'))
49 | Label_train_2017[idx, :,:] = img2
50 |
51 | print('Reading ISIC 2017 finished')
52 |
53 | ################################################################ Make the train and test sets ########################################
54 | # We consider 1250 samples for training, 150 samples for validation and 600 samples for testing
55 |
56 | Train_img = Data_train_2017[0:1250,:,:,:]
57 | Validation_img = Data_train_2017[1250:1250+150,:,:,:]
58 | Test_img = Data_train_2017[1250+150:2000,:,:,:]
59 |
60 | Train_mask = Label_train_2017[0:1250,:,:]
61 | Validation_mask = Label_train_2017[1250:1250+150,:,:]
62 | Test_mask = Label_train_2017[1250+150:2000,:,:]
63 |
64 |
65 | np.save('data_train', Train_img)
66 | np.save('data_test' , Test_img)
67 | np.save('data_val' , Validation_img)
68 |
69 | np.save('mask_train', Train_mask)
70 | np.save('mask_test' , Test_mask)
71 | np.save('mask_val' , Validation_mask)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/dataprepare/Prepare_ISIC2018.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Code created on Sat Jun 8 18:15:43 2019
4 | @author: Reza Azad
5 | """
6 |
7 | """
8 | Reminder added on December 6, 2023.
9 | Reminder Created on Wed Dec 6 2023
10 | @author: Renkai Wu
11 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1
12 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect.
13 | """
14 |
15 | import h5py
16 | import numpy as np
17 | import scipy.io as sio
18 | import scipy.misc as sc
19 | import glob
20 |
21 | # Parameters
22 | height = 256
23 | width = 256
24 | channels = 3
25 |
26 | ############################################################# Prepare ISIC 2018 data set #################################################
27 | Dataset_add = './ISIC2018/'
28 | Tr_add = 'ISIC2018_Task1-2_Training_Input'
29 |
30 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg')
31 | # It contains 2000 training samples
32 | Data_train_2017 = np.zeros([2594, height, width, channels])
33 | Label_train_2017 = np.zeros([2594, height, width])
34 |
35 | print('Reading ISIC 2018')
36 | print(Tr_list)
37 | for idx in range(len(Tr_list)):
38 | print(idx+1)
39 | img = sc.imread(Tr_list[idx])
40 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB'))
41 | Data_train_2017[idx, :,:,:] = img
42 |
43 | b = Tr_list[idx]
44 | a = b[0:len(Dataset_add)]
45 | b = b[len(b)-16: len(b)-4]
46 | add = (a+ 'ISIC2018_Task1_Training_GroundTruth/' + b +'_segmentation.png')
47 | img2 = sc.imread(add)
48 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear'))
49 | Label_train_2017[idx, :,:] = img2
50 |
51 | print('Reading ISIC 2018 finished')
52 |
53 | ################################################################ Make the train and test sets ########################################
54 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
55 |
56 | Train_img = Data_train_2017[0:1815,:,:,:]
57 | Validation_img = Data_train_2017[1815:1815+259,:,:,:]
58 | Test_img = Data_train_2017[1815+259:2594,:,:,:]
59 |
60 | Train_mask = Label_train_2017[0:1815,:,:]
61 | Validation_mask = Label_train_2017[1815:1815+259,:,:]
62 | Test_mask = Label_train_2017[1815+259:2594,:,:]
63 |
64 |
65 | np.save('data_train', Train_img)
66 | np.save('data_test' , Test_img)
67 | np.save('data_val' , Validation_img)
68 |
69 | np.save('mask_train', Train_mask)
70 | np.save('mask_test' , Test_mask)
71 | np.save('mask_val' , Validation_mask)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/dataprepare/Prepare_PH2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Code created on Sat Jun 8 18:15:43 2019
4 | @author: Reza Azad
5 | """
6 |
7 | """
8 | Reminder added on December 6, 2023.
9 | Reminder Created on Wed Dec 6 2023
10 | @author: Renkai Wu
11 | 1.Note that the scipy package should need to be degraded. Otherwise, you need to modify the following code. ##scipy==1.2.1
12 | 2.Add a name that displays the file to be processed. If it does not appear, the output npy file is incorrect.
13 | 3.PH2 official data provided by the '.bmp' format, should first batch modify the suffix named '.jpg' or '.png' format before processing.
14 | """
15 |
16 | import h5py
17 | import numpy as np
18 | import scipy.io as sio
19 | import scipy.misc as sc
20 | import glob
21 |
22 | # Parameters
23 | height = 256
24 | width = 256
25 | channels = 3
26 |
27 | ############################################################# Prepare PH2 data set #################################################
28 | Dataset_add = './PH2/'
29 | Tr_add = 'images'
30 |
31 | Tr_list = glob.glob(Dataset_add+ Tr_add+'/*.jpg')
32 | # It contains 2000 training samples
33 | Data_train_2017 = np.zeros([200, height, width, channels])
34 | Label_train_2017 = np.zeros([200, height, width])
35 |
36 | print('Reading PH2')
37 | print(Tr_list)
38 | for idx in range(len(Tr_list)):
39 | print(idx+1)
40 | img = sc.imread(Tr_list[idx])
41 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB'))
42 | Data_train_2017[idx, :,:,:] = img
43 |
44 | b = Tr_list[idx]
45 | a = b[0:len(Dataset_add)]
46 | b = b[len(b)-16: len(b)-4]
47 | add = (a+ 'masks/' + b +'.png')
48 | img2 = sc.imread(add)
49 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear'))
50 | Label_train_2017[idx, :,:] = img2
51 |
52 | print('Reading PH2 finished')
53 |
54 | ################################################################ Make test sets ########################################
55 | # We consider 200 samples for testing
56 |
57 | #Train_img = Data_train_2017[0:1815,:,:,:]
58 | #Validation_img = Data_train_2017[1815:1815+259,:,:,:]
59 | Test_img = Data_train_2017[0:200,:,:,:]
60 |
61 | #Train_mask = Label_train_2017[0:1815,:,:]
62 | #Validation_mask = Label_train_2017[1815:1815+259,:,:]
63 | Test_mask = Label_train_2017[0:200,:,:]
64 |
65 |
66 | #np.save('data_train', Train_img)
67 | np.save('data_test' , Test_img)
68 | #np.save('data_val' , Validation_img)
69 |
70 | #np.save('mask_train', Train_mask)
71 | np.save('mask_test' , Test_mask)
72 | #np.save('mask_val' , Validation_mask)
73 |
74 |
75 |
--------------------------------------------------------------------------------
/dataprepare/Prepare_your_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | ##scipy==1.2.1
3 |
4 | import h5py
5 | import numpy as np
6 | import scipy.io as sio
7 | import scipy.misc as sc
8 | import glob
9 |
10 | # Parameters
11 | height = 256 # Enter the image size of the model.
12 | width = 256 # Enter the image size of the model.
13 | channels = 3 # Number of image channels
14 |
15 | train_number = 1000 # Randomly assign the number of images for generating the training set.
16 | val_number = 200 # Randomly assign the number of images for generating the validation set.
17 | test_number = 400 # Randomly assign the number of images for generating the test set.
18 | all = int(train_number) + int(val_number) + int(test_number)
19 |
20 | ############################################################# Prepare your data set #################################################
21 | Tr_list = glob.glob("images"+'/*.png') # Images storage folder. The image type should be 24-bit png format.
22 | # It contains 2594 training samples
23 | Data_train_2018 = np.zeros([all, height, width, channels])
24 | Label_train_2018 = np.zeros([all, height, width])
25 |
26 | print('Reading')
27 | print(len(Tr_list))
28 | for idx in range(len(Tr_list)):
29 | print(idx+1)
30 | img = sc.imread(Tr_list[idx])
31 | img = np.double(sc.imresize(img, [height, width, channels], interp='bilinear', mode = 'RGB'))
32 | Data_train_2018[idx, :,:,:] = img
33 |
34 | b = Tr_list[idx]
35 | b = b[len(b)-8: len(b)-4]
36 | add = ("masks/" + b +'.png') # Masks storage folder. The Mask type should be a black and white image of an 8-bit png (0 pixels for the background and 255 pixels for the target).
37 | img2 = sc.imread(add)
38 | img2 = np.double(sc.imresize(img2, [height, width], interp='bilinear'))
39 | Label_train_2018[idx, :,:] = img2
40 |
41 | print('Reading your dataset finished')
42 |
43 | ################################################################ Make the training, validation and test sets ########################################
44 | Train_img = Data_train_2018[0:train_number,:,:,:]
45 | Validation_img = Data_train_2018[train_number:train_number+val_number,:,:,:]
46 | Test_img = Data_train_2018[train_number+val_number:all,:,:,:]
47 |
48 | Train_mask = Label_train_2018[0:train_number,:,:]
49 | Validation_mask = Label_train_2018[train_number:train_number+val_number,:,:]
50 | Test_mask = Label_train_2018[train_number+val_number:all,:,:]
51 |
52 |
53 | np.save('data_train', Train_img)
54 | np.save('data_test' , Test_img)
55 | np.save('data_val' , Validation_img)
56 |
57 | np.save('mask_train', Train_mask)
58 | np.save('mask_test' , Test_mask)
59 | np.save('mask_val' , Validation_mask)
60 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import torch
4 | from torch.cuda.amp import autocast as autocast
5 | from sklearn.metrics import confusion_matrix
6 | from utils import save_imgs
7 |
8 |
9 | def train_one_epoch(train_loader,
10 | model,
11 | criterion,
12 | optimizer,
13 | scheduler,
14 | epoch,
15 | logger,
16 | config,
17 | scaler=None):
18 | '''
19 | train model for one epoch
20 | '''
21 | # switch to train mode
22 | model.train()
23 |
24 | loss_list = []
25 |
26 | for iter, data in enumerate(train_loader):
27 | optimizer.zero_grad()
28 | images, targets = data
29 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float()
30 | if config.amp:
31 | with autocast():
32 | out = model(images)
33 | loss = criterion(out, targets)
34 | scaler.scale(loss).backward()
35 | scaler.step(optimizer)
36 | scaler.update()
37 | else:
38 | out = model(images)
39 | loss = criterion(out, targets)
40 | loss.backward()
41 | optimizer.step()
42 |
43 | loss_list.append(loss.item())
44 |
45 | now_lr = optimizer.state_dict()['param_groups'][0]['lr']
46 | if iter % config.print_interval == 0:
47 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}'
48 | print(log_info)
49 | logger.info(log_info)
50 | scheduler.step()
51 |
52 |
53 | def val_one_epoch(test_loader,
54 | model,
55 | criterion,
56 | epoch,
57 | logger,
58 | config):
59 | # switch to evaluate mode
60 | model.eval()
61 | preds = []
62 | gts = []
63 | loss_list = []
64 | with torch.no_grad():
65 | for data in tqdm(test_loader):
66 | img, msk = data
67 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
68 | out = model(img)
69 | loss = criterion(out, msk)
70 | loss_list.append(loss.item())
71 | gts.append(msk.squeeze(1).cpu().detach().numpy())
72 | if type(out) is tuple:
73 | out = out[0]
74 | out = out.squeeze(1).cpu().detach().numpy()
75 | preds.append(out)
76 |
77 | if epoch % config.val_interval == 0:
78 | preds = np.array(preds).reshape(-1)
79 | gts = np.array(gts).reshape(-1)
80 |
81 | y_pre = np.where(preds>=config.threshold, 1, 0)
82 | y_true = np.where(gts>=0.5, 1, 0)
83 |
84 | confusion = confusion_matrix(y_true, y_pre)
85 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
86 |
87 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
88 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
89 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
90 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
91 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
92 |
93 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
94 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
95 | print(log_info)
96 | logger.info(log_info)
97 |
98 | else:
99 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}'
100 | print(log_info)
101 | logger.info(log_info)
102 |
103 | return np.mean(loss_list)
104 |
105 |
106 | def test_one_epoch(test_loader,
107 | model,
108 | criterion,
109 | logger,
110 | config,
111 | test_data_name=None):
112 | # switch to evaluate mode
113 | model.eval()
114 | preds = []
115 | gts = []
116 | loss_list = []
117 | with torch.no_grad():
118 | for i, data in enumerate(tqdm(test_loader)):
119 | img, msk = data
120 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
121 | out = model(img)
122 | loss = criterion(out, msk)
123 | loss_list.append(loss.item())
124 | msk = msk.squeeze(1).cpu().detach().numpy()
125 | gts.append(msk)
126 | if type(out) is tuple:
127 | out = out[0]
128 | out = out.squeeze(1).cpu().detach().numpy()
129 | preds.append(out)
130 | save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name)
131 |
132 | preds = np.array(preds).reshape(-1)
133 | gts = np.array(gts).reshape(-1)
134 |
135 | y_pre = np.where(preds>=config.threshold, 1, 0)
136 | y_true = np.where(gts>=0.5, 1, 0)
137 |
138 | confusion = confusion_matrix(y_true, y_pre)
139 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
140 |
141 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
142 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
143 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
144 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
145 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
146 |
147 | if test_data_name is not None:
148 | log_info = f'test_datasets_name: {test_data_name}'
149 | print(log_info)
150 | logger.info(log_info)
151 | log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
152 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
153 | print(log_info)
154 | logger.info(log_info)
155 |
156 | return np.mean(loss_list)
157 |
--------------------------------------------------------------------------------
/loader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torch
3 | import numpy as np
4 | import random
5 | import os
6 | from PIL import Image
7 | from einops.layers.torch import Rearrange
8 | from scipy.ndimage.morphology import binary_dilation
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms
11 | from scipy import ndimage
12 | from utils import *
13 |
14 |
15 | # ===== normalize over the dataset
16 | def dataset_normalized(imgs):
17 | imgs_normalized = np.empty(imgs.shape)
18 | imgs_std = np.std(imgs)
19 | imgs_mean = np.mean(imgs)
20 | imgs_normalized = (imgs-imgs_mean)/imgs_std
21 | for i in range(imgs.shape[0]):
22 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255
23 | return imgs_normalized
24 |
25 |
26 | ## Temporary
27 | class isic_loader(Dataset):
28 | """ dataset class for Brats datasets
29 | """
30 | def __init__(self, path_Data, train = True, Test = False):
31 | super(isic_loader, self)
32 | self.train = train
33 | if train:
34 | self.data = np.load(path_Data+'data_train.npy')
35 | self.mask = np.load(path_Data+'mask_train.npy')
36 | else:
37 | if Test:
38 | self.data = np.load(path_Data+'data_test.npy')
39 | self.mask = np.load(path_Data+'mask_test.npy')
40 | else:
41 | self.data = np.load(path_Data+'data_val.npy')
42 | self.mask = np.load(path_Data+'mask_val.npy')
43 |
44 | self.data = dataset_normalized(self.data)
45 | self.mask = np.expand_dims(self.mask, axis=3)
46 | self.mask = self.mask/255.
47 |
48 | def __getitem__(self, indx):
49 | img = self.data[indx]
50 | seg = self.mask[indx]
51 | if self.train:
52 | if random.random() > 0.5:
53 | img, seg = self.random_rot_flip(img, seg)
54 | if random.random() > 0.5:
55 | img, seg = self.random_rotate(img, seg)
56 |
57 | seg = torch.tensor(seg.copy())
58 | img = torch.tensor(img.copy())
59 | img = img.permute( 2, 0, 1)
60 | seg = seg.permute( 2, 0, 1)
61 |
62 | return img, seg
63 |
64 | def random_rot_flip(self,image, label):
65 | k = np.random.randint(0, 4)
66 | image = np.rot90(image, k)
67 | label = np.rot90(label, k)
68 | axis = np.random.randint(0, 2)
69 | image = np.flip(image, axis=axis).copy()
70 | label = np.flip(label, axis=axis).copy()
71 | return image, label
72 |
73 | def random_rotate(self,image, label):
74 | angle = np.random.randint(20, 80)
75 | image = ndimage.rotate(image, angle, order=0, reshape=False)
76 | label = ndimage.rotate(label, angle, order=0, reshape=False)
77 | return image, label
78 |
79 |
80 |
81 | def __len__(self):
82 | return len(self.data)
83 |
--------------------------------------------------------------------------------
/models/UltraLight_VM_UNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from timm.models.layers import trunc_normal_
6 | import math
7 | from mamba_ssm import Mamba
8 |
9 |
10 | class PVMLayer(nn.Module):
11 | def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2):
12 | super().__init__()
13 | self.input_dim = input_dim
14 | self.output_dim = output_dim
15 | self.norm = nn.LayerNorm(input_dim)
16 | self.mamba = Mamba(
17 | d_model=input_dim//4, # Model dimension d_model
18 | d_state=d_state, # SSM state expansion factor
19 | d_conv=d_conv, # Local convolution width
20 | expand=expand, # Block expansion factor
21 | )
22 | self.proj = nn.Linear(input_dim, output_dim)
23 | self.skip_scale= nn.Parameter(torch.ones(1))
24 |
25 | def forward(self, x):
26 | if x.dtype == torch.float16:
27 | x = x.type(torch.float32)
28 | B, C = x.shape[:2]
29 | assert C == self.input_dim
30 | n_tokens = x.shape[2:].numel()
31 | img_dims = x.shape[2:]
32 | x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
33 | x_norm = self.norm(x_flat)
34 |
35 | x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2)
36 | x_mamba1 = self.mamba(x1) + self.skip_scale * x1
37 | x_mamba2 = self.mamba(x2) + self.skip_scale * x2
38 | x_mamba3 = self.mamba(x3) + self.skip_scale * x3
39 | x_mamba4 = self.mamba(x4) + self.skip_scale * x4
40 | x_mamba = torch.cat([x_mamba1, x_mamba2,x_mamba3,x_mamba4], dim=2)
41 |
42 | x_mamba = self.norm(x_mamba)
43 | x_mamba = self.proj(x_mamba)
44 | out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims)
45 | return out
46 |
47 |
48 | class Channel_Att_Bridge(nn.Module):
49 | def __init__(self, c_list, split_att='fc'):
50 | super().__init__()
51 | c_list_sum = sum(c_list) - c_list[-1]
52 | self.split_att = split_att
53 | self.avgpool = nn.AdaptiveAvgPool2d(1)
54 | self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
55 | self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1)
56 | self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1)
57 | self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1)
58 | self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1)
59 | self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1)
60 | self.sigmoid = nn.Sigmoid()
61 |
62 | def forward(self, t1, t2, t3, t4, t5):
63 | att = torch.cat((self.avgpool(t1),
64 | self.avgpool(t2),
65 | self.avgpool(t3),
66 | self.avgpool(t4),
67 | self.avgpool(t5)), dim=1)
68 | att = self.get_all_att(att.squeeze(-1).transpose(-1, -2))
69 | if self.split_att != 'fc':
70 | att = att.transpose(-1, -2)
71 | att1 = self.sigmoid(self.att1(att))
72 | att2 = self.sigmoid(self.att2(att))
73 | att3 = self.sigmoid(self.att3(att))
74 | att4 = self.sigmoid(self.att4(att))
75 | att5 = self.sigmoid(self.att5(att))
76 | if self.split_att == 'fc':
77 | att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1)
78 | att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2)
79 | att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3)
80 | att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4)
81 | att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5)
82 | else:
83 | att1 = att1.unsqueeze(-1).expand_as(t1)
84 | att2 = att2.unsqueeze(-1).expand_as(t2)
85 | att3 = att3.unsqueeze(-1).expand_as(t3)
86 | att4 = att4.unsqueeze(-1).expand_as(t4)
87 | att5 = att5.unsqueeze(-1).expand_as(t5)
88 |
89 | return att1, att2, att3, att4, att5
90 |
91 |
92 | class Spatial_Att_Bridge(nn.Module):
93 | def __init__(self):
94 | super().__init__()
95 | self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3),
96 | nn.Sigmoid())
97 |
98 | def forward(self, t1, t2, t3, t4, t5):
99 | t_list = [t1, t2, t3, t4, t5]
100 | att_list = []
101 | for t in t_list:
102 | avg_out = torch.mean(t, dim=1, keepdim=True)
103 | max_out, _ = torch.max(t, dim=1, keepdim=True)
104 | att = torch.cat([avg_out, max_out], dim=1)
105 | att = self.shared_conv2d(att)
106 | att_list.append(att)
107 | return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4]
108 |
109 |
110 | class SC_Att_Bridge(nn.Module):
111 | def __init__(self, c_list, split_att='fc'):
112 | super().__init__()
113 |
114 | self.catt = Channel_Att_Bridge(c_list, split_att=split_att)
115 | self.satt = Spatial_Att_Bridge()
116 |
117 | def forward(self, t1, t2, t3, t4, t5):
118 | r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5
119 |
120 | satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5)
121 | t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5
122 |
123 | r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5
124 | t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5
125 |
126 | catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5)
127 | t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5
128 |
129 | return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_
130 |
131 |
132 | class UltraLight_VM_UNet(nn.Module):
133 |
134 | def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64],
135 | split_att='fc', bridge=True):
136 | super().__init__()
137 |
138 | self.bridge = bridge
139 |
140 | self.encoder1 = nn.Sequential(
141 | nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1),
142 | )
143 | self.encoder2 =nn.Sequential(
144 | nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1),
145 | )
146 | self.encoder3 = nn.Sequential(
147 | nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1),
148 | )
149 | self.encoder4 = nn.Sequential(
150 | PVMLayer(input_dim=c_list[2], output_dim=c_list[3])
151 | )
152 | self.encoder5 = nn.Sequential(
153 | PVMLayer(input_dim=c_list[3], output_dim=c_list[4])
154 | )
155 | self.encoder6 = nn.Sequential(
156 | PVMLayer(input_dim=c_list[4], output_dim=c_list[5])
157 | )
158 |
159 | if bridge:
160 | self.scab = SC_Att_Bridge(c_list, split_att)
161 | print('SC_Att_Bridge was used')
162 |
163 | self.decoder1 = nn.Sequential(
164 | PVMLayer(input_dim=c_list[5], output_dim=c_list[4])
165 | )
166 | self.decoder2 = nn.Sequential(
167 | PVMLayer(input_dim=c_list[4], output_dim=c_list[3])
168 | )
169 | self.decoder3 = nn.Sequential(
170 | PVMLayer(input_dim=c_list[3], output_dim=c_list[2])
171 | )
172 | self.decoder4 = nn.Sequential(
173 | nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1),
174 | )
175 | self.decoder5 = nn.Sequential(
176 | nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1),
177 | )
178 | self.ebn1 = nn.GroupNorm(4, c_list[0])
179 | self.ebn2 = nn.GroupNorm(4, c_list[1])
180 | self.ebn3 = nn.GroupNorm(4, c_list[2])
181 | self.ebn4 = nn.GroupNorm(4, c_list[3])
182 | self.ebn5 = nn.GroupNorm(4, c_list[4])
183 | self.dbn1 = nn.GroupNorm(4, c_list[4])
184 | self.dbn2 = nn.GroupNorm(4, c_list[3])
185 | self.dbn3 = nn.GroupNorm(4, c_list[2])
186 | self.dbn4 = nn.GroupNorm(4, c_list[1])
187 | self.dbn5 = nn.GroupNorm(4, c_list[0])
188 |
189 | self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1)
190 |
191 | self.apply(self._init_weights)
192 |
193 | def _init_weights(self, m):
194 | if isinstance(m, nn.Linear):
195 | trunc_normal_(m.weight, std=.02)
196 | if isinstance(m, nn.Linear) and m.bias is not None:
197 | nn.init.constant_(m.bias, 0)
198 | elif isinstance(m, nn.Conv1d):
199 | n = m.kernel_size[0] * m.out_channels
200 | m.weight.data.normal_(0, math.sqrt(2. / n))
201 | elif isinstance(m, nn.Conv2d):
202 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
203 | fan_out //= m.groups
204 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
205 | if m.bias is not None:
206 | m.bias.data.zero_()
207 |
208 | def forward(self, x):
209 |
210 | out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
211 | t1 = out # b, c0, H/2, W/2
212 |
213 | out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
214 | t2 = out # b, c1, H/4, W/4
215 |
216 | out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
217 | t3 = out # b, c2, H/8, W/8
218 |
219 | out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2))
220 | t4 = out # b, c3, H/16, W/16
221 |
222 | out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2))
223 | t5 = out # b, c4, H/32, W/32
224 |
225 | if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5)
226 |
227 | out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32
228 |
229 | out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32
230 | out5 = torch.add(out5, t5) # b, c4, H/32, W/32
231 |
232 | out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16
233 | out4 = torch.add(out4, t4) # b, c3, H/16, W/16
234 |
235 | out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8
236 | out3 = torch.add(out3, t3) # b, c2, H/8, W/8
237 |
238 | out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4
239 | out2 = torch.add(out2, t2) # b, c1, H/4, W/4
240 |
241 | out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2
242 | out1 = torch.add(out1, t1) # b, c0, H/2, W/2
243 |
244 | out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W
245 |
246 | return torch.sigmoid(out0)
247 |
248 |
249 |
--------------------------------------------------------------------------------
/results/Readme.txt:
--------------------------------------------------------------------------------
1 | Result save location
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.cuda.amp import autocast, GradScaler
4 | from torch.utils.data import DataLoader
5 | from loader import *
6 |
7 | from models.UltraLight_VM_UNet import UltraLight_VM_UNet
8 | from engine import *
9 | import os
10 | import sys
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3"
12 |
13 | from utils import *
14 | from configs.config_setting import setting_config
15 |
16 | import warnings
17 | warnings.filterwarnings("ignore")
18 |
19 |
20 |
21 | def main(config):
22 |
23 | print('#----------Creating logger----------#')
24 | sys.path.append(config.work_dir + '/')
25 | log_dir = os.path.join(config.work_dir, 'log')
26 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
27 | resume_model = os.path.join('')
28 | outputs = os.path.join(config.work_dir, 'outputs')
29 | if not os.path.exists(checkpoint_dir):
30 | os.makedirs(checkpoint_dir)
31 | if not os.path.exists(outputs):
32 | os.makedirs(outputs)
33 |
34 | global logger
35 | logger = get_logger('test', log_dir)
36 |
37 | log_config_info(config, logger)
38 |
39 |
40 |
41 |
42 |
43 | print('#----------GPU init----------#')
44 | set_seed(config.seed)
45 | gpu_ids = [0]# [0, 1, 2, 3]
46 | torch.cuda.empty_cache()
47 |
48 |
49 |
50 | print('#----------Prepareing Models----------#')
51 | model_cfg = config.model_config
52 | model = UltraLight_VM_UNet(num_classes=model_cfg['num_classes'],
53 | input_channels=model_cfg['input_channels'],
54 | c_list=model_cfg['c_list'],
55 | split_att=model_cfg['split_att'],
56 | bridge=model_cfg['bridge'],)
57 |
58 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])
59 |
60 |
61 | print('#----------Preparing dataset----------#')
62 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True)
63 | test_loader = DataLoader(test_dataset,
64 | batch_size=1,
65 | shuffle=False,
66 | pin_memory=True,
67 | num_workers=config.num_workers,
68 | drop_last=True)
69 |
70 | print('#----------Prepareing loss, opt, sch and amp----------#')
71 | criterion = config.criterion
72 | optimizer = get_optimizer(config, model)
73 | scheduler = get_scheduler(config, optimizer)
74 | scaler = GradScaler()
75 |
76 |
77 |
78 |
79 |
80 | print('#----------Set other params----------#')
81 | min_loss = 999
82 | start_epoch = 1
83 | min_epoch = 1
84 |
85 |
86 | print('#----------Testing----------#')
87 | best_weight = torch.load(resume_model, map_location=torch.device('cpu'))
88 | model.module.load_state_dict(best_weight)
89 | loss = test_one_epoch(
90 | test_loader,
91 | model,
92 | criterion,
93 | logger,
94 | config,
95 | )
96 |
97 |
98 |
99 | if __name__ == '__main__':
100 | config = setting_config
101 | main(config)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.cuda.amp import autocast, GradScaler
4 | from torch.utils.data import DataLoader
5 | from loader import *
6 |
7 | from models.UltraLight_VM_UNet import UltraLight_VM_UNet
8 | from engine import *
9 | import os
10 | import sys
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3"
12 |
13 | from utils import *
14 | from configs.config_setting import setting_config
15 |
16 | import warnings
17 | warnings.filterwarnings("ignore")
18 |
19 |
20 | def main(config):
21 |
22 | print('#----------Creating logger----------#')
23 | sys.path.append(config.work_dir + '/')
24 | log_dir = os.path.join(config.work_dir, 'log')
25 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
26 | resume_model = os.path.join(checkpoint_dir, 'latest.pth')
27 | outputs = os.path.join(config.work_dir, 'outputs')
28 | if not os.path.exists(checkpoint_dir):
29 | os.makedirs(checkpoint_dir)
30 | if not os.path.exists(outputs):
31 | os.makedirs(outputs)
32 |
33 | global logger
34 | logger = get_logger('train', log_dir)
35 |
36 | log_config_info(config, logger)
37 |
38 |
39 |
40 |
41 |
42 | print('#----------GPU init----------#')
43 | set_seed(config.seed)
44 | gpu_ids = [0]# [0, 1, 2, 3]
45 | torch.cuda.empty_cache()
46 |
47 |
48 |
49 |
50 |
51 | print('#----------Preparing dataset----------#')
52 | train_dataset = isic_loader(path_Data = config.data_path, train = True)
53 | train_loader = DataLoader(train_dataset,
54 | batch_size=config.batch_size,
55 | shuffle=True,
56 | pin_memory=True,
57 | num_workers=config.num_workers)
58 | val_dataset = isic_loader(path_Data = config.data_path, train = False)
59 | val_loader = DataLoader(val_dataset,
60 | batch_size=1,
61 | shuffle=False,
62 | pin_memory=True,
63 | num_workers=config.num_workers,
64 | drop_last=True)
65 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True)
66 | test_loader = DataLoader(test_dataset,
67 | batch_size=1,
68 | shuffle=False,
69 | pin_memory=True,
70 | num_workers=config.num_workers,
71 | drop_last=True)
72 |
73 |
74 |
75 |
76 | print('#----------Prepareing Models----------#')
77 | model_cfg = config.model_config
78 | model = UltraLight_VM_UNet(num_classes=model_cfg['num_classes'],
79 | input_channels=model_cfg['input_channels'],
80 | c_list=model_cfg['c_list'],
81 | split_att=model_cfg['split_att'],
82 | bridge=model_cfg['bridge'],)
83 |
84 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])
85 |
86 |
87 |
88 |
89 |
90 |
91 | print('#----------Prepareing loss, opt, sch and amp----------#')
92 | criterion = config.criterion
93 | optimizer = get_optimizer(config, model)
94 | scheduler = get_scheduler(config, optimizer)
95 | scaler = GradScaler()
96 |
97 |
98 |
99 |
100 |
101 | print('#----------Set other params----------#')
102 | min_loss = 999
103 | start_epoch = 1
104 | min_epoch = 1
105 |
106 |
107 |
108 |
109 |
110 | if os.path.exists(resume_model):
111 | print('#----------Resume Model and Other params----------#')
112 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu'))
113 | model.module.load_state_dict(checkpoint['model_state_dict'])
114 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
115 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
116 | saved_epoch = checkpoint['epoch']
117 | start_epoch += saved_epoch
118 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss']
119 |
120 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}'
121 | logger.info(log_info)
122 |
123 |
124 |
125 |
126 |
127 | print('#----------Training----------#')
128 | for epoch in range(start_epoch, config.epochs + 1):
129 |
130 | torch.cuda.empty_cache()
131 |
132 | train_one_epoch(
133 | train_loader,
134 | model,
135 | criterion,
136 | optimizer,
137 | scheduler,
138 | epoch,
139 | logger,
140 | config,
141 | scaler=scaler
142 | )
143 |
144 | loss = val_one_epoch(
145 | val_loader,
146 | model,
147 | criterion,
148 | epoch,
149 | logger,
150 | config
151 | )
152 |
153 |
154 | if loss < min_loss:
155 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth'))
156 | min_loss = loss
157 | min_epoch = epoch
158 |
159 | torch.save(
160 | {
161 | 'epoch': epoch,
162 | 'min_loss': min_loss,
163 | 'min_epoch': min_epoch,
164 | 'loss': loss,
165 | 'model_state_dict': model.module.state_dict(),
166 | 'optimizer_state_dict': optimizer.state_dict(),
167 | 'scheduler_state_dict': scheduler.state_dict(),
168 | }, os.path.join(checkpoint_dir, 'latest.pth'))
169 |
170 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')):
171 | print('#----------Testing----------#')
172 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu'))
173 | model.module.load_state_dict(best_weight)
174 | loss = test_one_epoch(
175 | test_loader,
176 | model,
177 | criterion,
178 | logger,
179 | config,
180 | )
181 | os.rename(
182 | os.path.join(checkpoint_dir, 'best.pth'),
183 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth')
184 | )
185 |
186 |
187 | if __name__ == '__main__':
188 | config = setting_config
189 | main(config)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.backends.cudnn as cudnn
5 | import torchvision.transforms.functional as TF
6 | import numpy as np
7 | import os
8 | import math
9 | import random
10 | import logging
11 | import logging.handlers
12 | from matplotlib import pyplot as plt
13 |
14 |
15 | def set_seed(seed):
16 | # for hash
17 | os.environ['PYTHONHASHSEED'] = str(seed)
18 | # for python and numpy
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | # for cpu gpu
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | # for cudnn
26 | cudnn.benchmark = False
27 | cudnn.deterministic = True
28 |
29 |
30 | def get_logger(name, log_dir):
31 | '''
32 | Args:
33 | name(str): name of logger
34 | log_dir(str): path of log
35 | '''
36 |
37 | if not os.path.exists(log_dir):
38 | os.makedirs(log_dir)
39 |
40 | logger = logging.getLogger(name)
41 | logger.setLevel(logging.INFO)
42 |
43 | info_name = os.path.join(log_dir, '{}.info.log'.format(name))
44 | info_handler = logging.handlers.TimedRotatingFileHandler(info_name,
45 | when='D',
46 | encoding='utf-8')
47 | info_handler.setLevel(logging.INFO)
48 |
49 | formatter = logging.Formatter('%(asctime)s - %(message)s',
50 | datefmt='%Y-%m-%d %H:%M:%S')
51 |
52 | info_handler.setFormatter(formatter)
53 |
54 | logger.addHandler(info_handler)
55 |
56 | return logger
57 |
58 |
59 | def log_config_info(config, logger):
60 | config_dict = config.__dict__
61 | log_info = f'#----------Config info----------#'
62 | logger.info(log_info)
63 | for k, v in config_dict.items():
64 | if k[0] == '_':
65 | continue
66 | else:
67 | log_info = f'{k}: {v},'
68 | logger.info(log_info)
69 |
70 |
71 |
72 | def get_optimizer(config, model):
73 | assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!'
74 |
75 | if config.opt == 'Adadelta':
76 | return torch.optim.Adadelta(
77 | model.parameters(),
78 | lr = config.lr,
79 | rho = config.rho,
80 | eps = config.eps,
81 | weight_decay = config.weight_decay
82 | )
83 | elif config.opt == 'Adagrad':
84 | return torch.optim.Adagrad(
85 | model.parameters(),
86 | lr = config.lr,
87 | lr_decay = config.lr_decay,
88 | eps = config.eps,
89 | weight_decay = config.weight_decay
90 | )
91 | elif config.opt == 'Adam':
92 | return torch.optim.Adam(
93 | model.parameters(),
94 | lr = config.lr,
95 | betas = config.betas,
96 | eps = config.eps,
97 | weight_decay = config.weight_decay,
98 | amsgrad = config.amsgrad
99 | )
100 | elif config.opt == 'AdamW':
101 | return torch.optim.AdamW(
102 | model.parameters(),
103 | lr = config.lr,
104 | betas = config.betas,
105 | eps = config.eps,
106 | weight_decay = config.weight_decay,
107 | amsgrad = config.amsgrad
108 | )
109 | elif config.opt == 'Adamax':
110 | return torch.optim.Adamax(
111 | model.parameters(),
112 | lr = config.lr,
113 | betas = config.betas,
114 | eps = config.eps,
115 | weight_decay = config.weight_decay
116 | )
117 | elif config.opt == 'ASGD':
118 | return torch.optim.ASGD(
119 | model.parameters(),
120 | lr = config.lr,
121 | lambd = config.lambd,
122 | alpha = config.alpha,
123 | t0 = config.t0,
124 | weight_decay = config.weight_decay
125 | )
126 | elif config.opt == 'RMSprop':
127 | return torch.optim.RMSprop(
128 | model.parameters(),
129 | lr = config.lr,
130 | momentum = config.momentum,
131 | alpha = config.alpha,
132 | eps = config.eps,
133 | centered = config.centered,
134 | weight_decay = config.weight_decay
135 | )
136 | elif config.opt == 'Rprop':
137 | return torch.optim.Rprop(
138 | model.parameters(),
139 | lr = config.lr,
140 | etas = config.etas,
141 | step_sizes = config.step_sizes,
142 | )
143 | elif config.opt == 'SGD':
144 | return torch.optim.SGD(
145 | model.parameters(),
146 | lr = config.lr,
147 | momentum = config.momentum,
148 | weight_decay = config.weight_decay,
149 | dampening = config.dampening,
150 | nesterov = config.nesterov
151 | )
152 | else: # default opt is SGD
153 | return torch.optim.SGD(
154 | model.parameters(),
155 | lr = 0.01,
156 | momentum = 0.9,
157 | weight_decay = 0.05,
158 | )
159 |
160 |
161 |
162 | def get_scheduler(config, optimizer):
163 | assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau',
164 | 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!'
165 | if config.sch == 'StepLR':
166 | scheduler = torch.optim.lr_scheduler.StepLR(
167 | optimizer,
168 | step_size = config.step_size,
169 | gamma = config.gamma,
170 | last_epoch = config.last_epoch
171 | )
172 | elif config.sch == 'MultiStepLR':
173 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
174 | optimizer,
175 | milestones = config.milestones,
176 | gamma = config.gamma,
177 | last_epoch = config.last_epoch
178 | )
179 | elif config.sch == 'ExponentialLR':
180 | scheduler = torch.optim.lr_scheduler.ExponentialLR(
181 | optimizer,
182 | gamma = config.gamma,
183 | last_epoch = config.last_epoch
184 | )
185 | elif config.sch == 'CosineAnnealingLR':
186 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
187 | optimizer,
188 | T_max = config.T_max,
189 | eta_min = config.eta_min,
190 | last_epoch = config.last_epoch
191 | )
192 | elif config.sch == 'ReduceLROnPlateau':
193 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
194 | optimizer,
195 | mode = config.mode,
196 | factor = config.factor,
197 | patience = config.patience,
198 | threshold = config.threshold,
199 | threshold_mode = config.threshold_mode,
200 | cooldown = config.cooldown,
201 | min_lr = config.min_lr,
202 | eps = config.eps
203 | )
204 | elif config.sch == 'CosineAnnealingWarmRestarts':
205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
206 | optimizer,
207 | T_0 = config.T_0,
208 | T_mult = config.T_mult,
209 | eta_min = config.eta_min,
210 | last_epoch = config.last_epoch
211 | )
212 | elif config.sch == 'WP_MultiStepLR':
213 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len(
214 | [m for m in config.milestones if m <= epoch])
215 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
216 | elif config.sch == 'WP_CosineLR':
217 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * (
218 | math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1)
219 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
220 |
221 | return scheduler
222 |
223 |
224 |
225 | def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None):
226 | img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy()
227 | img = img / 255. if img.max() > 1.1 else img
228 | if datasets == 'retinal':
229 | msk = np.squeeze(msk, axis=0)
230 | msk_pred = np.squeeze(msk_pred, axis=0)
231 | else:
232 | msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0)
233 | msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0)
234 |
235 | plt.figure(figsize=(7,15))
236 |
237 | plt.subplot(3,1,1)
238 | plt.imshow(img)
239 | plt.axis('off')
240 |
241 | plt.subplot(3,1,2)
242 | plt.imshow(msk, cmap= 'gray')
243 | plt.axis('off')
244 |
245 | plt.subplot(3,1,3)
246 | plt.imshow(msk_pred, cmap = 'gray')
247 | plt.axis('off')
248 |
249 | if test_data_name is not None:
250 | save_path = save_path + test_data_name + '_'
251 | plt.savefig(save_path + str(i) +'.png')
252 | plt.close()
253 |
254 |
255 |
256 | class BCELoss(nn.Module):
257 | def __init__(self):
258 | super(BCELoss, self).__init__()
259 | self.bceloss = nn.BCELoss()
260 |
261 | def forward(self, pred, target):
262 | size = pred.size(0)
263 | pred_ = pred.view(size, -1)
264 | target_ = target.view(size, -1)
265 |
266 | return self.bceloss(pred_, target_)
267 |
268 |
269 | class DiceLoss(nn.Module):
270 | def __init__(self):
271 | super(DiceLoss, self).__init__()
272 |
273 | def forward(self, pred, target):
274 | smooth = 1
275 | size = pred.size(0)
276 |
277 | pred_ = pred.view(size, -1)
278 | target_ = target.view(size, -1)
279 | intersection = pred_ * target_
280 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth)
281 | dice_loss = 1 - dice_score.sum()/size
282 |
283 | return dice_loss
284 |
285 |
286 | class BceDiceLoss(nn.Module):
287 | def __init__(self, wb=1, wd=1):
288 | super(BceDiceLoss, self).__init__()
289 | self.bce = BCELoss()
290 | self.dice = DiceLoss()
291 | self.wb = wb
292 | self.wd = wd
293 |
294 | def forward(self, pred, target):
295 | bceloss = self.bce(pred, target)
296 | diceloss = self.dice(pred, target)
297 |
298 | loss = self.wd * diceloss + self.wb * bceloss
299 | return loss
300 |
301 |
302 | from thop import profile ## 导入thop模块
303 | def cal_params_flops(model, size, logger):
304 | input = torch.randn(1, 3, size, size).cuda()
305 | flops, params = profile(model, inputs=(input,))
306 | print('flops',flops/1e9) ## 打印计算量
307 | print('params',params/1e6) ## 打印参数量
308 |
309 | total = sum(p.numel() for p in model.parameters())
310 | print("Total params: %.3fM" % (total/1e6))
311 | logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}')
312 |
313 |
314 |
--------------------------------------------------------------------------------