├── utils ├── __init__.py ├── compute_overlap.cpython-36m-x86_64-linux-gnu.so ├── model.py ├── keras_version.py ├── compute_overlap.pyx ├── config.py ├── colors.py ├── coco_eval.py ├── visualization.py ├── transform.py ├── eval.py └── image.py ├── yolo ├── __init__.py ├── eval │ ├── __init__.py │ ├── pascal.py │ ├── coco.py │ └── common.py ├── generators │ ├── __init__.py │ ├── coco.py │ ├── pascal.py │ └── csv_.py ├── config.py ├── README.md ├── inference.py └── model.py ├── augmentor ├── __init__.py ├── misc.py └── color.py ├── generators ├── __init__.py ├── coco_generator.py ├── voc_generator.py └── csv_generator.py ├── test ├── 004456.jpg ├── 005770.jpg └── 006408.jpg ├── configure.py ├── requirements.txt ├── .github └── stale.yml ├── initializers.py ├── setup.py ├── .gitignore ├── README.md ├── inference.py ├── models ├── vgg.py ├── densenet.py ├── mobilenet.py ├── __init__.py └── resnet.py ├── util_graphs.py ├── callbacks.py └── losses.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yolo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /augmentor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yolo/eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yolo/generators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/004456.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/test/004456.jpg -------------------------------------------------------------------------------- /test/005770.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/test/005770.jpg -------------------------------------------------------------------------------- /test/006408.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/test/006408.jpg -------------------------------------------------------------------------------- /yolo/config.py: -------------------------------------------------------------------------------- 1 | MAX_NUM_GT_BOXES = 100 2 | POS_SCALE = 0.2 3 | IGNORE_SCALE = 0.5 4 | STRIDES = (8, 16, 32) 5 | 6 | -------------------------------------------------------------------------------- /configure.py: -------------------------------------------------------------------------------- 1 | MAX_NUM_GT_BOXES = 100 2 | POS_SCALE = 0.2 3 | IGNORE_SCALE = 0.5 4 | STRIDES = (8, 16, 32, 64, 128) 5 | 6 | -------------------------------------------------------------------------------- /utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Keras==2.2.5 2 | keras-resnet==0.2.0 3 | opencv-contrib-python==3.4.2.17 4 | opencv-python==3.4.2.17 5 | Pillow==6.2.0 6 | tensorflow-gpu==1.15.2 7 | progressbar2 8 | git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI 9 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 5 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 3 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | def freeze(model): 19 | """ 20 | Set all layers in a model to non-trainable. 21 | 22 | The weights for these layers will not be updated during training. 23 | 24 | This function modifies the given model in-place, 25 | but it also returns the modified model to allow easy chaining with other functions. 26 | """ 27 | for layer in model.layers: 28 | layer.trainable = False 29 | return model 30 | -------------------------------------------------------------------------------- /initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | 19 | import numpy as np 20 | import math 21 | 22 | 23 | class PriorProbability(keras.initializers.Initializer): 24 | """ Apply a prior probability to the weights. 25 | """ 26 | 27 | def __init__(self, probability=0.01): 28 | self.probability = probability 29 | 30 | def get_config(self): 31 | return { 32 | 'probability': self.probability 33 | } 34 | 35 | def __call__(self, shape, dtype=None): 36 | # set bias to -log((1 - p)/p) for foreground 37 | result = np.ones(shape, dtype=dtype) * -math.log((1 - self.probability) / self.probability) 38 | 39 | return result 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools.extension import Extension 3 | from distutils.command.build_ext import build_ext as DistUtilsBuildExt 4 | 5 | 6 | class BuildExtension(setuptools.Command): 7 | description = DistUtilsBuildExt.description 8 | user_options = DistUtilsBuildExt.user_options 9 | boolean_options = DistUtilsBuildExt.boolean_options 10 | help_options = DistUtilsBuildExt.help_options 11 | 12 | def __init__(self, *args, **kwargs): 13 | from setuptools.command.build_ext import build_ext as SetupToolsBuildExt 14 | 15 | # Bypass __setatrr__ to avoid infinite recursion. 16 | self.__dict__['_command'] = SetupToolsBuildExt(*args, **kwargs) 17 | 18 | def __getattr__(self, name): 19 | return getattr(self._command, name) 20 | 21 | def __setattr__(self, name, value): 22 | setattr(self._command, name, value) 23 | 24 | def initialize_options(self, *args, **kwargs): 25 | return self._command.initialize_options(*args, **kwargs) 26 | 27 | def finalize_options(self, *args, **kwargs): 28 | ret = self._command.finalize_options(*args, **kwargs) 29 | import numpy 30 | self.include_dirs.append(numpy.get_include()) 31 | return ret 32 | 33 | def run(self, *args, **kwargs): 34 | return self._command.run(*args, **kwargs) 35 | 36 | 37 | extensions = [ 38 | Extension( 39 | 'utils.compute_overlap', 40 | ['utils/compute_overlap.pyx'] 41 | ), 42 | ] 43 | 44 | setuptools.setup( 45 | cmdclass={'build_ext': BuildExtension}, 46 | packages=setuptools.find_packages(), 47 | ext_modules=extensions, 48 | setup_requires=["cython>=0.28", "numpy>=1.14.0"] 49 | ) 50 | -------------------------------------------------------------------------------- /utils/keras_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from __future__ import print_function 18 | 19 | import keras 20 | import sys 21 | 22 | minimum_keras_version = 2, 2, 4 23 | 24 | 25 | def keras_version(): 26 | """ 27 | Get the Keras version. 28 | 29 | Returns 30 | tuple of (major, minor, patch). 如 (2, 2, 4) 31 | """ 32 | return tuple(map(int, keras.__version__.split('.'))) 33 | 34 | 35 | def keras_version_ok(): 36 | """ 37 | Check if the current Keras version is higher than the minimum version. 38 | """ 39 | return keras_version() >= minimum_keras_version 40 | 41 | 42 | def assert_keras_version(): 43 | """ 44 | Assert that the Keras version is up to date. 45 | """ 46 | detected = keras.__version__ 47 | required = '.'.join(map(str, minimum_keras_version)) 48 | assert(keras_version() >= minimum_keras_version), 'You are using keras version {}. The minimum required version is {}.'.format(detected, required) 49 | 50 | 51 | def check_keras_version(): 52 | """ 53 | Check that the Keras version is up to date. If it isn't, print an error message and exit the script. 54 | """ 55 | try: 56 | assert_keras_version() 57 | except AssertionError as e: 58 | print(e, file=sys.stderr) 59 | sys.exit(1) 60 | -------------------------------------------------------------------------------- /utils/compute_overlap.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | 13 | def compute_overlap( 14 | np.ndarray[double, ndim=2] boxes, 15 | np.ndarray[double, ndim=2] query_boxes 16 | ): 17 | """ 18 | Args 19 | a: (N, 4) ndarray of float 20 | b: (K, 4) ndarray of float 21 | 22 | Returns 23 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 24 | """ 25 | cdef unsigned int N = boxes.shape[0] 26 | cdef unsigned int K = query_boxes.shape[0] 27 | cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64) 28 | cdef double iw, ih, box_area 29 | cdef double ua 30 | cdef unsigned int k, n 31 | for k in range(K): 32 | box_area = ( 33 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 34 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 35 | ) 36 | for n in range(N): 37 | iw = ( 38 | min(boxes[n, 2], query_boxes[k, 2]) - 39 | max(boxes[n, 0], query_boxes[k, 0]) + 1 40 | ) 41 | if iw > 0: 42 | ih = ( 43 | min(boxes[n, 3], query_boxes[k, 3]) - 44 | max(boxes[n, 1], query_boxes[k, 1]) + 1 45 | ) 46 | if ih > 0: 47 | ua = np.float64( 48 | (boxes[n, 2] - boxes[n, 0] + 1) * 49 | (boxes[n, 3] - boxes[n, 1] + 1) + 50 | box_area - iw * ih 51 | ) 52 | overlaps[n, k] = iw * ih / ua 53 | return overlaps 54 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import configparser 18 | import numpy as np 19 | import keras 20 | from utils.anchors import AnchorParameters 21 | 22 | 23 | def read_config_file(config_path): 24 | config = configparser.ConfigParser() 25 | 26 | with open(config_path, 'r') as file: 27 | config.read_file(file) 28 | 29 | assert 'anchor_parameters' in config, \ 30 | "Malformed config file. Verify that it contains the anchor_parameters section." 31 | 32 | config_keys = set(config['anchor_parameters']) 33 | default_keys = set(AnchorParameters.default.__dict__.keys()) 34 | 35 | assert config_keys <= default_keys, \ 36 | "Malformed config file. These keys are not valid: {}".format(config_keys - default_keys) 37 | 38 | return config 39 | 40 | 41 | def parse_anchor_parameters(config): 42 | ratios = np.array(list(map(float, config['anchor_parameters']['ratios'].split(' '))), keras.backend.floatx()) 43 | scales = np.array(list(map(float, config['anchor_parameters']['scales'].split(' '))), keras.backend.floatx()) 44 | sizes = list(map(int, config['anchor_parameters']['sizes'].split(' '))) 45 | strides = list(map(int, config['anchor_parameters']['strides'].split(' '))) 46 | 47 | return AnchorParameters(sizes, strides, ratios, scales) 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | datasets/ 107 | logs/ 108 | !utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so 109 | snapshots/ 110 | checkpoints/ 111 | -------------------------------------------------------------------------------- /yolo/README.md: -------------------------------------------------------------------------------- 1 | # FSAF 2 | This is an implementation of [FSAF](https://arxiv.org/abs/1903.00621) on keras and Tensorflow. The project is based on [qqwweee/keras-yolo3](https://github.com/qqwweee/keras-yolo3) and [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet). 3 | Thanks for their hard work. 4 | 5 | As the authors write, FASF module can be plugged into any single-shot detectors with FPN-like structure smoothly. 6 | I have also tried on yolo3. Anchor-free yolo3(with FSAF) gets a comparable performance with the anchor-based counterpart. But you don't need to pre-compute the anchor sizes any more. 7 | And it is much better and faster than the one based on retinanet. 8 | 9 | It can also converge quite quickly. After the first epoch(batch_size=64, steps=1000), we can get an mAP50 of 0.6xxx on val dataset. 10 | 11 | ## Test 12 | 1. I trained on Pascal VOC2012 trainval.txt + Pascal VOC2007 train.txt, and validated on Pascal VOC2007 val.txt. There are 14041 images for training and 2510 images for validation. 13 | 2. The best evaluation result (score_threshold=0.01, mAP50, image_size=416) on VOC2007 test is 0.8358. I have only trained once. 14 | 3. Pretrained yolo and fsaf weights are here. [baidu netdisk](https://pan.baidu.com/s/1QoGXnajcohj9P4yCVwJ4Yw), extract code: qab7 15 | 4. `python3 yolo/inference.py` to test your image by specifying image path and model path there. 16 | 17 | ## Train 18 | ### build dataset (Pascal VOC, other types please refer to [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet)) 19 | * Download VOC2007 and VOC2012, copy all image files from VOC2007 to VOC2012. 20 | * Append VOC2007 train.txt to VOC2012 trainval.txt. 21 | * Overwrite VOC2012 val.txt by VOC2007 val.txt. 22 | ### train 23 | * **STEP1**: `python3 yolo/train.py --freeze-body darknet --gpu 0 --batch-size 32 --random-transform pascal datasets/VOC2012` to start training with lr=1e-3 then set lr=1e-4 when val mAP continue to drop. 24 | * **STEP2**: `python3 yolo/train.py --snapshot --freeze-body none --gpu 0 --batch-size 32 --random-transform pascal datasets/VOC2012` to start training with lr=1e-5 and then set lr=1e-6 when val mAP contines to drop. 25 | ## Evaluate 26 | * `python3 yolo/eval/common.py` to evaluate by specifying model path there. 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FSAF 2 | This is an implementation of [FSAF](https://arxiv.org/abs/1903.00621) on keras and Tensorflow. The project is based on [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet) 3 | and fsaf branch of [zccstig/mmdetection](https://github.com/zccstig/mmdetection/tree/fsaf). 4 | Thanks for their hard work. 5 | 6 | As the authors write, **FASF module can be plugged into any single-shot detectors with FPN-like structure smoothly**. 7 | I have also tried on [yolo3](yolo). Anchor-free yolo3(with FSAF) gets a comparable performance with the anchor-based counterpart. But you don't need to pre-compute the anchor sizes any more. 8 | And it is much better and faster than the one based on retinanet. 9 | 10 | **Updates** 11 | - [03/05/2020] The author of the paper has released a new paper [SAPD](https://arxiv.org/abs/1911.12448), which is based on FSAF. 12 | I have implemented it at [xuannianz/SAPD](https://github.com/xuannianz/SAPD). 13 | 14 | 15 | ## Test 16 | 1. I trained on Pascal VOC2012 trainval.txt + Pascal VOC2007 train.txt, and validated on Pascal VOC2007 val.txt. There are 14041 images for training and 2510 images for validation. 17 | 2. The best evaluation results (score_threshold=0.05) on VOC2007 test are: 18 | 19 | | backbone | mAP50 | 20 | | ---- | ---- | 21 | | resnet50 | 0.7248 | 22 | | resnet101 | 0.7652 | 23 | 24 | 3. Pretrained models are here. 25 | [baidu netdisk](https://pan.baidu.com/s/1ZdHvR-03XqHvxWG0rLCw1g) extract code: rbrr 26 | [goole dirver](https://drive.google.com/open?id=1Hcgxp5OwqNsAx-HYgcIhLat1OOHKnvJ2) 27 | 28 | 4. `python3 inference.py` to test your image by specifying image path and model path there. 29 | 30 | ![image1](test/004456.jpg) 31 | ![image2](test/005770.jpg) 32 | ![image3](test/006408.jpg) 33 | 34 | 35 | ## Train 36 | ### build dataset (Pascal VOC, other types please refer to [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet)) 37 | * Download VOC2007 and VOC2012, copy all image files from VOC2007 to VOC2012. 38 | * Append VOC2007 train.txt to VOC2012 trainval.txt. 39 | * Overwrite VOC2012 val.txt by VOC2007 val.txt. 40 | ### train 41 | * `python3 train.py --backbone resnet50 --gpu 0 --random-transform pascal datasets/VOC2012` to start training. 42 | ## Evaluate 43 | * `python3 utils/eval.py` to evaluate by specifying model path there. 44 | -------------------------------------------------------------------------------- /utils/colors.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | def label_color(label): 5 | """ Return a color from a set of predefined colors. Contains 80 colors in total. 6 | 7 | Args 8 | label: The label to get the color for. 9 | 10 | Returns 11 | A list of three values representing a RGB color. 12 | 13 | If no color is defined for a certain label, the color green is returned and a warning is printed. 14 | """ 15 | if label < len(colors): 16 | return colors[label] 17 | else: 18 | warnings.warn('Label {} has no color, returning default.'.format(label)) 19 | return (0, 255, 0) 20 | 21 | 22 | """ 23 | Generated using: 24 | 25 | ``` 26 | colors = [list((matplotlib.colors.hsv_to_rgb([x, 1.0, 1.0]) * 255).astype(int)) for x in np.arange(0, 1, 1.0 / 80)] 27 | shuffle(colors) 28 | pprint(colors) 29 | ``` 30 | """ 31 | colors = [ 32 | [31 , 0 , 255] , 33 | [0 , 159 , 255] , 34 | [255 , 95 , 0] , 35 | [255 , 19 , 0] , 36 | [255 , 0 , 0] , 37 | [255 , 38 , 0] , 38 | [0 , 255 , 25] , 39 | [255 , 0 , 133] , 40 | [255 , 172 , 0] , 41 | [108 , 0 , 255] , 42 | [0 , 82 , 255] , 43 | [0 , 255 , 6] , 44 | [255 , 0 , 152] , 45 | [223 , 0 , 255] , 46 | [12 , 0 , 255] , 47 | [0 , 255 , 178] , 48 | [108 , 255 , 0] , 49 | [184 , 0 , 255] , 50 | [255 , 0 , 76] , 51 | [146 , 255 , 0] , 52 | [51 , 0 , 255] , 53 | [0 , 197 , 255] , 54 | [255 , 248 , 0] , 55 | [255 , 0 , 19] , 56 | [255 , 0 , 38] , 57 | [89 , 255 , 0] , 58 | [127 , 255 , 0] , 59 | [255 , 153 , 0] , 60 | [0 , 255 , 255] , 61 | [0 , 255 , 216] , 62 | [0 , 255 , 121] , 63 | [255 , 0 , 248] , 64 | [70 , 0 , 255] , 65 | [0 , 255 , 159] , 66 | [0 , 216 , 255] , 67 | [0 , 6 , 255] , 68 | [0 , 63 , 255] , 69 | [31 , 255 , 0] , 70 | [255 , 57 , 0] , 71 | [255 , 0 , 210] , 72 | [0 , 255 , 102] , 73 | [242 , 255 , 0] , 74 | [255 , 191 , 0] , 75 | [0 , 255 , 63] , 76 | [255 , 0 , 95] , 77 | [146 , 0 , 255] , 78 | [184 , 255 , 0] , 79 | [255 , 114 , 0] , 80 | [0 , 255 , 235] , 81 | [255 , 229 , 0] , 82 | [0 , 178 , 255] , 83 | [255 , 0 , 114] , 84 | [255 , 0 , 57] , 85 | [0 , 140 , 255] , 86 | [0 , 121 , 255] , 87 | [12 , 255 , 0] , 88 | [255 , 210 , 0] , 89 | [0 , 255 , 44] , 90 | [165 , 255 , 0] , 91 | [0 , 25 , 255] , 92 | [0 , 255 , 140] , 93 | [0 , 101 , 255] , 94 | [0 , 255 , 82] , 95 | [223 , 255 , 0] , 96 | [242 , 0 , 255] , 97 | [89 , 0 , 255] , 98 | [165 , 0 , 255] , 99 | [70 , 255 , 0] , 100 | [255 , 0 , 172] , 101 | [255 , 76 , 0] , 102 | [203 , 255 , 0] , 103 | [204 , 0 , 255] , 104 | [255 , 0 , 229] , 105 | [255 , 133 , 0] , 106 | [127 , 0 , 255] , 107 | [0 , 235 , 255] , 108 | [0 , 255 , 197] , 109 | [255 , 0 , 191] , 110 | [0 , 44 , 255] , 111 | [50 , 255 , 0] 112 | ] 113 | -------------------------------------------------------------------------------- /utils/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from pycocotools.cocoeval import COCOeval 18 | 19 | import keras 20 | import numpy as np 21 | import json 22 | 23 | import progressbar 24 | 25 | assert (callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 26 | 27 | 28 | def evaluate_coco(generator, model, threshold=0.05): 29 | """ Use the pycocotools to evaluate a COCO model on a dataset. 30 | 31 | Args 32 | generator : The generator for generating the evaluation data. 33 | model : The model to evaluate. 34 | threshold : The score threshold to use. 35 | """ 36 | # start collecting results 37 | results = [] 38 | image_ids = [] 39 | for index in progressbar.progressbar(range(generator.size()), prefix='COCO evaluation: '): 40 | image = generator.load_image(index) 41 | image = generator.preprocess_image(image) 42 | image, scale = generator.resize_image(image) 43 | 44 | if keras.backend.image_data_format() == 'channels_first': 45 | image = image.transpose((2, 0, 1)) 46 | 47 | # run network 48 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 49 | 50 | # correct boxes for image scale 51 | boxes /= scale 52 | 53 | # change to (x, y, w, h) (MS COCO standard) 54 | boxes[:, :, 2] -= boxes[:, :, 0] 55 | boxes[:, :, 3] -= boxes[:, :, 1] 56 | 57 | # compute predicted labels and scores 58 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 59 | # scores are sorted, so we can break 60 | if score < threshold: 61 | break 62 | 63 | # append detection for each positively labeled class 64 | image_result = { 65 | 'image_id': generator.image_ids[index], 66 | 'category_id': generator.label_to_coco_label(label), 67 | 'score': float(score), 68 | 'bbox': box.tolist(), 69 | } 70 | 71 | # append detection to results 72 | results.append(image_result) 73 | 74 | # append image to list of processed images 75 | image_ids.append(generator.image_ids[index]) 76 | 77 | if not len(results): 78 | return 79 | 80 | # write output 81 | json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4) 82 | json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4) 83 | 84 | # load results in COCO evaluation tool 85 | coco_true = generator.coco 86 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name)) 87 | 88 | # run COCO evaluation 89 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox') 90 | coco_eval.params.imgIds = image_ids 91 | coco_eval.evaluate() 92 | coco_eval.accumulate() 93 | coco_eval.summarize() 94 | return coco_eval.stats 95 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import models 3 | from utils.image import read_image_bgr, preprocess_image, resize_image 4 | from utils.visualization import draw_box, draw_caption 5 | from utils.colors import label_color 6 | 7 | # import miscellaneous modules 8 | import matplotlib.pyplot as plt 9 | import cv2 10 | import os 11 | import numpy as np 12 | import time 13 | import glob 14 | import os.path as osp 15 | 16 | # set tf backend to allow memory to grow, instead of claiming everything 17 | import tensorflow as tf 18 | 19 | 20 | def get_session(): 21 | config = tf.ConfigProto() 22 | config.gpu_options.allow_growth = True 23 | return tf.Session(config=config) 24 | 25 | 26 | # use this environment flag to change which GPU to use 27 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 28 | 29 | # set the modified tf session as backend in keras 30 | keras.backend.set_session(get_session()) 31 | # adjust this to point to your downloaded/trained model 32 | # models can be downloaded here: https://github.com/fizyr/keras-retinanet/releases 33 | model_path = '/home/adam/workspace/github/xuannianz/carrot/fsaf/snapshots/2019-10-05/resnet101_pascal_47_0.7652.h5' 34 | 35 | # load retinanet model 36 | # model = models.load_model(model_path, backbone_name='resnet101') 37 | 38 | # if the model is not converted to an inference model, use the line below 39 | # see: https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model 40 | from models.resnet import resnet_fsaf 41 | from models.retinanet import fsaf_bbox 42 | fsaf = resnet_fsaf(num_classes=20, backbone='resnet101') 43 | model = fsaf_bbox(fsaf) 44 | model.load_weights(model_path, by_name=True) 45 | # load label to names mapping for visualization purposes 46 | voc_classes = { 47 | 'aeroplane': 0, 48 | 'bicycle': 1, 49 | 'bird': 2, 50 | 'boat': 3, 51 | 'bottle': 4, 52 | 'bus': 5, 53 | 'car': 6, 54 | 'cat': 7, 55 | 'chair': 8, 56 | 'cow': 9, 57 | 'diningtable': 10, 58 | 'dog': 11, 59 | 'horse': 12, 60 | 'motorbike': 13, 61 | 'person': 14, 62 | 'pottedplant': 15, 63 | 'sheep': 16, 64 | 'sofa': 17, 65 | 'train': 18, 66 | 'tvmonitor': 19 67 | } 68 | labels_to_names = {} 69 | for key, value in voc_classes.items(): 70 | labels_to_names[value] = key 71 | # load image 72 | image_paths = glob.glob('datasets/voc_test/VOC2007/JPEGImages/*.jpg') 73 | for image_path in image_paths: 74 | print('Handling {}'.format(image_path)) 75 | image = read_image_bgr(image_path) 76 | 77 | # copy to draw on 78 | draw = image.copy() 79 | 80 | # preprocess image for network 81 | image = preprocess_image(image) 82 | image, scale = resize_image(image) 83 | 84 | # process image 85 | start = time.time() 86 | # locations, feature_shapes = model.predict_on_batch(np.expand_dims(image, axis=0)) 87 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0)) 88 | print("processing time: ", time.time() - start) 89 | 90 | # correct for image scale 91 | boxes /= scale 92 | labels_to_locations = {} 93 | # visualize detections 94 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 95 | # scores are sorted so we can break 96 | if score < 0.5: 97 | break 98 | start_x = int(box[0]) 99 | start_y = int(box[1]) 100 | end_x = int(box[2]) 101 | end_y = int(box[3]) 102 | color = label_color(label) 103 | 104 | b = box.astype(int) 105 | draw_box(draw, b, color=color) 106 | 107 | caption = "{} {:.3f}".format(labels_to_names[label], score) 108 | draw_caption(draw, b, caption) 109 | 110 | cv2.namedWindow('image', cv2.WINDOW_NORMAL) 111 | cv2.imshow('image', draw) 112 | key = cv2.waitKey(0) 113 | if int(key) == 121: 114 | image_fname = osp.split(image_path)[-1] 115 | cv2.imwrite('test/{}'.format(image_fname), draw) 116 | 117 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 cgratie (https://github.com/cgratie/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import keras 19 | from keras.utils import get_file 20 | 21 | from . import retinanet 22 | from . import Backbone 23 | from ..utils.image import preprocess_image 24 | 25 | 26 | class VGGBackbone(Backbone): 27 | """ Describes backbone information and provides utility functions. 28 | """ 29 | 30 | def retinanet(self, *args, **kwargs): 31 | """ Returns a retinanet model using the correct backbone. 32 | """ 33 | return vgg_retinanet(*args, backbone=self.backbone, **kwargs) 34 | 35 | def download_imagenet(self): 36 | """ Downloads ImageNet weights and returns path to weights file. 37 | Weights can be downloaded at https://github.com/fizyr/keras-models/releases . 38 | """ 39 | if self.backbone == 'vgg16': 40 | resource = keras.applications.vgg16.vgg16.WEIGHTS_PATH_NO_TOP 41 | checksum = '6d6bbae143d832006294945121d1f1fc' 42 | elif self.backbone == 'vgg19': 43 | resource = keras.applications.vgg19.vgg19.WEIGHTS_PATH_NO_TOP 44 | checksum = '253f8cb515780f3b799900260a226db6' 45 | else: 46 | raise ValueError("Backbone '{}' not recognized.".format(self.backbone)) 47 | 48 | return get_file( 49 | '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'.format(self.backbone), 50 | resource, 51 | cache_subdir='models', 52 | file_hash=checksum 53 | ) 54 | 55 | def validate(self): 56 | """ Checks whether the backbone string is correct. 57 | """ 58 | allowed_backbones = ['vgg16', 'vgg19'] 59 | 60 | if self.backbone not in allowed_backbones: 61 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones)) 62 | 63 | def preprocess_image(self, inputs): 64 | """ Takes as input an image and prepares it for being passed through the network. 65 | """ 66 | return preprocess_image(inputs, mode='caffe') 67 | 68 | 69 | def vgg_retinanet(num_classes, backbone='vgg16', inputs=None, modifier=None, **kwargs): 70 | """ Constructs a retinanet model using a vgg backbone. 71 | 72 | Args 73 | num_classes: Number of classes to predict. 74 | backbone: Which backbone to use (one of ('vgg16', 'vgg19')). 75 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 76 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 77 | 78 | Returns 79 | RetinaNet model with a VGG backbone. 80 | """ 81 | # choose default input 82 | if inputs is None: 83 | inputs = keras.layers.Input(shape=(None, None, 3)) 84 | 85 | # create the vgg backbone 86 | if backbone == 'vgg16': 87 | vgg = keras.applications.VGG16(input_tensor=inputs, include_top=False, weights=None) 88 | elif backbone == 'vgg19': 89 | vgg = keras.applications.VGG19(input_tensor=inputs, include_top=False, weights=None) 90 | else: 91 | raise ValueError("Backbone '{}' not recognized.".format(backbone)) 92 | 93 | if modifier: 94 | vgg = modifier(vgg) 95 | 96 | # create the full model 97 | layer_names = ["block3_pool", "block4_pool", "block5_pool"] 98 | layer_outputs = [vgg.get_layer(name).output for name in layer_names] 99 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=layer_outputs, **kwargs) 100 | -------------------------------------------------------------------------------- /yolo/inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import keras 4 | import numpy as np 5 | import os 6 | import os.path as osp 7 | import tensorflow as tf 8 | import time 9 | 10 | from utils.visualization import draw_box, draw_caption 11 | from utils.colors import label_color 12 | from yolo.model import yolo_body 13 | 14 | 15 | # set tf backend to allow memory to grow, instead of claiming everything 16 | def get_session(): 17 | config = tf.ConfigProto() 18 | config.gpu_options.allow_growth = True 19 | return tf.Session(config=config) 20 | 21 | 22 | def preprocess_image(image, image_size=416): 23 | image_height, image_width = image.shape[:2] 24 | if image_height > image_width: 25 | scale = image_size / image_height 26 | resized_height = image_size 27 | resized_width = int(image_width * scale) 28 | else: 29 | scale = image_size / image_width 30 | resized_height = int(image_height * scale) 31 | resized_width = image_size 32 | image = cv2.resize(image, (resized_width, resized_height)) 33 | new_image = np.ones((image_size, image_size, 3), dtype=np.float32) * 128. 34 | offset_h = (image_size - resized_height) // 2 35 | offset_w = (image_size - resized_width) // 2 36 | new_image[offset_h:offset_h + resized_height, offset_w:offset_w + resized_width] = image.astype(np.float32) 37 | new_image /= 255. 38 | return new_image, scale, offset_h, offset_w 39 | 40 | 41 | # use this environment flag to change which GPU to use 42 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 43 | 44 | # set the modified tf session as backend in keras 45 | keras.backend.set_session(get_session()) 46 | 47 | model_path = 'pascal_18_6.4112_6.5125_0.8319_0.8358.h5' 48 | 49 | model, prediction_model = yolo_body(num_classes=20) 50 | 51 | prediction_model.load_weights(model_path, by_name=True) 52 | 53 | # load label to names mapping for visualization purposes 54 | voc_classes = { 55 | 'aeroplane': 0, 56 | 'bicycle': 1, 57 | 'bird': 2, 58 | 'boat': 3, 59 | 'bottle': 4, 60 | 'bus': 5, 61 | 'car': 6, 62 | 'cat': 7, 63 | 'chair': 8, 64 | 'cow': 9, 65 | 'diningtable': 10, 66 | 'dog': 11, 67 | 'horse': 12, 68 | 'motorbike': 13, 69 | 'person': 14, 70 | 'pottedplant': 15, 71 | 'sheep': 16, 72 | 'sofa': 17, 73 | 'train': 18, 74 | 'tvmonitor': 19 75 | } 76 | labels_to_names = {} 77 | for key, value in voc_classes.items(): 78 | labels_to_names[value] = key 79 | # load image 80 | image_paths = glob.glob('datasets/voc_test/VOC2007/JPEGImages/*.jpg') 81 | for image_path in image_paths: 82 | print('Handling {}'.format(image_path)) 83 | image = cv2.imread(image_path) 84 | 85 | # copy to draw on 86 | draw = image.copy() 87 | 88 | # preprocess image for network 89 | image, scale, offset_h, offset_w = preprocess_image(image) 90 | 91 | # process image 92 | start = time.time() 93 | # locations, feature_shapes = model.predict_on_batch(np.expand_dims(image, axis=0)) 94 | boxes, scores, labels = prediction_model.predict_on_batch(np.expand_dims(image, axis=0)) 95 | print("processing time: ", time.time() - start) 96 | 97 | # correct boxes for image scale 98 | boxes[0, :, [0, 2]] -= offset_w 99 | boxes[0, :, [1, 3]] -= offset_h 100 | boxes /= scale 101 | 102 | labels_to_locations = {} 103 | # visualize detections 104 | for box, score, label in zip(boxes[0], scores[0], labels[0]): 105 | # scores are sorted so we can break 106 | if score < 0.5: 107 | break 108 | start_x = int(box[0]) 109 | start_y = int(box[1]) 110 | end_x = int(box[2]) 111 | end_y = int(box[3]) 112 | color = label_color(label) 113 | 114 | b = box.astype(int) 115 | draw_box(draw, b, color=color) 116 | 117 | caption = "{} {:.3f}".format(labels_to_names[label], score) 118 | draw_caption(draw, b, caption) 119 | 120 | cv2.namedWindow('image', cv2.WINDOW_NORMAL) 121 | cv2.imshow('image', draw) 122 | key = cv2.waitKey(0) 123 | if int(key) == 121: 124 | image_fname = osp.split(image_path)[-1] 125 | cv2.imwrite('test/{}'.format(image_fname), draw) 126 | -------------------------------------------------------------------------------- /yolo/eval/pascal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from .common import evaluate 19 | 20 | 21 | class Evaluate(keras.callbacks.Callback): 22 | """ 23 | Evaluation callback for arbitrary datasets. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | generator, 29 | model, 30 | iou_threshold=0.5, 31 | score_threshold=0.01, 32 | max_detections=100, 33 | save_path=None, 34 | tensorboard=None, 35 | weighted_average=False, 36 | verbose=1 37 | ): 38 | """ 39 | Evaluate a given dataset using a given model at the end of every epoch during training. 40 | 41 | Args: 42 | generator: The generator that represents the dataset to evaluate. 43 | iou_threshold: The threshold used to consider when a detection is positive or negative. 44 | score_threshold: The score confidence threshold to use for detections. 45 | max_detections: The maximum number of detections to use per image. 46 | save_path: The path to save images with visualized detections to. 47 | tensorboard: Instance of keras.callbacks.TensorBoard used to log the mAP value. 48 | weighted_average: Compute the mAP using the weighted average of precisions among classes. 49 | verbose: Set the verbosity level, by default this is set to 1. 50 | """ 51 | self.generator = generator 52 | self.iou_threshold = iou_threshold 53 | self.score_threshold = score_threshold 54 | self.max_detections = max_detections 55 | self.save_path = save_path 56 | self.tensorboard = tensorboard 57 | self.weighted_average = weighted_average 58 | self.verbose = verbose 59 | self.active_model = model 60 | 61 | super(Evaluate, self).__init__() 62 | 63 | def on_epoch_end(self, epoch, logs=None): 64 | logs = logs or {} 65 | 66 | # run evaluation 67 | average_precisions = evaluate( 68 | self.generator, 69 | self.active_model, 70 | iou_threshold=self.iou_threshold, 71 | score_threshold=self.score_threshold, 72 | max_detections=self.max_detections, 73 | visualize=False 74 | ) 75 | 76 | # compute per class average precision 77 | total_instances = [] 78 | precisions = [] 79 | for label, (average_precision, num_annotations) in average_precisions.items(): 80 | if self.verbose == 1: 81 | print('{:.0f} instances of class'.format(num_annotations), 82 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 83 | total_instances.append(num_annotations) 84 | precisions.append(average_precision) 85 | if self.weighted_average: 86 | self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances) 87 | else: 88 | self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 89 | 90 | if self.tensorboard is not None and self.tensorboard.writer is not None: 91 | import tensorflow as tf 92 | summary = tf.Summary() 93 | summary_value = summary.value.add() 94 | summary_value.simple_value = self.mean_ap 95 | summary_value.tag = "mAP" 96 | self.tensorboard.writer.add_summary(summary, epoch) 97 | 98 | logs['mAP'] = self.mean_ap 99 | 100 | if self.verbose == 1: 101 | print('mAP: {:.4f}'.format(self.mean_ap)) 102 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 vidosits (https://github.com/vidosits/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.applications import densenet 19 | from keras.utils import get_file 20 | 21 | from . import retinanet 22 | from . import Backbone 23 | from ..utils.image import preprocess_image 24 | 25 | 26 | allowed_backbones = { 27 | 'densenet121': ([6, 12, 24, 16], densenet.DenseNet121), 28 | 'densenet169': ([6, 12, 32, 32], densenet.DenseNet169), 29 | 'densenet201': ([6, 12, 48, 32], densenet.DenseNet201), 30 | } 31 | 32 | 33 | class DenseNetBackbone(Backbone): 34 | """ Describes backbone information and provides utility functions. 35 | """ 36 | 37 | def retinanet(self, *args, **kwargs): 38 | """ Returns a retinanet model using the correct backbone. 39 | """ 40 | return densenet_retinanet(*args, backbone=self.backbone, **kwargs) 41 | 42 | def download_imagenet(self): 43 | """ Download pre-trained weights for the specified backbone name. 44 | This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop 45 | where backbone is the densenet + number of layers (e.g. densenet121). 46 | For more info check the explanation from the keras densenet script itself: 47 | https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py 48 | """ 49 | origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/' 50 | file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5' 51 | 52 | # load weights 53 | if keras.backend.image_data_format() == 'channels_first': 54 | raise ValueError('Weights for "channels_first" format are not available.') 55 | 56 | weights_url = origin + file_name.format(self.backbone) 57 | return get_file(file_name.format(self.backbone), weights_url, cache_subdir='models') 58 | 59 | def validate(self): 60 | """ Checks whether the backbone string is correct. 61 | """ 62 | backbone = self.backbone.split('_')[0] 63 | 64 | if backbone not in allowed_backbones: 65 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones.keys())) 66 | 67 | def preprocess_image(self, inputs): 68 | """ Takes as input an image and prepares it for being passed through the network. 69 | """ 70 | return preprocess_image(inputs, mode='tf') 71 | 72 | 73 | def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifier=None, **kwargs): 74 | """ Constructs a retinanet model using a densenet backbone. 75 | 76 | Args 77 | num_classes: Number of classes to predict. 78 | backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')). 79 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 80 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 81 | 82 | Returns 83 | RetinaNet model with a DenseNet backbone. 84 | """ 85 | # choose default input 86 | if inputs is None: 87 | inputs = keras.layers.Input((None, None, 3)) 88 | 89 | blocks, creator = allowed_backbones[backbone] 90 | model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None) 91 | 92 | # get last conv layer from the end of each dense block 93 | layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)] 94 | 95 | # create the densenet backbone 96 | model = keras.models.Model(inputs=inputs, outputs=layer_outputs[1:], name=model.name) 97 | 98 | # invoke modifier if given 99 | if modifier: 100 | model = modifier(model) 101 | 102 | # create the full model 103 | model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=model.outputs, **kwargs) 104 | 105 | return model 106 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.applications import mobilenet 19 | from keras.utils import get_file 20 | from ..utils.image import preprocess_image 21 | 22 | from . import retinanet 23 | from . import Backbone 24 | 25 | 26 | class MobileNetBackbone(Backbone): 27 | """ Describes backbone information and provides utility functions. 28 | """ 29 | 30 | allowed_backbones = ['mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224'] 31 | 32 | def retinanet(self, *args, **kwargs): 33 | """ Returns a retinanet model using the correct backbone. 34 | """ 35 | return mobilenet_retinanet(*args, backbone=self.backbone, **kwargs) 36 | 37 | def download_imagenet(self): 38 | """ Download pre-trained weights for the specified backbone name. 39 | This name is in the format mobilenet{rows}_{alpha} where rows is the 40 | imagenet shape dimension and 'alpha' controls the width of the network. 41 | For more info check the explanation from the keras mobilenet script itself. 42 | """ 43 | 44 | alpha = float(self.backbone.split('_')[1]) 45 | rows = int(self.backbone.split('_')[0].replace('mobilenet', '')) 46 | 47 | # load weights 48 | if keras.backend.image_data_format() == 'channels_first': 49 | raise ValueError('Weights for "channels_last" format ' 50 | 'are not available.') 51 | if alpha == 1.0: 52 | alpha_text = '1_0' 53 | elif alpha == 0.75: 54 | alpha_text = '7_5' 55 | elif alpha == 0.50: 56 | alpha_text = '5_0' 57 | else: 58 | alpha_text = '2_5' 59 | 60 | model_name = 'mobilenet_{}_{}_tf_no_top.h5'.format(alpha_text, rows) 61 | weights_url = mobilenet.mobilenet.BASE_WEIGHT_PATH + model_name 62 | weights_path = get_file(model_name, weights_url, cache_subdir='models') 63 | 64 | return weights_path 65 | 66 | def validate(self): 67 | """ Checks whether the backbone string is correct. 68 | """ 69 | backbone = self.backbone.split('_')[0] 70 | 71 | if backbone not in MobileNetBackbone.allowed_backbones: 72 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, MobileNetBackbone.allowed_backbones)) 73 | 74 | def preprocess_image(self, inputs): 75 | """ Takes as input an image and prepares it for being passed through the network. 76 | """ 77 | return preprocess_image(inputs, mode='tf') 78 | 79 | 80 | def mobilenet_retinanet(num_classes, backbone='mobilenet224_1.0', inputs=None, modifier=None, **kwargs): 81 | """ Constructs a retinanet model using a mobilenet backbone. 82 | 83 | Args 84 | num_classes: Number of classes to predict. 85 | backbone: Which backbone to use (one of ('mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224')). 86 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 87 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 88 | 89 | Returns 90 | RetinaNet model with a MobileNet backbone. 91 | """ 92 | alpha = float(backbone.split('_')[1]) 93 | 94 | # choose default input 95 | if inputs is None: 96 | inputs = keras.layers.Input((None, None, 3)) 97 | 98 | backbone = mobilenet.MobileNet(input_tensor=inputs, alpha=alpha, include_top=False, pooling=None, weights=None) 99 | 100 | # create the full model 101 | layer_names = ['conv_pw_5_relu', 'conv_pw_11_relu', 'conv_pw_13_relu'] 102 | layer_outputs = [backbone.get_layer(name).output for name in layer_names] 103 | backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=backbone.name) 104 | 105 | # invoke modifier if given 106 | if modifier: 107 | backbone = modifier(backbone) 108 | 109 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone.outputs, **kwargs) 110 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import cv2 18 | import numpy as np 19 | 20 | from .colors import label_color 21 | 22 | 23 | def draw_box(image, box, color, thickness=2): 24 | """ Draws a box on an image with a given color. 25 | 26 | # Arguments 27 | image : The image to draw on. 28 | box : A list of 4 elements (x1, y1, x2, y2). 29 | color : The color of the box. 30 | thickness : The thickness of the lines to draw a box with. 31 | """ 32 | b = np.array(box).astype(int) 33 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, thickness, cv2.LINE_AA) 34 | 35 | 36 | def draw_caption(image, box, caption): 37 | """ Draws a caption above the box in an image. 38 | 39 | # Arguments 40 | image : The image to draw on. 41 | box : A list of 4 elements (x1, y1, x2, y2). 42 | caption : String containing the text to draw. 43 | """ 44 | b = np.array(box).astype(int) 45 | ret, baseline = cv2.getTextSize(caption, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) 46 | cv2.rectangle(image, (b[0], b[3] - ret[1] - baseline), (b[0] + ret[0], b[3]), (255, 255, 255), -1) 47 | cv2.putText(image, caption, (b[0], b[3] - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) 48 | 49 | 50 | def draw_boxes(image, boxes, color, thickness=2): 51 | """ Draws boxes on an image with a given color. 52 | 53 | # Arguments 54 | image : The image to draw on. 55 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 56 | color : The color of the boxes. 57 | thickness : The thickness of the lines to draw boxes with. 58 | """ 59 | for b in boxes: 60 | draw_box(image, b, color, thickness=thickness) 61 | 62 | 63 | def draw_detections(image, boxes, scores, labels, color=None, label_to_name=None, score_threshold=0.5): 64 | """ Draws detections in an image. 65 | 66 | # Arguments 67 | image : The image to draw on. 68 | boxes : A [N, 4] matrix (x1, y1, x2, y2). 69 | scores : A list of N classification scores. 70 | labels : A list of N labels. 71 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 72 | label_to_name : (optional) Functor for mapping a label to a name. 73 | score_threshold : Threshold used for determining what detections to draw. 74 | """ 75 | selection = np.where(scores > score_threshold)[0] 76 | 77 | for i in selection: 78 | c = color if color is not None else label_color(labels[i]) 79 | draw_box(image, boxes[i, :], color=c) 80 | 81 | # draw labels 82 | caption = (label_to_name(labels[i]) if label_to_name else labels[i]) + ': {0:.2f}'.format(scores[i]) 83 | draw_caption(image, boxes[i, :], caption) 84 | 85 | 86 | def draw_annotations(image, annotations, color=(0, 255, 0), label_to_name=None): 87 | """ Draws annotations in an image. 88 | 89 | # Arguments 90 | image : The image to draw on. 91 | annotations : A [N, 5] matrix (x1, y1, x2, y2, label) or dictionary containing bboxes (shaped [N, 4]) and labels (shaped [N]). 92 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used. 93 | label_to_name : (optional) Functor for mapping a label to a name. 94 | """ 95 | if isinstance(annotations, np.ndarray): 96 | annotations = {'bboxes': annotations[:, :4], 'labels': annotations[:, 4]} 97 | 98 | assert('bboxes' in annotations) 99 | assert('labels' in annotations) 100 | assert(annotations['bboxes'].shape[0] == annotations['labels'].shape[0]) 101 | 102 | for i in range(annotations['bboxes'].shape[0]): 103 | label = annotations['labels'][i] 104 | c = color if color is not None else label_color(label) 105 | caption = '{}'.format(label_to_name(label) if label_to_name else label) 106 | draw_caption(image, annotations['bboxes'][i], caption) 107 | draw_box(image, annotations['bboxes'][i], color=c) 108 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import layers 4 | import losses 5 | import initializers 6 | from fsaf_layers import RegressBoxes, Locations, LevelSelect, FSAFTarget 7 | import keras 8 | import tensorflow as tf 9 | 10 | 11 | class Backbone(object): 12 | """ This class stores additional information on backbones. 13 | """ 14 | 15 | def __init__(self, backbone): 16 | # a dictionary mapping custom layer names to the correct classes 17 | self.custom_objects = { 18 | 'UpsampleLike': layers.UpsampleLike, 19 | 'PriorProbability': initializers.PriorProbability, 20 | 'RegressBoxes': RegressBoxes, 21 | 'FilterDetections': layers.FilterDetections, 22 | 'Anchors': layers.Anchors, 23 | 'ClipBoxes': layers.ClipBoxes, 24 | 'cls_loss': losses.focal_with_mask(), 25 | 'regr_loss': losses.iou_with_mask(), 26 | 'Locations': Locations, 27 | 'LevelSelect': LevelSelect, 28 | 'FSAFTarget': FSAFTarget, 29 | 'keras': keras, 30 | 'tf': tf, 31 | 'backend': tf, 32 | '': lambda y_true, y_pred: y_pred 33 | } 34 | 35 | self.backbone = backbone 36 | self.validate() 37 | 38 | def retinanet(self, *args, **kwargs): 39 | """ 40 | Returns a retinanet model using the correct backbone. 41 | """ 42 | raise NotImplementedError('retinanet method not implemented.') 43 | 44 | def download_imagenet(self): 45 | """ 46 | Downloads ImageNet weights and returns path to weights file. 47 | """ 48 | raise NotImplementedError('download_imagenet method not implemented.') 49 | 50 | def validate(self): 51 | """ 52 | Checks whether the backbone string is correct. 53 | """ 54 | raise NotImplementedError('validate method not implemented.') 55 | 56 | def preprocess_image(self, inputs): 57 | """ 58 | Takes as input an image and prepares it for being passed through the network. 59 | Having this function in Backbone allows other backbones to define a specific preprocessing step. 60 | """ 61 | raise NotImplementedError('preprocess_image method not implemented.') 62 | 63 | 64 | def backbone(backbone_name): 65 | """ Returns a backbone object for the given backbone. 66 | """ 67 | if 'resnet' in backbone_name: 68 | from .resnet import ResNetBackbone as b 69 | elif 'mobilenet' in backbone_name: 70 | from .mobilenet import MobileNetBackbone as b 71 | elif 'vgg' in backbone_name: 72 | from .vgg import VGGBackbone as b 73 | elif 'densenet' in backbone_name: 74 | from .densenet import DenseNetBackbone as b 75 | else: 76 | raise NotImplementedError('Backbone class for \'{}\' not implemented.'.format(backbone)) 77 | 78 | return b(backbone_name) 79 | 80 | 81 | def load_model(filepath, backbone_name='resnet50'): 82 | """ Loads a retinanet model using the correct custom objects. 83 | 84 | Args 85 | filepath: one of the following: 86 | - string, path to the saved model, or 87 | - h5py.File object from which to load the model 88 | backbone_name : Backbone with which the model was trained. 89 | 90 | Returns 91 | A keras.models.Model object. 92 | 93 | Raises 94 | ImportError: if h5py is not available. 95 | ValueError: In case of an invalid savefile. 96 | """ 97 | import keras.models 98 | return keras.models.load_model(filepath, custom_objects=backbone(backbone_name).custom_objects) 99 | 100 | 101 | def convert_model(model, nms=True, class_specific_filter=True): 102 | """ Converts a training model to an inference model. 103 | 104 | Args 105 | model : A retinanet training model. 106 | nms : Boolean, whether to add NMS filtering to the converted model. 107 | class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only. 108 | anchor_params : Anchor parameters object. If omitted, default values are used. 109 | 110 | Returns 111 | A keras.models.Model object. 112 | 113 | Raises 114 | ImportError: if h5py is not available. 115 | ValueError: In case of an invalid savefile. 116 | """ 117 | from .retinanet import fsaf_bbox 118 | return fsaf_bbox(model=model, nms=nms, class_specific_filter=class_specific_filter) 119 | 120 | 121 | def assert_training_model(model): 122 | """ 123 | Assert that the model is a training model. 124 | """ 125 | assert (all(output in model.output_names for output in 126 | ['cls_loss', 'regr_loss', 'fsaf_regression', 'fsaf_classification'])), \ 127 | "Input is not a training model (no 'regression' and 'classification' outputs were found, outputs are: {}).".format( 128 | model.output_names) 129 | 130 | 131 | def check_training_model(model): 132 | """ 133 | Check that model is a training model and exit otherwise. 134 | """ 135 | try: 136 | assert_training_model(model) 137 | except AssertionError as e: 138 | print(e, file=sys.stderr) 139 | sys.exit(1) 140 | -------------------------------------------------------------------------------- /augmentor/misc.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from augmentor.transform import translation_xy, change_transform_origin 4 | 5 | ROTATE_DEGREE = [90, 180, 270] 6 | 7 | 8 | def rotate(image, boxes, prob=0.5): 9 | random_prob = np.random.uniform() 10 | if random_prob < prob: 11 | return image, boxes 12 | rotate_degree = ROTATE_DEGREE[np.random.randint(0, 3)] 13 | h, w = image.shape[:2] 14 | # Compute the rotation matrix. 15 | M = cv2.getRotationMatrix2D(center=(w / 2, h / 2), 16 | angle=rotate_degree, 17 | scale=1) 18 | 19 | # Get the sine and cosine from the rotation matrix. 20 | abs_cos_angle = np.abs(M[0, 0]) 21 | abs_sin_angle = np.abs(M[0, 1]) 22 | 23 | # Compute the new bounding dimensions of the image. 24 | new_w = int(h * abs_sin_angle + w * abs_cos_angle) 25 | new_h = int(h * abs_cos_angle + w * abs_sin_angle) 26 | 27 | # Adjust the rotation matrix to take into account the translation. 28 | M[0, 2] += new_w // 2 - w // 2 29 | M[1, 2] += new_h // 2 - h // 2 30 | 31 | # Rotate the image. 32 | image = cv2.warpAffine(image, M=M, dsize=(new_w, new_h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, 33 | borderValue=(128, 128, 128)) 34 | 35 | new_boxes = [] 36 | for box in boxes: 37 | x1, y1, x2, y2 = box 38 | points = M.dot([ 39 | [x1, x2, x1, x2], 40 | [y1, y2, y2, y1], 41 | [1, 1, 1, 1], 42 | ]) 43 | 44 | # Extract the min and max corners again. 45 | min_xy = np.sort(points, axis=1)[:, :2] 46 | min_x = np.mean(min_xy[0]) 47 | min_y = np.mean(min_xy[1]) 48 | max_xy = np.sort(points, axis=1)[:, 2:] 49 | max_x = np.mean(max_xy[0]) 50 | max_y = np.mean(max_xy[1]) 51 | 52 | new_boxes.append([min_x, min_y, max_x, max_y]) 53 | boxes = np.array(new_boxes) 54 | return image, boxes 55 | 56 | 57 | def crop(image, boxes, prob=0.5): 58 | random_prob = np.random.uniform() 59 | if random_prob < prob: 60 | return image, boxes 61 | h, w = image.shape[:2] 62 | min_x1, min_y1 = np.min(boxes, axis=0)[:2] 63 | max_x2, max_y2 = np.max(boxes, axis=0)[2:] 64 | random_x1 = np.random.randint(0, max(min_x1 // 2, 1)) 65 | random_y1 = np.random.randint(0, max(min_y1 // 2, 1)) 66 | random_x2 = np.random.randint(max_x2, max(min(w, max_x2 + (w - max_x2) // 2), max_x2 + 1)) 67 | random_y2 = np.random.randint(max_y2, max(min(h, max_y2 + (h - max_y2) // 2), max_y2 + 1)) 68 | image = image[random_y1:random_y2, random_x1:random_x2] 69 | boxes[:, [0, 2]] = boxes[:, [0, 2]] - random_x1 70 | boxes[:, [1, 3]] = boxes[:, [1, 3]] - random_y1 71 | return image, boxes 72 | 73 | 74 | def translate(image, boxes, prob=0.5): 75 | random_prob = np.random.uniform() 76 | if random_prob < prob: 77 | return image, boxes 78 | h, w = image.shape[:2] 79 | min_x1, min_y1 = np.min(boxes, axis=0)[:2] 80 | max_x2, max_y2 = np.max(boxes, axis=0)[2:] 81 | translation_matrix = translation_xy(min=(min(-min_x1 // 2, 0), min(-min_y1 // 2, 0)), 82 | max=(max((w - max_x2) // 2, 1), max((h - max_y2) // 2, 1)), prob=1.) 83 | translation_matrix = change_transform_origin(translation_matrix, (w / 2, h / 2)) 84 | image = cv2.warpAffine( 85 | image, 86 | translation_matrix[:2, :], 87 | dsize=(w, h), 88 | flags=cv2.INTER_CUBIC, 89 | borderMode=cv2.BORDER_CONSTANT, 90 | borderValue=(128, 128, 128), 91 | ) 92 | new_boxes = [] 93 | for box in boxes: 94 | x1, y1, x2, y2 = box 95 | points = translation_matrix.dot([ 96 | [x1, x2, x1, x2], 97 | [y1, y2, y2, y1], 98 | [1, 1, 1, 1], 99 | ]) 100 | min_x, min_y = np.min(points, axis=1)[:2] 101 | max_x, max_y = np.max(points, axis=1)[:2] 102 | new_boxes.append([min_x, min_y, max_x, max_y]) 103 | boxes = np.array(new_boxes) 104 | return image, boxes 105 | 106 | 107 | class MiscEffect: 108 | def __init__(self, rotate_prob=0.9, crop_prob=0.5, translate_prob=0.5): 109 | self.rotate_prob = rotate_prob 110 | self.crop_prob = crop_prob 111 | self.translate_prob = translate_prob 112 | 113 | def __call__(self, image, boxes): 114 | image, boxes = rotate(image, boxes, prob=self.rotate_prob) 115 | image, boxes = crop(image, boxes, prob=self.crop_prob) 116 | image, boxes = translate(image, boxes, prob=self.translate_prob) 117 | return image, boxes 118 | 119 | 120 | if __name__ == '__main__': 121 | from yolo.generators.pascal import PascalVocGenerator 122 | 123 | train_generator = PascalVocGenerator( 124 | 'datasets/VOC0712', 125 | 'trainval', 126 | skip_difficult=True, 127 | anchors_path='voc_anchors_416.txt', 128 | batch_size=1 129 | ) 130 | misc_effect = MiscEffect() 131 | for i in range(train_generator.size()): 132 | image = train_generator.load_image(i) 133 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 134 | annotations = train_generator.load_annotations(i) 135 | boxes = annotations['bboxes'] 136 | for box in boxes.astype(np.int32): 137 | cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2) 138 | src_image = image.copy() 139 | cv2.namedWindow('src_image', cv2.WINDOW_NORMAL) 140 | cv2.imshow('src_image', src_image) 141 | image, boxes = misc_effect(image, boxes) 142 | for box in boxes.astype(np.int32): 143 | cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 1) 144 | cv2.namedWindow('image', cv2.WINDOW_NORMAL) 145 | cv2.imshow('image', image) 146 | cv2.waitKey(0) 147 | -------------------------------------------------------------------------------- /generators/coco_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from .generator import Generator 18 | from ..utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | 23 | from pycocotools.coco import COCO 24 | 25 | 26 | class CocoGenerator(Generator): 27 | """ Generate data from the COCO dataset. 28 | 29 | See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information. 30 | """ 31 | 32 | def __init__(self, data_dir, set_name, **kwargs): 33 | """ Initialize a COCO data generator. 34 | 35 | Args 36 | data_dir: Path to where the COCO dataset is stored. 37 | set_name: Name of the set to parse. 38 | """ 39 | self.data_dir = data_dir 40 | self.set_name = set_name 41 | self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json')) 42 | self.image_ids = self.coco.getImgIds() 43 | 44 | self.load_classes() 45 | 46 | super(CocoGenerator, self).__init__(**kwargs) 47 | 48 | def load_classes(self): 49 | """ Loads the class to label mapping (and inverse) for COCO. 50 | """ 51 | # load class names (name -> label) 52 | categories = self.coco.loadCats(self.coco.getCatIds()) 53 | categories.sort(key=lambda x: x['id']) 54 | 55 | self.classes = {} 56 | self.coco_labels = {} 57 | self.coco_labels_inverse = {} 58 | for c in categories: 59 | self.coco_labels[len(self.classes)] = c['id'] 60 | self.coco_labels_inverse[c['id']] = len(self.classes) 61 | self.classes[c['name']] = len(self.classes) 62 | 63 | # also load the reverse (label -> name) 64 | self.labels = {} 65 | for key, value in self.classes.items(): 66 | self.labels[value] = key 67 | 68 | def size(self): 69 | """ Size of the COCO dataset. 70 | """ 71 | return len(self.image_ids) 72 | 73 | def num_classes(self): 74 | """ Number of classes in the dataset. For COCO this is 80. 75 | """ 76 | return len(self.classes) 77 | 78 | def has_label(self, label): 79 | """ Return True if label is a known label. 80 | """ 81 | return label in self.labels 82 | 83 | def has_name(self, name): 84 | """ Returns True if name is a known class. 85 | """ 86 | return name in self.classes 87 | 88 | def name_to_label(self, name): 89 | """ Map name to label. 90 | """ 91 | return self.classes[name] 92 | 93 | def label_to_name(self, label): 94 | """ Map label to name. 95 | """ 96 | return self.labels[label] 97 | 98 | def coco_label_to_label(self, coco_label): 99 | """ Map COCO label to the label as used in the network. 100 | COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes. 101 | """ 102 | return self.coco_labels_inverse[coco_label] 103 | 104 | def coco_label_to_name(self, coco_label): 105 | """ Map COCO label to name. 106 | """ 107 | return self.label_to_name(self.coco_label_to_label(coco_label)) 108 | 109 | def label_to_coco_label(self, label): 110 | """ Map label as used by the network to labels as used by COCO. 111 | """ 112 | return self.coco_labels[label] 113 | 114 | def image_aspect_ratio(self, image_index): 115 | """ Compute the aspect ratio for an image with image_index. 116 | """ 117 | image = self.coco.loadImgs(self.image_ids[image_index])[0] 118 | return float(image['width']) / float(image['height']) 119 | 120 | def load_image(self, image_index): 121 | """ Load an image at the image_index. 122 | """ 123 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 124 | path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name']) 125 | return read_image_bgr(path) 126 | 127 | def load_annotations(self, image_index): 128 | """ Load annotations for an image_index. 129 | """ 130 | # get ground truth annotations 131 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 132 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 133 | 134 | # some images appear to miss annotations (like image with id 257034) 135 | if len(annotations_ids) == 0: 136 | return annotations 137 | 138 | # parse annotations 139 | coco_annotations = self.coco.loadAnns(annotations_ids) 140 | for idx, a in enumerate(coco_annotations): 141 | # some annotations have basically no width / height, skip them 142 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 143 | continue 144 | 145 | annotations['labels'] = np.concatenate( 146 | [annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0) 147 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[ 148 | a['bbox'][0], 149 | a['bbox'][1], 150 | a['bbox'][0] + a['bbox'][2], 151 | a['bbox'][1] + a['bbox'][3], 152 | ]]], axis=0) 153 | 154 | return annotations 155 | -------------------------------------------------------------------------------- /yolo/generators/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import cv2 17 | import numpy as np 18 | import os 19 | from pycocotools.coco import COCO 20 | 21 | from yolo.generators.common import Generator 22 | 23 | 24 | class CocoGenerator(Generator): 25 | """ 26 | Generate data from the COCO dataset. 27 | See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information. 28 | """ 29 | 30 | def __init__(self, data_dir, set_name, **kwargs): 31 | """ 32 | Initialize a COCO data generator. 33 | 34 | Args 35 | data_dir: Path to where the COCO dataset is stored. 36 | set_name: Name of the set to parse. 37 | """ 38 | self.data_dir = data_dir 39 | self.set_name = set_name 40 | self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json')) 41 | self.image_ids = self.coco.getImgIds() 42 | 43 | self.load_classes() 44 | 45 | super(CocoGenerator, self).__init__(**kwargs) 46 | 47 | def load_classes(self): 48 | """ 49 | Loads the class to label mapping (and inverse) for COCO. 50 | """ 51 | # load class names (name -> label) 52 | categories = self.coco.loadCats(self.coco.getCatIds()) 53 | categories.sort(key=lambda x: x['id']) 54 | 55 | self.classes = {} 56 | self.coco_labels = {} 57 | self.coco_labels_inverse = {} 58 | for c in categories: 59 | self.coco_labels[len(self.classes)] = c['id'] 60 | self.coco_labels_inverse[c['id']] = len(self.classes) 61 | self.classes[c['name']] = len(self.classes) 62 | 63 | # also load the reverse (label -> name) 64 | self.labels = {} 65 | for key, value in self.classes.items(): 66 | self.labels[value] = key 67 | 68 | def size(self): 69 | """ Size of the COCO dataset. 70 | """ 71 | return len(self.image_ids) 72 | 73 | def num_classes(self): 74 | """ Number of classes in the dataset. For COCO this is 80. 75 | """ 76 | return len(self.classes) 77 | 78 | def has_label(self, label): 79 | """ Return True if label is a known label. 80 | """ 81 | return label in self.labels 82 | 83 | def has_name(self, name): 84 | """ Returns True if name is a known class. 85 | """ 86 | return name in self.classes 87 | 88 | def name_to_label(self, name): 89 | """ Map name to label. 90 | """ 91 | return self.classes[name] 92 | 93 | def label_to_name(self, label): 94 | """ Map label to name. 95 | """ 96 | return self.labels[label] 97 | 98 | def coco_label_to_label(self, coco_label): 99 | """ Map COCO label to the label as used in the network. 100 | COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes. 101 | """ 102 | return self.coco_labels_inverse[coco_label] 103 | 104 | def coco_label_to_name(self, coco_label): 105 | """ Map COCO label to name. 106 | """ 107 | return self.label_to_name(self.coco_label_to_label(coco_label)) 108 | 109 | def label_to_coco_label(self, label): 110 | """ Map label as used by the network to labels as used by COCO. 111 | """ 112 | return self.coco_labels[label] 113 | 114 | def image_aspect_ratio(self, image_index): 115 | """ Compute the aspect ratio for an image with image_index. 116 | """ 117 | image = self.coco.loadImgs(self.image_ids[image_index])[0] 118 | return float(image['width']) / float(image['height']) 119 | 120 | def load_image(self, image_index): 121 | """ 122 | Load an image at the image_index. 123 | """ 124 | # {'license': 2, 'file_name': '000000259765.jpg', 'coco_url': 'http://images.cocodataset.org/test2017/000000259765.jpg', 'height': 480, 'width': 640, 'date_captured': '2013-11-21 04:02:31', 'id': 259765} 125 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 126 | path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name']) 127 | image = cv2.imread(path) 128 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 129 | return image 130 | 131 | def load_annotations(self, image_index): 132 | """ Load annotations for an image_index. 133 | """ 134 | # get ground truth annotations 135 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 136 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))} 137 | 138 | # some images appear to miss annotations (like image with id 257034) 139 | if len(annotations_ids) == 0: 140 | return annotations 141 | 142 | # parse annotations 143 | coco_annotations = self.coco.loadAnns(annotations_ids) 144 | for idx, a in enumerate(coco_annotations): 145 | # some annotations have basically no width / height, skip them 146 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 147 | continue 148 | 149 | annotations['labels'] = np.concatenate( 150 | [annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0) 151 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[ 152 | a['bbox'][0], 153 | a['bbox'][1], 154 | a['bbox'][0] + a['bbox'][2], 155 | a['bbox'][1] + a['bbox'][3], 156 | ]]], axis=0) 157 | 158 | return annotations 159 | 160 | 161 | if __name__ == '__main__': 162 | dataset_dir = '/home/adam/.keras/datasets/coco/2017_118_5' 163 | generator = CocoGenerator(data_dir=dataset_dir, set_name='test-dev2017') 164 | print(generator[0]) 165 | -------------------------------------------------------------------------------- /augmentor/color.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageEnhance, ImageOps 3 | 4 | 5 | def autocontrast(image, prob=0.5): 6 | random_prob = np.random.uniform() 7 | if random_prob > prob: 8 | return image 9 | image = Image.fromarray(image[..., ::-1]) 10 | image = ImageOps.autocontrast(image) 11 | image = np.array(image)[..., ::-1] 12 | return image 13 | 14 | 15 | def equalize(image, prob=0.5): 16 | random_prob = np.random.uniform() 17 | if random_prob > prob: 18 | return image 19 | image = Image.fromarray(image[..., ::-1]) 20 | image = ImageOps.equalize(image) 21 | image = np.array(image)[..., ::-1] 22 | return image 23 | 24 | 25 | def solarize(image, prob=0.5, threshold=128.): 26 | random_prob = np.random.uniform() 27 | if random_prob > prob: 28 | return image 29 | image = Image.fromarray(image[..., ::-1]) 30 | image = ImageOps.solarize(image, threshold=threshold) 31 | image = np.array(image)[..., ::-1] 32 | return image 33 | 34 | 35 | def sharpness(image, prob=0.5, min=0, max=2, factor=None): 36 | random_prob = np.random.uniform() 37 | if random_prob > prob: 38 | return image 39 | if factor is None: 40 | factor = np.random.uniform(min, max) 41 | image = Image.fromarray(image[..., ::-1]) 42 | enhancer = ImageEnhance.Sharpness(image) 43 | image = enhancer.enhance(factor=factor) 44 | return np.array(image)[..., ::-1] 45 | 46 | 47 | def color(image, prob=0.5, min=0., max=1., factor=None): 48 | random_prob = np.random.uniform() 49 | if random_prob > prob: 50 | return image 51 | if factor is None: 52 | factor = np.random.uniform(min, max) 53 | image = Image.fromarray(image[..., ::-1]) 54 | enhancer = ImageEnhance.Color(image) 55 | image = enhancer.enhance(factor=factor) 56 | return np.array(image)[..., ::-1] 57 | 58 | 59 | def contrast(image, prob=0.5, min=0.2, max=1., factor=None): 60 | random_prob = np.random.uniform() 61 | if random_prob > prob: 62 | return image 63 | if factor is None: 64 | factor = np.random.uniform(min, max) 65 | image = Image.fromarray(image[..., ::-1]) 66 | enhancer = ImageEnhance.Contrast(image) 67 | image = enhancer.enhance(factor=factor) 68 | return np.array(image)[..., ::-1] 69 | 70 | 71 | def brightness(image, prob=0.5, min=0.8, max=1., factor=None): 72 | random_prob = np.random.uniform() 73 | if random_prob > prob: 74 | return image 75 | if factor is None: 76 | factor = np.random.uniform(min, max) 77 | image = Image.fromarray(image[..., ::-1]) 78 | enhancer = ImageEnhance.Brightness(image) 79 | image = enhancer.enhance(factor=factor) 80 | return np.array(image)[..., ::-1] 81 | 82 | 83 | class VisualEffect: 84 | """ 85 | Struct holding parameters and applying image color transformation. 86 | 87 | Args 88 | solarize_threshold: 89 | color_factor: A factor for adjusting color. 90 | contrast_factor: A factor for adjusting contrast. 91 | brightness_factor: A factor for adjusting brightness. 92 | sharpness_factor: A factor for adjusting sharpness. 93 | """ 94 | 95 | def __init__( 96 | self, 97 | color_factor=None, 98 | contrast_factor=None, 99 | brightness_factor=None, 100 | sharpness_factor=None, 101 | color_prob=0.5, 102 | contrast_prob=0.5, 103 | brightness_prob=0.5, 104 | sharpness_prob=0.5, 105 | autocontrast_prob=0.5, 106 | equalize_prob=0.5, 107 | solarize_prob=0.1, 108 | solarize_threshold=128., 109 | 110 | ): 111 | self.color_factor = color_factor 112 | self.contrast_factor = contrast_factor 113 | self.brightness_factor = brightness_factor 114 | self.sharpness_factor = sharpness_factor 115 | self.color_prob = color_prob 116 | self.contrast_prob = contrast_prob 117 | self.brightness_prob = brightness_prob 118 | self.sharpness_prob = sharpness_prob 119 | self.autocontrast_prob = autocontrast_prob 120 | self.equalize_prob = equalize_prob 121 | self.solarize_prob = solarize_prob 122 | self.solarize_threshold = solarize_threshold 123 | 124 | def __call__(self, image): 125 | """ 126 | Apply a visual effect on the image. 127 | 128 | Args 129 | image: Image to adjust 130 | """ 131 | random_enhance_id = np.random.randint(0, 4) 132 | if random_enhance_id == 0: 133 | image = color(image, prob=self.color_prob, factor=self.color_factor) 134 | elif random_enhance_id == 1: 135 | image = contrast(image, prob=self.contrast_prob, factor=self.contrast_factor) 136 | elif random_enhance_id == 2: 137 | image = brightness(image, prob=self.brightness_prob, factor=self.brightness_factor) 138 | else: 139 | image = sharpness(image, prob=self.sharpness_prob, factor=self.sharpness_factor) 140 | 141 | random_ops_id = np.random.randint(0, 3) 142 | if random_ops_id == 0: 143 | image = autocontrast(image, prob=self.autocontrast_prob) 144 | elif random_ops_id == 1: 145 | image = equalize(image, prob=self.equalize_prob) 146 | else: 147 | image = solarize(image, prob=self.solarize_prob, threshold=self.solarize_threshold) 148 | return image 149 | 150 | 151 | if __name__ == '__main__': 152 | from yolo.generators.pascal import PascalVocGenerator 153 | import cv2 154 | 155 | train_generator = PascalVocGenerator( 156 | 'datasets/VOC0712', 157 | 'trainval', 158 | skip_difficult=True, 159 | anchors_path='voc_anchors_416.txt', 160 | batch_size=1 161 | ) 162 | visual_effect = VisualEffect() 163 | for i in range(train_generator.size()): 164 | image = train_generator.load_image(i) 165 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 166 | annotations = train_generator.load_annotations(i) 167 | boxes = annotations['bboxes'] 168 | for box in boxes.astype(np.int32): 169 | cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2) 170 | src_image = image.copy() 171 | image = visual_effect(image) 172 | cv2.namedWindow('image', cv2.WINDOW_NORMAL) 173 | cv2.imshow('image', np.concatenate([src_image, image], axis=1)) 174 | cv2.waitKey(0) 175 | -------------------------------------------------------------------------------- /yolo/model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Add, BatchNormalization, Concatenate, Conv2D, Input, Reshape 2 | from keras.layers import Lambda, LeakyReLU, UpSampling2D, ZeroPadding2D, Activation 3 | from keras.regularizers import l2 4 | from keras.models import Model 5 | from functools import reduce 6 | 7 | from layers import FilterDetections, ClipBoxes 8 | from losses import focal_with_mask, iou_with_mask 9 | from yolo import config 10 | from yolo.fsaf_layers import FSAFTarget, LevelSelect, Locations, RegressBoxes 11 | 12 | 13 | def compose(*funcs): 14 | """ 15 | Compose arbitrarily many functions, evaluated left to right. 16 | 17 | Reference: https://mathieularose.com/function-composition-in-python/ 18 | """ 19 | # return lambda x: reduce(lambda v, f: f(v), funcs, x) 20 | if funcs: 21 | return reduce(lambda f, g: lambda *args, **kwargs: g(f(*args, **kwargs)), funcs) 22 | else: 23 | raise ValueError('Composition of empty sequence not supported.') 24 | 25 | 26 | def darknet_conv2d(*args, **kwargs): 27 | """ 28 | Wrapper to set Darknet parameters for Convolution2D. 29 | """ 30 | darknet_conv_kwargs = dict({'kernel_regularizer': l2(5e-4)}) 31 | darknet_conv_kwargs['padding'] = 'valid' if kwargs.get('strides') == (2, 2) else 'same' 32 | darknet_conv_kwargs.update(kwargs) 33 | return Conv2D(*args, **darknet_conv_kwargs) 34 | 35 | 36 | def darknet_conv2d_bn_leaky(*args, **kwargs): 37 | """ 38 | Darknet Convolution2D followed by BatchNormalization and LeakyReLU. 39 | """ 40 | no_bias_kwargs = {'use_bias': False} 41 | no_bias_kwargs.update(kwargs) 42 | return compose( 43 | darknet_conv2d(*args, **no_bias_kwargs), 44 | BatchNormalization(), 45 | LeakyReLU(alpha=0.1)) 46 | 47 | 48 | def resblock_body(x, num_filters, num_blocks): 49 | """ 50 | A series of resblocks starting with a downsampling Convolution2D 51 | """ 52 | # Darknet uses left and top padding instead of 'same' mode 53 | x = ZeroPadding2D(((1, 0), (1, 0)))(x) 54 | x = darknet_conv2d_bn_leaky(num_filters, (3, 3), strides=(2, 2))(x) 55 | for i in range(num_blocks): 56 | y = compose( 57 | darknet_conv2d_bn_leaky(num_filters // 2, (1, 1)), 58 | darknet_conv2d_bn_leaky(num_filters, (3, 3)))(x) 59 | x = Add()([x, y]) 60 | return x 61 | 62 | 63 | def darknet_body(x): 64 | """ 65 | Darknet body having 52 Convolution2D layers 66 | """ 67 | x = darknet_conv2d_bn_leaky(32, (3, 3))(x) 68 | x = resblock_body(x, 64, 1) 69 | x = resblock_body(x, 128, 2) 70 | x = resblock_body(x, 256, 8) 71 | x = resblock_body(x, 512, 8) 72 | x = resblock_body(x, 1024, 4) 73 | return x 74 | 75 | 76 | def make_last_layers(x, num_filters, out_filters): 77 | """ 78 | 6 conv2d_bn_leaky layers followed by a conv2d layer 79 | """ 80 | x = compose(darknet_conv2d_bn_leaky(num_filters, (1, 1)), 81 | darknet_conv2d_bn_leaky(num_filters * 2, (3, 3)), 82 | darknet_conv2d_bn_leaky(num_filters, (1, 1)), 83 | darknet_conv2d_bn_leaky(num_filters * 2, (3, 3)), 84 | darknet_conv2d_bn_leaky(num_filters, (1, 1)))(x) 85 | y = compose(darknet_conv2d_bn_leaky(num_filters * 2, (3, 3)), 86 | darknet_conv2d(out_filters, (1, 1)))(x) 87 | return x, y 88 | 89 | 90 | def yolo_body(num_classes=20, score_threshold=0.01): 91 | """ 92 | Create YOLO_V3 model CNN body in Keras. 93 | 94 | Args: 95 | num_classes: 96 | score_threshold: 97 | 98 | Returns: 99 | 100 | """ 101 | image_input = Input(shape=(None, None, 3), name='image_input') 102 | darknet = Model([image_input], darknet_body(image_input)) 103 | ################################################## 104 | # build fsaf head 105 | ################################################## 106 | x, y1 = make_last_layers(darknet.output, 512, 4 + num_classes) 107 | 108 | x = compose(darknet_conv2d_bn_leaky(256, (1, 1)), UpSampling2D(2))(x) 109 | x = Concatenate()([x, darknet.layers[152].output]) 110 | x, y2 = make_last_layers(x, 256, 4 + num_classes) 111 | x = compose(darknet_conv2d_bn_leaky(128, (1, 1)), UpSampling2D(2))(x) 112 | x = Concatenate()([x, darknet.layers[92].output]) 113 | x, y3 = make_last_layers(x, 128, 4 + num_classes) 114 | y1_ = Reshape((-1, 4 + num_classes))(y1) 115 | y2_ = Reshape((-1, 4 + num_classes))(y2) 116 | y3_ = Reshape((-1, 4 + num_classes))(y3) 117 | y = Concatenate(axis=1)([y1_, y2_, y3_]) 118 | batch_cls_pred = Lambda(lambda x: x[..., 4:])(y) 119 | batch_regr_pred = Lambda(lambda x: x[..., :4])(y) 120 | batch_cls_pred = Activation('sigmoid')(batch_cls_pred) 121 | batch_regr_pred = Activation('relu')(batch_regr_pred) 122 | 123 | gt_boxes_input = Input(shape=(config.MAX_NUM_GT_BOXES, 5), name='gt_boxes_input') 124 | grid_shapes_input = Input((len(config.STRIDES), 2), dtype='int32', name='grid_shapes_input') 125 | batch_gt_box_levels = LevelSelect(name='level_select')( 126 | [batch_cls_pred, batch_regr_pred, grid_shapes_input, gt_boxes_input]) 127 | batch_cls_target, batch_cls_mask, batch_cls_num_pos, batch_regr_target, batch_regr_mask = FSAFTarget( 128 | num_classes=num_classes, 129 | name='fsaf_target')( 130 | [batch_gt_box_levels, grid_shapes_input, gt_boxes_input]) 131 | focal_loss_graph = focal_with_mask() 132 | iou_loss_graph = iou_with_mask() 133 | cls_loss = Lambda(focal_loss_graph, 134 | output_shape=(1,), 135 | name="cls_loss")( 136 | [batch_cls_target, batch_cls_pred, batch_cls_mask, batch_cls_num_pos]) 137 | regr_loss = Lambda(iou_loss_graph, 138 | output_shape=(1,), 139 | name="regr_loss")([batch_regr_target, batch_regr_pred, batch_regr_mask]) 140 | model = Model(inputs=[image_input, gt_boxes_input, grid_shapes_input], 141 | outputs=[cls_loss, regr_loss], 142 | name='fsaf') 143 | 144 | # compute the anchors 145 | features = [y1, y2, y3] 146 | 147 | locations, strides = Locations(strides=config.STRIDES)(features) 148 | 149 | # apply predicted regression to anchors 150 | boxes = RegressBoxes(name='boxes')([locations, strides, batch_regr_pred]) 151 | boxes = ClipBoxes(name='clipped_boxes')([image_input, boxes]) 152 | 153 | # filter detections (apply NMS / score threshold / select top-k) 154 | detections = FilterDetections( 155 | nms=True, 156 | class_specific_filter=True, 157 | name='filtered_detections', 158 | score_threshold=score_threshold 159 | )([boxes, batch_cls_pred]) 160 | 161 | prediction_model = Model(inputs=image_input, outputs=detections, name='fsaf_detection') 162 | return model, prediction_model 163 | -------------------------------------------------------------------------------- /util_graphs.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import keras.backend as K 3 | 4 | 5 | def xyxy2cxcywh(xyxy): 6 | """ 7 | Convert [x1 y1 x2 y2] box format to [cx cx w h] format. 8 | """ 9 | return tf.concat((0.5 * (xyxy[:, 0:2] + xyxy[:, 2:4]), xyxy[:, 2:4] - xyxy[:, 0:2]), axis=-1) 10 | 11 | 12 | def cxcywh2xyxy(xywh): 13 | """ 14 | Convert [cx cy w y] box format to [x1 y1 x2 y2] format. 15 | """ 16 | return tf.concat((xywh[:, 0:2] - 0.5 * xywh[:, 2:4], xywh[:, 0:2] + 0.5 * xywh[:, 2:4]), axis=-1) 17 | 18 | 19 | def prop_box_graph(boxes, scale, width, height): 20 | """ 21 | Compute proportional box coordinates. 22 | 23 | Box centers are fixed. Box w and h scaled by scale. 24 | """ 25 | prop_boxes = xyxy2cxcywh(boxes) 26 | prop_boxes = tf.concat((prop_boxes[:, :2], prop_boxes[:, 2:] * scale), axis=-1) 27 | prop_boxes = cxcywh2xyxy(prop_boxes) 28 | x1 = tf.floor(prop_boxes[:, 0]) 29 | y1 = tf.floor(prop_boxes[:, 1]) 30 | x2 = tf.math.ceil(prop_boxes[:, 2]) 31 | y2 = tf.math.ceil(prop_boxes[:, 3]) 32 | width = tf.cast(width, tf.float32) 33 | height = tf.cast(height, tf.float32) 34 | x2 = tf.cast(tf.clip_by_value(x2, 1, width), tf.int32) 35 | y2 = tf.cast(tf.clip_by_value(y2, 1, height), tf.int32) 36 | x1 = tf.cast(tf.clip_by_value(x1, 0, tf.cast(x2, tf.float32) - 1), tf.int32) 37 | y1 = tf.cast(tf.clip_by_value(y1, 0, tf.cast(y2, tf.float32) - 1), tf.int32) 38 | 39 | return x1, y1, x2, y2 40 | 41 | 42 | def prop_box_graph_2(boxes, scale, width, height): 43 | """ 44 | Compute proportional box coordinates. 45 | 46 | Box centers are fixed. Box w and h scaled by scale. 47 | """ 48 | prop_boxes = xyxy2cxcywh(boxes) 49 | prop_boxes = tf.concat((prop_boxes[:, :2], prop_boxes[:, 2:] * scale), axis=-1) 50 | prop_boxes = cxcywh2xyxy(prop_boxes) 51 | # (n, 1) 52 | x1 = tf.floor(prop_boxes[:, 0:1]) 53 | y1 = tf.floor(prop_boxes[:, 1:2]) 54 | x2 = tf.math.ceil(prop_boxes[:, 2:3]) 55 | y2 = tf.math.ceil(prop_boxes[:, 3:4]) 56 | width = tf.cast(width, tf.float32) 57 | height = tf.cast(height, tf.float32) 58 | x2 = tf.cast(tf.clip_by_value(x2, 1, width), tf.int32) 59 | y2 = tf.cast(tf.clip_by_value(y2, 1, height), tf.int32) 60 | x1 = tf.cast(tf.clip_by_value(x1, 0, tf.cast(x2, tf.float32) - 1), tf.int32) 61 | y1 = tf.cast(tf.clip_by_value(y1, 0, tf.cast(y2, tf.float32) - 1), tf.int32) 62 | 63 | return x1, y1, x2, y2 64 | 65 | 66 | def trim_zeros_graph(boxes, name='trim_zeros'): 67 | """ 68 | Often boxes are represented with matrices of shape [N, 4] and are padded with zeros. 69 | This removes zero boxes. 70 | 71 | Args: 72 | boxes: [N, 4] matrix of boxes. 73 | name: name of tensor 74 | 75 | Returns: 76 | 77 | """ 78 | non_zeros = tf.cast(tf.reduce_sum(tf.abs(boxes), axis=1), tf.bool) 79 | boxes = tf.boolean_mask(boxes, non_zeros, name=name) 80 | return boxes, non_zeros 81 | 82 | 83 | def bbox_transform_inv(boxes, deltas, mean=None, std=None): 84 | """ 85 | Applies deltas (usually regression results) to boxes (usually anchors). 86 | 87 | Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed. 88 | The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes. 89 | 90 | Args 91 | boxes: np.array of shape (B, N, 4), where B is the batch size, N the number of boxes and 4 values for (x1, y1, x2, y2). 92 | deltas: np.array of same shape as boxes. These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height. 93 | mean: The mean value used when computing deltas (defaults to [0, 0, 0, 0]). 94 | std: The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]). 95 | 96 | Returns 97 | A np.array of the same shape as boxes, but with deltas applied to each box. 98 | The mean and std are used during training to normalize the regression values (networks love normalization). 99 | """ 100 | if mean is None: 101 | mean = [0, 0, 0, 0] 102 | if std is None: 103 | std = [0.2, 0.2, 0.2, 0.2] 104 | 105 | width = boxes[:, :, 2] - boxes[:, :, 0] 106 | height = boxes[:, :, 3] - boxes[:, :, 1] 107 | 108 | x1 = boxes[:, :, 0] + (deltas[:, :, 0] * std[0] + mean[0]) * width 109 | y1 = boxes[:, :, 1] + (deltas[:, :, 1] * std[1] + mean[1]) * height 110 | x2 = boxes[:, :, 2] + (deltas[:, :, 2] * std[2] + mean[2]) * width 111 | y2 = boxes[:, :, 3] + (deltas[:, :, 3] * std[3] + mean[3]) * height 112 | 113 | pred_boxes = K.stack([x1, y1, x2, y2], axis=2) 114 | 115 | return pred_boxes 116 | 117 | 118 | def shift(shape, stride, anchors): 119 | """ 120 | Produce shifted anchors based on shape of the map and stride size. 121 | 122 | Args 123 | shape: Shape to shift the anchors over. (h,w) 124 | stride: Stride to shift the anchors with over the shape. 125 | anchors: The anchors to apply at each location. 126 | 127 | Returns 128 | shifted_anchors: (fh * fw * num_anchors, 4) 129 | """ 130 | shift_x = (K.arange(0, shape[1], dtype=K.floatx()) + K.constant(0.5, dtype=K.floatx())) * stride 131 | shift_y = (K.arange(0, shape[0], dtype=K.floatx()) + K.constant(0.5, dtype=K.floatx())) * stride 132 | shift_x, shift_y = tf.meshgrid(shift_x, shift_y) 133 | shift_x = K.reshape(shift_x, [-1]) 134 | shift_y = K.reshape(shift_y, [-1]) 135 | 136 | # (4, fh * fw) 137 | shifts = K.stack([ 138 | shift_x, 139 | shift_y, 140 | shift_x, 141 | shift_y 142 | ], axis=0) 143 | # (fh * fw, 4) 144 | shifts = K.transpose(shifts) 145 | number_anchors = K.shape(anchors)[0] 146 | 147 | # number of base points = fh * fw 148 | k = K.shape(shifts)[0] 149 | 150 | # (k=fh*fw, num_anchors, 4) 151 | shifted_anchors = K.reshape(anchors, [1, number_anchors, 4]) + K.cast(K.reshape(shifts, [k, 1, 4]), K.floatx()) 152 | # (k * num_anchors, 4) 153 | shifted_anchors = K.reshape(shifted_anchors, [k * number_anchors, 4]) 154 | 155 | return shifted_anchors 156 | 157 | 158 | def resize_images(images, size, method='bilinear', align_corners=False): 159 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/resize_images . 160 | 161 | Args 162 | method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area'). 163 | """ 164 | methods = { 165 | 'bilinear': tf.image.ResizeMethod.BILINEAR, 166 | 'nearest': tf.image.ResizeMethod.NEAREST_NEIGHBOR, 167 | 'bicubic': tf.image.ResizeMethod.BICUBIC, 168 | 'area': tf.image.ResizeMethod.AREA, 169 | } 170 | return tf.image.resize_images(images, size, methods[method], align_corners) 171 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from keras.utils import get_file 19 | import keras_resnet 20 | import keras_resnet.models 21 | import tensorflow as tf 22 | 23 | from models import retinanet 24 | from models import Backbone 25 | from utils.image import preprocess_image 26 | import configure 27 | 28 | 29 | class ResNetBackbone(Backbone): 30 | """ 31 | Describes backbone information and provides utility functions. 32 | """ 33 | 34 | def __init__(self, backbone): 35 | super(ResNetBackbone, self).__init__(backbone) 36 | self.custom_objects.update(keras_resnet.custom_objects) 37 | 38 | def retinanet(self, *args, **kwargs): 39 | """ 40 | Returns a retinanet model using the correct backbone. 41 | """ 42 | return resnet_retinanet(*args, backbone=self.backbone, **kwargs) 43 | 44 | def fsaf(self, num_classes, modifier): 45 | """ 46 | Returns a retinanet model using the correct backbone. 47 | """ 48 | return resnet_fsaf(num_classes=num_classes, backbone=self.backbone, modifier=modifier) 49 | 50 | def download_imagenet(self): 51 | """ 52 | Downloads ImageNet weights and returns path to weights file. 53 | """ 54 | resnet_filename = 'ResNet-{}-model.keras.h5' 55 | resnet_resource = 'https://github.com/fizyr/keras-models/releases/download/v0.0.1/{}'.format(resnet_filename) 56 | depth = int(self.backbone.replace('resnet', '')) 57 | 58 | filename = resnet_filename.format(depth) 59 | resource = resnet_resource.format(depth) 60 | if depth == 50: 61 | checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319' 62 | elif depth == 101: 63 | checksum = '05dc86924389e5b401a9ea0348a3213c' 64 | elif depth == 152: 65 | checksum = '6ee11ef2b135592f8031058820bb9e71' 66 | else: 67 | raise ValueError('Unknown depth') 68 | 69 | return get_file( 70 | filename, 71 | resource, 72 | cache_subdir='models', 73 | md5_hash=checksum 74 | ) 75 | 76 | def validate(self): 77 | """ 78 | Checks whether the backbone string is correct. 79 | """ 80 | allowed_backbones = ['resnet50', 'resnet101', 'resnet152'] 81 | 82 | if self.backbone not in allowed_backbones: 83 | raise ValueError( 84 | 'Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones)) 85 | 86 | def preprocess_image(self, inputs): 87 | """ 88 | Takes as input an image and prepares it for being passed through the network. 89 | """ 90 | return preprocess_image(inputs, mode='caffe') 91 | 92 | 93 | def resnet_retinanet(num_classes, backbone='resnet50', modifier=None, **kwargs): 94 | """ 95 | Constructs a retinanet model using a resnet backbone. 96 | 97 | Args 98 | num_classes: Number of classes to predict. 99 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). 100 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 101 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 102 | 103 | Returns 104 | RetinaNet model with a ResNet backbone. 105 | """ 106 | # choose default input 107 | inputs = keras.layers.Input(shape=(None, None, 3)) 108 | 109 | # create the resnet backbone 110 | if backbone == 'resnet50': 111 | resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True) 112 | elif backbone == 'resnet101': 113 | resnet = keras_resnet.models.ResNet101(inputs, include_top=False, freeze_bn=True) 114 | elif backbone == 'resnet152': 115 | resnet = keras_resnet.models.ResNet152(inputs, include_top=False, freeze_bn=True) 116 | else: 117 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) 118 | 119 | # invoke modifier if given 120 | if modifier: 121 | resnet = modifier(resnet) 122 | 123 | # create the full model 124 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=resnet.outputs[1:], **kwargs) 125 | 126 | 127 | def resnet_fsaf(num_classes, backbone='resnet50', modifier=None): 128 | """ 129 | Constructs a retinanet model using a resnet backbone. 130 | 131 | Args 132 | num_classes: Number of classes to predict. 133 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')). 134 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). 135 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). 136 | 137 | Returns 138 | RetinaNet model with a ResNet backbone. 139 | """ 140 | image_input = keras.layers.Input(shape=(None, None, 3)) 141 | gt_boxes_input = keras.layers.Input(shape=(configure.MAX_NUM_GT_BOXES, 5)) 142 | feature_shapes_input = keras.layers.Input((5, 2), dtype='int32') 143 | 144 | # create the resnet backbone 145 | if backbone == 'resnet50': 146 | resnet = keras_resnet.models.ResNet50(image_input, include_top=False, freeze_bn=True) 147 | elif backbone == 'resnet101': 148 | resnet = keras_resnet.models.ResNet101(image_input, include_top=False, freeze_bn=True) 149 | elif backbone == 'resnet152': 150 | resnet = keras_resnet.models.ResNet152(image_input, include_top=False, freeze_bn=True) 151 | else: 152 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone)) 153 | 154 | # invoke modifier if given 155 | if modifier: 156 | resnet = modifier(resnet) 157 | 158 | # create the full model 159 | return retinanet.fsaf(inputs=[image_input, gt_boxes_input, feature_shapes_input], 160 | num_classes=num_classes, 161 | backbone_layers=resnet.outputs[1:]) 162 | 163 | 164 | def resnet50_retinanet(num_classes, inputs=None, **kwargs): 165 | return resnet_retinanet(num_classes=num_classes, backbone='resnet50', inputs=inputs, **kwargs) 166 | 167 | 168 | def resnet101_retinanet(num_classes, inputs=None, **kwargs): 169 | return resnet_retinanet(num_classes=num_classes, backbone='resnet101', inputs=inputs, **kwargs) 170 | 171 | 172 | def resnet152_retinanet(num_classes, inputs=None, **kwargs): 173 | return resnet_retinanet(num_classes=num_classes, backbone='resnet152', inputs=inputs, **kwargs) 174 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from utils.eval import evaluate 3 | from utils.coco_eval import evaluate_coco 4 | 5 | 6 | class Evaluate(keras.callbacks.Callback): 7 | """ 8 | Evaluation callback for arbitrary datasets. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | generator, 14 | iou_threshold=0.5, 15 | score_threshold=0.05, 16 | max_detections=100, 17 | save_path=None, 18 | tensorboard=None, 19 | weighted_average=False, 20 | verbose=1 21 | ): 22 | """ 23 | Evaluate a given dataset using a given model at the end of every epoch during training. 24 | 25 | Args: 26 | generator: The generator that represents the dataset to evaluate. 27 | iou_threshold: The threshold used to consider when a detection is positive or negative. 28 | score_threshold: The score confidence threshold to use for detections. 29 | max_detections: The maximum number of detections to use per image. 30 | save_path: The path to save images with visualized detections to. 31 | tensorboard: Instance of keras.callbacks.TensorBoard used to log the mAP value. 32 | weighted_average: Compute the mAP using the weighted average of precisions among classes. 33 | verbose: Set the verbosity level, by default this is set to 1. 34 | """ 35 | self.generator = generator 36 | self.iou_threshold = iou_threshold 37 | self.score_threshold = score_threshold 38 | self.max_detections = max_detections 39 | self.save_path = save_path 40 | self.tensorboard = tensorboard 41 | self.weighted_average = weighted_average 42 | self.verbose = verbose 43 | 44 | super(Evaluate, self).__init__() 45 | 46 | def on_epoch_end(self, epoch, logs=None): 47 | logs = logs or {} 48 | 49 | # run evaluation 50 | average_precisions = evaluate( 51 | self.generator, 52 | self.model, 53 | iou_threshold=self.iou_threshold, 54 | score_threshold=self.score_threshold, 55 | max_detections=self.max_detections, 56 | visualize=False, 57 | ) 58 | 59 | # compute per class average precision 60 | total_instances = [] 61 | precisions = [] 62 | for label, (average_precision, num_annotations) in average_precisions.items(): 63 | if self.verbose == 1: 64 | print('{:.0f} instances of class'.format(num_annotations), 65 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision)) 66 | total_instances.append(num_annotations) 67 | precisions.append(average_precision) 68 | if self.weighted_average: 69 | self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances) 70 | else: 71 | self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 72 | 73 | if self.tensorboard is not None and self.tensorboard.writer is not None: 74 | import tensorflow as tf 75 | summary = tf.Summary() 76 | summary_value = summary.value.add() 77 | summary_value.simple_value = self.mean_ap 78 | summary_value.tag = "mAP" 79 | self.tensorboard.writer.add_summary(summary, epoch) 80 | 81 | logs['mAP'] = self.mean_ap 82 | 83 | if self.verbose == 1: 84 | print('mAP: {:.4f}'.format(self.mean_ap)) 85 | 86 | 87 | class RedirectModel(keras.callbacks.Callback): 88 | """ 89 | Callback which wraps another callback, but executed on a different model. 90 | 91 | ```python 92 | model = keras.models.load_model('model.h5') 93 | model_checkpoint = ModelCheckpoint(filepath='snapshot.h5') 94 | parallel_model = multi_gpu_model(model, gpus=2) 95 | parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)]) 96 | ``` 97 | 98 | Args 99 | callback : callback to wrap. 100 | model : model to use when executing callbacks. 101 | """ 102 | 103 | def __init__(self, 104 | callback, 105 | model): 106 | super(RedirectModel, self).__init__() 107 | 108 | self.callback = callback 109 | self.redirect_model = model 110 | 111 | def on_epoch_begin(self, epoch, logs=None): 112 | self.callback.on_epoch_begin(epoch, logs=logs) 113 | 114 | def on_epoch_end(self, epoch, logs=None): 115 | self.callback.on_epoch_end(epoch, logs=logs) 116 | 117 | def on_batch_begin(self, batch, logs=None): 118 | self.callback.on_batch_begin(batch, logs=logs) 119 | 120 | def on_batch_end(self, batch, logs=None): 121 | self.callback.on_batch_end(batch, logs=logs) 122 | 123 | def on_train_begin(self, logs=None): 124 | # overwrite the model with our custom model 125 | self.callback.set_model(self.redirect_model) 126 | 127 | self.callback.on_train_begin(logs=logs) 128 | 129 | def on_train_end(self, logs=None): 130 | self.callback.on_train_end(logs=logs) 131 | 132 | 133 | class CocoEval(keras.callbacks.Callback): 134 | """ Performs COCO evaluation on each epoch. 135 | """ 136 | def __init__(self, generator, tensorboard=None, threshold=0.05): 137 | """ CocoEval callback intializer. 138 | 139 | Args 140 | generator : The generator used for creating validation data. 141 | tensorboard : If given, the results will be written to tensorboard. 142 | threshold : The score threshold to use. 143 | """ 144 | self.generator = generator 145 | self.threshold = threshold 146 | self.tensorboard = tensorboard 147 | 148 | super(CocoEval, self).__init__() 149 | 150 | def on_epoch_end(self, epoch, logs=None): 151 | logs = logs or {} 152 | 153 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 154 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]', 155 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]', 156 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 157 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 158 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', 159 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', 160 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', 161 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 162 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 163 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 164 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]'] 165 | coco_eval_stats = evaluate_coco(self.generator, self.model, self.threshold) 166 | if coco_eval_stats is not None and self.tensorboard is not None and self.tensorboard.writer is not None: 167 | import tensorflow as tf 168 | summary = tf.Summary() 169 | for index, result in enumerate(coco_eval_stats): 170 | summary_value = summary.value.add() 171 | summary_value.simple_value = result 172 | summary_value.tag = '{}. {}'.format(index + 1, coco_tag[index]) 173 | self.tensorboard.writer.add_summary(summary, epoch) 174 | logs[coco_tag[index]] = result 175 | -------------------------------------------------------------------------------- /generators/voc_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from generators.generator import Generator 18 | from utils.image import read_image_bgr 19 | 20 | import os 21 | import numpy as np 22 | from six import raise_from 23 | from PIL import Image 24 | 25 | try: 26 | import xml.etree.cElementTree as ET 27 | except ImportError: 28 | import xml.etree.ElementTree as ET 29 | 30 | voc_classes = { 31 | 'aeroplane': 0, 32 | 'bicycle': 1, 33 | 'bird': 2, 34 | 'boat': 3, 35 | 'bottle': 4, 36 | 'bus': 5, 37 | 'car': 6, 38 | 'cat': 7, 39 | 'chair': 8, 40 | 'cow': 9, 41 | 'diningtable': 10, 42 | 'dog': 11, 43 | 'horse': 12, 44 | 'motorbike': 13, 45 | 'person': 14, 46 | 'pottedplant': 15, 47 | 'sheep': 16, 48 | 'sofa': 17, 49 | 'train': 18, 50 | 'tvmonitor': 19 51 | } 52 | 53 | 54 | def _findNode(parent, name, debug_name=None, parse=None): 55 | if debug_name is None: 56 | debug_name = name 57 | 58 | result = parent.find(name) 59 | if result is None: 60 | raise ValueError('missing element \'{}\''.format(debug_name)) 61 | if parse is not None: 62 | try: 63 | return parse(result.text) 64 | except ValueError as e: 65 | raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None) 66 | return result 67 | 68 | 69 | class PascalVocGenerator(Generator): 70 | """ 71 | Generate data for a Pascal VOC dataset. 72 | 73 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information. 74 | """ 75 | 76 | def __init__( 77 | self, 78 | data_dir, 79 | set_name, 80 | classes=voc_classes, 81 | image_extension='.jpg', 82 | skip_truncated=False, 83 | skip_difficult=False, 84 | **kwargs 85 | ): 86 | """ 87 | Initialize a Pascal VOC data generator. 88 | 89 | Args: 90 | data_dir: the path of directory which contains ImageSets directory 91 | set_name: test|trainval|train|val 92 | classes: class names tos id mapping 93 | image_extension: image filename ext 94 | skip_truncated: 95 | skip_difficult: 96 | **kwargs: 97 | """ 98 | self.data_dir = data_dir 99 | self.set_name = set_name 100 | self.classes = classes 101 | self.image_names = [l.strip().split(None, 1)[0] for l in 102 | open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()] 103 | self.image_extension = image_extension 104 | self.skip_truncated = skip_truncated 105 | self.skip_difficult = skip_difficult 106 | # class ids to names mapping 107 | self.labels = {} 108 | for key, value in self.classes.items(): 109 | self.labels[value] = key 110 | 111 | super(PascalVocGenerator, self).__init__(**kwargs) 112 | 113 | def size(self): 114 | """ 115 | Size of the dataset. 116 | """ 117 | return len(self.image_names) 118 | 119 | def num_classes(self): 120 | """ 121 | Number of classes in the dataset. 122 | """ 123 | return len(self.classes) 124 | 125 | def has_label(self, label): 126 | """ 127 | Return True if label is a known label. 128 | """ 129 | return label in self.labels 130 | 131 | def has_name(self, name): 132 | """ 133 | Returns True if name is a known class. 134 | """ 135 | return name in self.classes 136 | 137 | def name_to_label(self, name): 138 | """ 139 | Map name to label. 140 | """ 141 | return self.classes[name] 142 | 143 | def label_to_name(self, label): 144 | """ 145 | Map label to name. 146 | """ 147 | return self.labels[label] 148 | 149 | def image_aspect_ratio(self, image_index): 150 | """ 151 | Compute the aspect ratio for an image with image_index. 152 | """ 153 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 154 | image = Image.open(path) 155 | return float(image.width) / float(image.height) 156 | 157 | def load_image(self, image_index): 158 | """ 159 | Load an image at the image_index. 160 | """ 161 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 162 | return read_image_bgr(path) 163 | 164 | def __parse_annotation(self, element): 165 | """ 166 | Parse an annotation given an XML element. 167 | """ 168 | truncated = _findNode(element, 'truncated', parse=int) 169 | difficult = _findNode(element, 'difficult', parse=int) 170 | 171 | class_name = _findNode(element, 'name').text 172 | if class_name not in self.classes: 173 | raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys()))) 174 | 175 | box = np.zeros((4,)) 176 | label = self.name_to_label(class_name) 177 | 178 | bndbox = _findNode(element, 'bndbox') 179 | box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1 180 | box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1 181 | box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1 182 | box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1 183 | 184 | return truncated, difficult, box, label 185 | 186 | def __parse_annotations(self, xml_root): 187 | """ 188 | Parse all annotations under the xml_root. 189 | """ 190 | annotations = {'labels': np.empty((0,), dtype=np.int32), 191 | 'bboxes': np.empty((0, 4))} 192 | for i, element in enumerate(xml_root.iter('object')): 193 | try: 194 | truncated, difficult, box, label = self.__parse_annotation(element) 195 | except ValueError as e: 196 | raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None) 197 | 198 | if truncated and self.skip_truncated: 199 | continue 200 | if difficult and self.skip_difficult: 201 | continue 202 | 203 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [box]]) 204 | annotations['labels'] = np.concatenate([annotations['labels'], [label]]) 205 | 206 | return annotations 207 | 208 | def load_annotations(self, image_index): 209 | """ 210 | Load annotations for an image_index. 211 | """ 212 | filename = self.image_names[image_index] + '.xml' 213 | try: 214 | tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename)) 215 | return self.__parse_annotations(tree.getroot()) 216 | except ET.ParseError as e: 217 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 218 | except ValueError as e: 219 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 220 | -------------------------------------------------------------------------------- /yolo/generators/pascal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import cv2 18 | import numpy as np 19 | import os 20 | from six import raise_from 21 | 22 | from yolo.generators.common import Generator 23 | 24 | try: 25 | import xml.etree.cElementTree as ET 26 | except ImportError: 27 | import xml.etree.ElementTree as ET 28 | 29 | voc_classes = { 30 | 'aeroplane': 0, 31 | 'bicycle': 1, 32 | 'bird': 2, 33 | 'boat': 3, 34 | 'bottle': 4, 35 | 'bus': 5, 36 | 'car': 6, 37 | 'cat': 7, 38 | 'chair': 8, 39 | 'cow': 9, 40 | 'diningtable': 10, 41 | 'dog': 11, 42 | 'horse': 12, 43 | 'motorbike': 13, 44 | 'person': 14, 45 | 'pottedplant': 15, 46 | 'sheep': 16, 47 | 'sofa': 17, 48 | 'train': 18, 49 | 'tvmonitor': 19 50 | } 51 | 52 | 53 | def _findNode(parent, name, debug_name=None, parse=None): 54 | if debug_name is None: 55 | debug_name = name 56 | 57 | result = parent.find(name) 58 | if result is None: 59 | raise ValueError('missing element \'{}\''.format(debug_name)) 60 | if parse is not None: 61 | try: 62 | return parse(result.text) 63 | except ValueError as e: 64 | raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None) 65 | return result 66 | 67 | 68 | class PascalVocGenerator(Generator): 69 | """ 70 | Generate data for a Pascal VOC dataset. 71 | 72 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | data_dir, 78 | set_name, 79 | classes=voc_classes, 80 | image_extension='.jpg', 81 | skip_truncated=False, 82 | skip_difficult=False, 83 | **kwargs 84 | ): 85 | """ 86 | Initialize a Pascal VOC data generator. 87 | 88 | Args: 89 | data_dir: the path of directory which contains ImageSets directory 90 | set_name: test|trainval|train|val 91 | classes: class names tos id mapping 92 | image_extension: image filename ext 93 | skip_truncated: 94 | skip_difficult: 95 | **kwargs: 96 | """ 97 | self.data_dir = data_dir 98 | self.set_name = set_name 99 | self.classes = classes 100 | self.image_names = [l.strip().split(None, 1)[0] for l in 101 | open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()] 102 | self.image_extension = image_extension 103 | self.skip_truncated = skip_truncated 104 | self.skip_difficult = skip_difficult 105 | # class ids to names mapping 106 | self.labels = {} 107 | for key, value in self.classes.items(): 108 | self.labels[value] = key 109 | 110 | super(PascalVocGenerator, self).__init__(**kwargs) 111 | 112 | def size(self): 113 | """ 114 | Size of the dataset. 115 | """ 116 | return len(self.image_names) 117 | 118 | def num_classes(self): 119 | """ 120 | Number of classes in the dataset. 121 | """ 122 | return len(self.classes) 123 | 124 | def has_label(self, label): 125 | """ 126 | Return True if label is a known label. 127 | """ 128 | return label in self.labels 129 | 130 | def has_name(self, name): 131 | """ 132 | Returns True if name is a known class. 133 | """ 134 | return name in self.classes 135 | 136 | def name_to_label(self, name): 137 | """ 138 | Map name to label. 139 | """ 140 | return self.classes[name] 141 | 142 | def label_to_name(self, label): 143 | """ 144 | Map label to name. 145 | """ 146 | return self.labels[label] 147 | 148 | def image_aspect_ratio(self, image_index): 149 | """ 150 | Compute the aspect ratio for an image with image_index. 151 | """ 152 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 153 | image = cv2.imread(path) 154 | h, w = image.shape[:2] 155 | return float(w) / float(h) 156 | 157 | def load_image(self, image_index): 158 | """ 159 | Load an image at the image_index. 160 | """ 161 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension) 162 | image = cv2.imread(path) 163 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 164 | return image 165 | 166 | def __parse_annotation(self, element): 167 | """ 168 | Parse an annotation given an XML element. 169 | """ 170 | truncated = _findNode(element, 'truncated', parse=int) 171 | difficult = _findNode(element, 'difficult', parse=int) 172 | 173 | class_name = _findNode(element, 'name').text 174 | if class_name not in self.classes: 175 | raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys()))) 176 | 177 | box = np.zeros((4,)) 178 | label = self.name_to_label(class_name) 179 | 180 | bndbox = _findNode(element, 'bndbox') 181 | box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1 182 | box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1 183 | box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1 184 | box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1 185 | 186 | return truncated, difficult, box, label 187 | 188 | def __parse_annotations(self, xml_root): 189 | """ 190 | Parse all annotations under the xml_root. 191 | """ 192 | annotations = {'labels': np.empty((0,), dtype=np.int32), 193 | 'bboxes': np.empty((0, 4))} 194 | for i, element in enumerate(xml_root.iter('object')): 195 | try: 196 | truncated, difficult, box, label = self.__parse_annotation(element) 197 | except ValueError as e: 198 | raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None) 199 | 200 | if truncated and self.skip_truncated: 201 | continue 202 | if difficult and self.skip_difficult: 203 | continue 204 | 205 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [box]]) 206 | annotations['labels'] = np.concatenate([annotations['labels'], [label]]) 207 | 208 | return annotations 209 | 210 | def load_annotations(self, image_index): 211 | """ 212 | Load annotations for an image_index. 213 | """ 214 | filename = self.image_names[image_index] + '.xml' 215 | try: 216 | tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename)) 217 | return self.__parse_annotations(tree.getroot()) 218 | except ET.ParseError as e: 219 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 220 | except ValueError as e: 221 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None) 222 | -------------------------------------------------------------------------------- /yolo/eval/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | from pycocotools.cocoeval import COCOeval 19 | import numpy as np 20 | import json 21 | from tqdm import trange 22 | import cv2 23 | 24 | from generators.coco import CocoGenerator 25 | from model import yolo_body 26 | 27 | 28 | def evaluate(generator, model, threshold=0.01): 29 | """ 30 | Use the pycocotools to evaluate a COCO model on a dataset. 31 | 32 | Args 33 | generator: The generator for generating the evaluation data. 34 | model: The model to evaluate. 35 | threshold: The score threshold to use. 36 | """ 37 | # start collecting results 38 | results = [] 39 | image_ids = [] 40 | for index in trange(generator.size(), desc='COCO evaluation: '): 41 | image = generator.load_image(index) 42 | src_image = image.copy() 43 | image_shape = image.shape[:2] 44 | image_shape = np.array(image_shape) 45 | image = generator.preprocess_image(image) 46 | 47 | # run network 48 | detections = model.predict_on_batch([np.expand_dims(image, axis=0), np.expand_dims(image_shape, axis=0)])[0] 49 | 50 | # change to (x, y, w, h) (MS COCO standard) 51 | boxes = np.zeros((detections.shape[0], 4), dtype=np.int32) 52 | # xmin 53 | boxes[:, 0] = np.maximum(np.round(detections[:, 1]).astype(np.int32), 0) 54 | # ymin 55 | boxes[:, 1] = np.maximum(np.round(detections[:, 0]).astype(np.int32), 0) 56 | # w 57 | boxes[:, 2] = np.minimum(np.round(detections[:, 3] - detections[:, 1]).astype(np.int32), image_shape[1]) 58 | # h 59 | boxes[:, 3] = np.minimum(np.round(detections[:, 2] - detections[:, 0]).astype(np.int32), image_shape[0]) 60 | scores = detections[:, 4] 61 | class_ids = detections[:, 5].astype(np.int32) 62 | # compute predicted labels and scores 63 | for box, score, class_id in zip(boxes, scores, class_ids): 64 | # scores are sorted, so we can break 65 | if score < threshold: 66 | break 67 | 68 | # append detection for each positively labeled class 69 | image_result = { 70 | 'image_id': generator.image_ids[index], 71 | 'category_id': generator.label_to_coco_label(class_id), 72 | 'score': float(score), 73 | 'bbox': box.tolist(), 74 | } 75 | # append detection to results 76 | results.append(image_result) 77 | class_name = generator.label_to_name(class_id) 78 | ret, baseline = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) 79 | cv2.rectangle(src_image, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (0, 255, 0), 1) 80 | cv2.putText(src_image, class_name, (box[0], box[1] + box[3] - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) 81 | cv2.namedWindow('image', cv2.WINDOW_NORMAL) 82 | cv2.imshow('image', src_image) 83 | cv2.waitKey(0) 84 | # append image to list of processed images 85 | image_ids.append(generator.image_ids[index]) 86 | 87 | if not len(results): 88 | return 89 | 90 | # write output 91 | json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4) 92 | json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4) 93 | 94 | # load results in COCO evaluation tool 95 | coco_true = generator.coco 96 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name)) 97 | 98 | # run COCO evaluation 99 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox') 100 | coco_eval.params.imgIds = image_ids 101 | coco_eval.evaluate() 102 | coco_eval.accumulate() 103 | coco_eval.summarize() 104 | return coco_eval.stats 105 | 106 | 107 | class Evaluate(keras.callbacks.Callback): 108 | """ Performs COCO evaluation on each epoch. 109 | """ 110 | 111 | def __init__(self, generator, model, tensorboard=None, threshold=0.01): 112 | """ Evaluate callback initializer. 113 | 114 | Args 115 | generator : The generator used for creating validation data. 116 | model: prediction model 117 | tensorboard : If given, the results will be written to tensorboard. 118 | threshold : The score threshold to use. 119 | """ 120 | self.generator = generator 121 | self.active_model = model 122 | self.threshold = threshold 123 | self.tensorboard = tensorboard 124 | 125 | super(Evaluate, self).__init__() 126 | 127 | def on_epoch_end(self, epoch, logs=None): 128 | logs = logs or {} 129 | 130 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 131 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]', 132 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]', 133 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 134 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 135 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', 136 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', 137 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', 138 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 139 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 140 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 141 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]'] 142 | coco_eval_stats = evaluate(self.generator, self.model, self.threshold) 143 | if coco_eval_stats is not None and self.tensorboard is not None and self.tensorboard.writer is not None: 144 | import tensorflow as tf 145 | summary = tf.Summary() 146 | for index, result in enumerate(coco_eval_stats): 147 | summary_value = summary.value.add() 148 | summary_value.simple_value = result 149 | summary_value.tag = '{}. {}'.format(index + 1, coco_tag[index]) 150 | self.tensorboard.writer.add_summary(summary, epoch) 151 | logs[coco_tag[index]] = result 152 | 153 | 154 | if __name__ == '__main__': 155 | dataset_dir = '/home/adam/.keras/datasets/coco/2017_118_5' 156 | test_generator = CocoGenerator( 157 | anchors_path='yolo_anchors.txt', 158 | data_dir=dataset_dir, 159 | set_name='test-dev2017', 160 | shuffle_groups=False, 161 | ) 162 | input_shape = (416, 416) 163 | model, prediction_model = yolo_body(test_generator.anchors, num_classes=80) 164 | model.load_weights('checkpoints/yolov3_weights.h5', by_name=True) 165 | coco_eval_stats = evaluate(test_generator, model) 166 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 167 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]', 168 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]', 169 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 170 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 171 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', 172 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', 173 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', 174 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', 175 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', 176 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', 177 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]'] 178 | if coco_eval_stats is not None: 179 | for index, result in enumerate(coco_eval_stats): 180 | print([coco_tag[index]], result) 181 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import keras 18 | import keras.backend as K 19 | import tensorflow as tf 20 | 21 | 22 | def focal(alpha=0.25, gamma=2.0): 23 | """ 24 | Create a functor for computing the focal loss. 25 | 26 | Args 27 | alpha: Scale the focal weight with alpha. 28 | gamma: Take the power of the focal weight with gamma. 29 | 30 | Returns 31 | A functor that computes the focal loss using the alpha and gamma. 32 | """ 33 | 34 | def _focal(y_true, y_pred): 35 | """ 36 | Compute the focal loss given the target tensor and the predicted tensor. 37 | 38 | As defined in https://arxiv.org/abs/1708.02002 39 | 40 | Args 41 | y_true: Tensor of target data from the generator with shape (B, N, num_classes). 42 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes). 43 | 44 | Returns 45 | The focal loss of y_pred w.r.t. y_true. 46 | """ 47 | # compute the focal loss 48 | alpha_factor = K.ones_like(y_true) * alpha 49 | alpha_factor = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor) 50 | focal_weight = tf.where(K.equal(y_true, 1), 1 - y_pred, y_pred) 51 | focal_weight = alpha_factor * focal_weight ** gamma 52 | cls_loss = focal_weight * K.binary_crossentropy(y_true, y_pred) 53 | 54 | # compute the normalizer: the number of positive anchors 55 | normalizer = K.cast(K.shape(y_pred)[1], K.floatx()) 56 | normalizer = K.maximum(K.cast_to_floatx(1.0), normalizer) 57 | 58 | return K.sum(cls_loss) / normalizer 59 | 60 | return _focal 61 | 62 | 63 | def smooth_l1(sigma=3.0): 64 | """ 65 | Create a smooth L1 loss functor. 66 | 67 | Args 68 | sigma: This argument defines the point where the loss changes from L2 to L1. 69 | 70 | Returns 71 | A functor for computing the smooth L1 loss given target data and predicted data. 72 | """ 73 | sigma_squared = sigma ** 2 74 | 75 | def _smooth_l1(y_true, y_pred): 76 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true. 77 | 78 | Args 79 | y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive). 80 | y_pred: Tensor from the network of shape (B, N, 4). 81 | 82 | Returns 83 | The smooth L1 loss of y_pred w.r.t. y_true. 84 | """ 85 | # separate target and state 86 | regression = y_pred 87 | regression_target = y_true[:, :, :-1] 88 | anchor_state = y_true[:, :, -1] 89 | 90 | # filter out "ignore" anchors 91 | indices = tf.where(K.equal(anchor_state, 1)) 92 | regression = tf.gather_nd(regression, indices) 93 | regression_target = tf.gather_nd(regression_target, indices) 94 | 95 | # compute smooth L1 loss 96 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma 97 | # |x| - 0.5 / sigma / sigma otherwise 98 | regression_diff = regression - regression_target 99 | regression_diff = K.abs(regression_diff) 100 | regression_loss = tf.where( 101 | K.less(regression_diff, 1.0 / sigma_squared), 102 | 0.5 * sigma_squared * K.pow(regression_diff, 2), 103 | regression_diff - 0.5 / sigma_squared 104 | ) 105 | 106 | # compute the normalizer: the number of positive anchors 107 | normalizer = K.maximum(1, K.shape(indices)[0]) 108 | normalizer = K.cast(normalizer, dtype=K.floatx()) 109 | return K.sum(regression_loss) / normalizer 110 | 111 | return _smooth_l1 112 | 113 | 114 | def iou(): 115 | def _iou(y_true, y_pred): 116 | y_true = tf.maximum(y_true, 0) 117 | pred_left = y_pred[:, :, 0] 118 | pred_top = y_pred[:, :, 1] 119 | pred_right = y_pred[:, :, 2] 120 | pred_bottom = y_pred[:, :, 3] 121 | 122 | # (num_pos, ) 123 | target_left = y_true[:, :, 0] 124 | target_top = y_true[:, :, 1] 125 | target_right = y_true[:, :, 2] 126 | target_bottom = y_true[:, :, 3] 127 | 128 | target_area = (target_left + target_right) * (target_top + target_bottom) 129 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) 130 | w_intersect = tf.minimum(pred_left, target_left) + tf.minimum(pred_right, target_right) 131 | h_intersect = tf.minimum(pred_bottom, target_bottom) + tf.minimum(pred_top, target_top) 132 | 133 | area_intersect = w_intersect * h_intersect 134 | area_union = target_area + pred_area - area_intersect 135 | 136 | # (num_pos, ) 137 | iou_loss = -tf.log((area_intersect + 1e-7) / (area_union + 1e-7)) 138 | # compute the normalizer: the number of positive anchors 139 | normalizer = K.maximum(1, K.shape(y_true)[1]) 140 | normalizer = K.cast(normalizer, dtype=K.floatx()) 141 | return K.sum(iou_loss) / normalizer 142 | 143 | return _iou 144 | 145 | 146 | def focal_with_mask(alpha=0.25, gamma=2.0): 147 | """ 148 | Create a functor for computing the focal loss. 149 | 150 | Args 151 | alpha: Scale the focal weight with alpha. 152 | gamma: Take the power of the focal weight with gamma. 153 | 154 | Returns 155 | A functor that computes the focal loss using the alpha and gamma. 156 | """ 157 | 158 | def _focal(inputs): 159 | """ 160 | Compute the focal loss given the target tensor and the predicted tensor. 161 | 162 | As defined in https://arxiv.org/abs/1708.02002 163 | 164 | Args 165 | y_true: Tensor of target data from the generator with shape (B, N, num_classes). 166 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes). 167 | cls_mask: (B, N) 168 | cls_num_pos: (B, ) 169 | 170 | Returns 171 | The focal loss of y_pred w.r.t. y_true. 172 | """ 173 | # compute the focal loss 174 | y_true, y_pred, cls_mask, cls_num_pos = inputs[0], inputs[1], inputs[2], inputs[3] 175 | alpha_factor = K.ones_like(y_true) * alpha 176 | alpha_factor = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor) 177 | focal_weight = tf.where(K.equal(y_true, 1), 1 - y_pred, y_pred) 178 | focal_weight = alpha_factor * focal_weight ** gamma 179 | # (B, N) --> (B, N, 1) 180 | cls_mask = tf.cast(cls_mask, tf.float32) 181 | cls_mask = tf.expand_dims(cls_mask, axis=-1) 182 | # (B, N, num_classes) * (B, N, 1) 183 | masked_cls_loss = focal_weight * K.binary_crossentropy(y_true, y_pred) * cls_mask 184 | # compute the normalizer: the number of positive locations 185 | normalizer = K.maximum(K.cast_to_floatx(1.0), tf.reduce_sum(cls_num_pos)) 186 | return K.sum(masked_cls_loss) / normalizer 187 | 188 | return _focal 189 | 190 | 191 | def iou_with_mask(): 192 | def _iou(inputs): 193 | """ 194 | 195 | Args: 196 | inputs: y_true: (B, N, 4) y_pred: (B, N, 4) regr_mask: (B, N) 197 | 198 | Returns: 199 | 200 | """ 201 | y_true, y_pred, regr_mask = inputs[0], inputs[1], inputs[2] 202 | y_true = tf.maximum(y_true, 0) 203 | pred_left = y_pred[:, :, 0] 204 | pred_top = y_pred[:, :, 1] 205 | pred_right = y_pred[:, :, 2] 206 | pred_bottom = y_pred[:, :, 3] 207 | 208 | # (B, N) 209 | target_left = y_true[:, :, 0] 210 | target_top = y_true[:, :, 1] 211 | target_right = y_true[:, :, 2] 212 | target_bottom = y_true[:, :, 3] 213 | 214 | # (B, N) 215 | target_area = (target_left + target_right) * (target_top + target_bottom) 216 | masked_target_area = tf.boolean_mask(target_area, regr_mask) 217 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) 218 | masked_pred_area = tf.boolean_mask(pred_area, regr_mask) 219 | w_intersect = tf.minimum(pred_left, target_left) + tf.minimum(pred_right, target_right) 220 | h_intersect = tf.minimum(pred_bottom, target_bottom) + tf.minimum(pred_top, target_top) 221 | 222 | area_intersect = w_intersect * h_intersect 223 | masked_area_intersect = tf.boolean_mask(area_intersect, regr_mask) 224 | masked_area_union = masked_target_area + masked_pred_area - masked_area_intersect 225 | 226 | # (B, N) 227 | masked_iou_loss = -tf.log((masked_area_intersect + 1e-7) / (masked_area_union + 1e-7)) 228 | regr_mask = tf.cast(regr_mask, tf.float32) 229 | # compute the normalizer: the number of positive locations 230 | regr_num_pos = tf.reduce_sum(regr_mask) 231 | normalizer = K.maximum(K.cast_to_floatx(1.), regr_num_pos) 232 | return K.sum(masked_iou_loss) / normalizer 233 | 234 | return _iou 235 | -------------------------------------------------------------------------------- /generators/csv_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 yhenon (https://github.com/yhenon/) 3 | Copyright 2017-2018 Fizyr (https://fizyr.com) 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | from .generator import Generator 19 | from utils.image import read_image_bgr 20 | 21 | import numpy as np 22 | from PIL import Image 23 | from six import raise_from 24 | import csv 25 | import sys 26 | import os.path 27 | from collections import OrderedDict 28 | 29 | 30 | def _parse(value, function, fmt): 31 | """ 32 | Parse a string into a value, and format a nice ValueError if it fails. 33 | 34 | Returns `function(value)`. 35 | Any `ValueError` raised is catched and a new `ValueError` is raised 36 | with message `fmt.format(e)`, where `e` is the caught `ValueError`. 37 | """ 38 | try: 39 | return function(value) 40 | except ValueError as e: 41 | raise_from(ValueError(fmt.format(e)), None) 42 | 43 | 44 | def _read_classes(csv_reader): 45 | """ 46 | Parse the classes file given by csv_reader. 47 | """ 48 | result = OrderedDict() 49 | for line, row in enumerate(csv_reader): 50 | line += 1 51 | 52 | try: 53 | class_name, class_id = row 54 | except ValueError: 55 | raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None) 56 | class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line)) 57 | 58 | if class_name in result: 59 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) 60 | result[class_name] = class_id 61 | return result 62 | 63 | 64 | def _read_annotations(csv_reader, classes): 65 | """ 66 | Read annotations from the csv_reader. 67 | """ 68 | result = OrderedDict() 69 | for line, row in enumerate(csv_reader): 70 | line += 1 71 | 72 | try: 73 | img_file, x1, y1, x2, y2, class_name = row[:6] 74 | except ValueError: 75 | raise_from(ValueError( 76 | 'line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), 77 | None) 78 | 79 | if img_file not in result: 80 | result[img_file] = [] 81 | 82 | # If a row contains only an image path, it's an image without annotations. 83 | if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''): 84 | continue 85 | 86 | x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line)) 87 | y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line)) 88 | x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line)) 89 | y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line)) 90 | 91 | # Check that the bounding box is valid. 92 | if x2 <= x1: 93 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) 94 | if y2 <= y1: 95 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) 96 | 97 | # check if the current class name is correctly present 98 | if class_name not in classes: 99 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes)) 100 | 101 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name}) 102 | return result 103 | 104 | 105 | def _open_for_csv(path): 106 | """ 107 | Open a file with flags suitable for csv.reader. 108 | 109 | This is different for python2 it means with mode 'rb', for python3 this means 'r' with "universal newlines". 110 | """ 111 | if sys.version_info[0] < 3: 112 | return open(path, 'rb') 113 | else: 114 | return open(path, 'r', newline='') 115 | 116 | 117 | class CSVGenerator(Generator): 118 | """ 119 | Generate data for a custom CSV dataset. 120 | 121 | See https://github.com/fizyr/keras-retinanet#csv-datasets for more information. 122 | """ 123 | 124 | def __init__( 125 | self, 126 | csv_data_file, 127 | csv_class_file, 128 | base_dir=None, 129 | **kwargs 130 | ): 131 | """ 132 | Initialize a CSV data generator. 133 | 134 | Args 135 | csv_data_file: Path to the CSV annotations file. 136 | csv_class_file: Path to the CSV classes file. 137 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 138 | """ 139 | self.image_names = [] 140 | self.image_data = {} 141 | self.base_dir = base_dir 142 | 143 | # Take base_dir from annotations file if not explicitly specified. 144 | if self.base_dir is None: 145 | self.base_dir = os.path.dirname(csv_data_file) 146 | 147 | # parse the provided class file 148 | try: 149 | with _open_for_csv(csv_class_file) as file: 150 | # class_name --> class_id 151 | self.classes = _read_classes(csv.reader(file, delimiter=',')) 152 | except ValueError as e: 153 | raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None) 154 | 155 | self.labels = {} 156 | # class_id --> class_name 157 | for key, value in self.classes.items(): 158 | self.labels[value] = key 159 | 160 | # csv with img_path, x1, y1, x2, y2, class_name 161 | try: 162 | with _open_for_csv(csv_data_file) as file: 163 | # {'img_path1':[{'x1':xx,'y1':xx,'x2':xx,'y2':xx,'class':xx}...],...} 164 | self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes) 165 | except ValueError as e: 166 | raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None) 167 | self.image_names = list(self.image_data.keys()) 168 | 169 | super(CSVGenerator, self).__init__(**kwargs) 170 | 171 | def size(self): 172 | """ 173 | Size of the dataset. 174 | """ 175 | return len(self.image_names) 176 | 177 | def num_classes(self): 178 | """ 179 | Number of classes in the dataset. 180 | """ 181 | return max(self.classes.values()) + 1 182 | 183 | def has_label(self, label): 184 | """ 185 | Return True if label is a known label. 186 | """ 187 | return label in self.labels 188 | 189 | def has_name(self, name): 190 | """ 191 | Returns True if name is a known class. 192 | """ 193 | return name in self.classes 194 | 195 | def name_to_label(self, name): 196 | """ 197 | Map name to label. 198 | """ 199 | return self.classes[name] 200 | 201 | def label_to_name(self, label): 202 | """ 203 | Map label to name. 204 | """ 205 | return self.labels[label] 206 | 207 | def image_path(self, image_index): 208 | """ 209 | Returns the image path for image_index. 210 | """ 211 | return os.path.join(self.base_dir, self.image_names[image_index]) 212 | 213 | def image_aspect_ratio(self, image_index): 214 | """ 215 | Compute the aspect ratio for an image with image_index. 216 | """ 217 | # PIL is fast for metadata 218 | image = Image.open(self.image_path(image_index)) 219 | return float(image.width) / float(image.height) 220 | 221 | def load_image(self, image_index): 222 | """ 223 | Load an image at the image_index. 224 | """ 225 | return read_image_bgr(self.image_path(image_index)) 226 | 227 | def load_annotations(self, image_index): 228 | """ 229 | Load annotations for an image_index. 230 | """ 231 | path = self.image_names[image_index] 232 | annotations = {'labels': np.empty((0,), dtype=np.int32), 'bboxes': np.empty((0, 4))} 233 | 234 | for idx, annot in enumerate(self.image_data[path]): 235 | annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])])) 236 | annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[ 237 | float(annot['x1']), 238 | float(annot['y1']), 239 | float(annot['x2']), 240 | float(annot['y2']), 241 | ]])) 242 | 243 | return annotations 244 | 245 | 246 | if __name__ == '__main__': 247 | import cv2 248 | csv_generator = CSVGenerator( 249 | csv_data_file='/home/adam/workspace/github/keras-retinanet_vat/val_gray_annotations_20190615_1255_127.csv', 250 | csv_class_file='/home/adam/workspace/github/keras-retinanet_vat/vat_classes.csv' 251 | ) 252 | for image_group, annotation_group, targets in csv_generator: 253 | locations = targets[0] 254 | batch_regr_targets = targets[1] 255 | batch_cls_targets = targets[2] 256 | batch_centerness_targets = targets[3] 257 | for image, annotation, regr_targets, cls_targets, centerness_targets in zip(image_group, annotation_group, batch_regr_targets, batch_cls_targets, batch_centerness_targets): 258 | gt_boxes = annotation['bboxes'] 259 | for gt_box in gt_boxes: 260 | gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_box 261 | cv2.rectangle(image, (int(gt_xmin), int(gt_ymin)), (int(gt_xmax), int(gt_ymax)), (0, 255, 0), 2) 262 | pos_indices = np.where(centerness_targets[:, 1] == 1)[0] 263 | for pos_index in pos_indices: 264 | cx, cy = locations[pos_index] 265 | l, t, r, b, _ = regr_targets[pos_index] 266 | xmin = cx - l 267 | ymin = cy - t 268 | xmax = cx + r 269 | ymax = cy + b 270 | class_id = np.argmax(cls_targets[pos_index]) 271 | centerness = centerness_targets[pos_index][0] 272 | cv2.putText(image, '{:.2f}'.format(centerness), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 0, 255), 2) 273 | cv2.putText(image, str(class_id), (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (0, 0, 0), 3) 274 | cv2.circle(image, (round(cx), round(cy)), 5, (255, 0, 0), -1) 275 | cv2.rectangle(image, (round(xmin), round(ymin)), (round(xmax), round(ymax)), (0, 0, 255), 2) 276 | cv2.namedWindow('image', cv2.WINDOW_NORMAL) 277 | cv2.imshow('image', image) 278 | cv2.waitKey(0) 279 | pass 280 | -------------------------------------------------------------------------------- /yolo/generators/csv_.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 yhenon (https://github.com/yhenon/) 3 | Copyright 2017-2018 Fizyr (https://fizyr.com) 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | from collections import OrderedDict 18 | import csv 19 | import numpy as np 20 | import os.path 21 | from PIL import Image 22 | from six import raise_from 23 | import sys 24 | 25 | from utils.image import read_image_bgr 26 | from yolo.generators.common import Generator 27 | 28 | 29 | def _parse(value, function, fmt): 30 | """ 31 | Parse a string into a value, and format a nice ValueError if it fails. 32 | 33 | Returns `function(value)`. 34 | Any `ValueError` raised is catched and a new `ValueError` is raised 35 | with message `fmt.format(e)`, where `e` is the caught `ValueError`. 36 | """ 37 | try: 38 | return function(value) 39 | except ValueError as e: 40 | raise_from(ValueError(fmt.format(e)), None) 41 | 42 | 43 | def _read_classes(csv_reader): 44 | """ 45 | Parse the classes file given by csv_reader. 46 | """ 47 | result = OrderedDict() 48 | for line, row in enumerate(csv_reader): 49 | line += 1 50 | 51 | try: 52 | class_name, class_id = row 53 | except ValueError: 54 | raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None) 55 | class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line)) 56 | 57 | if class_name in result: 58 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) 59 | result[class_name] = class_id 60 | return result 61 | 62 | 63 | def _read_annotations(csv_reader, classes): 64 | """ 65 | Read annotations from the csv_reader. 66 | """ 67 | result = OrderedDict() 68 | for line, row in enumerate(csv_reader): 69 | line += 1 70 | 71 | try: 72 | img_file, x1, y1, x2, y2, class_name = row[:6] 73 | except ValueError: 74 | raise_from(ValueError( 75 | 'line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), 76 | None) 77 | 78 | if img_file not in result: 79 | result[img_file] = [] 80 | 81 | # If a row contains only an image path, it's an image without annotations. 82 | if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''): 83 | continue 84 | 85 | x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line)) 86 | y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line)) 87 | x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line)) 88 | y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line)) 89 | 90 | # Check that the bounding box is valid. 91 | if x2 <= x1: 92 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) 93 | if y2 <= y1: 94 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) 95 | 96 | # check if the current class name is correctly present 97 | if class_name not in classes: 98 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes)) 99 | 100 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name}) 101 | return result 102 | 103 | 104 | def _open_for_csv(path): 105 | """ 106 | Open a file with flags suitable for csv.reader. 107 | 108 | This is different for python2 it means with mode 'rb', for python3 this means 'r' with "universal newlines". 109 | """ 110 | if sys.version_info[0] < 3: 111 | return open(path, 'rb') 112 | else: 113 | return open(path, 'r', newline='') 114 | 115 | 116 | class CSVGenerator(Generator): 117 | """ 118 | Generate data for a custom CSV dataset. 119 | 120 | See https://github.com/fizyr/keras-retinanet#csv-datasets for more information. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | csv_data_file, 126 | csv_class_file, 127 | base_dir=None, 128 | **kwargs 129 | ): 130 | """ 131 | Initialize a CSV data generator. 132 | 133 | Args 134 | csv_data_file: Path to the CSV annotations file. 135 | csv_class_file: Path to the CSV classes file. 136 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). 137 | """ 138 | self.image_names = [] 139 | self.image_data = {} 140 | self.base_dir = base_dir 141 | 142 | # Take base_dir from annotations file if not explicitly specified. 143 | if self.base_dir is None: 144 | self.base_dir = os.path.dirname(csv_data_file) 145 | 146 | # parse the provided class file 147 | try: 148 | with _open_for_csv(csv_class_file) as file: 149 | # class_name --> class_id 150 | self.classes = _read_classes(csv.reader(file, delimiter=',')) 151 | except ValueError as e: 152 | raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None) 153 | 154 | self.labels = {} 155 | # class_id --> class_name 156 | for key, value in self.classes.items(): 157 | self.labels[value] = key 158 | 159 | # csv with img_path, x1, y1, x2, y2, class_name 160 | try: 161 | with _open_for_csv(csv_data_file) as file: 162 | # {'img_path1':[{'x1':xx,'y1':xx,'x2':xx,'y2':xx,'class':xx}...],...} 163 | self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes) 164 | except ValueError as e: 165 | raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None) 166 | self.image_names = list(self.image_data.keys()) 167 | 168 | super(CSVGenerator, self).__init__(**kwargs) 169 | 170 | def size(self): 171 | """ 172 | Size of the dataset. 173 | """ 174 | return len(self.image_names) 175 | 176 | def num_classes(self): 177 | """ 178 | Number of classes in the dataset. 179 | """ 180 | return max(self.classes.values()) + 1 181 | 182 | def has_label(self, label): 183 | """ 184 | Return True if label is a known label. 185 | """ 186 | return label in self.labels 187 | 188 | def has_name(self, name): 189 | """ 190 | Returns True if name is a known class. 191 | """ 192 | return name in self.classes 193 | 194 | def name_to_label(self, name): 195 | """ 196 | Map name to label. 197 | """ 198 | return self.classes[name] 199 | 200 | def label_to_name(self, label): 201 | """ 202 | Map label to name. 203 | """ 204 | return self.labels[label] 205 | 206 | def image_path(self, image_index): 207 | """ 208 | Returns the image path for image_index. 209 | """ 210 | return os.path.join(self.base_dir, self.image_names[image_index]) 211 | 212 | def image_aspect_ratio(self, image_index): 213 | """ 214 | Compute the aspect ratio for an image with image_index. 215 | """ 216 | # PIL is fast for metadata 217 | image = Image.open(self.image_path(image_index)) 218 | return float(image.width) / float(image.height) 219 | 220 | def load_image(self, image_index): 221 | """ 222 | Load an image at the image_index. 223 | """ 224 | return read_image_bgr(self.image_path(image_index)) 225 | 226 | def load_annotations(self, image_index): 227 | """ 228 | Load annotations for an image_index. 229 | """ 230 | path = self.image_names[image_index] 231 | annotations = {'labels': np.empty((0,), dtype=np.int32), 'bboxes': np.empty((0, 4))} 232 | 233 | for idx, annot in enumerate(self.image_data[path]): 234 | annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])])) 235 | annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[ 236 | float(annot['x1']), 237 | float(annot['y1']), 238 | float(annot['x2']), 239 | float(annot['y2']), 240 | ]])) 241 | 242 | return annotations 243 | 244 | 245 | if __name__ == '__main__': 246 | import cv2 247 | 248 | csv_generator = CSVGenerator( 249 | csv_data_file='/home/adam/workspace/github/keras-retinanet_vat/val_gray_annotations_20190615_1255_127.csv', 250 | csv_class_file='/home/adam/workspace/github/keras-retinanet_vat/vat_classes.csv' 251 | ) 252 | for image_group, annotation_group, targets in csv_generator: 253 | locations = targets[0] 254 | batch_regr_targets = targets[1] 255 | batch_cls_targets = targets[2] 256 | batch_centerness_targets = targets[3] 257 | for image, annotation, regr_targets, cls_targets, centerness_targets in zip(image_group, annotation_group, 258 | batch_regr_targets, 259 | batch_cls_targets, 260 | batch_centerness_targets): 261 | gt_boxes = annotation['bboxes'] 262 | for gt_box in gt_boxes: 263 | gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_box 264 | cv2.rectangle(image, (int(gt_xmin), int(gt_ymin)), (int(gt_xmax), int(gt_ymax)), (0, 255, 0), 2) 265 | pos_indices = np.where(centerness_targets[:, 1] == 1)[0] 266 | for pos_index in pos_indices: 267 | cx, cy = locations[pos_index] 268 | l, t, r, b, _ = regr_targets[pos_index] 269 | xmin = cx - l 270 | ymin = cy - t 271 | xmax = cx + r 272 | ymax = cy + b 273 | class_id = np.argmax(cls_targets[pos_index]) 274 | centerness = centerness_targets[pos_index][0] 275 | cv2.putText(image, '{:.2f}'.format(centerness), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 0, 255), 276 | 2) 277 | cv2.putText(image, str(class_id), (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (0, 0, 0), 3) 278 | cv2.circle(image, (round(cx), round(cy)), 5, (255, 0, 0), -1) 279 | cv2.rectangle(image, (round(xmin), round(ymin)), (round(xmax), round(ymax)), (0, 0, 255), 2) 280 | cv2.namedWindow('image', cv2.WINDOW_NORMAL) 281 | cv2.imshow('image', image) 282 | cv2.waitKey(0) 283 | pass 284 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import numpy as np 18 | 19 | DEFAULT_PRNG = np.random 20 | 21 | 22 | def colvec(*args): 23 | """ 24 | Create a numpy array representing a column vector. 25 | """ 26 | return np.array([args]).T 27 | 28 | 29 | def transform_aabb(transform, aabb): 30 | """ 31 | Apply a transformation to an axis aligned bounding box. 32 | 33 | The result is a new AABB in the same coordinate system as the original AABB. 34 | The new AABB contains all corner points of the original AABB after applying the given transformation. 35 | 36 | Args 37 | transform: The transformation to apply. 38 | x1: The minimum x value of the AABB. 39 | y1: The minimum y value of the AABB. 40 | x2: The maximum x value of the AABB. 41 | y2: The maximum y value of the AABB. 42 | Returns 43 | The new AABB as tuple (x1, y1, x2, y2) 44 | """ 45 | x1, y1, x2, y2 = aabb 46 | # Transform all 4 corners of the AABB. 47 | points = transform.dot([ 48 | [x1, x2, x1, x2], 49 | [y1, y2, y2, y1], 50 | [1, 1, 1, 1], 51 | ]) 52 | 53 | # Extract the min and max corners again. 54 | # (3, ) (min_x, min_y, 1) 55 | min_corner = points.min(axis=1) 56 | # (3, ) (max_x, max_y, 1) 57 | max_corner = points.max(axis=1) 58 | 59 | return [min_corner[0], min_corner[1], max_corner[0], max_corner[1]] 60 | 61 | 62 | def _random_vector(min, max, prng=DEFAULT_PRNG): 63 | """ 64 | Construct a random vector between min and max. 65 | 66 | Args 67 | min: the minimum value for each component, (n, ) 68 | max: the maximum value for each component, (n, ) 69 | """ 70 | min = np.array(min) 71 | max = np.array(max) 72 | assert min.shape == max.shape 73 | assert len(min.shape) == 1 74 | return prng.uniform(min, max) 75 | 76 | 77 | def rotation(angle): 78 | """ 79 | Construct a homogeneous 2D rotation matrix. 80 | 81 | Args 82 | angle: the angle in radians 83 | Returns 84 | the rotation matrix as 3 by 3 numpy array 85 | """ 86 | return np.array([ 87 | [np.cos(angle), -np.sin(angle), 0], 88 | [np.sin(angle), np.cos(angle), 0], 89 | [0, 0, 1] 90 | ]) 91 | 92 | 93 | def random_rotation(min, max, prng=DEFAULT_PRNG): 94 | """ 95 | Construct a random rotation between -max and max. 96 | 97 | Args 98 | min: a scalar for the minimum absolute angle in radians 99 | max: a scalar for the maximum absolute angle in radians 100 | prng: the pseudo-random number generator to use. 101 | Returns 102 | a homogeneous 3 by 3 rotation matrix 103 | """ 104 | return rotation(prng.uniform(min, max)) 105 | 106 | 107 | def translation(translation): 108 | """ 109 | Construct a homogeneous 2D translation matrix. 110 | 111 | Args: 112 | translation: the translation 2D vector 113 | 114 | Returns: 115 | the translation matrix as 3 by 3 numpy array 116 | 117 | """ 118 | return np.array([ 119 | [1, 0, translation[0]], 120 | [0, 1, translation[1]], 121 | [0, 0, 1] 122 | ]) 123 | 124 | 125 | def random_translation(min, max, prng=DEFAULT_PRNG): 126 | """ 127 | Construct a random 2D translation between min and max. 128 | 129 | Args 130 | min: a 2D vector with the minimum translation for each dimension 131 | max: a 2D vector with the maximum translation for each dimension 132 | prng: the pseudo-random number generator to use. 133 | Returns 134 | a homogeneous 3 by 3 translation matrix 135 | """ 136 | return translation(_random_vector(min, max, prng)) 137 | 138 | 139 | def shear(angle): 140 | """ 141 | Construct a homogeneous 2D shear matrix. 142 | 143 | Args 144 | angle: the shear angle in radians 145 | Returns 146 | the shear matrix as 3 by 3 numpy array 147 | """ 148 | return np.array([ 149 | [1, -np.sin(angle), 0], 150 | [0, np.cos(angle), 0], 151 | [0, 0, 1] 152 | ]) 153 | 154 | 155 | def random_shear(min, max, prng=DEFAULT_PRNG): 156 | """ 157 | Construct a random 2D shear matrix with shear angle between -max and max. 158 | 159 | Args 160 | min: the minimum shear angle in radians. 161 | max: the maximum shear angle in radians. 162 | prng: the pseudo-random number generator to use. 163 | Returns 164 | a homogeneous 3 by 3 shear matrix 165 | """ 166 | return shear(prng.uniform(min, max)) 167 | 168 | 169 | def scaling(factor): 170 | """ 171 | Construct a homogeneous 2D scaling matrix. 172 | 173 | Args 174 | factor: a 2D vector for X and Y scaling 175 | Returns 176 | the zoom matrix as 3 by 3 numpy array 177 | """ 178 | 179 | return np.array([ 180 | [factor[0], 0, 0], 181 | [0, factor[1], 0], 182 | [0, 0, 1] 183 | ]) 184 | 185 | 186 | def random_scaling(min, max, prng=DEFAULT_PRNG): 187 | """ 188 | Construct a random 2D scale matrix between -max and max. 189 | 190 | Args 191 | min: a 2D vector containing the minimum scaling factor for X and Y. 192 | min: a 2D vector containing The maximum scaling factor for X and Y. 193 | prng: the pseudo-random number generator to use. 194 | Returns 195 | a homogeneous 3 by 3 scaling matrix 196 | """ 197 | return scaling(_random_vector(min, max, prng)) 198 | 199 | 200 | def random_flip(flip_x_chance, flip_y_chance, prng=DEFAULT_PRNG): 201 | """ 202 | Construct a transformation randomly containing X/Y flips (or not). 203 | 204 | Args 205 | flip_x_chance: The chance that the result will contain a flip along the X axis. 206 | flip_y_chance: The chance that the result will contain a flip along the Y axis. 207 | prng: The pseudo-random number generator to use. 208 | Returns 209 | a homogeneous 3 by 3 transformation matrix 210 | """ 211 | flip_x = prng.uniform(0, 1) < flip_x_chance 212 | flip_y = prng.uniform(0, 1) < flip_y_chance 213 | # 1 - 2 * bool gives 1 for False and -1 for True. 214 | return scaling((1 - 2 * flip_x, 1 - 2 * flip_y)) 215 | 216 | 217 | def change_transform_origin(transform, center): 218 | """ 219 | Create a new transform representing the same transformation, only with the origin of the linear part changed. 220 | 221 | Args 222 | transform: the transformation matrix 223 | center: the new origin of the transformation 224 | Returns 225 | translate(center) * transform * translate(-center) 226 | """ 227 | center = np.array(center) 228 | return np.linalg.multi_dot([translation(center), transform, translation(-center)]) 229 | 230 | 231 | def random_transform( 232 | min_rotation=0, 233 | max_rotation=0, 234 | min_translation=(0, 0), 235 | max_translation=(0, 0), 236 | min_shear=0, 237 | max_shear=0, 238 | min_scaling=(1, 1), 239 | max_scaling=(1, 1), 240 | flip_x_chance=0, 241 | flip_y_chance=0, 242 | prng=DEFAULT_PRNG 243 | ): 244 | """ 245 | Create a random transformation. 246 | 247 | The transformation consists of the following operations in this order (from left to right): 248 | * rotation 249 | * translation 250 | * shear 251 | * scaling 252 | * flip x (if applied) 253 | * flip y (if applied) 254 | 255 | Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation 256 | as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width. 257 | Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret 258 | the translation directly as pixel distances instead. 259 | 260 | Args 261 | min_rotation: The minimum rotation in radians for the transform as scalar. 262 | max_rotation: The maximum rotation in radians for the transform as scalar. 263 | min_translation: The minimum translation for the transform as 2D column vector. 264 | max_translation: The maximum translation for the transform as 2D column vector. 265 | min_shear: The minimum shear angle for the transform in radians. 266 | max_shear: The maximum shear angle for the transform in radians. 267 | min_scaling: The minimum scaling for the transform as 2D column vector. 268 | max_scaling: The maximum scaling for the transform as 2D column vector. 269 | flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction. 270 | flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction. 271 | prng: The pseudo-random number generator to use. 272 | """ 273 | return np.linalg.multi_dot([ 274 | random_rotation(min_rotation, max_rotation, prng), 275 | random_translation(min_translation, max_translation, prng), 276 | random_shear(min_shear, max_shear, prng), 277 | random_scaling(min_scaling, max_scaling, prng), 278 | random_flip(flip_x_chance, flip_y_chance, prng) 279 | ]) 280 | 281 | 282 | def random_transform_generator(prng=None, **kwargs): 283 | """ 284 | Create a random transform generator. 285 | Uses a dedicated, newly created, properly seeded PRNG by default instead of the global DEFAULT_PRNG. 286 | 287 | The transformation consists of the following operations in this order (from left to right): 288 | * rotation 289 | * translation 290 | * shear 291 | * scaling 292 | * flip x (if applied) 293 | * flip y (if applied) 294 | 295 | Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation 296 | as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width. 297 | Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret 298 | the translation directly as pixel distances instead. 299 | 300 | Args 301 | min_rotation: The minimum rotation in radians for the transform as scalar. 302 | max_rotation: The maximum rotation in radians for the transform as scalar. 303 | min_translation: The minimum translation for the transform as 2D column vector. 304 | max_translation: The maximum translation for the transform as 2D column vector. 305 | min_shear: The minimum shear angle for the transform in radians. 306 | max_shear: The maximum shear angle for the transform in radians. 307 | min_scaling: The minimum scaling for the transform as 2D column vector. 308 | max_scaling: The maximum scaling for the transform as 2D column vector. 309 | flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction. 310 | flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction. 311 | prng: The pseudo-random number generator to use. 312 | """ 313 | 314 | if prng is None: 315 | # RandomState automatically seeds using the best available method. 316 | prng = np.random.RandomState() 317 | 318 | while True: 319 | yield random_transform(prng=prng, **kwargs) 320 | -------------------------------------------------------------------------------- /yolo/eval/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import numpy as np 18 | import cv2 19 | import progressbar 20 | assert (callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 21 | 22 | from utils.compute_overlap import compute_overlap 23 | from utils.visualization import draw_detections, draw_annotations 24 | 25 | 26 | def _compute_ap(recall, precision): 27 | """ 28 | Compute the average precision, given the recall and precision curves. 29 | 30 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 31 | 32 | Args: 33 | recall: The recall curve (list). 34 | precision: The precision curve (list). 35 | 36 | Returns: 37 | The average precision as computed in py-faster-rcnn. 38 | 39 | """ 40 | # correct AP calculation 41 | # first append sentinel values at the end 42 | mrec = np.concatenate(([0.], recall, [1.])) 43 | mpre = np.concatenate(([0.], precision, [0.])) 44 | 45 | # compute the precision envelope 46 | for i in range(mpre.size - 1, 0, -1): 47 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 48 | 49 | # to calculate area under PR curve, look for points 50 | # where X axis (recall) changes value 51 | i = np.where(mrec[1:] != mrec[:-1])[0] 52 | 53 | # and sum (delta recall) * prec 54 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 55 | return ap 56 | 57 | 58 | def _get_detections(generator, model, score_threshold=0.01, max_detections=100, visualize=False): 59 | """ 60 | Get the detections from the model using the generator. 61 | 62 | The result is a list of lists such that the size is: 63 | all_detections[num_images][num_classes] = detections[num_class_detections, 5] 64 | 65 | Args: 66 | generator: The generator used to run images through the model. 67 | model: The model to run on the images. 68 | score_threshold: The score confidence threshold to use. 69 | max_detections: The maximum number of detections to use per image. 70 | save_path: The path to save the images with visualized detections to. 71 | 72 | Returns: 73 | A list of lists containing the detections for each image in the generator. 74 | 75 | """ 76 | all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in range(generator.size())] 77 | 78 | for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '): 79 | image = generator.load_image(i) 80 | src_image = image.copy() 81 | image, scale, offset_h, offset_w = generator.preprocess_image(image) 82 | 83 | # run network 84 | # run network 85 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3] 86 | 87 | # correct boxes for image scale 88 | boxes[0, :, [0, 2]] -= offset_w 89 | boxes[0, :, [1, 3]] -= offset_h 90 | boxes /= scale 91 | 92 | # select indices which have a score above the threshold 93 | indices = np.where(scores[0, :] > score_threshold)[0] 94 | 95 | # select those scores 96 | scores = scores[0][indices] 97 | 98 | # find the order with which to sort the scores 99 | scores_sort = np.argsort(-scores)[:max_detections] 100 | 101 | # select detections 102 | # (n, 4) 103 | boxes = boxes[0, indices[scores_sort], :] 104 | # (n, ) 105 | scores = scores[scores_sort] 106 | # (n, ) 107 | labels = labels[0, indices[scores_sort]] 108 | # (n, 6) 109 | detections = np.concatenate( 110 | [boxes, np.expand_dims(scores, axis=1), np.expand_dims(labels, axis=1)], axis=1) 111 | 112 | if visualize: 113 | # draw_annotations(src_image, generator.load_annotations(i), label_to_name=generator.label_to_name) 114 | draw_detections(src_image, boxes[:5], scores[:5], labels[:5], 115 | label_to_name=generator.label_to_name, 116 | score_threshold=score_threshold) 117 | 118 | # cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image) 119 | cv2.namedWindow('{}'.format(i), cv2.WINDOW_NORMAL) 120 | cv2.imshow('{}'.format(i), src_image) 121 | cv2.waitKey(0) 122 | 123 | # copy detections to all_detections 124 | for class_id in range(generator.num_classes()): 125 | all_detections[i][class_id] = detections[detections[:, -1] == class_id, :-1] 126 | 127 | return all_detections 128 | 129 | 130 | def _get_annotations(generator): 131 | """ 132 | Get the ground truth annotations from the generator. 133 | 134 | The result is a list of lists such that the size is: 135 | all_annotations[num_images][num_classes] = annotations[num_class_annotations, 5] 136 | 137 | Args: 138 | generator: The generator used to retrieve ground truth annotations. 139 | 140 | Returns: 141 | A list of lists containing the annotations for each image in the generator. 142 | 143 | """ 144 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 145 | 146 | for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '): 147 | # load the annotations 148 | annotations = generator.load_annotations(i) 149 | 150 | # copy detections to all_annotations 151 | for label in range(generator.num_classes()): 152 | if not generator.has_label(label): 153 | continue 154 | 155 | all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy() 156 | 157 | return all_annotations 158 | 159 | 160 | def evaluate( 161 | generator, 162 | model, 163 | iou_threshold=0.5, 164 | score_threshold=0.01, 165 | max_detections=100, 166 | visualize=False, 167 | epoch=0 168 | ): 169 | """ 170 | Evaluate a given dataset using a given model. 171 | 172 | Args: 173 | generator: The generator that represents the dataset to evaluate. 174 | model: The model to evaluate. 175 | iou_threshold: The threshold used to consider when a detection is positive or negative. 176 | score_threshold: The score confidence threshold to use for detections. 177 | max_detections: The maximum number of detections to use per image. 178 | visualize: Show the visualized detections or not. 179 | 180 | Returns: 181 | A dict mapping class names to mAP scores. 182 | 183 | """ 184 | # gather all detections and annotations 185 | all_detections = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections, 186 | visualize=visualize) 187 | all_annotations = _get_annotations(generator) 188 | average_precisions = {} 189 | 190 | # all_detections = pickle.load(open('all_detections_{}.pkl'.format(epoch + 1), 'rb')) 191 | # all_annotations = pickle.load(open('all_annotations_{}.pkl'.format(epoch + 1), 'rb')) 192 | # pickle.dump(all_detections, open('all_detections_{}.pkl'.format(epoch + 1), 'wb')) 193 | # pickle.dump(all_annotations, open('all_annotations_{}.pkl'.format(epoch + 1), 'wb')) 194 | 195 | # process detections and annotations 196 | for label in range(generator.num_classes()): 197 | if not generator.has_label(label): 198 | continue 199 | 200 | false_positives = np.zeros((0,)) 201 | true_positives = np.zeros((0,)) 202 | scores = np.zeros((0,)) 203 | num_annotations = 0.0 204 | 205 | for i in range(generator.size()): 206 | detections = all_detections[i][label] 207 | annotations = all_annotations[i][label] 208 | num_annotations += annotations.shape[0] 209 | detected_annotations = [] 210 | 211 | for d in detections: 212 | scores = np.append(scores, d[4]) 213 | 214 | if annotations.shape[0] == 0: 215 | false_positives = np.append(false_positives, 1) 216 | true_positives = np.append(true_positives, 0) 217 | continue 218 | overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations) 219 | assigned_annotation = np.argmax(overlaps, axis=1) 220 | max_overlap = overlaps[0, assigned_annotation] 221 | 222 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations: 223 | false_positives = np.append(false_positives, 0) 224 | true_positives = np.append(true_positives, 1) 225 | detected_annotations.append(assigned_annotation) 226 | else: 227 | false_positives = np.append(false_positives, 1) 228 | true_positives = np.append(true_positives, 0) 229 | 230 | # no annotations -> AP for this class is 0 (is this correct?) 231 | if num_annotations == 0: 232 | average_precisions[label] = 0, 0 233 | continue 234 | 235 | # sort by score 236 | indices = np.argsort(-scores) 237 | false_positives = false_positives[indices] 238 | true_positives = true_positives[indices] 239 | 240 | # compute false positives and true positives 241 | false_positives = np.cumsum(false_positives) 242 | true_positives = np.cumsum(true_positives) 243 | 244 | # compute recall and precision 245 | recall = true_positives / num_annotations 246 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) 247 | 248 | # compute average precision 249 | average_precision = _compute_ap(recall, precision) 250 | average_precisions[label] = average_precision, num_annotations 251 | 252 | return average_precisions 253 | 254 | 255 | if __name__ == '__main__': 256 | from yolo.generators.pascal import PascalVocGenerator 257 | from yolo.model import yolo_body 258 | import os 259 | 260 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 261 | common_args = { 262 | 'batch_size': 1, 263 | 'image_size': 416 264 | } 265 | test_generator = PascalVocGenerator( 266 | 'datasets/voc_test/VOC2007', 267 | 'test', 268 | shuffle_groups=False, 269 | skip_truncated=False, 270 | skip_difficult=True, 271 | anchors_path='voc_anchors_416.txt', 272 | **common_args 273 | ) 274 | model_path = 'pascal_18_6.4112_6.5125_0.8319_0.8358.h5' 275 | num_classes = test_generator.num_classes() 276 | model, prediction_model = yolo_body(num_classes=num_classes) 277 | prediction_model.load_weights(model_path, by_name=True, skip_mismatch=True) 278 | average_precisions = evaluate(test_generator, prediction_model, visualize=False) 279 | # compute per class average precision 280 | total_instances = [] 281 | precisions = [] 282 | for label, (average_precision, num_annotations) in average_precisions.items(): 283 | print('{:.0f} instances of class'.format(num_annotations), test_generator.label_to_name(label), 284 | 'with average precision: {:.4f}'.format(average_precision)) 285 | total_instances.append(num_annotations) 286 | precisions.append(average_precision) 287 | mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 288 | print('mAP: {:.4f}'.format(mean_ap)) 289 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from utils.compute_overlap import compute_overlap 18 | from utils.visualization import draw_detections, draw_annotations 19 | 20 | import keras 21 | import numpy as np 22 | import os 23 | import cv2 24 | import progressbar 25 | import pickle 26 | 27 | assert (callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead." 28 | 29 | 30 | def _compute_ap(recall, precision): 31 | """ 32 | Compute the average precision, given the recall and precision curves. 33 | 34 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 35 | 36 | Args: 37 | recall: The recall curve (list). 38 | precision: The precision curve (list). 39 | 40 | Returns: 41 | The average precision as computed in py-faster-rcnn. 42 | 43 | """ 44 | # correct AP calculation 45 | # first append sentinel values at the end 46 | mrec = np.concatenate(([0.], recall, [1.])) 47 | mpre = np.concatenate(([0.], precision, [0.])) 48 | 49 | # compute the precision envelope 50 | for i in range(mpre.size - 1, 0, -1): 51 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 52 | 53 | # to calculate area under PR curve, look for points 54 | # where X axis (recall) changes value 55 | i = np.where(mrec[1:] != mrec[:-1])[0] 56 | 57 | # and sum (delta recall) * prec 58 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 59 | return ap 60 | 61 | 62 | def _get_detections(generator, model, score_threshold=0.05, max_detections=100, visualize=False): 63 | """ 64 | Get the detections from the model using the generator. 65 | 66 | The result is a list of lists such that the size is: 67 | all_detections[num_images][num_classes] = detections[num_class_detections, 5] 68 | 69 | Args: 70 | generator: The generator used to run images through the model. 71 | model: The model to run on the images. 72 | score_threshold: The score confidence threshold to use. 73 | max_detections: The maximum number of detections to use per image. 74 | save_path: The path to save the images with visualized detections to. 75 | 76 | Returns: 77 | A list of lists containing the detections for each image in the generator. 78 | 79 | """ 80 | all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in 81 | range(generator.size())] 82 | 83 | for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '): 84 | raw_image = generator.load_image(i) 85 | image = generator.preprocess_image(raw_image.copy()) 86 | image, scale = generator.resize_image(image) 87 | 88 | if keras.backend.image_data_format() == 'channels_first': 89 | image = image.transpose((2, 0, 1)) 90 | 91 | # run network 92 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3] 93 | 94 | # correct boxes for image scale 95 | boxes /= scale 96 | 97 | # select indices which have a score above the threshold 98 | indices = np.where(scores[0, :] > score_threshold)[0] 99 | 100 | # select those scores 101 | scores = scores[0][indices] 102 | 103 | # find the order with which to sort the scores 104 | scores_sort = np.argsort(-scores)[:max_detections] 105 | 106 | # select detections 107 | # (n, 4) 108 | image_boxes = boxes[0, indices[scores_sort], :] 109 | # (n, ) 110 | image_scores = scores[scores_sort] 111 | # (n, ) 112 | image_labels = labels[0, indices[scores_sort]] 113 | # (n, 6) 114 | image_detections = np.concatenate( 115 | [image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1) 116 | 117 | if visualize: 118 | draw_annotations(raw_image, generator.load_annotations(i), label_to_name=generator.label_to_name) 119 | draw_detections(raw_image, image_boxes[:5], image_scores[:5], image_labels[:5], label_to_name=generator.label_to_name, 120 | score_threshold=score_threshold) 121 | 122 | # cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image) 123 | cv2.namedWindow('{}'.format(i), cv2.WINDOW_NORMAL) 124 | cv2.imshow('{}'.format(i), raw_image) 125 | cv2.waitKey(0) 126 | 127 | # copy detections to all_detections 128 | for label in range(generator.num_classes()): 129 | if not generator.has_label(label): 130 | continue 131 | 132 | all_detections[i][label] = image_detections[image_detections[:, -1] == label, :-1] 133 | 134 | return all_detections 135 | 136 | 137 | def _get_annotations(generator): 138 | """ 139 | Get the ground truth annotations from the generator. 140 | 141 | The result is a list of lists such that the size is: 142 | all_annotations[num_images][num_classes] = annotations[num_class_annotations, 5] 143 | 144 | Args: 145 | generator: The generator used to retrieve ground truth annotations. 146 | 147 | Returns: 148 | A list of lists containing the annotations for each image in the generator. 149 | 150 | """ 151 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())] 152 | 153 | for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '): 154 | # load the annotations 155 | annotations = generator.load_annotations(i) 156 | 157 | # copy detections to all_annotations 158 | for label in range(generator.num_classes()): 159 | if not generator.has_label(label): 160 | continue 161 | 162 | all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy() 163 | 164 | return all_annotations 165 | 166 | 167 | def evaluate( 168 | generator, 169 | model, 170 | iou_threshold=0.5, 171 | score_threshold=0.05, 172 | max_detections=100, 173 | visualize=False, 174 | epoch=0 175 | ): 176 | """ 177 | Evaluate a given dataset using a given model. 178 | 179 | Args: 180 | generator: The generator that represents the dataset to evaluate. 181 | model: The model to evaluate. 182 | iou_threshold: The threshold used to consider when a detection is positive or negative. 183 | score_threshold: The score confidence threshold to use for detections. 184 | max_detections: The maximum number of detections to use per image. 185 | visualize: Show the visualized detections or not. 186 | 187 | Returns: 188 | A dict mapping class names to mAP scores. 189 | 190 | """ 191 | # gather all detections and annotations 192 | all_detections = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections, 193 | visualize=visualize) 194 | all_annotations = _get_annotations(generator) 195 | average_precisions = {} 196 | 197 | # all_detections = pickle.load(open('all_detections_{}.pkl'.format(epoch + 1), 'rb')) 198 | # all_annotations = pickle.load(open('all_annotations_{}.pkl'.format(epoch + 1), 'rb')) 199 | # pickle.dump(all_detections, open('all_detections_{}.pkl'.format(epoch + 1), 'wb')) 200 | # pickle.dump(all_annotations, open('all_annotations_{}.pkl'.format(epoch + 1), 'wb')) 201 | 202 | # process detections and annotations 203 | for label in range(generator.num_classes()): 204 | if not generator.has_label(label): 205 | continue 206 | 207 | false_positives = np.zeros((0,)) 208 | true_positives = np.zeros((0,)) 209 | scores = np.zeros((0,)) 210 | num_annotations = 0.0 211 | 212 | for i in range(generator.size()): 213 | detections = all_detections[i][label] 214 | annotations = all_annotations[i][label] 215 | num_annotations += annotations.shape[0] 216 | detected_annotations = [] 217 | 218 | for d in detections: 219 | scores = np.append(scores, d[4]) 220 | 221 | if annotations.shape[0] == 0: 222 | false_positives = np.append(false_positives, 1) 223 | true_positives = np.append(true_positives, 0) 224 | continue 225 | overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations) 226 | assigned_annotation = np.argmax(overlaps, axis=1) 227 | max_overlap = overlaps[0, assigned_annotation] 228 | 229 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations: 230 | false_positives = np.append(false_positives, 0) 231 | true_positives = np.append(true_positives, 1) 232 | detected_annotations.append(assigned_annotation) 233 | else: 234 | false_positives = np.append(false_positives, 1) 235 | true_positives = np.append(true_positives, 0) 236 | 237 | # no annotations -> AP for this class is 0 (is this correct?) 238 | if num_annotations == 0: 239 | average_precisions[label] = 0, 0 240 | continue 241 | 242 | # sort by score 243 | indices = np.argsort(-scores) 244 | false_positives = false_positives[indices] 245 | true_positives = true_positives[indices] 246 | 247 | # compute false positives and true positives 248 | false_positives = np.cumsum(false_positives) 249 | true_positives = np.cumsum(true_positives) 250 | 251 | # compute recall and precision 252 | recall = true_positives / num_annotations 253 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) 254 | 255 | # compute average precision 256 | average_precision = _compute_ap(recall, precision) 257 | average_precisions[label] = average_precision, num_annotations 258 | 259 | return average_precisions 260 | 261 | 262 | if __name__ == '__main__': 263 | from generators.voc_generator import PascalVocGenerator 264 | from utils.image import preprocess_image 265 | import models 266 | import os 267 | 268 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 269 | common_args = { 270 | 'batch_size': 1, 271 | 'image_min_side': 800, 272 | 'image_max_side': 1333, 273 | 'preprocess_image': preprocess_image, 274 | } 275 | # generator = PascalVocGenerator( 276 | # 'datasets/voc_trainval/VOC0712', 277 | # 'val', 278 | # shuffle_groups=False, 279 | # skip_truncated=False, 280 | # skip_difficult=True, 281 | # **common_args 282 | # ) 283 | generator = PascalVocGenerator( 284 | 'datasets/voc_test/VOC2007', 285 | 'test', 286 | shuffle_groups=False, 287 | skip_truncated=False, 288 | skip_difficult=True, 289 | **common_args 290 | ) 291 | model_path = '/home/adam/workspace/github/xuannianz/carrot/fsaf/snapshots/2019-10-05/resnet101_pascal_47_0.7652.h5' 292 | # load retinanet model 293 | # import keras.backend as K 294 | # K.set_learning_phase(1) 295 | from models.resnet import resnet_fsaf 296 | from models.retinanet import fsaf_bbox 297 | fsaf = resnet_fsaf(num_classes=20, backbone='resnet101') 298 | model = fsaf_bbox(fsaf) 299 | model.load_weights(model_path, by_name=True) 300 | average_precisions = evaluate(generator, model, visualize=False) 301 | # compute per class average precision 302 | total_instances = [] 303 | precisions = [] 304 | for label, (average_precision, num_annotations) in average_precisions.items(): 305 | print('{:.0f} instances of class'.format(num_annotations), generator.label_to_name(label), 306 | 'with average precision: {:.4f}'.format(average_precision)) 307 | total_instances.append(num_annotations) 308 | precisions.append(average_precision) 309 | mean_ap = sum(precisions) / sum(x > 0 for x in total_instances) 310 | print('mAP: {:.4f}'.format(mean_ap)) 311 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017-2018 Fizyr (https://fizyr.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from __future__ import division 18 | import numpy as np 19 | import cv2 20 | from PIL import Image 21 | 22 | from .transform import change_transform_origin 23 | 24 | 25 | def read_image_bgr(path): 26 | """ 27 | Read an image in BGR format. 28 | 29 | Args 30 | path: Path to the image. 31 | """ 32 | # We deliberately don't use cv2.imread here, since it gives no feedback on errors while reading the image. 33 | image = np.asarray(Image.open(path).convert('RGB')) 34 | return image[:, :, ::-1].copy() 35 | 36 | 37 | def preprocess_image(x, mode='caffe'): 38 | """ 39 | Preprocess an image by subtracting the ImageNet mean. 40 | 41 | Args 42 | x: np.array of shape (None, None, 3) or (3, None, None). 43 | mode: One of "caffe" or "tf". 44 | - caffe: will zero-center each color channel with 45 | respect to the ImageNet dataset, without scaling. 46 | - tf: will scale pixels between -1 and 1, sample-wise. 47 | 48 | Returns 49 | The input with the ImageNet mean subtracted. 50 | """ 51 | # mostly identical to "https://github.com/keras-team/keras-applications/blob/master/keras_applications/imagenet_utils.py" 52 | # except for converting RGB -> BGR since we assume BGR already 53 | 54 | # covert always to float32 to keep compatibility with opencv 55 | x = x.astype(np.float32) 56 | 57 | if mode == 'tf': 58 | x /= 127.5 59 | x -= 1. 60 | elif mode == 'caffe': 61 | x[..., 0] -= 103.939 62 | x[..., 1] -= 116.779 63 | x[..., 2] -= 123.68 64 | 65 | return x 66 | 67 | 68 | def adjust_transform_for_image(transform, image, relative_translation): 69 | """ 70 | Adjust a transformation for a specific image. 71 | 72 | The translation of the matrix will be scaled with the size of the image. 73 | The linear part of the transformation will adjusted so that the origin of the transformation will be at the center of the image. 74 | """ 75 | height, width, channels = image.shape 76 | 77 | result = transform 78 | 79 | # Scale the translation with the image size if specified. 80 | if relative_translation: 81 | result[0:2, 2] *= [width, height] 82 | 83 | # Move the origin of transformation. 84 | result = change_transform_origin(transform, (0.5 * width, 0.5 * height)) 85 | 86 | return result 87 | 88 | 89 | class TransformParameters: 90 | """ 91 | Struct holding parameters determining how to apply a transformation to an image. 92 | 93 | Args 94 | fill_mode: One of: 'constant', 'nearest', 'reflect', 'wrap' 95 | interpolation: One of: 'nearest', 'linear', 'cubic', 'area', 'lanczos4' 96 | cval: Fill value to use with fill_mode='constant' 97 | relative_translation: If true (the default), interpret translation as a factor of the image size. 98 | If false, interpret it as absolute pixels. 99 | """ 100 | 101 | def __init__( 102 | self, 103 | fill_mode='nearest', 104 | interpolation='linear', 105 | cval=0, 106 | relative_translation=True, 107 | ): 108 | self.fill_mode = fill_mode 109 | self.cval = cval 110 | self.interpolation = interpolation 111 | self.relative_translation = relative_translation 112 | 113 | def cvBorderMode(self): 114 | if self.fill_mode == 'constant': 115 | return cv2.BORDER_CONSTANT 116 | if self.fill_mode == 'nearest': 117 | return cv2.BORDER_REPLICATE 118 | if self.fill_mode == 'reflect': 119 | return cv2.BORDER_REFLECT_101 120 | if self.fill_mode == 'wrap': 121 | return cv2.BORDER_WRAP 122 | 123 | def cvInterpolation(self): 124 | if self.interpolation == 'nearest': 125 | return cv2.INTER_NEAREST 126 | if self.interpolation == 'linear': 127 | return cv2.INTER_LINEAR 128 | if self.interpolation == 'cubic': 129 | return cv2.INTER_CUBIC 130 | if self.interpolation == 'area': 131 | return cv2.INTER_AREA 132 | if self.interpolation == 'lanczos4': 133 | return cv2.INTER_LANCZOS4 134 | 135 | 136 | def apply_transform(matrix, image, params): 137 | """ 138 | Apply a transformation to an image. 139 | 140 | The origin of transformation is at the top left corner of the image. 141 | 142 | The matrix is interpreted such that a point (x, y) on the original image is moved to transform * (x, y) in the generated image. 143 | Mathematically speaking, that means that the matrix is a transformation from the transformed image space to the original image space. 144 | 145 | Args 146 | matrix: A homogeneous 3 by 3 matrix holding representing the transformation to apply. 147 | image: The image to transform. 148 | params: The transform parameters (see TransformParameters) 149 | """ 150 | output = cv2.warpAffine( 151 | image, 152 | matrix[:2, :], 153 | dsize=(image.shape[1], image.shape[0]), 154 | flags=params.cvInterpolation(), 155 | borderMode=params.cvBorderMode(), 156 | borderValue=params.cval, 157 | ) 158 | return output 159 | 160 | 161 | def compute_resize_scale(image_shape, min_side=800, max_side=1333): 162 | """ 163 | Compute an image scale such that the image size is constrained to min_side and max_side. 164 | 165 | Args 166 | min_side: The image's min side will be equal to min_side after resizing. 167 | max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side. 168 | 169 | Returns 170 | A resizing scale. 171 | """ 172 | (rows, cols, _) = image_shape 173 | 174 | smallest_side = min(rows, cols) 175 | 176 | # rescale the image so the smallest side is min_side 177 | scale = min_side / smallest_side 178 | 179 | # check if the largest side is now greater than max_side, which can happen 180 | # when images have a large aspect ratio 181 | largest_side = max(rows, cols) 182 | if largest_side * scale > max_side: 183 | scale = max_side / largest_side 184 | 185 | return scale 186 | 187 | 188 | def resize_image(img, min_side=800, max_side=1333): 189 | """ 190 | Resize an image such that the size is constrained to min_side and max_side. 191 | 192 | Args 193 | min_side: The image's min side will be equal to min_side after resizing. 194 | max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side. 195 | 196 | Returns 197 | A resized image. 198 | """ 199 | # compute scale to resize the image 200 | scale = compute_resize_scale(img.shape, min_side=min_side, max_side=max_side) 201 | 202 | # resize the image with the computed scale 203 | img = cv2.resize(img, None, fx=scale, fy=scale) 204 | 205 | return img, scale 206 | 207 | 208 | def _uniform(val_range): 209 | """ 210 | Uniformly sample from the given range. 211 | 212 | Args 213 | val_range: A pair of lower and upper bound. 214 | """ 215 | return np.random.uniform(val_range[0], val_range[1]) 216 | 217 | 218 | def _check_range(val_range, min_val=None, max_val=None): 219 | """ 220 | Check whether the range is a valid range. 221 | 222 | Args 223 | val_range: A pair of lower and upper bound. 224 | min_val: Minimal value for the lower bound. 225 | max_val: Maximal value for the upper bound. 226 | """ 227 | if val_range[0] > val_range[1]: 228 | raise ValueError('interval lower bound > upper bound') 229 | if min_val is not None and val_range[0] < min_val: 230 | raise ValueError('invalid interval lower bound') 231 | if max_val is not None and val_range[1] > max_val: 232 | raise ValueError('invalid interval upper bound') 233 | 234 | 235 | def _clip(image): 236 | """ 237 | Clip and convert an image to np.uint8. 238 | 239 | Args 240 | image: Image to clip. 241 | """ 242 | return np.clip(image, 0, 255).astype(np.uint8) 243 | 244 | 245 | class VisualEffect: 246 | """ 247 | Struct holding parameters and applying image color transformation. 248 | 249 | Args 250 | contrast_factor: A factor for adjusting contrast. Should be between 0 and 3. 251 | brightness_delta: Brightness offset between -1 and 1 added to the pixel values. 252 | hue_delta: Hue offset between -1 and 1 added to the hue channel. 253 | saturation_factor: A factor multiplying the saturation values of each pixel. 254 | """ 255 | 256 | def __init__( 257 | self, 258 | contrast_factor, 259 | brightness_delta, 260 | hue_delta, 261 | saturation_factor, 262 | ): 263 | self.contrast_factor = contrast_factor 264 | self.brightness_delta = brightness_delta 265 | self.hue_delta = hue_delta 266 | self.saturation_factor = saturation_factor 267 | 268 | def __call__(self, image): 269 | """ 270 | Apply a visual effect on the image. 271 | 272 | Args 273 | image: Image to adjust 274 | """ 275 | 276 | if self.contrast_factor: 277 | image = adjust_contrast(image, self.contrast_factor) 278 | if self.brightness_delta: 279 | image = adjust_brightness(image, self.brightness_delta) 280 | 281 | if self.hue_delta or self.saturation_factor: 282 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 283 | if self.hue_delta: 284 | image = adjust_hue(image, self.hue_delta) 285 | if self.saturation_factor: 286 | image = adjust_saturation(image, self.saturation_factor) 287 | 288 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 289 | 290 | return image 291 | 292 | 293 | def random_visual_effect_generator( 294 | contrast_range=(0.9, 1.1), 295 | brightness_range=(-.1, .1), 296 | hue_range=(-0.05, 0.05), 297 | saturation_range=(0.95, 1.05) 298 | ): 299 | """ 300 | Generate visual effect parameters uniformly sampled from the given intervals. 301 | 302 | Args 303 | contrast_factor: A factor interval for adjusting contrast. Should be between 0 and 3. 304 | brightness_delta: An interval between -1 and 1 for the amount added to the pixels. 305 | hue_delta: An interval between -1 and 1 for the amount added to the hue channel. 306 | The values are rotated if they exceed 180. 307 | saturation_factor: An interval for the factor multiplying the saturation values of each 308 | pixel. 309 | """ 310 | _check_range(contrast_range, 0) 311 | _check_range(brightness_range, -1, 1) 312 | _check_range(hue_range, -1, 1) 313 | _check_range(saturation_range, 0) 314 | 315 | def _generate(): 316 | while True: 317 | yield VisualEffect( 318 | contrast_factor=_uniform(contrast_range), 319 | brightness_delta=_uniform(brightness_range), 320 | hue_delta=_uniform(hue_range), 321 | saturation_factor=_uniform(saturation_range), 322 | ) 323 | 324 | return _generate() 325 | 326 | 327 | def adjust_contrast(image, factor): 328 | """ 329 | Adjust contrast of an image. 330 | 331 | Args 332 | image: Image to adjust. 333 | factor: A factor for adjusting contrast. 334 | """ 335 | mean = image.mean(axis=0).mean(axis=0) 336 | return _clip((image - mean) * factor + mean) 337 | 338 | 339 | def adjust_brightness(image, delta): 340 | """ 341 | Adjust brightness of an image 342 | 343 | Args 344 | image: Image to adjust. 345 | delta: Brightness offset between -1 and 1 added to the pixel values. 346 | """ 347 | return _clip(image + delta * 255) 348 | 349 | 350 | def adjust_hue(image, delta): 351 | """ 352 | Adjust hue of an image. 353 | 354 | Args 355 | image: Image to adjust. 356 | delta: An interval between -1 and 1 for the amount added to the hue channel. 357 | The values are rotated if they exceed 180. 358 | """ 359 | image[..., 0] = np.mod(image[..., 0] + delta * 180, 180) 360 | return image 361 | 362 | 363 | def adjust_saturation(image, factor): 364 | """ 365 | Adjust saturation of an image. 366 | 367 | Args 368 | image: Image to adjust. 369 | factor: An interval for the factor multiplying the saturation values of each pixel. 370 | """ 371 | image[..., 1] = np.clip(image[..., 1] * factor, 0, 255) 372 | return image 373 | --------------------------------------------------------------------------------