├── LICENSE.txt ├── README.md ├── __init__.py ├── examples ├── README.md ├── TASM_Example_1.ipynb ├── TASM_Example_2.ipynb ├── TASM_Example_3.ipynb └── TASM_Example_4.ipynb ├── images ├── README.md └── tasm logo big white bg.PNG ├── requirements_macos.rtf ├── requirements_windows.txt ├── setup.py └── tensorflow_advanced_segmentation_models ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── losses.cpython-37.pyc └── metrics.cpython-37.pyc ├── __version__.py ├── backbones ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── tf_backbones.cpython-37.pyc └── tf_backbones.py ├── base ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── functional.cpython-37.pyc │ └── objects.cpython-37.pyc ├── functional.py └── objects.py ├── losses.py ├── metrics.py └── models ├── ACFNet.py ├── ASPOCRNet.py ├── CFNet.py ├── DANet.py ├── DeepLab.py ├── DeepLabV3.py ├── DeepLabV3plus.py ├── FCN.py ├── FPNet.py ├── HRNetOCR.py ├── OCNet.py ├── PSPNet.py ├── SpatialOCRNet.py ├── UNet.py ├── __init__.py ├── __pycache__ ├── DANet.cpython-37.pyc ├── DeepLab.cpython-37.pyc ├── DeepLabV3.cpython-37.pyc ├── DeepLabV3plus.cpython-37.pyc ├── FCN.cpython-37.pyc ├── FPNet.cpython-37.pyc ├── OCNet.cpython-37.pyc ├── PSPNet.cpython-37.pyc ├── UNet.cpython-37.pyc ├── __init__.cpython-37.pyc └── _custom_layers_and_blocks.cpython-37.pyc └── _custom_layers_and_blocks.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License 3 | 4 | Copyright (c) 2020, Jan-Marcel Kezmann 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | """ 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Advanced Segmentation Models 2 | A Python Library for High-Level Semantic Segmentation Models. 3 | 4 |
5 |
6 |
Since the breakthrough of Deep Learning and Computer Vision was always one of the core problems that researcher all over the world have worked on, to create better models every day. One Computer Vision area that got huge attention in the last couple of years is Semantic Segmentation. The task to segment every pixel on a given image led to the invention of many great models starting with the classical U-Net up to now more and more complex neural network structures. But even though many new algorithms were developed, the distribution of easy to use open source libraries which contain High-Level APIs that make the technology accessible for everyone is still far behind the huge amount of research that is published continuously.
10 |Inspired by qubvel's segmentation_models this repository builds upon his work and extends it by a variety of recently developed models which achieved great results on the Cityscapes, PASCAL VOC 2012, PASCAL Context, ADE20K dataset and many more.
11 |The library contains to date 14 different Semantic Segmentation Model Architecters for multi-class semantic segmentation as well as many on imagenet pretrained backbones. An important new feature is the upgrade to Tensorflow 2.x including the use of the advanced model subclassing feauture to build customized segmentation models. Further are now all system platforms compatible with the library this means that tasm can run on Windows, Linux and MacOS as well.
12 | 13 | ### Main Library Features 14 | - High Level API 15 | - 14 Segmentation Model Architectures for multi-class semantic segmentation 16 | - **New:** HRNet + OCR Model 17 | - Many already pretrained backbones for each architecture 18 | - Many useful segmentation losses (Dice, Focal, Tversky, Jaccard and many more combinations of them) 19 | - **New:** Models can be used as Subclassed or Functional Model 20 | - **New:** TASM works now on all platforms, i.e. Windows, Linux, MacOS with Intel or Apple Silicon Chips 21 | 22 | ## Table of Contents 23 | 24 | - [Installation and Setup](#installation-and-setup) 25 | - [Training Pipeline](#training-pipeline) 26 | - [Examples](#examples) 27 | - [Models and Backbones](#models-and-backbones) 28 | - [Citing](#citing) 29 | - [License](#license) 30 | - [References](#references) 31 | 32 | ## Installation and Setup 33 | 34 |To get the repository running just check the following requirements.
35 | 36 | **Requirements** 37 | **Windows or Linus** 38 | 1) Python 3.6 or higher 39 | 2) tensorflow >= 2.3.0 (>= 2.0.0 is sufficient if no efficientnet backbone is used) 40 | 3) numpy 41 | 4) matplotlib 42 | 43 | **MacOS** 44 | 1) Python 3.9 or higher 45 | 2) tensorflow-macos >= 2.5.0 46 | 3) numpy >= 1.21.0 47 | 4) matplotlib 48 | 49 |Furthermore just execute the following command to download and install the git repository.
50 | 51 | **Clone Repository** 52 | 53 | $ git clone https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git 54 | 55 | or directly install it:Thank you for all the papers that made this repository possible and especially thank you Pavel Yakubovskiy's initial segmentation models repository.
183 | 184 | - Pavel Yakubovskiy, Segmentation Models, 2019, GitHub, GitHub Repository, https://github.com/qubvel/segmentation_models 185 | - Liang-Chieh Chen and Yukun Zhu and George Papandreou and Florian Schroff and Hartwig Adam, Tensorflow Models DeepLab, 2020, GitHub, GitHub Repository, https://github.com/tensorflow/models/tree/master/research/deeplab 186 | - Fu, Jun and Liu, Jing and Tian, Haijie and Li, Yong and Bao, Yongjun and Fang, Zhiwei and Lu, Hanqing, DANet, 2020, GitHub, GitHub Repository, https://github.com/junfu1115/DANet 187 | - Yuhui Yuan and Jingdong Wang, openseg.OCN.pytorch, 2020, GitHub, GitHub Repository, https://github.com/openseg-group/OCNet.pytorch 188 | - Yuhui Yuan and Xilin Chen and Jingdong Wang, openseg.pytotrch, 2020, GitHub, GitHub Repository, https://github.com/openseg-group/openseg.pytorch 189 | - Fan Zhang, Yanqin Chen, Zhihang Li, Zhibin Hong, Jingtuo Liu, Feifei Ma, Junyu Han, Errui Ding, 2020, GitHub, GitHub Repository, https://github.com/zrl4836/ACFNet 190 | - Xie Jingyi, Ke Sun, Jingdong Wang, RainbowSecret, 2021, GitHub, GitHub Repository, https://github.com/HRNet/HRNet-Semantic-Segmentation 191 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .tensorflow_advanced_segmentation_models import * -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples Section 2 | 3 | - [Jupyter Notebook] Multi-class (3 classes) segmentation (sky, building, background) on CamVid dataset here 4 | - [Jupyter Notebook] Multi-class (11 classes) segmentation on CamVid dataset here 5 | - [Jupyter Notebook] Multi-class (11 classes) segmentation on CamVid dataset with a custom training loophere 6 | - [Jupyter Notebook] Two-class (2 classes) segmentation on Caltech-Birds-2010 dataset here 7 | -------------------------------------------------------------------------------- /images/README.md: -------------------------------------------------------------------------------- 1 | ### Here you can find some logo and icon samples. 2 | -------------------------------------------------------------------------------- /images/tasm logo big white bg.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/images/tasm logo big white bg.PNG -------------------------------------------------------------------------------- /requirements_macos.rtf: -------------------------------------------------------------------------------- 1 | tensorflow-macos>=2.5.0 2 | numpy>=1.21.0 3 | matplotlib 4 | -------------------------------------------------------------------------------- /requirements_windows.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.0.0 2 | numpy 3 | matplotlib 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import sys 4 | import setuptools 5 | 6 | # Package meta-data. 7 | NAME = 'tensorflow_advanced_segmentation_models' 8 | DESCRIPTION = 'A Python Library for High-Level Semantic Segmentation Models based on TensorFlow and Keras.' 9 | URL = 'https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models' 10 | EMAIL = 'jankezmann@t-online.de' 11 | AUTHOR = 'Jan Marcel Kezmann' 12 | REQUIRES_PYTHON = '>=3.6.0' 13 | VERSION = None 14 | 15 | here = os.path.abspath(os.path.dirname(__file__)) 16 | 17 | try: 18 | if sys.platform == 'darwin': 19 | with open(os.path.join(here, 'requirements_macos.txt'), encoding='utf-8') as f: 20 | REQUIRED = f.read().split('\n') 21 | else: 22 | with open(os.path.join(here, 'requirements_windows.txt'), encoding='utf-8') as f: 23 | REQUIRED = f.read().split('\n') 24 | except: 25 | REQUIRED = [] 26 | 27 | # Import the README and use it as the long-description. 28 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 29 | try: 30 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 31 | long_description = '\n' + f.read() 32 | except FileNotFoundError: 33 | long_description = DESCRIPTION 34 | 35 | # Load the package's __version__.py module as a dictionary. 36 | about = {} 37 | if not VERSION: 38 | with open(os.path.join(here, NAME, '__version__.py')) as f: 39 | exec(f.read(), about) 40 | else: 41 | about['__version__'] = VERSION 42 | 43 | 44 | # with open("README.md", "r") as fh: 45 | # long_description = fh.read() 46 | 47 | setuptools.setup( 48 | name=NAME, 49 | version=about["__version__"], 50 | author=AUTHOR, 51 | author_email=EMAIL, 52 | description=DESCRIPTION, 53 | long_description=long_description, 54 | long_description_content_type="text/markdown", 55 | url=URL, 56 | packages=setuptools.find_packages(exclude=("images", "examples")), 57 | install_requires=REQUIRED, 58 | include_package_data=True, 59 | license='MIT', 60 | classifiers=[ 61 | "Programming Language :: Python :: 3", 62 | "License :: OSI Approved :: MIT License", 63 | ], 64 | python_requires=REQUIRES_PYTHON, 65 | ) 66 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .__version__ import __version__ 2 | from . import base 3 | from . import losses 4 | from . import metrics 5 | 6 | from .backbones.tf_backbones import create_base_model 7 | from .models.FCN import FCN 8 | from .models.UNet import UNet 9 | from .models.OCNet import OCNet 10 | from .models.FPNet import FPNet 11 | from .models.DANet import DANet 12 | from .models.CFNet import CFNet 13 | from .models.ACFNet import ACFNet 14 | from .models.PSPNet import PSPNet 15 | from .models.DeepLab import DeepLab 16 | from .models.HRNetOCR import HRNetOCR 17 | from .models.DeepLabV3 import DeepLabV3 18 | from .models.ASPOCRNet import ASPOCRNet 19 | from .models.SpatialOCRNet import SpatialOCRNet 20 | from .models.DeepLabV3plus import DeepLabV3plus 21 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 4, 10) 2 | 3 | __version__ = ".".join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/backbones/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/backbones/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/backbones/__pycache__/tf_backbones.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/backbones/__pycache__/tf_backbones.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/backbones/tf_backbones.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | ################################################################################ 4 | # Backbone 5 | ################################################################################ 6 | 7 | 8 | def create_base_model(name="ResNet50", weights="imagenet", height=None, width=None, 9 | channels=3, include_top=False, pooling=None, alpha=1.0, 10 | depth_multiplier=1, dropout=0.001): 11 | if not isinstance(height, int) or not isinstance(width, int) or not isinstance(channels, int): 12 | raise TypeError( 13 | "'height', 'width' and 'channels' need to be of type 'int'") 14 | 15 | if channels <= 0: 16 | raise ValueError( 17 | f"'channels' must be greater of equal to 1 but given was {channels}") 18 | 19 | input_shape = [height, width, channels] 20 | 21 | if name.lower() == "densenet121": 22 | if height <= 31 or width <= 31: 23 | raise ValueError( 24 | "Parameters 'height' and 'width' should not be smaller than 32.") 25 | base_model = tf.keras.applications.DenseNet121( 26 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 27 | layer_names = ["conv1/relu", "pool2_relu", 28 | "pool3_relu", "pool4_relu", "relu"] 29 | elif name.lower() == "densenet169": 30 | if height <= 31 or width <= 31: 31 | raise ValueError( 32 | "Parameters 'height' and 'width' should not be smaller than 32.") 33 | base_model = tf.keras.applications.DenseNet169( 34 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 35 | layer_names = ["conv1/relu", "pool2_relu", 36 | "pool3_relu", "pool4_relu", "relu"] 37 | elif name.lower() == "densenet201": 38 | if height <= 31 or width <= 31: 39 | raise ValueError( 40 | "Parameters 'height' and 'width' should not be smaller than 32.") 41 | base_model = tf.keras.applications.DenseNet201( 42 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 43 | layer_names = ["conv1/relu", "pool2_relu", 44 | "pool3_relu", "pool4_relu", "relu"] 45 | elif name.lower() == "efficientnetb0": 46 | if height <= 31 or width <= 31: 47 | raise ValueError( 48 | "Parameters 'height' and 'width' should not be smaller than 32.") 49 | base_model = tf.keras.applications.EfficientNetB0( 50 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 51 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 52 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 53 | elif name.lower() == "efficientnetb1": 54 | if height <= 31 or width <= 31: 55 | raise ValueError( 56 | "Parameters 'height' and 'width' should not be smaller than 32.") 57 | base_model = tf.keras.applications.EfficientNetB1( 58 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 59 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 60 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 61 | elif name.lower() == "efficientnetb2": 62 | if height <= 31 or width <= 31: 63 | raise ValueError( 64 | "Parameters 'height' and 'width' should not be smaller than 32.") 65 | base_model = tf.keras.applications.EfficientNetB2( 66 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 67 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 68 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 69 | elif name.lower() == "efficientnetb3": 70 | if height <= 31 or width <= 31: 71 | raise ValueError( 72 | "Parameters 'height' and 'width' should not be smaller than 32.") 73 | base_model = tf.keras.applications.EfficientNetB3( 74 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 75 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 76 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 77 | elif name.lower() == "efficientnetb4": 78 | if height <= 31 or width <= 31: 79 | raise ValueError( 80 | "Parameters 'height' and 'width' should not be smaller than 32.") 81 | base_model = tf.keras.applications.EfficientNetB4( 82 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 83 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 84 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 85 | elif name.lower() == "efficientnetb5": 86 | if height <= 31 or width <= 31: 87 | raise ValueError( 88 | "Parameters 'height' and 'width' should not be smaller than 32.") 89 | base_model = tf.keras.applications.EfficientNetB5( 90 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 91 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 92 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 93 | elif name.lower() == "efficientnetb6": 94 | if height <= 31 or width <= 31: 95 | raise ValueError( 96 | "Parameters 'height' and 'width' should not be smaller than 32.") 97 | base_model = tf.keras.applications.EfficientNetB6( 98 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 99 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 100 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 101 | elif name.lower() == "efficientnetb7": 102 | if height <= 31 or width <= 31: 103 | raise ValueError( 104 | "Parameters 'height' and 'width' should not be smaller than 32.") 105 | base_model = tf.keras.applications.EfficientNetB7( 106 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 107 | layer_names = ["block2a_expand_activation", "block3a_expand_activation", 108 | "block4a_expand_activation", "block6a_expand_activation", "top_activation"] 109 | elif name.lower() == "mobilenet": 110 | if height <= 31 or width <= 31: 111 | raise ValueError( 112 | "Parameters 'height' and 'width' should not be smaller than 32.") 113 | base_model = tf.keras.applications.MobileNet(include_top=include_top, weights=weights, input_shape=input_shape, 114 | pooling=pooling, alpha=alpha, depth_multiplier=depth_multiplier, dropout=dropout) 115 | layer_names = ["conv_pw_1_relu", "conv_pw_3_relu", 116 | "conv_pw_5_relu", "conv_pw_11_relu", "conv_pw_13_relu"] 117 | elif name.lower() == "mobilenetv2": 118 | if height <= 31 or width <= 31: 119 | raise ValueError( 120 | "Parameters 'height' and 'width' should not be smaller than 32.") 121 | base_model = tf.keras.applications.MobileNetV2( 122 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling, alpha=alpha) 123 | layer_names = ["block_1_expand_relu", "block_3_expand_relu", 124 | "block_6_expand_relu", "block_13_expand_relu", "out_relu"] 125 | elif name.lower() == "mobilenetv3small": 126 | if height <= 31 or width <= 31: 127 | raise ValueError( 128 | "Parameters 'height' and 'width' should not be smaller than 32.") 129 | base_model = tf.keras.applications.MobileNetV3Small( 130 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling, alpha=alpha) 131 | layer_names = ["multiply", "re_lu_3", 132 | "multiply_1", "multiply_11", "multiply_17"] 133 | elif name.lower() == "nasnetlarge": 134 | if height <= 31 or width <= 31: 135 | raise ValueError( 136 | "Parameters 'height' and 'width' should not be smaller than 32.") 137 | base_model = tf.keras.applications.NASNetLarge( 138 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 139 | layer_names = ["zero_padding2d", "cropping2d_1", 140 | "cropping2d_2", "cropping2d_3", "activation_650"] 141 | elif name.lower() == "nasnetmobile": 142 | if height <= 31 or width <= 31: 143 | raise ValueError( 144 | "Parameters 'height' and 'width' should not be smaller than 32.") 145 | base_model = tf.keras.applications.NASNetMobile( 146 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 147 | layer_names = ["zero_padding2d_4", "cropping2d_5", 148 | "cropping2d_6", "cropping2d_7", "activation_838"] 149 | elif name.lower() == "resnet50": 150 | if height <= 31 or width <= 31: 151 | raise ValueError( 152 | "Parameters 'height' and 'width' should not be smaller than 32.") 153 | base_model = tf.keras.applications.ResNet50( 154 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 155 | layer_names = ["conv1_relu", "conv2_block3_out", 156 | "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] 157 | elif name.lower() == "resnet50v2": 158 | if height <= 31 or width <= 31: 159 | raise ValueError( 160 | "Parameters 'height' and 'width' should not be smaller than 32.") 161 | base_model = tf.keras.applications.ResNet50V2( 162 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 163 | layer_names = ["conv1_conv", "conv2_block3_preact_relu", 164 | "conv3_block4_preact_relu", "conv4_block6_preact_relu", "post_relu"] 165 | elif name.lower() == "resnet101": 166 | if height <= 31 or width <= 31: 167 | raise ValueError( 168 | "Parameters 'height' and 'width' should not be smaller than 32.") 169 | base_model = tf.keras.applications.ResNet101( 170 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 171 | layer_names = ["conv1_relu", "conv2_block3_out", 172 | "conv3_block4_out", "conv4_block23_out", "conv5_block3_out"] 173 | elif name.lower() == "resnet101v2": 174 | if height <= 31 or width <= 31: 175 | raise ValueError( 176 | "Parameters 'height' and 'width' should not be smaller than 32.") 177 | base_model = tf.keras.applications.ResNet101V2( 178 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 179 | layer_names = ["conv1_conv", "conv2_block3_preact_relu", 180 | "conv3_block4_preact_relu", "conv4_block23_preact_relu", "post_relu"] 181 | elif name.lower() == "resnet152": 182 | if height <= 31 or width <= 31: 183 | raise ValueError( 184 | "Parameters 'height' and 'width' should not be smaller than 32.") 185 | base_model = tf.keras.applications.ResNet152( 186 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 187 | layer_names = ["conv1_relu", "conv2_block3_out", 188 | "conv3_block8_out", "conv4_block36_out", "conv5_block3_out"] 189 | elif name.lower() == "resnet152v2": 190 | if height <= 31 or width <= 31: 191 | raise ValueError( 192 | "Parameters 'height' and 'width' should not be smaller than 32.") 193 | base_model = tf.keras.applications.ResNet152V2( 194 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 195 | layer_names = ["conv1_conv", "conv2_block3_preact_relu", 196 | "conv3_block8_preact_relu", "conv4_block36_preact_relu", "post_relu"] 197 | elif name.lower() == "vgg16": 198 | if height <= 31 or width <= 31: 199 | raise ValueError( 200 | "Parameters 'height' and 'width' should not be smaller than 32.") 201 | base_model = tf.keras.applications.VGG16( 202 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 203 | layer_names = ["block2_conv2", "block3_conv3", 204 | "block4_conv3", "block5_conv3", "block5_pool"] 205 | elif name.lower() == "vgg19": 206 | if height <= 31 or width <= 31: 207 | raise ValueError( 208 | "Parameters 'height' and 'width' should not be smaller than 32.") 209 | base_model = tf.keras.applications.VGG19( 210 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 211 | layer_names = ["block2_conv2", "block3_conv4", 212 | "block4_conv4", "block5_conv4", "block5_pool"] 213 | elif name.lower() == "xception": 214 | if height <= 70 or width <= 70: 215 | raise ValueError( 216 | "Parameters 'height' and width' should not be smaller than 71.") 217 | base_model = tf.keras.applications.Xception( 218 | include_top=include_top, weights=weights, input_shape=input_shape, pooling=pooling) 219 | layer_names = ["block2_sepconv2_act", "block3_sepconv2_act", 220 | "block4_sepconv2_act", "block13_sepconv2_act", "block14_sepconv2_act"] 221 | else: 222 | raise ValueError("'name' should be one of 'densenet121', 'densenet169', 'densenet201', 'efficientnetb0', 'efficientnetb1', 'efficientnetb2', \ 223 | 'efficientnetb3', 'efficientnetb4', 'efficientnetb5', 'efficientnetb6', 'efficientnetb7','mobilenet', 'mobilenetv2', 'mobilenetv3small', 'nasnetlarge', 'nasnetmobile', \ 224 | 'resnet50', 'resnet50v2', 'resnet101', 'resnet101v2', 'resnet152', 'resnet152v2', 'vgg16', 'vgg19' or 'xception'.") 225 | 226 | layers = [base_model.get_layer( 227 | layer_name).output for layer_name in layer_names] 228 | 229 | return base_model, layers, layer_names 230 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .objects import KerasObject, Loss, Metric 2 | from . import functional -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/base/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/base/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/base/__pycache__/objects.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/base/__pycache__/objects.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/base/functional.py: -------------------------------------------------------------------------------- 1 | import tensorflow.keras.backend as K 2 | 3 | ################################################################################ 4 | # Helper Functions 5 | ################################################################################ 6 | def average(x, class_weights=None): 7 | if class_weights is not None: 8 | x = x * class_weights 9 | return K.mean(x) 10 | 11 | def gather_channels(*xs): 12 | return xs 13 | 14 | def round_if_needed(x, threshold): 15 | if threshold is not None: 16 | x = K.greater(x, threshold) 17 | x = K.cast(x, K.floatx()) 18 | return x 19 | 20 | ################################################################################ 21 | # Metric Functions 22 | ################################################################################ 23 | def iou_score(y_true, y_pred, class_weights=1., smooth=1e-5, threshold=None): 24 | # y_true = K.one_hot(K.squeeze(K.cast(y_true, tf.int32), axis=-1), n_classes) 25 | 26 | y_true, y_pred = gather_channels(y_true, y_pred) 27 | y_pred = round_if_needed(y_pred, threshold) 28 | axes = [1, 2] if K.image_data_format() == "channels_last" else [2, 3] 29 | 30 | intersection = K.sum(y_true * y_pred, axis=axes) 31 | union = K.sum(y_true + y_pred, axis=axes) - intersection 32 | 33 | score = (intersection + smooth) / (union + smooth) 34 | score = average(score, class_weights) 35 | 36 | return score 37 | 38 | def dice_coefficient(y_true, y_pred, beta=1.0, class_weights=1., smooth=1e-5, threshold=None): 39 | # print(y_pred) 40 | y_true, y_pred = gather_channels(y_true, y_pred) 41 | y_pred = round_if_needed(y_pred, threshold) 42 | axes = [1, 2] if K.image_data_format() == "channels_last" else [2, 3] 43 | 44 | tp = K.sum(y_true * y_pred, axis=axes) 45 | fp = K.sum(y_pred, axis=axes) - tp 46 | fn = K.sum(y_true, axis=axes) - tp 47 | 48 | score = ((1.0 + beta) * tp + smooth) / ((1.0 + beta) * tp + (beta ** 2.0) * fn + fp + smooth) 49 | # print("Score, wo avg: " + str(score)) 50 | score = average(score, class_weights) 51 | # print("Score: " + str(score)) 52 | 53 | return score 54 | 55 | def precision(y_true, y_pred, class_weights=1., smooth=1e-5, threshold=None): 56 | y_true, y_pred = gather_channels(y_true, y_pred) 57 | y_pred = round_if_needed(y_pred, threshold) 58 | axes = [1, 2] if K.image_data_format() == "channels_last" else [2, 3] 59 | 60 | tp = K.sum(y_true * y_pred, axis=axes) 61 | fp = K.sum(y_pred, axis=axes) - tp 62 | 63 | score = (tp + smooth) / (tp + fp + smooth) 64 | score = average(score, class_weights) 65 | 66 | return score 67 | 68 | def recall(y_true, y_pred, class_weights=1., smooth=1e-5, threshold=None): 69 | y_true, y_pred = gather_channels(y_true, y_pred) 70 | y_pred = round_if_needed(y_pred, threshold) 71 | axes = [1, 2] if K.image_data_format() == "channels_last" else [2, 3] 72 | 73 | tp = K.sum(y_true * y_pred, axis=axes) 74 | fn = K.sum(y_true, axis=axes) - tp 75 | 76 | score = (tp + smooth) / (tp + fn + smooth) 77 | score = average(score, class_weights) 78 | 79 | return score 80 | 81 | def tversky(y_true, y_pred, alpha=0.7, class_weights=1., smooth=1e-5, threshold=None): 82 | y_true, y_pred = gather_channels(y_true, y_pred) 83 | y_pred = round_if_needed(y_pred, threshold) 84 | axes = [1, 2] if K.image_data_format() == "channels_last" else [2, 3] 85 | 86 | tp = K.sum(y_true * y_pred, axis=axes) 87 | fp = K.sum(y_pred, axis=axes) - tp 88 | fn = K.sum(y_true, axis=axes) - tp 89 | 90 | score = (tp + smooth) / (tp + alpha * fn + (1 - alpha) * fp + smooth) 91 | score = average(score, class_weights) 92 | 93 | return score 94 | 95 | 96 | ################################################################################ 97 | # Loss Functions 98 | ################################################################################ 99 | def categorical_crossentropy(y_true, y_pred, class_weights=1.): 100 | y_true, y_pred = gather_channels(y_true, y_pred) 101 | 102 | axis = 3 if K.image_data_format() == "channels_last" else 1 103 | y_pred /= K.sum(y_pred, axis=axis, keepdims=True) 104 | 105 | y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) 106 | 107 | loss = y_true * K.log(y_pred) * class_weights 108 | return - K.mean(loss) 109 | 110 | def binary_crossentropy(y_true, y_pred): 111 | return K.mean(K.binary_crossentropy(y_true, y_pred)) 112 | 113 | def categorical_focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25): 114 | y_true, y_pred = gather_channels(y_true, y_pred) 115 | y_true = K.cast(y_true, K.floatx()) 116 | y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) 117 | 118 | loss = - y_true * (alpha * K.pow((1 - y_pred), gamma) * K.log(y_pred)) 119 | 120 | return K.mean(loss) 121 | 122 | # def categorical_focal_dice_loss(y_true, y_pred, gamma=2.0, alpha=0.25, beta=1.0, class_weights=1., smooth=1e-5, threshold=None): 123 | # dice_score = dice_coefficient(y_true, y_pred, beta=beta, class_weights=class_weights, smooth=smooth, threshold=threshold) 124 | 125 | # cat_focal_loss = categorical_focal_loss(y_true, y_pred, gamma=gamma, alpha=alpha) 126 | # return dice_loss + cat_focal_loss 127 | 128 | def binary_focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25): 129 | y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) 130 | 131 | loss_a = - y_true * (alpha * K.pow((1 - y_pred), gamma) * K.log(y_pred)) 132 | loss_b = - (1 - y_true) * ((1 - alpha) * K.pow((y_pred), gamma) * K.log(1 - y_pred)) 133 | 134 | return K.mean(loss_a + loss_b) 135 | 136 | def combo(y_true, y_pred, alpha=0.5, beta=1.0, ce_ratio=0.5, class_weights=1., smooth=1e-5, threshold=None): 137 | # alpha < 0.5 penalizes FP more, alpha > 0.5 penalizes FN more 138 | 139 | y_true, y_pred = gather_channels(y_true, y_pred) 140 | y_pred = round_if_needed(y_pred, threshold) 141 | axes = [1, 2] if K.image_data_format() == "channels_last" else [2, 3] 142 | 143 | tp = K.sum(y_true * y_pred, axis=axes) 144 | fp = K.sum(y_pred, axis=axes) - tp 145 | fn = K.sum(y_true, axis=axes) - tp 146 | 147 | dice = ((1.0 + beta) * tp + smooth) / ((1.0 + beta) * tp + (beta ** 2.0) * fn + fp + smooth) 148 | 149 | y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) 150 | 151 | ce = - (alpha * (y_true * K.log(y_pred))) + ((1 - alpha) * (1.0 - y_true) * K.log(1.0 - y_pred)) 152 | ce = K.mean(ce, axis=axes) 153 | 154 | combo = (ce_ratio * ce) - ((1 - ce_ratio) * dice) 155 | loss = average(combo, class_weights) 156 | 157 | return loss 158 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/base/objects.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Objects 3 | ################################################################################ 4 | class KerasObject: 5 | def __init__(self, name=None): 6 | self._name = name 7 | 8 | @property 9 | def __name__(self): 10 | if self._name is None: 11 | return self.__class__.__name__ 12 | return self._name 13 | 14 | @property 15 | def name(self): 16 | return self.__name__ 17 | 18 | @name.setter 19 | def name(self, name): 20 | self._name = name 21 | 22 | class Metric(KerasObject): 23 | pass 24 | 25 | class Loss(KerasObject): 26 | 27 | def __add__(self, other): 28 | if isinstance(other, Loss): 29 | return SumOfLosses(self, other) 30 | else: 31 | raise ValueError('Loss should be inherited from `Loss` class') 32 | 33 | def __radd__(self, other): 34 | return self.__add__(other) 35 | 36 | def __mul__(self, value): 37 | if isinstance(value, (int, float)): 38 | return MultipliedLoss(self, value) 39 | else: 40 | raise ValueError('Loss should be inherited from `BaseLoss` class') 41 | 42 | def __rmul__(self, other): 43 | return self.__mul__(other) 44 | 45 | class MultipliedLoss(Loss): 46 | 47 | def __init__(self, loss, multiplier): 48 | 49 | # resolve name 50 | if len(loss.__name__.split('+')) > 1: 51 | name = '{}({})'.format(multiplier, loss.__name__) 52 | else: 53 | name = '{}{}'.format(multiplier, loss.__name__) 54 | super().__init__(name=name) 55 | self.loss = loss 56 | self.multiplier = multiplier 57 | 58 | def __call__(self, gt, pr): 59 | return self.multiplier * self.loss(gt, pr) 60 | 61 | 62 | class SumOfLosses(Loss): 63 | 64 | def __init__(self, l1, l2): 65 | name = '{}_plus_{}'.format(l1.__name__, l2.__name__) 66 | super().__init__(name=name) 67 | self.l1 = l1 68 | self.l2 = l2 69 | 70 | def __call__(self, gt, pr): 71 | return self.l1(gt, pr) + self.l2(gt, pr) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/losses.py: -------------------------------------------------------------------------------- 1 | from .base import Loss 2 | from .base.functional import * 3 | 4 | ################################################################################ 5 | # Losses 6 | ################################################################################ 7 | class JaccardLoss(Loss): 8 | def __init__(self, class_weights=None, smooth=1e-5): 9 | super().__init__(name="jaccard_loss") 10 | self.class_weights = class_weights if class_weights is not None else 1.0 11 | self.smooth = smooth 12 | 13 | def __call__(self, y_true, y_pred): 14 | return 1.0 - iou_score( 15 | y_true, 16 | y_pred, 17 | class_weights=self.class_weights, 18 | smooth=self.smooth, 19 | threshold=None 20 | ) 21 | 22 | class DiceLoss(Loss): 23 | def __init__(self, beta=1.0, class_weights=None, smooth=1e-5): 24 | super().__init__(name="dice_loss") 25 | self.beta = beta 26 | self.class_weights = class_weights if class_weights is not None else 1.0 27 | self.smooth = smooth 28 | 29 | def __call__(self, y_true, y_pred): 30 | # print(y_pred) 31 | return 1.0 - dice_coefficient( 32 | y_true, 33 | y_pred, 34 | beta=self.beta, 35 | class_weights=self.class_weights, 36 | smooth=self.smooth, 37 | threshold=None 38 | ) 39 | 40 | 41 | class TverskyLoss(Loss): 42 | def __init__(self, alpha=0.7, class_weights=None, smooth=1e-5): 43 | super().__init__(name="tversky_loss") 44 | self.alpha = alpha 45 | self.class_weights = class_weights if class_weights is not None else 1.0 46 | self.smooth = smooth 47 | 48 | def __call__(self, y_true, y_pred): 49 | return 1.0 - tversky( 50 | y_true, 51 | y_pred, 52 | alpha=self.alpha, 53 | class_weights=self.class_weights, 54 | smooth=self.smooth, 55 | threshold=None 56 | ) 57 | 58 | 59 | class FocalTverskyLoss(Loss): 60 | def __init__(self, alpha=0.7, gamma=1.25, class_weights=None, smooth=1e-5): 61 | super().__init__(name="focal_tversky_loss") 62 | self.alpha = alpha 63 | self.gamma = gamma 64 | self.class_weights = class_weights if class_weights is not None else 1.0 65 | self.smooth = smooth 66 | 67 | def __call__(self, y_true, y_pred): 68 | return K.pow((1.0 - tversky( 69 | y_true, 70 | y_pred, 71 | alpha=self.alpha, 72 | class_weights=self.class_weights, 73 | smooth=self.smooth, 74 | threshold=None 75 | )), 76 | self.gamma 77 | ) 78 | 79 | 80 | class BinaryCELoss(Loss): 81 | def __init__(self): 82 | super().__init__(name="binary_crossentropy") 83 | 84 | def __call__(self, y_true, y_pred): 85 | return binary_crossentropy( 86 | y_true, 87 | y_pred 88 | ) 89 | 90 | 91 | class CategoricalCELoss(Loss): 92 | def __init__(self, class_weights=None): 93 | super().__init__(name="categorical_crossentropy") 94 | self.class_weights = class_weights 95 | 96 | def __call__(self, y_true, y_pred): 97 | return categorical_crossentropy( 98 | y_true, 99 | y_pred, 100 | class_weights=self.class_weights 101 | ) 102 | 103 | class CategoricalFocalLoss(Loss): 104 | def __init__(self, alpha=0.25, gamma=2.0): 105 | super().__init__(name="focal_loss") 106 | self.alpha = alpha 107 | self.gamma = gamma 108 | 109 | def __call__(self, y_true, y_pred): 110 | return categorical_focal_loss( 111 | y_true, 112 | y_pred, 113 | alpha=self.alpha, 114 | gamma=self.gamma 115 | ) 116 | 117 | class BinaryFocalLoss(Loss): 118 | def __init__(self, alpha=0.25, gamma=2.0): 119 | super().__init__(name='binary_focal_loss') 120 | self.alpha = alpha 121 | self.gamma = gamma 122 | 123 | def __call__(self, y_true, y_pred): 124 | return binary_focal_loss(y_true, y_pred, alpha=self.alpha, gamma=self.gamma) 125 | 126 | 127 | class ComboLoss(Loss): 128 | def __init__(self, alpha=0.5, beta=1.0, ce_ratio=0.5, class_weights=None, smooth=1e-5): 129 | super().__init__(name="combo_loss") 130 | self.alpha = alpha 131 | self.beta = beta 132 | self.ce_ratio = ce_ratio 133 | self.class_weights = class_weights if class_weights is not None else 1.0 134 | self.smooth = smooth 135 | 136 | def __call__(self, y_true, y_pred): 137 | return combo( 138 | y_true, 139 | y_pred, 140 | alpha=self.alpha, 141 | beta=self.beta, 142 | ce_ratio=self.ce_ratio, 143 | class_weights=self.class_weights, 144 | smooth=self.smooth, 145 | threshold=None 146 | ) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/metrics.py: -------------------------------------------------------------------------------- 1 | from .base import Metric 2 | from .base.functional import * 3 | 4 | ################################################################################ 5 | # Metrics 6 | ################################################################################ 7 | class IOUScore(Metric): 8 | def __init__(self, class_weights=None, threshold=None, smooth=1e-5, name=None): 9 | name = name or "iou_score" 10 | super().__init__(name=name) 11 | self.class_weights = class_weights if class_weights is not None else 1. 12 | self.threshold = threshold 13 | self.smooth = smooth 14 | 15 | def __call__(self, y_true, y_pred): 16 | return iou_score( 17 | y_true, 18 | y_pred, 19 | class_weights=self.class_weights, 20 | smooth=self.smooth, 21 | threshold=self.threshold 22 | ) 23 | 24 | class FScore(Metric): 25 | def __init__(self, beta=1, class_weights=None, threshold=None, smooth=1e-5, name=None): 26 | name = name or "f{}-score".format(beta) 27 | super().__init__(name=name) 28 | self.beta = beta 29 | self.class_weights = class_weights if class_weights is not None else 1. 30 | self.threshold = threshold 31 | self.smooth = smooth 32 | 33 | def __call__(self, y_true, y_pred): 34 | return dice_coefficient( 35 | y_true, 36 | y_pred, 37 | beta=self.beta, 38 | class_weights=self.class_weights, 39 | smooth=self.smooth, 40 | threshold=self.threshold 41 | ) 42 | 43 | class Precision(Metric): 44 | def __init__(self, class_weights=None, threshold=None, smooth=1e-5, name=None): 45 | name = name or "precision" 46 | super().__init__(name=name) 47 | self.class_weights = class_weights if class_weights is not None else 1. 48 | self.threshold = threshold 49 | self.smooth = smooth 50 | 51 | def __call__(self, y_true, y_pred): 52 | return precision( 53 | y_true, 54 | y_pred, 55 | class_weights=self.class_weights, 56 | smooth=self.smooth, 57 | threshold=self.threshold 58 | ) 59 | 60 | class Recall(Metric): 61 | def __init__(self, class_weights=None, threshold=None, smooth=1e-5, name=None): 62 | name = name or "recall" 63 | super().__init__(name=name) 64 | self.class_weights = class_weights if class_weights is not None else 1. 65 | self.threshold = threshold 66 | self.smooth = smooth 67 | 68 | def __call__(self, y_true, y_pred): 69 | return iou_score( 70 | y_true, 71 | y_pred, 72 | class_weights=self.class_weights, 73 | smooth=self.smooth, 74 | threshold=self.threshold 75 | ) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/ACFNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, AtrousSpatialPyramidPoolingV3, AttCF_Module 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Attentional Class Feature Network 9 | ################################################################################ 10 | class ACFNet(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | dilations=[6, 12, 18], **kwargs): 14 | super(ACFNet, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.filters = filters 19 | self.final_activation = final_activation 20 | self.height = height 21 | self.width = width 22 | 23 | output_layers = output_layers[:3] 24 | 25 | base_model.trainable = backbone_trainable 26 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 27 | 28 | # Layers 29 | self.aspp = AtrousSpatialPyramidPoolingV3(dilations, filters) 30 | 31 | self.dropout_2 = tf.keras.layers.Dropout(0.25) 32 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, (3, 3)) 33 | self.dropout_3 = tf.keras.layers.Dropout(0.1) 34 | self.conv1x1_bn_activation = ConvolutionBnActivation(n_classes, (1, 1), post_activation=final_activation) 35 | 36 | self.acf = AttCF_Module(filters) 37 | 38 | axis = 3 if K.image_data_format() == "channels_last" else 1 39 | self.concat = tf.keras.layers.Concatenate(axis=axis) 40 | 41 | self.conv3x3_bn_relu_3 = ConvolutionBnActivation(filters, (3, 3)) 42 | self.dropout_4 = tf.keras.layers.Dropout(0.1) 43 | self.final_conv1x1_bn_activation = ConvolutionBnActivation(n_classes, (1, 1), post_activation=final_activation) 44 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 45 | 46 | def call(self, inputs, training=None, mask=None): 47 | if training is None: 48 | training = True 49 | 50 | x = self.backbone(inputs, training=training)[-1] 51 | aspp = self.aspp(x, training=training) 52 | 53 | x = self.dropout_2(aspp, training=training) 54 | x = self.conv3x3_bn_relu_2(x, training=training) 55 | x = self.dropout_3(x, training=training) 56 | x = self.conv1x1_bn_activation(x, training=training) # coarse segmentation map 57 | 58 | x = self.acf(aspp, x, training=training) 59 | 60 | x = self.concat([x, aspp]) 61 | 62 | x = self.conv3x3_bn_relu_3(x, training=training) 63 | x = self.dropout_4(x, training=training) 64 | x = self.final_conv1x1_bn_activation(x, training=training) 65 | x = self.final_upsampling2d(x) 66 | 67 | return x 68 | 69 | def model(self): 70 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 71 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) 72 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/ASPOCRNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, SpatialOCR_ASP_Module 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # ASP Object-Contextual Representations Network 9 | ################################################################################ 10 | class ASPOCRNet(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | spatial_context_scale=1, **kwargs): 14 | super(ASPOCRNet, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.filters = filters 19 | self.final_activation = final_activation 20 | self.spatial_context_scale = spatial_context_scale 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | output_layers = output_layers[:4] 26 | 27 | base_model.trainable = backbone_trainable 28 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 29 | 30 | # Layers 31 | self.conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3)) 32 | self.dropout = tf.keras.layers.Dropout(0.05) 33 | self.conv1x1_bn_activation = ConvolutionBnActivation(filters, (1, 1), post_activation=final_activation) 34 | self.upsampling2d_x2 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 35 | 36 | self.asp_ocr = SpatialOCR_ASP_Module(filters, scale=spatial_context_scale) 37 | 38 | self.final_conv1x1_bn_activation = ConvolutionBnActivation(self.n_classes, (1, 1), post_activation=final_activation) 39 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 40 | 41 | def call(self, inputs, training=None, mask=None): 42 | if training is None: 43 | training = True 44 | 45 | x0, x1, x2, x3 = self.backbone(inputs, training=training) 46 | 47 | x_dsn = self.conv3x3_bn_relu(x3, training=training) 48 | x_dsn = self.dropout(x_dsn, training=training) 49 | x_dsn = self.conv1x1_bn_activation(x_dsn, training=training) 50 | x_dsn = self.upsampling2d_x2(x_dsn) 51 | 52 | x = self.asp_ocr(x2, x_dsn, training=training) 53 | 54 | x = self.final_conv1x1_bn_activation(x, training=training) 55 | x = self.final_upsampling2d(x) 56 | 57 | return x 58 | 59 | def model(self): 60 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 61 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/CFNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, GlobalPooling, AggCF_Module 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Co-Occurent Feature Network 9 | ################################################################################ 10 | class CFNet(tf.keras.Model): 11 | # Co-occurent Feature Network 12 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 13 | final_activation="softmax", backbone_trainable=False, 14 | lateral=True, global_pool=False, acf_pool=True, 15 | acf_kq_transform="conv", acf_concat=False, **kwargs): 16 | super(CFNet, self).__init__(**kwargs) 17 | 18 | self.n_classes = n_classes 19 | self.backbone = None 20 | self.filters = filters 21 | self.final_activation = final_activation 22 | self.lateral = lateral 23 | self.global_pool = global_pool 24 | self.acf_pool = acf_pool 25 | self.acf_kq_transform = acf_kq_transform 26 | self.acf_concat = acf_concat 27 | self.height = height 28 | self.width = width 29 | 30 | 31 | output_layers = output_layers[:4] 32 | 33 | base_model.trainable = backbone_trainable 34 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 35 | 36 | # Layers 37 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, (3, 3)) 38 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, (3, 3)) 39 | self.conv3x3_bn_relu_3 = ConvolutionBnActivation(filters, (3, 3)) 40 | self.conv3x3_bn_relu_4 = ConvolutionBnActivation(filters, (3, 3)) 41 | 42 | self.upsample2d_2x = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 43 | self.pool2d = tf.keras.layers.MaxPooling2D((2, 2), padding="same") 44 | 45 | axis = 3 if K.image_data_format() == "channels_last" else 1 46 | self.concat_1 = tf.keras.layers.Concatenate(axis=axis) 47 | self.concat_2 = tf.keras.layers.Concatenate(axis=axis) 48 | 49 | self.glob_pool = GlobalPooling(filters) 50 | 51 | self.acf = AggCF_Module(filters, kq_transform=self.acf_kq_transform, value_transform="conv", 52 | pooling=self.acf_pool, concat=self.acf_concat, dropout=0.1) 53 | 54 | self.final_conv3x3_bn_activation = ConvolutionBnActivation(n_classes, (3, 3), post_activation=final_activation) 55 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 56 | 57 | def call(self, inputs, training=None, mask=None): 58 | if training is None: 59 | training = True 60 | 61 | x0, x1, x2, x3 = self.backbone(inputs, training=training) 62 | 63 | feat = self.conv3x3_bn_relu_1(x3, training=training) 64 | if self.lateral: 65 | feat = self.upsample2d_2x(feat) 66 | c2 = self.conv3x3_bn_relu_2(x1, training=training) 67 | c2 = self.pool2d(c2) 68 | c3 = self.conv3x3_bn_relu_3(x2, training=training) 69 | feat = self.concat_1([feat, c2, c3]) 70 | feat = self.conv3x3_bn_relu_4(feat, training=training) 71 | 72 | if self.global_pool: 73 | pool = self.glob_pool(feat, training=training) 74 | feat = self.acf(feat, training=training) 75 | feat = self.concat_2([pool, feat]) 76 | else: 77 | feat = self.acf(feat, training=training) 78 | 79 | x = self.final_conv3x3_bn_activation(feat, training=training) 80 | x = self.final_upsampling2d(x) 81 | 82 | return x 83 | 84 | def model(self): 85 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 86 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/DANet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, PAM_Module, CAM_Module 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Dual Attention Network 9 | ################################################################################ 10 | class DANet(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | output_stride=8, **kwargs): 14 | super(DANet, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.filters = filters 19 | self.final_activation = final_activation 20 | self.output_stride = output_stride 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | if self.output_stride == 8: 26 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 27 | output_layers = output_layers[:3] 28 | elif self.output_stride == 16: 29 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=16, interpolation="bilinear") 30 | self.output_layers = self.output_layers[:4] 31 | else: 32 | raise ValueError("'output_stride' must be one of (8, 16), got {}".format(self.output_stride)) 33 | 34 | base_model.trainable = backbone_trainable 35 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 36 | 37 | # Layers 38 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, (3, 3)) 39 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, (3, 3)) 40 | 41 | self.pam = PAM_Module(filters) 42 | self.cam = CAM_Module(filters) 43 | 44 | self.conv3x3_bn_relu_3 = ConvolutionBnActivation(filters, (3, 3)) 45 | self.conv3x3_bn_relu_4 = ConvolutionBnActivation(filters, (3, 3)) 46 | 47 | self.dropout_1 = tf.keras.layers.Dropout(0.1) 48 | self.dropout_2 = tf.keras.layers.Dropout(0.1) 49 | self.dropout_3 = tf.keras.layers.Dropout(0.1) 50 | 51 | self.conv1x1_bn_relu_1 = ConvolutionBnActivation(n_classes, (1, 1), post_activation="relu") 52 | self.conv1x1_bn_relu_2 = ConvolutionBnActivation(n_classes, (1, 1), post_activation="relu") 53 | self.conv1x1_bn_relu_3 = ConvolutionBnActivation(n_classes, (1, 1), post_activation="relu") 54 | 55 | axis = 3 if K.image_data_format() == "channels_last" else 1 56 | self.concat_1 = tf.keras.layers.Concatenate(axis=axis) 57 | self.concat_2 = tf.keras.layers.Concatenate(axis=axis) 58 | 59 | self.final_conv1x1_bn_activation = ConvolutionBnActivation(n_classes, (1, 1), post_activation=final_activation) 60 | 61 | def call(self, inputs, training=None, mask=None): 62 | if training is None: 63 | training = True 64 | 65 | x = self.backbone(inputs, training=training)[-1] 66 | 67 | x_pam = self.conv3x3_bn_relu_1(x, training=training) 68 | x_pam_out = self.pam(x_pam, training=training) 69 | x_pam = self.conv3x3_bn_relu_3(x_pam_out, training=training) 70 | x_pam = self.dropout_1(x_pam, training=training) 71 | x_pam = self.conv1x1_bn_relu_1(x_pam, training=training) 72 | 73 | x_cam = self.conv3x3_bn_relu_2(x, training=training) 74 | x_cam_out = self.cam(x_cam, training=training) 75 | x_cam = self.conv3x3_bn_relu_4(x_cam_out, training=training) 76 | x_cam = self.dropout_2(x_cam, training=training) 77 | x_cam = self.conv1x1_bn_relu_2(x_cam, training=training) 78 | 79 | # x_pam_cam = x_pam_out + x_cam_out # maybe add or concat layer 80 | x_pam_cam = self.concat_1([x_pam_out, x_cam_out]) 81 | x = self.dropout_3(x_pam_cam, training=training) 82 | x = self.conv1x1_bn_relu_3(x, training=training) 83 | 84 | x = self.concat_2([x_pam, x_cam, x]) 85 | x = self.final_conv1x1_bn_activation(x, training=training) 86 | x = self.final_upsampling2d(x) 87 | 88 | return x 89 | 90 | def model(self): 91 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 92 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/DeepLab.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, AtrousSpatialPyramidPoolingV1 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # DeepLab 9 | ################################################################################ 10 | class DeepLab(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | final_upsample_factor=8, **kwargs): 14 | super(DeepLab, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.final_activation = final_activation 19 | self.filters = filters 20 | self.final_upsample_factor = final_upsample_factor 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | if self.final_upsample_factor == 16: 26 | output_layers = output_layers[:4] 27 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=16) 28 | elif self.final_upsample_factor == 8: 29 | output_layers = output_layers[:3] 30 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=8) 31 | elif self.final_upsample_factor == 4: 32 | output_layers = output_layers[:2] 33 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=4) 34 | else: 35 | raise ValueError("'final_upsample_factor' must be one of (4, 8, 16), got {}".format(self.final_upsample_factor)) 36 | 37 | base_model.trainable = backbone_trainable 38 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 39 | 40 | # Define Layers 41 | self.aspp = AtrousSpatialPyramidPoolingV1(filters) 42 | 43 | self.final_conv1x1_bn_activation = ConvolutionBnActivation(n_classes, kernel_size=(1, 1), post_activation=final_activation) 44 | 45 | def call(self, inputs, training=None, mask=None): 46 | if training is None: 47 | training = True 48 | 49 | x = self.backbone(inputs)[-1] 50 | 51 | aspp = self.aspp(x, training=training) 52 | 53 | upsample = self.final_upsample2d(aspp) 54 | x = self.final_conv1x1_bn_activation(upsample, training=training) 55 | 56 | return x 57 | 58 | def model(self): 59 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 60 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/DeepLabV3.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, AtrousSeparableConvolutionBnReLU, AtrousSpatialPyramidPoolingV3 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # DeepLabV3 9 | ################################################################################ 10 | class DeepLabV3(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | output_stride=8, dilations=[6, 12, 18], **kwargs): 14 | super(DeepLabV3, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.filters = filters 19 | self.final_activation = final_activation 20 | self.output_stride = output_stride 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | if self.output_stride == 8: 26 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 27 | output_layers = output_layers[:3] 28 | self.dilations = [2 * rate for rate in dilations] 29 | elif self.output_stride == 16: 30 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=16, interpolation="bilinear") 31 | output_layers = output_layers[:4] 32 | self.dilations = dilations 33 | else: 34 | raise ValueError("'output_stride' must be one of (8, 16), got {}".format(self.output_stride)) 35 | 36 | base_model.trainable = backbone_trainable 37 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 38 | 39 | # Define Layers 40 | self.atrous_sepconv_bn_relu = AtrousSeparableConvolutionBnReLU(dilation=2, filters=filters, kernel_size=3) 41 | self.aspp = AtrousSpatialPyramidPoolingV3(dilations, filters) 42 | 43 | self.conv1x1_bn_relu = ConvolutionBnActivation(filters, (1, 1)) 44 | self.conv1x1_bn_activation = ConvolutionBnActivation(n_classes, (1, 1), post_activation=final_activation) 45 | 46 | self.final_activation = tf.keras.layers.Activation(final_activation) 47 | 48 | def call(self, inputs, training=None, mask=None): 49 | if training is None: 50 | training = True 51 | 52 | x = self.backbone(inputs, training=training)[-1] 53 | 54 | x = self.atrous_sepconv_bn_relu(x, training=training) 55 | x = self.aspp(x, training=training) 56 | 57 | x = self.conv1x1_bn_relu(x, training=training) 58 | x = self.conv1x1_bn_activation(x, training=training) 59 | 60 | x = self.final_upsampling2d(x) 61 | x = self.final_activation(x) 62 | 63 | return x 64 | 65 | def model(self): 66 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 67 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/DeepLabV3plus.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ._custom_layers_and_blocks import ConvolutionBnActivation, AtrousSeparableConvolutionBnReLU, AtrousSpatialPyramidPoolingV3 4 | from ..backbones.tf_backbones import create_base_model 5 | 6 | ################################################################################ 7 | # DeepLabV3+ 8 | ################################################################################ 9 | class DeepLabV3plus(tf.keras.Model): 10 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 11 | final_activation="softmax", backbone_trainable=False, 12 | output_stride=8, dilations=[6, 12, 18], **kwargs): 13 | super(DeepLabV3plus, self).__init__(**kwargs) 14 | 15 | self.n_classes = n_classes 16 | self.backbone = None 17 | self.filters = filters 18 | self.final_activation = final_activation 19 | self.output_stride = output_stride 20 | self.dilations = dilations 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | if self.output_stride == 8: 26 | self.upsampling2d_1 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 27 | output_layers = output_layers[:3] 28 | self.dilations = [2 * rate for rate in dilations] 29 | elif self.output_stride == 16: 30 | self.upsampling2d_1 = tf.keras.layers.UpSampling2D(size=4, interpolation="bilinear") 31 | output_layers = output_layers[:4] 32 | self.dilations = dilations 33 | else: 34 | raise ValueError("'output_stride' must be one of (8, 16), got {}".format(self.output_stride)) 35 | 36 | base_model.trainable = backbone_trainable 37 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 38 | 39 | # Define Layers 40 | self.atrous_sepconv_bn_relu_1 = AtrousSeparableConvolutionBnReLU(dilation=2, filters=filters, kernel_size=3) 41 | self.atrous_sepconv_bn_relu_2 = AtrousSeparableConvolutionBnReLU(dilation=2, filters=filters, kernel_size=3) 42 | self.aspp = AtrousSpatialPyramidPoolingV3(self.dilations, filters) 43 | 44 | self.conv1x1_bn_relu_1 = ConvolutionBnActivation(filters, 1) 45 | self.conv1x1_bn_relu_2 = ConvolutionBnActivation(64, 1) 46 | 47 | self.upsample2d_1 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 48 | self.upsample2d_2 = tf.keras.layers.UpSampling2D(size=4, interpolation="bilinear") 49 | 50 | self.concat = tf.keras.layers.Concatenate(axis=3) 51 | 52 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, 3) 53 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, 3) 54 | self.conv1x1_bn_sigmoid = ConvolutionBnActivation(self.n_classes, 1, post_activation="linear") 55 | 56 | self.final_activation = tf.keras.layers.Activation(final_activation) 57 | 58 | def call(self, inputs, training=None, mask=None): 59 | if training is None: 60 | training = True 61 | 62 | x = self.backbone(inputs)[-1] 63 | low_level_features = self.backbone(inputs)[1] 64 | 65 | # Encoder Module 66 | encoder = self.atrous_sepconv_bn_relu_1(x, training) 67 | encoder = self.aspp(encoder, training) 68 | encoder = self.conv1x1_bn_relu_1(encoder, training) 69 | encoder = self.upsample2d_1(encoder) 70 | 71 | # Decoder Module 72 | decoder_low_level_features = self.atrous_sepconv_bn_relu_2(low_level_features, training) 73 | decoder_low_level_features = self.conv1x1_bn_relu_2(decoder_low_level_features, training) 74 | 75 | decoder = self.concat([decoder_low_level_features, encoder]) 76 | 77 | decoder = self.conv3x3_bn_relu_1(decoder, training) 78 | decoder = self.conv3x3_bn_relu_2(decoder, training) 79 | decoder = self.conv1x1_bn_sigmoid(decoder, training) 80 | 81 | decoder = self.upsample2d_2(decoder) 82 | decoder = self.final_activation(decoder) 83 | 84 | return decoder 85 | 86 | def model(self): 87 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 88 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/FCN.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, Upsample_x2_Add_Block 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Fully Convolutional Network 9 | ################################################################################ 10 | class FCN(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | backbone_output_factor=32, **kwargs): 14 | super(FCN, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.final_activation = final_activation 19 | self.filters = filters 20 | self.backbone_output_factor = backbone_output_factor 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | if self.backbone_output_factor == 32: 26 | output_layers = output_layers[:5] 27 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=8) 28 | elif self.backbone_output_factor == 16: 29 | output_layers = output_layers[:4] 30 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=4) 31 | elif self.backbone_output_factor == 8: 32 | output_layers = output_layers[:3] 33 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=2) 34 | else: 35 | raise ValueError("'backbone_output_factor' must be one of (8, 16, 32), got {}".format(self.backbone_output_factor)) 36 | 37 | base_model.trainable = backbone_trainable 38 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 39 | 40 | # Define Layers 41 | self.conv1x1_bn_relu = ConvolutionBnActivation(n_classes, kernel_size=(1, 1), post_activation="relu") 42 | 43 | self.upsample2d_x2_add_block1 = Upsample_x2_Add_Block(n_classes) 44 | self.upsample2d_x2_add_block2 = Upsample_x2_Add_Block(n_classes) 45 | 46 | self.final_conv1x1_bn_activation = ConvolutionBnActivation(n_classes, kernel_size=(1, 1), post_activation=final_activation) 47 | 48 | def call(self, inputs, training=None, mask=None): 49 | if training is None: 50 | training = True 51 | 52 | x = self.backbone(inputs)[-1] 53 | 54 | x = self.conv1x1_bn_relu(x) 55 | 56 | upsample = self.upsample2d_x2_add_block1(x, self.backbone(inputs)[-2], training) 57 | upsample = self.upsample2d_x2_add_block2(upsample, self.backbone(inputs)[-3], training) 58 | 59 | upsample = self.final_upsample2d(upsample) 60 | x = self.final_conv1x1_bn_activation(upsample, training=training) 61 | 62 | return x 63 | 64 | def model(self): 65 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 66 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/FPNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, FPNBlock 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Feature Pyramid Network 9 | ################################################################################ 10 | class FPNet(tf.keras.models.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=128, 12 | final_activation="softmax", backbone_trainable=False, 13 | pyramid_filters=256, aggregation="sum", dropout=None, **kwargs): 14 | super(FPNet, self).__init__() 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.final_activation = final_activation 19 | self.filters = filters 20 | self.pyramid_filters = pyramid_filters 21 | self.aggregation = aggregation 22 | self.dropout = dropout 23 | self.height = height 24 | self.width = width 25 | 26 | 27 | self.axis = 3 if K.image_data_format() == "channels_last" else 1 28 | 29 | base_model.trainable = backbone_trainable 30 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 31 | 32 | # Define Layers 33 | self.fpn_block_p5 = FPNBlock(pyramid_filters) 34 | self.fpn_block_p4 = FPNBlock(pyramid_filters) 35 | self.fpn_block_p3 = FPNBlock(pyramid_filters) 36 | self.fpn_block_p2 = FPNBlock(pyramid_filters) 37 | 38 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, (3, 3)) 39 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, (3, 3)) 40 | self.conv3x3_bn_relu_3 = ConvolutionBnActivation(filters, (3, 3)) 41 | self.conv3x3_bn_relu_4 = ConvolutionBnActivation(filters, (3, 3)) 42 | self.conv3x3_bn_relu_5 = ConvolutionBnActivation(filters, (3, 3)) 43 | self.conv3x3_bn_relu_6 = ConvolutionBnActivation(filters, (3, 3)) 44 | self.conv3x3_bn_relu_7 = ConvolutionBnActivation(filters, (3, 3)) 45 | self.conv3x3_bn_relu_8 = ConvolutionBnActivation(filters, (3, 3)) 46 | 47 | self.upsample2d_s5 = tf.keras.layers.UpSampling2D((8, 8), interpolation="nearest") 48 | self.upsample2d_s4 = tf.keras.layers.UpSampling2D((4, 4), interpolation="nearest") 49 | self.upsample2d_s3 = tf.keras.layers.UpSampling2D((2, 2), interpolation="nearest") 50 | 51 | 52 | self.add = tf.keras.layers.Add() 53 | self.concat = tf.keras.layers.Concatenate(axis=self.axis) 54 | 55 | self.spatial_dropout = tf.keras.layers.SpatialDropout2D(dropout) 56 | self.pre_final_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3)) 57 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 58 | 59 | self.final_conv3x3 = tf.keras.layers.Conv2D(self.n_classes, (3, 3), strides=(1, 1), padding='same') 60 | self.final_activation = tf.keras.layers.Activation(final_activation) 61 | 62 | 63 | def call(self, inputs, training=None, mask=None): 64 | if self.axis == 3: 65 | if inputs.shape[1] % 160 != 0 or inputs.shape[2] % 160 != 0: 66 | raise ValueError("Input height and width must be a multiple of 160, got height = " + str(inputs.shape[1]) + " and width " + str(inputs.shape[0]) + ".") 67 | 68 | if training is None: 69 | training = True 70 | 71 | x = self.backbone(inputs)[4] 72 | 73 | p5 = self.fpn_block_p5(x, self.backbone(inputs)[3], training=training) 74 | p4 = self.fpn_block_p4(p5, self.backbone(inputs)[2], training=training) 75 | p3 = self.fpn_block_p3(p4, self.backbone(inputs)[1], training=training) 76 | p2 = self.fpn_block_p2(p3, self.backbone(inputs)[0], training=training) 77 | 78 | s5 = self.conv3x3_bn_relu_1(p5, training=training) 79 | s5 = self.conv3x3_bn_relu_2(s5, training=training) 80 | s4 = self.conv3x3_bn_relu_3(p4, training=training) 81 | s4 = self.conv3x3_bn_relu_4(s4, training=training) 82 | s3 = self.conv3x3_bn_relu_5(p3, training=training) 83 | s3 = self.conv3x3_bn_relu_6(s3, training=training) 84 | s2 = self.conv3x3_bn_relu_7(p2, training=training) 85 | s2 = self.conv3x3_bn_relu_8(s2, training=training) 86 | 87 | s5 = self.upsample2d_s5(s5) 88 | s4 = self.upsample2d_s4(s4) 89 | s3 = self.upsample2d_s3(s3) 90 | 91 | if self.aggregation == "sum": 92 | x = self.add([s2, s3, s4, s5]) 93 | elif self.aggregation == "concat": 94 | x = self.concat([s2, s3, s4, s5]) 95 | else: 96 | raise ValueError("Aggregation parameter should be one of ['sum', 'concat'], got {}".format(aggregation)) 97 | 98 | if self.dropout is not None: 99 | if self.dropout >= 1 or self.dropout < 0: 100 | raise ValueError("'dropout' must be between 0 and 1, got {}".format(dropout)) 101 | else: 102 | x = self.spatial_dropout(x, training=training) 103 | 104 | # Final Stage 105 | x = self.pre_final_conv3x3_bn_relu(x, training=training) 106 | x = self.final_upsample2d(x) 107 | 108 | x = self.final_conv3x3(x) 109 | x = self.final_activation(x) 110 | 111 | return x 112 | 113 | def model(self): 114 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 115 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/HRNetOCR.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, BottleneckBlock, HighResolutionModule,SpatialGather_Module, SpatialOCR_Module 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # High Resolution Network + Object-Contextual Representations 9 | ################################################################################ 10 | class HRNetOCR(tf.keras.Model): 11 | def __init__(self, n_classes, filters=64, height=None, width=None, final_activation="softmax", 12 | spatial_ocr_scale=1, spatial_context_scale=1, **kwargs): 13 | super(HRNetOCR, self).__init__(**kwargs) 14 | 15 | self.n_classes = n_classes 16 | self.filters = filters 17 | self.height = height 18 | self.width = width 19 | self.final_activation = final_activation 20 | self.spatial_ocr_scale = spatial_ocr_scale 21 | self.spatial_context_scale = spatial_context_scale 22 | 23 | axis = 3 if K.image_data_format() == "channels_last" else 1 24 | 25 | # Stem Net 26 | ### Probably set strides to default, i.e. (1, 1) 27 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, (3, 3)) 28 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, (3, 3)) 29 | 30 | # stage 1 31 | self.bottleneck_downsample = BottleneckBlock(64, downsample=True) 32 | self.bottleneck_1 = BottleneckBlock(64) 33 | self.bottleneck_2 = BottleneckBlock(64) 34 | self.bottleneck_3 = BottleneckBlock(64) 35 | 36 | # Stage 2 37 | # Transition 38 | self.conv3x3_bn_relu_stage2_1 = ConvolutionBnActivation(48, (3, 3), momentum=0.1) 39 | self.conv3x3_bn_relu_stage2_2 = ConvolutionBnActivation(96, (3, 3), strides=(2, 2), momentum=0.1) 40 | 41 | # Stage 42 | # num_modules=1, num_branches=2, blocks=[4, 4], channels=[48, 96] 43 | self.hrn_stage2_module_1 = HighResolutionModule(num_branches=2, blocks=[4, 4], filters=[48, 96]) 44 | self.hrn_stage2_module_2 = HighResolutionModule(num_branches=2, blocks=[4, 4], filters=[48, 96]) 45 | 46 | # Stage 3 47 | # Transition 48 | self.conv3x3_bn_relu_stage3 = ConvolutionBnActivation(192, (3, 3), strides=(2, 2), momentum=0.1) 49 | 50 | # Stage 51 | # num_modules=4, num_branches=3, blocks=[4, 4, 4], channels=[48, 96, 192] 52 | self.hrn_stage3_module_1 = HighResolutionModule(num_branches=3, blocks=[4, 4, 4], filters=[48, 96, 192]) 53 | self.hrn_stage3_module_2 = HighResolutionModule(num_branches=3, blocks=[4, 4, 4], filters=[48, 96, 192]) 54 | self.hrn_stage3_module_3 = HighResolutionModule(num_branches=3, blocks=[4, 4, 4], filters=[48, 96, 192]) 55 | self.hrn_stage3_module_4 = HighResolutionModule(num_branches=3, blocks=[4, 4, 4], filters=[48, 96, 192]) 56 | 57 | # Stage 4 58 | # Transition 59 | self.conv3x3_bn_relu_stage4 = ConvolutionBnActivation(384, (3, 3), strides=(2, 2), momentum=0.1) 60 | 61 | # Stage 62 | # num_modules=3, num_branches=4, num_blocks=[4, 4, 4, 4], num_channels=[48, 96, 192, 384] 63 | self.hrn_stage4_module_1 = HighResolutionModule(num_branches=4, blocks=[4, 4, 4, 4], filters=[48, 96, 192, 384]) 64 | self.hrn_stage4_module_2 = HighResolutionModule(num_branches=4, blocks=[4, 4, 4, 4], filters=[48, 96, 192, 384]) 65 | self.hrn_stage4_module_3 = HighResolutionModule(num_branches=4, blocks=[4, 4, 4, 4], filters=[48, 96, 192, 384]) 66 | 67 | # Upsampling and Concatentation of stages 68 | self.upsample_x2 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 69 | self.upsample_x4 = tf.keras.layers.UpSampling2D(size=4, interpolation="bilinear") 70 | self.upsample_x8 = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 71 | 72 | self.concat = tf.keras.layers.Concatenate(axis=axis) 73 | 74 | # OCR 75 | self.aux_head = tf.keras.Sequential([ 76 | ConvolutionBnActivation(720, (1, 1)), 77 | tf.keras.layers.Conv2D(filters=self.n_classes, kernel_size=(1, 1), use_bias=True), 78 | tf.keras.layers.Activation(final_activation) 79 | ]) 80 | self.conv3x3_bn_relu_ocr = ConvolutionBnActivation(512, (3, 3)) 81 | 82 | self.spatial_context = SpatialGather_Module(scale=spatial_context_scale) 83 | self.spatial_ocr = SpatialOCR_Module(512, scale=spatial_ocr_scale, dropout=0.05) 84 | 85 | self.final_conv3x3 = tf.keras.layers.Conv2D(filters=self.n_classes, kernel_size=(1, 1), use_bias=True) 86 | self.final_activation = tf.keras.layers.Activation(final_activation) 87 | 88 | def call(self, inputs, training=None, mask=None): 89 | if training is None: 90 | training = True 91 | 92 | x = self.conv3x3_bn_relu_1(inputs, training=training) 93 | x = self.conv3x3_bn_relu_2(x, training=training) 94 | 95 | # Stage 1 96 | x = self.bottleneck_downsample(x, training=training) 97 | x = self.bottleneck_1(x, training=training) 98 | x = self.bottleneck_2(x, training=training) 99 | x = self.bottleneck_3(x, training=training) 100 | 101 | # Stage 2 102 | x_1 = self.conv3x3_bn_relu_stage2_1(x, training=training) 103 | x_2 = self.conv3x3_bn_relu_stage2_2(x, training=training) # includes strided convolution 104 | 105 | y_list = self.hrn_stage2_module_1(x_1, x_2, None, None, training=training) 106 | y_list = self.hrn_stage2_module_2(y_list[0], y_list[1], None, None, training=training) 107 | 108 | # Stage 3 109 | x_3 = self.conv3x3_bn_relu_stage3(y_list[1], training=training) # includes strided convolution 110 | 111 | y_list = self.hrn_stage3_module_1(y_list[0], y_list[1], x_3, None, training=training) 112 | y_list = self.hrn_stage3_module_2(y_list[0], y_list[1], y_list[2], None, training=training) 113 | y_list = self.hrn_stage3_module_3(y_list[0], y_list[1], y_list[2], None, training=training) 114 | y_list = self.hrn_stage3_module_4(y_list[0], y_list[1], y_list[2], None, training=training) 115 | 116 | # Stage 4 117 | x_4 = self.conv3x3_bn_relu_stage4(y_list[2], training=training) 118 | 119 | y_list = self.hrn_stage4_module_1(y_list[0], y_list[1], y_list[2], x_4, training=training) 120 | y_list = self.hrn_stage4_module_2(y_list[0], y_list[1], y_list[2], y_list[3], training=training) 121 | y_list = self.hrn_stage4_module_3(y_list[0], y_list[1], y_list[2], y_list[3], training=training) 122 | 123 | # Upsampling + Concatentation 124 | x_2 = self.upsample_x2(y_list[1]) 125 | x_3 = self.upsample_x4(y_list[2]) 126 | x_4 = self.upsample_x8(y_list[3]) 127 | 128 | feats = self.concat([y_list[0], x_2, x_3, x_4]) 129 | 130 | # OCR 131 | aux = self.aux_head(feats) 132 | 133 | feats = self.conv3x3_bn_relu_ocr(feats) 134 | 135 | context = self.spatial_context(feats, aux) 136 | feats = self.spatial_ocr(feats, context) 137 | 138 | out = self.final_conv3x3(feats) 139 | out = self.final_activation(out) 140 | 141 | return out 142 | 143 | def model(self): 144 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 145 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) 146 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/OCNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, Base_OC_Module, Pyramid_OC_Module, ASP_OC_Module 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Object Context Network 9 | ################################################################################ 10 | class OCNet(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | output_stride=8, dilations=[6, 12, 18], oc_module="base_oc", **kwargs): 14 | super(OCNet, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.filters = filters 19 | self.final_activation = final_activation 20 | self.output_stride = output_stride 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | if self.output_stride == 8: 26 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 27 | output_layers = output_layers[:3] 28 | self.dilations = [2 * rate for rate in dilations] 29 | elif self.output_stride == 16: 30 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=16, interpolation="bilinear") 31 | self.output_layers = self.output_layers[:4] 32 | self.dilations = dilations 33 | else: 34 | raise ValueError("'output_stride' must be one of (8, 16), got {}".format(self.output_stride)) 35 | 36 | base_model.trainable = backbone_trainable 37 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 38 | 39 | # Define Layers 40 | self.conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3)) 41 | if oc_module == "base_oc": 42 | self.oc = Base_OC_Module(filters) 43 | elif oc_module == "pyramid_oc": 44 | self.oc = Pyramid_OC_Module(filters=filters, levels=[1, 2, 3, 6]) 45 | elif oc_module == "asp_oc": 46 | self.oc = ASP_OC_Module(filters, self.dilations) 47 | else: 48 | raise ValueError("'oc_module' must be one of ('base_oc', 'pyramid_oc', 'asp_oc'), got {}".format(oc_module)) 49 | 50 | self.final_conv1x1_bn_activation = ConvolutionBnActivation(n_classes, (1, 1), post_activation=final_activation) 51 | 52 | def call(self, inputs, training=None, mask=None): 53 | if training is None: 54 | training = True 55 | 56 | x = self.backbone(inputs, training=training)[-1] 57 | 58 | x = self.conv3x3_bn_relu(x, training=training) 59 | 60 | if K.image_data_format() == "channels_last": 61 | if x.shape[1] % 6 != 0 or x.shape[2] % 6 != 0: 62 | raise ValueError("Height and Width of the backbone output must be divisible by 6, i.e. \ 63 | input_height or input_width / final_upsample_factor must be divisble by 6.") 64 | else: 65 | if x.shape[2] % 6 != 0 or x.shape[2] % 6 != 0: 66 | raise ValueError("Height and Width of the backbone output must be divisible by 6, i.e. \ 67 | input_height or input_width / final_upsample_factor must be divisble by 6.") 68 | 69 | x = self.oc(x, training=training) 70 | 71 | x = self.final_conv1x1_bn_activation(x, training=training) 72 | 73 | x = self.final_upsampling2d(x) 74 | 75 | return 76 | 77 | def model(self): 78 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 79 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) 80 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/PSPNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, SpatialContextBlock 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Pyramid Scene Parsing Network 9 | ################################################################################ 10 | class PSPNet(tf.keras.models.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | dropout=None, pooling_type="avg", final_upsample_factor=2, **kwargs): 14 | super(PSPNet, self).__init__() 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.final_activation = final_activation 19 | self.filters = filters 20 | self.dropout = dropout 21 | self.pooling_type = pooling_type 22 | self.final_upsample_factor = final_upsample_factor 23 | self.height = height 24 | self.width = width 25 | 26 | 27 | axis = 3 if K.image_data_format() == "channels_last" else 1 28 | 29 | if self.final_upsample_factor == 8: 30 | output_layers = output_layers[:3] 31 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=final_upsample_factor, interpolation="bilinear") 32 | elif self.final_upsample_factor == 4: 33 | output_layers = output_layers[:2] 34 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=final_upsample_factor, interpolation="bilinear") 35 | elif self.final_upsample_factor == 2: 36 | output_layers = output_layers[:1] 37 | self.final_upsample2d = tf.keras.layers.UpSampling2D(size=final_upsample_factor, interpolation="bilinear") 38 | else: 39 | raise ValueError("'final_upsample_factor' must be one of (2, 4, 8), got {}".format(self.final_upsample_factor)) 40 | 41 | base_model.trainable = backbone_trainable 42 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 43 | 44 | # Define Layers 45 | self.spatial_context_block_1 = SpatialContextBlock(1, filters, pooling_type) 46 | self.spatial_context_block_2 = SpatialContextBlock(2, filters, pooling_type) 47 | self.spatial_context_block_3 = SpatialContextBlock(3, filters, pooling_type) 48 | self.spatial_context_block_4 = SpatialContextBlock(6, filters, pooling_type) 49 | 50 | self.concat = tf.keras.layers.Concatenate(axis=axis) 51 | self.conv1x1_bn_relu = ConvolutionBnActivation(filters, (1, 1)) 52 | 53 | if dropout is not None: 54 | self.spatial_dropout = tf.keras.layers.SpatialDropout2D(dropout) 55 | 56 | self.final_conv3x3 = tf.keras.layers.Conv2D(self.n_classes, (3, 3), strides=(1, 1), padding='same') 57 | self.final_activation = tf.keras.layers.Activation(final_activation) 58 | 59 | 60 | def call(self, inputs, training=None, mask=None): 61 | if training is None: 62 | training = True 63 | 64 | if self.final_upsample_factor == 2: 65 | x = self.backbone(inputs) 66 | else: 67 | x = self.backbone(inputs)[-1] 68 | 69 | if K.image_data_format() == "channels_last": 70 | if x.shape[1] % 6 != 0 or x.shape[2] % 6 != 0: 71 | raise ValueError("Height and Width of the backbone output must be divisible by 6, i.e. \ 72 | input_height or input_width / final_upsample_factor must be divisble by 6.") 73 | else: 74 | if x.shape[2] % 6 != 0 or x.shape[2] % 6 != 0: 75 | raise ValueError("Height and Width of the backbone output must be divisible by 6, i.e. \ 76 | input_height or input_width / final_upsample_factor must be divisble by 6.") 77 | 78 | 79 | x1 = self.spatial_context_block_1(x, training=training) 80 | x2 = self.spatial_context_block_2(x, training=training) 81 | x3 = self.spatial_context_block_3(x, training=training) 82 | x6 = self.spatial_context_block_4(x, training=training) 83 | 84 | x = self.concat([x1, x2, x3, x6]) 85 | x = self.conv1x1_bn_relu(x, training=training) 86 | 87 | if self.dropout is not None: 88 | x = self.spatial_dropout(x, training=training) 89 | 90 | x = self.final_conv3x3(x) 91 | x = self.final_upsample2d(x) 92 | x = self.final_activation(x) 93 | 94 | return x 95 | 96 | def model(self): 97 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 98 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/SpatialOCRNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | from ._custom_layers_and_blocks import ConvolutionBnActivation, SpatialGather_Module, SpatialOCR_Module 5 | from ..backbones.tf_backbones import create_base_model 6 | 7 | ################################################################################ 8 | # Spatial Object-Contextual Representations Network 9 | ################################################################################ 10 | class SpatialOCRNet(tf.keras.Model): 11 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=256, 12 | final_activation="softmax", backbone_trainable=False, 13 | spatial_ocr_scale=1, spatial_context_scale=1, **kwargs): 14 | super(SpatialOCRNet, self).__init__(**kwargs) 15 | 16 | self.n_classes = n_classes 17 | self.backbone = None 18 | self.filters = filters 19 | self.final_activation = final_activation 20 | self.spatial_ocr_scale = spatial_ocr_scale 21 | self.spatial_context_scale = spatial_context_scale 22 | self.height = height 23 | self.width = width 24 | 25 | 26 | output_layers = output_layers[:4] 27 | 28 | base_model.trainable = backbone_trainable 29 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 30 | 31 | # Layers 32 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, (3, 3)) 33 | self.upsampling2d_x2 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 34 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, (3, 3)) 35 | self.dropout_1 = tf.keras.layers.Dropout(0.05) 36 | self.conv1x1_bn_activation = ConvolutionBnActivation(filters, (1, 1), post_activation=final_activation) 37 | 38 | self.spatial_context = SpatialGather_Module(scale=spatial_context_scale) 39 | self.spatial_ocr = SpatialOCR_Module(filters, scale=spatial_ocr_scale, dropout=0.05) 40 | 41 | self.conv3x3_bn_relu_3 = ConvolutionBnActivation(filters, (3, 3)) 42 | self.dropout = tf.keras.layers.Dropout(0.05) 43 | self.final_conv1x1_bn_activation = ConvolutionBnActivation(self.n_classes, (1, 1), post_activation=final_activation) 44 | self.final_upsampling2d = tf.keras.layers.UpSampling2D(size=8, interpolation="bilinear") 45 | 46 | def call(self, inputs, training=None, mask=None): 47 | if training is None: 48 | training = True 49 | 50 | x0, x1, x2, x3 = self.backbone(inputs, training=training) 51 | 52 | x = self.conv3x3_bn_relu_1(x3, training=training) 53 | x = self.upsampling2d_x2(x) 54 | 55 | x_dsn = self.conv3x3_bn_relu_2(x2, training=training) 56 | x_dsn = self.dropout(x_dsn, training=training) 57 | x_dsn = self.conv1x1_bn_activation(x_dsn, training=training) 58 | 59 | context = self.spatial_context(x, x_dsn, training=training) 60 | 61 | x = self.spatial_ocr(x, context, training=training) 62 | 63 | x = self.conv3x3_bn_relu_3(x, training=training) 64 | x = self.dropout(x, training=training) 65 | x = self.final_conv1x1_bn_activation(x, training=training) 66 | x = self.final_upsampling2d(x) 67 | 68 | return x 69 | 70 | def model(self): 71 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 72 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/UNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ._custom_layers_and_blocks import ConvolutionBnActivation, Upsample_x2_Block 4 | from ..backbones.tf_backbones import create_base_model 5 | 6 | ################################################################################ 7 | # UNet 8 | ################################################################################ 9 | class UNet(tf.keras.Model): 10 | def __init__(self, n_classes, base_model, output_layers, height=None, width=None, filters=128, 11 | final_activation="softmax", backbone_trainable=False, 12 | up_filters=[32, 64, 128, 256, 512], include_top_conv=True, **kwargs): 13 | super(UNet, self).__init__(**kwargs) 14 | 15 | self.n_classes = n_classes 16 | self.backbone = None 17 | self.final_activation = final_activation 18 | self.filters = filters 19 | self.up_filters = up_filters 20 | self.include_top_conv = include_top_conv 21 | self.height = height 22 | self.width = width 23 | 24 | 25 | base_model.trainable = backbone_trainable 26 | self.backbone = tf.keras.Model(inputs=base_model.input, outputs=output_layers) 27 | 28 | # Define Layers 29 | self.conv3x3_bn_relu1 = ConvolutionBnActivation(filters, kernel_size=(3, 3), post_activation="relu") 30 | self.conv3x3_bn_relu2 = ConvolutionBnActivation(filters, kernel_size=(3, 3), post_activation="relu") 31 | 32 | self.upsample2d_x2_block1 = Upsample_x2_Block(up_filters[4]) 33 | self.upsample2d_x2_block2 = Upsample_x2_Block(up_filters[3]) 34 | self.upsample2d_x2_block3 = Upsample_x2_Block(up_filters[2]) 35 | self.upsample2d_x2_block4 = Upsample_x2_Block(up_filters[1]) 36 | self.upsample2d_x2_block5 = Upsample_x2_Block(up_filters[0]) 37 | 38 | self.final_conv3x3 = tf.keras.layers.Conv2D(self.n_classes, (3, 3), strides=(1, 1), padding='same') 39 | 40 | self.final_activation = tf.keras.layers.Activation(final_activation) 41 | 42 | def call(self, inputs, training=None, mask=None): 43 | if training is None: 44 | training = True 45 | 46 | if self.include_top_conv: 47 | conv1 = self.conv3x3_bn_relu1(inputs, training=training) 48 | conv1 = self.conv3x3_bn_relu2(conv1, training=training) 49 | else: 50 | conv1 = None 51 | 52 | x = self.backbone(inputs)[4] 53 | 54 | upsample = self.upsample2d_x2_block1(x, self.backbone(inputs)[3], training) 55 | upsample = self.upsample2d_x2_block2(upsample, self.backbone(inputs)[2], training) 56 | upsample = self.upsample2d_x2_block3(upsample, self.backbone(inputs)[1], training) 57 | upsample = self.upsample2d_x2_block4(upsample, self.backbone(inputs)[0], training) 58 | upsample = self.upsample2d_x2_block5(upsample, conv1, training) 59 | 60 | x = self.final_conv3x3(upsample, training=training) 61 | x = self.final_activation(x) 62 | 63 | return x 64 | 65 | def model(self): 66 | x = tf.keras.layers.Input(shape=(self.height, self.width, 3)) 67 | return tf.keras.Model(inputs=[x], outputs=self.call(x)) -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/DANet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/DANet.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/DeepLab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/DeepLab.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/DeepLabV3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/DeepLabV3.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/DeepLabV3plus.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/DeepLabV3plus.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/FCN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/FCN.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/FPNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/FPNet.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/OCNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/OCNet.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/PSPNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/PSPNet.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/UNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/UNet.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/__pycache__/_custom_layers_and_blocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models/3714839ee49759b26e2b0ae3d3a0aa37b00df962/tensorflow_advanced_segmentation_models/models/__pycache__/_custom_layers_and_blocks.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_advanced_segmentation_models/models/_custom_layers_and_blocks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | ################################################################################ 5 | # Layers 6 | ################################################################################ 7 | from tensorflow.keras import activations 8 | 9 | 10 | class ConvolutionBnActivation(tf.keras.layers.Layer): 11 | """ 12 | """ 13 | # def __init__(self, filters, kernel_size, strides=(1, 1), activation=tf.keras.activations.relu, **kwargs): 14 | def __init__(self, filters, kernel_size, strides=(1, 1), padding="same", data_format=None, dilation_rate=(1, 1), 15 | groups=1, activation=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, 16 | bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_batchnorm=False, 17 | axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, trainable=True, 18 | post_activation="relu", block_name=None, **kwargs): 19 | super(ConvolutionBnActivation, self).__init__(**kwargs) 20 | 21 | 22 | # 2D Convolution Arguments 23 | self.filters = filters 24 | self.kernel_size = kernel_size 25 | self.strides = strides 26 | self.padding = padding 27 | self.data_format = data_format 28 | self.dilation_rate = dilation_rate 29 | self.activation = activation 30 | self.use_bias = not (use_batchnorm) 31 | self.kernel_initializer = kernel_initializer 32 | self.bias_initializer = bias_initializer 33 | self.kernel_regularizer = kernel_regularizer 34 | self.bias_regularizer = bias_regularizer 35 | self.activity_regularizer = activity_regularizer 36 | self.kernel_constraint = kernel_constraint 37 | self.bias_constraint = bias_constraint 38 | 39 | # Batch Normalization Arguments 40 | self.axis = axis 41 | self.momentum = momentum 42 | self.epsilon = epsilon 43 | self.center = center 44 | self.scale = scale 45 | self.trainable = trainable 46 | 47 | self.block_name = block_name 48 | 49 | self.conv = None 50 | self.bn = None 51 | #tf.keras.layers.BatchNormalization(scale=False, momentum=0.9) 52 | # self.post_activation = tf.keras.layers.Activation(post_activation) 53 | self.post_activation = activations.get(post_activation) 54 | 55 | def build(self, input_shape): 56 | self.conv = tf.keras.layers.Conv2D( 57 | filters=self.filters, 58 | kernel_size=self.kernel_size, 59 | strides=self.strides, 60 | padding=self.padding, 61 | data_format=self.data_format, 62 | dilation_rate=self.dilation_rate, 63 | activation=self.activation, 64 | use_bias=self.use_bias, 65 | kernel_initializer=self.kernel_initializer, 66 | bias_initializer=self.bias_initializer, 67 | kernel_regularizer=self.kernel_regularizer, 68 | bias_regularizer=self.bias_regularizer, 69 | activity_regularizer=self.activity_regularizer, 70 | kernel_constraint=self.kernel_constraint, 71 | bias_constraint=self.bias_constraint, 72 | name=self.block_name + "_conv" if self.block_name is not None else None 73 | 74 | ) 75 | 76 | self.bn = tf.keras.layers.BatchNormalization( 77 | axis=self.axis, 78 | momentum=self.momentum, 79 | epsilon=self.epsilon, 80 | center=self.center, 81 | scale=self.scale, 82 | trainable=self.trainable, 83 | name=self.block_name + "_bn" if self.block_name is not None else None 84 | ) 85 | 86 | def call(self, x, training=None): 87 | x = self.conv(x) 88 | x = self.bn(x, training=training) 89 | x = self.post_activation(x) 90 | 91 | return x 92 | 93 | def compute_output_shape(self, input_shape): 94 | print(input_shape) 95 | return [input_shape[0], input_shape[1], input_shape[2], self.filters] 96 | 97 | def get_config(self): 98 | config = { 99 | "filters": self.filters, 100 | "kernel_size": self.kernel_size, 101 | "strides": self.strides, 102 | "padding": self.padding, 103 | "data_format": self.data_format, 104 | "dilation_rate": self.dilation_rate, 105 | "activation": activations.serialize(self.activation), 106 | "use_batchnorm": not self.use_bias, 107 | "kernel_initializer": self.kernel_initializer, 108 | "bias_initializer": self.bias_initializer, 109 | "kernel_regularizer": self.kernel_regularizer, 110 | "bias_regularizer": self.bias_regularizer, 111 | "activity_regularizer": self.activity_regularizer, 112 | "kernel_constraint": self.kernel_constraint, 113 | "bias_constraint": self.bias_constraint, 114 | # Batch Normalization Arguments 115 | "axis": self.axis, 116 | "momentum": self.momentum, 117 | "epsilon": self.epsilon, 118 | "center": self.center, 119 | "scale": self.scale, 120 | "trainable": self.trainable, 121 | "block_name": self.block_name, 122 | 123 | 124 | # self.use_bias = not (use_batchnorm) 125 | # 126 | # 127 | # 128 | # 129 | # self.conv = None 130 | # self.bn = None 131 | 132 | 133 | # tf.keras.layers.BatchNormalization(scale=False, momentum=0.9) 134 | # self.post_activation = tf.keras.layers.Activation(post_activation) 135 | "post_activation": activations.serialize(self.post_activation), 136 | } 137 | base_config = super(ConvolutionBnActivation, self).get_config() 138 | return dict(list(base_config.items()) + list(config.items())) 139 | 140 | class AtrousSeparableConvolutionBnReLU(tf.keras.layers.Layer): 141 | """ 142 | """ 143 | def __init__(self, filters, kernel_size, strides=[1, 1, 1, 1], padding="SAME", data_format=None, 144 | dilation=None, channel_multiplier=1, axis=-1, momentum=0.99, epsilon=0.001, 145 | center=True, scale=True, trainable=True, post_activation=None, block_name=None): 146 | super(AtrousSeparableConvolutionBnReLU, self).__init__() 147 | 148 | self.filters = filters 149 | self.kernel_size = kernel_size 150 | self.strides = strides 151 | self.padding = padding 152 | self.data_format = data_format 153 | self.dilation = dilation 154 | self.channel_multiplier = channel_multiplier 155 | 156 | # Batch Normalization Arguments 157 | self.axis = axis 158 | self.momentum = momentum 159 | self.epsilon = epsilon 160 | self.center = center 161 | self.scale = scale 162 | self.trainable = trainable 163 | 164 | self.block_name = block_name 165 | 166 | self.bn = None 167 | 168 | self.activation = tf.keras.layers.Activation(tf.keras.activations.relu) 169 | 170 | self.dw_filter = None 171 | self.pw_filter = None 172 | 173 | def build(self, input_shape): 174 | in_channels = input_shape[-1] 175 | self.dw_filter = self.add_weight( 176 | name="dw_kernel", 177 | shape=[self.kernel_size, self.kernel_size, in_channels, self.channel_multiplier], 178 | initializer=tf.keras.initializers.GlorotNormal(), 179 | regularizer=tf.keras.regularizers.l2(l=1e-4), 180 | trainable=True 181 | ) 182 | self.pw_filter = self.add_weight( 183 | name="pw_kernel", 184 | shape=[1, 1, in_channels * self.channel_multiplier, self.filters], 185 | initializer=tf.keras.initializers.GlorotNormal(), 186 | regularizer=tf.keras.regularizers.l2(l=1e-4), 187 | trainable=True 188 | ) 189 | 190 | self.bn = tf.keras.layers.BatchNormalization( 191 | axis=self.axis, 192 | momentum=self.momentum, 193 | epsilon=self.epsilon, 194 | center=self.center, 195 | scale=self.scale, 196 | name=self.block_name + "_bn" if self.block_name is not None else None 197 | ) 198 | 199 | def call(self, x, training=None): 200 | x = tf.nn.separable_conv2d( 201 | x, 202 | self.dw_filter, 203 | self.pw_filter, 204 | strides=self.strides, 205 | dilations=[self.dilation, self.dilation], 206 | padding=self.padding, 207 | ) 208 | x = self.bn(x, training=training) 209 | x = self.activation(x) 210 | 211 | return x 212 | 213 | def compute_output_shape(self, input_shape): 214 | print(input_shape) 215 | return [input_shape[0], input_shape[1], input_shape[2], self.filters] 216 | 217 | class AtrousSpatialPyramidPoolingV3(tf.keras.layers.Layer): 218 | """ 219 | """ 220 | def __init__(self, atrous_rates, filters): 221 | super(AtrousSpatialPyramidPoolingV3, self).__init__() 222 | self.filters = filters 223 | 224 | # adapt scale and mometum for bn 225 | self.conv_bn_relu = ConvolutionBnActivation(filters=filters, kernel_size=1) 226 | 227 | self.atrous_sepconv_bn_relu_1 = AtrousSeparableConvolutionBnReLU(dilation=atrous_rates[0], filters=filters, kernel_size=3) 228 | self.atrous_sepconv_bn_relu_2 = AtrousSeparableConvolutionBnReLU(dilation=atrous_rates[1], filters=filters, kernel_size=3) 229 | self.atrous_sepconv_bn_relu_3 = AtrousSeparableConvolutionBnReLU(dilation=atrous_rates[2], filters=filters, kernel_size=3) 230 | 231 | # 1x1 reduction convolutions 232 | self.conv_reduction_1 = tf.keras.layers.Conv2D( 233 | filters=256, 234 | kernel_size=1, 235 | use_bias=False, 236 | kernel_regularizer=tf.keras.regularizers.l2(l=1e-4)) 237 | 238 | 239 | def call(self, input_tensor, training=None): 240 | # global average pooling input_tensor 241 | glob_avg_pool = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2], keepdims=True))(input_tensor) 242 | glob_avg_pool = self.conv_bn_relu(glob_avg_pool, training=training) 243 | glob_avg_pool = tf.keras.layers.Lambda(lambda x: tf.image.resize(x, [input_tensor.shape[1], input_tensor.shape[2]]))(glob_avg_pool) 244 | 245 | # process with atrous 246 | w = self.conv_bn_relu(input_tensor, training=training) 247 | x = self.atrous_sepconv_bn_relu_1(input_tensor, training=training) 248 | y = self.atrous_sepconv_bn_relu_2(input_tensor, training=training) 249 | z = self.atrous_sepconv_bn_relu_3(input_tensor, training=training) 250 | 251 | # concatenation 252 | net = tf.concat([glob_avg_pool, w, x, y, z], axis=-1) 253 | net = self.conv_reduction_1(net, training=training) 254 | 255 | return net 256 | 257 | def compute_output_shape(self, input_shape): 258 | print(input_shape) 259 | return [input_shape[0], input_shape[1], input_shape[2], 256] 260 | 261 | class Upsample_x2_Block(tf.keras.layers.Layer): 262 | """ 263 | """ 264 | def __init__(self, filters, trainable=None, **kwargs): 265 | super(Upsample_x2_Block, self).__init__(**kwargs) 266 | self.trainable = trainable 267 | self.filters = filters 268 | 269 | self.upsample2d_size2 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 270 | self.conv2x2_bn_relu = tf.keras.layers.Conv2D(filters, kernel_size=(2, 2), padding="same") 271 | 272 | self.concat = tf.keras.layers.Concatenate(axis=3) 273 | 274 | self.conv3x3_bn_relu1 = ConvolutionBnActivation(filters, kernel_size=(3, 3), post_activation="relu") 275 | self.conv3x3_bn_relu2 = ConvolutionBnActivation(filters, kernel_size=(3, 3), post_activation="relu") 276 | 277 | def call(self, x, skip=None, training=None): 278 | x = self.upsample2d_size2(x) 279 | x = self.conv2x2_bn_relu(x, training=training) 280 | 281 | if skip is not None: 282 | x = self.concat([x, skip]) 283 | 284 | x = self.conv3x3_bn_relu1(x, training=training) 285 | x = self.conv3x3_bn_relu2(x, training=training) 286 | 287 | return x 288 | 289 | def compute_output_shape(self, input_shape): 290 | print(input_shape) 291 | return [input_shape[0], input_shape[1] * 2, input_shape[2] * 2, input_shape[3]] 292 | 293 | def get_config(self): 294 | config = { 295 | "filters": self.filters, 296 | "trainable": self.trainable, 297 | } 298 | base_config = super(Upsample_x2_Block, self).get_config() 299 | return dict(list(base_config.items()) + list(config.items())) 300 | 301 | class Upsample_x2_Add_Block(tf.keras.layers.Layer): 302 | """ 303 | """ 304 | def __init__(self, filters): 305 | super(Upsample_x2_Add_Block, self).__init__() 306 | 307 | self.upsample2d_size2 = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear") 308 | self.conv1x1_bn_relu = tf.keras.layers.Conv2D(filters, kernel_size=(1, 1), padding="same") 309 | self.add = tf.keras.layers.Add() 310 | 311 | def call(self, x, skip, training=None): 312 | x = self.upsample2d_size2(x) 313 | skip = self.conv1x1_bn_relu(x, training=training) 314 | x = self.add([x, skip]) 315 | 316 | return x 317 | 318 | def compute_output_shape(self, input_shape): 319 | print(input_shape) 320 | return [input_shape[0], input_shape[1] * 2, input_shape[2] * 2, input_shape[3]] 321 | 322 | class SpatialContextBlock(tf.keras.layers.Layer): 323 | def __init__(self, level, filters=256, pooling_type="avg"): 324 | super(SpatialContextBlock, self).__init__() 325 | 326 | self.level = level 327 | self.filters = filters 328 | self.pooling_type = pooling_type 329 | 330 | self.pooling2d = None 331 | self.conv1x1_bn_relu = ConvolutionBnActivation(filters, kernel_size=(1, 1)) 332 | self.upsample2d = None 333 | 334 | def build(self, input_shape): 335 | if self.pooling_type not in ("max", "avg"): 336 | raise ValueError("Unsupported pooling type - '{}'".format(pooling_type) + "Use 'avg' or 'max'") 337 | 338 | self.pooling2d = tf.keras.layers.MaxPool2D if self.pooling_type == "max" else tf.keras.layers.AveragePooling2D 339 | 340 | spatial_size = input_shape[1:3] if K.image_data_format() == "channels_last" else input_shape[2:] 341 | 342 | pool_size = up_size = [spatial_size[0] // self.level, spatial_size[1] // self.level] 343 | self.pooling2d = self.pooling2d(pool_size, strides=pool_size, padding="same") 344 | 345 | self.upsample2d = tf.keras.layers.UpSampling2D(up_size, interpolation="bilinear") 346 | 347 | def call(self, x, training=None): 348 | x = self.pooling2d(x, training=training) 349 | x = self.conv1x1_bn_relu(x, training=training) 350 | x = self.upsample2d(x) 351 | 352 | return x 353 | 354 | def compute_output_shape(self, input_shape): 355 | return input_shape 356 | 357 | class FPNBlock(tf.keras.layers.Layer): 358 | def __init__(self, filters): 359 | super(FPNBlock, self).__init__() 360 | 361 | self.filters = filters 362 | self.input_filters = None 363 | 364 | self.conv1x1_1 = tf.keras.layers.Conv2D(filters, (1, 1), padding="same", kernel_initializer="he_uniform") 365 | self.conv1x1_2 = tf.keras.layers.Conv2D(filters, (1, 1), padding="same", kernel_initializer="he_uniform") 366 | 367 | self.upsample2d = tf.keras.layers.UpSampling2D((2, 2)) 368 | self.add = tf.keras.layers.Add() 369 | 370 | def build(self, input_shape): 371 | if input_shape != self.filters: 372 | self.input_filters = True 373 | 374 | def call(self, x, skip, training=None): 375 | if self.input_filters: 376 | x = self.conv1x1_1(x) 377 | 378 | skip = self.conv1x1_2(skip) 379 | x = self.upsample2d(x) 380 | x = self.add([x, skip]) 381 | 382 | return x 383 | 384 | class AtrousSpatialPyramidPoolingV1(tf.keras.layers.Layer): 385 | def __init__(self, filters): 386 | super(AtrousSpatialPyramidPoolingV1, self).__init__() 387 | 388 | self.filters = filters 389 | 390 | self.conv1x1_bn_relu = ConvolutionBnActivation(filters, (1, 1), post_activation="relu") 391 | self.atrous6_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), dilation_rate=6, post_activation="relu") 392 | self.atrous12_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), dilation_rate=12, post_activation="relu") 393 | self.atrous18_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), dilation_rate=18, post_activation="relu") 394 | self.atrous24_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), dilation_rate=24, post_activation="relu") 395 | 396 | axis = 3 if K.image_data_format() == "channels_last" else 1 397 | 398 | self.concat = tf.keras.layers.Concatenate(axis=axis) 399 | 400 | def call(self, x, training=None): 401 | 402 | x1 = self.conv1x1_bn_relu(x, training=training) 403 | x3_r6 = self.atrous6_conv3x3_bn_relu(x, training=training) 404 | x3_r12 = self.atrous12_conv3x3_bn_relu(x, training=training) 405 | x3_r18 = self.atrous18_conv3x3_bn_relu(x, training=training) 406 | x3_r24 = self.atrous24_conv3x3_bn_relu(x, training=training) 407 | 408 | x = self.concat([x1, x3_r6, x3_r12, x3_r18, x3_r24]) 409 | 410 | return x 411 | 412 | class Base_OC_Module(tf.keras.layers.Layer): 413 | def __init__(self, filters): 414 | super(Base_OC_Module, self).__init__() 415 | 416 | self.filters = filters 417 | 418 | axis = 3 if K.image_data_format() == "channels_last" else 1 419 | 420 | self.self_attention_block2d = SelfAttentionBlock2D(filters) 421 | self.concat = tf.keras.layers.Concatenate(axis=axis) 422 | self.conv1x1_bn_relu = ConvolutionBnActivation(filters, (1, 1)) 423 | 424 | def call(self, x, training=None): 425 | 426 | attention = self.self_attention_block2d(x, training=training) 427 | x = self.concat([attention, x]) 428 | x = self.conv1x1_bn_relu(x, training=training) 429 | 430 | return x 431 | 432 | class Pyramid_OC_Module(tf.keras.layers.Layer): 433 | def __init__(self, levels, filters=256, pooling_type="avg"): 434 | super(Pyramid_OC_Module, self).__init__() 435 | 436 | self.levels = levels 437 | self.filters = filters 438 | self.pooling_type = pooling_type 439 | 440 | self.pyramid_block_1 = SelfAttentionBlock2D(filters) 441 | self.pyramid_block_2 = SelfAttentionBlock2D(filters) 442 | self.pyramid_block_3 = SelfAttentionBlock2D(filters) 443 | self.pyramid_block_6 = SelfAttentionBlock2D(filters) 444 | 445 | self.pooling2d_1 = None 446 | self.pooling2d_2 = None 447 | self.pooling2d_3 = None 448 | self.pooling2d_6 = None 449 | 450 | self.conv1x1_bn_relu_1 = ConvolutionBnActivation(filters, kernel_size=(1, 1)) 451 | 452 | self.upsample2d_1 = None 453 | self.upsample2d_2 = None 454 | self.upsample2d_3 = None 455 | self.upsample2d_6 = None 456 | 457 | axis = 3 if K.image_data_format() == "channels_last" else 1 458 | 459 | self.concat = tf.keras.layers.Concatenate(axis=axis) 460 | self.conv1x1_bn_relu_2 = ConvolutionBnActivation(filters, kernel_size=(1, 1)) 461 | 462 | def build(self, input_shape): 463 | if self.pooling_type not in ("max", "avg"): 464 | raise ValueError("Unsupported pooling type - '{}'".format(pooling_type) + "Use 'avg' or 'max'") 465 | 466 | self.pooling2d_1 = tf.keras.layers.MaxPool2D if self.pooling_type == "max" else tf.keras.layers.AveragePooling2D 467 | self.pooling2d_2 = tf.keras.layers.MaxPool2D if self.pooling_type == "max" else tf.keras.layers.AveragePooling2D 468 | self.pooling2d_3 = tf.keras.layers.MaxPool2D if self.pooling_type == "max" else tf.keras.layers.AveragePooling2D 469 | self.pooling2d_6 = tf.keras.layers.MaxPool2D if self.pooling_type == "max" else tf.keras.layers.AveragePooling2D 470 | 471 | spatial_size = input_shape[1:3] if K.image_data_format() == "channels_last" else input_shape[2:] 472 | pool_size_1 = up_size_1 = [spatial_size[0] // self.levels[0], spatial_size[1] // self.levels[0]] 473 | pool_size_2 = up_size_2 = [spatial_size[0] // self.levels[1], spatial_size[1] // self.levels[1]] 474 | pool_size_3 = up_size_3 = [spatial_size[0] // self.levels[2], spatial_size[1] // self.levels[2]] 475 | pool_size_6 = up_size_6 = [spatial_size[0] // self.levels[3], spatial_size[1] // self.levels[3]] 476 | 477 | self.pooling2d_1 = self.pooling2d_1(pool_size_1, strides=pool_size_1, padding="same") 478 | self.pooling2d_2 = self.pooling2d_2(pool_size_2, strides=pool_size_2, padding="same") 479 | self.pooling2d_3 = self.pooling2d_3(pool_size_3, strides=pool_size_3, padding="same") 480 | self.pooling2d_6 = self.pooling2d_6(pool_size_6, strides=pool_size_6, padding="same") 481 | 482 | self.upsample2d_1 = tf.keras.layers.UpSampling2D(up_size_1, interpolation="bilinear") 483 | self.upsample2d_2 = tf.keras.layers.UpSampling2D(up_size_2, interpolation="bilinear") 484 | self.upsample2d_3 = tf.keras.layers.UpSampling2D(up_size_3, interpolation="bilinear") 485 | self.upsample2d_6 = tf.keras.layers.UpSampling2D(up_size_6, interpolation="bilinear") 486 | 487 | def call(self, x, training=None): 488 | attention_1 = self.pooling2d_1(x, training=training) 489 | attention_1 = self.pyramid_block_1(attention_1, training=training) 490 | attention_1 = self.upsample2d_1(attention_1) 491 | attention_2 = self.pooling2d_2(x, training=training) 492 | attention_2 = self.pyramid_block_2(attention_2, training=training) 493 | attention_2 = self.upsample2d_2(attention_2) 494 | attention_3 = self.pooling2d_3(x, training=training) 495 | attention_3 = self.pyramid_block_3(attention_3, training=training) 496 | attention_3 = self.upsample2d_3(attention_3) 497 | attention_6 = self.pooling2d_6(x, training=training) 498 | attention_6 = self.pyramid_block_6(attention_6, training=training) 499 | attention_6 = self.upsample2d_6(attention_6) 500 | 501 | x = self.conv1x1_bn_relu_1(x, training=training) 502 | 503 | x = self.concat([attention_1, attention_2, attention_3, attention_6, x]) 504 | x = self.conv1x1_bn_relu_2(x, training=training) 505 | 506 | return x 507 | 508 | class ASP_OC_Module(tf.keras.layers.Layer): 509 | def __init__(self, filters, dilations): 510 | super(ASP_OC_Module, self).__init__() 511 | self.filters = filters 512 | self.dilations = dilations 513 | 514 | self.conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3)) 515 | self.context = Base_OC_Module(filters) 516 | 517 | self.conv1x1_bn_relu_1 = ConvolutionBnActivation(filters, (1, 1), post_activation="relu") 518 | self.atrous6_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), dilation_rate=6, post_activation="relu") 519 | self.atrous12_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), dilation_rate=12, post_activation="relu") 520 | self.atrous18_conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), dilation_rate=18, post_activation="relu") 521 | 522 | axis = 3 if K.image_data_format() == "channels_last" else 1 523 | 524 | self.concat = tf.keras.layers.Concatenate(axis=axis) 525 | self.conv1x1_bn_relu_2 = ConvolutionBnActivation(filters, (1, 1)) 526 | 527 | def call(self, x, training=None): 528 | 529 | a = self.conv3x3_bn_relu(x, training=training) 530 | a = self.context(a, training=training) 531 | b = self.conv1x1_bn_relu_1(x, training=training) 532 | c = self.atrous6_conv3x3_bn_relu(x, training=training) 533 | d = self.atrous12_conv3x3_bn_relu(x, training=training) 534 | e = self.atrous18_conv3x3_bn_relu(x, training=training) 535 | 536 | x = self.concat([a, b, c, d, e]) 537 | x = self.conv1x1_bn_relu_2(x, training=training) 538 | 539 | return x 540 | 541 | 542 | class PAM_Module(tf.keras.layers.Layer): 543 | def __init__(self, filters): 544 | super(PAM_Module, self).__init__() 545 | 546 | self.filters = filters 547 | 548 | axis = 3 if K.image_data_format() == "channels_last" else 1 549 | self.concat = tf.keras.layers.Concatenate(axis=axis) 550 | 551 | self.conv1x1_bn_relu_1 = ConvolutionBnActivation(filters, (1, 1)) 552 | self.conv1x1_bn_relu_2 = ConvolutionBnActivation(filters, (1, 1)) 553 | self.conv1x1_bn_relu_3 = ConvolutionBnActivation(filters, (1, 1)) 554 | 555 | self.gamma = None 556 | 557 | self.softmax = tf.keras.layers.Activation("softmax") 558 | 559 | def build(self, input_shape): 560 | self.gamma = self.add_weight( 561 | shape=(1,), 562 | initializer="random_normal", 563 | name="pam_gamma", 564 | trainable=True, 565 | ) 566 | 567 | def call(self, x, training=None): 568 | BS, C, H, W = x.shape 569 | 570 | query = self.conv1x1_bn_relu_1(x, training=training) 571 | key = self.conv1x1_bn_relu_2(x, training=training) 572 | 573 | if K.image_data_format() == "channels_last": 574 | query = tf.keras.layers.Reshape((H * W, -1))(query) 575 | key = tf.keras.layers.Reshape((H * W, -1))(key) 576 | 577 | energy = tf.linalg.matmul(query, key, transpose_b=True) 578 | else: 579 | query = tf.keras.layers.Reshape((-1, H * W))(query) 580 | key = tf.keras.layers.Reshape((-1, H * W))(key) 581 | 582 | energy = tf.linalg.matmul(query, key, transpose_a=True) 583 | 584 | attention = self.softmax(energy) 585 | 586 | value = self.conv1x1_bn_relu_3(x, training=training) 587 | 588 | if K.image_data_format() == "channels_last": 589 | value = tf.keras.layers.Reshape((H * W, -1))(value) 590 | out = tf.linalg.matmul(value, attention, transpose_a=True) 591 | else: 592 | value = tf.keras.layers.Reshape((-1, H * W))(value) 593 | out = tf.linalg.matmul(value, attention) 594 | 595 | out = tf.keras.layers.Reshape(x.shape[1:])(out) 596 | out = self.gamma * out + x 597 | 598 | return out 599 | 600 | class CAM_Module(tf.keras.layers.Layer): 601 | def __init__(self, filters): 602 | super(CAM_Module, self).__init__() 603 | 604 | self.filters = filters 605 | 606 | self.gamma = None 607 | 608 | self.softmax = tf.keras.layers.Activation("softmax") 609 | 610 | def build(self, input_shape): 611 | self.gamma = self.add_weight( 612 | shape=(1,), 613 | initializer="random_normal", 614 | name="cam_gamma", 615 | trainable=True, 616 | ) 617 | 618 | def call(self, x, training=None): 619 | BS, C, H, W = x.shape 620 | 621 | if K.image_data_format() == "channels_last": 622 | query = tf.keras.layers.Reshape((-1, C))(x) 623 | key = tf.keras.layers.Reshape((-1, C))(x) 624 | 625 | energy = tf.linalg.matmul(query, key, transpose_a=True) 626 | energy_2 = tf.math.reduce_max(energy, axis=1, keepdims=True)[0] - energy 627 | else: 628 | query = tf.keras.layers.Reshape((C, -1))(query) 629 | key = tf.keras.layers.Reshape((C, -1))(key) 630 | 631 | energy = tf.linalg.matmul(query, key, transpose_b=True) 632 | energy_2 = tf.math.reduce_max(energy, axis=-1, keepdims=True)[0] - energy 633 | 634 | attention = self.softmax(energy_2) 635 | 636 | if K.image_data_format() == "channels_last": 637 | value = tf.keras.layers.Reshape((-1, C))(x) 638 | out = tf.linalg.matmul(attention, value, transpose_b=True) 639 | else: 640 | value = tf.keras.layers.Reshape((C, -1))(x) 641 | out = tf.linalg.matmul(attention, value) 642 | 643 | out = tf.keras.layers.Reshape(x.shape[1:])(out) 644 | out = self.gamma * out + x 645 | 646 | return out 647 | 648 | 649 | class SelfAttentionBlock2D(tf.keras.layers.Layer): 650 | def __init__(self, filters): 651 | super(SelfAttentionBlock2D, self).__init__() 652 | 653 | self.filters = filters 654 | 655 | self.conv1x1_1 = tf.keras.layers.Conv2D(filters // 8, (1, 1), padding="same") 656 | self.conv1x1_2 = tf.keras.layers.Conv2D(filters // 8, (1, 1), padding="same") 657 | self.conv1x1_3 = tf.keras.layers.Conv2D(filters // 2, (1, 1), padding="same") 658 | self.conv1x1_4 = tf.keras.layers.Conv2D(filters, (1, 1), padding="same") 659 | 660 | self.gamma = None 661 | 662 | self.softmax_activation = tf.keras.layers.Activation("softmax") 663 | 664 | def build(self, input_shape): 665 | self.gamma = self.add_weight( 666 | shape=(1,), 667 | initializer="random_normal", 668 | name="gamma", 669 | trainable=True, 670 | ) 671 | 672 | def call(self, x, training=None): 673 | f = self.conv1x1_1(x, training=training) 674 | 675 | g = self.conv1x1_2(x, training=training) 676 | 677 | h = self.conv1x1_3(x, training=training) 678 | 679 | g = tf.reshape(g, (g.shape[0], -1, g.shape[-1])) 680 | f = tf.reshape(f, (f.shape[0], -1, f.shape[-1])) 681 | 682 | # s = tf.matmul(tf.reshape(g, (x.shape[0], -1, x.shape[-1])), tf.reshape(f, (x.shape[0], -1, x.shape[-1])), transpose_b=True) 683 | s = tf.matmul(g, f, transpose_b=True) 684 | beta = self.softmax_activation(s) 685 | 686 | h = tf.reshape(h, (h.shape[0], -1, h.shape[-1])) 687 | 688 | o = tf.matmul(beta, h) 689 | 690 | o = tf.reshape(o, shape=(x.shape[0], x.shape[1], x.shape[2], x.shape[3] // 2)) 691 | o = self.conv1x1_4(o, training=training) 692 | x = self.gamma * o + x 693 | 694 | return x 695 | 696 | class GlobalPooling(tf.keras.layers.Layer): 697 | def __init__(self, filters): 698 | super(GlobalPooling, self).__init__() 699 | self.filters = filters 700 | 701 | self.conv1x1_bn_relu = ConvolutionBnActivation(filters, (1, 1)) 702 | 703 | def call(self, x, training=None): 704 | if K.image_data_format() == "channels_last": 705 | BS, H, W, C = x.shape 706 | glob_avg_pool = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2], keepdims=True))(x) 707 | glob_avg_pool = self.conv1x1_bn_relu(glob_avg_pool, training=training) 708 | glob_avg_pool = tf.keras.layers.Lambda(lambda x: tf.image.resize(x, [H, W]))(glob_avg_pool) 709 | else: 710 | BS, C, H, W = x.shape 711 | glob_avg_pool = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2], keepdims=True))(x) 712 | glob_avg_pool = self.conv1x1_bn_relu(glob_avg_pool, training=training) 713 | glob_avg_pool = tf.keras.layers.Lambda(lambda x: tf.image.resize(x, [H, W]))(glob_avg_pool) 714 | 715 | return glob_avg_pool 716 | 717 | class MixtureOfSoftMaxACF(tf.keras.layers.Layer): 718 | def __init__(self, d_k, att_dropout=0.1): 719 | super(MixtureOfSoftMaxACF, self).__init__() 720 | self.temperature = tf.math.pow(tf.cast(d_k, tf.float32), 0.5) 721 | self.att_dropout = att_dropout 722 | 723 | self.dropout = tf.keras.layers.Dropout(att_dropout) 724 | self.softmax_1 = tf.keras.layers.Activation("softmax") 725 | self.softmax_2 = tf.keras.layers.Activation("softmax") 726 | 727 | self.d_k = d_k 728 | 729 | def call(self, qt, kt, vt, training=None): 730 | if K.image_data_format() == "channels_last": 731 | BS, N, d_k = qt.shape # (BS, H * W, C) 732 | 733 | assert d_k == self.d_k 734 | d = d_k 735 | 736 | q = tf.keras.layers.Reshape((N, d))(qt) # (BS, N, d) 737 | # q = tf.transpose(q, perm=[0, 2, 1]) # (BS, d, N) 738 | N2 = kt.shape[1] 739 | kt = tf.keras.layers.Reshape((N2, d))(kt) # (BS, N2, d) 740 | # v = tf.keras.layers.transpose(vt, perm=[0, 2, 1]) # (BS, d, N2) 741 | 742 | att = tf.linalg.matmul(q, kt, transpose_b=True) # (BS, N, N2) 743 | att = att / self.temperature # (BS, N, N2) 744 | att = self.softmax_2(att) # (BS, N, N2) 745 | att = self.dropout(att, training=training) # (BS, N, N2) 746 | 747 | out = tf.linalg.matmul(att, vt) # (BS, N, d) 748 | 749 | else: 750 | BS, d_k, N = qt.shape 751 | 752 | assert d_k == self.d_k 753 | d = d_k 754 | 755 | q = tf.keras.layers.Reshape((d, N))(qt) # (BS, d, N) 756 | # q = tf.transpose(q, perm=[0, 2, 1]) # (BS, N, d) 757 | N2 = kt.shape[2] 758 | kt = tf.keras.layers.Reshape((d, N2))(kt) # (BS, d, N2) 759 | # v = tf.transpose(vt, perm=[0, 2, 1]) # (BS, N2, d) 760 | 761 | att = tf.linalg.matmul(q, kt, transpose_a=True) # (BS, N, N2) 762 | att = att / self.temperature # (BS, N, N2) 763 | att = self.softmax_2(att) # (BS, N, N2) 764 | att = self.dropout(att, training=training) # (BS, N, N2) 765 | 766 | out = tf.linalg.matmul(att, vt, transpose_b=True) # (BS, N, d) 767 | 768 | return out 769 | 770 | class AggCF_Module(tf.keras.layers.Layer): 771 | def __init__(self, filters, kq_transform="conv", value_transform="conv", 772 | pooling=True, concat=False, dropout=0.1): 773 | super(AggCF_Module, self).__init__() 774 | self.filters = filters 775 | self.kq_transform = kq_transform 776 | self.value_transform = value_transform 777 | self.pooling = pooling 778 | self.concat = concat # if True concat else Add 779 | self.dropout = dropout 780 | 781 | self.avg_pool2d_1 = tf.keras.layers.AveragePooling2D(pool_size=(2, 2), padding="same", data_format=K.image_data_format()) 782 | self.avg_pool2d_2 = tf.keras.layers.AveragePooling2D(pool_size=(2, 2), padding="same", data_format=K.image_data_format()) 783 | 784 | self.conv_ks_1 = None 785 | self.conv_ks_2 = None 786 | self.conv_vs = None 787 | 788 | self.attention = MixtureOfSoftMaxACF(d_k=filters, att_dropout=0.1) 789 | 790 | self.conv1x1_bn_relu = tf.keras.layers.Conv2D(filters, (1, 1), padding="same") 791 | 792 | axis = 3 if K.image_data_format() == "channels_last" else 1 793 | self.bn = tf.keras.layers.BatchNormalization(axis=axis) 794 | self.concat = tf.keras.layers.Concatenate(axis=axis) 795 | self.add = tf.keras.layers.Add() 796 | 797 | def build(self, input_shape): 798 | if self.kq_transform == "conv": 799 | self.conv_ks_1 = tf.keras.layers.Conv2D(self.filters, (1, 1), padding="same") 800 | self.conv_ks_2 = tf.keras.layers.Conv2D(self.filters, (1, 1), padding="same") 801 | elif self.kq_transform == "ffn": 802 | self.conv_ks_1 = tf.keras.Sequential( 803 | [ConvolutionBnActivation(self.filters, (3, 3)), 804 | tf.keras.layers.Conv2D(self.filters, (1, 1), padding="same")] 805 | ) 806 | self.conv_ks_2 = tf.keras.Sequential( 807 | [ConvolutionBnActivation(self.filters, (3, 3)), 808 | tf.keras.layers.Conv2D(self.filters, (1, 1), padding="same")] 809 | ) 810 | elif self.kq_transform == "dffn": 811 | self.conv_ks_1 = tf.keras.Sequential( 812 | [ConvolutionBnActivation(self.filters, (3, 3), dilation_rate=(4, 4)), 813 | tf.keras.layers.Conv2D(self.filters, (1, 1), padding="same")] 814 | ) 815 | self.conv_ks_2 = tf.keras.Sequential( 816 | [ConvolutionBnActivation(self.filters, (3, 3), dilation_rate=(4, 4)), 817 | tf.keras.layers.Conv2D(self.filters, (1, 1), padding="same")] 818 | ) 819 | else: 820 | raise NotImplementedError("Allowed options for 'kq_transform' are only ('conv', 'ffn', 'dffn'), got {}".format(self.kq_transform)) 821 | 822 | if self.value_transform == "conv": 823 | self.conv_vs = tf.keras.layers.Conv2D(self.filters, (1, 1), padding="same") 824 | else: 825 | raise NotImplementedError("Allowed options for 'value_transform' is only 'conv', got {}".format(self.kq_transform)) 826 | 827 | def call(self, x, training=None): 828 | residual = x 829 | d_k = self.filters / 8 830 | if K.image_data_format() == "channels_last": 831 | BS, H, W, C = x.shape 832 | 833 | if self.pooling: 834 | qt = self.conv_ks_1(x, training=training) 835 | qt = tf.keras.layers.Reshape((H * W, -1))(qt) # (BS, N, C) 836 | kt = self.avg_pool2d_1(x) 837 | kt = self.conv_ks_2(kt, training=training) 838 | kt = tf.keras.layers.Reshape((H * W // 4, -1))(kt) # (BS, N / 4, C) 839 | vt = self.avg_pool2d_2(x) 840 | vt = self.conv_vs(vt, training=training) 841 | vt = tf.keras.layers.Reshape((H * W // 4, -1))(vt) # (BS, N / 4, C) 842 | else: 843 | qt = self.conv_ks_1(x, training=training) 844 | qt = tf.keras.layers.Reshape((H * W, -1))(qt) # (BS, N, C) 845 | kt = self.conv_ks_2(x, training=training) 846 | kt = tf.keras.layers.Reshape((H * W, -1))(kt) # (BS, N, C) 847 | vt = self.conv_vs(x, training=training) 848 | vt = tf.keras.layers.Reshape((H * W, -1))(vt) # (BS, N, C) 849 | 850 | out = self.attention(qt, kt, vt, training=training) # (BS, N, C) 851 | 852 | # out = tf.transpose(out, perm=[0, 2, 1]) 853 | out = tf.keras.layers.Reshape((H, W, -1))(out) # (BS, H, W, C) 854 | 855 | else: 856 | BS, C, H, W = x.shape 857 | 858 | if self.pooling: 859 | qt = self.conv_ks_1(x, training=training) 860 | qt = tf.keras.layers.Reshape((-1, H * W))(qt) # (BS, C, N) 861 | kt = self.avg_pool2d_1(x) 862 | kt = self.conv_ks_2(kt, training=training) 863 | kt = tf.keras.layers.Reshape((-1, H * W // 4))(kt) # (BS, C, N / 4) 864 | vt = self.avg_pool2d_2(x) 865 | vt = self.conv_vs(vt, training=training) 866 | vt = tf.keras.layers.Reshape((-1, H * W // 4))(vt) # (BS, C, N / 4) 867 | else: 868 | qt = self.conv_ks_1(x, training=training) 869 | qt = tf.keras.layers.Reshape((-1, H * W))(qt) # (BS, C, N) 870 | kt = self.conv_ks_2(x, training=training) 871 | kt = tf.keras.layers.Reshape((-1, H * W))(kt) # (BS, C, N) 872 | vt = self.conv_vs(x, training=training) 873 | vt = tf.keras.layers.Reshape((-1, H * W))(vt) # (BS, C, N) 874 | 875 | out = self.attention(qt, kt, vt) # (BS, N, C) 876 | 877 | out = tf.transpose(out, perm=[0, 2, 1]) # (BS, C, N) 878 | out = tf.keras.layers.Reshape((-1, H, W))(out) # (BS, C, H, W) 879 | 880 | out = self.conv1x1_bn_relu(out, training=training) 881 | if self.concat: 882 | out = self.concat([out, residual]) 883 | else: 884 | out = self.add([out, residual]) 885 | 886 | return out 887 | 888 | class SpatialGather_Module(tf.keras.layers.Layer): 889 | def __init__(self, scale=1): 890 | super(SpatialGather_Module, self).__init__() 891 | 892 | self.scale = scale 893 | 894 | self.softmax = tf.keras.layers.Activation("softmax") 895 | 896 | def call(self, features, probabilities, training=None): 897 | if K.image_data_format() == "channels_last": 898 | BS, H, W, C = probabilities.shape 899 | p = tf.keras.layers.Reshape((-1, C))(probabilities) # (BS, N, C) 900 | f = tf.keras.layers.Reshape((-1, features.shape[-1]))(features) # (BS, N, C2) 901 | 902 | p = self.softmax(self.scale * p) # (BS, N, C) 903 | ocr_context = tf.linalg.matmul(p, f, transpose_a=True) # (BS, C, C2) 904 | 905 | else: 906 | BS, C, H, W = probabilities.shape 907 | p = tf.keras.layers.Reshape((C, -1))(probabilities) # (BS, C, N) 908 | f = tf.keras.layers.Reshape((features.shape[1], -1))(features) # (BS, C2, N) 909 | 910 | p = self.softmax(self.scale * p) # (BS, C, N) 911 | ocr_context = tf.linalg.matmul(p, f, transpose_b=True) # (BS, C, C2) 912 | 913 | return ocr_context 914 | 915 | class ObjectAttentionBlock2D(tf.keras.layers.Layer): 916 | def __init__(self, filters, scale=1.0): 917 | super(ObjectAttentionBlock2D, self).__init__() 918 | self.filters = filters 919 | self.scale = scale 920 | 921 | self.max_pool2d = tf.keras.layers.MaxPooling2D(pool_size=(scale, scale)) 922 | self.f_pixel = tf.keras.models.Sequential([ 923 | ConvolutionBnActivation(filters, (1, 1)), 924 | ConvolutionBnActivation(filters, (1, 1)) 925 | ]) 926 | # self.f_object = tf.keras.models.Sequential([ 927 | # ConvolutionBnActivation(filters, (1, 1)), 928 | # ConvolutionBnActivation(filters, (1, 1)) 929 | # ]) 930 | # self.f_down = ConvolutionBnActivation(filters, (1, 1)) 931 | self.f_up = ConvolutionBnActivation(filters, (1, 1)) 932 | 933 | self.softmax = tf.keras.layers.Activation("softmax") 934 | self.upsampling2d = tf.keras.layers.UpSampling2D(size=scale, interpolation="bilinear") 935 | 936 | def call(self, feats, ctx, training=None): 937 | if K.image_data_format() == "channels_last": 938 | # feats-dim: (BS, H, W, C) & ctx-dim: (BS, C, C2) 939 | ctx = tf.keras.layers.Permute((2, 1))(ctx) 940 | BS, H, W, C = feats.shape 941 | if self.scale > 1: 942 | feats = self.pool(feats, training=training) 943 | 944 | query = self.f_pixel(feats, training=training) # (BS, H, W, C) 945 | query = tf.keras.layers.Reshape((-1, C))(query) # (BS, N, C) 946 | # key = self.f_object(ctx, training=training) # (BS, C2, C) 947 | key = tf.keras.layers.Reshape((-1, C))(ctx) # (BS, C2, C) 948 | # value = self.f_down(ctx, training=training) # (BS, C2, C) 949 | value = tf.keras.layers.Reshape((-1, C))(ctx) # (BS, C2, C) 950 | 951 | sim_map = tf.linalg.matmul(query, key, transpose_b=True) # (BS, N, C2) 952 | sim_map = (self.filters ** -0.5) * sim_map # (BS, N, C2) 953 | sim_map = self.softmax(sim_map) # (BS, N, C2) 954 | 955 | context = tf.linalg.matmul(sim_map, value) # (BS, N, C) 956 | context = tf.keras.layers.Reshape((H, W, C))(context) # (BS, H, W, C) 957 | context = self.f_up(context, training=training) # (BS, H, W, C) 958 | if self.scale > 1: 959 | context = self.upsampling2d(context) 960 | 961 | else: 962 | # feats-dim: (BS, C, H, W) & ctx-dim: (BS, C, C2) 963 | BS, C, H, W = feats.shape 964 | if self.scale > 1: 965 | feats = self.pool(feats, training=training) 966 | 967 | query = self.f_pixel(feats, training=training) # (BS, C, H, W) 968 | query = tf.keras.layers.Reshape((C, -1))(query) # (BS, C, N) 969 | # key = self.f_object(ctx, training=training) # (BS, C, C2) 970 | key = tf.keras.layers.Reshape((C, -1))(ctx) # (BS, C, C2) 971 | # value = self.f_down(ctx, training=training) # (BS, C, C2) 972 | value = tf.keras.layers.Reshape((C, -1))(ctx) # (BS, C, C2) 973 | 974 | sim_map = tf.linalg.matmul(query, key, transpose_a=True) # (BS, N, C2) 975 | sim_map = (self.filters ** -0.5) * sim_map # (BS, N, C2) 976 | sim_map = self.softmax(sim_map) # (BS, N, C2) 977 | 978 | context = tf.linalg.matmul(sim_map, value, transpose_b=True) # (BS, N, C) 979 | context = tf.keras.layers.Permute(2, 1)(context) # (BS, C, N) 980 | context = tf.keras.layers.Reshape((C, H, W))(context) # (BS, C, H, W) 981 | context = self.f_up(context, training=training) # (BS, C, H, W) 982 | if self.scale > 1: 983 | context = self.upsampling2d(context) 984 | 985 | return context 986 | 987 | 988 | class SpatialOCR_Module(tf.keras.layers.Layer): 989 | def __init__(self, filters, scale=1.0, dropout=0.1): 990 | super(SpatialOCR_Module, self).__init__() 991 | self.filters = filters 992 | self.scale = scale 993 | self.dropout = dropout 994 | 995 | self.object_attention = ObjectAttentionBlock2D(filters, scale) 996 | 997 | axis = 3 if K.image_data_format() == "channels_last" else 1 998 | self.concat = tf.keras.layers.Concatenate(axis=axis) 999 | self.conv1x1_bn_relu = ConvolutionBnActivation(filters, (1, 1)) 1000 | self.dropout = tf.keras.layers.Dropout(dropout) 1001 | 1002 | def call(self, features, ocr_context, training=None): 1003 | # features-dim: (BS, H, W, C) & ocr_context-dim: (BS, C, C2) (if K.image_data_format() == "channels_last") 1004 | context = self.object_attention(features, ocr_context, training=training) # (BS, H, W, C) 1005 | 1006 | output = self.concat([context, features]) # (BS, H, W, 2*C) 1007 | output = self.conv1x1_bn_relu(output, training=training) # (BS, H, W, C) 1008 | output = self.dropout(output, training=training) # (BS, H, W, C) 1009 | 1010 | return output 1011 | 1012 | 1013 | class SpatialOCR_ASP_Module(tf.keras.layers.Layer): 1014 | def __init__(self, filters, scale=1, dropout=0.1, dilations=(12, 24, 36)): 1015 | super(SpatialOCR_ASP_Module, self).__init__() 1016 | self.filters = filters 1017 | self.scale = scale 1018 | self.dropout = dropout 1019 | self.dilations = dilations 1020 | 1021 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, (3, 3)) 1022 | self.context = ObjectAttentionBlock2D(filters, scale) 1023 | self.conv1x1_bn_relu_1 = ConvolutionBnActivation(filters, (1, 1)) 1024 | self.atrous_sepconv_bn_relu_1 = AtrousSeparableConvolutionBnReLU(dilation=dilations[0], filters=filters, kernel_size=3) 1025 | self.atrous_sepconv_bn_relu_2 = AtrousSeparableConvolutionBnReLU(dilation=dilations[1], filters=filters, kernel_size=3) 1026 | self.atrous_sepconv_bn_relu_3 = AtrousSeparableConvolutionBnReLU(dilation=dilations[2], filters=filters, kernel_size=3) 1027 | 1028 | self.spatial_context = SpatialGather_Module(scale=scale) 1029 | 1030 | self.axis = 3 if K.image_data_format() == "channels_last" else 1 1031 | self.concat = tf.keras.layers.Concatenate(axis=self.axis) 1032 | 1033 | self.conv1x1_bn_relu_2 = ConvolutionBnActivation(filters, (1, 1)) 1034 | self.dropout = tf.keras.layers.Dropout(dropout) 1035 | 1036 | def call(self, x, probabilities, training=None): 1037 | feat1 = self.conv3x3_bn_relu_1(x, training=training) 1038 | context = self.spatial_context(feat1, probabilities, training=training) 1039 | feat1 = self.context(feat1, context, training=training) 1040 | feat2 = self.conv1x1_bn_relu_1(x, training=training) 1041 | feat3 = self.atrous_sepconv_bn_relu_1(x, training=training) 1042 | feat4 = self.atrous_sepconv_bn_relu_2(x, training=training) 1043 | feat5 = self.atrous_sepconv_bn_relu_3(x, training=training) 1044 | 1045 | output = self.concat([feat1, feat2, feat3, feat4, feat5]) 1046 | output = self.conv1x1_bn_relu_2(output, training=training) 1047 | output = self.dropout(output, training=training) 1048 | 1049 | return output 1050 | 1051 | 1052 | class AttCF_Module(tf.keras.layers.Layer): 1053 | def __init__(self, filters): 1054 | super(AttCF_Module, self).__init__() 1055 | 1056 | self.filters = filters 1057 | 1058 | self.softmax = tf.keras.layers.Activation("softmax") 1059 | self.conv1x1 = tf.keras.layers.Conv2D(filters, (1, 1), padding="same") 1060 | 1061 | def call(self, aspp, coarse_x, training=None): 1062 | if K.image_data_format() == "channels_last": 1063 | BS, H, W, C = aspp.shape 1064 | _, h, w, c = coarse_x.shape 1065 | 1066 | # CCB 1067 | q = tf.keras.layers.Reshape((-1, c))(coarse_x) # (BS, N, c) 1068 | k = tf.keras.layers.Reshape((-1, C))(aspp) # (BS, N, C) 1069 | 1070 | e = tf.linalg.matmul(q, k, transpose_a=True) # (BS, c, C) 1071 | e = tf.math.reduce_max(e, -1, keepdims=True)[0] - e # (BS, c, C) 1072 | att = self.softmax(e) # (BS, c, C) 1073 | 1074 | # CAB 1075 | v = tf.keras.layers.Reshape((-1, c))(coarse_x) # (BS, N, c) 1076 | output = tf.linalg.matmul(att, v, transpose_a=True, transpose_b=True) # (BS, C, N) 1077 | output = tf.keras.layers.Permute((2, 1))(output) # (BS, N, C) 1078 | output = tf.keras.layers.Reshape((H, W, C))(output) # (BS, H, W, C) 1079 | 1080 | output = self.conv1x1(output) # (BS, H, W, C) 1081 | 1082 | else: 1083 | BS, C, H, W = aspp.shape 1084 | bs, c, h, w = coarse_x.shape 1085 | 1086 | # CCB 1087 | q = tf.keras.layers.Reshape((c, -1))(coarse_x) 1088 | k = tf.keras.layers.Reshape((C, -1))(aspp) 1089 | 1090 | e = tf.linalg.matmul(q, k, transpose_b=True) # (BS, c, C) 1091 | e = tf.math.reduce_max(e, -1, keepdims=True)[0] - e # (BS, c, C) 1092 | att = self.softmax(e) # (BS, c, C) 1093 | 1094 | # CAB 1095 | v = tf.keras.layers.Reshape((c, -1))(coarse_x) # (BS, c, N) 1096 | output = tf.linalg.matmul(att, v, transpose_a=True) # (BS, C, N) 1097 | output = tf.keras.layers.Reshape((C, H, W))(output) 1098 | 1099 | output = self.conv1x1(output) 1100 | 1101 | return output 1102 | 1103 | 1104 | class BasicBlock(tf.keras.layers.Layer): 1105 | def __init__(self, filters): 1106 | super(BasicBlock, self).__init__() 1107 | 1108 | self.conv3x3_bn_relu = ConvolutionBnActivation(filters, (3, 3), momentum=0.1) 1109 | self.conv3x3_bn = ConvolutionBnActivation(filters, (3, 3), momentum=0.1, post_activation="linear") 1110 | 1111 | self.relu = tf.keras.layers.Activation("relu") 1112 | 1113 | def call(self, input, training=None): 1114 | residual = input 1115 | 1116 | out = self.conv3x3_bn_relu(input, training=training) 1117 | out = self.conv3x3_bn(out, training=training) 1118 | 1119 | out = out + residual 1120 | out = self.relu(out) 1121 | 1122 | return out 1123 | 1124 | 1125 | class BottleneckBlock(tf.keras.layers.Layer): 1126 | def __init__(self, filters, downsample=False, expansion=4): 1127 | super(BottleneckBlock, self).__init__() 1128 | 1129 | self.ds = downsample 1130 | 1131 | self.conv3x3_bn_relu_1 = ConvolutionBnActivation(filters, (1, 1), momentum=0.1) 1132 | self.conv3x3_bn_relu_2 = ConvolutionBnActivation(filters, (3, 3), momentum=0.1) 1133 | self.conv3x3_bn = ConvolutionBnActivation(filters * expansion, (1, 1), momentum=0.1, post_activation="linear") 1134 | 1135 | if downsample: 1136 | self.downsample = ConvolutionBnActivation(filters * expansion, (1, 1), momentum=0.1) 1137 | 1138 | self.relu = tf.keras.layers.Activation("relu") 1139 | 1140 | def call(self, input, training=None): 1141 | residual = input 1142 | 1143 | out = self.conv3x3_bn_relu_1(input, training=training) 1144 | out = self.conv3x3_bn_relu_2(out, training=training) 1145 | out = self.conv3x3_bn(out, training=training) 1146 | 1147 | if self.ds: 1148 | residual = self.downsample(input) 1149 | 1150 | out = out + residual 1151 | out = self.relu(out) 1152 | 1153 | return out 1154 | 1155 | 1156 | class HighResolutionModule(tf.keras.layers.Layer): 1157 | def __init__(self, num_branches, blocks, filters): 1158 | # filters_in unnecessary since it equals filters 1159 | super(HighResolutionModule, self).__init__() 1160 | 1161 | self.num_branches = num_branches 1162 | self.filters = filters 1163 | self.num_in_channels = filters[0] 1164 | 1165 | self._check_branches(num_branches, blocks, filters) 1166 | 1167 | # Make Branches 1168 | self.branch_1 = tf.keras.Sequential([BasicBlock(filters[0]), BasicBlock(filters[0]), BasicBlock(filters[0]), BasicBlock(filters[0])]) 1169 | self.branch_2 = tf.keras.Sequential([BasicBlock(filters[1]), BasicBlock(filters[1]), BasicBlock(filters[1]), BasicBlock(filters[1])]) 1170 | self.branch_3 = tf.keras.Sequential([BasicBlock(filters[2]), BasicBlock(filters[2]), BasicBlock(filters[2]), BasicBlock(filters[2])]) if num_branches >= 3 else None 1171 | self.branch_4 = tf.keras.Sequential([BasicBlock(filters[3]), BasicBlock(filters[3]), BasicBlock(filters[3]), BasicBlock(filters[3])]) if num_branches >= 4 else None 1172 | 1173 | self.fuse_layers = self._make_fuse_layers() 1174 | self.relu = tf.keras.layers.Activation("relu") 1175 | 1176 | def _check_branches(self, num_branches, blocks, filters): 1177 | if num_branches != len(blocks): 1178 | raise ValueError("'num_branches' = {} is not equal to length of 'blocks' = {}".format(num_branches, len(blocks))) 1179 | 1180 | if num_branches != len(filters): 1181 | raise ValueError("'num_branches' = {} is not equal to length of 'filters' = {}".format(num_branches, len(filters))) 1182 | 1183 | def _make_fuse_layers(self): 1184 | fuse_layers = [] 1185 | for i in range(self.num_branches): 1186 | fuse_layer = [] 1187 | for j in range(self.num_branches): 1188 | if j > i: 1189 | fuse_layer.append(ConvolutionBnActivation(self.filters[i], (1, 1), momentum=0.1, post_activation="linear")) 1190 | elif j == i: 1191 | fuse_layer.append(None) 1192 | else: 1193 | conv3x3s = [] 1194 | for k in range(i - j): 1195 | if k == i - j - 1: 1196 | conv3x3s.append(ConvolutionBnActivation(self.filters[i], (3, 3), strides=(2, 2), momentum=0.1, post_activation="linear")) 1197 | else: 1198 | conv3x3s.append(ConvolutionBnActivation(self.filters[j], (3, 3), strides=(2, 2), momentum=0.1)) 1199 | fuse_layer.append(tf.keras.Sequential(conv3x3s)) 1200 | 1201 | fuse_layers.append(fuse_layer) 1202 | 1203 | return fuse_layers 1204 | 1205 | def call(self, input1, input2, input3, input4, training=None): 1206 | x_1 = self.branch_1(input1, training=training) 1207 | x_2 = self.branch_2(input2, training=training) 1208 | x_3 = self.branch_3(input3, training=training) if self.num_branches >= 3 else None 1209 | x_4 = self.branch_4(input4, training=training) if self.num_branches >= 4 else None 1210 | 1211 | x = [x_1, x_2] 1212 | if x_3 is not None: 1213 | x = [x_1, x_2, x_3] 1214 | if x_4 is not None: 1215 | x = [x_1, x_2, x_3, x_4] 1216 | 1217 | x_fuse = [] 1218 | for i in range(len(self.fuse_layers)): 1219 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 1220 | for j in range(1, self.num_branches): 1221 | if i == j: 1222 | y += x[j] 1223 | elif j > i: 1224 | f = self.fuse_layers[i][j](x[j]) 1225 | scale_factor = int(x[i].shape[-2] / f.shape[-2]) 1226 | if scale_factor > 1: 1227 | y += tf.keras.layers.UpSampling2D(size=scale_factor, interpolation="bilinear")(f) 1228 | else: 1229 | y += f 1230 | else: 1231 | y += self.fuse_layers[i][j](x[j]) 1232 | 1233 | x_fuse.append(self.relu(y)) 1234 | 1235 | return x_fuse 1236 | 1237 | 1238 | custom_objects = { 1239 | 'ConvolutionBnActivation': ConvolutionBnActivation, 1240 | 'AtrousSeparableConvolutionBnReLU': AtrousSeparableConvolutionBnReLU, 1241 | 'AtrousSpatialPyramidPoolingV3': AtrousSpatialPyramidPoolingV3, 1242 | 'Upsample_x2_Block': Upsample_x2_Block, 1243 | 'Upsample_x2_Add_Block': Upsample_x2_Add_Block, 1244 | 'SpatialContextBlock': SpatialContextBlock, 1245 | 'FPNBlock': FPNBlock, 1246 | 'AtrousSpatialPyramidPoolingV1': AtrousSpatialPyramidPoolingV1, 1247 | 'Base_OC_Module': Base_OC_Module, 1248 | 'Pyramid_OC_Module': Pyramid_OC_Module, 1249 | 'ASP_OC_Module': ASP_OC_Module, 1250 | 'PAM_Module': PAM_Module, 1251 | 'CAM_Module': CAM_Module, 1252 | 'SelfAttentionBlock2D': SelfAttentionBlock2D, 1253 | } 1254 | 1255 | tf.keras.utils.get_custom_objects().update(custom_objects) --------------------------------------------------------------------------------