├── .github
└── FUNDING.yml
├── .gitignore
├── .gitmodules
├── .travis.yml
├── CHANGELOG.md
├── LICENSE
├── MANIFEST.in
├── README.rst
├── __init__.py
├── docs
├── Makefile
├── api.rst
├── conf.py
├── index.rst
├── install.rst
├── support.rst
└── tutorial.rst
├── examples
├── binary segmentation (camvid).ipynb
└── multiclass segmentation (camvid).ipynb
├── images
├── fpn.png
├── linknet.png
├── logo.png
├── pspnet.png
└── unet.png
├── requirements.txt
├── segmentation_models
├── __init__.py
├── __version__.py
├── backbones
│ ├── __init__.py
│ ├── backbones_factory.py
│ ├── inception_resnet_v2.py
│ └── inception_v3.py
├── base
│ ├── __init__.py
│ ├── functional.py
│ └── objects.py
├── losses.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── _common_blocks.py
│ ├── _utils.py
│ ├── fpn.py
│ ├── linknet.py
│ ├── pspnet.py
│ └── unet.py
└── utils.py
├── setup.py
└── tests
├── test_metrics.py
├── test_models.py
└── test_utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: qubvel
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: qubvel
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea/
6 | *ipynb
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/.gitmodules
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 | dist: trusty
3 | language: python
4 |
5 | matrix:
6 | include:
7 | - python: 3.6
8 | env: KERAS_VERSION=2.2.4
9 | - python: 3.6
10 | env: SM_FRAMEWORK='tf.keras'
11 |
12 | git:
13 | submodules: true
14 |
15 |
16 | install:
17 | # code below is taken from http://conda.pydata.org/docs/travis.html
18 | # We do this conditionally because it saves us some downloading if the
19 | # version is the same.
20 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
21 | - bash miniconda.sh -b -p $HOME/miniconda
22 | - export PATH="$HOME/miniconda/bin:$PATH"
23 | - hash -r
24 | - conda config --set always_yes yes --set changeps1 no
25 | - conda update -q conda
26 | # Useful for debugging any issues with conda
27 | - conda info -a
28 |
29 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytest pandas
30 | - source activate test-environment
31 | - pip install --only-binary=numpy,scipy numpy nose scipy matplotlib h5py theano
32 |
33 |
34 | # set library path
35 | - export LD_LIBRARY_PATH=$HOME/miniconda/envs/test-environment/lib/:$LD_LIBRARY_PATH
36 | - conda install mkl mkl-service
37 |
38 | # install TensorFlow (CPU version).
39 | - pip install tensorflow==1.14
40 |
41 | # install keras
42 | - if [ -z $KERAS_VERSION ]; then
43 | echo "Using tf.keras";
44 | else
45 | echo "Using keras";
46 | pip install keras==$KERAS_VERSION;
47 | fi
48 |
49 | # install lib in develop mode
50 | - pip install -e .[tests]
51 |
52 | # detect one of markdown files is changed or not
53 | - export DOC_ONLY_CHANGED=False;
54 | - if [ $(git diff --name-only HEAD~1 | wc -l) == "1" ] && [[ "$(git diff --name-only HEAD~1)" == *"md" ]]; then
55 | export DOC_ONLY_CHANGED=True;
56 | fi
57 |
58 | # command to run tests
59 | script:
60 | - export MKL_THREADING_LAYER="GNU"
61 | - mkdir -p ~/.keras/models
62 | # set up keras backend
63 | - if [[ "$DOC_ONLY_CHANGED" == "False" ]]; then
64 | PYTHONPATH=$PWD:$PYTHONPATH py.test tests/;
65 | fi
66 |
67 | deploy:
68 | provider: pypi
69 | user: qubvel
70 | password:
71 | secure: QA/UJmkXGlXy/6C8X0E/bPf4izu3rJsztaEmqIM1npxPiv2Uf4WFs43vxkMXwfHrflocdfw8SBM8bWnbunGT2SvDdo/MMCMpol7unE74T/RbODYl6aiJWVM3QKOXL8pQD0oQ+03L1YK3nCeSQdePINEPmuFmvwyO40q8Dwv8HBZIGZlEo4SK4xr8ekxfmtbezxQ7vUL3sNcvCJDXrZX/4UdXrhdRk+zYoN3dv8NmM4FmChajq/m5Am9OPdbdUBHmIYmvk7L3IpwJeMMpG5FVdGNVwYj7XNHlcy+KZ2/CKn9EpslRDPxY4650654PmhSZWDctZG7jiFWLCZBUvowiyAOPZknZNgdu5gJAfdg37XS9IP3HgTZN6Jb5Bm0by3IlKt+dTzyJQcUnRql5B1wwEI0XO3/YWQe1GQQphIO1bli9hT8n8xNDNjc49vDlu4zKyaYnQmLhqNxkyeruXSTpc8qTITuS+EGgkAUrrBj/IaFcutIg9WOzvJ3nZO8X8UG7LlyQx4AOpfHP6bynAmlT+UFccCEq66Zoh7teWLk0lUekuYST2iQJ3pwFoQGYJRCsmxsz7J0B9ayFVVT/fg+GZpZm1oTnnJ27hh8LZWv/Cr/WHOBYc3qvigWx4pDssJ+O6z7de3aWrGvzAVgXr190fRdP55a34HhNbiKZ0YWmrTs=
72 | on:
73 | tags: true
74 | skip_existing: true
75 | distributions: "sdist bdist_wheel"
76 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Change Log
2 |
3 | **Version 1.0.0**
4 |
5 | ###### Areas of improvement
6 | - Support for `keras` and `tf.keras`
7 | - Losses as classes, base loss operations (sum of losses, multiplied loss)
8 | - NCHW and NHWC support
9 | - Removed pure tf operations to work with other keras backends
10 | - Reduced a number of custom objects for better models serialization and deserialization
11 |
12 | ###### New featrues
13 | - New backbones: EfficentNetB[0-7]
14 | - New loss function: Focal loss
15 | - New metrics: Precision, Recall
16 |
17 | ###### API changes
18 | - `get_preprocessing` moved from `sm.backbones.get_preprocessing` to `sm.get_preprocessing`
19 |
20 | **Version 0.2.1**
21 |
22 | ###### Areas of improvement
23 |
24 | - Added `set_regularization` function
25 | - Added `beta` argument to dice loss
26 | - Added `threshold` argument for metrics
27 | - Fixed `prerprocess_input` for mobilenets
28 | - Fixed missing parameter `interpolation` in `ResizeImage` layer config
29 | - Some minor improvements in docs, fixed typos
30 |
31 | **Version 0.2.0**
32 |
33 | ###### Areas of improvement
34 |
35 | - New backbones (SE-ResNets, SE-ResNeXts, SENet154, MobileNets)
36 | - Metrcis:
37 | - `iou_score` / `jaccard_score`
38 | - `f_score` / `dice_score`
39 | - Losses:
40 | - `jaccard_loss`
41 | - `bce_jaccard_loss`
42 | - `cce_jaccard_loss`
43 | - `dice_loss`
44 | - `bce_dice_loss`
45 | - `cce_dice_loss`
46 | - Documentation [Read the Docs](https://segmentation-models.readthedocs.io)
47 | - Tests + Travis-CI
48 |
49 | ###### API changes
50 |
51 | - Some parameters renamed (see API docs)
52 | - `encoder_freeze=True` does not `freeze` BatchNormalization layer of encoder
53 |
54 | ###### Thanks
55 |
56 | [@IlyaOvodov](https://github.com/IlyaOvodov) [#15](https://github.com/qubvel/segmentation_models/issues/15) [#37](https://github.com/qubvel/segmentation_models/pull/37) investigation of `align_corners` parameter in `ResizeImage` layer
57 | [@NiklasDL](https://github.com/NiklasDL) [#29](https://github.com/qubvel/segmentation_models/issues/29) investigation about convolution kernel in PSPNet final layers
58 |
59 | **Version 0.1.2**
60 |
61 | ###### Areas of improvement
62 |
63 | - Added PSPModel
64 | - Prepocessing functions for all backbones:
65 | ```python
66 | from segmentation_models.backbones import get_preprocessing
67 |
68 | preprocessing_fn = get_preprocessing('resnet34')
69 | X = preprocessing_fn(x)
70 | ```
71 | ###### API changes
72 | - Default param `use_batchnorm=True` for all decoders
73 | - FPN model `Upsample2D` layer renamed to `ResizeImage`
74 |
75 | **Version 0.1.1**
76 | - Added `Linknet` model
77 | - Keras 2.2+ compatibility (fixed import of `_obtain_input_shape`)
78 | - Small code improvements and bug fixes
79 |
80 | **Version 0.1.0**
81 | - `Unet` and `FPN` models
82 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2018, Pavel Yakubovskiy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md LICENSE requirements.txt
2 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | .. raw:: html
2 |
3 |
4 |
5 | Python library with Neural Networks for Image Segmentation based on Keras and TensorFlow.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | **The main features** of this library are:
19 |
20 | - High level API (just two lines of code to create model for segmentation)
21 | - **4** models architectures for binary and multi-class image segmentation
22 | (including legendary **Unet**)
23 | - **25** available backbones for each architecture
24 | - All backbones have **pre-trained** weights for faster and better
25 | convergence
26 | - Helpful segmentation losses (Jaccard, Dice, Focal) and metrics (IoU, F-score)
27 |
28 | **Important note**
29 |
30 | Some models of version ``1.*`` are not compatible with previously trained models,
31 | if you have such models and want to load them - roll back with:
32 |
33 | $ pip install -U segmentation-models==0.2.1
34 |
35 | Table of Contents
36 | ~~~~~~~~~~~~~~~~~
37 | - `Quick start`_
38 | - `Simple training pipeline`_
39 | - `Examples`_
40 | - `Models and Backbones`_
41 | - `Installation`_
42 | - `Documentation`_
43 | - `Change log`_
44 | - `Citing`_
45 | - `License`_
46 |
47 | Quick start
48 | ~~~~~~~~~~~
49 | Library is build to work together with Keras and TensorFlow Keras frameworks
50 |
51 | .. code:: python
52 |
53 | import segmentation_models as sm
54 | # Segmentation Models: using `keras` framework.
55 |
56 | By default it tries to import ``keras``, if it is not installed, it will try to start with ``tensorflow.keras`` framework.
57 | There are several ways to choose framework:
58 |
59 | - Provide environment variable ``SM_FRAMEWORK=keras`` / ``SM_FRAMEWORK=tf.keras`` before import ``segmentation_models``
60 | - Change framework ``sm.set_framework('keras')`` / ``sm.set_framework('tf.keras')``
61 |
62 | You can also specify what kind of ``image_data_format`` to use, segmentation-models works with both: ``channels_last`` and ``channels_first``.
63 | This can be useful for further model conversion to Nvidia TensorRT format or optimizing model for cpu/gpu computations.
64 |
65 | .. code:: python
66 |
67 | import keras
68 | # or from tensorflow import keras
69 |
70 | keras.backend.set_image_data_format('channels_last')
71 | # or keras.backend.set_image_data_format('channels_first')
72 |
73 | Created segmentation model is just an instance of Keras Model, which can be build as easy as:
74 |
75 | .. code:: python
76 |
77 | model = sm.Unet()
78 |
79 | Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
80 |
81 | .. code:: python
82 |
83 | model = sm.Unet('resnet34', encoder_weights='imagenet')
84 |
85 | Change number of output classes in the model (choose your case):
86 |
87 | .. code:: python
88 |
89 | # binary segmentation (this parameters are default when you call Unet('resnet34')
90 | model = sm.Unet('resnet34', classes=1, activation='sigmoid')
91 |
92 | .. code:: python
93 |
94 | # multiclass segmentation with non overlapping class masks (your classes + background)
95 | model = sm.Unet('resnet34', classes=3, activation='softmax')
96 |
97 | .. code:: python
98 |
99 | # multiclass segmentation with independent overlapping/non-overlapping class masks
100 | model = sm.Unet('resnet34', classes=3, activation='sigmoid')
101 |
102 |
103 | Change input shape of the model:
104 |
105 | .. code:: python
106 |
107 | # if you set input channels not equal to 3, you have to set encoder_weights=None
108 | # how to handle such case with encoder_weights='imagenet' described in docs
109 | model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)
110 |
111 | Simple training pipeline
112 | ~~~~~~~~~~~~~~~~~~~~~~~~
113 |
114 | .. code:: python
115 |
116 | import segmentation_models as sm
117 |
118 | BACKBONE = 'resnet34'
119 | preprocess_input = sm.get_preprocessing(BACKBONE)
120 |
121 | # load your data
122 | x_train, y_train, x_val, y_val = load_data(...)
123 |
124 | # preprocess input
125 | x_train = preprocess_input(x_train)
126 | x_val = preprocess_input(x_val)
127 |
128 | # define model
129 | model = sm.Unet(BACKBONE, encoder_weights='imagenet')
130 | model.compile(
131 | 'Adam',
132 | loss=sm.losses.bce_jaccard_loss,
133 | metrics=[sm.metrics.iou_score],
134 | )
135 |
136 | # fit model
137 | # if you use data generator use model.fit_generator(...) instead of model.fit(...)
138 | # more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator
139 | model.fit(
140 | x=x_train,
141 | y=y_train,
142 | batch_size=16,
143 | epochs=100,
144 | validation_data=(x_val, y_val),
145 | )
146 |
147 | Same manipulations can be done with ``Linknet``, ``PSPNet`` and ``FPN``. For more detailed information about models API and use cases `Read the Docs `__.
148 |
149 | Examples
150 | ~~~~~~~~
151 | Models training examples:
152 | - [Jupyter Notebook] Binary segmentation (`cars`) on CamVid dataset `here `__.
153 | - [Jupyter Notebook] Multi-class segmentation (`cars`, `pedestrians`) on CamVid dataset `here `__.
154 |
155 | Models and Backbones
156 | ~~~~~~~~~~~~~~~~~~~~
157 | **Models**
158 |
159 | - `Unet `__
160 | - `FPN `__
161 | - `Linknet `__
162 | - `PSPNet `__
163 |
164 | ============= ==============
165 | Unet Linknet
166 | ============= ==============
167 | |unet_image| |linknet_image|
168 | ============= ==============
169 |
170 | ============= ==============
171 | PSPNet FPN
172 | ============= ==============
173 | |psp_image| |fpn_image|
174 | ============= ==============
175 |
176 | .. _Unet: https://github.com/qubvel/segmentation_models/blob/readme/LICENSE
177 | .. _Linknet: https://arxiv.org/abs/1707.03718
178 | .. _PSPNet: https://arxiv.org/abs/1612.01105
179 | .. _FPN: http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
180 |
181 | .. |unet_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/unet.png
182 | .. |linknet_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/linknet.png
183 | .. |psp_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/pspnet.png
184 | .. |fpn_image| image:: https://github.com/qubvel/segmentation_models/blob/master/images/fpn.png
185 |
186 | **Backbones**
187 |
188 | .. table::
189 |
190 | ============= =====
191 | Type Names
192 | ============= =====
193 | VGG ``'vgg16' 'vgg19'``
194 | ResNet ``'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'``
195 | SE-ResNet ``'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'``
196 | ResNeXt ``'resnext50' 'resnext101'``
197 | SE-ResNeXt ``'seresnext50' 'seresnext101'``
198 | SENet154 ``'senet154'``
199 | DenseNet ``'densenet121' 'densenet169' 'densenet201'``
200 | Inception ``'inceptionv3' 'inceptionresnetv2'``
201 | MobileNet ``'mobilenet' 'mobilenetv2'``
202 | EfficientNet ``'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 'efficientnetb4' 'efficientnetb5' efficientnetb6' efficientnetb7'``
203 | ============= =====
204 |
205 | .. epigraph::
206 | All backbones have weights trained on 2012 ILSVRC ImageNet dataset (``encoder_weights='imagenet'``).
207 |
208 |
209 | Installation
210 | ~~~~~~~~~~~~
211 |
212 | **Requirements**
213 |
214 | 1) python 3
215 | 2) keras >= 2.2.0 or tensorflow >= 1.13
216 | 3) keras-applications >= 1.0.7, <=1.0.8
217 | 4) image-classifiers == 1.0.*
218 | 5) efficientnet == 1.0.*
219 |
220 | **PyPI stable package**
221 |
222 | .. code:: bash
223 |
224 | $ pip install -U segmentation-models
225 |
226 | **PyPI latest package**
227 |
228 | .. code:: bash
229 |
230 | $ pip install -U --pre segmentation-models
231 |
232 | **Source latest version**
233 |
234 | .. code:: bash
235 |
236 | $ pip install git+https://github.com/qubvel/segmentation_models
237 |
238 | Documentation
239 | ~~~~~~~~~~~~~
240 | Latest **documentation** is avaliable on `Read the
241 | Docs `__
242 |
243 | Change Log
244 | ~~~~~~~~~~
245 | To see important changes between versions look at CHANGELOG.md_
246 |
247 | Citing
248 | ~~~~~~~~
249 |
250 | .. code::
251 |
252 | @misc{Yakubovskiy:2019,
253 | Author = {Pavel Iakubovskii},
254 | Title = {Segmentation Models},
255 | Year = {2019},
256 | Publisher = {GitHub},
257 | Journal = {GitHub repository},
258 | Howpublished = {\url{https://github.com/qubvel/segmentation_models}}
259 | }
260 |
261 | License
262 | ~~~~~~~
263 | Project is distributed under `MIT Licence`_.
264 |
265 | .. _CHANGELOG.md: https://github.com/qubvel/segmentation_models/blob/master/CHANGELOG.md
266 | .. _`MIT Licence`: https://github.com/qubvel/segmentation_models/blob/master/LICENSE
267 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .segmentation_models import *
2 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | Segmentation Models Python API
2 | ==============================
3 |
4 |
5 | Getting started with segmentation models is easy.
6 |
7 | Unet
8 | ~~~~
9 | .. autofunction:: segmentation_models.Unet
10 |
11 | Linknet
12 | ~~~~~~~
13 | .. autofunction:: segmentation_models.Linknet
14 |
15 | FPN
16 | ~~~
17 | .. autofunction:: segmentation_models.FPN
18 |
19 | PSPNet
20 | ~~~~~~
21 | .. autofunction:: segmentation_models.PSPNet
22 |
23 | metrics
24 | ~~~~~~~
25 | .. autofunction:: segmentation_models.metrics.IOUScore
26 | .. autofunction:: segmentation_models.metrics.FScore
27 |
28 | losses
29 | ~~~~~~
30 | .. autofunction:: segmentation_models.losses.JaccardLoss
31 | .. autofunction:: segmentation_models.losses.DiceLoss
32 | .. autofunction:: segmentation_models.losses.BinaryCELoss
33 | .. autofunction:: segmentation_models.losses.CategoricalCELoss
34 | .. autofunction:: segmentation_models.losses.BinaryFocalLoss
35 | .. autofunction:: segmentation_models.losses.CategoricalFocalLoss
36 |
37 | utils
38 | ~~~~~
39 | .. autofunction:: segmentation_models.utils.set_trainable
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # This file does only contain a selection of the most common options. For a
6 | # full list see the documentation:
7 | # http://www.sphinx-doc.org/en/master/config
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | # import os
16 | # import sys
17 | # sys.path.insert(0, os.path.abspath('.'))
18 |
19 | # -- Project information -----------------------------------------------------
20 | import sys
21 | sys.path.append('..')
22 |
23 | project = u'Segmentation Models'
24 | copyright = u'2018, Pavel Yakubovskiy'
25 | author = u'Pavel Yakubovskiy'
26 |
27 | # The short X.Y version
28 | version = u''
29 | # The full version, including alpha/beta/rc tags
30 | release = u'0.1.2'
31 |
32 |
33 | # -- General configuration ---------------------------------------------------
34 |
35 | # If your documentation needs a minimal Sphinx version, state it here.
36 | #
37 | # needs_sphinx = '1.0'
38 |
39 | # Add any Sphinx extension module names here, as strings. They can be
40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
41 | # ones.
42 | extensions = [
43 | 'sphinx.ext.autodoc',
44 | 'sphinx.ext.coverage',
45 | 'sphinx.ext.napoleon',
46 | ]
47 |
48 | # Add any paths that contain templates here, relative to this directory.
49 | templates_path = ['_templates']
50 |
51 | # The suffix(es) of source filenames.
52 | # You can specify multiple suffix as a list of string:
53 | #
54 | # source_suffix = ['.rst', '.md']
55 | source_suffix = '.rst'
56 |
57 | # The master toctree document.
58 | master_doc = 'index'
59 |
60 | # The language for content autogenerated by Sphinx. Refer to documentation
61 | # for a list of supported languages.
62 | #
63 | # This is also used if you do content translation via gettext catalogs.
64 | # Usually you set "language" from the command line for these cases.
65 | language = None
66 |
67 | # List of patterns, relative to source directory, that match files and
68 | # directories to ignore when looking for source files.
69 | # This pattern also affects html_static_path and html_extra_path.
70 | exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store']
71 |
72 | # The name of the Pygments (syntax highlighting) style to use.
73 | pygments_style = None
74 |
75 |
76 | # -- Options for HTML output -------------------------------------------------
77 |
78 | # The theme to use for HTML and HTML Help pages. See the documentation for
79 | # a list of builtin themes.
80 | #
81 | # -- Theme setup -------------------------------------------------------------
82 |
83 | import sphinx_rtd_theme
84 |
85 | html_theme = 'sphinx_rtd_theme'
86 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
87 |
88 | # Theme options are theme-specific and customize the look and feel of a theme
89 | # further. For a list of options available for each theme, see the
90 | # documentation.
91 | #
92 | # html_theme_options = {}
93 |
94 | # Add any paths that contain custom static files (such as style sheets) here,
95 | # relative to this directory. They are copied after the builtin static files,
96 | # so a file named "default.css" will overwrite the builtin "default.css".
97 | html_static_path = ['_static']
98 |
99 | # Custom sidebar templates, must be a dictionary that maps document names
100 | # to template names.
101 | #
102 | # The default sidebars (for documents that don't match any pattern) are
103 | # defined by theme itself. Builtin themes are using these templates by
104 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
105 | # 'searchbox.html']``.
106 | #
107 | # html_sidebars = {}
108 |
109 |
110 | # -- Options for HTMLHelp output ---------------------------------------------
111 |
112 | # Output file base name for HTML help builder.
113 | htmlhelp_basename = 'SegmentationModelsdoc'
114 |
115 |
116 | # -- Options for LaTeX output ------------------------------------------------
117 |
118 | latex_elements = {
119 | # The paper size ('letterpaper' or 'a4paper').
120 | #
121 | # 'papersize': 'letterpaper',
122 |
123 | # The font size ('10pt', '11pt' or '12pt').
124 | #
125 | # 'pointsize': '10pt',
126 |
127 | # Additional stuff for the LaTeX preamble.
128 | #
129 | # 'preamble': '',
130 |
131 | # Latex figure (float) alignment
132 | #
133 | # 'figure_align': 'htbp',
134 | }
135 |
136 | # Grouping the document tree into LaTeX files. List of tuples
137 | # (source start file, target name, title,
138 | # author, documentclass [howto, manual, or own class]).
139 | latex_documents = [
140 | (master_doc, 'SegmentationModels.tex', u'Segmentation Models Documentation',
141 | u'Pavel Yakubovskiy', 'manual'),
142 | ]
143 |
144 |
145 | # -- Options for manual page output ------------------------------------------
146 |
147 | # One entry per manual page. List of tuples
148 | # (source start file, name, description, authors, manual section).
149 | man_pages = [
150 | (master_doc, 'segmentationmodels', u'Segmentation Models Documentation',
151 | [author], 1)
152 | ]
153 |
154 |
155 | # -- Options for Texinfo output ----------------------------------------------
156 |
157 | # Grouping the document tree into Texinfo files. List of tuples
158 | # (source start file, target name, title, author,
159 | # dir menu entry, description, category)
160 | texinfo_documents = [
161 | (master_doc, 'SegmentationModels', u'Segmentation Models Documentation',
162 | author, 'SegmentationModels', 'One line description of project.',
163 | 'Miscellaneous'),
164 | ]
165 |
166 |
167 | # -- Options for Epub output -------------------------------------------------
168 |
169 | # Bibliographic Dublin Core info.
170 | epub_title = project
171 |
172 | # The unique identifier of the text. This can be a ISBN number
173 | # or the project homepage.
174 | #
175 | # epub_identifier = ''
176 |
177 | # A unique identification for the text.
178 | #
179 | # epub_uid = ''
180 |
181 | # A list of files that should not be packed into the epub file.
182 | epub_exclude_files = ['search.html']
183 |
184 |
185 | # -- Extension configuration -------------------------------------------------
186 |
187 | autodoc_mock_imports = [
188 | 'skimage',
189 | 'keras',
190 | 'tensorflow',
191 | 'efficientnet',
192 | 'classification_models',
193 | 'keras_applications',
194 | ]
195 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Segmentation Models documentation master file, created by
2 | sphinx-quickstart on Tue Dec 18 17:37:58 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Segmentation Models's documentation!
7 | ===============================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | install
14 | tutorial
15 | api
16 | support
17 |
18 |
19 |
20 | Indices and tables
21 | ==================
22 |
23 | * :ref:`genindex`
24 | * :ref:`modindex`
25 | * :ref:`search`
26 |
--------------------------------------------------------------------------------
/docs/install.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | **Requirements**
5 |
6 | 1) Python 3
7 | 2) Keras >= 2.2.0 or TensorFlow >= 1.13
8 | 3) keras-applications >= 1.0.7, <=1.0.8
9 | 4) image-classifiers == 1.0.0
10 | 5) efficientnet == 1.0.0
11 |
12 |
13 | .. note::
14 |
15 | This library does not have Tensorflow_ in a requirements.txt
16 | for installation. Please, choose suitable version ('cpu'/'gpu')
17 | and install it manually using official Guide_.
18 |
19 |
20 | **Pip package**
21 |
22 | .. code:: bash
23 |
24 | $ pip install segmentation-models
25 |
26 | **Latest version**
27 |
28 | .. code:: bash
29 |
30 | $ pip install git+https://github.com/qubvel/segmentation_models
31 |
32 | .. _Guide:
33 | https://www.tensorflow.org/install/
34 |
35 | .. _Tensorflow:
36 | https://www.tensorflow.org/
37 |
38 |
--------------------------------------------------------------------------------
/docs/support.rst:
--------------------------------------------------------------------------------
1 | Support
2 | =======
3 |
4 | The easiest way to get help with the project is to create issue or PR on github.
5 |
6 | Github: http://github.com/qubvel/segmentation_models/issues
--------------------------------------------------------------------------------
/docs/tutorial.rst:
--------------------------------------------------------------------------------
1 | Tutorial
2 | ========
3 |
4 | **Segmentation models** is python library with Neural Networks for
5 | `Image
6 | Segmentation `__ based
7 | on `Keras `__
8 | (`Tensorflow `__) framework.
9 |
10 | **The main features** of this library are:
11 |
12 | - High level API (just two lines to create NN)
13 | - **4** models architectures for binary and multi class segmentation
14 | (including legendary **Unet**)
15 | - **25** available backbones for each architecture
16 | - All backbones have **pre-trained** weights for faster and better
17 | convergence
18 |
19 | Quick start
20 | ~~~~~~~~~~~
21 | Since the library is built on the Keras framework, created segmentation model is just a Keras Model, which can be created as easy as:
22 |
23 | .. code:: python
24 |
25 | from segmentation_models import Unet
26 |
27 | model = Unet()
28 |
29 | Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
30 |
31 | .. code:: python
32 |
33 | model = Unet('resnet34', encoder_weights='imagenet')
34 |
35 | Change number of output classes in the model:
36 |
37 | .. code:: python
38 |
39 | model = Unet('resnet34', classes=3, activation='softmax')
40 |
41 | Change input shape of the model:
42 |
43 | .. code:: python
44 |
45 | model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)
46 |
47 | Simple training pipeline
48 | ~~~~~~~~~~~~~~~~~~~~~~~~
49 |
50 | .. code:: python
51 |
52 | from segmentation_models import Unet
53 | from segmentation_models import get_preprocessing
54 | from segmentation_models.losses import bce_jaccard_loss
55 | from segmentation_models.metrics import iou_score
56 |
57 | BACKBONE = 'resnet34'
58 | preprocess_input = get_preprocessing(BACKBONE)
59 |
60 | # load your data
61 | x_train, y_train, x_val, y_val = load_data(...)
62 |
63 | # preprocess input
64 | x_train = preprocess_input(x_train)
65 | x_val = preprocess_input(x_val)
66 |
67 | # define model
68 | model = Unet(BACKBONE, encoder_weights='imagenet')
69 | model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score])
70 |
71 | # fit model
72 | model.fit(
73 | x=x_train,
74 | y=y_train,
75 | batch_size=16,
76 | epochs=100,
77 | validation_data=(x_val, y_val),
78 | )
79 |
80 |
81 | Same manimulations can be done with ``Linknet``, ``PSPNet`` and ``FPN``. For more detailed information about models API and use cases `Read the Docs `__.
82 |
83 | Models and Backbones
84 | ~~~~~~~~~~~~~~~~~~~~
85 | **Models**
86 |
87 | - `Unet `__
88 | - `FPN `__
89 | - `Linknet `__
90 | - `PSPNet `__
91 |
92 | ============= ==============
93 | Unet Linknet
94 | ============= ==============
95 | |unet_image| |linknet_image|
96 | ============= ==============
97 | ============= ==============
98 | PSPNet FPN
99 | ============= ==============
100 | |psp_image| |fpn_image|
101 | ============= ==============
102 |
103 | .. _Unet: https://github.com/qubvel/segmentation_models/blob/readme/LICENSE
104 | .. _Linknet: https://arxiv.org/abs/1707.03718
105 | .. _PSPNet: https://arxiv.org/abs/1612.01105
106 | .. _FPN: http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
107 |
108 | .. |unet_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/3a873a00c9742dc1fb33105ed846d5b5-full.png
109 | .. |linknet_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/1a996c4ef05531ff3861d80823c373d9-full.png
110 | .. |psp_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/aaabb97f89197b40e4879a7299b3c801-full.png
111 | .. |fpn_image| image:: https://cdn1.imggmi.com/uploads/2019/2/8/af00f11ef6bc8a64efd29ed873fcb0c4-full.png
112 |
113 | **Backbones**
114 |
115 | .. table::
116 |
117 | =========== =====
118 | Type Names
119 | =========== =====
120 | VGG ``'vgg16' 'vgg19'``
121 | ResNet ``'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'``
122 | SE-ResNet ``'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'``
123 | ResNeXt ``'resnext50' 'resnet101'``
124 | SE-ResNeXt ``'seresnext50' 'seresnet101'``
125 | SENet154 ``'senet154'``
126 | DenseNet ``'densenet121' 'densenet169' 'densenet201'``
127 | Inception ``'inceptionv3' 'inceptionresnetv2'``
128 | MobileNet ``'mobilenet' 'mobilenetv2'``
129 | EfficientNet ``efficientnetb0`` ``efficientnetb1`` ``efficientnetb2`` ``efficientnetb3`` ``efficientnetb4`` ``efficientnetb5``
130 | =========== =====
131 |
132 | .. epigraph::
133 | All backbones have weights trained on 2012 ILSVRC ImageNet dataset (``encoder_weights='imagenet'``).
134 |
135 |
136 | Fine tuning
137 | ~~~~~~~~~~~
138 |
139 | Some times, it is useful to train only randomly initialized
140 | *decoder* in order not to damage weights of properly trained
141 | *encoder* with huge gradients during first steps of training.
142 | In this case, all you need is just pass ``encoder_freeze = True`` argument
143 | while initializing the model.
144 |
145 | .. code-block:: python
146 |
147 | from segmentation_models import Unet
148 | from segmentation_models.utils import set_trainable
149 |
150 | model = Unet(backbone_name='resnet34', encoder_weights='imagenet', encoder_freeze=True)
151 | model.compile('Adam', 'binary_crossentropy', ['binary_accuracy'])
152 |
153 | # pretrain model decoder
154 | model.fit(x, y, epochs=2)
155 |
156 | # release all layers for training
157 | set_trainable(model) # set all layers trainable and recompile model
158 |
159 | # continue training
160 | model.fit(x, y, epochs=100)
161 |
162 |
163 | Training with non-RGB data
164 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
165 |
166 | In case you have non RGB images (e.g. grayscale or some medical/remote sensing data)
167 | you have few different options:
168 |
169 | 1. Train network from scratch with randomly initialized weights
170 |
171 | .. code-block:: python
172 |
173 | from segmentation_models import Unet
174 |
175 | # read/scale/preprocess data
176 | x, y = ...
177 |
178 | # define number of channels
179 | N = x.shape[-1]
180 |
181 | # define model
182 | model = Unet(backbone_name='resnet34', encoder_weights=None, input_shape=(None, None, N))
183 |
184 | # continue with usual steps: compile, fit, etc..
185 |
186 | 2. Add extra convolution layer to map ``N -> 3`` channels data and train with pretrained weights
187 |
188 | .. code-block:: python
189 |
190 | from segmentation_models import Unet
191 | from keras.layers import Input, Conv2D
192 | from keras.models import Model
193 |
194 | # read/scale/preprocess data
195 | x, y = ...
196 |
197 | # define number of channels
198 | N = x.shape[-1]
199 |
200 | base_model = Unet(backbone_name='resnet34', encoder_weights='imagenet')
201 |
202 | inp = Input(shape=(None, None, N))
203 | l1 = Conv2D(3, (1, 1))(inp) # map N channels data to 3 channels
204 | out = base_model(l1)
205 |
206 | model = Model(inp, out, name=base_model.name)
207 |
208 | # continue with usual steps: compile, fit, etc..
209 |
210 | .. _Image Segmentation:
211 | https://en.wikipedia.org/wiki/Image_segmentation
212 |
213 | .. _Tensorflow:
214 | https://www.tensorflow.org/
215 |
216 | .. _Keras:
217 | https://keras.io
218 |
219 | .. _Unet:
220 | https://arxiv.org/pdf/1505.04597
221 |
222 | .. _Linknet:
223 | https://arxiv.org/pdf/1707.03718.pdf
224 |
225 | .. _PSPNet:
226 | https://arxiv.org/pdf/1612.01105.pdf
227 |
228 | .. _FPN:
229 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
230 |
--------------------------------------------------------------------------------
/images/fpn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/fpn.png
--------------------------------------------------------------------------------
/images/linknet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/linknet.png
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/logo.png
--------------------------------------------------------------------------------
/images/pspnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/pspnet.png
--------------------------------------------------------------------------------
/images/unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/images/unet.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | keras_applications>=1.0.7,<=1.0.8
2 | image-classifiers==1.0.0
3 | efficientnet==1.1.1
4 |
--------------------------------------------------------------------------------
/segmentation_models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import functools
3 | from .__version__ import __version__
4 | from . import base
5 |
6 | _KERAS_FRAMEWORK_NAME = 'keras'
7 | _TF_KERAS_FRAMEWORK_NAME = 'tf.keras'
8 |
9 | _DEFAULT_KERAS_FRAMEWORK = _KERAS_FRAMEWORK_NAME
10 | _KERAS_FRAMEWORK = None
11 | _KERAS_BACKEND = None
12 | _KERAS_LAYERS = None
13 | _KERAS_MODELS = None
14 | _KERAS_UTILS = None
15 | _KERAS_LOSSES = None
16 |
17 |
18 | def inject_global_losses(func):
19 | @functools.wraps(func)
20 | def wrapper(*args, **kwargs):
21 | kwargs['losses'] = _KERAS_LOSSES
22 | return func(*args, **kwargs)
23 |
24 | return wrapper
25 |
26 |
27 | def inject_global_submodules(func):
28 | @functools.wraps(func)
29 | def wrapper(*args, **kwargs):
30 | kwargs['backend'] = _KERAS_BACKEND
31 | kwargs['layers'] = _KERAS_LAYERS
32 | kwargs['models'] = _KERAS_MODELS
33 | kwargs['utils'] = _KERAS_UTILS
34 | return func(*args, **kwargs)
35 |
36 | return wrapper
37 |
38 |
39 | def filter_kwargs(func):
40 | @functools.wraps(func)
41 | def wrapper(*args, **kwargs):
42 | new_kwargs = {k: v for k, v in kwargs.items() if k in ['backend', 'layers', 'models', 'utils']}
43 | return func(*args, **new_kwargs)
44 |
45 | return wrapper
46 |
47 |
48 | def framework():
49 | """Return name of Segmentation Models framework"""
50 | return _KERAS_FRAMEWORK
51 |
52 |
53 | def set_framework(name):
54 | """Set framework for Segmentation Models
55 |
56 | Args:
57 | name (str): one of ``keras``, ``tf.keras``, case insensitive.
58 |
59 | Raises:
60 | ValueError: in case of incorrect framework name.
61 | ImportError: in case framework is not installed.
62 |
63 | """
64 | name = name.lower()
65 |
66 | if name == _KERAS_FRAMEWORK_NAME:
67 | import keras
68 | import efficientnet.keras # init custom objects
69 | elif name == _TF_KERAS_FRAMEWORK_NAME:
70 | from tensorflow import keras
71 | import efficientnet.tfkeras # init custom objects
72 | else:
73 | raise ValueError('Not correct module name `{}`, use `{}` or `{}`'.format(
74 | name, _KERAS_FRAMEWORK_NAME, _TF_KERAS_FRAMEWORK_NAME))
75 |
76 | global _KERAS_BACKEND, _KERAS_LAYERS, _KERAS_MODELS
77 | global _KERAS_UTILS, _KERAS_LOSSES, _KERAS_FRAMEWORK
78 |
79 | _KERAS_FRAMEWORK = name
80 | _KERAS_BACKEND = keras.backend
81 | _KERAS_LAYERS = keras.layers
82 | _KERAS_MODELS = keras.models
83 | _KERAS_UTILS = keras.utils
84 | _KERAS_LOSSES = keras.losses
85 |
86 | # allow losses/metrics get keras submodules
87 | base.KerasObject.set_submodules(
88 | backend=keras.backend,
89 | layers=keras.layers,
90 | models=keras.models,
91 | utils=keras.utils,
92 | )
93 |
94 |
95 | # set default framework
96 | _framework = os.environ.get('SM_FRAMEWORK', _DEFAULT_KERAS_FRAMEWORK)
97 | try:
98 | set_framework(_framework)
99 | except ImportError:
100 | other = _TF_KERAS_FRAMEWORK_NAME if _framework == _KERAS_FRAMEWORK_NAME else _KERAS_FRAMEWORK_NAME
101 | set_framework(other)
102 |
103 | print('Segmentation Models: using `{}` framework.'.format(_KERAS_FRAMEWORK))
104 |
105 | # import helper modules
106 | from . import losses
107 | from . import metrics
108 | from . import utils
109 |
110 | # wrap segmentation models with framework modules
111 | from .backbones.backbones_factory import Backbones
112 | from .models.unet import Unet as _Unet
113 | from .models.pspnet import PSPNet as _PSPNet
114 | from .models.linknet import Linknet as _Linknet
115 | from .models.fpn import FPN as _FPN
116 |
117 | Unet = inject_global_submodules(_Unet)
118 | PSPNet = inject_global_submodules(_PSPNet)
119 | Linknet = inject_global_submodules(_Linknet)
120 | FPN = inject_global_submodules(_FPN)
121 | get_available_backbone_names = Backbones.models_names
122 |
123 |
124 | def get_preprocessing(name):
125 | preprocess_input = Backbones.get_preprocessing(name)
126 | # add bakcend, models, layers, utils submodules in kwargs
127 | preprocess_input = inject_global_submodules(preprocess_input)
128 | # delete other kwargs
129 | # keras-applications preprocessing raise an error if something
130 | # except `backend`, `layers`, `models`, `utils` passed in kwargs
131 | preprocess_input = filter_kwargs(preprocess_input)
132 | return preprocess_input
133 |
134 |
135 | __all__ = [
136 | 'Unet', 'PSPNet', 'FPN', 'Linknet',
137 | 'set_framework', 'framework',
138 | 'get_preprocessing', 'get_available_backbone_names',
139 | 'losses', 'metrics', 'utils',
140 | '__version__',
141 | ]
142 |
--------------------------------------------------------------------------------
/segmentation_models/__version__.py:
--------------------------------------------------------------------------------
1 | VERSION = (1, 0, 1)
2 |
3 | __version__ = '.'.join(map(str, VERSION))
4 |
--------------------------------------------------------------------------------
/segmentation_models/backbones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/segmentation_models/backbones/__init__.py
--------------------------------------------------------------------------------
/segmentation_models/backbones/backbones_factory.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import efficientnet.model as eff
3 | from classification_models.models_factory import ModelsFactory
4 |
5 | from . import inception_resnet_v2 as irv2
6 | from . import inception_v3 as iv3
7 |
8 |
9 | class BackbonesFactory(ModelsFactory):
10 | _default_feature_layers = {
11 |
12 | # List of layers to take features from backbone in the following order:
13 | # (x16, x8, x4, x2, x1) - `x4` mean that features has 4 times less spatial
14 | # resolution (Height x Width) than input image.
15 |
16 | # VGG
17 | 'vgg16': ('block5_conv3', 'block4_conv3', 'block3_conv3', 'block2_conv2', 'block1_conv2'),
18 | 'vgg19': ('block5_conv4', 'block4_conv4', 'block3_conv4', 'block2_conv2', 'block1_conv2'),
19 |
20 | # ResNets
21 | 'resnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
22 | 'resnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
23 | 'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
24 | 'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
25 | 'resnet152': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
26 |
27 | # ResNeXt
28 | 'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
29 | 'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
30 |
31 | # Inception
32 | 'inceptionv3': (228, 86, 16, 9),
33 | 'inceptionresnetv2': (594, 260, 16, 9),
34 |
35 | # DenseNet
36 | 'densenet121': (311, 139, 51, 4),
37 | 'densenet169': (367, 139, 51, 4),
38 | 'densenet201': (479, 139, 51, 4),
39 |
40 | # SE models
41 | 'seresnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
42 | 'seresnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
43 | 'seresnet50': (246, 136, 62, 4),
44 | 'seresnet101': (552, 136, 62, 4),
45 | 'seresnet152': (858, 208, 62, 4),
46 | 'seresnext50': (1078, 584, 254, 4),
47 | 'seresnext101': (2472, 584, 254, 4),
48 | 'senet154': (6884, 1625, 454, 12),
49 |
50 | # Mobile Nets
51 | 'mobilenet': ('conv_pw_11_relu', 'conv_pw_5_relu', 'conv_pw_3_relu', 'conv_pw_1_relu'),
52 | 'mobilenetv2': ('block_13_expand_relu', 'block_6_expand_relu', 'block_3_expand_relu',
53 | 'block_1_expand_relu'),
54 |
55 | # EfficientNets
56 | 'efficientnetb0': ('block6a_expand_activation', 'block4a_expand_activation',
57 | 'block3a_expand_activation', 'block2a_expand_activation'),
58 | 'efficientnetb1': ('block6a_expand_activation', 'block4a_expand_activation',
59 | 'block3a_expand_activation', 'block2a_expand_activation'),
60 | 'efficientnetb2': ('block6a_expand_activation', 'block4a_expand_activation',
61 | 'block3a_expand_activation', 'block2a_expand_activation'),
62 | 'efficientnetb3': ('block6a_expand_activation', 'block4a_expand_activation',
63 | 'block3a_expand_activation', 'block2a_expand_activation'),
64 | 'efficientnetb4': ('block6a_expand_activation', 'block4a_expand_activation',
65 | 'block3a_expand_activation', 'block2a_expand_activation'),
66 | 'efficientnetb5': ('block6a_expand_activation', 'block4a_expand_activation',
67 | 'block3a_expand_activation', 'block2a_expand_activation'),
68 | 'efficientnetb6': ('block6a_expand_activation', 'block4a_expand_activation',
69 | 'block3a_expand_activation', 'block2a_expand_activation'),
70 | 'efficientnetb7': ('block6a_expand_activation', 'block4a_expand_activation',
71 | 'block3a_expand_activation', 'block2a_expand_activation'),
72 |
73 | }
74 |
75 | _models_update = {
76 | 'inceptionresnetv2': [irv2.InceptionResNetV2, irv2.preprocess_input],
77 | 'inceptionv3': [iv3.InceptionV3, iv3.preprocess_input],
78 |
79 | 'efficientnetb0': [eff.EfficientNetB0, eff.preprocess_input],
80 | 'efficientnetb1': [eff.EfficientNetB1, eff.preprocess_input],
81 | 'efficientnetb2': [eff.EfficientNetB2, eff.preprocess_input],
82 | 'efficientnetb3': [eff.EfficientNetB3, eff.preprocess_input],
83 | 'efficientnetb4': [eff.EfficientNetB4, eff.preprocess_input],
84 | 'efficientnetb5': [eff.EfficientNetB5, eff.preprocess_input],
85 | 'efficientnetb6': [eff.EfficientNetB6, eff.preprocess_input],
86 | 'efficientnetb7': [eff.EfficientNetB7, eff.preprocess_input],
87 | }
88 |
89 | # currently not supported
90 | _models_delete = ['resnet50v2', 'resnet101v2', 'resnet152v2',
91 | 'nasnetlarge', 'nasnetmobile', 'xception']
92 |
93 | @property
94 | def models(self):
95 | all_models = copy.copy(self._models)
96 | all_models.update(self._models_update)
97 | for k in self._models_delete:
98 | del all_models[k]
99 | return all_models
100 |
101 | def get_backbone(self, name, *args, **kwargs):
102 | model_fn, _ = self.get(name)
103 | model = model_fn(*args, **kwargs)
104 | return model
105 |
106 | def get_feature_layers(self, name, n=5):
107 | return self._default_feature_layers[name][:n]
108 |
109 | def get_preprocessing(self, name):
110 | return self.get(name)[1]
111 |
112 |
113 | Backbones = BackbonesFactory()
114 |
--------------------------------------------------------------------------------
/segmentation_models/backbones/inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | """Inception-ResNet V2 model for Keras.
2 | Model naming and structure follows TF-slim implementation
3 | (which has some additional layers and different number of
4 | filters from the original arXiv paper):
5 | https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py
6 | Pre-trained ImageNet weights are also converted from TF-slim,
7 | which can be found in:
8 | https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models
9 | # Reference
10 | - [Inception-v4, Inception-ResNet and the Impact of
11 | Residual Connections on Learning](https://arxiv.org/abs/1602.07261) (AAAI 2017)
12 | """
13 | from __future__ import absolute_import
14 | from __future__ import division
15 | from __future__ import print_function
16 |
17 | import os
18 |
19 | from keras_applications import imagenet_utils
20 | from keras_applications import get_submodules_from_kwargs
21 |
22 | BASE_WEIGHT_URL = ('https://github.com/fchollet/deep-learning-models/'
23 | 'releases/download/v0.7/')
24 |
25 | backend = None
26 | layers = None
27 | models = None
28 | keras_utils = None
29 |
30 |
31 | def preprocess_input(x, **kwargs):
32 | """Preprocesses a numpy array encoding a batch of images.
33 | # Arguments
34 | x: a 4D numpy array consists of RGB values within [0, 255].
35 | # Returns
36 | Preprocessed array.
37 | """
38 | return imagenet_utils.preprocess_input(x, mode='tf', **kwargs)
39 |
40 |
41 | def conv2d_bn(x,
42 | filters,
43 | kernel_size,
44 | strides=1,
45 | padding='same',
46 | activation='relu',
47 | use_bias=False,
48 | name=None):
49 | """Utility function to apply conv + BN.
50 | # Arguments
51 | x: input tensor.
52 | filters: filters in `Conv2D`.
53 | kernel_size: kernel size as in `Conv2D`.
54 | strides: strides in `Conv2D`.
55 | padding: padding mode in `Conv2D`.
56 | activation: activation in `Conv2D`.
57 | use_bias: whether to use a bias in `Conv2D`.
58 | name: name of the ops; will become `name + '_ac'` for the activation
59 | and `name + '_bn'` for the batch norm layer.
60 | # Returns
61 | Output tensor after applying `Conv2D` and `BatchNormalization`.
62 | """
63 | x = layers.Conv2D(filters,
64 | kernel_size,
65 | strides=strides,
66 | padding=padding,
67 | use_bias=use_bias,
68 | name=name)(x)
69 | if not use_bias:
70 | bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3
71 | bn_name = None if name is None else name + '_bn'
72 | x = layers.BatchNormalization(axis=bn_axis,
73 | scale=False,
74 | name=bn_name)(x)
75 | if activation is not None:
76 | ac_name = None if name is None else name + '_ac'
77 | x = layers.Activation(activation, name=ac_name)(x)
78 | return x
79 |
80 |
81 | def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
82 | """Adds a Inception-ResNet block.
83 | This function builds 3 types of Inception-ResNet blocks mentioned
84 | in the paper, controlled by the `block_type` argument (which is the
85 | block name used in the official TF-slim implementation):
86 | - Inception-ResNet-A: `block_type='block35'`
87 | - Inception-ResNet-B: `block_type='block17'`
88 | - Inception-ResNet-C: `block_type='block8'`
89 | # Arguments
90 | x: input tensor.
91 | scale: scaling factor to scale the residuals (i.e., the output of
92 | passing `x` through an inception module) before adding them
93 | to the shortcut branch.
94 | Let `r` be the output from the residual branch,
95 | the output of this block will be `x + scale * r`.
96 | block_type: `'block35'`, `'block17'` or `'block8'`, determines
97 | the network structure in the residual branch.
98 | block_idx: an `int` used for generating layer names.
99 | The Inception-ResNet blocks
100 | are repeated many times in this network.
101 | We use `block_idx` to identify
102 | each of the repetitions. For example,
103 | the first Inception-ResNet-A block
104 | will have `block_type='block35', block_idx=0`,
105 | and the layer names will have
106 | a common prefix `'block35_0'`.
107 | activation: activation function to use at the end of the block
108 | (see [activations](../activations.md)).
109 | When `activation=None`, no activation is applied
110 | (i.e., "linear" activation: `a(x) = x`).
111 | # Returns
112 | Output tensor for the block.
113 | # Raises
114 | ValueError: if `block_type` is not one of `'block35'`,
115 | `'block17'` or `'block8'`.
116 | """
117 | if block_type == 'block35':
118 | branch_0 = conv2d_bn(x, 32, 1)
119 | branch_1 = conv2d_bn(x, 32, 1)
120 | branch_1 = conv2d_bn(branch_1, 32, 3)
121 | branch_2 = conv2d_bn(x, 32, 1)
122 | branch_2 = conv2d_bn(branch_2, 48, 3)
123 | branch_2 = conv2d_bn(branch_2, 64, 3)
124 | branches = [branch_0, branch_1, branch_2]
125 | elif block_type == 'block17':
126 | branch_0 = conv2d_bn(x, 192, 1)
127 | branch_1 = conv2d_bn(x, 128, 1)
128 | branch_1 = conv2d_bn(branch_1, 160, [1, 7])
129 | branch_1 = conv2d_bn(branch_1, 192, [7, 1])
130 | branches = [branch_0, branch_1]
131 | elif block_type == 'block8':
132 | branch_0 = conv2d_bn(x, 192, 1)
133 | branch_1 = conv2d_bn(x, 192, 1)
134 | branch_1 = conv2d_bn(branch_1, 224, [1, 3])
135 | branch_1 = conv2d_bn(branch_1, 256, [3, 1])
136 | branches = [branch_0, branch_1]
137 | else:
138 | raise ValueError('Unknown Inception-ResNet block type. '
139 | 'Expects "block35", "block17" or "block8", '
140 | 'but got: ' + str(block_type))
141 |
142 | block_name = block_type + '_' + str(block_idx)
143 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
144 | mixed = layers.Concatenate(
145 | axis=channel_axis, name=block_name + '_mixed')(branches)
146 | up = conv2d_bn(mixed,
147 | backend.int_shape(x)[channel_axis],
148 | 1,
149 | activation=None,
150 | use_bias=True,
151 | name=block_name + '_conv')
152 |
153 | x = layers.Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
154 | output_shape=backend.int_shape(x)[1:],
155 | arguments={'scale': scale},
156 | name=block_name)([x, up])
157 | if activation is not None:
158 | x = layers.Activation(activation, name=block_name + '_ac')(x)
159 | return x
160 |
161 |
162 | def InceptionResNetV2(include_top=True,
163 | weights='imagenet',
164 | input_tensor=None,
165 | input_shape=None,
166 | pooling=None,
167 | classes=1000,
168 | **kwargs):
169 | """Instantiates the Inception-ResNet v2 architecture.
170 | Optionally loads weights pre-trained on ImageNet.
171 | Note that the data format convention used by the model is
172 | the one specified in your Keras config at `~/.keras/keras.json`.
173 | # Arguments
174 | include_top: whether to include the fully-connected
175 | layer at the top of the network.
176 | weights: one of `None` (random initialization),
177 | 'imagenet' (pre-training on ImageNet),
178 | or the path to the weights file to be loaded.
179 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
180 | to use as image input for the model.
181 | input_shape: optional shape tuple, only to be specified
182 | if `include_top` is `False` (otherwise the input shape
183 | has to be `(299, 299, 3)` (with `'channels_last'` data format)
184 | or `(3, 299, 299)` (with `'channels_first'` data format).
185 | It should have exactly 3 inputs channels,
186 | and width and height should be no smaller than 75.
187 | E.g. `(150, 150, 3)` would be one valid value.
188 | pooling: Optional pooling mode for feature extraction
189 | when `include_top` is `False`.
190 | - `None` means that the output of the model will be
191 | the 4D tensor output of the last convolutional block.
192 | - `'avg'` means that global average pooling
193 | will be applied to the output of the
194 | last convolutional block, and thus
195 | the output of the model will be a 2D tensor.
196 | - `'max'` means that global max pooling will be applied.
197 | classes: optional number of classes to classify images
198 | into, only to be specified if `include_top` is `True`, and
199 | if no `weights` argument is specified.
200 | # Returns
201 | A Keras `Model` instance.
202 | # Raises
203 | ValueError: in case of invalid argument for `weights`,
204 | or invalid input shape.
205 | """
206 | global backend, layers, models, keras_utils
207 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
208 |
209 | if not (weights in {'imagenet', None} or os.path.exists(weights)):
210 | raise ValueError('The `weights` argument should be either '
211 | '`None` (random initialization), `imagenet` '
212 | '(pre-training on ImageNet), '
213 | 'or the path to the weights file to be loaded.')
214 |
215 | if weights == 'imagenet' and include_top and classes != 1000:
216 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
217 | ' as true, `classes` should be 1000')
218 |
219 | # Determine proper input shape
220 | input_shape = imagenet_utils._obtain_input_shape(
221 | input_shape,
222 | default_size=299,
223 | min_size=32,
224 | data_format=backend.image_data_format(),
225 | require_flatten=include_top,
226 | weights=weights)
227 |
228 | if input_tensor is None:
229 | img_input = layers.Input(shape=input_shape)
230 | else:
231 | if not backend.is_keras_tensor(input_tensor):
232 | img_input = layers.Input(tensor=input_tensor, shape=input_shape)
233 | else:
234 | img_input = input_tensor
235 |
236 | # Stem block: 35 x 35 x 192
237 | x = conv2d_bn(img_input, 32, 3, strides=2, padding='same')
238 | x = conv2d_bn(x, 32, 3, padding='same')
239 | x = conv2d_bn(x, 64, 3, padding='same')
240 | x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
241 | x = conv2d_bn(x, 80, 1, padding='same')
242 | x = conv2d_bn(x, 192, 3, padding='same')
243 | x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
244 |
245 | # Mixed 5b (Inception-A block): 35 x 35 x 320
246 | branch_0 = conv2d_bn(x, 96, 1, padding='same')
247 | branch_1 = conv2d_bn(x, 48, 1, padding='same')
248 | branch_1 = conv2d_bn(branch_1, 64, 5, padding='same')
249 | branch_2 = conv2d_bn(x, 64, 1, padding='same')
250 | branch_2 = conv2d_bn(branch_2, 96, 3, padding='same')
251 | branch_2 = conv2d_bn(branch_2, 96, 3, padding='same')
252 | branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x)
253 | branch_pool = conv2d_bn(branch_pool, 64, 1, padding='same')
254 | branches = [branch_0, branch_1, branch_2, branch_pool]
255 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
256 | x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches)
257 |
258 | # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
259 | for block_idx in range(1, 11):
260 | x = inception_resnet_block(x,
261 | scale=0.17,
262 | block_type='block35',
263 | block_idx=block_idx)
264 |
265 | # Mixed 6a (Reduction-A block): 17 x 17 x 1088
266 | branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='same')
267 | branch_1 = conv2d_bn(x, 256, 1, padding='same')
268 | branch_1 = conv2d_bn(branch_1, 256, 3, padding='same')
269 | branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='same')
270 | branch_pool = layers.MaxPooling2D(3, strides=2, padding='same')(x)
271 | branches = [branch_0, branch_1, branch_pool]
272 | x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches)
273 |
274 | # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
275 | for block_idx in range(1, 21):
276 | x = inception_resnet_block(x,
277 | scale=0.1,
278 | block_type='block17',
279 | block_idx=block_idx)
280 |
281 | # Mixed 7a (Reduction-B block): 8 x 8 x 2080
282 | branch_0 = conv2d_bn(x, 256, 1, padding='same')
283 | branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='same')
284 | branch_1 = conv2d_bn(x, 256, 1, padding='same')
285 | branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='same')
286 | branch_2 = conv2d_bn(x, 256, 1, padding='same')
287 | branch_2 = conv2d_bn(branch_2, 288, 3, padding='same')
288 | branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='same')
289 | branch_pool = layers.MaxPooling2D(3, strides=2, padding='same')(x)
290 | branches = [branch_0, branch_1, branch_2, branch_pool]
291 | x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches)
292 |
293 | # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
294 | for block_idx in range(1, 10):
295 | x = inception_resnet_block(x,
296 | scale=0.2,
297 | block_type='block8',
298 | block_idx=block_idx)
299 | x = inception_resnet_block(x,
300 | scale=1.,
301 | activation=None,
302 | block_type='block8',
303 | block_idx=10)
304 |
305 | # Final convolution block: 8 x 8 x 1536
306 | x = conv2d_bn(x, 1536, 1, name='conv_7b')
307 |
308 | if include_top:
309 | # Classification block
310 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
311 | x = layers.Dense(classes, activation='softmax', name='predictions')(x)
312 | else:
313 | if pooling == 'avg':
314 | x = layers.GlobalAveragePooling2D()(x)
315 | elif pooling == 'max':
316 | x = layers.GlobalMaxPooling2D()(x)
317 |
318 | # Ensure that the model takes into account
319 | # any potential predecessors of `input_tensor`.
320 | if input_tensor is not None:
321 | inputs = keras_utils.get_source_inputs(input_tensor)
322 | else:
323 | inputs = img_input
324 |
325 | # Create model.
326 | model = models.Model(inputs, x, name='inception_resnet_v2')
327 |
328 | # Load weights.
329 | if weights == 'imagenet':
330 | if include_top:
331 | fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5'
332 | weights_path = keras_utils.get_file(
333 | fname,
334 | BASE_WEIGHT_URL + fname,
335 | cache_subdir='models',
336 | file_hash='e693bd0210a403b3192acc6073ad2e96')
337 | else:
338 | fname = ('inception_resnet_v2_weights_'
339 | 'tf_dim_ordering_tf_kernels_notop.h5')
340 | weights_path = keras_utils.get_file(
341 | fname,
342 | BASE_WEIGHT_URL + fname,
343 | cache_subdir='models',
344 | file_hash='d19885ff4a710c122648d3b5c3b684e4')
345 | model.load_weights(weights_path)
346 | elif weights is not None:
347 | model.load_weights(weights)
348 |
349 | return model
350 |
--------------------------------------------------------------------------------
/segmentation_models/backbones/inception_v3.py:
--------------------------------------------------------------------------------
1 | """Inception V3 model for Keras.
2 | Note that the input image format for this model is different than for
3 | the VGG16 and ResNet models (299x299 instead of 224x224),
4 | and that the input preprocessing function is also different (same as Xception).
5 | # Reference
6 | - [Rethinking the Inception Architecture for Computer Vision](
7 | http://arxiv.org/abs/1512.00567) (CVPR 2016)
8 | """
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | import os
14 |
15 | from keras_applications import imagenet_utils
16 | from keras_applications import get_submodules_from_kwargs
17 |
18 | WEIGHTS_PATH = (
19 | 'https://github.com/fchollet/deep-learning-models/'
20 | 'releases/download/v0.5/'
21 | 'inception_v3_weights_tf_dim_ordering_tf_kernels.h5')
22 | WEIGHTS_PATH_NO_TOP = (
23 | 'https://github.com/fchollet/deep-learning-models/'
24 | 'releases/download/v0.5/'
25 | 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5')
26 |
27 | backend = None
28 | layers = None
29 | models = None
30 | keras_utils = None
31 |
32 |
33 | def conv2d_bn(x,
34 | filters,
35 | num_row,
36 | num_col,
37 | padding='same',
38 | strides=(1, 1),
39 | name=None):
40 | """Utility function to apply conv + BN.
41 | # Arguments
42 | x: input tensor.
43 | filters: filters in `Conv2D`.
44 | num_row: height of the convolution kernel.
45 | num_col: width of the convolution kernel.
46 | padding: padding mode in `Conv2D`.
47 | strides: strides in `Conv2D`.
48 | name: name of the ops; will become `name + '_conv'`
49 | for the convolution and `name + '_bn'` for the
50 | batch norm layer.
51 | # Returns
52 | Output tensor after applying `Conv2D` and `BatchNormalization`.
53 | """
54 | if name is not None:
55 | bn_name = name + '_bn'
56 | conv_name = name + '_conv'
57 | else:
58 | bn_name = None
59 | conv_name = None
60 | if backend.image_data_format() == 'channels_first':
61 | bn_axis = 1
62 | else:
63 | bn_axis = 3
64 | x = layers.Conv2D(
65 | filters, (num_row, num_col),
66 | strides=strides,
67 | padding=padding,
68 | use_bias=False,
69 | name=conv_name)(x)
70 | x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
71 | x = layers.Activation('relu', name=name)(x)
72 | return x
73 |
74 |
75 | def InceptionV3(include_top=True,
76 | weights='imagenet',
77 | input_tensor=None,
78 | input_shape=None,
79 | pooling=None,
80 | classes=1000,
81 | **kwargs):
82 | """Instantiates the Inception v3 architecture.
83 | Optionally loads weights pre-trained on ImageNet.
84 | Note that the data format convention used by the model is
85 | the one specified in your Keras config at `~/.keras/keras.json`.
86 | # Arguments
87 | include_top: whether to include the fully-connected
88 | layer at the top of the network.
89 | weights: one of `None` (random initialization),
90 | 'imagenet' (pre-training on ImageNet),
91 | or the path to the weights file to be loaded.
92 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
93 | to use as image input for the model.
94 | input_shape: optional shape tuple, only to be specified
95 | if `include_top` is False (otherwise the input shape
96 | has to be `(299, 299, 3)` (with `channels_last` data format)
97 | or `(3, 299, 299)` (with `channels_first` data format).
98 | It should have exactly 3 inputs channels,
99 | and width and height should be no smaller than 75.
100 | E.g. `(150, 150, 3)` would be one valid value.
101 | pooling: Optional pooling mode for feature extraction
102 | when `include_top` is `False`.
103 | - `None` means that the output of the model will be
104 | the 4D tensor output of the
105 | last convolutional block.
106 | - `avg` means that global average pooling
107 | will be applied to the output of the
108 | last convolutional block, and thus
109 | the output of the model will be a 2D tensor.
110 | - `max` means that global max pooling will
111 | be applied.
112 | classes: optional number of classes to classify images
113 | into, only to be specified if `include_top` is True, and
114 | if no `weights` argument is specified.
115 | # Returns
116 | A Keras model instance.
117 | # Raises
118 | ValueError: in case of invalid argument for `weights`,
119 | or invalid input shape.
120 | """
121 | global backend, layers, models, keras_utils
122 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
123 |
124 | if not (weights in {'imagenet', None} or os.path.exists(weights)):
125 | raise ValueError('The `weights` argument should be either '
126 | '`None` (random initialization), `imagenet` '
127 | '(pre-training on ImageNet), '
128 | 'or the path to the weights file to be loaded.')
129 |
130 | if weights == 'imagenet' and include_top and classes != 1000:
131 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
132 | ' as true, `classes` should be 1000')
133 |
134 | # Determine proper input shape
135 | input_shape = imagenet_utils._obtain_input_shape(
136 | input_shape,
137 | default_size=299,
138 | min_size=75,
139 | data_format=backend.image_data_format(),
140 | require_flatten=include_top,
141 | weights=weights)
142 |
143 | if input_tensor is None:
144 | img_input = layers.Input(shape=input_shape)
145 | else:
146 | if not backend.is_keras_tensor(input_tensor):
147 | img_input = layers.Input(tensor=input_tensor, shape=input_shape)
148 | else:
149 | img_input = input_tensor
150 |
151 | if backend.image_data_format() == 'channels_first':
152 | channel_axis = 1
153 | else:
154 | channel_axis = 3
155 |
156 | x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding='same')
157 | x = conv2d_bn(x, 32, 3, 3, padding='same')
158 | x = conv2d_bn(x, 64, 3, 3, padding='same')
159 | x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
160 |
161 | x = conv2d_bn(x, 80, 1, 1, padding='same')
162 | x = conv2d_bn(x, 192, 3, 3, padding='same')
163 | x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
164 |
165 | # mixed 0: 35 x 35 x 256
166 | branch1x1 = conv2d_bn(x, 64, 1, 1)
167 |
168 | branch5x5 = conv2d_bn(x, 48, 1, 1)
169 | branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
170 |
171 | branch3x3dbl = conv2d_bn(x, 64, 1, 1)
172 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
173 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
174 |
175 | branch_pool = layers.AveragePooling2D((3, 3),
176 | strides=(1, 1),
177 | padding='same')(x)
178 | branch_pool = conv2d_bn(branch_pool, 32, 1, 1)
179 | x = layers.concatenate(
180 | [branch1x1, branch5x5, branch3x3dbl, branch_pool],
181 | axis=channel_axis,
182 | name='mixed0')
183 |
184 | # mixed 1: 35 x 35 x 288
185 | branch1x1 = conv2d_bn(x, 64, 1, 1)
186 |
187 | branch5x5 = conv2d_bn(x, 48, 1, 1)
188 | branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
189 |
190 | branch3x3dbl = conv2d_bn(x, 64, 1, 1)
191 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
192 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
193 |
194 | branch_pool = layers.AveragePooling2D((3, 3),
195 | strides=(1, 1),
196 | padding='same')(x)
197 | branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
198 | x = layers.concatenate(
199 | [branch1x1, branch5x5, branch3x3dbl, branch_pool],
200 | axis=channel_axis,
201 | name='mixed1')
202 |
203 | # mixed 2: 35 x 35 x 288
204 | branch1x1 = conv2d_bn(x, 64, 1, 1)
205 |
206 | branch5x5 = conv2d_bn(x, 48, 1, 1)
207 | branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)
208 |
209 | branch3x3dbl = conv2d_bn(x, 64, 1, 1)
210 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
211 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
212 |
213 | branch_pool = layers.AveragePooling2D((3, 3),
214 | strides=(1, 1),
215 | padding='same')(x)
216 | branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
217 | x = layers.concatenate(
218 | [branch1x1, branch5x5, branch3x3dbl, branch_pool],
219 | axis=channel_axis,
220 | name='mixed2')
221 |
222 | # mixed 3: 17 x 17 x 768
223 | branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding='same')
224 |
225 | branch3x3dbl = conv2d_bn(x, 64, 1, 1)
226 | branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
227 | branch3x3dbl = conv2d_bn(
228 | branch3x3dbl, 96, 3, 3, strides=(2, 2), padding='same')
229 |
230 | branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
231 | x = layers.concatenate(
232 | [branch3x3, branch3x3dbl, branch_pool],
233 | axis=channel_axis,
234 | name='mixed3')
235 |
236 | # mixed 4: 17 x 17 x 768
237 | branch1x1 = conv2d_bn(x, 192, 1, 1)
238 |
239 | branch7x7 = conv2d_bn(x, 128, 1, 1)
240 | branch7x7 = conv2d_bn(branch7x7, 128, 1, 7)
241 | branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
242 |
243 | branch7x7dbl = conv2d_bn(x, 128, 1, 1)
244 | branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
245 | branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7)
246 | branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
247 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
248 |
249 | branch_pool = layers.AveragePooling2D((3, 3),
250 | strides=(1, 1),
251 | padding='same')(x)
252 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
253 | x = layers.concatenate(
254 | [branch1x1, branch7x7, branch7x7dbl, branch_pool],
255 | axis=channel_axis,
256 | name='mixed4')
257 |
258 | # mixed 5, 6: 17 x 17 x 768
259 | for i in range(2):
260 | branch1x1 = conv2d_bn(x, 192, 1, 1)
261 |
262 | branch7x7 = conv2d_bn(x, 160, 1, 1)
263 | branch7x7 = conv2d_bn(branch7x7, 160, 1, 7)
264 | branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
265 |
266 | branch7x7dbl = conv2d_bn(x, 160, 1, 1)
267 | branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
268 | branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7)
269 | branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
270 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
271 |
272 | branch_pool = layers.AveragePooling2D(
273 | (3, 3), strides=(1, 1), padding='same')(x)
274 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
275 | x = layers.concatenate(
276 | [branch1x1, branch7x7, branch7x7dbl, branch_pool],
277 | axis=channel_axis,
278 | name='mixed' + str(5 + i))
279 |
280 | # mixed 7: 17 x 17 x 768
281 | branch1x1 = conv2d_bn(x, 192, 1, 1)
282 |
283 | branch7x7 = conv2d_bn(x, 192, 1, 1)
284 | branch7x7 = conv2d_bn(branch7x7, 192, 1, 7)
285 | branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)
286 |
287 | branch7x7dbl = conv2d_bn(x, 192, 1, 1)
288 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
289 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
290 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
291 | branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
292 |
293 | branch_pool = layers.AveragePooling2D((3, 3),
294 | strides=(1, 1),
295 | padding='same')(x)
296 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
297 | x = layers.concatenate(
298 | [branch1x1, branch7x7, branch7x7dbl, branch_pool],
299 | axis=channel_axis,
300 | name='mixed7')
301 |
302 | # mixed 8: 8 x 8 x 1280
303 | branch3x3 = conv2d_bn(x, 192, 1, 1)
304 | branch3x3 = conv2d_bn(branch3x3, 320, 3, 3,
305 | strides=(2, 2), padding='same')
306 |
307 | branch7x7x3 = conv2d_bn(x, 192, 1, 1)
308 | branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7)
309 | branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1)
310 | branch7x7x3 = conv2d_bn(
311 | branch7x7x3, 192, 3, 3, strides=(2, 2), padding='same')
312 |
313 | branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
314 | x = layers.concatenate(
315 | [branch3x3, branch7x7x3, branch_pool],
316 | axis=channel_axis,
317 | name='mixed8')
318 |
319 | # mixed 9: 8 x 8 x 2048
320 | for i in range(2):
321 | branch1x1 = conv2d_bn(x, 320, 1, 1)
322 |
323 | branch3x3 = conv2d_bn(x, 384, 1, 1)
324 | branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3)
325 | branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1)
326 | branch3x3 = layers.concatenate(
327 | [branch3x3_1, branch3x3_2],
328 | axis=channel_axis,
329 | name='mixed9_' + str(i))
330 |
331 | branch3x3dbl = conv2d_bn(x, 448, 1, 1)
332 | branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3)
333 | branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3)
334 | branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1)
335 | branch3x3dbl = layers.concatenate(
336 | [branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis)
337 |
338 | branch_pool = layers.AveragePooling2D(
339 | (3, 3), strides=(1, 1), padding='same')(x)
340 | branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
341 | x = layers.concatenate(
342 | [branch1x1, branch3x3, branch3x3dbl, branch_pool],
343 | axis=channel_axis,
344 | name='mixed' + str(9 + i))
345 | if include_top:
346 | # Classification block
347 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
348 | x = layers.Dense(classes, activation='softmax', name='predictions')(x)
349 | else:
350 | if pooling == 'avg':
351 | x = layers.GlobalAveragePooling2D()(x)
352 | elif pooling == 'max':
353 | x = layers.GlobalMaxPooling2D()(x)
354 |
355 | # Ensure that the model takes into account
356 | # any potential predecessors of `input_tensor`.
357 | if input_tensor is not None:
358 | inputs = keras_utils.get_source_inputs(input_tensor)
359 | else:
360 | inputs = img_input
361 | # Create model.
362 | model = models.Model(inputs, x, name='inception_v3')
363 |
364 | # Load weights.
365 | if weights == 'imagenet':
366 | if include_top:
367 | weights_path = keras_utils.get_file(
368 | 'inception_v3_weights_tf_dim_ordering_tf_kernels.h5',
369 | WEIGHTS_PATH,
370 | cache_subdir='models',
371 | file_hash='9a0d58056eeedaa3f26cb7ebd46da564')
372 | else:
373 | weights_path = keras_utils.get_file(
374 | 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5',
375 | WEIGHTS_PATH_NO_TOP,
376 | cache_subdir='models',
377 | file_hash='bcbd6486424b2319ff4ef7d526e38f63')
378 | model.load_weights(weights_path)
379 | elif weights is not None:
380 | model.load_weights(weights)
381 |
382 | return model
383 |
384 |
385 | def preprocess_input(x, **kwargs):
386 | """Preprocesses a numpy array encoding a batch of images.
387 | # Arguments
388 | x: a 4D numpy array consists of RGB values within [0, 255].
389 | # Returns
390 | Preprocessed array.
391 | """
392 | return imagenet_utils.preprocess_input(x, mode='tf', **kwargs)
393 |
--------------------------------------------------------------------------------
/segmentation_models/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .objects import KerasObject, Loss, Metric
2 | from . import functional
--------------------------------------------------------------------------------
/segmentation_models/base/functional.py:
--------------------------------------------------------------------------------
1 | SMOOTH = 1e-5
2 |
3 |
4 | # ----------------------------------------------------------------
5 | # Helpers
6 | # ----------------------------------------------------------------
7 |
8 | def _gather_channels(x, indexes, **kwargs):
9 | """Slice tensor along channels axis by given indexes"""
10 | backend = kwargs['backend']
11 | if backend.image_data_format() == 'channels_last':
12 | x = backend.permute_dimensions(x, (3, 0, 1, 2))
13 | x = backend.gather(x, indexes)
14 | x = backend.permute_dimensions(x, (1, 2, 3, 0))
15 | else:
16 | x = backend.permute_dimensions(x, (1, 0, 2, 3))
17 | x = backend.gather(x, indexes)
18 | x = backend.permute_dimensions(x, (1, 0, 2, 3))
19 | return x
20 |
21 |
22 | def get_reduce_axes(per_image, **kwargs):
23 | backend = kwargs['backend']
24 | axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
25 | if not per_image:
26 | axes.insert(0, 0)
27 | return axes
28 |
29 |
30 | def gather_channels(*xs, indexes=None, **kwargs):
31 | """Slice tensors along channels axis by given indexes"""
32 | if indexes is None:
33 | return xs
34 | elif isinstance(indexes, (int)):
35 | indexes = [indexes]
36 | xs = [_gather_channels(x, indexes=indexes, **kwargs) for x in xs]
37 | return xs
38 |
39 |
40 | def round_if_needed(x, threshold, **kwargs):
41 | backend = kwargs['backend']
42 | if threshold is not None:
43 | x = backend.greater(x, threshold)
44 | x = backend.cast(x, backend.floatx())
45 | return x
46 |
47 |
48 | def average(x, per_image=False, class_weights=None, **kwargs):
49 | backend = kwargs['backend']
50 | if per_image:
51 | x = backend.mean(x, axis=0)
52 | if class_weights is not None:
53 | x = x * class_weights
54 | return backend.mean(x)
55 |
56 |
57 | # ----------------------------------------------------------------
58 | # Metric Functions
59 | # ----------------------------------------------------------------
60 |
61 | def iou_score(gt, pr, class_weights=1., class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs):
62 | r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient
63 | (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the
64 | similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets,
65 | and is defined as the size of the intersection divided by the size of the union of the sample sets:
66 |
67 | .. math:: J(A, B) = \frac{A \cap B}{A \cup B}
68 |
69 | Args:
70 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
71 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
72 | class_weights: 1. or list of class weights, len(weights) = C
73 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
74 | smooth: value to avoid division by zero
75 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
76 | else over whole batch
77 | threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round
78 |
79 | Returns:
80 | IoU/Jaccard score in range [0, 1]
81 |
82 | .. _`Jaccard index`: https://en.wikipedia.org/wiki/Jaccard_index
83 |
84 | """
85 |
86 | backend = kwargs['backend']
87 |
88 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
89 | pr = round_if_needed(pr, threshold, **kwargs)
90 | axes = get_reduce_axes(per_image, **kwargs)
91 |
92 | # score calculation
93 | intersection = backend.sum(gt * pr, axis=axes)
94 | union = backend.sum(gt + pr, axis=axes) - intersection
95 |
96 | score = (intersection + smooth) / (union + smooth)
97 | score = average(score, per_image, class_weights, **kwargs)
98 |
99 | return score
100 |
101 |
102 | def f_score(gt, pr, beta=1, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None,
103 | **kwargs):
104 | r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall,
105 | where an F-score reaches its best value at 1 and worst score at 0.
106 | The relative contribution of ``precision`` and ``recall`` to the F1-score are equal.
107 | The formula for the F score is:
108 |
109 | .. math:: F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall}
110 | {\beta^2 \cdot precision + recall}
111 |
112 | The formula in terms of *Type I* and *Type II* errors:
113 |
114 | .. math:: F_\beta(A, B) = \frac{(1 + \beta^2) TP} {(1 + \beta^2) TP + \beta^2 FN + FP}
115 |
116 |
117 | where:
118 | TP - true positive;
119 | FP - false positive;
120 | FN - false negative;
121 |
122 | Args:
123 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
124 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
125 | class_weights: 1. or list of class weights, len(weights) = C
126 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
127 | beta: f-score coefficient
128 | smooth: value to avoid division by zero
129 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
130 | else over whole batch
131 | threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round
132 |
133 | Returns:
134 | F-score in range [0, 1]
135 |
136 | """
137 |
138 | backend = kwargs['backend']
139 |
140 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
141 | pr = round_if_needed(pr, threshold, **kwargs)
142 | axes = get_reduce_axes(per_image, **kwargs)
143 |
144 | # calculate score
145 | tp = backend.sum(gt * pr, axis=axes)
146 | fp = backend.sum(pr, axis=axes) - tp
147 | fn = backend.sum(gt, axis=axes) - tp
148 |
149 | score = ((1 + beta ** 2) * tp + smooth) \
150 | / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
151 | score = average(score, per_image, class_weights, **kwargs)
152 |
153 | return score
154 |
155 |
156 | def precision(gt, pr, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs):
157 | r"""Calculate precision between the ground truth (gt) and the prediction (pr).
158 |
159 | .. math:: F_\beta(tp, fp) = \frac{tp} {(tp + fp)}
160 |
161 | where:
162 | - tp - true positives;
163 | - fp - false positives;
164 |
165 | Args:
166 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
167 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
168 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``)
169 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
170 | smooth: Float value to avoid division by zero.
171 | per_image: If ``True``, metric is calculated as mean over images in batch (B),
172 | else over whole batch.
173 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round.
174 | name: Optional string, if ``None`` default ``precision`` name is used.
175 |
176 | Returns:
177 | float: precision score
178 | """
179 | backend = kwargs['backend']
180 |
181 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
182 | pr = round_if_needed(pr, threshold, **kwargs)
183 | axes = get_reduce_axes(per_image, **kwargs)
184 |
185 | # score calculation
186 | tp = backend.sum(gt * pr, axis=axes)
187 | fp = backend.sum(pr, axis=axes) - tp
188 |
189 | score = (tp + smooth) / (tp + fp + smooth)
190 | score = average(score, per_image, class_weights, **kwargs)
191 |
192 | return score
193 |
194 |
195 | def recall(gt, pr, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs):
196 | r"""Calculate recall between the ground truth (gt) and the prediction (pr).
197 |
198 | .. math:: F_\beta(tp, fn) = \frac{tp} {(tp + fn)}
199 |
200 | where:
201 | - tp - true positives;
202 | - fp - false positives;
203 |
204 | Args:
205 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
206 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
207 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``)
208 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
209 | smooth: Float value to avoid division by zero.
210 | per_image: If ``True``, metric is calculated as mean over images in batch (B),
211 | else over whole batch.
212 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round.
213 | name: Optional string, if ``None`` default ``precision`` name is used.
214 |
215 | Returns:
216 | float: recall score
217 | """
218 | backend = kwargs['backend']
219 |
220 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
221 | pr = round_if_needed(pr, threshold, **kwargs)
222 | axes = get_reduce_axes(per_image, **kwargs)
223 |
224 | tp = backend.sum(gt * pr, axis=axes)
225 | fn = backend.sum(gt, axis=axes) - tp
226 |
227 | score = (tp + smooth) / (tp + fn + smooth)
228 | score = average(score, per_image, class_weights, **kwargs)
229 |
230 | return score
231 |
232 |
233 | # ----------------------------------------------------------------
234 | # Loss Functions
235 | # ----------------------------------------------------------------
236 |
237 | def categorical_crossentropy(gt, pr, class_weights=1., class_indexes=None, **kwargs):
238 | backend = kwargs['backend']
239 |
240 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
241 |
242 | # scale predictions so that the class probas of each sample sum to 1
243 | axis = 3 if backend.image_data_format() == 'channels_last' else 1
244 | pr /= backend.sum(pr, axis=axis, keepdims=True)
245 |
246 | # clip to prevent NaN's and Inf's
247 | pr = backend.clip(pr, backend.epsilon(), 1 - backend.epsilon())
248 |
249 | # calculate loss
250 | output = gt * backend.log(pr) * class_weights
251 | return - backend.mean(output)
252 |
253 |
254 | def binary_crossentropy(gt, pr, **kwargs):
255 | backend = kwargs['backend']
256 | return backend.mean(backend.binary_crossentropy(gt, pr))
257 |
258 |
259 | def categorical_focal_loss(gt, pr, gamma=2.0, alpha=0.25, class_indexes=None, **kwargs):
260 | r"""Implementation of Focal Loss from the paper in multiclass classification
261 |
262 | Formula:
263 | loss = - gt * alpha * ((1 - pr)^gamma) * log(pr)
264 |
265 | Args:
266 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
267 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
268 | alpha: the same as weighting factor in balanced cross entropy, default 0.25
269 | gamma: focusing parameter for modulating factor (1-p), default 2.0
270 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
271 |
272 | """
273 |
274 | backend = kwargs['backend']
275 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
276 |
277 | # clip to prevent NaN's and Inf's
278 | pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon())
279 |
280 | # Calculate focal loss
281 | loss = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr))
282 |
283 | return backend.mean(loss)
284 |
285 |
286 | def binary_focal_loss(gt, pr, gamma=2.0, alpha=0.25, **kwargs):
287 | r"""Implementation of Focal Loss from the paper in binary classification
288 |
289 | Formula:
290 | loss = - gt * alpha * ((1 - pr)^gamma) * log(pr) \
291 | - (1 - gt) * alpha * (pr^gamma) * log(1 - pr)
292 |
293 | Args:
294 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
295 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
296 | alpha: the same as weighting factor in balanced cross entropy, default 0.25
297 | gamma: focusing parameter for modulating factor (1-p), default 2.0
298 |
299 | """
300 | backend = kwargs['backend']
301 |
302 | # clip to prevent NaN's and Inf's
303 | pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon())
304 |
305 | loss_1 = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr))
306 | loss_0 = - (1 - gt) * ((1 - alpha) * backend.pow((pr), gamma) * backend.log(1 - pr))
307 | loss = backend.mean(loss_0 + loss_1)
308 | return loss
309 |
--------------------------------------------------------------------------------
/segmentation_models/base/objects.py:
--------------------------------------------------------------------------------
1 | class KerasObject:
2 | _backend = None
3 | _models = None
4 | _layers = None
5 | _utils = None
6 |
7 | def __init__(self, name=None):
8 | if (self.backend is None or
9 | self.utils is None or
10 | self.models is None or
11 | self.layers is None):
12 | raise RuntimeError('You cannot use `KerasObjects` with None submodules.')
13 |
14 | self._name = name
15 |
16 | @property
17 | def __name__(self):
18 | if self._name is None:
19 | return self.__class__.__name__
20 | return self._name
21 |
22 | @property
23 | def name(self):
24 | return self.__name__
25 |
26 | @name.setter
27 | def name(self, name):
28 | self._name = name
29 |
30 | @classmethod
31 | def set_submodules(cls, backend, layers, models, utils):
32 | cls._backend = backend
33 | cls._layers = layers
34 | cls._models = models
35 | cls._utils = utils
36 |
37 | @property
38 | def submodules(self):
39 | return {
40 | 'backend': self.backend,
41 | 'layers': self.layers,
42 | 'models': self.models,
43 | 'utils': self.utils,
44 | }
45 |
46 | @property
47 | def backend(self):
48 | return self._backend
49 |
50 | @property
51 | def layers(self):
52 | return self._layers
53 |
54 | @property
55 | def models(self):
56 | return self._models
57 |
58 | @property
59 | def utils(self):
60 | return self._utils
61 |
62 |
63 | class Metric(KerasObject):
64 | pass
65 |
66 |
67 | class Loss(KerasObject):
68 |
69 | def __add__(self, other):
70 | if isinstance(other, Loss):
71 | return SumOfLosses(self, other)
72 | else:
73 | raise ValueError('Loss should be inherited from `Loss` class')
74 |
75 | def __radd__(self, other):
76 | return self.__add__(other)
77 |
78 | def __mul__(self, value):
79 | if isinstance(value, (int, float)):
80 | return MultipliedLoss(self, value)
81 | else:
82 | raise ValueError('Loss should be inherited from `BaseLoss` class')
83 |
84 | def __rmul__(self, other):
85 | return self.__mul__(other)
86 |
87 |
88 | class MultipliedLoss(Loss):
89 |
90 | def __init__(self, loss, multiplier):
91 |
92 | # resolve name
93 | if len(loss.__name__.split('+')) > 1:
94 | name = '{}({})'.format(multiplier, loss.__name__)
95 | else:
96 | name = '{}{}'.format(multiplier, loss.__name__)
97 | super().__init__(name=name)
98 | self.loss = loss
99 | self.multiplier = multiplier
100 |
101 | def __call__(self, gt, pr):
102 | return self.multiplier * self.loss(gt, pr)
103 |
104 |
105 | class SumOfLosses(Loss):
106 |
107 | def __init__(self, l1, l2):
108 | name = '{}_plus_{}'.format(l1.__name__, l2.__name__)
109 | super().__init__(name=name)
110 | self.l1 = l1
111 | self.l2 = l2
112 |
113 | def __call__(self, gt, pr):
114 | return self.l1(gt, pr) + self.l2(gt, pr)
115 |
--------------------------------------------------------------------------------
/segmentation_models/losses.py:
--------------------------------------------------------------------------------
1 | from .base import Loss
2 | from .base import functional as F
3 |
4 | SMOOTH = 1e-5
5 |
6 |
7 | class JaccardLoss(Loss):
8 | r"""Creates a criterion to measure Jaccard loss:
9 |
10 | .. math:: L(A, B) = 1 - \frac{A \cap B}{A \cup B}
11 |
12 | Args:
13 | class_weights: Array (``np.array``) of class weights (``len(weights) = num_classes``).
14 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
15 | per_image: If ``True`` loss is calculated for each image in batch and then averaged,
16 | else loss is calculated for the whole batch.
17 | smooth: Value to avoid division by zero.
18 |
19 | Returns:
20 | A callable ``jaccard_loss`` instance. Can be used in ``model.compile(...)`` function
21 | or combined with other losses.
22 |
23 | Example:
24 |
25 | .. code:: python
26 |
27 | loss = JaccardLoss()
28 | model.compile('SGD', loss=loss)
29 | """
30 |
31 | def __init__(self, class_weights=None, class_indexes=None, per_image=False, smooth=SMOOTH):
32 | super().__init__(name='jaccard_loss')
33 | self.class_weights = class_weights if class_weights is not None else 1
34 | self.class_indexes = class_indexes
35 | self.per_image = per_image
36 | self.smooth = smooth
37 |
38 | def __call__(self, gt, pr):
39 | return 1 - F.iou_score(
40 | gt,
41 | pr,
42 | class_weights=self.class_weights,
43 | class_indexes=self.class_indexes,
44 | smooth=self.smooth,
45 | per_image=self.per_image,
46 | threshold=None,
47 | **self.submodules
48 | )
49 |
50 |
51 | class DiceLoss(Loss):
52 | r"""Creates a criterion to measure Dice loss:
53 |
54 | .. math:: L(precision, recall) = 1 - (1 + \beta^2) \frac{precision \cdot recall}
55 | {\beta^2 \cdot precision + recall}
56 |
57 | The formula in terms of *Type I* and *Type II* errors:
58 |
59 | .. math:: L(tp, fp, fn) = \frac{(1 + \beta^2) \cdot tp} {(1 + \beta^2) \cdot fp + \beta^2 \cdot fn + fp}
60 |
61 | where:
62 | - tp - true positives;
63 | - fp - false positives;
64 | - fn - false negatives;
65 |
66 | Args:
67 | beta: Float or integer coefficient for precision and recall balance.
68 | class_weights: Array (``np.array``) of class weights (``len(weights) = num_classes``).
69 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
70 | per_image: If ``True`` loss is calculated for each image in batch and then averaged,
71 | else loss is calculated for the whole batch.
72 | smooth: Value to avoid division by zero.
73 |
74 | Returns:
75 | A callable ``dice_loss`` instance. Can be used in ``model.compile(...)`` function`
76 | or combined with other losses.
77 |
78 | Example:
79 |
80 | .. code:: python
81 |
82 | loss = DiceLoss()
83 | model.compile('SGD', loss=loss)
84 | """
85 |
86 | def __init__(self, beta=1, class_weights=None, class_indexes=None, per_image=False, smooth=SMOOTH):
87 | super().__init__(name='dice_loss')
88 | self.beta = beta
89 | self.class_weights = class_weights if class_weights is not None else 1
90 | self.class_indexes = class_indexes
91 | self.per_image = per_image
92 | self.smooth = smooth
93 |
94 | def __call__(self, gt, pr):
95 | return 1 - F.f_score(
96 | gt,
97 | pr,
98 | beta=self.beta,
99 | class_weights=self.class_weights,
100 | class_indexes=self.class_indexes,
101 | smooth=self.smooth,
102 | per_image=self.per_image,
103 | threshold=None,
104 | **self.submodules
105 | )
106 |
107 |
108 | class BinaryCELoss(Loss):
109 | """Creates a criterion that measures the Binary Cross Entropy between the
110 | ground truth (gt) and the prediction (pr).
111 |
112 | .. math:: L(gt, pr) = - gt \cdot \log(pr) - (1 - gt) \cdot \log(1 - pr)
113 |
114 | Returns:
115 | A callable ``binary_crossentropy`` instance. Can be used in ``model.compile(...)`` function
116 | or combined with other losses.
117 |
118 | Example:
119 |
120 | .. code:: python
121 |
122 | loss = BinaryCELoss()
123 | model.compile('SGD', loss=loss)
124 | """
125 |
126 | def __init__(self):
127 | super().__init__(name='binary_crossentropy')
128 |
129 | def __call__(self, gt, pr):
130 | return F.binary_crossentropy(gt, pr, **self.submodules)
131 |
132 |
133 | class CategoricalCELoss(Loss):
134 | """Creates a criterion that measures the Categorical Cross Entropy between the
135 | ground truth (gt) and the prediction (pr).
136 |
137 | .. math:: L(gt, pr) = - gt \cdot \log(pr)
138 |
139 | Args:
140 | class_weights: Array (``np.array``) of class weights (``len(weights) = num_classes``).
141 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
142 |
143 | Returns:
144 | A callable ``categorical_crossentropy`` instance. Can be used in ``model.compile(...)`` function
145 | or combined with other losses.
146 |
147 | Example:
148 |
149 | .. code:: python
150 |
151 | loss = CategoricalCELoss()
152 | model.compile('SGD', loss=loss)
153 | """
154 |
155 | def __init__(self, class_weights=None, class_indexes=None):
156 | super().__init__(name='categorical_crossentropy')
157 | self.class_weights = class_weights if class_weights is not None else 1
158 | self.class_indexes = class_indexes
159 |
160 | def __call__(self, gt, pr):
161 | return F.categorical_crossentropy(
162 | gt,
163 | pr,
164 | class_weights=self.class_weights,
165 | class_indexes=self.class_indexes,
166 | **self.submodules
167 | )
168 |
169 |
170 | class CategoricalFocalLoss(Loss):
171 | r"""Creates a criterion that measures the Categorical Focal Loss between the
172 | ground truth (gt) and the prediction (pr).
173 |
174 | .. math:: L(gt, pr) = - gt \cdot \alpha \cdot (1 - pr)^\gamma \cdot \log(pr)
175 |
176 | Args:
177 | alpha: Float or integer, the same as weighting factor in balanced cross entropy, default 0.25.
178 | gamma: Float or integer, focusing parameter for modulating factor (1 - p), default 2.0.
179 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
180 |
181 | Returns:
182 | A callable ``categorical_focal_loss`` instance. Can be used in ``model.compile(...)`` function
183 | or combined with other losses.
184 |
185 | Example:
186 |
187 | .. code:: python
188 |
189 | loss = CategoricalFocalLoss()
190 | model.compile('SGD', loss=loss)
191 | """
192 |
193 | def __init__(self, alpha=0.25, gamma=2., class_indexes=None):
194 | super().__init__(name='focal_loss')
195 | self.alpha = alpha
196 | self.gamma = gamma
197 | self.class_indexes = class_indexes
198 |
199 | def __call__(self, gt, pr):
200 | return F.categorical_focal_loss(
201 | gt,
202 | pr,
203 | alpha=self.alpha,
204 | gamma=self.gamma,
205 | class_indexes=self.class_indexes,
206 | **self.submodules
207 | )
208 |
209 |
210 | class BinaryFocalLoss(Loss):
211 | r"""Creates a criterion that measures the Binary Focal Loss between the
212 | ground truth (gt) and the prediction (pr).
213 |
214 | .. math:: L(gt, pr) = - gt \alpha (1 - pr)^\gamma \log(pr) - (1 - gt) \alpha pr^\gamma \log(1 - pr)
215 |
216 | Args:
217 | alpha: Float or integer, the same as weighting factor in balanced cross entropy, default 0.25.
218 | gamma: Float or integer, focusing parameter for modulating factor (1 - p), default 2.0.
219 |
220 | Returns:
221 | A callable ``binary_focal_loss`` instance. Can be used in ``model.compile(...)`` function
222 | or combined with other losses.
223 |
224 | Example:
225 |
226 | .. code:: python
227 |
228 | loss = BinaryFocalLoss()
229 | model.compile('SGD', loss=loss)
230 | """
231 |
232 | def __init__(self, alpha=0.25, gamma=2.):
233 | super().__init__(name='binary_focal_loss')
234 | self.alpha = alpha
235 | self.gamma = gamma
236 |
237 | def __call__(self, gt, pr):
238 | return F.binary_focal_loss(gt, pr, alpha=self.alpha, gamma=self.gamma, **self.submodules)
239 |
240 |
241 | # aliases
242 | jaccard_loss = JaccardLoss()
243 | dice_loss = DiceLoss()
244 |
245 | binary_focal_loss = BinaryFocalLoss()
246 | categorical_focal_loss = CategoricalFocalLoss()
247 |
248 | binary_crossentropy = BinaryCELoss()
249 | categorical_crossentropy = CategoricalCELoss()
250 |
251 | # loss combinations
252 | bce_dice_loss = binary_crossentropy + dice_loss
253 | bce_jaccard_loss = binary_crossentropy + jaccard_loss
254 |
255 | cce_dice_loss = categorical_crossentropy + dice_loss
256 | cce_jaccard_loss = categorical_crossentropy + jaccard_loss
257 |
258 | binary_focal_dice_loss = binary_focal_loss + dice_loss
259 | binary_focal_jaccard_loss = binary_focal_loss + jaccard_loss
260 |
261 | categorical_focal_dice_loss = categorical_focal_loss + dice_loss
262 | categorical_focal_jaccard_loss = categorical_focal_loss + jaccard_loss
263 |
--------------------------------------------------------------------------------
/segmentation_models/metrics.py:
--------------------------------------------------------------------------------
1 | from .base import Metric
2 | from .base import functional as F
3 |
4 | SMOOTH = 1e-5
5 |
6 |
7 | class IOUScore(Metric):
8 | r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient
9 | (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the
10 | similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets,
11 | and is defined as the size of the intersection divided by the size of the union of the sample sets:
12 |
13 | .. math:: J(A, B) = \frac{A \cap B}{A \cup B}
14 |
15 | Args:
16 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``).
17 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
18 | smooth: value to avoid division by zero
19 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
20 | else over whole batch
21 | threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round
22 |
23 | Returns:
24 | A callable ``iou_score`` instance. Can be used in ``model.compile(...)`` function.
25 |
26 | .. _`Jaccard index`: https://en.wikipedia.org/wiki/Jaccard_index
27 |
28 | Example:
29 |
30 | .. code:: python
31 |
32 | metric = IOUScore()
33 | model.compile('SGD', loss=loss, metrics=[metric])
34 | """
35 |
36 | def __init__(
37 | self,
38 | class_weights=None,
39 | class_indexes=None,
40 | threshold=None,
41 | per_image=False,
42 | smooth=SMOOTH,
43 | name=None,
44 | ):
45 | name = name or 'iou_score'
46 | super().__init__(name=name)
47 | self.class_weights = class_weights if class_weights is not None else 1
48 | self.class_indexes = class_indexes
49 | self.threshold = threshold
50 | self.per_image = per_image
51 | self.smooth = smooth
52 |
53 | def __call__(self, gt, pr):
54 | return F.iou_score(
55 | gt,
56 | pr,
57 | class_weights=self.class_weights,
58 | class_indexes=self.class_indexes,
59 | smooth=self.smooth,
60 | per_image=self.per_image,
61 | threshold=self.threshold,
62 | **self.submodules
63 | )
64 |
65 |
66 | class FScore(Metric):
67 | r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall,
68 | where an F-score reaches its best value at 1 and worst score at 0.
69 | The relative contribution of ``precision`` and ``recall`` to the F1-score are equal.
70 | The formula for the F score is:
71 |
72 | .. math:: F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall}
73 | {\beta^2 \cdot precision + recall}
74 |
75 | The formula in terms of *Type I* and *Type II* errors:
76 |
77 | .. math:: L(tp, fp, fn) = \frac{(1 + \beta^2) \cdot tp} {(1 + \beta^2) \cdot fp + \beta^2 \cdot fn + fp}
78 |
79 | where:
80 | - tp - true positives;
81 | - fp - false positives;
82 | - fn - false negatives;
83 |
84 | Args:
85 | beta: Integer of float f-score coefficient to balance precision and recall.
86 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``)
87 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
88 | smooth: Float value to avoid division by zero.
89 | per_image: If ``True``, metric is calculated as mean over images in batch (B),
90 | else over whole batch.
91 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round.
92 | name: Optional string, if ``None`` default ``f{beta}-score`` name is used.
93 |
94 | Returns:
95 | A callable ``f_score`` instance. Can be used in ``model.compile(...)`` function.
96 |
97 | Example:
98 |
99 | .. code:: python
100 |
101 | metric = FScore()
102 | model.compile('SGD', loss=loss, metrics=[metric])
103 | """
104 |
105 | def __init__(
106 | self,
107 | beta=1,
108 | class_weights=None,
109 | class_indexes=None,
110 | threshold=None,
111 | per_image=False,
112 | smooth=SMOOTH,
113 | name=None,
114 | ):
115 | name = name or 'f{}-score'.format(beta)
116 | super().__init__(name=name)
117 | self.beta = beta
118 | self.class_weights = class_weights if class_weights is not None else 1
119 | self.class_indexes = class_indexes
120 | self.threshold = threshold
121 | self.per_image = per_image
122 | self.smooth = smooth
123 |
124 | def __call__(self, gt, pr):
125 | return F.f_score(
126 | gt,
127 | pr,
128 | beta=self.beta,
129 | class_weights=self.class_weights,
130 | class_indexes=self.class_indexes,
131 | smooth=self.smooth,
132 | per_image=self.per_image,
133 | threshold=self.threshold,
134 | **self.submodules
135 | )
136 |
137 |
138 | class Precision(Metric):
139 | r"""Creates a criterion that measures the Precision between the
140 | ground truth (gt) and the prediction (pr).
141 |
142 | .. math:: F_\beta(tp, fp) = \frac{tp} {(tp + fp)}
143 |
144 | where:
145 | - tp - true positives;
146 | - fp - false positives;
147 |
148 | Args:
149 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``).
150 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
151 | smooth: Float value to avoid division by zero.
152 | per_image: If ``True``, metric is calculated as mean over images in batch (B),
153 | else over whole batch.
154 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round.
155 | name: Optional string, if ``None`` default ``precision`` name is used.
156 |
157 | Returns:
158 | A callable ``precision`` instance. Can be used in ``model.compile(...)`` function.
159 |
160 | Example:
161 |
162 | .. code:: python
163 |
164 | metric = Precision()
165 | model.compile('SGD', loss=loss, metrics=[metric])
166 | """
167 |
168 | def __init__(
169 | self,
170 | class_weights=None,
171 | class_indexes=None,
172 | threshold=None,
173 | per_image=False,
174 | smooth=SMOOTH,
175 | name=None,
176 | ):
177 | name = name or 'precision'
178 | super().__init__(name=name)
179 | self.class_weights = class_weights if class_weights is not None else 1
180 | self.class_indexes = class_indexes
181 | self.threshold = threshold
182 | self.per_image = per_image
183 | self.smooth = smooth
184 |
185 | def __call__(self, gt, pr):
186 | return F.precision(
187 | gt,
188 | pr,
189 | class_weights=self.class_weights,
190 | class_indexes=self.class_indexes,
191 | smooth=self.smooth,
192 | per_image=self.per_image,
193 | threshold=self.threshold,
194 | **self.submodules
195 | )
196 |
197 |
198 | class Recall(Metric):
199 | r"""Creates a criterion that measures the Precision between the
200 | ground truth (gt) and the prediction (pr).
201 |
202 | .. math:: F_\beta(tp, fn) = \frac{tp} {(tp + fn)}
203 |
204 | where:
205 | - tp - true positives;
206 | - fn - false negatives;
207 |
208 | Args:
209 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``).
210 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
211 | smooth: Float value to avoid division by zero.
212 | per_image: If ``True``, metric is calculated as mean over images in batch (B),
213 | else over whole batch.
214 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round.
215 | name: Optional string, if ``None`` default ``recall`` name is used.
216 |
217 | Returns:
218 | A callable ``recall`` instance. Can be used in ``model.compile(...)`` function.
219 |
220 | Example:
221 |
222 | .. code:: python
223 |
224 | metric = Precision()
225 | model.compile('SGD', loss=loss, metrics=[metric])
226 | """
227 |
228 | def __init__(
229 | self,
230 | class_weights=None,
231 | class_indexes=None,
232 | threshold=None,
233 | per_image=False,
234 | smooth=SMOOTH,
235 | name=None,
236 | ):
237 | name = name or 'recall'
238 | super().__init__(name=name)
239 | self.class_weights = class_weights if class_weights is not None else 1
240 | self.class_indexes = class_indexes
241 | self.threshold = threshold
242 | self.per_image = per_image
243 | self.smooth = smooth
244 |
245 | def __call__(self, gt, pr):
246 | return F.recall(
247 | gt,
248 | pr,
249 | class_weights=self.class_weights,
250 | class_indexes=self.class_indexes,
251 | smooth=self.smooth,
252 | per_image=self.per_image,
253 | threshold=self.threshold,
254 | **self.submodules
255 | )
256 |
257 |
258 | # aliases
259 | iou_score = IOUScore()
260 | f1_score = FScore(beta=1)
261 | f2_score = FScore(beta=2)
262 | precision = Precision()
263 | recall = Recall()
264 |
--------------------------------------------------------------------------------
/segmentation_models/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qubvel/segmentation_models/5d24bbfb28af6134e25e2c0b79e7727f6c0491d0/segmentation_models/models/__init__.py
--------------------------------------------------------------------------------
/segmentation_models/models/_common_blocks.py:
--------------------------------------------------------------------------------
1 | from keras_applications import get_submodules_from_kwargs
2 |
3 |
4 | def Conv2dBn(
5 | filters,
6 | kernel_size,
7 | strides=(1, 1),
8 | padding='valid',
9 | data_format=None,
10 | dilation_rate=(1, 1),
11 | activation=None,
12 | kernel_initializer='glorot_uniform',
13 | bias_initializer='zeros',
14 | kernel_regularizer=None,
15 | bias_regularizer=None,
16 | activity_regularizer=None,
17 | kernel_constraint=None,
18 | bias_constraint=None,
19 | use_batchnorm=False,
20 | **kwargs
21 | ):
22 | """Extension of Conv2D layer with batchnorm"""
23 |
24 | conv_name, act_name, bn_name = None, None, None
25 | block_name = kwargs.pop('name', None)
26 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
27 |
28 | if block_name is not None:
29 | conv_name = block_name + '_conv'
30 |
31 | if block_name is not None and activation is not None:
32 | act_str = activation.__name__ if callable(activation) else str(activation)
33 | act_name = block_name + '_' + act_str
34 |
35 | if block_name is not None and use_batchnorm:
36 | bn_name = block_name + '_bn'
37 |
38 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
39 |
40 | def wrapper(input_tensor):
41 |
42 | x = layers.Conv2D(
43 | filters=filters,
44 | kernel_size=kernel_size,
45 | strides=strides,
46 | padding=padding,
47 | data_format=data_format,
48 | dilation_rate=dilation_rate,
49 | activation=None,
50 | use_bias=not (use_batchnorm),
51 | kernel_initializer=kernel_initializer,
52 | bias_initializer=bias_initializer,
53 | kernel_regularizer=kernel_regularizer,
54 | bias_regularizer=bias_regularizer,
55 | activity_regularizer=activity_regularizer,
56 | kernel_constraint=kernel_constraint,
57 | bias_constraint=bias_constraint,
58 | name=conv_name,
59 | )(input_tensor)
60 |
61 | if use_batchnorm:
62 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x)
63 |
64 | if activation:
65 | x = layers.Activation(activation, name=act_name)(x)
66 |
67 | return x
68 |
69 | return wrapper
70 |
--------------------------------------------------------------------------------
/segmentation_models/models/_utils.py:
--------------------------------------------------------------------------------
1 | from keras_applications import get_submodules_from_kwargs
2 |
3 |
4 | def freeze_model(model, **kwargs):
5 | """Set all layers non trainable, excluding BatchNormalization layers"""
6 | _, layers, _, _ = get_submodules_from_kwargs(kwargs)
7 | for layer in model.layers:
8 | if not isinstance(layer, layers.BatchNormalization):
9 | layer.trainable = False
10 | return
11 |
12 |
13 | def filter_keras_submodules(kwargs):
14 | """Selects only arguments that define keras_application submodules. """
15 | submodule_keys = kwargs.keys() & {'backend', 'layers', 'models', 'utils'}
16 | return {key: kwargs[key] for key in submodule_keys}
17 |
--------------------------------------------------------------------------------
/segmentation_models/models/fpn.py:
--------------------------------------------------------------------------------
1 | from keras_applications import get_submodules_from_kwargs
2 |
3 | from ._common_blocks import Conv2dBn
4 | from ._utils import freeze_model, filter_keras_submodules
5 | from ..backbones.backbones_factory import Backbones
6 |
7 | backend = None
8 | layers = None
9 | models = None
10 | keras_utils = None
11 |
12 |
13 | # ---------------------------------------------------------------------
14 | # Utility functions
15 | # ---------------------------------------------------------------------
16 |
17 | def get_submodules():
18 | return {
19 | 'backend': backend,
20 | 'models': models,
21 | 'layers': layers,
22 | 'utils': keras_utils,
23 | }
24 |
25 |
26 | # ---------------------------------------------------------------------
27 | # Blocks
28 | # ---------------------------------------------------------------------
29 |
30 | def Conv3x3BnReLU(filters, use_batchnorm, name=None):
31 | kwargs = get_submodules()
32 |
33 | def wrapper(input_tensor):
34 | return Conv2dBn(
35 | filters,
36 | kernel_size=3,
37 | activation='relu',
38 | kernel_initializer='he_uniform',
39 | padding='same',
40 | use_batchnorm=use_batchnorm,
41 | name=name,
42 | **kwargs
43 | )(input_tensor)
44 |
45 | return wrapper
46 |
47 |
48 | def DoubleConv3x3BnReLU(filters, use_batchnorm, name=None):
49 | name1, name2 = None, None
50 | if name is not None:
51 | name1 = name + 'a'
52 | name2 = name + 'b'
53 |
54 | def wrapper(input_tensor):
55 | x = Conv3x3BnReLU(filters, use_batchnorm, name=name1)(input_tensor)
56 | x = Conv3x3BnReLU(filters, use_batchnorm, name=name2)(x)
57 | return x
58 |
59 | return wrapper
60 |
61 |
62 | def FPNBlock(pyramid_filters, stage):
63 | conv0_name = 'fpn_stage_p{}_pre_conv'.format(stage)
64 | conv1_name = 'fpn_stage_p{}_conv'.format(stage)
65 | add_name = 'fpn_stage_p{}_add'.format(stage)
66 | up_name = 'fpn_stage_p{}_upsampling'.format(stage)
67 |
68 | channels_axis = 3 if backend.image_data_format() == 'channels_last' else 1
69 |
70 | def wrapper(input_tensor, skip):
71 | # if input tensor channels not equal to pyramid channels
72 | # we will not be able to sum input tensor and skip
73 | # so add extra conv layer to transform it
74 | input_filters = backend.int_shape(input_tensor)[channels_axis]
75 | if input_filters != pyramid_filters:
76 | input_tensor = layers.Conv2D(
77 | filters=pyramid_filters,
78 | kernel_size=(1, 1),
79 | kernel_initializer='he_uniform',
80 | name=conv0_name,
81 | )(input_tensor)
82 |
83 | skip = layers.Conv2D(
84 | filters=pyramid_filters,
85 | kernel_size=(1, 1),
86 | kernel_initializer='he_uniform',
87 | name=conv1_name,
88 | )(skip)
89 |
90 | x = layers.UpSampling2D((2, 2), name=up_name)(input_tensor)
91 | x = layers.Add(name=add_name)([x, skip])
92 |
93 | return x
94 |
95 | return wrapper
96 |
97 |
98 | # ---------------------------------------------------------------------
99 | # FPN Decoder
100 | # ---------------------------------------------------------------------
101 |
102 | def build_fpn(
103 | backbone,
104 | skip_connection_layers,
105 | pyramid_filters=256,
106 | segmentation_filters=128,
107 | classes=1,
108 | activation='sigmoid',
109 | use_batchnorm=True,
110 | aggregation='sum',
111 | dropout=None,
112 | ):
113 | input_ = backbone.input
114 | x = backbone.output
115 |
116 | # building decoder blocks with skip connections
117 | skips = ([backbone.get_layer(name=i).output if isinstance(i, str)
118 | else backbone.get_layer(index=i).output for i in skip_connection_layers])
119 |
120 | # build FPN pyramid
121 | p5 = FPNBlock(pyramid_filters, stage=5)(x, skips[0])
122 | p4 = FPNBlock(pyramid_filters, stage=4)(p5, skips[1])
123 | p3 = FPNBlock(pyramid_filters, stage=3)(p4, skips[2])
124 | p2 = FPNBlock(pyramid_filters, stage=2)(p3, skips[3])
125 |
126 | # add segmentation head to each
127 | s5 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage5')(p5)
128 | s4 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage4')(p4)
129 | s3 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage3')(p3)
130 | s2 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage2')(p2)
131 |
132 | # upsampling to same resolution
133 | s5 = layers.UpSampling2D((8, 8), interpolation='nearest', name='upsampling_stage5')(s5)
134 | s4 = layers.UpSampling2D((4, 4), interpolation='nearest', name='upsampling_stage4')(s4)
135 | s3 = layers.UpSampling2D((2, 2), interpolation='nearest', name='upsampling_stage3')(s3)
136 |
137 | # aggregating results
138 | if aggregation == 'sum':
139 | x = layers.Add(name='aggregation_sum')([s2, s3, s4, s5])
140 | elif aggregation == 'concat':
141 | concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1
142 | x = layers.Concatenate(axis=concat_axis, name='aggregation_concat')([s2, s3, s4, s5])
143 | else:
144 | raise ValueError('Aggregation parameter should be in ("sum", "concat"), '
145 | 'got {}'.format(aggregation))
146 |
147 | if dropout:
148 | x = layers.SpatialDropout2D(dropout, name='pyramid_dropout')(x)
149 |
150 | # final stage
151 | x = Conv3x3BnReLU(segmentation_filters, use_batchnorm, name='final_stage')(x)
152 | x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear', name='final_upsampling')(x)
153 |
154 | # model head (define number of output classes)
155 | x = layers.Conv2D(
156 | filters=classes,
157 | kernel_size=(3, 3),
158 | padding='same',
159 | use_bias=True,
160 | kernel_initializer='glorot_uniform',
161 | name='head_conv',
162 | )(x)
163 | x = layers.Activation(activation, name=activation)(x)
164 |
165 | # create keras model instance
166 | model = models.Model(input_, x)
167 |
168 | return model
169 |
170 |
171 | # ---------------------------------------------------------------------
172 | # FPN Model
173 | # ---------------------------------------------------------------------
174 |
175 | def FPN(
176 | backbone_name='vgg16',
177 | input_shape=(None, None, 3),
178 | classes=21,
179 | activation='softmax',
180 | weights=None,
181 | encoder_weights='imagenet',
182 | encoder_freeze=False,
183 | encoder_features='default',
184 | pyramid_block_filters=256,
185 | pyramid_use_batchnorm=True,
186 | pyramid_aggregation='concat',
187 | pyramid_dropout=None,
188 | **kwargs
189 | ):
190 | """FPN_ is a fully convolution neural network for image semantic segmentation
191 |
192 | Args:
193 | backbone_name: name of classification model (without last dense layers) used as feature
194 | extractor to build segmentation model.
195 | input_shape: shape of input data/image ``(H, W, C)``, in general
196 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be
197 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``.
198 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
199 | weights: optional, path to model weights.
200 | activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``).
201 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
202 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
203 | encoder_features: a list of layer numbers or names starting from top of the model.
204 | Each of these layers will be used to build features pyramid. If ``default`` is used
205 | layer names are taken from ``DEFAULT_FEATURE_PYRAMID_LAYERS``.
206 | pyramid_block_filters: a number of filters in Feature Pyramid Block of FPN_.
207 | pyramid_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
208 | is used.
209 | pyramid_aggregation: one of 'sum' or 'concat'. The way to aggregate pyramid blocks.
210 | pyramid_dropout: spatial dropout rate for feature pyramid in range (0, 1).
211 |
212 | Returns:
213 | ``keras.models.Model``: **FPN**
214 |
215 | .. _FPN:
216 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
217 |
218 | """
219 | global backend, layers, models, keras_utils
220 | submodule_args = filter_keras_submodules(kwargs)
221 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args)
222 |
223 | backbone = Backbones.get_backbone(
224 | backbone_name,
225 | input_shape=input_shape,
226 | weights=encoder_weights,
227 | include_top=False,
228 | **kwargs,
229 | )
230 |
231 | if encoder_features == 'default':
232 | encoder_features = Backbones.get_feature_layers(backbone_name, n=4)
233 |
234 | model = build_fpn(
235 | backbone=backbone,
236 | skip_connection_layers=encoder_features,
237 | pyramid_filters=pyramid_block_filters,
238 | segmentation_filters=pyramid_block_filters // 2,
239 | use_batchnorm=pyramid_use_batchnorm,
240 | dropout=pyramid_dropout,
241 | activation=activation,
242 | classes=classes,
243 | aggregation=pyramid_aggregation,
244 | )
245 |
246 | # lock encoder weights for fine-tuning
247 | if encoder_freeze:
248 | freeze_model(backbone, **kwargs)
249 |
250 | # loading model weights
251 | if weights is not None:
252 | model.load_weights(weights)
253 |
254 | return model
255 |
--------------------------------------------------------------------------------
/segmentation_models/models/linknet.py:
--------------------------------------------------------------------------------
1 | from keras_applications import get_submodules_from_kwargs
2 |
3 | from ._common_blocks import Conv2dBn
4 | from ._utils import freeze_model, filter_keras_submodules
5 | from ..backbones.backbones_factory import Backbones
6 |
7 | backend = None
8 | layers = None
9 | models = None
10 | keras_utils = None
11 |
12 |
13 | # ---------------------------------------------------------------------
14 | # Utility functions
15 | # ---------------------------------------------------------------------
16 |
17 | def get_submodules():
18 | return {
19 | 'backend': backend,
20 | 'models': models,
21 | 'layers': layers,
22 | 'utils': keras_utils,
23 | }
24 |
25 |
26 | # ---------------------------------------------------------------------
27 | # Blocks
28 | # ---------------------------------------------------------------------
29 |
30 | def Conv3x3BnReLU(filters, use_batchnorm, name=None):
31 | kwargs = get_submodules()
32 |
33 | def wrapper(input_tensor):
34 | return Conv2dBn(
35 | filters,
36 | kernel_size=3,
37 | activation='relu',
38 | kernel_initializer='he_uniform',
39 | padding='same',
40 | use_batchnorm=use_batchnorm,
41 | name=name,
42 | **kwargs
43 | )(input_tensor)
44 |
45 | return wrapper
46 |
47 |
48 | def Conv1x1BnReLU(filters, use_batchnorm, name=None):
49 | kwargs = get_submodules()
50 |
51 | def wrapper(input_tensor):
52 | return Conv2dBn(
53 | filters,
54 | kernel_size=1,
55 | activation='relu',
56 | kernel_initializer='he_uniform',
57 | padding='same',
58 | use_batchnorm=use_batchnorm,
59 | name=name,
60 | **kwargs
61 | )(input_tensor)
62 |
63 | return wrapper
64 |
65 |
66 | def DecoderUpsamplingX2Block(filters, stage, use_batchnorm):
67 | conv_block1_name = 'decoder_stage{}a'.format(stage)
68 | conv_block2_name = 'decoder_stage{}b'.format(stage)
69 | conv_block3_name = 'decoder_stage{}c'.format(stage)
70 | up_name = 'decoder_stage{}_upsampling'.format(stage)
71 | add_name = 'decoder_stage{}_add'.format(stage)
72 |
73 | channels_axis = 3 if backend.image_data_format() == 'channels_last' else 1
74 |
75 | def wrapper(input_tensor, skip=None):
76 | input_filters = backend.int_shape(input_tensor)[channels_axis]
77 | output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters
78 |
79 | x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor)
80 | x = layers.UpSampling2D((2, 2), name=up_name)(x)
81 | x = Conv3x3BnReLU(input_filters // 4, use_batchnorm, name=conv_block2_name)(x)
82 | x = Conv1x1BnReLU(output_filters, use_batchnorm, name=conv_block3_name)(x)
83 |
84 | if skip is not None:
85 | x = layers.Add(name=add_name)([x, skip])
86 | return x
87 |
88 | return wrapper
89 |
90 |
91 | def DecoderTransposeX2Block(filters, stage, use_batchnorm):
92 | conv_block1_name = 'decoder_stage{}a'.format(stage)
93 | transpose_name = 'decoder_stage{}b_transpose'.format(stage)
94 | bn_name = 'decoder_stage{}b_bn'.format(stage)
95 | relu_name = 'decoder_stage{}b_relu'.format(stage)
96 | conv_block3_name = 'decoder_stage{}c'.format(stage)
97 | add_name = 'decoder_stage{}_add'.format(stage)
98 |
99 | channels_axis = bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
100 |
101 | def wrapper(input_tensor, skip=None):
102 | input_filters = backend.int_shape(input_tensor)[channels_axis]
103 | output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters
104 |
105 | x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor)
106 | x = layers.Conv2DTranspose(
107 | filters=input_filters // 4,
108 | kernel_size=(4, 4),
109 | strides=(2, 2),
110 | padding='same',
111 | name=transpose_name,
112 | use_bias=not use_batchnorm,
113 | )(x)
114 |
115 | if use_batchnorm:
116 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x)
117 |
118 | x = layers.Activation('relu', name=relu_name)(x)
119 | x = Conv1x1BnReLU(output_filters, use_batchnorm, name=conv_block3_name)(x)
120 |
121 | if skip is not None:
122 | x = layers.Add(name=add_name)([x, skip])
123 |
124 | return x
125 |
126 | return wrapper
127 |
128 |
129 | # ---------------------------------------------------------------------
130 | # LinkNet Decoder
131 | # ---------------------------------------------------------------------
132 |
133 | def build_linknet(
134 | backbone,
135 | decoder_block,
136 | skip_connection_layers,
137 | decoder_filters=(256, 128, 64, 32, 16),
138 | n_upsample_blocks=5,
139 | classes=1,
140 | activation='sigmoid',
141 | use_batchnorm=True,
142 | ):
143 | input_ = backbone.input
144 | x = backbone.output
145 |
146 | # extract skip connections
147 | skips = ([backbone.get_layer(name=i).output if isinstance(i, str)
148 | else backbone.get_layer(index=i).output for i in skip_connection_layers])
149 |
150 | # add center block if previous operation was maxpooling (for vgg models)
151 | if isinstance(backbone.layers[-1], layers.MaxPooling2D):
152 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block1')(x)
153 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block2')(x)
154 |
155 | # building decoder blocks
156 | for i in range(n_upsample_blocks):
157 |
158 | if i < len(skips):
159 | skip = skips[i]
160 | else:
161 | skip = None
162 |
163 | x = decoder_block(decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip)
164 |
165 | # model head (define number of output classes)
166 | x = layers.Conv2D(
167 | filters=classes,
168 | kernel_size=(3, 3),
169 | padding='same',
170 | use_bias=True,
171 | kernel_initializer='glorot_uniform'
172 | )(x)
173 | x = layers.Activation(activation, name=activation)(x)
174 |
175 | # create keras model instance
176 | model = models.Model(input_, x)
177 |
178 | return model
179 |
180 |
181 | # ---------------------------------------------------------------------
182 | # LinkNet Model
183 | # ---------------------------------------------------------------------
184 |
185 | def Linknet(
186 | backbone_name='vgg16',
187 | input_shape=(None, None, 3),
188 | classes=1,
189 | activation='sigmoid',
190 | weights=None,
191 | encoder_weights='imagenet',
192 | encoder_freeze=False,
193 | encoder_features='default',
194 | decoder_block_type='upsampling',
195 | decoder_filters=(None, None, None, None, 16),
196 | decoder_use_batchnorm=True,
197 | **kwargs
198 | ):
199 | """Linknet_ is a fully convolution neural network for fast image semantic segmentation
200 |
201 | Note:
202 | This implementation by default has 4 skip connections (original - 3).
203 |
204 | Args:
205 | backbone_name: name of classification model (without last dense layers) used as feature
206 | extractor to build segmentation model.
207 | input_shape: shape of input data/image ``(H, W, C)``, in general
208 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be
209 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``.
210 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
211 | activation: name of one of ``keras.activations`` for last model layer
212 | (e.g. ``sigmoid``, ``softmax``, ``linear``).
213 | weights: optional, path to model weights.
214 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
215 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
216 | encoder_features: a list of layer numbers or names starting from top of the model.
217 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used
218 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``.
219 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks,
220 | for block with skip connection a number of filters is equal to number of filters in
221 | corresponding encoder block (estimates automatically and can be passed as ``None`` value).
222 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
223 | is used.
224 | decoder_block_type: one of
225 | - `upsampling`: use ``UpSampling2D`` keras layer
226 | - `transpose`: use ``Transpose2D`` keras layer
227 |
228 | Returns:
229 | ``keras.models.Model``: **Linknet**
230 |
231 | .. _Linknet:
232 | https://arxiv.org/pdf/1707.03718.pdf
233 | """
234 |
235 | global backend, layers, models, keras_utils
236 | submodule_args = filter_keras_submodules(kwargs)
237 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args)
238 |
239 | if decoder_block_type == 'upsampling':
240 | decoder_block = DecoderUpsamplingX2Block
241 | elif decoder_block_type == 'transpose':
242 | decoder_block = DecoderTransposeX2Block
243 | else:
244 | raise ValueError('Decoder block type should be in ("upsampling", "transpose"). '
245 | 'Got: {}'.format(decoder_block_type))
246 |
247 | backbone = Backbones.get_backbone(
248 | backbone_name,
249 | input_shape=input_shape,
250 | weights=encoder_weights,
251 | include_top=False,
252 | **kwargs,
253 | )
254 |
255 | if encoder_features == 'default':
256 | encoder_features = Backbones.get_feature_layers(backbone_name, n=4)
257 |
258 | model = build_linknet(
259 | backbone=backbone,
260 | decoder_block=decoder_block,
261 | skip_connection_layers=encoder_features,
262 | decoder_filters=decoder_filters,
263 | classes=classes,
264 | activation=activation,
265 | n_upsample_blocks=len(decoder_filters),
266 | use_batchnorm=decoder_use_batchnorm,
267 | )
268 |
269 | # lock encoder weights for fine-tuning
270 | if encoder_freeze:
271 | freeze_model(backbone, **kwargs)
272 |
273 | # loading model weights
274 | if weights is not None:
275 | model.load_weights(weights)
276 |
277 | return model
278 |
--------------------------------------------------------------------------------
/segmentation_models/models/pspnet.py:
--------------------------------------------------------------------------------
1 | from keras_applications import get_submodules_from_kwargs
2 |
3 | from ._common_blocks import Conv2dBn
4 | from ._utils import freeze_model, filter_keras_submodules
5 | from ..backbones.backbones_factory import Backbones
6 |
7 | backend = None
8 | layers = None
9 | models = None
10 | keras_utils = None
11 |
12 |
13 | # ---------------------------------------------------------------------
14 | # Utility functions
15 | # ---------------------------------------------------------------------
16 |
17 | def get_submodules():
18 | return {
19 | 'backend': backend,
20 | 'models': models,
21 | 'layers': layers,
22 | 'utils': keras_utils,
23 | }
24 |
25 |
26 | def check_input_shape(input_shape, factor):
27 | if input_shape is None:
28 | raise ValueError("Input shape should be a tuple of 3 integers, not None!")
29 |
30 | h, w = input_shape[:2] if backend.image_data_format() == 'channels_last' else input_shape[1:]
31 | min_size = factor * 6
32 |
33 | is_wrong_shape = (
34 | h % min_size != 0 or w % min_size != 0 or
35 | h < min_size or w < min_size
36 | )
37 |
38 | if is_wrong_shape:
39 | raise ValueError('Wrong shape {}, input H and W should '.format(input_shape) +
40 | 'be divisible by `{}`'.format(min_size))
41 |
42 |
43 | # ---------------------------------------------------------------------
44 | # Blocks
45 | # ---------------------------------------------------------------------
46 |
47 | def Conv1x1BnReLU(filters, use_batchnorm, name=None):
48 | kwargs = get_submodules()
49 |
50 | def wrapper(input_tensor):
51 | return Conv2dBn(
52 | filters,
53 | kernel_size=1,
54 | activation='relu',
55 | kernel_initializer='he_uniform',
56 | padding='same',
57 | use_batchnorm=use_batchnorm,
58 | name=name,
59 | **kwargs
60 | )(input_tensor)
61 |
62 | return wrapper
63 |
64 |
65 | def SpatialContextBlock(
66 | level,
67 | conv_filters=512,
68 | pooling_type='avg',
69 | use_batchnorm=True,
70 | ):
71 | if pooling_type not in ('max', 'avg'):
72 | raise ValueError('Unsupported pooling type - `{}`.'.format(pooling_type) +
73 | 'Use `avg` or `max`.')
74 |
75 | Pooling2D = layers.MaxPool2D if pooling_type == 'max' else layers.AveragePooling2D
76 |
77 | pooling_name = 'psp_level{}_pooling'.format(level)
78 | conv_block_name = 'psp_level{}'.format(level)
79 | upsampling_name = 'psp_level{}_upsampling'.format(level)
80 |
81 | def wrapper(input_tensor):
82 | # extract input feature maps size (h, and w dimensions)
83 | input_shape = backend.int_shape(input_tensor)
84 | spatial_size = input_shape[1:3] if backend.image_data_format() == 'channels_last' else input_shape[2:]
85 |
86 | # Compute the kernel and stride sizes according to how large the final feature map will be
87 | # When the kernel factor and strides are equal, then we can compute the final feature map factor
88 | # by simply dividing the current factor by the kernel or stride factor
89 | # The final feature map sizes are 1x1, 2x2, 3x3, and 6x6.
90 | pool_size = up_size = [spatial_size[0] // level, spatial_size[1] // level]
91 |
92 | x = Pooling2D(pool_size, strides=pool_size, padding='same', name=pooling_name)(input_tensor)
93 | x = Conv1x1BnReLU(conv_filters, use_batchnorm, name=conv_block_name)(x)
94 | x = layers.UpSampling2D(up_size, interpolation='bilinear', name=upsampling_name)(x)
95 | return x
96 |
97 | return wrapper
98 |
99 |
100 | # ---------------------------------------------------------------------
101 | # PSP Decoder
102 | # ---------------------------------------------------------------------
103 |
104 | def build_psp(
105 | backbone,
106 | psp_layer_idx,
107 | pooling_type='avg',
108 | conv_filters=512,
109 | use_batchnorm=True,
110 | final_upsampling_factor=8,
111 | classes=21,
112 | activation='softmax',
113 | dropout=None,
114 | ):
115 | input_ = backbone.input
116 | x = (backbone.get_layer(name=psp_layer_idx).output if isinstance(psp_layer_idx, str)
117 | else backbone.get_layer(index=psp_layer_idx).output)
118 |
119 | # build spatial pyramid
120 | x1 = SpatialContextBlock(1, conv_filters, pooling_type, use_batchnorm)(x)
121 | x2 = SpatialContextBlock(2, conv_filters, pooling_type, use_batchnorm)(x)
122 | x3 = SpatialContextBlock(3, conv_filters, pooling_type, use_batchnorm)(x)
123 | x6 = SpatialContextBlock(6, conv_filters, pooling_type, use_batchnorm)(x)
124 |
125 | # aggregate spatial pyramid
126 | concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1
127 | x = layers.Concatenate(axis=concat_axis, name='psp_concat')([x, x1, x2, x3, x6])
128 | x = Conv1x1BnReLU(conv_filters, use_batchnorm, name='aggregation')(x)
129 |
130 | # model regularization
131 | if dropout is not None:
132 | x = layers.SpatialDropout2D(dropout, name='spatial_dropout')(x)
133 |
134 | # model head
135 | x = layers.Conv2D(
136 | filters=classes,
137 | kernel_size=(3, 3),
138 | padding='same',
139 | kernel_initializer='glorot_uniform',
140 | name='final_conv',
141 | )(x)
142 |
143 | x = layers.UpSampling2D(final_upsampling_factor, name='final_upsampling', interpolation='bilinear')(x)
144 | x = layers.Activation(activation, name=activation)(x)
145 |
146 | model = models.Model(input_, x)
147 |
148 | return model
149 |
150 |
151 | # ---------------------------------------------------------------------
152 | # PSP Model
153 | # ---------------------------------------------------------------------
154 |
155 | def PSPNet(
156 | backbone_name='vgg16',
157 | input_shape=(384, 384, 3),
158 | classes=21,
159 | activation='softmax',
160 | weights=None,
161 | encoder_weights='imagenet',
162 | encoder_freeze=False,
163 | downsample_factor=8,
164 | psp_conv_filters=512,
165 | psp_pooling_type='avg',
166 | psp_use_batchnorm=True,
167 | psp_dropout=None,
168 | **kwargs
169 | ):
170 | """PSPNet_ is a fully convolution neural network for image semantic segmentation
171 |
172 | Args:
173 | backbone_name: name of classification model used as feature
174 | extractor to build segmentation model.
175 | input_shape: shape of input data/image ``(H, W, C)``.
176 | ``H`` and ``W`` should be divisible by ``6 * downsample_factor`` and **NOT** ``None``!
177 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
178 | activation: name of one of ``keras.activations`` for last model layer
179 | (e.g. ``sigmoid``, ``softmax``, ``linear``).
180 | weights: optional, path to model weights.
181 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
182 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
183 | downsample_factor: one of 4, 8 and 16. Downsampling rate or in other words backbone depth
184 | to construct PSP module on it.
185 | psp_conv_filters: number of filters in ``Conv2D`` layer in each PSP block.
186 | psp_pooling_type: one of 'avg', 'max'. PSP block pooling type (maximum or average).
187 | psp_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
188 | is used.
189 | psp_dropout: dropout rate between 0 and 1.
190 |
191 | Returns:
192 | ``keras.models.Model``: **PSPNet**
193 |
194 | .. _PSPNet:
195 | https://arxiv.org/pdf/1612.01105.pdf
196 |
197 | """
198 |
199 | global backend, layers, models, keras_utils
200 | submodule_args = filter_keras_submodules(kwargs)
201 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args)
202 |
203 | # control image input shape
204 | check_input_shape(input_shape, downsample_factor)
205 |
206 | backbone = Backbones.get_backbone(
207 | backbone_name,
208 | input_shape=input_shape,
209 | weights=encoder_weights,
210 | include_top=False,
211 | **kwargs
212 | )
213 |
214 | feature_layers = Backbones.get_feature_layers(backbone_name, n=3)
215 |
216 | if downsample_factor == 16:
217 | psp_layer_idx = feature_layers[0]
218 | elif downsample_factor == 8:
219 | psp_layer_idx = feature_layers[1]
220 | elif downsample_factor == 4:
221 | psp_layer_idx = feature_layers[2]
222 | else:
223 | raise ValueError('Unsupported factor - `{}`, Use 4, 8 or 16.'.format(downsample_factor))
224 |
225 | model = build_psp(
226 | backbone,
227 | psp_layer_idx,
228 | pooling_type=psp_pooling_type,
229 | conv_filters=psp_conv_filters,
230 | use_batchnorm=psp_use_batchnorm,
231 | final_upsampling_factor=downsample_factor,
232 | classes=classes,
233 | activation=activation,
234 | dropout=psp_dropout,
235 | )
236 |
237 | # lock encoder weights for fine-tuning
238 | if encoder_freeze:
239 | freeze_model(backbone, **kwargs)
240 |
241 | # loading model weights
242 | if weights is not None:
243 | model.load_weights(weights)
244 |
245 | return model
246 |
--------------------------------------------------------------------------------
/segmentation_models/models/unet.py:
--------------------------------------------------------------------------------
1 | from keras_applications import get_submodules_from_kwargs
2 |
3 | from ._common_blocks import Conv2dBn
4 | from ._utils import freeze_model, filter_keras_submodules
5 | from ..backbones.backbones_factory import Backbones
6 |
7 | backend = None
8 | layers = None
9 | models = None
10 | keras_utils = None
11 |
12 |
13 | # ---------------------------------------------------------------------
14 | # Utility functions
15 | # ---------------------------------------------------------------------
16 |
17 | def get_submodules():
18 | return {
19 | 'backend': backend,
20 | 'models': models,
21 | 'layers': layers,
22 | 'utils': keras_utils,
23 | }
24 |
25 |
26 | # ---------------------------------------------------------------------
27 | # Blocks
28 | # ---------------------------------------------------------------------
29 |
30 | def Conv3x3BnReLU(filters, use_batchnorm, name=None):
31 | kwargs = get_submodules()
32 |
33 | def wrapper(input_tensor):
34 | return Conv2dBn(
35 | filters,
36 | kernel_size=3,
37 | activation='relu',
38 | kernel_initializer='he_uniform',
39 | padding='same',
40 | use_batchnorm=use_batchnorm,
41 | name=name,
42 | **kwargs
43 | )(input_tensor)
44 |
45 | return wrapper
46 |
47 |
48 | def DecoderUpsamplingX2Block(filters, stage, use_batchnorm=False):
49 | up_name = 'decoder_stage{}_upsampling'.format(stage)
50 | conv1_name = 'decoder_stage{}a'.format(stage)
51 | conv2_name = 'decoder_stage{}b'.format(stage)
52 | concat_name = 'decoder_stage{}_concat'.format(stage)
53 |
54 | concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1
55 |
56 | def wrapper(input_tensor, skip=None):
57 | x = layers.UpSampling2D(size=2, name=up_name)(input_tensor)
58 |
59 | if skip is not None:
60 | x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip])
61 |
62 | x = Conv3x3BnReLU(filters, use_batchnorm, name=conv1_name)(x)
63 | x = Conv3x3BnReLU(filters, use_batchnorm, name=conv2_name)(x)
64 |
65 | return x
66 |
67 | return wrapper
68 |
69 |
70 | def DecoderTransposeX2Block(filters, stage, use_batchnorm=False):
71 | transp_name = 'decoder_stage{}a_transpose'.format(stage)
72 | bn_name = 'decoder_stage{}a_bn'.format(stage)
73 | relu_name = 'decoder_stage{}a_relu'.format(stage)
74 | conv_block_name = 'decoder_stage{}b'.format(stage)
75 | concat_name = 'decoder_stage{}_concat'.format(stage)
76 |
77 | concat_axis = bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
78 |
79 | def layer(input_tensor, skip=None):
80 |
81 | x = layers.Conv2DTranspose(
82 | filters,
83 | kernel_size=(4, 4),
84 | strides=(2, 2),
85 | padding='same',
86 | name=transp_name,
87 | use_bias=not use_batchnorm,
88 | )(input_tensor)
89 |
90 | if use_batchnorm:
91 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x)
92 |
93 | x = layers.Activation('relu', name=relu_name)(x)
94 |
95 | if skip is not None:
96 | x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip])
97 |
98 | x = Conv3x3BnReLU(filters, use_batchnorm, name=conv_block_name)(x)
99 |
100 | return x
101 |
102 | return layer
103 |
104 |
105 | # ---------------------------------------------------------------------
106 | # Unet Decoder
107 | # ---------------------------------------------------------------------
108 |
109 | def build_unet(
110 | backbone,
111 | decoder_block,
112 | skip_connection_layers,
113 | decoder_filters=(256, 128, 64, 32, 16),
114 | n_upsample_blocks=5,
115 | classes=1,
116 | activation='sigmoid',
117 | use_batchnorm=True,
118 | ):
119 | input_ = backbone.input
120 | x = backbone.output
121 |
122 | # extract skip connections
123 | skips = ([backbone.get_layer(name=i).output if isinstance(i, str)
124 | else backbone.get_layer(index=i).output for i in skip_connection_layers])
125 |
126 | # add center block if previous operation was maxpooling (for vgg models)
127 | if isinstance(backbone.layers[-1], layers.MaxPooling2D):
128 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block1')(x)
129 | x = Conv3x3BnReLU(512, use_batchnorm, name='center_block2')(x)
130 |
131 | # building decoder blocks
132 | for i in range(n_upsample_blocks):
133 |
134 | if i < len(skips):
135 | skip = skips[i]
136 | else:
137 | skip = None
138 |
139 | x = decoder_block(decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip)
140 |
141 | # model head (define number of output classes)
142 | x = layers.Conv2D(
143 | filters=classes,
144 | kernel_size=(3, 3),
145 | padding='same',
146 | use_bias=True,
147 | kernel_initializer='glorot_uniform',
148 | name='final_conv',
149 | )(x)
150 | x = layers.Activation(activation, name=activation)(x)
151 |
152 | # create keras model instance
153 | model = models.Model(input_, x)
154 |
155 | return model
156 |
157 |
158 | # ---------------------------------------------------------------------
159 | # Unet Model
160 | # ---------------------------------------------------------------------
161 |
162 | def Unet(
163 | backbone_name='vgg16',
164 | input_shape=(None, None, 3),
165 | classes=1,
166 | activation='sigmoid',
167 | weights=None,
168 | encoder_weights='imagenet',
169 | encoder_freeze=False,
170 | encoder_features='default',
171 | decoder_block_type='upsampling',
172 | decoder_filters=(256, 128, 64, 32, 16),
173 | decoder_use_batchnorm=True,
174 | **kwargs
175 | ):
176 | """ Unet_ is a fully convolution neural network for image semantic segmentation
177 |
178 | Args:
179 | backbone_name: name of classification model (without last dense layers) used as feature
180 | extractor to build segmentation model.
181 | input_shape: shape of input data/image ``(H, W, C)``, in general
182 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be
183 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``.
184 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
185 | activation: name of one of ``keras.activations`` for last model layer
186 | (e.g. ``sigmoid``, ``softmax``, ``linear``).
187 | weights: optional, path to model weights.
188 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
189 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
190 | encoder_features: a list of layer numbers or names starting from top of the model.
191 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used
192 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``.
193 | decoder_block_type: one of blocks with following layers structure:
194 |
195 | - `upsampling`: ``UpSampling2D`` -> ``Conv2D`` -> ``Conv2D``
196 | - `transpose`: ``Transpose2D`` -> ``Conv2D``
197 |
198 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks
199 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
200 | is used.
201 |
202 | Returns:
203 | ``keras.models.Model``: **Unet**
204 |
205 | .. _Unet:
206 | https://arxiv.org/pdf/1505.04597
207 |
208 | """
209 |
210 | global backend, layers, models, keras_utils
211 | submodule_args = filter_keras_submodules(kwargs)
212 | backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args)
213 |
214 | if decoder_block_type == 'upsampling':
215 | decoder_block = DecoderUpsamplingX2Block
216 | elif decoder_block_type == 'transpose':
217 | decoder_block = DecoderTransposeX2Block
218 | else:
219 | raise ValueError('Decoder block type should be in ("upsampling", "transpose"). '
220 | 'Got: {}'.format(decoder_block_type))
221 |
222 | backbone = Backbones.get_backbone(
223 | backbone_name,
224 | input_shape=input_shape,
225 | weights=encoder_weights,
226 | include_top=False,
227 | **kwargs,
228 | )
229 |
230 | if encoder_features == 'default':
231 | encoder_features = Backbones.get_feature_layers(backbone_name, n=4)
232 |
233 | model = build_unet(
234 | backbone=backbone,
235 | decoder_block=decoder_block,
236 | skip_connection_layers=encoder_features,
237 | decoder_filters=decoder_filters,
238 | classes=classes,
239 | activation=activation,
240 | n_upsample_blocks=len(decoder_filters),
241 | use_batchnorm=decoder_use_batchnorm,
242 | )
243 |
244 | # lock encoder weights for fine-tuning
245 | if encoder_freeze:
246 | freeze_model(backbone, **kwargs)
247 |
248 | # loading model weights
249 | if weights is not None:
250 | model.load_weights(weights)
251 |
252 | return model
253 |
--------------------------------------------------------------------------------
/segmentation_models/utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions for segmentation models """
2 |
3 | from keras_applications import get_submodules_from_kwargs
4 | from . import inject_global_submodules
5 |
6 |
7 | def set_trainable(model, recompile=True, **kwargs):
8 | """Set all layers of model trainable and recompile it
9 |
10 | Note:
11 | Model is recompiled using same optimizer, loss and metrics::
12 |
13 | model.compile(
14 | model.optimizer,
15 | loss=model.loss,
16 | metrics=model.metrics,
17 | loss_weights=model.loss_weights,
18 | sample_weight_mode=model.sample_weight_mode,
19 | weighted_metrics=model.weighted_metrics,
20 | )
21 |
22 | Args:
23 | model (``keras.models.Model``): instance of keras model
24 |
25 | """
26 | for layer in model.layers:
27 | layer.trainable = True
28 |
29 | if recompile:
30 | model.compile(
31 | model.optimizer,
32 | loss=model.loss,
33 | metrics=model.metrics,
34 | loss_weights=model.loss_weights,
35 | sample_weight_mode=model.sample_weight_mode,
36 | weighted_metrics=model.weighted_metrics,
37 | )
38 |
39 |
40 | @inject_global_submodules
41 | def set_regularization(
42 | model,
43 | kernel_regularizer=None,
44 | bias_regularizer=None,
45 | activity_regularizer=None,
46 | beta_regularizer=None,
47 | gamma_regularizer=None,
48 | **kwargs
49 | ):
50 | """Set regularizers to all layers
51 |
52 | Note:
53 | Returned model's config is updated correctly
54 |
55 | Args:
56 | model (``keras.models.Model``): instance of keras model
57 | kernel_regularizer(``regularizer`): regularizer of kernels
58 | bias_regularizer(``regularizer``): regularizer of bias
59 | activity_regularizer(``regularizer``): regularizer of activity
60 | gamma_regularizer(``regularizer``): regularizer of gamma of BatchNormalization
61 | beta_regularizer(``regularizer``): regularizer of beta of BatchNormalization
62 |
63 | Return:
64 | out (``Model``): config updated model
65 | """
66 | _, _, models, _ = get_submodules_from_kwargs(kwargs)
67 |
68 | for layer in model.layers:
69 | # set kernel_regularizer
70 | if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'):
71 | layer.kernel_regularizer = kernel_regularizer
72 | # set bias_regularizer
73 | if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'):
74 | layer.bias_regularizer = bias_regularizer
75 | # set activity_regularizer
76 | if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'):
77 | layer.activity_regularizer = activity_regularizer
78 |
79 | # set beta and gamma of BN layer
80 | if beta_regularizer is not None and hasattr(layer, 'beta_regularizer'):
81 | layer.beta_regularizer = beta_regularizer
82 |
83 | if gamma_regularizer is not None and hasattr(layer, 'gamma_regularizer'):
84 | layer.gamma_regularizer = gamma_regularizer
85 |
86 | out = models.model_from_json(model.to_json())
87 | out.set_weights(model.get_weights())
88 |
89 | return out
90 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Note: To use the 'upload' functionality of this file, you must:
5 | # $ pip install twine
6 |
7 | import io
8 | import os
9 | import sys
10 | from shutil import rmtree
11 |
12 | from setuptools import find_packages, setup, Command
13 |
14 | # Package meta-data.
15 | NAME = 'segmentation_models'
16 | DESCRIPTION = 'Image segmentation models with pre-trained backbones with Keras.'
17 | URL = 'https://github.com/qubvel/segmentation_models'
18 | EMAIL = 'qubvel@gmail.com'
19 | AUTHOR = 'Pavel Yakubovskiy'
20 | REQUIRES_PYTHON = '>=3.0.0'
21 | VERSION = None
22 |
23 | # The rest you shouldn't have to touch too much :)
24 | # ------------------------------------------------
25 | # Except, perhaps the License and Trove Classifiers!
26 | # If you do change the License, remember to change the Trove Classifier for that!
27 |
28 | here = os.path.abspath(os.path.dirname(__file__))
29 |
30 | # What packages are required for this module to be executed?
31 | try:
32 | with open(os.path.join(here, 'requirements.txt'), encoding='utf-8') as f:
33 | REQUIRED = f.read().split('\n')
34 | except:
35 | REQUIRED = []
36 |
37 | # What packages are optional?
38 | EXTRAS = {
39 | 'tests': ['pytest', 'scikit-image'],
40 | }
41 |
42 | # Import the README and use it as the long-description.
43 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
44 | try:
45 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
46 | long_description = '\n' + f.read()
47 | except FileNotFoundError:
48 | long_description = DESCRIPTION
49 |
50 | # Load the package's __version__.py module as a dictionary.
51 | about = {}
52 | if not VERSION:
53 | with open(os.path.join(here, NAME, '__version__.py')) as f:
54 | exec(f.read(), about)
55 | else:
56 | about['__version__'] = VERSION
57 |
58 |
59 | class UploadCommand(Command):
60 | """Support setup.py upload."""
61 |
62 | description = 'Build and publish the package.'
63 | user_options = []
64 |
65 | @staticmethod
66 | def status(s):
67 | """Prints things in bold."""
68 | print(s)
69 |
70 | def initialize_options(self):
71 | pass
72 |
73 | def finalize_options(self):
74 | pass
75 |
76 | def run(self):
77 | try:
78 | self.status('Removing previous builds...')
79 | rmtree(os.path.join(here, 'dist'))
80 | except OSError:
81 | pass
82 |
83 | self.status('Building Source and Wheel (universal) distribution...')
84 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable))
85 |
86 | self.status('Uploading the package to PyPI via Twine...')
87 | os.system('twine upload dist/*')
88 |
89 | self.status('Pushing git tags...')
90 | os.system('git tag v{0}'.format(about['__version__']))
91 | os.system('git push --tags')
92 |
93 | sys.exit()
94 |
95 |
96 | # Where the magic happens:
97 | setup(
98 | name=NAME,
99 | version=about['__version__'],
100 | description=DESCRIPTION,
101 | long_description=long_description,
102 | long_description_content_type='text/x-rst',
103 | author=AUTHOR,
104 | author_email=EMAIL,
105 | python_requires=REQUIRES_PYTHON,
106 | url=URL,
107 | packages=find_packages(exclude=('tests', 'docs', 'images', 'examples')),
108 | # If your package is a single module, use this instead of 'packages':
109 | # py_modules=['mypackage'],
110 |
111 | # entry_points={
112 | # 'console_scripts': ['mycli=mymodule:cli'],
113 | # },
114 | install_requires=REQUIRED,
115 | extras_require=EXTRAS,
116 | include_package_data=True,
117 | license='MIT',
118 | classifiers=[
119 | # Trove classifiers
120 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
121 | 'License :: OSI Approved :: MIT License',
122 | 'Programming Language :: Python',
123 | 'Programming Language :: Python :: 3',
124 | 'Programming Language :: Python :: Implementation :: CPython',
125 | 'Programming Language :: Python :: Implementation :: PyPy'
126 | ],
127 | # $ setup.py publish support.
128 | cmdclass={
129 | 'upload': UploadCommand,
130 | },
131 | )
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | import segmentation_models as sm
5 | from segmentation_models.metrics import IOUScore, FScore
6 | from segmentation_models.losses import JaccardLoss, DiceLoss
7 |
8 | if sm.framework() == sm._TF_KERAS_FRAMEWORK_NAME:
9 | from tensorflow import keras
10 | elif sm.framework() == sm._KERAS_FRAMEWORK_NAME:
11 | import keras
12 | else:
13 | raise ValueError('Incorrect framework {}'.format(sm.framework()))
14 |
15 | METRICS = [
16 | IOUScore,
17 | FScore,
18 | ]
19 |
20 | LOSSES = [
21 | JaccardLoss,
22 | DiceLoss,
23 | ]
24 |
25 | GT0 = np.array(
26 | [
27 | [0, 0, 0],
28 | [0, 0, 0],
29 | [0, 0, 0],
30 | ],
31 | dtype='float32',
32 | )
33 |
34 | GT1 = np.array(
35 | [
36 | [1, 1, 0],
37 | [1, 1, 0],
38 | [0, 0, 0],
39 | ],
40 | dtype='float32',
41 | )
42 |
43 | PR1 = np.array(
44 | [
45 | [0, 0, 0],
46 | [1, 1, 0],
47 | [0, 0, 0],
48 | ],
49 | dtype='float32',
50 | )
51 |
52 | PR2 = np.array(
53 | [
54 | [0, 0, 0],
55 | [1, 1, 0],
56 | [1, 1, 0],
57 | ],
58 | dtype='float32',
59 | )
60 |
61 | PR3 = np.array(
62 | [
63 | [0, 0, 0],
64 | [0, 0, 0],
65 | [1, 0, 0],
66 | ],
67 | dtype='float32',
68 | )
69 |
70 | IOU_CASES = (
71 |
72 | (GT0, GT0, 1.00),
73 | (GT1, GT1, 1.00),
74 |
75 | (GT0, PR1, 0.00),
76 | (GT0, PR2, 0.00),
77 | (GT0, PR3, 0.00),
78 |
79 | (GT1, PR1, 0.50),
80 | (GT1, PR2, 1. / 3.),
81 | (GT1, PR3, 0.00),
82 | )
83 |
84 | F1_CASES = (
85 |
86 | (GT0, GT0, 1.00),
87 | (GT1, GT1, 1.00),
88 |
89 | (GT0, PR1, 0.00),
90 | (GT0, PR2, 0.00),
91 | (GT0, PR3, 0.00),
92 |
93 | (GT1, PR1, 2. / 3.),
94 | (GT1, PR2, 0.50),
95 | (GT1, PR3, 0.00),
96 | )
97 |
98 | F2_CASES = (
99 |
100 | (GT0, GT0, 1.00),
101 | (GT1, GT1, 1.00),
102 |
103 | (GT0, PR1, 0.00),
104 | (GT0, PR2, 0.00),
105 | (GT0, PR3, 0.00),
106 |
107 | (GT1, PR1, 5. / 9.),
108 | (GT1, PR2, 0.50),
109 | (GT1, PR3, 0.00),
110 | )
111 |
112 |
113 | def _to_4d(x):
114 | if x.ndim == 2:
115 | return x[None, :, :, None]
116 | elif x.ndim == 3:
117 | return x[None, :, :]
118 |
119 |
120 | def _add_4d(x):
121 | if x.ndim == 3:
122 | return x[..., None]
123 |
124 |
125 | @pytest.mark.parametrize('case', IOU_CASES)
126 | def test_iou_metric(case):
127 | gt, pr, res = case
128 | gt = _to_4d(gt)
129 | pr = _to_4d(pr)
130 | iou_score = IOUScore(smooth=10e-12)
131 | score = keras.backend.eval(iou_score(gt, pr))
132 | assert np.allclose(score, res)
133 |
134 |
135 | @pytest.mark.parametrize('case', IOU_CASES)
136 | def test_jaccrad_loss(case):
137 | gt, pr, res = case
138 | gt = _to_4d(gt)
139 | pr = _to_4d(pr)
140 | jaccard_loss = JaccardLoss(smooth=10e-12)
141 | score = keras.backend.eval(jaccard_loss(gt, pr))
142 | assert np.allclose(score, 1 - res)
143 |
144 |
145 | def _test_f_metric(case, beta=1):
146 | gt, pr, res = case
147 | gt = _to_4d(gt)
148 | pr = _to_4d(pr)
149 | f_score = FScore(beta=beta, smooth=10e-12)
150 | score = keras.backend.eval(f_score(gt, pr))
151 | assert np.allclose(score, res)
152 |
153 |
154 | @pytest.mark.parametrize('case', F1_CASES)
155 | def test_f1_metric(case):
156 | _test_f_metric(case, beta=1)
157 |
158 |
159 | @pytest.mark.parametrize('case', F2_CASES)
160 | def test_f2_metric(case):
161 | _test_f_metric(case, beta=2)
162 |
163 |
164 | @pytest.mark.parametrize('case', F1_CASES)
165 | def test_dice_loss(case):
166 | gt, pr, res = case
167 | gt = _to_4d(gt)
168 | pr = _to_4d(pr)
169 | dice_loss = DiceLoss(smooth=10e-12)
170 | score = keras.backend.eval(dice_loss(gt, pr))
171 | assert np.allclose(score, 1 - res)
172 |
173 |
174 | @pytest.mark.parametrize('func', METRICS + LOSSES)
175 | def test_per_image(func):
176 | gt = np.stack([GT0, GT1], axis=0)
177 | pr = np.stack([PR1, PR2], axis=0)
178 |
179 | gt = _add_4d(gt)
180 | pr = _add_4d(pr)
181 |
182 | # calculate score per image
183 | score_1 = keras.backend.eval(func(per_image=True, smooth=10e-12)(gt, pr))
184 | score_2 = np.mean([
185 | keras.backend.eval(func(smooth=10e-12)(_to_4d(GT0), _to_4d(PR1))),
186 | keras.backend.eval(func(smooth=10e-12)(_to_4d(GT1), _to_4d(PR2))),
187 | ])
188 | assert np.allclose(score_1, score_2)
189 |
190 |
191 | @pytest.mark.parametrize('func', METRICS + LOSSES)
192 | def test_per_batch(func):
193 | gt = np.stack([GT0, GT1], axis=0)
194 | pr = np.stack([PR1, PR2], axis=0)
195 |
196 | gt = _add_4d(gt)
197 | pr = _add_4d(pr)
198 |
199 | # calculate score per batch
200 | score_1 = keras.backend.eval(func(per_image=False, smooth=10e-12)(gt, pr))
201 |
202 | gt1 = np.concatenate([GT0, GT1], axis=0)
203 | pr1 = np.concatenate([PR1, PR2], axis=0)
204 | score_2 = keras.backend.eval(func(per_image=True, smooth=10e-12)(_to_4d(gt1), _to_4d(pr1)))
205 |
206 | assert np.allclose(score_1, score_2)
207 |
208 |
209 | @pytest.mark.parametrize('case', IOU_CASES)
210 | def test_threshold_iou(case):
211 | gt, pr, res = case
212 | gt = _to_4d(gt)
213 | pr = _to_4d(pr) * 0.51
214 | iou_score = IOUScore(smooth=10e-12, threshold=0.5)
215 | score = keras.backend.eval(iou_score(gt, pr))
216 | assert np.allclose(score, res)
217 |
218 |
219 | if __name__ == '__main__':
220 | pytest.main([__file__])
221 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import random
4 | import six
5 | import numpy as np
6 |
7 | import segmentation_models as sm
8 | from segmentation_models import Unet
9 | from segmentation_models import Linknet
10 | from segmentation_models import PSPNet
11 | from segmentation_models import FPN
12 | from segmentation_models import get_available_backbone_names
13 |
14 | if sm.framework() == sm._TF_KERAS_FRAMEWORK_NAME:
15 | from tensorflow import keras
16 | elif sm.framework() == sm._KERAS_FRAMEWORK_NAME:
17 | import keras
18 | else:
19 | raise ValueError('Incorrect framework {}'.format(sm.framework()))
20 |
21 | def get_backbones():
22 | is_travis = os.environ.get('TRAVIS', False)
23 | exclude = ['senet154', 'efficientnetb6', 'efficientnetb7']
24 | backbones = get_available_backbone_names()
25 |
26 | if is_travis:
27 | backbones = [b for b in backbones if b not in exclude]
28 | return backbones
29 |
30 |
31 | BACKBONES = get_backbones()
32 |
33 |
34 | def _select_names(names):
35 | is_full = os.environ.get('FULL_TEST', False)
36 | if not is_full:
37 | return [random.choice(names)]
38 | else:
39 | return names
40 |
41 |
42 | def keras_test(func):
43 | """Function wrapper to clean up after TensorFlow tests.
44 | # Arguments
45 | func: test function to clean up after.
46 | # Returns
47 | A function wrapping the input function.
48 | """
49 | @six.wraps(func)
50 | def wrapper(*args, **kwargs):
51 | output = func(*args, **kwargs)
52 | keras.backend.clear_session()
53 | return output
54 | return wrapper
55 |
56 |
57 | @keras_test
58 | def _test_none_shape(model_fn, backbone, *args, **kwargs):
59 |
60 | # define number of channels
61 | input_shape = kwargs.get('input_shape', None)
62 | n_channels = 3 if input_shape is None else input_shape[-1]
63 |
64 | # create test sample
65 | x = np.ones((1, 32, 32, n_channels))
66 |
67 | # define model and process sample
68 | model = model_fn(backbone, *args, **kwargs)
69 | y = model.predict(x)
70 |
71 | # check output dimensions
72 | assert x.shape[:-1] == y.shape[:-1]
73 |
74 |
75 | @keras_test
76 | def _test_shape(model_fn, backbone, input_shape, *args, **kwargs):
77 |
78 | # create test sample
79 | x = np.ones((1, *input_shape))
80 |
81 | # define model and process sample
82 | model = model_fn(backbone, input_shape=input_shape, *args, **kwargs)
83 | y = model.predict(x)
84 |
85 | # check output dimensions
86 | assert x.shape[:-1] == y.shape[:-1]
87 |
88 |
89 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES))
90 | def test_unet(backbone):
91 | _test_none_shape(
92 | Unet, backbone, encoder_weights=None)
93 |
94 | _test_none_shape(
95 | Unet, backbone, encoder_weights='imagenet')
96 |
97 | _test_shape(
98 | Unet, backbone, input_shape=(256, 256, 4), encoder_weights=None)
99 |
100 |
101 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES))
102 | def test_linknet(backbone):
103 | _test_none_shape(
104 | Linknet, backbone, encoder_weights=None)
105 |
106 | _test_none_shape(
107 | Linknet, backbone, encoder_weights='imagenet')
108 |
109 | _test_shape(
110 | Linknet, backbone, input_shape=(256, 256, 4), encoder_weights=None)
111 |
112 |
113 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES))
114 | def test_pspnet(backbone):
115 |
116 | _test_shape(
117 | PSPNet, backbone, input_shape=(384, 384, 4), encoder_weights=None)
118 |
119 | _test_shape(
120 | PSPNet, backbone, input_shape=(384, 384, 3), encoder_weights='imagenet')
121 |
122 |
123 | @pytest.mark.parametrize('backbone', _select_names(BACKBONES))
124 | def test_fpn(backbone):
125 | _test_none_shape(
126 | FPN, backbone, encoder_weights=None)
127 |
128 | _test_none_shape(
129 | FPN, backbone, encoder_weights='imagenet')
130 |
131 | _test_shape(
132 | FPN, backbone, input_shape=(256, 256, 4), encoder_weights=None)
133 |
134 |
135 | if __name__ == '__main__':
136 | pytest.main([__file__])
137 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | import segmentation_models as sm
5 | from segmentation_models.utils import set_regularization
6 | from segmentation_models import Unet
7 |
8 | if sm.framework() == sm._TF_KERAS_FRAMEWORK_NAME:
9 | from tensorflow import keras
10 | elif sm.framework() == sm._KERAS_FRAMEWORK_NAME:
11 | import keras
12 | else:
13 | raise ValueError('Incorrect framework {}'.format(sm.framework()))
14 |
15 | X1 = np.ones((1, 32, 32, 3))
16 | Y1 = np.ones((1, 32, 32, 1))
17 | MODEL = Unet
18 | BACKBONE = 'resnet18'
19 | CASE = (
20 |
21 | (X1, Y1, MODEL, BACKBONE),
22 | )
23 |
24 |
25 | def _test_regularizer(model, reg_model, x, y):
26 |
27 | def zero_loss(gt, pr):
28 | return pr * 0
29 |
30 | model.compile('Adam', loss=zero_loss, metrics=['binary_accuracy'])
31 | reg_model.compile('Adam', loss=zero_loss, metrics=['binary_accuracy'])
32 |
33 | loss_1, _ = model.test_on_batch(x, y)
34 | loss_2, _ = reg_model.test_on_batch(x, y)
35 |
36 | assert loss_1 == 0
37 | assert loss_2 > 0
38 |
39 | keras.backend.clear_session()
40 |
41 |
42 | @pytest.mark.parametrize('case', CASE)
43 | def test_kernel_reg(case):
44 | x, y, model_fn, backbone= case
45 |
46 | l1_reg = keras.regularizers.l1(0.1)
47 | model = model_fn(backbone)
48 | reg_model = set_regularization(model, kernel_regularizer=l1_reg)
49 | _test_regularizer(model, reg_model, x, y)
50 |
51 | l2_reg = keras.regularizers.l2(0.1)
52 | model = model_fn(backbone, encoder_weights=None)
53 | reg_model = set_regularization(model, kernel_regularizer=l2_reg)
54 | _test_regularizer(model, reg_model, x, y)
55 |
56 |
57 | """
58 | Note:
59 | backbone resnet18 use BN after each conv layer --- so no bias used in these conv layers
60 | skip the bias regularizer test
61 |
62 | @pytest.mark.parametrize('case', CASE)
63 | def test_bias_reg(case):
64 | x, y, model_fn, backbone = case
65 |
66 | l1_reg = regularizers.l1(1)
67 | model = model_fn(backbone)
68 | reg_model = set_regularization(model, bias_regularizer=l1_reg)
69 | _test_regularizer(model, reg_model, x, y)
70 |
71 | l2_reg = regularizers.l2(1)
72 | model = model_fn(backbone)
73 | reg_model = set_regularization(model, bias_regularizer=l2_reg)
74 | _test_regularizer(model, reg_model, x, y)
75 | """
76 |
77 |
78 | @pytest.mark.parametrize('case', CASE)
79 | def test_bn_reg(case):
80 | x, y, model_fn, backbone= case
81 |
82 | l1_reg = keras.regularizers.l1(1)
83 | model = model_fn(backbone)
84 | reg_model = set_regularization(model, gamma_regularizer=l1_reg)
85 | _test_regularizer(model, reg_model, x, y)
86 |
87 | model = model_fn(backbone)
88 | reg_model = set_regularization(model, beta_regularizer=l1_reg)
89 | _test_regularizer(model, reg_model, x, y)
90 |
91 | l2_reg = keras.regularizers.l2(1)
92 | model = model_fn(backbone)
93 | reg_model = set_regularization(model, gamma_regularizer=l2_reg)
94 | _test_regularizer(model, reg_model, x, y)
95 |
96 | model = model_fn(backbone)
97 | reg_model = set_regularization(model, beta_regularizer=l2_reg)
98 | _test_regularizer(model, reg_model, x, y)
99 |
100 |
101 | @pytest.mark.parametrize('case', CASE)
102 | def test_activity_reg(case):
103 | x, y, model_fn, backbone= case
104 |
105 | l2_reg = keras.regularizers.l2(1)
106 | model = model_fn(backbone)
107 | reg_model = set_regularization(model, activity_regularizer=l2_reg)
108 | _test_regularizer(model, reg_model, x, y)
109 |
110 |
111 | if __name__ == '__main__':
112 | pytest.main([__file__])
113 |
--------------------------------------------------------------------------------