├── .gitignore ├── LICENSE ├── README.md ├── configs ├── config_setting.py ├── config_setting_synapse.py └── config_setting_v2.py ├── data └── README.md ├── datasets └── dataset.py ├── engine.py ├── engine_synapse.py ├── models ├── unet_v2 │ ├── UNet_v2.py │ ├── __init__.py │ └── pvtv2.py ├── vmunet │ ├── vmamba.py │ ├── vmunet.py │ └── vmunet_v2.py └── vmunet_old │ ├── vmamba.py │ └── vmunet.py ├── pre_trained_weights └── README.md ├── train.py ├── train_gim.py ├── train_isic_all.py ├── train_npy.py ├── train_synapse.py ├── unet_v2 ├── UNet_v2.py ├── __init__.py └── pvtv2.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VM-UNet 2 | This is the official code repository for "VM-UNetV2: Rethinking Vision Mamba UNet for Medical 3 | Image Segmentation". {[Arxiv Paper](https://arxiv.org/abs/2403.09157)} 4 | 5 | This work is based the VM-UNet {[Arxiv Paper](https://arxiv.org/abs/2402.02491)} and UNetV2 {[Arxiv Paper](https://arxiv.org/abs/2311.17791)} 6 | 7 | This code is adjusted from VM-UNet{[git link](https://github.com/yaoppeng/U-Net_v2)}, UNetV2{[git link](https://github.com/yaoppeng/U-Net_v2)} and VMamba{[git link](https://github.com/MzeroMiko/VMamba)} 8 | 9 | All VMamba pretrained weights can be download [here](https://github.com/MzeroMiko/VMamba) and PVT weights [here](https://github.com/yaoppeng/U-Net_v2) 10 | **Thanks for their base code** -------------------------------------------------------------------------------- /configs/config_setting.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from utils import * 3 | 4 | from datetime import datetime 5 | 6 | class setting_config: 7 | """ 8 | the config of training setting. 9 | """ 10 | 11 | network = 'vmunet' 12 | model_config = { 13 | 'num_classes': 1, 14 | 'input_channels': 3, 15 | # ----- VM-UNet ----- # 16 | 'depths': [2,2,2,2], 17 | 'depths_decoder': [2,2,2,1], 18 | 'drop_path_rate': 0.2, 19 | 'load_ckpt_path': './pre_trained_weights/vmamba_small_e238_ema.pth', 20 | } 21 | 22 | datasets = 'isic18' 23 | if datasets == 'isic18': 24 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/isic/isic2018/' 25 | elif datasets == 'isic17': 26 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/isic/isic2017/' 27 | elif datasets == 'polyp': 28 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/polyp_unet2/' 29 | else: 30 | raise Exception('datasets in not right!') 31 | 32 | criterion = BceDiceLoss(wb=1, wd=1) 33 | 34 | pretrained_path = './pre_trained/' 35 | num_classes = 1 36 | input_size_h = 256 37 | input_size_w = 256 38 | input_channels = 3 39 | distributed = False 40 | local_rank = -1 41 | num_workers = 8 42 | seed = 42 43 | world_size = None 44 | rank = None 45 | amp = False 46 | gpu_id = '0' 47 | batch_size = 60 48 | epochs = 300 49 | 50 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 51 | 52 | print_interval = 20 53 | val_interval = 2 54 | save_interval = 10 55 | threshold = 0.5 56 | 57 | train_transformer = transforms.Compose([ 58 | myNormalize(datasets, train=True), 59 | myToTensor(), 60 | myRandomHorizontalFlip(p=0.5), 61 | myRandomVerticalFlip(p=0.5), 62 | myRandomRotation(p=0.5, degree=[0, 360]), 63 | myResize(input_size_h, input_size_w) 64 | ]) 65 | test_transformer = transforms.Compose([ 66 | myNormalize(datasets, train=False), 67 | myToTensor(), 68 | myResize(input_size_h, input_size_w) 69 | ]) 70 | 71 | opt = 'AdamW' 72 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 73 | if opt == 'Adadelta': 74 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 75 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 76 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 77 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 78 | elif opt == 'Adagrad': 79 | lr = 0.01 # default: 0.01 – learning rate 80 | lr_decay = 0 # default: 0 – learning rate decay 81 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 82 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 83 | elif opt == 'Adam': 84 | lr = 0.001 # default: 1e-3 – learning rate 85 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 86 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 87 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 88 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 89 | elif opt == 'AdamW': 90 | lr = 0.001 # default: 1e-3 – learning rate 91 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 92 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 93 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 94 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 95 | elif opt == 'Adamax': 96 | lr = 2e-3 # default: 2e-3 – learning rate 97 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 98 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 99 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 100 | elif opt == 'ASGD': 101 | lr = 0.01 # default: 1e-2 – learning rate 102 | lambd = 1e-4 # default: 1e-4 – decay term 103 | alpha = 0.75 # default: 0.75 – power for eta update 104 | t0 = 1e6 # default: 1e6 – point at which to start averaging 105 | weight_decay = 0 # default: 0 – weight decay 106 | elif opt == 'RMSprop': 107 | lr = 1e-2 # default: 1e-2 – learning rate 108 | momentum = 0 # default: 0 – momentum factor 109 | alpha = 0.99 # default: 0.99 – smoothing constant 110 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 111 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 112 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 113 | elif opt == 'Rprop': 114 | lr = 1e-2 # default: 1e-2 – learning rate 115 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 116 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 117 | elif opt == 'SGD': 118 | lr = 0.01 # – learning rate 119 | momentum = 0.9 # default: 0 – momentum factor 120 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 121 | dampening = 0 # default: 0 – dampening for momentum 122 | nesterov = False # default: False – enables Nesterov momentum 123 | 124 | sch = 'CosineAnnealingLR' 125 | if sch == 'StepLR': 126 | step_size = epochs // 5 # – Period of learning rate decay. 127 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 128 | last_epoch = -1 # – The index of last epoch. Default: -1. 129 | elif sch == 'MultiStepLR': 130 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 131 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 132 | last_epoch = -1 # – The index of last epoch. Default: -1. 133 | elif sch == 'ExponentialLR': 134 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 135 | last_epoch = -1 # – The index of last epoch. Default: -1. 136 | elif sch == 'CosineAnnealingLR': 137 | T_max = 50 # – Maximum number of iterations. Cosine function period. 138 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 139 | last_epoch = -1 # – The index of last epoch. Default: -1. 140 | elif sch == 'ReduceLROnPlateau': 141 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 142 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 143 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 144 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 145 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 146 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 147 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 148 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 149 | elif sch == 'CosineAnnealingWarmRestarts': 150 | T_0 = 50 # – Number of iterations for the first restart. 151 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 152 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 153 | last_epoch = -1 # – The index of last epoch. Default: -1. 154 | elif sch == 'WP_MultiStepLR': 155 | warm_up_epochs = 10 156 | gamma = 0.1 157 | milestones = [125, 225] 158 | elif sch == 'WP_CosineLR': 159 | warm_up_epochs = 20 160 | -------------------------------------------------------------------------------- /configs/config_setting_synapse.py: -------------------------------------------------------------------------------- 1 | from pickle import FALSE 2 | from torchvision import transforms 3 | from datasets.dataset import * 4 | from utils import * 5 | 6 | from datetime import datetime 7 | import ml_collections 8 | 9 | class setting_config: 10 | """ 11 | the config of training setting. 12 | """ 13 | network = 'vmunet' 14 | model_config = { 15 | 'num_classes': 9, 16 | 'input_channels': 3, 17 | # ----- VM-UNet ----- # 18 | 'depths': [2,2,2,2], 19 | 'depths_decoder': [2,2,2,1], 20 | 'drop_path_rate': 0.2, 21 | 'load_ckpt_path': './pre_trained_weights/vmamba_small_e238_ema.pth', 22 | } 23 | datasets_name = 'synapse' 24 | input_size_h = 224 25 | input_size_w = 224 26 | if datasets_name == 'synapse': 27 | data_path = './data/Synapse/train_npz/' 28 | datasets = Synapse_dataset 29 | list_dir = './data/Synapse/lists/lists_Synapse/' 30 | volume_path = './data/Synapse/test_vol_h5/' 31 | else: 32 | raise Exception('datasets in not right!') 33 | 34 | pretrained_path = '' # if using pretrained, please enter the path of weights 35 | num_classes = 9 36 | loss_weight = [1, 1] 37 | criterion = CeDiceLoss(num_classes, loss_weight) 38 | z_spacing = 1 39 | input_channels = 3 40 | 41 | distributed = False 42 | local_rank = -1 43 | num_workers = 0 #16 44 | seed = 42 45 | world_size = None 46 | rank = None 47 | amp = False 48 | 49 | batch_size = 32 50 | epochs = 300 51 | work_dir = 'results/' + network + '_' + datasets_name + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 52 | # 'D:/CODES/MedSeg/BIBM22/results/datrm2_isic18_Sunday_04_September_2022_12h_04m_10s/' 53 | print_interval = 20 54 | val_interval = 100 55 | test_weights_path = '' 56 | 57 | threshold = 0.5 58 | 59 | opt = 'AdamW' 60 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 61 | if opt == 'Adadelta': 62 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 63 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 64 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 65 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 66 | elif opt == 'Adagrad': 67 | lr = 0.01 # default: 0.01 – learning rate 68 | lr_decay = 0 # default: 0 – learning rate decay 69 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 70 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 71 | elif opt == 'Adam': 72 | lr = 0.0001 # default: 1e-3 – learning rate 73 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 74 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 75 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 76 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 77 | elif opt == 'AdamW': 78 | lr = 0.001 # default: 1e-3 – learning rate 79 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 80 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 81 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 82 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 83 | elif opt == 'Adamax': 84 | lr = 2e-3 # default: 2e-3 – learning rate 85 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 86 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 87 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 88 | elif opt == 'ASGD': 89 | lr = 0.01 # default: 1e-2 – learning rate 90 | lambd = 1e-4 # default: 1e-4 – decay term 91 | alpha = 0.75 # default: 0.75 – power for eta update 92 | t0 = 1e6 # default: 1e6 – point at which to start averaging 93 | weight_decay = 0 # default: 0 – weight decay 94 | elif opt == 'RMSprop': 95 | lr = 1e-2 # default: 1e-2 – learning rate 96 | momentum = 0 # default: 0 – momentum factor 97 | alpha = 0.99 # default: 0.99 – smoothing constant 98 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 99 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 100 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 101 | elif opt == 'Rprop': 102 | lr = 1e-2 # default: 1e-2 – learning rate 103 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 104 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 105 | elif opt == 'SGD': 106 | lr = 0.003 # – learning rate 107 | momentum = 0.9 # default: 0 – momentum factor 108 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 109 | dampening = 0 # default: 0 – dampening for momentum 110 | nesterov = False # default: False – enables Nesterov momentum 111 | 112 | sch = 'CosineAnnealingLR' 113 | if sch == 'StepLR': 114 | step_size = epochs // 5 # – Period of learning rate decay. 115 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 116 | last_epoch = -1 # – The index of last epoch. Default: -1. 117 | elif sch == 'MultiStepLR': 118 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 119 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 120 | last_epoch = -1 # – The index of last epoch. Default: -1. 121 | elif sch == 'ExponentialLR': 122 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 123 | last_epoch = -1 # – The index of last epoch. Default: -1. 124 | elif sch == 'CosineAnnealingLR': 125 | T_max = 100 # – Maximum number of iterations. Cosine function period. 126 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 127 | last_epoch = -1 # – The index of last epoch. Default: -1. 128 | elif sch == 'ReduceLROnPlateau': 129 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 130 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 131 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 132 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 133 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 134 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 135 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 136 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 137 | elif sch == 'CosineAnnealingWarmRestarts': 138 | T_0 = 50 # – Number of iterations for the first restart. 139 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 140 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 141 | last_epoch = -1 # – The index of last epoch. Default: -1. 142 | elif sch == 'WP_MultiStepLR': 143 | warm_up_epochs = 10 144 | gamma = 0.1 145 | milestones = [125, 225] 146 | elif sch == 'WP_CosineLR': 147 | warm_up_epochs = 20 148 | -------------------------------------------------------------------------------- /configs/config_setting_v2.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from utils import * 3 | 4 | from datetime import datetime 5 | 6 | class setting_config: 7 | """ 8 | the config of training setting. 9 | """ 10 | 11 | network = 'vmunet-v2' 12 | model_config = { 13 | 'num_classes': 1, 14 | 'input_channels': 3, 15 | # ----- VM-UNet-V2 -----2 9 27 small # 16 | 'depths': [2,2,9,2], 17 | 'depths_decoder': [2,2,2,1], 18 | 'drop_path_rate': 0.2, 19 | 'load_ckpt_path': './pre_trained_weights/vmamba_small_e238_ema.pth', 20 | 'deep_supervision': True, 21 | } 22 | 23 | datasets = 'polyp' 24 | if datasets == 'isic18': 25 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/isic/isic2018/' 26 | elif datasets == 'isic17': 27 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/isic/isic2017/' 28 | elif datasets == 'polyp': 29 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/polyp_unet2/' 30 | elif datasets == 'gim': 31 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/endoscopy/endoscopy_sod/' 32 | elif datasets == 'isic_all': 33 | data_path = '/raid/code/mamba_all/VM-UNet/data/zd-medic/isic_all/' 34 | else: 35 | raise Exception('datasets in not right!') 36 | 37 | criterion = BceDiceLoss(wb=1, wd=1) 38 | 39 | pretrained_path = './pre_trained/' 40 | num_classes = 1 41 | input_size_h = 256 42 | input_size_w = 256 43 | input_channels = 3 44 | distributed = False 45 | local_rank = -1 46 | num_workers = 8 47 | seed = 42 48 | world_size = None 49 | rank = None 50 | amp = False 51 | gpu_id = '2' 52 | batch_size = 80 53 | epochs = 300 54 | 55 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 56 | 57 | print_interval = 20 58 | val_interval = 2 59 | save_interval = 10 60 | threshold = 0.5 61 | 62 | train_transformer = transforms.Compose([ 63 | myNormalize(datasets, train=True), # 64 | myToTensor(), 65 | myRandomHorizontalFlip(p=0.5), 66 | myRandomVerticalFlip(p=0.5), 67 | myRandomRotation(p=0.5, degree=[0, 360]), 68 | myResize(input_size_h, input_size_w) 69 | ]) 70 | test_transformer = transforms.Compose([ 71 | myNormalize(datasets, train=False), 72 | myToTensor(), 73 | myResize(input_size_h, input_size_w) 74 | ]) 75 | 76 | opt = 'AdamW' 77 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 78 | if opt == 'Adadelta': 79 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 80 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 81 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 82 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 83 | elif opt == 'Adagrad': 84 | lr = 0.01 # default: 0.01 – learning rate 85 | lr_decay = 0 # default: 0 – learning rate decay 86 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 87 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 88 | elif opt == 'Adam': 89 | lr = 0.001 # default: 1e-3 – learning rate 90 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 91 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 92 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 93 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 94 | elif opt == 'AdamW': 95 | lr = 0.001 # default: 1e-3 – learning rate 96 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 97 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 98 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 99 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 100 | elif opt == 'Adamax': 101 | lr = 2e-3 # default: 2e-3 – learning rate 102 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 103 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 104 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 105 | elif opt == 'ASGD': 106 | lr = 0.01 # default: 1e-2 – learning rate 107 | lambd = 1e-4 # default: 1e-4 – decay term 108 | alpha = 0.75 # default: 0.75 – power for eta update 109 | t0 = 1e6 # default: 1e6 – point at which to start averaging 110 | weight_decay = 0 # default: 0 – weight decay 111 | elif opt == 'RMSprop': 112 | lr = 1e-2 # default: 1e-2 – learning rate 113 | momentum = 0 # default: 0 – momentum factor 114 | alpha = 0.99 # default: 0.99 – smoothing constant 115 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 116 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 117 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 118 | elif opt == 'Rprop': 119 | lr = 1e-2 # default: 1e-2 – learning rate 120 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 121 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 122 | elif opt == 'SGD': 123 | lr = 0.01 # – learning rate 124 | momentum = 0.9 # default: 0 – momentum factor 125 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 126 | dampening = 0 # default: 0 – dampening for momentum 127 | nesterov = False # default: False – enables Nesterov momentum 128 | 129 | sch = 'CosineAnnealingLR' 130 | if sch == 'StepLR': 131 | step_size = epochs // 5 # – Period of learning rate decay. 132 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 133 | last_epoch = -1 # – The index of last epoch. Default: -1. 134 | elif sch == 'MultiStepLR': 135 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 136 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 137 | last_epoch = -1 # – The index of last epoch. Default: -1. 138 | elif sch == 'ExponentialLR': 139 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 140 | last_epoch = -1 # – The index of last epoch. Default: -1. 141 | elif sch == 'CosineAnnealingLR': 142 | T_max = 50 # – Maximum number of iterations. Cosine function period. 143 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 144 | last_epoch = -1 # – The index of last epoch. Default: -1. 145 | elif sch == 'ReduceLROnPlateau': 146 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 147 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 148 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 149 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 150 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 151 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 152 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 153 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 154 | elif sch == 'CosineAnnealingWarmRestarts': 155 | T_0 = 50 # – Number of iterations for the first restart. 156 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 157 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 158 | last_epoch = -1 # – The index of last epoch. Default: -1. 159 | elif sch == 'WP_MultiStepLR': 160 | warm_up_epochs = 10 161 | gamma = 0.1 162 | milestones = [125, 225] 163 | elif sch == 'WP_CosineLR': 164 | warm_up_epochs = 20 165 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | You could put the datasets here. -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | 6 | import random 7 | import h5py 8 | import torch 9 | from scipy import ndimage 10 | from scipy.ndimage.interpolation import zoom 11 | from torch.utils.data import Dataset 12 | from scipy import ndimage 13 | from PIL import Image 14 | 15 | from configs.config_setting import setting_config # debug use only 16 | 17 | 18 | # ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB'] 19 | class Polyp_datasets(Dataset): 20 | def __init__(self, path_Data, config, train=True, test_dataset = 'CVC-300'): 21 | super(Polyp_datasets, self) 22 | if train: 23 | images_list = sorted(os.listdir(path_Data+'TrainDataset/images/')) 24 | masks_list = sorted(os.listdir(path_Data+'TrainDataset/masks/')) 25 | self.data = [] 26 | for i in range(len(images_list)): 27 | img_path = path_Data+'TrainDataset/images/' + images_list[i] 28 | mask_path = path_Data+'TrainDataset/masks/' + masks_list[i] 29 | self.data.append([img_path, mask_path]) 30 | self.transformer = config.train_transformer 31 | else: # test 数据集需要 加 test 数据集的名称 32 | images_list = sorted(os.listdir(path_Data+'TestDataset/' + test_dataset + '/images/')) 33 | masks_list = sorted(os.listdir(path_Data+'TestDataset/' + test_dataset + '/masks/')) 34 | self.data = [] 35 | for i in range(len(images_list)): 36 | img_path = path_Data+'TestDataset/' + test_dataset + '/images/' + images_list[i] 37 | mask_path = path_Data+'TestDataset/' + test_dataset + '/masks/' + masks_list[i] 38 | self.data.append([img_path, mask_path]) 39 | self.transformer = config.test_transformer 40 | 41 | def __getitem__(self, index): 42 | img_path, msk_path = self.data[index] 43 | img = np.array(Image.open(img_path).convert('RGB')) 44 | # isic 数据集未做二值化处理 45 | msk = np.expand_dims(np.array(Image.open(msk_path).convert('L')), axis=2) / 255 46 | img, msk = self.transformer((img, msk)) 47 | return img, msk 48 | 49 | def __len__(self): 50 | return len(self.data) 51 | 52 | # ['isic17', 'isic18'] 53 | class Isic_datasets(Dataset): 54 | def __init__(self, path_Data, config, train=True, test_dataset = 'isic17'): 55 | super(Isic_datasets, self) 56 | if train: 57 | images_list = sorted(os.listdir(path_Data+'train/images/')) 58 | masks_list = sorted(os.listdir(path_Data+'train/masks/')) 59 | self.data = [] 60 | for i in range(len(images_list)): 61 | img_path = path_Data+'train/images/' + images_list[i] 62 | mask_path = path_Data+'train/masks/' + masks_list[i] 63 | self.data.append([img_path, mask_path]) 64 | self.transformer = config.train_transformer 65 | else: # test 数据集需要 加 test 数据集的名称 66 | images_list = sorted(os.listdir(path_Data+'val/' + test_dataset + '/images/')) 67 | masks_list = sorted(os.listdir(path_Data+'val/' + test_dataset + '/masks/')) 68 | self.data = [] 69 | for i in range(len(images_list)): 70 | img_path = path_Data+'val/' + test_dataset + '/images/' + images_list[i] 71 | mask_path = path_Data+'val/' + test_dataset + '/masks/' + masks_list[i] 72 | self.data.append([img_path, mask_path]) 73 | self.transformer = config.test_transformer 74 | 75 | def __getitem__(self, index): 76 | img_path, msk_path = self.data[index] 77 | img = np.array(Image.open(img_path).convert('RGB')) 78 | # isic 数据集未做二值化处理 79 | msk = np.expand_dims(np.array(Image.open(msk_path).convert('L')), axis=2) / 255 80 | img, msk = self.transformer((img, msk)) 81 | return img, msk 82 | 83 | def __len__(self): 84 | return len(self.data) 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | class GIM_datasets(Dataset): 93 | def __init__(self, path_Data, config, train=True): 94 | super(GIM_datasets, self) 95 | if train: 96 | images_list = sorted(os.listdir(path_Data+'train/image/')) 97 | masks_list = sorted(os.listdir(path_Data+'train/mask/')) 98 | self.data = [] 99 | for i in range(len(images_list)): 100 | img_path = path_Data+'train/image/' + images_list[i] 101 | mask_path = path_Data+'train/mask/' + masks_list[i] 102 | self.data.append([img_path, mask_path]) 103 | self.transformer = config.train_transformer 104 | else: 105 | images_list = sorted(os.listdir(path_Data+'val/image/')) 106 | masks_list = sorted(os.listdir(path_Data+'val/mask/')) 107 | self.data = [] 108 | for i in range(len(images_list)): 109 | img_path = path_Data+'val/image/' + images_list[i] 110 | mask_path = path_Data+'val/mask/' + masks_list[i] 111 | self.data.append([img_path, mask_path]) 112 | self.transformer = config.test_transformer 113 | 114 | def __getitem__(self, index): 115 | img_path, msk_path = self.data[index] 116 | img = np.array(Image.open(img_path).convert('RGB')) 117 | # isic 数据集未做二值化处理 118 | msk = np.expand_dims(np.array(Image.open(msk_path).convert('L')), axis=2) / 255 119 | img, msk = self.transformer((img, msk)) 120 | return img, msk 121 | 122 | def __len__(self): 123 | return len(self.data) 124 | 125 | 126 | 127 | 128 | class NPY_datasets(Dataset): 129 | def __init__(self, path_Data, config, train=True): 130 | super(NPY_datasets, self) 131 | if train: 132 | images_list = sorted(os.listdir(path_Data+'train/images/')) 133 | masks_list = sorted(os.listdir(path_Data+'train/masks/')) 134 | self.data = [] 135 | for i in range(len(images_list)): 136 | img_path = path_Data+'train/images/' + images_list[i] 137 | mask_path = path_Data+'train/masks/' + masks_list[i] 138 | self.data.append([img_path, mask_path]) 139 | self.transformer = config.train_transformer 140 | else: 141 | images_list = sorted(os.listdir(path_Data+'val/images/')) 142 | masks_list = sorted(os.listdir(path_Data+'val/masks/')) 143 | self.data = [] 144 | for i in range(len(images_list)): 145 | img_path = path_Data+'val/images/' + images_list[i] 146 | mask_path = path_Data+'val/masks/' + masks_list[i] 147 | self.data.append([img_path, mask_path]) 148 | self.transformer = config.test_transformer 149 | 150 | def __getitem__(self, index): 151 | img_path, msk_path = self.data[index] 152 | img = np.array(Image.open(img_path).convert('RGB')) 153 | # isic 数据集未做二值化处理 154 | msk = np.expand_dims(np.array(Image.open(msk_path).convert('L')), axis=2) / 255 155 | img, msk = self.transformer((img, msk)) 156 | return img, msk 157 | 158 | def __len__(self): 159 | return len(self.data) 160 | 161 | 162 | 163 | def random_rot_flip(image, label): 164 | k = np.random.randint(0, 4) 165 | image = np.rot90(image, k) 166 | label = np.rot90(label, k) 167 | axis = np.random.randint(0, 2) 168 | image = np.flip(image, axis=axis).copy() 169 | label = np.flip(label, axis=axis).copy() 170 | return image, label 171 | 172 | 173 | def random_rotate(image, label): 174 | angle = np.random.randint(-20, 20) 175 | image = ndimage.rotate(image, angle, order=0, reshape=False) 176 | label = ndimage.rotate(label, angle, order=0, reshape=False) 177 | return image, label 178 | 179 | 180 | class RandomGenerator(object): 181 | def __init__(self, output_size): 182 | self.output_size = output_size 183 | 184 | def __call__(self, sample): 185 | image, label = sample['image'], sample['label'] 186 | 187 | if random.random() > 0.5: 188 | image, label = random_rot_flip(image, label) 189 | elif random.random() > 0.5: 190 | image, label = random_rotate(image, label) 191 | x, y = image.shape 192 | if x != self.output_size[0] or y != self.output_size[1]: 193 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 194 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 195 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 196 | label = torch.from_numpy(label.astype(np.float32)) 197 | sample = {'image': image, 'label': label.long()} 198 | return sample 199 | 200 | 201 | class Synapse_dataset(Dataset): 202 | def __init__(self, base_dir, list_dir, split, transform=None): 203 | self.transform = transform # using transform in torch! 204 | self.split = split 205 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 206 | self.data_dir = base_dir 207 | 208 | def __len__(self): 209 | return len(self.sample_list) 210 | 211 | def __getitem__(self, idx): 212 | if self.split == "train": 213 | slice_name = self.sample_list[idx].strip('\n') 214 | data_path = os.path.join(self.data_dir, slice_name+'.npz') 215 | data = np.load(data_path) 216 | image, label = data['image'], data['label'] 217 | else: 218 | vol_name = self.sample_list[idx].strip('\n') 219 | filepath = self.data_dir + "/{}.npy.h5".format(vol_name) 220 | data = h5py.File(filepath) 221 | image, label = data['image'][:], data['label'][:] 222 | 223 | sample = {'image': image, 'label': label} 224 | if self.transform: 225 | sample = self.transform(sample) 226 | sample['case_name'] = self.sample_list[idx].strip('\n') 227 | return sample 228 | 229 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from torch.cuda.amp import autocast as autocast 5 | from sklearn.metrics import confusion_matrix 6 | from utils import save_imgs 7 | 8 | 9 | def train_one_epoch(train_loader, 10 | model, 11 | criterion, 12 | optimizer, 13 | scheduler, 14 | epoch, 15 | step, 16 | logger, 17 | config, 18 | writer): 19 | ''' 20 | train model for one epoch 21 | ''' 22 | # switch to train mode 23 | model.train() 24 | 25 | loss_list = [] 26 | 27 | for iter, data in enumerate(train_loader): 28 | step += iter 29 | optimizer.zero_grad() 30 | images, targets = data 31 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() 32 | 33 | out = model(images) 34 | loss = criterion(out, targets) 35 | 36 | loss.backward() 37 | optimizer.step() 38 | 39 | loss_list.append(loss.item()) 40 | 41 | now_lr = optimizer.state_dict()['param_groups'][0]['lr'] 42 | 43 | writer.add_scalar('loss', loss, global_step=step) 44 | 45 | if iter % config.print_interval == 0: 46 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}' 47 | print(log_info) 48 | logger.info(log_info) 49 | scheduler.step() 50 | return step 51 | 52 | 53 | def val_one_epoch(test_loader, 54 | model, 55 | criterion, 56 | epoch, 57 | logger, 58 | config, 59 | val_data_name=None): 60 | # switch to evaluate mode 61 | model.eval() 62 | preds = [] 63 | gts = [] 64 | loss_list = [] 65 | with torch.no_grad(): 66 | for data in tqdm(test_loader): 67 | img, msk = data 68 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 69 | 70 | out = model(img) 71 | 72 | loss = criterion(out, msk) 73 | 74 | loss_list.append(loss.item()) 75 | gts.append(msk.squeeze(1).cpu().detach().numpy()) 76 | if type(out) is tuple: 77 | out = out[0] 78 | out = out.squeeze(1).cpu().detach().numpy() 79 | preds.append(out) 80 | 81 | if epoch % config.val_interval == 0: 82 | preds = np.array(preds).reshape(-1) 83 | gts = np.array(gts).reshape(-1) 84 | 85 | y_pre = np.where(preds>=config.threshold, 1, 0) 86 | y_true = np.where(gts>=0.5, 1, 0) 87 | 88 | confusion = confusion_matrix(y_true, y_pre) 89 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 90 | 91 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 92 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 93 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 94 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 95 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 96 | 97 | if val_data_name is not None: 98 | log_info = f'val_datasets_name: {val_data_name}' 99 | print(log_info) 100 | logger.info(log_info) 101 | log_info = f' val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 102 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 103 | print(log_info) 104 | logger.info(log_info) 105 | 106 | else: 107 | log_info = f' val epoch: {epoch}, loss: {np.mean(loss_list):.4f}' 108 | print(log_info) 109 | logger.info(log_info) 110 | 111 | return np.mean(loss_list) 112 | 113 | 114 | def test_one_epoch(test_loader, 115 | model, 116 | criterion, 117 | logger, 118 | config, 119 | test_data_name=None): 120 | # switch to evaluate mode 121 | model.eval() 122 | preds = [] 123 | gts = [] 124 | loss_list = [] 125 | with torch.no_grad(): 126 | for i, data in enumerate(tqdm(test_loader)): 127 | img, msk = data 128 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 129 | 130 | out = model(img) 131 | loss = criterion(out, msk) 132 | 133 | loss_list.append(loss.item()) 134 | msk = msk.squeeze(1).cpu().detach().numpy() 135 | gts.append(msk) 136 | if type(out) is tuple: 137 | out = out[0] 138 | out = out.squeeze(1).cpu().detach().numpy() 139 | preds.append(out) 140 | if i % config.save_interval == 0: 141 | save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name) 142 | 143 | preds = np.array(preds).reshape(-1) 144 | gts = np.array(gts).reshape(-1) 145 | 146 | y_pre = np.where(preds>=config.threshold, 1, 0) 147 | y_true = np.where(gts>=0.5, 1, 0) 148 | 149 | confusion = confusion_matrix(y_true, y_pre) 150 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 151 | 152 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 153 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 154 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 155 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 156 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 157 | 158 | if test_data_name is not None: 159 | log_info = f'test_datasets_name: {test_data_name}' 160 | print(log_info) 161 | logger.info(log_info) 162 | log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 163 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 164 | print(log_info) 165 | logger.info(log_info) 166 | 167 | return np.mean(loss_list) -------------------------------------------------------------------------------- /engine_synapse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from torch.cuda.amp import autocast as autocast 5 | import torch 6 | 7 | from sklearn.metrics import confusion_matrix 8 | 9 | from scipy.ndimage.morphology import binary_fill_holes, binary_opening 10 | 11 | from utils import test_single_volume 12 | 13 | import time 14 | 15 | 16 | def train_one_epoch(train_loader, 17 | model, 18 | criterion, 19 | optimizer, 20 | scheduler, 21 | epoch, 22 | logger, 23 | config, 24 | scaler=None): 25 | ''' 26 | train model for one epoch 27 | ''' 28 | stime = time.time() 29 | model.train() 30 | 31 | loss_list = [] 32 | 33 | for iter, data in enumerate(train_loader): 34 | optimizer.zero_grad() 35 | 36 | images, targets = data['image'], data['label'] 37 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() 38 | 39 | if config.amp: 40 | with autocast(): 41 | out = model(images) 42 | loss = criterion(out, targets) 43 | scaler.scale(loss).backward() 44 | scaler.step(optimizer) 45 | scaler.update() 46 | else: 47 | out = model(images) 48 | loss = criterion(out, targets) 49 | loss.backward() 50 | optimizer.step() 51 | 52 | loss_list.append(loss.item()) 53 | now_lr = optimizer.state_dict()['param_groups'][0]['lr'] 54 | mean_loss = np.mean(loss_list) 55 | if iter % config.print_interval == 0: 56 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {loss.item():.4f}, lr: {now_lr}' 57 | print(log_info) 58 | logger.info(log_info) 59 | scheduler.step() 60 | etime = time.time() 61 | log_info = f'Finish one epoch train: epoch {epoch}, loss: {mean_loss:.4f}, time(s): {etime-stime:.2f}' 62 | print(log_info) 63 | logger.info(log_info) 64 | return mean_loss 65 | 66 | 67 | 68 | 69 | 70 | def val_one_epoch(test_datasets, 71 | test_loader, 72 | model, 73 | epoch, 74 | logger, 75 | config, 76 | test_save_path, 77 | val_or_test=False): 78 | # switch to evaluate mode 79 | stime = time.time() 80 | model.eval() 81 | with torch.no_grad(): 82 | metric_list = 0.0 83 | i_batch = 0 84 | for data in tqdm(test_loader): 85 | img, msk, case_name = data['image'], data['label'], data['case_name'][0] 86 | metric_i = test_single_volume(img, msk, model, classes=config.num_classes, patch_size=[config.input_size_h, config.input_size_w], 87 | test_save_path=test_save_path, case=case_name, z_spacing=config.z_spacing, val_or_test=val_or_test) 88 | metric_list += np.array(metric_i) 89 | 90 | logger.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, 91 | np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) 92 | i_batch += 1 93 | metric_list = metric_list / len(test_datasets) 94 | performance = np.mean(metric_list, axis=0)[0] 95 | mean_hd95 = np.mean(metric_list, axis=0)[1] 96 | for i in range(1, config.num_classes): 97 | logger.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1])) 98 | performance = np.mean(metric_list, axis=0)[0] 99 | mean_hd95 = np.mean(metric_list, axis=0)[1] 100 | etime = time.time() 101 | log_info = f'val epoch: {epoch}, mean_dice: {performance}, mean_hd95: {mean_hd95}, time(s): {etime-stime:.2f}' 102 | print(log_info) 103 | logger.info(log_info) 104 | 105 | return performance, mean_hd95 -------------------------------------------------------------------------------- /models/unet_v2/UNet_v2.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import warnings 3 | 4 | import torch 5 | from torch import nn 6 | from unet_v2.pvtv2 import * 7 | import torch.nn.functional as F 8 | 9 | 10 | class ChannelAttention(nn.Module): 11 | def __init__(self, in_planes, ratio=16): 12 | super(ChannelAttention, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.max_pool = nn.AdaptiveMaxPool2d(1) 15 | 16 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 17 | self.relu1 = nn.ReLU() 18 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 19 | 20 | self.sigmoid = nn.Sigmoid() 21 | 22 | def forward(self, x): 23 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 24 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 25 | out = avg_out + max_out 26 | return self.sigmoid(out) 27 | 28 | 29 | class SpatialAttention(nn.Module): 30 | def __init__(self, kernel_size=7): 31 | super(SpatialAttention, self).__init__() 32 | 33 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 34 | padding = 3 if kernel_size == 7 else 1 35 | 36 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 37 | self.sigmoid = nn.Sigmoid() 38 | 39 | def forward(self, x): 40 | avg_out = torch.mean(x, dim=1, keepdim=True) 41 | max_out, _ = torch.max(x, dim=1, keepdim=True) 42 | x = torch.cat([avg_out, max_out], dim=1) 43 | x = self.conv1(x) 44 | return self.sigmoid(x) 45 | 46 | 47 | class BasicConv2d(nn.Module): 48 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 49 | super(BasicConv2d, self).__init__() 50 | 51 | self.conv = nn.Conv2d(in_planes, out_planes, 52 | kernel_size=kernel_size, stride=stride, 53 | padding=padding, dilation=dilation, bias=False) 54 | self.bn = nn.BatchNorm2d(out_planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | def forward(self, x): 58 | x = self.conv(x) 59 | x = self.bn(x) 60 | return x 61 | 62 | 63 | class Encoder(nn.Module): 64 | def __init__(self, pretrain_path): 65 | super().__init__() 66 | self.backbone = pvt_v2_b2() 67 | 68 | if pretrain_path is None: 69 | warnings.warn('please provide the pretrained pvt model. Not using pretrained model.') 70 | elif not os.path.isfile(pretrain_path): 71 | warnings.warn(f'path: {pretrain_path} does not exists. Not using pretrained model.') 72 | else: 73 | print(f"using pretrained file: {pretrain_path}") 74 | save_model = torch.load(pretrain_path) 75 | model_dict = self.backbone.state_dict() 76 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 77 | model_dict.update(state_dict) 78 | 79 | self.backbone.load_state_dict(model_dict) 80 | 81 | def forward(self, x): 82 | f1, f2, f3, f4 = self.backbone(x) # (x: 3, 352, 352) 83 | return f1, f2, f3, f4 84 | 85 | 86 | class SDI(nn.Module): 87 | def __init__(self, channel): 88 | super().__init__() 89 | 90 | self.convs = nn.ModuleList( 91 | [nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)]) 92 | 93 | def forward(self, xs, anchor): 94 | ans = torch.ones_like(anchor) 95 | target_size = anchor.shape[-1] 96 | 97 | for i, x in enumerate(xs): 98 | if x.shape[-1] > target_size: 99 | x = F.adaptive_avg_pool2d(x, (target_size, target_size)) 100 | elif x.shape[-1] < target_size: 101 | x = F.interpolate(x, size=(target_size, target_size), 102 | mode='bilinear', align_corners=True) 103 | 104 | ans = ans * self.convs[i](x) 105 | 106 | return ans 107 | 108 | 109 | class UNetV2(nn.Module): 110 | """ 111 | use SpatialAtt + ChannelAtt 112 | """ 113 | def __init__(self, channel=32, n_classes=1, deep_supervision=True, pretrained_path=None): 114 | super().__init__() 115 | self.deep_supervision = deep_supervision 116 | 117 | self.encoder = Encoder(pretrained_path) 118 | 119 | self.ca_1 = ChannelAttention(64) 120 | self.sa_1 = SpatialAttention() 121 | 122 | self.ca_2 = ChannelAttention(128) 123 | self.sa_2 = SpatialAttention() 124 | 125 | self.ca_3 = ChannelAttention(320) 126 | self.sa_3 = SpatialAttention() 127 | 128 | self.ca_4 = ChannelAttention(512) 129 | self.sa_4 = SpatialAttention() 130 | 131 | self.Translayer_1 = BasicConv2d(64, channel, 1) 132 | self.Translayer_2 = BasicConv2d(128, channel, 1) 133 | self.Translayer_3 = BasicConv2d(320, channel, 1) 134 | self.Translayer_4 = BasicConv2d(512, channel, 1) 135 | 136 | self.sdi_1 = SDI(channel) 137 | self.sdi_2 = SDI(channel) 138 | self.sdi_3 = SDI(channel) 139 | self.sdi_4 = SDI(channel) 140 | 141 | self.seg_outs = nn.ModuleList([ 142 | nn.Conv2d(channel, n_classes, 1, 1) for _ in range(4)]) 143 | 144 | self.deconv2 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1, 145 | bias=False) 146 | self.deconv3 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 147 | padding=1, bias=False) 148 | self.deconv4 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 149 | padding=1, bias=False) 150 | self.deconv5 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 151 | padding=1, bias=False) 152 | 153 | def forward(self, x): 154 | seg_outs = [] 155 | f1, f2, f3, f4 = self.encoder(x) 156 | 157 | f1 = self.ca_1(f1) * f1 158 | f1 = self.sa_1(f1) * f1 159 | f1 = self.Translayer_1(f1) 160 | 161 | f2 = self.ca_2(f2) * f2 162 | f2 = self.sa_2(f2) * f2 163 | f2 = self.Translayer_2(f2) 164 | 165 | f3 = self.ca_3(f3) * f3 166 | f3 = self.sa_3(f3) * f3 167 | f3 = self.Translayer_3(f3) 168 | 169 | f4 = self.ca_4(f4) * f4 170 | f4 = self.sa_4(f4) * f4 171 | f4 = self.Translayer_4(f4) 172 | 173 | f41 = self.sdi_4([f1, f2, f3, f4], f4) 174 | f31 = self.sdi_3([f1, f2, f3, f4], f3) 175 | f21 = self.sdi_2([f1, f2, f3, f4], f2) 176 | f11 = self.sdi_1([f1, f2, f3, f4], f1) 177 | 178 | seg_outs.append(self.seg_outs[0](f41)) 179 | 180 | y = self.deconv2(f41) + f31 181 | seg_outs.append(self.seg_outs[1](y)) 182 | 183 | y = self.deconv3(y) + f21 184 | seg_outs.append(self.seg_outs[2](y)) 185 | 186 | y = self.deconv4(y) + f11 187 | seg_outs.append(self.seg_outs[3](y)) 188 | 189 | for i, o in enumerate(seg_outs): 190 | seg_outs[i] = F.interpolate(o, scale_factor=4, mode='bilinear') 191 | 192 | if self.deep_supervision: 193 | return seg_outs[::-1] 194 | else: 195 | return seg_outs[-1] 196 | 197 | 198 | if __name__ == "__main__": 199 | pretrained_path = "/root/workspace/code/U-Net_v2/PolypSeg/pvt_pretrained/pvt_v2_b2.pth" 200 | model = UNetV2(n_classes=2, deep_supervision=True, pretrained_path=None) 201 | x = torch.rand((2, 3, 256, 256)) 202 | ys = model(x) 203 | for y in ys: 204 | print(y.shape) 205 | -------------------------------------------------------------------------------- /models/unet_v2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nobodyplayer1/VM-UNetV2/b65f04c10237158ddd4ebe5bf2c1b5cbb15524fb/models/unet_v2/__init__.py -------------------------------------------------------------------------------- /models/unet_v2/pvtv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | from timm.models.registry import register_model 10 | 11 | import math 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.dwconv = DWConv(hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | self.apply(self._init_weights) 26 | 27 | def _init_weights(self, m): 28 | if isinstance(m, nn.Linear): 29 | trunc_normal_(m.weight, std=.02) 30 | if isinstance(m, nn.Linear) and m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.LayerNorm): 33 | nn.init.constant_(m.bias, 0) 34 | nn.init.constant_(m.weight, 1.0) 35 | elif isinstance(m, nn.Conv2d): 36 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | fan_out //= m.groups 38 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | 42 | def forward(self, x, H, W): 43 | x = self.fc1(x) 44 | x = self.dwconv(x, H, W) 45 | x = self.act(x) 46 | x = self.drop(x) 47 | x = self.fc2(x) 48 | x = self.drop(x) 49 | return x 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 54 | super().__init__() 55 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 56 | 57 | self.dim = dim 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | self.apply(self._init_weights) 74 | 75 | def _init_weights(self, m): 76 | if isinstance(m, nn.Linear): 77 | trunc_normal_(m.weight, std=.02) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | elif isinstance(m, nn.Conv2d): 84 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | fan_out //= m.groups 86 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 87 | if m.bias is not None: 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x, H, W): 91 | B, N, C = x.shape 92 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 93 | 94 | if self.sr_ratio > 1: 95 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 96 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 97 | x_ = self.norm(x_) 98 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | else: 100 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 101 | k, v = kv[0], kv[1] 102 | 103 | attn = (q @ k.transpose(-2, -1)) * self.scale 104 | attn = attn.softmax(dim=-1) 105 | attn = self.attn_drop(attn) 106 | 107 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 108 | x = self.proj(x) 109 | x = self.proj_drop(x) 110 | 111 | return x 112 | 113 | 114 | class Block(nn.Module): 115 | 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention( 121 | dim, 122 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 123 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 124 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 125 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 126 | self.norm2 = norm_layer(dim) 127 | mlp_hidden_dim = int(dim * mlp_ratio) 128 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 129 | 130 | self.apply(self._init_weights) 131 | 132 | def _init_weights(self, m): 133 | if isinstance(m, nn.Linear): 134 | trunc_normal_(m.weight, std=.02) 135 | if isinstance(m, nn.Linear) and m.bias is not None: 136 | nn.init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.LayerNorm): 138 | nn.init.constant_(m.bias, 0) 139 | nn.init.constant_(m.weight, 1.0) 140 | elif isinstance(m, nn.Conv2d): 141 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 142 | fan_out //= m.groups 143 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 144 | if m.bias is not None: 145 | m.bias.data.zero_() 146 | 147 | def forward(self, x, H, W): 148 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 149 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 150 | 151 | return x 152 | 153 | 154 | class OverlapPatchEmbed(nn.Module): 155 | """ Image to Patch Embedding 156 | """ 157 | 158 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 166 | self.num_patches = self.H * self.W 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 168 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 169 | self.norm = nn.LayerNorm(embed_dim) 170 | 171 | self.apply(self._init_weights) 172 | 173 | def _init_weights(self, m): 174 | if isinstance(m, nn.Linear): 175 | trunc_normal_(m.weight, std=.02) 176 | if isinstance(m, nn.Linear) and m.bias is not None: 177 | nn.init.constant_(m.bias, 0) 178 | elif isinstance(m, nn.LayerNorm): 179 | nn.init.constant_(m.bias, 0) 180 | nn.init.constant_(m.weight, 1.0) 181 | elif isinstance(m, nn.Conv2d): 182 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | fan_out //= m.groups 184 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 185 | if m.bias is not None: 186 | m.bias.data.zero_() 187 | 188 | def forward(self, x): 189 | x = self.proj(x) 190 | _, _, H, W = x.shape 191 | x = x.flatten(2).transpose(1, 2) 192 | x = self.norm(x) 193 | 194 | return x, H, W 195 | 196 | 197 | class PyramidVisionTransformerImpr(nn.Module): 198 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 199 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 200 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 201 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 202 | super().__init__() 203 | self.num_classes = num_classes 204 | self.depths = depths 205 | 206 | # patch_embed 207 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 208 | embed_dim=embed_dims[0]) 209 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 210 | embed_dim=embed_dims[1]) 211 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 212 | embed_dim=embed_dims[2]) 213 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 214 | embed_dim=embed_dims[3]) 215 | 216 | # transformer encoder 217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 218 | cur = 0 219 | self.block1 = nn.ModuleList([Block( 220 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 221 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 222 | sr_ratio=sr_ratios[0]) 223 | for i in range(depths[0])]) 224 | self.norm1 = norm_layer(embed_dims[0]) 225 | 226 | cur += depths[0] 227 | self.block2 = nn.ModuleList([Block( 228 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 229 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 230 | sr_ratio=sr_ratios[1]) 231 | for i in range(depths[1])]) 232 | self.norm2 = norm_layer(embed_dims[1]) 233 | 234 | cur += depths[1] 235 | self.block3 = nn.ModuleList([Block( 236 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 237 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 238 | sr_ratio=sr_ratios[2]) 239 | for i in range(depths[2])]) 240 | self.norm3 = norm_layer(embed_dims[2]) 241 | 242 | cur += depths[2] 243 | self.block4 = nn.ModuleList([Block( 244 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 245 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 246 | sr_ratio=sr_ratios[3]) 247 | for i in range(depths[3])]) 248 | self.norm4 = norm_layer(embed_dims[3]) 249 | 250 | # classification head 251 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 252 | 253 | self.apply(self._init_weights) 254 | 255 | def _init_weights(self, m): 256 | if isinstance(m, nn.Linear): 257 | trunc_normal_(m.weight, std=.02) 258 | if isinstance(m, nn.Linear) and m.bias is not None: 259 | nn.init.constant_(m.bias, 0) 260 | elif isinstance(m, nn.LayerNorm): 261 | nn.init.constant_(m.bias, 0) 262 | nn.init.constant_(m.weight, 1.0) 263 | elif isinstance(m, nn.Conv2d): 264 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 265 | fan_out //= m.groups 266 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 267 | if m.bias is not None: 268 | m.bias.data.zero_() 269 | 270 | def init_weights(self, pretrained=None): 271 | if isinstance(pretrained, str): 272 | logger = 1 273 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 274 | 275 | def reset_drop_path(self, drop_path_rate): 276 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 277 | cur = 0 278 | for i in range(self.depths[0]): 279 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 280 | 281 | cur += self.depths[0] 282 | for i in range(self.depths[1]): 283 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 284 | 285 | cur += self.depths[1] 286 | for i in range(self.depths[2]): 287 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 288 | 289 | cur += self.depths[2] 290 | for i in range(self.depths[3]): 291 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 292 | 293 | def freeze_patch_emb(self): 294 | self.patch_embed1.requires_grad = False 295 | 296 | @torch.jit.ignore 297 | def no_weight_decay(self): 298 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 299 | 300 | def get_classifier(self): 301 | return self.head 302 | 303 | def reset_classifier(self, num_classes, global_pool=''): 304 | self.num_classes = num_classes 305 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 306 | 307 | # def _get_pos_embed(self, pos_embed, patch_embed, H, W): 308 | # if H * W == self.patch_embed1.num_patches: 309 | # return pos_embed 310 | # else: 311 | # return F.interpolate( 312 | # pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 313 | # size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 314 | 315 | def forward_features(self, x): 316 | B = x.shape[0] 317 | outs = [] 318 | 319 | # stage 1 320 | x, H, W = self.patch_embed1(x) 321 | for i, blk in enumerate(self.block1): 322 | x = blk(x, H, W) 323 | x = self.norm1(x) 324 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 325 | outs.append(x) 326 | 327 | # stage 2 328 | x, H, W = self.patch_embed2(x) 329 | for i, blk in enumerate(self.block2): 330 | x = blk(x, H, W) 331 | x = self.norm2(x) 332 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 333 | outs.append(x) 334 | 335 | # stage 3 336 | x, H, W = self.patch_embed3(x) 337 | for i, blk in enumerate(self.block3): 338 | x = blk(x, H, W) 339 | x = self.norm3(x) 340 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 341 | outs.append(x) 342 | 343 | # stage 4 344 | x, H, W = self.patch_embed4(x) 345 | for i, blk in enumerate(self.block4): 346 | x = blk(x, H, W) 347 | x = self.norm4(x) 348 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 349 | outs.append(x) 350 | 351 | return outs 352 | 353 | # return x.mean(dim=1) 354 | 355 | def forward(self, x): 356 | x = self.forward_features(x) 357 | # x = self.head(x) 358 | 359 | return x 360 | 361 | 362 | class DWConv(nn.Module): 363 | def __init__(self, dim=768): 364 | super(DWConv, self).__init__() 365 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 366 | 367 | def forward(self, x, H, W): 368 | B, N, C = x.shape 369 | x = x.transpose(1, 2).view(B, C, H, W) 370 | x = self.dwconv(x) 371 | x = x.flatten(2).transpose(1, 2) 372 | 373 | return x 374 | 375 | 376 | def _conv_filter(state_dict, patch_size=16): 377 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 378 | out_dict = {} 379 | for k, v in state_dict.items(): 380 | if 'patch_embed.proj.weight' in k: 381 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 382 | out_dict[k] = v 383 | 384 | return out_dict 385 | 386 | 387 | @register_model 388 | class pvt_v2_b0(PyramidVisionTransformerImpr): 389 | def __init__(self, **kwargs): 390 | super(pvt_v2_b0, self).__init__( 391 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 392 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 393 | drop_rate=0.0, drop_path_rate=0.1) 394 | 395 | 396 | 397 | @register_model 398 | class pvt_v2_b1(PyramidVisionTransformerImpr): 399 | def __init__(self, **kwargs): 400 | super(pvt_v2_b1, self).__init__( 401 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 402 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 403 | drop_rate=0.0, drop_path_rate=0.1) 404 | 405 | @register_model 406 | class pvt_v2_b2(PyramidVisionTransformerImpr): 407 | def __init__(self, **kwargs): 408 | super(pvt_v2_b2, self).__init__( 409 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 410 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 411 | drop_rate=0.0, drop_path_rate=0.1) 412 | 413 | @register_model 414 | class pvt_v2_b3(PyramidVisionTransformerImpr): 415 | def __init__(self, **kwargs): 416 | super(pvt_v2_b3, self).__init__( 417 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 418 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 419 | drop_rate=0.0, drop_path_rate=0.1) 420 | 421 | @register_model 422 | class pvt_v2_b4(PyramidVisionTransformerImpr): 423 | def __init__(self, **kwargs): 424 | super(pvt_v2_b4, self).__init__( 425 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 426 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 427 | drop_rate=0.0, drop_path_rate=0.1) 428 | 429 | 430 | @register_model 431 | class pvt_v2_b5(PyramidVisionTransformerImpr): 432 | def __init__(self, **kwargs): 433 | super(pvt_v2_b5, self).__init__( 434 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 435 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 436 | drop_rate=0.0, drop_path_rate=0.1) -------------------------------------------------------------------------------- /models/vmunet/vmunet.py: -------------------------------------------------------------------------------- 1 | # from .vmamba import VSSM 2 | from vmamba import VSSM # debug use 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class VMUNet(nn.Module): 8 | def __init__(self, 9 | input_channels=3, 10 | num_classes=1, 11 | depths=[2, 2, 9, 2], 12 | depths_decoder=[2, 9, 2, 2], 13 | drop_path_rate=0.2, 14 | load_ckpt_path=None, 15 | ): 16 | super().__init__() 17 | 18 | self.load_ckpt_path = load_ckpt_path 19 | self.num_classes = num_classes 20 | 21 | self.vmunet = VSSM(in_chans=input_channels, 22 | num_classes=num_classes, 23 | depths=depths, 24 | depths_decoder=depths_decoder, 25 | drop_path_rate=drop_path_rate, 26 | ) 27 | 28 | def forward(self, x): 29 | if x.size()[1] == 1: # 如果是灰度图,就将1个channel 转为3个channel 30 | x = x.repeat(1,3,1,1) 31 | logits = self.vmunet(x) # x [1, 3, 256, 256], logits [1, 1, 256, 256] 32 | if self.num_classes == 1: return torch.sigmoid(logits) 33 | else: return logits 34 | 35 | def load_from(self): 36 | if self.load_ckpt_path is not None: 37 | model_dict = self.vmunet.state_dict() 38 | modelCheckpoint = torch.load(self.load_ckpt_path) 39 | pretrained_dict = modelCheckpoint['model'] 40 | # 过滤操作 41 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 42 | model_dict.update(new_dict) 43 | # 打印出来,更新了多少的参数 44 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 45 | self.vmunet.load_state_dict(model_dict) 46 | 47 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 48 | print('Not loaded keys:', not_loaded_keys) 49 | print("encoder loaded finished!") 50 | 51 | model_dict = self.vmunet.state_dict() 52 | modelCheckpoint = torch.load(self.load_ckpt_path) 53 | pretrained_odict = modelCheckpoint['model'] 54 | pretrained_dict = {} 55 | for k, v in pretrained_odict.items(): 56 | if 'layers.0' in k: 57 | new_k = k.replace('layers.0', 'layers_up.3') 58 | pretrained_dict[new_k] = v 59 | elif 'layers.1' in k: 60 | new_k = k.replace('layers.1', 'layers_up.2') 61 | pretrained_dict[new_k] = v 62 | elif 'layers.2' in k: 63 | new_k = k.replace('layers.2', 'layers_up.1') 64 | pretrained_dict[new_k] = v 65 | elif 'layers.3' in k: 66 | new_k = k.replace('layers.3', 'layers_up.0') 67 | pretrained_dict[new_k] = v 68 | # 过滤操作 69 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 70 | model_dict.update(new_dict) 71 | # 打印出来,更新了多少的参数 72 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 73 | self.vmunet.load_state_dict(model_dict) 74 | 75 | # 找到没有加载的键(keys) 76 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 77 | print('Not loaded keys:', not_loaded_keys) 78 | print("decoder loaded finished!") 79 | 80 | 81 | 82 | if __name__ == '__main__': 83 | pretrained_path = '/raid/code/mamba_all/VM-UNet/pre_trained_weights/vmamba_small_e238_ema.pth' 84 | model = VMUNet(load_ckpt_path=pretrained_path).cuda() 85 | model.load_from() 86 | x = torch.randn(1, 3, 256, 256).cuda() 87 | predict = model(x) 88 | print(predict.shape) 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /models/vmunet/vmunet_v2.py: -------------------------------------------------------------------------------- 1 | from .vmamba import VSSM 2 | # from vmamba import VSSM # debug use 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | 10 | class ChannelAttention(nn.Module): 11 | def __init__(self, in_planes, ratio=16): 12 | super(ChannelAttention, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.max_pool = nn.AdaptiveMaxPool2d(1) 15 | 16 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 17 | self.relu1 = nn.ReLU() 18 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 19 | 20 | self.sigmoid = nn.Sigmoid() 21 | 22 | def forward(self, x): 23 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 24 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 25 | out = avg_out + max_out 26 | return self.sigmoid(out) 27 | 28 | 29 | class SpatialAttention(nn.Module): 30 | def __init__(self, kernel_size=7): 31 | super(SpatialAttention, self).__init__() 32 | 33 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 34 | padding = 3 if kernel_size == 7 else 1 35 | 36 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 37 | self.sigmoid = nn.Sigmoid() 38 | 39 | def forward(self, x): 40 | avg_out = torch.mean(x, dim=1, keepdim=True) 41 | max_out, _ = torch.max(x, dim=1, keepdim=True) 42 | x = torch.cat([avg_out, max_out], dim=1) 43 | x = self.conv1(x) 44 | return self.sigmoid(x) 45 | 46 | 47 | class BasicConv2d(nn.Module): 48 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 49 | super(BasicConv2d, self).__init__() 50 | 51 | self.conv = nn.Conv2d(in_planes, out_planes, 52 | kernel_size=kernel_size, stride=stride, 53 | padding=padding, dilation=dilation, bias=False) 54 | self.bn = nn.BatchNorm2d(out_planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | def forward(self, x): 58 | x = self.conv(x) 59 | x = self.bn(x) 60 | return x 61 | 62 | 63 | class SDI(nn.Module): 64 | def __init__(self, channel): 65 | super().__init__() 66 | 67 | self.convs = nn.ModuleList( 68 | [nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)]) 69 | 70 | def forward(self, xs, anchor): 71 | ans = torch.ones_like(anchor) 72 | target_size = anchor.shape[-1] 73 | 74 | for i, x in enumerate(xs):#[f1,f2,f3,f4] 75 | if x.shape[-1] > target_size: 76 | x = F.adaptive_avg_pool2d(x, (target_size, target_size)) 77 | elif x.shape[-1] < target_size: 78 | x = F.interpolate(x, size=(target_size, target_size), 79 | mode='bilinear', align_corners=True) 80 | 81 | ans = ans * self.convs[i](x) 82 | 83 | return ans 84 | 85 | 86 | 87 | 88 | 89 | class VMUNetV2(nn.Module): 90 | def __init__(self, 91 | input_channels=3, 92 | num_classes=1, 93 | mid_channel = 48, 94 | depths=[2, 2, 9, 2], 95 | depths_decoder=[2, 9, 2, 2], 96 | drop_path_rate=0.2, 97 | load_ckpt_path=None, 98 | deep_supervision=True 99 | ): 100 | super().__init__() 101 | 102 | self.load_ckpt_path = load_ckpt_path 103 | self.num_classes = num_classes 104 | self.deep_supervision = deep_supervision 105 | 106 | # SDI 107 | self.ca_1 = ChannelAttention(2*mid_channel) 108 | self.sa_1 = SpatialAttention() 109 | 110 | self.ca_2 = ChannelAttention(4*mid_channel) 111 | self.sa_2 = SpatialAttention() 112 | # TODO 320 or mid_channel * 8? 113 | self.ca_3 = ChannelAttention(8*mid_channel) 114 | self.sa_3 = SpatialAttention() 115 | 116 | self.ca_4 = ChannelAttention(16*mid_channel) 117 | self.sa_4 = SpatialAttention() 118 | 119 | self.Translayer_1 = BasicConv2d(2*mid_channel, mid_channel, 1) 120 | self.Translayer_2 = BasicConv2d(4*mid_channel, mid_channel, 1) 121 | self.Translayer_3 = BasicConv2d(8*mid_channel, mid_channel, 1) 122 | self.Translayer_4 = BasicConv2d(16*mid_channel, mid_channel, 1) 123 | 124 | self.sdi_1 = SDI(mid_channel) 125 | self.sdi_2 = SDI(mid_channel) 126 | self.sdi_3 = SDI(mid_channel) 127 | self.sdi_4 = SDI(mid_channel) 128 | 129 | self.seg_outs = nn.ModuleList([ 130 | nn.Conv2d(mid_channel, num_classes, 1, 1) for _ in range(4)]) 131 | 132 | 133 | 134 | self.deconv2 = nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=4, stride=2, padding=1, bias=False) 135 | self.deconv3 = nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=4, stride=2, padding=1, bias=False) 136 | self.deconv4 = nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=4, stride=2, padding=1, bias=False) 137 | self.deconv5 = nn.ConvTranspose2d(mid_channel, mid_channel, kernel_size=4, stride=2, padding=1, bias=False) 138 | self.deconv6 = nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, output_padding=1) 139 | 140 | 141 | 142 | self.vmunet = VSSM(in_chans=input_channels, 143 | num_classes=num_classes, 144 | depths=depths, 145 | depths_decoder=depths_decoder, 146 | drop_path_rate=drop_path_rate, 147 | ) 148 | 149 | def forward(self, x): 150 | seg_outs = [] 151 | if x.size()[1] == 1: # 如果是灰度图,就将1个channel 转为3个channel 152 | x = x.repeat(1,3,1,1) 153 | f1, f2, f3, f4 = self.vmunet(x) # f1 [2, 64, 64, 96] f3 [2, 8, 8, 768] [b h w c] 154 | # b h w c --> b c h w 155 | f1 = f1.permute(0, 3, 1, 2) # f1 [2, 96, 64, 64] 156 | f2 = f2.permute(0, 3, 1, 2) 157 | f3 = f3.permute(0, 3, 1, 2) 158 | f4 = f4.permute(0, 3, 1, 2) 159 | 160 | # use sdi 161 | f1 = self.ca_1(f1) * f1 162 | f1 = self.sa_1(f1) * f1 163 | f1 = self.Translayer_1(f1) # f1 [2, 48, 64, 64] 164 | 165 | f2 = self.ca_2(f2) * f2 166 | f2 = self.sa_2(f2) * f2 167 | f2 = self.Translayer_2(f2) # f2 [2, 48, 32, 32] 168 | 169 | f3 = self.ca_3(f3) * f3 170 | f3 = self.sa_3(f3) * f3 171 | f3 = self.Translayer_3(f3) # f3 [2, 48, 16, 16] 172 | 173 | f4 = self.ca_4(f4) * f4 174 | f4 = self.sa_4(f4) * f4 175 | f4 = self.Translayer_4(f4) # f4 [2, 48, 8, 8] 176 | 177 | 178 | f41 = self.sdi_4([f1, f2, f3, f4], f4) # [2, 48, 8, 8] 179 | f31 = self.sdi_3([f1, f2, f3, f4], f3) # [2, 48, 16, 16] 180 | f21 = self.sdi_2([f1, f2, f3, f4], f2) # [2, 48, 32, 32] 181 | f11 = self.sdi_1([f1, f2, f3, f4], f1) # [2, 48, 64, 64] 182 | 183 | # 函数seg_outs 输出列表也是 seg_outs 只是名字相同 184 | seg_outs.append(self.seg_outs[0](f41)) # seg_outs[0] [2, 1, 8, 8] 185 | 186 | y = self.deconv2(f41) + f31 187 | seg_outs.append(self.seg_outs[1](y)) # seg_outs[1] [2, 1, 16, 16] 188 | 189 | y = self.deconv3(y) + f21 190 | seg_outs.append(self.seg_outs[2](y)) # seg_outs[2] [2, 1, 32, 32] 191 | 192 | y = self.deconv4(y) + f11 193 | seg_outs.append(self.seg_outs[3](y)) # seg_outs[3] [2, 1, 64, 64] 194 | 195 | for i, o in enumerate(seg_outs): # 4 倍上采样 196 | seg_outs[i] = F.interpolate(o, scale_factor=4, mode='bilinear') 197 | 198 | if self.deep_supervision: 199 | 200 | temp = seg_outs[::-1] # 0 [2, 1, 256, 256] 1 [2, 1, 128, 128] 201 | out_0 = temp[0] 202 | out_1 = temp[1] 203 | out_1 = self.deconv6(out_1) 204 | return torch.sigmoid(out_0 + out_1) # [2, 1, 256, 256] 205 | else: 206 | if self.num_classes == 1: return torch.sigmoid(seg_outs[-1]) 207 | else: return seg_outs[-1] 208 | 209 | 210 | def load_from(self): 211 | if self.load_ckpt_path is not None: 212 | model_dict = self.vmunet.state_dict() 213 | modelCheckpoint = torch.load(self.load_ckpt_path) 214 | pretrained_dict = modelCheckpoint['model'] 215 | # 过滤操作 216 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 217 | model_dict.update(new_dict) 218 | # 打印出来,更新了多少的参数 219 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 220 | self.vmunet.load_state_dict(model_dict) 221 | 222 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 223 | print('Not loaded keys:', not_loaded_keys) 224 | print("encoder loaded finished!") 225 | 226 | # model_dict = self.vmunet.state_dict() 227 | # modelCheckpoint = torch.load(self.load_ckpt_path) 228 | # # 下面 是 layers up 229 | # pretrained_odict = modelCheckpoint['model'] 230 | # pretrained_dict = {} 231 | # for k, v in pretrained_odict.items(): 232 | # if 'layers.0' in k: 233 | # new_k = k.replace('layers.0', 'layers_up.3') 234 | # pretrained_dict[new_k] = v 235 | # elif 'layers.1' in k: 236 | # new_k = k.replace('layers.1', 'layers_up.2') 237 | # pretrained_dict[new_k] = v 238 | # elif 'layers.2' in k: 239 | # new_k = k.replace('layers.2', 'layers_up.1') 240 | # pretrained_dict[new_k] = v 241 | # elif 'layers.3' in k: 242 | # new_k = k.replace('layers.3', 'layers_up.0') 243 | # pretrained_dict[new_k] = v 244 | # # 过滤操作 245 | # new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 246 | # model_dict.update(new_dict) 247 | # # 打印出来,更新了多少的参数 248 | # print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 249 | # self.vmunet.load_state_dict(model_dict) 250 | 251 | # # 找到没有加载的键(keys) 252 | # not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 253 | # print('Not loaded keys:', not_loaded_keys) 254 | # print("decoder loaded finished!") 255 | 256 | 257 | 258 | if __name__ == '__main__': 259 | pretrained_path = '/raid/code/mamba_all/VM-UNet/pre_trained_weights/vmamba_small_e238_ema.pth' 260 | model = VMUNetV2(load_ckpt_path=pretrained_path, deep_supervision=True).cuda() 261 | model.load_from() 262 | x = torch.randn(2, 3, 256, 256).cuda() 263 | predict = model(x) 264 | # print(predict.shape) # deep_supervision true predict[0] [2, 1, 256, 256] , predict[1] [2, 1, 128, 128] 这两项用于监督 265 | 266 | 267 | 268 | -------------------------------------------------------------------------------- /models/vmunet_old/vmamba.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from functools import partial 4 | from typing import Optional, Callable 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as checkpoint 10 | from einops import rearrange, repeat 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | try: 13 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 14 | except: 15 | pass 16 | 17 | # an alternative for mamba_ssm (in which causal_conv1d is needed) 18 | try: 19 | from selective_scan import selective_scan_fn as selective_scan_fn_v1 20 | from selective_scan import selective_scan_ref as selective_scan_ref_v1 21 | except: 22 | pass 23 | 24 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" 25 | 26 | 27 | def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): 28 | """ 29 | u: r(B D L) 30 | delta: r(B D L) 31 | A: r(D N) 32 | B: r(B N L) 33 | C: r(B N L) 34 | D: r(D) 35 | z: r(B D L) 36 | delta_bias: r(D), fp32 37 | 38 | ignores: 39 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 40 | """ 41 | import numpy as np 42 | 43 | # fvcore.nn.jit_handles 44 | def get_flops_einsum(input_shapes, equation): 45 | np_arrs = [np.zeros(s) for s in input_shapes] 46 | optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] 47 | for line in optim.split("\n"): 48 | if "optimized flop" in line.lower(): 49 | # divided by 2 because we count MAC (multiply-add counted as one flop) 50 | flop = float(np.floor(float(line.split(":")[-1]) / 2)) 51 | return flop 52 | 53 | 54 | assert not with_complex 55 | 56 | flops = 0 # below code flops = 0 57 | if False: 58 | ... 59 | """ 60 | dtype_in = u.dtype 61 | u = u.float() 62 | delta = delta.float() 63 | if delta_bias is not None: 64 | delta = delta + delta_bias[..., None].float() 65 | if delta_softplus: 66 | delta = F.softplus(delta) 67 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] 68 | is_variable_B = B.dim() >= 3 69 | is_variable_C = C.dim() >= 3 70 | if A.is_complex(): 71 | if is_variable_B: 72 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) 73 | if is_variable_C: 74 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) 75 | else: 76 | B = B.float() 77 | C = C.float() 78 | x = A.new_zeros((batch, dim, dstate)) 79 | ys = [] 80 | """ 81 | 82 | flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") 83 | if with_Group: 84 | flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") 85 | else: 86 | flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") 87 | if False: 88 | ... 89 | """ 90 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 91 | if not is_variable_B: 92 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) 93 | else: 94 | if B.dim() == 3: 95 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) 96 | else: 97 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) 98 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) 99 | if is_variable_C and C.dim() == 4: 100 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) 101 | last_state = None 102 | """ 103 | 104 | in_for_flops = B * D * N 105 | if with_Group: 106 | in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") 107 | else: 108 | in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") 109 | flops += L * in_for_flops 110 | if False: 111 | ... 112 | """ 113 | for i in range(u.shape[2]): 114 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 115 | if not is_variable_C: 116 | y = torch.einsum('bdn,dn->bd', x, C) 117 | else: 118 | if C.dim() == 3: 119 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) 120 | else: 121 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) 122 | if i == u.shape[2] - 1: 123 | last_state = x 124 | if y.is_complex(): 125 | y = y.real * 2 126 | ys.append(y) 127 | y = torch.stack(ys, dim=2) # (batch dim L) 128 | """ 129 | 130 | if with_D: 131 | flops += B * D * L 132 | if with_Z: 133 | flops += B * D * L 134 | if False: 135 | ... 136 | """ 137 | out = y if D is None else y + u * rearrange(D, "d -> d 1") 138 | if z is not None: 139 | out = out * F.silu(z) 140 | out = out.to(dtype=dtype_in) 141 | """ 142 | 143 | return flops 144 | 145 | 146 | class PatchEmbed2D(nn.Module): 147 | r""" Image to Patch Embedding 148 | Args: 149 | patch_size (int): Patch token size. Default: 4. 150 | in_chans (int): Number of input image channels. Default: 3. 151 | embed_dim (int): Number of linear projection output channels. Default: 96. 152 | norm_layer (nn.Module, optional): Normalization layer. Default: None 153 | """ 154 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): 155 | super().__init__() 156 | if isinstance(patch_size, int): 157 | patch_size = (patch_size, patch_size) 158 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 159 | if norm_layer is not None: 160 | self.norm = norm_layer(embed_dim) 161 | else: 162 | self.norm = None 163 | 164 | def forward(self, x): 165 | x = self.proj(x).permute(0, 2, 3, 1) 166 | if self.norm is not None: 167 | x = self.norm(x) 168 | return x 169 | 170 | 171 | class PatchMerging2D(nn.Module): 172 | r""" Patch Merging Layer. 173 | Args: 174 | input_resolution (tuple[int]): Resolution of input feature. 175 | dim (int): Number of input channels. 176 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 177 | """ 178 | 179 | def __init__(self, dim, norm_layer=nn.LayerNorm): 180 | super().__init__() 181 | self.dim = dim 182 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 183 | self.norm = norm_layer(4 * dim) 184 | 185 | def forward(self, x): 186 | B, H, W, C = x.shape 187 | 188 | SHAPE_FIX = [-1, -1] 189 | if (W % 2 != 0) or (H % 2 != 0): 190 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) 191 | SHAPE_FIX[0] = H // 2 192 | SHAPE_FIX[1] = W // 2 193 | 194 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 195 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 196 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 197 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 198 | 199 | if SHAPE_FIX[0] > 0: 200 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 201 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 202 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 203 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 204 | 205 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 206 | x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C 207 | 208 | x = self.norm(x) 209 | x = self.reduction(x) 210 | 211 | return x 212 | 213 | 214 | class PatchExpand2D(nn.Module): 215 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 216 | super().__init__() 217 | self.dim = dim*2 218 | self.dim_scale = dim_scale 219 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 220 | self.norm = norm_layer(self.dim // dim_scale) 221 | 222 | def forward(self, x): 223 | B, H, W, C = x.shape 224 | x = self.expand(x) 225 | 226 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 227 | x= self.norm(x) 228 | 229 | return x 230 | 231 | 232 | class Final_PatchExpand2D(nn.Module): 233 | def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): 234 | super().__init__() 235 | self.dim = dim 236 | self.dim_scale = dim_scale 237 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 238 | self.norm = norm_layer(self.dim // dim_scale) 239 | 240 | def forward(self, x): 241 | B, H, W, C = x.shape 242 | x = self.expand(x) 243 | 244 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 245 | x= self.norm(x) 246 | 247 | return x 248 | 249 | 250 | class SS2D(nn.Module): 251 | def __init__( 252 | self, 253 | d_model, 254 | d_state=16, 255 | # d_state="auto", # 20240109 256 | d_conv=3, 257 | expand=2, 258 | dt_rank="auto", 259 | dt_min=0.001, 260 | dt_max=0.1, 261 | dt_init="random", 262 | dt_scale=1.0, 263 | dt_init_floor=1e-4, 264 | dropout=0., 265 | conv_bias=True, 266 | bias=False, 267 | device=None, 268 | dtype=None, 269 | **kwargs, 270 | ): 271 | factory_kwargs = {"device": device, "dtype": dtype} 272 | super().__init__() 273 | self.d_model = d_model 274 | self.d_state = d_state 275 | # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 276 | self.d_conv = d_conv 277 | self.expand = expand 278 | self.d_inner = int(self.expand * self.d_model) 279 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 280 | 281 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 282 | self.conv2d = nn.Conv2d( 283 | in_channels=self.d_inner, 284 | out_channels=self.d_inner, 285 | groups=self.d_inner, 286 | bias=conv_bias, 287 | kernel_size=d_conv, 288 | padding=(d_conv - 1) // 2, 289 | **factory_kwargs, 290 | ) 291 | self.act = nn.SiLU() 292 | 293 | self.x_proj = ( 294 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 295 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 296 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 297 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 298 | ) 299 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) 300 | del self.x_proj 301 | 302 | self.dt_projs = ( 303 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 304 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 305 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 306 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 307 | ) 308 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) 309 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) 310 | del self.dt_projs 311 | 312 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) 313 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 314 | 315 | # self.selective_scan = selective_scan_fn 316 | self.forward_core = self.forward_corev0 317 | 318 | self.out_norm = nn.LayerNorm(self.d_inner) 319 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 320 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 321 | 322 | @staticmethod 323 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 324 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 325 | 326 | # Initialize special dt projection to preserve variance at initialization 327 | dt_init_std = dt_rank**-0.5 * dt_scale 328 | if dt_init == "constant": 329 | nn.init.constant_(dt_proj.weight, dt_init_std) 330 | elif dt_init == "random": 331 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 332 | else: 333 | raise NotImplementedError 334 | 335 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 336 | dt = torch.exp( 337 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 338 | + math.log(dt_min) 339 | ).clamp(min=dt_init_floor) 340 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 341 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 342 | with torch.no_grad(): 343 | dt_proj.bias.copy_(inv_dt) 344 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 345 | dt_proj.bias._no_reinit = True 346 | 347 | return dt_proj 348 | 349 | @staticmethod 350 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 351 | # S4D real initialization 352 | A = repeat( 353 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 354 | "n -> d n", 355 | d=d_inner, 356 | ).contiguous() 357 | A_log = torch.log(A) # Keep A_log in fp32 358 | if copies > 1: 359 | A_log = repeat(A_log, "d n -> r d n", r=copies) 360 | if merge: 361 | A_log = A_log.flatten(0, 1) 362 | A_log = nn.Parameter(A_log) 363 | A_log._no_weight_decay = True 364 | return A_log 365 | 366 | @staticmethod 367 | def D_init(d_inner, copies=1, device=None, merge=True): 368 | # D "skip" parameter 369 | D = torch.ones(d_inner, device=device) 370 | if copies > 1: 371 | D = repeat(D, "n1 -> r n1", r=copies) 372 | if merge: 373 | D = D.flatten(0, 1) 374 | D = nn.Parameter(D) # Keep in fp32 375 | D._no_weight_decay = True 376 | return D 377 | 378 | def forward_corev0(self, x: torch.Tensor): 379 | self.selective_scan = selective_scan_fn 380 | 381 | B, C, H, W = x.shape 382 | L = H * W 383 | K = 4 384 | 385 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 386 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 387 | 388 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 389 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 390 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 391 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 392 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 393 | 394 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 395 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 396 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 397 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 398 | Ds = self.Ds.float().view(-1) # (k * d) 399 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 400 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 401 | 402 | out_y = self.selective_scan( 403 | xs, dts, 404 | As, Bs, Cs, Ds, z=None, 405 | delta_bias=dt_projs_bias, 406 | delta_softplus=True, 407 | return_last_state=False, 408 | ).view(B, K, -1, L) 409 | assert out_y.dtype == torch.float 410 | 411 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 412 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 413 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 414 | 415 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 416 | 417 | # an alternative to forward_corev1 418 | def forward_corev1(self, x: torch.Tensor): 419 | self.selective_scan = selective_scan_fn_v1 420 | 421 | B, C, H, W = x.shape 422 | L = H * W 423 | K = 4 424 | 425 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 426 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 427 | 428 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 429 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 430 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 431 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 432 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 433 | 434 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 435 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 436 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 437 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 438 | Ds = self.Ds.float().view(-1) # (k * d) 439 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 440 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 441 | 442 | out_y = self.selective_scan( 443 | xs, dts, 444 | As, Bs, Cs, Ds, 445 | delta_bias=dt_projs_bias, 446 | delta_softplus=True, 447 | ).view(B, K, -1, L) 448 | assert out_y.dtype == torch.float 449 | 450 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 451 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 452 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 453 | 454 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 455 | 456 | def forward(self, x: torch.Tensor, **kwargs): 457 | B, H, W, C = x.shape 458 | 459 | xz = self.in_proj(x) 460 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 461 | 462 | x = x.permute(0, 3, 1, 2).contiguous() 463 | x = self.act(self.conv2d(x)) # (b, d, h, w) 464 | y1, y2, y3, y4 = self.forward_core(x) 465 | assert y1.dtype == torch.float32 466 | y = y1 + y2 + y3 + y4 467 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 468 | y = self.out_norm(y) 469 | y = y * F.silu(z) 470 | out = self.out_proj(y) 471 | if self.dropout is not None: 472 | out = self.dropout(out) 473 | return out 474 | 475 | 476 | class VSSBlock(nn.Module): 477 | def __init__( 478 | self, 479 | hidden_dim: int = 0, 480 | drop_path: float = 0, 481 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 482 | attn_drop_rate: float = 0, 483 | d_state: int = 16, 484 | **kwargs, 485 | ): 486 | super().__init__() 487 | self.ln_1 = norm_layer(hidden_dim) 488 | self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) 489 | self.drop_path = DropPath(drop_path) 490 | 491 | def forward(self, input: torch.Tensor): 492 | x = input + self.drop_path(self.self_attention(self.ln_1(input))) 493 | return x 494 | 495 | 496 | class VSSLayer(nn.Module): 497 | """ A basic Swin Transformer layer for one stage. 498 | Args: 499 | dim (int): Number of input channels. 500 | depth (int): Number of blocks. 501 | drop (float, optional): Dropout rate. Default: 0.0 502 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 503 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 504 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 505 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 506 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 507 | """ 508 | 509 | def __init__( 510 | self, 511 | dim, 512 | depth, 513 | attn_drop=0., 514 | drop_path=0., 515 | norm_layer=nn.LayerNorm, 516 | downsample=None, 517 | use_checkpoint=False, 518 | d_state=16, 519 | **kwargs, 520 | ): 521 | super().__init__() 522 | self.dim = dim 523 | self.use_checkpoint = use_checkpoint 524 | 525 | self.blocks = nn.ModuleList([ 526 | VSSBlock( 527 | hidden_dim=dim, 528 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 529 | norm_layer=norm_layer, 530 | attn_drop_rate=attn_drop, 531 | d_state=d_state, 532 | ) 533 | for i in range(depth)]) 534 | 535 | if True: # is this really applied? Yes, but been overriden later in VSSM! 536 | def _init_weights(module: nn.Module): 537 | for name, p in module.named_parameters(): 538 | if name in ["out_proj.weight"]: 539 | p = p.clone().detach_() # fake init, just to keep the seed .... 540 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 541 | self.apply(_init_weights) 542 | 543 | if downsample is not None: 544 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 545 | else: 546 | self.downsample = None 547 | 548 | 549 | def forward(self, x): 550 | for blk in self.blocks: 551 | if self.use_checkpoint: 552 | x = checkpoint.checkpoint(blk, x) 553 | else: 554 | x = blk(x) 555 | 556 | if self.downsample is not None: 557 | x = self.downsample(x) 558 | 559 | return x 560 | 561 | 562 | 563 | class VSSLayer_up(nn.Module): 564 | """ A basic Swin Transformer layer for one stage. 565 | Args: 566 | dim (int): Number of input channels. 567 | depth (int): Number of blocks. 568 | drop (float, optional): Dropout rate. Default: 0.0 569 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 570 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 571 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 572 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 573 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 574 | """ 575 | 576 | def __init__( 577 | self, 578 | dim, 579 | depth, 580 | attn_drop=0., 581 | drop_path=0., 582 | norm_layer=nn.LayerNorm, 583 | upsample=None, 584 | use_checkpoint=False, 585 | d_state=16, 586 | **kwargs, 587 | ): 588 | super().__init__() 589 | self.dim = dim 590 | self.use_checkpoint = use_checkpoint 591 | 592 | self.blocks = nn.ModuleList([ 593 | VSSBlock( 594 | hidden_dim=dim, 595 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 596 | norm_layer=norm_layer, 597 | attn_drop_rate=attn_drop, 598 | d_state=d_state, 599 | ) 600 | for i in range(depth)]) 601 | 602 | if True: # is this really applied? Yes, but been overriden later in VSSM! 603 | def _init_weights(module: nn.Module): 604 | for name, p in module.named_parameters(): 605 | if name in ["out_proj.weight"]: 606 | p = p.clone().detach_() # fake init, just to keep the seed .... 607 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 608 | self.apply(_init_weights) 609 | 610 | if upsample is not None: 611 | self.upsample = upsample(dim=dim, norm_layer=norm_layer) 612 | else: 613 | self.upsample = None 614 | 615 | 616 | def forward(self, x): 617 | if self.upsample is not None: 618 | x = self.upsample(x) 619 | for blk in self.blocks: 620 | if self.use_checkpoint: 621 | x = checkpoint.checkpoint(blk, x) 622 | else: 623 | x = blk(x) 624 | return x 625 | 626 | 627 | 628 | class VSSM(nn.Module): 629 | def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], 630 | dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 631 | norm_layer=nn.LayerNorm, patch_norm=True, 632 | use_checkpoint=False, **kwargs): 633 | super().__init__() 634 | self.num_classes = num_classes 635 | self.num_layers = len(depths) 636 | if isinstance(dims, int): 637 | dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] 638 | self.embed_dim = dims[0] 639 | self.num_features = dims[-1] 640 | self.dims = dims 641 | 642 | self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, 643 | norm_layer=norm_layer if patch_norm else None) 644 | 645 | # WASTED absolute position embedding ====================== 646 | self.ape = False 647 | # self.ape = False 648 | # drop_rate = 0.0 649 | if self.ape: 650 | self.patches_resolution = self.patch_embed.patches_resolution 651 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) 652 | trunc_normal_(self.absolute_pos_embed, std=.02) 653 | self.pos_drop = nn.Dropout(p=drop_rate) 654 | 655 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 656 | dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] 657 | 658 | self.layers = nn.ModuleList() 659 | for i_layer in range(self.num_layers): 660 | layer = VSSLayer( 661 | dim=dims[i_layer], 662 | depth=depths[i_layer], 663 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 664 | drop=drop_rate, 665 | attn_drop=attn_drop_rate, 666 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 667 | norm_layer=norm_layer, 668 | downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, 669 | use_checkpoint=use_checkpoint, 670 | ) 671 | self.layers.append(layer) 672 | 673 | self.layers_up = nn.ModuleList() 674 | for i_layer in range(self.num_layers): 675 | layer = VSSLayer_up( 676 | dim=dims_decoder[i_layer], 677 | depth=depths_decoder[i_layer], 678 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 679 | drop=drop_rate, 680 | attn_drop=attn_drop_rate, 681 | drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], 682 | norm_layer=norm_layer, 683 | upsample=PatchExpand2D if (i_layer != 0) else None, 684 | use_checkpoint=use_checkpoint, 685 | ) 686 | self.layers_up.append(layer) 687 | 688 | self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) 689 | self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) 690 | 691 | # self.norm = norm_layer(self.num_features) 692 | # self.avgpool = nn.AdaptiveAvgPool1d(1) 693 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 694 | 695 | self.apply(self._init_weights) 696 | 697 | def _init_weights(self, m: nn.Module): 698 | """ 699 | out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear 700 | no fc.weight found in the any of the model parameters 701 | no nn.Embedding found in the any of the model parameters 702 | so the thing is, VSSBlock initialization is useless 703 | 704 | Conv2D is not intialized !!! 705 | """ 706 | if isinstance(m, nn.Linear): 707 | trunc_normal_(m.weight, std=.02) 708 | if isinstance(m, nn.Linear) and m.bias is not None: 709 | nn.init.constant_(m.bias, 0) 710 | elif isinstance(m, nn.LayerNorm): 711 | nn.init.constant_(m.bias, 0) 712 | nn.init.constant_(m.weight, 1.0) 713 | 714 | @torch.jit.ignore 715 | def no_weight_decay(self): 716 | return {'absolute_pos_embed'} 717 | 718 | @torch.jit.ignore 719 | def no_weight_decay_keywords(self): 720 | return {'relative_position_bias_table'} 721 | 722 | def forward_features(self, x): 723 | skip_list = [] 724 | x = self.patch_embed(x) 725 | if self.ape: 726 | x = x + self.absolute_pos_embed 727 | x = self.pos_drop(x) 728 | 729 | for layer in self.layers: 730 | skip_list.append(x) 731 | x = layer(x) 732 | return x, skip_list 733 | 734 | def forward_features_up(self, x, skip_list): 735 | for inx, layer_up in enumerate(self.layers_up): 736 | if inx == 0: 737 | x = layer_up(x) 738 | else: 739 | x = layer_up(x+skip_list[-inx]) 740 | 741 | return x 742 | 743 | def forward_final(self, x): 744 | x = self.final_up(x) 745 | x = x.permute(0,3,1,2) 746 | x = self.final_conv(x) 747 | return x 748 | 749 | def forward_backbone(self, x): 750 | x = self.patch_embed(x) 751 | if self.ape: 752 | x = x + self.absolute_pos_embed 753 | x = self.pos_drop(x) 754 | 755 | for layer in self.layers: 756 | x = layer(x) 757 | return x 758 | 759 | def forward(self, x): 760 | x, skip_list = self.forward_features(x) 761 | x = self.forward_features_up(x, skip_list) 762 | x = self.forward_final(x) 763 | 764 | return x 765 | 766 | 767 | 768 | 769 | -------------------------------------------------------------------------------- /models/vmunet_old/vmunet.py: -------------------------------------------------------------------------------- 1 | from .vmamba import VSSM 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class VMUNet(nn.Module): 7 | def __init__(self, 8 | input_channels=3, 9 | num_classes=1, 10 | depths=[2, 2, 9, 2], 11 | depths_decoder=[2, 9, 2, 2], 12 | drop_path_rate=0.2, 13 | load_ckpt_path=None, 14 | ): 15 | super().__init__() 16 | 17 | self.load_ckpt_path = load_ckpt_path 18 | self.num_classes = num_classes 19 | 20 | self.vmunet = VSSM(in_chans=input_channels, 21 | num_classes=num_classes, 22 | depths=depths, 23 | depths_decoder=depths_decoder, 24 | drop_path_rate=drop_path_rate, 25 | ) 26 | 27 | def forward(self, x): 28 | if x.size()[1] == 1: 29 | x = x.repeat(1,3,1,1) 30 | logits = self.vmunet(x) 31 | if self.num_classes == 1: return torch.sigmoid(logits) 32 | else: return logits 33 | 34 | def load_from(self): 35 | if self.load_ckpt_path is not None: 36 | model_dict = self.vmunet.state_dict() 37 | modelCheckpoint = torch.load(self.load_ckpt_path) 38 | pretrained_dict = modelCheckpoint['model'] 39 | # 过滤操作 40 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 41 | model_dict.update(new_dict) 42 | # 打印出来,更新了多少的参数 43 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 44 | self.vmunet.load_state_dict(model_dict) 45 | 46 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 47 | print('Not loaded keys:', not_loaded_keys) 48 | print("encoder loaded finished!") 49 | 50 | model_dict = self.vmunet.state_dict() 51 | modelCheckpoint = torch.load(self.load_ckpt_path) 52 | pretrained_odict = modelCheckpoint['model'] 53 | pretrained_dict = {} 54 | for k, v in pretrained_odict.items(): 55 | if 'layers.0' in k: 56 | new_k = k.replace('layers.0', 'layers_up.3') 57 | pretrained_dict[new_k] = v 58 | elif 'layers.1' in k: 59 | new_k = k.replace('layers.1', 'layers_up.2') 60 | pretrained_dict[new_k] = v 61 | elif 'layers.2' in k: 62 | new_k = k.replace('layers.2', 'layers_up.1') 63 | pretrained_dict[new_k] = v 64 | elif 'layers.3' in k: 65 | new_k = k.replace('layers.3', 'layers_up.0') 66 | pretrained_dict[new_k] = v 67 | # 过滤操作 68 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 69 | model_dict.update(new_dict) 70 | # 打印出来,更新了多少的参数 71 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 72 | self.vmunet.load_state_dict(model_dict) 73 | 74 | # 找到没有加载的键(keys) 75 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 76 | print('Not loaded keys:', not_loaded_keys) 77 | print("decoder loaded finished!") -------------------------------------------------------------------------------- /pre_trained_weights/README.md: -------------------------------------------------------------------------------- 1 | You could download the weights used by VM-UNet from here: https://pan.baidu.com/s/144nc45k28NBB_8k9hf_OeQ?pwd=xuns 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import timm 4 | from datasets.dataset import NPY_datasets, Polyp_datasets 5 | from tensorboardX import SummaryWriter 6 | # from models.vmunet.vmunet import VMUNet 7 | from models.vmunet.vmunet_v2 import VMUNetV2 8 | 9 | from engine import * 10 | import os 11 | import sys 12 | 13 | from utils import * 14 | from configs.config_setting_v2 import setting_config 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | 21 | def main(config): 22 | 23 | print('#----------Creating logger----------#') 24 | sys.path.append(config.work_dir + '/') 25 | log_dir = os.path.join(config.work_dir, 'log') 26 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 27 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 28 | outputs = os.path.join(config.work_dir, 'outputs') 29 | if not os.path.exists(checkpoint_dir): 30 | os.makedirs(checkpoint_dir) 31 | if not os.path.exists(outputs): 32 | os.makedirs(outputs) 33 | 34 | global logger 35 | logger = get_logger('train', log_dir) 36 | global writer 37 | writer = SummaryWriter(config.work_dir + 'summary') 38 | 39 | log_config_info(config, logger) 40 | 41 | 42 | 43 | 44 | 45 | print('#----------GPU init----------#') 46 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id 47 | set_seed(config.seed) 48 | torch.cuda.empty_cache() 49 | 50 | 51 | 52 | 53 | 54 | print('#----------Preparing dataset----------#') 55 | # train_dataset = NPY_datasets(config.data_path, config, train=True) 56 | train_dataset = Polyp_datasets(config.data_path, config, train=True) 57 | 58 | train_loader = DataLoader(train_dataset, 59 | batch_size=config.batch_size, 60 | shuffle=True, 61 | pin_memory=True, 62 | num_workers=config.num_workers) 63 | 64 | 65 | # val_dataset = NPY_datasets(config.data_path, config, train=False) 66 | 67 | 68 | val_loader_dict = {} 69 | for dataset in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 70 | val_dataset = Polyp_datasets(config.data_path, config, train=False, test_dataset=dataset) 71 | val_loader = DataLoader(val_dataset, 72 | batch_size=1, 73 | shuffle=False, 74 | pin_memory=True, 75 | num_workers=config.num_workers, 76 | drop_last=True) 77 | 78 | val_loader_dict[dataset] = val_loader 79 | 80 | 81 | 82 | print('#----------Prepareing Model----------#') 83 | model_cfg = config.model_config 84 | if config.network == 'vmunet-v2': 85 | model = VMUNetV2( 86 | num_classes=model_cfg['num_classes'], 87 | input_channels=model_cfg['input_channels'], 88 | depths=model_cfg['depths'], 89 | depths_decoder=model_cfg['depths_decoder'], 90 | drop_path_rate=model_cfg['drop_path_rate'], 91 | load_ckpt_path=model_cfg['load_ckpt_path'], 92 | deep_supervision = model_cfg['deep_supervision'], 93 | ) 94 | model.load_from() 95 | 96 | else: raise Exception('network in not right!') 97 | model = model.cuda() 98 | 99 | cal_params_flops(model, 256, logger) 100 | 101 | 102 | 103 | 104 | 105 | print('#----------Prepareing loss, opt, sch and amp----------#') 106 | criterion = config.criterion 107 | optimizer = get_optimizer(config, model) 108 | scheduler = get_scheduler(config, optimizer) 109 | 110 | 111 | 112 | 113 | 114 | print('#----------Set other params----------#') 115 | min_loss = 999 116 | start_epoch = 1 117 | min_epoch = 1 118 | 119 | 120 | 121 | 122 | 123 | if os.path.exists(resume_model): 124 | print('#----------Resume Model and Other params----------#') 125 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 126 | model.load_state_dict(checkpoint['model_state_dict']) 127 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 128 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 129 | saved_epoch = checkpoint['epoch'] 130 | start_epoch += saved_epoch 131 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 132 | 133 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 134 | logger.info(log_info) 135 | 136 | 137 | 138 | 139 | step = 0 140 | print('#----------Training----------#') 141 | for epoch in range(start_epoch, config.epochs + 1): 142 | 143 | torch.cuda.empty_cache() 144 | 145 | step = train_one_epoch( 146 | train_loader, 147 | model, 148 | criterion, 149 | optimizer, 150 | scheduler, 151 | epoch, 152 | step, 153 | logger, 154 | config, 155 | writer 156 | ) 157 | 158 | loss_all = [] 159 | for name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 160 | val_loader_t = val_loader_dict[name] 161 | 162 | loss_t = val_one_epoch( 163 | val_loader_t, 164 | model, 165 | criterion, 166 | epoch, 167 | logger, 168 | config, 169 | val_data_name=name 170 | ) 171 | loss_all.append(loss_t) 172 | 173 | loss = np.mean(loss_all) 174 | if loss < min_loss: 175 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 176 | min_loss = loss 177 | min_epoch = epoch 178 | 179 | torch.save( 180 | { 181 | 'epoch': epoch, 182 | 'min_loss': min_loss, 183 | 'min_epoch': min_epoch, 184 | 'loss': loss, 185 | 'model_state_dict': model.state_dict(), 186 | 'optimizer_state_dict': optimizer.state_dict(), 187 | 'scheduler_state_dict': scheduler.state_dict(), 188 | }, os.path.join(checkpoint_dir, 'latest.pth')) 189 | 190 | # 测试的步骤:也要按照 Polyp 的文件夹列表 191 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 192 | print('#----------Testing----------#') 193 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 194 | model.load_state_dict(best_weight) 195 | for name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 196 | val_loader_t = val_loader_dict[name] 197 | loss = test_one_epoch( 198 | val_loader_t, 199 | model, 200 | criterion, 201 | logger, 202 | config, 203 | test_data_name=name 204 | ) 205 | os.rename( 206 | os.path.join(checkpoint_dir, 'best.pth'), 207 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 208 | ) 209 | 210 | 211 | if __name__ == '__main__': 212 | config = setting_config 213 | main(config) -------------------------------------------------------------------------------- /train_gim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import timm 4 | from datasets.dataset import NPY_datasets,GIM_datasets 5 | from tensorboardX import SummaryWriter 6 | from models.vmunet.vmunet_v2 import VMUNetV2 7 | 8 | from engine import * 9 | import os 10 | import sys 11 | 12 | from utils import * 13 | from configs.config_setting_v2 import setting_config 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | import os 18 | os.environ['CUDA_LAUNCH_BLOCKING'] = '2' 19 | 20 | 21 | def main(config): 22 | 23 | print('#----------Creating logger----------#') 24 | sys.path.append(config.work_dir + '/') 25 | log_dir = os.path.join(config.work_dir, 'log') 26 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 27 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 28 | outputs = os.path.join(config.work_dir, 'outputs') 29 | if not os.path.exists(checkpoint_dir): 30 | os.makedirs(checkpoint_dir) 31 | if not os.path.exists(outputs): 32 | os.makedirs(outputs) 33 | 34 | global logger 35 | logger = get_logger('train', log_dir) 36 | global writer 37 | writer = SummaryWriter(config.work_dir + 'summary') 38 | 39 | log_config_info(config, logger) 40 | 41 | 42 | 43 | 44 | 45 | print('#----------GPU init----------#') 46 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id 47 | set_seed(config.seed) 48 | torch.cuda.empty_cache() 49 | 50 | 51 | 52 | 53 | 54 | print('#----------Preparing dataset----------#') 55 | train_dataset = GIM_datasets(config.data_path, config, train=True) 56 | train_loader = DataLoader(train_dataset, 57 | batch_size=config.batch_size, 58 | shuffle=True, 59 | pin_memory=True, 60 | num_workers=config.num_workers) 61 | val_dataset = GIM_datasets(config.data_path, config, train=False) 62 | val_loader = DataLoader(val_dataset, 63 | batch_size=1, 64 | shuffle=False, 65 | pin_memory=True, 66 | num_workers=config.num_workers, 67 | drop_last=True) 68 | 69 | 70 | 71 | 72 | 73 | print('#----------Prepareing Model----------#') 74 | model_cfg = config.model_config 75 | if config.network == 'vmunet-v2': 76 | model = VMUNetV2( 77 | num_classes=model_cfg['num_classes'], 78 | input_channels=model_cfg['input_channels'], 79 | depths=model_cfg['depths'], 80 | depths_decoder=model_cfg['depths_decoder'], 81 | drop_path_rate=model_cfg['drop_path_rate'], 82 | load_ckpt_path=model_cfg['load_ckpt_path'], 83 | deep_supervision = model_cfg['deep_supervision'], 84 | ) 85 | model.load_from() 86 | 87 | else: raise Exception('network in not right!') 88 | model = model.cuda() 89 | 90 | cal_params_flops(model, 256, logger) 91 | 92 | 93 | 94 | 95 | 96 | print('#----------Prepareing loss, opt, sch and amp----------#') 97 | criterion = config.criterion 98 | optimizer = get_optimizer(config, model) 99 | scheduler = get_scheduler(config, optimizer) 100 | 101 | 102 | 103 | 104 | 105 | print('#----------Set other params----------#') 106 | min_loss = 999 107 | start_epoch = 1 108 | min_epoch = 1 109 | 110 | 111 | 112 | 113 | 114 | if os.path.exists(resume_model): 115 | print('#----------Resume Model and Other params----------#') 116 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 117 | model.load_state_dict(checkpoint['model_state_dict']) 118 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 119 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 120 | saved_epoch = checkpoint['epoch'] 121 | start_epoch += saved_epoch 122 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 123 | 124 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 125 | logger.info(log_info) 126 | 127 | 128 | 129 | 130 | step = 0 131 | print('#----------Training----------#') 132 | for epoch in range(start_epoch, config.epochs + 1): 133 | 134 | torch.cuda.empty_cache() 135 | 136 | step = train_one_epoch( 137 | train_loader, 138 | model, 139 | criterion, 140 | optimizer, 141 | scheduler, 142 | epoch, 143 | step, 144 | logger, 145 | config, 146 | writer 147 | ) 148 | 149 | loss = val_one_epoch( 150 | val_loader, 151 | model, 152 | criterion, 153 | epoch, 154 | logger, 155 | config 156 | ) 157 | 158 | if loss < min_loss: 159 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 160 | min_loss = loss 161 | min_epoch = epoch 162 | 163 | torch.save( 164 | { 165 | 'epoch': epoch, 166 | 'min_loss': min_loss, 167 | 'min_epoch': min_epoch, 168 | 'loss': loss, 169 | 'model_state_dict': model.state_dict(), 170 | 'optimizer_state_dict': optimizer.state_dict(), 171 | 'scheduler_state_dict': scheduler.state_dict(), 172 | }, os.path.join(checkpoint_dir, 'latest.pth')) 173 | 174 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 175 | print('#----------Testing----------#') 176 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 177 | model.load_state_dict(best_weight) 178 | loss = test_one_epoch( 179 | val_loader, 180 | model, 181 | criterion, 182 | logger, 183 | config, 184 | ) 185 | os.rename( 186 | os.path.join(checkpoint_dir, 'best.pth'), 187 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 188 | ) 189 | 190 | 191 | if __name__ == '__main__': 192 | config = setting_config 193 | main(config) -------------------------------------------------------------------------------- /train_isic_all.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import timm 4 | from datasets.dataset import NPY_datasets, Polyp_datasets, Isic_datasets 5 | from tensorboardX import SummaryWriter 6 | # from models.vmunet.vmunet import VMUNet 7 | from models.vmunet.vmunet_v2 import VMUNetV2 8 | 9 | from engine import * 10 | import os 11 | import sys 12 | 13 | from utils import * 14 | from configs.config_setting_v2 import setting_config 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | 21 | def main(config): 22 | 23 | print('#----------Creating logger----------#') 24 | sys.path.append(config.work_dir + '/') 25 | log_dir = os.path.join(config.work_dir, 'log') 26 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 27 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 28 | outputs = os.path.join(config.work_dir, 'outputs') 29 | if not os.path.exists(checkpoint_dir): 30 | os.makedirs(checkpoint_dir) 31 | if not os.path.exists(outputs): 32 | os.makedirs(outputs) 33 | 34 | global logger 35 | logger = get_logger('train', log_dir) 36 | global writer 37 | writer = SummaryWriter(config.work_dir + 'summary') 38 | 39 | log_config_info(config, logger) 40 | 41 | 42 | 43 | 44 | 45 | print('#----------GPU init----------#') 46 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id 47 | set_seed(config.seed) 48 | torch.cuda.empty_cache() 49 | 50 | 51 | 52 | 53 | 54 | print('#----------Preparing dataset----------#') 55 | # train_dataset = NPY_datasets(config.data_path, config, train=True) 56 | train_dataset = Isic_datasets(config.data_path, config, train=True) 57 | 58 | train_loader = DataLoader(train_dataset, 59 | batch_size=config.batch_size, 60 | shuffle=True, 61 | pin_memory=True, 62 | num_workers=config.num_workers) 63 | 64 | 65 | # val_dataset = NPY_datasets(config.data_path, config, train=False) 66 | 67 | 68 | val_loader_dict = {} 69 | for dataset in ['isic17', 'isic18']: 70 | val_dataset = Isic_datasets(config.data_path, config, train=False, test_dataset=dataset) 71 | val_loader = DataLoader(val_dataset, 72 | batch_size=1, 73 | shuffle=False, 74 | pin_memory=True, 75 | num_workers=config.num_workers, 76 | drop_last=True) 77 | 78 | val_loader_dict[dataset] = val_loader 79 | 80 | 81 | 82 | print('#----------Prepareing Model----------#') 83 | model_cfg = config.model_config 84 | if config.network == 'vmunet-v2': 85 | model = VMUNetV2( 86 | num_classes=model_cfg['num_classes'], 87 | input_channels=model_cfg['input_channels'], 88 | depths=model_cfg['depths'], 89 | depths_decoder=model_cfg['depths_decoder'], 90 | drop_path_rate=model_cfg['drop_path_rate'], 91 | load_ckpt_path=model_cfg['load_ckpt_path'], 92 | deep_supervision = model_cfg['deep_supervision'], 93 | ) 94 | model.load_from() 95 | 96 | else: raise Exception('network in not right!') 97 | model = model.cuda() 98 | 99 | cal_params_flops(model, 256, logger) 100 | 101 | 102 | 103 | 104 | 105 | print('#----------Prepareing loss, opt, sch and amp----------#') 106 | criterion = config.criterion 107 | optimizer = get_optimizer(config, model) 108 | scheduler = get_scheduler(config, optimizer) 109 | 110 | 111 | 112 | 113 | 114 | print('#----------Set other params----------#') 115 | min_loss = 999 116 | start_epoch = 1 117 | min_epoch = 1 118 | 119 | 120 | 121 | 122 | 123 | if os.path.exists(resume_model): 124 | print('#----------Resume Model and Other params----------#') 125 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 126 | model.load_state_dict(checkpoint['model_state_dict']) 127 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 128 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 129 | saved_epoch = checkpoint['epoch'] 130 | start_epoch += saved_epoch 131 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 132 | 133 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 134 | logger.info(log_info) 135 | 136 | 137 | 138 | 139 | step = 0 140 | print('#----------Training----------#') 141 | for epoch in range(start_epoch, config.epochs + 1): 142 | 143 | torch.cuda.empty_cache() 144 | 145 | step = train_one_epoch( 146 | train_loader, 147 | model, 148 | criterion, 149 | optimizer, 150 | scheduler, 151 | epoch, 152 | step, 153 | logger, 154 | config, 155 | writer 156 | ) 157 | 158 | loss_all = [] 159 | for name in ['isic17', 'isic18']: 160 | val_loader_t = val_loader_dict[name] 161 | 162 | loss_t = val_one_epoch( 163 | val_loader_t, 164 | model, 165 | criterion, 166 | epoch, 167 | logger, 168 | config, 169 | val_data_name=name 170 | ) 171 | loss_all.append(loss_t) 172 | 173 | loss = np.mean(loss_all) 174 | if loss < min_loss: 175 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 176 | min_loss = loss 177 | min_epoch = epoch 178 | 179 | torch.save( 180 | { 181 | 'epoch': epoch, 182 | 'min_loss': min_loss, 183 | 'min_epoch': min_epoch, 184 | 'loss': loss, 185 | 'model_state_dict': model.state_dict(), 186 | 'optimizer_state_dict': optimizer.state_dict(), 187 | 'scheduler_state_dict': scheduler.state_dict(), 188 | }, os.path.join(checkpoint_dir, 'latest.pth')) 189 | 190 | # 测试的步骤:也要按照 Polyp 的文件夹列表 191 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 192 | print('#----------Testing----------#') 193 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 194 | model.load_state_dict(best_weight) 195 | for name in ['isic17','isic18']: 196 | val_loader_t = val_loader_dict[name] 197 | loss = test_one_epoch( 198 | val_loader_t, 199 | model, 200 | criterion, 201 | logger, 202 | config, 203 | test_data_name=name 204 | ) 205 | os.rename( 206 | os.path.join(checkpoint_dir, 'best.pth'), 207 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 208 | ) 209 | 210 | 211 | if __name__ == '__main__': 212 | config = setting_config 213 | main(config) -------------------------------------------------------------------------------- /train_npy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import timm 4 | from datasets.dataset import NPY_datasets 5 | from tensorboardX import SummaryWriter 6 | from models.vmunet.vmunet_v2 import VMUNetV2 7 | 8 | from engine import * 9 | import os 10 | import sys 11 | 12 | from utils import * 13 | from configs.config_setting_v2 import setting_config 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | 20 | def main(config): 21 | 22 | print('#----------Creating logger----------#') 23 | sys.path.append(config.work_dir + '/') 24 | log_dir = os.path.join(config.work_dir, 'log') 25 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 26 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 27 | outputs = os.path.join(config.work_dir, 'outputs') 28 | if not os.path.exists(checkpoint_dir): 29 | os.makedirs(checkpoint_dir) 30 | if not os.path.exists(outputs): 31 | os.makedirs(outputs) 32 | 33 | global logger 34 | logger = get_logger('train', log_dir) 35 | global writer 36 | writer = SummaryWriter(config.work_dir + 'summary') 37 | 38 | log_config_info(config, logger) 39 | 40 | 41 | 42 | 43 | 44 | print('#----------GPU init----------#') 45 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id 46 | set_seed(config.seed) 47 | torch.cuda.empty_cache() 48 | 49 | 50 | 51 | 52 | 53 | print('#----------Preparing dataset----------#') 54 | train_dataset = NPY_datasets(config.data_path, config, train=True) 55 | train_loader = DataLoader(train_dataset, 56 | batch_size=config.batch_size, 57 | shuffle=True, 58 | pin_memory=True, 59 | num_workers=config.num_workers) 60 | val_dataset = NPY_datasets(config.data_path, config, train=False) 61 | val_loader = DataLoader(val_dataset, 62 | batch_size=1, 63 | shuffle=False, 64 | pin_memory=True, 65 | num_workers=config.num_workers, 66 | drop_last=True) 67 | 68 | 69 | 70 | 71 | 72 | print('#----------Prepareing Model----------#') 73 | model_cfg = config.model_config 74 | if config.network == 'vmunet-v2': 75 | model = VMUNetV2( 76 | num_classes=model_cfg['num_classes'], 77 | input_channels=model_cfg['input_channels'], 78 | depths=model_cfg['depths'], 79 | depths_decoder=model_cfg['depths_decoder'], 80 | drop_path_rate=model_cfg['drop_path_rate'], 81 | load_ckpt_path=model_cfg['load_ckpt_path'], 82 | deep_supervision = model_cfg['deep_supervision'], 83 | ) 84 | model.load_from() 85 | 86 | else: raise Exception('network in not right!') 87 | model = model.cuda() 88 | 89 | cal_params_flops(model, 256, logger) 90 | 91 | 92 | 93 | 94 | 95 | print('#----------Prepareing loss, opt, sch and amp----------#') 96 | criterion = config.criterion 97 | optimizer = get_optimizer(config, model) 98 | scheduler = get_scheduler(config, optimizer) 99 | 100 | 101 | 102 | 103 | 104 | print('#----------Set other params----------#') 105 | min_loss = 999 106 | start_epoch = 1 107 | min_epoch = 1 108 | 109 | 110 | 111 | 112 | 113 | if os.path.exists(resume_model): 114 | print('#----------Resume Model and Other params----------#') 115 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 116 | model.load_state_dict(checkpoint['model_state_dict']) 117 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 118 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 119 | saved_epoch = checkpoint['epoch'] 120 | start_epoch += saved_epoch 121 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 122 | 123 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 124 | logger.info(log_info) 125 | 126 | 127 | 128 | 129 | step = 0 130 | print('#----------Training----------#') 131 | for epoch in range(start_epoch, config.epochs + 1): 132 | 133 | torch.cuda.empty_cache() 134 | 135 | step = train_one_epoch( 136 | train_loader, 137 | model, 138 | criterion, 139 | optimizer, 140 | scheduler, 141 | epoch, 142 | step, 143 | logger, 144 | config, 145 | writer 146 | ) 147 | 148 | loss = val_one_epoch( 149 | val_loader, 150 | model, 151 | criterion, 152 | epoch, 153 | logger, 154 | config 155 | ) 156 | 157 | if loss < min_loss: 158 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 159 | min_loss = loss 160 | min_epoch = epoch 161 | 162 | torch.save( 163 | { 164 | 'epoch': epoch, 165 | 'min_loss': min_loss, 166 | 'min_epoch': min_epoch, 167 | 'loss': loss, 168 | 'model_state_dict': model.state_dict(), 169 | 'optimizer_state_dict': optimizer.state_dict(), 170 | 'scheduler_state_dict': scheduler.state_dict(), 171 | }, os.path.join(checkpoint_dir, 'latest.pth')) 172 | 173 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 174 | print('#----------Testing----------#') 175 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 176 | model.load_state_dict(best_weight) 177 | loss = test_one_epoch( 178 | val_loader, 179 | model, 180 | criterion, 181 | logger, 182 | config, 183 | ) 184 | os.rename( 185 | os.path.join(checkpoint_dir, 'best.pth'), 186 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 187 | ) 188 | 189 | 190 | if __name__ == '__main__': 191 | config = setting_config 192 | main(config) -------------------------------------------------------------------------------- /train_synapse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | import torch.distributed as dist 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torchvision import transforms 9 | 10 | from datasets.dataset import RandomGenerator 11 | from engine_synapse import * 12 | 13 | from models.vmunet.vmunet import VMUNet 14 | 15 | import os 16 | import sys 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 18 | 19 | from utils import * 20 | from configs.config_setting_synapse import setting_config 21 | 22 | import warnings 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | 27 | 28 | 29 | def main(config): 30 | 31 | print('#----------Creating logger----------#') 32 | sys.path.append(config.work_dir + '/') 33 | log_dir = os.path.join(config.work_dir, 'log') 34 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 35 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 36 | outputs = os.path.join(config.work_dir, 'outputs') 37 | if not os.path.exists(checkpoint_dir): 38 | os.makedirs(checkpoint_dir) 39 | if not os.path.exists(outputs): 40 | os.makedirs(outputs) 41 | 42 | global logger 43 | logger = get_logger('train', log_dir) 44 | 45 | log_config_info(config, logger) 46 | 47 | 48 | 49 | 50 | 51 | print('#----------GPU init----------#') 52 | set_seed(config.seed) 53 | gpu_ids = [0]# [0, 1, 2, 3] 54 | torch.cuda.empty_cache() 55 | gpus_type, gpus_num = torch.cuda.get_device_name(), torch.cuda.device_count() 56 | if config.distributed: 57 | print('#----------Start DDP----------#') 58 | dist.init_process_group(backend='nccl', init_method='env://') 59 | torch.cuda.manual_seed_all(config.seed) 60 | config.local_rank = torch.distributed.get_rank() 61 | 62 | 63 | 64 | 65 | 66 | print('#----------Preparing dataset----------#') 67 | train_dataset = config.datasets(base_dir=config.data_path, list_dir=config.list_dir, split="train", 68 | transform=transforms.Compose( 69 | [RandomGenerator(output_size=[config.input_size_h, config.input_size_w])])) 70 | train_sampler = DistributedSampler(train_dataset, shuffle=True) if config.distributed else None 71 | train_loader = DataLoader(train_dataset, 72 | batch_size=config.batch_size//gpus_num if config.distributed else config.batch_size, 73 | shuffle=(train_sampler is None), 74 | pin_memory=True, 75 | num_workers=config.num_workers, 76 | sampler=train_sampler) 77 | 78 | val_dataset = config.datasets(base_dir=config.volume_path, split="test_vol", list_dir=config.list_dir) 79 | val_sampler = DistributedSampler(val_dataset, shuffle=False) if config.distributed else None 80 | val_loader = DataLoader(val_dataset, 81 | batch_size=1, # if config.distributed else config.batch_size, 82 | shuffle=False, 83 | pin_memory=True, 84 | num_workers=config.num_workers, 85 | sampler=val_sampler, 86 | drop_last=True) 87 | 88 | 89 | 90 | 91 | 92 | print('#----------Prepareing Models----------#') 93 | model_cfg = config.model_config 94 | if config.network == 'vmunet': 95 | model = VMUNet( 96 | num_classes=model_cfg['num_classes'], 97 | input_channels=model_cfg['input_channels'], 98 | depths=model_cfg['depths'], 99 | depths_decoder=model_cfg['depths_decoder'], 100 | drop_path_rate=model_cfg['drop_path_rate'], 101 | load_ckpt_path=model_cfg['load_ckpt_path'], 102 | ) 103 | model.load_from() 104 | else: raise('Please prepare a right net!') 105 | 106 | if config.distributed: 107 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 108 | model = DDP(model, device_ids=[config.local_rank], output_device=config.local_rank) 109 | else: 110 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 111 | 112 | 113 | 114 | 115 | 116 | print('#----------Prepareing loss, opt, sch and amp----------#') 117 | criterion = config.criterion 118 | optimizer = get_optimizer(config, model) 119 | scheduler = get_scheduler(config, optimizer) 120 | scaler = GradScaler() 121 | 122 | 123 | 124 | 125 | 126 | print('#----------Set other params----------#') 127 | min_loss = 999 128 | start_epoch = 1 129 | min_epoch = 1 130 | 131 | 132 | 133 | 134 | 135 | if os.path.exists(resume_model): 136 | print('#----------Resume Model and Other params----------#') 137 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 138 | model.module.load_state_dict(checkpoint['model_state_dict']) 139 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 140 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 141 | saved_epoch = checkpoint['epoch'] 142 | start_epoch += saved_epoch 143 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 144 | 145 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 146 | logger.info(log_info) 147 | 148 | 149 | 150 | 151 | 152 | print('#----------Training----------#') 153 | for epoch in range(start_epoch, config.epochs + 1): 154 | 155 | torch.cuda.empty_cache() 156 | train_sampler.set_epoch(epoch) if config.distributed else None 157 | 158 | loss = train_one_epoch( 159 | train_loader, 160 | model, 161 | criterion, 162 | optimizer, 163 | scheduler, 164 | epoch, 165 | logger, 166 | config, 167 | scaler=scaler 168 | ) 169 | 170 | if loss < min_loss: 171 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 172 | min_loss = loss 173 | min_epoch = epoch 174 | 175 | if epoch % config.val_interval == 0: 176 | mean_dice, mean_hd95 = val_one_epoch( 177 | val_dataset, 178 | val_loader, 179 | model, 180 | epoch, 181 | logger, 182 | config, 183 | test_save_path=outputs, 184 | val_or_test=False 185 | ) 186 | 187 | torch.save( 188 | { 189 | 'epoch': epoch, 190 | 'min_loss': min_loss, 191 | 'min_epoch': min_epoch, 192 | 'loss': loss, 193 | 'model_state_dict': model.module.state_dict(), 194 | 'optimizer_state_dict': optimizer.state_dict(), 195 | 'scheduler_state_dict': scheduler.state_dict(), 196 | }, os.path.join(checkpoint_dir, 'latest.pth')) 197 | 198 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 199 | print('#----------Testing----------#') 200 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 201 | model.module.load_state_dict(best_weight) 202 | mean_dice, mean_hd95 = val_one_epoch( 203 | val_dataset, 204 | val_loader, 205 | model, 206 | epoch, 207 | logger, 208 | config, 209 | test_save_path=outputs, 210 | val_or_test=True 211 | ) 212 | os.rename( 213 | os.path.join(checkpoint_dir, 'best.pth'), 214 | os.path.join(checkpoint_dir, 215 | f'best-epoch{min_epoch}-mean_dice{mean_dice:.4f}-mean_hd95{mean_hd95:.4f}.pth') 216 | ) 217 | 218 | 219 | if __name__ == '__main__': 220 | config = setting_config 221 | main(config) -------------------------------------------------------------------------------- /unet_v2/UNet_v2.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import warnings 3 | 4 | import torch 5 | from torch import nn 6 | from unet_v2.pvtv2 import * 7 | # from pvtv2 import * # debug use only 8 | import torch.nn.functional as F 9 | 10 | 11 | class ChannelAttention(nn.Module): 12 | def __init__(self, in_planes, ratio=16): 13 | super(ChannelAttention, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.max_pool = nn.AdaptiveMaxPool2d(1) 16 | 17 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 18 | self.relu1 = nn.ReLU() 19 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 20 | 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | def forward(self, x): 24 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 25 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 26 | out = avg_out + max_out 27 | return self.sigmoid(out) 28 | 29 | 30 | class SpatialAttention(nn.Module): 31 | def __init__(self, kernel_size=7): 32 | super(SpatialAttention, self).__init__() 33 | 34 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 35 | padding = 3 if kernel_size == 7 else 1 36 | 37 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 38 | self.sigmoid = nn.Sigmoid() 39 | 40 | def forward(self, x): 41 | avg_out = torch.mean(x, dim=1, keepdim=True) 42 | max_out, _ = torch.max(x, dim=1, keepdim=True) 43 | x = torch.cat([avg_out, max_out], dim=1) 44 | x = self.conv1(x) 45 | return self.sigmoid(x) 46 | 47 | 48 | class BasicConv2d(nn.Module): 49 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 50 | super(BasicConv2d, self).__init__() 51 | 52 | self.conv = nn.Conv2d(in_planes, out_planes, 53 | kernel_size=kernel_size, stride=stride, 54 | padding=padding, dilation=dilation, bias=False) 55 | self.bn = nn.BatchNorm2d(out_planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | 58 | def forward(self, x): 59 | x = self.conv(x) 60 | x = self.bn(x) 61 | return x 62 | 63 | 64 | class Encoder(nn.Module): 65 | def __init__(self, pretrain_path): 66 | super().__init__() 67 | self.backbone = pvt_v2_b2() 68 | 69 | if pretrain_path is None: 70 | warnings.warn('please provide the pretrained pvt model. Not using pretrained model.') 71 | elif not os.path.isfile(pretrain_path): 72 | warnings.warn(f'path: {pretrain_path} does not exists. Not using pretrained model.') 73 | else: 74 | print(f"using pretrained file: {pretrain_path}") 75 | save_model = torch.load(pretrain_path) 76 | model_dict = self.backbone.state_dict() 77 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 78 | model_dict.update(state_dict) 79 | 80 | self.backbone.load_state_dict(model_dict) 81 | 82 | def forward(self, x): 83 | f1, f2, f3, f4 = self.backbone(x) # (x: 3, 352, 352) 84 | return f1, f2, f3, f4 85 | 86 | 87 | class SDI(nn.Module): 88 | def __init__(self, channel): 89 | super().__init__() 90 | 91 | self.convs = nn.ModuleList( 92 | [nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)]) 93 | 94 | def forward(self, xs, anchor): 95 | ans = torch.ones_like(anchor) 96 | target_size = anchor.shape[-1] 97 | 98 | for i, x in enumerate(xs): 99 | if x.shape[-1] > target_size: 100 | x = F.adaptive_avg_pool2d(x, (target_size, target_size)) 101 | elif x.shape[-1] < target_size: 102 | x = F.interpolate(x, size=(target_size, target_size), 103 | mode='bilinear', align_corners=True) 104 | 105 | ans = ans * self.convs[i](x) 106 | 107 | return ans 108 | 109 | 110 | class UNetV2(nn.Module): 111 | """ 112 | use SpatialAtt + ChannelAtt 113 | """ 114 | def __init__(self, channel=32, n_classes=1, deep_supervision=True, pretrained_path=None): 115 | super().__init__() 116 | self.deep_supervision = deep_supervision 117 | 118 | self.encoder = Encoder(pretrained_path) 119 | 120 | self.ca_1 = ChannelAttention(64) 121 | self.sa_1 = SpatialAttention() 122 | 123 | self.ca_2 = ChannelAttention(128) 124 | self.sa_2 = SpatialAttention() 125 | 126 | self.ca_3 = ChannelAttention(320) 127 | self.sa_3 = SpatialAttention() 128 | 129 | self.ca_4 = ChannelAttention(512) 130 | self.sa_4 = SpatialAttention() 131 | 132 | self.Translayer_1 = BasicConv2d(64, channel, 1) 133 | self.Translayer_2 = BasicConv2d(128, channel, 1) 134 | self.Translayer_3 = BasicConv2d(320, channel, 1) 135 | self.Translayer_4 = BasicConv2d(512, channel, 1) 136 | 137 | self.sdi_1 = SDI(channel) 138 | self.sdi_2 = SDI(channel) 139 | self.sdi_3 = SDI(channel) 140 | self.sdi_4 = SDI(channel) 141 | 142 | self.seg_outs = nn.ModuleList([ 143 | nn.Conv2d(channel, n_classes, 1, 1) for _ in range(4)]) 144 | 145 | self.deconv2 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1, 146 | bias=False) 147 | self.deconv3 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 148 | padding=1, bias=False) 149 | self.deconv4 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 150 | padding=1, bias=False) 151 | self.deconv5 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 152 | padding=1, bias=False) 153 | 154 | def forward(self, x): 155 | seg_outs = [] 156 | f1, f2, f3, f4 = self.encoder(x) # f1 [2, 64, 64, 64], f4 [2, 512, 8, 8] 157 | 158 | f1 = self.ca_1(f1) * f1 159 | f1 = self.sa_1(f1) * f1 160 | f1 = self.Translayer_1(f1) # f1 [2, 32, 64, 64] 161 | 162 | f2 = self.ca_2(f2) * f2 163 | f2 = self.sa_2(f2) * f2 164 | f2 = self.Translayer_2(f2) # f2 [2, 32, 32, 32] 165 | 166 | f3 = self.ca_3(f3) * f3 167 | f3 = self.sa_3(f3) * f3 168 | f3 = self.Translayer_3(f3) # f3 [2, 32, 16, 16] 169 | 170 | f4 = self.ca_4(f4) * f4 171 | f4 = self.sa_4(f4) * f4 172 | f4 = self.Translayer_4(f4) # f4 [2, 32, 8, 8] 173 | 174 | f41 = self.sdi_4([f1, f2, f3, f4], f4) # [2, 32, 8, 8] 175 | f31 = self.sdi_3([f1, f2, f3, f4], f3) # [2, 32, 16, 16] 176 | f21 = self.sdi_2([f1, f2, f3, f4], f2) # [2, 32, 32, 32] 177 | f11 = self.sdi_1([f1, f2, f3, f4], f1) # [2, 32, 64, 64] 178 | # 函数seg_outs 输出列表也是 seg_outs 只是名字相同 179 | seg_outs.append(self.seg_outs[0](f41)) # seg_outs[0] [2, 2, 8, 8] 180 | 181 | y = self.deconv2(f41) + f31 182 | seg_outs.append(self.seg_outs[1](y)) # seg_outs[1] [2, 2, 16, 16] 183 | 184 | y = self.deconv3(y) + f21 185 | seg_outs.append(self.seg_outs[2](y)) # seg_outs[2] [2, 2, 32, 32] 186 | 187 | y = self.deconv4(y) + f11 188 | seg_outs.append(self.seg_outs[3](y)) # seg_outs[3] [2, 2, 64, 64] 189 | 190 | for i, o in enumerate(seg_outs): # 4 倍上采样 191 | seg_outs[i] = F.interpolate(o, scale_factor=4, mode='bilinear') 192 | 193 | if self.deep_supervision: 194 | return seg_outs[::-1] 195 | else: 196 | return seg_outs[-1] 197 | 198 | 199 | 200 | 201 | 202 | from thop import profile ## 导入thop模块 203 | def cal_params_flops(model, size): 204 | input = torch.randn(1, 3, size, size).cuda() 205 | flops, params = profile(model, inputs=(input,)) 206 | print('flops',flops/1e9) ## 打印计算量 207 | print('params',params/1e6) ## 打印参数量 208 | 209 | total = sum(p.numel() for p in model.parameters()) 210 | print("Total params: %.2fM" % (total/1e6)) 211 | # logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}') 212 | 213 | 214 | 215 | 216 | # flops 5.399307904 217 | # params 25.128652 218 | # Total params: 25.15M 219 | 220 | 221 | if __name__ == "__main__": 222 | pretrained_path = "/raid/code/mamba_all/U-Net_v2/PolypSeg/pvt_pth/pvt_v2_b2.pth" 223 | model = UNetV2(n_classes=1, deep_supervision=True, pretrained_path=None).cuda() 224 | x = torch.rand((1, 3, 256, 256)) 225 | # ys = model(x) 226 | # for y in ys: 227 | # print(y.shape) 228 | cal_params_flops(model, 256) 229 | -------------------------------------------------------------------------------- /unet_v2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nobodyplayer1/VM-UNetV2/b65f04c10237158ddd4ebe5bf2c1b5cbb15524fb/unet_v2/__init__.py -------------------------------------------------------------------------------- /unet_v2/pvtv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | from timm.models.registry import register_model 10 | 11 | import math 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.dwconv = DWConv(hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | self.apply(self._init_weights) 26 | 27 | def _init_weights(self, m): 28 | if isinstance(m, nn.Linear): 29 | trunc_normal_(m.weight, std=.02) 30 | if isinstance(m, nn.Linear) and m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.LayerNorm): 33 | nn.init.constant_(m.bias, 0) 34 | nn.init.constant_(m.weight, 1.0) 35 | elif isinstance(m, nn.Conv2d): 36 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | fan_out //= m.groups 38 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | 42 | def forward(self, x, H, W): 43 | x = self.fc1(x) 44 | x = self.dwconv(x, H, W) 45 | x = self.act(x) 46 | x = self.drop(x) 47 | x = self.fc2(x) 48 | x = self.drop(x) 49 | return x 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 54 | super().__init__() 55 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 56 | 57 | self.dim = dim 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | self.apply(self._init_weights) 74 | 75 | def _init_weights(self, m): 76 | if isinstance(m, nn.Linear): 77 | trunc_normal_(m.weight, std=.02) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | elif isinstance(m, nn.Conv2d): 84 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | fan_out //= m.groups 86 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 87 | if m.bias is not None: 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x, H, W): 91 | B, N, C = x.shape 92 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 93 | 94 | if self.sr_ratio > 1: 95 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 96 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 97 | x_ = self.norm(x_) 98 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | else: 100 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 101 | k, v = kv[0], kv[1] 102 | 103 | attn = (q @ k.transpose(-2, -1)) * self.scale 104 | attn = attn.softmax(dim=-1) 105 | attn = self.attn_drop(attn) 106 | 107 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 108 | x = self.proj(x) 109 | x = self.proj_drop(x) 110 | 111 | return x 112 | 113 | 114 | class Block(nn.Module): 115 | 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention( 121 | dim, 122 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 123 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 124 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 125 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 126 | self.norm2 = norm_layer(dim) 127 | mlp_hidden_dim = int(dim * mlp_ratio) 128 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 129 | 130 | self.apply(self._init_weights) 131 | 132 | def _init_weights(self, m): 133 | if isinstance(m, nn.Linear): 134 | trunc_normal_(m.weight, std=.02) 135 | if isinstance(m, nn.Linear) and m.bias is not None: 136 | nn.init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.LayerNorm): 138 | nn.init.constant_(m.bias, 0) 139 | nn.init.constant_(m.weight, 1.0) 140 | elif isinstance(m, nn.Conv2d): 141 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 142 | fan_out //= m.groups 143 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 144 | if m.bias is not None: 145 | m.bias.data.zero_() 146 | 147 | def forward(self, x, H, W): 148 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 149 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 150 | 151 | return x 152 | 153 | 154 | class OverlapPatchEmbed(nn.Module): 155 | """ Image to Patch Embedding 156 | """ 157 | 158 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 166 | self.num_patches = self.H * self.W 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 168 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 169 | self.norm = nn.LayerNorm(embed_dim) 170 | 171 | self.apply(self._init_weights) 172 | 173 | def _init_weights(self, m): 174 | if isinstance(m, nn.Linear): 175 | trunc_normal_(m.weight, std=.02) 176 | if isinstance(m, nn.Linear) and m.bias is not None: 177 | nn.init.constant_(m.bias, 0) 178 | elif isinstance(m, nn.LayerNorm): 179 | nn.init.constant_(m.bias, 0) 180 | nn.init.constant_(m.weight, 1.0) 181 | elif isinstance(m, nn.Conv2d): 182 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | fan_out //= m.groups 184 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 185 | if m.bias is not None: 186 | m.bias.data.zero_() 187 | 188 | def forward(self, x): 189 | x = self.proj(x) 190 | _, _, H, W = x.shape 191 | x = x.flatten(2).transpose(1, 2) 192 | x = self.norm(x) 193 | 194 | return x, H, W 195 | 196 | 197 | class PyramidVisionTransformerImpr(nn.Module): 198 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 199 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 200 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 201 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 202 | super().__init__() 203 | self.num_classes = num_classes 204 | self.depths = depths 205 | 206 | # patch_embed 207 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 208 | embed_dim=embed_dims[0]) 209 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 210 | embed_dim=embed_dims[1]) 211 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 212 | embed_dim=embed_dims[2]) 213 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 214 | embed_dim=embed_dims[3]) 215 | 216 | # transformer encoder 217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 218 | cur = 0 219 | self.block1 = nn.ModuleList([Block( 220 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 221 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 222 | sr_ratio=sr_ratios[0]) 223 | for i in range(depths[0])]) 224 | self.norm1 = norm_layer(embed_dims[0]) 225 | 226 | cur += depths[0] 227 | self.block2 = nn.ModuleList([Block( 228 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 229 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 230 | sr_ratio=sr_ratios[1]) 231 | for i in range(depths[1])]) 232 | self.norm2 = norm_layer(embed_dims[1]) 233 | 234 | cur += depths[1] 235 | self.block3 = nn.ModuleList([Block( 236 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 237 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 238 | sr_ratio=sr_ratios[2]) 239 | for i in range(depths[2])]) 240 | self.norm3 = norm_layer(embed_dims[2]) 241 | 242 | cur += depths[2] 243 | self.block4 = nn.ModuleList([Block( 244 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 245 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 246 | sr_ratio=sr_ratios[3]) 247 | for i in range(depths[3])]) 248 | self.norm4 = norm_layer(embed_dims[3]) 249 | 250 | # classification head 251 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 252 | 253 | self.apply(self._init_weights) 254 | 255 | def _init_weights(self, m): 256 | if isinstance(m, nn.Linear): 257 | trunc_normal_(m.weight, std=.02) 258 | if isinstance(m, nn.Linear) and m.bias is not None: 259 | nn.init.constant_(m.bias, 0) 260 | elif isinstance(m, nn.LayerNorm): 261 | nn.init.constant_(m.bias, 0) 262 | nn.init.constant_(m.weight, 1.0) 263 | elif isinstance(m, nn.Conv2d): 264 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 265 | fan_out //= m.groups 266 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 267 | if m.bias is not None: 268 | m.bias.data.zero_() 269 | 270 | def init_weights(self, pretrained=None): 271 | if isinstance(pretrained, str): 272 | logger = 1 273 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 274 | 275 | def reset_drop_path(self, drop_path_rate): 276 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 277 | cur = 0 278 | for i in range(self.depths[0]): 279 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 280 | 281 | cur += self.depths[0] 282 | for i in range(self.depths[1]): 283 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 284 | 285 | cur += self.depths[1] 286 | for i in range(self.depths[2]): 287 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 288 | 289 | cur += self.depths[2] 290 | for i in range(self.depths[3]): 291 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 292 | 293 | def freeze_patch_emb(self): 294 | self.patch_embed1.requires_grad = False 295 | 296 | @torch.jit.ignore 297 | def no_weight_decay(self): 298 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 299 | 300 | def get_classifier(self): 301 | return self.head 302 | 303 | def reset_classifier(self, num_classes, global_pool=''): 304 | self.num_classes = num_classes 305 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 306 | 307 | # def _get_pos_embed(self, pos_embed, patch_embed, H, W): 308 | # if H * W == self.patch_embed1.num_patches: 309 | # return pos_embed 310 | # else: 311 | # return F.interpolate( 312 | # pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 313 | # size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 314 | 315 | def forward_features(self, x): 316 | B = x.shape[0] 317 | outs = [] 318 | 319 | # stage 1 320 | x, H, W = self.patch_embed1(x) 321 | for i, blk in enumerate(self.block1): 322 | x = blk(x, H, W) 323 | x = self.norm1(x) 324 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 325 | outs.append(x) 326 | 327 | # stage 2 328 | x, H, W = self.patch_embed2(x) 329 | for i, blk in enumerate(self.block2): 330 | x = blk(x, H, W) 331 | x = self.norm2(x) 332 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 333 | outs.append(x) 334 | 335 | # stage 3 336 | x, H, W = self.patch_embed3(x) 337 | for i, blk in enumerate(self.block3): 338 | x = blk(x, H, W) 339 | x = self.norm3(x) 340 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 341 | outs.append(x) 342 | 343 | # stage 4 344 | x, H, W = self.patch_embed4(x) 345 | for i, blk in enumerate(self.block4): 346 | x = blk(x, H, W) 347 | x = self.norm4(x) 348 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 349 | outs.append(x) 350 | 351 | return outs 352 | 353 | # return x.mean(dim=1) 354 | 355 | def forward(self, x): 356 | x = self.forward_features(x) 357 | # x = self.head(x) 358 | 359 | return x 360 | 361 | 362 | class DWConv(nn.Module): 363 | def __init__(self, dim=768): 364 | super(DWConv, self).__init__() 365 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 366 | 367 | def forward(self, x, H, W): 368 | B, N, C = x.shape 369 | x = x.transpose(1, 2).view(B, C, H, W) 370 | x = self.dwconv(x) 371 | x = x.flatten(2).transpose(1, 2) 372 | 373 | return x 374 | 375 | 376 | def _conv_filter(state_dict, patch_size=16): 377 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 378 | out_dict = {} 379 | for k, v in state_dict.items(): 380 | if 'patch_embed.proj.weight' in k: 381 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 382 | out_dict[k] = v 383 | 384 | return out_dict 385 | 386 | 387 | @register_model 388 | class pvt_v2_b0(PyramidVisionTransformerImpr): 389 | def __init__(self, **kwargs): 390 | super(pvt_v2_b0, self).__init__( 391 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 392 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 393 | drop_rate=0.0, drop_path_rate=0.1) 394 | 395 | 396 | 397 | @register_model 398 | class pvt_v2_b1(PyramidVisionTransformerImpr): 399 | def __init__(self, **kwargs): 400 | super(pvt_v2_b1, self).__init__( 401 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 402 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 403 | drop_rate=0.0, drop_path_rate=0.1) 404 | 405 | @register_model 406 | class pvt_v2_b2(PyramidVisionTransformerImpr): 407 | def __init__(self, **kwargs): 408 | super(pvt_v2_b2, self).__init__( 409 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 410 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 411 | drop_rate=0.0, drop_path_rate=0.1) 412 | 413 | @register_model 414 | class pvt_v2_b3(PyramidVisionTransformerImpr): 415 | def __init__(self, **kwargs): 416 | super(pvt_v2_b3, self).__init__( 417 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 418 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 419 | drop_rate=0.0, drop_path_rate=0.1) 420 | 421 | @register_model 422 | class pvt_v2_b4(PyramidVisionTransformerImpr): 423 | def __init__(self, **kwargs): 424 | super(pvt_v2_b4, self).__init__( 425 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 426 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 427 | drop_rate=0.0, drop_path_rate=0.1) 428 | 429 | 430 | @register_model 431 | class pvt_v2_b5(PyramidVisionTransformerImpr): 432 | def __init__(self, **kwargs): 433 | super(pvt_v2_b5, self).__init__( 434 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 435 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 436 | drop_rate=0.0, drop_path_rate=0.1) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.backends.cudnn as cudnn 5 | import torchvision.transforms.functional as TF 6 | import numpy as np 7 | import os 8 | import math 9 | import random 10 | import logging 11 | import logging.handlers 12 | from matplotlib import pyplot as plt 13 | 14 | from scipy.ndimage import zoom 15 | import SimpleITK as sitk 16 | from medpy import metric 17 | 18 | 19 | def set_seed(seed): 20 | # for hash 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | # for python and numpy 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | # for cpu gpu 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | # for cudnn 30 | cudnn.benchmark = False 31 | cudnn.deterministic = True 32 | 33 | 34 | def get_logger(name, log_dir): 35 | ''' 36 | Args: 37 | name(str): name of logger 38 | log_dir(str): path of log 39 | ''' 40 | 41 | if not os.path.exists(log_dir): 42 | os.makedirs(log_dir) 43 | 44 | logger = logging.getLogger(name) 45 | logger.setLevel(logging.INFO) 46 | 47 | info_name = os.path.join(log_dir, '{}.info.log'.format(name)) 48 | info_handler = logging.handlers.TimedRotatingFileHandler(info_name, 49 | when='D', 50 | encoding='utf-8') 51 | info_handler.setLevel(logging.INFO) 52 | 53 | formatter = logging.Formatter('%(asctime)s - %(message)s', 54 | datefmt='%Y-%m-%d %H:%M:%S') 55 | 56 | info_handler.setFormatter(formatter) 57 | 58 | logger.addHandler(info_handler) 59 | 60 | return logger 61 | 62 | 63 | def log_config_info(config, logger): 64 | config_dict = config.__dict__ 65 | log_info = f'#----------Config info----------#' 66 | logger.info(log_info) 67 | for k, v in config_dict.items(): 68 | if k[0] == '_': 69 | continue 70 | else: 71 | log_info = f'{k}: {v},' 72 | logger.info(log_info) 73 | 74 | 75 | 76 | def get_optimizer(config, model): 77 | assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 78 | 79 | if config.opt == 'Adadelta': 80 | return torch.optim.Adadelta( 81 | model.parameters(), 82 | lr = config.lr, 83 | rho = config.rho, 84 | eps = config.eps, 85 | weight_decay = config.weight_decay 86 | ) 87 | elif config.opt == 'Adagrad': 88 | return torch.optim.Adagrad( 89 | model.parameters(), 90 | lr = config.lr, 91 | lr_decay = config.lr_decay, 92 | eps = config.eps, 93 | weight_decay = config.weight_decay 94 | ) 95 | elif config.opt == 'Adam': 96 | return torch.optim.Adam( 97 | model.parameters(), 98 | lr = config.lr, 99 | betas = config.betas, 100 | eps = config.eps, 101 | weight_decay = config.weight_decay, 102 | amsgrad = config.amsgrad 103 | ) 104 | elif config.opt == 'AdamW': 105 | return torch.optim.AdamW( 106 | model.parameters(), 107 | lr = config.lr, 108 | betas = config.betas, 109 | eps = config.eps, 110 | weight_decay = config.weight_decay, 111 | amsgrad = config.amsgrad 112 | ) 113 | elif config.opt == 'Adamax': 114 | return torch.optim.Adamax( 115 | model.parameters(), 116 | lr = config.lr, 117 | betas = config.betas, 118 | eps = config.eps, 119 | weight_decay = config.weight_decay 120 | ) 121 | elif config.opt == 'ASGD': 122 | return torch.optim.ASGD( 123 | model.parameters(), 124 | lr = config.lr, 125 | lambd = config.lambd, 126 | alpha = config.alpha, 127 | t0 = config.t0, 128 | weight_decay = config.weight_decay 129 | ) 130 | elif config.opt == 'RMSprop': 131 | return torch.optim.RMSprop( 132 | model.parameters(), 133 | lr = config.lr, 134 | momentum = config.momentum, 135 | alpha = config.alpha, 136 | eps = config.eps, 137 | centered = config.centered, 138 | weight_decay = config.weight_decay 139 | ) 140 | elif config.opt == 'Rprop': 141 | return torch.optim.Rprop( 142 | model.parameters(), 143 | lr = config.lr, 144 | etas = config.etas, 145 | step_sizes = config.step_sizes, 146 | ) 147 | elif config.opt == 'SGD': 148 | return torch.optim.SGD( 149 | model.parameters(), 150 | lr = config.lr, 151 | momentum = config.momentum, 152 | weight_decay = config.weight_decay, 153 | dampening = config.dampening, 154 | nesterov = config.nesterov 155 | ) 156 | else: # default opt is SGD 157 | return torch.optim.SGD( 158 | model.parameters(), 159 | lr = 0.01, 160 | momentum = 0.9, 161 | weight_decay = 0.05, 162 | ) 163 | 164 | 165 | def get_scheduler(config, optimizer): 166 | assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau', 167 | 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!' 168 | if config.sch == 'StepLR': 169 | scheduler = torch.optim.lr_scheduler.StepLR( 170 | optimizer, 171 | step_size = config.step_size, 172 | gamma = config.gamma, 173 | last_epoch = config.last_epoch 174 | ) 175 | elif config.sch == 'MultiStepLR': 176 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 177 | optimizer, 178 | milestones = config.milestones, 179 | gamma = config.gamma, 180 | last_epoch = config.last_epoch 181 | ) 182 | elif config.sch == 'ExponentialLR': 183 | scheduler = torch.optim.lr_scheduler.ExponentialLR( 184 | optimizer, 185 | gamma = config.gamma, 186 | last_epoch = config.last_epoch 187 | ) 188 | elif config.sch == 'CosineAnnealingLR': 189 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 190 | optimizer, 191 | T_max = config.T_max, 192 | eta_min = config.eta_min, 193 | last_epoch = config.last_epoch 194 | ) 195 | elif config.sch == 'ReduceLROnPlateau': 196 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 197 | optimizer, 198 | mode = config.mode, 199 | factor = config.factor, 200 | patience = config.patience, 201 | threshold = config.threshold, 202 | threshold_mode = config.threshold_mode, 203 | cooldown = config.cooldown, 204 | min_lr = config.min_lr, 205 | eps = config.eps 206 | ) 207 | elif config.sch == 'CosineAnnealingWarmRestarts': 208 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 209 | optimizer, 210 | T_0 = config.T_0, 211 | T_mult = config.T_mult, 212 | eta_min = config.eta_min, 213 | last_epoch = config.last_epoch 214 | ) 215 | elif config.sch == 'WP_MultiStepLR': 216 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len( 217 | [m for m in config.milestones if m <= epoch]) 218 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 219 | elif config.sch == 'WP_CosineLR': 220 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * ( 221 | math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1) 222 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 223 | 224 | return scheduler 225 | 226 | 227 | 228 | def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None): 229 | img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy() 230 | img = img / 255. if img.max() > 1.1 else img 231 | if datasets == 'retinal': 232 | msk = np.squeeze(msk, axis=0) 233 | msk_pred = np.squeeze(msk_pred, axis=0) 234 | else: 235 | msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0) 236 | msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0) 237 | 238 | plt.figure(figsize=(7,15)) 239 | 240 | plt.subplot(3,1,1) 241 | plt.imshow(img) 242 | plt.axis('off') 243 | 244 | plt.subplot(3,1,2) 245 | plt.imshow(msk, cmap= 'gray') 246 | plt.axis('off') 247 | 248 | plt.subplot(3,1,3) 249 | plt.imshow(msk_pred, cmap = 'gray') 250 | plt.axis('off') 251 | 252 | if test_data_name is not None: 253 | save_path = save_path + test_data_name + '_' 254 | plt.savefig(save_path + str(i) +'.png') 255 | plt.close() 256 | 257 | 258 | 259 | class BCELoss(nn.Module): 260 | def __init__(self): 261 | super(BCELoss, self).__init__() 262 | self.bceloss = nn.BCELoss() 263 | 264 | def forward(self, pred, target): 265 | size = pred.size(0) 266 | pred_ = pred.view(size, -1) 267 | target_ = target.view(size, -1) 268 | 269 | return self.bceloss(pred_, target_) 270 | 271 | # Dice 和 bce loss 的选择,根据是什么,如果分割小病灶,是否需要调参? 272 | # https://blog.csdn.net/longshaonihaoa/article/details/111824916 273 | class DiceLoss(nn.Module): 274 | def __init__(self): 275 | super(DiceLoss, self).__init__() 276 | 277 | def forward(self, pred, target): 278 | smooth = 1 279 | size = pred.size(0) 280 | 281 | pred_ = pred.view(size, -1) 282 | target_ = target.view(size, -1) 283 | intersection = pred_ * target_ 284 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth) 285 | dice_loss = 1 - dice_score.sum()/size 286 | 287 | return dice_loss 288 | 289 | 290 | class nDiceLoss(nn.Module): 291 | def __init__(self, n_classes): 292 | super(nDiceLoss, self).__init__() 293 | self.n_classes = n_classes 294 | 295 | def _one_hot_encoder(self, input_tensor): 296 | tensor_list = [] 297 | for i in range(self.n_classes): 298 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 299 | tensor_list.append(temp_prob.unsqueeze(1)) 300 | output_tensor = torch.cat(tensor_list, dim=1) 301 | return output_tensor.float() 302 | 303 | def _dice_loss(self, score, target): 304 | target = target.float() 305 | smooth = 1e-5 306 | intersect = torch.sum(score * target) 307 | y_sum = torch.sum(target * target) 308 | z_sum = torch.sum(score * score) 309 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 310 | loss = 1 - loss 311 | return loss 312 | 313 | def forward(self, inputs, target, weight=None, softmax=False): 314 | if softmax: 315 | inputs = torch.softmax(inputs, dim=1) 316 | target = self._one_hot_encoder(target) 317 | if weight is None: 318 | weight = [1] * self.n_classes 319 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 320 | class_wise_dice = [] 321 | loss = 0.0 322 | for i in range(0, self.n_classes): 323 | dice = self._dice_loss(inputs[:, i], target[:, i]) 324 | class_wise_dice.append(1.0 - dice.item()) 325 | loss += dice * weight[i] 326 | return loss / self.n_classes 327 | 328 | 329 | class CeDiceLoss(nn.Module): 330 | def __init__(self, num_classes, loss_weight=[0.4, 0.6]): 331 | super(CeDiceLoss, self).__init__() 332 | self.celoss = nn.CrossEntropyLoss() 333 | self.diceloss = nDiceLoss(num_classes) 334 | self.loss_weight = loss_weight 335 | 336 | def forward(self, pred, target): 337 | loss_ce = self.celoss(pred, target[:].long()) 338 | loss_dice = self.diceloss(pred, target, softmax=True) 339 | loss = self.loss_weight[0] * loss_ce + self.loss_weight[1] * loss_dice 340 | return loss 341 | 342 | 343 | class BceDiceLoss(nn.Module): 344 | def __init__(self, wb=1, wd=1): 345 | super(BceDiceLoss, self).__init__() 346 | self.bce = BCELoss() 347 | self.dice = DiceLoss() 348 | self.wb = wb 349 | self.wd = wd 350 | 351 | def forward(self, pred, target): 352 | bceloss = self.bce(pred, target) 353 | diceloss = self.dice(pred, target) 354 | 355 | loss = self.wd * diceloss + self.wb * bceloss 356 | return loss 357 | 358 | 359 | class GT_BceDiceLoss(nn.Module): 360 | def __init__(self, wb=1, wd=1): 361 | super(GT_BceDiceLoss, self).__init__() 362 | self.bcedice = BceDiceLoss(wb, wd) 363 | 364 | def forward(self, gt_pre, out, target): 365 | bcediceloss = self.bcedice(out, target) 366 | gt_pre5, gt_pre4, gt_pre3, gt_pre2, gt_pre1 = gt_pre 367 | gt_loss = self.bcedice(gt_pre5, target) * 0.1 + self.bcedice(gt_pre4, target) * 0.2 + self.bcedice(gt_pre3, target) * 0.3 + self.bcedice(gt_pre2, target) * 0.4 + self.bcedice(gt_pre1, target) * 0.5 368 | return bcediceloss + gt_loss 369 | 370 | 371 | 372 | class myToTensor: 373 | def __init__(self): 374 | pass 375 | def __call__(self, data): 376 | image, mask = data 377 | return torch.tensor(image).permute(2,0,1), torch.tensor(mask).permute(2,0,1) 378 | 379 | 380 | class myResize: 381 | def __init__(self, size_h=256, size_w=256): 382 | self.size_h = size_h 383 | self.size_w = size_w 384 | def __call__(self, data): 385 | image, mask = data 386 | return TF.resize(image, [self.size_h, self.size_w]), TF.resize(mask, [self.size_h, self.size_w]) 387 | 388 | 389 | class myRandomHorizontalFlip: 390 | def __init__(self, p=0.5): 391 | self.p = p 392 | def __call__(self, data): 393 | image, mask = data 394 | if random.random() < self.p: return TF.hflip(image), TF.hflip(mask) 395 | else: return image, mask 396 | 397 | 398 | class myRandomVerticalFlip: 399 | def __init__(self, p=0.5): 400 | self.p = p 401 | def __call__(self, data): 402 | image, mask = data 403 | if random.random() < self.p: return TF.vflip(image), TF.vflip(mask) 404 | else: return image, mask 405 | 406 | 407 | class myRandomRotation: 408 | def __init__(self, p=0.5, degree=[0,360]): 409 | self.angle = random.uniform(degree[0], degree[1]) 410 | self.p = p 411 | def __call__(self, data): 412 | image, mask = data 413 | if random.random() < self.p: return TF.rotate(image,self.angle), TF.rotate(mask,self.angle) 414 | else: return image, mask 415 | 416 | 417 | class myNormalize: 418 | def __init__(self, data_name, train=True): 419 | if data_name == 'isic18': 420 | if train: 421 | self.mean = 157.561 422 | self.std = 26.706 423 | else: 424 | self.mean = 149.034 425 | self.std = 32.022 426 | elif data_name == 'isic17': 427 | if train: 428 | self.mean = 159.922 429 | self.std = 28.871 430 | else: 431 | self.mean = 148.429 432 | self.std = 25.748 433 | elif data_name == 'isic18_82': 434 | if train: 435 | self.mean = 156.2899 436 | self.std = 26.5457 437 | else: 438 | self.mean = 149.8485 439 | self.std = 35.3346 440 | elif data_name == 'polyp': 441 | if train: 442 | self.mean = 86.17 443 | self.std = 69.08 444 | else: 445 | self.mean = 86.17 446 | self.std = 69.08 447 | elif data_name == 'gim': 448 | if train: 449 | self.mean = 87.84 450 | self.std = 55.37 451 | else: 452 | self.mean = 85.27 453 | self.std = 54.75 454 | elif data_name == 'isic_all': 455 | if train: 456 | self.mean = 158.6 457 | self.std = 44.92 458 | else: 459 | self.mean = 156.2899 460 | self.std = 26.5457 461 | 462 | # TODO 注意 息肉数据集合肠化数据集的区别 463 | # 只是对image 做了归一化和标准化,后来还 TM * 255 窒息 464 | def __call__(self, data): 465 | img, msk = data 466 | img_normalized = (img-self.mean)/self.std 467 | img_normalized = ((img_normalized - np.min(img_normalized)) 468 | / (np.max(img_normalized)-np.min(img_normalized))) * 255. 469 | return img_normalized, msk 470 | 471 | 472 | 473 | from thop import profile ## 导入thop模块 474 | def cal_params_flops(model, size, logger): 475 | input = torch.randn(1, 3, size, size).cuda() 476 | flops, params = profile(model, inputs=(input,)) 477 | print('flops',flops/1e9) ## 打印计算量 478 | print('params',params/1e6) ## 打印参数量 479 | 480 | total = sum(p.numel() for p in model.parameters()) 481 | print("Total params: %.2fM" % (total/1e6)) 482 | logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}') 483 | 484 | 485 | 486 | 487 | 488 | 489 | def calculate_metric_percase(pred, gt): 490 | pred[pred > 0] = 1 491 | gt[gt > 0] = 1 492 | if pred.sum() > 0 and gt.sum()>0: 493 | dice = metric.binary.dc(pred, gt) 494 | hd95 = metric.binary.hd95(pred, gt) 495 | return dice, hd95 496 | elif pred.sum() > 0 and gt.sum()==0: 497 | return 1, 0 498 | else: 499 | return 0, 0 500 | 501 | 502 | 503 | def test_single_volume(image, label, net, classes, patch_size=[256, 256], 504 | test_save_path=None, case=None, z_spacing=1, val_or_test=False): 505 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 506 | if len(image.shape) == 3: 507 | prediction = np.zeros_like(label) 508 | for ind in range(image.shape[0]): 509 | slice = image[ind, :, :] 510 | x, y = slice.shape[0], slice.shape[1] 511 | if x != patch_size[0] or y != patch_size[1]: 512 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 513 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 514 | net.eval() 515 | with torch.no_grad(): 516 | outputs = net(input) 517 | out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) 518 | out = out.cpu().detach().numpy() 519 | if x != patch_size[0] or y != patch_size[1]: 520 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 521 | else: 522 | pred = out 523 | prediction[ind] = pred 524 | else: 525 | input = torch.from_numpy(image).unsqueeze( 526 | 0).unsqueeze(0).float().cuda() 527 | net.eval() 528 | with torch.no_grad(): 529 | out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 530 | prediction = out.cpu().detach().numpy() 531 | metric_list = [] 532 | for i in range(1, classes): 533 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 534 | 535 | if test_save_path is not None and val_or_test is True: 536 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 537 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 538 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 539 | img_itk.SetSpacing((1, 1, z_spacing)) 540 | prd_itk.SetSpacing((1, 1, z_spacing)) 541 | lab_itk.SetSpacing((1, 1, z_spacing)) 542 | sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz") 543 | sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz") 544 | sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz") 545 | # cv2.imwrite(test_save_path + '/'+case + '.png', prediction*255) 546 | return metric_list --------------------------------------------------------------------------------