├── .gitignore
├── README.md
├── configs
├── isic
│ ├── isic2018_attunet.yaml
│ ├── isic2018_missformer.yaml
│ ├── isic2018_multiresunet.yaml
│ ├── isic2018_resunet.yaml
│ ├── isic2018_transunet.yaml
│ ├── isic2018_uctransnet.yaml
│ ├── isic2018_unet.yaml
│ └── isic2018_unetpp.yaml
└── segpc
│ ├── segpc2021_attunet.yaml
│ ├── segpc2021_missformer.yaml
│ ├── segpc2021_multiresunet.yaml
│ ├── segpc2021_resunet.yaml
│ ├── segpc2021_transunet.yaml
│ ├── segpc2021_uctransnet.yaml
│ ├── segpc2021_unet.yaml
│ └── segpc2021_unetpp.yaml
├── datasets
├── README.md
├── __init__.py
├── isic.ipynb
├── isic.py
├── prepare_isic.ipynb
├── prepare_segpc.ipynb
├── segpc.ipynb
└── segpc.py
├── images
├── ComparisonOfModels.png
├── U-Net_Taxonomy.png
├── isic2018.png
├── isic2018_sample.png
├── segpc.png
├── segpc2021_sample.png
├── synapse.png
└── unet-pipeline.png
├── losses.py
├── models
├── __init__.py
├── _missformer
│ ├── MISSFormer.py
│ ├── __init__.py
│ └── segformer.py
├── _resunet
│ ├── __init__.py
│ ├── modules.py
│ └── res_unet.py
├── _transunet
│ ├── vit_seg_configs.py
│ ├── vit_seg_modeling.py
│ ├── vit_seg_modeling_c4.py
│ ├── vit_seg_modeling_resnet_skip.py
│ └── vit_seg_modeling_resnet_skip_c4.py
├── _uctransnet
│ ├── CTrans.py
│ ├── Config.py
│ ├── UCTransNet.py
│ └── UNet.py
├── attunet.py
├── multiresunet.py
├── unet.py
└── unetpp.py
├── train_and_test
├── isic
│ ├── attunet-isic.ipynb
│ ├── attunet-isic.py
│ ├── missformer-isic.ipynb
│ ├── missformer-isic.py
│ ├── multiresunet-isic.ipynb
│ ├── multiresunet-isic.py
│ ├── resunet-isic.ipynb
│ ├── resunet-isic.py
│ ├── transunet-isic.ipynb
│ ├── transunet-isic.py
│ ├── uctransnet-isic.ipynb
│ ├── uctransnet-isic.py
│ ├── unet-isic.ipynb
│ ├── unet-isic.py
│ ├── unetpp-isic.ipynb
│ └── unetpp-isic.py
└── segpc
│ ├── attunet-segpc.ipynb
│ ├── attunet-segpc.py
│ ├── missformer-segpc.ipynb
│ ├── missformer-segpc.py
│ ├── multiresunet-segpc.ipynb
│ ├── multiresunet-segpc.py
│ ├── resunet-segpc.ipynb
│ ├── resunet-segpc.py
│ ├── transunet-segpc.ipynb
│ ├── transunet-segpc.py
│ ├── uctransnet-segpc.ipynb
│ ├── uctransnet-segpc.py
│ ├── unet-segpc.ipynb
│ ├── unet-segpc.py
│ ├── unetpp-segpc.ipynb
│ └── unetpp-segpc.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | # db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | #vscode
132 | .vscode
133 |
134 |
135 | **.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Awesome U-Net
2 |
3 | [](https://github.com/hee9joon/Awesome-Diffusion-Models)
4 | [](https://opensource.org/licenses/MIT)
5 |
6 | Official repo for [Medical Image Segmentation Review: The Success of U-Net](https://ieeexplore.ieee.org/abstract/document/10643318)
7 |
8 |
9 |
10 |
11 |
12 | ### Announcements
13 |
14 | August 21, 2024: The final draft is published at the [IEEE TPAMI](https://ieeexplore.ieee.org/abstract/document/10643318) :fire::fire:
15 |
16 | November 27, 2022: [arXiv](https://arxiv.org/abs/2211.14830) release version.
17 |
18 | #### Citation
19 |
20 | ```latex
21 | @article{azad2024medical,
22 | author={Azad, Reza and Aghdam, Ehsan Khodapanah and Rauland, Amelie and Jia, Yiwei and Avval, Atlas Haddadi and Bozorgpour, Afshin and Karimijafarbigloo, Sanaz and Cohen, Joseph Paul and Adeli, Ehsan and Merhof, Dorit},
23 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
24 | title={Medical Image Segmentation Review: The Success of U-Net},
25 | year={2024},
26 | pages={1-20},
27 | keywords={Image segmentation;Biomedical imaging;Taxonomy;Computer architecture;Feature extraction;Transformers;Task analysis;Medical Image Segmentation;Deep Learning;U-Net;Convolutional Neural Network;Transformer},
28 | doi={10.1109/TPAMI.2024.3435571}
29 | }
30 | ```
31 |
32 | ---
33 |
34 | ### Abstract
35 |
36 | Automatic medical image segmentation is a crucial topic in the medical domain and successively a critical counterpart in the computer-aided diagnosis paradigm. U-Net is the most widespread image segmentation architecture due to its flexibility, optimized modular design, and success in all medical image modalities. Over the years, the U-Net model achieved tremendous attention from academic and industrial researchers. Several extensions of this network have been proposed to address the scale and complexity created by medical tasks. Addressing the deficiency of the naive U-Net model is the foremost step for vendors to utilize the proper U-Net variant model for their business. Having a compendium of different variants in one place makes it easier for builders to identify the relevant research. Also, for ML researchers it will help them understand the challenges of the biological tasks that challenge the model. To address this, we discuss the practical aspects of the U-Net model and suggest a taxonomy to categorize each network variant. Moreover, to measure the performance of these strategies in a clinical application, we propose fair evaluations of some unique and famous designs on well-known datasets. We provide a comprehensive implementation library with trained models for future research. In addition, for ease of future studies, we created an online list of U-Net papers with their possible official implementation.
37 |
38 |
39 |
40 |
41 |
42 | ---
43 |
44 |
45 |
46 | ### The structure of codes
47 |
48 | Here is
49 |
50 | ```bash
51 | .
52 | ├── README.md
53 | ├── images
54 | │ └── *.png
55 | ├── configs
56 | │ ├── isic
57 | │ │ ├── isic2018_*.yaml
58 | │ └── segpc
59 | │ └── segpc2021_*.yaml
60 | ├── datasets
61 | │ ├── *.py
62 | │ ├── *.ipynb
63 | │ └── prepare_*.ipynb
64 | ├── models
65 | │ ├── *.py
66 | │ └── _*
67 | │ └── *.py
68 | ├── train_and_test
69 | │ ├── isic
70 | │ │ ├── *-isic.ipynb
71 | │ │ └── *-isic.py
72 | │ └── segpc
73 | │ ├── *-segpc.ipynb
74 | │ └── *-segpc.py
75 | ├── losses.py
76 | └── utils.py
77 | ```
78 |
79 |
80 |
81 | ## Dataset prepration
82 |
83 | Please go to ["./datasets/README.md"](https://github.com/NITR098/Awesome-U-Net/blob/main/datasets/README.md) for details. We used 3 datasets for this work. After preparing required data you need to put the required data path in relevant config files.
84 |
85 |
86 |
87 | ## Train and Test
88 |
89 | In the `train_and_test` folder, there are folders with the names of different datasets. In each of these subfolders, there are files related to each model network in two different formats (`.py` and `.ipynb`). In notebook files you will face with the following procedures. This file contains both the testing and traning steps.
90 |
91 | - Prepration step
92 | - Import packages & functions
93 | - Set the seed
94 | - Load the config file
95 | - Dataset and Dataloader
96 | - Prepare Metrics
97 | - Define test and validate function
98 | - Load and prepare model
99 | - Traning
100 | - Save the best model
101 | - Test the best inferred model
102 | - Load the best model
103 | - Evaluation
104 | - Plot graphs and print results
105 | - Save images
106 |
107 |
108 |
109 | ### Pretrained model weights
110 |
111 | Here you can download pre-trained weights for networks.
112 |
113 | | Network | Model Weight | Train and Test File |
114 | | ------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
115 | | **U-Net** | [ISIC2018](https://mega.nz/file/pNd0xLIB#LqY-e-hdQhq6_dQZpAw_7MxKclMB5DAFMybL5w99OzM) - [SegPC2021](https://mega.nz/file/EZEjTYyT#UMsliboXuqrsobGHV_mn4jiBrOf_dMZF7hp2aY0o2hI) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/unet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/unet-segpc.ipynb) |
116 | | **Att-UNet** | [ISIC2018](https://mega.nz/file/5VsBTKgK#vNu_nvuz-9Lktw6aMOxuguQyim1sVnG4QdkGtVX3pEs) - [SegPC2021](https://mega.nz/file/gRVCXCgT#We3_nPsx_xIBXy6-bsg85rQYYzKHut17Zn5HDnh0Aqw) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/attunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/attunet-segpc.ipynb) |
117 | | **U-Net++** | [ISIC2018](https://mega.nz/file/NcFQUY5D#1mSGOC4GGTA8arWzcM77yyH9GoApciw0mB4pFp18n0Q) - [SegPC2021](https://mega.nz/file/JFVSHLxY#EwPpPZ5N0KDaXhDXxyyuQ_HaD2iNiv5hdqplznrP8Os) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/unetpp-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/unetpp-segpc.ipynb) |
118 | | **MultiResUNet** | [ISIC2018](https://mega.nz/file/tIEVAAba#t-5vLCMwlH6hzAri7DJ8ut-eT2vFN5b6qj6Vc3By6_g) - [SegPC2021](https://mega.nz/file/tUN11R5C#I_JpAT7mYDM1q40ulp8TJxnzHFR4Fh3WX_klep62ywE) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/multiresunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/multiresunet-segpc.ipynb) |
119 | | **Residual U-Net** | [ISIC2018](https://mega.nz/file/NAVHSSJa#FwcYG6bKOdpcEorN_nnjWFEx29toSspSiMzFTqIrVW4) - [SegPC2021](https://mega.nz/file/gQ91WBRB#mzIAeEUze4cAi74dMa3rqivGdYtzpKqDI16vNao7-6A) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/resunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/resunet-segpc.ipynb) |
120 | | **TransUNet** | [ISIC2018](https://mega.nz/file/UM9jkK6B#7rFd9TiOY6pEGt-gDosFopdV78slgpHbj_wKZ4H39OM) - [SegPC2021](https://mega.nz/file/5YFBXBoZ#6S8B6MyAsSsr5cNw0-QIIIzF6CgxhEUsOl0xwAknTr8) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/transunet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/transunet-segpc.ipynb) |
121 | | **UCTransNet** | [ISIC2018](https://mega.nz/file/RMNQmKoQ#j8zGEuud33eh-tOIZa1dpkReB8DYKt1De75eeR7wLnM) - [SegPC2021](https://mega.nz/file/hYMShICa#kg5VFhE-m5X0ouE1rc_teaYSb_E15NpbBVE0P_V7WH8) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/uctransnet-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/uctransnet-segpc.ipynb) |
122 | | **MISSFormer** | [ISIC2018](https://mega.nz/file/EANRiBoQ#E2LC0ZS7LU5OuEdQJ8dYGihjzqpEEotUqLEnEGZ59wU) - [SegPC2021](https://mega.nz/file/9I1CUJbZ#V6zdx8vZDyPJjHmVgoJH4D86sTuqNu6OuHeUQVB6ees) | [ISIC2018](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/isic/missformer-isic.ipynb) - [SegPC2021](https://github.com/NITR098/Awesome-U-Net/blob/main/train_and_test/segpc/missformer-segpc.ipynb) |
123 |
124 |
125 | ## Results
126 |
127 | For evaluating the performance of some mentioned methods, three challenging tasks in medical image segmentaion has been considered. In bellow, results of them illustrated.
128 |
129 |
130 |
131 | Performance comparison on ***ISIC 2018*** dataset (best results are bolded).
132 |
133 | | Methods | AC | PR | SE | SP | Dice | IoU |
134 | | ------------------ | ---------- | ---------- | ---------- | ---------- | ---------- | ---------- |
135 | | **U-Net** | 0.9446 | 0.8746 | 0.8603 | 0.9671 | 0.8674 | 0.8491 |
136 | | **Att-UNet** | 0.9516 | 0.9075 | 0.8579 | 0.9766 | 0.8820 | 0.8649 |
137 | | **U-Net++** | 0.9517 | 0.9067 | 0.8590 | 0.9764 | 0.8822 | 0.8651 |
138 | | **MultiResUNet** | 0.9473 | 0.8765 | 0.8689 | 0.9704 | 0.8694 | 0.8537 |
139 | | **Residual U-Net** | 0.9468 | 0.8753 | 0.8659 | 0.9688 | 0.8689 | 0.8509 |
140 | | **TransUNet** | 0.9452 | 0.8823 | 0.8578 | 0.9653 | 0.8499 | 0.8365 |
141 | | **UCTransNet** | **0.9546** | **0.9100** | **0.8704** | **0.9770** | **0.8898** | **0.8729** |
142 | | **MISSFormer** | 0.9453 | 0.8964 | 0.8371 | 0.9742 | 0.8657 | 0.8484 |
143 |
144 |
145 |
146 | Performance comparison on ***SegPC 2021*** dataset (best results are bolded).
147 |
148 | | Methods | AC | PR | SE | SP | Dice | IoU |
149 | | ------------------ | ---------- | ---------- | ---------- | ---------- | ---------- | ---------- |
150 | | **U-Net** | 0.9795 | 0.9084 | 0.8548 | 0.9916 | 0.8808 | 0.8824 |
151 | | **Att-UNet** | 0.9854 | 0.9360 | 0.8964 | 0.9940 | 0.9158 | 0.9144 |
152 | | **U-Net++** | 0.9845 | 0.9328 | 0.8887 | 0.9938 | 0.9102 | 0.9092 |
153 | | **MultiResUNet** | 0.9753 | 0.8391 | 0.8925 | 0.9834 | 0.8649 | 0.8676 |
154 | | **Residual U-Net** | 0.9743 | 0.8920 | 0.8080 | 0.9905 | 0.8479 | 0.8541 |
155 | | **TransUNet** | 0.9702 | 0.8678 | 0.7831 | 0.9884 | 0.8233 | 0.8338 |
156 | | **UCTransNet** | **0.9857** | **0.9365** | **0.8991** | **0.9941** | **0.9174** | **0.9159** |
157 | | **MISSFormer** | 0.9663 | 0.8152 | 0.8014 | 0.9823 | 0.8082 | 0.8209 |
158 |
159 |
160 |
161 | Performance comparison on ***Synapse*** dataset (best results are bolded).
162 | | Method | DSC↑ | HD↓ | Aorta | Gallbladder | Kidney(L) | Kidney(R) | Liver | Pancreas | Spleen | Stomach |
163 | | ------------------ | --------- | --------- | --------- | ----------- | --------- | --------- | --------- | --------- | --------- | --------- |
164 | | **U-Net** | 76.85 | 39.70 | 89.07 | 69.72 | 77.77 | 68.60 | 93.43 | 53.98 | 86.67 | 75.58 |
165 | | **Att-UNet** | 77.77 | 36.02 | **89.55** | **68.88** | 77.98 | 71.11 | 93.57 | 58.04 | 87.30 | 75.75 |
166 | | **U-Net++** | 76.91 | 36.93 | 88.19 | 65.89 | 81.76 | 74.27 | 93.01 | 58.20 | 83.44 | 70.52 |
167 | | **MultiResUNet** | 77.42 | 36.84 | 87.73 | 65.67 | 82.08 | 70.43 | 93.49 | 60.09 | 85.23 | 74.66 |
168 | | **Residual U-Net** | 76.95 | 38.44 | 87.06 | 66.05 | 83.43 | 76.83 | 93.99 | 51.86 | 85.25 | 70.13 |
169 | | **TransUNet** | 77.48 | 31.69 | 87.23 | 63.13 | 81.87 | 77.02 | 94.08 | 55.86 | 85.08 | 75.62 |
170 | | **UCTransNet** | 78.23 | 26.75 | 84.25 | 64.65 | 82.35 | 77.65 | 94.36 | 58.18 | 84.74 | 79.66 |
171 | | **MISSFormer** | **81.96** | **18.20** | 86.99 | 68.65 | **85.21** | **82.00** | **94.41** | **65.67** | **91.92** | **80.81** |
172 |
173 | ### Visualization
174 |
175 | - **Results on ISIC 2018**
176 |
177 | 
178 |
179 | Visual comparisons of different methods on the *ISIC 2018* skin lesion segmentation dataset. Ground truth boundaries are shown in green, and predicted boundaries are shown in blue.
180 |
181 | - **Result on SegPC 2021**
182 |
183 | 
184 |
185 | Visual comparisons of different methods on the *SegPC 2021* cell segmentation dataset. Red region indicates the Cytoplasm and blue denotes the Nucleus area of cell.
186 |
187 | - **Result on Synapse**
188 |
189 | 
190 |
191 | Visual comparisons of different methods on the *Synapse* multi-organ segmentation dataset.
192 |
193 | ## References
194 |
195 | ### Codes [GitHub Pages]
196 |
197 | - AttU-Net: [https://github.com/LeeJunHyun/Image_Segmentation](https://github.com/LeeJunHyun/Image_Segmentation)
198 |
199 | - U-Net++: [https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py](https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py)
200 |
201 | - MultiResUNet: https://github.com/j-sripad/mulitresunet-pytorch/blob/main/multiresunet.py
202 |
203 | - Residual U-Net: https://github.com/rishikksh20/ResUnet
204 |
205 | - TransUNet: https://github.com/Beckschen/TransUNet
206 |
207 | - UCTransNet: https://github.com/McGregorWwww/UCTransNet
208 |
209 | - MISSFormer: https://github.com/ZhifangDeng/MISSFormer
210 |
211 | ### Query
212 |
213 | For any query, please contact us.
214 |
215 | ```
216 | rezazad68@gmail.com
217 | engtekh@gmail.com
218 | afshinbigboy@gmail.com
219 | ```
220 |
--------------------------------------------------------------------------------
/configs/isic/isic2018_attunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 8
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 8
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 4
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'Adam'
34 | params:
35 | lr: 0.0001
36 | criterion:
37 | name: "DiceLoss"
38 | params: {}
39 | scheduler:
40 | factor: 0.5
41 | patience: 10
42 | epochs: 100
43 | model:
44 | save_dir: '../../saved_models/isic2018_attunet'
45 | load_weights: false
46 | name: 'AttU_Net'
47 | params:
48 | img_ch: 3
49 | output_ch: 2
50 | # preprocess:
--------------------------------------------------------------------------------
/configs/isic/isic2018_missformer.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 8
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 8
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 4
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'SGD'
34 | params:
35 | lr: 0.0001
36 | momentum: 0.9
37 | weight_decay: 0.0001
38 | criterion:
39 | name: "DiceLoss"
40 | params: {}
41 | scheduler:
42 | factor: 0.5
43 | patience: 10
44 | epochs: 300
45 | model:
46 | save_dir: '../../saved_models/isic2018_missformer'
47 | load_weights: false
48 | name: "MISSFormer"
49 | params:
50 | in_ch: 3
51 | num_classes: 2
52 | # preprocess:
--------------------------------------------------------------------------------
/configs/isic/isic2018_multiresunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 2
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 2
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 2
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'Adam'
34 | params:
35 | lr: 0.0005
36 | criterion:
37 | name: "DiceLoss"
38 | params: {}
39 | scheduler:
40 | factor: 0.5
41 | patience: 10
42 | epochs: 100
43 | model:
44 | save_dir: '../../saved_models/isic2018_multiresunet'
45 | load_weights: false
46 | name: 'MultiResUnet'
47 | params:
48 | channels: 3
49 | filters: 32
50 | nclasses: 2
51 | # preprocess:
--------------------------------------------------------------------------------
/configs/isic/isic2018_resunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 8
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 8
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 4
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'Adam'
34 | params:
35 | lr: 0.0001
36 | criterion:
37 | name: "DiceLoss"
38 | params: {}
39 | scheduler:
40 | factor: 0.5
41 | patience: 10
42 | epochs: 100
43 | model:
44 | save_dir: '../../saved_models/isic2018_resunet'
45 | load_weights: false
46 | name: 'ResUnet'
47 | params:
48 | in_ch: 3
49 | out_ch: 2
50 | # preprocess:
--------------------------------------------------------------------------------
/configs/isic/isic2018_transunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 8
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 8
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 4
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'SGD'
34 | params:
35 | lr: 0.0001
36 | momentum: 0.9
37 | weight_decay: 0.0001
38 | criterion:
39 | name: "DiceLoss"
40 | params: {}
41 | scheduler:
42 | factor: 0.5
43 | patience: 10
44 | epochs: 100
45 | model:
46 | save_dir: '../../saved_models/isic2018_transunet'
47 | load_weights: false
48 | name: 'VisionTransformer'
49 | params:
50 | img_size: 224
51 | num_classes: 2
52 | # preprocess:
--------------------------------------------------------------------------------
/configs/isic/isic2018_uctransnet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 8
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 8
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 4
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'Adam'
34 | params:
35 | lr: 0.0001
36 | criterion:
37 | name: "DiceLoss"
38 | params: {}
39 | scheduler:
40 | factor: 0.5
41 | patience: 10
42 | epochs: 100
43 | model:
44 | save_dir: '../../saved_models/isic2018_uctransnet'
45 | load_weights: false
46 | name: "UCTransNet"
47 | params:
48 | n_channels: 3
49 | n_classes: 2
50 | # preprocess:
--------------------------------------------------------------------------------
/configs/isic/isic2018_unet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 8
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 8
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 4
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'Adam'
34 | params:
35 | lr: 0.0001
36 | criterion:
37 | name: "DiceLoss"
38 | params: {}
39 | scheduler:
40 | factor: 0.5
41 | patience: 10
42 | epochs: 100
43 | model:
44 | save_dir: '../../saved_models/isic2018_unet'
45 | load_weights: false
46 | name: 'UNet'
47 | params:
48 | in_channels: 3
49 | out_channels: 2
50 | with_bn: false
51 | # preprocess:
--------------------------------------------------------------------------------
/configs/isic/isic2018_unetpp.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "ISIC2018Dataset"
7 | input_size: 224
8 | training:
9 | params:
10 | data_dir: "/path/to/datasets/ISIC2018"
11 | validation:
12 | params:
13 | data_dir: "/path/to/datasets/ISIC2018"
14 | number_classes: 2
15 | data_loader:
16 | train:
17 | batch_size: 16
18 | shuffle: true
19 | num_workers: 8
20 | pin_memory: true
21 | validation:
22 | batch_size: 16
23 | shuffle: false
24 | num_workers: 8
25 | pin_memory: true
26 | test:
27 | batch_size: 16
28 | shuffle: false
29 | num_workers: 4
30 | pin_memory: false
31 | training:
32 | optimizer:
33 | name: 'Adam'
34 | params:
35 | lr: 0.0001
36 | criterion:
37 | name: "DiceLoss"
38 | params: {}
39 | scheduler:
40 | factor: 0.5
41 | patience: 10
42 | epochs: 100
43 | model:
44 | save_dir: '../../saved_models/isic2018_unetpp'
45 | load_weights: false
46 | name: 'NestedUNet'
47 | params:
48 | num_classes: 2
49 | input_channels: 3
50 | deep_supervision: false
51 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_attunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | name: 'Adam'
31 | params:
32 | lr: 0.0001
33 | criterion:
34 | name: "DiceLoss"
35 | params: {}
36 | scheduler:
37 | factor: 0.5
38 | patience: 10
39 | epochs: 100
40 | model:
41 | save_dir: '../../saved_models/segpc2021_attunet'
42 | load_weights: false
43 | name: 'AttU_Net'
44 | params:
45 | img_ch: 4
46 | output_ch: 2
47 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_missformer.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | name: 'SGD'
31 | params:
32 | lr: 0.0001
33 | momentum: 0.9
34 | weight_decay: 0.0001
35 | criterion:
36 | name: "DiceLoss"
37 | params: {}
38 | scheduler:
39 | factor: 0.5
40 | patience: 10
41 | epochs: 500
42 | model:
43 | save_dir: '../../saved_models/segpc2021_missformer'
44 | load_weights: false
45 | name: 'MISSFormer'
46 | params:
47 | in_ch: 4
48 | num_classes: 2
49 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_multiresunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | name: 'Adam'
31 | params:
32 | lr: 0.0001
33 | # name: "SGD"
34 | # params:
35 | # lr: 0.0001
36 | # momentum: 0.9
37 | # weight_decay: 0.0001
38 | criterion:
39 | name: "DiceLoss"
40 | params: {}
41 | scheduler:
42 | factor: 0.5
43 | patience: 10
44 | epochs: 100
45 | model:
46 | save_dir: '../../saved_models/segpc2021_multiresunet'
47 | load_weights: false
48 | name: 'MultiResUnet'
49 | params:
50 | channels: 4
51 | filters: 32
52 | nclasses: 2
53 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_resunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | name: 'Adam'
31 | params:
32 | lr: 0.0001
33 | criterion:
34 | name: "DiceLoss"
35 | params: {}
36 | scheduler:
37 | factor: 0.5
38 | patience: 10
39 | epochs: 100
40 | model:
41 | save_dir: '../../saved_models/segpc2021_resunet'
42 | load_weights: false
43 | name: 'ResUnet'
44 | params:
45 | in_ch: 4
46 | out_ch: 2
47 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_transunet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | # name: 'Adam'
31 | # params:
32 | # lr: 0.0001
33 | name: "SGD"
34 | params:
35 | lr: 0.0001
36 | momentum: 0.9
37 | weight_decay: 0.0001
38 | criterion:
39 | name: "DiceLoss"
40 | params: {}
41 | scheduler:
42 | factor: 0.5
43 | patience: 10
44 | epochs: 100
45 | model:
46 | save_dir: '../../saved_models/segpc2021_transunet'
47 | load_weights: false
48 | name: 'VisionTransformer'
49 | params:
50 | img_size: 224
51 | num_classes: 2
52 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_uctransnet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | name: 'Adam'
31 | params:
32 | lr: 0.0001
33 | criterion:
34 | name: "DiceLoss"
35 | params: {}
36 | scheduler:
37 | factor: 0.5
38 | patience: 10
39 | epochs: 100
40 | model:
41 | save_dir: '../../saved_models/segpc2021_uctransnet'
42 | load_weights: false
43 | name: 'UCTransNet'
44 | params:
45 | n_channels: 4
46 | n_classes: 2
47 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_unet.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | name: 'Adam'
31 | params:
32 | lr: 0.0001
33 | criterion:
34 | name: "DiceLoss"
35 | params: {}
36 | scheduler:
37 | factor: 0.5
38 | patience: 10
39 | epochs: 100
40 | model:
41 | save_dir: '../../saved_models/segpc2021_unet'
42 | load_weights: false
43 | name: 'UNet'
44 | params:
45 | in_channels: 4
46 | out_channels: 2
47 | with_bn: false
48 | # preprocess:
--------------------------------------------------------------------------------
/configs/segpc/segpc2021_unetpp.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | mode: 'train'
3 | device: 'gpu'
4 | transforms: none
5 | dataset:
6 | class_name: "SegPC2021Dataset"
7 | input_size: 224
8 | scale: 2.5
9 | data_dir: "/path/to/datasets/segpc/np"
10 | dataset_dir: "/path/to/datasets/segpc/TCIA_SegPC_dataset"
11 | number_classes: 2
12 | data_loader:
13 | train:
14 | batch_size: 16
15 | shuffle: true
16 | num_workers: 4
17 | pin_memory: true
18 | validation:
19 | batch_size: 16
20 | shuffle: false
21 | num_workers: 4
22 | pin_memory: true
23 | test:
24 | batch_size: 16
25 | shuffle: false
26 | num_workers: 4
27 | pin_memory: false
28 | training:
29 | optimizer:
30 | name: 'Adam'
31 | params:
32 | lr: 0.0001
33 | criterion:
34 | name: "DiceLoss"
35 | params: {}
36 | scheduler:
37 | factor: 0.5
38 | patience: 10
39 | epochs: 100
40 | model:
41 | save_dir: '../../saved_models/segpc2021_unetpp'
42 | load_weights: false
43 | name: 'NestedUNet'
44 | params:
45 | num_classes: 2
46 | input_channels: 4
47 | deep_supervision: false
48 | # preprocess:
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Data Preparing
2 |
3 | The structure of the `datasets` and `configs` folders are here.
4 |
5 | ```bash
6 | .
7 | ├── configs
8 | │ ├── isic
9 | │ │ ├── isic2018_*.yaml
10 | │ │ └── ...
11 | │ └── segpc
12 | │ ├── segpc2021_*.yaml
13 | │ └── ...
14 | └── datasets
15 | ├── isic.ipynb
16 | ├── isic.py
17 | ├── segpc.ipynb
18 | ├── segpc.py
19 | ├── prepare_isic.ipynb
20 | └── prepare_segpc.ipynb
21 | ```
22 |
23 | In this work, we used 3 datasets (ISIC, SegPC, and Synapse). In the `datasets` folder, you can see files in 3 different formats,
24 |
25 | - `.py`: These files are used in the main codes for the training and testing procedures.
26 |
27 | - `.ipynb`: These files are just for a better understanding of datasets and how we are using them. You can also see at least one image of the related dataset on them.
28 |
29 | - `prepare_.ipynb`: You can use these files for preparing required files to use desired `Dataset` classes.
30 |
31 | When you want to use codes for preparing or using datasets **do not forget to rewrite your right data path** in related dataset files and also corresponding config files.
32 |
33 | In the following, you can find more information regarding accessing, preparing, and using datasets.
34 |
35 |
36 |
37 | ## ISIC2018 [(ISIC Challenge 2016-2020)](https://challenge.isic-archive.com/)
38 |
39 | The lesion images were acquired with a variety of dermatoscopy types, from all anatomic sites (excluding mucosa and nails), from a historical sample of patients presented for skin cancer screening, from several different institutions. Every lesion image contains exactly one primary lesion; other fiducial markers, smaller secondary lesions, or other pigmented regions may be neglected.
40 |
41 | #### [Data 2018](https://challenge.isic-archive.com/data/#2018)
42 |
43 | You can download the ISIC2018 dataset using the following links,
44 |
45 | - [Download training input images](https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Training_Input.zip)
46 |
47 | - [Download training ground truth masks](https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Training_GroundTruth.zip)
48 |
49 | After downloading, extract them in a folder, and ***do not forget to put the path*** on relevant files.
50 |
51 | The input data are dermoscopic lesion images in JPEG format.
52 |
53 | All lesion images are named using the scheme `ISIC_.jpg`, where `` is a 7-digit unique identifier. EXIF tags in the images have been removed; any remaining EXIF tags should not be relied upon to provide accurate metadata.
54 |
55 | At bellow, you can see a sample of this dataset.
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 | ## SegPC2021 [(SegPC Challenge 2021)](https://ieee-dataport.org/open-access/segpc-2021-segmentation-multiple-myeloma-plasma-cells-microscopic-images)
65 |
66 | This [challenge](https://segpc-2021.grand-challenge.org/) was positioned towards robust segmentation of cells which is the first stage to build such a tool for plasma cell cancer, namely, Multiple Myeloma (MM), which is a type of blood cancer. They provided images after stain color normalization.
67 |
68 | In this challenge, the target was instance segmentation of cells of interest. For each input image, both nuclei and cytoplasm were required to be segmented separately for each cell of interest.
69 |
70 | ### Our purpose
71 |
72 | In this work, we used this dataset for a bit different purpose. Here, we want to use the SegPC2021 database to create an independent image centered on the nucleus for each desired cell in this dataset, with the ultimate goal of obtaining the cytoplasm of that cell. Also, the nucleus mask will be used as an auxiliary channel in the input images. You can refer to [this article](https://arxiv.org/abs/2105.06238) for more information. You can download the main dataset using the following links:
73 |
74 | - [Official page](https://ieee-dataport.org/open-access/segpc-2021-segmentation-multiple-myeloma-plasma-cells-microscopic-images)
75 | - [Alternative link](https://www.kaggle.com/datasets/sbilab/segpc2021dataset)
76 |
77 | After downloading the dataset you can use the `prepare_segpc.ipynb` file to produce the required data file. ***Do not forget to put the right path on relevant files***.
78 |
79 | Below, you can see a sample of the prepared dataset with `scale=2.5`.
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 | ## Synapse
88 |
89 | 1. Access to the synapse multi-organ dataset:
90 |
91 | 1. Sign up in the [official Synapse website](https://www.synapse.org/#!Synapse:syn3193805/wiki/) and download the dataset. Convert them to numpy format, clip the images within [-125, 275], normalize each 3D image to [0, 1], and extract 2D slices from 3D volume for training cases while keeping the 3D volume in h5 format for testing cases.
92 |
93 | 2. The directory structure of the whole project is as follows:
94 |
95 | ```bash
96 | .
97 | ├── model
98 | │ └── vit_checkpoint
99 | │ └── imagenet21k
100 | │ ├── R50+ViT-B_16.npz
101 | │ └── *.npz
102 | └── data
103 | └──Synapse
104 | ├── test_vol_h5
105 | │ ├── case0001.npy.h5
106 | │ └── *.npy.h5
107 | └── train_npz
108 | ├── case0005_slice000.npz
109 | └── *.npz
110 | ```
111 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/isic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # ## [ISIC Challenge (2016-2020)](https://challenge.isic-archive.com/)
5 | # ---
6 | #
7 | # ### [Data 2018](https://challenge.isic-archive.com/data/)
8 | #
9 | # The input data are dermoscopic lesion images in JPEG format.
10 | #
11 | # All lesion images are named using the scheme `ISIC_.jpg`, where `` is a 7-digit unique identifier. EXIF tags in the images have been removed; any remaining EXIF tags should not be relied upon to provide accurate metadata.
12 | #
13 | # The lesion images were acquired with a variety of dermatoscope types, from all anatomic sites (excluding mucosa and nails), from a historical sample of patients presented for skin cancer screening, from several different institutions. Every lesion image contains exactly one primary lesion; other fiducial markers, smaller secondary lesions, or other pigmented regions may be neglected.
14 | #
15 | # The distribution of disease states represent a modified "real world" setting whereby there are more benign lesions than malignant lesions, but an over-representation of malignancies.
16 |
17 | # In[2]:
18 |
19 |
20 | import os
21 | import glob
22 | import numpy as np
23 | import torch
24 | from torch.utils.data import Dataset
25 | from torchvision import transforms, utils
26 | from torchvision.io import read_image
27 | from torchvision.io.image import ImageReadMode
28 | import torch.nn.functional as F
29 |
30 |
31 |
32 | # In[3]:
33 |
34 | class ISIC2018Dataset(Dataset):
35 | def __init__(self, data_dir=None, one_hot=True, img_transform=None, msk_transform=None):
36 | # pre-set variables
37 | self.data_prefix = "ISIC_"
38 | self.target_postfix = "_segmentation"
39 | self.target_fex = "png"
40 | self.input_fex = "jpg"
41 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018"
42 | self.imgs_dir = os.path.join(self.data_dir, "ISIC2018_Task1-2_Training_Input")
43 | self.msks_dir = os.path.join(self.data_dir, "ISIC2018_Task1_Training_GroundTruth")
44 |
45 | # input parameters
46 | self.img_dirs = glob.glob(f"{self.imgs_dir}/*.{self.input_fex}")
47 | self.data_ids = [d.split(self.data_prefix)[1].split(f".{self.input_fex}")[0] for d in self.img_dirs]
48 | self.one_hot = one_hot
49 | self.img_transform = img_transform
50 | self.msk_transform = msk_transform
51 |
52 | def get_img_by_id(self, id):
53 | img_dir = os.path.join(self.imgs_dir, f"{self.data_prefix}{id}.{self.input_fex}")
54 | img = read_image(img_dir, ImageReadMode.RGB)
55 | return img
56 |
57 | def get_msk_by_id(self, id):
58 | msk_dir = os.path.join(self.msks_dir, f"{self.data_prefix}{id}{self.target_postfix}.{self.target_fex}")
59 | msk = read_image(msk_dir, ImageReadMode.GRAY)
60 | return msk
61 |
62 | def __len__(self):
63 | return len(self.data_ids)
64 |
65 | def __getitem__(self, idx):
66 | data_id = self.data_ids[idx]
67 | img = self.get_img_by_id(data_id)
68 | msk = self.get_msk_by_id(data_id)
69 |
70 | if self.img_transform:
71 | img = self.img_transform(img)
72 | img = (img - img.min())/(img.max() - img.min())
73 | if self.msk_transform:
74 | msk = self.msk_transform(msk)
75 | msk = (msk - msk.min())/(msk.max() - msk.min())
76 |
77 | if self.one_hot:
78 | msk = F.one_hot(torch.squeeze(msk).to(torch.int64))
79 | msk = torch.moveaxis(msk, -1, 0).to(torch.float)
80 |
81 | sample = {'image': img, 'mask': msk, 'id': data_id}
82 | return sample
83 |
84 |
85 | class ISIC2018DatasetFast(Dataset):
86 | def __init__(self, mode, data_dir=None, one_hot=True, img_transform=None, msk_transform=None):
87 | # pre-set variables
88 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018/np"
89 |
90 | # input parameters
91 | self.one_hot = one_hot
92 |
93 | X = np.load(f"{self.data_dir}/X_tr_224x224.npy")
94 | Y = np.load(f"{self.data_dir}/Y_tr_224x224.npy")
95 |
96 | X = torch.tensor(X)
97 | Y = torch.tensor(Y)
98 |
99 | if mode == "tr":
100 | self.imgs = X[0:1815]
101 | self.msks = Y[0:1815]
102 | elif mode == "vl":
103 | self.imgs = X[1815:1815+259]
104 | self.msks = Y[1815:1815+259]
105 | elif mode == "te":
106 | self.imgs = X[1815+259:2594]
107 | self.msks = Y[1815+259:2594]
108 | else:
109 | raise ValueError()
110 |
111 | def __len__(self):
112 | return len(self.imgs)
113 |
114 | def __getitem__(self, idx):
115 | data_id = idx
116 | img = self.imgs[idx]
117 | msk = self.msks[idx]
118 |
119 | if self.one_hot:
120 | msk = F.one_hot(torch.squeeze(msk).to(torch.int64))
121 | msk = torch.moveaxis(msk, -1, 0).to(torch.float)
122 |
123 | sample = {'image': img, 'mask': msk, 'id': data_id}
124 | return sample
125 |
126 |
127 | class ISIC2018TrainingDataset(Dataset):
128 | def __init__(self, data_dir=None, img_transform=None, msk_transform=None):
129 | # pre-set variables
130 | self.data_prefix = "ISIC_"
131 | self.target_postfix = "_segmentation"
132 | self.target_fex = "png"
133 | self.input_fex = "jpg"
134 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018"
135 | self.imgs_dir = os.path.join(self.data_dir, "ISIC2018_Task1-2_Training_Input")
136 | self.msks_dir = os.path.join(self.data_dir, "ISIC2018_Task1_Training_GroundTruth")
137 |
138 | # input parameters
139 | self.img_dirs = glob.glob(f"{self.imgs_dir}/*.{self.input_fex}")
140 | self.data_ids = [d.split(self.data_prefix)[1].split(f".{self.input_fex}")[0] for d in self.img_dirs]
141 | self.img_transform = img_transform
142 | self.msk_transform = msk_transform
143 |
144 | def get_img_by_id(self, id):
145 | img_dir = os.path.join(self.imgs_dir, f"{self.data_prefix}{id}.{self.input_fex}")
146 | img = read_image(img_dir, ImageReadMode.RGB)
147 | return img
148 |
149 | def get_msk_by_id(self, id):
150 | msk_dir = os.path.join(self.msks_dir, f"{self.data_prefix}{id}{self.target_postfix}.{self.target_fex}")
151 | msk = read_image(msk_dir, ImageReadMode.GRAY)
152 | return msk
153 |
154 | def __len__(self):
155 | return len(self.data_ids)
156 |
157 | def __getitem__(self, idx):
158 | data_id = self.data_ids[idx]
159 | img = self.get_img_by_id(data_id)
160 | msk = self.get_msk_by_id(data_id)
161 |
162 | if self.img_transform:
163 | img = self.img_transform(img)
164 | img = (img - img.min())/(img.max() - img.min())
165 | if self.msk_transform:
166 | msk = self.msk_transform(msk)
167 | msk = (msk - msk.min())/(msk.max() - msk.min())
168 | sample = {'image': img, 'mask': msk, 'id': data_id}
169 | return sample
170 |
171 |
172 | # In[10]:
173 |
174 |
175 | class ISIC2018ValidationDataset(Dataset):
176 | def __init__(self, data_dir=None, img_transform=None, msk_transform=None):
177 | # pre-set variables
178 | self.data_prefix = "ISIC_"
179 | self.target_postfix = "_segmentation"
180 | self.target_fex = "png"
181 | self.input_fex = "jpg"
182 | self.data_dir = data_dir if data_dir else "/path/to/datasets/ISIC2018"
183 | self.imgs_dir = os.path.join(self.data_dir, "ISIC2018_Task1-2_Validation_Input")
184 | self.msks_dir = os.path.join(self.data_dir, "ISIC2018_Task1_Validation_GroundTruth")
185 |
186 | # input parameters
187 | self.img_dirs = glob.glob(f"{self.imgs_dir}/*.{self.input_fex}")
188 | self.data_ids = [d.split(self.data_prefix)[1].split(f".{self.input_fex}")[0] for d in self.img_dirs]
189 | self.img_transform = img_transform
190 | self.msk_transform = msk_transform
191 |
192 | def get_img_by_id(self, id):
193 | img_dir = os.path.join(self.imgs_dir, f"{self.data_prefix}{id}.{self.input_fex}")
194 | img = read_image(img_dir, ImageReadMode.RGB)
195 | return img
196 |
197 | def get_msk_by_id(self, id):
198 | msk_dir = os.path.join(self.msks_dir, f"{self.data_prefix}{id}{self.target_postfix}.{self.target_fex}")
199 | msk = read_image(msk_dir, ImageReadMode.GRAY)
200 | return msk
201 |
202 | def __len__(self):
203 | return len(self.data_ids)
204 |
205 | def __getitem__(self, idx):
206 | data_id = self.data_ids[idx]
207 | img = self.get_img_by_id(data_id)
208 | msk = self.get_msk_by_id(data_id)
209 |
210 | print(f"msk shape: {msk.shape}")
211 | print(f"img shape: {img.shape}")
212 |
213 |
214 | if self.img_transform:
215 | img = self.img_transform(img)
216 | img = (img - img.min())/(img.max() - img.min())
217 | if self.msk_transform:
218 | msk = self.msk_transform(msk)
219 | msk = (msk - msk.min())/(msk.max() - msk.min())
220 |
221 | sample = {'image': img, 'mask': msk, 'id': data_id}
222 | return sample
223 |
224 |
225 | # ## Test dataset and dataloader
226 | # ---
227 |
228 | # # In[13]:
229 |
230 |
231 | # import sys
232 | # sys.path.append('..')
233 | # from utils import show_sbs
234 | # from torch.utils.data import DataLoader, Subset
235 | # from torchvision import transforms
236 |
237 |
238 |
239 | # # ------------------- params --------------------
240 | # INPUT_SIZE = 224
241 |
242 | # TR_BATCH_SIZE = 8
243 | # TR_DL_SHUFFLE = True
244 | # TR_DL_WORKER = 1
245 |
246 | # VL_BATCH_SIZE = 12
247 | # VL_DL_SHUFFLE = False
248 | # VL_DL_WORKER = 1
249 |
250 | # TE_BATCH_SIZE = 12
251 | # TE_DL_SHUFFLE = False
252 | # TE_DL_WORKER = 1
253 | # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
254 |
255 |
256 | # # ----------------- transform ------------------
257 | # # transform for image
258 | # img_transform = transforms.Compose([
259 | # transforms.Resize(
260 | # size=[INPUT_SIZE, INPUT_SIZE],
261 | # interpolation=transforms.functional.InterpolationMode.BILINEAR
262 | # ),
263 | # ])
264 | # # transform for mask
265 | # msk_transform = transforms.Compose([
266 | # transforms.Resize(
267 | # size=[INPUT_SIZE, INPUT_SIZE],
268 | # interpolation=transforms.functional.InterpolationMode.NEAREST
269 | # ),
270 | # ])
271 | # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
272 |
273 |
274 | # # ----------------- dataset --------------------
275 | # # preparing training dataset
276 | # train_dataset = ISIC2018TrainingDataset(
277 | # img_transform=img_transform,
278 | # msk_transform=msk_transform
279 | # )
280 |
281 | # # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
282 | # # !cat ~/deeplearning/skin/Prepare_ISIC2018.py
283 |
284 | # indices = list(range(len(train_dataset)))
285 |
286 | # # split indices to: -> train, validation, and test
287 | # tr_indices = indices[0:1815]
288 | # vl_indices = indices[1815:1815+259]
289 | # te_indices = indices[1815+259:2594]
290 |
291 | # # create new datasets from train dataset as training, validation, and test
292 | # tr_dataset = Subset(train_dataset, tr_indices)
293 | # vl_dataset = Subset(train_dataset, vl_indices)
294 | # te_dataset = Subset(train_dataset, te_indices)
295 |
296 | # # prepare train dataloader
297 | # tr_loader = DataLoader(
298 | # tr_dataset,
299 | # batch_size=TR_BATCH_SIZE,
300 | # shuffle=TR_DL_SHUFFLE,
301 | # num_workers=TR_DL_WORKER,
302 | # pin_memory=True
303 | # )
304 |
305 | # # prepare validation dataloader
306 | # vl_loader = DataLoader(
307 | # vl_dataset,
308 | # batch_size=VL_BATCH_SIZE,
309 | # shuffle=VL_DL_SHUFFLE,
310 | # num_workers=VL_DL_WORKER,
311 | # pin_memory=True
312 | # )
313 |
314 | # # prepare test dataloader
315 | # te_loader = DataLoader(
316 | # te_dataset,
317 | # batch_size=TE_BATCH_SIZE,
318 | # shuffle=TE_DL_SHUFFLE,
319 | # num_workers=TE_DL_WORKER,
320 | # pin_memory=True
321 | # )
322 |
323 | # # -------------- test -----------------
324 | # # test and visualize the input data
325 | # for sample in tr_loader:
326 | # img = sample['image']
327 | # msk = sample['mask']
328 | # print("Training")
329 | # show_sbs(img[0], msk[0])
330 | # break
331 |
332 | # for sample in vl_loader:
333 | # img = sample['image']
334 | # msk = sample['mask']
335 | # print("Validation")
336 | # show_sbs(img[0], msk[0])
337 | # break
338 |
339 | # for sample in te_loader:
340 | # img = sample['image']
341 | # msk = sample['mask']
342 | # print("Test")
343 | # show_sbs(img[0], msk[0])
344 | # break
345 |
346 |
--------------------------------------------------------------------------------
/datasets/segpc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms, utils
7 | from torchvision.io import read_image
8 | from torchvision.io.image import ImageReadMode
9 | import torch.nn.functional as F
10 |
11 |
12 |
13 | class SegPC2021Dataset(Dataset):
14 | def __init__(self,
15 | mode, # 'tr'-> train, 'vl' -> validation, 'te' -> test
16 | input_size=224,
17 | scale=2.5,
18 | data_dir=None,
19 | dataset_dir=None,
20 | one_hot=True,
21 | force_rebuild=False,
22 | img_transform=None,
23 | msk_transform=None):
24 | # pre-set variables
25 | self.data_dir = data_dir if data_dir else "/path/to/datasets/segpc/np"
26 | self.dataset_dir = dataset_dir if dataset_dir else "/path/to/datasets/segpc/TCIA_SegPC_dataset/"
27 | self.mode = mode
28 | # input parameters
29 | self.img_transform = img_transform
30 | self.msk_transform = msk_transform
31 | self.input_size = input_size
32 | self.scale = scale
33 | self.one_hot = one_hot
34 |
35 | # loading data
36 | self.load_dataset(force_rebuild=force_rebuild)
37 |
38 |
39 | def load_dataset(self, force_rebuild):
40 | INPUT_SIZE = self.input_size
41 | ADD = self.data_dir
42 |
43 | # build_segpc_dataset(
44 | # input_size = self.input_size,
45 | # scale = self.scale,
46 | # data_dir = self.data_dir,
47 | # dataset_dir = self.dataset_dir,
48 | # mode = self.mode,
49 | # force_rebuild = force_rebuild,
50 | # )
51 |
52 | print(f'loading X_{self.mode}...')
53 | self.X = np.load(f'{ADD}/cyts_{self.mode}_{self.input_size}x{self.input_size}_s{self.scale}_X.npy')
54 | print(f'loading Y_{self.mode}...')
55 | self.Y = np.load(f'{ADD}/cyts_{self.mode}_{self.input_size}x{self.input_size}_s{self.scale}_Y.npy')
56 | print('finished.')
57 |
58 |
59 | def __len__(self):
60 | return len(self.X)
61 |
62 |
63 | def __getitem__(self, idx):
64 | img = self.X[idx]
65 | msk = self.Y[idx]
66 | msk = np.where(msk<0.5, 0, 1)
67 |
68 | if self.img_transform:
69 | img = self.img_transform(img)
70 | img = (img - img.min())/(img.max() - img.min())
71 | if self.msk_transform:
72 | msk = self.msk_transform(msk)
73 | msk = (msk - msk.min())/(msk.max() - msk.min())
74 |
75 | if self.one_hot:
76 | msk = F.one_hot(torch.squeeze(msk).to(torch.int64))
77 | msk = torch.moveaxis(msk, -1, 0).to(torch.float)
78 |
79 | sample = {'image': img, 'mask': msk, 'id': idx}
80 | return sample
--------------------------------------------------------------------------------
/images/ComparisonOfModels.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/ComparisonOfModels.png
--------------------------------------------------------------------------------
/images/U-Net_Taxonomy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/U-Net_Taxonomy.png
--------------------------------------------------------------------------------
/images/isic2018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/isic2018.png
--------------------------------------------------------------------------------
/images/isic2018_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/isic2018_sample.png
--------------------------------------------------------------------------------
/images/segpc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/segpc.png
--------------------------------------------------------------------------------
/images/segpc2021_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/segpc2021_sample.png
--------------------------------------------------------------------------------
/images/synapse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/synapse.png
--------------------------------------------------------------------------------
/images/unet-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/images/unet-pipeline.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 |
5 |
6 | EPSILON = 1e-6
7 |
8 | class DiceLoss(torch.nn.Module):
9 | def __init__(self,):
10 | super().__init__()
11 |
12 | def forward(self, pred, mask):
13 | pred = pred.flatten()
14 | mask = mask.flatten()
15 |
16 | intersect = (mask * pred).sum()
17 | dice_score = 2*intersect / (pred.sum() + mask.sum() + EPSILON)
18 | dice_loss = 1 - dice_score
19 | return dice_loss
20 |
21 |
22 | class DiceLossWithLogtis(torch.nn.Module):
23 | def __init__(self,):
24 | super().__init__()
25 |
26 | def forward(self, pred, mask):
27 | prob = F.softmax(pred, dim=1)
28 | true_1_hot = mask.type(prob.type())
29 |
30 | dims = (0,) + tuple(range(2, true_1_hot.ndimension()))
31 | intersection = torch.sum(prob * true_1_hot, dims)
32 | cardinality = torch.sum(prob + true_1_hot, dims)
33 | dice_loss = (2. * intersection / (cardinality + EPSILON)).mean()
34 | return (1 - dice_loss)
35 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/models/__init__.py
--------------------------------------------------------------------------------
/models/_missformer/MISSFormer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .segformer import *
4 | from typing import Tuple
5 | from einops import rearrange
6 |
7 | class PatchExpand(nn.Module):
8 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
9 | super().__init__()
10 | self.input_resolution = input_resolution
11 | self.dim = dim
12 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
13 | self.norm = norm_layer(dim // dim_scale)
14 |
15 | def forward(self, x):
16 | """
17 | x: B, H*W, C
18 | """
19 | # print("x_shape-----",x.shape)
20 | H, W = self.input_resolution
21 | x = self.expand(x)
22 |
23 | B, L, C = x.shape
24 | # print(x.shape)
25 | assert L == H * W, "input feature has wrong size"
26 |
27 | x = x.view(B, H, W, C)
28 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
29 | x = x.view(B,-1,C//4)
30 | x= self.norm(x.clone())
31 |
32 | return x
33 |
34 | class FinalPatchExpand_X4(nn.Module):
35 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
36 | super().__init__()
37 | self.input_resolution = input_resolution
38 | self.dim = dim
39 | self.dim_scale = dim_scale
40 | self.expand = nn.Linear(dim, 16*dim, bias=False)
41 | self.output_dim = dim
42 | self.norm = norm_layer(self.output_dim)
43 |
44 | def forward(self, x):
45 | """
46 | x: B, H*W, C
47 | """
48 | H, W = self.input_resolution
49 | x = self.expand(x)
50 | B, L, C = x.shape
51 | assert L == H * W, "input feature has wrong size"
52 |
53 | x = x.view(B, H, W, C)
54 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
55 | x = x.view(B,-1,self.output_dim)
56 | x= self.norm(x.clone())
57 |
58 | return x
59 |
60 |
61 | class SegU_decoder(nn.Module):
62 | def __init__(self, input_size, in_out_chan, heads, reduction_ratios, n_class=9, norm_layer=nn.LayerNorm, is_last=False):
63 | super().__init__()
64 | dims = in_out_chan[0]
65 | out_dim = in_out_chan[1]
66 | if not is_last:
67 | self.concat_linear = nn.Linear(dims*2, out_dim)
68 | # transformer decoder
69 | self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
70 | self.last_layer = None
71 | else:
72 | self.concat_linear = nn.Linear(dims*4, out_dim)
73 | # transformer decoder
74 | self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
75 | # self.last_layer = nn.Linear(out_dim, n_class)
76 | self.last_layer = nn.Conv2d(out_dim, n_class,1)
77 | # self.last_layer = None
78 |
79 | self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios)
80 | self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios)
81 |
82 |
83 | def init_weights(self):
84 | for m in self.modules():
85 | if isinstance(m, nn.Linear):
86 | nn.init.xavier_uniform_(m.weight)
87 | if m.bias is not None:
88 | nn.init.zeros_(m.bias)
89 | elif isinstance(m, nn.LayerNorm):
90 | nn.init.ones_(m.weight)
91 | nn.init.zeros_(m.bias)
92 | elif isinstance(m, nn.Conv2d):
93 | nn.init.xavier_uniform_(m.weight)
94 | if m.bias is not None:
95 | nn.init.zeros_(m.bias)
96 |
97 | init_weights(self)
98 |
99 |
100 |
101 | def forward(self, x1, x2=None):
102 | if x2 is not None:
103 | b, h, w, c = x2.shape
104 | x2 = x2.view(b, -1, c)
105 | # print("------",x1.shape, x2.shape)
106 | cat_x = torch.cat([x1, x2], dim=-1)
107 | # print("-----catx shape", cat_x.shape)
108 | cat_linear_x = self.concat_linear(cat_x)
109 | tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
110 | tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
111 |
112 | if self.last_layer:
113 | out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2))
114 | else:
115 | out = self.layer_up(tran_layer_2)
116 | else:
117 | # if len(x1.shape)>3:
118 | # x1 = x1.permute(0,2,3,1)
119 | # b, h, w, c = x1.shape
120 | # x1 = x1.view(b, -1, c)
121 | out = self.layer_up(x1)
122 | return out
123 |
124 |
125 | class BridgeLayer_4(nn.Module):
126 | def __init__(self, dims, head, reduction_ratios):
127 | super().__init__()
128 |
129 | self.norm1 = nn.LayerNorm(dims)
130 | self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
131 | self.norm2 = nn.LayerNorm(dims)
132 | self.mixffn1 = MixFFN_skip(dims,dims*4)
133 | self.mixffn2 = MixFFN_skip(dims*2,dims*8)
134 | self.mixffn3 = MixFFN_skip(dims*5,dims*20)
135 | self.mixffn4 = MixFFN_skip(dims*8,dims*32)
136 |
137 |
138 | def forward(self, inputs):
139 | B = inputs[0].shape[0]
140 | C = 64
141 | if (type(inputs) == list):
142 | # print("-----1-----")
143 | c1, c2, c3, c4 = inputs
144 | B, C, _, _= c1.shape
145 | c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64
146 | c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64
147 | c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64
148 | c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64
149 |
150 | # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
151 | inputs = torch.cat([c1f, c2f, c3f, c4f], -2)
152 | else:
153 | B,_,C = inputs.shape
154 |
155 | tx1 = inputs + self.attn(self.norm1(inputs))
156 | tx = self.norm2(tx1)
157 |
158 |
159 | tem1 = tx[:,:3136,:].reshape(B, -1, C)
160 | tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2)
161 | tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5)
162 | tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8)
163 |
164 | m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
165 | m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
166 | m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
167 | m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
168 |
169 | t1 = torch.cat([m1f, m2f, m3f, m4f], -2)
170 |
171 | tx2 = tx1 + t1
172 |
173 |
174 | return tx2
175 |
176 |
177 | class BridgeLayer_3(nn.Module):
178 | def __init__(self, dims, head, reduction_ratios):
179 | super().__init__()
180 |
181 | self.norm1 = nn.LayerNorm(dims)
182 | self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
183 | self.norm2 = nn.LayerNorm(dims)
184 | # self.mixffn1 = MixFFN(dims,dims*4)
185 | self.mixffn2 = MixFFN(dims*2,dims*8)
186 | self.mixffn3 = MixFFN(dims*5,dims*20)
187 | self.mixffn4 = MixFFN(dims*8,dims*32)
188 |
189 |
190 | def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
191 | B = inputs[0].shape[0]
192 | C = 64
193 | if (type(inputs) == list):
194 | # print("-----1-----")
195 | c1, c2, c3, c4 = inputs
196 | B, C, _, _= c1.shape
197 | c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64
198 | c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64
199 | c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64
200 | c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64
201 |
202 | # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
203 | inputs = torch.cat([c2f, c3f, c4f], -2)
204 | else:
205 | B,_,C = inputs.shape
206 |
207 | tx1 = inputs + self.attn(self.norm1(inputs))
208 | tx = self.norm2(tx1)
209 |
210 |
211 | # tem1 = tx[:,:3136,:].reshape(B, -1, C)
212 | tem2 = tx[:,:1568,:].reshape(B, -1, C*2)
213 | tem3 = tx[:,1568:2548,:].reshape(B, -1, C*5)
214 | tem4 = tx[:,2548:2940,:].reshape(B, -1, C*8)
215 |
216 | # m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
217 | m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
218 | m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
219 | m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
220 |
221 | t1 = torch.cat([m2f, m3f, m4f], -2)
222 |
223 | tx2 = tx1 + t1
224 |
225 |
226 | return tx2
227 |
228 |
229 |
230 | class BridegeBlock_4(nn.Module):
231 | def __init__(self, dims, head, reduction_ratios):
232 | super().__init__()
233 | self.bridge_layer1 = BridgeLayer_4(dims, head, reduction_ratios)
234 | self.bridge_layer2 = BridgeLayer_4(dims, head, reduction_ratios)
235 | self.bridge_layer3 = BridgeLayer_4(dims, head, reduction_ratios)
236 | self.bridge_layer4 = BridgeLayer_4(dims, head, reduction_ratios)
237 |
238 | def forward(self, x: torch.Tensor) -> torch.Tensor:
239 | bridge1 = self.bridge_layer1(x)
240 | bridge2 = self.bridge_layer2(bridge1)
241 | bridge3 = self.bridge_layer3(bridge2)
242 | bridge4 = self.bridge_layer4(bridge3)
243 |
244 | B,_,C = bridge4.shape
245 | outs = []
246 |
247 | sk1 = bridge4[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2)
248 | sk2 = bridge4[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0,3,1,2)
249 | sk3 = bridge4[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0,3,1,2)
250 | sk4 = bridge4[:,5684:6076,:].reshape(B, 7, 7, C*8).permute(0,3,1,2)
251 |
252 | outs.append(sk1)
253 | outs.append(sk2)
254 | outs.append(sk3)
255 | outs.append(sk4)
256 |
257 | return outs
258 |
259 |
260 | class BridegeBlock_3(nn.Module):
261 | def __init__(self, dims, head, reduction_ratios):
262 | super().__init__()
263 | self.bridge_layer1 = BridgeLayer_3(dims, head, reduction_ratios)
264 | self.bridge_layer2 = BridgeLayer_3(dims, head, reduction_ratios)
265 | self.bridge_layer3 = BridgeLayer_3(dims, head, reduction_ratios)
266 | self.bridge_layer4 = BridgeLayer_3(dims, head, reduction_ratios)
267 |
268 | def forward(self, x: torch.Tensor) -> torch.Tensor:
269 | outs = []
270 | if (type(x) == list):
271 | # print("-----1-----")
272 | outs.append(x[0])
273 | bridge1 = self.bridge_layer1(x)
274 | bridge2 = self.bridge_layer2(bridge1)
275 | bridge3 = self.bridge_layer3(bridge2)
276 | bridge4 = self.bridge_layer4(bridge3)
277 |
278 | B,_,C = bridge4.shape
279 |
280 |
281 | # sk1 = bridge2[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2)
282 | sk2 = bridge4[:,:1568,:].reshape(B, 28, 28, C*2).permute(0,3,1,2)
283 | sk3 = bridge4[:,1568:2548,:].reshape(B, 14, 14, C*5).permute(0,3,1,2)
284 | sk4 = bridge4[:,2548:2940,:].reshape(B, 7, 7, C*8).permute(0,3,1,2)
285 |
286 | # outs.append(sk1)
287 | outs.append(sk2)
288 | outs.append(sk3)
289 | outs.append(sk4)
290 |
291 | return outs
292 |
293 |
294 | class MyDecoderLayer(nn.Module):
295 | def __init__(self, input_size, in_out_chan, heads, reduction_ratios,token_mlp_mode, n_class=9, norm_layer=nn.LayerNorm, is_last=False):
296 | super().__init__()
297 | dims = in_out_chan[0]
298 | out_dim = in_out_chan[1]
299 | if not is_last:
300 | self.concat_linear = nn.Linear(dims*2, out_dim)
301 | # transformer decoder
302 | self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
303 | self.last_layer = None
304 | else:
305 | self.concat_linear = nn.Linear(dims*4, out_dim)
306 | # transformer decoder
307 | self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
308 | # self.last_layer = nn.Linear(out_dim, n_class)
309 | self.last_layer = nn.Conv2d(out_dim, n_class,1)
310 | # self.last_layer = None
311 |
312 | self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
313 | self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
314 |
315 |
316 | def init_weights(self):
317 | for m in self.modules():
318 | if isinstance(m, nn.Linear):
319 | nn.init.xavier_uniform_(m.weight)
320 | if m.bias is not None:
321 | nn.init.zeros_(m.bias)
322 | elif isinstance(m, nn.LayerNorm):
323 | nn.init.ones_(m.weight)
324 | nn.init.zeros_(m.bias)
325 | elif isinstance(m, nn.Conv2d):
326 | nn.init.xavier_uniform_(m.weight)
327 | if m.bias is not None:
328 | nn.init.zeros_(m.bias)
329 |
330 | init_weights(self)
331 |
332 | def forward(self, x1, x2=None):
333 | if x2 is not None:
334 | b, h, w, c = x2.shape
335 | x2 = x2.view(b, -1, c)
336 | # print("------",x1.shape, x2.shape)
337 | cat_x = torch.cat([x1, x2], dim=-1)
338 | # print("-----catx shape", cat_x.shape)
339 | cat_linear_x = self.concat_linear(cat_x)
340 | tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
341 | tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
342 |
343 | if self.last_layer:
344 | out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2))
345 | else:
346 | out = self.layer_up(tran_layer_2)
347 | else:
348 | # if len(x1.shape)>3:
349 | # x1 = x1.permute(0,2,3,1)
350 | # b, h, w, c = x1.shape
351 | # x1 = x1.view(b, -1, c)
352 | out = self.layer_up(x1)
353 | return out
354 |
355 | class MISSFormer(nn.Module):
356 | def __init__(self, num_classes=9, in_ch=3, token_mlp_mode="mix_skip", encoder_pretrained=True):
357 | super().__init__()
358 |
359 | reduction_ratios = [8, 4, 2, 1]
360 | heads = [1, 2, 5, 8]
361 | d_base_feat_size = 7 #16 for 512 inputsize 7for 224
362 | in_out_chan = [[32, 64],[144, 128],[288, 320],[512, 512]]
363 |
364 | dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]]
365 | self.backbone = MiT(224, dims, layers,in_ch, token_mlp_mode)
366 |
367 | self.reduction_ratios = [1, 2, 4, 8]
368 | self.bridge = BridegeBlock_4(64, 1, self.reduction_ratios)
369 |
370 | self.decoder_3= MyDecoderLayer((d_base_feat_size,d_base_feat_size), in_out_chan[3], heads[3], reduction_ratios[3],token_mlp_mode, n_class=num_classes)
371 | self.decoder_2= MyDecoderLayer((d_base_feat_size*2,d_base_feat_size*2),in_out_chan[2], heads[2], reduction_ratios[2], token_mlp_mode, n_class=num_classes)
372 | self.decoder_1= MyDecoderLayer((d_base_feat_size*4,d_base_feat_size*4), in_out_chan[1], heads[1], reduction_ratios[1], token_mlp_mode, n_class=num_classes)
373 | self.decoder_0= MyDecoderLayer((d_base_feat_size*8,d_base_feat_size*8), in_out_chan[0], heads[0], reduction_ratios[0], token_mlp_mode, n_class=num_classes, is_last=True)
374 |
375 |
376 | def forward(self, x):
377 | #---------------Encoder-------------------------
378 | if x.size()[1] == 1:
379 | x = x.repeat(1,3,1,1)
380 |
381 | encoder = self.backbone(x)
382 | bridge = self.bridge(encoder) #list
383 |
384 | b,c,_,_ = bridge[3].shape
385 | # print(bridge[3].shape, bridge[2].shape,bridge[1].shape, bridge[0].shape)
386 | #---------------Decoder-------------------------
387 | # print("stage3-----")
388 | tmp_3 = self.decoder_3(bridge[3].permute(0,2,3,1).view(b,-1,c))
389 | # print("stage2-----")
390 | tmp_2 = self.decoder_2(tmp_3, bridge[2].permute(0,2,3,1))
391 | # print("stage1-----")
392 | tmp_1 = self.decoder_1(tmp_2, bridge[1].permute(0,2,3,1))
393 | # print("stage0-----")
394 | tmp_0 = self.decoder_0(tmp_1, bridge[0].permute(0,2,3,1))
395 |
396 | return tmp_0
397 |
398 |
399 |
--------------------------------------------------------------------------------
/models/_missformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/models/_missformer/__init__.py
--------------------------------------------------------------------------------
/models/_resunet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NITR098/Awesome-U-Net/8785625b1113979303e8f7c3f88cc97a1f68324c/models/_resunet/__init__.py
--------------------------------------------------------------------------------
/models/_resunet/modules.py:
--------------------------------------------------------------------------------
1 | # https://github.com/rishikksh20/ResUnet/blob/master/core/modules.py
2 |
3 |
4 | import torch.nn as nn
5 | import torch
6 |
7 |
8 | class ResidualConv(nn.Module):
9 | def __init__(self, input_dim, output_dim, stride, padding):
10 | super(ResidualConv, self).__init__()
11 |
12 | self.conv_block = nn.Sequential(
13 | nn.BatchNorm2d(input_dim),
14 | nn.ReLU(),
15 | nn.Conv2d(
16 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
17 | ),
18 | nn.BatchNorm2d(output_dim),
19 | nn.ReLU(),
20 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
21 | )
22 | self.conv_skip = nn.Sequential(
23 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
24 | nn.BatchNorm2d(output_dim),
25 | )
26 |
27 | def forward(self, x):
28 |
29 | return self.conv_block(x) + self.conv_skip(x)
30 |
31 |
32 | class Upsample(nn.Module):
33 | def __init__(self, input_dim, output_dim, kernel, stride):
34 | super(Upsample, self).__init__()
35 |
36 | self.upsample = nn.ConvTranspose2d(
37 | input_dim, output_dim, kernel_size=kernel, stride=stride
38 | )
39 |
40 | def forward(self, x):
41 | return self.upsample(x)
42 |
43 |
44 | class Squeeze_Excite_Block(nn.Module):
45 | def __init__(self, channel, reduction=16):
46 | super(Squeeze_Excite_Block, self).__init__()
47 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
48 | self.fc = nn.Sequential(
49 | nn.Linear(channel, channel // reduction, bias=False),
50 | nn.ReLU(inplace=True),
51 | nn.Linear(channel // reduction, channel, bias=False),
52 | nn.Sigmoid(),
53 | )
54 |
55 | def forward(self, x):
56 | b, c, _, _ = x.size()
57 | y = self.avg_pool(x).view(b, c)
58 | y = self.fc(y).view(b, c, 1, 1)
59 | return x * y.expand_as(x)
60 |
61 |
62 | class ASPP(nn.Module):
63 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
64 | super(ASPP, self).__init__()
65 |
66 | self.aspp_block1 = nn.Sequential(
67 | nn.Conv2d(
68 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
69 | ),
70 | nn.ReLU(inplace=True),
71 | nn.BatchNorm2d(out_dims),
72 | )
73 | self.aspp_block2 = nn.Sequential(
74 | nn.Conv2d(
75 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
76 | ),
77 | nn.ReLU(inplace=True),
78 | nn.BatchNorm2d(out_dims),
79 | )
80 | self.aspp_block3 = nn.Sequential(
81 | nn.Conv2d(
82 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
83 | ),
84 | nn.ReLU(inplace=True),
85 | nn.BatchNorm2d(out_dims),
86 | )
87 |
88 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
89 | self._init_weights()
90 |
91 | def forward(self, x):
92 | x1 = self.aspp_block1(x)
93 | x2 = self.aspp_block2(x)
94 | x3 = self.aspp_block3(x)
95 | out = torch.cat([x1, x2, x3], dim=1)
96 | return self.output(out)
97 |
98 | def _init_weights(self):
99 | for m in self.modules():
100 | if isinstance(m, nn.Conv2d):
101 | nn.init.kaiming_normal_(m.weight)
102 | elif isinstance(m, nn.BatchNorm2d):
103 | m.weight.data.fill_(1)
104 | m.bias.data.zero_()
105 |
106 |
107 | class Upsample_(nn.Module):
108 | def __init__(self, scale=2):
109 | super(Upsample_, self).__init__()
110 |
111 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
112 |
113 | def forward(self, x):
114 | return self.upsample(x)
115 |
116 |
117 | class AttentionBlock(nn.Module):
118 | def __init__(self, input_encoder, input_decoder, output_dim):
119 | super(AttentionBlock, self).__init__()
120 |
121 | self.conv_encoder = nn.Sequential(
122 | nn.BatchNorm2d(input_encoder),
123 | nn.ReLU(),
124 | nn.Conv2d(input_encoder, output_dim, 3, padding=1),
125 | nn.MaxPool2d(2, 2),
126 | )
127 |
128 | self.conv_decoder = nn.Sequential(
129 | nn.BatchNorm2d(input_decoder),
130 | nn.ReLU(),
131 | nn.Conv2d(input_decoder, output_dim, 3, padding=1),
132 | )
133 |
134 | self.conv_attn = nn.Sequential(
135 | nn.BatchNorm2d(output_dim),
136 | nn.ReLU(),
137 | nn.Conv2d(output_dim, 1, 1),
138 | )
139 |
140 | def forward(self, x1, x2):
141 | out = self.conv_encoder(x1) + self.conv_decoder(x2)
142 | out = self.conv_attn(out)
143 | return out * x2
--------------------------------------------------------------------------------
/models/_resunet/res_unet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/rishikksh20/ResUnet
2 |
3 | import torch
4 | import torch.nn as nn
5 | from .modules import ResidualConv, Upsample
6 |
7 |
8 | class ResUnet(nn.Module):
9 | def __init__(self, in_ch, out_ch, filters=[64, 128, 256, 512]):
10 | super(ResUnet, self).__init__()
11 |
12 | self.input_layer = nn.Sequential(
13 | nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1),
14 | nn.BatchNorm2d(filters[0]),
15 | nn.ReLU(),
16 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
17 | )
18 | self.input_skip = nn.Sequential(
19 | nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1)
20 | )
21 |
22 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1)
23 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
24 |
25 | self.bridge = ResidualConv(filters[2], filters[3], 2, 1)
26 |
27 | self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)
28 | self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
29 |
30 | self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)
31 | self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
32 |
33 | self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)
34 | self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1)
35 |
36 | self.output_layer = nn.Sequential(
37 | nn.Conv2d(filters[0], out_ch, 1, 1),
38 | )
39 |
40 | def forward(self, x):
41 | # Encode
42 | x1 = self.input_layer(x) + self.input_skip(x)
43 | x2 = self.residual_conv_1(x1)
44 | x3 = self.residual_conv_2(x2)
45 | # Bridge
46 | x4 = self.bridge(x3)
47 | # Decode
48 | x4 = self.upsample_1(x4)
49 | x5 = torch.cat([x4, x3], dim=1)
50 |
51 | x6 = self.up_residual_conv1(x5)
52 |
53 | x6 = self.upsample_2(x6)
54 | x7 = torch.cat([x6, x2], dim=1)
55 |
56 | x8 = self.up_residual_conv2(x7)
57 |
58 | x8 = self.upsample_3(x8)
59 | x9 = torch.cat([x8, x1], dim=1)
60 |
61 | x10 = self.up_residual_conv3(x9)
62 |
63 | output = self.output_layer(x10)
64 |
65 | return output
66 |
--------------------------------------------------------------------------------
/models/_transunet/vit_seg_configs.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 | def get_b16_config():
4 | """Returns the ViT-B/16 configuration."""
5 | config = ml_collections.ConfigDict()
6 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
7 | config.hidden_size = 768
8 | config.transformer = ml_collections.ConfigDict()
9 | config.transformer.mlp_dim = 3072
10 | config.transformer.num_heads = 12
11 | config.transformer.num_layers = 12
12 | config.transformer.attention_dropout_rate = 0.0
13 | config.transformer.dropout_rate = 0.1
14 |
15 | config.classifier = 'seg'
16 | config.representation_size = None
17 | config.resnet_pretrained_path = None
18 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
19 | config.patch_size = 16
20 |
21 | config.decoder_channels = (256, 128, 64, 16)
22 | config.n_classes = 2
23 | config.activation = 'softmax'
24 | return config
25 |
26 |
27 | def get_testing():
28 | """Returns a minimal configuration for testing."""
29 | config = ml_collections.ConfigDict()
30 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
31 | config.hidden_size = 1
32 | config.transformer = ml_collections.ConfigDict()
33 | config.transformer.mlp_dim = 1
34 | config.transformer.num_heads = 1
35 | config.transformer.num_layers = 1
36 | config.transformer.attention_dropout_rate = 0.0
37 | config.transformer.dropout_rate = 0.1
38 | config.classifier = 'token'
39 | config.representation_size = None
40 | return config
41 |
42 | def get_r50_b16_config():
43 | """Returns the Resnet50 + ViT-B/16 configuration."""
44 | config = get_b16_config()
45 | config.patches.grid = (16, 16)
46 | config.resnet = ml_collections.ConfigDict()
47 | config.resnet.num_layers = (3, 4, 9)
48 | config.resnet.width_factor = 1
49 |
50 | config.classifier = 'seg'
51 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
52 | config.decoder_channels = (256, 128, 64, 16)
53 | config.skip_channels = [512, 256, 64, 16]
54 | config.n_classes = 2
55 | config.n_skip = 3
56 | config.activation = 'softmax'
57 |
58 | return config
59 |
60 |
61 | def get_b32_config():
62 | """Returns the ViT-B/32 configuration."""
63 | config = get_b16_config()
64 | config.patches.size = (32, 32)
65 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
66 | return config
67 |
68 |
69 | def get_l16_config():
70 | """Returns the ViT-L/16 configuration."""
71 | config = ml_collections.ConfigDict()
72 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
73 | config.hidden_size = 1024
74 | config.transformer = ml_collections.ConfigDict()
75 | config.transformer.mlp_dim = 4096
76 | config.transformer.num_heads = 16
77 | config.transformer.num_layers = 24
78 | config.transformer.attention_dropout_rate = 0.0
79 | config.transformer.dropout_rate = 0.1
80 | config.representation_size = None
81 |
82 | # custom
83 | config.classifier = 'seg'
84 | config.resnet_pretrained_path = None
85 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
86 | config.decoder_channels = (256, 128, 64, 16)
87 | config.n_classes = 2
88 | config.activation = 'softmax'
89 | return config
90 |
91 |
92 | def get_r50_l16_config():
93 | """Returns the Resnet50 + ViT-L/16 configuration. customized """
94 | config = get_l16_config()
95 | config.patches.grid = (16, 16)
96 | config.resnet = ml_collections.ConfigDict()
97 | config.resnet.num_layers = (3, 4, 9)
98 | config.resnet.width_factor = 1
99 |
100 | config.classifier = 'seg'
101 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
102 | config.decoder_channels = (256, 128, 64, 16)
103 | config.skip_channels = [512, 256, 64, 16]
104 | config.n_classes = 2
105 | config.activation = 'softmax'
106 | return config
107 |
108 |
109 | def get_l32_config():
110 | """Returns the ViT-L/32 configuration."""
111 | config = get_l16_config()
112 | config.patches.size = (32, 32)
113 | return config
114 |
115 |
116 | def get_h14_config():
117 | """Returns the ViT-L/16 configuration."""
118 | config = ml_collections.ConfigDict()
119 | config.patches = ml_collections.ConfigDict({'size': (14, 14)})
120 | config.hidden_size = 1280
121 | config.transformer = ml_collections.ConfigDict()
122 | config.transformer.mlp_dim = 5120
123 | config.transformer.num_heads = 16
124 | config.transformer.num_layers = 32
125 | config.transformer.attention_dropout_rate = 0.0
126 | config.transformer.dropout_rate = 0.1
127 | config.classifier = 'token'
128 | config.representation_size = None
129 |
130 | return config
131 |
--------------------------------------------------------------------------------
/models/_transunet/vit_seg_modeling_resnet_skip.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from os.path import join as pjoin
4 | from collections import OrderedDict
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def np2th(weights, conv=False):
12 | """Possibly convert HWIO to OIHW."""
13 | if conv:
14 | weights = weights.transpose([3, 2, 0, 1])
15 | return torch.from_numpy(weights)
16 |
17 |
18 | class StdConv2d(nn.Conv2d):
19 |
20 | def forward(self, x):
21 | w = self.weight
22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
23 | w = (w - m) / torch.sqrt(v + 1e-5)
24 | return F.conv2d(x, w, self.bias, self.stride, self.padding,
25 | self.dilation, self.groups)
26 |
27 |
28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False):
29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride,
30 | padding=1, bias=bias, groups=groups)
31 |
32 |
33 | def conv1x1(cin, cout, stride=1, bias=False):
34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride,
35 | padding=0, bias=bias)
36 |
37 |
38 | class PreActBottleneck(nn.Module):
39 | """Pre-activation (v2) bottleneck block.
40 | """
41 |
42 | def __init__(self, cin, cout=None, cmid=None, stride=1):
43 | super().__init__()
44 | cout = cout or cin
45 | cmid = cmid or cout//4
46 |
47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
48 | self.conv1 = conv1x1(cin, cmid, bias=False)
49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
52 | self.conv3 = conv1x1(cmid, cout, bias=False)
53 | self.relu = nn.ReLU(inplace=True)
54 |
55 | if (stride != 1 or cin != cout):
56 | # Projection also with pre-activation according to paper.
57 | self.downsample = conv1x1(cin, cout, stride, bias=False)
58 | self.gn_proj = nn.GroupNorm(cout, cout)
59 |
60 | def forward(self, x):
61 |
62 | # Residual branch
63 | residual = x
64 | if hasattr(self, 'downsample'):
65 | residual = self.downsample(x)
66 | residual = self.gn_proj(residual)
67 |
68 | # Unit's branch
69 | y = self.relu(self.gn1(self.conv1(x)))
70 | y = self.relu(self.gn2(self.conv2(y)))
71 | y = self.gn3(self.conv3(y))
72 |
73 | y = self.relu(residual + y)
74 | return y
75 |
76 | def load_from(self, weights, n_block, n_unit):
77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
80 |
81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
83 |
84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
86 |
87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
89 |
90 | self.conv1.weight.copy_(conv1_weight)
91 | self.conv2.weight.copy_(conv2_weight)
92 | self.conv3.weight.copy_(conv3_weight)
93 |
94 | self.gn1.weight.copy_(gn1_weight.view(-1))
95 | self.gn1.bias.copy_(gn1_bias.view(-1))
96 |
97 | self.gn2.weight.copy_(gn2_weight.view(-1))
98 | self.gn2.bias.copy_(gn2_bias.view(-1))
99 |
100 | self.gn3.weight.copy_(gn3_weight.view(-1))
101 | self.gn3.bias.copy_(gn3_bias.view(-1))
102 |
103 | if hasattr(self, 'downsample'):
104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
107 |
108 | self.downsample.weight.copy_(proj_conv_weight)
109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
111 |
112 | class ResNetV2(nn.Module):
113 | """Implementation of Pre-activation (v2) ResNet mode."""
114 |
115 | def __init__(self, block_units, width_factor):
116 | super().__init__()
117 | width = int(64 * width_factor)
118 | self.width = width
119 |
120 | self.root = nn.Sequential(OrderedDict([
121 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)),
123 | ('relu', nn.ReLU(inplace=True)),
124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
125 | ]))
126 |
127 | self.body = nn.Sequential(OrderedDict([
128 | ('block1', nn.Sequential(OrderedDict(
129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
131 | ))),
132 | ('block2', nn.Sequential(OrderedDict(
133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
135 | ))),
136 | ('block3', nn.Sequential(OrderedDict(
137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
139 | ))),
140 | ]))
141 |
142 | def forward(self, x):
143 | features = []
144 | b, c, in_size, _ = x.size()
145 | x = self.root(x)
146 | features.append(x)
147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
148 | for i in range(len(self.body)-1):
149 | x = self.body[i](x)
150 | right_size = int(in_size / 4 / (i+1))
151 | if x.size()[2] != right_size:
152 | pad = right_size - x.size()[2]
153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
156 | else:
157 | feat = x
158 | features.append(feat)
159 | x = self.body[-1](x)
160 | return x, features[::-1]
161 |
--------------------------------------------------------------------------------
/models/_transunet/vit_seg_modeling_resnet_skip_c4.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from os.path import join as pjoin
4 | from collections import OrderedDict
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def np2th(weights, conv=False):
12 | """Possibly convert HWIO to OIHW."""
13 | if conv:
14 | weights = weights.transpose([3, 2, 0, 1])
15 | return torch.from_numpy(weights)
16 |
17 |
18 | class StdConv2d(nn.Conv2d):
19 |
20 | def forward(self, x):
21 | w = self.weight
22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
23 | w = (w - m) / torch.sqrt(v + 1e-5)
24 | return F.conv2d(x, w, self.bias, self.stride, self.padding,
25 | self.dilation, self.groups)
26 |
27 |
28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False):
29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride,
30 | padding=1, bias=bias, groups=groups)
31 |
32 |
33 | def conv1x1(cin, cout, stride=1, bias=False):
34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride,
35 | padding=0, bias=bias)
36 |
37 |
38 | class PreActBottleneck(nn.Module):
39 | """Pre-activation (v2) bottleneck block.
40 | """
41 |
42 | def __init__(self, cin, cout=None, cmid=None, stride=1):
43 | super().__init__()
44 | cout = cout or cin
45 | cmid = cmid or cout//4
46 |
47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
48 | self.conv1 = conv1x1(cin, cmid, bias=False)
49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
52 | self.conv3 = conv1x1(cmid, cout, bias=False)
53 | self.relu = nn.ReLU(inplace=True)
54 |
55 | if (stride != 1 or cin != cout):
56 | # Projection also with pre-activation according to paper.
57 | self.downsample = conv1x1(cin, cout, stride, bias=False)
58 | self.gn_proj = nn.GroupNorm(cout, cout)
59 |
60 | def forward(self, x):
61 |
62 | # Residual branch
63 | residual = x
64 | if hasattr(self, 'downsample'):
65 | residual = self.downsample(x)
66 | residual = self.gn_proj(residual)
67 |
68 | # Unit's branch
69 | y = self.relu(self.gn1(self.conv1(x)))
70 | y = self.relu(self.gn2(self.conv2(y)))
71 | y = self.gn3(self.conv3(y))
72 |
73 | y = self.relu(residual + y)
74 | return y
75 |
76 | def load_from(self, weights, n_block, n_unit):
77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
80 |
81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
83 |
84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
86 |
87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
89 |
90 | self.conv1.weight.copy_(conv1_weight)
91 | self.conv2.weight.copy_(conv2_weight)
92 | self.conv3.weight.copy_(conv3_weight)
93 |
94 | self.gn1.weight.copy_(gn1_weight.view(-1))
95 | self.gn1.bias.copy_(gn1_bias.view(-1))
96 |
97 | self.gn2.weight.copy_(gn2_weight.view(-1))
98 | self.gn2.bias.copy_(gn2_bias.view(-1))
99 |
100 | self.gn3.weight.copy_(gn3_weight.view(-1))
101 | self.gn3.bias.copy_(gn3_bias.view(-1))
102 |
103 | if hasattr(self, 'downsample'):
104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
107 |
108 | self.downsample.weight.copy_(proj_conv_weight)
109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
111 |
112 | class ResNetV2(nn.Module):
113 | """Implementation of Pre-activation (v2) ResNet mode."""
114 |
115 | def __init__(self, block_units, width_factor):
116 | super().__init__()
117 | width = int(64 * width_factor)
118 | self.width = width
119 |
120 | self.root = nn.Sequential(OrderedDict([
121 | ('conv', StdConv2d(4, width, kernel_size=7, stride=2, bias=False, padding=3)),
122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)),
123 | ('relu', nn.ReLU(inplace=True)),
124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
125 | ]))
126 |
127 | self.body = nn.Sequential(OrderedDict([
128 | ('block1', nn.Sequential(OrderedDict(
129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
131 | ))),
132 | ('block2', nn.Sequential(OrderedDict(
133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
135 | ))),
136 | ('block3', nn.Sequential(OrderedDict(
137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
139 | ))),
140 | ]))
141 |
142 | def forward(self, x):
143 | features = []
144 | b, c, in_size, _ = x.size()
145 | x = self.root(x)
146 | features.append(x)
147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
148 | for i in range(len(self.body)-1):
149 | x = self.body[i](x)
150 | right_size = int(in_size / 4 / (i+1))
151 | if x.size()[2] != right_size:
152 | pad = right_size - x.size()[2]
153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
156 | else:
157 | feat = x
158 | features.append(feat)
159 | x = self.body[-1](x)
160 | return x, features[::-1]
161 |
--------------------------------------------------------------------------------
/models/_uctransnet/Config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/6/19 2:44 下午
3 | # @Author : Haonan Wang
4 | # @File : Config.py
5 | # @Software: PyCharm
6 | import os
7 | import torch
8 | import time
9 | import ml_collections
10 |
11 | ## PARAMETERS OF THE MODEL
12 | save_model = True
13 | tensorboard = True
14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15 | use_cuda = torch.cuda.is_available()
16 | seed = 666
17 | os.environ['PYTHONHASHSEED'] = str(seed)
18 |
19 | cosineLR = True # whether use cosineLR or not
20 | n_channels = 3
21 | n_labels = 1
22 | epochs = 2000
23 | img_size = 224
24 | print_frequency = 1
25 | save_frequency = 5000
26 | vis_frequency = 10
27 | early_stopping_patience = 50
28 |
29 | pretrain = False
30 | task_name = 'MoNuSeg' # GlaS MoNuSeg
31 | # task_name = 'GlaS'
32 | learning_rate = 1e-3
33 | batch_size = 4
34 |
35 |
36 | # model_name = 'UCTransNet'
37 | model_name = 'UCTransNet_pretrain'
38 |
39 | train_dataset = './datasets/'+ task_name+ '/Train_Folder/'
40 | val_dataset = './datasets/'+ task_name+ '/Val_Folder/'
41 | test_dataset = './datasets/'+ task_name+ '/Test_Folder/'
42 | session_name = 'Test_session' + '_' + time.strftime('%m.%d_%Hh%M')
43 | save_path = task_name +'/'+ model_name +'/' + session_name + '/'
44 | model_path = save_path + 'models/'
45 | tensorboard_folder = save_path + 'tensorboard_logs/'
46 | logger_path = save_path + session_name + ".log"
47 | visualize_path = save_path + 'visualize_val/'
48 |
49 |
50 | ##########################################################################
51 | # CTrans configs
52 | ##########################################################################
53 | def get_CTranS_config():
54 | config = ml_collections.ConfigDict()
55 | config.transformer = ml_collections.ConfigDict()
56 | config.KV_size = 960 # KV_size = Q1 + Q2 + Q3 + Q4
57 | config.transformer.num_heads = 4
58 | config.transformer.num_layers = 4
59 | config.expand_ratio = 4 # MLP channel dimension expand ratio
60 | config.transformer.embeddings_dropout_rate = 0.1
61 | config.transformer.attention_dropout_rate = 0.1
62 | config.transformer.dropout_rate = 0
63 | config.patch_sizes = [16,8,4,2]
64 | config.base_channel = 64 # base channel of U-Net
65 | config.n_classes = 1
66 | return config
67 |
68 |
69 |
70 |
71 | # used in testing phase, copy the session name in training phase
72 | test_session = "Test_session_07.03_20h39"
--------------------------------------------------------------------------------
/models/_uctransnet/UCTransNet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2021/7/8 8:59 上午
3 | # @File : UCTransNet.py
4 | # @Software: PyCharm
5 | import torch.nn as nn
6 | import torch
7 | import torch.nn.functional as F
8 | from .CTrans import ChannelTransformer
9 |
10 | def get_activation(activation_type):
11 | activation_type = activation_type.lower()
12 | if hasattr(nn, activation_type):
13 | return getattr(nn, activation_type)()
14 | else:
15 | return nn.ReLU()
16 |
17 | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
18 | layers = []
19 | layers.append(ConvBatchNorm(in_channels, out_channels, activation))
20 |
21 | for _ in range(nb_Conv - 1):
22 | layers.append(ConvBatchNorm(out_channels, out_channels, activation))
23 | return nn.Sequential(*layers)
24 |
25 | class ConvBatchNorm(nn.Module):
26 | """(convolution => [BN] => ReLU)"""
27 |
28 | def __init__(self, in_channels, out_channels, activation='ReLU'):
29 | super(ConvBatchNorm, self).__init__()
30 | self.conv = nn.Conv2d(in_channels, out_channels,
31 | kernel_size=3, padding=1)
32 | self.norm = nn.BatchNorm2d(out_channels)
33 | self.activation = get_activation(activation)
34 |
35 | def forward(self, x):
36 | out = self.conv(x)
37 | out = self.norm(out)
38 | return self.activation(out)
39 |
40 | class DownBlock(nn.Module):
41 | """Downscaling with maxpool convolution"""
42 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
43 | super(DownBlock, self).__init__()
44 | self.maxpool = nn.MaxPool2d(2)
45 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
46 |
47 | def forward(self, x):
48 | out = self.maxpool(x)
49 | return self.nConvs(out)
50 |
51 | class Flatten(nn.Module):
52 | def forward(self, x):
53 | return x.view(x.size(0), -1)
54 |
55 | class CCA(nn.Module):
56 | """
57 | CCA Block
58 | """
59 | def __init__(self, F_g, F_x):
60 | super().__init__()
61 | self.mlp_x = nn.Sequential(
62 | Flatten(),
63 | nn.Linear(F_x, F_x))
64 | self.mlp_g = nn.Sequential(
65 | Flatten(),
66 | nn.Linear(F_g, F_x))
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | def forward(self, g, x):
70 | # channel-wise attention
71 | avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
72 | channel_att_x = self.mlp_x(avg_pool_x)
73 | avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
74 | channel_att_g = self.mlp_g(avg_pool_g)
75 | channel_att_sum = (channel_att_x + channel_att_g)/2.0
76 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
77 | x_after_channel = x * scale
78 | out = self.relu(x_after_channel)
79 | return out
80 |
81 | class UpBlock_attention(nn.Module):
82 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
83 | super().__init__()
84 | self.up = nn.Upsample(scale_factor=2)
85 | self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2)
86 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
87 |
88 | def forward(self, x, skip_x):
89 | up = self.up(x)
90 | skip_x_att = self.coatt(g=up, x=skip_x)
91 | x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension
92 | return self.nConvs(x)
93 |
94 | class UCTransNet(nn.Module):
95 | def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False):
96 | super().__init__()
97 | self.vis = vis
98 | self.n_channels = n_channels
99 | self.n_classes = n_classes
100 | in_channels = config.base_channel
101 | self.inc = ConvBatchNorm(n_channels, in_channels)
102 | self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
103 | self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
104 | self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
105 | self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
106 | self.mtc = ChannelTransformer(config, vis, img_size,
107 | channel_num=[in_channels, in_channels*2, in_channels*4, in_channels*8],
108 | patchSize=config.patch_sizes)
109 | self.up4 = UpBlock_attention(in_channels*16, in_channels*4, nb_Conv=2)
110 | self.up3 = UpBlock_attention(in_channels*8, in_channels*2, nb_Conv=2)
111 | self.up2 = UpBlock_attention(in_channels*4, in_channels, nb_Conv=2)
112 | self.up1 = UpBlock_attention(in_channels*2, in_channels, nb_Conv=2)
113 | self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1), stride=(1,1))
114 | self.last_activation = nn.Sigmoid() # if using BCELoss
115 |
116 | def forward(self, x):
117 | x = x.float()
118 | x1 = self.inc(x)
119 | x2 = self.down1(x1)
120 | x3 = self.down2(x2)
121 | x4 = self.down3(x3)
122 | x5 = self.down4(x4)
123 | x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4)
124 | x = self.up4(x5, x4)
125 | x = self.up3(x, x3)
126 | x = self.up2(x, x2)
127 | x = self.up1(x, x1)
128 | if self.n_classes ==1:
129 | logits = self.last_activation(self.outc(x))
130 | else:
131 | logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1
132 | if self.vis: # visualize the attention maps
133 | return logits, att_weights
134 | else:
135 | return logits
136 |
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/models/_uctransnet/UNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | def get_activation(activation_type):
5 | activation_type = activation_type.lower()
6 | if hasattr(nn, activation_type):
7 | return getattr(nn, activation_type)()
8 | else:
9 | return nn.ReLU()
10 |
11 | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'):
12 | layers = []
13 | layers.append(ConvBatchNorm(in_channels, out_channels, activation))
14 |
15 | for _ in range(nb_Conv - 1):
16 | layers.append(ConvBatchNorm(out_channels, out_channels, activation))
17 | return nn.Sequential(*layers)
18 |
19 | class ConvBatchNorm(nn.Module):
20 | """(convolution => [BN] => ReLU)"""
21 |
22 | def __init__(self, in_channels, out_channels, activation='ReLU'):
23 | super(ConvBatchNorm, self).__init__()
24 | self.conv = nn.Conv2d(in_channels, out_channels,
25 | kernel_size=3, padding=1)
26 | self.norm = nn.BatchNorm2d(out_channels)
27 | self.activation = get_activation(activation)
28 |
29 | def forward(self, x):
30 | out = self.conv(x)
31 | out = self.norm(out)
32 | return self.activation(out)
33 |
34 | class DownBlock(nn.Module):
35 | """Downscaling with maxpool convolution"""
36 |
37 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
38 | super(DownBlock, self).__init__()
39 | self.maxpool = nn.MaxPool2d(2)
40 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
41 |
42 | def forward(self, x):
43 | out = self.maxpool(x)
44 | return self.nConvs(out)
45 |
46 | class UpBlock(nn.Module):
47 | """Upscaling then conv"""
48 |
49 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'):
50 | super(UpBlock, self).__init__()
51 |
52 | # self.up = nn.Upsample(scale_factor=2)
53 | self.up = nn.ConvTranspose2d(in_channels//2,in_channels//2,(2,2),2)
54 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation)
55 |
56 | def forward(self, x, skip_x):
57 | out = self.up(x)
58 | x = torch.cat([out, skip_x], dim=1) # dim 1 is the channel dimension
59 | return self.nConvs(x)
60 |
61 | class UNet(nn.Module):
62 | def __init__(self, n_channels=3, n_classes=9):
63 | '''
64 | n_channels : number of channels of the input.
65 | By default 3, because we have RGB images
66 | n_labels : number of channels of the ouput.
67 | By default 3 (2 labels + 1 for the background)
68 | '''
69 | super().__init__()
70 | self.n_channels = n_channels
71 | self.n_classes = n_classes
72 | # Question here
73 | in_channels = 64
74 | self.inc = ConvBatchNorm(n_channels, in_channels)
75 | self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2)
76 | self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2)
77 | self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2)
78 | self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2)
79 | self.up4 = UpBlock(in_channels*16, in_channels*4, nb_Conv=2)
80 | self.up3 = UpBlock(in_channels*8, in_channels*2, nb_Conv=2)
81 | self.up2 = UpBlock(in_channels*4, in_channels, nb_Conv=2)
82 | self.up1 = UpBlock(in_channels*2, in_channels, nb_Conv=2)
83 | self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1))
84 | if n_classes == 1:
85 | self.last_activation = nn.Sigmoid()
86 | else:
87 | self.last_activation = None
88 |
89 | def forward(self, x):
90 | # Question here
91 | x = x.float()
92 | x1 = self.inc(x)
93 | x2 = self.down1(x1)
94 | x3 = self.down2(x2)
95 | x4 = self.down3(x3)
96 | x5 = self.down4(x4)
97 | x = self.up4(x5, x4)
98 | x = self.up3(x, x3)
99 | x = self.up2(x, x2)
100 | x = self.up1(x, x1)
101 | if self.last_activation is not None:
102 | logits = self.last_activation(self.outc(x))
103 | # print("111")
104 | else:
105 | logits = self.outc(x)
106 | # print("222")
107 | # logits = self.outc(x) # if using BCEWithLogitsLoss
108 | # print(logits.size())
109 | return logits
110 |
111 |
112 |
--------------------------------------------------------------------------------
/models/attunet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/LeeJunHyun/Image_Segmentation
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 |
8 | def init_weights(net, init_type='normal', gain=0.02):
9 | def init_func(m):
10 | classname = m.__class__.__name__
11 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
12 | if init_type == 'normal':
13 | init.normal_(m.weight.data, 0.0, gain)
14 | elif init_type == 'xavier':
15 | init.xavier_normal_(m.weight.data, gain=gain)
16 | elif init_type == 'kaiming':
17 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
18 | elif init_type == 'orthogonal':
19 | init.orthogonal_(m.weight.data, gain=gain)
20 | else:
21 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
22 | if hasattr(m, 'bias') and m.bias is not None:
23 | init.constant_(m.bias.data, 0.0)
24 | elif classname.find('BatchNorm2d') != -1:
25 | init.normal_(m.weight.data, 1.0, gain)
26 | init.constant_(m.bias.data, 0.0)
27 |
28 | print('initialize network with %s' % init_type)
29 | net.apply(init_func)
30 |
31 | class conv_block(nn.Module):
32 | def __init__(self,ch_in,ch_out):
33 | super(conv_block,self).__init__()
34 | self.conv = nn.Sequential(
35 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
36 | nn.BatchNorm2d(ch_out),
37 | nn.ReLU(inplace=True),
38 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
39 | nn.BatchNorm2d(ch_out),
40 | nn.ReLU(inplace=True)
41 | )
42 |
43 |
44 | def forward(self,x):
45 | x = self.conv(x)
46 | return x
47 |
48 | class up_conv(nn.Module):
49 | def __init__(self,ch_in,ch_out):
50 | super(up_conv,self).__init__()
51 | self.up = nn.Sequential(
52 | nn.Upsample(scale_factor=2),
53 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
54 | nn.BatchNorm2d(ch_out),
55 | nn.ReLU(inplace=True)
56 | )
57 |
58 | def forward(self,x):
59 | x = self.up(x)
60 | return x
61 |
62 | class Recurrent_block(nn.Module):
63 | def __init__(self,ch_out,t=2):
64 | super(Recurrent_block,self).__init__()
65 | self.t = t
66 | self.ch_out = ch_out
67 | self.conv = nn.Sequential(
68 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
69 | nn.BatchNorm2d(ch_out),
70 | nn.ReLU(inplace=True)
71 | )
72 |
73 | def forward(self,x):
74 | for i in range(self.t):
75 |
76 | if i==0:
77 | x1 = self.conv(x)
78 |
79 | x1 = self.conv(x+x1)
80 | return x1
81 |
82 | class RRCNN_block(nn.Module):
83 | def __init__(self,ch_in,ch_out,t=2):
84 | super(RRCNN_block,self).__init__()
85 | self.RCNN = nn.Sequential(
86 | Recurrent_block(ch_out,t=t),
87 | Recurrent_block(ch_out,t=t)
88 | )
89 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
90 |
91 | def forward(self,x):
92 | x = self.Conv_1x1(x)
93 | x1 = self.RCNN(x)
94 | return x+x1
95 |
96 |
97 | class single_conv(nn.Module):
98 | def __init__(self,ch_in,ch_out):
99 | super(single_conv,self).__init__()
100 | self.conv = nn.Sequential(
101 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
102 | nn.BatchNorm2d(ch_out),
103 | nn.ReLU(inplace=True)
104 | )
105 |
106 | def forward(self,x):
107 | x = self.conv(x)
108 | return x
109 |
110 | class Attention_block(nn.Module):
111 | def __init__(self,F_g,F_l,F_int):
112 | super(Attention_block,self).__init__()
113 | self.W_g = nn.Sequential(
114 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
115 | nn.BatchNorm2d(F_int)
116 | )
117 |
118 | self.W_x = nn.Sequential(
119 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
120 | nn.BatchNorm2d(F_int)
121 | )
122 |
123 | self.psi = nn.Sequential(
124 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
125 | nn.BatchNorm2d(1),
126 | nn.Sigmoid()
127 | )
128 |
129 | self.relu = nn.ReLU(inplace=True)
130 |
131 | def forward(self,g,x):
132 | g1 = self.W_g(g)
133 | x1 = self.W_x(x)
134 | psi = self.relu(g1+x1)
135 | psi = self.psi(psi)
136 |
137 | return x*psi
138 |
139 |
140 | class U_Net(nn.Module):
141 | def __init__(self,img_ch=3,output_ch=1):
142 | super(U_Net,self).__init__()
143 |
144 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
145 |
146 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
147 | self.Conv2 = conv_block(ch_in=64,ch_out=128)
148 | self.Conv3 = conv_block(ch_in=128,ch_out=256)
149 | self.Conv4 = conv_block(ch_in=256,ch_out=512)
150 | self.Conv5 = conv_block(ch_in=512,ch_out=1024)
151 |
152 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
153 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
154 |
155 | self.Up4 = up_conv(ch_in=512,ch_out=256)
156 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
157 |
158 | self.Up3 = up_conv(ch_in=256,ch_out=128)
159 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
160 |
161 | self.Up2 = up_conv(ch_in=128,ch_out=64)
162 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
163 |
164 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
165 |
166 |
167 | def forward(self,x):
168 | # encoding path
169 | x1 = self.Conv1(x)
170 |
171 | x2 = self.Maxpool(x1)
172 | x2 = self.Conv2(x2)
173 |
174 | x3 = self.Maxpool(x2)
175 | x3 = self.Conv3(x3)
176 |
177 | x4 = self.Maxpool(x3)
178 | x4 = self.Conv4(x4)
179 |
180 | x5 = self.Maxpool(x4)
181 | x5 = self.Conv5(x5)
182 |
183 | # decoding + concat path
184 | d5 = self.Up5(x5)
185 | d5 = torch.cat((x4,d5),dim=1)
186 |
187 | d5 = self.Up_conv5(d5)
188 |
189 | d4 = self.Up4(d5)
190 | d4 = torch.cat((x3,d4),dim=1)
191 | d4 = self.Up_conv4(d4)
192 |
193 | d3 = self.Up3(d4)
194 | d3 = torch.cat((x2,d3),dim=1)
195 | d3 = self.Up_conv3(d3)
196 |
197 | d2 = self.Up2(d3)
198 | d2 = torch.cat((x1,d2),dim=1)
199 | d2 = self.Up_conv2(d2)
200 |
201 | d1 = self.Conv_1x1(d2)
202 |
203 | return d1
204 |
205 |
206 | class R2U_Net(nn.Module):
207 | def __init__(self,img_ch=3,output_ch=1,t=2):
208 | super(R2U_Net,self).__init__()
209 |
210 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
211 | self.Upsample = nn.Upsample(scale_factor=2)
212 |
213 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
214 |
215 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
216 |
217 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
218 |
219 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
220 |
221 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
222 |
223 |
224 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
225 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
226 |
227 | self.Up4 = up_conv(ch_in=512,ch_out=256)
228 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
229 |
230 | self.Up3 = up_conv(ch_in=256,ch_out=128)
231 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
232 |
233 | self.Up2 = up_conv(ch_in=128,ch_out=64)
234 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
235 |
236 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
237 |
238 |
239 | def forward(self,x):
240 | # encoding path
241 | x1 = self.RRCNN1(x)
242 |
243 | x2 = self.Maxpool(x1)
244 | x2 = self.RRCNN2(x2)
245 |
246 | x3 = self.Maxpool(x2)
247 | x3 = self.RRCNN3(x3)
248 |
249 | x4 = self.Maxpool(x3)
250 | x4 = self.RRCNN4(x4)
251 |
252 | x5 = self.Maxpool(x4)
253 | x5 = self.RRCNN5(x5)
254 |
255 | # decoding + concat path
256 | d5 = self.Up5(x5)
257 | d5 = torch.cat((x4,d5),dim=1)
258 | d5 = self.Up_RRCNN5(d5)
259 |
260 | d4 = self.Up4(d5)
261 | d4 = torch.cat((x3,d4),dim=1)
262 | d4 = self.Up_RRCNN4(d4)
263 |
264 | d3 = self.Up3(d4)
265 | d3 = torch.cat((x2,d3),dim=1)
266 | d3 = self.Up_RRCNN3(d3)
267 |
268 | d2 = self.Up2(d3)
269 | d2 = torch.cat((x1,d2),dim=1)
270 | d2 = self.Up_RRCNN2(d2)
271 |
272 | d1 = self.Conv_1x1(d2)
273 |
274 | return d1
275 |
276 |
277 |
278 | class AttU_Net(nn.Module):
279 | def __init__(self,img_ch=3,output_ch=1):
280 | super(AttU_Net,self).__init__()
281 |
282 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
283 |
284 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
285 | self.Conv2 = conv_block(ch_in=64,ch_out=128)
286 | self.Conv3 = conv_block(ch_in=128,ch_out=256)
287 | self.Conv4 = conv_block(ch_in=256,ch_out=512)
288 | self.Conv5 = conv_block(ch_in=512,ch_out=1024)
289 |
290 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
291 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
292 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
293 |
294 | self.Up4 = up_conv(ch_in=512,ch_out=256)
295 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
296 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
297 |
298 | self.Up3 = up_conv(ch_in=256,ch_out=128)
299 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
300 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
301 |
302 | self.Up2 = up_conv(ch_in=128,ch_out=64)
303 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
304 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
305 |
306 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
307 |
308 |
309 | def forward(self,x):
310 | # encoding path
311 | x1 = self.Conv1(x)
312 |
313 | x2 = self.Maxpool(x1)
314 | x2 = self.Conv2(x2)
315 |
316 | x3 = self.Maxpool(x2)
317 | x3 = self.Conv3(x3)
318 |
319 | x4 = self.Maxpool(x3)
320 | x4 = self.Conv4(x4)
321 |
322 | x5 = self.Maxpool(x4)
323 | x5 = self.Conv5(x5)
324 |
325 | # decoding + concat path
326 | d5 = self.Up5(x5)
327 | x4 = self.Att5(g=d5,x=x4)
328 | d5 = torch.cat((x4,d5),dim=1)
329 | d5 = self.Up_conv5(d5)
330 |
331 | d4 = self.Up4(d5)
332 | x3 = self.Att4(g=d4,x=x3)
333 | d4 = torch.cat((x3,d4),dim=1)
334 | d4 = self.Up_conv4(d4)
335 |
336 | d3 = self.Up3(d4)
337 | x2 = self.Att3(g=d3,x=x2)
338 | d3 = torch.cat((x2,d3),dim=1)
339 | d3 = self.Up_conv3(d3)
340 |
341 | d2 = self.Up2(d3)
342 | x1 = self.Att2(g=d2,x=x1)
343 | d2 = torch.cat((x1,d2),dim=1)
344 | d2 = self.Up_conv2(d2)
345 |
346 | d1 = self.Conv_1x1(d2)
347 |
348 | return d1
349 |
350 |
351 | class R2AttU_Net(nn.Module):
352 | def __init__(self,img_ch=3,output_ch=1,t=2):
353 | super(R2AttU_Net,self).__init__()
354 |
355 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
356 | self.Upsample = nn.Upsample(scale_factor=2)
357 |
358 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
359 |
360 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
361 |
362 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
363 |
364 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
365 |
366 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
367 |
368 |
369 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
370 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
371 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
372 |
373 | self.Up4 = up_conv(ch_in=512,ch_out=256)
374 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
375 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
376 |
377 | self.Up3 = up_conv(ch_in=256,ch_out=128)
378 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
379 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
380 |
381 | self.Up2 = up_conv(ch_in=128,ch_out=64)
382 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
383 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
384 |
385 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
386 |
387 |
388 | def forward(self,x):
389 | # encoding path
390 | x1 = self.RRCNN1(x)
391 |
392 | x2 = self.Maxpool(x1)
393 | x2 = self.RRCNN2(x2)
394 |
395 | x3 = self.Maxpool(x2)
396 | x3 = self.RRCNN3(x3)
397 |
398 | x4 = self.Maxpool(x3)
399 | x4 = self.RRCNN4(x4)
400 |
401 | x5 = self.Maxpool(x4)
402 | x5 = self.RRCNN5(x5)
403 |
404 | # decoding + concat path
405 | d5 = self.Up5(x5)
406 | x4 = self.Att5(g=d5,x=x4)
407 | d5 = torch.cat((x4,d5),dim=1)
408 | d5 = self.Up_RRCNN5(d5)
409 |
410 | d4 = self.Up4(d5)
411 | x3 = self.Att4(g=d4,x=x3)
412 | d4 = torch.cat((x3,d4),dim=1)
413 | d4 = self.Up_RRCNN4(d4)
414 |
415 | d3 = self.Up3(d4)
416 | x2 = self.Att3(g=d3,x=x2)
417 | d3 = torch.cat((x2,d3),dim=1)
418 | d3 = self.Up_RRCNN3(d3)
419 |
420 | d2 = self.Up2(d3)
421 | x1 = self.Att2(g=d2,x=x1)
422 | d2 = torch.cat((x1,d2),dim=1)
423 | d2 = self.Up_RRCNN2(d2)
424 |
425 | d1 = self.Conv_1x1(d2)
426 |
427 | return d1
--------------------------------------------------------------------------------
/models/multiresunet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/j-sripad/mulitresunet-pytorch/blob/main/multiresunet.py
2 |
3 | from typing import Tuple, Dict
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch
7 |
8 |
9 | class Multiresblock(nn.Module):
10 | def __init__(self,input_features : int, corresponding_unet_filters : int ,alpha : float =1.67)->None:
11 | """
12 | MultiResblock
13 | Arguments:
14 | x - input layer
15 | corresponding_unet_filters - Unet filters for the same stage
16 | alpha - 1.67 - factor used in the paper to dervie number of filters for multiresunet filters from Unet filters
17 | Returns - None
18 | """
19 | super().__init__()
20 | self.corresponding_unet_filters = corresponding_unet_filters
21 | self.alpha = alpha
22 | self.W = corresponding_unet_filters * alpha
23 | self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5),
24 | kernel_size = (1,1),activation='None',padding = 0)
25 |
26 | self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = int(self.W*0.167),
27 | kernel_size = (3,3),activation='relu',padding = 1)
28 | self.conv2d_bn_5x5 = Conv2d_batchnorm(input_features=int(self.W*0.167),num_of_filters = int(self.W*0.333),
29 | kernel_size = (3,3),activation='relu',padding = 1)
30 | self.conv2d_bn_7x7 = Conv2d_batchnorm(input_features=int(self.W*0.333),num_of_filters = int(self.W*0.5),
31 | kernel_size = (3,3),activation='relu',padding = 1)
32 | self.batch_norm1 = nn.BatchNorm2d(int(self.W*0.5)+int(self.W*0.167)+int(self.W*0.333) ,affine=False)
33 |
34 | def forward(self,x: torch.Tensor)->torch.Tensor:
35 |
36 | temp = self.conv2d_bn_1x1(x)
37 | a = self.conv2d_bn_3x3(x)
38 | b = self.conv2d_bn_5x5(a)
39 | c = self.conv2d_bn_7x7(b)
40 | x = torch.cat([a,b,c],axis=1)
41 | x = self.batch_norm1(x)
42 | x = x + temp
43 | x = self.batch_norm1(x)
44 | return x
45 |
46 | class Conv2d_batchnorm(nn.Module):
47 | def __init__(self,input_features : int,num_of_filters : int ,kernel_size : Tuple = (2,2),stride : Tuple = (1,1), activation : str = 'relu',padding : int= 0)->None:
48 | """
49 | Arguments:
50 | x - input layer
51 | num_of_filters - no. of filter outputs
52 | filters - shape of the filters to be used
53 | stride - stride dimension
54 | activation -activation function to be used
55 | Returns - None
56 | """
57 | super().__init__()
58 | self.activation = activation
59 | self.conv1 = nn.Conv2d(in_channels=input_features,out_channels=num_of_filters,kernel_size=kernel_size,stride=stride,padding = padding)
60 | self.batchnorm = nn.BatchNorm2d(num_of_filters,affine=False)
61 |
62 | def forward(self,x : torch.Tensor)->torch.Tensor:
63 | x = self.conv1(x)
64 | x = self.batchnorm(x)
65 | if self.activation == 'relu':
66 | return F.relu(x)
67 | else:
68 | return x
69 |
70 |
71 | class Respath(nn.Module):
72 | def __init__(self,input_features : int,filters : int,respath_length : int)->None:
73 | """
74 | Arguments:
75 | input_features - input layer filters
76 | filters - output channels
77 | respath_length - length of the Respath
78 |
79 | Returns - None
80 | """
81 | super().__init__()
82 | self.filters = filters
83 | self.respath_length = respath_length
84 | self.conv2d_bn_1x1 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters,
85 | kernel_size = (1,1),activation='None',padding = 0)
86 | self.conv2d_bn_3x3 = Conv2d_batchnorm(input_features=input_features,num_of_filters = self.filters,
87 | kernel_size = (3,3),activation='relu',padding = 1)
88 | self.conv2d_bn_1x1_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters,
89 | kernel_size = (1,1),activation='None',padding = 0)
90 | self.conv2d_bn_3x3_common = Conv2d_batchnorm(input_features=self.filters,num_of_filters = self.filters,
91 | kernel_size = (3,3),activation='relu',padding = 1)
92 | self.batch_norm1 = nn.BatchNorm2d(filters,affine=False)
93 |
94 | def forward(self,x : torch.Tensor)->torch.Tensor:
95 | shortcut = self.conv2d_bn_1x1(x)
96 | x = self.conv2d_bn_3x3(x)
97 | x = x + shortcut
98 | x = F.relu(x)
99 | x = self.batch_norm1(x)
100 | if self.respath_length>1:
101 | for i in range(self.respath_length):
102 | shortcut = self.conv2d_bn_1x1_common(x)
103 | x = self.conv2d_bn_3x3_common(x)
104 | x = x + shortcut
105 | x = F.relu(x)
106 | x = self.batch_norm1(x)
107 | return x
108 | else:
109 | return x
110 |
111 | class MultiResUnet(nn.Module):
112 | def __init__(self,channels : int,filters : int =32,nclasses : int =1)->None:
113 |
114 | """
115 | Arguments:
116 | channels - input image channels
117 | filters - filters to begin with (Unet)
118 | nclasses - number of classes
119 | Returns - None
120 | """
121 | super().__init__()
122 | self.alpha = 1.67
123 | self.filters = filters
124 | self.nclasses = nclasses
125 | self.multiresblock1 = Multiresblock(input_features=channels,corresponding_unet_filters=self.filters)
126 | self.pool1 = nn.MaxPool2d(2,stride= 2)
127 | self.in_filters1 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333)
128 | self.respath1 = Respath(input_features=self.in_filters1 ,filters=self.filters,respath_length=4)
129 | self.multiresblock2 = Multiresblock(input_features= self.in_filters1,corresponding_unet_filters=self.filters*2)
130 | self.pool2 = nn.MaxPool2d(2, 2)
131 | self.in_filters2 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333)
132 | self.respath2 = Respath(input_features=self.in_filters2,filters=self.filters*2,respath_length=3)
133 | self.multiresblock3 = Multiresblock(input_features= self.in_filters2,corresponding_unet_filters=self.filters*4)
134 | self.pool3 = nn.MaxPool2d(2, 2)
135 | self.in_filters3 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333)
136 | self.respath3 = Respath(input_features=self.in_filters3,filters=self.filters*4,respath_length=2)
137 | self.multiresblock4 = Multiresblock(input_features= self.in_filters3,corresponding_unet_filters=self.filters*8)
138 | self.pool4 = nn.MaxPool2d(2, 2)
139 | self.in_filters4 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333)
140 | self.respath4 = Respath(input_features=self.in_filters4,filters=self.filters*8,respath_length=1)
141 | self.multiresblock5 = Multiresblock(input_features= self.in_filters4,corresponding_unet_filters=self.filters*16)
142 | self.in_filters5 = int(self.filters*16*self.alpha* 0.5)+int(self.filters*16*self.alpha*0.167)+int(self.filters*16*self.alpha*0.333)
143 |
144 | #Decoder path
145 | self.upsample6 = nn.ConvTranspose2d(in_channels=self.in_filters5,out_channels=self.filters*8,kernel_size=(2,2),stride=(2,2),padding = 0)
146 | self.concat_filters1 = self.filters*8+self.filters*8
147 | self.multiresblock6 = Multiresblock(input_features=self.concat_filters1,corresponding_unet_filters=self.filters*8)
148 | self.in_filters6 = int(self.filters*8*self.alpha* 0.5)+int(self.filters*8*self.alpha*0.167)+int(self.filters*8*self.alpha*0.333)
149 | self.upsample7 = nn.ConvTranspose2d(in_channels=self.in_filters6,out_channels=self.filters*4,kernel_size=(2,2),stride=(2,2),padding = 0)
150 | self.concat_filters2 = self.filters*4+self.filters*4
151 | self.multiresblock7 = Multiresblock(input_features=self.concat_filters2,corresponding_unet_filters=self.filters*4)
152 | self.in_filters7 = int(self.filters*4*self.alpha* 0.5)+int(self.filters*4*self.alpha*0.167)+int(self.filters*4*self.alpha*0.333)
153 | self.upsample8 = nn.ConvTranspose2d(in_channels=self.in_filters7,out_channels=self.filters*2,kernel_size=(2,2),stride=(2,2),padding = 0)
154 | self.concat_filters3 = self.filters*2+self.filters*2
155 | self.multiresblock8 = Multiresblock(input_features=self.concat_filters3,corresponding_unet_filters=self.filters*2)
156 | self.in_filters8 = int(self.filters*2*self.alpha* 0.5)+int(self.filters*2*self.alpha*0.167)+int(self.filters*2*self.alpha*0.333)
157 | self.upsample9 = nn.ConvTranspose2d(in_channels=self.in_filters8,out_channels=self.filters,kernel_size=(2,2),stride=(2,2),padding = 0)
158 | self.concat_filters4 = self.filters+self.filters
159 | self.multiresblock9 = Multiresblock(input_features=self.concat_filters4,corresponding_unet_filters=self.filters)
160 | self.in_filters9 = int(self.filters*self.alpha* 0.5)+int(self.filters*self.alpha*0.167)+int(self.filters*self.alpha*0.333)
161 | self.conv_final = Conv2d_batchnorm(input_features=self.in_filters9,num_of_filters = self.nclasses,
162 | kernel_size = (1,1),activation='None')
163 |
164 | def forward(self,x : torch.Tensor)->torch.Tensor:
165 | x_multires1 = self.multiresblock1(x)
166 | x_pool1 = self.pool1(x_multires1)
167 | x_multires1 = self.respath1(x_multires1)
168 | x_multires2 = self.multiresblock2(x_pool1)
169 | x_pool2 = self.pool2(x_multires2)
170 | x_multires2 = self.respath2(x_multires2)
171 | x_multires3 = self.multiresblock3(x_pool2)
172 | x_pool3 = self.pool3(x_multires3)
173 | x_multires3 = self.respath3(x_multires3)
174 | x_multires4 = self.multiresblock4(x_pool3)
175 | x_pool4 = self.pool4(x_multires4)
176 | x_multires4 = self.respath4(x_multires4)
177 | x_multires5 = self.multiresblock5(x_pool4)
178 | up6 = torch.cat([self.upsample6(x_multires5),x_multires4],axis=1)
179 | x_multires6 = self.multiresblock6(up6)
180 | up7 = torch.cat([self.upsample7(x_multires6),x_multires3],axis=1)
181 | x_multires7 = self.multiresblock7(up7)
182 | up8 = torch.cat([self.upsample8(x_multires7),x_multires2],axis=1)
183 | x_multires8 = self.multiresblock8(up8)
184 | up9 = torch.cat([self.upsample9(x_multires8),x_multires1],axis=1)
185 | x_multires9 = self.multiresblock9(up9)
186 | if self.nclasses > 1:
187 | conv_final_layer = self.conv_final(x_multires9)
188 | else:
189 | conv_final_layer = torch.sigmoid(self.conv_final(x_multires9))
190 | return conv_final_layer
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 |
6 | class DoubleConv(nn.Module):
7 | def __init__(self, in_channels, out_channels, with_bn=False):
8 | super().__init__()
9 | if with_bn:
10 | self.step = nn.Sequential(
11 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
12 | nn.BatchNorm2d(out_channels),
13 | nn.ReLU(),
14 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
15 | nn.BatchNorm2d(out_channels),
16 | nn.ReLU(),
17 | )
18 | else:
19 | self.step = nn.Sequential(
20 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
21 | nn.ReLU(),
22 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
23 | nn.ReLU(),
24 | )
25 |
26 | def forward(self, x):
27 | return self.step(x)
28 |
29 |
30 | class UNet(nn.Module):
31 | def __init__(self, in_channels, out_channels, with_bn=False):
32 | super().__init__()
33 | init_channels = 32
34 | self.out_channels = out_channels
35 |
36 | self.en_1 = DoubleConv(in_channels , init_channels , with_bn)
37 | self.en_2 = DoubleConv(1*init_channels, 2*init_channels, with_bn)
38 | self.en_3 = DoubleConv(2*init_channels, 4*init_channels, with_bn)
39 | self.en_4 = DoubleConv(4*init_channels, 8*init_channels, with_bn)
40 |
41 | self.de_1 = DoubleConv((4 + 8)*init_channels, 4*init_channels, with_bn)
42 | self.de_2 = DoubleConv((2 + 4)*init_channels, 2*init_channels, with_bn)
43 | self.de_3 = DoubleConv((1 + 2)*init_channels, 1*init_channels, with_bn)
44 | self.de_4 = nn.Conv2d(init_channels, out_channels, 1)
45 |
46 | self.maxpool = nn.MaxPool2d(kernel_size=2)
47 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
48 |
49 | def forward(self, x):
50 | e1 = self.en_1(x)
51 | e2 = self.en_2(self.maxpool(e1))
52 | e3 = self.en_3(self.maxpool(e2))
53 | e4 = self.en_4(self.maxpool(e3))
54 |
55 | d1 = self.de_1(torch.cat([self.upsample(e4), e3], dim=1))
56 | d2 = self.de_2(torch.cat([self.upsample(d1), e2], dim=1))
57 | d3 = self.de_3(torch.cat([self.upsample(d2), e1], dim=1))
58 | d4 = self.de_4(d3)
59 |
60 | return d4
61 |
62 | # if self.out_channels<2:
63 | # return torch.sigmoid(d4)
64 | # return torch.softmax(d4, 1)
65 |
--------------------------------------------------------------------------------
/models/unetpp.py:
--------------------------------------------------------------------------------
1 | # https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py (unetpp)
2 |
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn.functional import softmax, sigmoid
7 |
8 |
9 | __all__ = ['UNet', 'NestedUNet']
10 |
11 |
12 | class VGGBlock(nn.Module):
13 | def __init__(self, in_channels, middle_channels, out_channels):
14 | super().__init__()
15 | self.relu = nn.ReLU(inplace=True)
16 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
17 | self.bn1 = nn.BatchNorm2d(middle_channels)
18 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
19 | self.bn2 = nn.BatchNorm2d(out_channels)
20 |
21 | def forward(self, x):
22 | out = self.conv1(x)
23 | out = self.bn1(out)
24 | out = self.relu(out)
25 |
26 | out = self.conv2(out)
27 | out = self.bn2(out)
28 | out = self.relu(out)
29 |
30 | return out
31 |
32 |
33 | class UNet(nn.Module):
34 | def __init__(self, num_classes, input_channels=3, **kwargs):
35 | super().__init__()
36 |
37 | nb_filter = [32, 64, 128, 256, 512]
38 |
39 | self.pool = nn.MaxPool2d(2, 2)
40 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
41 |
42 | self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
43 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
44 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
45 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
46 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
47 |
48 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
49 | self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
50 | self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
51 | self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
52 |
53 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
54 |
55 |
56 | def forward(self, input):
57 | x0_0 = self.conv0_0(input)
58 | x1_0 = self.conv1_0(self.pool(x0_0))
59 | x2_0 = self.conv2_0(self.pool(x1_0))
60 | x3_0 = self.conv3_0(self.pool(x2_0))
61 | x4_0 = self.conv4_0(self.pool(x3_0))
62 |
63 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
64 | x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
65 | x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
66 | x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
67 |
68 | output = self.final(x0_4)
69 | return output
70 |
71 |
72 | class NestedUNet(nn.Module):
73 | def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
74 | super().__init__()
75 |
76 | nb_filter = [32, 64, 128, 256, 512]
77 |
78 | self.deep_supervision = deep_supervision
79 |
80 | self.pool = nn.MaxPool2d(2, 2)
81 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
82 |
83 | self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
84 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
85 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
86 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
87 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
88 |
89 | self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
90 | self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
91 | self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
92 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
93 |
94 | self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
95 | self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
96 | self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
97 |
98 | self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
99 | self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
100 |
101 | self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
102 |
103 | if self.deep_supervision:
104 | self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
105 | self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
106 | self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
107 | self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
108 | else:
109 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
110 |
111 |
112 | def forward(self, input):
113 | x0_0 = self.conv0_0(input)
114 | x1_0 = self.conv1_0(self.pool(x0_0))
115 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
116 |
117 | x2_0 = self.conv2_0(self.pool(x1_0))
118 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
119 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
120 |
121 | x3_0 = self.conv3_0(self.pool(x2_0))
122 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
123 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
124 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
125 |
126 | x4_0 = self.conv4_0(self.pool(x3_0))
127 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
128 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
129 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
130 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
131 |
132 | if self.deep_supervision:
133 | output1 = self.final1(x0_1)
134 | output2 = self.final2(x0_2)
135 | output3 = self.final3(x0_3)
136 | output4 = self.final4(x0_4)
137 | return [output1, output2, output3, output4]
138 |
139 | else:
140 | output = self.final(x0_4)
141 | return output
142 |
--------------------------------------------------------------------------------
/train_and_test/isic/attunet-isic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # # AttUNet - ISIC2018
5 | # ---
6 |
7 | # ## Import packages & functions
8 |
9 | # In[7]:
10 |
11 |
12 | from __future__ import print_function, division
13 |
14 |
15 | import os
16 | import sys
17 | sys.path.append('../..')
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19 |
20 | import copy
21 | import json
22 | import importlib
23 | import glob
24 | import pandas as pd
25 | from skimage import io, transform
26 | import matplotlib.pyplot as plt
27 | from matplotlib.image import imread
28 | import numpy as np
29 | from tqdm import tqdm
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 | import torch.optim as optim
35 | import torchmetrics
36 | from torch.optim import Adam, SGD
37 | from losses import DiceLoss, DiceLossWithLogtis
38 | from torch.nn import BCELoss, CrossEntropyLoss
39 |
40 | from utils import (
41 | show_sbs,
42 | load_config,
43 | _print,
44 | )
45 |
46 | # Ignore warnings
47 | import warnings
48 | warnings.filterwarnings("ignore")
49 |
50 | # plt.ion() # interactive mode
51 |
52 |
53 | # ## Set the seed
54 |
55 | # In[2]:
56 |
57 |
58 | torch.manual_seed(0)
59 | np.random.seed(0)
60 | torch.cuda.manual_seed(0)
61 | import random
62 | random.seed(0)
63 |
64 |
65 | # ## Load the config
66 |
67 | # In[3]:
68 |
69 |
70 | CONFIG_NAME = "isic/isic2018_attunet.yaml"
71 | CONFIG_FILE_PATH = os.path.join("../../configs", CONFIG_NAME)
72 |
73 |
74 | # In[4]:
75 |
76 |
77 | config = load_config(CONFIG_FILE_PATH)
78 | _print("Config:", "info_underline")
79 | print(json.dumps(config, indent=2))
80 | print(20*"~-", "\n")
81 |
82 |
83 | # ## Dataset and Dataloader
84 |
85 | # In[6]:
86 |
87 |
88 | from datasets.isic import ISIC2018DatasetFast
89 | from torch.utils.data import DataLoader, Subset
90 | from torchvision import transforms
91 |
92 |
93 | # In[7]:
94 |
95 |
96 | # ------------------- params --------------------
97 | INPUT_SIZE = config['dataset']['input_size']
98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
99 |
100 |
101 | # ----------------- dataset --------------------
102 | # preparing training dataset
103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True)
104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True)
105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True)
106 |
107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py
109 |
110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}")
111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}")
112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}")
113 |
114 |
115 | # prepare train dataloader
116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train'])
117 |
118 | # prepare validation dataloader
119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation'])
120 |
121 | # prepare test dataloader
122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test'])
123 |
124 | # -------------- test -----------------
125 | # test and visualize the input data
126 | for sample in tr_dataloader:
127 | img = sample['image']
128 | msk = sample['mask']
129 | print("\n Training")
130 | show_sbs(img[0], msk[0,1])
131 | break
132 |
133 | for sample in vl_dataloader:
134 | img = sample['image']
135 | msk = sample['mask']
136 | print("Validation")
137 | show_sbs(img[0], msk[0,1])
138 | break
139 |
140 | for sample in te_dataloader:
141 | img = sample['image']
142 | msk = sample['mask']
143 | print("Test")
144 | show_sbs(img[0], msk[0,1])
145 | break
146 |
147 |
148 | # ### Device
149 |
150 | # In[9]:
151 |
152 |
153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154 | print(f"Torch device: {device}")
155 |
156 |
157 | # ## Metrics
158 |
159 | # In[ ]:
160 |
161 |
162 | metrics = torchmetrics.MetricCollection(
163 | [
164 | torchmetrics.F1Score(),
165 | torchmetrics.Accuracy(),
166 | torchmetrics.Dice(),
167 | torchmetrics.Precision(),
168 | torchmetrics.Specificity(),
169 | torchmetrics.Recall(),
170 | # IoU
171 | torchmetrics.JaccardIndex(2)
172 | ],
173 | prefix='train_metrics/'
174 | )
175 |
176 | # train_metrics
177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device)
178 |
179 | # valid_metrics
180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device)
181 |
182 | # test_metrics
183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device)
184 |
185 |
186 | # In[13]:
187 |
188 |
189 | def make_serializeable_metrics(computed_metrics):
190 | res = {}
191 | for k, v in computed_metrics.items():
192 | res[k] = float(v.cpu().detach().numpy())
193 | return res
194 |
195 |
196 | # ## Define validate function
197 |
198 | # In[14]:
199 |
200 |
201 | def validate(model, criterion, vl_dataloader):
202 | model.eval()
203 | with torch.no_grad():
204 |
205 | evaluator = valid_metrics.clone().to(device)
206 |
207 | losses = []
208 | cnt = 0.
209 | for batch, batch_data in enumerate(vl_dataloader):
210 | imgs = batch_data['image']
211 | msks = batch_data['mask']
212 |
213 | cnt += msks.shape[0]
214 |
215 | imgs = imgs.to(device)
216 | msks = msks.to(device)
217 |
218 | preds = model(imgs)
219 | loss = criterion(preds, msks)
220 | losses.append(loss.item())
221 |
222 |
223 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
224 | msks_ = torch.argmax(msks, 1, keepdim=False)
225 | evaluator.update(preds_, msks_)
226 |
227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}"
228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}"
229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}")
230 |
231 | # print the final results
232 | loss = np.sum(losses)/cnt
233 | metrics = evaluator.compute()
234 |
235 | return evaluator, loss
236 |
237 |
238 | # ## Define train function
239 |
240 | # In[15]:
241 |
242 |
243 | def train(
244 | model,
245 | device,
246 | tr_dataloader,
247 | vl_dataloader,
248 | config,
249 |
250 | criterion,
251 | optimizer,
252 | scheduler,
253 |
254 | save_dir='./',
255 | save_file_id=None,
256 | ):
257 |
258 | EPOCHS = tr_prms['epochs']
259 |
260 | torch.cuda.empty_cache()
261 | model = model.to(device)
262 |
263 | evaluator = train_metrics.clone().to(device)
264 |
265 | epochs_info = []
266 | best_model = None
267 | best_result = {}
268 | best_vl_loss = np.Inf
269 | for epoch in range(EPOCHS):
270 | model.train()
271 |
272 | evaluator.reset()
273 | tr_iterator = tqdm(enumerate(tr_dataloader))
274 | tr_losses = []
275 | cnt = 0
276 | for batch, batch_data in tr_iterator:
277 | imgs = batch_data['image']
278 | msks = batch_data['mask']
279 |
280 | imgs = imgs.to(device)
281 | msks = msks.to(device)
282 |
283 | optimizer.zero_grad()
284 | preds = model(imgs)
285 | loss = criterion(preds, msks)
286 | loss.backward()
287 | optimizer.step()
288 |
289 | # evaluate by metrics
290 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
291 | msks_ = torch.argmax(msks, 1, keepdim=False)
292 | evaluator.update(preds_, msks_)
293 |
294 | cnt += imgs.shape[0]
295 | tr_losses.append(loss.item())
296 |
297 | # write details for each training batch
298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}"
299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}"
300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}")
301 |
302 | tr_loss = np.sum(tr_losses)/cnt
303 |
304 | # validate model
305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader)
306 | if vl_loss < best_vl_loss:
307 | # find a better model
308 | best_model = model
309 | best_vl_loss = vl_loss
310 | best_result = {
311 | 'tr_loss': tr_loss,
312 | 'vl_loss': vl_loss,
313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
315 | }
316 |
317 | # write the final results
318 | epoch_info = {
319 | 'tr_loss': tr_loss,
320 | 'vl_loss': vl_loss,
321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
323 | }
324 | epochs_info.append(epoch_info)
325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}")
326 | evaluator.reset()
327 |
328 | scheduler.step(vl_loss)
329 |
330 | # save final results
331 | res = {
332 | 'id': save_file_id,
333 | 'config': config,
334 | 'epochs_info': epochs_info,
335 | 'best_result': best_result
336 | }
337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json"
338 | fp = os.path.join(config['model']['save_dir'],fn)
339 | with open(fp, "w") as write_file:
340 | json.dump(res, write_file, indent=4)
341 |
342 | # save model's state_dict
343 | fn = "last_model_state_dict.pt"
344 | fp = os.path.join(config['model']['save_dir'],fn)
345 | torch.save(model.state_dict(), fp)
346 |
347 | # save the best model's state_dict
348 | fn = "best_model_state_dict.pt"
349 | fp = os.path.join(config['model']['save_dir'], fn)
350 | torch.save(best_model.state_dict(), fp)
351 |
352 | return best_model, model, res
353 |
354 |
355 | # ## Define test function
356 |
357 | # In[17]:
358 |
359 |
360 | def test(model, te_dataloader):
361 | model.eval()
362 | with torch.no_grad():
363 | evaluator = test_metrics.clone().to(device)
364 | for batch_data in tqdm(te_dataloader):
365 | imgs = batch_data['image']
366 | msks = batch_data['mask']
367 |
368 | imgs = imgs.to(device)
369 | msks = msks.to(device)
370 |
371 | preds = model(imgs)
372 |
373 | # evaluate by metrics
374 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
375 | msks_ = torch.argmax(msks, 1, keepdim=False)
376 | evaluator.update(preds_, msks_)
377 |
378 | return evaluator
379 |
380 |
381 | # ## Load and prepare model
382 |
383 | # In[18]:
384 |
385 |
386 | from models.attunet import AttU_Net as Net
387 |
388 |
389 | model = Net(**config['model']['params'])
390 | torch.cuda.empty_cache()
391 | model = model.to(device)
392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
393 |
394 | os.makedirs(config['model']['save_dir'], exist_ok=True)
395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt"
396 |
397 | if config['model']['load_weights']:
398 | model.load_state_dict(torch.load(model_path))
399 | print("Loaded pre-trained weights...")
400 |
401 |
402 | # criterion_dice = DiceLoss()
403 | criterion_dice = DiceLossWithLogtis()
404 | # criterion_ce = BCELoss()
405 | criterion_ce = CrossEntropyLoss()
406 |
407 |
408 | def criterion(preds, masks):
409 | c_dice = criterion_dice(preds, masks)
410 | c_ce = criterion_ce(preds, masks)
411 | return 0.5*c_dice + 0.5*c_ce
412 |
413 | tr_prms = config['training']
414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params'])
415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler'])
416 |
417 |
418 | # ## Start traning
419 |
420 | # In[19]:
421 |
422 |
423 | best_model, model, res = train(
424 | model,
425 | device,
426 | tr_dataloader,
427 | vl_dataloader,
428 | config,
429 |
430 | criterion,
431 | optimizer,
432 | scheduler,
433 |
434 | save_dir = config['model']['save_dir'],
435 | save_file_id = None,
436 | )
437 |
438 |
439 | # In[20]:
440 |
441 |
442 | te_metrics = test(best_model, te_dataloader)
443 | te_metrics.compute()
444 |
445 |
446 | # In[21]:
447 |
448 |
449 | f"{config['model']['save_dir']}"
450 |
451 |
452 | # # Test the best inferred model
453 | # ----
454 |
455 | # ## Load the best model
456 |
457 | # In[22]:
458 |
459 |
460 | best_model = Net(**config['model']['params'])
461 | torch.cuda.empty_cache()
462 | best_model = best_model.to(device)
463 |
464 | fn = "best_model_state_dict.pt"
465 | os.makedirs(config['model']['save_dir'], exist_ok=True)
466 | model_path = f"{config['model']['save_dir']}/{fn}"
467 |
468 | best_model.load_state_dict(torch.load(model_path))
469 | print("Loaded best model weights...")
470 |
471 |
472 | # ## Evaluation
473 |
474 | # In[23]:
475 |
476 |
477 | te_metrics = test(best_model, te_dataloader)
478 | te_metrics.compute()
479 |
480 |
481 | # ## Plot graphs
482 |
483 | # In[24]:
484 |
485 |
486 | result_file_path = f"{config['model']['save_dir']}/result.json"
487 | with open(result_file_path, 'r') as f:
488 | results = json.loads(''.join(f.readlines()))
489 | epochs_info = results['epochs_info']
490 |
491 | tr_losses = [d['tr_loss'] for d in epochs_info]
492 | vl_losses = [d['vl_loss'] for d in epochs_info]
493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info]
494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info]
495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info]
496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info]
497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info]
498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info]
499 |
500 |
501 | _, axs = plt.subplots(1, 4, figsize=[16,3])
502 |
503 | axs[0].set_title("Loss")
504 | axs[0].plot(tr_losses, 'r-', label="train loss")
505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss")
506 | axs[0].legend()
507 |
508 | axs[1].set_title("Dice score")
509 | axs[1].plot(tr_dice, 'r-', label="train dice")
510 | axs[1].plot(vl_dice, 'b-', label="validation dice")
511 | axs[1].legend()
512 |
513 | axs[2].set_title("Jaccard Similarity")
514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex")
515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex")
516 | axs[2].legend()
517 |
518 | axs[3].set_title("Accuracy")
519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy")
520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy")
521 | axs[3].legend()
522 |
523 | plt.show()
524 |
525 |
526 | # In[25]:
527 |
528 |
529 | epochs_info
530 |
531 |
532 | # ## Save images
533 |
534 | # In[30]:
535 |
536 |
537 | from PIL import Image
538 | import cv2
539 | def skin_plot(img, gt, pred):
540 | img = np.array(img)
541 | gt = np.array(gt)
542 | pred = np.array(pred)
543 | edged_test = cv2.Canny(pred, 100, 255)
544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
545 | edged_gt = cv2.Canny(gt, 100, 255)
546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
547 | for cnt_test in contours_test:
548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1)
549 | for cnt_gt in contours_gt:
550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1)
551 | return img
552 |
553 | #---------------------------------------------------------------------------------------------
554 |
555 |
556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized"
557 |
558 | if not os.path.isdir(save_imgs_dir):
559 | os.mkdir(save_imgs_dir)
560 |
561 | with torch.no_grad():
562 | for batch in tqdm(te_dataloader):
563 | imgs = batch['image']
564 | msks = batch['mask']
565 | ids = batch['id']
566 |
567 | preds = best_model(imgs.to(device))
568 |
569 | txm = imgs.cpu().numpy()
570 | tbm = torch.argmax(msks, 1).cpu().numpy()
571 | tpm = torch.argmax(preds, 1).cpu().numpy()
572 | tid = ids
573 |
574 | for idx in range(len(tbm)):
575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255.
576 | img = np.ascontiguousarray(img, dtype=np.uint8)
577 | gt = np.uint8(tbm[idx]*255.)
578 | pred = np.where(tpm[idx]>0.5, 255, 0)
579 | pred = np.ascontiguousarray(pred, dtype=np.uint8)
580 |
581 | res_img = skin_plot(img, gt, pred)
582 |
583 | fid = tid[idx]
584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png")
585 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png")
586 |
587 |
588 | # In[31]:
589 |
590 |
591 | f"{config['model']['save_dir']}/visualized"
592 |
593 |
--------------------------------------------------------------------------------
/train_and_test/isic/resunet-isic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # # Residual UNet - ISIC2018
5 | # ---
6 |
7 | # ## Import packages & functions
8 |
9 | # In[1]:
10 |
11 |
12 | from __future__ import print_function, division
13 |
14 |
15 | import os
16 | import sys
17 | sys.path.append('../..')
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19 |
20 | import copy
21 | import json
22 | import importlib
23 | import glob
24 | import pandas as pd
25 | from skimage import io, transform
26 | import matplotlib.pyplot as plt
27 | from matplotlib.image import imread
28 | import numpy as np
29 | from tqdm import tqdm
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 | import torch.optim as optim
35 | import torchmetrics
36 | from torch.optim import Adam, SGD
37 | from losses import DiceLoss, DiceLossWithLogtis
38 | from torch.nn import BCELoss, CrossEntropyLoss
39 |
40 | from utils import (
41 | show_sbs,
42 | load_config,
43 | _print,
44 | )
45 |
46 | # Ignore warnings
47 | import warnings
48 | warnings.filterwarnings("ignore")
49 |
50 | # plt.ion() # interactive mode
51 |
52 |
53 | # ## Set the seed
54 |
55 | # In[2]:
56 |
57 |
58 | torch.manual_seed(0)
59 | np.random.seed(0)
60 | torch.cuda.manual_seed(0)
61 | import random
62 | random.seed(0)
63 |
64 |
65 | # ## Load the config
66 |
67 | # In[3]:
68 |
69 |
70 | CONFIG_NAME = "isic/isic2018_resunet.yaml"
71 | CONFIG_FILE_PATH = os.path.join("./configs", CONFIG_NAME)
72 |
73 |
74 | # In[4]:
75 |
76 |
77 | config = load_config(CONFIG_FILE_PATH)
78 | _print("Config:", "info_underline")
79 | print(json.dumps(config, indent=2))
80 | print(20*"~-", "\n")
81 |
82 |
83 | # ## Dataset and Dataloader
84 |
85 | # In[6]:
86 |
87 |
88 | from datasets.isic import ISIC2018DatasetFast
89 | from torch.utils.data import DataLoader, Subset
90 | from torchvision import transforms
91 |
92 |
93 | # In[7]:
94 |
95 |
96 | # ------------------- params --------------------
97 | INPUT_SIZE = config['dataset']['input_size']
98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
99 |
100 |
101 | # ----------------- dataset --------------------
102 | # preparing training dataset
103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True)
104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True)
105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True)
106 |
107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py
109 |
110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}")
111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}")
112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}")
113 |
114 |
115 | # prepare train dataloader
116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train'])
117 |
118 | # prepare validation dataloader
119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation'])
120 |
121 | # prepare test dataloader
122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test'])
123 |
124 | # -------------- test -----------------
125 | # test and visualize the input data
126 | for sample in tr_dataloader:
127 | img = sample['image']
128 | msk = sample['mask']
129 | print("\n Training")
130 | show_sbs(img[0], msk[0,1])
131 | break
132 |
133 | for sample in vl_dataloader:
134 | img = sample['image']
135 | msk = sample['mask']
136 | print("Validation")
137 | show_sbs(img[0], msk[0,1])
138 | break
139 |
140 | for sample in te_dataloader:
141 | img = sample['image']
142 | msk = sample['mask']
143 | print("Test")
144 | show_sbs(img[0], msk[0,1])
145 | break
146 |
147 |
148 | # ### Device
149 |
150 | # In[9]:
151 |
152 |
153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154 | print(f"Torch device: {device}")
155 |
156 |
157 | # ## Metrics
158 |
159 | # In[12]:
160 |
161 |
162 | metrics = torchmetrics.MetricCollection(
163 | [
164 | torchmetrics.F1Score(),
165 | torchmetrics.Accuracy(),
166 | torchmetrics.Dice(),
167 | torchmetrics.Precision(),
168 | torchmetrics.Specificity(),
169 | torchmetrics.Recall(),
170 | # IoU
171 | torchmetrics.JaccardIndex(2)
172 | ],
173 | prefix='train_metrics/'
174 | )
175 |
176 | # train_metrics
177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device)
178 |
179 | # valid_metrics
180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device)
181 |
182 | # test_metrics
183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device)
184 |
185 |
186 | # In[13]:
187 |
188 |
189 | def make_serializeable_metrics(computed_metrics):
190 | res = {}
191 | for k, v in computed_metrics.items():
192 | res[k] = float(v.cpu().detach().numpy())
193 | return res
194 |
195 |
196 | # ## Define validate function
197 |
198 | # In[14]:
199 |
200 |
201 | def validate(model, criterion, vl_dataloader):
202 | model.eval()
203 | with torch.no_grad():
204 |
205 | evaluator = valid_metrics.clone().to(device)
206 |
207 | losses = []
208 | cnt = 0.
209 | for batch, batch_data in enumerate(vl_dataloader):
210 | imgs = batch_data['image']
211 | msks = batch_data['mask']
212 |
213 | cnt += msks.shape[0]
214 |
215 | imgs = imgs.to(device)
216 | msks = msks.to(device)
217 |
218 | preds = model(imgs)
219 | loss = criterion(preds, msks)
220 | losses.append(loss.item())
221 |
222 |
223 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
224 | msks_ = torch.argmax(msks, 1, keepdim=False)
225 | evaluator.update(preds_, msks_)
226 |
227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}"
228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}"
229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}")
230 |
231 | # print the final results
232 | loss = np.sum(losses)/cnt
233 | metrics = evaluator.compute()
234 |
235 | return evaluator, loss
236 |
237 |
238 | # ## Define train function
239 |
240 | # In[15]:
241 |
242 |
243 | def train(
244 | model,
245 | device,
246 | tr_dataloader,
247 | vl_dataloader,
248 | config,
249 |
250 | criterion,
251 | optimizer,
252 | scheduler,
253 |
254 | save_dir='./',
255 | save_file_id=None,
256 | ):
257 |
258 | EPOCHS = tr_prms['epochs']
259 |
260 | torch.cuda.empty_cache()
261 | model = model.to(device)
262 |
263 | evaluator = train_metrics.clone().to(device)
264 |
265 | epochs_info = []
266 | best_model = None
267 | best_result = {}
268 | best_vl_loss = np.Inf
269 | for epoch in range(EPOCHS):
270 | model.train()
271 |
272 | evaluator.reset()
273 | tr_iterator = tqdm(enumerate(tr_dataloader))
274 | tr_losses = []
275 | cnt = 0
276 | for batch, batch_data in tr_iterator:
277 | imgs = batch_data['image']
278 | msks = batch_data['mask']
279 |
280 | imgs = imgs.to(device)
281 | msks = msks.to(device)
282 |
283 | optimizer.zero_grad()
284 | preds = model(imgs)
285 | loss = criterion(preds, msks)
286 | loss.backward()
287 | optimizer.step()
288 |
289 | # evaluate by metrics
290 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
291 | msks_ = torch.argmax(msks, 1, keepdim=False)
292 | evaluator.update(preds_, msks_)
293 |
294 | cnt += imgs.shape[0]
295 | tr_losses.append(loss.item())
296 |
297 | # write details for each training batch
298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}"
299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}"
300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}")
301 |
302 | tr_loss = np.sum(tr_losses)/cnt
303 |
304 | # validate model
305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader)
306 | if vl_loss < best_vl_loss:
307 | # find a better model
308 | best_model = model
309 | best_vl_loss = vl_loss
310 | best_result = {
311 | 'tr_loss': tr_loss,
312 | 'vl_loss': vl_loss,
313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
315 | }
316 |
317 | # write the final results
318 | epoch_info = {
319 | 'tr_loss': tr_loss,
320 | 'vl_loss': vl_loss,
321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
323 | }
324 | epochs_info.append(epoch_info)
325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}")
326 | evaluator.reset()
327 |
328 | scheduler.step(vl_loss)
329 |
330 | # save final results
331 | res = {
332 | 'id': save_file_id,
333 | 'config': config,
334 | 'epochs_info': epochs_info,
335 | 'best_result': best_result
336 | }
337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json"
338 | fp = os.path.join(config['model']['save_dir'],fn)
339 | with open(fp, "w") as write_file:
340 | json.dump(res, write_file, indent=4)
341 |
342 | # save model's state_dict
343 | fn = "last_model_state_dict.pt"
344 | fp = os.path.join(config['model']['save_dir'],fn)
345 | torch.save(model.state_dict(), fp)
346 |
347 | # save the best model's state_dict
348 | fn = "best_model_state_dict.pt"
349 | fp = os.path.join(config['model']['save_dir'], fn)
350 | torch.save(best_model.state_dict(), fp)
351 |
352 | return best_model, model, res
353 |
354 |
355 | # ## Define test function
356 |
357 | # In[17]:
358 |
359 |
360 | def test(model, te_dataloader):
361 | model.eval()
362 | with torch.no_grad():
363 | evaluator = test_metrics.clone().to(device)
364 | for batch_data in tqdm(te_dataloader):
365 | imgs = batch_data['image']
366 | msks = batch_data['mask']
367 |
368 | imgs = imgs.to(device)
369 | msks = msks.to(device)
370 |
371 | preds = model(imgs)
372 |
373 | # evaluate by metrics
374 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
375 | msks_ = torch.argmax(msks, 1, keepdim=False)
376 | evaluator.update(preds_, msks_)
377 |
378 | return evaluator
379 |
380 |
381 | # ## Load and prepare model
382 |
383 | # In[18]:
384 |
385 |
386 | from models._resunet.res_unet import ResUnet as Net
387 |
388 |
389 | model = Net(**config['model']['params'])
390 | torch.cuda.empty_cache()
391 | model = model.to(device)
392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
393 |
394 | os.makedirs(config['model']['save_dir'], exist_ok=True)
395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt"
396 |
397 | if config['model']['load_weights']:
398 | model.load_state_dict(torch.load(model_path))
399 | print("Loaded pre-trained weights...")
400 |
401 |
402 | # criterion_dice = DiceLoss()
403 | criterion_dice = DiceLossWithLogtis()
404 | # criterion_ce = BCELoss()
405 | criterion_ce = CrossEntropyLoss()
406 |
407 |
408 | def criterion(preds, masks):
409 | c_dice = criterion_dice(preds, masks)
410 | c_ce = criterion_ce(preds, masks)
411 | return 0.5*c_dice + 0.5*c_ce
412 |
413 | tr_prms = config['training']
414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params'])
415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler'])
416 |
417 |
418 | # ## Start traning
419 |
420 | # In[19]:
421 |
422 |
423 | best_model, model, res = train(
424 | model,
425 | device,
426 | tr_dataloader,
427 | vl_dataloader,
428 | config,
429 |
430 | criterion,
431 | optimizer,
432 | scheduler,
433 |
434 | save_dir = config['model']['save_dir'],
435 | save_file_id = None,
436 | )
437 |
438 |
439 | # In[20]:
440 |
441 |
442 | te_metrics = test(best_model, te_dataloader)
443 | te_metrics.compute()
444 |
445 |
446 | # In[21]:
447 |
448 |
449 | f"{config['model']['save_dir']}"
450 |
451 |
452 | # # Test the best inferred model
453 | # ----
454 |
455 | # ## Load the best model
456 |
457 | # In[22]:
458 |
459 |
460 | best_model = Net(**config['model']['params'])
461 | torch.cuda.empty_cache()
462 | best_model = best_model.to(device)
463 |
464 | fn = "best_model_state_dict.pt"
465 | os.makedirs(config['model']['save_dir'], exist_ok=True)
466 | model_path = f"{config['model']['save_dir']}/{fn}"
467 |
468 | best_model.load_state_dict(torch.load(model_path))
469 | print("Loaded best model weights...")
470 |
471 |
472 | # ## Evaluation
473 |
474 | # In[23]:
475 |
476 |
477 | te_metrics = test(best_model, te_dataloader)
478 | te_metrics.compute()
479 |
480 |
481 | # ## Plot graphs
482 |
483 | # In[24]:
484 |
485 |
486 | result_file_path = f"{config['model']['save_dir']}/result.json"
487 | with open(result_file_path, 'r') as f:
488 | results = json.loads(''.join(f.readlines()))
489 | epochs_info = results['epochs_info']
490 |
491 | tr_losses = [d['tr_loss'] for d in epochs_info]
492 | vl_losses = [d['vl_loss'] for d in epochs_info]
493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info]
494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info]
495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info]
496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info]
497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info]
498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info]
499 |
500 |
501 | _, axs = plt.subplots(1, 4, figsize=[16,3])
502 |
503 | axs[0].set_title("Loss")
504 | axs[0].plot(tr_losses, 'r-', label="train loss")
505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss")
506 | axs[0].legend()
507 |
508 | axs[1].set_title("Dice score")
509 | axs[1].plot(tr_dice, 'r-', label="train dice")
510 | axs[1].plot(vl_dice, 'b-', label="validation dice")
511 | axs[1].legend()
512 |
513 | axs[2].set_title("Jaccard Similarity")
514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex")
515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex")
516 | axs[2].legend()
517 |
518 | axs[3].set_title("Accuracy")
519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy")
520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy")
521 | axs[3].legend()
522 |
523 | plt.show()
524 |
525 |
526 | # In[25]:
527 |
528 |
529 | epochs_info
530 |
531 |
532 | # ## Save images
533 |
534 | # In[28]:
535 |
536 |
537 | from PIL import Image
538 | import cv2
539 | def skin_plot(img, gt, pred):
540 | img = np.array(img)
541 | gt = np.array(gt)
542 | pred = np.array(pred)
543 | edged_test = cv2.Canny(pred, 100, 255)
544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
545 | edged_gt = cv2.Canny(gt, 100, 255)
546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
547 | for cnt_test in contours_test:
548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1)
549 | for cnt_gt in contours_gt:
550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1)
551 | return img
552 |
553 | #---------------------------------------------------------------------------------------------
554 |
555 |
556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized"
557 |
558 | if not os.path.isdir(save_imgs_dir):
559 | os.mkdir(save_imgs_dir)
560 |
561 | with torch.no_grad():
562 | for batch in tqdm(te_dataloader):
563 | imgs = batch['image']
564 | msks = batch['mask']
565 | ids = batch['id']
566 |
567 | preds = best_model(imgs.to(device))
568 |
569 | txm = imgs.cpu().numpy()
570 | tbm = torch.argmax(msks, 1).cpu().numpy()
571 | tpm = torch.argmax(preds, 1).cpu().numpy()
572 | tid = ids
573 |
574 | for idx in range(len(tbm)):
575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255.
576 | img = np.ascontiguousarray(img, dtype=np.uint8)
577 | gt = np.uint8(tbm[idx]*255.)
578 | pred = np.where(tpm[idx]>0.5, 255, 0)
579 | pred = np.ascontiguousarray(pred, dtype=np.uint8)
580 |
581 | res_img = skin_plot(img, gt, pred)
582 |
583 | fid = tid[idx]
584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png")
585 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png")
586 |
587 |
588 | # In[29]:
589 |
590 |
591 | f"{config['model']['save_dir']}/visualized"
592 |
593 |
--------------------------------------------------------------------------------
/train_and_test/isic/unet-isic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # # UNet - ISIC2018
5 | # ---
6 |
7 | # ## Import packages & functions
8 |
9 | # In[1]:
10 |
11 |
12 | from __future__ import print_function, division
13 |
14 |
15 | import os
16 | import sys
17 | sys.path.append('../..')
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19 |
20 | import copy
21 | import json
22 | import importlib
23 | import glob
24 | import pandas as pd
25 | from skimage import io, transform
26 | import matplotlib.pyplot as plt
27 | from matplotlib.image import imread
28 | import numpy as np
29 | from tqdm import tqdm
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 | import torch.optim as optim
35 | import torchmetrics
36 | from torch.optim import Adam, SGD
37 | from losses import DiceLoss, DiceLossWithLogtis
38 | from torch.nn import BCELoss, CrossEntropyLoss
39 |
40 | from utils import (
41 | show_sbs,
42 | load_config,
43 | _print,
44 | )
45 |
46 | # Ignore warnings
47 | import warnings
48 | warnings.filterwarnings("ignore")
49 |
50 | # plt.ion() # interactive mode
51 |
52 |
53 | # ## Set the seed
54 |
55 | # In[2]:
56 |
57 |
58 | torch.manual_seed(0)
59 | np.random.seed(0)
60 | torch.cuda.manual_seed(0)
61 | import random
62 | random.seed(0)
63 |
64 |
65 | # ## Load the config
66 |
67 | # In[3]:
68 |
69 |
70 | CONFIG_NAME = "isic/isic2018_unet.yaml"
71 | CONFIG_FILE_PATH = os.path.join("../../configs", CONFIG_NAME)
72 |
73 |
74 | # In[4]:
75 |
76 |
77 | config = load_config(CONFIG_FILE_PATH)
78 | _print("Config:", "info_underline")
79 | print(json.dumps(config, indent=2))
80 | print(20*"~-", "\n")
81 |
82 |
83 | # ## Dataset and Dataloader
84 |
85 | # In[6]:
86 |
87 |
88 | from datasets.isic import ISIC2018DatasetFast
89 | from torch.utils.data import DataLoader, Subset
90 | from torchvision import transforms
91 |
92 |
93 | # In[7]:
94 |
95 |
96 | # ------------------- params --------------------
97 | INPUT_SIZE = config['dataset']['input_size']
98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
99 |
100 |
101 | # ----------------- dataset --------------------
102 | # preparing training dataset
103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True)
104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True)
105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True)
106 |
107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py
109 |
110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}")
111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}")
112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}")
113 |
114 |
115 | # prepare train dataloader
116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train'])
117 |
118 | # prepare validation dataloader
119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation'])
120 |
121 | # prepare test dataloader
122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test'])
123 |
124 | # -------------- test -----------------
125 | # test and visualize the input data
126 | for sample in tr_dataloader:
127 | img = sample['image']
128 | msk = sample['mask']
129 | print("\n Training")
130 | show_sbs(img[0], msk[0,1])
131 | break
132 |
133 | for sample in vl_dataloader:
134 | img = sample['image']
135 | msk = sample['mask']
136 | print("Validation")
137 | show_sbs(img[0], msk[0,1])
138 | break
139 |
140 | for sample in te_dataloader:
141 | img = sample['image']
142 | msk = sample['mask']
143 | print("Test")
144 | show_sbs(img[0], msk[0,1])
145 | break
146 |
147 |
148 | # ### Device
149 |
150 | # In[9]:
151 |
152 |
153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154 | print(f"Torch device: {device}")
155 |
156 |
157 | # ## Metrics
158 |
159 | # In[12]:
160 |
161 |
162 | metrics = torchmetrics.MetricCollection(
163 | [
164 | torchmetrics.F1Score(),
165 | torchmetrics.Accuracy(),
166 | torchmetrics.Dice(),
167 | torchmetrics.Precision(),
168 | torchmetrics.Recall(),
169 | torchmetrics.Specificity(),
170 | # IoU
171 | torchmetrics.JaccardIndex(2)
172 | ],
173 | prefix='train_metrics/'
174 | )
175 |
176 | # train_metrics
177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device)
178 |
179 | # valid_metrics
180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device)
181 |
182 | # test_metrics
183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device)
184 |
185 |
186 | # In[13]:
187 |
188 |
189 | def make_serializeable_metrics(computed_metrics):
190 | res = {}
191 | for k, v in computed_metrics.items():
192 | res[k] = float(v.cpu().detach().numpy())
193 | return res
194 |
195 |
196 | # ## Define validate function
197 |
198 | # In[14]:
199 |
200 |
201 | def validate(model, criterion, vl_dataloader):
202 | model.eval()
203 | with torch.no_grad():
204 |
205 | evaluator = valid_metrics.clone().to(device)
206 |
207 | losses = []
208 | cnt = 0.
209 | for batch, batch_data in enumerate(vl_dataloader):
210 | imgs = batch_data['image']
211 | msks = batch_data['mask']
212 |
213 | cnt += msks.shape[0]
214 |
215 | imgs = imgs.to(device)
216 | msks = msks.to(device)
217 |
218 | preds = model(imgs)
219 | loss = criterion(preds, msks)
220 | losses.append(loss.item())
221 |
222 |
223 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
224 | msks_ = torch.argmax(msks, 1, keepdim=False)
225 | evaluator.update(preds_, msks_)
226 |
227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}"
228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}"
229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}")
230 |
231 | # print the final results
232 | loss = np.sum(losses)/cnt
233 | metrics = evaluator.compute()
234 |
235 | return evaluator, loss
236 |
237 |
238 | # ## Define train function
239 |
240 | # In[15]:
241 |
242 |
243 | def train(
244 | model,
245 | device,
246 | tr_dataloader,
247 | vl_dataloader,
248 | config,
249 |
250 | criterion,
251 | optimizer,
252 | scheduler,
253 |
254 | save_dir='./',
255 | save_file_id=None,
256 | ):
257 |
258 | EPOCHS = tr_prms['epochs']
259 |
260 | torch.cuda.empty_cache()
261 | model = model.to(device)
262 |
263 | evaluator = train_metrics.clone().to(device)
264 |
265 | epochs_info = []
266 | best_model = None
267 | best_result = {}
268 | best_vl_loss = np.Inf
269 | for epoch in range(EPOCHS):
270 | model.train()
271 |
272 | evaluator.reset()
273 | tr_iterator = tqdm(enumerate(tr_dataloader))
274 | tr_losses = []
275 | cnt = 0
276 | for batch, batch_data in tr_iterator:
277 | imgs = batch_data['image']
278 | msks = batch_data['mask']
279 |
280 | imgs = imgs.to(device)
281 | msks = msks.to(device)
282 |
283 | optimizer.zero_grad()
284 | preds = model(imgs)
285 | loss = criterion(preds, msks)
286 | loss.backward()
287 | optimizer.step()
288 |
289 | # evaluate by metrics
290 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
291 | msks_ = torch.argmax(msks, 1, keepdim=False)
292 | evaluator.update(preds_, msks_)
293 |
294 | cnt += imgs.shape[0]
295 | tr_losses.append(loss.item())
296 |
297 | # write details for each training batch
298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}"
299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}"
300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}")
301 |
302 | tr_loss = np.sum(tr_losses)/cnt
303 |
304 | # validate model
305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader)
306 | if vl_loss < best_vl_loss:
307 | # find a better model
308 | best_model = model
309 | best_vl_loss = vl_loss
310 | best_result = {
311 | 'tr_loss': tr_loss,
312 | 'vl_loss': vl_loss,
313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
315 | }
316 |
317 | # write the final results
318 | epoch_info = {
319 | 'tr_loss': tr_loss,
320 | 'vl_loss': vl_loss,
321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
323 | }
324 | epochs_info.append(epoch_info)
325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}")
326 | evaluator.reset()
327 |
328 | scheduler.step(vl_loss)
329 |
330 | # save final results
331 | res = {
332 | 'id': save_file_id,
333 | 'config': config,
334 | 'epochs_info': epochs_info,
335 | 'best_result': best_result
336 | }
337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json"
338 | fp = os.path.join(config['model']['save_dir'],fn)
339 | with open(fp, "w") as write_file:
340 | json.dump(res, write_file, indent=4)
341 |
342 | # save model's state_dict
343 | fn = "last_model_state_dict.pt"
344 | fp = os.path.join(config['model']['save_dir'],fn)
345 | torch.save(model.state_dict(), fp)
346 |
347 | # save the best model's state_dict
348 | fn = "best_model_state_dict.pt"
349 | fp = os.path.join(config['model']['save_dir'], fn)
350 | torch.save(best_model.state_dict(), fp)
351 |
352 | return best_model, model, res
353 |
354 |
355 | # ## Define test function
356 |
357 | # In[17]:
358 |
359 |
360 | def test(model, te_dataloader):
361 | model.eval()
362 | with torch.no_grad():
363 | evaluator = test_metrics.clone().to(device)
364 | for batch_data in tqdm(te_dataloader):
365 | imgs = batch_data['image']
366 | msks = batch_data['mask']
367 |
368 | imgs = imgs.to(device)
369 | msks = msks.to(device)
370 |
371 | preds = model(imgs)
372 |
373 | # evaluate by metrics
374 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
375 | msks_ = torch.argmax(msks, 1, keepdim=False)
376 | evaluator.update(preds_, msks_)
377 |
378 | return evaluator
379 |
380 |
381 | # ## Load and prepare model
382 |
383 | # In[18]:
384 |
385 |
386 | from models.unet import UNet as Net
387 |
388 |
389 | model = Net(**config['model']['params'])
390 | torch.cuda.empty_cache()
391 | model = model.to(device)
392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
393 |
394 | os.makedirs(config['model']['save_dir'], exist_ok=True)
395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt"
396 |
397 | if config['model']['load_weights']:
398 | model.load_state_dict(torch.load(model_path))
399 | print("Loaded pre-trained weights...")
400 |
401 |
402 | # criterion_dice = DiceLoss()
403 | criterion_dice = DiceLossWithLogtis()
404 | # criterion_ce = BCELoss()
405 | criterion_ce = CrossEntropyLoss()
406 |
407 |
408 | def criterion(preds, masks):
409 | c_dice = criterion_dice(preds, masks)
410 | c_ce = criterion_ce(preds, masks)
411 | return 0.5*c_dice + 0.5*c_ce
412 |
413 | tr_prms = config['training']
414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params'])
415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler'])
416 |
417 |
418 | # ## Start traning
419 |
420 | # In[19]:
421 |
422 |
423 | best_model, model, res = train(
424 | model,
425 | device,
426 | tr_dataloader,
427 | vl_dataloader,
428 | config,
429 |
430 | criterion,
431 | optimizer,
432 | scheduler,
433 |
434 | save_dir = config['model']['save_dir'],
435 | save_file_id = None,
436 | )
437 |
438 |
439 | # In[27]:
440 |
441 |
442 | te_metrics = test(best_model, te_dataloader)
443 | te_metrics.compute()
444 |
445 |
446 | # In[28]:
447 |
448 |
449 | f"{config['model']['save_dir']}"
450 |
451 |
452 | # # Test the best inferred model
453 | # ----
454 |
455 | # ## Load the best model
456 |
457 | # In[29]:
458 |
459 |
460 | best_model = Net(**config['model']['params'])
461 | torch.cuda.empty_cache()
462 | best_model = best_model.to(device)
463 |
464 | fn = "best_model_state_dict.pt"
465 | os.makedirs(config['model']['save_dir'], exist_ok=True)
466 | model_path = f"{config['model']['save_dir']}/{fn}"
467 |
468 | best_model.load_state_dict(torch.load(model_path))
469 | print("Loaded best model weights...")
470 |
471 |
472 | # ## Evaluation
473 |
474 | # In[30]:
475 |
476 |
477 | te_metrics = test(best_model, te_dataloader)
478 | te_metrics.compute()
479 |
480 |
481 | # ## Plot graphs
482 |
483 | # In[33]:
484 |
485 |
486 | result_file_path = f"{config['model']['save_dir']}/result.json"
487 | with open(result_file_path, 'r') as f:
488 | results = json.loads(''.join(f.readlines()))
489 | epochs_info = results['epochs_info']
490 |
491 | tr_losses = [d['tr_loss'] for d in epochs_info]
492 | vl_losses = [d['vl_loss'] for d in epochs_info]
493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info]
494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info]
495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info]
496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info]
497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info]
498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info]
499 |
500 |
501 | _, axs = plt.subplots(1, 4, figsize=[16,3])
502 |
503 | axs[0].set_title("Loss")
504 | axs[0].plot(tr_losses, 'r-', label="train loss")
505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss")
506 | axs[0].legend()
507 |
508 | axs[1].set_title("Dice score")
509 | axs[1].plot(tr_dice, 'r-', label="train dice")
510 | axs[1].plot(vl_dice, 'b-', label="validation dice")
511 | axs[1].legend()
512 |
513 | axs[2].set_title("Jaccard Similarity")
514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex")
515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex")
516 | axs[2].legend()
517 |
518 | axs[3].set_title("Accuracy")
519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy")
520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy")
521 | axs[3].legend()
522 |
523 | plt.show()
524 |
525 |
526 | # In[32]:
527 |
528 |
529 | epochs_info
530 |
531 |
532 | # ## Save images
533 |
534 | # In[13]:
535 |
536 |
537 | from PIL import Image
538 | import cv2
539 | def skin_plot(img, gt, pred):
540 | img = np.array(img)
541 | gt = np.array(gt)
542 | pred = np.array(pred)
543 | edged_test = cv2.Canny(pred, 100, 255)
544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
545 | edged_gt = cv2.Canny(gt, 100, 255)
546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
547 | for cnt_test in contours_test:
548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1)
549 | for cnt_gt in contours_gt:
550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1)
551 | return img
552 |
553 | #---------------------------------------------------------------------------------------------
554 |
555 |
556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized"
557 |
558 | if not os.path.isdir(save_imgs_dir):
559 | os.mkdir(save_imgs_dir)
560 |
561 | with torch.no_grad():
562 | for batch in tqdm(te_dataloader):
563 | imgs = batch['image']
564 | msks = batch['mask']
565 | ids = batch['id']
566 |
567 | preds = best_model(imgs.to(device))
568 |
569 | txm = imgs.cpu().numpy()
570 | tbm = torch.argmax(msks, 1).cpu().numpy()
571 | tpm = torch.argmax(preds, 1).cpu().numpy()
572 | tid = ids
573 |
574 | for idx in range(len(tbm)):
575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255.
576 | img = np.ascontiguousarray(img*255., dtype=np.uint8)
577 | gt = np.uint8(tbm[idx]*255.)
578 | pred = np.where(tpm[idx]>0.5, 255, 0)
579 | pred = np.ascontiguousarray(pred, dtype=np.uint8)
580 |
581 | res_img = skin_plot(img, gt, pred)
582 |
583 | fid = tid[idx]
584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png")
585 | Image.fromarray(gt).save(f"{save_imgs_dir}/{fid}_gt.png")
586 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png")
587 |
588 |
589 | # In[ ]:
590 |
591 |
592 | f"{config['model']['save_dir']}/visualized"
593 |
594 |
--------------------------------------------------------------------------------
/train_and_test/isic/unetpp-isic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # # UNet++ - ISIC2018
5 | # ---
6 |
7 | # ## Import packages & functions
8 |
9 | # In[1]:
10 |
11 |
12 | from __future__ import print_function, division
13 |
14 |
15 | import os
16 | import sys
17 | sys.path.append('../..')
18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19 |
20 | import copy
21 | import json
22 | import importlib
23 | import glob
24 | import pandas as pd
25 | from skimage import io, transform
26 | import matplotlib.pyplot as plt
27 | from matplotlib.image import imread
28 | import numpy as np
29 | from tqdm import tqdm
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 | import torch.optim as optim
35 | import torchmetrics
36 | from torch.optim import Adam, SGD
37 | from losses import DiceLoss, DiceLossWithLogtis
38 | from torch.nn import BCELoss, CrossEntropyLoss
39 |
40 | from utils import (
41 | show_sbs,
42 | load_config,
43 | _print,
44 | )
45 |
46 | # Ignore warnings
47 | import warnings
48 | warnings.filterwarnings("ignore")
49 |
50 | # plt.ion() # interactive mode
51 |
52 |
53 | # ## Set the seed
54 |
55 | # In[2]:
56 |
57 |
58 | torch.manual_seed(0)
59 | np.random.seed(0)
60 | torch.cuda.manual_seed(0)
61 | import random
62 | random.seed(0)
63 |
64 |
65 | # ## Load the config
66 |
67 | # In[3]:
68 |
69 |
70 | CONFIG_NAME = "isic/isic2018_unetpp.yaml"
71 | CONFIG_FILE_PATH = os.path.join("./configs", CONFIG_NAME)
72 |
73 |
74 | # In[5]:
75 |
76 |
77 | config = load_config(CONFIG_FILE_PATH)
78 | _print("Config:", "info_underline")
79 | print(json.dumps(config, indent=2))
80 | print(20*"~-", "\n")
81 |
82 |
83 | # ## Dataset and Dataloader
84 |
85 | # In[7]:
86 |
87 |
88 | from datasets.isic import ISIC2018DatasetFast
89 | from torch.utils.data import DataLoader, Subset
90 | from torchvision import transforms
91 |
92 |
93 | # In[8]:
94 |
95 |
96 | # ------------------- params --------------------
97 | INPUT_SIZE = config['dataset']['input_size']
98 | # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
99 |
100 |
101 | # ----------------- dataset --------------------
102 | # preparing training dataset
103 | tr_dataset = ISIC2018DatasetFast(mode="tr", one_hot=True)
104 | vl_dataset = ISIC2018DatasetFast(mode="vl", one_hot=True)
105 | te_dataset = ISIC2018DatasetFast(mode="te", one_hot=True)
106 |
107 | # We consider 1815 samples for training, 259 samples for validation and 520 samples for testing
108 | # !cat ~/deeplearning/skin/Prepare_ISIC2018.py
109 |
110 | print(f"Length of trainig_dataset:\t{len(tr_dataset)}")
111 | print(f"Length of validation_dataset:\t{len(vl_dataset)}")
112 | print(f"Length of test_dataset:\t\t{len(te_dataset)}")
113 |
114 |
115 | # prepare train dataloader
116 | tr_dataloader = DataLoader(tr_dataset, **config['data_loader']['train'])
117 |
118 | # prepare validation dataloader
119 | vl_dataloader = DataLoader(vl_dataset, **config['data_loader']['validation'])
120 |
121 | # prepare test dataloader
122 | te_dataloader = DataLoader(te_dataset, **config['data_loader']['test'])
123 |
124 | # -------------- test -----------------
125 | # test and visualize the input data
126 | for sample in tr_dataloader:
127 | img = sample['image']
128 | msk = sample['mask']
129 | print("\n Training")
130 | show_sbs(img[0], msk[0,1])
131 | break
132 |
133 | for sample in vl_dataloader:
134 | img = sample['image']
135 | msk = sample['mask']
136 | print("Validation")
137 | show_sbs(img[0], msk[0,1])
138 | break
139 |
140 | for sample in te_dataloader:
141 | img = sample['image']
142 | msk = sample['mask']
143 | print("Test")
144 | show_sbs(img[0], msk[0,1])
145 | break
146 |
147 |
148 | # ### Device
149 |
150 | # In[10]:
151 |
152 |
153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154 | print(f"Torch device: {device}")
155 |
156 |
157 | # ## Metrics
158 |
159 | # In[13]:
160 |
161 |
162 | metrics = torchmetrics.MetricCollection(
163 | [
164 | torchmetrics.F1Score(),
165 | torchmetrics.Accuracy(),
166 | torchmetrics.Dice(),
167 | torchmetrics.Precision(),
168 | torchmetrics.Specificity(),
169 | torchmetrics.Recall(),
170 | # IoU
171 | torchmetrics.JaccardIndex(2)
172 | ],
173 | prefix='train_metrics/'
174 | )
175 |
176 | # train_metrics
177 | train_metrics = metrics.clone(prefix='train_metrics/').to(device)
178 |
179 | # valid_metrics
180 | valid_metrics = metrics.clone(prefix='valid_metrics/').to(device)
181 |
182 | # test_metrics
183 | test_metrics = metrics.clone(prefix='test_metrics/').to(device)
184 |
185 |
186 | # In[14]:
187 |
188 |
189 | def make_serializeable_metrics(computed_metrics):
190 | res = {}
191 | for k, v in computed_metrics.items():
192 | res[k] = float(v.cpu().detach().numpy())
193 | return res
194 |
195 |
196 | # ## Define validate function
197 |
198 | # In[15]:
199 |
200 |
201 | def validate(model, criterion, vl_dataloader):
202 | model.eval()
203 | with torch.no_grad():
204 |
205 | evaluator = valid_metrics.clone().to(device)
206 |
207 | losses = []
208 | cnt = 0.
209 | for batch, batch_data in enumerate(vl_dataloader):
210 | imgs = batch_data['image']
211 | msks = batch_data['mask']
212 |
213 | cnt += msks.shape[0]
214 |
215 | imgs = imgs.to(device)
216 | msks = msks.to(device)
217 |
218 | preds = model(imgs)
219 | loss = criterion(preds, msks)
220 | losses.append(loss.item())
221 |
222 |
223 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
224 | msks_ = torch.argmax(msks, 1, keepdim=False)
225 | evaluator.update(preds_, msks_)
226 |
227 | # _cml = f"curr_mean-loss:{np.sum(losses)/cnt:0.5f}"
228 | # _bl = f"batch-loss:{losses[-1]/msks.shape[0]:0.5f}"
229 | # iterator.set_description(f"Validation) batch:{batch+1:04d} -> {_cml}, {_bl}")
230 |
231 | # print the final results
232 | loss = np.sum(losses)/cnt
233 | metrics = evaluator.compute()
234 |
235 | return evaluator, loss
236 |
237 |
238 | # ## Define train function¶
239 |
240 | # In[16]:
241 |
242 |
243 | def train(
244 | model,
245 | device,
246 | tr_dataloader,
247 | vl_dataloader,
248 | config,
249 |
250 | criterion,
251 | optimizer,
252 | scheduler,
253 |
254 | save_dir='./',
255 | save_file_id=None,
256 | ):
257 |
258 | EPOCHS = tr_prms['epochs']
259 |
260 | torch.cuda.empty_cache()
261 | model = model.to(device)
262 |
263 | evaluator = train_metrics.clone().to(device)
264 |
265 | epochs_info = []
266 | best_model = None
267 | best_result = {}
268 | best_vl_loss = np.Inf
269 | for epoch in range(EPOCHS):
270 | model.train()
271 |
272 | evaluator.reset()
273 | tr_iterator = tqdm(enumerate(tr_dataloader))
274 | tr_losses = []
275 | cnt = 0
276 | for batch, batch_data in tr_iterator:
277 | imgs = batch_data['image']
278 | msks = batch_data['mask']
279 |
280 | imgs = imgs.to(device)
281 | msks = msks.to(device)
282 |
283 | optimizer.zero_grad()
284 | preds = model(imgs)
285 | loss = criterion(preds, msks)
286 | loss.backward()
287 | optimizer.step()
288 |
289 | # evaluate by metrics
290 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
291 | msks_ = torch.argmax(msks, 1, keepdim=False)
292 | evaluator.update(preds_, msks_)
293 |
294 | cnt += imgs.shape[0]
295 | tr_losses.append(loss.item())
296 |
297 | # write details for each training batch
298 | _cml = f"curr_mean-loss:{np.sum(tr_losses)/cnt:0.5f}"
299 | _bl = f"mean_batch-loss:{tr_losses[-1]/imgs.shape[0]:0.5f}"
300 | tr_iterator.set_description(f"Training) ep:{epoch:03d}, batch:{batch+1:04d} -> {_cml}, {_bl}")
301 |
302 | tr_loss = np.sum(tr_losses)/cnt
303 |
304 | # validate model
305 | vl_metrics, vl_loss = validate(model, criterion, vl_dataloader)
306 | if vl_loss < best_vl_loss:
307 | # find a better model
308 | best_model = model
309 | best_vl_loss = vl_loss
310 | best_result = {
311 | 'tr_loss': tr_loss,
312 | 'vl_loss': vl_loss,
313 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
314 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
315 | }
316 |
317 | # write the final results
318 | epoch_info = {
319 | 'tr_loss': tr_loss,
320 | 'vl_loss': vl_loss,
321 | 'tr_metrics': make_serializeable_metrics(evaluator.compute()),
322 | 'vl_metrics': make_serializeable_metrics(vl_metrics.compute())
323 | }
324 | epochs_info.append(epoch_info)
325 | # epoch_tqdm.set_description(f"Epoch:{epoch+1}/{EPOCHS} -> tr_loss:{tr_loss}, vl_loss:{vl_loss}")
326 | evaluator.reset()
327 |
328 | scheduler.step(vl_loss)
329 |
330 | # save final results
331 | res = {
332 | 'id': save_file_id,
333 | 'config': config,
334 | 'epochs_info': epochs_info,
335 | 'best_result': best_result
336 | }
337 | fn = f"{save_file_id+'_' if save_file_id else ''}result.json"
338 | fp = os.path.join(config['model']['save_dir'],fn)
339 | with open(fp, "w") as write_file:
340 | json.dump(res, write_file, indent=4)
341 |
342 | # save model's state_dict
343 | fn = "last_model_state_dict.pt"
344 | fp = os.path.join(config['model']['save_dir'],fn)
345 | torch.save(model.state_dict(), fp)
346 |
347 | # save the best model's state_dict
348 | fn = "best_model_state_dict.pt"
349 | fp = os.path.join(config['model']['save_dir'], fn)
350 | torch.save(best_model.state_dict(), fp)
351 |
352 | return best_model, model, res
353 |
354 |
355 | # ## Define test function
356 |
357 | # In[18]:
358 |
359 |
360 | def test(model, te_dataloader):
361 | model.eval()
362 | with torch.no_grad():
363 | evaluator = test_metrics.clone().to(device)
364 | for batch_data in tqdm(te_dataloader):
365 | imgs = batch_data['image']
366 | msks = batch_data['mask']
367 |
368 | imgs = imgs.to(device)
369 | msks = msks.to(device)
370 |
371 | preds = model(imgs)
372 |
373 | # evaluate by metrics
374 | preds_ = torch.argmax(preds, 1, keepdim=False).float()
375 | msks_ = torch.argmax(msks, 1, keepdim=False)
376 | evaluator.update(preds_, msks_)
377 |
378 | return evaluator
379 |
380 |
381 | # ## Load and prepare model
382 |
383 | # In[19]:
384 |
385 |
386 | from models.unetpp import NestedUNet as Net
387 |
388 |
389 | model = Net(**config['model']['params'])
390 | torch.cuda.empty_cache()
391 | model = model.to(device)
392 | print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
393 |
394 | os.makedirs(config['model']['save_dir'], exist_ok=True)
395 | model_path = f"{config['model']['save_dir']}/model_state_dict.pt"
396 |
397 | if config['model']['load_weights']:
398 | model.load_state_dict(torch.load(model_path))
399 | print("Loaded pre-trained weights...")
400 |
401 |
402 | # criterion_dice = DiceLoss()
403 | criterion_dice = DiceLossWithLogtis()
404 | # criterion_ce = BCELoss()
405 | criterion_ce = CrossEntropyLoss()
406 |
407 |
408 | def criterion(preds, masks):
409 | c_dice = criterion_dice(preds, masks)
410 | c_ce = criterion_ce(preds, masks)
411 | return 0.5*c_dice + 0.5*c_ce
412 |
413 | tr_prms = config['training']
414 | optimizer = globals()[tr_prms['optimizer']['name']](model.parameters(), **tr_prms['optimizer']['params'])
415 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **tr_prms['scheduler'])
416 |
417 |
418 | # ## Start traning
419 |
420 | # In[20]:
421 |
422 |
423 | best_model, model, res = train(
424 | model,
425 | device,
426 | tr_dataloader,
427 | vl_dataloader,
428 | config,
429 |
430 | criterion,
431 | optimizer,
432 | scheduler,
433 |
434 | save_dir = config['model']['save_dir'],
435 | save_file_id = None,
436 | )
437 |
438 |
439 | # In[21]:
440 |
441 |
442 | te_metrics = test(best_model, te_dataloader)
443 | te_metrics.compute()
444 |
445 |
446 | # In[22]:
447 |
448 |
449 | f"{config['model']['save_dir']}"
450 |
451 |
452 | # # Test the best inferred model
453 | # ----
454 |
455 | # ## Load the best model
456 |
457 | # In[23]:
458 |
459 |
460 | best_model = Net(**config['model']['params'])
461 | torch.cuda.empty_cache()
462 | best_model = best_model.to(device)
463 |
464 | fn = "best_model_state_dict.pt"
465 | os.makedirs(config['model']['save_dir'], exist_ok=True)
466 | model_path = f"{config['model']['save_dir']}/{fn}"
467 |
468 | best_model.load_state_dict(torch.load(model_path))
469 | print("Loaded best model weights...")
470 |
471 |
472 | # ## Evaluation
473 |
474 | # In[24]:
475 |
476 |
477 | te_metrics = test(best_model, te_dataloader)
478 | te_metrics.compute()
479 |
480 |
481 | # ## Plot graphs
482 |
483 | # In[25]:
484 |
485 |
486 | result_file_path = f"{config['model']['save_dir']}/result.json"
487 | with open(result_file_path, 'r') as f:
488 | results = json.loads(''.join(f.readlines()))
489 | epochs_info = results['epochs_info']
490 |
491 | tr_losses = [d['tr_loss'] for d in epochs_info]
492 | vl_losses = [d['vl_loss'] for d in epochs_info]
493 | tr_dice = [d['tr_metrics']['train_metrics/Dice'] for d in epochs_info]
494 | vl_dice = [d['vl_metrics']['valid_metrics/Dice'] for d in epochs_info]
495 | tr_js = [d['tr_metrics']['train_metrics/JaccardIndex'] for d in epochs_info]
496 | vl_js = [d['vl_metrics']['valid_metrics/JaccardIndex'] for d in epochs_info]
497 | tr_acc = [d['tr_metrics']['train_metrics/Accuracy'] for d in epochs_info]
498 | vl_acc = [d['vl_metrics']['valid_metrics/Accuracy'] for d in epochs_info]
499 |
500 |
501 | _, axs = plt.subplots(1, 4, figsize=[16,3])
502 |
503 | axs[0].set_title("Loss")
504 | axs[0].plot(tr_losses, 'r-', label="train loss")
505 | axs[0].plot(vl_losses, 'b-', label="validatiton loss")
506 | axs[0].legend()
507 |
508 | axs[1].set_title("Dice score")
509 | axs[1].plot(tr_dice, 'r-', label="train dice")
510 | axs[1].plot(vl_dice, 'b-', label="validation dice")
511 | axs[1].legend()
512 |
513 | axs[2].set_title("Jaccard Similarity")
514 | axs[2].plot(tr_js, 'r-', label="train JaccardIndex")
515 | axs[2].plot(vl_js, 'b-', label="validatiton JaccardIndex")
516 | axs[2].legend()
517 |
518 | axs[3].set_title("Accuracy")
519 | axs[3].plot(tr_acc, 'r-', label="train Accuracy")
520 | axs[3].plot(vl_acc, 'b-', label="validation Accuracy")
521 | axs[3].legend()
522 |
523 | plt.show()
524 |
525 |
526 | # In[ ]:
527 |
528 |
529 | epochs_info
530 |
531 |
532 | # ## Save images
533 |
534 | # In[ ]:
535 |
536 |
537 | from PIL import Image
538 | import cv2
539 | def skin_plot(img, gt, pred):
540 | img = np.array(img)
541 | gt = np.array(gt)
542 | pred = np.array(pred)
543 | edged_test = cv2.Canny(pred, 100, 255)
544 | contours_test, _ = cv2.findContours(edged_test, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
545 | edged_gt = cv2.Canny(gt, 100, 255)
546 | contours_gt, _ = cv2.findContours(edged_gt, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
547 | for cnt_test in contours_test:
548 | cv2.drawContours(img, [cnt_test], -1, (0, 0, 255), 1)
549 | for cnt_gt in contours_gt:
550 | cv2.drawContours(img, [cnt_gt], -1, (0,255,0), 1)
551 | return img
552 |
553 | #---------------------------------------------------------------------------------------------
554 |
555 |
556 | save_imgs_dir = f"{config['model']['save_dir']}/visualized"
557 |
558 | if not os.path.isdir(save_imgs_dir):
559 | os.mkdir(save_imgs_dir)
560 |
561 | with torch.no_grad():
562 | for batch in tqdm(te_dataloader):
563 | imgs = batch['image']
564 | msks = batch['mask']
565 | ids = batch['id']
566 |
567 | preds = best_model(imgs.to(device))
568 |
569 | txm = imgs.cpu().numpy()
570 | tbm = torch.argmax(msks, 1).cpu().numpy()
571 | tpm = torch.argmax(preds, 1).cpu().numpy()
572 | tid = ids
573 |
574 | for idx in range(len(tbm)):
575 | img = np.moveaxis(txm[idx, :3], 0, -1)*255.
576 | img = np.ascontiguousarray(img, dtype=np.uint8)
577 | gt = np.uint8(tbm[idx]*255.)
578 | pred = np.where(tpm[idx]>0.5, 255, 0)
579 | pred = np.ascontiguousarray(pred, dtype=np.uint8)
580 |
581 | res_img = skin_plot(img, gt, pred)
582 |
583 | fid = tid[idx]
584 | Image.fromarray(img).save(f"{save_imgs_dir}/{fid}_img.png")
585 | Image.fromarray(res_img).save(f"{save_imgs_dir}/{fid}_img_gt_pred.png")
586 |
587 |
588 | # In[ ]:
589 |
590 |
591 | f"{config['model']['save_dir']}/visualized"
592 |
593 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | class bcolors:
2 | HEADER = '\033[95m'
3 | OKBLUE = '\033[94m'
4 | OKCYAN = '\033[96m'
5 | OKGREEN = '\033[92m'
6 | WARNING = '\033[93m'
7 | FAIL = '\033[91m'
8 | ENDC = '\033[0m'
9 | BOLD = '\033[1m'
10 | UNDERLINE = '\033[4m'
11 |
12 |
13 |
14 | from termcolor import colored
15 | def _print(string, p=None):
16 | if not p:
17 | print(string)
18 | return
19 | pre = f"{bcolors.ENDC}"
20 |
21 | if "bold" in p.lower():
22 | pre += bcolors.BOLD
23 | elif "underline" in p.lower():
24 | pre += bcolors.UNDERLINE
25 | elif "header" in p.lower():
26 | pre += bcolors.HEADER
27 |
28 | if "warning" in p.lower():
29 | pre += bcolors.WARNING
30 | elif "error" in p.lower():
31 | pre += bcolors.FAIL
32 | elif "ok" in p.lower():
33 | pre += bcolors.OKGREEN
34 | elif "info" in p.lower():
35 | if "blue" in p.lower():
36 | pre += bcolors.OKBLUE
37 | else:
38 | pre += bcolors.OKCYAN
39 |
40 | print(f"{pre}{string}{bcolors.ENDC}")
41 |
42 |
43 |
44 | import yaml
45 | def load_config(config_filepath):
46 | try:
47 | with open(config_filepath, 'r') as file:
48 | config = yaml.safe_load(file)
49 | return config
50 | except FileNotFoundError:
51 | _print(f"Config file not found! <{config_filepath}>", "error_bold")
52 | exit(1)
53 |
54 |
55 |
56 | import numpy as np
57 | from matplotlib import pyplot as plt
58 | def show_sbs(im1, im2, figsize=[8,4], im1_title="Image", im2_title="Mask", show=True):
59 | if im1.shape[0]<4:
60 | im1 = np.array(im1)
61 | im1 = np.transpose(im1, [1, 2, 0])
62 |
63 | if im2.shape[0]<4:
64 | im2 = np.array(im2)
65 | im2 = np.transpose(im2, [1, 2, 0])
66 |
67 | _, axs = plt.subplots(1, 2, figsize=figsize)
68 | axs[0].imshow(im1)
69 | axs[0].set_title(im1_title)
70 | axs[1].imshow(im2, cmap='gray')
71 | axs[1].set_title(im2_title)
72 | if show: plt.show()
--------------------------------------------------------------------------------