├── tests ├── __init__.py ├── utils │ ├── __init__.py │ ├── test_transform.py │ └── test_anchors.py ├── backend │ ├── __init__.py │ └── test_common.py ├── layers │ ├── __init__.py │ ├── test_filter_detections.py │ └── test_misc.py ├── models │ ├── __init__.py │ ├── test_densenet.py │ └── test_mobilenet.py ├── preprocessing │ ├── __init__.py │ ├── test_csv_generator.py │ └── test_generator.py ├── test_losses.py └── bin │ └── test_train.py ├── keras_retinanet ├── __init__.py ├── bin │ ├── __init__.py │ └── convert_model.py ├── utils │ ├── __init__.py │ ├── model.py │ ├── keras_version.py │ ├── compute_overlap.pyx │ ├── config.py │ ├── colors.py │ ├── coco_eval.py │ └── visualization.py ├── preprocessing │ ├── __init__.py │ ├── coco.py │ ├── kitti.py │ ├── pascal_voc.py │ └── csv_generator.py ├── callbacks │ ├── __init__.py │ ├── coco.py │ ├── common.py │ └── eval.py ├── backend │ ├── __init__.py │ ├── cntk_backend.py │ ├── theano_backend.py │ ├── dynamic.py │ ├── tensorflow_backend.py │ └── common.py ├── layers │ ├── __init__.py │ ├── _misc.py │ └── filter_detections.py ├── initializers.py ├── models │ ├── densenet.py │ ├── mobilenet.py │ ├── __init__.py │ ├── resnet.py │ └── vgg.py └── losses.py ├── slurm_template.sh ├── examples ├── 20130320T004433.135911.Cam6_52.png ├── 20130320T004433.707425.Cam6_54.png ├── 20130320T004443.802883.Cam6_42.png ├── 20130320T004445.136158.Cam6_23.png ├── .ipynb_checkpoints │ ├── Inference-checkpoint.ipynb │ ├── Load_model-checkpoint.ipynb │ └── convert_csvs-checkpoint.ipynb ├── config_202.ini ├── config_512.ini ├── inference.sh ├── data_preproccessing.ipynb ├── inference.py ├── convert_annotations.py └── anchor_optimization.py ├── .gitignore ├── .gitmodules ├── setup.cfg ├── run_eval.sh ├── .travis.yml ├── CONTRIBUTORS.md ├── ISSUE_TEMPLATE.md ├── setup.py └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * # noqa: F401,F403 2 | -------------------------------------------------------------------------------- /keras_retinanet/backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic import * # noqa: F401,F403 2 | from .common import * # noqa: F401,F403 3 | -------------------------------------------------------------------------------- /slurm_template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p lyceum 4 | 5 | #SBATCH --time=60:00:00 # walltime 6 | #SBATCH --gres=gpu:1 7 | -------------------------------------------------------------------------------- /examples/20130320T004433.135911.Cam6_52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikostsagk/Apple-detection/HEAD/examples/20130320T004433.135911.Cam6_52.png -------------------------------------------------------------------------------- /examples/20130320T004433.707425.Cam6_54.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikostsagk/Apple-detection/HEAD/examples/20130320T004433.707425.Cam6_54.png -------------------------------------------------------------------------------- /examples/20130320T004443.802883.Cam6_42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikostsagk/Apple-detection/HEAD/examples/20130320T004443.802883.Cam6_42.png -------------------------------------------------------------------------------- /examples/20130320T004445.136158.Cam6_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikostsagk/Apple-detection/HEAD/examples/20130320T004445.136158.Cam6_23.png -------------------------------------------------------------------------------- /examples/.ipynb_checkpoints/Inference-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /examples/.ipynb_checkpoints/Load_model-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /examples/config_202.ini: -------------------------------------------------------------------------------- 1 | [anchor_parameters] 2 | sizes = 32 64 128 256 512 3 | strides = 8 16 32 64 128 4 | ratios = 0.805 1.0 1.242 5 | scales = 0.696 1.0 1.313 6 | -------------------------------------------------------------------------------- /examples/config_512.ini: -------------------------------------------------------------------------------- 1 | [anchor_parameters] 2 | sizes = 32 64 128 256 512 3 | strides = 8 16 32 64 128 4 | ratios = 0.66 1.0 1.514 5 | scales = 1.271 1.577 2.0 6 | -------------------------------------------------------------------------------- /examples/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p lyceum 4 | 5 | #SBATCH --time=60:00:00 # walltime 6 | #SBATCH --gres=gpu:1 7 | 8 | python inference.py -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.c 4 | *.so 5 | /.pytest_cache 6 | /.cache 7 | *.h5 8 | *.zip 9 | data/* 10 | *.sh 11 | results/ 12 | imagenet-weights/ 13 | -------------------------------------------------------------------------------- /keras_retinanet/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from ._misc import RegressBoxes, UpsampleLike, Anchors, ClipBoxes # noqa: F401 2 | from .filter_detections import FilterDetections # noqa: F401 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tests/test-data"] 2 | path = tests/test-data 3 | url = https://github.com/fizyr/keras-retinanet-test-data.git 4 | [submodule "msc-project-report"] 5 | path = msc-project-report 6 | url = https://github.com/nikostsagk/msc-project-report.git 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # ignore: 2 | # E201 whitespace after '[' 3 | # E202 whitespace before ']' 4 | # E203 whitespace before ':' 5 | # E221 multiple spaces before operator 6 | # E241 multiple spaces after ',' 7 | # E251 unexpected spaces around keyword / parameter equals 8 | # E501 line too long (85 > 79 characters) 9 | # W504 line break after binary operator 10 | [tool:pytest] 11 | flake8-max-line-length = 127 12 | flake8-ignore = E201 E202 E203 E221 E241 E251 E402 E501 W504 13 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p lyceum 4 | 5 | #SBATCH --time=60:00:00 # walltime 6 | #SBATCH --gres=gpu:1 7 | 8 | python ./keras_retinanet/bin/evaluate.py --backbone 'resnet152' --iou-threshold 0.5 --convert-model --save-path ./dummy_predictions --score-threshold 0.99 \ 9 | csv ./data/acfr-fruit-dataset/apples/rectangular_annotations/dummy_val_annotations.csv ./data/acfr-fruit-dataset/apples/rectangular_annotations/classes.csv \ 10 | ./resnet152_csv_10.h5 11 | 12 | 13 | -------------------------------------------------------------------------------- /keras_retinanet/backend/cntk_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /keras_retinanet/backend/theano_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | -------------------------------------------------------------------------------- /keras_retinanet/backend/dynamic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _BACKEND = "tensorflow" 4 | 5 | if "KERAS_BACKEND" in os.environ: 6 | _backend = os.environ["KERAS_BACKEND"] 7 | 8 | backends = { 9 | "cntk", 10 | "tensorflow", 11 | "theano" 12 | } 13 | 14 | assert _backend in backends 15 | 16 | _BACKEND = _backend 17 | 18 | if _BACKEND == "cntk": 19 | from .cntk_backend import * # noqa: F401,F403 20 | elif _BACKEND == "theano": 21 | from .theano_backend import * # noqa: F401,F403 22 | elif _BACKEND == "tensorflow": 23 | from .tensorflow_backend import * # noqa: F401,F403 24 | else: 25 | raise ValueError("Unknown backend: " + str(_BACKEND)) 26 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: required 3 | python: 4 | - '3.6' 5 | - '2.7' 6 | install: 7 | - pip install 'cython' 8 | - pip install 'keras-resnet' 9 | - pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 10 | - pip install 'h5py' 11 | - pip install 'keras' 12 | - pip install 'matplotlib' 13 | - pip install 'numpy>=1.14' 14 | - pip install 'opencv-python>=3.3.0' 15 | - pip install 'pillow' 16 | - pip install 'progressbar2' 17 | - pip install 'pytest-flake8' 18 | - pip install 'pytest-xdist' 19 | - pip install 'tensorflow' 20 | cache: pip 21 | script: 22 | - python setup.py build_ext --inplace 23 | - py.test --flake8 --forked 24 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import keras_retinanet.losses 2 | import keras 3 | 4 | import numpy as np 5 | 6 | import pytest 7 | 8 | 9 | def test_smooth_l1(): 10 | regression = np.array([ 11 | [ 12 | [0, 0, 0, 0], 13 | [0, 0, 0, 0], 14 | [0, 0, 0, 0], 15 | [0, 0, 0, 0], 16 | ] 17 | ], dtype=keras.backend.floatx()) 18 | regression = keras.backend.variable(regression) 19 | 20 | regression_target = np.array([ 21 | [ 22 | [0, 0, 0, 1, 1], 23 | [0, 0, 1, 0, 1], 24 | [0, 0, 0.05, 0, 1], 25 | [0, 0, 1, 0, 0], 26 | ] 27 | ], dtype=keras.backend.floatx()) 28 | regression_target = keras.backend.variable(regression_target) 29 | 30 | loss = keras_retinanet.losses.smooth_l1()(regression_target, regression) 31 | loss = keras.backend.eval(loss) 32 | 33 | assert loss == pytest.approx((((1 - 0.5 / 9) * 2 + (0.5 * 9 * 0.05 ** 2)) / 3)) 34 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | This is a list of people who contributed patches to keras-retinanet. 4 | 5 | If you feel you should be listed here or if you have any other questions/comments on your listing here, 6 | please create an issue or pull request at https://github.com/fizyr/keras-retinanet/ 7 | 8 | * Hans Gaiser 9 | * Maarten de Vries 10 | * Ashley Williamson 11 | * Yann Henon 12 | * Valeriu Lacatusu 13 | * András Vidosits 14 | * Cristian Gratie 15 | * jjiunlin 16 | * Sorin Panduru 17 | * Rodrigo Meira de Andrade 18 | * Enrico Liscio 19 | * Mihai Morariu 20 | * pedroconceicao 21 | * jjiun 22 | * Wudi Fang 23 | * Mike Clark 24 | * hannesedvartsen 25 | * Max Van Sande 26 | * Pierre Dérian 27 | * ori 28 | * mxvs 29 | * mwilder 30 | * Muhammed Kocabas 31 | * Max Van Sande 32 | * Koen Vijverberg 33 | * iver56 34 | * hnsywangxin 35 | * Guillaume Erhard 36 | * Eduardo Ramos 37 | * DiegoAgher 38 | * Alexander Pacha 39 | -------------------------------------------------------------------------------- /keras_retinanet/utils/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | def freeze(model): 19 | """ Set all layers in a model to non-trainable. 20 | 21 | The weights for these layers will not be updated during training. 22 | 23 | This function modifies the given model in-place, 24 | but it also returns the modified model to allow easy chaining with other functions. 25 | """ 26 | for layer in model.layers: 27 | layer.trainable = False 28 | return model 29 | -------------------------------------------------------------------------------- /keras_retinanet/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | 19 | import numpy as np 20 | import math 21 | 22 | 23 | class PriorProbability(keras.initializers.Initializer): 24 | """ Apply a prior probability to the weights. 25 | """ 26 | 27 | def __init__(self, probability=0.01): 28 | self.probability = probability 29 | 30 | def get_config(self): 31 | return { 32 | 'probability': self.probability 33 | } 34 | 35 | def __call__(self, shape, dtype=None): 36 | # set bias to -log((1 - p)/p) for foreground 37 | result = np.ones(shape, dtype=dtype) * -math.log((1 - self.probability) / self.probability) 38 | 39 | return result 40 | -------------------------------------------------------------------------------- /ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Please make sure that you follow the steps below when creating an issue. 2 | Only use GitHub issues for issues with the implementation, not for issues with specific datasets or general questions about functionality. 3 | If your issue is an implementation question, please ask your question on the #keras-retinanet Slack channel instead of filing a GitHub issue. 4 | You can find directions for the Slack channel here: https://github.com/fizyr/keras-retinanet#discussions 5 | 6 | Thank you! 7 | 8 | 1. Check that you are up-to-date with the master branch of keras-retinanet. 9 | 2. Check that you are up-to-date with the latest version of Keras: https://github.com/keras-team/keras. 10 | 3. Check that you are up-to-date with the latest version of TensorFlow. 11 | The installation instructions can be found here: https://www.tensorflow.org/get_started/os_setup. 12 | 4. Check that you have read the entire README.md: https://github.com/fizyr/keras-retinanet/README.md. 13 | Most noticably the FAQ section shows common issues: https://github.com/fizyr/keras-retinanet#faq. 14 | 5. Clearly describe the issues you're having including the expected behaviour, the actual behaviour 15 | and the steps required to trigger the issue. 16 | 6. Include relevant output from the commands you're executing, including full stack traces where relevant. 17 | 7. Remove this entire message and replace it with your issue. 18 | -------------------------------------------------------------------------------- /tests/models/test_densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 vidosits (https://github.com/vidosits/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import warnings 18 | import pytest 19 | import numpy as np 20 | import keras 21 | from keras_retinanet import losses 22 | from keras_retinanet.models.densenet import DenseNetBackbone 23 | 24 | parameters = ['densenet121'] 25 | 26 | 27 | @pytest.mark.parametrize("backbone", parameters) 28 | def test_backbone(backbone): 29 | # ignore warnings in this test 30 | warnings.simplefilter('ignore') 31 | 32 | num_classes = 10 33 | 34 | inputs = np.zeros((1, 200, 400, 3), dtype=np.float32) 35 | targets = [np.zeros((1, 14814, 5), dtype=np.float32), np.zeros((1, 14814, num_classes + 1))] 36 | 37 | inp = keras.layers.Input(inputs[0].shape) 38 | 39 | densenet_backbone = DenseNetBackbone(backbone) 40 | model = densenet_backbone.retinanet(num_classes=num_classes, inputs=inp) 41 | model.summary() 42 | 43 | # compile model 44 | model.compile( 45 | loss={ 46 | 'regression': losses.smooth_l1(), 47 | 'classification': losses.focal() 48 | }, 49 | optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)) 50 | 51 | model.fit(inputs, targets, batch_size=1) 52 | -------------------------------------------------------------------------------- /keras_retinanet/utils/keras_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from __future__ import print_function 18 | 19 | import keras 20 | import sys 21 | 22 | minimum_keras_version = 2, 2, 4 23 | 24 | 25 | def keras_version(): 26 | """ Get the Keras version. 27 | 28 | Returns 29 | tuple of (major, minor, patch). 30 | """ 31 | return tuple(map(int, keras.__version__.split('.'))) 32 | 33 | 34 | def keras_version_ok(): 35 | """ Check if the current Keras version is higher than the minimum version. 36 | """ 37 | return keras_version() >= minimum_keras_version 38 | 39 | 40 | def assert_keras_version(): 41 | """ Assert that the Keras version is up to date. 42 | """ 43 | detected = keras.__version__ 44 | required = '.'.join(map(str, minimum_keras_version)) 45 | assert(keras_version() >= minimum_keras_version), 'You are using keras version {}. The minimum required version is {}.'.format(detected, required) 46 | 47 | 48 | def check_keras_version(): 49 | """ Check that the Keras version is up to date. If it isn't, print an error message and exit the script. 50 | """ 51 | try: 52 | assert_keras_version() 53 | except AssertionError as e: 54 | print(e, file=sys.stderr) 55 | sys.exit(1) 56 | -------------------------------------------------------------------------------- /keras_retinanet/utils/compute_overlap.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | 13 | def compute_overlap( 14 | np.ndarray[double, ndim=2] boxes, 15 | np.ndarray[double, ndim=2] query_boxes 16 | ): 17 | """ 18 | Args 19 | a: (N, 4) ndarray of float 20 | b: (K, 4) ndarray of float 21 | 22 | Returns 23 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 24 | """ 25 | cdef unsigned int N = boxes.shape[0] 26 | cdef unsigned int K = query_boxes.shape[0] 27 | cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64) 28 | cdef double iw, ih, box_area 29 | cdef double ua 30 | cdef unsigned int k, n 31 | for k in range(K): 32 | box_area = ( 33 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 34 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 35 | ) 36 | for n in range(N): 37 | iw = ( 38 | min(boxes[n, 2], query_boxes[k, 2]) - 39 | max(boxes[n, 0], query_boxes[k, 0]) + 1 40 | ) 41 | if iw > 0: 42 | ih = ( 43 | min(boxes[n, 3], query_boxes[k, 3]) - 44 | max(boxes[n, 1], query_boxes[k, 1]) + 1 45 | ) 46 | if ih > 0: 47 | ua = np.float64( 48 | (boxes[n, 2] - boxes[n, 0] + 1) * 49 | (boxes[n, 3] - boxes[n, 1] + 1) + 50 | box_area - iw * ih 51 | ) 52 | overlaps[n, k] = iw * ih / ua 53 | return overlaps 54 | -------------------------------------------------------------------------------- /keras_retinanet/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import configparser 18 | import numpy as np 19 | import keras 20 | from ..utils.anchors import AnchorParameters 21 | 22 | 23 | def read_config_file(config_path): 24 | config = configparser.ConfigParser() 25 | 26 | with open(config_path, 'r') as file: 27 | config.read_file(file) 28 | 29 | assert 'anchor_parameters' in config, \ 30 | "Malformed config file. Verify that it contains the anchor_parameters section." 31 | 32 | config_keys = set(config['anchor_parameters']) 33 | default_keys = set(AnchorParameters.default.__dict__.keys()) 34 | 35 | assert config_keys <= default_keys, \ 36 | "Malformed config file. These keys are not valid: {}".format(config_keys - default_keys) 37 | 38 | return config 39 | 40 | 41 | def parse_anchor_parameters(config): 42 | ratios = np.array(list(map(float, config['anchor_parameters']['ratios'].split(' '))), keras.backend.floatx()) 43 | scales = np.array(list(map(float, config['anchor_parameters']['scales'].split(' '))), keras.backend.floatx()) 44 | sizes = list(map(int, config['anchor_parameters']['sizes'].split(' '))) 45 | strides = list(map(int, config['anchor_parameters']['strides'].split(' '))) 46 | 47 | return AnchorParameters(sizes, strides, ratios, scales) 48 | -------------------------------------------------------------------------------- /tests/models/test_mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import warnings 18 | import pytest 19 | import numpy as np 20 | import keras 21 | from keras_retinanet import losses 22 | from keras_retinanet.models.mobilenet import MobileNetBackbone 23 | 24 | alphas = ['1.0'] 25 | parameters = [] 26 | 27 | for backbone in MobileNetBackbone.allowed_backbones: 28 | for alpha in alphas: 29 | parameters.append((backbone, alpha)) 30 | 31 | 32 | @pytest.mark.parametrize("backbone, alpha", parameters) 33 | def test_backbone(backbone, alpha): 34 | # ignore warnings in this test 35 | warnings.simplefilter('ignore') 36 | 37 | num_classes = 10 38 | 39 | inputs = np.zeros((1, 1024, 363, 3), dtype=np.float32) 40 | targets = [np.zeros((1, 68760, 5), dtype=np.float32), np.zeros((1, 68760, num_classes + 1))] 41 | 42 | inp = keras.layers.Input(inputs[0].shape) 43 | 44 | mobilenet_backbone = MobileNetBackbone(backbone='{}_{}'.format(backbone, format(alpha))) 45 | training_model = mobilenet_backbone.retinanet(num_classes=num_classes, inputs=inp) 46 | training_model.summary() 47 | 48 | # compile model 49 | training_model.compile( 50 | loss={ 51 | 'regression': losses.smooth_l1(), 52 | 'classification': losses.focal() 53 | }, 54 | optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)) 55 | 56 | training_model.fit(inputs, targets, batch_size=1) 57 | -------------------------------------------------------------------------------- /examples/data_preproccessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 29, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import glob\n", 10 | "import cv2\n", 11 | "import os\n", 12 | "import sys\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from PIL import Image\n", 16 | "\n", 17 | "script_path = os.getcwd() # .ipynb file\n", 18 | "\n", 19 | "image_path = os.path.join(os.path.join(script_path, os.pardir), 'data/images/')\n", 20 | "train_set = os.path.join(os.path.join(script_path, os.pardir), 'data/sets/train.txt')" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 59, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "(896, 202, 308, 3)\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "def read_image_bgr(path):\n", 38 | " image = np.asarray(Image.open(path).convert('RGB'))\n", 39 | " return image[:, :, ::-1].copy()\n", 40 | "\n", 41 | "\n", 42 | "image_batch = []\n", 43 | "with open(train_set) as f:\n", 44 | " for n, name in enumerate(f):\n", 45 | " image_batch.append(read_image_bgr(image_path + name[:-1] + '.png')) #name[:-1] to ignore \"\\n\"\n", 46 | "\n", 47 | "image_array = np.asarray(image_batch)\n", 48 | "print(image_array.shape) # Stored in BGR" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 73, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Mean in B G R: [135.19350556 121.36561177 104.94353401]\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "mean = np.mean(image_array, axis=(0,1,2))\n", 66 | "print('Mean in B G R:', mean)" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "Python 3", 73 | "language": "python", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.6.8" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 2 91 | } 92 | -------------------------------------------------------------------------------- /tests/bin/test_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras_retinanet.bin.train 18 | import keras.backend 19 | 20 | import warnings 21 | 22 | import pytest 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def clear_session(): 27 | # run before test (do nothing) 28 | yield 29 | # run after test, clear keras session 30 | keras.backend.clear_session() 31 | 32 | 33 | def test_coco(): 34 | # ignore warnings in this test 35 | warnings.simplefilter('ignore') 36 | 37 | # run training / evaluation 38 | keras_retinanet.bin.train.main([ 39 | '--epochs=1', 40 | '--steps=1', 41 | '--no-weights', 42 | '--no-snapshots', 43 | 'coco', 44 | 'tests/test-data/coco', 45 | ]) 46 | 47 | 48 | def test_pascal(): 49 | # ignore warnings in this test 50 | warnings.simplefilter('ignore') 51 | 52 | # run training / evaluation 53 | keras_retinanet.bin.train.main([ 54 | '--epochs=1', 55 | '--steps=1', 56 | '--no-weights', 57 | '--no-snapshots', 58 | 'pascal', 59 | 'tests/test-data/pascal', 60 | ]) 61 | 62 | 63 | def test_csv(): 64 | # ignore warnings in this test 65 | warnings.simplefilter('ignore') 66 | 67 | # run training / evaluation 68 | keras_retinanet.bin.train.main([ 69 | '--epochs=1', 70 | '--steps=1', 71 | '--no-weights', 72 | '--no-snapshots', 73 | 'csv', 74 | 'tests/test-data/csv/annotations.csv', 75 | 'tests/test-data/csv/classes.csv', 76 | ]) 77 | 78 | 79 | def test_vgg(): 80 | # ignore warnings in this test 81 | warnings.simplefilter('ignore') 82 | 83 | # run training / evaluation 84 | keras_retinanet.bin.train.main([ 85 | '--backbone=vgg16', 86 | '--epochs=1', 87 | '--steps=1', 88 | '--no-weights', 89 | '--no-snapshots', 90 | '--freeze-backbone', 91 | 'coco', 92 | 'tests/test-data/coco', 93 | ]) 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools.extension import Extension 3 | from distutils.command.build_ext import build_ext as DistUtilsBuildExt 4 | 5 | 6 | class BuildExtension(setuptools.Command): 7 | description = DistUtilsBuildExt.description 8 | user_options = DistUtilsBuildExt.user_options 9 | boolean_options = DistUtilsBuildExt.boolean_options 10 | help_options = DistUtilsBuildExt.help_options 11 | 12 | def __init__(self, *args, **kwargs): 13 | from setuptools.command.build_ext import build_ext as SetupToolsBuildExt 14 | 15 | # Bypass __setatrr__ to avoid infinite recursion. 16 | self.__dict__['_command'] = SetupToolsBuildExt(*args, **kwargs) 17 | 18 | def __getattr__(self, name): 19 | return getattr(self._command, name) 20 | 21 | def __setattr__(self, name, value): 22 | setattr(self._command, name, value) 23 | 24 | def initialize_options(self, *args, **kwargs): 25 | return self._command.initialize_options(*args, **kwargs) 26 | 27 | def finalize_options(self, *args, **kwargs): 28 | ret = self._command.finalize_options(*args, **kwargs) 29 | import numpy 30 | self.include_dirs.append(numpy.get_include()) 31 | return ret 32 | 33 | def run(self, *args, **kwargs): 34 | return self._command.run(*args, **kwargs) 35 | 36 | 37 | extensions = [ 38 | Extension( 39 | 'keras_retinanet.utils.compute_overlap', 40 | ['keras_retinanet/utils/compute_overlap.pyx'] 41 | ), 42 | ] 43 | 44 | 45 | setuptools.setup( 46 | name = 'keras-retinanet', 47 | version = '0.5.1', 48 | description = 'Keras implementation of RetinaNet object detection.', 49 | url = 'https://github.com/fizyr/keras-retinanet', 50 | author = 'Hans Gaiser', 51 | author_email = 'h.gaiser@fizyr.com', 52 | maintainer = 'Hans Gaiser', 53 | maintainer_email = 'h.gaiser@fizyr.com', 54 | cmdclass = {'build_ext': BuildExtension}, 55 | packages = setuptools.find_packages(), 56 | install_requires = ['keras', 'keras-resnet', 'six', 'scipy', 'cython', 'Pillow', 'opencv-python', 'progressbar2'], 57 | entry_points = { 58 | 'console_scripts': [ 59 | 'retinanet-train=keras_retinanet.bin.train:main', 60 | 'retinanet-evaluate=keras_retinanet.bin.evaluate:main', 61 | 'retinanet-debug=keras_retinanet.bin.debug:main', 62 | 'retinanet-convert-model=keras_retinanet.bin.convert_model:main', 63 | ], 64 | }, 65 | ext_modules = extensions, 66 | setup_requires = ["cython>=0.28", "numpy>=1.14.0"] 67 | ) 68 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from ..utils.coco_eval import evaluate_coco 19 | 20 | 21 | class CocoEval(keras.callbacks.Callback): 22 | """ Performs COCO evaluation on each epoch. 23 | """ 24 | def __init__(self, generator, tensorboard=None, threshold=0.05): 25 | """ CocoEval callback intializer. 26 | 27 | Args 28 | generator : The generator used for creating validation data. 29 | tensorboard : If given, the results will be written to tensorboard. 30 | threshold : The score threshold to use. 31 | """ 32 | self.generator = generator 33 | self.threshold = threshold 34 | self.tensorboard = tensorboard 35 | 36 | super(CocoEval, self).__init__() 37 | 38 | def on_epoch_end(self, epoch, logs=None): 39 | logs = logs or {} 40 | 41 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 42 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]', 43 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]', 44 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 45 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 46 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', 47 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', 48 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', 49 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 50 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 51 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 52 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]'] 53 | coco_eval_stats = evaluate_coco(self.generator, self.model, self.threshold) 54 | if coco_eval_stats is not None and self.tensorboard is not None and self.tensorboard.writer is not None: 55 | import tensorflow as tf 56 | summary = tf.Summary() 57 | for index, result in enumerate(coco_eval_stats): 58 | summary_value = summary.value.add() 59 | summary_value.simple_value = result 60 | summary_value.tag = '{}. {}'.format(index + 1, coco_tag[index]) 61 | self.tensorboard.writer.add_summary(summary, epoch) 62 | logs[coco_tag[index]] = result 63 | -------------------------------------------------------------------------------- /keras_retinanet/utils/colors.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | def label_color(label): 5 | """ Return a color from a set of predefined colors. Contains 80 colors in total. 6 | 7 | Args 8 | label: The label to get the color for. 9 | 10 | Returns 11 | A list of three values representing a RGB color. 12 | 13 | If no color is defined for a certain label, the color green is returned and a warning is printed. 14 | """ 15 | if label < len(colors): 16 | return colors[label] 17 | else: 18 | warnings.warn('Label {} has no color, returning default.'.format(label)) 19 | return (0, 255, 0) 20 | 21 | 22 | """ 23 | Generated using: 24 | 25 | ``` 26 | colors = [list((matplotlib.colors.hsv_to_rgb([x, 1.0, 1.0]) * 255).astype(int)) for x in np.arange(0, 1, 1.0 / 80)] 27 | shuffle(colors) 28 | pprint(colors) 29 | ``` 30 | """ 31 | colors = [ 32 | [31 , 0 , 255] , 33 | [0 , 159 , 255] , 34 | [255 , 95 , 0] , 35 | [255 , 19 , 0] , 36 | [255 , 0 , 0] , 37 | [255 , 38 , 0] , 38 | [0 , 255 , 25] , 39 | [255 , 0 , 133] , 40 | [255 , 172 , 0] , 41 | [108 , 0 , 255] , 42 | [0 , 82 , 255] , 43 | [0 , 255 , 6] , 44 | [255 , 0 , 152] , 45 | [223 , 0 , 255] , 46 | [12 , 0 , 255] , 47 | [0 , 255 , 178] , 48 | [108 , 255 , 0] , 49 | [184 , 0 , 255] , 50 | [255 , 0 , 76] , 51 | [146 , 255 , 0] , 52 | [51 , 0 , 255] , 53 | [0 , 197 , 255] , 54 | [255 , 248 , 0] , 55 | [255 , 0 , 19] , 56 | [255 , 0 , 38] , 57 | [89 , 255 , 0] , 58 | [127 , 255 , 0] , 59 | [255 , 153 , 0] , 60 | [0 , 255 , 255] , 61 | [0 , 255 , 216] , 62 | [0 , 255 , 121] , 63 | [255 , 0 , 248] , 64 | [70 , 0 , 255] , 65 | [0 , 255 , 159] , 66 | [0 , 216 , 255] , 67 | [0 , 6 , 255] , 68 | [0 , 63 , 255] , 69 | [31 , 255 , 0] , 70 | [255 , 57 , 0] , 71 | [255 , 0 , 210] , 72 | [0 , 255 , 102] , 73 | [242 , 255 , 0] , 74 | [255 , 191 , 0] , 75 | [0 , 255 , 63] , 76 | [255 , 0 , 95] , 77 | [146 , 0 , 255] , 78 | [184 , 255 , 0] , 79 | [255 , 114 , 0] , 80 | [0 , 255 , 235] , 81 | [255 , 229 , 0] , 82 | [0 , 178 , 255] , 83 | [255 , 0 , 114] , 84 | [255 , 0 , 57] , 85 | [0 , 140 , 255] , 86 | [0 , 121 , 255] , 87 | [12 , 255 , 0] , 88 | [255 , 210 , 0] , 89 | [0 , 255 , 44] , 90 | [165 , 255 , 0] , 91 | [0 , 25 , 255] , 92 | [0 , 255 , 140] , 93 | [0 , 101 , 255] , 94 | [0 , 255 , 82] , 95 | [223 , 255 , 0] , 96 | [242 , 0 , 255] , 97 | [89 , 0 , 255] , 98 | [165 , 0 , 255] , 99 | [70 , 255 , 0] , 100 | [255 , 0 , 172] , 101 | [255 , 76 , 0] , 102 | [203 , 255 , 0] , 103 | [204 , 0 , 255] , 104 | [255 , 0 , 229] , 105 | [255 , 133 , 0] , 106 | [127 , 0 , 255] , 107 | [0 , 235 , 255] , 108 | [0 , 255 , 197] , 109 | [255 , 0 , 191] , 110 | [0 , 44 , 255] , 111 | [50 , 255 , 0] 112 | ] 113 | -------------------------------------------------------------------------------- /keras_retinanet/bin/convert_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Copyright 2017-2018 Fizyr (https://fizyr.com) 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import argparse 20 | import os 21 | import sys 22 | 23 | import keras 24 | import tensorflow as tf 25 | 26 | # Allow relative imports when being executed as script. 27 | if __name__ == "__main__" and __package__ is None: 28 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) 29 | import keras_retinanet.bin # noqa: F401 30 | __package__ = "keras_retinanet.bin" 31 | 32 | # Change these to absolute imports if you copy this script outside the keras_retinanet package. 33 | from .. import models 34 | from ..utils.config import read_config_file, parse_anchor_parameters 35 | 36 | 37 | def get_session(): 38 | """ Construct a modified tf session. 39 | """ 40 | config = tf.ConfigProto() 41 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 42 | return tf.Session(config=config) 43 | 44 | 45 | def parse_args(args): 46 | parser = argparse.ArgumentParser(description='Script for converting a training model to an inference model.') 47 | 48 | parser.add_argument('model_in', help='The model to convert.') 49 | parser.add_argument('model_out', help='Path to save the converted model to.') 50 | parser.add_argument('--backbone', help='The backbone of the model to convert.', default='resnet50') 51 | parser.add_argument('--no-nms', help='Disables non maximum suppression.', dest='nms', action='store_false') 52 | parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false') 53 | parser.add_argument('--config', help='Path to a configuration parameters .ini file.') 54 | 55 | return parser.parse_args(args) 56 | 57 | 58 | def main(args=None): 59 | # parse arguments 60 | if args is None: 61 | args = sys.argv[1:] 62 | args = parse_args(args) 63 | 64 | # Set modified tf session to avoid using the GPUs 65 | keras.backend.tensorflow_backend.set_session(get_session()) 66 | 67 | # optionally load config parameters 68 | anchor_parameters = None 69 | if args.config: 70 | args.config = read_config_file(args.config) 71 | if 'anchor_parameters' in args.config: 72 | anchor_parameters = parse_anchor_parameters(args.config) 73 | 74 | # load the model 75 | model = models.load_model(args.model_in, backbone_name=args.backbone) 76 | 77 | # check if this is indeed a training model 78 | models.check_training_model(model) 79 | 80 | # convert the model 81 | model = models.convert_model(model, nms=args.nms, class_specific_filter=args.class_specific_filter, anchor_params=anchor_parameters) 82 | 83 | # save model 84 | model.save(args.model_out) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/common.py: -------------------------------------------------------------------------------- 1 | import keras.callbacks 2 | import numpy as np 3 | 4 | def default_lr_scheduler( 5 | base_lr = 1e-5, 6 | steps = np.array([15, 25]) 7 | ): 8 | 9 | def default_lr_scheduler_(epoch, lr): 10 | 11 | if (epoch > steps[0]) and (epoch <= steps[1]): 12 | lr = base_lr * 1e-1 13 | elif epoch > steps[1]: 14 | lr = base_lr * 1e-2 15 | 16 | print('Learning rate: ', lr) 17 | return lr 18 | 19 | return default_lr_scheduler_ 20 | 21 | 22 | class LearningRateScheduler(keras.callbacks.Callback): 23 | """Learning rate scheduler (mostly copied from keras.callbacks.LearningRateScheduler). 24 | # Arguments 25 | schedule: a function that takes an iteration as input 26 | (integer, indexed from 0) and current learning rate 27 | and returns a new learning rate as output (float). 28 | verbose: int. 0: quiet, 1: update messages. 29 | """ 30 | 31 | def __init__(self, schedule, base_lr=0.01, verbose=0): 32 | super(LearningRateScheduler, self).__init__() 33 | self.schedule = schedule 34 | self.epoch = 0 35 | self.verbose = verbose 36 | 37 | def on_epoch_begin(self, epoch, logs=None): 38 | self.epoch += 1 39 | 40 | if not hasattr(self.model.optimizer, 'lr'): 41 | raise ValueError('Optimizer must have a "lr" attribute.') 42 | 43 | lr = float(keras.backend.get_value(self.model.optimizer.lr)) 44 | lr = self.schedule(epoch=self.epoch, lr=lr) 45 | 46 | if not isinstance(lr, (float, np.float32, np.float64)): 47 | raise ValueError('The output of the "schedule" function should be float (got {}).'.format(lr)) 48 | 49 | keras.backend.set_value(self.model.optimizer.lr, lr) 50 | 51 | if self.verbose > 0: 52 | print() 53 | print('\nEpoch {:05d}: LearningRateScheduler reducing learning rate to {}.'.format(self.epoch, lr)) 54 | 55 | class RedirectModel(keras.callbacks.Callback): 56 | """Callback which wraps another callback, but executed on a different model. 57 | 58 | ```python 59 | model = keras.models.load_model('model.h5') 60 | model_checkpoint = ModelCheckpoint(filepath='snapshot.h5') 61 | parallel_model = multi_gpu_model(model, gpus=2) 62 | parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)]) 63 | ``` 64 | 65 | Args 66 | callback : callback to wrap. 67 | model : model to use when executing callbacks. 68 | """ 69 | 70 | def __init__(self, 71 | callback, 72 | model): 73 | super(RedirectModel, self).__init__() 74 | 75 | self.callback = callback 76 | self.redirect_model = model 77 | 78 | def on_epoch_begin(self, epoch, logs=None): 79 | self.callback.on_epoch_begin(epoch, logs=logs) 80 | 81 | def on_epoch_end(self, epoch, logs=None): 82 | self.callback.on_epoch_end(epoch, logs=logs) 83 | 84 | def on_batch_begin(self, batch, logs=None): 85 | self.callback.on_batch_begin(batch, logs=logs) 86 | 87 | def on_batch_end(self, batch, logs=None): 88 | self.callback.on_batch_end(batch, logs=logs) 89 | 90 | def on_train_begin(self, logs=None): 91 | # overwrite the model with our custom model 92 | self.callback.set_model(self.redirect_model) 93 | 94 | self.callback.on_train_begin(logs=logs) 95 | 96 | def on_train_end(self, logs=None): 97 | self.callback.on_train_end(logs=logs) 98 | -------------------------------------------------------------------------------- /keras_retinanet/utils/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from pycocotools.cocoeval import COCOeval 18 | 19 | import keras 20 | import numpy as np 21 | import json 22 | 23 | import progressbar 24 | assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 25 | 26 | 27 | def evaluate_coco(generator, model, threshold=0.05): 28 | """ Use the pycocotools to evaluate a COCO model on a dataset. 29 | 30 | Args 31 | generator : The generator for generating the evaluation data. 32 | model : The model to evaluate. 33 | threshold : The score threshold to use. 34 | """ 35 | # start collecting results 36 | results = [] 37 | image_ids = [] 38 | for index in progressbar.progressbar(range(generator.size()), prefix='COCO evaluation: '): 39 | image = generator.load_image(index) 40 | image = generator.preprocess_image(image) 41 | image, scale = generator.resize_image(image) 42 | 43 | if keras.backend.image_data_format() == 'channels_first': 44 | image = image.transpose((2, 0, 1)) 45 | 46 | # run network 47 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 48 | 49 | # correct boxes for image scale 50 | boxes /= scale 51 | 52 | # change to (x, y, w, h) (MS COCO standard) 53 | boxes[:, :, 2] -= boxes[:, :, 0] 54 | boxes[:, :, 3] -= boxes[:, :, 1] 55 | 56 | # compute predicted labels and scores 57 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 58 | # scores are sorted, so we can break 59 | if score < threshold: 60 | break 61 | 62 | # append detection for each positively labeled class 63 | image_result = { 64 | 'image_id' : generator.image_ids[index], 65 | 'category_id' : generator.label_to_coco_label(label), 66 | 'score' : float(score), 67 | 'bbox' : box.tolist(), 68 | } 69 | 70 | # append detection to results 71 | results.append(image_result) 72 | 73 | # append image to list of processed images 74 | image_ids.append(generator.image_ids[index]) 75 | 76 | if not len(results): 77 | return 78 | 79 | # write output 80 | json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4) 81 | json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4) 82 | 83 | # load results in COCO evaluation tool 84 | coco_true = generator.coco 85 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name)) 86 | 87 | # run COCO evaluation 88 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox') 89 | coco_eval.params.imgIds = image_ids 90 | coco_eval.evaluate() 91 | coco_eval.accumulate() 92 | coco_eval.summarize() 93 | return coco_eval.stats 94 | -------------------------------------------------------------------------------- /examples/inference.py: -------------------------------------------------------------------------------- 1 | # show images inline 2 | # %matplotlib inline 3 | 4 | # automatically reload modules when they have changed 5 | # %load_ext autoreload 6 | # %autoreload 2 7 | 8 | # import keras 9 | import keras 10 | 11 | # import keras_retinanet 12 | from keras_retinanet import models 13 | from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image 14 | from keras_retinanet.utils.visualization import draw_box, draw_caption 15 | from keras_retinanet.utils.colors import label_color 16 | 17 | # import miscellaneous modules 18 | import matplotlib.pyplot as plt 19 | import cv2 20 | import os 21 | import numpy as np 22 | import time 23 | import argparse 24 | import ntpath 25 | 26 | # set tf backend to allow memory to grow, instead of claiming everything 27 | import tensorflow as tf 28 | 29 | #def parse_args(args): 30 | # """ Parse the arguments. """ 31 | # parser = argparse.ArgumentParser(description='Infer an image and visualize detections.') 32 | # parser.add_argument('--model', help='The path to model weights', dest='model', default=None) 33 | # parser.add_argument('--convert-model', help='Convert the model to inference model', action='store_true') 34 | # parser.add_argument('--image', help='The path to the image to infer', default=None) 35 | # 36 | # return paser.parse_args(args) 37 | 38 | def get_session(): 39 | config = tf.ConfigProto() 40 | config.gpu_options.allow_growth = True 41 | return tf.Session(config=config) 42 | 43 | # use this environment flag to change which GPU to use 44 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 45 | 46 | # set the modified tf session as backend in keras 47 | keras.backend.tensorflow_backend.set_session(get_session()) 48 | 49 | 50 | 51 | # adjust this to point to your downloaded/trained model 52 | # models can be downloaded here: https://github.com/fizyr/keras-retinanet/releases 53 | #if args.model is not None: 54 | # model_path = args.model 55 | #else: 56 | # print('Model to path is missing.ß') 57 | 58 | # load retinanet model 59 | model = models.load_model('../results/test_02/after_10_init_epochs/vgg16_csv_08.h5', backbone_name='vgg16') 60 | 61 | # if the model is not converted to an inference model, use the line below 62 | # see: https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model 63 | model = models.convert_model(model) 64 | print(model.summary()) 65 | 66 | # load label to names mapping for visualization purposes 67 | labels_to_names = {0: 'apple'} 68 | 69 | 70 | # load image 71 | image_name = '20130320T012856.619229_62.png' 72 | image = read_image_bgr(image_name) 73 | 74 | # copy to draw on 75 | draw = image.copy() 76 | draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB) 77 | 78 | # preprocess image for network 79 | image = preprocess_image(image) 80 | image, scale = resize_image(image) 81 | 82 | # process image 83 | start = time.time() 84 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 85 | print("processing time: ", time.time() - start) 86 | 87 | # correct for image scale 88 | boxes /= scale 89 | 90 | # visualize detections 91 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 92 | # scores are sorted so we can break 93 | if score < 0.5: 94 | break 95 | 96 | color = label_color(label) 97 | 98 | b = box.astype(int) 99 | draw_box(draw, b, color=color) 100 | 101 | caption = "{} {:.3f}".format(labels_to_names[label], score) 102 | draw_caption(draw, b, caption) 103 | 104 | plt.figure(figsize=(15, 15)) 105 | plt.axis('off') 106 | plt.imshow(draw) 107 | plt.show() 108 | plt.savefig('{}.pdf'.format(image_name, bbox_inches='tight')) 109 | 110 | -------------------------------------------------------------------------------- /keras_retinanet/backend/tensorflow_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import tensorflow 18 | 19 | 20 | def ones(*args, **kwargs): 21 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/ones . 22 | """ 23 | return tensorflow.ones(*args, **kwargs) 24 | 25 | 26 | def transpose(*args, **kwargs): 27 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/transpose . 28 | """ 29 | return tensorflow.transpose(*args, **kwargs) 30 | 31 | 32 | def map_fn(*args, **kwargs): 33 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/map_fn . 34 | """ 35 | return tensorflow.map_fn(*args, **kwargs) 36 | 37 | 38 | def pad(*args, **kwargs): 39 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/pad . 40 | """ 41 | return tensorflow.pad(*args, **kwargs) 42 | 43 | 44 | def top_k(*args, **kwargs): 45 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/nn/top_k . 46 | """ 47 | return tensorflow.nn.top_k(*args, **kwargs) 48 | 49 | 50 | def clip_by_value(*args, **kwargs): 51 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/clip_by_value . 52 | """ 53 | return tensorflow.clip_by_value(*args, **kwargs) 54 | 55 | 56 | def resize_images(images, size, method='bilinear', align_corners=False): 57 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/resize_images . 58 | 59 | Args 60 | method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area'). 61 | """ 62 | methods = { 63 | 'bilinear': tensorflow.image.ResizeMethod.BILINEAR, 64 | 'nearest' : tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR, 65 | 'bicubic' : tensorflow.image.ResizeMethod.BICUBIC, 66 | 'area' : tensorflow.image.ResizeMethod.AREA, 67 | } 68 | return tensorflow.image.resize_images(images, size, methods[method], align_corners) 69 | 70 | 71 | def non_max_suppression(*args, **kwargs): 72 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/non_max_suppression . 73 | """ 74 | return tensorflow.image.non_max_suppression(*args, **kwargs) 75 | 76 | 77 | def range(*args, **kwargs): 78 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/range . 79 | """ 80 | return tensorflow.range(*args, **kwargs) 81 | 82 | 83 | def scatter_nd(*args, **kwargs): 84 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/scatter_nd . 85 | """ 86 | return tensorflow.scatter_nd(*args, **kwargs) 87 | 88 | 89 | def gather_nd(*args, **kwargs): 90 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/gather_nd . 91 | """ 92 | return tensorflow.gather_nd(*args, **kwargs) 93 | 94 | 95 | def meshgrid(*args, **kwargs): 96 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/meshgrid . 97 | """ 98 | return tensorflow.meshgrid(*args, **kwargs) 99 | 100 | 101 | def where(*args, **kwargs): 102 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/where . 103 | """ 104 | return tensorflow.where(*args, **kwargs) 105 | -------------------------------------------------------------------------------- /keras_retinanet/backend/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras.backend 18 | from .dynamic import meshgrid 19 | 20 | 21 | def bbox_transform_inv(boxes, deltas, mean=None, std=None): 22 | """ Applies deltas (usually regression results) to boxes (usually anchors). 23 | 24 | Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed. 25 | The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes. 26 | 27 | Args 28 | boxes : np.array of shape (B, N, 4), where B is the batch size, N the number of boxes and 4 values for (x1, y1, x2, y2). 29 | deltas: np.array of same shape as boxes. These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height. 30 | mean : The mean value used when computing deltas (defaults to [0, 0, 0, 0]). 31 | std : The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]). 32 | 33 | Returns 34 | A np.array of the same shape as boxes, but with deltas applied to each box. 35 | The mean and std are used during training to normalize the regression values (networks love normalization). 36 | """ 37 | if mean is None: 38 | mean = [0, 0, 0, 0] 39 | if std is None: 40 | std = [0.2, 0.2, 0.2, 0.2] 41 | 42 | width = boxes[:, :, 2] - boxes[:, :, 0] 43 | height = boxes[:, :, 3] - boxes[:, :, 1] 44 | 45 | x1 = boxes[:, :, 0] + (deltas[:, :, 0] * std[0] + mean[0]) * width 46 | y1 = boxes[:, :, 1] + (deltas[:, :, 1] * std[1] + mean[1]) * height 47 | x2 = boxes[:, :, 2] + (deltas[:, :, 2] * std[2] + mean[2]) * width 48 | y2 = boxes[:, :, 3] + (deltas[:, :, 3] * std[3] + mean[3]) * height 49 | 50 | pred_boxes = keras.backend.stack([x1, y1, x2, y2], axis=2) 51 | 52 | return pred_boxes 53 | 54 | 55 | def shift(shape, stride, anchors): 56 | """ Produce shifted anchors based on shape of the map and stride size. 57 | 58 | Args 59 | shape : Shape to shift the anchors over. 60 | stride : Stride to shift the anchors with over the shape. 61 | anchors: The anchors to apply at each location. 62 | """ 63 | shift_x = (keras.backend.arange(0, shape[1], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride 64 | shift_y = (keras.backend.arange(0, shape[0], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride 65 | 66 | shift_x, shift_y = meshgrid(shift_x, shift_y) 67 | shift_x = keras.backend.reshape(shift_x, [-1]) 68 | shift_y = keras.backend.reshape(shift_y, [-1]) 69 | 70 | shifts = keras.backend.stack([ 71 | shift_x, 72 | shift_y, 73 | shift_x, 74 | shift_y 75 | ], axis=0) 76 | 77 | shifts = keras.backend.transpose(shifts) 78 | number_of_anchors = keras.backend.shape(anchors)[0] 79 | 80 | k = keras.backend.shape(shifts)[0] # number of base points = feat_h * feat_w 81 | 82 | shifted_anchors = keras.backend.reshape(anchors, [1, number_of_anchors, 4]) + keras.backend.cast(keras.backend.reshape(shifts, [k, 1, 4]), keras.backend.floatx()) 83 | shifted_anchors = keras.backend.reshape(shifted_anchors, [k * number_of_anchors, 4]) 84 | 85 | return shifted_anchors -------------------------------------------------------------------------------- /examples/convert_annotations.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os, sys 3 | 4 | script_path = os.path.abspath(os.path.dirname(sys.argv[0])) # path/to/conver_annotations.py 5 | keras_retinanet_path = os.path.abspath(os.path.join(script_path, os.pardir)) # path/to/keras-retinanet/ 6 | 7 | images_path = os.path.join(keras_retinanet_path,'data/images/') 8 | annotations_path = os.path.join(keras_retinanet_path,'data/annotations/') 9 | rect_annotations_path = os.path.join(keras_retinanet_path,'data/rectangular_annotations/') 10 | sets_path = os.path.join(keras_retinanet_path,'data/sets/') 11 | 12 | #create directory 13 | if not os.path.exists(rect_annotations_path): 14 | os.makedirs(rect_annotations_path) 15 | 16 | 17 | def is_in_set(key, ID, sets_path): 18 | """ 19 | returns whether image ID is in the set path/to/data/sets/key.txt 20 | """ 21 | with open(os.path.join(sets_path, key + '.txt')) as f: 22 | reader = csv.reader(f, delimiter=',') 23 | for row in reader: 24 | if ID == row[0]: 25 | return True 26 | return False 27 | 28 | 29 | data = {'all': [], 'train' : [], 'val' : [], 'test' : [], 'train_val' : []} 30 | for root, dirs, files in os.walk(annotations_path): 31 | for file in files: 32 | if file.endswith('.csv'): 33 | with open(os.path.join(annotations_path, file)) as f: 34 | empty = True if (sum(1 for line in f) == 1) else False #check if .csv is empty # Needs to be fixed 35 | 36 | if empty: 37 | row = [os.path.join(images_path,file.replace('csv','png')),'','','','',''] 38 | for key in data: 39 | if is_in_set(key, file.replace('.csv',''), sets_path) == True: 40 | data[key].append(row) 41 | else: 42 | with open(os.path.join(annotations_path, file)) as f: 43 | reader = csv.reader(f) 44 | next(reader) 45 | for row in reader: 46 | row[0] = os.path.join(images_path, file.replace('csv','png')) # path 47 | radius = float(row[3]) 48 | row[1] = round(float(row[1]) - radius) # x0 49 | row[2] = round(float(row[2]) - radius) # y0 50 | row[3] = round(float(row[1]) + 2*radius) # x1 51 | row[4] = round(float(row[2]) + 2*radius) # y1 52 | row.append('apple') # class 53 | 54 | # Fix boxes on the fringe 55 | if (row[1] < 0) : row[1] = 0 56 | if (row[2] < 0) : row[2] = 0 57 | if (row[3] > 307) : row[3] = 307 58 | if (row[4] > 201) : row[4] = 201 59 | 60 | for key in data: 61 | if is_in_set(key, file.replace('.csv',''), sets_path) == True: 62 | data[key].append(row) 63 | 64 | # Discard annotations that are off the image 65 | #if (row[1] >= 0) and (row[2] >= 0) and (row[3] <= 307) and (row[4] <= 201): 66 | # for key in data: 67 | # if is_in_set(key, file.replace('.csv',''), sets_path) == True: 68 | # data[key].append(row) 69 | 70 | 71 | annotation_sets = {} 72 | for key in data: 73 | with open(os.path.join(rect_annotations_path, key + '_annotations.csv'), mode='w') as f: 74 | file = csv.writer(f, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 75 | for row in data[key]: 76 | file.writerow(row) 77 | 78 | #all__annotations = os.path.join(rect_annotations_path, 'all_annotations.csv') 79 | #train_annotations = os.path.join(rect_annotations_path, 'train_annotations.csv') 80 | new_classes = os.path.join(rect_annotations_path, 'classes.csv') 81 | 82 | with open(new_classes, mode='w') as f: 83 | file = csv.writer(f, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 84 | file.writerow(['apple',0]) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Apple Detection and Counting Using RetinaNet 2 | 3 | This work aims to investigate the apple detection problem through the deployment of the RetinaNet object detection framework in conjunction with the VGG architecture. Following hyper-parameters’ optimisation, the performance scaling with the backbone’s network depth is examined through four different proposed deployments for the side-network. Analysis of the relationship between performance and training size establishes that 10 samples are enough to achieve adequate performance, while 200 samples are enough to achieve state-of-the-art performance. Moreover, a novel lightweight model is proposed that achieves an F1-score of 0.908 and inference time of nearly 70FPS. These results outperform previous state-of-the-art models in both performance and detection rates. Finally, the results are discussed regarding the model’s limitations, and insights for future work are provided. 4 | 5 | # Dataset 6 | The dataset used for this project is the [ACFR dataset](http://data.acfr.usyd.edu.au/ag/treecrops/2016-multifruit/) and can be downloaded [here](https://github.com/nikostsagk/Apple-detection/releases/download/dataset/Archive.zip). It consists of images of three different fruits (apples, mangoes & almonds), but only the apple set was used. The original train/val/test set was preserved in order to make comparisons with previous studies. 7 | 8 | The dataset contains 1120 308x202 samples with apples. The annotations are given in `#item, x0, y0, x1, y1, class` format (circular) and can be converted to square with the `examples/convert_annotations.py` file. More info in the `readme.txt` file in the dataset folder. 9 | 10 |
11 |
12 | example 1 13 |
14 |
15 | example 2 16 |
17 |
18 | example 3 19 |
20 |
21 | example 4 22 |
23 |
24 | 25 | # Architectures 26 | 27 | The repository consists of four side-network architectures, each one implemented on the four repo branches. 28 | 29 | * `master` : The original side-network architecture. 30 | 31 |

32 | 33 |

34 | 35 | * `retinanet_p3p4p5` : The original side-network architecture without the strided convolutional filters right after the VGG network. 36 | 37 |

38 | 39 |

40 | 41 | * `retinanet_ci_multiclassifiers` : The `retinanet_p3p4p5` implementation with separate classification regression heads for the predictions. 42 | 43 |

44 | 45 |

46 | 47 | * `retinanet_ci` : A lightweight implementation where common classification and regression heads make predictions right after the Ci reduced blocks, without the upsampling-merging technique. 48 | 49 |

50 | 51 |

52 | 53 | # Installation 54 | Clone the repo and follow the instructions in: [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet) 55 | 56 | # Sources 57 | 1) [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet) 58 | 2) [Martin Zlocha](https://github.com/martinzlocha/anchor-optimization) 59 | 3) [ACFR FRUIT DATASET](http://data.acfr.usyd.edu.au/ag/treecrops/2016-multifruit/) 60 | -------------------------------------------------------------------------------- /tests/backend/test_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import numpy as np 18 | import keras 19 | import keras_retinanet.backend 20 | 21 | 22 | def test_bbox_transform_inv(): 23 | boxes = np.array([[ 24 | [100, 100, 200, 200], 25 | [100, 100, 300, 300], 26 | [100, 100, 200, 300], 27 | [100, 100, 300, 200], 28 | [80, 120, 200, 200], 29 | [80, 120, 300, 300], 30 | [80, 120, 200, 300], 31 | [80, 120, 300, 200], 32 | ]]) 33 | boxes = keras.backend.variable(boxes) 34 | 35 | deltas = np.array([[ 36 | [0 , 0 , 0 , 0 ], 37 | [0 , 0.1, 0 , 0 ], 38 | [-0.3, 0 , 0 , 0 ], 39 | [0.2 , 0.2, 0 , 0 ], 40 | [0 , 0 , 0.1 , 0 ], 41 | [0 , 0 , 0 , -0.3], 42 | [0 , 0 , 0.2 , 0.2 ], 43 | [0.1 , 0.2, -0.3, 0.4 ], 44 | ]]) 45 | deltas = keras.backend.variable(deltas) 46 | 47 | expected = np.array([[ 48 | [100 , 100 , 200 , 200 ], 49 | [100 , 104 , 300 , 300 ], 50 | [ 94 , 100 , 200 , 300 ], 51 | [108 , 104 , 300 , 200 ], 52 | [ 80 , 120 , 202.4 , 200 ], 53 | [ 80 , 120 , 300 , 289.2], 54 | [ 80 , 120 , 204.8 , 307.2], 55 | [ 84.4, 123.2, 286.8 , 206.4] 56 | ]]) 57 | 58 | result = keras_retinanet.backend.bbox_transform_inv(boxes, deltas) 59 | result = keras.backend.eval(result) 60 | 61 | np.testing.assert_array_almost_equal(result, expected, decimal=2) 62 | 63 | 64 | def test_shift(): 65 | shape = (2, 3) 66 | stride = 8 67 | 68 | anchors = np.array([ 69 | [-8, -8, 8, 8], 70 | [-16, -16, 16, 16], 71 | [-12, -12, 12, 12], 72 | [-12, -16, 12, 16], 73 | [-16, -12, 16, 12] 74 | ], dtype=keras.backend.floatx()) 75 | 76 | expected = [ 77 | # anchors for (0, 0) 78 | [4 - 8, 4 - 8, 4 + 8, 4 + 8], 79 | [4 - 16, 4 - 16, 4 + 16, 4 + 16], 80 | [4 - 12, 4 - 12, 4 + 12, 4 + 12], 81 | [4 - 12, 4 - 16, 4 + 12, 4 + 16], 82 | [4 - 16, 4 - 12, 4 + 16, 4 + 12], 83 | 84 | # anchors for (0, 1) 85 | [12 - 8, 4 - 8, 12 + 8, 4 + 8], 86 | [12 - 16, 4 - 16, 12 + 16, 4 + 16], 87 | [12 - 12, 4 - 12, 12 + 12, 4 + 12], 88 | [12 - 12, 4 - 16, 12 + 12, 4 + 16], 89 | [12 - 16, 4 - 12, 12 + 16, 4 + 12], 90 | 91 | # anchors for (0, 2) 92 | [20 - 8, 4 - 8, 20 + 8, 4 + 8], 93 | [20 - 16, 4 - 16, 20 + 16, 4 + 16], 94 | [20 - 12, 4 - 12, 20 + 12, 4 + 12], 95 | [20 - 12, 4 - 16, 20 + 12, 4 + 16], 96 | [20 - 16, 4 - 12, 20 + 16, 4 + 12], 97 | 98 | # anchors for (1, 0) 99 | [4 - 8, 12 - 8, 4 + 8, 12 + 8], 100 | [4 - 16, 12 - 16, 4 + 16, 12 + 16], 101 | [4 - 12, 12 - 12, 4 + 12, 12 + 12], 102 | [4 - 12, 12 - 16, 4 + 12, 12 + 16], 103 | [4 - 16, 12 - 12, 4 + 16, 12 + 12], 104 | 105 | # anchors for (1, 1) 106 | [12 - 8, 12 - 8, 12 + 8, 12 + 8], 107 | [12 - 16, 12 - 16, 12 + 16, 12 + 16], 108 | [12 - 12, 12 - 12, 12 + 12, 12 + 12], 109 | [12 - 12, 12 - 16, 12 + 12, 12 + 16], 110 | [12 - 16, 12 - 12, 12 + 16, 12 + 12], 111 | 112 | # anchors for (1, 2) 113 | [20 - 8, 12 - 8, 20 + 8, 12 + 8], 114 | [20 - 16, 12 - 16, 20 + 16, 12 + 16], 115 | [20 - 12, 12 - 12, 20 + 12, 12 + 12], 116 | [20 - 12, 12 - 16, 20 + 12, 12 + 16], 117 | [20 - 16, 12 - 12, 20 + 16, 12 + 12], 118 | ] 119 | 120 | result = keras_retinanet.backend.shift(shape, stride, anchors) 121 | result = keras.backend.eval(result) 122 | 123 | np.testing.assert_array_equal(result, expected) 124 | -------------------------------------------------------------------------------- /keras_retinanet/models/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 vidosits (https://github.com/vidosits/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.applications import densenet 19 | from keras.utils import get_file 20 | 21 | from . import retinanet 22 | from . import Backbone 23 | from ..utils.image import preprocess_image 24 | 25 | 26 | allowed_backbones = { 27 | 'densenet121': ([6, 12, 24, 16], densenet.DenseNet121), 28 | 'densenet169': ([6, 12, 32, 32], densenet.DenseNet169), 29 | 'densenet201': ([6, 12, 48, 32], densenet.DenseNet201), 30 | } 31 | 32 | 33 | class DenseNetBackbone(Backbone): 34 | """ Describes backbone information and provides utility functions. 35 | """ 36 | 37 | def retinanet(self, *args, **kwargs): 38 | """ Returns a retinanet model using the correct backbone. 39 | """ 40 | return densenet_retinanet(*args, backbone=self.backbone, **kwargs) 41 | 42 | def download_imagenet(self): 43 | """ Download pre-trained weights for the specified backbone name. 44 | This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop 45 | where backbone is the densenet + number of layers (e.g. densenet121). 46 | For more info check the explanation from the keras densenet script itself: 47 | https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py 48 | """ 49 | origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/' 50 | file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5' 51 | 52 | # load weights 53 | if keras.backend.image_data_format() == 'channels_first': 54 | raise ValueError('Weights for "channels_first" format are not available.') 55 | 56 | weights_url = origin + file_name.format(self.backbone) 57 | return get_file(file_name.format(self.backbone), weights_url, cache_subdir='models') 58 | 59 | def validate(self): 60 | """ Checks whether the backbone string is correct. 61 | """ 62 | backbone = self.backbone.split('_')[0] 63 | 64 | if backbone not in allowed_backbones: 65 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones.keys())) 66 | 67 | def preprocess_image(self, inputs): 68 | """ Takes as input an image and prepares it for being passed through the network. 69 | """ 70 | return preprocess_image(inputs, mode='tf') 71 | 72 | 73 | def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifier=None, **kwargs): 74 | """ Constructs a retinanet model using a densenet backbone. 75 | 76 | Args 77 | num_classes: Number of classes to predict. 78 | backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')). 79 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 80 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 81 | 82 | Returns 83 | RetinaNet model with a DenseNet backbone. 84 | """ 85 | # choose default input 86 | if inputs is None: 87 | inputs = keras.layers.Input((None, None, 3)) 88 | 89 | blocks, creator = allowed_backbones[backbone] 90 | model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None) 91 | 92 | # get last conv layer from the end of each dense block 93 | layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)] 94 | 95 | # create the densenet backbone 96 | model = keras.models.Model(inputs=inputs, outputs=layer_outputs[1:], name=model.name) 97 | 98 | # invoke modifier if given 99 | if modifier: 100 | model = modifier(model) 101 | 102 | # create the full model 103 | model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=model.outputs, **kwargs) 104 | 105 | return model 106 | -------------------------------------------------------------------------------- /keras_retinanet/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.applications import mobilenet 19 | from keras.utils import get_file 20 | from ..utils.image import preprocess_image 21 | 22 | from . import retinanet 23 | from . import Backbone 24 | 25 | 26 | class MobileNetBackbone(Backbone): 27 | """ Describes backbone information and provides utility functions. 28 | """ 29 | 30 | allowed_backbones = ['mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224'] 31 | 32 | def retinanet(self, *args, **kwargs): 33 | """ Returns a retinanet model using the correct backbone. 34 | """ 35 | return mobilenet_retinanet(*args, backbone=self.backbone, **kwargs) 36 | 37 | def download_imagenet(self): 38 | """ Download pre-trained weights for the specified backbone name. 39 | This name is in the format mobilenet{rows}_{alpha} where rows is the 40 | imagenet shape dimension and 'alpha' controls the width of the network. 41 | For more info check the explanation from the keras mobilenet script itself. 42 | """ 43 | 44 | alpha = float(self.backbone.split('_')[1]) 45 | rows = int(self.backbone.split('_')[0].replace('mobilenet', '')) 46 | 47 | # load weights 48 | if keras.backend.image_data_format() == 'channels_first': 49 | raise ValueError('Weights for "channels_last" format ' 50 | 'are not available.') 51 | if alpha == 1.0: 52 | alpha_text = '1_0' 53 | elif alpha == 0.75: 54 | alpha_text = '7_5' 55 | elif alpha == 0.50: 56 | alpha_text = '5_0' 57 | else: 58 | alpha_text = '2_5' 59 | 60 | model_name = 'mobilenet_{}_{}_tf_no_top.h5'.format(alpha_text, rows) 61 | weights_url = mobilenet.mobilenet.BASE_WEIGHT_PATH + model_name 62 | weights_path = get_file(model_name, weights_url, cache_subdir='models') 63 | 64 | return weights_path 65 | 66 | def validate(self): 67 | """ Checks whether the backbone string is correct. 68 | """ 69 | backbone = self.backbone.split('_')[0] 70 | 71 | if backbone not in MobileNetBackbone.allowed_backbones: 72 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, MobileNetBackbone.allowed_backbones)) 73 | 74 | def preprocess_image(self, inputs): 75 | """ Takes as input an image and prepares it for being passed through the network. 76 | """ 77 | return preprocess_image(inputs, mode='tf') 78 | 79 | 80 | def mobilenet_retinanet(num_classes, backbone='mobilenet224_1.0', inputs=None, modifier=None, **kwargs): 81 | """ Constructs a retinanet model using a mobilenet backbone. 82 | 83 | Args 84 | num_classes: Number of classes to predict. 85 | backbone: Which backbone to use (one of ('mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224')). 86 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 87 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 88 | 89 | Returns 90 | RetinaNet model with a MobileNet backbone. 91 | """ 92 | alpha = float(backbone.split('_')[1]) 93 | 94 | # choose default input 95 | if inputs is None: 96 | inputs = keras.layers.Input((None, None, 3)) 97 | 98 | backbone = mobilenet.MobileNet(input_tensor=inputs, alpha=alpha, include_top=False, pooling=None, weights=None) 99 | 100 | # create the full model 101 | layer_names = ['conv_pw_5_relu', 'conv_pw_11_relu', 'conv_pw_13_relu'] 102 | layer_outputs = [backbone.get_layer(name).output for name in layer_names] 103 | backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=backbone.name) 104 | 105 | # invoke modifier if given 106 | if modifier: 107 | backbone = modifier(backbone) 108 | 109 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone.outputs, **kwargs) 110 | -------------------------------------------------------------------------------- /keras_retinanet/utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import cv2 18 | import numpy as np 19 | 20 | from .colors import label_color 21 | 22 | 23 | def draw_box(image, box, color, thickness=1): 24 | """ Draws a box on an image with a given color. 25 | 26 | # Arguments 27 | image : The image to draw on. 28 | box : A list of 4 elements (x1, y1, x2, y2). 29 | color : The color of the box. 30 | thickness : The thickness of the lines to draw a box with. 31 | """ 32 | b = np.array(box).astype(int) 33 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, thickness) #remove cv2.LINE_AA 34 | 35 | 36 | def draw_caption(image, box, caption, fontScale=1): 37 | """ Draws a caption above the box in an image. 38 | 39 | # Arguments 40 | image : The image to draw on. 41 | box : A list of 4 elements (x1, y1, x2, y2). 42 | caption : String containing the text to draw. 43 | """ 44 | b = np.array(box).astype(int) 45 | cv2.putText(image, caption, (b[0], b[1] - 3), cv2.FONT_HERSHEY_DUPLEX, fontScale, (0, 0, 0), 2) 46 | cv2.putText(image, caption, (b[0], b[1] - 3), cv2.FONT_HERSHEY_DUPLEX, fontScale, (255, 255, 255), 1) 47 | 48 | 49 | def draw_boxes(image, boxes, color, thickness=1): 50 | """ Draws boxes on an image with a given color. 51 | 52 | # Arguments 53 | image : The image to draw on. 54 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 55 | color : The color of the boxes. 56 | thickness : The thickness of the lines to draw boxes with. 57 | """ 58 | for b in boxes: 59 | draw_box(image, b, color, thickness=thickness) 60 | 61 | 62 | def draw_detections(image, boxes, scores, labels, color=None, label_to_name=None, score_threshold=0.5): 63 | """ Draws detections in an image. 64 | 65 | # Arguments 66 | image : The image to draw on. 67 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 68 | scores : A list of N classification scores. 69 | labels : A list of N labels. 70 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 71 | label_to_name : (optional) Functor for mapping a label to a name. 72 | score_threshold : Threshold used for determining what detections to draw. 73 | """ 74 | selection = np.where(scores > score_threshold)[0] 75 | 76 | for i in selection: 77 | c = color if color is not None else label_color(labels[i]) 78 | draw_box(image, boxes[i, :], color=c) 79 | 80 | # draw labels 81 | caption = (label_to_name(labels[i]) if label_to_name else labels[i]) + ': {0:.2f}'.format(scores[i]) 82 | fontScale = 0.7 * image.shape[0] / 800 83 | draw_caption(image, boxes[i, :], caption, fontScale) 84 | 85 | 86 | def draw_annotations(image, annotations, color=(0, 255, 0), label_to_name=None): 87 | """ Draws annotations in an image. 88 | 89 | # Arguments 90 | image : The image to draw on. 91 | annotations : A [N, 5] matrix (x1, y1, x2, y2, label) or dictionary containing bboxes (shaped [N, 4]) and labels (shaped [N]). 92 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 93 | label_to_name : (optional) Functor for mapping a label to a name. 94 | """ 95 | if isinstance(annotations, np.ndarray): 96 | annotations = {'bboxes': annotations[:, :4], 'labels': annotations[:, 4]} 97 | 98 | assert('bboxes' in annotations) 99 | assert('labels' in annotations) 100 | assert(annotations['bboxes'].shape[0] == annotations['labels'].shape[0]) 101 | 102 | for i in range(annotations['bboxes'].shape[0]): 103 | label = annotations['labels'][i] 104 | c = color if color is not None else label_color(label) 105 | caption = '{}'.format(label_to_name(label) if label_to_name else label) 106 | fontScale = 0.7 * image.shape[0] / 800 107 | #draw_caption(image, annotations['bboxes'][i], caption, fontScale) 108 | draw_box(image, annotations['bboxes'][i], color=c) 109 | -------------------------------------------------------------------------------- /keras_retinanet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | 4 | 5 | class Backbone(object): 6 | """ This class stores additional information on backbones. 7 | """ 8 | def __init__(self, backbone, *args): 9 | # a dictionary mapping custom layer names to the correct classes 10 | from .. import layers 11 | from .. import losses 12 | from .. import initializers 13 | self.custom_objects = { 14 | 'UpsampleLike' : layers.UpsampleLike, 15 | 'PriorProbability' : initializers.PriorProbability, 16 | 'RegressBoxes' : layers.RegressBoxes, 17 | 'FilterDetections' : layers.FilterDetections, 18 | 'Anchors' : layers.Anchors, 19 | 'ClipBoxes' : layers.ClipBoxes, 20 | '_smooth_l1' : losses.smooth_l1(), 21 | '_focal' : losses.focal(), 22 | } 23 | 24 | self.backbone = backbone 25 | self.validate() 26 | 27 | def retinanet(self, *args, **kwargs): 28 | """ Returns a retinanet model using the correct backbone. 29 | """ 30 | raise NotImplementedError('retinanet method not implemented.') 31 | 32 | def download_imagenet(self): 33 | """ Downloads ImageNet weights and returns path to weights file. 34 | """ 35 | raise NotImplementedError('download_imagenet method not implemented.') 36 | 37 | def validate(self): 38 | """ Checks whether the backbone string is correct. 39 | """ 40 | raise NotImplementedError('validate method not implemented.') 41 | 42 | def preprocess_image(self, inputs): 43 | """ Takes as input an image and prepares it for being passed through the network. 44 | Having this function in Backbone allows other backbones to define a specific preprocessing step. 45 | """ 46 | raise NotImplementedError('preprocess_image method not implemented.') 47 | 48 | 49 | def backbone(backbone_name): 50 | """ Returns a backbone object for the given backbone. 51 | """ 52 | if 'resnet' in backbone_name: 53 | from .resnet import ResNetBackbone as b 54 | elif 'mobilenet' in backbone_name: 55 | from .mobilenet import MobileNetBackbone as b 56 | elif 'vgg' in backbone_name: 57 | from .vgg import VGGBackbone as b 58 | elif 'densenet' in backbone_name: 59 | from .densenet import DenseNetBackbone as b 60 | else: 61 | raise NotImplementedError('Backbone class for \'{}\' not implemented.'.format(backbone)) 62 | 63 | return b(backbone_name) 64 | 65 | 66 | def load_model(filepath, backbone_name='resnet50'): 67 | """ Loads a retinanet model using the correct custom objects. 68 | 69 | Args 70 | filepath: one of the following: 71 | - string, path to the saved model, or 72 | - h5py.File object from which to load the model 73 | backbone_name : Backbone with which the model was trained. 74 | 75 | Returns 76 | A keras.models.Model object. 77 | 78 | Raises 79 | ImportError: if h5py is not available. 80 | ValueError: In case of an invalid savefile. 81 | """ 82 | import keras.models 83 | return keras.models.load_model(filepath, custom_objects=backbone(backbone_name).custom_objects) 84 | 85 | 86 | def convert_model(model, nms=True, class_specific_filter=True, nms_threshold=0.5, score_threshold=0.05, max_detections=300, anchor_params=None): 87 | """ Converts a training model to an inference model. 88 | 89 | Args 90 | model : A retinanet training model. 91 | nms : Boolean, whether to add NMS filtering to the converted model. 92 | class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only. 93 | anchor_params : Anchor parameters object. If omitted, default values are used. 94 | 95 | Returns 96 | A keras.models.Model object. 97 | 98 | Raises 99 | ImportError: if h5py is not available. 100 | ValueError: In case of an invalid savefile. 101 | """ 102 | from .retinanet import retinanet_bbox 103 | return retinanet_bbox( 104 | model = model, 105 | nms = nms, 106 | class_specific_filter = class_specific_filter, 107 | nms_threshold = nms_threshold, 108 | score_threshold = score_threshold, 109 | max_detections = max_detections, 110 | anchor_params = anchor_params 111 | ) 112 | 113 | 114 | def assert_training_model(model): 115 | """ Assert that the model is a training model. 116 | """ 117 | assert(all(output in model.output_names for output in ['regression', 'classification'])), \ 118 | "Input is not a training model (no 'regression' and 'classification' outputs were found, outputs are: {}).".format(model.output_names) 119 | 120 | 121 | def check_training_model(model): 122 | """ Check that model is a training model and exit otherwise. 123 | """ 124 | try: 125 | assert_training_model(model) 126 | except AssertionError as e: 127 | print(e, file=sys.stderr) 128 | sys.exit(1) 129 | -------------------------------------------------------------------------------- /keras_retinanet/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.utils import get_file 19 | import keras_resnet 20 | import keras_resnet.models 21 | 22 | from . import retinanet 23 | from . import Backbone 24 | from ..utils.image import preprocess_image 25 | 26 | 27 | class ResNetBackbone(Backbone): 28 | """ Describes backbone information and provides utility functions. 29 | """ 30 | 31 | def __init__(self, backbone): 32 | super(ResNetBackbone, self).__init__(backbone) 33 | self.custom_objects.update(keras_resnet.custom_objects) 34 | 35 | def retinanet(self, *args, **kwargs): 36 | """ Returns a retinanet model using the correct backbone. 37 | """ 38 | return resnet_retinanet(*args, backbone=self.backbone, **kwargs) 39 | 40 | def download_imagenet(self): 41 | """ Downloads ImageNet weights and returns path to weights file. 42 | """ 43 | resnet_filename = 'ResNet-{}-model.keras.h5' 44 | resnet_resource = 'https://github.com/fizyr/keras-models/releases/download/v0.0.1/{}'.format(resnet_filename) 45 | depth = int(self.backbone.replace('resnet', '')) 46 | 47 | filename = resnet_filename.format(depth) 48 | resource = resnet_resource.format(depth) 49 | if depth == 50: 50 | checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319' 51 | elif depth == 101: 52 | checksum = '05dc86924389e5b401a9ea0348a3213c' 53 | elif depth == 152: 54 | checksum = '6ee11ef2b135592f8031058820bb9e71' 55 | 56 | return get_file( 57 | filename, 58 | resource, 59 | cache_subdir='models', 60 | md5_hash=checksum 61 | ) 62 | 63 | def validate(self): 64 | """ Checks whether the backbone string is correct. 65 | """ 66 | allowed_backbones = ['resnet50', 'resnet101', 'resnet152'] 67 | backbone = self.backbone.split('_')[0] 68 | 69 | if backbone not in allowed_backbones: 70 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) 71 | 72 | def preprocess_image(self, inputs): 73 | """ Takes as input an image and prepares it for being passed through the network. 74 | """ 75 | return preprocess_image(inputs, mode='caffe') 76 | 77 | 78 | def resnet_retinanet(num_classes, backbone='resnet50', inputs=None, modifier=None, **kwargs): 79 | """ Constructs a retinanet model using a resnet backbone. 80 | 81 | Args 82 | num_classes: Number of classes to predict. 83 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). 84 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 85 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 86 | 87 | Returns 88 | RetinaNet model with a ResNet backbone. 89 | """ 90 | # choose default input 91 | if inputs is None: 92 | if keras.backend.image_data_format() == 'channels_first': 93 | inputs = keras.layers.Input(shape=(3, None, None)) 94 | else: 95 | inputs = keras.layers.Input(shape=(None, None, 3)) 96 | 97 | # create the resnet backbone 98 | if backbone == 'resnet50': 99 | resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True) 100 | elif backbone == 'resnet101': 101 | resnet = keras_resnet.models.ResNet101(inputs, include_top=False, freeze_bn=True) 102 | elif backbone == 'resnet152': 103 | resnet = keras_resnet.models.ResNet152(inputs, include_top=False, freeze_bn=True) 104 | else: 105 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) 106 | 107 | # invoke modifier if given 108 | if modifier: 109 | resnet = modifier(resnet) 110 | 111 | # create the full model 112 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=resnet.outputs[1:], **kwargs) 113 | 114 | 115 | def resnet50_retinanet(num_classes, inputs=None, **kwargs): 116 | return resnet_retinanet(num_classes=num_classes, backbone='resnet50', inputs=inputs, **kwargs) 117 | 118 | 119 | def resnet101_retinanet(num_classes, inputs=None, **kwargs): 120 | return resnet_retinanet(num_classes=num_classes, backbone='resnet101', inputs=inputs, **kwargs) 121 | 122 | 123 | def resnet152_retinanet(num_classes, inputs=None, **kwargs): 124 | return resnet_retinanet(num_classes=num_classes, backbone='resnet152', inputs=inputs, **kwargs) 125 | -------------------------------------------------------------------------------- /keras_retinanet/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | import tensorflow as tf 19 | from . import backend 20 | 21 | 22 | def focal(alpha=0.25, gamma=2.0): 23 | """ Create a functor for computing the focal loss. 24 | 25 | Args 26 | alpha: Scale the focal weight with alpha. 27 | gamma: Take the power of the focal weight with gamma. 28 | 29 | Returns 30 | A functor that computes the focal loss using the alpha and gamma. 31 | """ 32 | 33 | average_focal = tf.Variable(1.0) 34 | 35 | def _focal(y_true, y_pred): 36 | """ Compute the focal loss given the target tensor and the predicted tensor. 37 | 38 | As defined in https://arxiv.org/abs/1708.02002 39 | 40 | Args 41 | y_true: Tensor of target data from the generator with shape (B, N, num_classes). 42 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes). 43 | 44 | Returns 45 | The focal loss of y_pred w.r.t. y_true. 46 | """ 47 | labels = y_true[:, :, :-1] 48 | anchor_state = y_true[:, :, -1] # -1 for ignore, 0 for background, 1 for object 49 | classification = y_pred 50 | 51 | # filter out "ignore" anchors 52 | indices = backend.where(keras.backend.not_equal(anchor_state, -1)) 53 | labels = backend.gather_nd(labels, indices) 54 | classification = backend.gather_nd(classification, indices) 55 | 56 | # compute the focal loss 57 | alpha_factor = keras.backend.ones_like(labels) * alpha 58 | alpha_factor = backend.where(keras.backend.equal(labels, 1), alpha_factor, 1 - alpha_factor) 59 | focal_weight = backend.where(keras.backend.equal(labels, 1), 1 - classification, classification) 60 | focal_weight = alpha_factor * focal_weight ** gamma 61 | 62 | cls_loss = focal_weight * keras.backend.binary_crossentropy(labels, classification) 63 | 64 | # compute the normalizer: the number of positive anchors 65 | normalizer = backend.where(keras.backend.equal(anchor_state, 1)) 66 | normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx()) 67 | normalizer = backend.where(keras.backend.less(normalizer, 1.0), average_focal, normalizer) 68 | 69 | assign_op_focal = average_focal.assign(0.99 * average_focal.value() + 0.01 * normalizer) 70 | 71 | with tf.control_dependencies([assign_op_focal]): 72 | return keras.backend.sum(cls_loss) / normalizer 73 | 74 | return _focal 75 | 76 | 77 | def smooth_l1(sigma=3.0): 78 | """ Create a smooth L1 loss functor. 79 | 80 | Args 81 | sigma: This argument defines the point where the loss changes from L2 to L1. 82 | 83 | Returns 84 | A functor for computing the smooth L1 loss given target data and predicted data. 85 | """ 86 | sigma_squared = sigma ** 2 87 | average_smooth = tf.Variable(1.0) 88 | 89 | def _smooth_l1(y_true, y_pred): 90 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true. 91 | 92 | Args 93 | y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive). 94 | y_pred: Tensor from the network of shape (B, N, 4). 95 | 96 | Returns 97 | The smooth L1 loss of y_pred w.r.t. y_true. 98 | """ 99 | # separate target and state 100 | regression = y_pred 101 | regression_target = y_true[:, :, :-1] 102 | anchor_state = y_true[:, :, -1] 103 | 104 | # filter out "ignore" anchors 105 | indices = backend.where(keras.backend.equal(anchor_state, 1)) 106 | regression = backend.gather_nd(regression, indices) 107 | regression_target = backend.gather_nd(regression_target, indices) 108 | 109 | # compute smooth L1 loss 110 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma 111 | # |x| - 0.5 / sigma / sigma otherwise 112 | regression_diff = regression - regression_target 113 | regression_diff = keras.backend.abs(regression_diff) 114 | regression_loss = backend.where( 115 | keras.backend.less(regression_diff, 1.0 / sigma_squared), 116 | 0.5 * sigma_squared * keras.backend.pow(regression_diff, 2), 117 | regression_diff - 0.5 / sigma_squared 118 | ) 119 | 120 | # compute the normalizer: the number of positive anchors 121 | normalizer = keras.backend.maximum(1, keras.backend.shape(indices)[0]) 122 | normalizer = keras.backend.cast(normalizer, dtype=keras.backend.floatx()) 123 | normalizer = backend.where(keras.backend.less(normalizer, 1.0), average_smooth, normalizer) 124 | 125 | assign_op_smooth = average_smooth.assign(0.99 * average_smooth.value() + 0.01 * normalizer) 126 | 127 | with tf.control_dependencies([assign_op_smooth]): 128 | return keras.backend.sum(regression_loss) / normalizer 129 | 130 | return _smooth_l1 131 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | import numpy as np 19 | from ..utils.eval import evaluate 20 | 21 | 22 | class Evaluate(keras.callbacks.Callback): 23 | """ Evaluation callback for arbitrary datasets. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | generator, 29 | iou_threshold=0.5, 30 | score_threshold=0.05, 31 | max_detections=100, 32 | save_path=None, 33 | tensorboard=None, 34 | weighted_average=False, 35 | verbose=1 36 | ): 37 | """ Evaluate a given dataset using a given model at the end of every epoch during training. 38 | 39 | # Arguments 40 | generator : The generator that represents the dataset to evaluate. 41 | iou_threshold : The threshold used to consider when a detection is positive or negative. 42 | score_threshold : The score confidence threshold to use for detections. 43 | max_detections : The maximum number of detections to use per image. 44 | save_path : The path to save images with visualized detections to. 45 | tensorboard : Instance of keras.callbacks.TensorBoard used to log the mAP value. 46 | weighted_average : Compute the mAP using the weighted average of precisions among classes. 47 | verbose : Set the verbosity level, by default this is set to 1. 48 | """ 49 | self.generator = generator 50 | self.iou_threshold = iou_threshold 51 | self.score_threshold = score_threshold 52 | self.max_detections = max_detections 53 | self.save_path = save_path 54 | self.tensorboard = tensorboard 55 | self.weighted_average = weighted_average 56 | self.verbose = verbose 57 | 58 | super(Evaluate, self).__init__() 59 | 60 | def on_epoch_end(self, epoch, logs=None): 61 | logs = logs or {} 62 | 63 | # run evaluation 64 | average_precisions, pr_curves = evaluate( 65 | self.generator, 66 | self.model, 67 | iou_threshold=self.iou_threshold, 68 | score_threshold=self.score_threshold, 69 | max_detections=self.max_detections, 70 | save_path=self.save_path 71 | ) 72 | 73 | # compute per class average precision and F1-score 74 | total_instances = [] 75 | precisions = [] 76 | f1_scores = [] 77 | mean_ious = [] 78 | 79 | for label, (average_precision, num_annotations ) in average_precisions.items(): 80 | if self.verbose == 2: 81 | print('{:.0f} instances of class'.format(num_annotations), 82 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 83 | total_instances.append(num_annotations) 84 | precisions.append(average_precision) 85 | try: 86 | f1_scores.append(np.max(pr_curves[label]['f1_score'])) 87 | mean_ious.append(np.mean(pr_curves[label]['average_iou'])) 88 | except: 89 | f1_scores.append(0) 90 | mean_ious.append(0) 91 | 92 | if self.weighted_average: 93 | self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances) 94 | self.mean_f1 = sum([a * b for a, b in zip(total_instances, f1_scores)]) / sum(total_instances) 95 | self.mean_iou = sum(mean_ious) / sum(x > 0 for x in total_instances) 96 | else: 97 | self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 98 | self.mean_f1 = sum(f1_scores) / sum(x > 0 for x in total_instances) 99 | self.mean_iou = sum(mean_ious) / sum(x > 0 for x in total_instances) 100 | 101 | if self.tensorboard is not None and self.tensorboard.writer is not None: 102 | import tensorflow as tf 103 | summary = tf.Summary() 104 | summary_value = summary.value.add() 105 | summary_value.simple_value = self.mean_ap 106 | summary_value.tag = "mAP" 107 | self.tensorboard.writer.add_summary(summary, epoch) 108 | 109 | logs['mAP'] = self.mean_ap 110 | logs['mF1'] = self.mean_f1 111 | logs['mIoU'] = self.mean_iou 112 | 113 | if self.verbose == 1: 114 | for label in range(self.generator.num_classes()): 115 | true_positives = int(pr_curves[label]['TP'][-1]) if len(pr_curves[label]['TP']) > 0 else 0 116 | false_negatives = int(pr_curves[label]['FP'][-1]) if len(pr_curves[label]['FP']) > 0 else 0 117 | print('\nClass {}: Instances: {} | Predictions: {} | False positives: {} | True positives: {}'.format( 118 | self.generator.label_to_name(label), 119 | int(total_instances[label]), 120 | true_positives + false_negatives, 121 | false_negatives, true_positives)) 122 | 123 | print('mAP: {:.4f}'.format(self.mean_ap), 'mF1-score: {:.4f}'.format(self.mean_f1), 'mIoU: {:.4f}'.format(self.mean_iou)) 124 | 125 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from ..preprocessing.generator import Generator 18 | from ..utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | 23 | from pycocotools.coco import COCO 24 | 25 | 26 | class CocoGenerator(Generator): 27 | """ Generate data from the COCO dataset. 28 | 29 | See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information. 30 | """ 31 | 32 | def __init__(self, data_dir, set_name, **kwargs): 33 | """ Initialize a COCO data generator. 34 | 35 | Args 36 | data_dir: Path to where the COCO dataset is stored. 37 | set_name: Name of the set to parse. 38 | """ 39 | self.data_dir = data_dir 40 | self.set_name = set_name 41 | self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json')) 42 | self.image_ids = self.coco.getImgIds() 43 | 44 | self.load_classes() 45 | 46 | super(CocoGenerator, self).__init__(**kwargs) 47 | 48 | def load_classes(self): 49 | """ Loads the class to label mapping (and inverse) for COCO. 50 | """ 51 | # load class names (name -> label) 52 | categories = self.coco.loadCats(self.coco.getCatIds()) 53 | categories.sort(key=lambda x: x['id']) 54 | 55 | self.classes = {} 56 | self.coco_labels = {} 57 | self.coco_labels_inverse = {} 58 | for c in categories: 59 | self.coco_labels[len(self.classes)] = c['id'] 60 | self.coco_labels_inverse[c['id']] = len(self.classes) 61 | self.classes[c['name']] = len(self.classes) 62 | 63 | # also load the reverse (label -> name) 64 | self.labels = {} 65 | for key, value in self.classes.items(): 66 | self.labels[value] = key 67 | 68 | def size(self): 69 | """ Size of the COCO dataset. 70 | """ 71 | return len(self.image_ids) 72 | 73 | def num_classes(self): 74 | """ Number of classes in the dataset. For COCO this is 80. 75 | """ 76 | return len(self.classes) 77 | 78 | def has_label(self, label): 79 | """ Return True if label is a known label. 80 | """ 81 | return label in self.labels 82 | 83 | def has_name(self, name): 84 | """ Returns True if name is a known class. 85 | """ 86 | return name in self.classes 87 | 88 | def name_to_label(self, name): 89 | """ Map name to label. 90 | """ 91 | return self.classes[name] 92 | 93 | def label_to_name(self, label): 94 | """ Map label to name. 95 | """ 96 | return self.labels[label] 97 | 98 | def coco_label_to_label(self, coco_label): 99 | """ Map COCO label to the label as used in the network. 100 | COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes. 101 | """ 102 | return self.coco_labels_inverse[coco_label] 103 | 104 | def coco_label_to_name(self, coco_label): 105 | """ Map COCO label to name. 106 | """ 107 | return self.label_to_name(self.coco_label_to_label(coco_label)) 108 | 109 | def label_to_coco_label(self, label): 110 | """ Map label as used by the network to labels as used by COCO. 111 | """ 112 | return self.coco_labels[label] 113 | 114 | def image_aspect_ratio(self, image_index): 115 | """ Compute the aspect ratio for an image with image_index. 116 | """ 117 | image = self.coco.loadImgs(self.image_ids[image_index])[0] 118 | return float(image['width']) / float(image['height']) 119 | 120 | def load_image(self, image_index): 121 | """ Load an image at the image_index. 122 | """ 123 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 124 | path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name']) 125 | return read_image_bgr(path) 126 | 127 | def load_annotations(self, image_index): 128 | """ Load annotations for an image_index. 129 | """ 130 | # get ground truth annotations 131 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 132 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 133 | 134 | # some images appear to miss annotations (like image with id 257034) 135 | if len(annotations_ids) == 0: 136 | return annotations 137 | 138 | # parse annotations 139 | coco_annotations = self.coco.loadAnns(annotations_ids) 140 | for idx, a in enumerate(coco_annotations): 141 | # some annotations have basically no width / height, skip them 142 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 143 | continue 144 | 145 | annotations['labels'] = np.concatenate([annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0) 146 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[ 147 | a['bbox'][0], 148 | a['bbox'][1], 149 | a['bbox'][0] + a['bbox'][2], 150 | a['bbox'][1] + a['bbox'][3], 151 | ]]], axis=0) 152 | 153 | return annotations 154 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import csv 18 | import os.path 19 | 20 | import numpy as np 21 | from PIL import Image 22 | 23 | from .generator import Generator 24 | from ..utils.image import read_image_bgr 25 | 26 | kitti_classes = { 27 | 'Car': 0, 28 | 'Van': 1, 29 | 'Truck': 2, 30 | 'Pedestrian': 3, 31 | 'Person_sitting': 4, 32 | 'Cyclist': 5, 33 | 'Tram': 6, 34 | 'Misc': 7, 35 | 'DontCare': 7 36 | } 37 | 38 | 39 | class KittiGenerator(Generator): 40 | """ Generate data for a KITTI dataset. 41 | 42 | See http://www.cvlibs.net/datasets/kitti/ for more information. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | base_dir, 48 | subset='train', 49 | **kwargs 50 | ): 51 | """ Initialize a KITTI data generator. 52 | 53 | Args 54 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 55 | subset: The subset to generate data for (defaults to 'train'). 56 | """ 57 | self.base_dir = base_dir 58 | 59 | label_dir = os.path.join(self.base_dir, subset, 'labels') 60 | image_dir = os.path.join(self.base_dir, subset, 'images') 61 | 62 | """ 63 | 1 type Describes the type of object: 'Car', 'Van', 'Truck', 64 | 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 65 | 'Misc' or 'DontCare' 66 | 1 truncated Float from 0 (non-truncated) to 1 (truncated), where 67 | truncated refers to the object leaving image boundaries 68 | 1 occluded Integer (0,1,2,3) indicating occlusion state: 69 | 0 = fully visible, 1 = partly occluded 70 | 2 = largely occluded, 3 = unknown 71 | 1 alpha Observation angle of object, ranging [-pi..pi] 72 | 4 bbox 2D bounding box of object in the image (0-based index): 73 | contains left, top, right, bottom pixel coordinates 74 | 3 dimensions 3D object dimensions: height, width, length (in meters) 75 | 3 location 3D object location x,y,z in camera coordinates (in meters) 76 | 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi] 77 | """ 78 | 79 | self.labels = {} 80 | self.classes = kitti_classes 81 | for name, label in self.classes.items(): 82 | self.labels[label] = name 83 | 84 | self.image_data = dict() 85 | self.images = [] 86 | for i, fn in enumerate(os.listdir(label_dir)): 87 | label_fp = os.path.join(label_dir, fn) 88 | image_fp = os.path.join(image_dir, fn.replace('.txt', '.png')) 89 | 90 | self.images.append(image_fp) 91 | 92 | fieldnames = ['type', 'truncated', 'occluded', 'alpha', 'left', 'top', 'right', 'bottom', 'dh', 'dw', 'dl', 93 | 'lx', 'ly', 'lz', 'ry'] 94 | with open(label_fp, 'r') as csv_file: 95 | reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames) 96 | boxes = [] 97 | for line, row in enumerate(reader): 98 | label = row['type'] 99 | cls_id = kitti_classes[label] 100 | 101 | annotation = {'cls_id': cls_id, 'x1': row['left'], 'x2': row['right'], 'y2': row['bottom'], 'y1': row['top']} 102 | boxes.append(annotation) 103 | 104 | self.image_data[i] = boxes 105 | 106 | super(KittiGenerator, self).__init__(**kwargs) 107 | 108 | def size(self): 109 | """ Size of the dataset. 110 | """ 111 | return len(self.images) 112 | 113 | def num_classes(self): 114 | """ Number of classes in the dataset. 115 | """ 116 | return max(self.classes.values()) + 1 117 | 118 | def has_label(self, label): 119 | """ Return True if label is a known label. 120 | """ 121 | return label in self.labels 122 | 123 | def has_name(self, name): 124 | """ Returns True if name is a known class. 125 | """ 126 | return name in self.classes 127 | 128 | def name_to_label(self, name): 129 | """ Map name to label. 130 | """ 131 | raise NotImplementedError() 132 | 133 | def label_to_name(self, label): 134 | """ Map label to name. 135 | """ 136 | return self.labels[label] 137 | 138 | def image_aspect_ratio(self, image_index): 139 | """ Compute the aspect ratio for an image with image_index. 140 | """ 141 | # PIL is fast for metadata 142 | image = Image.open(self.images[image_index]) 143 | return float(image.width) / float(image.height) 144 | 145 | def load_image(self, image_index): 146 | """ Load an image at the image_index. 147 | """ 148 | return read_image_bgr(self.images[image_index]) 149 | 150 | def load_annotations(self, image_index): 151 | """ Load annotations for an image_index. 152 | """ 153 | image_data = self.image_data[image_index] 154 | annotations = {'labels': np.empty((len(image_data),)), 'bboxes': np.empty((len(image_data), 4))} 155 | 156 | for idx, ann in enumerate(image_data): 157 | annotations['bboxes'][idx, 0] = float(ann['x1']) 158 | annotations['bboxes'][idx, 1] = float(ann['y1']) 159 | annotations['bboxes'][idx, 2] = float(ann['x2']) 160 | annotations['bboxes'][idx, 3] = float(ann['y2']) 161 | annotations['labels'][idx] = int(ann['cls_id']) 162 | 163 | return annotations 164 | -------------------------------------------------------------------------------- /tests/utils/test_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_almost_equal 3 | from math import pi 4 | 5 | from keras_retinanet.utils.transform import ( 6 | colvec, 7 | transform_aabb, 8 | rotation, random_rotation, 9 | translation, random_translation, 10 | scaling, random_scaling, 11 | shear, random_shear, 12 | random_flip, 13 | random_transform, 14 | random_transform_generator, 15 | change_transform_origin, 16 | ) 17 | 18 | 19 | def test_colvec(): 20 | assert np.array_equal(colvec(0), np.array([[0]])) 21 | assert np.array_equal(colvec(1, 2, 3), np.array([[1], [2], [3]])) 22 | assert np.array_equal(colvec(-1, -2), np.array([[-1], [-2]])) 23 | 24 | 25 | def test_rotation(): 26 | assert_almost_equal(colvec( 1, 0, 1), rotation(0.0 * pi).dot(colvec(1, 0, 1))) 27 | assert_almost_equal(colvec( 0, 1, 1), rotation(0.5 * pi).dot(colvec(1, 0, 1))) 28 | assert_almost_equal(colvec(-1, 0, 1), rotation(1.0 * pi).dot(colvec(1, 0, 1))) 29 | assert_almost_equal(colvec( 0, -1, 1), rotation(1.5 * pi).dot(colvec(1, 0, 1))) 30 | assert_almost_equal(colvec( 1, 0, 1), rotation(2.0 * pi).dot(colvec(1, 0, 1))) 31 | 32 | assert_almost_equal(colvec( 0, 1, 1), rotation(0.0 * pi).dot(colvec(0, 1, 1))) 33 | assert_almost_equal(colvec(-1, 0, 1), rotation(0.5 * pi).dot(colvec(0, 1, 1))) 34 | assert_almost_equal(colvec( 0, -1, 1), rotation(1.0 * pi).dot(colvec(0, 1, 1))) 35 | assert_almost_equal(colvec( 1, 0, 1), rotation(1.5 * pi).dot(colvec(0, 1, 1))) 36 | assert_almost_equal(colvec( 0, 1, 1), rotation(2.0 * pi).dot(colvec(0, 1, 1))) 37 | 38 | 39 | def test_random_rotation(): 40 | prng = np.random.RandomState(0) 41 | for i in range(100): 42 | assert_almost_equal(1, np.linalg.det(random_rotation(-i, i, prng))) 43 | 44 | 45 | def test_translation(): 46 | assert_almost_equal(colvec( 1, 2, 1), translation(colvec( 0, 0)).dot(colvec(1, 2, 1))) 47 | assert_almost_equal(colvec( 4, 6, 1), translation(colvec( 3, 4)).dot(colvec(1, 2, 1))) 48 | assert_almost_equal(colvec(-2, -2, 1), translation(colvec(-3, -4)).dot(colvec(1, 2, 1))) 49 | 50 | 51 | def assert_is_translation(transform, min, max): 52 | assert transform.shape == (3, 3) 53 | assert np.array_equal(transform[:, 0:2], np.eye(3, 2)) 54 | assert transform[2, 2] == 1 55 | assert np.greater_equal(transform[0:2, 2], min).all() 56 | assert np.less( transform[0:2, 2], max).all() 57 | 58 | 59 | def test_random_translation(): 60 | prng = np.random.RandomState(0) 61 | min = (-10, -20) 62 | max = (20, 10) 63 | for i in range(100): 64 | assert_is_translation(random_translation(min, max, prng), min, max) 65 | 66 | 67 | def test_shear(): 68 | assert_almost_equal(colvec( 1, 2, 1), shear(0.0 * pi).dot(colvec(1, 2, 1))) 69 | assert_almost_equal(colvec(-1, 0, 1), shear(0.5 * pi).dot(colvec(1, 2, 1))) 70 | assert_almost_equal(colvec( 1, -2, 1), shear(1.0 * pi).dot(colvec(1, 2, 1))) 71 | assert_almost_equal(colvec( 3, 0, 1), shear(1.5 * pi).dot(colvec(1, 2, 1))) 72 | assert_almost_equal(colvec( 1, 2, 1), shear(2.0 * pi).dot(colvec(1, 2, 1))) 73 | 74 | 75 | def assert_is_shear(transform): 76 | assert transform.shape == (3, 3) 77 | assert np.array_equal(transform[:, 0], [1, 0, 0]) 78 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 79 | assert transform[2, 1] == 0 80 | # sin^2 + cos^2 == 1 81 | assert_almost_equal(1, transform[0, 1] ** 2 + transform[1, 1] ** 2) 82 | 83 | 84 | def test_random_shear(): 85 | prng = np.random.RandomState(0) 86 | for i in range(100): 87 | assert_is_shear(random_shear(-pi, pi, prng)) 88 | 89 | 90 | def test_scaling(): 91 | assert_almost_equal(colvec(1.0, 2, 1), scaling(colvec(1.0, 1.0)).dot(colvec(1, 2, 1))) 92 | assert_almost_equal(colvec(0.0, 2, 1), scaling(colvec(0.0, 1.0)).dot(colvec(1, 2, 1))) 93 | assert_almost_equal(colvec(1.0, 0, 1), scaling(colvec(1.0, 0.0)).dot(colvec(1, 2, 1))) 94 | assert_almost_equal(colvec(0.5, 4, 1), scaling(colvec(0.5, 2.0)).dot(colvec(1, 2, 1))) 95 | 96 | 97 | def assert_is_scaling(transform, min, max): 98 | assert transform.shape == (3, 3) 99 | assert np.array_equal(transform[2, :], [0, 0, 1]) 100 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 101 | assert transform[1, 0] == 0 102 | assert transform[0, 1] == 0 103 | assert np.greater_equal(np.diagonal(transform)[:2], min).all() 104 | assert np.less( np.diagonal(transform)[:2], max).all() 105 | 106 | 107 | def test_random_scaling(): 108 | prng = np.random.RandomState(0) 109 | min = (0.1, 0.2) 110 | max = (20, 10) 111 | for i in range(100): 112 | assert_is_scaling(random_scaling(min, max, prng), min, max) 113 | 114 | 115 | def assert_is_flip(transform): 116 | assert transform.shape == (3, 3) 117 | assert np.array_equal(transform[2, :], [0, 0, 1]) 118 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 119 | assert transform[1, 0] == 0 120 | assert transform[0, 1] == 0 121 | assert abs(transform[0, 0]) == 1 122 | assert abs(transform[1, 1]) == 1 123 | 124 | 125 | def test_random_flip(): 126 | prng = np.random.RandomState(0) 127 | for i in range(100): 128 | assert_is_flip(random_flip(0.5, 0.5, prng)) 129 | 130 | 131 | def test_random_transform(): 132 | prng = np.random.RandomState(0) 133 | for i in range(100): 134 | transform = random_transform(prng=prng) 135 | assert np.array_equal(transform, np.identity(3)) 136 | 137 | for i, transform in zip(range(100), random_transform_generator(prng=np.random.RandomState())): 138 | assert np.array_equal(transform, np.identity(3)) 139 | 140 | 141 | def test_transform_aabb(): 142 | assert np.array_equal([1, 2, 3, 4], transform_aabb(np.identity(3), [1, 2, 3, 4])) 143 | assert_almost_equal([-3, -4, -1, -2], transform_aabb(rotation(pi), [1, 2, 3, 4])) 144 | assert_almost_equal([ 2, 4, 4, 6], transform_aabb(translation([1, 2]), [1, 2, 3, 4])) 145 | 146 | 147 | def test_change_transform_origin(): 148 | assert np.array_equal(change_transform_origin(translation([3, 4]), [1, 2]), translation([3, 4])) 149 | assert_almost_equal(colvec(1, 2, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(1, 2, 1))) 150 | assert_almost_equal(colvec(0, 0, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(2, 4, 1))) 151 | assert_almost_equal(colvec(0, 0, 1), change_transform_origin(scaling([0.5, 0.5]), [-2, -4]).dot(colvec(2, 4, 1))) 152 | -------------------------------------------------------------------------------- /examples/.ipynb_checkpoints/convert_csvs-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 200, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "ename": "ValueError", 10 | "evalue": "could not convert string to float: ", 11 | "output_type": "error", 12 | "traceback": [ 13 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 14 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", 15 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mreader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'csv'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'png'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0mradius\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mradius\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mradius\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 16 | "\u001b[0;31mValueError\u001b[0m: could not convert string to float: " 17 | ] 18 | } 19 | ], 20 | "source": [ 21 | "import csv\n", 22 | "import os\n", 23 | "\n", 24 | "keras_retinanet_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) #this file is in examples/\n", 25 | "annotations_path = os.path.join(keras_retinanet_path,'data/acfr-fruit-dataset/apples/annotations/')\n", 26 | "images_path = os.path.join(keras_retinanet_path,'data/acfr-fruit-dataset/apples/images/')\n", 27 | "\n", 28 | "csv_list = [f for f in os.listdir(annotations_path) if os.path.isfile(os.path.join(annotations_path, f))]\n", 29 | "\n", 30 | "data = []\n", 31 | "c=0\n", 32 | "for file in csv_list:\n", 33 | " with open(os.path.join(annotations_path, file)) as f: \n", 34 | " reader = csv.reader(f)\n", 35 | " \n", 36 | " if sum(1 for line in f) == 1: #empty csv\n", 37 | " row = [os.path.join(images_path,file.replace('csv','png')),'','','','','']\n", 38 | " data.append(row)\n", 39 | " else: #not empty\n", 40 | " with open(os.path.join(annotations_path, file)) as f:\n", 41 | " reader = csv.reader(f)\n", 42 | " next(reader)\n", 43 | " for row in reader:\n", 44 | " row[0] = os.path.join(images_path, file.replace('csv','png'))\n", 45 | " print(row[0])\n", 46 | " radius = float(row[3])\n", 47 | " row[1] = round(float(row[1]) - radius) - 1\n", 48 | " row[2] = round(float(row[2]) - radius) - 1 \n", 49 | " row[3] = round(float(row[1]) + radius)\n", 50 | " row[4] = round(float(row[2]) + radius)\n", 51 | " row.append(1)\n", 52 | " data.append(row)\n", 53 | " c+=1\n", 54 | "\n", 55 | "with open(os.path.join(annotations_path, 'rectangular_annot.csv'), mode='w') as f:\n", 56 | " file = csv.writer(f, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)\n", 57 | " for row in data:\n", 58 | " file.writerow(row)\n", 59 | "\n", 60 | "with open(os.path.join(annotations_path, 'classes.csv'), mode='w') as f:\n", 61 | " file = csv.writer(f, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)\n", 62 | " file.writerow('apple',1)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "scrolled": true 70 | }, 71 | "outputs": [], 72 | "source": [] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "Python 3", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.6.8" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 2 124 | } 125 | -------------------------------------------------------------------------------- /tests/layers/test_filter_detections.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | import keras_retinanet.layers 19 | 20 | import numpy as np 21 | 22 | 23 | class TestFilterDetections(object): 24 | def test_simple(self): 25 | # create simple FilterDetections layer 26 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 27 | 28 | # create simple input 29 | boxes = np.array([[ 30 | [0, 0, 10, 10], 31 | [0, 0, 10, 10], # this will be suppressed 32 | ]], dtype=keras.backend.floatx()) 33 | boxes = keras.backend.constant(boxes) 34 | 35 | classification = np.array([[ 36 | [0, 0.9], # this will be suppressed 37 | [0, 1], 38 | ]], dtype=keras.backend.floatx()) 39 | classification = keras.backend.constant(classification) 40 | 41 | # compute output 42 | actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) 43 | actual_boxes = keras.backend.eval(actual_boxes) 44 | actual_scores = keras.backend.eval(actual_scores) 45 | actual_labels = keras.backend.eval(actual_labels) 46 | 47 | # define expected output 48 | expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) 49 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 50 | 51 | expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 52 | expected_scores[0, 0] = 1 53 | 54 | expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 55 | expected_labels[0, 0] = 1 56 | 57 | # assert actual and expected are equal 58 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 59 | np.testing.assert_array_equal(actual_scores, expected_scores) 60 | np.testing.assert_array_equal(actual_labels, expected_labels) 61 | 62 | def test_simple_with_other(self): 63 | # create simple FilterDetections layer 64 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 65 | 66 | # create simple input 67 | boxes = np.array([[ 68 | [0, 0, 10, 10], 69 | [0, 0, 10, 10], # this will be suppressed 70 | ]], dtype=keras.backend.floatx()) 71 | boxes = keras.backend.constant(boxes) 72 | 73 | classification = np.array([[ 74 | [0, 0.9], # this will be suppressed 75 | [0, 1], 76 | ]], dtype=keras.backend.floatx()) 77 | classification = keras.backend.constant(classification) 78 | 79 | other = [] 80 | other.append(np.array([[ 81 | [0, 1234], # this will be suppressed 82 | [0, 5678], 83 | ]], dtype=keras.backend.floatx())) 84 | other.append(np.array([[ 85 | 5678, # this will be suppressed 86 | 1234, 87 | ]], dtype=keras.backend.floatx())) 88 | other = [keras.backend.constant(o) for o in other] 89 | 90 | # compute output 91 | actual = filter_detections_layer.call([boxes, classification] + other) 92 | actual_boxes = keras.backend.eval(actual[0]) 93 | actual_scores = keras.backend.eval(actual[1]) 94 | actual_labels = keras.backend.eval(actual[2]) 95 | actual_other = [keras.backend.eval(a) for a in actual[3:]] 96 | 97 | # define expected output 98 | expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) 99 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 100 | 101 | expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 102 | expected_scores[0, 0] = 1 103 | 104 | expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 105 | expected_labels[0, 0] = 1 106 | 107 | expected_other = [] 108 | expected_other.append(-1 * np.ones((1, 300, 2), dtype=keras.backend.floatx())) 109 | expected_other[-1][0, 0, :] = [0, 5678] 110 | expected_other.append(-1 * np.ones((1, 300), dtype=keras.backend.floatx())) 111 | expected_other[-1][0, 0] = 1234 112 | 113 | # assert actual and expected are equal 114 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 115 | np.testing.assert_array_equal(actual_scores, expected_scores) 116 | np.testing.assert_array_equal(actual_labels, expected_labels) 117 | 118 | for a, e in zip(actual_other, expected_other): 119 | np.testing.assert_array_equal(a, e) 120 | 121 | def test_mini_batch(self): 122 | # create simple FilterDetections layer 123 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 124 | 125 | # create input with batch_size=2 126 | boxes = np.array([ 127 | [ 128 | [0, 0, 10, 10], # this will be suppressed 129 | [0, 0, 10, 10], 130 | ], 131 | [ 132 | [100, 100, 150, 150], 133 | [100, 100, 150, 150], # this will be suppressed 134 | ], 135 | ], dtype=keras.backend.floatx()) 136 | boxes = keras.backend.constant(boxes) 137 | 138 | classification = np.array([ 139 | [ 140 | [0, 0.9], # this will be suppressed 141 | [0, 1], 142 | ], 143 | [ 144 | [1, 0], 145 | [0.9, 0], # this will be suppressed 146 | ], 147 | ], dtype=keras.backend.floatx()) 148 | classification = keras.backend.constant(classification) 149 | 150 | # compute output 151 | actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) 152 | actual_boxes = keras.backend.eval(actual_boxes) 153 | actual_scores = keras.backend.eval(actual_scores) 154 | actual_labels = keras.backend.eval(actual_labels) 155 | 156 | # define expected output 157 | expected_boxes = -1 * np.ones((2, 300, 4), dtype=keras.backend.floatx()) 158 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 159 | expected_boxes[1, 0, :] = [100, 100, 150, 150] 160 | 161 | expected_scores = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) 162 | expected_scores[0, 0] = 1 163 | expected_scores[1, 0] = 1 164 | 165 | expected_labels = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) 166 | expected_labels[0, 0] = 1 167 | expected_labels[1, 0] = 0 168 | 169 | # assert actual and expected are equal 170 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 171 | np.testing.assert_array_equal(actual_scores, expected_scores) 172 | np.testing.assert_array_equal(actual_labels, expected_labels) 173 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from ..preprocessing.generator import Generator 18 | from ..utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | from six import raise_from 23 | from PIL import Image 24 | 25 | try: 26 | import xml.etree.cElementTree as ET 27 | except ImportError: 28 | import xml.etree.ElementTree as ET 29 | 30 | voc_classes = { 31 | 'aeroplane' : 0, 32 | 'bicycle' : 1, 33 | 'bird' : 2, 34 | 'boat' : 3, 35 | 'bottle' : 4, 36 | 'bus' : 5, 37 | 'car' : 6, 38 | 'cat' : 7, 39 | 'chair' : 8, 40 | 'cow' : 9, 41 | 'diningtable' : 10, 42 | 'dog' : 11, 43 | 'horse' : 12, 44 | 'motorbike' : 13, 45 | 'person' : 14, 46 | 'pottedplant' : 15, 47 | 'sheep' : 16, 48 | 'sofa' : 17, 49 | 'train' : 18, 50 | 'tvmonitor' : 19 51 | } 52 | 53 | 54 | def _findNode(parent, name, debug_name=None, parse=None): 55 | if debug_name is None: 56 | debug_name = name 57 | 58 | result = parent.find(name) 59 | if result is None: 60 | raise ValueError('missing element \'{}\''.format(debug_name)) 61 | if parse is not None: 62 | try: 63 | return parse(result.text) 64 | except ValueError as e: 65 | raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None) 66 | return result 67 | 68 | 69 | class PascalVocGenerator(Generator): 70 | """ Generate data for a Pascal VOC dataset. 71 | 72 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | data_dir, 78 | set_name, 79 | classes=voc_classes, 80 | image_extension='.jpg', 81 | skip_truncated=False, 82 | skip_difficult=False, 83 | **kwargs 84 | ): 85 | """ Initialize a Pascal VOC data generator. 86 | 87 | Args 88 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 89 | csv_class_file: Path to the CSV classes file. 90 | """ 91 | self.data_dir = data_dir 92 | self.set_name = set_name 93 | self.classes = classes 94 | self.image_names = [l.strip().split(None, 1)[0] for l in open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()] 95 | self.image_extension = image_extension 96 | self.skip_truncated = skip_truncated 97 | self.skip_difficult = skip_difficult 98 | 99 | self.labels = {} 100 | for key, value in self.classes.items(): 101 | self.labels[value] = key 102 | 103 | super(PascalVocGenerator, self).__init__(**kwargs) 104 | 105 | def size(self): 106 | """ Size of the dataset. 107 | """ 108 | return len(self.image_names) 109 | 110 | def num_classes(self): 111 | """ Number of classes in the dataset. 112 | """ 113 | return len(self.classes) 114 | 115 | def has_label(self, label): 116 | """ Return True if label is a known label. 117 | """ 118 | return label in self.labels 119 | 120 | def has_name(self, name): 121 | """ Returns True if name is a known class. 122 | """ 123 | return name in self.classes 124 | 125 | def name_to_label(self, name): 126 | """ Map name to label. 127 | """ 128 | return self.classes[name] 129 | 130 | def label_to_name(self, label): 131 | """ Map label to name. 132 | """ 133 | return self.labels[label] 134 | 135 | def image_aspect_ratio(self, image_index): 136 | """ Compute the aspect ratio for an image with image_index. 137 | """ 138 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 139 | image = Image.open(path) 140 | return float(image.width) / float(image.height) 141 | 142 | def load_image(self, image_index): 143 | """ Load an image at the image_index. 144 | """ 145 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 146 | return read_image_bgr(path) 147 | 148 | def __parse_annotation(self, element): 149 | """ Parse an annotation given an XML element. 150 | """ 151 | truncated = _findNode(element, 'truncated', parse=int) 152 | difficult = _findNode(element, 'difficult', parse=int) 153 | 154 | class_name = _findNode(element, 'name').text 155 | if class_name not in self.classes: 156 | raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys()))) 157 | 158 | box = np.zeros((4,)) 159 | label = self.name_to_label(class_name) 160 | 161 | bndbox = _findNode(element, 'bndbox') 162 | box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1 163 | box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1 164 | box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1 165 | box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1 166 | 167 | return truncated, difficult, box, label 168 | 169 | def __parse_annotations(self, xml_root): 170 | """ Parse all annotations under the xml_root. 171 | """ 172 | annotations = {'labels': np.empty((len(xml_root.findall('object')),)), 'bboxes': np.empty((len(xml_root.findall('object')), 4))} 173 | for i, element in enumerate(xml_root.iter('object')): 174 | try: 175 | truncated, difficult, box, label = self.__parse_annotation(element) 176 | except ValueError as e: 177 | raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None) 178 | 179 | if truncated and self.skip_truncated: 180 | continue 181 | if difficult and self.skip_difficult: 182 | continue 183 | 184 | annotations['bboxes'][i, :] = box 185 | annotations['labels'][i] = label 186 | 187 | return annotations 188 | 189 | def load_annotations(self, image_index): 190 | """ Load annotations for an image_index. 191 | """ 192 | filename = self.image_names[image_index] + '.xml' 193 | try: 194 | tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename)) 195 | return self.__parse_annotations(tree.getroot()) 196 | except ET.ParseError as e: 197 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 198 | except ValueError as e: 199 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 200 | -------------------------------------------------------------------------------- /keras_retinanet/layers/_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from .. import backend 19 | from ..utils import anchors as utils_anchors 20 | 21 | import numpy as np 22 | 23 | 24 | class Anchors(keras.layers.Layer): 25 | """ Keras layer for generating achors for a given shape. 26 | """ 27 | 28 | def __init__(self, size, stride, ratios=None, scales=None, *args, **kwargs): 29 | """ Initializer for an Anchors layer. 30 | 31 | Args 32 | size: The base size of the anchors to generate. 33 | stride: The stride of the anchors to generate. 34 | ratios: The ratios of the anchors to generate (defaults to AnchorParameters.default.ratios). 35 | scales: The scales of the anchors to generate (defaults to AnchorParameters.default.scales). 36 | """ 37 | self.size = size 38 | self.stride = stride 39 | self.ratios = ratios 40 | self.scales = scales 41 | 42 | if ratios is None: 43 | self.ratios = utils_anchors.AnchorParameters.default.ratios 44 | elif isinstance(ratios, list): 45 | self.ratios = np.array(ratios) 46 | if scales is None: 47 | self.scales = utils_anchors.AnchorParameters.default.scales 48 | elif isinstance(scales, list): 49 | self.scales = np.array(scales) 50 | 51 | self.num_anchors = len(ratios) * len(scales) 52 | self.anchors = keras.backend.variable(utils_anchors.generate_anchors( 53 | base_size=size, 54 | ratios=ratios, 55 | scales=scales, 56 | )) 57 | 58 | super(Anchors, self).__init__(*args, **kwargs) 59 | 60 | def call(self, inputs, **kwargs): 61 | features = inputs 62 | features_shape = keras.backend.shape(features) 63 | 64 | # generate proposals from bbox deltas and shifted anchors 65 | if keras.backend.image_data_format() == 'channels_first': 66 | anchors = backend.shift(features_shape[2:4], self.stride, self.anchors) 67 | else: 68 | anchors = backend.shift(features_shape[1:3], self.stride, self.anchors) 69 | anchors = keras.backend.tile(keras.backend.expand_dims(anchors, axis=0), (features_shape[0], 1, 1)) 70 | 71 | return anchors 72 | 73 | def compute_output_shape(self, input_shape): 74 | if None not in input_shape[1:]: 75 | if keras.backend.image_data_format() == 'channels_first': 76 | total = np.prod(input_shape[2:4]) * self.num_anchors 77 | else: 78 | total = np.prod(input_shape[1:3]) * self.num_anchors 79 | 80 | return (input_shape[0], total, 4) 81 | else: 82 | return (input_shape[0], None, 4) 83 | 84 | def get_config(self): 85 | config = super(Anchors, self).get_config() 86 | config.update({ 87 | 'size' : self.size, 88 | 'stride' : self.stride, 89 | 'ratios' : self.ratios.tolist(), 90 | 'scales' : self.scales.tolist(), 91 | }) 92 | 93 | return config 94 | 95 | 96 | class UpsampleLike(keras.layers.Layer): 97 | """ Keras layer for upsampling a Tensor to be the same shape as another Tensor. 98 | """ 99 | 100 | def call(self, inputs, **kwargs): 101 | source, target = inputs 102 | target_shape = keras.backend.shape(target) 103 | if keras.backend.image_data_format() == 'channels_first': 104 | source = backend.transpose(source, (0, 2, 3, 1)) 105 | output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest') 106 | output = backend.transpose(output, (0, 3, 1, 2)) 107 | return output 108 | else: 109 | return backend.resize_images(source, (target_shape[1], target_shape[2]), method='nearest') 110 | 111 | def compute_output_shape(self, input_shape): 112 | if keras.backend.image_data_format() == 'channels_first': 113 | return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4] 114 | else: 115 | return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],) 116 | 117 | 118 | class RegressBoxes(keras.layers.Layer): 119 | """ Keras layer for applying regression values to boxes. 120 | """ 121 | 122 | def __init__(self, mean=None, std=None, *args, **kwargs): 123 | """ Initializer for the RegressBoxes layer. 124 | 125 | Args 126 | mean: The mean value of the regression values which was used for normalization. 127 | std: The standard value of the regression values which was used for normalization. 128 | """ 129 | if mean is None: 130 | mean = np.array([0, 0, 0, 0]) 131 | if std is None: 132 | std = np.array([0.2, 0.2, 0.2, 0.2]) 133 | 134 | if isinstance(mean, (list, tuple)): 135 | mean = np.array(mean) 136 | elif not isinstance(mean, np.ndarray): 137 | raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean))) 138 | 139 | if isinstance(std, (list, tuple)): 140 | std = np.array(std) 141 | elif not isinstance(std, np.ndarray): 142 | raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std))) 143 | 144 | self.mean = mean 145 | self.std = std 146 | super(RegressBoxes, self).__init__(*args, **kwargs) 147 | 148 | def call(self, inputs, **kwargs): 149 | anchors, regression = inputs 150 | return backend.bbox_transform_inv(anchors, regression, mean=self.mean, std=self.std) 151 | 152 | def compute_output_shape(self, input_shape): 153 | return input_shape[0] 154 | 155 | def get_config(self): 156 | config = super(RegressBoxes, self).get_config() 157 | config.update({ 158 | 'mean': self.mean.tolist(), 159 | 'std' : self.std.tolist(), 160 | }) 161 | 162 | return config 163 | 164 | 165 | class ClipBoxes(keras.layers.Layer): 166 | """ Keras layer to clip box values to lie inside a given shape. 167 | """ 168 | 169 | def call(self, inputs, **kwargs): 170 | image, boxes = inputs 171 | shape = keras.backend.cast(keras.backend.shape(image), keras.backend.floatx()) 172 | if keras.backend.image_data_format() == 'channels_first': 173 | height = shape[2] 174 | width = shape[3] 175 | else: 176 | height = shape[1] 177 | width = shape[2] 178 | x1 = backend.clip_by_value(boxes[:, :, 0], 0, width) 179 | y1 = backend.clip_by_value(boxes[:, :, 1], 0, height) 180 | x2 = backend.clip_by_value(boxes[:, :, 2], 0, width) 181 | y2 = backend.clip_by_value(boxes[:, :, 3], 0, height) 182 | 183 | return keras.backend.stack([x1, y1, x2, y2], axis=2) 184 | 185 | def compute_output_shape(self, input_shape): 186 | return input_shape[1] -------------------------------------------------------------------------------- /tests/layers/test_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | import keras_retinanet.layers 19 | 20 | import numpy as np 21 | 22 | 23 | class TestAnchors(object): 24 | def test_simple(self): 25 | # create simple Anchors layer 26 | anchors_layer = keras_retinanet.layers.Anchors( 27 | size=32, 28 | stride=8, 29 | ratios=np.array([1], keras.backend.floatx()), 30 | scales=np.array([1], keras.backend.floatx()), 31 | ) 32 | 33 | # create fake features input (only shape is used anyway) 34 | features = np.zeros((1, 2, 2, 1024), dtype=keras.backend.floatx()) 35 | features = keras.backend.variable(features) 36 | 37 | # call the Anchors layer 38 | anchors = anchors_layer.call(features) 39 | anchors = keras.backend.eval(anchors) 40 | 41 | # expected anchor values 42 | expected = np.array([[ 43 | [-12, -12, 20, 20], 44 | [-4 , -12, 28, 20], 45 | [-12, -4 , 20, 28], 46 | [-4 , -4 , 28, 28], 47 | ]], dtype=keras.backend.floatx()) 48 | 49 | # test anchor values 50 | np.testing.assert_array_equal(anchors, expected) 51 | 52 | # mark test to fail 53 | def test_mini_batch(self): 54 | # create simple Anchors layer 55 | anchors_layer = keras_retinanet.layers.Anchors( 56 | size=32, 57 | stride=8, 58 | ratios=np.array([1], dtype=keras.backend.floatx()), 59 | scales=np.array([1], dtype=keras.backend.floatx()), 60 | ) 61 | 62 | # create fake features input with batch_size=2 63 | features = np.zeros((2, 2, 2, 1024), dtype=keras.backend.floatx()) 64 | features = keras.backend.variable(features) 65 | 66 | # call the Anchors layer 67 | anchors = anchors_layer.call(features) 68 | anchors = keras.backend.eval(anchors) 69 | 70 | # expected anchor values 71 | expected = np.array([[ 72 | [-12, -12, 20, 20], 73 | [-4 , -12, 28, 20], 74 | [-12, -4 , 20, 28], 75 | [-4 , -4 , 28, 28], 76 | ]], dtype=keras.backend.floatx()) 77 | expected = np.tile(expected, (2, 1, 1)) 78 | 79 | # test anchor values 80 | np.testing.assert_array_equal(anchors, expected) 81 | 82 | 83 | class TestUpsampleLike(object): 84 | def test_simple(self): 85 | # create simple UpsampleLike layer 86 | upsample_like_layer = keras_retinanet.layers.UpsampleLike() 87 | 88 | # create input source 89 | source = np.zeros((1, 2, 2, 1), dtype=keras.backend.floatx()) 90 | source = keras.backend.variable(source) 91 | target = np.zeros((1, 5, 5, 1), dtype=keras.backend.floatx()) 92 | expected = target 93 | target = keras.backend.variable(target) 94 | 95 | # compute output 96 | actual = upsample_like_layer.call([source, target]) 97 | actual = keras.backend.eval(actual) 98 | 99 | np.testing.assert_array_equal(actual, expected) 100 | 101 | def test_mini_batch(self): 102 | # create simple UpsampleLike layer 103 | upsample_like_layer = keras_retinanet.layers.UpsampleLike() 104 | 105 | # create input source 106 | source = np.zeros((2, 2, 2, 1), dtype=keras.backend.floatx()) 107 | source = keras.backend.variable(source) 108 | 109 | target = np.zeros((2, 5, 5, 1), dtype=keras.backend.floatx()) 110 | expected = target 111 | target = keras.backend.variable(target) 112 | 113 | # compute output 114 | actual = upsample_like_layer.call([source, target]) 115 | actual = keras.backend.eval(actual) 116 | 117 | np.testing.assert_array_equal(actual, expected) 118 | 119 | 120 | class TestRegressBoxes(object): 121 | def test_simple(self): 122 | mean = [0, 0, 0, 0] 123 | std = [0.2, 0.2, 0.2, 0.2] 124 | 125 | # create simple RegressBoxes layer 126 | regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) 127 | 128 | # create input 129 | anchors = np.array([[ 130 | [0 , 0 , 10 , 10 ], 131 | [50, 50, 100, 100], 132 | [20, 20, 40 , 40 ], 133 | ]], dtype=keras.backend.floatx()) 134 | anchors = keras.backend.variable(anchors) 135 | 136 | regression = np.array([[ 137 | [0 , 0 , 0 , 0 ], 138 | [0.1, 0.1, 0 , 0 ], 139 | [0 , 0 , 0.1, 0.1], 140 | ]], dtype=keras.backend.floatx()) 141 | regression = keras.backend.variable(regression) 142 | 143 | # compute output 144 | actual = regress_boxes_layer.call([anchors, regression]) 145 | actual = keras.backend.eval(actual) 146 | 147 | # compute expected output 148 | expected = np.array([[ 149 | [0 , 0 , 10 , 10 ], 150 | [51, 51, 100 , 100 ], 151 | [20, 20, 40.4, 40.4], 152 | ]], dtype=keras.backend.floatx()) 153 | 154 | np.testing.assert_array_almost_equal(actual, expected, decimal=2) 155 | 156 | # mark test to fail 157 | def test_mini_batch(self): 158 | mean = [0, 0, 0, 0] 159 | std = [0.2, 0.2, 0.2, 0.2] 160 | 161 | # create simple RegressBoxes layer 162 | regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) 163 | 164 | # create input 165 | anchors = np.array([ 166 | [ 167 | [0 , 0 , 10 , 10 ], # 1 168 | [50, 50, 100, 100], # 2 169 | [20, 20, 40 , 40 ], # 3 170 | ], 171 | [ 172 | [20, 20, 40 , 40 ], # 3 173 | [0 , 0 , 10 , 10 ], # 1 174 | [50, 50, 100, 100], # 2 175 | ], 176 | ], dtype=keras.backend.floatx()) 177 | anchors = keras.backend.variable(anchors) 178 | 179 | regression = np.array([ 180 | [ 181 | [0 , 0 , 0 , 0 ], # 1 182 | [0.1, 0.1, 0 , 0 ], # 2 183 | [0 , 0 , 0.1, 0.1], # 3 184 | ], 185 | [ 186 | [0 , 0 , 0.1, 0.1], # 3 187 | [0 , 0 , 0 , 0 ], # 1 188 | [0.1, 0.1, 0 , 0 ], # 2 189 | ], 190 | ], dtype=keras.backend.floatx()) 191 | regression = keras.backend.variable(regression) 192 | 193 | # compute output 194 | actual = regress_boxes_layer.call([anchors, regression]) 195 | actual = keras.backend.eval(actual) 196 | 197 | # compute expected output 198 | expected = np.array([ 199 | [ 200 | [0 , 0 , 10 , 10 ], # 1 201 | [51, 51, 100 , 100 ], # 2 202 | [20, 20, 40.4, 40.4], # 3 203 | ], 204 | [ 205 | [20, 20, 40.4, 40.4], # 3 206 | [0 , 0 , 10 , 10 ], # 1 207 | [51, 51, 100 , 100 ], # 2 208 | ], 209 | ], dtype=keras.backend.floatx()) 210 | 211 | np.testing.assert_array_almost_equal(actual, expected, decimal=2) 212 | -------------------------------------------------------------------------------- /examples/anchor_optimization.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import csv 3 | import argparse 4 | import sys 5 | 6 | import numpy as np 7 | import scipy.optimize 8 | 9 | from PIL import Image 10 | from keras_retinanet.utils.compute_overlap import compute_overlap 11 | 12 | from keras_retinanet.preprocessing.csv_generator import _open_for_csv 13 | from keras_retinanet.utils.anchors import generate_anchors, AnchorParameters, anchors_for_shape 14 | from keras_retinanet.utils.image import compute_resize_scale 15 | 16 | warnings.simplefilter("ignore") 17 | 18 | SIZES = [32, 64, 128, 256, 512] 19 | STRIDES = [8, 16, 32, 64, 128] 20 | state = {'best_result': sys.maxsize} 21 | 22 | 23 | def calculate_config(values, ratio_count): 24 | split_point = int((ratio_count - 1) / 2) 25 | 26 | ratios = [1] 27 | for i in range(split_point): 28 | ratios.append(values[i]) 29 | ratios.append(1 / values[i]) 30 | 31 | scales = values[split_point:] 32 | 33 | return AnchorParameters(SIZES, STRIDES, ratios, scales) 34 | 35 | 36 | def base_anchors_for_shape(pyramid_levels=None, anchor_params=None): 37 | if pyramid_levels is None: 38 | pyramid_levels = [3, 4, 5, 6, 7] 39 | 40 | if anchor_params is None: 41 | anchor_params = AnchorParameters.default 42 | 43 | # compute anchors over all pyramid levels 44 | all_anchors = np.zeros((0, 4)) 45 | for idx, p in enumerate(pyramid_levels): 46 | anchors = generate_anchors( 47 | base_size=anchor_params.sizes[idx], 48 | ratios=anchor_params.ratios, 49 | scales=anchor_params.scales 50 | ) 51 | all_anchors = np.append(all_anchors, anchors, axis=0) 52 | 53 | return all_anchors 54 | 55 | 56 | def average_overlap(values): 57 | global not_matched 58 | anchor_params = calculate_config(values, args.ratios) 59 | 60 | if args.include_stride: 61 | anchors = anchors_for_shape(image_shape, anchor_params=anchor_params) 62 | else: 63 | anchors = base_anchors_for_shape(anchor_params=anchor_params) 64 | 65 | overlap = compute_overlap(entries, anchors) 66 | max_overlap = np.amax(overlap, axis=1) 67 | not_matched = len(np.where(max_overlap < 0.5)[0]) 68 | 69 | if args.objective == 'avg': 70 | result = 1 - np.average(max_overlap) 71 | elif args.objective == 'ce': 72 | result = np.average(-np.log(max_overlap)) 73 | elif args.objective == 'focal': 74 | result = np.average(-(1 - max_overlap) ** 2 * np.log(max_overlap)) 75 | else: 76 | raise Exception('Invalid mode.') 77 | 78 | if result < state['best_result']: 79 | state['best_result'] = result 80 | 81 | print('Current best anchor configuration') 82 | print(f'Ratios: {sorted(np.round(anchor_params.ratios, 3))}') 83 | print(f'Scales: {sorted(np.round(anchor_params.scales, 3))}') 84 | 85 | if args.include_stride: 86 | print(f'Average overlap: {np.round(np.average(max_overlap), 3)}') 87 | 88 | print(f'Number of labels that don\'t have any matching anchor: {not_matched}') 89 | print() 90 | 91 | return result 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser(description='Optimize RetinaNet anchor configuration') 95 | parser.add_argument('annotations', help='Path to CSV file containing annotations for anchor optimization.') 96 | parser.add_argument('--scales', type=int, help='Number of scales.', default=3) 97 | parser.add_argument('--ratios', type=int, help='Number of ratios, has to be an odd number.', default=3) 98 | parser.add_argument('--include-stride', action='store_true', 99 | help='Should stride of the anchors be taken into account. Setting this to false will give ' 100 | 'more accurate results however it is much slower.') 101 | parser.add_argument('--objective', type=str, default='focal', 102 | help='Function used to weight the difference between the target and proposed anchors. ' 103 | 'Options: focal, avg, ce.') 104 | parser.add_argument('--popsize', type=int, default=20, 105 | help='The total population size multiplier used by differential evolution.') 106 | parser.add_argument('--no-resize', help='Disable image resizing.', dest='resize', action='store_false') 107 | parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, 108 | default=800) 109 | parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', 110 | type=int, default=1333) 111 | parser.add_argument('--seed', type=int, help='Seed value to use for differential evolution.') 112 | parser.add_argument('--workers', type=int, help='The number of workers to assess the job to.', default=1) 113 | args = parser.parse_args() 114 | 115 | if args.ratios % 2 != 1: 116 | raise Exception('The number of ratios has to be odd.') 117 | 118 | entries = np.zeros((0, 4)) 119 | max_x = 0 120 | max_y = 0 121 | 122 | if args.seed: 123 | seed = np.random.RandomState(args.seed) 124 | else: 125 | seed = np.random.RandomState() 126 | 127 | print('Loading object dimensions.') 128 | 129 | with _open_for_csv(args.annotations) as file: 130 | for line, row in enumerate(csv.reader(file, delimiter=',')): 131 | if row[1:5] != ['','','','']: 132 | x1, y1, x2, y2 = list(map(lambda x: int(x), row[1:5])) 133 | 134 | if not x1 or not y1 or not x2 or not y2: 135 | continue 136 | 137 | if args.resize: 138 | img = Image.open(row[0]) 139 | scale = compute_resize_scale((img.height, img.width, 1), min_side=args.image_min_side, max_side=args.image_max_side) 140 | x1, y1, x2, y2 = list(map(lambda x: int(x) * scale, row[1:5])) 141 | 142 | max_x = max(x2, max_x) 143 | max_y = max(y2, max_y) 144 | 145 | if args.include_stride: 146 | entry = np.expand_dims(np.array([x1, y1, x2, y2]), axis=0) 147 | entries = np.append(entries, entry, axis=0) 148 | else: 149 | width = x2 - x1 150 | height = y2 - y1 151 | entry = np.expand_dims(np.array([-width / 2, -height / 2, width / 2, height / 2]), axis=0) 152 | entries = np.append(entries, entry, axis=0) 153 | 154 | image_shape = [max_y, max_x] 155 | 156 | print('Optimising anchors.') 157 | 158 | bounds = [] 159 | best_result = sys.maxsize 160 | 161 | for i in range(int((args.ratios - 1) / 2)): 162 | bounds.append((1, 4)) 163 | 164 | for i in range(args.scales): 165 | bounds.append((0.4, 2)) 166 | 167 | result = scipy.optimize.differential_evolution(average_overlap, bounds=bounds, popsize=args.popsize, seed=seed, workers=args.workers) 168 | 169 | if hasattr(result, 'success') and result.success: 170 | print('Optimization ended successfully!') 171 | elif not hasattr(result, 'success'): 172 | print('Optimization ended!') 173 | else: 174 | print('Optimization ended unsuccessfully!') 175 | print(f'Reason: {result.message}') 176 | 177 | values = result.x 178 | anchor_params = calculate_config(values, args.ratios) 179 | avg = average_overlap(values) 180 | 181 | print() 182 | print('Final best anchor configuration') 183 | print(f'Ratios: {sorted(np.round(anchor_params.ratios, 3))}') 184 | print(f'Scales: {sorted(np.round(anchor_params.scales, 3))}') 185 | print(f'Average overlap: {1 - avg}') 186 | print(f'Number of labels that don\'t have any matching anchor: {not_matched}') -------------------------------------------------------------------------------- /keras_retinanet/models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 cgratie (https://github.com/cgratie/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import keras 19 | from keras.utils import get_file 20 | 21 | from . import retinanet 22 | from . import Backbone 23 | from ..utils.image import preprocess_image 24 | 25 | 26 | class VGGBackbone(Backbone): 27 | """ Describes backbone information and provides utility functions. 28 | """ 29 | 30 | def retinanet(self, *args, **kwargs): 31 | """ Returns a retinanet model using the correct backbone. 32 | """ 33 | return vgg_retinanet(*args, backbone=self.backbone, **kwargs) 34 | 35 | def download_imagenet(self): 36 | """ Downloads ImageNet weights and returns path to weights file. 37 | Weights can be downloaded at https://github.com/fizyr/keras-models/releases . 38 | """ 39 | if self.backbone == 'vgg16' or self.backbone == 'vgg13' or self.backbone == 'vgg11' : 40 | resource = keras.applications.vgg16.vgg16.WEIGHTS_PATH_NO_TOP 41 | checksum = '6d6bbae143d832006294945121d1f1fc' 42 | elif self.backbone == 'vgg19': 43 | resource = keras.applications.vgg19.vgg19.WEIGHTS_PATH_NO_TOP 44 | checksum = '253f8cb515780f3b799900260a226db6' 45 | else: 46 | raise ValueError("Backbone '{}' not recognized.".format(self.backbone)) 47 | 48 | return get_file( 49 | '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'.format(self.backbone), 50 | resource, 51 | cache_subdir='models', 52 | file_hash=checksum 53 | ) 54 | 55 | def validate(self): 56 | """ Checks whether the backbone string is correct. 57 | """ 58 | allowed_backbones = ['vgg16', 'vgg19', 'vgg13', 'vgg11'] 59 | 60 | if self.backbone not in allowed_backbones: 61 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones)) 62 | 63 | def preprocess_image(self, inputs, mode='caffe'): 64 | """ Takes as input an image and prepares it for being passed through the network. 65 | """ 66 | return preprocess_image(inputs, mode=mode) 67 | 68 | 69 | def vgg_retinanet(num_classes, backbone='vgg16', inputs=None, modifier=None, **kwargs): 70 | """ Constructs a retinanet model using a vgg backbone. 71 | 72 | Args 73 | num_classes: Number of classes to predict. 74 | backbone: Which backbone to use (one of ('vgg16', 'vgg19')). 75 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 76 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 77 | 78 | Returns 79 | RetinaNet model with a VGG backbone. 80 | """ 81 | # choose default input 82 | if inputs is None: 83 | inputs = keras.layers.Input(shape=(None, None, 3)) 84 | else: 85 | inputs = keras.layers.Input(shape=inputs) 86 | 87 | # create the vgg backbone 88 | if backbone == 'vgg16': 89 | vgg = keras.applications.VGG16(input_tensor=inputs, include_top=False, weights=None) 90 | elif backbone == 'vgg19': 91 | vgg = keras.applications.VGG19(input_tensor=inputs, include_top=False, weights=None) 92 | elif backbone == 'vgg13': 93 | vgg = vgg13(input_tensor=inputs, weights=None) 94 | elif backbone == 'vgg11': 95 | vgg = vgg11(input_tensor=inputs, weights=None) 96 | else: 97 | raise ValueError("Backbone '{}' not recognized.".format(backbone)) 98 | 99 | if modifier: 100 | vgg = modifier(vgg) 101 | 102 | # create the full model 103 | layer_names = ["block3_pool", "block4_pool", "block5_pool"] 104 | layer_outputs = [vgg.get_layer(name).get_output_at(0) for name in layer_names] 105 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=layer_outputs, **kwargs) 106 | 107 | 108 | def vgg13(input_tensor=None, weights=None): 109 | """ 110 | Returns 111 | A lighter version of VGG16. 112 | """ 113 | # Block 1 114 | x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(input_tensor) 115 | x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 116 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 117 | 118 | # Block 2 119 | x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 120 | x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 121 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 122 | 123 | # Block 3 124 | x = keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 125 | x = keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 126 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 127 | 128 | # Block 4 129 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 130 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 131 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 132 | 133 | # Block 5 134 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 135 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 136 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 137 | 138 | return keras.Model(input_tensor, x, name='VGG13') 139 | 140 | def vgg11(input_tensor=None, weights=None): 141 | """ 142 | Returns 143 | A lighter version of VGG16. 144 | """ 145 | # Block 1 146 | x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(input_tensor) 147 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 148 | 149 | # Block 2 150 | x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 151 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 152 | 153 | # Block 3 154 | x = keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 155 | x = keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 156 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 157 | 158 | # Block 4 159 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 160 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 161 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 162 | 163 | # Block 5 164 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 165 | x = keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 166 | x = keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 167 | 168 | return keras.Model(input_tensor, x, name='VGG11') 169 | 170 | -------------------------------------------------------------------------------- /tests/preprocessing/test_csv_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import csv 18 | import pytest 19 | try: 20 | from io import StringIO 21 | except ImportError: 22 | from stringio import StringIO 23 | 24 | from keras_retinanet.preprocessing import csv_generator 25 | 26 | 27 | def csv_str(string): 28 | if str == bytes: 29 | string = string.decode('utf-8') 30 | return csv.reader(StringIO(string)) 31 | 32 | 33 | def annotation(x1, y1, x2, y2, class_name): 34 | return {'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'class': class_name} 35 | 36 | 37 | def test_read_classes(): 38 | assert csv_generator._read_classes(csv_str('')) == {} 39 | assert csv_generator._read_classes(csv_str('a,1')) == {'a': 1} 40 | assert csv_generator._read_classes(csv_str('a,1\nb,2')) == {'a': 1, 'b': 2} 41 | 42 | 43 | def test_read_classes_wrong_format(): 44 | with pytest.raises(ValueError): 45 | try: 46 | csv_generator._read_classes(csv_str('a,b,c')) 47 | except ValueError as e: 48 | assert str(e).startswith('line 1: format should be') 49 | raise 50 | with pytest.raises(ValueError): 51 | try: 52 | csv_generator._read_classes(csv_str('a,1\nb,c,d')) 53 | except ValueError as e: 54 | assert str(e).startswith('line 2: format should be') 55 | raise 56 | 57 | 58 | def test_read_classes_malformed_class_id(): 59 | with pytest.raises(ValueError): 60 | try: 61 | csv_generator._read_classes(csv_str('a,b')) 62 | except ValueError as e: 63 | assert str(e).startswith("line 1: malformed class ID:") 64 | raise 65 | 66 | with pytest.raises(ValueError): 67 | try: 68 | csv_generator._read_classes(csv_str('a,1\nb,c')) 69 | except ValueError as e: 70 | assert str(e).startswith('line 2: malformed class ID:') 71 | raise 72 | 73 | 74 | def test_read_classes_duplicate_name(): 75 | with pytest.raises(ValueError): 76 | try: 77 | csv_generator._read_classes(csv_str('a,1\nb,2\na,3')) 78 | except ValueError as e: 79 | assert str(e).startswith('line 3: duplicate class name') 80 | raise 81 | 82 | 83 | def test_read_annotations(): 84 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 85 | annotations = csv_generator._read_annotations(csv_str( 86 | 'a.png,0,1,2,3,a' '\n' 87 | 'b.png,4,5,6,7,b' '\n' 88 | 'c.png,8,9,10,11,c' '\n' 89 | 'd.png,12,13,14,15,d' '\n' 90 | ), classes) 91 | assert annotations == { 92 | 'a.png': [annotation( 0, 1, 2, 3, 'a')], 93 | 'b.png': [annotation( 4, 5, 6, 7, 'b')], 94 | 'c.png': [annotation( 8, 9, 10, 11, 'c')], 95 | 'd.png': [annotation(12, 13, 14, 15, 'd')], 96 | } 97 | 98 | 99 | def test_read_annotations_multiple(): 100 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 101 | annotations = csv_generator._read_annotations(csv_str( 102 | 'a.png,0,1,2,3,a' '\n' 103 | 'b.png,4,5,6,7,b' '\n' 104 | 'a.png,8,9,10,11,c' '\n' 105 | ), classes) 106 | assert annotations == { 107 | 'a.png': [ 108 | annotation(0, 1, 2, 3, 'a'), 109 | annotation(8, 9, 10, 11, 'c'), 110 | ], 111 | 'b.png': [annotation(4, 5, 6, 7, 'b')], 112 | } 113 | 114 | 115 | def test_read_annotations_wrong_format(): 116 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 117 | with pytest.raises(ValueError): 118 | try: 119 | csv_generator._read_annotations(csv_str('a.png,1,2,3,a'), classes) 120 | except ValueError as e: 121 | assert str(e).startswith("line 1: format should be") 122 | raise 123 | 124 | with pytest.raises(ValueError): 125 | try: 126 | csv_generator._read_annotations(csv_str( 127 | 'a.png,0,1,2,3,a' '\n' 128 | 'a.png,1,2,3,a' '\n' 129 | ), classes) 130 | except ValueError as e: 131 | assert str(e).startswith("line 2: format should be") 132 | raise 133 | 134 | 135 | def test_read_annotations_wrong_x1(): 136 | with pytest.raises(ValueError): 137 | try: 138 | csv_generator._read_annotations(csv_str('a.png,a,0,1,2,a'), {'a': 1}) 139 | except ValueError as e: 140 | assert str(e).startswith("line 1: malformed x1:") 141 | raise 142 | 143 | 144 | def test_read_annotations_wrong_y1(): 145 | with pytest.raises(ValueError): 146 | try: 147 | csv_generator._read_annotations(csv_str('a.png,0,a,1,2,a'), {'a': 1}) 148 | except ValueError as e: 149 | assert str(e).startswith("line 1: malformed y1:") 150 | raise 151 | 152 | 153 | def test_read_annotations_wrong_x2(): 154 | with pytest.raises(ValueError): 155 | try: 156 | csv_generator._read_annotations(csv_str('a.png,0,1,a,2,a'), {'a': 1}) 157 | except ValueError as e: 158 | assert str(e).startswith("line 1: malformed x2:") 159 | raise 160 | 161 | 162 | def test_read_annotations_wrong_y2(): 163 | with pytest.raises(ValueError): 164 | try: 165 | csv_generator._read_annotations(csv_str('a.png,0,1,2,a,a'), {'a': 1}) 166 | except ValueError as e: 167 | assert str(e).startswith("line 1: malformed y2:") 168 | raise 169 | 170 | 171 | def test_read_annotations_wrong_class(): 172 | with pytest.raises(ValueError): 173 | try: 174 | csv_generator._read_annotations(csv_str('a.png,0,1,2,3,g'), {'a': 1}) 175 | except ValueError as e: 176 | assert str(e).startswith("line 1: unknown class name:") 177 | raise 178 | 179 | 180 | def test_read_annotations_invalid_bb_x(): 181 | with pytest.raises(ValueError): 182 | try: 183 | csv_generator._read_annotations(csv_str('a.png,1,2,1,3,g'), {'a': 1}) 184 | except ValueError as e: 185 | assert str(e).startswith("line 1: x2 (1) must be higher than x1 (1)") 186 | raise 187 | with pytest.raises(ValueError): 188 | try: 189 | csv_generator._read_annotations(csv_str('a.png,9,2,5,3,g'), {'a': 1}) 190 | except ValueError as e: 191 | assert str(e).startswith("line 1: x2 (5) must be higher than x1 (9)") 192 | raise 193 | 194 | 195 | def test_read_annotations_invalid_bb_y(): 196 | with pytest.raises(ValueError): 197 | try: 198 | csv_generator._read_annotations(csv_str('a.png,1,2,3,2,a'), {'a': 1}) 199 | except ValueError as e: 200 | assert str(e).startswith("line 1: y2 (2) must be higher than y1 (2)") 201 | raise 202 | with pytest.raises(ValueError): 203 | try: 204 | csv_generator._read_annotations(csv_str('a.png,1,8,3,5,a'), {'a': 1}) 205 | except ValueError as e: 206 | assert str(e).startswith("line 1: y2 (5) must be higher than y1 (8)") 207 | raise 208 | 209 | 210 | def test_read_annotations_empty_image(): 211 | # Check that images without annotations are parsed. 212 | assert csv_generator._read_annotations(csv_str('a.png,,,,,\nb.png,,,,,'), {'a': 1}) == {'a.png': [], 'b.png': []} 213 | 214 | # Check that lines without annotations don't clear earlier annotations. 215 | assert csv_generator._read_annotations(csv_str('a.png,0,1,2,3,a\na.png,,,,,'), {'a': 1}) == {'a.png': [annotation(0, 1, 2, 3, 'a')]} 216 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/csv_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 yhenon (https://github.com/yhenon/) 3 | Copyright 2017-2018 Fizyr (https://fizyr.com) 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from .generator import Generator 19 | from ..utils.image import read_image_bgr 20 | 21 | import numpy as np 22 | from PIL import Image 23 | from six import raise_from 24 | 25 | import csv 26 | import sys 27 | import os.path 28 | 29 | 30 | def _parse(value, function, fmt): 31 | """ 32 | Parse a string into a value, and format a nice ValueError if it fails. 33 | 34 | Returns `function(value)`. 35 | Any `ValueError` raised is catched and a new `ValueError` is raised 36 | with message `fmt.format(e)`, where `e` is the caught `ValueError`. 37 | """ 38 | try: 39 | return function(value) 40 | except ValueError as e: 41 | raise_from(ValueError(fmt.format(e)), None) 42 | 43 | 44 | def _read_classes(csv_reader): 45 | """ Parse the classes file given by csv_reader. 46 | """ 47 | result = {} 48 | for line, row in enumerate(csv_reader): 49 | line += 1 50 | 51 | try: 52 | class_name, class_id = row 53 | except ValueError: 54 | raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None) 55 | class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line)) 56 | 57 | if class_name in result: 58 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) 59 | result[class_name] = class_id 60 | return result 61 | 62 | 63 | def _read_annotations(csv_reader, classes): 64 | """ Read annotations from the csv_reader. 65 | """ 66 | result = {} 67 | for line, row in enumerate(csv_reader): 68 | line += 1 69 | 70 | try: 71 | img_file, x1, y1, x2, y2, class_name = row[:6] 72 | except ValueError: 73 | raise_from(ValueError('line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None) 74 | 75 | if img_file not in result: 76 | result[img_file] = [] 77 | 78 | # If a row contains only an image path, it's an image without annotations. 79 | if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''): 80 | continue 81 | 82 | x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line)) 83 | y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line)) 84 | x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line)) 85 | y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line)) 86 | 87 | # Check that the bounding box is valid. 88 | if x2 <= x1: 89 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) 90 | if y2 <= y1: 91 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) 92 | 93 | # check if the current class name is correctly present 94 | if class_name not in classes: 95 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes)) 96 | 97 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name}) 98 | return result 99 | 100 | 101 | def _open_for_csv(path): 102 | """ Open a file with flags suitable for csv.reader. 103 | 104 | This is different for python2 it means with mode 'rb', 105 | for python3 this means 'r' with "universal newlines". 106 | """ 107 | if sys.version_info[0] < 3: 108 | return open(path, 'rb') 109 | else: 110 | return open(path, 'r', newline='') 111 | 112 | 113 | class CSVGenerator(Generator): 114 | """ Generate data for a custom CSV dataset. 115 | 116 | See https://github.com/fizyr/keras-retinanet#csv-datasets for more information. 117 | """ 118 | 119 | def __init__( 120 | self, 121 | csv_data_file, 122 | csv_class_file, 123 | base_dir=None, 124 | **kwargs 125 | ): 126 | """ Initialize a CSV data generator. 127 | 128 | Args 129 | csv_data_file: Path to the CSV annotations file. 130 | csv_class_file: Path to the CSV classes file. 131 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 132 | """ 133 | self.image_names = [] 134 | self.image_data = {} 135 | self.base_dir = base_dir 136 | 137 | # Take base_dir from annotations file if not explicitly specified. 138 | if self.base_dir is None: 139 | self.base_dir = os.path.dirname(csv_data_file) 140 | 141 | # parse the provided class file 142 | try: 143 | with _open_for_csv(csv_class_file) as file: 144 | self.classes = _read_classes(csv.reader(file, delimiter=',')) 145 | except ValueError as e: 146 | raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None) 147 | 148 | self.labels = {} 149 | for key, value in self.classes.items(): 150 | self.labels[value] = key 151 | 152 | # csv with img_path, x1, y1, x2, y2, class_name 153 | try: 154 | with _open_for_csv(csv_data_file) as file: 155 | self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes) 156 | except ValueError as e: 157 | raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None) 158 | self.image_names = list(self.image_data.keys()) 159 | 160 | super(CSVGenerator, self).__init__(**kwargs) 161 | 162 | def size(self): 163 | """ Size of the dataset. 164 | """ 165 | return len(self.image_names) 166 | 167 | def num_classes(self): 168 | """ Number of classes in the dataset. 169 | """ 170 | return max(self.classes.values()) + 1 171 | 172 | def has_label(self, label): 173 | """ Return True if label is a known label. 174 | """ 175 | return label in self.labels 176 | 177 | def has_name(self, name): 178 | """ Returns True if name is a known class. 179 | """ 180 | return name in self.classes 181 | 182 | def name_to_label(self, name): 183 | """ Map name to label. 184 | """ 185 | return self.classes[name] 186 | 187 | def label_to_name(self, label): 188 | """ Map label to name. 189 | """ 190 | return self.labels[label] 191 | 192 | def image_path(self, image_index): 193 | """ Returns the image path for image_index. 194 | """ 195 | return os.path.join(self.base_dir, self.image_names[image_index]) 196 | 197 | def image_aspect_ratio(self, image_index): 198 | """ Compute the aspect ratio for an image with image_index. 199 | """ 200 | # PIL is fast for metadata 201 | image = Image.open(self.image_path(image_index)) 202 | return float(image.width) / float(image.height) 203 | 204 | def load_image(self, image_index): 205 | """ Load an image at the image_index. 206 | """ 207 | return read_image_bgr(self.image_path(image_index)) 208 | 209 | def load_annotations(self, image_index): 210 | """ Load annotations for an image_index. 211 | """ 212 | path = self.image_names[image_index] 213 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 214 | 215 | for idx, annot in enumerate(self.image_data[path]): 216 | annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])])) 217 | annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[ 218 | float(annot['x1']), 219 | float(annot['y1']), 220 | float(annot['x2']), 221 | float(annot['y2']), 222 | ]])) 223 | 224 | return annotations 225 | -------------------------------------------------------------------------------- /tests/utils/test_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import configparser 3 | import keras 4 | 5 | from keras_retinanet.utils.anchors import anchors_for_shape, AnchorParameters 6 | from keras_retinanet.utils.config import read_config_file, parse_anchor_parameters 7 | 8 | 9 | def test_config_read(): 10 | config = read_config_file('tests/test-data/config/config.ini') 11 | assert 'anchor_parameters' in config 12 | assert 'sizes' in config['anchor_parameters'] 13 | assert 'strides' in config['anchor_parameters'] 14 | assert 'ratios' in config['anchor_parameters'] 15 | assert 'scales' in config['anchor_parameters'] 16 | assert config['anchor_parameters']['sizes'] == '32 64 128 256 512' 17 | assert config['anchor_parameters']['strides'] == '8 16 32 64 128' 18 | assert config['anchor_parameters']['ratios'] == '0.5 1 2 3' 19 | assert config['anchor_parameters']['scales'] == '1 1.2 1.6' 20 | 21 | 22 | def create_anchor_params_config(): 23 | config = configparser.ConfigParser() 24 | config['anchor_parameters'] = {} 25 | config['anchor_parameters']['sizes'] = '32 64 128 256 512' 26 | config['anchor_parameters']['strides'] = '8 16 32 64 128' 27 | config['anchor_parameters']['ratios'] = '0.5 1' 28 | config['anchor_parameters']['scales'] = '1 1.2 1.6' 29 | 30 | return config 31 | 32 | 33 | def test_parse_anchor_parameters(): 34 | config = create_anchor_params_config() 35 | anchor_params_parsed = parse_anchor_parameters(config) 36 | 37 | sizes = [32, 64, 128, 256, 512] 38 | strides = [8, 16, 32, 64, 128] 39 | ratios = np.array([0.5, 1], keras.backend.floatx()) 40 | scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) 41 | 42 | assert sizes == anchor_params_parsed.sizes 43 | assert strides == anchor_params_parsed.strides 44 | np.testing.assert_equal(ratios, anchor_params_parsed.ratios) 45 | np.testing.assert_equal(scales, anchor_params_parsed.scales) 46 | 47 | 48 | def test_anchors_for_shape_dimensions(): 49 | sizes = [32, 64, 128] 50 | strides = [8, 16, 32] 51 | ratios = np.array([0.5, 1, 2, 3], keras.backend.floatx()) 52 | scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) 53 | anchor_params = AnchorParameters(sizes, strides, ratios, scales) 54 | 55 | pyramid_levels = [3, 4, 5] 56 | image_shape = (64, 64) 57 | all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) 58 | 59 | assert all_anchors.shape == (1008, 4) 60 | 61 | 62 | def test_anchors_for_shape_values(): 63 | sizes = [12] 64 | strides = [8] 65 | ratios = np.array([1, 2], keras.backend.floatx()) 66 | scales = np.array([1, 2], keras.backend.floatx()) 67 | anchor_params = AnchorParameters(sizes, strides, ratios, scales) 68 | 69 | pyramid_levels = [3] 70 | image_shape = (16, 16) 71 | all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) 72 | 73 | # using almost_equal for floating point imprecisions 74 | np.testing.assert_almost_equal(all_anchors[0, :], [ 75 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 76 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 77 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 78 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 79 | ], decimal=6) 80 | np.testing.assert_almost_equal(all_anchors[1, :], [ 81 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 82 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 83 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 84 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 85 | ], decimal=6) 86 | np.testing.assert_almost_equal(all_anchors[2, :], [ 87 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 88 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 89 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 90 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 91 | ], decimal=6) 92 | np.testing.assert_almost_equal(all_anchors[3, :], [ 93 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 94 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 95 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 96 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 97 | ], decimal=6) 98 | np.testing.assert_almost_equal(all_anchors[4, :], [ 99 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 100 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 101 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 102 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 103 | ], decimal=6) 104 | np.testing.assert_almost_equal(all_anchors[5, :], [ 105 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 106 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 107 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 108 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 109 | ], decimal=6) 110 | np.testing.assert_almost_equal(all_anchors[6, :], [ 111 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 112 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 113 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 114 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 115 | ], decimal=6) 116 | np.testing.assert_almost_equal(all_anchors[7, :], [ 117 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 118 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 119 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 120 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 121 | ], decimal=6) 122 | np.testing.assert_almost_equal(all_anchors[8, :], [ 123 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 124 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 125 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 126 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 127 | ], decimal=6) 128 | np.testing.assert_almost_equal(all_anchors[9, :], [ 129 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 130 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 131 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 132 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 133 | ], decimal=6) 134 | np.testing.assert_almost_equal(all_anchors[10, :], [ 135 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 136 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 137 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 138 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 139 | ], decimal=6) 140 | np.testing.assert_almost_equal(all_anchors[11, :], [ 141 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 142 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 143 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 144 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 145 | ], decimal=6) 146 | np.testing.assert_almost_equal(all_anchors[12, :], [ 147 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 148 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 149 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 150 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 151 | ], decimal=6) 152 | np.testing.assert_almost_equal(all_anchors[13, :], [ 153 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 154 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 155 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 156 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 157 | ], decimal=6) 158 | np.testing.assert_almost_equal(all_anchors[14, :], [ 159 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 160 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 161 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 162 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 163 | ], decimal=6) 164 | np.testing.assert_almost_equal(all_anchors[15, :], [ 165 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 166 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 167 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 168 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 169 | ], decimal=6) 170 | -------------------------------------------------------------------------------- /tests/preprocessing/test_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from keras_retinanet.preprocessing.generator import Generator 18 | 19 | import numpy as np 20 | import pytest 21 | 22 | 23 | class SimpleGenerator(Generator): 24 | def __init__(self, bboxes, labels, num_classes=0, image=None): 25 | assert(len(bboxes) == len(labels)) 26 | self.bboxes = bboxes 27 | self.labels = labels 28 | self.num_classes_ = num_classes 29 | self.image = image 30 | super(SimpleGenerator, self).__init__(group_method='none', shuffle_groups=False) 31 | 32 | def num_classes(self): 33 | return self.num_classes_ 34 | 35 | def load_image(self, image_index): 36 | return self.image 37 | 38 | def size(self): 39 | return len(self.bboxes) 40 | 41 | def load_annotations(self, image_index): 42 | annotations = {'labels': self.labels[image_index], 'bboxes': self.bboxes[image_index]} 43 | return annotations 44 | 45 | 46 | class TestLoadAnnotationsGroup(object): 47 | def test_simple(self): 48 | input_bboxes_group = [ 49 | np.array([ 50 | [ 0, 0, 10, 10], 51 | [150, 150, 350, 350] 52 | ]), 53 | ] 54 | input_labels_group = [ 55 | np.array([ 56 | 1, 57 | 3 58 | ]), 59 | ] 60 | expected_bboxes_group = input_bboxes_group 61 | expected_labels_group = input_labels_group 62 | 63 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 64 | annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) 65 | 66 | assert('bboxes' in annotations[0]) 67 | assert('labels' in annotations[0]) 68 | np.testing.assert_equal(expected_bboxes_group[0], annotations[0]['bboxes']) 69 | np.testing.assert_equal(expected_labels_group[0], annotations[0]['labels']) 70 | 71 | def test_multiple(self): 72 | input_bboxes_group = [ 73 | np.array([ 74 | [ 0, 0, 10, 10], 75 | [150, 150, 350, 350] 76 | ]), 77 | np.array([ 78 | [0, 0, 50, 50], 79 | ]), 80 | ] 81 | input_labels_group = [ 82 | np.array([ 83 | 1, 84 | 0 85 | ]), 86 | np.array([ 87 | 3 88 | ]) 89 | ] 90 | expected_bboxes_group = input_bboxes_group 91 | expected_labels_group = input_labels_group 92 | 93 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 94 | annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) 95 | annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) 96 | 97 | assert('bboxes' in annotations_group_0[0]) 98 | assert('bboxes' in annotations_group_1[0]) 99 | assert('labels' in annotations_group_0[0]) 100 | assert('labels' in annotations_group_1[0]) 101 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) 102 | np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) 103 | np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) 104 | np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) 105 | 106 | 107 | class TestFilterAnnotations(object): 108 | def test_simple_filter(self): 109 | input_bboxes_group = [ 110 | np.array([ 111 | [ 0, 0, 10, 10], 112 | [150, 150, 50, 50] 113 | ]), 114 | ] 115 | input_labels_group = [ 116 | np.array([ 117 | 3, 118 | 1 119 | ]), 120 | ] 121 | 122 | input_image = np.zeros((500, 500, 3)) 123 | 124 | expected_bboxes_group = [ 125 | np.array([ 126 | [0, 0, 10, 10], 127 | ]), 128 | ] 129 | expected_labels_group = [ 130 | np.array([ 131 | 3, 132 | ]), 133 | ] 134 | 135 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 136 | annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) 137 | # expect a UserWarning 138 | with pytest.warns(UserWarning): 139 | image_group, annotations_group = simple_generator.filter_annotations([input_image], annotations, simple_generator.groups[0]) 140 | 141 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group[0]['bboxes']) 142 | np.testing.assert_equal(expected_labels_group[0], annotations_group[0]['labels']) 143 | 144 | def test_multiple_filter(self): 145 | input_bboxes_group = [ 146 | np.array([ 147 | [ 0, 0, 10, 10], 148 | [150, 150, 50, 50], 149 | [150, 150, 350, 350], 150 | [350, 350, 150, 150], 151 | [ 1, 1, 2, 2], 152 | [ 2, 2, 1, 1] 153 | ]), 154 | np.array([ 155 | [0, 0, -1, -1] 156 | ]), 157 | np.array([ 158 | [-10, -10, 0, 0], 159 | [-10, -10, -100, -100], 160 | [ 10, 10, 100, 100] 161 | ]), 162 | np.array([ 163 | [ 10, 10, 100, 100], 164 | [ 10, 10, 600, 600] 165 | ]), 166 | ] 167 | 168 | input_labels_group = [ 169 | np.array([ 170 | 6, 171 | 5, 172 | 4, 173 | 3, 174 | 2, 175 | 1 176 | ]), 177 | np.array([ 178 | 0 179 | ]), 180 | np.array([ 181 | 10, 182 | 11, 183 | 12 184 | ]), 185 | np.array([ 186 | 105, 187 | 107 188 | ]), 189 | ] 190 | 191 | input_image = np.zeros((500, 500, 3)) 192 | 193 | expected_bboxes_group = [ 194 | np.array([ 195 | [ 0, 0, 10, 10], 196 | [150, 150, 350, 350], 197 | [ 1, 1, 2, 2] 198 | ]), 199 | np.zeros((0, 4)), 200 | np.array([ 201 | [10, 10, 100, 100] 202 | ]), 203 | np.array([ 204 | [ 10, 10, 100, 100] 205 | ]), 206 | ] 207 | expected_labels_group = [ 208 | np.array([ 209 | 6, 210 | 4, 211 | 2 212 | ]), 213 | np.zeros((0,)), 214 | np.array([ 215 | 12 216 | ]), 217 | np.array([ 218 | 105 219 | ]), 220 | ] 221 | 222 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 223 | # expect a UserWarning 224 | annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) 225 | with pytest.warns(UserWarning): 226 | image_group, annotations_group_0 = simple_generator.filter_annotations([input_image], annotations_group_0, simple_generator.groups[0]) 227 | 228 | annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) 229 | with pytest.warns(UserWarning): 230 | image_group, annotations_group_1 = simple_generator.filter_annotations([input_image], annotations_group_1, simple_generator.groups[1]) 231 | 232 | annotations_group_2 = simple_generator.load_annotations_group(simple_generator.groups[2]) 233 | with pytest.warns(UserWarning): 234 | image_group, annotations_group_2 = simple_generator.filter_annotations([input_image], annotations_group_2, simple_generator.groups[2]) 235 | 236 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) 237 | np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) 238 | 239 | np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) 240 | np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) 241 | 242 | np.testing.assert_equal(expected_bboxes_group[2], annotations_group_2[0]['bboxes']) 243 | np.testing.assert_equal(expected_labels_group[2], annotations_group_2[0]['labels']) 244 | 245 | def test_complete(self): 246 | input_bboxes_group = [ 247 | np.array([ 248 | [ 0, 0, 50, 50], 249 | [150, 150, 50, 50], # invalid bbox 250 | ], dtype=float) 251 | ] 252 | 253 | input_labels_group = [ 254 | np.array([ 255 | 5, # one object of class 5 256 | 3, # one object of class 3 with an invalid box 257 | ], dtype=float) 258 | ] 259 | 260 | input_image = np.zeros((500, 500, 3), dtype=np.uint8) 261 | 262 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group, image=input_image, num_classes=6) 263 | # expect a UserWarning 264 | with pytest.warns(UserWarning): 265 | _, [_, labels_batch] = simple_generator[0] 266 | 267 | # test that only object with class 5 is present in labels_batch 268 | labels = np.unique(np.argmax(labels_batch == 5, axis=2)) 269 | assert(len(labels) == 1 and labels[0] == 0), 'Expected only class 0 to be present, but got classes {}'.format(labels) 270 | -------------------------------------------------------------------------------- /keras_retinanet/layers/filter_detections.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from .. import backend 19 | 20 | 21 | def filter_detections( 22 | boxes, 23 | classification, 24 | other = [], 25 | class_specific_filter = True, 26 | nms = True, 27 | score_threshold = 0.05, 28 | max_detections = 300, 29 | nms_threshold = 0.5 30 | ): 31 | """ Filter detections using the boxes and classification values. 32 | 33 | Args 34 | boxes : Tensor of shape (num_boxes, 4) containing the boxes in (x1, y1, x2, y2) format. 35 | classification : Tensor of shape (num_boxes, num_classes) containing the classification scores. 36 | other : List of tensors of shape (num_boxes, ...) to filter along with the boxes and classification scores. 37 | class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those. 38 | nms : Flag to enable/disable non maximum suppression. 39 | score_threshold : Threshold used to prefilter the boxes with. 40 | max_detections : Maximum number of detections to keep. 41 | nms_threshold : Threshold for the IoU value to determine when a box should be suppressed. 42 | 43 | Returns 44 | A list of [boxes, scores, labels, other[0], other[1], ...]. 45 | boxes is shaped (max_detections, 4) and contains the (x1, y1, x2, y2) of the non-suppressed boxes. 46 | scores is shaped (max_detections,) and contains the scores of the predicted class. 47 | labels is shaped (max_detections,) and contains the predicted label. 48 | other[i] is shaped (max_detections, ...) and contains the filtered other[i] data. 49 | In case there are less than max_detections detections, the tensors are padded with -1's. 50 | """ 51 | def _filter_detections(scores, labels): 52 | # threshold based on score 53 | indices = backend.where(keras.backend.greater(scores, score_threshold)) 54 | 55 | if nms: 56 | filtered_boxes = backend.gather_nd(boxes, indices) 57 | filtered_scores = keras.backend.gather(scores, indices)[:, 0] 58 | 59 | # perform NMS 60 | nms_indices = backend.non_max_suppression(filtered_boxes, filtered_scores, max_output_size=max_detections, iou_threshold=nms_threshold) 61 | 62 | # filter indices based on NMS 63 | indices = keras.backend.gather(indices, nms_indices) 64 | 65 | # add indices to list of all indices 66 | labels = backend.gather_nd(labels, indices) 67 | indices = keras.backend.stack([indices[:, 0], labels], axis=1) 68 | 69 | return indices 70 | 71 | if class_specific_filter: 72 | all_indices = [] 73 | # perform per class filtering 74 | for c in range(int(classification.shape[1])): 75 | scores = classification[:, c] 76 | labels = c * backend.ones((keras.backend.shape(scores)[0],), dtype='int64') 77 | all_indices.append(_filter_detections(scores, labels)) 78 | 79 | # concatenate indices to single tensor 80 | indices = keras.backend.concatenate(all_indices, axis=0) 81 | else: 82 | scores = keras.backend.max(classification, axis = 1) 83 | labels = keras.backend.argmax(classification, axis = 1) 84 | indices = _filter_detections(scores, labels) 85 | 86 | # select top k 87 | scores = backend.gather_nd(classification, indices) 88 | labels = indices[:, 1] 89 | scores, top_indices = backend.top_k(scores, k=keras.backend.minimum(max_detections, keras.backend.shape(scores)[0])) 90 | 91 | # filter input using the final set of indices 92 | indices = keras.backend.gather(indices[:, 0], top_indices) 93 | boxes = keras.backend.gather(boxes, indices) 94 | labels = keras.backend.gather(labels, top_indices) 95 | other_ = [keras.backend.gather(o, indices) for o in other] 96 | 97 | # zero pad the outputs 98 | pad_size = keras.backend.maximum(0, max_detections - keras.backend.shape(scores)[0]) 99 | boxes = backend.pad(boxes, [[0, pad_size], [0, 0]], constant_values=-1) 100 | scores = backend.pad(scores, [[0, pad_size]], constant_values=-1) 101 | labels = backend.pad(labels, [[0, pad_size]], constant_values=-1) 102 | labels = keras.backend.cast(labels, 'int32') 103 | other_ = [backend.pad(o, [[0, pad_size]] + [[0, 0] for _ in range(1, len(o.shape))], constant_values=-1) for o in other_] 104 | 105 | # set shapes, since we know what they are 106 | boxes.set_shape([max_detections, 4]) 107 | scores.set_shape([max_detections]) 108 | labels.set_shape([max_detections]) 109 | for o, s in zip(other_, [list(keras.backend.int_shape(o)) for o in other]): 110 | o.set_shape([max_detections] + s[1:]) 111 | 112 | return [boxes, scores, labels] + other_ 113 | 114 | 115 | class FilterDetections(keras.layers.Layer): 116 | """ Keras layer for filtering detections using score threshold and NMS. 117 | """ 118 | 119 | def __init__( 120 | self, 121 | nms = True, 122 | class_specific_filter = True, 123 | nms_threshold = 0.5, 124 | score_threshold = 0.05, 125 | max_detections = 300, 126 | parallel_iterations = 32, 127 | **kwargs 128 | ): 129 | """ Filters detections using score threshold, NMS and selecting the top-k detections. 130 | 131 | Args 132 | nms : Flag to enable/disable NMS. 133 | class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those. 134 | nms_threshold : Threshold for the IoU value to determine when a box should be suppressed. 135 | score_threshold : Threshold used to prefilter the boxes with. 136 | max_detections : Maximum number of detections to keep. 137 | parallel_iterations : Number of batch items to process in parallel. 138 | """ 139 | self.nms = nms 140 | self.class_specific_filter = class_specific_filter 141 | self.nms_threshold = nms_threshold 142 | self.score_threshold = score_threshold 143 | self.max_detections = max_detections 144 | self.parallel_iterations = parallel_iterations 145 | super(FilterDetections, self).__init__(**kwargs) 146 | 147 | def call(self, inputs, **kwargs): 148 | """ Constructs the NMS graph. 149 | 150 | Args 151 | inputs : List of [boxes, classification, other[0], other[1], ...] tensors. 152 | """ 153 | boxes = inputs[0] 154 | classification = inputs[1] 155 | other = inputs[2:] 156 | 157 | # wrap nms with our parameters 158 | def _filter_detections(args): 159 | boxes = args[0] 160 | classification = args[1] 161 | other = args[2] 162 | 163 | return filter_detections( 164 | boxes, 165 | classification, 166 | other, 167 | nms = self.nms, 168 | class_specific_filter = self.class_specific_filter, 169 | score_threshold = self.score_threshold, 170 | max_detections = self.max_detections, 171 | nms_threshold = self.nms_threshold, 172 | ) 173 | 174 | # call filter_detections on each batch 175 | outputs = backend.map_fn( 176 | _filter_detections, 177 | elems=[boxes, classification, other], 178 | dtype=[keras.backend.floatx(), keras.backend.floatx(), 'int32'] + [o.dtype for o in other], 179 | parallel_iterations=self.parallel_iterations 180 | ) 181 | 182 | return outputs 183 | 184 | def compute_output_shape(self, input_shape): 185 | """ Computes the output shapes given the input shapes. 186 | 187 | Args 188 | input_shape : List of input shapes [boxes, classification, other[0], other[1], ...]. 189 | 190 | Returns 191 | List of tuples representing the output shapes: 192 | [filtered_boxes.shape, filtered_scores.shape, filtered_labels.shape, filtered_other[0].shape, filtered_other[1].shape, ...] 193 | """ 194 | return [ 195 | (input_shape[0][0], self.max_detections, 4), 196 | (input_shape[1][0], self.max_detections), 197 | (input_shape[1][0], self.max_detections), 198 | ] + [ 199 | tuple([input_shape[i][0], self.max_detections] + list(input_shape[i][2:])) for i in range(2, len(input_shape)) 200 | ] 201 | 202 | def compute_mask(self, inputs, mask=None): 203 | """ This is required in Keras when there is more than 1 output. 204 | """ 205 | return (len(inputs) + 1) * [None] 206 | 207 | def get_config(self): 208 | """ Gets the configuration of this layer. 209 | 210 | Returns 211 | Dictionary containing the parameters of this layer. 212 | """ 213 | config = super(FilterDetections, self).get_config() 214 | config.update({ 215 | 'nms' : self.nms, 216 | 'class_specific_filter' : self.class_specific_filter, 217 | 'nms_threshold' : self.nms_threshold, 218 | 'score_threshold' : self.score_threshold, 219 | 'max_detections' : self.max_detections, 220 | 'parallel_iterations' : self.parallel_iterations, 221 | }) 222 | 223 | return config 224 | --------------------------------------------------------------------------------