├── utils
├── __init__.py
├── compute_overlap.cpython-36m-x86_64-linux-gnu.so
├── model.py
├── keras_version.py
├── compute_overlap.pyx
├── config.py
├── colors.py
├── coco_eval.py
├── visualization.py
├── transform.py
├── eval.py
└── image.py
├── yolo
├── __init__.py
├── eval
│ ├── __init__.py
│ ├── pascal.py
│ ├── coco.py
│ └── common.py
├── generators
│ ├── __init__.py
│ ├── coco.py
│ ├── pascal.py
│ └── csv_.py
├── config.py
├── README.md
├── inference.py
└── model.py
├── augmentor
├── __init__.py
├── misc.py
└── color.py
├── generators
├── __init__.py
├── coco_generator.py
├── voc_generator.py
└── csv_generator.py
├── test
├── 004456.jpg
├── 005770.jpg
└── 006408.jpg
├── configure.py
├── requirements.txt
├── .github
└── stale.yml
├── initializers.py
├── setup.py
├── .gitignore
├── README.md
├── inference.py
├── models
├── vgg.py
├── densenet.py
├── mobilenet.py
├── __init__.py
└── resnet.py
├── util_graphs.py
├── callbacks.py
└── losses.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/yolo/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/augmentor/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/generators/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/yolo/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/yolo/generators/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/004456.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/test/004456.jpg
--------------------------------------------------------------------------------
/test/005770.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/test/005770.jpg
--------------------------------------------------------------------------------
/test/006408.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/test/006408.jpg
--------------------------------------------------------------------------------
/yolo/config.py:
--------------------------------------------------------------------------------
1 | MAX_NUM_GT_BOXES = 100
2 | POS_SCALE = 0.2
3 | IGNORE_SCALE = 0.5
4 | STRIDES = (8, 16, 32)
5 |
6 |
--------------------------------------------------------------------------------
/configure.py:
--------------------------------------------------------------------------------
1 | MAX_NUM_GT_BOXES = 100
2 | POS_SCALE = 0.2
3 | IGNORE_SCALE = 0.5
4 | STRIDES = (8, 16, 32, 64, 128)
5 |
6 |
--------------------------------------------------------------------------------
/utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuannianz/FSAF/HEAD/utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Keras==2.2.5
2 | keras-resnet==0.2.0
3 | opencv-contrib-python==3.4.2.17
4 | opencv-python==3.4.2.17
5 | Pillow==6.2.0
6 | tensorflow-gpu==1.15.2
7 | progressbar2
8 | git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
9 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 5
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 3
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: wontfix
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: false
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | def freeze(model):
19 | """
20 | Set all layers in a model to non-trainable.
21 |
22 | The weights for these layers will not be updated during training.
23 |
24 | This function modifies the given model in-place,
25 | but it also returns the modified model to allow easy chaining with other functions.
26 | """
27 | for layer in model.layers:
28 | layer.trainable = False
29 | return model
30 |
--------------------------------------------------------------------------------
/initializers.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import keras
18 |
19 | import numpy as np
20 | import math
21 |
22 |
23 | class PriorProbability(keras.initializers.Initializer):
24 | """ Apply a prior probability to the weights.
25 | """
26 |
27 | def __init__(self, probability=0.01):
28 | self.probability = probability
29 |
30 | def get_config(self):
31 | return {
32 | 'probability': self.probability
33 | }
34 |
35 | def __call__(self, shape, dtype=None):
36 | # set bias to -log((1 - p)/p) for foreground
37 | result = np.ones(shape, dtype=dtype) * -math.log((1 - self.probability) / self.probability)
38 |
39 | return result
40 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | from setuptools.extension import Extension
3 | from distutils.command.build_ext import build_ext as DistUtilsBuildExt
4 |
5 |
6 | class BuildExtension(setuptools.Command):
7 | description = DistUtilsBuildExt.description
8 | user_options = DistUtilsBuildExt.user_options
9 | boolean_options = DistUtilsBuildExt.boolean_options
10 | help_options = DistUtilsBuildExt.help_options
11 |
12 | def __init__(self, *args, **kwargs):
13 | from setuptools.command.build_ext import build_ext as SetupToolsBuildExt
14 |
15 | # Bypass __setatrr__ to avoid infinite recursion.
16 | self.__dict__['_command'] = SetupToolsBuildExt(*args, **kwargs)
17 |
18 | def __getattr__(self, name):
19 | return getattr(self._command, name)
20 |
21 | def __setattr__(self, name, value):
22 | setattr(self._command, name, value)
23 |
24 | def initialize_options(self, *args, **kwargs):
25 | return self._command.initialize_options(*args, **kwargs)
26 |
27 | def finalize_options(self, *args, **kwargs):
28 | ret = self._command.finalize_options(*args, **kwargs)
29 | import numpy
30 | self.include_dirs.append(numpy.get_include())
31 | return ret
32 |
33 | def run(self, *args, **kwargs):
34 | return self._command.run(*args, **kwargs)
35 |
36 |
37 | extensions = [
38 | Extension(
39 | 'utils.compute_overlap',
40 | ['utils/compute_overlap.pyx']
41 | ),
42 | ]
43 |
44 | setuptools.setup(
45 | cmdclass={'build_ext': BuildExtension},
46 | packages=setuptools.find_packages(),
47 | ext_modules=extensions,
48 | setup_requires=["cython>=0.28", "numpy>=1.14.0"]
49 | )
50 |
--------------------------------------------------------------------------------
/utils/keras_version.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from __future__ import print_function
18 |
19 | import keras
20 | import sys
21 |
22 | minimum_keras_version = 2, 2, 4
23 |
24 |
25 | def keras_version():
26 | """
27 | Get the Keras version.
28 |
29 | Returns
30 | tuple of (major, minor, patch). 如 (2, 2, 4)
31 | """
32 | return tuple(map(int, keras.__version__.split('.')))
33 |
34 |
35 | def keras_version_ok():
36 | """
37 | Check if the current Keras version is higher than the minimum version.
38 | """
39 | return keras_version() >= minimum_keras_version
40 |
41 |
42 | def assert_keras_version():
43 | """
44 | Assert that the Keras version is up to date.
45 | """
46 | detected = keras.__version__
47 | required = '.'.join(map(str, minimum_keras_version))
48 | assert(keras_version() >= minimum_keras_version), 'You are using keras version {}. The minimum required version is {}.'.format(detected, required)
49 |
50 |
51 | def check_keras_version():
52 | """
53 | Check that the Keras version is up to date. If it isn't, print an error message and exit the script.
54 | """
55 | try:
56 | assert_keras_version()
57 | except AssertionError as e:
58 | print(e, file=sys.stderr)
59 | sys.exit(1)
60 |
--------------------------------------------------------------------------------
/utils/compute_overlap.pyx:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Sergey Karayev
6 | # --------------------------------------------------------
7 |
8 | cimport cython
9 | import numpy as np
10 | cimport numpy as np
11 |
12 |
13 | def compute_overlap(
14 | np.ndarray[double, ndim=2] boxes,
15 | np.ndarray[double, ndim=2] query_boxes
16 | ):
17 | """
18 | Args
19 | a: (N, 4) ndarray of float
20 | b: (K, 4) ndarray of float
21 |
22 | Returns
23 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes
24 | """
25 | cdef unsigned int N = boxes.shape[0]
26 | cdef unsigned int K = query_boxes.shape[0]
27 | cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64)
28 | cdef double iw, ih, box_area
29 | cdef double ua
30 | cdef unsigned int k, n
31 | for k in range(K):
32 | box_area = (
33 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) *
34 | (query_boxes[k, 3] - query_boxes[k, 1] + 1)
35 | )
36 | for n in range(N):
37 | iw = (
38 | min(boxes[n, 2], query_boxes[k, 2]) -
39 | max(boxes[n, 0], query_boxes[k, 0]) + 1
40 | )
41 | if iw > 0:
42 | ih = (
43 | min(boxes[n, 3], query_boxes[k, 3]) -
44 | max(boxes[n, 1], query_boxes[k, 1]) + 1
45 | )
46 | if ih > 0:
47 | ua = np.float64(
48 | (boxes[n, 2] - boxes[n, 0] + 1) *
49 | (boxes[n, 3] - boxes[n, 1] + 1) +
50 | box_area - iw * ih
51 | )
52 | overlaps[n, k] = iw * ih / ua
53 | return overlaps
54 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import configparser
18 | import numpy as np
19 | import keras
20 | from utils.anchors import AnchorParameters
21 |
22 |
23 | def read_config_file(config_path):
24 | config = configparser.ConfigParser()
25 |
26 | with open(config_path, 'r') as file:
27 | config.read_file(file)
28 |
29 | assert 'anchor_parameters' in config, \
30 | "Malformed config file. Verify that it contains the anchor_parameters section."
31 |
32 | config_keys = set(config['anchor_parameters'])
33 | default_keys = set(AnchorParameters.default.__dict__.keys())
34 |
35 | assert config_keys <= default_keys, \
36 | "Malformed config file. These keys are not valid: {}".format(config_keys - default_keys)
37 |
38 | return config
39 |
40 |
41 | def parse_anchor_parameters(config):
42 | ratios = np.array(list(map(float, config['anchor_parameters']['ratios'].split(' '))), keras.backend.floatx())
43 | scales = np.array(list(map(float, config['anchor_parameters']['scales'].split(' '))), keras.backend.floatx())
44 | sizes = list(map(int, config['anchor_parameters']['sizes'].split(' ')))
45 | strides = list(map(int, config['anchor_parameters']['strides'].split(' ')))
46 |
47 | return AnchorParameters(sizes, strides, ratios, scales)
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 | .idea/
106 | datasets/
107 | logs/
108 | !utils/compute_overlap.cpython-36m-x86_64-linux-gnu.so
109 | snapshots/
110 | checkpoints/
111 |
--------------------------------------------------------------------------------
/yolo/README.md:
--------------------------------------------------------------------------------
1 | # FSAF
2 | This is an implementation of [FSAF](https://arxiv.org/abs/1903.00621) on keras and Tensorflow. The project is based on [qqwweee/keras-yolo3](https://github.com/qqwweee/keras-yolo3) and [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet).
3 | Thanks for their hard work.
4 |
5 | As the authors write, FASF module can be plugged into any single-shot detectors with FPN-like structure smoothly.
6 | I have also tried on yolo3. Anchor-free yolo3(with FSAF) gets a comparable performance with the anchor-based counterpart. But you don't need to pre-compute the anchor sizes any more.
7 | And it is much better and faster than the one based on retinanet.
8 |
9 | It can also converge quite quickly. After the first epoch(batch_size=64, steps=1000), we can get an mAP50 of 0.6xxx on val dataset.
10 |
11 | ## Test
12 | 1. I trained on Pascal VOC2012 trainval.txt + Pascal VOC2007 train.txt, and validated on Pascal VOC2007 val.txt. There are 14041 images for training and 2510 images for validation.
13 | 2. The best evaluation result (score_threshold=0.01, mAP50, image_size=416) on VOC2007 test is 0.8358. I have only trained once.
14 | 3. Pretrained yolo and fsaf weights are here. [baidu netdisk](https://pan.baidu.com/s/1QoGXnajcohj9P4yCVwJ4Yw), extract code: qab7
15 | 4. `python3 yolo/inference.py` to test your image by specifying image path and model path there.
16 |
17 | ## Train
18 | ### build dataset (Pascal VOC, other types please refer to [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet))
19 | * Download VOC2007 and VOC2012, copy all image files from VOC2007 to VOC2012.
20 | * Append VOC2007 train.txt to VOC2012 trainval.txt.
21 | * Overwrite VOC2012 val.txt by VOC2007 val.txt.
22 | ### train
23 | * **STEP1**: `python3 yolo/train.py --freeze-body darknet --gpu 0 --batch-size 32 --random-transform pascal datasets/VOC2012` to start training with lr=1e-3 then set lr=1e-4 when val mAP continue to drop.
24 | * **STEP2**: `python3 yolo/train.py --snapshot --freeze-body none --gpu 0 --batch-size 32 --random-transform pascal datasets/VOC2012` to start training with lr=1e-5 and then set lr=1e-6 when val mAP contines to drop.
25 | ## Evaluate
26 | * `python3 yolo/eval/common.py` to evaluate by specifying model path there.
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FSAF
2 | This is an implementation of [FSAF](https://arxiv.org/abs/1903.00621) on keras and Tensorflow. The project is based on [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet)
3 | and fsaf branch of [zccstig/mmdetection](https://github.com/zccstig/mmdetection/tree/fsaf).
4 | Thanks for their hard work.
5 |
6 | As the authors write, **FASF module can be plugged into any single-shot detectors with FPN-like structure smoothly**.
7 | I have also tried on [yolo3](yolo). Anchor-free yolo3(with FSAF) gets a comparable performance with the anchor-based counterpart. But you don't need to pre-compute the anchor sizes any more.
8 | And it is much better and faster than the one based on retinanet.
9 |
10 | **Updates**
11 | - [03/05/2020] The author of the paper has released a new paper [SAPD](https://arxiv.org/abs/1911.12448), which is based on FSAF.
12 | I have implemented it at [xuannianz/SAPD](https://github.com/xuannianz/SAPD).
13 |
14 |
15 | ## Test
16 | 1. I trained on Pascal VOC2012 trainval.txt + Pascal VOC2007 train.txt, and validated on Pascal VOC2007 val.txt. There are 14041 images for training and 2510 images for validation.
17 | 2. The best evaluation results (score_threshold=0.05) on VOC2007 test are:
18 |
19 | | backbone | mAP50 |
20 | | ---- | ---- |
21 | | resnet50 | 0.7248 |
22 | | resnet101 | 0.7652 |
23 |
24 | 3. Pretrained models are here.
25 | [baidu netdisk](https://pan.baidu.com/s/1ZdHvR-03XqHvxWG0rLCw1g) extract code: rbrr
26 | [goole dirver](https://drive.google.com/open?id=1Hcgxp5OwqNsAx-HYgcIhLat1OOHKnvJ2)
27 |
28 | 4. `python3 inference.py` to test your image by specifying image path and model path there.
29 |
30 | 
31 | 
32 | 
33 |
34 |
35 | ## Train
36 | ### build dataset (Pascal VOC, other types please refer to [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet))
37 | * Download VOC2007 and VOC2012, copy all image files from VOC2007 to VOC2012.
38 | * Append VOC2007 train.txt to VOC2012 trainval.txt.
39 | * Overwrite VOC2012 val.txt by VOC2007 val.txt.
40 | ### train
41 | * `python3 train.py --backbone resnet50 --gpu 0 --random-transform pascal datasets/VOC2012` to start training.
42 | ## Evaluate
43 | * `python3 utils/eval.py` to evaluate by specifying model path there.
44 |
--------------------------------------------------------------------------------
/utils/colors.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 |
4 | def label_color(label):
5 | """ Return a color from a set of predefined colors. Contains 80 colors in total.
6 |
7 | Args
8 | label: The label to get the color for.
9 |
10 | Returns
11 | A list of three values representing a RGB color.
12 |
13 | If no color is defined for a certain label, the color green is returned and a warning is printed.
14 | """
15 | if label < len(colors):
16 | return colors[label]
17 | else:
18 | warnings.warn('Label {} has no color, returning default.'.format(label))
19 | return (0, 255, 0)
20 |
21 |
22 | """
23 | Generated using:
24 |
25 | ```
26 | colors = [list((matplotlib.colors.hsv_to_rgb([x, 1.0, 1.0]) * 255).astype(int)) for x in np.arange(0, 1, 1.0 / 80)]
27 | shuffle(colors)
28 | pprint(colors)
29 | ```
30 | """
31 | colors = [
32 | [31 , 0 , 255] ,
33 | [0 , 159 , 255] ,
34 | [255 , 95 , 0] ,
35 | [255 , 19 , 0] ,
36 | [255 , 0 , 0] ,
37 | [255 , 38 , 0] ,
38 | [0 , 255 , 25] ,
39 | [255 , 0 , 133] ,
40 | [255 , 172 , 0] ,
41 | [108 , 0 , 255] ,
42 | [0 , 82 , 255] ,
43 | [0 , 255 , 6] ,
44 | [255 , 0 , 152] ,
45 | [223 , 0 , 255] ,
46 | [12 , 0 , 255] ,
47 | [0 , 255 , 178] ,
48 | [108 , 255 , 0] ,
49 | [184 , 0 , 255] ,
50 | [255 , 0 , 76] ,
51 | [146 , 255 , 0] ,
52 | [51 , 0 , 255] ,
53 | [0 , 197 , 255] ,
54 | [255 , 248 , 0] ,
55 | [255 , 0 , 19] ,
56 | [255 , 0 , 38] ,
57 | [89 , 255 , 0] ,
58 | [127 , 255 , 0] ,
59 | [255 , 153 , 0] ,
60 | [0 , 255 , 255] ,
61 | [0 , 255 , 216] ,
62 | [0 , 255 , 121] ,
63 | [255 , 0 , 248] ,
64 | [70 , 0 , 255] ,
65 | [0 , 255 , 159] ,
66 | [0 , 216 , 255] ,
67 | [0 , 6 , 255] ,
68 | [0 , 63 , 255] ,
69 | [31 , 255 , 0] ,
70 | [255 , 57 , 0] ,
71 | [255 , 0 , 210] ,
72 | [0 , 255 , 102] ,
73 | [242 , 255 , 0] ,
74 | [255 , 191 , 0] ,
75 | [0 , 255 , 63] ,
76 | [255 , 0 , 95] ,
77 | [146 , 0 , 255] ,
78 | [184 , 255 , 0] ,
79 | [255 , 114 , 0] ,
80 | [0 , 255 , 235] ,
81 | [255 , 229 , 0] ,
82 | [0 , 178 , 255] ,
83 | [255 , 0 , 114] ,
84 | [255 , 0 , 57] ,
85 | [0 , 140 , 255] ,
86 | [0 , 121 , 255] ,
87 | [12 , 255 , 0] ,
88 | [255 , 210 , 0] ,
89 | [0 , 255 , 44] ,
90 | [165 , 255 , 0] ,
91 | [0 , 25 , 255] ,
92 | [0 , 255 , 140] ,
93 | [0 , 101 , 255] ,
94 | [0 , 255 , 82] ,
95 | [223 , 255 , 0] ,
96 | [242 , 0 , 255] ,
97 | [89 , 0 , 255] ,
98 | [165 , 0 , 255] ,
99 | [70 , 255 , 0] ,
100 | [255 , 0 , 172] ,
101 | [255 , 76 , 0] ,
102 | [203 , 255 , 0] ,
103 | [204 , 0 , 255] ,
104 | [255 , 0 , 229] ,
105 | [255 , 133 , 0] ,
106 | [127 , 0 , 255] ,
107 | [0 , 235 , 255] ,
108 | [0 , 255 , 197] ,
109 | [255 , 0 , 191] ,
110 | [0 , 44 , 255] ,
111 | [50 , 255 , 0]
112 | ]
113 |
--------------------------------------------------------------------------------
/utils/coco_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from pycocotools.cocoeval import COCOeval
18 |
19 | import keras
20 | import numpy as np
21 | import json
22 |
23 | import progressbar
24 |
25 | assert (callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead."
26 |
27 |
28 | def evaluate_coco(generator, model, threshold=0.05):
29 | """ Use the pycocotools to evaluate a COCO model on a dataset.
30 |
31 | Args
32 | generator : The generator for generating the evaluation data.
33 | model : The model to evaluate.
34 | threshold : The score threshold to use.
35 | """
36 | # start collecting results
37 | results = []
38 | image_ids = []
39 | for index in progressbar.progressbar(range(generator.size()), prefix='COCO evaluation: '):
40 | image = generator.load_image(index)
41 | image = generator.preprocess_image(image)
42 | image, scale = generator.resize_image(image)
43 |
44 | if keras.backend.image_data_format() == 'channels_first':
45 | image = image.transpose((2, 0, 1))
46 |
47 | # run network
48 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
49 |
50 | # correct boxes for image scale
51 | boxes /= scale
52 |
53 | # change to (x, y, w, h) (MS COCO standard)
54 | boxes[:, :, 2] -= boxes[:, :, 0]
55 | boxes[:, :, 3] -= boxes[:, :, 1]
56 |
57 | # compute predicted labels and scores
58 | for box, score, label in zip(boxes[0], scores[0], labels[0]):
59 | # scores are sorted, so we can break
60 | if score < threshold:
61 | break
62 |
63 | # append detection for each positively labeled class
64 | image_result = {
65 | 'image_id': generator.image_ids[index],
66 | 'category_id': generator.label_to_coco_label(label),
67 | 'score': float(score),
68 | 'bbox': box.tolist(),
69 | }
70 |
71 | # append detection to results
72 | results.append(image_result)
73 |
74 | # append image to list of processed images
75 | image_ids.append(generator.image_ids[index])
76 |
77 | if not len(results):
78 | return
79 |
80 | # write output
81 | json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4)
82 | json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4)
83 |
84 | # load results in COCO evaluation tool
85 | coco_true = generator.coco
86 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name))
87 |
88 | # run COCO evaluation
89 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
90 | coco_eval.params.imgIds = image_ids
91 | coco_eval.evaluate()
92 | coco_eval.accumulate()
93 | coco_eval.summarize()
94 | return coco_eval.stats
95 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import keras
2 | import models
3 | from utils.image import read_image_bgr, preprocess_image, resize_image
4 | from utils.visualization import draw_box, draw_caption
5 | from utils.colors import label_color
6 |
7 | # import miscellaneous modules
8 | import matplotlib.pyplot as plt
9 | import cv2
10 | import os
11 | import numpy as np
12 | import time
13 | import glob
14 | import os.path as osp
15 |
16 | # set tf backend to allow memory to grow, instead of claiming everything
17 | import tensorflow as tf
18 |
19 |
20 | def get_session():
21 | config = tf.ConfigProto()
22 | config.gpu_options.allow_growth = True
23 | return tf.Session(config=config)
24 |
25 |
26 | # use this environment flag to change which GPU to use
27 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
28 |
29 | # set the modified tf session as backend in keras
30 | keras.backend.set_session(get_session())
31 | # adjust this to point to your downloaded/trained model
32 | # models can be downloaded here: https://github.com/fizyr/keras-retinanet/releases
33 | model_path = '/home/adam/workspace/github/xuannianz/carrot/fsaf/snapshots/2019-10-05/resnet101_pascal_47_0.7652.h5'
34 |
35 | # load retinanet model
36 | # model = models.load_model(model_path, backbone_name='resnet101')
37 |
38 | # if the model is not converted to an inference model, use the line below
39 | # see: https://github.com/fizyr/keras-retinanet#converting-a-training-model-to-inference-model
40 | from models.resnet import resnet_fsaf
41 | from models.retinanet import fsaf_bbox
42 | fsaf = resnet_fsaf(num_classes=20, backbone='resnet101')
43 | model = fsaf_bbox(fsaf)
44 | model.load_weights(model_path, by_name=True)
45 | # load label to names mapping for visualization purposes
46 | voc_classes = {
47 | 'aeroplane': 0,
48 | 'bicycle': 1,
49 | 'bird': 2,
50 | 'boat': 3,
51 | 'bottle': 4,
52 | 'bus': 5,
53 | 'car': 6,
54 | 'cat': 7,
55 | 'chair': 8,
56 | 'cow': 9,
57 | 'diningtable': 10,
58 | 'dog': 11,
59 | 'horse': 12,
60 | 'motorbike': 13,
61 | 'person': 14,
62 | 'pottedplant': 15,
63 | 'sheep': 16,
64 | 'sofa': 17,
65 | 'train': 18,
66 | 'tvmonitor': 19
67 | }
68 | labels_to_names = {}
69 | for key, value in voc_classes.items():
70 | labels_to_names[value] = key
71 | # load image
72 | image_paths = glob.glob('datasets/voc_test/VOC2007/JPEGImages/*.jpg')
73 | for image_path in image_paths:
74 | print('Handling {}'.format(image_path))
75 | image = read_image_bgr(image_path)
76 |
77 | # copy to draw on
78 | draw = image.copy()
79 |
80 | # preprocess image for network
81 | image = preprocess_image(image)
82 | image, scale = resize_image(image)
83 |
84 | # process image
85 | start = time.time()
86 | # locations, feature_shapes = model.predict_on_batch(np.expand_dims(image, axis=0))
87 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
88 | print("processing time: ", time.time() - start)
89 |
90 | # correct for image scale
91 | boxes /= scale
92 | labels_to_locations = {}
93 | # visualize detections
94 | for box, score, label in zip(boxes[0], scores[0], labels[0]):
95 | # scores are sorted so we can break
96 | if score < 0.5:
97 | break
98 | start_x = int(box[0])
99 | start_y = int(box[1])
100 | end_x = int(box[2])
101 | end_y = int(box[3])
102 | color = label_color(label)
103 |
104 | b = box.astype(int)
105 | draw_box(draw, b, color=color)
106 |
107 | caption = "{} {:.3f}".format(labels_to_names[label], score)
108 | draw_caption(draw, b, caption)
109 |
110 | cv2.namedWindow('image', cv2.WINDOW_NORMAL)
111 | cv2.imshow('image', draw)
112 | key = cv2.waitKey(0)
113 | if int(key) == 121:
114 | image_fname = osp.split(image_path)[-1]
115 | cv2.imwrite('test/{}'.format(image_fname), draw)
116 |
117 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 cgratie (https://github.com/cgratie/)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | import keras
19 | from keras.utils import get_file
20 |
21 | from . import retinanet
22 | from . import Backbone
23 | from ..utils.image import preprocess_image
24 |
25 |
26 | class VGGBackbone(Backbone):
27 | """ Describes backbone information and provides utility functions.
28 | """
29 |
30 | def retinanet(self, *args, **kwargs):
31 | """ Returns a retinanet model using the correct backbone.
32 | """
33 | return vgg_retinanet(*args, backbone=self.backbone, **kwargs)
34 |
35 | def download_imagenet(self):
36 | """ Downloads ImageNet weights and returns path to weights file.
37 | Weights can be downloaded at https://github.com/fizyr/keras-models/releases .
38 | """
39 | if self.backbone == 'vgg16':
40 | resource = keras.applications.vgg16.vgg16.WEIGHTS_PATH_NO_TOP
41 | checksum = '6d6bbae143d832006294945121d1f1fc'
42 | elif self.backbone == 'vgg19':
43 | resource = keras.applications.vgg19.vgg19.WEIGHTS_PATH_NO_TOP
44 | checksum = '253f8cb515780f3b799900260a226db6'
45 | else:
46 | raise ValueError("Backbone '{}' not recognized.".format(self.backbone))
47 |
48 | return get_file(
49 | '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'.format(self.backbone),
50 | resource,
51 | cache_subdir='models',
52 | file_hash=checksum
53 | )
54 |
55 | def validate(self):
56 | """ Checks whether the backbone string is correct.
57 | """
58 | allowed_backbones = ['vgg16', 'vgg19']
59 |
60 | if self.backbone not in allowed_backbones:
61 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones))
62 |
63 | def preprocess_image(self, inputs):
64 | """ Takes as input an image and prepares it for being passed through the network.
65 | """
66 | return preprocess_image(inputs, mode='caffe')
67 |
68 |
69 | def vgg_retinanet(num_classes, backbone='vgg16', inputs=None, modifier=None, **kwargs):
70 | """ Constructs a retinanet model using a vgg backbone.
71 |
72 | Args
73 | num_classes: Number of classes to predict.
74 | backbone: Which backbone to use (one of ('vgg16', 'vgg19')).
75 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
76 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
77 |
78 | Returns
79 | RetinaNet model with a VGG backbone.
80 | """
81 | # choose default input
82 | if inputs is None:
83 | inputs = keras.layers.Input(shape=(None, None, 3))
84 |
85 | # create the vgg backbone
86 | if backbone == 'vgg16':
87 | vgg = keras.applications.VGG16(input_tensor=inputs, include_top=False, weights=None)
88 | elif backbone == 'vgg19':
89 | vgg = keras.applications.VGG19(input_tensor=inputs, include_top=False, weights=None)
90 | else:
91 | raise ValueError("Backbone '{}' not recognized.".format(backbone))
92 |
93 | if modifier:
94 | vgg = modifier(vgg)
95 |
96 | # create the full model
97 | layer_names = ["block3_pool", "block4_pool", "block5_pool"]
98 | layer_outputs = [vgg.get_layer(name).output for name in layer_names]
99 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=layer_outputs, **kwargs)
100 |
--------------------------------------------------------------------------------
/yolo/inference.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import keras
4 | import numpy as np
5 | import os
6 | import os.path as osp
7 | import tensorflow as tf
8 | import time
9 |
10 | from utils.visualization import draw_box, draw_caption
11 | from utils.colors import label_color
12 | from yolo.model import yolo_body
13 |
14 |
15 | # set tf backend to allow memory to grow, instead of claiming everything
16 | def get_session():
17 | config = tf.ConfigProto()
18 | config.gpu_options.allow_growth = True
19 | return tf.Session(config=config)
20 |
21 |
22 | def preprocess_image(image, image_size=416):
23 | image_height, image_width = image.shape[:2]
24 | if image_height > image_width:
25 | scale = image_size / image_height
26 | resized_height = image_size
27 | resized_width = int(image_width * scale)
28 | else:
29 | scale = image_size / image_width
30 | resized_height = int(image_height * scale)
31 | resized_width = image_size
32 | image = cv2.resize(image, (resized_width, resized_height))
33 | new_image = np.ones((image_size, image_size, 3), dtype=np.float32) * 128.
34 | offset_h = (image_size - resized_height) // 2
35 | offset_w = (image_size - resized_width) // 2
36 | new_image[offset_h:offset_h + resized_height, offset_w:offset_w + resized_width] = image.astype(np.float32)
37 | new_image /= 255.
38 | return new_image, scale, offset_h, offset_w
39 |
40 |
41 | # use this environment flag to change which GPU to use
42 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
43 |
44 | # set the modified tf session as backend in keras
45 | keras.backend.set_session(get_session())
46 |
47 | model_path = 'pascal_18_6.4112_6.5125_0.8319_0.8358.h5'
48 |
49 | model, prediction_model = yolo_body(num_classes=20)
50 |
51 | prediction_model.load_weights(model_path, by_name=True)
52 |
53 | # load label to names mapping for visualization purposes
54 | voc_classes = {
55 | 'aeroplane': 0,
56 | 'bicycle': 1,
57 | 'bird': 2,
58 | 'boat': 3,
59 | 'bottle': 4,
60 | 'bus': 5,
61 | 'car': 6,
62 | 'cat': 7,
63 | 'chair': 8,
64 | 'cow': 9,
65 | 'diningtable': 10,
66 | 'dog': 11,
67 | 'horse': 12,
68 | 'motorbike': 13,
69 | 'person': 14,
70 | 'pottedplant': 15,
71 | 'sheep': 16,
72 | 'sofa': 17,
73 | 'train': 18,
74 | 'tvmonitor': 19
75 | }
76 | labels_to_names = {}
77 | for key, value in voc_classes.items():
78 | labels_to_names[value] = key
79 | # load image
80 | image_paths = glob.glob('datasets/voc_test/VOC2007/JPEGImages/*.jpg')
81 | for image_path in image_paths:
82 | print('Handling {}'.format(image_path))
83 | image = cv2.imread(image_path)
84 |
85 | # copy to draw on
86 | draw = image.copy()
87 |
88 | # preprocess image for network
89 | image, scale, offset_h, offset_w = preprocess_image(image)
90 |
91 | # process image
92 | start = time.time()
93 | # locations, feature_shapes = model.predict_on_batch(np.expand_dims(image, axis=0))
94 | boxes, scores, labels = prediction_model.predict_on_batch(np.expand_dims(image, axis=0))
95 | print("processing time: ", time.time() - start)
96 |
97 | # correct boxes for image scale
98 | boxes[0, :, [0, 2]] -= offset_w
99 | boxes[0, :, [1, 3]] -= offset_h
100 | boxes /= scale
101 |
102 | labels_to_locations = {}
103 | # visualize detections
104 | for box, score, label in zip(boxes[0], scores[0], labels[0]):
105 | # scores are sorted so we can break
106 | if score < 0.5:
107 | break
108 | start_x = int(box[0])
109 | start_y = int(box[1])
110 | end_x = int(box[2])
111 | end_y = int(box[3])
112 | color = label_color(label)
113 |
114 | b = box.astype(int)
115 | draw_box(draw, b, color=color)
116 |
117 | caption = "{} {:.3f}".format(labels_to_names[label], score)
118 | draw_caption(draw, b, caption)
119 |
120 | cv2.namedWindow('image', cv2.WINDOW_NORMAL)
121 | cv2.imshow('image', draw)
122 | key = cv2.waitKey(0)
123 | if int(key) == 121:
124 | image_fname = osp.split(image_path)[-1]
125 | cv2.imwrite('test/{}'.format(image_fname), draw)
126 |
--------------------------------------------------------------------------------
/yolo/eval/pascal.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import keras
18 | from .common import evaluate
19 |
20 |
21 | class Evaluate(keras.callbacks.Callback):
22 | """
23 | Evaluation callback for arbitrary datasets.
24 | """
25 |
26 | def __init__(
27 | self,
28 | generator,
29 | model,
30 | iou_threshold=0.5,
31 | score_threshold=0.01,
32 | max_detections=100,
33 | save_path=None,
34 | tensorboard=None,
35 | weighted_average=False,
36 | verbose=1
37 | ):
38 | """
39 | Evaluate a given dataset using a given model at the end of every epoch during training.
40 |
41 | Args:
42 | generator: The generator that represents the dataset to evaluate.
43 | iou_threshold: The threshold used to consider when a detection is positive or negative.
44 | score_threshold: The score confidence threshold to use for detections.
45 | max_detections: The maximum number of detections to use per image.
46 | save_path: The path to save images with visualized detections to.
47 | tensorboard: Instance of keras.callbacks.TensorBoard used to log the mAP value.
48 | weighted_average: Compute the mAP using the weighted average of precisions among classes.
49 | verbose: Set the verbosity level, by default this is set to 1.
50 | """
51 | self.generator = generator
52 | self.iou_threshold = iou_threshold
53 | self.score_threshold = score_threshold
54 | self.max_detections = max_detections
55 | self.save_path = save_path
56 | self.tensorboard = tensorboard
57 | self.weighted_average = weighted_average
58 | self.verbose = verbose
59 | self.active_model = model
60 |
61 | super(Evaluate, self).__init__()
62 |
63 | def on_epoch_end(self, epoch, logs=None):
64 | logs = logs or {}
65 |
66 | # run evaluation
67 | average_precisions = evaluate(
68 | self.generator,
69 | self.active_model,
70 | iou_threshold=self.iou_threshold,
71 | score_threshold=self.score_threshold,
72 | max_detections=self.max_detections,
73 | visualize=False
74 | )
75 |
76 | # compute per class average precision
77 | total_instances = []
78 | precisions = []
79 | for label, (average_precision, num_annotations) in average_precisions.items():
80 | if self.verbose == 1:
81 | print('{:.0f} instances of class'.format(num_annotations),
82 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision))
83 | total_instances.append(num_annotations)
84 | precisions.append(average_precision)
85 | if self.weighted_average:
86 | self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances)
87 | else:
88 | self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances)
89 |
90 | if self.tensorboard is not None and self.tensorboard.writer is not None:
91 | import tensorflow as tf
92 | summary = tf.Summary()
93 | summary_value = summary.value.add()
94 | summary_value.simple_value = self.mean_ap
95 | summary_value.tag = "mAP"
96 | self.tensorboard.writer.add_summary(summary, epoch)
97 |
98 | logs['mAP'] = self.mean_ap
99 |
100 | if self.verbose == 1:
101 | print('mAP: {:.4f}'.format(self.mean_ap))
102 |
--------------------------------------------------------------------------------
/models/densenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2018 vidosits (https://github.com/vidosits/)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import keras
18 | from keras.applications import densenet
19 | from keras.utils import get_file
20 |
21 | from . import retinanet
22 | from . import Backbone
23 | from ..utils.image import preprocess_image
24 |
25 |
26 | allowed_backbones = {
27 | 'densenet121': ([6, 12, 24, 16], densenet.DenseNet121),
28 | 'densenet169': ([6, 12, 32, 32], densenet.DenseNet169),
29 | 'densenet201': ([6, 12, 48, 32], densenet.DenseNet201),
30 | }
31 |
32 |
33 | class DenseNetBackbone(Backbone):
34 | """ Describes backbone information and provides utility functions.
35 | """
36 |
37 | def retinanet(self, *args, **kwargs):
38 | """ Returns a retinanet model using the correct backbone.
39 | """
40 | return densenet_retinanet(*args, backbone=self.backbone, **kwargs)
41 |
42 | def download_imagenet(self):
43 | """ Download pre-trained weights for the specified backbone name.
44 | This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop
45 | where backbone is the densenet + number of layers (e.g. densenet121).
46 | For more info check the explanation from the keras densenet script itself:
47 | https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py
48 | """
49 | origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/'
50 | file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5'
51 |
52 | # load weights
53 | if keras.backend.image_data_format() == 'channels_first':
54 | raise ValueError('Weights for "channels_first" format are not available.')
55 |
56 | weights_url = origin + file_name.format(self.backbone)
57 | return get_file(file_name.format(self.backbone), weights_url, cache_subdir='models')
58 |
59 | def validate(self):
60 | """ Checks whether the backbone string is correct.
61 | """
62 | backbone = self.backbone.split('_')[0]
63 |
64 | if backbone not in allowed_backbones:
65 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones.keys()))
66 |
67 | def preprocess_image(self, inputs):
68 | """ Takes as input an image and prepares it for being passed through the network.
69 | """
70 | return preprocess_image(inputs, mode='tf')
71 |
72 |
73 | def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifier=None, **kwargs):
74 | """ Constructs a retinanet model using a densenet backbone.
75 |
76 | Args
77 | num_classes: Number of classes to predict.
78 | backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')).
79 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
80 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
81 |
82 | Returns
83 | RetinaNet model with a DenseNet backbone.
84 | """
85 | # choose default input
86 | if inputs is None:
87 | inputs = keras.layers.Input((None, None, 3))
88 |
89 | blocks, creator = allowed_backbones[backbone]
90 | model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None)
91 |
92 | # get last conv layer from the end of each dense block
93 | layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)]
94 |
95 | # create the densenet backbone
96 | model = keras.models.Model(inputs=inputs, outputs=layer_outputs[1:], name=model.name)
97 |
98 | # invoke modifier if given
99 | if modifier:
100 | model = modifier(model)
101 |
102 | # create the full model
103 | model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=model.outputs, **kwargs)
104 |
105 | return model
106 |
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 lvaleriu (https://github.com/lvaleriu/)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import keras
18 | from keras.applications import mobilenet
19 | from keras.utils import get_file
20 | from ..utils.image import preprocess_image
21 |
22 | from . import retinanet
23 | from . import Backbone
24 |
25 |
26 | class MobileNetBackbone(Backbone):
27 | """ Describes backbone information and provides utility functions.
28 | """
29 |
30 | allowed_backbones = ['mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224']
31 |
32 | def retinanet(self, *args, **kwargs):
33 | """ Returns a retinanet model using the correct backbone.
34 | """
35 | return mobilenet_retinanet(*args, backbone=self.backbone, **kwargs)
36 |
37 | def download_imagenet(self):
38 | """ Download pre-trained weights for the specified backbone name.
39 | This name is in the format mobilenet{rows}_{alpha} where rows is the
40 | imagenet shape dimension and 'alpha' controls the width of the network.
41 | For more info check the explanation from the keras mobilenet script itself.
42 | """
43 |
44 | alpha = float(self.backbone.split('_')[1])
45 | rows = int(self.backbone.split('_')[0].replace('mobilenet', ''))
46 |
47 | # load weights
48 | if keras.backend.image_data_format() == 'channels_first':
49 | raise ValueError('Weights for "channels_last" format '
50 | 'are not available.')
51 | if alpha == 1.0:
52 | alpha_text = '1_0'
53 | elif alpha == 0.75:
54 | alpha_text = '7_5'
55 | elif alpha == 0.50:
56 | alpha_text = '5_0'
57 | else:
58 | alpha_text = '2_5'
59 |
60 | model_name = 'mobilenet_{}_{}_tf_no_top.h5'.format(alpha_text, rows)
61 | weights_url = mobilenet.mobilenet.BASE_WEIGHT_PATH + model_name
62 | weights_path = get_file(model_name, weights_url, cache_subdir='models')
63 |
64 | return weights_path
65 |
66 | def validate(self):
67 | """ Checks whether the backbone string is correct.
68 | """
69 | backbone = self.backbone.split('_')[0]
70 |
71 | if backbone not in MobileNetBackbone.allowed_backbones:
72 | raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, MobileNetBackbone.allowed_backbones))
73 |
74 | def preprocess_image(self, inputs):
75 | """ Takes as input an image and prepares it for being passed through the network.
76 | """
77 | return preprocess_image(inputs, mode='tf')
78 |
79 |
80 | def mobilenet_retinanet(num_classes, backbone='mobilenet224_1.0', inputs=None, modifier=None, **kwargs):
81 | """ Constructs a retinanet model using a mobilenet backbone.
82 |
83 | Args
84 | num_classes: Number of classes to predict.
85 | backbone: Which backbone to use (one of ('mobilenet128', 'mobilenet160', 'mobilenet192', 'mobilenet224')).
86 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
87 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
88 |
89 | Returns
90 | RetinaNet model with a MobileNet backbone.
91 | """
92 | alpha = float(backbone.split('_')[1])
93 |
94 | # choose default input
95 | if inputs is None:
96 | inputs = keras.layers.Input((None, None, 3))
97 |
98 | backbone = mobilenet.MobileNet(input_tensor=inputs, alpha=alpha, include_top=False, pooling=None, weights=None)
99 |
100 | # create the full model
101 | layer_names = ['conv_pw_5_relu', 'conv_pw_11_relu', 'conv_pw_13_relu']
102 | layer_outputs = [backbone.get_layer(name).output for name in layer_names]
103 | backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs, name=backbone.name)
104 |
105 | # invoke modifier if given
106 | if modifier:
107 | backbone = modifier(backbone)
108 |
109 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone.outputs, **kwargs)
110 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import cv2
18 | import numpy as np
19 |
20 | from .colors import label_color
21 |
22 |
23 | def draw_box(image, box, color, thickness=2):
24 | """ Draws a box on an image with a given color.
25 |
26 | # Arguments
27 | image : The image to draw on.
28 | box : A list of 4 elements (x1, y1, x2, y2).
29 | color : The color of the box.
30 | thickness : The thickness of the lines to draw a box with.
31 | """
32 | b = np.array(box).astype(int)
33 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, thickness, cv2.LINE_AA)
34 |
35 |
36 | def draw_caption(image, box, caption):
37 | """ Draws a caption above the box in an image.
38 |
39 | # Arguments
40 | image : The image to draw on.
41 | box : A list of 4 elements (x1, y1, x2, y2).
42 | caption : String containing the text to draw.
43 | """
44 | b = np.array(box).astype(int)
45 | ret, baseline = cv2.getTextSize(caption, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
46 | cv2.rectangle(image, (b[0], b[3] - ret[1] - baseline), (b[0] + ret[0], b[3]), (255, 255, 255), -1)
47 | cv2.putText(image, caption, (b[0], b[3] - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
48 |
49 |
50 | def draw_boxes(image, boxes, color, thickness=2):
51 | """ Draws boxes on an image with a given color.
52 |
53 | # Arguments
54 | image : The image to draw on.
55 | boxes : A [N, 4] matrix (x1, y1, x2, y2).
56 | color : The color of the boxes.
57 | thickness : The thickness of the lines to draw boxes with.
58 | """
59 | for b in boxes:
60 | draw_box(image, b, color, thickness=thickness)
61 |
62 |
63 | def draw_detections(image, boxes, scores, labels, color=None, label_to_name=None, score_threshold=0.5):
64 | """ Draws detections in an image.
65 |
66 | # Arguments
67 | image : The image to draw on.
68 | boxes : A [N, 4] matrix (x1, y1, x2, y2).
69 | scores : A list of N classification scores.
70 | labels : A list of N labels.
71 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used.
72 | label_to_name : (optional) Functor for mapping a label to a name.
73 | score_threshold : Threshold used for determining what detections to draw.
74 | """
75 | selection = np.where(scores > score_threshold)[0]
76 |
77 | for i in selection:
78 | c = color if color is not None else label_color(labels[i])
79 | draw_box(image, boxes[i, :], color=c)
80 |
81 | # draw labels
82 | caption = (label_to_name(labels[i]) if label_to_name else labels[i]) + ': {0:.2f}'.format(scores[i])
83 | draw_caption(image, boxes[i, :], caption)
84 |
85 |
86 | def draw_annotations(image, annotations, color=(0, 255, 0), label_to_name=None):
87 | """ Draws annotations in an image.
88 |
89 | # Arguments
90 | image : The image to draw on.
91 | annotations : A [N, 5] matrix (x1, y1, x2, y2, label) or dictionary containing bboxes (shaped [N, 4]) and labels (shaped [N]).
92 | color : The color of the boxes. By default the color from keras_retinanet.utils.colors.label_color will be used.
93 | label_to_name : (optional) Functor for mapping a label to a name.
94 | """
95 | if isinstance(annotations, np.ndarray):
96 | annotations = {'bboxes': annotations[:, :4], 'labels': annotations[:, 4]}
97 |
98 | assert('bboxes' in annotations)
99 | assert('labels' in annotations)
100 | assert(annotations['bboxes'].shape[0] == annotations['labels'].shape[0])
101 |
102 | for i in range(annotations['bboxes'].shape[0]):
103 | label = annotations['labels'][i]
104 | c = color if color is not None else label_color(label)
105 | caption = '{}'.format(label_to_name(label) if label_to_name else label)
106 | draw_caption(image, annotations['bboxes'][i], caption)
107 | draw_box(image, annotations['bboxes'][i], color=c)
108 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import layers
4 | import losses
5 | import initializers
6 | from fsaf_layers import RegressBoxes, Locations, LevelSelect, FSAFTarget
7 | import keras
8 | import tensorflow as tf
9 |
10 |
11 | class Backbone(object):
12 | """ This class stores additional information on backbones.
13 | """
14 |
15 | def __init__(self, backbone):
16 | # a dictionary mapping custom layer names to the correct classes
17 | self.custom_objects = {
18 | 'UpsampleLike': layers.UpsampleLike,
19 | 'PriorProbability': initializers.PriorProbability,
20 | 'RegressBoxes': RegressBoxes,
21 | 'FilterDetections': layers.FilterDetections,
22 | 'Anchors': layers.Anchors,
23 | 'ClipBoxes': layers.ClipBoxes,
24 | 'cls_loss': losses.focal_with_mask(),
25 | 'regr_loss': losses.iou_with_mask(),
26 | 'Locations': Locations,
27 | 'LevelSelect': LevelSelect,
28 | 'FSAFTarget': FSAFTarget,
29 | 'keras': keras,
30 | 'tf': tf,
31 | 'backend': tf,
32 | '': lambda y_true, y_pred: y_pred
33 | }
34 |
35 | self.backbone = backbone
36 | self.validate()
37 |
38 | def retinanet(self, *args, **kwargs):
39 | """
40 | Returns a retinanet model using the correct backbone.
41 | """
42 | raise NotImplementedError('retinanet method not implemented.')
43 |
44 | def download_imagenet(self):
45 | """
46 | Downloads ImageNet weights and returns path to weights file.
47 | """
48 | raise NotImplementedError('download_imagenet method not implemented.')
49 |
50 | def validate(self):
51 | """
52 | Checks whether the backbone string is correct.
53 | """
54 | raise NotImplementedError('validate method not implemented.')
55 |
56 | def preprocess_image(self, inputs):
57 | """
58 | Takes as input an image and prepares it for being passed through the network.
59 | Having this function in Backbone allows other backbones to define a specific preprocessing step.
60 | """
61 | raise NotImplementedError('preprocess_image method not implemented.')
62 |
63 |
64 | def backbone(backbone_name):
65 | """ Returns a backbone object for the given backbone.
66 | """
67 | if 'resnet' in backbone_name:
68 | from .resnet import ResNetBackbone as b
69 | elif 'mobilenet' in backbone_name:
70 | from .mobilenet import MobileNetBackbone as b
71 | elif 'vgg' in backbone_name:
72 | from .vgg import VGGBackbone as b
73 | elif 'densenet' in backbone_name:
74 | from .densenet import DenseNetBackbone as b
75 | else:
76 | raise NotImplementedError('Backbone class for \'{}\' not implemented.'.format(backbone))
77 |
78 | return b(backbone_name)
79 |
80 |
81 | def load_model(filepath, backbone_name='resnet50'):
82 | """ Loads a retinanet model using the correct custom objects.
83 |
84 | Args
85 | filepath: one of the following:
86 | - string, path to the saved model, or
87 | - h5py.File object from which to load the model
88 | backbone_name : Backbone with which the model was trained.
89 |
90 | Returns
91 | A keras.models.Model object.
92 |
93 | Raises
94 | ImportError: if h5py is not available.
95 | ValueError: In case of an invalid savefile.
96 | """
97 | import keras.models
98 | return keras.models.load_model(filepath, custom_objects=backbone(backbone_name).custom_objects)
99 |
100 |
101 | def convert_model(model, nms=True, class_specific_filter=True):
102 | """ Converts a training model to an inference model.
103 |
104 | Args
105 | model : A retinanet training model.
106 | nms : Boolean, whether to add NMS filtering to the converted model.
107 | class_specific_filter : Whether to use class specific filtering or filter for the best scoring class only.
108 | anchor_params : Anchor parameters object. If omitted, default values are used.
109 |
110 | Returns
111 | A keras.models.Model object.
112 |
113 | Raises
114 | ImportError: if h5py is not available.
115 | ValueError: In case of an invalid savefile.
116 | """
117 | from .retinanet import fsaf_bbox
118 | return fsaf_bbox(model=model, nms=nms, class_specific_filter=class_specific_filter)
119 |
120 |
121 | def assert_training_model(model):
122 | """
123 | Assert that the model is a training model.
124 | """
125 | assert (all(output in model.output_names for output in
126 | ['cls_loss', 'regr_loss', 'fsaf_regression', 'fsaf_classification'])), \
127 | "Input is not a training model (no 'regression' and 'classification' outputs were found, outputs are: {}).".format(
128 | model.output_names)
129 |
130 |
131 | def check_training_model(model):
132 | """
133 | Check that model is a training model and exit otherwise.
134 | """
135 | try:
136 | assert_training_model(model)
137 | except AssertionError as e:
138 | print(e, file=sys.stderr)
139 | sys.exit(1)
140 |
--------------------------------------------------------------------------------
/augmentor/misc.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from augmentor.transform import translation_xy, change_transform_origin
4 |
5 | ROTATE_DEGREE = [90, 180, 270]
6 |
7 |
8 | def rotate(image, boxes, prob=0.5):
9 | random_prob = np.random.uniform()
10 | if random_prob < prob:
11 | return image, boxes
12 | rotate_degree = ROTATE_DEGREE[np.random.randint(0, 3)]
13 | h, w = image.shape[:2]
14 | # Compute the rotation matrix.
15 | M = cv2.getRotationMatrix2D(center=(w / 2, h / 2),
16 | angle=rotate_degree,
17 | scale=1)
18 |
19 | # Get the sine and cosine from the rotation matrix.
20 | abs_cos_angle = np.abs(M[0, 0])
21 | abs_sin_angle = np.abs(M[0, 1])
22 |
23 | # Compute the new bounding dimensions of the image.
24 | new_w = int(h * abs_sin_angle + w * abs_cos_angle)
25 | new_h = int(h * abs_cos_angle + w * abs_sin_angle)
26 |
27 | # Adjust the rotation matrix to take into account the translation.
28 | M[0, 2] += new_w // 2 - w // 2
29 | M[1, 2] += new_h // 2 - h // 2
30 |
31 | # Rotate the image.
32 | image = cv2.warpAffine(image, M=M, dsize=(new_w, new_h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT,
33 | borderValue=(128, 128, 128))
34 |
35 | new_boxes = []
36 | for box in boxes:
37 | x1, y1, x2, y2 = box
38 | points = M.dot([
39 | [x1, x2, x1, x2],
40 | [y1, y2, y2, y1],
41 | [1, 1, 1, 1],
42 | ])
43 |
44 | # Extract the min and max corners again.
45 | min_xy = np.sort(points, axis=1)[:, :2]
46 | min_x = np.mean(min_xy[0])
47 | min_y = np.mean(min_xy[1])
48 | max_xy = np.sort(points, axis=1)[:, 2:]
49 | max_x = np.mean(max_xy[0])
50 | max_y = np.mean(max_xy[1])
51 |
52 | new_boxes.append([min_x, min_y, max_x, max_y])
53 | boxes = np.array(new_boxes)
54 | return image, boxes
55 |
56 |
57 | def crop(image, boxes, prob=0.5):
58 | random_prob = np.random.uniform()
59 | if random_prob < prob:
60 | return image, boxes
61 | h, w = image.shape[:2]
62 | min_x1, min_y1 = np.min(boxes, axis=0)[:2]
63 | max_x2, max_y2 = np.max(boxes, axis=0)[2:]
64 | random_x1 = np.random.randint(0, max(min_x1 // 2, 1))
65 | random_y1 = np.random.randint(0, max(min_y1 // 2, 1))
66 | random_x2 = np.random.randint(max_x2, max(min(w, max_x2 + (w - max_x2) // 2), max_x2 + 1))
67 | random_y2 = np.random.randint(max_y2, max(min(h, max_y2 + (h - max_y2) // 2), max_y2 + 1))
68 | image = image[random_y1:random_y2, random_x1:random_x2]
69 | boxes[:, [0, 2]] = boxes[:, [0, 2]] - random_x1
70 | boxes[:, [1, 3]] = boxes[:, [1, 3]] - random_y1
71 | return image, boxes
72 |
73 |
74 | def translate(image, boxes, prob=0.5):
75 | random_prob = np.random.uniform()
76 | if random_prob < prob:
77 | return image, boxes
78 | h, w = image.shape[:2]
79 | min_x1, min_y1 = np.min(boxes, axis=0)[:2]
80 | max_x2, max_y2 = np.max(boxes, axis=0)[2:]
81 | translation_matrix = translation_xy(min=(min(-min_x1 // 2, 0), min(-min_y1 // 2, 0)),
82 | max=(max((w - max_x2) // 2, 1), max((h - max_y2) // 2, 1)), prob=1.)
83 | translation_matrix = change_transform_origin(translation_matrix, (w / 2, h / 2))
84 | image = cv2.warpAffine(
85 | image,
86 | translation_matrix[:2, :],
87 | dsize=(w, h),
88 | flags=cv2.INTER_CUBIC,
89 | borderMode=cv2.BORDER_CONSTANT,
90 | borderValue=(128, 128, 128),
91 | )
92 | new_boxes = []
93 | for box in boxes:
94 | x1, y1, x2, y2 = box
95 | points = translation_matrix.dot([
96 | [x1, x2, x1, x2],
97 | [y1, y2, y2, y1],
98 | [1, 1, 1, 1],
99 | ])
100 | min_x, min_y = np.min(points, axis=1)[:2]
101 | max_x, max_y = np.max(points, axis=1)[:2]
102 | new_boxes.append([min_x, min_y, max_x, max_y])
103 | boxes = np.array(new_boxes)
104 | return image, boxes
105 |
106 |
107 | class MiscEffect:
108 | def __init__(self, rotate_prob=0.9, crop_prob=0.5, translate_prob=0.5):
109 | self.rotate_prob = rotate_prob
110 | self.crop_prob = crop_prob
111 | self.translate_prob = translate_prob
112 |
113 | def __call__(self, image, boxes):
114 | image, boxes = rotate(image, boxes, prob=self.rotate_prob)
115 | image, boxes = crop(image, boxes, prob=self.crop_prob)
116 | image, boxes = translate(image, boxes, prob=self.translate_prob)
117 | return image, boxes
118 |
119 |
120 | if __name__ == '__main__':
121 | from yolo.generators.pascal import PascalVocGenerator
122 |
123 | train_generator = PascalVocGenerator(
124 | 'datasets/VOC0712',
125 | 'trainval',
126 | skip_difficult=True,
127 | anchors_path='voc_anchors_416.txt',
128 | batch_size=1
129 | )
130 | misc_effect = MiscEffect()
131 | for i in range(train_generator.size()):
132 | image = train_generator.load_image(i)
133 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
134 | annotations = train_generator.load_annotations(i)
135 | boxes = annotations['bboxes']
136 | for box in boxes.astype(np.int32):
137 | cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
138 | src_image = image.copy()
139 | cv2.namedWindow('src_image', cv2.WINDOW_NORMAL)
140 | cv2.imshow('src_image', src_image)
141 | image, boxes = misc_effect(image, boxes)
142 | for box in boxes.astype(np.int32):
143 | cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 1)
144 | cv2.namedWindow('image', cv2.WINDOW_NORMAL)
145 | cv2.imshow('image', image)
146 | cv2.waitKey(0)
147 |
--------------------------------------------------------------------------------
/generators/coco_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from .generator import Generator
18 | from ..utils.image import read_image_bgr
19 |
20 | import os
21 | import numpy as np
22 |
23 | from pycocotools.coco import COCO
24 |
25 |
26 | class CocoGenerator(Generator):
27 | """ Generate data from the COCO dataset.
28 |
29 | See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information.
30 | """
31 |
32 | def __init__(self, data_dir, set_name, **kwargs):
33 | """ Initialize a COCO data generator.
34 |
35 | Args
36 | data_dir: Path to where the COCO dataset is stored.
37 | set_name: Name of the set to parse.
38 | """
39 | self.data_dir = data_dir
40 | self.set_name = set_name
41 | self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json'))
42 | self.image_ids = self.coco.getImgIds()
43 |
44 | self.load_classes()
45 |
46 | super(CocoGenerator, self).__init__(**kwargs)
47 |
48 | def load_classes(self):
49 | """ Loads the class to label mapping (and inverse) for COCO.
50 | """
51 | # load class names (name -> label)
52 | categories = self.coco.loadCats(self.coco.getCatIds())
53 | categories.sort(key=lambda x: x['id'])
54 |
55 | self.classes = {}
56 | self.coco_labels = {}
57 | self.coco_labels_inverse = {}
58 | for c in categories:
59 | self.coco_labels[len(self.classes)] = c['id']
60 | self.coco_labels_inverse[c['id']] = len(self.classes)
61 | self.classes[c['name']] = len(self.classes)
62 |
63 | # also load the reverse (label -> name)
64 | self.labels = {}
65 | for key, value in self.classes.items():
66 | self.labels[value] = key
67 |
68 | def size(self):
69 | """ Size of the COCO dataset.
70 | """
71 | return len(self.image_ids)
72 |
73 | def num_classes(self):
74 | """ Number of classes in the dataset. For COCO this is 80.
75 | """
76 | return len(self.classes)
77 |
78 | def has_label(self, label):
79 | """ Return True if label is a known label.
80 | """
81 | return label in self.labels
82 |
83 | def has_name(self, name):
84 | """ Returns True if name is a known class.
85 | """
86 | return name in self.classes
87 |
88 | def name_to_label(self, name):
89 | """ Map name to label.
90 | """
91 | return self.classes[name]
92 |
93 | def label_to_name(self, label):
94 | """ Map label to name.
95 | """
96 | return self.labels[label]
97 |
98 | def coco_label_to_label(self, coco_label):
99 | """ Map COCO label to the label as used in the network.
100 | COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes.
101 | """
102 | return self.coco_labels_inverse[coco_label]
103 |
104 | def coco_label_to_name(self, coco_label):
105 | """ Map COCO label to name.
106 | """
107 | return self.label_to_name(self.coco_label_to_label(coco_label))
108 |
109 | def label_to_coco_label(self, label):
110 | """ Map label as used by the network to labels as used by COCO.
111 | """
112 | return self.coco_labels[label]
113 |
114 | def image_aspect_ratio(self, image_index):
115 | """ Compute the aspect ratio for an image with image_index.
116 | """
117 | image = self.coco.loadImgs(self.image_ids[image_index])[0]
118 | return float(image['width']) / float(image['height'])
119 |
120 | def load_image(self, image_index):
121 | """ Load an image at the image_index.
122 | """
123 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
124 | path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name'])
125 | return read_image_bgr(path)
126 |
127 | def load_annotations(self, image_index):
128 | """ Load annotations for an image_index.
129 | """
130 | # get ground truth annotations
131 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
132 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))}
133 |
134 | # some images appear to miss annotations (like image with id 257034)
135 | if len(annotations_ids) == 0:
136 | return annotations
137 |
138 | # parse annotations
139 | coco_annotations = self.coco.loadAnns(annotations_ids)
140 | for idx, a in enumerate(coco_annotations):
141 | # some annotations have basically no width / height, skip them
142 | if a['bbox'][2] < 1 or a['bbox'][3] < 1:
143 | continue
144 |
145 | annotations['labels'] = np.concatenate(
146 | [annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0)
147 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[
148 | a['bbox'][0],
149 | a['bbox'][1],
150 | a['bbox'][0] + a['bbox'][2],
151 | a['bbox'][1] + a['bbox'][3],
152 | ]]], axis=0)
153 |
154 | return annotations
155 |
--------------------------------------------------------------------------------
/yolo/generators/coco.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 | import cv2
17 | import numpy as np
18 | import os
19 | from pycocotools.coco import COCO
20 |
21 | from yolo.generators.common import Generator
22 |
23 |
24 | class CocoGenerator(Generator):
25 | """
26 | Generate data from the COCO dataset.
27 | See https://github.com/cocodataset/cocoapi/tree/master/PythonAPI for more information.
28 | """
29 |
30 | def __init__(self, data_dir, set_name, **kwargs):
31 | """
32 | Initialize a COCO data generator.
33 |
34 | Args
35 | data_dir: Path to where the COCO dataset is stored.
36 | set_name: Name of the set to parse.
37 | """
38 | self.data_dir = data_dir
39 | self.set_name = set_name
40 | self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json'))
41 | self.image_ids = self.coco.getImgIds()
42 |
43 | self.load_classes()
44 |
45 | super(CocoGenerator, self).__init__(**kwargs)
46 |
47 | def load_classes(self):
48 | """
49 | Loads the class to label mapping (and inverse) for COCO.
50 | """
51 | # load class names (name -> label)
52 | categories = self.coco.loadCats(self.coco.getCatIds())
53 | categories.sort(key=lambda x: x['id'])
54 |
55 | self.classes = {}
56 | self.coco_labels = {}
57 | self.coco_labels_inverse = {}
58 | for c in categories:
59 | self.coco_labels[len(self.classes)] = c['id']
60 | self.coco_labels_inverse[c['id']] = len(self.classes)
61 | self.classes[c['name']] = len(self.classes)
62 |
63 | # also load the reverse (label -> name)
64 | self.labels = {}
65 | for key, value in self.classes.items():
66 | self.labels[value] = key
67 |
68 | def size(self):
69 | """ Size of the COCO dataset.
70 | """
71 | return len(self.image_ids)
72 |
73 | def num_classes(self):
74 | """ Number of classes in the dataset. For COCO this is 80.
75 | """
76 | return len(self.classes)
77 |
78 | def has_label(self, label):
79 | """ Return True if label is a known label.
80 | """
81 | return label in self.labels
82 |
83 | def has_name(self, name):
84 | """ Returns True if name is a known class.
85 | """
86 | return name in self.classes
87 |
88 | def name_to_label(self, name):
89 | """ Map name to label.
90 | """
91 | return self.classes[name]
92 |
93 | def label_to_name(self, label):
94 | """ Map label to name.
95 | """
96 | return self.labels[label]
97 |
98 | def coco_label_to_label(self, coco_label):
99 | """ Map COCO label to the label as used in the network.
100 | COCO has some gaps in the order of labels. The highest label is 90, but there are 80 classes.
101 | """
102 | return self.coco_labels_inverse[coco_label]
103 |
104 | def coco_label_to_name(self, coco_label):
105 | """ Map COCO label to name.
106 | """
107 | return self.label_to_name(self.coco_label_to_label(coco_label))
108 |
109 | def label_to_coco_label(self, label):
110 | """ Map label as used by the network to labels as used by COCO.
111 | """
112 | return self.coco_labels[label]
113 |
114 | def image_aspect_ratio(self, image_index):
115 | """ Compute the aspect ratio for an image with image_index.
116 | """
117 | image = self.coco.loadImgs(self.image_ids[image_index])[0]
118 | return float(image['width']) / float(image['height'])
119 |
120 | def load_image(self, image_index):
121 | """
122 | Load an image at the image_index.
123 | """
124 | # {'license': 2, 'file_name': '000000259765.jpg', 'coco_url': 'http://images.cocodataset.org/test2017/000000259765.jpg', 'height': 480, 'width': 640, 'date_captured': '2013-11-21 04:02:31', 'id': 259765}
125 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
126 | path = os.path.join(self.data_dir, 'images', self.set_name, image_info['file_name'])
127 | image = cv2.imread(path)
128 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
129 | return image
130 |
131 | def load_annotations(self, image_index):
132 | """ Load annotations for an image_index.
133 | """
134 | # get ground truth annotations
135 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
136 | annotations = {'labels': np.empty((0,)), 'bboxes': np.empty((0, 4))}
137 |
138 | # some images appear to miss annotations (like image with id 257034)
139 | if len(annotations_ids) == 0:
140 | return annotations
141 |
142 | # parse annotations
143 | coco_annotations = self.coco.loadAnns(annotations_ids)
144 | for idx, a in enumerate(coco_annotations):
145 | # some annotations have basically no width / height, skip them
146 | if a['bbox'][2] < 1 or a['bbox'][3] < 1:
147 | continue
148 |
149 | annotations['labels'] = np.concatenate(
150 | [annotations['labels'], [self.coco_label_to_label(a['category_id'])]], axis=0)
151 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [[
152 | a['bbox'][0],
153 | a['bbox'][1],
154 | a['bbox'][0] + a['bbox'][2],
155 | a['bbox'][1] + a['bbox'][3],
156 | ]]], axis=0)
157 |
158 | return annotations
159 |
160 |
161 | if __name__ == '__main__':
162 | dataset_dir = '/home/adam/.keras/datasets/coco/2017_118_5'
163 | generator = CocoGenerator(data_dir=dataset_dir, set_name='test-dev2017')
164 | print(generator[0])
165 |
--------------------------------------------------------------------------------
/augmentor/color.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image, ImageEnhance, ImageOps
3 |
4 |
5 | def autocontrast(image, prob=0.5):
6 | random_prob = np.random.uniform()
7 | if random_prob > prob:
8 | return image
9 | image = Image.fromarray(image[..., ::-1])
10 | image = ImageOps.autocontrast(image)
11 | image = np.array(image)[..., ::-1]
12 | return image
13 |
14 |
15 | def equalize(image, prob=0.5):
16 | random_prob = np.random.uniform()
17 | if random_prob > prob:
18 | return image
19 | image = Image.fromarray(image[..., ::-1])
20 | image = ImageOps.equalize(image)
21 | image = np.array(image)[..., ::-1]
22 | return image
23 |
24 |
25 | def solarize(image, prob=0.5, threshold=128.):
26 | random_prob = np.random.uniform()
27 | if random_prob > prob:
28 | return image
29 | image = Image.fromarray(image[..., ::-1])
30 | image = ImageOps.solarize(image, threshold=threshold)
31 | image = np.array(image)[..., ::-1]
32 | return image
33 |
34 |
35 | def sharpness(image, prob=0.5, min=0, max=2, factor=None):
36 | random_prob = np.random.uniform()
37 | if random_prob > prob:
38 | return image
39 | if factor is None:
40 | factor = np.random.uniform(min, max)
41 | image = Image.fromarray(image[..., ::-1])
42 | enhancer = ImageEnhance.Sharpness(image)
43 | image = enhancer.enhance(factor=factor)
44 | return np.array(image)[..., ::-1]
45 |
46 |
47 | def color(image, prob=0.5, min=0., max=1., factor=None):
48 | random_prob = np.random.uniform()
49 | if random_prob > prob:
50 | return image
51 | if factor is None:
52 | factor = np.random.uniform(min, max)
53 | image = Image.fromarray(image[..., ::-1])
54 | enhancer = ImageEnhance.Color(image)
55 | image = enhancer.enhance(factor=factor)
56 | return np.array(image)[..., ::-1]
57 |
58 |
59 | def contrast(image, prob=0.5, min=0.2, max=1., factor=None):
60 | random_prob = np.random.uniform()
61 | if random_prob > prob:
62 | return image
63 | if factor is None:
64 | factor = np.random.uniform(min, max)
65 | image = Image.fromarray(image[..., ::-1])
66 | enhancer = ImageEnhance.Contrast(image)
67 | image = enhancer.enhance(factor=factor)
68 | return np.array(image)[..., ::-1]
69 |
70 |
71 | def brightness(image, prob=0.5, min=0.8, max=1., factor=None):
72 | random_prob = np.random.uniform()
73 | if random_prob > prob:
74 | return image
75 | if factor is None:
76 | factor = np.random.uniform(min, max)
77 | image = Image.fromarray(image[..., ::-1])
78 | enhancer = ImageEnhance.Brightness(image)
79 | image = enhancer.enhance(factor=factor)
80 | return np.array(image)[..., ::-1]
81 |
82 |
83 | class VisualEffect:
84 | """
85 | Struct holding parameters and applying image color transformation.
86 |
87 | Args
88 | solarize_threshold:
89 | color_factor: A factor for adjusting color.
90 | contrast_factor: A factor for adjusting contrast.
91 | brightness_factor: A factor for adjusting brightness.
92 | sharpness_factor: A factor for adjusting sharpness.
93 | """
94 |
95 | def __init__(
96 | self,
97 | color_factor=None,
98 | contrast_factor=None,
99 | brightness_factor=None,
100 | sharpness_factor=None,
101 | color_prob=0.5,
102 | contrast_prob=0.5,
103 | brightness_prob=0.5,
104 | sharpness_prob=0.5,
105 | autocontrast_prob=0.5,
106 | equalize_prob=0.5,
107 | solarize_prob=0.1,
108 | solarize_threshold=128.,
109 |
110 | ):
111 | self.color_factor = color_factor
112 | self.contrast_factor = contrast_factor
113 | self.brightness_factor = brightness_factor
114 | self.sharpness_factor = sharpness_factor
115 | self.color_prob = color_prob
116 | self.contrast_prob = contrast_prob
117 | self.brightness_prob = brightness_prob
118 | self.sharpness_prob = sharpness_prob
119 | self.autocontrast_prob = autocontrast_prob
120 | self.equalize_prob = equalize_prob
121 | self.solarize_prob = solarize_prob
122 | self.solarize_threshold = solarize_threshold
123 |
124 | def __call__(self, image):
125 | """
126 | Apply a visual effect on the image.
127 |
128 | Args
129 | image: Image to adjust
130 | """
131 | random_enhance_id = np.random.randint(0, 4)
132 | if random_enhance_id == 0:
133 | image = color(image, prob=self.color_prob, factor=self.color_factor)
134 | elif random_enhance_id == 1:
135 | image = contrast(image, prob=self.contrast_prob, factor=self.contrast_factor)
136 | elif random_enhance_id == 2:
137 | image = brightness(image, prob=self.brightness_prob, factor=self.brightness_factor)
138 | else:
139 | image = sharpness(image, prob=self.sharpness_prob, factor=self.sharpness_factor)
140 |
141 | random_ops_id = np.random.randint(0, 3)
142 | if random_ops_id == 0:
143 | image = autocontrast(image, prob=self.autocontrast_prob)
144 | elif random_ops_id == 1:
145 | image = equalize(image, prob=self.equalize_prob)
146 | else:
147 | image = solarize(image, prob=self.solarize_prob, threshold=self.solarize_threshold)
148 | return image
149 |
150 |
151 | if __name__ == '__main__':
152 | from yolo.generators.pascal import PascalVocGenerator
153 | import cv2
154 |
155 | train_generator = PascalVocGenerator(
156 | 'datasets/VOC0712',
157 | 'trainval',
158 | skip_difficult=True,
159 | anchors_path='voc_anchors_416.txt',
160 | batch_size=1
161 | )
162 | visual_effect = VisualEffect()
163 | for i in range(train_generator.size()):
164 | image = train_generator.load_image(i)
165 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
166 | annotations = train_generator.load_annotations(i)
167 | boxes = annotations['bboxes']
168 | for box in boxes.astype(np.int32):
169 | cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
170 | src_image = image.copy()
171 | image = visual_effect(image)
172 | cv2.namedWindow('image', cv2.WINDOW_NORMAL)
173 | cv2.imshow('image', np.concatenate([src_image, image], axis=1))
174 | cv2.waitKey(0)
175 |
--------------------------------------------------------------------------------
/yolo/model.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Add, BatchNormalization, Concatenate, Conv2D, Input, Reshape
2 | from keras.layers import Lambda, LeakyReLU, UpSampling2D, ZeroPadding2D, Activation
3 | from keras.regularizers import l2
4 | from keras.models import Model
5 | from functools import reduce
6 |
7 | from layers import FilterDetections, ClipBoxes
8 | from losses import focal_with_mask, iou_with_mask
9 | from yolo import config
10 | from yolo.fsaf_layers import FSAFTarget, LevelSelect, Locations, RegressBoxes
11 |
12 |
13 | def compose(*funcs):
14 | """
15 | Compose arbitrarily many functions, evaluated left to right.
16 |
17 | Reference: https://mathieularose.com/function-composition-in-python/
18 | """
19 | # return lambda x: reduce(lambda v, f: f(v), funcs, x)
20 | if funcs:
21 | return reduce(lambda f, g: lambda *args, **kwargs: g(f(*args, **kwargs)), funcs)
22 | else:
23 | raise ValueError('Composition of empty sequence not supported.')
24 |
25 |
26 | def darknet_conv2d(*args, **kwargs):
27 | """
28 | Wrapper to set Darknet parameters for Convolution2D.
29 | """
30 | darknet_conv_kwargs = dict({'kernel_regularizer': l2(5e-4)})
31 | darknet_conv_kwargs['padding'] = 'valid' if kwargs.get('strides') == (2, 2) else 'same'
32 | darknet_conv_kwargs.update(kwargs)
33 | return Conv2D(*args, **darknet_conv_kwargs)
34 |
35 |
36 | def darknet_conv2d_bn_leaky(*args, **kwargs):
37 | """
38 | Darknet Convolution2D followed by BatchNormalization and LeakyReLU.
39 | """
40 | no_bias_kwargs = {'use_bias': False}
41 | no_bias_kwargs.update(kwargs)
42 | return compose(
43 | darknet_conv2d(*args, **no_bias_kwargs),
44 | BatchNormalization(),
45 | LeakyReLU(alpha=0.1))
46 |
47 |
48 | def resblock_body(x, num_filters, num_blocks):
49 | """
50 | A series of resblocks starting with a downsampling Convolution2D
51 | """
52 | # Darknet uses left and top padding instead of 'same' mode
53 | x = ZeroPadding2D(((1, 0), (1, 0)))(x)
54 | x = darknet_conv2d_bn_leaky(num_filters, (3, 3), strides=(2, 2))(x)
55 | for i in range(num_blocks):
56 | y = compose(
57 | darknet_conv2d_bn_leaky(num_filters // 2, (1, 1)),
58 | darknet_conv2d_bn_leaky(num_filters, (3, 3)))(x)
59 | x = Add()([x, y])
60 | return x
61 |
62 |
63 | def darknet_body(x):
64 | """
65 | Darknet body having 52 Convolution2D layers
66 | """
67 | x = darknet_conv2d_bn_leaky(32, (3, 3))(x)
68 | x = resblock_body(x, 64, 1)
69 | x = resblock_body(x, 128, 2)
70 | x = resblock_body(x, 256, 8)
71 | x = resblock_body(x, 512, 8)
72 | x = resblock_body(x, 1024, 4)
73 | return x
74 |
75 |
76 | def make_last_layers(x, num_filters, out_filters):
77 | """
78 | 6 conv2d_bn_leaky layers followed by a conv2d layer
79 | """
80 | x = compose(darknet_conv2d_bn_leaky(num_filters, (1, 1)),
81 | darknet_conv2d_bn_leaky(num_filters * 2, (3, 3)),
82 | darknet_conv2d_bn_leaky(num_filters, (1, 1)),
83 | darknet_conv2d_bn_leaky(num_filters * 2, (3, 3)),
84 | darknet_conv2d_bn_leaky(num_filters, (1, 1)))(x)
85 | y = compose(darknet_conv2d_bn_leaky(num_filters * 2, (3, 3)),
86 | darknet_conv2d(out_filters, (1, 1)))(x)
87 | return x, y
88 |
89 |
90 | def yolo_body(num_classes=20, score_threshold=0.01):
91 | """
92 | Create YOLO_V3 model CNN body in Keras.
93 |
94 | Args:
95 | num_classes:
96 | score_threshold:
97 |
98 | Returns:
99 |
100 | """
101 | image_input = Input(shape=(None, None, 3), name='image_input')
102 | darknet = Model([image_input], darknet_body(image_input))
103 | ##################################################
104 | # build fsaf head
105 | ##################################################
106 | x, y1 = make_last_layers(darknet.output, 512, 4 + num_classes)
107 |
108 | x = compose(darknet_conv2d_bn_leaky(256, (1, 1)), UpSampling2D(2))(x)
109 | x = Concatenate()([x, darknet.layers[152].output])
110 | x, y2 = make_last_layers(x, 256, 4 + num_classes)
111 | x = compose(darknet_conv2d_bn_leaky(128, (1, 1)), UpSampling2D(2))(x)
112 | x = Concatenate()([x, darknet.layers[92].output])
113 | x, y3 = make_last_layers(x, 128, 4 + num_classes)
114 | y1_ = Reshape((-1, 4 + num_classes))(y1)
115 | y2_ = Reshape((-1, 4 + num_classes))(y2)
116 | y3_ = Reshape((-1, 4 + num_classes))(y3)
117 | y = Concatenate(axis=1)([y1_, y2_, y3_])
118 | batch_cls_pred = Lambda(lambda x: x[..., 4:])(y)
119 | batch_regr_pred = Lambda(lambda x: x[..., :4])(y)
120 | batch_cls_pred = Activation('sigmoid')(batch_cls_pred)
121 | batch_regr_pred = Activation('relu')(batch_regr_pred)
122 |
123 | gt_boxes_input = Input(shape=(config.MAX_NUM_GT_BOXES, 5), name='gt_boxes_input')
124 | grid_shapes_input = Input((len(config.STRIDES), 2), dtype='int32', name='grid_shapes_input')
125 | batch_gt_box_levels = LevelSelect(name='level_select')(
126 | [batch_cls_pred, batch_regr_pred, grid_shapes_input, gt_boxes_input])
127 | batch_cls_target, batch_cls_mask, batch_cls_num_pos, batch_regr_target, batch_regr_mask = FSAFTarget(
128 | num_classes=num_classes,
129 | name='fsaf_target')(
130 | [batch_gt_box_levels, grid_shapes_input, gt_boxes_input])
131 | focal_loss_graph = focal_with_mask()
132 | iou_loss_graph = iou_with_mask()
133 | cls_loss = Lambda(focal_loss_graph,
134 | output_shape=(1,),
135 | name="cls_loss")(
136 | [batch_cls_target, batch_cls_pred, batch_cls_mask, batch_cls_num_pos])
137 | regr_loss = Lambda(iou_loss_graph,
138 | output_shape=(1,),
139 | name="regr_loss")([batch_regr_target, batch_regr_pred, batch_regr_mask])
140 | model = Model(inputs=[image_input, gt_boxes_input, grid_shapes_input],
141 | outputs=[cls_loss, regr_loss],
142 | name='fsaf')
143 |
144 | # compute the anchors
145 | features = [y1, y2, y3]
146 |
147 | locations, strides = Locations(strides=config.STRIDES)(features)
148 |
149 | # apply predicted regression to anchors
150 | boxes = RegressBoxes(name='boxes')([locations, strides, batch_regr_pred])
151 | boxes = ClipBoxes(name='clipped_boxes')([image_input, boxes])
152 |
153 | # filter detections (apply NMS / score threshold / select top-k)
154 | detections = FilterDetections(
155 | nms=True,
156 | class_specific_filter=True,
157 | name='filtered_detections',
158 | score_threshold=score_threshold
159 | )([boxes, batch_cls_pred])
160 |
161 | prediction_model = Model(inputs=image_input, outputs=detections, name='fsaf_detection')
162 | return model, prediction_model
163 |
--------------------------------------------------------------------------------
/util_graphs.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import keras.backend as K
3 |
4 |
5 | def xyxy2cxcywh(xyxy):
6 | """
7 | Convert [x1 y1 x2 y2] box format to [cx cx w h] format.
8 | """
9 | return tf.concat((0.5 * (xyxy[:, 0:2] + xyxy[:, 2:4]), xyxy[:, 2:4] - xyxy[:, 0:2]), axis=-1)
10 |
11 |
12 | def cxcywh2xyxy(xywh):
13 | """
14 | Convert [cx cy w y] box format to [x1 y1 x2 y2] format.
15 | """
16 | return tf.concat((xywh[:, 0:2] - 0.5 * xywh[:, 2:4], xywh[:, 0:2] + 0.5 * xywh[:, 2:4]), axis=-1)
17 |
18 |
19 | def prop_box_graph(boxes, scale, width, height):
20 | """
21 | Compute proportional box coordinates.
22 |
23 | Box centers are fixed. Box w and h scaled by scale.
24 | """
25 | prop_boxes = xyxy2cxcywh(boxes)
26 | prop_boxes = tf.concat((prop_boxes[:, :2], prop_boxes[:, 2:] * scale), axis=-1)
27 | prop_boxes = cxcywh2xyxy(prop_boxes)
28 | x1 = tf.floor(prop_boxes[:, 0])
29 | y1 = tf.floor(prop_boxes[:, 1])
30 | x2 = tf.math.ceil(prop_boxes[:, 2])
31 | y2 = tf.math.ceil(prop_boxes[:, 3])
32 | width = tf.cast(width, tf.float32)
33 | height = tf.cast(height, tf.float32)
34 | x2 = tf.cast(tf.clip_by_value(x2, 1, width), tf.int32)
35 | y2 = tf.cast(tf.clip_by_value(y2, 1, height), tf.int32)
36 | x1 = tf.cast(tf.clip_by_value(x1, 0, tf.cast(x2, tf.float32) - 1), tf.int32)
37 | y1 = tf.cast(tf.clip_by_value(y1, 0, tf.cast(y2, tf.float32) - 1), tf.int32)
38 |
39 | return x1, y1, x2, y2
40 |
41 |
42 | def prop_box_graph_2(boxes, scale, width, height):
43 | """
44 | Compute proportional box coordinates.
45 |
46 | Box centers are fixed. Box w and h scaled by scale.
47 | """
48 | prop_boxes = xyxy2cxcywh(boxes)
49 | prop_boxes = tf.concat((prop_boxes[:, :2], prop_boxes[:, 2:] * scale), axis=-1)
50 | prop_boxes = cxcywh2xyxy(prop_boxes)
51 | # (n, 1)
52 | x1 = tf.floor(prop_boxes[:, 0:1])
53 | y1 = tf.floor(prop_boxes[:, 1:2])
54 | x2 = tf.math.ceil(prop_boxes[:, 2:3])
55 | y2 = tf.math.ceil(prop_boxes[:, 3:4])
56 | width = tf.cast(width, tf.float32)
57 | height = tf.cast(height, tf.float32)
58 | x2 = tf.cast(tf.clip_by_value(x2, 1, width), tf.int32)
59 | y2 = tf.cast(tf.clip_by_value(y2, 1, height), tf.int32)
60 | x1 = tf.cast(tf.clip_by_value(x1, 0, tf.cast(x2, tf.float32) - 1), tf.int32)
61 | y1 = tf.cast(tf.clip_by_value(y1, 0, tf.cast(y2, tf.float32) - 1), tf.int32)
62 |
63 | return x1, y1, x2, y2
64 |
65 |
66 | def trim_zeros_graph(boxes, name='trim_zeros'):
67 | """
68 | Often boxes are represented with matrices of shape [N, 4] and are padded with zeros.
69 | This removes zero boxes.
70 |
71 | Args:
72 | boxes: [N, 4] matrix of boxes.
73 | name: name of tensor
74 |
75 | Returns:
76 |
77 | """
78 | non_zeros = tf.cast(tf.reduce_sum(tf.abs(boxes), axis=1), tf.bool)
79 | boxes = tf.boolean_mask(boxes, non_zeros, name=name)
80 | return boxes, non_zeros
81 |
82 |
83 | def bbox_transform_inv(boxes, deltas, mean=None, std=None):
84 | """
85 | Applies deltas (usually regression results) to boxes (usually anchors).
86 |
87 | Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed.
88 | The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes.
89 |
90 | Args
91 | boxes: np.array of shape (B, N, 4), where B is the batch size, N the number of boxes and 4 values for (x1, y1, x2, y2).
92 | deltas: np.array of same shape as boxes. These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height.
93 | mean: The mean value used when computing deltas (defaults to [0, 0, 0, 0]).
94 | std: The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]).
95 |
96 | Returns
97 | A np.array of the same shape as boxes, but with deltas applied to each box.
98 | The mean and std are used during training to normalize the regression values (networks love normalization).
99 | """
100 | if mean is None:
101 | mean = [0, 0, 0, 0]
102 | if std is None:
103 | std = [0.2, 0.2, 0.2, 0.2]
104 |
105 | width = boxes[:, :, 2] - boxes[:, :, 0]
106 | height = boxes[:, :, 3] - boxes[:, :, 1]
107 |
108 | x1 = boxes[:, :, 0] + (deltas[:, :, 0] * std[0] + mean[0]) * width
109 | y1 = boxes[:, :, 1] + (deltas[:, :, 1] * std[1] + mean[1]) * height
110 | x2 = boxes[:, :, 2] + (deltas[:, :, 2] * std[2] + mean[2]) * width
111 | y2 = boxes[:, :, 3] + (deltas[:, :, 3] * std[3] + mean[3]) * height
112 |
113 | pred_boxes = K.stack([x1, y1, x2, y2], axis=2)
114 |
115 | return pred_boxes
116 |
117 |
118 | def shift(shape, stride, anchors):
119 | """
120 | Produce shifted anchors based on shape of the map and stride size.
121 |
122 | Args
123 | shape: Shape to shift the anchors over. (h,w)
124 | stride: Stride to shift the anchors with over the shape.
125 | anchors: The anchors to apply at each location.
126 |
127 | Returns
128 | shifted_anchors: (fh * fw * num_anchors, 4)
129 | """
130 | shift_x = (K.arange(0, shape[1], dtype=K.floatx()) + K.constant(0.5, dtype=K.floatx())) * stride
131 | shift_y = (K.arange(0, shape[0], dtype=K.floatx()) + K.constant(0.5, dtype=K.floatx())) * stride
132 | shift_x, shift_y = tf.meshgrid(shift_x, shift_y)
133 | shift_x = K.reshape(shift_x, [-1])
134 | shift_y = K.reshape(shift_y, [-1])
135 |
136 | # (4, fh * fw)
137 | shifts = K.stack([
138 | shift_x,
139 | shift_y,
140 | shift_x,
141 | shift_y
142 | ], axis=0)
143 | # (fh * fw, 4)
144 | shifts = K.transpose(shifts)
145 | number_anchors = K.shape(anchors)[0]
146 |
147 | # number of base points = fh * fw
148 | k = K.shape(shifts)[0]
149 |
150 | # (k=fh*fw, num_anchors, 4)
151 | shifted_anchors = K.reshape(anchors, [1, number_anchors, 4]) + K.cast(K.reshape(shifts, [k, 1, 4]), K.floatx())
152 | # (k * num_anchors, 4)
153 | shifted_anchors = K.reshape(shifted_anchors, [k * number_anchors, 4])
154 |
155 | return shifted_anchors
156 |
157 |
158 | def resize_images(images, size, method='bilinear', align_corners=False):
159 | """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/resize_images .
160 |
161 | Args
162 | method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area').
163 | """
164 | methods = {
165 | 'bilinear': tf.image.ResizeMethod.BILINEAR,
166 | 'nearest': tf.image.ResizeMethod.NEAREST_NEIGHBOR,
167 | 'bicubic': tf.image.ResizeMethod.BICUBIC,
168 | 'area': tf.image.ResizeMethod.AREA,
169 | }
170 | return tf.image.resize_images(images, size, methods[method], align_corners)
171 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import keras
18 | from keras.utils import get_file
19 | import keras_resnet
20 | import keras_resnet.models
21 | import tensorflow as tf
22 |
23 | from models import retinanet
24 | from models import Backbone
25 | from utils.image import preprocess_image
26 | import configure
27 |
28 |
29 | class ResNetBackbone(Backbone):
30 | """
31 | Describes backbone information and provides utility functions.
32 | """
33 |
34 | def __init__(self, backbone):
35 | super(ResNetBackbone, self).__init__(backbone)
36 | self.custom_objects.update(keras_resnet.custom_objects)
37 |
38 | def retinanet(self, *args, **kwargs):
39 | """
40 | Returns a retinanet model using the correct backbone.
41 | """
42 | return resnet_retinanet(*args, backbone=self.backbone, **kwargs)
43 |
44 | def fsaf(self, num_classes, modifier):
45 | """
46 | Returns a retinanet model using the correct backbone.
47 | """
48 | return resnet_fsaf(num_classes=num_classes, backbone=self.backbone, modifier=modifier)
49 |
50 | def download_imagenet(self):
51 | """
52 | Downloads ImageNet weights and returns path to weights file.
53 | """
54 | resnet_filename = 'ResNet-{}-model.keras.h5'
55 | resnet_resource = 'https://github.com/fizyr/keras-models/releases/download/v0.0.1/{}'.format(resnet_filename)
56 | depth = int(self.backbone.replace('resnet', ''))
57 |
58 | filename = resnet_filename.format(depth)
59 | resource = resnet_resource.format(depth)
60 | if depth == 50:
61 | checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319'
62 | elif depth == 101:
63 | checksum = '05dc86924389e5b401a9ea0348a3213c'
64 | elif depth == 152:
65 | checksum = '6ee11ef2b135592f8031058820bb9e71'
66 | else:
67 | raise ValueError('Unknown depth')
68 |
69 | return get_file(
70 | filename,
71 | resource,
72 | cache_subdir='models',
73 | md5_hash=checksum
74 | )
75 |
76 | def validate(self):
77 | """
78 | Checks whether the backbone string is correct.
79 | """
80 | allowed_backbones = ['resnet50', 'resnet101', 'resnet152']
81 |
82 | if self.backbone not in allowed_backbones:
83 | raise ValueError(
84 | 'Backbone (\'{}\') not in allowed backbones ({}).'.format(self.backbone, allowed_backbones))
85 |
86 | def preprocess_image(self, inputs):
87 | """
88 | Takes as input an image and prepares it for being passed through the network.
89 | """
90 | return preprocess_image(inputs, mode='caffe')
91 |
92 |
93 | def resnet_retinanet(num_classes, backbone='resnet50', modifier=None, **kwargs):
94 | """
95 | Constructs a retinanet model using a resnet backbone.
96 |
97 | Args
98 | num_classes: Number of classes to predict.
99 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')).
100 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
101 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
102 |
103 | Returns
104 | RetinaNet model with a ResNet backbone.
105 | """
106 | # choose default input
107 | inputs = keras.layers.Input(shape=(None, None, 3))
108 |
109 | # create the resnet backbone
110 | if backbone == 'resnet50':
111 | resnet = keras_resnet.models.ResNet50(inputs, include_top=False, freeze_bn=True)
112 | elif backbone == 'resnet101':
113 | resnet = keras_resnet.models.ResNet101(inputs, include_top=False, freeze_bn=True)
114 | elif backbone == 'resnet152':
115 | resnet = keras_resnet.models.ResNet152(inputs, include_top=False, freeze_bn=True)
116 | else:
117 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone))
118 |
119 | # invoke modifier if given
120 | if modifier:
121 | resnet = modifier(resnet)
122 |
123 | # create the full model
124 | return retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=resnet.outputs[1:], **kwargs)
125 |
126 |
127 | def resnet_fsaf(num_classes, backbone='resnet50', modifier=None):
128 | """
129 | Constructs a retinanet model using a resnet backbone.
130 |
131 | Args
132 | num_classes: Number of classes to predict.
133 | backbone: Which backbone to use (one of ('resnet50', 'resnet101', 'resnet152')).
134 | inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)).
135 | modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example).
136 |
137 | Returns
138 | RetinaNet model with a ResNet backbone.
139 | """
140 | image_input = keras.layers.Input(shape=(None, None, 3))
141 | gt_boxes_input = keras.layers.Input(shape=(configure.MAX_NUM_GT_BOXES, 5))
142 | feature_shapes_input = keras.layers.Input((5, 2), dtype='int32')
143 |
144 | # create the resnet backbone
145 | if backbone == 'resnet50':
146 | resnet = keras_resnet.models.ResNet50(image_input, include_top=False, freeze_bn=True)
147 | elif backbone == 'resnet101':
148 | resnet = keras_resnet.models.ResNet101(image_input, include_top=False, freeze_bn=True)
149 | elif backbone == 'resnet152':
150 | resnet = keras_resnet.models.ResNet152(image_input, include_top=False, freeze_bn=True)
151 | else:
152 | raise ValueError('Backbone (\'{}\') is invalid.'.format(backbone))
153 |
154 | # invoke modifier if given
155 | if modifier:
156 | resnet = modifier(resnet)
157 |
158 | # create the full model
159 | return retinanet.fsaf(inputs=[image_input, gt_boxes_input, feature_shapes_input],
160 | num_classes=num_classes,
161 | backbone_layers=resnet.outputs[1:])
162 |
163 |
164 | def resnet50_retinanet(num_classes, inputs=None, **kwargs):
165 | return resnet_retinanet(num_classes=num_classes, backbone='resnet50', inputs=inputs, **kwargs)
166 |
167 |
168 | def resnet101_retinanet(num_classes, inputs=None, **kwargs):
169 | return resnet_retinanet(num_classes=num_classes, backbone='resnet101', inputs=inputs, **kwargs)
170 |
171 |
172 | def resnet152_retinanet(num_classes, inputs=None, **kwargs):
173 | return resnet_retinanet(num_classes=num_classes, backbone='resnet152', inputs=inputs, **kwargs)
174 |
--------------------------------------------------------------------------------
/callbacks.py:
--------------------------------------------------------------------------------
1 | import keras
2 | from utils.eval import evaluate
3 | from utils.coco_eval import evaluate_coco
4 |
5 |
6 | class Evaluate(keras.callbacks.Callback):
7 | """
8 | Evaluation callback for arbitrary datasets.
9 | """
10 |
11 | def __init__(
12 | self,
13 | generator,
14 | iou_threshold=0.5,
15 | score_threshold=0.05,
16 | max_detections=100,
17 | save_path=None,
18 | tensorboard=None,
19 | weighted_average=False,
20 | verbose=1
21 | ):
22 | """
23 | Evaluate a given dataset using a given model at the end of every epoch during training.
24 |
25 | Args:
26 | generator: The generator that represents the dataset to evaluate.
27 | iou_threshold: The threshold used to consider when a detection is positive or negative.
28 | score_threshold: The score confidence threshold to use for detections.
29 | max_detections: The maximum number of detections to use per image.
30 | save_path: The path to save images with visualized detections to.
31 | tensorboard: Instance of keras.callbacks.TensorBoard used to log the mAP value.
32 | weighted_average: Compute the mAP using the weighted average of precisions among classes.
33 | verbose: Set the verbosity level, by default this is set to 1.
34 | """
35 | self.generator = generator
36 | self.iou_threshold = iou_threshold
37 | self.score_threshold = score_threshold
38 | self.max_detections = max_detections
39 | self.save_path = save_path
40 | self.tensorboard = tensorboard
41 | self.weighted_average = weighted_average
42 | self.verbose = verbose
43 |
44 | super(Evaluate, self).__init__()
45 |
46 | def on_epoch_end(self, epoch, logs=None):
47 | logs = logs or {}
48 |
49 | # run evaluation
50 | average_precisions = evaluate(
51 | self.generator,
52 | self.model,
53 | iou_threshold=self.iou_threshold,
54 | score_threshold=self.score_threshold,
55 | max_detections=self.max_detections,
56 | visualize=False,
57 | )
58 |
59 | # compute per class average precision
60 | total_instances = []
61 | precisions = []
62 | for label, (average_precision, num_annotations) in average_precisions.items():
63 | if self.verbose == 1:
64 | print('{:.0f} instances of class'.format(num_annotations),
65 | self.generator.label_to_name(label), 'with average precision: {:.4f}'.format(average_precision))
66 | total_instances.append(num_annotations)
67 | precisions.append(average_precision)
68 | if self.weighted_average:
69 | self.mean_ap = sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances)
70 | else:
71 | self.mean_ap = sum(precisions) / sum(x > 0 for x in total_instances)
72 |
73 | if self.tensorboard is not None and self.tensorboard.writer is not None:
74 | import tensorflow as tf
75 | summary = tf.Summary()
76 | summary_value = summary.value.add()
77 | summary_value.simple_value = self.mean_ap
78 | summary_value.tag = "mAP"
79 | self.tensorboard.writer.add_summary(summary, epoch)
80 |
81 | logs['mAP'] = self.mean_ap
82 |
83 | if self.verbose == 1:
84 | print('mAP: {:.4f}'.format(self.mean_ap))
85 |
86 |
87 | class RedirectModel(keras.callbacks.Callback):
88 | """
89 | Callback which wraps another callback, but executed on a different model.
90 |
91 | ```python
92 | model = keras.models.load_model('model.h5')
93 | model_checkpoint = ModelCheckpoint(filepath='snapshot.h5')
94 | parallel_model = multi_gpu_model(model, gpus=2)
95 | parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)])
96 | ```
97 |
98 | Args
99 | callback : callback to wrap.
100 | model : model to use when executing callbacks.
101 | """
102 |
103 | def __init__(self,
104 | callback,
105 | model):
106 | super(RedirectModel, self).__init__()
107 |
108 | self.callback = callback
109 | self.redirect_model = model
110 |
111 | def on_epoch_begin(self, epoch, logs=None):
112 | self.callback.on_epoch_begin(epoch, logs=logs)
113 |
114 | def on_epoch_end(self, epoch, logs=None):
115 | self.callback.on_epoch_end(epoch, logs=logs)
116 |
117 | def on_batch_begin(self, batch, logs=None):
118 | self.callback.on_batch_begin(batch, logs=logs)
119 |
120 | def on_batch_end(self, batch, logs=None):
121 | self.callback.on_batch_end(batch, logs=logs)
122 |
123 | def on_train_begin(self, logs=None):
124 | # overwrite the model with our custom model
125 | self.callback.set_model(self.redirect_model)
126 |
127 | self.callback.on_train_begin(logs=logs)
128 |
129 | def on_train_end(self, logs=None):
130 | self.callback.on_train_end(logs=logs)
131 |
132 |
133 | class CocoEval(keras.callbacks.Callback):
134 | """ Performs COCO evaluation on each epoch.
135 | """
136 | def __init__(self, generator, tensorboard=None, threshold=0.05):
137 | """ CocoEval callback intializer.
138 |
139 | Args
140 | generator : The generator used for creating validation data.
141 | tensorboard : If given, the results will be written to tensorboard.
142 | threshold : The score threshold to use.
143 | """
144 | self.generator = generator
145 | self.threshold = threshold
146 | self.tensorboard = tensorboard
147 |
148 | super(CocoEval, self).__init__()
149 |
150 | def on_epoch_end(self, epoch, logs=None):
151 | logs = logs or {}
152 |
153 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
154 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]',
155 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]',
156 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
157 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
158 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
159 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]',
160 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]',
161 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
162 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
163 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
164 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]']
165 | coco_eval_stats = evaluate_coco(self.generator, self.model, self.threshold)
166 | if coco_eval_stats is not None and self.tensorboard is not None and self.tensorboard.writer is not None:
167 | import tensorflow as tf
168 | summary = tf.Summary()
169 | for index, result in enumerate(coco_eval_stats):
170 | summary_value = summary.value.add()
171 | summary_value.simple_value = result
172 | summary_value.tag = '{}. {}'.format(index + 1, coco_tag[index])
173 | self.tensorboard.writer.add_summary(summary, epoch)
174 | logs[coco_tag[index]] = result
175 |
--------------------------------------------------------------------------------
/generators/voc_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from generators.generator import Generator
18 | from utils.image import read_image_bgr
19 |
20 | import os
21 | import numpy as np
22 | from six import raise_from
23 | from PIL import Image
24 |
25 | try:
26 | import xml.etree.cElementTree as ET
27 | except ImportError:
28 | import xml.etree.ElementTree as ET
29 |
30 | voc_classes = {
31 | 'aeroplane': 0,
32 | 'bicycle': 1,
33 | 'bird': 2,
34 | 'boat': 3,
35 | 'bottle': 4,
36 | 'bus': 5,
37 | 'car': 6,
38 | 'cat': 7,
39 | 'chair': 8,
40 | 'cow': 9,
41 | 'diningtable': 10,
42 | 'dog': 11,
43 | 'horse': 12,
44 | 'motorbike': 13,
45 | 'person': 14,
46 | 'pottedplant': 15,
47 | 'sheep': 16,
48 | 'sofa': 17,
49 | 'train': 18,
50 | 'tvmonitor': 19
51 | }
52 |
53 |
54 | def _findNode(parent, name, debug_name=None, parse=None):
55 | if debug_name is None:
56 | debug_name = name
57 |
58 | result = parent.find(name)
59 | if result is None:
60 | raise ValueError('missing element \'{}\''.format(debug_name))
61 | if parse is not None:
62 | try:
63 | return parse(result.text)
64 | except ValueError as e:
65 | raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None)
66 | return result
67 |
68 |
69 | class PascalVocGenerator(Generator):
70 | """
71 | Generate data for a Pascal VOC dataset.
72 |
73 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information.
74 | """
75 |
76 | def __init__(
77 | self,
78 | data_dir,
79 | set_name,
80 | classes=voc_classes,
81 | image_extension='.jpg',
82 | skip_truncated=False,
83 | skip_difficult=False,
84 | **kwargs
85 | ):
86 | """
87 | Initialize a Pascal VOC data generator.
88 |
89 | Args:
90 | data_dir: the path of directory which contains ImageSets directory
91 | set_name: test|trainval|train|val
92 | classes: class names tos id mapping
93 | image_extension: image filename ext
94 | skip_truncated:
95 | skip_difficult:
96 | **kwargs:
97 | """
98 | self.data_dir = data_dir
99 | self.set_name = set_name
100 | self.classes = classes
101 | self.image_names = [l.strip().split(None, 1)[0] for l in
102 | open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()]
103 | self.image_extension = image_extension
104 | self.skip_truncated = skip_truncated
105 | self.skip_difficult = skip_difficult
106 | # class ids to names mapping
107 | self.labels = {}
108 | for key, value in self.classes.items():
109 | self.labels[value] = key
110 |
111 | super(PascalVocGenerator, self).__init__(**kwargs)
112 |
113 | def size(self):
114 | """
115 | Size of the dataset.
116 | """
117 | return len(self.image_names)
118 |
119 | def num_classes(self):
120 | """
121 | Number of classes in the dataset.
122 | """
123 | return len(self.classes)
124 |
125 | def has_label(self, label):
126 | """
127 | Return True if label is a known label.
128 | """
129 | return label in self.labels
130 |
131 | def has_name(self, name):
132 | """
133 | Returns True if name is a known class.
134 | """
135 | return name in self.classes
136 |
137 | def name_to_label(self, name):
138 | """
139 | Map name to label.
140 | """
141 | return self.classes[name]
142 |
143 | def label_to_name(self, label):
144 | """
145 | Map label to name.
146 | """
147 | return self.labels[label]
148 |
149 | def image_aspect_ratio(self, image_index):
150 | """
151 | Compute the aspect ratio for an image with image_index.
152 | """
153 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension)
154 | image = Image.open(path)
155 | return float(image.width) / float(image.height)
156 |
157 | def load_image(self, image_index):
158 | """
159 | Load an image at the image_index.
160 | """
161 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension)
162 | return read_image_bgr(path)
163 |
164 | def __parse_annotation(self, element):
165 | """
166 | Parse an annotation given an XML element.
167 | """
168 | truncated = _findNode(element, 'truncated', parse=int)
169 | difficult = _findNode(element, 'difficult', parse=int)
170 |
171 | class_name = _findNode(element, 'name').text
172 | if class_name not in self.classes:
173 | raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys())))
174 |
175 | box = np.zeros((4,))
176 | label = self.name_to_label(class_name)
177 |
178 | bndbox = _findNode(element, 'bndbox')
179 | box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1
180 | box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1
181 | box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1
182 | box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1
183 |
184 | return truncated, difficult, box, label
185 |
186 | def __parse_annotations(self, xml_root):
187 | """
188 | Parse all annotations under the xml_root.
189 | """
190 | annotations = {'labels': np.empty((0,), dtype=np.int32),
191 | 'bboxes': np.empty((0, 4))}
192 | for i, element in enumerate(xml_root.iter('object')):
193 | try:
194 | truncated, difficult, box, label = self.__parse_annotation(element)
195 | except ValueError as e:
196 | raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None)
197 |
198 | if truncated and self.skip_truncated:
199 | continue
200 | if difficult and self.skip_difficult:
201 | continue
202 |
203 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [box]])
204 | annotations['labels'] = np.concatenate([annotations['labels'], [label]])
205 |
206 | return annotations
207 |
208 | def load_annotations(self, image_index):
209 | """
210 | Load annotations for an image_index.
211 | """
212 | filename = self.image_names[image_index] + '.xml'
213 | try:
214 | tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename))
215 | return self.__parse_annotations(tree.getroot())
216 | except ET.ParseError as e:
217 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None)
218 | except ValueError as e:
219 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None)
220 |
--------------------------------------------------------------------------------
/yolo/generators/pascal.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import cv2
18 | import numpy as np
19 | import os
20 | from six import raise_from
21 |
22 | from yolo.generators.common import Generator
23 |
24 | try:
25 | import xml.etree.cElementTree as ET
26 | except ImportError:
27 | import xml.etree.ElementTree as ET
28 |
29 | voc_classes = {
30 | 'aeroplane': 0,
31 | 'bicycle': 1,
32 | 'bird': 2,
33 | 'boat': 3,
34 | 'bottle': 4,
35 | 'bus': 5,
36 | 'car': 6,
37 | 'cat': 7,
38 | 'chair': 8,
39 | 'cow': 9,
40 | 'diningtable': 10,
41 | 'dog': 11,
42 | 'horse': 12,
43 | 'motorbike': 13,
44 | 'person': 14,
45 | 'pottedplant': 15,
46 | 'sheep': 16,
47 | 'sofa': 17,
48 | 'train': 18,
49 | 'tvmonitor': 19
50 | }
51 |
52 |
53 | def _findNode(parent, name, debug_name=None, parse=None):
54 | if debug_name is None:
55 | debug_name = name
56 |
57 | result = parent.find(name)
58 | if result is None:
59 | raise ValueError('missing element \'{}\''.format(debug_name))
60 | if parse is not None:
61 | try:
62 | return parse(result.text)
63 | except ValueError as e:
64 | raise_from(ValueError('illegal value for \'{}\': {}'.format(debug_name, e)), None)
65 | return result
66 |
67 |
68 | class PascalVocGenerator(Generator):
69 | """
70 | Generate data for a Pascal VOC dataset.
71 |
72 | See http://host.robots.ox.ac.uk/pascal/VOC/ for more information.
73 | """
74 |
75 | def __init__(
76 | self,
77 | data_dir,
78 | set_name,
79 | classes=voc_classes,
80 | image_extension='.jpg',
81 | skip_truncated=False,
82 | skip_difficult=False,
83 | **kwargs
84 | ):
85 | """
86 | Initialize a Pascal VOC data generator.
87 |
88 | Args:
89 | data_dir: the path of directory which contains ImageSets directory
90 | set_name: test|trainval|train|val
91 | classes: class names tos id mapping
92 | image_extension: image filename ext
93 | skip_truncated:
94 | skip_difficult:
95 | **kwargs:
96 | """
97 | self.data_dir = data_dir
98 | self.set_name = set_name
99 | self.classes = classes
100 | self.image_names = [l.strip().split(None, 1)[0] for l in
101 | open(os.path.join(data_dir, 'ImageSets', 'Main', set_name + '.txt')).readlines()]
102 | self.image_extension = image_extension
103 | self.skip_truncated = skip_truncated
104 | self.skip_difficult = skip_difficult
105 | # class ids to names mapping
106 | self.labels = {}
107 | for key, value in self.classes.items():
108 | self.labels[value] = key
109 |
110 | super(PascalVocGenerator, self).__init__(**kwargs)
111 |
112 | def size(self):
113 | """
114 | Size of the dataset.
115 | """
116 | return len(self.image_names)
117 |
118 | def num_classes(self):
119 | """
120 | Number of classes in the dataset.
121 | """
122 | return len(self.classes)
123 |
124 | def has_label(self, label):
125 | """
126 | Return True if label is a known label.
127 | """
128 | return label in self.labels
129 |
130 | def has_name(self, name):
131 | """
132 | Returns True if name is a known class.
133 | """
134 | return name in self.classes
135 |
136 | def name_to_label(self, name):
137 | """
138 | Map name to label.
139 | """
140 | return self.classes[name]
141 |
142 | def label_to_name(self, label):
143 | """
144 | Map label to name.
145 | """
146 | return self.labels[label]
147 |
148 | def image_aspect_ratio(self, image_index):
149 | """
150 | Compute the aspect ratio for an image with image_index.
151 | """
152 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension)
153 | image = cv2.imread(path)
154 | h, w = image.shape[:2]
155 | return float(w) / float(h)
156 |
157 | def load_image(self, image_index):
158 | """
159 | Load an image at the image_index.
160 | """
161 | path = os.path.join(self.data_dir, 'JPEGImages', self.image_names[image_index] + self.image_extension)
162 | image = cv2.imread(path)
163 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
164 | return image
165 |
166 | def __parse_annotation(self, element):
167 | """
168 | Parse an annotation given an XML element.
169 | """
170 | truncated = _findNode(element, 'truncated', parse=int)
171 | difficult = _findNode(element, 'difficult', parse=int)
172 |
173 | class_name = _findNode(element, 'name').text
174 | if class_name not in self.classes:
175 | raise ValueError('class name \'{}\' not found in classes: {}'.format(class_name, list(self.classes.keys())))
176 |
177 | box = np.zeros((4,))
178 | label = self.name_to_label(class_name)
179 |
180 | bndbox = _findNode(element, 'bndbox')
181 | box[0] = _findNode(bndbox, 'xmin', 'bndbox.xmin', parse=float) - 1
182 | box[1] = _findNode(bndbox, 'ymin', 'bndbox.ymin', parse=float) - 1
183 | box[2] = _findNode(bndbox, 'xmax', 'bndbox.xmax', parse=float) - 1
184 | box[3] = _findNode(bndbox, 'ymax', 'bndbox.ymax', parse=float) - 1
185 |
186 | return truncated, difficult, box, label
187 |
188 | def __parse_annotations(self, xml_root):
189 | """
190 | Parse all annotations under the xml_root.
191 | """
192 | annotations = {'labels': np.empty((0,), dtype=np.int32),
193 | 'bboxes': np.empty((0, 4))}
194 | for i, element in enumerate(xml_root.iter('object')):
195 | try:
196 | truncated, difficult, box, label = self.__parse_annotation(element)
197 | except ValueError as e:
198 | raise_from(ValueError('could not parse object #{}: {}'.format(i, e)), None)
199 |
200 | if truncated and self.skip_truncated:
201 | continue
202 | if difficult and self.skip_difficult:
203 | continue
204 |
205 | annotations['bboxes'] = np.concatenate([annotations['bboxes'], [box]])
206 | annotations['labels'] = np.concatenate([annotations['labels'], [label]])
207 |
208 | return annotations
209 |
210 | def load_annotations(self, image_index):
211 | """
212 | Load annotations for an image_index.
213 | """
214 | filename = self.image_names[image_index] + '.xml'
215 | try:
216 | tree = ET.parse(os.path.join(self.data_dir, 'Annotations', filename))
217 | return self.__parse_annotations(tree.getroot())
218 | except ET.ParseError as e:
219 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None)
220 | except ValueError as e:
221 | raise_from(ValueError('invalid annotations file: {}: {}'.format(filename, e)), None)
222 |
--------------------------------------------------------------------------------
/yolo/eval/coco.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import keras
18 | from pycocotools.cocoeval import COCOeval
19 | import numpy as np
20 | import json
21 | from tqdm import trange
22 | import cv2
23 |
24 | from generators.coco import CocoGenerator
25 | from model import yolo_body
26 |
27 |
28 | def evaluate(generator, model, threshold=0.01):
29 | """
30 | Use the pycocotools to evaluate a COCO model on a dataset.
31 |
32 | Args
33 | generator: The generator for generating the evaluation data.
34 | model: The model to evaluate.
35 | threshold: The score threshold to use.
36 | """
37 | # start collecting results
38 | results = []
39 | image_ids = []
40 | for index in trange(generator.size(), desc='COCO evaluation: '):
41 | image = generator.load_image(index)
42 | src_image = image.copy()
43 | image_shape = image.shape[:2]
44 | image_shape = np.array(image_shape)
45 | image = generator.preprocess_image(image)
46 |
47 | # run network
48 | detections = model.predict_on_batch([np.expand_dims(image, axis=0), np.expand_dims(image_shape, axis=0)])[0]
49 |
50 | # change to (x, y, w, h) (MS COCO standard)
51 | boxes = np.zeros((detections.shape[0], 4), dtype=np.int32)
52 | # xmin
53 | boxes[:, 0] = np.maximum(np.round(detections[:, 1]).astype(np.int32), 0)
54 | # ymin
55 | boxes[:, 1] = np.maximum(np.round(detections[:, 0]).astype(np.int32), 0)
56 | # w
57 | boxes[:, 2] = np.minimum(np.round(detections[:, 3] - detections[:, 1]).astype(np.int32), image_shape[1])
58 | # h
59 | boxes[:, 3] = np.minimum(np.round(detections[:, 2] - detections[:, 0]).astype(np.int32), image_shape[0])
60 | scores = detections[:, 4]
61 | class_ids = detections[:, 5].astype(np.int32)
62 | # compute predicted labels and scores
63 | for box, score, class_id in zip(boxes, scores, class_ids):
64 | # scores are sorted, so we can break
65 | if score < threshold:
66 | break
67 |
68 | # append detection for each positively labeled class
69 | image_result = {
70 | 'image_id': generator.image_ids[index],
71 | 'category_id': generator.label_to_coco_label(class_id),
72 | 'score': float(score),
73 | 'bbox': box.tolist(),
74 | }
75 | # append detection to results
76 | results.append(image_result)
77 | class_name = generator.label_to_name(class_id)
78 | ret, baseline = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
79 | cv2.rectangle(src_image, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (0, 255, 0), 1)
80 | cv2.putText(src_image, class_name, (box[0], box[1] + box[3] - baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
81 | cv2.namedWindow('image', cv2.WINDOW_NORMAL)
82 | cv2.imshow('image', src_image)
83 | cv2.waitKey(0)
84 | # append image to list of processed images
85 | image_ids.append(generator.image_ids[index])
86 |
87 | if not len(results):
88 | return
89 |
90 | # write output
91 | json.dump(results, open('{}_bbox_results.json'.format(generator.set_name), 'w'), indent=4)
92 | json.dump(image_ids, open('{}_processed_image_ids.json'.format(generator.set_name), 'w'), indent=4)
93 |
94 | # load results in COCO evaluation tool
95 | coco_true = generator.coco
96 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(generator.set_name))
97 |
98 | # run COCO evaluation
99 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
100 | coco_eval.params.imgIds = image_ids
101 | coco_eval.evaluate()
102 | coco_eval.accumulate()
103 | coco_eval.summarize()
104 | return coco_eval.stats
105 |
106 |
107 | class Evaluate(keras.callbacks.Callback):
108 | """ Performs COCO evaluation on each epoch.
109 | """
110 |
111 | def __init__(self, generator, model, tensorboard=None, threshold=0.01):
112 | """ Evaluate callback initializer.
113 |
114 | Args
115 | generator : The generator used for creating validation data.
116 | model: prediction model
117 | tensorboard : If given, the results will be written to tensorboard.
118 | threshold : The score threshold to use.
119 | """
120 | self.generator = generator
121 | self.active_model = model
122 | self.threshold = threshold
123 | self.tensorboard = tensorboard
124 |
125 | super(Evaluate, self).__init__()
126 |
127 | def on_epoch_end(self, epoch, logs=None):
128 | logs = logs or {}
129 |
130 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
131 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]',
132 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]',
133 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
134 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
135 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
136 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]',
137 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]',
138 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
139 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
140 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
141 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]']
142 | coco_eval_stats = evaluate(self.generator, self.model, self.threshold)
143 | if coco_eval_stats is not None and self.tensorboard is not None and self.tensorboard.writer is not None:
144 | import tensorflow as tf
145 | summary = tf.Summary()
146 | for index, result in enumerate(coco_eval_stats):
147 | summary_value = summary.value.add()
148 | summary_value.simple_value = result
149 | summary_value.tag = '{}. {}'.format(index + 1, coco_tag[index])
150 | self.tensorboard.writer.add_summary(summary, epoch)
151 | logs[coco_tag[index]] = result
152 |
153 |
154 | if __name__ == '__main__':
155 | dataset_dir = '/home/adam/.keras/datasets/coco/2017_118_5'
156 | test_generator = CocoGenerator(
157 | anchors_path='yolo_anchors.txt',
158 | data_dir=dataset_dir,
159 | set_name='test-dev2017',
160 | shuffle_groups=False,
161 | )
162 | input_shape = (416, 416)
163 | model, prediction_model = yolo_body(test_generator.anchors, num_classes=80)
164 | model.load_weights('checkpoints/yolov3_weights.h5', by_name=True)
165 | coco_eval_stats = evaluate(test_generator, model)
166 | coco_tag = ['AP @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
167 | 'AP @[ IoU=0.50 | area= all | maxDets=100 ]',
168 | 'AP @[ IoU=0.75 | area= all | maxDets=100 ]',
169 | 'AP @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
170 | 'AP @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
171 | 'AP @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
172 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]',
173 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]',
174 | 'AR @[ IoU=0.50:0.95 | area= all | maxDets=100 ]',
175 | 'AR @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
176 | 'AR @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
177 | 'AR @[ IoU=0.50:0.95 | area= large | maxDets=100 ]']
178 | if coco_eval_stats is not None:
179 | for index, result in enumerate(coco_eval_stats):
180 | print([coco_tag[index]], result)
181 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import keras
18 | import keras.backend as K
19 | import tensorflow as tf
20 |
21 |
22 | def focal(alpha=0.25, gamma=2.0):
23 | """
24 | Create a functor for computing the focal loss.
25 |
26 | Args
27 | alpha: Scale the focal weight with alpha.
28 | gamma: Take the power of the focal weight with gamma.
29 |
30 | Returns
31 | A functor that computes the focal loss using the alpha and gamma.
32 | """
33 |
34 | def _focal(y_true, y_pred):
35 | """
36 | Compute the focal loss given the target tensor and the predicted tensor.
37 |
38 | As defined in https://arxiv.org/abs/1708.02002
39 |
40 | Args
41 | y_true: Tensor of target data from the generator with shape (B, N, num_classes).
42 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes).
43 |
44 | Returns
45 | The focal loss of y_pred w.r.t. y_true.
46 | """
47 | # compute the focal loss
48 | alpha_factor = K.ones_like(y_true) * alpha
49 | alpha_factor = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
50 | focal_weight = tf.where(K.equal(y_true, 1), 1 - y_pred, y_pred)
51 | focal_weight = alpha_factor * focal_weight ** gamma
52 | cls_loss = focal_weight * K.binary_crossentropy(y_true, y_pred)
53 |
54 | # compute the normalizer: the number of positive anchors
55 | normalizer = K.cast(K.shape(y_pred)[1], K.floatx())
56 | normalizer = K.maximum(K.cast_to_floatx(1.0), normalizer)
57 |
58 | return K.sum(cls_loss) / normalizer
59 |
60 | return _focal
61 |
62 |
63 | def smooth_l1(sigma=3.0):
64 | """
65 | Create a smooth L1 loss functor.
66 |
67 | Args
68 | sigma: This argument defines the point where the loss changes from L2 to L1.
69 |
70 | Returns
71 | A functor for computing the smooth L1 loss given target data and predicted data.
72 | """
73 | sigma_squared = sigma ** 2
74 |
75 | def _smooth_l1(y_true, y_pred):
76 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true.
77 |
78 | Args
79 | y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive).
80 | y_pred: Tensor from the network of shape (B, N, 4).
81 |
82 | Returns
83 | The smooth L1 loss of y_pred w.r.t. y_true.
84 | """
85 | # separate target and state
86 | regression = y_pred
87 | regression_target = y_true[:, :, :-1]
88 | anchor_state = y_true[:, :, -1]
89 |
90 | # filter out "ignore" anchors
91 | indices = tf.where(K.equal(anchor_state, 1))
92 | regression = tf.gather_nd(regression, indices)
93 | regression_target = tf.gather_nd(regression_target, indices)
94 |
95 | # compute smooth L1 loss
96 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma
97 | # |x| - 0.5 / sigma / sigma otherwise
98 | regression_diff = regression - regression_target
99 | regression_diff = K.abs(regression_diff)
100 | regression_loss = tf.where(
101 | K.less(regression_diff, 1.0 / sigma_squared),
102 | 0.5 * sigma_squared * K.pow(regression_diff, 2),
103 | regression_diff - 0.5 / sigma_squared
104 | )
105 |
106 | # compute the normalizer: the number of positive anchors
107 | normalizer = K.maximum(1, K.shape(indices)[0])
108 | normalizer = K.cast(normalizer, dtype=K.floatx())
109 | return K.sum(regression_loss) / normalizer
110 |
111 | return _smooth_l1
112 |
113 |
114 | def iou():
115 | def _iou(y_true, y_pred):
116 | y_true = tf.maximum(y_true, 0)
117 | pred_left = y_pred[:, :, 0]
118 | pred_top = y_pred[:, :, 1]
119 | pred_right = y_pred[:, :, 2]
120 | pred_bottom = y_pred[:, :, 3]
121 |
122 | # (num_pos, )
123 | target_left = y_true[:, :, 0]
124 | target_top = y_true[:, :, 1]
125 | target_right = y_true[:, :, 2]
126 | target_bottom = y_true[:, :, 3]
127 |
128 | target_area = (target_left + target_right) * (target_top + target_bottom)
129 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom)
130 | w_intersect = tf.minimum(pred_left, target_left) + tf.minimum(pred_right, target_right)
131 | h_intersect = tf.minimum(pred_bottom, target_bottom) + tf.minimum(pred_top, target_top)
132 |
133 | area_intersect = w_intersect * h_intersect
134 | area_union = target_area + pred_area - area_intersect
135 |
136 | # (num_pos, )
137 | iou_loss = -tf.log((area_intersect + 1e-7) / (area_union + 1e-7))
138 | # compute the normalizer: the number of positive anchors
139 | normalizer = K.maximum(1, K.shape(y_true)[1])
140 | normalizer = K.cast(normalizer, dtype=K.floatx())
141 | return K.sum(iou_loss) / normalizer
142 |
143 | return _iou
144 |
145 |
146 | def focal_with_mask(alpha=0.25, gamma=2.0):
147 | """
148 | Create a functor for computing the focal loss.
149 |
150 | Args
151 | alpha: Scale the focal weight with alpha.
152 | gamma: Take the power of the focal weight with gamma.
153 |
154 | Returns
155 | A functor that computes the focal loss using the alpha and gamma.
156 | """
157 |
158 | def _focal(inputs):
159 | """
160 | Compute the focal loss given the target tensor and the predicted tensor.
161 |
162 | As defined in https://arxiv.org/abs/1708.02002
163 |
164 | Args
165 | y_true: Tensor of target data from the generator with shape (B, N, num_classes).
166 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes).
167 | cls_mask: (B, N)
168 | cls_num_pos: (B, )
169 |
170 | Returns
171 | The focal loss of y_pred w.r.t. y_true.
172 | """
173 | # compute the focal loss
174 | y_true, y_pred, cls_mask, cls_num_pos = inputs[0], inputs[1], inputs[2], inputs[3]
175 | alpha_factor = K.ones_like(y_true) * alpha
176 | alpha_factor = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
177 | focal_weight = tf.where(K.equal(y_true, 1), 1 - y_pred, y_pred)
178 | focal_weight = alpha_factor * focal_weight ** gamma
179 | # (B, N) --> (B, N, 1)
180 | cls_mask = tf.cast(cls_mask, tf.float32)
181 | cls_mask = tf.expand_dims(cls_mask, axis=-1)
182 | # (B, N, num_classes) * (B, N, 1)
183 | masked_cls_loss = focal_weight * K.binary_crossentropy(y_true, y_pred) * cls_mask
184 | # compute the normalizer: the number of positive locations
185 | normalizer = K.maximum(K.cast_to_floatx(1.0), tf.reduce_sum(cls_num_pos))
186 | return K.sum(masked_cls_loss) / normalizer
187 |
188 | return _focal
189 |
190 |
191 | def iou_with_mask():
192 | def _iou(inputs):
193 | """
194 |
195 | Args:
196 | inputs: y_true: (B, N, 4) y_pred: (B, N, 4) regr_mask: (B, N)
197 |
198 | Returns:
199 |
200 | """
201 | y_true, y_pred, regr_mask = inputs[0], inputs[1], inputs[2]
202 | y_true = tf.maximum(y_true, 0)
203 | pred_left = y_pred[:, :, 0]
204 | pred_top = y_pred[:, :, 1]
205 | pred_right = y_pred[:, :, 2]
206 | pred_bottom = y_pred[:, :, 3]
207 |
208 | # (B, N)
209 | target_left = y_true[:, :, 0]
210 | target_top = y_true[:, :, 1]
211 | target_right = y_true[:, :, 2]
212 | target_bottom = y_true[:, :, 3]
213 |
214 | # (B, N)
215 | target_area = (target_left + target_right) * (target_top + target_bottom)
216 | masked_target_area = tf.boolean_mask(target_area, regr_mask)
217 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom)
218 | masked_pred_area = tf.boolean_mask(pred_area, regr_mask)
219 | w_intersect = tf.minimum(pred_left, target_left) + tf.minimum(pred_right, target_right)
220 | h_intersect = tf.minimum(pred_bottom, target_bottom) + tf.minimum(pred_top, target_top)
221 |
222 | area_intersect = w_intersect * h_intersect
223 | masked_area_intersect = tf.boolean_mask(area_intersect, regr_mask)
224 | masked_area_union = masked_target_area + masked_pred_area - masked_area_intersect
225 |
226 | # (B, N)
227 | masked_iou_loss = -tf.log((masked_area_intersect + 1e-7) / (masked_area_union + 1e-7))
228 | regr_mask = tf.cast(regr_mask, tf.float32)
229 | # compute the normalizer: the number of positive locations
230 | regr_num_pos = tf.reduce_sum(regr_mask)
231 | normalizer = K.maximum(K.cast_to_floatx(1.), regr_num_pos)
232 | return K.sum(masked_iou_loss) / normalizer
233 |
234 | return _iou
235 |
--------------------------------------------------------------------------------
/generators/csv_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 yhenon (https://github.com/yhenon/)
3 | Copyright 2017-2018 Fizyr (https://fizyr.com)
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 | """
17 |
18 | from .generator import Generator
19 | from utils.image import read_image_bgr
20 |
21 | import numpy as np
22 | from PIL import Image
23 | from six import raise_from
24 | import csv
25 | import sys
26 | import os.path
27 | from collections import OrderedDict
28 |
29 |
30 | def _parse(value, function, fmt):
31 | """
32 | Parse a string into a value, and format a nice ValueError if it fails.
33 |
34 | Returns `function(value)`.
35 | Any `ValueError` raised is catched and a new `ValueError` is raised
36 | with message `fmt.format(e)`, where `e` is the caught `ValueError`.
37 | """
38 | try:
39 | return function(value)
40 | except ValueError as e:
41 | raise_from(ValueError(fmt.format(e)), None)
42 |
43 |
44 | def _read_classes(csv_reader):
45 | """
46 | Parse the classes file given by csv_reader.
47 | """
48 | result = OrderedDict()
49 | for line, row in enumerate(csv_reader):
50 | line += 1
51 |
52 | try:
53 | class_name, class_id = row
54 | except ValueError:
55 | raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None)
56 | class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line))
57 |
58 | if class_name in result:
59 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name))
60 | result[class_name] = class_id
61 | return result
62 |
63 |
64 | def _read_annotations(csv_reader, classes):
65 | """
66 | Read annotations from the csv_reader.
67 | """
68 | result = OrderedDict()
69 | for line, row in enumerate(csv_reader):
70 | line += 1
71 |
72 | try:
73 | img_file, x1, y1, x2, y2, class_name = row[:6]
74 | except ValueError:
75 | raise_from(ValueError(
76 | 'line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)),
77 | None)
78 |
79 | if img_file not in result:
80 | result[img_file] = []
81 |
82 | # If a row contains only an image path, it's an image without annotations.
83 | if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''):
84 | continue
85 |
86 | x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line))
87 | y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line))
88 | x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line))
89 | y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line))
90 |
91 | # Check that the bounding box is valid.
92 | if x2 <= x1:
93 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1))
94 | if y2 <= y1:
95 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1))
96 |
97 | # check if the current class name is correctly present
98 | if class_name not in classes:
99 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes))
100 |
101 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name})
102 | return result
103 |
104 |
105 | def _open_for_csv(path):
106 | """
107 | Open a file with flags suitable for csv.reader.
108 |
109 | This is different for python2 it means with mode 'rb', for python3 this means 'r' with "universal newlines".
110 | """
111 | if sys.version_info[0] < 3:
112 | return open(path, 'rb')
113 | else:
114 | return open(path, 'r', newline='')
115 |
116 |
117 | class CSVGenerator(Generator):
118 | """
119 | Generate data for a custom CSV dataset.
120 |
121 | See https://github.com/fizyr/keras-retinanet#csv-datasets for more information.
122 | """
123 |
124 | def __init__(
125 | self,
126 | csv_data_file,
127 | csv_class_file,
128 | base_dir=None,
129 | **kwargs
130 | ):
131 | """
132 | Initialize a CSV data generator.
133 |
134 | Args
135 | csv_data_file: Path to the CSV annotations file.
136 | csv_class_file: Path to the CSV classes file.
137 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file).
138 | """
139 | self.image_names = []
140 | self.image_data = {}
141 | self.base_dir = base_dir
142 |
143 | # Take base_dir from annotations file if not explicitly specified.
144 | if self.base_dir is None:
145 | self.base_dir = os.path.dirname(csv_data_file)
146 |
147 | # parse the provided class file
148 | try:
149 | with _open_for_csv(csv_class_file) as file:
150 | # class_name --> class_id
151 | self.classes = _read_classes(csv.reader(file, delimiter=','))
152 | except ValueError as e:
153 | raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None)
154 |
155 | self.labels = {}
156 | # class_id --> class_name
157 | for key, value in self.classes.items():
158 | self.labels[value] = key
159 |
160 | # csv with img_path, x1, y1, x2, y2, class_name
161 | try:
162 | with _open_for_csv(csv_data_file) as file:
163 | # {'img_path1':[{'x1':xx,'y1':xx,'x2':xx,'y2':xx,'class':xx}...],...}
164 | self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes)
165 | except ValueError as e:
166 | raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None)
167 | self.image_names = list(self.image_data.keys())
168 |
169 | super(CSVGenerator, self).__init__(**kwargs)
170 |
171 | def size(self):
172 | """
173 | Size of the dataset.
174 | """
175 | return len(self.image_names)
176 |
177 | def num_classes(self):
178 | """
179 | Number of classes in the dataset.
180 | """
181 | return max(self.classes.values()) + 1
182 |
183 | def has_label(self, label):
184 | """
185 | Return True if label is a known label.
186 | """
187 | return label in self.labels
188 |
189 | def has_name(self, name):
190 | """
191 | Returns True if name is a known class.
192 | """
193 | return name in self.classes
194 |
195 | def name_to_label(self, name):
196 | """
197 | Map name to label.
198 | """
199 | return self.classes[name]
200 |
201 | def label_to_name(self, label):
202 | """
203 | Map label to name.
204 | """
205 | return self.labels[label]
206 |
207 | def image_path(self, image_index):
208 | """
209 | Returns the image path for image_index.
210 | """
211 | return os.path.join(self.base_dir, self.image_names[image_index])
212 |
213 | def image_aspect_ratio(self, image_index):
214 | """
215 | Compute the aspect ratio for an image with image_index.
216 | """
217 | # PIL is fast for metadata
218 | image = Image.open(self.image_path(image_index))
219 | return float(image.width) / float(image.height)
220 |
221 | def load_image(self, image_index):
222 | """
223 | Load an image at the image_index.
224 | """
225 | return read_image_bgr(self.image_path(image_index))
226 |
227 | def load_annotations(self, image_index):
228 | """
229 | Load annotations for an image_index.
230 | """
231 | path = self.image_names[image_index]
232 | annotations = {'labels': np.empty((0,), dtype=np.int32), 'bboxes': np.empty((0, 4))}
233 |
234 | for idx, annot in enumerate(self.image_data[path]):
235 | annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])]))
236 | annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[
237 | float(annot['x1']),
238 | float(annot['y1']),
239 | float(annot['x2']),
240 | float(annot['y2']),
241 | ]]))
242 |
243 | return annotations
244 |
245 |
246 | if __name__ == '__main__':
247 | import cv2
248 | csv_generator = CSVGenerator(
249 | csv_data_file='/home/adam/workspace/github/keras-retinanet_vat/val_gray_annotations_20190615_1255_127.csv',
250 | csv_class_file='/home/adam/workspace/github/keras-retinanet_vat/vat_classes.csv'
251 | )
252 | for image_group, annotation_group, targets in csv_generator:
253 | locations = targets[0]
254 | batch_regr_targets = targets[1]
255 | batch_cls_targets = targets[2]
256 | batch_centerness_targets = targets[3]
257 | for image, annotation, regr_targets, cls_targets, centerness_targets in zip(image_group, annotation_group, batch_regr_targets, batch_cls_targets, batch_centerness_targets):
258 | gt_boxes = annotation['bboxes']
259 | for gt_box in gt_boxes:
260 | gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_box
261 | cv2.rectangle(image, (int(gt_xmin), int(gt_ymin)), (int(gt_xmax), int(gt_ymax)), (0, 255, 0), 2)
262 | pos_indices = np.where(centerness_targets[:, 1] == 1)[0]
263 | for pos_index in pos_indices:
264 | cx, cy = locations[pos_index]
265 | l, t, r, b, _ = regr_targets[pos_index]
266 | xmin = cx - l
267 | ymin = cy - t
268 | xmax = cx + r
269 | ymax = cy + b
270 | class_id = np.argmax(cls_targets[pos_index])
271 | centerness = centerness_targets[pos_index][0]
272 | cv2.putText(image, '{:.2f}'.format(centerness), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 0, 255), 2)
273 | cv2.putText(image, str(class_id), (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (0, 0, 0), 3)
274 | cv2.circle(image, (round(cx), round(cy)), 5, (255, 0, 0), -1)
275 | cv2.rectangle(image, (round(xmin), round(ymin)), (round(xmax), round(ymax)), (0, 0, 255), 2)
276 | cv2.namedWindow('image', cv2.WINDOW_NORMAL)
277 | cv2.imshow('image', image)
278 | cv2.waitKey(0)
279 | pass
280 |
--------------------------------------------------------------------------------
/yolo/generators/csv_.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 yhenon (https://github.com/yhenon/)
3 | Copyright 2017-2018 Fizyr (https://fizyr.com)
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 | """
17 | from collections import OrderedDict
18 | import csv
19 | import numpy as np
20 | import os.path
21 | from PIL import Image
22 | from six import raise_from
23 | import sys
24 |
25 | from utils.image import read_image_bgr
26 | from yolo.generators.common import Generator
27 |
28 |
29 | def _parse(value, function, fmt):
30 | """
31 | Parse a string into a value, and format a nice ValueError if it fails.
32 |
33 | Returns `function(value)`.
34 | Any `ValueError` raised is catched and a new `ValueError` is raised
35 | with message `fmt.format(e)`, where `e` is the caught `ValueError`.
36 | """
37 | try:
38 | return function(value)
39 | except ValueError as e:
40 | raise_from(ValueError(fmt.format(e)), None)
41 |
42 |
43 | def _read_classes(csv_reader):
44 | """
45 | Parse the classes file given by csv_reader.
46 | """
47 | result = OrderedDict()
48 | for line, row in enumerate(csv_reader):
49 | line += 1
50 |
51 | try:
52 | class_name, class_id = row
53 | except ValueError:
54 | raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None)
55 | class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line))
56 |
57 | if class_name in result:
58 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name))
59 | result[class_name] = class_id
60 | return result
61 |
62 |
63 | def _read_annotations(csv_reader, classes):
64 | """
65 | Read annotations from the csv_reader.
66 | """
67 | result = OrderedDict()
68 | for line, row in enumerate(csv_reader):
69 | line += 1
70 |
71 | try:
72 | img_file, x1, y1, x2, y2, class_name = row[:6]
73 | except ValueError:
74 | raise_from(ValueError(
75 | 'line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)),
76 | None)
77 |
78 | if img_file not in result:
79 | result[img_file] = []
80 |
81 | # If a row contains only an image path, it's an image without annotations.
82 | if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''):
83 | continue
84 |
85 | x1 = _parse(x1, int, 'line {}: malformed x1: {{}}'.format(line))
86 | y1 = _parse(y1, int, 'line {}: malformed y1: {{}}'.format(line))
87 | x2 = _parse(x2, int, 'line {}: malformed x2: {{}}'.format(line))
88 | y2 = _parse(y2, int, 'line {}: malformed y2: {{}}'.format(line))
89 |
90 | # Check that the bounding box is valid.
91 | if x2 <= x1:
92 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1))
93 | if y2 <= y1:
94 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1))
95 |
96 | # check if the current class name is correctly present
97 | if class_name not in classes:
98 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes))
99 |
100 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name})
101 | return result
102 |
103 |
104 | def _open_for_csv(path):
105 | """
106 | Open a file with flags suitable for csv.reader.
107 |
108 | This is different for python2 it means with mode 'rb', for python3 this means 'r' with "universal newlines".
109 | """
110 | if sys.version_info[0] < 3:
111 | return open(path, 'rb')
112 | else:
113 | return open(path, 'r', newline='')
114 |
115 |
116 | class CSVGenerator(Generator):
117 | """
118 | Generate data for a custom CSV dataset.
119 |
120 | See https://github.com/fizyr/keras-retinanet#csv-datasets for more information.
121 | """
122 |
123 | def __init__(
124 | self,
125 | csv_data_file,
126 | csv_class_file,
127 | base_dir=None,
128 | **kwargs
129 | ):
130 | """
131 | Initialize a CSV data generator.
132 |
133 | Args
134 | csv_data_file: Path to the CSV annotations file.
135 | csv_class_file: Path to the CSV classes file.
136 | base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file).
137 | """
138 | self.image_names = []
139 | self.image_data = {}
140 | self.base_dir = base_dir
141 |
142 | # Take base_dir from annotations file if not explicitly specified.
143 | if self.base_dir is None:
144 | self.base_dir = os.path.dirname(csv_data_file)
145 |
146 | # parse the provided class file
147 | try:
148 | with _open_for_csv(csv_class_file) as file:
149 | # class_name --> class_id
150 | self.classes = _read_classes(csv.reader(file, delimiter=','))
151 | except ValueError as e:
152 | raise_from(ValueError('invalid CSV class file: {}: {}'.format(csv_class_file, e)), None)
153 |
154 | self.labels = {}
155 | # class_id --> class_name
156 | for key, value in self.classes.items():
157 | self.labels[value] = key
158 |
159 | # csv with img_path, x1, y1, x2, y2, class_name
160 | try:
161 | with _open_for_csv(csv_data_file) as file:
162 | # {'img_path1':[{'x1':xx,'y1':xx,'x2':xx,'y2':xx,'class':xx}...],...}
163 | self.image_data = _read_annotations(csv.reader(file, delimiter=','), self.classes)
164 | except ValueError as e:
165 | raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None)
166 | self.image_names = list(self.image_data.keys())
167 |
168 | super(CSVGenerator, self).__init__(**kwargs)
169 |
170 | def size(self):
171 | """
172 | Size of the dataset.
173 | """
174 | return len(self.image_names)
175 |
176 | def num_classes(self):
177 | """
178 | Number of classes in the dataset.
179 | """
180 | return max(self.classes.values()) + 1
181 |
182 | def has_label(self, label):
183 | """
184 | Return True if label is a known label.
185 | """
186 | return label in self.labels
187 |
188 | def has_name(self, name):
189 | """
190 | Returns True if name is a known class.
191 | """
192 | return name in self.classes
193 |
194 | def name_to_label(self, name):
195 | """
196 | Map name to label.
197 | """
198 | return self.classes[name]
199 |
200 | def label_to_name(self, label):
201 | """
202 | Map label to name.
203 | """
204 | return self.labels[label]
205 |
206 | def image_path(self, image_index):
207 | """
208 | Returns the image path for image_index.
209 | """
210 | return os.path.join(self.base_dir, self.image_names[image_index])
211 |
212 | def image_aspect_ratio(self, image_index):
213 | """
214 | Compute the aspect ratio for an image with image_index.
215 | """
216 | # PIL is fast for metadata
217 | image = Image.open(self.image_path(image_index))
218 | return float(image.width) / float(image.height)
219 |
220 | def load_image(self, image_index):
221 | """
222 | Load an image at the image_index.
223 | """
224 | return read_image_bgr(self.image_path(image_index))
225 |
226 | def load_annotations(self, image_index):
227 | """
228 | Load annotations for an image_index.
229 | """
230 | path = self.image_names[image_index]
231 | annotations = {'labels': np.empty((0,), dtype=np.int32), 'bboxes': np.empty((0, 4))}
232 |
233 | for idx, annot in enumerate(self.image_data[path]):
234 | annotations['labels'] = np.concatenate((annotations['labels'], [self.name_to_label(annot['class'])]))
235 | annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[
236 | float(annot['x1']),
237 | float(annot['y1']),
238 | float(annot['x2']),
239 | float(annot['y2']),
240 | ]]))
241 |
242 | return annotations
243 |
244 |
245 | if __name__ == '__main__':
246 | import cv2
247 |
248 | csv_generator = CSVGenerator(
249 | csv_data_file='/home/adam/workspace/github/keras-retinanet_vat/val_gray_annotations_20190615_1255_127.csv',
250 | csv_class_file='/home/adam/workspace/github/keras-retinanet_vat/vat_classes.csv'
251 | )
252 | for image_group, annotation_group, targets in csv_generator:
253 | locations = targets[0]
254 | batch_regr_targets = targets[1]
255 | batch_cls_targets = targets[2]
256 | batch_centerness_targets = targets[3]
257 | for image, annotation, regr_targets, cls_targets, centerness_targets in zip(image_group, annotation_group,
258 | batch_regr_targets,
259 | batch_cls_targets,
260 | batch_centerness_targets):
261 | gt_boxes = annotation['bboxes']
262 | for gt_box in gt_boxes:
263 | gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_box
264 | cv2.rectangle(image, (int(gt_xmin), int(gt_ymin)), (int(gt_xmax), int(gt_ymax)), (0, 255, 0), 2)
265 | pos_indices = np.where(centerness_targets[:, 1] == 1)[0]
266 | for pos_index in pos_indices:
267 | cx, cy = locations[pos_index]
268 | l, t, r, b, _ = regr_targets[pos_index]
269 | xmin = cx - l
270 | ymin = cy - t
271 | xmax = cx + r
272 | ymax = cy + b
273 | class_id = np.argmax(cls_targets[pos_index])
274 | centerness = centerness_targets[pos_index][0]
275 | cv2.putText(image, '{:.2f}'.format(centerness), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 0, 255),
276 | 2)
277 | cv2.putText(image, str(class_id), (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 2.0, (0, 0, 0), 3)
278 | cv2.circle(image, (round(cx), round(cy)), 5, (255, 0, 0), -1)
279 | cv2.rectangle(image, (round(xmin), round(ymin)), (round(xmax), round(ymax)), (0, 0, 255), 2)
280 | cv2.namedWindow('image', cv2.WINDOW_NORMAL)
281 | cv2.imshow('image', image)
282 | cv2.waitKey(0)
283 | pass
284 |
--------------------------------------------------------------------------------
/utils/transform.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import numpy as np
18 |
19 | DEFAULT_PRNG = np.random
20 |
21 |
22 | def colvec(*args):
23 | """
24 | Create a numpy array representing a column vector.
25 | """
26 | return np.array([args]).T
27 |
28 |
29 | def transform_aabb(transform, aabb):
30 | """
31 | Apply a transformation to an axis aligned bounding box.
32 |
33 | The result is a new AABB in the same coordinate system as the original AABB.
34 | The new AABB contains all corner points of the original AABB after applying the given transformation.
35 |
36 | Args
37 | transform: The transformation to apply.
38 | x1: The minimum x value of the AABB.
39 | y1: The minimum y value of the AABB.
40 | x2: The maximum x value of the AABB.
41 | y2: The maximum y value of the AABB.
42 | Returns
43 | The new AABB as tuple (x1, y1, x2, y2)
44 | """
45 | x1, y1, x2, y2 = aabb
46 | # Transform all 4 corners of the AABB.
47 | points = transform.dot([
48 | [x1, x2, x1, x2],
49 | [y1, y2, y2, y1],
50 | [1, 1, 1, 1],
51 | ])
52 |
53 | # Extract the min and max corners again.
54 | # (3, ) (min_x, min_y, 1)
55 | min_corner = points.min(axis=1)
56 | # (3, ) (max_x, max_y, 1)
57 | max_corner = points.max(axis=1)
58 |
59 | return [min_corner[0], min_corner[1], max_corner[0], max_corner[1]]
60 |
61 |
62 | def _random_vector(min, max, prng=DEFAULT_PRNG):
63 | """
64 | Construct a random vector between min and max.
65 |
66 | Args
67 | min: the minimum value for each component, (n, )
68 | max: the maximum value for each component, (n, )
69 | """
70 | min = np.array(min)
71 | max = np.array(max)
72 | assert min.shape == max.shape
73 | assert len(min.shape) == 1
74 | return prng.uniform(min, max)
75 |
76 |
77 | def rotation(angle):
78 | """
79 | Construct a homogeneous 2D rotation matrix.
80 |
81 | Args
82 | angle: the angle in radians
83 | Returns
84 | the rotation matrix as 3 by 3 numpy array
85 | """
86 | return np.array([
87 | [np.cos(angle), -np.sin(angle), 0],
88 | [np.sin(angle), np.cos(angle), 0],
89 | [0, 0, 1]
90 | ])
91 |
92 |
93 | def random_rotation(min, max, prng=DEFAULT_PRNG):
94 | """
95 | Construct a random rotation between -max and max.
96 |
97 | Args
98 | min: a scalar for the minimum absolute angle in radians
99 | max: a scalar for the maximum absolute angle in radians
100 | prng: the pseudo-random number generator to use.
101 | Returns
102 | a homogeneous 3 by 3 rotation matrix
103 | """
104 | return rotation(prng.uniform(min, max))
105 |
106 |
107 | def translation(translation):
108 | """
109 | Construct a homogeneous 2D translation matrix.
110 |
111 | Args:
112 | translation: the translation 2D vector
113 |
114 | Returns:
115 | the translation matrix as 3 by 3 numpy array
116 |
117 | """
118 | return np.array([
119 | [1, 0, translation[0]],
120 | [0, 1, translation[1]],
121 | [0, 0, 1]
122 | ])
123 |
124 |
125 | def random_translation(min, max, prng=DEFAULT_PRNG):
126 | """
127 | Construct a random 2D translation between min and max.
128 |
129 | Args
130 | min: a 2D vector with the minimum translation for each dimension
131 | max: a 2D vector with the maximum translation for each dimension
132 | prng: the pseudo-random number generator to use.
133 | Returns
134 | a homogeneous 3 by 3 translation matrix
135 | """
136 | return translation(_random_vector(min, max, prng))
137 |
138 |
139 | def shear(angle):
140 | """
141 | Construct a homogeneous 2D shear matrix.
142 |
143 | Args
144 | angle: the shear angle in radians
145 | Returns
146 | the shear matrix as 3 by 3 numpy array
147 | """
148 | return np.array([
149 | [1, -np.sin(angle), 0],
150 | [0, np.cos(angle), 0],
151 | [0, 0, 1]
152 | ])
153 |
154 |
155 | def random_shear(min, max, prng=DEFAULT_PRNG):
156 | """
157 | Construct a random 2D shear matrix with shear angle between -max and max.
158 |
159 | Args
160 | min: the minimum shear angle in radians.
161 | max: the maximum shear angle in radians.
162 | prng: the pseudo-random number generator to use.
163 | Returns
164 | a homogeneous 3 by 3 shear matrix
165 | """
166 | return shear(prng.uniform(min, max))
167 |
168 |
169 | def scaling(factor):
170 | """
171 | Construct a homogeneous 2D scaling matrix.
172 |
173 | Args
174 | factor: a 2D vector for X and Y scaling
175 | Returns
176 | the zoom matrix as 3 by 3 numpy array
177 | """
178 |
179 | return np.array([
180 | [factor[0], 0, 0],
181 | [0, factor[1], 0],
182 | [0, 0, 1]
183 | ])
184 |
185 |
186 | def random_scaling(min, max, prng=DEFAULT_PRNG):
187 | """
188 | Construct a random 2D scale matrix between -max and max.
189 |
190 | Args
191 | min: a 2D vector containing the minimum scaling factor for X and Y.
192 | min: a 2D vector containing The maximum scaling factor for X and Y.
193 | prng: the pseudo-random number generator to use.
194 | Returns
195 | a homogeneous 3 by 3 scaling matrix
196 | """
197 | return scaling(_random_vector(min, max, prng))
198 |
199 |
200 | def random_flip(flip_x_chance, flip_y_chance, prng=DEFAULT_PRNG):
201 | """
202 | Construct a transformation randomly containing X/Y flips (or not).
203 |
204 | Args
205 | flip_x_chance: The chance that the result will contain a flip along the X axis.
206 | flip_y_chance: The chance that the result will contain a flip along the Y axis.
207 | prng: The pseudo-random number generator to use.
208 | Returns
209 | a homogeneous 3 by 3 transformation matrix
210 | """
211 | flip_x = prng.uniform(0, 1) < flip_x_chance
212 | flip_y = prng.uniform(0, 1) < flip_y_chance
213 | # 1 - 2 * bool gives 1 for False and -1 for True.
214 | return scaling((1 - 2 * flip_x, 1 - 2 * flip_y))
215 |
216 |
217 | def change_transform_origin(transform, center):
218 | """
219 | Create a new transform representing the same transformation, only with the origin of the linear part changed.
220 |
221 | Args
222 | transform: the transformation matrix
223 | center: the new origin of the transformation
224 | Returns
225 | translate(center) * transform * translate(-center)
226 | """
227 | center = np.array(center)
228 | return np.linalg.multi_dot([translation(center), transform, translation(-center)])
229 |
230 |
231 | def random_transform(
232 | min_rotation=0,
233 | max_rotation=0,
234 | min_translation=(0, 0),
235 | max_translation=(0, 0),
236 | min_shear=0,
237 | max_shear=0,
238 | min_scaling=(1, 1),
239 | max_scaling=(1, 1),
240 | flip_x_chance=0,
241 | flip_y_chance=0,
242 | prng=DEFAULT_PRNG
243 | ):
244 | """
245 | Create a random transformation.
246 |
247 | The transformation consists of the following operations in this order (from left to right):
248 | * rotation
249 | * translation
250 | * shear
251 | * scaling
252 | * flip x (if applied)
253 | * flip y (if applied)
254 |
255 | Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation
256 | as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width.
257 | Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret
258 | the translation directly as pixel distances instead.
259 |
260 | Args
261 | min_rotation: The minimum rotation in radians for the transform as scalar.
262 | max_rotation: The maximum rotation in radians for the transform as scalar.
263 | min_translation: The minimum translation for the transform as 2D column vector.
264 | max_translation: The maximum translation for the transform as 2D column vector.
265 | min_shear: The minimum shear angle for the transform in radians.
266 | max_shear: The maximum shear angle for the transform in radians.
267 | min_scaling: The minimum scaling for the transform as 2D column vector.
268 | max_scaling: The maximum scaling for the transform as 2D column vector.
269 | flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction.
270 | flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction.
271 | prng: The pseudo-random number generator to use.
272 | """
273 | return np.linalg.multi_dot([
274 | random_rotation(min_rotation, max_rotation, prng),
275 | random_translation(min_translation, max_translation, prng),
276 | random_shear(min_shear, max_shear, prng),
277 | random_scaling(min_scaling, max_scaling, prng),
278 | random_flip(flip_x_chance, flip_y_chance, prng)
279 | ])
280 |
281 |
282 | def random_transform_generator(prng=None, **kwargs):
283 | """
284 | Create a random transform generator.
285 | Uses a dedicated, newly created, properly seeded PRNG by default instead of the global DEFAULT_PRNG.
286 |
287 | The transformation consists of the following operations in this order (from left to right):
288 | * rotation
289 | * translation
290 | * shear
291 | * scaling
292 | * flip x (if applied)
293 | * flip y (if applied)
294 |
295 | Note that by default, the data generators in `keras_retinanet.preprocessing.generators` interpret the translation
296 | as factor of the image size. So an X translation of 0.1 would translate the image by 10% of it's width.
297 | Set `relative_translation` to `False` in the `TransformParameters` of a data generator to have it interpret
298 | the translation directly as pixel distances instead.
299 |
300 | Args
301 | min_rotation: The minimum rotation in radians for the transform as scalar.
302 | max_rotation: The maximum rotation in radians for the transform as scalar.
303 | min_translation: The minimum translation for the transform as 2D column vector.
304 | max_translation: The maximum translation for the transform as 2D column vector.
305 | min_shear: The minimum shear angle for the transform in radians.
306 | max_shear: The maximum shear angle for the transform in radians.
307 | min_scaling: The minimum scaling for the transform as 2D column vector.
308 | max_scaling: The maximum scaling for the transform as 2D column vector.
309 | flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction.
310 | flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction.
311 | prng: The pseudo-random number generator to use.
312 | """
313 |
314 | if prng is None:
315 | # RandomState automatically seeds using the best available method.
316 | prng = np.random.RandomState()
317 |
318 | while True:
319 | yield random_transform(prng=prng, **kwargs)
320 |
--------------------------------------------------------------------------------
/yolo/eval/common.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import numpy as np
18 | import cv2
19 | import progressbar
20 | assert (callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead."
21 |
22 | from utils.compute_overlap import compute_overlap
23 | from utils.visualization import draw_detections, draw_annotations
24 |
25 |
26 | def _compute_ap(recall, precision):
27 | """
28 | Compute the average precision, given the recall and precision curves.
29 |
30 | Code originally from https://github.com/rbgirshick/py-faster-rcnn.
31 |
32 | Args:
33 | recall: The recall curve (list).
34 | precision: The precision curve (list).
35 |
36 | Returns:
37 | The average precision as computed in py-faster-rcnn.
38 |
39 | """
40 | # correct AP calculation
41 | # first append sentinel values at the end
42 | mrec = np.concatenate(([0.], recall, [1.]))
43 | mpre = np.concatenate(([0.], precision, [0.]))
44 |
45 | # compute the precision envelope
46 | for i in range(mpre.size - 1, 0, -1):
47 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
48 |
49 | # to calculate area under PR curve, look for points
50 | # where X axis (recall) changes value
51 | i = np.where(mrec[1:] != mrec[:-1])[0]
52 |
53 | # and sum (delta recall) * prec
54 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
55 | return ap
56 |
57 |
58 | def _get_detections(generator, model, score_threshold=0.01, max_detections=100, visualize=False):
59 | """
60 | Get the detections from the model using the generator.
61 |
62 | The result is a list of lists such that the size is:
63 | all_detections[num_images][num_classes] = detections[num_class_detections, 5]
64 |
65 | Args:
66 | generator: The generator used to run images through the model.
67 | model: The model to run on the images.
68 | score_threshold: The score confidence threshold to use.
69 | max_detections: The maximum number of detections to use per image.
70 | save_path: The path to save the images with visualized detections to.
71 |
72 | Returns:
73 | A list of lists containing the detections for each image in the generator.
74 |
75 | """
76 | all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in range(generator.size())]
77 |
78 | for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '):
79 | image = generator.load_image(i)
80 | src_image = image.copy()
81 | image, scale, offset_h, offset_w = generator.preprocess_image(image)
82 |
83 | # run network
84 | # run network
85 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3]
86 |
87 | # correct boxes for image scale
88 | boxes[0, :, [0, 2]] -= offset_w
89 | boxes[0, :, [1, 3]] -= offset_h
90 | boxes /= scale
91 |
92 | # select indices which have a score above the threshold
93 | indices = np.where(scores[0, :] > score_threshold)[0]
94 |
95 | # select those scores
96 | scores = scores[0][indices]
97 |
98 | # find the order with which to sort the scores
99 | scores_sort = np.argsort(-scores)[:max_detections]
100 |
101 | # select detections
102 | # (n, 4)
103 | boxes = boxes[0, indices[scores_sort], :]
104 | # (n, )
105 | scores = scores[scores_sort]
106 | # (n, )
107 | labels = labels[0, indices[scores_sort]]
108 | # (n, 6)
109 | detections = np.concatenate(
110 | [boxes, np.expand_dims(scores, axis=1), np.expand_dims(labels, axis=1)], axis=1)
111 |
112 | if visualize:
113 | # draw_annotations(src_image, generator.load_annotations(i), label_to_name=generator.label_to_name)
114 | draw_detections(src_image, boxes[:5], scores[:5], labels[:5],
115 | label_to_name=generator.label_to_name,
116 | score_threshold=score_threshold)
117 |
118 | # cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image)
119 | cv2.namedWindow('{}'.format(i), cv2.WINDOW_NORMAL)
120 | cv2.imshow('{}'.format(i), src_image)
121 | cv2.waitKey(0)
122 |
123 | # copy detections to all_detections
124 | for class_id in range(generator.num_classes()):
125 | all_detections[i][class_id] = detections[detections[:, -1] == class_id, :-1]
126 |
127 | return all_detections
128 |
129 |
130 | def _get_annotations(generator):
131 | """
132 | Get the ground truth annotations from the generator.
133 |
134 | The result is a list of lists such that the size is:
135 | all_annotations[num_images][num_classes] = annotations[num_class_annotations, 5]
136 |
137 | Args:
138 | generator: The generator used to retrieve ground truth annotations.
139 |
140 | Returns:
141 | A list of lists containing the annotations for each image in the generator.
142 |
143 | """
144 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())]
145 |
146 | for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '):
147 | # load the annotations
148 | annotations = generator.load_annotations(i)
149 |
150 | # copy detections to all_annotations
151 | for label in range(generator.num_classes()):
152 | if not generator.has_label(label):
153 | continue
154 |
155 | all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy()
156 |
157 | return all_annotations
158 |
159 |
160 | def evaluate(
161 | generator,
162 | model,
163 | iou_threshold=0.5,
164 | score_threshold=0.01,
165 | max_detections=100,
166 | visualize=False,
167 | epoch=0
168 | ):
169 | """
170 | Evaluate a given dataset using a given model.
171 |
172 | Args:
173 | generator: The generator that represents the dataset to evaluate.
174 | model: The model to evaluate.
175 | iou_threshold: The threshold used to consider when a detection is positive or negative.
176 | score_threshold: The score confidence threshold to use for detections.
177 | max_detections: The maximum number of detections to use per image.
178 | visualize: Show the visualized detections or not.
179 |
180 | Returns:
181 | A dict mapping class names to mAP scores.
182 |
183 | """
184 | # gather all detections and annotations
185 | all_detections = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections,
186 | visualize=visualize)
187 | all_annotations = _get_annotations(generator)
188 | average_precisions = {}
189 |
190 | # all_detections = pickle.load(open('all_detections_{}.pkl'.format(epoch + 1), 'rb'))
191 | # all_annotations = pickle.load(open('all_annotations_{}.pkl'.format(epoch + 1), 'rb'))
192 | # pickle.dump(all_detections, open('all_detections_{}.pkl'.format(epoch + 1), 'wb'))
193 | # pickle.dump(all_annotations, open('all_annotations_{}.pkl'.format(epoch + 1), 'wb'))
194 |
195 | # process detections and annotations
196 | for label in range(generator.num_classes()):
197 | if not generator.has_label(label):
198 | continue
199 |
200 | false_positives = np.zeros((0,))
201 | true_positives = np.zeros((0,))
202 | scores = np.zeros((0,))
203 | num_annotations = 0.0
204 |
205 | for i in range(generator.size()):
206 | detections = all_detections[i][label]
207 | annotations = all_annotations[i][label]
208 | num_annotations += annotations.shape[0]
209 | detected_annotations = []
210 |
211 | for d in detections:
212 | scores = np.append(scores, d[4])
213 |
214 | if annotations.shape[0] == 0:
215 | false_positives = np.append(false_positives, 1)
216 | true_positives = np.append(true_positives, 0)
217 | continue
218 | overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)
219 | assigned_annotation = np.argmax(overlaps, axis=1)
220 | max_overlap = overlaps[0, assigned_annotation]
221 |
222 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:
223 | false_positives = np.append(false_positives, 0)
224 | true_positives = np.append(true_positives, 1)
225 | detected_annotations.append(assigned_annotation)
226 | else:
227 | false_positives = np.append(false_positives, 1)
228 | true_positives = np.append(true_positives, 0)
229 |
230 | # no annotations -> AP for this class is 0 (is this correct?)
231 | if num_annotations == 0:
232 | average_precisions[label] = 0, 0
233 | continue
234 |
235 | # sort by score
236 | indices = np.argsort(-scores)
237 | false_positives = false_positives[indices]
238 | true_positives = true_positives[indices]
239 |
240 | # compute false positives and true positives
241 | false_positives = np.cumsum(false_positives)
242 | true_positives = np.cumsum(true_positives)
243 |
244 | # compute recall and precision
245 | recall = true_positives / num_annotations
246 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
247 |
248 | # compute average precision
249 | average_precision = _compute_ap(recall, precision)
250 | average_precisions[label] = average_precision, num_annotations
251 |
252 | return average_precisions
253 |
254 |
255 | if __name__ == '__main__':
256 | from yolo.generators.pascal import PascalVocGenerator
257 | from yolo.model import yolo_body
258 | import os
259 |
260 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
261 | common_args = {
262 | 'batch_size': 1,
263 | 'image_size': 416
264 | }
265 | test_generator = PascalVocGenerator(
266 | 'datasets/voc_test/VOC2007',
267 | 'test',
268 | shuffle_groups=False,
269 | skip_truncated=False,
270 | skip_difficult=True,
271 | anchors_path='voc_anchors_416.txt',
272 | **common_args
273 | )
274 | model_path = 'pascal_18_6.4112_6.5125_0.8319_0.8358.h5'
275 | num_classes = test_generator.num_classes()
276 | model, prediction_model = yolo_body(num_classes=num_classes)
277 | prediction_model.load_weights(model_path, by_name=True, skip_mismatch=True)
278 | average_precisions = evaluate(test_generator, prediction_model, visualize=False)
279 | # compute per class average precision
280 | total_instances = []
281 | precisions = []
282 | for label, (average_precision, num_annotations) in average_precisions.items():
283 | print('{:.0f} instances of class'.format(num_annotations), test_generator.label_to_name(label),
284 | 'with average precision: {:.4f}'.format(average_precision))
285 | total_instances.append(num_annotations)
286 | precisions.append(average_precision)
287 | mean_ap = sum(precisions) / sum(x > 0 for x in total_instances)
288 | print('mAP: {:.4f}'.format(mean_ap))
289 |
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from utils.compute_overlap import compute_overlap
18 | from utils.visualization import draw_detections, draw_annotations
19 |
20 | import keras
21 | import numpy as np
22 | import os
23 | import cv2
24 | import progressbar
25 | import pickle
26 |
27 | assert (callable(progressbar.progressbar)), "Using wrong progressbar module, install 'progressbar2' instead."
28 |
29 |
30 | def _compute_ap(recall, precision):
31 | """
32 | Compute the average precision, given the recall and precision curves.
33 |
34 | Code originally from https://github.com/rbgirshick/py-faster-rcnn.
35 |
36 | Args:
37 | recall: The recall curve (list).
38 | precision: The precision curve (list).
39 |
40 | Returns:
41 | The average precision as computed in py-faster-rcnn.
42 |
43 | """
44 | # correct AP calculation
45 | # first append sentinel values at the end
46 | mrec = np.concatenate(([0.], recall, [1.]))
47 | mpre = np.concatenate(([0.], precision, [0.]))
48 |
49 | # compute the precision envelope
50 | for i in range(mpre.size - 1, 0, -1):
51 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
52 |
53 | # to calculate area under PR curve, look for points
54 | # where X axis (recall) changes value
55 | i = np.where(mrec[1:] != mrec[:-1])[0]
56 |
57 | # and sum (delta recall) * prec
58 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
59 | return ap
60 |
61 |
62 | def _get_detections(generator, model, score_threshold=0.05, max_detections=100, visualize=False):
63 | """
64 | Get the detections from the model using the generator.
65 |
66 | The result is a list of lists such that the size is:
67 | all_detections[num_images][num_classes] = detections[num_class_detections, 5]
68 |
69 | Args:
70 | generator: The generator used to run images through the model.
71 | model: The model to run on the images.
72 | score_threshold: The score confidence threshold to use.
73 | max_detections: The maximum number of detections to use per image.
74 | save_path: The path to save the images with visualized detections to.
75 |
76 | Returns:
77 | A list of lists containing the detections for each image in the generator.
78 |
79 | """
80 | all_detections = [[None for i in range(generator.num_classes()) if generator.has_label(i)] for j in
81 | range(generator.size())]
82 |
83 | for i in progressbar.progressbar(range(generator.size()), prefix='Running network: '):
84 | raw_image = generator.load_image(i)
85 | image = generator.preprocess_image(raw_image.copy())
86 | image, scale = generator.resize_image(image)
87 |
88 | if keras.backend.image_data_format() == 'channels_first':
89 | image = image.transpose((2, 0, 1))
90 |
91 | # run network
92 | boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))[:3]
93 |
94 | # correct boxes for image scale
95 | boxes /= scale
96 |
97 | # select indices which have a score above the threshold
98 | indices = np.where(scores[0, :] > score_threshold)[0]
99 |
100 | # select those scores
101 | scores = scores[0][indices]
102 |
103 | # find the order with which to sort the scores
104 | scores_sort = np.argsort(-scores)[:max_detections]
105 |
106 | # select detections
107 | # (n, 4)
108 | image_boxes = boxes[0, indices[scores_sort], :]
109 | # (n, )
110 | image_scores = scores[scores_sort]
111 | # (n, )
112 | image_labels = labels[0, indices[scores_sort]]
113 | # (n, 6)
114 | image_detections = np.concatenate(
115 | [image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1)
116 |
117 | if visualize:
118 | draw_annotations(raw_image, generator.load_annotations(i), label_to_name=generator.label_to_name)
119 | draw_detections(raw_image, image_boxes[:5], image_scores[:5], image_labels[:5], label_to_name=generator.label_to_name,
120 | score_threshold=score_threshold)
121 |
122 | # cv2.imwrite(os.path.join(save_path, '{}.png'.format(i)), raw_image)
123 | cv2.namedWindow('{}'.format(i), cv2.WINDOW_NORMAL)
124 | cv2.imshow('{}'.format(i), raw_image)
125 | cv2.waitKey(0)
126 |
127 | # copy detections to all_detections
128 | for label in range(generator.num_classes()):
129 | if not generator.has_label(label):
130 | continue
131 |
132 | all_detections[i][label] = image_detections[image_detections[:, -1] == label, :-1]
133 |
134 | return all_detections
135 |
136 |
137 | def _get_annotations(generator):
138 | """
139 | Get the ground truth annotations from the generator.
140 |
141 | The result is a list of lists such that the size is:
142 | all_annotations[num_images][num_classes] = annotations[num_class_annotations, 5]
143 |
144 | Args:
145 | generator: The generator used to retrieve ground truth annotations.
146 |
147 | Returns:
148 | A list of lists containing the annotations for each image in the generator.
149 |
150 | """
151 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(generator.size())]
152 |
153 | for i in progressbar.progressbar(range(generator.size()), prefix='Parsing annotations: '):
154 | # load the annotations
155 | annotations = generator.load_annotations(i)
156 |
157 | # copy detections to all_annotations
158 | for label in range(generator.num_classes()):
159 | if not generator.has_label(label):
160 | continue
161 |
162 | all_annotations[i][label] = annotations['bboxes'][annotations['labels'] == label, :].copy()
163 |
164 | return all_annotations
165 |
166 |
167 | def evaluate(
168 | generator,
169 | model,
170 | iou_threshold=0.5,
171 | score_threshold=0.05,
172 | max_detections=100,
173 | visualize=False,
174 | epoch=0
175 | ):
176 | """
177 | Evaluate a given dataset using a given model.
178 |
179 | Args:
180 | generator: The generator that represents the dataset to evaluate.
181 | model: The model to evaluate.
182 | iou_threshold: The threshold used to consider when a detection is positive or negative.
183 | score_threshold: The score confidence threshold to use for detections.
184 | max_detections: The maximum number of detections to use per image.
185 | visualize: Show the visualized detections or not.
186 |
187 | Returns:
188 | A dict mapping class names to mAP scores.
189 |
190 | """
191 | # gather all detections and annotations
192 | all_detections = _get_detections(generator, model, score_threshold=score_threshold, max_detections=max_detections,
193 | visualize=visualize)
194 | all_annotations = _get_annotations(generator)
195 | average_precisions = {}
196 |
197 | # all_detections = pickle.load(open('all_detections_{}.pkl'.format(epoch + 1), 'rb'))
198 | # all_annotations = pickle.load(open('all_annotations_{}.pkl'.format(epoch + 1), 'rb'))
199 | # pickle.dump(all_detections, open('all_detections_{}.pkl'.format(epoch + 1), 'wb'))
200 | # pickle.dump(all_annotations, open('all_annotations_{}.pkl'.format(epoch + 1), 'wb'))
201 |
202 | # process detections and annotations
203 | for label in range(generator.num_classes()):
204 | if not generator.has_label(label):
205 | continue
206 |
207 | false_positives = np.zeros((0,))
208 | true_positives = np.zeros((0,))
209 | scores = np.zeros((0,))
210 | num_annotations = 0.0
211 |
212 | for i in range(generator.size()):
213 | detections = all_detections[i][label]
214 | annotations = all_annotations[i][label]
215 | num_annotations += annotations.shape[0]
216 | detected_annotations = []
217 |
218 | for d in detections:
219 | scores = np.append(scores, d[4])
220 |
221 | if annotations.shape[0] == 0:
222 | false_positives = np.append(false_positives, 1)
223 | true_positives = np.append(true_positives, 0)
224 | continue
225 | overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)
226 | assigned_annotation = np.argmax(overlaps, axis=1)
227 | max_overlap = overlaps[0, assigned_annotation]
228 |
229 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:
230 | false_positives = np.append(false_positives, 0)
231 | true_positives = np.append(true_positives, 1)
232 | detected_annotations.append(assigned_annotation)
233 | else:
234 | false_positives = np.append(false_positives, 1)
235 | true_positives = np.append(true_positives, 0)
236 |
237 | # no annotations -> AP for this class is 0 (is this correct?)
238 | if num_annotations == 0:
239 | average_precisions[label] = 0, 0
240 | continue
241 |
242 | # sort by score
243 | indices = np.argsort(-scores)
244 | false_positives = false_positives[indices]
245 | true_positives = true_positives[indices]
246 |
247 | # compute false positives and true positives
248 | false_positives = np.cumsum(false_positives)
249 | true_positives = np.cumsum(true_positives)
250 |
251 | # compute recall and precision
252 | recall = true_positives / num_annotations
253 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
254 |
255 | # compute average precision
256 | average_precision = _compute_ap(recall, precision)
257 | average_precisions[label] = average_precision, num_annotations
258 |
259 | return average_precisions
260 |
261 |
262 | if __name__ == '__main__':
263 | from generators.voc_generator import PascalVocGenerator
264 | from utils.image import preprocess_image
265 | import models
266 | import os
267 |
268 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
269 | common_args = {
270 | 'batch_size': 1,
271 | 'image_min_side': 800,
272 | 'image_max_side': 1333,
273 | 'preprocess_image': preprocess_image,
274 | }
275 | # generator = PascalVocGenerator(
276 | # 'datasets/voc_trainval/VOC0712',
277 | # 'val',
278 | # shuffle_groups=False,
279 | # skip_truncated=False,
280 | # skip_difficult=True,
281 | # **common_args
282 | # )
283 | generator = PascalVocGenerator(
284 | 'datasets/voc_test/VOC2007',
285 | 'test',
286 | shuffle_groups=False,
287 | skip_truncated=False,
288 | skip_difficult=True,
289 | **common_args
290 | )
291 | model_path = '/home/adam/workspace/github/xuannianz/carrot/fsaf/snapshots/2019-10-05/resnet101_pascal_47_0.7652.h5'
292 | # load retinanet model
293 | # import keras.backend as K
294 | # K.set_learning_phase(1)
295 | from models.resnet import resnet_fsaf
296 | from models.retinanet import fsaf_bbox
297 | fsaf = resnet_fsaf(num_classes=20, backbone='resnet101')
298 | model = fsaf_bbox(fsaf)
299 | model.load_weights(model_path, by_name=True)
300 | average_precisions = evaluate(generator, model, visualize=False)
301 | # compute per class average precision
302 | total_instances = []
303 | precisions = []
304 | for label, (average_precision, num_annotations) in average_precisions.items():
305 | print('{:.0f} instances of class'.format(num_annotations), generator.label_to_name(label),
306 | 'with average precision: {:.4f}'.format(average_precision))
307 | total_instances.append(num_annotations)
308 | precisions.append(average_precision)
309 | mean_ap = sum(precisions) / sum(x > 0 for x in total_instances)
310 | print('mAP: {:.4f}'.format(mean_ap))
311 |
--------------------------------------------------------------------------------
/utils/image.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2017-2018 Fizyr (https://fizyr.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from __future__ import division
18 | import numpy as np
19 | import cv2
20 | from PIL import Image
21 |
22 | from .transform import change_transform_origin
23 |
24 |
25 | def read_image_bgr(path):
26 | """
27 | Read an image in BGR format.
28 |
29 | Args
30 | path: Path to the image.
31 | """
32 | # We deliberately don't use cv2.imread here, since it gives no feedback on errors while reading the image.
33 | image = np.asarray(Image.open(path).convert('RGB'))
34 | return image[:, :, ::-1].copy()
35 |
36 |
37 | def preprocess_image(x, mode='caffe'):
38 | """
39 | Preprocess an image by subtracting the ImageNet mean.
40 |
41 | Args
42 | x: np.array of shape (None, None, 3) or (3, None, None).
43 | mode: One of "caffe" or "tf".
44 | - caffe: will zero-center each color channel with
45 | respect to the ImageNet dataset, without scaling.
46 | - tf: will scale pixels between -1 and 1, sample-wise.
47 |
48 | Returns
49 | The input with the ImageNet mean subtracted.
50 | """
51 | # mostly identical to "https://github.com/keras-team/keras-applications/blob/master/keras_applications/imagenet_utils.py"
52 | # except for converting RGB -> BGR since we assume BGR already
53 |
54 | # covert always to float32 to keep compatibility with opencv
55 | x = x.astype(np.float32)
56 |
57 | if mode == 'tf':
58 | x /= 127.5
59 | x -= 1.
60 | elif mode == 'caffe':
61 | x[..., 0] -= 103.939
62 | x[..., 1] -= 116.779
63 | x[..., 2] -= 123.68
64 |
65 | return x
66 |
67 |
68 | def adjust_transform_for_image(transform, image, relative_translation):
69 | """
70 | Adjust a transformation for a specific image.
71 |
72 | The translation of the matrix will be scaled with the size of the image.
73 | The linear part of the transformation will adjusted so that the origin of the transformation will be at the center of the image.
74 | """
75 | height, width, channels = image.shape
76 |
77 | result = transform
78 |
79 | # Scale the translation with the image size if specified.
80 | if relative_translation:
81 | result[0:2, 2] *= [width, height]
82 |
83 | # Move the origin of transformation.
84 | result = change_transform_origin(transform, (0.5 * width, 0.5 * height))
85 |
86 | return result
87 |
88 |
89 | class TransformParameters:
90 | """
91 | Struct holding parameters determining how to apply a transformation to an image.
92 |
93 | Args
94 | fill_mode: One of: 'constant', 'nearest', 'reflect', 'wrap'
95 | interpolation: One of: 'nearest', 'linear', 'cubic', 'area', 'lanczos4'
96 | cval: Fill value to use with fill_mode='constant'
97 | relative_translation: If true (the default), interpret translation as a factor of the image size.
98 | If false, interpret it as absolute pixels.
99 | """
100 |
101 | def __init__(
102 | self,
103 | fill_mode='nearest',
104 | interpolation='linear',
105 | cval=0,
106 | relative_translation=True,
107 | ):
108 | self.fill_mode = fill_mode
109 | self.cval = cval
110 | self.interpolation = interpolation
111 | self.relative_translation = relative_translation
112 |
113 | def cvBorderMode(self):
114 | if self.fill_mode == 'constant':
115 | return cv2.BORDER_CONSTANT
116 | if self.fill_mode == 'nearest':
117 | return cv2.BORDER_REPLICATE
118 | if self.fill_mode == 'reflect':
119 | return cv2.BORDER_REFLECT_101
120 | if self.fill_mode == 'wrap':
121 | return cv2.BORDER_WRAP
122 |
123 | def cvInterpolation(self):
124 | if self.interpolation == 'nearest':
125 | return cv2.INTER_NEAREST
126 | if self.interpolation == 'linear':
127 | return cv2.INTER_LINEAR
128 | if self.interpolation == 'cubic':
129 | return cv2.INTER_CUBIC
130 | if self.interpolation == 'area':
131 | return cv2.INTER_AREA
132 | if self.interpolation == 'lanczos4':
133 | return cv2.INTER_LANCZOS4
134 |
135 |
136 | def apply_transform(matrix, image, params):
137 | """
138 | Apply a transformation to an image.
139 |
140 | The origin of transformation is at the top left corner of the image.
141 |
142 | The matrix is interpreted such that a point (x, y) on the original image is moved to transform * (x, y) in the generated image.
143 | Mathematically speaking, that means that the matrix is a transformation from the transformed image space to the original image space.
144 |
145 | Args
146 | matrix: A homogeneous 3 by 3 matrix holding representing the transformation to apply.
147 | image: The image to transform.
148 | params: The transform parameters (see TransformParameters)
149 | """
150 | output = cv2.warpAffine(
151 | image,
152 | matrix[:2, :],
153 | dsize=(image.shape[1], image.shape[0]),
154 | flags=params.cvInterpolation(),
155 | borderMode=params.cvBorderMode(),
156 | borderValue=params.cval,
157 | )
158 | return output
159 |
160 |
161 | def compute_resize_scale(image_shape, min_side=800, max_side=1333):
162 | """
163 | Compute an image scale such that the image size is constrained to min_side and max_side.
164 |
165 | Args
166 | min_side: The image's min side will be equal to min_side after resizing.
167 | max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side.
168 |
169 | Returns
170 | A resizing scale.
171 | """
172 | (rows, cols, _) = image_shape
173 |
174 | smallest_side = min(rows, cols)
175 |
176 | # rescale the image so the smallest side is min_side
177 | scale = min_side / smallest_side
178 |
179 | # check if the largest side is now greater than max_side, which can happen
180 | # when images have a large aspect ratio
181 | largest_side = max(rows, cols)
182 | if largest_side * scale > max_side:
183 | scale = max_side / largest_side
184 |
185 | return scale
186 |
187 |
188 | def resize_image(img, min_side=800, max_side=1333):
189 | """
190 | Resize an image such that the size is constrained to min_side and max_side.
191 |
192 | Args
193 | min_side: The image's min side will be equal to min_side after resizing.
194 | max_side: If after resizing the image's max side is above max_side, resize until the max side is equal to max_side.
195 |
196 | Returns
197 | A resized image.
198 | """
199 | # compute scale to resize the image
200 | scale = compute_resize_scale(img.shape, min_side=min_side, max_side=max_side)
201 |
202 | # resize the image with the computed scale
203 | img = cv2.resize(img, None, fx=scale, fy=scale)
204 |
205 | return img, scale
206 |
207 |
208 | def _uniform(val_range):
209 | """
210 | Uniformly sample from the given range.
211 |
212 | Args
213 | val_range: A pair of lower and upper bound.
214 | """
215 | return np.random.uniform(val_range[0], val_range[1])
216 |
217 |
218 | def _check_range(val_range, min_val=None, max_val=None):
219 | """
220 | Check whether the range is a valid range.
221 |
222 | Args
223 | val_range: A pair of lower and upper bound.
224 | min_val: Minimal value for the lower bound.
225 | max_val: Maximal value for the upper bound.
226 | """
227 | if val_range[0] > val_range[1]:
228 | raise ValueError('interval lower bound > upper bound')
229 | if min_val is not None and val_range[0] < min_val:
230 | raise ValueError('invalid interval lower bound')
231 | if max_val is not None and val_range[1] > max_val:
232 | raise ValueError('invalid interval upper bound')
233 |
234 |
235 | def _clip(image):
236 | """
237 | Clip and convert an image to np.uint8.
238 |
239 | Args
240 | image: Image to clip.
241 | """
242 | return np.clip(image, 0, 255).astype(np.uint8)
243 |
244 |
245 | class VisualEffect:
246 | """
247 | Struct holding parameters and applying image color transformation.
248 |
249 | Args
250 | contrast_factor: A factor for adjusting contrast. Should be between 0 and 3.
251 | brightness_delta: Brightness offset between -1 and 1 added to the pixel values.
252 | hue_delta: Hue offset between -1 and 1 added to the hue channel.
253 | saturation_factor: A factor multiplying the saturation values of each pixel.
254 | """
255 |
256 | def __init__(
257 | self,
258 | contrast_factor,
259 | brightness_delta,
260 | hue_delta,
261 | saturation_factor,
262 | ):
263 | self.contrast_factor = contrast_factor
264 | self.brightness_delta = brightness_delta
265 | self.hue_delta = hue_delta
266 | self.saturation_factor = saturation_factor
267 |
268 | def __call__(self, image):
269 | """
270 | Apply a visual effect on the image.
271 |
272 | Args
273 | image: Image to adjust
274 | """
275 |
276 | if self.contrast_factor:
277 | image = adjust_contrast(image, self.contrast_factor)
278 | if self.brightness_delta:
279 | image = adjust_brightness(image, self.brightness_delta)
280 |
281 | if self.hue_delta or self.saturation_factor:
282 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
283 | if self.hue_delta:
284 | image = adjust_hue(image, self.hue_delta)
285 | if self.saturation_factor:
286 | image = adjust_saturation(image, self.saturation_factor)
287 |
288 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
289 |
290 | return image
291 |
292 |
293 | def random_visual_effect_generator(
294 | contrast_range=(0.9, 1.1),
295 | brightness_range=(-.1, .1),
296 | hue_range=(-0.05, 0.05),
297 | saturation_range=(0.95, 1.05)
298 | ):
299 | """
300 | Generate visual effect parameters uniformly sampled from the given intervals.
301 |
302 | Args
303 | contrast_factor: A factor interval for adjusting contrast. Should be between 0 and 3.
304 | brightness_delta: An interval between -1 and 1 for the amount added to the pixels.
305 | hue_delta: An interval between -1 and 1 for the amount added to the hue channel.
306 | The values are rotated if they exceed 180.
307 | saturation_factor: An interval for the factor multiplying the saturation values of each
308 | pixel.
309 | """
310 | _check_range(contrast_range, 0)
311 | _check_range(brightness_range, -1, 1)
312 | _check_range(hue_range, -1, 1)
313 | _check_range(saturation_range, 0)
314 |
315 | def _generate():
316 | while True:
317 | yield VisualEffect(
318 | contrast_factor=_uniform(contrast_range),
319 | brightness_delta=_uniform(brightness_range),
320 | hue_delta=_uniform(hue_range),
321 | saturation_factor=_uniform(saturation_range),
322 | )
323 |
324 | return _generate()
325 |
326 |
327 | def adjust_contrast(image, factor):
328 | """
329 | Adjust contrast of an image.
330 |
331 | Args
332 | image: Image to adjust.
333 | factor: A factor for adjusting contrast.
334 | """
335 | mean = image.mean(axis=0).mean(axis=0)
336 | return _clip((image - mean) * factor + mean)
337 |
338 |
339 | def adjust_brightness(image, delta):
340 | """
341 | Adjust brightness of an image
342 |
343 | Args
344 | image: Image to adjust.
345 | delta: Brightness offset between -1 and 1 added to the pixel values.
346 | """
347 | return _clip(image + delta * 255)
348 |
349 |
350 | def adjust_hue(image, delta):
351 | """
352 | Adjust hue of an image.
353 |
354 | Args
355 | image: Image to adjust.
356 | delta: An interval between -1 and 1 for the amount added to the hue channel.
357 | The values are rotated if they exceed 180.
358 | """
359 | image[..., 0] = np.mod(image[..., 0] + delta * 180, 180)
360 | return image
361 |
362 |
363 | def adjust_saturation(image, factor):
364 | """
365 | Adjust saturation of an image.
366 |
367 | Args
368 | image: Image to adjust.
369 | factor: An interval for the factor multiplying the saturation values of each pixel.
370 | """
371 | image[..., 1] = np.clip(image[..., 1] * factor, 0, 255)
372 | return image
373 |
--------------------------------------------------------------------------------