├── .codecov.yml ├── .github ├── ISSUE_TEMPLATE │ ├── Bug_report.md │ └── Feature_request.md └── stale.yml ├── .gitignore ├── .gitmodules ├── .travis.yml ├── CONTRIBUTORS.md ├── LICENSE ├── README.md ├── examples ├── 000000008021.jpg ├── ResNet50RetinaNet.ipynb └── resnet50_retinanet.py ├── images ├── coco1.png ├── coco2.png └── coco3.png ├── keras_retinanet ├── __init__.py ├── backend │ ├── __init__.py │ └── backend.py ├── bin │ ├── __init__.py │ ├── convert_model.py │ ├── debug.py │ ├── evaluate.py │ └── train.py ├── callbacks │ ├── __init__.py │ ├── coco.py │ ├── common.py │ └── eval.py ├── initializers.py ├── layers │ ├── __init__.py │ ├── _misc.py │ └── filter_detections.py ├── losses.py ├── models │ ├── __init__.py │ ├── densenet.py │ ├── effnet.py │ ├── mobilenet.py │ ├── resnet.py │ ├── retinanet.py │ ├── senet.py │ └── vgg.py ├── preprocessing │ ├── __init__.py │ ├── coco.py │ ├── csv_generator.py │ ├── generator.py │ ├── kitti.py │ ├── open_images.py │ └── pascal_voc.py └── utils │ ├── __init__.py │ ├── anchors.py │ ├── coco_eval.py │ ├── colors.py │ ├── compute_overlap.pyx │ ├── config.py │ ├── eval.py │ ├── gpu.py │ ├── image.py │ ├── model.py │ ├── tf_version.py │ ├── transform.py │ └── visualization.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── snapshots └── .gitignore └── tests ├── __init__.py ├── backend ├── __init__.py └── test_common.py ├── bin └── test_train.py ├── layers ├── __init__.py ├── test_filter_detections.py └── test_misc.py ├── models ├── __init__.py ├── test_densenet.py └── test_mobilenet.py ├── preprocessing ├── __init__.py ├── test_csv_generator.py ├── test_generator.py └── test_image.py ├── requirements.txt ├── test_losses.py └── utils ├── __init__.py ├── test_anchors.py └── test_transform.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | #see https://github.com/codecov/support/wiki/Codecov-Yaml 2 | codecov: 3 | notify: 4 | require_ci_to_pass: yes 5 | 6 | coverage: 7 | precision: 0 # 2 = xx.xx%, 0 = xx% 8 | round: nearest # how coverage is rounded: down/up/nearest 9 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 10 | status: 11 | # https://codecov.readme.io/v1.0/docs/commit-status 12 | project: 13 | default: 14 | against: auto 15 | target: 90% # specify the target coverage for each commit status 16 | threshold: 20% # allow this little decrease on project 17 | # https://github.com/codecov/support/wiki/Filtering-Branches 18 | # branches: master 19 | if_ci_failed: error 20 | # https://github.com/codecov/support/wiki/Patch-Status 21 | patch: 22 | default: 23 | against: auto 24 | target: 40% # specify the target "X%" coverage to hit 25 | # threshold: 50% # allow this much decrease on patch 26 | changes: false 27 | 28 | parsers: 29 | gcov: 30 | branch_detection: 31 | conditional: true 32 | loop: true 33 | macro: false 34 | method: false 35 | javascript: 36 | enable_partials: false 37 | 38 | comment: 39 | layout: header, diff 40 | require_changes: false 41 | behavior: default # update if exists else create new 42 | # branches: * -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Bug_report.md: -------------------------------------------------------------------------------- 1 | **Please make sure that you follow the steps below when creating an issue.** 2 | Only use GitHub issues for issues with the implementation, not for issues with specific datasets or general questions about functionality. 3 | If your issue is an implementation question, please ask your question on the #keras-retinanet Slack channel instead of filing a GitHub issue. 4 | You can find directions for the [Slack channel](https://github.com/fizyr/keras-retinanet#discussions) 5 | 6 | Thank you! 7 | 8 | **To be followed:** 9 | 10 | 1. Check that you are up-to-date with 11 | - the master branch of keras-retinanet, 12 | - latest version of [Keras](https://github.com/keras-team/keras), 13 | - the latest version of TensorFlow (see [installation instructions](https://www.tensorflow.org/get_started/os_setup)). 14 | 2. Check that you have read the entire [README.md](https://github.com/fizyr/keras-retinanet/README.md). 15 | Most noticably the [FAQ](https://github.com/fizyr/keras-retinanet#faq) section shows common issues. 16 | 3. Clearly describe the issues you're having including the expected behaviour, the actual behaviour 17 | and the steps required to trigger the issue. 18 | 4. Include relevant output from the commands you're executing, including full stack traces where relevant. 19 | 5. Remove this entire message and replace it with your issue. 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Feature_request.md: -------------------------------------------------------------------------------- 1 | **Is your feature request related to a problem? Please describe.** 2 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 3 | 4 | **Describe the solution you'd like** 5 | A clear and concise description of what you want to happen. 6 | 7 | **Describe alternatives you've considered** 8 | A clear and concise description of any alternative solutions or features you've considered. 9 | 10 | **Additional context** 11 | Add any other context or screenshots about the feature request here. -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | 4 | # Number of days of inactivity before a stale issue is closed 5 | daysUntilClose: 10 6 | 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - wontfix 10 | - feature request 11 | - enhancement 12 | - discussion 13 | - help wanted 14 | 15 | # Label to use when marking an issue as stale 16 | staleLabel: stale 17 | 18 | # Limit to only `issues` or `pulls` 19 | only: issues 20 | 21 | # Comment to post when marking an issue as stale. Set to `false` to disable 22 | markComment: > 23 | This issue has been automatically marked as stale due to the lack of 24 | recent activity. It will be closed if no further activity occurs. Thank you 25 | for your contributions. 26 | 27 | # Comment to post when closing a stale issue. Set to `false` to disable 28 | closeComment: false 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | /build/ 9 | /dist/ 10 | /eggs/ 11 | /*-eggs/ 12 | .eggs/ 13 | /sdist/ 14 | /wheels/ 15 | /*.egg-info/ 16 | .installed.cfg 17 | *.egg 18 | 19 | # Unit test / coverage reports 20 | .coverage 21 | .coverage.* 22 | coverage.xml 23 | *.cover -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tests/test-data"] 2 | path = tests/test-data 3 | url = https://github.com/fizyr/keras-retinanet-test-data.git 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | sudo: required 4 | 5 | python: 6 | - '3.6' 7 | - '3.7' 8 | 9 | install: 10 | - pip install -r requirements.txt 11 | - pip install -r tests/requirements.txt 12 | 13 | cache: pip 14 | 15 | script: 16 | - python setup.py check -m -s 17 | - python setup.py build_ext --inplace 18 | - coverage run --source keras_retinanet -m py.test keras_retinanet tests --doctest-modules --forked --flake8 19 | 20 | after_success: 21 | - coverage xml 22 | - coverage report 23 | - codecov 24 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | This is a list of people who contributed patches to keras-retinanet. 4 | 5 | If you feel you should be listed here or if you have any other questions/comments on your listing here, 6 | please create an issue or pull request at https://github.com/fizyr/keras-retinanet/ 7 | 8 | * Hans Gaiser 9 | * Maarten de Vries 10 | * Valerio Carpani 11 | * Ashley Williamson 12 | * Yann Henon 13 | * Valeriu Lacatusu 14 | * András Vidosits 15 | * Cristian Gratie 16 | * jjiunlin 17 | * Sorin Panduru 18 | * Rodrigo Meira de Andrade 19 | * Enrico Liscio 20 | * Mihai Morariu 21 | * pedroconceicao 22 | * jjiun 23 | * Wudi Fang 24 | * Mike Clark 25 | * hannesedvartsen 26 | * Max Van Sande 27 | * Pierre Dérian 28 | * ori 29 | * mxvs 30 | * mwilder 31 | * Muhammed Kocabas 32 | * Koen Vijverberg 33 | * iver56 34 | * hnsywangxin 35 | * Guillaume Erhard 36 | * Eduardo Ramos 37 | * DiegoAgher 38 | * Alexander Pacha 39 | * Agastya Kalra 40 | * Jiri BOROVEC 41 | * ntsagko 42 | * charlie / tianqi 43 | * jsemric 44 | * Martin Zlocha 45 | * Raghav Bhardwaj 46 | * bw4sz 47 | * Morten Back Nielsen 48 | * dshahrokhian 49 | * Alex / adreo00 50 | * simone.merello 51 | * Matt Wilder 52 | * Jinwoo Baek 53 | * Etienne Meunier 54 | * Denis Dowling 55 | * cclauss 56 | * Andrew Grigorev 57 | * ZFTurbo 58 | * UgoLouche 59 | * Richard Higgins 60 | * Rajat / rajat.goel 61 | * philipp.marquardt 62 | * peacherwu 63 | * Paul / pauldesigaud 64 | * Martin Genet 65 | * Leo / leonardvandriel 66 | * Laurens Hagendoorn 67 | * Julius / juliussimonelli 68 | * HolyGuacamole 69 | * Fausto Morales 70 | * borakrc 71 | * Ben Weinstein 72 | * Anil Karaka 73 | * Andrea Panizza 74 | * Bruno Santos -------------------------------------------------------------------------------- /examples/000000008021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/examples/000000008021.jpg -------------------------------------------------------------------------------- /examples/resnet50_retinanet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Load necessary modules 5 | 6 | import sys 7 | sys.path.insert(0, '../') 8 | 9 | 10 | # import keras_retinanet 11 | from keras_retinanet import models 12 | from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image 13 | from keras_retinanet.utils.visualization import draw_box, draw_caption 14 | from keras_retinanet.utils.colors import label_color 15 | from keras_retinanet.utils.gpu import setup_gpu 16 | 17 | # import miscellaneous modules 18 | import matplotlib.pyplot as plt 19 | import cv2 20 | import os 21 | import numpy as np 22 | import time 23 | 24 | # set tf backend to allow memory to grow, instead of claiming everything 25 | import tensorflow as tf 26 | 27 | # use this to change which GPU to use 28 | gpu = 0 29 | 30 | # set the modified tf session as backend in keras 31 | setup_gpu(gpu) 32 | 33 | 34 | # ## Load RetinaNet model 35 | 36 | # In[ ]: 37 | 38 | 39 | # adjust this to point to your downloaded/trained model 40 | # models can be downloaded here: https://github.com/fizyr/keras-retinanet/releases 41 | model_path = os.path.join('..', 'snapshots', 'resnet50_coco_best_v2.1.0.h5') 42 | 43 | # load retinanet model 44 | model = models.load_model(model_path, backbone_name='resnet50') 45 | 46 | # if the model is not converted to an inference model, use the line below 47 | # see: https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model 48 | # model = models.convert_model(model) 49 | 50 | #print(model.summary()) 51 | 52 | # load label to names mapping for visualization purposes 53 | labels_to_names = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'} 54 | 55 | 56 | # ## Run detection on example 57 | 58 | # In[ ]: 59 | 60 | 61 | # load image 62 | image = read_image_bgr('000000008021.jpg') 63 | 64 | # copy to draw on 65 | draw = image.copy() 66 | draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB) 67 | 68 | # preprocess image for network 69 | image = preprocess_image(image) 70 | image, scale = resize_image(image) 71 | 72 | # process image 73 | start = time.time() 74 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 75 | print("processing time: ", time.time() - start) 76 | 77 | # correct for image scale 78 | boxes /= scale 79 | 80 | # visualize detections 81 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 82 | # scores are sorted so we can break 83 | if score < 0.5: 84 | break 85 | 86 | color = label_color(label) 87 | 88 | b = box.astype(int) 89 | draw_box(draw, b, color=color) 90 | 91 | caption = "{} {:.3f}".format(labels_to_names[label], score) 92 | draw_caption(draw, b, caption) 93 | 94 | plt.figure(figsize=(15, 15)) 95 | plt.axis('off') 96 | plt.imshow(draw) 97 | plt.show() 98 | 99 | 100 | # In[ ]: 101 | 102 | 103 | 104 | 105 | 106 | # In[ ]: 107 | -------------------------------------------------------------------------------- /images/coco1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/images/coco1.png -------------------------------------------------------------------------------- /images/coco2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/images/coco2.png -------------------------------------------------------------------------------- /images/coco3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/images/coco3.png -------------------------------------------------------------------------------- /keras_retinanet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/keras_retinanet/__init__.py -------------------------------------------------------------------------------- /keras_retinanet/backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .backend import * # noqa: F401,F403 2 | -------------------------------------------------------------------------------- /keras_retinanet/backend/backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import tensorflow 18 | from tensorflow import keras 19 | 20 | 21 | def bbox_transform_inv(boxes, deltas, mean=None, std=None): 22 | """ Applies deltas (usually regression results) to boxes (usually anchors). 23 | 24 | Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed. 25 | The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes. 26 | 27 | Args 28 | boxes : np.array of shape (B, N, 4), where B is the batch size, N the number of boxes and 4 values for (x1, y1, x2, y2). 29 | deltas: np.array of same shape as boxes. These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height. 30 | mean : The mean value used when computing deltas (defaults to [0, 0, 0, 0]). 31 | std : The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]). 32 | 33 | Returns 34 | A np.array of the same shape as boxes, but with deltas applied to each box. 35 | The mean and std are used during training to normalize the regression values (networks love normalization). 36 | """ 37 | if mean is None: 38 | mean = [0, 0, 0, 0] 39 | if std is None: 40 | std = [0.2, 0.2, 0.2, 0.2] 41 | 42 | width = boxes[:, :, 2] - boxes[:, :, 0] 43 | height = boxes[:, :, 3] - boxes[:, :, 1] 44 | 45 | x1 = boxes[:, :, 0] + (deltas[:, :, 0] * std[0] + mean[0]) * width 46 | y1 = boxes[:, :, 1] + (deltas[:, :, 1] * std[1] + mean[1]) * height 47 | x2 = boxes[:, :, 2] + (deltas[:, :, 2] * std[2] + mean[2]) * width 48 | y2 = boxes[:, :, 3] + (deltas[:, :, 3] * std[3] + mean[3]) * height 49 | 50 | pred_boxes = keras.backend.stack([x1, y1, x2, y2], axis=2) 51 | 52 | return pred_boxes 53 | 54 | 55 | def shift(shape, stride, anchors): 56 | """ Produce shifted anchors based on shape of the map and stride size. 57 | 58 | Args 59 | shape : Shape to shift the anchors over. 60 | stride : Stride to shift the anchors with over the shape. 61 | anchors: The anchors to apply at each location. 62 | """ 63 | shift_x = (keras.backend.arange(0, shape[1], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride 64 | shift_y = (keras.backend.arange(0, shape[0], dtype=keras.backend.floatx()) + keras.backend.constant(0.5, dtype=keras.backend.floatx())) * stride 65 | 66 | shift_x, shift_y = tensorflow.meshgrid(shift_x, shift_y) 67 | shift_x = keras.backend.reshape(shift_x, [-1]) 68 | shift_y = keras.backend.reshape(shift_y, [-1]) 69 | 70 | shifts = keras.backend.stack([ 71 | shift_x, 72 | shift_y, 73 | shift_x, 74 | shift_y 75 | ], axis=0) 76 | 77 | shifts = keras.backend.transpose(shifts) 78 | number_of_anchors = keras.backend.shape(anchors)[0] 79 | 80 | k = keras.backend.shape(shifts)[0] # number of base points = feat_h * feat_w 81 | 82 | shifted_anchors = keras.backend.reshape(anchors, [1, number_of_anchors, 4]) + keras.backend.cast(keras.backend.reshape(shifts, [k, 1, 4]), keras.backend.floatx()) 83 | shifted_anchors = keras.backend.reshape(shifted_anchors, [k * number_of_anchors, 4]) 84 | 85 | return shifted_anchors 86 | 87 | 88 | def map_fn(*args, **kwargs): 89 | """ See https://www.tensorflow.org/api_docs/python/tf/map_fn . 90 | """ 91 | 92 | if "shapes" in kwargs: 93 | shapes = kwargs.pop("shapes") 94 | dtype = kwargs.pop("dtype") 95 | sig = [tensorflow.TensorSpec(shapes[i], dtype=t) for i, t in 96 | enumerate(dtype)] 97 | 98 | # Try to use the new feature fn_output_signature in TF 2.3, use fallback if this is not available 99 | try: 100 | return tensorflow.map_fn(*args, **kwargs, fn_output_signature=sig) 101 | except TypeError: 102 | kwargs["dtype"] = dtype 103 | 104 | return tensorflow.map_fn(*args, **kwargs) 105 | 106 | 107 | def resize_images(images, size, method='bilinear', align_corners=False): 108 | """ See https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/image/resize_images . 109 | 110 | Args 111 | method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area'). 112 | """ 113 | methods = { 114 | 'bilinear': tensorflow.image.ResizeMethod.BILINEAR, 115 | 'nearest' : tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR, 116 | 'bicubic' : tensorflow.image.ResizeMethod.BICUBIC, 117 | 'area' : tensorflow.image.ResizeMethod.AREA, 118 | } 119 | return tensorflow.compat.v1.image.resize_images(images, size, methods[method], align_corners) 120 | -------------------------------------------------------------------------------- /keras_retinanet/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/keras_retinanet/bin/__init__.py -------------------------------------------------------------------------------- /keras_retinanet/bin/convert_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Copyright 2017-2018 Fizyr (https://fizyr.com) 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Allow relative imports when being executed as script. 24 | if __name__ == "__main__" and __package__ is None: 25 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) 26 | import keras_retinanet.bin # noqa: F401 27 | __package__ = "keras_retinanet.bin" 28 | 29 | # Change these to absolute imports if you copy this script outside the keras_retinanet package. 30 | from .. import models 31 | from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels 32 | from ..utils.gpu import setup_gpu 33 | from ..utils.tf_version import check_tf_version 34 | 35 | 36 | def parse_args(args): 37 | parser = argparse.ArgumentParser(description='Script for converting a training model to an inference model.') 38 | 39 | parser.add_argument('model_in', help='The model to convert.') 40 | parser.add_argument('model_out', help='Path to save the converted model to.') 41 | parser.add_argument('--backbone', help='The backbone of the model to convert.', default='resnet50') 42 | parser.add_argument('--no-nms', help='Disables non maximum suppression.', dest='nms', action='store_false') 43 | parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false') 44 | parser.add_argument('--config', help='Path to a configuration parameters .ini file.') 45 | parser.add_argument('--nms-threshold', help='Value for non maximum suppression threshold.', type=float, default=0.5) 46 | parser.add_argument('--score-threshold', help='Threshold for prefiltering boxes.', type=float, default=0.05) 47 | parser.add_argument('--max-detections', help='Maximum number of detections to keep.', type=int, default=300) 48 | parser.add_argument('--parallel-iterations', help='Number of batch items to process in parallel.', type=int, default=32) 49 | 50 | return parser.parse_args(args) 51 | 52 | 53 | def main(args=None): 54 | # parse arguments 55 | if args is None: 56 | args = sys.argv[1:] 57 | args = parse_args(args) 58 | 59 | # make sure tensorflow is the minimum required version 60 | check_tf_version() 61 | 62 | # set modified tf session to avoid using the GPUs 63 | setup_gpu('cpu') 64 | 65 | # optionally load config parameters 66 | anchor_parameters = None 67 | pyramid_levels = None 68 | if args.config: 69 | args.config = read_config_file(args.config) 70 | if 'anchor_parameters' in args.config: 71 | anchor_parameters = parse_anchor_parameters(args.config) 72 | 73 | if 'pyramid_levels' in args.config: 74 | pyramid_levels = parse_pyramid_levels(args.config) 75 | 76 | # load the model 77 | model = models.load_model(args.model_in, backbone_name=args.backbone) 78 | 79 | # check if this is indeed a training model 80 | models.check_training_model(model) 81 | 82 | # convert the model 83 | model = models.convert_model( 84 | model, 85 | nms=args.nms, 86 | class_specific_filter=args.class_specific_filter, 87 | anchor_params=anchor_parameters, 88 | pyramid_levels=pyramid_levels, 89 | nms_threshold=args.nms_threshold, 90 | score_threshold=args.score_threshold, 91 | max_detections=args.max_detections, 92 | parallel_iterations=args.parallel_iterations 93 | ) 94 | 95 | # save model 96 | model.save(args.model_out) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /keras_retinanet/bin/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Copyright 2017-2018 Fizyr (https://fizyr.com) 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import argparse 18 | import os 19 | import sys 20 | 21 | # Allow relative imports when being executed as script. 22 | if __name__ == "__main__" and __package__ is None: 23 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) 24 | import keras_retinanet.bin # noqa: F401 25 | __package__ = "keras_retinanet.bin" 26 | 27 | # Change these to absolute imports if you copy this script outside the keras_retinanet package. 28 | from .. import models 29 | from ..preprocessing.csv_generator import CSVGenerator 30 | from ..preprocessing.pascal_voc import PascalVocGenerator 31 | from ..preprocessing.kitti import KittiGenerator 32 | from ..utils.anchors import make_shapes_callback 33 | from ..utils.config import read_config_file, parse_anchor_parameters, parse_pyramid_levels 34 | from ..utils.eval import evaluate 35 | from ..utils.gpu import setup_gpu 36 | from ..utils.tf_version import check_tf_version 37 | 38 | 39 | def create_generator(args, preprocess_image): 40 | """ Create generators for evaluation. 41 | """ 42 | common_args = { 43 | 'config' : args.config, 44 | 'image_min_side' : args.image_min_side, 45 | 'image_max_side' : args.image_max_side, 46 | 'no_resize' : args.no_resize, 47 | 'preprocess_image' : preprocess_image, 48 | 'group_method' : args.group_method 49 | } 50 | 51 | if args.dataset_type == 'coco': 52 | # import here to prevent unnecessary dependency on cocoapi 53 | from ..preprocessing.coco import CocoGenerator 54 | 55 | validation_generator = CocoGenerator( 56 | args.coco_path, 57 | 'val2017', 58 | shuffle_groups=False, 59 | **common_args 60 | ) 61 | elif args.dataset_type == 'pascal': 62 | validation_generator = PascalVocGenerator( 63 | args.pascal_path, 64 | 'test', 65 | image_extension=args.image_extension, 66 | shuffle_groups=False, 67 | **common_args 68 | ) 69 | elif args.dataset_type == 'csv': 70 | validation_generator = CSVGenerator( 71 | args.annotations, 72 | args.classes, 73 | shuffle_groups=False, 74 | **common_args 75 | ) 76 | elif args.dataset_type == 'kitti': 77 | validation_generator = KittiGenerator( 78 | args.kitti_path, 79 | 'val', 80 | shuffle_groups=False, 81 | **common_args 82 | ) 83 | else: 84 | raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) 85 | 86 | return validation_generator 87 | 88 | 89 | def parse_args(args): 90 | """ Parse the arguments. 91 | """ 92 | parser = argparse.ArgumentParser(description='Evaluation script for a RetinaNet network.') 93 | subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type') 94 | subparsers.required = True 95 | 96 | coco_parser = subparsers.add_parser('coco') 97 | coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).') 98 | 99 | pascal_parser = subparsers.add_parser('pascal') 100 | pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).') 101 | pascal_parser.add_argument('--image-extension', help='Declares the dataset images\' extension.', default='.jpg') 102 | 103 | csv_parser = subparsers.add_parser('csv') 104 | csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for evaluation.') 105 | csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.') 106 | 107 | kitti_parser=subparsers.add_parser('kitti') 108 | kitti_parser.add_argument('--kitti_path', help='Path to dataset directory') 109 | 110 | parser.add_argument('model', help='Path to RetinaNet model.') 111 | parser.add_argument('--convert-model', help='Convert the model to an inference model (ie. the input is a training model).', action='store_true') 112 | parser.add_argument('--backbone', help='The backbone of the model.', default='resnet50') 113 | parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).') 114 | parser.add_argument('--score-threshold', help='Threshold on score to filter detections with (defaults to 0.05).', default=0.05, type=float) 115 | parser.add_argument('--iou-threshold', help='IoU Threshold to count for a positive detection (defaults to 0.5).', default=0.5, type=float) 116 | parser.add_argument('--max-detections', help='Max Detections per image (defaults to 100).', default=100, type=int) 117 | parser.add_argument('--save-path', help='Path for saving images with detections (doesn\'t work for COCO).') 118 | parser.add_argument('--image-min-side', help='Rescale the image so the smallest side is min_side.', type=int, default=800) 119 | parser.add_argument('--image-max-side', help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333) 120 | parser.add_argument('--no-resize', help='Don''t rescale the image.', action='store_true') 121 | parser.add_argument('--config', help='Path to a configuration parameters .ini file (only used with --convert-model).') 122 | parser.add_argument('--group-method', help='Determines how images are grouped together', type=str, default='ratio', choices=['none', 'random', 'ratio']) 123 | 124 | return parser.parse_args(args) 125 | 126 | 127 | def main(args=None): 128 | # parse arguments 129 | if args is None: 130 | args = sys.argv[1:] 131 | args = parse_args(args) 132 | 133 | # make sure tensorflow is the minimum required version 134 | check_tf_version() 135 | 136 | # optionally choose specific GPU 137 | if args.gpu: 138 | setup_gpu(args.gpu) 139 | 140 | # make save path if it doesn't exist 141 | if args.save_path is not None and not os.path.exists(args.save_path): 142 | os.makedirs(args.save_path) 143 | 144 | # optionally load config parameters 145 | if args.config: 146 | args.config = read_config_file(args.config) 147 | 148 | # create the generator 149 | backbone = models.backbone(args.backbone) 150 | generator = create_generator(args, backbone.preprocess_image) 151 | 152 | # optionally load anchor parameters 153 | anchor_params = None 154 | pyramid_levels = None 155 | if args.config and 'anchor_parameters' in args.config: 156 | anchor_params = parse_anchor_parameters(args.config) 157 | if args.config and 'pyramid_levels' in args.config: 158 | pyramid_levels = parse_pyramid_levels(args.config) 159 | 160 | # load the model 161 | print('Loading model, this may take a second...') 162 | model = models.load_model(args.model, backbone_name=args.backbone) 163 | generator.compute_shapes = make_shapes_callback(model) 164 | 165 | # optionally convert the model 166 | if args.convert_model: 167 | model = models.convert_model(model, anchor_params=anchor_params, pyramid_levels=pyramid_levels) 168 | 169 | # print model summary 170 | # print(model.summary()) 171 | 172 | # start evaluation 173 | if args.dataset_type == 'coco': 174 | from ..utils.coco_eval import evaluate_coco 175 | evaluate_coco(generator, model, args.score_threshold) 176 | else: 177 | average_precisions, inference_time = evaluate( 178 | generator, 179 | model, 180 | iou_threshold=args.iou_threshold, 181 | score_threshold=args.score_threshold, 182 | max_detections=args.max_detections, 183 | save_path=args.save_path 184 | ) 185 | 186 | # print evaluation 187 | total_instances = [] 188 | precisions = [] 189 | for label, (average_precision, num_annotations) in average_precisions.items(): 190 | print('{:.0f} instances of class'.format(num_annotations), 191 | generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 192 | total_instances.append(num_annotations) 193 | precisions.append(average_precision) 194 | 195 | if sum(total_instances) == 0: 196 | print('No test instances found.') 197 | return 198 | 199 | print('Inference time for {:.0f} images: {:.4f}'.format(generator.size(), inference_time)) 200 | 201 | print('mAP using the weighted average of precisions among classes: {:.4f}'.format(sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances))) 202 | print('mAP: {:.4f}'.format(sum(precisions) / sum(x > 0 for x in total_instances))) 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * # noqa: F401,F403 2 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | from ..utils.coco_eval import evaluate_coco 19 | 20 | 21 | class CocoEval(keras.callbacks.Callback): 22 | """ Performs COCO evaluation on each epoch. 23 | """ 24 | def __init__(self, generator, tensorboard=None, threshold=0.05): 25 | """ CocoEval callback intializer. 26 | 27 | Args 28 | generator : The generator used for creating validation data. 29 | tensorboard : If given, the results will be written to tensorboard. 30 | threshold : The score threshold to use. 31 | """ 32 | self.generator = generator 33 | self.threshold = threshold 34 | self.tensorboard = tensorboard 35 | 36 | super(CocoEval, self).__init__() 37 | 38 | def on_epoch_end(self, epoch, logs=None): 39 | logs = logs or {} 40 | 41 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 42 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]', 43 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]', 44 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 45 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 46 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', 47 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', 48 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', 49 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 50 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 51 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 52 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]'] 53 | coco_eval_stats = evaluate_coco(self.generator, self.model, self.threshold) 54 | 55 | if coco_eval_stats is not None: 56 | for index, result in enumerate(coco_eval_stats): 57 | logs[coco_tag[index]] = result 58 | 59 | if self.tensorboard: 60 | import tensorflow as tf 61 | writer = tf.summary.create_file_writer(self.tensorboard.log_dir) 62 | with writer.as_default(): 63 | for index, result in enumerate(coco_eval_stats): 64 | tf.summary.scalar('{}. {}'.format(index + 1, coco_tag[index]), result, step=epoch) 65 | writer.flush() 66 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/common.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | 4 | class RedirectModel(keras.callbacks.Callback): 5 | """Callback which wraps another callback, but executed on a different model. 6 | 7 | ```python 8 | model = keras.models.load_model('model.h5') 9 | model_checkpoint = ModelCheckpoint(filepath='snapshot.h5') 10 | parallel_model = multi_gpu_model(model, gpus=2) 11 | parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)]) 12 | ``` 13 | 14 | Args 15 | callback : callback to wrap. 16 | model : model to use when executing callbacks. 17 | """ 18 | 19 | def __init__(self, 20 | callback, 21 | model): 22 | super(RedirectModel, self).__init__() 23 | 24 | self.callback = callback 25 | self.redirect_model = model 26 | 27 | def on_epoch_begin(self, epoch, logs=None): 28 | self.callback.on_epoch_begin(epoch, logs=logs) 29 | 30 | def on_epoch_end(self, epoch, logs=None): 31 | self.callback.on_epoch_end(epoch, logs=logs) 32 | 33 | def on_batch_begin(self, batch, logs=None): 34 | self.callback.on_batch_begin(batch, logs=logs) 35 | 36 | def on_batch_end(self, batch, logs=None): 37 | self.callback.on_batch_end(batch, logs=logs) 38 | 39 | def on_train_begin(self, logs=None): 40 | # overwrite the model with our custom model 41 | self.callback.set_model(self.redirect_model) 42 | 43 | self.callback.on_train_begin(logs=logs) 44 | 45 | def on_train_end(self, logs=None): 46 | self.callback.on_train_end(logs=logs) 47 | -------------------------------------------------------------------------------- /keras_retinanet/callbacks/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | from ..utils.eval import evaluate 19 | 20 | 21 | class Evaluate(keras.callbacks.Callback): 22 | """ Evaluation callback for arbitrary datasets. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | generator, 28 | iou_threshold=0.5, 29 | score_threshold=0.05, 30 | max_detections=100, 31 | save_path=None, 32 | tensorboard=None, 33 | weighted_average=False, 34 | verbose=1 35 | ): 36 | """ Evaluate a given dataset using a given model at the end of every epoch during training. 37 | 38 | # Arguments 39 | generator : The generator that represents the dataset to evaluate. 40 | iou_threshold : The threshold used to consider when a detection is positive or negative. 41 | score_threshold : The score confidence threshold to use for detections. 42 | max_detections : The maximum number of detections to use per image. 43 | save_path : The path to save images with visualized detections to. 44 | tensorboard : Instance of keras.callbacks.TensorBoard used to log the mAP value. 45 | weighted_average : Compute the mAP using the weighted average of precisions among classes. 46 | verbose : Set the verbosity level, by default this is set to 1. 47 | """ 48 | self.generator = generator 49 | self.iou_threshold = iou_threshold 50 | self.score_threshold = score_threshold 51 | self.max_detections = max_detections 52 | self.save_path = save_path 53 | self.tensorboard = tensorboard 54 | self.weighted_average = weighted_average 55 | self.verbose = verbose 56 | 57 | super(Evaluate, self).__init__() 58 | 59 | def on_epoch_end(self, epoch, logs=None): 60 | logs = logs or {} 61 | 62 | # run evaluation 63 | average_precisions, _ = evaluate( 64 | self.generator, 65 | self.model, 66 | iou_threshold=self.iou_threshold, 67 | score_threshold=self.score_threshold, 68 | max_detections=self.max_detections, 69 | save_path=self.save_path 70 | ) 71 | 72 | # compute per class average precision 73 | total_instances = [] 74 | precisions = [] 75 | for label, (average_precision, num_annotations) in average_precisions.items(): 76 | if self.verbose == 1: 77 | print('{:.0f} instances of class'.format(num_annotations), 78 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 79 | total_instances.append(num_annotations) 80 | precisions.append(average_precision) 81 | if self.weighted_average: 82 | self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances) 83 | else: 84 | self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 85 | 86 | if self.tensorboard: 87 | import tensorflow as tf 88 | writer = tf.summary.create_file_writer(self.tensorboard.log_dir) 89 | with writer.as_default(): 90 | tf.summary.scalar("mAP", self.mean_ap, step=epoch) 91 | if self.verbose == 1: 92 | for label, (average_precision, num_annotations) in average_precisions.items(): 93 | tf.summary.scalar("AP_" + self.generator.label_to_name(label), average_precision, step=epoch) 94 | writer.flush() 95 | 96 | logs['mAP'] = self.mean_ap 97 | 98 | if self.verbose == 1: 99 | print('mAP: {:.4f}'.format(self.mean_ap)) 100 | -------------------------------------------------------------------------------- /keras_retinanet/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | 19 | import math 20 | 21 | 22 | class PriorProbability(keras.initializers.Initializer): 23 | """ Apply a prior probability to the weights. 24 | """ 25 | 26 | def __init__(self, probability=0.01): 27 | self.probability = probability 28 | 29 | def get_config(self): 30 | return { 31 | 'probability': self.probability 32 | } 33 | 34 | def __call__(self, shape, dtype=None): 35 | # set bias to -log((1 - p)/p) for foreground 36 | result = keras.backend.ones(shape, dtype=dtype) * -math.log((1 - self.probability) / self.probability) 37 | 38 | return result 39 | -------------------------------------------------------------------------------- /keras_retinanet/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from ._misc import RegressBoxes, UpsampleLike, Anchors, ClipBoxes # noqa: F401 2 | from .filter_detections import FilterDetections # noqa: F401 3 | -------------------------------------------------------------------------------- /keras_retinanet/layers/_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import tensorflow 18 | from tensorflow import keras 19 | from .. import backend 20 | from ..utils import anchors as utils_anchors 21 | 22 | import numpy as np 23 | 24 | 25 | class Anchors(keras.layers.Layer): 26 | """ Keras layer for generating achors for a given shape. 27 | """ 28 | 29 | def __init__(self, size, stride, ratios=None, scales=None, *args, **kwargs): 30 | """ Initializer for an Anchors layer. 31 | 32 | Args 33 | size: The base size of the anchors to generate. 34 | stride: The stride of the anchors to generate. 35 | ratios: The ratios of the anchors to generate (defaults to AnchorParameters.default.ratios). 36 | scales: The scales of the anchors to generate (defaults to AnchorParameters.default.scales). 37 | """ 38 | self.size = size 39 | self.stride = stride 40 | self.ratios = ratios 41 | self.scales = scales 42 | 43 | if ratios is None: 44 | self.ratios = utils_anchors.AnchorParameters.default.ratios 45 | elif isinstance(ratios, list): 46 | self.ratios = np.array(ratios) 47 | if scales is None: 48 | self.scales = utils_anchors.AnchorParameters.default.scales 49 | elif isinstance(scales, list): 50 | self.scales = np.array(scales) 51 | 52 | self.num_anchors = len(self.ratios) * len(self.scales) 53 | self.anchors = utils_anchors.generate_anchors( 54 | base_size=self.size, 55 | ratios=self.ratios, 56 | scales=self.scales, 57 | ).astype(np.float32) 58 | 59 | super(Anchors, self).__init__(*args, **kwargs) 60 | 61 | def call(self, inputs, **kwargs): 62 | features = inputs 63 | features_shape = keras.backend.shape(features) 64 | 65 | # generate proposals from bbox deltas and shifted anchors 66 | if keras.backend.image_data_format() == 'channels_first': 67 | anchors = backend.shift(features_shape[2:4], self.stride, self.anchors) 68 | else: 69 | anchors = backend.shift(features_shape[1:3], self.stride, self.anchors) 70 | anchors = keras.backend.tile(keras.backend.expand_dims(anchors, axis=0), (features_shape[0], 1, 1)) 71 | 72 | return anchors 73 | 74 | def compute_output_shape(self, input_shape): 75 | if None not in input_shape[1:]: 76 | if keras.backend.image_data_format() == 'channels_first': 77 | total = np.prod(input_shape[2:4]) * self.num_anchors 78 | else: 79 | total = np.prod(input_shape[1:3]) * self.num_anchors 80 | 81 | return (input_shape[0], total, 4) 82 | else: 83 | return (input_shape[0], None, 4) 84 | 85 | def get_config(self): 86 | config = super(Anchors, self).get_config() 87 | config.update({ 88 | 'size' : self.size, 89 | 'stride' : self.stride, 90 | 'ratios' : self.ratios.tolist(), 91 | 'scales' : self.scales.tolist(), 92 | }) 93 | 94 | return config 95 | 96 | 97 | class UpsampleLike(keras.layers.Layer): 98 | """ Keras layer for upsampling a Tensor to be the same shape as another Tensor. 99 | """ 100 | 101 | def call(self, inputs, **kwargs): 102 | source, target = inputs 103 | target_shape = keras.backend.shape(target) 104 | if keras.backend.image_data_format() == 'channels_first': 105 | source = tensorflow.transpose(source, (0, 2, 3, 1)) 106 | output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest') 107 | output = tensorflow.transpose(output, (0, 3, 1, 2)) 108 | return output 109 | else: 110 | return backend.resize_images(source, (target_shape[1], target_shape[2]), method='nearest') 111 | 112 | def compute_output_shape(self, input_shape): 113 | if keras.backend.image_data_format() == 'channels_first': 114 | return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4] 115 | else: 116 | return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],) 117 | 118 | 119 | class RegressBoxes(keras.layers.Layer): 120 | """ Keras layer for applying regression values to boxes. 121 | """ 122 | 123 | def __init__(self, mean=None, std=None, *args, **kwargs): 124 | """ Initializer for the RegressBoxes layer. 125 | 126 | Args 127 | mean: The mean value of the regression values which was used for normalization. 128 | std: The standard value of the regression values which was used for normalization. 129 | """ 130 | if mean is None: 131 | mean = np.array([0, 0, 0, 0]) 132 | if std is None: 133 | std = np.array([0.2, 0.2, 0.2, 0.2]) 134 | 135 | if isinstance(mean, (list, tuple)): 136 | mean = np.array(mean) 137 | elif not isinstance(mean, np.ndarray): 138 | raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean))) 139 | 140 | if isinstance(std, (list, tuple)): 141 | std = np.array(std) 142 | elif not isinstance(std, np.ndarray): 143 | raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std))) 144 | 145 | self.mean = mean 146 | self.std = std 147 | super(RegressBoxes, self).__init__(*args, **kwargs) 148 | 149 | def call(self, inputs, **kwargs): 150 | anchors, regression = inputs 151 | return backend.bbox_transform_inv(anchors, regression, mean=self.mean, std=self.std) 152 | 153 | def compute_output_shape(self, input_shape): 154 | return input_shape[0] 155 | 156 | def get_config(self): 157 | config = super(RegressBoxes, self).get_config() 158 | config.update({ 159 | 'mean': self.mean.tolist(), 160 | 'std' : self.std.tolist(), 161 | }) 162 | 163 | return config 164 | 165 | 166 | class ClipBoxes(keras.layers.Layer): 167 | """ Keras layer to clip box values to lie inside a given shape. 168 | """ 169 | def call(self, inputs, **kwargs): 170 | image, boxes = inputs 171 | shape = keras.backend.cast(keras.backend.shape(image), keras.backend.floatx()) 172 | if keras.backend.image_data_format() == 'channels_first': 173 | _, _, height, width = tensorflow.unstack(shape, axis=0) 174 | else: 175 | _, height, width, _ = tensorflow.unstack(shape, axis=0) 176 | 177 | x1, y1, x2, y2 = tensorflow.unstack(boxes, axis=-1) 178 | x1 = tensorflow.clip_by_value(x1, 0, width - 1) 179 | y1 = tensorflow.clip_by_value(y1, 0, height - 1) 180 | x2 = tensorflow.clip_by_value(x2, 0, width - 1) 181 | y2 = tensorflow.clip_by_value(y2, 0, height - 1) 182 | 183 | return keras.backend.stack([x1, y1, x2, y2], axis=2) 184 | 185 | def compute_output_shape(self, input_shape): 186 | return input_shape[1] 187 | -------------------------------------------------------------------------------- /keras_retinanet/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import tensorflow 18 | from tensorflow import keras 19 | 20 | 21 | def focal(alpha=0.25, gamma=2.0, cutoff=0.5): 22 | """ Create a functor for computing the focal loss. 23 | 24 | Args 25 | alpha: Scale the focal weight with alpha. 26 | gamma: Take the power of the focal weight with gamma. 27 | cutoff: Positive prediction cutoff for soft targets 28 | 29 | Returns 30 | A functor that computes the focal loss using the alpha and gamma. 31 | """ 32 | def _focal(y_true, y_pred): 33 | """ Compute the focal loss given the target tensor and the predicted tensor. 34 | 35 | As defined in https://arxiv.org/abs/1708.02002 36 | 37 | Args 38 | y_true: Tensor of target data from the generator with shape (B, N, num_classes). 39 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes). 40 | 41 | Returns 42 | The focal loss of y_pred w.r.t. y_true. 43 | """ 44 | labels = y_true[:, :, :-1] 45 | anchor_state = y_true[:, :, -1] # -1 for ignore, 0 for background, 1 for object 46 | classification = y_pred 47 | 48 | # filter out "ignore" anchors 49 | indices = tensorflow.where(keras.backend.not_equal(anchor_state, -1)) 50 | labels = tensorflow.gather_nd(labels, indices) 51 | classification = tensorflow.gather_nd(classification, indices) 52 | 53 | # compute the focal loss 54 | alpha_factor = keras.backend.ones_like(labels) * alpha 55 | alpha_factor = tensorflow.where(keras.backend.greater(labels, cutoff), alpha_factor, 1 - alpha_factor) 56 | focal_weight = tensorflow.where(keras.backend.greater(labels, cutoff), 1 - classification, classification) 57 | focal_weight = alpha_factor * focal_weight ** gamma 58 | 59 | cls_loss = focal_weight * keras.backend.binary_crossentropy(labels, classification) 60 | 61 | # compute the normalizer: the number of positive anchors 62 | normalizer = tensorflow.where(keras.backend.equal(anchor_state, 1)) 63 | normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx()) 64 | normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer) 65 | 66 | return keras.backend.sum(cls_loss) / normalizer 67 | 68 | return _focal 69 | 70 | 71 | def smooth_l1(sigma=3.0): 72 | """ Create a smooth L1 loss functor. 73 | 74 | Args 75 | sigma: This argument defines the point where the loss changes from L2 to L1. 76 | 77 | Returns 78 | A functor for computing the smooth L1 loss given target data and predicted data. 79 | """ 80 | sigma_squared = sigma ** 2 81 | 82 | def _smooth_l1(y_true, y_pred): 83 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true. 84 | 85 | Args 86 | y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive). 87 | y_pred: Tensor from the network of shape (B, N, 4). 88 | 89 | Returns 90 | The smooth L1 loss of y_pred w.r.t. y_true. 91 | """ 92 | # separate target and state 93 | regression = y_pred 94 | regression_target = y_true[:, :, :-1] 95 | anchor_state = y_true[:, :, -1] 96 | 97 | # filter out "ignore" anchors 98 | indices = tensorflow.where(keras.backend.equal(anchor_state, 1)) 99 | regression = tensorflow.gather_nd(regression, indices) 100 | regression_target = tensorflow.gather_nd(regression_target, indices) 101 | 102 | # compute smooth L1 loss 103 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma 104 | # |x| - 0.5 / sigma / sigma otherwise 105 | regression_diff = regression - regression_target 106 | regression_diff = keras.backend.abs(regression_diff) 107 | regression_loss = tensorflow.where( 108 | keras.backend.less(regression_diff, 1.0 / sigma_squared), 109 | 0.5 * sigma_squared * keras.backend.pow(regression_diff, 2), 110 | regression_diff - 0.5 / sigma_squared 111 | ) 112 | 113 | # compute the normalizer: the number of positive anchors 114 | normalizer = keras.backend.maximum(1, keras.backend.shape(indices)[0]) 115 | normalizer = keras.backend.cast(normalizer, dtype=keras.backend.floatx()) 116 | return keras.backend.sum(regression_loss) / normalizer 117 | 118 | return _smooth_l1 119 | -------------------------------------------------------------------------------- /keras_retinanet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | 4 | 5 | class Backbone(object): 6 | """ This class stores additional information on backbones. 7 | """ 8 | def __init__(self, backbone): 9 | # a dictionary mapping custom layer names to the correct classes 10 | from .. import layers 11 | from .. import losses 12 | from .. import initializers 13 | self.custom_objects = { 14 | 'UpsampleLike' : layers.UpsampleLike, 15 | 'PriorProbability' : initializers.PriorProbability, 16 | 'RegressBoxes' : layers.RegressBoxes, 17 | 'FilterDetections' : layers.FilterDetections, 18 | 'Anchors' : layers.Anchors, 19 | 'ClipBoxes' : layers.ClipBoxes, 20 | '_smooth_l1' : losses.smooth_l1(), 21 | '_focal' : losses.focal(), 22 | } 23 | 24 | self.backbone = backbone 25 | self.validate() 26 | 27 | def retinanet(self, *args, **kwargs): 28 | """ Returns a retinanet model using the correct backbone. 29 | """ 30 | raise NotImplementedError('retinanet method not implemented.') 31 | 32 | def download_imagenet(self): 33 | """ Downloads ImageNet weights and returns path to weights file. 34 | """ 35 | raise NotImplementedError('download_imagenet method not implemented.') 36 | 37 | def validate(self): 38 | """ Checks whether the backbone string is correct. 39 | """ 40 | raise NotImplementedError('validate method not implemented.') 41 | 42 | def preprocess_image(self, inputs): 43 | """ Takes as input an image and prepares it for being passed through the network. 44 | Having this function in Backbone allows other backbones to define a specific preprocessing step. 45 | """ 46 | raise NotImplementedError('preprocess_image method not implemented.') 47 | 48 | 49 | def backbone(backbone_name): 50 | """ Returns a backbone object for the given backbone. 51 | """ 52 | if 'densenet' in backbone_name: 53 | from .densenet import DenseNetBackbone as b 54 | elif 'seresnext' in backbone_name or 'seresnet' in backbone_name or 'senet' in backbone_name: 55 | from .senet import SeBackbone as b 56 | elif 'resnet' in backbone_name: 57 | from .resnet import ResNetBackbone as b 58 | elif 'mobilenet' in backbone_name: 59 | from .mobilenet import MobileNetBackbone as b 60 | elif 'vgg' in backbone_name: 61 | from .vgg import VGGBackbone as b 62 | elif 'EfficientNet' in backbone_name: 63 | from .effnet import EfficientNetBackbone as b 64 | else: 65 | raise NotImplementedError('Backbone class for \'{}\' not implemented.'.format(backbone)) 66 | 67 | return b(backbone_name) 68 | 69 | 70 | def load_model(filepath, backbone_name='resnet50'): 71 | """ Loads a retinanet model using the correct custom objects. 72 | 73 | Args 74 | filepath: one of the following: 75 | - string, path to the saved model, or 76 | - h5py.File object from which to load the model 77 | backbone_name : Backbone with which the model was trained. 78 | 79 | Returns 80 | A keras.models.Model object. 81 | 82 | Raises 83 | ImportError: if h5py is not available. 84 | ValueError: In case of an invalid savefile. 85 | """ 86 | from tensorflow import keras 87 | return keras.models.load_model(filepath, custom_objects=backbone(backbone_name).custom_objects) 88 | 89 | 90 | def convert_model(model, nms=True, class_specific_filter=True, anchor_params=None, **kwargs): 91 | """ Converts a training model to an inference model. 92 | 93 | Args 94 | model : A retinanet training model. 95 | nms : Boolean, whether to add NMS filtering to the converted model. 96 | class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only. 97 | anchor_params : Anchor parameters object. If omitted, default values are used. 98 | **kwargs : Inference and minimal retinanet model settings. 99 | 100 | Returns 101 | A keras.models.Model object. 102 | 103 | Raises 104 | ImportError: if h5py is not available. 105 | ValueError: In case of an invalid savefile. 106 | """ 107 | from .retinanet import retinanet_bbox 108 | return retinanet_bbox(model=model, nms=nms, class_specific_filter=class_specific_filter, anchor_params=anchor_params, **kwargs) 109 | 110 | 111 | def assert_training_model(model): 112 | """ Assert that the model is a training model. 113 | """ 114 | assert(all(output in model.output_names for output in ['regression', 'classification'])), \ 115 | "Input is not a training model (no 'regression' and 'classification' outputs were found, outputs are: {}).".format(model.output_names) 116 | 117 | 118 | def check_training_model(model): 119 | """ Check that model is a training model and exit otherwise. 120 | """ 121 | try: 122 | assert_training_model(model) 123 | except AssertionError as e: 124 | print(e, file=sys.stderr) 125 | sys.exit(1) 126 | -------------------------------------------------------------------------------- /keras_retinanet/models/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 vidosits (https://github.com/vidosits/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | 19 | from . import retinanet 20 | from . import Backbone 21 | from ..utils.image import preprocess_image 22 | 23 | 24 | allowed_backbones = { 25 | 'densenet121': ([6, 12, 24, 16], keras.applications.densenet.DenseNet121), 26 | 'densenet169': ([6, 12, 32, 32], keras.applications.densenet.DenseNet169), 27 | 'densenet201': ([6, 12, 48, 32], keras.applications.densenet.DenseNet201), 28 | } 29 | 30 | 31 | class DenseNetBackbone(Backbone): 32 | """ Describes backbone information and provides utility functions. 33 | """ 34 | 35 | def retinanet(self, *args, **kwargs): 36 | """ Returns a retinanet model using the correct backbone. 37 | """ 38 | return densenet_retinanet(*args, backbone=self.backbone, **kwargs) 39 | 40 | def download_imagenet(self): 41 | """ Download pre-trained weights for the specified backbone name. 42 | This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop 43 | where backbone is the densenet + number of layers (e.g. densenet121). 44 | For more info check the explanation from the keras densenet script itself: 45 | https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py 46 | """ 47 | origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/' 48 | file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5' 49 | 50 | # load weights 51 | if keras.backend.image_data_format() == 'channels_first': 52 | raise ValueError('Weights for "channels_first" format are not available.') 53 | 54 | weights_url = origin + file_name.format(self.backbone) 55 | return keras.utils.get_file(file_name.format(self.backbone), weights_url, cache_subdir='models') 56 | 57 | def validate(self): 58 | """ Checks whether the backbone string is correct. 59 | """ 60 | backbone = self.backbone.split('_')[0] 61 | 62 | if backbone not in allowed_backbones: 63 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones.keys())) 64 | 65 | def preprocess_image(self, inputs): 66 | """ Takes as input an image and prepares it for being passed through the network. 67 | """ 68 | return preprocess_image(inputs, mode='tf') 69 | 70 | 71 | def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifier=None, **kwargs): 72 | """ Constructs a retinanet model using a densenet backbone. 73 | 74 | Args 75 | num_classes: Number of classes to predict. 76 | backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')). 77 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 78 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 79 | 80 | Returns 81 | RetinaNet model with a DenseNet backbone. 82 | """ 83 | # choose default input 84 | if inputs is None: 85 | inputs = keras.layers.Input((None, None, 3)) 86 | 87 | blocks, creator = allowed_backbones[backbone] 88 | model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None) 89 | 90 | # get last conv layer from the end of each dense block 91 | layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)] 92 | 93 | # create the densenet backbone 94 | # layer_outputs contains 4 layers 95 | model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name) 96 | 97 | # invoke modifier if given 98 | if modifier: 99 | model = modifier(model) 100 | 101 | # create the full model 102 | backbone_layers = { 103 | 'C2': model.outputs[0], 104 | 'C3': model.outputs[1], 105 | 'C4': model.outputs[2], 106 | 'C5': model.outputs[3] 107 | } 108 | 109 | model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) 110 | 111 | return model 112 | -------------------------------------------------------------------------------- /keras_retinanet/models/effnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | 19 | from . import retinanet 20 | from . import Backbone 21 | import efficientnet.keras as efn 22 | 23 | 24 | class EfficientNetBackbone(Backbone): 25 | """ Describes backbone information and provides utility functions. 26 | """ 27 | 28 | def __init__(self, backbone): 29 | super(EfficientNetBackbone, self).__init__(backbone) 30 | self.preprocess_image_func = None 31 | 32 | def retinanet(self, *args, **kwargs): 33 | """ Returns a retinanet model using the correct backbone. 34 | """ 35 | return effnet_retinanet(*args, backbone=self.backbone, **kwargs) 36 | 37 | def download_imagenet(self): 38 | """ Downloads ImageNet weights and returns path to weights file. 39 | """ 40 | from efficientnet.weights import IMAGENET_WEIGHTS_PATH 41 | from efficientnet.weights import IMAGENET_WEIGHTS_HASHES 42 | 43 | model_name = 'efficientnet-b' + self.backbone[-1] 44 | file_name = model_name + '_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5' 45 | file_hash = IMAGENET_WEIGHTS_HASHES[model_name][1] 46 | weights_path = keras.utils.get_file(file_name, IMAGENET_WEIGHTS_PATH + file_name, cache_subdir='models', file_hash=file_hash) 47 | return weights_path 48 | 49 | def validate(self): 50 | """ Checks whether the backbone string is correct. 51 | """ 52 | allowed_backbones = ['EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4', 53 | 'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7'] 54 | backbone = self.backbone.split('_')[0] 55 | 56 | if backbone not in allowed_backbones: 57 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) 58 | 59 | def preprocess_image(self, inputs): 60 | """ Takes as input an image and prepares it for being passed through the network. 61 | """ 62 | return efn.preprocess_input(inputs) 63 | 64 | 65 | def effnet_retinanet(num_classes, backbone='EfficientNetB0', inputs=None, modifier=None, **kwargs): 66 | """ Constructs a retinanet model using a resnet backbone. 67 | 68 | Args 69 | num_classes: Number of classes to predict. 70 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). 71 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 72 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 73 | 74 | Returns 75 | RetinaNet model with a ResNet backbone. 76 | """ 77 | # choose default input 78 | if inputs is None: 79 | if keras.backend.image_data_format() == 'channels_first': 80 | inputs = keras.layers.Input(shape=(3, None, None)) 81 | else: 82 | # inputs = keras.layers.Input(shape=(224, 224, 3)) 83 | inputs = keras.layers.Input(shape=(None, None, 3)) 84 | 85 | # get last conv layer from the end of each block [28x28, 14x14, 7x7] 86 | if backbone == 'EfficientNetB0': 87 | model = efn.EfficientNetB0(input_tensor=inputs, include_top=False, weights=None) 88 | elif backbone == 'EfficientNetB1': 89 | model = efn.EfficientNetB1(input_tensor=inputs, include_top=False, weights=None) 90 | elif backbone == 'EfficientNetB2': 91 | model = efn.EfficientNetB2(input_tensor=inputs, include_top=False, weights=None) 92 | elif backbone == 'EfficientNetB3': 93 | model = efn.EfficientNetB3(input_tensor=inputs, include_top=False, weights=None) 94 | elif backbone == 'EfficientNetB4': 95 | model = efn.EfficientNetB4(input_tensor=inputs, include_top=False, weights=None) 96 | elif backbone == 'EfficientNetB5': 97 | model = efn.EfficientNetB5(input_tensor=inputs, include_top=False, weights=None) 98 | elif backbone == 'EfficientNetB6': 99 | model = efn.EfficientNetB6(input_tensor=inputs, include_top=False, weights=None) 100 | elif backbone == 'EfficientNetB7': 101 | model = efn.EfficientNetB7(input_tensor=inputs, include_top=False, weights=None) 102 | else: 103 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) 104 | 105 | layer_outputs = ['block4a_expand_activation', 'block6a_expand_activation', 'top_activation'] 106 | 107 | layer_outputs = [ 108 | model.get_layer(name=layer_outputs[0]).output, # 28x28 109 | model.get_layer(name=layer_outputs[1]).output, # 14x14 110 | model.get_layer(name=layer_outputs[2]).output, # 7x7 111 | ] 112 | # create the densenet backbone 113 | model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name) 114 | 115 | # invoke modifier if given 116 | if modifier: 117 | model = modifier(model) 118 | 119 | # C2 not provided 120 | backbone_layers = { 121 | 'C3': model.outputs[0], 122 | 'C4': model.outputs[1], 123 | 'C5': model.outputs[2] 124 | } 125 | 126 | # create the full model 127 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) 128 | 129 | 130 | def EfficientNetB0_retinanet(num_classes, inputs=None, **kwargs): 131 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB0', inputs=inputs, **kwargs) 132 | 133 | 134 | def EfficientNetB1_retinanet(num_classes, inputs=None, **kwargs): 135 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB1', inputs=inputs, **kwargs) 136 | 137 | 138 | def EfficientNetB2_retinanet(num_classes, inputs=None, **kwargs): 139 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB2', inputs=inputs, **kwargs) 140 | 141 | 142 | def EfficientNetB3_retinanet(num_classes, inputs=None, **kwargs): 143 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB3', inputs=inputs, **kwargs) 144 | 145 | 146 | def EfficientNetB4_retinanet(num_classes, inputs=None, **kwargs): 147 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB4', inputs=inputs, **kwargs) 148 | 149 | 150 | def EfficientNetB5_retinanet(num_classes, inputs=None, **kwargs): 151 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB5', inputs=inputs, **kwargs) 152 | 153 | 154 | def EfficientNetB6_retinanet(num_classes, inputs=None, **kwargs): 155 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB6', inputs=inputs, **kwargs) 156 | 157 | 158 | def EfficientNetB7_retinanet(num_classes, inputs=None, **kwargs): 159 | return effnet_retinanet(num_classes=num_classes, backbone='EfficientNetB7', inputs=inputs, **kwargs) 160 | -------------------------------------------------------------------------------- /keras_retinanet/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | from ..utils.image import preprocess_image 19 | 20 | from . import retinanet 21 | from . import Backbone 22 | 23 | 24 | class MobileNetBackbone(Backbone): 25 | """ Describes backbone information and provides utility functions. 26 | """ 27 | 28 | allowed_backbones = ['mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224'] 29 | 30 | def retinanet(self, *args, **kwargs): 31 | """ Returns a retinanet model using the correct backbone. 32 | """ 33 | return mobilenet_retinanet(*args, backbone=self.backbone, **kwargs) 34 | 35 | def download_imagenet(self): 36 | """ Download pre-trained weights for the specified backbone name. 37 | This name is in the format mobilenet{rows}_{alpha} where rows is the 38 | imagenet shape dimension and 'alpha' controls the width of the network. 39 | For more info check the explanation from the keras mobilenet script itself. 40 | """ 41 | 42 | alpha = float(self.backbone.split('_')[1]) 43 | rows = int(self.backbone.split('_')[0].replace('mobilenet', '')) 44 | 45 | # load weights 46 | if keras.backend.image_data_format() == 'channels_first': 47 | raise ValueError('Weights for "channels_last" format ' 48 | 'are not available.') 49 | if alpha == 1.0: 50 | alpha_text = '1_0' 51 | elif alpha == 0.75: 52 | alpha_text = '7_5' 53 | elif alpha == 0.50: 54 | alpha_text = '5_0' 55 | else: 56 | alpha_text = '2_5' 57 | 58 | model_name = 'mobilenet_{}_{}_tf_no_top.h5'.format(alpha_text, rows) 59 | weights_url = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/' + model_name 60 | weights_path = keras.utils.get_file(model_name, weights_url, cache_subdir='models') 61 | 62 | return weights_path 63 | 64 | def validate(self): 65 | """ Checks whether the backbone string is correct. 66 | """ 67 | backbone = self.backbone.split('_')[0] 68 | 69 | if backbone not in MobileNetBackbone.allowed_backbones: 70 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, MobileNetBackbone.allowed_backbones)) 71 | 72 | def preprocess_image(self, inputs): 73 | """ Takes as input an image and prepares it for being passed through the network. 74 | """ 75 | return preprocess_image(inputs, mode='tf') 76 | 77 | 78 | def mobilenet_retinanet(num_classes, backbone='mobilenet224_1.0', inputs=None, modifier=None, **kwargs): 79 | """ Constructs a retinanet model using a mobilenet backbone. 80 | 81 | Args 82 | num_classes: Number of classes to predict. 83 | backbone: Which backbone to use (one of ('mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224')). 84 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 85 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 86 | 87 | Returns 88 | RetinaNet model with a MobileNet backbone. 89 | """ 90 | alpha = float(backbone.split('_')[1]) 91 | 92 | # choose default input 93 | if inputs is None: 94 | inputs = keras.layers.Input((None, None, 3)) 95 | 96 | backbone = keras.applications.mobilenet.MobileNet(input_tensor=inputs, alpha=alpha, include_top=False, pooling=None, weights=None) 97 | 98 | # create the full model 99 | layer_names = ['conv_pw_5_relu', 'conv_pw_11_relu', 'conv_pw_13_relu'] 100 | layer_outputs = [backbone.get_layer(name).output for name in layer_names] 101 | backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=backbone.name) 102 | 103 | # invoke modifier if given 104 | if modifier: 105 | backbone = modifier(backbone) 106 | 107 | # C2 not provided 108 | backbone_layers = { 109 | 'C3': backbone.outputs[0], 110 | 'C4': backbone.outputs[1], 111 | 'C5': backbone.outputs[2] 112 | } 113 | 114 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) 115 | -------------------------------------------------------------------------------- /keras_retinanet/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | import keras_resnet 19 | import keras_resnet.models 20 | 21 | from . import retinanet 22 | from . import Backbone 23 | from ..utils.image import preprocess_image 24 | 25 | 26 | class ResNetBackbone(Backbone): 27 | """ Describes backbone information and provides utility functions. 28 | """ 29 | 30 | def __init__(self, backbone): 31 | super(ResNetBackbone, self).__init__(backbone) 32 | self.custom_objects.update(keras_resnet.custom_objects) 33 | 34 | def retinanet(self, *args, **kwargs): 35 | """ Returns a retinanet model using the correct backbone. 36 | """ 37 | return resnet_retinanet(*args, backbone=self.backbone, **kwargs) 38 | 39 | def download_imagenet(self): 40 | """ Downloads ImageNet weights and returns path to weights file. 41 | """ 42 | resnet_filename = 'ResNet-{}-model.keras.h5' 43 | resnet_resource = 'https://github.com/fizyr/keras-models/releases/download/v0.0.1/{}'.format(resnet_filename) 44 | depth = int(self.backbone.replace('resnet', '')) 45 | 46 | filename = resnet_filename.format(depth) 47 | resource = resnet_resource.format(depth) 48 | if depth == 50: 49 | checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319' 50 | elif depth == 101: 51 | checksum = '05dc86924389e5b401a9ea0348a3213c' 52 | elif depth == 152: 53 | checksum = '6ee11ef2b135592f8031058820bb9e71' 54 | 55 | return keras.utils.get_file( 56 | filename, 57 | resource, 58 | cache_subdir='models', 59 | md5_hash=checksum 60 | ) 61 | 62 | def validate(self): 63 | """ Checks whether the backbone string is correct. 64 | """ 65 | allowed_backbones = ['resnet50', 'resnet101', 'resnet152'] 66 | backbone = self.backbone.split('_')[0] 67 | 68 | if backbone not in allowed_backbones: 69 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) 70 | 71 | def preprocess_image(self, inputs): 72 | """ Takes as input an image and prepares it for being passed through the network. 73 | """ 74 | return preprocess_image(inputs, mode='caffe') 75 | 76 | 77 | def resnet_retinanet(num_classes, backbone='resnet50', inputs=None, modifier=None, **kwargs): 78 | """ Constructs a retinanet model using a resnet backbone. 79 | 80 | Args 81 | num_classes: Number of classes to predict. 82 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). 83 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 84 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 85 | 86 | Returns 87 | RetinaNet model with a ResNet backbone. 88 | """ 89 | # choose default input 90 | if inputs is None: 91 | if keras.backend.image_data_format() == 'channels_first': 92 | inputs = keras.layers.Input(shape=(3, None, None)) 93 | else: 94 | inputs = keras.layers.Input(shape=(None, None, 3)) 95 | 96 | # create the resnet backbone 97 | if backbone == 'resnet50': 98 | resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True) 99 | elif backbone == 'resnet101': 100 | resnet = keras_resnet.models.ResNet101(inputs, include_top=False, freeze_bn=True) 101 | elif backbone == 'resnet152': 102 | resnet = keras_resnet.models.ResNet152(inputs, include_top=False, freeze_bn=True) 103 | else: 104 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) 105 | 106 | # invoke modifier if given 107 | if modifier: 108 | resnet = modifier(resnet) 109 | 110 | # create the full model 111 | # resnet.outputs contains 4 layers 112 | backbone_layers = { 113 | 'C2': resnet.outputs[0], 114 | 'C3': resnet.outputs[1], 115 | 'C4': resnet.outputs[2], 116 | 'C5': resnet.outputs[3] 117 | } 118 | 119 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) 120 | 121 | 122 | def resnet50_retinanet(num_classes, inputs=None, **kwargs): 123 | return resnet_retinanet(num_classes=num_classes, backbone='resnet50', inputs=inputs, **kwargs) 124 | 125 | 126 | def resnet101_retinanet(num_classes, inputs=None, **kwargs): 127 | return resnet_retinanet(num_classes=num_classes, backbone='resnet101', inputs=inputs, **kwargs) 128 | 129 | 130 | def resnet152_retinanet(num_classes, inputs=None, **kwargs): 131 | return resnet_retinanet(num_classes=num_classes, backbone='resnet152', inputs=inputs, **kwargs) 132 | -------------------------------------------------------------------------------- /keras_retinanet/models/senet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | 19 | from . import retinanet 20 | from . import Backbone 21 | from classification_models.keras import Classifiers 22 | 23 | 24 | class SeBackbone(Backbone): 25 | """ Describes backbone information and provides utility functions. 26 | """ 27 | 28 | def __init__(self, backbone): 29 | super(SeBackbone, self).__init__(backbone) 30 | _, self.preprocess_image_func = Classifiers.get(self.backbone) 31 | 32 | def retinanet(self, *args, **kwargs): 33 | """ Returns a retinanet model using the correct backbone. 34 | """ 35 | return senet_retinanet(*args, backbone=self.backbone, **kwargs) 36 | 37 | def download_imagenet(self): 38 | """ Downloads ImageNet weights and returns path to weights file. 39 | """ 40 | from classification_models.weights import WEIGHTS_COLLECTION 41 | 42 | weights_path = None 43 | for el in WEIGHTS_COLLECTION: 44 | if el['model'] == self.backbone and not el['include_top']: 45 | weights_path = keras.utils.get_file(el['name'], el['url'], cache_subdir='models', file_hash=el['md5']) 46 | 47 | if weights_path is None: 48 | raise ValueError('Unable to find imagenet weights for backbone {}!'.format(self.backbone)) 49 | 50 | return weights_path 51 | 52 | def validate(self): 53 | """ Checks whether the backbone string is correct. 54 | """ 55 | allowed_backbones = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', 56 | 'seresnext50', 'seresnext101', 'senet154'] 57 | backbone = self.backbone.split('_')[0] 58 | 59 | if backbone not in allowed_backbones: 60 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) 61 | 62 | def preprocess_image(self, inputs): 63 | """ Takes as input an image and prepares it for being passed through the network. 64 | """ 65 | return self.preprocess_image_func(inputs) 66 | 67 | 68 | def senet_retinanet(num_classes, backbone='seresnext50', inputs=None, modifier=None, **kwargs): 69 | """ Constructs a retinanet model using a resnet backbone. 70 | 71 | Args 72 | num_classes: Number of classes to predict. 73 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). 74 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 75 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 76 | 77 | Returns 78 | RetinaNet model with a ResNet backbone. 79 | """ 80 | # choose default input 81 | if inputs is None: 82 | if keras.backend.image_data_format() == 'channels_first': 83 | inputs = keras.layers.Input(shape=(3, None, None)) 84 | else: 85 | # inputs = keras.layers.Input(shape=(224, 224, 3)) 86 | inputs = keras.layers.Input(shape=(None, None, 3)) 87 | 88 | classifier, _ = Classifiers.get(backbone) 89 | model = classifier(input_tensor=inputs, include_top=False, weights=None) 90 | 91 | # get last conv layer from the end of each block [28x28, 14x14, 7x7] 92 | if backbone == 'seresnet18' or backbone == 'seresnet34': 93 | layer_outputs = ['stage3_unit1_relu1', 'stage4_unit1_relu1', 'relu1'] 94 | elif backbone == 'seresnet50': 95 | layer_outputs = ['activation_36', 'activation_66', 'activation_81'] 96 | elif backbone == 'seresnet101': 97 | layer_outputs = ['activation_36', 'activation_151', 'activation_166'] 98 | elif backbone == 'seresnet152': 99 | layer_outputs = ['activation_56', 'activation_236', 'activation_251'] 100 | elif backbone == 'seresnext50': 101 | layer_outputs = ['activation_37', 'activation_67', 'activation_81'] 102 | elif backbone == 'seresnext101': 103 | layer_outputs = ['activation_37', 'activation_152', 'activation_166'] 104 | elif backbone == 'senet154': 105 | layer_outputs = ['activation_59', 'activation_239', 'activation_253'] 106 | else: 107 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) 108 | 109 | layer_outputs = [ 110 | model.get_layer(name=layer_outputs[0]).output, # 28x28 111 | model.get_layer(name=layer_outputs[1]).output, # 14x14 112 | model.get_layer(name=layer_outputs[2]).output, # 7x7 113 | ] 114 | # create the densenet backbone 115 | model = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=model.name) 116 | 117 | # invoke modifier if given 118 | if modifier: 119 | model = modifier(model) 120 | 121 | # C2 not provided 122 | backbone_layers = { 123 | 'C3': model.outputs[0], 124 | 'C4': model.outputs[1], 125 | 'C5': model.outputs[2] 126 | } 127 | 128 | # create the full model 129 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) 130 | 131 | 132 | def seresnet18_retinanet(num_classes, inputs=None, **kwargs): 133 | return senet_retinanet(num_classes=num_classes, backbone='seresnet18', inputs=inputs, **kwargs) 134 | 135 | 136 | def seresnet34_retinanet(num_classes, inputs=None, **kwargs): 137 | return senet_retinanet(num_classes=num_classes, backbone='seresnet34', inputs=inputs, **kwargs) 138 | 139 | 140 | def seresnet50_retinanet(num_classes, inputs=None, **kwargs): 141 | return senet_retinanet(num_classes=num_classes, backbone='seresnet50', inputs=inputs, **kwargs) 142 | 143 | 144 | def seresnet101_retinanet(num_classes, inputs=None, **kwargs): 145 | return senet_retinanet(num_classes=num_classes, backbone='seresnet101', inputs=inputs, **kwargs) 146 | 147 | 148 | def seresnet152_retinanet(num_classes, inputs=None, **kwargs): 149 | return senet_retinanet(num_classes=num_classes, backbone='seresnet152', inputs=inputs, **kwargs) 150 | 151 | 152 | def seresnext50_retinanet(num_classes, inputs=None, **kwargs): 153 | return senet_retinanet(num_classes=num_classes, backbone='seresnext50', inputs=inputs, **kwargs) 154 | 155 | 156 | def seresnext101_retinanet(num_classes, inputs=None, **kwargs): 157 | return senet_retinanet(num_classes=num_classes, backbone='seresnext101', inputs=inputs, **kwargs) 158 | 159 | 160 | def senet154_retinanet(num_classes, inputs=None, **kwargs): 161 | return senet_retinanet(num_classes=num_classes, backbone='senet154', inputs=inputs, **kwargs) 162 | -------------------------------------------------------------------------------- /keras_retinanet/models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 cgratie (https://github.com/cgratie/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | from tensorflow import keras 19 | 20 | from . import retinanet 21 | from . import Backbone 22 | from ..utils.image import preprocess_image 23 | 24 | 25 | class VGGBackbone(Backbone): 26 | """ Describes backbone information and provides utility functions. 27 | """ 28 | 29 | def retinanet(self, *args, **kwargs): 30 | """ Returns a retinanet model using the correct backbone. 31 | """ 32 | return vgg_retinanet(*args, backbone=self.backbone, **kwargs) 33 | 34 | def download_imagenet(self): 35 | """ Downloads ImageNet weights and returns path to weights file. 36 | Weights can be downloaded at https://github.com/fizyr/keras-models/releases . 37 | """ 38 | if self.backbone == 'vgg16': 39 | resource = keras.applications.vgg16.vgg16.WEIGHTS_PATH_NO_TOP 40 | checksum = '6d6bbae143d832006294945121d1f1fc' 41 | elif self.backbone == 'vgg19': 42 | resource = keras.applications.vgg19.vgg19.WEIGHTS_PATH_NO_TOP 43 | checksum = '253f8cb515780f3b799900260a226db6' 44 | else: 45 | raise ValueError("Backbone '{}' not recognized.".format(self.backbone)) 46 | 47 | return keras.utils.get_file( 48 | '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'.format(self.backbone), 49 | resource, 50 | cache_subdir='models', 51 | file_hash=checksum 52 | ) 53 | 54 | def validate(self): 55 | """ Checks whether the backbone string is correct. 56 | """ 57 | allowed_backbones = ['vgg16', 'vgg19'] 58 | 59 | if self.backbone not in allowed_backbones: 60 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones)) 61 | 62 | def preprocess_image(self, inputs): 63 | """ Takes as input an image and prepares it for being passed through the network. 64 | """ 65 | return preprocess_image(inputs, mode='caffe') 66 | 67 | 68 | def vgg_retinanet(num_classes, backbone='vgg16', inputs=None, modifier=None, **kwargs): 69 | """ Constructs a retinanet model using a vgg backbone. 70 | 71 | Args 72 | num_classes: Number of classes to predict. 73 | backbone: Which backbone to use (one of ('vgg16', 'vgg19')). 74 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 75 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 76 | 77 | Returns 78 | RetinaNet model with a VGG backbone. 79 | """ 80 | # choose default input 81 | if inputs is None: 82 | inputs = keras.layers.Input(shape=(None, None, 3)) 83 | 84 | # create the vgg backbone 85 | if backbone == 'vgg16': 86 | vgg = keras.applications.VGG16(input_tensor=inputs, include_top=False, weights=None) 87 | elif backbone == 'vgg19': 88 | vgg = keras.applications.VGG19(input_tensor=inputs, include_top=False, weights=None) 89 | else: 90 | raise ValueError("Backbone '{}' not recognized.".format(backbone)) 91 | 92 | if modifier: 93 | vgg = modifier(vgg) 94 | 95 | # create the full model 96 | layer_names = ["block3_pool", "block4_pool", "block5_pool"] 97 | layer_outputs = [vgg.get_layer(name).output for name in layer_names] 98 | 99 | # C2 not provided 100 | backbone_layers = { 101 | 'C3': layer_outputs[0], 102 | 'C4': layer_outputs[1], 103 | 'C5': layer_outputs[2] 104 | } 105 | 106 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone_layers, **kwargs) 107 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/keras_retinanet/preprocessing/__init__.py -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from ..preprocessing.generator import Generator 18 | from ..utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | 23 | from pycocotools.coco import COCO 24 | 25 | 26 | class CocoGenerator(Generator): 27 | """ Generate data from the COCO dataset. 28 | 29 | See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information. 30 | """ 31 | 32 | def __init__(self, data_dir, set_name, **kwargs): 33 | """ Initialize a COCO data generator. 34 | 35 | Args 36 | data_dir: Path to where the COCO dataset is stored. 37 | set_name: Name of the set to parse. 38 | """ 39 | self.data_dir = data_dir 40 | self.set_name = set_name 41 | self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json')) 42 | self.image_ids = self.coco.getImgIds() 43 | 44 | self.load_classes() 45 | 46 | super(CocoGenerator, self).__init__(**kwargs) 47 | 48 | def load_classes(self): 49 | """ Loads the class to label mapping (and inverse) for COCO. 50 | """ 51 | # load class names (name -> label) 52 | categories = self.coco.loadCats(self.coco.getCatIds()) 53 | categories.sort(key=lambda x: x['id']) 54 | 55 | self.classes = {} 56 | self.coco_labels = {} 57 | self.coco_labels_inverse = {} 58 | for c in categories: 59 | self.coco_labels[len(self.classes)] = c['id'] 60 | self.coco_labels_inverse[c['id']] = len(self.classes) 61 | self.classes[c['name']] = len(self.classes) 62 | 63 | # also load the reverse (label -> name) 64 | self.labels = {} 65 | for key, value in self.classes.items(): 66 | self.labels[value] = key 67 | 68 | def size(self): 69 | """ Size of the COCO dataset. 70 | """ 71 | return len(self.image_ids) 72 | 73 | def num_classes(self): 74 | """ Number of classes in the dataset. For COCO this is 80. 75 | """ 76 | return len(self.classes) 77 | 78 | def has_label(self, label): 79 | """ Return True if label is a known label. 80 | """ 81 | return label in self.labels 82 | 83 | def has_name(self, name): 84 | """ Returns True if name is a known class. 85 | """ 86 | return name in self.classes 87 | 88 | def name_to_label(self, name): 89 | """ Map name to label. 90 | """ 91 | return self.classes[name] 92 | 93 | def label_to_name(self, label): 94 | """ Map label to name. 95 | """ 96 | return self.labels[label] 97 | 98 | def coco_label_to_label(self, coco_label): 99 | """ Map COCO label to the label as used in the network. 100 | COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes. 101 | """ 102 | return self.coco_labels_inverse[coco_label] 103 | 104 | def coco_label_to_name(self, coco_label): 105 | """ Map COCO label to name. 106 | """ 107 | return self.label_to_name(self.coco_label_to_label(coco_label)) 108 | 109 | def label_to_coco_label(self, label): 110 | """ Map label as used by the network to labels as used by COCO. 111 | """ 112 | return self.coco_labels[label] 113 | 114 | def image_path(self, image_index): 115 | """ Returns the image path for image_index. 116 | """ 117 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 118 | path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name']) 119 | return path 120 | 121 | def image_aspect_ratio(self, image_index): 122 | """ Compute the aspect ratio for an image with image_index. 123 | """ 124 | image = self.coco.loadImgs(self.image_ids[image_index])[0] 125 | return float(image['width']) / float(image['height']) 126 | 127 | def load_image(self, image_index): 128 | """ Load an image at the image_index. 129 | """ 130 | path = self.image_path(image_index) 131 | return read_image_bgr(path) 132 | 133 | def load_annotations(self, image_index): 134 | """ Load annotations for an image_index. 135 | """ 136 | # get ground truth annotations 137 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 138 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 139 | 140 | # some images appear to miss annotations (like image with id 257034) 141 | if len(annotations_ids) == 0: 142 | return annotations 143 | 144 | # parse annotations 145 | coco_annotations = self.coco.loadAnns(annotations_ids) 146 | for idx, a in enumerate(coco_annotations): 147 | # some annotations have basically no width / height, skip them 148 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 149 | continue 150 | 151 | annotations['labels'] = np.concatenate([annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0) 152 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[ 153 | a['bbox'][0], 154 | a['bbox'][1], 155 | a['bbox'][0] + a['bbox'][2], 156 | a['bbox'][1] + a['bbox'][3], 157 | ]]], axis=0) 158 | 159 | return annotations 160 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/csv_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 yhenon (https://github.com/yhenon/) 3 | Copyright 2017-2018 Fizyr (https://fizyr.com) 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from .generator import Generator 19 | from ..utils.image import read_image_bgr 20 | 21 | import numpy as np 22 | from PIL import Image 23 | from six import raise_from 24 | 25 | import csv 26 | import sys 27 | import os.path 28 | from collections import OrderedDict 29 | 30 | 31 | def _parse(value, function, fmt): 32 | """ 33 | Parse a string into a value, and format a nice ValueError if it fails. 34 | 35 | Returns `function(value)`. 36 | Any `ValueError` raised is catched and a new `ValueError` is raised 37 | with message `fmt.format(e)`, where `e` is the caught `ValueError`. 38 | """ 39 | try: 40 | return function(value) 41 | except ValueError as e: 42 | raise_from(ValueError(fmt.format(e)), None) 43 | 44 | 45 | def _read_classes(csv_reader): 46 | """ Parse the classes file given by csv_reader. 47 | """ 48 | result = OrderedDict() 49 | for line, row in enumerate(csv_reader): 50 | line += 1 51 | 52 | try: 53 | class_name, class_id = row 54 | except ValueError: 55 | raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None) 56 | class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line)) 57 | 58 | if class_name in result: 59 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) 60 | result[class_name] = class_id 61 | return result 62 | 63 | 64 | def _read_annotations(csv_reader, classes): 65 | """ Read annotations from the csv_reader. 66 | """ 67 | result = OrderedDict() 68 | for line, row in enumerate(csv_reader): 69 | line += 1 70 | 71 | try: 72 | img_file, x1, y1, x2, y2, class_name = row[:6] 73 | except ValueError: 74 | raise_from(ValueError('line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None) 75 | 76 | if img_file not in result: 77 | result[img_file] = [] 78 | 79 | # If a row contains only an image path, it's an image without annotations. 80 | if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''): 81 | continue 82 | 83 | x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line)) 84 | y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line)) 85 | x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line)) 86 | y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line)) 87 | 88 | # Check that the bounding box is valid. 89 | if x2 <= x1: 90 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) 91 | if y2 <= y1: 92 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) 93 | 94 | # check if the current class name is correctly present 95 | if class_name not in classes: 96 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes)) 97 | 98 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name}) 99 | return result 100 | 101 | 102 | def _open_for_csv(path): 103 | """ Open a file with flags suitable for csv.reader. 104 | 105 | This is different for python2 it means with mode 'rb', 106 | for python3 this means 'r' with "universal newlines". 107 | """ 108 | if sys.version_info[0] < 3: 109 | return open(path, 'rb') 110 | else: 111 | return open(path, 'r', newline='') 112 | 113 | 114 | class CSVGenerator(Generator): 115 | """ Generate data for a custom CSV dataset. 116 | 117 | See https://github.com/fizyr/keras-retinanet#csv-datasets for more information. 118 | """ 119 | 120 | def __init__( 121 | self, 122 | csv_data_file, 123 | csv_class_file, 124 | base_dir=None, 125 | **kwargs 126 | ): 127 | """ Initialize a CSV data generator. 128 | 129 | Args 130 | csv_data_file: Path to the CSV annotations file. 131 | csv_class_file: Path to the CSV classes file. 132 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 133 | """ 134 | self.image_names = [] 135 | self.image_data = {} 136 | self.base_dir = base_dir 137 | 138 | # Take base_dir from annotations file if not explicitly specified. 139 | if self.base_dir is None: 140 | self.base_dir = os.path.dirname(csv_data_file) 141 | 142 | # parse the provided class file 143 | try: 144 | with _open_for_csv(csv_class_file) as file: 145 | self.classes = _read_classes(csv.reader(file, delimiter=',')) 146 | except ValueError as e: 147 | raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None) 148 | 149 | self.labels = {} 150 | for key, value in self.classes.items(): 151 | self.labels[value] = key 152 | 153 | # csv with img_path, x1, y1, x2, y2, class_name 154 | try: 155 | with _open_for_csv(csv_data_file) as file: 156 | self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes) 157 | except ValueError as e: 158 | raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None) 159 | self.image_names = list(self.image_data.keys()) 160 | 161 | super(CSVGenerator, self).__init__(**kwargs) 162 | 163 | def size(self): 164 | """ Size of the dataset. 165 | """ 166 | return len(self.image_names) 167 | 168 | def num_classes(self): 169 | """ Number of classes in the dataset. 170 | """ 171 | return max(self.classes.values()) + 1 172 | 173 | def has_label(self, label): 174 | """ Return True if label is a known label. 175 | """ 176 | return label in self.labels 177 | 178 | def has_name(self, name): 179 | """ Returns True if name is a known class. 180 | """ 181 | return name in self.classes 182 | 183 | def name_to_label(self, name): 184 | """ Map name to label. 185 | """ 186 | return self.classes[name] 187 | 188 | def label_to_name(self, label): 189 | """ Map label to name. 190 | """ 191 | return self.labels[label] 192 | 193 | def image_path(self, image_index): 194 | """ Returns the image path for image_index. 195 | """ 196 | return os.path.join(self.base_dir, self.image_names[image_index]) 197 | 198 | def image_aspect_ratio(self, image_index): 199 | """ Compute the aspect ratio for an image with image_index. 200 | """ 201 | # PIL is fast for metadata 202 | image = Image.open(self.image_path(image_index)) 203 | return float(image.width) / float(image.height) 204 | 205 | def load_image(self, image_index): 206 | """ Load an image at the image_index. 207 | """ 208 | return read_image_bgr(self.image_path(image_index)) 209 | 210 | def load_annotations(self, image_index): 211 | """ Load annotations for an image_index. 212 | """ 213 | path = self.image_names[image_index] 214 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 215 | 216 | for idx, annot in enumerate(self.image_data[path]): 217 | annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])])) 218 | annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[ 219 | float(annot['x1']), 220 | float(annot['y1']), 221 | float(annot['x2']), 222 | float(annot['y2']), 223 | ]])) 224 | 225 | return annotations 226 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import csv 18 | import os.path 19 | 20 | import numpy as np 21 | from PIL import Image 22 | 23 | from .generator import Generator 24 | from ..utils.image import read_image_bgr 25 | 26 | kitti_classes = { 27 | 'Car': 0, 28 | 'Van': 1, 29 | 'Truck': 2, 30 | 'Pedestrian': 3, 31 | 'Person_sitting': 4, 32 | 'Cyclist': 5, 33 | 'Tram': 6, 34 | 'Misc': 7, 35 | 'DontCare': 7 36 | } 37 | 38 | 39 | class KittiGenerator(Generator): 40 | """ Generate data for a KITTI dataset. 41 | 42 | See http://www.cvlibs.net/datasets/kitti/ for more information. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | base_dir, 48 | subset='train', 49 | **kwargs 50 | ): 51 | """ Initialize a KITTI data generator. 52 | 53 | Args 54 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 55 | subset: The subset to generate data for (defaults to 'train'). 56 | """ 57 | self.base_dir = base_dir 58 | 59 | label_dir = os.path.join(self.base_dir, subset, 'labels') 60 | image_dir = os.path.join(self.base_dir, subset, 'images') 61 | 62 | """ 63 | 1 type Describes the type of object: 'Car', 'Van', 'Truck', 64 | 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 65 | 'Misc' or 'DontCare' 66 | 1 truncated Float from 0 (non-truncated) to 1 (truncated), where 67 | truncated refers to the object leaving image boundaries 68 | 1 occluded Integer (0,1,2,3) indicating occlusion state: 69 | 0 = fully visible, 1 = partly occluded 70 | 2 = largely occluded, 3 = unknown 71 | 1 alpha Observation angle of object, ranging [-pi..pi] 72 | 4 bbox 2D bounding box of object in the image (0-based index): 73 | contains left, top, right, bottom pixel coordinates 74 | 3 dimensions 3D object dimensions: height, width, length (in meters) 75 | 3 location 3D object location x,y,z in camera coordinates (in meters) 76 | 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi] 77 | """ 78 | 79 | self.labels = {} 80 | self.classes = kitti_classes 81 | for name, label in self.classes.items(): 82 | self.labels[label] = name 83 | 84 | self.image_data = dict() 85 | self.images = [] 86 | for i, fn in enumerate(os.listdir(label_dir)): 87 | label_fp = os.path.join(label_dir, fn) 88 | image_fp = os.path.join(image_dir, fn.replace('.txt', '.png')) 89 | 90 | self.images.append(image_fp) 91 | 92 | fieldnames = ['type', 'truncated', 'occluded', 'alpha', 'left', 'top', 'right', 'bottom', 'dh', 'dw', 'dl', 93 | 'lx', 'ly', 'lz', 'ry'] 94 | with open(label_fp, 'r') as csv_file: 95 | reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames) 96 | boxes = [] 97 | for line, row in enumerate(reader): 98 | label = row['type'] 99 | cls_id = kitti_classes[label] 100 | 101 | annotation = {'cls_id': cls_id, 'x1': row['left'], 'x2': row['right'], 'y2': row['bottom'], 'y1': row['top']} 102 | boxes.append(annotation) 103 | 104 | self.image_data[i] = boxes 105 | 106 | super(KittiGenerator, self).__init__(**kwargs) 107 | 108 | def size(self): 109 | """ Size of the dataset. 110 | """ 111 | return len(self.images) 112 | 113 | def num_classes(self): 114 | """ Number of classes in the dataset. 115 | """ 116 | return max(self.classes.values()) + 1 117 | 118 | def has_label(self, label): 119 | """ Return True if label is a known label. 120 | """ 121 | return label in self.labels 122 | 123 | def has_name(self, name): 124 | """ Returns True if name is a known class. 125 | """ 126 | return name in self.classes 127 | 128 | def name_to_label(self, name): 129 | """ Map name to label. 130 | """ 131 | raise NotImplementedError() 132 | 133 | def label_to_name(self, label): 134 | """ Map label to name. 135 | """ 136 | return self.labels[label] 137 | 138 | def image_aspect_ratio(self, image_index): 139 | """ Compute the aspect ratio for an image with image_index. 140 | """ 141 | # PIL is fast for metadata 142 | image = Image.open(self.images[image_index]) 143 | return float(image.width) / float(image.height) 144 | 145 | def image_path(self, image_index): 146 | """ Get the path to an image. 147 | """ 148 | return self.images[image_index] 149 | 150 | def load_image(self, image_index): 151 | """ Load an image at the image_index. 152 | """ 153 | return read_image_bgr(self.image_path(image_index)) 154 | 155 | def load_annotations(self, image_index): 156 | """ Load annotations for an image_index. 157 | """ 158 | image_data = self.image_data[image_index] 159 | annotations = {'labels': np.empty((len(image_data),)), 'bboxes': np.empty((len(image_data), 4))} 160 | 161 | for idx, ann in enumerate(image_data): 162 | annotations['bboxes'][idx, 0] = float(ann['x1']) 163 | annotations['bboxes'][idx, 1] = float(ann['y1']) 164 | annotations['bboxes'][idx, 2] = float(ann['x2']) 165 | annotations['bboxes'][idx, 3] = float(ann['y2']) 166 | annotations['labels'][idx] = int(ann['cls_id']) 167 | 168 | return annotations 169 | -------------------------------------------------------------------------------- /keras_retinanet/preprocessing/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from ..preprocessing.generator import Generator 18 | from ..utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | from six import raise_from 23 | from PIL import Image 24 | 25 | try: 26 | import xml.etree.cElementTree as ET 27 | except ImportError: 28 | import xml.etree.ElementTree as ET 29 | 30 | voc_classes = { 31 | 'aeroplane' : 0, 32 | 'bicycle' : 1, 33 | 'bird' : 2, 34 | 'boat' : 3, 35 | 'bottle' : 4, 36 | 'bus' : 5, 37 | 'car' : 6, 38 | 'cat' : 7, 39 | 'chair' : 8, 40 | 'cow' : 9, 41 | 'diningtable' : 10, 42 | 'dog' : 11, 43 | 'horse' : 12, 44 | 'motorbike' : 13, 45 | 'person' : 14, 46 | 'pottedplant' : 15, 47 | 'sheep' : 16, 48 | 'sofa' : 17, 49 | 'train' : 18, 50 | 'tvmonitor' : 19 51 | } 52 | 53 | 54 | def _findNode(parent, name, debug_name=None, parse=None): 55 | if debug_name is None: 56 | debug_name = name 57 | 58 | result = parent.find(name) 59 | if result is None: 60 | raise ValueError('missing element \'{}\''.format(debug_name)) 61 | if parse is not None: 62 | try: 63 | return parse(result.text) 64 | except ValueError as e: 65 | raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None) 66 | return result 67 | 68 | 69 | class PascalVocGenerator(Generator): 70 | """ Generate data for a Pascal VOC dataset. 71 | 72 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | data_dir, 78 | set_name, 79 | classes=voc_classes, 80 | image_extension='.jpg', 81 | skip_truncated=False, 82 | skip_difficult=False, 83 | **kwargs 84 | ): 85 | """ Initialize a Pascal VOC data generator. 86 | 87 | Args 88 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 89 | csv_class_file: Path to the CSV classes file. 90 | """ 91 | self.data_dir = data_dir 92 | self.set_name = set_name 93 | self.classes = classes 94 | self.image_names = [line.strip().split(None, 1)[0] for line in open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()] 95 | self.image_extension = image_extension 96 | self.skip_truncated = skip_truncated 97 | self.skip_difficult = skip_difficult 98 | 99 | self.labels = {} 100 | for key, value in self.classes.items(): 101 | self.labels[value] = key 102 | 103 | super(PascalVocGenerator, self).__init__(**kwargs) 104 | 105 | def size(self): 106 | """ Size of the dataset. 107 | """ 108 | return len(self.image_names) 109 | 110 | def num_classes(self): 111 | """ Number of classes in the dataset. 112 | """ 113 | return len(self.classes) 114 | 115 | def has_label(self, label): 116 | """ Return True if label is a known label. 117 | """ 118 | return label in self.labels 119 | 120 | def has_name(self, name): 121 | """ Returns True if name is a known class. 122 | """ 123 | return name in self.classes 124 | 125 | def name_to_label(self, name): 126 | """ Map name to label. 127 | """ 128 | return self.classes[name] 129 | 130 | def label_to_name(self, label): 131 | """ Map label to name. 132 | """ 133 | return self.labels[label] 134 | 135 | def image_aspect_ratio(self, image_index): 136 | """ Compute the aspect ratio for an image with image_index. 137 | """ 138 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 139 | image = Image.open(path) 140 | return float(image.width) / float(image.height) 141 | 142 | def image_path(self, image_index): 143 | """ Get the path to an image. 144 | """ 145 | return os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 146 | 147 | def load_image(self, image_index): 148 | """ Load an image at the image_index. 149 | """ 150 | return read_image_bgr(self.image_path(image_index)) 151 | 152 | def __parse_annotation(self, element): 153 | """ Parse an annotation given an XML element. 154 | """ 155 | truncated = _findNode(element, 'truncated', parse=int) 156 | difficult = _findNode(element, 'difficult', parse=int) 157 | 158 | class_name = _findNode(element, 'name').text 159 | if class_name not in self.classes: 160 | raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys()))) 161 | 162 | box = np.zeros((4,)) 163 | label = self.name_to_label(class_name) 164 | 165 | bndbox = _findNode(element, 'bndbox') 166 | box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1 167 | box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1 168 | box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1 169 | box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1 170 | 171 | return truncated, difficult, box, label 172 | 173 | def __parse_annotations(self, xml_root): 174 | """ Parse all annotations under the xml_root. 175 | """ 176 | annotations = {'labels': np.empty((len(xml_root.findall('object')),)), 'bboxes': np.empty((len(xml_root.findall('object')), 4))} 177 | for i, element in enumerate(xml_root.iter('object')): 178 | try: 179 | truncated, difficult, box, label = self.__parse_annotation(element) 180 | except ValueError as e: 181 | raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None) 182 | 183 | if truncated and self.skip_truncated: 184 | continue 185 | if difficult and self.skip_difficult: 186 | continue 187 | 188 | annotations['bboxes'][i, :] = box 189 | annotations['labels'][i] = label 190 | 191 | return annotations 192 | 193 | def load_annotations(self, image_index): 194 | """ Load annotations for an image_index. 195 | """ 196 | filename = self.image_names[image_index] + '.xml' 197 | try: 198 | tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename)) 199 | return self.__parse_annotations(tree.getroot()) 200 | except ET.ParseError as e: 201 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 202 | except ValueError as e: 203 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 204 | -------------------------------------------------------------------------------- /keras_retinanet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/keras_retinanet/utils/__init__.py -------------------------------------------------------------------------------- /keras_retinanet/utils/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from pycocotools.cocoeval import COCOeval 18 | 19 | from tensorflow import keras 20 | import numpy as np 21 | import json 22 | 23 | import progressbar 24 | assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 25 | 26 | 27 | def evaluate_coco(generator, model, threshold=0.05): 28 | """ Use the pycocotools to evaluate a COCO model on a dataset. 29 | 30 | Args 31 | generator : The generator for generating the evaluation data. 32 | model : The model to evaluate. 33 | threshold : The score threshold to use. 34 | """ 35 | # start collecting results 36 | results = [] 37 | image_ids = [] 38 | for index in progressbar.progressbar(range(generator.size()), prefix='COCO evaluation: '): 39 | image = generator.load_image(index) 40 | image = generator.preprocess_image(image) 41 | image, scale = generator.resize_image(image) 42 | 43 | if keras.backend.image_data_format() == 'channels_first': 44 | image = image.transpose((2, 0, 1)) 45 | 46 | # run network 47 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 48 | 49 | # correct boxes for image scale 50 | boxes /= scale 51 | 52 | # change to (x, y, w, h) (MS COCO standard) 53 | boxes[:, :, 2] -= boxes[:, :, 0] 54 | boxes[:, :, 3] -= boxes[:, :, 1] 55 | 56 | # compute predicted labels and scores 57 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 58 | # scores are sorted, so we can break 59 | if score < threshold: 60 | break 61 | 62 | # append detection for each positively labeled class 63 | image_result = { 64 | 'image_id' : generator.image_ids[index], 65 | 'category_id' : generator.label_to_coco_label(label), 66 | 'score' : float(score), 67 | 'bbox' : box.tolist(), 68 | } 69 | 70 | # append detection to results 71 | results.append(image_result) 72 | 73 | # append image to list of processed images 74 | image_ids.append(generator.image_ids[index]) 75 | 76 | if not len(results): 77 | return 78 | 79 | # write output 80 | json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4) 81 | json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4) 82 | 83 | # load results in COCO evaluation tool 84 | coco_true = generator.coco 85 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name)) 86 | 87 | # run COCO evaluation 88 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox') 89 | coco_eval.params.imgIds = image_ids 90 | coco_eval.evaluate() 91 | coco_eval.accumulate() 92 | coco_eval.summarize() 93 | return coco_eval.stats 94 | -------------------------------------------------------------------------------- /keras_retinanet/utils/colors.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | def label_color(label): 5 | """ Return a color from a set of predefined colors. Contains 80 colors in total. 6 | 7 | Args 8 | label: The label to get the color for. 9 | 10 | Returns 11 | A list of three values representing a RGB color. 12 | 13 | If no color is defined for a certain label, the color green is returned and a warning is printed. 14 | """ 15 | if label < len(colors): 16 | return colors[label] 17 | else: 18 | warnings.warn('Label {} has no color, returning default.'.format(label)) 19 | return (0, 255, 0) 20 | 21 | 22 | """ 23 | Generated using: 24 | 25 | ``` 26 | colors = [list((matplotlib.colors.hsv_to_rgb([x, 1.0, 1.0]) * 255).astype(int)) for x in np.arange(0, 1, 1.0 / 80)] 27 | shuffle(colors) 28 | pprint(colors) 29 | ``` 30 | """ 31 | colors = [ 32 | [31 , 0 , 255] , 33 | [0 , 159 , 255] , 34 | [255 , 95 , 0] , 35 | [255 , 19 , 0] , 36 | [255 , 0 , 0] , 37 | [255 , 38 , 0] , 38 | [0 , 255 , 25] , 39 | [255 , 0 , 133] , 40 | [255 , 172 , 0] , 41 | [108 , 0 , 255] , 42 | [0 , 82 , 255] , 43 | [0 , 255 , 6] , 44 | [255 , 0 , 152] , 45 | [223 , 0 , 255] , 46 | [12 , 0 , 255] , 47 | [0 , 255 , 178] , 48 | [108 , 255 , 0] , 49 | [184 , 0 , 255] , 50 | [255 , 0 , 76] , 51 | [146 , 255 , 0] , 52 | [51 , 0 , 255] , 53 | [0 , 197 , 255] , 54 | [255 , 248 , 0] , 55 | [255 , 0 , 19] , 56 | [255 , 0 , 38] , 57 | [89 , 255 , 0] , 58 | [127 , 255 , 0] , 59 | [255 , 153 , 0] , 60 | [0 , 255 , 255] , 61 | [0 , 255 , 216] , 62 | [0 , 255 , 121] , 63 | [255 , 0 , 248] , 64 | [70 , 0 , 255] , 65 | [0 , 255 , 159] , 66 | [0 , 216 , 255] , 67 | [0 , 6 , 255] , 68 | [0 , 63 , 255] , 69 | [31 , 255 , 0] , 70 | [255 , 57 , 0] , 71 | [255 , 0 , 210] , 72 | [0 , 255 , 102] , 73 | [242 , 255 , 0] , 74 | [255 , 191 , 0] , 75 | [0 , 255 , 63] , 76 | [255 , 0 , 95] , 77 | [146 , 0 , 255] , 78 | [184 , 255 , 0] , 79 | [255 , 114 , 0] , 80 | [0 , 255 , 235] , 81 | [255 , 229 , 0] , 82 | [0 , 178 , 255] , 83 | [255 , 0 , 114] , 84 | [255 , 0 , 57] , 85 | [0 , 140 , 255] , 86 | [0 , 121 , 255] , 87 | [12 , 255 , 0] , 88 | [255 , 210 , 0] , 89 | [0 , 255 , 44] , 90 | [165 , 255 , 0] , 91 | [0 , 25 , 255] , 92 | [0 , 255 , 140] , 93 | [0 , 101 , 255] , 94 | [0 , 255 , 82] , 95 | [223 , 255 , 0] , 96 | [242 , 0 , 255] , 97 | [89 , 0 , 255] , 98 | [165 , 0 , 255] , 99 | [70 , 255 , 0] , 100 | [255 , 0 , 172] , 101 | [255 , 76 , 0] , 102 | [203 , 255 , 0] , 103 | [204 , 0 , 255] , 104 | [255 , 0 , 229] , 105 | [255 , 133 , 0] , 106 | [127 , 0 , 255] , 107 | [0 , 235 , 255] , 108 | [0 , 255 , 197] , 109 | [255 , 0 , 191] , 110 | [0 , 44 , 255] , 111 | [50 , 255 , 0] 112 | ] 113 | -------------------------------------------------------------------------------- /keras_retinanet/utils/compute_overlap.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | 13 | def compute_overlap( 14 | np.ndarray[double, ndim=2] boxes, 15 | np.ndarray[double, ndim=2] query_boxes 16 | ): 17 | """ 18 | Args 19 | a: (N, 4) ndarray of float 20 | b: (K, 4) ndarray of float 21 | 22 | Returns 23 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 24 | """ 25 | cdef unsigned int N = boxes.shape[0] 26 | cdef unsigned int K = query_boxes.shape[0] 27 | cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64) 28 | cdef double iw, ih, box_area 29 | cdef double ua 30 | cdef unsigned int k, n 31 | for k in range(K): 32 | box_area = ( 33 | (query_boxes[k, 2] - query_boxes[k, 0]) * 34 | (query_boxes[k, 3] - query_boxes[k, 1]) 35 | ) 36 | for n in range(N): 37 | iw = ( 38 | min(boxes[n, 2], query_boxes[k, 2]) - 39 | max(boxes[n, 0], query_boxes[k, 0]) 40 | ) 41 | if iw > 0: 42 | ih = ( 43 | min(boxes[n, 3], query_boxes[k, 3]) - 44 | max(boxes[n, 1], query_boxes[k, 1]) 45 | ) 46 | if ih > 0: 47 | ua = np.float64( 48 | (boxes[n, 2] - boxes[n, 0]) * 49 | (boxes[n, 3] - boxes[n, 1]) + 50 | box_area - iw * ih 51 | ) 52 | overlaps[n, k] = iw * ih / ua 53 | return overlaps 54 | -------------------------------------------------------------------------------- /keras_retinanet/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import configparser 18 | import numpy as np 19 | from tensorflow import keras 20 | from ..utils.anchors import AnchorParameters 21 | 22 | 23 | def read_config_file(config_path): 24 | config = configparser.ConfigParser() 25 | 26 | with open(config_path, 'r') as file: 27 | config.read_file(file) 28 | 29 | assert 'anchor_parameters' in config, \ 30 | "Malformed config file. Verify that it contains the anchor_parameters section." 31 | 32 | config_keys = set(config['anchor_parameters']) 33 | default_keys = set(AnchorParameters.default.__dict__.keys()) 34 | 35 | assert config_keys <= default_keys, \ 36 | "Malformed config file. These keys are not valid: {}".format(config_keys - default_keys) 37 | 38 | if 'pyramid_levels' in config: 39 | assert('levels' in config['pyramid_levels']), "pyramid levels specified by levels key" 40 | 41 | return config 42 | 43 | 44 | def parse_anchor_parameters(config): 45 | ratios = np.array(list(map(float, config['anchor_parameters']['ratios'].split(' '))), keras.backend.floatx()) 46 | scales = np.array(list(map(float, config['anchor_parameters']['scales'].split(' '))), keras.backend.floatx()) 47 | sizes = list(map(int, config['anchor_parameters']['sizes'].split(' '))) 48 | strides = list(map(int, config['anchor_parameters']['strides'].split(' '))) 49 | assert (len(sizes) == len(strides)), "sizes and strides should have an equal number of values" 50 | 51 | return AnchorParameters(sizes, strides, ratios, scales) 52 | 53 | 54 | def parse_pyramid_levels(config): 55 | levels = list(map(int, config['pyramid_levels']['levels'].split(' '))) 56 | 57 | return levels 58 | -------------------------------------------------------------------------------- /keras_retinanet/utils/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from .anchors import compute_overlap 18 | from .visualization import draw_detections, draw_annotations 19 | 20 | from tensorflow import keras 21 | import numpy as np 22 | import os 23 | import time 24 | 25 | import cv2 26 | import progressbar 27 | assert(callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 28 | 29 | 30 | def _compute_ap(recall, precision): 31 | """ Compute the average precision, given the recall and precision curves. 32 | 33 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 34 | 35 | # Arguments 36 | recall: The recall curve (list). 37 | precision: The precision curve (list). 38 | # Returns 39 | The average precision as computed in py-faster-rcnn. 40 | """ 41 | # correct AP calculation 42 | # first append sentinel values at the end 43 | mrec = np.concatenate(([0.], recall, [1.])) 44 | mpre = np.concatenate(([0.], precision, [0.])) 45 | 46 | # compute the precision envelope 47 | for i in range(mpre.size - 1, 0, -1): 48 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 49 | 50 | # to calculate area under PR curve, look for points 51 | # where X axis (recall) changes value 52 | i = np.where(mrec[1:] != mrec[:-1])[0] 53 | 54 | # and sum (\Delta recall) * prec 55 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 56 | return ap 57 | 58 | 59 | def _get_detections(generator, model, score_threshold=0.05, max_detections=100, save_path=None): 60 | """ Get the detections from the model using the generator. 61 | 62 | The result is a list of lists such that the size is: 63 | all_detections[num_images][num_classes] = detections[num_detections, 4 + num_classes] 64 | 65 | # Arguments 66 | generator : The generator used to run images through the model. 67 | model : The model to run on the images. 68 | score_threshold : The score confidence threshold to use. 69 | max_detections : The maximum number of detections to use per image. 70 | save_path : The path to save the images with visualized detections to. 71 | # Returns 72 | A list of lists containing the detections for each image in the generator. 73 | """ 74 | all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in range(generator.size())] 75 | all_inferences = [None for i in range(generator.size())] 76 | 77 | for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '): 78 | raw_image = generator.load_image(i) 79 | image, scale = generator.resize_image(raw_image.copy()) 80 | image = generator.preprocess_image(image) 81 | 82 | if keras.backend.image_data_format() == 'channels_first': 83 | image = image.transpose((2, 0, 1)) 84 | 85 | # run network 86 | start = time.time() 87 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3] 88 | inference_time = time.time() - start 89 | 90 | # correct boxes for image scale 91 | boxes /= scale 92 | 93 | # select indices which have a score above the threshold 94 | indices = np.where(scores[0, :] > score_threshold)[0] 95 | 96 | # select those scores 97 | scores = scores[0][indices] 98 | 99 | # find the order with which to sort the scores 100 | scores_sort = np.argsort(-scores)[:max_detections] 101 | 102 | # select detections 103 | image_boxes = boxes[0, indices[scores_sort], :] 104 | image_scores = scores[scores_sort] 105 | image_labels = labels[0, indices[scores_sort]] 106 | image_detections = np.concatenate([image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1) 107 | 108 | if save_path is not None: 109 | draw_annotations(raw_image, generator.load_annotations(i), label_to_name=generator.label_to_name) 110 | draw_detections(raw_image, image_boxes, image_scores, image_labels, label_to_name=generator.label_to_name, score_threshold=score_threshold) 111 | 112 | cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image) 113 | 114 | # copy detections to all_detections 115 | for label in range(generator.num_classes()): 116 | if not generator.has_label(label): 117 | continue 118 | 119 | all_detections[i][label] = image_detections[image_detections[:, -1] == label, :-1] 120 | 121 | all_inferences[i] = inference_time 122 | 123 | return all_detections, all_inferences 124 | 125 | 126 | def _get_annotations(generator): 127 | """ Get the ground truth annotations from the generator. 128 | 129 | The result is a list of lists such that the size is: 130 | all_detections[num_images][num_classes] = annotations[num_detections, 5] 131 | 132 | # Arguments 133 | generator : The generator used to retrieve ground truth annotations. 134 | # Returns 135 | A list of lists containing the annotations for each image in the generator. 136 | """ 137 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 138 | 139 | for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '): 140 | # load the annotations 141 | annotations = generator.load_annotations(i) 142 | 143 | # copy detections to all_annotations 144 | for label in range(generator.num_classes()): 145 | if not generator.has_label(label): 146 | continue 147 | 148 | all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy() 149 | 150 | return all_annotations 151 | 152 | 153 | def evaluate( 154 | generator, 155 | model, 156 | iou_threshold=0.5, 157 | score_threshold=0.05, 158 | max_detections=100, 159 | save_path=None 160 | ): 161 | """ Evaluate a given dataset using a given model. 162 | 163 | # Arguments 164 | generator : The generator that represents the dataset to evaluate. 165 | model : The model to evaluate. 166 | iou_threshold : The threshold used to consider when a detection is positive or negative. 167 | score_threshold : The score confidence threshold to use for detections. 168 | max_detections : The maximum number of detections to use per image. 169 | save_path : The path to save images with visualized detections to. 170 | # Returns 171 | A dict mapping class names to mAP scores. 172 | """ 173 | # gather all detections and annotations 174 | all_detections, all_inferences = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections, save_path=save_path) 175 | all_annotations = _get_annotations(generator) 176 | average_precisions = {} 177 | 178 | # all_detections = pickle.load(open('all_detections.pkl', 'rb')) 179 | # all_annotations = pickle.load(open('all_annotations.pkl', 'rb')) 180 | # pickle.dump(all_detections, open('all_detections.pkl', 'wb')) 181 | # pickle.dump(all_annotations, open('all_annotations.pkl', 'wb')) 182 | 183 | # process detections and annotations 184 | for label in range(generator.num_classes()): 185 | if not generator.has_label(label): 186 | continue 187 | 188 | false_positives = np.zeros((0,)) 189 | true_positives = np.zeros((0,)) 190 | scores = np.zeros((0,)) 191 | num_annotations = 0.0 192 | 193 | for i in range(generator.size()): 194 | detections = all_detections[i][label] 195 | annotations = all_annotations[i][label] 196 | num_annotations += annotations.shape[0] 197 | detected_annotations = [] 198 | 199 | for d in detections: 200 | scores = np.append(scores, d[4]) 201 | 202 | if annotations.shape[0] == 0: 203 | false_positives = np.append(false_positives, 1) 204 | true_positives = np.append(true_positives, 0) 205 | continue 206 | 207 | overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations) 208 | assigned_annotation = np.argmax(overlaps, axis=1) 209 | max_overlap = overlaps[0, assigned_annotation] 210 | 211 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations: 212 | false_positives = np.append(false_positives, 0) 213 | true_positives = np.append(true_positives, 1) 214 | detected_annotations.append(assigned_annotation) 215 | else: 216 | false_positives = np.append(false_positives, 1) 217 | true_positives = np.append(true_positives, 0) 218 | 219 | # no annotations -> AP for this class is 0 (is this correct?) 220 | if num_annotations == 0: 221 | average_precisions[label] = 0, 0 222 | continue 223 | 224 | # sort by score 225 | indices = np.argsort(-scores) 226 | false_positives = false_positives[indices] 227 | true_positives = true_positives[indices] 228 | 229 | # compute false positives and true positives 230 | false_positives = np.cumsum(false_positives) 231 | true_positives = np.cumsum(true_positives) 232 | 233 | # compute recall and precision 234 | recall = true_positives / num_annotations 235 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) 236 | 237 | # compute average precision 238 | average_precision = _compute_ap(recall, precision) 239 | average_precisions[label] = average_precision, num_annotations 240 | 241 | # inference time 242 | inference_time = np.sum(all_inferences) / generator.size() 243 | 244 | return average_precisions, inference_time 245 | -------------------------------------------------------------------------------- /keras_retinanet/utils/gpu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2019 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def setup_gpu(gpu_id): 21 | try: 22 | visible_gpu_indices = [int(id) for id in gpu_id.split(',')] 23 | available_gpus = tf.config.list_physical_devices('GPU') 24 | visible_gpus = [gpu for idx, gpu in enumerate(available_gpus) if idx in visible_gpu_indices] 25 | 26 | if visible_gpus: 27 | try: 28 | # Currently, memory growth needs to be the same across GPUs. 29 | for gpu in available_gpus: 30 | tf.config.experimental.set_memory_growth(gpu, True) 31 | 32 | # Use only the selcted gpu. 33 | tf.config.set_visible_devices(visible_gpus, 'GPU') 34 | except RuntimeError as e: 35 | # Visible devices must be set before GPUs have been initialized. 36 | print(e) 37 | 38 | logical_gpus = tf.config.list_logical_devices('GPU') 39 | print(len(available_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 40 | else: 41 | tf.config.set_visible_devices([], 'GPU') 42 | except ValueError: 43 | tf.config.set_visible_devices([], 'GPU') 44 | -------------------------------------------------------------------------------- /keras_retinanet/utils/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | def freeze(model): 19 | """ Set all layers in a model to non-trainable. 20 | 21 | The weights for these layers will not be updated during training. 22 | 23 | This function modifies the given model in-place, 24 | but it also returns the modified model to allow easy chaining with other functions. 25 | """ 26 | for layer in model.layers: 27 | layer.trainable = False 28 | return model 29 | -------------------------------------------------------------------------------- /keras_retinanet/utils/tf_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2019 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | import sys 21 | 22 | MINIMUM_TF_VERSION = 2, 3, 0 23 | BLACKLISTED_TF_VERSIONS = [] 24 | 25 | 26 | def tf_version(): 27 | """ Get the Tensorflow version. 28 | Returns 29 | tuple of (major, minor, patch). 30 | """ 31 | return tuple(map(int, tf.version.VERSION.split('-')[0].split('.'))) 32 | 33 | 34 | def tf_version_ok(minimum_tf_version=MINIMUM_TF_VERSION, blacklisted=BLACKLISTED_TF_VERSIONS): 35 | """ Check if the current Tensorflow version is higher than the minimum version. 36 | """ 37 | return tf_version() >= minimum_tf_version and tf_version() not in blacklisted 38 | 39 | 40 | def assert_tf_version(minimum_tf_version=MINIMUM_TF_VERSION, blacklisted=BLACKLISTED_TF_VERSIONS): 41 | """ Assert that the Tensorflow version is up to date. 42 | """ 43 | detected = tf.version.VERSION 44 | required = '.'.join(map(str, minimum_tf_version)) 45 | assert(tf_version_ok(minimum_tf_version, blacklisted)), 'You are using tensorflow version {}. The minimum required version is {} (blacklisted: {}).'.format(detected, required, blacklisted) 46 | 47 | 48 | def check_tf_version(): 49 | """ Check that the Tensorflow version is up to date. If it isn't, print an error message and exit the script. 50 | """ 51 | try: 52 | assert_tf_version() 53 | except AssertionError as e: 54 | print(e, file=sys.stderr) 55 | sys.exit(1) 56 | -------------------------------------------------------------------------------- /keras_retinanet/utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import cv2 18 | import numpy as np 19 | 20 | from .colors import label_color 21 | 22 | 23 | def draw_box(image, box, color, thickness=2): 24 | """ Draws a box on an image with a given color. 25 | 26 | # Arguments 27 | image : The image to draw on. 28 | box : A list of 4 elements (x1, y1, x2, y2). 29 | color : The color of the box. 30 | thickness : The thickness of the lines to draw a box with. 31 | """ 32 | b = np.array(box).astype(int) 33 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, thickness, cv2.LINE_AA) 34 | 35 | 36 | def draw_caption(image, box, caption): 37 | """ Draws a caption above the box in an image. 38 | 39 | # Arguments 40 | image : The image to draw on. 41 | box : A list of 4 elements (x1, y1, x2, y2). 42 | caption : String containing the text to draw. 43 | """ 44 | b = np.array(box).astype(int) 45 | cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2) 46 | cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1) 47 | 48 | 49 | def draw_boxes(image, boxes, color, thickness=2): 50 | """ Draws boxes on an image with a given color. 51 | 52 | # Arguments 53 | image : The image to draw on. 54 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 55 | color : The color of the boxes. 56 | thickness : The thickness of the lines to draw boxes with. 57 | """ 58 | for b in boxes: 59 | draw_box(image, b, color, thickness=thickness) 60 | 61 | 62 | def draw_detections(image, boxes, scores, labels, color=None, label_to_name=None, score_threshold=0.5): 63 | """ Draws detections in an image. 64 | 65 | # Arguments 66 | image : The image to draw on. 67 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 68 | scores : A list of N classification scores. 69 | labels : A list of N labels. 70 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 71 | label_to_name : (optional) Functor for mapping a label to a name. 72 | score_threshold : Threshold used for determining what detections to draw. 73 | """ 74 | selection = np.where(scores > score_threshold)[0] 75 | 76 | for i in selection: 77 | c = color if color is not None else label_color(labels[i]) 78 | draw_box(image, boxes[i, :], color=c) 79 | 80 | # draw labels 81 | caption = (label_to_name(labels[i]) if label_to_name else labels[i]) + ': {0:.2f}'.format(scores[i]) 82 | draw_caption(image, boxes[i, :], caption) 83 | 84 | 85 | def draw_annotations(image, annotations, color=(0, 255, 0), label_to_name=None): 86 | """ Draws annotations in an image. 87 | 88 | # Arguments 89 | image : The image to draw on. 90 | annotations : A [N, 5] matrix (x1, y1, x2, y2, label) or dictionary containing bboxes (shaped [N, 4]) and labels (shaped [N]). 91 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 92 | label_to_name : (optional) Functor for mapping a label to a name. 93 | """ 94 | if isinstance(annotations, np.ndarray): 95 | annotations = {'bboxes': annotations[:, :4], 'labels': annotations[:, 4]} 96 | 97 | assert('bboxes' in annotations) 98 | assert('labels' in annotations) 99 | assert(annotations['bboxes'].shape[0] == annotations['labels'].shape[0]) 100 | 101 | for i in range(annotations['bboxes'].shape[0]): 102 | label = annotations['labels'][i] 103 | c = color if color is not None else label_color(label) 104 | caption = '{}'.format(label_to_name(label) if label_to_name else label) 105 | draw_caption(image, annotations['bboxes'][i], caption) 106 | draw_box(image, annotations['bboxes'][i], color=c) 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | keras-resnet==0.2.0 3 | git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI 4 | h5py 5 | keras>=2.0.9,<=2.3.1 6 | matplotlib 7 | numpy>=1.14 8 | opencv-python>=3.3.0 9 | pillow 10 | progressbar2 11 | tensorflow>=2.3.0 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # ignore: 2 | # E201 whitespace after '[' 3 | # E202 whitespace before ']' 4 | # E203 whitespace before ':' 5 | # E221 multiple spaces before operator 6 | # E241 multiple spaces after ',' 7 | # E251 unexpected spaces around keyword / parameter equals 8 | # E501 line too long (85 > 79 characters) 9 | # W504 line break after binary operator 10 | [tool:pytest] 11 | flake8-max-line-length = 100 12 | flake8-ignore = E201 E202 E203 E221 E241 E251 E402 E501 W504 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools.extension import Extension 3 | from distutils.command.build_ext import build_ext as DistUtilsBuildExt 4 | 5 | 6 | class BuildExtension(setuptools.Command): 7 | description = DistUtilsBuildExt.description 8 | user_options = DistUtilsBuildExt.user_options 9 | boolean_options = DistUtilsBuildExt.boolean_options 10 | help_options = DistUtilsBuildExt.help_options 11 | 12 | def __init__(self, *args, **kwargs): 13 | from setuptools.command.build_ext import build_ext as SetupToolsBuildExt 14 | 15 | # Bypass __setatrr__ to avoid infinite recursion. 16 | self.__dict__['_command'] = SetupToolsBuildExt(*args, **kwargs) 17 | 18 | def __getattr__(self, name): 19 | return getattr(self._command, name) 20 | 21 | def __setattr__(self, name, value): 22 | setattr(self._command, name, value) 23 | 24 | def initialize_options(self, *args, **kwargs): 25 | return self._command.initialize_options(*args, **kwargs) 26 | 27 | def finalize_options(self, *args, **kwargs): 28 | ret = self._command.finalize_options(*args, **kwargs) 29 | import numpy 30 | self.include_dirs.append(numpy.get_include()) 31 | return ret 32 | 33 | def run(self, *args, **kwargs): 34 | return self._command.run(*args, **kwargs) 35 | 36 | 37 | extensions = [ 38 | Extension( 39 | 'keras_retinanet.utils.compute_overlap', 40 | ['keras_retinanet/utils/compute_overlap.pyx'] 41 | ), 42 | ] 43 | 44 | 45 | setuptools.setup( 46 | name = 'keras-retinanet', 47 | version = '1.0.0', 48 | description = 'Keras implementation of RetinaNet object detection.', 49 | url = 'https://github.com/fizyr/keras-retinanet', 50 | author = 'Hans Gaiser', 51 | author_email = 'h.gaiser@fizyr.com', 52 | maintainer = 'Hans Gaiser', 53 | maintainer_email = 'h.gaiser@fizyr.com', 54 | cmdclass = {'build_ext': BuildExtension}, 55 | packages = setuptools.find_packages(), 56 | install_requires = ['keras-resnet==0.2.0', 'six', 'numpy', 'cython', 'Pillow', 'opencv-python', 'progressbar2'], 57 | entry_points = { 58 | 'console_scripts': [ 59 | 'retinanet-train=keras_retinanet.bin.train:main', 60 | 'retinanet-evaluate=keras_retinanet.bin.evaluate:main', 61 | 'retinanet-debug=keras_retinanet.bin.debug:main', 62 | 'retinanet-convert-model=keras_retinanet.bin.convert_model:main', 63 | ], 64 | }, 65 | ext_modules = extensions, 66 | setup_requires = ["cython>=0.28", "numpy>=1.14.0"] 67 | ) 68 | -------------------------------------------------------------------------------- /snapshots/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/tests/__init__.py -------------------------------------------------------------------------------- /tests/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/tests/backend/__init__.py -------------------------------------------------------------------------------- /tests/backend/test_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import numpy as np 18 | from tensorflow import keras 19 | import keras_retinanet.backend 20 | 21 | 22 | def test_bbox_transform_inv(): 23 | boxes = np.array([[ 24 | [100, 100, 200, 200], 25 | [100, 100, 300, 300], 26 | [100, 100, 200, 300], 27 | [100, 100, 300, 200], 28 | [80, 120, 200, 200], 29 | [80, 120, 300, 300], 30 | [80, 120, 200, 300], 31 | [80, 120, 300, 200], 32 | ]]) 33 | boxes = keras.backend.variable(boxes) 34 | 35 | deltas = np.array([[ 36 | [0 , 0 , 0 , 0 ], 37 | [0 , 0.1, 0 , 0 ], 38 | [-0.3, 0 , 0 , 0 ], 39 | [0.2 , 0.2, 0 , 0 ], 40 | [0 , 0 , 0.1 , 0 ], 41 | [0 , 0 , 0 , -0.3], 42 | [0 , 0 , 0.2 , 0.2 ], 43 | [0.1 , 0.2, -0.3, 0.4 ], 44 | ]]) 45 | deltas = keras.backend.variable(deltas) 46 | 47 | expected = np.array([[ 48 | [100 , 100 , 200 , 200 ], 49 | [100 , 104 , 300 , 300 ], 50 | [ 94 , 100 , 200 , 300 ], 51 | [108 , 104 , 300 , 200 ], 52 | [ 80 , 120 , 202.4 , 200 ], 53 | [ 80 , 120 , 300 , 289.2], 54 | [ 80 , 120 , 204.8 , 307.2], 55 | [ 84.4, 123.2, 286.8 , 206.4] 56 | ]]) 57 | 58 | result = keras_retinanet.backend.bbox_transform_inv(boxes, deltas) 59 | result = keras.backend.eval(result) 60 | 61 | np.testing.assert_array_almost_equal(result, expected, decimal=2) 62 | 63 | 64 | def test_shift(): 65 | shape = (2, 3) 66 | stride = 8 67 | 68 | anchors = np.array([ 69 | [-8, -8, 8, 8], 70 | [-16, -16, 16, 16], 71 | [-12, -12, 12, 12], 72 | [-12, -16, 12, 16], 73 | [-16, -12, 16, 12] 74 | ], dtype=keras.backend.floatx()) 75 | 76 | expected = [ 77 | # anchors for (0, 0) 78 | [4 - 8, 4 - 8, 4 + 8, 4 + 8], 79 | [4 - 16, 4 - 16, 4 + 16, 4 + 16], 80 | [4 - 12, 4 - 12, 4 + 12, 4 + 12], 81 | [4 - 12, 4 - 16, 4 + 12, 4 + 16], 82 | [4 - 16, 4 - 12, 4 + 16, 4 + 12], 83 | 84 | # anchors for (0, 1) 85 | [12 - 8, 4 - 8, 12 + 8, 4 + 8], 86 | [12 - 16, 4 - 16, 12 + 16, 4 + 16], 87 | [12 - 12, 4 - 12, 12 + 12, 4 + 12], 88 | [12 - 12, 4 - 16, 12 + 12, 4 + 16], 89 | [12 - 16, 4 - 12, 12 + 16, 4 + 12], 90 | 91 | # anchors for (0, 2) 92 | [20 - 8, 4 - 8, 20 + 8, 4 + 8], 93 | [20 - 16, 4 - 16, 20 + 16, 4 + 16], 94 | [20 - 12, 4 - 12, 20 + 12, 4 + 12], 95 | [20 - 12, 4 - 16, 20 + 12, 4 + 16], 96 | [20 - 16, 4 - 12, 20 + 16, 4 + 12], 97 | 98 | # anchors for (1, 0) 99 | [4 - 8, 12 - 8, 4 + 8, 12 + 8], 100 | [4 - 16, 12 - 16, 4 + 16, 12 + 16], 101 | [4 - 12, 12 - 12, 4 + 12, 12 + 12], 102 | [4 - 12, 12 - 16, 4 + 12, 12 + 16], 103 | [4 - 16, 12 - 12, 4 + 16, 12 + 12], 104 | 105 | # anchors for (1, 1) 106 | [12 - 8, 12 - 8, 12 + 8, 12 + 8], 107 | [12 - 16, 12 - 16, 12 + 16, 12 + 16], 108 | [12 - 12, 12 - 12, 12 + 12, 12 + 12], 109 | [12 - 12, 12 - 16, 12 + 12, 12 + 16], 110 | [12 - 16, 12 - 12, 12 + 16, 12 + 12], 111 | 112 | # anchors for (1, 2) 113 | [20 - 8, 12 - 8, 20 + 8, 12 + 8], 114 | [20 - 16, 12 - 16, 20 + 16, 12 + 16], 115 | [20 - 12, 12 - 12, 20 + 12, 12 + 12], 116 | [20 - 12, 12 - 16, 20 + 12, 12 + 16], 117 | [20 - 16, 12 - 12, 20 + 16, 12 + 12], 118 | ] 119 | 120 | result = keras_retinanet.backend.shift(shape, stride, anchors) 121 | result = keras.backend.eval(result) 122 | 123 | np.testing.assert_array_equal(result, expected) 124 | -------------------------------------------------------------------------------- /tests/bin/test_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras_retinanet.backend 18 | import keras_retinanet.bin.train 19 | from tensorflow import keras 20 | 21 | import warnings 22 | 23 | import pytest 24 | 25 | 26 | @pytest.fixture(autouse=True) 27 | def clear_session(): 28 | # run before test (do nothing) 29 | yield 30 | # run after test, clear keras session 31 | keras.backend.clear_session() 32 | 33 | 34 | def test_coco(): 35 | # ignore warnings in this test 36 | warnings.simplefilter('ignore') 37 | 38 | # run training / evaluation 39 | keras_retinanet.bin.train.main([ 40 | '--epochs=1', 41 | '--steps=1', 42 | '--no-weights', 43 | '--no-snapshots', 44 | 'coco', 45 | 'tests/test-data/coco', 46 | ]) 47 | 48 | 49 | def test_pascal(): 50 | # ignore warnings in this test 51 | warnings.simplefilter('ignore') 52 | 53 | # run training / evaluation 54 | keras_retinanet.bin.train.main([ 55 | '--epochs=1', 56 | '--steps=1', 57 | '--no-weights', 58 | '--no-snapshots', 59 | 'pascal', 60 | 'tests/test-data/pascal', 61 | ]) 62 | 63 | 64 | def test_csv(): 65 | # ignore warnings in this test 66 | warnings.simplefilter('ignore') 67 | 68 | # run training / evaluation 69 | keras_retinanet.bin.train.main([ 70 | '--epochs=1', 71 | '--steps=1', 72 | '--no-weights', 73 | '--no-snapshots', 74 | 'csv', 75 | 'tests/test-data/csv/annotations.csv', 76 | 'tests/test-data/csv/classes.csv', 77 | ]) 78 | 79 | 80 | def test_vgg(): 81 | # ignore warnings in this test 82 | warnings.simplefilter('ignore') 83 | 84 | # run training / evaluation 85 | keras_retinanet.bin.train.main([ 86 | '--backbone=vgg16', 87 | '--epochs=1', 88 | '--steps=1', 89 | '--no-weights', 90 | '--no-snapshots', 91 | '--freeze-backbone', 92 | 'coco', 93 | 'tests/test-data/coco', 94 | ]) 95 | -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/tests/layers/__init__.py -------------------------------------------------------------------------------- /tests/layers/test_filter_detections.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | import keras_retinanet.backend 19 | import keras_retinanet.layers 20 | 21 | import numpy as np 22 | 23 | 24 | class TestFilterDetections(object): 25 | def test_simple(self): 26 | # create simple FilterDetections layer 27 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 28 | 29 | # create simple input 30 | boxes = np.array([[ 31 | [0, 0, 10, 10], 32 | [0, 0, 10, 10], # this will be suppressed 33 | ]], dtype=keras.backend.floatx()) 34 | boxes = keras.backend.constant(boxes) 35 | 36 | classification = np.array([[ 37 | [0, 0.9], # this will be suppressed 38 | [0, 1], 39 | ]], dtype=keras.backend.floatx()) 40 | classification = keras.backend.constant(classification) 41 | 42 | # compute output 43 | actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) 44 | actual_boxes = keras.backend.eval(actual_boxes) 45 | actual_scores = keras.backend.eval(actual_scores) 46 | actual_labels = keras.backend.eval(actual_labels) 47 | 48 | # define expected output 49 | expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) 50 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 51 | 52 | expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 53 | expected_scores[0, 0] = 1 54 | 55 | expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 56 | expected_labels[0, 0] = 1 57 | 58 | # assert actual and expected are equal 59 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 60 | np.testing.assert_array_equal(actual_scores, expected_scores) 61 | np.testing.assert_array_equal(actual_labels, expected_labels) 62 | 63 | def test_simple_with_other(self): 64 | # create simple FilterDetections layer 65 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 66 | 67 | # create simple input 68 | boxes = np.array([[ 69 | [0, 0, 10, 10], 70 | [0, 0, 10, 10], # this will be suppressed 71 | ]], dtype=keras.backend.floatx()) 72 | boxes = keras.backend.constant(boxes) 73 | 74 | classification = np.array([[ 75 | [0, 0.9], # this will be suppressed 76 | [0, 1], 77 | ]], dtype=keras.backend.floatx()) 78 | classification = keras.backend.constant(classification) 79 | 80 | other = [] 81 | other.append(np.array([[ 82 | [0, 1234], # this will be suppressed 83 | [0, 5678], 84 | ]], dtype=keras.backend.floatx())) 85 | other.append(np.array([[ 86 | 5678, # this will be suppressed 87 | 1234, 88 | ]], dtype=keras.backend.floatx())) 89 | other = [keras.backend.constant(o) for o in other] 90 | 91 | # compute output 92 | actual = filter_detections_layer.call([boxes, classification] + other) 93 | actual_boxes = keras.backend.eval(actual[0]) 94 | actual_scores = keras.backend.eval(actual[1]) 95 | actual_labels = keras.backend.eval(actual[2]) 96 | actual_other = [keras.backend.eval(a) for a in actual[3:]] 97 | 98 | # define expected output 99 | expected_boxes = -1 * np.ones((1, 300, 4), dtype=keras.backend.floatx()) 100 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 101 | 102 | expected_scores = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 103 | expected_scores[0, 0] = 1 104 | 105 | expected_labels = -1 * np.ones((1, 300), dtype=keras.backend.floatx()) 106 | expected_labels[0, 0] = 1 107 | 108 | expected_other = [] 109 | expected_other.append(-1 * np.ones((1, 300, 2), dtype=keras.backend.floatx())) 110 | expected_other[-1][0, 0, :] = [0, 5678] 111 | expected_other.append(-1 * np.ones((1, 300), dtype=keras.backend.floatx())) 112 | expected_other[-1][0, 0] = 1234 113 | 114 | # assert actual and expected are equal 115 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 116 | np.testing.assert_array_equal(actual_scores, expected_scores) 117 | np.testing.assert_array_equal(actual_labels, expected_labels) 118 | 119 | for a, e in zip(actual_other, expected_other): 120 | np.testing.assert_array_equal(a, e) 121 | 122 | def test_mini_batch(self): 123 | # create simple FilterDetections layer 124 | filter_detections_layer = keras_retinanet.layers.FilterDetections() 125 | 126 | # create input with batch_size=2 127 | boxes = np.array([ 128 | [ 129 | [0, 0, 10, 10], # this will be suppressed 130 | [0, 0, 10, 10], 131 | ], 132 | [ 133 | [100, 100, 150, 150], 134 | [100, 100, 150, 150], # this will be suppressed 135 | ], 136 | ], dtype=keras.backend.floatx()) 137 | boxes = keras.backend.constant(boxes) 138 | 139 | classification = np.array([ 140 | [ 141 | [0, 0.9], # this will be suppressed 142 | [0, 1], 143 | ], 144 | [ 145 | [1, 0], 146 | [0.9, 0], # this will be suppressed 147 | ], 148 | ], dtype=keras.backend.floatx()) 149 | classification = keras.backend.constant(classification) 150 | 151 | # compute output 152 | actual_boxes, actual_scores, actual_labels = filter_detections_layer.call([boxes, classification]) 153 | actual_boxes = keras.backend.eval(actual_boxes) 154 | actual_scores = keras.backend.eval(actual_scores) 155 | actual_labels = keras.backend.eval(actual_labels) 156 | 157 | # define expected output 158 | expected_boxes = -1 * np.ones((2, 300, 4), dtype=keras.backend.floatx()) 159 | expected_boxes[0, 0, :] = [0, 0, 10, 10] 160 | expected_boxes[1, 0, :] = [100, 100, 150, 150] 161 | 162 | expected_scores = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) 163 | expected_scores[0, 0] = 1 164 | expected_scores[1, 0] = 1 165 | 166 | expected_labels = -1 * np.ones((2, 300), dtype=keras.backend.floatx()) 167 | expected_labels[0, 0] = 1 168 | expected_labels[1, 0] = 0 169 | 170 | # assert actual and expected are equal 171 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 172 | np.testing.assert_array_equal(actual_scores, expected_scores) 173 | np.testing.assert_array_equal(actual_labels, expected_labels) 174 | -------------------------------------------------------------------------------- /tests/layers/test_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from tensorflow import keras 18 | import keras_retinanet.backend 19 | import keras_retinanet.layers 20 | 21 | import numpy as np 22 | 23 | 24 | class TestAnchors(object): 25 | def test_simple(self): 26 | # create simple Anchors layer 27 | anchors_layer = keras_retinanet.layers.Anchors( 28 | size=32, 29 | stride=8, 30 | ratios=np.array([1], keras.backend.floatx()), 31 | scales=np.array([1], keras.backend.floatx()), 32 | ) 33 | 34 | # create fake features input (only shape is used anyway) 35 | features = np.zeros((1, 2, 2, 1024), dtype=keras.backend.floatx()) 36 | features = keras.backend.variable(features) 37 | 38 | # call the Anchors layer 39 | anchors = anchors_layer.call(features) 40 | anchors = keras.backend.eval(anchors) 41 | 42 | # expected anchor values 43 | expected = np.array([[ 44 | [-12, -12, 20, 20], 45 | [-4 , -12, 28, 20], 46 | [-12, -4 , 20, 28], 47 | [-4 , -4 , 28, 28], 48 | ]], dtype=keras.backend.floatx()) 49 | 50 | # test anchor values 51 | np.testing.assert_array_equal(anchors, expected) 52 | 53 | # mark test to fail 54 | def test_mini_batch(self): 55 | # create simple Anchors layer 56 | anchors_layer = keras_retinanet.layers.Anchors( 57 | size=32, 58 | stride=8, 59 | ratios=np.array([1], dtype=keras.backend.floatx()), 60 | scales=np.array([1], dtype=keras.backend.floatx()), 61 | ) 62 | 63 | # create fake features input with batch_size=2 64 | features = np.zeros((2, 2, 2, 1024), dtype=keras.backend.floatx()) 65 | features = keras.backend.variable(features) 66 | 67 | # call the Anchors layer 68 | anchors = anchors_layer.call(features) 69 | anchors = keras.backend.eval(anchors) 70 | 71 | # expected anchor values 72 | expected = np.array([[ 73 | [-12, -12, 20, 20], 74 | [-4 , -12, 28, 20], 75 | [-12, -4 , 20, 28], 76 | [-4 , -4 , 28, 28], 77 | ]], dtype=keras.backend.floatx()) 78 | expected = np.tile(expected, (2, 1, 1)) 79 | 80 | # test anchor values 81 | np.testing.assert_array_equal(anchors, expected) 82 | 83 | 84 | class TestUpsampleLike(object): 85 | def test_simple(self): 86 | # create simple UpsampleLike layer 87 | upsample_like_layer = keras_retinanet.layers.UpsampleLike() 88 | 89 | # create input source 90 | source = np.zeros((1, 2, 2, 1), dtype=keras.backend.floatx()) 91 | source = keras.backend.variable(source) 92 | target = np.zeros((1, 5, 5, 1), dtype=keras.backend.floatx()) 93 | expected = target 94 | target = keras.backend.variable(target) 95 | 96 | # compute output 97 | actual = upsample_like_layer.call([source, target]) 98 | actual = keras.backend.eval(actual) 99 | 100 | np.testing.assert_array_equal(actual, expected) 101 | 102 | def test_mini_batch(self): 103 | # create simple UpsampleLike layer 104 | upsample_like_layer = keras_retinanet.layers.UpsampleLike() 105 | 106 | # create input source 107 | source = np.zeros((2, 2, 2, 1), dtype=keras.backend.floatx()) 108 | source = keras.backend.variable(source) 109 | 110 | target = np.zeros((2, 5, 5, 1), dtype=keras.backend.floatx()) 111 | expected = target 112 | target = keras.backend.variable(target) 113 | 114 | # compute output 115 | actual = upsample_like_layer.call([source, target]) 116 | actual = keras.backend.eval(actual) 117 | 118 | np.testing.assert_array_equal(actual, expected) 119 | 120 | 121 | class TestRegressBoxes(object): 122 | def test_simple(self): 123 | mean = [0, 0, 0, 0] 124 | std = [0.2, 0.2, 0.2, 0.2] 125 | 126 | # create simple RegressBoxes layer 127 | regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) 128 | 129 | # create input 130 | anchors = np.array([[ 131 | [0 , 0 , 10 , 10 ], 132 | [50, 50, 100, 100], 133 | [20, 20, 40 , 40 ], 134 | ]], dtype=keras.backend.floatx()) 135 | anchors = keras.backend.variable(anchors) 136 | 137 | regression = np.array([[ 138 | [0 , 0 , 0 , 0 ], 139 | [0.1, 0.1, 0 , 0 ], 140 | [0 , 0 , 0.1, 0.1], 141 | ]], dtype=keras.backend.floatx()) 142 | regression = keras.backend.variable(regression) 143 | 144 | # compute output 145 | actual = regress_boxes_layer.call([anchors, regression]) 146 | actual = keras.backend.eval(actual) 147 | 148 | # compute expected output 149 | expected = np.array([[ 150 | [0 , 0 , 10 , 10 ], 151 | [51, 51, 100 , 100 ], 152 | [20, 20, 40.4, 40.4], 153 | ]], dtype=keras.backend.floatx()) 154 | 155 | np.testing.assert_array_almost_equal(actual, expected, decimal=2) 156 | 157 | # mark test to fail 158 | def test_mini_batch(self): 159 | mean = [0, 0, 0, 0] 160 | std = [0.2, 0.2, 0.2, 0.2] 161 | 162 | # create simple RegressBoxes layer 163 | regress_boxes_layer = keras_retinanet.layers.RegressBoxes(mean=mean, std=std) 164 | 165 | # create input 166 | anchors = np.array([ 167 | [ 168 | [0 , 0 , 10 , 10 ], # 1 169 | [50, 50, 100, 100], # 2 170 | [20, 20, 40 , 40 ], # 3 171 | ], 172 | [ 173 | [20, 20, 40 , 40 ], # 3 174 | [0 , 0 , 10 , 10 ], # 1 175 | [50, 50, 100, 100], # 2 176 | ], 177 | ], dtype=keras.backend.floatx()) 178 | anchors = keras.backend.variable(anchors) 179 | 180 | regression = np.array([ 181 | [ 182 | [0 , 0 , 0 , 0 ], # 1 183 | [0.1, 0.1, 0 , 0 ], # 2 184 | [0 , 0 , 0.1, 0.1], # 3 185 | ], 186 | [ 187 | [0 , 0 , 0.1, 0.1], # 3 188 | [0 , 0 , 0 , 0 ], # 1 189 | [0.1, 0.1, 0 , 0 ], # 2 190 | ], 191 | ], dtype=keras.backend.floatx()) 192 | regression = keras.backend.variable(regression) 193 | 194 | # compute output 195 | actual = regress_boxes_layer.call([anchors, regression]) 196 | actual = keras.backend.eval(actual) 197 | 198 | # compute expected output 199 | expected = np.array([ 200 | [ 201 | [0 , 0 , 10 , 10 ], # 1 202 | [51, 51, 100 , 100 ], # 2 203 | [20, 20, 40.4, 40.4], # 3 204 | ], 205 | [ 206 | [20, 20, 40.4, 40.4], # 3 207 | [0 , 0 , 10 , 10 ], # 1 208 | [51, 51, 100 , 100 ], # 2 209 | ], 210 | ], dtype=keras.backend.floatx()) 211 | 212 | np.testing.assert_array_almost_equal(actual, expected, decimal=2) 213 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 vidosits (https://github.com/vidosits/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import warnings 18 | import pytest 19 | import numpy as np 20 | from tensorflow import keras 21 | from keras_retinanet import losses 22 | from keras_retinanet.models.densenet import DenseNetBackbone 23 | 24 | parameters = ['densenet121'] 25 | 26 | 27 | @pytest.mark.parametrize("backbone", parameters) 28 | def test_backbone(backbone): 29 | # ignore warnings in this test 30 | warnings.simplefilter('ignore') 31 | 32 | num_classes = 10 33 | 34 | inputs = np.zeros((1, 200, 400, 3), dtype=np.float32) 35 | targets = [np.zeros((1, 14814, 5), dtype=np.float32), np.zeros((1, 14814, num_classes + 1))] 36 | 37 | inp = keras.layers.Input(inputs[0].shape) 38 | 39 | densenet_backbone = DenseNetBackbone(backbone) 40 | model = densenet_backbone.retinanet(num_classes=num_classes, inputs=inp) 41 | model.summary() 42 | 43 | # compile model 44 | model.compile( 45 | loss={ 46 | 'regression': losses.smooth_l1(), 47 | 'classification': losses.focal() 48 | }, 49 | optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001)) 50 | 51 | model.fit(inputs, targets, batch_size=1) 52 | -------------------------------------------------------------------------------- /tests/models/test_mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import warnings 18 | import pytest 19 | import numpy as np 20 | from tensorflow import keras 21 | from keras_retinanet import losses 22 | from keras_retinanet.models.mobilenet import MobileNetBackbone 23 | 24 | 25 | alphas = ['1.0'] 26 | parameters = [] 27 | 28 | for backbone in MobileNetBackbone.allowed_backbones: 29 | for alpha in alphas: 30 | parameters.append((backbone, alpha)) 31 | 32 | 33 | @pytest.mark.parametrize("backbone, alpha", parameters) 34 | def test_backbone(backbone, alpha): 35 | # ignore warnings in this test 36 | warnings.simplefilter('ignore') 37 | 38 | num_classes = 10 39 | 40 | inputs = np.zeros((1, 1024, 363, 3), dtype=np.float32) 41 | targets = [np.zeros((1, 68760, 5), dtype=np.float32), np.zeros((1, 68760, num_classes + 1))] 42 | 43 | inp = keras.layers.Input(inputs[0].shape) 44 | 45 | mobilenet_backbone = MobileNetBackbone(backbone='{}_{}'.format(backbone, format(alpha))) 46 | training_model = mobilenet_backbone.retinanet(num_classes=num_classes, inputs=inp) 47 | training_model.summary() 48 | 49 | # compile model 50 | training_model.compile( 51 | loss={ 52 | 'regression': losses.smooth_l1(), 53 | 'classification': losses.focal() 54 | }, 55 | optimizer=keras.optimizers.Adam(lr=1e-5, clipnorm=0.001)) 56 | 57 | training_model.fit(inputs, targets, batch_size=1) 58 | -------------------------------------------------------------------------------- /tests/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/tests/preprocessing/__init__.py -------------------------------------------------------------------------------- /tests/preprocessing/test_csv_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import csv 18 | import pytest 19 | try: 20 | from io import StringIO 21 | except ImportError: 22 | from stringio import StringIO 23 | 24 | from keras_retinanet.preprocessing import csv_generator 25 | 26 | 27 | def csv_str(string): 28 | if str == bytes: 29 | string = string.decode('utf-8') 30 | return csv.reader(StringIO(string)) 31 | 32 | 33 | def annotation(x1, y1, x2, y2, class_name): 34 | return {'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'class': class_name} 35 | 36 | 37 | def test_read_classes(): 38 | assert csv_generator._read_classes(csv_str('')) == {} 39 | assert csv_generator._read_classes(csv_str('a,1')) == {'a': 1} 40 | assert csv_generator._read_classes(csv_str('a,1\nb,2')) == {'a': 1, 'b': 2} 41 | 42 | 43 | def test_read_classes_wrong_format(): 44 | with pytest.raises(ValueError): 45 | try: 46 | csv_generator._read_classes(csv_str('a,b,c')) 47 | except ValueError as e: 48 | assert str(e).startswith('line 1: format should be') 49 | raise 50 | with pytest.raises(ValueError): 51 | try: 52 | csv_generator._read_classes(csv_str('a,1\nb,c,d')) 53 | except ValueError as e: 54 | assert str(e).startswith('line 2: format should be') 55 | raise 56 | 57 | 58 | def test_read_classes_malformed_class_id(): 59 | with pytest.raises(ValueError): 60 | try: 61 | csv_generator._read_classes(csv_str('a,b')) 62 | except ValueError as e: 63 | assert str(e).startswith("line 1: malformed class ID:") 64 | raise 65 | 66 | with pytest.raises(ValueError): 67 | try: 68 | csv_generator._read_classes(csv_str('a,1\nb,c')) 69 | except ValueError as e: 70 | assert str(e).startswith('line 2: malformed class ID:') 71 | raise 72 | 73 | 74 | def test_read_classes_duplicate_name(): 75 | with pytest.raises(ValueError): 76 | try: 77 | csv_generator._read_classes(csv_str('a,1\nb,2\na,3')) 78 | except ValueError as e: 79 | assert str(e).startswith('line 3: duplicate class name') 80 | raise 81 | 82 | 83 | def test_read_annotations(): 84 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 85 | annotations = csv_generator._read_annotations(csv_str( 86 | 'a.png,0,1,2,3,a' '\n' 87 | 'b.png,4,5,6,7,b' '\n' 88 | 'c.png,8,9,10,11,c' '\n' 89 | 'd.png,12,13,14,15,d' '\n' 90 | ), classes) 91 | assert annotations == { 92 | 'a.png': [annotation( 0, 1, 2, 3, 'a')], 93 | 'b.png': [annotation( 4, 5, 6, 7, 'b')], 94 | 'c.png': [annotation( 8, 9, 10, 11, 'c')], 95 | 'd.png': [annotation(12, 13, 14, 15, 'd')], 96 | } 97 | 98 | 99 | def test_read_annotations_multiple(): 100 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 101 | annotations = csv_generator._read_annotations(csv_str( 102 | 'a.png,0,1,2,3,a' '\n' 103 | 'b.png,4,5,6,7,b' '\n' 104 | 'a.png,8,9,10,11,c' '\n' 105 | ), classes) 106 | assert annotations == { 107 | 'a.png': [ 108 | annotation(0, 1, 2, 3, 'a'), 109 | annotation(8, 9, 10, 11, 'c'), 110 | ], 111 | 'b.png': [annotation(4, 5, 6, 7, 'b')], 112 | } 113 | 114 | 115 | def test_read_annotations_wrong_format(): 116 | classes = {'a': 1, 'b': 2, 'c': 4, 'd': 10} 117 | with pytest.raises(ValueError): 118 | try: 119 | csv_generator._read_annotations(csv_str('a.png,1,2,3,a'), classes) 120 | except ValueError as e: 121 | assert str(e).startswith("line 1: format should be") 122 | raise 123 | 124 | with pytest.raises(ValueError): 125 | try: 126 | csv_generator._read_annotations(csv_str( 127 | 'a.png,0,1,2,3,a' '\n' 128 | 'a.png,1,2,3,a' '\n' 129 | ), classes) 130 | except ValueError as e: 131 | assert str(e).startswith("line 2: format should be") 132 | raise 133 | 134 | 135 | def test_read_annotations_wrong_x1(): 136 | with pytest.raises(ValueError): 137 | try: 138 | csv_generator._read_annotations(csv_str('a.png,a,0,1,2,a'), {'a': 1}) 139 | except ValueError as e: 140 | assert str(e).startswith("line 1: malformed x1:") 141 | raise 142 | 143 | 144 | def test_read_annotations_wrong_y1(): 145 | with pytest.raises(ValueError): 146 | try: 147 | csv_generator._read_annotations(csv_str('a.png,0,a,1,2,a'), {'a': 1}) 148 | except ValueError as e: 149 | assert str(e).startswith("line 1: malformed y1:") 150 | raise 151 | 152 | 153 | def test_read_annotations_wrong_x2(): 154 | with pytest.raises(ValueError): 155 | try: 156 | csv_generator._read_annotations(csv_str('a.png,0,1,a,2,a'), {'a': 1}) 157 | except ValueError as e: 158 | assert str(e).startswith("line 1: malformed x2:") 159 | raise 160 | 161 | 162 | def test_read_annotations_wrong_y2(): 163 | with pytest.raises(ValueError): 164 | try: 165 | csv_generator._read_annotations(csv_str('a.png,0,1,2,a,a'), {'a': 1}) 166 | except ValueError as e: 167 | assert str(e).startswith("line 1: malformed y2:") 168 | raise 169 | 170 | 171 | def test_read_annotations_wrong_class(): 172 | with pytest.raises(ValueError): 173 | try: 174 | csv_generator._read_annotations(csv_str('a.png,0,1,2,3,g'), {'a': 1}) 175 | except ValueError as e: 176 | assert str(e).startswith("line 1: unknown class name:") 177 | raise 178 | 179 | 180 | def test_read_annotations_invalid_bb_x(): 181 | with pytest.raises(ValueError): 182 | try: 183 | csv_generator._read_annotations(csv_str('a.png,1,2,1,3,g'), {'a': 1}) 184 | except ValueError as e: 185 | assert str(e).startswith("line 1: x2 (1) must be higher than x1 (1)") 186 | raise 187 | with pytest.raises(ValueError): 188 | try: 189 | csv_generator._read_annotations(csv_str('a.png,9,2,5,3,g'), {'a': 1}) 190 | except ValueError as e: 191 | assert str(e).startswith("line 1: x2 (5) must be higher than x1 (9)") 192 | raise 193 | 194 | 195 | def test_read_annotations_invalid_bb_y(): 196 | with pytest.raises(ValueError): 197 | try: 198 | csv_generator._read_annotations(csv_str('a.png,1,2,3,2,a'), {'a': 1}) 199 | except ValueError as e: 200 | assert str(e).startswith("line 1: y2 (2) must be higher than y1 (2)") 201 | raise 202 | with pytest.raises(ValueError): 203 | try: 204 | csv_generator._read_annotations(csv_str('a.png,1,8,3,5,a'), {'a': 1}) 205 | except ValueError as e: 206 | assert str(e).startswith("line 1: y2 (5) must be higher than y1 (8)") 207 | raise 208 | 209 | 210 | def test_read_annotations_empty_image(): 211 | # Check that images without annotations are parsed. 212 | assert csv_generator._read_annotations(csv_str('a.png,,,,,\nb.png,,,,,'), {'a': 1}) == {'a.png': [], 'b.png': []} 213 | 214 | # Check that lines without annotations don't clear earlier annotations. 215 | assert csv_generator._read_annotations(csv_str('a.png,0,1,2,3,a\na.png,,,,,'), {'a': 1}) == {'a.png': [annotation(0, 1, 2, 3, 'a')]} 216 | -------------------------------------------------------------------------------- /tests/preprocessing/test_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from keras_retinanet.preprocessing.generator import Generator 18 | 19 | import numpy as np 20 | import pytest 21 | 22 | 23 | class SimpleGenerator(Generator): 24 | def __init__(self, bboxes, labels, num_classes=0, image=None): 25 | assert(len(bboxes) == len(labels)) 26 | self.bboxes = bboxes 27 | self.labels = labels 28 | self.num_classes_ = num_classes 29 | self.image = image 30 | super(SimpleGenerator, self).__init__(group_method='none', shuffle_groups=False) 31 | 32 | def num_classes(self): 33 | return self.num_classes_ 34 | 35 | def load_image(self, image_index): 36 | return self.image 37 | 38 | def image_path(self, image_index): 39 | return '' 40 | 41 | def size(self): 42 | return len(self.bboxes) 43 | 44 | def load_annotations(self, image_index): 45 | annotations = {'labels': self.labels[image_index], 'bboxes': self.bboxes[image_index]} 46 | return annotations 47 | 48 | 49 | class TestLoadAnnotationsGroup(object): 50 | def test_simple(self): 51 | input_bboxes_group = [ 52 | np.array([ 53 | [ 0, 0, 10, 10], 54 | [150, 150, 350, 350] 55 | ]), 56 | ] 57 | input_labels_group = [ 58 | np.array([ 59 | 1, 60 | 3 61 | ]), 62 | ] 63 | expected_bboxes_group = input_bboxes_group 64 | expected_labels_group = input_labels_group 65 | 66 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 67 | annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) 68 | 69 | assert('bboxes' in annotations[0]) 70 | assert('labels' in annotations[0]) 71 | np.testing.assert_equal(expected_bboxes_group[0], annotations[0]['bboxes']) 72 | np.testing.assert_equal(expected_labels_group[0], annotations[0]['labels']) 73 | 74 | def test_multiple(self): 75 | input_bboxes_group = [ 76 | np.array([ 77 | [ 0, 0, 10, 10], 78 | [150, 150, 350, 350] 79 | ]), 80 | np.array([ 81 | [0, 0, 50, 50], 82 | ]), 83 | ] 84 | input_labels_group = [ 85 | np.array([ 86 | 1, 87 | 0 88 | ]), 89 | np.array([ 90 | 3 91 | ]) 92 | ] 93 | expected_bboxes_group = input_bboxes_group 94 | expected_labels_group = input_labels_group 95 | 96 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 97 | annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) 98 | annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) 99 | 100 | assert('bboxes' in annotations_group_0[0]) 101 | assert('bboxes' in annotations_group_1[0]) 102 | assert('labels' in annotations_group_0[0]) 103 | assert('labels' in annotations_group_1[0]) 104 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) 105 | np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) 106 | np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) 107 | np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) 108 | 109 | 110 | class TestFilterAnnotations(object): 111 | def test_simple_filter(self): 112 | input_bboxes_group = [ 113 | np.array([ 114 | [ 0, 0, 10, 10], 115 | [150, 150, 50, 50] 116 | ]), 117 | ] 118 | input_labels_group = [ 119 | np.array([ 120 | 3, 121 | 1 122 | ]), 123 | ] 124 | 125 | input_image = np.zeros((500, 500, 3)) 126 | 127 | expected_bboxes_group = [ 128 | np.array([ 129 | [0, 0, 10, 10], 130 | ]), 131 | ] 132 | expected_labels_group = [ 133 | np.array([ 134 | 3, 135 | ]), 136 | ] 137 | 138 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 139 | annotations = simple_generator.load_annotations_group(simple_generator.groups[0]) 140 | # expect a UserWarning 141 | with pytest.warns(UserWarning): 142 | image_group, annotations_group = simple_generator.filter_annotations([input_image], annotations, simple_generator.groups[0]) 143 | 144 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group[0]['bboxes']) 145 | np.testing.assert_equal(expected_labels_group[0], annotations_group[0]['labels']) 146 | 147 | def test_multiple_filter(self): 148 | input_bboxes_group = [ 149 | np.array([ 150 | [ 0, 0, 10, 10], 151 | [150, 150, 50, 50], 152 | [150, 150, 350, 350], 153 | [350, 350, 150, 150], 154 | [ 1, 1, 2, 2], 155 | [ 2, 2, 1, 1] 156 | ]), 157 | np.array([ 158 | [0, 0, -1, -1] 159 | ]), 160 | np.array([ 161 | [-10, -10, 0, 0], 162 | [-10, -10, -100, -100], 163 | [ 10, 10, 100, 100] 164 | ]), 165 | np.array([ 166 | [ 10, 10, 100, 100], 167 | [ 10, 10, 600, 600] 168 | ]), 169 | ] 170 | 171 | input_labels_group = [ 172 | np.array([ 173 | 6, 174 | 5, 175 | 4, 176 | 3, 177 | 2, 178 | 1 179 | ]), 180 | np.array([ 181 | 0 182 | ]), 183 | np.array([ 184 | 10, 185 | 11, 186 | 12 187 | ]), 188 | np.array([ 189 | 105, 190 | 107 191 | ]), 192 | ] 193 | 194 | input_image = np.zeros((500, 500, 3)) 195 | 196 | expected_bboxes_group = [ 197 | np.array([ 198 | [ 0, 0, 10, 10], 199 | [150, 150, 350, 350], 200 | [ 1, 1, 2, 2] 201 | ]), 202 | np.zeros((0, 4)), 203 | np.array([ 204 | [10, 10, 100, 100] 205 | ]), 206 | np.array([ 207 | [ 10, 10, 100, 100] 208 | ]), 209 | ] 210 | expected_labels_group = [ 211 | np.array([ 212 | 6, 213 | 4, 214 | 2 215 | ]), 216 | np.zeros((0,)), 217 | np.array([ 218 | 12 219 | ]), 220 | np.array([ 221 | 105 222 | ]), 223 | ] 224 | 225 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group) 226 | # expect a UserWarning 227 | annotations_group_0 = simple_generator.load_annotations_group(simple_generator.groups[0]) 228 | with pytest.warns(UserWarning): 229 | image_group, annotations_group_0 = simple_generator.filter_annotations([input_image], annotations_group_0, simple_generator.groups[0]) 230 | 231 | annotations_group_1 = simple_generator.load_annotations_group(simple_generator.groups[1]) 232 | with pytest.warns(UserWarning): 233 | image_group, annotations_group_1 = simple_generator.filter_annotations([input_image], annotations_group_1, simple_generator.groups[1]) 234 | 235 | annotations_group_2 = simple_generator.load_annotations_group(simple_generator.groups[2]) 236 | with pytest.warns(UserWarning): 237 | image_group, annotations_group_2 = simple_generator.filter_annotations([input_image], annotations_group_2, simple_generator.groups[2]) 238 | 239 | np.testing.assert_equal(expected_bboxes_group[0], annotations_group_0[0]['bboxes']) 240 | np.testing.assert_equal(expected_labels_group[0], annotations_group_0[0]['labels']) 241 | 242 | np.testing.assert_equal(expected_bboxes_group[1], annotations_group_1[0]['bboxes']) 243 | np.testing.assert_equal(expected_labels_group[1], annotations_group_1[0]['labels']) 244 | 245 | np.testing.assert_equal(expected_bboxes_group[2], annotations_group_2[0]['bboxes']) 246 | np.testing.assert_equal(expected_labels_group[2], annotations_group_2[0]['labels']) 247 | 248 | def test_complete(self): 249 | input_bboxes_group = [ 250 | np.array([ 251 | [ 0, 0, 50, 50], 252 | [150, 150, 50, 50], # invalid bbox 253 | ], dtype=float) 254 | ] 255 | 256 | input_labels_group = [ 257 | np.array([ 258 | 5, # one object of class 5 259 | 3, # one object of class 3 with an invalid box 260 | ], dtype=float) 261 | ] 262 | 263 | input_image = np.zeros((500, 500, 3), dtype=np.uint8) 264 | 265 | simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group, image=input_image, num_classes=6) 266 | # expect a UserWarning 267 | with pytest.warns(UserWarning): 268 | _, [_, labels_batch] = simple_generator[0] 269 | 270 | # test that only object with class 5 is present in labels_batch 271 | labels = np.unique(np.argmax(labels_batch == 5, axis=2)) 272 | assert(len(labels) == 1 and labels[0] == 0), 'Expected only class 0 to be present, but got classes {}'.format(labels) 273 | -------------------------------------------------------------------------------- /tests/preprocessing/test_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from PIL import Image 4 | from keras_retinanet.utils import image 5 | import numpy as np 6 | 7 | _STUB_IMG_FNAME = 'stub-image.jpg' 8 | 9 | 10 | @pytest.fixture(autouse=True) 11 | def run_around_tests(tmp_path): 12 | """Create a temp image for test""" 13 | rand_img = np.random.randint(0, 255, (3, 3, 3), dtype='uint8') 14 | Image.fromarray(rand_img).save(os.path.join(tmp_path, _STUB_IMG_FNAME)) 15 | yield 16 | 17 | 18 | def test_read_image_bgr(tmp_path): 19 | stub_image_path = os.path.join(tmp_path, _STUB_IMG_FNAME) 20 | 21 | original_img = np.asarray(Image.open( 22 | stub_image_path).convert('RGB'))[:, :, ::-1] 23 | loaded_image = image.read_image_bgr(stub_image_path) 24 | 25 | # Assert images are equal 26 | np.testing.assert_array_equal(original_img, loaded_image) 27 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | check-manifest 2 | image-classifiers 3 | efficientnet 4 | # pytest 5 | pytest-xdist 6 | pytest-cov 7 | pytest-flake8 8 | # flake8 9 | coverage 10 | codecov 11 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import keras_retinanet.losses 2 | from tensorflow import keras 3 | 4 | import numpy as np 5 | 6 | import pytest 7 | 8 | 9 | def test_smooth_l1(): 10 | regression = np.array([ 11 | [ 12 | [0, 0, 0, 0], 13 | [0, 0, 0, 0], 14 | [0, 0, 0, 0], 15 | [0, 0, 0, 0], 16 | ] 17 | ], dtype=keras.backend.floatx()) 18 | regression = keras.backend.variable(regression) 19 | 20 | regression_target = np.array([ 21 | [ 22 | [0, 0, 0, 1, 1], 23 | [0, 0, 1, 0, 1], 24 | [0, 0, 0.05, 0, 1], 25 | [0, 0, 1, 0, 0], 26 | ] 27 | ], dtype=keras.backend.floatx()) 28 | regression_target = keras.backend.variable(regression_target) 29 | 30 | loss = keras_retinanet.losses.smooth_l1()(regression_target, regression) 31 | loss = keras.backend.eval(loss) 32 | 33 | assert loss == pytest.approx((((1 - 0.5 / 9) * 2 + (0.5 * 9 * 0.05 ** 2)) / 3)) 34 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fizyr/keras-retinanet/7ac91dfbbacce77d6d9633fc09e16cd0ee71fd5e/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import configparser 3 | from tensorflow import keras 4 | 5 | from keras_retinanet.utils.anchors import anchors_for_shape, AnchorParameters 6 | from keras_retinanet.utils.config import read_config_file, parse_anchor_parameters 7 | 8 | 9 | def test_config_read(): 10 | config = read_config_file('tests/test-data/config/config.ini') 11 | assert 'anchor_parameters' in config 12 | assert 'sizes' in config['anchor_parameters'] 13 | assert 'strides' in config['anchor_parameters'] 14 | assert 'ratios' in config['anchor_parameters'] 15 | assert 'scales' in config['anchor_parameters'] 16 | assert config['anchor_parameters']['sizes'] == '32 64 128 256 512' 17 | assert config['anchor_parameters']['strides'] == '8 16 32 64 128' 18 | assert config['anchor_parameters']['ratios'] == '0.5 1 2 3' 19 | assert config['anchor_parameters']['scales'] == '1 1.2 1.6' 20 | 21 | 22 | def create_anchor_params_config(): 23 | config = configparser.ConfigParser() 24 | config['anchor_parameters'] = {} 25 | config['anchor_parameters']['sizes'] = '32 64 128 256 512' 26 | config['anchor_parameters']['strides'] = '8 16 32 64 128' 27 | config['anchor_parameters']['ratios'] = '0.5 1' 28 | config['anchor_parameters']['scales'] = '1 1.2 1.6' 29 | 30 | return config 31 | 32 | 33 | def test_parse_anchor_parameters(): 34 | config = create_anchor_params_config() 35 | anchor_params_parsed = parse_anchor_parameters(config) 36 | 37 | sizes = [32, 64, 128, 256, 512] 38 | strides = [8, 16, 32, 64, 128] 39 | ratios = np.array([0.5, 1], keras.backend.floatx()) 40 | scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) 41 | 42 | assert sizes == anchor_params_parsed.sizes 43 | assert strides == anchor_params_parsed.strides 44 | np.testing.assert_equal(ratios, anchor_params_parsed.ratios) 45 | np.testing.assert_equal(scales, anchor_params_parsed.scales) 46 | 47 | 48 | def test_anchors_for_shape_dimensions(): 49 | sizes = [32, 64, 128] 50 | strides = [8, 16, 32] 51 | ratios = np.array([0.5, 1, 2, 3], keras.backend.floatx()) 52 | scales = np.array([1, 1.2, 1.6], keras.backend.floatx()) 53 | anchor_params = AnchorParameters(sizes, strides, ratios, scales) 54 | 55 | pyramid_levels = [3, 4, 5] 56 | image_shape = (64, 64) 57 | all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) 58 | 59 | assert all_anchors.shape == (1008, 4) 60 | 61 | 62 | def test_anchors_for_shape_values(): 63 | sizes = [12] 64 | strides = [8] 65 | ratios = np.array([1, 2], keras.backend.floatx()) 66 | scales = np.array([1, 2], keras.backend.floatx()) 67 | anchor_params = AnchorParameters(sizes, strides, ratios, scales) 68 | 69 | pyramid_levels = [3] 70 | image_shape = (16, 16) 71 | all_anchors = anchors_for_shape(image_shape, pyramid_levels=pyramid_levels, anchor_params=anchor_params) 72 | 73 | # using almost_equal for floating point imprecisions 74 | np.testing.assert_almost_equal(all_anchors[0, :], [ 75 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 76 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 77 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 78 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 79 | ], decimal=6) 80 | np.testing.assert_almost_equal(all_anchors[1, :], [ 81 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 82 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 83 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 84 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 85 | ], decimal=6) 86 | np.testing.assert_almost_equal(all_anchors[2, :], [ 87 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 88 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 89 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 90 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 91 | ], decimal=6) 92 | np.testing.assert_almost_equal(all_anchors[3, :], [ 93 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 94 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 95 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 96 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 97 | ], decimal=6) 98 | np.testing.assert_almost_equal(all_anchors[4, :], [ 99 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 100 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 101 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 102 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 103 | ], decimal=6) 104 | np.testing.assert_almost_equal(all_anchors[5, :], [ 105 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 106 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 107 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 108 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 109 | ], decimal=6) 110 | np.testing.assert_almost_equal(all_anchors[6, :], [ 111 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 112 | strides[0] / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 113 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 114 | strides[0] / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 115 | ], decimal=6) 116 | np.testing.assert_almost_equal(all_anchors[7, :], [ 117 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 118 | strides[0] / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 119 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 120 | strides[0] / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 121 | ], decimal=6) 122 | np.testing.assert_almost_equal(all_anchors[8, :], [ 123 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 124 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 125 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 126 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 127 | ], decimal=6) 128 | np.testing.assert_almost_equal(all_anchors[9, :], [ 129 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 130 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 131 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 132 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 133 | ], decimal=6) 134 | np.testing.assert_almost_equal(all_anchors[10, :], [ 135 | strides[0] / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 136 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 137 | strides[0] / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 138 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 139 | ], decimal=6) 140 | np.testing.assert_almost_equal(all_anchors[11, :], [ 141 | strides[0] / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 142 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 143 | strides[0] / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 144 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 145 | ], decimal=6) 146 | np.testing.assert_almost_equal(all_anchors[12, :], [ 147 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 148 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 149 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[0])) / 2, 150 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[0])) / 2, 151 | ], decimal=6) 152 | np.testing.assert_almost_equal(all_anchors[13, :], [ 153 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 154 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 155 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[0])) / 2, 156 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[0])) / 2, 157 | ], decimal=6) 158 | np.testing.assert_almost_equal(all_anchors[14, :], [ 159 | strides[0] * 3 / 2 - (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 160 | strides[0] * 3 / 2 - (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 161 | strides[0] * 3 / 2 + (sizes[0] * scales[0] / np.sqrt(ratios[1])) / 2, 162 | strides[0] * 3 / 2 + (sizes[0] * scales[0] * np.sqrt(ratios[1])) / 2, 163 | ], decimal=6) 164 | np.testing.assert_almost_equal(all_anchors[15, :], [ 165 | strides[0] * 3 / 2 - (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 166 | strides[0] * 3 / 2 - (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 167 | strides[0] * 3 / 2 + (sizes[0] * scales[1] / np.sqrt(ratios[1])) / 2, 168 | strides[0] * 3 / 2 + (sizes[0] * scales[1] * np.sqrt(ratios[1])) / 2, 169 | ], decimal=6) 170 | -------------------------------------------------------------------------------- /tests/utils/test_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_almost_equal 3 | from math import pi 4 | 5 | from keras_retinanet.utils.transform import ( 6 | colvec, 7 | transform_aabb, 8 | rotation, random_rotation, 9 | translation, random_translation, 10 | scaling, random_scaling, 11 | shear, random_shear, 12 | random_flip, 13 | random_transform, 14 | random_transform_generator, 15 | change_transform_origin, 16 | ) 17 | 18 | 19 | def test_colvec(): 20 | assert np.array_equal(colvec(0), np.array([[0]])) 21 | assert np.array_equal(colvec(1, 2, 3), np.array([[1], [2], [3]])) 22 | assert np.array_equal(colvec(-1, -2), np.array([[-1], [-2]])) 23 | 24 | 25 | def test_rotation(): 26 | assert_almost_equal(colvec( 1, 0, 1), rotation(0.0 * pi).dot(colvec(1, 0, 1))) 27 | assert_almost_equal(colvec( 0, 1, 1), rotation(0.5 * pi).dot(colvec(1, 0, 1))) 28 | assert_almost_equal(colvec(-1, 0, 1), rotation(1.0 * pi).dot(colvec(1, 0, 1))) 29 | assert_almost_equal(colvec( 0, -1, 1), rotation(1.5 * pi).dot(colvec(1, 0, 1))) 30 | assert_almost_equal(colvec( 1, 0, 1), rotation(2.0 * pi).dot(colvec(1, 0, 1))) 31 | 32 | assert_almost_equal(colvec( 0, 1, 1), rotation(0.0 * pi).dot(colvec(0, 1, 1))) 33 | assert_almost_equal(colvec(-1, 0, 1), rotation(0.5 * pi).dot(colvec(0, 1, 1))) 34 | assert_almost_equal(colvec( 0, -1, 1), rotation(1.0 * pi).dot(colvec(0, 1, 1))) 35 | assert_almost_equal(colvec( 1, 0, 1), rotation(1.5 * pi).dot(colvec(0, 1, 1))) 36 | assert_almost_equal(colvec( 0, 1, 1), rotation(2.0 * pi).dot(colvec(0, 1, 1))) 37 | 38 | 39 | def test_random_rotation(): 40 | prng = np.random.RandomState(0) 41 | for i in range(100): 42 | assert_almost_equal(1, np.linalg.det(random_rotation(-i, i, prng))) 43 | 44 | 45 | def test_translation(): 46 | assert_almost_equal(colvec( 1, 2, 1), translation(colvec( 0, 0)).dot(colvec(1, 2, 1))) 47 | assert_almost_equal(colvec( 4, 6, 1), translation(colvec( 3, 4)).dot(colvec(1, 2, 1))) 48 | assert_almost_equal(colvec(-2, -2, 1), translation(colvec(-3, -4)).dot(colvec(1, 2, 1))) 49 | 50 | 51 | def assert_is_translation(transform, min, max): 52 | assert transform.shape == (3, 3) 53 | assert np.array_equal(transform[:, 0:2], np.eye(3, 2)) 54 | assert transform[2, 2] == 1 55 | assert np.greater_equal(transform[0:2, 2], min).all() 56 | assert np.less( transform[0:2, 2], max).all() 57 | 58 | 59 | def test_random_translation(): 60 | prng = np.random.RandomState(0) 61 | min = (-10, -20) 62 | max = (20, 10) 63 | for i in range(100): 64 | assert_is_translation(random_translation(min, max, prng), min, max) 65 | 66 | 67 | def test_shear(): 68 | assert_almost_equal(colvec( 1, 2, 1), shear(0.0 * pi).dot(colvec(1, 2, 1))) 69 | assert_almost_equal(colvec(-1, 0, 1), shear(0.5 * pi).dot(colvec(1, 2, 1))) 70 | assert_almost_equal(colvec( 1, -2, 1), shear(1.0 * pi).dot(colvec(1, 2, 1))) 71 | assert_almost_equal(colvec( 3, 0, 1), shear(1.5 * pi).dot(colvec(1, 2, 1))) 72 | assert_almost_equal(colvec( 1, 2, 1), shear(2.0 * pi).dot(colvec(1, 2, 1))) 73 | 74 | 75 | def assert_is_shear(transform): 76 | assert transform.shape == (3, 3) 77 | assert np.array_equal(transform[:, 0], [1, 0, 0]) 78 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 79 | assert transform[2, 1] == 0 80 | # sin^2 + cos^2 == 1 81 | assert_almost_equal(1, transform[0, 1] ** 2 + transform[1, 1] ** 2) 82 | 83 | 84 | def test_random_shear(): 85 | prng = np.random.RandomState(0) 86 | for i in range(100): 87 | assert_is_shear(random_shear(-pi, pi, prng)) 88 | 89 | 90 | def test_scaling(): 91 | assert_almost_equal(colvec(1.0, 2, 1), scaling(colvec(1.0, 1.0)).dot(colvec(1, 2, 1))) 92 | assert_almost_equal(colvec(0.0, 2, 1), scaling(colvec(0.0, 1.0)).dot(colvec(1, 2, 1))) 93 | assert_almost_equal(colvec(1.0, 0, 1), scaling(colvec(1.0, 0.0)).dot(colvec(1, 2, 1))) 94 | assert_almost_equal(colvec(0.5, 4, 1), scaling(colvec(0.5, 2.0)).dot(colvec(1, 2, 1))) 95 | 96 | 97 | def assert_is_scaling(transform, min, max): 98 | assert transform.shape == (3, 3) 99 | assert np.array_equal(transform[2, :], [0, 0, 1]) 100 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 101 | assert transform[1, 0] == 0 102 | assert transform[0, 1] == 0 103 | assert np.greater_equal(np.diagonal(transform)[:2], min).all() 104 | assert np.less( np.diagonal(transform)[:2], max).all() 105 | 106 | 107 | def test_random_scaling(): 108 | prng = np.random.RandomState(0) 109 | min = (0.1, 0.2) 110 | max = (20, 10) 111 | for i in range(100): 112 | assert_is_scaling(random_scaling(min, max, prng), min, max) 113 | 114 | 115 | def assert_is_flip(transform): 116 | assert transform.shape == (3, 3) 117 | assert np.array_equal(transform[2, :], [0, 0, 1]) 118 | assert np.array_equal(transform[:, 2], [0, 0, 1]) 119 | assert transform[1, 0] == 0 120 | assert transform[0, 1] == 0 121 | assert abs(transform[0, 0]) == 1 122 | assert abs(transform[1, 1]) == 1 123 | 124 | 125 | def test_random_flip(): 126 | prng = np.random.RandomState(0) 127 | for i in range(100): 128 | assert_is_flip(random_flip(0.5, 0.5, prng)) 129 | 130 | 131 | def test_random_transform(): 132 | prng = np.random.RandomState(0) 133 | for i in range(100): 134 | transform = random_transform(prng=prng) 135 | assert np.array_equal(transform, np.identity(3)) 136 | 137 | for i, transform in zip(range(100), random_transform_generator(prng=np.random.RandomState())): 138 | assert np.array_equal(transform, np.identity(3)) 139 | 140 | 141 | def test_transform_aabb(): 142 | assert np.array_equal([1, 2, 3, 4], transform_aabb(np.identity(3), [1, 2, 3, 4])) 143 | assert_almost_equal([-3, -4, -1, -2], transform_aabb(rotation(pi), [1, 2, 3, 4])) 144 | assert_almost_equal([ 2, 4, 4, 6], transform_aabb(translation([1, 2]), [1, 2, 3, 4])) 145 | 146 | 147 | def test_change_transform_origin(): 148 | assert np.array_equal(change_transform_origin(translation([3, 4]), [1, 2]), translation([3, 4])) 149 | assert_almost_equal(colvec(1, 2, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(1, 2, 1))) 150 | assert_almost_equal(colvec(0, 0, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(2, 4, 1))) 151 | assert_almost_equal(colvec(0, 0, 1), change_transform_origin(scaling([0.5, 0.5]), [-2, -4]).dot(colvec(2, 4, 1))) 152 | --------------------------------------------------------------------------------