├── .gitignore ├── README.md ├── configs ├── isic │ ├── isic2018_attunet.yaml │ ├── isic2018_missformer.yaml │ ├── isic2018_multiresunet.yaml │ ├── isic2018_resunet.yaml │ ├── isic2018_transunet.yaml │ ├── isic2018_uctransnet.yaml │ ├── isic2018_unet.yaml │ └── isic2018_unetpp.yaml └── segpc │ ├── segpc2021_attunet.yaml │ ├── segpc2021_missformer.yaml │ ├── segpc2021_multiresunet.yaml │ ├── segpc2021_resunet.yaml │ ├── segpc2021_transunet.yaml │ ├── segpc2021_uctransnet.yaml │ ├── segpc2021_unet.yaml │ └── segpc2021_unetpp.yaml ├── datasets ├── README.md ├── __init__.py ├── isic.ipynb ├── isic.py ├── prepare_isic.ipynb ├── prepare_segpc.ipynb ├── segpc.ipynb └── segpc.py ├── images ├── ComparisonOfModels.png ├── U-Net_Taxonomy.png ├── isic2018.png ├── isic2018_sample.png ├── segpc.png ├── segpc2021_sample.png ├── synapse.png └── unet-pipeline.png ├── losses.py ├── models ├── __init__.py ├── _missformer │ ├── MISSFormer.py │ ├── __init__.py │ └── segformer.py ├── _resunet │ ├── __init__.py │ ├── modules.py │ └── res_unet.py ├── _transunet │ ├── vit_seg_configs.py │ ├── vit_seg_modeling.py │ ├── vit_seg_modeling_c4.py │ ├── vit_seg_modeling_resnet_skip.py │ └── vit_seg_modeling_resnet_skip_c4.py ├── _uctransnet │ ├── CTrans.py │ ├── Config.py │ ├── UCTransNet.py │ └── UNet.py ├── attunet.py ├── multiresunet.py ├── unet.py └── unetpp.py ├── train_and_test ├── isic │ ├── attunet-isic.ipynb │ ├── attunet-isic.py │ ├── missformer-isic.ipynb │ ├── missformer-isic.py │ ├── multiresunet-isic.ipynb │ ├── multiresunet-isic.py │ ├── resunet-isic.ipynb │ ├── resunet-isic.py │ ├── transunet-isic.ipynb │ ├── transunet-isic.py │ ├── uctransnet-isic.ipynb │ ├── uctransnet-isic.py │ ├── unet-isic.ipynb │ ├── unet-isic.py │ ├── unetpp-isic.ipynb │ └── unetpp-isic.py └── segpc │ ├── attunet-segpc.ipynb │ ├── attunet-segpc.py │ ├── missformer-segpc.ipynb │ ├── missformer-segpc.py │ ├── multiresunet-segpc.ipynb │ ├── multiresunet-segpc.py │ ├── resunet-segpc.ipynb │ ├── resunet-segpc.py │ ├── transunet-segpc.ipynb │ ├── transunet-segpc.py │ ├── uctransnet-segpc.ipynb │ ├── uctransnet-segpc.py │ ├── unet-segpc.ipynb │ ├── unet-segpc.py │ ├── unetpp-segpc.ipynb │ └── unetpp-segpc.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | # db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | #vscode 132 | .vscode 133 | 134 | 135 | **.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Awesome U-Net 2 | 3 | [![Awesome](https://cdn.rawgit.com/sindresorhus/awesome/d7305f38d29fed78fa85652e3a63e154dd8e8829/media/badge.svg)](https://github.com/hee9joon/Awesome-Diffusion-Models) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 5 | 6 | Official repo for [Medical Image Segmentation Review: The Success of U-Net](https://ieeexplore.ieee.org/abstract/document/10643318) 7 | 8 |

9 | 10 |

11 | 12 | ### Announcements 13 | 14 | August 21, 2024: The final draft is published at the [IEEE TPAMI](https://ieeexplore.ieee.org/abstract/document/10643318) :fire::fire: 15 | 16 | November 27, 2022: [arXiv](https://arxiv.org/abs/2211.14830) release version. 17 | 18 | #### Citation 19 | 20 | ```latex 21 | @article{azad2024medical, 22 | author={Azad, Reza and Aghdam, Ehsan Khodapanah and Rauland, Amelie and Jia, Yiwei and Avval, Atlas Haddadi and Bozorgpour, Afshin and Karimijafarbigloo, Sanaz and Cohen, Joseph Paul and Adeli, Ehsan and Merhof, Dorit}, 23 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 24 | title={Medical Image Segmentation Review: The Success of U-Net}, 25 | year={2024}, 26 | pages={1-20}, 27 | keywords={Image segmentation;Biomedical imaging;Taxonomy;Computer architecture;Feature extraction;Transformers;Task analysis;Medical Image Segmentation;Deep Learning;U-Net;Convolutional Neural Network;Transformer}, 28 | doi={10.1109/TPAMI.2024.3435571} 29 | } 30 | ``` 31 | 32 | --- 33 | 34 | ### Abstract 35 | 36 | Automatic medical image segmentation is a crucial topic in the medical domain and successively a critical counterpart in the computer-aided diagnosis paradigm. U-Net is the most widespread image segmentation architecture due to its flexibility, optimized modular design, and success in all medical image modalities. Over the years, the U-Net model achieved tremendous attention from academic and industrial researchers. Several extensions of this network have been proposed to address the scale and complexity created by medical tasks. Addressing the deficiency of the naive U-Net model is the foremost step for vendors to utilize the proper U-Net variant model for their business. Having a compendium of different variants in one place makes it easier for builders to identify the relevant research. Also, for ML researchers it will help them understand the challenges of the biological tasks that challenge the model. To address this, we discuss the practical aspects of the U-Net model and suggest a taxonomy to categorize each network variant. Moreover, to measure the performance of these strategies in a clinical application, we propose fair evaluations of some unique and famous designs on well-known datasets. We provide a comprehensive implementation library with trained models for future research. In addition, for ease of future studies, we created an online list of U-Net papers with their possible official implementation. 37 | 38 |

39 | 40 |

41 | 42 | --- 43 | 44 | 45 | 46 | ### The structure of codes 47 | 48 | Here is 49 | 50 | ```bash 51 | . 52 | ├── README.md 53 | ├── images 54 | │ └── *.png 55 | ├── configs 56 | │ ├── isic 57 | │ │ ├── isic2018_*.yaml 58 | │ └── segpc 59 | │ └── segpc2021_*.yaml 60 | ├── datasets 61 | │ ├── *.py 62 | │ ├── *.ipynb 63 | │ └── prepare_*.ipynb 64 | ├── models 65 | │ ├── *.py 66 | │ └── _* 67 | │ └── *.py 68 | ├── train_and_test 69 | │ ├── isic 70 | │ │ ├── *-isic.ipynb 71 | │ │ └── *-isic.py 72 | │ └── segpc 73 | │ ├── *-segpc.ipynb 74 | │ └── *-segpc.py 75 | ├── losses.py 76 | └── utils.py 77 | ``` 78 | 79 | 80 | 81 | ## Dataset prepration 82 | 83 | Please go to ["./datasets/README.md"](https://github.com/NITR098/Awesome-U-Net/blob/main/datasets/README.md) for details. We used 3 datasets for this work. After preparing required data you need to put the required data path in relevant config files. 84 | 85 | 86 | 87 | ## Train and Test 88 | 89 | In the `train_and_test` folder, there are folders with the names of different datasets. In each of these subfolders, there are files related to each model network in two different formats (`.py` and ‍`.ipynb`). In notebook files you will face with the following procedures. This file contains both the testing and traning steps. 90 | 91 | - Prepration step 92 | - Import packages & functions 93 | - Set the seed 94 | - Load the config file 95 | - Dataset and Dataloader 96 | - Prepare Metrics 97 | - Define test and validate function 98 | - Load and prepare model 99 | - Traning 100 | - Save the best model 101 | - Test the best inferred model 102 | - Load the best model 103 | - Evaluation 104 | - Plot graphs and print results 105 | - Save images 106 | 107 | 108 | 109 | ### Pretrained model weights 110 | 111 | Here you can download pre-trained weights for networks. 112 | 113 | | Network | Model Weight | Train and Test File | 114 | | ------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 115 | | **U-Net** | [ISIC2018](https://mega.nz/file/pNd0xLIB#LqY-e-hdQhq6_dQZpAw_7MxKclMB5DAFMybL5w99OzM) - [SegPC2021](https://mega.nz/file/EZEjTYyT#UMsliboXuqrsobGHV_mn4jiBrOf_dMZF7hp2aY0o2hI) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/unet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/unet-segpc.ipynb) | 116 | | **Att-UNet** | [ISIC2018](https://mega.nz/file/5VsBTKgK#vNu_nvuz-9Lktw6aMOxuguQyim1sVnG4QdkGtVX3pEs) - [SegPC2021](https://mega.nz/file/gRVCXCgT#We3_nPsx_xIBXy6-bsg85rQYYzKHut17Zn5HDnh0Aqw) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/attunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/attunet-segpc.ipynb) | 117 | | **U-Net++** | [ISIC2018](https://mega.nz/file/NcFQUY5D#1mSGOC4GGTA8arWzcM77yyH9GoApciw0mB4pFp18n0Q) - [SegPC2021](https://mega.nz/file/JFVSHLxY#EwPpPZ5N0KDaXhDXxyyuQ_HaD2iNiv5hdqplznrP8Os) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/unetpp-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/unetpp-segpc.ipynb) | 118 | | **MultiResUNet** | [ISIC2018](https://mega.nz/file/tIEVAAba#t-5vLCMwlH6hzAri7DJ8ut-eT2vFN5b6qj6Vc3By6_g) - [SegPC2021](https://mega.nz/file/tUN11R5C#I_JpAT7mYDM1q40ulp8TJxnzHFR4Fh3WX_klep62ywE) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/multiresunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/multiresunet-segpc.ipynb) | 119 | | **Residual U-Net** | [ISIC2018](https://mega.nz/file/NAVHSSJa#FwcYG6bKOdpcEorN_nnjWFEx29toSspSiMzFTqIrVW4) - [SegPC2021](https://mega.nz/file/gQ91WBRB#mzIAeEUze4cAi74dMa3rqivGdYtzpKqDI16vNao7-6A) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/resunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/resunet-segpc.ipynb) | 120 | | **TransUNet** | [ISIC2018](https://mega.nz/file/UM9jkK6B#7rFd9TiOY6pEGt-gDosFopdV78slgpHbj_wKZ4H39OM) - [SegPC2021](https://mega.nz/file/5YFBXBoZ#6S8B6MyAsSsr5cNw0-QIIIzF6CgxhEUsOl0xwAknTr8) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/transunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/transunet-segpc.ipynb) | 121 | | **UCTransNet** | [ISIC2018](https://mega.nz/file/RMNQmKoQ#j8zGEuud33eh-tOIZa1dpkReB8DYKt1De75eeR7wLnM) - [SegPC2021](https://mega.nz/file/hYMShICa#kg5VFhE-m5X0ouE1rc_teaYSb_E15NpbBVE0P_V7WH8) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/uctransnet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/uctransnet-segpc.ipynb) | 122 | | **MISSFormer** | [ISIC2018](https://mega.nz/file/EANRiBoQ#E2LC0ZS7LU5OuEdQJ8dYGihjzqpEEotUqLEnEGZ59wU) - [SegPC2021](https://mega.nz/file/9I1CUJbZ#V6zdx8vZDyPJjHmVgoJH4D86sTuqNu6OuHeUQVB6ees) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/missformer-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/missformer-segpc.ipynb) | 123 | 124 | 125 | ## Results 126 | 127 | For evaluating the performance of some mentioned methods, three challenging tasks in medical image segmentaion has been considered. In bellow, results of them illustrated. 128 | 129 |
130 | 131 | Performance comparison on ***ISIC 2018*** dataset (best results are bolded). 132 | 133 | | Methods | AC | PR | SE | SP | Dice | IoU | 134 | | ------------------ | ---------- | ---------- | ---------- | ---------- | ---------- | ---------- | 135 | | **U-Net** | 0.9446 | 0.8746 | 0.8603 | 0.9671 | 0.8674 | 0.8491 | 136 | | **Att-UNet** | 0.9516 | 0.9075 | 0.8579 | 0.9766 | 0.8820 | 0.8649 | 137 | | **U-Net++** | 0.9517 | 0.9067 | 0.8590 | 0.9764 | 0.8822 | 0.8651 | 138 | | **MultiResUNet** | 0.9473 | 0.8765 | 0.8689 | 0.9704 | 0.8694 | 0.8537 | 139 | | **Residual U-Net** | 0.9468 | 0.8753 | 0.8659 | 0.9688 | 0.8689 | 0.8509 | 140 | | **TransUNet** | 0.9452 | 0.8823 | 0.8578 | 0.9653 | 0.8499 | 0.8365 | 141 | | **UCTransNet** | **0.9546** | **0.9100** | **0.8704** | **0.9770** | **0.8898** | **0.8729** | 142 | | **MISSFormer** | 0.9453 | 0.8964 | 0.8371 | 0.9742 | 0.8657 | 0.8484 | 143 | 144 |
145 | 146 | Performance comparison on ***SegPC 2021*** dataset (best results are bolded). 147 | 148 | | Methods | AC | PR | SE | SP | Dice | IoU | 149 | | ------------------ | ---------- | ---------- | ---------- | ---------- | ---------- | ---------- | 150 | | **U-Net** | 0.9795 | 0.9084 | 0.8548 | 0.9916 | 0.8808 | 0.8824 | 151 | | **Att-UNet** | 0.9854 | 0.9360 | 0.8964 | 0.9940 | 0.9158 | 0.9144 | 152 | | **U-Net++** | 0.9845 | 0.9328 | 0.8887 | 0.9938 | 0.9102 | 0.9092 | 153 | | **MultiResUNet** | 0.9753 | 0.8391 | 0.8925 | 0.9834 | 0.8649 | 0.8676 | 154 | | **Residual U-Net** | 0.9743 | 0.8920 | 0.8080 | 0.9905 | 0.8479 | 0.8541 | 155 | | **TransUNet** | 0.9702 | 0.8678 | 0.7831 | 0.9884 | 0.8233 | 0.8338 | 156 | | **UCTransNet** | **0.9857** | **0.9365** | **0.8991** | **0.9941** | **0.9174** | **0.9159** | 157 | | **MISSFormer** | 0.9663 | 0.8152 | 0.8014 | 0.9823 | 0.8082 | 0.8209 | 158 | 159 |
160 | 161 | Performance comparison on ***Synapse*** dataset (best results are bolded). 162 | | Method | DSC↑ | HD↓ | Aorta | Gallbladder | Kidney(L) | Kidney(R) | Liver | Pancreas | Spleen | Stomach | 163 | | ------------------ | --------- | --------- | --------- | ----------- | --------- | --------- | --------- | --------- | --------- | --------- | 164 | | **U-Net** | 76.85 | 39.70 | 89.07 | 69.72 | 77.77 | 68.60 | 93.43 | 53.98 | 86.67 | 75.58 | 165 | | **Att-UNet** | 77.77 | 36.02 | **89.55** | **68.88** | 77.98 | 71.11 | 93.57 | 58.04 | 87.30 | 75.75 | 166 | | **U-Net++** | 76.91 | 36.93 | 88.19 | 65.89 | 81.76 | 74.27 | 93.01 | 58.20 | 83.44 | 70.52 | 167 | | **MultiResUNet** | 77.42 | 36.84 | 87.73 | 65.67 | 82.08 | 70.43 | 93.49 | 60.09 | 85.23 | 74.66 | 168 | | **Residual U-Net** | 76.95 | 38.44 | 87.06 | 66.05 | 83.43 | 76.83 | 93.99 | 51.86 | 85.25 | 70.13 | 169 | | **TransUNet** | 77.48 | 31.69 | 87.23 | 63.13 | 81.87 | 77.02 | 94.08 | 55.86 | 85.08 | 75.62 | 170 | | **UCTransNet** | 78.23 | 26.75 | 84.25 | 64.65 | 82.35 | 77.65 | 94.36 | 58.18 | 84.74 | 79.66 | 171 | | **MISSFormer** | **81.96** | **18.20** | 86.99 | 68.65 | **85.21** | **82.00** | **94.41** | **65.67** | **91.92** | **80.81** | 172 | 173 | ### Visualization 174 | 175 | - **Results on ISIC 2018** 176 | 177 | ![isic2018.png](./images/isic2018.png) 178 | 179 | Visual comparisons of different methods on the *ISIC 2018* skin lesion segmentation dataset. Ground truth boundaries are shown in green, and predicted boundaries are shown in blue. 180 | 181 | - **Result on SegPC 2021** 182 | 183 | ![segpc.png](./images/segpc.png) 184 | 185 | Visual comparisons of different methods on the *SegPC 2021* cell segmentation dataset. Red region indicates the Cytoplasm and blue denotes the Nucleus area of cell. 186 | 187 | - **Result on Synapse** 188 | 189 | ![synapse.png](./images/synapse.png) 190 | 191 | Visual comparisons of different methods on the *Synapse* multi-organ segmentation dataset. 192 | 193 | ## References 194 | 195 | ### Codes [GitHub Pages] 196 | 197 | - AttU-Net: [https://github.com/LeeJunHyun/Image_Segmentation](https://github.com/LeeJunHyun/Image_Segmentation) 198 | 199 | - U-Net++: [https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py](https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py) 200 | 201 | - MultiResUNet: https://github.com/j-sripad/mulitresunet-pytorch/blob/main/multiresunet.py 202 | 203 | - Residual U-Net: https://github.com/rishikksh20/ResUnet 204 | 205 | - TransUNet: https://github.com/Beckschen/TransUNet 206 | 207 | - UCTransNet: https://github.com/McGregorWwww/UCTransNet 208 | 209 | - MISSFormer: https://github.com/ZhifangDeng/MISSFormer 210 | 211 | ### Query 212 | 213 | For any query, please contact us. 214 | 215 | ``` 216 | rezazad68@gmail.com 217 | engtekh@gmail.com 218 | afshinbigboy@gmail.com 219 | ``` 220 | -------------------------------------------------------------------------------- /configs/isic/isic2018_attunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 8 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 8 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 4 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'Adam' 34 | params: 35 | lr: 0.0001 36 | criterion: 37 | name: "DiceLoss" 38 | params: {} 39 | scheduler: 40 | factor: 0.5 41 | patience: 10 42 | epochs: 100 43 | model: 44 | save_dir: '../../saved_models/isic2018_attunet' 45 | load_weights: false 46 | name: 'AttU_Net' 47 | params: 48 | img_ch: 3 49 | output_ch: 2 50 | # preprocess: -------------------------------------------------------------------------------- /configs/isic/isic2018_missformer.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 8 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 8 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 4 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'SGD' 34 | params: 35 | lr: 0.0001 36 | momentum: 0.9 37 | weight_decay: 0.0001 38 | criterion: 39 | name: "DiceLoss" 40 | params: {} 41 | scheduler: 42 | factor: 0.5 43 | patience: 10 44 | epochs: 300 45 | model: 46 | save_dir: '../../saved_models/isic2018_missformer' 47 | load_weights: false 48 | name: "MISSFormer" 49 | params: 50 | in_ch: 3 51 | num_classes: 2 52 | # preprocess: -------------------------------------------------------------------------------- /configs/isic/isic2018_multiresunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 2 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 2 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 2 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'Adam' 34 | params: 35 | lr: 0.0005 36 | criterion: 37 | name: "DiceLoss" 38 | params: {} 39 | scheduler: 40 | factor: 0.5 41 | patience: 10 42 | epochs: 100 43 | model: 44 | save_dir: '../../saved_models/isic2018_multiresunet' 45 | load_weights: false 46 | name: 'MultiResUnet' 47 | params: 48 | channels: 3 49 | filters: 32 50 | nclasses: 2 51 | # preprocess: -------------------------------------------------------------------------------- /configs/isic/isic2018_resunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 8 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 8 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 4 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'Adam' 34 | params: 35 | lr: 0.0001 36 | criterion: 37 | name: "DiceLoss" 38 | params: {} 39 | scheduler: 40 | factor: 0.5 41 | patience: 10 42 | epochs: 100 43 | model: 44 | save_dir: '../../saved_models/isic2018_resunet' 45 | load_weights: false 46 | name: 'ResUnet' 47 | params: 48 | in_ch: 3 49 | out_ch: 2 50 | # preprocess: -------------------------------------------------------------------------------- /configs/isic/isic2018_transunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 8 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 8 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 4 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'SGD' 34 | params: 35 | lr: 0.0001 36 | momentum: 0.9 37 | weight_decay: 0.0001 38 | criterion: 39 | name: "DiceLoss" 40 | params: {} 41 | scheduler: 42 | factor: 0.5 43 | patience: 10 44 | epochs: 100 45 | model: 46 | save_dir: '../../saved_models/isic2018_transunet' 47 | load_weights: false 48 | name: 'VisionTransformer' 49 | params: 50 | img_size: 224 51 | num_classes: 2 52 | # preprocess: -------------------------------------------------------------------------------- /configs/isic/isic2018_uctransnet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 8 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 8 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 4 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'Adam' 34 | params: 35 | lr: 0.0001 36 | criterion: 37 | name: "DiceLoss" 38 | params: {} 39 | scheduler: 40 | factor: 0.5 41 | patience: 10 42 | epochs: 100 43 | model: 44 | save_dir: '../../saved_models/isic2018_uctransnet' 45 | load_weights: false 46 | name: "UCTransNet" 47 | params: 48 | n_channels: 3 49 | n_classes: 2 50 | # preprocess: -------------------------------------------------------------------------------- /configs/isic/isic2018_unet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 8 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 8 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 4 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'Adam' 34 | params: 35 | lr: 0.0001 36 | criterion: 37 | name: "DiceLoss" 38 | params: {} 39 | scheduler: 40 | factor: 0.5 41 | patience: 10 42 | epochs: 100 43 | model: 44 | save_dir: '../../saved_models/isic2018_unet' 45 | load_weights: false 46 | name: 'UNet' 47 | params: 48 | in_channels: 3 49 | out_channels: 2 50 | with_bn: false 51 | # preprocess: -------------------------------------------------------------------------------- /configs/isic/isic2018_unetpp.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "ISIC2018Dataset" 7 | input_size: 224 8 | training: 9 | params: 10 | data_dir: "/path/to/datasets/ISIC2018" 11 | validation: 12 | params: 13 | data_dir: "/path/to/datasets/ISIC2018" 14 | number_classes: 2 15 | data_loader: 16 | train: 17 | batch_size: 16 18 | shuffle: true 19 | num_workers: 8 20 | pin_memory: true 21 | validation: 22 | batch_size: 16 23 | shuffle: false 24 | num_workers: 8 25 | pin_memory: true 26 | test: 27 | batch_size: 16 28 | shuffle: false 29 | num_workers: 4 30 | pin_memory: false 31 | training: 32 | optimizer: 33 | name: 'Adam' 34 | params: 35 | lr: 0.0001 36 | criterion: 37 | name: "DiceLoss" 38 | params: {} 39 | scheduler: 40 | factor: 0.5 41 | patience: 10 42 | epochs: 100 43 | model: 44 | save_dir: '../../saved_models/isic2018_unetpp' 45 | load_weights: false 46 | name: 'NestedUNet' 47 | params: 48 | num_classes: 2 49 | input_channels: 3 50 | deep_supervision: false 51 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_attunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | name: 'Adam' 31 | params: 32 | lr: 0.0001 33 | criterion: 34 | name: "DiceLoss" 35 | params: {} 36 | scheduler: 37 | factor: 0.5 38 | patience: 10 39 | epochs: 100 40 | model: 41 | save_dir: '../../saved_models/segpc2021_attunet' 42 | load_weights: false 43 | name: 'AttU_Net' 44 | params: 45 | img_ch: 4 46 | output_ch: 2 47 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_missformer.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | name: 'SGD' 31 | params: 32 | lr: 0.0001 33 | momentum: 0.9 34 | weight_decay: 0.0001 35 | criterion: 36 | name: "DiceLoss" 37 | params: {} 38 | scheduler: 39 | factor: 0.5 40 | patience: 10 41 | epochs: 500 42 | model: 43 | save_dir: '../../saved_models/segpc2021_missformer' 44 | load_weights: false 45 | name: 'MISSFormer' 46 | params: 47 | in_ch: 4 48 | num_classes: 2 49 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_multiresunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | name: 'Adam' 31 | params: 32 | lr: 0.0001 33 | # name: "SGD" 34 | # params: 35 | # lr: 0.0001 36 | # momentum: 0.9 37 | # weight_decay: 0.0001 38 | criterion: 39 | name: "DiceLoss" 40 | params: {} 41 | scheduler: 42 | factor: 0.5 43 | patience: 10 44 | epochs: 100 45 | model: 46 | save_dir: '../../saved_models/segpc2021_multiresunet' 47 | load_weights: false 48 | name: 'MultiResUnet' 49 | params: 50 | channels: 4 51 | filters: 32 52 | nclasses: 2 53 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_resunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | name: 'Adam' 31 | params: 32 | lr: 0.0001 33 | criterion: 34 | name: "DiceLoss" 35 | params: {} 36 | scheduler: 37 | factor: 0.5 38 | patience: 10 39 | epochs: 100 40 | model: 41 | save_dir: '../../saved_models/segpc2021_resunet' 42 | load_weights: false 43 | name: 'ResUnet' 44 | params: 45 | in_ch: 4 46 | out_ch: 2 47 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_transunet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | # name: 'Adam' 31 | # params: 32 | # lr: 0.0001 33 | name: "SGD" 34 | params: 35 | lr: 0.0001 36 | momentum: 0.9 37 | weight_decay: 0.0001 38 | criterion: 39 | name: "DiceLoss" 40 | params: {} 41 | scheduler: 42 | factor: 0.5 43 | patience: 10 44 | epochs: 100 45 | model: 46 | save_dir: '../../saved_models/segpc2021_transunet' 47 | load_weights: false 48 | name: 'VisionTransformer' 49 | params: 50 | img_size: 224 51 | num_classes: 2 52 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_uctransnet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | name: 'Adam' 31 | params: 32 | lr: 0.0001 33 | criterion: 34 | name: "DiceLoss" 35 | params: {} 36 | scheduler: 37 | factor: 0.5 38 | patience: 10 39 | epochs: 100 40 | model: 41 | save_dir: '../../saved_models/segpc2021_uctransnet' 42 | load_weights: false 43 | name: 'UCTransNet' 44 | params: 45 | n_channels: 4 46 | n_classes: 2 47 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_unet.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | name: 'Adam' 31 | params: 32 | lr: 0.0001 33 | criterion: 34 | name: "DiceLoss" 35 | params: {} 36 | scheduler: 37 | factor: 0.5 38 | patience: 10 39 | epochs: 100 40 | model: 41 | save_dir: '../../saved_models/segpc2021_unet' 42 | load_weights: false 43 | name: 'UNet' 44 | params: 45 | in_channels: 4 46 | out_channels: 2 47 | with_bn: false 48 | # preprocess: -------------------------------------------------------------------------------- /configs/segpc/segpc2021_unetpp.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | mode: 'train' 3 | device: 'gpu' 4 | transforms: none 5 | dataset: 6 | class_name: "SegPC2021Dataset" 7 | input_size: 224 8 | scale: 2.5 9 | data_dir: "/path/to/datasets/segpc/np" 10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset" 11 | number_classes: 2 12 | data_loader: 13 | train: 14 | batch_size: 16 15 | shuffle: true 16 | num_workers: 4 17 | pin_memory: true 18 | validation: 19 | batch_size: 16 20 | shuffle: false 21 | num_workers: 4 22 | pin_memory: true 23 | test: 24 | batch_size: 16 25 | shuffle: false 26 | num_workers: 4 27 | pin_memory: false 28 | training: 29 | optimizer: 30 | name: 'Adam' 31 | params: 32 | lr: 0.0001 33 | criterion: 34 | name: "DiceLoss" 35 | params: {} 36 | scheduler: 37 | factor: 0.5 38 | patience: 10 39 | epochs: 100 40 | model: 41 | save_dir: '../../saved_models/segpc2021_unetpp' 42 | load_weights: false 43 | name: 'NestedUNet' 44 | params: 45 | num_classes: 2 46 | input_channels: 4 47 | deep_supervision: false 48 | # preprocess: -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Data Preparing 2 | 3 | The structure of the `datasets` and `configs` folders are here. 4 | 5 | ```bash 6 | . 7 | ├── configs 8 | │ ├── isic 9 | │ │ ├── isic2018_*.yaml 10 | │ │ └── ... 11 | │ └── segpc 12 | │ ├── segpc2021_*.yaml 13 | │ └── ... 14 | └── datasets 15 | ├── isic.ipynb 16 | ├── isic.py 17 | ├── segpc.ipynb 18 | ├── segpc.py 19 | ├── prepare_isic.ipynb 20 | └── prepare_segpc.ipynb 21 | ``` 22 | 23 | In this work, we used 3 datasets (ISIC, SegPC, and Synapse). In the `datasets` folder, you can see files in 3 different formats, 24 | 25 | - `.py`: These files are used in the main codes for the training and testing procedures. 26 | 27 | - `.ipynb`: These files are just for a better understanding of datasets and how we are using them. You can also see at least one image of the related dataset on them. 28 | 29 | - `prepare_.ipynb`: You can use these files for preparing required files to use desired `Dataset` classes. 30 | 31 | When you want to use codes for preparing or using datasets **do not forget to rewrite your right data path** in related dataset files and also corresponding config files. 32 | 33 | In the following, you can find more information regarding accessing, preparing, and using datasets. 34 | 35 | 36 | 37 | ## ISIC2018 [(ISIC Challenge 2016-2020)](https://challenge.isic-archive.com/) 38 | 39 | The lesion images were acquired with a variety of dermatoscopy types, from all anatomic sites (excluding mucosa and nails), from a historical sample of patients presented for skin cancer screening, from several different institutions. Every lesion image contains exactly one primary lesion; other fiducial markers, smaller secondary lesions, or other pigmented regions may be neglected. 40 | 41 | #### [Data 2018](https://challenge.isic-archive.com/data/#2018) 42 | 43 | You can download the ISIC2018 dataset using the following links, 44 | 45 | - [Download training input images](https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Training_Input.zip) 46 | 47 | - [Download training ground truth masks](https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Training_GroundTruth.zip) 48 | 49 | After downloading, extract them in a folder, and ***do not forget to put the path*** on relevant files. 50 | 51 | The input data are dermoscopic lesion images in JPEG format. 52 | 53 | All lesion images are named using the scheme `ISIC_.jpg`, where `` is a 7-digit unique identifier. EXIF tags in the images have been removed; any remaining EXIF tags should not be relied upon to provide accurate metadata. 54 | 55 | At bellow, you can see a sample of this dataset. 56 | 57 |

58 | 59 |

60 | 61 | 62 | 63 | 64 | ## SegPC2021 [(SegPC Challenge 2021)](https://ieee-dataport.org/open-access/segpc-2021-segmentation-multiple-myeloma-plasma-cells-microscopic-images) 65 | 66 | This [challenge](https://segpc-2021.grand-challenge.org/) was positioned towards robust segmentation of cells which is the first stage to build such a tool for plasma cell cancer, namely, Multiple Myeloma (MM), which is a type of blood cancer. They provided images after stain color normalization. 67 | 68 | In this challenge, the target was instance segmentation of cells of interest. For each input image, both nuclei and cytoplasm were required to be segmented separately for each cell of interest. 69 | 70 | ### Our purpose 71 | 72 | In this work, we used this dataset for a bit different purpose. Here, we want to use the SegPC2021 database to create an independent image centered on the nucleus for each desired cell in this dataset, with the ultimate goal of obtaining the cytoplasm of that cell. Also, the nucleus mask will be used as an auxiliary channel in the input images. You can refer to [this article](https://arxiv.org/abs/2105.06238) for more information. You can download the main dataset using the following links: 73 | 74 | - [Official page](https://ieee-dataport.org/open-access/segpc-2021-segmentation-multiple-myeloma-plasma-cells-microscopic-images) 75 | - [Alternative link](https://www.kaggle.com/datasets/sbilab/segpc2021dataset) 76 | 77 | After downloading the dataset you can use the `prepare_segpc.ipynb` file to produce the required data file. ***Do not forget to put the right path on relevant files***. 78 | 79 | Below, you can see a sample of the prepared dataset with `scale=2.5`. 80 | 81 |

82 | 83 |

84 | 85 | 86 | 87 | ## Synapse 88 | 89 | 1. Access to the synapse multi-organ dataset: 90 | 91 | 1. Sign up in the [official Synapse website](https://www.synapse.org/#!Synapse:syn3193805/wiki/) and download the dataset. Convert them to numpy format, clip the images within [-125, 275], normalize each 3D image to [0, 1], and extract 2D slices from 3D volume for training cases while keeping the 3D volume in h5 format for testing cases. 92 | 93 | 2. The directory structure of the whole project is as follows: 94 | 95 | ```bash 96 | . 97 | ├── model 98 | │ └── vit_checkpoint 99 | │ └── imagenet21k 100 | │ ├── R50+ViT-B_16.npz 101 | │ └── *.npz 102 | └── data 103 | └──Synapse 104 | ├── test_vol_h5 105 | │ ├── case0001.npy.h5 106 | │ └── *.npy.h5 107 | └── train_npz 108 | ├── case0005_slice000.npz 109 | └── *.npz 110 | ``` 111 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/isic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # ## [ISIC Challenge (2016-2020)](https://challenge.isic-archive.com/) 5 | # --- 6 | # 7 | # ### [Data 2018](https://challenge.isic-archive.com/data/) 8 | # 9 | # The input data are dermoscopic lesion images in JPEG format. 10 | # 11 | # All lesion images are named using the scheme `ISIC_.jpg`, where `` is a 7-digit unique identifier. EXIF tags in the images have been removed; any remaining EXIF tags should not be relied upon to provide accurate metadata. 12 | # 13 | # The lesion images were acquired with a variety of dermatoscope types, from all anatomic sites (excluding mucosa and nails), from a historical sample of patients presented for skin cancer screening, from several different institutions. Every lesion image contains exactly one primary lesion; other fiducial markers, smaller secondary lesions, or other pigmented regions may be neglected. 14 | # 15 | # The distribution of disease states represent a modified "real world" setting whereby there are more benign lesions than malignant lesions, but an over-representation of malignancies. 16 | 17 | # In[2]: 18 | 19 | 20 | import os 21 | import glob 22 | import numpy as np 23 | import torch 24 | from torch.utils.data import Dataset 25 | from torchvision import transforms, utils 26 | from torchvision.io import read_image 27 | from torchvision.io.image import ImageReadMode 28 | import torch.nn.functional as F 29 | 30 | 31 | 32 | # In[3]: 33 | 34 | class ISIC2018Dataset(Dataset): 35 | def __init__(self, data_dir=None, one_hot=True, img_transform=None, msk_transform=None): 36 | # pre-set variables 37 | self.data_prefix = "ISIC_" 38 | self.target_postfix = "_segmentation" 39 | self.target_fex = "png" 40 | self.input_fex = "jpg" 41 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018" 42 | self.imgs_dir = os.path.join(self.data_dir, "ISIC2018_Task1-2_Training_Input") 43 | self.msks_dir = os.path.join(self.data_dir, "ISIC2018_Task1_Training_GroundTruth") 44 | 45 | # input parameters 46 | self.img_dirs = glob.glob(f"{self.imgs_dir}/*.{self.input_fex}") 47 | self.data_ids = [d.split(self.data_prefix)[1].split(f".{self.input_fex}")[0] for d in self.img_dirs] 48 | self.one_hot = one_hot 49 | self.img_transform = img_transform 50 | self.msk_transform = msk_transform 51 | 52 | def get_img_by_id(self, id): 53 | img_dir = os.path.join(self.imgs_dir, f"{self.data_prefix}{id}.{self.input_fex}") 54 | img = read_image(img_dir, ImageReadMode.RGB) 55 | return img 56 | 57 | def get_msk_by_id(self, id): 58 | msk_dir = os.path.join(self.msks_dir, f"{self.data_prefix}{id}{self.target_postfix}.{self.target_fex}") 59 | msk = read_image(msk_dir, ImageReadMode.GRAY) 60 | return msk 61 | 62 | def __len__(self): 63 | return len(self.data_ids) 64 | 65 | def __getitem__(self, idx): 66 | data_id = self.data_ids[idx] 67 | img = self.get_img_by_id(data_id) 68 | msk = self.get_msk_by_id(data_id) 69 | 70 | if self.img_transform: 71 | img = self.img_transform(img) 72 | img = (img - img.min())/(img.max() - img.min()) 73 | if self.msk_transform: 74 | msk = self.msk_transform(msk) 75 | msk = (msk - msk.min())/(msk.max() - msk.min()) 76 | 77 | if self.one_hot: 78 | msk = F.one_hot(torch.squeeze(msk).to(torch.int64)) 79 | msk = torch.moveaxis(msk, -1, 0).to(torch.float) 80 | 81 | sample = {'image': img, 'mask': msk, 'id': data_id} 82 | return sample 83 | 84 | 85 | class ISIC2018DatasetFast(Dataset): 86 | def __init__(self, mode, data_dir=None, one_hot=True, img_transform=None, msk_transform=None): 87 | # pre-set variables 88 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018/np" 89 | 90 | # input parameters 91 | self.one_hot = one_hot 92 | 93 | X = np.load(f"{self.data_dir}/X_tr_224x224.npy") 94 | Y = np.load(f"{self.data_dir}/Y_tr_224x224.npy") 95 | 96 | X = torch.tensor(X) 97 | Y = torch.tensor(Y) 98 | 99 | if mode == "tr": 100 | self.imgs = X[0:1815] 101 | self.msks = Y[0:1815] 102 | elif mode == "vl": 103 | self.imgs = X[1815:1815+259] 104 | self.msks = Y[1815:1815+259] 105 | elif mode == "te": 106 | self.imgs = X[1815+259:2594] 107 | self.msks = Y[1815+259:2594] 108 | else: 109 | raise ValueError() 110 | 111 | def __len__(self): 112 | return len(self.imgs) 113 | 114 | def __getitem__(self, idx): 115 | data_id = idx 116 | img = self.imgs[idx] 117 | msk = self.msks[idx] 118 | 119 | if self.one_hot: 120 | msk = F.one_hot(torch.squeeze(msk).to(torch.int64)) 121 | msk = torch.moveaxis(msk, -1, 0).to(torch.float) 122 | 123 | sample = {'image': img, 'mask': msk, 'id': data_id} 124 | return sample 125 | 126 | 127 | class ISIC2018TrainingDataset(Dataset): 128 | def __init__(self, data_dir=None, img_transform=None, msk_transform=None): 129 | # pre-set variables 130 | self.data_prefix = "ISIC_" 131 | self.target_postfix = "_segmentation" 132 | self.target_fex = "png" 133 | self.input_fex = "jpg" 134 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018" 135 | self.imgs_dir = os.path.join(self.data_dir, "ISIC2018_Task1-2_Training_Input") 136 | self.msks_dir = os.path.join(self.data_dir, "ISIC2018_Task1_Training_GroundTruth") 137 | 138 | # input parameters 139 | self.img_dirs = glob.glob(f"{self.imgs_dir}/*.{self.input_fex}") 140 | self.data_ids = [d.split(self.data_prefix)[1].split(f".{self.input_fex}")[0] for d in self.img_dirs] 141 | self.img_transform = img_transform 142 | self.msk_transform = msk_transform 143 | 144 | def get_img_by_id(self, id): 145 | img_dir = os.path.join(self.imgs_dir, f"{self.data_prefix}{id}.{self.input_fex}") 146 | img = read_image(img_dir, ImageReadMode.RGB) 147 | return img 148 | 149 | def get_msk_by_id(self, id): 150 | msk_dir = os.path.join(self.msks_dir, f"{self.data_prefix}{id}{self.target_postfix}.{self.target_fex}") 151 | msk = read_image(msk_dir, ImageReadMode.GRAY) 152 | return msk 153 | 154 | def __len__(self): 155 | return len(self.data_ids) 156 | 157 | def __getitem__(self, idx): 158 | data_id = self.data_ids[idx] 159 | img = self.get_img_by_id(data_id) 160 | msk = self.get_msk_by_id(data_id) 161 | 162 | if self.img_transform: 163 | img = self.img_transform(img) 164 | img = (img - img.min())/(img.max() - img.min()) 165 | if self.msk_transform: 166 | msk = self.msk_transform(msk) 167 | msk = (msk - msk.min())/(msk.max() - msk.min()) 168 | sample = {'image': img, 'mask': msk, 'id': data_id} 169 | return sample 170 | 171 | 172 | # In[10]: 173 | 174 | 175 | class ISIC2018ValidationDataset(Dataset): 176 | def __init__(self, data_dir=None, img_transform=None, msk_transform=None): 177 | # pre-set variables 178 | self.data_prefix = "ISIC_" 179 | self.target_postfix = "_segmentation" 180 | self.target_fex = "png" 181 | self.input_fex = "jpg" 182 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018" 183 | self.imgs_dir = os.path.join(self.data_dir, "ISIC2018_Task1-2_Validation_Input") 184 | self.msks_dir = os.path.join(self.data_dir, "ISIC2018_Task1_Validation_GroundTruth") 185 | 186 | # input parameters 187 | self.img_dirs = glob.glob(f"{self.imgs_dir}/*.{self.input_fex}") 188 | self.data_ids = [d.split(self.data_prefix)[1].split(f".{self.input_fex}")[0] for d in self.img_dirs] 189 | self.img_transform = img_transform 190 | self.msk_transform = msk_transform 191 | 192 | def get_img_by_id(self, id): 193 | img_dir = os.path.join(self.imgs_dir, f"{self.data_prefix}{id}.{self.input_fex}") 194 | img = read_image(img_dir, ImageReadMode.RGB) 195 | return img 196 | 197 | def get_msk_by_id(self, id): 198 | msk_dir = os.path.join(self.msks_dir, f"{self.data_prefix}{id}{self.target_postfix}.{self.target_fex}") 199 | msk = read_image(msk_dir, ImageReadMode.GRAY) 200 | return msk 201 | 202 | def __len__(self): 203 | return len(self.data_ids) 204 | 205 | def __getitem__(self, idx): 206 | data_id = self.data_ids[idx] 207 | img = self.get_img_by_id(data_id) 208 | msk = self.get_msk_by_id(data_id) 209 | 210 | print(f"msk shape: {msk.shape}") 211 | print(f"img shape: {img.shape}") 212 | 213 | 214 | if self.img_transform: 215 | img = self.img_transform(img) 216 | img = (img - img.min())/(img.max() - img.min()) 217 | if self.msk_transform: 218 | msk = self.msk_transform(msk) 219 | msk = (msk - msk.min())/(msk.max() - msk.min()) 220 | 221 | sample = {'image': img, 'mask': msk, 'id': data_id} 222 | return sample 223 | 224 | 225 | # ## Test dataset and dataloader 226 | # --- 227 | 228 | # # In[13]: 229 | 230 | 231 | # import sys 232 | # sys.path.append('..') 233 | # from utils import show_sbs 234 | # from torch.utils.data import DataLoader, Subset 235 | # from torchvision import transforms 236 | 237 | 238 | 239 | # # ------------------- params -------------------- 240 | # INPUT_SIZE = 224 241 | 242 | # TR_BATCH_SIZE = 8 243 | # TR_DL_SHUFFLE = True 244 | # TR_DL_WORKER = 1 245 | 246 | # VL_BATCH_SIZE = 12 247 | # VL_DL_SHUFFLE = False 248 | # VL_DL_WORKER = 1 249 | 250 | # TE_BATCH_SIZE = 12 251 | # TE_DL_SHUFFLE = False 252 | # TE_DL_WORKER = 1 253 | # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 254 | 255 | 256 | # # ----------------- transform ------------------ 257 | # # transform for image 258 | # img_transform = transforms.Compose([ 259 | # transforms.Resize( 260 | # size=[INPUT_SIZE, INPUT_SIZE], 261 | # interpolation=transforms.functional.InterpolationMode.BILINEAR 262 | # ), 263 | # ]) 264 | # # transform for mask 265 | # msk_transform = transforms.Compose([ 266 | # transforms.Resize( 267 | # size=[INPUT_SIZE, INPUT_SIZE], 268 | # interpolation=transforms.functional.InterpolationMode.NEAREST 269 | # ), 270 | # ]) 271 | # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 272 | 273 | 274 | # # ----------------- dataset -------------------- 275 | # # preparing training dataset 276 | # train_dataset = ISIC2018TrainingDataset( 277 | # img_transform=img_transform, 278 | # msk_transform=msk_transform 279 | # ) 280 | 281 | # # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 282 | # # !cat ~/deeplearning/skin/Prepare_ISIC2018.py 283 | 284 | # indices = list(range(len(train_dataset))) 285 | 286 | # # split indices to: -> train, validation, and test 287 | # tr_indices = indices[0:1815] 288 | # vl_indices = indices[1815:1815+259] 289 | # te_indices = indices[1815+259:2594] 290 | 291 | # # create new datasets from train dataset as training, validation, and test 292 | # tr_dataset = Subset(train_dataset, tr_indices) 293 | # vl_dataset = Subset(train_dataset, vl_indices) 294 | # te_dataset = Subset(train_dataset, te_indices) 295 | 296 | # # prepare train dataloader 297 | # tr_loader = DataLoader( 298 | # tr_dataset, 299 | # batch_size=TR_BATCH_SIZE, 300 | # shuffle=TR_DL_SHUFFLE, 301 | # num_workers=TR_DL_WORKER, 302 | # pin_memory=True 303 | # ) 304 | 305 | # # prepare validation dataloader 306 | # vl_loader = DataLoader( 307 | # vl_dataset, 308 | # batch_size=VL_BATCH_SIZE, 309 | # shuffle=VL_DL_SHUFFLE, 310 | # num_workers=VL_DL_WORKER, 311 | # pin_memory=True 312 | # ) 313 | 314 | # # prepare test dataloader 315 | # te_loader = DataLoader( 316 | # te_dataset, 317 | # batch_size=TE_BATCH_SIZE, 318 | # shuffle=TE_DL_SHUFFLE, 319 | # num_workers=TE_DL_WORKER, 320 | # pin_memory=True 321 | # ) 322 | 323 | # # -------------- test ----------------- 324 | # # test and visualize the input data 325 | # for sample in tr_loader: 326 | # img = sample['image'] 327 | # msk = sample['mask'] 328 | # print("Training") 329 | # show_sbs(img[0], msk[0]) 330 | # break 331 | 332 | # for sample in vl_loader: 333 | # img = sample['image'] 334 | # msk = sample['mask'] 335 | # print("Validation") 336 | # show_sbs(img[0], msk[0]) 337 | # break 338 | 339 | # for sample in te_loader: 340 | # img = sample['image'] 341 | # msk = sample['mask'] 342 | # print("Test") 343 | # show_sbs(img[0], msk[0]) 344 | # break 345 | 346 | -------------------------------------------------------------------------------- /datasets/segpc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms, utils 7 | from torchvision.io import read_image 8 | from torchvision.io.image import ImageReadMode 9 | import torch.nn.functional as F 10 | 11 | 12 | 13 | class SegPC2021Dataset(Dataset): 14 | def __init__(self, 15 | mode, # 'tr'-> train, 'vl' -> validation, 'te' -> test 16 | input_size=224, 17 | scale=2.5, 18 | data_dir=None, 19 | dataset_dir=None, 20 | one_hot=True, 21 | force_rebuild=False, 22 | img_transform=None, 23 | msk_transform=None): 24 | # pre-set variables 25 | self.data_dir = data_dir if data_dir else "/path/to/datasets/segpc/np" 26 | self.dataset_dir = dataset_dir if dataset_dir else "/path/to/datasets/segpc/TCIA_SegPC_dataset/" 27 | self.mode = mode 28 | # input parameters 29 | self.img_transform = img_transform 30 | self.msk_transform = msk_transform 31 | self.input_size = input_size 32 | self.scale = scale 33 | self.one_hot = one_hot 34 | 35 | # loading data 36 | self.load_dataset(force_rebuild=force_rebuild) 37 | 38 | 39 | def load_dataset(self, force_rebuild): 40 | INPUT_SIZE = self.input_size 41 | ADD = self.data_dir 42 | 43 | # build_segpc_dataset( 44 | # input_size = self.input_size, 45 | # scale = self.scale, 46 | # data_dir = self.data_dir, 47 | # dataset_dir = self.dataset_dir, 48 | # mode = self.mode, 49 | # force_rebuild = force_rebuild, 50 | # ) 51 | 52 | print(f'loading X_{self.mode}...') 53 | self.X = np.load(f'{ADD}/cyts_{self.mode}_{self.input_size}x{self.input_size}_s{self.scale}_X.npy') 54 | print(f'loading Y_{self.mode}...') 55 | self.Y = np.load(f'{ADD}/cyts_{self.mode}_{self.input_size}x{self.input_size}_s{self.scale}_Y.npy') 56 | print('finished.') 57 | 58 | 59 | def __len__(self): 60 | return len(self.X) 61 | 62 | 63 | def __getitem__(self, idx): 64 | img = self.X[idx] 65 | msk = self.Y[idx] 66 | msk = np.where(msk<0.5, 0, 1) 67 | 68 | if self.img_transform: 69 | img = self.img_transform(img) 70 | img = (img - img.min())/(img.max() - img.min()) 71 | if self.msk_transform: 72 | msk = self.msk_transform(msk) 73 | msk = (msk - msk.min())/(msk.max() - msk.min()) 74 | 75 | if self.one_hot: 76 | msk = F.one_hot(torch.squeeze(msk).to(torch.int64)) 77 | msk = torch.moveaxis(msk, -1, 0).to(torch.float) 78 | 79 | sample = {'image': img, 'mask': msk, 'id': idx} 80 | return sample -------------------------------------------------------------------------------- /images/ComparisonOfModels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/ComparisonOfModels.png -------------------------------------------------------------------------------- /images/U-Net_Taxonomy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/U-Net_Taxonomy.png -------------------------------------------------------------------------------- /images/isic2018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/isic2018.png -------------------------------------------------------------------------------- /images/isic2018_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/isic2018_sample.png -------------------------------------------------------------------------------- /images/segpc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/segpc.png -------------------------------------------------------------------------------- /images/segpc2021_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/segpc2021_sample.png -------------------------------------------------------------------------------- /images/synapse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/synapse.png -------------------------------------------------------------------------------- /images/unet-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/unet-pipeline.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | 6 | EPSILON = 1e-6 7 | 8 | class DiceLoss(torch.nn.Module): 9 | def __init__(self,): 10 | super().__init__() 11 | 12 | def forward(self, pred, mask): 13 | pred = pred.flatten() 14 | mask = mask.flatten() 15 | 16 | intersect = (mask * pred).sum() 17 | dice_score = 2*intersect / (pred.sum() + mask.sum() + EPSILON) 18 | dice_loss = 1 - dice_score 19 | return dice_loss 20 | 21 | 22 | class DiceLossWithLogtis(torch.nn.Module): 23 | def __init__(self,): 24 | super().__init__() 25 | 26 | def forward(self, pred, mask): 27 | prob = F.softmax(pred, dim=1) 28 | true_1_hot = mask.type(prob.type()) 29 | 30 | dims = (0,) + tuple(range(2, true_1_hot.ndimension())) 31 | intersection = torch.sum(prob * true_1_hot, dims) 32 | cardinality = torch.sum(prob + true_1_hot, dims) 33 | dice_loss = (2. * intersection / (cardinality + EPSILON)).mean() 34 | return (1 - dice_loss) 35 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/models/__init__.py -------------------------------------------------------------------------------- /models/_missformer/MISSFormer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .segformer import * 4 | from typing import Tuple 5 | from einops import rearrange 6 | 7 | class PatchExpand(nn.Module): 8 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 9 | super().__init__() 10 | self.input_resolution = input_resolution 11 | self.dim = dim 12 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() 13 | self.norm = norm_layer(dim // dim_scale) 14 | 15 | def forward(self, x): 16 | """ 17 | x: B, H*W, C 18 | """ 19 | # print("x_shape-----",x.shape) 20 | H, W = self.input_resolution 21 | x = self.expand(x) 22 | 23 | B, L, C = x.shape 24 | # print(x.shape) 25 | assert L == H * W, "input feature has wrong size" 26 | 27 | x = x.view(B, H, W, C) 28 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 29 | x = x.view(B,-1,C//4) 30 | x= self.norm(x.clone()) 31 | 32 | return x 33 | 34 | class FinalPatchExpand_X4(nn.Module): 35 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 36 | super().__init__() 37 | self.input_resolution = input_resolution 38 | self.dim = dim 39 | self.dim_scale = dim_scale 40 | self.expand = nn.Linear(dim, 16*dim, bias=False) 41 | self.output_dim = dim 42 | self.norm = norm_layer(self.output_dim) 43 | 44 | def forward(self, x): 45 | """ 46 | x: B, H*W, C 47 | """ 48 | H, W = self.input_resolution 49 | x = self.expand(x) 50 | B, L, C = x.shape 51 | assert L == H * W, "input feature has wrong size" 52 | 53 | x = x.view(B, H, W, C) 54 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) 55 | x = x.view(B,-1,self.output_dim) 56 | x= self.norm(x.clone()) 57 | 58 | return x 59 | 60 | 61 | class SegU_decoder(nn.Module): 62 | def __init__(self, input_size, in_out_chan, heads, reduction_ratios, n_class=9, norm_layer=nn.LayerNorm, is_last=False): 63 | super().__init__() 64 | dims = in_out_chan[0] 65 | out_dim = in_out_chan[1] 66 | if not is_last: 67 | self.concat_linear = nn.Linear(dims*2, out_dim) 68 | # transformer decoder 69 | self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer) 70 | self.last_layer = None 71 | else: 72 | self.concat_linear = nn.Linear(dims*4, out_dim) 73 | # transformer decoder 74 | self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer) 75 | # self.last_layer = nn.Linear(out_dim, n_class) 76 | self.last_layer = nn.Conv2d(out_dim, n_class,1) 77 | # self.last_layer = None 78 | 79 | self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios) 80 | self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios) 81 | 82 | 83 | def init_weights(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Linear): 86 | nn.init.xavier_uniform_(m.weight) 87 | if m.bias is not None: 88 | nn.init.zeros_(m.bias) 89 | elif isinstance(m, nn.LayerNorm): 90 | nn.init.ones_(m.weight) 91 | nn.init.zeros_(m.bias) 92 | elif isinstance(m, nn.Conv2d): 93 | nn.init.xavier_uniform_(m.weight) 94 | if m.bias is not None: 95 | nn.init.zeros_(m.bias) 96 | 97 | init_weights(self) 98 | 99 | 100 | 101 | def forward(self, x1, x2=None): 102 | if x2 is not None: 103 | b, h, w, c = x2.shape 104 | x2 = x2.view(b, -1, c) 105 | # print("------",x1.shape, x2.shape) 106 | cat_x = torch.cat([x1, x2], dim=-1) 107 | # print("-----catx shape", cat_x.shape) 108 | cat_linear_x = self.concat_linear(cat_x) 109 | tran_layer_1 = self.layer_former_1(cat_linear_x, h, w) 110 | tran_layer_2 = self.layer_former_2(tran_layer_1, h, w) 111 | 112 | if self.last_layer: 113 | out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2)) 114 | else: 115 | out = self.layer_up(tran_layer_2) 116 | else: 117 | # if len(x1.shape)>3: 118 | # x1 = x1.permute(0,2,3,1) 119 | # b, h, w, c = x1.shape 120 | # x1 = x1.view(b, -1, c) 121 | out = self.layer_up(x1) 122 | return out 123 | 124 | 125 | class BridgeLayer_4(nn.Module): 126 | def __init__(self, dims, head, reduction_ratios): 127 | super().__init__() 128 | 129 | self.norm1 = nn.LayerNorm(dims) 130 | self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios) 131 | self.norm2 = nn.LayerNorm(dims) 132 | self.mixffn1 = MixFFN_skip(dims,dims*4) 133 | self.mixffn2 = MixFFN_skip(dims*2,dims*8) 134 | self.mixffn3 = MixFFN_skip(dims*5,dims*20) 135 | self.mixffn4 = MixFFN_skip(dims*8,dims*32) 136 | 137 | 138 | def forward(self, inputs): 139 | B = inputs[0].shape[0] 140 | C = 64 141 | if (type(inputs) == list): 142 | # print("-----1-----") 143 | c1, c2, c3, c4 = inputs 144 | B, C, _, _= c1.shape 145 | c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64 146 | c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64 147 | c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64 148 | c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64 149 | 150 | # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape) 151 | inputs = torch.cat([c1f, c2f, c3f, c4f], -2) 152 | else: 153 | B,_,C = inputs.shape 154 | 155 | tx1 = inputs + self.attn(self.norm1(inputs)) 156 | tx = self.norm2(tx1) 157 | 158 | 159 | tem1 = tx[:,:3136,:].reshape(B, -1, C) 160 | tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2) 161 | tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5) 162 | tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8) 163 | 164 | m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C) 165 | m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C) 166 | m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C) 167 | m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C) 168 | 169 | t1 = torch.cat([m1f, m2f, m3f, m4f], -2) 170 | 171 | tx2 = tx1 + t1 172 | 173 | 174 | return tx2 175 | 176 | 177 | class BridgeLayer_3(nn.Module): 178 | def __init__(self, dims, head, reduction_ratios): 179 | super().__init__() 180 | 181 | self.norm1 = nn.LayerNorm(dims) 182 | self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios) 183 | self.norm2 = nn.LayerNorm(dims) 184 | # self.mixffn1 = MixFFN(dims,dims*4) 185 | self.mixffn2 = MixFFN(dims*2,dims*8) 186 | self.mixffn3 = MixFFN(dims*5,dims*20) 187 | self.mixffn4 = MixFFN(dims*8,dims*32) 188 | 189 | 190 | def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: 191 | B = inputs[0].shape[0] 192 | C = 64 193 | if (type(inputs) == list): 194 | # print("-----1-----") 195 | c1, c2, c3, c4 = inputs 196 | B, C, _, _= c1.shape 197 | c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64 198 | c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64 199 | c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64 200 | c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64 201 | 202 | # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape) 203 | inputs = torch.cat([c2f, c3f, c4f], -2) 204 | else: 205 | B,_,C = inputs.shape 206 | 207 | tx1 = inputs + self.attn(self.norm1(inputs)) 208 | tx = self.norm2(tx1) 209 | 210 | 211 | # tem1 = tx[:,:3136,:].reshape(B, -1, C) 212 | tem2 = tx[:,:1568,:].reshape(B, -1, C*2) 213 | tem3 = tx[:,1568:2548,:].reshape(B, -1, C*5) 214 | tem4 = tx[:,2548:2940,:].reshape(B, -1, C*8) 215 | 216 | # m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C) 217 | m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C) 218 | m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C) 219 | m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C) 220 | 221 | t1 = torch.cat([m2f, m3f, m4f], -2) 222 | 223 | tx2 = tx1 + t1 224 | 225 | 226 | return tx2 227 | 228 | 229 | 230 | class BridegeBlock_4(nn.Module): 231 | def __init__(self, dims, head, reduction_ratios): 232 | super().__init__() 233 | self.bridge_layer1 = BridgeLayer_4(dims, head, reduction_ratios) 234 | self.bridge_layer2 = BridgeLayer_4(dims, head, reduction_ratios) 235 | self.bridge_layer3 = BridgeLayer_4(dims, head, reduction_ratios) 236 | self.bridge_layer4 = BridgeLayer_4(dims, head, reduction_ratios) 237 | 238 | def forward(self, x: torch.Tensor) -> torch.Tensor: 239 | bridge1 = self.bridge_layer1(x) 240 | bridge2 = self.bridge_layer2(bridge1) 241 | bridge3 = self.bridge_layer3(bridge2) 242 | bridge4 = self.bridge_layer4(bridge3) 243 | 244 | B,_,C = bridge4.shape 245 | outs = [] 246 | 247 | sk1 = bridge4[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2) 248 | sk2 = bridge4[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0,3,1,2) 249 | sk3 = bridge4[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0,3,1,2) 250 | sk4 = bridge4[:,5684:6076,:].reshape(B, 7, 7, C*8).permute(0,3,1,2) 251 | 252 | outs.append(sk1) 253 | outs.append(sk2) 254 | outs.append(sk3) 255 | outs.append(sk4) 256 | 257 | return outs 258 | 259 | 260 | class BridegeBlock_3(nn.Module): 261 | def __init__(self, dims, head, reduction_ratios): 262 | super().__init__() 263 | self.bridge_layer1 = BridgeLayer_3(dims, head, reduction_ratios) 264 | self.bridge_layer2 = BridgeLayer_3(dims, head, reduction_ratios) 265 | self.bridge_layer3 = BridgeLayer_3(dims, head, reduction_ratios) 266 | self.bridge_layer4 = BridgeLayer_3(dims, head, reduction_ratios) 267 | 268 | def forward(self, x: torch.Tensor) -> torch.Tensor: 269 | outs = [] 270 | if (type(x) == list): 271 | # print("-----1-----") 272 | outs.append(x[0]) 273 | bridge1 = self.bridge_layer1(x) 274 | bridge2 = self.bridge_layer2(bridge1) 275 | bridge3 = self.bridge_layer3(bridge2) 276 | bridge4 = self.bridge_layer4(bridge3) 277 | 278 | B,_,C = bridge4.shape 279 | 280 | 281 | # sk1 = bridge2[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2) 282 | sk2 = bridge4[:,:1568,:].reshape(B, 28, 28, C*2).permute(0,3,1,2) 283 | sk3 = bridge4[:,1568:2548,:].reshape(B, 14, 14, C*5).permute(0,3,1,2) 284 | sk4 = bridge4[:,2548:2940,:].reshape(B, 7, 7, C*8).permute(0,3,1,2) 285 | 286 | # outs.append(sk1) 287 | outs.append(sk2) 288 | outs.append(sk3) 289 | outs.append(sk4) 290 | 291 | return outs 292 | 293 | 294 | class MyDecoderLayer(nn.Module): 295 | def __init__(self, input_size, in_out_chan, heads, reduction_ratios,token_mlp_mode, n_class=9, norm_layer=nn.LayerNorm, is_last=False): 296 | super().__init__() 297 | dims = in_out_chan[0] 298 | out_dim = in_out_chan[1] 299 | if not is_last: 300 | self.concat_linear = nn.Linear(dims*2, out_dim) 301 | # transformer decoder 302 | self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer) 303 | self.last_layer = None 304 | else: 305 | self.concat_linear = nn.Linear(dims*4, out_dim) 306 | # transformer decoder 307 | self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer) 308 | # self.last_layer = nn.Linear(out_dim, n_class) 309 | self.last_layer = nn.Conv2d(out_dim, n_class,1) 310 | # self.last_layer = None 311 | 312 | self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode) 313 | self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode) 314 | 315 | 316 | def init_weights(self): 317 | for m in self.modules(): 318 | if isinstance(m, nn.Linear): 319 | nn.init.xavier_uniform_(m.weight) 320 | if m.bias is not None: 321 | nn.init.zeros_(m.bias) 322 | elif isinstance(m, nn.LayerNorm): 323 | nn.init.ones_(m.weight) 324 | nn.init.zeros_(m.bias) 325 | elif isinstance(m, nn.Conv2d): 326 | nn.init.xavier_uniform_(m.weight) 327 | if m.bias is not None: 328 | nn.init.zeros_(m.bias) 329 | 330 | init_weights(self) 331 | 332 | def forward(self, x1, x2=None): 333 | if x2 is not None: 334 | b, h, w, c = x2.shape 335 | x2 = x2.view(b, -1, c) 336 | # print("------",x1.shape, x2.shape) 337 | cat_x = torch.cat([x1, x2], dim=-1) 338 | # print("-----catx shape", cat_x.shape) 339 | cat_linear_x = self.concat_linear(cat_x) 340 | tran_layer_1 = self.layer_former_1(cat_linear_x, h, w) 341 | tran_layer_2 = self.layer_former_2(tran_layer_1, h, w) 342 | 343 | if self.last_layer: 344 | out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2)) 345 | else: 346 | out = self.layer_up(tran_layer_2) 347 | else: 348 | # if len(x1.shape)>3: 349 | # x1 = x1.permute(0,2,3,1) 350 | # b, h, w, c = x1.shape 351 | # x1 = x1.view(b, -1, c) 352 | out = self.layer_up(x1) 353 | return out 354 | 355 | class MISSFormer(nn.Module): 356 | def __init__(self, num_classes=9, in_ch=3, token_mlp_mode="mix_skip", encoder_pretrained=True): 357 | super().__init__() 358 | 359 | reduction_ratios = [8, 4, 2, 1] 360 | heads = [1, 2, 5, 8] 361 | d_base_feat_size = 7 #16 for 512 inputsize 7for 224 362 | in_out_chan = [[32, 64],[144, 128],[288, 320],[512, 512]] 363 | 364 | dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]] 365 | self.backbone = MiT(224, dims, layers,in_ch, token_mlp_mode) 366 | 367 | self.reduction_ratios = [1, 2, 4, 8] 368 | self.bridge = BridegeBlock_4(64, 1, self.reduction_ratios) 369 | 370 | self.decoder_3= MyDecoderLayer((d_base_feat_size,d_base_feat_size), in_out_chan[3], heads[3], reduction_ratios[3],token_mlp_mode, n_class=num_classes) 371 | self.decoder_2= MyDecoderLayer((d_base_feat_size*2,d_base_feat_size*2),in_out_chan[2], heads[2], reduction_ratios[2], token_mlp_mode, n_class=num_classes) 372 | self.decoder_1= MyDecoderLayer((d_base_feat_size*4,d_base_feat_size*4), in_out_chan[1], heads[1], reduction_ratios[1], token_mlp_mode, n_class=num_classes) 373 | self.decoder_0= MyDecoderLayer((d_base_feat_size*8,d_base_feat_size*8), in_out_chan[0], heads[0], reduction_ratios[0], token_mlp_mode, n_class=num_classes, is_last=True) 374 | 375 | 376 | def forward(self, x): 377 | #---------------Encoder------------------------- 378 | if x.size()[1] == 1: 379 | x = x.repeat(1,3,1,1) 380 | 381 | encoder = self.backbone(x) 382 | bridge = self.bridge(encoder) #list 383 | 384 | b,c,_,_ = bridge[3].shape 385 | # print(bridge[3].shape, bridge[2].shape,bridge[1].shape, bridge[0].shape) 386 | #---------------Decoder------------------------- 387 | # print("stage3-----") 388 | tmp_3 = self.decoder_3(bridge[3].permute(0,2,3,1).view(b,-1,c)) 389 | # print("stage2-----") 390 | tmp_2 = self.decoder_2(tmp_3, bridge[2].permute(0,2,3,1)) 391 | # print("stage1-----") 392 | tmp_1 = self.decoder_1(tmp_2, bridge[1].permute(0,2,3,1)) 393 | # print("stage0-----") 394 | tmp_0 = self.decoder_0(tmp_1, bridge[0].permute(0,2,3,1)) 395 | 396 | return tmp_0 397 | 398 | 399 | -------------------------------------------------------------------------------- /models/_missformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/models/_missformer/__init__.py -------------------------------------------------------------------------------- /models/_resunet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/models/_resunet/__init__.py -------------------------------------------------------------------------------- /models/_resunet/modules.py: -------------------------------------------------------------------------------- 1 | # https://github.com/rishikksh20/ResUnet/blob/master/core/modules.py 2 | 3 | 4 | import torch.nn as nn 5 | import torch 6 | 7 | 8 | class ResidualConv(nn.Module): 9 | def __init__(self, input_dim, output_dim, stride, padding): 10 | super(ResidualConv, self).__init__() 11 | 12 | self.conv_block = nn.Sequential( 13 | nn.BatchNorm2d(input_dim), 14 | nn.ReLU(), 15 | nn.Conv2d( 16 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 17 | ), 18 | nn.BatchNorm2d(output_dim), 19 | nn.ReLU(), 20 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 21 | ) 22 | self.conv_skip = nn.Sequential( 23 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 24 | nn.BatchNorm2d(output_dim), 25 | ) 26 | 27 | def forward(self, x): 28 | 29 | return self.conv_block(x) + self.conv_skip(x) 30 | 31 | 32 | class Upsample(nn.Module): 33 | def __init__(self, input_dim, output_dim, kernel, stride): 34 | super(Upsample, self).__init__() 35 | 36 | self.upsample = nn.ConvTranspose2d( 37 | input_dim, output_dim, kernel_size=kernel, stride=stride 38 | ) 39 | 40 | def forward(self, x): 41 | return self.upsample(x) 42 | 43 | 44 | class Squeeze_Excite_Block(nn.Module): 45 | def __init__(self, channel, reduction=16): 46 | super(Squeeze_Excite_Block, self).__init__() 47 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 48 | self.fc = nn.Sequential( 49 | nn.Linear(channel, channel // reduction, bias=False), 50 | nn.ReLU(inplace=True), 51 | nn.Linear(channel // reduction, channel, bias=False), 52 | nn.Sigmoid(), 53 | ) 54 | 55 | def forward(self, x): 56 | b, c, _, _ = x.size() 57 | y = self.avg_pool(x).view(b, c) 58 | y = self.fc(y).view(b, c, 1, 1) 59 | return x * y.expand_as(x) 60 | 61 | 62 | class ASPP(nn.Module): 63 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): 64 | super(ASPP, self).__init__() 65 | 66 | self.aspp_block1 = nn.Sequential( 67 | nn.Conv2d( 68 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] 69 | ), 70 | nn.ReLU(inplace=True), 71 | nn.BatchNorm2d(out_dims), 72 | ) 73 | self.aspp_block2 = nn.Sequential( 74 | nn.Conv2d( 75 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] 76 | ), 77 | nn.ReLU(inplace=True), 78 | nn.BatchNorm2d(out_dims), 79 | ) 80 | self.aspp_block3 = nn.Sequential( 81 | nn.Conv2d( 82 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] 83 | ), 84 | nn.ReLU(inplace=True), 85 | nn.BatchNorm2d(out_dims), 86 | ) 87 | 88 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) 89 | self._init_weights() 90 | 91 | def forward(self, x): 92 | x1 = self.aspp_block1(x) 93 | x2 = self.aspp_block2(x) 94 | x3 = self.aspp_block3(x) 95 | out = torch.cat([x1, x2, x3], dim=1) 96 | return self.output(out) 97 | 98 | def _init_weights(self): 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | nn.init.kaiming_normal_(m.weight) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | 107 | class Upsample_(nn.Module): 108 | def __init__(self, scale=2): 109 | super(Upsample_, self).__init__() 110 | 111 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) 112 | 113 | def forward(self, x): 114 | return self.upsample(x) 115 | 116 | 117 | class AttentionBlock(nn.Module): 118 | def __init__(self, input_encoder, input_decoder, output_dim): 119 | super(AttentionBlock, self).__init__() 120 | 121 | self.conv_encoder = nn.Sequential( 122 | nn.BatchNorm2d(input_encoder), 123 | nn.ReLU(), 124 | nn.Conv2d(input_encoder, output_dim, 3, padding=1), 125 | nn.MaxPool2d(2, 2), 126 | ) 127 | 128 | self.conv_decoder = nn.Sequential( 129 | nn.BatchNorm2d(input_decoder), 130 | nn.ReLU(), 131 | nn.Conv2d(input_decoder, output_dim, 3, padding=1), 132 | ) 133 | 134 | self.conv_attn = nn.Sequential( 135 | nn.BatchNorm2d(output_dim), 136 | nn.ReLU(), 137 | nn.Conv2d(output_dim, 1, 1), 138 | ) 139 | 140 | def forward(self, x1, x2): 141 | out = self.conv_encoder(x1) + self.conv_decoder(x2) 142 | out = self.conv_attn(out) 143 | return out * x2 -------------------------------------------------------------------------------- /models/_resunet/res_unet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/rishikksh20/ResUnet 2 | 3 | import torch 4 | import torch.nn as nn 5 | from .modules import ResidualConv, Upsample 6 | 7 | 8 | class ResUnet(nn.Module): 9 | def __init__(self, in_ch, out_ch, filters=[64, 128, 256, 512]): 10 | super(ResUnet, self).__init__() 11 | 12 | self.input_layer = nn.Sequential( 13 | nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1), 14 | nn.BatchNorm2d(filters[0]), 15 | nn.ReLU(), 16 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 17 | ) 18 | self.input_skip = nn.Sequential( 19 | nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1) 20 | ) 21 | 22 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1) 23 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1) 24 | 25 | self.bridge = ResidualConv(filters[2], filters[3], 2, 1) 26 | 27 | self.upsample_1 = Upsample(filters[3], filters[3], 2, 2) 28 | self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1) 29 | 30 | self.upsample_2 = Upsample(filters[2], filters[2], 2, 2) 31 | self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1) 32 | 33 | self.upsample_3 = Upsample(filters[1], filters[1], 2, 2) 34 | self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1) 35 | 36 | self.output_layer = nn.Sequential( 37 | nn.Conv2d(filters[0], out_ch, 1, 1), 38 | ) 39 | 40 | def forward(self, x): 41 | # Encode 42 | x1 = self.input_layer(x) + self.input_skip(x) 43 | x2 = self.residual_conv_1(x1) 44 | x3 = self.residual_conv_2(x2) 45 | # Bridge 46 | x4 = self.bridge(x3) 47 | # Decode 48 | x4 = self.upsample_1(x4) 49 | x5 = torch.cat([x4, x3], dim=1) 50 | 51 | x6 = self.up_residual_conv1(x5) 52 | 53 | x6 = self.upsample_2(x6) 54 | x7 = torch.cat([x6, x2], dim=1) 55 | 56 | x8 = self.up_residual_conv2(x7) 57 | 58 | x8 = self.upsample_3(x8) 59 | x9 = torch.cat([x8, x1], dim=1) 60 | 61 | x10 = self.up_residual_conv3(x9) 62 | 63 | output = self.output_layer(x10) 64 | 65 | return output 66 | -------------------------------------------------------------------------------- /models/_transunet/vit_seg_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | def get_b16_config(): 4 | """Returns the ViT-B/16 configuration.""" 5 | config = ml_collections.ConfigDict() 6 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 7 | config.hidden_size = 768 8 | config.transformer = ml_collections.ConfigDict() 9 | config.transformer.mlp_dim = 3072 10 | config.transformer.num_heads = 12 11 | config.transformer.num_layers = 12 12 | config.transformer.attention_dropout_rate = 0.0 13 | config.transformer.dropout_rate = 0.1 14 | 15 | config.classifier = 'seg' 16 | config.representation_size = None 17 | config.resnet_pretrained_path = None 18 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' 19 | config.patch_size = 16 20 | 21 | config.decoder_channels = (256, 128, 64, 16) 22 | config.n_classes = 2 23 | config.activation = 'softmax' 24 | return config 25 | 26 | 27 | def get_testing(): 28 | """Returns a minimal configuration for testing.""" 29 | config = ml_collections.ConfigDict() 30 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 31 | config.hidden_size = 1 32 | config.transformer = ml_collections.ConfigDict() 33 | config.transformer.mlp_dim = 1 34 | config.transformer.num_heads = 1 35 | config.transformer.num_layers = 1 36 | config.transformer.attention_dropout_rate = 0.0 37 | config.transformer.dropout_rate = 0.1 38 | config.classifier = 'token' 39 | config.representation_size = None 40 | return config 41 | 42 | def get_r50_b16_config(): 43 | """Returns the Resnet50 + ViT-B/16 configuration.""" 44 | config = get_b16_config() 45 | config.patches.grid = (16, 16) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | 50 | config.classifier = 'seg' 51 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 52 | config.decoder_channels = (256, 128, 64, 16) 53 | config.skip_channels = [512, 256, 64, 16] 54 | config.n_classes = 2 55 | config.n_skip = 3 56 | config.activation = 'softmax' 57 | 58 | return config 59 | 60 | 61 | def get_b32_config(): 62 | """Returns the ViT-B/32 configuration.""" 63 | config = get_b16_config() 64 | config.patches.size = (32, 32) 65 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' 66 | return config 67 | 68 | 69 | def get_l16_config(): 70 | """Returns the ViT-L/16 configuration.""" 71 | config = ml_collections.ConfigDict() 72 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 73 | config.hidden_size = 1024 74 | config.transformer = ml_collections.ConfigDict() 75 | config.transformer.mlp_dim = 4096 76 | config.transformer.num_heads = 16 77 | config.transformer.num_layers = 24 78 | config.transformer.attention_dropout_rate = 0.0 79 | config.transformer.dropout_rate = 0.1 80 | config.representation_size = None 81 | 82 | # custom 83 | config.classifier = 'seg' 84 | config.resnet_pretrained_path = None 85 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' 86 | config.decoder_channels = (256, 128, 64, 16) 87 | config.n_classes = 2 88 | config.activation = 'softmax' 89 | return config 90 | 91 | 92 | def get_r50_l16_config(): 93 | """Returns the Resnet50 + ViT-L/16 configuration. customized """ 94 | config = get_l16_config() 95 | config.patches.grid = (16, 16) 96 | config.resnet = ml_collections.ConfigDict() 97 | config.resnet.num_layers = (3, 4, 9) 98 | config.resnet.width_factor = 1 99 | 100 | config.classifier = 'seg' 101 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 102 | config.decoder_channels = (256, 128, 64, 16) 103 | config.skip_channels = [512, 256, 64, 16] 104 | config.n_classes = 2 105 | config.activation = 'softmax' 106 | return config 107 | 108 | 109 | def get_l32_config(): 110 | """Returns the ViT-L/32 configuration.""" 111 | config = get_l16_config() 112 | config.patches.size = (32, 32) 113 | return config 114 | 115 | 116 | def get_h14_config(): 117 | """Returns the ViT-L/16 configuration.""" 118 | config = ml_collections.ConfigDict() 119 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 120 | config.hidden_size = 1280 121 | config.transformer = ml_collections.ConfigDict() 122 | config.transformer.mlp_dim = 5120 123 | config.transformer.num_heads = 16 124 | config.transformer.num_layers = 32 125 | config.transformer.attention_dropout_rate = 0.0 126 | config.transformer.dropout_rate = 0.1 127 | config.classifier = 'token' 128 | config.representation_size = None 129 | 130 | return config 131 | -------------------------------------------------------------------------------- /models/_transunet/vit_seg_modeling_resnet_skip.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from os.path import join as pjoin 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def np2th(weights, conv=False): 12 | """Possibly convert HWIO to OIHW.""" 13 | if conv: 14 | weights = weights.transpose([3, 2, 0, 1]) 15 | return torch.from_numpy(weights) 16 | 17 | 18 | class StdConv2d(nn.Conv2d): 19 | 20 | def forward(self, x): 21 | w = self.weight 22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 23 | w = (w - m) / torch.sqrt(v + 1e-5) 24 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 25 | self.dilation, self.groups) 26 | 27 | 28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 30 | padding=1, bias=bias, groups=groups) 31 | 32 | 33 | def conv1x1(cin, cout, stride=1, bias=False): 34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 35 | padding=0, bias=bias) 36 | 37 | 38 | class PreActBottleneck(nn.Module): 39 | """Pre-activation (v2) bottleneck block. 40 | """ 41 | 42 | def __init__(self, cin, cout=None, cmid=None, stride=1): 43 | super().__init__() 44 | cout = cout or cin 45 | cmid = cmid or cout//4 46 | 47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 48 | self.conv1 = conv1x1(cin, cmid, bias=False) 49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 52 | self.conv3 = conv1x1(cmid, cout, bias=False) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | if (stride != 1 or cin != cout): 56 | # Projection also with pre-activation according to paper. 57 | self.downsample = conv1x1(cin, cout, stride, bias=False) 58 | self.gn_proj = nn.GroupNorm(cout, cout) 59 | 60 | def forward(self, x): 61 | 62 | # Residual branch 63 | residual = x 64 | if hasattr(self, 'downsample'): 65 | residual = self.downsample(x) 66 | residual = self.gn_proj(residual) 67 | 68 | # Unit's branch 69 | y = self.relu(self.gn1(self.conv1(x))) 70 | y = self.relu(self.gn2(self.conv2(y))) 71 | y = self.gn3(self.conv3(y)) 72 | 73 | y = self.relu(residual + y) 74 | return y 75 | 76 | def load_from(self, weights, n_block, n_unit): 77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 80 | 81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 83 | 84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 86 | 87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 89 | 90 | self.conv1.weight.copy_(conv1_weight) 91 | self.conv2.weight.copy_(conv2_weight) 92 | self.conv3.weight.copy_(conv3_weight) 93 | 94 | self.gn1.weight.copy_(gn1_weight.view(-1)) 95 | self.gn1.bias.copy_(gn1_bias.view(-1)) 96 | 97 | self.gn2.weight.copy_(gn2_weight.view(-1)) 98 | self.gn2.bias.copy_(gn2_bias.view(-1)) 99 | 100 | self.gn3.weight.copy_(gn3_weight.view(-1)) 101 | self.gn3.bias.copy_(gn3_bias.view(-1)) 102 | 103 | if hasattr(self, 'downsample'): 104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 107 | 108 | self.downsample.weight.copy_(proj_conv_weight) 109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 111 | 112 | class ResNetV2(nn.Module): 113 | """Implementation of Pre-activation (v2) ResNet mode.""" 114 | 115 | def __init__(self, block_units, width_factor): 116 | super().__init__() 117 | width = int(64 * width_factor) 118 | self.width = width 119 | 120 | self.root = nn.Sequential(OrderedDict([ 121 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), 122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 123 | ('relu', nn.ReLU(inplace=True)), 124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 125 | ])) 126 | 127 | self.body = nn.Sequential(OrderedDict([ 128 | ('block1', nn.Sequential(OrderedDict( 129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 131 | ))), 132 | ('block2', nn.Sequential(OrderedDict( 133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 135 | ))), 136 | ('block3', nn.Sequential(OrderedDict( 137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 139 | ))), 140 | ])) 141 | 142 | def forward(self, x): 143 | features = [] 144 | b, c, in_size, _ = x.size() 145 | x = self.root(x) 146 | features.append(x) 147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) 148 | for i in range(len(self.body)-1): 149 | x = self.body[i](x) 150 | right_size = int(in_size / 4 / (i+1)) 151 | if x.size()[2] != right_size: 152 | pad = right_size - x.size()[2] 153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) 154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) 155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] 156 | else: 157 | feat = x 158 | features.append(feat) 159 | x = self.body[-1](x) 160 | return x, features[::-1] 161 | -------------------------------------------------------------------------------- /models/_transunet/vit_seg_modeling_resnet_skip_c4.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from os.path import join as pjoin 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def np2th(weights, conv=False): 12 | """Possibly convert HWIO to OIHW.""" 13 | if conv: 14 | weights = weights.transpose([3, 2, 0, 1]) 15 | return torch.from_numpy(weights) 16 | 17 | 18 | class StdConv2d(nn.Conv2d): 19 | 20 | def forward(self, x): 21 | w = self.weight 22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 23 | w = (w - m) / torch.sqrt(v + 1e-5) 24 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 25 | self.dilation, self.groups) 26 | 27 | 28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 30 | padding=1, bias=bias, groups=groups) 31 | 32 | 33 | def conv1x1(cin, cout, stride=1, bias=False): 34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 35 | padding=0, bias=bias) 36 | 37 | 38 | class PreActBottleneck(nn.Module): 39 | """Pre-activation (v2) bottleneck block. 40 | """ 41 | 42 | def __init__(self, cin, cout=None, cmid=None, stride=1): 43 | super().__init__() 44 | cout = cout or cin 45 | cmid = cmid or cout//4 46 | 47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 48 | self.conv1 = conv1x1(cin, cmid, bias=False) 49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 52 | self.conv3 = conv1x1(cmid, cout, bias=False) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | if (stride != 1 or cin != cout): 56 | # Projection also with pre-activation according to paper. 57 | self.downsample = conv1x1(cin, cout, stride, bias=False) 58 | self.gn_proj = nn.GroupNorm(cout, cout) 59 | 60 | def forward(self, x): 61 | 62 | # Residual branch 63 | residual = x 64 | if hasattr(self, 'downsample'): 65 | residual = self.downsample(x) 66 | residual = self.gn_proj(residual) 67 | 68 | # Unit's branch 69 | y = self.relu(self.gn1(self.conv1(x))) 70 | y = self.relu(self.gn2(self.conv2(y))) 71 | y = self.gn3(self.conv3(y)) 72 | 73 | y = self.relu(residual + y) 74 | return y 75 | 76 | def load_from(self, weights, n_block, n_unit): 77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 80 | 81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 83 | 84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 86 | 87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 89 | 90 | self.conv1.weight.copy_(conv1_weight) 91 | self.conv2.weight.copy_(conv2_weight) 92 | self.conv3.weight.copy_(conv3_weight) 93 | 94 | self.gn1.weight.copy_(gn1_weight.view(-1)) 95 | self.gn1.bias.copy_(gn1_bias.view(-1)) 96 | 97 | self.gn2.weight.copy_(gn2_weight.view(-1)) 98 | self.gn2.bias.copy_(gn2_bias.view(-1)) 99 | 100 | self.gn3.weight.copy_(gn3_weight.view(-1)) 101 | self.gn3.bias.copy_(gn3_bias.view(-1)) 102 | 103 | if hasattr(self, 'downsample'): 104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 107 | 108 | self.downsample.weight.copy_(proj_conv_weight) 109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 111 | 112 | class ResNetV2(nn.Module): 113 | """Implementation of Pre-activation (v2) ResNet mode.""" 114 | 115 | def __init__(self, block_units, width_factor): 116 | super().__init__() 117 | width = int(64 * width_factor) 118 | self.width = width 119 | 120 | self.root = nn.Sequential(OrderedDict([ 121 | ('conv', StdConv2d(4, width, kernel_size=7, stride=2, bias=False, padding=3)), 122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 123 | ('relu', nn.ReLU(inplace=True)), 124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 125 | ])) 126 | 127 | self.body = nn.Sequential(OrderedDict([ 128 | ('block1', nn.Sequential(OrderedDict( 129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 131 | ))), 132 | ('block2', nn.Sequential(OrderedDict( 133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 135 | ))), 136 | ('block3', nn.Sequential(OrderedDict( 137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 139 | ))), 140 | ])) 141 | 142 | def forward(self, x): 143 | features = [] 144 | b, c, in_size, _ = x.size() 145 | x = self.root(x) 146 | features.append(x) 147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) 148 | for i in range(len(self.body)-1): 149 | x = self.body[i](x) 150 | right_size = int(in_size / 4 / (i+1)) 151 | if x.size()[2] != right_size: 152 | pad = right_size - x.size()[2] 153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) 154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) 155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] 156 | else: 157 | feat = x 158 | features.append(feat) 159 | x = self.body[-1](x) 160 | return x, features[::-1] 161 | -------------------------------------------------------------------------------- /models/_uctransnet/Config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/6/19 2:44 下午 3 | # @Author : Haonan Wang 4 | # @File : Config.py 5 | # @Software: PyCharm 6 | import os 7 | import torch 8 | import time 9 | import ml_collections 10 | 11 | ## PARAMETERS OF THE MODEL 12 | save_model = True 13 | tensorboard = True 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 15 | use_cuda = torch.cuda.is_available() 16 | seed = 666 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | 19 | cosineLR = True # whether use cosineLR or not 20 | n_channels = 3 21 | n_labels = 1 22 | epochs = 2000 23 | img_size = 224 24 | print_frequency = 1 25 | save_frequency = 5000 26 | vis_frequency = 10 27 | early_stopping_patience = 50 28 | 29 | pretrain = False 30 | task_name = 'MoNuSeg' # GlaS MoNuSeg 31 | # task_name = 'GlaS' 32 | learning_rate = 1e-3 33 | batch_size = 4 34 | 35 | 36 | # model_name = 'UCTransNet' 37 | model_name = 'UCTransNet_pretrain' 38 | 39 | train_dataset = './datasets/'+ task_name+ '/Train_Folder/' 40 | val_dataset = './datasets/'+ task_name+ '/Val_Folder/' 41 | test_dataset = './datasets/'+ task_name+ '/Test_Folder/' 42 | session_name = 'Test_session' + '_' + time.strftime('%m.%d_%Hh%M') 43 | save_path = task_name +'/'+ model_name +'/' + session_name + '/' 44 | model_path = save_path + 'models/' 45 | tensorboard_folder = save_path + 'tensorboard_logs/' 46 | logger_path = save_path + session_name + ".log" 47 | visualize_path = save_path + 'visualize_val/' 48 | 49 | 50 | ########################################################################## 51 | # CTrans configs 52 | ########################################################################## 53 | def get_CTranS_config(): 54 | config = ml_collections.ConfigDict() 55 | config.transformer = ml_collections.ConfigDict() 56 | config.KV_size = 960 # KV_size = Q1 + Q2 + Q3 + Q4 57 | config.transformer.num_heads = 4 58 | config.transformer.num_layers = 4 59 | config.expand_ratio = 4 # MLP channel dimension expand ratio 60 | config.transformer.embeddings_dropout_rate = 0.1 61 | config.transformer.attention_dropout_rate = 0.1 62 | config.transformer.dropout_rate = 0 63 | config.patch_sizes = [16,8,4,2] 64 | config.base_channel = 64 # base channel of U-Net 65 | config.n_classes = 1 66 | return config 67 | 68 | 69 | 70 | 71 | # used in testing phase, copy the session name in training phase 72 | test_session = "Test_session_07.03_20h39" -------------------------------------------------------------------------------- /models/_uctransnet/UCTransNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/7/8 8:59 上午 3 | # @File : UCTransNet.py 4 | # @Software: PyCharm 5 | import torch.nn as nn 6 | import torch 7 | import torch.nn.functional as F 8 | from .CTrans import ChannelTransformer 9 | 10 | def get_activation(activation_type): 11 | activation_type = activation_type.lower() 12 | if hasattr(nn, activation_type): 13 | return getattr(nn, activation_type)() 14 | else: 15 | return nn.ReLU() 16 | 17 | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'): 18 | layers = [] 19 | layers.append(ConvBatchNorm(in_channels, out_channels, activation)) 20 | 21 | for _ in range(nb_Conv - 1): 22 | layers.append(ConvBatchNorm(out_channels, out_channels, activation)) 23 | return nn.Sequential(*layers) 24 | 25 | class ConvBatchNorm(nn.Module): 26 | """(convolution => [BN] => ReLU)""" 27 | 28 | def __init__(self, in_channels, out_channels, activation='ReLU'): 29 | super(ConvBatchNorm, self).__init__() 30 | self.conv = nn.Conv2d(in_channels, out_channels, 31 | kernel_size=3, padding=1) 32 | self.norm = nn.BatchNorm2d(out_channels) 33 | self.activation = get_activation(activation) 34 | 35 | def forward(self, x): 36 | out = self.conv(x) 37 | out = self.norm(out) 38 | return self.activation(out) 39 | 40 | class DownBlock(nn.Module): 41 | """Downscaling with maxpool convolution""" 42 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 43 | super(DownBlock, self).__init__() 44 | self.maxpool = nn.MaxPool2d(2) 45 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 46 | 47 | def forward(self, x): 48 | out = self.maxpool(x) 49 | return self.nConvs(out) 50 | 51 | class Flatten(nn.Module): 52 | def forward(self, x): 53 | return x.view(x.size(0), -1) 54 | 55 | class CCA(nn.Module): 56 | """ 57 | CCA Block 58 | """ 59 | def __init__(self, F_g, F_x): 60 | super().__init__() 61 | self.mlp_x = nn.Sequential( 62 | Flatten(), 63 | nn.Linear(F_x, F_x)) 64 | self.mlp_g = nn.Sequential( 65 | Flatten(), 66 | nn.Linear(F_g, F_x)) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | def forward(self, g, x): 70 | # channel-wise attention 71 | avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 72 | channel_att_x = self.mlp_x(avg_pool_x) 73 | avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3))) 74 | channel_att_g = self.mlp_g(avg_pool_g) 75 | channel_att_sum = (channel_att_x + channel_att_g)/2.0 76 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 77 | x_after_channel = x * scale 78 | out = self.relu(x_after_channel) 79 | return out 80 | 81 | class UpBlock_attention(nn.Module): 82 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 83 | super().__init__() 84 | self.up = nn.Upsample(scale_factor=2) 85 | self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2) 86 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 87 | 88 | def forward(self, x, skip_x): 89 | up = self.up(x) 90 | skip_x_att = self.coatt(g=up, x=skip_x) 91 | x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension 92 | return self.nConvs(x) 93 | 94 | class UCTransNet(nn.Module): 95 | def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False): 96 | super().__init__() 97 | self.vis = vis 98 | self.n_channels = n_channels 99 | self.n_classes = n_classes 100 | in_channels = config.base_channel 101 | self.inc = ConvBatchNorm(n_channels, in_channels) 102 | self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2) 103 | self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2) 104 | self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2) 105 | self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2) 106 | self.mtc = ChannelTransformer(config, vis, img_size, 107 | channel_num=[in_channels, in_channels*2, in_channels*4, in_channels*8], 108 | patchSize=config.patch_sizes) 109 | self.up4 = UpBlock_attention(in_channels*16, in_channels*4, nb_Conv=2) 110 | self.up3 = UpBlock_attention(in_channels*8, in_channels*2, nb_Conv=2) 111 | self.up2 = UpBlock_attention(in_channels*4, in_channels, nb_Conv=2) 112 | self.up1 = UpBlock_attention(in_channels*2, in_channels, nb_Conv=2) 113 | self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1), stride=(1,1)) 114 | self.last_activation = nn.Sigmoid() # if using BCELoss 115 | 116 | def forward(self, x): 117 | x = x.float() 118 | x1 = self.inc(x) 119 | x2 = self.down1(x1) 120 | x3 = self.down2(x2) 121 | x4 = self.down3(x3) 122 | x5 = self.down4(x4) 123 | x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4) 124 | x = self.up4(x5, x4) 125 | x = self.up3(x, x3) 126 | x = self.up2(x, x2) 127 | x = self.up1(x, x1) 128 | if self.n_classes ==1: 129 | logits = self.last_activation(self.outc(x)) 130 | else: 131 | logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1 132 | if self.vis: # visualize the attention maps 133 | return logits, att_weights 134 | else: 135 | return logits 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /models/_uctransnet/UNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | def get_activation(activation_type): 5 | activation_type = activation_type.lower() 6 | if hasattr(nn, activation_type): 7 | return getattr(nn, activation_type)() 8 | else: 9 | return nn.ReLU() 10 | 11 | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'): 12 | layers = [] 13 | layers.append(ConvBatchNorm(in_channels, out_channels, activation)) 14 | 15 | for _ in range(nb_Conv - 1): 16 | layers.append(ConvBatchNorm(out_channels, out_channels, activation)) 17 | return nn.Sequential(*layers) 18 | 19 | class ConvBatchNorm(nn.Module): 20 | """(convolution => [BN] => ReLU)""" 21 | 22 | def __init__(self, in_channels, out_channels, activation='ReLU'): 23 | super(ConvBatchNorm, self).__init__() 24 | self.conv = nn.Conv2d(in_channels, out_channels, 25 | kernel_size=3, padding=1) 26 | self.norm = nn.BatchNorm2d(out_channels) 27 | self.activation = get_activation(activation) 28 | 29 | def forward(self, x): 30 | out = self.conv(x) 31 | out = self.norm(out) 32 | return self.activation(out) 33 | 34 | class DownBlock(nn.Module): 35 | """Downscaling with maxpool convolution""" 36 | 37 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 38 | super(DownBlock, self).__init__() 39 | self.maxpool = nn.MaxPool2d(2) 40 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 41 | 42 | def forward(self, x): 43 | out = self.maxpool(x) 44 | return self.nConvs(out) 45 | 46 | class UpBlock(nn.Module): 47 | """Upscaling then conv""" 48 | 49 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 50 | super(UpBlock, self).__init__() 51 | 52 | # self.up = nn.Upsample(scale_factor=2) 53 | self.up = nn.ConvTranspose2d(in_channels//2,in_channels//2,(2,2),2) 54 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 55 | 56 | def forward(self, x, skip_x): 57 | out = self.up(x) 58 | x = torch.cat([out, skip_x], dim=1) # dim 1 is the channel dimension 59 | return self.nConvs(x) 60 | 61 | class UNet(nn.Module): 62 | def __init__(self, n_channels=3, n_classes=9): 63 | ''' 64 | n_channels : number of channels of the input. 65 | By default 3, because we have RGB images 66 | n_labels : number of channels of the ouput. 67 | By default 3 (2 labels + 1 for the background) 68 | ''' 69 | super().__init__() 70 | self.n_channels = n_channels 71 | self.n_classes = n_classes 72 | # Question here 73 | in_channels = 64 74 | self.inc = ConvBatchNorm(n_channels, in_channels) 75 | self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2) 76 | self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2) 77 | self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2) 78 | self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2) 79 | self.up4 = UpBlock(in_channels*16, in_channels*4, nb_Conv=2) 80 | self.up3 = UpBlock(in_channels*8, in_channels*2, nb_Conv=2) 81 | self.up2 = UpBlock(in_channels*4, in_channels, nb_Conv=2) 82 | self.up1 = UpBlock(in_channels*2, in_channels, nb_Conv=2) 83 | self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1)) 84 | if n_classes == 1: 85 | self.last_activation = nn.Sigmoid() 86 | else: 87 | self.last_activation = None 88 | 89 | def forward(self, x): 90 | # Question here 91 | x = x.float() 92 | x1 = self.inc(x) 93 | x2 = self.down1(x1) 94 | x3 = self.down2(x2) 95 | x4 = self.down3(x3) 96 | x5 = self.down4(x4) 97 | x = self.up4(x5, x4) 98 | x = self.up3(x, x3) 99 | x = self.up2(x, x2) 100 | x = self.up1(x, x1) 101 | if self.last_activation is not None: 102 | logits = self.last_activation(self.outc(x)) 103 | # print("111") 104 | else: 105 | logits = self.outc(x) 106 | # print("222") 107 | # logits = self.outc(x) # if using BCEWithLogitsLoss 108 | # print(logits.size()) 109 | return logits 110 | 111 | 112 | -------------------------------------------------------------------------------- /models/attunet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/LeeJunHyun/Image_Segmentation 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | def init_weights(net, init_type='normal', gain=0.02): 9 | def init_func(m): 10 | classname = m.__class__.__name__ 11 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 12 | if init_type == 'normal': 13 | init.normal_(m.weight.data, 0.0, gain) 14 | elif init_type == 'xavier': 15 | init.xavier_normal_(m.weight.data, gain=gain) 16 | elif init_type == 'kaiming': 17 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 18 | elif init_type == 'orthogonal': 19 | init.orthogonal_(m.weight.data, gain=gain) 20 | else: 21 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 22 | if hasattr(m, 'bias') and m.bias is not None: 23 | init.constant_(m.bias.data, 0.0) 24 | elif classname.find('BatchNorm2d') != -1: 25 | init.normal_(m.weight.data, 1.0, gain) 26 | init.constant_(m.bias.data, 0.0) 27 | 28 | print('initialize network with %s' % init_type) 29 | net.apply(init_func) 30 | 31 | class conv_block(nn.Module): 32 | def __init__(self,ch_in,ch_out): 33 | super(conv_block,self).__init__() 34 | self.conv = nn.Sequential( 35 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 36 | nn.BatchNorm2d(ch_out), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 39 | nn.BatchNorm2d(ch_out), 40 | nn.ReLU(inplace=True) 41 | ) 42 | 43 | 44 | def forward(self,x): 45 | x = self.conv(x) 46 | return x 47 | 48 | class up_conv(nn.Module): 49 | def __init__(self,ch_in,ch_out): 50 | super(up_conv,self).__init__() 51 | self.up = nn.Sequential( 52 | nn.Upsample(scale_factor=2), 53 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 54 | nn.BatchNorm2d(ch_out), 55 | nn.ReLU(inplace=True) 56 | ) 57 | 58 | def forward(self,x): 59 | x = self.up(x) 60 | return x 61 | 62 | class Recurrent_block(nn.Module): 63 | def __init__(self,ch_out,t=2): 64 | super(Recurrent_block,self).__init__() 65 | self.t = t 66 | self.ch_out = ch_out 67 | self.conv = nn.Sequential( 68 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 69 | nn.BatchNorm2d(ch_out), 70 | nn.ReLU(inplace=True) 71 | ) 72 | 73 | def forward(self,x): 74 | for i in range(self.t): 75 | 76 | if i==0: 77 | x1 = self.conv(x) 78 | 79 | x1 = self.conv(x+x1) 80 | return x1 81 | 82 | class RRCNN_block(nn.Module): 83 | def __init__(self,ch_in,ch_out,t=2): 84 | super(RRCNN_block,self).__init__() 85 | self.RCNN = nn.Sequential( 86 | Recurrent_block(ch_out,t=t), 87 | Recurrent_block(ch_out,t=t) 88 | ) 89 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 90 | 91 | def forward(self,x): 92 | x = self.Conv_1x1(x) 93 | x1 = self.RCNN(x) 94 | return x+x1 95 | 96 | 97 | class single_conv(nn.Module): 98 | def __init__(self,ch_in,ch_out): 99 | super(single_conv,self).__init__() 100 | self.conv = nn.Sequential( 101 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 102 | nn.BatchNorm2d(ch_out), 103 | nn.ReLU(inplace=True) 104 | ) 105 | 106 | def forward(self,x): 107 | x = self.conv(x) 108 | return x 109 | 110 | class Attention_block(nn.Module): 111 | def __init__(self,F_g,F_l,F_int): 112 | super(Attention_block,self).__init__() 113 | self.W_g = nn.Sequential( 114 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 115 | nn.BatchNorm2d(F_int) 116 | ) 117 | 118 | self.W_x = nn.Sequential( 119 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 120 | nn.BatchNorm2d(F_int) 121 | ) 122 | 123 | self.psi = nn.Sequential( 124 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 125 | nn.BatchNorm2d(1), 126 | nn.Sigmoid() 127 | ) 128 | 129 | self.relu = nn.ReLU(inplace=True) 130 | 131 | def forward(self,g,x): 132 | g1 = self.W_g(g) 133 | x1 = self.W_x(x) 134 | psi = self.relu(g1+x1) 135 | psi = self.psi(psi) 136 | 137 | return x*psi 138 | 139 | 140 | class U_Net(nn.Module): 141 | def __init__(self,img_ch=3,output_ch=1): 142 | super(U_Net,self).__init__() 143 | 144 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 145 | 146 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 147 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 148 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 149 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 150 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 151 | 152 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 153 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 154 | 155 | self.Up4 = up_conv(ch_in=512,ch_out=256) 156 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 157 | 158 | self.Up3 = up_conv(ch_in=256,ch_out=128) 159 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 160 | 161 | self.Up2 = up_conv(ch_in=128,ch_out=64) 162 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 163 | 164 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 165 | 166 | 167 | def forward(self,x): 168 | # encoding path 169 | x1 = self.Conv1(x) 170 | 171 | x2 = self.Maxpool(x1) 172 | x2 = self.Conv2(x2) 173 | 174 | x3 = self.Maxpool(x2) 175 | x3 = self.Conv3(x3) 176 | 177 | x4 = self.Maxpool(x3) 178 | x4 = self.Conv4(x4) 179 | 180 | x5 = self.Maxpool(x4) 181 | x5 = self.Conv5(x5) 182 | 183 | # decoding + concat path 184 | d5 = self.Up5(x5) 185 | d5 = torch.cat((x4,d5),dim=1) 186 | 187 | d5 = self.Up_conv5(d5) 188 | 189 | d4 = self.Up4(d5) 190 | d4 = torch.cat((x3,d4),dim=1) 191 | d4 = self.Up_conv4(d4) 192 | 193 | d3 = self.Up3(d4) 194 | d3 = torch.cat((x2,d3),dim=1) 195 | d3 = self.Up_conv3(d3) 196 | 197 | d2 = self.Up2(d3) 198 | d2 = torch.cat((x1,d2),dim=1) 199 | d2 = self.Up_conv2(d2) 200 | 201 | d1 = self.Conv_1x1(d2) 202 | 203 | return d1 204 | 205 | 206 | class R2U_Net(nn.Module): 207 | def __init__(self,img_ch=3,output_ch=1,t=2): 208 | super(R2U_Net,self).__init__() 209 | 210 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 211 | self.Upsample = nn.Upsample(scale_factor=2) 212 | 213 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 214 | 215 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 216 | 217 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 218 | 219 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 220 | 221 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 222 | 223 | 224 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 225 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 226 | 227 | self.Up4 = up_conv(ch_in=512,ch_out=256) 228 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 229 | 230 | self.Up3 = up_conv(ch_in=256,ch_out=128) 231 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 232 | 233 | self.Up2 = up_conv(ch_in=128,ch_out=64) 234 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 235 | 236 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 237 | 238 | 239 | def forward(self,x): 240 | # encoding path 241 | x1 = self.RRCNN1(x) 242 | 243 | x2 = self.Maxpool(x1) 244 | x2 = self.RRCNN2(x2) 245 | 246 | x3 = self.Maxpool(x2) 247 | x3 = self.RRCNN3(x3) 248 | 249 | x4 = self.Maxpool(x3) 250 | x4 = self.RRCNN4(x4) 251 | 252 | x5 = self.Maxpool(x4) 253 | x5 = self.RRCNN5(x5) 254 | 255 | # decoding + concat path 256 | d5 = self.Up5(x5) 257 | d5 = torch.cat((x4,d5),dim=1) 258 | d5 = self.Up_RRCNN5(d5) 259 | 260 | d4 = self.Up4(d5) 261 | d4 = torch.cat((x3,d4),dim=1) 262 | d4 = self.Up_RRCNN4(d4) 263 | 264 | d3 = self.Up3(d4) 265 | d3 = torch.cat((x2,d3),dim=1) 266 | d3 = self.Up_RRCNN3(d3) 267 | 268 | d2 = self.Up2(d3) 269 | d2 = torch.cat((x1,d2),dim=1) 270 | d2 = self.Up_RRCNN2(d2) 271 | 272 | d1 = self.Conv_1x1(d2) 273 | 274 | return d1 275 | 276 | 277 | 278 | class AttU_Net(nn.Module): 279 | def __init__(self,img_ch=3,output_ch=1): 280 | super(AttU_Net,self).__init__() 281 | 282 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 283 | 284 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 285 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 286 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 287 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 288 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 289 | 290 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 291 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 292 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 293 | 294 | self.Up4 = up_conv(ch_in=512,ch_out=256) 295 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) 296 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 297 | 298 | self.Up3 = up_conv(ch_in=256,ch_out=128) 299 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) 300 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 301 | 302 | self.Up2 = up_conv(ch_in=128,ch_out=64) 303 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) 304 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 305 | 306 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 307 | 308 | 309 | def forward(self,x): 310 | # encoding path 311 | x1 = self.Conv1(x) 312 | 313 | x2 = self.Maxpool(x1) 314 | x2 = self.Conv2(x2) 315 | 316 | x3 = self.Maxpool(x2) 317 | x3 = self.Conv3(x3) 318 | 319 | x4 = self.Maxpool(x3) 320 | x4 = self.Conv4(x4) 321 | 322 | x5 = self.Maxpool(x4) 323 | x5 = self.Conv5(x5) 324 | 325 | # decoding + concat path 326 | d5 = self.Up5(x5) 327 | x4 = self.Att5(g=d5,x=x4) 328 | d5 = torch.cat((x4,d5),dim=1) 329 | d5 = self.Up_conv5(d5) 330 | 331 | d4 = self.Up4(d5) 332 | x3 = self.Att4(g=d4,x=x3) 333 | d4 = torch.cat((x3,d4),dim=1) 334 | d4 = self.Up_conv4(d4) 335 | 336 | d3 = self.Up3(d4) 337 | x2 = self.Att3(g=d3,x=x2) 338 | d3 = torch.cat((x2,d3),dim=1) 339 | d3 = self.Up_conv3(d3) 340 | 341 | d2 = self.Up2(d3) 342 | x1 = self.Att2(g=d2,x=x1) 343 | d2 = torch.cat((x1,d2),dim=1) 344 | d2 = self.Up_conv2(d2) 345 | 346 | d1 = self.Conv_1x1(d2) 347 | 348 | return d1 349 | 350 | 351 | class R2AttU_Net(nn.Module): 352 | def __init__(self,img_ch=3,output_ch=1,t=2): 353 | super(R2AttU_Net,self).__init__() 354 | 355 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 356 | self.Upsample = nn.Upsample(scale_factor=2) 357 | 358 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 359 | 360 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 361 | 362 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 363 | 364 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 365 | 366 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 367 | 368 | 369 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 370 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 371 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 372 | 373 | self.Up4 = up_conv(ch_in=512,ch_out=256) 374 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) 375 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 376 | 377 | self.Up3 = up_conv(ch_in=256,ch_out=128) 378 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) 379 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 380 | 381 | self.Up2 = up_conv(ch_in=128,ch_out=64) 382 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) 383 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 384 | 385 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 386 | 387 | 388 | def forward(self,x): 389 | # encoding path 390 | x1 = self.RRCNN1(x) 391 | 392 | x2 = self.Maxpool(x1) 393 | x2 = self.RRCNN2(x2) 394 | 395 | x3 = self.Maxpool(x2) 396 | x3 = self.RRCNN3(x3) 397 | 398 | x4 = self.Maxpool(x3) 399 | x4 = self.RRCNN4(x4) 400 | 401 | x5 = self.Maxpool(x4) 402 | x5 = self.RRCNN5(x5) 403 | 404 | # decoding + concat path 405 | d5 = self.Up5(x5) 406 | x4 = self.Att5(g=d5,x=x4) 407 | d5 = torch.cat((x4,d5),dim=1) 408 | d5 = self.Up_RRCNN5(d5) 409 | 410 | d4 = self.Up4(d5) 411 | x3 = self.Att4(g=d4,x=x3) 412 | d4 = torch.cat((x3,d4),dim=1) 413 | d4 = self.Up_RRCNN4(d4) 414 | 415 | d3 = self.Up3(d4) 416 | x2 = self.Att3(g=d3,x=x2) 417 | d3 = torch.cat((x2,d3),dim=1) 418 | d3 = self.Up_RRCNN3(d3) 419 | 420 | d2 = self.Up2(d3) 421 | x1 = self.Att2(g=d2,x=x1) 422 | d2 = torch.cat((x1,d2),dim=1) 423 | d2 = self.Up_RRCNN2(d2) 424 | 425 | d1 = self.Conv_1x1(d2) 426 | 427 | return d1 -------------------------------------------------------------------------------- /models/multiresunet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/j-sripad/mulitresunet-pytorch/blob/main/multiresunet.py 2 | 3 | from typing import Tuple, Dict 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch 7 | 8 | 9 | class Multiresblock(nn.Module): 10 | def __init__(self,input_features : int, corresponding_unet_filters : int ,alpha : float =1.67)->None: 11 | """ 12 | MultiResblock 13 | Arguments: 14 | x - input layer 15 | corresponding_unet_filters - Unet filters for the same stage 16 | alpha - 1.67 - factor used in the paper to dervie number of filters for multiresunet filters from Unet filters 17 | Returns - None 18 | """ 19 | super().__init__() 20 | self.corresponding_unet_filters = corresponding_unet_filters 21 | self.alpha = alpha 22 | self.W = corresponding_unet_filters * alpha 23 | self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5), 24 | kernel_size = (1,1),activation='None',padding = 0) 25 | 26 | self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167), 27 | kernel_size = (3,3),activation='relu',padding = 1) 28 | self.conv2d_bn_5x5 = Conv2d_batchnorm(input_features=int(self.W*0.167),num_of_filters = int(self.W*0.333), 29 | kernel_size = (3,3),activation='relu',padding = 1) 30 | self.conv2d_bn_7x7 = Conv2d_batchnorm(input_features=int(self.W*0.333),num_of_filters = int(self.W*0.5), 31 | kernel_size = (3,3),activation='relu',padding = 1) 32 | self.batch_norm1 = nn.BatchNorm2d(int(self.W*0.5)+int(self.W*0.167)+int(self.W*0.333) ,affine=False) 33 | 34 | def forward(self,x: torch.Tensor)->torch.Tensor: 35 | 36 | temp = self.conv2d_bn_1x1(x) 37 | a = self.conv2d_bn_3x3(x) 38 | b = self.conv2d_bn_5x5(a) 39 | c = self.conv2d_bn_7x7(b) 40 | x = torch.cat([a,b,c],axis=1) 41 | x = self.batch_norm1(x) 42 | x = x + temp 43 | x = self.batch_norm1(x) 44 | return x 45 | 46 | class Conv2d_batchnorm(nn.Module): 47 | def __init__(self,input_features : int,num_of_filters : int ,kernel_size : Tuple = (2,2),stride : Tuple = (1,1), activation : str = 'relu',padding : int= 0)->None: 48 | """ 49 | Arguments: 50 | x - input layer 51 | num_of_filters - no. of filter outputs 52 | filters - shape of the filters to be used 53 | stride - stride dimension 54 | activation -activation function to be used 55 | Returns - None 56 | """ 57 | super().__init__() 58 | self.activation = activation 59 | self.conv1 = nn.Conv2d(in_channels=input_features,out_channels=num_of_filters,kernel_size=kernel_size,stride=stride,padding = padding) 60 | self.batchnorm = nn.BatchNorm2d(num_of_filters,affine=False) 61 | 62 | def forward(self,x : torch.Tensor)->torch.Tensor: 63 | x = self.conv1(x) 64 | x = self.batchnorm(x) 65 | if self.activation == 'relu': 66 | return F.relu(x) 67 | else: 68 | return x 69 | 70 | 71 | class Respath(nn.Module): 72 | def __init__(self,input_features : int,filters : int,respath_length : int)->None: 73 | """ 74 | Arguments: 75 | input_features - input layer filters 76 | filters - output channels 77 | respath_length - length of the Respath 78 | 79 | Returns - None 80 | """ 81 | super().__init__() 82 | self.filters = filters 83 | self.respath_length = respath_length 84 | self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters, 85 | kernel_size = (1,1),activation='None',padding = 0) 86 | self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters, 87 | kernel_size = (3,3),activation='relu',padding = 1) 88 | self.conv2d_bn_1x1_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters, 89 | kernel_size = (1,1),activation='None',padding = 0) 90 | self.conv2d_bn_3x3_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters, 91 | kernel_size = (3,3),activation='relu',padding = 1) 92 | self.batch_norm1 = nn.BatchNorm2d(filters,affine=False) 93 | 94 | def forward(self,x : torch.Tensor)->torch.Tensor: 95 | shortcut = self.conv2d_bn_1x1(x) 96 | x = self.conv2d_bn_3x3(x) 97 | x = x + shortcut 98 | x = F.relu(x) 99 | x = self.batch_norm1(x) 100 | if self.respath_length>1: 101 | for i in range(self.respath_length): 102 | shortcut = self.conv2d_bn_1x1_common(x) 103 | x = self.conv2d_bn_3x3_common(x) 104 | x = x + shortcut 105 | x = F.relu(x) 106 | x = self.batch_norm1(x) 107 | return x 108 | else: 109 | return x 110 | 111 | class MultiResUnet(nn.Module): 112 | def __init__(self,channels : int,filters : int =32,nclasses : int =1)->None: 113 | 114 | """ 115 | Arguments: 116 | channels - input image channels 117 | filters - filters to begin with (Unet) 118 | nclasses - number of classes 119 | Returns - None 120 | """ 121 | super().__init__() 122 | self.alpha = 1.67 123 | self.filters = filters 124 | self.nclasses = nclasses 125 | self.multiresblock1 = Multiresblock(input_features=channels,corresponding_unet_filters=self.filters) 126 | self.pool1 = nn.MaxPool2d(2,stride= 2) 127 | self.in_filters1 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333) 128 | self.respath1 = Respath(input_features=self.in_filters1 ,filters=self.filters,respath_length=4) 129 | self.multiresblock2 = Multiresblock(input_features= self.in_filters1,corresponding_unet_filters=self.filters*2) 130 | self.pool2 = nn.MaxPool2d(2, 2) 131 | self.in_filters2 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333) 132 | self.respath2 = Respath(input_features=self.in_filters2,filters=self.filters*2,respath_length=3) 133 | self.multiresblock3 = Multiresblock(input_features= self.in_filters2,corresponding_unet_filters=self.filters*4) 134 | self.pool3 = nn.MaxPool2d(2, 2) 135 | self.in_filters3 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333) 136 | self.respath3 = Respath(input_features=self.in_filters3,filters=self.filters*4,respath_length=2) 137 | self.multiresblock4 = Multiresblock(input_features= self.in_filters3,corresponding_unet_filters=self.filters*8) 138 | self.pool4 = nn.MaxPool2d(2, 2) 139 | self.in_filters4 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333) 140 | self.respath4 = Respath(input_features=self.in_filters4,filters=self.filters*8,respath_length=1) 141 | self.multiresblock5 = Multiresblock(input_features= self.in_filters4,corresponding_unet_filters=self.filters*16) 142 | self.in_filters5 = int(self.filters*16*self.alpha* 0.5)+int(self.filters*16*self.alpha*0.167)+int(self.filters*16*self.alpha*0.333) 143 | 144 | #Decoder path 145 | self.upsample6 = nn.ConvTranspose2d(in_channels=self.in_filters5,out_channels=self.filters*8,kernel_size=(2,2),stride=(2,2),padding = 0) 146 | self.concat_filters1 = self.filters*8+self.filters*8 147 | self.multiresblock6 = Multiresblock(input_features=self.concat_filters1,corresponding_unet_filters=self.filters*8) 148 | self.in_filters6 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333) 149 | self.upsample7 = nn.ConvTranspose2d(in_channels=self.in_filters6,out_channels=self.filters*4,kernel_size=(2,2),stride=(2,2),padding = 0) 150 | self.concat_filters2 = self.filters*4+self.filters*4 151 | self.multiresblock7 = Multiresblock(input_features=self.concat_filters2,corresponding_unet_filters=self.filters*4) 152 | self.in_filters7 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333) 153 | self.upsample8 = nn.ConvTranspose2d(in_channels=self.in_filters7,out_channels=self.filters*2,kernel_size=(2,2),stride=(2,2),padding = 0) 154 | self.concat_filters3 = self.filters*2+self.filters*2 155 | self.multiresblock8 = Multiresblock(input_features=self.concat_filters3,corresponding_unet_filters=self.filters*2) 156 | self.in_filters8 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333) 157 | self.upsample9 = nn.ConvTranspose2d(in_channels=self.in_filters8,out_channels=self.filters,kernel_size=(2,2),stride=(2,2),padding = 0) 158 | self.concat_filters4 = self.filters+self.filters 159 | self.multiresblock9 = Multiresblock(input_features=self.concat_filters4,corresponding_unet_filters=self.filters) 160 | self.in_filters9 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333) 161 | self.conv_final = Conv2d_batchnorm(input_features=self.in_filters9,num_of_filters = self.nclasses, 162 | kernel_size = (1,1),activation='None') 163 | 164 | def forward(self,x : torch.Tensor)->torch.Tensor: 165 | x_multires1 = self.multiresblock1(x) 166 | x_pool1 = self.pool1(x_multires1) 167 | x_multires1 = self.respath1(x_multires1) 168 | x_multires2 = self.multiresblock2(x_pool1) 169 | x_pool2 = self.pool2(x_multires2) 170 | x_multires2 = self.respath2(x_multires2) 171 | x_multires3 = self.multiresblock3(x_pool2) 172 | x_pool3 = self.pool3(x_multires3) 173 | x_multires3 = self.respath3(x_multires3) 174 | x_multires4 = self.multiresblock4(x_pool3) 175 | x_pool4 = self.pool4(x_multires4) 176 | x_multires4 = self.respath4(x_multires4) 177 | x_multires5 = self.multiresblock5(x_pool4) 178 | up6 = torch.cat([self.upsample6(x_multires5),x_multires4],axis=1) 179 | x_multires6 = self.multiresblock6(up6) 180 | up7 = torch.cat([self.upsample7(x_multires6),x_multires3],axis=1) 181 | x_multires7 = self.multiresblock7(up7) 182 | up8 = torch.cat([self.upsample8(x_multires7),x_multires2],axis=1) 183 | x_multires8 = self.multiresblock8(up8) 184 | up9 = torch.cat([self.upsample9(x_multires8),x_multires1],axis=1) 185 | x_multires9 = self.multiresblock9(up9) 186 | if self.nclasses > 1: 187 | conv_final_layer = self.conv_final(x_multires9) 188 | else: 189 | conv_final_layer = torch.sigmoid(self.conv_final(x_multires9)) 190 | return conv_final_layer -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | 6 | class DoubleConv(nn.Module): 7 | def __init__(self, in_channels, out_channels, with_bn=False): 8 | super().__init__() 9 | if with_bn: 10 | self.step = nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 12 | nn.BatchNorm2d(out_channels), 13 | nn.ReLU(), 14 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(out_channels), 16 | nn.ReLU(), 17 | ) 18 | else: 19 | self.step = nn.Sequential( 20 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 21 | nn.ReLU(), 22 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 23 | nn.ReLU(), 24 | ) 25 | 26 | def forward(self, x): 27 | return self.step(x) 28 | 29 | 30 | class UNet(nn.Module): 31 | def __init__(self, in_channels, out_channels, with_bn=False): 32 | super().__init__() 33 | init_channels = 32 34 | self.out_channels = out_channels 35 | 36 | self.en_1 = DoubleConv(in_channels , init_channels , with_bn) 37 | self.en_2 = DoubleConv(1*init_channels, 2*init_channels, with_bn) 38 | self.en_3 = DoubleConv(2*init_channels, 4*init_channels, with_bn) 39 | self.en_4 = DoubleConv(4*init_channels, 8*init_channels, with_bn) 40 | 41 | self.de_1 = DoubleConv((4 + 8)*init_channels, 4*init_channels, with_bn) 42 | self.de_2 = DoubleConv((2 + 4)*init_channels, 2*init_channels, with_bn) 43 | self.de_3 = DoubleConv((1 + 2)*init_channels, 1*init_channels, with_bn) 44 | self.de_4 = nn.Conv2d(init_channels, out_channels, 1) 45 | 46 | self.maxpool = nn.MaxPool2d(kernel_size=2) 47 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 48 | 49 | def forward(self, x): 50 | e1 = self.en_1(x) 51 | e2 = self.en_2(self.maxpool(e1)) 52 | e3 = self.en_3(self.maxpool(e2)) 53 | e4 = self.en_4(self.maxpool(e3)) 54 | 55 | d1 = self.de_1(torch.cat([self.upsample(e4), e3], dim=1)) 56 | d2 = self.de_2(torch.cat([self.upsample(d1), e2], dim=1)) 57 | d3 = self.de_3(torch.cat([self.upsample(d2), e1], dim=1)) 58 | d4 = self.de_4(d3) 59 | 60 | return d4 61 | 62 | # if self.out_channels<2: 63 | # return torch.sigmoid(d4) 64 | # return torch.softmax(d4, 1) 65 | -------------------------------------------------------------------------------- /models/unetpp.py: -------------------------------------------------------------------------------- 1 | # https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py (unetpp) 2 | 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn.functional import softmax, sigmoid 7 | 8 | 9 | __all__ = ['UNet', 'NestedUNet'] 10 | 11 | 12 | class VGGBlock(nn.Module): 13 | def __init__(self, in_channels, middle_channels, out_channels): 14 | super().__init__() 15 | self.relu = nn.ReLU(inplace=True) 16 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) 17 | self.bn1 = nn.BatchNorm2d(middle_channels) 18 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) 19 | self.bn2 = nn.BatchNorm2d(out_channels) 20 | 21 | def forward(self, x): 22 | out = self.conv1(x) 23 | out = self.bn1(out) 24 | out = self.relu(out) 25 | 26 | out = self.conv2(out) 27 | out = self.bn2(out) 28 | out = self.relu(out) 29 | 30 | return out 31 | 32 | 33 | class UNet(nn.Module): 34 | def __init__(self, num_classes, input_channels=3, **kwargs): 35 | super().__init__() 36 | 37 | nb_filter = [32, 64, 128, 256, 512] 38 | 39 | self.pool = nn.MaxPool2d(2, 2) 40 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 41 | 42 | self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0]) 43 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1]) 44 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2]) 45 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3]) 46 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4]) 47 | 48 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3]) 49 | self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2]) 50 | self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1]) 51 | self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0]) 52 | 53 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 54 | 55 | 56 | def forward(self, input): 57 | x0_0 = self.conv0_0(input) 58 | x1_0 = self.conv1_0(self.pool(x0_0)) 59 | x2_0 = self.conv2_0(self.pool(x1_0)) 60 | x3_0 = self.conv3_0(self.pool(x2_0)) 61 | x4_0 = self.conv4_0(self.pool(x3_0)) 62 | 63 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 64 | x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1)) 65 | x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1)) 66 | x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1)) 67 | 68 | output = self.final(x0_4) 69 | return output 70 | 71 | 72 | class NestedUNet(nn.Module): 73 | def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs): 74 | super().__init__() 75 | 76 | nb_filter = [32, 64, 128, 256, 512] 77 | 78 | self.deep_supervision = deep_supervision 79 | 80 | self.pool = nn.MaxPool2d(2, 2) 81 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 82 | 83 | self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0]) 84 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1]) 85 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2]) 86 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3]) 87 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4]) 88 | 89 | self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0]) 90 | self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1]) 91 | self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2]) 92 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3]) 93 | 94 | self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0]) 95 | self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1]) 96 | self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2]) 97 | 98 | self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0]) 99 | self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1]) 100 | 101 | self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0]) 102 | 103 | if self.deep_supervision: 104 | self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 105 | self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 106 | self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 107 | self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 108 | else: 109 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 110 | 111 | 112 | def forward(self, input): 113 | x0_0 = self.conv0_0(input) 114 | x1_0 = self.conv1_0(self.pool(x0_0)) 115 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 116 | 117 | x2_0 = self.conv2_0(self.pool(x1_0)) 118 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 119 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 120 | 121 | x3_0 = self.conv3_0(self.pool(x2_0)) 122 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) 123 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) 124 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 125 | 126 | x4_0 = self.conv4_0(self.pool(x3_0)) 127 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 128 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) 129 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) 130 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 131 | 132 | if self.deep_supervision: 133 | output1 = self.final1(x0_1) 134 | output2 = self.final2(x0_2) 135 | output3 = self.final3(x0_3) 136 | output4 = self.final4(x0_4) 137 | return [output1, output2, output3, output4] 138 | 139 | else: 140 | output = self.final(x0_4) 141 | return output 142 | -------------------------------------------------------------------------------- /train_and_test/isic/attunet-isic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # AttUNet - ISIC2018 5 | # --- 6 | 7 | # ## Import packages & functions 8 | 9 | # In[7]: 10 | 11 | 12 | from __future__ import print_function, division 13 | 14 | 15 | import os 16 | import sys 17 | sys.path.append('../..') 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | import copy 21 | import json 22 | import importlib 23 | import glob 24 | import pandas as pd 25 | from skimage import io, transform 26 | import matplotlib.pyplot as plt 27 | from matplotlib.image import imread 28 | import numpy as np 29 | from tqdm import tqdm 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | import torch.optim as optim 35 | import torchmetrics 36 | from torch.optim import Adam, SGD 37 | from losses import DiceLoss, DiceLossWithLogtis 38 | from torch.nn import BCELoss, CrossEntropyLoss 39 | 40 | from utils import ( 41 | show_sbs, 42 | load_config, 43 | _print, 44 | ) 45 | 46 | # Ignore warnings 47 | import warnings 48 | warnings.filterwarnings("ignore") 49 | 50 | # plt.ion() # interactive mode 51 | 52 | 53 | # ## Set the seed 54 | 55 | # In[2]: 56 | 57 | 58 | torch.manual_seed(0) 59 | np.random.seed(0) 60 | torch.cuda.manual_seed(0) 61 | import random 62 | random.seed(0) 63 | 64 | 65 | # ## Load the config 66 | 67 | # In[3]: 68 | 69 | 70 | CONFIG_NAME = "isic/isic2018_attunet.yaml" 71 | CONFIG_FILE_PATH = os.path.join("../../configs", CONFIG_NAME) 72 | 73 | 74 | # In[4]: 75 | 76 | 77 | config = load_config(CONFIG_FILE_PATH) 78 | _print("Config:", "info_underline") 79 | print(json.dumps(config, indent=2)) 80 | print(20*"~-", "\n") 81 | 82 | 83 | # ## Dataset and Dataloader 84 | 85 | # In[6]: 86 | 87 | 88 | from datasets.isic import ISIC2018DatasetFast 89 | from torch.utils.data import DataLoader, Subset 90 | from torchvision import transforms 91 | 92 | 93 | # In[7]: 94 | 95 | 96 | # ------------------- params -------------------- 97 | INPUT_SIZE = config['dataset']['input_size'] 98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 99 | 100 | 101 | # ----------------- dataset -------------------- 102 | # preparing training dataset 103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True) 104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True) 105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True) 106 | 107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py 109 | 110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}") 111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}") 112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}") 113 | 114 | 115 | # prepare train dataloader 116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train']) 117 | 118 | # prepare validation dataloader 119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation']) 120 | 121 | # prepare test dataloader 122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test']) 123 | 124 | # -------------- test ----------------- 125 | # test and visualize the input data 126 | for sample in tr_dataloader: 127 | img = sample['image'] 128 | msk = sample['mask'] 129 | print("\n Training") 130 | show_sbs(img[0], msk[0,1]) 131 | break 132 | 133 | for sample in vl_dataloader: 134 | img = sample['image'] 135 | msk = sample['mask'] 136 | print("Validation") 137 | show_sbs(img[0], msk[0,1]) 138 | break 139 | 140 | for sample in te_dataloader: 141 | img = sample['image'] 142 | msk = sample['mask'] 143 | print("Test") 144 | show_sbs(img[0], msk[0,1]) 145 | break 146 | 147 | 148 | # ### Device 149 | 150 | # In[9]: 151 | 152 | 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 154 | print(f"Torch device: {device}") 155 | 156 | 157 | # ## Metrics 158 | 159 | # In[ ]: 160 | 161 | 162 | metrics = torchmetrics.MetricCollection( 163 | [ 164 | torchmetrics.F1Score(), 165 | torchmetrics.Accuracy(), 166 | torchmetrics.Dice(), 167 | torchmetrics.Precision(), 168 | torchmetrics.Specificity(), 169 | torchmetrics.Recall(), 170 | # IoU 171 | torchmetrics.JaccardIndex(2) 172 | ], 173 | prefix='train_metrics/' 174 | ) 175 | 176 | # train_metrics 177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device) 178 | 179 | # valid_metrics 180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device) 181 | 182 | # test_metrics 183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device) 184 | 185 | 186 | # In[13]: 187 | 188 | 189 | def make_serializeable_metrics(computed_metrics): 190 | res = {} 191 | for k, v in computed_metrics.items(): 192 | res[k] = float(v.cpu().detach().numpy()) 193 | return res 194 | 195 | 196 | # ## Define validate function 197 | 198 | # In[14]: 199 | 200 | 201 | def validate(model, criterion, vl_dataloader): 202 | model.eval() 203 | with torch.no_grad(): 204 | 205 | evaluator = valid_metrics.clone().to(device) 206 | 207 | losses = [] 208 | cnt = 0. 209 | for batch, batch_data in enumerate(vl_dataloader): 210 | imgs = batch_data['image'] 211 | msks = batch_data['mask'] 212 | 213 | cnt += msks.shape[0] 214 | 215 | imgs = imgs.to(device) 216 | msks = msks.to(device) 217 | 218 | preds = model(imgs) 219 | loss = criterion(preds, msks) 220 | losses.append(loss.item()) 221 | 222 | 223 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 224 | msks_ = torch.argmax(msks, 1, keepdim=False) 225 | evaluator.update(preds_, msks_) 226 | 227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}" 228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}" 229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}") 230 | 231 | # print the final results 232 | loss = np.sum(losses)/cnt 233 | metrics = evaluator.compute() 234 | 235 | return evaluator, loss 236 | 237 | 238 | # ## Define train function 239 | 240 | # In[15]: 241 | 242 | 243 | def train( 244 | model, 245 | device, 246 | tr_dataloader, 247 | vl_dataloader, 248 | config, 249 | 250 | criterion, 251 | optimizer, 252 | scheduler, 253 | 254 | save_dir='./', 255 | save_file_id=None, 256 | ): 257 | 258 | EPOCHS = tr_prms['epochs'] 259 | 260 | torch.cuda.empty_cache() 261 | model = model.to(device) 262 | 263 | evaluator = train_metrics.clone().to(device) 264 | 265 | epochs_info = [] 266 | best_model = None 267 | best_result = {} 268 | best_vl_loss = np.Inf 269 | for epoch in range(EPOCHS): 270 | model.train() 271 | 272 | evaluator.reset() 273 | tr_iterator = tqdm(enumerate(tr_dataloader)) 274 | tr_losses = [] 275 | cnt = 0 276 | for batch, batch_data in tr_iterator: 277 | imgs = batch_data['image'] 278 | msks = batch_data['mask'] 279 | 280 | imgs = imgs.to(device) 281 | msks = msks.to(device) 282 | 283 | optimizer.zero_grad() 284 | preds = model(imgs) 285 | loss = criterion(preds, msks) 286 | loss.backward() 287 | optimizer.step() 288 | 289 | # evaluate by metrics 290 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 291 | msks_ = torch.argmax(msks, 1, keepdim=False) 292 | evaluator.update(preds_, msks_) 293 | 294 | cnt += imgs.shape[0] 295 | tr_losses.append(loss.item()) 296 | 297 | # write details for each training batch 298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}" 299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}" 300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}") 301 | 302 | tr_loss = np.sum(tr_losses)/cnt 303 | 304 | # validate model 305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader) 306 | if vl_loss < best_vl_loss: 307 | # find a better model 308 | best_model = model 309 | best_vl_loss = vl_loss 310 | best_result = { 311 | 'tr_loss': tr_loss, 312 | 'vl_loss': vl_loss, 313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 315 | } 316 | 317 | # write the final results 318 | epoch_info = { 319 | 'tr_loss': tr_loss, 320 | 'vl_loss': vl_loss, 321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 323 | } 324 | epochs_info.append(epoch_info) 325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}") 326 | evaluator.reset() 327 | 328 | scheduler.step(vl_loss) 329 | 330 | # save final results 331 | res = { 332 | 'id': save_file_id, 333 | 'config': config, 334 | 'epochs_info': epochs_info, 335 | 'best_result': best_result 336 | } 337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json" 338 | fp = os.path.join(config['model']['save_dir'],fn) 339 | with open(fp, "w") as write_file: 340 | json.dump(res, write_file, indent=4) 341 | 342 | # save model's state_dict 343 | fn = "last_model_state_dict.pt" 344 | fp = os.path.join(config['model']['save_dir'],fn) 345 | torch.save(model.state_dict(), fp) 346 | 347 | # save the best model's state_dict 348 | fn = "best_model_state_dict.pt" 349 | fp = os.path.join(config['model']['save_dir'], fn) 350 | torch.save(best_model.state_dict(), fp) 351 | 352 | return best_model, model, res 353 | 354 | 355 | # ## Define test function 356 | 357 | # In[17]: 358 | 359 | 360 | def test(model, te_dataloader): 361 | model.eval() 362 | with torch.no_grad(): 363 | evaluator = test_metrics.clone().to(device) 364 | for batch_data in tqdm(te_dataloader): 365 | imgs = batch_data['image'] 366 | msks = batch_data['mask'] 367 | 368 | imgs = imgs.to(device) 369 | msks = msks.to(device) 370 | 371 | preds = model(imgs) 372 | 373 | # evaluate by metrics 374 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 375 | msks_ = torch.argmax(msks, 1, keepdim=False) 376 | evaluator.update(preds_, msks_) 377 | 378 | return evaluator 379 | 380 | 381 | # ## Load and prepare model 382 | 383 | # In[18]: 384 | 385 | 386 | from models.attunet import AttU_Net as Net 387 | 388 | 389 | model = Net(**config['model']['params']) 390 | torch.cuda.empty_cache() 391 | model = model.to(device) 392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) 393 | 394 | os.makedirs(config['model']['save_dir'], exist_ok=True) 395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt" 396 | 397 | if config['model']['load_weights']: 398 | model.load_state_dict(torch.load(model_path)) 399 | print("Loaded pre-trained weights...") 400 | 401 | 402 | # criterion_dice = DiceLoss() 403 | criterion_dice = DiceLossWithLogtis() 404 | # criterion_ce = BCELoss() 405 | criterion_ce = CrossEntropyLoss() 406 | 407 | 408 | def criterion(preds, masks): 409 | c_dice = criterion_dice(preds, masks) 410 | c_ce = criterion_ce(preds, masks) 411 | return 0.5*c_dice + 0.5*c_ce 412 | 413 | tr_prms = config['training'] 414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params']) 415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler']) 416 | 417 | 418 | # ## Start traning 419 | 420 | # In[19]: 421 | 422 | 423 | best_model, model, res = train( 424 | model, 425 | device, 426 | tr_dataloader, 427 | vl_dataloader, 428 | config, 429 | 430 | criterion, 431 | optimizer, 432 | scheduler, 433 | 434 | save_dir = config['model']['save_dir'], 435 | save_file_id = None, 436 | ) 437 | 438 | 439 | # In[20]: 440 | 441 | 442 | te_metrics = test(best_model, te_dataloader) 443 | te_metrics.compute() 444 | 445 | 446 | # In[21]: 447 | 448 | 449 | f"{config['model']['save_dir']}" 450 | 451 | 452 | # # Test the best inferred model 453 | # ---- 454 | 455 | # ## Load the best model 456 | 457 | # In[22]: 458 | 459 | 460 | best_model = Net(**config['model']['params']) 461 | torch.cuda.empty_cache() 462 | best_model = best_model.to(device) 463 | 464 | fn = "best_model_state_dict.pt" 465 | os.makedirs(config['model']['save_dir'], exist_ok=True) 466 | model_path = f"{config['model']['save_dir']}/{fn}" 467 | 468 | best_model.load_state_dict(torch.load(model_path)) 469 | print("Loaded best model weights...") 470 | 471 | 472 | # ## Evaluation 473 | 474 | # In[23]: 475 | 476 | 477 | te_metrics = test(best_model, te_dataloader) 478 | te_metrics.compute() 479 | 480 | 481 | # ## Plot graphs 482 | 483 | # In[24]: 484 | 485 | 486 | result_file_path = f"{config['model']['save_dir']}/result.json" 487 | with open(result_file_path, 'r') as f: 488 | results = json.loads(''.join(f.readlines())) 489 | epochs_info = results['epochs_info'] 490 | 491 | tr_losses = [d['tr_loss'] for d in epochs_info] 492 | vl_losses = [d['vl_loss'] for d in epochs_info] 493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info] 494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info] 495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info] 496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info] 497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info] 498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info] 499 | 500 | 501 | _, axs = plt.subplots(1, 4, figsize=[16,3]) 502 | 503 | axs[0].set_title("Loss") 504 | axs[0].plot(tr_losses, 'r-', label="train loss") 505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss") 506 | axs[0].legend() 507 | 508 | axs[1].set_title("Dice score") 509 | axs[1].plot(tr_dice, 'r-', label="train dice") 510 | axs[1].plot(vl_dice, 'b-', label="validation dice") 511 | axs[1].legend() 512 | 513 | axs[2].set_title("Jaccard Similarity") 514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex") 515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex") 516 | axs[2].legend() 517 | 518 | axs[3].set_title("Accuracy") 519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy") 520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy") 521 | axs[3].legend() 522 | 523 | plt.show() 524 | 525 | 526 | # In[25]: 527 | 528 | 529 | epochs_info 530 | 531 | 532 | # ## Save images 533 | 534 | # In[30]: 535 | 536 | 537 | from PIL import Image 538 | import cv2 539 | def skin_plot(img, gt, pred): 540 | img = np.array(img) 541 | gt = np.array(gt) 542 | pred = np.array(pred) 543 | edged_test = cv2.Canny(pred, 100, 255) 544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 545 | edged_gt = cv2.Canny(gt, 100, 255) 546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 547 | for cnt_test in contours_test: 548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1) 549 | for cnt_gt in contours_gt: 550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1) 551 | return img 552 | 553 | #--------------------------------------------------------------------------------------------- 554 | 555 | 556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized" 557 | 558 | if not os.path.isdir(save_imgs_dir): 559 | os.mkdir(save_imgs_dir) 560 | 561 | with torch.no_grad(): 562 | for batch in tqdm(te_dataloader): 563 | imgs = batch['image'] 564 | msks = batch['mask'] 565 | ids = batch['id'] 566 | 567 | preds = best_model(imgs.to(device)) 568 | 569 | txm = imgs.cpu().numpy() 570 | tbm = torch.argmax(msks, 1).cpu().numpy() 571 | tpm = torch.argmax(preds, 1).cpu().numpy() 572 | tid = ids 573 | 574 | for idx in range(len(tbm)): 575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255. 576 | img = np.ascontiguousarray(img, dtype=np.uint8) 577 | gt = np.uint8(tbm[idx]*255.) 578 | pred = np.where(tpm[idx]>0.5, 255, 0) 579 | pred = np.ascontiguousarray(pred, dtype=np.uint8) 580 | 581 | res_img = skin_plot(img, gt, pred) 582 | 583 | fid = tid[idx] 584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png") 585 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png") 586 | 587 | 588 | # In[31]: 589 | 590 | 591 | f"{config['model']['save_dir']}/visualized" 592 | 593 | -------------------------------------------------------------------------------- /train_and_test/isic/resunet-isic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Residual UNet - ISIC2018 5 | # --- 6 | 7 | # ## Import packages & functions 8 | 9 | # In[1]: 10 | 11 | 12 | from __future__ import print_function, division 13 | 14 | 15 | import os 16 | import sys 17 | sys.path.append('../..') 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | import copy 21 | import json 22 | import importlib 23 | import glob 24 | import pandas as pd 25 | from skimage import io, transform 26 | import matplotlib.pyplot as plt 27 | from matplotlib.image import imread 28 | import numpy as np 29 | from tqdm import tqdm 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | import torch.optim as optim 35 | import torchmetrics 36 | from torch.optim import Adam, SGD 37 | from losses import DiceLoss, DiceLossWithLogtis 38 | from torch.nn import BCELoss, CrossEntropyLoss 39 | 40 | from utils import ( 41 | show_sbs, 42 | load_config, 43 | _print, 44 | ) 45 | 46 | # Ignore warnings 47 | import warnings 48 | warnings.filterwarnings("ignore") 49 | 50 | # plt.ion() # interactive mode 51 | 52 | 53 | # ## Set the seed 54 | 55 | # In[2]: 56 | 57 | 58 | torch.manual_seed(0) 59 | np.random.seed(0) 60 | torch.cuda.manual_seed(0) 61 | import random 62 | random.seed(0) 63 | 64 | 65 | # ## Load the config 66 | 67 | # In[3]: 68 | 69 | 70 | CONFIG_NAME = "isic/isic2018_resunet.yaml" 71 | CONFIG_FILE_PATH = os.path.join("./configs", CONFIG_NAME) 72 | 73 | 74 | # In[4]: 75 | 76 | 77 | config = load_config(CONFIG_FILE_PATH) 78 | _print("Config:", "info_underline") 79 | print(json.dumps(config, indent=2)) 80 | print(20*"~-", "\n") 81 | 82 | 83 | # ## Dataset and Dataloader 84 | 85 | # In[6]: 86 | 87 | 88 | from datasets.isic import ISIC2018DatasetFast 89 | from torch.utils.data import DataLoader, Subset 90 | from torchvision import transforms 91 | 92 | 93 | # In[7]: 94 | 95 | 96 | # ------------------- params -------------------- 97 | INPUT_SIZE = config['dataset']['input_size'] 98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 99 | 100 | 101 | # ----------------- dataset -------------------- 102 | # preparing training dataset 103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True) 104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True) 105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True) 106 | 107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py 109 | 110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}") 111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}") 112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}") 113 | 114 | 115 | # prepare train dataloader 116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train']) 117 | 118 | # prepare validation dataloader 119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation']) 120 | 121 | # prepare test dataloader 122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test']) 123 | 124 | # -------------- test ----------------- 125 | # test and visualize the input data 126 | for sample in tr_dataloader: 127 | img = sample['image'] 128 | msk = sample['mask'] 129 | print("\n Training") 130 | show_sbs(img[0], msk[0,1]) 131 | break 132 | 133 | for sample in vl_dataloader: 134 | img = sample['image'] 135 | msk = sample['mask'] 136 | print("Validation") 137 | show_sbs(img[0], msk[0,1]) 138 | break 139 | 140 | for sample in te_dataloader: 141 | img = sample['image'] 142 | msk = sample['mask'] 143 | print("Test") 144 | show_sbs(img[0], msk[0,1]) 145 | break 146 | 147 | 148 | # ### Device 149 | 150 | # In[9]: 151 | 152 | 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 154 | print(f"Torch device: {device}") 155 | 156 | 157 | # ## Metrics 158 | 159 | # In[12]: 160 | 161 | 162 | metrics = torchmetrics.MetricCollection( 163 | [ 164 | torchmetrics.F1Score(), 165 | torchmetrics.Accuracy(), 166 | torchmetrics.Dice(), 167 | torchmetrics.Precision(), 168 | torchmetrics.Specificity(), 169 | torchmetrics.Recall(), 170 | # IoU 171 | torchmetrics.JaccardIndex(2) 172 | ], 173 | prefix='train_metrics/' 174 | ) 175 | 176 | # train_metrics 177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device) 178 | 179 | # valid_metrics 180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device) 181 | 182 | # test_metrics 183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device) 184 | 185 | 186 | # In[13]: 187 | 188 | 189 | def make_serializeable_metrics(computed_metrics): 190 | res = {} 191 | for k, v in computed_metrics.items(): 192 | res[k] = float(v.cpu().detach().numpy()) 193 | return res 194 | 195 | 196 | # ## Define validate function 197 | 198 | # In[14]: 199 | 200 | 201 | def validate(model, criterion, vl_dataloader): 202 | model.eval() 203 | with torch.no_grad(): 204 | 205 | evaluator = valid_metrics.clone().to(device) 206 | 207 | losses = [] 208 | cnt = 0. 209 | for batch, batch_data in enumerate(vl_dataloader): 210 | imgs = batch_data['image'] 211 | msks = batch_data['mask'] 212 | 213 | cnt += msks.shape[0] 214 | 215 | imgs = imgs.to(device) 216 | msks = msks.to(device) 217 | 218 | preds = model(imgs) 219 | loss = criterion(preds, msks) 220 | losses.append(loss.item()) 221 | 222 | 223 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 224 | msks_ = torch.argmax(msks, 1, keepdim=False) 225 | evaluator.update(preds_, msks_) 226 | 227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}" 228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}" 229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}") 230 | 231 | # print the final results 232 | loss = np.sum(losses)/cnt 233 | metrics = evaluator.compute() 234 | 235 | return evaluator, loss 236 | 237 | 238 | # ## Define train function 239 | 240 | # In[15]: 241 | 242 | 243 | def train( 244 | model, 245 | device, 246 | tr_dataloader, 247 | vl_dataloader, 248 | config, 249 | 250 | criterion, 251 | optimizer, 252 | scheduler, 253 | 254 | save_dir='./', 255 | save_file_id=None, 256 | ): 257 | 258 | EPOCHS = tr_prms['epochs'] 259 | 260 | torch.cuda.empty_cache() 261 | model = model.to(device) 262 | 263 | evaluator = train_metrics.clone().to(device) 264 | 265 | epochs_info = [] 266 | best_model = None 267 | best_result = {} 268 | best_vl_loss = np.Inf 269 | for epoch in range(EPOCHS): 270 | model.train() 271 | 272 | evaluator.reset() 273 | tr_iterator = tqdm(enumerate(tr_dataloader)) 274 | tr_losses = [] 275 | cnt = 0 276 | for batch, batch_data in tr_iterator: 277 | imgs = batch_data['image'] 278 | msks = batch_data['mask'] 279 | 280 | imgs = imgs.to(device) 281 | msks = msks.to(device) 282 | 283 | optimizer.zero_grad() 284 | preds = model(imgs) 285 | loss = criterion(preds, msks) 286 | loss.backward() 287 | optimizer.step() 288 | 289 | # evaluate by metrics 290 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 291 | msks_ = torch.argmax(msks, 1, keepdim=False) 292 | evaluator.update(preds_, msks_) 293 | 294 | cnt += imgs.shape[0] 295 | tr_losses.append(loss.item()) 296 | 297 | # write details for each training batch 298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}" 299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}" 300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}") 301 | 302 | tr_loss = np.sum(tr_losses)/cnt 303 | 304 | # validate model 305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader) 306 | if vl_loss < best_vl_loss: 307 | # find a better model 308 | best_model = model 309 | best_vl_loss = vl_loss 310 | best_result = { 311 | 'tr_loss': tr_loss, 312 | 'vl_loss': vl_loss, 313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 315 | } 316 | 317 | # write the final results 318 | epoch_info = { 319 | 'tr_loss': tr_loss, 320 | 'vl_loss': vl_loss, 321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 323 | } 324 | epochs_info.append(epoch_info) 325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}") 326 | evaluator.reset() 327 | 328 | scheduler.step(vl_loss) 329 | 330 | # save final results 331 | res = { 332 | 'id': save_file_id, 333 | 'config': config, 334 | 'epochs_info': epochs_info, 335 | 'best_result': best_result 336 | } 337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json" 338 | fp = os.path.join(config['model']['save_dir'],fn) 339 | with open(fp, "w") as write_file: 340 | json.dump(res, write_file, indent=4) 341 | 342 | # save model's state_dict 343 | fn = "last_model_state_dict.pt" 344 | fp = os.path.join(config['model']['save_dir'],fn) 345 | torch.save(model.state_dict(), fp) 346 | 347 | # save the best model's state_dict 348 | fn = "best_model_state_dict.pt" 349 | fp = os.path.join(config['model']['save_dir'], fn) 350 | torch.save(best_model.state_dict(), fp) 351 | 352 | return best_model, model, res 353 | 354 | 355 | # ## Define test function 356 | 357 | # In[17]: 358 | 359 | 360 | def test(model, te_dataloader): 361 | model.eval() 362 | with torch.no_grad(): 363 | evaluator = test_metrics.clone().to(device) 364 | for batch_data in tqdm(te_dataloader): 365 | imgs = batch_data['image'] 366 | msks = batch_data['mask'] 367 | 368 | imgs = imgs.to(device) 369 | msks = msks.to(device) 370 | 371 | preds = model(imgs) 372 | 373 | # evaluate by metrics 374 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 375 | msks_ = torch.argmax(msks, 1, keepdim=False) 376 | evaluator.update(preds_, msks_) 377 | 378 | return evaluator 379 | 380 | 381 | # ## Load and prepare model 382 | 383 | # In[18]: 384 | 385 | 386 | from models._resunet.res_unet import ResUnet as Net 387 | 388 | 389 | model = Net(**config['model']['params']) 390 | torch.cuda.empty_cache() 391 | model = model.to(device) 392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) 393 | 394 | os.makedirs(config['model']['save_dir'], exist_ok=True) 395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt" 396 | 397 | if config['model']['load_weights']: 398 | model.load_state_dict(torch.load(model_path)) 399 | print("Loaded pre-trained weights...") 400 | 401 | 402 | # criterion_dice = DiceLoss() 403 | criterion_dice = DiceLossWithLogtis() 404 | # criterion_ce = BCELoss() 405 | criterion_ce = CrossEntropyLoss() 406 | 407 | 408 | def criterion(preds, masks): 409 | c_dice = criterion_dice(preds, masks) 410 | c_ce = criterion_ce(preds, masks) 411 | return 0.5*c_dice + 0.5*c_ce 412 | 413 | tr_prms = config['training'] 414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params']) 415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler']) 416 | 417 | 418 | # ## Start traning 419 | 420 | # In[19]: 421 | 422 | 423 | best_model, model, res = train( 424 | model, 425 | device, 426 | tr_dataloader, 427 | vl_dataloader, 428 | config, 429 | 430 | criterion, 431 | optimizer, 432 | scheduler, 433 | 434 | save_dir = config['model']['save_dir'], 435 | save_file_id = None, 436 | ) 437 | 438 | 439 | # In[20]: 440 | 441 | 442 | te_metrics = test(best_model, te_dataloader) 443 | te_metrics.compute() 444 | 445 | 446 | # In[21]: 447 | 448 | 449 | f"{config['model']['save_dir']}" 450 | 451 | 452 | # # Test the best inferred model 453 | # ---- 454 | 455 | # ## Load the best model 456 | 457 | # In[22]: 458 | 459 | 460 | best_model = Net(**config['model']['params']) 461 | torch.cuda.empty_cache() 462 | best_model = best_model.to(device) 463 | 464 | fn = "best_model_state_dict.pt" 465 | os.makedirs(config['model']['save_dir'], exist_ok=True) 466 | model_path = f"{config['model']['save_dir']}/{fn}" 467 | 468 | best_model.load_state_dict(torch.load(model_path)) 469 | print("Loaded best model weights...") 470 | 471 | 472 | # ## Evaluation 473 | 474 | # In[23]: 475 | 476 | 477 | te_metrics = test(best_model, te_dataloader) 478 | te_metrics.compute() 479 | 480 | 481 | # ## Plot graphs 482 | 483 | # In[24]: 484 | 485 | 486 | result_file_path = f"{config['model']['save_dir']}/result.json" 487 | with open(result_file_path, 'r') as f: 488 | results = json.loads(''.join(f.readlines())) 489 | epochs_info = results['epochs_info'] 490 | 491 | tr_losses = [d['tr_loss'] for d in epochs_info] 492 | vl_losses = [d['vl_loss'] for d in epochs_info] 493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info] 494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info] 495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info] 496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info] 497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info] 498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info] 499 | 500 | 501 | _, axs = plt.subplots(1, 4, figsize=[16,3]) 502 | 503 | axs[0].set_title("Loss") 504 | axs[0].plot(tr_losses, 'r-', label="train loss") 505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss") 506 | axs[0].legend() 507 | 508 | axs[1].set_title("Dice score") 509 | axs[1].plot(tr_dice, 'r-', label="train dice") 510 | axs[1].plot(vl_dice, 'b-', label="validation dice") 511 | axs[1].legend() 512 | 513 | axs[2].set_title("Jaccard Similarity") 514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex") 515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex") 516 | axs[2].legend() 517 | 518 | axs[3].set_title("Accuracy") 519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy") 520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy") 521 | axs[3].legend() 522 | 523 | plt.show() 524 | 525 | 526 | # In[25]: 527 | 528 | 529 | epochs_info 530 | 531 | 532 | # ## Save images 533 | 534 | # In[28]: 535 | 536 | 537 | from PIL import Image 538 | import cv2 539 | def skin_plot(img, gt, pred): 540 | img = np.array(img) 541 | gt = np.array(gt) 542 | pred = np.array(pred) 543 | edged_test = cv2.Canny(pred, 100, 255) 544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 545 | edged_gt = cv2.Canny(gt, 100, 255) 546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 547 | for cnt_test in contours_test: 548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1) 549 | for cnt_gt in contours_gt: 550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1) 551 | return img 552 | 553 | #--------------------------------------------------------------------------------------------- 554 | 555 | 556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized" 557 | 558 | if not os.path.isdir(save_imgs_dir): 559 | os.mkdir(save_imgs_dir) 560 | 561 | with torch.no_grad(): 562 | for batch in tqdm(te_dataloader): 563 | imgs = batch['image'] 564 | msks = batch['mask'] 565 | ids = batch['id'] 566 | 567 | preds = best_model(imgs.to(device)) 568 | 569 | txm = imgs.cpu().numpy() 570 | tbm = torch.argmax(msks, 1).cpu().numpy() 571 | tpm = torch.argmax(preds, 1).cpu().numpy() 572 | tid = ids 573 | 574 | for idx in range(len(tbm)): 575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255. 576 | img = np.ascontiguousarray(img, dtype=np.uint8) 577 | gt = np.uint8(tbm[idx]*255.) 578 | pred = np.where(tpm[idx]>0.5, 255, 0) 579 | pred = np.ascontiguousarray(pred, dtype=np.uint8) 580 | 581 | res_img = skin_plot(img, gt, pred) 582 | 583 | fid = tid[idx] 584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png") 585 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png") 586 | 587 | 588 | # In[29]: 589 | 590 | 591 | f"{config['model']['save_dir']}/visualized" 592 | 593 | -------------------------------------------------------------------------------- /train_and_test/isic/unet-isic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # UNet - ISIC2018 5 | # --- 6 | 7 | # ## Import packages & functions 8 | 9 | # In[1]: 10 | 11 | 12 | from __future__ import print_function, division 13 | 14 | 15 | import os 16 | import sys 17 | sys.path.append('../..') 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | import copy 21 | import json 22 | import importlib 23 | import glob 24 | import pandas as pd 25 | from skimage import io, transform 26 | import matplotlib.pyplot as plt 27 | from matplotlib.image import imread 28 | import numpy as np 29 | from tqdm import tqdm 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | import torch.optim as optim 35 | import torchmetrics 36 | from torch.optim import Adam, SGD 37 | from losses import DiceLoss, DiceLossWithLogtis 38 | from torch.nn import BCELoss, CrossEntropyLoss 39 | 40 | from utils import ( 41 | show_sbs, 42 | load_config, 43 | _print, 44 | ) 45 | 46 | # Ignore warnings 47 | import warnings 48 | warnings.filterwarnings("ignore") 49 | 50 | # plt.ion() # interactive mode 51 | 52 | 53 | # ## Set the seed 54 | 55 | # In[2]: 56 | 57 | 58 | torch.manual_seed(0) 59 | np.random.seed(0) 60 | torch.cuda.manual_seed(0) 61 | import random 62 | random.seed(0) 63 | 64 | 65 | # ## Load the config 66 | 67 | # In[3]: 68 | 69 | 70 | CONFIG_NAME = "isic/isic2018_unet.yaml" 71 | CONFIG_FILE_PATH = os.path.join("../../configs", CONFIG_NAME) 72 | 73 | 74 | # In[4]: 75 | 76 | 77 | config = load_config(CONFIG_FILE_PATH) 78 | _print("Config:", "info_underline") 79 | print(json.dumps(config, indent=2)) 80 | print(20*"~-", "\n") 81 | 82 | 83 | # ## Dataset and Dataloader 84 | 85 | # In[6]: 86 | 87 | 88 | from datasets.isic import ISIC2018DatasetFast 89 | from torch.utils.data import DataLoader, Subset 90 | from torchvision import transforms 91 | 92 | 93 | # In[7]: 94 | 95 | 96 | # ------------------- params -------------------- 97 | INPUT_SIZE = config['dataset']['input_size'] 98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 99 | 100 | 101 | # ----------------- dataset -------------------- 102 | # preparing training dataset 103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True) 104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True) 105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True) 106 | 107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py 109 | 110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}") 111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}") 112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}") 113 | 114 | 115 | # prepare train dataloader 116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train']) 117 | 118 | # prepare validation dataloader 119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation']) 120 | 121 | # prepare test dataloader 122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test']) 123 | 124 | # -------------- test ----------------- 125 | # test and visualize the input data 126 | for sample in tr_dataloader: 127 | img = sample['image'] 128 | msk = sample['mask'] 129 | print("\n Training") 130 | show_sbs(img[0], msk[0,1]) 131 | break 132 | 133 | for sample in vl_dataloader: 134 | img = sample['image'] 135 | msk = sample['mask'] 136 | print("Validation") 137 | show_sbs(img[0], msk[0,1]) 138 | break 139 | 140 | for sample in te_dataloader: 141 | img = sample['image'] 142 | msk = sample['mask'] 143 | print("Test") 144 | show_sbs(img[0], msk[0,1]) 145 | break 146 | 147 | 148 | # ### Device 149 | 150 | # In[9]: 151 | 152 | 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 154 | print(f"Torch device: {device}") 155 | 156 | 157 | # ## Metrics 158 | 159 | # In[12]: 160 | 161 | 162 | metrics = torchmetrics.MetricCollection( 163 | [ 164 | torchmetrics.F1Score(), 165 | torchmetrics.Accuracy(), 166 | torchmetrics.Dice(), 167 | torchmetrics.Precision(), 168 | torchmetrics.Recall(), 169 | torchmetrics.Specificity(), 170 | # IoU 171 | torchmetrics.JaccardIndex(2) 172 | ], 173 | prefix='train_metrics/' 174 | ) 175 | 176 | # train_metrics 177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device) 178 | 179 | # valid_metrics 180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device) 181 | 182 | # test_metrics 183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device) 184 | 185 | 186 | # In[13]: 187 | 188 | 189 | def make_serializeable_metrics(computed_metrics): 190 | res = {} 191 | for k, v in computed_metrics.items(): 192 | res[k] = float(v.cpu().detach().numpy()) 193 | return res 194 | 195 | 196 | # ## Define validate function 197 | 198 | # In[14]: 199 | 200 | 201 | def validate(model, criterion, vl_dataloader): 202 | model.eval() 203 | with torch.no_grad(): 204 | 205 | evaluator = valid_metrics.clone().to(device) 206 | 207 | losses = [] 208 | cnt = 0. 209 | for batch, batch_data in enumerate(vl_dataloader): 210 | imgs = batch_data['image'] 211 | msks = batch_data['mask'] 212 | 213 | cnt += msks.shape[0] 214 | 215 | imgs = imgs.to(device) 216 | msks = msks.to(device) 217 | 218 | preds = model(imgs) 219 | loss = criterion(preds, msks) 220 | losses.append(loss.item()) 221 | 222 | 223 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 224 | msks_ = torch.argmax(msks, 1, keepdim=False) 225 | evaluator.update(preds_, msks_) 226 | 227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}" 228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}" 229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}") 230 | 231 | # print the final results 232 | loss = np.sum(losses)/cnt 233 | metrics = evaluator.compute() 234 | 235 | return evaluator, loss 236 | 237 | 238 | # ## Define train function 239 | 240 | # In[15]: 241 | 242 | 243 | def train( 244 | model, 245 | device, 246 | tr_dataloader, 247 | vl_dataloader, 248 | config, 249 | 250 | criterion, 251 | optimizer, 252 | scheduler, 253 | 254 | save_dir='./', 255 | save_file_id=None, 256 | ): 257 | 258 | EPOCHS = tr_prms['epochs'] 259 | 260 | torch.cuda.empty_cache() 261 | model = model.to(device) 262 | 263 | evaluator = train_metrics.clone().to(device) 264 | 265 | epochs_info = [] 266 | best_model = None 267 | best_result = {} 268 | best_vl_loss = np.Inf 269 | for epoch in range(EPOCHS): 270 | model.train() 271 | 272 | evaluator.reset() 273 | tr_iterator = tqdm(enumerate(tr_dataloader)) 274 | tr_losses = [] 275 | cnt = 0 276 | for batch, batch_data in tr_iterator: 277 | imgs = batch_data['image'] 278 | msks = batch_data['mask'] 279 | 280 | imgs = imgs.to(device) 281 | msks = msks.to(device) 282 | 283 | optimizer.zero_grad() 284 | preds = model(imgs) 285 | loss = criterion(preds, msks) 286 | loss.backward() 287 | optimizer.step() 288 | 289 | # evaluate by metrics 290 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 291 | msks_ = torch.argmax(msks, 1, keepdim=False) 292 | evaluator.update(preds_, msks_) 293 | 294 | cnt += imgs.shape[0] 295 | tr_losses.append(loss.item()) 296 | 297 | # write details for each training batch 298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}" 299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}" 300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}") 301 | 302 | tr_loss = np.sum(tr_losses)/cnt 303 | 304 | # validate model 305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader) 306 | if vl_loss < best_vl_loss: 307 | # find a better model 308 | best_model = model 309 | best_vl_loss = vl_loss 310 | best_result = { 311 | 'tr_loss': tr_loss, 312 | 'vl_loss': vl_loss, 313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 315 | } 316 | 317 | # write the final results 318 | epoch_info = { 319 | 'tr_loss': tr_loss, 320 | 'vl_loss': vl_loss, 321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 323 | } 324 | epochs_info.append(epoch_info) 325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}") 326 | evaluator.reset() 327 | 328 | scheduler.step(vl_loss) 329 | 330 | # save final results 331 | res = { 332 | 'id': save_file_id, 333 | 'config': config, 334 | 'epochs_info': epochs_info, 335 | 'best_result': best_result 336 | } 337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json" 338 | fp = os.path.join(config['model']['save_dir'],fn) 339 | with open(fp, "w") as write_file: 340 | json.dump(res, write_file, indent=4) 341 | 342 | # save model's state_dict 343 | fn = "last_model_state_dict.pt" 344 | fp = os.path.join(config['model']['save_dir'],fn) 345 | torch.save(model.state_dict(), fp) 346 | 347 | # save the best model's state_dict 348 | fn = "best_model_state_dict.pt" 349 | fp = os.path.join(config['model']['save_dir'], fn) 350 | torch.save(best_model.state_dict(), fp) 351 | 352 | return best_model, model, res 353 | 354 | 355 | # ## Define test function 356 | 357 | # In[17]: 358 | 359 | 360 | def test(model, te_dataloader): 361 | model.eval() 362 | with torch.no_grad(): 363 | evaluator = test_metrics.clone().to(device) 364 | for batch_data in tqdm(te_dataloader): 365 | imgs = batch_data['image'] 366 | msks = batch_data['mask'] 367 | 368 | imgs = imgs.to(device) 369 | msks = msks.to(device) 370 | 371 | preds = model(imgs) 372 | 373 | # evaluate by metrics 374 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 375 | msks_ = torch.argmax(msks, 1, keepdim=False) 376 | evaluator.update(preds_, msks_) 377 | 378 | return evaluator 379 | 380 | 381 | # ## Load and prepare model 382 | 383 | # In[18]: 384 | 385 | 386 | from models.unet import UNet as Net 387 | 388 | 389 | model = Net(**config['model']['params']) 390 | torch.cuda.empty_cache() 391 | model = model.to(device) 392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) 393 | 394 | os.makedirs(config['model']['save_dir'], exist_ok=True) 395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt" 396 | 397 | if config['model']['load_weights']: 398 | model.load_state_dict(torch.load(model_path)) 399 | print("Loaded pre-trained weights...") 400 | 401 | 402 | # criterion_dice = DiceLoss() 403 | criterion_dice = DiceLossWithLogtis() 404 | # criterion_ce = BCELoss() 405 | criterion_ce = CrossEntropyLoss() 406 | 407 | 408 | def criterion(preds, masks): 409 | c_dice = criterion_dice(preds, masks) 410 | c_ce = criterion_ce(preds, masks) 411 | return 0.5*c_dice + 0.5*c_ce 412 | 413 | tr_prms = config['training'] 414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params']) 415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler']) 416 | 417 | 418 | # ## Start traning 419 | 420 | # In[19]: 421 | 422 | 423 | best_model, model, res = train( 424 | model, 425 | device, 426 | tr_dataloader, 427 | vl_dataloader, 428 | config, 429 | 430 | criterion, 431 | optimizer, 432 | scheduler, 433 | 434 | save_dir = config['model']['save_dir'], 435 | save_file_id = None, 436 | ) 437 | 438 | 439 | # In[27]: 440 | 441 | 442 | te_metrics = test(best_model, te_dataloader) 443 | te_metrics.compute() 444 | 445 | 446 | # In[28]: 447 | 448 | 449 | f"{config['model']['save_dir']}" 450 | 451 | 452 | # # Test the best inferred model 453 | # ---- 454 | 455 | # ## Load the best model 456 | 457 | # In[29]: 458 | 459 | 460 | best_model = Net(**config['model']['params']) 461 | torch.cuda.empty_cache() 462 | best_model = best_model.to(device) 463 | 464 | fn = "best_model_state_dict.pt" 465 | os.makedirs(config['model']['save_dir'], exist_ok=True) 466 | model_path = f"{config['model']['save_dir']}/{fn}" 467 | 468 | best_model.load_state_dict(torch.load(model_path)) 469 | print("Loaded best model weights...") 470 | 471 | 472 | # ## Evaluation 473 | 474 | # In[30]: 475 | 476 | 477 | te_metrics = test(best_model, te_dataloader) 478 | te_metrics.compute() 479 | 480 | 481 | # ## Plot graphs 482 | 483 | # In[33]: 484 | 485 | 486 | result_file_path = f"{config['model']['save_dir']}/result.json" 487 | with open(result_file_path, 'r') as f: 488 | results = json.loads(''.join(f.readlines())) 489 | epochs_info = results['epochs_info'] 490 | 491 | tr_losses = [d['tr_loss'] for d in epochs_info] 492 | vl_losses = [d['vl_loss'] for d in epochs_info] 493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info] 494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info] 495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info] 496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info] 497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info] 498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info] 499 | 500 | 501 | _, axs = plt.subplots(1, 4, figsize=[16,3]) 502 | 503 | axs[0].set_title("Loss") 504 | axs[0].plot(tr_losses, 'r-', label="train loss") 505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss") 506 | axs[0].legend() 507 | 508 | axs[1].set_title("Dice score") 509 | axs[1].plot(tr_dice, 'r-', label="train dice") 510 | axs[1].plot(vl_dice, 'b-', label="validation dice") 511 | axs[1].legend() 512 | 513 | axs[2].set_title("Jaccard Similarity") 514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex") 515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex") 516 | axs[2].legend() 517 | 518 | axs[3].set_title("Accuracy") 519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy") 520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy") 521 | axs[3].legend() 522 | 523 | plt.show() 524 | 525 | 526 | # In[32]: 527 | 528 | 529 | epochs_info 530 | 531 | 532 | # ## Save images 533 | 534 | # In[13]: 535 | 536 | 537 | from PIL import Image 538 | import cv2 539 | def skin_plot(img, gt, pred): 540 | img = np.array(img) 541 | gt = np.array(gt) 542 | pred = np.array(pred) 543 | edged_test = cv2.Canny(pred, 100, 255) 544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 545 | edged_gt = cv2.Canny(gt, 100, 255) 546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 547 | for cnt_test in contours_test: 548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1) 549 | for cnt_gt in contours_gt: 550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1) 551 | return img 552 | 553 | #--------------------------------------------------------------------------------------------- 554 | 555 | 556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized" 557 | 558 | if not os.path.isdir(save_imgs_dir): 559 | os.mkdir(save_imgs_dir) 560 | 561 | with torch.no_grad(): 562 | for batch in tqdm(te_dataloader): 563 | imgs = batch['image'] 564 | msks = batch['mask'] 565 | ids = batch['id'] 566 | 567 | preds = best_model(imgs.to(device)) 568 | 569 | txm = imgs.cpu().numpy() 570 | tbm = torch.argmax(msks, 1).cpu().numpy() 571 | tpm = torch.argmax(preds, 1).cpu().numpy() 572 | tid = ids 573 | 574 | for idx in range(len(tbm)): 575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255. 576 | img = np.ascontiguousarray(img*255., dtype=np.uint8) 577 | gt = np.uint8(tbm[idx]*255.) 578 | pred = np.where(tpm[idx]>0.5, 255, 0) 579 | pred = np.ascontiguousarray(pred, dtype=np.uint8) 580 | 581 | res_img = skin_plot(img, gt, pred) 582 | 583 | fid = tid[idx] 584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png") 585 | Image.fromarray(gt).save(f"{save_imgs_dir}/{fid}_gt.png") 586 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png") 587 | 588 | 589 | # In[ ]: 590 | 591 | 592 | f"{config['model']['save_dir']}/visualized" 593 | 594 | -------------------------------------------------------------------------------- /train_and_test/isic/unetpp-isic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # UNet++ - ISIC2018 5 | # --- 6 | 7 | # ## Import packages & functions 8 | 9 | # In[1]: 10 | 11 | 12 | from __future__ import print_function, division 13 | 14 | 15 | import os 16 | import sys 17 | sys.path.append('../..') 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | import copy 21 | import json 22 | import importlib 23 | import glob 24 | import pandas as pd 25 | from skimage import io, transform 26 | import matplotlib.pyplot as plt 27 | from matplotlib.image import imread 28 | import numpy as np 29 | from tqdm import tqdm 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | import torch.optim as optim 35 | import torchmetrics 36 | from torch.optim import Adam, SGD 37 | from losses import DiceLoss, DiceLossWithLogtis 38 | from torch.nn import BCELoss, CrossEntropyLoss 39 | 40 | from utils import ( 41 | show_sbs, 42 | load_config, 43 | _print, 44 | ) 45 | 46 | # Ignore warnings 47 | import warnings 48 | warnings.filterwarnings("ignore") 49 | 50 | # plt.ion() # interactive mode 51 | 52 | 53 | # ## Set the seed 54 | 55 | # In[2]: 56 | 57 | 58 | torch.manual_seed(0) 59 | np.random.seed(0) 60 | torch.cuda.manual_seed(0) 61 | import random 62 | random.seed(0) 63 | 64 | 65 | # ## Load the config 66 | 67 | # In[3]: 68 | 69 | 70 | CONFIG_NAME = "isic/isic2018_unetpp.yaml" 71 | CONFIG_FILE_PATH = os.path.join("./configs", CONFIG_NAME) 72 | 73 | 74 | # In[5]: 75 | 76 | 77 | config = load_config(CONFIG_FILE_PATH) 78 | _print("Config:", "info_underline") 79 | print(json.dumps(config, indent=2)) 80 | print(20*"~-", "\n") 81 | 82 | 83 | # ## Dataset and Dataloader 84 | 85 | # In[7]: 86 | 87 | 88 | from datasets.isic import ISIC2018DatasetFast 89 | from torch.utils.data import DataLoader, Subset 90 | from torchvision import transforms 91 | 92 | 93 | # In[8]: 94 | 95 | 96 | # ------------------- params -------------------- 97 | INPUT_SIZE = config['dataset']['input_size'] 98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 99 | 100 | 101 | # ----------------- dataset -------------------- 102 | # preparing training dataset 103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True) 104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True) 105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True) 106 | 107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing 108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py 109 | 110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}") 111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}") 112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}") 113 | 114 | 115 | # prepare train dataloader 116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train']) 117 | 118 | # prepare validation dataloader 119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation']) 120 | 121 | # prepare test dataloader 122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test']) 123 | 124 | # -------------- test ----------------- 125 | # test and visualize the input data 126 | for sample in tr_dataloader: 127 | img = sample['image'] 128 | msk = sample['mask'] 129 | print("\n Training") 130 | show_sbs(img[0], msk[0,1]) 131 | break 132 | 133 | for sample in vl_dataloader: 134 | img = sample['image'] 135 | msk = sample['mask'] 136 | print("Validation") 137 | show_sbs(img[0], msk[0,1]) 138 | break 139 | 140 | for sample in te_dataloader: 141 | img = sample['image'] 142 | msk = sample['mask'] 143 | print("Test") 144 | show_sbs(img[0], msk[0,1]) 145 | break 146 | 147 | 148 | # ### Device 149 | 150 | # In[10]: 151 | 152 | 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 154 | print(f"Torch device: {device}") 155 | 156 | 157 | # ## Metrics 158 | 159 | # In[13]: 160 | 161 | 162 | metrics = torchmetrics.MetricCollection( 163 | [ 164 | torchmetrics.F1Score(), 165 | torchmetrics.Accuracy(), 166 | torchmetrics.Dice(), 167 | torchmetrics.Precision(), 168 | torchmetrics.Specificity(), 169 | torchmetrics.Recall(), 170 | # IoU 171 | torchmetrics.JaccardIndex(2) 172 | ], 173 | prefix='train_metrics/' 174 | ) 175 | 176 | # train_metrics 177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device) 178 | 179 | # valid_metrics 180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device) 181 | 182 | # test_metrics 183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device) 184 | 185 | 186 | # In[14]: 187 | 188 | 189 | def make_serializeable_metrics(computed_metrics): 190 | res = {} 191 | for k, v in computed_metrics.items(): 192 | res[k] = float(v.cpu().detach().numpy()) 193 | return res 194 | 195 | 196 | # ## Define validate function 197 | 198 | # In[15]: 199 | 200 | 201 | def validate(model, criterion, vl_dataloader): 202 | model.eval() 203 | with torch.no_grad(): 204 | 205 | evaluator = valid_metrics.clone().to(device) 206 | 207 | losses = [] 208 | cnt = 0. 209 | for batch, batch_data in enumerate(vl_dataloader): 210 | imgs = batch_data['image'] 211 | msks = batch_data['mask'] 212 | 213 | cnt += msks.shape[0] 214 | 215 | imgs = imgs.to(device) 216 | msks = msks.to(device) 217 | 218 | preds = model(imgs) 219 | loss = criterion(preds, msks) 220 | losses.append(loss.item()) 221 | 222 | 223 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 224 | msks_ = torch.argmax(msks, 1, keepdim=False) 225 | evaluator.update(preds_, msks_) 226 | 227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}" 228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}" 229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}") 230 | 231 | # print the final results 232 | loss = np.sum(losses)/cnt 233 | metrics = evaluator.compute() 234 | 235 | return evaluator, loss 236 | 237 | 238 | # ## Define train function¶ 239 | 240 | # In[16]: 241 | 242 | 243 | def train( 244 | model, 245 | device, 246 | tr_dataloader, 247 | vl_dataloader, 248 | config, 249 | 250 | criterion, 251 | optimizer, 252 | scheduler, 253 | 254 | save_dir='./', 255 | save_file_id=None, 256 | ): 257 | 258 | EPOCHS = tr_prms['epochs'] 259 | 260 | torch.cuda.empty_cache() 261 | model = model.to(device) 262 | 263 | evaluator = train_metrics.clone().to(device) 264 | 265 | epochs_info = [] 266 | best_model = None 267 | best_result = {} 268 | best_vl_loss = np.Inf 269 | for epoch in range(EPOCHS): 270 | model.train() 271 | 272 | evaluator.reset() 273 | tr_iterator = tqdm(enumerate(tr_dataloader)) 274 | tr_losses = [] 275 | cnt = 0 276 | for batch, batch_data in tr_iterator: 277 | imgs = batch_data['image'] 278 | msks = batch_data['mask'] 279 | 280 | imgs = imgs.to(device) 281 | msks = msks.to(device) 282 | 283 | optimizer.zero_grad() 284 | preds = model(imgs) 285 | loss = criterion(preds, msks) 286 | loss.backward() 287 | optimizer.step() 288 | 289 | # evaluate by metrics 290 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 291 | msks_ = torch.argmax(msks, 1, keepdim=False) 292 | evaluator.update(preds_, msks_) 293 | 294 | cnt += imgs.shape[0] 295 | tr_losses.append(loss.item()) 296 | 297 | # write details for each training batch 298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}" 299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}" 300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}") 301 | 302 | tr_loss = np.sum(tr_losses)/cnt 303 | 304 | # validate model 305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader) 306 | if vl_loss < best_vl_loss: 307 | # find a better model 308 | best_model = model 309 | best_vl_loss = vl_loss 310 | best_result = { 311 | 'tr_loss': tr_loss, 312 | 'vl_loss': vl_loss, 313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 315 | } 316 | 317 | # write the final results 318 | epoch_info = { 319 | 'tr_loss': tr_loss, 320 | 'vl_loss': vl_loss, 321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()), 322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute()) 323 | } 324 | epochs_info.append(epoch_info) 325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}") 326 | evaluator.reset() 327 | 328 | scheduler.step(vl_loss) 329 | 330 | # save final results 331 | res = { 332 | 'id': save_file_id, 333 | 'config': config, 334 | 'epochs_info': epochs_info, 335 | 'best_result': best_result 336 | } 337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json" 338 | fp = os.path.join(config['model']['save_dir'],fn) 339 | with open(fp, "w") as write_file: 340 | json.dump(res, write_file, indent=4) 341 | 342 | # save model's state_dict 343 | fn = "last_model_state_dict.pt" 344 | fp = os.path.join(config['model']['save_dir'],fn) 345 | torch.save(model.state_dict(), fp) 346 | 347 | # save the best model's state_dict 348 | fn = "best_model_state_dict.pt" 349 | fp = os.path.join(config['model']['save_dir'], fn) 350 | torch.save(best_model.state_dict(), fp) 351 | 352 | return best_model, model, res 353 | 354 | 355 | # ## Define test function 356 | 357 | # In[18]: 358 | 359 | 360 | def test(model, te_dataloader): 361 | model.eval() 362 | with torch.no_grad(): 363 | evaluator = test_metrics.clone().to(device) 364 | for batch_data in tqdm(te_dataloader): 365 | imgs = batch_data['image'] 366 | msks = batch_data['mask'] 367 | 368 | imgs = imgs.to(device) 369 | msks = msks.to(device) 370 | 371 | preds = model(imgs) 372 | 373 | # evaluate by metrics 374 | preds_ = torch.argmax(preds, 1, keepdim=False).float() 375 | msks_ = torch.argmax(msks, 1, keepdim=False) 376 | evaluator.update(preds_, msks_) 377 | 378 | return evaluator 379 | 380 | 381 | # ## Load and prepare model 382 | 383 | # In[19]: 384 | 385 | 386 | from models.unetpp import NestedUNet as Net 387 | 388 | 389 | model = Net(**config['model']['params']) 390 | torch.cuda.empty_cache() 391 | model = model.to(device) 392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) 393 | 394 | os.makedirs(config['model']['save_dir'], exist_ok=True) 395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt" 396 | 397 | if config['model']['load_weights']: 398 | model.load_state_dict(torch.load(model_path)) 399 | print("Loaded pre-trained weights...") 400 | 401 | 402 | # criterion_dice = DiceLoss() 403 | criterion_dice = DiceLossWithLogtis() 404 | # criterion_ce = BCELoss() 405 | criterion_ce = CrossEntropyLoss() 406 | 407 | 408 | def criterion(preds, masks): 409 | c_dice = criterion_dice(preds, masks) 410 | c_ce = criterion_ce(preds, masks) 411 | return 0.5*c_dice + 0.5*c_ce 412 | 413 | tr_prms = config['training'] 414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params']) 415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler']) 416 | 417 | 418 | # ## Start traning 419 | 420 | # In[20]: 421 | 422 | 423 | best_model, model, res = train( 424 | model, 425 | device, 426 | tr_dataloader, 427 | vl_dataloader, 428 | config, 429 | 430 | criterion, 431 | optimizer, 432 | scheduler, 433 | 434 | save_dir = config['model']['save_dir'], 435 | save_file_id = None, 436 | ) 437 | 438 | 439 | # In[21]: 440 | 441 | 442 | te_metrics = test(best_model, te_dataloader) 443 | te_metrics.compute() 444 | 445 | 446 | # In[22]: 447 | 448 | 449 | f"{config['model']['save_dir']}" 450 | 451 | 452 | # # Test the best inferred model 453 | # ---- 454 | 455 | # ## Load the best model 456 | 457 | # In[23]: 458 | 459 | 460 | best_model = Net(**config['model']['params']) 461 | torch.cuda.empty_cache() 462 | best_model = best_model.to(device) 463 | 464 | fn = "best_model_state_dict.pt" 465 | os.makedirs(config['model']['save_dir'], exist_ok=True) 466 | model_path = f"{config['model']['save_dir']}/{fn}" 467 | 468 | best_model.load_state_dict(torch.load(model_path)) 469 | print("Loaded best model weights...") 470 | 471 | 472 | # ## Evaluation 473 | 474 | # In[24]: 475 | 476 | 477 | te_metrics = test(best_model, te_dataloader) 478 | te_metrics.compute() 479 | 480 | 481 | # ## Plot graphs 482 | 483 | # In[25]: 484 | 485 | 486 | result_file_path = f"{config['model']['save_dir']}/result.json" 487 | with open(result_file_path, 'r') as f: 488 | results = json.loads(''.join(f.readlines())) 489 | epochs_info = results['epochs_info'] 490 | 491 | tr_losses = [d['tr_loss'] for d in epochs_info] 492 | vl_losses = [d['vl_loss'] for d in epochs_info] 493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info] 494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info] 495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info] 496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info] 497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info] 498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info] 499 | 500 | 501 | _, axs = plt.subplots(1, 4, figsize=[16,3]) 502 | 503 | axs[0].set_title("Loss") 504 | axs[0].plot(tr_losses, 'r-', label="train loss") 505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss") 506 | axs[0].legend() 507 | 508 | axs[1].set_title("Dice score") 509 | axs[1].plot(tr_dice, 'r-', label="train dice") 510 | axs[1].plot(vl_dice, 'b-', label="validation dice") 511 | axs[1].legend() 512 | 513 | axs[2].set_title("Jaccard Similarity") 514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex") 515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex") 516 | axs[2].legend() 517 | 518 | axs[3].set_title("Accuracy") 519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy") 520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy") 521 | axs[3].legend() 522 | 523 | plt.show() 524 | 525 | 526 | # In[ ]: 527 | 528 | 529 | epochs_info 530 | 531 | 532 | # ## Save images 533 | 534 | # In[ ]: 535 | 536 | 537 | from PIL import Image 538 | import cv2 539 | def skin_plot(img, gt, pred): 540 | img = np.array(img) 541 | gt = np.array(gt) 542 | pred = np.array(pred) 543 | edged_test = cv2.Canny(pred, 100, 255) 544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 545 | edged_gt = cv2.Canny(gt, 100, 255) 546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 547 | for cnt_test in contours_test: 548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1) 549 | for cnt_gt in contours_gt: 550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1) 551 | return img 552 | 553 | #--------------------------------------------------------------------------------------------- 554 | 555 | 556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized" 557 | 558 | if not os.path.isdir(save_imgs_dir): 559 | os.mkdir(save_imgs_dir) 560 | 561 | with torch.no_grad(): 562 | for batch in tqdm(te_dataloader): 563 | imgs = batch['image'] 564 | msks = batch['mask'] 565 | ids = batch['id'] 566 | 567 | preds = best_model(imgs.to(device)) 568 | 569 | txm = imgs.cpu().numpy() 570 | tbm = torch.argmax(msks, 1).cpu().numpy() 571 | tpm = torch.argmax(preds, 1).cpu().numpy() 572 | tid = ids 573 | 574 | for idx in range(len(tbm)): 575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255. 576 | img = np.ascontiguousarray(img, dtype=np.uint8) 577 | gt = np.uint8(tbm[idx]*255.) 578 | pred = np.where(tpm[idx]>0.5, 255, 0) 579 | pred = np.ascontiguousarray(pred, dtype=np.uint8) 580 | 581 | res_img = skin_plot(img, gt, pred) 582 | 583 | fid = tid[idx] 584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png") 585 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png") 586 | 587 | 588 | # In[ ]: 589 | 590 | 591 | f"{config['model']['save_dir']}/visualized" 592 | 593 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | class bcolors: 2 | HEADER = '\033[95m' 3 | OKBLUE = '\033[94m' 4 | OKCYAN = '\033[96m' 5 | OKGREEN = '\033[92m' 6 | WARNING = '\033[93m' 7 | FAIL = '\033[91m' 8 | ENDC = '\033[0m' 9 | BOLD = '\033[1m' 10 | UNDERLINE = '\033[4m' 11 | 12 | 13 | 14 | from termcolor import colored 15 | def _print(string, p=None): 16 | if not p: 17 | print(string) 18 | return 19 | pre = f"{bcolors.ENDC}" 20 | 21 | if "bold" in p.lower(): 22 | pre += bcolors.BOLD 23 | elif "underline" in p.lower(): 24 | pre += bcolors.UNDERLINE 25 | elif "header" in p.lower(): 26 | pre += bcolors.HEADER 27 | 28 | if "warning" in p.lower(): 29 | pre += bcolors.WARNING 30 | elif "error" in p.lower(): 31 | pre += bcolors.FAIL 32 | elif "ok" in p.lower(): 33 | pre += bcolors.OKGREEN 34 | elif "info" in p.lower(): 35 | if "blue" in p.lower(): 36 | pre += bcolors.OKBLUE 37 | else: 38 | pre += bcolors.OKCYAN 39 | 40 | print(f"{pre}{string}{bcolors.ENDC}") 41 | 42 | 43 | 44 | import yaml 45 | def load_config(config_filepath): 46 | try: 47 | with open(config_filepath, 'r') as file: 48 | config = yaml.safe_load(file) 49 | return config 50 | except FileNotFoundError: 51 | _print(f"Config file not found! <{config_filepath}>", "error_bold") 52 | exit(1) 53 | 54 | 55 | 56 | import numpy as np 57 | from matplotlib import pyplot as plt 58 | def show_sbs(im1, im2, figsize=[8,4], im1_title="Image", im2_title="Mask", show=True): 59 | if im1.shape[0]<4: 60 | im1 = np.array(im1) 61 | im1 = np.transpose(im1, [1, 2, 0]) 62 | 63 | if im2.shape[0]<4: 64 | im2 = np.array(im2) 65 | im2 = np.transpose(im2, [1, 2, 0]) 66 | 67 | _, axs = plt.subplots(1, 2, figsize=figsize) 68 | axs[0].imshow(im1) 69 | axs[0].set_title(im1_title) 70 | axs[1].imshow(im2, cmap='gray') 71 | axs[1].set_title(im2_title) 72 | if show: plt.show() --------------------------------------------------------------------------------