├── .gitignore ├── LICENSE ├── README.md ├── configs ├── config_setting.py └── config_setting_synapse.py ├── data └── README.md ├── datasets └── dataset.py ├── engine.py ├── engine_synapse.py ├── models └── vmunet │ ├── vmamba.py │ └── vmunet.py ├── pre_trained_weights └── README.md ├── results └── README.md ├── train.py ├── train_synapse.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VM-UNet 2 | This is the official code repository for "VM-UNet: Vision Mamba UNet for Medical 3 | Image Segmentation". {[Arxiv Paper](https://arxiv.org/abs/2402.02491)} 4 | 5 | ## Abstract 6 | In the realm of medical image segmentation, both CNN-based and Transformer-based models have been extensively explored. However, CNNs exhibit limitations in long-range modeling capabilities, whereas Transformers are hampered by their quadratic computational complexity. Recently, State Space Models (SSMs), exemplified by Mamba, have emerged as a promising approach. They not only excel in modeling long-range interactions but also maintain a linear computational complexity. In this paper, leveraging state space models, we propose a U-shape architecture model for medical image segmentation, named Vision Mamba UNet (VM-UNet). Specifically, the Visual State Space (VSS) block is introduced as the foundation block to capture extensive contextual information, and an asymmetrical encoder-decoder structure is constructed. We conduct comprehensive experiments on the ISIC17, ISIC18, and Synapse datasets, and the results indicate that VM-UNet performs competitively in medical image segmentation tasks. To our best knowledge, this is the first medical image segmentation model constructed based on the pure SSM-based model. We aim to establish a baseline and provide valuable insights for the future development of more efficient and effective SSM-based segmentation systems. 7 | 8 | ## 0. Main Environments 9 | ```bash 10 | conda create -n vmunet python=3.8 11 | conda activate vmunet 12 | pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 13 | pip install packaging 14 | pip install timm==0.4.12 15 | pip install pytest chardet yacs termcolor 16 | pip install submitit tensorboardX 17 | pip install triton==2.0.0 18 | pip install causal_conv1d==1.0.0 # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 19 | pip install mamba_ssm==1.0.1 # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 20 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs 21 | ``` 22 | The .whl files of causal_conv1d and mamba_ssm could be found here. {[Baidu](https://pan.baidu.com/s/1Tibn8Xh4FMwj0ths8Ufazw?pwd=uu5k) or [GoogleDrive](https://drive.google.com/drive/folders/1tZGs1YFHiDrMa-MjYY8ZoEnCyy7m7Gaj?usp=sharing)} 23 | 24 | ## 1. Prepare the dataset 25 | 26 | ### ISIC datasets 27 | - The ISIC17 and ISIC18 datasets, divided into a 7:3 ratio, can be found here {[Baidu](https://pan.baidu.com/s/1Y0YupaH21yDN5uldl7IcZA?pwd=dybm) or [GoogleDrive](https://drive.google.com/file/d/1XM10fmAXndVLtXWOt5G0puYSQyI2veWy/view?usp=sharing)}. 28 | 29 | - After downloading the datasets, you are supposed to put them into './data/isic17/' and './data/isic18/', and the file format reference is as follows. (take the ISIC17 dataset as an example.) 30 | 31 | - './data/isic17/' 32 | - train 33 | - images 34 | - .png 35 | - masks 36 | - .png 37 | - val 38 | - images 39 | - .png 40 | - masks 41 | - .png 42 | 43 | ### Synapse datasets 44 | 45 | - For the Synapse dataset, you could follow [Swin-UNet](https://github.com/HuCaoFighting/Swin-Unet) to download the dataset, or you could download them from {[Baidu](https://pan.baidu.com/s/1JCXBfRL9y1cjfJUKtbEhiQ?pwd=9jti)}. 46 | 47 | - After downloading the datasets, you are supposed to put them into './data/Synapse/', and the file format reference is as follows. 48 | 49 | - './data/Synapse/' 50 | - lists 51 | - list_Synapse 52 | - all.lst 53 | - test_vol.txt 54 | - train.txt 55 | - test_vol_h5 56 | - casexxxx.npy.h5 57 | - train_npz 58 | - casexxxx_slicexxx.npz 59 | 60 | ## 2. Prepare the pre_trained weights 61 | 62 | - The weights of the pre-trained VMamba could be downloaded from [Baidu](https://pan.baidu.com/s/1ci_YvPPEiUT2bIIK5x8Igw?pwd=wnyy) or [GoogleDrive](https://drive.google.com/drive/folders/1tZGs1YFHiDrMa-MjYY8ZoEnCyy7m7Gaj?usp=sharing). After that, the pre-trained weights should be stored in './pretrained_weights/'. 63 | 64 | 65 | 66 | ## 3. Train the VM-UNet 67 | ```bash 68 | cd VM-UNet 69 | python train.py # Train and test VM-UNet on the ISIC17 or ISIC18 dataset. 70 | python train_synapse.py # Train and test VM-UNet on the Synapse dataset. 71 | ``` 72 | 73 | **NOTE**: If you want to use the trained checkpoint for inference testing only and save the corresponding test images, you can follow these steps: 74 | 75 | - **In `config_setting`**: 76 | - Set the parameter `only_test_and_save_figs` to `True`. 77 | - Fill in the path of the trained checkpoint in `best_ckpt_path`. 78 | - Specify the save path for test images in `img_save_path`. 79 | 80 | - **Execute the script**: 81 | After setting the above parameters, you can run `train.py`. 82 | 83 | ## 4. Obtain the outputs 84 | - After trianing, you could obtain the results in './results/' 85 | 86 | ## 5. Acknowledgments 87 | 88 | - We thank the authors of [VMamba](https://github.com/MzeroMiko/VMamba) and [Swin-UNet](https://github.com/HuCaoFighting/Swin-Unet) for their open-source codes. 89 | -------------------------------------------------------------------------------- /configs/config_setting.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from utils import * 3 | 4 | from datetime import datetime 5 | 6 | class setting_config: 7 | """ 8 | the config of training setting. 9 | """ 10 | 11 | network = 'vmunet' 12 | model_config = { 13 | 'num_classes': 1, 14 | 'input_channels': 3, 15 | # ----- VM-UNet ----- # 16 | 'depths': [2,2,2,2], 17 | 'depths_decoder': [2,2,2,1], 18 | 'drop_path_rate': 0.2, 19 | 'load_ckpt_path': './pre_trained_weights/vmamba_small_e238_ema.pth', 20 | } 21 | 22 | datasets = 'isic18' 23 | if datasets == 'isic18': 24 | data_path = './data/isic2018/' 25 | elif datasets == 'isic17': 26 | data_path = './data/isic2017/' 27 | else: 28 | raise Exception('datasets in not right!') 29 | 30 | criterion = BceDiceLoss(wb=1, wd=1) 31 | 32 | pretrained_path = './pre_trained/' 33 | num_classes = 1 34 | input_size_h = 256 35 | input_size_w = 256 36 | input_channels = 3 37 | distributed = False 38 | local_rank = -1 39 | num_workers = 0 40 | seed = 42 41 | world_size = None 42 | rank = None 43 | amp = False 44 | gpu_id = '0' 45 | batch_size = 32 46 | epochs = 300 47 | 48 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 49 | 50 | print_interval = 20 51 | val_interval = 30 52 | save_interval = 100 53 | threshold = 0.5 54 | only_test_and_save_figs = False 55 | best_ckpt_path = 'PATH_TO_YOUR_BEST_CKPT' 56 | img_save_path = 'PATH_TO_SAVE_IMAGES' 57 | 58 | train_transformer = transforms.Compose([ 59 | myNormalize(datasets, train=True), 60 | myToTensor(), 61 | myRandomHorizontalFlip(p=0.5), 62 | myRandomVerticalFlip(p=0.5), 63 | myRandomRotation(p=0.5, degree=[0, 360]), 64 | myResize(input_size_h, input_size_w) 65 | ]) 66 | test_transformer = transforms.Compose([ 67 | myNormalize(datasets, train=False), 68 | myToTensor(), 69 | myResize(input_size_h, input_size_w) 70 | ]) 71 | 72 | opt = 'AdamW' 73 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 74 | if opt == 'Adadelta': 75 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 76 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 77 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 78 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 79 | elif opt == 'Adagrad': 80 | lr = 0.01 # default: 0.01 – learning rate 81 | lr_decay = 0 # default: 0 – learning rate decay 82 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 83 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 84 | elif opt == 'Adam': 85 | lr = 0.001 # default: 1e-3 – learning rate 86 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 87 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 88 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 89 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 90 | elif opt == 'AdamW': 91 | lr = 0.001 # default: 1e-3 – learning rate 92 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 93 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 94 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 95 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 96 | elif opt == 'Adamax': 97 | lr = 2e-3 # default: 2e-3 – learning rate 98 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 99 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 100 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 101 | elif opt == 'ASGD': 102 | lr = 0.01 # default: 1e-2 – learning rate 103 | lambd = 1e-4 # default: 1e-4 – decay term 104 | alpha = 0.75 # default: 0.75 – power for eta update 105 | t0 = 1e6 # default: 1e6 – point at which to start averaging 106 | weight_decay = 0 # default: 0 – weight decay 107 | elif opt == 'RMSprop': 108 | lr = 1e-2 # default: 1e-2 – learning rate 109 | momentum = 0 # default: 0 – momentum factor 110 | alpha = 0.99 # default: 0.99 – smoothing constant 111 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 112 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 113 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 114 | elif opt == 'Rprop': 115 | lr = 1e-2 # default: 1e-2 – learning rate 116 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 117 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 118 | elif opt == 'SGD': 119 | lr = 0.01 # – learning rate 120 | momentum = 0.9 # default: 0 – momentum factor 121 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 122 | dampening = 0 # default: 0 – dampening for momentum 123 | nesterov = False # default: False – enables Nesterov momentum 124 | 125 | sch = 'CosineAnnealingLR' 126 | if sch == 'StepLR': 127 | step_size = epochs // 5 # – Period of learning rate decay. 128 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 129 | last_epoch = -1 # – The index of last epoch. Default: -1. 130 | elif sch == 'MultiStepLR': 131 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 132 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 133 | last_epoch = -1 # – The index of last epoch. Default: -1. 134 | elif sch == 'ExponentialLR': 135 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 136 | last_epoch = -1 # – The index of last epoch. Default: -1. 137 | elif sch == 'CosineAnnealingLR': 138 | T_max = 50 # – Maximum number of iterations. Cosine function period. 139 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 140 | last_epoch = -1 # – The index of last epoch. Default: -1. 141 | elif sch == 'ReduceLROnPlateau': 142 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 143 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 144 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 145 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 146 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 147 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 148 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 149 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 150 | elif sch == 'CosineAnnealingWarmRestarts': 151 | T_0 = 50 # – Number of iterations for the first restart. 152 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 153 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 154 | last_epoch = -1 # – The index of last epoch. Default: -1. 155 | elif sch == 'WP_MultiStepLR': 156 | warm_up_epochs = 10 157 | gamma = 0.1 158 | milestones = [125, 225] 159 | elif sch == 'WP_CosineLR': 160 | warm_up_epochs = 20 161 | -------------------------------------------------------------------------------- /configs/config_setting_synapse.py: -------------------------------------------------------------------------------- 1 | from pickle import FALSE 2 | from torchvision import transforms 3 | from datasets.dataset import * 4 | from utils import * 5 | 6 | from datetime import datetime 7 | import ml_collections 8 | 9 | class setting_config: 10 | """ 11 | the config of training setting. 12 | """ 13 | network = 'vmunet' 14 | model_config = { 15 | 'num_classes': 9, 16 | 'input_channels': 3, 17 | # ----- VM-UNet ----- # 18 | 'depths': [2,2,2,2], 19 | 'depths_decoder': [2,2,2,1], 20 | 'drop_path_rate': 0.2, 21 | 'load_ckpt_path': './pre_trained_weights/vmamba_small_e238_ema.pth', 22 | } 23 | datasets_name = 'synapse' 24 | input_size_h = 224 25 | input_size_w = 224 26 | if datasets_name == 'synapse': 27 | data_path = './data/Synapse/train_npz/' 28 | datasets = Synapse_dataset 29 | list_dir = './data/Synapse/lists/lists_Synapse/' 30 | volume_path = './data/Synapse/test_vol_h5/' 31 | else: 32 | raise Exception('datasets in not right!') 33 | 34 | pretrained_path = '' # if using pretrained, please enter the path of weights 35 | num_classes = 9 36 | loss_weight = [1, 1] 37 | criterion = CeDiceLoss(num_classes, loss_weight) 38 | z_spacing = 1 39 | input_channels = 3 40 | 41 | distributed = False 42 | local_rank = -1 43 | num_workers = 0 #16 44 | seed = 42 45 | world_size = None 46 | rank = None 47 | amp = False 48 | 49 | batch_size = 32 50 | epochs = 300 51 | work_dir = 'results/' + network + '_' + datasets_name + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 52 | print_interval = 20 53 | val_interval = 100 54 | threshold = 0.5 55 | only_test_and_save_figs = False 56 | best_ckpt_path = 'PATH_TO_YOUR_BEST_CKPT' 57 | img_save_path = 'PATH_TO_SAVE_IMAGES' 58 | 59 | opt = 'AdamW' 60 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 61 | if opt == 'Adadelta': 62 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 63 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 64 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 65 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 66 | elif opt == 'Adagrad': 67 | lr = 0.01 # default: 0.01 – learning rate 68 | lr_decay = 0 # default: 0 – learning rate decay 69 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 70 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 71 | elif opt == 'Adam': 72 | lr = 0.0001 # default: 1e-3 – learning rate 73 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 74 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 75 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 76 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 77 | elif opt == 'AdamW': 78 | lr = 0.001 # default: 1e-3 – learning rate 79 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 80 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 81 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 82 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 83 | elif opt == 'Adamax': 84 | lr = 2e-3 # default: 2e-3 – learning rate 85 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 86 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 87 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 88 | elif opt == 'ASGD': 89 | lr = 0.01 # default: 1e-2 – learning rate 90 | lambd = 1e-4 # default: 1e-4 – decay term 91 | alpha = 0.75 # default: 0.75 – power for eta update 92 | t0 = 1e6 # default: 1e6 – point at which to start averaging 93 | weight_decay = 0 # default: 0 – weight decay 94 | elif opt == 'RMSprop': 95 | lr = 1e-2 # default: 1e-2 – learning rate 96 | momentum = 0 # default: 0 – momentum factor 97 | alpha = 0.99 # default: 0.99 – smoothing constant 98 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 99 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 100 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 101 | elif opt == 'Rprop': 102 | lr = 1e-2 # default: 1e-2 – learning rate 103 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 104 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 105 | elif opt == 'SGD': 106 | lr = 0.003 # – learning rate 107 | momentum = 0.9 # default: 0 – momentum factor 108 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 109 | dampening = 0 # default: 0 – dampening for momentum 110 | nesterov = False # default: False – enables Nesterov momentum 111 | 112 | sch = 'CosineAnnealingLR' 113 | if sch == 'StepLR': 114 | step_size = epochs // 5 # – Period of learning rate decay. 115 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 116 | last_epoch = -1 # – The index of last epoch. Default: -1. 117 | elif sch == 'MultiStepLR': 118 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 119 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 120 | last_epoch = -1 # – The index of last epoch. Default: -1. 121 | elif sch == 'ExponentialLR': 122 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 123 | last_epoch = -1 # – The index of last epoch. Default: -1. 124 | elif sch == 'CosineAnnealingLR': 125 | T_max = 100 # – Maximum number of iterations. Cosine function period. 126 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 127 | last_epoch = -1 # – The index of last epoch. Default: -1. 128 | elif sch == 'ReduceLROnPlateau': 129 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 130 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 131 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 132 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 133 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 134 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 135 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 136 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 137 | elif sch == 'CosineAnnealingWarmRestarts': 138 | T_0 = 50 # – Number of iterations for the first restart. 139 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 140 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 141 | last_epoch = -1 # – The index of last epoch. Default: -1. 142 | elif sch == 'WP_MultiStepLR': 143 | warm_up_epochs = 10 144 | gamma = 0.1 145 | milestones = [125, 225] 146 | elif sch == 'WP_CosineLR': 147 | warm_up_epochs = 20 148 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | You could put the datasets here. -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | 6 | import random 7 | import h5py 8 | import torch 9 | from scipy import ndimage 10 | from scipy.ndimage.interpolation import zoom 11 | from torch.utils.data import Dataset 12 | from scipy import ndimage 13 | from PIL import Image 14 | 15 | 16 | class NPY_datasets(Dataset): 17 | def __init__(self, path_Data, config, train=True): 18 | super(NPY_datasets, self) 19 | if train: 20 | images_list = sorted(os.listdir(path_Data+'train/images/')) 21 | masks_list = sorted(os.listdir(path_Data+'train/masks/')) 22 | self.data = [] 23 | for i in range(len(images_list)): 24 | img_path = path_Data+'train/images/' + images_list[i] 25 | mask_path = path_Data+'train/masks/' + masks_list[i] 26 | self.data.append([img_path, mask_path]) 27 | self.transformer = config.train_transformer 28 | else: 29 | images_list = sorted(os.listdir(path_Data+'val/images/')) 30 | masks_list = sorted(os.listdir(path_Data+'val/masks/')) 31 | self.data = [] 32 | for i in range(len(images_list)): 33 | img_path = path_Data+'val/images/' + images_list[i] 34 | mask_path = path_Data+'val/masks/' + masks_list[i] 35 | self.data.append([img_path, mask_path]) 36 | self.transformer = config.test_transformer 37 | 38 | def __getitem__(self, indx): 39 | img_path, msk_path = self.data[indx] 40 | img = np.array(Image.open(img_path).convert('RGB')) 41 | msk = np.expand_dims(np.array(Image.open(msk_path).convert('L')), axis=2) / 255 42 | img, msk = self.transformer((img, msk)) 43 | return img, msk 44 | 45 | def __len__(self): 46 | return len(self.data) 47 | 48 | 49 | 50 | def random_rot_flip(image, label): 51 | k = np.random.randint(0, 4) 52 | image = np.rot90(image, k) 53 | label = np.rot90(label, k) 54 | axis = np.random.randint(0, 2) 55 | image = np.flip(image, axis=axis).copy() 56 | label = np.flip(label, axis=axis).copy() 57 | return image, label 58 | 59 | 60 | def random_rotate(image, label): 61 | angle = np.random.randint(-20, 20) 62 | image = ndimage.rotate(image, angle, order=0, reshape=False) 63 | label = ndimage.rotate(label, angle, order=0, reshape=False) 64 | return image, label 65 | 66 | 67 | class RandomGenerator(object): 68 | def __init__(self, output_size): 69 | self.output_size = output_size 70 | 71 | def __call__(self, sample): 72 | image, label = sample['image'], sample['label'] 73 | 74 | if random.random() > 0.5: 75 | image, label = random_rot_flip(image, label) 76 | elif random.random() > 0.5: 77 | image, label = random_rotate(image, label) 78 | x, y = image.shape 79 | if x != self.output_size[0] or y != self.output_size[1]: 80 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 81 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 82 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 83 | label = torch.from_numpy(label.astype(np.float32)) 84 | sample = {'image': image, 'label': label.long()} 85 | return sample 86 | 87 | 88 | class Synapse_dataset(Dataset): 89 | def __init__(self, base_dir, list_dir, split, transform=None): 90 | self.transform = transform # using transform in torch! 91 | self.split = split 92 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 93 | self.data_dir = base_dir 94 | 95 | def __len__(self): 96 | return len(self.sample_list) 97 | 98 | def __getitem__(self, idx): 99 | if self.split == "train": 100 | slice_name = self.sample_list[idx].strip('\n') 101 | data_path = os.path.join(self.data_dir, slice_name+'.npz') 102 | data = np.load(data_path) 103 | image, label = data['image'], data['label'] 104 | else: 105 | vol_name = self.sample_list[idx].strip('\n') 106 | filepath = self.data_dir + "/{}.npy.h5".format(vol_name) 107 | data = h5py.File(filepath) 108 | image, label = data['image'][:], data['label'][:] 109 | 110 | sample = {'image': image, 'label': label} 111 | if self.transform: 112 | sample = self.transform(sample) 113 | sample['case_name'] = self.sample_list[idx].strip('\n') 114 | return sample 115 | 116 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from torch.cuda.amp import autocast as autocast 5 | from sklearn.metrics import confusion_matrix 6 | from utils import save_imgs 7 | 8 | 9 | def train_one_epoch(train_loader, 10 | model, 11 | criterion, 12 | optimizer, 13 | scheduler, 14 | epoch, 15 | step, 16 | logger, 17 | config, 18 | writer): 19 | ''' 20 | train model for one epoch 21 | ''' 22 | # switch to train mode 23 | model.train() 24 | 25 | loss_list = [] 26 | 27 | for iter, data in enumerate(train_loader): 28 | step += iter 29 | optimizer.zero_grad() 30 | images, targets = data 31 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() 32 | 33 | out = model(images) 34 | loss = criterion(out, targets) 35 | 36 | loss.backward() 37 | optimizer.step() 38 | 39 | loss_list.append(loss.item()) 40 | 41 | now_lr = optimizer.state_dict()['param_groups'][0]['lr'] 42 | 43 | writer.add_scalar('loss', loss, global_step=step) 44 | 45 | if iter % config.print_interval == 0: 46 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}' 47 | print(log_info) 48 | logger.info(log_info) 49 | scheduler.step() 50 | return step 51 | 52 | 53 | def val_one_epoch(test_loader, 54 | model, 55 | criterion, 56 | epoch, 57 | logger, 58 | config): 59 | # switch to evaluate mode 60 | model.eval() 61 | preds = [] 62 | gts = [] 63 | loss_list = [] 64 | with torch.no_grad(): 65 | for data in tqdm(test_loader): 66 | img, msk = data 67 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 68 | 69 | out = model(img) 70 | loss = criterion(out, msk) 71 | 72 | loss_list.append(loss.item()) 73 | gts.append(msk.squeeze(1).cpu().detach().numpy()) 74 | if type(out) is tuple: 75 | out = out[0] 76 | out = out.squeeze(1).cpu().detach().numpy() 77 | preds.append(out) 78 | 79 | if epoch % config.val_interval == 0: 80 | preds = np.array(preds).reshape(-1) 81 | gts = np.array(gts).reshape(-1) 82 | 83 | y_pre = np.where(preds>=config.threshold, 1, 0) 84 | y_true = np.where(gts>=0.5, 1, 0) 85 | 86 | confusion = confusion_matrix(y_true, y_pre) 87 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 88 | 89 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 90 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 91 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 92 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 93 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 94 | 95 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 96 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 97 | print(log_info) 98 | logger.info(log_info) 99 | 100 | else: 101 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}' 102 | print(log_info) 103 | logger.info(log_info) 104 | 105 | return np.mean(loss_list) 106 | 107 | 108 | def test_one_epoch(test_loader, 109 | model, 110 | criterion, 111 | logger, 112 | config, 113 | test_data_name=None): 114 | # switch to evaluate mode 115 | model.eval() 116 | preds = [] 117 | gts = [] 118 | loss_list = [] 119 | with torch.no_grad(): 120 | for i, data in enumerate(tqdm(test_loader)): 121 | img, msk = data 122 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 123 | 124 | out = model(img) 125 | loss = criterion(out, msk) 126 | 127 | loss_list.append(loss.item()) 128 | msk = msk.squeeze(1).cpu().detach().numpy() 129 | gts.append(msk) 130 | if type(out) is tuple: 131 | out = out[0] 132 | out = out.squeeze(1).cpu().detach().numpy() 133 | preds.append(out) 134 | if i % config.save_interval == 0: 135 | save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name) 136 | 137 | preds = np.array(preds).reshape(-1) 138 | gts = np.array(gts).reshape(-1) 139 | 140 | y_pre = np.where(preds>=config.threshold, 1, 0) 141 | y_true = np.where(gts>=0.5, 1, 0) 142 | 143 | confusion = confusion_matrix(y_true, y_pre) 144 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 145 | 146 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 147 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 148 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 149 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 150 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 151 | 152 | if test_data_name is not None: 153 | log_info = f'test_datasets_name: {test_data_name}' 154 | print(log_info) 155 | logger.info(log_info) 156 | log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 157 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 158 | print(log_info) 159 | logger.info(log_info) 160 | 161 | return np.mean(loss_list) -------------------------------------------------------------------------------- /engine_synapse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from torch.cuda.amp import autocast as autocast 5 | import torch 6 | 7 | from sklearn.metrics import confusion_matrix 8 | 9 | from scipy.ndimage.morphology import binary_fill_holes, binary_opening 10 | 11 | from utils import test_single_volume 12 | 13 | import time 14 | 15 | 16 | def train_one_epoch(train_loader, 17 | model, 18 | criterion, 19 | optimizer, 20 | scheduler, 21 | epoch, 22 | logger, 23 | config, 24 | scaler=None): 25 | ''' 26 | train model for one epoch 27 | ''' 28 | stime = time.time() 29 | model.train() 30 | 31 | loss_list = [] 32 | 33 | for iter, data in enumerate(train_loader): 34 | optimizer.zero_grad() 35 | 36 | images, targets = data['image'], data['label'] 37 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() 38 | 39 | if config.amp: 40 | with autocast(): 41 | out = model(images) 42 | loss = criterion(out, targets) 43 | scaler.scale(loss).backward() 44 | scaler.step(optimizer) 45 | scaler.update() 46 | else: 47 | out = model(images) 48 | loss = criterion(out, targets) 49 | loss.backward() 50 | optimizer.step() 51 | 52 | loss_list.append(loss.item()) 53 | now_lr = optimizer.state_dict()['param_groups'][0]['lr'] 54 | mean_loss = np.mean(loss_list) 55 | if iter % config.print_interval == 0: 56 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {loss.item():.4f}, lr: {now_lr}' 57 | print(log_info) 58 | logger.info(log_info) 59 | scheduler.step() 60 | etime = time.time() 61 | log_info = f'Finish one epoch train: epoch {epoch}, loss: {mean_loss:.4f}, time(s): {etime-stime:.2f}' 62 | print(log_info) 63 | logger.info(log_info) 64 | return mean_loss 65 | 66 | 67 | 68 | 69 | 70 | def val_one_epoch(test_datasets, 71 | test_loader, 72 | model, 73 | epoch, 74 | logger, 75 | config, 76 | test_save_path, 77 | val_or_test=False): 78 | # switch to evaluate mode 79 | stime = time.time() 80 | model.eval() 81 | with torch.no_grad(): 82 | metric_list = 0.0 83 | i_batch = 0 84 | for data in tqdm(test_loader): 85 | img, msk, case_name = data['image'], data['label'], data['case_name'][0] 86 | metric_i = test_single_volume(img, msk, model, classes=config.num_classes, patch_size=[config.input_size_h, config.input_size_w], 87 | test_save_path=test_save_path, case=case_name, z_spacing=config.z_spacing, val_or_test=val_or_test) 88 | metric_list += np.array(metric_i) 89 | 90 | logger.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, 91 | np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) 92 | i_batch += 1 93 | metric_list = metric_list / len(test_datasets) 94 | performance = np.mean(metric_list, axis=0)[0] 95 | mean_hd95 = np.mean(metric_list, axis=0)[1] 96 | for i in range(1, config.num_classes): 97 | logger.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1])) 98 | performance = np.mean(metric_list, axis=0)[0] 99 | mean_hd95 = np.mean(metric_list, axis=0)[1] 100 | etime = time.time() 101 | log_info = f'val epoch: {epoch}, mean_dice: {performance}, mean_hd95: {mean_hd95}, time(s): {etime-stime:.2f}' 102 | print(log_info) 103 | logger.info(log_info) 104 | 105 | return performance, mean_hd95 -------------------------------------------------------------------------------- /models/vmunet/vmamba.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from functools import partial 4 | from typing import Optional, Callable 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as checkpoint 10 | from einops import rearrange, repeat 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | try: 13 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 14 | except: 15 | pass 16 | 17 | # an alternative for mamba_ssm (in which causal_conv1d is needed) 18 | try: 19 | from selective_scan import selective_scan_fn as selective_scan_fn_v1 20 | from selective_scan import selective_scan_ref as selective_scan_ref_v1 21 | except: 22 | pass 23 | 24 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" 25 | 26 | 27 | def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): 28 | """ 29 | u: r(B D L) 30 | delta: r(B D L) 31 | A: r(D N) 32 | B: r(B N L) 33 | C: r(B N L) 34 | D: r(D) 35 | z: r(B D L) 36 | delta_bias: r(D), fp32 37 | 38 | ignores: 39 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 40 | """ 41 | import numpy as np 42 | 43 | # fvcore.nn.jit_handles 44 | def get_flops_einsum(input_shapes, equation): 45 | np_arrs = [np.zeros(s) for s in input_shapes] 46 | optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] 47 | for line in optim.split("\n"): 48 | if "optimized flop" in line.lower(): 49 | # divided by 2 because we count MAC (multiply-add counted as one flop) 50 | flop = float(np.floor(float(line.split(":")[-1]) / 2)) 51 | return flop 52 | 53 | 54 | assert not with_complex 55 | 56 | flops = 0 # below code flops = 0 57 | if False: 58 | ... 59 | """ 60 | dtype_in = u.dtype 61 | u = u.float() 62 | delta = delta.float() 63 | if delta_bias is not None: 64 | delta = delta + delta_bias[..., None].float() 65 | if delta_softplus: 66 | delta = F.softplus(delta) 67 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] 68 | is_variable_B = B.dim() >= 3 69 | is_variable_C = C.dim() >= 3 70 | if A.is_complex(): 71 | if is_variable_B: 72 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) 73 | if is_variable_C: 74 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) 75 | else: 76 | B = B.float() 77 | C = C.float() 78 | x = A.new_zeros((batch, dim, dstate)) 79 | ys = [] 80 | """ 81 | 82 | flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") 83 | if with_Group: 84 | flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") 85 | else: 86 | flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") 87 | if False: 88 | ... 89 | """ 90 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 91 | if not is_variable_B: 92 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) 93 | else: 94 | if B.dim() == 3: 95 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) 96 | else: 97 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) 98 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) 99 | if is_variable_C and C.dim() == 4: 100 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) 101 | last_state = None 102 | """ 103 | 104 | in_for_flops = B * D * N 105 | if with_Group: 106 | in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") 107 | else: 108 | in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") 109 | flops += L * in_for_flops 110 | if False: 111 | ... 112 | """ 113 | for i in range(u.shape[2]): 114 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 115 | if not is_variable_C: 116 | y = torch.einsum('bdn,dn->bd', x, C) 117 | else: 118 | if C.dim() == 3: 119 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) 120 | else: 121 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) 122 | if i == u.shape[2] - 1: 123 | last_state = x 124 | if y.is_complex(): 125 | y = y.real * 2 126 | ys.append(y) 127 | y = torch.stack(ys, dim=2) # (batch dim L) 128 | """ 129 | 130 | if with_D: 131 | flops += B * D * L 132 | if with_Z: 133 | flops += B * D * L 134 | if False: 135 | ... 136 | """ 137 | out = y if D is None else y + u * rearrange(D, "d -> d 1") 138 | if z is not None: 139 | out = out * F.silu(z) 140 | out = out.to(dtype=dtype_in) 141 | """ 142 | 143 | return flops 144 | 145 | 146 | class PatchEmbed2D(nn.Module): 147 | r""" Image to Patch Embedding 148 | Args: 149 | patch_size (int): Patch token size. Default: 4. 150 | in_chans (int): Number of input image channels. Default: 3. 151 | embed_dim (int): Number of linear projection output channels. Default: 96. 152 | norm_layer (nn.Module, optional): Normalization layer. Default: None 153 | """ 154 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): 155 | super().__init__() 156 | if isinstance(patch_size, int): 157 | patch_size = (patch_size, patch_size) 158 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 159 | if norm_layer is not None: 160 | self.norm = norm_layer(embed_dim) 161 | else: 162 | self.norm = None 163 | 164 | def forward(self, x): 165 | x = self.proj(x).permute(0, 2, 3, 1) 166 | if self.norm is not None: 167 | x = self.norm(x) 168 | return x 169 | 170 | 171 | class PatchMerging2D(nn.Module): 172 | r""" Patch Merging Layer. 173 | Args: 174 | input_resolution (tuple[int]): Resolution of input feature. 175 | dim (int): Number of input channels. 176 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 177 | """ 178 | 179 | def __init__(self, dim, norm_layer=nn.LayerNorm): 180 | super().__init__() 181 | self.dim = dim 182 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 183 | self.norm = norm_layer(4 * dim) 184 | 185 | def forward(self, x): 186 | B, H, W, C = x.shape 187 | 188 | SHAPE_FIX = [-1, -1] 189 | if (W % 2 != 0) or (H % 2 != 0): 190 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) 191 | SHAPE_FIX[0] = H // 2 192 | SHAPE_FIX[1] = W // 2 193 | 194 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 195 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 196 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 197 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 198 | 199 | if SHAPE_FIX[0] > 0: 200 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 201 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 202 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 203 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 204 | 205 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 206 | x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C 207 | 208 | x = self.norm(x) 209 | x = self.reduction(x) 210 | 211 | return x 212 | 213 | 214 | class PatchExpand2D(nn.Module): 215 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 216 | super().__init__() 217 | self.dim = dim*2 218 | self.dim_scale = dim_scale 219 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 220 | self.norm = norm_layer(self.dim // dim_scale) 221 | 222 | def forward(self, x): 223 | B, H, W, C = x.shape 224 | x = self.expand(x) 225 | 226 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 227 | x= self.norm(x) 228 | 229 | return x 230 | 231 | 232 | class Final_PatchExpand2D(nn.Module): 233 | def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): 234 | super().__init__() 235 | self.dim = dim 236 | self.dim_scale = dim_scale 237 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 238 | self.norm = norm_layer(self.dim // dim_scale) 239 | 240 | def forward(self, x): 241 | B, H, W, C = x.shape 242 | x = self.expand(x) 243 | 244 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 245 | x= self.norm(x) 246 | 247 | return x 248 | 249 | 250 | class SS2D(nn.Module): 251 | def __init__( 252 | self, 253 | d_model, 254 | d_state=16, 255 | # d_state="auto", # 20240109 256 | d_conv=3, 257 | expand=2, 258 | dt_rank="auto", 259 | dt_min=0.001, 260 | dt_max=0.1, 261 | dt_init="random", 262 | dt_scale=1.0, 263 | dt_init_floor=1e-4, 264 | dropout=0., 265 | conv_bias=True, 266 | bias=False, 267 | device=None, 268 | dtype=None, 269 | **kwargs, 270 | ): 271 | factory_kwargs = {"device": device, "dtype": dtype} 272 | super().__init__() 273 | self.d_model = d_model 274 | self.d_state = d_state 275 | # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 276 | self.d_conv = d_conv 277 | self.expand = expand 278 | self.d_inner = int(self.expand * self.d_model) 279 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 280 | 281 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 282 | self.conv2d = nn.Conv2d( 283 | in_channels=self.d_inner, 284 | out_channels=self.d_inner, 285 | groups=self.d_inner, 286 | bias=conv_bias, 287 | kernel_size=d_conv, 288 | padding=(d_conv - 1) // 2, 289 | **factory_kwargs, 290 | ) 291 | self.act = nn.SiLU() 292 | 293 | self.x_proj = ( 294 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 295 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 296 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 297 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 298 | ) 299 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) 300 | del self.x_proj 301 | 302 | self.dt_projs = ( 303 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 304 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 305 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 306 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 307 | ) 308 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) 309 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) 310 | del self.dt_projs 311 | 312 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) 313 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 314 | 315 | # self.selective_scan = selective_scan_fn 316 | self.forward_core = self.forward_corev0 317 | 318 | self.out_norm = nn.LayerNorm(self.d_inner) 319 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 320 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 321 | 322 | @staticmethod 323 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 324 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 325 | 326 | # Initialize special dt projection to preserve variance at initialization 327 | dt_init_std = dt_rank**-0.5 * dt_scale 328 | if dt_init == "constant": 329 | nn.init.constant_(dt_proj.weight, dt_init_std) 330 | elif dt_init == "random": 331 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 332 | else: 333 | raise NotImplementedError 334 | 335 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 336 | dt = torch.exp( 337 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 338 | + math.log(dt_min) 339 | ).clamp(min=dt_init_floor) 340 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 341 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 342 | with torch.no_grad(): 343 | dt_proj.bias.copy_(inv_dt) 344 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 345 | dt_proj.bias._no_reinit = True 346 | 347 | return dt_proj 348 | 349 | @staticmethod 350 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 351 | # S4D real initialization 352 | A = repeat( 353 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 354 | "n -> d n", 355 | d=d_inner, 356 | ).contiguous() 357 | A_log = torch.log(A) # Keep A_log in fp32 358 | if copies > 1: 359 | A_log = repeat(A_log, "d n -> r d n", r=copies) 360 | if merge: 361 | A_log = A_log.flatten(0, 1) 362 | A_log = nn.Parameter(A_log) 363 | A_log._no_weight_decay = True 364 | return A_log 365 | 366 | @staticmethod 367 | def D_init(d_inner, copies=1, device=None, merge=True): 368 | # D "skip" parameter 369 | D = torch.ones(d_inner, device=device) 370 | if copies > 1: 371 | D = repeat(D, "n1 -> r n1", r=copies) 372 | if merge: 373 | D = D.flatten(0, 1) 374 | D = nn.Parameter(D) # Keep in fp32 375 | D._no_weight_decay = True 376 | return D 377 | 378 | def forward_corev0(self, x: torch.Tensor): 379 | self.selective_scan = selective_scan_fn 380 | 381 | B, C, H, W = x.shape 382 | L = H * W 383 | K = 4 384 | 385 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 386 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 387 | 388 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 389 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 390 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 391 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 392 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 393 | 394 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 395 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 396 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 397 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 398 | Ds = self.Ds.float().view(-1) # (k * d) 399 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 400 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 401 | 402 | out_y = self.selective_scan( 403 | xs, dts, 404 | As, Bs, Cs, Ds, z=None, 405 | delta_bias=dt_projs_bias, 406 | delta_softplus=True, 407 | return_last_state=False, 408 | ).view(B, K, -1, L) 409 | assert out_y.dtype == torch.float 410 | 411 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 412 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 413 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 414 | 415 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 416 | 417 | # an alternative to forward_corev1 418 | def forward_corev1(self, x: torch.Tensor): 419 | self.selective_scan = selective_scan_fn_v1 420 | 421 | B, C, H, W = x.shape 422 | L = H * W 423 | K = 4 424 | 425 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 426 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 427 | 428 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 429 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 430 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 431 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 432 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 433 | 434 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 435 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 436 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 437 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 438 | Ds = self.Ds.float().view(-1) # (k * d) 439 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 440 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 441 | 442 | out_y = self.selective_scan( 443 | xs, dts, 444 | As, Bs, Cs, Ds, 445 | delta_bias=dt_projs_bias, 446 | delta_softplus=True, 447 | ).view(B, K, -1, L) 448 | assert out_y.dtype == torch.float 449 | 450 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 451 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 452 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 453 | 454 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 455 | 456 | def forward(self, x: torch.Tensor, **kwargs): 457 | B, H, W, C = x.shape 458 | 459 | xz = self.in_proj(x) 460 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 461 | 462 | x = x.permute(0, 3, 1, 2).contiguous() 463 | x = self.act(self.conv2d(x)) # (b, d, h, w) 464 | y1, y2, y3, y4 = self.forward_core(x) 465 | assert y1.dtype == torch.float32 466 | y = y1 + y2 + y3 + y4 467 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 468 | y = self.out_norm(y) 469 | y = y * F.silu(z) 470 | out = self.out_proj(y) 471 | if self.dropout is not None: 472 | out = self.dropout(out) 473 | return out 474 | 475 | 476 | class VSSBlock(nn.Module): 477 | def __init__( 478 | self, 479 | hidden_dim: int = 0, 480 | drop_path: float = 0, 481 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 482 | attn_drop_rate: float = 0, 483 | d_state: int = 16, 484 | **kwargs, 485 | ): 486 | super().__init__() 487 | self.ln_1 = norm_layer(hidden_dim) 488 | self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) 489 | self.drop_path = DropPath(drop_path) 490 | 491 | def forward(self, input: torch.Tensor): 492 | x = input + self.drop_path(self.self_attention(self.ln_1(input))) 493 | return x 494 | 495 | 496 | class VSSLayer(nn.Module): 497 | """ A basic Swin Transformer layer for one stage. 498 | Args: 499 | dim (int): Number of input channels. 500 | depth (int): Number of blocks. 501 | drop (float, optional): Dropout rate. Default: 0.0 502 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 503 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 504 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 505 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 506 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 507 | """ 508 | 509 | def __init__( 510 | self, 511 | dim, 512 | depth, 513 | attn_drop=0., 514 | drop_path=0., 515 | norm_layer=nn.LayerNorm, 516 | downsample=None, 517 | use_checkpoint=False, 518 | d_state=16, 519 | **kwargs, 520 | ): 521 | super().__init__() 522 | self.dim = dim 523 | self.use_checkpoint = use_checkpoint 524 | 525 | self.blocks = nn.ModuleList([ 526 | VSSBlock( 527 | hidden_dim=dim, 528 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 529 | norm_layer=norm_layer, 530 | attn_drop_rate=attn_drop, 531 | d_state=d_state, 532 | ) 533 | for i in range(depth)]) 534 | 535 | if True: # is this really applied? Yes, but been overriden later in VSSM! 536 | def _init_weights(module: nn.Module): 537 | for name, p in module.named_parameters(): 538 | if name in ["out_proj.weight"]: 539 | p = p.clone().detach_() # fake init, just to keep the seed .... 540 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 541 | self.apply(_init_weights) 542 | 543 | if downsample is not None: 544 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 545 | else: 546 | self.downsample = None 547 | 548 | 549 | def forward(self, x): 550 | for blk in self.blocks: 551 | if self.use_checkpoint: 552 | x = checkpoint.checkpoint(blk, x) 553 | else: 554 | x = blk(x) 555 | 556 | if self.downsample is not None: 557 | x = self.downsample(x) 558 | 559 | return x 560 | 561 | 562 | 563 | class VSSLayer_up(nn.Module): 564 | """ A basic Swin Transformer layer for one stage. 565 | Args: 566 | dim (int): Number of input channels. 567 | depth (int): Number of blocks. 568 | drop (float, optional): Dropout rate. Default: 0.0 569 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 570 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 571 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 572 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 573 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 574 | """ 575 | 576 | def __init__( 577 | self, 578 | dim, 579 | depth, 580 | attn_drop=0., 581 | drop_path=0., 582 | norm_layer=nn.LayerNorm, 583 | upsample=None, 584 | use_checkpoint=False, 585 | d_state=16, 586 | **kwargs, 587 | ): 588 | super().__init__() 589 | self.dim = dim 590 | self.use_checkpoint = use_checkpoint 591 | 592 | self.blocks = nn.ModuleList([ 593 | VSSBlock( 594 | hidden_dim=dim, 595 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 596 | norm_layer=norm_layer, 597 | attn_drop_rate=attn_drop, 598 | d_state=d_state, 599 | ) 600 | for i in range(depth)]) 601 | 602 | if True: # is this really applied? Yes, but been overriden later in VSSM! 603 | def _init_weights(module: nn.Module): 604 | for name, p in module.named_parameters(): 605 | if name in ["out_proj.weight"]: 606 | p = p.clone().detach_() # fake init, just to keep the seed .... 607 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 608 | self.apply(_init_weights) 609 | 610 | if upsample is not None: 611 | self.upsample = upsample(dim=dim, norm_layer=norm_layer) 612 | else: 613 | self.upsample = None 614 | 615 | 616 | def forward(self, x): 617 | if self.upsample is not None: 618 | x = self.upsample(x) 619 | for blk in self.blocks: 620 | if self.use_checkpoint: 621 | x = checkpoint.checkpoint(blk, x) 622 | else: 623 | x = blk(x) 624 | return x 625 | 626 | 627 | 628 | class VSSM(nn.Module): 629 | def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], 630 | dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 631 | norm_layer=nn.LayerNorm, patch_norm=True, 632 | use_checkpoint=False, **kwargs): 633 | super().__init__() 634 | self.num_classes = num_classes 635 | self.num_layers = len(depths) 636 | if isinstance(dims, int): 637 | dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] 638 | self.embed_dim = dims[0] 639 | self.num_features = dims[-1] 640 | self.dims = dims 641 | 642 | self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, 643 | norm_layer=norm_layer if patch_norm else None) 644 | 645 | # WASTED absolute position embedding ====================== 646 | self.ape = False 647 | # self.ape = False 648 | # drop_rate = 0.0 649 | if self.ape: 650 | self.patches_resolution = self.patch_embed.patches_resolution 651 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) 652 | trunc_normal_(self.absolute_pos_embed, std=.02) 653 | self.pos_drop = nn.Dropout(p=drop_rate) 654 | 655 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 656 | dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] 657 | 658 | self.layers = nn.ModuleList() 659 | for i_layer in range(self.num_layers): 660 | layer = VSSLayer( 661 | dim=dims[i_layer], 662 | depth=depths[i_layer], 663 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 664 | drop=drop_rate, 665 | attn_drop=attn_drop_rate, 666 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 667 | norm_layer=norm_layer, 668 | downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, 669 | use_checkpoint=use_checkpoint, 670 | ) 671 | self.layers.append(layer) 672 | 673 | self.layers_up = nn.ModuleList() 674 | for i_layer in range(self.num_layers): 675 | layer = VSSLayer_up( 676 | dim=dims_decoder[i_layer], 677 | depth=depths_decoder[i_layer], 678 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 679 | drop=drop_rate, 680 | attn_drop=attn_drop_rate, 681 | drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], 682 | norm_layer=norm_layer, 683 | upsample=PatchExpand2D if (i_layer != 0) else None, 684 | use_checkpoint=use_checkpoint, 685 | ) 686 | self.layers_up.append(layer) 687 | 688 | self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) 689 | self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) 690 | 691 | # self.norm = norm_layer(self.num_features) 692 | # self.avgpool = nn.AdaptiveAvgPool1d(1) 693 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 694 | 695 | self.apply(self._init_weights) 696 | 697 | def _init_weights(self, m: nn.Module): 698 | """ 699 | out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear 700 | no fc.weight found in the any of the model parameters 701 | no nn.Embedding found in the any of the model parameters 702 | so the thing is, VSSBlock initialization is useless 703 | 704 | Conv2D is not intialized !!! 705 | """ 706 | if isinstance(m, nn.Linear): 707 | trunc_normal_(m.weight, std=.02) 708 | if isinstance(m, nn.Linear) and m.bias is not None: 709 | nn.init.constant_(m.bias, 0) 710 | elif isinstance(m, nn.LayerNorm): 711 | nn.init.constant_(m.bias, 0) 712 | nn.init.constant_(m.weight, 1.0) 713 | 714 | @torch.jit.ignore 715 | def no_weight_decay(self): 716 | return {'absolute_pos_embed'} 717 | 718 | @torch.jit.ignore 719 | def no_weight_decay_keywords(self): 720 | return {'relative_position_bias_table'} 721 | 722 | def forward_features(self, x): 723 | skip_list = [] 724 | x = self.patch_embed(x) 725 | if self.ape: 726 | x = x + self.absolute_pos_embed 727 | x = self.pos_drop(x) 728 | 729 | for layer in self.layers: 730 | skip_list.append(x) 731 | x = layer(x) 732 | return x, skip_list 733 | 734 | def forward_features_up(self, x, skip_list): 735 | for inx, layer_up in enumerate(self.layers_up): 736 | if inx == 0: 737 | x = layer_up(x) 738 | else: 739 | x = layer_up(x+skip_list[-inx]) 740 | 741 | return x 742 | 743 | def forward_final(self, x): 744 | x = self.final_up(x) 745 | x = x.permute(0,3,1,2) 746 | x = self.final_conv(x) 747 | return x 748 | 749 | def forward_backbone(self, x): 750 | x = self.patch_embed(x) 751 | if self.ape: 752 | x = x + self.absolute_pos_embed 753 | x = self.pos_drop(x) 754 | 755 | for layer in self.layers: 756 | x = layer(x) 757 | return x 758 | 759 | def forward(self, x): 760 | x, skip_list = self.forward_features(x) 761 | x = self.forward_features_up(x, skip_list) 762 | x = self.forward_final(x) 763 | 764 | return x 765 | 766 | 767 | 768 | 769 | 770 | 771 | 772 | -------------------------------------------------------------------------------- /models/vmunet/vmunet.py: -------------------------------------------------------------------------------- 1 | from .vmamba import VSSM 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class VMUNet(nn.Module): 7 | def __init__(self, 8 | input_channels=3, 9 | num_classes=1, 10 | depths=[2, 2, 9, 2], 11 | depths_decoder=[2, 9, 2, 2], 12 | drop_path_rate=0.2, 13 | load_ckpt_path=None, 14 | ): 15 | super().__init__() 16 | 17 | self.load_ckpt_path = load_ckpt_path 18 | self.num_classes = num_classes 19 | 20 | self.vmunet = VSSM(in_chans=input_channels, 21 | num_classes=num_classes, 22 | depths=depths, 23 | depths_decoder=depths_decoder, 24 | drop_path_rate=drop_path_rate, 25 | ) 26 | 27 | def forward(self, x): 28 | if x.size()[1] == 1: 29 | x = x.repeat(1,3,1,1) 30 | logits = self.vmunet(x) 31 | if self.num_classes == 1: return torch.sigmoid(logits) 32 | else: return logits 33 | 34 | def load_from(self): 35 | if self.load_ckpt_path is not None: 36 | model_dict = self.vmunet.state_dict() 37 | modelCheckpoint = torch.load(self.load_ckpt_path) 38 | pretrained_dict = modelCheckpoint['model'] 39 | # 过滤操作 40 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 41 | model_dict.update(new_dict) 42 | # 打印出来,更新了多少的参数 43 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 44 | self.vmunet.load_state_dict(model_dict) 45 | 46 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 47 | print('Not loaded keys:', not_loaded_keys) 48 | print("encoder loaded finished!") 49 | 50 | model_dict = self.vmunet.state_dict() 51 | modelCheckpoint = torch.load(self.load_ckpt_path) 52 | pretrained_odict = modelCheckpoint['model'] 53 | pretrained_dict = {} 54 | for k, v in pretrained_odict.items(): 55 | if 'layers.0' in k: 56 | new_k = k.replace('layers.0', 'layers_up.3') 57 | pretrained_dict[new_k] = v 58 | elif 'layers.1' in k: 59 | new_k = k.replace('layers.1', 'layers_up.2') 60 | pretrained_dict[new_k] = v 61 | elif 'layers.2' in k: 62 | new_k = k.replace('layers.2', 'layers_up.1') 63 | pretrained_dict[new_k] = v 64 | elif 'layers.3' in k: 65 | new_k = k.replace('layers.3', 'layers_up.0') 66 | pretrained_dict[new_k] = v 67 | # 过滤操作 68 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 69 | model_dict.update(new_dict) 70 | # 打印出来,更新了多少的参数 71 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 72 | self.vmunet.load_state_dict(model_dict) 73 | 74 | # 找到没有加载的键(keys) 75 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 76 | print('Not loaded keys:', not_loaded_keys) 77 | print("decoder loaded finished!") -------------------------------------------------------------------------------- /pre_trained_weights/README.md: -------------------------------------------------------------------------------- 1 | You could download the weights used by VM-UNet from [Baidu Drive](https://pan.baidu.com/s/1ci_YvPPEiUT2bIIK5x8Igw?pwd=wnyy) or [Google Drive](https://drive.google.com/drive/folders/1Fr7zM1wq7106d0P7_3oeU5UZqUvk2KaP?usp=sharing) 2 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | The results of experiments could be found here. 2 | 3 | Our results on the ISIC17, ISIC18 and Synapse datasets could be found [here](https://pan.baidu.com/s/1FzDpfBQGz56BH6v2APoymA?pwd=8bqf) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import timm 4 | from datasets.dataset import NPY_datasets 5 | from tensorboardX import SummaryWriter 6 | from models.vmunet.vmunet import VMUNet 7 | 8 | from engine import * 9 | import os 10 | import sys 11 | 12 | from utils import * 13 | from configs.config_setting import setting_config 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | 20 | def main(config): 21 | 22 | print('#----------Creating logger----------#') 23 | sys.path.append(config.work_dir + '/') 24 | log_dir = os.path.join(config.work_dir, 'log') 25 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 26 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 27 | outputs = os.path.join(config.work_dir, 'outputs') 28 | if not os.path.exists(checkpoint_dir): 29 | os.makedirs(checkpoint_dir) 30 | if not os.path.exists(outputs): 31 | os.makedirs(outputs) 32 | 33 | global logger 34 | logger = get_logger('train', log_dir) 35 | global writer 36 | writer = SummaryWriter(config.work_dir + 'summary') 37 | 38 | log_config_info(config, logger) 39 | 40 | 41 | 42 | 43 | 44 | print('#----------GPU init----------#') 45 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id 46 | set_seed(config.seed) 47 | torch.cuda.empty_cache() 48 | 49 | 50 | 51 | 52 | 53 | print('#----------Preparing dataset----------#') 54 | train_dataset = NPY_datasets(config.data_path, config, train=True) 55 | train_loader = DataLoader(train_dataset, 56 | batch_size=config.batch_size, 57 | shuffle=True, 58 | pin_memory=True, 59 | num_workers=config.num_workers) 60 | val_dataset = NPY_datasets(config.data_path, config, train=False) 61 | val_loader = DataLoader(val_dataset, 62 | batch_size=1, 63 | shuffle=False, 64 | pin_memory=True, 65 | num_workers=config.num_workers, 66 | drop_last=True) 67 | 68 | 69 | 70 | 71 | 72 | print('#----------Prepareing Model----------#') 73 | model_cfg = config.model_config 74 | if config.network == 'vmunet': 75 | model = VMUNet( 76 | num_classes=model_cfg['num_classes'], 77 | input_channels=model_cfg['input_channels'], 78 | depths=model_cfg['depths'], 79 | depths_decoder=model_cfg['depths_decoder'], 80 | drop_path_rate=model_cfg['drop_path_rate'], 81 | load_ckpt_path=model_cfg['load_ckpt_path'], 82 | ) 83 | model.load_from() 84 | 85 | else: raise Exception('network in not right!') 86 | model = model.cuda() 87 | 88 | cal_params_flops(model, 256, logger) 89 | 90 | 91 | 92 | 93 | 94 | print('#----------Prepareing loss, opt, sch and amp----------#') 95 | criterion = config.criterion 96 | optimizer = get_optimizer(config, model) 97 | scheduler = get_scheduler(config, optimizer) 98 | 99 | 100 | 101 | 102 | 103 | print('#----------Set other params----------#') 104 | min_loss = 999 105 | start_epoch = 1 106 | min_epoch = 1 107 | 108 | if config.only_test_and_save_figs: 109 | checkpoint = torch.load(config.best_ckpt_path, map_location=torch.device('cpu')) 110 | model.load_state_dict(checkpoint) 111 | config.work_dir = config.img_save_path 112 | if not os.path.exists(config.work_dir + 'outputs/'): 113 | os.makedirs(config.work_dir + 'outputs/') 114 | loss = test_one_epoch( 115 | val_loader, 116 | model, 117 | criterion, 118 | logger, 119 | config, 120 | ) 121 | return 122 | 123 | 124 | 125 | 126 | if os.path.exists(resume_model): 127 | print('#----------Resume Model and Other params----------#') 128 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 129 | model.load_state_dict(checkpoint['model_state_dict']) 130 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 131 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 132 | saved_epoch = checkpoint['epoch'] 133 | start_epoch += saved_epoch 134 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 135 | 136 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 137 | logger.info(log_info) 138 | 139 | 140 | 141 | 142 | step = 0 143 | print('#----------Training----------#') 144 | for epoch in range(start_epoch, config.epochs + 1): 145 | 146 | torch.cuda.empty_cache() 147 | 148 | step = train_one_epoch( 149 | train_loader, 150 | model, 151 | criterion, 152 | optimizer, 153 | scheduler, 154 | epoch, 155 | step, 156 | logger, 157 | config, 158 | writer 159 | ) 160 | 161 | loss = val_one_epoch( 162 | val_loader, 163 | model, 164 | criterion, 165 | epoch, 166 | logger, 167 | config 168 | ) 169 | 170 | if loss < min_loss: 171 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 172 | min_loss = loss 173 | min_epoch = epoch 174 | 175 | torch.save( 176 | { 177 | 'epoch': epoch, 178 | 'min_loss': min_loss, 179 | 'min_epoch': min_epoch, 180 | 'loss': loss, 181 | 'model_state_dict': model.state_dict(), 182 | 'optimizer_state_dict': optimizer.state_dict(), 183 | 'scheduler_state_dict': scheduler.state_dict(), 184 | }, os.path.join(checkpoint_dir, 'latest.pth')) 185 | 186 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 187 | print('#----------Testing----------#') 188 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 189 | model.load_state_dict(best_weight) 190 | loss = test_one_epoch( 191 | val_loader, 192 | model, 193 | criterion, 194 | logger, 195 | config, 196 | ) 197 | os.rename( 198 | os.path.join(checkpoint_dir, 'best.pth'), 199 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 200 | ) 201 | 202 | 203 | if __name__ == '__main__': 204 | config = setting_config 205 | main(config) -------------------------------------------------------------------------------- /train_synapse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | import torch.distributed as dist 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torchvision import transforms 9 | 10 | from datasets.dataset import RandomGenerator 11 | from engine_synapse import * 12 | 13 | from models.vmunet.vmunet import VMUNet 14 | 15 | import os 16 | import sys 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 18 | 19 | from utils import * 20 | from configs.config_setting_synapse import setting_config 21 | 22 | import warnings 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | 27 | 28 | 29 | def main(config): 30 | 31 | print('#----------Creating logger----------#') 32 | sys.path.append(config.work_dir + '/') 33 | log_dir = os.path.join(config.work_dir, 'log') 34 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 35 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 36 | outputs = os.path.join(config.work_dir, 'outputs') 37 | if not os.path.exists(checkpoint_dir): 38 | os.makedirs(checkpoint_dir) 39 | if not os.path.exists(outputs): 40 | os.makedirs(outputs) 41 | 42 | global logger 43 | logger = get_logger('train', log_dir) 44 | 45 | log_config_info(config, logger) 46 | 47 | 48 | 49 | 50 | 51 | print('#----------GPU init----------#') 52 | set_seed(config.seed) 53 | gpu_ids = [0]# [0, 1, 2, 3] 54 | torch.cuda.empty_cache() 55 | gpus_type, gpus_num = torch.cuda.get_device_name(), torch.cuda.device_count() 56 | if config.distributed: 57 | print('#----------Start DDP----------#') 58 | dist.init_process_group(backend='nccl', init_method='env://') 59 | torch.cuda.manual_seed_all(config.seed) 60 | config.local_rank = torch.distributed.get_rank() 61 | 62 | 63 | 64 | 65 | 66 | print('#----------Preparing dataset----------#') 67 | train_dataset = config.datasets(base_dir=config.data_path, list_dir=config.list_dir, split="train", 68 | transform=transforms.Compose( 69 | [RandomGenerator(output_size=[config.input_size_h, config.input_size_w])])) 70 | train_sampler = DistributedSampler(train_dataset, shuffle=True) if config.distributed else None 71 | train_loader = DataLoader(train_dataset, 72 | batch_size=config.batch_size//gpus_num if config.distributed else config.batch_size, 73 | shuffle=(train_sampler is None), 74 | pin_memory=True, 75 | num_workers=config.num_workers, 76 | sampler=train_sampler) 77 | 78 | val_dataset = config.datasets(base_dir=config.volume_path, split="test_vol", list_dir=config.list_dir) 79 | val_sampler = DistributedSampler(val_dataset, shuffle=False) if config.distributed else None 80 | val_loader = DataLoader(val_dataset, 81 | batch_size=1, # if config.distributed else config.batch_size, 82 | shuffle=False, 83 | pin_memory=True, 84 | num_workers=config.num_workers, 85 | sampler=val_sampler, 86 | drop_last=True) 87 | 88 | 89 | 90 | 91 | 92 | print('#----------Prepareing Models----------#') 93 | model_cfg = config.model_config 94 | if config.network == 'vmunet': 95 | model = VMUNet( 96 | num_classes=model_cfg['num_classes'], 97 | input_channels=model_cfg['input_channels'], 98 | depths=model_cfg['depths'], 99 | depths_decoder=model_cfg['depths_decoder'], 100 | drop_path_rate=model_cfg['drop_path_rate'], 101 | load_ckpt_path=model_cfg['load_ckpt_path'], 102 | ) 103 | model.load_from() 104 | else: raise('Please prepare a right net!') 105 | 106 | if config.distributed: 107 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 108 | model = DDP(model, device_ids=[config.local_rank], output_device=config.local_rank) 109 | else: 110 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 111 | 112 | 113 | 114 | 115 | 116 | print('#----------Prepareing loss, opt, sch and amp----------#') 117 | criterion = config.criterion 118 | optimizer = get_optimizer(config, model) 119 | scheduler = get_scheduler(config, optimizer) 120 | scaler = GradScaler() 121 | 122 | 123 | 124 | 125 | 126 | print('#----------Set other params----------#') 127 | min_loss = 999 128 | start_epoch = 1 129 | min_epoch = 1 130 | 131 | 132 | if config.only_test_and_save_figs: 133 | checkpoint = torch.load(config.best_ckpt_path, map_location=torch.device('cpu')) 134 | model.load_state_dict(checkpoint) 135 | config.work_dir = config.img_save_path 136 | if not os.path.exists(config.work_dir + 'outputs/'): 137 | os.makedirs(config.work_dir + 'outputs/') 138 | mean_dice, mean_hd95 = val_one_epoch( 139 | val_dataset, 140 | val_loader, 141 | model, 142 | epoch, 143 | logger, 144 | config, 145 | test_save_path=outputs, 146 | val_or_test=True 147 | ) 148 | print(mean_dice, mean_hd95) 149 | return 150 | 151 | 152 | 153 | if os.path.exists(resume_model): 154 | print('#----------Resume Model and Other params----------#') 155 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 156 | model.module.load_state_dict(checkpoint['model_state_dict']) 157 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 158 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 159 | saved_epoch = checkpoint['epoch'] 160 | start_epoch += saved_epoch 161 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 162 | 163 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 164 | logger.info(log_info) 165 | 166 | 167 | 168 | 169 | 170 | print('#----------Training----------#') 171 | for epoch in range(start_epoch, config.epochs + 1): 172 | 173 | torch.cuda.empty_cache() 174 | train_sampler.set_epoch(epoch) if config.distributed else None 175 | 176 | loss = train_one_epoch( 177 | train_loader, 178 | model, 179 | criterion, 180 | optimizer, 181 | scheduler, 182 | epoch, 183 | logger, 184 | config, 185 | scaler=scaler 186 | ) 187 | 188 | if loss < min_loss: 189 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 190 | min_loss = loss 191 | min_epoch = epoch 192 | 193 | if epoch % config.val_interval == 0: 194 | mean_dice, mean_hd95 = val_one_epoch( 195 | val_dataset, 196 | val_loader, 197 | model, 198 | epoch, 199 | logger, 200 | config, 201 | test_save_path=outputs, 202 | val_or_test=False 203 | ) 204 | 205 | torch.save( 206 | { 207 | 'epoch': epoch, 208 | 'min_loss': min_loss, 209 | 'min_epoch': min_epoch, 210 | 'loss': loss, 211 | 'model_state_dict': model.module.state_dict(), 212 | 'optimizer_state_dict': optimizer.state_dict(), 213 | 'scheduler_state_dict': scheduler.state_dict(), 214 | }, os.path.join(checkpoint_dir, 'latest.pth')) 215 | 216 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 217 | print('#----------Testing----------#') 218 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 219 | model.module.load_state_dict(best_weight) 220 | mean_dice, mean_hd95 = val_one_epoch( 221 | val_dataset, 222 | val_loader, 223 | model, 224 | epoch, 225 | logger, 226 | config, 227 | test_save_path=outputs, 228 | val_or_test=True 229 | ) 230 | os.rename( 231 | os.path.join(checkpoint_dir, 'best.pth'), 232 | os.path.join(checkpoint_dir, 233 | f'best-epoch{min_epoch}-mean_dice{mean_dice:.4f}-mean_hd95{mean_hd95:.4f}.pth') 234 | ) 235 | 236 | 237 | if __name__ == '__main__': 238 | config = setting_config 239 | main(config) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.backends.cudnn as cudnn 5 | import torchvision.transforms.functional as TF 6 | import numpy as np 7 | import os 8 | import math 9 | import random 10 | import logging 11 | import logging.handlers 12 | from matplotlib import pyplot as plt 13 | 14 | from scipy.ndimage import zoom 15 | import SimpleITK as sitk 16 | from medpy import metric 17 | 18 | 19 | def set_seed(seed): 20 | # for hash 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | # for python and numpy 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | # for cpu gpu 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | # for cudnn 30 | cudnn.benchmark = False 31 | cudnn.deterministic = True 32 | 33 | 34 | def get_logger(name, log_dir): 35 | ''' 36 | Args: 37 | name(str): name of logger 38 | log_dir(str): path of log 39 | ''' 40 | 41 | if not os.path.exists(log_dir): 42 | os.makedirs(log_dir) 43 | 44 | logger = logging.getLogger(name) 45 | logger.setLevel(logging.INFO) 46 | 47 | info_name = os.path.join(log_dir, '{}.info.log'.format(name)) 48 | info_handler = logging.handlers.TimedRotatingFileHandler(info_name, 49 | when='D', 50 | encoding='utf-8') 51 | info_handler.setLevel(logging.INFO) 52 | 53 | formatter = logging.Formatter('%(asctime)s - %(message)s', 54 | datefmt='%Y-%m-%d %H:%M:%S') 55 | 56 | info_handler.setFormatter(formatter) 57 | 58 | logger.addHandler(info_handler) 59 | 60 | return logger 61 | 62 | 63 | def log_config_info(config, logger): 64 | config_dict = config.__dict__ 65 | log_info = f'#----------Config info----------#' 66 | logger.info(log_info) 67 | for k, v in config_dict.items(): 68 | if k[0] == '_': 69 | continue 70 | else: 71 | log_info = f'{k}: {v},' 72 | logger.info(log_info) 73 | 74 | 75 | 76 | def get_optimizer(config, model): 77 | assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 78 | 79 | if config.opt == 'Adadelta': 80 | return torch.optim.Adadelta( 81 | model.parameters(), 82 | lr = config.lr, 83 | rho = config.rho, 84 | eps = config.eps, 85 | weight_decay = config.weight_decay 86 | ) 87 | elif config.opt == 'Adagrad': 88 | return torch.optim.Adagrad( 89 | model.parameters(), 90 | lr = config.lr, 91 | lr_decay = config.lr_decay, 92 | eps = config.eps, 93 | weight_decay = config.weight_decay 94 | ) 95 | elif config.opt == 'Adam': 96 | return torch.optim.Adam( 97 | model.parameters(), 98 | lr = config.lr, 99 | betas = config.betas, 100 | eps = config.eps, 101 | weight_decay = config.weight_decay, 102 | amsgrad = config.amsgrad 103 | ) 104 | elif config.opt == 'AdamW': 105 | return torch.optim.AdamW( 106 | model.parameters(), 107 | lr = config.lr, 108 | betas = config.betas, 109 | eps = config.eps, 110 | weight_decay = config.weight_decay, 111 | amsgrad = config.amsgrad 112 | ) 113 | elif config.opt == 'Adamax': 114 | return torch.optim.Adamax( 115 | model.parameters(), 116 | lr = config.lr, 117 | betas = config.betas, 118 | eps = config.eps, 119 | weight_decay = config.weight_decay 120 | ) 121 | elif config.opt == 'ASGD': 122 | return torch.optim.ASGD( 123 | model.parameters(), 124 | lr = config.lr, 125 | lambd = config.lambd, 126 | alpha = config.alpha, 127 | t0 = config.t0, 128 | weight_decay = config.weight_decay 129 | ) 130 | elif config.opt == 'RMSprop': 131 | return torch.optim.RMSprop( 132 | model.parameters(), 133 | lr = config.lr, 134 | momentum = config.momentum, 135 | alpha = config.alpha, 136 | eps = config.eps, 137 | centered = config.centered, 138 | weight_decay = config.weight_decay 139 | ) 140 | elif config.opt == 'Rprop': 141 | return torch.optim.Rprop( 142 | model.parameters(), 143 | lr = config.lr, 144 | etas = config.etas, 145 | step_sizes = config.step_sizes, 146 | ) 147 | elif config.opt == 'SGD': 148 | return torch.optim.SGD( 149 | model.parameters(), 150 | lr = config.lr, 151 | momentum = config.momentum, 152 | weight_decay = config.weight_decay, 153 | dampening = config.dampening, 154 | nesterov = config.nesterov 155 | ) 156 | else: # default opt is SGD 157 | return torch.optim.SGD( 158 | model.parameters(), 159 | lr = 0.01, 160 | momentum = 0.9, 161 | weight_decay = 0.05, 162 | ) 163 | 164 | 165 | def get_scheduler(config, optimizer): 166 | assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau', 167 | 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!' 168 | if config.sch == 'StepLR': 169 | scheduler = torch.optim.lr_scheduler.StepLR( 170 | optimizer, 171 | step_size = config.step_size, 172 | gamma = config.gamma, 173 | last_epoch = config.last_epoch 174 | ) 175 | elif config.sch == 'MultiStepLR': 176 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 177 | optimizer, 178 | milestones = config.milestones, 179 | gamma = config.gamma, 180 | last_epoch = config.last_epoch 181 | ) 182 | elif config.sch == 'ExponentialLR': 183 | scheduler = torch.optim.lr_scheduler.ExponentialLR( 184 | optimizer, 185 | gamma = config.gamma, 186 | last_epoch = config.last_epoch 187 | ) 188 | elif config.sch == 'CosineAnnealingLR': 189 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 190 | optimizer, 191 | T_max = config.T_max, 192 | eta_min = config.eta_min, 193 | last_epoch = config.last_epoch 194 | ) 195 | elif config.sch == 'ReduceLROnPlateau': 196 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 197 | optimizer, 198 | mode = config.mode, 199 | factor = config.factor, 200 | patience = config.patience, 201 | threshold = config.threshold, 202 | threshold_mode = config.threshold_mode, 203 | cooldown = config.cooldown, 204 | min_lr = config.min_lr, 205 | eps = config.eps 206 | ) 207 | elif config.sch == 'CosineAnnealingWarmRestarts': 208 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 209 | optimizer, 210 | T_0 = config.T_0, 211 | T_mult = config.T_mult, 212 | eta_min = config.eta_min, 213 | last_epoch = config.last_epoch 214 | ) 215 | elif config.sch == 'WP_MultiStepLR': 216 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len( 217 | [m for m in config.milestones if m <= epoch]) 218 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 219 | elif config.sch == 'WP_CosineLR': 220 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * ( 221 | math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1) 222 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 223 | 224 | return scheduler 225 | 226 | 227 | 228 | def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None): 229 | # 处理图像数据 230 | img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy() 231 | img = img / 255. if img.max() > 1.1 else img 232 | 233 | # 根据数据集处理掩码 234 | if datasets == 'retinal': 235 | msk = np.squeeze(msk, axis=0) 236 | msk_pred = np.squeeze(msk_pred, axis=0) 237 | else: 238 | att_msk = np.squeeze(msk_pred, axis=0) 239 | msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0) 240 | msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0) 241 | 242 | # 设置画布大小 243 | plt.figure(figsize=(10, 20)) 244 | 245 | # 调整子图间距,减少白色边框 246 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.05, hspace=0.05) 247 | 248 | # 原图 249 | plt.subplot(4, 1, 1) 250 | plt.imshow(img) 251 | plt.axis('off') 252 | plt.title('Original Image') 253 | 254 | # Ground Truth 掩码 255 | plt.subplot(4, 1, 2) 256 | plt.imshow(msk, cmap='gray') 257 | plt.axis('off') 258 | plt.title('Ground Truth Mask') 259 | 260 | # 预测掩码 261 | plt.subplot(4, 1, 3) 262 | plt.imshow(msk_pred, cmap='gray') 263 | plt.axis('off') 264 | plt.title('Predicted Mask') 265 | 266 | # 注意力图谱叠加到原图上 267 | plt.subplot(4, 1, 4) 268 | plt.imshow(img) 269 | plt.imshow(att_msk, cmap='jet', alpha=0.5) # 热力图叠加到原图上,alpha设置透明度 270 | plt.axis('off') 271 | plt.title('Attention Map Overlay') 272 | 273 | # 保存图片 274 | if test_data_name is not None: 275 | save_path = save_path + test_data_name + '_' 276 | plt.savefig(save_path + str(i) + '.png', bbox_inches='tight', pad_inches=0) 277 | plt.close() 278 | 279 | 280 | 281 | class BCELoss(nn.Module): 282 | def __init__(self): 283 | super(BCELoss, self).__init__() 284 | self.bceloss = nn.BCELoss() 285 | 286 | def forward(self, pred, target): 287 | size = pred.size(0) 288 | pred_ = pred.view(size, -1) 289 | target_ = target.view(size, -1) 290 | 291 | return self.bceloss(pred_, target_) 292 | 293 | 294 | class DiceLoss(nn.Module): 295 | def __init__(self): 296 | super(DiceLoss, self).__init__() 297 | 298 | def forward(self, pred, target): 299 | smooth = 1 300 | size = pred.size(0) 301 | 302 | pred_ = pred.view(size, -1) 303 | target_ = target.view(size, -1) 304 | intersection = pred_ * target_ 305 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth) 306 | dice_loss = 1 - dice_score.sum()/size 307 | 308 | return dice_loss 309 | 310 | 311 | class nDiceLoss(nn.Module): 312 | def __init__(self, n_classes): 313 | super(nDiceLoss, self).__init__() 314 | self.n_classes = n_classes 315 | 316 | def _one_hot_encoder(self, input_tensor): 317 | tensor_list = [] 318 | for i in range(self.n_classes): 319 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 320 | tensor_list.append(temp_prob.unsqueeze(1)) 321 | output_tensor = torch.cat(tensor_list, dim=1) 322 | return output_tensor.float() 323 | 324 | def _dice_loss(self, score, target): 325 | target = target.float() 326 | smooth = 1e-5 327 | intersect = torch.sum(score * target) 328 | y_sum = torch.sum(target * target) 329 | z_sum = torch.sum(score * score) 330 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 331 | loss = 1 - loss 332 | return loss 333 | 334 | def forward(self, inputs, target, weight=None, softmax=False): 335 | if softmax: 336 | inputs = torch.softmax(inputs, dim=1) 337 | target = self._one_hot_encoder(target) 338 | if weight is None: 339 | weight = [1] * self.n_classes 340 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 341 | class_wise_dice = [] 342 | loss = 0.0 343 | for i in range(0, self.n_classes): 344 | dice = self._dice_loss(inputs[:, i], target[:, i]) 345 | class_wise_dice.append(1.0 - dice.item()) 346 | loss += dice * weight[i] 347 | return loss / self.n_classes 348 | 349 | 350 | class CeDiceLoss(nn.Module): 351 | def __init__(self, num_classes, loss_weight=[0.4, 0.6]): 352 | super(CeDiceLoss, self).__init__() 353 | self.celoss = nn.CrossEntropyLoss() 354 | self.diceloss = nDiceLoss(num_classes) 355 | self.loss_weight = loss_weight 356 | 357 | def forward(self, pred, target): 358 | loss_ce = self.celoss(pred, target[:].long()) 359 | loss_dice = self.diceloss(pred, target, softmax=True) 360 | loss = self.loss_weight[0] * loss_ce + self.loss_weight[1] * loss_dice 361 | return loss 362 | 363 | 364 | class BceDiceLoss(nn.Module): 365 | def __init__(self, wb=1, wd=1): 366 | super(BceDiceLoss, self).__init__() 367 | self.bce = BCELoss() 368 | self.dice = DiceLoss() 369 | self.wb = wb 370 | self.wd = wd 371 | 372 | def forward(self, pred, target): 373 | bceloss = self.bce(pred, target) 374 | diceloss = self.dice(pred, target) 375 | 376 | loss = self.wd * diceloss + self.wb * bceloss 377 | return loss 378 | 379 | 380 | class GT_BceDiceLoss(nn.Module): 381 | def __init__(self, wb=1, wd=1): 382 | super(GT_BceDiceLoss, self).__init__() 383 | self.bcedice = BceDiceLoss(wb, wd) 384 | 385 | def forward(self, gt_pre, out, target): 386 | bcediceloss = self.bcedice(out, target) 387 | gt_pre5, gt_pre4, gt_pre3, gt_pre2, gt_pre1 = gt_pre 388 | gt_loss = self.bcedice(gt_pre5, target) * 0.1 + self.bcedice(gt_pre4, target) * 0.2 + self.bcedice(gt_pre3, target) * 0.3 + self.bcedice(gt_pre2, target) * 0.4 + self.bcedice(gt_pre1, target) * 0.5 389 | return bcediceloss + gt_loss 390 | 391 | 392 | 393 | class myToTensor: 394 | def __init__(self): 395 | pass 396 | def __call__(self, data): 397 | image, mask = data 398 | return torch.tensor(image).permute(2,0,1), torch.tensor(mask).permute(2,0,1) 399 | 400 | 401 | class myResize: 402 | def __init__(self, size_h=256, size_w=256): 403 | self.size_h = size_h 404 | self.size_w = size_w 405 | def __call__(self, data): 406 | image, mask = data 407 | return TF.resize(image, [self.size_h, self.size_w]), TF.resize(mask, [self.size_h, self.size_w]) 408 | 409 | 410 | class myRandomHorizontalFlip: 411 | def __init__(self, p=0.5): 412 | self.p = p 413 | def __call__(self, data): 414 | image, mask = data 415 | if random.random() < self.p: return TF.hflip(image), TF.hflip(mask) 416 | else: return image, mask 417 | 418 | 419 | class myRandomVerticalFlip: 420 | def __init__(self, p=0.5): 421 | self.p = p 422 | def __call__(self, data): 423 | image, mask = data 424 | if random.random() < self.p: return TF.vflip(image), TF.vflip(mask) 425 | else: return image, mask 426 | 427 | 428 | class myRandomRotation: 429 | def __init__(self, p=0.5, degree=[0,360]): 430 | self.angle = random.uniform(degree[0], degree[1]) 431 | self.p = p 432 | def __call__(self, data): 433 | image, mask = data 434 | if random.random() < self.p: return TF.rotate(image,self.angle), TF.rotate(mask,self.angle) 435 | else: return image, mask 436 | 437 | 438 | class myNormalize: 439 | def __init__(self, data_name, train=True): 440 | if data_name == 'isic18': 441 | if train: 442 | self.mean = 157.561 443 | self.std = 26.706 444 | else: 445 | self.mean = 149.034 446 | self.std = 32.022 447 | elif data_name == 'isic17': 448 | if train: 449 | self.mean = 159.922 450 | self.std = 28.871 451 | else: 452 | self.mean = 148.429 453 | self.std = 25.748 454 | elif data_name == 'isic18_82': 455 | if train: 456 | self.mean = 156.2899 457 | self.std = 26.5457 458 | else: 459 | self.mean = 149.8485 460 | self.std = 35.3346 461 | 462 | def __call__(self, data): 463 | img, msk = data 464 | img_normalized = (img-self.mean)/self.std 465 | img_normalized = ((img_normalized - np.min(img_normalized)) 466 | / (np.max(img_normalized)-np.min(img_normalized))) * 255. 467 | return img_normalized, msk 468 | 469 | 470 | 471 | from thop import profile ## 导入thop模块 472 | def cal_params_flops(model, size, logger): 473 | input = torch.randn(1, 3, size, size).cuda() 474 | flops, params = profile(model, inputs=(input,)) 475 | print('flops',flops/1e9) ## 打印计算量 476 | print('params',params/1e6) ## 打印参数量 477 | 478 | total = sum(p.numel() for p in model.parameters()) 479 | print("Total params: %.2fM" % (total/1e6)) 480 | logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}') 481 | 482 | 483 | 484 | 485 | 486 | 487 | def calculate_metric_percase(pred, gt): 488 | pred[pred > 0] = 1 489 | gt[gt > 0] = 1 490 | if pred.sum() > 0 and gt.sum()>0: 491 | dice = metric.binary.dc(pred, gt) 492 | hd95 = metric.binary.hd95(pred, gt) 493 | return dice, hd95 494 | elif pred.sum() > 0 and gt.sum()==0: 495 | return 1, 0 496 | else: 497 | return 0, 0 498 | 499 | 500 | 501 | def test_single_volume(image, label, net, classes, patch_size=[256, 256], 502 | test_save_path=None, case=None, z_spacing=1, val_or_test=False): 503 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 504 | if len(image.shape) == 3: 505 | prediction = np.zeros_like(label) 506 | for ind in range(image.shape[0]): 507 | slice = image[ind, :, :] 508 | x, y = slice.shape[0], slice.shape[1] 509 | if x != patch_size[0] or y != patch_size[1]: 510 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 511 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 512 | net.eval() 513 | with torch.no_grad(): 514 | outputs = net(input) 515 | out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) 516 | out = out.cpu().detach().numpy() 517 | if x != patch_size[0] or y != patch_size[1]: 518 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 519 | else: 520 | pred = out 521 | prediction[ind] = pred 522 | else: 523 | input = torch.from_numpy(image).unsqueeze( 524 | 0).unsqueeze(0).float().cuda() 525 | net.eval() 526 | with torch.no_grad(): 527 | out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 528 | prediction = out.cpu().detach().numpy() 529 | metric_list = [] 530 | for i in range(1, classes): 531 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 532 | 533 | if test_save_path is not None and val_or_test is True: 534 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 535 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 536 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 537 | img_itk.SetSpacing((1, 1, z_spacing)) 538 | prd_itk.SetSpacing((1, 1, z_spacing)) 539 | lab_itk.SetSpacing((1, 1, z_spacing)) 540 | sitk.WriteImage(prd_itk, test_save_path + '/'+ case + "_pred.nii.gz") 541 | sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz") 542 | sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz") 543 | # cv2.imwrite(test_save_path + '/'+case + '.png', prediction*255) 544 | return metric_list --------------------------------------------------------------------------------