├── .codecov.xml ├── .gitignore ├── .readthedocs.yml ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── ROADMAP.md ├── docker ├── Dockerfile └── tf20 │ └── Dockerfile ├── docs ├── activations.md ├── datasets.md ├── docker.md ├── index.md ├── metrics.md ├── models.md ├── predict.md ├── requirements.txt ├── tflite.md └── usage.md ├── examples ├── Demo.ipynb ├── FindLR.ipynb ├── FindOptimalBatchSize.ipynb ├── Predict.ipynb └── Train.ipynb ├── mkdocs.yml ├── requirements.txt ├── scripts ├── docker-compose-install.sh └── nvidia-docker-setup.sh ├── setup.cfg ├── setup.py ├── sweep.example.yaml ├── tests ├── __init__.py ├── bin │ ├── __init__.py │ ├── tfrecord_writer_test.py │ └── train_test.py ├── data.py ├── datasets │ ├── __init__.py │ ├── test_dataset.py │ ├── test_shapes.py │ ├── test_tfrecord.py │ └── test_utils.py ├── debug │ ├── __init__.py │ └── test_export_dataset.py ├── fixtures.py ├── losses │ ├── __init__.py │ └── test_utils.py ├── processing │ ├── __init__.py │ ├── dataset_test.py │ └── image_test.py ├── requirements.txt ├── test.png ├── test_activations.py ├── test_apps.py ├── test_losses.py ├── test_metrics.py ├── test_models.py ├── test_threading.py ├── test_utils.py └── visualizations │ ├── __init__.py │ └── mask_test.py └── tf_semantic_segmentation ├── __init__.py ├── activations └── __init__.py ├── bin ├── __init__.py ├── convert_tflite.py ├── download.py ├── model_server_config_writer.py ├── tfrecord_analyser.py ├── tfrecord_download.py ├── tfrecord_writer.py ├── train.py └── train_all.py ├── callbacks.py ├── datasets ├── __init__.py ├── ade20k.py ├── bioimage.py ├── camvid.py ├── cityscapes.py ├── cub.py ├── cvc_clinicdb.py ├── dataset.py ├── directory.py ├── isic.py ├── mapping_challenge.py ├── mots_challenge.py ├── ms_coco.py ├── pascal.py ├── shapes.py ├── sun.py ├── taco.py ├── tfrecord.py ├── toy.py └── utils.py ├── debug ├── __init__.py ├── dataset_export.py ├── dataset_vis.py ├── devices.py ├── export_saved_model.py ├── model_parameters.py ├── preprocessing_vis.py ├── record_vis.py └── tflite_test.py ├── evaluation ├── __init__.py ├── compare_models.py ├── eval_loss.py ├── predict.py ├── video.py └── viewer.py ├── layers ├── __init__.py ├── conv.py ├── minibatchstddev.py ├── pixel_norm.py ├── subpixel.py └── utils.py ├── losses ├── __init__.py ├── ce.py ├── combined.py ├── dice.py ├── focal.py ├── lovasz.py ├── ssim.py └── utils.py ├── metrics ├── __init__.py ├── f_scores.py ├── iou_score.py ├── kmetrics.py ├── precision.py ├── psnr.py ├── recall.py └── ssim.py ├── models ├── __init__.py ├── apps │ ├── __init__.py │ ├── inception.py │ ├── mobilenet.py │ ├── resnet50.py │ └── utils.py ├── attention_unet.py ├── deeplabv3.py ├── deeplabv3plus.py ├── erfnet.py ├── fcn.py ├── imagenet_unet.py ├── multires_unet.py ├── nested_unet.py ├── psp.py ├── satellite_unet.py ├── u2net.py └── unet.py ├── optimizers └── __init__.py ├── processing ├── __init__.py ├── dataset.py └── image.py ├── serving.py ├── settings.py ├── threading.py ├── utils.py ├── version.py └── visualizations ├── __init__.py ├── masks.py └── show.py /.codecov.xml: -------------------------------------------------------------------------------- 1 | #see https://github.com/codecov/support/wiki/Codecov-Yaml 2 | codecov: 3 | notify: 4 | require_ci_to_pass: yes 5 | 6 | coverage: 7 | precision: 0 # 2 = xx.xx%, 0 = xx% 8 | round: nearest # how coverage is rounded: down/up/nearest 9 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 10 | status: 11 | # https://codecov.readme.io/v1.0/docs/commit-status 12 | project: 13 | default: 14 | against: auto 15 | target: 90% # specify the target coverage for each commit status 16 | threshold: 20% # allow this little decrease on project 17 | # https://github.com/codecov/support/wiki/Filtering-Branches 18 | # branches: master 19 | if_ci_failed: error 20 | # https://github.com/codecov/support/wiki/Patch-Status 21 | patch: 22 | default: 23 | against: auto 24 | target: 40% # specify the target "X%" coverage to hit 25 | # threshold: 50% # allow this much decrease on patch 26 | changes: false 27 | 28 | parsers: 29 | gcov: 30 | branch_detection: 31 | conditional: true 32 | loop: true 33 | macro: false 34 | method: false 35 | javascript: 36 | enable_partials: false 37 | 38 | comment: 39 | layout: header, diff 40 | require_changes: false 41 | behavior: default # update if exists else create new 42 | # branches: * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | # IDEs 62 | .idea/ 63 | .vscode/ 64 | checkpoints/ 65 | *.wav 66 | *.mp3 67 | *.png 68 | *.pkl 69 | *.swp 70 | *.gz 71 | *.gif 72 | .config/ 73 | .env/ 74 | .pytest_cache/ 75 | **.ipynb_checkpoints/ 76 | .venv/ 77 | logs/ 78 | wandb/ 79 | vast 80 | *.csv 81 | client_secrets.json 82 | experimental/ 83 | .local/ 84 | .keras/ 85 | records/** 86 | gt/** 87 | pets/** -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with mkdocs 9 | mkdocs: 10 | configuration: mkdocs.yml 11 | fail_on_warning: false 12 | 13 | # Optionally build your docs in additional formats such as PDF 14 | formats: 15 | - pdf 16 | 17 | # Optionally set the version of Python and requirements required to build your docs 18 | python: 19 | version: 3.7 20 | install: 21 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: required 3 | python: 4 | - 3.7 5 | before_install: 6 | - sudo apt-get install -y libsm6 libxext6 libxrender-dev libyaml-dev libpython3-dev unrar tar zip 7 | install: 8 | - pip install --upgrade pip 9 | - pip install --upgrade setuptools 10 | - pip install -r requirements.txt 11 | - pip install -r tests/requirements.txt 12 | cache: pip 13 | script: 14 | - python setup.py sdist bdist_wheel 15 | - python setup.py clean 16 | - travis_wait py.test --cov-report term -v --cov=tf_semantic_segmentation 17 | - python setup.py install 18 | - python setup.py clean 19 | - pip uninstall -y tf_semantic_segmentation 20 | after_success: 21 | - codecov 22 | deploy: 23 | provider: pypi 24 | user: "__token__" 25 | password: 26 | secure: Qx0XdKlX//uskH7B+FRQA8iKKN+rC3ona7Os+wiJqY0Vge2r7x4fFODtyfAA9mlvj6UYQ0x56H5k1sOeiJmmHbZTV5sulLurLOi87BpcYT9I10+Zsiw//y2CWjDVrfptGFnGYA/u/Qj6M7pOhSDZGQZ9YxoqasOvIE8t/sn3f4HXq8y/PjqMTaJ5mGYDIytz/A5RjN/R1lnPPrzado1NkJ5nf1nmX1kglCne1dny0xnOTJHdGanwa4p6A8sU6FildN+VskmA+vs8sK7Zr3Zuu9RtyTFEI0u+5pxoZQ4YIjZxp1/Wmd0fDX3vAbsKFAfGdUsYOgNigcmRd1+vOYEBSxnIARR7FG8zi+oEz5qJJFMXXflC0mu5kL9mOzS8eM3GrOvlDHrrqRL+vlOmrLIf7MozCnetyjl3fx0K0SrrgB/blPjUleIHjDlr6+NXnQqdh0xUVRE6Geo+FD8P7Giy1TbUo5xLKKzDF6XOPvggMKhz4OJ3/A6IRNF20ut26Uww8cyrJTEIBj7ia30B+9jLRwM0uU1Sw9BzsPq8dNpBs7scF8c0dlNswXw7q+f8gGeOUO0gzX9HuY+szpW2hJqL7xlixyGOHnzIqdIczS81OXaCWBcoQYEqI97SKf4ZngMUclzZ1AR1i2/manVMiYiC5Q/1NVJEDNWv1dkmLF1DRd0= 27 | distributions: "sdist bdist_wheel" 28 | skip_existing: true 29 | on: 30 | tags: true 31 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | exclude experimental/ -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | ## v0.4.0 4 | 5 | - [ ] pretrained models on 6 | - Ade20k 7 | - Taco 8 | - CVC-ClinicDB 9 | - [ ] New models 10 | - [DeeplabV3](https://arxiv.org/abs/1706.05587) - Resnet50 11 | - [DeeplabV3+](https://arxiv.org/abs/1802.02611) - Xception and Mobile 12 | 13 | ## v0.5.0 14 | 15 | - [ ] performance metrics 16 | - training speed (images per second) 17 | - maximum batch size 18 | - [ ] model comparisons with paper results on PascalVoc -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.4.0-gpu-py3 2 | 3 | ADD requirements.txt requirements.txt 4 | 5 | RUN apt-get update && apt-get install -y libsm6 libxext6 libxrender-dev libyaml-dev libpython3-dev 6 | RUN pip install -r requirements.txt 7 | 8 | # install tf addons 9 | RUN pip install tensorflow-addons==0.12.0 10 | RUN apt-get install -y fish tmux curl htop screen 11 | 12 | WORKDIR /root 13 | 14 | # expose tensorboard port 15 | EXPOSE 6006 16 | 17 | # ADD run.sh run.sh 18 | COPY tf_semantic_segmentation/ tf_semantic_segmentation 19 | 20 | # hack for rtx cards to work with tf 2.0, otherwise pooling operation will fail 21 | # see: https://github.com/AlexEMG/DeepLabCut/issues/1 22 | ENV TF_FORCE_GPU_ALLOW_GROWTH 'true' 23 | # ARG record_tag="" 24 | # RUN test -z "$record_tag" || python -m tf_semantic_segmentation.bin.tfrecord_download -t ${record_tag} -r /hdd/datasets/downloaded/${record_tag} 25 | CMD fish -------------------------------------------------------------------------------- /docker/tf20/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.0.0-gpu-py3 2 | 3 | ADD requirements.txt requirements.txt 4 | 5 | RUN apt-get update && apt-get install -y libsm6 libxext6 libxrender-dev libyaml-dev libpython3-dev 6 | RUN pip install -r requirements.txt 7 | 8 | # install tf addons 9 | RUN pip install tensorflow-addons==0.6.0 10 | RUN apt-get install -y fish tmux curl htop screen 11 | 12 | WORKDIR /root 13 | 14 | # expose tensorboard port 15 | EXPOSE 6006 16 | 17 | # ADD run.sh run.sh 18 | COPY tf_semantic_segmentation/ tf_semantic_segmentation 19 | 20 | # hack for rtx cards to work with tf 2.0, otherwise pooling operation will fail 21 | # see: https://github.com/AlexEMG/DeepLabCut/issues/1 22 | ENV TF_FORCE_GPU_ALLOW_GROWTH 'true' 23 | # ARG record_tag="" 24 | # RUN test -z "$record_tag" || python -m tf_semantic_segmentation.bin.tfrecord_download -t ${record_tag} -r /hdd/datasets/downloaded/${record_tag} 25 | CMD fish -------------------------------------------------------------------------------- /docs/activations.md: -------------------------------------------------------------------------------- 1 | - mish 2 | - swish 3 | - relu6 4 | 5 | ```python 6 | # import to use as a keras activation 7 | from tf_semantic_segmentation import activations 8 | from tensorflow.keras.layers import Activation 9 | 10 | act = Activation("swish") 11 | ``` -------------------------------------------------------------------------------- /docs/datasets.md: -------------------------------------------------------------------------------- 1 | - Ade20k 2 | - Camvid 3 | - Cityscapes 4 | - MappingChallenge 5 | - MotsChallenge 6 | - Coco 7 | - PascalVoc2012 8 | - Taco 9 | - Shapes (randomly creating triangles, rectangles and circles) 10 | - Toy (Overlaying TinyImageNet with MNIST) 11 | - ISIC2018 12 | - CVC-ClinicDB 13 | 14 | ### Using Dataset (tacobinary) 15 | 16 | ```python 17 | from tf_semantic_sementation.datasets import get_dataset by name, datasets_by_name, DataType, get_cache_dir 18 | 19 | # print availiable dataset names 20 | print(list(datasets_by_name.keys())) 21 | 22 | # get the binary (waste or not) dataset 23 | data_dir = '/hdd/data/' 24 | name = 'tacobinary' 25 | cache_dir = get_cache_dir(data_dir, name.lower()) 26 | ds = get_dataset_by_name(name, cache_dir) 27 | 28 | # print labels and classes 29 | print(ds.labels) 30 | print(ds.num_classes) 31 | 32 | # print number of training examples 33 | print(ds.num_examples(DataType.TRAIN)) 34 | 35 | # or simply print the summary 36 | ds.summary() 37 | ``` 38 | 39 | ### Debug Datasets 40 | 41 | #### Visualize 42 | ```bash 43 | python -m tf_semantic_segmentation.debug.dataset_vis -d ade20k -c '/tmp/data' 44 | ``` 45 | 46 | #### Analyse 47 | 48 | ```bash 49 | python -m tf_semantic_segmentation.bin.tfrecord_analyser -r records/ --mean 50 | ``` 51 | 52 | ### Create TFRecord 53 | 54 | #### Using Code 55 | 56 | ```python 57 | from tf_semantic_segmentation.datasets import TFWriter 58 | ds = ... 59 | writer = TFWriter(record_dir) 60 | writer.write(ds) 61 | writer.validate(ds) 62 | ``` 63 | 64 | #### Inbuild dataset 65 | 66 | ```shell 67 | INPUT_DIR = ... 68 | tf-semantic-segmentation-tfrecord-writer -dir $INPUT_DIR -r $INPUT_DIR/records 69 | ``` 70 | 71 | 72 | #### From Custom Dataset 73 | 74 | ```shell 75 | INPUT_DIR = ... 76 | tf-semantic-segmentation-tfrecord-writer -dir $INPUT_DIR -r $INPUT_DIR/records 77 | ``` 78 | 79 | There are the following addition arguments: 80 | 81 | - -s [--size] '$width,$height' (f.e. "512,512") 82 | - -rm [--resize_method] ('resize', 'resize_with_pad', 'resize_with_crop_or_pad) 83 | - cm [--color_mode] (0=RGB, 1=GRAY, 2=NONE (default)) 84 | 85 | 86 | 87 | ### Use your own dataset 88 | 89 | - Accepted file types are: jpg(jpeg) and png 90 | 91 | If you already have a train/test/val split then use the following data structure: 92 | 93 | ```text 94 | dataset/ 95 | labels.txt 96 | test/ 97 | images/ 98 | masks/ 99 | train/ 100 | images/ 101 | masks/ 102 | val/ 103 | images/ 104 | masks/ 105 | ``` 106 | 107 | or use 108 | 109 | ```text 110 | dataset/ 111 | labels.txt 112 | images/ 113 | masks/ 114 | ``` 115 | 116 | The labels.txt should contain a list of labels separated by newline [/n]. For instance it looks like this: 117 | 118 | ```text 119 | background 120 | car 121 | pedestrian 122 | ``` 123 | 124 | #### Create TFRecord 125 | 126 | ```shell 127 | INPUT_DIR = ... 128 | tf-semantic-segmentation-tfrecord-writer -dir $INPUT_DIR -r $INPUT_DIR/records 129 | ``` 130 | 131 | There are the following addition arguments: 132 | 133 | - -s [--size] '$width,$height' (f.e. "512,512") 134 | - -rm [--resize_method] ('resize', 'resize_with_pad', 'resize_with_crop_or_pad) 135 | - cm [--color_mode] (0=RGB, 1=GRAY, 2=NONE (default)) -------------------------------------------------------------------------------- /docs/docker.md: -------------------------------------------------------------------------------- 1 | #### Build 2 | 3 | ```shell 4 | docker build -t tf_semantic_segmentation -f docker/Dockerfile ./ 5 | ``` 6 | 7 | #### Pull 8 | 9 | or pull the latest release 10 | 11 | ```shell 12 | docker pull baudcode/tf_semantic_segmentation:latest 13 | ``` -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/baudcode/tf-semantic-segmentation.svg?branch=master)](https://travis-ci.org/baudcode/tf-semantic-segmentation) 2 | [![PyPI Status Badge](https://badge.fury.io/py/tf-semantic-segmentation.svg)](https://pypi.org/project/tf-semantic-segmentation/) 3 | [![codecov](https://codecov.io/gh/baudcode/tf-semantic-segmentation/branch/dev/graph/badge.svg)](https://codecov.io/gh/baudcode/tf-semantic-segmentation) 4 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1xBH4WxhJ7TlnC7pck7ifLjo3NrdYmug-) 5 | 6 | ## Requirements 7 | 8 | ```shell 9 | sudo apt-get install libsm6 libxext6 libxrender-dev libyaml-dev libpython3-dev 10 | ``` 11 | 12 | #### Tensorflow (2.x) & Tensorflow Addons (optional) 13 | 14 | ```shell 15 | pip install tensorflow-gpu==2.4.0 --upgrade 16 | pip install tensorflow-addons==0.12.0 --upgrade 17 | ``` 18 | 19 | ## Installation 20 | 21 | ```shell 22 | pip install tf-semantic-segmentation 23 | ``` 24 | 25 | ## Features 26 | 27 | - Fast and easy training/prediction on multiple datasets 28 | - Distributed Training on Multiple GPUs 29 | - Hyper Parameter Optimization using WandB 30 | - WandB Integration 31 | - Easily create TFRecord from Directory 32 | - Tensorboard visualizations 33 | - Ensemble inference 34 | 35 | 36 | ### Datasets 37 | 38 | - Ade20k 39 | - Camvid 40 | - Cityscapes 41 | - MappingChallenge 42 | - MotsChallenge 43 | - Coco 44 | - PascalVoc2012 45 | - Taco 46 | - Shapes (randomly creating triangles, rectangles and circles) 47 | - Toy (Overlaying TinyImageNet with MNIST) 48 | - ISIC2018 49 | - CVC-ClinicDB 50 | 51 | 52 | 53 | ### Models 54 | 55 | - [U2Net / U2NetP](https://arxiv.org/abs/2005.09007) 56 | - [Unet](https://arxiv.org/abs/1505.04597) 57 | - [PSP](https://arxiv.org/abs/1612.01105) 58 | - [FCN](https://arxiv.org/abs/1411.4038) 59 | - [Erfnet](https://arxiv.org/abs/1806.08522) 60 | - [MultiResUnet](https://arxiv.org/abs/1902.04049) 61 | - [NestedUnet (Unet++)](https://arxiv.org/abs/1807.10165) 62 | - SatelliteUnet 63 | - MobilenetUnet (unet with mobilenet encoder pre-trained on imagenet) 64 | - InceptionResnetV2Unet (unet with inception-resnet v2 encoder pre-trained on imagenet) 65 | - ResnetUnet (unet with resnet50 encoder pre-trained on imagenet) 66 | - AttentionUnet 67 | 68 | ### Losses 69 | 70 | - Catagorical Crossentropy 71 | - Binary Crossentropy 72 | - Crossentropy + SSIM 73 | - Dice 74 | - Crossentropy + Dice 75 | - Tversky 76 | - Focal 77 | - Focal + Tversky 78 | 79 | ### Metrics 80 | 81 | - f1 82 | - f2 83 | - iou 84 | - precision 85 | - recall 86 | - psnr 87 | - ssim 88 | 89 | ### Activations: 90 | 91 | - mish 92 | - swish 93 | - relu6 94 | 95 | ### Optimizers: 96 | 97 | - Ranger 98 | - RAdam 99 | 100 | ### Normalization 101 | 102 | - Instance Norm 103 | - Batch Norm 104 | 105 | ### Augmentations 106 | 107 | - flip left/right 108 | - flip up/down 109 | - rot 180 110 | - color -------------------------------------------------------------------------------- /docs/metrics.md: -------------------------------------------------------------------------------- 1 | ### Losses 2 | 3 | - Catagorical Crossentropy 4 | - Binary Crossentropy 5 | - Crossentropy + SSIM 6 | - Dice 7 | - Crossentropy + Dice 8 | - Tversky 9 | - Focal 10 | - Focal + Tversky 11 | 12 | ##### Code 13 | 14 | ``` 15 | from tf_semantic_segmentation.losses import get_loss_by_name, losses_by_name 16 | 17 | losses = list(losses_by_name.keys()) 18 | for l in losses: 19 | fn = get_loss_by_name(l) 20 | value = fn(y_true, y_pred) 21 | ``` 22 | 23 | ### Metrics 24 | 25 | - f1 26 | - f2 27 | - iou 28 | - precision 29 | - recall 30 | - psnr 31 | - ssim 32 | 33 | #### Using Code 34 | 35 | ``` 36 | from tf_semantic_segmentation.metrics import get_metric_by_name, metrics_by_name 37 | 38 | metrics = list(metrics_by_name.keys()) 39 | for m in metrics: 40 | fn = get_metric_by_name(m) 41 | value = fn(y_true, y_pred) 42 | ``` -------------------------------------------------------------------------------- /docs/models.md: -------------------------------------------------------------------------------- 1 | - [Unet](https://arxiv.org/abs/1505.04597) 2 | - [Erfnet](https://arxiv.org/abs/1806.08522) 3 | - [MultiResUnet](https://arxiv.org/abs/1902.04049) 4 | - [PSP](https://arxiv.org/abs/1612.01105) (experimental) 5 | - [FCN](https://arxiv.org/abs/1411.4038) (experimental) 6 | - [NestedUnet (Unet++)](https://arxiv.org/abs/1807.10165) (experimental) 7 | - [U2Net / U2NetP](https://arxiv.org/abs/2005.09007) (experimental) 8 | - SatelliteUnet 9 | - MobilenetUnet (unet with mobilenet encoder pre-trained on imagenet) 10 | - InceptionResnetV2Unet (unet with inception-resnet v2 encoder pre-trained on imagenet) 11 | - ResnetUnet (unet with resnet50 encoder pre-trained on imagenet) 12 | - AttentionUnet 13 | 14 | ```python 15 | from tf_semantic_segmentation import models 16 | 17 | # print all available models 18 | print(list(modes.models_by_name.keys())) 19 | 20 | # returns a model (without the final activation function) 21 | model = models.get_model_by_name('erfnet', {"input_shape": (128, 128, 3), "num_classes": 5}) 22 | 23 | # call models directly 24 | model = models.erfnet(input_shape=(128, 128), num_classes=5) 25 | ``` -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material -------------------------------------------------------------------------------- /docs/tflite.md: -------------------------------------------------------------------------------- 1 | #### Convert the model 2 | 3 | ```shell 4 | python -m tf_semantic_segmentation.bin.convert_tflite -i logs/mymodel/saved_model/0/ -o model.tflite 5 | ``` 6 | 7 | #### Test inference on the model 8 | 9 | ```shell 10 | python -m tf_semantic_segmentation.debug.tflite_test -m model.tflite -i Harris_Sparrow_0001_116398.jpg 11 | ``` -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | ## Create Dataset for fast training 2 | 3 | - The following command will create the dataset `cvc_clinicdb` in `/tmp/data` and the records 4 | in `/tmp/data/cvc_clinicdb/records`. 5 | ```shell 6 | python -m tf_semantic_segmentation.bin.tfrecord_writer -d cvc_clinicdb -c '/tmp/data' 7 | ``` 8 | - For using your own dataset please refer to the [datasets](/datasets) section. 9 | 10 | 11 | ## Training 12 | 13 | - Hint: To see train/test/val images you have to start tensorboard like this 14 | 15 | ```bash 16 | tensorboard --logdir=logs/ --reload_multifile=true 17 | ``` 18 | 19 | #### On inbuild datasets - generator (slow) 20 | 21 | ```bash 22 | python -m tf_semantic_segmentation.bin.train -ds 'cvc_clinicdb' -bs 8 -e 100 \ 23 | -logdir 'logs/taco-binary-test' -o 'adam' -lr 5e-3 --size 256,256 \ 24 | -l 'binary_crossentropy' -fa 'sigmoid' \ 25 | --train_on_generator --gpus='0' \ 26 | --tensorboard_train_images --tensorboard_val_images 27 | ``` 28 | 29 | #### Using a fixed record path (fast) 30 | 31 | ```bash 32 | python -m tf_semantic_segmentation.bin.train --record_dir="/tmp/data/cvc_clinicdb/records" \ 33 | -bs 4 -e 100 -logdir 'logs/cvc-adam-1e-4-mish-bs4' -o 'adam' -lr 1e-4 -l 'categorical_crossentropy' \ 34 | -fa 'softmax' -bufsize 50 --metrics='iou_score,f1_score' -m 'unet' --gpus='0' -a 'mish' \ 35 | --tensorboard_train_images --tensorboard_val_images 36 | ``` 37 | 38 | #### Multi GPU training 39 | 40 | ```bash 41 | python -m tf_semantic_segmentation.bin.train --record_dir="/tmp/data/cvc_clinicdb/records" \ 42 | -bs 4 -e 100 -logdir 'logs/cvc-adam-1e-4-mish-bs4' -o 'adam' -lr 1e-4 -l 'categorical_crossentropy' \ 43 | -fa 'softmax' -bufsize 50 --metrics='iou_score,f1_score' -m 'unet' --gpus='0,1,2,3' -a 'mish' 44 | ``` 45 | 46 | ## Using Code 47 | 48 | ```python 49 | from tf_semantic_segmentation.bin.train import train_test_model, get_args 50 | 51 | # get the default args 52 | args = get_args({}) 53 | 54 | # change some parameters 55 | # !rm -r logs/ 56 | args.model = 'unet' 57 | # args['color_mode'] = 0 58 | args.batch_size = 8 59 | args.size = [128, 128] # resize input dataset to this size 60 | args.epochs = 10 61 | args.learning_rate = 1e-4 62 | args.optimizer = 'adam' # ['adam', 'radam', 'ranger'] 63 | args.loss = 'dice' 64 | args.logdir = 'logs' 65 | args.record_dir = "/tmp/data/cvc_clinicdb/records" 66 | args.final_activation = 'softmax' 67 | 68 | # train and test 69 | results, model = train_test_model(args) 70 | results['evaluate'] # returns last evaluated results using val dataset 71 | results['history'] # returns the history object from model.fit() 72 | ``` -------------------------------------------------------------------------------- /examples/Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Simply train on of the shipped models" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "scrolled": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "from tf_semantic_segmentation.bin import train\n", 19 | "from pprint import pprint\n", 20 | "# get the default arguments\n", 21 | "args = train.get_args({})\n", 22 | "\n", 23 | "args.model = 'unet'\n", 24 | "args.batch_size = 4\n", 25 | "args.epochs = 20\n", 26 | "args.gpus = [0] # list of all gpus to use for training\n", 27 | "args.color_mode = 0 # RGB, 1 = GRAY\n", 28 | "args.size = [512, 512] # height, width\n", 29 | "\n", 30 | "# optimizer with learning rate\n", 31 | "args.optmizer = 'ranger'\n", 32 | "args.learning_rate = 1e-4\n", 33 | "\n", 34 | "# define the loss with metrics\n", 35 | "args.loss = 'dice_binary_crossentropy'\n", 36 | "args.final_activation = 'sigmoid'\n", 37 | "args.metrics = ['iou_score', 'f1_score']\n", 38 | "\n", 39 | "# use the taco binary dataset (0=no_plastic, 1=plastic)\n", 40 | "args.dataset = 'tacobinary'\n", 41 | "args.train_on_generator = True\n", 42 | "args.buffer_size = 50 # increase this depending on how much RAM you have available\n", 43 | "\n", 44 | "# tensorboard parameters\n", 45 | "args.tensorboard_train_images_update_batch_freq = 100 # show predictions during training, every 100 steps\n", 46 | "args.tensorboard_val_images = True # show val predictions at the end of every epoch\n", 47 | "\n", 48 | "pprint(vars(args))\n", 49 | "\n", 50 | "\n", 51 | "# start training\n", 52 | "train.train_test_model(args)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "### Train our own model" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "from tensorflow.keras.models import Model\n", 69 | "from tensorflow.keras.layers import Input, Conv2D\n", 70 | "\n", 71 | "def model_fn(input_shape=(256, 256, 3), num_classes=2):\n", 72 | " i = Input(input_shape)\n", 73 | " y = Conv2D(num_classes, kernel_size=(1, 1))(i)\n", 74 | " return Model(inputs=i, outputs=y)\n", 75 | "\n", 76 | "\n", 77 | "args.model = model_fn\n", 78 | "train.train_test_model(args)" 79 | ] 80 | } 81 | ], 82 | "metadata": { 83 | "kernelspec": { 84 | "display_name": "Python 3", 85 | "language": "python", 86 | "name": "python3" 87 | }, 88 | "language_info": { 89 | "codemirror_mode": { 90 | "name": "ipython", 91 | "version": 3 92 | }, 93 | "file_extension": ".py", 94 | "mimetype": "text/x-python", 95 | "name": "python", 96 | "nbconvert_exporter": "python", 97 | "pygments_lexer": "ipython3", 98 | "version": "3.7.6" 99 | } 100 | }, 101 | "nbformat": 4, 102 | "nbformat_minor": 4 103 | } 104 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: TF Semantic Segmentation Documentation 2 | theme: 3 | name: material 4 | repo_name: baudcode/tf-semantic-segmentation 5 | repo_url: https://github.com/baudcode/tf-semantic-segmentation 6 | extra: 7 | social: 8 | - icon: fontawesome/brands/github-alt 9 | link: https://github.com/baudcode 10 | markdown_extensions: 11 | - admonition 12 | nav: 13 | - Introduction: 'index.md' 14 | - Usage: 'usage.md' 15 | - Datasets: 'datasets.md' 16 | - Models: 'models.md' 17 | - Metrics and Losses: "metrics.md" 18 | - Custom Activations: "activations.md" 19 | - Predict and Ensemble: "predict.md" 20 | - TFLite: 'tflite.md' 21 | - Docker: 'docker.md' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | imageio 3 | opencv-python 4 | wandb 5 | tqdm 6 | scipy 7 | xmltodict 8 | pillow 9 | pytz 10 | pyyaml -------------------------------------------------------------------------------- /scripts/docker-compose-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sudo apt update && sudo apt upgrade -y 4 | sudo apt install -y apt-transport-https ca-certificates curl software-properties-common 5 | 6 | # install docker 7 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - 8 | sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" 9 | sudo apt update 10 | sudo apt install -y docker-ce 11 | sudo systemctl status docker 12 | docker -v 13 | sudo usermod -aG docker $USER 14 | 15 | # make sure docker runs hello-world 16 | docker container run hello-world 17 | 18 | # install docker compose 19 | sudo curl -L "https://github.com/docker/compose/releases/download/1.23.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose 20 | sudo chmod +x /usr/local/bin/docker-compose 21 | docker-compose --version 22 | 23 | -------------------------------------------------------------------------------- /scripts/nvidia-docker-setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | sudo apt install curl servefile fish 5 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - 6 | sudo apt-key fingerprint 0EBFCD88 7 | sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" 8 | sudo apt-get update 9 | sudo apt-get install docker-ce 10 | sudo service docker restart 11 | 12 | # If you have nvidia-docker 1.0 installed: we need to remove it and all existing GPU containers 13 | docker volume ls -q -f driver=nvidia-docker | xargs -r -I{} -n1 docker ps -q -a -f volume={} | xargs -r docker rm -f 14 | # sudo apt-get purge -y nvidia-docker 15 | 16 | # Add the package repositories 17 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | \ 18 | sudo apt-key add - 19 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 20 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \ 21 | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 22 | sudo apt-get update 23 | 24 | # Install nvidia-docker2 and reload the Docker daemon configuration 25 | sudo apt-get install -y nvidia-docker2 26 | sudo pkill -SIGHUP dockerd 27 | 28 | # Test nvidia-smi with the latest official CUDA image 29 | docker run --runtime=nvidia --rm nvidia/cuda:10.0-base nvidia-smi 30 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max_line_length = 200 3 | ignore = E265,E266,E722,E741,W605 4 | 5 | [tool:pytest] 6 | pep8ignore = E265 E266 E722 E741 W605 7 | pep8maxlinelength = 200 8 | 9 | 10 | [pycodestyle] 11 | exclude = .vscode,.pytest_cache/,__pycache__/,.env/,.eggs/,build/,dist/, tf_semantic_segmentation.egg-info/ 12 | max_line_length = 200 13 | ignore = E265,E266,E722,E741,W605 14 | statistics = True -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.command.install import install 2 | from distutils.command.clean import clean 3 | from distutils.util import convert_path 4 | import setuptools 5 | import distutils 6 | import os 7 | import glob 8 | import shutil 9 | 10 | # See: https://stackoverflow.com/questions/2058802/how-can-i-get-the-version-defined-in-setup-py-setuptools-in-my-package/2073599#2073599 11 | package_name = "tf_semantic_segmentation" 12 | 13 | main_ns = {} 14 | ver_path = convert_path('%s/version.py' % package_name) 15 | with open(ver_path) as ver_file: 16 | exec(ver_file.read(), main_ns) 17 | 18 | with open("requirements.txt", 'r') as h: 19 | requirements = [r.replace("\n", "") for r in h.readlines()] 20 | 21 | here = os.path.dirname(__name__) 22 | 23 | 24 | class CleanCommand(clean): 25 | """Custom clean command to tidy up the project root.""" 26 | CLEAN_FILES = './build ./dist ./.eggs ./*.pyc ./*.tgz ./*.egg-info'.split( 27 | ' ') 28 | 29 | user_options = [] 30 | 31 | def initialize_options(self): 32 | pass 33 | 34 | def finalize_options(self): 35 | pass 36 | 37 | def run(self): 38 | global here 39 | 40 | for path_spec in self.CLEAN_FILES: 41 | # Make paths absolute and relative to this path 42 | abs_paths = glob.glob(os.path.normpath( 43 | os.path.join(here, path_spec))) 44 | for path in [str(p) for p in abs_paths]: 45 | if not path.startswith(here): 46 | # Die if path in CLEAN_FILES is absolute + outside this directory 47 | raise ValueError( 48 | "%s is not a path inside %s" % (path, here)) 49 | print('removing %s' % os.path.relpath(path)) 50 | shutil.rmtree(path) 51 | 52 | 53 | with open("README.md", "r") as fh: 54 | long_description = fh.read() 55 | 56 | setuptools.setup( 57 | name=package_name, 58 | version=main_ns['__version__'], 59 | description='Implementation of various semantic segmentation models in tensorflow & keras including popular datasets', 60 | author='Malte Koch', 61 | license='MIT', 62 | long_description=long_description, 63 | long_description_content_type="text/markdown", 64 | keywords=["keras", "tensorflow", "%s" % package_name, "semantic", "segmentation", "ade20k", "coco", "pascalvoc", "cityscapes"], 65 | author_email='malte-koch@gmx.net', 66 | maintainer='Malte Koch', 67 | maintainer_email='malte-koch@gmx.net', 68 | url="https://github.com/baudcode/tf-semantic-segmentation", 69 | cmdclass={"clean": CleanCommand}, 70 | # namespace_packages=[package_name], 71 | packages=setuptools.find_packages(include=package_name + "/*"), 72 | # packages=setuptools.find_namespace_packages(exclude=['tests', 'tests.*', "experimental", "experimantal/*"]), 73 | install_requires=requirements, 74 | entry_points={ 75 | 'console_scripts': [ 76 | "tf-semantic-segmentation-train=tf_semantic_segmentation.bin.train:main", 77 | "tf-semantic-segmentation-predict=tf_semantic_segmentation.evaluation.predict:main", 78 | "tf-semantic-segmentation-tfrecord-writer=tf_semantic_segmentation.bin.tfrecord_writer:main", 79 | "tf-semantic-segmentation-tfrecord-analyser=tf_semantic_segmentation.bin.tfrecord_analyser:main", 80 | "tf-semantic-segmentation-tfrecord-download=tf_semantic_segmentation.bin.tfrecord_download:main", 81 | "tf-semantic-segmentation-compare-models=tf_semantic_segmentation.evaluation.compare_models:main", 82 | ], 83 | }, 84 | ext_modules=[], 85 | setup_requires=[], 86 | classifiers=[ 87 | # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package 88 | 'Development Status :: 3 - Alpha', 89 | 'Intended Audience :: Developers', 90 | 'Operating System :: POSIX :: Linux', 91 | 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', 92 | 'Programming Language :: Python :: 3', 93 | ], 94 | ) 95 | -------------------------------------------------------------------------------- /sweep.example.yaml: -------------------------------------------------------------------------------- 1 | program: tf_semantic_segmentation/bin/train.py 2 | method: random 3 | metric: 4 | name: val_iou_score 5 | goal: maximize 6 | parameters: 7 | learning_rate: 8 | min: 0.001 9 | max: 0.1 10 | optimizer: 11 | values: ["adam", "radm", "ranger"] 12 | activation: 13 | values: ["relu", "mish", "swish"] 14 | buffer_size: 15 | values: [50] 16 | loss: 17 | values: ["binary_crossentropy"] 18 | final_activation: 19 | values: ["sigmoid"] 20 | epochs: 21 | values: [100] 22 | batch_size: 23 | values: [8, 16] 24 | record_dir: 25 | values: ["/hdd/datasets/taco/records/tacobinary-256x256-resize"] 26 | wandb_project: 27 | values: ["tacobinary-256x256-resize"] 28 | gpus: 29 | values: [""] 30 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/__init__.py -------------------------------------------------------------------------------- /tests/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/bin/__init__.py -------------------------------------------------------------------------------- /tests/bin/tfrecord_writer_test.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.debug import dataset_export 2 | from tf_semantic_segmentation.bin import tfrecord_writer 3 | from tf_semantic_segmentation.datasets import TFReader, DataType 4 | from tf_semantic_segmentation.processing import ColorMode 5 | from ..fixtures import dataset 6 | import tempfile 7 | import shutil 8 | import os 9 | 10 | 11 | def test_export(dataset): 12 | output_dir = tempfile.mkdtemp() 13 | record_dir = os.path.join(output_dir, 'records') 14 | 15 | dataset_export.export(dataset, output_dir, size=None, color_mode=ColorMode.RGB, overwrite=True) 16 | tfrecord_writer.write_records_from_directory(output_dir, record_dir) 17 | reader = TFReader(record_dir) 18 | tfds = reader.get_dataset(DataType.TRAIN) 19 | 20 | for image, mask, num_classes in tfds: 21 | assert(image.numpy().shape == (64, 64, 3)) 22 | assert(mask.numpy().shape == (64, 64)) 23 | assert(num_classes.numpy() == dataset.num_classes) 24 | break 25 | 26 | dataset_export.export(dataset, output_dir, size=(32, 32), color_mode=ColorMode.RGB, overwrite=True) 27 | tfrecord_writer.write_records_from_directory(output_dir, record_dir, overwrite=True) 28 | reader = TFReader(record_dir) 29 | tfds = reader.get_dataset(DataType.TRAIN) 30 | 31 | for image, mask, num_classes in tfds: 32 | assert(image.numpy().shape == (32, 32, 3)) 33 | assert(mask.numpy().shape == (32, 32)) 34 | assert(num_classes.numpy() == dataset.num_classes) 35 | break 36 | 37 | dataset_export.export(dataset, output_dir, size=(31, 31), color_mode=ColorMode.GRAY, overwrite=True) 38 | tfrecord_writer.write_records_from_directory(output_dir, record_dir, overwrite=True) 39 | reader = TFReader(record_dir) 40 | tfds = reader.get_dataset(DataType.TRAIN) 41 | 42 | for image, mask, num_classes in tfds: 43 | assert(image.numpy().shape == (31, 31, 1)) 44 | assert(mask.numpy().shape == (31, 31)) 45 | assert(num_classes.numpy() == dataset.num_classes) 46 | break 47 | 48 | shutil.rmtree(output_dir) 49 | 50 | 51 | def test_export_dataset_by_name(): 52 | # test export dataset 53 | output_dir = tempfile.mkdtemp() 54 | record_dir = os.path.join(output_dir, 'records') 55 | name = 'shapesmini' 56 | num_classes = 2 57 | 58 | record_dir = tfrecord_writer.write_records_from_dataset_name(name, output_dir, size=None, overwrite=True) 59 | reader = TFReader(record_dir) 60 | tfds = reader.get_dataset(DataType.TRAIN) 61 | 62 | for image, mask, num_classes in tfds: 63 | assert(image.numpy().shape == (32, 32, 1)) 64 | assert(mask.numpy().shape == (32, 32)) 65 | assert(num_classes.numpy() == num_classes) 66 | break 67 | 68 | record_dir = tfrecord_writer.write_records_from_dataset_name(name, output_dir, size=(32, 32), overwrite=True) 69 | reader = TFReader(record_dir) 70 | tfds = reader.get_dataset(DataType.TRAIN) 71 | 72 | for image, mask, num_classes in tfds: 73 | assert(image.numpy().shape == (32, 32, 1)) 74 | assert(mask.numpy().shape == (32, 32)) 75 | assert(num_classes.numpy() == num_classes) 76 | break 77 | 78 | record_dir = tfrecord_writer.write_records_from_dataset_name(name, output_dir, size=(31, 31), color_mode=ColorMode.RGB, overwrite=True) 79 | reader = TFReader(record_dir) 80 | tfds = reader.get_dataset(DataType.TRAIN) 81 | 82 | for image, mask, num_classes in tfds: 83 | assert(image.numpy().shape == (31, 31, 3)) 84 | assert(mask.numpy().shape == (31, 31)) 85 | assert(num_classes.numpy() == num_classes) 86 | break 87 | 88 | shutil.rmtree(output_dir) 89 | -------------------------------------------------------------------------------- /tests/bin/train_test.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.bin import train 2 | from ..fixtures import tfrecord_reader 3 | import pytest 4 | import tempfile 5 | import shutil 6 | 7 | 8 | @pytest.mark.usefixtures('tfrecord_reader') 9 | def test_simple_train(tfrecord_reader): 10 | args = train.get_args({}) 11 | args.validation_steps = 1 12 | args.epochs = 1 13 | args.steps_per_epoch = 1 14 | args.batch_size = 1 15 | args.buffer_size = 1 16 | args.gpus = "" 17 | args.record_dir = tfrecord_reader.record_dir 18 | args.logdir = tempfile.mkdtemp() 19 | data = train.train_test_model(args) 20 | shutil.rmtree(args.logdir) 21 | -------------------------------------------------------------------------------- /tests/data.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import tensorflow as tf 4 | from numpy import random 5 | 6 | NUM_CLASSES = 5 7 | TEST_IMG = ((random.random((32, 32)) - 0.001) * NUM_CLASSES).astype(np.uint8) 8 | TEST_IMG_BINARY = TEST_IMG.astype(np.float32) / float(NUM_CLASSES - 1) 9 | TEST_IMG_BINARY = np.where(TEST_IMG_BINARY > 0.5, 1.0, 0.0) 10 | TEST_IMG_ONEHOT = tf.one_hot(TEST_IMG, NUM_CLASSES) 11 | TEST_BATCH = np.expand_dims(TEST_IMG_ONEHOT, axis=0) 12 | TEST_BATCH = TEST_BATCH.astype(np.float32) 13 | TEST_BATCH_BINARY = np.expand_dims(np.expand_dims(TEST_IMG_BINARY, axis=0), axis=-1) 14 | 15 | 16 | def load_large_mask_batch(): 17 | mask = imageio.imread('tests/test.png') 18 | print(np.unique(mask)) 19 | print(mask.max(), mask.shape) 20 | mask_onehot = tf.one_hot(mask, 35) 21 | mask_batch = tf.expand_dims(mask_onehot, axis=0) 22 | return mask_batch 23 | -------------------------------------------------------------------------------- /tests/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/datasets/__init__.py -------------------------------------------------------------------------------- /tests/datasets/test_dataset.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.datasets import Dataset, DataType 2 | import pytest 3 | from ..fixtures import dataset 4 | import types 5 | 6 | 7 | @pytest.mark.usefixtures('dataset') 8 | def test_dataset_structure(dataset): 9 | ds = dataset 10 | assert(ds.num_examples(DataType.TRAIN) == 80) 11 | assert(ds.num_examples(DataType.VAL) == 10) 12 | assert(ds.num_examples(DataType.TEST) == 10) 13 | assert(ds.total_examples() == 100) 14 | assert(ds.num_classes == 2) 15 | assert(all(map(lambda x: x in ds.raw().keys(), DataType.get()))) 16 | 17 | image, mask = ds.get_random_item() 18 | assert(image.shape == (64, 64, 3) and image.dtype == 'uint8') 19 | assert(mask.shape == (64, 64) and mask.dtype == 'uint8') 20 | assert(mask.max() == 1) 21 | 22 | g = ds.get()() 23 | assert(isinstance(g, types.GeneratorType)) 24 | -------------------------------------------------------------------------------- /tests/datasets/test_shapes.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.datasets.shapes import DataType 2 | import os 3 | import tempfile 4 | import numpy as np 5 | import pytest 6 | from ..fixtures import shapes_ds 7 | 8 | 9 | @pytest.mark.usefixtures('shapes_ds') 10 | def test_shapes_dataset(shapes_ds): 11 | 12 | gen = shapes_ds.get() 13 | inputs, targets = next(gen()) 14 | assert(len(targets.shape) == 2) 15 | assert(targets.shape[:2] == inputs.shape[:2]) 16 | assert(targets.dtype == np.uint8) 17 | assert(inputs.dtype == np.uint8) 18 | 19 | assert(shapes_ds.num_examples(DataType.TRAIN) == 80) 20 | assert(shapes_ds.num_examples(DataType.TEST) == 10) 21 | assert(shapes_ds.num_examples(DataType.VAL) == 10) 22 | -------------------------------------------------------------------------------- /tests/datasets/test_tfrecord.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.datasets import tfrecord, DataType 2 | from ..fixtures import dataset 3 | 4 | import os 5 | import tempfile 6 | import numpy as np 7 | import pytest 8 | import time 9 | import shutil 10 | 11 | 12 | def write_and_read_records(ds, options): 13 | record_dir = os.path.join(tempfile.gettempdir(), str(time.time()), 'records') 14 | writer = tfrecord.TFWriter(record_dir, options=options) 15 | writer.write(ds) 16 | 17 | reader = tfrecord.TFReader(record_dir, options=options) 18 | dataset = reader.get_dataset(DataType.TRAIN) 19 | for image, mask, num_classes in dataset: 20 | print(image.shape, mask.shape, num_classes) 21 | break 22 | 23 | assert(writer.num_written(DataType.TRAIN) == ds.num_examples(DataType.TRAIN)) 24 | assert(writer.num_written(DataType.TEST) == ds.num_examples(DataType.TEST)) 25 | assert(writer.num_written(DataType.VAL) == ds.num_examples(DataType.VAL)) 26 | 27 | assert(reader.num_examples(DataType.TRAIN) == writer.num_written(DataType.TRAIN)) 28 | assert(reader.num_examples(DataType.TEST) == writer.num_written(DataType.TEST)) 29 | assert(reader.num_examples(DataType.VAL) == writer.num_written(DataType.VAL)) 30 | shutil.rmtree(record_dir) 31 | 32 | 33 | def test_write_and_read_records_no_compression(dataset): 34 | write_and_read_records(dataset, "") 35 | 36 | 37 | def test_write_and_read_records_gzip_compression(dataset): 38 | write_and_read_records(dataset, "GZIP") 39 | -------------------------------------------------------------------------------- /tests/datasets/test_utils.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.datasets import utils, DataType 2 | from ..fixtures import dataset 3 | import pytest 4 | import copy 5 | import tempfile 6 | import numpy as np 7 | import imageio 8 | import os 9 | import shutil 10 | 11 | 12 | def test_all_data_types_exist(): 13 | assert(utils.DataType.get() == ['train', 'test', 'val']) 14 | 15 | 16 | @pytest.mark.usefictures('dataset') 17 | def test_test_dataset(dataset): 18 | image, mask, num_classes = utils.test_dataset(dataset) 19 | 20 | assert(num_classes == dataset.num_classes) 21 | assert(image.shape == (64, 64, 3)) 22 | assert(mask.shape == (64, 64)) 23 | 24 | 25 | @pytest.mark.usefictures('dataset') 26 | def test_tfdataset(dataset): 27 | tfds = utils.convert2tfdataset(dataset, DataType.TRAIN) 28 | for image, mask, num_classes in tfds: 29 | assert(image.numpy().shape == (64, 64, 3) and image.numpy().dtype == 'uint8') 30 | assert(mask.numpy().shape == (64, 64) and mask.numpy().dtype == 'uint8') 31 | assert(dataset.num_classes == num_classes.numpy()) 32 | break 33 | 34 | 35 | def test_splits(): 36 | # get_train_test_val_from_list 37 | l = list(range(100)) 38 | train, test, val = utils.get_train_test_val_from_list(copy.deepcopy(l), train_split=0.8, val_split=0.5, shuffle=False) 39 | assert(len(train) == 80 and len(test) == 10 and len(val) == 10) 40 | assert(train == list(range(80)) and val == list(range(80, 90)) and test == list(range(90, 100))) 41 | 42 | train, test, val = utils.get_train_test_val_from_list(copy.deepcopy(l), train_split=0.8, val_split=0.5, shuffle=True) 43 | assert(list(sorted(train + test + val)) == l) 44 | train2, test2, val2 = utils.get_train_test_val_from_list(copy.deepcopy(l), train_split=0.8, val_split=0.5, shuffle=True, rand=lambda: 0.3) 45 | assert(train2 != train and test2 != 2 and val2 != val) 46 | 47 | train, val = utils.get_split_from_list(copy.copy(l), split=0.7) 48 | assert(train == list(range(70)) and val == list(range(70, 100))) 49 | 50 | images_dir = tempfile.mkdtemp() 51 | masks_dir = tempfile.mkdtemp() 52 | 53 | for i in range(10): 54 | imageio.imwrite(os.path.join(images_dir, '%d.png' % i), np.zeros((32, 32), np.uint8)) 55 | imageio.imwrite(os.path.join(masks_dir, '%d.png' % i), np.zeros((32, 32), np.uint8)) 56 | 57 | split = utils.get_split_from_dirs(images_dir, masks_dir, extensions=['png'], train_split=0.8, val_split=0.5) 58 | assert(all([dt in split.keys() for dt in DataType.get()])) 59 | assert(len(split[DataType.TRAIN]) == 8 and len(split[DataType.VAL]) == 1 and len(split[DataType.TEST]) == 1) 60 | 61 | shutil.rmtree(images_dir) 62 | shutil.rmtree(masks_dir) 63 | 64 | split = utils.get_split(copy.copy(l)) 65 | assert(all([dt in split.keys() for dt in DataType.get()])) 66 | assert(len(split[DataType.TRAIN]) == 80 and len(split[DataType.VAL]) == 10 and len(split[DataType.TEST]) == 10) 67 | -------------------------------------------------------------------------------- /tests/debug/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/debug/__init__.py -------------------------------------------------------------------------------- /tests/debug/test_export_dataset.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.debug import dataset_export 2 | from tf_semantic_segmentation.processing import ColorMode 3 | from tf_semantic_segmentation.utils import get_files 4 | from ..fixtures import dataset 5 | import tempfile 6 | import shutil 7 | import os 8 | 9 | 10 | def test_dataset_export(dataset): 11 | output_dir = tempfile.mkdtemp() 12 | dataset_export.export(dataset, output_dir, size=(64, 64), color_mode=ColorMode.RGB, overwrite=True) 13 | files = get_files(output_dir) 14 | assert(len(files) == (dataset.total_examples() * 2) + 1) 15 | 16 | masks = get_files(output_dir, ['png']) 17 | images = get_files(output_dir, ['jpg']) 18 | 19 | assert(len(masks) == dataset.total_examples()) 20 | assert(len(images) == dataset.total_examples()) 21 | labels_path = os.path.join(output_dir, 'labels.txt') 22 | 23 | assert(os.path.exists(labels_path)) 24 | assert(dataset_export.read_labels(labels_path) == dataset.labels) 25 | 26 | shutil.rmtree(output_dir) 27 | -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.datasets import Dataset, DataType, utils, TFReader, TFWriter 2 | from tf_semantic_segmentation.datasets.shapes import ShapesDS 3 | 4 | import pytest 5 | import os 6 | import tqdm 7 | import cv2 8 | import numpy as np 9 | import imageio 10 | import tempfile 11 | import pytest 12 | import shutil 13 | 14 | 15 | class TestDataset(Dataset): 16 | 17 | def __init__(self, cache_dir, num=100, size=(64, 64)): 18 | super(TestDataset, self).__init__(cache_dir) 19 | self._num = num 20 | self._size = size 21 | self.split = self._generate() 22 | 23 | def _generate(self): 24 | data = [] 25 | for i in tqdm.trange(self._num, desc='creating test dataset'): 26 | image = np.zeros((self._size[0], self._size[1], 3), np.uint8) 27 | mask = np.zeros((self._size[0], self._size[1]), np.uint8) 28 | center = (self._size[1] // 2, self._size[0] // 2) 29 | image = cv2.circle(image, center, 5, (255, 0, 0), thickness=cv2.FILLED) 30 | mask = cv2.circle(mask, center, 5, 1, thickness=cv2.FILLED) 31 | mask_path = os.path.join(self.cache_dir, 'mask-%d.png') 32 | image_path = os.path.join(self.cache_dir, '%d.jpg') 33 | imageio.imwrite(image_path, image) 34 | imageio.imwrite(mask_path, mask) 35 | data.append((image_path, mask_path)) 36 | 37 | return utils.get_split(data) 38 | 39 | def raw(self): 40 | return self.split 41 | 42 | def delete(self): 43 | shutil.rmtree(self.cache_dir) 44 | 45 | @property 46 | def labels(self): 47 | return ['bg', 'circle'] 48 | 49 | 50 | @pytest.fixture() 51 | def dataset(request): 52 | cache_dir = os.path.join(tempfile.tempdir, 'testds') # tempfile.mkdtemp() 53 | os.makedirs(cache_dir, exist_ok=True) 54 | ds = TestDataset(cache_dir, num=100, size=(64, 64)) 55 | yield ds 56 | ds.delete() 57 | 58 | 59 | @pytest.fixture() 60 | def tfrecord_reader(): 61 | cache_dir = os.path.join(tempfile.tempdir, 'testds') # tempfile.mkdtemp() 62 | record_dir = os.path.join(cache_dir, 'records') 63 | 64 | os.makedirs(cache_dir, exist_ok=True) 65 | ds = TestDataset(cache_dir, num=100, size=(64, 64)) 66 | TFWriter(record_dir).write(ds) 67 | reader = TFReader(record_dir) 68 | 69 | yield reader 70 | ds.delete() 71 | 72 | 73 | @pytest.fixture() 74 | def shapes_ds(): 75 | """ Returns the dataset fixture """ 76 | return ShapesDS(os.path.join(tempfile.tempdir, 'SHAPES'), num_examples=100) 77 | -------------------------------------------------------------------------------- /tests/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/losses/__init__.py -------------------------------------------------------------------------------- /tests/losses/test_utils.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.losses import utils 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | def test_to_1d_2d(): 7 | arr = np.zeros((64, 64, 3), np.uint8) 8 | 9 | arr1d = utils.to1d(arr) 10 | assert(arr1d.shape[0] == np.prod(arr.shape)) 11 | 12 | arr2d = utils.to2d(arr) 13 | assert(arr2d.shape[0] == arr.shape[0] and arr2d.shape[1] == np.prod(arr.shape[1:])) 14 | 15 | 16 | def test_round_if_needed(): 17 | arr = np.ones((64, 64)) * 0.7 18 | t = utils.round_if_needed(arr, 0.5).numpy() 19 | np.testing.assert_equal(t, np.ones((64, 64))) 20 | 21 | 22 | def test_gather_channels(): 23 | arr = np.ones((1, 32, 32, 3)) 24 | t = utils.gather_channels(arr, indexes=[1, 2]) 25 | assert(t[0].shape == (1, 32, 32, 2)) 26 | np.testing.assert_array_equal(t[0], arr[:, :, :, 1:]) 27 | 28 | 29 | def test_image2onehot(): 30 | num_classes = 5 31 | idx = 2 32 | onehot = np.zeros((1, 64, 64, num_classes), np.float32) 33 | onehot[:, :, :, idx] = 1.0 34 | 35 | image = utils.onehot2image(onehot) 36 | arr = np.ones((1, 64, 64, 1), np.uint8) / (num_classes - 1.0) * idx 37 | np.testing.assert_array_equal(arr, image) 38 | 39 | 40 | def test_expand_binary(): 41 | binary = np.ones((1, 64, 64, 1), np.float32) 42 | onehot = utils.expand_binary(binary) 43 | assert(onehot.shape == (1, 64, 64, 2)) 44 | assert(onehot.numpy().dtype == 'float32') 45 | 46 | onehot = onehot.numpy() 47 | np.testing.assert_array_equal(onehot[:, :, :, 1], np.squeeze(binary, axis=-1)) 48 | np.testing.assert_array_equal(onehot[:, :, :, 0], np.squeeze(np.zeros_like(binary), axis=-1)) 49 | 50 | 51 | def test_average(): 52 | batch = np.ones((5, 64, 64, 1), np.float32) * 5 53 | assert(utils.average(tf.convert_to_tensor(batch)).numpy() == 5.0) 54 | batch[0, :, :, :] = 0.0 55 | assert(utils.average(tf.convert_to_tensor(batch)).numpy() == 4.0) 56 | assert(utils.average(tf.convert_to_tensor(batch), per_image=True).numpy() == 4.0) 57 | -------------------------------------------------------------------------------- /tests/processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/processing/__init__.py -------------------------------------------------------------------------------- /tests/processing/dataset_test.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.processing import dataset, ColorMode 2 | from tf_semantic_segmentation.datasets.utils import convert2tfdataset, DataType 3 | import numpy as np 4 | from ..fixtures import dataset as ds 5 | 6 | 7 | def test_preprocess_fn(): 8 | size = (64, 64) 9 | color_mode = ColorMode.RGB 10 | resize_method = 'resize' 11 | 12 | f = dataset.get_preprocess_fn(size, color_mode, resize_method, scale_mask=False) 13 | 14 | num_classes = 3 15 | image = np.ones((32, 32, 1), np.uint8) * 255 16 | mask = np.ones((32, 32), np.uint8) * 2 17 | 18 | pimage, pmask = f(image, mask, num_classes) 19 | assert(pimage.shape == (64, 64, 3)) 20 | assert(pmask.shape == (64, 64, num_classes)) 21 | pmask_argmax = np.argmax(pmask.numpy(), axis=-1) 22 | np.testing.assert_array_equal(pmask_argmax, np.ones_like(pmask_argmax) * 2) 23 | 24 | # scale masks between 0 and 1.0 25 | f = dataset.get_preprocess_fn(size, color_mode, resize_method, scale_mask=True) 26 | pimage, pmask = f(image, mask, num_classes) 27 | 28 | assert(pmask.shape == (64, 64)) 29 | assert(pmask.numpy().max() <= 1.0) 30 | 31 | 32 | def test_prepare_dataset(ds): 33 | for dt in DataType.get(): 34 | tfds = convert2tfdataset(ds, dt) 35 | tfds = tfds.map(dataset.get_preprocess_fn((64, 64), ColorMode.RGB, 'resize', False)) 36 | tfds = dataset.prepare_dataset(tfds, 2) 37 | for image_batch, mask_batch in tfds: 38 | assert(image_batch.shape == (2, 64, 64, 3)) 39 | assert(mask_batch.shape == (2, 64, 64, ds.num_classes)) 40 | break 41 | 42 | tfds = convert2tfdataset(ds, DataType.TRAIN) 43 | tfds = tfds.map(dataset.get_preprocess_fn((64, 64), ColorMode.RGB, 'resize', True)) 44 | tfds = dataset.prepare_dataset(tfds, 2) 45 | for image_batch, mask_batch in tfds: 46 | assert(image_batch.shape == (2, 64, 64, 3)) 47 | assert(mask_batch.shape == (2, 64, 64)) 48 | break 49 | 50 | batch_size, size = 2, (64, 64) 51 | augment_fn = dataset.get_augment_fn(size, batch_size) 52 | tfds = convert2tfdataset(ds, DataType.TRAIN) 53 | tfds = tfds.map(dataset.get_preprocess_fn(size, ColorMode.RGB, 'resize', False)) 54 | tfds = dataset.prepare_dataset(tfds, batch_size, augment_fn=augment_fn) 55 | 56 | for image_batch, mask_batch in tfds: 57 | assert(image_batch.shape == (2, size[0], size[1], 3)) 58 | assert(mask_batch.shape == (2, size[0], size[1], ds.num_classes)) 59 | break 60 | -------------------------------------------------------------------------------- /tests/processing/image_test.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.processing import image 2 | import numpy as np 3 | 4 | 5 | def test_grid_vis(): 6 | x = np.zeros((9, 16, 16)) 7 | montage = image.grayscale_grid_vis(x, 3, 3) 8 | assert(montage.shape == (48, 48)) 9 | 10 | 11 | def test_fixed_resize(): 12 | x = np.zeros((16, 16, 3)) 13 | assert(image.fixed_resize(x, width=32).shape == (32, 32, 3)) 14 | assert(image.fixed_resize(x, height=32).shape == (32, 32, 3)) 15 | 16 | x = np.zeros((16, 16, 1)) 17 | assert(image.fixed_resize(x, width=32).shape == (32, 32, 1)) 18 | assert(image.fixed_resize(x, height=32).shape == (32, 32, 1)) 19 | 20 | x = np.zeros((32, 16, 1)) 21 | assert(image.fixed_resize(x, width=32).shape == (64, 32, 1)) 22 | assert(image.fixed_resize(x, height=32).shape == (32, 16, 1)) 23 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==2.4.0 2 | tensorflow-addons==0.12.0 3 | pytest-pep8 4 | pytest-cov 5 | coverage 6 | codecov -------------------------------------------------------------------------------- /tests/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/test.png -------------------------------------------------------------------------------- /tests/test_activations.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation import activations 2 | from tensorflow.keras.layers import Activation 3 | import numpy as np 4 | 5 | TEST_BATCH = np.ones((1, 32, 32, 3), dtype=np.float32) 6 | 7 | for name, activation in activations.custom_objects.items(): 8 | print("testing activation %s" % name) 9 | 10 | # use custom object name (keras Activation) 11 | act = Activation(name) 12 | a1 = act(TEST_BATCH) 13 | 14 | # use activation 15 | print(activation) 16 | print(activation.__dict__) 17 | a2 = activation(TEST_BATCH) 18 | np.testing.assert_allclose(a1, a2) 19 | -------------------------------------------------------------------------------- /tests/test_apps.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.models.apps import resnet50, inception, mobilenet 2 | import tensorflow as tf 3 | 4 | 5 | def test_resnet50(): 6 | model = resnet50.ResNet50(input_shape=(256, 256, 3), include_top=False) 7 | print(model.input.shape, model.output.get_shape()) 8 | assert(model.output.shape.as_list() == [None, 1, 1, 2048]) 9 | 10 | 11 | def test_inception(): 12 | model = inception.InceptionResNetV2(input_shape=(256, 256, 3), include_top=False) 13 | print(model.input.shape, model.output.get_shape()) 14 | assert(model.output.shape.as_list() == [None, 8, 8, 1536]) 15 | assert(model.input.shape.as_list() == [None, 256, 256, 3]) 16 | 17 | 18 | def test_mobilenet(): 19 | model = mobilenet.MobileNet(input_shape=(256, 256, 3), include_top=False) 20 | print(model.input.shape, model.output.get_shape()) 21 | assert(model.output.shape.as_list() == [None, 8, 8, 1024]) 22 | assert(model.input.shape.as_list() == [None, 256, 256, 3]) 23 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation import losses 2 | import numpy as np 3 | from .data import TEST_BATCH_BINARY, TEST_BATCH 4 | import tensorflow as tf 5 | 6 | # x = np.array([-2.2, -1.4, -.8, .2, .4, .8, 1.2, 2.2, 2.9, 4.6]) 7 | # y = np.array([0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) 8 | 9 | 10 | def test_losses(): 11 | print("testing...") 12 | np.testing.assert_allclose(TEST_BATCH_BINARY, TEST_BATCH_BINARY) 13 | for name, loss in losses.losses_by_name.items(): 14 | assert(losses.get_loss_by_name(name) == loss) 15 | print("testing loss %s" % name) 16 | with tf.device("/cpu:0"): 17 | if "binary" in name: 18 | l = loss(tf.convert_to_tensor(TEST_BATCH_BINARY), tf.convert_to_tensor(TEST_BATCH_BINARY)) 19 | else: 20 | l = loss(tf.convert_to_tensor(TEST_BATCH), tf.convert_to_tensor(TEST_BATCH)) 21 | print(l) 22 | 23 | if name == "ce_label_smoothing": 24 | np.testing.assert_almost_equal(l, 1.29, decimal=2) 25 | elif name == "categorical_crossentropy": 26 | np.testing.assert_almost_equal(l, 0.0, decimal=6) 27 | else: 28 | np.testing.assert_almost_equal(l, 0.0) 29 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation import metrics 2 | import imageio 3 | import numpy as np 4 | import tensorflow as tf 5 | from .data import TEST_BATCH 6 | 7 | 8 | def test_metrics_assert_1_0(): 9 | """ 10 | assert(metrics.precision(TEST_BATCH, TEST_BATCH) == 1.0) 11 | assert(metrics.f2_score(TEST_BATCH, TEST_BATCH) == 1.0) 12 | assert(metrics.recall(TEST_BATCH, TEST_BATCH) == 1.0) 13 | assert(metrics.recall(TEST_BATCH, TEST_BATCH) == 1.0) 14 | assert(metrics.iou_score(TEST_BATCH, TEST_BATCH) == 1.0) 15 | """ 16 | for name, metric in reversed(list(metrics.metrics_by_name.items())): 17 | assert(metrics.get_metric_by_name(name) == metric) 18 | print("metrics: %s" % name) 19 | if name == 'mae': 20 | assert(metric(TEST_BATCH, TEST_BATCH).numpy() == 0.0) 21 | elif name == "psnr": 22 | assert(metric(TEST_BATCH, TEST_BATCH).numpy() > 100) 23 | else: 24 | assert(metric(TEST_BATCH, TEST_BATCH).numpy() == 1.0) 25 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation import models 2 | import tensorflow as tf 3 | from tensorflow.keras import backend as K 4 | 5 | 6 | def test_build_models(): 7 | 8 | for name, model_fn in models.models_by_name.items(): 9 | 10 | print("testing model %s with model fn %s" % (name, model_fn)) 11 | # check that the model builds 12 | input_shape = (256, 256, 3) 13 | num_classes = 5 14 | model = models.get_model_by_name(name, {"input_shape": input_shape, "num_classes": num_classes}) 15 | assert(model.output.get_shape().as_list(), [None, input_shape[0], input_shape[1], num_classes]) 16 | K.clear_session() 17 | -------------------------------------------------------------------------------- /tests/test_threading.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation import threading 2 | import random 3 | 4 | 5 | def f(i): 6 | return i * 2 7 | 8 | 9 | def test_parallize(): 10 | args = list(range(100)) 11 | 12 | results = threading.parallize(f, args, threads=2) 13 | assert(all([(args[i] * 2) == results[i] for i in range(len(args))])) 14 | 15 | results = threading.parallize(f, args, threads=None) 16 | assert(all([(args[i] * 2) == results[i] for i in range(len(args))])) 17 | 18 | results = threading.parallize_v2(f, args) 19 | assert(all([(args[i] * 2) == results[i] for i in range(len(args))])) 20 | 21 | results = threading.parallize_v3(f, args) 22 | assert(all([(args[i] * 2) == results[i] for i in range(len(args))])) 23 | 24 | results = threading.parallize_v3(f, args, n_processes=2) 25 | assert(all([(args[i] * 2) == results[i] for i in range(len(args))])) 26 | -------------------------------------------------------------------------------- /tests/visualizations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tests/visualizations/__init__.py -------------------------------------------------------------------------------- /tests/visualizations/mask_test.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.visualizations import masks 2 | import numpy as np 3 | 4 | 5 | def test_get_colors(): 6 | colors = masks.get_colors(10) 7 | assert(len(colors) == 10 and all([len(c) == 3 for c in colors])) 8 | 9 | shuffled = masks.get_colors(10, shuffle=True) 10 | assert(len(shuffled) == 10 and all([len(c) == 3 for c in shuffled])) 11 | assert(colors != shuffled) 12 | 13 | 14 | def test_get_colored_segmentation(): 15 | num_classes = 5 16 | predictions = np.ones((1, 32, 32, num_classes), np.float32) 17 | predictions[:, :, :, 2] = 1.0 18 | col = masks.get_colored_segmentation_mask(predictions, num_classes) 19 | assert(col.shape == (1, 32, 32, 3) and col.dtype == 'uint8') 20 | 21 | predictions = np.ones((1, 32, 32, 1), np.float32) 22 | predictions = predictions * 0.6 23 | col = masks.get_colored_segmentation_mask(predictions, num_classes, binary_threshold=0.5) 24 | assert(col.shape == (1, 32, 32, 3) and col.dtype == 'uint8') 25 | 26 | seg_test = np.zeros((1, 32, 32, 3), np.uint8) 27 | seg_test[:, :, :, 0] = 127 28 | np.testing.assert_array_equal(seg_test, col) 29 | 30 | predictions = np.ones((1, 32, 32, 1), np.float32) 31 | predictions = predictions * 0.4 32 | col = masks.get_colored_segmentation_mask(predictions, num_classes, binary_threshold=0.5) 33 | assert(col.shape == (1, 32, 32, 3) and col.dtype == 'uint8') 34 | 35 | seg_test = np.zeros((1, 32, 32, 3), np.uint8) 36 | np.testing.assert_array_equal(seg_test, col) 37 | 38 | 39 | def test_overlay_classes(): 40 | image = np.zeros((32, 32, 3), np.uint8) 41 | num_classes = 5 42 | colors = masks.get_colors(num_classes) 43 | mask = np.ones((32, 32), np.uint8) * 4 44 | overlay = masks.overlay_classes(image.copy(), mask, colors, num_classes, alpha=1.0) 45 | assert(all(np.mean(overlay.mean(axis=0), axis=0) == list(map(float, colors[4])))) 46 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | __import__('pkg_resources').declare_namespace(__name__) # noqa 2 | 3 | from . import activations 4 | from . import datasets 5 | from . import debug 6 | from . import evaluation 7 | from . import layers 8 | from . import losses 9 | from . import metrics 10 | from . import models 11 | from . import optimizers 12 | from . import processing 13 | from . import visualizations 14 | from . import callbacks 15 | from . import serving 16 | from . import settings 17 | from . import threading 18 | from . import utils 19 | from . import version 20 | 21 | from .version import __version__ 22 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/activations/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | from tensorflow.keras.utils import get_custom_objects 3 | from tensorflow.keras.layers import Activation 4 | import tensorflow as tf 5 | 6 | 7 | class Mish(Activation): 8 | 9 | def __init__(self, **kwargs): 10 | super(Mish, self).__init__(lambda x: x * K.tanh(K.softplus(x)), **kwargs) 11 | self.__name__ = "Mish" 12 | 13 | 14 | class ReLU6(Activation): 15 | 16 | def __init__(self, **kwargs): 17 | super(ReLU6, self).__init__(lambda x: K.maximum(x, 6), **kwargs) 18 | self.__name__ = 'ReLU6' 19 | 20 | 21 | class Swish(Activation): 22 | 23 | def __init__(self, **kwargs): 24 | super(Swish, self).__init__(lambda x: x * K.sigmoid(x), **kwargs) 25 | self.__name__ = "Switch" 26 | 27 | 28 | class LeakyReLU(Activation): 29 | 30 | def __init__(self, alpha=0.2, **kwargs): 31 | 32 | super(LeakyReLU, self).__init__(lambda x: tf.nn.leaky_relu(x, alpha=alpha), **kwargs) 33 | self.__name__ = "LeakyReLU" 34 | 35 | 36 | custom_objects = {'relu6': ReLU6(), 'mish': Mish(), 'swish': Swish(), 'leaky_relu': LeakyReLU()} 37 | get_custom_objects().update(custom_objects) 38 | 39 | __all__ = ['Mish', "ReLU6", "Swish"] 40 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tf_semantic_segmentation/bin/__init__.py -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/convert_tflite.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | from ..settings import logger 4 | 5 | 6 | def convert(saved_model_dir, output_path, optimize_for_size=True): 7 | converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) 8 | converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE if optimize_for_size else tf.lite.Optimize.DEFAULT] 9 | logger.info('optimizations: %s', str(converter.optimizations)) 10 | tflite_quant_model = converter.convert() 11 | open(output_path, "wb").write(tflite_quant_model) 12 | 13 | 14 | if __name__ == "__main__": 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("-i", '--saved_model_dir', required=True, help='path to the saved model dir') 18 | parser.add_argument('-o', '--output_path', default='output_model.tflite', help='output model path') 19 | parser.add_argument('--no_size_optimization', action='store_true') 20 | args = parser.parse_args() 21 | 22 | convert(args.saved_model_dir, args.output_path, optimize_for_size=not args.no_size_optimization) 23 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ..utils import download_and_extract, download_file, download_from_google_drive 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('-d', '--destination', help='destintation directory', required=True) 7 | parser.add_argument('-u', '--url', help='url to download', default=None) 8 | parser.add_argument('-id', '--google_drive_id', help='google drive id to download') 9 | parser.add_argument('-fn', '--filename', help='google drive name to download', default=None) 10 | parser.add_argument('-e', '--extract', action='store_true') 11 | parser.add_argument('-remove', '--remove_archive_on_success', action='store_true') 12 | args = parser.parse_args() 13 | 14 | if args.url: 15 | url = args.url 16 | elif args.google_drive_id and args.filename: 17 | url = (args.google_drive_id, args.filename) 18 | else: 19 | raise Exception("invalid arguments") 20 | 21 | if args.extract: 22 | download_and_extract(url, args.destination, remove_archive_on_success=args.remove_archive_on_success) 23 | else: 24 | if type(url) == tuple: 25 | download_from_google_drive(url[0], args.destination, url[1]) 26 | else: 27 | download_file(url, destination_dir=args.destination) 28 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/model_server_config_writer.py: -------------------------------------------------------------------------------- 1 | from ..serving import write_model_config_from_models_dir 2 | import argparse 3 | 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("-d", '--models_dir', help='path to the directory containing the logdirs/models') 8 | parser.add_argument('-c', '--contains', help='logdir/model_name must contain the given sequence', default=None) 9 | parser.add_argument('-o', '--config_path', default='models.yaml') 10 | args = parser.parse_args() 11 | 12 | write_model_config_from_models_dir(args.models_dir, args.contains, args.config_path) 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/tfrecord_analyser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ..datasets import tfrecord, DataType 4 | from ..utils import get_size 5 | import tensorflow as tf 6 | import logging 7 | import imageio 8 | import numpy as np 9 | 10 | from ..settings import logger, logging 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-r', '--record_dir', required=True) 16 | parser.add_argument('-d', '--dump_example', action='store_true') 17 | parser.add_argument('-m', '--mean', action='store_true') 18 | args = parser.parse_args() 19 | 20 | logger.setLevel(logging.INFO) 21 | 22 | reader = tfrecord.TFReader(args.record_dir) 23 | for image, mask, num_classes in reader.get_dataset(DataType.TRAIN): 24 | if args.dump_example: 25 | imageio.imwrite('_example.png', (image.numpy() * 255.).astype(np.uint8)) 26 | imageio.imwrite('_example_mask.png', (mask.numpy() * 255. / (num_classes.numpy() - 1)).astype(np.uint8)) 27 | print("=" * 20) 28 | print("image/mask stats:") 29 | print("image: ", image.dtype, "(shape=", image.shape, ",max=", image.numpy().max(), ")") 30 | print("mask: ", mask.dtype, "(shape=", mask.shape, ",max=", mask.numpy().max(), ")") 31 | break 32 | 33 | print("=" * 20) 34 | print("num_classes: ", reader.num_classes) 35 | print("size: ", reader.size) 36 | print("input shape: ", reader.input_shape) 37 | print('Calculating entries...') 38 | sizes = [] 39 | for data_type in [DataType.TRAIN, DataType.TEST, DataType.VAL]: 40 | if args.mean: 41 | n, mean = reader.num_examples_and_mean(data_type) 42 | print("-> mean[%s(%d)] = %s" % (data_type, n, mean.tolist())) 43 | else: 44 | n = reader.num_examples(data_type) 45 | print(data_type, ":", n) 46 | sizes.append(n) 47 | print("-> total: %d" % sum(sizes)) 48 | print("=" * 20) 49 | print("size: %.2f GB" % (get_size(args.record_dir) / 1024. / 1024. / 1024.)) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/tfrecord_download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ..datasets import download_records 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-t', '--tag', required=True) 9 | parser.add_argument('-r', '--record_dir', required=True) 10 | args = parser.parse_args() 11 | 12 | download_records(args.tag, args.record_dir) 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/tfrecord_writer.py: -------------------------------------------------------------------------------- 1 | from ..datasets import datasets_by_name, get_dataset_by_name, get_cache_dir, DirectoryDataset 2 | from ..processing.dataset import ColorMode, resize_and_change_color 3 | from ..datasets.tfrecord import TFWriter 4 | from ..settings import logger 5 | 6 | import tensorflow as tf 7 | import argparse 8 | import os 9 | 10 | 11 | def write_records_from_directory(directory, record_dir, size=None, color_mode=ColorMode.NONE, 12 | resize_method='resize', num_examples_per_record=100, overwrite=False): 13 | ds = DirectoryDataset(directory) 14 | 15 | if record_dir is None: 16 | raise AssertionError("record_dir cannot be None") 17 | 18 | write_records(ds, record_dir, size=size, color_mode=color_mode, resize_method=resize_method, 19 | num_examples_per_record=num_examples_per_record, overwrite=overwrite) 20 | 21 | 22 | def write_records_from_dataset_name(dataset, data_dir, record_dir=None, size=None, color_mode=ColorMode.NONE, 23 | resize_method='resize', num_examples_per_record=100, overwrite=False): 24 | 25 | cache_dir = get_cache_dir(data_dir, dataset.lower()) 26 | ds = get_dataset_by_name(dataset, cache_dir) 27 | 28 | # write records 29 | if record_dir: 30 | record_dir = record_dir 31 | else: 32 | record_dir = os.path.join(cache_dir, 'records', dataset.lower()) 33 | 34 | write_records(ds, record_dir, size=size, color_mode=color_mode, resize_method=resize_method, 35 | num_examples_per_record=num_examples_per_record, overwrite=overwrite) 36 | 37 | return record_dir 38 | 39 | 40 | def write_records(ds, record_dir, size=None, color_mode=ColorMode.NONE, resize_method='resize', 41 | num_examples_per_record=100, overwrite=False): 42 | 43 | def preprocess_fn(image, mask): 44 | image = tf.image.convert_image_dtype(image, tf.float32) 45 | return resize_and_change_color(image, mask, size, color_mode, resize_method) 46 | 47 | logger.info('wrting records to %s' % record_dir) 48 | writer = TFWriter(record_dir) 49 | writer.write(ds, overwrite=overwrite, num_examples_per_record=num_examples_per_record, preprocess_fn=preprocess_fn) 50 | 51 | # validate number of examples written 52 | writer.validate(ds) 53 | 54 | 55 | def main(): 56 | 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('-d', '--dataset', default=None, choices=list(datasets_by_name.keys())) 59 | parser.add_argument('-r', '--record_dir', default=None) 60 | 61 | parser.add_argument('-dir', '--directory', default=None) 62 | parser.add_argument('-c', '--data_dir', default=None) 63 | parser.add_argument('-num', '--num_examples_per_record', default=100, type=int) 64 | parser.add_argument('-s', '--size', default=None, type=lambda x: list(map(int, x.split(','))), help='height,width') 65 | parser.add_argument('-rm', '--resize_method', default='resize') 66 | parser.add_argument('-cm', '--color_mode', default=ColorMode.NONE, type=int) 67 | 68 | parser.add_argument('-o', '--overwrite', action='store_true') 69 | args = parser.parse_args() 70 | 71 | if args.directory is None and args.dataset is None: 72 | raise AssertionError("please either supply a dataset or a directory containing your data") 73 | 74 | if args.dataset: 75 | assert(args.data_dir is not None), "data_dir argument is required" 76 | write_records_from_dataset_name(args.dataset, args.data_dir, args.record_dir, args.size, args.color_mode, args.resize_method, 77 | args.num_examples_per_record, args.overwrite) 78 | else: 79 | write_records_from_directory(args.directory, args.record_dir, args.size, args.color_mode, args.resize_method, 80 | args.num_examples_per_record, args.overwrite) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/bin/train_all.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 2 | from ..bin.train import train_test_model, get_args 3 | from ..models import models_by_name 4 | from ..utils import logger 5 | from ..datasets import get_dataset_by_name, datasets_by_name, DataType, get_cache_dir 6 | from ..datasets import TFWriter 7 | from tensorflow.keras import backend as K 8 | import os 9 | import json 10 | import time 11 | 12 | 13 | def _create_dataset(name, overwrite=False, data_dir='/tmp/data'): 14 | 15 | cache_dir = get_cache_dir(data_dir, name.lower()) 16 | if os.path.exists(cache_dir) and not overwrite: 17 | return cache_dir 18 | 19 | ds = get_dataset_by_name(name, cache_dir) 20 | 21 | # print labels and classes 22 | print(ds.labels) 23 | print(ds.num_classes) 24 | 25 | # print number of training examples 26 | print(ds.num_examples(DataType.TRAIN)) 27 | 28 | # or simply print the summary 29 | ds.summary() 30 | 31 | writer = TFWriter(cache_dir) 32 | writer.write(ds) 33 | writer.validate(ds) 34 | return cache_dir 35 | 36 | 37 | def train_all(dataset, project='all_models', loss="dice", batch_size=4, epochs=1, steps_per_epoch=1, validation_steps=1, size=[512, 512], overwrite=False, data_dir='/tmp/data'): 38 | 39 | logger.info("cerating dataset") 40 | record_dir = _create_dataset(dataset, overwrite=overwrite, data_dir=data_dir) 41 | 42 | results_dir = "/tmp/%s/results/" % project 43 | os.makedirs(results_dir, exist_ok=True) 44 | 45 | for model_name in models_by_name.keys(): 46 | # get the default args 47 | args = get_args({}) 48 | 49 | # change some parameters 50 | # !rm -r logs/ 51 | args.model = model_name 52 | args.batch_size = batch_size 53 | args.size = size # resize input dataset to this size 54 | args.epochs = epochs 55 | args.steps_per_epoch = steps_per_epoch 56 | args.validation_steps = validation_steps 57 | args.learning_rate = 1e-4 58 | args.optimizer = 'adam' # ['adam', 'radam', 'ranger'] 59 | args.loss = loss 60 | args.logdir = '/tmp/%s/logs/%s' % (project, model_name) 61 | args.record_dir = record_dir 62 | args.final_activation = 'softmax' 63 | args.wandb_name = "%s-%s-%d" % (dataset, model_name, time.time()) 64 | args.wandb_project = project 65 | # train and test 66 | results, model = train_test_model(args) 67 | json.dump(results, open(os.path.join(results_dir, "%s.json" % model_name), 'w')) 68 | 69 | K.clear_session() 70 | 71 | 72 | def main(): 73 | 74 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 75 | parser.add_argument("-d", '--dataset', required=True, help='dataset name') 76 | parser.add_argument("-p", '--project', default='all_models', help="wandb project name") 77 | parser.add_argument("-steps", '--steps_per_epoch', default=-1, type=int, help='number of steps') 78 | parser.add_argument("-e", '--epochs', default=1, type=int, help='number of epochs') 79 | parser.add_argument("-val_steps", '--validation_steps', default=-1, type=int, help='number of val steos') 80 | parser.add_argument("-size", '--size', default=[512, 512], type=lambda x: list(map(float, x.split(","))), help='input size') 81 | parser.add_argument("-dd", '--data_dir', default="/tmp/data", help='dataset data dir') 82 | parser.add_argument("-l", '--loss', default="dice", help='loss function') 83 | parser.add_argument("-bs", '--batch_size', default=4, type=int, help='batch_size') 84 | 85 | args = parser.parse_args() 86 | 87 | train_all(args.dataset, args.project, args.loss, args.batch_size, args.epochs, args.steps_per_epoch, args.validation_steps, args.size) 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # datasets 2 | from . import camvid 3 | from . import cityscapes 4 | from . import mots_challenge 5 | from . import mapping_challenge 6 | from . import ms_coco 7 | from . import pascal 8 | from . import taco 9 | from . import shapes 10 | from . import sun 11 | from . import toy 12 | from . import ade20k 13 | from . import isic 14 | from . import cvc_clinicdb 15 | from . import cub 16 | from . import bioimage 17 | 18 | from .directory import DirectoryDataset 19 | from .utils import DataType, get_split, get_split_from_list, download_records, google_drive_records_by_tag 20 | from .tfrecord import TFReader, TFWriter 21 | from .dataset import Dataset 22 | import os 23 | 24 | # kitty, sun, pascal-voc, ade20k 25 | 26 | datasets_by_name = { 27 | # "ade20k": ade20k.Ade20k, 28 | "sun": sun.Kinect2Data, 29 | "camvid": camvid.CamSeq01, 30 | "coco2014": ms_coco.Coco2014, 31 | "coco2017": ms_coco.Coco2017, 32 | "cityscapes": cityscapes.Cityscapes, 33 | "tacobinary": taco.TacoBinary, 34 | "tacocategory": taco.TacoCategory, 35 | "tacosupercategory": taco.TacoSuperCategory, 36 | "mots": mots_challenge.MotsChallenge, 37 | "pascalvoc2012": pascal.PascalVOC2012, 38 | # dummy datasets for testing 39 | "shapes": shapes.ShapesDS, 40 | "shapesmini": shapes.ShapesDSMini, 41 | "toy": toy.Toy, 42 | "mappingchallenge": mapping_challenge.MappingChallenge, 43 | "ade20k": ade20k.Ade20k, 44 | "isic2018": isic.ISIC2018, 45 | "cvc_clinicdb": cvc_clinicdb.CVCClinicDB, 46 | 'cub2002011binary': cub.CUB2002011Binary, 47 | 'cub2002011category': cub.CUB2002011Category, 48 | 'bioimagebenchmark': bioimage.BioimageBenchmark 49 | } 50 | 51 | 52 | def get_cache_dir(data_dir: str, name: str) -> str: 53 | if "taco" in name.lower(): 54 | cache_dir = os.path.join(data_dir, 'taco') 55 | elif "cub2002011" in name.lower(): 56 | cache_dir = os.path.join(data_dir, 'cub2002011') 57 | elif 'isic' in name.lower(): 58 | cache_dir = os.path.join(data_dir, 'isic') 59 | elif 'coco' in name.lower(): 60 | cache_dir = os.path.join(data_dir, 'coco') 61 | else: 62 | cache_dir = os.path.join(data_dir, name.lower()) 63 | 64 | return cache_dir 65 | 66 | 67 | def get_dataset_by_name(name: str, cache_dir: str, args: dict = {}) -> Dataset: 68 | if name in datasets_by_name.keys(): 69 | return datasets_by_name[name](cache_dir, **args) 70 | else: 71 | raise Exception("could not find dataset %s" % name) 72 | 73 | 74 | __all__ = ["get_dataset_by_name", "DataType", "download_records", "google_drive_records_by_tag", "TFWriter", "TFReader", 75 | "get_split", "get_split_from_list", "get_cache_dir", "datasets_by_name", 'DirectoryDataset'] 76 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import csv 4 | import numpy as np 5 | import csv 6 | 7 | from .utils import get_split_from_list 8 | from ..utils import get_files, download_file, download_and_extract 9 | from .dataset import Dataset, DataType 10 | # https://github.com/CSAILVision/sceneparsing/blob/master/objectInfo150.txt 11 | # https://github.com/CSAILVision/sceneparsing/blob/master/objectInfo150.csv 12 | # download: https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/master/download_ADE20K.sh 13 | 14 | 15 | class Ade20k(Dataset): 16 | 17 | DATA_URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" 18 | LABELS_URL = "https://raw.githubusercontent.com/CSAILVision/sceneparsing/master/objectInfo150.csv" 19 | 20 | @property 21 | def labels(self): 22 | download_path = download_file(self.LABELS_URL, self.cache_dir, file_name="objectInfo150.csv") 23 | with open(download_path, newline='') as csvfile: 24 | reader = csv.DictReader(csvfile) 25 | labels = [dict(row)['Name'] for row in reader] 26 | return ['bg'] + labels # add background 27 | 28 | def raw(self): 29 | extract_dir = download_and_extract(self.DATA_URL, self.cache_dir) 30 | extract_dir = os.path.join(extract_dir, "ADEChallengeData2016") 31 | print(extract_dir) 32 | 33 | val_images_dir = os.path.join(extract_dir, 'images', 'validation') 34 | train_images_dir = os.path.join(extract_dir, 'images', 'training') 35 | 36 | val_annotations_dir = os.path.join(extract_dir, 'annotations', 'validation') 37 | train_annotations_dir = os.path.join(extract_dir, 'annotations', 'training') 38 | 39 | val_images = get_files(val_images_dir, extensions=['jpg']) 40 | val_annotations = get_files(val_annotations_dir, extensions=['png']) 41 | 42 | train_images = get_files(train_images_dir, extensions=['jpg']) 43 | train_annotations = get_files(train_annotations_dir, extensions=['png']) 44 | 45 | return { 46 | DataType.TRAIN: list(zip(train_images, train_annotations)), 47 | DataType.VAL: list(zip(val_images, val_annotations)), 48 | DataType.TEST: [] 49 | } 50 | 51 | 52 | if __name__ == "__main__": 53 | 54 | from .utils import test_dataset 55 | ade20k = Ade20k('/hdd/datasets/ade20k') 56 | test_dataset(ade20k) 57 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/bioimage.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DataType 2 | from .utils import get_split_from_list 3 | from ..utils import download_and_extract, get_files 4 | import os 5 | import numpy as np 6 | import imageio 7 | 8 | 9 | class BioimageBenchmark(Dataset): 10 | """ 11 | Kaggle 2018 Data Science Bowl 12 | https://data.broadinstitute.org/bbbc/BBBC038/ 13 | 14 | Broad Bioimage Benchmark Collection 15 | 16 | This image data set contains a large number of segmented nuclei images and was created for the Kaggle 2018 Data Science Bowl sponsored by 17 | Booz Allen Hamilton with cash prizes. The image set was a testing ground for the application of novel and cutting edge approaches 18 | in computer vision and machine learning to the segmentation of the nuclei belonging to cells from a breadth of biological contexts. 19 | """ 20 | 21 | TRAIN_URL = "https://data.broadinstitute.org/bbbc/BBBC038/stage1_train.zip" 22 | TEST_URL = "https://data.broadinstitute.org/bbbc/BBBC038/stage1_test.zip" 23 | TEST_URL_STAGE_2 = "https://data.broadinstitute.org/bbbc/BBBC038/stage2_test_final.zip" 24 | 25 | def __init__(self, cache_dir, use_stage2_testing=False): 26 | super(BioimageBenchmark, self).__init__(cache_dir) 27 | self.use_stage2_testing = use_stage2_testing 28 | 29 | def raw(self): 30 | train_dir = download_and_extract(self.TRAIN_URL, os.path.join(self.cache_dir, 'train')) 31 | test_url = self.TEST_URL_STAGE_2 if self.use_stage2_testing else self.TEST_URL 32 | test_dir = download_and_extract(test_url, os.path.join(self.cache_dir, 'test')) 33 | trainset = [] 34 | 35 | # make train dataset 36 | for dir in os.listdir(train_dir): 37 | abs_dir = os.path.join(train_dir, dir) 38 | images = get_files(os.path.join(abs_dir, 'images'), extensions=['png']) 39 | masks = get_files(os.path.join(abs_dir, 'masks'), extensions=['png']) 40 | if len(images) == 1: 41 | trainset.append((images[0], masks)) 42 | 43 | # use part of train dataset as validation set 44 | trainset, valset = get_split_from_list(trainset, split=0.9) 45 | 46 | # make test dataset 47 | testset = [] 48 | for dir in os.listdir(train_dir): 49 | abs_dir = os.path.join(train_dir, dir) 50 | images = get_files(os.path.join(abs_dir, 'images'), extensions=['png']) 51 | if len(images) == 1: 52 | testset.append((images[0], None)) 53 | return { 54 | DataType.TRAIN: trainset, 55 | DataType.VAL: valset, 56 | DataType.TEST: testset 57 | } 58 | 59 | @property 60 | def labels(self): 61 | return ['bg', 'nucleus'] 62 | 63 | def parse_example(self, example): 64 | image_path, masks = example 65 | 66 | image = imageio.imread(image_path)[:, :, :3] 67 | mask = np.zeros((image.shape[0], image.shape[1]), np.uint8) 68 | 69 | # testset does not contain any masks, return zeros for masks 70 | if masks is None: 71 | return image, mask 72 | 73 | for mask_path in masks: 74 | nucleus = imageio.imread(mask_path) 75 | mask |= nucleus 76 | 77 | mask = np.divide(mask, 255) 78 | mask = mask.astype(np.uint8) 79 | return image, mask 80 | 81 | 82 | if __name__ == "__main__": 83 | from .utils import test_dataset 84 | 85 | ds = BioimageBenchmark('/hdd/datasets/BioimageBenchmark'.lower()) 86 | test_dataset(ds) 87 | ds.summary() 88 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/camvid.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from ..utils import get_files, download_and_extract, download_file 3 | from .utils import get_split, DataType, Color 4 | 5 | import imageio 6 | import os 7 | import numpy as np 8 | 9 | 10 | class CamSeq01(Dataset): 11 | """ 12 | Image Segmentation DataSet of Road Scenes 13 | 14 | Dataset url: http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamSeq01/CamSeq01.zip 15 | """ 16 | 17 | DATA_URL = "http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamSeq01/CamSeq01.zip" 18 | LABEL_COLORS_URL = "http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/data/label_colors.txt" 19 | 20 | def __init__(self, cache_dir): 21 | super(CamSeq01, self).__init__(cache_dir) 22 | self._labels = self.labels 23 | self._colormap = self.colormap 24 | 25 | def raw(self): 26 | dataset_dir = os.path.join(self.cache_dir, 'dataset') 27 | extracted = download_and_extract(self.DATA_URL, dataset_dir) 28 | imgs = get_files(extracted, extensions=["png"]) 29 | images = list(filter(lambda x: not x.endswith("_L.png"), imgs)) 30 | labels = list(filter(lambda x: x.endswith("_L.png"), imgs)) 31 | trainset = list(zip(images, labels)) 32 | return get_split(trainset) 33 | 34 | @property 35 | def colormap(self): 36 | file_path = download_file(self.LABEL_COLORS_URL, self.cache_dir) 37 | 38 | color_label_mapping = {} 39 | with open(file_path, "r") as handler: 40 | for line in handler.readlines(): 41 | args = line.split("\t") 42 | color = list(map(lambda x: int(x), args[0].split(" "))) 43 | color = Color(*color) 44 | label = args[-1].replace("\n", "") 45 | color_label_mapping[color] = label 46 | 47 | return color_label_mapping 48 | 49 | @property 50 | def labels(self): 51 | file_path = download_file(self.LABEL_COLORS_URL, self.cache_dir) 52 | 53 | labels = [] 54 | 55 | with open(file_path, "r") as handler: 56 | for line in handler.readlines(): 57 | args = line.split("\t") 58 | label = args[-1].replace("\n", "") 59 | labels.append(label) 60 | 61 | return labels 62 | 63 | def parse_example(self, example): 64 | image_path, target_path = example 65 | i = imageio.imread(image_path) 66 | t = imageio.imread(target_path) 67 | mask = np.zeros((i.shape[0], i.shape[1]), np.uint8) 68 | 69 | for color, label in self._colormap.items(): 70 | color = [color.r, color.g, color.b] 71 | idxs = np.where(np.all(t == color, axis=-1)) 72 | mask[idxs] = self._labels.index(label) 73 | 74 | return i, mask 75 | 76 | 77 | if __name__ == "__main__": 78 | from .utils import test_dataset 79 | 80 | ds = CamSeq01('/hdd/datasets/camvid') 81 | test_dataset(ds) 82 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/cub.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DataType 2 | from ..utils import download_and_extract, get_files 3 | from .utils import get_split_from_list 4 | 5 | import os 6 | import csv 7 | import imageio 8 | import numpy as np 9 | 10 | 11 | def read(path): 12 | with open(path, newline='') as csvfile: 13 | reader = csv.reader(csvfile, delimiter=' ', quotechar='|') 14 | lines = [row for row in reader] 15 | return lines 16 | 17 | 18 | class CUB2002011(Dataset): 19 | IMAEGS_URL = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" 20 | MASKS_URL = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/segmentations.tgz" 21 | 22 | modes = ['binary', 'category'] 23 | 24 | def __init__(self, cache_dir, mode='binary'): 25 | super(CUB2002011, self).__init__(cache_dir) 26 | assert(mode in self.modes) 27 | self.mode = mode 28 | 29 | images_dir = download_and_extract(self.IMAEGS_URL, os.path.join(self.cache_dir, 'images')) 30 | masks_dir = download_and_extract(self.MASKS_URL, os.path.join(self.cache_dir, 'masks')) 31 | info_dir = os.path.join(images_dir, 'CUB_200_2011') 32 | 33 | self.images_dir = os.path.join(info_dir, 'images') 34 | self.masks_dir = os.path.join(masks_dir, 'segmentations') 35 | 36 | self._labels = self._read_labels(info_dir) 37 | self.image_id_to_class_id = self._read_image_id_to_class_id(info_dir) 38 | self.bounding_boxes = self._read_bounding_boxes(info_dir) 39 | self.image_id_to_filename = self._read_id_to_filename(info_dir) 40 | self.traindata, self.testdata = self._read_train_test_split(info_dir) 41 | 42 | @property 43 | def labels(self): 44 | if self.mode == 'category': 45 | return self._labels 46 | else: 47 | return ['bg', 'bird'] 48 | 49 | def _read_image_id_to_class_id(self, d, filename='image_class_labels.txt'): 50 | data = read(os.path.join(d, filename)) 51 | return {d[0]: d[1] for d in data} 52 | 53 | def _read_bounding_boxes(self, d, filename='bounding_boxes.txt'): 54 | data = read(os.path.join(d, filename)) 55 | return { 56 | d[0]: { 57 | "x": d[1], 58 | "y": d[2], 59 | "width": d[3], 60 | "height": d[4] 61 | } for d in data 62 | } 63 | 64 | def _read_labels(self, d, filename='classes.txt'): 65 | data = read(os.path.join(d, filename)) 66 | labels = [d[1] for d in data] 67 | labels = ['bg'] + labels 68 | return labels 69 | 70 | def _read_id_to_filename(self, d, filename='images.txt'): 71 | data = read(os.path.join(d, filename)) 72 | return {i[0]: i[1] for i in data} 73 | 74 | def _get_split(self, data, id_to_filename, id_to_class_id, images_dir, masks_dir): 75 | items = [] 76 | for image_id in data: 77 | filename = id_to_filename[image_id] 78 | class_id = int(id_to_class_id[image_id]) 79 | base, _ = os.path.splitext(filename) 80 | image_path = os.path.join(images_dir, filename) 81 | masks_path = os.path.join(masks_dir, "%s.png" % base) 82 | items.append((image_path, masks_path, class_id)) 83 | return items 84 | 85 | def _read_train_test_split(self, d, filename='train_test_split.txt'): 86 | data = read(os.path.join(d, filename)) 87 | trainset = [d[0] for d in data if int(d[1]) == 1] 88 | testset = [d[0] for d in data if int(d[1]) == 0] 89 | return trainset, testset 90 | 91 | def parse_example(self, example): 92 | image_path, mask_path, class_id = example 93 | image = imageio.imread(image_path) 94 | 95 | mask = imageio.imread(mask_path) 96 | if len(mask.shape) > 2: 97 | raise Exception("mask %s has invalid shape %s" % (mask_path, mask.shape)) 98 | 99 | mask = (mask / 255.).astype(np.uint8) 100 | 101 | if self.mode == 'category': 102 | mask = mask * class_id 103 | 104 | return image, mask 105 | 106 | def raw(self): 107 | trainset = self._get_split(self.traindata, self.image_id_to_filename, self.image_id_to_class_id, self.images_dir, self.masks_dir) 108 | trainset, valset = get_split_from_list(trainset, split=0.95) 109 | testset = self._get_split(self.testdata, self.image_id_to_filename, self.image_id_to_class_id, self.images_dir, self.masks_dir) 110 | return { 111 | DataType.TRAIN: trainset, 112 | DataType.TEST: testset, 113 | DataType.VAL: valset 114 | } 115 | 116 | 117 | class CUB2002011Binary(CUB2002011): 118 | 119 | def __init__(self, cache_dir): 120 | super(CUB2002011Binary, self).__init__(cache_dir, mode='binary') 121 | 122 | 123 | class CUB2002011Category(CUB2002011): 124 | 125 | def __init__(self, cache_dir): 126 | super(CUB2002011Category, self).__init__(cache_dir, mode='category') 127 | 128 | 129 | if __name__ == "__main__": 130 | from .utils import test_dataset 131 | 132 | ds = CUB2002011Binary('/hdd/datasets/CUB2002011'.lower()) 133 | test_dataset(ds) 134 | 135 | ds = CUB2002011Category('/hdd/datasets/CUB2002011'.lower()) 136 | test_dataset(ds) 137 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/cvc_clinicdb.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DataType 2 | from ..utils import download_and_extract, get_files, ExtractException 3 | from .utils import get_split 4 | 5 | import os 6 | import shutil 7 | import imageio 8 | import numpy as np 9 | 10 | 11 | class CVCClinicDB(Dataset): 12 | """ 13 | This dataset contains segmentations for polyps. 14 | 15 | https://polyp.grand-challenge.org/CVCClinicDB/ 16 | """ 17 | DATA_URL = "https://www.dropbox.com/s/p5qe9eotetjnbmq/CVC-ClinicDB.rar?dl=1" 18 | 19 | @property 20 | def labels(self): 21 | return ['bg', 'polyps'] 22 | 23 | def raw(self): 24 | output_path = os.path.join(self.cache_dir, 'data') 25 | try: 26 | extracted = download_and_extract(self.DATA_URL, output_path, file_name='CVC-ClinicDB.rar') 27 | except ExtractException as e: 28 | print(str(e)) 29 | shutil.rmtree(output_path) 30 | 31 | images = get_files(os.path.join(output_path, 'CVC-ClinicDB/Original'), extensions=['tif']) 32 | masks = get_files(os.path.join(output_path, 'CVC-ClinicDB/Ground Truth/'), extensions=['tif']) 33 | dataset = list(zip(images, masks)) 34 | return get_split(dataset) 35 | 36 | def parse_example(self, example): 37 | image_fn, mask_fn = example 38 | image = imageio.imread(image_fn) 39 | mask = (imageio.imread(mask_fn) / 255.).astype(np.uint8) 40 | 41 | return image, mask 42 | 43 | 44 | if __name__ == "__main__": 45 | from .utils import test_dataset 46 | ds = CVCClinicDB('/hdd/datasets/cvc_clinic_db') 47 | test_dataset(ds) 48 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from ..settings import logger 2 | from ..visualizations import show, masks 3 | 4 | import imageio 5 | import random 6 | import numpy as np 7 | 8 | 9 | class DataType: 10 | TRAIN, TEST, VAL = 'train', 'test', 'val' 11 | 12 | @staticmethod 13 | def get(): 14 | return list(map(lambda x: DataType.__dict__[x], list(filter(lambda k: not k.startswith("__") and type(DataType.__dict__[k]) == str, DataType.__dict__)))) 15 | 16 | 17 | class Dataset(object): 18 | 19 | def __init__(self, cache_dir): 20 | self.cache_dir = cache_dir 21 | 22 | def summary(self): 23 | logger.info("======================%s========================" % self.__class__.__name__) 24 | logger.info("dataset has %d classes" % self.num_classes) 25 | logger.info("labels: %s" % self.labels) 26 | examples = [(data_type, self.num_examples(data_type)) for data_type in [DataType.TRAIN, DataType.VAL, DataType.TEST]] 27 | logger.info("examples: %s" % str(examples)) 28 | logger.info("total: %d" % sum([l for _, l in examples])) 29 | logger.info("================================================") 30 | 31 | @property 32 | def labels(self): 33 | return [] 34 | 35 | @property 36 | def num_classes(self): 37 | return len(self.labels) 38 | 39 | def raw(self): 40 | return { 41 | DataType.TRAIN: [], 42 | DataType.VAL: [], 43 | DataType.TEST: [] 44 | } 45 | 46 | def parse_example(self, example): 47 | return list(map(imageio.imread, example)) 48 | 49 | def num_examples(self, data_type): 50 | return len(self.raw()[data_type]) 51 | 52 | def total_examples(self): 53 | return sum([self.num_examples(dt) for dt in DataType.get()]) 54 | 55 | def get_random_item(self, data_type=DataType.TRAIN): 56 | data = self.raw()[data_type] 57 | example = random.choice(data) 58 | return self.parse_example(example) 59 | 60 | def show_random_item(self, data_type=DataType.TRAIN): 61 | example = self.get_random_item(data_type=data_type) 62 | show.show_images([example[0], example[1].astype(np.float32)]) 63 | 64 | def save_random_item(self, data_type=DataType.TRAIN, image_path='image.png', mask_path='mask.png', mask_mode='rgb', alpha=0.5): 65 | assert(mask_mode in ['gray', 'rgb', 'overlay']), 'mask mode must be in %s' % str(['gray', 'rgb', 'overlay']) 66 | item = self.get_random_item(data_type=data_type) 67 | imageio.imwrite(image_path, item[0]) 68 | 69 | mask = item[1] 70 | if mask_mode == 'rgb': 71 | image = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) 72 | colors = masks.get_colors(self.num_classes) 73 | mask = masks.overlay_classes(image, mask, colors, self.num_classes, alpha=1.0) 74 | 75 | elif mask_mode == 'overlay': 76 | colors = masks.get_colors(self.num_classes) 77 | mask = masks.overlay_classes(item[0], mask, colors, self.num_classes, alpha=alpha) 78 | 79 | elif mask_mode == 'gray': # scale mask to 0 to 255 80 | mask = (mask * (255. / (self.num_classes - 1))).astype(np.uint8) 81 | 82 | imageio.imwrite(mask_path, mask) 83 | 84 | def get(self, data_type=DataType.TRAIN): 85 | data = self.raw()[data_type] 86 | 87 | def gen(): 88 | for example in data: 89 | try: 90 | yield self.parse_example(example) 91 | except: 92 | logger.error("could not read either one of these files %s" % str(example)) 93 | return gen 94 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/directory.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DataType 2 | from ..utils import get_files 3 | from .utils import get_split, load_image 4 | from ..settings import logger 5 | 6 | import os 7 | import tensorflow as tf 8 | import random 9 | 10 | class DirectoryDataset(Dataset): 11 | 12 | def __init__(self, directory, rand=0.2, extensions=['png', 'jpg', 'jpeg']): 13 | super(DirectoryDataset, self).__init__(directory) 14 | 15 | labels_path = os.path.join(directory, 'labels.txt') 16 | 17 | if not os.path.exists(labels_path): 18 | raise FileNotFoundError("Please provide a file containing the labels. Cannot find file %s" % labels_path) 19 | 20 | with open(labels_path, 'r') as reader: 21 | self._labels = list(map(lambda x: x.replace("\n", "").strip(), reader.readlines())) 22 | 23 | print("labels: ", self._labels) 24 | 25 | if len(self._labels) < 2: 26 | raise AttributeError("Please provide more than 1 label, only found %s in file %s" % (str(self._labels), labels_path)) 27 | 28 | masks_dir = os.path.join(directory, 'masks') 29 | images_dir = os.path.join(directory, 'images') 30 | 31 | if os.path.exists(os.path.join(directory, 'train')): 32 | logger.info("using train, val, test found in directory") 33 | 34 | self.split = {} 35 | for data_type in [DataType.TRAIN, DataType.TEST, DataType.VAL]: 36 | d_masks_dir = os.path.join(directory, data_type, 'masks') 37 | d_images_dir = os.path.join(directory, data_type, 'images') 38 | 39 | if not os.path.exists(d_masks_dir) or not os.path.exists(d_images_dir): 40 | logger.warning("either %s or %s does not exist, getting 0 examples for data_type %s" % (d_masks_dir, d_images_dir, data_type)) 41 | self.split[data_type] = [] 42 | continue 43 | 44 | masks = get_files(d_masks_dir, extensions=extensions) 45 | images = get_files(d_images_dir, extensions=extensions) 46 | 47 | if len(images) != len(masks): 48 | raise Exception("len(images)=%d (%s) does not equal len(masks)=%d (%s)" % (len(images), d_images_dir, len(masks), d_masks_dir)) 49 | 50 | self.split[data_type] = list(zip(images, masks)) 51 | 52 | elif os.path.exists(masks_dir) and os.path.exists(images_dir): 53 | masks = get_files(masks_dir, extensions=extensions) 54 | if len(masks) == 0: 55 | raise Exception("cannot find any images in masks directory %s" % masks_dir) 56 | 57 | images = get_files(images_dir, extensions=extensions) 58 | if len(images) == 0: 59 | raise Exception("cannot find any pictures in images directory %s" % images_dir) 60 | 61 | if len(images) != len(masks): 62 | raise Exception("len(images)=%d does not equal len(masks)=%d" % (len(images), len(masks))) 63 | 64 | trainset = list(zip(images, masks)) 65 | self.split = get_split(trainset, rand=lambda: rand) 66 | 67 | @property 68 | def labels(self): 69 | return self._labels 70 | 71 | def raw(self): 72 | return self.split 73 | 74 | 75 | def tfdataset(self, data_type=DataType.TRAIN, randomize=False): 76 | 77 | data = self.raw()[data_type] 78 | 79 | if randomize: 80 | random.shuffle(data) 81 | 82 | image_paths = [d[0] for d in data] 83 | mask_paths = [d[1] for d in data] 84 | 85 | assert(len(image_paths) == len(mask_paths)), "len of images does not equal len of masks" 86 | 87 | images_ds = tf.data.Dataset.from_tensor_slices(image_paths).map( 88 | lambda path: load_image(path, tf.uint8), 89 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 90 | ) 91 | masks_ds = tf.data.Dataset.from_tensor_slices(mask_paths).map( 92 | lambda path: load_image(path, tf.uint8, squeeze=True, channels=1), 93 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 94 | ) 95 | nc = self.num_classes 96 | num_classes_ds = tf.data.Dataset.from_tensor_slices([nc for i in range(len(image_paths))]) 97 | 98 | dataset = tf.data.Dataset.zip((images_ds, masks_ds, num_classes_ds)) 99 | return dataset 100 | 101 | if __name__ == "__main__": 102 | from ..processing import dataset 103 | from .utils import convert2tfdataset 104 | from ..visualizations import show 105 | import numpy as np 106 | ds = DirectoryDataset('output') 107 | 108 | tfds = convert2tfdataset(ds, DataType.TRAIN) 109 | fn = dataset.get_preprocess_fn((128, 128), 0, 'resize', True, mode='eager') 110 | tfds = tfds.map(fn) 111 | for image, mask in tfds: 112 | show.show_images([image.numpy(), mask.numpy().astype(np.float32)]) 113 | print(image.shape, mask.shape) 114 | break -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/isic.py: -------------------------------------------------------------------------------- 1 | from .dataset import DataType, Dataset 2 | from ..utils import extract_zip, get_files 3 | from ..settings import logger 4 | 5 | import os 6 | import imageio 7 | import numpy as np 8 | 9 | 10 | class ISIC2018(Dataset): 11 | 12 | DATASET_URL = "https://challenge.kitware.com/#challenge/5aab46f156357d5e82b00fe5" 13 | 14 | def __init__(self, cache_dir): 15 | super(ISIC2018, self).__init__(cache_dir) 16 | self.extract() 17 | 18 | @property 19 | def labels(self): 20 | return ['bg', 'melona'] 21 | 22 | def extract(self): 23 | filenames = [ 24 | "ISIC2018_Task1-2_Training_Input.zip", 25 | "ISIC2018_Task1_Training_GroundTruth.zip", 26 | "ISIC2018_Task1-2_Validation_Input.zip", 27 | "ISIC2018_Task1-2_Test_Input.zip" 28 | ] 29 | files = [os.path.join(self.cache_dir, f) for f in filenames] 30 | 31 | if not all(map(os.path.exists, files)): 32 | raise FileNotFoundError("Please download the following files %s to %s from %s" % (str(filenames), self.cache_dir, self.DATASET_URL)) 33 | 34 | for f in files: 35 | logger.info("extracting %s, this may take a while" % f) 36 | output_dirname = os.path.splitext(os.path.basename(f))[0] 37 | destination = os.path.join(self.cache_dir, output_dirname) 38 | if not os.path.exists(destination): 39 | extract_zip(f, destination) 40 | else: 41 | logger.info("skipping extraction, because directory already exist") 42 | 43 | def raw(self): 44 | 45 | train_images = get_files(os.path.join(self.cache_dir, 'ISIC2018_Task1-2_Training_Input'), extensions=['jpg']) 46 | train_masks = get_files(os.path.join(self.cache_dir, 'ISIC2018_Task1_Training_GroundTruth'), extensions=['png']) 47 | val_images = get_files(os.path.join(self.cache_dir, 'ISIC2018_Task1-2_Validation_Input'), extensions=['jpg']) 48 | test_images = get_files(os.path.join(self.cache_dir, 'ISIC2018_Task1-2_Test_Input'), extensions=['jpg']) 49 | 50 | return { 51 | DataType.TRAIN: list(zip(train_images, train_masks)), 52 | DataType.VAL: list(zip(val_images, [None] * len(val_images))), 53 | DataType.TEST: list(zip(test_images, [None] * len(test_images))), 54 | } 55 | 56 | def parse_example(self, example): 57 | image_fn, mask_fn = example 58 | image = imageio.imread(image_fn) 59 | 60 | if mask_fn is None: 61 | logger.warning("mask of dataset is None, returning ZEROS") 62 | mask = np.zeros((image.shape[0], image.shape[1]), np.uint8) 63 | else: 64 | mask = (imageio.imread(mask_fn) / 255.).astype(np.uint8) 65 | 66 | return image, mask 67 | 68 | 69 | if __name__ == "__main__": 70 | from .utils import test_dataset 71 | 72 | ds = ISIC2018('/hdd/datasets/isic/') 73 | test_dataset(ds) 74 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/mapping_challenge.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DataType 2 | from ..utils import download_and_extract, get_files 3 | from ..visualizations import masks, show 4 | 5 | import os 6 | import tqdm 7 | import json 8 | import numpy as np 9 | import cv2 10 | import imageio 11 | 12 | 13 | class MappingChallenge(Dataset): 14 | """ https://www.crowdai.org/challenges/mapping-challenge/dataset_files """ 15 | 16 | TRAIN_URL = "https://crowdai-prd.s3.eu-central-1.amazonaws.com/dataset_files/challenge_25/8e089a94-555c-4d7b-8f2f-4d733aebb058_train.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAILFF3ZEGG7Y4HXEQ%2F20191218%2Feu-central-1%2Fs3%2Faws4_request&X-Amz-Date=20191218T022310Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=5d49a86529cc2078a6ab4da22e7fe1b9e3b5dff2187fd106190799cb034eaa5c" 17 | VAL_URL = "https://crowdai-prd.s3.eu-central-1.amazonaws.com/dataset_files/challenge_25/0a5c561f-e361-4e9b-a3e2-94f42a003a2b_val.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAILFF3ZEGG7Y4HXEQ%2F20191218%2Feu-central-1%2Fs3%2Faws4_request&X-Amz-Date=20191218T022310Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=76d8cbe79f17cdd673fe44135c3f0c456ae25f89756854b67958441a09d15716" 18 | 19 | def __init__(self, cache_dir): 20 | super(MappingChallenge, self).__init__(cache_dir) 21 | self.already_extracted = False 22 | 23 | @property 24 | def labels(self): 25 | return ['bg', 'building'] 26 | 27 | def read_annotations_extract_masks(self, path, output_dir): 28 | 29 | print("loading annotation data %s" % path) 30 | data = json.load(open(path, 'r')) 31 | 32 | segmentation_data = {} 33 | 34 | categories = {c['id']: c for c in data['categories']} 35 | 36 | # label must start at 1 37 | for i, c in enumerate(data['categories']): 38 | categories[c['id']]['label'] = i + 1 39 | 40 | print(categories) 41 | 42 | for image_info in tqdm.tqdm(data['images'], desc='building image info (annotations)'): 43 | image_id = image_info['id'] 44 | segmentation_data[image_id] = image_info 45 | 46 | for seg_info in tqdm.tqdm(data['annotations'], desc='writing masks'): 47 | info = segmentation_data[seg_info['image_id']] 48 | file_name = info['file_name'] 49 | file_name, ext = os.path.splitext(file_name) 50 | output_path = os.path.join(output_dir, '%s.png' % file_name) 51 | 52 | if os.path.exists(output_path): 53 | continue 54 | 55 | category = categories[seg_info["category_id"]] 56 | 57 | mask = np.zeros((info['height'], info['width']), np.uint8) 58 | for poly in seg_info['segmentation']: 59 | print(poly) 60 | cnt = np.asarray(poly).reshape((-1, 2)).astype(np.int32) 61 | mask = cv2.drawContours(mask, [cnt], 0, int(category['label']), cv2.FILLED, cv2.LINE_AA) 62 | 63 | image = imageio.imread(output_path) 64 | show.show_images([mask.astype(np.float32), image]) 65 | imageio.imwrite(output_path, mask) 66 | 67 | def raw(self): 68 | 69 | train_dir = os.path.join(self.cache_dir, 'train') 70 | val_dir = os.path.join(self.cache_dir, 'val') 71 | 72 | val_masks_dir = os.path.join(val_dir, 'masks') 73 | train_masks_dir = os.path.join(train_dir, 'masks') 74 | 75 | if not self.already_extracted: 76 | 77 | train_dir = download_and_extract(self.TRAIN_URL, train_dir, file_name='train.tar.gz') 78 | val_dir = download_and_extract(self.VAL_URL, val_dir, file_name='val.tar.gz') 79 | 80 | os.makedirs(train_masks_dir, exist_ok=True) 81 | os.makedirs(val_masks_dir, exist_ok=True) 82 | 83 | train_annotations_path = os.path.join(train_dir, 'train', 'annotation.json') 84 | val_annotations_path = os.path.join(val_dir, 'val', 'annotation.json') 85 | 86 | self.read_annotations_extract_masks(train_annotations_path, train_masks_dir) 87 | self.read_annotations_extract_masks(val_annotations_path, val_masks_dir) 88 | self.already_extracted = True 89 | 90 | train_images = get_files(train_dir, extensions=['jpg']) 91 | val_images = get_files(val_dir, extensions=['jpg']) 92 | 93 | train_masks = get_files(train_masks_dir, extensions=['png']) 94 | val_masks = get_files(train_masks_dir, extensions=['png']) 95 | 96 | return { 97 | DataType.TRAIN: list(zip(train_images, train_masks)), 98 | DataType.VAL: list(zip(val_images, val_masks)), 99 | DataType.TEST: [] 100 | } 101 | 102 | 103 | if __name__ == "__main__": 104 | from ..visualizations import show 105 | for image, mask in MappingChallenge('/hdd/datasets/mappingchallenge').get(DataType.TRAIN)(): 106 | show.show_images([image, mask.astype(np.float32)]) 107 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/mots_challenge.py: -------------------------------------------------------------------------------- 1 | from ..utils import download_and_extract, extract_zip 2 | from .utils import get_split_from_dirs 3 | from .dataset import Dataset, DataType 4 | import os 5 | import imageio 6 | 7 | 8 | class MotsChallenge(Dataset): 9 | """ https://www.vision.rwth-aachen.de/page/mots """ 10 | IMAGE_URL = "https://motchallenge.net/data/MOT17.zip" 11 | ANNOTATIONS_URL = "https://www.vision.rwth-aachen.de/media/resource_files/instances_motschallenge.zip" 12 | 13 | def __init__(self, cache_dir): 14 | self.cache_dir = cache_dir 15 | 16 | @property 17 | def labels(self): 18 | return ['bg', 'car', 'pedestrian'] 19 | 20 | def raw(self): 21 | images_dir = download_and_extract( 22 | self.IMAGE_URL, os.path.join(self.cache_dir, "images")) 23 | 24 | annotations_dir = download_and_extract( 25 | self.ANNOTATIONS_URL, os.path.join(self.cache_dir, 'annotations')) 26 | 27 | return get_split_from_dirs(images_dir, annotations_dir) 28 | 29 | def parse_example(self, example): 30 | image_path, target_path = example 31 | i = imageio.imread(image_path) 32 | t = imageio.imread(target_path) 33 | 34 | obj_ids = np.unique(t) 35 | # to correctly interpret the id of a single object 36 | 37 | mask = np.zeros(i.shape[:2], np.uint8) 38 | for obj_id in obj_ids: 39 | class_id = obj_id // 1000 40 | 41 | idxs = np.where(np.all(t == obj_id, axis=-1)) 42 | mask[idxs] = class_id 43 | 44 | return i, mask 45 | 46 | 47 | if __name__ == "__main__": 48 | from ..visualizations import show 49 | import numpy as np 50 | 51 | mots = MotsChallenge('/hdd/datasets/mots/') 52 | gen = mots.get() 53 | for image, target in gen(): 54 | print(np.unique(target)) 55 | show.show_images([image, target.astype(np.float32)]) 56 | break 57 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/shapes.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import imageio 3 | import os 4 | import numpy as np 5 | import cv2 6 | import random 7 | import shutil 8 | 9 | from .dataset import Dataset 10 | from .utils import DataType, get_split_from_list, get_split 11 | 12 | from ..visualizations import show 13 | from ..utils import download_and_extract, get_files 14 | from ..processing import ColorMode 15 | from ..settings import logger 16 | from ..threading import parallize_v3 17 | 18 | 19 | class ShapesDS(Dataset): 20 | 21 | SHAPES = ['rectangle', 'triangle', 'circle'] 22 | 23 | def __init__(self, cache_dir, num_examples=10000, size=(512, 512), overwrite=False, color_mode=ColorMode.RGB, max_shapes_per_example=3): 24 | super(ShapesDS, self).__init__(cache_dir) 25 | self._num_examples = num_examples 26 | self.images_dir = os.path.join(self.cache_dir, 'images') 27 | self.masks_dir = os.path.join(self.cache_dir, 'masks') 28 | self.max_shapes_per_example = max_shapes_per_example 29 | self.overwrite = overwrite 30 | self.color_mode = color_mode 31 | self.size = size 32 | 33 | os.makedirs(self.masks_dir, exist_ok=True) 34 | os.makedirs(self.images_dir, exist_ok=True) 35 | 36 | self.trainset = parallize_v3(self.create_example, list(range(self._num_examples)), desc='creating shapes dataset') 37 | 38 | @property 39 | def labels(self): 40 | l = ['bg'] 41 | l.extend(self.SHAPES) 42 | return l 43 | 44 | def draw_shape(self, image, shape, x, y, w, h, color): 45 | assert(shape in self.SHAPES) 46 | if shape == "rectangle": 47 | cv2.rectangle(image, (x, y), (x + w, y + h), color, thickness=cv2.FILLED) 48 | elif shape == "circle": 49 | cv2.circle(image, (x, y), w, color, thickness=cv2.FILLED) 50 | elif shape == 'triangle': 51 | triangle_cnt = np.asarray([[x, y], [x + w, y], [x + w // 2, y + h]], np.int32) 52 | cv2.drawContours(image, [triangle_cnt], 0, color, thickness=cv2.FILLED) 53 | return image 54 | 55 | def get_random_color(self): 56 | if self.color_mode == ColorMode.RGB: 57 | return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 58 | elif self.color_mode == ColorMode.GRAY: 59 | return (random.randint(0, 255)) 60 | 61 | def draw_shapes(self, image, shapes): 62 | mask = np.zeros((self.size[1], self.size[0]), np.uint8) 63 | for shape in shapes: 64 | color = self.get_random_color() 65 | qx = self.size[0] // 4 66 | qy = self.size[1] // 4 67 | x = random.randint(qx, self.size[0] - qx) 68 | y = random.randint(qy, self.size[1] - qy) 69 | w = random.randint(0, self.size[0] // 2) 70 | h = random.randint(0, self.size[1] // 2) 71 | image = self.draw_shape(image, shape, x, y, w, h, color) 72 | mask = self.draw_shape(mask, shape, x, y, w, h, self.SHAPES.index(shape) + 1) 73 | return image, mask 74 | 75 | def create_example(self, i): 76 | mask_path = os.path.join(self.masks_dir, "%d.png" % i) 77 | image_path = os.path.join(self.images_dir, "%d.png" % i) 78 | 79 | if not self.overwrite and os.path.exists(mask_path) and os.path.exists(image_path): 80 | return image_path, mask_path 81 | 82 | num_shapes = random.randint(0, self.max_shapes_per_example) 83 | random_shapes = [random.choice(self.SHAPES) for i in range(num_shapes)] 84 | 85 | if self.color_mode == ColorMode.RGB: 86 | image = np.zeros((self.size[1], self.size[0], 3), np.uint8) 87 | 88 | elif self.color_mode == ColorMode.GRAY: 89 | image = np.zeros((self.size[1], self.size[0], 1), np.uint8) 90 | else: 91 | raise Exception("unknown color mode %s" % self.color_mode.name) 92 | 93 | image, mask = self.draw_shapes(image, random_shapes) 94 | imageio.imwrite(image_path, image) 95 | imageio.imwrite(mask_path, mask) 96 | return image_path, mask_path 97 | 98 | def raw(self): 99 | return get_split(self.trainset) 100 | 101 | 102 | class ShapesDSMini(ShapesDS): 103 | 104 | SHAPES = ['circle'] 105 | 106 | def __init__(self, cache_dir, overwrite=False, color_mode=ColorMode.GRAY): 107 | super(ShapesDSMini, self).__init__(cache_dir, num_examples=100, size=(32, 32), overwrite=True, color_mode=color_mode, 108 | max_shapes_per_example=3) 109 | 110 | 111 | if __name__ == "__main__": 112 | 113 | ds = ShapesDS('/hdd/datasets/shapes', 1000) 114 | for image_path, mask_path in ds.raw()[DataType.TRAIN]: 115 | show.show_images([imageio.imread(image_path), imageio.imread(mask_path).astype(np.float32)]) 116 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/datasets/taco.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DataType 2 | from .utils import get_split 3 | from ..utils import download_file, get_files 4 | from .ms_coco import CocoAnnotationReader 5 | import json 6 | import os 7 | import tqdm 8 | 9 | 10 | class Taco(Dataset): 11 | 12 | ANNOTATIONS_URL = "https://raw.githubusercontent.com/pedropro/TACO/master/data/annotations.json" 13 | 14 | """ Super Categories: 15 | 16 | {'Scrap metal', 'Plastic utensils', 'Other plastic', 'Lid', 'Shoe', 'Plastic container', 'Bottle cap', 'Blister pack', 17 | 'Aluminium foil', 'Broken glass', 'Carton', 'Can', 'Styrofoam piece', 'Plastic glooves', 'Glass jar', 'Paper bag', 18 | 'Squeezable tube', 'Paper', 'Rope & strings', 'Bottle', 'Cup', 'Food waste', 19 | 'Pop tab', 'Straw', 'Cigarette', 'Unlabeled litter', 'Plastic bag & wrapper', 'Battery'} 20 | """ 21 | MODES = ['supercategory', 'category', 'binary'] 22 | 23 | def __init__(self, cache_dir, mode='supercategory'): 24 | super(Taco, self).__init__(cache_dir) 25 | self.annotations_file = download_file(self.ANNOTATIONS_URL, self.cache_dir) 26 | self.ann_reader = CocoAnnotationReader(self.annotations_file) 27 | self.mode = mode 28 | self.split = self.generate() 29 | 30 | def generate(self): 31 | data_dir = os.path.join(self.cache_dir, 'dataset') 32 | masks_dir = os.path.join(self.cache_dir, 'masks', self.mode) 33 | self.download_files(self.annotations_file, data_dir) 34 | 35 | anns = self.ann_reader.read_annotations() 36 | 37 | output_paths = [] 38 | input_paths = get_files(data_dir, extensions=['jpg']) 39 | for image_path in tqdm.tqdm(input_paths, desc='generating masks'): 40 | output_paths.append(self.ann_reader.generate_masks(image_path, anns, masks_dir, data_dir=data_dir, mode=self.mode)) 41 | 42 | trainset = list(zip(input_paths, output_paths)) 43 | assert(len(trainset) != 0) 44 | 45 | return get_split(trainset) 46 | 47 | @property 48 | def labels(self): 49 | return self.ann_reader.get_labels(self.mode) 50 | 51 | def download_files(self, path, dataset_dir): 52 | with open(path, 'r') as f: 53 | annotations = json.loads(f.read()) 54 | 55 | nr_images = len(annotations['images']) 56 | for i in tqdm.trange(nr_images): 57 | 58 | image = annotations['images'][i] 59 | 60 | file_name = image['file_name'] 61 | url_original = image['flickr_url'] 62 | url_resized = image['flickr_640_url'] 63 | 64 | file_path = os.path.join(dataset_dir, file_name) 65 | 66 | # Create subdir if necessary 67 | subdir = os.path.dirname(file_path) 68 | 69 | if not os.path.isdir(subdir): 70 | os.makedirs(subdir) 71 | 72 | if not os.path.isfile(file_path): 73 | # Load and Save Image 74 | download_file(url_original, os.path.dirname(file_path), file_name=os.path.basename(file_path)) 75 | 76 | def raw(self): 77 | return self.split 78 | 79 | 80 | class TacoBinary(Taco): 81 | 82 | def __init__(self, cache_dir): 83 | super(TacoBinary, self).__init__(cache_dir, mode='binary') 84 | 85 | 86 | class TacoCategory(Taco): 87 | 88 | def __init__(self, cache_dir): 89 | super(TacoCategory, self).__init__(cache_dir, mode='category') 90 | 91 | 92 | class TacoSuperCategory(Taco): 93 | 94 | def __init__(self, cache_dir): 95 | super(TacoSuperCategory, self).__init__(cache_dir, mode='supercategory') 96 | 97 | 98 | if __name__ == "__main__": 99 | import imageio 100 | t = Taco('/hdd/datasets/taco', mode='supercategory') 101 | t.summary() 102 | #for image, labels in t.get()(): 103 | # imageio.imwrite("test.png", labels) 104 | # break 105 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tf_semantic_segmentation/debug/__init__.py -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/dataset_export.py: -------------------------------------------------------------------------------- 1 | from ..datasets import datasets_by_name, get_dataset_by_name, get_cache_dir, DataType 2 | from ..datasets.utils import convert2tfdataset 3 | from ..processing.dataset import ColorMode, resize_and_change_color 4 | from ..datasets.tfrecord import TFWriter 5 | from ..settings import logger 6 | 7 | import tensorflow as tf 8 | import argparse 9 | import os 10 | import imageio 11 | import tqdm 12 | import multiprocessing 13 | 14 | 15 | def write_labels(path, labels): 16 | with open(path, 'w') as writer: 17 | for label in labels: 18 | writer.write(label.strip() + "\n") 19 | 20 | 21 | def read_labels(path): 22 | with open(path, 'r') as reader: 23 | lines = reader.readlines() 24 | labels = list(map(lambda x: x[:-1], lines)) 25 | return labels 26 | 27 | 28 | def export(ds, output_dir, size=None, resize_method="resize_with_pad", color_mode=ColorMode.NONE, overwrite=False, batch_size=4): 29 | 30 | os.makedirs(output_dir, exist_ok=True) 31 | write_labels(os.path.join(output_dir, 'labels.txt'), ds.labels) 32 | 33 | def preprocess_fn(image, mask, num_classes): 34 | image = tf.image.convert_image_dtype(image, tf.float32) 35 | image, mask = resize_and_change_color(image, mask, size, color_mode, resize_method) 36 | image = tf.image.convert_image_dtype(image, tf.uint8) 37 | return image, mask 38 | 39 | for data_type in DataType.get(): 40 | masks_dir = os.path.join(output_dir, data_type, 'masks') 41 | images_dir = os.path.join(output_dir, data_type, 'images') 42 | 43 | os.makedirs(masks_dir, exist_ok=True) 44 | os.makedirs(images_dir, exist_ok=True) 45 | 46 | tfds = convert2tfdataset(ds, data_type) 47 | tfds = tfds.map(preprocess_fn, num_parallel_calls=multiprocessing.cpu_count()) 48 | tfds = tfds.batch(batch_size) 49 | 50 | total_examples = ds.num_examples(data_type) 51 | for k, (images, masks) in tqdm.tqdm(enumerate(tfds), total=total_examples // batch_size, desc="exporting", 52 | postfix=dict(data_type=data_type, batch_size=batch_size, total=total_examples)): 53 | for g in range(len(images)): 54 | i = k * batch_size + g 55 | image = images[g] 56 | mask = masks[g] 57 | 58 | image_path = os.path.join(images_dir, '%d.jpg' % i) 59 | mask_path = os.path.join(masks_dir, '%d.png' % i) 60 | if (not os.path.exists(image_path) or not os.path.exists(image_path)) or overwrite: 61 | imageio.imwrite(image_path, image) 62 | imageio.imwrite(mask_path, mask) 63 | 64 | 65 | def main(): 66 | 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('-d', '--dataset', choices=list(datasets_by_name.keys()), required=True) 69 | parser.add_argument('-c', '--data_dir', required=True) 70 | parser.add_argument('-o', '--output_dir', required=True) 71 | parser.add_argument('-bs', '--batch_size', type=int, default=4) 72 | parser.add_argument('-s', '--size', default=None, type=lambda x: list(map(int, x.split(','))), help=['height,width']) 73 | parser.add_argument('-rm', '--resize_method', default='resize_with_pad') 74 | parser.add_argument('-cm', '--color_mode', default=ColorMode.NONE, type=int) 75 | parser.add_argument('-overwrite', '--overwrite', action='store_true') 76 | 77 | args = parser.parse_args() 78 | 79 | cache_dir = get_cache_dir(args.data_dir, args.dataset.lower()) 80 | ds = get_dataset_by_name(args.dataset, cache_dir) 81 | 82 | logger.info('wrting dataset to %s' % args.output_dir) 83 | export(ds, args.output_dir, args.size, args.resize_method, args.color_mode, args.overwrite, batch_size=args.batch_size) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/dataset_vis.py: -------------------------------------------------------------------------------- 1 | from ..datasets import datasets_by_name, get_dataset_by_name, DataType, get_cache_dir 2 | from ..visualizations import masks, show 3 | from ..processing.image import fixed_resize 4 | import os 5 | import argparse 6 | import numpy as np 7 | 8 | if __name__ == "__main__": 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('-d', '--dataset', choices=list(datasets_by_name.keys()), required=True) 12 | parser.add_argument('-data_dir', '--data_dir', default='/hdd/datasets/') 13 | parser.add_argument('-t', '--data_type', default=DataType.TRAIN, choices=DataType.get()) 14 | args = parser.parse_args() 15 | 16 | cache_dir = get_cache_dir(args.data_dir, args.dataset.lower()) 17 | ds = get_dataset_by_name(args.dataset, cache_dir) 18 | 19 | # download and cache data 20 | ds.raw() 21 | 22 | ds.summary() 23 | 24 | colors = masks.get_colors(ds.num_classes) 25 | print("using colors: %s" % str(colors)) 26 | 27 | for image, target in ds.get(DataType.TRAIN)(): 28 | 29 | image = fixed_resize(image, width=500) 30 | labels = np.unique(target) 31 | print("labels: ", labels) 32 | target = fixed_resize(target, width=500) 33 | 34 | labels = np.unique(target) 35 | print("labels resize: ", labels) 36 | print('============================') 37 | # print("counts: ", counts) 38 | # print('target:', target.dtype, target.max(), target.min()) 39 | # target_3d = np.zeros((image.shape[0], image.shape[1], 3), np.uint8) 40 | overlay_on_black = masks.overlay_classes(np.ones_like(image) * 255, target, colors, ds.num_classes, alpha=1.0) 41 | overlay = masks.overlay_classes(image.copy(), target, colors, ds.num_classes) 42 | # target = masks.apply_mask() 43 | show.show_images([image, target.astype(np.float32), overlay, overlay_on_black]) 44 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/devices.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.client import device_lib 2 | from pprint import pprint 3 | 4 | 5 | def get_available_gpus(): 6 | local_device_protos = device_lib.list_local_devices() 7 | return [(x.name, x.memory_limit / 1000. / 1000.) for x in local_device_protos if x.device_type == 'GPU'] 8 | 9 | 10 | if __name__ == "__main__": 11 | pprint(get_available_gpus()) 12 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/export_saved_model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import load_model 2 | # import necessary modules 3 | from ..bin.train import main 4 | import argparse 5 | import os 6 | 7 | 8 | def export_saved_model(logdir, model_name='model-best.h5'): 9 | model_path = os.path.join(logdir, model_name) 10 | model = load_model(model_path, compile=False) 11 | 12 | saved_model_path = os.path.join(args.logdir, 'saved_model', '0') 13 | model.save(saved_model_path, save_format='tf') 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-l', '--logdir', required=True) 19 | parser.add_argument('-m', '--model_name', default='model-best.h5') 20 | args = parser.parse_args() 21 | 22 | export_saved_model(args.logdir, args.model_name) 23 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/model_parameters.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | from ..models import models_by_name 5 | from ..settings import logger 6 | 7 | from tensorflow.keras import backend as K 8 | import tensorflow as tf 9 | import numpy as np 10 | import logging 11 | import argparse 12 | 13 | 14 | def counts(model): 15 | trainable_count = int( 16 | np.sum([K.count_params(p) for p in model.trainable_weights])) 17 | 18 | non_trainable_count = int( 19 | np.sum([K.count_params(p) for p in model.non_trainable_weights])) 20 | 21 | return (trainable_count, non_trainable_count) 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # noqa 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('-s', '--shape', default=[256, 256, 3], help='input shape default: (256, 256, 3)', type=lambda x: list(map(int, x.split(",")))) 30 | parser.add_argument('-nc', '--num_classes', default=10, help='number of classes default: 10') 31 | args = parser.parse_args() 32 | 33 | logger.setLevel(logging.INFO) 34 | num_classes = args.num_classes 35 | input_shape = tuple(args.shape) 36 | 37 | logger.info("using input shape %s" % str(input_shape)) 38 | logger.info("using %d classes" % num_classes) 39 | logger.info("==================") 40 | 41 | models = { 42 | 43 | } 44 | 45 | for name, model_fn in list(models_by_name.items()): 46 | model, _ = model_fn(input_shape=input_shape, num_classes=num_classes) 47 | models[name] = counts(model) 48 | logger.info("%s has %d trainable and %d non trainable parameters" % (name, models[name][0], models[name][1])) 49 | 50 | K.clear_session() 51 | 52 | print("=" * 20, "summary", "=" * 20) 53 | max_length = max([len(name) for name in models.keys()]) 54 | 55 | row_format = "{:>" + str(max_length) + "} {:>20} {:>20}" 56 | print(row_format.format("name", "trainable params", "non trainable params")) 57 | 58 | row_format = "{:>" + str(max_length) + "} {:>20,} {:>20,}" 59 | for name, (params, non_params) in models.items(): 60 | print(row_format.format(name, params, non_params)) 61 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/preprocessing_vis.py: -------------------------------------------------------------------------------- 1 | from ..datasets import get_dataset_by_name, DataType 2 | from ..datasets.utils import convert2tfdataset 3 | from ..processing import dataset as ds_preprocessing 4 | from ..processing import ColorMode 5 | from ..visualizations.show import show_images 6 | import numpy as np 7 | 8 | if __name__ == "__main__": 9 | ds = 'bioimagebenchmark' 10 | ds = get_dataset_by_name(ds, '/hdd/datasets/%s' % ds) 11 | tfds = convert2tfdataset(ds, DataType.TRAIN) 12 | 13 | preprocess_fn = ds_preprocessing.get_preprocess_fn((256, 256), ColorMode.RGB, 'patch', scale_mask=True) 14 | tdds = tfds.map(preprocess_fn) 15 | for inputs, targets in tdds: 16 | show_images([inputs, targets]) 17 | break 18 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/record_vis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ..visualizations import show 3 | from ..datasets.tfrecord import TFReader 4 | from ..processing import dataset as ds_preprocessing 5 | from ..datasets import DataType 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | def get_args(): 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-r', '--record_dir', required=True) 14 | parser.add_argument('-t', '--take', default=8, type=int) 15 | return parser.parse_args() 16 | 17 | 18 | def main(): 19 | 20 | args = get_args() 21 | reader = TFReader(args.record_dir) 22 | dataset = reader.get_dataset(DataType.TRAIN) 23 | dataset = dataset.shuffle(50) 24 | dataset = dataset.take(args.take) 25 | print("num_classes:", reader.num_classes) 26 | print("input shape: ", reader.input_shape) 27 | 28 | for image, _target, num_classes in dataset: 29 | show.show_images([image.numpy(), _target.numpy().astype(np.float32)], titles=['input', 'target']) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/debug/tflite_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | from ..settings import logger 5 | 6 | 7 | class TFLiteInterpreter(): 8 | 9 | def __init__(self, tflite_model_path): 10 | self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_path)) 11 | self.interpreter.allocate_tensors() 12 | 13 | self.input_index = self.interpreter.get_input_details()[0]["index"] 14 | self.output_index = self.interpreter.get_output_details()[0]["index"] 15 | 16 | @property 17 | def output_shape(self): 18 | return self.interpreter.get_output_details()[0]['shape'] 19 | 20 | @property 21 | def input_shape(self): 22 | return self.interpreter.get_input_details()[0]['shape'] 23 | 24 | def predict(self, image): 25 | 26 | test_image = np.expand_dims(image, axis=0) 27 | 28 | start = time.time() 29 | self.interpreter.set_tensor(self.input_index, test_image) 30 | self.interpreter.invoke() 31 | output = self.interpreter.get_tensor(self.output_index)[0] 32 | logger.debug("Inference took %.3f seconds" % (time.time() - start)) 33 | return output 34 | 35 | def predict_on_batch(self, batch): 36 | 37 | start = time.time() 38 | self.interpreter.set_tensor(self.input_index, batch) 39 | self.interpreter.invoke() 40 | 41 | output = self.interpreter.get_tensor(self.output_index) 42 | logger.debug("Inference took %.3f seconds" % (time.time() - start)) 43 | return output 44 | 45 | 46 | if __name__ == "__main__": 47 | 48 | import argparse 49 | import imageio 50 | 51 | from ..processing.dataset import resize_and_change_color 52 | from ..visualizations import show, masks 53 | 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('-i', '--image', help='input image path', required=True) 56 | parser.add_argument('-m', '--tflite_model_path', help='path to the tflite model', required=True) 57 | parser.add_argument('-rm', '--resize_method', default='resize', help='method for resizing inputs') 58 | parser.add_argument('-thresh', '--binary_threshold', default=0.5, help='threshold') 59 | args = parser.parse_args() 60 | 61 | interpreter = TFLiteInterpreter(args.tflite_model_path) 62 | logger.info("input shape: %s, output shape: %s" % (str(interpreter.input_shape), str(interpreter.output_shape))) 63 | 64 | # define model parameters 65 | size = interpreter.input_shape[1:3] 66 | color_mode = 0 if interpreter.input_shape[-1] == 3 else 1 67 | resize_method = args.resize_method 68 | scale_mask = interpreter.output_shape[-1] == 1 69 | num_classes = 2 if interpreter.output_shape[-1] == 1 else interpreter.output_shape[-1] 70 | 71 | image = imageio.imread(args.image) 72 | 73 | # scale between 0 and 1 74 | image = tf.image.convert_image_dtype(image, tf.float32) 75 | 76 | # resize method for image create float32 image anyway 77 | image, _ = resize_and_change_color(image, None, size, color_mode, resize_method=resize_method) 78 | 79 | batch = tf.expand_dims(image, axis=0) 80 | predictions = interpreter.predict_on_batch(batch) 81 | seg_masks = masks.get_colored_segmentation_mask(predictions, num_classes, images=batch.numpy(), binary_threshold=args.binary_threshold) 82 | 83 | show.show_images(seg_masks) 84 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tf_semantic_segmentation/evaluation/__init__.py -------------------------------------------------------------------------------- /tf_semantic_segmentation/evaluation/compare_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pprint import pprint 4 | import argparse 5 | import tensorflow as tf 6 | 7 | from ..settings import logger 8 | from ..metrics import get_metric_by_name 9 | from ..datasets import get_dataset_by_name, get_cache_dir 10 | from ..datasets.utils import convert2tfdataset, DataType 11 | from ..processing.dataset import get_preprocess_fn, ColorMode 12 | from ..visualizations import show 13 | from ..serving import predict, predict_on_batch, get_models_from_directory, ensemble_inference, retrieve_metadata 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-i', '--models_dir', required=True, help='path to dir containing multiple trained models') 19 | parser.add_argument('-data', '--data_dir', required=True, help='data directory') 20 | parser.add_argument('-rand', '--randomize', action='store_true', help='randomize the dataset examples order') 21 | parser.add_argument('-c', '--contains', default=None, help='model name must contain this value') 22 | parser.add_argument('-d', '--dataset', required=True, help='dataset name') 23 | parser.add_argument('-dt', '--data_type', default=DataType.VAL, choices=DataType.get()) 24 | parser.add_argument("-rm", '--resize_method', default='resize', help='resize method') 25 | parser.add_argument("-host", '--host', default='localhost', help='tf model server host') 26 | parser.add_argument("-p", '--port', default=8501, type=int, help='tf model server port') 27 | parser.add_argument("-npr", '--num_per_row', default=4, type=int, help='number of images per row') 28 | parser.add_argument("-t", '--threshold', default=0.5, type=float, help='binary threshold') 29 | parser.add_argument("-m", '--metric', default='iou_score', type=str, help='metric to evaluate the models with') 30 | return parser.parse_args() 31 | 32 | 33 | def compare_models(models, image, mask, num_classes, host='localhost', port=8501, threshold=0.5, metric='iou_score', num_per_row=4): 34 | 35 | metric = get_metric_by_name(metric) 36 | 37 | def get_score(mask, p): 38 | p = tf.cast(p, tf.int32) 39 | p = tf.expand_dims(p, axis=0) 40 | 41 | mask = tf.expand_dims(mask, axis=0) 42 | mask = tf.expand_dims(mask, axis=-1) 43 | return metric(mask, p) 44 | 45 | ensemble, predictions = ensemble_inference(models, image, host=host, port=port, threshold=threshold) 46 | model_scores = [get_score(mask, p) for p in predictions] 47 | model_titles = ["%s (IoU: %.3f)" % (m['name'], model_scores[k]) for k, m in enumerate(models)] 48 | 49 | ensemble_score = get_score(mask, ensemble.astype(np.uint8)) 50 | 51 | titles = ['input'] + model_titles + ['ensemble (IoU: %.3f)' % ensemble_score] + ['target'] 52 | images = [image] + predictions + [ensemble] + [mask.astype(np.float32)] 53 | show.show_images(images, titles=titles, cols=len(titles) // num_per_row) 54 | 55 | 56 | def main(): 57 | 58 | args = get_args() 59 | models = get_models_from_directory(args.models_dir, contains=args.contains) 60 | 61 | logger.info("=============") 62 | logger.info("Found models:") 63 | pprint(models) 64 | 65 | try: 66 | meta = retrieve_metadata(models[0]['name']) 67 | pprint(meta) 68 | except: 69 | logger.info("Please start the tensorflow model server using `tensorflow_model_server - -model_config_file=models.yaml --rest_api_port=%d" % args.port) 70 | exit(0) 71 | 72 | input_shape = meta['inputs'][list(meta['inputs'].keys())[0]]['shape'] 73 | output_shape = meta['outputs'][list(meta['outputs'].keys())[0]]['shape'] 74 | 75 | logger.info("input shape: %s" % str(input_shape)) 76 | logger.info("output shape: %s" % str(output_shape)) 77 | 78 | # infer from retrieved meta data 79 | size = tuple(input_shape[1:3]) 80 | scale_mask = output_shape[-1] == 1 81 | color_mode = 0 if input_shape[-1] == 3 else 1 82 | num_classes = 2 if output_shape[-1] == 1 else output_shape[-1] 83 | 84 | cache_dir = get_cache_dir(args.data_dir, args.dataset) 85 | ds = get_dataset_by_name(args.dataset, cache_dir) 86 | 87 | ds = convert2tfdataset(ds, args.data_type, randomize=args.randomize) 88 | ds = ds.map(get_preprocess_fn(size, color_mode, args.resize_method, scale_mask=scale_mask)) 89 | 90 | for image, mask in ds: 91 | compare_models(models, image.numpy(), mask.numpy(), num_classes, host=args.host, port=args.port, threshold=args.threshold, 92 | metric=args.metric, num_per_row=args.num_per_row) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/evaluation/video.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFont, ImageDraw 2 | from timeit import default_timer as timer 3 | 4 | from ..visualizations import masks 5 | from ..processing import dataset as pre_dataset 6 | import cv2 7 | import numpy as np 8 | import tqdm 9 | 10 | 11 | def predict_video(model, video_path, stream=True, output_path=None, resize_method='resize_with_pad'): 12 | 13 | size = tuple(model.input.shape[1:3]) 14 | depth = model.input.shape[-1] 15 | color_mode = pre_dataset.ColorMode.GRAY if depth == 1 else pre_dataset.ColorMode.RGB 16 | 17 | vid = cv2.VideoCapture(video_path) 18 | if not vid.isOpened(): 19 | raise IOError("Couldn't open webcam or video") 20 | 21 | if output_path: 22 | video_FourCC = cv2.VideoWriter_fourcc(*"mp4v") 23 | video_fps = vid.get(cv2.CAP_PROP_FPS) 24 | video_size = (int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)), 25 | int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))) 26 | 27 | out = cv2.VideoWriter(output_path, video_FourCC, video_fps, video_size) 28 | 29 | accum_time = 0 30 | curr_fps = 0 31 | fps = "FPS: ??" 32 | prev_time = timer() 33 | with tqdm.tqdm(desc="Predict on Video") as tq: 34 | 35 | return_value, frame = vid.read() 36 | while return_value: 37 | # read the frame 38 | image = frame / 255. 39 | image, _ = pre_dataset.resize_and_change_color(image, None, size, color_mode, resize_method=resize_method) 40 | # get detections 41 | images = np.expand_dims(image, axis=0) 42 | p = model.predict(images) 43 | 44 | num_classes = p.shape[-1] if p.shape[-1] > 1 else 2 45 | result = masks.get_colored_segmentation_mask(p, num_classes, images=images)[0] 46 | 47 | # calc fps 48 | curr_time = timer() 49 | exec_time = curr_time - prev_time 50 | prev_time = curr_time 51 | accum_time = accum_time + exec_time 52 | curr_fps = curr_fps + 1 53 | 54 | if accum_time > 1: 55 | accum_time = accum_time - 1 56 | fps = "FPS: " + str(curr_fps) 57 | curr_fps = 0 58 | 59 | if output_path: 60 | out.write(result) 61 | 62 | if stream: 63 | # put fps 64 | cv2.putText(result, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX, 65 | fontScale=0.50, color=(255, 0, 0), thickness=2) 66 | cv2.namedWindow("result", cv2.WINDOW_NORMAL) 67 | cv2.imshow("result", result) 68 | 69 | if cv2.waitKey(1) & 0xFF == ord('q'): 70 | break 71 | tq.update(1) 72 | return_value, frame = vid.read() 73 | 74 | 75 | if __name__ == "__main__": 76 | from tensorflow.keras.models import load_model 77 | model_path = "/home/baudcode/Code/tf-semantic-segmentation/logs/unet-v2-tacobinary-generator-ranger-1e-4-bce_dice/model-best.h5" 78 | model = load_model(model_path, compile=False) 79 | 80 | video_path = '../../dwhelper/VideoClip.mp4' 81 | predict_video(model, 2, video_path, output_path="") 82 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/evaluation/viewer.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.visualizations import masks as masks_utils 2 | import random 3 | import imageio 4 | import numpy as np 5 | import streamlit as st 6 | import os 7 | import tensorflow as tf 8 | import pandas as pd 9 | from tf_semantic_segmentation.datasets import get_dataset_by_name, datasets_by_name, DataType 10 | from tf_semantic_segmentation.processing import ColorMode 11 | from tf_semantic_segmentation.datasets.utils import convert2tfdataset 12 | from tf_semantic_segmentation.processing.dataset import get_preprocess_fn 13 | from tf_semantic_segmentation.serving import predict_on_batch 14 | from tf_semantic_segmentation.metrics.iou_score import iou_score 15 | from tf_semantic_segmentation.metrics.f_scores import f1_score 16 | from tf_semantic_segmentation.metrics.recall import recall 17 | from tf_semantic_segmentation.metrics.precision import precision 18 | 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 20 | 21 | size = (256, 256) 22 | color_mode = ColorMode.RGB 23 | resize_method = 'resize' 24 | num_classes = 2 25 | 26 | # @st.cache 27 | 28 | 29 | def get_ds(): 30 | ds = get_dataset_by_name('tacobinary', '/hdd/datasets/taco/') 31 | return ds, ds.raw()[DataType.VAL] 32 | 33 | 34 | @st.cache 35 | def predict(idx): 36 | ds, data = get_ds() 37 | example = data[idx] 38 | image, mask = ds.parse_example(example) 39 | print(tf.shape(image)) 40 | print(image.shape, mask.shape) 41 | image, mask = get_preprocess_fn(size, color_mode, resize_method, scale_mask=False, mode='np')(image, mask, num_classes) 42 | 43 | images = tf.expand_dims(image, axis=0).numpy() 44 | masks_onehot = tf.expand_dims(mask, axis=0).numpy() 45 | predictions = predict_on_batch(images, model_name='0', input_name='input_1') 46 | predictions = np.asarray(predictions) 47 | 48 | print(predictions.max(), predictions.min(), predictions.dtype) 49 | print(predictions.shape) 50 | 51 | print('onehot...') 52 | if predictions[0].shape[-1] == 1: 53 | predictions[predictions > 0.7] = 1.0 54 | predictions[predictions <= 0.7] = 0.0 55 | 56 | predictions = tf.cast(predictions, tf.int32) 57 | predictions = tf.squeeze(predictions, axis=-1) 58 | predictions = tf.one_hot(predictions, 2) 59 | 60 | prediction_onehot = np.asarray(predictions, dtype=np.int32) 61 | return images, masks_onehot.astype(np.float32), prediction_onehot.astype(np.float32) 62 | 63 | 64 | idx = st.sidebar.number_input('batch', min_value=0, max_value=100, value=0) 65 | 66 | 67 | if st.sidebar.button('random'): 68 | idx = random.randint(0, 100) 69 | 70 | images, masks_onehot, prediction_onehot = predict(idx) 71 | 72 | with tf.device("cpu:0"): 73 | df = { 74 | "iou": iou_score()(masks_onehot, prediction_onehot).numpy(), 75 | "precition": precision()(masks_onehot, prediction_onehot).numpy(), 76 | "recall": recall()(masks_onehot, prediction_onehot).numpy(), 77 | "f1_score": f1_score()(masks_onehot, prediction_onehot).numpy() 78 | } 79 | 80 | for name, value in df.items(): 81 | st.sidebar.markdown("- %s: %.2f" % (name, value)) 82 | 83 | masks = [np.argmax(mask, axis=-1).astype(np.float32) for mask in masks_onehot] 84 | 85 | image = np.concatenate(images, axis=1) if len(images) > 1 else images[0] 86 | mask = np.concatenate(masks, axis=1) if len(masks) > 1 else masks[0] 87 | 88 | prediction_masks = [np.argmax(p, axis=-1).astype(np.float32) for p in prediction_onehot] 89 | prediction = np.concatenate(prediction_masks, axis=1) if len(prediction_masks) > 1 else prediction_masks[0] 90 | 91 | predictions_on_image = [masks_utils.overlay_classes((images[k] * 255.).astype(np.uint8), np.argmax(p, axis=-1), masks_utils.get_colors(num_classes), num_classes) 92 | for k, p in enumerate(prediction_onehot)] 93 | 94 | prediction_on_image = np.concatenate(predictions_on_image, axis=1) 95 | 96 | 97 | st.image((image * 255.).astype(np.uint8), 'inputs', use_column_width=False) 98 | print(mask.shape) 99 | st.image(mask, 'masks', use_column_width=False) 100 | st.image(prediction, 'predictions', use_column_width=False) 101 | st.image(prediction_on_image, 'prediction_on_image', use_column_width=False) 102 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv import Fire, MixConv, GroupedConv2D 2 | from .minibatchstddev import MiniBatchStdDev 3 | from .subpixel import Subpixel 4 | from .pixel_norm import PixelNorm 5 | from ..settings import logger 6 | from .utils import get_norm_by_name 7 | import tensorflow as tf 8 | 9 | 10 | __all__ = ['Fire', "MiniBatchStdDev", "PixelNorm", "MixConv", "GroupedConv2D", "Subpixel", 'get_norm_by_name'] 11 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/layers/minibatchstddev.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code taken from https://github.com/johnryh/Face_Embedding_GAN/blob/master/network_utility.py 3 | """ 4 | 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras.layers import Layer 7 | 8 | 9 | class MiniBatchStdDev(Layer): 10 | """ 11 | https://arxiv.org/abs/1710.10196 12 | 13 | It computes the standard deviations of the 14 | feature map pixels across the batch, and appends them as an extra channel. 15 | """ 16 | 17 | def __init__(self, group_size=4): 18 | super(MiniBatchStdDev, self).__init__() 19 | self.group_size = group_size 20 | 21 | def call(self, x, training=None): 22 | group_size = K.minimum(self.group_size, K.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. 23 | s = K.shape(x) # [NCHW] Input shape. 24 | y = K.reshape(x, [group_size, -1, s[1], s[2], s[3]]) # [GMCHW] Split minibatch into M groups of size G. 25 | y = K.cast(y, 'float32') # [GMCHW] Cast to FP32. 26 | y -= K.mean(y, axis=0, keepdims=True) # [GMCHW] Subtract mean over group. 27 | y = K.mean(K.square(y), axis=0) # [MCHW] Calc variance over group. 28 | y = K.sqrt(y + 1e-8) # [MCHW] Calc stddev over group. 29 | y = K.mean(y, axis=[1, 2, 3], keepdims=True) # [M111] Take average over fmaps and pixels. 30 | y = K.cast(y, x.dtype) # [M111] Cast back to original data type. 31 | y = K.tile(y, [group_size, s[1], s[2], 1]) # [N1HW] Replicate over group and pixels. 32 | return K.concatenate([x, y], axis=-1) 33 | 34 | def get_config(self): 35 | config = { 36 | "group_size": self.group_size, 37 | } 38 | base_config = super(MiniBatchStdDev, self).get_config() 39 | return dict(list(base_config.items()) + list(config.items())) 40 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/layers/pixel_norm.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras.layers import Layer 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | class PixelNorm(Layer): 6 | """ https://arxiv.org/abs/1710.10196 7 | 8 | It normalizes the feature vector in each pixel to unit length, and is applied after the convolutional layers. 9 | This is done to prevent signal magnitudes from spiraling out of control during training. 10 | """ 11 | 12 | def __init__(self, epsilon=1e-8): 13 | super(PixelNorm, self).__init__() 14 | self.epsilon = epsilon 15 | 16 | def call(self, x): 17 | return x / K.sqrt(K.mean(K.square(x), axis=-1, keepdims=True) + self.epsilon) 18 | 19 | def get_config(self): 20 | config = { 21 | "epsilon": self.epsilon, 22 | } 23 | base_config = super(PixelNorm, self).get_config() 24 | return dict(list(base_config.items()) + list(config.items())) 25 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/layers/subpixel.py: -------------------------------------------------------------------------------- 1 | """ Code taken from https://github.com/atriumlts/subpixel/blob/a2d9e9a163db9e3268df1b7109672eea814ee356/keras_subpixel.py """ 2 | 3 | 4 | from tensorflow.keras import backend as K 5 | from tensorflow.keras.layers import Conv2D 6 | 7 | 8 | class Subpixel(Conv2D): 9 | 10 | """ https://arxiv.org/abs/1609.05158 11 | 12 | Subpixel Layer as a child class of Conv2D. This layer accepts all normal 13 | arguments, with the exception of dilation_rate(). The argument r indicates 14 | the upsampling factor, which is applied to the normal output of Conv2D. 15 | The output of this layer will have the same number of channels as the 16 | indicated filter field, and thus works for grayscale, color, or as a a 17 | hidden layer. 18 | 19 | Arguments: 20 | *see Keras Docs for Conv2D args, noting that dilation_rate() is removed* 21 | r: upscaling factor, which is applied to the output of normal Conv2D 22 | """ 23 | 24 | def __init__(self, 25 | filters, 26 | kernel_size, 27 | r, 28 | padding='valid', 29 | data_format=None, 30 | strides=(1, 1), 31 | activation=None, 32 | use_bias=True, 33 | kernel_initializer='glorot_uniform', 34 | bias_initializer='zeros', 35 | kernel_regularizer=None, 36 | bias_regularizer=None, 37 | activity_regularizer=None, 38 | kernel_constraint=None, 39 | bias_constraint=None, 40 | **kwargs): 41 | super(Subpixel, self).__init__( 42 | filters=r * r * filters, 43 | kernel_size=kernel_size, 44 | strides=strides, 45 | padding=padding, 46 | data_format=data_format, 47 | activation=activation, 48 | use_bias=use_bias, 49 | kernel_initializer=kernel_initializer, 50 | bias_initializer=bias_initializer, 51 | kernel_regularizer=kernel_regularizer, 52 | bias_regularizer=bias_regularizer, 53 | activity_regularizer=activity_regularizer, 54 | kernel_constraint=kernel_constraint, 55 | bias_constraint=bias_constraint, 56 | **kwargs) 57 | self.r = r 58 | 59 | def _phase_shift(self, I): 60 | r = self.r 61 | bsize, a, b, c = I.get_shape().as_list() 62 | bsize = K.shape(I)[0] # Handling Dimension(None) type for undefined batch dim 63 | X = K.reshape(I, [bsize, a, b, K.cast(c / (r * r), 'int32'), r, r]) # bsize, a, b, c/(r*r), r, r 64 | X = K.permute_dimensions(X, (0, 1, 2, 5, 4, 3)) # bsize, a, b, r, r, c/(r*r) 65 | #Keras backend does not support tf.split, so in future versions this could be nicer 66 | X = [X[:, i, :, :, :, :] for i in range(a)] # a, [bsize, b, r, r, c/(r*r) 67 | X = K.concatenate(X, 2) # bsize, b, a*r, r, c/(r*r) 68 | X = [X[:, i, :, :, :] for i in range(b)] # b, [bsize, r, r, c/(r*r) 69 | X = K.concatenate(X, 2) # bsize, a*r, b*r, c/(r*r) 70 | return X 71 | 72 | def call(self, inputs): 73 | return self._phase_shift(super(Subpixel, self).call(inputs)) 74 | 75 | def compute_output_shape(self, input_shape): 76 | unshifted = super(Subpixel, self).compute_output_shape(input_shape) 77 | return (unshifted[0], self.r * unshifted[1], self.r * unshifted[2], unshifted[3] / (self.r * self.r)) 78 | 79 | def get_config(self): 80 | config = super(Subpixel, self).get_config() 81 | config.pop('rank') 82 | config.pop('dilation_rate') 83 | config['filters'] /= self.r * self.r 84 | config['r'] = self.r 85 | return config 86 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/layers/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .pixel_norm import PixelNorm 3 | from ..utils import logger 4 | 5 | 6 | def get_norm_by_name(name='batch'): 7 | if name == 'batch': 8 | return tf.keras.layers.BatchNormalization(axis=-1) 9 | elif name == 'instance': 10 | import tensorflow_addons as tfa 11 | return tfa.layers.InstanceNormalization(axis=-1) 12 | # elif name == 'pixel': 13 | # return PixelNorm() 14 | else: 15 | logger.warn("using default norm: batch") 16 | return tf.keras.layers.BatchNormalization(axis=-1) 17 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | from .utils import gather_channels, get_reduce_axes, round_if_needed, SMOOTH, average, onehot2image, expand_binary 3 | from .focal import binary_focal_loss, categorical_focal_loss 4 | from .ssim import ssim_loss 5 | from .ce import ce_label_smoothing_loss, categorical_crossentropy_loss, binary_crossentropy_loss 6 | from .dice import dice_loss, tversky_loss, focal_tversky_loss 7 | from .combined import categorical_crossentropy_ssim_loss, binary_crossentropy_ssim_loss, \ 8 | dice_binary_crossentropy_loss, dice_categorical_crossentropy_loss, dice_ssim_loss, \ 9 | dice_ssim_binary_crossentropy_loss, dice_ssim_categorical_crossentropy_loss 10 | from .lovasz import binary_lovasz, categorical_lovasz 11 | 12 | losses_by_name = { 13 | "categorical_crossentropy": categorical_crossentropy_loss(), 14 | "ce_label_smoothing": ce_label_smoothing_loss(smoothing=0.1), 15 | "binary_crossentropy": binary_crossentropy_loss(), 16 | "categorical_focal": categorical_focal_loss(), 17 | "binary_focal": binary_focal_loss(), 18 | "ssim": ssim_loss(), 19 | "dice": dice_loss(), 20 | "tversky": tversky_loss(), 21 | # "binary_lovasz": binary_lovasz(), 22 | "categorical_lovasz": categorical_lovasz(), 23 | "focal_tversky": focal_tversky_loss(), 24 | # combined losses 25 | "binary_crossentropy_ssim": binary_crossentropy_ssim_loss(), 26 | "categorical_crossentropy_ssim": categorical_crossentropy_ssim_loss(), 27 | "dice_binary_crossentropy": dice_binary_crossentropy_loss(), 28 | "dice_categorical_crossentropy": dice_categorical_crossentropy_loss(), 29 | "dice_ssim": dice_ssim_loss(), 30 | "dice_ssim_binary_crossentropy": dice_ssim_binary_crossentropy_loss(), 31 | "dice_ssim_categorical_crossentropy": dice_ssim_categorical_crossentropy_loss() 32 | } 33 | 34 | 35 | def get_loss_by_name(name): 36 | if name in losses_by_name: 37 | return losses_by_name[name] 38 | else: 39 | raise Exception("cannot find loss %s" % name) 40 | 41 | 42 | __all__ = ["categorical_focal_loss", "binary_crossentropy_loss", "binary_focal_loss", 43 | "focal_tversky_loss", "tversky_loss", "dice_loss", "ssim_loss", 44 | "binary_crossentropy_ssim_loss", "categorical_crossentropy_ssim_loss", "dice_binary_crossentropy_loss", 45 | "dice_categorical_crossentropy_loss", "dice_ssim_loss", "dice_ssim_binary_crossentropy_loss", "dice_ssim_categorical_crossentropy_loss", 46 | "get_loss_by_name", "losses_by_name", 47 | "SMOOTH", "gather_channels", "get_reduce_axes", "round_if_needed", "average", "expand_binary"] 48 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/losses/ce.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.losses import CategoricalCrossentropy, BinaryCrossentropy, Reduction 2 | from tensorflow.keras import backend as K 3 | import tensorflow as tf 4 | from .utils import to2d, to1d 5 | 6 | 7 | def ce_label_smoothing_loss(smoothing=0.1): 8 | def ce_label_smoothing_fixed(y_true, y_pred): 9 | # y_true, y_pred = to2d(y_true), to2d(y_pred) 10 | return K.mean(CategoricalCrossentropy(label_smoothing=smoothing, reduction=Reduction.NONE)(y_true, y_pred)) 11 | return ce_label_smoothing_fixed 12 | 13 | 14 | def categorical_crossentropy_loss(): 15 | def categorical_crossentropy(y_true, y_pred): 16 | # y_true, y_pred = to2d(y_true, y_pred) 17 | # y_true, y_pred = to2d(y_true), to2d(y_pred) 18 | return K.mean(CategoricalCrossentropy(reduction=Reduction.NONE)(y_true, y_pred)) 19 | return categorical_crossentropy 20 | 21 | 22 | def binary_crossentropy_loss(): 23 | def binary_crossentropy(y_true, y_pred): 24 | #y_true, y_pred = to1d(y_true), to1d(y_pred) 25 | r = K.mean(BinaryCrossentropy(reduction=Reduction.NONE)(y_true, y_pred)) 26 | return tf.cast(r, tf.float32) 27 | 28 | return binary_crossentropy 29 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/losses/combined.py: -------------------------------------------------------------------------------- 1 | from .ssim import ssim_loss 2 | from .ce import categorical_crossentropy_loss, binary_crossentropy_loss 3 | from .dice import dice_loss 4 | import tensorflow as tf 5 | 6 | 7 | def categorical_crossentropy_ssim_loss(loss_weight_ce=1.0, loss_weight_ssim=1.0): 8 | def categorical_crossentropy_ssim(y_true, y_pred): 9 | ce = categorical_crossentropy_loss()(y_true, y_pred) 10 | ssim = ssim_loss()(y_true, y_pred) 11 | return ce * loss_weight_ce + loss_weight_ssim * ssim 12 | 13 | return categorical_crossentropy_ssim 14 | 15 | 16 | def binary_crossentropy_ssim_loss(loss_weight_ce=1.0, loss_weight_ssim=1.0): 17 | def binary_crossentropy_ssim(y_true, y_pred): 18 | ce = binary_crossentropy_loss()(y_true, y_pred) 19 | ssim = ssim_loss()(y_true, y_pred) 20 | return ce * loss_weight_ce + loss_weight_ssim * ssim 21 | 22 | return binary_crossentropy_ssim 23 | 24 | 25 | def dice_ssim_loss(loss_weight_dice=1.0, loss_weight_ssim=1.0): 26 | def dice_ssim(y_true, y_pred): 27 | dice = dice_loss()(y_true, y_pred) 28 | ssim = ssim_loss()(y_true, y_pred) 29 | return dice * loss_weight_dice + loss_weight_ssim * ssim 30 | 31 | return dice_ssim 32 | 33 | 34 | def dice_binary_crossentropy_loss(loss_weight_dice=1.0, loss_weight_ce=1.0): 35 | def dice_binary_crossentropy(y_true, y_pred): 36 | dice = dice_loss()(y_true, y_pred) 37 | ce = binary_crossentropy_loss()(y_true, y_pred) 38 | return dice * loss_weight_dice + ce * loss_weight_ce 39 | 40 | return dice_binary_crossentropy 41 | 42 | 43 | def dice_categorical_crossentropy_loss(loss_weight_dice=1.0, loss_weight_ce=1.0): 44 | def dice_categorical_crossentropy(y_true, y_pred): 45 | dice = dice_loss()(y_true, y_pred) 46 | ce = categorical_crossentropy_loss()(y_true, y_pred) 47 | return dice * loss_weight_dice + ce * loss_weight_ce 48 | 49 | return dice_categorical_crossentropy 50 | 51 | 52 | def dice_ssim_binary_crossentropy_loss(loss_weight_dice=1.0, loss_weight_ce=1.0, loss_weight_ssim=1.0): 53 | def dice_ssim_binary_crossentropy(y_true, y_pred): 54 | dice = dice_loss()(y_true, y_pred) 55 | dice = tf.cast(dice, tf.float32) 56 | 57 | ce = binary_crossentropy_loss()(y_true, y_pred) 58 | ce = tf.cast(ce, tf.float32) 59 | ssim = ssim_loss()(y_true, y_pred) 60 | return dice * loss_weight_dice + ce * loss_weight_ce + loss_weight_ssim * ssim 61 | 62 | return dice_ssim_binary_crossentropy 63 | 64 | 65 | def dice_ssim_categorical_crossentropy_loss(loss_weight_dice=1.0, loss_weight_ce=1.0, loss_weight_ssim=1.0): 66 | def dice_ssim_categorical_crossentropy(y_true, y_pred): 67 | dice = dice_loss()(y_true, y_pred) 68 | ce = categorical_crossentropy_loss()(y_true, y_pred) 69 | ssim = ssim_loss()(y_true, y_pred) 70 | tf.print(dice.dtype, ce.dtype, ssim.dtype) 71 | return dice * loss_weight_dice + ce * loss_weight_ce + loss_weight_ssim * ssim 72 | 73 | return dice_ssim_categorical_crossentropy 74 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/losses/dice.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def dice_loss(): 5 | def dice_loss(y_true, y_pred): 6 | """ F1 Score """ 7 | numerator = 2 * tf.reduce_sum(y_true * y_pred) 8 | denominator = tf.reduce_sum(y_true + y_pred) 9 | 10 | r = 1 - (numerator + 1) / (denominator + 1) 11 | return tf.cast(r, tf.float32) 12 | 13 | return dice_loss 14 | 15 | 16 | def tversky_loss(beta=0.7): 17 | """ Tversky index (TI) is a generalization of Dice’s coefficient. TI adds a weight to FP (false positives) and FN (false negatives). """ 18 | def tversky_loss(y_true, y_pred): 19 | numerator = tf.reduce_sum(y_true * y_pred) 20 | denominator = y_true * y_pred + beta * (1 - y_true) * y_pred + (1 - beta) * y_true * (1 - y_pred) 21 | 22 | r = 1 - (numerator + 1) / (tf.reduce_sum(denominator) + 1) 23 | return tf.cast(r, tf.float32) 24 | 25 | return tversky_loss 26 | 27 | 28 | def focal_tversky_loss(beta=0.7, gamma=0.75): 29 | def focal_tversky(y_true, y_pred): 30 | loss = tversky_loss(beta)(y_true, y_pred) 31 | return tf.pow(loss, gamma) 32 | 33 | return focal_tversky 34 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/losses/focal.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | import tensorflow as tf 3 | from .utils import gather_channels 4 | 5 | 6 | def binary_focal_loss(gamma=2.0, alpha=0.25, **kwargs): 7 | r"""Implementation of Focal Loss from the paper in binary classification 8 | 9 | Formula: 10 | loss = - gt * alpha * ((1 - pr)^gamma) * log(pr) \ 11 | - (1 - gt) * alpha * (pr^gamma) * log(1 - pr) 12 | 13 | Args: 14 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 15 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 16 | alpha: the same as weighting factor in balanced cross entropy, default 0.25 17 | gamma: focusing parameter for modulating factor (1-p), default 2.0 18 | 19 | """ 20 | 21 | def binary_focal(gt, pr): 22 | # clip to prevent NaN's and Inf's 23 | pr = K.clip(pr, K.epsilon(), 1.0 - K.epsilon()) 24 | 25 | loss_1 = - gt * (alpha * K.pow((1 - pr), gamma) * K.log(pr)) 26 | loss_0 = - (1 - gt) * ((1 - alpha) * K.pow((pr), gamma) * K.log(1 - pr)) 27 | loss = K.mean(loss_0 + loss_1) 28 | return loss 29 | return binary_focal 30 | 31 | 32 | def categorical_focal_loss(gamma=2.0, alpha=0.25, class_indexes=None, **kwargs): 33 | r"""Implementation of Focal Loss from the paper in multiclass classification 34 | 35 | Formula: 36 | loss = - gt * alpha * ((1 - pr)^gamma) * log(pr) 37 | 38 | Args: 39 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 40 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 41 | alpha: the same as weighting factor in balanced cross entropy, default 0.25 42 | gamma: focusing parameter for modulating factor (1-p), default 2.0 43 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 44 | 45 | """ 46 | 47 | def categorical_focal(gt, pr): 48 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 49 | 50 | # clip to prevent NaN's and Inf's 51 | pr = K.clip(pr, K.epsilon(), 1.0 - K.epsilon()) 52 | 53 | # Calculate focal loss 54 | loss = - gt * (alpha * K.pow((1 - pr), gamma) * K.log(pr)) 55 | 56 | return K.mean(loss) 57 | return categorical_focal 58 | 59 | 60 | def smooth_l1(sigma=3.0): 61 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true. 62 | Args 63 | sigma: This argument defines the point where the loss changes from L2 to L1. 64 | Returns 65 | The smooth L1 loss of y_pred w.r.t. y_true. 66 | """ 67 | def smooth_l1(y_true, y_pred): 68 | 69 | sigma_squared = sigma ** 2 70 | # separate target and state 71 | regression = y_pred 72 | regression_target = y_true 73 | 74 | # compute smooth L1 loss 75 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma 76 | # |x| - 0.5 / sigma / sigma otherwise 77 | regression_diff = regression - regression_target 78 | regression_diff = tf.abs(regression_diff) 79 | regression_loss = tf.where( 80 | tf.less(regression_diff, 1.0 / sigma_squared), 81 | 0.5 * sigma_squared * tf.pow(regression_diff, 2), 82 | regression_diff - 0.5 / sigma_squared 83 | ) 84 | 85 | return tf.reduce_mean(regression_loss) 86 | return smooth_l1 87 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/losses/ssim.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .utils import onehot2image 3 | 4 | 5 | def ssim_loss(): 6 | def ssim(y_true, y_pred): 7 | y_true = onehot2image(y_true) 8 | y_pred = onehot2image(y_pred) 9 | 10 | ssim_batch = tf.image.ssim(y_true, y_pred, max_val=1.0, filter_size=3, filter_sigma=1.5, k1=0.01, k2=0.03) 11 | ssim = tf.reduce_mean(ssim_batch, axis=-1) 12 | 13 | return 1.0 - ssim 14 | return ssim 15 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/losses/utils.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend 2 | SMOOTH = 1e-5 3 | 4 | # ---------------------------------------------------------------- 5 | # Helpers 6 | # ---------------------------------------------------------------- 7 | 8 | 9 | def to1d(t): 10 | entries = backend.prod(t.shape) 11 | return backend.reshape(t, (entries, )) 12 | 13 | 14 | def to2d(t): 15 | s = backend.prod(backend.shape(t)[1:]) 16 | t = backend.reshape(t, [-1, s]) 17 | return t 18 | 19 | 20 | def onehot2image(y): 21 | """ 22 | Arguments: 23 | - y: Tensor BHWC (onehot) 24 | 25 | Scales input of shape BHWC to BHW1 image of range (0, 1) tf.float32 26 | """ 27 | if y.shape[-1] == 1: 28 | # assume masks are scaled using sigmoid 29 | return y 30 | 31 | num_classes = y.shape[-1] 32 | y = backend.argmax(y, axis=-1) 33 | y = backend.expand_dims(y, axis=-1) 34 | y = backend.cast(y, backend.floatx()) 35 | y = backend.cast(y, backend.floatx()) / backend.cast((num_classes - 1), backend.floatx()) 36 | return y 37 | 38 | 39 | def _gather_channels(x, indexes): 40 | """Slice tensor along channels axis by given indexes""" 41 | if backend.image_data_format() == 'channels_last': 42 | x = backend.permute_dimensions(x, (3, 0, 1, 2)) 43 | x = backend.gather(x, indexes) 44 | x = backend.permute_dimensions(x, (1, 2, 3, 0)) 45 | else: 46 | x = backend.permute_dimensions(x, (1, 0, 2, 3)) 47 | x = backend.gather(x, indexes) 48 | x = backend.permute_dimensions(x, (1, 0, 2, 3)) 49 | return x 50 | 51 | 52 | def get_reduce_axes(per_image): 53 | axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3] 54 | if not per_image: 55 | axes.insert(0, 0) 56 | return axes 57 | 58 | 59 | def gather_channels(*xs, indexes=None): 60 | """Slice tensors along channels axis by given indexes""" 61 | if indexes is None: 62 | return xs 63 | elif isinstance(indexes, (int)): 64 | indexes = [indexes] 65 | xs = [_gather_channels(x, indexes=indexes) for x in xs] 66 | return xs 67 | 68 | 69 | def round_if_needed(x, threshold): 70 | if threshold is not None: 71 | x = backend.greater(x, threshold) 72 | x = backend.cast(x, backend.floatx()) 73 | return x 74 | 75 | 76 | def average(x, per_image=False, class_weights=None): 77 | if per_image: 78 | x = backend.mean(x, axis=0) 79 | if class_weights is not None: 80 | x = x * class_weights 81 | return backend.mean(x) 82 | 83 | 84 | def expand_binary(x): 85 | # remove last dim 86 | x = backend.squeeze(x, axis=-1) 87 | # scale to 0 or 1 88 | x = backend.round(x) 89 | x = backend.cast(x, 'int32') 90 | x = backend.one_hot(x, 2) 91 | return x 92 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .f_scores import f1_score, f2_score 2 | from .iou_score import iou_score 3 | from .precision import precision 4 | from .recall import recall 5 | from .psnr import psnr 6 | from .ssim import ssim 7 | from .kmetrics import binary_accuracy, categorical_accuracy, mae 8 | 9 | 10 | metrics_by_name = { 11 | "f1_score": f1_score(), 12 | "f2_score": f2_score(), 13 | "precision": precision(), 14 | "recall": recall(), 15 | "iou_score": iou_score(), 16 | "psnr": psnr(), 17 | "ssim": ssim(), 18 | "mae": mae, 19 | "binary_accuracy": binary_accuracy, 20 | "categorical_accuracy": categorical_accuracy 21 | } 22 | 23 | 24 | def get_metric_by_name(name): 25 | if name in metrics_by_name: 26 | return metrics_by_name[name] 27 | else: 28 | raise Exception("cannot find metric %s" % name) 29 | 30 | 31 | __all__ = list(metrics_by_name.keys()) 32 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/f_scores.py: -------------------------------------------------------------------------------- 1 | from ..losses import SMOOTH, gather_channels, round_if_needed, get_reduce_axes, average, expand_binary 2 | from tensorflow.keras import backend as K 3 | 4 | """ Taken from https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/metrics.py """ 5 | 6 | 7 | def _f_score(gt, pr, beta=1, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None): 8 | r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall, 9 | where an F-score reaches its best value at 1 and worst score at 0. 10 | The relative contribution of ``precision`` and ``recall`` to the F1-score are equal. 11 | The formula for the F score is: 12 | 13 | .. math:: F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall} 14 | {\beta^2 \cdot precision + recall} 15 | 16 | The formula in terms of *Type I* and *Type II* errors: 17 | 18 | .. math:: F_\beta(A, B) = \frac{(1 + \beta^2) TP} {(1 + \beta^2) TP + \beta^2 FN + FP} 19 | 20 | 21 | where: 22 | TP - true positive; 23 | FP - false positive; 24 | FN - false negative; 25 | 26 | Args: 27 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 28 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 29 | class_weights: 1. or list of class weights, len(weights) = C 30 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 31 | beta: f-score coefficient 32 | smooth: value to avoid division by zero 33 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 34 | else over whole batch 35 | threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round 36 | 37 | Returns: 38 | F-score in range [0, 1] 39 | 40 | """ 41 | if gt.shape[-1] == 1: 42 | # assuming binary 43 | gt, pr = expand_binary(gt), expand_binary(pr) 44 | 45 | gt, pr = gather_channels(gt, pr, indexes=class_indexes) 46 | pr = round_if_needed(pr, threshold) 47 | axes = get_reduce_axes(per_image) 48 | 49 | # calculate score 50 | tp = K.cast(K.sum(gt * pr, axis=axes), "float64") 51 | fp = K.cast(K.sum(pr, axis=axes), 'float64') - tp 52 | fn = K.cast(K.sum(gt, axis=axes), 'float64') - tp 53 | 54 | score = ((1 + beta ** 2) * tp + smooth) \ 55 | / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 56 | score = average(score, per_image, class_weights) 57 | 58 | return score 59 | 60 | 61 | def f1_score(class_weights=1): 62 | def f1_score(gt, pr): 63 | return _f_score(gt, pr, beta=1, class_weights=class_weights) 64 | return f1_score 65 | 66 | 67 | def f2_score(class_weights=1): 68 | def f2_score(gt, pr): 69 | return _f_score(gt, pr, beta=2, class_weights=class_weights) 70 | return f2_score 71 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/iou_score.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | from ..losses import SMOOTH, gather_channels, round_if_needed, get_reduce_axes, average, expand_binary 3 | 4 | """ Taken from https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/metrics.py """ 5 | 6 | 7 | def iou_score(class_weights=1., class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None): 8 | r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient 9 | (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the 10 | similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, 11 | and is defined as the size of the intersection divided by the size of the union of the sample sets: 12 | 13 | .. math:: J(A, B) = \frac{A \cap B}{A \cup B} 14 | 15 | Args: 16 | gt: ground truth 4D keras tensor(B, H, W, C) or (B, C, H, W) 17 | pr: prediction 4D keras tensor(B, H, W, C) or (B, C, H, W) 18 | class_weights: 1. or list of class weights, len(weights) = C 19 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 20 | smooth: value to avoid division by zero 21 | per_image: if ``True``, metric is calculated as mean over images in batch(B), 22 | else over whole batch 23 | threshold: value to round predictions(use ``>`` comparison), if ``None`` prediction will not be round 24 | 25 | Returns: 26 | IoU / Jaccard score in range[0, 1] 27 | 28 | .. _`Jaccard index`: https: // en.wikipedia.org / wiki / Jaccard_index 29 | 30 | """ 31 | def iou_score(gt, pr): 32 | 33 | if gt.shape[-1] == 1: 34 | # assuming binary 35 | gt, pr = expand_binary(gt), expand_binary(pr) 36 | 37 | gt, pr = gather_channels(gt, pr, indexes=class_indexes) 38 | pr = round_if_needed(pr, threshold) 39 | axes = get_reduce_axes(per_image) 40 | 41 | # score calculation 42 | intersection = K.sum(gt * pr, axis=axes) 43 | union = K.sum(gt + pr, axis=axes) - intersection 44 | 45 | score = (intersection + smooth) / (union + smooth) 46 | score = average(score, per_image, class_weights) 47 | 48 | return score 49 | 50 | return iou_score 51 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/kmetrics.py: -------------------------------------------------------------------------------- 1 | 2 | from tensorflow.keras import backend as K 3 | from tensorflow.keras import metrics 4 | 5 | 6 | def mae(y_true, y_pred): return K.mean(metrics.mae(y_true, y_pred)) 7 | 8 | 9 | def binary_accuracy(y_true, y_pred): return K.mean(metrics.binary_accuracy(y_true, y_pred)) 10 | 11 | 12 | def categorical_accuracy(y_true, y_pred): return K.mean(metrics.categorical_accuracy(y_true, y_pred)) 13 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/precision.py: -------------------------------------------------------------------------------- 1 | from ..losses import gather_channels, round_if_needed, get_reduce_axes, SMOOTH, average, expand_binary 2 | from tensorflow.keras import backend as K 3 | 4 | 5 | """ Taken from https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/metrics.py """ 6 | 7 | 8 | def precision(class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs): 9 | r"""Calculate precision between the ground truth (gt) and the prediction (pr). 10 | 11 | .. math:: F_\beta(tp, fp) = \frac{tp} {(tp + fp)} 12 | 13 | where: 14 | - tp - true positives; 15 | - fp - false positives; 16 | 17 | Args: 18 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 19 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 20 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``) 21 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 22 | smooth: Float value to avoid division by zero. 23 | per_image: If ``True``, metric is calculated as mean over images in batch (B), 24 | else over whole batch. 25 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round. 26 | name: Optional string, if ``None`` default ``precision`` name is used. 27 | 28 | Returns: 29 | float: precision score 30 | """ 31 | def precision(gt, pr): 32 | if gt.shape[-1] == 1: 33 | # assuming binary 34 | gt, pr = expand_binary(gt), expand_binary(pr) 35 | 36 | gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) 37 | pr = round_if_needed(pr, threshold, **kwargs) 38 | axes = get_reduce_axes(per_image, **kwargs) 39 | 40 | # score calculation 41 | tp = K.cast(K.sum(gt * pr, axis=axes), 'float64') 42 | fp = K.cast(K.sum(pr, axis=axes), 'float64') - tp 43 | 44 | score = (tp + smooth) / (tp + fp + smooth) 45 | score = average(score, per_image, class_weights, **kwargs) 46 | 47 | return score 48 | return precision 49 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/psnr.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | import numpy as np 3 | from ..losses import onehot2image 4 | 5 | 6 | def psnr(SMOOTH=1e-9): 7 | def psnr(y_true, y_pred): 8 | # scale between 0 and 255 9 | y_true = onehot2image(y_true) * 255. 10 | y_pred = onehot2image(y_pred) * 255. 11 | 12 | # mean squared error and scale 13 | mse = K.mean(K.square(y_pred - y_true)) + SMOOTH 14 | k1 = 20 * K.log(255.0) / K.log(10.0) 15 | k2 = np.float32(10.0 / np.log(10)) * K.log(mse) 16 | return k1 - k2 17 | return psnr 18 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/recall.py: -------------------------------------------------------------------------------- 1 | from ..losses import SMOOTH, gather_channels, round_if_needed, get_reduce_axes, average, expand_binary 2 | from tensorflow.keras import backend as K 3 | 4 | """ Taken from https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/metrics.py """ 5 | 6 | 7 | def recall(class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None): 8 | r"""Calculate recall between the ground truth (gt) and the prediction (pr). 9 | 10 | .. math:: F_\beta(tp, fn) = \frac{tp} {(tp + fn)} 11 | 12 | where: 13 | - tp - true positives; 14 | - fp - false positives; 15 | 16 | Args: 17 | gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W) 18 | pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W) 19 | class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``) 20 | class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. 21 | smooth: Float value to avoid division by zero. 22 | per_image: If ``True``, metric is calculated as mean over images in batch (B), 23 | else over whole batch. 24 | threshold: Float value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round. 25 | name: Optional string, if ``None`` default ``precision`` name is used. 26 | 27 | Returns: 28 | float: recall score 29 | """ 30 | def recall(gt, pr): 31 | if gt.shape[-1] == 1: 32 | # assuming binary 33 | gt, pr = expand_binary(gt), expand_binary(pr) 34 | 35 | gt, pr = gather_channels(gt, pr, indexes=class_indexes) 36 | pr = round_if_needed(pr, threshold) 37 | axes = get_reduce_axes(per_image) 38 | 39 | tp = K.cast(K.sum(gt * pr, axis=axes), 'float64') 40 | fn = K.cast(K.sum(gt, axis=axes), 'float64') - tp 41 | 42 | score = (tp + smooth) / (tp + fn + smooth) 43 | score = average(score, per_image, class_weights) 44 | 45 | return score 46 | return recall 47 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/metrics/ssim.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ..losses import onehot2image 3 | 4 | 5 | def ssim(filter_size=3, filter_sigma=1.5, k1=0.01, k2=0.03): 6 | def ssim(y_true, y_pred): 7 | # find label max 8 | y_true = onehot2image(y_true) 9 | y_pred = onehot2image(y_pred) 10 | 11 | ssim_batch = tf.image.ssim(y_true, y_pred, max_val=1.0, filter_size=filter_size, filter_sigma=filter_sigma, k1=k1, k2=k2) 12 | return tf.reduce_mean(ssim_batch, axis=-1) 13 | return ssim 14 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .erfnet import erfnet 2 | from .unet import unet 3 | from .imagenet_unet import unet_mobilenet, unet_inception_resnet_v2, unet_resnet 4 | from .satellite_unet import satellite_unet 5 | from .multires_unet import multires_unet 6 | from .attention_unet import attention_unet 7 | from .nested_unet import nested_unet 8 | from .psp import psp 9 | from .fcn import fcn 10 | from .u2net import u2net, u2netp 11 | 12 | from tensorflow.keras.models import Model 13 | import inspect 14 | 15 | models_by_name = { 16 | "erfnet": erfnet, 17 | "unet": unet, 18 | "unet_mobilenet": unet_mobilenet, 19 | "unet_inception_resnet_v2": unet_inception_resnet_v2, 20 | "unet_resnet": unet_resnet, 21 | "satellite_unet": satellite_unet, 22 | "multires_unet": multires_unet, 23 | "attention_unet": attention_unet, 24 | "nested_unet": nested_unet, 25 | "psp": psp, 26 | "fcn": fcn, 27 | "u2net": u2net, 28 | "u2netp": u2netp, 29 | } 30 | 31 | 32 | def get_model_description(name): 33 | return inspect.getdoc(models_by_name[name]) 34 | 35 | 36 | def get_model_by_name(name, args) -> Model: 37 | if name in models_by_name.keys(): 38 | return models_by_name[name](**args) 39 | else: 40 | raise Exception("cannot find model %s" % name) 41 | 42 | 43 | __all__ = ['erfnet', 'unet', 'multires_unet', "unet_mobilenet", "unet_inception_resnet_v2", "unet_resnet", "satellite_unet", 44 | "attention_unet", "nested_unet", 45 | 'get_model_by_name', 'models_by_name'] 46 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tf_semantic_segmentation/models/apps/__init__.py -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/apps/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | 4 | 5 | def _obtain_input_shape(input_shape, 6 | default_size, 7 | min_size, 8 | data_format, 9 | require_flatten, 10 | weights=None): 11 | """ 12 | Taken from: 13 | https://github.com/keras-team/keras-applications/blob/master/keras_applications/imagenet_utils.py 14 | 15 | Internal utility to compute/validate a model's input shape. 16 | # Arguments 17 | input_shape: Either None (will return the default network input shape), 18 | or a user-provided shape to be validated. 19 | default_size: Default input width/height for the model. 20 | min_size: Minimum input width/height accepted by the model. 21 | data_format: Image data format to use. 22 | require_flatten: Whether the model is expected to 23 | be linked to a classifier via a Flatten layer. 24 | weights: One of `None` (random initialization) 25 | or 'imagenet' (pre-training on ImageNet). 26 | If weights='imagenet' input channels must be equal to 3. 27 | # Returns 28 | An integer shape tuple (may include None entries). 29 | # Raises 30 | ValueError: In case of invalid argument values. 31 | """ 32 | if weights != 'imagenet' and input_shape and len(input_shape) == 3: 33 | if data_format == 'channels_first': 34 | if input_shape[0] not in {1, 3}: 35 | warnings.warn( 36 | 'This model usually expects 1 or 3 input channels. ' 37 | 'However, it was passed an input_shape with ' + 38 | str(input_shape[0]) + ' input channels.') 39 | default_shape = (input_shape[0], default_size, default_size) 40 | else: 41 | if input_shape[-1] not in {1, 3}: 42 | warnings.warn( 43 | 'This model usually expects 1 or 3 input channels. ' 44 | 'However, it was passed an input_shape with ' + 45 | str(input_shape[-1]) + ' input channels.') 46 | default_shape = (default_size, default_size, input_shape[-1]) 47 | else: 48 | if data_format == 'channels_first': 49 | default_shape = (3, default_size, default_size) 50 | else: 51 | default_shape = (default_size, default_size, 3) 52 | if weights == 'imagenet' and require_flatten: 53 | if input_shape is not None: 54 | if input_shape != default_shape: 55 | raise ValueError('When setting `include_top=True` ' 56 | 'and loading `imagenet` weights, ' 57 | '`input_shape` should be ' + 58 | str(default_shape) + '.') 59 | return default_shape 60 | if input_shape: 61 | if data_format == 'channels_first': 62 | if input_shape is not None: 63 | if len(input_shape) != 3: 64 | raise ValueError( 65 | '`input_shape` must be a tuple of three integers.') 66 | if input_shape[0] != 3 and weights == 'imagenet': 67 | raise ValueError('The input must have 3 channels; got ' 68 | '`input_shape=' + str(input_shape) + '`') 69 | if ((input_shape[1] is not None and input_shape[1] < min_size) or 70 | (input_shape[2] is not None and input_shape[2] < min_size)): 71 | raise ValueError('Input size must be at least ' + 72 | str(min_size) + 'x' + str(min_size) + 73 | '; got `input_shape=' + 74 | str(input_shape) + '`') 75 | else: 76 | if input_shape is not None: 77 | if len(input_shape) != 3: 78 | raise ValueError( 79 | '`input_shape` must be a tuple of three integers.') 80 | if input_shape[-1] != 3 and weights == 'imagenet': 81 | raise ValueError('The input must have 3 channels; got ' 82 | '`input_shape=' + str(input_shape) + '`') 83 | if ((input_shape[0] is not None and input_shape[0] < min_size) or 84 | (input_shape[1] is not None and input_shape[1] < min_size)): 85 | raise ValueError('Input size must be at least ' + 86 | str(min_size) + 'x' + str(min_size) + 87 | '; got `input_shape=' + 88 | str(input_shape) + '`') 89 | else: 90 | if require_flatten: 91 | input_shape = default_shape 92 | else: 93 | if data_format == 'channels_first': 94 | input_shape = (3, None, None) 95 | else: 96 | input_shape = (None, None, 3) 97 | if require_flatten: 98 | if None in input_shape: 99 | raise ValueError('If `include_top` is True, ' 100 | 'you should specify a static `input_shape`. ' 101 | 'Got `input_shape=' + str(input_shape) + '`') 102 | return input_shape 103 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/attention_unet.py: -------------------------------------------------------------------------------- 1 | from .unet import conv, upsample, downsample, logger 2 | import math 3 | from tensorflow.keras.layers import Input 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras import backend as K 6 | import tensorflow as tf 7 | 8 | 9 | def attention(filters, x, g): 10 | # conv, norm on g with stride=1, kernel_size=1, no activation 11 | g2 = conv(g, filters, kernel_size=(1, 1), strides=(1, 1), activation=None, norm='batch') 12 | # conv norm on x 13 | x2 = conv(x, filters, kernel_size=(1, 1), strides=(1, 1), activation=None, norm='batch') 14 | 15 | psi = tf.nn.relu(x2 + g2) 16 | psi = conv(psi, 1, kernel_size=(1, 1), strides=(1, 1), activation=None, norm='batch') 17 | psi = tf.nn.sigmoid(psi) 18 | return x * psi 19 | 20 | 21 | def attention_unet(input_shape=(256, 256, 1), num_classes=3, depth=5, activation='relu', num_first_filters=64, l2=None, 22 | upsampling_method='conv', downsampling_method='max_pool', conv_type='conv'): 23 | """ 24 | https://arxiv.org/pdf/1505.04597.pdf 25 | """ 26 | logger.debug("building model unet with args %s" % (str(locals()))) 27 | inputs = Input(input_shape) 28 | 29 | y = inputs 30 | layers = [] 31 | 32 | features = [int(pow(2, math.log2(num_first_filters) + i)) for i in range(depth)] 33 | 34 | for k, num_filters in enumerate(features): 35 | y = conv(y, num_filters, activation=activation, l2=l2, conv_type=conv_type) 36 | y = conv(y, num_filters, activation=activation, l2=l2, conv_type=conv_type) 37 | layers.append(y) 38 | 39 | if k != (len(features) - 1): 40 | y = downsample(y, method=downsampling_method, activation=activation, l2=l2) 41 | 42 | logger.debug("encoder - features: %d, shape: %s" % (num_filters, str(y.shape))) 43 | 44 | for k, num_filters in enumerate(reversed(features[:-1])): 45 | y = upsample(y, method=upsampling_method, activation=activation, l2=l2) 46 | # y = conv(y, num_filters, kernel_size=(2, 2), activation=activation, l2=l2, conv_type=conv_type) 47 | att = attention(num_filters, layers[-(k + 2)], y) 48 | y = K.concatenate([att, y]) 49 | logger.debug("concat shape: %s" % str(y.shape)) 50 | y = conv(y, num_filters, activation=activation, l2=l2, conv_type=conv_type) 51 | y = conv(y, num_filters, activation=activation, l2=l2, conv_type=conv_type) 52 | logger.debug("decoder - features: %d, shape: %s" % (num_filters, str(y.shape))) 53 | 54 | y = conv(y, num_classes, kernel_size=(1, 1), activation=None, norm=None) 55 | return Model(inputs, y) 56 | 57 | 58 | if __name__ == "__main__": 59 | model = attention_unet() 60 | model.summary() 61 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/deeplabv3.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/pdf/1706.05587.pdf 2 | # PASCAL VOC: 77.21 (val set) 3 | # Cityscpes: 79.30 % IOU (val), 81.3 % (test) 4 | # Crop Size: 513 5 | # output stride=16 6 | # eval output stride: 8 7 | 8 | from tf_semantic_segmentation.models.apps.resnet50 import resnet50 9 | from tensorflow.keras.models import Model 10 | from tensorflow.keras import layers 11 | import tensorflow as tf 12 | from ..layers import get_norm_by_name 13 | 14 | 15 | def conv(x, filters, kernel_size=(3, 3), dilation=1, activation='relu', norm='batch'): 16 | y = layers.Conv2D(filters, kernel_size=kernel_size, dilation_rate=dilation, activation=None, padding='same')(x) 17 | y = get_norm_by_name(norm)(y) 18 | if activation: 19 | y = layers.Activation(activation)(y) 20 | return y 21 | 22 | 23 | def atrous_spatial_pyramid_pooling(x, depth=256, norm='batch'): 24 | shape = x.shape 25 | 26 | # pooling 27 | pool = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 28 | pool = conv(pool, depth, kernel_size=(1, 1)) 29 | pool = tf.image.resize(pool, (shape[1], shape[2])) 30 | 31 | # 4 level pyramid 32 | l1 = conv(x, depth, (1, 1), norm=norm) 33 | l2 = conv(x, depth, (3, 3), dilation=6, norm=norm) 34 | l3 = conv(x, depth, (3, 3), dilation=12, norm=norm) 35 | l4 = conv(x, depth, (3, 3), dilation=18, norm=norm) 36 | 37 | # concat features and 1x1 conv 38 | y = tf.concat([l1, l2, l3, l4, pool], axis=-1) 39 | y = conv(y, depth, kernel_size=(1, 1)) 40 | return y 41 | 42 | 43 | def deeplabv3(input_shape=(512, 512, 3), num_classes=2, encoder_weights='imagenet'): 44 | 45 | base_model = resnet50(input_shape=input_shape, encoder_weights=encoder_weights) 46 | 47 | y = base_model.outputs[-3] 48 | y = atrous_spatial_pyramid_pooling(y) 49 | y = layers.Conv2D(num_classes, kernel_size=1)(y) 50 | 51 | # upsampling 52 | y = tf.image.resize(y, input_shape[:2]) 53 | return Model(inputs=base_model.inputs, outputs=y) 54 | 55 | 56 | if __name__ == "__main__": 57 | model = deeplabv3() 58 | model.summary() 59 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/erfnet.py: -------------------------------------------------------------------------------- 1 | from ..layers import get_norm_by_name 2 | from ..settings import logger 3 | 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras import layers 6 | from tensorflow.keras import regularizers 7 | from tensorflow.keras import backend as K 8 | 9 | 10 | def conv(x, filters, kernel_size, strides=1, norm='batch', activation='relu', l2=None, rate=1, deconv=False): 11 | 12 | if deconv: 13 | output_shape = list(K.int_shape(x)[1:]) 14 | output_shape[0] = int(output_shape[0] * strides) 15 | output_shape[1] = int(output_shape[1] * strides) 16 | output_shape[2] = filters 17 | logger.debug(str(output_shape)) 18 | 19 | y = layers.Conv2DTranspose(filters, kernel_size, strides=strides, padding='SAME', activation=activation, 20 | dilation_rate=rate, kernel_regularizer=regularizers.l2(l2) if l2 else None)(x) 21 | # y = K.reshape(y, [-1] + output_shape) 22 | # y = layers.Reshape(output_shape)(y) 23 | else: 24 | y = layers.Conv2D(filters, kernel_size, strides=strides, padding='SAME', activation=activation, 25 | kernel_regularizer=regularizers.l2(l2) if l2 else None, 26 | dilation_rate=rate)(x) 27 | 28 | # kernel_regularizer=regularizers.l2(l2) if l2 else None 29 | if norm: 30 | y = get_norm_by_name(norm)(y) 31 | 32 | return y 33 | 34 | 35 | def factorized_module(x, dropout=0.3, dilation=[1, 1], l2=None): 36 | logger.debug("factorized: %s" % str(locals())) 37 | n = K.int_shape(x)[-1] 38 | y = conv(x, n, [3, 1], rate=dilation[0], norm=None, l2=l2) 39 | y = conv(y, n, [1, 3], rate=dilation[0], l2=l2) 40 | y = conv(y, n, [3, 1], rate=dilation[1], norm=None, l2=l2) 41 | y = conv(y, n, [1, 3], rate=dilation[1], l2=l2) 42 | y = layers.Dropout(dropout)(y) 43 | y = layers.Add()([x, y]) 44 | return y 45 | 46 | 47 | def downsample(x, n, activation='relu', norm=None, l2=None): 48 | logger.debug('downsample: %s' % str(locals())) 49 | f_in = int(K.int_shape(x)[-1]) 50 | f_conv = int(n - f_in) 51 | branch_1 = conv(x, f_conv, 3, strides=2, l2=l2) 52 | branch_2 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME")(x) 53 | return layers.Concatenate(axis=-1)([branch_1, branch_2]) 54 | 55 | 56 | def upsample(x, n, norm=None, activation=None, l2=None): 57 | return conv(x, n, 3, strides=2, deconv=True, l2=l2) 58 | 59 | 60 | def erfnet(input_shape=(256, 256, 1), num_classes=3, l2=None): 61 | x = layers.Input(shape=input_shape, name='inputs') 62 | 63 | y = downsample(x, 16, l2=l2) 64 | y = downsample(y, 64, l2=l2) 65 | 66 | for i in range(5): 67 | y = factorized_module(y, dilation=[1, 1], l2=l2) 68 | 69 | y = downsample(y, 128, l2=l2) 70 | for k in range(2): 71 | for i in range(4): 72 | y = factorized_module(y, dilation=[1, pow(2, i + 1)], l2=l2) 73 | 74 | logger.debug("upsampling...") 75 | y = upsample(y, 64) 76 | for i in range(2): 77 | y = factorized_module(y, dilation=[1, 1], l2=l2) 78 | 79 | y = upsample(y, 16) 80 | for i in range(2): 81 | y = factorized_module(y, dilation=[1, 1], l2=l2) 82 | 83 | y = upsample(y, num_classes, l2=l2) 84 | return Model(inputs=x, outputs=y) 85 | 86 | 87 | if __name__ == "__main__": 88 | erfnet().summary() 89 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/fcn.py: -------------------------------------------------------------------------------- 1 | from tf_semantic_segmentation.layers.utils import get_norm_by_name 2 | from tensorflow.keras import layers 3 | from tensorflow.keras import Model 4 | from .apps import resnet50 5 | 6 | 7 | def fcn_head(x, features, acivation='relu', norm='batch'): 8 | inter_channels = x.shape[-1] // 4 9 | y = layers.Conv2D(inter_channels, kernel_size=3, padding="SAME", use_bias=False)(x) 10 | y = get_norm_by_name(norm)(y) 11 | y = layers.Activation(acivation)(y) 12 | y = layers.Dropout(0.1)(y) 13 | y = layers.Conv2D(features, kernel_size=1)(y) 14 | return y 15 | 16 | 17 | def fcn(input_shape=(256, 256, 3), num_classes=8, upsample_factor=32, activation='relu', norm='batch', encoder_weights='imagenet'): 18 | base_model = resnet50.ResNet50(input_shape=input_shape, include_top=False, weights=encoder_weights) 19 | 20 | for l in base_model.layers: 21 | l.trainable = True 22 | 23 | # conv0 = base_model.get_layer("activation").output 24 | # conv1 = base_model.get_layer("activation_1").output 25 | conv2 = base_model.get_layer("activation_10").output 26 | conv3 = base_model.get_layer("activation_22").output 27 | # conv4 = base_model.get_layer("activation_40").output 28 | conv5 = base_model.get_layer("activation_48").output 29 | 30 | if upsample_factor == 8: 31 | y = fcn_head(conv2, 2048, acivation=activation, norm=norm) 32 | y = layers.UpSampling2D(size=(upsample_factor, upsample_factor), interpolation='bilinear')(y) 33 | elif upsample_factor == 16: 34 | y = fcn_head(conv3, 2048, acivation=activation, norm=norm) 35 | y = layers.UpSampling2D(size=(upsample_factor, upsample_factor), interpolation='bilinear')(y) 36 | elif upsample_factor == 32: 37 | y = fcn_head(conv5, 2048, acivation=activation, norm=norm) 38 | y = layers.UpSampling2D(size=(upsample_factor, upsample_factor), interpolation='bilinear')(y) 39 | else: 40 | raise Exception("upsample factor %d is invalid" % upsample_factor) 41 | 42 | y = layers.Conv2D(num_classes, kernel_size=1)(y) 43 | print("shape:", y.shape) 44 | # x = interpolate(x, imsize, **self._up_kwargs) 45 | return Model(inputs=base_model.inputs, outputs=y) 46 | 47 | 48 | if __name__ == "__main__": 49 | fcn(input_shape=(256, 256, 3), num_classes=3, upsample_factor=32).summary() 50 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/models/nested_unet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import layers, regularizers 2 | from ..layers import get_norm_by_name 3 | from tensorflow.keras import backend as K 4 | from tensorflow.keras.models import Model 5 | 6 | 7 | def conv(x, filters, kernel_size=(3, 3), l2=None, padding='SAME', activation='relu'): 8 | y = layers.Conv2D(filters, kernel_size=kernel_size, 9 | kernel_regularizer=regularizers.l2(l2) if l2 else None, 10 | activation=None, 11 | padding=padding)(x) 12 | y = get_norm_by_name('batch')(y) 13 | y = layers.Activation(activation)(y) 14 | return y 15 | 16 | 17 | def nested_conv(x, filters, residual=False): 18 | y = conv(x, filters) 19 | y = conv(y, filters) 20 | if residual: 21 | return x + y 22 | else: 23 | return y 24 | 25 | 26 | def nested_unet(input_shape=(256, 256, 3), num_classes=2, num_first_filters=64, depth=4): 27 | """UNet++ aims to improve segmentation accuracy, with a series of nested, dense skip pathways. 28 | 29 | Redesigned skip pathways made optimisation easier with the semantically similar feature maps. 30 | Dense skip connections improve segmentation accuracy and improve gradient flow. 31 | 32 | https://arxiv.org/abs/1807.10165 33 | """ 34 | n1 = num_first_filters 35 | filters = [num_first_filters * pow(2, i) for i in range(5)] 36 | assert(depth in [3, 4, 5]), 'depth has to be either 3, 4 or 5' 37 | 38 | pool = layers.MaxPooling2D(pool_size=(2, 2)) 39 | up = layers.UpSampling2D((2, 2), interpolation='bilinear') 40 | 41 | x = layers.Input(shape=input_shape, name='inputs') 42 | x0_0 = nested_conv(x, filters[0]) 43 | 44 | x1_0 = nested_conv(pool(x0_0), filters[1]) 45 | x0_1 = nested_conv(K.concatenate([x0_0, up(x1_0)]), filters[0]) 46 | 47 | x2_0 = nested_conv(pool(x1_0), filters[1]) 48 | x1_1 = nested_conv(K.concatenate([x1_0, up(x2_0)]), filters[1]) 49 | x0_2 = nested_conv(K.concatenate([x0_0, x0_1, up(x1_1)]), filters[0]) 50 | 51 | # 3 52 | x3_0 = nested_conv(pool(x2_0), filters[2]) 53 | x2_1 = nested_conv(K.concatenate([x2_0, up(x3_0)]), filters[2]) 54 | x1_2 = nested_conv(K.concatenate([x1_0, x1_1, up(x2_1)]), filters[1]) 55 | x0_3 = nested_conv(K.concatenate([x0_0, x0_1, x0_2, up(x1_2)]), filters[0]) 56 | 57 | # 4 58 | x4_0 = nested_conv(pool(x3_0), filters[3]) 59 | x3_1 = nested_conv(K.concatenate([x3_0, up(x4_0)]), filters[3]) 60 | x2_2 = nested_conv(K.concatenate([x2_0, x2_1, up(x3_1)]), filters[2]) 61 | x1_3 = nested_conv(K.concatenate([x1_0, x1_1, x1_2, up(x2_2)]), filters[1]) 62 | x0_4 = nested_conv(K.concatenate([x0_0, x0_1, x0_2, x0_3, up(x1_3)]), filters[0]) 63 | 64 | # 5 65 | x5_0 = nested_conv(pool(x4_0), filters[4]) 66 | x4_1 = nested_conv(K.concatenate([x4_0, up(x5_0)]), filters[4]) 67 | x3_2 = nested_conv(K.concatenate([x3_0, x3_1, up(x4_1)]), filters[3]) 68 | x2_3 = nested_conv(K.concatenate([x2_0, x2_1, x2_2, up(x3_2)]), filters[2]) 69 | x1_4 = nested_conv(K.concatenate([x1_0, x1_1, x1_2, x1_3, up(x2_3)]), filters[1]) 70 | x0_5 = nested_conv(K.concatenate([x0_0, x0_1, x0_2, x0_3, x0_4, up(x1_4)]), filters[0]) 71 | 72 | if depth == 3: 73 | output = layers.Conv2D(num_classes, kernel_size=(1, 1), activation=None)(x0_3) 74 | elif depth == 4: 75 | output = layers.Conv2D(num_classes, kernel_size=(1, 1), activation=None)(x0_4) 76 | elif depth == 5: 77 | output = layers.Conv2D(num_classes, kernel_size=(1, 1), activation=None)(x0_5) 78 | else: 79 | raise Exception("depth %d is invalid" % depth) 80 | 81 | return Model(outputs=output, inputs=x) 82 | 83 | 84 | if __name__ == "__main__": 85 | model = nested_unet(depth=4) 86 | model.summary() 87 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import optimizers 2 | 3 | 4 | def get_optimizer_by_name(name, lr) -> optimizers.Optimizer: 5 | if name == 'adam': 6 | return optimizers.Adam(learning_rate=lr) 7 | 8 | elif name == 'radam': 9 | import tensorflow_addons as tfa 10 | return tfa.optimizers.RectifiedAdam(learning_rate=lr) 11 | 12 | elif name == 'ranger': 13 | import tensorflow_addons as tfa 14 | radam = tfa.optimizers.RectifiedAdam(learning_rate=lr) 15 | ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5) 16 | return ranger 17 | else: 18 | raise Exception("unknown optimizer %s" % name) 19 | 20 | 21 | names = ['adam', 'radam', 'ranger'] 22 | 23 | __all__ = ['names', 'get_optimizer_by_name'] 24 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/processing/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class ColorMode(IntEnum): 5 | RGB, GRAY, NONE = range(3) 6 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/processing/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import imageio 4 | 5 | 6 | def fixed_resize(image, width=None, height=None, interpolation=cv2.INTER_NEAREST): 7 | assert(width is not None or height is not None) 8 | size = image.shape[:2][::-1] 9 | depth = image.shape[-1] if len(image.shape) == 3 else 0 10 | if width: 11 | f = width / size[0] 12 | height = int(size[1] * f) 13 | else: 14 | f = height / size[1] 15 | width = int(size[0] * f) 16 | 17 | if image.shape[0] != height or image.shape[1] != width: 18 | image = cv2.resize(image, (width, height), 0, 0, interpolation=interpolation) 19 | 20 | if depth and len(image.shape) == 2: 21 | image = np.expand_dims(image, axis=-1) 22 | return image 23 | 24 | 25 | def grayscale_grid_vis(X, nh, nw, save_path=None): 26 | """ https://github.com/Newmu/dcgan_code/blob/master/lib/vis.py """ 27 | h, w = X[0].shape[:2] 28 | img = np.zeros((h * nh, w * nw), np.uint8) 29 | for n, x in enumerate(X): 30 | j = int(n / nw) 31 | i = n % nw 32 | img[int(j * h):int(j * h + h), int(i * w):int(i * w + w)] = x 33 | 34 | if save_path is not None: 35 | imageio.imwrite(save_path, img) 36 | return img 37 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/settings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(name='tf-semantic-segmentation'): 5 | 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | # create formatter and add it to the handlers 10 | _formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 11 | console_handler = logging.StreamHandler() 12 | 13 | # output of subprocess call will not be printed into the console if using logging.INFO 14 | # this prohibits printing of subprocess call output twice 15 | console_handler.setLevel(logging.DEBUG) 16 | console_handler.setFormatter(_formatter) 17 | if len(logger.handlers) == 0: 18 | logger.addHandler(console_handler) 19 | return logger 20 | 21 | 22 | logger = get_logger() 23 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/threading.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import sys 3 | import subprocess 4 | import multiprocessing 5 | import threading 6 | import tqdm 7 | import os 8 | 9 | 10 | class ThreadWithReturnValue(threading.Thread): 11 | def __init__(self, group=None, target=None, name=None, 12 | args=(), kwargs={}, Verbose=None): 13 | threading.Thread.__init__(self, group, target, name, args, kwargs) 14 | self._return = None 15 | 16 | def run(self): 17 | if self._target is not None: 18 | self._return = self._target(*self._args, **self._kwargs) 19 | 20 | def join(self, *args): 21 | if sys.version_info >= (3, 0): 22 | threading.Thread.join(self, *args) 23 | else: 24 | threading.Thread.join(self) 25 | 26 | return self._return 27 | 28 | 29 | def parallize_v2(f, args, desc='threading'): 30 | threads = multiprocessing.cpu_count() 31 | return parallize(f, args, threads=threads, desc=desc) 32 | 33 | 34 | def parallize(f, args, threads=None, desc='threading'): 35 | """ 36 | Args: 37 | - f: function 38 | - args: list or list(tuple), list when threads not None 39 | """ 40 | def parse_arg(arg): 41 | if type(arg) == list or type(arg) == set or type(arg) == tuple: 42 | return tuple(arg) 43 | elif type(arg) == dict: 44 | return arg 45 | else: 46 | return (arg, ) 47 | 48 | if threads is not None: 49 | 50 | results = [] 51 | for i in tqdm.trange(0, len(args), threads, desc=desc): 52 | func_args = [parse_arg(arg) for arg in args[i:i + threads]] 53 | if len(func_args) == 0: 54 | continue 55 | 56 | if type(func_args[0]) == dict: 57 | active_threads = [ThreadWithReturnValue( 58 | target=f, kwargs=arg) for arg in func_args] 59 | else: 60 | active_threads = [ThreadWithReturnValue( 61 | target=f, args=arg) for arg in func_args] 62 | [thread.start() for thread in active_threads] 63 | results += [thread.join() for thread in active_threads] 64 | return results 65 | else: 66 | if len(args) == 0: 67 | return [] 68 | 69 | if type(args[0]) == dict: 70 | active_threads = [ThreadWithReturnValue( 71 | target=f, kwargs=arg) for arg in args] 72 | else: 73 | args = [parse_arg(arg) for arg in args] 74 | active_threads = [ThreadWithReturnValue( 75 | target=f, args=arg) for arg in args] 76 | 77 | [thread.start() for thread in active_threads] 78 | return [thread.join() for thread in active_threads] 79 | 80 | 81 | def parallize_v3(f, args, n_processes=None, desc='parallize_v3'): 82 | if n_processes == None: 83 | n_processes = multiprocessing.cpu_count() 84 | 85 | with multiprocessing.Pool(n_processes) as pool: 86 | results = [r for r in tqdm.tqdm(pool.imap(f, args), desc=desc, total=len(args))] 87 | 88 | return results 89 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/version.py: -------------------------------------------------------------------------------- 1 | version = (0, 3, 1) 2 | __version__ = ".".join(map(str, version)) 3 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/visualizations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baudcode/tf-semantic-segmentation/6e425287f309f86b931e34bbbb804bbc46963e28/tf_semantic_segmentation/visualizations/__init__.py -------------------------------------------------------------------------------- /tf_semantic_segmentation/visualizations/masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageColor 3 | import colorsys 4 | import random 5 | import cv2 6 | 7 | 8 | def get_colors(N, shuffle=False, bright=True): 9 | """ 10 | https://github.com/pedropro/TACO/blob/master/detector/visualize.py 11 | 12 | Generate random colors. 13 | To get visually distinct colors, generate them in HSV space then 14 | convert to RGB. 15 | """ 16 | colors = [[0, 0, 0]] 17 | if N == 1: 18 | return colors 19 | 20 | brightness = 1.0 if bright else 0.7 21 | hsv = [(i / (N - 1), 1, brightness) for i in range(N - 1)] 22 | colors.extend(list(map(lambda c: list(map(lambda x: int(x * 255), colorsys.hsv_to_rgb(*c))), hsv))) 23 | 24 | if shuffle: 25 | random.shuffle(colors) 26 | return colors 27 | 28 | 29 | def apply_mask(image, mask, color, alpha=0.5): 30 | for c in range(3): 31 | image[:, :, c] = np.where(mask == 1, 32 | image[:, :, c] * 33 | (1 - alpha) + alpha * color[c], 34 | image[:, :, c]) 35 | return image 36 | 37 | 38 | def overlay_classes(image, target, colors, num_classes, alpha=0.5): 39 | assert(len(colors) == num_classes) 40 | for k, color in enumerate(colors): 41 | mask = np.where(target == k, 1, 0) 42 | image = apply_mask(image, mask, color, alpha=alpha) 43 | return image 44 | 45 | 46 | def get_colored_segmentation_mask(predictions, num_classes, images=None, binary_threshold=0.5, alpha=0.5): 47 | """ 48 | Arguments: 49 | 50 | predictions: ndarray - BHWC (if C=1 - np.float32 else np.uint8) probabilities 51 | num_classes: int - number of classes 52 | images: ndarray - float32 (0..1) or uint8 (0..255) 53 | binary_threshold: float - when predicting only 1 value, threshold to set label to 1 54 | alpha: float - overlay percentage 55 | """ 56 | predictions = predictions.copy() 57 | colors = get_colors(num_classes) 58 | 59 | if images is None: 60 | shape = (predictions.shape[0], predictions.shape[1], predictions.shape[2], 3) 61 | images = np.zeros(shape, np.uint8) 62 | else: 63 | images = images.copy() 64 | 65 | if images.dtype == "float32" or images.dtype == 'float64': 66 | images = (images * 255).astype(np.uint8) 67 | 68 | if images.shape[-1] == 1: 69 | images = [cv2.cvtColor(i.copy(), cv2.COLOR_GRAY2RGB) for i in images] 70 | images = np.asarray(images) 71 | 72 | if predictions.shape[-1] == 1: 73 | # remove channel dimension 74 | predictions = np.squeeze(predictions, axis=-1) 75 | 76 | # set either zero or one 77 | predictions[predictions > binary_threshold] = 1.0 78 | predictions[predictions <= binary_threshold] = 0.0 79 | else: 80 | # find the argmax channel from all channels 81 | predictions = np.argmax(predictions, axis=-1) 82 | 83 | predictions = predictions.astype(np.uint8) 84 | 85 | for i in range(len(predictions)): 86 | images[i, :, :, :] = overlay_classes(images[i, :, :, :].copy(), predictions[i], colors, num_classes, alpha=alpha) 87 | 88 | return images 89 | -------------------------------------------------------------------------------- /tf_semantic_segmentation/visualizations/show.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def show_images(images, cols=1, titles=None, show=True): 5 | """Display a list of images in a single figure with matplotlib. 6 | https://gist.github.com/soply/f3eec2e79c165e39c9d540e916142ae1 7 | 8 | Parameters 9 | --------- 10 | images: List of np.arrays compatible with plt.imshow. 11 | 12 | cols (Default = 1): Number of columns in figure (number of rows is 13 | set to np.ceil(n_images/float(cols))). 14 | 15 | titles: List of titles corresponding to each image. Must have 16 | the same length as titles. 17 | """ 18 | from matplotlib import pyplot as plt 19 | assert((titles is None)or (len(images) == len(titles))) 20 | n_images = len(images) 21 | if titles is None: 22 | titles = ['Image (%d)' % i for i in range(1, n_images + 1)] 23 | fig = plt.figure() 24 | for n, (image, title) in enumerate(zip(images, titles)): 25 | a = fig.add_subplot(cols, np.ceil(n_images / float(cols)), n + 1) 26 | if image.ndim == 3 and image.shape[2] == 1: 27 | image = np.squeeze(image, axis=2) 28 | gray = True 29 | if image.ndim == 2: 30 | gray = True 31 | else: 32 | gray = False 33 | 34 | if gray: 35 | plt.gray() 36 | 37 | if image.dtype == np.uint8: 38 | a.imshow(image, cmap='gray' if gray else 'jet', vmin=0, vmax=255) 39 | else: 40 | a.imshow(image, cmap='gray' if gray else 'jet') 41 | 42 | a.set_title(title) 43 | fig.set_size_inches(np.array(fig.get_size_inches()) * n_images) 44 | if show: 45 | plt.show() 46 | --------------------------------------------------------------------------------