├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── docker
├── Dockerfile
└── README.md
├── keras_segmentation
├── __init__.py
├── __main__.py
├── cli_interface.py
├── data_utils
│ ├── __init__.py
│ ├── augmentation.py
│ ├── data_loader.py
│ └── visualize_dataset.py
├── metrics.py
├── model_compression.py
├── models
│ ├── __init__.py
│ ├── _pspnet_2.py
│ ├── all_models.py
│ ├── basic_models.py
│ ├── config.py
│ ├── fcn.py
│ ├── mobilenet.py
│ ├── model.py
│ ├── model_utils.py
│ ├── pspnet.py
│ ├── resnet50.py
│ ├── segnet.py
│ ├── unet.py
│ └── vgg16.py
├── predict.py
├── pretrained.py
└── train.py
├── requirements.txt
├── sample_images
├── 1_input.jpg
├── 1_output.png
├── 2_input.jpg
├── 2_output.png
├── 3_input.jpg
├── 3_output.png
├── liner_dataset.png
├── liner_export.png
├── liner_testing.png
└── liner_training.png
├── setup.cfg
├── setup.py
└── test
├── __init__.py
├── example_dataset
├── annotations_prepped_test
│ ├── 0016E5_07959.png
│ ├── 0016E5_07961.png
│ └── 0016E5_07963.png
├── annotations_prepped_train
│ ├── 0001TP_006690.png
│ ├── 0001TP_006720.png
│ ├── 0001TP_006750.png
│ ├── 0001TP_006780.png
│ └── 0001TP_006810.png
├── images_prepped_test
│ ├── 0016E5_07959.png
│ ├── 0016E5_07961.png
│ └── 0016E5_07963.png
└── images_prepped_train
│ ├── 0001TP_006690.png
│ ├── 0001TP_006720.png
│ ├── 0001TP_006750.png
│ ├── 0001TP_006780.png
│ └── 0001TP_006810.png
├── test_models.py
└── unit
├── data_utils
├── test_augmentation.py
├── test_data_loader.py
└── test_visualize_dataset.py
├── models
└── test_basic_models.py
├── test_metrics.py
├── test_predict.py
├── test_pretrained.py
└── test_train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 |
3 |
4 |
5 | sync.expect
6 |
7 |
8 | # OS generated files #
9 | ######################
10 | .DS_Store
11 | .DS_Store?
12 | ehthumbs.db
13 | Icon?
14 | Thumbs.db
15 |
16 |
17 |
18 | # Byte-compiled / optimized / DLL files
19 | __pycache__/
20 | *.py[cod]
21 | *$py.class
22 |
23 | # C extensions
24 | *.so
25 |
26 | # Distribution / packaging
27 | .Python
28 | env/
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .coverage
58 | .coverage.*
59 | .cache
60 | nosetests.xml
61 | coverage.xml
62 | *,cover
63 | .hypothesis/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 |
73 | # Flask instance folder
74 | instance/
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # IPython Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # dotenv
95 | .env
96 |
97 | # virtualenv
98 | venv/
99 | ENV/
100 |
101 | # Spyder project settings
102 | .spyderproject
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # VSCode settings
108 | .vscode/
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "2.7"
4 | # - "3.5" removing 3.5 due to some strange error on travis.
5 | - "3.6" # current default Python on Travis CI
6 | - "3.7"
7 | # command to install dependencies
8 | install:
9 | # Install with tests default extras-require
10 | - pip install .[tests-default]
11 | # command to run tests
12 | script: pytest
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Divam Gupta
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image Segmentation Keras : Implementation of Segnet, FCN, UNet, PSPNet and other models in Keras.
2 |
3 | [](https://badge.fury.io/py/keras-segmentation)
4 | [](https://pepy.tech/project/keras-segmentation)
5 | [](https://travis-ci.org/divamgupta/image-segmentation-keras)
6 | [](http://perso.crans.org/besson/LICENSE.html)
7 | [](https://twitter.com/divamgupta)
8 |
9 |
10 |
11 | Implementation of various Deep Image Segmentation models in keras.
12 |
13 | ### News : Some functionality of this repository has been integrated with https://liner.ai . Check it out!!
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | Link to the full blog post with tutorial : https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html
22 |
23 |
24 | ## Working Google Colab Examples:
25 | * Python Interface: https://colab.research.google.com/drive/1q_eCYEzKxixpCKH1YDsLnsvgxl92ORcv?usp=sharing
26 | * CLI Interface: https://colab.research.google.com/drive/1Kpy4QGFZ2ZHm69mPfkmLSUes8kj6Bjyi?usp=sharing
27 |
28 | ## Training using GUI interface
29 | You can also train segmentation models on your computer with https://liner.ai
30 |
31 | Train | Inference / Export
32 | :-------------------------:|:-------------------------:
33 |  | 
34 |  | 
35 |
36 |
37 | ## Models
38 |
39 | Following models are supported:
40 |
41 | | model_name | Base Model | Segmentation Model |
42 | |------------------|-------------------|--------------------|
43 | | fcn_8 | Vanilla CNN | FCN8 |
44 | | fcn_32 | Vanilla CNN | FCN8 |
45 | | fcn_8_vgg | VGG 16 | FCN8 |
46 | | fcn_32_vgg | VGG 16 | FCN32 |
47 | | fcn_8_resnet50 | Resnet-50 | FCN32 |
48 | | fcn_32_resnet50 | Resnet-50 | FCN32 |
49 | | fcn_8_mobilenet | MobileNet | FCN32 |
50 | | fcn_32_mobilenet | MobileNet | FCN32 |
51 | | pspnet | Vanilla CNN | PSPNet |
52 | | pspnet_50 | Vanilla CNN | PSPNet |
53 | | pspnet_101 | Vanilla CNN | PSPNet |
54 | | vgg_pspnet | VGG 16 | PSPNet |
55 | | resnet50_pspnet | Resnet-50 | PSPNet |
56 | | unet_mini | Vanilla Mini CNN | U-Net |
57 | | unet | Vanilla CNN | U-Net |
58 | | vgg_unet | VGG 16 | U-Net |
59 | | resnet50_unet | Resnet-50 | U-Net |
60 | | mobilenet_unet | MobileNet | U-Net |
61 | | segnet | Vanilla CNN | Segnet |
62 | | vgg_segnet | VGG 16 | Segnet |
63 | | resnet50_segnet | Resnet-50 | Segnet |
64 | | mobilenet_segnet | MobileNet | Segnet |
65 |
66 |
67 | Example results for the pre-trained models provided :
68 |
69 | Input Image | Output Segmentation Image
70 | :-------------------------:|:-------------------------:
71 |  | 
72 |  | 
73 |
74 |
75 | ## How to cite
76 |
77 | If you are using this library, please cite using:
78 |
79 | ```
80 | @article{gupta2023image,
81 | title={Image segmentation keras: Implementation of segnet, fcn, unet, pspnet and other models in keras},
82 | author={Gupta, Divam},
83 | journal={arXiv preprint arXiv:2307.13215},
84 | year={2023}
85 | }
86 |
87 | ```
88 |
89 |
90 | ## Getting Started
91 |
92 | ### Prerequisites
93 |
94 | * Keras ( recommended version : 2.4.3 )
95 | * OpenCV for Python
96 | * Tensorflow ( recommended version : 2.4.1 )
97 |
98 | ```shell
99 | apt-get install -y libsm6 libxext6 libxrender-dev
100 | pip install opencv-python
101 | ```
102 |
103 | ### Installing
104 |
105 | Install the module
106 |
107 | Recommended way:
108 | ```shell
109 | pip install --upgrade git+https://github.com/divamgupta/image-segmentation-keras
110 | ```
111 |
112 | ### or
113 |
114 | ```shell
115 | pip install keras-segmentation
116 | ```
117 |
118 | ### or
119 |
120 | ```shell
121 | git clone https://github.com/divamgupta/image-segmentation-keras
122 | cd image-segmentation-keras
123 | python setup.py install
124 | ```
125 |
126 |
127 | ## Pre-trained models:
128 | ```python
129 | from keras_segmentation.pretrained import pspnet_50_ADE_20K , pspnet_101_cityscapes, pspnet_101_voc12
130 |
131 | model = pspnet_50_ADE_20K() # load the pretrained model trained on ADE20k dataset
132 |
133 | model = pspnet_101_cityscapes() # load the pretrained model trained on Cityscapes dataset
134 |
135 | model = pspnet_101_voc12() # load the pretrained model trained on Pascal VOC 2012 dataset
136 |
137 | # load any of the 3 pretrained models
138 |
139 | out = model.predict_segmentation(
140 | inp="input_image.jpg",
141 | out_fname="out.png"
142 | )
143 |
144 | ```
145 |
146 |
147 | ### Preparing the data for training
148 |
149 | You need to make two folders
150 |
151 | * Images Folder - For all the training images
152 | * Annotations Folder - For the corresponding ground truth segmentation images
153 |
154 | The filenames of the annotation images should be same as the filenames of the RGB images.
155 |
156 | The size of the annotation image for the corresponding RGB image should be same.
157 |
158 | For each pixel in the RGB image, the class label of that pixel in the annotation image would be the value of the blue pixel.
159 |
160 | Example code to generate annotation images :
161 |
162 | ```python
163 | import cv2
164 | import numpy as np
165 |
166 | ann_img = np.zeros((30,30,3)).astype('uint8')
167 | ann_img[ 3 , 4 ] = 1 # this would set the label of pixel 3,4 as 1
168 |
169 | cv2.imwrite( "ann_1.png" ,ann_img )
170 | ```
171 |
172 | Only use bmp or png format for the annotation images.
173 |
174 | ## Download the sample prepared dataset
175 |
176 | Download and extract the following:
177 |
178 | https://drive.google.com/file/d/0B0d9ZiqAgFkiOHR1NTJhWVJMNEU/view?usp=sharing
179 |
180 | You will get a folder named dataset1/
181 |
182 |
183 | ## Using the python module
184 |
185 | You can import keras_segmentation in your python script and use the API
186 |
187 | ```python
188 | from keras_segmentation.models.unet import vgg_unet
189 |
190 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608 )
191 |
192 | model.train(
193 | train_images = "dataset1/images_prepped_train/",
194 | train_annotations = "dataset1/annotations_prepped_train/",
195 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5
196 | )
197 |
198 | out = model.predict_segmentation(
199 | inp="dataset1/images_prepped_test/0016E5_07965.png",
200 | out_fname="/tmp/out.png"
201 | )
202 |
203 | import matplotlib.pyplot as plt
204 | plt.imshow(out)
205 |
206 | # evaluating the model
207 | print(model.evaluate_segmentation( inp_images_dir="dataset1/images_prepped_test/" , annotations_dir="dataset1/annotations_prepped_test/" ) )
208 |
209 | ```
210 |
211 |
212 | ## Usage via command line
213 | You can also use the tool just using command line
214 |
215 | ### Visualizing the prepared data
216 |
217 | You can also visualize your prepared annotations for verification of the prepared data.
218 |
219 |
220 | ```shell
221 | python -m keras_segmentation verify_dataset \
222 | --images_path="dataset1/images_prepped_train/" \
223 | --segs_path="dataset1/annotations_prepped_train/" \
224 | --n_classes=50
225 | ```
226 |
227 | ```shell
228 | python -m keras_segmentation visualize_dataset \
229 | --images_path="dataset1/images_prepped_train/" \
230 | --segs_path="dataset1/annotations_prepped_train/" \
231 | --n_classes=50
232 | ```
233 |
234 |
235 |
236 | ### Training the Model
237 |
238 | To train the model run the following command:
239 |
240 | ```shell
241 | python -m keras_segmentation train \
242 | --checkpoints_path="path_to_checkpoints" \
243 | --train_images="dataset1/images_prepped_train/" \
244 | --train_annotations="dataset1/annotations_prepped_train/" \
245 | --val_images="dataset1/images_prepped_test/" \
246 | --val_annotations="dataset1/annotations_prepped_test/" \
247 | --n_classes=50 \
248 | --input_height=320 \
249 | --input_width=640 \
250 | --model_name="vgg_unet"
251 | ```
252 |
253 | Choose model_name from the table above
254 |
255 |
256 |
257 | ### Getting the predictions
258 |
259 | To get the predictions of a trained model
260 |
261 | ```shell
262 | python -m keras_segmentation predict \
263 | --checkpoints_path="path_to_checkpoints" \
264 | --input_path="dataset1/images_prepped_test/" \
265 | --output_path="path_to_predictions"
266 |
267 | ```
268 |
269 |
270 |
271 | ### Video inference
272 |
273 | To get predictions of a video
274 | ```shell
275 | python -m keras_segmentation predict_video \
276 | --checkpoints_path="path_to_checkpoints" \
277 | --input="path_to_video" \
278 | --output_file="path_for_save_inferenced_video" \
279 | --display
280 | ```
281 |
282 | If you want to make predictions on your webcam, don't use `--input`, or pass your device number: `--input 0`
283 | `--display` opens a window with the predicted video. Remove this argument when using a headless system.
284 |
285 |
286 | ### Model Evaluation
287 |
288 | To get the IoU scores
289 |
290 | ```shell
291 | python -m keras_segmentation evaluate_model \
292 | --checkpoints_path="path_to_checkpoints" \
293 | --images_path="dataset1/images_prepped_test/" \
294 | --segs_path="dataset1/annotations_prepped_test/"
295 | ```
296 |
297 |
298 |
299 | ## Fine-tuning from existing segmentation model
300 |
301 | The following example shows how to fine-tune a model with 10 classes .
302 |
303 | ```python
304 | from keras_segmentation.models.model_utils import transfer_weights
305 | from keras_segmentation.pretrained import pspnet_50_ADE_20K
306 | from keras_segmentation.models.pspnet import pspnet_50
307 |
308 | pretrained_model = pspnet_50_ADE_20K()
309 |
310 | new_model = pspnet_50( n_classes=51 )
311 |
312 | transfer_weights( pretrained_model , new_model ) # transfer weights from pre-trained model to your model
313 |
314 | new_model.train(
315 | train_images = "dataset1/images_prepped_train/",
316 | train_annotations = "dataset1/annotations_prepped_train/",
317 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5
318 | )
319 |
320 |
321 | ```
322 |
323 |
324 |
325 | ## Knowledge distillation for compressing the model
326 |
327 | The following example shows transfer the knowledge from a larger ( and more accurate ) model to a smaller model. In most cases the smaller model trained via knowledge distilation is more accurate compared to the same model trained using vanilla supervised learning.
328 |
329 | ```python
330 | from keras_segmentation.predict import model_from_checkpoint_path
331 | from keras_segmentation.models.unet import unet_mini
332 | from keras_segmentation.model_compression import perform_distilation
333 |
334 | model_large = model_from_checkpoint_path( "/checkpoints/path/of/trained/model" )
335 | model_small = unet_mini( n_classes=51, input_height=300, input_width=400 )
336 |
337 | perform_distilation ( data_path="/path/to/large_image_set/" , checkpoints_path="path/to/save/checkpoints" ,
338 | teacher_model=model_large , student_model=model_small , distilation_loss='kl' , feats_distilation_loss='pa' )
339 |
340 | ```
341 |
342 |
343 |
344 |
345 |
346 | ## Adding custom augmentation function to training
347 |
348 | The following example shows how to define a custom augmentation function for training.
349 |
350 | ```python
351 |
352 | from keras_segmentation.models.unet import vgg_unet
353 | from imgaug import augmenters as iaa
354 |
355 | def custom_augmentation():
356 | return iaa.Sequential(
357 | [
358 | # apply the following augmenters to most images
359 | iaa.Fliplr(0.5), # horizontally flip 50% of all images
360 | iaa.Flipud(0.5), # horizontally flip 50% of all images
361 | ])
362 |
363 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608)
364 |
365 | model.train(
366 | train_images = "dataset1/images_prepped_train/",
367 | train_annotations = "dataset1/annotations_prepped_train/",
368 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
369 | do_augment=True, # enable augmentation
370 | custom_augmentation=custom_augmentation # sets the augmention function to use
371 | )
372 | ```
373 | ## Custom number of input channels
374 |
375 | The following example shows how to set the number of input channels.
376 |
377 | ```python
378 |
379 | from keras_segmentation.models.unet import vgg_unet
380 |
381 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608,
382 | channels=1 # Sets the number of input channels
383 | )
384 |
385 | model.train(
386 | train_images = "dataset1/images_prepped_train/",
387 | train_annotations = "dataset1/annotations_prepped_train/",
388 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
389 | read_image_type=0 # Sets how opencv will read the images
390 | # cv2.IMREAD_COLOR = 1 (rgb),
391 | # cv2.IMREAD_GRAYSCALE = 0,
392 | # cv2.IMREAD_UNCHANGED = -1 (4 channels like RGBA)
393 | )
394 | ```
395 |
396 | ## Custom preprocessing
397 |
398 | The following example shows how to set a custom image preprocessing function.
399 |
400 | ```python
401 |
402 | from keras_segmentation.models.unet import vgg_unet
403 |
404 | def image_preprocessing(image):
405 | return image + 1
406 |
407 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608)
408 |
409 | model.train(
410 | train_images = "dataset1/images_prepped_train/",
411 | train_annotations = "dataset1/annotations_prepped_train/",
412 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
413 | preprocessing=image_preprocessing # Sets the preprocessing function
414 | )
415 | ```
416 |
417 | ## Custom callbacks
418 |
419 | The following example shows how to set custom callbacks for the model training.
420 |
421 | ```python
422 |
423 | from keras_segmentation.models.unet import vgg_unet
424 | from keras.callbacks import ModelCheckpoint, EarlyStopping
425 |
426 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608 )
427 |
428 | # When using custom callbacks, the default checkpoint saver is removed
429 | callbacks = [
430 | ModelCheckpoint(
431 | filepath="checkpoints/" + model.name + ".{epoch:05d}",
432 | save_weights_only=True,
433 | verbose=True
434 | ),
435 | EarlyStopping()
436 | ]
437 |
438 | model.train(
439 | train_images = "dataset1/images_prepped_train/",
440 | train_annotations = "dataset1/annotations_prepped_train/",
441 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
442 | callbacks=callbacks
443 | )
444 | ```
445 |
446 | ## Multi input image input
447 |
448 | The following example shows how to add additional image inputs for models.
449 |
450 | ```python
451 |
452 | from keras_segmentation.models.unet import vgg_unet
453 |
454 | model = vgg_unet(n_classes=51 , input_height=416, input_width=608)
455 |
456 | model.train(
457 | train_images = "dataset1/images_prepped_train/",
458 | train_annotations = "dataset1/annotations_prepped_train/",
459 | checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
460 | other_inputs_paths=[
461 | "/path/to/other/directory"
462 | ],
463 |
464 |
465 | # Ability to add preprocessing
466 | preprocessing=[lambda x: x+1, lambda x: x+2, lambda x: x+3], # Different prepocessing for each input
467 | # OR
468 | preprocessing=lambda x: x+1, # Same preprocessing for each input
469 | )
470 | ```
471 |
472 |
473 | ## Projects using keras-segmentation
474 | Here are a few projects which are using our library :
475 | * https://github.com/SteliosTsop/QF-image-segmentation-keras [paper](https://arxiv.org/pdf/1908.02242.pdf)
476 | * https://github.com/willembressers/bouquet_quality
477 | * https://github.com/jqueguiner/image-segmentation
478 | * https://github.com/pan0rama/CS230-Microcrystal-Facet-Segmentation
479 | * https://github.com/theerawatramchuen/Keras_Segmentation
480 | * https://github.com/neheller/labels18
481 | * https://github.com/Divyam10/Face-Matting-using-Unet
482 | * https://github.com/shsh-a/segmentation-over-web
483 | * https://github.com/chenwe73/deep_active_learning_segmentation
484 | * https://github.com/vigneshrajap/vision-based-navigation-agri-fields
485 | * https://github.com/ronalddas/Pneumonia-Detection
486 | * https://github.com/Aiwiscal/ECG_UNet
487 | * https://github.com/TianzhongSong/Unet-for-Person-Segmentation
488 | * https://github.com/Guyanqi/GMDNN
489 | * https://github.com/kozemzak/prostate-lesion-segmentation
490 | * https://github.com/lixiaoyu12138/fcn-date
491 | * https://github.com/sagarbhokre/LyftChallenge
492 | * https://github.com/TianzhongSong/Person-Segmentation-Keras
493 | * https://github.com/divyanshpuri02/COCO_2018-Stuff-Segmentation-Challenge
494 | * https://github.com/XiangbingJi/Stanford-cs230-final-project
495 | * https://github.com/lsh1994/keras-segmentation
496 | * https://github.com/SpirinEgor/mobile_semantic_segmentation
497 | * https://github.com/LeadingIndiaAI/COCO-DATASET-STUFF-SEGMENTATION-CHALLENGE
498 | * https://github.com/lidongyue12138/Image-Segmentation-by-Keras
499 | * https://github.com/laoj2/segnet_crfasrnn
500 | * https://github.com/rancheng/AirSimProjects
501 | * https://github.com/RadiumScriptTang/cartoon_segmentation
502 | * https://github.com/dquail/NerveSegmentation
503 | * https://github.com/Bhomik/SemanticHumanMatting
504 | * https://github.com/Symefa/FP-Biomedik-Breast-Cancer
505 | * https://github.com/Alpha-Monocerotis/PDF_FigureTable_Extraction
506 | * https://github.com/rusito-23/mobile_unet_segmentation
507 | * https://github.com/Philliec459/ThinSection-image-segmentation-keras
508 | * https://github.com/imsadia/cv-assignment-three.git
509 | * https://github.com/kejitan/ESVGscale
510 |
511 | If you use our code in a publicly available project, please add the link here ( by posting an issue or creating a PR )
512 |
513 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tanmaniac/opencv3-cudagl
2 |
3 | # install prerequisites
4 | RUN apt-get update \
5 | && apt-get install -y wget git curl nano \
6 | && apt-get install -y libsm6 libxext6 libxrender-dev
7 |
8 | # install Cudnn
9 | ENV CUDNN_VERSION 7.6.0.64
10 | RUN apt-get update && apt-get install -y --no-install-recommends \
11 | libcudnn7=$CUDNN_VERSION-1+cuda9.0 \
12 | libcudnn7-dev=$CUDNN_VERSION-1+cuda9.0 && \
13 | apt-mark hold libcudnn7 && \
14 | rm -rf /var/lib/apt/lists/*
15 |
16 | # Install Miniconda
17 | RUN curl -so /miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh \
18 | && chmod +x /miniconda.sh \
19 | && /miniconda.sh -b -p /miniconda \
20 | && rm /miniconda.sh
21 |
22 | # Create a Python 3.6 environment
23 | ENV PATH=/miniconda/bin:$PATH
24 |
25 | RUN /miniconda/bin/conda install -y conda-build \
26 | && /miniconda/bin/conda create -y --name unet python=3.6.7 \
27 | && /miniconda/bin/conda clean -ya
28 |
29 | ENV CONDA_DEFAULT_ENV=unet
30 | ENV CONDA_PREFIX=/miniconda/envs/$CONDA_DEFAULT_ENV
31 | ENV PATH=$CONDA_PREFIX/bin:$PATH
32 | ENV CONDA_AUTO_UPDATE_CONDA=false
33 |
34 | RUN conda install -y ipython tensorflow-gpu=1.14.0 keras=2.3.1
35 |
36 | # install model library
37 | RUN git clone https://github.com/divamgupta/image-segmentation-keras.git
38 | WORKDIR /image-segmentation-keras
39 | RUN python setup.py install
40 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | ## Installing Docker
2 |
3 | General installation instructions are
4 | [on the Docker site](https://docs.docker.com/installation/), but we give some
5 | quick links here:
6 |
7 | * [OSX](https://docs.docker.com/installation/mac/): [docker toolbox](https://www.docker.com/toolbox)
8 | * [ubuntu](https://docs.docker.com/installation/ubuntulinux/)
9 |
10 | For GPU support, install compatible NVIDIA drivers with CUDA9.0 and CUDNN 7.6
11 |
12 | ## Running the container
13 |
14 | Build the container:
15 |
16 | $ docker build -t isk
17 |
18 | To run the image:
19 |
20 | $ docker run --gpus all -it isk
21 |
22 | If you want to train with a dataset on your local machine, or make inference on images or videos, mount a volume to share this data with the docker container:
23 |
24 | $ docker run --gpus all -v /path/to/data/folder:/image-segmentation-keras/share -it isk
25 |
26 | If graphical interface is needed, to show results, like `predict_video --display`, first let docker to use system interface. In your local host, type this line once:
27 |
28 | $ xhost +local:docker
29 |
30 | And run the container with access to X11:
31 |
32 | $ docker run --gpus all -v /path/to/data/folder:/image-segmentation-keras/share -e DISPLAY=$DISPLAY -v /tmp/.X11-unix/:/tmp/.X11-unix --env QT_X11_NO_MITSHM=1 -it isk
33 |
--------------------------------------------------------------------------------
/keras_segmentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/keras_segmentation/__init__.py
--------------------------------------------------------------------------------
/keras_segmentation/__main__.py:
--------------------------------------------------------------------------------
1 |
2 | def main():
3 | from . import cli_interface
4 | cli_interface.main()
5 |
6 | if __name__ == "__main__":
7 | main()
8 |
--------------------------------------------------------------------------------
/keras_segmentation/cli_interface.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import sys
4 | import argparse
5 |
6 | from .train import train
7 | from .predict import predict, predict_multiple, predict_video, evaluate
8 | from .data_utils.data_loader import verify_segmentation_dataset
9 | from .data_utils.visualize_dataset import visualize_segmentation_dataset
10 |
11 |
12 | def train_action(command_parser):
13 | parser = command_parser.add_parser('train')
14 | parser.add_argument("--model_name", type=str, required=True)
15 | parser.add_argument("--train_images", type=str, required=True)
16 | parser.add_argument("--train_annotations", type=str, required=True)
17 |
18 | parser.add_argument("--n_classes", type=int, required=True)
19 | parser.add_argument("--input_height", type=int, default=None)
20 | parser.add_argument("--input_width", type=int, default=None)
21 |
22 | parser.add_argument('--not_verify_dataset', action='store_false')
23 | parser.add_argument("--checkpoints_path", type=str, default=None)
24 | parser.add_argument("--epochs", type=int, default=5)
25 | parser.add_argument("--batch_size", type=int, default=2)
26 |
27 | parser.add_argument('--validate', action='store_true')
28 | parser.add_argument("--val_images", type=str, default="")
29 | parser.add_argument("--val_annotations", type=str, default="")
30 |
31 | parser.add_argument("--val_batch_size", type=int, default=2)
32 | parser.add_argument("--load_weights", type=str, default=None)
33 | parser.add_argument('--auto_resume_checkpoint', action='store_true')
34 |
35 | parser.add_argument("--steps_per_epoch", type=int, default=512)
36 | parser.add_argument("--optimizer_name", type=str, default="adam")
37 |
38 | def action(args):
39 | return train(model=args.model_name,
40 | train_images=args.train_images,
41 | train_annotations=args.train_annotations,
42 | input_height=args.input_height,
43 | input_width=args.input_width,
44 | n_classes=args.n_classes,
45 | verify_dataset=args.not_verify_dataset,
46 | checkpoints_path=args.checkpoints_path,
47 | epochs=args.epochs,
48 | batch_size=args.batch_size,
49 | validate=args.validate,
50 | val_images=args.val_images,
51 | val_annotations=args.val_annotations,
52 | val_batch_size=args.val_batch_size,
53 | auto_resume_checkpoint=args.auto_resume_checkpoint,
54 | load_weights=args.load_weights,
55 | steps_per_epoch=args.steps_per_epoch,
56 | optimizer_name=args.optimizer_name)
57 |
58 | parser.set_defaults(func=action)
59 |
60 |
61 | def predict_action(command_parser):
62 |
63 | parser = command_parser.add_parser('predict')
64 | parser.add_argument("--checkpoints_path", type=str, required=True)
65 | parser.add_argument("--input_path", type=str, default="", required=True)
66 | parser.add_argument("--output_path", type=str, default="", required=True)
67 |
68 | def action(args):
69 | input_path_extension = args.input_path.split('.')[-1]
70 | if input_path_extension in ['jpg', 'jpeg', 'png']:
71 | return predict(inp=args.input_path, out_fname=args.output_path,
72 | checkpoints_path=args.checkpoints_path)
73 | else:
74 | return predict_multiple(inp_dir=args.input_path,
75 | out_dir=args.output_path,
76 | checkpoints_path=args.checkpoints_path)
77 |
78 | parser.set_defaults(func=action)
79 |
80 |
81 | def predict_video_action(command_parser):
82 | parser = command_parser.add_parser('predict_video')
83 | parser.add_argument("--input", type=str, default=0, required=False)
84 | parser.add_argument("--output_file", type=str, default="", required=False)
85 | parser.add_argument("--checkpoints_path", required=True)
86 | parser.add_argument("--display", action='store_true', required=False)
87 |
88 | def action(args):
89 | return predict_video(inp=args.input,
90 | output=args.output_file,
91 | checkpoints_path=args.checkpoints_path,
92 | display=args.display,
93 | )
94 |
95 | parser.set_defaults(func=action)
96 |
97 |
98 | def evaluate_model_action(command_parser):
99 |
100 | parser = command_parser.add_parser('evaluate_model')
101 | parser.add_argument("--images_path", type=str, required=True)
102 | parser.add_argument("--segs_path", type=str, required=True)
103 | parser.add_argument("--checkpoints_path", type=str, required=True)
104 |
105 | def action(args):
106 | print(evaluate(
107 | inp_images_dir=args.images_path, annotations_dir=args.segs_path,
108 | checkpoints_path=args.checkpoints_path))
109 |
110 | parser.set_defaults(func=action)
111 |
112 |
113 | def verify_dataset_action(command_parser):
114 |
115 | parser = command_parser.add_parser('verify_dataset')
116 | parser.add_argument("--images_path", type=str)
117 | parser.add_argument("--segs_path", type=str)
118 | parser.add_argument("--n_classes", type=int)
119 |
120 | def action(args):
121 | verify_segmentation_dataset(
122 | args.images_path, args.segs_path, args.n_classes)
123 |
124 | parser.set_defaults(func=action)
125 |
126 |
127 | def visualize_dataset_action(command_parser):
128 |
129 | parser = command_parser.add_parser('visualize_dataset')
130 | parser.add_argument("--images_path", type=str)
131 | parser.add_argument("--segs_path", type=str)
132 | parser.add_argument("--n_classes", type=int)
133 | parser.add_argument('--do_augment', action='store_true')
134 |
135 | def action(args):
136 | visualize_segmentation_dataset(args.images_path, args.segs_path,
137 | args.n_classes,
138 | do_augment=args.do_augment)
139 |
140 | parser.set_defaults(func=action)
141 |
142 |
143 | def main():
144 | assert len(sys.argv) >= 2, \
145 | "python -m keras_segmentation "
146 |
147 | main_parser = argparse.ArgumentParser()
148 | command_parser = main_parser.add_subparsers()
149 |
150 | # Add individual commands
151 | train_action(command_parser)
152 | predict_action(command_parser)
153 | predict_video_action(command_parser)
154 | verify_dataset_action(command_parser)
155 | visualize_dataset_action(command_parser)
156 | evaluate_model_action(command_parser)
157 |
158 | args = main_parser.parse_args()
159 |
160 | args.func(args)
161 |
--------------------------------------------------------------------------------
/keras_segmentation/data_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/keras_segmentation/data_utils/__init__.py
--------------------------------------------------------------------------------
/keras_segmentation/data_utils/augmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | try:
4 | import imgaug as ia
5 | from imgaug import augmenters as iaa
6 | except ImportError:
7 | print("Error in loading augmentation, can't import imgaug."
8 | "Please make sure it is installed.")
9 |
10 |
11 | IMAGE_AUGMENTATION_SEQUENCE = None
12 | IMAGE_AUGMENTATION_NUM_TRIES = 10
13 |
14 | loaded_augmentation_name = ""
15 |
16 |
17 | def _load_augmentation_aug_geometric():
18 | return iaa.OneOf([
19 | iaa.Sequential([iaa.Fliplr(0.5), iaa.Flipud(0.2)]),
20 | iaa.CropAndPad(percent=(-0.05, 0.1),
21 | pad_mode='constant',
22 | pad_cval=(0, 255)),
23 | iaa.Crop(percent=(0.0, 0.1)),
24 | iaa.Crop(percent=(0.3, 0.5)),
25 | iaa.Crop(percent=(0.3, 0.5)),
26 | iaa.Crop(percent=(0.3, 0.5)),
27 | iaa.Sequential([
28 | iaa.Affine(
29 | # scale images to 80-120% of their size,
30 | # individually per axis
31 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
32 | # translate by -20 to +20 percent (per axis)
33 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
34 | rotate=(-45, 45), # rotate by -45 to +45 degrees
35 | shear=(-16, 16), # shear by -16 to +16 degrees
36 | # use nearest neighbour or bilinear interpolation (fast)
37 | order=[0, 1],
38 | # if mode is constant, use a cval between 0 and 255
39 | mode='constant',
40 | cval=(0, 255),
41 | # use any of scikit-image's warping modes
42 | # (see 2nd image from the top for examples)
43 | ),
44 | iaa.Sometimes(0.3, iaa.Crop(percent=(0.3, 0.5)))])
45 | ])
46 |
47 |
48 | def _load_augmentation_aug_non_geometric():
49 | return iaa.Sequential([
50 | iaa.Sometimes(0.3, iaa.Multiply((0.5, 1.5), per_channel=0.5)),
51 | iaa.Sometimes(0.2, iaa.JpegCompression(compression=(70, 99))),
52 | iaa.Sometimes(0.2, iaa.GaussianBlur(sigma=(0, 3.0))),
53 | iaa.Sometimes(0.2, iaa.MotionBlur(k=15, angle=[-45, 45])),
54 | iaa.Sometimes(0.2, iaa.MultiplyHue((0.5, 1.5))),
55 | iaa.Sometimes(0.2, iaa.MultiplySaturation((0.5, 1.5))),
56 | iaa.Sometimes(0.34, iaa.MultiplyHueAndSaturation((0.5, 1.5),
57 | per_channel=True)),
58 | iaa.Sometimes(0.34, iaa.Grayscale(alpha=(0.0, 1.0))),
59 | iaa.Sometimes(0.2, iaa.ChangeColorTemperature((1100, 10000))),
60 | iaa.Sometimes(0.1, iaa.GammaContrast((0.5, 2.0))),
61 | iaa.Sometimes(0.2, iaa.SigmoidContrast(gain=(3, 10),
62 | cutoff=(0.4, 0.6))),
63 | iaa.Sometimes(0.1, iaa.CLAHE()),
64 | iaa.Sometimes(0.1, iaa.HistogramEqualization()),
65 | iaa.Sometimes(0.2, iaa.LinearContrast((0.5, 2.0), per_channel=0.5)),
66 | iaa.Sometimes(0.1, iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)))
67 | ])
68 |
69 |
70 | def _load_augmentation_aug_all2():
71 | return iaa.Sequential([
72 | iaa.Sometimes(0.65, _load_augmentation_aug_non_geometric()),
73 | iaa.Sometimes(0.65, _load_augmentation_aug_geometric())
74 | ])
75 |
76 |
77 | def _load_augmentation_aug_all():
78 | """ Load image augmentation model """
79 |
80 | def sometimes(aug):
81 | return iaa.Sometimes(0.5, aug)
82 |
83 | return iaa.Sequential(
84 | [
85 | # apply the following augmenters to most images
86 | iaa.Fliplr(0.5), # horizontally flip 50% of all images
87 | iaa.Flipud(0.2), # vertically flip 20% of all images
88 | # crop images by -5% to 10% of their height/width
89 | sometimes(iaa.CropAndPad(
90 | percent=(-0.05, 0.1),
91 | pad_mode='constant',
92 | pad_cval=(0, 255)
93 | )),
94 | sometimes(iaa.Affine(
95 | # scale images to 80-120% of their size, individually per axis
96 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
97 | # translate by -20 to +20 percent (per axis)
98 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
99 | rotate=(-45, 45), # rotate by -45 to +45 degrees
100 | shear=(-16, 16), # shear by -16 to +16 degrees
101 | # use nearest neighbour or bilinear interpolation (fast)
102 | order=[0, 1],
103 | # if mode is constant, use a cval between 0 and 255
104 | cval=(0, 255),
105 | # use any of scikit-image's warping modes
106 | # (see 2nd image from the top for examples)
107 | mode='constant'
108 | )),
109 | # execute 0 to 5 of the following (less important) augmenters per
110 | # image don't execute all of them, as that would often be way too
111 | # strong
112 | iaa.SomeOf((0, 5),
113 | [
114 | # convert images into their superpixel representation
115 | sometimes(iaa.Superpixels(
116 | p_replace=(0, 1.0), n_segments=(20, 200))),
117 | iaa.OneOf([
118 | # blur images with a sigma between 0 and 3.0
119 | iaa.GaussianBlur((0, 3.0)),
120 | # blur image using local means with kernel sizes
121 | # between 2 and 7
122 | iaa.AverageBlur(k=(2, 7)),
123 | # blur image using local medians with kernel sizes
124 | # between 2 and 7
125 | iaa.MedianBlur(k=(3, 11)),
126 | ]),
127 | iaa.Sharpen(alpha=(0, 1.0), lightness=(
128 | 0.75, 1.5)), # sharpen images
129 | iaa.Emboss(alpha=(0, 1.0), strength=(
130 | 0, 2.0)), # emboss images
131 | # search either for all edges or for directed edges,
132 | # blend the result with the original image using a blobby mask
133 | iaa.BlendAlphaSimplexNoise(iaa.OneOf([
134 | iaa.EdgeDetect(alpha=(0.5, 1.0)),
135 | iaa.DirectedEdgeDetect(
136 | alpha=(0.5, 1.0), direction=(0.0, 1.0)),
137 | ])),
138 | # add gaussian noise to images
139 | iaa.AdditiveGaussianNoise(loc=0, scale=(
140 | 0.0, 0.05*255), per_channel=0.5),
141 | iaa.OneOf([
142 | # randomly remove up to 10% of the pixels
143 | iaa.Dropout((0.01, 0.1), per_channel=0.5),
144 | iaa.CoarseDropout((0.03, 0.15), size_percent=(
145 | 0.02, 0.05), per_channel=0.2),
146 | ]),
147 | # invert color channels
148 | iaa.Invert(0.05, per_channel=True),
149 | # change brightness of images (by -10 to 10 of original value)
150 | iaa.Add((-10, 10), per_channel=0.5),
151 | # change hue and saturation
152 | iaa.AddToHueAndSaturation((-20, 20)),
153 | # either change the brightness of the whole image (sometimes
154 | # per channel) or change the brightness of subareas
155 | iaa.OneOf([
156 | iaa.Multiply(
157 | (0.5, 1.5), per_channel=0.5),
158 | iaa.BlendAlphaFrequencyNoise(
159 | exponent=(-4, 0),
160 | foreground=iaa.Multiply(
161 | (0.5, 1.5), per_channel=True),
162 | background=iaa.contrast.LinearContrast(
163 | (0.5, 2.0))
164 | )
165 | ]),
166 | # improve or worsen the contrast
167 | iaa.contrast.LinearContrast((0.5, 2.0), per_channel=0.5),
168 | iaa.Grayscale(alpha=(0.0, 1.0)),
169 | # move pixels locally around (with random strengths)
170 | sometimes(iaa.ElasticTransformation(
171 | alpha=(0.5, 3.5), sigma=0.25)),
172 | # sometimes move parts of the image around
173 | sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.05))),
174 | sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1)))
175 | ],
176 | random_order=True
177 | )
178 | ],
179 | random_order=True
180 | )
181 |
182 |
183 | augmentation_functions = {
184 | "aug_all": _load_augmentation_aug_all,
185 | "aug_all2": _load_augmentation_aug_all2,
186 | "aug_geometric": _load_augmentation_aug_geometric,
187 | "aug_non_geometric": _load_augmentation_aug_non_geometric
188 | }
189 |
190 |
191 | def _load_augmentation(augmentation_name="aug_all"):
192 |
193 | global IMAGE_AUGMENTATION_SEQUENCE
194 |
195 | if augmentation_name not in augmentation_functions:
196 | raise ValueError("Augmentation name not supported")
197 |
198 | IMAGE_AUGMENTATION_SEQUENCE = augmentation_functions[augmentation_name]()
199 |
200 |
201 | def _augment_seg(img, seg, augmentation_name="aug_all", other_imgs=None):
202 |
203 | global loaded_augmentation_name
204 |
205 | if (not IMAGE_AUGMENTATION_SEQUENCE) or\
206 | (augmentation_name != loaded_augmentation_name):
207 | _load_augmentation(augmentation_name)
208 | loaded_augmentation_name = augmentation_name
209 |
210 | # Create a deterministic augmentation from the random one
211 | aug_det = IMAGE_AUGMENTATION_SEQUENCE.to_deterministic()
212 | # Augment the input image
213 | image_aug = aug_det.augment_image(img)
214 |
215 | if other_imgs is not None:
216 | image_aug = [image_aug]
217 |
218 | for other_img in other_imgs:
219 | image_aug.append(aug_det.augment_image(other_img))
220 |
221 | segmap = ia.SegmentationMapsOnImage(
222 | seg, shape=img.shape)
223 | segmap_aug = aug_det.augment_segmentation_maps(segmap)
224 | segmap_aug = segmap_aug.get_arr()
225 |
226 | return image_aug, segmap_aug
227 |
228 |
229 | def _custom_augment_seg(img, seg, augmentation_function, other_imgs=None):
230 | augmentation_functions['custom_aug'] = augmentation_function
231 |
232 | return _augment_seg(img, seg, "custom_aug", other_imgs=other_imgs)
233 |
234 |
235 | def _try_n_times(fn, n, *args, **kargs):
236 | """ Try a function N times """
237 | attempts = 0
238 | while attempts < n:
239 | try:
240 | return fn(*args, **kargs)
241 | except Exception:
242 | attempts += 1
243 |
244 | return fn(*args, **kargs)
245 |
246 |
247 | def augment_seg(img, seg, augmentation_name="aug_all", other_imgs=None):
248 | return _try_n_times(_augment_seg, IMAGE_AUGMENTATION_NUM_TRIES,
249 | img, seg, augmentation_name=augmentation_name,
250 | other_imgs=other_imgs)
251 |
252 |
253 | def custom_augment_seg(img, seg, augmentation_function, other_imgs=None):
254 | return _try_n_times(_custom_augment_seg, IMAGE_AUGMENTATION_NUM_TRIES,
255 | img, seg, augmentation_function=augmentation_function,
256 | other_imgs=other_imgs)
--------------------------------------------------------------------------------
/keras_segmentation/data_utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 | import random
4 | import six
5 | import numpy as np
6 | import cv2
7 |
8 | try:
9 | from collections.abc import Sequence
10 | except ImportError:
11 | from collections import Sequence
12 |
13 | try:
14 | from tqdm import tqdm
15 | except ImportError:
16 | print("tqdm not found, disabling progress bars")
17 |
18 | def tqdm(iter):
19 | return iter
20 |
21 |
22 | from ..models.config import IMAGE_ORDERING
23 | from .augmentation import augment_seg, custom_augment_seg
24 |
25 | DATA_LOADER_SEED = 0
26 |
27 | random.seed(DATA_LOADER_SEED)
28 | class_colors = [(random.randint(0, 255), random.randint(
29 | 0, 255), random.randint(0, 255)) for _ in range(5000)]
30 |
31 |
32 | ACCEPTABLE_IMAGE_FORMATS = [".jpg", ".jpeg", ".png", ".bmp"]
33 | ACCEPTABLE_SEGMENTATION_FORMATS = [".png", ".bmp"]
34 |
35 |
36 | class DataLoaderError(Exception):
37 | pass
38 |
39 |
40 |
41 | def get_image_list_from_path(images_path ):
42 | image_files = []
43 | for dir_entry in os.listdir(images_path):
44 | if os.path.isfile(os.path.join(images_path, dir_entry)) and \
45 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_IMAGE_FORMATS:
46 | file_name, file_extension = os.path.splitext(dir_entry)
47 | image_files.append(os.path.join(images_path, dir_entry))
48 | return image_files
49 |
50 |
51 | def get_pairs_from_paths(images_path, segs_path, ignore_non_matching=False, other_inputs_paths=None):
52 | """ Find all the images from the images_path directory and
53 | the segmentation images from the segs_path directory
54 | while checking integrity of data """
55 |
56 |
57 |
58 | image_files = []
59 | segmentation_files = {}
60 |
61 | for dir_entry in os.listdir(images_path):
62 | if os.path.isfile(os.path.join(images_path, dir_entry)) and \
63 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_IMAGE_FORMATS:
64 | file_name, file_extension = os.path.splitext(dir_entry)
65 | image_files.append((file_name, file_extension,
66 | os.path.join(images_path, dir_entry)))
67 |
68 | if other_inputs_paths is not None:
69 | other_inputs_files = []
70 |
71 | for i, other_inputs_path in enumerate(other_inputs_paths):
72 | temp = []
73 |
74 | for y, dir_entry in enumerate(os.listdir(other_inputs_path)):
75 | if os.path.isfile(os.path.join(other_inputs_path, dir_entry)) and \
76 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_IMAGE_FORMATS:
77 | file_name, file_extension = os.path.splitext(dir_entry)
78 |
79 | temp.append((file_name, file_extension,
80 | os.path.join(other_inputs_path, dir_entry)))
81 |
82 | other_inputs_files.append(temp)
83 |
84 | for dir_entry in os.listdir(segs_path):
85 | if os.path.isfile(os.path.join(segs_path, dir_entry)) and \
86 | os.path.splitext(dir_entry)[1] in ACCEPTABLE_SEGMENTATION_FORMATS:
87 | file_name, file_extension = os.path.splitext(dir_entry)
88 | full_dir_entry = os.path.join(segs_path, dir_entry)
89 | if file_name in segmentation_files:
90 | raise DataLoaderError("Segmentation file with filename {0}"
91 | " already exists and is ambiguous to"
92 | " resolve with path {1}."
93 | " Please remove or rename the latter."
94 | .format(file_name, full_dir_entry))
95 |
96 | segmentation_files[file_name] = (file_extension, full_dir_entry)
97 |
98 | return_value = []
99 | # Match the images and segmentations
100 | for image_file, _, image_full_path in image_files:
101 | if image_file in segmentation_files:
102 | if other_inputs_paths is not None:
103 | other_inputs = []
104 | for file_paths in other_inputs_files:
105 | success = False
106 |
107 | for (other_file, _, other_full_path) in file_paths:
108 | if image_file == other_file:
109 | other_inputs.append(other_full_path)
110 | success = True
111 | break
112 |
113 | if not success:
114 | raise ValueError("There was no matching other input to", image_file, "in directory")
115 |
116 | return_value.append((image_full_path,
117 | segmentation_files[image_file][1], other_inputs))
118 | else:
119 | return_value.append((image_full_path,
120 | segmentation_files[image_file][1]))
121 | elif ignore_non_matching:
122 | continue
123 | else:
124 | # Error out
125 | raise DataLoaderError("No corresponding segmentation "
126 | "found for image {0}."
127 | .format(image_full_path))
128 |
129 | return return_value
130 |
131 |
132 | def get_image_array(image_input,
133 | width, height,
134 | imgNorm="sub_mean", ordering='channels_first', read_image_type=1):
135 | """ Load image array from input """
136 |
137 | if type(image_input) is np.ndarray:
138 | # It is already an array, use it as it is
139 | img = image_input
140 | elif isinstance(image_input, six.string_types):
141 | if not os.path.isfile(image_input):
142 | raise DataLoaderError("get_image_array: path {0} doesn't exist"
143 | .format(image_input))
144 | img = cv2.imread(image_input, read_image_type)
145 | else:
146 | raise DataLoaderError("get_image_array: Can't process input type {0}"
147 | .format(str(type(image_input))))
148 |
149 | if imgNorm == "sub_and_divide":
150 | img = np.float32(cv2.resize(img, (width, height))) / 127.5 - 1
151 | elif imgNorm == "sub_mean":
152 | img = cv2.resize(img, (width, height))
153 | img = img.astype(np.float32)
154 | img = np.atleast_3d(img)
155 |
156 | means = [103.939, 116.779, 123.68]
157 |
158 | for i in range(min(img.shape[2], len(means))):
159 | img[:, :, i] -= means[i]
160 |
161 | img = img[:, :, ::-1]
162 | elif imgNorm == "divide":
163 | img = cv2.resize(img, (width, height))
164 | img = img.astype(np.float32)
165 | img = img/255.0
166 |
167 | if ordering == 'channels_first':
168 | img = np.rollaxis(img, 2, 0)
169 | return img
170 |
171 |
172 | def get_segmentation_array(image_input, nClasses,
173 | width, height, no_reshape=False, read_image_type=1):
174 | """ Load segmentation array from input """
175 |
176 | seg_labels = np.zeros((height, width, nClasses))
177 |
178 | if type(image_input) is np.ndarray:
179 | # It is already an array, use it as it is
180 | img = image_input
181 | elif isinstance(image_input, six.string_types):
182 | if not os.path.isfile(image_input):
183 | raise DataLoaderError("get_segmentation_array: "
184 | "path {0} doesn't exist".format(image_input))
185 | img = cv2.imread(image_input, read_image_type)
186 | else:
187 | raise DataLoaderError("get_segmentation_array: "
188 | "Can't process input type {0}"
189 | .format(str(type(image_input))))
190 |
191 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST)
192 | img = img[:, :, 0]
193 |
194 | for c in range(nClasses):
195 | seg_labels[:, :, c] = (img == c).astype(int)
196 |
197 | if not no_reshape:
198 | seg_labels = np.reshape(seg_labels, (width*height, nClasses))
199 |
200 | return seg_labels
201 |
202 |
203 | def verify_segmentation_dataset(images_path, segs_path,
204 | n_classes, show_all_errors=False):
205 | try:
206 | img_seg_pairs = get_pairs_from_paths(images_path, segs_path)
207 | if not len(img_seg_pairs):
208 | print("Couldn't load any data from images_path: "
209 | "{0} and segmentations path: {1}"
210 | .format(images_path, segs_path))
211 | return False
212 |
213 | return_value = True
214 | for im_fn, seg_fn in tqdm(img_seg_pairs):
215 | img = cv2.imread(im_fn)
216 | seg = cv2.imread(seg_fn)
217 | # Check dimensions match
218 | if not img.shape == seg.shape:
219 | return_value = False
220 | print("The size of image {0} and its segmentation {1} "
221 | "doesn't match (possibly the files are corrupt)."
222 | .format(im_fn, seg_fn))
223 | if not show_all_errors:
224 | break
225 | else:
226 | max_pixel_value = np.max(seg[:, :, 0])
227 | if max_pixel_value >= n_classes:
228 | return_value = False
229 | print("The pixel values of the segmentation image {0} "
230 | "violating range [0, {1}]. "
231 | "Found maximum pixel value {2}"
232 | .format(seg_fn, str(n_classes - 1), max_pixel_value))
233 | if not show_all_errors:
234 | break
235 | if return_value:
236 | print("Dataset verified! ")
237 | else:
238 | print("Dataset not verified!")
239 | return return_value
240 | except DataLoaderError as e:
241 | print("Found error during data loading\n{0}".format(str(e)))
242 | return False
243 |
244 |
245 | def image_segmentation_generator(images_path, segs_path, batch_size,
246 | n_classes, input_height, input_width,
247 | output_height, output_width,
248 | do_augment=False,
249 | augmentation_name="aug_all",
250 | custom_augmentation=None,
251 | other_inputs_paths=None, preprocessing=None,
252 | read_image_type=cv2.IMREAD_COLOR , ignore_segs=False ):
253 |
254 |
255 | if not ignore_segs:
256 | img_seg_pairs = get_pairs_from_paths(images_path, segs_path, other_inputs_paths=other_inputs_paths)
257 | random.shuffle(img_seg_pairs)
258 | zipped = itertools.cycle(img_seg_pairs)
259 | else:
260 | img_list = get_image_list_from_path( images_path )
261 | random.shuffle( img_list )
262 | img_list_gen = itertools.cycle( img_list )
263 |
264 |
265 | while True:
266 | X = []
267 | Y = []
268 | for _ in range(batch_size):
269 | if other_inputs_paths is None:
270 |
271 | if ignore_segs:
272 | im = next( img_list_gen )
273 | seg = None
274 | else:
275 | im, seg = next(zipped)
276 | seg = cv2.imread(seg, 1)
277 |
278 | im = cv2.imread(im, read_image_type)
279 |
280 |
281 | if do_augment:
282 |
283 | assert ignore_segs == False , "Not supported yet"
284 |
285 | if custom_augmentation is None:
286 | im, seg[:, :, 0] = augment_seg(im, seg[:, :, 0],
287 | augmentation_name)
288 | else:
289 | im, seg[:, :, 0] = custom_augment_seg(im, seg[:, :, 0],
290 | custom_augmentation)
291 |
292 | if preprocessing is not None:
293 | im = preprocessing(im)
294 |
295 | X.append(get_image_array(im, input_width,
296 | input_height, ordering=IMAGE_ORDERING))
297 | else:
298 |
299 | assert ignore_segs == False , "Not supported yet"
300 |
301 | im, seg, others = next(zipped)
302 |
303 | im = cv2.imread(im, read_image_type)
304 | seg = cv2.imread(seg, 1)
305 |
306 | oth = []
307 | for f in others:
308 | oth.append(cv2.imread(f, read_image_type))
309 |
310 | if do_augment:
311 | if custom_augmentation is None:
312 | ims, seg[:, :, 0] = augment_seg(im, seg[:, :, 0],
313 | augmentation_name, other_imgs=oth)
314 | else:
315 | ims, seg[:, :, 0] = custom_augment_seg(im, seg[:, :, 0],
316 | custom_augmentation, other_imgs=oth)
317 | else:
318 | ims = [im]
319 | ims.extend(oth)
320 |
321 | oth = []
322 | for i, image in enumerate(ims):
323 | oth_im = get_image_array(image, input_width,
324 | input_height, ordering=IMAGE_ORDERING)
325 |
326 | if preprocessing is not None:
327 | if isinstance(preprocessing, Sequence):
328 | oth_im = preprocessing[i](oth_im)
329 | else:
330 | oth_im = preprocessing(oth_im)
331 |
332 | oth.append(oth_im)
333 |
334 | X.append(oth)
335 |
336 | if not ignore_segs:
337 | Y.append(get_segmentation_array(
338 | seg, n_classes, output_width, output_height))
339 |
340 | if ignore_segs:
341 | yield np.array(X)
342 | else:
343 | yield np.array(X), np.array(Y)
344 |
--------------------------------------------------------------------------------
/keras_segmentation/data_utils/visualize_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import random
4 |
5 | import numpy as np
6 | import cv2
7 |
8 | from .augmentation import augment_seg, custom_augment_seg
9 | from .data_loader import \
10 | get_pairs_from_paths, DATA_LOADER_SEED, class_colors, DataLoaderError
11 |
12 | random.seed(DATA_LOADER_SEED)
13 |
14 |
15 | def _get_colored_segmentation_image(img, seg, colors,
16 | n_classes, do_augment=False, augment_name='aug_all', custom_aug=None):
17 | """ Return a colored segmented image """
18 | seg_img = np.zeros_like(seg)
19 |
20 | if do_augment:
21 | if custom_aug is not None:
22 | img, seg[:, :, 0] = custom_augment_seg(img, seg[:, :, 0], augmentation_function=custom_aug)
23 | else:
24 | img, seg[:, :, 0] = augment_seg(img, seg[:, :, 0], augmentation_name=augment_name)
25 |
26 | for c in range(n_classes):
27 | seg_img[:, :, 0] += ((seg[:, :, 0] == c)
28 | * (colors[c][0])).astype('uint8')
29 | seg_img[:, :, 1] += ((seg[:, :, 0] == c)
30 | * (colors[c][1])).astype('uint8')
31 | seg_img[:, :, 2] += ((seg[:, :, 0] == c)
32 | * (colors[c][2])).astype('uint8')
33 |
34 | return img, seg_img
35 |
36 |
37 | def visualize_segmentation_dataset(images_path, segs_path, n_classes,
38 | do_augment=False, ignore_non_matching=False,
39 | no_show=False, image_size=None, augment_name="aug_all", custom_aug=None):
40 | try:
41 | # Get image-segmentation pairs
42 | img_seg_pairs = get_pairs_from_paths(
43 | images_path, segs_path,
44 | ignore_non_matching=ignore_non_matching)
45 |
46 | # Get the colors for the classes
47 | colors = class_colors
48 |
49 | print("Please press any key to display the next image")
50 | for im_fn, seg_fn in img_seg_pairs:
51 | img = cv2.imread(im_fn)
52 | seg = cv2.imread(seg_fn)
53 | print("Found the following classes in the segmentation image:",
54 | np.unique(seg))
55 | img, seg_img = _get_colored_segmentation_image(
56 | img, seg, colors,
57 | n_classes,
58 | do_augment=do_augment, augment_name=augment_name, custom_aug=custom_aug)
59 |
60 | if image_size is not None:
61 | img = cv2.resize(img, image_size)
62 | seg_img = cv2.resize(seg_img, image_size)
63 |
64 | print("Please press any key to display the next image")
65 | cv2.imshow("img", img)
66 | cv2.imshow("seg_img", seg_img)
67 | cv2.waitKey()
68 | except DataLoaderError as e:
69 | print("Found error during data loading\n{0}".format(str(e)))
70 | return False
71 |
72 |
73 | def visualize_segmentation_dataset_one(images_path, segs_path, n_classes,
74 | do_augment=False, no_show=False,
75 | ignore_non_matching=False):
76 |
77 | img_seg_pairs = get_pairs_from_paths(
78 | images_path, segs_path,
79 | ignore_non_matching=ignore_non_matching)
80 |
81 | colors = class_colors
82 |
83 | im_fn, seg_fn = random.choice(img_seg_pairs)
84 |
85 | img = cv2.imread(im_fn)
86 | seg = cv2.imread(seg_fn)
87 | print("Found the following classes "
88 | "in the segmentation image:", np.unique(seg))
89 |
90 | img, seg_img = _get_colored_segmentation_image(
91 | img, seg, colors,
92 | n_classes, do_augment=do_augment)
93 |
94 | if not no_show:
95 | cv2.imshow("img", img)
96 | cv2.imshow("seg_img", seg_img)
97 | cv2.waitKey()
98 |
99 | return img, seg_img
100 |
101 |
102 | if __name__ == "__main__":
103 | import argparse
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument("--images", type=str)
106 | parser.add_argument("--annotations", type=str)
107 | parser.add_argument("--n_classes", type=int)
108 | args = parser.parse_args()
109 |
110 | visualize_segmentation_dataset(
111 | args.images, args.annotations, args.n_classes)
112 |
--------------------------------------------------------------------------------
/keras_segmentation/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | EPS = 1e-12
4 |
5 |
6 | def get_iou(gt, pr, n_classes):
7 | class_wise = np.zeros(n_classes)
8 | for cl in range(n_classes):
9 | intersection = np.sum((gt == cl)*(pr == cl))
10 | union = np.sum(np.maximum((gt == cl), (pr == cl)))
11 | iou = float(intersection)/(union + EPS)
12 | class_wise[cl] = iou
13 | return class_wise
14 |
--------------------------------------------------------------------------------
/keras_segmentation/model_compression.py:
--------------------------------------------------------------------------------
1 | import keras
2 | import tensorflow as tf
3 |
4 | from tqdm import tqdm
5 | import numpy as np
6 | import six
7 | import os
8 | import json
9 | import sys
10 |
11 | from .data_utils.data_loader import image_segmentation_generator
12 | from .train import CheckpointsCallback
13 |
14 | from keras.models import Model
15 |
16 |
17 |
18 | def get_pariwise_similarities( feats ):
19 | feats_i = tf.reshape( feats , (-1 , 1 , feats.shape[1]*feats.shape[2] , feats.shape[3]))
20 | feats_j = tf.reshape( feats , (-1 , feats.shape[1]*feats.shape[2] , 1 , feats.shape[3]))
21 |
22 | feats_i = feats_i / (( tf.reduce_sum(feats_i**2 , axis=-1 ) )**(0.5))[ ... , None ]
23 | feats_j = feats_j / (( tf.reduce_sum(feats_j**2 , axis=-1 ) )**(0.5))[ ... , None ]
24 |
25 | feats_ixj = feats_i*feats_j
26 |
27 | return tf.reduce_sum( feats_ixj , axis=-1 )
28 |
29 |
30 |
31 | def pairwise_dist_loss( feats_t , feats_s ):
32 |
33 | # todo max POOL
34 | pool_factor = 4
35 |
36 | feats_t = tf.nn.max_pool(feats_t , (pool_factor,pool_factor) , strides=(pool_factor,pool_factor) , padding="VALID" )
37 | feats_s = tf.nn.max_pool(feats_s , (pool_factor,pool_factor) , strides=(pool_factor,pool_factor) , padding="VALID" )
38 |
39 | sims_t = get_pariwise_similarities( feats_t )
40 | sims_s = get_pariwise_similarities( feats_s )
41 | n_pixs = sims_s.shape[1]
42 |
43 | return tf.reduce_sum(tf.reduce_sum(((sims_t - sims_s )**2 ) , axis=1), axis=1)/(n_pixs**2 )
44 |
45 |
46 |
47 | class Distiller(keras.Model):
48 | def __init__(self, student, teacher , distilation_loss , feats_distilation_loss=None , feats_distilation_loss_w=0.1 ):
49 | super(Distiller, self).__init__()
50 | self.teacher = teacher
51 | self.student = student
52 | self.distilation_loss = distilation_loss
53 |
54 | self.feats_distilation_loss = feats_distilation_loss
55 | self.feats_distilation_loss_w = feats_distilation_loss_w
56 |
57 | if not feats_distilation_loss is None:
58 | try:
59 | s_feat_out = student.get_layer("seg_feats").output
60 | except:
61 | s_feat_out = student.get_layer(student.seg_feats_layer_name ).output
62 |
63 |
64 | try:
65 | t_feat_out = teacher.get_layer("seg_feats").output
66 | except:
67 | t_feat_out = teacher.get_layer(teacher.seg_feats_layer_name ).output
68 |
69 |
70 | self.student_feat_model = Model( student.input , s_feat_out )
71 | self.teacher_feat_model = Model( teacher.input , t_feat_out )
72 |
73 | def compile(
74 | self,
75 | optimizer,
76 | metrics,
77 |
78 | ):
79 | super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
80 |
81 |
82 | def train_step(self, data):
83 | teacher_input , = data
84 |
85 | student_input = tf.image.resize( teacher_input , ( self.student.input_height , self.student.input_width ) )
86 |
87 | teacher_predictions = self.teacher(teacher_input, training=False)
88 | teacher_predictions_reshape = tf.reshape(teacher_predictions , ((-1 , self.teacher.output_height , self.teacher.output_width , self.teacher.output_shape[-1])))
89 |
90 |
91 | if not self.feats_distilation_loss is None:
92 | teacher_feats = self.teacher_feat_model(teacher_input, training=False)
93 |
94 | with tf.GradientTape() as tape:
95 | student_predictions = self.student( student_input , training=True)
96 | student_predictions_resize = tf.reshape(student_predictions , ((-1, self.student.output_height , self.student.output_width , self.student.output_shape[-1])))
97 | student_predictions_resize = tf.image.resize( student_predictions_resize , ( self.teacher.output_height , self.teacher.output_width ) )
98 |
99 | loss = self.distilation_loss( teacher_predictions_reshape , student_predictions_resize )
100 |
101 | if not self.feats_distilation_loss is None:
102 | student_feats = self.student_feat_model( student_input , training=True)
103 | student_feats_resize = tf.image.resize( student_feats , ( teacher_feats.shape[1] , teacher_feats.shape[2] ) )
104 | loss += self.feats_distilation_loss_w*self.feats_distilation_loss( teacher_feats , student_feats_resize )
105 |
106 |
107 |
108 |
109 | trainable_vars = self.student.trainable_variables
110 | gradients = tape.gradient(loss, trainable_vars)
111 |
112 | self.optimizer.apply_gradients(zip(gradients, trainable_vars))
113 | self.compiled_metrics.update_state(teacher_predictions_reshape , student_predictions_resize )
114 |
115 |
116 | results = {m.name: m.result() for m in self.metrics}
117 | results.update(
118 | { "distillation_loss": loss}
119 | )
120 | return results
121 |
122 |
123 | # created a simple custom fit generator due to some issue in keras
124 | def fit_generator_custom( model , gen , epochs , steps_per_epoch , callback=None ):
125 | for ep in range( epochs ):
126 | print("Epoch %d/%d"%(ep+1 , epochs ))
127 | bar = tqdm( range(steps_per_epoch))
128 | losses = [ ]
129 | for i in bar:
130 | x = next( gen )
131 | l = model.train_on_batch( x )
132 | losses.append( l )
133 | bar.set_description("Loss : %s"%str(np.mean( np.array(losses) )))
134 | if not callback is None:
135 | callback.model = model.student
136 | callback.on_epoch_end( ep )
137 |
138 |
139 | def perform_distilation(teacher_model ,student_model, data_path , distilation_loss='kl' ,
140 | batch_size = 6 ,checkpoints_path=None , epochs = 32 , steps_per_epoch=512,
141 | feats_distilation_loss=None , feats_distilation_loss_w=0.1 ):
142 |
143 |
144 | losses_dict = { 'l1':keras.losses.MeanAbsoluteError() , "l2": keras.losses.MeanSquaredError() , "kl":keras.losses.KLDivergence() , 'pa':pairwise_dist_loss }
145 |
146 | if isinstance( distilation_loss , six.string_types):
147 | distilation_loss = losses_dict[ distilation_loss ]
148 |
149 | if isinstance( feats_distilation_loss , six.string_types):
150 | feats_distilation_loss = losses_dict[ feats_distilation_loss ]
151 |
152 |
153 | distill_model = Distiller( student=student_model , teacher=teacher_model , distilation_loss=distilation_loss, feats_distilation_loss=feats_distilation_loss , feats_distilation_loss_w=feats_distilation_loss_w )
154 |
155 | img_gen = image_segmentation_generator(images_path=data_path , segs_path=None, batch_size=batch_size,
156 | n_classes=teacher_model.n_classes , input_height=teacher_model.input_height, input_width=teacher_model.input_width,
157 | output_height=None, output_width=None , ignore_segs=True)
158 |
159 | distill_model.compile(
160 | optimizer='adam',
161 | metrics=[ distilation_loss ]
162 | )
163 |
164 | if checkpoints_path is not None:
165 | config_file = checkpoints_path + "_config.json"
166 | dir_name = os.path.dirname(config_file)
167 |
168 | if not os.path.exists(dir_name):
169 | os.makedirs(dir_name)
170 |
171 | with open(config_file, "w") as f:
172 | json.dump({
173 | "model_class": student_model.model_name,
174 | "n_classes": student_model.n_classes,
175 | "input_height": student_model.input_height,
176 | "input_width": student_model.input_width,
177 | "output_height": student_model.output_height,
178 | "output_width": student_model.output_width
179 | }, f)
180 |
181 | cb = CheckpointsCallback( checkpoints_path )
182 | else:
183 | cb = None
184 |
185 | fit_generator_custom( distill_model , img_gen , steps_per_epoch=steps_per_epoch , epochs=epochs ,callback=cb )
186 |
187 | print("done ")
188 |
189 |
190 |
--------------------------------------------------------------------------------
/keras_segmentation/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/keras_segmentation/models/__init__.py
--------------------------------------------------------------------------------
/keras_segmentation/models/_pspnet_2.py:
--------------------------------------------------------------------------------
1 | # This code is proveded by Vladkryvoruchko and small modifications done by me .
2 |
3 | from math import ceil
4 | from sys import exit
5 | from keras import layers
6 | from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
7 | from keras.layers import BatchNormalization, Activation,\
8 | Input, Dropout, ZeroPadding2D
9 | from keras.layers import Concatenate, Add
10 | import tensorflow as tf
11 |
12 | from .config import IMAGE_ORDERING
13 | from .model_utils import get_segmentation_model
14 |
15 |
16 | learning_rate = 1e-3 # Layer specific learning rate
17 | # Weight decay not implemented
18 |
19 |
20 | def BN(name=""):
21 | return BatchNormalization(momentum=0.95, name=name, epsilon=1e-5)
22 |
23 |
24 | class Interp(layers.Layer):
25 |
26 | def __init__(self, new_size, **kwargs):
27 | self.new_size = new_size
28 | super(Interp, self).__init__(**kwargs)
29 |
30 | def build(self, input_shape):
31 | super(Interp, self).build(input_shape)
32 |
33 | def call(self, inputs, **kwargs):
34 | new_height, new_width = self.new_size
35 | try:
36 | resized = tf.image.resize(inputs, [new_height, new_width])
37 | except AttributeError:
38 | resized = tf.image.resize_images(inputs, [new_height, new_width],
39 | align_corners=True)
40 | return resized
41 |
42 | def compute_output_shape(self, input_shape):
43 | return tuple([None,
44 | self.new_size[0],
45 | self.new_size[1],
46 | input_shape[3]])
47 |
48 | def get_config(self):
49 | config = super(Interp, self).get_config()
50 | config['new_size'] = self.new_size
51 | return config
52 |
53 |
54 | # def Interp(x, shape):
55 | # new_height, new_width = shape
56 | # resized = tf.image.resize_images(x, [new_height, new_width],
57 | # align_corners=True)
58 | # return resized
59 |
60 |
61 | def residual_conv(prev, level, pad=1, lvl=1, sub_lvl=1, modify_stride=False):
62 | lvl = str(lvl)
63 | sub_lvl = str(sub_lvl)
64 | names = ["conv" + lvl + "_" + sub_lvl + "_1x1_reduce",
65 | "conv" + lvl + "_" + sub_lvl + "_1x1_reduce_bn",
66 | "conv" + lvl + "_" + sub_lvl + "_3x3",
67 | "conv" + lvl + "_" + sub_lvl + "_3x3_bn",
68 | "conv" + lvl + "_" + sub_lvl + "_1x1_increase",
69 | "conv" + lvl + "_" + sub_lvl + "_1x1_increase_bn"]
70 | if modify_stride is False:
71 | prev = Conv2D(64 * level, (1, 1), strides=(1, 1), name=names[0],
72 | use_bias=False)(prev)
73 | elif modify_stride is True:
74 | prev = Conv2D(64 * level, (1, 1), strides=(2, 2), name=names[0],
75 | use_bias=False)(prev)
76 |
77 | prev = BN(name=names[1])(prev)
78 | prev = Activation('relu')(prev)
79 |
80 | prev = ZeroPadding2D(padding=(pad, pad))(prev)
81 | prev = Conv2D(64 * level, (3, 3), strides=(1, 1), dilation_rate=pad,
82 | name=names[2], use_bias=False)(prev)
83 |
84 | prev = BN(name=names[3])(prev)
85 | prev = Activation('relu')(prev)
86 | prev = Conv2D(256 * level, (1, 1), strides=(1, 1), name=names[4],
87 | use_bias=False)(prev)
88 | prev = BN(name=names[5])(prev)
89 | return prev
90 |
91 |
92 | def short_convolution_branch(prev, level, lvl=1, sub_lvl=1,
93 | modify_stride=False):
94 | lvl = str(lvl)
95 | sub_lvl = str(sub_lvl)
96 | names = ["conv" + lvl + "_" + sub_lvl + "_1x1_proj",
97 | "conv" + lvl + "_" + sub_lvl + "_1x1_proj_bn"]
98 |
99 | if modify_stride is False:
100 | prev = Conv2D(256 * level, (1, 1), strides=(1, 1), name=names[0],
101 | use_bias=False)(prev)
102 | elif modify_stride is True:
103 | prev = Conv2D(256 * level, (1, 1), strides=(2, 2), name=names[0],
104 | use_bias=False)(prev)
105 |
106 | prev = BN(name=names[1])(prev)
107 | return prev
108 |
109 |
110 | def empty_branch(prev):
111 | return prev
112 |
113 |
114 | def residual_short(prev_layer, level, pad=1, lvl=1, sub_lvl=1,
115 | modify_stride=False):
116 | prev_layer = Activation('relu')(prev_layer)
117 | block_1 = residual_conv(prev_layer, level,
118 | pad=pad, lvl=lvl, sub_lvl=sub_lvl,
119 | modify_stride=modify_stride)
120 |
121 | block_2 = short_convolution_branch(prev_layer, level,
122 | lvl=lvl, sub_lvl=sub_lvl,
123 | modify_stride=modify_stride)
124 | added = Add()([block_1, block_2])
125 | return added
126 |
127 |
128 | def residual_empty(prev_layer, level, pad=1, lvl=1, sub_lvl=1):
129 | prev_layer = Activation('relu')(prev_layer)
130 |
131 | block_1 = residual_conv(prev_layer, level, pad=pad,
132 | lvl=lvl, sub_lvl=sub_lvl)
133 | block_2 = empty_branch(prev_layer)
134 | added = Add()([block_1, block_2])
135 | return added
136 |
137 |
138 | def ResNet(inp, layers):
139 | # Names for the first couple layers of model
140 | names = ["conv1_1_3x3_s2",
141 | "conv1_1_3x3_s2_bn",
142 | "conv1_2_3x3",
143 | "conv1_2_3x3_bn",
144 | "conv1_3_3x3",
145 | "conv1_3_3x3_bn"]
146 |
147 | # Short branch(only start of network)
148 |
149 | cnv1 = Conv2D(64, (3, 3), strides=(2, 2), padding='same', name=names[0],
150 | use_bias=False)(inp) # "conv1_1_3x3_s2"
151 | bn1 = BN(name=names[1])(cnv1) # "conv1_1_3x3_s2/bn"
152 | relu1 = Activation('relu')(bn1) # "conv1_1_3x3_s2/relu"
153 |
154 | cnv1 = Conv2D(64, (3, 3), strides=(1, 1), padding='same', name=names[2],
155 | use_bias=False)(relu1) # "conv1_2_3x3"
156 | bn1 = BN(name=names[3])(cnv1) # "conv1_2_3x3/bn"
157 | relu1 = Activation('relu')(bn1) # "conv1_2_3x3/relu"
158 |
159 | cnv1 = Conv2D(128, (3, 3), strides=(1, 1), padding='same', name=names[4],
160 | use_bias=False)(relu1) # "conv1_3_3x3"
161 | bn1 = BN(name=names[5])(cnv1) # "conv1_3_3x3/bn"
162 | relu1 = Activation('relu')(bn1) # "conv1_3_3x3/relu"
163 |
164 | res = MaxPooling2D(pool_size=(3, 3), padding='same',
165 | strides=(2, 2))(relu1) # "pool1_3x3_s2"
166 |
167 | # ---Residual layers(body of network)
168 |
169 | """
170 | Modify_stride --Used only once in first 3_1 convolutions block.
171 | changes stride of first convolution from 1 -> 2
172 | """
173 |
174 | # 2_1- 2_3
175 | res = residual_short(res, 1, pad=1, lvl=2, sub_lvl=1)
176 | for i in range(2):
177 | res = residual_empty(res, 1, pad=1, lvl=2, sub_lvl=i + 2)
178 |
179 | # 3_1 - 3_3
180 | res = residual_short(res, 2, pad=1, lvl=3, sub_lvl=1, modify_stride=True)
181 | for i in range(3):
182 | res = residual_empty(res, 2, pad=1, lvl=3, sub_lvl=i + 2)
183 | if layers == 50:
184 | # 4_1 - 4_6
185 | res = residual_short(res, 4, pad=2, lvl=4, sub_lvl=1)
186 | for i in range(5):
187 | res = residual_empty(res, 4, pad=2, lvl=4, sub_lvl=i + 2)
188 | elif layers == 101:
189 | # 4_1 - 4_23
190 | res = residual_short(res, 4, pad=2, lvl=4, sub_lvl=1)
191 | for i in range(22):
192 | res = residual_empty(res, 4, pad=2, lvl=4, sub_lvl=i + 2)
193 | else:
194 | print("This ResNet is not implemented")
195 |
196 | # 5_1 - 5_3
197 | res = residual_short(res, 8, pad=4, lvl=5, sub_lvl=1)
198 | for i in range(2):
199 | res = residual_empty(res, 8, pad=4, lvl=5, sub_lvl=i + 2)
200 |
201 | res = Activation('relu')(res)
202 | return res
203 |
204 |
205 | def interp_block(prev_layer, level, feature_map_shape, input_shape):
206 | if input_shape == (473, 473):
207 | kernel_strides_map = {1: 60,
208 | 2: 30,
209 | 3: 20,
210 | 6: 10}
211 | elif input_shape == (713, 713):
212 | kernel_strides_map = {1: 90,
213 | 2: 45,
214 | 3: 30,
215 | 6: 15}
216 | else:
217 | print("Pooling parameters for input shape ",
218 | input_shape, " are not defined.")
219 | exit(1)
220 |
221 | names = [
222 | "conv5_3_pool" + str(level) + "_conv",
223 | "conv5_3_pool" + str(level) + "_conv_bn"
224 | ]
225 | kernel = (kernel_strides_map[level], kernel_strides_map[level])
226 | strides = (kernel_strides_map[level], kernel_strides_map[level])
227 | prev_layer = AveragePooling2D(kernel, strides=strides)(prev_layer)
228 | prev_layer = Conv2D(512, (1, 1), strides=(1, 1), name=names[0],
229 | use_bias=False)(prev_layer)
230 | prev_layer = BN(name=names[1])(prev_layer)
231 | prev_layer = Activation('relu')(prev_layer)
232 | # prev_layer = Lambda(Interp, arguments={
233 | # 'shape': feature_map_shape})(prev_layer)
234 | prev_layer = Interp(feature_map_shape)(prev_layer)
235 | return prev_layer
236 |
237 |
238 | def build_pyramid_pooling_module(res, input_shape):
239 | """Build the Pyramid Pooling Module."""
240 | # ---PSPNet concat layers with Interpolation
241 | feature_map_size = tuple(int(ceil(input_dim / 8.0))
242 | for input_dim in input_shape)
243 |
244 | interp_block1 = interp_block(res, 1, feature_map_size, input_shape)
245 | interp_block2 = interp_block(res, 2, feature_map_size, input_shape)
246 | interp_block3 = interp_block(res, 3, feature_map_size, input_shape)
247 | interp_block6 = interp_block(res, 6, feature_map_size, input_shape)
248 |
249 | # concat all these layers. resulted
250 | # shape=(1,feature_map_size_x,feature_map_size_y,4096)
251 | res = Concatenate()([res,
252 | interp_block6,
253 | interp_block3,
254 | interp_block2,
255 | interp_block1])
256 | return res
257 |
258 |
259 | def _build_pspnet(nb_classes, resnet_layers, input_shape,
260 | activation='softmax', channels=3):
261 |
262 | assert IMAGE_ORDERING == 'channels_last'
263 |
264 | inp = Input((input_shape[0], input_shape[1], channels))
265 |
266 | res = ResNet(inp, layers=resnet_layers)
267 |
268 | psp = build_pyramid_pooling_module(res, input_shape)
269 |
270 | x = Conv2D(512, (3, 3), strides=(1, 1), padding="same", name="conv5_4",
271 | use_bias=False)(psp)
272 | x = BN(name="conv5_4_bn")(x)
273 | x = Activation('relu')(x)
274 | x = Dropout(0.1)(x)
275 |
276 | x = Conv2D(nb_classes, (1, 1), strides=(1, 1), name="conv6")(x)
277 | # x = Lambda(Interp, arguments={'shape': (
278 | # input_shape[0], input_shape[1])})(x)
279 | x = Interp([input_shape[0], input_shape[1]])(x)
280 |
281 | model = get_segmentation_model(inp, x)
282 | model.seg_feats_layer_name = "conv5_4"
283 |
284 | return model
285 |
--------------------------------------------------------------------------------
/keras_segmentation/models/all_models.py:
--------------------------------------------------------------------------------
1 | from . import pspnet
2 | from . import unet
3 | from . import segnet
4 | from . import fcn
5 | model_from_name = {}
6 |
7 |
8 | model_from_name["fcn_8"] = fcn.fcn_8
9 | model_from_name["fcn_32"] = fcn.fcn_32
10 | model_from_name["fcn_8_vgg"] = fcn.fcn_8_vgg
11 | model_from_name["fcn_32_vgg"] = fcn.fcn_32_vgg
12 | model_from_name["fcn_8_resnet50"] = fcn.fcn_8_resnet50
13 | model_from_name["fcn_32_resnet50"] = fcn.fcn_32_resnet50
14 | model_from_name["fcn_8_mobilenet"] = fcn.fcn_8_mobilenet
15 | model_from_name["fcn_32_mobilenet"] = fcn.fcn_32_mobilenet
16 |
17 |
18 | model_from_name["pspnet"] = pspnet.pspnet
19 | model_from_name["vgg_pspnet"] = pspnet.vgg_pspnet
20 | model_from_name["resnet50_pspnet"] = pspnet.resnet50_pspnet
21 |
22 | model_from_name["vgg_pspnet"] = pspnet.vgg_pspnet
23 | model_from_name["resnet50_pspnet"] = pspnet.resnet50_pspnet
24 |
25 | model_from_name["pspnet_50"] = pspnet.pspnet_50
26 | model_from_name["pspnet_101"] = pspnet.pspnet_101
27 |
28 |
29 | # model_from_name["mobilenet_pspnet"] = pspnet.mobilenet_pspnet
30 |
31 |
32 | model_from_name["unet_mini"] = unet.unet_mini
33 | model_from_name["unet"] = unet.unet
34 | model_from_name["vgg_unet"] = unet.vgg_unet
35 | model_from_name["resnet50_unet"] = unet.resnet50_unet
36 | model_from_name["mobilenet_unet"] = unet.mobilenet_unet
37 |
38 |
39 | model_from_name["segnet"] = segnet.segnet
40 | model_from_name["vgg_segnet"] = segnet.vgg_segnet
41 | model_from_name["resnet50_segnet"] = segnet.resnet50_segnet
42 | model_from_name["mobilenet_segnet"] = segnet.mobilenet_segnet
43 |
--------------------------------------------------------------------------------
/keras_segmentation/models/basic_models.py:
--------------------------------------------------------------------------------
1 | from keras.models import *
2 | from keras.layers import *
3 | import keras.backend as K
4 |
5 | from .config import IMAGE_ORDERING
6 |
7 |
8 | def vanilla_encoder(input_height=224, input_width=224, channels=3):
9 |
10 | kernel = 3
11 | filter_size = 64
12 | pad = 1
13 | pool_size = 2
14 |
15 | if IMAGE_ORDERING == 'channels_first':
16 | img_input = Input(shape=(channels, input_height, input_width))
17 | elif IMAGE_ORDERING == 'channels_last':
18 | img_input = Input(shape=(input_height, input_width, channels))
19 |
20 | x = img_input
21 | levels = []
22 |
23 | x = (ZeroPadding2D((pad, pad), data_format=IMAGE_ORDERING))(x)
24 | x = (Conv2D(filter_size, (kernel, kernel),
25 | data_format=IMAGE_ORDERING, padding='valid'))(x)
26 | x = (BatchNormalization())(x)
27 | x = (Activation('relu'))(x)
28 | x = (MaxPooling2D((pool_size, pool_size), data_format=IMAGE_ORDERING))(x)
29 | levels.append(x)
30 |
31 | x = (ZeroPadding2D((pad, pad), data_format=IMAGE_ORDERING))(x)
32 | x = (Conv2D(128, (kernel, kernel), data_format=IMAGE_ORDERING,
33 | padding='valid'))(x)
34 | x = (BatchNormalization())(x)
35 | x = (Activation('relu'))(x)
36 | x = (MaxPooling2D((pool_size, pool_size), data_format=IMAGE_ORDERING))(x)
37 | levels.append(x)
38 |
39 | for _ in range(3):
40 | x = (ZeroPadding2D((pad, pad), data_format=IMAGE_ORDERING))(x)
41 | x = (Conv2D(256, (kernel, kernel),
42 | data_format=IMAGE_ORDERING, padding='valid'))(x)
43 | x = (BatchNormalization())(x)
44 | x = (Activation('relu'))(x)
45 | x = (MaxPooling2D((pool_size, pool_size),
46 | data_format=IMAGE_ORDERING))(x)
47 | levels.append(x)
48 |
49 | return img_input, levels
50 |
--------------------------------------------------------------------------------
/keras_segmentation/models/config.py:
--------------------------------------------------------------------------------
1 | IMAGE_ORDERING_CHANNELS_LAST = "channels_last"
2 | IMAGE_ORDERING_CHANNELS_FIRST = "channels_first"
3 |
4 | # Default IMAGE_ORDERING = channels_last
5 | IMAGE_ORDERING = IMAGE_ORDERING_CHANNELS_LAST
6 |
--------------------------------------------------------------------------------
/keras_segmentation/models/fcn.py:
--------------------------------------------------------------------------------
1 | from keras.models import *
2 | from keras.layers import *
3 |
4 | from .config import IMAGE_ORDERING
5 | from .model_utils import get_segmentation_model
6 | from .vgg16 import get_vgg_encoder
7 | from .mobilenet import get_mobilenet_encoder
8 | from .basic_models import vanilla_encoder
9 | from .resnet50 import get_resnet50_encoder
10 |
11 |
12 | # crop o1 wrt o2
13 | def crop(o1, o2, i):
14 | o_shape2 = Model(i, o2).output_shape
15 |
16 | if IMAGE_ORDERING == 'channels_first':
17 | output_height2 = o_shape2[2]
18 | output_width2 = o_shape2[3]
19 | else:
20 | output_height2 = o_shape2[1]
21 | output_width2 = o_shape2[2]
22 |
23 | o_shape1 = Model(i, o1).output_shape
24 | if IMAGE_ORDERING == 'channels_first':
25 | output_height1 = o_shape1[2]
26 | output_width1 = o_shape1[3]
27 | else:
28 | output_height1 = o_shape1[1]
29 | output_width1 = o_shape1[2]
30 |
31 | cx = abs(output_width1 - output_width2)
32 | cy = abs(output_height2 - output_height1)
33 |
34 | if output_width1 > output_width2:
35 | o1 = Cropping2D(cropping=((0, 0), (0, cx)),
36 | data_format=IMAGE_ORDERING)(o1)
37 | else:
38 | o2 = Cropping2D(cropping=((0, 0), (0, cx)),
39 | data_format=IMAGE_ORDERING)(o2)
40 |
41 | if output_height1 > output_height2:
42 | o1 = Cropping2D(cropping=((0, cy), (0, 0)),
43 | data_format=IMAGE_ORDERING)(o1)
44 | else:
45 | o2 = Cropping2D(cropping=((0, cy), (0, 0)),
46 | data_format=IMAGE_ORDERING)(o2)
47 |
48 | return o1, o2
49 |
50 |
51 | def fcn_8(n_classes, encoder=vanilla_encoder, input_height=416,
52 | input_width=608, channels=3):
53 |
54 | img_input, levels = encoder(
55 | input_height=input_height, input_width=input_width, channels=channels)
56 | [f1, f2, f3, f4, f5] = levels
57 |
58 | o = f5
59 |
60 | o = (Conv2D(4096, (7, 7), activation='relu',
61 | padding='same', data_format=IMAGE_ORDERING))(o)
62 | o = Dropout(0.5)(o)
63 | o = (Conv2D(4096, (1, 1), activation='relu',
64 | padding='same', data_format=IMAGE_ORDERING))(o)
65 | o = Dropout(0.5)(o)
66 |
67 | o = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal',
68 | data_format=IMAGE_ORDERING))(o)
69 | o = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(
70 | 2, 2), use_bias=False, data_format=IMAGE_ORDERING)(o)
71 |
72 | o2 = f4
73 | o2 = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal',
74 | data_format=IMAGE_ORDERING))(o2)
75 |
76 | o, o2 = crop(o, o2, img_input)
77 |
78 | o = Add()([o, o2])
79 |
80 | o = Conv2DTranspose(n_classes, kernel_size=(4, 4), strides=(
81 | 2, 2), use_bias=False, data_format=IMAGE_ORDERING)(o)
82 | o2 = f3
83 | o2 = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal',
84 | data_format=IMAGE_ORDERING))(o2)
85 | o2, o = crop(o2, o, img_input)
86 | o = Add( name="seg_feats" )([o2, o])
87 |
88 | o = Conv2DTranspose(n_classes, kernel_size=(16, 16), strides=(
89 | 8, 8), use_bias=False, data_format=IMAGE_ORDERING)(o)
90 |
91 | model = get_segmentation_model(img_input, o)
92 | model.model_name = "fcn_8"
93 | return model
94 |
95 |
96 | def fcn_32(n_classes, encoder=vanilla_encoder, input_height=416,
97 | input_width=608, channels=3):
98 |
99 | img_input, levels = encoder(
100 | input_height=input_height, input_width=input_width, channels=channels)
101 | [f1, f2, f3, f4, f5] = levels
102 |
103 | o = f5
104 |
105 | o = (Conv2D(4096, (7, 7), activation='relu',
106 | padding='same', data_format=IMAGE_ORDERING))(o)
107 | o = Dropout(0.5)(o)
108 | o = (Conv2D(4096, (1, 1), activation='relu',
109 | padding='same', data_format=IMAGE_ORDERING))(o)
110 | o = Dropout(0.5)(o)
111 |
112 | o = (Conv2D(n_classes, (1, 1), kernel_initializer='he_normal',
113 | data_format=IMAGE_ORDERING , name="seg_feats" ))(o)
114 | o = Conv2DTranspose(n_classes, kernel_size=(64, 64), strides=(
115 | 32, 32), use_bias=False, data_format=IMAGE_ORDERING)(o)
116 |
117 | model = get_segmentation_model(img_input, o)
118 | model.model_name = "fcn_32"
119 | return model
120 |
121 |
122 | def fcn_8_vgg(n_classes, input_height=416, input_width=608, channels=3):
123 | model = fcn_8(n_classes, get_vgg_encoder,
124 | input_height=input_height, input_width=input_width, channels=channels)
125 | model.model_name = "fcn_8_vgg"
126 | return model
127 |
128 |
129 | def fcn_32_vgg(n_classes, input_height=416, input_width=608, channels=3):
130 | model = fcn_32(n_classes, get_vgg_encoder,
131 | input_height=input_height, input_width=input_width, channels=channels)
132 | model.model_name = "fcn_32_vgg"
133 | return model
134 |
135 |
136 | def fcn_8_resnet50(n_classes, input_height=416, input_width=608, channels=3):
137 | model = fcn_8(n_classes, get_resnet50_encoder,
138 | input_height=input_height, input_width=input_width, channels=channels)
139 | model.model_name = "fcn_8_resnet50"
140 | return model
141 |
142 |
143 | def fcn_32_resnet50(n_classes, input_height=416, input_width=608, channels=3):
144 | model = fcn_32(n_classes, get_resnet50_encoder,
145 | input_height=input_height, input_width=input_width, channels=channels)
146 | model.model_name = "fcn_32_resnet50"
147 | return model
148 |
149 |
150 | def fcn_8_mobilenet(n_classes, input_height=224, input_width=224, channels=3):
151 | model = fcn_8(n_classes, get_mobilenet_encoder,
152 | input_height=input_height, input_width=input_width, channels=channels)
153 | model.model_name = "fcn_8_mobilenet"
154 | return model
155 |
156 |
157 | def fcn_32_mobilenet(n_classes, input_height=224, input_width=224, channels=3):
158 | model = fcn_32(n_classes, get_mobilenet_encoder,
159 | input_height=input_height, input_width=input_width, channels=channels)
160 | model.model_name = "fcn_32_mobilenet"
161 | return model
162 |
163 |
164 | if __name__ == '__main__':
165 | m = fcn_8(101)
166 | m = fcn_32(101)
167 |
--------------------------------------------------------------------------------
/keras_segmentation/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | from keras.models import *
2 | from keras.layers import *
3 | import keras.backend as K
4 | import keras
5 | import tensorflow as tf
6 | from .config import IMAGE_ORDERING
7 |
8 | BASE_WEIGHT_PATH = ('https://github.com/fchollet/deep-learning-models/'
9 | 'releases/download/v0.6/')
10 |
11 |
12 | def relu6(x):
13 | return K.relu(x, max_value=6)
14 |
15 |
16 | def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
17 |
18 | channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1
19 | filters = int(filters * alpha)
20 | x = ZeroPadding2D(padding=(1, 1), name='conv1_pad',
21 | data_format=IMAGE_ORDERING)(inputs)
22 | x = Conv2D(filters, kernel, data_format=IMAGE_ORDERING,
23 | padding='valid',
24 | use_bias=False,
25 | strides=strides,
26 | name='conv1')(x)
27 | x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
28 | return Activation(relu6, name='conv1_relu')(x)
29 |
30 |
31 | def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha,
32 | depth_multiplier=1, strides=(1, 1), block_id=1):
33 |
34 | channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1
35 | pointwise_conv_filters = int(pointwise_conv_filters * alpha)
36 |
37 | x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING,
38 | name='conv_pad_%d' % block_id)(inputs)
39 | x = DepthwiseConv2D((3, 3), data_format=IMAGE_ORDERING,
40 | padding='valid',
41 | depth_multiplier=depth_multiplier,
42 | strides=strides,
43 | use_bias=False,
44 | name='conv_dw_%d' % block_id)(x)
45 | x = BatchNormalization(
46 | axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
47 | x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
48 |
49 | x = Conv2D(pointwise_conv_filters, (1, 1), data_format=IMAGE_ORDERING,
50 | padding='same',
51 | use_bias=False,
52 | strides=(1, 1),
53 | name='conv_pw_%d' % block_id)(x)
54 | x = BatchNormalization(axis=channel_axis,
55 | name='conv_pw_%d_bn' % block_id)(x)
56 | return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
57 |
58 |
59 | def get_mobilenet_encoder(input_height=224, input_width=224,
60 | pretrained='imagenet', channels=3):
61 |
62 | # todo add more alpha and stuff
63 |
64 | assert (K.image_data_format() ==
65 | 'channels_last'), "Currently only channels last mode is supported"
66 | assert (IMAGE_ORDERING ==
67 | 'channels_last'), "Currently only channels last mode is supported"
68 |
69 | assert input_height % 32 == 0
70 | assert input_width % 32 == 0
71 |
72 | alpha = 1.0
73 | depth_multiplier = 1
74 | dropout = 1e-3
75 |
76 | img_input = Input(shape=(input_height, input_width, channels))
77 |
78 | x = _conv_block(img_input, 32, alpha, strides=(2, 2))
79 | x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)
80 | f1 = x
81 |
82 | x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
83 | strides=(2, 2), block_id=2)
84 | x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)
85 | f2 = x
86 |
87 | x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
88 | strides=(2, 2), block_id=4)
89 | x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)
90 | f3 = x
91 |
92 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
93 | strides=(2, 2), block_id=6)
94 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)
95 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)
96 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)
97 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)
98 | x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)
99 | f4 = x
100 |
101 | x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
102 | strides=(2, 2), block_id=12)
103 | x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)
104 | f5 = x
105 |
106 | if pretrained == 'imagenet':
107 | model_name = 'mobilenet_%s_%d_tf_no_top.h5' % ('1_0', 224)
108 |
109 | weight_path = BASE_WEIGHT_PATH + model_name
110 | weights_path = tf.keras.utils.get_file(model_name, weight_path)
111 |
112 | Model(img_input, x).load_weights(weights_path, by_name=True, skip_mismatch=True)
113 |
114 | return img_input, [f1, f2, f3, f4, f5]
115 |
--------------------------------------------------------------------------------
/keras_segmentation/models/model.py:
--------------------------------------------------------------------------------
1 | """ Definition for the generic Model class """
2 |
3 |
4 | class Model:
5 | def __init__(self, n_classes, input_height=None, input_width=None):
6 | pass
7 |
--------------------------------------------------------------------------------
/keras_segmentation/models/model_utils.py:
--------------------------------------------------------------------------------
1 | from types import MethodType
2 |
3 | from keras.models import *
4 | from keras.layers import *
5 | import keras.backend as K
6 | from tqdm import tqdm
7 |
8 | from .config import IMAGE_ORDERING
9 | from ..train import train
10 | from ..predict import predict, predict_multiple, evaluate
11 |
12 |
13 | # source m1 , dest m2
14 | def transfer_weights(m1, m2, verbose=True):
15 |
16 | assert len(m1.layers) == len(
17 | m2.layers), "Both models should have same number of layers"
18 |
19 | nSet = 0
20 | nNotSet = 0
21 |
22 | if verbose:
23 | print("Copying weights ")
24 | bar = tqdm(zip(m1.layers, m2.layers))
25 | else:
26 | bar = zip(m1.layers, m2.layers)
27 |
28 | for l, ll in bar:
29 |
30 | if not any([w.shape != ww.shape for w, ww in zip(list(l.weights),
31 | list(ll.weights))]):
32 | if len(list(l.weights)) > 0:
33 | ll.set_weights(l.get_weights())
34 | nSet += 1
35 | else:
36 | nNotSet += 1
37 |
38 | if verbose:
39 | print("Copied weights of %d layers and skipped %d layers" %
40 | (nSet, nNotSet))
41 |
42 |
43 | def resize_image(inp, s, data_format):
44 |
45 | try:
46 |
47 | return Lambda(lambda x: K.resize_images(x,
48 | height_factor=s[0],
49 | width_factor=s[1],
50 | data_format=data_format,
51 | interpolation='bilinear'))(inp)
52 |
53 | except Exception as e:
54 | # if keras is old, then rely on the tf function
55 | # Sorry theano/cntk users!!!
56 | assert data_format == 'channels_last'
57 | assert IMAGE_ORDERING == 'channels_last'
58 |
59 | import tensorflow as tf
60 |
61 | return Lambda(
62 | lambda x: tf.image.resize_images(
63 | x, (K.int_shape(x)[1]*s[0], K.int_shape(x)[2]*s[1]))
64 | )(inp)
65 |
66 |
67 | def get_segmentation_model(input, output):
68 |
69 | img_input = input
70 | o = output
71 |
72 | o_shape = Model(img_input, o).output_shape
73 | i_shape = Model(img_input, o).input_shape
74 |
75 | if IMAGE_ORDERING == 'channels_first':
76 | output_height = o_shape[2]
77 | output_width = o_shape[3]
78 | input_height = i_shape[2]
79 | input_width = i_shape[3]
80 | n_classes = o_shape[1]
81 | o = (Reshape((-1, output_height*output_width)))(o)
82 | o = (Permute((2, 1)))(o)
83 | elif IMAGE_ORDERING == 'channels_last':
84 | output_height = o_shape[1]
85 | output_width = o_shape[2]
86 | input_height = i_shape[1]
87 | input_width = i_shape[2]
88 | n_classes = o_shape[3]
89 | o = (Reshape((output_height*output_width, -1)))(o)
90 |
91 | o = (Activation('softmax'))(o)
92 | model = Model(img_input, o)
93 | model.output_width = output_width
94 | model.output_height = output_height
95 | model.n_classes = n_classes
96 | model.input_height = input_height
97 | model.input_width = input_width
98 | model.model_name = ""
99 |
100 | model.train = MethodType(train, model)
101 | model.predict_segmentation = MethodType(predict, model)
102 | model.predict_multiple = MethodType(predict_multiple, model)
103 | model.evaluate_segmentation = MethodType(evaluate, model)
104 |
105 | return model
106 |
--------------------------------------------------------------------------------
/keras_segmentation/models/pspnet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import keras
3 | from keras.models import *
4 | from keras.layers import *
5 | import keras.backend as K
6 |
7 | from .config import IMAGE_ORDERING
8 | from .model_utils import get_segmentation_model, resize_image
9 | from .vgg16 import get_vgg_encoder
10 | from .basic_models import vanilla_encoder
11 | from .resnet50 import get_resnet50_encoder
12 |
13 |
14 | if IMAGE_ORDERING == 'channels_first':
15 | MERGE_AXIS = 1
16 | elif IMAGE_ORDERING == 'channels_last':
17 | MERGE_AXIS = -1
18 |
19 |
20 | def pool_block(feats, pool_factor):
21 |
22 | if IMAGE_ORDERING == 'channels_first':
23 | h = K.int_shape(feats)[2]
24 | w = K.int_shape(feats)[3]
25 | elif IMAGE_ORDERING == 'channels_last':
26 | h = K.int_shape(feats)[1]
27 | w = K.int_shape(feats)[2]
28 |
29 | pool_size = strides = [
30 | int(np.round(float(h) / pool_factor)),
31 | int(np.round(float(w) / pool_factor))]
32 |
33 | x = AveragePooling2D(pool_size, data_format=IMAGE_ORDERING,
34 | strides=strides, padding='same')(feats)
35 | x = Conv2D(512, (1, 1), data_format=IMAGE_ORDERING,
36 | padding='same', use_bias=False)(x)
37 | x = BatchNormalization()(x)
38 | x = Activation('relu')(x)
39 |
40 | x = resize_image(x, strides, data_format=IMAGE_ORDERING)
41 |
42 | return x
43 |
44 |
45 | def _pspnet(n_classes, encoder, input_height=384, input_width=576, channels=3):
46 |
47 | assert input_height % 192 == 0
48 | assert input_width % 192 == 0
49 |
50 | img_input, levels = encoder(
51 | input_height=input_height, input_width=input_width, channels=channels)
52 | [f1, f2, f3, f4, f5] = levels
53 |
54 | o = f5
55 |
56 | pool_factors = [1, 2, 3, 6]
57 | pool_outs = [o]
58 |
59 | for p in pool_factors:
60 | pooled = pool_block(o, p)
61 | pool_outs.append(pooled)
62 |
63 | o = Concatenate(axis=MERGE_AXIS)(pool_outs)
64 |
65 | o = Conv2D(512, (1, 1), data_format=IMAGE_ORDERING, use_bias=False , name="seg_feats" )(o)
66 | o = BatchNormalization()(o)
67 | o = Activation('relu')(o)
68 |
69 | o = Conv2D(n_classes, (3, 3), data_format=IMAGE_ORDERING,
70 | padding='same')(o)
71 | o = resize_image(o, (8, 8), data_format=IMAGE_ORDERING)
72 |
73 | model = get_segmentation_model(img_input, o)
74 | return model
75 |
76 |
77 | def pspnet(n_classes, input_height=384, input_width=576, channels=3):
78 |
79 | model = _pspnet(n_classes, vanilla_encoder,
80 | input_height=input_height, input_width=input_width, channels=channels)
81 | model.model_name = "pspnet"
82 | return model
83 |
84 |
85 | def vgg_pspnet(n_classes, input_height=384, input_width=576, channels=3):
86 |
87 | model = _pspnet(n_classes, get_vgg_encoder,
88 | input_height=input_height, input_width=input_width, channels=channels)
89 | model.model_name = "vgg_pspnet"
90 | return model
91 |
92 |
93 | def resnet50_pspnet(n_classes, input_height=384, input_width=576, channels=3):
94 |
95 | model = _pspnet(n_classes, get_resnet50_encoder,
96 | input_height=input_height, input_width=input_width, channels=channels)
97 | model.model_name = "resnet50_pspnet"
98 | return model
99 |
100 |
101 | def pspnet_50(n_classes, input_height=473, input_width=473, channels=3):
102 | from ._pspnet_2 import _build_pspnet
103 |
104 | nb_classes = n_classes
105 | resnet_layers = 50
106 | input_shape = (input_height, input_width)
107 | model = _build_pspnet(nb_classes=nb_classes,
108 | resnet_layers=resnet_layers,
109 | input_shape=input_shape, channels=channels)
110 | model.model_name = "pspnet_50"
111 | return model
112 |
113 |
114 | def pspnet_101(n_classes, input_height=473, input_width=473, channels=3):
115 | from ._pspnet_2 import _build_pspnet
116 |
117 | nb_classes = n_classes
118 | resnet_layers = 101
119 | input_shape = (input_height, input_width)
120 | model = _build_pspnet(nb_classes=nb_classes,
121 | resnet_layers=resnet_layers,
122 | input_shape=input_shape, channels=channels)
123 | model.model_name = "pspnet_101"
124 | return model
125 |
126 |
127 | # def mobilenet_pspnet( n_classes , input_height=224, input_width=224 ):
128 |
129 | # model = _pspnet(n_classes, get_mobilenet_encoder,
130 | # input_height=input_height, input_width=input_width)
131 | # model.model_name = "mobilenet_pspnet"
132 | # return model
133 |
134 |
135 | if __name__ == '__main__':
136 |
137 | m = _pspnet(101, vanilla_encoder)
138 | # m = _pspnet( 101 , get_mobilenet_encoder ,True , 224 , 224 )
139 | m = _pspnet(101, get_vgg_encoder)
140 | m = _pspnet(101, get_resnet50_encoder)
141 |
--------------------------------------------------------------------------------
/keras_segmentation/models/resnet50.py:
--------------------------------------------------------------------------------
1 | import keras
2 | from keras.models import *
3 | from keras.layers import *
4 | from keras import layers
5 | import tensorflow as tf
6 |
7 | # Source:
8 | # https://github.com/fchollet/deep-learning-models/blob/master/resnet50.py
9 |
10 |
11 | from .config import IMAGE_ORDERING
12 |
13 |
14 | if IMAGE_ORDERING == 'channels_first':
15 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \
16 | "releases/download/v0.2/" \
17 | "resnet50_weights_th_dim_ordering_th_kernels_notop.h5"
18 | elif IMAGE_ORDERING == 'channels_last':
19 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \
20 | "releases/download/v0.2/" \
21 | "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5"
22 |
23 |
24 | def one_side_pad(x):
25 | x = ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING)(x)
26 | if IMAGE_ORDERING == 'channels_first':
27 | x = Lambda(lambda x: x[:, :, :-1, :-1])(x)
28 | elif IMAGE_ORDERING == 'channels_last':
29 | x = Lambda(lambda x: x[:, :-1, :-1, :])(x)
30 | return x
31 |
32 |
33 | def identity_block(input_tensor, kernel_size, filters, stage, block):
34 | """The identity block is the block that has no conv layer at shortcut.
35 | # Arguments
36 | input_tensor: input tensor
37 | kernel_size: defualt 3, the kernel size of middle conv layer at
38 | main path
39 | filters: list of integers, the filterss of 3 conv layer at main path
40 | stage: integer, current stage label, used for generating layer names
41 | block: 'a','b'..., current block label, used for generating layer names
42 | # Returns
43 | Output tensor for the block.
44 | """
45 | filters1, filters2, filters3 = filters
46 |
47 | if IMAGE_ORDERING == 'channels_last':
48 | bn_axis = 3
49 | else:
50 | bn_axis = 1
51 |
52 | conv_name_base = 'res' + str(stage) + block + '_branch'
53 | bn_name_base = 'bn' + str(stage) + block + '_branch'
54 |
55 | x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING,
56 | name=conv_name_base + '2a')(input_tensor)
57 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
58 | x = Activation('relu')(x)
59 |
60 | x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING,
61 | padding='same', name=conv_name_base + '2b')(x)
62 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
63 | x = Activation('relu')(x)
64 |
65 | x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING,
66 | name=conv_name_base + '2c')(x)
67 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
68 |
69 | x = layers.add([x, input_tensor])
70 | x = Activation('relu')(x)
71 | return x
72 |
73 |
74 | def conv_block(input_tensor, kernel_size, filters, stage, block,
75 | strides=(2, 2)):
76 | """conv_block is the block that has a conv layer at shortcut
77 | # Arguments
78 | input_tensor: input tensor
79 | kernel_size: defualt 3, the kernel size of middle conv layer at
80 | main path
81 | filters: list of integers, the filterss of 3 conv layer at main path
82 | stage: integer, current stage label, used for generating layer names
83 | block: 'a','b'..., current block label, used for generating layer names
84 | # Returns
85 | Output tensor for the block.
86 | Note that from stage 3, the first conv layer at main path is with
87 | strides=(2,2) and the shortcut should have strides=(2,2) as well
88 | """
89 | filters1, filters2, filters3 = filters
90 |
91 | if IMAGE_ORDERING == 'channels_last':
92 | bn_axis = 3
93 | else:
94 | bn_axis = 1
95 |
96 | conv_name_base = 'res' + str(stage) + block + '_branch'
97 | bn_name_base = 'bn' + str(stage) + block + '_branch'
98 |
99 | x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
100 | name=conv_name_base + '2a')(input_tensor)
101 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
102 | x = Activation('relu')(x)
103 |
104 | x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING,
105 | padding='same', name=conv_name_base + '2b')(x)
106 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
107 | x = Activation('relu')(x)
108 |
109 | x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING,
110 | name=conv_name_base + '2c')(x)
111 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
112 |
113 | shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING,
114 | strides=strides, name=conv_name_base + '1')(input_tensor)
115 | shortcut = BatchNormalization(
116 | axis=bn_axis, name=bn_name_base + '1')(shortcut)
117 |
118 | x = layers.add([x, shortcut])
119 | x = Activation('relu')(x)
120 | return x
121 |
122 |
123 | def get_resnet50_encoder(input_height=224, input_width=224,
124 | pretrained='imagenet',
125 | include_top=True, weights='imagenet',
126 | input_tensor=None, input_shape=None,
127 | pooling=None,
128 | classes=1000, channels=3):
129 |
130 | assert input_height % 32 == 0
131 | assert input_width % 32 == 0
132 |
133 | if IMAGE_ORDERING == 'channels_first':
134 | img_input = Input(shape=(channels, input_height, input_width))
135 | elif IMAGE_ORDERING == 'channels_last':
136 | img_input = Input(shape=(input_height, input_width, channels))
137 |
138 | if IMAGE_ORDERING == 'channels_last':
139 | bn_axis = 3
140 | else:
141 | bn_axis = 1
142 |
143 | x = ZeroPadding2D((3, 3), data_format=IMAGE_ORDERING)(img_input)
144 | x = Conv2D(64, (7, 7), data_format=IMAGE_ORDERING,
145 | strides=(2, 2), name='conv1')(x)
146 | f1 = x
147 |
148 | x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
149 | x = Activation('relu')(x)
150 | x = MaxPooling2D((3, 3), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
151 |
152 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
153 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
154 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
155 | f2 = one_side_pad(x)
156 |
157 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
158 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
159 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
160 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
161 | f3 = x
162 |
163 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
164 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
165 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
166 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
167 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
168 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
169 | f4 = x
170 |
171 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
172 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
173 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
174 | f5 = x
175 |
176 | x = AveragePooling2D(
177 | (7, 7), data_format=IMAGE_ORDERING, name='avg_pool')(x)
178 | # f6 = x
179 |
180 | if pretrained == 'imagenet':
181 | weights_path = tf.keras.utils.get_file(
182 | pretrained_url.split("/")[-1], pretrained_url)
183 | Model(img_input, x).load_weights(weights_path, by_name=True, skip_mismatch=True)
184 |
185 | return img_input, [f1, f2, f3, f4, f5]
186 |
--------------------------------------------------------------------------------
/keras_segmentation/models/segnet.py:
--------------------------------------------------------------------------------
1 | from keras.models import *
2 | from keras.layers import *
3 |
4 | from .config import IMAGE_ORDERING
5 | from .model_utils import get_segmentation_model
6 | from .vgg16 import get_vgg_encoder
7 | from .mobilenet import get_mobilenet_encoder
8 | from .basic_models import vanilla_encoder
9 | from .resnet50 import get_resnet50_encoder
10 |
11 |
12 | def segnet_decoder(f, n_classes, n_up=3):
13 |
14 | assert n_up >= 2
15 |
16 | o = f
17 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
18 | o = (Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING))(o)
19 | o = (BatchNormalization())(o)
20 |
21 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
22 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
23 | o = (Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING))(o)
24 | o = (BatchNormalization())(o)
25 |
26 | for _ in range(n_up-2):
27 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
28 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
29 | o = (Conv2D(128, (3, 3), padding='valid',
30 | data_format=IMAGE_ORDERING))(o)
31 | o = (BatchNormalization())(o)
32 |
33 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
34 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
35 | o = (Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING, name="seg_feats"))(o)
36 | o = (BatchNormalization())(o)
37 |
38 | o = Conv2D(n_classes, (3, 3), padding='same',
39 | data_format=IMAGE_ORDERING)(o)
40 |
41 | return o
42 |
43 |
44 | def _segnet(n_classes, encoder, input_height=416, input_width=608,
45 | encoder_level=3, channels=3):
46 |
47 | img_input, levels = encoder(
48 | input_height=input_height, input_width=input_width, channels=channels)
49 |
50 | feat = levels[encoder_level]
51 | o = segnet_decoder(feat, n_classes, n_up=3)
52 | model = get_segmentation_model(img_input, o)
53 |
54 | return model
55 |
56 |
57 | def segnet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3):
58 |
59 | model = _segnet(n_classes, vanilla_encoder, input_height=input_height,
60 | input_width=input_width, encoder_level=encoder_level, channels=channels)
61 | model.model_name = "segnet"
62 | return model
63 |
64 |
65 | def vgg_segnet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3):
66 |
67 | model = _segnet(n_classes, get_vgg_encoder, input_height=input_height,
68 | input_width=input_width, encoder_level=encoder_level, channels=channels)
69 | model.model_name = "vgg_segnet"
70 | return model
71 |
72 |
73 | def resnet50_segnet(n_classes, input_height=416, input_width=608,
74 | encoder_level=3, channels=3):
75 |
76 | model = _segnet(n_classes, get_resnet50_encoder, input_height=input_height,
77 | input_width=input_width, encoder_level=encoder_level, channels=channels)
78 | model.model_name = "resnet50_segnet"
79 | return model
80 |
81 |
82 | def mobilenet_segnet(n_classes, input_height=224, input_width=224,
83 | encoder_level=3, channels=3):
84 |
85 | model = _segnet(n_classes, get_mobilenet_encoder,
86 | input_height=input_height,
87 | input_width=input_width, encoder_level=encoder_level, channels=channels)
88 | model.model_name = "mobilenet_segnet"
89 | return model
90 |
91 |
92 | if __name__ == '__main__':
93 | m = vgg_segnet(101)
94 | m = segnet(101)
95 | # m = mobilenet_segnet( 101 )
96 | # from keras.utils import plot_model
97 | # plot_model( m , show_shapes=True , to_file='model.png')
98 |
--------------------------------------------------------------------------------
/keras_segmentation/models/unet.py:
--------------------------------------------------------------------------------
1 | from keras.models import *
2 | from keras.layers import *
3 |
4 | from .config import IMAGE_ORDERING
5 | from .model_utils import get_segmentation_model
6 | from .vgg16 import get_vgg_encoder
7 | from .mobilenet import get_mobilenet_encoder
8 | from .basic_models import vanilla_encoder
9 | from .resnet50 import get_resnet50_encoder
10 |
11 |
12 | if IMAGE_ORDERING == 'channels_first':
13 | MERGE_AXIS = 1
14 | elif IMAGE_ORDERING == 'channels_last':
15 | MERGE_AXIS = -1
16 |
17 |
18 | def unet_mini(n_classes, input_height=360, input_width=480, channels=3):
19 |
20 | if IMAGE_ORDERING == 'channels_first':
21 | img_input = Input(shape=(channels, input_height, input_width))
22 | elif IMAGE_ORDERING == 'channels_last':
23 | img_input = Input(shape=(input_height, input_width, channels))
24 |
25 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING,
26 | activation='relu', padding='same')(img_input)
27 | conv1 = Dropout(0.2)(conv1)
28 | conv1 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING,
29 | activation='relu', padding='same')(conv1)
30 | pool1 = MaxPooling2D((2, 2), data_format=IMAGE_ORDERING)(conv1)
31 |
32 | conv2 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING,
33 | activation='relu', padding='same')(pool1)
34 | conv2 = Dropout(0.2)(conv2)
35 | conv2 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING,
36 | activation='relu', padding='same')(conv2)
37 | pool2 = MaxPooling2D((2, 2), data_format=IMAGE_ORDERING)(conv2)
38 |
39 | conv3 = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING,
40 | activation='relu', padding='same')(pool2)
41 | conv3 = Dropout(0.2)(conv3)
42 | conv3 = Conv2D(128, (3, 3), data_format=IMAGE_ORDERING,
43 | activation='relu', padding='same')(conv3)
44 |
45 | up1 = concatenate([UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(
46 | conv3), conv2], axis=MERGE_AXIS)
47 | conv4 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING,
48 | activation='relu', padding='same')(up1)
49 | conv4 = Dropout(0.2)(conv4)
50 | conv4 = Conv2D(64, (3, 3), data_format=IMAGE_ORDERING,
51 | activation='relu', padding='same')(conv4)
52 |
53 | up2 = concatenate([UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(
54 | conv4), conv1], axis=MERGE_AXIS)
55 | conv5 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING,
56 | activation='relu', padding='same')(up2)
57 | conv5 = Dropout(0.2)(conv5)
58 | conv5 = Conv2D(32, (3, 3), data_format=IMAGE_ORDERING,
59 | activation='relu', padding='same' , name="seg_feats")(conv5)
60 |
61 | o = Conv2D(n_classes, (1, 1), data_format=IMAGE_ORDERING,
62 | padding='same')(conv5)
63 |
64 | model = get_segmentation_model(img_input, o)
65 | model.model_name = "unet_mini"
66 | return model
67 |
68 |
69 | def _unet(n_classes, encoder, l1_skip_conn=True, input_height=416,
70 | input_width=608, channels=3):
71 |
72 | img_input, levels = encoder(
73 | input_height=input_height, input_width=input_width, channels=channels)
74 | [f1, f2, f3, f4, f5] = levels
75 |
76 | o = f4
77 |
78 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
79 | o = (Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING))(o)
80 | o = (BatchNormalization())(o)
81 |
82 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
83 | o = (concatenate([o, f3], axis=MERGE_AXIS))
84 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
85 | o = (Conv2D(256, (3, 3), padding='valid', activation='relu' , data_format=IMAGE_ORDERING))(o)
86 | o = (BatchNormalization())(o)
87 |
88 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
89 | o = (concatenate([o, f2], axis=MERGE_AXIS))
90 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
91 | o = (Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING))(o)
92 | o = (BatchNormalization())(o)
93 |
94 | o = (UpSampling2D((2, 2), data_format=IMAGE_ORDERING))(o)
95 |
96 | if l1_skip_conn:
97 | o = (concatenate([o, f1], axis=MERGE_AXIS))
98 |
99 | o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
100 | o = (Conv2D(64, (3, 3), padding='valid', activation='relu', data_format=IMAGE_ORDERING, name="seg_feats"))(o)
101 | o = (BatchNormalization())(o)
102 |
103 | o = Conv2D(n_classes, (3, 3), padding='same',
104 | data_format=IMAGE_ORDERING)(o)
105 |
106 | model = get_segmentation_model(img_input, o)
107 |
108 | return model
109 |
110 |
111 | def unet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3):
112 |
113 | model = _unet(n_classes, vanilla_encoder,
114 | input_height=input_height, input_width=input_width, channels=channels)
115 | model.model_name = "unet"
116 | return model
117 |
118 |
119 | def vgg_unet(n_classes, input_height=416, input_width=608, encoder_level=3, channels=3):
120 |
121 | model = _unet(n_classes, get_vgg_encoder,
122 | input_height=input_height, input_width=input_width, channels=channels)
123 | model.model_name = "vgg_unet"
124 | return model
125 |
126 |
127 | def resnet50_unet(n_classes, input_height=416, input_width=608,
128 | encoder_level=3, channels=3):
129 |
130 | model = _unet(n_classes, get_resnet50_encoder,
131 | input_height=input_height, input_width=input_width, channels=channels)
132 | model.model_name = "resnet50_unet"
133 | return model
134 |
135 |
136 | def mobilenet_unet(n_classes, input_height=224, input_width=224,
137 | encoder_level=3, channels=3):
138 |
139 | model = _unet(n_classes, get_mobilenet_encoder,
140 | input_height=input_height, input_width=input_width, channels=channels)
141 | model.model_name = "mobilenet_unet"
142 | return model
143 |
144 |
145 | if __name__ == '__main__':
146 | m = unet_mini(101)
147 | m = _unet(101, vanilla_encoder)
148 | # m = _unet( 101 , get_mobilenet_encoder ,True , 224 , 224 )
149 | m = _unet(101, get_vgg_encoder)
150 | m = _unet(101, get_resnet50_encoder)
151 |
--------------------------------------------------------------------------------
/keras_segmentation/models/vgg16.py:
--------------------------------------------------------------------------------
1 | import keras
2 | from keras.models import *
3 | from keras.layers import *
4 | import tensorflow as tf
5 | from .config import IMAGE_ORDERING
6 |
7 | if IMAGE_ORDERING == 'channels_first':
8 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \
9 | "releases/download/v0.1/" \
10 | "vgg16_weights_th_dim_ordering_th_kernels_notop.h5"
11 | elif IMAGE_ORDERING == 'channels_last':
12 | pretrained_url = "https://github.com/fchollet/deep-learning-models/" \
13 | "releases/download/v0.1/" \
14 | "vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"
15 |
16 |
17 | def get_vgg_encoder(input_height=224, input_width=224, pretrained='imagenet', channels=3):
18 |
19 | assert input_height % 32 == 0
20 | assert input_width % 32 == 0
21 |
22 | if IMAGE_ORDERING == 'channels_first':
23 | img_input = Input(shape=(channels, input_height, input_width))
24 | elif IMAGE_ORDERING == 'channels_last':
25 | img_input = Input(shape=(input_height, input_width, channels))
26 |
27 | x = Conv2D(64, (3, 3), activation='relu', padding='same',
28 | name='block1_conv1', data_format=IMAGE_ORDERING)(img_input)
29 | x = Conv2D(64, (3, 3), activation='relu', padding='same',
30 | name='block1_conv2', data_format=IMAGE_ORDERING)(x)
31 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool',
32 | data_format=IMAGE_ORDERING)(x)
33 | f1 = x
34 | # Block 2
35 | x = Conv2D(128, (3, 3), activation='relu', padding='same',
36 | name='block2_conv1', data_format=IMAGE_ORDERING)(x)
37 | x = Conv2D(128, (3, 3), activation='relu', padding='same',
38 | name='block2_conv2', data_format=IMAGE_ORDERING)(x)
39 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool',
40 | data_format=IMAGE_ORDERING)(x)
41 | f2 = x
42 |
43 | # Block 3
44 | x = Conv2D(256, (3, 3), activation='relu', padding='same',
45 | name='block3_conv1', data_format=IMAGE_ORDERING)(x)
46 | x = Conv2D(256, (3, 3), activation='relu', padding='same',
47 | name='block3_conv2', data_format=IMAGE_ORDERING)(x)
48 | x = Conv2D(256, (3, 3), activation='relu', padding='same',
49 | name='block3_conv3', data_format=IMAGE_ORDERING)(x)
50 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool',
51 | data_format=IMAGE_ORDERING)(x)
52 | f3 = x
53 |
54 | # Block 4
55 | x = Conv2D(512, (3, 3), activation='relu', padding='same',
56 | name='block4_conv1', data_format=IMAGE_ORDERING)(x)
57 | x = Conv2D(512, (3, 3), activation='relu', padding='same',
58 | name='block4_conv2', data_format=IMAGE_ORDERING)(x)
59 | x = Conv2D(512, (3, 3), activation='relu', padding='same',
60 | name='block4_conv3', data_format=IMAGE_ORDERING)(x)
61 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool',
62 | data_format=IMAGE_ORDERING)(x)
63 | f4 = x
64 |
65 | # Block 5
66 | x = Conv2D(512, (3, 3), activation='relu', padding='same',
67 | name='block5_conv1', data_format=IMAGE_ORDERING)(x)
68 | x = Conv2D(512, (3, 3), activation='relu', padding='same',
69 | name='block5_conv2', data_format=IMAGE_ORDERING)(x)
70 | x = Conv2D(512, (3, 3), activation='relu', padding='same',
71 | name='block5_conv3', data_format=IMAGE_ORDERING)(x)
72 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool',
73 | data_format=IMAGE_ORDERING)(x)
74 | f5 = x
75 |
76 | if pretrained == 'imagenet':
77 | VGG_Weights_path = tf.keras.utils.get_file(
78 | pretrained_url.split("/")[-1], pretrained_url)
79 | Model(img_input, x).load_weights(VGG_Weights_path, by_name=True, skip_mismatch=True)
80 |
81 | return img_input, [f1, f2, f3, f4, f5]
82 |
--------------------------------------------------------------------------------
/keras_segmentation/predict.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | import json
4 | import os
5 | import six
6 |
7 | import cv2
8 | import numpy as np
9 | from tqdm import tqdm
10 | from time import time
11 |
12 | from .train import find_latest_checkpoint
13 | from .data_utils.data_loader import get_image_array, get_segmentation_array,\
14 | DATA_LOADER_SEED, class_colors, get_pairs_from_paths
15 | from .models.config import IMAGE_ORDERING
16 |
17 |
18 | random.seed(DATA_LOADER_SEED)
19 |
20 |
21 | def model_from_checkpoint_path(checkpoints_path):
22 |
23 | from .models.all_models import model_from_name
24 | assert (os.path.isfile(checkpoints_path+"_config.json")
25 | ), "Checkpoint not found."
26 | model_config = json.loads(
27 | open(checkpoints_path+"_config.json", "r").read())
28 | latest_weights = find_latest_checkpoint(checkpoints_path)
29 | assert (latest_weights is not None), "Checkpoint not found."
30 | model = model_from_name[model_config['model_class']](
31 | model_config['n_classes'], input_height=model_config['input_height'],
32 | input_width=model_config['input_width'])
33 | print("loaded weights ", latest_weights)
34 | status = model.load_weights(latest_weights)
35 |
36 | if status is not None:
37 | status.expect_partial()
38 |
39 | return model
40 |
41 |
42 | def get_colored_segmentation_image(seg_arr, n_classes, colors=class_colors):
43 | output_height = seg_arr.shape[0]
44 | output_width = seg_arr.shape[1]
45 |
46 | seg_img = np.zeros((output_height, output_width, 3))
47 |
48 | for c in range(n_classes):
49 | seg_arr_c = seg_arr[:, :] == c
50 | seg_img[:, :, 0] += ((seg_arr_c)*(colors[c][0])).astype('uint8')
51 | seg_img[:, :, 1] += ((seg_arr_c)*(colors[c][1])).astype('uint8')
52 | seg_img[:, :, 2] += ((seg_arr_c)*(colors[c][2])).astype('uint8')
53 |
54 | return seg_img
55 |
56 |
57 | def get_legends(class_names, colors=class_colors):
58 |
59 | n_classes = len(class_names)
60 | legend = np.zeros(((len(class_names) * 25) + 25, 125, 3),
61 | dtype="uint8") + 255
62 |
63 | class_names_colors = enumerate(zip(class_names[:n_classes],
64 | colors[:n_classes]))
65 |
66 | for (i, (class_name, color)) in class_names_colors:
67 | color = [int(c) for c in color]
68 | cv2.putText(legend, class_name, (5, (i * 25) + 17),
69 | cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1)
70 | cv2.rectangle(legend, (100, (i * 25)), (125, (i * 25) + 25),
71 | tuple(color), -1)
72 |
73 | return legend
74 |
75 |
76 | def overlay_seg_image(inp_img, seg_img):
77 | orininal_h = inp_img.shape[0]
78 | orininal_w = inp_img.shape[1]
79 | seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST)
80 |
81 | fused_img = (inp_img/2 + seg_img/2).astype('uint8')
82 | return fused_img
83 |
84 |
85 | def concat_lenends(seg_img, legend_img):
86 |
87 | new_h = np.maximum(seg_img.shape[0], legend_img.shape[0])
88 | new_w = seg_img.shape[1] + legend_img.shape[1]
89 |
90 | out_img = np.zeros((new_h, new_w, 3)).astype('uint8') + legend_img[0, 0, 0]
91 |
92 | out_img[:legend_img.shape[0], : legend_img.shape[1]] = np.copy(legend_img)
93 | out_img[:seg_img.shape[0], legend_img.shape[1]:] = np.copy(seg_img)
94 |
95 | return out_img
96 |
97 |
98 | def visualize_segmentation(seg_arr, inp_img=None, n_classes=None,
99 | colors=class_colors, class_names=None,
100 | overlay_img=False, show_legends=False,
101 | prediction_width=None, prediction_height=None):
102 |
103 | if n_classes is None:
104 | n_classes = np.max(seg_arr)
105 |
106 | seg_img = get_colored_segmentation_image(seg_arr, n_classes, colors=colors)
107 |
108 | if inp_img is not None:
109 | original_h = inp_img.shape[0]
110 | original_w = inp_img.shape[1]
111 | seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
112 |
113 | if (prediction_height is not None) and (prediction_width is not None):
114 | seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST)
115 | if inp_img is not None:
116 | inp_img = cv2.resize(inp_img,
117 | (prediction_width, prediction_height))
118 |
119 | if overlay_img:
120 | assert inp_img is not None
121 | seg_img = overlay_seg_image(inp_img, seg_img)
122 |
123 | if show_legends:
124 | assert class_names is not None
125 | legend_img = get_legends(class_names, colors=colors)
126 |
127 | seg_img = concat_lenends(seg_img, legend_img)
128 |
129 | return seg_img
130 |
131 |
132 | def predict(model=None, inp=None, out_fname=None,
133 | checkpoints_path=None, overlay_img=False,
134 | class_names=None, show_legends=False, colors=class_colors,
135 | prediction_width=None, prediction_height=None,
136 | read_image_type=1):
137 |
138 | if model is None and (checkpoints_path is not None):
139 | model = model_from_checkpoint_path(checkpoints_path)
140 |
141 | assert (inp is not None)
142 | assert ((type(inp) is np.ndarray) or isinstance(inp, six.string_types)),\
143 | "Input should be the CV image or the input file name"
144 |
145 | if isinstance(inp, six.string_types):
146 | inp = cv2.imread(inp, read_image_type)
147 |
148 | assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 "
149 |
150 | output_width = model.output_width
151 | output_height = model.output_height
152 | input_width = model.input_width
153 | input_height = model.input_height
154 | n_classes = model.n_classes
155 |
156 | x = get_image_array(inp, input_width, input_height,
157 | ordering=IMAGE_ORDERING)
158 | pr = model.predict(np.array([x]))[0]
159 | pr = pr.reshape((output_height, output_width, n_classes)).argmax(axis=2)
160 |
161 | seg_img = visualize_segmentation(pr, inp, n_classes=n_classes,
162 | colors=colors, overlay_img=overlay_img,
163 | show_legends=show_legends,
164 | class_names=class_names,
165 | prediction_width=prediction_width,
166 | prediction_height=prediction_height)
167 |
168 | if out_fname is not None:
169 | cv2.imwrite(out_fname, seg_img)
170 |
171 | return pr
172 |
173 |
174 | def predict_multiple(model=None, inps=None, inp_dir=None, out_dir=None,
175 | checkpoints_path=None, overlay_img=False,
176 | class_names=None, show_legends=False, colors=class_colors,
177 | prediction_width=None, prediction_height=None, read_image_type=1):
178 |
179 | if model is None and (checkpoints_path is not None):
180 | model = model_from_checkpoint_path(checkpoints_path)
181 |
182 | if inps is None and (inp_dir is not None):
183 | inps = glob.glob(os.path.join(inp_dir, "*.jpg")) + glob.glob(
184 | os.path.join(inp_dir, "*.png")) + \
185 | glob.glob(os.path.join(inp_dir, "*.jpeg"))
186 | inps = sorted(inps)
187 |
188 | assert type(inps) is list
189 |
190 | all_prs = []
191 |
192 | if not out_dir is None:
193 | if not os.path.exists(out_dir):
194 | os.makedirs(out_dir)
195 |
196 |
197 | for i, inp in enumerate(tqdm(inps)):
198 | if out_dir is None:
199 | out_fname = None
200 | else:
201 | if isinstance(inp, six.string_types):
202 | out_fname = os.path.join(out_dir, os.path.basename(inp))
203 | else:
204 | out_fname = os.path.join(out_dir, str(i) + ".jpg")
205 |
206 | pr = predict(model, inp, out_fname,
207 | overlay_img=overlay_img, class_names=class_names,
208 | show_legends=show_legends, colors=colors,
209 | prediction_width=prediction_width,
210 | prediction_height=prediction_height, read_image_type=read_image_type)
211 |
212 | all_prs.append(pr)
213 |
214 | return all_prs
215 |
216 |
217 | def set_video(inp, video_name):
218 | cap = cv2.VideoCapture(inp)
219 | fps = int(cap.get(cv2.CAP_PROP_FPS))
220 | video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
221 | video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
222 | size = (video_width, video_height)
223 | fourcc = cv2.VideoWriter_fourcc(*"XVID")
224 | video = cv2.VideoWriter(video_name, fourcc, fps, size)
225 | return cap, video, fps
226 |
227 |
228 | def predict_video(model=None, inp=None, output=None,
229 | checkpoints_path=None, display=False, overlay_img=True,
230 | class_names=None, show_legends=False, colors=class_colors,
231 | prediction_width=None, prediction_height=None):
232 |
233 | if model is None and (checkpoints_path is not None):
234 | model = model_from_checkpoint_path(checkpoints_path)
235 | n_classes = model.n_classes
236 |
237 | cap, video, fps = set_video(inp, output)
238 | while(cap.isOpened()):
239 | prev_time = time()
240 | ret, frame = cap.read()
241 | if frame is not None:
242 | pr = predict(model=model, inp=frame)
243 | fused_img = visualize_segmentation(
244 | pr, frame, n_classes=n_classes,
245 | colors=colors,
246 | overlay_img=overlay_img,
247 | show_legends=show_legends,
248 | class_names=class_names,
249 | prediction_width=prediction_width,
250 | prediction_height=prediction_height
251 | )
252 | else:
253 | break
254 | print("FPS: {}".format(1/(time() - prev_time)))
255 | if output is not None:
256 | video.write(fused_img)
257 | if display:
258 | cv2.imshow('Frame masked', fused_img)
259 | if cv2.waitKey(fps) & 0xFF == ord('q'):
260 | break
261 | cap.release()
262 | if output is not None:
263 | video.release()
264 | cv2.destroyAllWindows()
265 |
266 |
267 | def evaluate(model=None, inp_images=None, annotations=None,
268 | inp_images_dir=None, annotations_dir=None, checkpoints_path=None, read_image_type=1):
269 |
270 | if model is None:
271 | assert (checkpoints_path is not None),\
272 | "Please provide the model or the checkpoints_path"
273 | model = model_from_checkpoint_path(checkpoints_path)
274 |
275 | if inp_images is None:
276 | assert (inp_images_dir is not None),\
277 | "Please provide inp_images or inp_images_dir"
278 | assert (annotations_dir is not None),\
279 | "Please provide inp_images or inp_images_dir"
280 |
281 | paths = get_pairs_from_paths(inp_images_dir, annotations_dir)
282 | paths = list(zip(*paths))
283 | inp_images = list(paths[0])
284 | annotations = list(paths[1])
285 |
286 | assert type(inp_images) is list
287 | assert type(annotations) is list
288 |
289 | tp = np.zeros(model.n_classes)
290 | fp = np.zeros(model.n_classes)
291 | fn = np.zeros(model.n_classes)
292 | n_pixels = np.zeros(model.n_classes)
293 |
294 | for inp, ann in tqdm(zip(inp_images, annotations)):
295 | pr = predict(model, inp, read_image_type=read_image_type)
296 | gt = get_segmentation_array(ann, model.n_classes,
297 | model.output_width, model.output_height,
298 | no_reshape=True, read_image_type=read_image_type)
299 | gt = gt.argmax(-1)
300 | pr = pr.flatten()
301 | gt = gt.flatten()
302 |
303 | for cl_i in range(model.n_classes):
304 |
305 | tp[cl_i] += np.sum((pr == cl_i) * (gt == cl_i))
306 | fp[cl_i] += np.sum((pr == cl_i) * ((gt != cl_i)))
307 | fn[cl_i] += np.sum((pr != cl_i) * ((gt == cl_i)))
308 | n_pixels[cl_i] += np.sum(gt == cl_i)
309 |
310 | cl_wise_score = tp / (tp + fp + fn + 0.000000000001)
311 | n_pixels_norm = n_pixels / np.sum(n_pixels)
312 | frequency_weighted_IU = np.sum(cl_wise_score*n_pixels_norm)
313 | mean_IU = np.mean(cl_wise_score)
314 |
315 | return {
316 | "frequency_weighted_IU": frequency_weighted_IU,
317 | "mean_IU": mean_IU,
318 | "class_wise_IU": cl_wise_score
319 | }
320 |
--------------------------------------------------------------------------------
/keras_segmentation/pretrained.py:
--------------------------------------------------------------------------------
1 | import keras
2 | import tensorflow as tf
3 | from .models.all_models import model_from_name
4 |
5 |
6 | def model_from_checkpoint_path(model_config, latest_weights):
7 |
8 | model = model_from_name[model_config['model_class']](
9 | model_config['n_classes'], input_height=model_config['input_height'],
10 | input_width=model_config['input_width'])
11 | model.load_weights(latest_weights)
12 | return model
13 |
14 |
15 | def resnet_pspnet_VOC12_v0_1():
16 |
17 | model_config = {
18 | "output_height": 96,
19 | "input_height": 384,
20 | "input_width": 576,
21 | "n_classes": 151,
22 | "model_class": "resnet50_pspnet",
23 | "output_width": 144
24 | }
25 |
26 | REPO_URL = "https://github.com/divamgupta/image-segmentation-keras"
27 | MODEL_PATH = "pretrained_model_1/r2_voc12_resnetpspnet_384x576.24"
28 | model_url = "{0}/releases/download/{1}".format(REPO_URL, MODEL_PATH)
29 | latest_weights = tf.keras.utils.get_file(model_url.split("/")[-1], model_url)
30 |
31 | return model_from_checkpoint_path(model_config, latest_weights)
32 |
33 |
34 | # pretrained model converted from caffe by Vladkryvoruchko ... thanks !
35 | def pspnet_50_ADE_20K():
36 |
37 | model_config = {
38 | "input_height": 473,
39 | "input_width": 473,
40 | "n_classes": 150,
41 | "model_class": "pspnet_50",
42 | }
43 |
44 | model_url = "https://www.dropbox.com/s/" \
45 | "0uxn14y26jcui4v/pspnet50_ade20k.h5?dl=1"
46 | latest_weights = tf.keras.utils.get_file("pspnet50_ade20k.h5", model_url)
47 |
48 | return model_from_checkpoint_path(model_config, latest_weights)
49 |
50 |
51 | def pspnet_101_cityscapes():
52 |
53 | model_config = {
54 | "input_height": 713,
55 | "input_width": 713,
56 | "n_classes": 19,
57 | "model_class": "pspnet_101",
58 | }
59 |
60 | model_url = "https://www.dropbox.com/s/" \
61 | "c17g94n946tpalb/pspnet101_cityscapes.h5?dl=1"
62 | latest_weights = tf.keras.utils.get_file("pspnet101_cityscapes.h5", model_url)
63 |
64 | return model_from_checkpoint_path(model_config, latest_weights)
65 |
66 |
67 | def pspnet_101_voc12():
68 |
69 | model_config = {
70 | "input_height": 473,
71 | "input_width": 473,
72 | "n_classes": 21,
73 | "model_class": "pspnet_101",
74 | }
75 |
76 | model_url = "https://www.dropbox.com/s/" \
77 | "uvqj2cjo4b9c5wg/pspnet101_voc2012.h5?dl=1"
78 | latest_weights = tf.keras.utils.get_file("pspnet101_voc2012.h5", model_url)
79 |
80 | return model_from_checkpoint_path(model_config, latest_weights)
81 |
--------------------------------------------------------------------------------
/keras_segmentation/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from .data_utils.data_loader import image_segmentation_generator, \
5 | verify_segmentation_dataset
6 | import six
7 | from keras.callbacks import Callback
8 | from keras.callbacks import ModelCheckpoint
9 | import tensorflow as tf
10 | import glob
11 | import sys
12 |
13 | def find_latest_checkpoint(checkpoints_path, fail_safe=True):
14 |
15 | # This is legacy code, there should always be a "checkpoint" file in your directory
16 |
17 | def get_epoch_number_from_path(path):
18 | return path.replace(checkpoints_path, "").strip(".")
19 |
20 | # Get all matching files
21 | all_checkpoint_files = glob.glob(checkpoints_path + ".*")
22 | if len(all_checkpoint_files) == 0:
23 | all_checkpoint_files = glob.glob(checkpoints_path + "*.*")
24 | all_checkpoint_files = [ff.replace(".index", "") for ff in
25 | all_checkpoint_files] # to make it work for newer versions of keras
26 | # Filter out entries where the epoc_number part is pure number
27 | all_checkpoint_files = list(filter(lambda f: get_epoch_number_from_path(f)
28 | .isdigit(), all_checkpoint_files))
29 | if not len(all_checkpoint_files):
30 | # The glob list is empty, don't have a checkpoints_path
31 | if not fail_safe:
32 | raise ValueError("Checkpoint path {0} invalid"
33 | .format(checkpoints_path))
34 | else:
35 | return None
36 |
37 | # Find the checkpoint file with the maximum epoch
38 | latest_epoch_checkpoint = max(all_checkpoint_files,
39 | key=lambda f:
40 | int(get_epoch_number_from_path(f)))
41 |
42 | return latest_epoch_checkpoint
43 |
44 | def masked_categorical_crossentropy(gt, pr):
45 | from keras.losses import categorical_crossentropy
46 | mask = 1 - gt[:, :, 0]
47 | return categorical_crossentropy(gt, pr) * mask
48 |
49 |
50 | class CheckpointsCallback(Callback):
51 | def __init__(self, checkpoints_path):
52 | self.checkpoints_path = checkpoints_path
53 |
54 | def on_epoch_end(self, epoch, logs=None):
55 | if self.checkpoints_path is not None:
56 | self.model.save_weights(self.checkpoints_path + "." + str(epoch))
57 | print("saved ", self.checkpoints_path + "." + str(epoch))
58 |
59 |
60 | def train(model,
61 | train_images,
62 | train_annotations,
63 | input_height=None,
64 | input_width=None,
65 | n_classes=None,
66 | verify_dataset=True,
67 | checkpoints_path=None,
68 | epochs=5,
69 | batch_size=2,
70 | validate=False,
71 | val_images=None,
72 | val_annotations=None,
73 | val_batch_size=2,
74 | auto_resume_checkpoint=False,
75 | load_weights=None,
76 | steps_per_epoch=512,
77 | val_steps_per_epoch=512,
78 | gen_use_multiprocessing=False,
79 | ignore_zero_class=False,
80 | optimizer_name='adam',
81 | do_augment=False,
82 | augmentation_name="aug_all",
83 | callbacks=None,
84 | custom_augmentation=None,
85 | other_inputs_paths=None,
86 | preprocessing=None,
87 | read_image_type=1 # cv2.IMREAD_COLOR = 1 (rgb),
88 | # cv2.IMREAD_GRAYSCALE = 0,
89 | # cv2.IMREAD_UNCHANGED = -1 (4 channels like RGBA)
90 | ):
91 | from .models.all_models import model_from_name
92 | # check if user gives model name instead of the model object
93 | if isinstance(model, six.string_types):
94 | # create the model from the name
95 | assert (n_classes is not None), "Please provide the n_classes"
96 | if (input_height is not None) and (input_width is not None):
97 | model = model_from_name[model](
98 | n_classes, input_height=input_height, input_width=input_width)
99 | else:
100 | model = model_from_name[model](n_classes)
101 |
102 | n_classes = model.n_classes
103 | input_height = model.input_height
104 | input_width = model.input_width
105 | output_height = model.output_height
106 | output_width = model.output_width
107 |
108 | if validate:
109 | assert val_images is not None
110 | assert val_annotations is not None
111 |
112 | if optimizer_name is not None:
113 |
114 | if ignore_zero_class:
115 | loss_k = masked_categorical_crossentropy
116 | else:
117 | loss_k = 'categorical_crossentropy'
118 |
119 | model.compile(loss=loss_k,
120 | optimizer=optimizer_name,
121 | metrics=['accuracy'])
122 |
123 | if checkpoints_path is not None:
124 | config_file = checkpoints_path + "_config.json"
125 | dir_name = os.path.dirname(config_file)
126 |
127 | if ( not os.path.exists(dir_name) ) and len( dir_name ) > 0 :
128 | os.makedirs(dir_name)
129 |
130 | with open(config_file, "w") as f:
131 | json.dump({
132 | "model_class": model.model_name,
133 | "n_classes": n_classes,
134 | "input_height": input_height,
135 | "input_width": input_width,
136 | "output_height": output_height,
137 | "output_width": output_width
138 | }, f)
139 |
140 | if load_weights is not None and len(load_weights) > 0:
141 | print("Loading weights from ", load_weights)
142 | model.load_weights(load_weights)
143 |
144 | initial_epoch = 0
145 |
146 | if auto_resume_checkpoint and (checkpoints_path is not None):
147 | latest_checkpoint = find_latest_checkpoint(checkpoints_path)
148 | if latest_checkpoint is not None:
149 | print("Loading the weights from latest checkpoint ",
150 | latest_checkpoint)
151 | model.load_weights(latest_checkpoint)
152 |
153 | initial_epoch = int(latest_checkpoint.split('.')[-1])
154 |
155 | if verify_dataset:
156 | print("Verifying training dataset")
157 | verified = verify_segmentation_dataset(train_images,
158 | train_annotations,
159 | n_classes)
160 | assert verified
161 | if validate:
162 | print("Verifying validation dataset")
163 | verified = verify_segmentation_dataset(val_images,
164 | val_annotations,
165 | n_classes)
166 | assert verified
167 |
168 | train_gen = image_segmentation_generator(
169 | train_images, train_annotations, batch_size, n_classes,
170 | input_height, input_width, output_height, output_width,
171 | do_augment=do_augment, augmentation_name=augmentation_name,
172 | custom_augmentation=custom_augmentation, other_inputs_paths=other_inputs_paths,
173 | preprocessing=preprocessing, read_image_type=read_image_type)
174 |
175 | if validate:
176 | val_gen = image_segmentation_generator(
177 | val_images, val_annotations, val_batch_size,
178 | n_classes, input_height, input_width, output_height, output_width,
179 | other_inputs_paths=other_inputs_paths,
180 | preprocessing=preprocessing, read_image_type=read_image_type)
181 |
182 | if callbacks is None and (not checkpoints_path is None) :
183 | default_callback = ModelCheckpoint(
184 | filepath=checkpoints_path + ".{epoch:05d}",
185 | save_weights_only=True,
186 | verbose=True
187 | )
188 |
189 | if sys.version_info[0] < 3: # for pyhton 2
190 | default_callback = CheckpointsCallback(checkpoints_path)
191 |
192 | callbacks = [
193 | default_callback
194 | ]
195 |
196 | if callbacks is None:
197 | callbacks = []
198 |
199 | if not validate:
200 | model.fit(train_gen, steps_per_epoch=steps_per_epoch,
201 | epochs=epochs, callbacks=callbacks, initial_epoch=initial_epoch)
202 | else:
203 | model.fit(train_gen,
204 | steps_per_epoch=steps_per_epoch,
205 | validation_data=val_gen,
206 | validation_steps=val_steps_per_epoch,
207 | epochs=epochs, callbacks=callbacks,
208 | use_multiprocessing=gen_use_multiprocessing, initial_epoch=initial_epoch)
209 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | h5py
3 | tqdm
4 | keras>=2.3.0
5 | tensorflow>=2.2
6 | opencv-python
7 | tqdm
8 | imageio>=2.5.0
9 | imgaug>=0.4.0
10 |
--------------------------------------------------------------------------------
/sample_images/1_input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/1_input.jpg
--------------------------------------------------------------------------------
/sample_images/1_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/1_output.png
--------------------------------------------------------------------------------
/sample_images/2_input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/2_input.jpg
--------------------------------------------------------------------------------
/sample_images/2_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/2_output.png
--------------------------------------------------------------------------------
/sample_images/3_input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/3_input.jpg
--------------------------------------------------------------------------------
/sample_images/3_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/3_output.png
--------------------------------------------------------------------------------
/sample_images/liner_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_dataset.png
--------------------------------------------------------------------------------
/sample_images/liner_export.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_export.png
--------------------------------------------------------------------------------
/sample_images/liner_testing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_testing.png
--------------------------------------------------------------------------------
/sample_images/liner_training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/sample_images/liner_training.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [egg_info]
2 | tag_build =
3 | tag_date = 0
4 |
5 | [metadata]
6 | description-file = README.md
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | import sys
4 |
5 | cv_ver = ""
6 | keras_ver = ">=2.0.0"
7 | if sys.version_info.major < 3:
8 | cv_ver = "<=4.2.0.32"
9 | keras_ver = "<=2.3.0"
10 |
11 |
12 | setup(name="keras_segmentation",
13 | version="0.3.0",
14 | description="Image Segmentation toolkit for keras",
15 | author="Divam Gupta",
16 | author_email='divamgupta@gmail.com',
17 | platforms=["any"], # or more specific, e.g. "win32", "cygwin", "osx"
18 | license="GPLv3",
19 | url="https://github.com/divamgupta/image-segmentation-keras",
20 | packages=find_packages(exclude=["test"]),
21 | entry_points={
22 | 'console_scripts': [
23 | 'keras_segmentation = keras_segmentation.__main__:main'
24 | ]
25 | },
26 | install_requires=[
27 | "h5py<=2.10.0",
28 | "Keras"+keras_ver,
29 | "imageio==2.5.0",
30 | "imgaug>=0.4.0",
31 | "opencv-python"+cv_ver,
32 | "tqdm"],
33 | extras_require={
34 | # These requires provide different backends available with Keras
35 | "tensorflow": ["tensorflow"],
36 | "cntk": ["cntk"],
37 | "theano": ["theano"],
38 | # Default testing with tensorflow
39 | "tests-default": ["tensorflow", "pytest"]
40 | }
41 | )
42 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/__init__.py
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_test/0016E5_07959.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_test/0016E5_07959.png
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_test/0016E5_07961.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_test/0016E5_07961.png
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_test/0016E5_07963.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_test/0016E5_07963.png
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_train/0001TP_006690.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006690.png
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_train/0001TP_006720.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006720.png
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_train/0001TP_006750.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006750.png
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_train/0001TP_006780.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006780.png
--------------------------------------------------------------------------------
/test/example_dataset/annotations_prepped_train/0001TP_006810.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/annotations_prepped_train/0001TP_006810.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_test/0016E5_07959.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_test/0016E5_07959.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_test/0016E5_07961.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_test/0016E5_07961.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_test/0016E5_07963.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_test/0016E5_07963.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_train/0001TP_006690.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006690.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_train/0001TP_006720.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006720.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_train/0001TP_006750.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006750.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_train/0001TP_006780.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006780.png
--------------------------------------------------------------------------------
/test/example_dataset/images_prepped_train/0001TP_006810.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/example_dataset/images_prepped_train/0001TP_006810.png
--------------------------------------------------------------------------------
/test/test_models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tempfile
3 |
4 | import sys
5 |
6 | import keras
7 |
8 | from keras_segmentation.models import all_models
9 | from keras_segmentation.data_utils.data_loader import \
10 | verify_segmentation_dataset, image_segmentation_generator
11 | from keras_segmentation.predict import predict_multiple, predict, evaluate
12 |
13 |
14 | from keras_segmentation.model_compression import perform_distilation
15 | from keras_segmentation.pretrained import pspnet_50_ADE_20K
16 |
17 | tr_im = "test/example_dataset/images_prepped_train"
18 | tr_an = "test/example_dataset/annotations_prepped_train"
19 | te_im = "test/example_dataset/images_prepped_test"
20 | te_an = "test/example_dataset/annotations_prepped_test"
21 |
22 |
23 |
24 | def test_models():
25 |
26 | n_c = 100
27 |
28 | models = [ ( "unet_mini" , 124 , 156 ) , ( "vgg_unet" , 224 , 224*2 ) ,
29 | ( 'resnet50_pspnet', 192*2 , 192*3 ) , ( 'mobilenet_unet', 224 , 224 ), ( 'mobilenet_unet', 224+32 , 224+32 ),( 'segnet', 224 , 224*2 ),( 'vgg_segnet', 224 , 224*2 ) ,( 'fcn_32', 224 , 224*2 ) ,( 'fcn_8_vgg', 224 , 224*2 ) ]
30 |
31 | for model_name , h , w in models:
32 | m = all_models.model_from_name[model_name]( n_c, input_height=h, input_width=w)
33 |
34 | m.train(train_images=tr_im,
35 | train_annotations=tr_an,
36 | steps_per_epoch=2,
37 | epochs=2 )
38 |
39 | keras.backend.clear_session()
40 |
41 |
42 |
43 |
44 |
45 | def test_verify():
46 | verify_segmentation_dataset(tr_im, tr_an, 50)
47 |
48 |
49 | def test_datag():
50 | g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an,
51 | batch_size=3, n_classes=50,
52 | input_height=224, input_width=324,
53 | output_height=114, output_width=134,
54 | do_augment=False)
55 |
56 | x, y = next(g)
57 | assert x.shape[0] == 3
58 | assert y.shape[0] == 3
59 | assert y.shape[-1] == 50
60 |
61 |
62 | # with augmentation
63 | def test_datag2():
64 | g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an,
65 | batch_size=3, n_classes=50,
66 | input_height=224, input_width=324,
67 | output_height=114, output_width=134,
68 | do_augment=True)
69 |
70 | x, y = next(g)
71 | assert x.shape[0] == 3
72 | assert y.shape[0] == 3
73 | assert y.shape[-1] == 50
74 |
75 |
76 | def test_model():
77 | model_name = "fcn_8"
78 | h = 224
79 | w = 256
80 | n_c = 100
81 | check_path = tempfile.mktemp()
82 |
83 | m = all_models.model_from_name[model_name](
84 | n_c, input_height=h, input_width=w)
85 |
86 | m.train(train_images=tr_im,
87 | train_annotations=tr_an,
88 | steps_per_epoch=2,
89 | epochs=2,
90 | checkpoints_path=check_path
91 | )
92 |
93 | m.train(train_images=tr_im,
94 | train_annotations=tr_an,
95 | steps_per_epoch=2,
96 | epochs=2,
97 | checkpoints_path=check_path,
98 | augmentation_name='aug_geometric', do_augment=True
99 | )
100 |
101 | m.predict_segmentation(np.zeros((h, w, 3))).shape
102 |
103 | predict_multiple(
104 | inp_dir=te_im, checkpoints_path=check_path, out_dir="/tmp")
105 | predict_multiple(inps=[np.zeros((h, w, 3))]*3,
106 | checkpoints_path=check_path, out_dir="/tmp")
107 |
108 | ev = m.evaluate_segmentation(inp_images_dir=te_im, annotations_dir=te_an)
109 | assert ev['frequency_weighted_IU'] > 0.01
110 | print(ev)
111 | o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path)
112 |
113 | o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path,
114 | overlay_img=True, class_names=['nn']*n_c, show_legends=True)
115 | print("pr")
116 |
117 | o.shape
118 |
119 | ev = evaluate(inp_images_dir=te_im,
120 | annotations_dir=te_an,
121 | checkpoints_path=check_path)
122 | assert ev['frequency_weighted_IU'] > 0.01
123 |
124 |
125 | def test_kd():
126 |
127 | if sys.version_info.major < 3:
128 | # KD wont work with python 2
129 | return
130 |
131 | model_name = "fcn_8"
132 | h = 224
133 | w = 256
134 | n_c = 100
135 | check_path1 = tempfile.mktemp()
136 |
137 | m1 = all_models.model_from_name[model_name](
138 | n_c, input_height=h, input_width=w)
139 |
140 |
141 |
142 | model_name = "unet_mini"
143 | h = 124
144 | w = 156
145 | n_c = 100
146 | check_path2 = tempfile.mktemp()
147 |
148 | m2 = all_models.model_from_name[model_name](
149 | n_c, input_height=h, input_width=w)
150 |
151 |
152 | m1.train(train_images=tr_im,
153 | train_annotations=tr_an,
154 | steps_per_epoch=2,
155 | epochs=2,
156 | checkpoints_path=check_path1
157 | )
158 |
159 | perform_distilation(m1 ,m2, tr_im , distilation_loss='kl' ,
160 | batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, )
161 |
162 |
163 | perform_distilation(m1 ,m2, tr_im , distilation_loss='l2' ,
164 | batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, )
165 |
166 |
167 | perform_distilation(m1 ,m2, tr_im , distilation_loss='l2' ,
168 | batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, feats_distilation_loss='pa' )
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 | def test_pretrained():
180 |
181 |
182 | model = pspnet_50_ADE_20K()
183 |
184 | out = model.predict_segmentation(
185 | inp=te_im+"/0016E5_07959.png",
186 | out_fname="/tmp/out.png"
187 | )
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 | # def test_models():
197 |
198 |
199 | # unet_models = [ , models.unet.vgg_unet , models.unet.resnet50_unet ]
200 | # args = [(101, 416, 608), (101, 224 ,224), (101, 256, 256), (2, 32*4, 32*5)]
201 | # en_level = [ 1,2,3,4 ]
202 |
203 | # for mf in unet_models:
204 | # for en in en_level:
205 | # for ar in args:
206 | # m = mf( *ar , encoder_level=en )
207 |
208 |
209 | # m = models.unet.mobilenet_unet( 55 )
210 | # for ar in args:
211 | # m = unet_mini( *ar )
212 |
--------------------------------------------------------------------------------
/test/unit/data_utils/test_augmentation.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/data_utils/test_augmentation.py
--------------------------------------------------------------------------------
/test/unit/data_utils/test_data_loader.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import unittest
3 | import tempfile
4 | from shutil import rmtree
5 | import os
6 | import six
7 | from keras_segmentation.data_utils import data_loader
8 | import random
9 | import cv2
10 | from imgaug import augmenters as iaa
11 | import shutil
12 | import numpy as np
13 |
14 | class TestGetPairsFromPaths(unittest.TestCase):
15 | """ Test data loader facilities """
16 |
17 | def _setup_images_and_segs(self, images, segs):
18 | for file_name in images:
19 | open(os.path.join(self.img_path, file_name), 'a').close()
20 | for file_name in segs:
21 | open(os.path.join(self.seg_path, file_name), 'a').close()
22 |
23 | @classmethod
24 | def _cleanup_folder(cls, folder_path):
25 | return rmtree(folder_path)
26 |
27 | def setUp(self):
28 | self.tmp_dir = tempfile.mkdtemp()
29 | self.img_path = os.path.join(self.tmp_dir, "images")
30 | self.seg_path = os.path.join(self.tmp_dir, "segs")
31 | os.mkdir(self.img_path)
32 | os.mkdir(self.seg_path)
33 |
34 | def tearDown(self):
35 | rmtree(self.tmp_dir)
36 |
37 | def test_get_pairs_from_paths_1(self):
38 | """ Normal execution """
39 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"]
40 | segs = ["A.png", "B.png", "C.png", "D.png"]
41 | self._setup_images_and_segs(images, segs)
42 |
43 | expected = [("A.jpg", "A.png"),
44 | ("B.jpg", "B.png"),
45 | ("C.jpeg", "C.png"),
46 | ("D.png", "D.png")]
47 | expected_values = []
48 | # Transform paths
49 | for (x, y) in expected:
50 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y)))
51 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path)))
52 |
53 | def test_get_pairs_from_paths_2(self):
54 | """ Normal execution with extra files """
55 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png", "E.txt"]
56 | segs = ["A.png", "B.png", "C.png", "D.png", "E.png"]
57 | self._setup_images_and_segs(images, segs)
58 |
59 | expected = [("A.jpg", "A.png"),
60 | ("B.jpg", "B.png"),
61 | ("C.jpeg", "C.png"),
62 | ("D.png", "D.png")]
63 | expected_values = []
64 | # Transform paths
65 | for (x, y) in expected:
66 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y)))
67 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path)))
68 |
69 |
70 | def test_get_pairs_from_paths_3(self):
71 | """ Normal execution with multiple images pointing to one """
72 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png", "D.jpg"]
73 | segs = ["A.png", "B.png", "C.png", "D.png"]
74 | self._setup_images_and_segs(images, segs)
75 |
76 | expected = [("A.jpg", "A.png"),
77 | ("B.jpg", "B.png"),
78 | ("C.jpeg", "C.png"),
79 | ("D.jpg", "D.png"),
80 | ("D.png", "D.png"),]
81 | expected_values = []
82 | # Transform paths
83 | for (x, y) in expected:
84 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y)))
85 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path)))
86 |
87 | def test_get_pairs_from_paths_with_invalid_segs(self):
88 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"]
89 | segs = ["A.png", "B.png", "C.png", "D.png", "D.jpg"]
90 | self._setup_images_and_segs(images, segs)
91 |
92 | expected = [("A.jpg", "A.png"),
93 | ("B.jpg", "B.png"),
94 | ("C.jpeg", "C.png"),
95 | ("D.png", "D.png"),]
96 | expected_values = []
97 | # Transform paths
98 | for (x, y) in expected:
99 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y)))
100 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path)))
101 |
102 | def test_get_pairs_from_paths_with_no_matching_segs(self):
103 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"]
104 | segs = ["A.png", "B.png", "C.png"]
105 | self._setup_images_and_segs(images, segs)
106 |
107 | expected = [("A.jpg", "A.png"),
108 | ("B.jpg", "B.png"),
109 | ("C.jpeg", "C.png")]
110 | expected_values = []
111 | # Transform paths
112 | for (x, y) in expected:
113 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y)))
114 | six.assertRaisesRegex(self, data_loader.DataLoaderError, "No corresponding segmentation found for image", data_loader.get_pairs_from_paths, self.img_path, self.seg_path)
115 |
116 | def test_get_pairs_from_paths_with_no_matching_segs_with_escape(self):
117 | images = ["A.jpg", "B.jpg", "C.jpeg", "D.png"]
118 | segs = ["A.png", "B.png", "C.png"]
119 | self._setup_images_and_segs(images, segs)
120 |
121 | expected = [("A.jpg", "A.png"),
122 | ("B.jpg", "B.png"),
123 | ("C.jpeg", "C.png")]
124 | expected_values = []
125 | # Transform paths
126 | for (x, y) in expected:
127 | expected_values.append((os.path.join(self.img_path, x), os.path.join(self.seg_path, y)))
128 | self.assertEqual(expected_values, sorted(data_loader.get_pairs_from_paths(self.img_path, self.seg_path, ignore_non_matching=True)))
129 |
130 | class TestGetImageArray(unittest.TestCase):
131 | def test_get_image_array_normal(self):
132 | """ Stub test
133 | TODO(divamgupta): Fill with actual test
134 | """
135 | pass
136 |
137 | class TestGetSegmentationArray(unittest.TestCase):
138 | def test_get_segmentation_array_normal(self):
139 | """ Stub test
140 | TODO(divamgupta): Fill with actual test
141 | """
142 | pass
143 |
144 | class TestVerifySegmentationDataset(unittest.TestCase):
145 | def test_verify_segmentation_dataset(self):
146 | """ Stub test
147 | TODO(divamgupta): Fill with actual test
148 | """
149 | pass
150 |
151 | class TestImageSegmentationGenerator(unittest.TestCase):
152 | def setUp(self):
153 | self.train_temp_dir = tempfile.mkdtemp()
154 | self.test_temp_dir = tempfile.mkdtemp()
155 | self.other_temp_dir = tempfile.mkdtemp()
156 | self.other_temp_dir_2 = tempfile.mkdtemp()
157 |
158 | self.image_size = 4
159 |
160 | # Training
161 | train_image = np.arange(self.image_size * self.image_size)
162 | train_image = train_image.reshape((self.image_size, self.image_size))
163 |
164 | train_file = os.path.join(self.train_temp_dir, "train.png")
165 | test_file = os.path.join(self.test_temp_dir, "train.png")
166 |
167 | cv2.imwrite(train_file, train_image)
168 | cv2.imwrite(test_file, train_image)
169 |
170 | # Testing
171 | train_image = np.arange(start=self.image_size * self.image_size,
172 | stop=self.image_size * self.image_size * 2)
173 | train_image = train_image.reshape((self.image_size,self.image_size))
174 |
175 | train_file = os.path.join(self.train_temp_dir, "train2.png")
176 | test_file = os.path.join(self.test_temp_dir, "train2.png")
177 |
178 | cv2.imwrite(train_file, train_image)
179 | cv2.imwrite(test_file, train_image)
180 |
181 | # Extra one
182 |
183 | i = 0
184 | for dir in [self.other_temp_dir, self.other_temp_dir_2]:
185 | extra_image = np.arange(start=self.image_size * self.image_size * (2 + i),
186 | stop=self.image_size * self.image_size * (2 + i + 1))
187 | extra_image = extra_image.reshape((self.image_size, self.image_size))
188 |
189 | extra_file = os.path.join(dir, "train.png")
190 | cv2.imwrite(extra_file, extra_image)
191 | i += 1
192 |
193 | extra_image = np.arange(start=self.image_size * self.image_size * (2 + i),
194 | stop=self.image_size * self.image_size * (2 + i + 1))
195 | extra_image = extra_image.reshape((self.image_size, self.image_size))
196 |
197 | extra_file = os.path.join(dir, "train2.png")
198 | cv2.imwrite(extra_file, extra_image)
199 | i += 1
200 |
201 | def tearDown(self):
202 | shutil.rmtree(self.train_temp_dir)
203 | shutil.rmtree(self.test_temp_dir)
204 | shutil.rmtree(self.other_temp_dir)
205 | shutil.rmtree(self.other_temp_dir_2)
206 |
207 | def custom_aug(self):
208 | return iaa.Sequential(
209 | [
210 | iaa.Fliplr(1), # horizontally flip 100% of all images
211 | ])
212 |
213 | def test_image_segmentation_generator_custom_augmentation(self):
214 | random.seed(0)
215 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, self.test_temp_dir)
216 |
217 | random.seed(0)
218 | random.shuffle(image_seg_pairs)
219 |
220 | random.seed(0)
221 |
222 | generator = data_loader.image_segmentation_generator(
223 | self.train_temp_dir, self.test_temp_dir, 1,
224 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size, self.image_size,
225 | do_augment=True, custom_augmentation=self.custom_aug
226 | )
227 |
228 | i = 0
229 | for (aug_im, aug_an), (expt_im_f, expt_an_f) in zip(generator, image_seg_pairs):
230 | if i >= len(image_seg_pairs):
231 | break
232 |
233 | expt_im = data_loader.get_image_array(expt_im_f, self.image_size, self.image_size, ordering='channel_last')
234 |
235 | expt_im = cv2.flip(expt_im, flipCode=1)
236 | self.assertTrue(np.equal(expt_im, aug_im).all())
237 |
238 | i += 1
239 |
240 | def test_image_segmentation_generator_custom_augmentation_with_other_inputs(self):
241 | other_paths = [
242 | self.other_temp_dir, self.other_temp_dir_2
243 | ]
244 | random.seed(0)
245 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir,
246 | self.test_temp_dir,
247 | other_inputs_paths=other_paths)
248 |
249 | random.seed(0)
250 | random.shuffle(image_seg_pairs)
251 |
252 | random.seed(0)
253 | generator = data_loader.image_segmentation_generator(
254 | self.train_temp_dir, self.test_temp_dir, 1,
255 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size,
256 | self.image_size,
257 | do_augment=True, custom_augmentation=self.custom_aug, other_inputs_paths=other_paths
258 | )
259 |
260 | i = 0
261 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs):
262 | if i >= len(image_seg_pairs):
263 | break
264 |
265 | ims = [expt_im_f]
266 | ims.extend(expt_oth)
267 |
268 | for i in range(aug_im.shape[1]):
269 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size,
270 | ordering='channel_last')
271 |
272 | expt_im = cv2.flip(expt_im, flipCode=1)
273 |
274 | self.assertTrue(np.equal(expt_im, aug_im[0, i, :, :]).all())
275 |
276 | i += 1
277 |
278 | def test_image_segmentation_generator_with_other_inputs(self):
279 | other_paths = [
280 | self.other_temp_dir, self.other_temp_dir_2
281 | ]
282 | random.seed(0)
283 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir,
284 | self.test_temp_dir,
285 | other_inputs_paths=other_paths)
286 |
287 | random.seed(0)
288 | random.shuffle(image_seg_pairs)
289 |
290 | random.seed(0)
291 | generator = data_loader.image_segmentation_generator(
292 | self.train_temp_dir, self.test_temp_dir, 1,
293 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size,
294 | self.image_size,
295 | other_inputs_paths=other_paths
296 | )
297 |
298 | i = 0
299 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs):
300 | if i >= len(image_seg_pairs):
301 | break
302 |
303 | ims = [expt_im_f]
304 | ims.extend(expt_oth)
305 |
306 | for i in range(aug_im.shape[1]):
307 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size,
308 | ordering='channel_last')
309 | self.assertTrue(np.equal(expt_im, aug_im[0, i, :, :]).all())
310 |
311 | i += 1
312 |
313 | def test_image_segmentation_generator_preprocessing(self):
314 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir, self.test_temp_dir)
315 |
316 | random.seed(0)
317 | random.shuffle(image_seg_pairs)
318 |
319 | random.seed(0)
320 |
321 | generator = data_loader.image_segmentation_generator(
322 | self.train_temp_dir, self.test_temp_dir, 1,
323 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size,
324 | self.image_size,
325 | preprocessing=lambda x: x + 1
326 | )
327 |
328 | i = 0
329 | for (aug_im, aug_an), (expt_im_f, expt_an_f) in zip(generator, image_seg_pairs):
330 | if i >= len(image_seg_pairs):
331 | break
332 |
333 | expt_im = data_loader.get_image_array(expt_im_f, self.image_size, self.image_size,
334 | ordering='channel_last')
335 |
336 | expt_im += 1
337 | self.assertTrue(np.equal(expt_im, aug_im[0, :, :]).all())
338 |
339 | i += 1
340 |
341 | def test_single_image_segmentation_generator_preprocessing_with_other_inputs(self):
342 | other_paths = [
343 | self.train_temp_dir, self.test_temp_dir
344 | ]
345 | random.seed(0)
346 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir,
347 | self.test_temp_dir,
348 | other_inputs_paths=other_paths)
349 |
350 | random.seed(0)
351 | random.shuffle(image_seg_pairs)
352 |
353 | random.seed(0)
354 | generator = data_loader.image_segmentation_generator(
355 | self.train_temp_dir, self.test_temp_dir, 1,
356 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size,
357 | self.image_size,
358 | preprocessing=lambda x: x+1, other_inputs_paths=other_paths
359 | )
360 |
361 | i = 0
362 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs):
363 | if i >= len(image_seg_pairs):
364 | break
365 |
366 | ims = [expt_im_f]
367 | ims.extend(expt_oth)
368 |
369 | for i in range(aug_im.shape[1]):
370 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size,
371 | ordering='channel_last')
372 |
373 | self.assertTrue(np.equal(expt_im + 1, aug_im[0, i, :, :]).all())
374 |
375 | i += 1
376 |
377 | def test_multi_image_segmentation_generator_preprocessing_with_other_inputs(self):
378 | other_paths = [
379 | self.other_temp_dir, self.other_temp_dir_2
380 | ]
381 | random.seed(0)
382 | image_seg_pairs = data_loader.get_pairs_from_paths(self.train_temp_dir,
383 | self.test_temp_dir,
384 | other_inputs_paths=other_paths)
385 |
386 | random.seed(0)
387 | random.shuffle(image_seg_pairs)
388 |
389 | random.seed(0)
390 | generator = data_loader.image_segmentation_generator(
391 | self.train_temp_dir, self.test_temp_dir, 1,
392 | self.image_size * self.image_size, self.image_size, self.image_size, self.image_size,
393 | self.image_size,
394 | preprocessing=[lambda x: x+1, lambda x: x+2, lambda x: x+3], other_inputs_paths=other_paths
395 | )
396 |
397 | i = 0
398 | for (aug_im, aug_an), (expt_im_f, expt_an_f, expt_oth) in zip(generator, image_seg_pairs):
399 | if i >= len(image_seg_pairs):
400 | break
401 |
402 | ims = [expt_im_f]
403 | ims.extend(expt_oth)
404 |
405 | for i in range(aug_im.shape[1]):
406 | expt_im = data_loader.get_image_array(ims[i], self.image_size, self.image_size,
407 | ordering='channel_last')
408 |
409 | self.assertTrue(np.equal(expt_im + (i + 1), aug_im[0, i, :, :]).all())
410 |
411 | i += 1
412 |
413 |
414 |
415 |
--------------------------------------------------------------------------------
/test/unit/data_utils/test_visualize_dataset.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/data_utils/test_visualize_dataset.py
--------------------------------------------------------------------------------
/test/unit/models/test_basic_models.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/models/test_basic_models.py
--------------------------------------------------------------------------------
/test/unit/test_metrics.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/test_metrics.py
--------------------------------------------------------------------------------
/test/unit/test_predict.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/test_predict.py
--------------------------------------------------------------------------------
/test/unit/test_pretrained.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divamgupta/image-segmentation-keras/1b2ba53ae49387c2d1abbd9a2f4a9a45eea6912f/test/unit/test_pretrained.py
--------------------------------------------------------------------------------
/test/unit/test_train.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import tempfile
3 | from shutil import rmtree
4 | import os
5 | import six
6 | from keras_segmentation import train
7 |
8 | class TestTrainInternalFunctions(unittest.TestCase):
9 | """ Test internal functions of the module """
10 |
11 | def setUp(self):
12 | self.tmp_dir = tempfile.mkdtemp()
13 |
14 | def tearDown(self):
15 | rmtree(self.tmp_dir)
16 |
17 | def test_find_latest_checkpoint(self):
18 | # Populate a folder of images and try checkpoint
19 | checkpoints_path = os.path.join(self.tmp_dir, "test1")
20 | # Create files
21 | self.assertEqual(None, train.find_latest_checkpoint(checkpoints_path))
22 | # When fail_safe is turned off, throw an exception when no checkpoint is found.
23 | six.assertRaisesRegex(self, ValueError, "Checkpoint path", train.find_latest_checkpoint, checkpoints_path, False)
24 | for suffix in ["0", "2", "4", "12", "_config.json", "ABC"]:
25 | open(checkpoints_path + '.' + suffix, 'a').close()
26 | self.assertEqual(checkpoints_path + ".12", train.find_latest_checkpoint(checkpoints_path))
--------------------------------------------------------------------------------