├── .github └── FUNDING.yml ├── .gitignore ├── .gitmodules ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.rst ├── __init__.py ├── docs ├── Makefile ├── api.rst ├── conf.py ├── index.rst ├── install.rst ├── support.rst └── tutorial.rst ├── examples ├── binary segmentation (camvid).ipynb └── multiclass segmentation (camvid).ipynb ├── images ├── fpn.png ├── linknet.png ├── logo.png ├── pspnet.png └── unet.png ├── requirements.txt ├── segmentation_models ├── __init__.py ├── __version__.py ├── backbones │ ├── __init__.py │ ├── backbones_factory.py │ ├── inception_resnet_v2.py │ └── inception_v3.py ├── base │ ├── __init__.py │ ├── functional.py │ └── objects.py ├── losses.py ├── metrics.py ├── models │ ├── __init__.py │ ├── _common_blocks.py │ ├── _utils.py │ ├── fpn.py │ ├── linknet.py │ ├── pspnet.py │ └── unet.py └── utils.py ├── setup.py └── tests ├── test_metrics.py ├── test_models.py └── test_utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: qubvel 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: qubvel 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | *ipynb 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/.gitmodules -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | language: python 4 | 5 | matrix: 6 | include: 7 | - python: 3.6 8 | env: KERAS_VERSION=2.2.4 9 | - python: 3.6 10 | env: SM_FRAMEWORK='tf.keras' 11 | 12 | git: 13 | submodules: true 14 | 15 | 16 | install: 17 | # code below is taken from http://conda.pydata.org/docs/travis.html 18 | # We do this conditionally because it saves us some downloading if the 19 | # version is the same. 20 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 21 | - bash miniconda.sh -b -p $HOME/miniconda 22 | - export PATH="$HOME/miniconda/bin:$PATH" 23 | - hash -r 24 | - conda config --set always_yes yes --set changeps1 no 25 | - conda update -q conda 26 | # Useful for debugging any issues with conda 27 | - conda info -a 28 | 29 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytest pandas 30 | - source activate test-environment 31 | - pip install --only-binary=numpy,scipy numpy nose scipy matplotlib h5py theano 32 | 33 | 34 | # set library path 35 | - export LD_LIBRARY_PATH=$HOME/miniconda/envs/test-environment/lib/:$LD_LIBRARY_PATH 36 | - conda install mkl mkl-service 37 | 38 | # install TensorFlow (CPU version). 39 | - pip install tensorflow==1.14 40 | 41 | # install keras 42 | - if [ -z $KERAS_VERSION ]; then 43 | echo "Using tf.keras"; 44 | else 45 | echo "Using keras"; 46 | pip install keras==$KERAS_VERSION; 47 | fi 48 | 49 | # install lib in develop mode 50 | - pip install -e .[tests] 51 | 52 | # detect one of markdown files is changed or not 53 | - export DOC_ONLY_CHANGED=False; 54 | - if [ $(git diff --name-only HEAD~1 | wc -l) == "1" ] && [[ "$(git diff --name-only HEAD~1)" == *"md" ]]; then 55 | export DOC_ONLY_CHANGED=True; 56 | fi 57 | 58 | # command to run tests 59 | script: 60 | - export MKL_THREADING_LAYER="GNU" 61 | - mkdir -p ~/.keras/models 62 | # set up keras backend 63 | - if [[ "$DOC_ONLY_CHANGED" == "False" ]]; then 64 | PYTHONPATH=$PWD:$PYTHONPATH py.test tests/; 65 | fi 66 | 67 | deploy: 68 | provider: pypi 69 | user: qubvel 70 | password: 71 | secure: QA/UJmkXGlXy/6C8X0E/bPf4izu3rJsztaEmqIM1npxPiv2Uf4WFs43vxkMXwfHrflocdfw8SBM8bWnbunGT2SvDdo/MMCMpol7unE74T/RbODYl6aiJWVM3QKOXL8pQD0oQ+03L1YK3nCeSQdePINEPmuFmvwyO40q8Dwv8HBZIGZlEo4SK4xr8ekxfmtbezxQ7vUL3sNcvCJDXrZX/4UdXrhdRk+zYoN3dv8NmM4FmChajq/m5Am9OPdbdUBHmIYmvk7L3IpwJeMMpG5FVdGNVwYj7XNHlcy+KZ2/CKn9EpslRDPxY4650654PmhSZWDctZG7jiFWLCZBUvowiyAOPZknZNgdu5gJAfdg37XS9IP3HgTZN6Jb5Bm0by3IlKt+dTzyJQcUnRql5B1wwEI0XO3/YWQe1GQQphIO1bli9hT8n8xNDNjc49vDlu4zKyaYnQmLhqNxkyeruXSTpc8qTITuS+EGgkAUrrBj/IaFcutIg9WOzvJ3nZO8X8UG7LlyQx4AOpfHP6bynAmlT+UFccCEq66Zoh7teWLk0lUekuYST2iQJ3pwFoQGYJRCsmxsz7J0B9ayFVVT/fg+GZpZm1oTnnJ27hh8LZWv/Cr/WHOBYc3qvigWx4pDssJ+O6z7de3aWrGvzAVgXr190fRdP55a34HhNbiKZ0YWmrTs= 72 | on: 73 | tags: true 74 | skip_existing: true 75 | distributions: "sdist bdist_wheel" 76 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | **Version 1.0.0** 4 | 5 | ###### Areas of improvement 6 | - Support for `keras` and `tf.keras` 7 | - Losses as classes, base loss operations (sum of losses, multiplied loss) 8 | - NCHW and NHWC support 9 | - Removed pure tf operations to work with other keras backends 10 | - Reduced a number of custom objects for better models serialization and deserialization 11 | 12 | ###### New featrues 13 | - New backbones: EfficentNetB[0-7] 14 | - New loss function: Focal loss 15 | - New metrics: Precision, Recall 16 | 17 | ###### API changes 18 | - `get_preprocessing` moved from `sm.backbones.get_preprocessing` to `sm.get_preprocessing` 19 | 20 | **Version 0.2.1** 21 | 22 | ###### Areas of improvement 23 | 24 | - Added `set_regularization` function 25 | - Added `beta` argument to dice loss 26 | - Added `threshold` argument for metrics 27 | - Fixed `prerprocess_input` for mobilenets 28 | - Fixed missing parameter `interpolation` in `ResizeImage` layer config 29 | - Some minor improvements in docs, fixed typos 30 | 31 | **Version 0.2.0** 32 | 33 | ###### Areas of improvement 34 | 35 | - New backbones (SE-ResNets, SE-ResNeXts, SENet154, MobileNets) 36 | - Metrcis: 37 | - `iou_score` / `jaccard_score` 38 | - `f_score` / `dice_score` 39 | - Losses: 40 | - `jaccard_loss` 41 | - `bce_jaccard_loss` 42 | - `cce_jaccard_loss` 43 | - `dice_loss` 44 | - `bce_dice_loss` 45 | - `cce_dice_loss` 46 | - Documentation [Read the Docs](https://segmentation-models.readthedocs.io) 47 | - Tests + Travis-CI 48 | 49 | ###### API changes 50 | 51 | - Some parameters renamed (see API docs) 52 | - `encoder_freeze=True` does not `freeze` BatchNormalization layer of encoder 53 | 54 | ###### Thanks 55 | 56 | [@IlyaOvodov](https://github.com/IlyaOvodov) [#15](https://github.com/qubvel/segmentation_models/issues/15) [#37](https://github.com/qubvel/segmentation_models/pull/37) investigation of `align_corners` parameter in `ResizeImage` layer 57 | [@NiklasDL](https://github.com/NiklasDL) [#29](https://github.com/qubvel/segmentation_models/issues/29) investigation about convolution kernel in PSPNet final layers 58 | 59 | **Version 0.1.2** 60 | 61 | ###### Areas of improvement 62 | 63 | - Added PSPModel 64 | - Prepocessing functions for all backbones: 65 | ```python 66 | from segmentation_models.backbones import get_preprocessing 67 | 68 | preprocessing_fn = get_preprocessing('resnet34') 69 | X = preprocessing_fn(x) 70 | ``` 71 | ###### API changes 72 | - Default param `use_batchnorm=True` for all decoders 73 | - FPN model `Upsample2D` layer renamed to `ResizeImage` 74 | 75 | **Version 0.1.1** 76 | - Added `Linknet` model 77 | - Keras 2.2+ compatibility (fixed import of `_obtain_input_shape`) 78 | - Small code improvements and bug fixes 79 | 80 | **Version 0.1.0** 81 | - `Unet` and `FPN` models 82 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2018, Pavel Yakubovskiy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE requirements.txt 2 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. raw:: html 2 | 3 |

4 | 5 | Python library with Neural Networks for Image Segmentation based on Keras and TensorFlow. 6 | 7 |

8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |

16 | 17 | 18 | **The main features** of this library are: 19 | 20 | - High level API (just two lines of code to create model for segmentation) 21 | - **4** models architectures for binary and multi-class image segmentation 22 | (including legendary **Unet**) 23 | - **25** available backbones for each architecture 24 | - All backbones have **pre-trained** weights for faster and better 25 | convergence 26 | - Helpful segmentation losses (Jaccard, Dice, Focal) and metrics (IoU, F-score) 27 | 28 | **Important note** 29 | 30 | Some models of version ``1.*`` are not compatible with previously trained models, 31 | if you have such models and want to load them - roll back with: 32 | 33 | $ pip install -U segmentation-models==0.2.1 34 | 35 | Table of Contents 36 | ~~~~~~~~~~~~~~~~~ 37 | - `Quick start`_ 38 | - `Simple training pipeline`_ 39 | - `Examples`_ 40 | - `Models and Backbones`_ 41 | - `Installation`_ 42 | - `Documentation`_ 43 | - `Change log`_ 44 | - `Citing`_ 45 | - `License`_ 46 | 47 | Quick start 48 | ~~~~~~~~~~~ 49 | Library is build to work together with Keras and TensorFlow Keras frameworks 50 | 51 | .. code:: python 52 | 53 | import segmentation_models as sm 54 | # Segmentation Models: using `keras` framework. 55 | 56 | By default it tries to import ``keras``, if it is not installed, it will try to start with ``tensorflow.keras`` framework. 57 | There are several ways to choose framework: 58 | 59 | - Provide environment variable ``SM_FRAMEWORK=keras`` / ``SM_FRAMEWORK=tf.keras`` before import ``segmentation_models`` 60 | - Change framework ``sm.set_framework('keras')`` / ``sm.set_framework('tf.keras')`` 61 | 62 | You can also specify what kind of ``image_data_format`` to use, segmentation-models works with both: ``channels_last`` and ``channels_first``. 63 | This can be useful for further model conversion to Nvidia TensorRT format or optimizing model for cpu/gpu computations. 64 | 65 | .. code:: python 66 | 67 | import keras 68 | # or from tensorflow import keras 69 | 70 | keras.backend.set_image_data_format('channels_last') 71 | # or keras.backend.set_image_data_format('channels_first') 72 | 73 | Created segmentation model is just an instance of Keras Model, which can be build as easy as: 74 | 75 | .. code:: python 76 | 77 | model = sm.Unet() 78 | 79 | Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it: 80 | 81 | .. code:: python 82 | 83 | model = sm.Unet('resnet34', encoder_weights='imagenet') 84 | 85 | Change number of output classes in the model (choose your case): 86 | 87 | .. code:: python 88 | 89 | # binary segmentation (this parameters are default when you call Unet('resnet34') 90 | model = sm.Unet('resnet34', classes=1, activation='sigmoid') 91 | 92 | .. code:: python 93 | 94 | # multiclass segmentation with non overlapping class masks (your classes + background) 95 | model = sm.Unet('resnet34', classes=3, activation='softmax') 96 | 97 | .. code:: python 98 | 99 | # multiclass segmentation with independent overlapping/non-overlapping class masks 100 | model = sm.Unet('resnet34', classes=3, activation='sigmoid') 101 | 102 | 103 | Change input shape of the model: 104 | 105 | .. code:: python 106 | 107 | # if you set input channels not equal to 3, you have to set encoder_weights=None 108 | # how to handle such case with encoder_weights='imagenet' described in docs 109 | model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None) 110 | 111 | Simple training pipeline 112 | ~~~~~~~~~~~~~~~~~~~~~~~~ 113 | 114 | .. code:: python 115 | 116 | import segmentation_models as sm 117 | 118 | BACKBONE = 'resnet34' 119 | preprocess_input = sm.get_preprocessing(BACKBONE) 120 | 121 | # load your data 122 | x_train, y_train, x_val, y_val = load_data(...) 123 | 124 | # preprocess input 125 | x_train = preprocess_input(x_train) 126 | x_val = preprocess_input(x_val) 127 | 128 | # define model 129 | model = sm.Unet(BACKBONE, encoder_weights='imagenet') 130 | model.compile( 131 | 'Adam', 132 | loss=sm.losses.bce_jaccard_loss, 133 | metrics=[sm.metrics.iou_score], 134 | ) 135 | 136 | # fit model 137 | # if you use data generator use model.fit_generator(...) instead of model.fit(...) 138 | # more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator 139 | model.fit( 140 | x=x_train, 141 | y=y_train, 142 | batch_size=16, 143 | epochs=100, 144 | validation_data=(x_val, y_val), 145 | ) 146 | 147 | Same manipulations can be done with ``Linknet``, ``PSPNet`` and ``FPN``. For more detailed information about models API and use cases `Read the Docs `__. 148 | 149 | Examples 150 | ~~~~~~~~ 151 | Models training examples: 152 | - [Jupyter Notebook] Binary segmentation (`cars`) on CamVid dataset `here `__. 153 | - [Jupyter Notebook] Multi-class segmentation (`cars`, `pedestrians`) on CamVid dataset `here `__. 154 | 155 | Models and Backbones 156 | ~~~~~~~~~~~~~~~~~~~~ 157 | **Models** 158 | 159 | - `Unet `__ 160 | - `FPN `__ 161 | - `Linknet `__ 162 | - `PSPNet `__ 163 | 164 | ============= ============== 165 | Unet Linknet 166 | ============= ============== 167 | |unet_image| |linknet_image| 168 | ============= ============== 169 | 170 | ============= ============== 171 | PSPNet FPN 172 | ============= ============== 173 | |psp_image| |fpn_image| 174 | ============= ============== 175 | 176 | .. _Unet: https://github.com/qubvel/segmentation_models/blob/readme/LICENSE 177 | .. _Linknet: https://arxiv.org/abs/1707.03718 178 | .. _PSPNet: https://arxiv.org/abs/1612.01105 179 | .. _FPN: http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 180 | 181 | .. |unet_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/unet.png 182 | .. |linknet_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/linknet.png 183 | .. |psp_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/pspnet.png 184 | .. |fpn_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/fpn.png 185 | 186 | **Backbones** 187 | 188 | .. table:: 189 | 190 | ============= ===== 191 | Type Names 192 | ============= ===== 193 | VGG ``'vgg16' 'vgg19'`` 194 | ResNet ``'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'`` 195 | SE-ResNet ``'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'`` 196 | ResNeXt ``'resnext50' 'resnext101'`` 197 | SE-ResNeXt ``'seresnext50' 'seresnext101'`` 198 | SENet154 ``'senet154'`` 199 | DenseNet ``'densenet121' 'densenet169' 'densenet201'`` 200 | Inception ``'inceptionv3' 'inceptionresnetv2'`` 201 | MobileNet ``'mobilenet' 'mobilenetv2'`` 202 | EfficientNet ``'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 'efficientnetb4' 'efficientnetb5' efficientnetb6' efficientnetb7'`` 203 | ============= ===== 204 | 205 | .. epigraph:: 206 | All backbones have weights trained on 2012 ILSVRC ImageNet dataset (``encoder_weights='imagenet'``). 207 | 208 | 209 | Installation 210 | ~~~~~~~~~~~~ 211 | 212 | **Requirements** 213 | 214 | 1) python 3 215 | 2) keras >= 2.2.0 or tensorflow >= 1.13 216 | 3) keras-applications >= 1.0.7, <=1.0.8 217 | 4) image-classifiers == 1.0.* 218 | 5) efficientnet == 1.0.* 219 | 220 | **PyPI stable package** 221 | 222 | .. code:: bash 223 | 224 | $ pip install -U segmentation-models 225 | 226 | **PyPI latest package** 227 | 228 | .. code:: bash 229 | 230 | $ pip install -U --pre segmentation-models 231 | 232 | **Source latest version** 233 | 234 | .. code:: bash 235 | 236 | $ pip install git+https://github.com/qubvel/segmentation_models 237 | 238 | Documentation 239 | ~~~~~~~~~~~~~ 240 | Latest **documentation** is avaliable on `Read the 241 | Docs `__ 242 | 243 | Change Log 244 | ~~~~~~~~~~ 245 | To see important changes between versions look at CHANGELOG.md_ 246 | 247 | Citing 248 | ~~~~~~~~ 249 | 250 | .. code:: 251 | 252 | @misc{Yakubovskiy:2019, 253 | Author = {Pavel Iakubovskii}, 254 | Title = {Segmentation Models}, 255 | Year = {2019}, 256 | Publisher = {GitHub}, 257 | Journal = {GitHub repository}, 258 | Howpublished = {\url{https://github.com/qubvel/segmentation_models}} 259 | } 260 | 261 | License 262 | ~~~~~~~ 263 | Project is distributed under `MIT Licence`_. 264 | 265 | .. _CHANGELOG.md: https://github.com/qubvel/segmentation_models/blob/master/CHANGELOG.md 266 | .. _`MIT Licence`: https://github.com/qubvel/segmentation_models/blob/master/LICENSE 267 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation_models import * 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Segmentation Models Python API 2 | ============================== 3 | 4 | 5 | Getting started with segmentation models is easy. 6 | 7 | Unet 8 | ~~~~ 9 | .. autofunction:: segmentation_models.Unet 10 | 11 | Linknet 12 | ~~~~~~~ 13 | .. autofunction:: segmentation_models.Linknet 14 | 15 | FPN 16 | ~~~ 17 | .. autofunction:: segmentation_models.FPN 18 | 19 | PSPNet 20 | ~~~~~~ 21 | .. autofunction:: segmentation_models.PSPNet 22 | 23 | metrics 24 | ~~~~~~~ 25 | .. autofunction:: segmentation_models.metrics.IOUScore 26 | .. autofunction:: segmentation_models.metrics.FScore 27 | 28 | losses 29 | ~~~~~~ 30 | .. autofunction:: segmentation_models.losses.JaccardLoss 31 | .. autofunction:: segmentation_models.losses.DiceLoss 32 | .. autofunction:: segmentation_models.losses.BinaryCELoss 33 | .. autofunction:: segmentation_models.losses.CategoricalCELoss 34 | .. autofunction:: segmentation_models.losses.BinaryFocalLoss 35 | .. autofunction:: segmentation_models.losses.CategoricalFocalLoss 36 | 37 | utils 38 | ~~~~~ 39 | .. autofunction:: segmentation_models.utils.set_trainable -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | import sys 21 | sys.path.append('..') 22 | 23 | project = u'Segmentation Models' 24 | copyright = u'2018, Pavel Yakubovskiy' 25 | author = u'Pavel Yakubovskiy' 26 | 27 | # The short X.Y version 28 | version = u'' 29 | # The full version, including alpha/beta/rc tags 30 | release = u'0.1.2' 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | 'sphinx.ext.autodoc', 44 | 'sphinx.ext.coverage', 45 | 'sphinx.ext.napoleon', 46 | ] 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | # The suffix(es) of source filenames. 52 | # You can specify multiple suffix as a list of string: 53 | # 54 | # source_suffix = ['.rst', '.md'] 55 | source_suffix = '.rst' 56 | 57 | # The master toctree document. 58 | master_doc = 'index' 59 | 60 | # The language for content autogenerated by Sphinx. Refer to documentation 61 | # for a list of supported languages. 62 | # 63 | # This is also used if you do content translation via gettext catalogs. 64 | # Usually you set "language" from the command line for these cases. 65 | language = None 66 | 67 | # List of patterns, relative to source directory, that match files and 68 | # directories to ignore when looking for source files. 69 | # This pattern also affects html_static_path and html_extra_path. 70 | exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store'] 71 | 72 | # The name of the Pygments (syntax highlighting) style to use. 73 | pygments_style = None 74 | 75 | 76 | # -- Options for HTML output ------------------------------------------------- 77 | 78 | # The theme to use for HTML and HTML Help pages. See the documentation for 79 | # a list of builtin themes. 80 | # 81 | # -- Theme setup ------------------------------------------------------------- 82 | 83 | import sphinx_rtd_theme 84 | 85 | html_theme = 'sphinx_rtd_theme' 86 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 87 | 88 | # Theme options are theme-specific and customize the look and feel of a theme 89 | # further. For a list of options available for each theme, see the 90 | # documentation. 91 | # 92 | # html_theme_options = {} 93 | 94 | # Add any paths that contain custom static files (such as style sheets) here, 95 | # relative to this directory. They are copied after the builtin static files, 96 | # so a file named "default.css" will overwrite the builtin "default.css". 97 | html_static_path = ['_static'] 98 | 99 | # Custom sidebar templates, must be a dictionary that maps document names 100 | # to template names. 101 | # 102 | # The default sidebars (for documents that don't match any pattern) are 103 | # defined by theme itself. Builtin themes are using these templates by 104 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 105 | # 'searchbox.html']``. 106 | # 107 | # html_sidebars = {} 108 | 109 | 110 | # -- Options for HTMLHelp output --------------------------------------------- 111 | 112 | # Output file base name for HTML help builder. 113 | htmlhelp_basename = 'SegmentationModelsdoc' 114 | 115 | 116 | # -- Options for LaTeX output ------------------------------------------------ 117 | 118 | latex_elements = { 119 | # The paper size ('letterpaper' or 'a4paper'). 120 | # 121 | # 'papersize': 'letterpaper', 122 | 123 | # The font size ('10pt', '11pt' or '12pt'). 124 | # 125 | # 'pointsize': '10pt', 126 | 127 | # Additional stuff for the LaTeX preamble. 128 | # 129 | # 'preamble': '', 130 | 131 | # Latex figure (float) alignment 132 | # 133 | # 'figure_align': 'htbp', 134 | } 135 | 136 | # Grouping the document tree into LaTeX files. List of tuples 137 | # (source start file, target name, title, 138 | # author, documentclass [howto, manual, or own class]). 139 | latex_documents = [ 140 | (master_doc, 'SegmentationModels.tex', u'Segmentation Models Documentation', 141 | u'Pavel Yakubovskiy', 'manual'), 142 | ] 143 | 144 | 145 | # -- Options for manual page output ------------------------------------------ 146 | 147 | # One entry per manual page. List of tuples 148 | # (source start file, name, description, authors, manual section). 149 | man_pages = [ 150 | (master_doc, 'segmentationmodels', u'Segmentation Models Documentation', 151 | [author], 1) 152 | ] 153 | 154 | 155 | # -- Options for Texinfo output ---------------------------------------------- 156 | 157 | # Grouping the document tree into Texinfo files. List of tuples 158 | # (source start file, target name, title, author, 159 | # dir menu entry, description, category) 160 | texinfo_documents = [ 161 | (master_doc, 'SegmentationModels', u'Segmentation Models Documentation', 162 | author, 'SegmentationModels', 'One line description of project.', 163 | 'Miscellaneous'), 164 | ] 165 | 166 | 167 | # -- Options for Epub output ------------------------------------------------- 168 | 169 | # Bibliographic Dublin Core info. 170 | epub_title = project 171 | 172 | # The unique identifier of the text. This can be a ISBN number 173 | # or the project homepage. 174 | # 175 | # epub_identifier = '' 176 | 177 | # A unique identification for the text. 178 | # 179 | # epub_uid = '' 180 | 181 | # A list of files that should not be packed into the epub file. 182 | epub_exclude_files = ['search.html'] 183 | 184 | 185 | # -- Extension configuration ------------------------------------------------- 186 | 187 | autodoc_mock_imports = [ 188 | 'skimage', 189 | 'keras', 190 | 'tensorflow', 191 | 'efficientnet', 192 | 'classification_models', 193 | 'keras_applications', 194 | ] 195 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Segmentation Models documentation master file, created by 2 | sphinx-quickstart on Tue Dec 18 17:37:58 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Segmentation Models's documentation! 7 | =============================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | install 14 | tutorial 15 | api 16 | support 17 | 18 | 19 | 20 | Indices and tables 21 | ================== 22 | 23 | * :ref:`genindex` 24 | * :ref:`modindex` 25 | * :ref:`search` 26 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | **Requirements** 5 | 6 | 1) Python 3 7 | 2) Keras >= 2.2.0 or TensorFlow >= 1.13 8 | 3) keras-applications >= 1.0.7, <=1.0.8 9 | 4) image-classifiers == 1.0.0 10 | 5) efficientnet == 1.0.0 11 | 12 | 13 | .. note:: 14 | 15 | This library does not have Tensorflow_ in a requirements.txt 16 | for installation. Please, choose suitable version ('cpu'/'gpu') 17 | and install it manually using official Guide_. 18 | 19 | 20 | **Pip package** 21 | 22 | .. code:: bash 23 | 24 | $ pip install segmentation-models 25 | 26 | **Latest version** 27 | 28 | .. code:: bash 29 | 30 | $ pip install git+https://github.com/qubvel/segmentation_models 31 | 32 | .. _Guide: 33 | https://www.tensorflow.org/install/ 34 | 35 | .. _Tensorflow: 36 | https://www.tensorflow.org/ 37 | 38 | -------------------------------------------------------------------------------- /docs/support.rst: -------------------------------------------------------------------------------- 1 | Support 2 | ======= 3 | 4 | The easiest way to get help with the project is to create issue or PR on github. 5 | 6 | Github: http://github.com/qubvel/segmentation_models/issues -------------------------------------------------------------------------------- /docs/tutorial.rst: -------------------------------------------------------------------------------- 1 | Tutorial 2 | ======== 3 | 4 | **Segmentation models** is python library with Neural Networks for 5 | `Image 6 | Segmentation `__ based 7 | on `Keras `__ 8 | (`Tensorflow `__) framework. 9 | 10 | **The main features** of this library are: 11 | 12 | - High level API (just two lines to create NN) 13 | - **4** models architectures for binary and multi class segmentation 14 | (including legendary **Unet**) 15 | - **25** available backbones for each architecture 16 | - All backbones have **pre-trained** weights for faster and better 17 | convergence 18 | 19 | Quick start 20 | ~~~~~~~~~~~ 21 | Since the library is built on the Keras framework, created segmentation model is just a Keras Model, which can be created as easy as: 22 | 23 | .. code:: python 24 | 25 | from segmentation_models import Unet 26 | 27 | model = Unet() 28 | 29 | Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it: 30 | 31 | .. code:: python 32 | 33 | model = Unet('resnet34', encoder_weights='imagenet') 34 | 35 | Change number of output classes in the model: 36 | 37 | .. code:: python 38 | 39 | model = Unet('resnet34', classes=3, activation='softmax') 40 | 41 | Change input shape of the model: 42 | 43 | .. code:: python 44 | 45 | model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None) 46 | 47 | Simple training pipeline 48 | ~~~~~~~~~~~~~~~~~~~~~~~~ 49 | 50 | .. code:: python 51 | 52 | from segmentation_models import Unet 53 | from segmentation_models import get_preprocessing 54 | from segmentation_models.losses import bce_jaccard_loss 55 | from segmentation_models.metrics import iou_score 56 | 57 | BACKBONE = 'resnet34' 58 | preprocess_input = get_preprocessing(BACKBONE) 59 | 60 | # load your data 61 | x_train, y_train, x_val, y_val = load_data(...) 62 | 63 | # preprocess input 64 | x_train = preprocess_input(x_train) 65 | x_val = preprocess_input(x_val) 66 | 67 | # define model 68 | model = Unet(BACKBONE, encoder_weights='imagenet') 69 | model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score]) 70 | 71 | # fit model 72 | model.fit( 73 | x=x_train, 74 | y=y_train, 75 | batch_size=16, 76 | epochs=100, 77 | validation_data=(x_val, y_val), 78 | ) 79 | 80 | 81 | Same manimulations can be done with ``Linknet``, ``PSPNet`` and ``FPN``. For more detailed information about models API and use cases `Read the Docs `__. 82 | 83 | Models and Backbones 84 | ~~~~~~~~~~~~~~~~~~~~ 85 | **Models** 86 | 87 | - `Unet `__ 88 | - `FPN `__ 89 | - `Linknet `__ 90 | - `PSPNet `__ 91 | 92 | ============= ============== 93 | Unet Linknet 94 | ============= ============== 95 | |unet_image| |linknet_image| 96 | ============= ============== 97 | ============= ============== 98 | PSPNet FPN 99 | ============= ============== 100 | |psp_image| |fpn_image| 101 | ============= ============== 102 | 103 | .. _Unet: https://github.com/qubvel/segmentation_models/blob/readme/LICENSE 104 | .. _Linknet: https://arxiv.org/abs/1707.03718 105 | .. _PSPNet: https://arxiv.org/abs/1612.01105 106 | .. _FPN: http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 107 | 108 | .. |unet_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/3a873a00c9742dc1fb33105ed846d5b5-full.png 109 | .. |linknet_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/1a996c4ef05531ff3861d80823c373d9-full.png 110 | .. |psp_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/aaabb97f89197b40e4879a7299b3c801-full.png 111 | .. |fpn_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/af00f11ef6bc8a64efd29ed873fcb0c4-full.png 112 | 113 | **Backbones** 114 | 115 | .. table:: 116 | 117 | =========== ===== 118 | Type Names 119 | =========== ===== 120 | VGG ``'vgg16' 'vgg19'`` 121 | ResNet ``'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'`` 122 | SE-ResNet ``'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'`` 123 | ResNeXt ``'resnext50' 'resnet101'`` 124 | SE-ResNeXt ``'seresnext50' 'seresnet101'`` 125 | SENet154 ``'senet154'`` 126 | DenseNet ``'densenet121' 'densenet169' 'densenet201'`` 127 | Inception ``'inceptionv3' 'inceptionresnetv2'`` 128 | MobileNet ``'mobilenet' 'mobilenetv2'`` 129 | EfficientNet ``efficientnetb0`` ``efficientnetb1`` ``efficientnetb2`` ``efficientnetb3`` ``efficientnetb4`` ``efficientnetb5`` 130 | =========== ===== 131 | 132 | .. epigraph:: 133 | All backbones have weights trained on 2012 ILSVRC ImageNet dataset (``encoder_weights='imagenet'``). 134 | 135 | 136 | Fine tuning 137 | ~~~~~~~~~~~ 138 | 139 | Some times, it is useful to train only randomly initialized 140 | *decoder* in order not to damage weights of properly trained 141 | *encoder* with huge gradients during first steps of training. 142 | In this case, all you need is just pass ``encoder_freeze = True`` argument 143 | while initializing the model. 144 | 145 | .. code-block:: python 146 | 147 | from segmentation_models import Unet 148 | from segmentation_models.utils import set_trainable 149 | 150 | model = Unet(backbone_name='resnet34', encoder_weights='imagenet', encoder_freeze=True) 151 | model.compile('Adam', 'binary_crossentropy', ['binary_accuracy']) 152 | 153 | # pretrain model decoder 154 | model.fit(x, y, epochs=2) 155 | 156 | # release all layers for training 157 | set_trainable(model) # set all layers trainable and recompile model 158 | 159 | # continue training 160 | model.fit(x, y, epochs=100) 161 | 162 | 163 | Training with non-RGB data 164 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 165 | 166 | In case you have non RGB images (e.g. grayscale or some medical/remote sensing data) 167 | you have few different options: 168 | 169 | 1. Train network from scratch with randomly initialized weights 170 | 171 | .. code-block:: python 172 | 173 | from segmentation_models import Unet 174 | 175 | # read/scale/preprocess data 176 | x, y = ... 177 | 178 | # define number of channels 179 | N = x.shape[-1] 180 | 181 | # define model 182 | model = Unet(backbone_name='resnet34', encoder_weights=None, input_shape=(None, None, N)) 183 | 184 | # continue with usual steps: compile, fit, etc.. 185 | 186 | 2. Add extra convolution layer to map ``N -> 3`` channels data and train with pretrained weights 187 | 188 | .. code-block:: python 189 | 190 | from segmentation_models import Unet 191 | from keras.layers import Input, Conv2D 192 | from keras.models import Model 193 | 194 | # read/scale/preprocess data 195 | x, y = ... 196 | 197 | # define number of channels 198 | N = x.shape[-1] 199 | 200 | base_model = Unet(backbone_name='resnet34', encoder_weights='imagenet') 201 | 202 | inp = Input(shape=(None, None, N)) 203 | l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels 204 | out = base_model(l1) 205 | 206 | model = Model(inp, out, name=base_model.name) 207 | 208 | # continue with usual steps: compile, fit, etc.. 209 | 210 | .. _Image Segmentation: 211 | https://en.wikipedia.org/wiki/Image_segmentation 212 | 213 | .. _Tensorflow: 214 | https://www.tensorflow.org/ 215 | 216 | .. _Keras: 217 | https://keras.io 218 | 219 | .. _Unet: 220 | https://arxiv.org/pdf/1505.04597 221 | 222 | .. _Linknet: 223 | https://arxiv.org/pdf/1707.03718.pdf 224 | 225 | .. _PSPNet: 226 | https://arxiv.org/pdf/1612.01105.pdf 227 | 228 | .. _FPN: 229 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 230 | -------------------------------------------------------------------------------- /images/fpn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/fpn.png -------------------------------------------------------------------------------- /images/linknet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/linknet.png -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/logo.png -------------------------------------------------------------------------------- /images/pspnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/pspnet.png -------------------------------------------------------------------------------- /images/unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/unet.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras_applications>=1.0.7,<=1.0.8 2 | image-classifiers==1.0.0 3 | efficientnet==1.1.1 4 | -------------------------------------------------------------------------------- /segmentation_models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | from .__version__ import __version__ 4 | from . import base 5 | 6 | _KERAS_FRAMEWORK_NAME = 'keras' 7 | _TF_KERAS_FRAMEWORK_NAME = 'tf.keras' 8 | 9 | _DEFAULT_KERAS_FRAMEWORK = _KERAS_FRAMEWORK_NAME 10 | _KERAS_FRAMEWORK = None 11 | _KERAS_BACKEND = None 12 | _KERAS_LAYERS = None 13 | _KERAS_MODELS = None 14 | _KERAS_UTILS = None 15 | _KERAS_LOSSES = None 16 | 17 | 18 | def inject_global_losses(func): 19 | @functools.wraps(func) 20 | def wrapper(*args, **kwargs): 21 | kwargs['losses'] = _KERAS_LOSSES 22 | return func(*args, **kwargs) 23 | 24 | return wrapper 25 | 26 | 27 | def inject_global_submodules(func): 28 | @functools.wraps(func) 29 | def wrapper(*args, **kwargs): 30 | kwargs['backend'] = _KERAS_BACKEND 31 | kwargs['layers'] = _KERAS_LAYERS 32 | kwargs['models'] = _KERAS_MODELS 33 | kwargs['utils'] = _KERAS_UTILS 34 | return func(*args, **kwargs) 35 | 36 | return wrapper 37 | 38 | 39 | def filter_kwargs(func): 40 | @functools.wraps(func) 41 | def wrapper(*args, **kwargs): 42 | new_kwargs = {k: v for k, v in kwargs.items() if k in ['backend', 'layers', 'models', 'utils']} 43 | return func(*args, **new_kwargs) 44 | 45 | return wrapper 46 | 47 | 48 | def framework(): 49 | """Return name of Segmentation Models framework""" 50 | return _KERAS_FRAMEWORK 51 | 52 | 53 | def set_framework(name): 54 | """Set framework for Segmentation Models 55 | 56 | Args: 57 | name (str): one of ``keras``, ``tf.keras``, case insensitive. 58 | 59 | Raises: 60 | ValueError: in case of incorrect framework name. 61 | ImportError: in case framework is not installed. 62 | 63 | """ 64 | name = name.lower() 65 | 66 | if name == _KERAS_FRAMEWORK_NAME: 67 | import keras 68 | import efficientnet.keras # init custom objects 69 | elif name == _TF_KERAS_FRAMEWORK_NAME: 70 | from tensorflow import keras 71 | import efficientnet.tfkeras # init custom objects 72 | else: 73 | raise ValueError('Not correct module name `{}`, use `{}` or `{}`'.format( 74 | name, _KERAS_FRAMEWORK_NAME, _TF_KERAS_FRAMEWORK_NAME)) 75 | 76 | global _KERAS_BACKEND, _KERAS_LAYERS, _KERAS_MODELS 77 | global _KERAS_UTILS, _KERAS_LOSSES, _KERAS_FRAMEWORK 78 | 79 | _KERAS_FRAMEWORK = name 80 | _KERAS_BACKEND = keras.backend 81 | _KERAS_LAYERS = keras.layers 82 | _KERAS_MODELS = keras.models 83 | _KERAS_UTILS = keras.utils 84 | _KERAS_LOSSES = keras.losses 85 | 86 | # allow losses/metrics get keras submodules 87 | base.KerasObject.set_submodules( 88 | backend=keras.backend, 89 | layers=keras.layers, 90 | models=keras.models, 91 | utils=keras.utils, 92 | ) 93 | 94 | 95 | # set default framework 96 | _framework = os.environ.get('SM_FRAMEWORK', _DEFAULT_KERAS_FRAMEWORK) 97 | try: 98 | set_framework(_framework) 99 | except ImportError: 100 | other = _TF_KERAS_FRAMEWORK_NAME if _framework == _KERAS_FRAMEWORK_NAME else _KERAS_FRAMEWORK_NAME 101 | set_framework(other) 102 | 103 | print('Segmentation Models: using `{}` framework.'.format(_KERAS_FRAMEWORK)) 104 | 105 | # import helper modules 106 | from . import losses 107 | from . import metrics 108 | from . import utils 109 | 110 | # wrap segmentation models with framework modules 111 | from .backbones.backbones_factory import Backbones 112 | from .models.unet import Unet as _Unet 113 | from .models.pspnet import PSPNet as _PSPNet 114 | from .models.linknet import Linknet as _Linknet 115 | from .models.fpn import FPN as _FPN 116 | 117 | Unet = inject_global_submodules(_Unet) 118 | PSPNet = inject_global_submodules(_PSPNet) 119 | Linknet = inject_global_submodules(_Linknet) 120 | FPN = inject_global_submodules(_FPN) 121 | get_available_backbone_names = Backbones.models_names 122 | 123 | 124 | def get_preprocessing(name): 125 | preprocess_input = Backbones.get_preprocessing(name) 126 | # add bakcend, models, layers, utils submodules in kwargs 127 | preprocess_input = inject_global_submodules(preprocess_input) 128 | # delete other kwargs 129 | # keras-applications preprocessing raise an error if something 130 | # except `backend`, `layers`, `models`, `utils` passed in kwargs 131 | preprocess_input = filter_kwargs(preprocess_input) 132 | return preprocess_input 133 | 134 | 135 | __all__ = [ 136 | 'Unet', 'PSPNet', 'FPN', 'Linknet', 137 | 'set_framework', 'framework', 138 | 'get_preprocessing', 'get_available_backbone_names', 139 | 'losses', 'metrics', 'utils', 140 | '__version__', 141 | ] 142 | -------------------------------------------------------------------------------- /segmentation_models/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (1, 0, 1) 2 | 3 | __version__ = '.'.join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /segmentation_models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/segmentation_models/backbones/__init__.py -------------------------------------------------------------------------------- /segmentation_models/backbones/backbones_factory.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import efficientnet.model as eff 3 | from classification_models.models_factory import ModelsFactory 4 | 5 | from . import inception_resnet_v2 as irv2 6 | from . import inception_v3 as iv3 7 | 8 | 9 | class BackbonesFactory(ModelsFactory): 10 | _default_feature_layers = { 11 | 12 | # List of layers to take features from backbone in the following order: 13 | # (x16, x8, x4, x2, x1) - `x4` mean that features has 4 times less spatial 14 | # resolution (Height x Width) than input image. 15 | 16 | # VGG 17 | 'vgg16': ('block5_conv3', 'block4_conv3', 'block3_conv3', 'block2_conv2', 'block1_conv2'), 18 | 'vgg19': ('block5_conv4', 'block4_conv4', 'block3_conv4', 'block2_conv2', 'block1_conv2'), 19 | 20 | # ResNets 21 | 'resnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 22 | 'resnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 23 | 'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 24 | 'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 25 | 'resnet152': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 26 | 27 | # ResNeXt 28 | 'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 29 | 'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 30 | 31 | # Inception 32 | 'inceptionv3': (228, 86, 16, 9), 33 | 'inceptionresnetv2': (594, 260, 16, 9), 34 | 35 | # DenseNet 36 | 'densenet121': (311, 139, 51, 4), 37 | 'densenet169': (367, 139, 51, 4), 38 | 'densenet201': (479, 139, 51, 4), 39 | 40 | # SE models 41 | 'seresnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 42 | 'seresnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 43 | 'seresnet50': (246, 136, 62, 4), 44 | 'seresnet101': (552, 136, 62, 4), 45 | 'seresnet152': (858, 208, 62, 4), 46 | 'seresnext50': (1078, 584, 254, 4), 47 | 'seresnext101': (2472, 584, 254, 4), 48 | 'senet154': (6884, 1625, 454, 12), 49 | 50 | # Mobile Nets 51 | 'mobilenet': ('conv_pw_11_relu', 'conv_pw_5_relu', 'conv_pw_3_relu', 'conv_pw_1_relu'), 52 | 'mobilenetv2': ('block_13_expand_relu', 'block_6_expand_relu', 'block_3_expand_relu', 53 | 'block_1_expand_relu'), 54 | 55 | # EfficientNets 56 | 'efficientnetb0': ('block6a_expand_activation', 'block4a_expand_activation', 57 | 'block3a_expand_activation', 'block2a_expand_activation'), 58 | 'efficientnetb1': ('block6a_expand_activation', 'block4a_expand_activation', 59 | 'block3a_expand_activation', 'block2a_expand_activation'), 60 | 'efficientnetb2': ('block6a_expand_activation', 'block4a_expand_activation', 61 | 'block3a_expand_activation', 'block2a_expand_activation'), 62 | 'efficientnetb3': ('block6a_expand_activation', 'block4a_expand_activation', 63 | 'block3a_expand_activation', 'block2a_expand_activation'), 64 | 'efficientnetb4': ('block6a_expand_activation', 'block4a_expand_activation', 65 | 'block3a_expand_activation', 'block2a_expand_activation'), 66 | 'efficientnetb5': ('block6a_expand_activation', 'block4a_expand_activation', 67 | 'block3a_expand_activation', 'block2a_expand_activation'), 68 | 'efficientnetb6': ('block6a_expand_activation', 'block4a_expand_activation', 69 | 'block3a_expand_activation', 'block2a_expand_activation'), 70 | 'efficientnetb7': ('block6a_expand_activation', 'block4a_expand_activation', 71 | 'block3a_expand_activation', 'block2a_expand_activation'), 72 | 73 | } 74 | 75 | _models_update = { 76 | 'inceptionresnetv2': [irv2.InceptionResNetV2, irv2.preprocess_input], 77 | 'inceptionv3': [iv3.InceptionV3, iv3.preprocess_input], 78 | 79 | 'efficientnetb0': [eff.EfficientNetB0, eff.preprocess_input], 80 | 'efficientnetb1': [eff.EfficientNetB1, eff.preprocess_input], 81 | 'efficientnetb2': [eff.EfficientNetB2, eff.preprocess_input], 82 | 'efficientnetb3': [eff.EfficientNetB3, eff.preprocess_input], 83 | 'efficientnetb4': [eff.EfficientNetB4, eff.preprocess_input], 84 | 'efficientnetb5': [eff.EfficientNetB5, eff.preprocess_input], 85 | 'efficientnetb6': [eff.EfficientNetB6, eff.preprocess_input], 86 | 'efficientnetb7': [eff.EfficientNetB7, eff.preprocess_input], 87 | } 88 | 89 | # currently not supported 90 | _models_delete = ['resnet50v2', 'resnet101v2', 'resnet152v2', 91 | 'nasnetlarge', 'nasnetmobile', 'xception'] 92 | 93 | @property 94 | def models(self): 95 | all_models = copy.copy(self._models) 96 | all_models.update(self._models_update) 97 | for k in self._models_delete: 98 | del all_models[k] 99 | return all_models 100 | 101 | def get_backbone(self, name, *args, **kwargs): 102 | model_fn, _ = self.get(name) 103 | model = model_fn(*args, **kwargs) 104 | return model 105 | 106 | def get_feature_layers(self, name, n=5): 107 | return self._default_feature_layers[name][:n] 108 | 109 | def get_preprocessing(self, name): 110 | return self.get(name)[1] 111 | 112 | 113 | Backbones = BackbonesFactory() 114 | -------------------------------------------------------------------------------- /segmentation_models/backbones/inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | """Inception-ResNet V2 model for Keras. 2 | Model naming and structure follows TF-slim implementation 3 | (which has some additional layers and different number of 4 | filters from the original arXiv paper): 5 | https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py 6 | Pre-trained ImageNet weights are also converted from TF-slim, 7 | which can be found in: 8 | https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models 9 | # Reference 10 | - [Inception-v4, Inception-ResNet and the Impact of 11 | Residual Connections on Learning](https://arxiv.org/abs/1602.07261) (AAAI 2017) 12 | """ 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | 17 | import os 18 | 19 | from keras_applications import imagenet_utils 20 | from keras_applications import get_submodules_from_kwargs 21 | 22 | BASE_WEIGHT_URL = ('https://github.com/fchollet/deep-learning-models/' 23 | 'releases/download/v0.7/') 24 | 25 | backend = None 26 | layers = None 27 | models = None 28 | keras_utils = None 29 | 30 | 31 | def preprocess_input(x, **kwargs): 32 | """Preprocesses a numpy array encoding a batch of images. 33 | # Arguments 34 | x: a 4D numpy array consists of RGB values within [0, 255]. 35 | # Returns 36 | Preprocessed array. 37 | """ 38 | return imagenet_utils.preprocess_input(x, mode='tf', **kwargs) 39 | 40 | 41 | def conv2d_bn(x, 42 | filters, 43 | kernel_size, 44 | strides=1, 45 | padding='same', 46 | activation='relu', 47 | use_bias=False, 48 | name=None): 49 | """Utility function to apply conv + BN. 50 | # Arguments 51 | x: input tensor. 52 | filters: filters in `Conv2D`. 53 | kernel_size: kernel size as in `Conv2D`. 54 | strides: strides in `Conv2D`. 55 | padding: padding mode in `Conv2D`. 56 | activation: activation in `Conv2D`. 57 | use_bias: whether to use a bias in `Conv2D`. 58 | name: name of the ops; will become `name + '_ac'` for the activation 59 | and `name + '_bn'` for the batch norm layer. 60 | # Returns 61 | Output tensor after applying `Conv2D` and `BatchNormalization`. 62 | """ 63 | x = layers.Conv2D(filters, 64 | kernel_size, 65 | strides=strides, 66 | padding=padding, 67 | use_bias=use_bias, 68 | name=name)(x) 69 | if not use_bias: 70 | bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3 71 | bn_name = None if name is None else name + '_bn' 72 | x = layers.BatchNormalization(axis=bn_axis, 73 | scale=False, 74 | name=bn_name)(x) 75 | if activation is not None: 76 | ac_name = None if name is None else name + '_ac' 77 | x = layers.Activation(activation, name=ac_name)(x) 78 | return x 79 | 80 | 81 | def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): 82 | """Adds a Inception-ResNet block. 83 | This function builds 3 types of Inception-ResNet blocks mentioned 84 | in the paper, controlled by the `block_type` argument (which is the 85 | block name used in the official TF-slim implementation): 86 | - Inception-ResNet-A: `block_type='block35'` 87 | - Inception-ResNet-B: `block_type='block17'` 88 | - Inception-ResNet-C: `block_type='block8'` 89 | # Arguments 90 | x: input tensor. 91 | scale: scaling factor to scale the residuals (i.e., the output of 92 | passing `x` through an inception module) before adding them 93 | to the shortcut branch. 94 | Let `r` be the output from the residual branch, 95 | the output of this block will be `x + scale * r`. 96 | block_type: `'block35'`, `'block17'` or `'block8'`, determines 97 | the network structure in the residual branch. 98 | block_idx: an `int` used for generating layer names. 99 | The Inception-ResNet blocks 100 | are repeated many times in this network. 101 | We use `block_idx` to identify 102 | each of the repetitions. For example, 103 | the first Inception-ResNet-A block 104 | will have `block_type='block35', block_idx=0`, 105 | and the layer names will have 106 | a common prefix `'block35_0'`. 107 | activation: activation function to use at the end of the block 108 | (see [activations](../activations.md)). 109 | When `activation=None`, no activation is applied 110 | (i.e., "linear" activation: `a(x) = x`). 111 | # Returns 112 | Output tensor for the block. 113 | # Raises 114 | ValueError: if `block_type` is not one of `'block35'`, 115 | `'block17'` or `'block8'`. 116 | """ 117 | if block_type == 'block35': 118 | branch_0 = conv2d_bn(x, 32, 1) 119 | branch_1 = conv2d_bn(x, 32, 1) 120 | branch_1 = conv2d_bn(branch_1, 32, 3) 121 | branch_2 = conv2d_bn(x, 32, 1) 122 | branch_2 = conv2d_bn(branch_2, 48, 3) 123 | branch_2 = conv2d_bn(branch_2, 64, 3) 124 | branches = [branch_0, branch_1, branch_2] 125 | elif block_type == 'block17': 126 | branch_0 = conv2d_bn(x, 192, 1) 127 | branch_1 = conv2d_bn(x, 128, 1) 128 | branch_1 = conv2d_bn(branch_1, 160, [1, 7]) 129 | branch_1 = conv2d_bn(branch_1, 192, [7, 1]) 130 | branches = [branch_0, branch_1] 131 | elif block_type == 'block8': 132 | branch_0 = conv2d_bn(x, 192, 1) 133 | branch_1 = conv2d_bn(x, 192, 1) 134 | branch_1 = conv2d_bn(branch_1, 224, [1, 3]) 135 | branch_1 = conv2d_bn(branch_1, 256, [3, 1]) 136 | branches = [branch_0, branch_1] 137 | else: 138 | raise ValueError('Unknown Inception-ResNet block type. ' 139 | 'Expects "block35", "block17" or "block8", ' 140 | 'but got: ' + str(block_type)) 141 | 142 | block_name = block_type + '_' + str(block_idx) 143 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 144 | mixed = layers.Concatenate( 145 | axis=channel_axis, name=block_name + '_mixed')(branches) 146 | up = conv2d_bn(mixed, 147 | backend.int_shape(x)[channel_axis], 148 | 1, 149 | activation=None, 150 | use_bias=True, 151 | name=block_name + '_conv') 152 | 153 | x = layers.Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale, 154 | output_shape=backend.int_shape(x)[1:], 155 | arguments={'scale': scale}, 156 | name=block_name)([x, up]) 157 | if activation is not None: 158 | x = layers.Activation(activation, name=block_name + '_ac')(x) 159 | return x 160 | 161 | 162 | def InceptionResNetV2(include_top=True, 163 | weights='imagenet', 164 | input_tensor=None, 165 | input_shape=None, 166 | pooling=None, 167 | classes=1000, 168 | **kwargs): 169 | """Instantiates the Inception-ResNet v2 architecture. 170 | Optionally loads weights pre-trained on ImageNet. 171 | Note that the data format convention used by the model is 172 | the one specified in your Keras config at `~/.keras/keras.json`. 173 | # Arguments 174 | include_top: whether to include the fully-connected 175 | layer at the top of the network. 176 | weights: one of `None` (random initialization), 177 | 'imagenet' (pre-training on ImageNet), 178 | or the path to the weights file to be loaded. 179 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 180 | to use as image input for the model. 181 | input_shape: optional shape tuple, only to be specified 182 | if `include_top` is `False` (otherwise the input shape 183 | has to be `(299, 299, 3)` (with `'channels_last'` data format) 184 | or `(3, 299, 299)` (with `'channels_first'` data format). 185 | It should have exactly 3 inputs channels, 186 | and width and height should be no smaller than 75. 187 | E.g. `(150, 150, 3)` would be one valid value. 188 | pooling: Optional pooling mode for feature extraction 189 | when `include_top` is `False`. 190 | - `None` means that the output of the model will be 191 | the 4D tensor output of the last convolutional block. 192 | - `'avg'` means that global average pooling 193 | will be applied to the output of the 194 | last convolutional block, and thus 195 | the output of the model will be a 2D tensor. 196 | - `'max'` means that global max pooling will be applied. 197 | classes: optional number of classes to classify images 198 | into, only to be specified if `include_top` is `True`, and 199 | if no `weights` argument is specified. 200 | # Returns 201 | A Keras `Model` instance. 202 | # Raises 203 | ValueError: in case of invalid argument for `weights`, 204 | or invalid input shape. 205 | """ 206 | global backend, layers, models, keras_utils 207 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs) 208 | 209 | if not (weights in {'imagenet', None} or os.path.exists(weights)): 210 | raise ValueError('The `weights` argument should be either ' 211 | '`None` (random initialization), `imagenet` ' 212 | '(pre-training on ImageNet), ' 213 | 'or the path to the weights file to be loaded.') 214 | 215 | if weights == 'imagenet' and include_top and classes != 1000: 216 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 217 | ' as true, `classes` should be 1000') 218 | 219 | # Determine proper input shape 220 | input_shape = imagenet_utils._obtain_input_shape( 221 | input_shape, 222 | default_size=299, 223 | min_size=32, 224 | data_format=backend.image_data_format(), 225 | require_flatten=include_top, 226 | weights=weights) 227 | 228 | if input_tensor is None: 229 | img_input = layers.Input(shape=input_shape) 230 | else: 231 | if not backend.is_keras_tensor(input_tensor): 232 | img_input = layers.Input(tensor=input_tensor, shape=input_shape) 233 | else: 234 | img_input = input_tensor 235 | 236 | # Stem block: 35 x 35 x 192 237 | x = conv2d_bn(img_input, 32, 3, strides=2, padding='same') 238 | x = conv2d_bn(x, 32, 3, padding='same') 239 | x = conv2d_bn(x, 64, 3, padding='same') 240 | x = layers.MaxPooling2D(3, strides=2, padding='same')(x) 241 | x = conv2d_bn(x, 80, 1, padding='same') 242 | x = conv2d_bn(x, 192, 3, padding='same') 243 | x = layers.MaxPooling2D(3, strides=2, padding='same')(x) 244 | 245 | # Mixed 5b (Inception-A block): 35 x 35 x 320 246 | branch_0 = conv2d_bn(x, 96, 1, padding='same') 247 | branch_1 = conv2d_bn(x, 48, 1, padding='same') 248 | branch_1 = conv2d_bn(branch_1, 64, 5, padding='same') 249 | branch_2 = conv2d_bn(x, 64, 1, padding='same') 250 | branch_2 = conv2d_bn(branch_2, 96, 3, padding='same') 251 | branch_2 = conv2d_bn(branch_2, 96, 3, padding='same') 252 | branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x) 253 | branch_pool = conv2d_bn(branch_pool, 64, 1, padding='same') 254 | branches = [branch_0, branch_1, branch_2, branch_pool] 255 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 256 | x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches) 257 | 258 | # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320 259 | for block_idx in range(1, 11): 260 | x = inception_resnet_block(x, 261 | scale=0.17, 262 | block_type='block35', 263 | block_idx=block_idx) 264 | 265 | # Mixed 6a (Reduction-A block): 17 x 17 x 1088 266 | branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='same') 267 | branch_1 = conv2d_bn(x, 256, 1, padding='same') 268 | branch_1 = conv2d_bn(branch_1, 256, 3, padding='same') 269 | branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='same') 270 | branch_pool = layers.MaxPooling2D(3, strides=2, padding='same')(x) 271 | branches = [branch_0, branch_1, branch_pool] 272 | x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches) 273 | 274 | # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088 275 | for block_idx in range(1, 21): 276 | x = inception_resnet_block(x, 277 | scale=0.1, 278 | block_type='block17', 279 | block_idx=block_idx) 280 | 281 | # Mixed 7a (Reduction-B block): 8 x 8 x 2080 282 | branch_0 = conv2d_bn(x, 256, 1, padding='same') 283 | branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='same') 284 | branch_1 = conv2d_bn(x, 256, 1, padding='same') 285 | branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='same') 286 | branch_2 = conv2d_bn(x, 256, 1, padding='same') 287 | branch_2 = conv2d_bn(branch_2, 288, 3, padding='same') 288 | branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='same') 289 | branch_pool = layers.MaxPooling2D(3, strides=2, padding='same')(x) 290 | branches = [branch_0, branch_1, branch_2, branch_pool] 291 | x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches) 292 | 293 | # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080 294 | for block_idx in range(1, 10): 295 | x = inception_resnet_block(x, 296 | scale=0.2, 297 | block_type='block8', 298 | block_idx=block_idx) 299 | x = inception_resnet_block(x, 300 | scale=1., 301 | activation=None, 302 | block_type='block8', 303 | block_idx=10) 304 | 305 | # Final convolution block: 8 x 8 x 1536 306 | x = conv2d_bn(x, 1536, 1, name='conv_7b') 307 | 308 | if include_top: 309 | # Classification block 310 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 311 | x = layers.Dense(classes, activation='softmax', name='predictions')(x) 312 | else: 313 | if pooling == 'avg': 314 | x = layers.GlobalAveragePooling2D()(x) 315 | elif pooling == 'max': 316 | x = layers.GlobalMaxPooling2D()(x) 317 | 318 | # Ensure that the model takes into account 319 | # any potential predecessors of `input_tensor`. 320 | if input_tensor is not None: 321 | inputs = keras_utils.get_source_inputs(input_tensor) 322 | else: 323 | inputs = img_input 324 | 325 | # Create model. 326 | model = models.Model(inputs, x, name='inception_resnet_v2') 327 | 328 | # Load weights. 329 | if weights == 'imagenet': 330 | if include_top: 331 | fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5' 332 | weights_path = keras_utils.get_file( 333 | fname, 334 | BASE_WEIGHT_URL + fname, 335 | cache_subdir='models', 336 | file_hash='e693bd0210a403b3192acc6073ad2e96') 337 | else: 338 | fname = ('inception_resnet_v2_weights_' 339 | 'tf_dim_ordering_tf_kernels_notop.h5') 340 | weights_path = keras_utils.get_file( 341 | fname, 342 | BASE_WEIGHT_URL + fname, 343 | cache_subdir='models', 344 | file_hash='d19885ff4a710c122648d3b5c3b684e4') 345 | model.load_weights(weights_path) 346 | elif weights is not None: 347 | model.load_weights(weights) 348 | 349 | return model 350 | -------------------------------------------------------------------------------- /segmentation_models/backbones/inception_v3.py: -------------------------------------------------------------------------------- 1 | """Inception V3 model for Keras. 2 | Note that the input image format for this model is different than for 3 | the VGG16 and ResNet models (299x299 instead of 224x224), 4 | and that the input preprocessing function is also different (same as Xception). 5 | # Reference 6 | - [Rethinking the Inception Architecture for Computer Vision]( 7 | http://arxiv.org/abs/1512.00567) (CVPR 2016) 8 | """ 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import os 14 | 15 | from keras_applications import imagenet_utils 16 | from keras_applications import get_submodules_from_kwargs 17 | 18 | WEIGHTS_PATH = ( 19 | 'https://github.com/fchollet/deep-learning-models/' 20 | 'releases/download/v0.5/' 21 | 'inception_v3_weights_tf_dim_ordering_tf_kernels.h5') 22 | WEIGHTS_PATH_NO_TOP = ( 23 | 'https://github.com/fchollet/deep-learning-models/' 24 | 'releases/download/v0.5/' 25 | 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5') 26 | 27 | backend = None 28 | layers = None 29 | models = None 30 | keras_utils = None 31 | 32 | 33 | def conv2d_bn(x, 34 | filters, 35 | num_row, 36 | num_col, 37 | padding='same', 38 | strides=(1, 1), 39 | name=None): 40 | """Utility function to apply conv + BN. 41 | # Arguments 42 | x: input tensor. 43 | filters: filters in `Conv2D`. 44 | num_row: height of the convolution kernel. 45 | num_col: width of the convolution kernel. 46 | padding: padding mode in `Conv2D`. 47 | strides: strides in `Conv2D`. 48 | name: name of the ops; will become `name + '_conv'` 49 | for the convolution and `name + '_bn'` for the 50 | batch norm layer. 51 | # Returns 52 | Output tensor after applying `Conv2D` and `BatchNormalization`. 53 | """ 54 | if name is not None: 55 | bn_name = name + '_bn' 56 | conv_name = name + '_conv' 57 | else: 58 | bn_name = None 59 | conv_name = None 60 | if backend.image_data_format() == 'channels_first': 61 | bn_axis = 1 62 | else: 63 | bn_axis = 3 64 | x = layers.Conv2D( 65 | filters, (num_row, num_col), 66 | strides=strides, 67 | padding=padding, 68 | use_bias=False, 69 | name=conv_name)(x) 70 | x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) 71 | x = layers.Activation('relu', name=name)(x) 72 | return x 73 | 74 | 75 | def InceptionV3(include_top=True, 76 | weights='imagenet', 77 | input_tensor=None, 78 | input_shape=None, 79 | pooling=None, 80 | classes=1000, 81 | **kwargs): 82 | """Instantiates the Inception v3 architecture. 83 | Optionally loads weights pre-trained on ImageNet. 84 | Note that the data format convention used by the model is 85 | the one specified in your Keras config at `~/.keras/keras.json`. 86 | # Arguments 87 | include_top: whether to include the fully-connected 88 | layer at the top of the network. 89 | weights: one of `None` (random initialization), 90 | 'imagenet' (pre-training on ImageNet), 91 | or the path to the weights file to be loaded. 92 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 93 | to use as image input for the model. 94 | input_shape: optional shape tuple, only to be specified 95 | if `include_top` is False (otherwise the input shape 96 | has to be `(299, 299, 3)` (with `channels_last` data format) 97 | or `(3, 299, 299)` (with `channels_first` data format). 98 | It should have exactly 3 inputs channels, 99 | and width and height should be no smaller than 75. 100 | E.g. `(150, 150, 3)` would be one valid value. 101 | pooling: Optional pooling mode for feature extraction 102 | when `include_top` is `False`. 103 | - `None` means that the output of the model will be 104 | the 4D tensor output of the 105 | last convolutional block. 106 | - `avg` means that global average pooling 107 | will be applied to the output of the 108 | last convolutional block, and thus 109 | the output of the model will be a 2D tensor. 110 | - `max` means that global max pooling will 111 | be applied. 112 | classes: optional number of classes to classify images 113 | into, only to be specified if `include_top` is True, and 114 | if no `weights` argument is specified. 115 | # Returns 116 | A Keras model instance. 117 | # Raises 118 | ValueError: in case of invalid argument for `weights`, 119 | or invalid input shape. 120 | """ 121 | global backend, layers, models, keras_utils 122 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs) 123 | 124 | if not (weights in {'imagenet', None} or os.path.exists(weights)): 125 | raise ValueError('The `weights` argument should be either ' 126 | '`None` (random initialization), `imagenet` ' 127 | '(pre-training on ImageNet), ' 128 | 'or the path to the weights file to be loaded.') 129 | 130 | if weights == 'imagenet' and include_top and classes != 1000: 131 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 132 | ' as true, `classes` should be 1000') 133 | 134 | # Determine proper input shape 135 | input_shape = imagenet_utils._obtain_input_shape( 136 | input_shape, 137 | default_size=299, 138 | min_size=75, 139 | data_format=backend.image_data_format(), 140 | require_flatten=include_top, 141 | weights=weights) 142 | 143 | if input_tensor is None: 144 | img_input = layers.Input(shape=input_shape) 145 | else: 146 | if not backend.is_keras_tensor(input_tensor): 147 | img_input = layers.Input(tensor=input_tensor, shape=input_shape) 148 | else: 149 | img_input = input_tensor 150 | 151 | if backend.image_data_format() == 'channels_first': 152 | channel_axis = 1 153 | else: 154 | channel_axis = 3 155 | 156 | x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding='same') 157 | x = conv2d_bn(x, 32, 3, 3, padding='same') 158 | x = conv2d_bn(x, 64, 3, 3, padding='same') 159 | x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 160 | 161 | x = conv2d_bn(x, 80, 1, 1, padding='same') 162 | x = conv2d_bn(x, 192, 3, 3, padding='same') 163 | x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 164 | 165 | # mixed 0: 35 x 35 x 256 166 | branch1x1 = conv2d_bn(x, 64, 1, 1) 167 | 168 | branch5x5 = conv2d_bn(x, 48, 1, 1) 169 | branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) 170 | 171 | branch3x3dbl = conv2d_bn(x, 64, 1, 1) 172 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 173 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 174 | 175 | branch_pool = layers.AveragePooling2D((3, 3), 176 | strides=(1, 1), 177 | padding='same')(x) 178 | branch_pool = conv2d_bn(branch_pool, 32, 1, 1) 179 | x = layers.concatenate( 180 | [branch1x1, branch5x5, branch3x3dbl, branch_pool], 181 | axis=channel_axis, 182 | name='mixed0') 183 | 184 | # mixed 1: 35 x 35 x 288 185 | branch1x1 = conv2d_bn(x, 64, 1, 1) 186 | 187 | branch5x5 = conv2d_bn(x, 48, 1, 1) 188 | branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) 189 | 190 | branch3x3dbl = conv2d_bn(x, 64, 1, 1) 191 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 192 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 193 | 194 | branch_pool = layers.AveragePooling2D((3, 3), 195 | strides=(1, 1), 196 | padding='same')(x) 197 | branch_pool = conv2d_bn(branch_pool, 64, 1, 1) 198 | x = layers.concatenate( 199 | [branch1x1, branch5x5, branch3x3dbl, branch_pool], 200 | axis=channel_axis, 201 | name='mixed1') 202 | 203 | # mixed 2: 35 x 35 x 288 204 | branch1x1 = conv2d_bn(x, 64, 1, 1) 205 | 206 | branch5x5 = conv2d_bn(x, 48, 1, 1) 207 | branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) 208 | 209 | branch3x3dbl = conv2d_bn(x, 64, 1, 1) 210 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 211 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 212 | 213 | branch_pool = layers.AveragePooling2D((3, 3), 214 | strides=(1, 1), 215 | padding='same')(x) 216 | branch_pool = conv2d_bn(branch_pool, 64, 1, 1) 217 | x = layers.concatenate( 218 | [branch1x1, branch5x5, branch3x3dbl, branch_pool], 219 | axis=channel_axis, 220 | name='mixed2') 221 | 222 | # mixed 3: 17 x 17 x 768 223 | branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding='same') 224 | 225 | branch3x3dbl = conv2d_bn(x, 64, 1, 1) 226 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 227 | branch3x3dbl = conv2d_bn( 228 | branch3x3dbl, 96, 3, 3, strides=(2, 2), padding='same') 229 | 230 | branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 231 | x = layers.concatenate( 232 | [branch3x3, branch3x3dbl, branch_pool], 233 | axis=channel_axis, 234 | name='mixed3') 235 | 236 | # mixed 4: 17 x 17 x 768 237 | branch1x1 = conv2d_bn(x, 192, 1, 1) 238 | 239 | branch7x7 = conv2d_bn(x, 128, 1, 1) 240 | branch7x7 = conv2d_bn(branch7x7, 128, 1, 7) 241 | branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) 242 | 243 | branch7x7dbl = conv2d_bn(x, 128, 1, 1) 244 | branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) 245 | branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7) 246 | branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) 247 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 248 | 249 | branch_pool = layers.AveragePooling2D((3, 3), 250 | strides=(1, 1), 251 | padding='same')(x) 252 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 253 | x = layers.concatenate( 254 | [branch1x1, branch7x7, branch7x7dbl, branch_pool], 255 | axis=channel_axis, 256 | name='mixed4') 257 | 258 | # mixed 5, 6: 17 x 17 x 768 259 | for i in range(2): 260 | branch1x1 = conv2d_bn(x, 192, 1, 1) 261 | 262 | branch7x7 = conv2d_bn(x, 160, 1, 1) 263 | branch7x7 = conv2d_bn(branch7x7, 160, 1, 7) 264 | branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) 265 | 266 | branch7x7dbl = conv2d_bn(x, 160, 1, 1) 267 | branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) 268 | branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7) 269 | branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) 270 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 271 | 272 | branch_pool = layers.AveragePooling2D( 273 | (3, 3), strides=(1, 1), padding='same')(x) 274 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 275 | x = layers.concatenate( 276 | [branch1x1, branch7x7, branch7x7dbl, branch_pool], 277 | axis=channel_axis, 278 | name='mixed' + str(5 + i)) 279 | 280 | # mixed 7: 17 x 17 x 768 281 | branch1x1 = conv2d_bn(x, 192, 1, 1) 282 | 283 | branch7x7 = conv2d_bn(x, 192, 1, 1) 284 | branch7x7 = conv2d_bn(branch7x7, 192, 1, 7) 285 | branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) 286 | 287 | branch7x7dbl = conv2d_bn(x, 192, 1, 1) 288 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) 289 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 290 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) 291 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 292 | 293 | branch_pool = layers.AveragePooling2D((3, 3), 294 | strides=(1, 1), 295 | padding='same')(x) 296 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 297 | x = layers.concatenate( 298 | [branch1x1, branch7x7, branch7x7dbl, branch_pool], 299 | axis=channel_axis, 300 | name='mixed7') 301 | 302 | # mixed 8: 8 x 8 x 1280 303 | branch3x3 = conv2d_bn(x, 192, 1, 1) 304 | branch3x3 = conv2d_bn(branch3x3, 320, 3, 3, 305 | strides=(2, 2), padding='same') 306 | 307 | branch7x7x3 = conv2d_bn(x, 192, 1, 1) 308 | branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7) 309 | branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1) 310 | branch7x7x3 = conv2d_bn( 311 | branch7x7x3, 192, 3, 3, strides=(2, 2), padding='same') 312 | 313 | branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 314 | x = layers.concatenate( 315 | [branch3x3, branch7x7x3, branch_pool], 316 | axis=channel_axis, 317 | name='mixed8') 318 | 319 | # mixed 9: 8 x 8 x 2048 320 | for i in range(2): 321 | branch1x1 = conv2d_bn(x, 320, 1, 1) 322 | 323 | branch3x3 = conv2d_bn(x, 384, 1, 1) 324 | branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3) 325 | branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1) 326 | branch3x3 = layers.concatenate( 327 | [branch3x3_1, branch3x3_2], 328 | axis=channel_axis, 329 | name='mixed9_' + str(i)) 330 | 331 | branch3x3dbl = conv2d_bn(x, 448, 1, 1) 332 | branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3) 333 | branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3) 334 | branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1) 335 | branch3x3dbl = layers.concatenate( 336 | [branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis) 337 | 338 | branch_pool = layers.AveragePooling2D( 339 | (3, 3), strides=(1, 1), padding='same')(x) 340 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 341 | x = layers.concatenate( 342 | [branch1x1, branch3x3, branch3x3dbl, branch_pool], 343 | axis=channel_axis, 344 | name='mixed' + str(9 + i)) 345 | if include_top: 346 | # Classification block 347 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 348 | x = layers.Dense(classes, activation='softmax', name='predictions')(x) 349 | else: 350 | if pooling == 'avg': 351 | x = layers.GlobalAveragePooling2D()(x) 352 | elif pooling == 'max': 353 | x = layers.GlobalMaxPooling2D()(x) 354 | 355 | # Ensure that the model takes into account 356 | # any potential predecessors of `input_tensor`. 357 | if input_tensor is not None: 358 | inputs = keras_utils.get_source_inputs(input_tensor) 359 | else: 360 | inputs = img_input 361 | # Create model. 362 | model = models.Model(inputs, x, name='inception_v3') 363 | 364 | # Load weights. 365 | if weights == 'imagenet': 366 | if include_top: 367 | weights_path = keras_utils.get_file( 368 | 'inception_v3_weights_tf_dim_ordering_tf_kernels.h5', 369 | WEIGHTS_PATH, 370 | cache_subdir='models', 371 | file_hash='9a0d58056eeedaa3f26cb7ebd46da564') 372 | else: 373 | weights_path = keras_utils.get_file( 374 | 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5', 375 | WEIGHTS_PATH_NO_TOP, 376 | cache_subdir='models', 377 | file_hash='bcbd6486424b2319ff4ef7d526e38f63') 378 | model.load_weights(weights_path) 379 | elif weights is not None: 380 | model.load_weights(weights) 381 | 382 | return model 383 | 384 | 385 | def preprocess_input(x, **kwargs): 386 | """Preprocesses a numpy array encoding a batch of images. 387 | # Arguments 388 | x: a 4D numpy array consists of RGB values within [0, 255]. 389 | # Returns 390 | Preprocessed array. 391 | """ 392 | return imagenet_utils.preprocess_input(x, mode='tf', **kwargs) 393 | -------------------------------------------------------------------------------- /segmentation_models/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .objects import KerasObject, Loss, Metric 2 | from . import functional -------------------------------------------------------------------------------- /segmentation_models/base/functional.py: -------------------------------------------------------------------------------- 1 | SMOOTH = 1e-5 2 | 3 | 4 | # ---------------------------------------------------------------- 5 | # Helpers 6 | # ---------------------------------------------------------------- 7 | 8 | def _gather_channels(x, indexes, **kwargs): 9 | """Slice tensor along channels axis by given indexes""" 10 | backend = kwargs['backend'] 11 | if backend.image_data_format() == 'channels_last': 12 | x = backend.permute_dimensions(x, (3, 0, 1, 2)) 13 | x = backend.gather(x, indexes) 14 | x = backend.permute_dimensions(x, (1, 2, 3, 0)) 15 | else: 16 | x = backend.permute_dimensions(x, (1, 0, 2, 3)) 17 | x = backend.gather(x, indexes) 18 | x = backend.permute_dimensions(x, (1, 0, 2, 3)) 19 | return x 20 | 21 | 22 | def get_reduce_axes(per_image, **kwargs): 23 | backend = kwargs['backend'] 24 | axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3] 25 | if not per_image: 26 | axes.insert(0, 0) 27 | return axes 28 | 29 | 30 | def gather_channels(*xs, indexes=None, **kwargs): 31 | """Slice tensors along channels axis by given indexes""" 32 | if indexes is None: 33 | return xs 34 | elif isinstance(indexes, (int)): 35 | indexes = [indexes] 36 | xs = [_gather_channels(x, indexes=indexes, **kwargs) for x in xs] 37 | return xs 38 | 39 | 40 | def round_if_needed(x, threshold, **kwargs): 41 | backend = kwargs['backend'] 42 | if threshold is not None: 43 | x = backend.greater(x, threshold) 44 | x = backend.cast(x, backend.floatx()) 45 | return x 46 | 47 | 48 | def average(x, per_image=False, class_weights=None, **kwargs): 49 | backend = kwargs['backend'] 50 | if per_image: 51 | x = backend.mean(x, axis=0) 52 | if class_weights is not None: 53 | x = x * class_weights 54 | return backend.mean(x) 55 | 56 | 57 | # ---------------------------------------------------------------- 58 | # Metric Functions 59 | # ---------------------------------------------------------------- 60 | 61 | def iou_score(gt, pr, class_weights=1., class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs): 62 | r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient 63 | (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the 64 | similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, 65 | and is defined as the size of the intersection divided by the size of the union of the sample sets: 66 | 67 | .. math:: J(A, B) = \frac{A \cap B}{A \cup B} 68 | 69 | Args: 70 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 71 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 72 | class_weights: 1. or list of class weights, len(weights) = C 73 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 74 | smooth: value to avoid division by zero 75 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 76 | else over whole batch 77 | threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round 78 | 79 | Returns: 80 | IoU/Jaccard score in range [0, 1] 81 | 82 | .. _`Jaccard index`: https://en.wikipedia.org/wiki/Jaccard_index 83 | 84 | """ 85 | 86 | backend = kwargs['backend'] 87 | 88 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 89 | pr = round_if_needed(pr, threshold, **kwargs) 90 | axes = get_reduce_axes(per_image, **kwargs) 91 | 92 | # score calculation 93 | intersection = backend.sum(gt * pr, axis=axes) 94 | union = backend.sum(gt + pr, axis=axes) - intersection 95 | 96 | score = (intersection + smooth) / (union + smooth) 97 | score = average(score, per_image, class_weights, **kwargs) 98 | 99 | return score 100 | 101 | 102 | def f_score(gt, pr, beta=1, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, 103 | **kwargs): 104 | r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall, 105 | where an F-score reaches its best value at 1 and worst score at 0. 106 | The relative contribution of ``precision`` and ``recall`` to the F1-score are equal. 107 | The formula for the F score is: 108 | 109 | .. math:: F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall} 110 | {\beta^2 \cdot precision + recall} 111 | 112 | The formula in terms of *Type I* and *Type II* errors: 113 | 114 | .. math:: F_\beta(A, B) = \frac{(1 + \beta^2) TP} {(1 + \beta^2) TP + \beta^2 FN + FP} 115 | 116 | 117 | where: 118 | TP - true positive; 119 | FP - false positive; 120 | FN - false negative; 121 | 122 | Args: 123 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 124 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 125 | class_weights: 1. or list of class weights, len(weights) = C 126 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 127 | beta: f-score coefficient 128 | smooth: value to avoid division by zero 129 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 130 | else over whole batch 131 | threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round 132 | 133 | Returns: 134 | F-score in range [0, 1] 135 | 136 | """ 137 | 138 | backend = kwargs['backend'] 139 | 140 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 141 | pr = round_if_needed(pr, threshold, **kwargs) 142 | axes = get_reduce_axes(per_image, **kwargs) 143 | 144 | # calculate score 145 | tp = backend.sum(gt * pr, axis=axes) 146 | fp = backend.sum(pr, axis=axes) - tp 147 | fn = backend.sum(gt, axis=axes) - tp 148 | 149 | score = ((1 + beta ** 2) * tp + smooth) \ 150 | / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 151 | score = average(score, per_image, class_weights, **kwargs) 152 | 153 | return score 154 | 155 | 156 | def precision(gt, pr, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs): 157 | r"""Calculate precision between the ground truth (gt) and the prediction (pr). 158 | 159 | .. math:: F_\beta(tp, fp) = \frac{tp} {(tp + fp)} 160 | 161 | where: 162 | - tp - true positives; 163 | - fp - false positives; 164 | 165 | Args: 166 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 167 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 168 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``) 169 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 170 | smooth: Float value to avoid division by zero. 171 | per_image: If ``True``, metric is calculated as mean over images in batch (B), 172 | else over whole batch. 173 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round. 174 | name: Optional string, if ``None`` default ``precision`` name is used. 175 | 176 | Returns: 177 | float: precision score 178 | """ 179 | backend = kwargs['backend'] 180 | 181 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 182 | pr = round_if_needed(pr, threshold, **kwargs) 183 | axes = get_reduce_axes(per_image, **kwargs) 184 | 185 | # score calculation 186 | tp = backend.sum(gt * pr, axis=axes) 187 | fp = backend.sum(pr, axis=axes) - tp 188 | 189 | score = (tp + smooth) / (tp + fp + smooth) 190 | score = average(score, per_image, class_weights, **kwargs) 191 | 192 | return score 193 | 194 | 195 | def recall(gt, pr, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs): 196 | r"""Calculate recall between the ground truth (gt) and the prediction (pr). 197 | 198 | .. math:: F_\beta(tp, fn) = \frac{tp} {(tp + fn)} 199 | 200 | where: 201 | - tp - true positives; 202 | - fp - false positives; 203 | 204 | Args: 205 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 206 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 207 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``) 208 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 209 | smooth: Float value to avoid division by zero. 210 | per_image: If ``True``, metric is calculated as mean over images in batch (B), 211 | else over whole batch. 212 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round. 213 | name: Optional string, if ``None`` default ``precision`` name is used. 214 | 215 | Returns: 216 | float: recall score 217 | """ 218 | backend = kwargs['backend'] 219 | 220 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 221 | pr = round_if_needed(pr, threshold, **kwargs) 222 | axes = get_reduce_axes(per_image, **kwargs) 223 | 224 | tp = backend.sum(gt * pr, axis=axes) 225 | fn = backend.sum(gt, axis=axes) - tp 226 | 227 | score = (tp + smooth) / (tp + fn + smooth) 228 | score = average(score, per_image, class_weights, **kwargs) 229 | 230 | return score 231 | 232 | 233 | # ---------------------------------------------------------------- 234 | # Loss Functions 235 | # ---------------------------------------------------------------- 236 | 237 | def categorical_crossentropy(gt, pr, class_weights=1., class_indexes=None, **kwargs): 238 | backend = kwargs['backend'] 239 | 240 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 241 | 242 | # scale predictions so that the class probas of each sample sum to 1 243 | axis = 3 if backend.image_data_format() == 'channels_last' else 1 244 | pr /= backend.sum(pr, axis=axis, keepdims=True) 245 | 246 | # clip to prevent NaN's and Inf's 247 | pr = backend.clip(pr, backend.epsilon(), 1 - backend.epsilon()) 248 | 249 | # calculate loss 250 | output = gt * backend.log(pr) * class_weights 251 | return - backend.mean(output) 252 | 253 | 254 | def binary_crossentropy(gt, pr, **kwargs): 255 | backend = kwargs['backend'] 256 | return backend.mean(backend.binary_crossentropy(gt, pr)) 257 | 258 | 259 | def categorical_focal_loss(gt, pr, gamma=2.0, alpha=0.25, class_indexes=None, **kwargs): 260 | r"""Implementation of Focal Loss from the paper in multiclass classification 261 | 262 | Formula: 263 | loss = - gt * alpha * ((1 - pr)^gamma) * log(pr) 264 | 265 | Args: 266 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 267 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 268 | alpha: the same as weighting factor in balanced cross entropy, default 0.25 269 | gamma: focusing parameter for modulating factor (1-p), default 2.0 270 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 271 | 272 | """ 273 | 274 | backend = kwargs['backend'] 275 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 276 | 277 | # clip to prevent NaN's and Inf's 278 | pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon()) 279 | 280 | # Calculate focal loss 281 | loss = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr)) 282 | 283 | return backend.mean(loss) 284 | 285 | 286 | def binary_focal_loss(gt, pr, gamma=2.0, alpha=0.25, **kwargs): 287 | r"""Implementation of Focal Loss from the paper in binary classification 288 | 289 | Formula: 290 | loss = - gt * alpha * ((1 - pr)^gamma) * log(pr) \ 291 | - (1 - gt) * alpha * (pr^gamma) * log(1 - pr) 292 | 293 | Args: 294 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 295 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 296 | alpha: the same as weighting factor in balanced cross entropy, default 0.25 297 | gamma: focusing parameter for modulating factor (1-p), default 2.0 298 | 299 | """ 300 | backend = kwargs['backend'] 301 | 302 | # clip to prevent NaN's and Inf's 303 | pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon()) 304 | 305 | loss_1 = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr)) 306 | loss_0 = - (1 - gt) * ((1 - alpha) * backend.pow((pr), gamma) * backend.log(1 - pr)) 307 | loss = backend.mean(loss_0 + loss_1) 308 | return loss 309 | -------------------------------------------------------------------------------- /segmentation_models/base/objects.py: -------------------------------------------------------------------------------- 1 | class KerasObject: 2 | _backend = None 3 | _models = None 4 | _layers = None 5 | _utils = None 6 | 7 | def __init__(self, name=None): 8 | if (self.backend is None or 9 | self.utils is None or 10 | self.models is None or 11 | self.layers is None): 12 | raise RuntimeError('You cannot use `KerasObjects` with None submodules.') 13 | 14 | self._name = name 15 | 16 | @property 17 | def __name__(self): 18 | if self._name is None: 19 | return self.__class__.__name__ 20 | return self._name 21 | 22 | @property 23 | def name(self): 24 | return self.__name__ 25 | 26 | @name.setter 27 | def name(self, name): 28 | self._name = name 29 | 30 | @classmethod 31 | def set_submodules(cls, backend, layers, models, utils): 32 | cls._backend = backend 33 | cls._layers = layers 34 | cls._models = models 35 | cls._utils = utils 36 | 37 | @property 38 | def submodules(self): 39 | return { 40 | 'backend': self.backend, 41 | 'layers': self.layers, 42 | 'models': self.models, 43 | 'utils': self.utils, 44 | } 45 | 46 | @property 47 | def backend(self): 48 | return self._backend 49 | 50 | @property 51 | def layers(self): 52 | return self._layers 53 | 54 | @property 55 | def models(self): 56 | return self._models 57 | 58 | @property 59 | def utils(self): 60 | return self._utils 61 | 62 | 63 | class Metric(KerasObject): 64 | pass 65 | 66 | 67 | class Loss(KerasObject): 68 | 69 | def __add__(self, other): 70 | if isinstance(other, Loss): 71 | return SumOfLosses(self, other) 72 | else: 73 | raise ValueError('Loss should be inherited from `Loss` class') 74 | 75 | def __radd__(self, other): 76 | return self.__add__(other) 77 | 78 | def __mul__(self, value): 79 | if isinstance(value, (int, float)): 80 | return MultipliedLoss(self, value) 81 | else: 82 | raise ValueError('Loss should be inherited from `BaseLoss` class') 83 | 84 | def __rmul__(self, other): 85 | return self.__mul__(other) 86 | 87 | 88 | class MultipliedLoss(Loss): 89 | 90 | def __init__(self, loss, multiplier): 91 | 92 | # resolve name 93 | if len(loss.__name__.split('+')) > 1: 94 | name = '{}({})'.format(multiplier, loss.__name__) 95 | else: 96 | name = '{}{}'.format(multiplier, loss.__name__) 97 | super().__init__(name=name) 98 | self.loss = loss 99 | self.multiplier = multiplier 100 | 101 | def __call__(self, gt, pr): 102 | return self.multiplier * self.loss(gt, pr) 103 | 104 | 105 | class SumOfLosses(Loss): 106 | 107 | def __init__(self, l1, l2): 108 | name = '{}_plus_{}'.format(l1.__name__, l2.__name__) 109 | super().__init__(name=name) 110 | self.l1 = l1 111 | self.l2 = l2 112 | 113 | def __call__(self, gt, pr): 114 | return self.l1(gt, pr) + self.l2(gt, pr) 115 | -------------------------------------------------------------------------------- /segmentation_models/losses.py: -------------------------------------------------------------------------------- 1 | from .base import Loss 2 | from .base import functional as F 3 | 4 | SMOOTH = 1e-5 5 | 6 | 7 | class JaccardLoss(Loss): 8 | r"""Creates a criterion to measure Jaccard loss: 9 | 10 | .. math:: L(A, B) = 1 - \frac{A \cap B}{A \cup B} 11 | 12 | Args: 13 | class_weights: Array (``np.array``) of class weights (``len(weights) = num_classes``). 14 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 15 | per_image: If ``True`` loss is calculated for each image in batch and then averaged, 16 | else loss is calculated for the whole batch. 17 | smooth: Value to avoid division by zero. 18 | 19 | Returns: 20 | A callable ``jaccard_loss`` instance. Can be used in ``model.compile(...)`` function 21 | or combined with other losses. 22 | 23 | Example: 24 | 25 | .. code:: python 26 | 27 | loss = JaccardLoss() 28 | model.compile('SGD', loss=loss) 29 | """ 30 | 31 | def __init__(self, class_weights=None, class_indexes=None, per_image=False, smooth=SMOOTH): 32 | super().__init__(name='jaccard_loss') 33 | self.class_weights = class_weights if class_weights is not None else 1 34 | self.class_indexes = class_indexes 35 | self.per_image = per_image 36 | self.smooth = smooth 37 | 38 | def __call__(self, gt, pr): 39 | return 1 - F.iou_score( 40 | gt, 41 | pr, 42 | class_weights=self.class_weights, 43 | class_indexes=self.class_indexes, 44 | smooth=self.smooth, 45 | per_image=self.per_image, 46 | threshold=None, 47 | **self.submodules 48 | ) 49 | 50 | 51 | class DiceLoss(Loss): 52 | r"""Creates a criterion to measure Dice loss: 53 | 54 | .. math:: L(precision, recall) = 1 - (1 + \beta^2) \frac{precision \cdot recall} 55 | {\beta^2 \cdot precision + recall} 56 | 57 | The formula in terms of *Type I* and *Type II* errors: 58 | 59 | .. math:: L(tp, fp, fn) = \frac{(1 + \beta^2) \cdot tp} {(1 + \beta^2) \cdot fp + \beta^2 \cdot fn + fp} 60 | 61 | where: 62 | - tp - true positives; 63 | - fp - false positives; 64 | - fn - false negatives; 65 | 66 | Args: 67 | beta: Float or integer coefficient for precision and recall balance. 68 | class_weights: Array (``np.array``) of class weights (``len(weights) = num_classes``). 69 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 70 | per_image: If ``True`` loss is calculated for each image in batch and then averaged, 71 | else loss is calculated for the whole batch. 72 | smooth: Value to avoid division by zero. 73 | 74 | Returns: 75 | A callable ``dice_loss`` instance. Can be used in ``model.compile(...)`` function` 76 | or combined with other losses. 77 | 78 | Example: 79 | 80 | .. code:: python 81 | 82 | loss = DiceLoss() 83 | model.compile('SGD', loss=loss) 84 | """ 85 | 86 | def __init__(self, beta=1, class_weights=None, class_indexes=None, per_image=False, smooth=SMOOTH): 87 | super().__init__(name='dice_loss') 88 | self.beta = beta 89 | self.class_weights = class_weights if class_weights is not None else 1 90 | self.class_indexes = class_indexes 91 | self.per_image = per_image 92 | self.smooth = smooth 93 | 94 | def __call__(self, gt, pr): 95 | return 1 - F.f_score( 96 | gt, 97 | pr, 98 | beta=self.beta, 99 | class_weights=self.class_weights, 100 | class_indexes=self.class_indexes, 101 | smooth=self.smooth, 102 | per_image=self.per_image, 103 | threshold=None, 104 | **self.submodules 105 | ) 106 | 107 | 108 | class BinaryCELoss(Loss): 109 | """Creates a criterion that measures the Binary Cross Entropy between the 110 | ground truth (gt) and the prediction (pr). 111 | 112 | .. math:: L(gt, pr) = - gt \cdot \log(pr) - (1 - gt) \cdot \log(1 - pr) 113 | 114 | Returns: 115 | A callable ``binary_crossentropy`` instance. Can be used in ``model.compile(...)`` function 116 | or combined with other losses. 117 | 118 | Example: 119 | 120 | .. code:: python 121 | 122 | loss = BinaryCELoss() 123 | model.compile('SGD', loss=loss) 124 | """ 125 | 126 | def __init__(self): 127 | super().__init__(name='binary_crossentropy') 128 | 129 | def __call__(self, gt, pr): 130 | return F.binary_crossentropy(gt, pr, **self.submodules) 131 | 132 | 133 | class CategoricalCELoss(Loss): 134 | """Creates a criterion that measures the Categorical Cross Entropy between the 135 | ground truth (gt) and the prediction (pr). 136 | 137 | .. math:: L(gt, pr) = - gt \cdot \log(pr) 138 | 139 | Args: 140 | class_weights: Array (``np.array``) of class weights (``len(weights) = num_classes``). 141 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 142 | 143 | Returns: 144 | A callable ``categorical_crossentropy`` instance. Can be used in ``model.compile(...)`` function 145 | or combined with other losses. 146 | 147 | Example: 148 | 149 | .. code:: python 150 | 151 | loss = CategoricalCELoss() 152 | model.compile('SGD', loss=loss) 153 | """ 154 | 155 | def __init__(self, class_weights=None, class_indexes=None): 156 | super().__init__(name='categorical_crossentropy') 157 | self.class_weights = class_weights if class_weights is not None else 1 158 | self.class_indexes = class_indexes 159 | 160 | def __call__(self, gt, pr): 161 | return F.categorical_crossentropy( 162 | gt, 163 | pr, 164 | class_weights=self.class_weights, 165 | class_indexes=self.class_indexes, 166 | **self.submodules 167 | ) 168 | 169 | 170 | class CategoricalFocalLoss(Loss): 171 | r"""Creates a criterion that measures the Categorical Focal Loss between the 172 | ground truth (gt) and the prediction (pr). 173 | 174 | .. math:: L(gt, pr) = - gt \cdot \alpha \cdot (1 - pr)^\gamma \cdot \log(pr) 175 | 176 | Args: 177 | alpha: Float or integer, the same as weighting factor in balanced cross entropy, default 0.25. 178 | gamma: Float or integer, focusing parameter for modulating factor (1 - p), default 2.0. 179 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 180 | 181 | Returns: 182 | A callable ``categorical_focal_loss`` instance. Can be used in ``model.compile(...)`` function 183 | or combined with other losses. 184 | 185 | Example: 186 | 187 | .. code:: python 188 | 189 | loss = CategoricalFocalLoss() 190 | model.compile('SGD', loss=loss) 191 | """ 192 | 193 | def __init__(self, alpha=0.25, gamma=2., class_indexes=None): 194 | super().__init__(name='focal_loss') 195 | self.alpha = alpha 196 | self.gamma = gamma 197 | self.class_indexes = class_indexes 198 | 199 | def __call__(self, gt, pr): 200 | return F.categorical_focal_loss( 201 | gt, 202 | pr, 203 | alpha=self.alpha, 204 | gamma=self.gamma, 205 | class_indexes=self.class_indexes, 206 | **self.submodules 207 | ) 208 | 209 | 210 | class BinaryFocalLoss(Loss): 211 | r"""Creates a criterion that measures the Binary Focal Loss between the 212 | ground truth (gt) and the prediction (pr). 213 | 214 | .. math:: L(gt, pr) = - gt \alpha (1 - pr)^\gamma \log(pr) - (1 - gt) \alpha pr^\gamma \log(1 - pr) 215 | 216 | Args: 217 | alpha: Float or integer, the same as weighting factor in balanced cross entropy, default 0.25. 218 | gamma: Float or integer, focusing parameter for modulating factor (1 - p), default 2.0. 219 | 220 | Returns: 221 | A callable ``binary_focal_loss`` instance. Can be used in ``model.compile(...)`` function 222 | or combined with other losses. 223 | 224 | Example: 225 | 226 | .. code:: python 227 | 228 | loss = BinaryFocalLoss() 229 | model.compile('SGD', loss=loss) 230 | """ 231 | 232 | def __init__(self, alpha=0.25, gamma=2.): 233 | super().__init__(name='binary_focal_loss') 234 | self.alpha = alpha 235 | self.gamma = gamma 236 | 237 | def __call__(self, gt, pr): 238 | return F.binary_focal_loss(gt, pr, alpha=self.alpha, gamma=self.gamma, **self.submodules) 239 | 240 | 241 | # aliases 242 | jaccard_loss = JaccardLoss() 243 | dice_loss = DiceLoss() 244 | 245 | binary_focal_loss = BinaryFocalLoss() 246 | categorical_focal_loss = CategoricalFocalLoss() 247 | 248 | binary_crossentropy = BinaryCELoss() 249 | categorical_crossentropy = CategoricalCELoss() 250 | 251 | # loss combinations 252 | bce_dice_loss = binary_crossentropy + dice_loss 253 | bce_jaccard_loss = binary_crossentropy + jaccard_loss 254 | 255 | cce_dice_loss = categorical_crossentropy + dice_loss 256 | cce_jaccard_loss = categorical_crossentropy + jaccard_loss 257 | 258 | binary_focal_dice_loss = binary_focal_loss + dice_loss 259 | binary_focal_jaccard_loss = binary_focal_loss + jaccard_loss 260 | 261 | categorical_focal_dice_loss = categorical_focal_loss + dice_loss 262 | categorical_focal_jaccard_loss = categorical_focal_loss + jaccard_loss 263 | -------------------------------------------------------------------------------- /segmentation_models/metrics.py: -------------------------------------------------------------------------------- 1 | from .base import Metric 2 | from .base import functional as F 3 | 4 | SMOOTH = 1e-5 5 | 6 | 7 | class IOUScore(Metric): 8 | r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient 9 | (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the 10 | similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, 11 | and is defined as the size of the intersection divided by the size of the union of the sample sets: 12 | 13 | .. math:: J(A, B) = \frac{A \cap B}{A \cup B} 14 | 15 | Args: 16 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``). 17 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 18 | smooth: value to avoid division by zero 19 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 20 | else over whole batch 21 | threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round 22 | 23 | Returns: 24 | A callable ``iou_score`` instance. Can be used in ``model.compile(...)`` function. 25 | 26 | .. _`Jaccard index`: https://en.wikipedia.org/wiki/Jaccard_index 27 | 28 | Example: 29 | 30 | .. code:: python 31 | 32 | metric = IOUScore() 33 | model.compile('SGD', loss=loss, metrics=[metric]) 34 | """ 35 | 36 | def __init__( 37 | self, 38 | class_weights=None, 39 | class_indexes=None, 40 | threshold=None, 41 | per_image=False, 42 | smooth=SMOOTH, 43 | name=None, 44 | ): 45 | name = name or 'iou_score' 46 | super().__init__(name=name) 47 | self.class_weights = class_weights if class_weights is not None else 1 48 | self.class_indexes = class_indexes 49 | self.threshold = threshold 50 | self.per_image = per_image 51 | self.smooth = smooth 52 | 53 | def __call__(self, gt, pr): 54 | return F.iou_score( 55 | gt, 56 | pr, 57 | class_weights=self.class_weights, 58 | class_indexes=self.class_indexes, 59 | smooth=self.smooth, 60 | per_image=self.per_image, 61 | threshold=self.threshold, 62 | **self.submodules 63 | ) 64 | 65 | 66 | class FScore(Metric): 67 | r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall, 68 | where an F-score reaches its best value at 1 and worst score at 0. 69 | The relative contribution of ``precision`` and ``recall`` to the F1-score are equal. 70 | The formula for the F score is: 71 | 72 | .. math:: F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall} 73 | {\beta^2 \cdot precision + recall} 74 | 75 | The formula in terms of *Type I* and *Type II* errors: 76 | 77 | .. math:: L(tp, fp, fn) = \frac{(1 + \beta^2) \cdot tp} {(1 + \beta^2) \cdot fp + \beta^2 \cdot fn + fp} 78 | 79 | where: 80 | - tp - true positives; 81 | - fp - false positives; 82 | - fn - false negatives; 83 | 84 | Args: 85 | beta: Integer of float f-score coefficient to balance precision and recall. 86 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``) 87 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 88 | smooth: Float value to avoid division by zero. 89 | per_image: If ``True``, metric is calculated as mean over images in batch (B), 90 | else over whole batch. 91 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round. 92 | name: Optional string, if ``None`` default ``f{beta}-score`` name is used. 93 | 94 | Returns: 95 | A callable ``f_score`` instance. Can be used in ``model.compile(...)`` function. 96 | 97 | Example: 98 | 99 | .. code:: python 100 | 101 | metric = FScore() 102 | model.compile('SGD', loss=loss, metrics=[metric]) 103 | """ 104 | 105 | def __init__( 106 | self, 107 | beta=1, 108 | class_weights=None, 109 | class_indexes=None, 110 | threshold=None, 111 | per_image=False, 112 | smooth=SMOOTH, 113 | name=None, 114 | ): 115 | name = name or 'f{}-score'.format(beta) 116 | super().__init__(name=name) 117 | self.beta = beta 118 | self.class_weights = class_weights if class_weights is not None else 1 119 | self.class_indexes = class_indexes 120 | self.threshold = threshold 121 | self.per_image = per_image 122 | self.smooth = smooth 123 | 124 | def __call__(self, gt, pr): 125 | return F.f_score( 126 | gt, 127 | pr, 128 | beta=self.beta, 129 | class_weights=self.class_weights, 130 | class_indexes=self.class_indexes, 131 | smooth=self.smooth, 132 | per_image=self.per_image, 133 | threshold=self.threshold, 134 | **self.submodules 135 | ) 136 | 137 | 138 | class Precision(Metric): 139 | r"""Creates a criterion that measures the Precision between the 140 | ground truth (gt) and the prediction (pr). 141 | 142 | .. math:: F_\beta(tp, fp) = \frac{tp} {(tp + fp)} 143 | 144 | where: 145 | - tp - true positives; 146 | - fp - false positives; 147 | 148 | Args: 149 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``). 150 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 151 | smooth: Float value to avoid division by zero. 152 | per_image: If ``True``, metric is calculated as mean over images in batch (B), 153 | else over whole batch. 154 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round. 155 | name: Optional string, if ``None`` default ``precision`` name is used. 156 | 157 | Returns: 158 | A callable ``precision`` instance. Can be used in ``model.compile(...)`` function. 159 | 160 | Example: 161 | 162 | .. code:: python 163 | 164 | metric = Precision() 165 | model.compile('SGD', loss=loss, metrics=[metric]) 166 | """ 167 | 168 | def __init__( 169 | self, 170 | class_weights=None, 171 | class_indexes=None, 172 | threshold=None, 173 | per_image=False, 174 | smooth=SMOOTH, 175 | name=None, 176 | ): 177 | name = name or 'precision' 178 | super().__init__(name=name) 179 | self.class_weights = class_weights if class_weights is not None else 1 180 | self.class_indexes = class_indexes 181 | self.threshold = threshold 182 | self.per_image = per_image 183 | self.smooth = smooth 184 | 185 | def __call__(self, gt, pr): 186 | return F.precision( 187 | gt, 188 | pr, 189 | class_weights=self.class_weights, 190 | class_indexes=self.class_indexes, 191 | smooth=self.smooth, 192 | per_image=self.per_image, 193 | threshold=self.threshold, 194 | **self.submodules 195 | ) 196 | 197 | 198 | class Recall(Metric): 199 | r"""Creates a criterion that measures the Precision between the 200 | ground truth (gt) and the prediction (pr). 201 | 202 | .. math:: F_\beta(tp, fn) = \frac{tp} {(tp + fn)} 203 | 204 | where: 205 | - tp - true positives; 206 | - fn - false negatives; 207 | 208 | Args: 209 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``). 210 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 211 | smooth: Float value to avoid division by zero. 212 | per_image: If ``True``, metric is calculated as mean over images in batch (B), 213 | else over whole batch. 214 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round. 215 | name: Optional string, if ``None`` default ``recall`` name is used. 216 | 217 | Returns: 218 | A callable ``recall`` instance. Can be used in ``model.compile(...)`` function. 219 | 220 | Example: 221 | 222 | .. code:: python 223 | 224 | metric = Precision() 225 | model.compile('SGD', loss=loss, metrics=[metric]) 226 | """ 227 | 228 | def __init__( 229 | self, 230 | class_weights=None, 231 | class_indexes=None, 232 | threshold=None, 233 | per_image=False, 234 | smooth=SMOOTH, 235 | name=None, 236 | ): 237 | name = name or 'recall' 238 | super().__init__(name=name) 239 | self.class_weights = class_weights if class_weights is not None else 1 240 | self.class_indexes = class_indexes 241 | self.threshold = threshold 242 | self.per_image = per_image 243 | self.smooth = smooth 244 | 245 | def __call__(self, gt, pr): 246 | return F.recall( 247 | gt, 248 | pr, 249 | class_weights=self.class_weights, 250 | class_indexes=self.class_indexes, 251 | smooth=self.smooth, 252 | per_image=self.per_image, 253 | threshold=self.threshold, 254 | **self.submodules 255 | ) 256 | 257 | 258 | # aliases 259 | iou_score = IOUScore() 260 | f1_score = FScore(beta=1) 261 | f2_score = FScore(beta=2) 262 | precision = Precision() 263 | recall = Recall() 264 | -------------------------------------------------------------------------------- /segmentation_models/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/segmentation_models/models/__init__.py -------------------------------------------------------------------------------- /segmentation_models/models/_common_blocks.py: -------------------------------------------------------------------------------- 1 | from keras_applications import get_submodules_from_kwargs 2 | 3 | 4 | def Conv2dBn( 5 | filters, 6 | kernel_size, 7 | strides=(1, 1), 8 | padding='valid', 9 | data_format=None, 10 | dilation_rate=(1, 1), 11 | activation=None, 12 | kernel_initializer='glorot_uniform', 13 | bias_initializer='zeros', 14 | kernel_regularizer=None, 15 | bias_regularizer=None, 16 | activity_regularizer=None, 17 | kernel_constraint=None, 18 | bias_constraint=None, 19 | use_batchnorm=False, 20 | **kwargs 21 | ): 22 | """Extension of Conv2D layer with batchnorm""" 23 | 24 | conv_name, act_name, bn_name = None, None, None 25 | block_name = kwargs.pop('name', None) 26 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs) 27 | 28 | if block_name is not None: 29 | conv_name = block_name + '_conv' 30 | 31 | if block_name is not None and activation is not None: 32 | act_str = activation.__name__ if callable(activation) else str(activation) 33 | act_name = block_name + '_' + act_str 34 | 35 | if block_name is not None and use_batchnorm: 36 | bn_name = block_name + '_bn' 37 | 38 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 39 | 40 | def wrapper(input_tensor): 41 | 42 | x = layers.Conv2D( 43 | filters=filters, 44 | kernel_size=kernel_size, 45 | strides=strides, 46 | padding=padding, 47 | data_format=data_format, 48 | dilation_rate=dilation_rate, 49 | activation=None, 50 | use_bias=not (use_batchnorm), 51 | kernel_initializer=kernel_initializer, 52 | bias_initializer=bias_initializer, 53 | kernel_regularizer=kernel_regularizer, 54 | bias_regularizer=bias_regularizer, 55 | activity_regularizer=activity_regularizer, 56 | kernel_constraint=kernel_constraint, 57 | bias_constraint=bias_constraint, 58 | name=conv_name, 59 | )(input_tensor) 60 | 61 | if use_batchnorm: 62 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x) 63 | 64 | if activation: 65 | x = layers.Activation(activation, name=act_name)(x) 66 | 67 | return x 68 | 69 | return wrapper 70 | -------------------------------------------------------------------------------- /segmentation_models/models/_utils.py: -------------------------------------------------------------------------------- 1 | from keras_applications import get_submodules_from_kwargs 2 | 3 | 4 | def freeze_model(model, **kwargs): 5 | """Set all layers non trainable, excluding BatchNormalization layers""" 6 | _, layers, _, _ = get_submodules_from_kwargs(kwargs) 7 | for layer in model.layers: 8 | if not isinstance(layer, layers.BatchNormalization): 9 | layer.trainable = False 10 | return 11 | 12 | 13 | def filter_keras_submodules(kwargs): 14 | """Selects only arguments that define keras_application submodules. """ 15 | submodule_keys = kwargs.keys() & {'backend', 'layers', 'models', 'utils'} 16 | return {key: kwargs[key] for key in submodule_keys} 17 | -------------------------------------------------------------------------------- /segmentation_models/models/fpn.py: -------------------------------------------------------------------------------- 1 | from keras_applications import get_submodules_from_kwargs 2 | 3 | from ._common_blocks import Conv2dBn 4 | from ._utils import freeze_model, filter_keras_submodules 5 | from ..backbones.backbones_factory import Backbones 6 | 7 | backend = None 8 | layers = None 9 | models = None 10 | keras_utils = None 11 | 12 | 13 | # --------------------------------------------------------------------- 14 | # Utility functions 15 | # --------------------------------------------------------------------- 16 | 17 | def get_submodules(): 18 | return { 19 | 'backend': backend, 20 | 'models': models, 21 | 'layers': layers, 22 | 'utils': keras_utils, 23 | } 24 | 25 | 26 | # --------------------------------------------------------------------- 27 | # Blocks 28 | # --------------------------------------------------------------------- 29 | 30 | def Conv3x3BnReLU(filters, use_batchnorm, name=None): 31 | kwargs = get_submodules() 32 | 33 | def wrapper(input_tensor): 34 | return Conv2dBn( 35 | filters, 36 | kernel_size=3, 37 | activation='relu', 38 | kernel_initializer='he_uniform', 39 | padding='same', 40 | use_batchnorm=use_batchnorm, 41 | name=name, 42 | **kwargs 43 | )(input_tensor) 44 | 45 | return wrapper 46 | 47 | 48 | def DoubleConv3x3BnReLU(filters, use_batchnorm, name=None): 49 | name1, name2 = None, None 50 | if name is not None: 51 | name1 = name + 'a' 52 | name2 = name + 'b' 53 | 54 | def wrapper(input_tensor): 55 | x = Conv3x3BnReLU(filters, use_batchnorm, name=name1)(input_tensor) 56 | x = Conv3x3BnReLU(filters, use_batchnorm, name=name2)(x) 57 | return x 58 | 59 | return wrapper 60 | 61 | 62 | def FPNBlock(pyramid_filters, stage): 63 | conv0_name = 'fpn_stage_p{}_pre_conv'.format(stage) 64 | conv1_name = 'fpn_stage_p{}_conv'.format(stage) 65 | add_name = 'fpn_stage_p{}_add'.format(stage) 66 | up_name = 'fpn_stage_p{}_upsampling'.format(stage) 67 | 68 | channels_axis = 3 if backend.image_data_format() == 'channels_last' else 1 69 | 70 | def wrapper(input_tensor, skip): 71 | # if input tensor channels not equal to pyramid channels 72 | # we will not be able to sum input tensor and skip 73 | # so add extra conv layer to transform it 74 | input_filters = backend.int_shape(input_tensor)[channels_axis] 75 | if input_filters != pyramid_filters: 76 | input_tensor = layers.Conv2D( 77 | filters=pyramid_filters, 78 | kernel_size=(1, 1), 79 | kernel_initializer='he_uniform', 80 | name=conv0_name, 81 | )(input_tensor) 82 | 83 | skip = layers.Conv2D( 84 | filters=pyramid_filters, 85 | kernel_size=(1, 1), 86 | kernel_initializer='he_uniform', 87 | name=conv1_name, 88 | )(skip) 89 | 90 | x = layers.UpSampling2D((2, 2), name=up_name)(input_tensor) 91 | x = layers.Add(name=add_name)([x, skip]) 92 | 93 | return x 94 | 95 | return wrapper 96 | 97 | 98 | # --------------------------------------------------------------------- 99 | # FPN Decoder 100 | # --------------------------------------------------------------------- 101 | 102 | def build_fpn( 103 | backbone, 104 | skip_connection_layers, 105 | pyramid_filters=256, 106 | segmentation_filters=128, 107 | classes=1, 108 | activation='sigmoid', 109 | use_batchnorm=True, 110 | aggregation='sum', 111 | dropout=None, 112 | ): 113 | input_ = backbone.input 114 | x = backbone.output 115 | 116 | # building decoder blocks with skip connections 117 | skips = ([backbone.get_layer(name=i).output if isinstance(i, str) 118 | else backbone.get_layer(index=i).output for i in skip_connection_layers]) 119 | 120 | # build FPN pyramid 121 | p5 = FPNBlock(pyramid_filters, stage=5)(x, skips[0]) 122 | p4 = FPNBlock(pyramid_filters, stage=4)(p5, skips[1]) 123 | p3 = FPNBlock(pyramid_filters, stage=3)(p4, skips[2]) 124 | p2 = FPNBlock(pyramid_filters, stage=2)(p3, skips[3]) 125 | 126 | # add segmentation head to each 127 | s5 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage5')(p5) 128 | s4 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage4')(p4) 129 | s3 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage3')(p3) 130 | s2 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage2')(p2) 131 | 132 | # upsampling to same resolution 133 | s5 = layers.UpSampling2D((8, 8), interpolation='nearest', name='upsampling_stage5')(s5) 134 | s4 = layers.UpSampling2D((4, 4), interpolation='nearest', name='upsampling_stage4')(s4) 135 | s3 = layers.UpSampling2D((2, 2), interpolation='nearest', name='upsampling_stage3')(s3) 136 | 137 | # aggregating results 138 | if aggregation == 'sum': 139 | x = layers.Add(name='aggregation_sum')([s2, s3, s4, s5]) 140 | elif aggregation == 'concat': 141 | concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1 142 | x = layers.Concatenate(axis=concat_axis, name='aggregation_concat')([s2, s3, s4, s5]) 143 | else: 144 | raise ValueError('Aggregation parameter should be in ("sum", "concat"), ' 145 | 'got {}'.format(aggregation)) 146 | 147 | if dropout: 148 | x = layers.SpatialDropout2D(dropout, name='pyramid_dropout')(x) 149 | 150 | # final stage 151 | x = Conv3x3BnReLU(segmentation_filters, use_batchnorm, name='final_stage')(x) 152 | x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear', name='final_upsampling')(x) 153 | 154 | # model head (define number of output classes) 155 | x = layers.Conv2D( 156 | filters=classes, 157 | kernel_size=(3, 3), 158 | padding='same', 159 | use_bias=True, 160 | kernel_initializer='glorot_uniform', 161 | name='head_conv', 162 | )(x) 163 | x = layers.Activation(activation, name=activation)(x) 164 | 165 | # create keras model instance 166 | model = models.Model(input_, x) 167 | 168 | return model 169 | 170 | 171 | # --------------------------------------------------------------------- 172 | # FPN Model 173 | # --------------------------------------------------------------------- 174 | 175 | def FPN( 176 | backbone_name='vgg16', 177 | input_shape=(None, None, 3), 178 | classes=21, 179 | activation='softmax', 180 | weights=None, 181 | encoder_weights='imagenet', 182 | encoder_freeze=False, 183 | encoder_features='default', 184 | pyramid_block_filters=256, 185 | pyramid_use_batchnorm=True, 186 | pyramid_aggregation='concat', 187 | pyramid_dropout=None, 188 | **kwargs 189 | ): 190 | """FPN_ is a fully convolution neural network for image semantic segmentation 191 | 192 | Args: 193 | backbone_name: name of classification model (without last dense layers) used as feature 194 | extractor to build segmentation model. 195 | input_shape: shape of input data/image ``(H, W, C)``, in general 196 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be 197 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``. 198 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 199 | weights: optional, path to model weights. 200 | activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``). 201 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 202 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 203 | encoder_features: a list of layer numbers or names starting from top of the model. 204 | Each of these layers will be used to build features pyramid. If ``default`` is used 205 | layer names are taken from ``DEFAULT_FEATURE_PYRAMID_LAYERS``. 206 | pyramid_block_filters: a number of filters in Feature Pyramid Block of FPN_. 207 | pyramid_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 208 | is used. 209 | pyramid_aggregation: one of 'sum' or 'concat'. The way to aggregate pyramid blocks. 210 | pyramid_dropout: spatial dropout rate for feature pyramid in range (0, 1). 211 | 212 | Returns: 213 | ``keras.models.Model``: **FPN** 214 | 215 | .. _FPN: 216 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 217 | 218 | """ 219 | global backend, layers, models, keras_utils 220 | submodule_args = filter_keras_submodules(kwargs) 221 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) 222 | 223 | backbone = Backbones.get_backbone( 224 | backbone_name, 225 | input_shape=input_shape, 226 | weights=encoder_weights, 227 | include_top=False, 228 | **kwargs, 229 | ) 230 | 231 | if encoder_features == 'default': 232 | encoder_features = Backbones.get_feature_layers(backbone_name, n=4) 233 | 234 | model = build_fpn( 235 | backbone=backbone, 236 | skip_connection_layers=encoder_features, 237 | pyramid_filters=pyramid_block_filters, 238 | segmentation_filters=pyramid_block_filters // 2, 239 | use_batchnorm=pyramid_use_batchnorm, 240 | dropout=pyramid_dropout, 241 | activation=activation, 242 | classes=classes, 243 | aggregation=pyramid_aggregation, 244 | ) 245 | 246 | # lock encoder weights for fine-tuning 247 | if encoder_freeze: 248 | freeze_model(backbone, **kwargs) 249 | 250 | # loading model weights 251 | if weights is not None: 252 | model.load_weights(weights) 253 | 254 | return model 255 | -------------------------------------------------------------------------------- /segmentation_models/models/linknet.py: -------------------------------------------------------------------------------- 1 | from keras_applications import get_submodules_from_kwargs 2 | 3 | from ._common_blocks import Conv2dBn 4 | from ._utils import freeze_model, filter_keras_submodules 5 | from ..backbones.backbones_factory import Backbones 6 | 7 | backend = None 8 | layers = None 9 | models = None 10 | keras_utils = None 11 | 12 | 13 | # --------------------------------------------------------------------- 14 | # Utility functions 15 | # --------------------------------------------------------------------- 16 | 17 | def get_submodules(): 18 | return { 19 | 'backend': backend, 20 | 'models': models, 21 | 'layers': layers, 22 | 'utils': keras_utils, 23 | } 24 | 25 | 26 | # --------------------------------------------------------------------- 27 | # Blocks 28 | # --------------------------------------------------------------------- 29 | 30 | def Conv3x3BnReLU(filters, use_batchnorm, name=None): 31 | kwargs = get_submodules() 32 | 33 | def wrapper(input_tensor): 34 | return Conv2dBn( 35 | filters, 36 | kernel_size=3, 37 | activation='relu', 38 | kernel_initializer='he_uniform', 39 | padding='same', 40 | use_batchnorm=use_batchnorm, 41 | name=name, 42 | **kwargs 43 | )(input_tensor) 44 | 45 | return wrapper 46 | 47 | 48 | def Conv1x1BnReLU(filters, use_batchnorm, name=None): 49 | kwargs = get_submodules() 50 | 51 | def wrapper(input_tensor): 52 | return Conv2dBn( 53 | filters, 54 | kernel_size=1, 55 | activation='relu', 56 | kernel_initializer='he_uniform', 57 | padding='same', 58 | use_batchnorm=use_batchnorm, 59 | name=name, 60 | **kwargs 61 | )(input_tensor) 62 | 63 | return wrapper 64 | 65 | 66 | def DecoderUpsamplingX2Block(filters, stage, use_batchnorm): 67 | conv_block1_name = 'decoder_stage{}a'.format(stage) 68 | conv_block2_name = 'decoder_stage{}b'.format(stage) 69 | conv_block3_name = 'decoder_stage{}c'.format(stage) 70 | up_name = 'decoder_stage{}_upsampling'.format(stage) 71 | add_name = 'decoder_stage{}_add'.format(stage) 72 | 73 | channels_axis = 3 if backend.image_data_format() == 'channels_last' else 1 74 | 75 | def wrapper(input_tensor, skip=None): 76 | input_filters = backend.int_shape(input_tensor)[channels_axis] 77 | output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters 78 | 79 | x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor) 80 | x = layers.UpSampling2D((2, 2), name=up_name)(x) 81 | x = Conv3x3BnReLU(input_filters // 4, use_batchnorm, name=conv_block2_name)(x) 82 | x = Conv1x1BnReLU(output_filters, use_batchnorm, name=conv_block3_name)(x) 83 | 84 | if skip is not None: 85 | x = layers.Add(name=add_name)([x, skip]) 86 | return x 87 | 88 | return wrapper 89 | 90 | 91 | def DecoderTransposeX2Block(filters, stage, use_batchnorm): 92 | conv_block1_name = 'decoder_stage{}a'.format(stage) 93 | transpose_name = 'decoder_stage{}b_transpose'.format(stage) 94 | bn_name = 'decoder_stage{}b_bn'.format(stage) 95 | relu_name = 'decoder_stage{}b_relu'.format(stage) 96 | conv_block3_name = 'decoder_stage{}c'.format(stage) 97 | add_name = 'decoder_stage{}_add'.format(stage) 98 | 99 | channels_axis = bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 100 | 101 | def wrapper(input_tensor, skip=None): 102 | input_filters = backend.int_shape(input_tensor)[channels_axis] 103 | output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters 104 | 105 | x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor) 106 | x = layers.Conv2DTranspose( 107 | filters=input_filters // 4, 108 | kernel_size=(4, 4), 109 | strides=(2, 2), 110 | padding='same', 111 | name=transpose_name, 112 | use_bias=not use_batchnorm, 113 | )(x) 114 | 115 | if use_batchnorm: 116 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x) 117 | 118 | x = layers.Activation('relu', name=relu_name)(x) 119 | x = Conv1x1BnReLU(output_filters, use_batchnorm, name=conv_block3_name)(x) 120 | 121 | if skip is not None: 122 | x = layers.Add(name=add_name)([x, skip]) 123 | 124 | return x 125 | 126 | return wrapper 127 | 128 | 129 | # --------------------------------------------------------------------- 130 | # LinkNet Decoder 131 | # --------------------------------------------------------------------- 132 | 133 | def build_linknet( 134 | backbone, 135 | decoder_block, 136 | skip_connection_layers, 137 | decoder_filters=(256, 128, 64, 32, 16), 138 | n_upsample_blocks=5, 139 | classes=1, 140 | activation='sigmoid', 141 | use_batchnorm=True, 142 | ): 143 | input_ = backbone.input 144 | x = backbone.output 145 | 146 | # extract skip connections 147 | skips = ([backbone.get_layer(name=i).output if isinstance(i, str) 148 | else backbone.get_layer(index=i).output for i in skip_connection_layers]) 149 | 150 | # add center block if previous operation was maxpooling (for vgg models) 151 | if isinstance(backbone.layers[-1], layers.MaxPooling2D): 152 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block1')(x) 153 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block2')(x) 154 | 155 | # building decoder blocks 156 | for i in range(n_upsample_blocks): 157 | 158 | if i < len(skips): 159 | skip = skips[i] 160 | else: 161 | skip = None 162 | 163 | x = decoder_block(decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip) 164 | 165 | # model head (define number of output classes) 166 | x = layers.Conv2D( 167 | filters=classes, 168 | kernel_size=(3, 3), 169 | padding='same', 170 | use_bias=True, 171 | kernel_initializer='glorot_uniform' 172 | )(x) 173 | x = layers.Activation(activation, name=activation)(x) 174 | 175 | # create keras model instance 176 | model = models.Model(input_, x) 177 | 178 | return model 179 | 180 | 181 | # --------------------------------------------------------------------- 182 | # LinkNet Model 183 | # --------------------------------------------------------------------- 184 | 185 | def Linknet( 186 | backbone_name='vgg16', 187 | input_shape=(None, None, 3), 188 | classes=1, 189 | activation='sigmoid', 190 | weights=None, 191 | encoder_weights='imagenet', 192 | encoder_freeze=False, 193 | encoder_features='default', 194 | decoder_block_type='upsampling', 195 | decoder_filters=(None, None, None, None, 16), 196 | decoder_use_batchnorm=True, 197 | **kwargs 198 | ): 199 | """Linknet_ is a fully convolution neural network for fast image semantic segmentation 200 | 201 | Note: 202 | This implementation by default has 4 skip connections (original - 3). 203 | 204 | Args: 205 | backbone_name: name of classification model (without last dense layers) used as feature 206 | extractor to build segmentation model. 207 | input_shape: shape of input data/image ``(H, W, C)``, in general 208 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be 209 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``. 210 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 211 | activation: name of one of ``keras.activations`` for last model layer 212 | (e.g. ``sigmoid``, ``softmax``, ``linear``). 213 | weights: optional, path to model weights. 214 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 215 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 216 | encoder_features: a list of layer numbers or names starting from top of the model. 217 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used 218 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``. 219 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks, 220 | for block with skip connection a number of filters is equal to number of filters in 221 | corresponding encoder block (estimates automatically and can be passed as ``None`` value). 222 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 223 | is used. 224 | decoder_block_type: one of 225 | - `upsampling`: use ``UpSampling2D`` keras layer 226 | - `transpose`: use ``Transpose2D`` keras layer 227 | 228 | Returns: 229 | ``keras.models.Model``: **Linknet** 230 | 231 | .. _Linknet: 232 | https://arxiv.org/pdf/1707.03718.pdf 233 | """ 234 | 235 | global backend, layers, models, keras_utils 236 | submodule_args = filter_keras_submodules(kwargs) 237 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) 238 | 239 | if decoder_block_type == 'upsampling': 240 | decoder_block = DecoderUpsamplingX2Block 241 | elif decoder_block_type == 'transpose': 242 | decoder_block = DecoderTransposeX2Block 243 | else: 244 | raise ValueError('Decoder block type should be in ("upsampling", "transpose"). ' 245 | 'Got: {}'.format(decoder_block_type)) 246 | 247 | backbone = Backbones.get_backbone( 248 | backbone_name, 249 | input_shape=input_shape, 250 | weights=encoder_weights, 251 | include_top=False, 252 | **kwargs, 253 | ) 254 | 255 | if encoder_features == 'default': 256 | encoder_features = Backbones.get_feature_layers(backbone_name, n=4) 257 | 258 | model = build_linknet( 259 | backbone=backbone, 260 | decoder_block=decoder_block, 261 | skip_connection_layers=encoder_features, 262 | decoder_filters=decoder_filters, 263 | classes=classes, 264 | activation=activation, 265 | n_upsample_blocks=len(decoder_filters), 266 | use_batchnorm=decoder_use_batchnorm, 267 | ) 268 | 269 | # lock encoder weights for fine-tuning 270 | if encoder_freeze: 271 | freeze_model(backbone, **kwargs) 272 | 273 | # loading model weights 274 | if weights is not None: 275 | model.load_weights(weights) 276 | 277 | return model 278 | -------------------------------------------------------------------------------- /segmentation_models/models/pspnet.py: -------------------------------------------------------------------------------- 1 | from keras_applications import get_submodules_from_kwargs 2 | 3 | from ._common_blocks import Conv2dBn 4 | from ._utils import freeze_model, filter_keras_submodules 5 | from ..backbones.backbones_factory import Backbones 6 | 7 | backend = None 8 | layers = None 9 | models = None 10 | keras_utils = None 11 | 12 | 13 | # --------------------------------------------------------------------- 14 | # Utility functions 15 | # --------------------------------------------------------------------- 16 | 17 | def get_submodules(): 18 | return { 19 | 'backend': backend, 20 | 'models': models, 21 | 'layers': layers, 22 | 'utils': keras_utils, 23 | } 24 | 25 | 26 | def check_input_shape(input_shape, factor): 27 | if input_shape is None: 28 | raise ValueError("Input shape should be a tuple of 3 integers, not None!") 29 | 30 | h, w = input_shape[:2] if backend.image_data_format() == 'channels_last' else input_shape[1:] 31 | min_size = factor * 6 32 | 33 | is_wrong_shape = ( 34 | h % min_size != 0 or w % min_size != 0 or 35 | h < min_size or w < min_size 36 | ) 37 | 38 | if is_wrong_shape: 39 | raise ValueError('Wrong shape {}, input H and W should '.format(input_shape) + 40 | 'be divisible by `{}`'.format(min_size)) 41 | 42 | 43 | # --------------------------------------------------------------------- 44 | # Blocks 45 | # --------------------------------------------------------------------- 46 | 47 | def Conv1x1BnReLU(filters, use_batchnorm, name=None): 48 | kwargs = get_submodules() 49 | 50 | def wrapper(input_tensor): 51 | return Conv2dBn( 52 | filters, 53 | kernel_size=1, 54 | activation='relu', 55 | kernel_initializer='he_uniform', 56 | padding='same', 57 | use_batchnorm=use_batchnorm, 58 | name=name, 59 | **kwargs 60 | )(input_tensor) 61 | 62 | return wrapper 63 | 64 | 65 | def SpatialContextBlock( 66 | level, 67 | conv_filters=512, 68 | pooling_type='avg', 69 | use_batchnorm=True, 70 | ): 71 | if pooling_type not in ('max', 'avg'): 72 | raise ValueError('Unsupported pooling type - `{}`.'.format(pooling_type) + 73 | 'Use `avg` or `max`.') 74 | 75 | Pooling2D = layers.MaxPool2D if pooling_type == 'max' else layers.AveragePooling2D 76 | 77 | pooling_name = 'psp_level{}_pooling'.format(level) 78 | conv_block_name = 'psp_level{}'.format(level) 79 | upsampling_name = 'psp_level{}_upsampling'.format(level) 80 | 81 | def wrapper(input_tensor): 82 | # extract input feature maps size (h, and w dimensions) 83 | input_shape = backend.int_shape(input_tensor) 84 | spatial_size = input_shape[1:3] if backend.image_data_format() == 'channels_last' else input_shape[2:] 85 | 86 | # Compute the kernel and stride sizes according to how large the final feature map will be 87 | # When the kernel factor and strides are equal, then we can compute the final feature map factor 88 | # by simply dividing the current factor by the kernel or stride factor 89 | # The final feature map sizes are 1x1, 2x2, 3x3, and 6x6. 90 | pool_size = up_size = [spatial_size[0] // level, spatial_size[1] // level] 91 | 92 | x = Pooling2D(pool_size, strides=pool_size, padding='same', name=pooling_name)(input_tensor) 93 | x = Conv1x1BnReLU(conv_filters, use_batchnorm, name=conv_block_name)(x) 94 | x = layers.UpSampling2D(up_size, interpolation='bilinear', name=upsampling_name)(x) 95 | return x 96 | 97 | return wrapper 98 | 99 | 100 | # --------------------------------------------------------------------- 101 | # PSP Decoder 102 | # --------------------------------------------------------------------- 103 | 104 | def build_psp( 105 | backbone, 106 | psp_layer_idx, 107 | pooling_type='avg', 108 | conv_filters=512, 109 | use_batchnorm=True, 110 | final_upsampling_factor=8, 111 | classes=21, 112 | activation='softmax', 113 | dropout=None, 114 | ): 115 | input_ = backbone.input 116 | x = (backbone.get_layer(name=psp_layer_idx).output if isinstance(psp_layer_idx, str) 117 | else backbone.get_layer(index=psp_layer_idx).output) 118 | 119 | # build spatial pyramid 120 | x1 = SpatialContextBlock(1, conv_filters, pooling_type, use_batchnorm)(x) 121 | x2 = SpatialContextBlock(2, conv_filters, pooling_type, use_batchnorm)(x) 122 | x3 = SpatialContextBlock(3, conv_filters, pooling_type, use_batchnorm)(x) 123 | x6 = SpatialContextBlock(6, conv_filters, pooling_type, use_batchnorm)(x) 124 | 125 | # aggregate spatial pyramid 126 | concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1 127 | x = layers.Concatenate(axis=concat_axis, name='psp_concat')([x, x1, x2, x3, x6]) 128 | x = Conv1x1BnReLU(conv_filters, use_batchnorm, name='aggregation')(x) 129 | 130 | # model regularization 131 | if dropout is not None: 132 | x = layers.SpatialDropout2D(dropout, name='spatial_dropout')(x) 133 | 134 | # model head 135 | x = layers.Conv2D( 136 | filters=classes, 137 | kernel_size=(3, 3), 138 | padding='same', 139 | kernel_initializer='glorot_uniform', 140 | name='final_conv', 141 | )(x) 142 | 143 | x = layers.UpSampling2D(final_upsampling_factor, name='final_upsampling', interpolation='bilinear')(x) 144 | x = layers.Activation(activation, name=activation)(x) 145 | 146 | model = models.Model(input_, x) 147 | 148 | return model 149 | 150 | 151 | # --------------------------------------------------------------------- 152 | # PSP Model 153 | # --------------------------------------------------------------------- 154 | 155 | def PSPNet( 156 | backbone_name='vgg16', 157 | input_shape=(384, 384, 3), 158 | classes=21, 159 | activation='softmax', 160 | weights=None, 161 | encoder_weights='imagenet', 162 | encoder_freeze=False, 163 | downsample_factor=8, 164 | psp_conv_filters=512, 165 | psp_pooling_type='avg', 166 | psp_use_batchnorm=True, 167 | psp_dropout=None, 168 | **kwargs 169 | ): 170 | """PSPNet_ is a fully convolution neural network for image semantic segmentation 171 | 172 | Args: 173 | backbone_name: name of classification model used as feature 174 | extractor to build segmentation model. 175 | input_shape: shape of input data/image ``(H, W, C)``. 176 | ``H`` and ``W`` should be divisible by ``6 * downsample_factor`` and **NOT** ``None``! 177 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 178 | activation: name of one of ``keras.activations`` for last model layer 179 | (e.g. ``sigmoid``, ``softmax``, ``linear``). 180 | weights: optional, path to model weights. 181 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 182 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 183 | downsample_factor: one of 4, 8 and 16. Downsampling rate or in other words backbone depth 184 | to construct PSP module on it. 185 | psp_conv_filters: number of filters in ``Conv2D`` layer in each PSP block. 186 | psp_pooling_type: one of 'avg', 'max'. PSP block pooling type (maximum or average). 187 | psp_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 188 | is used. 189 | psp_dropout: dropout rate between 0 and 1. 190 | 191 | Returns: 192 | ``keras.models.Model``: **PSPNet** 193 | 194 | .. _PSPNet: 195 | https://arxiv.org/pdf/1612.01105.pdf 196 | 197 | """ 198 | 199 | global backend, layers, models, keras_utils 200 | submodule_args = filter_keras_submodules(kwargs) 201 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) 202 | 203 | # control image input shape 204 | check_input_shape(input_shape, downsample_factor) 205 | 206 | backbone = Backbones.get_backbone( 207 | backbone_name, 208 | input_shape=input_shape, 209 | weights=encoder_weights, 210 | include_top=False, 211 | **kwargs 212 | ) 213 | 214 | feature_layers = Backbones.get_feature_layers(backbone_name, n=3) 215 | 216 | if downsample_factor == 16: 217 | psp_layer_idx = feature_layers[0] 218 | elif downsample_factor == 8: 219 | psp_layer_idx = feature_layers[1] 220 | elif downsample_factor == 4: 221 | psp_layer_idx = feature_layers[2] 222 | else: 223 | raise ValueError('Unsupported factor - `{}`, Use 4, 8 or 16.'.format(downsample_factor)) 224 | 225 | model = build_psp( 226 | backbone, 227 | psp_layer_idx, 228 | pooling_type=psp_pooling_type, 229 | conv_filters=psp_conv_filters, 230 | use_batchnorm=psp_use_batchnorm, 231 | final_upsampling_factor=downsample_factor, 232 | classes=classes, 233 | activation=activation, 234 | dropout=psp_dropout, 235 | ) 236 | 237 | # lock encoder weights for fine-tuning 238 | if encoder_freeze: 239 | freeze_model(backbone, **kwargs) 240 | 241 | # loading model weights 242 | if weights is not None: 243 | model.load_weights(weights) 244 | 245 | return model 246 | -------------------------------------------------------------------------------- /segmentation_models/models/unet.py: -------------------------------------------------------------------------------- 1 | from keras_applications import get_submodules_from_kwargs 2 | 3 | from ._common_blocks import Conv2dBn 4 | from ._utils import freeze_model, filter_keras_submodules 5 | from ..backbones.backbones_factory import Backbones 6 | 7 | backend = None 8 | layers = None 9 | models = None 10 | keras_utils = None 11 | 12 | 13 | # --------------------------------------------------------------------- 14 | # Utility functions 15 | # --------------------------------------------------------------------- 16 | 17 | def get_submodules(): 18 | return { 19 | 'backend': backend, 20 | 'models': models, 21 | 'layers': layers, 22 | 'utils': keras_utils, 23 | } 24 | 25 | 26 | # --------------------------------------------------------------------- 27 | # Blocks 28 | # --------------------------------------------------------------------- 29 | 30 | def Conv3x3BnReLU(filters, use_batchnorm, name=None): 31 | kwargs = get_submodules() 32 | 33 | def wrapper(input_tensor): 34 | return Conv2dBn( 35 | filters, 36 | kernel_size=3, 37 | activation='relu', 38 | kernel_initializer='he_uniform', 39 | padding='same', 40 | use_batchnorm=use_batchnorm, 41 | name=name, 42 | **kwargs 43 | )(input_tensor) 44 | 45 | return wrapper 46 | 47 | 48 | def DecoderUpsamplingX2Block(filters, stage, use_batchnorm=False): 49 | up_name = 'decoder_stage{}_upsampling'.format(stage) 50 | conv1_name = 'decoder_stage{}a'.format(stage) 51 | conv2_name = 'decoder_stage{}b'.format(stage) 52 | concat_name = 'decoder_stage{}_concat'.format(stage) 53 | 54 | concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1 55 | 56 | def wrapper(input_tensor, skip=None): 57 | x = layers.UpSampling2D(size=2, name=up_name)(input_tensor) 58 | 59 | if skip is not None: 60 | x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip]) 61 | 62 | x = Conv3x3BnReLU(filters, use_batchnorm, name=conv1_name)(x) 63 | x = Conv3x3BnReLU(filters, use_batchnorm, name=conv2_name)(x) 64 | 65 | return x 66 | 67 | return wrapper 68 | 69 | 70 | def DecoderTransposeX2Block(filters, stage, use_batchnorm=False): 71 | transp_name = 'decoder_stage{}a_transpose'.format(stage) 72 | bn_name = 'decoder_stage{}a_bn'.format(stage) 73 | relu_name = 'decoder_stage{}a_relu'.format(stage) 74 | conv_block_name = 'decoder_stage{}b'.format(stage) 75 | concat_name = 'decoder_stage{}_concat'.format(stage) 76 | 77 | concat_axis = bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 78 | 79 | def layer(input_tensor, skip=None): 80 | 81 | x = layers.Conv2DTranspose( 82 | filters, 83 | kernel_size=(4, 4), 84 | strides=(2, 2), 85 | padding='same', 86 | name=transp_name, 87 | use_bias=not use_batchnorm, 88 | )(input_tensor) 89 | 90 | if use_batchnorm: 91 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x) 92 | 93 | x = layers.Activation('relu', name=relu_name)(x) 94 | 95 | if skip is not None: 96 | x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip]) 97 | 98 | x = Conv3x3BnReLU(filters, use_batchnorm, name=conv_block_name)(x) 99 | 100 | return x 101 | 102 | return layer 103 | 104 | 105 | # --------------------------------------------------------------------- 106 | # Unet Decoder 107 | # --------------------------------------------------------------------- 108 | 109 | def build_unet( 110 | backbone, 111 | decoder_block, 112 | skip_connection_layers, 113 | decoder_filters=(256, 128, 64, 32, 16), 114 | n_upsample_blocks=5, 115 | classes=1, 116 | activation='sigmoid', 117 | use_batchnorm=True, 118 | ): 119 | input_ = backbone.input 120 | x = backbone.output 121 | 122 | # extract skip connections 123 | skips = ([backbone.get_layer(name=i).output if isinstance(i, str) 124 | else backbone.get_layer(index=i).output for i in skip_connection_layers]) 125 | 126 | # add center block if previous operation was maxpooling (for vgg models) 127 | if isinstance(backbone.layers[-1], layers.MaxPooling2D): 128 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block1')(x) 129 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block2')(x) 130 | 131 | # building decoder blocks 132 | for i in range(n_upsample_blocks): 133 | 134 | if i < len(skips): 135 | skip = skips[i] 136 | else: 137 | skip = None 138 | 139 | x = decoder_block(decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip) 140 | 141 | # model head (define number of output classes) 142 | x = layers.Conv2D( 143 | filters=classes, 144 | kernel_size=(3, 3), 145 | padding='same', 146 | use_bias=True, 147 | kernel_initializer='glorot_uniform', 148 | name='final_conv', 149 | )(x) 150 | x = layers.Activation(activation, name=activation)(x) 151 | 152 | # create keras model instance 153 | model = models.Model(input_, x) 154 | 155 | return model 156 | 157 | 158 | # --------------------------------------------------------------------- 159 | # Unet Model 160 | # --------------------------------------------------------------------- 161 | 162 | def Unet( 163 | backbone_name='vgg16', 164 | input_shape=(None, None, 3), 165 | classes=1, 166 | activation='sigmoid', 167 | weights=None, 168 | encoder_weights='imagenet', 169 | encoder_freeze=False, 170 | encoder_features='default', 171 | decoder_block_type='upsampling', 172 | decoder_filters=(256, 128, 64, 32, 16), 173 | decoder_use_batchnorm=True, 174 | **kwargs 175 | ): 176 | """ Unet_ is a fully convolution neural network for image semantic segmentation 177 | 178 | Args: 179 | backbone_name: name of classification model (without last dense layers) used as feature 180 | extractor to build segmentation model. 181 | input_shape: shape of input data/image ``(H, W, C)``, in general 182 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be 183 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``. 184 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 185 | activation: name of one of ``keras.activations`` for last model layer 186 | (e.g. ``sigmoid``, ``softmax``, ``linear``). 187 | weights: optional, path to model weights. 188 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 189 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 190 | encoder_features: a list of layer numbers or names starting from top of the model. 191 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used 192 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``. 193 | decoder_block_type: one of blocks with following layers structure: 194 | 195 | - `upsampling`: ``UpSampling2D`` -> ``Conv2D`` -> ``Conv2D`` 196 | - `transpose`: ``Transpose2D`` -> ``Conv2D`` 197 | 198 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks 199 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 200 | is used. 201 | 202 | Returns: 203 | ``keras.models.Model``: **Unet** 204 | 205 | .. _Unet: 206 | https://arxiv.org/pdf/1505.04597 207 | 208 | """ 209 | 210 | global backend, layers, models, keras_utils 211 | submodule_args = filter_keras_submodules(kwargs) 212 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) 213 | 214 | if decoder_block_type == 'upsampling': 215 | decoder_block = DecoderUpsamplingX2Block 216 | elif decoder_block_type == 'transpose': 217 | decoder_block = DecoderTransposeX2Block 218 | else: 219 | raise ValueError('Decoder block type should be in ("upsampling", "transpose"). ' 220 | 'Got: {}'.format(decoder_block_type)) 221 | 222 | backbone = Backbones.get_backbone( 223 | backbone_name, 224 | input_shape=input_shape, 225 | weights=encoder_weights, 226 | include_top=False, 227 | **kwargs, 228 | ) 229 | 230 | if encoder_features == 'default': 231 | encoder_features = Backbones.get_feature_layers(backbone_name, n=4) 232 | 233 | model = build_unet( 234 | backbone=backbone, 235 | decoder_block=decoder_block, 236 | skip_connection_layers=encoder_features, 237 | decoder_filters=decoder_filters, 238 | classes=classes, 239 | activation=activation, 240 | n_upsample_blocks=len(decoder_filters), 241 | use_batchnorm=decoder_use_batchnorm, 242 | ) 243 | 244 | # lock encoder weights for fine-tuning 245 | if encoder_freeze: 246 | freeze_model(backbone, **kwargs) 247 | 248 | # loading model weights 249 | if weights is not None: 250 | model.load_weights(weights) 251 | 252 | return model 253 | -------------------------------------------------------------------------------- /segmentation_models/utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for segmentation models """ 2 | 3 | from keras_applications import get_submodules_from_kwargs 4 | from . import inject_global_submodules 5 | 6 | 7 | def set_trainable(model, recompile=True, **kwargs): 8 | """Set all layers of model trainable and recompile it 9 | 10 | Note: 11 | Model is recompiled using same optimizer, loss and metrics:: 12 | 13 | model.compile( 14 | model.optimizer, 15 | loss=model.loss, 16 | metrics=model.metrics, 17 | loss_weights=model.loss_weights, 18 | sample_weight_mode=model.sample_weight_mode, 19 | weighted_metrics=model.weighted_metrics, 20 | ) 21 | 22 | Args: 23 | model (``keras.models.Model``): instance of keras model 24 | 25 | """ 26 | for layer in model.layers: 27 | layer.trainable = True 28 | 29 | if recompile: 30 | model.compile( 31 | model.optimizer, 32 | loss=model.loss, 33 | metrics=model.metrics, 34 | loss_weights=model.loss_weights, 35 | sample_weight_mode=model.sample_weight_mode, 36 | weighted_metrics=model.weighted_metrics, 37 | ) 38 | 39 | 40 | @inject_global_submodules 41 | def set_regularization( 42 | model, 43 | kernel_regularizer=None, 44 | bias_regularizer=None, 45 | activity_regularizer=None, 46 | beta_regularizer=None, 47 | gamma_regularizer=None, 48 | **kwargs 49 | ): 50 | """Set regularizers to all layers 51 | 52 | Note: 53 | Returned model's config is updated correctly 54 | 55 | Args: 56 | model (``keras.models.Model``): instance of keras model 57 | kernel_regularizer(``regularizer`): regularizer of kernels 58 | bias_regularizer(``regularizer``): regularizer of bias 59 | activity_regularizer(``regularizer``): regularizer of activity 60 | gamma_regularizer(``regularizer``): regularizer of gamma of BatchNormalization 61 | beta_regularizer(``regularizer``): regularizer of beta of BatchNormalization 62 | 63 | Return: 64 | out (``Model``): config updated model 65 | """ 66 | _, _, models, _ = get_submodules_from_kwargs(kwargs) 67 | 68 | for layer in model.layers: 69 | # set kernel_regularizer 70 | if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'): 71 | layer.kernel_regularizer = kernel_regularizer 72 | # set bias_regularizer 73 | if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'): 74 | layer.bias_regularizer = bias_regularizer 75 | # set activity_regularizer 76 | if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'): 77 | layer.activity_regularizer = activity_regularizer 78 | 79 | # set beta and gamma of BN layer 80 | if beta_regularizer is not None and hasattr(layer, 'beta_regularizer'): 81 | layer.beta_regularizer = beta_regularizer 82 | 83 | if gamma_regularizer is not None and hasattr(layer, 'gamma_regularizer'): 84 | layer.gamma_regularizer = gamma_regularizer 85 | 86 | out = models.model_from_json(model.to_json()) 87 | out.set_weights(model.get_weights()) 88 | 89 | return out 90 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pip install twine 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'segmentation_models' 16 | DESCRIPTION = 'Image segmentation models with pre-trained backbones with Keras.' 17 | URL = 'https://github.com/qubvel/segmentation_models' 18 | EMAIL = 'qubvel@gmail.com' 19 | AUTHOR = 'Pavel Yakubovskiy' 20 | REQUIRES_PYTHON = '>=3.0.0' 21 | VERSION = None 22 | 23 | # The rest you shouldn't have to touch too much :) 24 | # ------------------------------------------------ 25 | # Except, perhaps the License and Trove Classifiers! 26 | # If you do change the License, remember to change the Trove Classifier for that! 27 | 28 | here = os.path.abspath(os.path.dirname(__file__)) 29 | 30 | # What packages are required for this module to be executed? 31 | try: 32 | with open(os.path.join(here, 'requirements.txt'), encoding='utf-8') as f: 33 | REQUIRED = f.read().split('\n') 34 | except: 35 | REQUIRED = [] 36 | 37 | # What packages are optional? 38 | EXTRAS = { 39 | 'tests': ['pytest', 'scikit-image'], 40 | } 41 | 42 | # Import the README and use it as the long-description. 43 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 44 | try: 45 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 46 | long_description = '\n' + f.read() 47 | except FileNotFoundError: 48 | long_description = DESCRIPTION 49 | 50 | # Load the package's __version__.py module as a dictionary. 51 | about = {} 52 | if not VERSION: 53 | with open(os.path.join(here, NAME, '__version__.py')) as f: 54 | exec(f.read(), about) 55 | else: 56 | about['__version__'] = VERSION 57 | 58 | 59 | class UploadCommand(Command): 60 | """Support setup.py upload.""" 61 | 62 | description = 'Build and publish the package.' 63 | user_options = [] 64 | 65 | @staticmethod 66 | def status(s): 67 | """Prints things in bold.""" 68 | print(s) 69 | 70 | def initialize_options(self): 71 | pass 72 | 73 | def finalize_options(self): 74 | pass 75 | 76 | def run(self): 77 | try: 78 | self.status('Removing previous builds...') 79 | rmtree(os.path.join(here, 'dist')) 80 | except OSError: 81 | pass 82 | 83 | self.status('Building Source and Wheel (universal) distribution...') 84 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 85 | 86 | self.status('Uploading the package to PyPI via Twine...') 87 | os.system('twine upload dist/*') 88 | 89 | self.status('Pushing git tags...') 90 | os.system('git tag v{0}'.format(about['__version__'])) 91 | os.system('git push --tags') 92 | 93 | sys.exit() 94 | 95 | 96 | # Where the magic happens: 97 | setup( 98 | name=NAME, 99 | version=about['__version__'], 100 | description=DESCRIPTION, 101 | long_description=long_description, 102 | long_description_content_type='text/x-rst', 103 | author=AUTHOR, 104 | author_email=EMAIL, 105 | python_requires=REQUIRES_PYTHON, 106 | url=URL, 107 | packages=find_packages(exclude=('tests', 'docs', 'images', 'examples')), 108 | # If your package is a single module, use this instead of 'packages': 109 | # py_modules=['mypackage'], 110 | 111 | # entry_points={ 112 | # 'console_scripts': ['mycli=mymodule:cli'], 113 | # }, 114 | install_requires=REQUIRED, 115 | extras_require=EXTRAS, 116 | include_package_data=True, 117 | license='MIT', 118 | classifiers=[ 119 | # Trove classifiers 120 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 121 | 'License :: OSI Approved :: MIT License', 122 | 'Programming Language :: Python', 123 | 'Programming Language :: Python :: 3', 124 | 'Programming Language :: Python :: Implementation :: CPython', 125 | 'Programming Language :: Python :: Implementation :: PyPy' 126 | ], 127 | # $ setup.py publish support. 128 | cmdclass={ 129 | 'upload': UploadCommand, 130 | }, 131 | ) -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | import segmentation_models as sm 5 | from segmentation_models.metrics import IOUScore, FScore 6 | from segmentation_models.losses import JaccardLoss, DiceLoss 7 | 8 | if sm.framework() == sm._TF_KERAS_FRAMEWORK_NAME: 9 | from tensorflow import keras 10 | elif sm.framework() == sm._KERAS_FRAMEWORK_NAME: 11 | import keras 12 | else: 13 | raise ValueError('Incorrect framework {}'.format(sm.framework())) 14 | 15 | METRICS = [ 16 | IOUScore, 17 | FScore, 18 | ] 19 | 20 | LOSSES = [ 21 | JaccardLoss, 22 | DiceLoss, 23 | ] 24 | 25 | GT0 = np.array( 26 | [ 27 | [0, 0, 0], 28 | [0, 0, 0], 29 | [0, 0, 0], 30 | ], 31 | dtype='float32', 32 | ) 33 | 34 | GT1 = np.array( 35 | [ 36 | [1, 1, 0], 37 | [1, 1, 0], 38 | [0, 0, 0], 39 | ], 40 | dtype='float32', 41 | ) 42 | 43 | PR1 = np.array( 44 | [ 45 | [0, 0, 0], 46 | [1, 1, 0], 47 | [0, 0, 0], 48 | ], 49 | dtype='float32', 50 | ) 51 | 52 | PR2 = np.array( 53 | [ 54 | [0, 0, 0], 55 | [1, 1, 0], 56 | [1, 1, 0], 57 | ], 58 | dtype='float32', 59 | ) 60 | 61 | PR3 = np.array( 62 | [ 63 | [0, 0, 0], 64 | [0, 0, 0], 65 | [1, 0, 0], 66 | ], 67 | dtype='float32', 68 | ) 69 | 70 | IOU_CASES = ( 71 | 72 | (GT0, GT0, 1.00), 73 | (GT1, GT1, 1.00), 74 | 75 | (GT0, PR1, 0.00), 76 | (GT0, PR2, 0.00), 77 | (GT0, PR3, 0.00), 78 | 79 | (GT1, PR1, 0.50), 80 | (GT1, PR2, 1. / 3.), 81 | (GT1, PR3, 0.00), 82 | ) 83 | 84 | F1_CASES = ( 85 | 86 | (GT0, GT0, 1.00), 87 | (GT1, GT1, 1.00), 88 | 89 | (GT0, PR1, 0.00), 90 | (GT0, PR2, 0.00), 91 | (GT0, PR3, 0.00), 92 | 93 | (GT1, PR1, 2. / 3.), 94 | (GT1, PR2, 0.50), 95 | (GT1, PR3, 0.00), 96 | ) 97 | 98 | F2_CASES = ( 99 | 100 | (GT0, GT0, 1.00), 101 | (GT1, GT1, 1.00), 102 | 103 | (GT0, PR1, 0.00), 104 | (GT0, PR2, 0.00), 105 | (GT0, PR3, 0.00), 106 | 107 | (GT1, PR1, 5. / 9.), 108 | (GT1, PR2, 0.50), 109 | (GT1, PR3, 0.00), 110 | ) 111 | 112 | 113 | def _to_4d(x): 114 | if x.ndim == 2: 115 | return x[None, :, :, None] 116 | elif x.ndim == 3: 117 | return x[None, :, :] 118 | 119 | 120 | def _add_4d(x): 121 | if x.ndim == 3: 122 | return x[..., None] 123 | 124 | 125 | @pytest.mark.parametrize('case', IOU_CASES) 126 | def test_iou_metric(case): 127 | gt, pr, res = case 128 | gt = _to_4d(gt) 129 | pr = _to_4d(pr) 130 | iou_score = IOUScore(smooth=10e-12) 131 | score = keras.backend.eval(iou_score(gt, pr)) 132 | assert np.allclose(score, res) 133 | 134 | 135 | @pytest.mark.parametrize('case', IOU_CASES) 136 | def test_jaccrad_loss(case): 137 | gt, pr, res = case 138 | gt = _to_4d(gt) 139 | pr = _to_4d(pr) 140 | jaccard_loss = JaccardLoss(smooth=10e-12) 141 | score = keras.backend.eval(jaccard_loss(gt, pr)) 142 | assert np.allclose(score, 1 - res) 143 | 144 | 145 | def _test_f_metric(case, beta=1): 146 | gt, pr, res = case 147 | gt = _to_4d(gt) 148 | pr = _to_4d(pr) 149 | f_score = FScore(beta=beta, smooth=10e-12) 150 | score = keras.backend.eval(f_score(gt, pr)) 151 | assert np.allclose(score, res) 152 | 153 | 154 | @pytest.mark.parametrize('case', F1_CASES) 155 | def test_f1_metric(case): 156 | _test_f_metric(case, beta=1) 157 | 158 | 159 | @pytest.mark.parametrize('case', F2_CASES) 160 | def test_f2_metric(case): 161 | _test_f_metric(case, beta=2) 162 | 163 | 164 | @pytest.mark.parametrize('case', F1_CASES) 165 | def test_dice_loss(case): 166 | gt, pr, res = case 167 | gt = _to_4d(gt) 168 | pr = _to_4d(pr) 169 | dice_loss = DiceLoss(smooth=10e-12) 170 | score = keras.backend.eval(dice_loss(gt, pr)) 171 | assert np.allclose(score, 1 - res) 172 | 173 | 174 | @pytest.mark.parametrize('func', METRICS + LOSSES) 175 | def test_per_image(func): 176 | gt = np.stack([GT0, GT1], axis=0) 177 | pr = np.stack([PR1, PR2], axis=0) 178 | 179 | gt = _add_4d(gt) 180 | pr = _add_4d(pr) 181 | 182 | # calculate score per image 183 | score_1 = keras.backend.eval(func(per_image=True, smooth=10e-12)(gt, pr)) 184 | score_2 = np.mean([ 185 | keras.backend.eval(func(smooth=10e-12)(_to_4d(GT0), _to_4d(PR1))), 186 | keras.backend.eval(func(smooth=10e-12)(_to_4d(GT1), _to_4d(PR2))), 187 | ]) 188 | assert np.allclose(score_1, score_2) 189 | 190 | 191 | @pytest.mark.parametrize('func', METRICS + LOSSES) 192 | def test_per_batch(func): 193 | gt = np.stack([GT0, GT1], axis=0) 194 | pr = np.stack([PR1, PR2], axis=0) 195 | 196 | gt = _add_4d(gt) 197 | pr = _add_4d(pr) 198 | 199 | # calculate score per batch 200 | score_1 = keras.backend.eval(func(per_image=False, smooth=10e-12)(gt, pr)) 201 | 202 | gt1 = np.concatenate([GT0, GT1], axis=0) 203 | pr1 = np.concatenate([PR1, PR2], axis=0) 204 | score_2 = keras.backend.eval(func(per_image=True, smooth=10e-12)(_to_4d(gt1), _to_4d(pr1))) 205 | 206 | assert np.allclose(score_1, score_2) 207 | 208 | 209 | @pytest.mark.parametrize('case', IOU_CASES) 210 | def test_threshold_iou(case): 211 | gt, pr, res = case 212 | gt = _to_4d(gt) 213 | pr = _to_4d(pr) * 0.51 214 | iou_score = IOUScore(smooth=10e-12, threshold=0.5) 215 | score = keras.backend.eval(iou_score(gt, pr)) 216 | assert np.allclose(score, res) 217 | 218 | 219 | if __name__ == '__main__': 220 | pytest.main([__file__]) 221 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import random 4 | import six 5 | import numpy as np 6 | 7 | import segmentation_models as sm 8 | from segmentation_models import Unet 9 | from segmentation_models import Linknet 10 | from segmentation_models import PSPNet 11 | from segmentation_models import FPN 12 | from segmentation_models import get_available_backbone_names 13 | 14 | if sm.framework() == sm._TF_KERAS_FRAMEWORK_NAME: 15 | from tensorflow import keras 16 | elif sm.framework() == sm._KERAS_FRAMEWORK_NAME: 17 | import keras 18 | else: 19 | raise ValueError('Incorrect framework {}'.format(sm.framework())) 20 | 21 | def get_backbones(): 22 | is_travis = os.environ.get('TRAVIS', False) 23 | exclude = ['senet154', 'efficientnetb6', 'efficientnetb7'] 24 | backbones = get_available_backbone_names() 25 | 26 | if is_travis: 27 | backbones = [b for b in backbones if b not in exclude] 28 | return backbones 29 | 30 | 31 | BACKBONES = get_backbones() 32 | 33 | 34 | def _select_names(names): 35 | is_full = os.environ.get('FULL_TEST', False) 36 | if not is_full: 37 | return [random.choice(names)] 38 | else: 39 | return names 40 | 41 | 42 | def keras_test(func): 43 | """Function wrapper to clean up after TensorFlow tests. 44 | # Arguments 45 | func: test function to clean up after. 46 | # Returns 47 | A function wrapping the input function. 48 | """ 49 | @six.wraps(func) 50 | def wrapper(*args, **kwargs): 51 | output = func(*args, **kwargs) 52 | keras.backend.clear_session() 53 | return output 54 | return wrapper 55 | 56 | 57 | @keras_test 58 | def _test_none_shape(model_fn, backbone, *args, **kwargs): 59 | 60 | # define number of channels 61 | input_shape = kwargs.get('input_shape', None) 62 | n_channels = 3 if input_shape is None else input_shape[-1] 63 | 64 | # create test sample 65 | x = np.ones((1, 32, 32, n_channels)) 66 | 67 | # define model and process sample 68 | model = model_fn(backbone, *args, **kwargs) 69 | y = model.predict(x) 70 | 71 | # check output dimensions 72 | assert x.shape[:-1] == y.shape[:-1] 73 | 74 | 75 | @keras_test 76 | def _test_shape(model_fn, backbone, input_shape, *args, **kwargs): 77 | 78 | # create test sample 79 | x = np.ones((1, *input_shape)) 80 | 81 | # define model and process sample 82 | model = model_fn(backbone, input_shape=input_shape, *args, **kwargs) 83 | y = model.predict(x) 84 | 85 | # check output dimensions 86 | assert x.shape[:-1] == y.shape[:-1] 87 | 88 | 89 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES)) 90 | def test_unet(backbone): 91 | _test_none_shape( 92 | Unet, backbone, encoder_weights=None) 93 | 94 | _test_none_shape( 95 | Unet, backbone, encoder_weights='imagenet') 96 | 97 | _test_shape( 98 | Unet, backbone, input_shape=(256, 256, 4), encoder_weights=None) 99 | 100 | 101 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES)) 102 | def test_linknet(backbone): 103 | _test_none_shape( 104 | Linknet, backbone, encoder_weights=None) 105 | 106 | _test_none_shape( 107 | Linknet, backbone, encoder_weights='imagenet') 108 | 109 | _test_shape( 110 | Linknet, backbone, input_shape=(256, 256, 4), encoder_weights=None) 111 | 112 | 113 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES)) 114 | def test_pspnet(backbone): 115 | 116 | _test_shape( 117 | PSPNet, backbone, input_shape=(384, 384, 4), encoder_weights=None) 118 | 119 | _test_shape( 120 | PSPNet, backbone, input_shape=(384, 384, 3), encoder_weights='imagenet') 121 | 122 | 123 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES)) 124 | def test_fpn(backbone): 125 | _test_none_shape( 126 | FPN, backbone, encoder_weights=None) 127 | 128 | _test_none_shape( 129 | FPN, backbone, encoder_weights='imagenet') 130 | 131 | _test_shape( 132 | FPN, backbone, input_shape=(256, 256, 4), encoder_weights=None) 133 | 134 | 135 | if __name__ == '__main__': 136 | pytest.main([__file__]) 137 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | import segmentation_models as sm 5 | from segmentation_models.utils import set_regularization 6 | from segmentation_models import Unet 7 | 8 | if sm.framework() == sm._TF_KERAS_FRAMEWORK_NAME: 9 | from tensorflow import keras 10 | elif sm.framework() == sm._KERAS_FRAMEWORK_NAME: 11 | import keras 12 | else: 13 | raise ValueError('Incorrect framework {}'.format(sm.framework())) 14 | 15 | X1 = np.ones((1, 32, 32, 3)) 16 | Y1 = np.ones((1, 32, 32, 1)) 17 | MODEL = Unet 18 | BACKBONE = 'resnet18' 19 | CASE = ( 20 | 21 | (X1, Y1, MODEL, BACKBONE), 22 | ) 23 | 24 | 25 | def _test_regularizer(model, reg_model, x, y): 26 | 27 | def zero_loss(gt, pr): 28 | return pr * 0 29 | 30 | model.compile('Adam', loss=zero_loss, metrics=['binary_accuracy']) 31 | reg_model.compile('Adam', loss=zero_loss, metrics=['binary_accuracy']) 32 | 33 | loss_1, _ = model.test_on_batch(x, y) 34 | loss_2, _ = reg_model.test_on_batch(x, y) 35 | 36 | assert loss_1 == 0 37 | assert loss_2 > 0 38 | 39 | keras.backend.clear_session() 40 | 41 | 42 | @pytest.mark.parametrize('case', CASE) 43 | def test_kernel_reg(case): 44 | x, y, model_fn, backbone= case 45 | 46 | l1_reg = keras.regularizers.l1(0.1) 47 | model = model_fn(backbone) 48 | reg_model = set_regularization(model, kernel_regularizer=l1_reg) 49 | _test_regularizer(model, reg_model, x, y) 50 | 51 | l2_reg = keras.regularizers.l2(0.1) 52 | model = model_fn(backbone, encoder_weights=None) 53 | reg_model = set_regularization(model, kernel_regularizer=l2_reg) 54 | _test_regularizer(model, reg_model, x, y) 55 | 56 | 57 | """ 58 | Note: 59 | backbone resnet18 use BN after each conv layer --- so no bias used in these conv layers 60 | skip the bias regularizer test 61 | 62 | @pytest.mark.parametrize('case', CASE) 63 | def test_bias_reg(case): 64 | x, y, model_fn, backbone = case 65 | 66 | l1_reg = regularizers.l1(1) 67 | model = model_fn(backbone) 68 | reg_model = set_regularization(model, bias_regularizer=l1_reg) 69 | _test_regularizer(model, reg_model, x, y) 70 | 71 | l2_reg = regularizers.l2(1) 72 | model = model_fn(backbone) 73 | reg_model = set_regularization(model, bias_regularizer=l2_reg) 74 | _test_regularizer(model, reg_model, x, y) 75 | """ 76 | 77 | 78 | @pytest.mark.parametrize('case', CASE) 79 | def test_bn_reg(case): 80 | x, y, model_fn, backbone= case 81 | 82 | l1_reg = keras.regularizers.l1(1) 83 | model = model_fn(backbone) 84 | reg_model = set_regularization(model, gamma_regularizer=l1_reg) 85 | _test_regularizer(model, reg_model, x, y) 86 | 87 | model = model_fn(backbone) 88 | reg_model = set_regularization(model, beta_regularizer=l1_reg) 89 | _test_regularizer(model, reg_model, x, y) 90 | 91 | l2_reg = keras.regularizers.l2(1) 92 | model = model_fn(backbone) 93 | reg_model = set_regularization(model, gamma_regularizer=l2_reg) 94 | _test_regularizer(model, reg_model, x, y) 95 | 96 | model = model_fn(backbone) 97 | reg_model = set_regularization(model, beta_regularizer=l2_reg) 98 | _test_regularizer(model, reg_model, x, y) 99 | 100 | 101 | @pytest.mark.parametrize('case', CASE) 102 | def test_activity_reg(case): 103 | x, y, model_fn, backbone= case 104 | 105 | l2_reg = keras.regularizers.l2(1) 106 | model = model_fn(backbone) 107 | reg_model = set_regularization(model, activity_regularizer=l2_reg) 108 | _test_regularizer(model, reg_model, x, y) 109 | 110 | 111 | if __name__ == '__main__': 112 | pytest.main([__file__]) 113 | --------------------------------------------------------------------------------