├── .gitignore ├── .gitmodules ├── Coronary ├── test │ ├── imgs │ │ ├── 5.png │ │ └── 6.png │ └── masks │ │ ├── 5a.png │ │ └── 6a.png ├── train │ ├── imgs │ │ ├── 1.png │ │ ├── 2.png │ │ └── 3.png │ └── masks │ │ ├── 1a.png │ │ ├── 2a.png │ │ └── 3a.png └── val │ ├── imgs │ └── 4.png │ └── masks │ └── 4a.png ├── DRIVE ├── test │ └── images │ │ ├── 01_test.tif │ │ ├── 02_test.tif │ │ ├── 03_test.tif │ │ ├── 04_test.tif │ │ ├── 05_test.tif │ │ ├── 06_test.tif │ │ ├── 07_test.tif │ │ ├── 08_test.tif │ │ ├── 09_test.tif │ │ ├── 10_test.tif │ │ ├── 11_test.tif │ │ ├── 12_test.tif │ │ ├── 13_test.tif │ │ ├── 14_test.tif │ │ ├── 15_test.tif │ │ ├── 16_test.tif │ │ ├── 17_test.tif │ │ ├── 18_test.tif │ │ ├── 19_test.tif │ │ └── 20_test.tif ├── training │ ├── 1st_manual │ │ ├── 21_manual1.gif │ │ ├── 22_manual1.gif │ │ ├── 23_manual1.gif │ │ ├── 24_manual1.gif │ │ ├── 25_manual1.gif │ │ ├── 26_manual1.gif │ │ ├── 27_manual1.gif │ │ ├── 28_manual1.gif │ │ ├── 29_manual1.gif │ │ ├── 30_manual1.gif │ │ ├── 31_manual1.gif │ │ ├── 32_manual1.gif │ │ ├── 33_manual1.gif │ │ ├── 34_manual1.gif │ │ ├── 35_manual1.gif │ │ └── 36_manual1.gif │ ├── images │ │ ├── 21_training.tif │ │ ├── 22_training.tif │ │ ├── 23_training.tif │ │ ├── 24_training.tif │ │ ├── 25_training.tif │ │ ├── 26_training.tif │ │ ├── 27_training.tif │ │ ├── 28_training.tif │ │ ├── 29_training.tif │ │ ├── 30_training.tif │ │ ├── 31_training.tif │ │ ├── 32_training.tif │ │ ├── 33_training.tif │ │ ├── 34_training.tif │ │ ├── 35_training.tif │ │ └── 36_training.tif │ └── mask │ │ ├── 21_training_mask.gif │ │ ├── 22_training_mask.gif │ │ ├── 23_training_mask.gif │ │ ├── 24_training_mask.gif │ │ ├── 25_training_mask.gif │ │ ├── 26_training_mask.gif │ │ ├── 27_training_mask.gif │ │ ├── 28_training_mask.gif │ │ ├── 29_training_mask.gif │ │ ├── 30_training_mask.gif │ │ ├── 31_training_mask.gif │ │ ├── 32_training_mask.gif │ │ ├── 33_training_mask.gif │ │ ├── 34_training_mask.gif │ │ ├── 35_training_mask.gif │ │ ├── 36_training_mask.gif │ │ ├── 37_training_mask.gif │ │ ├── 38_training_mask.gif │ │ ├── 39_training_mask.gif │ │ └── 40_training_mask.gif └── validation │ ├── 1st_manual │ ├── 37_manual1.gif │ ├── 38_manual1.gif │ ├── 39_manual1.gif │ └── 40_manual1.gif │ └── images │ ├── 37_training.tif │ ├── 38_training.tif │ ├── 39_training.tif │ └── 40_training.tif ├── LICENSE ├── README.md ├── doc_images ├── block_comparison.png ├── decoder_comparison.png └── flops_pareto.png ├── eval.py ├── metrics.py ├── predict.py ├── requirements.txt ├── train.py └── utils ├── augment.py └── dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "segmentation_models_pytorch"] 2 | path = segmentation_models_pytorch 3 | url = https://github.com/jlcsilva/segmentation_models.pytorch.git 4 | -------------------------------------------------------------------------------- /Coronary/test/imgs/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/test/imgs/5.png -------------------------------------------------------------------------------- /Coronary/test/imgs/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/test/imgs/6.png -------------------------------------------------------------------------------- /Coronary/test/masks/5a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/test/masks/5a.png -------------------------------------------------------------------------------- /Coronary/test/masks/6a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/test/masks/6a.png -------------------------------------------------------------------------------- /Coronary/train/imgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/train/imgs/1.png -------------------------------------------------------------------------------- /Coronary/train/imgs/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/train/imgs/2.png -------------------------------------------------------------------------------- /Coronary/train/imgs/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/train/imgs/3.png -------------------------------------------------------------------------------- /Coronary/train/masks/1a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/train/masks/1a.png -------------------------------------------------------------------------------- /Coronary/train/masks/2a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/train/masks/2a.png -------------------------------------------------------------------------------- /Coronary/train/masks/3a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/train/masks/3a.png -------------------------------------------------------------------------------- /Coronary/val/imgs/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/val/imgs/4.png -------------------------------------------------------------------------------- /Coronary/val/masks/4a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/Coronary/val/masks/4a.png -------------------------------------------------------------------------------- /DRIVE/test/images/01_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/01_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/02_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/02_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/03_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/03_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/04_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/04_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/05_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/05_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/06_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/06_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/07_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/07_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/08_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/08_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/09_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/09_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/10_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/10_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/11_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/11_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/12_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/12_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/13_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/13_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/14_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/14_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/15_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/15_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/16_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/16_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/17_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/17_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/18_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/18_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/19_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/19_test.tif -------------------------------------------------------------------------------- /DRIVE/test/images/20_test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/test/images/20_test.tif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/21_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/21_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/22_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/22_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/23_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/23_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/24_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/24_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/25_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/25_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/26_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/26_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/27_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/27_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/28_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/28_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/29_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/29_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/30_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/30_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/31_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/31_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/32_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/32_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/33_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/33_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/34_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/34_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/35_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/35_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/1st_manual/36_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/1st_manual/36_manual1.gif -------------------------------------------------------------------------------- /DRIVE/training/images/21_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/21_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/22_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/22_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/23_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/23_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/24_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/24_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/25_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/25_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/26_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/26_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/27_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/27_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/28_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/28_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/29_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/29_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/30_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/30_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/31_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/31_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/32_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/32_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/33_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/33_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/34_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/34_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/35_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/35_training.tif -------------------------------------------------------------------------------- /DRIVE/training/images/36_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/images/36_training.tif -------------------------------------------------------------------------------- /DRIVE/training/mask/21_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/21_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/22_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/22_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/23_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/23_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/24_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/24_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/25_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/25_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/26_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/26_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/27_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/27_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/28_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/28_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/29_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/29_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/30_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/30_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/31_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/31_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/32_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/32_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/33_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/33_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/34_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/34_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/35_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/35_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/36_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/36_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/37_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/37_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/38_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/38_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/39_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/39_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/training/mask/40_training_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/training/mask/40_training_mask.gif -------------------------------------------------------------------------------- /DRIVE/validation/1st_manual/37_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/1st_manual/37_manual1.gif -------------------------------------------------------------------------------- /DRIVE/validation/1st_manual/38_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/1st_manual/38_manual1.gif -------------------------------------------------------------------------------- /DRIVE/validation/1st_manual/39_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/1st_manual/39_manual1.gif -------------------------------------------------------------------------------- /DRIVE/validation/1st_manual/40_manual1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/1st_manual/40_manual1.gif -------------------------------------------------------------------------------- /DRIVE/validation/images/37_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/images/37_training.tif -------------------------------------------------------------------------------- /DRIVE/validation/images/38_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/images/38_training.tif -------------------------------------------------------------------------------- /DRIVE/validation/images/39_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/images/39_training.tif -------------------------------------------------------------------------------- /DRIVE/validation/images/40_training.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/DRIVE/validation/images/40_training.tif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 jlcsilva 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientUNet++ 2 | 3 | [![Paper](http://img.shields.io/badge/Paper-arXiv.2106.11447-B3181B?logo=arXiv)](https://arxiv.org/abs/2106.11447) [![Generic badge](https://img.shields.io/badge/License-MIT-.svg)](https://shields.io/) 4 | 5 | - June 23/2021: Initial PyTorch code release for [EfficientUNet++](https://arxiv.org/abs/2106.11447) 6 | 7 | ## 1. About the EfficientUNet++ decoder architecture 8 | 9 | The [EfficientUNet++](https://arxiv.org/abs/2106.11447) is a decoder architecture inspired by the [UNet++](https://arxiv.org/abs/1807.10165). Although it was designed and fine-tuned for a coronary artery segmentation task, we expect it to perform well across a wide variety of medical image segmentation tasks. We plan to test it on other tasks soon, and will upload test set performance when we do. 10 | 11 | In our coronary artery segmentation task, using [EfficientNet](https://arxiv.org/abs/1905.11946) models as backbone, the EfficientUNet++ achieves higher performance that the UNet++, [U-Net](https://arxiv.org/abs/1505.04597), [ResUNet++](https://arxiv.org/abs/1911.07067) and other high-performing medical image segmentation architectures. Such improvement is achieved by two simple modifications, inspired by the EfficientNet building blocks: 12 | 13 | - To reduce computation, the 3x3 convolutional blocks of the UNet++ are replaced with residual bottleneck blocks with depthwise convolutions 14 | - To enhance performance, the feature maps outputted by 3x3 depthwise convolution are processed by an [scSE](https://arxiv.org/abs/1803.02579), that applies both channel and spatial attention. The SE channel attention used in the EfficientNet was tested, but yielded worse results. For more details, refer to the [paper](https://arxiv.org/abs/2106.11447). 15 | 16 | Comparison of the EfficientUNet++ and UNet++ building blocks: 17 | 18 |
19 | 20 |
21 | 22 | Performance of different decoder architectures as a function of the number of FLOPS and the number of parameters, when using the EfficientNet B0 to B7 models as encoders: 23 | 24 |
25 | 26 |
27 | 28 | To help practicioners choose the most adequate model for their clinical needs and available hardware, the Pareto Frontier of the tested models was drawn, with the number of FLOPS as a function of model performance. 29 | 30 |
31 | 32 | 33 | 34 |
35 | 36 | ## 2. Installing the repository 37 | 38 | When installing the repository, do not forget to update the submodule, by running the `git submodule update --init --recursive` command. 39 | 40 | ## 3. Using the EfficientUNet++ independently 41 | 42 | The EfficientUNet++ was implemented as an extension of the Segmentation Models Pytorch repository, by [Pavel Yakubovskiy](https://github.com/qubvel). Currently, it is a pending pull request. For detailed information on library installation and model usage, visit the [Read The Docs Project Page](https://segmentation-models-pytorch.readthedocs.io/en/latest/) or the [repository's README](https://github.com/jlcsilva/segmentation_models.pytorch/blob/master/README.md). 43 | 44 | ## 4. Training the EfficientUNet++ 45 | 46 | It was not possible to obtain the necessary clearances to release the full coronary artery segmentation dataset. Therefore, to enable other researchers and programmers to test our architecture, we provide a toy dataset with 6 coronary angiography images, and built a loader for the the retinal vessel segmentation [DRIVE](https://drive.grand-challenge.org/) dataset. We still haven't had the time to test the EfficientUNet++ in the retinal vessel segmentation task, but we expect it to perform well when coupled with the right encoder. 47 | 48 | To train the an EfficientUNet++ model on the coronary artery segmentation toy dataset using the EfficientNetB0 encoder, run: 49 | 50 | ``` 51 | python train.py -d Coronary -enc timm-efficientnet-b0 52 | ``` 53 | 54 | To train the same model on the DRIVE dataset: 55 | 56 | ``` 57 | python train.py -d DRIVE -enc timm-efficientnet-b0 58 | ``` 59 | 60 | For more details on program arguments, run: 61 | 62 | ``` 63 | python train.py -h 64 | ``` 65 | 66 | To consult the available encoder architectures, visit the [Read The Docs Project Page](https://segmentation-models-pytorch.readthedocs.io/en/latest/). 67 | 68 | ## 5. Inference 69 | 70 | To perform inference, run the following command: 71 | 72 | ``` 73 | python predict.py -d -enc -i -m 74 | ``` 75 | 76 | where the dataset can be either Coronary or DRIVE. 77 | 78 | ## 6. Citing 📝 79 | 80 | If you find our work useful and use it or draw inspiration from it in your research, please consider citing it. 81 | 82 | ``` 83 | @misc{silva2021encoderdecoder, 84 | title={Encoder-Decoder Architectures for Clinically Relevant Coronary Artery Segmentation}, 85 | author={João Lourenço Silva and Miguel Nobre Menezes and Tiago Rodrigues and Beatriz Silva and Fausto J. Pinto and Arlindo L. Oliveira}, 86 | year={2021}, 87 | eprint={2106.11447}, 88 | archivePrefix={arXiv}, 89 | primaryClass={eess.IV} 90 | } 91 | ``` 92 | 93 | ## 6. License 🛡️ 94 | 95 | Distributed under [MIT License](https://github.com/jlcsilva/EfficientUNetPlusPlus/blob/main/LICENSE) 96 | -------------------------------------------------------------------------------- /doc_images/block_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/doc_images/block_comparison.png -------------------------------------------------------------------------------- /doc_images/decoder_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/doc_images/decoder_comparison.png -------------------------------------------------------------------------------- /doc_images/flops_pareto.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jlcsilva/EfficientUNetPlusPlus/a68471d921426219b22bba1077f46869f9515e0d/doc_images/flops_pareto.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | 5 | from metrics import dice_loss 6 | 7 | def eval_net(net, loader, device, n_classes=3): 8 | net.eval() 9 | mask_type = torch.float32 if n_classes == 1 else torch.long 10 | n_val = len(loader) 11 | tot = 0 12 | 13 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: 14 | for batch in loader: 15 | imgs, true_masks = batch['image'][0], batch['mask'][0] 16 | imgs = imgs.to(device=device, dtype=torch.float32) 17 | true_masks = true_masks.to(device=device, dtype=mask_type) 18 | 19 | with torch.no_grad(): 20 | mask_pred = net(imgs) 21 | 22 | if n_classes > 1: 23 | true_masks = true_masks.squeeze(1) 24 | tot += dice_loss(mask_pred, true_masks, use_weights=True).item() 25 | else: 26 | tot += dice_loss(mask_pred, true_masks, use_weights=False).item() 27 | pbar.update() 28 | 29 | net.train() 30 | return tot / n_val -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | """ 5 | Implements metrics to be used as loss functions or performance evaluation criteria. 6 | """ 7 | def dice_loss(input: torch.FloatTensor, target: torch.LongTensor, use_weights: bool = False, k: int = 0, eps: float = 0.0001): 8 | """ 9 | Returns the Generalized Dice Loss Coefficient of a batch associated to the input and target tensors. In case `use_weights` \ 10 | is specified and is `True`, then the computation of the loss takes the class weights into account. 11 | 12 | Args: 13 | input (torch.FloatTensor): NCHW tensor containing the probabilities predicted for each class. 14 | target (torch.LongTensor): NCHW one-hot encoded tensor, containing the ground truth segmentation mask. 15 | use_weights (bool): specifies whether to use class weights in the computation or not. 16 | k (int): weight for pGD function. Default is 0 for ordinary dice loss. 17 | """ 18 | if input.is_cuda: 19 | s = torch.FloatTensor(1).cuda().zero_() 20 | else: 21 | s = torch.FloatTensor(1).zero_() 22 | 23 | # Multiple class case 24 | n_classes = input.size()[1] 25 | if n_classes != 1: 26 | # Convert target to one hot encoding 27 | target = F.one_hot(target, n_classes).squeeze() 28 | if target.ndim == 3: 29 | target = target.unsqueeze(0) 30 | target = torch.transpose(torch.transpose(target, 2, 3), 1, 2).type(torch.FloatTensor).cuda().contiguous() 31 | input = torch.softmax(input, dim=1) 32 | else: 33 | input = torch.sigmoid(input) 34 | 35 | class_weights = None 36 | for i, c in enumerate(zip(input, target)): 37 | if use_weights: 38 | class_weights = torch.pow(torch.sum(c[1], (1,2)) + eps, -2) 39 | s = s + __dice_loss(c[0], c[1], class_weights, k=k) 40 | 41 | return s / (i + 1) 42 | 43 | def __dice_loss(input: torch.FloatTensor, target: torch.LongTensor, weights: torch.FloatTensor = None, k: int = 0, eps: float = 0.0001): 44 | """ 45 | Returns the Generalized Dice Loss Coefficient associated to the input and target tensors, as well as to the input weights,\ 46 | in case they are specified. 47 | 48 | Args: 49 | input (torch.FloatTensor): CHW tensor containing the classes predicted for each pixel. 50 | target (torch.LongTensor): CHW one-hot encoded tensor, containing the ground truth segmentation mask. 51 | weights (torch.FloatTensor): 2D tensor of size C, containing the weight of each class, if specified. 52 | k (int): weight for pGD function. Default is 0 for ordinary dice loss. 53 | """ 54 | n_classes = input.size()[0] 55 | 56 | if weights is not None: 57 | for c in range(n_classes): 58 | intersection = (input[c] * target[c] * weights[c]).sum() 59 | union = (weights[c] * (input[c] + target[c])).sum() + eps 60 | else: 61 | intersection = torch.dot(input.view(-1), target.view(-1)) 62 | union = torch.sum(input) + torch.sum(target) + eps 63 | 64 | gd = (2 * intersection.float() + eps) / union.float() 65 | return 1 - (gd / (1 + k*(1-gd))) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | from PIL import Image 8 | from torchvision import transforms 9 | 10 | from utils.dataset import CoronaryArterySegmentationDataset, RetinaSegmentationDataset 11 | import segmentation_models_pytorch.segmentation_models_pytorch as smp 12 | 13 | from torch.backends import cudnn 14 | 15 | def predict_img(net, dataset_class, full_img, device, scale_factor=1, n_classes=3): 16 | net.eval() 17 | 18 | img = torch.from_numpy(dataset_class.preprocess(full_img, scale_factor)) 19 | 20 | img = img.unsqueeze(0) 21 | img = img.to(device=device, dtype=torch.float32) 22 | 23 | with torch.no_grad(): 24 | output = net(img) 25 | 26 | if n_classes > 1: 27 | probs = torch.softmax(output, dim=1) 28 | else: 29 | probs = torch.sigmoid(output) 30 | 31 | probs = probs.squeeze(0) 32 | 33 | tf = transforms.Compose( 34 | [ 35 | transforms.ToPILImage(), 36 | transforms.Resize(full_img.size[1]), 37 | transforms.ToTensor() 38 | ] 39 | ) 40 | 41 | full_mask = tf(probs.cpu()) 42 | 43 | if n_classes > 1: 44 | return dataset_class.one_hot2mask(full_mask) 45 | else: 46 | return full_mask > 0.5 47 | 48 | 49 | def get_args(): 50 | parser = argparse.ArgumentParser(description='Predict masks from input images', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True) 52 | parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', help="Specify the file in which the model is stored") 53 | parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='filenames of input images', required=True) 54 | parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of ouput images') 55 | parser.add_argument('--scale', '-s', type=float, help="Scale factor for the input images", default=1) 56 | parser.add_argument('-enc', '--encoder', metavar='ENC', type=str, default='timm-efficientnet-b0', help='Encoder to be used', dest='encoder') 57 | return parser.parse_args() 58 | 59 | 60 | def get_output_filenames(args, dataset='DRIVE'): 61 | """ 62 | Validates and/or computes the output path of the segmentation masks to be generated. 63 | """ 64 | in_files = args.input 65 | out_files = [] 66 | 67 | if not args.output: 68 | for f in in_files: 69 | if dataset == 'DRIVE': 70 | dest_f = f.replace("images", "mask") 71 | dest_f = dest_f.replace("_test.tif", ".png") 72 | elif dataset == 'Coronary': 73 | dest_f = f.replace("imgs", "pred") 74 | else: 75 | dest_f = f.replace(".", "mask.") 76 | dest_dir, _ = os.path.split(dest_f) 77 | # Create destination directory 78 | if not os.path.exists(dest_dir): 79 | os.makedirs(dest_dir) 80 | out_files.append(dest_f) 81 | elif len(in_files) != len(args.output): 82 | logging.error("Input files and output files are not of the same length") 83 | raise SystemExit() 84 | else: 85 | out_files = args.output 86 | 87 | return out_files 88 | 89 | if __name__ == "__main__": 90 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') 91 | args = get_args() 92 | in_files = args.input 93 | out_files = get_output_filenames(args, args.dataset) 94 | 95 | # Number of classes 96 | if args.dataset == 'DRIVE': 97 | n_classes = 2 98 | dataset_class = RetinaSegmentationDataset 99 | elif args.dataset == 'Coronary': 100 | n_classes = 3 101 | dataset_class = CoronaryArterySegmentationDataset 102 | else: 103 | print("Invalid dataset") 104 | exit() 105 | net = smp.EfficientUnetPlusPlus(encoder_name=args.encoder, encoder_weights="imagenet", in_channels=3, classes=n_classes) 106 | net = nn.DataParallel(net) 107 | 108 | logging.info("Loading model {}".format(args.model)) 109 | 110 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 111 | logging.info(f'Using device {device}') 112 | net.to(device=device) 113 | # faster convolutions, but more memory 114 | cudnn.benchmark = True 115 | net.load_state_dict(torch.load(args.model, map_location=device)) 116 | 117 | logging.info("Model loaded !") 118 | 119 | if device == 'cuda': 120 | score = torch.FloatTensor(1).cuda().zero_() 121 | weighted_score = torch.FloatTensor(1).cuda().zero_() 122 | else: 123 | score = torch.FloatTensor(1).zero_() 124 | weighted_score = torch.FloatTensor(1).zero_() 125 | 126 | 127 | for i, fn in enumerate(in_files): 128 | if args.dataset == 'DRIVE': 129 | orig_image = Image.open(fn) 130 | w, h = orig_image.size 131 | image = Image.new('RGB', (608, 608), (0, 0, 0)) 132 | image.paste(orig_image, (0, 0)) 133 | elif args.dataset == 'Coronary': 134 | image = Image.open(fn).convert(mode='RGB') 135 | mask = predict_img(net=net, dataset_class=dataset_class, full_img=image, scale_factor=args.scale, device=device) 136 | result = dataset_class.mask2image(mask) 137 | if args.dataset == 'DRIVE': 138 | result = result.crop((0, 0, w, h)) 139 | result.save(out_files[i]) 140 | logging.info("Mask saved to {}".format(out_files[i])) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | cachetools==4.2.2 3 | certifi==2021.5.30 4 | chardet==4.0.0 5 | efficientnet-pytorch==0.6.3 6 | google-auth==1.31.0 7 | google-auth-oauthlib==0.4.4 8 | grpcio==1.38.0 9 | idna==2.10 10 | kornia==0.5.4 11 | Markdown==3.3.4 12 | munch==2.5.0 13 | numpy==1.20.3 14 | oauthlib==3.1.1 15 | Pillow==8.2.0 16 | pretrainedmodels==0.7.4 17 | protobuf==3.17.3 18 | pyasn1==0.4.8 19 | pyasn1-modules==0.2.8 20 | requests==2.25.1 21 | requests-oauthlib==1.3.0 22 | rsa==4.7.2 23 | six==1.16.0 24 | tensorboard==2.5.0 25 | tensorboard-data-server==0.6.1 26 | tensorboard-plugin-wit==1.8.0 27 | timm==0.3.2 28 | torch==1.7.1+cu101 29 | torchaudio==0.7.2 30 | torchvision==0.8.2+cu101 31 | tqdm==4.61.1 32 | typing-extensions==3.10.0.0 33 | urllib3==1.26.5 34 | Werkzeug==2.0.1 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import optim 9 | from tqdm import tqdm 10 | from metrics import dice_loss 11 | from eval import eval_net 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | from utils.dataset import CoronaryArterySegmentationDataset, RetinaSegmentationDataset 15 | from torch.utils.data import DataLoader 16 | from kornia.losses import focal_loss 17 | 18 | import segmentation_models_pytorch.segmentation_models_pytorch as smp 19 | 20 | def train_net(net, 21 | device, 22 | training_set, 23 | validation_set, 24 | dir_checkpoint, 25 | epochs=150, 26 | batch_size=2, 27 | lr=0.001, 28 | save_cp=True, 29 | img_scale=1, 30 | n_classes=3, 31 | n_channels=3, 32 | augmentation_ratio = 0): 33 | 34 | train = training_set 35 | val = validation_set 36 | n_train = len(train) 37 | n_val = len(val) 38 | train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) 39 | val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) 40 | 41 | # Sets the effective batch size according to the batch size and the data augmentation ratio 42 | batch_size = (1 + augmentation_ratio)*batch_size 43 | 44 | # Prepares the summary file 45 | writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') 46 | global_step = 0 47 | 48 | logging.info(f'''Starting training: 49 | Epochs: {epochs} 50 | Batch size: {batch_size} 51 | Learning rate: {lr} 52 | Training size: {n_train} 53 | Validation size: {n_val} 54 | Checkpoints: {save_cp} 55 | Device: {device.type} 56 | Images scaling: {img_scale} 57 | Augmentation ratio: {augmentation_ratio} 58 | ''') 59 | 60 | # Choose the optimizer and scheduler 61 | optimizer = optim.Adam(net.parameters(), lr=lr) 62 | scheduler = optim.lr_scheduler.StepLR(optimizer, epochs//3, gamma=0.1, verbose=True) 63 | 64 | # Train loop 65 | for epoch in range(epochs): 66 | net.train() 67 | epoch_loss = 0 68 | with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: 69 | for batch in train_loader: 70 | imgs = batch['image'] 71 | true_masks = batch['mask'] 72 | # DataLoaders return lists of tensors. TODO: Concatenate the lists inside the DataLoaders 73 | imgs = torch.cat(imgs, dim = 0) 74 | true_masks = torch.cat(true_masks, dim = 0) 75 | 76 | assert imgs.shape[1] == n_channels, \ 77 | f'Network has been defined with {n_channels} input channels, ' \ 78 | f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 79 | 'the images are loaded correctly.' 80 | 81 | imgs = imgs.to(device=device, dtype=torch.float32) 82 | mask_type = torch.float32 if n_classes == 1 else torch.long 83 | true_masks = true_masks.to(device=device, dtype=mask_type) 84 | 85 | masks_pred = net(imgs) 86 | 87 | # Compute loss 88 | loss = focal_loss(masks_pred, true_masks.squeeze(1), alpha=0.25, gamma = 2, reduction='mean').unsqueeze(0) 89 | loss += dice_loss(masks_pred, true_masks.squeeze(1), True, k = 0.75) 90 | 91 | epoch_loss += loss.item() 92 | writer.add_scalar('Loss/train', loss.item(), global_step) 93 | pbar.set_postfix(**{'loss (batch)': loss.item()}) 94 | 95 | optimizer.zero_grad() 96 | loss.backward() 97 | nn.utils.clip_grad_value_(net.parameters(), 0.1) 98 | optimizer.step() 99 | 100 | pbar.update(imgs.shape[0]//(1 + augmentation_ratio)) 101 | global_step += 1 102 | if global_step % (n_train // (batch_size / (1 + augmentation_ratio))) == 0: 103 | for tag, value in net.named_parameters(): 104 | tag = tag.replace('.', '/') 105 | 106 | try: 107 | writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) 108 | writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) 109 | except: 110 | pass 111 | 112 | epoch_score = eval_net(net, train_loader, device) 113 | val_score = eval_net(net, val_loader, device) 114 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) 115 | 116 | if n_classes > 1: 117 | logging.info('Validation loss: {}'.format(val_score)) 118 | writer.add_scalar('Generalized dice loss/train', epoch_score, global_step) 119 | writer.add_scalar('Generalized dice loss/test', val_score, global_step) 120 | else: 121 | logging.info('Validation loss: {}'.format(val_score)) 122 | writer.add_scalar('Dice loss/train', epoch_score, global_step) 123 | writer.add_scalar('Dice loss/test', val_score, global_step) 124 | scheduler.step() 125 | if save_cp: 126 | try: 127 | os.mkdir(dir_checkpoint) 128 | logging.info('Created checkpoint directory') 129 | except OSError: 130 | pass 131 | torch.save(net.state_dict(), 132 | dir_checkpoint + f'CP_epoch{epoch + 1}.pth') 133 | logging.info(f'Checkpoint {epoch + 1} saved !') 134 | writer.close() 135 | 136 | def get_args(): 137 | parser = argparse.ArgumentParser(description='EfficientUNet++ train script', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 138 | parser.add_argument('-d', '--dataset', type=str, help='Specifies the dataset to be used', dest='dataset', required=True) 139 | parser.add_argument('-ti', '--training-images-dir', type=str, default=None, help='Training images directory', dest='train_img_dir') 140 | parser.add_argument('-tm', '--training-masks-dir', type=str, default=None, help='Training masks directory', dest='train_mask_dir') 141 | parser.add_argument('-vi', '--validation-images-dir', type=str, default=None, help='Validation images directory', dest='val_img_dir') 142 | parser.add_argument('-vm', '--validation-masks-dir', type=str, default=None, help='Validation masks directory', dest='val_mask_dir') 143 | parser.add_argument('-enc', '--encoder', metavar='ENC', type=str, default='timm-efficientnet-b0', help='Encoder to be used', dest='encoder') 144 | parser.add_argument('-e', '--epochs', metavar='E', type=int, default=150, help='Number of epochs', dest='epochs') 145 | parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1, help='Batch size', dest='batchsize') 146 | parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.001, help='Learning rate', dest='lr') 147 | parser.add_argument('-f', '--load', type=str, default=False, help='Load model from a .pth file', dest='load') 148 | parser.add_argument('-s', '--scale', metavar='S', type=float, default=1, help='Downscaling factor of the images', dest='scale') 149 | parser.add_argument('-a', '--augmentation-ratio', metavar='AR', type=int, default=0, help='Number of augmentation to be generated for each image in the dataset', dest='augmentation_ratio') 150 | parser.add_argument('-c', '--dir_checkpoint', type=str, default='checkpoints/', help='Directory to save the checkpoints', dest='dir_checkpoint') 151 | return parser.parse_args() 152 | 153 | if __name__ == '__main__': 154 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') 155 | args = get_args() 156 | 157 | # Determine device 158 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 159 | logging.info(f'Using device {device}') 160 | 161 | # Number of classes 162 | if args.dataset == 'DRIVE': 163 | n_classes = 2 164 | elif args.dataset == 'Coronary': 165 | n_classes = 3 166 | 167 | # Instantiate EfficientUNet++ with the specified encoder 168 | net = smp.EfficientUnetPlusPlus(encoder_name=args.encoder, encoder_weights="imagenet", in_channels=3, classes=n_classes) 169 | 170 | # Freeze encoder weights 171 | net.encoder.eval() 172 | for m in net.encoder.modules(): 173 | m.requires_grad_ = False 174 | 175 | # Distribute training over GPUs 176 | net = nn.DataParallel(net) 177 | 178 | # Load weights from file 179 | if args.load: 180 | net.load_state_dict( 181 | torch.load(args.load, map_location=device) 182 | ) 183 | logging.info(f'Model loaded from {args.load}') 184 | 185 | net.to(device=device) 186 | # Faster convolutions, but more memory usage 187 | #cudnn.benchmark = True 188 | 189 | # Instantiate datasets 190 | if args.dataset == 'DRIVE': 191 | training_set = RetinaSegmentationDataset(args.train_img_dir if args.train_img_dir is not None else 'DRIVE/training/images/', 192 | args.train_mask_dir if args.train_mask_dir is not None else 'DRIVE/training/1st_manual/', args.scale, 193 | augmentation_ratio = args.augmentation_ratio, crop_size=512) 194 | validation_set = RetinaSegmentationDataset(args.val_img_dir if args.val_img_dir is not None else 'DRIVE/validation/images/', 195 | args.val_mask_dir if args.val_mask_dir is not None else 'DRIVE/validation/1st_manual/', args.scale) 196 | dataset_class = RetinaSegmentationDataset 197 | elif args.dataset == 'Coronary': 198 | training_set = CoronaryArterySegmentationDataset(args.train_img_dir if args.train_img_dir is not None else 'Coronary/train/imgs/', 199 | args.train_mask_dir if args.train_mask_dir is not None else 'Coronary/train/masks/', args.scale, 200 | augmentation_ratio = args.augmentation_ratio, crop_size=512) 201 | validation_set = CoronaryArterySegmentationDataset(args.val_img_dir if args.val_img_dir is not None else 'Coronary/val/imgs/', 202 | args.val_mask_dir if args.val_mask_dir is not None else 'Coronary/val/masks/', args.scale, mask_suffix='a') 203 | dataset_class = RetinaSegmentationDataset 204 | else: 205 | print("Invalid dataset") 206 | exit() 207 | 208 | try: 209 | train_net(net=net, 210 | device=device, 211 | training_set=training_set, 212 | validation_set=validation_set, 213 | dir_checkpoint=args.dir_checkpoint, 214 | epochs=args.epochs, 215 | batch_size=args.batchsize, 216 | lr=args.lr, 217 | img_scale=args.scale, 218 | n_classes=n_classes, 219 | n_channels=3, 220 | augmentation_ratio = args.augmentation_ratio) 221 | except KeyboardInterrupt: 222 | torch.save(net.state_dict(), 'INTERRUPTED.pth') 223 | logging.info('Saved interrupt') 224 | try: 225 | sys.exit(0) 226 | except SystemExit: 227 | os._exit(0) 228 | -------------------------------------------------------------------------------- /utils/augment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | from torchvision import transforms 5 | 6 | class TNetPolicy(object): 7 | """ 8 | Applies the augmentation policy used in Jun et al's coronary artery segmentation T-Net 9 | https://arxiv.org/abs/1905.04197. As described by the authors, first they zoom-in or zoom-out at a 10 | random ratio within +/- 20%. Then, the image is shifted, horizontally and vertically, at a random 11 | ratio within +/- 20% of the image size (512 x 512). Then, the angiography is rotated between 12 | +/- 30 degrees. The rotation angle is not larger because actual angiographies do not deviate much 13 | from this range. Finally, because the brightness of the angiography image can vary, the brightness is 14 | also changed within +/- 40% at random rates. 15 | Referred in the results as Aug1. 16 | Example: 17 | >>> policy = TNetPolicy() 18 | >>> transformed_img, transformed_mask = policy(image, mask) 19 | """ 20 | def __init__(self, scale_ranges = [0.8, 1.2], img_size = [512, 512], translate = [0.2, 0.2], rotation = [-30, 30], brightness = 0.4): 21 | self.scale_ranges = scale_ranges 22 | self.img_size = img_size 23 | self.translate = translate 24 | self.rotation = rotation 25 | self.brightness = brightness 26 | 27 | def __call__(self, image, mask = None): 28 | tf_mask = None 29 | tf_list = list() # List of transformation 30 | 31 | # Random zoom-in or zoom-out of -20% to 20% 32 | params = transforms.RandomAffine.get_params(degrees = [0, 0], translate = [0, 0], \ 33 | scale_ranges = self.scale_ranges, img_size = self.img_size, shears = [0, 0]) 34 | tf_image = transforms.functional.affine(image, params[0], params[1], params[2], params[3]) 35 | if mask is not None: 36 | tf_mask = transforms.functional.affine(mask, params[0], params[1], params[2], params[3]) 37 | 38 | # Random horizontal and vertical shift of -20% to 20% 39 | params = transforms.RandomAffine.get_params(degrees = [0, 0], translate = self.translate, \ 40 | scale_ranges = [1, 1], img_size = self.img_size, shears = [0, 0]) 41 | tf_image = transforms.functional.affine(tf_image, params[0], params[1], params[2], params[3]) 42 | if mask is not None: 43 | tf_mask = transforms.functional.affine(tf_mask, params[0], params[1], params[2], params[3]) 44 | 45 | # Random rotation of -30 to 30 degress 46 | angle = transforms.RandomRotation.get_params(self.rotation) 47 | tf_image = transforms.functional.rotate(tf_image, angle) 48 | if mask is not None: 49 | tf_mask = transforms.functional.rotate(tf_mask, angle) 50 | 51 | # Random brightness change of -40% to 40% 52 | tf = transforms.ColorJitter(brightness = self.brightness) 53 | tf_image = tf(tf_image) 54 | 55 | if mask is not None: 56 | return (tf_image, tf_mask) 57 | else: 58 | return tf_image 59 | 60 | def __repr__(self): 61 | return "TNet Coronary Artery Segmentation Augmentation Policy" 62 | 63 | class RetinaPolicy(object): 64 | def __init__(self, scale_ranges = [1, 1.1], img_size = [512, 512], translate = [0.1, 0.1], rotation = [-20, 20], crop_dims = [480, 480], brightness = None): 65 | self.scale_ranges = scale_ranges 66 | self.img_size = img_size 67 | self.translate = translate 68 | self.rotation = rotation 69 | self.brightness = brightness 70 | self.crop_dims = crop_dims 71 | 72 | def __call__(self, image, mask = None): 73 | tf_mask = None 74 | 75 | # Random crop 76 | i, j, h, w = transforms.RandomCrop.get_params(image, self.crop_dims) 77 | tf_image = transforms.functional.crop(image, i, j, h, w) 78 | if mask is not None: 79 | tf_mask = transforms.functional.crop(mask, i, j, h, w) 80 | 81 | # Random rotation of -20 to 20 degress 82 | angle = transforms.RandomRotation.get_params(self.rotation) 83 | tf_image = transforms.functional.rotate(tf_image, angle) 84 | if mask is not None: 85 | tf_mask = transforms.functional.rotate(tf_mask, angle) 86 | 87 | # Random horizontal and vertical shift of -10% to 10% 88 | params = transforms.RandomAffine.get_params(degrees = [0, 0], translate = self.translate, \ 89 | scale_ranges = [1, 1], img_size = self.img_size, shears = [0, 0]) 90 | tf_image = transforms.functional.affine(tf_image, params[0], params[1], params[2], params[3]) 91 | if mask is not None: 92 | tf_mask = transforms.functional.affine(tf_mask, params[0], params[1], params[2], params[3]) 93 | 94 | # TODO: -10% to 10% may make more sense, due to the existance of images with black padding borders 95 | # Random zoom-in of 0% to 10% 96 | params = transforms.RandomAffine.get_params(degrees = [0, 0], translate = [0, 0], \ 97 | scale_ranges = self.scale_ranges, img_size = self.img_size, shears = [0, 0]) 98 | tf_image = transforms.functional.affine(tf_image, params[0], params[1], params[2], params[3]) 99 | if mask is not None: 100 | tf_mask = transforms.functional.affine(tf_mask, params[0], params[1], params[2], params[3]) 101 | 102 | # TODO: change brightness too 103 | # Random brightness change 104 | if self.brightness is not None: 105 | tf = transforms.ColorJitter(brightness = self.brightness) 106 | tf_image = tf(tf_image) 107 | 108 | if mask is not None: 109 | return (tf_image, tf_mask) 110 | else: 111 | return tf_image 112 | 113 | def __repr__(self): 114 | return "Retinal Vessel Segmentation Augmentation Policy" 115 | 116 | class CoronaryPolicy(object): 117 | def __init__(self, scale_ranges = [1, 1.1], img_size = [512, 512], translate = [0.1, 0.1], rotation = [-20, 20], brightness = None): 118 | self.scale_ranges = scale_ranges 119 | self.img_size = img_size 120 | self.translate = translate 121 | self.rotation = rotation 122 | self.brightness = brightness 123 | 124 | def __call__(self, image, mask = None): 125 | tf_mask = None 126 | # Random rotation of -20 to 20 degress 127 | angle = transforms.RandomRotation.get_params(self.rotation) 128 | tf_image = transforms.functional.rotate(image, angle) 129 | if mask is not None: 130 | tf_mask = transforms.functional.rotate(mask, angle) 131 | 132 | # Random horizontal and vertical shift of -10% to 10% 133 | params = transforms.RandomAffine.get_params(degrees = [0, 0], translate = self.translate, \ 134 | scale_ranges = [1, 1], img_size = self.img_size, shears = [0, 0]) 135 | tf_image = transforms.functional.affine(tf_image, params[0], params[1], params[2], params[3]) 136 | if mask is not None: 137 | tf_mask = transforms.functional.affine(tf_mask, params[0], params[1], params[2], params[3]) 138 | 139 | # TODO: -10% to 10% may make more sense, due to the existance of images with black padding borders 140 | # Random zoom-in of 0% to 10% 141 | params = transforms.RandomAffine.get_params(degrees = [0, 0], translate = [0, 0], \ 142 | scale_ranges = self.scale_ranges, img_size = self.img_size, shears = [0, 0]) 143 | tf_image = transforms.functional.affine(tf_image, params[0], params[1], params[2], params[3]) 144 | if mask is not None: 145 | tf_mask = transforms.functional.affine(tf_mask, params[0], params[1], params[2], params[3]) 146 | 147 | # TODO: change brightness too 148 | # Random brightness change 149 | if self.brightness is not None: 150 | tf = transforms.ColorJitter(brightness = self.brightness) 151 | tf_image = tf(tf_image) 152 | 153 | if mask is not None: 154 | return (tf_image, tf_mask) 155 | else: 156 | return tf_image 157 | 158 | def __repr__(self): 159 | return "Coronary Artery Segmentation Augmentation Policy" -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import splitext 2 | from os import listdir 3 | from typing import Dict, List 4 | import numpy as np 5 | from glob import glob 6 | import torch 7 | from torch.utils.data import Dataset 8 | import logging 9 | from PIL import Image 10 | import torch.nn.functional as F 11 | from utils.augment import * 12 | 13 | r""" 14 | Defines the `BasicSegmentationDataset` and `CoronaryArterySegmentationDatasets`, which extend the `Dataset` and `BasicSegmentationDataset` \ 15 | classes, respectively. Each class defines the specific methods needed for data processing and a method :func:`__getitem__` to return samples. 16 | """ 17 | 18 | class BasicSegmentationDataset(Dataset): 19 | r""" 20 | Implements a basic dataset for segmentation tasks, with methods for image and mask scaling and normalization. \ 21 | The filenames of the segmentation ground truths must be equal to the filenames of the images to be segmented, \ 22 | except for a possible suffix. 23 | 24 | Args: 25 | imgs_dir (str): path to the directory containing the images to be segmented. 26 | masks_dir (str): path to the directory containing the segmentation ground truths. 27 | scale (float, optional): image scale, between 0 and 1, to be used in the segmentation. 28 | mask_suffix (str, optional): suffix to be added to an image's filename to obtain its 29 | ground truth filename. 30 | """ 31 | 32 | def __init__(self, imgs_dir: str, masks_dir: str, scale: float = 1, mask_suffix: str = ''): 33 | self.imgs_dir = imgs_dir 34 | self.masks_dir = masks_dir 35 | self.scale = scale 36 | self.mask_suffix = mask_suffix 37 | assert 0 < scale <= 1, 'Scale must be between 0 and 1' 38 | 39 | self.ids = [splitext(file)[0] for file in listdir(imgs_dir) 40 | if not file.startswith('.')] 41 | logging.info(f'Creating dataset with {len(self.ids)} examples') 42 | 43 | def __len__(self) -> int: 44 | r""" 45 | Returns the size of the dataset. 46 | """ 47 | return len(self.ids) 48 | 49 | @classmethod 50 | def preprocess(cls, pil_img: Image, scale: float) -> Image: 51 | r""" 52 | Preprocesses an `Image`, rescaling it and returning it as a NumPy array in 53 | the CHW format. 54 | 55 | Args: 56 | pil_imgs (Image): object of class `Image` to be preprocessed. 57 | scale (float): image scale, between 0 and 1. 58 | """ 59 | w, h = pil_img.size 60 | newW, newH = int(scale * w), int(scale * h) 61 | assert newW > 0 and newH > 0, 'Scale is too small' 62 | pil_img = pil_img.resize((newW, newH)) 63 | 64 | img_nd = np.array(pil_img) 65 | 66 | if len(img_nd.shape) == 2: 67 | img_nd = np.expand_dims(img_nd, axis=2) 68 | 69 | # HWC to CHW 70 | img_trans = img_nd.transpose((2, 0, 1)) 71 | if img_trans.max() > 1: 72 | img_trans = img_trans / 255 73 | 74 | return img_trans 75 | 76 | def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]: 77 | r""" 78 | Returns two tensors: an image and the corresponding mask. 79 | """ 80 | idx = self.ids[i] 81 | mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*') 82 | img_file = glob(self.imgs_dir + idx + '.*') 83 | 84 | assert len(mask_file) == 1, \ 85 | f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' 86 | assert len(img_file) == 1, \ 87 | f'Either no image or multiple images found for the ID {idx}: {img_file}' 88 | mask = Image.open(mask_file[0]) 89 | img = Image.open(img_file[0]) 90 | 91 | assert img.size == mask.size, \ 92 | f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' 93 | 94 | img = self.preprocess(img, self.scale) 95 | mask = self.preprocess(mask, self.scale) 96 | 97 | return { 98 | 'image': [torch.from_numpy(img).type(torch.FloatTensor)], 99 | 'mask': [torch.from_numpy(mask).type(torch.FloatTensor)] 100 | } 101 | 102 | class RetinaSegmentationDataset(BasicSegmentationDataset): 103 | r""" 104 | Implements a dataset for the Retinal Vessel Segmentation task 105 | 106 | Args: 107 | imgs_dir (str): path to the directory containing the images to be segmented. 108 | masks_dir (str): path to the directory containing the segmentation ground truths. 109 | scale (float, optional): image scale, between 0 and 1, to be used in the segmentation. 110 | augmentation_ratio (int, optional): number of augmentations to generate per image. 111 | crop_size (int, optional): size of the square image to be fed to the model. 112 | aug_policy (str, optional): data augmentation policy. 113 | """ 114 | # Number of classes, including the background class 115 | n_classes = 2 116 | 117 | # Maps maks grayscale value to mask class index 118 | gray2class_mapping = { 119 | 0: 0, 120 | 255: 1 121 | } 122 | 123 | # Maps mask grayscale value to mask RGB value 124 | gray2rgb_mapping = { 125 | 0: (0, 0, 0), 126 | 255: (255, 255, 255) 127 | } 128 | rgb2class_mapping = { 129 | (0, 0, 0): 0, 130 | (255, 255, 255): 1 131 | } 132 | 133 | def __init__(self, imgs_dir: str, masks_dir: str, scale: float = 1, augmentation_ratio: int = 0, crop_size: int = 512, aug_policy: str = 'retina'): 134 | super().__init__(imgs_dir, masks_dir, scale) 135 | self.augmentation_ratio = augmentation_ratio 136 | self.policy = aug_policy 137 | self.crop_size = crop_size 138 | 139 | @classmethod 140 | def mask_img2class_mask(cls, pil_mask: Image, scale: float) -> np.array: 141 | r""" 142 | Preprocesses a grayscale `Image` containing a segmentation mask, rescaling it, converting its grayscale values \ 143 | to class indices and returning it as a NumPy array in the CHW format. 144 | 145 | Args: 146 | pil_imgs (Image): object of class `Image` to be preprocessed. 147 | scale (float): image scale, between 0 and 1. 148 | """ 149 | w, h = pil_mask.size 150 | newW, newH = int(scale * w), int(scale * h) 151 | assert newW > 0 and newH > 0, 'Scale is too small' 152 | pil_mask = pil_mask.resize((newW, newH)) 153 | 154 | if pil_mask.mode != "L": 155 | pil_mask = pil_mask.convert(mode="L") 156 | mask_nd = np.array(pil_mask) 157 | 158 | if len(mask_nd.shape) == 2: 159 | mask_nd = np.expand_dims(mask_nd, axis=2) 160 | 161 | # HWC to CHW 162 | mask = mask_nd.transpose((2, 0, 1)) 163 | mask = mask / 255 164 | 165 | return mask 166 | 167 | @classmethod 168 | def one_hot2mask(cls, one_hot_mask: torch.FloatTensor, shape: str = 'CHW') -> np.array: 169 | r""" 170 | Returns the one-channel mask (1HW) corresponding to the CHW one-hot encoded one. 171 | """ 172 | # Assuming tensor in CHW shape 173 | if shape == 'CHW': 174 | return np.argmax(one_hot_mask.detach().numpy(), axis=0) 175 | elif shape == 'NCHW': 176 | return np.argmax(one_hot_mask.detach().numpy(), axis=1) 177 | return np.argmax(one_hot_mask.detach().numpy(), axis=0) 178 | 179 | @classmethod 180 | def mask2one_hot(cls, mask_tensor: torch.FloatTensor, output_shape: str = 'NHWC') -> torch.Tensor: 181 | r""" 182 | Returns the received `FloatTensor` in the N1HW shape to a one hot encoded `LongTensor` in the NHWC shape.\ 183 | Can return in NCHW shape is specified. 184 | 185 | Args: 186 | mask_tensor (FloatTensor): N1HW FloatTensor to be one-hot encoded. 187 | output_shape (str): NHWC or NCHW. 188 | """ 189 | assert output_shape == 'NHWC' or output_shape == 'NCHW', 'Invalid output shape specified' 190 | 191 | # Assuming tensor in NCHW = N1HW shape 192 | if output_shape == 'NHWC': 193 | return F.one_hot(mask_tensor, cls.n_classes).squeeze(1) 194 | # Assuming tensor in N1HW shape 195 | elif output_shape == 'NCHW': 196 | return torch.transpose(torch.transpose(F.one_hot(mask_tensor, cls.n_classes), 2, 3), 1, 2) 197 | 198 | @classmethod 199 | def class2gray(cls, mask: np.array) -> np.array: 200 | r""" 201 | Replaces the class labels in a numpy array represented mask by their grayscale values, according to `gray2class_mapping`. 202 | """ 203 | assert len(cls.gray2class_mapping) == cls.n_classes, \ 204 | f'Number of class mappings - {len(cls.gray2class_mapping)} - should be the same as the number of classes - {cls.n_classes}' 205 | for color, label in cls.gray2class_mapping.items(): 206 | mask[mask == label] = color 207 | return mask 208 | 209 | @classmethod 210 | def gray2rgb(cls, img: Image) -> Image: 211 | r""" 212 | Converts a grayscale image into an RGB one, according to gray2rgb_mapping. 213 | """ 214 | rgb_img = Image.new("RGB", img.size) 215 | for x in range(img.size[0]): 216 | for y in range(img.size[1]): 217 | rgb_img.putpixel((x, y), cls.gray2rgb_mapping[img.getpixel((x, y))]) 218 | return rgb_img 219 | 220 | @classmethod 221 | def mask2image(cls, mask: np.array) -> Image: 222 | r""" 223 | Converts a one-channel mask (1HW) with class indices into an RGB image, according to gray2class_mapping and gray2rgb_mapping. 224 | """ 225 | return cls.gray2rgb(Image.fromarray(cls.class2gray(mask).astype(np.uint8))) 226 | 227 | def augment(self, image, mask, policy = 'retina', augmentation_ratio = 0): 228 | """ 229 | Returns a list with the original image and mask and augmented versions of them. 230 | The number of augmented images and masks is equal to the specified augmentation_ratio. 231 | The policy is chosen by the policy argument 232 | """ 233 | tf_imgs = [] 234 | tf_masks = [] 235 | # Data Augmentation 236 | for i in range(augmentation_ratio): 237 | # Select the policy 238 | if policy == 'retina': 239 | aug_policy = RetinaPolicy(crop_dims=[self.crop_size, self.crop_size], brightness=[0.9, 1.1]) 240 | 241 | # Apply the transformation 242 | tf_image, tf_mask = aug_policy(image, mask) 243 | 244 | # Further process the images and masks 245 | tf_image = self.preprocess(tf_image, self.scale) 246 | tf_mask = self.mask_img2class_mask(tf_mask, self.scale) 247 | tf_image = torch.from_numpy(tf_image).type(torch.FloatTensor) 248 | tf_mask = torch.from_numpy(tf_mask).type(torch.FloatTensor) 249 | tf_imgs.append(tf_image) 250 | tf_masks.append(tf_mask) 251 | 252 | i, j, h, w = transforms.RandomCrop.get_params(image, [self.crop_size, self.crop_size]) 253 | image = transforms.functional.crop(image, i, j, h, w) 254 | mask = transforms.functional.crop(mask, i, j, h, w) 255 | image = self.preprocess(image, self.scale) 256 | mask = self.mask_img2class_mask(mask, self.scale) 257 | image = torch.from_numpy(image).type(torch.FloatTensor) 258 | mask = torch.from_numpy(mask).type(torch.FloatTensor) 259 | 260 | tf_imgs.insert(0, image) 261 | tf_masks.insert(0, mask) 262 | 263 | return (tf_imgs, tf_masks) 264 | 265 | def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]: 266 | r""" 267 | Returns two tensors: an image, of shape 1HW, and the corresponding mask, of shape CHW. 268 | """ 269 | idx = self.ids[i] 270 | mask_file = glob(self.masks_dir + idx.replace('training', 'manual1') + '.*') 271 | img_file = glob(self.imgs_dir + idx + '.*') 272 | 273 | assert len(mask_file) == 1, \ 274 | f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' 275 | assert len(img_file) == 1, \ 276 | f'Either no image or multiple images found for the ID {idx}: {img_file}' 277 | mask = Image.open(mask_file[0]) 278 | image = Image.open(img_file[0]) 279 | 280 | assert image.size == mask.size, \ 281 | f'Image and mask {idx} should be the same size, but are {image.size} and {mask.size}' 282 | 283 | images, masks = self.augment(image, mask, policy = self.policy, augmentation_ratio = self.augmentation_ratio) 284 | return { 285 | 'image': images, 286 | 'mask': masks 287 | } 288 | 289 | class CoronaryArterySegmentationDataset(BasicSegmentationDataset): 290 | r""" 291 | Implements a dataset for the Coronary Artery Segmentation task, with mappings between grayscale image values to class \ 292 | indices, grayscale image values to RGB image values, and methods for the necessary conversions. 293 | 294 | Args: 295 | imgs_dir (str): path to the directory containing the images to be segmented. 296 | masks_dir (str): path to the directory containing the segmentation ground truths. 297 | scale (float, optional): image scale, between 0 and 1, to be used in the segmentation. 298 | mask_suffix (str, optional): suffix to be added to an image's filename to obtain its ground truth filename. 299 | augmentation_ratio (int, optional): number of augmentations to generate per image. 300 | aug_policy (str, optional): data augmentation policy. 301 | """ 302 | 303 | # Maps maks grayscale value to mask class index 304 | gray2class_mapping = { 305 | 0: 0, 306 | 100: 1, 307 | 255: 2 308 | } 309 | 310 | # Maps mask grayscale value to mask RGB value 311 | gray2rgb_mapping = { 312 | 0: (0, 0, 0), 313 | 100: (255, 0, 0), 314 | 255: (255, 255, 255) 315 | } 316 | rgb2class_mapping = { 317 | (0, 0, 0): 0, 318 | (255, 0, 0): 1, 319 | (255, 255, 255): 2 320 | } 321 | 322 | # Total number of classes, including the background class 323 | n_classes = 3 324 | 325 | def __init__(self, imgs_dir: str, masks_dir: str, scale: float = 1, mask_suffix: str = 'a', augmentation_ratio: int = 0, aug_policy: str = 'coronary'): 326 | super().__init__(imgs_dir, masks_dir, scale, mask_suffix) 327 | self.augmentation_ratio = augmentation_ratio 328 | self.policy = aug_policy 329 | 330 | @classmethod 331 | def mask_img2class_mask(cls, pil_mask: Image, scale: float) -> np.array: 332 | r""" 333 | Preprocesses a grayscale `Image` containing a segmentation mask, rescaling it, converting its grayscale values \ 334 | to class indices and returning it as a NumPy array in the CHW format. 335 | 336 | Args: 337 | pil_imgs (Image): object of class `Image` to be preprocessed. 338 | scale (float): image scale, between 0 and 1. 339 | """ 340 | w, h = pil_mask.size 341 | newW, newH = int(scale * w), int(scale * h) 342 | assert newW > 0 and newH > 0, 'Scale is too small' 343 | pil_mask = pil_mask.resize((newW, newH)) 344 | 345 | if pil_mask.mode != "L": 346 | pil_mask = pil_mask.convert(mode="L") 347 | mask_nd = np.array(pil_mask) 348 | 349 | if len(mask_nd.shape) == 2: 350 | mask_nd = np.expand_dims(mask_nd, axis=2) 351 | 352 | # HWC to CHW 353 | mask = mask_nd.transpose((2, 0, 1)) 354 | mask = mask / 255 355 | mask = mask * 2 356 | mask = np.around(mask) 357 | 358 | return mask 359 | 360 | @classmethod 361 | def mask2one_hot(cls, mask_tensor: torch.LongTensor, output_shape: str = 'NHWC') -> torch.Tensor: 362 | r""" 363 | Returns the received `FloatTensor` in the N1HW shape to a one hot encoded `LongTensor` in the NHWC shape.\ 364 | Can return in NCHW shape is specified. 365 | 366 | Args: 367 | mask_tensor (LongTensor): N1HW LongTensor to be one-hot encoded. 368 | output_shape (str): NHWC or NCHW. 369 | """ 370 | assert output_shape == 'NHWC' or output_shape == 'NCHW', 'Invalid output shape specified' 371 | 372 | # Assuming tensor in NCHW = N1HW shape 373 | if output_shape == 'NHWC': 374 | return F.one_hot(mask_tensor, cls.n_classes).squeeze(1) 375 | # Assuming tensor in N1HW shape 376 | elif output_shape == 'NCHW': 377 | return torch.transpose(torch.transpose(F.one_hot(mask_tensor, cls.n_classes), 2, 3), 1, 2) 378 | 379 | @classmethod 380 | def one_hot2mask(cls, one_hot_mask: torch.FloatTensor, shape: str = 'CHW') -> np.array: 381 | r""" 382 | Returns the one-channel mask (1HW) corresponding to the CHW one-hot encoded one. 383 | """ 384 | # Assuming tensor in CHW shape 385 | if shape == 'CHW': 386 | return np.argmax(one_hot_mask.detach().numpy(), axis=0) 387 | elif shape == 'NCHW': 388 | return np.argmax(one_hot_mask.detach().numpy(), axis=1) 389 | return np.argmax(one_hot_mask.detach().numpy(), axis=0) 390 | 391 | @classmethod 392 | def class2gray(cls, mask: np.array) -> np.array: 393 | r""" 394 | Replaces the class labels in a numpy array represented mask by their grayscale values, according to `gray2class_mapping`. 395 | """ 396 | assert len(cls.gray2class_mapping) == cls.n_classes, \ 397 | f'Number of class mappings - {len(cls.gray2class_mapping)} - should be the same as the number of classes - {cls.n_classes}' 398 | for color, label in cls.gray2class_mapping.items(): 399 | mask[mask == label] = color 400 | return mask 401 | 402 | @classmethod 403 | def gray2rgb(cls, img: Image) -> Image: 404 | r""" 405 | Converts a grayscale image into an RGB one, according to gray2rgb_mapping. 406 | """ 407 | rgb_img = Image.new("RGB", img.size) 408 | for x in range(img.size[0]): 409 | for y in range(img.size[1]): 410 | rgb_img.putpixel((x, y), cls.gray2rgb_mapping[img.getpixel((x, y))]) 411 | return rgb_img 412 | 413 | @classmethod 414 | def mask2image(cls, mask: np.array) -> Image: 415 | r""" 416 | Converts a one-channel mask (1HW) with class indices into an RGB image, according to gray2class_mapping and gray2rgb_mapping. 417 | """ 418 | return cls.gray2rgb(Image.fromarray(cls.class2gray(mask).astype(np.uint8))) 419 | 420 | def augment(self, image, mask, policy = 'coronary', augmentation_ratio = 0): 421 | """ 422 | Returns a list with the original image and mask and augmented versions of them. 423 | The number of augmented images and masks is equal to the specified augmentation_ratio. 424 | The policy is chosen by the policy argument 425 | """ 426 | tf_imgs = [] 427 | tf_masks = [] 428 | # Data Augmentation 429 | for i in range(augmentation_ratio): 430 | # Select the policy 431 | if policy == 'coronary': 432 | aug_policy = CoronaryPolicy() 433 | elif policy == 'retina': 434 | aug_policy = RetinaPolicy() 435 | elif policy == 'tnet': 436 | aug_policy = TNetPolicy() 437 | 438 | # Apply the transformation 439 | tf_image, tf_mask = aug_policy(image, mask) 440 | 441 | # Further process the images and masks 442 | tf_image = self.preprocess(tf_image, self.scale) 443 | tf_mask = self.mask_img2class_mask(tf_mask, self.scale) 444 | tf_image = torch.from_numpy(tf_image).type(torch.FloatTensor) 445 | tf_mask = torch.from_numpy(tf_mask).type(torch.FloatTensor) 446 | tf_imgs.append(tf_image) 447 | tf_masks.append(tf_mask) 448 | 449 | image = self.preprocess(image, self.scale) 450 | mask = self.mask_img2class_mask(mask, self.scale) 451 | image = torch.from_numpy(image).type(torch.FloatTensor) 452 | mask = torch.from_numpy(mask).type(torch.FloatTensor) 453 | 454 | tf_imgs.insert(0, image) 455 | tf_masks.insert(0, mask) 456 | 457 | return (tf_imgs, tf_masks) 458 | 459 | def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]: 460 | r""" 461 | Returns two tensors: an image, of shape 1HW, and the corresponding mask, of shape CHW. 462 | """ 463 | idx = self.ids[i] 464 | mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*') 465 | img_file = glob(self.imgs_dir + idx + '.*') 466 | 467 | assert len(mask_file) == 1, \ 468 | f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' 469 | assert len(img_file) == 1, \ 470 | f'Either no image or multiple images found for the ID {idx}: {img_file}' 471 | mask = Image.open(mask_file[0]) 472 | 473 | if mask.mode != 'L': 474 | mask = mask.convert(mode='L') 475 | image = Image.open(img_file[0]) 476 | 477 | assert image.size == mask.size, \ 478 | f'Image and mask {idx} should be the same size, but are {image.size} and {mask.size}' 479 | 480 | images, masks = self.augment(image, mask, policy = self.policy, augmentation_ratio = self.augmentation_ratio) 481 | return { 482 | 'image': images, 483 | 'mask': masks 484 | } --------------------------------------------------------------------------------