├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── docker ├── Dockerfile └── README.md ├── keras_segmentation ├── __init__.py ├── __main__.py ├── cli_interface.py ├── data_utils │ ├── __init__.py │ ├── augmentation.py │ ├── data_loader.py │ └── visualize_dataset.py ├── metrics.py ├── model_compression.py ├── models │ ├── __init__.py │ ├── _pspnet_2.py │ ├── all_models.py │ ├── basic_models.py │ ├── config.py │ ├── fcn.py │ ├── mobilenet.py │ ├── model.py │ ├── model_utils.py │ ├── pspnet.py │ ├── resnet50.py │ ├── segnet.py │ ├── unet.py │ └── vgg16.py ├── predict.py ├── pretrained.py └── train.py ├── requirements.txt ├── sample_images ├── 1_input.jpg ├── 1_output.png ├── 2_input.jpg ├── 2_output.png ├── 3_input.jpg ├── 3_output.png ├── liner_dataset.png ├── liner_export.png ├── liner_testing.png └── liner_training.png ├── setup.cfg ├── setup.py └── test ├── __init__.py ├── example_dataset ├── annotations_prepped_test │ ├── 0016E5_07959.png │ ├── 0016E5_07961.png │ └── 0016E5_07963.png ├── annotations_prepped_train │ ├── 0001TP_006690.png │ ├── 0001TP_006720.png │ ├── 0001TP_006750.png │ ├── 0001TP_006780.png │ └── 0001TP_006810.png ├── images_prepped_test │ ├── 0016E5_07959.png │ ├── 0016E5_07961.png │ └── 0016E5_07963.png └── images_prepped_train │ ├── 0001TP_006690.png │ ├── 0001TP_006720.png │ ├── 0001TP_006750.png │ ├── 0001TP_006780.png │ └── 0001TP_006810.png ├── test_models.py └── unit ├── data_utils ├── test_augmentation.py ├── test_data_loader.py └── test_visualize_dataset.py ├── models └── test_basic_models.py ├── test_metrics.py ├── test_predict.py ├── test_pretrained.py └── test_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | 3 | 4 | 5 | sync.expect 6 | 7 | 8 | # OS generated files # 9 | ###################### 10 | .DS_Store 11 | .DS_Store? 12 | ehthumbs.db 13 | Icon? 14 | Thumbs.db 15 | 16 | 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | env/ 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *,cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | 73 | # Flask instance folder 74 | instance/ 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # IPython Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | venv/ 99 | ENV/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # VSCode settings 108 | .vscode/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | # - "3.5" removing 3.5 due to some strange error on travis. 5 | - "3.6" # current default Python on Travis CI 6 | - "3.7" 7 | # command to install dependencies 8 | install: 9 | # Install with tests default extras-require 10 | - pip install .[tests-default] 11 | # command to run tests 12 | script: pytest 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Divam Gupta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Segmentation Keras : Implementation of Segnet, FCN, UNet, PSPNet and other models in Keras. 2 | 3 | [![PyPI version](https://badge.fury.io/py/keras-segmentation.svg)](https://badge.fury.io/py/keras-segmentation) 4 | [![Downloads](https://pepy.tech/badge/keras-segmentation)](https://pepy.tech/project/keras-segmentation) 5 | [![Build Status](https://travis-ci.org/divamgupta/image-segmentation-keras.png)](https://travis-ci.org/divamgupta/image-segmentation-keras) 6 | [![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](http://perso.crans.org/besson/LICENSE.html) 7 | [![Twitter](https://img.shields.io/twitter/url.svg?label=Follow%20%40divamgupta&style=social&url=https%3A%2F%2Ftwitter.com%2Fdivamgupta)](https://twitter.com/divamgupta) 8 | 9 | 10 | 11 | Implementation of various Deep Image Segmentation models in keras. 12 | 13 | ### News : Some functionality of this repository has been integrated with https://liner.ai . Check it out!! 14 | 15 | 16 | 17 |

18 | 19 |

20 | 21 | Link to the full blog post with tutorial : https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html 22 | 23 | 24 | ## Working Google Colab Examples: 25 | * Python Interface: https://colab.research.google.com/drive/1q_eCYEzKxixpCKH1YDsLnsvgxl92ORcv?usp=sharing 26 | * CLI Interface: https://colab.research.google.com/drive/1Kpy4QGFZ2ZHm69mPfkmLSUes8kj6Bjyi?usp=sharing 27 | 28 | ## Training using GUI interface 29 | You can also train segmentation models on your computer with https://liner.ai 30 | 31 | Train | Inference / Export 32 | :-------------------------:|:-------------------------: 33 | ![https://liner.ai ](sample_images/liner_dataset.png) | ![https://liner.ai ](sample_images/liner_testing.png) 34 | ![https://liner.ai ](sample_images/liner_training.png) | ![https://liner.ai ](sample_images/liner_export.png) 35 | 36 | 37 | ## Models 38 | 39 | Following models are supported: 40 | 41 | | model_name | Base Model | Segmentation Model | 42 | |------------------|-------------------|--------------------| 43 | | fcn_8 | Vanilla CNN | FCN8 | 44 | | fcn_32 | Vanilla CNN | FCN8 | 45 | | fcn_8_vgg | VGG 16 | FCN8 | 46 | | fcn_32_vgg | VGG 16 | FCN32 | 47 | | fcn_8_resnet50 | Resnet-50 | FCN32 | 48 | | fcn_32_resnet50 | Resnet-50 | FCN32 | 49 | | fcn_8_mobilenet | MobileNet | FCN32 | 50 | | fcn_32_mobilenet | MobileNet | FCN32 | 51 | | pspnet | Vanilla CNN | PSPNet | 52 | | pspnet_50 | Vanilla CNN | PSPNet | 53 | | pspnet_101 | Vanilla CNN | PSPNet | 54 | | vgg_pspnet | VGG 16 | PSPNet | 55 | | resnet50_pspnet | Resnet-50 | PSPNet | 56 | | unet_mini | Vanilla Mini CNN | U-Net | 57 | | unet | Vanilla CNN | U-Net | 58 | | vgg_unet | VGG 16 | U-Net | 59 | | resnet50_unet | Resnet-50 | U-Net | 60 | | mobilenet_unet | MobileNet | U-Net | 61 | | segnet | Vanilla CNN | Segnet | 62 | | vgg_segnet | VGG 16 | Segnet | 63 | | resnet50_segnet | Resnet-50 | Segnet | 64 | | mobilenet_segnet | MobileNet | Segnet | 65 | 66 | 67 | Example results for the pre-trained models provided : 68 | 69 | Input Image | Output Segmentation Image 70 | :-------------------------:|:-------------------------: 71 | ![](sample_images/1_input.jpg) | ![](sample_images/1_output.png) 72 | ![](sample_images/3_input.jpg) | ![](sample_images/3_output.png) 73 | 74 | 75 | ## How to cite 76 | 77 | If you are using this library, please cite using: 78 | 79 | ``` 80 | @article{gupta2023image, 81 | title={Image segmentation keras: Implementation of segnet, fcn, unet, pspnet and other models in keras}, 82 | author={Gupta, Divam}, 83 | journal={arXiv preprint arXiv:2307.13215}, 84 | year={2023} 85 | } 86 | 87 | ``` 88 | 89 | 90 | ## Getting Started 91 | 92 | ### Prerequisites 93 | 94 | * Keras ( recommended version : 2.4.3 ) 95 | * OpenCV for Python 96 | * Tensorflow ( recommended version : 2.4.1 ) 97 | 98 | ```shell 99 | apt-get install -y libsm6 libxext6 libxrender-dev 100 | pip install opencv-python 101 | ``` 102 | 103 | ### Installing 104 | 105 | Install the module 106 | 107 | Recommended way: 108 | ```shell 109 | pip install --upgrade git+https://github.com/divamgupta/image-segmentation-keras 110 | ``` 111 | 112 | ### or 113 | 114 | ```shell 115 | pip install keras-segmentation 116 | ``` 117 | 118 | ### or 119 | 120 | ```shell 121 | git clone https://github.com/divamgupta/image-segmentation-keras 122 | cd image-segmentation-keras 123 | python setup.py install 124 | ``` 125 | 126 | 127 | ## Pre-trained models: 128 | ```python 129 | from keras_segmentation.pretrained import pspnet_50_ADE_20K , pspnet_101_cityscapes, pspnet_101_voc12 130 | 131 | model = pspnet_50_ADE_20K() # load the pretrained model trained on ADE20k dataset 132 | 133 | model = pspnet_101_cityscapes() # load the pretrained model trained on Cityscapes dataset 134 | 135 | model = pspnet_101_voc12() # load the pretrained model trained on Pascal VOC 2012 dataset 136 | 137 | # load any of the 3 pretrained models 138 | 139 | out = model.predict_segmentation( 140 | inp="input_image.jpg", 141 | out_fname="out.png" 142 | ) 143 | 144 | ``` 145 | 146 | 147 | ### Preparing the data for training 148 | 149 | You need to make two folders 150 | 151 | * Images Folder - For all the training images 152 | * Annotations Folder - For the corresponding ground truth segmentation images 153 | 154 | The filenames of the annotation images should be same as the filenames of the RGB images. 155 | 156 | The size of the annotation image for the corresponding RGB image should be same. 157 | 158 | For each pixel in the RGB image, the class label of that pixel in the annotation image would be the value of the blue pixel. 159 | 160 | Example code to generate annotation images : 161 | 162 | ```python 163 | import cv2 164 | import numpy as np 165 | 166 | ann_img = np.zeros((30,30,3)).astype('uint8') 167 | ann_img[ 3 , 4 ] = 1 # this would set the label of pixel 3,4 as 1 168 | 169 | cv2.imwrite( "ann_1.png" ,ann_img ) 170 | ``` 171 | 172 | Only use bmp or png format for the annotation images. 173 | 174 | ## Download the sample prepared dataset 175 | 176 | Download and extract the following: 177 | 178 | https://drive.google.com/file/d/0B0d9ZiqAgFkiOHR1NTJhWVJMNEU/view?usp=sharing 179 | 180 | You will get a folder named dataset1/ 181 | 182 | 183 | ## Using the python module 184 | 185 | You can import keras_segmentation in your python script and use the API 186 | 187 | ```python 188 | from keras_segmentation.models.unet import vgg_unet 189 | 190 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608 ) 191 | 192 | model.train( 193 | train_images = "dataset1/images_prepped_train/", 194 | train_annotations = "dataset1/annotations_prepped_train/", 195 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5 196 | ) 197 | 198 | out = model.predict_segmentation( 199 | inp="dataset1/images_prepped_test/0016E5_07965.png", 200 | out_fname="/tmp/out.png" 201 | ) 202 | 203 | import matplotlib.pyplot as plt 204 | plt.imshow(out) 205 | 206 | # evaluating the model 207 | print(model.evaluate_segmentation( inp_images_dir="dataset1/images_prepped_test/" , annotations_dir="dataset1/annotations_prepped_test/" ) ) 208 | 209 | ``` 210 | 211 | 212 | ## Usage via command line 213 | You can also use the tool just using command line 214 | 215 | ### Visualizing the prepared data 216 | 217 | You can also visualize your prepared annotations for verification of the prepared data. 218 | 219 | 220 | ```shell 221 | python -m keras_segmentation verify_dataset \ 222 | --images_path="dataset1/images_prepped_train/" \ 223 | --segs_path="dataset1/annotations_prepped_train/" \ 224 | --n_classes=50 225 | ``` 226 | 227 | ```shell 228 | python -m keras_segmentation visualize_dataset \ 229 | --images_path="dataset1/images_prepped_train/" \ 230 | --segs_path="dataset1/annotations_prepped_train/" \ 231 | --n_classes=50 232 | ``` 233 | 234 | 235 | 236 | ### Training the Model 237 | 238 | To train the model run the following command: 239 | 240 | ```shell 241 | python -m keras_segmentation train \ 242 | --checkpoints_path="path_to_checkpoints" \ 243 | --train_images="dataset1/images_prepped_train/" \ 244 | --train_annotations="dataset1/annotations_prepped_train/" \ 245 | --val_images="dataset1/images_prepped_test/" \ 246 | --val_annotations="dataset1/annotations_prepped_test/" \ 247 | --n_classes=50 \ 248 | --input_height=320 \ 249 | --input_width=640 \ 250 | --model_name="vgg_unet" 251 | ``` 252 | 253 | Choose model_name from the table above 254 | 255 | 256 | 257 | ### Getting the predictions 258 | 259 | To get the predictions of a trained model 260 | 261 | ```shell 262 | python -m keras_segmentation predict \ 263 | --checkpoints_path="path_to_checkpoints" \ 264 | --input_path="dataset1/images_prepped_test/" \ 265 | --output_path="path_to_predictions" 266 | 267 | ``` 268 | 269 | 270 | 271 | ### Video inference 272 | 273 | To get predictions of a video 274 | ```shell 275 | python -m keras_segmentation predict_video \ 276 | --checkpoints_path="path_to_checkpoints" \ 277 | --input="path_to_video" \ 278 | --output_file="path_for_save_inferenced_video" \ 279 | --display 280 | ``` 281 | 282 | If you want to make predictions on your webcam, don't use `--input`, or pass your device number: `--input 0` 283 | `--display` opens a window with the predicted video. Remove this argument when using a headless system. 284 | 285 | 286 | ### Model Evaluation 287 | 288 | To get the IoU scores 289 | 290 | ```shell 291 | python -m keras_segmentation evaluate_model \ 292 | --checkpoints_path="path_to_checkpoints" \ 293 | --images_path="dataset1/images_prepped_test/" \ 294 | --segs_path="dataset1/annotations_prepped_test/" 295 | ``` 296 | 297 | 298 | 299 | ## Fine-tuning from existing segmentation model 300 | 301 | The following example shows how to fine-tune a model with 10 classes . 302 | 303 | ```python 304 | from keras_segmentation.models.model_utils import transfer_weights 305 | from keras_segmentation.pretrained import pspnet_50_ADE_20K 306 | from keras_segmentation.models.pspnet import pspnet_50 307 | 308 | pretrained_model = pspnet_50_ADE_20K() 309 | 310 | new_model = pspnet_50( n_classes=51 ) 311 | 312 | transfer_weights( pretrained_model , new_model ) # transfer weights from pre-trained model to your model 313 | 314 | new_model.train( 315 | train_images = "dataset1/images_prepped_train/", 316 | train_annotations = "dataset1/annotations_prepped_train/", 317 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5 318 | ) 319 | 320 | 321 | ``` 322 | 323 | 324 | 325 | ## Knowledge distillation for compressing the model 326 | 327 | The following example shows transfer the knowledge from a larger ( and more accurate ) model to a smaller model. In most cases the smaller model trained via knowledge distilation is more accurate compared to the same model trained using vanilla supervised learning. 328 | 329 | ```python 330 | from keras_segmentation.predict import model_from_checkpoint_path 331 | from keras_segmentation.models.unet import unet_mini 332 | from keras_segmentation.model_compression import perform_distilation 333 | 334 | model_large = model_from_checkpoint_path( "/checkpoints/path/of/trained/model" ) 335 | model_small = unet_mini( n_classes=51, input_height=300, input_width=400 ) 336 | 337 | perform_distilation ( data_path="/path/to/large_image_set/" , checkpoints_path="path/to/save/checkpoints" , 338 | teacher_model=model_large , student_model=model_small , distilation_loss='kl' , feats_distilation_loss='pa' ) 339 | 340 | ``` 341 | 342 | 343 | 344 | 345 | 346 | ## Adding custom augmentation function to training 347 | 348 | The following example shows how to define a custom augmentation function for training. 349 | 350 | ```python 351 | 352 | from keras_segmentation.models.unet import vgg_unet 353 | from imgaug import augmenters as iaa 354 | 355 | def custom_augmentation(): 356 | return iaa.Sequential( 357 | [ 358 | # apply the following augmenters to most images 359 | iaa.Fliplr(0.5), # horizontally flip 50% of all images 360 | iaa.Flipud(0.5), # horizontally flip 50% of all images 361 | ]) 362 | 363 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608) 364 | 365 | model.train( 366 | train_images = "dataset1/images_prepped_train/", 367 | train_annotations = "dataset1/annotations_prepped_train/", 368 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, 369 | do_augment=True, # enable augmentation 370 | custom_augmentation=custom_augmentation # sets the augmention function to use 371 | ) 372 | ``` 373 | ## Custom number of input channels 374 | 375 | The following example shows how to set the number of input channels. 376 | 377 | ```python 378 | 379 | from keras_segmentation.models.unet import vgg_unet 380 | 381 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608, 382 | channels=1 # Sets the number of input channels 383 | ) 384 | 385 | model.train( 386 | train_images = "dataset1/images_prepped_train/", 387 | train_annotations = "dataset1/annotations_prepped_train/", 388 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, 389 | read_image_type=0 # Sets how opencv will read the images 390 | # cv2.IMREAD_COLOR = 1 (rgb), 391 | # cv2.IMREAD_GRAYSCALE = 0, 392 | # cv2.IMREAD_UNCHANGED = -1 (4 channels like RGBA) 393 | ) 394 | ``` 395 | 396 | ## Custom preprocessing 397 | 398 | The following example shows how to set a custom image preprocessing function. 399 | 400 | ```python 401 | 402 | from keras_segmentation.models.unet import vgg_unet 403 | 404 | def image_preprocessing(image): 405 | return image + 1 406 | 407 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608) 408 | 409 | model.train( 410 | train_images = "dataset1/images_prepped_train/", 411 | train_annotations = "dataset1/annotations_prepped_train/", 412 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, 413 | preprocessing=image_preprocessing # Sets the preprocessing function 414 | ) 415 | ``` 416 | 417 | ## Custom callbacks 418 | 419 | The following example shows how to set custom callbacks for the model training. 420 | 421 | ```python 422 | 423 | from keras_segmentation.models.unet import vgg_unet 424 | from keras.callbacks import ModelCheckpoint, EarlyStopping 425 | 426 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608 ) 427 | 428 | # When using custom callbacks, the default checkpoint saver is removed 429 | callbacks = [ 430 | ModelCheckpoint( 431 | filepath="checkpoints/" + model.name + ".{epoch:05d}", 432 | save_weights_only=True, 433 | verbose=True 434 | ), 435 | EarlyStopping() 436 | ] 437 | 438 | model.train( 439 | train_images = "dataset1/images_prepped_train/", 440 | train_annotations = "dataset1/annotations_prepped_train/", 441 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, 442 | callbacks=callbacks 443 | ) 444 | ``` 445 | 446 | ## Multi input image input 447 | 448 | The following example shows how to add additional image inputs for models. 449 | 450 | ```python 451 | 452 | from keras_segmentation.models.unet import vgg_unet 453 | 454 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608) 455 | 456 | model.train( 457 | train_images = "dataset1/images_prepped_train/", 458 | train_annotations = "dataset1/annotations_prepped_train/", 459 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, 460 | other_inputs_paths=[ 461 | "/path/to/other/directory" 462 | ], 463 | 464 | 465 | # Ability to add preprocessing 466 | preprocessing=[lambda x: x+1, lambda x: x+2, lambda x: x+3], # Different prepocessing for each input 467 | # OR 468 | preprocessing=lambda x: x+1, # Same preprocessing for each input 469 | ) 470 | ``` 471 | 472 | 473 | ## Projects using keras-segmentation 474 | Here are a few projects which are using our library : 475 | * https://github.com/SteliosTsop/QF-image-segmentation-keras [paper](https://arxiv.org/pdf/1908.02242.pdf) 476 | * https://github.com/willembressers/bouquet_quality 477 | * https://github.com/jqueguiner/image-segmentation 478 | * https://github.com/pan0rama/CS230-Microcrystal-Facet-Segmentation 479 | * https://github.com/theerawatramchuen/Keras_Segmentation 480 | * https://github.com/neheller/labels18 481 | * https://github.com/Divyam10/Face-Matting-using-Unet 482 | * https://github.com/shsh-a/segmentation-over-web 483 | * https://github.com/chenwe73/deep_active_learning_segmentation 484 | * https://github.com/vigneshrajap/vision-based-navigation-agri-fields 485 | * https://github.com/ronalddas/Pneumonia-Detection 486 | * https://github.com/Aiwiscal/ECG_UNet 487 | * https://github.com/TianzhongSong/Unet-for-Person-Segmentation 488 | * https://github.com/Guyanqi/GMDNN 489 | * https://github.com/kozemzak/prostate-lesion-segmentation 490 | * https://github.com/lixiaoyu12138/fcn-date 491 | * https://github.com/sagarbhokre/LyftChallenge 492 | * https://github.com/TianzhongSong/Person-Segmentation-Keras 493 | * https://github.com/divyanshpuri02/COCO_2018-Stuff-Segmentation-Challenge 494 | * https://github.com/XiangbingJi/Stanford-cs230-final-project 495 | * https://github.com/lsh1994/keras-segmentation 496 | * https://github.com/SpirinEgor/mobile_semantic_segmentation 497 | * https://github.com/LeadingIndiaAI/COCO-DATASET-STUFF-SEGMENTATION-CHALLENGE 498 | * https://github.com/lidongyue12138/Image-Segmentation-by-Keras 499 | * https://github.com/laoj2/segnet_crfasrnn 500 | * https://github.com/rancheng/AirSimProjects 501 | * https://github.com/RadiumScriptTang/cartoon_segmentation 502 | * https://github.com/dquail/NerveSegmentation 503 | * https://github.com/Bhomik/SemanticHumanMatting 504 | * https://github.com/Symefa/FP-Biomedik-Breast-Cancer 505 | * https://github.com/Alpha-Monocerotis/PDF_FigureTable_Extraction 506 | * https://github.com/rusito-23/mobile_unet_segmentation 507 | * https://github.com/Philliec459/ThinSection-image-segmentation-keras 508 | * https://github.com/imsadia/cv-assignment-three.git 509 | * https://github.com/kejitan/ESVGscale 510 | 511 | If you use our code in a publicly available project, please add the link here ( by posting an issue or creating a PR ) 512 | 513 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tanmaniac/opencv3-cudagl 2 | 3 | # install prerequisites 4 | RUN apt-get update \ 5 | && apt-get install -y wget git curl nano \ 6 | && apt-get install -y libsm6 libxext6 libxrender-dev 7 | 8 | # install Cudnn 9 | ENV CUDNN_VERSION 7.6.0.64 10 | RUN apt-get update && apt-get install -y --no-install-recommends \ 11 | libcudnn7=$CUDNN_VERSION-1+cuda9.0 \ 12 | libcudnn7-dev=$CUDNN_VERSION-1+cuda9.0 && \ 13 | apt-mark hold libcudnn7 && \ 14 | rm -rf /var/lib/apt/lists/* 15 | 16 | # Install Miniconda 17 | RUN curl -so /miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh \ 18 | && chmod +x /miniconda.sh \ 19 | && /miniconda.sh -b -p /miniconda \ 20 | && rm /miniconda.sh 21 | 22 | # Create a Python 3.6 environment 23 | ENV PATH=/miniconda/bin:$PATH 24 | 25 | RUN /miniconda/bin/conda install -y conda-build \ 26 | && /miniconda/bin/conda create -y --name unet python=3.6.7 \ 27 | && /miniconda/bin/conda clean -ya 28 | 29 | ENV CONDA_DEFAULT_ENV=unet 30 | ENV CONDA_PREFIX=/miniconda/envs/$CONDA_DEFAULT_ENV 31 | ENV PATH=$CONDA_PREFIX/bin:$PATH 32 | ENV CONDA_AUTO_UPDATE_CONDA=false 33 | 34 | RUN conda install -y ipython tensorflow-gpu=1.14.0 keras=2.3.1 35 | 36 | # install model library 37 | RUN git clone https://github.com/divamgupta/image-segmentation-keras.git 38 | WORKDIR /image-segmentation-keras 39 | RUN python setup.py install 40 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ## Installing Docker 2 | 3 | General installation instructions are 4 | [on the Docker site](https://docs.docker.com/installation/), but we give some 5 | quick links here: 6 | 7 | * [OSX](https://docs.docker.com/installation/mac/): [docker toolbox](https://www.docker.com/toolbox) 8 | * [ubuntu](https://docs.docker.com/installation/ubuntulinux/) 9 | 10 | For GPU support, install compatible NVIDIA drivers with CUDA9.0 and CUDNN 7.6 11 | 12 | ## Running the container 13 | 14 | Build the container: 15 | 16 | $ docker build -t isk 17 | 18 | To run the image: 19 | 20 | $ docker run --gpus all -it isk 21 | 22 | If you want to train with a dataset on your local machine, or make inference on images or videos, mount a volume to share this data with the docker container: 23 | 24 | $ docker run --gpus all -v /path/to/data/folder:/image-segmentation-keras/share -it isk 25 | 26 | If graphical interface is needed, to show results, like `predict_video --display`, first let docker to use system interface. In your local host, type this line once: 27 | 28 | $ xhost +local:docker 29 | 30 | And run the container with access to X11: 31 | 32 | $ docker run --gpus all -v /path/to/data/folder:/image-segmentation-keras/share -e DISPLAY=$DISPLAY -v /tmp/.X11-unix/:/tmp/.X11-unix --env QT_X11_NO_MITSHM=1 -it isk 33 | -------------------------------------------------------------------------------- /keras_segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/keras_segmentation/__init__.py -------------------------------------------------------------------------------- /keras_segmentation/__main__.py: -------------------------------------------------------------------------------- 1 | 2 | def main(): 3 | from . import cli_interface 4 | cli_interface.main() 5 | 6 | if __name__ == "__main__": 7 | main() 8 | -------------------------------------------------------------------------------- /keras_segmentation/cli_interface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import argparse 5 | 6 | from .train import train 7 | from .predict import predict, predict_multiple, predict_video, evaluate 8 | from .data_utils.data_loader import verify_segmentation_dataset 9 | from .data_utils.visualize_dataset import visualize_segmentation_dataset 10 | 11 | 12 | def train_action(command_parser): 13 | parser = command_parser.add_parser('train') 14 | parser.add_argument("--model_name", type=str, required=True) 15 | parser.add_argument("--train_images", type=str, required=True) 16 | parser.add_argument("--train_annotations", type=str, required=True) 17 | 18 | parser.add_argument("--n_classes", type=int, required=True) 19 | parser.add_argument("--input_height", type=int, default=None) 20 | parser.add_argument("--input_width", type=int, default=None) 21 | 22 | parser.add_argument('--not_verify_dataset', action='store_false') 23 | parser.add_argument("--checkpoints_path", type=str, default=None) 24 | parser.add_argument("--epochs", type=int, default=5) 25 | parser.add_argument("--batch_size", type=int, default=2) 26 | 27 | parser.add_argument('--validate', action='store_true') 28 | parser.add_argument("--val_images", type=str, default="") 29 | parser.add_argument("--val_annotations", type=str, default="") 30 | 31 | parser.add_argument("--val_batch_size", type=int, default=2) 32 | parser.add_argument("--load_weights", type=str, default=None) 33 | parser.add_argument('--auto_resume_checkpoint', action='store_true') 34 | 35 | parser.add_argument("--steps_per_epoch", type=int, default=512) 36 | parser.add_argument("--optimizer_name", type=str, default="adam") 37 | 38 | def action(args): 39 | return train(model=args.model_name, 40 | train_images=args.train_images, 41 | train_annotations=args.train_annotations, 42 | input_height=args.input_height, 43 | input_width=args.input_width, 44 | n_classes=args.n_classes, 45 | verify_dataset=args.not_verify_dataset, 46 | checkpoints_path=args.checkpoints_path, 47 | epochs=args.epochs, 48 | batch_size=args.batch_size, 49 | validate=args.validate, 50 | val_images=args.val_images, 51 | val_annotations=args.val_annotations, 52 | val_batch_size=args.val_batch_size, 53 | auto_resume_checkpoint=args.auto_resume_checkpoint, 54 | load_weights=args.load_weights, 55 | steps_per_epoch=args.steps_per_epoch, 56 | optimizer_name=args.optimizer_name) 57 | 58 | parser.set_defaults(func=action) 59 | 60 | 61 | def predict_action(command_parser): 62 | 63 | parser = command_parser.add_parser('predict') 64 | parser.add_argument("--checkpoints_path", type=str, required=True) 65 | parser.add_argument("--input_path", type=str, default="", required=True) 66 | parser.add_argument("--output_path", type=str, default="", required=True) 67 | 68 | def action(args): 69 | input_path_extension = args.input_path.split('.')[-1] 70 | if input_path_extension in ['jpg', 'jpeg', 'png']: 71 | return predict(inp=args.input_path, out_fname=args.output_path, 72 | checkpoints_path=args.checkpoints_path) 73 | else: 74 | return predict_multiple(inp_dir=args.input_path, 75 | out_dir=args.output_path, 76 | checkpoints_path=args.checkpoints_path) 77 | 78 | parser.set_defaults(func=action) 79 | 80 | 81 | def predict_video_action(command_parser): 82 | parser = command_parser.add_parser('predict_video') 83 | parser.add_argument("--input", type=str, default=0, required=False) 84 | parser.add_argument("--output_file", type=str, default="", required=False) 85 | parser.add_argument("--checkpoints_path", required=True) 86 | parser.add_argument("--display", action='store_true', required=False) 87 | 88 | def action(args): 89 | return predict_video(inp=args.input, 90 | output=args.output_file, 91 | checkpoints_path=args.checkpoints_path, 92 | display=args.display, 93 | ) 94 | 95 | parser.set_defaults(func=action) 96 | 97 | 98 | def evaluate_model_action(command_parser): 99 | 100 | parser = command_parser.add_parser('evaluate_model') 101 | parser.add_argument("--images_path", type=str, required=True) 102 | parser.add_argument("--segs_path", type=str, required=True) 103 | parser.add_argument("--checkpoints_path", type=str, required=True) 104 | 105 | def action(args): 106 | print(evaluate( 107 | inp_images_dir=args.images_path, annotations_dir=args.segs_path, 108 | checkpoints_path=args.checkpoints_path)) 109 | 110 | parser.set_defaults(func=action) 111 | 112 | 113 | def verify_dataset_action(command_parser): 114 | 115 | parser = command_parser.add_parser('verify_dataset') 116 | parser.add_argument("--images_path", type=str) 117 | parser.add_argument("--segs_path", type=str) 118 | parser.add_argument("--n_classes", type=int) 119 | 120 | def action(args): 121 | verify_segmentation_dataset( 122 | args.images_path, args.segs_path, args.n_classes) 123 | 124 | parser.set_defaults(func=action) 125 | 126 | 127 | def visualize_dataset_action(command_parser): 128 | 129 | parser = command_parser.add_parser('visualize_dataset') 130 | parser.add_argument("--images_path", type=str) 131 | parser.add_argument("--segs_path", type=str) 132 | parser.add_argument("--n_classes", type=int) 133 | parser.add_argument('--do_augment', action='store_true') 134 | 135 | def action(args): 136 | visualize_segmentation_dataset(args.images_path, args.segs_path, 137 | args.n_classes, 138 | do_augment=args.do_augment) 139 | 140 | parser.set_defaults(func=action) 141 | 142 | 143 | def main(): 144 | assert len(sys.argv) >= 2, \ 145 | "python -m keras_segmentation " 146 | 147 | main_parser = argparse.ArgumentParser() 148 | command_parser = main_parser.add_subparsers() 149 | 150 | # Add individual commands 151 | train_action(command_parser) 152 | predict_action(command_parser) 153 | predict_video_action(command_parser) 154 | verify_dataset_action(command_parser) 155 | visualize_dataset_action(command_parser) 156 | evaluate_model_action(command_parser) 157 | 158 | args = main_parser.parse_args() 159 | 160 | args.func(args) 161 | -------------------------------------------------------------------------------- /keras_segmentation/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/keras_segmentation/data_utils/__init__.py -------------------------------------------------------------------------------- /keras_segmentation/data_utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | try: 4 | import imgaug as ia 5 | from imgaug import augmenters as iaa 6 | except ImportError: 7 | print("Error in loading augmentation, can't import imgaug." 8 | "Please make sure it is installed.") 9 | 10 | 11 | IMAGE_AUGMENTATION_SEQUENCE = None 12 | IMAGE_AUGMENTATION_NUM_TRIES = 10 13 | 14 | loaded_augmentation_name = "" 15 | 16 | 17 | def _load_augmentation_aug_geometric(): 18 | return iaa.OneOf([ 19 | iaa.Sequential([iaa.Fliplr(0.5), iaa.Flipud(0.2)]), 20 | iaa.CropAndPad(percent=(-0.05, 0.1), 21 | pad_mode='constant', 22 | pad_cval=(0, 255)), 23 | iaa.Crop(percent=(0.0, 0.1)), 24 | iaa.Crop(percent=(0.3, 0.5)), 25 | iaa.Crop(percent=(0.3, 0.5)), 26 | iaa.Crop(percent=(0.3, 0.5)), 27 | iaa.Sequential([ 28 | iaa.Affine( 29 | # scale images to 80-120% of their size, 30 | # individually per axis 31 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 32 | # translate by -20 to +20 percent (per axis) 33 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, 34 | rotate=(-45, 45), # rotate by -45 to +45 degrees 35 | shear=(-16, 16), # shear by -16 to +16 degrees 36 | # use nearest neighbour or bilinear interpolation (fast) 37 | order=[0, 1], 38 | # if mode is constant, use a cval between 0 and 255 39 | mode='constant', 40 | cval=(0, 255), 41 | # use any of scikit-image's warping modes 42 | # (see 2nd image from the top for examples) 43 | ), 44 | iaa.Sometimes(0.3, iaa.Crop(percent=(0.3, 0.5)))]) 45 | ]) 46 | 47 | 48 | def _load_augmentation_aug_non_geometric(): 49 | return iaa.Sequential([ 50 | iaa.Sometimes(0.3, iaa.Multiply((0.5, 1.5), per_channel=0.5)), 51 | iaa.Sometimes(0.2, iaa.JpegCompression(compression=(70, 99))), 52 | iaa.Sometimes(0.2, iaa.GaussianBlur(sigma=(0, 3.0))), 53 | iaa.Sometimes(0.2, iaa.MotionBlur(k=15, angle=[-45, 45])), 54 | iaa.Sometimes(0.2, iaa.MultiplyHue((0.5, 1.5))), 55 | iaa.Sometimes(0.2, iaa.MultiplySaturation((0.5, 1.5))), 56 | iaa.Sometimes(0.34, iaa.MultiplyHueAndSaturation((0.5, 1.5), 57 | per_channel=True)), 58 | iaa.Sometimes(0.34, iaa.Grayscale(alpha=(0.0, 1.0))), 59 | iaa.Sometimes(0.2, iaa.ChangeColorTemperature((1100, 10000))), 60 | iaa.Sometimes(0.1, iaa.GammaContrast((0.5, 2.0))), 61 | iaa.Sometimes(0.2, iaa.SigmoidContrast(gain=(3, 10), 62 | cutoff=(0.4, 0.6))), 63 | iaa.Sometimes(0.1, iaa.CLAHE()), 64 | iaa.Sometimes(0.1, iaa.HistogramEqualization()), 65 | iaa.Sometimes(0.2, iaa.LinearContrast((0.5, 2.0), per_channel=0.5)), 66 | iaa.Sometimes(0.1, iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0))) 67 | ]) 68 | 69 | 70 | def _load_augmentation_aug_all2(): 71 | return iaa.Sequential([ 72 | iaa.Sometimes(0.65, _load_augmentation_aug_non_geometric()), 73 | iaa.Sometimes(0.65, _load_augmentation_aug_geometric()) 74 | ]) 75 | 76 | 77 | def _load_augmentation_aug_all(): 78 | """ Load image augmentation model """ 79 | 80 | def sometimes(aug): 81 | return iaa.Sometimes(0.5, aug) 82 | 83 | return iaa.Sequential( 84 | [ 85 | # apply the following augmenters to most images 86 | iaa.Fliplr(0.5), # horizontally flip 50% of all images 87 | iaa.Flipud(0.2), # vertically flip 20% of all images 88 | # crop images by -5% to 10% of their height/width 89 | sometimes(iaa.CropAndPad( 90 | percent=(-0.05, 0.1), 91 | pad_mode='constant', 92 | pad_cval=(0, 255) 93 | )), 94 | sometimes(iaa.Affine( 95 | # scale images to 80-120% of their size, individually per axis 96 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 97 | # translate by -20 to +20 percent (per axis) 98 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, 99 | rotate=(-45, 45), # rotate by -45 to +45 degrees 100 | shear=(-16, 16), # shear by -16 to +16 degrees 101 | # use nearest neighbour or bilinear interpolation (fast) 102 | order=[0, 1], 103 | # if mode is constant, use a cval between 0 and 255 104 | cval=(0, 255), 105 | # use any of scikit-image's warping modes 106 | # (see 2nd image from the top for examples) 107 | mode='constant' 108 | )), 109 | # execute 0 to 5 of the following (less important) augmenters per 110 | # image don't execute all of them, as that would often be way too 111 | # strong 112 | iaa.SomeOf((0, 5), 113 | [ 114 | # convert images into their superpixel representation 115 | sometimes(iaa.Superpixels( 116 | p_replace=(0, 1.0), n_segments=(20, 200))), 117 | iaa.OneOf([ 118 | # blur images with a sigma between 0 and 3.0 119 | iaa.GaussianBlur((0, 3.0)), 120 | # blur image using local means with kernel sizes 121 | # between 2 and 7 122 | iaa.AverageBlur(k=(2, 7)), 123 | # blur image using local medians with kernel sizes 124 | # between 2 and 7 125 | iaa.MedianBlur(k=(3, 11)), 126 | ]), 127 | iaa.Sharpen(alpha=(0, 1.0), lightness=( 128 | 0.75, 1.5)), # sharpen images 129 | iaa.Emboss(alpha=(0, 1.0), strength=( 130 | 0, 2.0)), # emboss images 131 | # search either for all edges or for directed edges, 132 | # blend the result with the original image using a blobby mask 133 | iaa.BlendAlphaSimplexNoise(iaa.OneOf([ 134 | iaa.EdgeDetect(alpha=(0.5, 1.0)), 135 | iaa.DirectedEdgeDetect( 136 | alpha=(0.5, 1.0), direction=(0.0, 1.0)), 137 | ])), 138 | # add gaussian noise to images 139 | iaa.AdditiveGaussianNoise(loc=0, scale=( 140 | 0.0, 0.05*255), per_channel=0.5), 141 | iaa.OneOf([ 142 | # randomly remove up to 10% of the pixels 143 | iaa.Dropout((0.01, 0.1), per_channel=0.5), 144 | iaa.CoarseDropout((0.03, 0.15), size_percent=( 145 | 0.02, 0.05), per_channel=0.2), 146 | ]), 147 | # invert color channels 148 | iaa.Invert(0.05, per_channel=True), 149 | # change brightness of images (by -10 to 10 of original value) 150 | iaa.Add((-10, 10), per_channel=0.5), 151 | # change hue and saturation 152 | iaa.AddToHueAndSaturation((-20, 20)), 153 | # either change the brightness of the whole image (sometimes 154 | # per channel) or change the brightness of subareas 155 | iaa.OneOf([ 156 | iaa.Multiply( 157 | (0.5, 1.5), per_channel=0.5), 158 | iaa.BlendAlphaFrequencyNoise( 159 | exponent=(-4, 0), 160 | foreground=iaa.Multiply( 161 | (0.5, 1.5), per_channel=True), 162 | background=iaa.contrast.LinearContrast( 163 | (0.5, 2.0)) 164 | ) 165 | ]), 166 | # improve or worsen the contrast 167 | iaa.contrast.LinearContrast((0.5, 2.0), per_channel=0.5), 168 | iaa.Grayscale(alpha=(0.0, 1.0)), 169 | # move pixels locally around (with random strengths) 170 | sometimes(iaa.ElasticTransformation( 171 | alpha=(0.5, 3.5), sigma=0.25)), 172 | # sometimes move parts of the image around 173 | sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.05))), 174 | sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1))) 175 | ], 176 | random_order=True 177 | ) 178 | ], 179 | random_order=True 180 | ) 181 | 182 | 183 | augmentation_functions = { 184 | "aug_all": _load_augmentation_aug_all, 185 | "aug_all2": _load_augmentation_aug_all2, 186 | "aug_geometric": _load_augmentation_aug_geometric, 187 | "aug_non_geometric": _load_augmentation_aug_non_geometric 188 | } 189 | 190 | 191 | def _load_augmentation(augmentation_name="aug_all"): 192 | 193 | global IMAGE_AUGMENTATION_SEQUENCE 194 | 195 | if augmentation_name not in augmentation_functions: 196 | raise ValueError("Augmentation name not supported") 197 | 198 | IMAGE_AUGMENTATION_SEQUENCE = augmentation_functions[augmentation_name]() 199 | 200 | 201 | def _augment_seg(img, seg, augmentation_name="aug_all", other_imgs=None): 202 | 203 | global loaded_augmentation_name 204 | 205 | if (not IMAGE_AUGMENTATION_SEQUENCE) or\ 206 | (augmentation_name != loaded_augmentation_name): 207 | _load_augmentation(augmentation_name) 208 | loaded_augmentation_name = augmentation_name 209 | 210 | # Create a deterministic augmentation from the random one 211 | aug_det = IMAGE_AUGMENTATION_SEQUENCE.to_deterministic() 212 | # Augment the input image 213 | image_aug = aug_det.augment_image(img) 214 | 215 | if other_imgs is not None: 216 | image_aug = [image_aug] 217 | 218 | for other_img in other_imgs: 219 | image_aug.append(aug_det.augment_image(other_img)) 220 | 221 | segmap = ia.SegmentationMapsOnImage( 222 | seg, shape=img.shape) 223 | segmap_aug = aug_det.augment_segmentation_maps(segmap) 224 | segmap_aug = segmap_aug.get_arr() 225 | 226 | return image_aug, segmap_aug 227 | 228 | 229 | def _custom_augment_seg(img, seg, augmentation_function, other_imgs=None): 230 | augmentation_functions['custom_aug'] = augmentation_function 231 | 232 | return _augment_seg(img, seg, "custom_aug", other_imgs=other_imgs) 233 | 234 | 235 | def _try_n_times(fn, n, *args, **kargs): 236 | """ Try a function N times """ 237 | attempts = 0 238 | while attempts < n: 239 | try: 240 | return fn(*args, **kargs) 241 | except Exception: 242 | attempts += 1 243 | 244 | return fn(*args, **kargs) 245 | 246 | 247 | def augment_seg(img, seg, augmentation_name="aug_all", other_imgs=None): 248 | return _try_n_times(_augment_seg, IMAGE_AUGMENTATION_NUM_TRIES, 249 | img, seg, augmentation_name=augmentation_name, 250 | other_imgs=other_imgs) 251 | 252 | 253 | def custom_augment_seg(img, seg, augmentation_function, other_imgs=None): 254 | return _try_n_times(_custom_augment_seg, IMAGE_AUGMENTATION_NUM_TRIES, 255 | img, seg, augmentation_function=augmentation_function, 256 | other_imgs=other_imgs) -------------------------------------------------------------------------------- /keras_segmentation/data_utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import random 4 | import six 5 | import numpy as np 6 | import cv2 7 | 8 | try: 9 | from collections.abc import Sequence 10 | except ImportError: 11 | from collections import Sequence 12 | 13 | try: 14 | from tqdm import tqdm 15 | except ImportError: 16 | print("tqdm not found, disabling progress bars") 17 | 18 | def tqdm(iter): 19 | return iter 20 | 21 | 22 | from ..models.config import IMAGE_ORDERING 23 | from .augmentation import augment_seg, custom_augment_seg 24 | 25 | DATA_LOADER_SEED = 0 26 | 27 | random.seed(DATA_LOADER_SEED) 28 | class_colors = [(random.randint(0, 255), random.randint( 29 | 0, 255), random.randint(0, 255)) for _ in range(5000)] 30 | 31 | 32 | ACCEPTABLE_IMAGE_FORMATS = [".jpg", ".jpeg", ".png", ".bmp"] 33 | ACCEPTABLE_SEGMENTATION_FORMATS = [".png", ".bmp"] 34 | 35 | 36 | class DataLoaderError(Exception): 37 | pass 38 | 39 | 40 | 41 | def get_image_list_from_path(images_path ): 42 | image_files = [] 43 | for dir_entry in os.listdir(images_path): 44 | if os.path.isfile(os.path.join(images_path, dir_entry)) and \ 45 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_IMAGE_FORMATS: 46 | file_name, file_extension = os.path.splitext(dir_entry) 47 | image_files.append(os.path.join(images_path, dir_entry)) 48 | return image_files 49 | 50 | 51 | def get_pairs_from_paths(images_path, segs_path, ignore_non_matching=False, other_inputs_paths=None): 52 | """ Find all the images from the images_path directory and 53 | the segmentation images from the segs_path directory 54 | while checking integrity of data """ 55 | 56 | 57 | 58 | image_files = [] 59 | segmentation_files = {} 60 | 61 | for dir_entry in os.listdir(images_path): 62 | if os.path.isfile(os.path.join(images_path, dir_entry)) and \ 63 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_IMAGE_FORMATS: 64 | file_name, file_extension = os.path.splitext(dir_entry) 65 | image_files.append((file_name, file_extension, 66 | os.path.join(images_path, dir_entry))) 67 | 68 | if other_inputs_paths is not None: 69 | other_inputs_files = [] 70 | 71 | for i, other_inputs_path in enumerate(other_inputs_paths): 72 | temp = [] 73 | 74 | for y, dir_entry in enumerate(os.listdir(other_inputs_path)): 75 | if os.path.isfile(os.path.join(other_inputs_path, dir_entry)) and \ 76 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_IMAGE_FORMATS: 77 | file_name, file_extension = os.path.splitext(dir_entry) 78 | 79 | temp.append((file_name, file_extension, 80 | os.path.join(other_inputs_path, dir_entry))) 81 | 82 | other_inputs_files.append(temp) 83 | 84 | for dir_entry in os.listdir(segs_path): 85 | if os.path.isfile(os.path.join(segs_path, dir_entry)) and \ 86 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_SEGMENTATION_FORMATS: 87 | file_name, file_extension = os.path.splitext(dir_entry) 88 | full_dir_entry = os.path.join(segs_path, dir_entry) 89 | if file_name in segmentation_files: 90 | raise DataLoaderError("Segmentation file with filename {0}" 91 | " already exists and is ambiguous to" 92 | " resolve with path {1}." 93 | " Please remove or rename the latter." 94 | .format(file_name, full_dir_entry)) 95 | 96 | segmentation_files[file_name] = (file_extension, full_dir_entry) 97 | 98 | return_value = [] 99 | # Match the images and segmentations 100 | for image_file, _, image_full_path in image_files: 101 | if image_file in segmentation_files: 102 | if other_inputs_paths is not None: 103 | other_inputs = [] 104 | for file_paths in other_inputs_files: 105 | success = False 106 | 107 | for (other_file, _, other_full_path) in file_paths: 108 | if image_file == other_file: 109 | other_inputs.append(other_full_path) 110 | success = True 111 | break 112 | 113 | if not success: 114 | raise ValueError("There was no matching other input to", image_file, "in directory") 115 | 116 | return_value.append((image_full_path, 117 | segmentation_files[image_file][1], other_inputs)) 118 | else: 119 | return_value.append((image_full_path, 120 | segmentation_files[image_file][1])) 121 | elif ignore_non_matching: 122 | continue 123 | else: 124 | # Error out 125 | raise DataLoaderError("No corresponding segmentation " 126 | "found for image {0}." 127 | .format(image_full_path)) 128 | 129 | return return_value 130 | 131 | 132 | def get_image_array(image_input, 133 | width, height, 134 | imgNorm="sub_mean", ordering='channels_first', read_image_type=1): 135 | """ Load image array from input """ 136 | 137 | if type(image_input) is np.ndarray: 138 | # It is already an array, use it as it is 139 | img = image_input 140 | elif isinstance(image_input, six.string_types): 141 | if not os.path.isfile(image_input): 142 | raise DataLoaderError("get_image_array: path {0} doesn't exist" 143 | .format(image_input)) 144 | img = cv2.imread(image_input, read_image_type) 145 | else: 146 | raise DataLoaderError("get_image_array: Can't process input type {0}" 147 | .format(str(type(image_input)))) 148 | 149 | if imgNorm == "sub_and_divide": 150 | img = np.float32(cv2.resize(img, (width, height))) / 127.5 - 1 151 | elif imgNorm == "sub_mean": 152 | img = cv2.resize(img, (width, height)) 153 | img = img.astype(np.float32) 154 | img = np.atleast_3d(img) 155 | 156 | means = [103.939, 116.779, 123.68] 157 | 158 | for i in range(min(img.shape[2], len(means))): 159 | img[:, :, i] -= means[i] 160 | 161 | img = img[:, :, ::-1] 162 | elif imgNorm == "divide": 163 | img = cv2.resize(img, (width, height)) 164 | img = img.astype(np.float32) 165 | img = img/255.0 166 | 167 | if ordering == 'channels_first': 168 | img = np.rollaxis(img, 2, 0) 169 | return img 170 | 171 | 172 | def get_segmentation_array(image_input, nClasses, 173 | width, height, no_reshape=False, read_image_type=1): 174 | """ Load segmentation array from input """ 175 | 176 | seg_labels = np.zeros((height, width, nClasses)) 177 | 178 | if type(image_input) is np.ndarray: 179 | # It is already an array, use it as it is 180 | img = image_input 181 | elif isinstance(image_input, six.string_types): 182 | if not os.path.isfile(image_input): 183 | raise DataLoaderError("get_segmentation_array: " 184 | "path {0} doesn't exist".format(image_input)) 185 | img = cv2.imread(image_input, read_image_type) 186 | else: 187 | raise DataLoaderError("get_segmentation_array: " 188 | "Can't process input type {0}" 189 | .format(str(type(image_input)))) 190 | 191 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST) 192 | img = img[:, :, 0] 193 | 194 | for c in range(nClasses): 195 | seg_labels[:, :, c] = (img == c).astype(int) 196 | 197 | if not no_reshape: 198 | seg_labels = np.reshape(seg_labels, (width*height, nClasses)) 199 | 200 | return seg_labels 201 | 202 | 203 | def verify_segmentation_dataset(images_path, segs_path, 204 | n_classes, show_all_errors=False): 205 | try: 206 | img_seg_pairs = get_pairs_from_paths(images_path, segs_path) 207 | if not len(img_seg_pairs): 208 | print("Couldn't load any data from images_path: " 209 | "{0} and segmentations path: {1}" 210 | .format(images_path, segs_path)) 211 | return False 212 | 213 | return_value = True 214 | for im_fn, seg_fn in tqdm(img_seg_pairs): 215 | img = cv2.imread(im_fn) 216 | seg = cv2.imread(seg_fn) 217 | # Check dimensions match 218 | if not img.shape == seg.shape: 219 | return_value = False 220 | print("The size of image {0} and its segmentation {1} " 221 | "doesn't match (possibly the files are corrupt)." 222 | .format(im_fn, seg_fn)) 223 | if not show_all_errors: 224 | break 225 | else: 226 | max_pixel_value = np.max(seg[:, :, 0]) 227 | if max_pixel_value >= n_classes: 228 | return_value = False 229 | print("The pixel values of the segmentation image {0} " 230 | "violating range [0, {1}]. " 231 | "Found maximum pixel value {2}" 232 | .format(seg_fn, str(n_classes - 1), max_pixel_value)) 233 | if not show_all_errors: 234 | break 235 | if return_value: 236 | print("Dataset verified! ") 237 | else: 238 | print("Dataset not verified!") 239 | return return_value 240 | except DataLoaderError as e: 241 | print("Found error during data loading\n{0}".format(str(e))) 242 | return False 243 | 244 | 245 | def image_segmentation_generator(images_path, segs_path, batch_size, 246 | n_classes, input_height, input_width, 247 | output_height, output_width, 248 | do_augment=False, 249 | augmentation_name="aug_all", 250 | custom_augmentation=None, 251 | other_inputs_paths=None, preprocessing=None, 252 | read_image_type=cv2.IMREAD_COLOR , ignore_segs=False ): 253 | 254 | 255 | if not ignore_segs: 256 | img_seg_pairs = get_pairs_from_paths(images_path, segs_path, other_inputs_paths=other_inputs_paths) 257 | random.shuffle(img_seg_pairs) 258 | zipped = itertools.cycle(img_seg_pairs) 259 | else: 260 | img_list = get_image_list_from_path( images_path ) 261 | random.shuffle( img_list ) 262 | img_list_gen = itertools.cycle( img_list ) 263 | 264 | 265 | while True: 266 | X = [] 267 | Y = [] 268 | for _ in range(batch_size): 269 | if other_inputs_paths is None: 270 | 271 | if ignore_segs: 272 | im = next( img_list_gen ) 273 | seg = None 274 | else: 275 | im, seg = next(zipped) 276 | seg = cv2.imread(seg, 1) 277 | 278 | im = cv2.imread(im, read_image_type) 279 | 280 | 281 | if do_augment: 282 | 283 | assert ignore_segs == False , "Not supported yet" 284 | 285 | if custom_augmentation is None: 286 | im, seg[:, :, 0] = augment_seg(im, seg[:, :, 0], 287 | augmentation_name) 288 | else: 289 | im, seg[:, :, 0] = custom_augment_seg(im, seg[:, :, 0], 290 | custom_augmentation) 291 | 292 | if preprocessing is not None: 293 | im = preprocessing(im) 294 | 295 | X.append(get_image_array(im, input_width, 296 | input_height, ordering=IMAGE_ORDERING)) 297 | else: 298 | 299 | assert ignore_segs == False , "Not supported yet" 300 | 301 | im, seg, others = next(zipped) 302 | 303 | im = cv2.imread(im, read_image_type) 304 | seg = cv2.imread(seg, 1) 305 | 306 | oth = [] 307 | for f in others: 308 | oth.append(cv2.imread(f, read_image_type)) 309 | 310 | if do_augment: 311 | if custom_augmentation is None: 312 | ims, seg[:, :, 0] = augment_seg(im, seg[:, :, 0], 313 | augmentation_name, other_imgs=oth) 314 | else: 315 | ims, seg[:, :, 0] = custom_augment_seg(im, seg[:, :, 0], 316 | custom_augmentation, other_imgs=oth) 317 | else: 318 | ims = [im] 319 | ims.extend(oth) 320 | 321 | oth = [] 322 | for i, image in enumerate(ims): 323 | oth_im = get_image_array(image, input_width, 324 | input_height, ordering=IMAGE_ORDERING) 325 | 326 | if preprocessing is not None: 327 | if isinstance(preprocessing, Sequence): 328 | oth_im = preprocessing[i](oth_im) 329 | else: 330 | oth_im = preprocessing(oth_im) 331 | 332 | oth.append(oth_im) 333 | 334 | X.append(oth) 335 | 336 | if not ignore_segs: 337 | Y.append(get_segmentation_array( 338 | seg, n_classes, output_width, output_height)) 339 | 340 | if ignore_segs: 341 | yield np.array(X) 342 | else: 343 | yield np.array(X), np.array(Y) 344 | -------------------------------------------------------------------------------- /keras_segmentation/data_utils/visualize_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import random 4 | 5 | import numpy as np 6 | import cv2 7 | 8 | from .augmentation import augment_seg, custom_augment_seg 9 | from .data_loader import \ 10 | get_pairs_from_paths, DATA_LOADER_SEED, class_colors, DataLoaderError 11 | 12 | random.seed(DATA_LOADER_SEED) 13 | 14 | 15 | def _get_colored_segmentation_image(img, seg, colors, 16 | n_classes, do_augment=False, augment_name='aug_all', custom_aug=None): 17 | """ Return a colored segmented image """ 18 | seg_img = np.zeros_like(seg) 19 | 20 | if do_augment: 21 | if custom_aug is not None: 22 | img, seg[:, :, 0] = custom_augment_seg(img, seg[:, :, 0], augmentation_function=custom_aug) 23 | else: 24 | img, seg[:, :, 0] = augment_seg(img, seg[:, :, 0], augmentation_name=augment_name) 25 | 26 | for c in range(n_classes): 27 | seg_img[:, :, 0] += ((seg[:, :, 0] == c) 28 | * (colors[c][0])).astype('uint8') 29 | seg_img[:, :, 1] += ((seg[:, :, 0] == c) 30 | * (colors[c][1])).astype('uint8') 31 | seg_img[:, :, 2] += ((seg[:, :, 0] == c) 32 | * (colors[c][2])).astype('uint8') 33 | 34 | return img, seg_img 35 | 36 | 37 | def visualize_segmentation_dataset(images_path, segs_path, n_classes, 38 | do_augment=False, ignore_non_matching=False, 39 | no_show=False, image_size=None, augment_name="aug_all", custom_aug=None): 40 | try: 41 | # Get image-segmentation pairs 42 | img_seg_pairs = get_pairs_from_paths( 43 | images_path, segs_path, 44 | ignore_non_matching=ignore_non_matching) 45 | 46 | # Get the colors for the classes 47 | colors = class_colors 48 | 49 | print("Please press any key to display the next image") 50 | for im_fn, seg_fn in img_seg_pairs: 51 | img = cv2.imread(im_fn) 52 | seg = cv2.imread(seg_fn) 53 | print("Found the following classes in the segmentation image:", 54 | np.unique(seg)) 55 | img, seg_img = _get_colored_segmentation_image( 56 | img, seg, colors, 57 | n_classes, 58 | do_augment=do_augment, augment_name=augment_name, custom_aug=custom_aug) 59 | 60 | if image_size is not None: 61 | img = cv2.resize(img, image_size) 62 | seg_img = cv2.resize(seg_img, image_size) 63 | 64 | print("Please press any key to display the next image") 65 | cv2.imshow("img", img) 66 | cv2.imshow("seg_img", seg_img) 67 | cv2.waitKey() 68 | except DataLoaderError as e: 69 | print("Found error during data loading\n{0}".format(str(e))) 70 | return False 71 | 72 | 73 | def visualize_segmentation_dataset_one(images_path, segs_path, n_classes, 74 | do_augment=False, no_show=False, 75 | ignore_non_matching=False): 76 | 77 | img_seg_pairs = get_pairs_from_paths( 78 | images_path, segs_path, 79 | ignore_non_matching=ignore_non_matching) 80 | 81 | colors = class_colors 82 | 83 | im_fn, seg_fn = random.choice(img_seg_pairs) 84 | 85 | img = cv2.imread(im_fn) 86 | seg = cv2.imread(seg_fn) 87 | print("Found the following classes " 88 | "in the segmentation image:", np.unique(seg)) 89 | 90 | img, seg_img = _get_colored_segmentation_image( 91 | img, seg, colors, 92 | n_classes, do_augment=do_augment) 93 | 94 | if not no_show: 95 | cv2.imshow("img", img) 96 | cv2.imshow("seg_img", seg_img) 97 | cv2.waitKey() 98 | 99 | return img, seg_img 100 | 101 | 102 | if __name__ == "__main__": 103 | import argparse 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--images", type=str) 106 | parser.add_argument("--annotations", type=str) 107 | parser.add_argument("--n_classes", type=int) 108 | args = parser.parse_args() 109 | 110 | visualize_segmentation_dataset( 111 | args.images, args.annotations, args.n_classes) 112 | -------------------------------------------------------------------------------- /keras_segmentation/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | EPS = 1e-12 4 | 5 | 6 | def get_iou(gt, pr, n_classes): 7 | class_wise = np.zeros(n_classes) 8 | for cl in range(n_classes): 9 | intersection = np.sum((gt == cl)*(pr == cl)) 10 | union = np.sum(np.maximum((gt == cl), (pr == cl))) 11 | iou = float(intersection)/(union + EPS) 12 | class_wise[cl] = iou 13 | return class_wise 14 | -------------------------------------------------------------------------------- /keras_segmentation/model_compression.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import tensorflow as tf 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import six 7 | import os 8 | import json 9 | import sys 10 | 11 | from .data_utils.data_loader import image_segmentation_generator 12 | from .train import CheckpointsCallback 13 | 14 | from keras.models import Model 15 | 16 | 17 | 18 | def get_pariwise_similarities( feats ): 19 | feats_i = tf.reshape( feats , (-1 , 1 , feats.shape[1]*feats.shape[2] , feats.shape[3])) 20 | feats_j = tf.reshape( feats , (-1 , feats.shape[1]*feats.shape[2] , 1 , feats.shape[3])) 21 | 22 | feats_i = feats_i / (( tf.reduce_sum(feats_i**2 , axis=-1 ) )**(0.5))[ ... , None ] 23 | feats_j = feats_j / (( tf.reduce_sum(feats_j**2 , axis=-1 ) )**(0.5))[ ... , None ] 24 | 25 | feats_ixj = feats_i*feats_j 26 | 27 | return tf.reduce_sum( feats_ixj , axis=-1 ) 28 | 29 | 30 | 31 | def pairwise_dist_loss( feats_t , feats_s ): 32 | 33 | # todo max POOL 34 | pool_factor = 4 35 | 36 | feats_t = tf.nn.max_pool(feats_t , (pool_factor,pool_factor) , strides=(pool_factor,pool_factor) , padding="VALID" ) 37 | feats_s = tf.nn.max_pool(feats_s , (pool_factor,pool_factor) , strides=(pool_factor,pool_factor) , padding="VALID" ) 38 | 39 | sims_t = get_pariwise_similarities( feats_t ) 40 | sims_s = get_pariwise_similarities( feats_s ) 41 | n_pixs = sims_s.shape[1] 42 | 43 | return tf.reduce_sum(tf.reduce_sum(((sims_t - sims_s )**2 ) , axis=1), axis=1)/(n_pixs**2 ) 44 | 45 | 46 | 47 | class Distiller(keras.Model): 48 | def __init__(self, student, teacher , distilation_loss , feats_distilation_loss=None , feats_distilation_loss_w=0.1 ): 49 | super(Distiller, self).__init__() 50 | self.teacher = teacher 51 | self.student = student 52 | self.distilation_loss = distilation_loss 53 | 54 | self.feats_distilation_loss = feats_distilation_loss 55 | self.feats_distilation_loss_w = feats_distilation_loss_w 56 | 57 | if not feats_distilation_loss is None: 58 | try: 59 | s_feat_out = student.get_layer("seg_feats").output 60 | except: 61 | s_feat_out = student.get_layer(student.seg_feats_layer_name ).output 62 | 63 | 64 | try: 65 | t_feat_out = teacher.get_layer("seg_feats").output 66 | except: 67 | t_feat_out = teacher.get_layer(teacher.seg_feats_layer_name ).output 68 | 69 | 70 | self.student_feat_model = Model( student.input , s_feat_out ) 71 | self.teacher_feat_model = Model( teacher.input , t_feat_out ) 72 | 73 | def compile( 74 | self, 75 | optimizer, 76 | metrics, 77 | 78 | ): 79 | super(Distiller, self).compile(optimizer=optimizer, metrics=metrics) 80 | 81 | 82 | def train_step(self, data): 83 | teacher_input , = data 84 | 85 | student_input = tf.image.resize( teacher_input , ( self.student.input_height , self.student.input_width ) ) 86 | 87 | teacher_predictions = self.teacher(teacher_input, training=False) 88 | teacher_predictions_reshape = tf.reshape(teacher_predictions , ((-1 , self.teacher.output_height , self.teacher.output_width , self.teacher.output_shape[-1]))) 89 | 90 | 91 | if not self.feats_distilation_loss is None: 92 | teacher_feats = self.teacher_feat_model(teacher_input, training=False) 93 | 94 | with tf.GradientTape() as tape: 95 | student_predictions = self.student( student_input , training=True) 96 | student_predictions_resize = tf.reshape(student_predictions , ((-1, self.student.output_height , self.student.output_width , self.student.output_shape[-1]))) 97 | student_predictions_resize = tf.image.resize( student_predictions_resize , ( self.teacher.output_height , self.teacher.output_width ) ) 98 | 99 | loss = self.distilation_loss( teacher_predictions_reshape , student_predictions_resize ) 100 | 101 | if not self.feats_distilation_loss is None: 102 | student_feats = self.student_feat_model( student_input , training=True) 103 | student_feats_resize = tf.image.resize( student_feats , ( teacher_feats.shape[1] , teacher_feats.shape[2] ) ) 104 | loss += self.feats_distilation_loss_w*self.feats_distilation_loss( teacher_feats , student_feats_resize ) 105 | 106 | 107 | 108 | 109 | trainable_vars = self.student.trainable_variables 110 | gradients = tape.gradient(loss, trainable_vars) 111 | 112 | self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 113 | self.compiled_metrics.update_state(teacher_predictions_reshape , student_predictions_resize ) 114 | 115 | 116 | results = {m.name: m.result() for m in self.metrics} 117 | results.update( 118 | { "distillation_loss": loss} 119 | ) 120 | return results 121 | 122 | 123 | # created a simple custom fit generator due to some issue in keras 124 | def fit_generator_custom( model , gen , epochs , steps_per_epoch , callback=None ): 125 | for ep in range( epochs ): 126 | print("Epoch %d/%d"%(ep+1 , epochs )) 127 | bar = tqdm( range(steps_per_epoch)) 128 | losses = [ ] 129 | for i in bar: 130 | x = next( gen ) 131 | l = model.train_on_batch( x ) 132 | losses.append( l ) 133 | bar.set_description("Loss : %s"%str(np.mean( np.array(losses) ))) 134 | if not callback is None: 135 | callback.model = model.student 136 | callback.on_epoch_end( ep ) 137 | 138 | 139 | def perform_distilation(teacher_model ,student_model, data_path , distilation_loss='kl' , 140 | batch_size = 6 ,checkpoints_path=None , epochs = 32 , steps_per_epoch=512, 141 | feats_distilation_loss=None , feats_distilation_loss_w=0.1 ): 142 | 143 | 144 | losses_dict = { 'l1':keras.losses.MeanAbsoluteError() , "l2": keras.losses.MeanSquaredError() , "kl":keras.losses.KLDivergence() , 'pa':pairwise_dist_loss } 145 | 146 | if isinstance( distilation_loss , six.string_types): 147 | distilation_loss = losses_dict[ distilation_loss ] 148 | 149 | if isinstance( feats_distilation_loss , six.string_types): 150 | feats_distilation_loss = losses_dict[ feats_distilation_loss ] 151 | 152 | 153 | distill_model = Distiller( student=student_model , teacher=teacher_model , distilation_loss=distilation_loss, feats_distilation_loss=feats_distilation_loss , feats_distilation_loss_w=feats_distilation_loss_w ) 154 | 155 | img_gen = image_segmentation_generator(images_path=data_path , segs_path=None, batch_size=batch_size, 156 | n_classes=teacher_model.n_classes , input_height=teacher_model.input_height, input_width=teacher_model.input_width, 157 | output_height=None, output_width=None , ignore_segs=True) 158 | 159 | distill_model.compile( 160 | optimizer='adam', 161 | metrics=[ distilation_loss ] 162 | ) 163 | 164 | if checkpoints_path is not None: 165 | config_file = checkpoints_path + "_config.json" 166 | dir_name = os.path.dirname(config_file) 167 | 168 | if not os.path.exists(dir_name): 169 | os.makedirs(dir_name) 170 | 171 | with open(config_file, "w") as f: 172 | json.dump({ 173 | "model_class": student_model.model_name, 174 | "n_classes": student_model.n_classes, 175 | "input_height": student_model.input_height, 176 | "input_width": student_model.input_width, 177 | "output_height": student_model.output_height, 178 | "output_width": student_model.output_width 179 | }, f) 180 | 181 | cb = CheckpointsCallback( checkpoints_path ) 182 | else: 183 | cb = None 184 | 185 | fit_generator_custom( distill_model , img_gen , steps_per_epoch=steps_per_epoch , epochs=epochs ,callback=cb ) 186 | 187 | print("done ") 188 | 189 | 190 | -------------------------------------------------------------------------------- /keras_segmentation/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/keras_segmentation/models/__init__.py -------------------------------------------------------------------------------- /keras_segmentation/models/_pspnet_2.py: -------------------------------------------------------------------------------- 1 | # This code is proveded by Vladkryvoruchko and small modifications done by me . 2 | 3 | from math import ceil 4 | from sys import exit 5 | from keras import layers 6 | from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D 7 | from keras.layers import BatchNormalization, Activation,\ 8 | Input, Dropout, ZeroPadding2D 9 | from keras.layers import Concatenate, Add 10 | import tensorflow as tf 11 | 12 | from .config import IMAGE_ORDERING 13 | from .model_utils import get_segmentation_model 14 | 15 | 16 | learning_rate = 1e-3 # Layer specific learning rate 17 | # Weight decay not implemented 18 | 19 | 20 | def BN(name=""): 21 | return BatchNormalization(momentum=0.95, name=name, epsilon=1e-5) 22 | 23 | 24 | class Interp(layers.Layer): 25 | 26 | def __init__(self, new_size, **kwargs): 27 | self.new_size = new_size 28 | super(Interp, self).__init__(**kwargs) 29 | 30 | def build(self, input_shape): 31 | super(Interp, self).build(input_shape) 32 | 33 | def call(self, inputs, **kwargs): 34 | new_height, new_width = self.new_size 35 | try: 36 | resized = tf.image.resize(inputs, [new_height, new_width]) 37 | except AttributeError: 38 | resized = tf.image.resize_images(inputs, [new_height, new_width], 39 | align_corners=True) 40 | return resized 41 | 42 | def compute_output_shape(self, input_shape): 43 | return tuple([None, 44 | self.new_size[0], 45 | self.new_size[1], 46 | input_shape[3]]) 47 | 48 | def get_config(self): 49 | config = super(Interp, self).get_config() 50 | config['new_size'] = self.new_size 51 | return config 52 | 53 | 54 | # def Interp(x, shape): 55 | # new_height, new_width = shape 56 | # resized = tf.image.resize_images(x, [new_height, new_width], 57 | # align_corners=True) 58 | # return resized 59 | 60 | 61 | def residual_conv(prev, level, pad=1, lvl=1, sub_lvl=1, modify_stride=False): 62 | lvl = str(lvl) 63 | sub_lvl = str(sub_lvl) 64 | names = ["conv" + lvl + "_" + sub_lvl + "_1x1_reduce", 65 | "conv" + lvl + "_" + sub_lvl + "_1x1_reduce_bn", 66 | "conv" + lvl + "_" + sub_lvl + "_3x3", 67 | "conv" + lvl + "_" + sub_lvl + "_3x3_bn", 68 | "conv" + lvl + "_" + sub_lvl + "_1x1_increase", 69 | "conv" + lvl + "_" + sub_lvl + "_1x1_increase_bn"] 70 | if modify_stride is False: 71 | prev = Conv2D(64 * level, (1, 1), strides=(1, 1), name=names[0], 72 | use_bias=False)(prev) 73 | elif modify_stride is True: 74 | prev = Conv2D(64 * level, (1, 1), strides=(2, 2), name=names[0], 75 | use_bias=False)(prev) 76 | 77 | prev = BN(name=names[1])(prev) 78 | prev = Activation('relu')(prev) 79 | 80 | prev = ZeroPadding2D(padding=(pad, pad))(prev) 81 | prev = Conv2D(64 * level, (3, 3), strides=(1, 1), dilation_rate=pad, 82 | name=names[2], use_bias=False)(prev) 83 | 84 | prev = BN(name=names[3])(prev) 85 | prev = Activation('relu')(prev) 86 | prev = Conv2D(256 * level, (1, 1), strides=(1, 1), name=names[4], 87 | use_bias=False)(prev) 88 | prev = BN(name=names[5])(prev) 89 | return prev 90 | 91 | 92 | def short_convolution_branch(prev, level, lvl=1, sub_lvl=1, 93 | modify_stride=False): 94 | lvl = str(lvl) 95 | sub_lvl = str(sub_lvl) 96 | names = ["conv" + lvl + "_" + sub_lvl + "_1x1_proj", 97 | "conv" + lvl + "_" + sub_lvl + "_1x1_proj_bn"] 98 | 99 | if modify_stride is False: 100 | prev = Conv2D(256 * level, (1, 1), strides=(1, 1), name=names[0], 101 | use_bias=False)(prev) 102 | elif modify_stride is True: 103 | prev = Conv2D(256 * level, (1, 1), strides=(2, 2), name=names[0], 104 | use_bias=False)(prev) 105 | 106 | prev = BN(name=names[1])(prev) 107 | return prev 108 | 109 | 110 | def empty_branch(prev): 111 | return prev 112 | 113 | 114 | def residual_short(prev_layer, level, pad=1, lvl=1, sub_lvl=1, 115 | modify_stride=False): 116 | prev_layer = Activation('relu')(prev_layer) 117 | block_1 = residual_conv(prev_layer, level, 118 | pad=pad, lvl=lvl, sub_lvl=sub_lvl, 119 | modify_stride=modify_stride) 120 | 121 | block_2 = short_convolution_branch(prev_layer, level, 122 | lvl=lvl, sub_lvl=sub_lvl, 123 | modify_stride=modify_stride) 124 | added = Add()([block_1, block_2]) 125 | return added 126 | 127 | 128 | def residual_empty(prev_layer, level, pad=1, lvl=1, sub_lvl=1): 129 | prev_layer = Activation('relu')(prev_layer) 130 | 131 | block_1 = residual_conv(prev_layer, level, pad=pad, 132 | lvl=lvl, sub_lvl=sub_lvl) 133 | block_2 = empty_branch(prev_layer) 134 | added = Add()([block_1, block_2]) 135 | return added 136 | 137 | 138 | def ResNet(inp, layers): 139 | # Names for the first couple layers of model 140 | names = ["conv1_1_3x3_s2", 141 | "conv1_1_3x3_s2_bn", 142 | "conv1_2_3x3", 143 | "conv1_2_3x3_bn", 144 | "conv1_3_3x3", 145 | "conv1_3_3x3_bn"] 146 | 147 | # Short branch(only start of network) 148 | 149 | cnv1 = Conv2D(64, (3, 3), strides=(2, 2), padding='same', name=names[0], 150 | use_bias=False)(inp) # "conv1_1_3x3_s2" 151 | bn1 = BN(name=names[1])(cnv1) # "conv1_1_3x3_s2/bn" 152 | relu1 = Activation('relu')(bn1) # "conv1_1_3x3_s2/relu" 153 | 154 | cnv1 = Conv2D(64, (3, 3), strides=(1, 1), padding='same', name=names[2], 155 | use_bias=False)(relu1) # "conv1_2_3x3" 156 | bn1 = BN(name=names[3])(cnv1) # "conv1_2_3x3/bn" 157 | relu1 = Activation('relu')(bn1) # "conv1_2_3x3/relu" 158 | 159 | cnv1 = Conv2D(128, (3, 3), strides=(1, 1), padding='same', name=names[4], 160 | use_bias=False)(relu1) # "conv1_3_3x3" 161 | bn1 = BN(name=names[5])(cnv1) # "conv1_3_3x3/bn" 162 | relu1 = Activation('relu')(bn1) # "conv1_3_3x3/relu" 163 | 164 | res = MaxPooling2D(pool_size=(3, 3), padding='same', 165 | strides=(2, 2))(relu1) # "pool1_3x3_s2" 166 | 167 | # ---Residual layers(body of network) 168 | 169 | """ 170 | Modify_stride --Used only once in first 3_1 convolutions block. 171 | changes stride of first convolution from 1 -> 2 172 | """ 173 | 174 | # 2_1- 2_3 175 | res = residual_short(res, 1, pad=1, lvl=2, sub_lvl=1) 176 | for i in range(2): 177 | res = residual_empty(res, 1, pad=1, lvl=2, sub_lvl=i + 2) 178 | 179 | # 3_1 - 3_3 180 | res = residual_short(res, 2, pad=1, lvl=3, sub_lvl=1, modify_stride=True) 181 | for i in range(3): 182 | res = residual_empty(res, 2, pad=1, lvl=3, sub_lvl=i + 2) 183 | if layers == 50: 184 | # 4_1 - 4_6 185 | res = residual_short(res, 4, pad=2, lvl=4, sub_lvl=1) 186 | for i in range(5): 187 | res = residual_empty(res, 4, pad=2, lvl=4, sub_lvl=i + 2) 188 | elif layers == 101: 189 | # 4_1 - 4_23 190 | res = residual_short(res, 4, pad=2, lvl=4, sub_lvl=1) 191 | for i in range(22): 192 | res = residual_empty(res, 4, pad=2, lvl=4, sub_lvl=i + 2) 193 | else: 194 | print("This ResNet is not implemented") 195 | 196 | # 5_1 - 5_3 197 | res = residual_short(res, 8, pad=4, lvl=5, sub_lvl=1) 198 | for i in range(2): 199 | res = residual_empty(res, 8, pad=4, lvl=5, sub_lvl=i + 2) 200 | 201 | res = Activation('relu')(res) 202 | return res 203 | 204 | 205 | def interp_block(prev_layer, level, feature_map_shape, input_shape): 206 | if input_shape == (473, 473): 207 | kernel_strides_map = {1: 60, 208 | 2: 30, 209 | 3: 20, 210 | 6: 10} 211 | elif input_shape == (713, 713): 212 | kernel_strides_map = {1: 90, 213 | 2: 45, 214 | 3: 30, 215 | 6: 15} 216 | else: 217 | print("Pooling parameters for input shape ", 218 | input_shape, " are not defined.") 219 | exit(1) 220 | 221 | names = [ 222 | "conv5_3_pool" + str(level) + "_conv", 223 | "conv5_3_pool" + str(level) + "_conv_bn" 224 | ] 225 | kernel = (kernel_strides_map[level], kernel_strides_map[level]) 226 | strides = (kernel_strides_map[level], kernel_strides_map[level]) 227 | prev_layer = AveragePooling2D(kernel, strides=strides)(prev_layer) 228 | prev_layer = Conv2D(512, (1, 1), strides=(1, 1), name=names[0], 229 | use_bias=False)(prev_layer) 230 | prev_layer = BN(name=names[1])(prev_layer) 231 | prev_layer = Activation('relu')(prev_layer) 232 | # prev_layer = Lambda(Interp, arguments={ 233 | # 'shape': feature_map_shape})(prev_layer) 234 | prev_layer = Interp(feature_map_shape)(prev_layer) 235 | return prev_layer 236 | 237 | 238 | def build_pyramid_pooling_module(res, input_shape): 239 | """Build the Pyramid Pooling Module.""" 240 | # ---PSPNet concat layers with Interpolation 241 | feature_map_size = tuple(int(ceil(input_dim / 8.0)) 242 | for input_dim in input_shape) 243 | 244 | interp_block1 = interp_block(res, 1, feature_map_size, input_shape) 245 | interp_block2 = interp_block(res, 2, feature_map_size, input_shape) 246 | interp_block3 = interp_block(res, 3, feature_map_size, input_shape) 247 | interp_block6 = interp_block(res, 6, feature_map_size, input_shape) 248 | 249 | # concat all these layers. resulted 250 | # shape=(1,feature_map_size_x,feature_map_size_y,4096) 251 | res = Concatenate()([res, 252 | interp_block6, 253 | interp_block3, 254 | interp_block2, 255 | interp_block1]) 256 | return res 257 | 258 | 259 | def _build_pspnet(nb_classes, resnet_layers, input_shape, 260 | activation='softmax', channels=3): 261 | 262 | assert IMAGE_ORDERING == 'channels_last' 263 | 264 | inp = Input((input_shape[0], input_shape[1], channels)) 265 | 266 | res = ResNet(inp, layers=resnet_layers) 267 | 268 | psp = build_pyramid_pooling_module(res, input_shape) 269 | 270 | x = Conv2D(512, (3, 3), strides=(1, 1), padding="same", name="conv5_4", 271 | use_bias=False)(psp) 272 | x = BN(name="conv5_4_bn")(x) 273 | x = Activation('relu')(x) 274 | x = Dropout(0.1)(x) 275 | 276 | x = Conv2D(nb_classes, (1, 1), strides=(1, 1), name="conv6")(x) 277 | # x = Lambda(Interp, arguments={'shape': ( 278 | # input_shape[0], input_shape[1])})(x) 279 | x = Interp([input_shape[0], input_shape[1]])(x) 280 | 281 | model = get_segmentation_model(inp, x) 282 | model.seg_feats_layer_name = "conv5_4" 283 | 284 | return model 285 | -------------------------------------------------------------------------------- /keras_segmentation/models/all_models.py: -------------------------------------------------------------------------------- 1 | from . import pspnet 2 | from . import unet 3 | from . import segnet 4 | from . import fcn 5 | model_from_name = {} 6 | 7 | 8 | model_from_name["fcn_8"] = fcn.fcn_8 9 | model_from_name["fcn_32"] = fcn.fcn_32 10 | model_from_name["fcn_8_vgg"] = fcn.fcn_8_vgg 11 | model_from_name["fcn_32_vgg"] = fcn.fcn_32_vgg 12 | model_from_name["fcn_8_resnet50"] = fcn.fcn_8_resnet50 13 | model_from_name["fcn_32_resnet50"] = fcn.fcn_32_resnet50 14 | model_from_name["fcn_8_mobilenet"] = fcn.fcn_8_mobilenet 15 | model_from_name["fcn_32_mobilenet"] = fcn.fcn_32_mobilenet 16 | 17 | 18 | model_from_name["pspnet"] = pspnet.pspnet 19 | model_from_name["vgg_pspnet"] = pspnet.vgg_pspnet 20 | model_from_name["resnet50_pspnet"] = pspnet.resnet50_pspnet 21 | 22 | model_from_name["vgg_pspnet"] = pspnet.vgg_pspnet 23 | model_from_name["resnet50_pspnet"] = pspnet.resnet50_pspnet 24 | 25 | model_from_name["pspnet_50"] = pspnet.pspnet_50 26 | model_from_name["pspnet_101"] = pspnet.pspnet_101 27 | 28 | 29 | # model_from_name["mobilenet_pspnet"] = pspnet.mobilenet_pspnet 30 | 31 | 32 | model_from_name["unet_mini"] = unet.unet_mini 33 | model_from_name["unet"] = unet.unet 34 | model_from_name["vgg_unet"] = unet.vgg_unet 35 | model_from_name["resnet50_unet"] = unet.resnet50_unet 36 | model_from_name["mobilenet_unet"] = unet.mobilenet_unet 37 | 38 | 39 | model_from_name["segnet"] = segnet.segnet 40 | model_from_name["vgg_segnet"] = segnet.vgg_segnet 41 | model_from_name["resnet50_segnet"] = segnet.resnet50_segnet 42 | model_from_name["mobilenet_segnet"] = segnet.mobilenet_segnet 43 | -------------------------------------------------------------------------------- /keras_segmentation/models/basic_models.py: -------------------------------------------------------------------------------- 1 | from keras.models import * 2 | from keras.layers import * 3 | import keras.backend as K 4 | 5 | from .config import IMAGE_ORDERING 6 | 7 | 8 | def vanilla_encoder(input_height=224, input_width=224, channels=3): 9 | 10 | kernel = 3 11 | filter_size = 64 12 | pad = 1 13 | pool_size = 2 14 | 15 | if IMAGE_ORDERING == 'channels_first': 16 | img_input = Input(shape=(channels, input_height, input_width)) 17 | elif IMAGE_ORDERING == 'channels_last': 18 | img_input = Input(shape=(input_height, input_width, channels)) 19 | 20 | x = img_input 21 | levels = [] 22 | 23 | x = (ZeroPadding2D((pad, pad), data_format=IMAGE_ORDERING))(x) 24 | x = (Conv2D(filter_size, (kernel, kernel), 25 | data_format=IMAGE_ORDERING, padding='valid'))(x) 26 | x = (BatchNormalization())(x) 27 | x = (Activation('relu'))(x) 28 | x = (MaxPooling2D((pool_size, pool_size), data_format=IMAGE_ORDERING))(x) 29 | levels.append(x) 30 | 31 | x = (ZeroPadding2D((pad, pad), data_format=IMAGE_ORDERING))(x) 32 | x = (Conv2D(128, (kernel, kernel), data_format=IMAGE_ORDERING, 33 | padding='valid'))(x) 34 | x = (BatchNormalization())(x) 35 | x = (Activation('relu'))(x) 36 | x = (MaxPooling2D((pool_size, pool_size), data_format=IMAGE_ORDERING))(x) 37 | levels.append(x) 38 | 39 | for _ in range(3): 40 | x = (ZeroPadding2D((pad, pad), data_format=IMAGE_ORDERING))(x) 41 | x = (Conv2D(256, (kernel, kernel), 42 | data_format=IMAGE_ORDERING, padding='valid'))(x) 43 | x = (BatchNormalization())(x) 44 | x = (Activation('relu'))(x) 45 | x = (MaxPooling2D((pool_size, pool_size), 46 | data_format=IMAGE_ORDERING))(x) 47 | levels.append(x) 48 | 49 | return img_input, levels 50 | -------------------------------------------------------------------------------- /keras_segmentation/models/config.py: -------------------------------------------------------------------------------- 1 | IMAGE_ORDERING_CHANNELS_LAST = "channels_last" 2 | IMAGE_ORDERING_CHANNELS_FIRST = "channels_first" 3 | 4 | # Default IMAGE_ORDERING = channels_last 5 | IMAGE_ORDERING = IMAGE_ORDERING_CHANNELS_LAST 6 | -------------------------------------------------------------------------------- /keras_segmentation/models/fcn.py: -------------------------------------------------------------------------------- 1 | from keras.models import * 2 | from keras.layers import * 3 | 4 | from .config import IMAGE_ORDERING 5 | from .model_utils import get_segmentation_model 6 | from .vgg16 import get_vgg_encoder 7 | from .mobilenet import get_mobilenet_encoder 8 | from .basic_models import vanilla_encoder 9 | from .resnet50 import get_resnet50_encoder 10 | 11 | 12 | # crop o1 wrt o2 13 | def crop(o1, o2, i): 14 | o_shape2 = Model(i, o2).output_shape 15 | 16 | if IMAGE_ORDERING == 'channels_first': 17 | output_height2 = o_shape2[2] 18 | output_width2 = o_shape2[3] 19 | else: 20 | output_height2 = o_shape2[1] 21 | output_width2 = o_shape2[2] 22 | 23 | o_shape1 = Model(i, o1).output_shape 24 | if IMAGE_ORDERING == 'channels_first': 25 | output_height1 = o_shape1[2] 26 | output_width1 = o_shape1[3] 27 | else: 28 | output_height1 = o_shape1[1] 29 | output_width1 = o_shape1[2] 30 | 31 | cx = abs(output_width1 - output_width2) 32 | cy = abs(output_height2 - output_height1) 33 | 34 | if output_width1 > output_width2: 35 | o1 = Cropping2D(cropping=((0, 0), (0, cx)), 36 | data_format=IMAGE_ORDERING)(o1) 37 | else: 38 | o2 = Cropping2D(cropping=((0, 0), (0, cx)), 39 | data_format=IMAGE_ORDERING)(o2) 40 | 41 | if output_height1 > output_height2: 42 | o1 = Cropping2D(cropping=((0, cy), (0, 0)), 43 | data_format=IMAGE_ORDERING)(o1) 44 | else: 45 | o2 = Cropping2D(cropping=((0, cy), (0, 0)), 46 | data_format=IMAGE_ORDERING)(o2) 47 | 48 | return o1, o2 49 | 50 | 51 | def fcn_8(n_classes, encoder=vanilla_encoder, input_height=416, 52 | input_width=608, channels=3): 53 | 54 | img_input, levels = encoder( 55 | input_height=input_height, input_width=input_width, channels=channels) 56 | [f1, f2, f3, f4, f5] = levels 57 | 58 | o = f5 59 | 60 | o = (Conv2D(4096, (7, 7), activation='relu', 61 | padding='same', data_format=IMAGE_ORDERING))(o) 62 | o = Dropout(0.5)(o) 63 | o = (Conv2D(4096, (1, 1), activation='relu', 64 | padding='same', data_format=IMAGE_ORDERING))(o) 65 | o = Dropout(0.5)(o) 66 | 67 | o = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal', 68 | data_format=IMAGE_ORDERING))(o) 69 | o = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=( 70 | 2, 2), use_bias=False, data_format=IMAGE_ORDERING)(o) 71 | 72 | o2 = f4 73 | o2 = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal', 74 | data_format=IMAGE_ORDERING))(o2) 75 | 76 | o, o2 = crop(o, o2, img_input) 77 | 78 | o = Add()([o, o2]) 79 | 80 | o = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=( 81 | 2, 2), use_bias=False, data_format=IMAGE_ORDERING)(o) 82 | o2 = f3 83 | o2 = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal', 84 | data_format=IMAGE_ORDERING))(o2) 85 | o2, o = crop(o2, o, img_input) 86 | o = Add( name="seg_feats" )([o2, o]) 87 | 88 | o = Conv2DTranspose(n_classes, kernel_size=(16, 16), strides=( 89 | 8, 8), use_bias=False, data_format=IMAGE_ORDERING)(o) 90 | 91 | model = get_segmentation_model(img_input, o) 92 | model.model_name = "fcn_8" 93 | return model 94 | 95 | 96 | def fcn_32(n_classes, encoder=vanilla_encoder, input_height=416, 97 | input_width=608, channels=3): 98 | 99 | img_input, levels = encoder( 100 | input_height=input_height, input_width=input_width, channels=channels) 101 | [f1, f2, f3, f4, f5] = levels 102 | 103 | o = f5 104 | 105 | o = (Conv2D(4096, (7, 7), activation='relu', 106 | padding='same', data_format=IMAGE_ORDERING))(o) 107 | o = Dropout(0.5)(o) 108 | o = (Conv2D(4096, (1, 1), activation='relu', 109 | padding='same', data_format=IMAGE_ORDERING))(o) 110 | o = Dropout(0.5)(o) 111 | 112 | o = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal', 113 | data_format=IMAGE_ORDERING , name="seg_feats" ))(o) 114 | o = Conv2DTranspose(n_classes, kernel_size=(64, 64), strides=( 115 | 32, 32), use_bias=False, data_format=IMAGE_ORDERING)(o) 116 | 117 | model = get_segmentation_model(img_input, o) 118 | model.model_name = "fcn_32" 119 | return model 120 | 121 | 122 | def fcn_8_vgg(n_classes, input_height=416, input_width=608, channels=3): 123 | model = fcn_8(n_classes, get_vgg_encoder, 124 | input_height=input_height, input_width=input_width, channels=channels) 125 | model.model_name = "fcn_8_vgg" 126 | return model 127 | 128 | 129 | def fcn_32_vgg(n_classes, input_height=416, input_width=608, channels=3): 130 | model = fcn_32(n_classes, get_vgg_encoder, 131 | input_height=input_height, input_width=input_width, channels=channels) 132 | model.model_name = "fcn_32_vgg" 133 | return model 134 | 135 | 136 | def fcn_8_resnet50(n_classes, input_height=416, input_width=608, channels=3): 137 | model = fcn_8(n_classes, get_resnet50_encoder, 138 | input_height=input_height, input_width=input_width, channels=channels) 139 | model.model_name = "fcn_8_resnet50" 140 | return model 141 | 142 | 143 | def fcn_32_resnet50(n_classes, input_height=416, input_width=608, channels=3): 144 | model = fcn_32(n_classes, get_resnet50_encoder, 145 | input_height=input_height, input_width=input_width, channels=channels) 146 | model.model_name = "fcn_32_resnet50" 147 | return model 148 | 149 | 150 | def fcn_8_mobilenet(n_classes, input_height=224, input_width=224, channels=3): 151 | model = fcn_8(n_classes, get_mobilenet_encoder, 152 | input_height=input_height, input_width=input_width, channels=channels) 153 | model.model_name = "fcn_8_mobilenet" 154 | return model 155 | 156 | 157 | def fcn_32_mobilenet(n_classes, input_height=224, input_width=224, channels=3): 158 | model = fcn_32(n_classes, get_mobilenet_encoder, 159 | input_height=input_height, input_width=input_width, channels=channels) 160 | model.model_name = "fcn_32_mobilenet" 161 | return model 162 | 163 | 164 | if __name__ == '__main__': 165 | m = fcn_8(101) 166 | m = fcn_32(101) 167 | -------------------------------------------------------------------------------- /keras_segmentation/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | from keras.models import * 2 | from keras.layers import * 3 | import keras.backend as K 4 | import keras 5 | import tensorflow as tf 6 | from .config import IMAGE_ORDERING 7 | 8 | BASE_WEIGHT_PATH = ('https://github.com/fchollet/deep-learning-models/' 9 | 'releases/download/v0.6/') 10 | 11 | 12 | def relu6(x): 13 | return K.relu(x, max_value=6) 14 | 15 | 16 | def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): 17 | 18 | channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1 19 | filters = int(filters * alpha) 20 | x = ZeroPadding2D(padding=(1, 1), name='conv1_pad', 21 | data_format=IMAGE_ORDERING)(inputs) 22 | x = Conv2D(filters, kernel, data_format=IMAGE_ORDERING, 23 | padding='valid', 24 | use_bias=False, 25 | strides=strides, 26 | name='conv1')(x) 27 | x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x) 28 | return Activation(relu6, name='conv1_relu')(x) 29 | 30 | 31 | def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha, 32 | depth_multiplier=1, strides=(1, 1), block_id=1): 33 | 34 | channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1 35 | pointwise_conv_filters = int(pointwise_conv_filters * alpha) 36 | 37 | x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING, 38 | name='conv_pad_%d' % block_id)(inputs) 39 | x = DepthwiseConv2D((3, 3), data_format=IMAGE_ORDERING, 40 | padding='valid', 41 | depth_multiplier=depth_multiplier, 42 | strides=strides, 43 | use_bias=False, 44 | name='conv_dw_%d' % block_id)(x) 45 | x = BatchNormalization( 46 | axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x) 47 | x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) 48 | 49 | x = Conv2D(pointwise_conv_filters, (1, 1), data_format=IMAGE_ORDERING, 50 | padding='same', 51 | use_bias=False, 52 | strides=(1, 1), 53 | name='conv_pw_%d' % block_id)(x) 54 | x = BatchNormalization(axis=channel_axis, 55 | name='conv_pw_%d_bn' % block_id)(x) 56 | return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x) 57 | 58 | 59 | def get_mobilenet_encoder(input_height=224, input_width=224, 60 | pretrained='imagenet', channels=3): 61 | 62 | # todo add more alpha and stuff 63 | 64 | assert (K.image_data_format() == 65 | 'channels_last'), "Currently only channels last mode is supported" 66 | assert (IMAGE_ORDERING == 67 | 'channels_last'), "Currently only channels last mode is supported" 68 | 69 | assert input_height % 32 == 0 70 | assert input_width % 32 == 0 71 | 72 | alpha = 1.0 73 | depth_multiplier = 1 74 | dropout = 1e-3 75 | 76 | img_input = Input(shape=(input_height, input_width, channels)) 77 | 78 | x = _conv_block(img_input, 32, alpha, strides=(2, 2)) 79 | x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1) 80 | f1 = x 81 | 82 | x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, 83 | strides=(2, 2), block_id=2) 84 | x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3) 85 | f2 = x 86 | 87 | x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, 88 | strides=(2, 2), block_id=4) 89 | x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5) 90 | f3 = x 91 | 92 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, 93 | strides=(2, 2), block_id=6) 94 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7) 95 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8) 96 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9) 97 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10) 98 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11) 99 | f4 = x 100 | 101 | x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, 102 | strides=(2, 2), block_id=12) 103 | x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13) 104 | f5 = x 105 | 106 | if pretrained == 'imagenet': 107 | model_name = 'mobilenet_%s_%d_tf_no_top.h5' % ('1_0', 224) 108 | 109 | weight_path = BASE_WEIGHT_PATH + model_name 110 | weights_path = tf.keras.utils.get_file(model_name, weight_path) 111 | 112 | Model(img_input, x).load_weights(weights_path, by_name=True, skip_mismatch=True) 113 | 114 | return img_input, [f1, f2, f3, f4, f5] 115 | -------------------------------------------------------------------------------- /keras_segmentation/models/model.py: -------------------------------------------------------------------------------- 1 | """ Definition for the generic Model class """ 2 | 3 | 4 | class Model: 5 | def __init__(self, n_classes, input_height=None, input_width=None): 6 | pass 7 | -------------------------------------------------------------------------------- /keras_segmentation/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from types import MethodType 2 | 3 | from keras.models import * 4 | from keras.layers import * 5 | import keras.backend as K 6 | from tqdm import tqdm 7 | 8 | from .config import IMAGE_ORDERING 9 | from ..train import train 10 | from ..predict import predict, predict_multiple, evaluate 11 | 12 | 13 | # source m1 , dest m2 14 | def transfer_weights(m1, m2, verbose=True): 15 | 16 | assert len(m1.layers) == len( 17 | m2.layers), "Both models should have same number of layers" 18 | 19 | nSet = 0 20 | nNotSet = 0 21 | 22 | if verbose: 23 | print("Copying weights ") 24 | bar = tqdm(zip(m1.layers, m2.layers)) 25 | else: 26 | bar = zip(m1.layers, m2.layers) 27 | 28 | for l, ll in bar: 29 | 30 | if not any([w.shape != ww.shape for w, ww in zip(list(l.weights), 31 | list(ll.weights))]): 32 | if len(list(l.weights)) > 0: 33 | ll.set_weights(l.get_weights()) 34 | nSet += 1 35 | else: 36 | nNotSet += 1 37 | 38 | if verbose: 39 | print("Copied weights of %d layers and skipped %d layers" % 40 | (nSet, nNotSet)) 41 | 42 | 43 | def resize_image(inp, s, data_format): 44 | 45 | try: 46 | 47 | return Lambda(lambda x: K.resize_images(x, 48 | height_factor=s[0], 49 | width_factor=s[1], 50 | data_format=data_format, 51 | interpolation='bilinear'))(inp) 52 | 53 | except Exception as e: 54 | # if keras is old, then rely on the tf function 55 | # Sorry theano/cntk users!!! 56 | assert data_format == 'channels_last' 57 | assert IMAGE_ORDERING == 'channels_last' 58 | 59 | import tensorflow as tf 60 | 61 | return Lambda( 62 | lambda x: tf.image.resize_images( 63 | x, (K.int_shape(x)[1]*s[0], K.int_shape(x)[2]*s[1])) 64 | )(inp) 65 | 66 | 67 | def get_segmentation_model(input, output): 68 | 69 | img_input = input 70 | o = output 71 | 72 | o_shape = Model(img_input, o).output_shape 73 | i_shape = Model(img_input, o).input_shape 74 | 75 | if IMAGE_ORDERING == 'channels_first': 76 | output_height = o_shape[2] 77 | output_width = o_shape[3] 78 | input_height = i_shape[2] 79 | input_width = i_shape[3] 80 | n_classes = o_shape[1] 81 | o = (Reshape((-1, output_height*output_width)))(o) 82 | o = (Permute((2, 1)))(o) 83 | elif IMAGE_ORDERING == 'channels_last': 84 | output_height = o_shape[1] 85 | output_width = o_shape[2] 86 | input_height = i_shape[1] 87 | input_width = i_shape[2] 88 | n_classes = o_shape[3] 89 | o = (Reshape((output_height*output_width, -1)))(o) 90 | 91 | o = (Activation('softmax'))(o) 92 | model = Model(img_input, o) 93 | model.output_width = output_width 94 | model.output_height = output_height 95 | model.n_classes = n_classes 96 | model.input_height = input_height 97 | model.input_width = input_width 98 | model.model_name = "" 99 | 100 | model.train = MethodType(train, model) 101 | model.predict_segmentation = MethodType(predict, model) 102 | model.predict_multiple = MethodType(predict_multiple, model) 103 | model.evaluate_segmentation = MethodType(evaluate, model) 104 | 105 | return model 106 | -------------------------------------------------------------------------------- /keras_segmentation/models/pspnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras 3 | from keras.models import * 4 | from keras.layers import * 5 | import keras.backend as K 6 | 7 | from .config import IMAGE_ORDERING 8 | from .model_utils import get_segmentation_model, resize_image 9 | from .vgg16 import get_vgg_encoder 10 | from .basic_models import vanilla_encoder 11 | from .resnet50 import get_resnet50_encoder 12 | 13 | 14 | if IMAGE_ORDERING == 'channels_first': 15 | MERGE_AXIS = 1 16 | elif IMAGE_ORDERING == 'channels_last': 17 | MERGE_AXIS = -1 18 | 19 | 20 | def pool_block(feats, pool_factor): 21 | 22 | if IMAGE_ORDERING == 'channels_first': 23 | h = K.int_shape(feats)[2] 24 | w = K.int_shape(feats)[3] 25 | elif IMAGE_ORDERING == 'channels_last': 26 | h = K.int_shape(feats)[1] 27 | w = K.int_shape(feats)[2] 28 | 29 | pool_size = strides = [ 30 | int(np.round(float(h) / pool_factor)), 31 | int(np.round(float(w) / pool_factor))] 32 | 33 | x = AveragePooling2D(pool_size, data_format=IMAGE_ORDERING, 34 | strides=strides, padding='same')(feats) 35 | x = Conv2D(512, (1, 1), data_format=IMAGE_ORDERING, 36 | padding='same', use_bias=False)(x) 37 | x = BatchNormalization()(x) 38 | x = Activation('relu')(x) 39 | 40 | x = resize_image(x, strides, data_format=IMAGE_ORDERING) 41 | 42 | return x 43 | 44 | 45 | def _pspnet(n_classes, encoder, input_height=384, input_width=576, channels=3): 46 | 47 | assert input_height % 192 == 0 48 | assert input_width % 192 == 0 49 | 50 | img_input, levels = encoder( 51 | input_height=input_height, input_width=input_width, channels=channels) 52 | [f1, f2, f3, f4, f5] = levels 53 | 54 | o = f5 55 | 56 | pool_factors = [1, 2, 3, 6] 57 | pool_outs = [o] 58 | 59 | for p in pool_factors: 60 | pooled = pool_block(o, p) 61 | pool_outs.append(pooled) 62 | 63 | o = Concatenate(axis=MERGE_AXIS)(pool_outs) 64 | 65 | o = Conv2D(512, (1, 1), data_format=IMAGE_ORDERING, use_bias=False , name="seg_feats" )(o) 66 | o = BatchNormalization()(o) 67 | o = Activation('relu')(o) 68 | 69 | o = Conv2D(n_classes, (3, 3), data_format=IMAGE_ORDERING, 70 | padding='same')(o) 71 | o = resize_image(o, (8, 8), data_format=IMAGE_ORDERING) 72 | 73 | model = get_segmentation_model(img_input, o) 74 | return model 75 | 76 | 77 | def pspnet(n_classes, input_height=384, input_width=576, channels=3): 78 | 79 | model = _pspnet(n_classes, vanilla_encoder, 80 | input_height=input_height, input_width=input_width, channels=channels) 81 | model.model_name = "pspnet" 82 | return model 83 | 84 | 85 | def vgg_pspnet(n_classes, input_height=384, input_width=576, channels=3): 86 | 87 | model = _pspnet(n_classes, get_vgg_encoder, 88 | input_height=input_height, input_width=input_width, channels=channels) 89 | model.model_name = "vgg_pspnet" 90 | return model 91 | 92 | 93 | def resnet50_pspnet(n_classes, input_height=384, input_width=576, channels=3): 94 | 95 | model = _pspnet(n_classes, get_resnet50_encoder, 96 | input_height=input_height, input_width=input_width, channels=channels) 97 | model.model_name = "resnet50_pspnet" 98 | return model 99 | 100 | 101 | def pspnet_50(n_classes, input_height=473, input_width=473, channels=3): 102 | from ._pspnet_2 import _build_pspnet 103 | 104 | nb_classes = n_classes 105 | resnet_layers = 50 106 | input_shape = (input_height, input_width) 107 | model = _build_pspnet(nb_classes=nb_classes, 108 | resnet_layers=resnet_layers, 109 | input_shape=input_shape, channels=channels) 110 | model.model_name = "pspnet_50" 111 | return model 112 | 113 | 114 | def pspnet_101(n_classes, input_height=473, input_width=473, channels=3): 115 | from ._pspnet_2 import _build_pspnet 116 | 117 | nb_classes = n_classes 118 | resnet_layers = 101 119 | input_shape = (input_height, input_width) 120 | model = _build_pspnet(nb_classes=nb_classes, 121 | resnet_layers=resnet_layers, 122 | input_shape=input_shape, channels=channels) 123 | model.model_name = "pspnet_101" 124 | return model 125 | 126 | 127 | # def mobilenet_pspnet( n_classes , input_height=224, input_width=224 ): 128 | 129 | # model = _pspnet(n_classes, get_mobilenet_encoder, 130 | # input_height=input_height, input_width=input_width) 131 | # model.model_name = "mobilenet_pspnet" 132 | # return model 133 | 134 | 135 | if __name__ == '__main__': 136 | 137 | m = _pspnet(101, vanilla_encoder) 138 | # m = _pspnet( 101 , get_mobilenet_encoder ,True , 224 , 224 ) 139 | m = _pspnet(101, get_vgg_encoder) 140 | m = _pspnet(101, get_resnet50_encoder) 141 | -------------------------------------------------------------------------------- /keras_segmentation/models/resnet50.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import * 3 | from keras.layers import * 4 | from keras import layers 5 | import tensorflow as tf 6 | 7 | # Source: 8 | # https://github.com/fchollet/deep-learning-models/blob/master/resnet50.py 9 | 10 | 11 | from .config import IMAGE_ORDERING 12 | 13 | 14 | if IMAGE_ORDERING == 'channels_first': 15 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \ 16 | "releases/download/v0.2/" \ 17 | "resnet50_weights_th_dim_ordering_th_kernels_notop.h5" 18 | elif IMAGE_ORDERING == 'channels_last': 19 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \ 20 | "releases/download/v0.2/" \ 21 | "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5" 22 | 23 | 24 | def one_side_pad(x): 25 | x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x) 26 | if IMAGE_ORDERING == 'channels_first': 27 | x = Lambda(lambda x: x[:, :, :-1, :-1])(x) 28 | elif IMAGE_ORDERING == 'channels_last': 29 | x = Lambda(lambda x: x[:, :-1, :-1, :])(x) 30 | return x 31 | 32 | 33 | def identity_block(input_tensor, kernel_size, filters, stage, block): 34 | """The identity block is the block that has no conv layer at shortcut. 35 | # Arguments 36 | input_tensor: input tensor 37 | kernel_size: defualt 3, the kernel size of middle conv layer at 38 | main path 39 | filters: list of integers, the filterss of 3 conv layer at main path 40 | stage: integer, current stage label, used for generating layer names 41 | block: 'a','b'..., current block label, used for generating layer names 42 | # Returns 43 | Output tensor for the block. 44 | """ 45 | filters1, filters2, filters3 = filters 46 | 47 | if IMAGE_ORDERING == 'channels_last': 48 | bn_axis = 3 49 | else: 50 | bn_axis = 1 51 | 52 | conv_name_base = 'res' + str(stage) + block + '_branch' 53 | bn_name_base = 'bn' + str(stage) + block + '_branch' 54 | 55 | x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, 56 | name=conv_name_base + '2a')(input_tensor) 57 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 58 | x = Activation('relu')(x) 59 | 60 | x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, 61 | padding='same', name=conv_name_base + '2b')(x) 62 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 63 | x = Activation('relu')(x) 64 | 65 | x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, 66 | name=conv_name_base + '2c')(x) 67 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 68 | 69 | x = layers.add([x, input_tensor]) 70 | x = Activation('relu')(x) 71 | return x 72 | 73 | 74 | def conv_block(input_tensor, kernel_size, filters, stage, block, 75 | strides=(2, 2)): 76 | """conv_block is the block that has a conv layer at shortcut 77 | # Arguments 78 | input_tensor: input tensor 79 | kernel_size: defualt 3, the kernel size of middle conv layer at 80 | main path 81 | filters: list of integers, the filterss of 3 conv layer at main path 82 | stage: integer, current stage label, used for generating layer names 83 | block: 'a','b'..., current block label, used for generating layer names 84 | # Returns 85 | Output tensor for the block. 86 | Note that from stage 3, the first conv layer at main path is with 87 | strides=(2,2) and the shortcut should have strides=(2,2) as well 88 | """ 89 | filters1, filters2, filters3 = filters 90 | 91 | if IMAGE_ORDERING == 'channels_last': 92 | bn_axis = 3 93 | else: 94 | bn_axis = 1 95 | 96 | conv_name_base = 'res' + str(stage) + block + '_branch' 97 | bn_name_base = 'bn' + str(stage) + block + '_branch' 98 | 99 | x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides, 100 | name=conv_name_base + '2a')(input_tensor) 101 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 102 | x = Activation('relu')(x) 103 | 104 | x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, 105 | padding='same', name=conv_name_base + '2b')(x) 106 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 107 | x = Activation('relu')(x) 108 | 109 | x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, 110 | name=conv_name_base + '2c')(x) 111 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 112 | 113 | shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, 114 | strides=strides, name=conv_name_base + '1')(input_tensor) 115 | shortcut = BatchNormalization( 116 | axis=bn_axis, name=bn_name_base + '1')(shortcut) 117 | 118 | x = layers.add([x, shortcut]) 119 | x = Activation('relu')(x) 120 | return x 121 | 122 | 123 | def get_resnet50_encoder(input_height=224, input_width=224, 124 | pretrained='imagenet', 125 | include_top=True, weights='imagenet', 126 | input_tensor=None, input_shape=None, 127 | pooling=None, 128 | classes=1000, channels=3): 129 | 130 | assert input_height % 32 == 0 131 | assert input_width % 32 == 0 132 | 133 | if IMAGE_ORDERING == 'channels_first': 134 | img_input = Input(shape=(channels, input_height, input_width)) 135 | elif IMAGE_ORDERING == 'channels_last': 136 | img_input = Input(shape=(input_height, input_width, channels)) 137 | 138 | if IMAGE_ORDERING == 'channels_last': 139 | bn_axis = 3 140 | else: 141 | bn_axis = 1 142 | 143 | x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input) 144 | x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING, 145 | strides=(2, 2), name='conv1')(x) 146 | f1 = x 147 | 148 | x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) 149 | x = Activation('relu')(x) 150 | x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x) 151 | 152 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) 153 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') 154 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 155 | f2 = one_side_pad(x) 156 | 157 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') 158 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') 159 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') 160 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') 161 | f3 = x 162 | 163 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 164 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') 165 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') 166 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') 167 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') 168 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') 169 | f4 = x 170 | 171 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') 172 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') 173 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') 174 | f5 = x 175 | 176 | x = AveragePooling2D( 177 | (7, 7), data_format=IMAGE_ORDERING, name='avg_pool')(x) 178 | # f6 = x 179 | 180 | if pretrained == 'imagenet': 181 | weights_path = tf.keras.utils.get_file( 182 | pretrained_url.split("/")[-1], pretrained_url) 183 | Model(img_input, x).load_weights(weights_path, by_name=True, skip_mismatch=True) 184 | 185 | return img_input, [f1, f2, f3, f4, f5] 186 | -------------------------------------------------------------------------------- /keras_segmentation/models/segnet.py: -------------------------------------------------------------------------------- 1 | from keras.models import * 2 | from keras.layers import * 3 | 4 | from .config import IMAGE_ORDERING 5 | from .model_utils import get_segmentation_model 6 | from .vgg16 import get_vgg_encoder 7 | from .mobilenet import get_mobilenet_encoder 8 | from .basic_models import vanilla_encoder 9 | from .resnet50 import get_resnet50_encoder 10 | 11 | 12 | def segnet_decoder(f, n_classes, n_up=3): 13 | 14 | assert n_up >= 2 15 | 16 | o = f 17 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 18 | o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING))(o) 19 | o = (BatchNormalization())(o) 20 | 21 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) 22 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 23 | o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING))(o) 24 | o = (BatchNormalization())(o) 25 | 26 | for _ in range(n_up-2): 27 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) 28 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 29 | o = (Conv2D(128, (3, 3), padding='valid', 30 | data_format=IMAGE_ORDERING))(o) 31 | o = (BatchNormalization())(o) 32 | 33 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) 34 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 35 | o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, name="seg_feats"))(o) 36 | o = (BatchNormalization())(o) 37 | 38 | o = Conv2D(n_classes, (3, 3), padding='same', 39 | data_format=IMAGE_ORDERING)(o) 40 | 41 | return o 42 | 43 | 44 | def _segnet(n_classes, encoder, input_height=416, input_width=608, 45 | encoder_level=3, channels=3): 46 | 47 | img_input, levels = encoder( 48 | input_height=input_height, input_width=input_width, channels=channels) 49 | 50 | feat = levels[encoder_level] 51 | o = segnet_decoder(feat, n_classes, n_up=3) 52 | model = get_segmentation_model(img_input, o) 53 | 54 | return model 55 | 56 | 57 | def segnet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3): 58 | 59 | model = _segnet(n_classes, vanilla_encoder, input_height=input_height, 60 | input_width=input_width, encoder_level=encoder_level, channels=channels) 61 | model.model_name = "segnet" 62 | return model 63 | 64 | 65 | def vgg_segnet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3): 66 | 67 | model = _segnet(n_classes, get_vgg_encoder, input_height=input_height, 68 | input_width=input_width, encoder_level=encoder_level, channels=channels) 69 | model.model_name = "vgg_segnet" 70 | return model 71 | 72 | 73 | def resnet50_segnet(n_classes, input_height=416, input_width=608, 74 | encoder_level=3, channels=3): 75 | 76 | model = _segnet(n_classes, get_resnet50_encoder, input_height=input_height, 77 | input_width=input_width, encoder_level=encoder_level, channels=channels) 78 | model.model_name = "resnet50_segnet" 79 | return model 80 | 81 | 82 | def mobilenet_segnet(n_classes, input_height=224, input_width=224, 83 | encoder_level=3, channels=3): 84 | 85 | model = _segnet(n_classes, get_mobilenet_encoder, 86 | input_height=input_height, 87 | input_width=input_width, encoder_level=encoder_level, channels=channels) 88 | model.model_name = "mobilenet_segnet" 89 | return model 90 | 91 | 92 | if __name__ == '__main__': 93 | m = vgg_segnet(101) 94 | m = segnet(101) 95 | # m = mobilenet_segnet( 101 ) 96 | # from keras.utils import plot_model 97 | # plot_model( m , show_shapes=True , to_file='model.png') 98 | -------------------------------------------------------------------------------- /keras_segmentation/models/unet.py: -------------------------------------------------------------------------------- 1 | from keras.models import * 2 | from keras.layers import * 3 | 4 | from .config import IMAGE_ORDERING 5 | from .model_utils import get_segmentation_model 6 | from .vgg16 import get_vgg_encoder 7 | from .mobilenet import get_mobilenet_encoder 8 | from .basic_models import vanilla_encoder 9 | from .resnet50 import get_resnet50_encoder 10 | 11 | 12 | if IMAGE_ORDERING == 'channels_first': 13 | MERGE_AXIS = 1 14 | elif IMAGE_ORDERING == 'channels_last': 15 | MERGE_AXIS = -1 16 | 17 | 18 | def unet_mini(n_classes, input_height=360, input_width=480, channels=3): 19 | 20 | if IMAGE_ORDERING == 'channels_first': 21 | img_input = Input(shape=(channels, input_height, input_width)) 22 | elif IMAGE_ORDERING == 'channels_last': 23 | img_input = Input(shape=(input_height, input_width, channels)) 24 | 25 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, 26 | activation='relu', padding='same')(img_input) 27 | conv1 = Dropout(0.2)(conv1) 28 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, 29 | activation='relu', padding='same')(conv1) 30 | pool1 = MaxPooling2D((2, 2), data_format=IMAGE_ORDERING)(conv1) 31 | 32 | conv2 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING, 33 | activation='relu', padding='same')(pool1) 34 | conv2 = Dropout(0.2)(conv2) 35 | conv2 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING, 36 | activation='relu', padding='same')(conv2) 37 | pool2 = MaxPooling2D((2, 2), data_format=IMAGE_ORDERING)(conv2) 38 | 39 | conv3 = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING, 40 | activation='relu', padding='same')(pool2) 41 | conv3 = Dropout(0.2)(conv3) 42 | conv3 = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING, 43 | activation='relu', padding='same')(conv3) 44 | 45 | up1 = concatenate([UpSampling2D((2, 2), data_format=IMAGE_ORDERING)( 46 | conv3), conv2], axis=MERGE_AXIS) 47 | conv4 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING, 48 | activation='relu', padding='same')(up1) 49 | conv4 = Dropout(0.2)(conv4) 50 | conv4 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING, 51 | activation='relu', padding='same')(conv4) 52 | 53 | up2 = concatenate([UpSampling2D((2, 2), data_format=IMAGE_ORDERING)( 54 | conv4), conv1], axis=MERGE_AXIS) 55 | conv5 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, 56 | activation='relu', padding='same')(up2) 57 | conv5 = Dropout(0.2)(conv5) 58 | conv5 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING, 59 | activation='relu', padding='same' , name="seg_feats")(conv5) 60 | 61 | o = Conv2D(n_classes, (1, 1), data_format=IMAGE_ORDERING, 62 | padding='same')(conv5) 63 | 64 | model = get_segmentation_model(img_input, o) 65 | model.model_name = "unet_mini" 66 | return model 67 | 68 | 69 | def _unet(n_classes, encoder, l1_skip_conn=True, input_height=416, 70 | input_width=608, channels=3): 71 | 72 | img_input, levels = encoder( 73 | input_height=input_height, input_width=input_width, channels=channels) 74 | [f1, f2, f3, f4, f5] = levels 75 | 76 | o = f4 77 | 78 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 79 | o = (Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING))(o) 80 | o = (BatchNormalization())(o) 81 | 82 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) 83 | o = (concatenate([o, f3], axis=MERGE_AXIS)) 84 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 85 | o = (Conv2D(256, (3, 3), padding='valid', activation='relu' , data_format=IMAGE_ORDERING))(o) 86 | o = (BatchNormalization())(o) 87 | 88 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) 89 | o = (concatenate([o, f2], axis=MERGE_AXIS)) 90 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 91 | o = (Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING))(o) 92 | o = (BatchNormalization())(o) 93 | 94 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o) 95 | 96 | if l1_skip_conn: 97 | o = (concatenate([o, f1], axis=MERGE_AXIS)) 98 | 99 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o) 100 | o = (Conv2D(64, (3, 3), padding='valid', activation='relu', data_format=IMAGE_ORDERING, name="seg_feats"))(o) 101 | o = (BatchNormalization())(o) 102 | 103 | o = Conv2D(n_classes, (3, 3), padding='same', 104 | data_format=IMAGE_ORDERING)(o) 105 | 106 | model = get_segmentation_model(img_input, o) 107 | 108 | return model 109 | 110 | 111 | def unet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3): 112 | 113 | model = _unet(n_classes, vanilla_encoder, 114 | input_height=input_height, input_width=input_width, channels=channels) 115 | model.model_name = "unet" 116 | return model 117 | 118 | 119 | def vgg_unet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3): 120 | 121 | model = _unet(n_classes, get_vgg_encoder, 122 | input_height=input_height, input_width=input_width, channels=channels) 123 | model.model_name = "vgg_unet" 124 | return model 125 | 126 | 127 | def resnet50_unet(n_classes, input_height=416, input_width=608, 128 | encoder_level=3, channels=3): 129 | 130 | model = _unet(n_classes, get_resnet50_encoder, 131 | input_height=input_height, input_width=input_width, channels=channels) 132 | model.model_name = "resnet50_unet" 133 | return model 134 | 135 | 136 | def mobilenet_unet(n_classes, input_height=224, input_width=224, 137 | encoder_level=3, channels=3): 138 | 139 | model = _unet(n_classes, get_mobilenet_encoder, 140 | input_height=input_height, input_width=input_width, channels=channels) 141 | model.model_name = "mobilenet_unet" 142 | return model 143 | 144 | 145 | if __name__ == '__main__': 146 | m = unet_mini(101) 147 | m = _unet(101, vanilla_encoder) 148 | # m = _unet( 101 , get_mobilenet_encoder ,True , 224 , 224 ) 149 | m = _unet(101, get_vgg_encoder) 150 | m = _unet(101, get_resnet50_encoder) 151 | -------------------------------------------------------------------------------- /keras_segmentation/models/vgg16.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import * 3 | from keras.layers import * 4 | import tensorflow as tf 5 | from .config import IMAGE_ORDERING 6 | 7 | if IMAGE_ORDERING == 'channels_first': 8 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \ 9 | "releases/download/v0.1/" \ 10 | "vgg16_weights_th_dim_ordering_th_kernels_notop.h5" 11 | elif IMAGE_ORDERING == 'channels_last': 12 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \ 13 | "releases/download/v0.1/" \ 14 | "vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5" 15 | 16 | 17 | def get_vgg_encoder(input_height=224, input_width=224, pretrained='imagenet', channels=3): 18 | 19 | assert input_height % 32 == 0 20 | assert input_width % 32 == 0 21 | 22 | if IMAGE_ORDERING == 'channels_first': 23 | img_input = Input(shape=(channels, input_height, input_width)) 24 | elif IMAGE_ORDERING == 'channels_last': 25 | img_input = Input(shape=(input_height, input_width, channels)) 26 | 27 | x = Conv2D(64, (3, 3), activation='relu', padding='same', 28 | name='block1_conv1', data_format=IMAGE_ORDERING)(img_input) 29 | x = Conv2D(64, (3, 3), activation='relu', padding='same', 30 | name='block1_conv2', data_format=IMAGE_ORDERING)(x) 31 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool', 32 | data_format=IMAGE_ORDERING)(x) 33 | f1 = x 34 | # Block 2 35 | x = Conv2D(128, (3, 3), activation='relu', padding='same', 36 | name='block2_conv1', data_format=IMAGE_ORDERING)(x) 37 | x = Conv2D(128, (3, 3), activation='relu', padding='same', 38 | name='block2_conv2', data_format=IMAGE_ORDERING)(x) 39 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool', 40 | data_format=IMAGE_ORDERING)(x) 41 | f2 = x 42 | 43 | # Block 3 44 | x = Conv2D(256, (3, 3), activation='relu', padding='same', 45 | name='block3_conv1', data_format=IMAGE_ORDERING)(x) 46 | x = Conv2D(256, (3, 3), activation='relu', padding='same', 47 | name='block3_conv2', data_format=IMAGE_ORDERING)(x) 48 | x = Conv2D(256, (3, 3), activation='relu', padding='same', 49 | name='block3_conv3', data_format=IMAGE_ORDERING)(x) 50 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool', 51 | data_format=IMAGE_ORDERING)(x) 52 | f3 = x 53 | 54 | # Block 4 55 | x = Conv2D(512, (3, 3), activation='relu', padding='same', 56 | name='block4_conv1', data_format=IMAGE_ORDERING)(x) 57 | x = Conv2D(512, (3, 3), activation='relu', padding='same', 58 | name='block4_conv2', data_format=IMAGE_ORDERING)(x) 59 | x = Conv2D(512, (3, 3), activation='relu', padding='same', 60 | name='block4_conv3', data_format=IMAGE_ORDERING)(x) 61 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool', 62 | data_format=IMAGE_ORDERING)(x) 63 | f4 = x 64 | 65 | # Block 5 66 | x = Conv2D(512, (3, 3), activation='relu', padding='same', 67 | name='block5_conv1', data_format=IMAGE_ORDERING)(x) 68 | x = Conv2D(512, (3, 3), activation='relu', padding='same', 69 | name='block5_conv2', data_format=IMAGE_ORDERING)(x) 70 | x = Conv2D(512, (3, 3), activation='relu', padding='same', 71 | name='block5_conv3', data_format=IMAGE_ORDERING)(x) 72 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool', 73 | data_format=IMAGE_ORDERING)(x) 74 | f5 = x 75 | 76 | if pretrained == 'imagenet': 77 | VGG_Weights_path = tf.keras.utils.get_file( 78 | pretrained_url.split("/")[-1], pretrained_url) 79 | Model(img_input, x).load_weights(VGG_Weights_path, by_name=True, skip_mismatch=True) 80 | 81 | return img_input, [f1, f2, f3, f4, f5] 82 | -------------------------------------------------------------------------------- /keras_segmentation/predict.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import json 4 | import os 5 | import six 6 | 7 | import cv2 8 | import numpy as np 9 | from tqdm import tqdm 10 | from time import time 11 | 12 | from .train import find_latest_checkpoint 13 | from .data_utils.data_loader import get_image_array, get_segmentation_array,\ 14 | DATA_LOADER_SEED, class_colors, get_pairs_from_paths 15 | from .models.config import IMAGE_ORDERING 16 | 17 | 18 | random.seed(DATA_LOADER_SEED) 19 | 20 | 21 | def model_from_checkpoint_path(checkpoints_path): 22 | 23 | from .models.all_models import model_from_name 24 | assert (os.path.isfile(checkpoints_path+"_config.json") 25 | ), "Checkpoint not found." 26 | model_config = json.loads( 27 | open(checkpoints_path+"_config.json", "r").read()) 28 | latest_weights = find_latest_checkpoint(checkpoints_path) 29 | assert (latest_weights is not None), "Checkpoint not found." 30 | model = model_from_name[model_config['model_class']]( 31 | model_config['n_classes'], input_height=model_config['input_height'], 32 | input_width=model_config['input_width']) 33 | print("loaded weights ", latest_weights) 34 | status = model.load_weights(latest_weights) 35 | 36 | if status is not None: 37 | status.expect_partial() 38 | 39 | return model 40 | 41 | 42 | def get_colored_segmentation_image(seg_arr, n_classes, colors=class_colors): 43 | output_height = seg_arr.shape[0] 44 | output_width = seg_arr.shape[1] 45 | 46 | seg_img = np.zeros((output_height, output_width, 3)) 47 | 48 | for c in range(n_classes): 49 | seg_arr_c = seg_arr[:, :] == c 50 | seg_img[:, :, 0] += ((seg_arr_c)*(colors[c][0])).astype('uint8') 51 | seg_img[:, :, 1] += ((seg_arr_c)*(colors[c][1])).astype('uint8') 52 | seg_img[:, :, 2] += ((seg_arr_c)*(colors[c][2])).astype('uint8') 53 | 54 | return seg_img 55 | 56 | 57 | def get_legends(class_names, colors=class_colors): 58 | 59 | n_classes = len(class_names) 60 | legend = np.zeros(((len(class_names) * 25) + 25, 125, 3), 61 | dtype="uint8") + 255 62 | 63 | class_names_colors = enumerate(zip(class_names[:n_classes], 64 | colors[:n_classes])) 65 | 66 | for (i, (class_name, color)) in class_names_colors: 67 | color = [int(c) for c in color] 68 | cv2.putText(legend, class_name, (5, (i * 25) + 17), 69 | cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1) 70 | cv2.rectangle(legend, (100, (i * 25)), (125, (i * 25) + 25), 71 | tuple(color), -1) 72 | 73 | return legend 74 | 75 | 76 | def overlay_seg_image(inp_img, seg_img): 77 | orininal_h = inp_img.shape[0] 78 | orininal_w = inp_img.shape[1] 79 | seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST) 80 | 81 | fused_img = (inp_img/2 + seg_img/2).astype('uint8') 82 | return fused_img 83 | 84 | 85 | def concat_lenends(seg_img, legend_img): 86 | 87 | new_h = np.maximum(seg_img.shape[0], legend_img.shape[0]) 88 | new_w = seg_img.shape[1] + legend_img.shape[1] 89 | 90 | out_img = np.zeros((new_h, new_w, 3)).astype('uint8') + legend_img[0, 0, 0] 91 | 92 | out_img[:legend_img.shape[0], : legend_img.shape[1]] = np.copy(legend_img) 93 | out_img[:seg_img.shape[0], legend_img.shape[1]:] = np.copy(seg_img) 94 | 95 | return out_img 96 | 97 | 98 | def visualize_segmentation(seg_arr, inp_img=None, n_classes=None, 99 | colors=class_colors, class_names=None, 100 | overlay_img=False, show_legends=False, 101 | prediction_width=None, prediction_height=None): 102 | 103 | if n_classes is None: 104 | n_classes = np.max(seg_arr) 105 | 106 | seg_img = get_colored_segmentation_image(seg_arr, n_classes, colors=colors) 107 | 108 | if inp_img is not None: 109 | original_h = inp_img.shape[0] 110 | original_w = inp_img.shape[1] 111 | seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST) 112 | 113 | if (prediction_height is not None) and (prediction_width is not None): 114 | seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST) 115 | if inp_img is not None: 116 | inp_img = cv2.resize(inp_img, 117 | (prediction_width, prediction_height)) 118 | 119 | if overlay_img: 120 | assert inp_img is not None 121 | seg_img = overlay_seg_image(inp_img, seg_img) 122 | 123 | if show_legends: 124 | assert class_names is not None 125 | legend_img = get_legends(class_names, colors=colors) 126 | 127 | seg_img = concat_lenends(seg_img, legend_img) 128 | 129 | return seg_img 130 | 131 | 132 | def predict(model=None, inp=None, out_fname=None, 133 | checkpoints_path=None, overlay_img=False, 134 | class_names=None, show_legends=False, colors=class_colors, 135 | prediction_width=None, prediction_height=None, 136 | read_image_type=1): 137 | 138 | if model is None and (checkpoints_path is not None): 139 | model = model_from_checkpoint_path(checkpoints_path) 140 | 141 | assert (inp is not None) 142 | assert ((type(inp) is np.ndarray) or isinstance(inp, six.string_types)),\ 143 | "Input should be the CV image or the input file name" 144 | 145 | if isinstance(inp, six.string_types): 146 | inp = cv2.imread(inp, read_image_type) 147 | 148 | assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 " 149 | 150 | output_width = model.output_width 151 | output_height = model.output_height 152 | input_width = model.input_width 153 | input_height = model.input_height 154 | n_classes = model.n_classes 155 | 156 | x = get_image_array(inp, input_width, input_height, 157 | ordering=IMAGE_ORDERING) 158 | pr = model.predict(np.array([x]))[0] 159 | pr = pr.reshape((output_height, output_width, n_classes)).argmax(axis=2) 160 | 161 | seg_img = visualize_segmentation(pr, inp, n_classes=n_classes, 162 | colors=colors, overlay_img=overlay_img, 163 | show_legends=show_legends, 164 | class_names=class_names, 165 | prediction_width=prediction_width, 166 | prediction_height=prediction_height) 167 | 168 | if out_fname is not None: 169 | cv2.imwrite(out_fname, seg_img) 170 | 171 | return pr 172 | 173 | 174 | def predict_multiple(model=None, inps=None, inp_dir=None, out_dir=None, 175 | checkpoints_path=None, overlay_img=False, 176 | class_names=None, show_legends=False, colors=class_colors, 177 | prediction_width=None, prediction_height=None, read_image_type=1): 178 | 179 | if model is None and (checkpoints_path is not None): 180 | model = model_from_checkpoint_path(checkpoints_path) 181 | 182 | if inps is None and (inp_dir is not None): 183 | inps = glob.glob(os.path.join(inp_dir, "*.jpg")) + glob.glob( 184 | os.path.join(inp_dir, "*.png")) + \ 185 | glob.glob(os.path.join(inp_dir, "*.jpeg")) 186 | inps = sorted(inps) 187 | 188 | assert type(inps) is list 189 | 190 | all_prs = [] 191 | 192 | if not out_dir is None: 193 | if not os.path.exists(out_dir): 194 | os.makedirs(out_dir) 195 | 196 | 197 | for i, inp in enumerate(tqdm(inps)): 198 | if out_dir is None: 199 | out_fname = None 200 | else: 201 | if isinstance(inp, six.string_types): 202 | out_fname = os.path.join(out_dir, os.path.basename(inp)) 203 | else: 204 | out_fname = os.path.join(out_dir, str(i) + ".jpg") 205 | 206 | pr = predict(model, inp, out_fname, 207 | overlay_img=overlay_img, class_names=class_names, 208 | show_legends=show_legends, colors=colors, 209 | prediction_width=prediction_width, 210 | prediction_height=prediction_height, read_image_type=read_image_type) 211 | 212 | all_prs.append(pr) 213 | 214 | return all_prs 215 | 216 | 217 | def set_video(inp, video_name): 218 | cap = cv2.VideoCapture(inp) 219 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 220 | video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 221 | video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 222 | size = (video_width, video_height) 223 | fourcc = cv2.VideoWriter_fourcc(*"XVID") 224 | video = cv2.VideoWriter(video_name, fourcc, fps, size) 225 | return cap, video, fps 226 | 227 | 228 | def predict_video(model=None, inp=None, output=None, 229 | checkpoints_path=None, display=False, overlay_img=True, 230 | class_names=None, show_legends=False, colors=class_colors, 231 | prediction_width=None, prediction_height=None): 232 | 233 | if model is None and (checkpoints_path is not None): 234 | model = model_from_checkpoint_path(checkpoints_path) 235 | n_classes = model.n_classes 236 | 237 | cap, video, fps = set_video(inp, output) 238 | while(cap.isOpened()): 239 | prev_time = time() 240 | ret, frame = cap.read() 241 | if frame is not None: 242 | pr = predict(model=model, inp=frame) 243 | fused_img = visualize_segmentation( 244 | pr, frame, n_classes=n_classes, 245 | colors=colors, 246 | overlay_img=overlay_img, 247 | show_legends=show_legends, 248 | class_names=class_names, 249 | prediction_width=prediction_width, 250 | prediction_height=prediction_height 251 | ) 252 | else: 253 | break 254 | print("FPS: {}".format(1/(time() - prev_time))) 255 | if output is not None: 256 | video.write(fused_img) 257 | if display: 258 | cv2.imshow('Frame masked', fused_img) 259 | if cv2.waitKey(fps) & 0xFF == ord('q'): 260 | break 261 | cap.release() 262 | if output is not None: 263 | video.release() 264 | cv2.destroyAllWindows() 265 | 266 | 267 | def evaluate(model=None, inp_images=None, annotations=None, 268 | inp_images_dir=None, annotations_dir=None, checkpoints_path=None, read_image_type=1): 269 | 270 | if model is None: 271 | assert (checkpoints_path is not None),\ 272 | "Please provide the model or the checkpoints_path" 273 | model = model_from_checkpoint_path(checkpoints_path) 274 | 275 | if inp_images is None: 276 | assert (inp_images_dir is not None),\ 277 | "Please provide inp_images or inp_images_dir" 278 | assert (annotations_dir is not None),\ 279 | "Please provide inp_images or inp_images_dir" 280 | 281 | paths = get_pairs_from_paths(inp_images_dir, annotations_dir) 282 | paths = list(zip(*paths)) 283 | inp_images = list(paths[0]) 284 | annotations = list(paths[1]) 285 | 286 | assert type(inp_images) is list 287 | assert type(annotations) is list 288 | 289 | tp = np.zeros(model.n_classes) 290 | fp = np.zeros(model.n_classes) 291 | fn = np.zeros(model.n_classes) 292 | n_pixels = np.zeros(model.n_classes) 293 | 294 | for inp, ann in tqdm(zip(inp_images, annotations)): 295 | pr = predict(model, inp, read_image_type=read_image_type) 296 | gt = get_segmentation_array(ann, model.n_classes, 297 | model.output_width, model.output_height, 298 | no_reshape=True, read_image_type=read_image_type) 299 | gt = gt.argmax(-1) 300 | pr = pr.flatten() 301 | gt = gt.flatten() 302 | 303 | for cl_i in range(model.n_classes): 304 | 305 | tp[cl_i] += np.sum((pr == cl_i) * (gt == cl_i)) 306 | fp[cl_i] += np.sum((pr == cl_i) * ((gt != cl_i))) 307 | fn[cl_i] += np.sum((pr != cl_i) * ((gt == cl_i))) 308 | n_pixels[cl_i] += np.sum(gt == cl_i) 309 | 310 | cl_wise_score = tp / (tp + fp + fn + 0.000000000001) 311 | n_pixels_norm = n_pixels / np.sum(n_pixels) 312 | frequency_weighted_IU = np.sum(cl_wise_score*n_pixels_norm) 313 | mean_IU = np.mean(cl_wise_score) 314 | 315 | return { 316 | "frequency_weighted_IU": frequency_weighted_IU, 317 | "mean_IU": mean_IU, 318 | "class_wise_IU": cl_wise_score 319 | } 320 | -------------------------------------------------------------------------------- /keras_segmentation/pretrained.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import tensorflow as tf 3 | from .models.all_models import model_from_name 4 | 5 | 6 | def model_from_checkpoint_path(model_config, latest_weights): 7 | 8 | model = model_from_name[model_config['model_class']]( 9 | model_config['n_classes'], input_height=model_config['input_height'], 10 | input_width=model_config['input_width']) 11 | model.load_weights(latest_weights) 12 | return model 13 | 14 | 15 | def resnet_pspnet_VOC12_v0_1(): 16 | 17 | model_config = { 18 | "output_height": 96, 19 | "input_height": 384, 20 | "input_width": 576, 21 | "n_classes": 151, 22 | "model_class": "resnet50_pspnet", 23 | "output_width": 144 24 | } 25 | 26 | REPO_URL = "https://github.com/divamgupta/image-segmentation-keras" 27 | MODEL_PATH = "pretrained_model_1/r2_voc12_resnetpspnet_384x576.24" 28 | model_url = "{0}/releases/download/{1}".format(REPO_URL, MODEL_PATH) 29 | latest_weights = tf.keras.utils.get_file(model_url.split("/")[-1], model_url) 30 | 31 | return model_from_checkpoint_path(model_config, latest_weights) 32 | 33 | 34 | # pretrained model converted from caffe by Vladkryvoruchko ... thanks ! 35 | def pspnet_50_ADE_20K(): 36 | 37 | model_config = { 38 | "input_height": 473, 39 | "input_width": 473, 40 | "n_classes": 150, 41 | "model_class": "pspnet_50", 42 | } 43 | 44 | model_url = "https://www.dropbox.com/s/" \ 45 | "0uxn14y26jcui4v/pspnet50_ade20k.h5?dl=1" 46 | latest_weights = tf.keras.utils.get_file("pspnet50_ade20k.h5", model_url) 47 | 48 | return model_from_checkpoint_path(model_config, latest_weights) 49 | 50 | 51 | def pspnet_101_cityscapes(): 52 | 53 | model_config = { 54 | "input_height": 713, 55 | "input_width": 713, 56 | "n_classes": 19, 57 | "model_class": "pspnet_101", 58 | } 59 | 60 | model_url = "https://www.dropbox.com/s/" \ 61 | "c17g94n946tpalb/pspnet101_cityscapes.h5?dl=1" 62 | latest_weights = tf.keras.utils.get_file("pspnet101_cityscapes.h5", model_url) 63 | 64 | return model_from_checkpoint_path(model_config, latest_weights) 65 | 66 | 67 | def pspnet_101_voc12(): 68 | 69 | model_config = { 70 | "input_height": 473, 71 | "input_width": 473, 72 | "n_classes": 21, 73 | "model_class": "pspnet_101", 74 | } 75 | 76 | model_url = "https://www.dropbox.com/s/" \ 77 | "uvqj2cjo4b9c5wg/pspnet101_voc2012.h5?dl=1" 78 | latest_weights = tf.keras.utils.get_file("pspnet101_voc2012.h5", model_url) 79 | 80 | return model_from_checkpoint_path(model_config, latest_weights) 81 | -------------------------------------------------------------------------------- /keras_segmentation/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from .data_utils.data_loader import image_segmentation_generator, \ 5 | verify_segmentation_dataset 6 | import six 7 | from keras.callbacks import Callback 8 | from keras.callbacks import ModelCheckpoint 9 | import tensorflow as tf 10 | import glob 11 | import sys 12 | 13 | def find_latest_checkpoint(checkpoints_path, fail_safe=True): 14 | 15 | # This is legacy code, there should always be a "checkpoint" file in your directory 16 | 17 | def get_epoch_number_from_path(path): 18 | return path.replace(checkpoints_path, "").strip(".") 19 | 20 | # Get all matching files 21 | all_checkpoint_files = glob.glob(checkpoints_path + ".*") 22 | if len(all_checkpoint_files) == 0: 23 | all_checkpoint_files = glob.glob(checkpoints_path + "*.*") 24 | all_checkpoint_files = [ff.replace(".index", "") for ff in 25 | all_checkpoint_files] # to make it work for newer versions of keras 26 | # Filter out entries where the epoc_number part is pure number 27 | all_checkpoint_files = list(filter(lambda f: get_epoch_number_from_path(f) 28 | .isdigit(), all_checkpoint_files)) 29 | if not len(all_checkpoint_files): 30 | # The glob list is empty, don't have a checkpoints_path 31 | if not fail_safe: 32 | raise ValueError("Checkpoint path {0} invalid" 33 | .format(checkpoints_path)) 34 | else: 35 | return None 36 | 37 | # Find the checkpoint file with the maximum epoch 38 | latest_epoch_checkpoint = max(all_checkpoint_files, 39 | key=lambda f: 40 | int(get_epoch_number_from_path(f))) 41 | 42 | return latest_epoch_checkpoint 43 | 44 | def masked_categorical_crossentropy(gt, pr): 45 | from keras.losses import categorical_crossentropy 46 | mask = 1 - gt[:, :, 0] 47 | return categorical_crossentropy(gt, pr) * mask 48 | 49 | 50 | class CheckpointsCallback(Callback): 51 | def __init__(self, checkpoints_path): 52 | self.checkpoints_path = checkpoints_path 53 | 54 | def on_epoch_end(self, epoch, logs=None): 55 | if self.checkpoints_path is not None: 56 | self.model.save_weights(self.checkpoints_path + "." + str(epoch)) 57 | print("saved ", self.checkpoints_path + "." + str(epoch)) 58 | 59 | 60 | def train(model, 61 | train_images, 62 | train_annotations, 63 | input_height=None, 64 | input_width=None, 65 | n_classes=None, 66 | verify_dataset=True, 67 | checkpoints_path=None, 68 | epochs=5, 69 | batch_size=2, 70 | validate=False, 71 | val_images=None, 72 | val_annotations=None, 73 | val_batch_size=2, 74 | auto_resume_checkpoint=False, 75 | load_weights=None, 76 | steps_per_epoch=512, 77 | val_steps_per_epoch=512, 78 | gen_use_multiprocessing=False, 79 | ignore_zero_class=False, 80 | optimizer_name='adam', 81 | do_augment=False, 82 | augmentation_name="aug_all", 83 | callbacks=None, 84 | custom_augmentation=None, 85 | other_inputs_paths=None, 86 | preprocessing=None, 87 | read_image_type=1 # cv2.IMREAD_COLOR = 1 (rgb), 88 | # cv2.IMREAD_GRAYSCALE = 0, 89 | # cv2.IMREAD_UNCHANGED = -1 (4 channels like RGBA) 90 | ): 91 | from .models.all_models import model_from_name 92 | # check if user gives model name instead of the model object 93 | if isinstance(model, six.string_types): 94 | # create the model from the name 95 | assert (n_classes is not None), "Please provide the n_classes" 96 | if (input_height is not None) and (input_width is not None): 97 | model = model_from_name[model]( 98 | n_classes, input_height=input_height, input_width=input_width) 99 | else: 100 | model = model_from_name[model](n_classes) 101 | 102 | n_classes = model.n_classes 103 | input_height = model.input_height 104 | input_width = model.input_width 105 | output_height = model.output_height 106 | output_width = model.output_width 107 | 108 | if validate: 109 | assert val_images is not None 110 | assert val_annotations is not None 111 | 112 | if optimizer_name is not None: 113 | 114 | if ignore_zero_class: 115 | loss_k = masked_categorical_crossentropy 116 | else: 117 | loss_k = 'categorical_crossentropy' 118 | 119 | model.compile(loss=loss_k, 120 | optimizer=optimizer_name, 121 | metrics=['accuracy']) 122 | 123 | if checkpoints_path is not None: 124 | config_file = checkpoints_path + "_config.json" 125 | dir_name = os.path.dirname(config_file) 126 | 127 | if ( not os.path.exists(dir_name) ) and len( dir_name ) > 0 : 128 | os.makedirs(dir_name) 129 | 130 | with open(config_file, "w") as f: 131 | json.dump({ 132 | "model_class": model.model_name, 133 | "n_classes": n_classes, 134 | "input_height": input_height, 135 | "input_width": input_width, 136 | "output_height": output_height, 137 | "output_width": output_width 138 | }, f) 139 | 140 | if load_weights is not None and len(load_weights) > 0: 141 | print("Loading weights from ", load_weights) 142 | model.load_weights(load_weights) 143 | 144 | initial_epoch = 0 145 | 146 | if auto_resume_checkpoint and (checkpoints_path is not None): 147 | latest_checkpoint = find_latest_checkpoint(checkpoints_path) 148 | if latest_checkpoint is not None: 149 | print("Loading the weights from latest checkpoint ", 150 | latest_checkpoint) 151 | model.load_weights(latest_checkpoint) 152 | 153 | initial_epoch = int(latest_checkpoint.split('.')[-1]) 154 | 155 | if verify_dataset: 156 | print("Verifying training dataset") 157 | verified = verify_segmentation_dataset(train_images, 158 | train_annotations, 159 | n_classes) 160 | assert verified 161 | if validate: 162 | print("Verifying validation dataset") 163 | verified = verify_segmentation_dataset(val_images, 164 | val_annotations, 165 | n_classes) 166 | assert verified 167 | 168 | train_gen = image_segmentation_generator( 169 | train_images, train_annotations, batch_size, n_classes, 170 | input_height, input_width, output_height, output_width, 171 | do_augment=do_augment, augmentation_name=augmentation_name, 172 | custom_augmentation=custom_augmentation, other_inputs_paths=other_inputs_paths, 173 | preprocessing=preprocessing, read_image_type=read_image_type) 174 | 175 | if validate: 176 | val_gen = image_segmentation_generator( 177 | val_images, val_annotations, val_batch_size, 178 | n_classes, input_height, input_width, output_height, output_width, 179 | other_inputs_paths=other_inputs_paths, 180 | preprocessing=preprocessing, read_image_type=read_image_type) 181 | 182 | if callbacks is None and (not checkpoints_path is None) : 183 | default_callback = ModelCheckpoint( 184 | filepath=checkpoints_path + ".{epoch:05d}", 185 | save_weights_only=True, 186 | verbose=True 187 | ) 188 | 189 | if sys.version_info[0] < 3: # for pyhton 2 190 | default_callback = CheckpointsCallback(checkpoints_path) 191 | 192 | callbacks = [ 193 | default_callback 194 | ] 195 | 196 | if callbacks is None: 197 | callbacks = [] 198 | 199 | if not validate: 200 | model.fit(train_gen, steps_per_epoch=steps_per_epoch, 201 | epochs=epochs, callbacks=callbacks, initial_epoch=initial_epoch) 202 | else: 203 | model.fit(train_gen, 204 | steps_per_epoch=steps_per_epoch, 205 | validation_data=val_gen, 206 | validation_steps=val_steps_per_epoch, 207 | epochs=epochs, callbacks=callbacks, 208 | use_multiprocessing=gen_use_multiprocessing, initial_epoch=initial_epoch) 209 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | h5py 3 | tqdm 4 | keras>=2.3.0 5 | tensorflow>=2.2 6 | opencv-python 7 | tqdm 8 | imageio>=2.5.0 9 | imgaug>=0.4.0 10 | -------------------------------------------------------------------------------- /sample_images/1_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/1_input.jpg -------------------------------------------------------------------------------- /sample_images/1_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/1_output.png -------------------------------------------------------------------------------- /sample_images/2_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/2_input.jpg -------------------------------------------------------------------------------- /sample_images/2_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/2_output.png -------------------------------------------------------------------------------- /sample_images/3_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/3_input.jpg -------------------------------------------------------------------------------- /sample_images/3_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/3_output.png -------------------------------------------------------------------------------- /sample_images/liner_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_dataset.png -------------------------------------------------------------------------------- /sample_images/liner_export.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_export.png -------------------------------------------------------------------------------- /sample_images/liner_testing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_testing.png -------------------------------------------------------------------------------- /sample_images/liner_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_training.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [egg_info] 2 | tag_build = 3 | tag_date = 0 4 | 5 | [metadata] 6 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | import sys 4 | 5 | cv_ver = "" 6 | keras_ver = ">=2.0.0" 7 | if sys.version_info.major < 3: 8 | cv_ver = "<=4.2.0.32" 9 | keras_ver = "<=2.3.0" 10 | 11 | 12 | setup(name="keras_segmentation", 13 | version="0.3.0", 14 | description="Image Segmentation toolkit for keras", 15 | author="Divam Gupta", 16 | author_email='divamgupta@gmail.com', 17 | platforms=["any"], # or more specific, e.g. "win32", "cygwin", "osx" 18 | license="GPLv3", 19 | url="https://github.com/divamgupta/image-segmentation-keras", 20 | packages=find_packages(exclude=["test"]), 21 | entry_points={ 22 | 'console_scripts': [ 23 | 'keras_segmentation = keras_segmentation.__main__:main' 24 | ] 25 | }, 26 | install_requires=[ 27 | "h5py<=2.10.0", 28 | "Keras"+keras_ver, 29 | "imageio==2.5.0", 30 | "imgaug>=0.4.0", 31 | "opencv-python"+cv_ver, 32 | "tqdm"], 33 | extras_require={ 34 | # These requires provide different backends available with Keras 35 | "tensorflow": ["tensorflow"], 36 | "cntk": ["cntk"], 37 | "theano": ["theano"], 38 | # Default testing with tensorflow 39 | "tests-default": ["tensorflow", "pytest"] 40 | } 41 | ) 42 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/__init__.py -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_test/0016E5_07959.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_test/0016E5_07959.png -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_test/0016E5_07961.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_test/0016E5_07961.png -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_test/0016E5_07963.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_test/0016E5_07963.png -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_train/0001TP_006690.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006690.png -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_train/0001TP_006720.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006720.png -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_train/0001TP_006750.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006750.png -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_train/0001TP_006780.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006780.png -------------------------------------------------------------------------------- /test/example_dataset/annotations_prepped_train/0001TP_006810.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006810.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_test/0016E5_07959.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_test/0016E5_07959.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_test/0016E5_07961.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_test/0016E5_07961.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_test/0016E5_07963.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_test/0016E5_07963.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_train/0001TP_006690.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006690.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_train/0001TP_006720.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006720.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_train/0001TP_006750.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006750.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_train/0001TP_006780.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006780.png -------------------------------------------------------------------------------- /test/example_dataset/images_prepped_train/0001TP_006810.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006810.png -------------------------------------------------------------------------------- /test/test_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tempfile 3 | 4 | import sys 5 | 6 | import keras 7 | 8 | from keras_segmentation.models import all_models 9 | from keras_segmentation.data_utils.data_loader import \ 10 | verify_segmentation_dataset, image_segmentation_generator 11 | from keras_segmentation.predict import predict_multiple, predict, evaluate 12 | 13 | 14 | from keras_segmentation.model_compression import perform_distilation 15 | from keras_segmentation.pretrained import pspnet_50_ADE_20K 16 | 17 | tr_im = "test/example_dataset/images_prepped_train" 18 | tr_an = "test/example_dataset/annotations_prepped_train" 19 | te_im = "test/example_dataset/images_prepped_test" 20 | te_an = "test/example_dataset/annotations_prepped_test" 21 | 22 | 23 | 24 | def test_models(): 25 | 26 | n_c = 100 27 | 28 | models = [ ( "unet_mini" , 124 , 156 ) , ( "vgg_unet" , 224 , 224*2 ) , 29 | ( 'resnet50_pspnet', 192*2 , 192*3 ) , ( 'mobilenet_unet', 224 , 224 ), ( 'mobilenet_unet', 224+32 , 224+32 ),( 'segnet', 224 , 224*2 ),( 'vgg_segnet', 224 , 224*2 ) ,( 'fcn_32', 224 , 224*2 ) ,( 'fcn_8_vgg', 224 , 224*2 ) ] 30 | 31 | for model_name , h , w in models: 32 | m = all_models.model_from_name[model_name]( n_c, input_height=h, input_width=w) 33 | 34 | m.train(train_images=tr_im, 35 | train_annotations=tr_an, 36 | steps_per_epoch=2, 37 | epochs=2 ) 38 | 39 | keras.backend.clear_session() 40 | 41 | 42 | 43 | 44 | 45 | def test_verify(): 46 | verify_segmentation_dataset(tr_im, tr_an, 50) 47 | 48 | 49 | def test_datag(): 50 | g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an, 51 | batch_size=3, n_classes=50, 52 | input_height=224, input_width=324, 53 | output_height=114, output_width=134, 54 | do_augment=False) 55 | 56 | x, y = next(g) 57 | assert x.shape[0] == 3 58 | assert y.shape[0] == 3 59 | assert y.shape[-1] == 50 60 | 61 | 62 | # with augmentation 63 | def test_datag2(): 64 | g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an, 65 | batch_size=3, n_classes=50, 66 | input_height=224, input_width=324, 67 | output_height=114, output_width=134, 68 | do_augment=True) 69 | 70 | x, y = next(g) 71 | assert x.shape[0] == 3 72 | assert y.shape[0] == 3 73 | assert y.shape[-1] == 50 74 | 75 | 76 | def test_model(): 77 | model_name = "fcn_8" 78 | h = 224 79 | w = 256 80 | n_c = 100 81 | check_path = tempfile.mktemp() 82 | 83 | m = all_models.model_from_name[model_name]( 84 | n_c, input_height=h, input_width=w) 85 | 86 | m.train(train_images=tr_im, 87 | train_annotations=tr_an, 88 | steps_per_epoch=2, 89 | epochs=2, 90 | checkpoints_path=check_path 91 | ) 92 | 93 | m.train(train_images=tr_im, 94 | train_annotations=tr_an, 95 | steps_per_epoch=2, 96 | epochs=2, 97 | checkpoints_path=check_path, 98 | augmentation_name='aug_geometric', do_augment=True 99 | ) 100 | 101 | m.predict_segmentation(np.zeros((h, w, 3))).shape 102 | 103 | predict_multiple( 104 | inp_dir=te_im, checkpoints_path=check_path, out_dir="/tmp") 105 | predict_multiple(inps=[np.zeros((h, w, 3))]*3, 106 | checkpoints_path=check_path, out_dir="/tmp") 107 | 108 | ev = m.evaluate_segmentation(inp_images_dir=te_im, annotations_dir=te_an) 109 | assert ev['frequency_weighted_IU'] > 0.01 110 | print(ev) 111 | o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path) 112 | 113 | o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path, 114 | overlay_img=True, class_names=['nn']*n_c, show_legends=True) 115 | print("pr") 116 | 117 | o.shape 118 | 119 | ev = evaluate(inp_images_dir=te_im, 120 | annotations_dir=te_an, 121 | checkpoints_path=check_path) 122 | assert ev['frequency_weighted_IU'] > 0.01 123 | 124 | 125 | def test_kd(): 126 | 127 | if sys.version_info.major < 3: 128 | # KD wont work with python 2 129 | return 130 | 131 | model_name = "fcn_8" 132 | h = 224 133 | w = 256 134 | n_c = 100 135 | check_path1 = tempfile.mktemp() 136 | 137 | m1 = all_models.model_from_name[model_name]( 138 | n_c, input_height=h, input_width=w) 139 | 140 | 141 | 142 | model_name = "unet_mini" 143 | h = 124 144 | w = 156 145 | n_c = 100 146 | check_path2 = tempfile.mktemp() 147 | 148 | m2 = all_models.model_from_name[model_name]( 149 | n_c, input_height=h, input_width=w) 150 | 151 | 152 | m1.train(train_images=tr_im, 153 | train_annotations=tr_an, 154 | steps_per_epoch=2, 155 | epochs=2, 156 | checkpoints_path=check_path1 157 | ) 158 | 159 | perform_distilation(m1 ,m2, tr_im , distilation_loss='kl' , 160 | batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, ) 161 | 162 | 163 | perform_distilation(m1 ,m2, tr_im , distilation_loss='l2' , 164 | batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, ) 165 | 166 | 167 | perform_distilation(m1 ,m2, tr_im , distilation_loss='l2' , 168 | batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, feats_distilation_loss='pa' ) 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | def test_pretrained(): 180 | 181 | 182 | model = pspnet_50_ADE_20K() 183 | 184 | out = model.predict_segmentation( 185 | inp=te_im+"/0016E5_07959.png", 186 | out_fname="/tmp/out.png" 187 | ) 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | # def test_models(): 197 | 198 | 199 | # unet_models = [ , models.unet.vgg_unet , models.unet.resnet50_unet ] 200 | # args = [(101, 416, 608), (101, 224 ,224), (101, 256, 256), (2, 32*4, 32*5)] 201 | # en_level = [ 1,2,3,4 ] 202 | 203 | # for mf in unet_models: 204 | # for en in en_level: 205 | # for ar in args: 206 | # m = mf( *ar , encoder_level=en ) 207 | 208 | 209 | # m = models.unet.mobilenet_unet( 55 ) 210 | # for ar in args: 211 | # m = unet_mini( *ar ) 212 | -------------------------------------------------------------------------------- /test/unit/data_utils/test_augmentation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/data_utils/test_augmentation.py -------------------------------------------------------------------------------- /test/unit/data_utils/test_data_loader.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import unittest 3 | import tempfile 4 | from shutil import rmtree 5 | import os 6 | import six 7 | from keras_segmentation.data_utils import data_loader 8 | import random 9 | import cv2 10 | from imgaug import augmenters as iaa 11 | import shutil 12 | import numpy as np 13 | 14 | class TestGetPairsFromPaths(unittest.TestCase): 15 | """ Test data loader facilities """ 16 | 17 | def _setup_images_and_segs(self, images, segs): 18 | for file_name in images: 19 | open(os.path.join(self.img_path, file_name), 'a').close() 20 | for file_name in segs: 21 | open(os.path.join(self.seg_path, file_name), 'a').close() 22 | 23 | @classmethod 24 | def _cleanup_folder(cls, folder_path): 25 | return rmtree(folder_path) 26 | 27 | def setUp(self): 28 | self.tmp_dir = tempfile.mkdtemp() 29 | self.img_path = os.path.join(self.tmp_dir, "images") 30 | self.seg_path = os.path.join(self.tmp_dir, "segs") 31 | os.mkdir(self.img_path) 32 | os.mkdir(self.seg_path) 33 | 34 | def tearDown(self): 35 | rmtree(self.tmp_dir) 36 | 37 | def test_get_pairs_from_paths_1(self): 38 | """ Normal execution """ 39 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"] 40 | segs = ["A.png", "B.png", "C.png", "D.png"] 41 | self._setup_images_and_segs(images, segs) 42 | 43 | expected = [("A.jpg", "A.png"), 44 | ("B.jpg", "B.png"), 45 | ("C.jpeg", "C.png"), 46 | ("D.png", "D.png")] 47 | expected_values = [] 48 | # Transform paths 49 | for (x, y) in expected: 50 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y))) 51 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path))) 52 | 53 | def test_get_pairs_from_paths_2(self): 54 | """ Normal execution with extra files """ 55 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png", "E.txt"] 56 | segs = ["A.png", "B.png", "C.png", "D.png", "E.png"] 57 | self._setup_images_and_segs(images, segs) 58 | 59 | expected = [("A.jpg", "A.png"), 60 | ("B.jpg", "B.png"), 61 | ("C.jpeg", "C.png"), 62 | ("D.png", "D.png")] 63 | expected_values = [] 64 | # Transform paths 65 | for (x, y) in expected: 66 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y))) 67 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path))) 68 | 69 | 70 | def test_get_pairs_from_paths_3(self): 71 | """ Normal execution with multiple images pointing to one """ 72 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png", "D.jpg"] 73 | segs = ["A.png", "B.png", "C.png", "D.png"] 74 | self._setup_images_and_segs(images, segs) 75 | 76 | expected = [("A.jpg", "A.png"), 77 | ("B.jpg", "B.png"), 78 | ("C.jpeg", "C.png"), 79 | ("D.jpg", "D.png"), 80 | ("D.png", "D.png"),] 81 | expected_values = [] 82 | # Transform paths 83 | for (x, y) in expected: 84 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y))) 85 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path))) 86 | 87 | def test_get_pairs_from_paths_with_invalid_segs(self): 88 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"] 89 | segs = ["A.png", "B.png", "C.png", "D.png", "D.jpg"] 90 | self._setup_images_and_segs(images, segs) 91 | 92 | expected = [("A.jpg", "A.png"), 93 | ("B.jpg", "B.png"), 94 | ("C.jpeg", "C.png"), 95 | ("D.png", "D.png"),] 96 | expected_values = [] 97 | # Transform paths 98 | for (x, y) in expected: 99 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y))) 100 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path))) 101 | 102 | def test_get_pairs_from_paths_with_no_matching_segs(self): 103 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"] 104 | segs = ["A.png", "B.png", "C.png"] 105 | self._setup_images_and_segs(images, segs) 106 | 107 | expected = [("A.jpg", "A.png"), 108 | ("B.jpg", "B.png"), 109 | ("C.jpeg", "C.png")] 110 | expected_values = [] 111 | # Transform paths 112 | for (x, y) in expected: 113 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y))) 114 | six.assertRaisesRegex(self, data_loader.DataLoaderError, "No corresponding segmentation found for image", data_loader.get_pairs_from_paths, self.img_path, self.seg_path) 115 | 116 | def test_get_pairs_from_paths_with_no_matching_segs_with_escape(self): 117 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"] 118 | segs = ["A.png", "B.png", "C.png"] 119 | self._setup_images_and_segs(images, segs) 120 | 121 | expected = [("A.jpg", "A.png"), 122 | ("B.jpg", "B.png"), 123 | ("C.jpeg", "C.png")] 124 | expected_values = [] 125 | # Transform paths 126 | for (x, y) in expected: 127 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y))) 128 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path, ignore_non_matching=True))) 129 | 130 | class TestGetImageArray(unittest.TestCase): 131 | def test_get_image_array_normal(self): 132 | """ Stub test 133 | TODO(divamgupta): Fill with actual test 134 | """ 135 | pass 136 | 137 | class TestGetSegmentationArray(unittest.TestCase): 138 | def test_get_segmentation_array_normal(self): 139 | """ Stub test 140 | TODO(divamgupta): Fill with actual test 141 | """ 142 | pass 143 | 144 | class TestVerifySegmentationDataset(unittest.TestCase): 145 | def test_verify_segmentation_dataset(self): 146 | """ Stub test 147 | TODO(divamgupta): Fill with actual test 148 | """ 149 | pass 150 | 151 | class TestImageSegmentationGenerator(unittest.TestCase): 152 | def setUp(self): 153 | self.train_temp_dir = tempfile.mkdtemp() 154 | self.test_temp_dir = tempfile.mkdtemp() 155 | self.other_temp_dir = tempfile.mkdtemp() 156 | self.other_temp_dir_2 = tempfile.mkdtemp() 157 | 158 | self.image_size = 4 159 | 160 | # Training 161 | train_image = np.arange(self.image_size * self.image_size) 162 | train_image = train_image.reshape((self.image_size, self.image_size)) 163 | 164 | train_file = os.path.join(self.train_temp_dir, "train.png") 165 | test_file = os.path.join(self.test_temp_dir, "train.png") 166 | 167 | cv2.imwrite(train_file, train_image) 168 | cv2.imwrite(test_file, train_image) 169 | 170 | # Testing 171 | train_image = np.arange(start=self.image_size * self.image_size, 172 | stop=self.image_size * self.image_size * 2) 173 | train_image = train_image.reshape((self.image_size,self.image_size)) 174 | 175 | train_file = os.path.join(self.train_temp_dir, "train2.png") 176 | test_file = os.path.join(self.test_temp_dir, "train2.png") 177 | 178 | cv2.imwrite(train_file, train_image) 179 | cv2.imwrite(test_file, train_image) 180 | 181 | # Extra one 182 | 183 | i = 0 184 | for dir in [self.other_temp_dir, self.other_temp_dir_2]: 185 | extra_image = np.arange(start=self.image_size * self.image_size * (2 + i), 186 | stop=self.image_size * self.image_size * (2 + i + 1)) 187 | extra_image = extra_image.reshape((self.image_size, self.image_size)) 188 | 189 | extra_file = os.path.join(dir, "train.png") 190 | cv2.imwrite(extra_file, extra_image) 191 | i += 1 192 | 193 | extra_image = np.arange(start=self.image_size * self.image_size * (2 + i), 194 | stop=self.image_size * self.image_size * (2 + i + 1)) 195 | extra_image = extra_image.reshape((self.image_size, self.image_size)) 196 | 197 | extra_file = os.path.join(dir, "train2.png") 198 | cv2.imwrite(extra_file, extra_image) 199 | i += 1 200 | 201 | def tearDown(self): 202 | shutil.rmtree(self.train_temp_dir) 203 | shutil.rmtree(self.test_temp_dir) 204 | shutil.rmtree(self.other_temp_dir) 205 | shutil.rmtree(self.other_temp_dir_2) 206 | 207 | def custom_aug(self): 208 | return iaa.Sequential( 209 | [ 210 | iaa.Fliplr(1), # horizontally flip 100% of all images 211 | ]) 212 | 213 | def test_image_segmentation_generator_custom_augmentation(self): 214 | random.seed(0) 215 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, self.test_temp_dir) 216 | 217 | random.seed(0) 218 | random.shuffle(image_seg_pairs) 219 | 220 | random.seed(0) 221 | 222 | generator = data_loader.image_segmentation_generator( 223 | self.train_temp_dir, self.test_temp_dir, 1, 224 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size, self.image_size, 225 | do_augment=True, custom_augmentation=self.custom_aug 226 | ) 227 | 228 | i = 0 229 | for (aug_im, aug_an), (expt_im_f, expt_an_f) in zip(generator, image_seg_pairs): 230 | if i >= len(image_seg_pairs): 231 | break 232 | 233 | expt_im = data_loader.get_image_array(expt_im_f, self.image_size, self.image_size, ordering='channel_last') 234 | 235 | expt_im = cv2.flip(expt_im, flipCode=1) 236 | self.assertTrue(np.equal(expt_im, aug_im).all()) 237 | 238 | i += 1 239 | 240 | def test_image_segmentation_generator_custom_augmentation_with_other_inputs(self): 241 | other_paths = [ 242 | self.other_temp_dir, self.other_temp_dir_2 243 | ] 244 | random.seed(0) 245 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, 246 | self.test_temp_dir, 247 | other_inputs_paths=other_paths) 248 | 249 | random.seed(0) 250 | random.shuffle(image_seg_pairs) 251 | 252 | random.seed(0) 253 | generator = data_loader.image_segmentation_generator( 254 | self.train_temp_dir, self.test_temp_dir, 1, 255 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size, 256 | self.image_size, 257 | do_augment=True, custom_augmentation=self.custom_aug, other_inputs_paths=other_paths 258 | ) 259 | 260 | i = 0 261 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs): 262 | if i >= len(image_seg_pairs): 263 | break 264 | 265 | ims = [expt_im_f] 266 | ims.extend(expt_oth) 267 | 268 | for i in range(aug_im.shape[1]): 269 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size, 270 | ordering='channel_last') 271 | 272 | expt_im = cv2.flip(expt_im, flipCode=1) 273 | 274 | self.assertTrue(np.equal(expt_im, aug_im[0, i, :, :]).all()) 275 | 276 | i += 1 277 | 278 | def test_image_segmentation_generator_with_other_inputs(self): 279 | other_paths = [ 280 | self.other_temp_dir, self.other_temp_dir_2 281 | ] 282 | random.seed(0) 283 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, 284 | self.test_temp_dir, 285 | other_inputs_paths=other_paths) 286 | 287 | random.seed(0) 288 | random.shuffle(image_seg_pairs) 289 | 290 | random.seed(0) 291 | generator = data_loader.image_segmentation_generator( 292 | self.train_temp_dir, self.test_temp_dir, 1, 293 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size, 294 | self.image_size, 295 | other_inputs_paths=other_paths 296 | ) 297 | 298 | i = 0 299 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs): 300 | if i >= len(image_seg_pairs): 301 | break 302 | 303 | ims = [expt_im_f] 304 | ims.extend(expt_oth) 305 | 306 | for i in range(aug_im.shape[1]): 307 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size, 308 | ordering='channel_last') 309 | self.assertTrue(np.equal(expt_im, aug_im[0, i, :, :]).all()) 310 | 311 | i += 1 312 | 313 | def test_image_segmentation_generator_preprocessing(self): 314 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, self.test_temp_dir) 315 | 316 | random.seed(0) 317 | random.shuffle(image_seg_pairs) 318 | 319 | random.seed(0) 320 | 321 | generator = data_loader.image_segmentation_generator( 322 | self.train_temp_dir, self.test_temp_dir, 1, 323 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size, 324 | self.image_size, 325 | preprocessing=lambda x: x + 1 326 | ) 327 | 328 | i = 0 329 | for (aug_im, aug_an), (expt_im_f, expt_an_f) in zip(generator, image_seg_pairs): 330 | if i >= len(image_seg_pairs): 331 | break 332 | 333 | expt_im = data_loader.get_image_array(expt_im_f, self.image_size, self.image_size, 334 | ordering='channel_last') 335 | 336 | expt_im += 1 337 | self.assertTrue(np.equal(expt_im, aug_im[0, :, :]).all()) 338 | 339 | i += 1 340 | 341 | def test_single_image_segmentation_generator_preprocessing_with_other_inputs(self): 342 | other_paths = [ 343 | self.train_temp_dir, self.test_temp_dir 344 | ] 345 | random.seed(0) 346 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, 347 | self.test_temp_dir, 348 | other_inputs_paths=other_paths) 349 | 350 | random.seed(0) 351 | random.shuffle(image_seg_pairs) 352 | 353 | random.seed(0) 354 | generator = data_loader.image_segmentation_generator( 355 | self.train_temp_dir, self.test_temp_dir, 1, 356 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size, 357 | self.image_size, 358 | preprocessing=lambda x: x+1, other_inputs_paths=other_paths 359 | ) 360 | 361 | i = 0 362 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs): 363 | if i >= len(image_seg_pairs): 364 | break 365 | 366 | ims = [expt_im_f] 367 | ims.extend(expt_oth) 368 | 369 | for i in range(aug_im.shape[1]): 370 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size, 371 | ordering='channel_last') 372 | 373 | self.assertTrue(np.equal(expt_im + 1, aug_im[0, i, :, :]).all()) 374 | 375 | i += 1 376 | 377 | def test_multi_image_segmentation_generator_preprocessing_with_other_inputs(self): 378 | other_paths = [ 379 | self.other_temp_dir, self.other_temp_dir_2 380 | ] 381 | random.seed(0) 382 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, 383 | self.test_temp_dir, 384 | other_inputs_paths=other_paths) 385 | 386 | random.seed(0) 387 | random.shuffle(image_seg_pairs) 388 | 389 | random.seed(0) 390 | generator = data_loader.image_segmentation_generator( 391 | self.train_temp_dir, self.test_temp_dir, 1, 392 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size, 393 | self.image_size, 394 | preprocessing=[lambda x: x+1, lambda x: x+2, lambda x: x+3], other_inputs_paths=other_paths 395 | ) 396 | 397 | i = 0 398 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs): 399 | if i >= len(image_seg_pairs): 400 | break 401 | 402 | ims = [expt_im_f] 403 | ims.extend(expt_oth) 404 | 405 | for i in range(aug_im.shape[1]): 406 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size, 407 | ordering='channel_last') 408 | 409 | self.assertTrue(np.equal(expt_im + (i + 1), aug_im[0, i, :, :]).all()) 410 | 411 | i += 1 412 | 413 | 414 | 415 | -------------------------------------------------------------------------------- /test/unit/data_utils/test_visualize_dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/data_utils/test_visualize_dataset.py -------------------------------------------------------------------------------- /test/unit/models/test_basic_models.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/models/test_basic_models.py -------------------------------------------------------------------------------- /test/unit/test_metrics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/test_metrics.py -------------------------------------------------------------------------------- /test/unit/test_predict.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/test_predict.py -------------------------------------------------------------------------------- /test/unit/test_pretrained.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/test_pretrained.py -------------------------------------------------------------------------------- /test/unit/test_train.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tempfile 3 | from shutil import rmtree 4 | import os 5 | import six 6 | from keras_segmentation import train 7 | 8 | class TestTrainInternalFunctions(unittest.TestCase): 9 | """ Test internal functions of the module """ 10 | 11 | def setUp(self): 12 | self.tmp_dir = tempfile.mkdtemp() 13 | 14 | def tearDown(self): 15 | rmtree(self.tmp_dir) 16 | 17 | def test_find_latest_checkpoint(self): 18 | # Populate a folder of images and try checkpoint 19 | checkpoints_path = os.path.join(self.tmp_dir, "test1") 20 | # Create files 21 | self.assertEqual(None, train.find_latest_checkpoint(checkpoints_path)) 22 | # When fail_safe is turned off, throw an exception when no checkpoint is found. 23 | six.assertRaisesRegex(self, ValueError, "Checkpoint path", train.find_latest_checkpoint, checkpoints_path, False) 24 | for suffix in ["0", "2", "4", "12", "_config.json", "ABC"]: 25 | open(checkpoints_path + '.' + suffix, 'a').close() 26 | self.assertEqual(checkpoints_path + ".12", train.find_latest_checkpoint(checkpoints_path)) --------------------------------------------------------------------------------