├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── __init__.py ├── examples ├── human-seg_atten-unet-backbone_coco.ipynb ├── segmentation_unet-three-plus_oxford-iiit.ipynb └── user_guide_models.ipynb ├── keras_unet_collection ├── __init__.py ├── _backbone_zoo.py ├── _model_att_unet_2d.py ├── _model_r2_unet_2d.py ├── _model_resunet_a_2d.py ├── _model_swin_unet_2d.py ├── _model_transunet_2d.py ├── _model_u2net_2d.py ├── _model_unet_2d.py ├── _model_unet_3plus_2d.py ├── _model_unet_plus_2d.py ├── _model_vnet_2d.py ├── activations.py ├── backbones.py ├── base.py ├── layer_utils.py ├── losses.py ├── models.py ├── transformer_layers.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | 3 | # author's testing scripts 4 | *TEST* 5 | pyproject.toml 6 | setup.cfg 7 | setup.py 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | */.ipynb_checkpoints/ 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | | Version | Release date | Update | 3 | |:--------:|:-------------:|:-------- | 4 | | 0.0.2 | 2020-12-30 | (1) CRPS loss function.
(2) Semi-hard triplet loss function.
(3) Fixing user-specified names on keras models. | 5 | | 0.0.3 | 2021-01-01 | (1) Bugfix.
(2) keyword and documentation fixes for R2U-Net. | 6 | | 0.0.4 | 2021-01-04 | (1) Bugfix.
(2) New feature. ResUnet-a implementation.
(3) New feature. Tversky loss
(4) New feature. GELU and Snake activation layers. | 7 | | 0.0.5 | 2021-01-06 | (1) Bug and typo fixes.
(2) File structure changes.
(3) New feature. U^2-Net.
(4) New feature. Sigmoid output activation. | 8 | | 0.0.6 | 2021-01-10 | (1) New feature. Deep supervision for Unet++.
(2) New feature. Conv-based down- and upsampling for U^2-Net. | 9 | | 0.0.7 | 2021-01-12 | (1) Bugfix.
(2) New feature. backbone functions.
(3) New feature. UNET 3+. | 10 | | 0.0.8 | 2021-01-13 | Bugfix. | 11 | | 0.0.9 | 2021-01-14 | (1) New feature. two util functions.
(2) Documentation. UNET 3+ image segmentation. | 12 | | 0.0.10 | 2021-01-18 | (1) Bugfix.
(2) Rename the `backbone` module as `base`.
(3) New documentation.
(4) New feature. 2-d V-net. | 13 | | 0.0.11 | 2021-01-22 | Bugfix for UNET 3+. | 14 | | 0.0.12 | 2021-01-25 | (1) Bugfix.
(2) New feature. Backbone models for Unet, Unet++, and UNET 3+. | 15 | | 0.0.13 | 2021-01-25 | (1) Bugfix.
(2) New feature. Backbone models for Attention Unet. | 16 | | 0.0.14 | 2021-01-30 | Python script cleaning and documentation fix. | 17 | | 0.0.15 | 2021-02-03 | (1) Bugfix.
(2) New feature. More `pool` and `unpool` options. | 18 | | 0.0.16 | 2021-02-05 | (1) Bugfix.
(2) New feature. Loss functions. | 19 | | 0.0.17 | 2021-02-28 | Bugfix on UNET 3+ | 20 | | 0.0.18 | 2021-03-04 | Bugfix on UNET 3+ | 21 | | 0.1.1 | 2021-05-27 | (1) Bugfix.
(2) New feature. TransUNET. | 22 | | 0.1.5 | 2021-05-27 | Bugfix on the packaging issue. | 23 | | 0.1.6 | 2021-05-30 | New feature. MS-SSMI and IoU loss functions. | 24 | | 0.1.7 | 2021-06-07 | Bugfix on TransUNET | 25 | | 0.1.8 | 2021-06-07 | Bugfix on TransUNET | 26 | | 0.1.9 | 2021-06-25 | (1) New feature. Swin-UNET.
(2) Update utils and the user guide. | 27 | | 0.1.10 | 2021-07-13 | Bugfix on model saving issues. | 28 | | 0.1.11 | 2021-09-01 | Bugfix on transformer layers. | 29 | | 0.1.12 | 2021-09-04 | Add citation info into the main page.
Generate package doi reference. | 30 | | 0.1.13 | 2022-01-10 | Bugfix on `utils.dummpy_loader` | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yingkai Sha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-unet-collection 2 | 3 | [![PyPI version](https://badge.fury.io/py/keras-unet-collection.svg)](https://badge.fury.io/py/keras-unet-collection) 4 | [![PyPI license](https://img.shields.io/pypi/l/keras-unet-collection.svg)](https://pypi.org/project/keras-unet-collection/) 5 | [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/yingkaisha/keras-unet-collection/graphs/commit-activity) 6 | 7 | [![DOI](https://zenodo.org/badge/323426984.svg)](https://zenodo.org/badge/latestdoi/323426984) 8 | 9 | The `tensorflow.keras` implementation of U-net, V-net, U-net++, UNET 3+, Attention U-net, R2U-net, ResUnet-a, U^2-Net, TransUNET, and Swin-UNET with optional ImageNet-trained backbones. 10 | 11 | ---------- 12 | 13 | `keras_unet_collection.models` contains functions that configure keras models with hyper-parameter options. 14 | 15 | * Pre-trained ImageNet backbones are supported for U-net, U-net++, UNET 3+, Attention U-net, and TransUNET. 16 | * Deep supervision is supported for U-net++, UNET 3+, and U^2-Net. 17 | * See the [User guide](https://github.com/yingkaisha/keras-unet-collection/blob/main/examples/user_guide_models.ipynb) for other options and use cases. 18 | 19 | | `keras_unet_collection.models` | Name | Reference | 20 | |:---------------|:----------------|:----------------| 21 | | `unet_2d` | U-net | [Ronneberger et al. (2015)](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28) | 22 | | `vnet_2d` | V-net (modified for 2-d inputs) | [Milletari et al. (2016)](https://arxiv.org/abs/1606.04797) | 23 | | `unet_plus_2d` | U-net++ | [Zhou et al. (2018)](https://link.springer.com/chapter/10.1007/978-3-030-00889-5_1) | 24 | | `r2_unet_2d` | R2U-Net | [Alom et al. (2018)](https://arxiv.org/abs/1802.06955) | 25 | | `att_unet_2d` | Attention U-net | [Oktay et al. (2018)](https://arxiv.org/abs/1804.03999) | 26 | | `resunet_a_2d` | ResUnet-a | [Diakogiannis et al. (2020)](https://doi.org/10.1016/j.isprsjprs.2020.01.013) | 27 | | `u2net_2d` | U^2-Net | [Qin et al. (2020)](https://arxiv.org/abs/2005.09007) | 28 | | `unet_3plus_2d` | UNET 3+ | [Huang et al. (2020)](https://arxiv.org/abs/2004.08790) | 29 | | `transunet_2d` | TransUNET | [Chen et al. (2021)](https://arxiv.org/abs/2102.04306) | 30 | | `swin_unet_2d` | Swin-UNET | [Hu et al. (2021)](https://arxiv.org/abs/2105.05537) | 31 | 32 | **Note**: the two Transformer models are incompatible with `NumPy 1.20`; `NumPy 1.19.5` is recommended. 33 | 34 | ---------- 35 | 36 | ` keras_unet_collection.base` contains functions that build the base architecture (i.e., without model heads) of Unet variants for model customization and debugging. 37 | 38 | | ` keras_unet_collection.base` | Notes | 39 | |:-----------------------------------|:------| 40 | | `unet_2d_base`, `vnet_2d_base`, `unet_plus_2d_base`, `unet_3plus_2d_base`, `att_unet_2d_base`, `r2_unet_2d_base`, `resunet_a_2d_base`, `u2net_2d_base`, `transunet_2d_base`, `swin_unet_2d_base` | Functions that accept an input tensor and hyper-parameters of the corresponded model, and produce output tensors of the base architecture. | 41 | 42 | ---------- 43 | 44 | `keras_unet_collection.activations` and `keras_unet_collection.losses` provide additional activation layers and loss functions. 45 | 46 | | `keras_unet_collection.activations` | Name | Reference | 47 | |:--------|:----------------|:----------------| 48 | | `GELU` | Gaussian Error Linear Units (GELU) | [Hendrycks et al. (2016)](https://arxiv.org/abs/1606.08415) | 49 | | `Snake` | Snake activation | [Liu et al. (2020)](https://arxiv.org/abs/2006.08195) | 50 | 51 | | `keras_unet_collection.losses` | Name | Reference | 52 | |:----------------|:----------------|:----------------| 53 | | `dice` | Dice loss | [Sudre et al. (2017)](https://link.springer.com/chapter/10.1007/978-3-319-67558-9_28) | 54 | | `tversky` | Tversky loss | [Hashemi et al. (2018)](https://ieeexplore.ieee.org/abstract/document/8573779) | 55 | | `focal_tversky` | Focal Tversky loss | [Abraham et al. (2019)](https://ieeexplore.ieee.org/abstract/document/8759329) | 56 | | `ms_ssim` | Multi-scale Structural Similarity Index Measure loss | [Wang et al. (2003)](https://ieeexplore.ieee.org/abstract/document/1292216) | 57 | | `iou_seg` | Intersection over Union (IoU) loss for segmentation | [Rahman and Wang (2016)](https://link.springer.com/chapter/10.1007/978-3-319-50835-1_22) | 58 | | `iou_box` | (Generalized) IoU loss for object detection | [Rezatofighi et al. (2019)](https://openaccess.thecvf.com/content_CVPR_2019/html/Rezatofighi_Generalized_Intersection_Over_Union_A_Metric_and_a_Loss_for_CVPR_2019_paper.html) | 59 | | `triplet_1d` | Semi-hard triplet loss (experimental) | | 60 | | `crps2d_tf` | CRPS loss (experimental) | | 61 | 62 | # Installation 63 | 64 | The project is hosted on [PyPI](https://pypi.org/project/keras-unet-collection/) and can thus be installed by: 65 | 66 | ``` 67 | pip install keras-unet-collection 68 | ``` 69 | 70 | # Usage 71 | 72 | ```python 73 | from keras_unet_collection import models 74 | # e.g. models.unet_2d(...) 75 | ``` 76 | * **Note**: Currently supported backbone models are: `VGG[16,19]`, `ResNet[50,101,152]`, `ResNet[50,101,152]V2`, `DenseNet[121,169,201]`, and `EfficientNetB[0-7]`. See [Keras Applications](https://keras.io/api/applications/) for details. 77 | 78 | * **Note**: Neural networks produced by this package may contain customized layers that are not part of the Tensorflow. It is reommended to save and load model weights. 79 | 80 | * [Changelog](https://github.com/yingkaisha/keras-unet-collection/blob/main/CHANGELOG.md) 81 | 82 | # Examples 83 | 84 | * Jupyter notebooks are provided as [examples](https://github.com/yingkaisha/keras-unet-collection/tree/main/examples): 85 | 86 | * Attention U-net with VGG16 backbone [[link]](https://github.com/yingkaisha/keras-unet-collection/blob/main/examples/human-seg_atten-unet-backbone_coco.ipynb). 87 | 88 | * UNET 3+ with deep supervision, classification-guided module, and hybrid loss [[link]](https://github.com/yingkaisha/keras-unet-collection/blob/main/examples/segmentation_unet-three-plus_oxford-iiit.ipynb). 89 | 90 | * Vision-Transformer-based examples are in progress, and available at [**keras-vision-transformer**](https://github.com/yingkaisha/keras-vision-transformer). 91 | 92 | # Dependencies 93 | 94 | * TensorFlow 2.5.0, Keras 2.5.0, NumPy 1.19.5. 95 | 96 | * (Optional for examples) Pillow, matplotlib, etc. 97 | 98 | # Overview 99 | 100 | U-net is a convolutional neural network with encoder-decoder architecture and skip-connections, loosely defined under the concept of "fully convolutional networks." U-net was originally proposed for the semantic segmentation of medical images and is modified for solving a wider range of gridded learning problems. 101 | 102 | U-net and many of its variants take three or four-dimensional tensors as inputs and produce outputs of the same shape. One technical highlight of these models is the skip-connections from downsampling to upsampling layers, which benefit the reconstruction of high-resolution, gridded outputs. 103 | 104 | # Contact 105 | 106 | Yingkai (Kyle) Sha <> <> 107 | 108 | # License 109 | 110 | [MIT License](https://github.com/yingkaisha/keras-unet/blob/main/LICENSE) 111 | 112 | # Citation 113 | 114 | * Sha, Y., 2021: Keras-unet-collection. GitHub repository, accessed 4 September 2021, https://doi.org/10.5281/zenodo.5449801 115 | 116 | ``` 117 | @misc{keras-unet-collection, 118 | author = {Sha, Yingkai}, 119 | title = {Keras-unet-collection}, 120 | year = {2021}, 121 | publisher = {GitHub}, 122 | journal = {GitHub repository}, 123 | howpublished = {\url{https://github.com/yingkaisha/keras-unet-collection}}, 124 | doi = {10.5281/zenodo.5449801} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yingkaisha/keras-unet-collection/13e0e7e0b3c12d2be58d4a3157f3468652e00969/__init__.py -------------------------------------------------------------------------------- /keras_unet_collection/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /keras_unet_collection/_backbone_zoo.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from tensorflow.keras.applications import * 5 | from tensorflow.keras.models import Model 6 | 7 | from keras_unet_collection.utils import freeze_model 8 | 9 | import warnings 10 | 11 | layer_cadidates = { 12 | 'VGG16': ('block1_conv2', 'block2_conv2', 'block3_conv3', 'block4_conv3', 'block5_conv3'), 13 | 'VGG19': ('block1_conv2', 'block2_conv2', 'block3_conv4', 'block4_conv4', 'block5_conv4'), 14 | 'ResNet50': ('conv1_relu', 'conv2_block3_out', 'conv3_block4_out', 'conv4_block6_out', 'conv5_block3_out'), 15 | 'ResNet101': ('conv1_relu', 'conv2_block3_out', 'conv3_block4_out', 'conv4_block23_out', 'conv5_block3_out'), 16 | 'ResNet152': ('conv1_relu', 'conv2_block3_out', 'conv3_block8_out', 'conv4_block36_out', 'conv5_block3_out'), 17 | 'ResNet50V2': ('conv1_conv', 'conv2_block3_1_relu', 'conv3_block4_1_relu', 'conv4_block6_1_relu', 'post_relu'), 18 | 'ResNet101V2': ('conv1_conv', 'conv2_block3_1_relu', 'conv3_block4_1_relu', 'conv4_block23_1_relu', 'post_relu'), 19 | 'ResNet152V2': ('conv1_conv', 'conv2_block3_1_relu', 'conv3_block8_1_relu', 'conv4_block36_1_relu', 'post_relu'), 20 | 'DenseNet121': ('conv1/relu', 'pool2_conv', 'pool3_conv', 'pool4_conv', 'relu'), 21 | 'DenseNet169': ('conv1/relu', 'pool2_conv', 'pool3_conv', 'pool4_conv', 'relu'), 22 | 'DenseNet201': ('conv1/relu', 'pool2_conv', 'pool3_conv', 'pool4_conv', 'relu'), 23 | 'EfficientNetB0': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'), 24 | 'EfficientNetB1': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'), 25 | 'EfficientNetB2': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'), 26 | 'EfficientNetB3': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'), 27 | 'EfficientNetB4': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'), 28 | 'EfficientNetB5': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'), 29 | 'EfficientNetB6': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'), 30 | 'EfficientNetB7': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),} 31 | 32 | def bach_norm_checker(backbone_name, batch_norm): 33 | '''batch norm checker''' 34 | if 'VGG' in backbone_name: 35 | batch_norm_backbone = False 36 | else: 37 | batch_norm_backbone = True 38 | 39 | if batch_norm_backbone != batch_norm: 40 | if batch_norm_backbone: 41 | param_mismatch = "\n\nBackbone {} uses batch norm, but other layers received batch_norm={}".format(backbone_name, batch_norm) 42 | else: 43 | param_mismatch = "\n\nBackbone {} does not use batch norm, but other layers received batch_norm={}".format(backbone_name, batch_norm) 44 | 45 | warnings.warn(param_mismatch); 46 | 47 | def backbone_zoo(backbone_name, weights, input_tensor, depth, freeze_backbone, freeze_batch_norm): 48 | ''' 49 | Configuring a user specified encoder model based on the `tensorflow.keras.applications` 50 | 51 | Input 52 | ---------- 53 | backbone_name: the bakcbone model name. Expected as one of the `tensorflow.keras.applications` class. 54 | Currently supported backbones are: 55 | (1) VGG16, VGG19 56 | (2) ResNet50, ResNet101, ResNet152 57 | (3) ResNet50V2, ResNet101V2, ResNet152V2 58 | (4) DenseNet121, DenseNet169, DenseNet201 59 | (5) EfficientNetB[0,7] 60 | 61 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 62 | or the path to the weights file to be loaded. 63 | input_tensor: the input tensor 64 | depth: number of encoded feature maps. 65 | If four dwonsampling levels are needed, then depth=4. 66 | 67 | freeze_backbone: True for a frozen backbone 68 | freeze_batch_norm: False for not freezing batch normalization layers. 69 | 70 | Output 71 | ---------- 72 | model: a keras backbone model. 73 | 74 | ''' 75 | 76 | cadidate = layer_cadidates[backbone_name] 77 | 78 | # ----- # 79 | # depth checking 80 | depth_max = len(cadidate) 81 | if depth > depth_max: 82 | depth = depth_max 83 | # ----- # 84 | 85 | backbone_func = eval(backbone_name) 86 | backbone_ = backbone_func(include_top=False, weights=weights, input_tensor=input_tensor, pooling=None,) 87 | 88 | X_skip = [] 89 | 90 | for i in range(depth): 91 | X_skip.append(backbone_.get_layer(cadidate[i]).output) 92 | 93 | model = Model(inputs=[input_tensor,], outputs=X_skip, name='{}_backbone'.format(backbone_name)) 94 | 95 | if freeze_backbone: 96 | model = freeze_model(model, freeze_batch_norm=freeze_batch_norm) 97 | 98 | return model 99 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_att_unet_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right 7 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker 8 | 9 | from tensorflow.keras.layers import Input 10 | from tensorflow.keras.models import Model 11 | 12 | 13 | def UNET_att_right(X, X_left, channel, att_channel, kernel_size=3, stack_num=2, 14 | activation='ReLU', atten_activation='ReLU', attention='add', 15 | unpool=True, batch_norm=False, name='right0'): 16 | ''' 17 | the decoder block of Attention U-net. 18 | 19 | UNET_att_right(X, X_left, channel, att_channel, kernel_size=3, stack_num=2, 20 | activation='ReLU', atten_activation='ReLU', attention='add', 21 | unpool=True, batch_norm=False, name='right0') 22 | 23 | Input 24 | ---------- 25 | X: input tensor 26 | X_left: the output of corresponded downsampling output tensor (the input tensor is upsampling input) 27 | channel: number of convolution filters 28 | att_channel: number of intermediate channel. 29 | kernel_size: size of 2-d convolution kernels. 30 | stack_num: number of convolutional layers. 31 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 32 | atten_activation: a nonlinear attnetion activation. 33 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'. 34 | attention: 'add' for additive attention. 'multiply' for multiplicative attention. 35 | Oktay et al. 2018 applied additive attention. 36 | unpool: True or "bilinear" for Upsampling2D with bilinear interpolation. 37 | "nearest" for Upsampling2D with nearest interpolation. 38 | False for Conv2DTranspose + batch norm + activation. 39 | batch_norm: True for batch normalization, False otherwise. 40 | name: prefix of the created keras layers. 41 | Output 42 | ---------- 43 | X: output tensor. 44 | 45 | ''' 46 | 47 | pool_size = 2 48 | 49 | X = decode_layer(X, channel, pool_size, unpool, 50 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name)) 51 | 52 | X_left = attention_gate(X=X_left, g=X, channel=att_channel, activation=atten_activation, 53 | attention=attention, name='{}_att'.format(name)) 54 | 55 | # Tensor concatenation 56 | H = concatenate([X, X_left], axis=-1, name='{}_concat'.format(name)) 57 | 58 | # stacked linear convolutional layers after concatenation 59 | H = CONV_stack(H, channel, kernel_size, stack_num=stack_num, activation=activation, 60 | batch_norm=batch_norm, name='{}_conv_after_concat'.format(name)) 61 | 62 | return H 63 | 64 | def att_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, 65 | activation='ReLU', atten_activation='ReLU', attention='add', batch_norm=False, pool=True, unpool=True, 66 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='attunet'): 67 | ''' 68 | The base of Attention U-net with an optional ImageNet backbone 69 | 70 | att_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, 71 | activation='ReLU', atten_activation='ReLU', attention='add', batch_norm=False, pool=True, unpool=True, 72 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='att-unet') 73 | 74 | ---------- 75 | Oktay, O., Schlemper, J., Folgoc, L.L., Lee, M., Heinrich, M., Misawa, K., Mori, K., McDonagh, S., Hammerla, N.Y., Kainz, B. 76 | and Glocker, B., 2018. Attention u-net: Learning where to look for the pancreas. arXiv preprint arXiv:1804.03999. 77 | 78 | Input 79 | ---------- 80 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 81 | filter_num: a list that defines the number of filters for each \ 82 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 83 | The depth is expected as `len(filter_num)`. 84 | stack_num_down: number of convolutional layers per downsampling level/block. 85 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 86 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 87 | atten_activation: a nonlinear atteNtion activation. 88 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'. 89 | attention: 'add' for additive attention. 'multiply' for multiplicative attention. 90 | Oktay et al. 2018 applied additive attention. 91 | batch_norm: True for batch normalization. 92 | pool: True or 'max' for MaxPooling2D. 93 | 'ave' for AveragePooling2D. 94 | False for strided conv + batch norm + activation. 95 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 96 | 'nearest' for Upsampling2D with nearest interpolation. 97 | False for Conv2DTranspose + batch norm + activation. 98 | name: prefix of the created keras model and its layers. 99 | 100 | ---------- (keywords of backbone options) ---------- 101 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 102 | None (default) means no backbone. 103 | Currently supported backbones are: 104 | (1) VGG16, VGG19 105 | (2) ResNet50, ResNet101, ResNet152 106 | (3) ResNet50V2, ResNet101V2, ResNet152V2 107 | (4) DenseNet121, DenseNet169, DenseNet201 108 | (5) EfficientNetB[0-7] 109 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 110 | or the path to the weights file to be loaded. 111 | freeze_backbone: True for a frozen backbone. 112 | freeze_batch_norm: False for not freezing batch normalization layers. 113 | 114 | Output 115 | ---------- 116 | X: the output tensor of the base. 117 | 118 | ''' 119 | activation_func = eval(activation) 120 | 121 | depth_ = len(filter_num) 122 | X_skip = [] 123 | 124 | # no backbone cases 125 | if backbone is None: 126 | X = input_tensor 127 | # downsampling blocks 128 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation, 129 | batch_norm=batch_norm, name='{}_down0'.format(name)) 130 | X_skip.append(X) 131 | 132 | for i, f in enumerate(filter_num[1:]): 133 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, pool=pool, 134 | batch_norm=batch_norm, name='{}_down{}'.format(name, i+1)) 135 | X_skip.append(X) 136 | 137 | else: 138 | # handling VGG16 and VGG19 separately 139 | if 'VGG' in backbone: 140 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm) 141 | # collecting backbone feature maps 142 | X_skip = backbone_([input_tensor,]) 143 | depth_encode = len(X_skip) 144 | 145 | # for other backbones 146 | else: 147 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm) 148 | # collecting backbone feature maps 149 | X_skip = backbone_([input_tensor,]) 150 | depth_encode = len(X_skip) + 1 151 | 152 | # extra conv2d blocks are applied 153 | # if downsampling levels of a backbone < user-specified downsampling levels 154 | if depth_encode < depth_: 155 | 156 | # begins at the deepest available tensor 157 | X = X_skip[-1] 158 | 159 | # extra downsamplings 160 | for i in range(depth_-depth_encode): 161 | i_real = i + depth_encode 162 | 163 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool, 164 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1)) 165 | X_skip.append(X) 166 | 167 | # reverse indexing encoded feature maps 168 | X_skip = X_skip[::-1] 169 | # upsampling begins at the deepest available tensor 170 | X = X_skip[0] 171 | # other tensors are preserved for concatenation 172 | X_decode = X_skip[1:] 173 | depth_decode = len(X_decode) 174 | 175 | # reverse indexing filter numbers 176 | filter_num_decode = filter_num[:-1][::-1] 177 | 178 | for i in range(depth_decode): 179 | f = filter_num_decode[i] 180 | 181 | X = UNET_att_right(X, X_decode[i], f, att_channel=f//2, stack_num=stack_num_up, 182 | activation=activation, atten_activation=atten_activation, attention=attention, 183 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i)) 184 | 185 | # if tensors for concatenation is not enough 186 | # then use upsampling without concatenation 187 | if depth_decode < depth_-1: 188 | for i in range(depth_-depth_decode-1): 189 | i_real = i + depth_decode 190 | X = UNET_right(X, None, filter_num_decode[i_real], stack_num=stack_num_up, activation=activation, 191 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_up{}'.format(name, i_real)) 192 | return X 193 | 194 | def att_unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, activation='ReLU', 195 | atten_activation='ReLU', attention='add', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, 196 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='attunet'): 197 | ''' 198 | Attention U-net with an optional ImageNet backbone 199 | 200 | att_unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, activation='ReLU', 201 | atten_activation='ReLU', attention='add', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, 202 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='att-unet') 203 | 204 | ---------- 205 | Oktay, O., Schlemper, J., Folgoc, L.L., Lee, M., Heinrich, M., Misawa, K., Mori, K., McDonagh, S., Hammerla, N.Y., Kainz, B. 206 | and Glocker, B., 2018. Attention u-net: Learning where to look for the pancreas. arXiv preprint arXiv:1804.03999. 207 | 208 | Input 209 | ---------- 210 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 211 | filter_num: a list that defines the number of filters for each \ 212 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 213 | The depth is expected as `len(filter_num)`. 214 | n_labels: number of output labels. 215 | stack_num_down: number of convolutional layers per downsampling level/block. 216 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 217 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 218 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 219 | Default option is 'Softmax'. 220 | if None is received, then linear activation is applied. 221 | atten_activation: a nonlinear atteNtion activation. 222 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'. 223 | attention: 'add' for additive attention. 'multiply' for multiplicative attention. 224 | Oktay et al. 2018 applied additive attention. 225 | batch_norm: True for batch normalization. 226 | pool: True or 'max' for MaxPooling2D. 227 | 'ave' for AveragePooling2D. 228 | False for strided conv + batch norm + activation. 229 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 230 | 'nearest' for Upsampling2D with nearest interpolation. 231 | False for Conv2DTranspose + batch norm + activation. 232 | name: prefix of the created keras model and its layers. 233 | 234 | ---------- (keywords of backbone options) ---------- 235 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 236 | None (default) means no backbone. 237 | Currently supported backbones are: 238 | (1) VGG16, VGG19 239 | (2) ResNet50, ResNet101, ResNet152 240 | (3) ResNet50V2, ResNet101V2, ResNet152V2 241 | (4) DenseNet121, DenseNet169, DenseNet201 242 | (5) EfficientNetB[0-7] 243 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 244 | or the path to the weights file to be loaded. 245 | freeze_backbone: True for a frozen backbone. 246 | freeze_batch_norm: False for not freezing batch normalization layers. 247 | 248 | Output 249 | ---------- 250 | model: a keras model 251 | 252 | ''' 253 | 254 | # one of the ReLU, LeakyReLU, PReLU, ELU 255 | activation_func = eval(activation) 256 | 257 | if backbone is not None: 258 | bach_norm_checker(backbone, batch_norm) 259 | 260 | IN = Input(input_size) 261 | 262 | # base 263 | X = att_unet_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up, 264 | activation=activation, atten_activation=atten_activation, attention=attention, 265 | batch_norm=batch_norm, pool=pool, unpool=unpool, 266 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, 267 | freeze_batch_norm=freeze_backbone, name=name) 268 | 269 | # output layer 270 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 271 | 272 | # functional API model 273 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name)) 274 | 275 | return model 276 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_r2_unet_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | 7 | from tensorflow.keras.layers import Input 8 | from tensorflow.keras.models import Model 9 | 10 | def RR_CONV(X, channel, kernel_size=3, stack_num=2, recur_num=2, activation='ReLU', batch_norm=False, name='rr'): 11 | ''' 12 | Recurrent convolutional layers with skip connection. 13 | 14 | RR_CONV(X, channel, kernel_size=3, stack_num=2, recur_num=2, activation='ReLU', batch_norm=False, name='rr') 15 | 16 | Input 17 | ---------- 18 | X: input tensor. 19 | channel: number of convolution filters. 20 | kernel_size: size of 2-d convolution kernels. 21 | stack_num: number of stacked recurrent convolutional layers. 22 | recur_num: number of recurrent iterations. 23 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 24 | batch_norm: True for batch normalization, False otherwise. 25 | name: prefix of the created keras layers. 26 | 27 | Output 28 | ---------- 29 | X: output tensor. 30 | 31 | ''' 32 | 33 | activation_func = eval(activation) 34 | 35 | layer_skip = Conv2D(channel, 1, name='{}_conv'.format(name))(X) 36 | layer_main = layer_skip 37 | 38 | for i in range(stack_num): 39 | 40 | layer_res = Conv2D(channel, kernel_size, padding='same', name='{}_conv{}'.format(name, i))(layer_main) 41 | 42 | if batch_norm: 43 | layer_res = BatchNormalization(name='{}_bn{}'.format(name, i))(layer_res) 44 | 45 | layer_res = activation_func(name='{}_activation{}'.format(name, i))(layer_res) 46 | 47 | for j in range(recur_num): 48 | 49 | layer_add = add([layer_res, layer_main], name='{}_add{}_{}'.format(name, i, j)) 50 | 51 | layer_res = Conv2D(channel, kernel_size, padding='same', name='{}_conv{}_{}'.format(name, i, j))(layer_add) 52 | 53 | if batch_norm: 54 | layer_res = BatchNormalization(name='{}_bn{}_{}'.format(name, i, j))(layer_res) 55 | 56 | layer_res = activation_func(name='{}_activation{}_{}'.format(name, i, j))(layer_res) 57 | 58 | layer_main = layer_res 59 | 60 | out_layer = add([layer_main, layer_skip], name='{}_add{}'.format(name, i)) 61 | 62 | return out_layer 63 | 64 | 65 | def UNET_RR_left(X, channel, kernel_size=3, 66 | stack_num=2, recur_num=2, activation='ReLU', 67 | pool=True, batch_norm=False, name='left0'): 68 | ''' 69 | The encoder block of R2U-Net. 70 | 71 | UNET_RR_left(X, channel, kernel_size=3, 72 | stack_num=2, recur_num=2, activation='ReLU', 73 | pool=True, batch_norm=False, name='left0') 74 | 75 | Input 76 | ---------- 77 | X: input tensor. 78 | channel: number of convolution filters. 79 | kernel_size: size of 2-d convolution kernels. 80 | stack_num: number of stacked recurrent convolutional layers. 81 | recur_num: number of recurrent iterations. 82 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 83 | pool: True or 'max' for MaxPooling2D. 84 | 'ave' for AveragePooling2D. 85 | False for strided conv + batch norm + activation. 86 | batch_norm: True for batch normalization, False otherwise. 87 | name: prefix of the created keras layers. 88 | 89 | Output 90 | ---------- 91 | X: output tensor. 92 | 93 | *downsampling is fixed to 2-by-2, e.g., reducing feature map sizes from 64-by-64 to 32-by-32 94 | ''' 95 | pool_size = 2 96 | 97 | # maxpooling layer vs strided convolutional layers 98 | X = encode_layer(X, channel, pool_size, pool, activation=activation, 99 | batch_norm=batch_norm, name='{}_encode'.format(name)) 100 | 101 | # stack linear convolutional layers 102 | X = RR_CONV(X, channel, stack_num=stack_num, recur_num=recur_num, 103 | activation=activation, batch_norm=batch_norm, name=name) 104 | return X 105 | 106 | 107 | def UNET_RR_right(X, X_list, channel, kernel_size=3, 108 | stack_num=2, recur_num=2, activation='ReLU', 109 | unpool=True, batch_norm=False, name='right0'): 110 | ''' 111 | The decoder block of R2U-Net. 112 | 113 | UNET_RR_right(X, X_list, channel, kernel_size=3, 114 | stack_num=2, recur_num=2, activation='ReLU', 115 | unpool=True, batch_norm=False, name='right0') 116 | 117 | Input 118 | ---------- 119 | X: input tensor. 120 | X_list: a list of other tensors that connected to the input tensor. 121 | channel: number of convolution filters. 122 | kernel_size: size of 2-d convolution kernels. 123 | stack_num: number of stacked recurrent convolutional layers. 124 | recur_num: number of recurrent iterations. 125 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 126 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 127 | 'nearest' for Upsampling2D with nearest interpolation. 128 | False for Conv2DTranspose + batch norm + activation. 129 | batch_norm: True for batch normalization, False otherwise. 130 | name: prefix of the created keras layers. 131 | 132 | Output 133 | ---------- 134 | X: output tensor 135 | 136 | ''' 137 | 138 | pool_size = 2 139 | 140 | X = decode_layer(X, channel, pool_size, unpool, 141 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name)) 142 | 143 | # linear convolutional layers before concatenation 144 | X = CONV_stack(X, channel, kernel_size, stack_num=1, activation=activation, 145 | batch_norm=batch_norm, name='{}_conv_before_concat'.format(name)) 146 | 147 | # Tensor concatenation 148 | H = concatenate([X,]+X_list, axis=-1, name='{}_concat'.format(name)) 149 | 150 | # stacked linear convolutional layers after concatenation 151 | H = RR_CONV(H, channel, stack_num=stack_num, recur_num=recur_num, 152 | activation=activation, batch_norm=batch_norm, name=name) 153 | 154 | return H 155 | 156 | def r2_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, recur_num=2, 157 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='res_unet'): 158 | 159 | ''' 160 | The base of Recurrent Residual (R2) U-Net. 161 | 162 | r2_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, recur_num=2, 163 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='res_unet') 164 | 165 | ---------- 166 | Alom, M.Z., Hasan, M., Yakopcic, C., Taha, T.M. and Asari, V.K., 2018. Recurrent residual convolutional neural network 167 | based on u-net (r2u-net) for medical image segmentation. arXiv preprint arXiv:1802.06955. 168 | 169 | Input 170 | ---------- 171 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 172 | filter_num: a list that defines the number of filters for each \ 173 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 174 | The depth is expected as `len(filter_num)`. 175 | stack_num_down: number of stacked recurrent convolutional layers per downsampling level/block. 176 | stack_num_down: number of stacked recurrent convolutional layers per upsampling level/block. 177 | recur_num: number of recurrent iterations. 178 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 179 | batch_norm: True for batch normalization. 180 | pool: True or 'max' for MaxPooling2D. 181 | 'ave' for AveragePooling2D. 182 | False for strided conv + batch norm + activation. 183 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 184 | 'nearest' for Upsampling2D with nearest interpolation. 185 | False for Conv2DTranspose + batch norm + activation. 186 | name: prefix of the created keras layers. 187 | 188 | Output 189 | ---------- 190 | X: output tensor. 191 | 192 | ''' 193 | 194 | activation_func = eval(activation) 195 | 196 | X = input_tensor 197 | X_skip = [] 198 | 199 | # downsampling blocks 200 | X = RR_CONV(X, filter_num[0], stack_num=stack_num_down, recur_num=recur_num, 201 | activation=activation, batch_norm=batch_norm, name='{}_down0'.format(name)) 202 | X_skip.append(X) 203 | 204 | for i, f in enumerate(filter_num[1:]): 205 | X = UNET_RR_left(X, f, kernel_size=3, stack_num=stack_num_down, recur_num=recur_num, 206 | activation=activation, pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1)) 207 | X_skip.append(X) 208 | 209 | # upsampling blocks 210 | X_skip = X_skip[:-1][::-1] 211 | for i, f in enumerate(filter_num[:-1][::-1]): 212 | X = UNET_RR_right(X, [X_skip[i],], f, stack_num=stack_num_up, recur_num=recur_num, 213 | activation=activation, unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i+1)) 214 | 215 | return X 216 | 217 | 218 | def r2_unet_2d(input_size, filter_num, n_labels, 219 | stack_num_down=2, stack_num_up=2, recur_num=2, 220 | activation='ReLU', output_activation='Softmax', 221 | batch_norm=False, pool=True, unpool=True, name='r2_unet'): 222 | 223 | ''' 224 | Recurrent Residual (R2) U-Net 225 | 226 | r2_unet_2d(input_size, filter_num, n_labels, 227 | stack_num_down=2, stack_num_up=2, recur_num=2, 228 | activation='ReLU', output_activation='Softmax', 229 | batch_norm=False, pool=True, unpool=True, name='r2_unet') 230 | 231 | ---------- 232 | Alom, M.Z., Hasan, M., Yakopcic, C., Taha, T.M. and Asari, V.K., 2018. Recurrent residual convolutional neural network 233 | based on u-net (r2u-net) for medical image segmentation. arXiv preprint arXiv:1802.06955. 234 | 235 | Input 236 | ---------- 237 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 238 | filter_num: a list that defines the number of filters for each \ 239 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 240 | The depth is expected as `len(filter_num)`. 241 | n_labels: number of output labels. 242 | stack_num_down: number of stacked recurrent convolutional layers per downsampling level/block. 243 | stack_num_down: number of stacked recurrent convolutional layers per upsampling level/block. 244 | recur_num: number of recurrent iterations. 245 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 246 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 247 | Default option is 'Softmax'. 248 | if None is received, then linear activation is applied. 249 | batch_norm: True for batch normalization. 250 | pool: True or 'max' for MaxPooling2D. 251 | 'ave' for AveragePooling2D. 252 | False for strided conv + batch norm + activation. 253 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 254 | 'nearest' for Upsampling2D with nearest interpolation. 255 | False for Conv2DTranspose + batch norm + activation. 256 | name: prefix of the created keras layers. 257 | 258 | Output 259 | ---------- 260 | model: a keras model. 261 | 262 | ''' 263 | 264 | activation_func = eval(activation) 265 | 266 | IN = Input(input_size, name='{}_input'.format(name)) 267 | 268 | # base 269 | X = r2_unet_2d_base(IN, filter_num, 270 | stack_num_down=stack_num_down, stack_num_up=stack_num_up, recur_num=recur_num, 271 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name=name) 272 | # output layer 273 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 274 | 275 | # functional API model 276 | model = Model(inputs=[IN], outputs=[OUT], name='{}_model'.format(name)) 277 | 278 | return model 279 | 280 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_resunet_a_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | 7 | from tensorflow.keras.layers import Input 8 | from tensorflow.keras.models import Model 9 | 10 | def ResUNET_a_block(X, channel, kernel_size=3, dilation_num=1.0, activation='ReLU', batch_norm=False, name='res_a_block'): 11 | ''' 12 | The "ResUNET-a" block 13 | 14 | ResUNET_a_block(X, channel, kernel_size=3, dilation_num=1.0, activation='ReLU', batch_norm=False, name='res_a_block') 15 | 16 | ---------- 17 | Diakogiannis, F.I., Waldner, F., Caccetta, P. and Wu, C., 2020. Resunet-a: a deep learning framework for 18 | semantic segmentation of remotely sensed data. ISPRS Journal of Photogrammetry and Remote Sensing, 162, pp.94-114. 19 | 20 | Input 21 | ---------- 22 | X: input tensor. 23 | channel: number of convolution filters. 24 | kernel_size: size of 2-d convolution kernels. 25 | dilation_num: an iterable that defines dilation rates of convolutional layers. 26 | stacks of conv2d is expected as `len(dilation_num)`. 27 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 28 | batch_norm: True for batch normalization, False otherwise. 29 | name: prefix of the created keras layers. 30 | 31 | Output 32 | ---------- 33 | X: output tensor. 34 | 35 | ''' 36 | 37 | X_res = [] 38 | 39 | for i, d in enumerate(dilation_num): 40 | 41 | X_res.append(CONV_stack(X, channel, kernel_size=kernel_size, stack_num=2, dilation_rate=d, 42 | activation=activation, batch_norm=batch_norm, name='{}_stack{}'.format(name, i))) 43 | 44 | if len(X_res) > 1: 45 | return add(X_res) 46 | 47 | else: 48 | return X_res[0] 49 | 50 | 51 | def ResUNET_a_right(X, X_list, channel, kernel_size=3, dilation_num=[1,], 52 | activation='ReLU', unpool=True, batch_norm=False, name='right0'): 53 | ''' 54 | The decoder block of ResUNet-a 55 | 56 | ResUNET_a_right(X, X_list, channel, kernel_size=3, dilation_num=[1,], 57 | activation='ReLU', unpool=True, batch_norm=False, name='right0') 58 | 59 | Input 60 | ---------- 61 | X: input tensor. 62 | X_list: a list of other tensors that connected to the input tensor. 63 | channel: number of convolution filters. 64 | kernel_size: size of 2-d convolution kernels. 65 | dilation_num: an iterable that defines dilation rates of convolutional layers. 66 | stacks of conv2d is expected as `len(dilation_num)`. 67 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 68 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 69 | 'nearest' for Upsampling2D with nearest interpolation. 70 | False for Conv2DTranspose + batch norm + activation. 71 | batch_norm: True for batch normalization, False otherwise. 72 | name: name of the created keras layers. 73 | 74 | Output 75 | ---------- 76 | X: output tensor. 77 | 78 | 79 | ''' 80 | 81 | pool_size = 2 82 | 83 | X = decode_layer(X, channel, pool_size, unpool, 84 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name)) 85 | 86 | # <--- *stacked convolutional can be applied here 87 | X = concatenate([X,]+X_list, axis=3, name=name+'_concat') 88 | 89 | # Stacked convolutions after concatenation 90 | X = ResUNET_a_block(X, channel, kernel_size=kernel_size, dilation_num=dilation_num, activation=activation, 91 | batch_norm=batch_norm, name='{}_resblock'.format(name)) 92 | 93 | return X 94 | 95 | def resunet_a_2d_base(input_tensor, filter_num, dilation_num, 96 | aspp_num_down=256, aspp_num_up=128, activation='ReLU', 97 | batch_norm=True, pool=True, unpool=True, name='resunet'): 98 | ''' 99 | The base of ResUNet-a 100 | 101 | resunet_a_2d_base(input_tensor, filter_num, dilation_num, 102 | aspp_num_down=256, aspp_num_up=128, activation='ReLU', 103 | batch_norm=True, pool=True, unpool=True, name='resunet') 104 | 105 | ---------- 106 | Diakogiannis, F.I., Waldner, F., Caccetta, P. and Wu, C., 2020. Resunet-a: a deep learning framework for 107 | semantic segmentation of remotely sensed data. ISPRS Journal of Photogrammetry and Remote Sensing, 162, pp.94-114. 108 | 109 | Input 110 | ---------- 111 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 112 | filter_num: a list that defines the number of filters for each \ 113 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 114 | The depth is expected as `len(filter_num)`. 115 | dilation_num: an iterable that defines the dilation rates of convolutional layers. 116 | Diakogiannis et al. (2020) suggested `[1, 3, 15, 31]`. 117 | * This base function requires `len(filter_num) == len(dilation_num)`. 118 | Explicitly defining dilation rates for each down-/upsampling level. 119 | aspp_num_down: number of Atrous Spatial Pyramid Pooling (ASPP) layer filters after the last downsampling block. 120 | aspp_num_up: number of ASPP layer filters after the last upsampling block. 121 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 122 | batch_norm: True for batch normalization. 123 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 124 | 'nearest' for Upsampling2D with nearest interpolation. 125 | False for Conv2DTranspose + batch norm + activation. 126 | name: prefix of the created keras layers. 127 | 128 | Output 129 | ---------- 130 | X: output tensor. 131 | 132 | * Downsampling is achieved through strided convolutional layers with 1-by-1 kernels in Diakogiannis et al., (2020), 133 | and is here is achieved either with pooling layers or strided convolutional layers with 2-by-2 kernels. 134 | * If this base function is involved in network training, then the input shape cannot have NoneType. 135 | * `dilation_num` should be provided as 2d iterables, with the second dimension matches the model depth. 136 | e.g., for `len(filter_num) = 4`, dilation_num can be provided as: `[[1, 3, 15, 31], [1, 3, 15], [1,], [1,]]`. 137 | 138 | ''' 139 | 140 | pool_size = 2 141 | 142 | activation_func = eval(activation) 143 | 144 | depth_ = len(filter_num) 145 | X_skip = [] 146 | 147 | # ----- # 148 | # rejecting auto-mode from this base function 149 | if isinstance(dilation_num[0], int): 150 | raise ValueError('`resunet_a_2d_base` does not support automated determination of `dilation_num`.') 151 | else: 152 | dilation_ = dilation_num 153 | # ----- # 154 | 155 | X = input_tensor 156 | 157 | # input mapping with 1-by-1 conv 158 | X = Conv2D(filter_num[0], 1, 1, dilation_rate=1, padding='same', 159 | use_bias=True, name='{}_input_mapping'.format(name))(X) 160 | X = activation_func(name='{}_input_activation'.format(name))(X) 161 | X_skip.append(X) 162 | # ----- # 163 | 164 | X = ResUNET_a_block(X, filter_num[0], kernel_size=3, dilation_num=dilation_[0], 165 | activation=activation, batch_norm=batch_norm, name='{}_res0'.format(name)) 166 | X_skip.append(X) 167 | 168 | for i, f in enumerate(filter_num[1:]): 169 | ind_ = i+1 170 | 171 | X = encode_layer(X, f, pool_size, pool, activation=activation, 172 | batch_norm=batch_norm, name='{}_down{}'.format(name, i)) 173 | X = ResUNET_a_block(X, f, kernel_size=3, dilation_num=dilation_[ind_], activation=activation, 174 | batch_norm=batch_norm, name='{}_resblock_{}'.format(name, ind_)) 175 | X_skip.append(X) 176 | 177 | X = ASPP_conv(X, aspp_num_down, activation=activation, batch_norm=batch_norm, name='{}_aspp_bottom'.format(name)) 178 | 179 | X_skip = X_skip[:-1][::-1] 180 | dilation_ = dilation_[:-1][::-1] 181 | 182 | for i, f in enumerate(filter_num[:-1][::-1]): 183 | 184 | X = ResUNET_a_right(X, [X_skip[i],], f, kernel_size=3, activation=activation, dilation_num=dilation_[i], 185 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i)) 186 | 187 | X = concatenate([X_skip[-1], X], name='{}_concat_out'.format(name)) 188 | 189 | X = ASPP_conv(X, aspp_num_up, activation=activation, batch_norm=batch_norm, name='{}_aspp_out'.format(name)) 190 | 191 | return X 192 | 193 | 194 | def resunet_a_2d(input_size, filter_num, dilation_num, n_labels, 195 | aspp_num_down=256, aspp_num_up=128, activation='ReLU', output_activation='Softmax', 196 | batch_norm=True, pool=True, unpool=True, name='resunet'): 197 | ''' 198 | ResUNet-a 199 | 200 | resunet_a_2d(input_size, filter_num, dilation_num, n_labels, 201 | aspp_num_down=256, aspp_num_up=128, activation='ReLU', output_activation='Softmax', 202 | batch_norm=True, pool=True, unpool=True, name='resunet') 203 | 204 | ---------- 205 | Diakogiannis, F.I., Waldner, F., Caccetta, P. and Wu, C., 2020. Resunet-a: a deep learning framework for 206 | semantic segmentation of remotely sensed data. ISPRS Journal of Photogrammetry and Remote Sensing, 162, pp.94-114. 207 | 208 | Input 209 | ---------- 210 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 211 | filter_num: a list that defines the number of filters for each \ 212 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 213 | The depth is expected as `len(filter_num)`. 214 | dilation_num: an iterable that defines the dilation rates of convolutional layers. 215 | Diakogiannis et al. (2020) suggested `[1, 3, 15, 31]`. 216 | * `dilation_num` can be provided as 2d iterables, with the second dimension matches 217 | the model depth. e.g., for len(filter_num) = 4; dilation_num can be provided as: 218 | `[[1, 3, 15, 31], [1, 3, 15], [1,], [1,]]`. 219 | * If `dilation_num` is not provided per down-/upsampling level, then the automated 220 | determinations will be applied. 221 | n_labels: number of output labels. 222 | aspp_num_down: number of Atrous Spatial Pyramid Pooling (ASPP) layer filters after the last downsampling block. 223 | aspp_num_up: number of ASPP layer filters after the last upsampling block. 224 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 225 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 226 | Default option is 'Softmax'. 227 | if None is received, then linear activation is applied. 228 | batch_norm: True for batch normalization. 229 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 230 | 'nearest' for Upsampling2D with nearest interpolation. 231 | False for Conv2DTranspose + batch norm + activation. 232 | name: prefix of the created keras layers. 233 | 234 | Output 235 | ---------- 236 | model: a keras model. 237 | 238 | * Downsampling is achieved through strided convolutional layers with 1-by-1 kernels in Diakogiannis et al., (2020), 239 | and is here is achieved either with pooling layers or strided convolutional layers with 2-by-2 kernels. 240 | * `resunet_a_2d` does not support NoneType input shape. 241 | 242 | ''' 243 | 244 | activation_func = eval(activation) 245 | depth_ = len(filter_num) 246 | 247 | X_skip = [] 248 | 249 | # input_size cannot have None 250 | if input_size[0] is None or input_size[1] is None: 251 | raise ValueError('`resunet_a_2d` does not support NoneType input shape') 252 | 253 | # ----- # 254 | if isinstance(dilation_num[0], int): 255 | print("Received dilation rates: {}".format(dilation_num)) 256 | 257 | deep_ = (depth_-2)//2 258 | dilation_ = [[] for _ in range(depth_)] 259 | 260 | print("Received dilation rates are not defined on a per downsampling level basis.") 261 | print("Automated determinations are applied with the following details:") 262 | 263 | for i in range(depth_): 264 | if i <= 1: 265 | dilation_[i] += dilation_num 266 | elif i > 1 and i <= deep_+1: 267 | dilation_[i] += dilation_num[:-1] 268 | else: 269 | dilation_[i] += [1,] 270 | print('\tdepth-{}, dilation_rate = {}'.format(i, dilation_[i])) 271 | 272 | else: 273 | dilation_ = dilation_num 274 | # ----- # 275 | 276 | IN = Input(input_size) 277 | 278 | # base 279 | X = resunet_a_2d_base(IN, filter_num, dilation_, 280 | aspp_num_down=aspp_num_down, aspp_num_up=aspp_num_up, activation=activation, 281 | batch_norm=batch_norm, pool=pool, unpool=unpool, name=name) 282 | 283 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 284 | 285 | model = Model([IN], [OUT,], name='{}_model'.format(name)) 286 | 287 | return model 288 | 289 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_swin_unet_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.transformer_layers import patch_extract, patch_embedding, SwinTransformerBlock, patch_merging, patch_expanding 6 | 7 | from tensorflow.keras.layers import Input, Dense 8 | from tensorflow.keras.models import Model 9 | 10 | def swin_transformer_stack(X, stack_num, embed_dim, num_patch, num_heads, window_size, num_mlp, shift_window=True, name=''): 11 | ''' 12 | Stacked Swin Transformers that share the same token size. 13 | 14 | Alternated Window-MSA and Swin-MSA will be configured if `shift_window=True`, Window-MSA only otherwise. 15 | *Dropout is turned off. 16 | ''' 17 | # Turn-off dropouts 18 | mlp_drop_rate = 0 # Droupout after each MLP layer 19 | attn_drop_rate = 0 # Dropout after Swin-Attention 20 | proj_drop_rate = 0 # Dropout at the end of each Swin-Attention block, i.e., after linear projections 21 | drop_path_rate = 0 # Drop-path within skip-connections 22 | 23 | qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value 24 | qk_scale = None # None: Re-scale query based on embed dimensions per attention head # Float for user specified scaling factor 25 | 26 | if shift_window: 27 | shift_size = window_size // 2 28 | else: 29 | shift_size = 0 30 | 31 | for i in range(stack_num): 32 | 33 | if i % 2 == 0: 34 | shift_size_temp = 0 35 | else: 36 | shift_size_temp = shift_size 37 | 38 | X = SwinTransformerBlock(dim=embed_dim, num_patch=num_patch, num_heads=num_heads, 39 | window_size=window_size, shift_size=shift_size_temp, num_mlp=num_mlp, qkv_bias=qkv_bias, qk_scale=qk_scale, 40 | mlp_drop=mlp_drop_rate, attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, drop_path_prob=drop_path_rate, 41 | name='name{}'.format(i))(X) 42 | return X 43 | 44 | 45 | def swin_unet_2d_base(input_tensor, filter_num_begin, depth, stack_num_down, stack_num_up, 46 | patch_size, num_heads, window_size, num_mlp, shift_window=True, name='swin_unet'): 47 | ''' 48 | The base of SwinUNET. 49 | 50 | ---------- 51 | Cao, H., Wang, Y., Chen, J., Jiang, D., Zhang, X., Tian, Q. and Wang, M., 2021. 52 | Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation. arXiv preprint arXiv:2105.05537. 53 | 54 | Input 55 | ---------- 56 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 57 | filter_num_begin: number of channels in the first downsampling block; 58 | it is also the number of embedded dimensions. 59 | depth: the depth of Swin-UNET, e.g., depth=4 means three down/upsampling levels and a bottom level. 60 | stack_num_down: number of convolutional layers per downsampling level/block. 61 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 62 | name: prefix of the created keras model and its layers. 63 | 64 | ---------- (keywords of Swin-Transformers) ---------- 65 | 66 | patch_size: The size of extracted patches, 67 | e.g., patch_size=(2, 2) means 2-by-2 patches 68 | *Height and width of the patch must be equal. 69 | 70 | num_heads: number of attention heads per down/upsampling level, 71 | e.g., num_heads=[4, 8, 16, 16] means increased attention heads with increasing depth. 72 | *The length of num_heads must equal to `depth`. 73 | 74 | window_size: the size of attention window per down/upsampling level, 75 | e.g., window_size=[4, 2, 2, 2] means decreased window size with increasing depth. 76 | 77 | num_mlp: number of MLP nodes. 78 | 79 | shift_window: The indicator of window shifting; 80 | shift_window=True means applying Swin-MSA for every two Swin-Transformer blocks. 81 | shift_window=False means MSA with fixed window locations for all blocks. 82 | 83 | Output 84 | ---------- 85 | output tensor. 86 | 87 | Note: This function is experimental. 88 | The activation functions of all Swin-Transformers are fixed to GELU. 89 | 90 | ''' 91 | # Compute number be patches to be embeded 92 | input_size = input_tensor.shape.as_list()[1:] 93 | num_patch_x = input_size[0]//patch_size[0] 94 | num_patch_y = input_size[1]//patch_size[1] 95 | 96 | # Number of Embedded dimensions 97 | embed_dim = filter_num_begin 98 | 99 | depth_ = depth 100 | 101 | X_skip = [] 102 | 103 | X = input_tensor 104 | 105 | # Patch extraction 106 | X = patch_extract(patch_size)(X) 107 | 108 | # Embed patches to tokens 109 | X = patch_embedding(num_patch_x*num_patch_y, embed_dim)(X) 110 | 111 | # The first Swin Transformer stack 112 | X = swin_transformer_stack(X, stack_num=stack_num_down, 113 | embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 114 | num_heads=num_heads[0], window_size=window_size[0], num_mlp=num_mlp, 115 | shift_window=shift_window, name='{}_swin_down0'.format(name)) 116 | X_skip.append(X) 117 | 118 | # Downsampling blocks 119 | for i in range(depth_-1): 120 | 121 | # Patch merging 122 | X = patch_merging((num_patch_x, num_patch_y), embed_dim=embed_dim, name='down{}'.format(i))(X) 123 | 124 | # update token shape info 125 | embed_dim = embed_dim*2 126 | num_patch_x = num_patch_x//2 127 | num_patch_y = num_patch_y//2 128 | 129 | # Swin Transformer stacks 130 | X = swin_transformer_stack(X, stack_num=stack_num_down, 131 | embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 132 | num_heads=num_heads[i+1], window_size=window_size[i+1], num_mlp=num_mlp, 133 | shift_window=shift_window, name='{}_swin_down{}'.format(name, i+1)) 134 | 135 | # Store tensors for concat 136 | X_skip.append(X) 137 | 138 | # reverse indexing encoded tensors and hyperparams 139 | X_skip = X_skip[::-1] 140 | num_heads = num_heads[::-1] 141 | window_size = window_size[::-1] 142 | 143 | # upsampling begins at the deepest available tensor 144 | X = X_skip[0] 145 | 146 | # other tensors are preserved for concatenation 147 | X_decode = X_skip[1:] 148 | 149 | depth_decode = len(X_decode) 150 | 151 | for i in range(depth_decode): 152 | 153 | # Patch expanding 154 | X = patch_expanding(num_patch=(num_patch_x, num_patch_y), 155 | embed_dim=embed_dim, upsample_rate=2, return_vector=True, name='{}_swin_up{}'.format(name, i))(X) 156 | 157 | 158 | # update token shape info 159 | embed_dim = embed_dim//2 160 | num_patch_x = num_patch_x*2 161 | num_patch_y = num_patch_y*2 162 | 163 | # Concatenation and linear projection 164 | X = concatenate([X, X_decode[i]], axis=-1, name='{}_concat_{}'.format(name, i)) 165 | X = Dense(embed_dim, use_bias=False, name='{}_concat_linear_proj_{}'.format(name, i))(X) 166 | 167 | # Swin Transformer stacks 168 | X = swin_transformer_stack(X, stack_num=stack_num_up, 169 | embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 170 | num_heads=num_heads[i], window_size=window_size[i], num_mlp=num_mlp, 171 | shift_window=shift_window, name='{}_swin_up{}'.format(name, i)) 172 | 173 | # The last expanding layer; it produces full-size feature maps based on the patch size 174 | # !!! <--- "patch_size[0]" is used; it assumes patch_size = (size, size) 175 | X = patch_expanding(num_patch=(num_patch_x, num_patch_y), 176 | embed_dim=embed_dim, upsample_rate=patch_size[0], return_vector=False)(X) 177 | 178 | return X 179 | 180 | 181 | def swin_unet_2d(input_size, filter_num_begin, n_labels, depth, stack_num_down, stack_num_up, 182 | patch_size, num_heads, window_size, num_mlp, output_activation='Softmax', shift_window=True, name='swin_unet'): 183 | ''' 184 | The base of SwinUNET. 185 | 186 | ---------- 187 | Cao, H., Wang, Y., Chen, J., Jiang, D., Zhang, X., Tian, Q. and Wang, M., 2021. 188 | Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation. arXiv preprint arXiv:2105.05537. 189 | 190 | Input 191 | ---------- 192 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 193 | filter_num_begin: number of channels in the first downsampling block; 194 | it is also the number of embedded dimensions. 195 | n_labels: number of output labels. 196 | depth: the depth of Swin-UNET, e.g., depth=4 means three down/upsampling levels and a bottom level. 197 | stack_num_down: number of convolutional layers per downsampling level/block. 198 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 199 | name: prefix of the created keras model and its layers. 200 | 201 | ---------- (keywords of Swin-Transformers) ---------- 202 | 203 | patch_size: The size of extracted patches, 204 | e.g., patch_size=(2, 2) means 2-by-2 patches 205 | *Height and width of the patch must be equal. 206 | 207 | num_heads: number of attention heads per down/upsampling level, 208 | e.g., num_heads=[4, 8, 16, 16] means increased attention heads with increasing depth. 209 | *The length of num_heads must equal to `depth`. 210 | 211 | window_size: the size of attention window per down/upsampling level, 212 | e.g., window_size=[4, 2, 2, 2] means decreased window size with increasing depth. 213 | 214 | num_mlp: number of MLP nodes. 215 | 216 | shift_window: The indicator of window shifting; 217 | shift_window=True means applying Swin-MSA for every two Swin-Transformer blocks. 218 | shift_window=False means MSA with fixed window locations for all blocks. 219 | 220 | Output 221 | ---------- 222 | model: a keras model. 223 | 224 | Note: This function is experimental. 225 | The activation functions of all Swin-Transformers are fixed to GELU. 226 | ''' 227 | IN = Input(input_size) 228 | 229 | # base 230 | X = swin_unet_2d_base(IN, filter_num_begin=filter_num_begin, depth=depth, stack_num_down=stack_num_down, stack_num_up=stack_num_up, 231 | patch_size=patch_size, num_heads=num_heads, window_size=window_size, num_mlp=num_mlp, shift_window=shift_window, name=name) 232 | 233 | # output layer 234 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 235 | 236 | # functional API model 237 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name)) 238 | 239 | return model 240 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_transunet_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right 7 | from keras_unet_collection.transformer_layers import patch_extract, patch_embedding 8 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker 9 | 10 | import tensorflow as tf 11 | from tensorflow.keras.layers import Input 12 | from tensorflow.keras.models import Model 13 | from tensorflow.keras.layers import Layer, MultiHeadAttention, LayerNormalization, Dense, Embedding 14 | 15 | def ViT_MLP(X, filter_num, activation='GELU', name='MLP'): 16 | ''' 17 | The MLP block of ViT. 18 | 19 | ---------- 20 | Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, 21 | T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020. 22 | An image is worth 16x16 words: Transformers for image recognition at scale. 23 | arXiv preprint arXiv:2010.11929. 24 | 25 | Input 26 | ---------- 27 | X: the input tensor of MLP, i.e., after MSA and skip connections 28 | filter_num: a list that defines the number of nodes for each MLP layer. 29 | For the last MLP layer, its number of node must equal to the dimension of key. 30 | activation: activation of MLP nodes. 31 | name: prefix of the created keras layers. 32 | 33 | Output 34 | ---------- 35 | V: output tensor. 36 | 37 | ''' 38 | activation_func = eval(activation) 39 | 40 | for i, f in enumerate(filter_num): 41 | X = Dense(f, name='{}_dense_{}'.format(name, i))(X) 42 | X = activation_func(name='{}_activation_{}'.format(name, i))(X) 43 | 44 | return X 45 | 46 | def ViT_block(V, num_heads, key_dim, filter_num_MLP, activation='GELU', name='ViT'): 47 | ''' 48 | 49 | Vision transformer (ViT) block. 50 | 51 | ViT_block(V, num_heads, key_dim, filter_num_MLP, activation='GELU', name='ViT') 52 | 53 | ---------- 54 | Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, 55 | T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020. 56 | An image is worth 16x16 words: Transformers for image recognition at scale. 57 | arXiv preprint arXiv:2010.11929. 58 | 59 | Input 60 | ---------- 61 | V: embedded input features. 62 | num_heads: number of attention heads. 63 | key_dim: dimension of the attention key (equals to the embeded dimensions). 64 | filter_num_MLP: a list that defines the number of nodes for each MLP layer. 65 | For the last MLP layer, its number of node must equal to the dimension of key. 66 | activation: activation of MLP nodes. 67 | name: prefix of the created keras layers. 68 | 69 | Output 70 | ---------- 71 | V: output tensor. 72 | 73 | ''' 74 | # Multiheaded self-attention (MSA) 75 | V_atten = V # <--- skip 76 | V_atten = LayerNormalization(name='{}_layer_norm_1'.format(name))(V_atten) 77 | V_atten = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim, 78 | name='{}_atten'.format(name))(V_atten, V_atten) 79 | # Skip connection 80 | V_add = add([V_atten, V], name='{}_skip_1'.format(name)) # <--- skip 81 | 82 | # MLP 83 | V_MLP = V_add # <--- skip 84 | V_MLP = LayerNormalization(name='{}_layer_norm_2'.format(name))(V_MLP) 85 | V_MLP = ViT_MLP(V_MLP, filter_num_MLP, activation, name='{}_mlp'.format(name)) 86 | # Skip connection 87 | V_out = add([V_MLP, V_add], name='{}_skip_2'.format(name)) # <--- skip 88 | 89 | return V_out 90 | 91 | 92 | def transunet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, 93 | embed_dim=768, num_mlp=3072, num_heads=12, num_transformer=12, 94 | activation='ReLU', mlp_activation='GELU', batch_norm=False, pool=True, unpool=True, 95 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet'): 96 | ''' 97 | The base of transUNET with an optional ImageNet-trained backbone. 98 | 99 | ---------- 100 | Chen, J., Lu, Y., Yu, Q., Luo, X., Adeli, E., Wang, Y., Lu, L., Yuille, A.L. and Zhou, Y., 2021. 101 | Transunet: Transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306. 102 | 103 | Input 104 | ---------- 105 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 106 | filter_num: a list that defines the number of filters for each \ 107 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 108 | The depth is expected as `len(filter_num)`. 109 | stack_num_down: number of convolutional layers per downsampling level/block. 110 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 111 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 112 | batch_norm: True for batch normalization. 113 | pool: True or 'max' for MaxPooling2D. 114 | 'ave' for AveragePooling2D. 115 | False for strided conv + batch norm + activation. 116 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 117 | 'nearest' for Upsampling2D with nearest interpolation. 118 | False for Conv2DTranspose + batch norm + activation. 119 | name: prefix of the created keras model and its layers. 120 | 121 | ---------- (keywords of ViT) ---------- 122 | embed_dim: number of embedded dimensions. 123 | num_mlp: number of MLP nodes. 124 | num_heads: number of attention heads. 125 | num_transformer: number of stacked ViTs. 126 | mlp_activation: activation of MLP nodes. 127 | 128 | ---------- (keywords of backbone options) ---------- 129 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 130 | None (default) means no backbone. 131 | Currently supported backbones are: 132 | (1) VGG16, VGG19 133 | (2) ResNet50, ResNet101, ResNet152 134 | (3) ResNet50V2, ResNet101V2, ResNet152V2 135 | (4) DenseNet121, DenseNet169, DenseNet201 136 | (5) EfficientNetB[0-7] 137 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 138 | or the path to the weights file to be loaded. 139 | freeze_backbone: True for a frozen backbone. 140 | freeze_batch_norm: False for not freezing batch normalization layers. 141 | 142 | Output 143 | ---------- 144 | X: output tensor. 145 | 146 | ''' 147 | activation_func = eval(activation) 148 | 149 | X_skip = [] 150 | depth_ = len(filter_num) 151 | 152 | # ----- internal parameters ----- # 153 | 154 | # patch size (fixed to 1-by-1) 155 | patch_size = 1 156 | 157 | # input tensor size 158 | input_size = input_tensor.shape[1] 159 | 160 | # encoded feature map size 161 | encode_size = input_size // 2**(depth_-1) 162 | 163 | # number of size-1 patches 164 | num_patches = encode_size ** 2 165 | 166 | # dimension of the attention key (= dimension of embedings) 167 | key_dim = embed_dim 168 | 169 | # number of MLP nodes 170 | filter_num_MLP = [num_mlp, embed_dim] 171 | 172 | # ----- UNet-like downsampling ----- # 173 | 174 | # no backbone cases 175 | if backbone is None: 176 | 177 | X = input_tensor 178 | 179 | # stacked conv2d before downsampling 180 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation, 181 | batch_norm=batch_norm, name='{}_down0'.format(name)) 182 | X_skip.append(X) 183 | 184 | # downsampling blocks 185 | for i, f in enumerate(filter_num[1:]): 186 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, pool=pool, 187 | batch_norm=batch_norm, name='{}_down{}'.format(name, i+1)) 188 | X_skip.append(X) 189 | 190 | # backbone cases 191 | else: 192 | # handling VGG16 and VGG19 separately 193 | if 'VGG' in backbone: 194 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm) 195 | # collecting backbone feature maps 196 | X_skip = backbone_([input_tensor,]) 197 | depth_encode = len(X_skip) 198 | 199 | # for other backbones 200 | else: 201 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm) 202 | # collecting backbone feature maps 203 | X_skip = backbone_([input_tensor,]) 204 | depth_encode = len(X_skip) + 1 205 | 206 | 207 | # extra conv2d blocks are applied 208 | # if downsampling levels of a backbone < user-specified downsampling levels 209 | if depth_encode < depth_: 210 | 211 | # begins at the deepest available tensor 212 | X = X_skip[-1] 213 | 214 | # extra downsamplings 215 | for i in range(depth_-depth_encode): 216 | i_real = i + depth_encode 217 | 218 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool, 219 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1)) 220 | X_skip.append(X) 221 | 222 | # subtrack the last tensor (will be replaced by the ViT output) 223 | X = X_skip[-1] 224 | X_skip = X_skip[:-1] 225 | 226 | # 1-by-1 linear transformation before entering ViT blocks 227 | X = Conv2D(filter_num[-1], 1, padding='valid', use_bias=False, name='{}_conv_trans_before'.format(name))(X) 228 | 229 | X = patch_extract((patch_size, patch_size))(X) 230 | X = patch_embedding(num_patches, embed_dim)(X) 231 | 232 | # stacked ViTs 233 | for i in range(num_transformer): 234 | X = ViT_block(X, num_heads, key_dim, filter_num_MLP, activation=mlp_activation, 235 | name='{}_ViT_{}'.format(name, i)) 236 | 237 | # reshape patches to feature maps 238 | X = tf.reshape(X, (-1, encode_size, encode_size, embed_dim)) 239 | 240 | # 1-by-1 linear transformation to adjust the number of channels 241 | X = Conv2D(filter_num[-1], 1, padding='valid', use_bias=False, name='{}_conv_trans_after'.format(name))(X) 242 | 243 | X_skip.append(X) 244 | 245 | # ----- UNet-like upsampling ----- # 246 | 247 | # reverse indexing encoded feature maps 248 | X_skip = X_skip[::-1] 249 | # upsampling begins at the deepest available tensor 250 | X = X_skip[0] 251 | # other tensors are preserved for concatenation 252 | X_decode = X_skip[1:] 253 | depth_decode = len(X_decode) 254 | 255 | # reverse indexing filter numbers 256 | filter_num_decode = filter_num[:-1][::-1] 257 | 258 | # upsampling with concatenation 259 | for i in range(depth_decode): 260 | X = UNET_right(X, [X_decode[i],], filter_num_decode[i], stack_num=stack_num_up, activation=activation, 261 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i)) 262 | 263 | # if tensors for concatenation is not enough 264 | # then use upsampling without concatenation 265 | if depth_decode < depth_-1: 266 | for i in range(depth_-depth_decode-1): 267 | i_real = i + depth_decode 268 | X = UNET_right(X, None, filter_num_decode[i_real], stack_num=stack_num_up, activation=activation, 269 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_up{}'.format(name, i_real)) 270 | 271 | return X 272 | 273 | def transunet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, 274 | embed_dim=768, num_mlp = 3072, num_heads=12, num_transformer=12, 275 | activation='ReLU', mlp_activation='GELU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, 276 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet'): 277 | ''' 278 | TransUNET with an optional ImageNet-trained bakcbone. 279 | 280 | 281 | ---------- 282 | Chen, J., Lu, Y., Yu, Q., Luo, X., Adeli, E., Wang, Y., Lu, L., Yuille, A.L. and Zhou, Y., 2021. 283 | Transunet: Transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306. 284 | 285 | Input 286 | ---------- 287 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 288 | filter_num: a list that defines the number of filters for each \ 289 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 290 | The depth is expected as `len(filter_num)`. 291 | n_labels: number of output labels. 292 | stack_num_down: number of convolutional layers per downsampling level/block. 293 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 294 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 295 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 296 | Default option is 'Softmax'. 297 | if None is received, then linear activation is applied. 298 | batch_norm: True for batch normalization. 299 | pool: True or 'max' for MaxPooling2D. 300 | 'ave' for AveragePooling2D. 301 | False for strided conv + batch norm + activation. 302 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 303 | 'nearest' for Upsampling2D with nearest interpolation. 304 | False for Conv2DTranspose + batch norm + activation. 305 | name: prefix of the created keras model and its layers. 306 | 307 | ---------- (keywords of ViT) ---------- 308 | embed_dim: number of embedded dimensions. 309 | num_mlp: number of MLP nodes. 310 | num_heads: number of attention heads. 311 | num_transformer: number of stacked ViTs. 312 | mlp_activation: activation of MLP nodes. 313 | 314 | ---------- (keywords of backbone options) ---------- 315 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 316 | None (default) means no backbone. 317 | Currently supported backbones are: 318 | (1) VGG16, VGG19 319 | (2) ResNet50, ResNet101, ResNet152 320 | (3) ResNet50V2, ResNet101V2, ResNet152V2 321 | (4) DenseNet121, DenseNet169, DenseNet201 322 | (5) EfficientNetB[0-7] 323 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 324 | or the path to the weights file to be loaded. 325 | freeze_backbone: True for a frozen backbone. 326 | freeze_batch_norm: False for not freezing batch normalization layers. 327 | 328 | Output 329 | ---------- 330 | model: a keras model. 331 | 332 | ''' 333 | 334 | activation_func = eval(activation) 335 | 336 | IN = Input(input_size) 337 | 338 | # base 339 | X = transunet_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up, 340 | embed_dim=embed_dim, num_mlp=num_mlp, num_heads=num_heads, num_transformer=num_transformer, 341 | activation=activation, mlp_activation=mlp_activation, batch_norm=batch_norm, pool=pool, unpool=unpool, 342 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, freeze_batch_norm=freeze_batch_norm, name=name) 343 | 344 | # output layer 345 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 346 | 347 | # functional API model 348 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name)) 349 | 350 | return model 351 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_u2net_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | 7 | from tensorflow.keras.layers import Input 8 | from tensorflow.keras.models import Model 9 | 10 | 11 | def RSU(X, channel_in, channel_out, depth=5, activation='ReLU', batch_norm=True, pool=True, unpool=True, name='RSU'): 12 | ''' 13 | The Residual U-blocks (RSU). 14 | 15 | RSU(X, channel_in, channel_out, depth=5, activation='ReLU', batch_norm=True, pool=True, unpool=True, name='RSU') 16 | 17 | ---------- 18 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020. 19 | U2-Net: Going deeper with nested U-structure for salient object detection. 20 | Pattern Recognition, 106, p.107404. 21 | 22 | Input 23 | ---------- 24 | X: input tensor. 25 | channel_in: number of intermediate channels. 26 | channel_out: number of output channels. 27 | depth: number of down- and upsampling levels. 28 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 29 | batch_norm: True for batch normalization, False otherwise. 30 | pool: True or 'max' for MaxPooling2D. 31 | 'ave' for AveragePooling2D. 32 | False for strided conv + batch norm + activation. 33 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 34 | 'nearest' for Upsampling2D with nearest interpolation. 35 | False for Conv2DTranspose + batch norm + activation. 36 | name: prefix of the created keras layers. 37 | 38 | Output 39 | ---------- 40 | X: output tensor. 41 | 42 | ''' 43 | 44 | pool_size = 2 45 | 46 | X_skip = [] 47 | 48 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1, 49 | dilation_rate=1, activation=activation, batch_norm=batch_norm, 50 | name='{}_in'.format(name)) 51 | X_skip.append(X) 52 | 53 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=1, 54 | activation=activation, batch_norm=batch_norm, name='{}_down_0'.format(name)) 55 | X_skip.append(X) 56 | 57 | for i in range(depth): 58 | 59 | X = encode_layer(X, channel_in, pool_size, pool, activation=activation, 60 | batch_norm=batch_norm, name='{}_encode_{}'.format(name, i)) 61 | 62 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=1, 63 | activation=activation, batch_norm=batch_norm, name='{}_down_{}'.format(name, i+1)) 64 | X_skip.append(X) 65 | 66 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, 67 | dilation_rate=2, activation=activation, batch_norm=batch_norm, 68 | name='{}_up_0'.format(name)) 69 | 70 | X_skip = X_skip[::-1] 71 | 72 | for i in range(depth): 73 | 74 | X = concatenate([X, X_skip[i]], axis=-1, name='{}_concat_{}'.format(name, i)) 75 | 76 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=1, 77 | activation=activation, batch_norm=batch_norm, name='{}_up_{}'.format(name, i+1)) 78 | 79 | X = decode_layer(X, channel_in, pool_size, unpool, 80 | activation=activation, batch_norm=batch_norm, name='{}_decode_{}'.format(name, i)) 81 | 82 | X = concatenate([X, X_skip[depth]], axis=-1, name='{}_concat_out'.format(name)) 83 | 84 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1, dilation_rate=1, 85 | activation=activation, batch_norm=batch_norm, name='{}_out'.format(name)) 86 | X = add([X, X_skip[-1]], name='{}_out_add'.format(name)) 87 | return X 88 | 89 | def RSU4F(X, channel_in, channel_out, dilation_num=[1, 2, 4, 8], activation='ReLU', batch_norm=True, name='RSU4F'): 90 | ''' 91 | The Residual U-blocks with dilated convolutional kernels (RSU4F). 92 | 93 | RSU4F(X, channel_in, channel_out, dilation_num=[1, 2, 4, 8], activation='ReLU', batch_norm=True, name='RSU4F') 94 | 95 | ---------- 96 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020. 97 | U2-Net: Going deeper with nested U-structure for salient object detection. 98 | Pattern Recognition, 106, p.107404. 99 | 100 | Input 101 | ---------- 102 | X: input tensor. 103 | channel_in: number of intermediate channels. 104 | channel_out: number of output channels. 105 | dilation_num: an iterable that defines dilation rates of convolutional layers. 106 | Qin et al. (2020) suggested `[1, 2, 4, 8]`. 107 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 108 | batch_norm: True for batch normalization, False otherwise. 109 | name: prefix of the created keras layers. 110 | 111 | Output 112 | ---------- 113 | X: output tensor 114 | 115 | ''' 116 | 117 | X_skip = [] 118 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1, dilation_rate=1, 119 | activation=activation, batch_norm=batch_norm, name='{}_in'.format(name)) 120 | X_skip.append(X) 121 | 122 | for i, d in enumerate(dilation_num): 123 | 124 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=d, 125 | activation=activation, batch_norm=batch_norm, name='{}_down_{}'.format(name, i)) 126 | X_skip.append(X) 127 | 128 | X_skip = X_skip[:-1][::-1] 129 | dilation_num = dilation_num[:-1][::-1] 130 | 131 | for i, d in enumerate(dilation_num[:-1]): 132 | 133 | X = concatenate([X, X_skip[i]], axis=-1, name='{}_concat_{}'.format(name, i)) 134 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=d, 135 | activation=activation, batch_norm=batch_norm, name='{}_up_{}'.format(name, i)) 136 | 137 | X = concatenate([X, X_skip[2]], axis=-1, name='{}_concat_out'.format(name)) 138 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1, dilation_rate=1, 139 | activation=activation, batch_norm=batch_norm, name='{}_out'.format(name)) 140 | 141 | return add([X, X_skip[-1]], name='{}_out_add'.format(name)) 142 | 143 | def u2net_2d_base(input_tensor, 144 | filter_num_down, filter_num_up, 145 | filter_mid_num_down, filter_mid_num_up, 146 | filter_4f_num, filter_4f_mid_num, activation='ReLU', 147 | batch_norm=False, pool=True, unpool=True, name='u2net'): 148 | 149 | ''' 150 | The base of U^2-Net 151 | 152 | u2net_2d_base(input_tensor, 153 | filter_num_down, filter_num_up, 154 | filter_mid_num_down, filter_mid_num_up, 155 | filter_4f_num, filter_4f_mid_num, activation='ReLU', 156 | batch_norm=False, pool=True, unpool=True, name='u2net') 157 | 158 | ---------- 159 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020. 160 | U2-Net: Going deeper with nested U-structure for salient object detection. 161 | Pattern Recognition, 106, p.107404. 162 | 163 | Input 164 | ---------- 165 | input_tensor: the input tensor of the base, e.g., keras.layers.Inpyt((None, None, 3)) 166 | filter_num_down: a list that defines the number of RSU output filters for each 167 | downsampling level. e.g., `[64, 128, 256, 512]`. 168 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)` 169 | filter_mid_num_down: a list that defines the number of RSU intermediate filters for each 170 | downsampling level. e.g., `[16, 32, 64, 128]`. 171 | * RSU intermediate and output filters must paired, i.e., list with the same length. 172 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers. 173 | filter_mid_num_up: a list that defines the number of RSU intermediate filters for each 174 | upsampling level. e.g., `[16, 32, 64, 128]`. 175 | * RSU intermediate and output filters must paired, i.e., list with the same length. 176 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers. 177 | filter_4f_num: a list that defines the number of RSU-4F output filters for each 178 | downsampling and bottom level. e.g., `[512, 512]`. 179 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)`. 180 | filter_4f_mid_num: a list that defines the number of RSU-4F intermediate filters for each 181 | downsampling and bottom level. e.g., `[256, 256]`. 182 | * RSU-4F intermediate and output filters must paired, i.e., list with the same length. 183 | * RSU-4F intermediate filters numbers are expected to be smaller than output filters numbers. 184 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 185 | batch_norm: True for batch normalization. 186 | pool: True or 'max' for MaxPooling2D. 187 | 'ave' for AveragePooling2D. 188 | False for strided conv + batch norm + activation. 189 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 190 | 'nearest' for Upsampling2D with nearest interpolation. 191 | False for Conv2DTranspose + batch norm + activation. 192 | deep_supervision: True for a model that supports deep supervision. Details see Qin et al. (2020). 193 | name: prefix of the created keras layers. 194 | 195 | Output 196 | ---------- 197 | A list of tensors with the first/second/third tensor obtained from 198 | the deepest/second deepest/third deepest upsampling block, etc. 199 | * The feature map sizes of these tensors are different, 200 | with first tensor has the smallest size. 201 | 202 | * Dilation rates of RSU4F layers are fixed to `[1, 2, 4, 8]`. 203 | * Downsampling is achieved through maxpooling in Qin et al. (2020), 204 | and can be replaced by strided convolutional layers here. 205 | * Upsampling is achieved through bilinear interpolation in Qin et al. (2020), 206 | and can be replaced by transpose convolutional layers here. 207 | 208 | ''' 209 | 210 | pool_size = 2 211 | 212 | X_skip = []; X_out = []; OUT_stack = [] 213 | depth_backup = [] 214 | depth_ = len(filter_num_down) 215 | 216 | X = input_tensor 217 | 218 | X = RSU(X, filter_mid_num_down[0], filter_num_down[0], depth=depth_+1, activation=activation, 219 | batch_norm=batch_norm, pool=pool, unpool=unpool, name='{}_in'.format(name)) 220 | X_skip.append(X) 221 | 222 | depth_backup.append(depth_+1) 223 | 224 | for i, f in enumerate(filter_num_down[1:]): 225 | 226 | X = encode_layer(X, f, pool_size, pool, activation=activation, 227 | batch_norm=batch_norm, name='{}_encode_{}'.format(name, i)) 228 | 229 | X = RSU(X, filter_mid_num_down[i+1], f, depth=depth_-i, activation=activation, 230 | batch_norm=batch_norm, pool=pool, unpool=unpool, name='{}_down_{}'.format(name, i)) 231 | 232 | depth_backup.append(depth_-i) 233 | 234 | X_skip.append(X) 235 | 236 | for i, f in enumerate(filter_4f_num): 237 | 238 | X = encode_layer(X, f, pool_size, pool, activation=activation, 239 | batch_norm=batch_norm, name='{}_encode_4f_{}'.format(name, i)) 240 | 241 | X = RSU4F(X, filter_4f_mid_num[i], f, activation=activation, 242 | batch_norm=batch_norm, name='{}_down_4f_{}'.format(name, i)) 243 | X_skip.append(X) 244 | 245 | X_out.append(X) 246 | 247 | # ---------- # 248 | X_skip = X_skip[:-1][::-1] 249 | depth_backup = depth_backup[::-1] 250 | 251 | filter_num_up = filter_num_up[::-1] 252 | filter_mid_num_up = filter_mid_num_up[::-1] 253 | 254 | filter_4f_num = filter_4f_num[:-1][::-1] 255 | filter_4f_mid_num = filter_4f_mid_num[:-1][::-1] 256 | 257 | tensor_count = 0 258 | for i, f in enumerate(filter_4f_num): 259 | 260 | X = decode_layer(X, f, pool_size, unpool, 261 | activation=activation, batch_norm=batch_norm, name='{}_decode_4f_{}'.format(name, i)) 262 | 263 | X = concatenate([X, X_skip[tensor_count]], axis=-1, name='{}_concat_4f_{}'.format(name, i)) 264 | 265 | X = RSU4F(X, filter_4f_mid_num[i], f, activation=activation, 266 | batch_norm=batch_norm, name='{}_up_4f_{}'.format(name, i)) 267 | X_out.append(X) 268 | 269 | tensor_count += 1 270 | 271 | for i, f in enumerate(filter_num_up): 272 | 273 | X = decode_layer(X, f, pool_size, unpool, 274 | activation=activation, batch_norm=batch_norm, name='{}_decode_{}'.format(name, i)) 275 | 276 | X = concatenate([X, X_skip[tensor_count]], axis=-1, name='{}_concat_{}'.format(name, i)) 277 | 278 | X = RSU(X, filter_mid_num_up[i], f, depth=depth_backup[i], 279 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name='{}_up_{}'.format(name, i)) 280 | X_out.append(X) 281 | 282 | tensor_count += 1 283 | 284 | return X_out 285 | 286 | 287 | def u2net_2d(input_size, n_labels, filter_num_down, filter_num_up='auto', filter_mid_num_down='auto', filter_mid_num_up='auto', 288 | filter_4f_num='auto', filter_4f_mid_num='auto', activation='ReLU', output_activation='Sigmoid', 289 | batch_norm=False, pool=True, unpool=True, deep_supervision=False, name='u2net'): 290 | 291 | ''' 292 | U^2-Net 293 | 294 | u2net_2d(input_size, n_labels, filter_num_down, filter_num_up='auto', filter_mid_num_down='auto', filter_mid_num_up='auto', 295 | filter_4f_num='auto', filter_4f_mid_num='auto', activation='ReLU', output_activation='Sigmoid', 296 | batch_norm=False, deep_supervision=False, name='u2net') 297 | 298 | ---------- 299 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020. 300 | U2-Net: Going deeper with nested U-structure for salient object detection. 301 | Pattern Recognition, 106, p.107404. 302 | 303 | Input 304 | ---------- 305 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 306 | filter_num_down: a list that defines the number of RSU output filters for each 307 | downsampling level. e.g., `[64, 128, 256, 512]`. 308 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)` 309 | filter_mid_num_down: a list that defines the number of RSU intermediate filters for each 310 | downsampling level. e.g., `[16, 32, 64, 128]`. 311 | * RSU intermediate and output filters must paired, i.e., list with the same length. 312 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers. 313 | filter_mid_num_up: a list that defines the number of RSU intermediate filters for each 314 | upsampling level. e.g., `[16, 32, 64, 128]`. 315 | * RSU intermediate and output filters must paired, i.e., list with the same length. 316 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers. 317 | filter_4f_num: a list that defines the number of RSU-4F output filters for each 318 | downsampling and bottom level. e.g., `[512, 512]`. 319 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)`. 320 | filter_4f_mid_num: a list that defines the number of RSU-4F intermediate filters for each 321 | downsampling and bottom level. e.g., `[256, 256]`. 322 | * RSU-4F intermediate and output filters must paired, i.e., list with the same length. 323 | * RSU-4F intermediate filters numbers are expected to be smaller than output filters numbers. 324 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 325 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 326 | Default option is 'Softmax'. 327 | if None is received, then linear activation is applied. 328 | batch_norm: True for batch normalization. 329 | pool: True or 'max' for MaxPooling2D. 330 | 'ave' for AveragePooling2D. 331 | False for strided conv + batch norm + activation. 332 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 333 | 'nearest' for Upsampling2D with nearest interpolation. 334 | False for Conv2DTranspose + batch norm + activation. 335 | deep_supervision: True for a model that supports deep supervision. Details see Qin et al. (2020). 336 | name: prefix of the created keras layers. 337 | 338 | Output 339 | ---------- 340 | model: a keras model. 341 | 342 | * Automated hyper-parameter estimation will produce a slightly larger network, different from that of Qin et al. (2020). 343 | * Dilation rates of RSU4F layers are fixed to `[1, 2, 4, 8]`. 344 | * The default output activation is sigmoid, the same as in Qin et al. (2020). 345 | * Downsampling is achieved through maxpooling and can be replaced by strided convolutional layers. 346 | * Upsampling is achieved through bilinear interpolation and can be replaced by transpose convolutional layers. 347 | 348 | ''' 349 | 350 | verbose = False 351 | 352 | if filter_num_up == 'auto': 353 | verbose = True 354 | filter_num_up = filter_num_down 355 | 356 | if filter_mid_num_down == 'auto': 357 | verbose = True 358 | filter_mid_num_down = [num//4 for num in filter_num_down] 359 | 360 | if filter_mid_num_up == 'auto': 361 | verbose = True 362 | filter_mid_num_up = filter_mid_num_down 363 | 364 | if filter_4f_num == 'auto': 365 | verbose = True 366 | filter_4f_num = [filter_num_down[-1], filter_num_down[-1]] 367 | 368 | if filter_4f_mid_num == 'auto': 369 | verbose = True 370 | filter_4f_mid_num = [num//2 for num in filter_4f_num] 371 | 372 | if verbose: 373 | print('Automated hyper-parameter determination is applied with the following details:\n----------') 374 | print('\tNumber of RSU output channels within downsampling blocks: filter_num_down = {}'.format(filter_num_down)) 375 | print('\tNumber of RSU intermediate channels within downsampling blocks: filter_mid_num_down = {}'.format(filter_mid_num_down)) 376 | print('\tNumber of RSU output channels within upsampling blocks: filter_num_up = {}'.format(filter_num_up)) 377 | print('\tNumber of RSU intermediate channels within upsampling blocks: filter_mid_num_up = {}'.format(filter_mid_num_up)) 378 | print('\tNumber of RSU-4F output channels within downsampling and bottom blocks: filter_4f_num = {}'.format(filter_4f_num)) 379 | print('\tNumber of RSU-4F intermediate channels within downsampling and bottom blocks: filter_4f_num = {}'.format(filter_4f_mid_num)) 380 | print('----------\nExplicitly specifying keywords listed above if their "auto" settings do not satisfy your needs') 381 | 382 | print("----------\nThe depth of u2net_2d = len(filter_num_down) + len(filter_4f_num) = {}".format(len(filter_num_down)+len(filter_4f_num))) 383 | 384 | X_skip = []; X_out = []; OUT_stack = [] 385 | depth_backup = [] 386 | depth_ = len(filter_num_down) 387 | 388 | IN = Input(shape=input_size) 389 | 390 | # base (before conv + activation + upsample) 391 | X_out = u2net_2d_base(IN, 392 | filter_num_down, filter_num_up, 393 | filter_mid_num_down, filter_mid_num_up, 394 | filter_4f_num, filter_4f_mid_num, activation=activation, 395 | batch_norm=batch_norm, pool=pool, unpool=unpool, name=name) 396 | 397 | # output layers 398 | X_out = X_out[::-1] 399 | L_out = len(X_out) 400 | 401 | X = CONV_output(X_out[0], n_labels, kernel_size=3, activation=output_activation, 402 | name='{}_output_sup0'.format(name)) 403 | OUT_stack.append(X) 404 | 405 | for i in range(1, L_out): 406 | 407 | pool_size = 2**(i) 408 | 409 | X = Conv2D(n_labels, 3, padding='same', name='{}_output_conv_{}'.format(name, i))(X_out[i]) 410 | 411 | X = decode_layer(X, n_labels, pool_size, unpool, 412 | activation=None, batch_norm=False, name='{}_sup{}'.format(name, i)) 413 | 414 | if output_activation: 415 | if output_activation == 'Sigmoid': 416 | X = Activation('sigmoid', name='{}_output_sup{}_activation'.format(name, i))(X) 417 | else: 418 | activation_func = eval(output_activation) 419 | X = activation_func(name='{}_output_sup{}_activation'.format(name, i))(X) 420 | 421 | OUT_stack.append(X) 422 | 423 | D = concatenate(OUT_stack, axis=-1, name='{}_output_concat'.format(name)) 424 | 425 | D = CONV_output(D, n_labels, kernel_size=1, activation=output_activation, 426 | name='{}_output_final'.format(name)) 427 | 428 | if deep_supervision: 429 | 430 | OUT_stack.append(D) 431 | print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows ("sup0" is the shallowest supervision layer;\n"final" is the final output layer):\n') 432 | 433 | if output_activation == None: 434 | if unpool is False: 435 | for i in range(L_out): 436 | print('\t{}_output_sup{}_trans_conv'.format(name, i)) 437 | else: 438 | for i in range(L_out): 439 | print('\t{}_output_sup{}_unpool'.format(name, i)) 440 | 441 | print('\t{}_output_final'.format(name)) 442 | 443 | else: 444 | for i in range(L_out): 445 | print('\t{}_output_sup{}_activation'.format(name, i)) 446 | 447 | print('\t{}_output_final_activation'.format(name)) 448 | 449 | model = Model([IN,], OUT_stack) 450 | 451 | else: 452 | model = Model([IN,], [D,]) 453 | 454 | return model 455 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_unet_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker 7 | 8 | from tensorflow.keras.layers import Input 9 | from tensorflow.keras.models import Model 10 | 11 | def UNET_left(X, channel, kernel_size=3, stack_num=2, activation='ReLU', 12 | pool=True, batch_norm=False, name='left0'): 13 | ''' 14 | The encoder block of U-net. 15 | 16 | UNET_left(X, channel, kernel_size=3, stack_num=2, activation='ReLU', 17 | pool=True, batch_norm=False, name='left0') 18 | 19 | Input 20 | ---------- 21 | X: input tensor. 22 | channel: number of convolution filters. 23 | kernel_size: size of 2-d convolution kernels. 24 | stack_num: number of convolutional layers. 25 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'. 26 | pool: True or 'max' for MaxPooling2D. 27 | 'ave' for AveragePooling2D. 28 | False for strided conv + batch norm + activation. 29 | batch_norm: True for batch normalization, False otherwise. 30 | name: prefix of the created keras layers. 31 | 32 | Output 33 | ---------- 34 | X: output tensor. 35 | 36 | ''' 37 | pool_size = 2 38 | 39 | X = encode_layer(X, channel, pool_size, pool, activation=activation, 40 | batch_norm=batch_norm, name='{}_encode'.format(name)) 41 | 42 | X = CONV_stack(X, channel, kernel_size, stack_num=stack_num, activation=activation, 43 | batch_norm=batch_norm, name='{}_conv'.format(name)) 44 | 45 | return X 46 | 47 | 48 | def UNET_right(X, X_list, channel, kernel_size=3, 49 | stack_num=2, activation='ReLU', 50 | unpool=True, batch_norm=False, concat=True, name='right0'): 51 | 52 | ''' 53 | The decoder block of U-net. 54 | 55 | Input 56 | ---------- 57 | X: input tensor. 58 | X_list: a list of other tensors that connected to the input tensor. 59 | channel: number of convolution filters. 60 | kernel_size: size of 2-d convolution kernels. 61 | stack_num: number of convolutional layers. 62 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'. 63 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 64 | 'nearest' for Upsampling2D with nearest interpolation. 65 | False for Conv2DTranspose + batch norm + activation. 66 | batch_norm: True for batch normalization, False otherwise. 67 | concat: True for concatenating the corresponded X_list elements. 68 | name: prefix of the created keras layers. 69 | 70 | Output 71 | ---------- 72 | X: output tensor. 73 | 74 | ''' 75 | 76 | pool_size = 2 77 | 78 | X = decode_layer(X, channel, pool_size, unpool, 79 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name)) 80 | 81 | # linear convolutional layers before concatenation 82 | X = CONV_stack(X, channel, kernel_size, stack_num=1, activation=activation, 83 | batch_norm=batch_norm, name='{}_conv_before_concat'.format(name)) 84 | if concat: 85 | # <--- *stacked convolutional can be applied here 86 | X = concatenate([X,]+X_list, axis=3, name=name+'_concat') 87 | 88 | # Stacked convolutions after concatenation 89 | X = CONV_stack(X, channel, kernel_size, stack_num=stack_num, activation=activation, 90 | batch_norm=batch_norm, name=name+'_conv_after_concat') 91 | 92 | return X 93 | 94 | def unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, 95 | activation='ReLU', batch_norm=False, pool=True, unpool=True, 96 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet'): 97 | 98 | ''' 99 | The base of U-net with an optional ImageNet-trained backbone. 100 | 101 | unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, 102 | activation='ReLU', batch_norm=False, pool=True, unpool=True, 103 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet') 104 | 105 | ---------- 106 | Ronneberger, O., Fischer, P. and Brox, T., 2015, October. U-net: Convolutional networks for biomedical image segmentation. 107 | In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham. 108 | 109 | Input 110 | ---------- 111 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 112 | filter_num: a list that defines the number of filters for each \ 113 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 114 | The depth is expected as `len(filter_num)`. 115 | stack_num_down: number of convolutional layers per downsampling level/block. 116 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 117 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 118 | batch_norm: True for batch normalization. 119 | pool: True or 'max' for MaxPooling2D. 120 | 'ave' for AveragePooling2D. 121 | False for strided conv + batch norm + activation. 122 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 123 | 'nearest' for Upsampling2D with nearest interpolation. 124 | False for Conv2DTranspose + batch norm + activation. 125 | name: prefix of the created keras model and its layers. 126 | 127 | ---------- (keywords of backbone options) ---------- 128 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 129 | None (default) means no backbone. 130 | Currently supported backbones are: 131 | (1) VGG16, VGG19 132 | (2) ResNet50, ResNet101, ResNet152 133 | (3) ResNet50V2, ResNet101V2, ResNet152V2 134 | (4) DenseNet121, DenseNet169, DenseNet201 135 | (5) EfficientNetB[0-7] 136 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 137 | or the path to the weights file to be loaded. 138 | freeze_backbone: True for a frozen backbone. 139 | freeze_batch_norm: False for not freezing batch normalization layers. 140 | 141 | Output 142 | ---------- 143 | X: output tensor. 144 | 145 | ''' 146 | 147 | activation_func = eval(activation) 148 | 149 | X_skip = [] 150 | depth_ = len(filter_num) 151 | 152 | # no backbone cases 153 | if backbone is None: 154 | 155 | X = input_tensor 156 | 157 | # stacked conv2d before downsampling 158 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation, 159 | batch_norm=batch_norm, name='{}_down0'.format(name)) 160 | X_skip.append(X) 161 | 162 | # downsampling blocks 163 | for i, f in enumerate(filter_num[1:]): 164 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, pool=pool, 165 | batch_norm=batch_norm, name='{}_down{}'.format(name, i+1)) 166 | X_skip.append(X) 167 | 168 | # backbone cases 169 | else: 170 | # handling VGG16 and VGG19 separately 171 | if 'VGG' in backbone: 172 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm) 173 | # collecting backbone feature maps 174 | X_skip = backbone_([input_tensor,]) 175 | depth_encode = len(X_skip) 176 | 177 | # for other backbones 178 | else: 179 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm) 180 | # collecting backbone feature maps 181 | X_skip = backbone_([input_tensor,]) 182 | depth_encode = len(X_skip) + 1 183 | 184 | 185 | # extra conv2d blocks are applied 186 | # if downsampling levels of a backbone < user-specified downsampling levels 187 | if depth_encode < depth_: 188 | 189 | # begins at the deepest available tensor 190 | X = X_skip[-1] 191 | 192 | # extra downsamplings 193 | for i in range(depth_-depth_encode): 194 | i_real = i + depth_encode 195 | 196 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool, 197 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1)) 198 | X_skip.append(X) 199 | 200 | # reverse indexing encoded feature maps 201 | X_skip = X_skip[::-1] 202 | # upsampling begins at the deepest available tensor 203 | X = X_skip[0] 204 | # other tensors are preserved for concatenation 205 | X_decode = X_skip[1:] 206 | depth_decode = len(X_decode) 207 | 208 | # reverse indexing filter numbers 209 | filter_num_decode = filter_num[:-1][::-1] 210 | 211 | # upsampling with concatenation 212 | for i in range(depth_decode): 213 | X = UNET_right(X, [X_decode[i],], filter_num_decode[i], stack_num=stack_num_up, activation=activation, 214 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i)) 215 | 216 | # if tensors for concatenation is not enough 217 | # then use upsampling without concatenation 218 | if depth_decode < depth_-1: 219 | for i in range(depth_-depth_decode-1): 220 | i_real = i + depth_decode 221 | X = UNET_right(X, None, filter_num_decode[i_real], stack_num=stack_num_up, activation=activation, 222 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_up{}'.format(name, i_real)) 223 | return X 224 | 225 | def unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, 226 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, 227 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet'): 228 | ''' 229 | U-net with an optional ImageNet-trained bakcbone. 230 | 231 | unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, 232 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, 233 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet') 234 | 235 | ---------- 236 | Ronneberger, O., Fischer, P. and Brox, T., 2015, October. U-net: Convolutional networks for biomedical image segmentation. 237 | In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham. 238 | 239 | Input 240 | ---------- 241 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 242 | filter_num: a list that defines the number of filters for each \ 243 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 244 | The depth is expected as `len(filter_num)`. 245 | n_labels: number of output labels. 246 | stack_num_down: number of convolutional layers per downsampling level/block. 247 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 248 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 249 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 250 | Default option is 'Softmax'. 251 | if None is received, then linear activation is applied. 252 | batch_norm: True for batch normalization. 253 | pool: True or 'max' for MaxPooling2D. 254 | 'ave' for AveragePooling2D. 255 | False for strided conv + batch norm + activation. 256 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 257 | 'nearest' for Upsampling2D with nearest interpolation. 258 | False for Conv2DTranspose + batch norm + activation. 259 | name: prefix of the created keras model and its layers. 260 | 261 | ---------- (keywords of backbone options) ---------- 262 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 263 | None (default) means no backbone. 264 | Currently supported backbones are: 265 | (1) VGG16, VGG19 266 | (2) ResNet50, ResNet101, ResNet152 267 | (3) ResNet50V2, ResNet101V2, ResNet152V2 268 | (4) DenseNet121, DenseNet169, DenseNet201 269 | (5) EfficientNetB[0-7] 270 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 271 | or the path to the weights file to be loaded. 272 | freeze_backbone: True for a frozen backbone. 273 | freeze_batch_norm: False for not freezing batch normalization layers. 274 | 275 | Output 276 | ---------- 277 | model: a keras model. 278 | 279 | ''' 280 | activation_func = eval(activation) 281 | 282 | if backbone is not None: 283 | bach_norm_checker(backbone, batch_norm) 284 | 285 | IN = Input(input_size) 286 | 287 | # base 288 | X = unet_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up, 289 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, 290 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, 291 | freeze_batch_norm=freeze_backbone, name=name) 292 | 293 | # output layer 294 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 295 | 296 | # functional API model 297 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name)) 298 | 299 | return model 300 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_unet_3plus_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker 7 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right 8 | 9 | from tensorflow.keras.layers import Input 10 | from tensorflow.keras.models import Model 11 | 12 | def unet_3plus_2d_base(input_tensor, filter_num_down, filter_num_skip, filter_num_aggregate, 13 | stack_num_down=2, stack_num_up=1, activation='ReLU', batch_norm=False, pool=True, unpool=True, 14 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus'): 15 | ''' 16 | The base of UNET 3+ with an optional ImagNet-trained backbone. 17 | 18 | unet_3plus_2d_base(input_tensor, filter_num_down, filter_num_skip, filter_num_aggregate, 19 | stack_num_down=2, stack_num_up=1, activation='ReLU', batch_norm=False, pool=True, unpool=True, 20 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus') 21 | 22 | ---------- 23 | Huang, H., Lin, L., Tong, R., Hu, H., Zhang, Q., Iwamoto, Y., Han, X., Chen, Y.W. and Wu, J., 2020. 24 | UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation. 25 | In ICASSP 2020-2020 IEEE International Conference on Acoustics, 26 | Speech and Signal Processing (ICASSP) (pp. 1055-1059). IEEE. 27 | 28 | Input 29 | ---------- 30 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 31 | filter_num_down: a list that defines the number of filters for each 32 | downsampling level. e.g., `[64, 128, 256, 512, 1024]`. 33 | the network depth is expected as `len(filter_num_down)` 34 | filter_num_skip: a list that defines the number of filters after each 35 | full-scale skip connection. Number of elements is expected to be `depth-1`. 36 | i.e., the bottom level is not included. 37 | * Huang et al. (2020) applied the same numbers for all levels. 38 | e.g., `[64, 64, 64, 64]`. 39 | filter_num_aggregate: an int that defines the number of channels of full-scale aggregations. 40 | stack_num_down: number of convolutional layers per downsampling level/block. 41 | stack_num_up: number of convolutional layers (after full-scale concat) per upsampling level/block. 42 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., ReLU 43 | batch_norm: True for batch normalization. 44 | pool: True or 'max' for MaxPooling2D. 45 | 'ave' for AveragePooling2D. 46 | False for strided conv + batch norm + activation. 47 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 48 | 'nearest' for Upsampling2D with nearest interpolation. 49 | False for Conv2DTranspose + batch norm + activation. 50 | name: prefix of the created keras model and its layers. 51 | 52 | ---------- (keywords of backbone options) ---------- 53 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 54 | None (default) means no backbone. 55 | Currently supported backbones are: 56 | (1) VGG16, VGG19 57 | (2) ResNet50, ResNet101, ResNet152 58 | (3) ResNet50V2, ResNet101V2, ResNet152V2 59 | (4) DenseNet121, DenseNet169, DenseNet201 60 | (5) EfficientNetB[0-7] 61 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 62 | or the path to the weights file to be loaded. 63 | freeze_backbone: True for a frozen backbone. 64 | freeze_batch_norm: False for not freezing batch normalization layers. 65 | 66 | * Downsampling is achieved through maxpooling and can be replaced by strided convolutional layers here. 67 | * Upsampling is achieved through bilinear interpolation and can be replaced by transpose convolutional layers here. 68 | 69 | Output 70 | ---------- 71 | A list of tensors with the first/second/third tensor obtained from 72 | the deepest/second deepest/third deepest upsampling block, etc. 73 | * The feature map sizes of these tensors are different, 74 | with the first tensor has the smallest size. 75 | 76 | ''' 77 | 78 | depth_ = len(filter_num_down) 79 | 80 | X_encoder = [] 81 | X_decoder = [] 82 | 83 | # no backbone cases 84 | if backbone is None: 85 | 86 | X = input_tensor 87 | 88 | # stacked conv2d before downsampling 89 | X = CONV_stack(X, filter_num_down[0], kernel_size=3, stack_num=stack_num_down, 90 | activation=activation, batch_norm=batch_norm, name='{}_down0'.format(name)) 91 | X_encoder.append(X) 92 | 93 | # downsampling levels 94 | for i, f in enumerate(filter_num_down[1:]): 95 | 96 | # UNET-like downsampling 97 | X = UNET_left(X, f, kernel_size=3, stack_num=stack_num_down, activation=activation, 98 | pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1)) 99 | X_encoder.append(X) 100 | 101 | else: 102 | # handling VGG16 and VGG19 separately 103 | if 'VGG' in backbone: 104 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm) 105 | # collecting backbone feature maps 106 | X_encoder = backbone_([input_tensor,]) 107 | depth_encode = len(X_encoder) 108 | 109 | # for other backbones 110 | else: 111 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm) 112 | # collecting backbone feature maps 113 | X_encoder = backbone_([input_tensor,]) 114 | depth_encode = len(X_encoder) + 1 115 | 116 | # extra conv2d blocks are applied 117 | # if downsampling levels of a backbone < user-specified downsampling levels 118 | if depth_encode < depth_: 119 | 120 | # begins at the deepest available tensor 121 | X = X_encoder[-1] 122 | 123 | # extra downsamplings 124 | for i in range(depth_-depth_encode): 125 | 126 | i_real = i + depth_encode 127 | 128 | X = UNET_left(X, filter_num_down[i_real], stack_num=stack_num_down, activation=activation, pool=pool, 129 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1)) 130 | X_encoder.append(X) 131 | 132 | 133 | # treat the last encoded tensor as the first decoded tensor 134 | X_decoder.append(X_encoder[-1]) 135 | 136 | # upsampling levels 137 | X_encoder = X_encoder[::-1] 138 | 139 | depth_decode = len(X_encoder)-1 140 | 141 | # loop over upsampling levels 142 | for i in range(depth_decode): 143 | 144 | f = filter_num_skip[i] 145 | 146 | # collecting tensors for layer fusion 147 | X_fscale = [] 148 | 149 | # for each upsampling level, loop over all available downsampling levels (similar to the unet++) 150 | for lev in range(depth_decode): 151 | 152 | # counting scale difference between the current down- and upsampling levels 153 | pool_scale = lev-i-1 # -1 for python indexing 154 | 155 | # deeper tensors are obtained from **decoder** outputs 156 | if pool_scale < 0: 157 | pool_size = 2**(-1*pool_scale) 158 | 159 | X = decode_layer(X_decoder[lev], f, pool_size, unpool, 160 | activation=activation, batch_norm=batch_norm, name='{}_up_{}_en{}'.format(name, i, lev)) 161 | 162 | # unet skip connection (identity mapping) 163 | elif pool_scale == 0: 164 | 165 | X = X_encoder[lev] 166 | 167 | # shallower tensors are obtained from **encoder** outputs 168 | else: 169 | pool_size = 2**(pool_scale) 170 | 171 | X = encode_layer(X_encoder[lev], f, pool_size, pool, activation=activation, 172 | batch_norm=batch_norm, name='{}_down_{}_en{}'.format(name, i, lev)) 173 | 174 | # a conv layer after feature map scale change 175 | X = CONV_stack(X, f, kernel_size=3, stack_num=1, 176 | activation=activation, batch_norm=batch_norm, name='{}_down_from{}_to{}'.format(name, i, lev)) 177 | 178 | X_fscale.append(X) 179 | 180 | # layer fusion at the end of each level 181 | # stacked conv layers after concat. BatchNormalization is fixed to True 182 | 183 | X = concatenate(X_fscale, axis=-1, name='{}_concat_{}'.format(name, i)) 184 | X = CONV_stack(X, filter_num_aggregate, kernel_size=3, stack_num=stack_num_up, 185 | activation=activation, batch_norm=True, name='{}_fusion_conv_{}'.format(name, i)) 186 | X_decoder.append(X) 187 | 188 | # if tensors for concatenation is not enough 189 | # then use upsampling without concatenation 190 | if depth_decode < depth_-1: 191 | for i in range(depth_-depth_decode-1): 192 | i_real = i + depth_decode 193 | X = UNET_right(X, None, filter_num_aggregate, stack_num=stack_num_up, activation=activation, 194 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_plain_up{}'.format(name, i_real)) 195 | X_decoder.append(X) 196 | 197 | # return decoder outputs 198 | return X_decoder 199 | 200 | def unet_3plus_2d(input_size, n_labels, filter_num_down, filter_num_skip='auto', filter_num_aggregate='auto', 201 | stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid', 202 | batch_norm=False, pool=True, unpool=True, deep_supervision=False, 203 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus'): 204 | 205 | ''' 206 | UNET 3+ with an optional ImageNet-trained backbone. 207 | 208 | unet_3plus_2d(input_size, n_labels, filter_num_down, filter_num_skip='auto', filter_num_aggregate='auto', 209 | stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid', 210 | batch_norm=False, pool=True, unpool=True, deep_supervision=False, 211 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus') 212 | 213 | ---------- 214 | Huang, H., Lin, L., Tong, R., Hu, H., Zhang, Q., Iwamoto, Y., Han, X., Chen, Y.W. and Wu, J., 2020. 215 | UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation. 216 | In ICASSP 2020-2020 IEEE International Conference on Acoustics, 217 | Speech and Signal Processing (ICASSP) (pp. 1055-1059). IEEE. 218 | 219 | Input 220 | ---------- 221 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 222 | filter_num_down: a list that defines the number of filters for each 223 | downsampling level. e.g., `[64, 128, 256, 512, 1024]`. 224 | the network depth is expected as `len(filter_num_down)` 225 | filter_num_skip: a list that defines the number of filters after each 226 | full-scale skip connection. Number of elements is expected to be `depth-1`. 227 | i.e., the bottom level is not included. 228 | * Huang et al. (2020) applied the same numbers for all levels. 229 | e.g., `[64, 64, 64, 64]`. 230 | filter_num_aggregate: an int that defines the number of channels of full-scale aggregations. 231 | stack_num_down: number of convolutional layers per downsampling level/block. 232 | stack_num_up: number of convolutional layers (after full-scale concat) per upsampling level/block. 233 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU' 234 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 235 | Default option is 'Softmax'. 236 | if None is received, then linear activation is applied. 237 | batch_norm: True for batch normalization. 238 | pool: True or 'max' for MaxPooling2D. 239 | 'ave' for AveragePooling2D. 240 | False for strided conv + batch norm + activation. 241 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 242 | 'nearest' for Upsampling2D with nearest interpolation. 243 | False for Conv2DTranspose + batch norm + activation. 244 | deep_supervision: True for a model that supports deep supervision. Details see Huang et al. (2020). 245 | name: prefix of the created keras model and its layers. 246 | 247 | ---------- (keywords of backbone options) ---------- 248 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 249 | None (default) means no backbone. 250 | Currently supported backbones are: 251 | (1) VGG16, VGG19 252 | (2) ResNet50, ResNet101, ResNet152 253 | (3) ResNet50V2, ResNet101V2, ResNet152V2 254 | (4) DenseNet121, DenseNet169, DenseNet201 255 | (5) EfficientNetB[0-7] 256 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 257 | or the path to the weights file to be loaded. 258 | freeze_backbone: True for a frozen backbone. 259 | freeze_batch_norm: False for not freezing batch normalization layers. 260 | 261 | * The Classification-guided Module (CGM) is not implemented. 262 | See https://github.com/yingkaisha/keras-unet-collection/tree/main/examples for a relevant example. 263 | * Automated mode is applied for determining `filter_num_skip`, `filter_num_aggregate`. 264 | * The default output activation is sigmoid, consistent with Huang et al. (2020). 265 | * Downsampling is achieved through maxpooling and can be replaced by strided convolutional layers here. 266 | * Upsampling is achieved through bilinear interpolation and can be replaced by transpose convolutional layers here. 267 | 268 | Output 269 | ---------- 270 | model: a keras model. 271 | 272 | ''' 273 | 274 | depth_ = len(filter_num_down) 275 | 276 | verbose = False 277 | 278 | if filter_num_skip == 'auto': 279 | verbose = True 280 | filter_num_skip = [filter_num_down[0] for num in range(depth_-1)] 281 | 282 | if filter_num_aggregate == 'auto': 283 | verbose = True 284 | filter_num_aggregate = int(depth_*filter_num_down[0]) 285 | 286 | if verbose: 287 | print('Automated hyper-parameter determination is applied with the following details:\n----------') 288 | print('\tNumber of convolution filters after each full-scale skip connection: filter_num_skip = {}'.format(filter_num_skip)) 289 | print('\tNumber of channels of full-scale aggregated feature maps: filter_num_aggregate = {}'.format(filter_num_aggregate)) 290 | 291 | if backbone is not None: 292 | bach_norm_checker(backbone, batch_norm) 293 | 294 | X_encoder = [] 295 | X_decoder = [] 296 | 297 | 298 | IN = Input(input_size) 299 | 300 | X_decoder = unet_3plus_2d_base(IN, filter_num_down, filter_num_skip, filter_num_aggregate, 301 | stack_num_down=stack_num_down, stack_num_up=stack_num_up, activation=activation, 302 | batch_norm=batch_norm, pool=pool, unpool=unpool, 303 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, 304 | freeze_batch_norm=freeze_batch_norm, name=name) 305 | X_decoder = X_decoder[::-1] 306 | 307 | if deep_supervision: 308 | 309 | # ----- frozen backbone issue checker ----- # 310 | if ('{}_backbone_'.format(backbone) in X_decoder[0].name) and freeze_backbone: 311 | 312 | backbone_warn = '\n\nThe deepest UNET 3+ deep supervision branch directly connects to a frozen backbone.\nTesting your configurations on `keras_unet_collection.base.unet_plus_2d_base` is recommended.' 313 | warnings.warn(backbone_warn); 314 | # ----------------------------------------- # 315 | 316 | OUT_stack = [] 317 | L_out = len(X_decoder) 318 | 319 | print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows ("sup0" is the shallowest supervision layer;\n"final" is the final output layer):\n') 320 | 321 | # conv2d --> upsampling --> output activation. 322 | # index 0 is final output 323 | for i in range(1, L_out): 324 | 325 | pool_size = 2**(i) 326 | 327 | X = Conv2D(n_labels, 3, padding='same', name='{}_output_conv_{}'.format(name, i-1))(X_decoder[i]) 328 | 329 | X = decode_layer(X, n_labels, pool_size, unpool, 330 | activation=None, batch_norm=False, name='{}_output_sup{}'.format(name, i-1)) 331 | 332 | if output_activation: 333 | print('\t{}_output_sup{}_activation'.format(name, i-1)) 334 | 335 | if output_activation == 'Sigmoid': 336 | X = Activation('sigmoid', name='{}_output_sup{}_activation'.format(name, i-1))(X) 337 | else: 338 | activation_func = eval(output_activation) 339 | X = activation_func(name='{}_output_sup{}_activation'.format(name, i-1))(X) 340 | else: 341 | if unpool is False: 342 | print('\t{}_output_sup{}_trans_conv'.format(name, i-1)) 343 | else: 344 | print('\t{}_output_sup{}_unpool'.format(name, i-1)) 345 | 346 | OUT_stack.append(X) 347 | 348 | X = CONV_output(X_decoder[0], n_labels, kernel_size=3, 349 | activation=output_activation, name='{}_output_final'.format(name)) 350 | OUT_stack.append(X) 351 | 352 | if output_activation: 353 | print('\t{}_output_final_activation'.format(name)) 354 | else: 355 | print('\t{}_output_final'.format(name)) 356 | 357 | model = Model([IN,], OUT_stack) 358 | 359 | else: 360 | OUT = CONV_output(X_decoder[0], n_labels, kernel_size=3, 361 | activation=output_activation, name='{}_output_final'.format(name)) 362 | 363 | model = Model([IN,], [OUT,]) 364 | 365 | return model 366 | 367 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_unet_plus_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker 7 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right 8 | 9 | from tensorflow.keras.layers import Input 10 | from tensorflow.keras.models import Model 11 | 12 | import warnings 13 | 14 | def unet_plus_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, 15 | activation='ReLU', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 16 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet'): 17 | ''' 18 | The base of U-net++ with an optional ImageNet-trained backbone 19 | 20 | unet_plus_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, 21 | activation='ReLU', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 22 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet') 23 | 24 | ---------- 25 | Zhou, Z., Siddiquee, M.M.R., Tajbakhsh, N. and Liang, J., 2018. Unet++: A nested u-net architecture 26 | for medical image segmentation. In Deep Learning in Medical Image Analysis and Multimodal Learning 27 | for Clinical Decision Support (pp. 3-11). Springer, Cham. 28 | 29 | Input 30 | ---------- 31 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 32 | filter_num: a list that defines the number of filters for each \ 33 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 34 | The depth is expected as `len(filter_num)`. 35 | stack_num_down: number of convolutional layers per downsampling level/block. 36 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 37 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 38 | batch_norm: True for batch normalization. 39 | pool: True or 'max' for MaxPooling2D. 40 | 'ave' for AveragePooling2D. 41 | False for strided conv + batch norm + activation. 42 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 43 | 'nearest' for Upsampling2D with nearest interpolation. 44 | False for Conv2DTranspose + batch norm + activation. 45 | deep_supervision: True for a model that supports deep supervision. Details see Zhou et al. (2018). 46 | name: prefix of the created keras model and its layers. 47 | 48 | ---------- (keywords of backbone options) ---------- 49 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 50 | None (default) means no backbone. 51 | Currently supported backbones are: 52 | (1) VGG16, VGG19 53 | (2) ResNet50, ResNet101, ResNet152 54 | (3) ResNet50V2, ResNet101V2, ResNet152V2 55 | (4) DenseNet121, DenseNet169, DenseNet201 56 | (5) EfficientNetB[0-7] 57 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 58 | or the path to the weights file to be loaded. 59 | freeze_backbone: True for a frozen backbone. 60 | freeze_batch_norm: False for not freezing batch normalization layers. 61 | 62 | Output 63 | ---------- 64 | If deep_supervision = False; Then the output is a tensor. 65 | If deep_supervision = True; Then the output is a list of tensors 66 | with the first tensor obtained from the first downsampling level (for checking the input/output shapes only), 67 | the second to the `depth-1`-th tensors obtained from each intermediate upsampling levels (deep supervision tensors), 68 | and the last tensor obtained from the end of the base. 69 | 70 | ''' 71 | 72 | activation_func = eval(activation) 73 | 74 | depth_ = len(filter_num) 75 | # allocate nested lists for collecting output tensors 76 | X_nest_skip = [[] for _ in range(depth_)] 77 | 78 | # no backbone cases 79 | if backbone is None: 80 | 81 | X = input_tensor 82 | 83 | # downsampling blocks (same as in 'unet_2d') 84 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation, 85 | batch_norm=batch_norm, name='{}_down0'.format(name)) 86 | X_nest_skip[0].append(X) 87 | for i, f in enumerate(filter_num[1:]): 88 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, 89 | pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1)) 90 | X_nest_skip[0].append(X) 91 | 92 | # backbone cases 93 | else: 94 | # handling VGG16 and VGG19 separately 95 | if 'VGG' in backbone: 96 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm) 97 | # collecting backbone feature maps 98 | X_nest_skip[0] += backbone_([input_tensor,]) 99 | depth_encode = len(X_nest_skip[0]) 100 | 101 | # for other backbones 102 | else: 103 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm) 104 | # collecting backbone feature maps 105 | X_nest_skip[0] += backbone_([input_tensor,]) 106 | depth_encode = len(X_nest_skip[0]) + 1 107 | 108 | # extra conv2d blocks are applied 109 | # if downsampling levels of a backbone < user-specified downsampling levels 110 | if depth_encode < depth_: 111 | 112 | # begins at the deepest available tensor 113 | X = X_nest_skip[0][-1] 114 | 115 | # extra downsamplings 116 | for i in range(depth_-depth_encode): 117 | i_real = i + depth_encode 118 | 119 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool, 120 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1)) 121 | X_nest_skip[0].append(X) 122 | 123 | 124 | X = X_nest_skip[0][-1] 125 | 126 | for nest_lev in range(1, depth_): 127 | 128 | # depth difference between the deepest nest skip and the current upsampling 129 | depth_lev = depth_-nest_lev 130 | 131 | # number of available encoded tensors 132 | depth_decode = len(X_nest_skip[nest_lev-1]) 133 | 134 | # loop over individual upsamling levels 135 | for i in range(1, depth_decode): 136 | 137 | # collecting previous downsampling outputs 138 | previous_skip = [] 139 | for previous_lev in range(nest_lev): 140 | previous_skip.append(X_nest_skip[previous_lev][i-1]) 141 | 142 | # upsamping block that concatenates all available (same feature map size) down-/upsampling outputs 143 | X_nest_skip[nest_lev].append( 144 | UNET_right(X_nest_skip[nest_lev-1][i], previous_skip, filter_num[i-1], 145 | stack_num=stack_num_up, activation=activation, unpool=unpool, 146 | batch_norm=batch_norm, concat=False, name='{}_up{}_from{}'.format(name, nest_lev-1, i-1))) 147 | 148 | if depth_decode < depth_lev+1: 149 | 150 | X = X_nest_skip[nest_lev-1][-1] 151 | 152 | for j in range(depth_lev-depth_decode+1): 153 | j_real = j + depth_decode 154 | X = UNET_right(X, None, filter_num[j_real-1], 155 | stack_num=stack_num_up, activation=activation, unpool=unpool, 156 | batch_norm=batch_norm, concat=False, name='{}_up{}_from{}'.format(name, nest_lev-1, j_real-1)) 157 | X_nest_skip[nest_lev].append(X) 158 | 159 | # output 160 | if deep_supervision: 161 | 162 | X_list = [] 163 | 164 | for i in range(depth_): 165 | X_list.append(X_nest_skip[i][0]) 166 | 167 | return X_list 168 | 169 | else: 170 | return X_nest_skip[-1][0] 171 | 172 | def unet_plus_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, 173 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 174 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet'): 175 | ''' 176 | U-net++ with an optional ImageNet-trained backbone. 177 | 178 | unet_plus_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, 179 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, deep_supervision=False, 180 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet') 181 | 182 | ---------- 183 | Zhou, Z., Siddiquee, M.M.R., Tajbakhsh, N. and Liang, J., 2018. Unet++: A nested u-net architecture 184 | for medical image segmentation. In Deep Learning in Medical Image Analysis and Multimodal Learning 185 | for Clinical Decision Support (pp. 3-11). Springer, Cham. 186 | 187 | Input 188 | ---------- 189 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 190 | filter_num: a list that defines the number of filters for each \ 191 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 192 | The depth is expected as `len(filter_num)`. 193 | n_labels: number of output labels. 194 | stack_num_down: number of convolutional layers per downsampling level/block. 195 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block. 196 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 197 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 198 | Default option is 'Softmax'. 199 | if None is received, then linear activation is applied. 200 | batch_norm: True for batch normalization. 201 | pool: True or 'max' for MaxPooling2D. 202 | 'ave' for AveragePooling2D. 203 | False for strided conv + batch norm + activation. 204 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 205 | 'nearest' for Upsampling2D with nearest interpolation. 206 | False for Conv2DTranspose + batch norm + activation. 207 | deep_supervision: True for a model that supports deep supervision. Details see Zhou et al. (2018). 208 | name: prefix of the created keras model and its layers. 209 | 210 | ---------- (keywords of backbone options) ---------- 211 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class. 212 | None (default) means no backbone. 213 | Currently supported backbones are: 214 | (1) VGG16, VGG19 215 | (2) ResNet50, ResNet101, ResNet152 216 | (3) ResNet50V2, ResNet101V2, ResNet152V2 217 | (4) DenseNet121, DenseNet169, DenseNet201 218 | (5) EfficientNetB[0-7] 219 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet), 220 | or the path to the weights file to be loaded. 221 | freeze_backbone: True for a frozen backbone. 222 | freeze_batch_norm: False for not freezing batch normalization layers. 223 | 224 | Output 225 | ---------- 226 | model: a keras model. 227 | 228 | ''' 229 | 230 | depth_ = len(filter_num) 231 | 232 | if backbone is not None: 233 | bach_norm_checker(backbone, batch_norm) 234 | 235 | IN = Input(input_size) 236 | # base 237 | X = unet_plus_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up, 238 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, deep_supervision=deep_supervision, 239 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, freeze_batch_norm=freeze_batch_norm, name=name) 240 | 241 | # output 242 | if deep_supervision: 243 | 244 | if (backbone is not None) and freeze_backbone: 245 | backbone_warn = '\n\nThe shallowest U-net++ deep supervision branch directly connects to a frozen backbone.\nTesting your configurations on `keras_unet_collection.base.unet_plus_2d_base` is recommended.' 246 | warnings.warn(backbone_warn); 247 | 248 | # model base returns a list of tensors 249 | X_list = X 250 | OUT_list = [] 251 | 252 | print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows ("sup0" is the shallowest supervision layer;\n"final" is the final output layer):\n') 253 | 254 | # no backbone or VGG backbones 255 | # depth_ > 2 is expected (a least two downsampling blocks) 256 | if (backbone is None) or 'VGG' in backbone: 257 | 258 | for i in range(0, depth_-1): 259 | if output_activation is None: 260 | print('\t{}_output_sup{}'.format(name, i)) 261 | else: 262 | print('\t{}_output_sup{}_activation'.format(name, i)) 263 | 264 | OUT_list.append(CONV_output(X_list[i], n_labels, kernel_size=1, activation=output_activation, 265 | name='{}_output_sup{}'.format(name, i))) 266 | # other backbones 267 | else: 268 | for i in range(1, depth_-1): 269 | if output_activation is None: 270 | print('\t{}_output_sup{}'.format(name, i-1)) 271 | else: 272 | print('\t{}_output_sup{}_activation'.format(name, i-1)) 273 | 274 | # an extra upsampling for creating full resolution feature maps 275 | X = decode_layer(X_list[i], filter_num[i], 2, unpool, activation=activation, 276 | batch_norm=batch_norm, name='{}_sup{}_up'.format(name, i-1)) 277 | 278 | X = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output_sup{}'.format(name, i-1)) 279 | OUT_list.append(X) 280 | 281 | if output_activation is None: 282 | print('\t{}_output_final'.format(name)) 283 | else: 284 | print('\t{}_output_final_activation'.format(name)) 285 | 286 | OUT_list.append(CONV_output(X_list[-1], n_labels, kernel_size=1, activation=output_activation, name='{}_output_final'.format(name))) 287 | 288 | else: 289 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 290 | OUT_list = [OUT,] 291 | 292 | # model 293 | model = Model(inputs=[IN,], outputs=OUT_list, name='{}_model'.format(name)) 294 | 295 | return model 296 | -------------------------------------------------------------------------------- /keras_unet_collection/_model_vnet_2d.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.layer_utils import * 5 | from keras_unet_collection.activations import GELU, Snake 6 | 7 | from tensorflow.keras.layers import Input 8 | from tensorflow.keras.models import Model 9 | 10 | 11 | def vnet_left(X, channel, res_num, activation='ReLU', pool=True, batch_norm=False, name='left'): 12 | ''' 13 | The encoder block of 2-d V-net. 14 | 15 | vnet_left(X, channel, res_num, activation='ReLU', pool=True, batch_norm=False, name='left') 16 | 17 | Input 18 | ---------- 19 | X: input tensor. 20 | channel: number of convolution filters. 21 | res_num: number of convolutional layers within the residual path. 22 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'. 23 | pool: True or 'max' for MaxPooling2D. 24 | 'ave' for AveragePooling2D. 25 | False for strided conv + batch norm + activation. 26 | batch_norm: True for batch normalization, False otherwise. 27 | name: name of the created keras layers. 28 | 29 | Output 30 | ---------- 31 | X: output tensor. 32 | 33 | ''' 34 | 35 | pool_size = 2 36 | 37 | X = encode_layer(X, channel, pool_size, pool, activation=activation, 38 | batch_norm=batch_norm, name='{}_encode'.format(name)) 39 | 40 | if pool is not False: 41 | X = CONV_stack(X, channel, kernel_size=3, stack_num=1, dilation_rate=1, 42 | activation=activation, batch_norm=batch_norm, name='{}_pre_conv'.format(name)) 43 | 44 | X = Res_CONV_stack(X, X, channel, res_num=res_num, activation=activation, 45 | batch_norm=batch_norm, name='{}_res_conv'.format(name)) 46 | return X 47 | 48 | def vnet_right(X, X_list, channel, res_num, activation='ReLU', unpool=True, batch_norm=False, name='right'): 49 | ''' 50 | The decoder block of 2-d V-net. 51 | 52 | vnet_right(X, X_list, channel, res_num, activation='ReLU', unpool=True, batch_norm=False, name='right') 53 | 54 | Input 55 | ---------- 56 | X: input tensor. 57 | X_list: a list of other tensors that connected to the input tensor. 58 | channel: number of convolution filters. 59 | stack_num: number of convolutional layers. 60 | res_num: number of convolutional layers within the residual path. 61 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'. 62 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 63 | 'nearest' for Upsampling2D with nearest interpolation. 64 | False for Conv2DTranspose + batch norm + activation. 65 | batch_norm: True for batch normalization, False otherwise. 66 | name: name of the created keras layers. 67 | 68 | Output 69 | ---------- 70 | X: output tensor. 71 | 72 | ''' 73 | pool_size = 2 74 | 75 | X = decode_layer(X, channel, pool_size, unpool, 76 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name)) 77 | 78 | X_skip = X 79 | 80 | X = concatenate([X,]+X_list, axis=-1, name='{}_concat'.format(name)) 81 | 82 | X = Res_CONV_stack(X, X_skip, channel, res_num, activation=activation, 83 | batch_norm=batch_norm, name='{}_res_conv'.format(name)) 84 | 85 | return X 86 | 87 | def vnet_2d_base(input_tensor, filter_num, res_num_ini=1, res_num_max=3, 88 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='vnet'): 89 | ''' 90 | The base of 2-d V-net. 91 | 92 | vnet_2d_base(input_tensor, filter_num, res_num_ini=1, res_num_max=3, 93 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='vnet') 94 | 95 | Milletari, F., Navab, N. and Ahmadi, S.A., 2016, October. V-net: Fully convolutional neural 96 | networks for volumetric medical image segmentation. In 2016 fourth international conference 97 | on 3D vision (3DV) (pp. 565-571). IEEE. 98 | 99 | The Two-dimensional version is inspired by: 100 | https://github.com/FENGShuanglang/2D-Vnet-Keras 101 | 102 | Input 103 | ---------- 104 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`. 105 | filter_num: a list that defines the number of filters for each \ 106 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 107 | The depth is expected as `len(filter_num)`. 108 | res_num_ini: number of convolutional layers of the first first residual block (before downsampling). 109 | res_num_max: the max number of convolutional layers within a residual block. 110 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 111 | batch_norm: True for batch normalization. 112 | pool: True or 'max' for MaxPooling2D. 113 | 'ave' for AveragePooling2D. 114 | False for strided conv + batch norm + activation. 115 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 116 | 'nearest' for Upsampling2D with nearest interpolation. 117 | False for Conv2DTranspose + batch norm + activation. 118 | name: prefix of the created keras layers. 119 | 120 | Output 121 | ---------- 122 | X: output tensor. 123 | 124 | * This is a modified version of V-net for 2-d inputw. 125 | * The original work supports `pool=False` only. 126 | If pool is True, 'max', or 'ave', an additional conv2d layer will be applied. 127 | * All the 5-by-5 convolutional kernels are changed (and fixed) to 3-by-3. 128 | 129 | ''' 130 | 131 | depth_ = len(filter_num) 132 | 133 | # determine the number of res conv layers in each down- and upsampling level 134 | res_num_list = [] 135 | for i in range(depth_): 136 | temp_num = res_num_ini + i 137 | if temp_num > res_num_max: 138 | temp_num = res_num_max 139 | res_num_list.append(temp_num) 140 | 141 | X_skip = [] 142 | 143 | X = input_tensor 144 | # ini conv layer 145 | X = CONV_stack(X, filter_num[0], kernel_size=3, stack_num=1, dilation_rate=1, 146 | activation=activation, batch_norm=batch_norm, name='{}_input_conv'.format(name)) 147 | 148 | X = Res_CONV_stack(X, X, filter_num[0], res_num=res_num_list[0], activation=activation, 149 | batch_norm=batch_norm, name='{}_down_0'.format(name)) 150 | X_skip.append(X) 151 | 152 | # downsampling levels 153 | for i, f in enumerate(filter_num[1:]): 154 | X = vnet_left(X, f, res_num=res_num_list[i+1], activation=activation, pool=pool, 155 | batch_norm=batch_norm, name='{}_down_{}'.format(name, i+1)) 156 | 157 | X_skip.append(X) 158 | 159 | X_skip = X_skip[:-1][::-1] 160 | filter_num = filter_num[:-1][::-1] 161 | res_num_list = res_num_list[:-1][::-1] 162 | 163 | # upsampling levels 164 | for i, f in enumerate(filter_num): 165 | X = vnet_right(X, [X_skip[i],], f, res_num=res_num_list[i], 166 | activation=activation, unpool=unpool, batch_norm=batch_norm, name='{}_up_{}'.format(name, i)) 167 | 168 | return X 169 | 170 | 171 | def vnet_2d(input_size, filter_num, n_labels, 172 | res_num_ini=1, res_num_max=3, 173 | activation='ReLU', output_activation='Softmax', 174 | batch_norm=False, pool=True, unpool=True, name='vnet'): 175 | ''' 176 | vnet 2d 177 | 178 | vnet_2d(input_size, filter_num, n_labels, 179 | res_num_ini=1, res_num_max=3, 180 | activation='ReLU', output_activation='Softmax', 181 | batch_norm=False, pool=True, unpool=True, name='vnet') 182 | 183 | Milletari, F., Navab, N. and Ahmadi, S.A., 2016, October. V-net: Fully convolutional neural 184 | networks for volumetric medical image segmentation. In 2016 fourth international conference 185 | on 3D vision (3DV) (pp. 565-571). IEEE. 186 | 187 | The Two-dimensional version is inspired by: 188 | https://github.com/FENGShuanglang/2D-Vnet-Keras 189 | 190 | Input 191 | ---------- 192 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`. 193 | filter_num: a list that defines the number of filters for each \ 194 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`. 195 | The depth is expected as `len(filter_num)`. 196 | n_labels: number of output labels. 197 | res_num_ini: number of convolutional layers of the first first residual block (before downsampling). 198 | res_num_max: the max number of convolutional layers within a residual block. 199 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'. 200 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 201 | Default option is 'Softmax'. 202 | if None is received, then linear activation is applied. 203 | batch_norm: True for batch normalization. 204 | pool: True or 'max' for MaxPooling2D. 205 | 'ave' for AveragePooling2D. 206 | False for strided conv + batch norm + activation. 207 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 208 | 'nearest' for Upsampling2D with nearest interpolation. 209 | False for Conv2DTranspose + batch norm + activation. 210 | name: prefix of the created keras layers. 211 | 212 | Output 213 | ---------- 214 | model: a keras model. 215 | 216 | * This is a modified version of V-net for 2-d inputw. 217 | * The original work supports `pool=False` only. 218 | If pool is True, 'max', or 'ave', an additional conv2d layer will be applied. 219 | * All the 5-by-5 convolutional kernels are changed (and fixed) to 3-by-3. 220 | ''' 221 | 222 | IN = Input(input_size) 223 | X = IN 224 | # base 225 | X = vnet_2d_base(X, filter_num, res_num_ini=res_num_ini, res_num_max= res_num_max, 226 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name=name) 227 | # output layer 228 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name)) 229 | 230 | # functional API model 231 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name)) 232 | 233 | return model 234 | 235 | -------------------------------------------------------------------------------- /keras_unet_collection/activations.py: -------------------------------------------------------------------------------- 1 | 2 | from tensorflow import math 3 | from tensorflow.keras.layers import Layer 4 | import tensorflow.keras.backend as K 5 | 6 | 7 | def gelu_(X): 8 | 9 | return 0.5*X*(1.0 + math.tanh(0.7978845608028654*(X + 0.044715*math.pow(X, 3)))) 10 | 11 | def snake_(X, beta): 12 | 13 | return X + (1/beta)*math.square(math.sin(beta*X)) 14 | 15 | 16 | class GELU(Layer): 17 | ''' 18 | Gaussian Error Linear Unit (GELU), an alternative of ReLU 19 | 20 | Y = GELU()(X) 21 | 22 | ---------- 23 | Hendrycks, D. and Gimpel, K., 2016. Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415. 24 | 25 | Usage: use it as a tf.keras.Layer 26 | 27 | 28 | ''' 29 | def __init__(self, trainable=False, **kwargs): 30 | super(GELU, self).__init__(**kwargs) 31 | self.supports_masking = True 32 | self.trainable = trainable 33 | 34 | def build(self, input_shape): 35 | super(GELU, self).build(input_shape) 36 | 37 | def call(self, inputs, mask=None): 38 | return gelu_(inputs) 39 | 40 | def get_config(self): 41 | config = {'trainable': self.trainable} 42 | base_config = super(GELU, self).get_config() 43 | return dict(list(base_config.items()) + list(config.items())) 44 | def compute_output_shape(self, input_shape): 45 | return input_shape 46 | 47 | 48 | class Snake(Layer): 49 | ''' 50 | Snake activation function $X + (1/b)*sin^2(b*X)$. Proposed to learn periodic targets. 51 | 52 | Y = Snake(beta=0.5, trainable=False)(X) 53 | 54 | ---------- 55 | Ziyin, L., Hartwig, T. and Ueda, M., 2020. Neural networks fail to learn periodic functions 56 | and how to fix it. arXiv preprint arXiv:2006.08195. 57 | 58 | ''' 59 | def __init__(self, beta=0.5, trainable=False, **kwargs): 60 | super(Snake, self).__init__(**kwargs) 61 | self.supports_masking = True 62 | self.beta = beta 63 | self.trainable = trainable 64 | 65 | def build(self, input_shape): 66 | self.beta_factor = K.variable(self.beta, dtype=K.floatx(), name='beta_factor') 67 | if self.trainable: 68 | self._trainable_weights.append(self.beta_factor) 69 | 70 | super(Snake, self).build(input_shape) 71 | 72 | def call(self, inputs, mask=None): 73 | return snake_(inputs, self.beta_factor) 74 | 75 | def get_config(self): 76 | config = {'beta': self.get_weights()[0] if self.trainable else self.beta, 'trainable': self.trainable} 77 | base_config = super(Snake, self).get_config() 78 | return dict(list(base_config.items()) + list(config.items())) 79 | 80 | def compute_output_shape(self, input_shape): 81 | return input_shape 82 | -------------------------------------------------------------------------------- /keras_unet_collection/backbones.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import warnings 5 | 6 | depreciated_warning = "\n----------------------------------------\n`backbones`is depreciated, use `keras_unet_collection.base` instead.\ne.g.\nfrom keras_unet_collection import base\nbase.unet_2d_base(...);\n----------------------------------------" 7 | warnings.warn(depreciated_warning); 8 | 9 | from keras_unet_collection._model_unet_2d import unet_2d_base as unet_2d_backbone 10 | from keras_unet_collection._model_unet_plus_2d import unet_plus_2d_base as unet_plus_2d_backbone 11 | from keras_unet_collection._model_r2_unet_2d import r2_unet_2d_base as r2_unet_2d_backbone 12 | from keras_unet_collection._model_att_unet_2d import att_unet_2d_base as att_unet_2d_backbone 13 | from keras_unet_collection._model_resunet_a_2d import resunet_a_2d_base as resunet_a_2d_backbone 14 | from keras_unet_collection._model_u2net_2d import u2net_2d_base as u2net_2d_backbone 15 | from keras_unet_collection._model_unet_3plus_2d import unet_3plus_2d_base as unet_3plus_2d_backbone 16 | -------------------------------------------------------------------------------- /keras_unet_collection/base.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection._model_unet_2d import unet_2d_base 5 | from keras_unet_collection._model_vnet_2d import vnet_2d_base 6 | from keras_unet_collection._model_unet_plus_2d import unet_plus_2d_base 7 | from keras_unet_collection._model_r2_unet_2d import r2_unet_2d_base 8 | from keras_unet_collection._model_att_unet_2d import att_unet_2d_base 9 | from keras_unet_collection._model_resunet_a_2d import resunet_a_2d_base 10 | from keras_unet_collection._model_u2net_2d import u2net_2d_base 11 | from keras_unet_collection._model_unet_3plus_2d import unet_3plus_2d_base 12 | from keras_unet_collection._model_transunet_2d import transunet_2d_base 13 | from keras_unet_collection._model_swin_unet_2d import swin_unet_2d_base 14 | -------------------------------------------------------------------------------- /keras_unet_collection/layer_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection.activations import GELU, Snake 5 | from tensorflow import expand_dims 6 | from tensorflow.compat.v1 import image 7 | from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, UpSampling2D, Conv2DTranspose, GlobalAveragePooling2D 8 | from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Lambda 9 | from tensorflow.keras.layers import BatchNormalization, Activation, concatenate, multiply, add 10 | from tensorflow.keras.layers import ReLU, LeakyReLU, PReLU, ELU, Softmax 11 | 12 | def decode_layer(X, channel, pool_size, unpool, kernel_size=3, 13 | activation='ReLU', batch_norm=False, name='decode'): 14 | ''' 15 | An overall decode layer, based on either upsampling or trans conv. 16 | 17 | decode_layer(X, channel, pool_size, unpool, kernel_size=3, 18 | activation='ReLU', batch_norm=False, name='decode') 19 | 20 | Input 21 | ---------- 22 | X: input tensor. 23 | pool_size: the decoding factor. 24 | channel: (for trans conv only) number of convolution filters. 25 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation. 26 | 'nearest' for Upsampling2D with nearest interpolation. 27 | False for Conv2DTranspose + batch norm + activation. 28 | kernel_size: size of convolution kernels. 29 | If kernel_size='auto', then it equals to the `pool_size`. 30 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU. 31 | batch_norm: True for batch normalization, False otherwise. 32 | name: prefix of the created keras layers. 33 | 34 | Output 35 | ---------- 36 | X: output tensor. 37 | 38 | * The defaut: `kernel_size=3`, is suitable for `pool_size=2`. 39 | 40 | ''' 41 | # parsers 42 | if unpool is False: 43 | # trans conv configurations 44 | bias_flag = not batch_norm 45 | 46 | elif unpool == 'nearest': 47 | # upsample2d configurations 48 | unpool = True 49 | interp = 'nearest' 50 | 51 | elif (unpool is True) or (unpool == 'bilinear'): 52 | # upsample2d configurations 53 | unpool = True 54 | interp = 'bilinear' 55 | 56 | else: 57 | raise ValueError('Invalid unpool keyword') 58 | 59 | if unpool: 60 | X = UpSampling2D(size=(pool_size, pool_size), interpolation=interp, name='{}_unpool'.format(name))(X) 61 | else: 62 | if kernel_size == 'auto': 63 | kernel_size = pool_size 64 | 65 | X = Conv2DTranspose(channel, kernel_size, strides=(pool_size, pool_size), 66 | padding='same', name='{}_trans_conv'.format(name))(X) 67 | 68 | # batch normalization 69 | if batch_norm: 70 | X = BatchNormalization(axis=3, name='{}_bn'.format(name))(X) 71 | 72 | # activation 73 | if activation is not None: 74 | activation_func = eval(activation) 75 | X = activation_func(name='{}_activation'.format(name))(X) 76 | 77 | return X 78 | 79 | def encode_layer(X, channel, pool_size, pool, kernel_size='auto', 80 | activation='ReLU', batch_norm=False, name='encode'): 81 | ''' 82 | An overall encode layer, based on one of the: 83 | (1) max-pooling, (2) average-pooling, (3) strided conv2d. 84 | 85 | encode_layer(X, channel, pool_size, pool, kernel_size='auto', 86 | activation='ReLU', batch_norm=False, name='encode') 87 | 88 | Input 89 | ---------- 90 | X: input tensor. 91 | pool_size: the encoding factor. 92 | channel: (for strided conv only) number of convolution filters. 93 | pool: True or 'max' for MaxPooling2D. 94 | 'ave' for AveragePooling2D. 95 | False for strided conv + batch norm + activation. 96 | kernel_size: size of convolution kernels. 97 | If kernel_size='auto', then it equals to the `pool_size`. 98 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU. 99 | batch_norm: True for batch normalization, False otherwise. 100 | name: prefix of the created keras layers. 101 | 102 | Output 103 | ---------- 104 | X: output tensor. 105 | 106 | ''' 107 | # parsers 108 | if (pool in [False, True, 'max', 'ave']) is not True: 109 | raise ValueError('Invalid pool keyword') 110 | 111 | # maxpooling2d as default 112 | if pool is True: 113 | pool = 'max' 114 | 115 | elif pool is False: 116 | # stride conv configurations 117 | bias_flag = not batch_norm 118 | 119 | if pool == 'max': 120 | X = MaxPooling2D(pool_size=(pool_size, pool_size), name='{}_maxpool'.format(name))(X) 121 | 122 | elif pool == 'ave': 123 | X = AveragePooling2D(pool_size=(pool_size, pool_size), name='{}_avepool'.format(name))(X) 124 | 125 | else: 126 | if kernel_size == 'auto': 127 | kernel_size = pool_size 128 | 129 | # linear convolution with strides 130 | X = Conv2D(channel, kernel_size, strides=(pool_size, pool_size), 131 | padding='valid', use_bias=bias_flag, name='{}_stride_conv'.format(name))(X) 132 | 133 | # batch normalization 134 | if batch_norm: 135 | X = BatchNormalization(axis=3, name='{}_bn'.format(name))(X) 136 | 137 | # activation 138 | if activation is not None: 139 | activation_func = eval(activation) 140 | X = activation_func(name='{}_activation'.format(name))(X) 141 | 142 | return X 143 | 144 | def attention_gate(X, g, channel, 145 | activation='ReLU', 146 | attention='add', name='att'): 147 | ''' 148 | Self-attention gate modified from Oktay et al. 2018. 149 | 150 | attention_gate(X, g, channel, activation='ReLU', attention='add', name='att') 151 | 152 | Input 153 | ---------- 154 | X: input tensor, i.e., key and value. 155 | g: gated tensor, i.e., query. 156 | channel: number of intermediate channel. 157 | Oktay et al. (2018) did not specify (denoted as F_int). 158 | intermediate channel is expected to be smaller than the input channel. 159 | activation: a nonlinear attnetion activation. 160 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'. 161 | attention: 'add' for additive attention; 'multiply' for multiplicative attention. 162 | Oktay et al. 2018 applied additive attention. 163 | name: prefix of the created keras layers. 164 | 165 | Output 166 | ---------- 167 | X_att: output tensor. 168 | 169 | ''' 170 | activation_func = eval(activation) 171 | attention_func = eval(attention) 172 | 173 | # mapping the input tensor to the intermediate channel 174 | theta_att = Conv2D(channel, 1, use_bias=True, name='{}_theta_x'.format(name))(X) 175 | 176 | # mapping the gate tensor 177 | phi_g = Conv2D(channel, 1, use_bias=True, name='{}_phi_g'.format(name))(g) 178 | 179 | # ----- attention learning ----- # 180 | query = attention_func([theta_att, phi_g], name='{}_add'.format(name)) 181 | 182 | # nonlinear activation 183 | f = activation_func(name='{}_activation'.format(name))(query) 184 | 185 | # linear transformation 186 | psi_f = Conv2D(1, 1, use_bias=True, name='{}_psi_f'.format(name))(f) 187 | # ------------------------------ # 188 | 189 | # sigmoid activation as attention coefficients 190 | coef_att = Activation('sigmoid', name='{}_sigmoid'.format(name))(psi_f) 191 | 192 | # multiplicative attention masking 193 | X_att = multiply([X, coef_att], name='{}_masking'.format(name)) 194 | 195 | return X_att 196 | 197 | def CONV_stack(X, channel, kernel_size=3, stack_num=2, 198 | dilation_rate=1, activation='ReLU', 199 | batch_norm=False, name='conv_stack'): 200 | ''' 201 | Stacked convolutional layers: 202 | (Convolutional layer --> batch normalization --> Activation)*stack_num 203 | 204 | CONV_stack(X, channel, kernel_size=3, stack_num=2, dilation_rate=1, activation='ReLU', 205 | batch_norm=False, name='conv_stack') 206 | 207 | 208 | Input 209 | ---------- 210 | X: input tensor. 211 | channel: number of convolution filters. 212 | kernel_size: size of 2-d convolution kernels. 213 | stack_num: number of stacked Conv2D-BN-Activation layers. 214 | dilation_rate: optional dilated convolution kernel. 215 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU. 216 | batch_norm: True for batch normalization, False otherwise. 217 | name: prefix of the created keras layers. 218 | 219 | Output 220 | ---------- 221 | X: output tensor 222 | 223 | ''' 224 | 225 | bias_flag = not batch_norm 226 | 227 | # stacking Convolutional layers 228 | for i in range(stack_num): 229 | 230 | activation_func = eval(activation) 231 | 232 | # linear convolution 233 | X = Conv2D(channel, kernel_size, padding='same', use_bias=bias_flag, 234 | dilation_rate=dilation_rate, name='{}_{}'.format(name, i))(X) 235 | 236 | # batch normalization 237 | if batch_norm: 238 | X = BatchNormalization(axis=3, name='{}_{}_bn'.format(name, i))(X) 239 | 240 | # activation 241 | activation_func = eval(activation) 242 | X = activation_func(name='{}_{}_activation'.format(name, i))(X) 243 | 244 | return X 245 | 246 | def Res_CONV_stack(X, X_skip, channel, res_num, activation='ReLU', batch_norm=False, name='res_conv'): 247 | ''' 248 | Stacked convolutional layers with residual path. 249 | 250 | Res_CONV_stack(X, X_skip, channel, res_num, activation='ReLU', batch_norm=False, name='res_conv') 251 | 252 | Input 253 | ---------- 254 | X: input tensor. 255 | X_skip: the tensor that does go into the residual path 256 | can be a copy of X (e.g., the identity block of ResNet). 257 | channel: number of convolution filters. 258 | res_num: number of convolutional layers within the residual path. 259 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'. 260 | batch_norm: True for batch normalization, False otherwise. 261 | name: prefix of the created keras layers. 262 | 263 | Output 264 | ---------- 265 | X: output tensor. 266 | 267 | ''' 268 | X = CONV_stack(X, channel, kernel_size=3, stack_num=res_num, dilation_rate=1, 269 | activation=activation, batch_norm=batch_norm, name=name) 270 | 271 | X = add([X_skip, X], name='{}_add'.format(name)) 272 | 273 | activation_func = eval(activation) 274 | X = activation_func(name='{}_add_activation'.format(name))(X) 275 | 276 | return X 277 | 278 | def Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, dilation_rate=1, activation='ReLU', batch_norm=False, name='sep_conv'): 279 | ''' 280 | Depthwise separable convolution with (optional) dilated convolution kernel and batch normalization. 281 | 282 | Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, dilation_rate=1, activation='ReLU', batch_norm=False, name='sep_conv') 283 | 284 | Input 285 | ---------- 286 | X: input tensor. 287 | channel: number of convolution filters. 288 | kernel_size: size of 2-d convolution kernels. 289 | stack_num: number of stacked depthwise-pointwise layers. 290 | dilation_rate: optional dilated convolution kernel. 291 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'. 292 | batch_norm: True for batch normalization, False otherwise. 293 | name: prefix of the created keras layers. 294 | 295 | Output 296 | ---------- 297 | X: output tensor. 298 | 299 | ''' 300 | 301 | activation_func = eval(activation) 302 | bias_flag = not batch_norm 303 | 304 | for i in range(stack_num): 305 | X = DepthwiseConv2D(kernel_size, dilation_rate=dilation_rate, padding='same', 306 | use_bias=bias_flag, name='{}_{}_depthwise'.format(name, i))(X) 307 | 308 | if batch_norm: 309 | X = BatchNormalization(name='{}_{}_depthwise_BN'.format(name, i))(X) 310 | 311 | X = activation_func(name='{}_{}_depthwise_activation'.format(name, i))(X) 312 | 313 | X = Conv2D(channel, (1, 1), padding='same', use_bias=bias_flag, name='{}_{}_pointwise'.format(name, i))(X) 314 | 315 | if batch_norm: 316 | X = BatchNormalization(name='{}_{}_pointwise_BN'.format(name, i))(X) 317 | 318 | X = activation_func(name='{}_{}_pointwise_activation'.format(name, i))(X) 319 | 320 | return X 321 | 322 | def ASPP_conv(X, channel, activation='ReLU', batch_norm=True, name='aspp'): 323 | ''' 324 | Atrous Spatial Pyramid Pooling (ASPP). 325 | 326 | ASPP_conv(X, channel, activation='ReLU', batch_norm=True, name='aspp') 327 | 328 | ---------- 329 | Wang, Y., Liang, B., Ding, M. and Li, J., 2019. Dense semantic labeling 330 | with atrous spatial pyramid pooling and decoder for high-resolution remote 331 | sensing imagery. Remote Sensing, 11(1), p.20. 332 | 333 | Input 334 | ---------- 335 | X: input tensor. 336 | channel: number of convolution filters. 337 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU. 338 | batch_norm: True for batch normalization, False otherwise. 339 | name: prefix of the created keras layers. 340 | 341 | Output 342 | ---------- 343 | X: output tensor. 344 | 345 | * dilation rates are fixed to `[6, 9, 12]`. 346 | ''' 347 | 348 | activation_func = eval(activation) 349 | bias_flag = not batch_norm 350 | 351 | shape_before = X.get_shape().as_list() 352 | b4 = GlobalAveragePooling2D(name='{}_avepool_b4'.format(name))(X) 353 | 354 | b4 = expand_dims(expand_dims(b4, 1), 1, name='{}_expdim_b4'.format(name)) 355 | 356 | b4 = Conv2D(channel, 1, padding='same', use_bias=bias_flag, name='{}_conv_b4'.format(name))(b4) 357 | 358 | if batch_norm: 359 | b4 = BatchNormalization(name='{}_conv_b4_BN'.format(name))(b4) 360 | 361 | b4 = activation_func(name='{}_conv_b4_activation'.format(name))(b4) 362 | 363 | # <----- tensorflow v1 resize. 364 | b4 = Lambda(lambda X: image.resize(X, shape_before[1:3], method='bilinear', align_corners=True), 365 | name='{}_resize_b4'.format(name))(b4) 366 | 367 | b0 = Conv2D(channel, (1, 1), padding='same', use_bias=bias_flag, name='{}_conv_b0'.format(name))(X) 368 | 369 | if batch_norm: 370 | b0 = BatchNormalization(name='{}_conv_b0_BN'.format(name))(b0) 371 | 372 | b0 = activation_func(name='{}_conv_b0_activation'.format(name))(b0) 373 | 374 | # dilation rates are fixed to `[6, 9, 12]`. 375 | b_r6 = Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, activation='ReLU', 376 | dilation_rate=6, batch_norm=True, name='{}_sepconv_r6'.format(name)) 377 | b_r9 = Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, activation='ReLU', 378 | dilation_rate=9, batch_norm=True, name='{}_sepconv_r9'.format(name)) 379 | b_r12 = Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, activation='ReLU', 380 | dilation_rate=12, batch_norm=True, name='{}_sepconv_r12'.format(name)) 381 | 382 | return concatenate([b4, b0, b_r6, b_r9, b_r12]) 383 | 384 | def CONV_output(X, n_labels, kernel_size=1, activation='Softmax', name='conv_output'): 385 | ''' 386 | Convolutional layer with output activation. 387 | 388 | CONV_output(X, n_labels, kernel_size=1, activation='Softmax', name='conv_output') 389 | 390 | Input 391 | ---------- 392 | X: input tensor. 393 | n_labels: number of classification label(s). 394 | kernel_size: size of 2-d convolution kernels. Default is 1-by-1. 395 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'. 396 | Default option is 'Softmax'. 397 | if None is received, then linear activation is applied. 398 | name: prefix of the created keras layers. 399 | 400 | Output 401 | ---------- 402 | X: output tensor. 403 | 404 | ''' 405 | 406 | X = Conv2D(n_labels, kernel_size, padding='same', use_bias=True, name=name)(X) 407 | 408 | if activation: 409 | 410 | if activation == 'Sigmoid': 411 | X = Activation('sigmoid', name='{}_activation'.format(name))(X) 412 | 413 | else: 414 | activation_func = eval(activation) 415 | X = activation_func(name='{}_activation'.format(name))(X) 416 | 417 | return X 418 | 419 | -------------------------------------------------------------------------------- /keras_unet_collection/losses.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow.keras.backend as K 7 | 8 | def _crps_tf(y_true, y_pred, factor=0.05): 9 | 10 | ''' 11 | core of (pseudo) CRPS loss. 12 | 13 | y_true: two-dimensional arrays 14 | y_pred: two-dimensional arrays 15 | factor: importance of std term 16 | ''' 17 | 18 | # mean absolute error 19 | mae = K.mean(tf.abs(y_pred - y_true)) 20 | 21 | dist = tf.math.reduce_std(y_pred) 22 | 23 | return mae - factor*dist 24 | 25 | def crps2d_tf(y_true, y_pred, factor=0.05): 26 | 27 | ''' 28 | (Experimental) 29 | An approximated continuous ranked probability score (CRPS) loss function: 30 | 31 | CRPS = mean_abs_err - factor * std 32 | 33 | * Note that the "real CRPS" = mean_abs_err - mean_pairwise_abs_diff 34 | 35 | Replacing mean pairwise absolute difference by standard deviation offers 36 | a complexity reduction from O(N^2) to O(N*logN) 37 | 38 | ** factor > 0.1 may yield negative loss values. 39 | 40 | Compatible with high-level Keras training methods 41 | 42 | Input 43 | ---------- 44 | y_true: training target with shape=(batch_num, x, y, 1) 45 | y_pred: a forward pass with shape=(batch_num, x, y, 1) 46 | factor: relative importance of standard deviation term. 47 | 48 | ''' 49 | 50 | y_pred = tf.convert_to_tensor(y_pred) 51 | y_true = tf.cast(y_true, y_pred.dtype) 52 | 53 | y_pred = tf.squeeze(y_pred) 54 | y_true = tf.squeeze(y_true) 55 | 56 | batch_num = y_pred.shape.as_list()[0] 57 | 58 | crps_out = 0 59 | for i in range(batch_num): 60 | crps_out += _crps_tf(y_true[i, ...], y_pred[i, ...], factor=factor) 61 | 62 | return crps_out/batch_num 63 | 64 | 65 | def _crps_np(y_true, y_pred, factor=0.05): 66 | 67 | ''' 68 | Numpy version of _crps_tf. 69 | ''' 70 | 71 | # mean absolute error 72 | mae = np.nanmean(np.abs(y_pred - y_true)) 73 | dist = np.nanstd(y_pred) 74 | 75 | return mae - factor*dist 76 | 77 | def crps2d_np(y_true, y_pred, factor=0.05): 78 | 79 | ''' 80 | (Experimental) 81 | Nunpy version of `crps2d_tf`. 82 | 83 | Documentation refers to `crps2d_tf`. 84 | ''' 85 | 86 | y_true = np.squeeze(y_true) 87 | y_pred = np.squeeze(y_pred) 88 | 89 | batch_num = len(y_pred) 90 | 91 | crps_out = 0 92 | for i in range(batch_num): 93 | crps_out += _crps_np(y_true[i, ...], y_pred[i, ...], factor=factor) 94 | 95 | return crps_out/batch_num 96 | 97 | # ========================= # 98 | # Dice loss and variants 99 | 100 | def dice_coef(y_true, y_pred, const=K.epsilon()): 101 | ''' 102 | Sørensen–Dice coefficient for 2-d samples. 103 | 104 | Input 105 | ---------- 106 | y_true, y_pred: predicted outputs and targets. 107 | const: a constant that smooths the loss gradient and reduces numerical instabilities. 108 | 109 | ''' 110 | 111 | # flatten 2-d tensors 112 | y_true_pos = tf.reshape(y_true, [-1]) 113 | y_pred_pos = tf.reshape(y_pred, [-1]) 114 | 115 | # get true pos (TP), false neg (FN), false pos (FP). 116 | true_pos = tf.reduce_sum(y_true_pos * y_pred_pos) 117 | false_neg = tf.reduce_sum(y_true_pos * (1-y_pred_pos)) 118 | false_pos = tf.reduce_sum((1-y_true_pos) * y_pred_pos) 119 | 120 | # 2TP/(2TP+FP+FN) == 2TP/() 121 | coef_val = (2.0 * true_pos + const)/(2.0 * true_pos + false_pos + false_neg) 122 | 123 | return coef_val 124 | 125 | def dice(y_true, y_pred, const=K.epsilon()): 126 | ''' 127 | Sørensen–Dice Loss. 128 | 129 | dice(y_true, y_pred, const=K.epsilon()) 130 | 131 | Input 132 | ---------- 133 | const: a constant that smooths the loss gradient and reduces numerical instabilities. 134 | 135 | ''' 136 | # tf tensor casting 137 | y_pred = tf.convert_to_tensor(y_pred) 138 | y_true = tf.cast(y_true, y_pred.dtype) 139 | 140 | # <--- squeeze-out length-1 dimensions. 141 | y_pred = tf.squeeze(y_pred) 142 | y_true = tf.squeeze(y_true) 143 | 144 | loss_val = 1 - dice_coef(y_true, y_pred, const=const) 145 | 146 | return loss_val 147 | 148 | # ========================= # 149 | # Tversky loss and variants 150 | 151 | def tversky_coef(y_true, y_pred, alpha=0.5, const=K.epsilon()): 152 | ''' 153 | Weighted Sørensen–Dice coefficient. 154 | 155 | Input 156 | ---------- 157 | y_true, y_pred: predicted outputs and targets. 158 | const: a constant that smooths the loss gradient and reduces numerical instabilities. 159 | 160 | ''' 161 | 162 | # flatten 2-d tensors 163 | y_true_pos = tf.reshape(y_true, [-1]) 164 | y_pred_pos = tf.reshape(y_pred, [-1]) 165 | 166 | # get true pos (TP), false neg (FN), false pos (FP). 167 | true_pos = tf.reduce_sum(y_true_pos * y_pred_pos) 168 | false_neg = tf.reduce_sum(y_true_pos * (1-y_pred_pos)) 169 | false_pos = tf.reduce_sum((1-y_true_pos) * y_pred_pos) 170 | 171 | # TP/(TP + a*FN + b*FP); a+b = 1 172 | coef_val = (true_pos + const)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + const) 173 | 174 | return coef_val 175 | 176 | def tversky(y_true, y_pred, alpha=0.5, const=K.epsilon()): 177 | ''' 178 | Tversky Loss. 179 | 180 | tversky(y_true, y_pred, alpha=0.5, const=K.epsilon()) 181 | 182 | ---------- 183 | Hashemi, S.R., Salehi, S.S.M., Erdogmus, D., Prabhu, S.P., Warfield, S.K. and Gholipour, A., 2018. 184 | Tversky as a loss function for highly unbalanced image segmentation using 3d fully convolutional deep networks. 185 | arXiv preprint arXiv:1803.11078. 186 | 187 | Input 188 | ---------- 189 | alpha: tunable parameter within [0, 1]. Alpha handles imbalance classification cases. 190 | const: a constant that smooths the loss gradient and reduces numerical instabilities. 191 | 192 | ''' 193 | # tf tensor casting 194 | y_pred = tf.convert_to_tensor(y_pred) 195 | y_true = tf.cast(y_true, y_pred.dtype) 196 | 197 | # <--- squeeze-out length-1 dimensions. 198 | y_pred = tf.squeeze(y_pred) 199 | y_true = tf.squeeze(y_true) 200 | 201 | loss_val = 1 - tversky_coef(y_true, y_pred, alpha=alpha, const=const) 202 | 203 | return loss_val 204 | 205 | def focal_tversky(y_true, y_pred, alpha=0.5, gamma=4/3, const=K.epsilon()): 206 | 207 | ''' 208 | Focal Tversky Loss (FTL) 209 | 210 | focal_tversky(y_true, y_pred, alpha=0.5, gamma=4/3) 211 | 212 | ---------- 213 | Abraham, N. and Khan, N.M., 2019, April. A novel focal tversky loss function with improved 214 | attention u-net for lesion segmentation. In 2019 IEEE 16th International Symposium on Biomedical Imaging 215 | (ISBI 2019) (pp. 683-687). IEEE. 216 | 217 | ---------- 218 | Input 219 | alpha: tunable parameter within [0, 1]. Alpha handles imbalance classification cases 220 | gamma: tunable parameter within [1, 3]. 221 | const: a constant that smooths the loss gradient and reduces numerical instabilities. 222 | 223 | ''' 224 | # tf tensor casting 225 | y_pred = tf.convert_to_tensor(y_pred) 226 | y_true = tf.cast(y_true, y_pred.dtype) 227 | 228 | # <--- squeeze-out length-1 dimensions. 229 | y_pred = tf.squeeze(y_pred) 230 | y_true = tf.squeeze(y_true) 231 | 232 | # (Tversky loss)**(1/gamma) 233 | loss_val = tf.math.pow((1-tversky_coef(y_true, y_pred, alpha=alpha, const=const)), 1/gamma) 234 | 235 | return loss_val 236 | 237 | # ========================= # 238 | # MS-SSIM 239 | 240 | def ms_ssim(y_true, y_pred, **kwargs): 241 | """ 242 | Multiscale structural similarity (MS-SSIM) loss. 243 | 244 | ms_ssim(y_true, y_pred, **tf_ssim_kw) 245 | 246 | ---------- 247 | Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. Multiscale structural similarity for image quality assessment. 248 | In The Thrity-Seventh Asilomar Conference on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee. 249 | 250 | ---------- 251 | Input 252 | kwargs: keywords of `tf.image.ssim_multiscale` 253 | https://www.tensorflow.org/api_docs/python/tf/image/ssim_multiscale 254 | 255 | *Issues of `tf.image.ssim_multiscale`refers to: 256 | https://stackoverflow.com/questions/57127626/error-in-calculation-of-inbuilt-ms-ssim-function-in-tensorflow 257 | 258 | """ 259 | 260 | y_pred = tf.convert_to_tensor(y_pred) 261 | y_true = tf.cast(y_true, y_pred.dtype) 262 | 263 | y_pred = tf.squeeze(y_pred) 264 | y_true = tf.squeeze(y_true) 265 | 266 | tf_ms_ssim = tf.image.ssim_multiscale(y_true, y_pred, **kwargs) 267 | 268 | return 1 - tf_ms_ssim 269 | 270 | # ======================== # 271 | 272 | def iou_box_coef(y_true, y_pred, mode='giou', dtype=tf.float32): 273 | 274 | """ 275 | Inersection over Union (IoU) and generalized IoU coefficients for bounding boxes. 276 | 277 | iou_box_coef(y_true, y_pred, mode='giou', dtype=tf.float32) 278 | 279 | ---------- 280 | Rezatofighi, H., Tsoi, N., Gwak, J., Sadeghian, A., Reid, I. and Savarese, S., 2019. 281 | Generalized intersection over union: A metric and a loss for bounding box regression. 282 | In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 658-666). 283 | 284 | ---------- 285 | Input 286 | y_true: the target bounding box. 287 | y_pred: the predicted bounding box. 288 | 289 | Elements of a bounding box should be organized as: [y_min, x_min, y_max, x_max]. 290 | 291 | mode: 'iou' for IoU coeff (i.e., Jaccard index); 292 | 'giou' for generalized IoU coeff. 293 | 294 | dtype: the data type of input tensors. 295 | Default is tf.float32. 296 | 297 | """ 298 | 299 | zero = tf.convert_to_tensor(0.0, dtype) 300 | 301 | # subtrack bounding box coords 302 | ymin_true, xmin_true, ymax_true, xmax_true = tf.unstack(y_true, 4, axis=-1) 303 | ymin_pred, xmin_pred, ymax_pred, xmax_pred = tf.unstack(y_pred, 4, axis=-1) 304 | 305 | # true area 306 | w_true = tf.maximum(zero, xmax_true - xmin_true) 307 | h_true = tf.maximum(zero, ymax_true - ymin_true) 308 | area_true = w_true * h_true 309 | 310 | # pred area 311 | w_pred = tf.maximum(zero, xmax_pred - xmin_pred) 312 | h_pred = tf.maximum(zero, ymax_pred - ymin_pred) 313 | area_pred = w_pred * h_pred 314 | 315 | # intersections 316 | intersect_ymin = tf.maximum(ymin_true, ymin_pred) 317 | intersect_xmin = tf.maximum(xmin_true, xmin_pred) 318 | intersect_ymax = tf.minimum(ymax_true, ymax_pred) 319 | intersect_xmax = tf.minimum(xmax_true, xmax_pred) 320 | 321 | w_intersect = tf.maximum(zero, intersect_xmax - intersect_xmin) 322 | h_intersect = tf.maximum(zero, intersect_ymax - intersect_ymin) 323 | area_intersect = w_intersect * h_intersect 324 | 325 | # IoU 326 | area_union = area_true + area_pred - area_intersect 327 | iou = tf.math.divide_no_nan(area_intersect, area_union) 328 | 329 | if mode == "iou": 330 | 331 | return iou 332 | 333 | else: 334 | 335 | # encolsed coords 336 | enclose_ymin = tf.minimum(ymin_true, ymin_pred) 337 | enclose_xmin = tf.minimum(xmin_true, xmin_pred) 338 | enclose_ymax = tf.maximum(ymax_true, ymax_pred) 339 | enclose_xmax = tf.maximum(xmax_true, xmax_pred) 340 | 341 | # enclosed area 342 | w_enclose = tf.maximum(zero, enclose_xmax - enclose_xmin) 343 | h_enclose = tf.maximum(zero, enclose_ymax - enclose_ymin) 344 | area_enclose = w_enclose * h_enclose 345 | 346 | # generalized IoU 347 | giou = iou - tf.math.divide_no_nan((area_enclose - area_union), area_enclose) 348 | 349 | return giou 350 | 351 | def iou_box(y_true, y_pred, mode='giou', dtype=tf.float32): 352 | """ 353 | Inersection over Union (IoU) and generalized IoU losses for bounding boxes. 354 | 355 | iou_box(y_true, y_pred, mode='giou', dtype=tf.float32) 356 | 357 | ---------- 358 | Rezatofighi, H., Tsoi, N., Gwak, J., Sadeghian, A., Reid, I. and Savarese, S., 2019. 359 | Generalized intersection over union: A metric and a loss for bounding box regression. 360 | In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 658-666). 361 | 362 | ---------- 363 | Input 364 | y_true: the target bounding box. 365 | y_pred: the predicted bounding box. 366 | 367 | Elements of a bounding box should be organized as: [y_min, x_min, y_max, x_max]. 368 | 369 | mode: 'iou' for IoU coeff (i.e., Jaccard index); 370 | 'giou' for generalized IoU coeff. 371 | 372 | dtype: the data type of input tensors. 373 | Default is tf.float32. 374 | 375 | """ 376 | 377 | y_pred = tf.convert_to_tensor(y_pred) 378 | y_pred = tf.cast(y_pred, dtype) 379 | 380 | y_true = tf.cast(y_true, dtype) 381 | 382 | y_pred = tf.squeeze(y_pred) 383 | y_true = tf.squeeze(y_true) 384 | 385 | return 1 - iou_box_coef(y_true, y_pred, mode=mode, dtype=dtype) 386 | 387 | 388 | def iou_seg(y_true, y_pred, dtype=tf.float32): 389 | """ 390 | Inersection over Union (IoU) loss for segmentation maps. 391 | 392 | iou_seg(y_true, y_pred, dtype=tf.float32) 393 | 394 | ---------- 395 | Rahman, M.A. and Wang, Y., 2016, December. Optimizing intersection-over-union in deep neural networks for 396 | image segmentation. In International symposium on visual computing (pp. 234-244). Springer, Cham. 397 | 398 | ---------- 399 | Input 400 | y_true: segmentation targets, c.f. `keras.losses.categorical_crossentropy` 401 | y_pred: segmentation predictions. 402 | 403 | dtype: the data type of input tensors. 404 | Default is tf.float32. 405 | 406 | """ 407 | 408 | # tf tensor casting 409 | y_pred = tf.convert_to_tensor(y_pred) 410 | y_pred = tf.cast(y_pred, dtype) 411 | y_true = tf.cast(y_true, y_pred.dtype) 412 | 413 | y_pred = tf.squeeze(y_pred) 414 | y_true = tf.squeeze(y_true) 415 | 416 | y_true_pos = tf.reshape(y_true, [-1]) 417 | y_pred_pos = tf.reshape(y_pred, [-1]) 418 | 419 | area_intersect = tf.reduce_sum(tf.multiply(y_true_pos, y_pred_pos)) 420 | 421 | area_true = tf.reduce_sum(y_true_pos) 422 | area_pred = tf.reduce_sum(y_pred_pos) 423 | area_union = area_true + area_pred - area_intersect 424 | 425 | return 1-tf.math.divide_no_nan(area_intersect, area_union) 426 | 427 | # ========================= # 428 | # Semi-hard triplet 429 | 430 | def triplet_1d(y_true, y_pred, N, margin=5.0): 431 | 432 | ''' 433 | (Experimental) 434 | Semi-hard triplet loss with one-dimensional vectors of anchor, positive, and negative. 435 | 436 | triplet_1d(y_true, y_pred, N, margin=5.0) 437 | 438 | Input 439 | ---------- 440 | y_true: a dummy input, not used within this function. Appeared as a requirment of tf.keras.loss function format. 441 | y_pred: a single pass of triplet training, with `shape=(batch_num, 3*embeded_vector_size)`. 442 | i.e., `y_pred` is the ordered and concatenated anchor, positive, and negative embeddings. 443 | N: Size (dimensions) of embedded vectors 444 | margin: a positive number that prevents negative loss. 445 | 446 | ''' 447 | 448 | # anchor sample pair separations. 449 | Embd_anchor = y_pred[:, 0:N] 450 | Embd_pos = y_pred[:, N:2*N] 451 | Embd_neg = y_pred[:, 2*N:] 452 | 453 | # squared distance measures 454 | d_pos = tf.reduce_sum(tf.square(Embd_anchor - Embd_pos), 1) 455 | d_neg = tf.reduce_sum(tf.square(Embd_anchor - Embd_neg), 1) 456 | loss_val = tf.maximum(0., margin + d_pos - d_neg) 457 | loss_val = tf.reduce_mean(loss_val) 458 | 459 | return loss_val -------------------------------------------------------------------------------- /keras_unet_collection/models.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from keras_unet_collection._model_unet_2d import unet_2d 5 | from keras_unet_collection._model_vnet_2d import vnet_2d 6 | from keras_unet_collection._model_unet_plus_2d import unet_plus_2d 7 | from keras_unet_collection._model_r2_unet_2d import r2_unet_2d 8 | from keras_unet_collection._model_att_unet_2d import att_unet_2d 9 | from keras_unet_collection._model_resunet_a_2d import resunet_a_2d 10 | from keras_unet_collection._model_u2net_2d import u2net_2d 11 | from keras_unet_collection._model_unet_3plus_2d import unet_3plus_2d 12 | from keras_unet_collection._model_transunet_2d import transunet_2d 13 | from keras_unet_collection._model_swin_unet_2d import swin_unet_2d 14 | -------------------------------------------------------------------------------- /keras_unet_collection/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | #import tensorflow as tf 5 | from tensorflow import keras 6 | 7 | from PIL import Image 8 | 9 | def dummy_loader(model_path): 10 | ''' 11 | Load a stored keras model and return its weights. 12 | 13 | Input 14 | ---------- 15 | The file path of the stored keras model. 16 | 17 | Output 18 | ---------- 19 | Weights of the model. 20 | 21 | ''' 22 | backbone = keras.models.load_model(model_path, compile=False) 23 | W = backbone.get_weights() 24 | return W 25 | 26 | def image_to_array(filenames, size, channel): 27 | ''' 28 | Converting RGB images to numpy arrays. 29 | 30 | Input 31 | ---------- 32 | filenames: an iterable of the path of image files 33 | size: the output size (height == width) of image. 34 | Processed through PIL.Image.NEAREST 35 | channel: number of image channels, e.g. channel=3 for RGB. 36 | 37 | Output 38 | ---------- 39 | An array with shape = (filenum, size, size, channel) 40 | 41 | ''' 42 | 43 | # number of files 44 | L = len(filenames) 45 | 46 | # allocation 47 | out = np.empty((L, size, size, channel)) 48 | 49 | # loop over filenames 50 | if channel == 1: 51 | for i, name in enumerate(filenames): 52 | with Image.open(name) as pixio: 53 | pix = pixio.resize((size, size), Image.NEAREST) 54 | out[i, ..., 0] = np.array(pix) 55 | else: 56 | for i, name in enumerate(filenames): 57 | with Image.open(name) as pixio: 58 | pix = pixio.resize((size, size), Image.NEAREST) 59 | out[i, ...] = np.array(pix)[..., :channel] 60 | return out[:, ::-1, ...] 61 | 62 | def shuffle_ind(L): 63 | ''' 64 | Generating random shuffled indices. 65 | 66 | Input 67 | ---------- 68 | L: an int that defines the largest index 69 | 70 | Output 71 | ---------- 72 | a numpy array of shuffled indices with shape = (L,) 73 | ''' 74 | 75 | ind = np.arange(L) 76 | np.random.shuffle(ind) 77 | return ind 78 | 79 | def freeze_model(model, freeze_batch_norm=False): 80 | ''' 81 | freeze a keras model 82 | 83 | Input 84 | ---------- 85 | model: a keras model 86 | freeze_batch_norm: False for not freezing batch notmalization layers 87 | ''' 88 | if freeze_batch_norm: 89 | for layer in model.layers: 90 | layer.trainable = False 91 | else: 92 | from tensorflow.keras.layers import BatchNormalization 93 | for layer in model.layers: 94 | if isinstance(layer, BatchNormalization): 95 | layer.trainable = True 96 | else: 97 | layer.trainable = False 98 | return model 99 | 100 | 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow_gpu==2.5.0 2 | numpy==1.19.5 3 | Pillow==8.3.0 4 | tensorflow==2.5.0 5 | --------------------------------------------------------------------------------