├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── __init__.py
├── examples
├── human-seg_atten-unet-backbone_coco.ipynb
├── segmentation_unet-three-plus_oxford-iiit.ipynb
└── user_guide_models.ipynb
├── keras_unet_collection
├── __init__.py
├── _backbone_zoo.py
├── _model_att_unet_2d.py
├── _model_r2_unet_2d.py
├── _model_resunet_a_2d.py
├── _model_swin_unet_2d.py
├── _model_transunet_2d.py
├── _model_u2net_2d.py
├── _model_unet_2d.py
├── _model_unet_3plus_2d.py
├── _model_unet_plus_2d.py
├── _model_vnet_2d.py
├── activations.py
├── backbones.py
├── base.py
├── layer_utils.py
├── losses.py
├── models.py
├── transformer_layers.py
└── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | *.png
2 |
3 | # author's testing scripts
4 | *TEST*
5 | pyproject.toml
6 | setup.cfg
7 | setup.py
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | pip-wheel-metadata/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 | */.ipynb_checkpoints/
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 |
2 | | Version | Release date | Update |
3 | |:--------:|:-------------:|:-------- |
4 | | 0.0.2 | 2020-12-30 | (1) CRPS loss function.
(2) Semi-hard triplet loss function.
(3) Fixing user-specified names on keras models. |
5 | | 0.0.3 | 2021-01-01 | (1) Bugfix.
(2) keyword and documentation fixes for R2U-Net. |
6 | | 0.0.4 | 2021-01-04 | (1) Bugfix.
(2) New feature. ResUnet-a implementation.
(3) New feature. Tversky loss
(4) New feature. GELU and Snake activation layers. |
7 | | 0.0.5 | 2021-01-06 | (1) Bug and typo fixes.
(2) File structure changes.
(3) New feature. U^2-Net.
(4) New feature. Sigmoid output activation. |
8 | | 0.0.6 | 2021-01-10 | (1) New feature. Deep supervision for Unet++.
(2) New feature. Conv-based down- and upsampling for U^2-Net. |
9 | | 0.0.7 | 2021-01-12 | (1) Bugfix.
(2) New feature. backbone functions.
(3) New feature. UNET 3+. |
10 | | 0.0.8 | 2021-01-13 | Bugfix. |
11 | | 0.0.9 | 2021-01-14 | (1) New feature. two util functions.
(2) Documentation. UNET 3+ image segmentation. |
12 | | 0.0.10 | 2021-01-18 | (1) Bugfix.
(2) Rename the `backbone` module as `base`.
(3) New documentation.
(4) New feature. 2-d V-net. |
13 | | 0.0.11 | 2021-01-22 | Bugfix for UNET 3+. |
14 | | 0.0.12 | 2021-01-25 | (1) Bugfix.
(2) New feature. Backbone models for Unet, Unet++, and UNET 3+. |
15 | | 0.0.13 | 2021-01-25 | (1) Bugfix.
(2) New feature. Backbone models for Attention Unet. |
16 | | 0.0.14 | 2021-01-30 | Python script cleaning and documentation fix. |
17 | | 0.0.15 | 2021-02-03 | (1) Bugfix.
(2) New feature. More `pool` and `unpool` options. |
18 | | 0.0.16 | 2021-02-05 | (1) Bugfix.
(2) New feature. Loss functions. |
19 | | 0.0.17 | 2021-02-28 | Bugfix on UNET 3+ |
20 | | 0.0.18 | 2021-03-04 | Bugfix on UNET 3+ |
21 | | 0.1.1 | 2021-05-27 | (1) Bugfix.
(2) New feature. TransUNET. |
22 | | 0.1.5 | 2021-05-27 | Bugfix on the packaging issue. |
23 | | 0.1.6 | 2021-05-30 | New feature. MS-SSMI and IoU loss functions. |
24 | | 0.1.7 | 2021-06-07 | Bugfix on TransUNET |
25 | | 0.1.8 | 2021-06-07 | Bugfix on TransUNET |
26 | | 0.1.9 | 2021-06-25 | (1) New feature. Swin-UNET.
(2) Update utils and the user guide. |
27 | | 0.1.10 | 2021-07-13 | Bugfix on model saving issues. |
28 | | 0.1.11 | 2021-09-01 | Bugfix on transformer layers. |
29 | | 0.1.12 | 2021-09-04 | Add citation info into the main page.
Generate package doi reference. |
30 | | 0.1.13 | 2022-01-10 | Bugfix on `utils.dummpy_loader` |
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yingkai Sha
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # keras-unet-collection
2 |
3 | [](https://badge.fury.io/py/keras-unet-collection)
4 | [](https://pypi.org/project/keras-unet-collection/)
5 | [](https://github.com/yingkaisha/keras-unet-collection/graphs/commit-activity)
6 |
7 | [](https://zenodo.org/badge/latestdoi/323426984)
8 |
9 | The `tensorflow.keras` implementation of U-net, V-net, U-net++, UNET 3+, Attention U-net, R2U-net, ResUnet-a, U^2-Net, TransUNET, and Swin-UNET with optional ImageNet-trained backbones.
10 |
11 | ----------
12 |
13 | `keras_unet_collection.models` contains functions that configure keras models with hyper-parameter options.
14 |
15 | * Pre-trained ImageNet backbones are supported for U-net, U-net++, UNET 3+, Attention U-net, and TransUNET.
16 | * Deep supervision is supported for U-net++, UNET 3+, and U^2-Net.
17 | * See the [User guide](https://github.com/yingkaisha/keras-unet-collection/blob/main/examples/user_guide_models.ipynb) for other options and use cases.
18 |
19 | | `keras_unet_collection.models` | Name | Reference |
20 | |:---------------|:----------------|:----------------|
21 | | `unet_2d` | U-net | [Ronneberger et al. (2015)](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28) |
22 | | `vnet_2d` | V-net (modified for 2-d inputs) | [Milletari et al. (2016)](https://arxiv.org/abs/1606.04797) |
23 | | `unet_plus_2d` | U-net++ | [Zhou et al. (2018)](https://link.springer.com/chapter/10.1007/978-3-030-00889-5_1) |
24 | | `r2_unet_2d` | R2U-Net | [Alom et al. (2018)](https://arxiv.org/abs/1802.06955) |
25 | | `att_unet_2d` | Attention U-net | [Oktay et al. (2018)](https://arxiv.org/abs/1804.03999) |
26 | | `resunet_a_2d` | ResUnet-a | [Diakogiannis et al. (2020)](https://doi.org/10.1016/j.isprsjprs.2020.01.013) |
27 | | `u2net_2d` | U^2-Net | [Qin et al. (2020)](https://arxiv.org/abs/2005.09007) |
28 | | `unet_3plus_2d` | UNET 3+ | [Huang et al. (2020)](https://arxiv.org/abs/2004.08790) |
29 | | `transunet_2d` | TransUNET | [Chen et al. (2021)](https://arxiv.org/abs/2102.04306) |
30 | | `swin_unet_2d` | Swin-UNET | [Hu et al. (2021)](https://arxiv.org/abs/2105.05537) |
31 |
32 | **Note**: the two Transformer models are incompatible with `NumPy 1.20`; `NumPy 1.19.5` is recommended.
33 |
34 | ----------
35 |
36 | ` keras_unet_collection.base` contains functions that build the base architecture (i.e., without model heads) of Unet variants for model customization and debugging.
37 |
38 | | ` keras_unet_collection.base` | Notes |
39 | |:-----------------------------------|:------|
40 | | `unet_2d_base`, `vnet_2d_base`, `unet_plus_2d_base`, `unet_3plus_2d_base`, `att_unet_2d_base`, `r2_unet_2d_base`, `resunet_a_2d_base`, `u2net_2d_base`, `transunet_2d_base`, `swin_unet_2d_base` | Functions that accept an input tensor and hyper-parameters of the corresponded model, and produce output tensors of the base architecture. |
41 |
42 | ----------
43 |
44 | `keras_unet_collection.activations` and `keras_unet_collection.losses` provide additional activation layers and loss functions.
45 |
46 | | `keras_unet_collection.activations` | Name | Reference |
47 | |:--------|:----------------|:----------------|
48 | | `GELU` | Gaussian Error Linear Units (GELU) | [Hendrycks et al. (2016)](https://arxiv.org/abs/1606.08415) |
49 | | `Snake` | Snake activation | [Liu et al. (2020)](https://arxiv.org/abs/2006.08195) |
50 |
51 | | `keras_unet_collection.losses` | Name | Reference |
52 | |:----------------|:----------------|:----------------|
53 | | `dice` | Dice loss | [Sudre et al. (2017)](https://link.springer.com/chapter/10.1007/978-3-319-67558-9_28) |
54 | | `tversky` | Tversky loss | [Hashemi et al. (2018)](https://ieeexplore.ieee.org/abstract/document/8573779) |
55 | | `focal_tversky` | Focal Tversky loss | [Abraham et al. (2019)](https://ieeexplore.ieee.org/abstract/document/8759329) |
56 | | `ms_ssim` | Multi-scale Structural Similarity Index Measure loss | [Wang et al. (2003)](https://ieeexplore.ieee.org/abstract/document/1292216) |
57 | | `iou_seg` | Intersection over Union (IoU) loss for segmentation | [Rahman and Wang (2016)](https://link.springer.com/chapter/10.1007/978-3-319-50835-1_22) |
58 | | `iou_box` | (Generalized) IoU loss for object detection | [Rezatofighi et al. (2019)](https://openaccess.thecvf.com/content_CVPR_2019/html/Rezatofighi_Generalized_Intersection_Over_Union_A_Metric_and_a_Loss_for_CVPR_2019_paper.html) |
59 | | `triplet_1d` | Semi-hard triplet loss (experimental) | |
60 | | `crps2d_tf` | CRPS loss (experimental) | |
61 |
62 | # Installation
63 |
64 | The project is hosted on [PyPI](https://pypi.org/project/keras-unet-collection/) and can thus be installed by:
65 |
66 | ```
67 | pip install keras-unet-collection
68 | ```
69 |
70 | # Usage
71 |
72 | ```python
73 | from keras_unet_collection import models
74 | # e.g. models.unet_2d(...)
75 | ```
76 | * **Note**: Currently supported backbone models are: `VGG[16,19]`, `ResNet[50,101,152]`, `ResNet[50,101,152]V2`, `DenseNet[121,169,201]`, and `EfficientNetB[0-7]`. See [Keras Applications](https://keras.io/api/applications/) for details.
77 |
78 | * **Note**: Neural networks produced by this package may contain customized layers that are not part of the Tensorflow. It is reommended to save and load model weights.
79 |
80 | * [Changelog](https://github.com/yingkaisha/keras-unet-collection/blob/main/CHANGELOG.md)
81 |
82 | # Examples
83 |
84 | * Jupyter notebooks are provided as [examples](https://github.com/yingkaisha/keras-unet-collection/tree/main/examples):
85 |
86 | * Attention U-net with VGG16 backbone [[link]](https://github.com/yingkaisha/keras-unet-collection/blob/main/examples/human-seg_atten-unet-backbone_coco.ipynb).
87 |
88 | * UNET 3+ with deep supervision, classification-guided module, and hybrid loss [[link]](https://github.com/yingkaisha/keras-unet-collection/blob/main/examples/segmentation_unet-three-plus_oxford-iiit.ipynb).
89 |
90 | * Vision-Transformer-based examples are in progress, and available at [**keras-vision-transformer**](https://github.com/yingkaisha/keras-vision-transformer).
91 |
92 | # Dependencies
93 |
94 | * TensorFlow 2.5.0, Keras 2.5.0, NumPy 1.19.5.
95 |
96 | * (Optional for examples) Pillow, matplotlib, etc.
97 |
98 | # Overview
99 |
100 | U-net is a convolutional neural network with encoder-decoder architecture and skip-connections, loosely defined under the concept of "fully convolutional networks." U-net was originally proposed for the semantic segmentation of medical images and is modified for solving a wider range of gridded learning problems.
101 |
102 | U-net and many of its variants take three or four-dimensional tensors as inputs and produce outputs of the same shape. One technical highlight of these models is the skip-connections from downsampling to upsampling layers, which benefit the reconstruction of high-resolution, gridded outputs.
103 |
104 | # Contact
105 |
106 | Yingkai (Kyle) Sha <> <>
107 |
108 | # License
109 |
110 | [MIT License](https://github.com/yingkaisha/keras-unet/blob/main/LICENSE)
111 |
112 | # Citation
113 |
114 | * Sha, Y., 2021: Keras-unet-collection. GitHub repository, accessed 4 September 2021, https://doi.org/10.5281/zenodo.5449801
115 |
116 | ```
117 | @misc{keras-unet-collection,
118 | author = {Sha, Yingkai},
119 | title = {Keras-unet-collection},
120 | year = {2021},
121 | publisher = {GitHub},
122 | journal = {GitHub repository},
123 | howpublished = {\url{https://github.com/yingkaisha/keras-unet-collection}},
124 | doi = {10.5281/zenodo.5449801}
125 | }
126 | ```
127 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yingkaisha/keras-unet-collection/13e0e7e0b3c12d2be58d4a3157f3468652e00969/__init__.py
--------------------------------------------------------------------------------
/keras_unet_collection/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/keras_unet_collection/_backbone_zoo.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from tensorflow.keras.applications import *
5 | from tensorflow.keras.models import Model
6 |
7 | from keras_unet_collection.utils import freeze_model
8 |
9 | import warnings
10 |
11 | layer_cadidates = {
12 | 'VGG16': ('block1_conv2', 'block2_conv2', 'block3_conv3', 'block4_conv3', 'block5_conv3'),
13 | 'VGG19': ('block1_conv2', 'block2_conv2', 'block3_conv4', 'block4_conv4', 'block5_conv4'),
14 | 'ResNet50': ('conv1_relu', 'conv2_block3_out', 'conv3_block4_out', 'conv4_block6_out', 'conv5_block3_out'),
15 | 'ResNet101': ('conv1_relu', 'conv2_block3_out', 'conv3_block4_out', 'conv4_block23_out', 'conv5_block3_out'),
16 | 'ResNet152': ('conv1_relu', 'conv2_block3_out', 'conv3_block8_out', 'conv4_block36_out', 'conv5_block3_out'),
17 | 'ResNet50V2': ('conv1_conv', 'conv2_block3_1_relu', 'conv3_block4_1_relu', 'conv4_block6_1_relu', 'post_relu'),
18 | 'ResNet101V2': ('conv1_conv', 'conv2_block3_1_relu', 'conv3_block4_1_relu', 'conv4_block23_1_relu', 'post_relu'),
19 | 'ResNet152V2': ('conv1_conv', 'conv2_block3_1_relu', 'conv3_block8_1_relu', 'conv4_block36_1_relu', 'post_relu'),
20 | 'DenseNet121': ('conv1/relu', 'pool2_conv', 'pool3_conv', 'pool4_conv', 'relu'),
21 | 'DenseNet169': ('conv1/relu', 'pool2_conv', 'pool3_conv', 'pool4_conv', 'relu'),
22 | 'DenseNet201': ('conv1/relu', 'pool2_conv', 'pool3_conv', 'pool4_conv', 'relu'),
23 | 'EfficientNetB0': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),
24 | 'EfficientNetB1': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),
25 | 'EfficientNetB2': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),
26 | 'EfficientNetB3': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),
27 | 'EfficientNetB4': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),
28 | 'EfficientNetB5': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),
29 | 'EfficientNetB6': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),
30 | 'EfficientNetB7': ('block2a_expand_activation', 'block3a_expand_activation', 'block4a_expand_activation', 'block6a_expand_activation', 'top_activation'),}
31 |
32 | def bach_norm_checker(backbone_name, batch_norm):
33 | '''batch norm checker'''
34 | if 'VGG' in backbone_name:
35 | batch_norm_backbone = False
36 | else:
37 | batch_norm_backbone = True
38 |
39 | if batch_norm_backbone != batch_norm:
40 | if batch_norm_backbone:
41 | param_mismatch = "\n\nBackbone {} uses batch norm, but other layers received batch_norm={}".format(backbone_name, batch_norm)
42 | else:
43 | param_mismatch = "\n\nBackbone {} does not use batch norm, but other layers received batch_norm={}".format(backbone_name, batch_norm)
44 |
45 | warnings.warn(param_mismatch);
46 |
47 | def backbone_zoo(backbone_name, weights, input_tensor, depth, freeze_backbone, freeze_batch_norm):
48 | '''
49 | Configuring a user specified encoder model based on the `tensorflow.keras.applications`
50 |
51 | Input
52 | ----------
53 | backbone_name: the bakcbone model name. Expected as one of the `tensorflow.keras.applications` class.
54 | Currently supported backbones are:
55 | (1) VGG16, VGG19
56 | (2) ResNet50, ResNet101, ResNet152
57 | (3) ResNet50V2, ResNet101V2, ResNet152V2
58 | (4) DenseNet121, DenseNet169, DenseNet201
59 | (5) EfficientNetB[0,7]
60 |
61 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
62 | or the path to the weights file to be loaded.
63 | input_tensor: the input tensor
64 | depth: number of encoded feature maps.
65 | If four dwonsampling levels are needed, then depth=4.
66 |
67 | freeze_backbone: True for a frozen backbone
68 | freeze_batch_norm: False for not freezing batch normalization layers.
69 |
70 | Output
71 | ----------
72 | model: a keras backbone model.
73 |
74 | '''
75 |
76 | cadidate = layer_cadidates[backbone_name]
77 |
78 | # ----- #
79 | # depth checking
80 | depth_max = len(cadidate)
81 | if depth > depth_max:
82 | depth = depth_max
83 | # ----- #
84 |
85 | backbone_func = eval(backbone_name)
86 | backbone_ = backbone_func(include_top=False, weights=weights, input_tensor=input_tensor, pooling=None,)
87 |
88 | X_skip = []
89 |
90 | for i in range(depth):
91 | X_skip.append(backbone_.get_layer(cadidate[i]).output)
92 |
93 | model = Model(inputs=[input_tensor,], outputs=X_skip, name='{}_backbone'.format(backbone_name))
94 |
95 | if freeze_backbone:
96 | model = freeze_model(model, freeze_batch_norm=freeze_batch_norm)
97 |
98 | return model
99 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_att_unet_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right
7 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker
8 |
9 | from tensorflow.keras.layers import Input
10 | from tensorflow.keras.models import Model
11 |
12 |
13 | def UNET_att_right(X, X_left, channel, att_channel, kernel_size=3, stack_num=2,
14 | activation='ReLU', atten_activation='ReLU', attention='add',
15 | unpool=True, batch_norm=False, name='right0'):
16 | '''
17 | the decoder block of Attention U-net.
18 |
19 | UNET_att_right(X, X_left, channel, att_channel, kernel_size=3, stack_num=2,
20 | activation='ReLU', atten_activation='ReLU', attention='add',
21 | unpool=True, batch_norm=False, name='right0')
22 |
23 | Input
24 | ----------
25 | X: input tensor
26 | X_left: the output of corresponded downsampling output tensor (the input tensor is upsampling input)
27 | channel: number of convolution filters
28 | att_channel: number of intermediate channel.
29 | kernel_size: size of 2-d convolution kernels.
30 | stack_num: number of convolutional layers.
31 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
32 | atten_activation: a nonlinear attnetion activation.
33 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'.
34 | attention: 'add' for additive attention. 'multiply' for multiplicative attention.
35 | Oktay et al. 2018 applied additive attention.
36 | unpool: True or "bilinear" for Upsampling2D with bilinear interpolation.
37 | "nearest" for Upsampling2D with nearest interpolation.
38 | False for Conv2DTranspose + batch norm + activation.
39 | batch_norm: True for batch normalization, False otherwise.
40 | name: prefix of the created keras layers.
41 | Output
42 | ----------
43 | X: output tensor.
44 |
45 | '''
46 |
47 | pool_size = 2
48 |
49 | X = decode_layer(X, channel, pool_size, unpool,
50 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name))
51 |
52 | X_left = attention_gate(X=X_left, g=X, channel=att_channel, activation=atten_activation,
53 | attention=attention, name='{}_att'.format(name))
54 |
55 | # Tensor concatenation
56 | H = concatenate([X, X_left], axis=-1, name='{}_concat'.format(name))
57 |
58 | # stacked linear convolutional layers after concatenation
59 | H = CONV_stack(H, channel, kernel_size, stack_num=stack_num, activation=activation,
60 | batch_norm=batch_norm, name='{}_conv_after_concat'.format(name))
61 |
62 | return H
63 |
64 | def att_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
65 | activation='ReLU', atten_activation='ReLU', attention='add', batch_norm=False, pool=True, unpool=True,
66 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='attunet'):
67 | '''
68 | The base of Attention U-net with an optional ImageNet backbone
69 |
70 | att_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
71 | activation='ReLU', atten_activation='ReLU', attention='add', batch_norm=False, pool=True, unpool=True,
72 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='att-unet')
73 |
74 | ----------
75 | Oktay, O., Schlemper, J., Folgoc, L.L., Lee, M., Heinrich, M., Misawa, K., Mori, K., McDonagh, S., Hammerla, N.Y., Kainz, B.
76 | and Glocker, B., 2018. Attention u-net: Learning where to look for the pancreas. arXiv preprint arXiv:1804.03999.
77 |
78 | Input
79 | ----------
80 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
81 | filter_num: a list that defines the number of filters for each \
82 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
83 | The depth is expected as `len(filter_num)`.
84 | stack_num_down: number of convolutional layers per downsampling level/block.
85 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
86 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
87 | atten_activation: a nonlinear atteNtion activation.
88 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'.
89 | attention: 'add' for additive attention. 'multiply' for multiplicative attention.
90 | Oktay et al. 2018 applied additive attention.
91 | batch_norm: True for batch normalization.
92 | pool: True or 'max' for MaxPooling2D.
93 | 'ave' for AveragePooling2D.
94 | False for strided conv + batch norm + activation.
95 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
96 | 'nearest' for Upsampling2D with nearest interpolation.
97 | False for Conv2DTranspose + batch norm + activation.
98 | name: prefix of the created keras model and its layers.
99 |
100 | ---------- (keywords of backbone options) ----------
101 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
102 | None (default) means no backbone.
103 | Currently supported backbones are:
104 | (1) VGG16, VGG19
105 | (2) ResNet50, ResNet101, ResNet152
106 | (3) ResNet50V2, ResNet101V2, ResNet152V2
107 | (4) DenseNet121, DenseNet169, DenseNet201
108 | (5) EfficientNetB[0-7]
109 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
110 | or the path to the weights file to be loaded.
111 | freeze_backbone: True for a frozen backbone.
112 | freeze_batch_norm: False for not freezing batch normalization layers.
113 |
114 | Output
115 | ----------
116 | X: the output tensor of the base.
117 |
118 | '''
119 | activation_func = eval(activation)
120 |
121 | depth_ = len(filter_num)
122 | X_skip = []
123 |
124 | # no backbone cases
125 | if backbone is None:
126 | X = input_tensor
127 | # downsampling blocks
128 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation,
129 | batch_norm=batch_norm, name='{}_down0'.format(name))
130 | X_skip.append(X)
131 |
132 | for i, f in enumerate(filter_num[1:]):
133 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, pool=pool,
134 | batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))
135 | X_skip.append(X)
136 |
137 | else:
138 | # handling VGG16 and VGG19 separately
139 | if 'VGG' in backbone:
140 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm)
141 | # collecting backbone feature maps
142 | X_skip = backbone_([input_tensor,])
143 | depth_encode = len(X_skip)
144 |
145 | # for other backbones
146 | else:
147 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm)
148 | # collecting backbone feature maps
149 | X_skip = backbone_([input_tensor,])
150 | depth_encode = len(X_skip) + 1
151 |
152 | # extra conv2d blocks are applied
153 | # if downsampling levels of a backbone < user-specified downsampling levels
154 | if depth_encode < depth_:
155 |
156 | # begins at the deepest available tensor
157 | X = X_skip[-1]
158 |
159 | # extra downsamplings
160 | for i in range(depth_-depth_encode):
161 | i_real = i + depth_encode
162 |
163 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool,
164 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1))
165 | X_skip.append(X)
166 |
167 | # reverse indexing encoded feature maps
168 | X_skip = X_skip[::-1]
169 | # upsampling begins at the deepest available tensor
170 | X = X_skip[0]
171 | # other tensors are preserved for concatenation
172 | X_decode = X_skip[1:]
173 | depth_decode = len(X_decode)
174 |
175 | # reverse indexing filter numbers
176 | filter_num_decode = filter_num[:-1][::-1]
177 |
178 | for i in range(depth_decode):
179 | f = filter_num_decode[i]
180 |
181 | X = UNET_att_right(X, X_decode[i], f, att_channel=f//2, stack_num=stack_num_up,
182 | activation=activation, atten_activation=atten_activation, attention=attention,
183 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i))
184 |
185 | # if tensors for concatenation is not enough
186 | # then use upsampling without concatenation
187 | if depth_decode < depth_-1:
188 | for i in range(depth_-depth_decode-1):
189 | i_real = i + depth_decode
190 | X = UNET_right(X, None, filter_num_decode[i_real], stack_num=stack_num_up, activation=activation,
191 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_up{}'.format(name, i_real))
192 | return X
193 |
194 | def att_unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, activation='ReLU',
195 | atten_activation='ReLU', attention='add', output_activation='Softmax', batch_norm=False, pool=True, unpool=True,
196 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='attunet'):
197 | '''
198 | Attention U-net with an optional ImageNet backbone
199 |
200 | att_unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2, activation='ReLU',
201 | atten_activation='ReLU', attention='add', output_activation='Softmax', batch_norm=False, pool=True, unpool=True,
202 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='att-unet')
203 |
204 | ----------
205 | Oktay, O., Schlemper, J., Folgoc, L.L., Lee, M., Heinrich, M., Misawa, K., Mori, K., McDonagh, S., Hammerla, N.Y., Kainz, B.
206 | and Glocker, B., 2018. Attention u-net: Learning where to look for the pancreas. arXiv preprint arXiv:1804.03999.
207 |
208 | Input
209 | ----------
210 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
211 | filter_num: a list that defines the number of filters for each \
212 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
213 | The depth is expected as `len(filter_num)`.
214 | n_labels: number of output labels.
215 | stack_num_down: number of convolutional layers per downsampling level/block.
216 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
217 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
218 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
219 | Default option is 'Softmax'.
220 | if None is received, then linear activation is applied.
221 | atten_activation: a nonlinear atteNtion activation.
222 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'.
223 | attention: 'add' for additive attention. 'multiply' for multiplicative attention.
224 | Oktay et al. 2018 applied additive attention.
225 | batch_norm: True for batch normalization.
226 | pool: True or 'max' for MaxPooling2D.
227 | 'ave' for AveragePooling2D.
228 | False for strided conv + batch norm + activation.
229 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
230 | 'nearest' for Upsampling2D with nearest interpolation.
231 | False for Conv2DTranspose + batch norm + activation.
232 | name: prefix of the created keras model and its layers.
233 |
234 | ---------- (keywords of backbone options) ----------
235 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
236 | None (default) means no backbone.
237 | Currently supported backbones are:
238 | (1) VGG16, VGG19
239 | (2) ResNet50, ResNet101, ResNet152
240 | (3) ResNet50V2, ResNet101V2, ResNet152V2
241 | (4) DenseNet121, DenseNet169, DenseNet201
242 | (5) EfficientNetB[0-7]
243 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
244 | or the path to the weights file to be loaded.
245 | freeze_backbone: True for a frozen backbone.
246 | freeze_batch_norm: False for not freezing batch normalization layers.
247 |
248 | Output
249 | ----------
250 | model: a keras model
251 |
252 | '''
253 |
254 | # one of the ReLU, LeakyReLU, PReLU, ELU
255 | activation_func = eval(activation)
256 |
257 | if backbone is not None:
258 | bach_norm_checker(backbone, batch_norm)
259 |
260 | IN = Input(input_size)
261 |
262 | # base
263 | X = att_unet_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up,
264 | activation=activation, atten_activation=atten_activation, attention=attention,
265 | batch_norm=batch_norm, pool=pool, unpool=unpool,
266 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone,
267 | freeze_batch_norm=freeze_backbone, name=name)
268 |
269 | # output layer
270 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
271 |
272 | # functional API model
273 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name))
274 |
275 | return model
276 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_r2_unet_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 |
7 | from tensorflow.keras.layers import Input
8 | from tensorflow.keras.models import Model
9 |
10 | def RR_CONV(X, channel, kernel_size=3, stack_num=2, recur_num=2, activation='ReLU', batch_norm=False, name='rr'):
11 | '''
12 | Recurrent convolutional layers with skip connection.
13 |
14 | RR_CONV(X, channel, kernel_size=3, stack_num=2, recur_num=2, activation='ReLU', batch_norm=False, name='rr')
15 |
16 | Input
17 | ----------
18 | X: input tensor.
19 | channel: number of convolution filters.
20 | kernel_size: size of 2-d convolution kernels.
21 | stack_num: number of stacked recurrent convolutional layers.
22 | recur_num: number of recurrent iterations.
23 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
24 | batch_norm: True for batch normalization, False otherwise.
25 | name: prefix of the created keras layers.
26 |
27 | Output
28 | ----------
29 | X: output tensor.
30 |
31 | '''
32 |
33 | activation_func = eval(activation)
34 |
35 | layer_skip = Conv2D(channel, 1, name='{}_conv'.format(name))(X)
36 | layer_main = layer_skip
37 |
38 | for i in range(stack_num):
39 |
40 | layer_res = Conv2D(channel, kernel_size, padding='same', name='{}_conv{}'.format(name, i))(layer_main)
41 |
42 | if batch_norm:
43 | layer_res = BatchNormalization(name='{}_bn{}'.format(name, i))(layer_res)
44 |
45 | layer_res = activation_func(name='{}_activation{}'.format(name, i))(layer_res)
46 |
47 | for j in range(recur_num):
48 |
49 | layer_add = add([layer_res, layer_main], name='{}_add{}_{}'.format(name, i, j))
50 |
51 | layer_res = Conv2D(channel, kernel_size, padding='same', name='{}_conv{}_{}'.format(name, i, j))(layer_add)
52 |
53 | if batch_norm:
54 | layer_res = BatchNormalization(name='{}_bn{}_{}'.format(name, i, j))(layer_res)
55 |
56 | layer_res = activation_func(name='{}_activation{}_{}'.format(name, i, j))(layer_res)
57 |
58 | layer_main = layer_res
59 |
60 | out_layer = add([layer_main, layer_skip], name='{}_add{}'.format(name, i))
61 |
62 | return out_layer
63 |
64 |
65 | def UNET_RR_left(X, channel, kernel_size=3,
66 | stack_num=2, recur_num=2, activation='ReLU',
67 | pool=True, batch_norm=False, name='left0'):
68 | '''
69 | The encoder block of R2U-Net.
70 |
71 | UNET_RR_left(X, channel, kernel_size=3,
72 | stack_num=2, recur_num=2, activation='ReLU',
73 | pool=True, batch_norm=False, name='left0')
74 |
75 | Input
76 | ----------
77 | X: input tensor.
78 | channel: number of convolution filters.
79 | kernel_size: size of 2-d convolution kernels.
80 | stack_num: number of stacked recurrent convolutional layers.
81 | recur_num: number of recurrent iterations.
82 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
83 | pool: True or 'max' for MaxPooling2D.
84 | 'ave' for AveragePooling2D.
85 | False for strided conv + batch norm + activation.
86 | batch_norm: True for batch normalization, False otherwise.
87 | name: prefix of the created keras layers.
88 |
89 | Output
90 | ----------
91 | X: output tensor.
92 |
93 | *downsampling is fixed to 2-by-2, e.g., reducing feature map sizes from 64-by-64 to 32-by-32
94 | '''
95 | pool_size = 2
96 |
97 | # maxpooling layer vs strided convolutional layers
98 | X = encode_layer(X, channel, pool_size, pool, activation=activation,
99 | batch_norm=batch_norm, name='{}_encode'.format(name))
100 |
101 | # stack linear convolutional layers
102 | X = RR_CONV(X, channel, stack_num=stack_num, recur_num=recur_num,
103 | activation=activation, batch_norm=batch_norm, name=name)
104 | return X
105 |
106 |
107 | def UNET_RR_right(X, X_list, channel, kernel_size=3,
108 | stack_num=2, recur_num=2, activation='ReLU',
109 | unpool=True, batch_norm=False, name='right0'):
110 | '''
111 | The decoder block of R2U-Net.
112 |
113 | UNET_RR_right(X, X_list, channel, kernel_size=3,
114 | stack_num=2, recur_num=2, activation='ReLU',
115 | unpool=True, batch_norm=False, name='right0')
116 |
117 | Input
118 | ----------
119 | X: input tensor.
120 | X_list: a list of other tensors that connected to the input tensor.
121 | channel: number of convolution filters.
122 | kernel_size: size of 2-d convolution kernels.
123 | stack_num: number of stacked recurrent convolutional layers.
124 | recur_num: number of recurrent iterations.
125 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
126 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
127 | 'nearest' for Upsampling2D with nearest interpolation.
128 | False for Conv2DTranspose + batch norm + activation.
129 | batch_norm: True for batch normalization, False otherwise.
130 | name: prefix of the created keras layers.
131 |
132 | Output
133 | ----------
134 | X: output tensor
135 |
136 | '''
137 |
138 | pool_size = 2
139 |
140 | X = decode_layer(X, channel, pool_size, unpool,
141 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name))
142 |
143 | # linear convolutional layers before concatenation
144 | X = CONV_stack(X, channel, kernel_size, stack_num=1, activation=activation,
145 | batch_norm=batch_norm, name='{}_conv_before_concat'.format(name))
146 |
147 | # Tensor concatenation
148 | H = concatenate([X,]+X_list, axis=-1, name='{}_concat'.format(name))
149 |
150 | # stacked linear convolutional layers after concatenation
151 | H = RR_CONV(H, channel, stack_num=stack_num, recur_num=recur_num,
152 | activation=activation, batch_norm=batch_norm, name=name)
153 |
154 | return H
155 |
156 | def r2_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, recur_num=2,
157 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='res_unet'):
158 |
159 | '''
160 | The base of Recurrent Residual (R2) U-Net.
161 |
162 | r2_unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2, recur_num=2,
163 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='res_unet')
164 |
165 | ----------
166 | Alom, M.Z., Hasan, M., Yakopcic, C., Taha, T.M. and Asari, V.K., 2018. Recurrent residual convolutional neural network
167 | based on u-net (r2u-net) for medical image segmentation. arXiv preprint arXiv:1802.06955.
168 |
169 | Input
170 | ----------
171 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
172 | filter_num: a list that defines the number of filters for each \
173 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
174 | The depth is expected as `len(filter_num)`.
175 | stack_num_down: number of stacked recurrent convolutional layers per downsampling level/block.
176 | stack_num_down: number of stacked recurrent convolutional layers per upsampling level/block.
177 | recur_num: number of recurrent iterations.
178 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
179 | batch_norm: True for batch normalization.
180 | pool: True or 'max' for MaxPooling2D.
181 | 'ave' for AveragePooling2D.
182 | False for strided conv + batch norm + activation.
183 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
184 | 'nearest' for Upsampling2D with nearest interpolation.
185 | False for Conv2DTranspose + batch norm + activation.
186 | name: prefix of the created keras layers.
187 |
188 | Output
189 | ----------
190 | X: output tensor.
191 |
192 | '''
193 |
194 | activation_func = eval(activation)
195 |
196 | X = input_tensor
197 | X_skip = []
198 |
199 | # downsampling blocks
200 | X = RR_CONV(X, filter_num[0], stack_num=stack_num_down, recur_num=recur_num,
201 | activation=activation, batch_norm=batch_norm, name='{}_down0'.format(name))
202 | X_skip.append(X)
203 |
204 | for i, f in enumerate(filter_num[1:]):
205 | X = UNET_RR_left(X, f, kernel_size=3, stack_num=stack_num_down, recur_num=recur_num,
206 | activation=activation, pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))
207 | X_skip.append(X)
208 |
209 | # upsampling blocks
210 | X_skip = X_skip[:-1][::-1]
211 | for i, f in enumerate(filter_num[:-1][::-1]):
212 | X = UNET_RR_right(X, [X_skip[i],], f, stack_num=stack_num_up, recur_num=recur_num,
213 | activation=activation, unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i+1))
214 |
215 | return X
216 |
217 |
218 | def r2_unet_2d(input_size, filter_num, n_labels,
219 | stack_num_down=2, stack_num_up=2, recur_num=2,
220 | activation='ReLU', output_activation='Softmax',
221 | batch_norm=False, pool=True, unpool=True, name='r2_unet'):
222 |
223 | '''
224 | Recurrent Residual (R2) U-Net
225 |
226 | r2_unet_2d(input_size, filter_num, n_labels,
227 | stack_num_down=2, stack_num_up=2, recur_num=2,
228 | activation='ReLU', output_activation='Softmax',
229 | batch_norm=False, pool=True, unpool=True, name='r2_unet')
230 |
231 | ----------
232 | Alom, M.Z., Hasan, M., Yakopcic, C., Taha, T.M. and Asari, V.K., 2018. Recurrent residual convolutional neural network
233 | based on u-net (r2u-net) for medical image segmentation. arXiv preprint arXiv:1802.06955.
234 |
235 | Input
236 | ----------
237 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
238 | filter_num: a list that defines the number of filters for each \
239 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
240 | The depth is expected as `len(filter_num)`.
241 | n_labels: number of output labels.
242 | stack_num_down: number of stacked recurrent convolutional layers per downsampling level/block.
243 | stack_num_down: number of stacked recurrent convolutional layers per upsampling level/block.
244 | recur_num: number of recurrent iterations.
245 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
246 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
247 | Default option is 'Softmax'.
248 | if None is received, then linear activation is applied.
249 | batch_norm: True for batch normalization.
250 | pool: True or 'max' for MaxPooling2D.
251 | 'ave' for AveragePooling2D.
252 | False for strided conv + batch norm + activation.
253 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
254 | 'nearest' for Upsampling2D with nearest interpolation.
255 | False for Conv2DTranspose + batch norm + activation.
256 | name: prefix of the created keras layers.
257 |
258 | Output
259 | ----------
260 | model: a keras model.
261 |
262 | '''
263 |
264 | activation_func = eval(activation)
265 |
266 | IN = Input(input_size, name='{}_input'.format(name))
267 |
268 | # base
269 | X = r2_unet_2d_base(IN, filter_num,
270 | stack_num_down=stack_num_down, stack_num_up=stack_num_up, recur_num=recur_num,
271 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name=name)
272 | # output layer
273 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
274 |
275 | # functional API model
276 | model = Model(inputs=[IN], outputs=[OUT], name='{}_model'.format(name))
277 |
278 | return model
279 |
280 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_resunet_a_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 |
7 | from tensorflow.keras.layers import Input
8 | from tensorflow.keras.models import Model
9 |
10 | def ResUNET_a_block(X, channel, kernel_size=3, dilation_num=1.0, activation='ReLU', batch_norm=False, name='res_a_block'):
11 | '''
12 | The "ResUNET-a" block
13 |
14 | ResUNET_a_block(X, channel, kernel_size=3, dilation_num=1.0, activation='ReLU', batch_norm=False, name='res_a_block')
15 |
16 | ----------
17 | Diakogiannis, F.I., Waldner, F., Caccetta, P. and Wu, C., 2020. Resunet-a: a deep learning framework for
18 | semantic segmentation of remotely sensed data. ISPRS Journal of Photogrammetry and Remote Sensing, 162, pp.94-114.
19 |
20 | Input
21 | ----------
22 | X: input tensor.
23 | channel: number of convolution filters.
24 | kernel_size: size of 2-d convolution kernels.
25 | dilation_num: an iterable that defines dilation rates of convolutional layers.
26 | stacks of conv2d is expected as `len(dilation_num)`.
27 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
28 | batch_norm: True for batch normalization, False otherwise.
29 | name: prefix of the created keras layers.
30 |
31 | Output
32 | ----------
33 | X: output tensor.
34 |
35 | '''
36 |
37 | X_res = []
38 |
39 | for i, d in enumerate(dilation_num):
40 |
41 | X_res.append(CONV_stack(X, channel, kernel_size=kernel_size, stack_num=2, dilation_rate=d,
42 | activation=activation, batch_norm=batch_norm, name='{}_stack{}'.format(name, i)))
43 |
44 | if len(X_res) > 1:
45 | return add(X_res)
46 |
47 | else:
48 | return X_res[0]
49 |
50 |
51 | def ResUNET_a_right(X, X_list, channel, kernel_size=3, dilation_num=[1,],
52 | activation='ReLU', unpool=True, batch_norm=False, name='right0'):
53 | '''
54 | The decoder block of ResUNet-a
55 |
56 | ResUNET_a_right(X, X_list, channel, kernel_size=3, dilation_num=[1,],
57 | activation='ReLU', unpool=True, batch_norm=False, name='right0')
58 |
59 | Input
60 | ----------
61 | X: input tensor.
62 | X_list: a list of other tensors that connected to the input tensor.
63 | channel: number of convolution filters.
64 | kernel_size: size of 2-d convolution kernels.
65 | dilation_num: an iterable that defines dilation rates of convolutional layers.
66 | stacks of conv2d is expected as `len(dilation_num)`.
67 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
68 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
69 | 'nearest' for Upsampling2D with nearest interpolation.
70 | False for Conv2DTranspose + batch norm + activation.
71 | batch_norm: True for batch normalization, False otherwise.
72 | name: name of the created keras layers.
73 |
74 | Output
75 | ----------
76 | X: output tensor.
77 |
78 |
79 | '''
80 |
81 | pool_size = 2
82 |
83 | X = decode_layer(X, channel, pool_size, unpool,
84 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name))
85 |
86 | # <--- *stacked convolutional can be applied here
87 | X = concatenate([X,]+X_list, axis=3, name=name+'_concat')
88 |
89 | # Stacked convolutions after concatenation
90 | X = ResUNET_a_block(X, channel, kernel_size=kernel_size, dilation_num=dilation_num, activation=activation,
91 | batch_norm=batch_norm, name='{}_resblock'.format(name))
92 |
93 | return X
94 |
95 | def resunet_a_2d_base(input_tensor, filter_num, dilation_num,
96 | aspp_num_down=256, aspp_num_up=128, activation='ReLU',
97 | batch_norm=True, pool=True, unpool=True, name='resunet'):
98 | '''
99 | The base of ResUNet-a
100 |
101 | resunet_a_2d_base(input_tensor, filter_num, dilation_num,
102 | aspp_num_down=256, aspp_num_up=128, activation='ReLU',
103 | batch_norm=True, pool=True, unpool=True, name='resunet')
104 |
105 | ----------
106 | Diakogiannis, F.I., Waldner, F., Caccetta, P. and Wu, C., 2020. Resunet-a: a deep learning framework for
107 | semantic segmentation of remotely sensed data. ISPRS Journal of Photogrammetry and Remote Sensing, 162, pp.94-114.
108 |
109 | Input
110 | ----------
111 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
112 | filter_num: a list that defines the number of filters for each \
113 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
114 | The depth is expected as `len(filter_num)`.
115 | dilation_num: an iterable that defines the dilation rates of convolutional layers.
116 | Diakogiannis et al. (2020) suggested `[1, 3, 15, 31]`.
117 | * This base function requires `len(filter_num) == len(dilation_num)`.
118 | Explicitly defining dilation rates for each down-/upsampling level.
119 | aspp_num_down: number of Atrous Spatial Pyramid Pooling (ASPP) layer filters after the last downsampling block.
120 | aspp_num_up: number of ASPP layer filters after the last upsampling block.
121 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
122 | batch_norm: True for batch normalization.
123 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
124 | 'nearest' for Upsampling2D with nearest interpolation.
125 | False for Conv2DTranspose + batch norm + activation.
126 | name: prefix of the created keras layers.
127 |
128 | Output
129 | ----------
130 | X: output tensor.
131 |
132 | * Downsampling is achieved through strided convolutional layers with 1-by-1 kernels in Diakogiannis et al., (2020),
133 | and is here is achieved either with pooling layers or strided convolutional layers with 2-by-2 kernels.
134 | * If this base function is involved in network training, then the input shape cannot have NoneType.
135 | * `dilation_num` should be provided as 2d iterables, with the second dimension matches the model depth.
136 | e.g., for `len(filter_num) = 4`, dilation_num can be provided as: `[[1, 3, 15, 31], [1, 3, 15], [1,], [1,]]`.
137 |
138 | '''
139 |
140 | pool_size = 2
141 |
142 | activation_func = eval(activation)
143 |
144 | depth_ = len(filter_num)
145 | X_skip = []
146 |
147 | # ----- #
148 | # rejecting auto-mode from this base function
149 | if isinstance(dilation_num[0], int):
150 | raise ValueError('`resunet_a_2d_base` does not support automated determination of `dilation_num`.')
151 | else:
152 | dilation_ = dilation_num
153 | # ----- #
154 |
155 | X = input_tensor
156 |
157 | # input mapping with 1-by-1 conv
158 | X = Conv2D(filter_num[0], 1, 1, dilation_rate=1, padding='same',
159 | use_bias=True, name='{}_input_mapping'.format(name))(X)
160 | X = activation_func(name='{}_input_activation'.format(name))(X)
161 | X_skip.append(X)
162 | # ----- #
163 |
164 | X = ResUNET_a_block(X, filter_num[0], kernel_size=3, dilation_num=dilation_[0],
165 | activation=activation, batch_norm=batch_norm, name='{}_res0'.format(name))
166 | X_skip.append(X)
167 |
168 | for i, f in enumerate(filter_num[1:]):
169 | ind_ = i+1
170 |
171 | X = encode_layer(X, f, pool_size, pool, activation=activation,
172 | batch_norm=batch_norm, name='{}_down{}'.format(name, i))
173 | X = ResUNET_a_block(X, f, kernel_size=3, dilation_num=dilation_[ind_], activation=activation,
174 | batch_norm=batch_norm, name='{}_resblock_{}'.format(name, ind_))
175 | X_skip.append(X)
176 |
177 | X = ASPP_conv(X, aspp_num_down, activation=activation, batch_norm=batch_norm, name='{}_aspp_bottom'.format(name))
178 |
179 | X_skip = X_skip[:-1][::-1]
180 | dilation_ = dilation_[:-1][::-1]
181 |
182 | for i, f in enumerate(filter_num[:-1][::-1]):
183 |
184 | X = ResUNET_a_right(X, [X_skip[i],], f, kernel_size=3, activation=activation, dilation_num=dilation_[i],
185 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i))
186 |
187 | X = concatenate([X_skip[-1], X], name='{}_concat_out'.format(name))
188 |
189 | X = ASPP_conv(X, aspp_num_up, activation=activation, batch_norm=batch_norm, name='{}_aspp_out'.format(name))
190 |
191 | return X
192 |
193 |
194 | def resunet_a_2d(input_size, filter_num, dilation_num, n_labels,
195 | aspp_num_down=256, aspp_num_up=128, activation='ReLU', output_activation='Softmax',
196 | batch_norm=True, pool=True, unpool=True, name='resunet'):
197 | '''
198 | ResUNet-a
199 |
200 | resunet_a_2d(input_size, filter_num, dilation_num, n_labels,
201 | aspp_num_down=256, aspp_num_up=128, activation='ReLU', output_activation='Softmax',
202 | batch_norm=True, pool=True, unpool=True, name='resunet')
203 |
204 | ----------
205 | Diakogiannis, F.I., Waldner, F., Caccetta, P. and Wu, C., 2020. Resunet-a: a deep learning framework for
206 | semantic segmentation of remotely sensed data. ISPRS Journal of Photogrammetry and Remote Sensing, 162, pp.94-114.
207 |
208 | Input
209 | ----------
210 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
211 | filter_num: a list that defines the number of filters for each \
212 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
213 | The depth is expected as `len(filter_num)`.
214 | dilation_num: an iterable that defines the dilation rates of convolutional layers.
215 | Diakogiannis et al. (2020) suggested `[1, 3, 15, 31]`.
216 | * `dilation_num` can be provided as 2d iterables, with the second dimension matches
217 | the model depth. e.g., for len(filter_num) = 4; dilation_num can be provided as:
218 | `[[1, 3, 15, 31], [1, 3, 15], [1,], [1,]]`.
219 | * If `dilation_num` is not provided per down-/upsampling level, then the automated
220 | determinations will be applied.
221 | n_labels: number of output labels.
222 | aspp_num_down: number of Atrous Spatial Pyramid Pooling (ASPP) layer filters after the last downsampling block.
223 | aspp_num_up: number of ASPP layer filters after the last upsampling block.
224 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
225 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
226 | Default option is 'Softmax'.
227 | if None is received, then linear activation is applied.
228 | batch_norm: True for batch normalization.
229 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
230 | 'nearest' for Upsampling2D with nearest interpolation.
231 | False for Conv2DTranspose + batch norm + activation.
232 | name: prefix of the created keras layers.
233 |
234 | Output
235 | ----------
236 | model: a keras model.
237 |
238 | * Downsampling is achieved through strided convolutional layers with 1-by-1 kernels in Diakogiannis et al., (2020),
239 | and is here is achieved either with pooling layers or strided convolutional layers with 2-by-2 kernels.
240 | * `resunet_a_2d` does not support NoneType input shape.
241 |
242 | '''
243 |
244 | activation_func = eval(activation)
245 | depth_ = len(filter_num)
246 |
247 | X_skip = []
248 |
249 | # input_size cannot have None
250 | if input_size[0] is None or input_size[1] is None:
251 | raise ValueError('`resunet_a_2d` does not support NoneType input shape')
252 |
253 | # ----- #
254 | if isinstance(dilation_num[0], int):
255 | print("Received dilation rates: {}".format(dilation_num))
256 |
257 | deep_ = (depth_-2)//2
258 | dilation_ = [[] for _ in range(depth_)]
259 |
260 | print("Received dilation rates are not defined on a per downsampling level basis.")
261 | print("Automated determinations are applied with the following details:")
262 |
263 | for i in range(depth_):
264 | if i <= 1:
265 | dilation_[i] += dilation_num
266 | elif i > 1 and i <= deep_+1:
267 | dilation_[i] += dilation_num[:-1]
268 | else:
269 | dilation_[i] += [1,]
270 | print('\tdepth-{}, dilation_rate = {}'.format(i, dilation_[i]))
271 |
272 | else:
273 | dilation_ = dilation_num
274 | # ----- #
275 |
276 | IN = Input(input_size)
277 |
278 | # base
279 | X = resunet_a_2d_base(IN, filter_num, dilation_,
280 | aspp_num_down=aspp_num_down, aspp_num_up=aspp_num_up, activation=activation,
281 | batch_norm=batch_norm, pool=pool, unpool=unpool, name=name)
282 |
283 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
284 |
285 | model = Model([IN], [OUT,], name='{}_model'.format(name))
286 |
287 | return model
288 |
289 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_swin_unet_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.transformer_layers import patch_extract, patch_embedding, SwinTransformerBlock, patch_merging, patch_expanding
6 |
7 | from tensorflow.keras.layers import Input, Dense
8 | from tensorflow.keras.models import Model
9 |
10 | def swin_transformer_stack(X, stack_num, embed_dim, num_patch, num_heads, window_size, num_mlp, shift_window=True, name=''):
11 | '''
12 | Stacked Swin Transformers that share the same token size.
13 |
14 | Alternated Window-MSA and Swin-MSA will be configured if `shift_window=True`, Window-MSA only otherwise.
15 | *Dropout is turned off.
16 | '''
17 | # Turn-off dropouts
18 | mlp_drop_rate = 0 # Droupout after each MLP layer
19 | attn_drop_rate = 0 # Dropout after Swin-Attention
20 | proj_drop_rate = 0 # Dropout at the end of each Swin-Attention block, i.e., after linear projections
21 | drop_path_rate = 0 # Drop-path within skip-connections
22 |
23 | qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value
24 | qk_scale = None # None: Re-scale query based on embed dimensions per attention head # Float for user specified scaling factor
25 |
26 | if shift_window:
27 | shift_size = window_size // 2
28 | else:
29 | shift_size = 0
30 |
31 | for i in range(stack_num):
32 |
33 | if i % 2 == 0:
34 | shift_size_temp = 0
35 | else:
36 | shift_size_temp = shift_size
37 |
38 | X = SwinTransformerBlock(dim=embed_dim, num_patch=num_patch, num_heads=num_heads,
39 | window_size=window_size, shift_size=shift_size_temp, num_mlp=num_mlp, qkv_bias=qkv_bias, qk_scale=qk_scale,
40 | mlp_drop=mlp_drop_rate, attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, drop_path_prob=drop_path_rate,
41 | name='name{}'.format(i))(X)
42 | return X
43 |
44 |
45 | def swin_unet_2d_base(input_tensor, filter_num_begin, depth, stack_num_down, stack_num_up,
46 | patch_size, num_heads, window_size, num_mlp, shift_window=True, name='swin_unet'):
47 | '''
48 | The base of SwinUNET.
49 |
50 | ----------
51 | Cao, H., Wang, Y., Chen, J., Jiang, D., Zhang, X., Tian, Q. and Wang, M., 2021.
52 | Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation. arXiv preprint arXiv:2105.05537.
53 |
54 | Input
55 | ----------
56 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
57 | filter_num_begin: number of channels in the first downsampling block;
58 | it is also the number of embedded dimensions.
59 | depth: the depth of Swin-UNET, e.g., depth=4 means three down/upsampling levels and a bottom level.
60 | stack_num_down: number of convolutional layers per downsampling level/block.
61 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
62 | name: prefix of the created keras model and its layers.
63 |
64 | ---------- (keywords of Swin-Transformers) ----------
65 |
66 | patch_size: The size of extracted patches,
67 | e.g., patch_size=(2, 2) means 2-by-2 patches
68 | *Height and width of the patch must be equal.
69 |
70 | num_heads: number of attention heads per down/upsampling level,
71 | e.g., num_heads=[4, 8, 16, 16] means increased attention heads with increasing depth.
72 | *The length of num_heads must equal to `depth`.
73 |
74 | window_size: the size of attention window per down/upsampling level,
75 | e.g., window_size=[4, 2, 2, 2] means decreased window size with increasing depth.
76 |
77 | num_mlp: number of MLP nodes.
78 |
79 | shift_window: The indicator of window shifting;
80 | shift_window=True means applying Swin-MSA for every two Swin-Transformer blocks.
81 | shift_window=False means MSA with fixed window locations for all blocks.
82 |
83 | Output
84 | ----------
85 | output tensor.
86 |
87 | Note: This function is experimental.
88 | The activation functions of all Swin-Transformers are fixed to GELU.
89 |
90 | '''
91 | # Compute number be patches to be embeded
92 | input_size = input_tensor.shape.as_list()[1:]
93 | num_patch_x = input_size[0]//patch_size[0]
94 | num_patch_y = input_size[1]//patch_size[1]
95 |
96 | # Number of Embedded dimensions
97 | embed_dim = filter_num_begin
98 |
99 | depth_ = depth
100 |
101 | X_skip = []
102 |
103 | X = input_tensor
104 |
105 | # Patch extraction
106 | X = patch_extract(patch_size)(X)
107 |
108 | # Embed patches to tokens
109 | X = patch_embedding(num_patch_x*num_patch_y, embed_dim)(X)
110 |
111 | # The first Swin Transformer stack
112 | X = swin_transformer_stack(X, stack_num=stack_num_down,
113 | embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y),
114 | num_heads=num_heads[0], window_size=window_size[0], num_mlp=num_mlp,
115 | shift_window=shift_window, name='{}_swin_down0'.format(name))
116 | X_skip.append(X)
117 |
118 | # Downsampling blocks
119 | for i in range(depth_-1):
120 |
121 | # Patch merging
122 | X = patch_merging((num_patch_x, num_patch_y), embed_dim=embed_dim, name='down{}'.format(i))(X)
123 |
124 | # update token shape info
125 | embed_dim = embed_dim*2
126 | num_patch_x = num_patch_x//2
127 | num_patch_y = num_patch_y//2
128 |
129 | # Swin Transformer stacks
130 | X = swin_transformer_stack(X, stack_num=stack_num_down,
131 | embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y),
132 | num_heads=num_heads[i+1], window_size=window_size[i+1], num_mlp=num_mlp,
133 | shift_window=shift_window, name='{}_swin_down{}'.format(name, i+1))
134 |
135 | # Store tensors for concat
136 | X_skip.append(X)
137 |
138 | # reverse indexing encoded tensors and hyperparams
139 | X_skip = X_skip[::-1]
140 | num_heads = num_heads[::-1]
141 | window_size = window_size[::-1]
142 |
143 | # upsampling begins at the deepest available tensor
144 | X = X_skip[0]
145 |
146 | # other tensors are preserved for concatenation
147 | X_decode = X_skip[1:]
148 |
149 | depth_decode = len(X_decode)
150 |
151 | for i in range(depth_decode):
152 |
153 | # Patch expanding
154 | X = patch_expanding(num_patch=(num_patch_x, num_patch_y),
155 | embed_dim=embed_dim, upsample_rate=2, return_vector=True, name='{}_swin_up{}'.format(name, i))(X)
156 |
157 |
158 | # update token shape info
159 | embed_dim = embed_dim//2
160 | num_patch_x = num_patch_x*2
161 | num_patch_y = num_patch_y*2
162 |
163 | # Concatenation and linear projection
164 | X = concatenate([X, X_decode[i]], axis=-1, name='{}_concat_{}'.format(name, i))
165 | X = Dense(embed_dim, use_bias=False, name='{}_concat_linear_proj_{}'.format(name, i))(X)
166 |
167 | # Swin Transformer stacks
168 | X = swin_transformer_stack(X, stack_num=stack_num_up,
169 | embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y),
170 | num_heads=num_heads[i], window_size=window_size[i], num_mlp=num_mlp,
171 | shift_window=shift_window, name='{}_swin_up{}'.format(name, i))
172 |
173 | # The last expanding layer; it produces full-size feature maps based on the patch size
174 | # !!! <--- "patch_size[0]" is used; it assumes patch_size = (size, size)
175 | X = patch_expanding(num_patch=(num_patch_x, num_patch_y),
176 | embed_dim=embed_dim, upsample_rate=patch_size[0], return_vector=False)(X)
177 |
178 | return X
179 |
180 |
181 | def swin_unet_2d(input_size, filter_num_begin, n_labels, depth, stack_num_down, stack_num_up,
182 | patch_size, num_heads, window_size, num_mlp, output_activation='Softmax', shift_window=True, name='swin_unet'):
183 | '''
184 | The base of SwinUNET.
185 |
186 | ----------
187 | Cao, H., Wang, Y., Chen, J., Jiang, D., Zhang, X., Tian, Q. and Wang, M., 2021.
188 | Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation. arXiv preprint arXiv:2105.05537.
189 |
190 | Input
191 | ----------
192 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
193 | filter_num_begin: number of channels in the first downsampling block;
194 | it is also the number of embedded dimensions.
195 | n_labels: number of output labels.
196 | depth: the depth of Swin-UNET, e.g., depth=4 means three down/upsampling levels and a bottom level.
197 | stack_num_down: number of convolutional layers per downsampling level/block.
198 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
199 | name: prefix of the created keras model and its layers.
200 |
201 | ---------- (keywords of Swin-Transformers) ----------
202 |
203 | patch_size: The size of extracted patches,
204 | e.g., patch_size=(2, 2) means 2-by-2 patches
205 | *Height and width of the patch must be equal.
206 |
207 | num_heads: number of attention heads per down/upsampling level,
208 | e.g., num_heads=[4, 8, 16, 16] means increased attention heads with increasing depth.
209 | *The length of num_heads must equal to `depth`.
210 |
211 | window_size: the size of attention window per down/upsampling level,
212 | e.g., window_size=[4, 2, 2, 2] means decreased window size with increasing depth.
213 |
214 | num_mlp: number of MLP nodes.
215 |
216 | shift_window: The indicator of window shifting;
217 | shift_window=True means applying Swin-MSA for every two Swin-Transformer blocks.
218 | shift_window=False means MSA with fixed window locations for all blocks.
219 |
220 | Output
221 | ----------
222 | model: a keras model.
223 |
224 | Note: This function is experimental.
225 | The activation functions of all Swin-Transformers are fixed to GELU.
226 | '''
227 | IN = Input(input_size)
228 |
229 | # base
230 | X = swin_unet_2d_base(IN, filter_num_begin=filter_num_begin, depth=depth, stack_num_down=stack_num_down, stack_num_up=stack_num_up,
231 | patch_size=patch_size, num_heads=num_heads, window_size=window_size, num_mlp=num_mlp, shift_window=shift_window, name=name)
232 |
233 | # output layer
234 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
235 |
236 | # functional API model
237 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name))
238 |
239 | return model
240 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_transunet_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right
7 | from keras_unet_collection.transformer_layers import patch_extract, patch_embedding
8 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker
9 |
10 | import tensorflow as tf
11 | from tensorflow.keras.layers import Input
12 | from tensorflow.keras.models import Model
13 | from tensorflow.keras.layers import Layer, MultiHeadAttention, LayerNormalization, Dense, Embedding
14 |
15 | def ViT_MLP(X, filter_num, activation='GELU', name='MLP'):
16 | '''
17 | The MLP block of ViT.
18 |
19 | ----------
20 | Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner,
21 | T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020.
22 | An image is worth 16x16 words: Transformers for image recognition at scale.
23 | arXiv preprint arXiv:2010.11929.
24 |
25 | Input
26 | ----------
27 | X: the input tensor of MLP, i.e., after MSA and skip connections
28 | filter_num: a list that defines the number of nodes for each MLP layer.
29 | For the last MLP layer, its number of node must equal to the dimension of key.
30 | activation: activation of MLP nodes.
31 | name: prefix of the created keras layers.
32 |
33 | Output
34 | ----------
35 | V: output tensor.
36 |
37 | '''
38 | activation_func = eval(activation)
39 |
40 | for i, f in enumerate(filter_num):
41 | X = Dense(f, name='{}_dense_{}'.format(name, i))(X)
42 | X = activation_func(name='{}_activation_{}'.format(name, i))(X)
43 |
44 | return X
45 |
46 | def ViT_block(V, num_heads, key_dim, filter_num_MLP, activation='GELU', name='ViT'):
47 | '''
48 |
49 | Vision transformer (ViT) block.
50 |
51 | ViT_block(V, num_heads, key_dim, filter_num_MLP, activation='GELU', name='ViT')
52 |
53 | ----------
54 | Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner,
55 | T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020.
56 | An image is worth 16x16 words: Transformers for image recognition at scale.
57 | arXiv preprint arXiv:2010.11929.
58 |
59 | Input
60 | ----------
61 | V: embedded input features.
62 | num_heads: number of attention heads.
63 | key_dim: dimension of the attention key (equals to the embeded dimensions).
64 | filter_num_MLP: a list that defines the number of nodes for each MLP layer.
65 | For the last MLP layer, its number of node must equal to the dimension of key.
66 | activation: activation of MLP nodes.
67 | name: prefix of the created keras layers.
68 |
69 | Output
70 | ----------
71 | V: output tensor.
72 |
73 | '''
74 | # Multiheaded self-attention (MSA)
75 | V_atten = V # <--- skip
76 | V_atten = LayerNormalization(name='{}_layer_norm_1'.format(name))(V_atten)
77 | V_atten = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim,
78 | name='{}_atten'.format(name))(V_atten, V_atten)
79 | # Skip connection
80 | V_add = add([V_atten, V], name='{}_skip_1'.format(name)) # <--- skip
81 |
82 | # MLP
83 | V_MLP = V_add # <--- skip
84 | V_MLP = LayerNormalization(name='{}_layer_norm_2'.format(name))(V_MLP)
85 | V_MLP = ViT_MLP(V_MLP, filter_num_MLP, activation, name='{}_mlp'.format(name))
86 | # Skip connection
87 | V_out = add([V_MLP, V_add], name='{}_skip_2'.format(name)) # <--- skip
88 |
89 | return V_out
90 |
91 |
92 | def transunet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
93 | embed_dim=768, num_mlp=3072, num_heads=12, num_transformer=12,
94 | activation='ReLU', mlp_activation='GELU', batch_norm=False, pool=True, unpool=True,
95 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet'):
96 | '''
97 | The base of transUNET with an optional ImageNet-trained backbone.
98 |
99 | ----------
100 | Chen, J., Lu, Y., Yu, Q., Luo, X., Adeli, E., Wang, Y., Lu, L., Yuille, A.L. and Zhou, Y., 2021.
101 | Transunet: Transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306.
102 |
103 | Input
104 | ----------
105 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
106 | filter_num: a list that defines the number of filters for each \
107 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
108 | The depth is expected as `len(filter_num)`.
109 | stack_num_down: number of convolutional layers per downsampling level/block.
110 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
111 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
112 | batch_norm: True for batch normalization.
113 | pool: True or 'max' for MaxPooling2D.
114 | 'ave' for AveragePooling2D.
115 | False for strided conv + batch norm + activation.
116 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
117 | 'nearest' for Upsampling2D with nearest interpolation.
118 | False for Conv2DTranspose + batch norm + activation.
119 | name: prefix of the created keras model and its layers.
120 |
121 | ---------- (keywords of ViT) ----------
122 | embed_dim: number of embedded dimensions.
123 | num_mlp: number of MLP nodes.
124 | num_heads: number of attention heads.
125 | num_transformer: number of stacked ViTs.
126 | mlp_activation: activation of MLP nodes.
127 |
128 | ---------- (keywords of backbone options) ----------
129 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
130 | None (default) means no backbone.
131 | Currently supported backbones are:
132 | (1) VGG16, VGG19
133 | (2) ResNet50, ResNet101, ResNet152
134 | (3) ResNet50V2, ResNet101V2, ResNet152V2
135 | (4) DenseNet121, DenseNet169, DenseNet201
136 | (5) EfficientNetB[0-7]
137 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
138 | or the path to the weights file to be loaded.
139 | freeze_backbone: True for a frozen backbone.
140 | freeze_batch_norm: False for not freezing batch normalization layers.
141 |
142 | Output
143 | ----------
144 | X: output tensor.
145 |
146 | '''
147 | activation_func = eval(activation)
148 |
149 | X_skip = []
150 | depth_ = len(filter_num)
151 |
152 | # ----- internal parameters ----- #
153 |
154 | # patch size (fixed to 1-by-1)
155 | patch_size = 1
156 |
157 | # input tensor size
158 | input_size = input_tensor.shape[1]
159 |
160 | # encoded feature map size
161 | encode_size = input_size // 2**(depth_-1)
162 |
163 | # number of size-1 patches
164 | num_patches = encode_size ** 2
165 |
166 | # dimension of the attention key (= dimension of embedings)
167 | key_dim = embed_dim
168 |
169 | # number of MLP nodes
170 | filter_num_MLP = [num_mlp, embed_dim]
171 |
172 | # ----- UNet-like downsampling ----- #
173 |
174 | # no backbone cases
175 | if backbone is None:
176 |
177 | X = input_tensor
178 |
179 | # stacked conv2d before downsampling
180 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation,
181 | batch_norm=batch_norm, name='{}_down0'.format(name))
182 | X_skip.append(X)
183 |
184 | # downsampling blocks
185 | for i, f in enumerate(filter_num[1:]):
186 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, pool=pool,
187 | batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))
188 | X_skip.append(X)
189 |
190 | # backbone cases
191 | else:
192 | # handling VGG16 and VGG19 separately
193 | if 'VGG' in backbone:
194 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm)
195 | # collecting backbone feature maps
196 | X_skip = backbone_([input_tensor,])
197 | depth_encode = len(X_skip)
198 |
199 | # for other backbones
200 | else:
201 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm)
202 | # collecting backbone feature maps
203 | X_skip = backbone_([input_tensor,])
204 | depth_encode = len(X_skip) + 1
205 |
206 |
207 | # extra conv2d blocks are applied
208 | # if downsampling levels of a backbone < user-specified downsampling levels
209 | if depth_encode < depth_:
210 |
211 | # begins at the deepest available tensor
212 | X = X_skip[-1]
213 |
214 | # extra downsamplings
215 | for i in range(depth_-depth_encode):
216 | i_real = i + depth_encode
217 |
218 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool,
219 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1))
220 | X_skip.append(X)
221 |
222 | # subtrack the last tensor (will be replaced by the ViT output)
223 | X = X_skip[-1]
224 | X_skip = X_skip[:-1]
225 |
226 | # 1-by-1 linear transformation before entering ViT blocks
227 | X = Conv2D(filter_num[-1], 1, padding='valid', use_bias=False, name='{}_conv_trans_before'.format(name))(X)
228 |
229 | X = patch_extract((patch_size, patch_size))(X)
230 | X = patch_embedding(num_patches, embed_dim)(X)
231 |
232 | # stacked ViTs
233 | for i in range(num_transformer):
234 | X = ViT_block(X, num_heads, key_dim, filter_num_MLP, activation=mlp_activation,
235 | name='{}_ViT_{}'.format(name, i))
236 |
237 | # reshape patches to feature maps
238 | X = tf.reshape(X, (-1, encode_size, encode_size, embed_dim))
239 |
240 | # 1-by-1 linear transformation to adjust the number of channels
241 | X = Conv2D(filter_num[-1], 1, padding='valid', use_bias=False, name='{}_conv_trans_after'.format(name))(X)
242 |
243 | X_skip.append(X)
244 |
245 | # ----- UNet-like upsampling ----- #
246 |
247 | # reverse indexing encoded feature maps
248 | X_skip = X_skip[::-1]
249 | # upsampling begins at the deepest available tensor
250 | X = X_skip[0]
251 | # other tensors are preserved for concatenation
252 | X_decode = X_skip[1:]
253 | depth_decode = len(X_decode)
254 |
255 | # reverse indexing filter numbers
256 | filter_num_decode = filter_num[:-1][::-1]
257 |
258 | # upsampling with concatenation
259 | for i in range(depth_decode):
260 | X = UNET_right(X, [X_decode[i],], filter_num_decode[i], stack_num=stack_num_up, activation=activation,
261 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i))
262 |
263 | # if tensors for concatenation is not enough
264 | # then use upsampling without concatenation
265 | if depth_decode < depth_-1:
266 | for i in range(depth_-depth_decode-1):
267 | i_real = i + depth_decode
268 | X = UNET_right(X, None, filter_num_decode[i_real], stack_num=stack_num_up, activation=activation,
269 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_up{}'.format(name, i_real))
270 |
271 | return X
272 |
273 | def transunet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2,
274 | embed_dim=768, num_mlp = 3072, num_heads=12, num_transformer=12,
275 | activation='ReLU', mlp_activation='GELU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True,
276 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet'):
277 | '''
278 | TransUNET with an optional ImageNet-trained bakcbone.
279 |
280 |
281 | ----------
282 | Chen, J., Lu, Y., Yu, Q., Luo, X., Adeli, E., Wang, Y., Lu, L., Yuille, A.L. and Zhou, Y., 2021.
283 | Transunet: Transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306.
284 |
285 | Input
286 | ----------
287 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
288 | filter_num: a list that defines the number of filters for each \
289 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
290 | The depth is expected as `len(filter_num)`.
291 | n_labels: number of output labels.
292 | stack_num_down: number of convolutional layers per downsampling level/block.
293 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
294 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
295 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
296 | Default option is 'Softmax'.
297 | if None is received, then linear activation is applied.
298 | batch_norm: True for batch normalization.
299 | pool: True or 'max' for MaxPooling2D.
300 | 'ave' for AveragePooling2D.
301 | False for strided conv + batch norm + activation.
302 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
303 | 'nearest' for Upsampling2D with nearest interpolation.
304 | False for Conv2DTranspose + batch norm + activation.
305 | name: prefix of the created keras model and its layers.
306 |
307 | ---------- (keywords of ViT) ----------
308 | embed_dim: number of embedded dimensions.
309 | num_mlp: number of MLP nodes.
310 | num_heads: number of attention heads.
311 | num_transformer: number of stacked ViTs.
312 | mlp_activation: activation of MLP nodes.
313 |
314 | ---------- (keywords of backbone options) ----------
315 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
316 | None (default) means no backbone.
317 | Currently supported backbones are:
318 | (1) VGG16, VGG19
319 | (2) ResNet50, ResNet101, ResNet152
320 | (3) ResNet50V2, ResNet101V2, ResNet152V2
321 | (4) DenseNet121, DenseNet169, DenseNet201
322 | (5) EfficientNetB[0-7]
323 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
324 | or the path to the weights file to be loaded.
325 | freeze_backbone: True for a frozen backbone.
326 | freeze_batch_norm: False for not freezing batch normalization layers.
327 |
328 | Output
329 | ----------
330 | model: a keras model.
331 |
332 | '''
333 |
334 | activation_func = eval(activation)
335 |
336 | IN = Input(input_size)
337 |
338 | # base
339 | X = transunet_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up,
340 | embed_dim=embed_dim, num_mlp=num_mlp, num_heads=num_heads, num_transformer=num_transformer,
341 | activation=activation, mlp_activation=mlp_activation, batch_norm=batch_norm, pool=pool, unpool=unpool,
342 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, freeze_batch_norm=freeze_batch_norm, name=name)
343 |
344 | # output layer
345 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
346 |
347 | # functional API model
348 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name))
349 |
350 | return model
351 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_u2net_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 |
7 | from tensorflow.keras.layers import Input
8 | from tensorflow.keras.models import Model
9 |
10 |
11 | def RSU(X, channel_in, channel_out, depth=5, activation='ReLU', batch_norm=True, pool=True, unpool=True, name='RSU'):
12 | '''
13 | The Residual U-blocks (RSU).
14 |
15 | RSU(X, channel_in, channel_out, depth=5, activation='ReLU', batch_norm=True, pool=True, unpool=True, name='RSU')
16 |
17 | ----------
18 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020.
19 | U2-Net: Going deeper with nested U-structure for salient object detection.
20 | Pattern Recognition, 106, p.107404.
21 |
22 | Input
23 | ----------
24 | X: input tensor.
25 | channel_in: number of intermediate channels.
26 | channel_out: number of output channels.
27 | depth: number of down- and upsampling levels.
28 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
29 | batch_norm: True for batch normalization, False otherwise.
30 | pool: True or 'max' for MaxPooling2D.
31 | 'ave' for AveragePooling2D.
32 | False for strided conv + batch norm + activation.
33 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
34 | 'nearest' for Upsampling2D with nearest interpolation.
35 | False for Conv2DTranspose + batch norm + activation.
36 | name: prefix of the created keras layers.
37 |
38 | Output
39 | ----------
40 | X: output tensor.
41 |
42 | '''
43 |
44 | pool_size = 2
45 |
46 | X_skip = []
47 |
48 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1,
49 | dilation_rate=1, activation=activation, batch_norm=batch_norm,
50 | name='{}_in'.format(name))
51 | X_skip.append(X)
52 |
53 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=1,
54 | activation=activation, batch_norm=batch_norm, name='{}_down_0'.format(name))
55 | X_skip.append(X)
56 |
57 | for i in range(depth):
58 |
59 | X = encode_layer(X, channel_in, pool_size, pool, activation=activation,
60 | batch_norm=batch_norm, name='{}_encode_{}'.format(name, i))
61 |
62 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=1,
63 | activation=activation, batch_norm=batch_norm, name='{}_down_{}'.format(name, i+1))
64 | X_skip.append(X)
65 |
66 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1,
67 | dilation_rate=2, activation=activation, batch_norm=batch_norm,
68 | name='{}_up_0'.format(name))
69 |
70 | X_skip = X_skip[::-1]
71 |
72 | for i in range(depth):
73 |
74 | X = concatenate([X, X_skip[i]], axis=-1, name='{}_concat_{}'.format(name, i))
75 |
76 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=1,
77 | activation=activation, batch_norm=batch_norm, name='{}_up_{}'.format(name, i+1))
78 |
79 | X = decode_layer(X, channel_in, pool_size, unpool,
80 | activation=activation, batch_norm=batch_norm, name='{}_decode_{}'.format(name, i))
81 |
82 | X = concatenate([X, X_skip[depth]], axis=-1, name='{}_concat_out'.format(name))
83 |
84 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1, dilation_rate=1,
85 | activation=activation, batch_norm=batch_norm, name='{}_out'.format(name))
86 | X = add([X, X_skip[-1]], name='{}_out_add'.format(name))
87 | return X
88 |
89 | def RSU4F(X, channel_in, channel_out, dilation_num=[1, 2, 4, 8], activation='ReLU', batch_norm=True, name='RSU4F'):
90 | '''
91 | The Residual U-blocks with dilated convolutional kernels (RSU4F).
92 |
93 | RSU4F(X, channel_in, channel_out, dilation_num=[1, 2, 4, 8], activation='ReLU', batch_norm=True, name='RSU4F')
94 |
95 | ----------
96 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020.
97 | U2-Net: Going deeper with nested U-structure for salient object detection.
98 | Pattern Recognition, 106, p.107404.
99 |
100 | Input
101 | ----------
102 | X: input tensor.
103 | channel_in: number of intermediate channels.
104 | channel_out: number of output channels.
105 | dilation_num: an iterable that defines dilation rates of convolutional layers.
106 | Qin et al. (2020) suggested `[1, 2, 4, 8]`.
107 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
108 | batch_norm: True for batch normalization, False otherwise.
109 | name: prefix of the created keras layers.
110 |
111 | Output
112 | ----------
113 | X: output tensor
114 |
115 | '''
116 |
117 | X_skip = []
118 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1, dilation_rate=1,
119 | activation=activation, batch_norm=batch_norm, name='{}_in'.format(name))
120 | X_skip.append(X)
121 |
122 | for i, d in enumerate(dilation_num):
123 |
124 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=d,
125 | activation=activation, batch_norm=batch_norm, name='{}_down_{}'.format(name, i))
126 | X_skip.append(X)
127 |
128 | X_skip = X_skip[:-1][::-1]
129 | dilation_num = dilation_num[:-1][::-1]
130 |
131 | for i, d in enumerate(dilation_num[:-1]):
132 |
133 | X = concatenate([X, X_skip[i]], axis=-1, name='{}_concat_{}'.format(name, i))
134 | X = CONV_stack(X, channel_in, kernel_size=3, stack_num=1, dilation_rate=d,
135 | activation=activation, batch_norm=batch_norm, name='{}_up_{}'.format(name, i))
136 |
137 | X = concatenate([X, X_skip[2]], axis=-1, name='{}_concat_out'.format(name))
138 | X = CONV_stack(X, channel_out, kernel_size=3, stack_num=1, dilation_rate=1,
139 | activation=activation, batch_norm=batch_norm, name='{}_out'.format(name))
140 |
141 | return add([X, X_skip[-1]], name='{}_out_add'.format(name))
142 |
143 | def u2net_2d_base(input_tensor,
144 | filter_num_down, filter_num_up,
145 | filter_mid_num_down, filter_mid_num_up,
146 | filter_4f_num, filter_4f_mid_num, activation='ReLU',
147 | batch_norm=False, pool=True, unpool=True, name='u2net'):
148 |
149 | '''
150 | The base of U^2-Net
151 |
152 | u2net_2d_base(input_tensor,
153 | filter_num_down, filter_num_up,
154 | filter_mid_num_down, filter_mid_num_up,
155 | filter_4f_num, filter_4f_mid_num, activation='ReLU',
156 | batch_norm=False, pool=True, unpool=True, name='u2net')
157 |
158 | ----------
159 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020.
160 | U2-Net: Going deeper with nested U-structure for salient object detection.
161 | Pattern Recognition, 106, p.107404.
162 |
163 | Input
164 | ----------
165 | input_tensor: the input tensor of the base, e.g., keras.layers.Inpyt((None, None, 3))
166 | filter_num_down: a list that defines the number of RSU output filters for each
167 | downsampling level. e.g., `[64, 128, 256, 512]`.
168 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)`
169 | filter_mid_num_down: a list that defines the number of RSU intermediate filters for each
170 | downsampling level. e.g., `[16, 32, 64, 128]`.
171 | * RSU intermediate and output filters must paired, i.e., list with the same length.
172 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers.
173 | filter_mid_num_up: a list that defines the number of RSU intermediate filters for each
174 | upsampling level. e.g., `[16, 32, 64, 128]`.
175 | * RSU intermediate and output filters must paired, i.e., list with the same length.
176 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers.
177 | filter_4f_num: a list that defines the number of RSU-4F output filters for each
178 | downsampling and bottom level. e.g., `[512, 512]`.
179 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)`.
180 | filter_4f_mid_num: a list that defines the number of RSU-4F intermediate filters for each
181 | downsampling and bottom level. e.g., `[256, 256]`.
182 | * RSU-4F intermediate and output filters must paired, i.e., list with the same length.
183 | * RSU-4F intermediate filters numbers are expected to be smaller than output filters numbers.
184 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
185 | batch_norm: True for batch normalization.
186 | pool: True or 'max' for MaxPooling2D.
187 | 'ave' for AveragePooling2D.
188 | False for strided conv + batch norm + activation.
189 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
190 | 'nearest' for Upsampling2D with nearest interpolation.
191 | False for Conv2DTranspose + batch norm + activation.
192 | deep_supervision: True for a model that supports deep supervision. Details see Qin et al. (2020).
193 | name: prefix of the created keras layers.
194 |
195 | Output
196 | ----------
197 | A list of tensors with the first/second/third tensor obtained from
198 | the deepest/second deepest/third deepest upsampling block, etc.
199 | * The feature map sizes of these tensors are different,
200 | with first tensor has the smallest size.
201 |
202 | * Dilation rates of RSU4F layers are fixed to `[1, 2, 4, 8]`.
203 | * Downsampling is achieved through maxpooling in Qin et al. (2020),
204 | and can be replaced by strided convolutional layers here.
205 | * Upsampling is achieved through bilinear interpolation in Qin et al. (2020),
206 | and can be replaced by transpose convolutional layers here.
207 |
208 | '''
209 |
210 | pool_size = 2
211 |
212 | X_skip = []; X_out = []; OUT_stack = []
213 | depth_backup = []
214 | depth_ = len(filter_num_down)
215 |
216 | X = input_tensor
217 |
218 | X = RSU(X, filter_mid_num_down[0], filter_num_down[0], depth=depth_+1, activation=activation,
219 | batch_norm=batch_norm, pool=pool, unpool=unpool, name='{}_in'.format(name))
220 | X_skip.append(X)
221 |
222 | depth_backup.append(depth_+1)
223 |
224 | for i, f in enumerate(filter_num_down[1:]):
225 |
226 | X = encode_layer(X, f, pool_size, pool, activation=activation,
227 | batch_norm=batch_norm, name='{}_encode_{}'.format(name, i))
228 |
229 | X = RSU(X, filter_mid_num_down[i+1], f, depth=depth_-i, activation=activation,
230 | batch_norm=batch_norm, pool=pool, unpool=unpool, name='{}_down_{}'.format(name, i))
231 |
232 | depth_backup.append(depth_-i)
233 |
234 | X_skip.append(X)
235 |
236 | for i, f in enumerate(filter_4f_num):
237 |
238 | X = encode_layer(X, f, pool_size, pool, activation=activation,
239 | batch_norm=batch_norm, name='{}_encode_4f_{}'.format(name, i))
240 |
241 | X = RSU4F(X, filter_4f_mid_num[i], f, activation=activation,
242 | batch_norm=batch_norm, name='{}_down_4f_{}'.format(name, i))
243 | X_skip.append(X)
244 |
245 | X_out.append(X)
246 |
247 | # ---------- #
248 | X_skip = X_skip[:-1][::-1]
249 | depth_backup = depth_backup[::-1]
250 |
251 | filter_num_up = filter_num_up[::-1]
252 | filter_mid_num_up = filter_mid_num_up[::-1]
253 |
254 | filter_4f_num = filter_4f_num[:-1][::-1]
255 | filter_4f_mid_num = filter_4f_mid_num[:-1][::-1]
256 |
257 | tensor_count = 0
258 | for i, f in enumerate(filter_4f_num):
259 |
260 | X = decode_layer(X, f, pool_size, unpool,
261 | activation=activation, batch_norm=batch_norm, name='{}_decode_4f_{}'.format(name, i))
262 |
263 | X = concatenate([X, X_skip[tensor_count]], axis=-1, name='{}_concat_4f_{}'.format(name, i))
264 |
265 | X = RSU4F(X, filter_4f_mid_num[i], f, activation=activation,
266 | batch_norm=batch_norm, name='{}_up_4f_{}'.format(name, i))
267 | X_out.append(X)
268 |
269 | tensor_count += 1
270 |
271 | for i, f in enumerate(filter_num_up):
272 |
273 | X = decode_layer(X, f, pool_size, unpool,
274 | activation=activation, batch_norm=batch_norm, name='{}_decode_{}'.format(name, i))
275 |
276 | X = concatenate([X, X_skip[tensor_count]], axis=-1, name='{}_concat_{}'.format(name, i))
277 |
278 | X = RSU(X, filter_mid_num_up[i], f, depth=depth_backup[i],
279 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name='{}_up_{}'.format(name, i))
280 | X_out.append(X)
281 |
282 | tensor_count += 1
283 |
284 | return X_out
285 |
286 |
287 | def u2net_2d(input_size, n_labels, filter_num_down, filter_num_up='auto', filter_mid_num_down='auto', filter_mid_num_up='auto',
288 | filter_4f_num='auto', filter_4f_mid_num='auto', activation='ReLU', output_activation='Sigmoid',
289 | batch_norm=False, pool=True, unpool=True, deep_supervision=False, name='u2net'):
290 |
291 | '''
292 | U^2-Net
293 |
294 | u2net_2d(input_size, n_labels, filter_num_down, filter_num_up='auto', filter_mid_num_down='auto', filter_mid_num_up='auto',
295 | filter_4f_num='auto', filter_4f_mid_num='auto', activation='ReLU', output_activation='Sigmoid',
296 | batch_norm=False, deep_supervision=False, name='u2net')
297 |
298 | ----------
299 | Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R. and Jagersand, M., 2020.
300 | U2-Net: Going deeper with nested U-structure for salient object detection.
301 | Pattern Recognition, 106, p.107404.
302 |
303 | Input
304 | ----------
305 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
306 | filter_num_down: a list that defines the number of RSU output filters for each
307 | downsampling level. e.g., `[64, 128, 256, 512]`.
308 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)`
309 | filter_mid_num_down: a list that defines the number of RSU intermediate filters for each
310 | downsampling level. e.g., `[16, 32, 64, 128]`.
311 | * RSU intermediate and output filters must paired, i.e., list with the same length.
312 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers.
313 | filter_mid_num_up: a list that defines the number of RSU intermediate filters for each
314 | upsampling level. e.g., `[16, 32, 64, 128]`.
315 | * RSU intermediate and output filters must paired, i.e., list with the same length.
316 | * RSU intermediate filters numbers are expected to be smaller than output filters numbers.
317 | filter_4f_num: a list that defines the number of RSU-4F output filters for each
318 | downsampling and bottom level. e.g., `[512, 512]`.
319 | the network depth is expected as `len(filter_num_down) + len(filter_4f_num)`.
320 | filter_4f_mid_num: a list that defines the number of RSU-4F intermediate filters for each
321 | downsampling and bottom level. e.g., `[256, 256]`.
322 | * RSU-4F intermediate and output filters must paired, i.e., list with the same length.
323 | * RSU-4F intermediate filters numbers are expected to be smaller than output filters numbers.
324 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
325 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
326 | Default option is 'Softmax'.
327 | if None is received, then linear activation is applied.
328 | batch_norm: True for batch normalization.
329 | pool: True or 'max' for MaxPooling2D.
330 | 'ave' for AveragePooling2D.
331 | False for strided conv + batch norm + activation.
332 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
333 | 'nearest' for Upsampling2D with nearest interpolation.
334 | False for Conv2DTranspose + batch norm + activation.
335 | deep_supervision: True for a model that supports deep supervision. Details see Qin et al. (2020).
336 | name: prefix of the created keras layers.
337 |
338 | Output
339 | ----------
340 | model: a keras model.
341 |
342 | * Automated hyper-parameter estimation will produce a slightly larger network, different from that of Qin et al. (2020).
343 | * Dilation rates of RSU4F layers are fixed to `[1, 2, 4, 8]`.
344 | * The default output activation is sigmoid, the same as in Qin et al. (2020).
345 | * Downsampling is achieved through maxpooling and can be replaced by strided convolutional layers.
346 | * Upsampling is achieved through bilinear interpolation and can be replaced by transpose convolutional layers.
347 |
348 | '''
349 |
350 | verbose = False
351 |
352 | if filter_num_up == 'auto':
353 | verbose = True
354 | filter_num_up = filter_num_down
355 |
356 | if filter_mid_num_down == 'auto':
357 | verbose = True
358 | filter_mid_num_down = [num//4 for num in filter_num_down]
359 |
360 | if filter_mid_num_up == 'auto':
361 | verbose = True
362 | filter_mid_num_up = filter_mid_num_down
363 |
364 | if filter_4f_num == 'auto':
365 | verbose = True
366 | filter_4f_num = [filter_num_down[-1], filter_num_down[-1]]
367 |
368 | if filter_4f_mid_num == 'auto':
369 | verbose = True
370 | filter_4f_mid_num = [num//2 for num in filter_4f_num]
371 |
372 | if verbose:
373 | print('Automated hyper-parameter determination is applied with the following details:\n----------')
374 | print('\tNumber of RSU output channels within downsampling blocks: filter_num_down = {}'.format(filter_num_down))
375 | print('\tNumber of RSU intermediate channels within downsampling blocks: filter_mid_num_down = {}'.format(filter_mid_num_down))
376 | print('\tNumber of RSU output channels within upsampling blocks: filter_num_up = {}'.format(filter_num_up))
377 | print('\tNumber of RSU intermediate channels within upsampling blocks: filter_mid_num_up = {}'.format(filter_mid_num_up))
378 | print('\tNumber of RSU-4F output channels within downsampling and bottom blocks: filter_4f_num = {}'.format(filter_4f_num))
379 | print('\tNumber of RSU-4F intermediate channels within downsampling and bottom blocks: filter_4f_num = {}'.format(filter_4f_mid_num))
380 | print('----------\nExplicitly specifying keywords listed above if their "auto" settings do not satisfy your needs')
381 |
382 | print("----------\nThe depth of u2net_2d = len(filter_num_down) + len(filter_4f_num) = {}".format(len(filter_num_down)+len(filter_4f_num)))
383 |
384 | X_skip = []; X_out = []; OUT_stack = []
385 | depth_backup = []
386 | depth_ = len(filter_num_down)
387 |
388 | IN = Input(shape=input_size)
389 |
390 | # base (before conv + activation + upsample)
391 | X_out = u2net_2d_base(IN,
392 | filter_num_down, filter_num_up,
393 | filter_mid_num_down, filter_mid_num_up,
394 | filter_4f_num, filter_4f_mid_num, activation=activation,
395 | batch_norm=batch_norm, pool=pool, unpool=unpool, name=name)
396 |
397 | # output layers
398 | X_out = X_out[::-1]
399 | L_out = len(X_out)
400 |
401 | X = CONV_output(X_out[0], n_labels, kernel_size=3, activation=output_activation,
402 | name='{}_output_sup0'.format(name))
403 | OUT_stack.append(X)
404 |
405 | for i in range(1, L_out):
406 |
407 | pool_size = 2**(i)
408 |
409 | X = Conv2D(n_labels, 3, padding='same', name='{}_output_conv_{}'.format(name, i))(X_out[i])
410 |
411 | X = decode_layer(X, n_labels, pool_size, unpool,
412 | activation=None, batch_norm=False, name='{}_sup{}'.format(name, i))
413 |
414 | if output_activation:
415 | if output_activation == 'Sigmoid':
416 | X = Activation('sigmoid', name='{}_output_sup{}_activation'.format(name, i))(X)
417 | else:
418 | activation_func = eval(output_activation)
419 | X = activation_func(name='{}_output_sup{}_activation'.format(name, i))(X)
420 |
421 | OUT_stack.append(X)
422 |
423 | D = concatenate(OUT_stack, axis=-1, name='{}_output_concat'.format(name))
424 |
425 | D = CONV_output(D, n_labels, kernel_size=1, activation=output_activation,
426 | name='{}_output_final'.format(name))
427 |
428 | if deep_supervision:
429 |
430 | OUT_stack.append(D)
431 | print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows ("sup0" is the shallowest supervision layer;\n"final" is the final output layer):\n')
432 |
433 | if output_activation == None:
434 | if unpool is False:
435 | for i in range(L_out):
436 | print('\t{}_output_sup{}_trans_conv'.format(name, i))
437 | else:
438 | for i in range(L_out):
439 | print('\t{}_output_sup{}_unpool'.format(name, i))
440 |
441 | print('\t{}_output_final'.format(name))
442 |
443 | else:
444 | for i in range(L_out):
445 | print('\t{}_output_sup{}_activation'.format(name, i))
446 |
447 | print('\t{}_output_final_activation'.format(name))
448 |
449 | model = Model([IN,], OUT_stack)
450 |
451 | else:
452 | model = Model([IN,], [D,])
453 |
454 | return model
455 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_unet_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker
7 |
8 | from tensorflow.keras.layers import Input
9 | from tensorflow.keras.models import Model
10 |
11 | def UNET_left(X, channel, kernel_size=3, stack_num=2, activation='ReLU',
12 | pool=True, batch_norm=False, name='left0'):
13 | '''
14 | The encoder block of U-net.
15 |
16 | UNET_left(X, channel, kernel_size=3, stack_num=2, activation='ReLU',
17 | pool=True, batch_norm=False, name='left0')
18 |
19 | Input
20 | ----------
21 | X: input tensor.
22 | channel: number of convolution filters.
23 | kernel_size: size of 2-d convolution kernels.
24 | stack_num: number of convolutional layers.
25 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'.
26 | pool: True or 'max' for MaxPooling2D.
27 | 'ave' for AveragePooling2D.
28 | False for strided conv + batch norm + activation.
29 | batch_norm: True for batch normalization, False otherwise.
30 | name: prefix of the created keras layers.
31 |
32 | Output
33 | ----------
34 | X: output tensor.
35 |
36 | '''
37 | pool_size = 2
38 |
39 | X = encode_layer(X, channel, pool_size, pool, activation=activation,
40 | batch_norm=batch_norm, name='{}_encode'.format(name))
41 |
42 | X = CONV_stack(X, channel, kernel_size, stack_num=stack_num, activation=activation,
43 | batch_norm=batch_norm, name='{}_conv'.format(name))
44 |
45 | return X
46 |
47 |
48 | def UNET_right(X, X_list, channel, kernel_size=3,
49 | stack_num=2, activation='ReLU',
50 | unpool=True, batch_norm=False, concat=True, name='right0'):
51 |
52 | '''
53 | The decoder block of U-net.
54 |
55 | Input
56 | ----------
57 | X: input tensor.
58 | X_list: a list of other tensors that connected to the input tensor.
59 | channel: number of convolution filters.
60 | kernel_size: size of 2-d convolution kernels.
61 | stack_num: number of convolutional layers.
62 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'.
63 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
64 | 'nearest' for Upsampling2D with nearest interpolation.
65 | False for Conv2DTranspose + batch norm + activation.
66 | batch_norm: True for batch normalization, False otherwise.
67 | concat: True for concatenating the corresponded X_list elements.
68 | name: prefix of the created keras layers.
69 |
70 | Output
71 | ----------
72 | X: output tensor.
73 |
74 | '''
75 |
76 | pool_size = 2
77 |
78 | X = decode_layer(X, channel, pool_size, unpool,
79 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name))
80 |
81 | # linear convolutional layers before concatenation
82 | X = CONV_stack(X, channel, kernel_size, stack_num=1, activation=activation,
83 | batch_norm=batch_norm, name='{}_conv_before_concat'.format(name))
84 | if concat:
85 | # <--- *stacked convolutional can be applied here
86 | X = concatenate([X,]+X_list, axis=3, name=name+'_concat')
87 |
88 | # Stacked convolutions after concatenation
89 | X = CONV_stack(X, channel, kernel_size, stack_num=stack_num, activation=activation,
90 | batch_norm=batch_norm, name=name+'_conv_after_concat')
91 |
92 | return X
93 |
94 | def unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
95 | activation='ReLU', batch_norm=False, pool=True, unpool=True,
96 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet'):
97 |
98 | '''
99 | The base of U-net with an optional ImageNet-trained backbone.
100 |
101 | unet_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
102 | activation='ReLU', batch_norm=False, pool=True, unpool=True,
103 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet')
104 |
105 | ----------
106 | Ronneberger, O., Fischer, P. and Brox, T., 2015, October. U-net: Convolutional networks for biomedical image segmentation.
107 | In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.
108 |
109 | Input
110 | ----------
111 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
112 | filter_num: a list that defines the number of filters for each \
113 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
114 | The depth is expected as `len(filter_num)`.
115 | stack_num_down: number of convolutional layers per downsampling level/block.
116 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
117 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
118 | batch_norm: True for batch normalization.
119 | pool: True or 'max' for MaxPooling2D.
120 | 'ave' for AveragePooling2D.
121 | False for strided conv + batch norm + activation.
122 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
123 | 'nearest' for Upsampling2D with nearest interpolation.
124 | False for Conv2DTranspose + batch norm + activation.
125 | name: prefix of the created keras model and its layers.
126 |
127 | ---------- (keywords of backbone options) ----------
128 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
129 | None (default) means no backbone.
130 | Currently supported backbones are:
131 | (1) VGG16, VGG19
132 | (2) ResNet50, ResNet101, ResNet152
133 | (3) ResNet50V2, ResNet101V2, ResNet152V2
134 | (4) DenseNet121, DenseNet169, DenseNet201
135 | (5) EfficientNetB[0-7]
136 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
137 | or the path to the weights file to be loaded.
138 | freeze_backbone: True for a frozen backbone.
139 | freeze_batch_norm: False for not freezing batch normalization layers.
140 |
141 | Output
142 | ----------
143 | X: output tensor.
144 |
145 | '''
146 |
147 | activation_func = eval(activation)
148 |
149 | X_skip = []
150 | depth_ = len(filter_num)
151 |
152 | # no backbone cases
153 | if backbone is None:
154 |
155 | X = input_tensor
156 |
157 | # stacked conv2d before downsampling
158 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation,
159 | batch_norm=batch_norm, name='{}_down0'.format(name))
160 | X_skip.append(X)
161 |
162 | # downsampling blocks
163 | for i, f in enumerate(filter_num[1:]):
164 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation, pool=pool,
165 | batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))
166 | X_skip.append(X)
167 |
168 | # backbone cases
169 | else:
170 | # handling VGG16 and VGG19 separately
171 | if 'VGG' in backbone:
172 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm)
173 | # collecting backbone feature maps
174 | X_skip = backbone_([input_tensor,])
175 | depth_encode = len(X_skip)
176 |
177 | # for other backbones
178 | else:
179 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm)
180 | # collecting backbone feature maps
181 | X_skip = backbone_([input_tensor,])
182 | depth_encode = len(X_skip) + 1
183 |
184 |
185 | # extra conv2d blocks are applied
186 | # if downsampling levels of a backbone < user-specified downsampling levels
187 | if depth_encode < depth_:
188 |
189 | # begins at the deepest available tensor
190 | X = X_skip[-1]
191 |
192 | # extra downsamplings
193 | for i in range(depth_-depth_encode):
194 | i_real = i + depth_encode
195 |
196 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool,
197 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1))
198 | X_skip.append(X)
199 |
200 | # reverse indexing encoded feature maps
201 | X_skip = X_skip[::-1]
202 | # upsampling begins at the deepest available tensor
203 | X = X_skip[0]
204 | # other tensors are preserved for concatenation
205 | X_decode = X_skip[1:]
206 | depth_decode = len(X_decode)
207 |
208 | # reverse indexing filter numbers
209 | filter_num_decode = filter_num[:-1][::-1]
210 |
211 | # upsampling with concatenation
212 | for i in range(depth_decode):
213 | X = UNET_right(X, [X_decode[i],], filter_num_decode[i], stack_num=stack_num_up, activation=activation,
214 | unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i))
215 |
216 | # if tensors for concatenation is not enough
217 | # then use upsampling without concatenation
218 | if depth_decode < depth_-1:
219 | for i in range(depth_-depth_decode-1):
220 | i_real = i + depth_decode
221 | X = UNET_right(X, None, filter_num_decode[i_real], stack_num=stack_num_up, activation=activation,
222 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_up{}'.format(name, i_real))
223 | return X
224 |
225 | def unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2,
226 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True,
227 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet'):
228 | '''
229 | U-net with an optional ImageNet-trained bakcbone.
230 |
231 | unet_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2,
232 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True,
233 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet')
234 |
235 | ----------
236 | Ronneberger, O., Fischer, P. and Brox, T., 2015, October. U-net: Convolutional networks for biomedical image segmentation.
237 | In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.
238 |
239 | Input
240 | ----------
241 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
242 | filter_num: a list that defines the number of filters for each \
243 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
244 | The depth is expected as `len(filter_num)`.
245 | n_labels: number of output labels.
246 | stack_num_down: number of convolutional layers per downsampling level/block.
247 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
248 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
249 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
250 | Default option is 'Softmax'.
251 | if None is received, then linear activation is applied.
252 | batch_norm: True for batch normalization.
253 | pool: True or 'max' for MaxPooling2D.
254 | 'ave' for AveragePooling2D.
255 | False for strided conv + batch norm + activation.
256 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
257 | 'nearest' for Upsampling2D with nearest interpolation.
258 | False for Conv2DTranspose + batch norm + activation.
259 | name: prefix of the created keras model and its layers.
260 |
261 | ---------- (keywords of backbone options) ----------
262 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
263 | None (default) means no backbone.
264 | Currently supported backbones are:
265 | (1) VGG16, VGG19
266 | (2) ResNet50, ResNet101, ResNet152
267 | (3) ResNet50V2, ResNet101V2, ResNet152V2
268 | (4) DenseNet121, DenseNet169, DenseNet201
269 | (5) EfficientNetB[0-7]
270 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
271 | or the path to the weights file to be loaded.
272 | freeze_backbone: True for a frozen backbone.
273 | freeze_batch_norm: False for not freezing batch normalization layers.
274 |
275 | Output
276 | ----------
277 | model: a keras model.
278 |
279 | '''
280 | activation_func = eval(activation)
281 |
282 | if backbone is not None:
283 | bach_norm_checker(backbone, batch_norm)
284 |
285 | IN = Input(input_size)
286 |
287 | # base
288 | X = unet_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up,
289 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool,
290 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone,
291 | freeze_batch_norm=freeze_backbone, name=name)
292 |
293 | # output layer
294 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
295 |
296 | # functional API model
297 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name))
298 |
299 | return model
300 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_unet_3plus_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker
7 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right
8 |
9 | from tensorflow.keras.layers import Input
10 | from tensorflow.keras.models import Model
11 |
12 | def unet_3plus_2d_base(input_tensor, filter_num_down, filter_num_skip, filter_num_aggregate,
13 | stack_num_down=2, stack_num_up=1, activation='ReLU', batch_norm=False, pool=True, unpool=True,
14 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus'):
15 | '''
16 | The base of UNET 3+ with an optional ImagNet-trained backbone.
17 |
18 | unet_3plus_2d_base(input_tensor, filter_num_down, filter_num_skip, filter_num_aggregate,
19 | stack_num_down=2, stack_num_up=1, activation='ReLU', batch_norm=False, pool=True, unpool=True,
20 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus')
21 |
22 | ----------
23 | Huang, H., Lin, L., Tong, R., Hu, H., Zhang, Q., Iwamoto, Y., Han, X., Chen, Y.W. and Wu, J., 2020.
24 | UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation.
25 | In ICASSP 2020-2020 IEEE International Conference on Acoustics,
26 | Speech and Signal Processing (ICASSP) (pp. 1055-1059). IEEE.
27 |
28 | Input
29 | ----------
30 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
31 | filter_num_down: a list that defines the number of filters for each
32 | downsampling level. e.g., `[64, 128, 256, 512, 1024]`.
33 | the network depth is expected as `len(filter_num_down)`
34 | filter_num_skip: a list that defines the number of filters after each
35 | full-scale skip connection. Number of elements is expected to be `depth-1`.
36 | i.e., the bottom level is not included.
37 | * Huang et al. (2020) applied the same numbers for all levels.
38 | e.g., `[64, 64, 64, 64]`.
39 | filter_num_aggregate: an int that defines the number of channels of full-scale aggregations.
40 | stack_num_down: number of convolutional layers per downsampling level/block.
41 | stack_num_up: number of convolutional layers (after full-scale concat) per upsampling level/block.
42 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., ReLU
43 | batch_norm: True for batch normalization.
44 | pool: True or 'max' for MaxPooling2D.
45 | 'ave' for AveragePooling2D.
46 | False for strided conv + batch norm + activation.
47 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
48 | 'nearest' for Upsampling2D with nearest interpolation.
49 | False for Conv2DTranspose + batch norm + activation.
50 | name: prefix of the created keras model and its layers.
51 |
52 | ---------- (keywords of backbone options) ----------
53 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
54 | None (default) means no backbone.
55 | Currently supported backbones are:
56 | (1) VGG16, VGG19
57 | (2) ResNet50, ResNet101, ResNet152
58 | (3) ResNet50V2, ResNet101V2, ResNet152V2
59 | (4) DenseNet121, DenseNet169, DenseNet201
60 | (5) EfficientNetB[0-7]
61 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
62 | or the path to the weights file to be loaded.
63 | freeze_backbone: True for a frozen backbone.
64 | freeze_batch_norm: False for not freezing batch normalization layers.
65 |
66 | * Downsampling is achieved through maxpooling and can be replaced by strided convolutional layers here.
67 | * Upsampling is achieved through bilinear interpolation and can be replaced by transpose convolutional layers here.
68 |
69 | Output
70 | ----------
71 | A list of tensors with the first/second/third tensor obtained from
72 | the deepest/second deepest/third deepest upsampling block, etc.
73 | * The feature map sizes of these tensors are different,
74 | with the first tensor has the smallest size.
75 |
76 | '''
77 |
78 | depth_ = len(filter_num_down)
79 |
80 | X_encoder = []
81 | X_decoder = []
82 |
83 | # no backbone cases
84 | if backbone is None:
85 |
86 | X = input_tensor
87 |
88 | # stacked conv2d before downsampling
89 | X = CONV_stack(X, filter_num_down[0], kernel_size=3, stack_num=stack_num_down,
90 | activation=activation, batch_norm=batch_norm, name='{}_down0'.format(name))
91 | X_encoder.append(X)
92 |
93 | # downsampling levels
94 | for i, f in enumerate(filter_num_down[1:]):
95 |
96 | # UNET-like downsampling
97 | X = UNET_left(X, f, kernel_size=3, stack_num=stack_num_down, activation=activation,
98 | pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))
99 | X_encoder.append(X)
100 |
101 | else:
102 | # handling VGG16 and VGG19 separately
103 | if 'VGG' in backbone:
104 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm)
105 | # collecting backbone feature maps
106 | X_encoder = backbone_([input_tensor,])
107 | depth_encode = len(X_encoder)
108 |
109 | # for other backbones
110 | else:
111 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm)
112 | # collecting backbone feature maps
113 | X_encoder = backbone_([input_tensor,])
114 | depth_encode = len(X_encoder) + 1
115 |
116 | # extra conv2d blocks are applied
117 | # if downsampling levels of a backbone < user-specified downsampling levels
118 | if depth_encode < depth_:
119 |
120 | # begins at the deepest available tensor
121 | X = X_encoder[-1]
122 |
123 | # extra downsamplings
124 | for i in range(depth_-depth_encode):
125 |
126 | i_real = i + depth_encode
127 |
128 | X = UNET_left(X, filter_num_down[i_real], stack_num=stack_num_down, activation=activation, pool=pool,
129 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1))
130 | X_encoder.append(X)
131 |
132 |
133 | # treat the last encoded tensor as the first decoded tensor
134 | X_decoder.append(X_encoder[-1])
135 |
136 | # upsampling levels
137 | X_encoder = X_encoder[::-1]
138 |
139 | depth_decode = len(X_encoder)-1
140 |
141 | # loop over upsampling levels
142 | for i in range(depth_decode):
143 |
144 | f = filter_num_skip[i]
145 |
146 | # collecting tensors for layer fusion
147 | X_fscale = []
148 |
149 | # for each upsampling level, loop over all available downsampling levels (similar to the unet++)
150 | for lev in range(depth_decode):
151 |
152 | # counting scale difference between the current down- and upsampling levels
153 | pool_scale = lev-i-1 # -1 for python indexing
154 |
155 | # deeper tensors are obtained from **decoder** outputs
156 | if pool_scale < 0:
157 | pool_size = 2**(-1*pool_scale)
158 |
159 | X = decode_layer(X_decoder[lev], f, pool_size, unpool,
160 | activation=activation, batch_norm=batch_norm, name='{}_up_{}_en{}'.format(name, i, lev))
161 |
162 | # unet skip connection (identity mapping)
163 | elif pool_scale == 0:
164 |
165 | X = X_encoder[lev]
166 |
167 | # shallower tensors are obtained from **encoder** outputs
168 | else:
169 | pool_size = 2**(pool_scale)
170 |
171 | X = encode_layer(X_encoder[lev], f, pool_size, pool, activation=activation,
172 | batch_norm=batch_norm, name='{}_down_{}_en{}'.format(name, i, lev))
173 |
174 | # a conv layer after feature map scale change
175 | X = CONV_stack(X, f, kernel_size=3, stack_num=1,
176 | activation=activation, batch_norm=batch_norm, name='{}_down_from{}_to{}'.format(name, i, lev))
177 |
178 | X_fscale.append(X)
179 |
180 | # layer fusion at the end of each level
181 | # stacked conv layers after concat. BatchNormalization is fixed to True
182 |
183 | X = concatenate(X_fscale, axis=-1, name='{}_concat_{}'.format(name, i))
184 | X = CONV_stack(X, filter_num_aggregate, kernel_size=3, stack_num=stack_num_up,
185 | activation=activation, batch_norm=True, name='{}_fusion_conv_{}'.format(name, i))
186 | X_decoder.append(X)
187 |
188 | # if tensors for concatenation is not enough
189 | # then use upsampling without concatenation
190 | if depth_decode < depth_-1:
191 | for i in range(depth_-depth_decode-1):
192 | i_real = i + depth_decode
193 | X = UNET_right(X, None, filter_num_aggregate, stack_num=stack_num_up, activation=activation,
194 | unpool=unpool, batch_norm=batch_norm, concat=False, name='{}_plain_up{}'.format(name, i_real))
195 | X_decoder.append(X)
196 |
197 | # return decoder outputs
198 | return X_decoder
199 |
200 | def unet_3plus_2d(input_size, n_labels, filter_num_down, filter_num_skip='auto', filter_num_aggregate='auto',
201 | stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid',
202 | batch_norm=False, pool=True, unpool=True, deep_supervision=False,
203 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus'):
204 |
205 | '''
206 | UNET 3+ with an optional ImageNet-trained backbone.
207 |
208 | unet_3plus_2d(input_size, n_labels, filter_num_down, filter_num_skip='auto', filter_num_aggregate='auto',
209 | stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid',
210 | batch_norm=False, pool=True, unpool=True, deep_supervision=False,
211 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet3plus')
212 |
213 | ----------
214 | Huang, H., Lin, L., Tong, R., Hu, H., Zhang, Q., Iwamoto, Y., Han, X., Chen, Y.W. and Wu, J., 2020.
215 | UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation.
216 | In ICASSP 2020-2020 IEEE International Conference on Acoustics,
217 | Speech and Signal Processing (ICASSP) (pp. 1055-1059). IEEE.
218 |
219 | Input
220 | ----------
221 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
222 | filter_num_down: a list that defines the number of filters for each
223 | downsampling level. e.g., `[64, 128, 256, 512, 1024]`.
224 | the network depth is expected as `len(filter_num_down)`
225 | filter_num_skip: a list that defines the number of filters after each
226 | full-scale skip connection. Number of elements is expected to be `depth-1`.
227 | i.e., the bottom level is not included.
228 | * Huang et al. (2020) applied the same numbers for all levels.
229 | e.g., `[64, 64, 64, 64]`.
230 | filter_num_aggregate: an int that defines the number of channels of full-scale aggregations.
231 | stack_num_down: number of convolutional layers per downsampling level/block.
232 | stack_num_up: number of convolutional layers (after full-scale concat) per upsampling level/block.
233 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'
234 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
235 | Default option is 'Softmax'.
236 | if None is received, then linear activation is applied.
237 | batch_norm: True for batch normalization.
238 | pool: True or 'max' for MaxPooling2D.
239 | 'ave' for AveragePooling2D.
240 | False for strided conv + batch norm + activation.
241 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
242 | 'nearest' for Upsampling2D with nearest interpolation.
243 | False for Conv2DTranspose + batch norm + activation.
244 | deep_supervision: True for a model that supports deep supervision. Details see Huang et al. (2020).
245 | name: prefix of the created keras model and its layers.
246 |
247 | ---------- (keywords of backbone options) ----------
248 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
249 | None (default) means no backbone.
250 | Currently supported backbones are:
251 | (1) VGG16, VGG19
252 | (2) ResNet50, ResNet101, ResNet152
253 | (3) ResNet50V2, ResNet101V2, ResNet152V2
254 | (4) DenseNet121, DenseNet169, DenseNet201
255 | (5) EfficientNetB[0-7]
256 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
257 | or the path to the weights file to be loaded.
258 | freeze_backbone: True for a frozen backbone.
259 | freeze_batch_norm: False for not freezing batch normalization layers.
260 |
261 | * The Classification-guided Module (CGM) is not implemented.
262 | See https://github.com/yingkaisha/keras-unet-collection/tree/main/examples for a relevant example.
263 | * Automated mode is applied for determining `filter_num_skip`, `filter_num_aggregate`.
264 | * The default output activation is sigmoid, consistent with Huang et al. (2020).
265 | * Downsampling is achieved through maxpooling and can be replaced by strided convolutional layers here.
266 | * Upsampling is achieved through bilinear interpolation and can be replaced by transpose convolutional layers here.
267 |
268 | Output
269 | ----------
270 | model: a keras model.
271 |
272 | '''
273 |
274 | depth_ = len(filter_num_down)
275 |
276 | verbose = False
277 |
278 | if filter_num_skip == 'auto':
279 | verbose = True
280 | filter_num_skip = [filter_num_down[0] for num in range(depth_-1)]
281 |
282 | if filter_num_aggregate == 'auto':
283 | verbose = True
284 | filter_num_aggregate = int(depth_*filter_num_down[0])
285 |
286 | if verbose:
287 | print('Automated hyper-parameter determination is applied with the following details:\n----------')
288 | print('\tNumber of convolution filters after each full-scale skip connection: filter_num_skip = {}'.format(filter_num_skip))
289 | print('\tNumber of channels of full-scale aggregated feature maps: filter_num_aggregate = {}'.format(filter_num_aggregate))
290 |
291 | if backbone is not None:
292 | bach_norm_checker(backbone, batch_norm)
293 |
294 | X_encoder = []
295 | X_decoder = []
296 |
297 |
298 | IN = Input(input_size)
299 |
300 | X_decoder = unet_3plus_2d_base(IN, filter_num_down, filter_num_skip, filter_num_aggregate,
301 | stack_num_down=stack_num_down, stack_num_up=stack_num_up, activation=activation,
302 | batch_norm=batch_norm, pool=pool, unpool=unpool,
303 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone,
304 | freeze_batch_norm=freeze_batch_norm, name=name)
305 | X_decoder = X_decoder[::-1]
306 |
307 | if deep_supervision:
308 |
309 | # ----- frozen backbone issue checker ----- #
310 | if ('{}_backbone_'.format(backbone) in X_decoder[0].name) and freeze_backbone:
311 |
312 | backbone_warn = '\n\nThe deepest UNET 3+ deep supervision branch directly connects to a frozen backbone.\nTesting your configurations on `keras_unet_collection.base.unet_plus_2d_base` is recommended.'
313 | warnings.warn(backbone_warn);
314 | # ----------------------------------------- #
315 |
316 | OUT_stack = []
317 | L_out = len(X_decoder)
318 |
319 | print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows ("sup0" is the shallowest supervision layer;\n"final" is the final output layer):\n')
320 |
321 | # conv2d --> upsampling --> output activation.
322 | # index 0 is final output
323 | for i in range(1, L_out):
324 |
325 | pool_size = 2**(i)
326 |
327 | X = Conv2D(n_labels, 3, padding='same', name='{}_output_conv_{}'.format(name, i-1))(X_decoder[i])
328 |
329 | X = decode_layer(X, n_labels, pool_size, unpool,
330 | activation=None, batch_norm=False, name='{}_output_sup{}'.format(name, i-1))
331 |
332 | if output_activation:
333 | print('\t{}_output_sup{}_activation'.format(name, i-1))
334 |
335 | if output_activation == 'Sigmoid':
336 | X = Activation('sigmoid', name='{}_output_sup{}_activation'.format(name, i-1))(X)
337 | else:
338 | activation_func = eval(output_activation)
339 | X = activation_func(name='{}_output_sup{}_activation'.format(name, i-1))(X)
340 | else:
341 | if unpool is False:
342 | print('\t{}_output_sup{}_trans_conv'.format(name, i-1))
343 | else:
344 | print('\t{}_output_sup{}_unpool'.format(name, i-1))
345 |
346 | OUT_stack.append(X)
347 |
348 | X = CONV_output(X_decoder[0], n_labels, kernel_size=3,
349 | activation=output_activation, name='{}_output_final'.format(name))
350 | OUT_stack.append(X)
351 |
352 | if output_activation:
353 | print('\t{}_output_final_activation'.format(name))
354 | else:
355 | print('\t{}_output_final'.format(name))
356 |
357 | model = Model([IN,], OUT_stack)
358 |
359 | else:
360 | OUT = CONV_output(X_decoder[0], n_labels, kernel_size=3,
361 | activation=output_activation, name='{}_output_final'.format(name))
362 |
363 | model = Model([IN,], [OUT,])
364 |
365 | return model
366 |
367 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_unet_plus_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 | from keras_unet_collection._backbone_zoo import backbone_zoo, bach_norm_checker
7 | from keras_unet_collection._model_unet_2d import UNET_left, UNET_right
8 |
9 | from tensorflow.keras.layers import Input
10 | from tensorflow.keras.models import Model
11 |
12 | import warnings
13 |
14 | def unet_plus_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
15 | activation='ReLU', batch_norm=False, pool=True, unpool=True, deep_supervision=False,
16 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet'):
17 | '''
18 | The base of U-net++ with an optional ImageNet-trained backbone
19 |
20 | unet_plus_2d_base(input_tensor, filter_num, stack_num_down=2, stack_num_up=2,
21 | activation='ReLU', batch_norm=False, pool=True, unpool=True, deep_supervision=False,
22 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet')
23 |
24 | ----------
25 | Zhou, Z., Siddiquee, M.M.R., Tajbakhsh, N. and Liang, J., 2018. Unet++: A nested u-net architecture
26 | for medical image segmentation. In Deep Learning in Medical Image Analysis and Multimodal Learning
27 | for Clinical Decision Support (pp. 3-11). Springer, Cham.
28 |
29 | Input
30 | ----------
31 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
32 | filter_num: a list that defines the number of filters for each \
33 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
34 | The depth is expected as `len(filter_num)`.
35 | stack_num_down: number of convolutional layers per downsampling level/block.
36 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
37 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
38 | batch_norm: True for batch normalization.
39 | pool: True or 'max' for MaxPooling2D.
40 | 'ave' for AveragePooling2D.
41 | False for strided conv + batch norm + activation.
42 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
43 | 'nearest' for Upsampling2D with nearest interpolation.
44 | False for Conv2DTranspose + batch norm + activation.
45 | deep_supervision: True for a model that supports deep supervision. Details see Zhou et al. (2018).
46 | name: prefix of the created keras model and its layers.
47 |
48 | ---------- (keywords of backbone options) ----------
49 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
50 | None (default) means no backbone.
51 | Currently supported backbones are:
52 | (1) VGG16, VGG19
53 | (2) ResNet50, ResNet101, ResNet152
54 | (3) ResNet50V2, ResNet101V2, ResNet152V2
55 | (4) DenseNet121, DenseNet169, DenseNet201
56 | (5) EfficientNetB[0-7]
57 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
58 | or the path to the weights file to be loaded.
59 | freeze_backbone: True for a frozen backbone.
60 | freeze_batch_norm: False for not freezing batch normalization layers.
61 |
62 | Output
63 | ----------
64 | If deep_supervision = False; Then the output is a tensor.
65 | If deep_supervision = True; Then the output is a list of tensors
66 | with the first tensor obtained from the first downsampling level (for checking the input/output shapes only),
67 | the second to the `depth-1`-th tensors obtained from each intermediate upsampling levels (deep supervision tensors),
68 | and the last tensor obtained from the end of the base.
69 |
70 | '''
71 |
72 | activation_func = eval(activation)
73 |
74 | depth_ = len(filter_num)
75 | # allocate nested lists for collecting output tensors
76 | X_nest_skip = [[] for _ in range(depth_)]
77 |
78 | # no backbone cases
79 | if backbone is None:
80 |
81 | X = input_tensor
82 |
83 | # downsampling blocks (same as in 'unet_2d')
84 | X = CONV_stack(X, filter_num[0], stack_num=stack_num_down, activation=activation,
85 | batch_norm=batch_norm, name='{}_down0'.format(name))
86 | X_nest_skip[0].append(X)
87 | for i, f in enumerate(filter_num[1:]):
88 | X = UNET_left(X, f, stack_num=stack_num_down, activation=activation,
89 | pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))
90 | X_nest_skip[0].append(X)
91 |
92 | # backbone cases
93 | else:
94 | # handling VGG16 and VGG19 separately
95 | if 'VGG' in backbone:
96 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_, freeze_backbone, freeze_batch_norm)
97 | # collecting backbone feature maps
98 | X_nest_skip[0] += backbone_([input_tensor,])
99 | depth_encode = len(X_nest_skip[0])
100 |
101 | # for other backbones
102 | else:
103 | backbone_ = backbone_zoo(backbone, weights, input_tensor, depth_-1, freeze_backbone, freeze_batch_norm)
104 | # collecting backbone feature maps
105 | X_nest_skip[0] += backbone_([input_tensor,])
106 | depth_encode = len(X_nest_skip[0]) + 1
107 |
108 | # extra conv2d blocks are applied
109 | # if downsampling levels of a backbone < user-specified downsampling levels
110 | if depth_encode < depth_:
111 |
112 | # begins at the deepest available tensor
113 | X = X_nest_skip[0][-1]
114 |
115 | # extra downsamplings
116 | for i in range(depth_-depth_encode):
117 | i_real = i + depth_encode
118 |
119 | X = UNET_left(X, filter_num[i_real], stack_num=stack_num_down, activation=activation, pool=pool,
120 | batch_norm=batch_norm, name='{}_down{}'.format(name, i_real+1))
121 | X_nest_skip[0].append(X)
122 |
123 |
124 | X = X_nest_skip[0][-1]
125 |
126 | for nest_lev in range(1, depth_):
127 |
128 | # depth difference between the deepest nest skip and the current upsampling
129 | depth_lev = depth_-nest_lev
130 |
131 | # number of available encoded tensors
132 | depth_decode = len(X_nest_skip[nest_lev-1])
133 |
134 | # loop over individual upsamling levels
135 | for i in range(1, depth_decode):
136 |
137 | # collecting previous downsampling outputs
138 | previous_skip = []
139 | for previous_lev in range(nest_lev):
140 | previous_skip.append(X_nest_skip[previous_lev][i-1])
141 |
142 | # upsamping block that concatenates all available (same feature map size) down-/upsampling outputs
143 | X_nest_skip[nest_lev].append(
144 | UNET_right(X_nest_skip[nest_lev-1][i], previous_skip, filter_num[i-1],
145 | stack_num=stack_num_up, activation=activation, unpool=unpool,
146 | batch_norm=batch_norm, concat=False, name='{}_up{}_from{}'.format(name, nest_lev-1, i-1)))
147 |
148 | if depth_decode < depth_lev+1:
149 |
150 | X = X_nest_skip[nest_lev-1][-1]
151 |
152 | for j in range(depth_lev-depth_decode+1):
153 | j_real = j + depth_decode
154 | X = UNET_right(X, None, filter_num[j_real-1],
155 | stack_num=stack_num_up, activation=activation, unpool=unpool,
156 | batch_norm=batch_norm, concat=False, name='{}_up{}_from{}'.format(name, nest_lev-1, j_real-1))
157 | X_nest_skip[nest_lev].append(X)
158 |
159 | # output
160 | if deep_supervision:
161 |
162 | X_list = []
163 |
164 | for i in range(depth_):
165 | X_list.append(X_nest_skip[i][0])
166 |
167 | return X_list
168 |
169 | else:
170 | return X_nest_skip[-1][0]
171 |
172 | def unet_plus_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2,
173 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, deep_supervision=False,
174 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet'):
175 | '''
176 | U-net++ with an optional ImageNet-trained backbone.
177 |
178 | unet_plus_2d(input_size, filter_num, n_labels, stack_num_down=2, stack_num_up=2,
179 | activation='ReLU', output_activation='Softmax', batch_norm=False, pool=True, unpool=True, deep_supervision=False,
180 | backbone=None, weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='xnet')
181 |
182 | ----------
183 | Zhou, Z., Siddiquee, M.M.R., Tajbakhsh, N. and Liang, J., 2018. Unet++: A nested u-net architecture
184 | for medical image segmentation. In Deep Learning in Medical Image Analysis and Multimodal Learning
185 | for Clinical Decision Support (pp. 3-11). Springer, Cham.
186 |
187 | Input
188 | ----------
189 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
190 | filter_num: a list that defines the number of filters for each \
191 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
192 | The depth is expected as `len(filter_num)`.
193 | n_labels: number of output labels.
194 | stack_num_down: number of convolutional layers per downsampling level/block.
195 | stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
196 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
197 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
198 | Default option is 'Softmax'.
199 | if None is received, then linear activation is applied.
200 | batch_norm: True for batch normalization.
201 | pool: True or 'max' for MaxPooling2D.
202 | 'ave' for AveragePooling2D.
203 | False for strided conv + batch norm + activation.
204 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
205 | 'nearest' for Upsampling2D with nearest interpolation.
206 | False for Conv2DTranspose + batch norm + activation.
207 | deep_supervision: True for a model that supports deep supervision. Details see Zhou et al. (2018).
208 | name: prefix of the created keras model and its layers.
209 |
210 | ---------- (keywords of backbone options) ----------
211 | backbone_name: the bakcbone model name. Should be one of the `tensorflow.keras.applications` class.
212 | None (default) means no backbone.
213 | Currently supported backbones are:
214 | (1) VGG16, VGG19
215 | (2) ResNet50, ResNet101, ResNet152
216 | (3) ResNet50V2, ResNet101V2, ResNet152V2
217 | (4) DenseNet121, DenseNet169, DenseNet201
218 | (5) EfficientNetB[0-7]
219 | weights: one of None (random initialization), 'imagenet' (pre-training on ImageNet),
220 | or the path to the weights file to be loaded.
221 | freeze_backbone: True for a frozen backbone.
222 | freeze_batch_norm: False for not freezing batch normalization layers.
223 |
224 | Output
225 | ----------
226 | model: a keras model.
227 |
228 | '''
229 |
230 | depth_ = len(filter_num)
231 |
232 | if backbone is not None:
233 | bach_norm_checker(backbone, batch_norm)
234 |
235 | IN = Input(input_size)
236 | # base
237 | X = unet_plus_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up,
238 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, deep_supervision=deep_supervision,
239 | backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, freeze_batch_norm=freeze_batch_norm, name=name)
240 |
241 | # output
242 | if deep_supervision:
243 |
244 | if (backbone is not None) and freeze_backbone:
245 | backbone_warn = '\n\nThe shallowest U-net++ deep supervision branch directly connects to a frozen backbone.\nTesting your configurations on `keras_unet_collection.base.unet_plus_2d_base` is recommended.'
246 | warnings.warn(backbone_warn);
247 |
248 | # model base returns a list of tensors
249 | X_list = X
250 | OUT_list = []
251 |
252 | print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows ("sup0" is the shallowest supervision layer;\n"final" is the final output layer):\n')
253 |
254 | # no backbone or VGG backbones
255 | # depth_ > 2 is expected (a least two downsampling blocks)
256 | if (backbone is None) or 'VGG' in backbone:
257 |
258 | for i in range(0, depth_-1):
259 | if output_activation is None:
260 | print('\t{}_output_sup{}'.format(name, i))
261 | else:
262 | print('\t{}_output_sup{}_activation'.format(name, i))
263 |
264 | OUT_list.append(CONV_output(X_list[i], n_labels, kernel_size=1, activation=output_activation,
265 | name='{}_output_sup{}'.format(name, i)))
266 | # other backbones
267 | else:
268 | for i in range(1, depth_-1):
269 | if output_activation is None:
270 | print('\t{}_output_sup{}'.format(name, i-1))
271 | else:
272 | print('\t{}_output_sup{}_activation'.format(name, i-1))
273 |
274 | # an extra upsampling for creating full resolution feature maps
275 | X = decode_layer(X_list[i], filter_num[i], 2, unpool, activation=activation,
276 | batch_norm=batch_norm, name='{}_sup{}_up'.format(name, i-1))
277 |
278 | X = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output_sup{}'.format(name, i-1))
279 | OUT_list.append(X)
280 |
281 | if output_activation is None:
282 | print('\t{}_output_final'.format(name))
283 | else:
284 | print('\t{}_output_final_activation'.format(name))
285 |
286 | OUT_list.append(CONV_output(X_list[-1], n_labels, kernel_size=1, activation=output_activation, name='{}_output_final'.format(name)))
287 |
288 | else:
289 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
290 | OUT_list = [OUT,]
291 |
292 | # model
293 | model = Model(inputs=[IN,], outputs=OUT_list, name='{}_model'.format(name))
294 |
295 | return model
296 |
--------------------------------------------------------------------------------
/keras_unet_collection/_model_vnet_2d.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.layer_utils import *
5 | from keras_unet_collection.activations import GELU, Snake
6 |
7 | from tensorflow.keras.layers import Input
8 | from tensorflow.keras.models import Model
9 |
10 |
11 | def vnet_left(X, channel, res_num, activation='ReLU', pool=True, batch_norm=False, name='left'):
12 | '''
13 | The encoder block of 2-d V-net.
14 |
15 | vnet_left(X, channel, res_num, activation='ReLU', pool=True, batch_norm=False, name='left')
16 |
17 | Input
18 | ----------
19 | X: input tensor.
20 | channel: number of convolution filters.
21 | res_num: number of convolutional layers within the residual path.
22 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'.
23 | pool: True or 'max' for MaxPooling2D.
24 | 'ave' for AveragePooling2D.
25 | False for strided conv + batch norm + activation.
26 | batch_norm: True for batch normalization, False otherwise.
27 | name: name of the created keras layers.
28 |
29 | Output
30 | ----------
31 | X: output tensor.
32 |
33 | '''
34 |
35 | pool_size = 2
36 |
37 | X = encode_layer(X, channel, pool_size, pool, activation=activation,
38 | batch_norm=batch_norm, name='{}_encode'.format(name))
39 |
40 | if pool is not False:
41 | X = CONV_stack(X, channel, kernel_size=3, stack_num=1, dilation_rate=1,
42 | activation=activation, batch_norm=batch_norm, name='{}_pre_conv'.format(name))
43 |
44 | X = Res_CONV_stack(X, X, channel, res_num=res_num, activation=activation,
45 | batch_norm=batch_norm, name='{}_res_conv'.format(name))
46 | return X
47 |
48 | def vnet_right(X, X_list, channel, res_num, activation='ReLU', unpool=True, batch_norm=False, name='right'):
49 | '''
50 | The decoder block of 2-d V-net.
51 |
52 | vnet_right(X, X_list, channel, res_num, activation='ReLU', unpool=True, batch_norm=False, name='right')
53 |
54 | Input
55 | ----------
56 | X: input tensor.
57 | X_list: a list of other tensors that connected to the input tensor.
58 | channel: number of convolution filters.
59 | stack_num: number of convolutional layers.
60 | res_num: number of convolutional layers within the residual path.
61 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'.
62 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
63 | 'nearest' for Upsampling2D with nearest interpolation.
64 | False for Conv2DTranspose + batch norm + activation.
65 | batch_norm: True for batch normalization, False otherwise.
66 | name: name of the created keras layers.
67 |
68 | Output
69 | ----------
70 | X: output tensor.
71 |
72 | '''
73 | pool_size = 2
74 |
75 | X = decode_layer(X, channel, pool_size, unpool,
76 | activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name))
77 |
78 | X_skip = X
79 |
80 | X = concatenate([X,]+X_list, axis=-1, name='{}_concat'.format(name))
81 |
82 | X = Res_CONV_stack(X, X_skip, channel, res_num, activation=activation,
83 | batch_norm=batch_norm, name='{}_res_conv'.format(name))
84 |
85 | return X
86 |
87 | def vnet_2d_base(input_tensor, filter_num, res_num_ini=1, res_num_max=3,
88 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='vnet'):
89 | '''
90 | The base of 2-d V-net.
91 |
92 | vnet_2d_base(input_tensor, filter_num, res_num_ini=1, res_num_max=3,
93 | activation='ReLU', batch_norm=False, pool=True, unpool=True, name='vnet')
94 |
95 | Milletari, F., Navab, N. and Ahmadi, S.A., 2016, October. V-net: Fully convolutional neural
96 | networks for volumetric medical image segmentation. In 2016 fourth international conference
97 | on 3D vision (3DV) (pp. 565-571). IEEE.
98 |
99 | The Two-dimensional version is inspired by:
100 | https://github.com/FENGShuanglang/2D-Vnet-Keras
101 |
102 | Input
103 | ----------
104 | input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
105 | filter_num: a list that defines the number of filters for each \
106 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
107 | The depth is expected as `len(filter_num)`.
108 | res_num_ini: number of convolutional layers of the first first residual block (before downsampling).
109 | res_num_max: the max number of convolutional layers within a residual block.
110 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
111 | batch_norm: True for batch normalization.
112 | pool: True or 'max' for MaxPooling2D.
113 | 'ave' for AveragePooling2D.
114 | False for strided conv + batch norm + activation.
115 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
116 | 'nearest' for Upsampling2D with nearest interpolation.
117 | False for Conv2DTranspose + batch norm + activation.
118 | name: prefix of the created keras layers.
119 |
120 | Output
121 | ----------
122 | X: output tensor.
123 |
124 | * This is a modified version of V-net for 2-d inputw.
125 | * The original work supports `pool=False` only.
126 | If pool is True, 'max', or 'ave', an additional conv2d layer will be applied.
127 | * All the 5-by-5 convolutional kernels are changed (and fixed) to 3-by-3.
128 |
129 | '''
130 |
131 | depth_ = len(filter_num)
132 |
133 | # determine the number of res conv layers in each down- and upsampling level
134 | res_num_list = []
135 | for i in range(depth_):
136 | temp_num = res_num_ini + i
137 | if temp_num > res_num_max:
138 | temp_num = res_num_max
139 | res_num_list.append(temp_num)
140 |
141 | X_skip = []
142 |
143 | X = input_tensor
144 | # ini conv layer
145 | X = CONV_stack(X, filter_num[0], kernel_size=3, stack_num=1, dilation_rate=1,
146 | activation=activation, batch_norm=batch_norm, name='{}_input_conv'.format(name))
147 |
148 | X = Res_CONV_stack(X, X, filter_num[0], res_num=res_num_list[0], activation=activation,
149 | batch_norm=batch_norm, name='{}_down_0'.format(name))
150 | X_skip.append(X)
151 |
152 | # downsampling levels
153 | for i, f in enumerate(filter_num[1:]):
154 | X = vnet_left(X, f, res_num=res_num_list[i+1], activation=activation, pool=pool,
155 | batch_norm=batch_norm, name='{}_down_{}'.format(name, i+1))
156 |
157 | X_skip.append(X)
158 |
159 | X_skip = X_skip[:-1][::-1]
160 | filter_num = filter_num[:-1][::-1]
161 | res_num_list = res_num_list[:-1][::-1]
162 |
163 | # upsampling levels
164 | for i, f in enumerate(filter_num):
165 | X = vnet_right(X, [X_skip[i],], f, res_num=res_num_list[i],
166 | activation=activation, unpool=unpool, batch_norm=batch_norm, name='{}_up_{}'.format(name, i))
167 |
168 | return X
169 |
170 |
171 | def vnet_2d(input_size, filter_num, n_labels,
172 | res_num_ini=1, res_num_max=3,
173 | activation='ReLU', output_activation='Softmax',
174 | batch_norm=False, pool=True, unpool=True, name='vnet'):
175 | '''
176 | vnet 2d
177 |
178 | vnet_2d(input_size, filter_num, n_labels,
179 | res_num_ini=1, res_num_max=3,
180 | activation='ReLU', output_activation='Softmax',
181 | batch_norm=False, pool=True, unpool=True, name='vnet')
182 |
183 | Milletari, F., Navab, N. and Ahmadi, S.A., 2016, October. V-net: Fully convolutional neural
184 | networks for volumetric medical image segmentation. In 2016 fourth international conference
185 | on 3D vision (3DV) (pp. 565-571). IEEE.
186 |
187 | The Two-dimensional version is inspired by:
188 | https://github.com/FENGShuanglang/2D-Vnet-Keras
189 |
190 | Input
191 | ----------
192 | input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
193 | filter_num: a list that defines the number of filters for each \
194 | down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
195 | The depth is expected as `len(filter_num)`.
196 | n_labels: number of output labels.
197 | res_num_ini: number of convolutional layers of the first first residual block (before downsampling).
198 | res_num_max: the max number of convolutional layers within a residual block.
199 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
200 | output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
201 | Default option is 'Softmax'.
202 | if None is received, then linear activation is applied.
203 | batch_norm: True for batch normalization.
204 | pool: True or 'max' for MaxPooling2D.
205 | 'ave' for AveragePooling2D.
206 | False for strided conv + batch norm + activation.
207 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
208 | 'nearest' for Upsampling2D with nearest interpolation.
209 | False for Conv2DTranspose + batch norm + activation.
210 | name: prefix of the created keras layers.
211 |
212 | Output
213 | ----------
214 | model: a keras model.
215 |
216 | * This is a modified version of V-net for 2-d inputw.
217 | * The original work supports `pool=False` only.
218 | If pool is True, 'max', or 'ave', an additional conv2d layer will be applied.
219 | * All the 5-by-5 convolutional kernels are changed (and fixed) to 3-by-3.
220 | '''
221 |
222 | IN = Input(input_size)
223 | X = IN
224 | # base
225 | X = vnet_2d_base(X, filter_num, res_num_ini=res_num_ini, res_num_max= res_num_max,
226 | activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name=name)
227 | # output layer
228 | OUT = CONV_output(X, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
229 |
230 | # functional API model
231 | model = Model(inputs=[IN,], outputs=[OUT,], name='{}_model'.format(name))
232 |
233 | return model
234 |
235 |
--------------------------------------------------------------------------------
/keras_unet_collection/activations.py:
--------------------------------------------------------------------------------
1 |
2 | from tensorflow import math
3 | from tensorflow.keras.layers import Layer
4 | import tensorflow.keras.backend as K
5 |
6 |
7 | def gelu_(X):
8 |
9 | return 0.5*X*(1.0 + math.tanh(0.7978845608028654*(X + 0.044715*math.pow(X, 3))))
10 |
11 | def snake_(X, beta):
12 |
13 | return X + (1/beta)*math.square(math.sin(beta*X))
14 |
15 |
16 | class GELU(Layer):
17 | '''
18 | Gaussian Error Linear Unit (GELU), an alternative of ReLU
19 |
20 | Y = GELU()(X)
21 |
22 | ----------
23 | Hendrycks, D. and Gimpel, K., 2016. Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415.
24 |
25 | Usage: use it as a tf.keras.Layer
26 |
27 |
28 | '''
29 | def __init__(self, trainable=False, **kwargs):
30 | super(GELU, self).__init__(**kwargs)
31 | self.supports_masking = True
32 | self.trainable = trainable
33 |
34 | def build(self, input_shape):
35 | super(GELU, self).build(input_shape)
36 |
37 | def call(self, inputs, mask=None):
38 | return gelu_(inputs)
39 |
40 | def get_config(self):
41 | config = {'trainable': self.trainable}
42 | base_config = super(GELU, self).get_config()
43 | return dict(list(base_config.items()) + list(config.items()))
44 | def compute_output_shape(self, input_shape):
45 | return input_shape
46 |
47 |
48 | class Snake(Layer):
49 | '''
50 | Snake activation function $X + (1/b)*sin^2(b*X)$. Proposed to learn periodic targets.
51 |
52 | Y = Snake(beta=0.5, trainable=False)(X)
53 |
54 | ----------
55 | Ziyin, L., Hartwig, T. and Ueda, M., 2020. Neural networks fail to learn periodic functions
56 | and how to fix it. arXiv preprint arXiv:2006.08195.
57 |
58 | '''
59 | def __init__(self, beta=0.5, trainable=False, **kwargs):
60 | super(Snake, self).__init__(**kwargs)
61 | self.supports_masking = True
62 | self.beta = beta
63 | self.trainable = trainable
64 |
65 | def build(self, input_shape):
66 | self.beta_factor = K.variable(self.beta, dtype=K.floatx(), name='beta_factor')
67 | if self.trainable:
68 | self._trainable_weights.append(self.beta_factor)
69 |
70 | super(Snake, self).build(input_shape)
71 |
72 | def call(self, inputs, mask=None):
73 | return snake_(inputs, self.beta_factor)
74 |
75 | def get_config(self):
76 | config = {'beta': self.get_weights()[0] if self.trainable else self.beta, 'trainable': self.trainable}
77 | base_config = super(Snake, self).get_config()
78 | return dict(list(base_config.items()) + list(config.items()))
79 |
80 | def compute_output_shape(self, input_shape):
81 | return input_shape
82 |
--------------------------------------------------------------------------------
/keras_unet_collection/backbones.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import warnings
5 |
6 | depreciated_warning = "\n----------------------------------------\n`backbones`is depreciated, use `keras_unet_collection.base` instead.\ne.g.\nfrom keras_unet_collection import base\nbase.unet_2d_base(...);\n----------------------------------------"
7 | warnings.warn(depreciated_warning);
8 |
9 | from keras_unet_collection._model_unet_2d import unet_2d_base as unet_2d_backbone
10 | from keras_unet_collection._model_unet_plus_2d import unet_plus_2d_base as unet_plus_2d_backbone
11 | from keras_unet_collection._model_r2_unet_2d import r2_unet_2d_base as r2_unet_2d_backbone
12 | from keras_unet_collection._model_att_unet_2d import att_unet_2d_base as att_unet_2d_backbone
13 | from keras_unet_collection._model_resunet_a_2d import resunet_a_2d_base as resunet_a_2d_backbone
14 | from keras_unet_collection._model_u2net_2d import u2net_2d_base as u2net_2d_backbone
15 | from keras_unet_collection._model_unet_3plus_2d import unet_3plus_2d_base as unet_3plus_2d_backbone
16 |
--------------------------------------------------------------------------------
/keras_unet_collection/base.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection._model_unet_2d import unet_2d_base
5 | from keras_unet_collection._model_vnet_2d import vnet_2d_base
6 | from keras_unet_collection._model_unet_plus_2d import unet_plus_2d_base
7 | from keras_unet_collection._model_r2_unet_2d import r2_unet_2d_base
8 | from keras_unet_collection._model_att_unet_2d import att_unet_2d_base
9 | from keras_unet_collection._model_resunet_a_2d import resunet_a_2d_base
10 | from keras_unet_collection._model_u2net_2d import u2net_2d_base
11 | from keras_unet_collection._model_unet_3plus_2d import unet_3plus_2d_base
12 | from keras_unet_collection._model_transunet_2d import transunet_2d_base
13 | from keras_unet_collection._model_swin_unet_2d import swin_unet_2d_base
14 |
--------------------------------------------------------------------------------
/keras_unet_collection/layer_utils.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection.activations import GELU, Snake
5 | from tensorflow import expand_dims
6 | from tensorflow.compat.v1 import image
7 | from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, UpSampling2D, Conv2DTranspose, GlobalAveragePooling2D
8 | from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Lambda
9 | from tensorflow.keras.layers import BatchNormalization, Activation, concatenate, multiply, add
10 | from tensorflow.keras.layers import ReLU, LeakyReLU, PReLU, ELU, Softmax
11 |
12 | def decode_layer(X, channel, pool_size, unpool, kernel_size=3,
13 | activation='ReLU', batch_norm=False, name='decode'):
14 | '''
15 | An overall decode layer, based on either upsampling or trans conv.
16 |
17 | decode_layer(X, channel, pool_size, unpool, kernel_size=3,
18 | activation='ReLU', batch_norm=False, name='decode')
19 |
20 | Input
21 | ----------
22 | X: input tensor.
23 | pool_size: the decoding factor.
24 | channel: (for trans conv only) number of convolution filters.
25 | unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
26 | 'nearest' for Upsampling2D with nearest interpolation.
27 | False for Conv2DTranspose + batch norm + activation.
28 | kernel_size: size of convolution kernels.
29 | If kernel_size='auto', then it equals to the `pool_size`.
30 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU.
31 | batch_norm: True for batch normalization, False otherwise.
32 | name: prefix of the created keras layers.
33 |
34 | Output
35 | ----------
36 | X: output tensor.
37 |
38 | * The defaut: `kernel_size=3`, is suitable for `pool_size=2`.
39 |
40 | '''
41 | # parsers
42 | if unpool is False:
43 | # trans conv configurations
44 | bias_flag = not batch_norm
45 |
46 | elif unpool == 'nearest':
47 | # upsample2d configurations
48 | unpool = True
49 | interp = 'nearest'
50 |
51 | elif (unpool is True) or (unpool == 'bilinear'):
52 | # upsample2d configurations
53 | unpool = True
54 | interp = 'bilinear'
55 |
56 | else:
57 | raise ValueError('Invalid unpool keyword')
58 |
59 | if unpool:
60 | X = UpSampling2D(size=(pool_size, pool_size), interpolation=interp, name='{}_unpool'.format(name))(X)
61 | else:
62 | if kernel_size == 'auto':
63 | kernel_size = pool_size
64 |
65 | X = Conv2DTranspose(channel, kernel_size, strides=(pool_size, pool_size),
66 | padding='same', name='{}_trans_conv'.format(name))(X)
67 |
68 | # batch normalization
69 | if batch_norm:
70 | X = BatchNormalization(axis=3, name='{}_bn'.format(name))(X)
71 |
72 | # activation
73 | if activation is not None:
74 | activation_func = eval(activation)
75 | X = activation_func(name='{}_activation'.format(name))(X)
76 |
77 | return X
78 |
79 | def encode_layer(X, channel, pool_size, pool, kernel_size='auto',
80 | activation='ReLU', batch_norm=False, name='encode'):
81 | '''
82 | An overall encode layer, based on one of the:
83 | (1) max-pooling, (2) average-pooling, (3) strided conv2d.
84 |
85 | encode_layer(X, channel, pool_size, pool, kernel_size='auto',
86 | activation='ReLU', batch_norm=False, name='encode')
87 |
88 | Input
89 | ----------
90 | X: input tensor.
91 | pool_size: the encoding factor.
92 | channel: (for strided conv only) number of convolution filters.
93 | pool: True or 'max' for MaxPooling2D.
94 | 'ave' for AveragePooling2D.
95 | False for strided conv + batch norm + activation.
96 | kernel_size: size of convolution kernels.
97 | If kernel_size='auto', then it equals to the `pool_size`.
98 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU.
99 | batch_norm: True for batch normalization, False otherwise.
100 | name: prefix of the created keras layers.
101 |
102 | Output
103 | ----------
104 | X: output tensor.
105 |
106 | '''
107 | # parsers
108 | if (pool in [False, True, 'max', 'ave']) is not True:
109 | raise ValueError('Invalid pool keyword')
110 |
111 | # maxpooling2d as default
112 | if pool is True:
113 | pool = 'max'
114 |
115 | elif pool is False:
116 | # stride conv configurations
117 | bias_flag = not batch_norm
118 |
119 | if pool == 'max':
120 | X = MaxPooling2D(pool_size=(pool_size, pool_size), name='{}_maxpool'.format(name))(X)
121 |
122 | elif pool == 'ave':
123 | X = AveragePooling2D(pool_size=(pool_size, pool_size), name='{}_avepool'.format(name))(X)
124 |
125 | else:
126 | if kernel_size == 'auto':
127 | kernel_size = pool_size
128 |
129 | # linear convolution with strides
130 | X = Conv2D(channel, kernel_size, strides=(pool_size, pool_size),
131 | padding='valid', use_bias=bias_flag, name='{}_stride_conv'.format(name))(X)
132 |
133 | # batch normalization
134 | if batch_norm:
135 | X = BatchNormalization(axis=3, name='{}_bn'.format(name))(X)
136 |
137 | # activation
138 | if activation is not None:
139 | activation_func = eval(activation)
140 | X = activation_func(name='{}_activation'.format(name))(X)
141 |
142 | return X
143 |
144 | def attention_gate(X, g, channel,
145 | activation='ReLU',
146 | attention='add', name='att'):
147 | '''
148 | Self-attention gate modified from Oktay et al. 2018.
149 |
150 | attention_gate(X, g, channel, activation='ReLU', attention='add', name='att')
151 |
152 | Input
153 | ----------
154 | X: input tensor, i.e., key and value.
155 | g: gated tensor, i.e., query.
156 | channel: number of intermediate channel.
157 | Oktay et al. (2018) did not specify (denoted as F_int).
158 | intermediate channel is expected to be smaller than the input channel.
159 | activation: a nonlinear attnetion activation.
160 | The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'.
161 | attention: 'add' for additive attention; 'multiply' for multiplicative attention.
162 | Oktay et al. 2018 applied additive attention.
163 | name: prefix of the created keras layers.
164 |
165 | Output
166 | ----------
167 | X_att: output tensor.
168 |
169 | '''
170 | activation_func = eval(activation)
171 | attention_func = eval(attention)
172 |
173 | # mapping the input tensor to the intermediate channel
174 | theta_att = Conv2D(channel, 1, use_bias=True, name='{}_theta_x'.format(name))(X)
175 |
176 | # mapping the gate tensor
177 | phi_g = Conv2D(channel, 1, use_bias=True, name='{}_phi_g'.format(name))(g)
178 |
179 | # ----- attention learning ----- #
180 | query = attention_func([theta_att, phi_g], name='{}_add'.format(name))
181 |
182 | # nonlinear activation
183 | f = activation_func(name='{}_activation'.format(name))(query)
184 |
185 | # linear transformation
186 | psi_f = Conv2D(1, 1, use_bias=True, name='{}_psi_f'.format(name))(f)
187 | # ------------------------------ #
188 |
189 | # sigmoid activation as attention coefficients
190 | coef_att = Activation('sigmoid', name='{}_sigmoid'.format(name))(psi_f)
191 |
192 | # multiplicative attention masking
193 | X_att = multiply([X, coef_att], name='{}_masking'.format(name))
194 |
195 | return X_att
196 |
197 | def CONV_stack(X, channel, kernel_size=3, stack_num=2,
198 | dilation_rate=1, activation='ReLU',
199 | batch_norm=False, name='conv_stack'):
200 | '''
201 | Stacked convolutional layers:
202 | (Convolutional layer --> batch normalization --> Activation)*stack_num
203 |
204 | CONV_stack(X, channel, kernel_size=3, stack_num=2, dilation_rate=1, activation='ReLU',
205 | batch_norm=False, name='conv_stack')
206 |
207 |
208 | Input
209 | ----------
210 | X: input tensor.
211 | channel: number of convolution filters.
212 | kernel_size: size of 2-d convolution kernels.
213 | stack_num: number of stacked Conv2D-BN-Activation layers.
214 | dilation_rate: optional dilated convolution kernel.
215 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU.
216 | batch_norm: True for batch normalization, False otherwise.
217 | name: prefix of the created keras layers.
218 |
219 | Output
220 | ----------
221 | X: output tensor
222 |
223 | '''
224 |
225 | bias_flag = not batch_norm
226 |
227 | # stacking Convolutional layers
228 | for i in range(stack_num):
229 |
230 | activation_func = eval(activation)
231 |
232 | # linear convolution
233 | X = Conv2D(channel, kernel_size, padding='same', use_bias=bias_flag,
234 | dilation_rate=dilation_rate, name='{}_{}'.format(name, i))(X)
235 |
236 | # batch normalization
237 | if batch_norm:
238 | X = BatchNormalization(axis=3, name='{}_{}_bn'.format(name, i))(X)
239 |
240 | # activation
241 | activation_func = eval(activation)
242 | X = activation_func(name='{}_{}_activation'.format(name, i))(X)
243 |
244 | return X
245 |
246 | def Res_CONV_stack(X, X_skip, channel, res_num, activation='ReLU', batch_norm=False, name='res_conv'):
247 | '''
248 | Stacked convolutional layers with residual path.
249 |
250 | Res_CONV_stack(X, X_skip, channel, res_num, activation='ReLU', batch_norm=False, name='res_conv')
251 |
252 | Input
253 | ----------
254 | X: input tensor.
255 | X_skip: the tensor that does go into the residual path
256 | can be a copy of X (e.g., the identity block of ResNet).
257 | channel: number of convolution filters.
258 | res_num: number of convolutional layers within the residual path.
259 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'.
260 | batch_norm: True for batch normalization, False otherwise.
261 | name: prefix of the created keras layers.
262 |
263 | Output
264 | ----------
265 | X: output tensor.
266 |
267 | '''
268 | X = CONV_stack(X, channel, kernel_size=3, stack_num=res_num, dilation_rate=1,
269 | activation=activation, batch_norm=batch_norm, name=name)
270 |
271 | X = add([X_skip, X], name='{}_add'.format(name))
272 |
273 | activation_func = eval(activation)
274 | X = activation_func(name='{}_add_activation'.format(name))(X)
275 |
276 | return X
277 |
278 | def Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, dilation_rate=1, activation='ReLU', batch_norm=False, name='sep_conv'):
279 | '''
280 | Depthwise separable convolution with (optional) dilated convolution kernel and batch normalization.
281 |
282 | Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, dilation_rate=1, activation='ReLU', batch_norm=False, name='sep_conv')
283 |
284 | Input
285 | ----------
286 | X: input tensor.
287 | channel: number of convolution filters.
288 | kernel_size: size of 2-d convolution kernels.
289 | stack_num: number of stacked depthwise-pointwise layers.
290 | dilation_rate: optional dilated convolution kernel.
291 | activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'.
292 | batch_norm: True for batch normalization, False otherwise.
293 | name: prefix of the created keras layers.
294 |
295 | Output
296 | ----------
297 | X: output tensor.
298 |
299 | '''
300 |
301 | activation_func = eval(activation)
302 | bias_flag = not batch_norm
303 |
304 | for i in range(stack_num):
305 | X = DepthwiseConv2D(kernel_size, dilation_rate=dilation_rate, padding='same',
306 | use_bias=bias_flag, name='{}_{}_depthwise'.format(name, i))(X)
307 |
308 | if batch_norm:
309 | X = BatchNormalization(name='{}_{}_depthwise_BN'.format(name, i))(X)
310 |
311 | X = activation_func(name='{}_{}_depthwise_activation'.format(name, i))(X)
312 |
313 | X = Conv2D(channel, (1, 1), padding='same', use_bias=bias_flag, name='{}_{}_pointwise'.format(name, i))(X)
314 |
315 | if batch_norm:
316 | X = BatchNormalization(name='{}_{}_pointwise_BN'.format(name, i))(X)
317 |
318 | X = activation_func(name='{}_{}_pointwise_activation'.format(name, i))(X)
319 |
320 | return X
321 |
322 | def ASPP_conv(X, channel, activation='ReLU', batch_norm=True, name='aspp'):
323 | '''
324 | Atrous Spatial Pyramid Pooling (ASPP).
325 |
326 | ASPP_conv(X, channel, activation='ReLU', batch_norm=True, name='aspp')
327 |
328 | ----------
329 | Wang, Y., Liang, B., Ding, M. and Li, J., 2019. Dense semantic labeling
330 | with atrous spatial pyramid pooling and decoder for high-resolution remote
331 | sensing imagery. Remote Sensing, 11(1), p.20.
332 |
333 | Input
334 | ----------
335 | X: input tensor.
336 | channel: number of convolution filters.
337 | activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU.
338 | batch_norm: True for batch normalization, False otherwise.
339 | name: prefix of the created keras layers.
340 |
341 | Output
342 | ----------
343 | X: output tensor.
344 |
345 | * dilation rates are fixed to `[6, 9, 12]`.
346 | '''
347 |
348 | activation_func = eval(activation)
349 | bias_flag = not batch_norm
350 |
351 | shape_before = X.get_shape().as_list()
352 | b4 = GlobalAveragePooling2D(name='{}_avepool_b4'.format(name))(X)
353 |
354 | b4 = expand_dims(expand_dims(b4, 1), 1, name='{}_expdim_b4'.format(name))
355 |
356 | b4 = Conv2D(channel, 1, padding='same', use_bias=bias_flag, name='{}_conv_b4'.format(name))(b4)
357 |
358 | if batch_norm:
359 | b4 = BatchNormalization(name='{}_conv_b4_BN'.format(name))(b4)
360 |
361 | b4 = activation_func(name='{}_conv_b4_activation'.format(name))(b4)
362 |
363 | # <----- tensorflow v1 resize.
364 | b4 = Lambda(lambda X: image.resize(X, shape_before[1:3], method='bilinear', align_corners=True),
365 | name='{}_resize_b4'.format(name))(b4)
366 |
367 | b0 = Conv2D(channel, (1, 1), padding='same', use_bias=bias_flag, name='{}_conv_b0'.format(name))(X)
368 |
369 | if batch_norm:
370 | b0 = BatchNormalization(name='{}_conv_b0_BN'.format(name))(b0)
371 |
372 | b0 = activation_func(name='{}_conv_b0_activation'.format(name))(b0)
373 |
374 | # dilation rates are fixed to `[6, 9, 12]`.
375 | b_r6 = Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, activation='ReLU',
376 | dilation_rate=6, batch_norm=True, name='{}_sepconv_r6'.format(name))
377 | b_r9 = Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, activation='ReLU',
378 | dilation_rate=9, batch_norm=True, name='{}_sepconv_r9'.format(name))
379 | b_r12 = Sep_CONV_stack(X, channel, kernel_size=3, stack_num=1, activation='ReLU',
380 | dilation_rate=12, batch_norm=True, name='{}_sepconv_r12'.format(name))
381 |
382 | return concatenate([b4, b0, b_r6, b_r9, b_r12])
383 |
384 | def CONV_output(X, n_labels, kernel_size=1, activation='Softmax', name='conv_output'):
385 | '''
386 | Convolutional layer with output activation.
387 |
388 | CONV_output(X, n_labels, kernel_size=1, activation='Softmax', name='conv_output')
389 |
390 | Input
391 | ----------
392 | X: input tensor.
393 | n_labels: number of classification label(s).
394 | kernel_size: size of 2-d convolution kernels. Default is 1-by-1.
395 | activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
396 | Default option is 'Softmax'.
397 | if None is received, then linear activation is applied.
398 | name: prefix of the created keras layers.
399 |
400 | Output
401 | ----------
402 | X: output tensor.
403 |
404 | '''
405 |
406 | X = Conv2D(n_labels, kernel_size, padding='same', use_bias=True, name=name)(X)
407 |
408 | if activation:
409 |
410 | if activation == 'Sigmoid':
411 | X = Activation('sigmoid', name='{}_activation'.format(name))(X)
412 |
413 | else:
414 | activation_func = eval(activation)
415 | X = activation_func(name='{}_activation'.format(name))(X)
416 |
417 | return X
418 |
419 |
--------------------------------------------------------------------------------
/keras_unet_collection/losses.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | import tensorflow.keras.backend as K
7 |
8 | def _crps_tf(y_true, y_pred, factor=0.05):
9 |
10 | '''
11 | core of (pseudo) CRPS loss.
12 |
13 | y_true: two-dimensional arrays
14 | y_pred: two-dimensional arrays
15 | factor: importance of std term
16 | '''
17 |
18 | # mean absolute error
19 | mae = K.mean(tf.abs(y_pred - y_true))
20 |
21 | dist = tf.math.reduce_std(y_pred)
22 |
23 | return mae - factor*dist
24 |
25 | def crps2d_tf(y_true, y_pred, factor=0.05):
26 |
27 | '''
28 | (Experimental)
29 | An approximated continuous ranked probability score (CRPS) loss function:
30 |
31 | CRPS = mean_abs_err - factor * std
32 |
33 | * Note that the "real CRPS" = mean_abs_err - mean_pairwise_abs_diff
34 |
35 | Replacing mean pairwise absolute difference by standard deviation offers
36 | a complexity reduction from O(N^2) to O(N*logN)
37 |
38 | ** factor > 0.1 may yield negative loss values.
39 |
40 | Compatible with high-level Keras training methods
41 |
42 | Input
43 | ----------
44 | y_true: training target with shape=(batch_num, x, y, 1)
45 | y_pred: a forward pass with shape=(batch_num, x, y, 1)
46 | factor: relative importance of standard deviation term.
47 |
48 | '''
49 |
50 | y_pred = tf.convert_to_tensor(y_pred)
51 | y_true = tf.cast(y_true, y_pred.dtype)
52 |
53 | y_pred = tf.squeeze(y_pred)
54 | y_true = tf.squeeze(y_true)
55 |
56 | batch_num = y_pred.shape.as_list()[0]
57 |
58 | crps_out = 0
59 | for i in range(batch_num):
60 | crps_out += _crps_tf(y_true[i, ...], y_pred[i, ...], factor=factor)
61 |
62 | return crps_out/batch_num
63 |
64 |
65 | def _crps_np(y_true, y_pred, factor=0.05):
66 |
67 | '''
68 | Numpy version of _crps_tf.
69 | '''
70 |
71 | # mean absolute error
72 | mae = np.nanmean(np.abs(y_pred - y_true))
73 | dist = np.nanstd(y_pred)
74 |
75 | return mae - factor*dist
76 |
77 | def crps2d_np(y_true, y_pred, factor=0.05):
78 |
79 | '''
80 | (Experimental)
81 | Nunpy version of `crps2d_tf`.
82 |
83 | Documentation refers to `crps2d_tf`.
84 | '''
85 |
86 | y_true = np.squeeze(y_true)
87 | y_pred = np.squeeze(y_pred)
88 |
89 | batch_num = len(y_pred)
90 |
91 | crps_out = 0
92 | for i in range(batch_num):
93 | crps_out += _crps_np(y_true[i, ...], y_pred[i, ...], factor=factor)
94 |
95 | return crps_out/batch_num
96 |
97 | # ========================= #
98 | # Dice loss and variants
99 |
100 | def dice_coef(y_true, y_pred, const=K.epsilon()):
101 | '''
102 | Sørensen–Dice coefficient for 2-d samples.
103 |
104 | Input
105 | ----------
106 | y_true, y_pred: predicted outputs and targets.
107 | const: a constant that smooths the loss gradient and reduces numerical instabilities.
108 |
109 | '''
110 |
111 | # flatten 2-d tensors
112 | y_true_pos = tf.reshape(y_true, [-1])
113 | y_pred_pos = tf.reshape(y_pred, [-1])
114 |
115 | # get true pos (TP), false neg (FN), false pos (FP).
116 | true_pos = tf.reduce_sum(y_true_pos * y_pred_pos)
117 | false_neg = tf.reduce_sum(y_true_pos * (1-y_pred_pos))
118 | false_pos = tf.reduce_sum((1-y_true_pos) * y_pred_pos)
119 |
120 | # 2TP/(2TP+FP+FN) == 2TP/()
121 | coef_val = (2.0 * true_pos + const)/(2.0 * true_pos + false_pos + false_neg)
122 |
123 | return coef_val
124 |
125 | def dice(y_true, y_pred, const=K.epsilon()):
126 | '''
127 | Sørensen–Dice Loss.
128 |
129 | dice(y_true, y_pred, const=K.epsilon())
130 |
131 | Input
132 | ----------
133 | const: a constant that smooths the loss gradient and reduces numerical instabilities.
134 |
135 | '''
136 | # tf tensor casting
137 | y_pred = tf.convert_to_tensor(y_pred)
138 | y_true = tf.cast(y_true, y_pred.dtype)
139 |
140 | # <--- squeeze-out length-1 dimensions.
141 | y_pred = tf.squeeze(y_pred)
142 | y_true = tf.squeeze(y_true)
143 |
144 | loss_val = 1 - dice_coef(y_true, y_pred, const=const)
145 |
146 | return loss_val
147 |
148 | # ========================= #
149 | # Tversky loss and variants
150 |
151 | def tversky_coef(y_true, y_pred, alpha=0.5, const=K.epsilon()):
152 | '''
153 | Weighted Sørensen–Dice coefficient.
154 |
155 | Input
156 | ----------
157 | y_true, y_pred: predicted outputs and targets.
158 | const: a constant that smooths the loss gradient and reduces numerical instabilities.
159 |
160 | '''
161 |
162 | # flatten 2-d tensors
163 | y_true_pos = tf.reshape(y_true, [-1])
164 | y_pred_pos = tf.reshape(y_pred, [-1])
165 |
166 | # get true pos (TP), false neg (FN), false pos (FP).
167 | true_pos = tf.reduce_sum(y_true_pos * y_pred_pos)
168 | false_neg = tf.reduce_sum(y_true_pos * (1-y_pred_pos))
169 | false_pos = tf.reduce_sum((1-y_true_pos) * y_pred_pos)
170 |
171 | # TP/(TP + a*FN + b*FP); a+b = 1
172 | coef_val = (true_pos + const)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + const)
173 |
174 | return coef_val
175 |
176 | def tversky(y_true, y_pred, alpha=0.5, const=K.epsilon()):
177 | '''
178 | Tversky Loss.
179 |
180 | tversky(y_true, y_pred, alpha=0.5, const=K.epsilon())
181 |
182 | ----------
183 | Hashemi, S.R., Salehi, S.S.M., Erdogmus, D., Prabhu, S.P., Warfield, S.K. and Gholipour, A., 2018.
184 | Tversky as a loss function for highly unbalanced image segmentation using 3d fully convolutional deep networks.
185 | arXiv preprint arXiv:1803.11078.
186 |
187 | Input
188 | ----------
189 | alpha: tunable parameter within [0, 1]. Alpha handles imbalance classification cases.
190 | const: a constant that smooths the loss gradient and reduces numerical instabilities.
191 |
192 | '''
193 | # tf tensor casting
194 | y_pred = tf.convert_to_tensor(y_pred)
195 | y_true = tf.cast(y_true, y_pred.dtype)
196 |
197 | # <--- squeeze-out length-1 dimensions.
198 | y_pred = tf.squeeze(y_pred)
199 | y_true = tf.squeeze(y_true)
200 |
201 | loss_val = 1 - tversky_coef(y_true, y_pred, alpha=alpha, const=const)
202 |
203 | return loss_val
204 |
205 | def focal_tversky(y_true, y_pred, alpha=0.5, gamma=4/3, const=K.epsilon()):
206 |
207 | '''
208 | Focal Tversky Loss (FTL)
209 |
210 | focal_tversky(y_true, y_pred, alpha=0.5, gamma=4/3)
211 |
212 | ----------
213 | Abraham, N. and Khan, N.M., 2019, April. A novel focal tversky loss function with improved
214 | attention u-net for lesion segmentation. In 2019 IEEE 16th International Symposium on Biomedical Imaging
215 | (ISBI 2019) (pp. 683-687). IEEE.
216 |
217 | ----------
218 | Input
219 | alpha: tunable parameter within [0, 1]. Alpha handles imbalance classification cases
220 | gamma: tunable parameter within [1, 3].
221 | const: a constant that smooths the loss gradient and reduces numerical instabilities.
222 |
223 | '''
224 | # tf tensor casting
225 | y_pred = tf.convert_to_tensor(y_pred)
226 | y_true = tf.cast(y_true, y_pred.dtype)
227 |
228 | # <--- squeeze-out length-1 dimensions.
229 | y_pred = tf.squeeze(y_pred)
230 | y_true = tf.squeeze(y_true)
231 |
232 | # (Tversky loss)**(1/gamma)
233 | loss_val = tf.math.pow((1-tversky_coef(y_true, y_pred, alpha=alpha, const=const)), 1/gamma)
234 |
235 | return loss_val
236 |
237 | # ========================= #
238 | # MS-SSIM
239 |
240 | def ms_ssim(y_true, y_pred, **kwargs):
241 | """
242 | Multiscale structural similarity (MS-SSIM) loss.
243 |
244 | ms_ssim(y_true, y_pred, **tf_ssim_kw)
245 |
246 | ----------
247 | Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. Multiscale structural similarity for image quality assessment.
248 | In The Thrity-Seventh Asilomar Conference on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee.
249 |
250 | ----------
251 | Input
252 | kwargs: keywords of `tf.image.ssim_multiscale`
253 | https://www.tensorflow.org/api_docs/python/tf/image/ssim_multiscale
254 |
255 | *Issues of `tf.image.ssim_multiscale`refers to:
256 | https://stackoverflow.com/questions/57127626/error-in-calculation-of-inbuilt-ms-ssim-function-in-tensorflow
257 |
258 | """
259 |
260 | y_pred = tf.convert_to_tensor(y_pred)
261 | y_true = tf.cast(y_true, y_pred.dtype)
262 |
263 | y_pred = tf.squeeze(y_pred)
264 | y_true = tf.squeeze(y_true)
265 |
266 | tf_ms_ssim = tf.image.ssim_multiscale(y_true, y_pred, **kwargs)
267 |
268 | return 1 - tf_ms_ssim
269 |
270 | # ======================== #
271 |
272 | def iou_box_coef(y_true, y_pred, mode='giou', dtype=tf.float32):
273 |
274 | """
275 | Inersection over Union (IoU) and generalized IoU coefficients for bounding boxes.
276 |
277 | iou_box_coef(y_true, y_pred, mode='giou', dtype=tf.float32)
278 |
279 | ----------
280 | Rezatofighi, H., Tsoi, N., Gwak, J., Sadeghian, A., Reid, I. and Savarese, S., 2019.
281 | Generalized intersection over union: A metric and a loss for bounding box regression.
282 | In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 658-666).
283 |
284 | ----------
285 | Input
286 | y_true: the target bounding box.
287 | y_pred: the predicted bounding box.
288 |
289 | Elements of a bounding box should be organized as: [y_min, x_min, y_max, x_max].
290 |
291 | mode: 'iou' for IoU coeff (i.e., Jaccard index);
292 | 'giou' for generalized IoU coeff.
293 |
294 | dtype: the data type of input tensors.
295 | Default is tf.float32.
296 |
297 | """
298 |
299 | zero = tf.convert_to_tensor(0.0, dtype)
300 |
301 | # subtrack bounding box coords
302 | ymin_true, xmin_true, ymax_true, xmax_true = tf.unstack(y_true, 4, axis=-1)
303 | ymin_pred, xmin_pred, ymax_pred, xmax_pred = tf.unstack(y_pred, 4, axis=-1)
304 |
305 | # true area
306 | w_true = tf.maximum(zero, xmax_true - xmin_true)
307 | h_true = tf.maximum(zero, ymax_true - ymin_true)
308 | area_true = w_true * h_true
309 |
310 | # pred area
311 | w_pred = tf.maximum(zero, xmax_pred - xmin_pred)
312 | h_pred = tf.maximum(zero, ymax_pred - ymin_pred)
313 | area_pred = w_pred * h_pred
314 |
315 | # intersections
316 | intersect_ymin = tf.maximum(ymin_true, ymin_pred)
317 | intersect_xmin = tf.maximum(xmin_true, xmin_pred)
318 | intersect_ymax = tf.minimum(ymax_true, ymax_pred)
319 | intersect_xmax = tf.minimum(xmax_true, xmax_pred)
320 |
321 | w_intersect = tf.maximum(zero, intersect_xmax - intersect_xmin)
322 | h_intersect = tf.maximum(zero, intersect_ymax - intersect_ymin)
323 | area_intersect = w_intersect * h_intersect
324 |
325 | # IoU
326 | area_union = area_true + area_pred - area_intersect
327 | iou = tf.math.divide_no_nan(area_intersect, area_union)
328 |
329 | if mode == "iou":
330 |
331 | return iou
332 |
333 | else:
334 |
335 | # encolsed coords
336 | enclose_ymin = tf.minimum(ymin_true, ymin_pred)
337 | enclose_xmin = tf.minimum(xmin_true, xmin_pred)
338 | enclose_ymax = tf.maximum(ymax_true, ymax_pred)
339 | enclose_xmax = tf.maximum(xmax_true, xmax_pred)
340 |
341 | # enclosed area
342 | w_enclose = tf.maximum(zero, enclose_xmax - enclose_xmin)
343 | h_enclose = tf.maximum(zero, enclose_ymax - enclose_ymin)
344 | area_enclose = w_enclose * h_enclose
345 |
346 | # generalized IoU
347 | giou = iou - tf.math.divide_no_nan((area_enclose - area_union), area_enclose)
348 |
349 | return giou
350 |
351 | def iou_box(y_true, y_pred, mode='giou', dtype=tf.float32):
352 | """
353 | Inersection over Union (IoU) and generalized IoU losses for bounding boxes.
354 |
355 | iou_box(y_true, y_pred, mode='giou', dtype=tf.float32)
356 |
357 | ----------
358 | Rezatofighi, H., Tsoi, N., Gwak, J., Sadeghian, A., Reid, I. and Savarese, S., 2019.
359 | Generalized intersection over union: A metric and a loss for bounding box regression.
360 | In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 658-666).
361 |
362 | ----------
363 | Input
364 | y_true: the target bounding box.
365 | y_pred: the predicted bounding box.
366 |
367 | Elements of a bounding box should be organized as: [y_min, x_min, y_max, x_max].
368 |
369 | mode: 'iou' for IoU coeff (i.e., Jaccard index);
370 | 'giou' for generalized IoU coeff.
371 |
372 | dtype: the data type of input tensors.
373 | Default is tf.float32.
374 |
375 | """
376 |
377 | y_pred = tf.convert_to_tensor(y_pred)
378 | y_pred = tf.cast(y_pred, dtype)
379 |
380 | y_true = tf.cast(y_true, dtype)
381 |
382 | y_pred = tf.squeeze(y_pred)
383 | y_true = tf.squeeze(y_true)
384 |
385 | return 1 - iou_box_coef(y_true, y_pred, mode=mode, dtype=dtype)
386 |
387 |
388 | def iou_seg(y_true, y_pred, dtype=tf.float32):
389 | """
390 | Inersection over Union (IoU) loss for segmentation maps.
391 |
392 | iou_seg(y_true, y_pred, dtype=tf.float32)
393 |
394 | ----------
395 | Rahman, M.A. and Wang, Y., 2016, December. Optimizing intersection-over-union in deep neural networks for
396 | image segmentation. In International symposium on visual computing (pp. 234-244). Springer, Cham.
397 |
398 | ----------
399 | Input
400 | y_true: segmentation targets, c.f. `keras.losses.categorical_crossentropy`
401 | y_pred: segmentation predictions.
402 |
403 | dtype: the data type of input tensors.
404 | Default is tf.float32.
405 |
406 | """
407 |
408 | # tf tensor casting
409 | y_pred = tf.convert_to_tensor(y_pred)
410 | y_pred = tf.cast(y_pred, dtype)
411 | y_true = tf.cast(y_true, y_pred.dtype)
412 |
413 | y_pred = tf.squeeze(y_pred)
414 | y_true = tf.squeeze(y_true)
415 |
416 | y_true_pos = tf.reshape(y_true, [-1])
417 | y_pred_pos = tf.reshape(y_pred, [-1])
418 |
419 | area_intersect = tf.reduce_sum(tf.multiply(y_true_pos, y_pred_pos))
420 |
421 | area_true = tf.reduce_sum(y_true_pos)
422 | area_pred = tf.reduce_sum(y_pred_pos)
423 | area_union = area_true + area_pred - area_intersect
424 |
425 | return 1-tf.math.divide_no_nan(area_intersect, area_union)
426 |
427 | # ========================= #
428 | # Semi-hard triplet
429 |
430 | def triplet_1d(y_true, y_pred, N, margin=5.0):
431 |
432 | '''
433 | (Experimental)
434 | Semi-hard triplet loss with one-dimensional vectors of anchor, positive, and negative.
435 |
436 | triplet_1d(y_true, y_pred, N, margin=5.0)
437 |
438 | Input
439 | ----------
440 | y_true: a dummy input, not used within this function. Appeared as a requirment of tf.keras.loss function format.
441 | y_pred: a single pass of triplet training, with `shape=(batch_num, 3*embeded_vector_size)`.
442 | i.e., `y_pred` is the ordered and concatenated anchor, positive, and negative embeddings.
443 | N: Size (dimensions) of embedded vectors
444 | margin: a positive number that prevents negative loss.
445 |
446 | '''
447 |
448 | # anchor sample pair separations.
449 | Embd_anchor = y_pred[:, 0:N]
450 | Embd_pos = y_pred[:, N:2*N]
451 | Embd_neg = y_pred[:, 2*N:]
452 |
453 | # squared distance measures
454 | d_pos = tf.reduce_sum(tf.square(Embd_anchor - Embd_pos), 1)
455 | d_neg = tf.reduce_sum(tf.square(Embd_anchor - Embd_neg), 1)
456 | loss_val = tf.maximum(0., margin + d_pos - d_neg)
457 | loss_val = tf.reduce_mean(loss_val)
458 |
459 | return loss_val
--------------------------------------------------------------------------------
/keras_unet_collection/models.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | from keras_unet_collection._model_unet_2d import unet_2d
5 | from keras_unet_collection._model_vnet_2d import vnet_2d
6 | from keras_unet_collection._model_unet_plus_2d import unet_plus_2d
7 | from keras_unet_collection._model_r2_unet_2d import r2_unet_2d
8 | from keras_unet_collection._model_att_unet_2d import att_unet_2d
9 | from keras_unet_collection._model_resunet_a_2d import resunet_a_2d
10 | from keras_unet_collection._model_u2net_2d import u2net_2d
11 | from keras_unet_collection._model_unet_3plus_2d import unet_3plus_2d
12 | from keras_unet_collection._model_transunet_2d import transunet_2d
13 | from keras_unet_collection._model_swin_unet_2d import swin_unet_2d
14 |
--------------------------------------------------------------------------------
/keras_unet_collection/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 | #import tensorflow as tf
5 | from tensorflow import keras
6 |
7 | from PIL import Image
8 |
9 | def dummy_loader(model_path):
10 | '''
11 | Load a stored keras model and return its weights.
12 |
13 | Input
14 | ----------
15 | The file path of the stored keras model.
16 |
17 | Output
18 | ----------
19 | Weights of the model.
20 |
21 | '''
22 | backbone = keras.models.load_model(model_path, compile=False)
23 | W = backbone.get_weights()
24 | return W
25 |
26 | def image_to_array(filenames, size, channel):
27 | '''
28 | Converting RGB images to numpy arrays.
29 |
30 | Input
31 | ----------
32 | filenames: an iterable of the path of image files
33 | size: the output size (height == width) of image.
34 | Processed through PIL.Image.NEAREST
35 | channel: number of image channels, e.g. channel=3 for RGB.
36 |
37 | Output
38 | ----------
39 | An array with shape = (filenum, size, size, channel)
40 |
41 | '''
42 |
43 | # number of files
44 | L = len(filenames)
45 |
46 | # allocation
47 | out = np.empty((L, size, size, channel))
48 |
49 | # loop over filenames
50 | if channel == 1:
51 | for i, name in enumerate(filenames):
52 | with Image.open(name) as pixio:
53 | pix = pixio.resize((size, size), Image.NEAREST)
54 | out[i, ..., 0] = np.array(pix)
55 | else:
56 | for i, name in enumerate(filenames):
57 | with Image.open(name) as pixio:
58 | pix = pixio.resize((size, size), Image.NEAREST)
59 | out[i, ...] = np.array(pix)[..., :channel]
60 | return out[:, ::-1, ...]
61 |
62 | def shuffle_ind(L):
63 | '''
64 | Generating random shuffled indices.
65 |
66 | Input
67 | ----------
68 | L: an int that defines the largest index
69 |
70 | Output
71 | ----------
72 | a numpy array of shuffled indices with shape = (L,)
73 | '''
74 |
75 | ind = np.arange(L)
76 | np.random.shuffle(ind)
77 | return ind
78 |
79 | def freeze_model(model, freeze_batch_norm=False):
80 | '''
81 | freeze a keras model
82 |
83 | Input
84 | ----------
85 | model: a keras model
86 | freeze_batch_norm: False for not freezing batch notmalization layers
87 | '''
88 | if freeze_batch_norm:
89 | for layer in model.layers:
90 | layer.trainable = False
91 | else:
92 | from tensorflow.keras.layers import BatchNormalization
93 | for layer in model.layers:
94 | if isinstance(layer, BatchNormalization):
95 | layer.trainable = True
96 | else:
97 | layer.trainable = False
98 | return model
99 |
100 |
101 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow_gpu==2.5.0
2 | numpy==1.19.5
3 | Pillow==8.3.0
4 | tensorflow==2.5.0
5 |
--------------------------------------------------------------------------------